Improve DB

This commit is contained in:
Daniel 2024-07-23 21:35:02 -04:00
parent 709cfceec4
commit f7c3a66a43
3 changed files with 30 additions and 8 deletions

View file

@ -24,7 +24,7 @@ var (
ErrInvalidArguments = errors.New("invalid arguments") ErrInvalidArguments = errors.New("invalid arguments")
) )
func addUser(cfg *config.Config, username, email string, isAdmin bool) error { func AddUser(ctx context.Context, cfg *config.Config, username, email string, isAdmin bool) error {
if username == "" || email == "" { if username == "" || email == "" {
return ErrInvalidArguments return ErrInvalidArguments
} }
@ -54,7 +54,7 @@ func addUser(cfg *config.Config, username, email string, isAdmin bool) error {
hashpw, err := bcrypt.GenerateFromPassword([]byte(pw), bcrypt.DefaultCost) hashpw, err := bcrypt.GenerateFromPassword([]byte(pw), bcrypt.DefaultCost)
_, err = database.New(db).CreateUser(context.Background(), database.CreateUserParams{ _, err = db.CreateUser(context.Background(), database.CreateUserParams{
Username: username, Username: username,
Password: string(hashpw), Password: string(hashpw),
Email: email, Email: email,
@ -64,7 +64,7 @@ func addUser(cfg *config.Config, username, email string, isAdmin bool) error {
return err return err
} }
func passwd(cfg *config.Config, username string) error { func Passwd(ctx context.Context, cfg *config.Config, username string) error {
if username == "" { if username == "" {
return ErrInvalidArguments return ErrInvalidArguments
} }
@ -74,6 +74,15 @@ func passwd(cfg *config.Config, username string) error {
return err return err
} }
_, err = db.GetUserByUsername(ctx, username)
if err != nil && database.IsNoRows(err) {
return fmt.Errorf("no such user %s", username)
}
if err != nil {
return err
}
pw, err := readPassword(PromptPassword) pw, err := readPassword(PromptPassword)
if err != nil { if err != nil {
return err return err
@ -130,7 +139,7 @@ func addUserCommand(cfg *config.Config) *cobra.Command {
return err return err
} }
return addUser(cfg, username, email, isAdmin) return AddUser(context.Background(), cfg, username, email, isAdmin)
}, },
Args: cobra.ExactArgs(1), Args: cobra.ExactArgs(1),
} }
@ -146,7 +155,7 @@ func passwdCommand(cfg *config.Config) *cobra.Command {
Short: "changes password for a user", Short: "changes password for a user",
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
username := args[0] username := args[0]
return passwd(cfg, username) return Passwd(context.Background(), cfg, username)
}, },
Args: cobra.ExactArgs(1), Args: cobra.ExactArgs(1),
} }

View file

@ -13,9 +13,13 @@ import (
"github.com/jackc/pgx/v5/pgxpool" "github.com/jackc/pgx/v5/pgxpool"
) )
type DB struct {
*pgxpool.Pool
*Queries
}
type Conn = *pgxpool.Pool type Conn = *pgxpool.Pool
func NewClient(conf config.DB) (Conn, error) { func NewClient(conf config.DB) (*DB, error) {
dir, err := iofs.New(sqlembed.Migrations, "postgres/migrations") dir, err := iofs.New(sqlembed.Migrations, "postgres/migrations")
if err != nil { if err != nil {
return nil, err return nil, err
@ -33,11 +37,16 @@ func NewClient(conf config.DB) (Conn, error) {
m.Close() m.Close()
db, err := pgxpool.New(context.Background(), conf.Connect) pool, err := pgxpool.New(context.Background(), conf.Connect)
if err != nil { if err != nil {
return nil, err return nil, err
} }
db := &DB{
Pool: pool,
Queries: New(pool),
}
return db, nil return db, nil
} }
@ -57,3 +66,7 @@ func FromCtx(ctx context.Context) Conn {
func CtxWithDB(ctx context.Context, conn Conn) context.Context { func CtxWithDB(ctx context.Context, conn Conn) context.Context {
return context.WithValue(ctx, DBCTXKeyValue, conn) return context.WithValue(ctx, DBCTXKeyValue, conn)
} }
func IsNoRows(err error) bool {
return strings.Contains(err.Error(), "no rows in result set")
}

View file

@ -12,7 +12,7 @@ import (
type Server struct { type Server struct {
conf *config.Config conf *config.Config
db database.Conn db *database.DB
r *chi.Mux r *chi.Mux
jwt *jwtauth.JWTAuth jwt *jwtauth.JWTAuth
} }