admin is closer

This commit is contained in:
Daniel Ponte 2024-07-24 07:55:36 -04:00
parent f7c3a66a43
commit 8b5e38d08c
3 changed files with 23 additions and 21 deletions

View file

@ -24,15 +24,12 @@ var (
ErrInvalidArguments = errors.New("invalid arguments") ErrInvalidArguments = errors.New("invalid arguments")
) )
func AddUser(ctx context.Context, cfg *config.Config, username, email string, isAdmin bool) error { func AddUser(ctx context.Context, username, email string, isAdmin bool) error {
if username == "" || email == "" { if username == "" || email == "" {
return ErrInvalidArguments return ErrInvalidArguments
} }
db, err := database.NewClient(cfg.DB) db := database.FromCtx(ctx)
if err != nil {
return err
}
pw, err := readPassword(PromptPassword) pw, err := readPassword(PromptPassword)
if err != nil { if err != nil {
@ -64,17 +61,14 @@ func AddUser(ctx context.Context, cfg *config.Config, username, email string, is
return err return err
} }
func Passwd(ctx context.Context, cfg *config.Config, username string) error { func Passwd(ctx context.Context, username string) error {
if username == "" { if username == "" {
return ErrInvalidArguments return ErrInvalidArguments
} }
db, err := database.NewClient(cfg.DB) db := database.FromCtx(ctx)
if err != nil {
return err
}
_, err = db.GetUserByUsername(ctx, username) _, err := db.GetUserByUsername(ctx, username)
if err != nil && database.IsNoRows(err) { if err != nil && database.IsNoRows(err) {
return fmt.Errorf("no such user %s", username) return fmt.Errorf("no such user %s", username)
} }
@ -103,7 +97,7 @@ func Passwd(ctx context.Context, cfg *config.Config, username string) error {
hashpw, err := bcrypt.GenerateFromPassword([]byte(pw), bcrypt.DefaultCost) hashpw, err := bcrypt.GenerateFromPassword([]byte(pw), bcrypt.DefaultCost)
return database.New(db).UpdatePassword(context.Background(), username, string(hashpw)) return db.UpdatePassword(context.Background(), username, string(hashpw))
} }
func readPassword(prompt string) (string, error) { func readPassword(prompt string) (string, error) {
@ -129,6 +123,11 @@ func addUserCommand(cfg *config.Config) *cobra.Command {
Use: "add", Use: "add",
Short: "adds a user", Short: "adds a user",
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
db, err := database.NewClient(cfg.DB)
if err != nil {
return err
}
username := args[0] username := args[0]
isAdmin, err := cmd.Flags().GetBool("admin") isAdmin, err := cmd.Flags().GetBool("admin")
if err != nil { if err != nil {
@ -139,7 +138,7 @@ func addUserCommand(cfg *config.Config) *cobra.Command {
return err return err
} }
return AddUser(context.Background(), cfg, username, email, isAdmin) return AddUser(database.CtxWithDB(context.Background(), db), username, email, isAdmin)
}, },
Args: cobra.ExactArgs(1), Args: cobra.ExactArgs(1),
} }
@ -151,13 +150,17 @@ func addUserCommand(cfg *config.Config) *cobra.Command {
func passwdCommand(cfg *config.Config) *cobra.Command { func passwdCommand(cfg *config.Config) *cobra.Command {
c := &cobra.Command{ c := &cobra.Command{
Use: "passwd", Use: "passwd userid",
Short: "changes password for a user", Short: "changes password for a user",
Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
db, err := database.NewClient(cfg.DB)
if err != nil {
return err
}
username := args[0] username := args[0]
return Passwd(context.Background(), cfg, username) return Passwd(database.CtxWithDB(context.Background(), db), username)
}, },
Args: cobra.ExactArgs(1),
} }
return c return c

View file

@ -2,9 +2,9 @@ package config
import ( import (
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"os"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
"os"
) )
type Config struct { type Config struct {

View file

@ -17,7 +17,6 @@ type DB struct {
*pgxpool.Pool *pgxpool.Pool
*Queries *Queries
} }
type Conn = *pgxpool.Pool
func NewClient(conf config.DB) (*DB, error) { func NewClient(conf config.DB) (*DB, error) {
dir, err := iofs.New(sqlembed.Migrations, "postgres/migrations") dir, err := iofs.New(sqlembed.Migrations, "postgres/migrations")
@ -54,8 +53,8 @@ type DBCtxKey string
const DBCTXKeyValue DBCtxKey = "dbctx" const DBCTXKeyValue DBCtxKey = "dbctx"
func FromCtx(ctx context.Context) Conn { func FromCtx(ctx context.Context) *DB {
c, ok := ctx.Value(DBCTXKeyValue).(Conn) c, ok := ctx.Value(DBCTXKeyValue).(*DB)
if !ok { if !ok {
panic("no DB in context") panic("no DB in context")
} }
@ -63,7 +62,7 @@ func FromCtx(ctx context.Context) Conn {
return c return c
} }
func CtxWithDB(ctx context.Context, conn Conn) context.Context { func CtxWithDB(ctx context.Context, conn *DB) context.Context {
return context.WithValue(ctx, DBCTXKeyValue, conn) return context.WithValue(ctx, DBCTXKeyValue, conn)
} }