Improve DB

This commit is contained in:
Daniel 2024-07-23 21:35:02 -04:00
parent 709cfceec4
commit f7c3a66a43
3 changed files with 30 additions and 8 deletions

View file

@ -24,7 +24,7 @@ var (
ErrInvalidArguments = errors.New("invalid arguments")
)
func addUser(cfg *config.Config, username, email string, isAdmin bool) error {
func AddUser(ctx context.Context, cfg *config.Config, username, email string, isAdmin bool) error {
if username == "" || email == "" {
return ErrInvalidArguments
}
@ -54,7 +54,7 @@ func addUser(cfg *config.Config, username, email string, isAdmin bool) error {
hashpw, err := bcrypt.GenerateFromPassword([]byte(pw), bcrypt.DefaultCost)
_, err = database.New(db).CreateUser(context.Background(), database.CreateUserParams{
_, err = db.CreateUser(context.Background(), database.CreateUserParams{
Username: username,
Password: string(hashpw),
Email: email,
@ -64,7 +64,7 @@ func addUser(cfg *config.Config, username, email string, isAdmin bool) error {
return err
}
func passwd(cfg *config.Config, username string) error {
func Passwd(ctx context.Context, cfg *config.Config, username string) error {
if username == "" {
return ErrInvalidArguments
}
@ -74,6 +74,15 @@ func passwd(cfg *config.Config, username string) error {
return err
}
_, err = db.GetUserByUsername(ctx, username)
if err != nil && database.IsNoRows(err) {
return fmt.Errorf("no such user %s", username)
}
if err != nil {
return err
}
pw, err := readPassword(PromptPassword)
if err != nil {
return err
@ -130,7 +139,7 @@ func addUserCommand(cfg *config.Config) *cobra.Command {
return err
}
return addUser(cfg, username, email, isAdmin)
return AddUser(context.Background(), cfg, username, email, isAdmin)
},
Args: cobra.ExactArgs(1),
}
@ -146,7 +155,7 @@ func passwdCommand(cfg *config.Config) *cobra.Command {
Short: "changes password for a user",
RunE: func(cmd *cobra.Command, args []string) error {
username := args[0]
return passwd(cfg, username)
return Passwd(context.Background(), cfg, username)
},
Args: cobra.ExactArgs(1),
}

View file

@ -13,9 +13,13 @@ import (
"github.com/jackc/pgx/v5/pgxpool"
)
type DB struct {
*pgxpool.Pool
*Queries
}
type Conn = *pgxpool.Pool
func NewClient(conf config.DB) (Conn, error) {
func NewClient(conf config.DB) (*DB, error) {
dir, err := iofs.New(sqlembed.Migrations, "postgres/migrations")
if err != nil {
return nil, err
@ -33,11 +37,16 @@ func NewClient(conf config.DB) (Conn, error) {
m.Close()
db, err := pgxpool.New(context.Background(), conf.Connect)
pool, err := pgxpool.New(context.Background(), conf.Connect)
if err != nil {
return nil, err
}
db := &DB{
Pool: pool,
Queries: New(pool),
}
return db, nil
}
@ -57,3 +66,7 @@ func FromCtx(ctx context.Context) Conn {
func CtxWithDB(ctx context.Context, conn Conn) context.Context {
return context.WithValue(ctx, DBCTXKeyValue, conn)
}
func IsNoRows(err error) bool {
return strings.Contains(err.Error(), "no rows in result set")
}

View file

@ -12,7 +12,7 @@ import (
type Server struct {
conf *config.Config
db database.Conn
db *database.DB
r *chi.Mux
jwt *jwtauth.JWTAuth
}