Add some godoc comments
This commit is contained in:
parent
992542d9c6
commit
971a2aade7
7 changed files with 68 additions and 127 deletions
|
@ -23,6 +23,7 @@ var (
|
||||||
ErrInvalidArguments = errors.New("invalid arguments")
|
ErrInvalidArguments = errors.New("invalid arguments")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// AddUser adds a new user to the database. It asks for the password on the terminal.
|
||||||
func AddUser(ctx context.Context, username, email string, isAdmin bool) error {
|
func AddUser(ctx context.Context, username, email string, isAdmin bool) error {
|
||||||
if username == "" || email == "" {
|
if username == "" || email == "" {
|
||||||
return ErrInvalidArguments
|
return ErrInvalidArguments
|
||||||
|
@ -60,6 +61,7 @@ func AddUser(ctx context.Context, username, email string, isAdmin bool) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Passwd changes a user's password. It asks for the password on the terminal.
|
||||||
func Passwd(ctx context.Context, username string) error {
|
func Passwd(ctx context.Context, username string) error {
|
||||||
if username == "" {
|
if username == "" {
|
||||||
return ErrInvalidArguments
|
return ErrInvalidArguments
|
||||||
|
@ -106,6 +108,7 @@ func readPassword(prompt string) (string, error) {
|
||||||
return string(pw), err
|
return string(pw), err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Command is the users command.
|
||||||
func Command(cfg *config.Config) []*cobra.Command {
|
func Command(cfg *config.Config) []*cobra.Command {
|
||||||
userCmd := &cobra.Command{
|
userCmd := &cobra.Command{
|
||||||
Use: "users",
|
Use: "users",
|
||||||
|
|
|
@ -1,67 +1,40 @@
|
||||||
package server
|
package auth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"errors"
|
"errors"
|
||||||
"golang.org/x/crypto/bcrypt"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
|
||||||
|
|
||||||
"dynatron.me/x/stillbox/pkg/gordio/database"
|
|
||||||
|
|
||||||
"github.com/go-chi/jwtauth/v5"
|
"github.com/go-chi/jwtauth/v5"
|
||||||
"github.com/rs/zerolog/log"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type claims map[string]interface{}
|
type Authenticator struct {
|
||||||
|
domain string
|
||||||
|
jwt *jwtauth.JWTAuth
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Server) Authenticated(r *http.Request) (claims, bool) {
|
func NewAuthenticator(jwtSecret string, domain string) *Authenticator {
|
||||||
// TODO: check IP against ACL, or conf.Public, and against map of routes
|
return &Authenticator{
|
||||||
tok, cl, err := jwtauth.FromContext(r.Context())
|
domain: domain,
|
||||||
return cl, err != nil && tok != nil
|
jwt: jwtauth.New("HS256", []byte(jwtSecret), nil),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ErrLoginFailed = errors.New("Login failed")
|
ErrLoginFailed = errors.New("Login failed")
|
||||||
|
ErrInternal = errors.New("Internal server error")
|
||||||
|
ErrUnauthorized = errors.New("Unauthorized")
|
||||||
|
ErrBadRequest = errors.New("Bad request")
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s *Server) Login(ctx context.Context, username, password string) (token string, err error) {
|
func ErrorResponse(w http.ResponseWriter, err error) {
|
||||||
q := database.New(database.FromCtx(ctx))
|
switch err {
|
||||||
users, err := q.GetUsers(ctx)
|
case ErrLoginFailed, ErrUnauthorized:
|
||||||
if err != nil {
|
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||||
log.Error().Err(err).Msg("getUsers failed")
|
case ErrBadRequest:
|
||||||
return "", ErrLoginFailed
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||||
|
case ErrInternal:
|
||||||
|
fallthrough
|
||||||
|
default:
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
|
||||||
var found *database.User
|
|
||||||
|
|
||||||
for _, u := range users {
|
|
||||||
if u.Username == username {
|
|
||||||
found = &u
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if found == nil {
|
|
||||||
_ = bcrypt.CompareHashAndPassword([]byte("lol@timing"), []byte(password))
|
|
||||||
return "", ErrLoginFailed
|
|
||||||
} else {
|
|
||||||
err = bcrypt.CompareHashAndPassword([]byte(found.Password), []byte(password))
|
|
||||||
if err != nil {
|
|
||||||
return "", ErrLoginFailed
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return s.NewToken(found.ID), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) NewToken(uid int32) string {
|
|
||||||
claims := claims{
|
|
||||||
"user_id": uid,
|
|
||||||
}
|
|
||||||
jwtauth.SetExpiryIn(claims, time.Hour*24*30) // one month
|
|
||||||
_, tokenString, err := s.jwt.Encode(claims)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
return tokenString
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,7 +14,6 @@ import (
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
type claims map[string]interface{}
|
type claims map[string]interface{}
|
||||||
|
|
||||||
func (a *Authenticator) Authenticated(r *http.Request) (claims, bool) {
|
func (a *Authenticator) Authenticated(r *http.Request) (claims, bool) {
|
||||||
|
|
|
@ -13,11 +13,13 @@ import (
|
||||||
"github.com/jackc/pgx/v5/pgxpool"
|
"github.com/jackc/pgx/v5/pgxpool"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// DB is a database handle.
|
||||||
type DB struct {
|
type DB struct {
|
||||||
*pgxpool.Pool
|
*pgxpool.Pool
|
||||||
*Queries
|
*Queries
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewClient creates a new DB using the provided config.
|
||||||
func NewClient(conf config.DB) (*DB, error) {
|
func NewClient(conf config.DB) (*DB, error) {
|
||||||
dir, err := iofs.New(sqlembed.Migrations, "postgres/migrations")
|
dir, err := iofs.New(sqlembed.Migrations, "postgres/migrations")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -53,6 +55,7 @@ type DBCtxKey string
|
||||||
|
|
||||||
const DBCTXKeyValue DBCtxKey = "dbctx"
|
const DBCTXKeyValue DBCtxKey = "dbctx"
|
||||||
|
|
||||||
|
// FromCtx returns the database handle from the provided Context.
|
||||||
func FromCtx(ctx context.Context) *DB {
|
func FromCtx(ctx context.Context) *DB {
|
||||||
c, ok := ctx.Value(DBCTXKeyValue).(*DB)
|
c, ok := ctx.Value(DBCTXKeyValue).(*DB)
|
||||||
if !ok {
|
if !ok {
|
||||||
|
@ -62,10 +65,13 @@ func FromCtx(ctx context.Context) *DB {
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CtxWithDB returns a Context with the provided database handle.
|
||||||
func CtxWithDB(ctx context.Context, conn *DB) context.Context {
|
func CtxWithDB(ctx context.Context, conn *DB) context.Context {
|
||||||
return context.WithValue(ctx, DBCTXKeyValue, conn)
|
return context.WithValue(ctx, DBCTXKeyValue, conn)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsNoRows is a convenience function that returns whether a returned error is a database
|
||||||
|
// no rows error.
|
||||||
func IsNoRows(err error) bool {
|
func IsNoRows(err error) bool {
|
||||||
return strings.Contains(err.Error(), "no rows in result set")
|
return strings.Contains(err.Error(), "no rows in result set")
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,19 +10,25 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"dynatron.me/x/stillbox/internal/common"
|
"dynatron.me/x/stillbox/internal/common"
|
||||||
|
"dynatron.me/x/stillbox/pkg/gordio/auth"
|
||||||
"dynatron.me/x/stillbox/pkg/gordio/database"
|
"dynatron.me/x/stillbox/pkg/gordio/database"
|
||||||
"github.com/go-chi/chi/v5"
|
"github.com/go-chi/chi/v5"
|
||||||
"github.com/google/uuid"
|
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// HTTPIngestor is an ingestor that accepts calls over HTTP.
|
||||||
type HTTPIngestor struct {
|
type HTTPIngestor struct {
|
||||||
|
auth *auth.Authenticator
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewHTTPIngestor() *HTTPIngestor {
|
// NewHTTPIngestor creates a new HTTPIngestor. It requires an Authenticator.
|
||||||
return new(HTTPIngestor)
|
func NewHTTPIngestor(auth *auth.Authenticator) *HTTPIngestor {
|
||||||
|
return &HTTPIngestor{
|
||||||
|
auth: auth,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// InstallRoutes installs the HTTP ingestor's routes to the provided chi Router.
|
||||||
func (h *HTTPIngestor) InstallRoutes(r chi.Router) {
|
func (h *HTTPIngestor) InstallRoutes(r chi.Router) {
|
||||||
r.Post("/api/call-upload", h.routeCallUpload)
|
r.Post("/api/call-upload", h.routeCallUpload)
|
||||||
}
|
}
|
||||||
|
@ -46,7 +52,7 @@ type callUploadRequest struct {
|
||||||
TalkgroupTag string `form:"talkgroupTag"`
|
TalkgroupTag string `form:"talkgroupTag"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (car *callUploadRequest) ToAddCallParams(submitter int) database.AddCallParams {
|
func (car *callUploadRequest) toAddCallParams(submitter int) database.AddCallParams {
|
||||||
return database.AddCallParams{
|
return database.AddCallParams{
|
||||||
Submitter: common.PtrTo(int32(submitter)),
|
Submitter: common.PtrTo(int32(submitter)),
|
||||||
System: car.System,
|
System: car.System,
|
||||||
|
@ -72,28 +78,15 @@ func (h *HTTPIngestor) routeCallUpload(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
keyUuid, err := uuid.Parse(r.Form.Get("key"))
|
ctx := r.Context()
|
||||||
if err != nil {
|
|
||||||
http.Error(w, "cannot parse key "+err.Error(), http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
db := database.FromCtx(r.Context())
|
|
||||||
apik, err := db.GetAPIKey(r.Context(), keyUuid)
|
|
||||||
if err != nil {
|
|
||||||
if database.IsNoRows(err) {
|
|
||||||
http.Error(w, "bad key", http.StatusUnauthorized)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
apik, err := h.auth.CheckAPIKey(ctx, r.Form.Get("key"))
|
||||||
|
if err != nil {
|
||||||
|
auth.ErrorResponse(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if (apik.Disabled != nil && *apik.Disabled) || (apik.Expires.Valid && time.Now().After(apik.Expires.Time)) {
|
db := database.FromCtx(ctx)
|
||||||
http.Error(w, "disabled", http.StatusUnauthorized)
|
|
||||||
log.Error().Str("key", apik.ApiKey.String()).Msg("key disabled")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if strings.Trim(r.Form.Get("test"), "\r\n") == "1" {
|
if strings.Trim(r.Form.Get("test"), "\r\n") == "1" {
|
||||||
// fudge the official response
|
// fudge the official response
|
||||||
|
@ -108,7 +101,7 @@ func (h *HTTPIngestor) routeCallUpload(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
dbCall, err := db.AddCall(r.Context(), call.ToAddCallParams(apik.Owner))
|
dbCall, err := db.AddCall(ctx, call.toAddCallParams(apik.Owner))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, "internal error", http.StatusInternalServerError)
|
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||||
log.Error().Err(err).Msg("add call")
|
log.Error().Err(err).Msg("add call")
|
||||||
|
|
|
@ -8,7 +8,6 @@ import (
|
||||||
"github.com/go-chi/chi/v5"
|
"github.com/go-chi/chi/v5"
|
||||||
"github.com/go-chi/chi/v5/middleware"
|
"github.com/go-chi/chi/v5/middleware"
|
||||||
"github.com/go-chi/httprate"
|
"github.com/go-chi/httprate"
|
||||||
"github.com/go-chi/jwtauth/v5"
|
|
||||||
"github.com/go-chi/render"
|
"github.com/go-chi/render"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -17,22 +16,22 @@ func (s *Server) setupRoutes() {
|
||||||
r.Use(middleware.WithValue(database.DBCTXKeyValue, s.db))
|
r.Use(middleware.WithValue(database.DBCTXKeyValue, s.db))
|
||||||
|
|
||||||
r.Group(func(r chi.Router) {
|
r.Group(func(r chi.Router) {
|
||||||
r.Use(jwtauth.Verifier(s.jwt))
|
// authenticated routes
|
||||||
r.Use(jwtauth.Authenticator(s.jwt))
|
s.auth.InstallVerifyMiddleware(r)
|
||||||
|
s.auth.InstallAuthMiddleware(r)
|
||||||
})
|
})
|
||||||
|
|
||||||
r.Group(func(r chi.Router) {
|
r.Group(func(r chi.Router) {
|
||||||
r.Use(rateLimiter())
|
r.Use(rateLimiter())
|
||||||
r.Use(render.SetContentType(render.ContentTypeJSON))
|
r.Use(render.SetContentType(render.ContentTypeJSON))
|
||||||
// public routes
|
// public routes
|
||||||
r.Post("/auth", s.routeAuth)
|
s.auth.InstallRoutes(r)
|
||||||
s.hi.InstallRoutes(r)
|
s.httpIngestor.InstallRoutes(r)
|
||||||
})
|
})
|
||||||
|
|
||||||
r.Group(func(r chi.Router) {
|
r.Group(func(r chi.Router) {
|
||||||
r.Use(rateLimiter())
|
r.Use(rateLimiter())
|
||||||
r.Use(jwtauth.Verifier(s.jwt))
|
s.auth.InstallVerifyMiddleware(r)
|
||||||
|
|
||||||
// optional auth routes
|
// optional auth routes
|
||||||
|
|
||||||
|
@ -45,41 +44,8 @@ func rateLimiter() func(http.Handler) http.Handler {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) routeIndex(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) routeIndex(w http.ResponseWriter, r *http.Request) {
|
||||||
if cl, authenticated := s.Authenticated(r); authenticated {
|
if cl, authenticated := s.auth.Authenticated(r); authenticated {
|
||||||
w.Write([]byte("Hello " + cl["user"].(string) + "\n"))
|
w.Write([]byte("Hello " + cl["user"].(string) + "\n"))
|
||||||
}
|
}
|
||||||
w.Write([]byte("Welcome to gordio\n"))
|
w.Write([]byte("Welcome to gordio\n"))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) routeAuth(w http.ResponseWriter, r *http.Request) {
|
|
||||||
err := r.ParseForm()
|
|
||||||
if err != nil {
|
|
||||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
username, password := r.PostFormValue("username"), r.PostFormValue("password")
|
|
||||||
if username == "" || password == "" {
|
|
||||||
http.Error(w, "blank credentials", http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
tok, err := s.Login(r.Context(), username, password)
|
|
||||||
if err != nil {
|
|
||||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
http.SetCookie(w, &http.Cookie{
|
|
||||||
Name: "jwt",
|
|
||||||
Value: tok,
|
|
||||||
HttpOnly: true,
|
|
||||||
Secure: true,
|
|
||||||
Domain: s.conf.Domain,
|
|
||||||
})
|
|
||||||
|
|
||||||
jr := struct {
|
|
||||||
JWT string `json:"jwt"`
|
|
||||||
}{
|
|
||||||
JWT: tok,
|
|
||||||
}
|
|
||||||
render.JSON(w, r, &jr)
|
|
||||||
}
|
|
||||||
|
|
|
@ -3,20 +3,20 @@ package server
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
|
"dynatron.me/x/stillbox/pkg/gordio/auth"
|
||||||
"dynatron.me/x/stillbox/pkg/gordio/config"
|
"dynatron.me/x/stillbox/pkg/gordio/config"
|
||||||
"dynatron.me/x/stillbox/pkg/gordio/database"
|
"dynatron.me/x/stillbox/pkg/gordio/database"
|
||||||
"dynatron.me/x/stillbox/pkg/gordio/ingestors"
|
"dynatron.me/x/stillbox/pkg/gordio/ingestors"
|
||||||
"github.com/go-chi/chi/middleware"
|
"github.com/go-chi/chi/middleware"
|
||||||
"github.com/go-chi/chi/v5"
|
"github.com/go-chi/chi/v5"
|
||||||
"github.com/go-chi/jwtauth/v5"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Server struct {
|
type Server struct {
|
||||||
conf *config.Config
|
auth *auth.Authenticator
|
||||||
db *database.DB
|
conf *config.Config
|
||||||
r *chi.Mux
|
db *database.DB
|
||||||
jwt *jwtauth.JWTAuth
|
r *chi.Mux
|
||||||
hi *ingestors.HTTPIngestor
|
httpIngestor *ingestors.HTTPIngestor
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(cfg *config.Config) (*Server, error) {
|
func New(cfg *config.Config) (*Server, error) {
|
||||||
|
@ -26,12 +26,13 @@ func New(cfg *config.Config) (*Server, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
r := chi.NewRouter()
|
r := chi.NewRouter()
|
||||||
|
authenticator := auth.NewAuthenticator(cfg.JWTSecret, cfg.Domain)
|
||||||
srv := &Server{
|
srv := &Server{
|
||||||
conf: cfg,
|
auth: authenticator,
|
||||||
db: db,
|
conf: cfg,
|
||||||
r: r,
|
db: db,
|
||||||
jwt: jwtauth.New("HS256", []byte(cfg.JWTSecret), nil),
|
r: r,
|
||||||
hi: ingestors.NewHTTPIngestor(),
|
httpIngestor: ingestors.NewHTTPIngestor(authenticator),
|
||||||
}
|
}
|
||||||
r.Use(middleware.RequestID)
|
r.Use(middleware.RequestID)
|
||||||
r.Use(middleware.RealIP)
|
r.Use(middleware.RealIP)
|
||||||
|
|
Loading…
Reference in a new issue