From 759c274950a88dd0f9ab812a910507362871afbd Mon Sep 17 00:00:00 2001 From: Daniel Ponte Date: Sun, 10 Nov 2024 14:40:50 -0500 Subject: [PATCH] Use chi render, improvements --- pkg/api/api.go | 88 +++++++++++++++++++++++++++---------------- pkg/api/talkgroups.go | 23 +++++------ 2 files changed, 68 insertions(+), 43 deletions(-) diff --git a/pkg/api/api.go b/pkg/api/api.go index 9910d9f..923f139 100644 --- a/pkg/api/api.go +++ b/pkg/api/api.go @@ -1,13 +1,13 @@ package api import ( - "encoding/json" "errors" "net/http" "dynatron.me/x/stillbox/pkg/talkgroups" "github.com/go-chi/chi/v5" + "github.com/go-chi/render" "github.com/go-viper/mapstructure/v2" "github.com/jackc/pgx/v5" "github.com/rs/zerolog/log" @@ -29,57 +29,81 @@ func New() API { func (a *api) Subrouter() http.Handler { r := chi.NewMux() - r.Mount("/talkgroup", new(talkgroupAPI).routes()) + r.Mount("/talkgroup", new(talkgroupAPI).Subrouter()) return r } type errResponse struct { - text string - code int + Err error `json:"-"` + Code int `json:"-"` + Error string `json:"error"` } -var statusMapping = map[error]errResponse{ - talkgroups.ErrNotFound: {talkgroups.ErrNotFound.Error(), http.StatusNotFound}, - pgx.ErrNoRows: {"no such record", http.StatusNotFound}, +func (e *errResponse) Render(w http.ResponseWriter, r *http.Request) error { + switch e.Code { + case http.StatusNotFound: + case http.StatusBadRequest: + default: + log.Error().Str("path", r.URL.Path).Err(e.Err).Int("code", e.Code).Str("msg", e.Error).Msg("request failed") + } + + render.Status(r, e.Code) + + return nil } -func httpCode(err error) (string, int) { +func badRequest(err error) render.Renderer { + return &errResponse{ + Err: err, + Code: http.StatusBadRequest, + Error: "Bad request", + } +} + +func recordNotFound(err error) render.Renderer { + return &errResponse{ + Err: err, + Code: http.StatusNotFound, + Error: "Record not found", + } +} + +func internalError(err error) render.Renderer { + return &errResponse{ + Err: err, + Code: http.StatusNotFound, + Error: "Internal server error", + } +} + +type errResponder func(error) render.Renderer + +var statusMapping = map[error]errResponder{ + talkgroups.ErrNotFound: recordNotFound, + pgx.ErrNoRows: recordNotFound, +} + +func autoError(err error) render.Renderer { c, ok := statusMapping[err] if ok { - return c.text, c.code + c(err) } for e, c := range statusMapping { // check if err wraps an error we know about if errors.Is(err, e) { - return c.text, c.code + return c(err) } } - return err.Error(), http.StatusInternalServerError + return internalError(err) } -func writeResponse(w http.ResponseWriter, r *http.Request, data interface{}, err error) { +func wErr(w http.ResponseWriter, r *http.Request, v render.Renderer) { + err := render.Render(w, r, v) if err != nil { - log.Error().Str("path", r.URL.Path).Err(err).Msg("request failed") - text, code := httpCode(err) - http.Error(w, text, code) - return + log.Error().Err(err).Msg("wErr render error") } - - w.Header().Set("Content-Type", "application/json") - enc := json.NewEncoder(w) - err = enc.Encode(data) - if err != nil { - log.Error().Str("path", r.URL.Path).Err(err).Msg("response marshal failed") - text, code := httpCode(err) - http.Error(w, text, code) - return - } -} - -func reqErr(w http.ResponseWriter, err error, code int) { - http.Error(w, err.Error(), code) } func decodeParams(d interface{}, r *http.Request) error { @@ -103,6 +127,6 @@ func decodeParams(d interface{}, r *http.Request) error { return dec.Decode(m) } -func badReq(w http.ResponseWriter, err error) { - reqErr(w, err, http.StatusBadRequest) +func respond(w http.ResponseWriter, r *http.Request, v interface{}) { + render.DefaultResponder(w, r, v) } diff --git a/pkg/api/talkgroups.go b/pkg/api/talkgroups.go index e699b78..ac74190 100644 --- a/pkg/api/talkgroups.go +++ b/pkg/api/talkgroups.go @@ -1,7 +1,6 @@ package api import ( - "encoding/json" "net/http" "dynatron.me/x/stillbox/internal/forms" @@ -14,7 +13,7 @@ import ( type talkgroupAPI struct { } -func (tga *talkgroupAPI) routes() http.Handler { +func (tga *talkgroupAPI) Subrouter() http.Handler { r := chi.NewMux() r.Get("/{system:\\d+}/{id:\\d+}", tga.talkgroup) @@ -57,7 +56,7 @@ func (tga *talkgroupAPI) talkgroup(w http.ResponseWriter, r *http.Request) { err := decodeParams(&p, r) if err != nil { - badReq(w, err) + wErr(w, r, badRequest(err)) return } @@ -71,14 +70,19 @@ func (tga *talkgroupAPI) talkgroup(w http.ResponseWriter, r *http.Request) { res, err = tgs.TGs(ctx, nil) } - writeResponse(w, r, res, err) + if err != nil { + wErr(w, r, autoError(err)) + return + } + + respond(w, r, res) } func (tga *talkgroupAPI) putTalkgroup(w http.ResponseWriter, r *http.Request) { var id tgParams err := decodeParams(&id, r) if err != nil { - badReq(w, err) + wErr(w, r, badRequest(err)) return } @@ -89,19 +93,16 @@ func (tga *talkgroupAPI) putTalkgroup(w http.ResponseWriter, r *http.Request) { err = forms.Unmarshal(r, &input, forms.WithTag("json"), forms.WithAcceptBlank(), forms.WithOmitEmpty()) if err != nil { - writeResponse(w, r, nil, err) + wErr(w, r, badRequest(err)) return } input.ID = id.ToID().Pack() record, err := tgs.UpdateTG(ctx, input) if err != nil { - writeResponse(w, r, nil, err) + wErr(w, r, autoError(err)) return } - err = json.NewEncoder(w).Encode(record) - if err != nil { - writeResponse(w, r, nil, err) - } + respond(w, r, record) }