From 971a2aade701f72d565ff10d675d2aa7c0acce72 Mon Sep 17 00:00:00 2001 From: Daniel Ponte Date: Mon, 29 Jul 2024 00:29:16 -0400 Subject: [PATCH] Add some godoc comments --- pkg/gordio/admin/admin.go | 3 ++ pkg/gordio/auth/auth.go | 75 +++++++++++---------------------- pkg/gordio/auth/jwt.go | 1 - pkg/gordio/database/database.go | 6 +++ pkg/gordio/ingestors/http.go | 39 +++++++---------- pkg/gordio/server/routes.go | 48 +++------------------ pkg/gordio/server/server.go | 23 +++++----- 7 files changed, 68 insertions(+), 127 deletions(-) diff --git a/pkg/gordio/admin/admin.go b/pkg/gordio/admin/admin.go index 25d65ba..758951a 100644 --- a/pkg/gordio/admin/admin.go +++ b/pkg/gordio/admin/admin.go @@ -23,6 +23,7 @@ var ( ErrInvalidArguments = errors.New("invalid arguments") ) +// AddUser adds a new user to the database. It asks for the password on the terminal. func AddUser(ctx context.Context, username, email string, isAdmin bool) error { if username == "" || email == "" { return ErrInvalidArguments @@ -60,6 +61,7 @@ func AddUser(ctx context.Context, username, email string, isAdmin bool) error { return err } +// Passwd changes a user's password. It asks for the password on the terminal. func Passwd(ctx context.Context, username string) error { if username == "" { return ErrInvalidArguments @@ -106,6 +108,7 @@ func readPassword(prompt string) (string, error) { return string(pw), err } +// Command is the users command. func Command(cfg *config.Config) []*cobra.Command { userCmd := &cobra.Command{ Use: "users", diff --git a/pkg/gordio/auth/auth.go b/pkg/gordio/auth/auth.go index fe5be0e..83a967c 100644 --- a/pkg/gordio/auth/auth.go +++ b/pkg/gordio/auth/auth.go @@ -1,67 +1,40 @@ -package server +package auth import ( - "context" "errors" - "golang.org/x/crypto/bcrypt" "net/http" - "time" - - "dynatron.me/x/stillbox/pkg/gordio/database" "github.com/go-chi/jwtauth/v5" - "github.com/rs/zerolog/log" ) -type claims map[string]interface{} +type Authenticator struct { + domain string + jwt *jwtauth.JWTAuth +} -func (s *Server) Authenticated(r *http.Request) (claims, bool) { - // TODO: check IP against ACL, or conf.Public, and against map of routes - tok, cl, err := jwtauth.FromContext(r.Context()) - return cl, err != nil && tok != nil +func NewAuthenticator(jwtSecret string, domain string) *Authenticator { + return &Authenticator{ + domain: domain, + jwt: jwtauth.New("HS256", []byte(jwtSecret), nil), + } } var ( - ErrLoginFailed = errors.New("Login failed") + ErrLoginFailed = errors.New("Login failed") + ErrInternal = errors.New("Internal server error") + ErrUnauthorized = errors.New("Unauthorized") + ErrBadRequest = errors.New("Bad request") ) -func (s *Server) Login(ctx context.Context, username, password string) (token string, err error) { - q := database.New(database.FromCtx(ctx)) - users, err := q.GetUsers(ctx) - if err != nil { - log.Error().Err(err).Msg("getUsers failed") - return "", ErrLoginFailed +func ErrorResponse(w http.ResponseWriter, err error) { + switch err { + case ErrLoginFailed, ErrUnauthorized: + http.Error(w, err.Error(), http.StatusUnauthorized) + case ErrBadRequest: + http.Error(w, err.Error(), http.StatusBadRequest) + case ErrInternal: + fallthrough + default: + http.Error(w, err.Error(), http.StatusInternalServerError) } - - var found *database.User - - for _, u := range users { - if u.Username == username { - found = &u - } - } - - if found == nil { - _ = bcrypt.CompareHashAndPassword([]byte("lol@timing"), []byte(password)) - return "", ErrLoginFailed - } else { - err = bcrypt.CompareHashAndPassword([]byte(found.Password), []byte(password)) - if err != nil { - return "", ErrLoginFailed - } - } - - return s.NewToken(found.ID), nil -} - -func (s *Server) NewToken(uid int32) string { - claims := claims{ - "user_id": uid, - } - jwtauth.SetExpiryIn(claims, time.Hour*24*30) // one month - _, tokenString, err := s.jwt.Encode(claims) - if err != nil { - panic(err) - } - return tokenString } diff --git a/pkg/gordio/auth/jwt.go b/pkg/gordio/auth/jwt.go index 4463aa7..c0a1887 100644 --- a/pkg/gordio/auth/jwt.go +++ b/pkg/gordio/auth/jwt.go @@ -14,7 +14,6 @@ import ( "github.com/rs/zerolog/log" ) - type claims map[string]interface{} func (a *Authenticator) Authenticated(r *http.Request) (claims, bool) { diff --git a/pkg/gordio/database/database.go b/pkg/gordio/database/database.go index 62a93d8..4ed9e61 100644 --- a/pkg/gordio/database/database.go +++ b/pkg/gordio/database/database.go @@ -13,11 +13,13 @@ import ( "github.com/jackc/pgx/v5/pgxpool" ) +// DB is a database handle. type DB struct { *pgxpool.Pool *Queries } +// NewClient creates a new DB using the provided config. func NewClient(conf config.DB) (*DB, error) { dir, err := iofs.New(sqlembed.Migrations, "postgres/migrations") if err != nil { @@ -53,6 +55,7 @@ type DBCtxKey string const DBCTXKeyValue DBCtxKey = "dbctx" +// FromCtx returns the database handle from the provided Context. func FromCtx(ctx context.Context) *DB { c, ok := ctx.Value(DBCTXKeyValue).(*DB) if !ok { @@ -62,10 +65,13 @@ func FromCtx(ctx context.Context) *DB { return c } +// CtxWithDB returns a Context with the provided database handle. func CtxWithDB(ctx context.Context, conn *DB) context.Context { return context.WithValue(ctx, DBCTXKeyValue, conn) } +// IsNoRows is a convenience function that returns whether a returned error is a database +// no rows error. func IsNoRows(err error) bool { return strings.Contains(err.Error(), "no rows in result set") } diff --git a/pkg/gordio/ingestors/http.go b/pkg/gordio/ingestors/http.go index bce8f3c..9bdd39f 100644 --- a/pkg/gordio/ingestors/http.go +++ b/pkg/gordio/ingestors/http.go @@ -10,19 +10,25 @@ import ( "time" "dynatron.me/x/stillbox/internal/common" + "dynatron.me/x/stillbox/pkg/gordio/auth" "dynatron.me/x/stillbox/pkg/gordio/database" "github.com/go-chi/chi/v5" - "github.com/google/uuid" "github.com/rs/zerolog/log" ) +// HTTPIngestor is an ingestor that accepts calls over HTTP. type HTTPIngestor struct { + auth *auth.Authenticator } -func NewHTTPIngestor() *HTTPIngestor { - return new(HTTPIngestor) +// NewHTTPIngestor creates a new HTTPIngestor. It requires an Authenticator. +func NewHTTPIngestor(auth *auth.Authenticator) *HTTPIngestor { + return &HTTPIngestor{ + auth: auth, + } } +// InstallRoutes installs the HTTP ingestor's routes to the provided chi Router. func (h *HTTPIngestor) InstallRoutes(r chi.Router) { r.Post("/api/call-upload", h.routeCallUpload) } @@ -46,7 +52,7 @@ type callUploadRequest struct { TalkgroupTag string `form:"talkgroupTag"` } -func (car *callUploadRequest) ToAddCallParams(submitter int) database.AddCallParams { +func (car *callUploadRequest) toAddCallParams(submitter int) database.AddCallParams { return database.AddCallParams{ Submitter: common.PtrTo(int32(submitter)), System: car.System, @@ -72,28 +78,15 @@ func (h *HTTPIngestor) routeCallUpload(w http.ResponseWriter, r *http.Request) { return } - keyUuid, err := uuid.Parse(r.Form.Get("key")) - if err != nil { - http.Error(w, "cannot parse key "+err.Error(), http.StatusBadRequest) - return - } - db := database.FromCtx(r.Context()) - apik, err := db.GetAPIKey(r.Context(), keyUuid) - if err != nil { - if database.IsNoRows(err) { - http.Error(w, "bad key", http.StatusUnauthorized) - return - } + ctx := r.Context() - http.Error(w, "Internal server error", http.StatusInternalServerError) + apik, err := h.auth.CheckAPIKey(ctx, r.Form.Get("key")) + if err != nil { + auth.ErrorResponse(w, err) return } - if (apik.Disabled != nil && *apik.Disabled) || (apik.Expires.Valid && time.Now().After(apik.Expires.Time)) { - http.Error(w, "disabled", http.StatusUnauthorized) - log.Error().Str("key", apik.ApiKey.String()).Msg("key disabled") - return - } + db := database.FromCtx(ctx) if strings.Trim(r.Form.Get("test"), "\r\n") == "1" { // fudge the official response @@ -108,7 +101,7 @@ func (h *HTTPIngestor) routeCallUpload(w http.ResponseWriter, r *http.Request) { return } - dbCall, err := db.AddCall(r.Context(), call.ToAddCallParams(apik.Owner)) + dbCall, err := db.AddCall(ctx, call.toAddCallParams(apik.Owner)) if err != nil { http.Error(w, "internal error", http.StatusInternalServerError) log.Error().Err(err).Msg("add call") diff --git a/pkg/gordio/server/routes.go b/pkg/gordio/server/routes.go index 7423ba6..db452bd 100644 --- a/pkg/gordio/server/routes.go +++ b/pkg/gordio/server/routes.go @@ -8,7 +8,6 @@ import ( "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" "github.com/go-chi/httprate" - "github.com/go-chi/jwtauth/v5" "github.com/go-chi/render" ) @@ -17,22 +16,22 @@ func (s *Server) setupRoutes() { r.Use(middleware.WithValue(database.DBCTXKeyValue, s.db)) r.Group(func(r chi.Router) { - r.Use(jwtauth.Verifier(s.jwt)) - r.Use(jwtauth.Authenticator(s.jwt)) - + // authenticated routes + s.auth.InstallVerifyMiddleware(r) + s.auth.InstallAuthMiddleware(r) }) r.Group(func(r chi.Router) { r.Use(rateLimiter()) r.Use(render.SetContentType(render.ContentTypeJSON)) // public routes - r.Post("/auth", s.routeAuth) - s.hi.InstallRoutes(r) + s.auth.InstallRoutes(r) + s.httpIngestor.InstallRoutes(r) }) r.Group(func(r chi.Router) { r.Use(rateLimiter()) - r.Use(jwtauth.Verifier(s.jwt)) + s.auth.InstallVerifyMiddleware(r) // optional auth routes @@ -45,41 +44,8 @@ func rateLimiter() func(http.Handler) http.Handler { } func (s *Server) routeIndex(w http.ResponseWriter, r *http.Request) { - if cl, authenticated := s.Authenticated(r); authenticated { + if cl, authenticated := s.auth.Authenticated(r); authenticated { w.Write([]byte("Hello " + cl["user"].(string) + "\n")) } w.Write([]byte("Welcome to gordio\n")) } - -func (s *Server) routeAuth(w http.ResponseWriter, r *http.Request) { - err := r.ParseForm() - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - username, password := r.PostFormValue("username"), r.PostFormValue("password") - if username == "" || password == "" { - http.Error(w, "blank credentials", http.StatusBadRequest) - return - } - - tok, err := s.Login(r.Context(), username, password) - if err != nil { - http.Error(w, err.Error(), http.StatusUnauthorized) - return - } - http.SetCookie(w, &http.Cookie{ - Name: "jwt", - Value: tok, - HttpOnly: true, - Secure: true, - Domain: s.conf.Domain, - }) - - jr := struct { - JWT string `json:"jwt"` - }{ - JWT: tok, - } - render.JSON(w, r, &jr) -} diff --git a/pkg/gordio/server/server.go b/pkg/gordio/server/server.go index 91c98c6..b8a797f 100644 --- a/pkg/gordio/server/server.go +++ b/pkg/gordio/server/server.go @@ -3,20 +3,20 @@ package server import ( "net/http" + "dynatron.me/x/stillbox/pkg/gordio/auth" "dynatron.me/x/stillbox/pkg/gordio/config" "dynatron.me/x/stillbox/pkg/gordio/database" "dynatron.me/x/stillbox/pkg/gordio/ingestors" "github.com/go-chi/chi/middleware" "github.com/go-chi/chi/v5" - "github.com/go-chi/jwtauth/v5" ) type Server struct { - conf *config.Config - db *database.DB - r *chi.Mux - jwt *jwtauth.JWTAuth - hi *ingestors.HTTPIngestor + auth *auth.Authenticator + conf *config.Config + db *database.DB + r *chi.Mux + httpIngestor *ingestors.HTTPIngestor } func New(cfg *config.Config) (*Server, error) { @@ -26,12 +26,13 @@ func New(cfg *config.Config) (*Server, error) { } r := chi.NewRouter() + authenticator := auth.NewAuthenticator(cfg.JWTSecret, cfg.Domain) srv := &Server{ - conf: cfg, - db: db, - r: r, - jwt: jwtauth.New("HS256", []byte(cfg.JWTSecret), nil), - hi: ingestors.NewHTTPIngestor(), + auth: authenticator, + conf: cfg, + db: db, + r: r, + httpIngestor: ingestors.NewHTTPIngestor(authenticator), } r.Use(middleware.RequestID) r.Use(middleware.RealIP)