From e3a731380610994a5235ba21787a27654bd653ec Mon Sep 17 00:00:00 2001 From: Daniel Ponte Date: Sat, 9 Nov 2024 13:14:31 -0500 Subject: [PATCH 1/8] api restructure, support pointers in forms Tests pass Support pointer fields --- internal/forms/forms.go | 213 +++++++++++++++++++++++------------ internal/forms/forms_test.go | 87 ++++++++------ pkg/alerting/alert/alert.go | 3 +- pkg/alerting/simulate.go | 24 ++-- pkg/api/api.go | 42 +------ pkg/api/talkgroups.go | 54 +++++++++ pkg/database/database.go | 6 +- pkg/server/routes.go | 4 +- pkg/talkgroups/cache.go | 8 +- 9 files changed, 273 insertions(+), 168 deletions(-) create mode 100644 pkg/api/talkgroups.go diff --git a/internal/forms/forms.go b/internal/forms/forms.go index facb68c..bc28cc4 100644 --- a/internal/forms/forms.go +++ b/internal/forms/forms.go @@ -1,6 +1,7 @@ package forms import ( + "encoding/json" "errors" "fmt" "io" @@ -16,15 +17,22 @@ import ( ) var ( - ErrNotStruct = errors.New("destination is not a struct") - ErrNotPointer = errors.New("destination is not a pointer") + ErrNotStruct = errors.New("destination is not a struct") + ErrNotPointer = errors.New("destination is not a pointer") + ErrContentType = errors.New("bad content type") +) + +const ( + MaxMultipartMemory int64 = 1024 * 1024 // 1MB ) type options struct { - tagOverride *string - parseTimeIn *time.Location - parseLocal bool - acceptBlank bool + tagOverride *string + parseTimeIn *time.Location + parseLocal bool + acceptBlank bool + maxMultipartMemory int64 + defaultOmitEmpty bool } type Option func(*options) @@ -53,6 +61,18 @@ func WithTag(t string) Option { } } +func WithMaxMultipartSize(s int64) Option { + return func(o *options) { + o.maxMultipartMemory = s + } +} + +func WithOmitEmpty() Option { + return func(o *options) { + o.defaultOmitEmpty = true + } +} + func (o *options) Tag() string { if o.tagOverride != nil { return *o.tagOverride @@ -147,17 +167,19 @@ func (o *options) parseDuration(s string) (v time.Duration, set bool, err error) return } -func (o *options) iterFields(r *http.Request, rv reflect.Value) error { - rt := rv.Type() - for i := 0; i < rv.NumField(); i++ { - f := rv.Field(i) - tf := rt.Field(i) - if !tf.IsExported() && !tf.Anonymous { +var typeOfByteSlice = reflect.TypeOf([]byte(nil)) + +func (o *options) iterFields(r *http.Request, destStruct reflect.Value) error { + structType := destStruct.Type() + for i := 0; i < destStruct.NumField(); i++ { + destFieldVal := destStruct.Field(i) + fieldType := structType.Field(i) + if !fieldType.IsExported() && !fieldType.Anonymous { continue } - if f.Kind() == reflect.Struct && tf.Anonymous { - err := o.iterFields(r, f) + if destFieldVal.Kind() == reflect.Struct && fieldType.Anonymous { + err := o.iterFields(r, destFieldVal) if err != nil { return err } @@ -165,51 +187,38 @@ func (o *options) iterFields(r *http.Request, rv reflect.Value) error { var tAr []string var formField string - formTag, has := rt.Field(i).Tag.Lookup(o.Tag()) + var omitEmpty bool + if o.defaultOmitEmpty { + omitEmpty = true + } + + formTag, has := structType.Field(i).Tag.Lookup(o.Tag()) if has { tAr = strings.Split(formTag, ",") formField = tAr[0] + for _, v := range tAr[1:] { + if v == "omitempty" { + omitEmpty = true + break + } + } } + if !has || formField == "-" { continue } - fi := f.Interface() + destFieldIntf := destFieldVal.Interface() - switch v := fi.(type) { - case string, *string: - s := r.Form.Get(formField) - setVal(f, s != "" || o.acceptBlank, v, s) - case int, uint, *int, *uint: - ff := r.Form.Get(formField) - val, set, err := o.parseInt(ff) - if err != nil { - return err - } - setVal(f, set, v, val) - case float64: - ff := r.Form.Get(formField) - val, set, err := o.parseFloat64(ff) - if err != nil { - return err - } - setVal(f, set, v, val) - case bool, *bool: - ff := r.Form.Get(formField) - val, set, err := o.parseBool(ff) - if err != nil { - return err - } - setVal(f, set, v, val) - case []byte: + if destFieldVal.Kind() == reflect.Slice && destFieldVal.Type() == typeOfByteSlice { file, hdr, err := r.FormFile(formField) if err != nil { return fmt.Errorf("get form file: %w", err) } - nameField, hasFilename := rt.Field(i).Tag.Lookup("filenameField") + nameField, hasFilename := structType.Field(i).Tag.Lookup("filenameField") if hasFilename { - fnf := rv.FieldByName(nameField) + fnf := destStruct.FieldByName(nameField) if fnf == (reflect.Value{}) { panic(fmt.Errorf("filenameField '%s' does not exist", nameField)) } @@ -221,23 +230,52 @@ func (o *options) iterFields(r *http.Request, rv reflect.Value) error { return fmt.Errorf("file read: %w", err) } - f.SetBytes(audioBytes) + destFieldVal.SetBytes(audioBytes) + + continue + } + + if !r.Form.Has(formField) && omitEmpty { + continue + } + + ff := r.Form.Get(formField) + + switch v := destFieldIntf.(type) { + case string, *string: + setVal(destFieldVal, ff != "" || o.acceptBlank, ff) + case int, uint, *int, *uint: + val, set, err := o.parseInt(ff) + if err != nil { + return err + } + setVal(destFieldVal, set, val) + case float64: + val, set, err := o.parseFloat64(ff) + if err != nil { + return err + } + setVal(destFieldVal, set, val) + case bool, *bool: + val, set, err := o.parseBool(ff) + if err != nil { + return err + } + setVal(destFieldVal, set, val) case time.Time, *time.Time, jsontime.Time, *jsontime.Time: - tval := r.Form.Get(formField) - t, set, err := o.parseTime(tval) + t, set, err := o.parseTime(ff) if err != nil { return err } - setVal(f, set, v, t) + setVal(destFieldVal, set, t) case time.Duration, *time.Duration, jsontime.Duration, *jsontime.Duration: - dval := r.Form.Get(formField) - d, set, err := o.parseDuration(dval) + d, set, err := o.parseDuration(ff) if err != nil { return err } - setVal(f, set, v, d) + setVal(destFieldVal, set, d) case []int: - val := strings.Trim(r.Form.Get(formField), "[]") + val := strings.Trim(ff, "[]") if val == "" && o.acceptBlank { continue } @@ -249,7 +287,7 @@ func (o *options) iterFields(r *http.Request, rv reflect.Value) error { ar = append(ar, i) } } - f.Set(reflect.ValueOf(ar)) + destFieldVal.Set(reflect.ValueOf(ar)) default: panic(fmt.Errorf("unsupported type %T", v)) } @@ -258,48 +296,77 @@ func (o *options) iterFields(r *http.Request, rv reflect.Value) error { return nil } -func setVal(setField reflect.Value, set bool, fv any, sv any) { +func setVal(destFieldVal reflect.Value, set bool, src any) { if !set { return } - rv := reflect.TypeOf(fv) - svo := reflect.ValueOf(sv) + destType := destFieldVal.Type() + srcVal := reflect.ValueOf(src) - if svo.CanConvert(rv) { - svo = svo.Convert(rv) + if srcVal.Kind() == reflect.Ptr { + srcVal = srcVal.Elem() } - if rv.Kind() == reflect.Ptr { - svo = svo.Addr() + if destType.Kind() == reflect.Ptr { + if !srcVal.CanAddr() { + if srcVal.CanConvert(destType.Elem()) { + srcVal = srcVal.Convert(destType.Elem()) + } + copy := reflect.New(srcVal.Type()) + copy.Elem().Set(srcVal) + srcVal = copy + } + } else if srcVal.CanConvert(destFieldVal.Type()) { + srcVal = srcVal.Convert(destFieldVal.Type()) } - setField.Set(svo) + destFieldVal.Set(srcVal) } func Unmarshal(r *http.Request, dest any, opt ...Option) error { - o := options{} + o := options{ + maxMultipartMemory: MaxMultipartMemory, + } + for _, opt := range opt { opt(&o) } - rv := reflect.ValueOf(dest) - if k := rv.Kind(); k == reflect.Ptr { - rv = rv.Elem() - } else { - return ErrNotPointer - } + contentType := strings.Split(r.Header.Get("Content-Type"), ";")[0] - if rv.Kind() != reflect.Struct { - return ErrNotStruct - } + switch contentType { + case "multipart/form-data": + err := r.ParseMultipartForm(o.maxMultipartMemory) + if err != nil { + return fmt.Errorf("ParseForm: %w", err) + } - if strings.HasPrefix(r.Header.Get("Content-Type"), "application/x-www-form-urlencoded") { + return o.unmarshalForm(r, dest) + case "application/x-www-form-urlencoded": err := r.ParseForm() if err != nil { return fmt.Errorf("ParseForm: %w", err) } + return o.unmarshalForm(r, dest) + case "application/json": + return json.NewDecoder(r.Body).Decode(dest) } - return o.iterFields(r, rv) + return ErrContentType +} + +func (o *options) unmarshalForm(r *http.Request, dest any) error { + destVal := reflect.ValueOf(dest) + if k := destVal.Kind(); k == reflect.Ptr { + destVal = destVal.Elem() + } else { + return ErrNotPointer + } + + if destVal.Kind() != reflect.Struct { + return ErrNotStruct + } + + return o.iterFields(r, destVal) } diff --git a/internal/forms/forms_test.go b/internal/forms/forms_test.go index c75b839..5f73dca 100644 --- a/internal/forms/forms_test.go +++ b/internal/forms/forms_test.go @@ -8,6 +8,7 @@ import ( "testing" "time" + "dynatron.me/x/stillbox/internal/common" "dynatron.me/x/stillbox/internal/forms" "dynatron.me/x/stillbox/internal/jsontime" @@ -54,6 +55,14 @@ type urlEncTestJT struct { ScoreEnd jsontime.Time `json:"scoreEnd"` } +type ptrTestJT struct { + LookbackDays uint `form:"lookbackDays"` + HalfLife *jsontime.Duration `form:"halfLife"` + Recent *string `form:"recent"` + ScoreStart *jsontime.Time `form:"scoreStart"` + ScoreEnd jsontime.Time `form:"scoreEnd"` +} + var ( UrlEncTest = urlEncTest{ LookbackDays: 7, @@ -69,6 +78,13 @@ var ( ScoreStart: jsontime.Time(time.Date(2024, time.October, 28, 9, 25, 0, 0, time.UTC)), } + PtrTestJT = ptrTestJT{ + LookbackDays: 7, + HalfLife: common.PtrTo(jsontime.Duration(30 * time.Minute)), + Recent: common.PtrTo("2h0m0s"), + ScoreStart: common.PtrTo(jsontime.Time(time.Date(2024, time.October, 28, 9, 25, 0, 0, time.UTC))), + } + UrlEncTestJTLocal = urlEncTestJT{ LookbackDays: 7, HalfLife: jsontime.Duration(30 * time.Minute), @@ -122,29 +138,29 @@ func TestUnmarshal(t *testing.T) { name string r *http.Request dest any - compare any + expect any expectErr error opts []forms.Option }{ { - name: "base case", - r: makeRequest("call1.http"), - dest: &callUploadRequest{}, - compare: &Call1, - opts: []forms.Option{forms.WithAcceptBlank()}, + name: "base case", + r: makeRequest("call1.http"), + dest: &callUploadRequest{}, + expect: &Call1, + opts: []forms.Option{forms.WithAcceptBlank()}, }, { name: "base case no accept blank", r: makeRequest("call1.http"), dest: &callUploadRequest{}, - compare: &Call1, + expect: &Call1, expectErr: errors.New(`parsebool(''): strconv.ParseBool: parsing "": invalid syntax`), }, { name: "not a pointer", r: makeRequest("call1.http"), dest: callUploadRequest{}, - compare: callUploadRequest{}, + expect: callUploadRequest{}, expectErr: forms.ErrNotPointer, opts: []forms.Option{forms.WithAcceptBlank()}, }, @@ -152,7 +168,7 @@ func TestUnmarshal(t *testing.T) { name: "not a struct", r: makeRequest("call1.http"), dest: &str, - compare: callUploadRequest{}, + expect: callUploadRequest{}, expectErr: forms.ErrNotStruct, opts: []forms.Option{forms.WithAcceptBlank()}, }, @@ -160,44 +176,51 @@ func TestUnmarshal(t *testing.T) { name: "url encoded", r: makeRequest("urlenc.http"), dest: &urlEncTest{}, - compare: &UrlEncTest, + expect: &UrlEncTest, expectErr: errors.New(`Could not find format for ""`), }, { - name: "url encoded accept blank", - r: makeRequest("urlenc.http"), - dest: &urlEncTest{}, - compare: &UrlEncTest, - opts: []forms.Option{forms.WithAcceptBlank()}, + name: "url encoded accept blank", + r: makeRequest("urlenc.http"), + dest: &urlEncTest{}, + expect: &UrlEncTest, + opts: []forms.Option{forms.WithAcceptBlank()}, + }, + { + name: "url encoded accept blank pointer", + r: makeRequest("urlenc.http"), + dest: &ptrTestJT{}, + expect: &PtrTestJT, + opts: []forms.Option{forms.WithAcceptBlank()}, }, { name: "url encoded jsontime", r: makeRequest("urlenc.http"), dest: &urlEncTestJT{}, - compare: &UrlEncTestJT, + expect: &UrlEncTestJT, expectErr: errors.New(`Could not find format for ""`), opts: []forms.Option{forms.WithTag("json")}, }, { - name: "url encoded jsontime with tz", - r: makeRequest("urlenc.http"), - dest: &urlEncTestJT{}, - compare: &UrlEncTestJT, - opts: []forms.Option{forms.WithAcceptBlank(), forms.WithParseTimeInTZ(time.UTC), forms.WithTag("json")}, + name: "url encoded jsontime with tz", + r: makeRequest("urlenc.http"), + dest: &urlEncTestJT{}, + expect: &UrlEncTestJT, + opts: []forms.Option{forms.WithAcceptBlank(), forms.WithParseTimeInTZ(time.UTC), forms.WithTag("json")}, }, { - name: "url encoded jsontime with local", - r: makeRequest("urlenc.http"), - dest: &urlEncTestJT{}, - compare: &UrlEncTestJTLocal, - opts: []forms.Option{forms.WithAcceptBlank(), forms.WithParseLocalTime(), forms.WithTag("json")}, + name: "url encoded jsontime with local", + r: makeRequest("urlenc.http"), + dest: &urlEncTestJT{}, + expect: &UrlEncTestJTLocal, + opts: []forms.Option{forms.WithAcceptBlank(), forms.WithParseLocalTime(), forms.WithTag("json")}, }, { - name: "sim real data", - r: makeRequest("urlenc2.http"), - dest: &alerting.Simulation{}, - compare: realSim, - opts: []forms.Option{forms.WithAcceptBlank(), forms.WithParseLocalTime()}, + name: "sim real data", + r: makeRequest("urlenc2.http"), + dest: &alerting.Simulation{}, + expect: realSim, + opts: []forms.Option{forms.WithAcceptBlank(), forms.WithParseLocalTime()}, }, } @@ -209,7 +232,7 @@ func TestUnmarshal(t *testing.T) { assert.Contains(t, tc.expectErr.Error(), err.Error()) } else { require.NoError(t, err) - assert.Equal(t, tc.compare, tc.dest) + assert.Equal(t, tc.expect, tc.dest) } }) } diff --git a/pkg/alerting/alert/alert.go b/pkg/alerting/alert/alert.go index 52d1bf0..6f3ad84 100644 --- a/pkg/alerting/alert/alert.go +++ b/pkg/alerting/alert/alert.go @@ -44,8 +44,7 @@ func (a *Alert) ToAddAlertParams() database.AddAlertParams { } } -// makeAlert creates a notification for later rendering by the template. -// It takes a talkgroup Score as input. +// Make creates an alert for later rendering or storage. func Make(ctx context.Context, store talkgroups.Store, score trending.Score[talkgroups.ID], origScore float64) (Alert, error) { d := Alert{ ID: uuid.New(), diff --git a/pkg/alerting/simulate.go b/pkg/alerting/simulate.go index 6d646fa..0ebefb2 100644 --- a/pkg/alerting/simulate.go +++ b/pkg/alerting/simulate.go @@ -2,7 +2,6 @@ package alerting import ( "context" - "encoding/json" "errors" "fmt" "net/http" @@ -114,24 +113,15 @@ func (s *Simulation) Simulate(ctx context.Context) (trending.Scores[talkgroups.I func (as *alerter) simulateHandler(w http.ResponseWriter, r *http.Request) { ctx := r.Context() s := new(Simulation) - switch r.Header.Get("Content-Type") { - case "application/json": - err := json.NewDecoder(r.Body).Decode(s) - if err != nil { - err = fmt.Errorf("simulate decode: %w", err) - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - default: - err := forms.Unmarshal(r, s, forms.WithAcceptBlank(), forms.WithParseLocalTime()) - if err != nil { - err = fmt.Errorf("simulate unmarshal: %w", err) - http.Error(w, err.Error(), http.StatusBadRequest) - return - } + + err := forms.Unmarshal(r, s, forms.WithAcceptBlank(), forms.WithParseLocalTime()) + if err != nil { + err = fmt.Errorf("simulate unmarshal: %w", err) + http.Error(w, err.Error(), http.StatusBadRequest) + return } - err := s.verify() + err = s.verify() if err != nil { err = fmt.Errorf("simulation profile verify: %w", err) http.Error(w, err.Error(), http.StatusBadRequest) diff --git a/pkg/api/api.go b/pkg/api/api.go index 19ab3ef..580da85 100644 --- a/pkg/api/api.go +++ b/pkg/api/api.go @@ -18,13 +18,10 @@ type API interface { } type api struct { - tgs talkgroups.Store } -func New(tgs talkgroups.Store) API { - s := &api{ - tgs: tgs, - } +func New() API { + s := new(api) return s } @@ -32,9 +29,8 @@ func New(tgs talkgroups.Store) API { func (a *api) Subrouter() http.Handler { r := chi.NewMux() - r.Get("/talkgroup/{system:\\d+}/{id:\\d+}", a.talkgroup) - r.Get("/talkgroup/{system:\\d+}/", a.talkgroup) - r.Get("/talkgroup/", a.talkgroup) + r.Mount("/talkgroup", new(talkgroupAPI).routes()) + return r } @@ -58,7 +54,7 @@ func httpCode(err error) int { return http.StatusInternalServerError } -func (a *api) writeResponse(w http.ResponseWriter, r *http.Request, data interface{}, err error) { +func writeResponse(w http.ResponseWriter, r *http.Request, data interface{}, err error) { if err != nil { log.Error().Str("path", r.URL.Path).Err(err).Msg("request failed") http.Error(w, err.Error(), httpCode(err)) @@ -96,32 +92,6 @@ func decodeParams(d interface{}, r *http.Request) error { return dec.Decode(m) } -func (a *api) badReq(w http.ResponseWriter, err error) { +func badReq(w http.ResponseWriter, err error) { http.Error(w, err.Error(), http.StatusBadRequest) } - -func (a *api) talkgroup(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - p := struct { - System *int `param:"system"` - ID *int `param:"id"` - }{} - - err := decodeParams(&p, r) - if err != nil { - a.badReq(w, err) - return - } - - var res interface{} - switch { - case p.System != nil && p.ID != nil: - res, err = a.tgs.TG(ctx, talkgroups.TG(*p.System, *p.ID)) - case p.System != nil: - res, err = a.tgs.SystemTGs(ctx, int32(*p.System)) - default: - res, err = a.tgs.TGs(ctx, nil) - } - - a.writeResponse(w, r, res, err) -} diff --git a/pkg/api/talkgroups.go b/pkg/api/talkgroups.go new file mode 100644 index 0000000..9051539 --- /dev/null +++ b/pkg/api/talkgroups.go @@ -0,0 +1,54 @@ +package api + +import ( + "net/http" + + "dynatron.me/x/stillbox/internal/forms" + "dynatron.me/x/stillbox/pkg/talkgroups" + + "github.com/go-chi/chi/v5" +) + +type talkgroupAPI struct { +} + +func (tga *talkgroupAPI) routes() http.Handler { + r := chi.NewMux() + + r.Get("/{system:\\d+}/{id:\\d+}", tga.talkgroup) + r.Put("/{system:\\d+}/{id:\\d+}", tga.putTalkgroup) + r.Get("/{system:\\d+}/", tga.talkgroup) + r.Get("/", tga.talkgroup) + + return r +} + +func (tga *talkgroupAPI) talkgroup(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + tgs := talkgroups.StoreFrom(ctx) + p := struct { + System *int `param:"system"` + ID *int `param:"id"` + }{} + + err := decodeParams(&p, r) + if err != nil { + badReq(w, err) + return + } + + var res interface{} + switch { + case p.System != nil && p.ID != nil: + res, err = tgs.TG(ctx, talkgroups.TG(*p.System, *p.ID)) + case p.System != nil: + res, err = tgs.SystemTGs(ctx, int32(*p.System)) + default: + res, err = tgs.TGs(ctx, nil) + } + + writeResponse(w, r, res, err) +} + +func (tga *talkgroupAPI) putTalkgroup(w http.ResponseWriter, r *http.Request) { +} diff --git a/pkg/database/database.go b/pkg/database/database.go index 37794ff..f37b135 100644 --- a/pkg/database/database.go +++ b/pkg/database/database.go @@ -58,11 +58,11 @@ func NewClient(ctx context.Context, conf config.DB) (*DB, error) { type DBCtxKey string -const DBCTXKeyValue DBCtxKey = "dbctx" +const DBCtxKeyValue DBCtxKey = "dbctx" // FromCtx returns the database handle from the provided Context. func FromCtx(ctx context.Context) *DB { - c, ok := ctx.Value(DBCTXKeyValue).(*DB) + c, ok := ctx.Value(DBCtxKeyValue).(*DB) if !ok { panic("no DB in context") } @@ -72,7 +72,7 @@ func FromCtx(ctx context.Context) *DB { // CtxWithDB returns a Context with the provided database handle. func CtxWithDB(ctx context.Context, conn *DB) context.Context { - return context.WithValue(ctx, DBCTXKeyValue, conn) + return context.WithValue(ctx, DBCtxKeyValue, conn) } // IsNoRows is a convenience function that returns whether a returned error is a database diff --git a/pkg/server/routes.go b/pkg/server/routes.go index bcc7c3b..9b66b88 100644 --- a/pkg/server/routes.go +++ b/pkg/server/routes.go @@ -9,6 +9,7 @@ import ( "dynatron.me/x/stillbox/internal/version" "dynatron.me/x/stillbox/pkg/config" "dynatron.me/x/stillbox/pkg/database" + "dynatron.me/x/stillbox/pkg/talkgroups" "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" "github.com/go-chi/httprate" @@ -26,7 +27,8 @@ func (s *Server) setupRoutes() { } r := s.r - r.Use(middleware.WithValue(database.DBCTXKeyValue, s.db)) + r.Use(middleware.WithValue(database.DBCtxKeyValue, s.db)) + r.Use(middleware.WithValue(talkgroups.StoreCtxKeyValue, s.tgs)) s.installPprof() diff --git a/pkg/talkgroups/cache.go b/pkg/talkgroups/cache.go index 85ac5d0..14c2237 100644 --- a/pkg/talkgroups/cache.go +++ b/pkg/talkgroups/cache.go @@ -49,16 +49,16 @@ type Store interface { HUP(*config.Config) } -type CtxStoreKeyT string +type CtxStoreKey string -const CtxStoreKey CtxStoreKeyT = "store" +const StoreCtxKeyValue CtxStoreKey = "store" func CtxWithStore(ctx context.Context, s Store) context.Context { - return context.WithValue(ctx, CtxStoreKey, s) + return context.WithValue(ctx, StoreCtxKeyValue, s) } func StoreFrom(ctx context.Context) Store { - s, ok := ctx.Value(CtxStoreKey).(Store) + s, ok := ctx.Value(StoreCtxKeyValue).(Store) if !ok { return NewCache() } From 9046e346b101c90a883cf5e7ddd1a0769f2470cf Mon Sep 17 00:00:00 2001 From: Daniel Ponte Date: Sat, 9 Nov 2024 21:04:16 -0500 Subject: [PATCH 2/8] API structure improvements, convenience funcs --- pkg/api/api.go | 6 +++- pkg/api/talkgroups.go | 70 ++++++++++++++++++++++++++++++++++++++++--- pkg/server/server.go | 2 +- 3 files changed, 72 insertions(+), 6 deletions(-) diff --git a/pkg/api/api.go b/pkg/api/api.go index 580da85..4e22860 100644 --- a/pkg/api/api.go +++ b/pkg/api/api.go @@ -71,6 +71,10 @@ func writeResponse(w http.ResponseWriter, r *http.Request, data interface{}, err } } +func reqErr(w http.ResponseWriter, err error, code int) { + http.Error(w, err.Error(), code) +} + func decodeParams(d interface{}, r *http.Request) error { params := chi.RouteContext(r.Context()).URLParams m := make(map[string]string, len(params.Keys)) @@ -93,5 +97,5 @@ func decodeParams(d interface{}, r *http.Request) error { } func badReq(w http.ResponseWriter, err error) { - http.Error(w, err.Error(), http.StatusBadRequest) + reqErr(w, err, http.StatusBadRequest) } diff --git a/pkg/api/talkgroups.go b/pkg/api/talkgroups.go index 9051539..e16f8f3 100644 --- a/pkg/api/talkgroups.go +++ b/pkg/api/talkgroups.go @@ -1,6 +1,7 @@ package api import ( + "fmt" "net/http" "dynatron.me/x/stillbox/internal/forms" @@ -23,13 +24,35 @@ func (tga *talkgroupAPI) routes() http.Handler { return r } +type tgParams struct { + System *int `param:"system"` + ID *int `param:"id"` +} + +func (t tgParams) haveBoth() bool { + return t.System != nil && t.ID != nil +} + +func (t tgParams) ToID() talkgroups.ID { + nilOr := func(i *int) uint32 { + if i == nil { + return 0 + } + + return uint32(*i) + } + + return talkgroups.ID{ + System: nilOr(t.System), + Talkgroup: nilOr(t.ID), + } +} + func (tga *talkgroupAPI) talkgroup(w http.ResponseWriter, r *http.Request) { ctx := r.Context() tgs := talkgroups.StoreFrom(ctx) - p := struct { - System *int `param:"system"` - ID *int `param:"id"` - }{} + + var p tgParams err := decodeParams(&p, r) if err != nil { @@ -51,4 +74,43 @@ func (tga *talkgroupAPI) talkgroup(w http.ResponseWriter, r *http.Request) { } func (tga *talkgroupAPI) putTalkgroup(w http.ResponseWriter, r *http.Request) { + var id tgParams + err := decodeParams(&id, r) + if err != nil { + badReq(w, err) + return + } + /* + ctx := r.Context() + tgs := talkgroups.StoreFrom(ctx) + + tg, err := tgs.TG(ctx, id.ToID()) + switch err { + case nil: + case talkgroups.ErrNotFound: + reqErr(w, err, http.StatusNotFound) + return + default: + reqErr(w, err, http.StatusInternalServerError) + } + */ + + input := struct { + Name *string `form:"name"` + AlphaTag *string `form:"alpha_tag"` + TgGroup *string `form:"tg_group"` + Frequency *int32 `form:"frequency"` + Metadata []byte `form:"metadata"` + Tags []string `form:"tags"` + Alert *bool `form:"alert"` + AlertConfig []byte `form:"alert_config"` + Weight *float32 `form:"weight"` + }{} + + err = forms.Unmarshal(r, &input, forms.WithAcceptBlank(), forms.WithOmitEmpty()) + if err != nil { + reqErr(w, err, http.StatusBadRequest) + return + } + fmt.Fprintf(w, "%+v\n", input) } diff --git a/pkg/server/server.go b/pkg/server/server.go index ed11996..366bd6c 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -61,7 +61,7 @@ func New(ctx context.Context, cfg *config.Config) (*Server, error) { } tgCache := talkgroups.NewCache() - api := api.New(tgCache) + api := api.New() srv := &Server{ auth: authenticator, From f76db949e009d886cca8fe61aca5f034d64dfa72 Mon Sep 17 00:00:00 2001 From: Daniel Ponte Date: Sun, 10 Nov 2024 07:29:30 -0500 Subject: [PATCH 3/8] fix path --- pkg/auth/login.html | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/auth/login.html b/pkg/auth/login.html index eeb9783..9c99bfc 100644 --- a/pkg/auth/login.html +++ b/pkg/auth/login.html @@ -5,7 +5,7 @@
-
+ From 6e1640e4b4b6169d9b25ad2ebc87513f1601eb18 Mon Sep 17 00:00:00 2001 From: Daniel Ponte Date: Sun, 10 Nov 2024 10:13:38 -0500 Subject: [PATCH 4/8] PUT talkgroup --- pkg/api/api.go | 25 +++++++----- pkg/api/talkgroups.go | 47 +++++++++------------- pkg/database/extend.go | 3 ++ pkg/database/querier.go | 1 + pkg/database/talkgroups.sql.go | 60 +++++++++++++++++++++++++++++ pkg/talkgroups/cache.go | 30 ++++++++++++++- pkg/talkgroups/talkgroup.go | 7 ++++ sql/postgres/queries/talkgroups.sql | 16 +++++++- 8 files changed, 149 insertions(+), 40 deletions(-) diff --git a/pkg/api/api.go b/pkg/api/api.go index 4e22860..9910d9f 100644 --- a/pkg/api/api.go +++ b/pkg/api/api.go @@ -34,30 +34,36 @@ func (a *api) Subrouter() http.Handler { return r } -var statusMapping = map[error]int{ - talkgroups.ErrNotFound: http.StatusNotFound, - pgx.ErrNoRows: http.StatusNotFound, +type errResponse struct { + text string + code int } -func httpCode(err error) int { +var statusMapping = map[error]errResponse{ + talkgroups.ErrNotFound: {talkgroups.ErrNotFound.Error(), http.StatusNotFound}, + pgx.ErrNoRows: {"no such record", http.StatusNotFound}, +} + +func httpCode(err error) (string, int) { c, ok := statusMapping[err] if ok { - return c + return c.text, c.code } for e, c := range statusMapping { // check if err wraps an error we know about if errors.Is(err, e) { - return c + return c.text, c.code } } - return http.StatusInternalServerError + return err.Error(), http.StatusInternalServerError } func writeResponse(w http.ResponseWriter, r *http.Request, data interface{}, err error) { if err != nil { log.Error().Str("path", r.URL.Path).Err(err).Msg("request failed") - http.Error(w, err.Error(), httpCode(err)) + text, code := httpCode(err) + http.Error(w, text, code) return } @@ -66,7 +72,8 @@ func writeResponse(w http.ResponseWriter, r *http.Request, data interface{}, err err = enc.Encode(data) if err != nil { log.Error().Str("path", r.URL.Path).Err(err).Msg("response marshal failed") - http.Error(w, err.Error(), httpCode(err)) + text, code := httpCode(err) + http.Error(w, text, code) return } } diff --git a/pkg/api/talkgroups.go b/pkg/api/talkgroups.go index e16f8f3..e699b78 100644 --- a/pkg/api/talkgroups.go +++ b/pkg/api/talkgroups.go @@ -1,10 +1,11 @@ package api import ( - "fmt" + "encoding/json" "net/http" "dynatron.me/x/stillbox/internal/forms" + "dynatron.me/x/stillbox/pkg/database" "dynatron.me/x/stillbox/pkg/talkgroups" "github.com/go-chi/chi/v5" @@ -80,37 +81,27 @@ func (tga *talkgroupAPI) putTalkgroup(w http.ResponseWriter, r *http.Request) { badReq(w, err) return } - /* - ctx := r.Context() - tgs := talkgroups.StoreFrom(ctx) - tg, err := tgs.TG(ctx, id.ToID()) - switch err { - case nil: - case talkgroups.ErrNotFound: - reqErr(w, err, http.StatusNotFound) - return - default: - reqErr(w, err, http.StatusInternalServerError) - } - */ + ctx := r.Context() + tgs := talkgroups.StoreFrom(ctx) - input := struct { - Name *string `form:"name"` - AlphaTag *string `form:"alpha_tag"` - TgGroup *string `form:"tg_group"` - Frequency *int32 `form:"frequency"` - Metadata []byte `form:"metadata"` - Tags []string `form:"tags"` - Alert *bool `form:"alert"` - AlertConfig []byte `form:"alert_config"` - Weight *float32 `form:"weight"` - }{} + input := database.UpdateTalkgroupParams{} - err = forms.Unmarshal(r, &input, forms.WithAcceptBlank(), forms.WithOmitEmpty()) + err = forms.Unmarshal(r, &input, forms.WithTag("json"), forms.WithAcceptBlank(), forms.WithOmitEmpty()) if err != nil { - reqErr(w, err, http.StatusBadRequest) + writeResponse(w, r, nil, err) return } - fmt.Fprintf(w, "%+v\n", input) + input.ID = id.ToID().Pack() + + record, err := tgs.UpdateTG(ctx, input) + if err != nil { + writeResponse(w, r, nil, err) + return + } + + err = json.NewEncoder(w).Encode(record) + if err != nil { + writeResponse(w, r, nil, err) + } } diff --git a/pkg/database/extend.go b/pkg/database/extend.go index 3a96cd1..5f165a0 100644 --- a/pkg/database/extend.go +++ b/pkg/database/extend.go @@ -9,3 +9,6 @@ func (g GetTalkgroupsWithLearnedRow) GetLearned() bool { retur func (g GetTalkgroupsWithLearnedBySystemRow) GetTalkgroup() Talkgroup { return g.Talkgroup } func (g GetTalkgroupsWithLearnedBySystemRow) GetSystem() System { return g.System } func (g GetTalkgroupsWithLearnedBySystemRow) GetLearned() bool { return g.Learned } +func (g Talkgroup) GetTalkgroup() Talkgroup { return g } +func (g Talkgroup) GetSystem() System { return System{ID: int(g.SystemID)} } +func (g Talkgroup) GetLearned() bool { return false } diff --git a/pkg/database/querier.go b/pkg/database/querier.go index ba1ace1..6a814c3 100644 --- a/pkg/database/querier.go +++ b/pkg/database/querier.go @@ -39,6 +39,7 @@ type Querier interface { SetCallTranscript(ctx context.Context, iD uuid.UUID, transcript *string) error SetTalkgroupTags(ctx context.Context, sys int, tg int, tags []string) error UpdatePassword(ctx context.Context, username string, password string) error + UpdateTalkgroup(ctx context.Context, arg UpdateTalkgroupParams) (Talkgroup, error) } var _ Querier = (*Queries)(nil) diff --git a/pkg/database/talkgroups.sql.go b/pkg/database/talkgroups.sql.go index b32deb3..81d224f 100644 --- a/pkg/database/talkgroups.sql.go +++ b/pkg/database/talkgroups.sql.go @@ -472,3 +472,63 @@ func (q *Queries) SetTalkgroupTags(ctx context.Context, sys int, tg int, tags [] _, err := q.db.Exec(ctx, setTalkgroupTags, sys, tg, tags) return err } + +const updateTalkgroup = `-- name: UpdateTalkgroup :one +UPDATE talkgroups +SET + name = COALESCE($1, name), + alpha_tag = COALESCE($2, alpha_tag), + tg_group = COALESCE($3, tg_group), + frequency = COALESCE($4, frequency), + metadata = COALESCE($5, metadata), + tags = COALESCE($6, tags), + alert = COALESCE($7, alert), + alert_config = COALESCE($8, alert_config), + weight = COALESCE($9, weight) +WHERE id = $10 +RETURNING id, system_id, tgid, name, alpha_tag, tg_group, frequency, metadata, tags, alert, alert_config, weight +` + +type UpdateTalkgroupParams struct { + Name *string `json:"name"` + AlphaTag *string `json:"alpha_tag"` + TgGroup *string `json:"tg_group"` + Frequency *int32 `json:"frequency"` + Metadata []byte `json:"metadata"` + Tags []string `json:"tags"` + Alert *bool `json:"alert"` + AlertConfig []byte `json:"alert_config"` + Weight *float32 `json:"weight"` + ID int64 `json:"id"` +} + +func (q *Queries) UpdateTalkgroup(ctx context.Context, arg UpdateTalkgroupParams) (Talkgroup, error) { + row := q.db.QueryRow(ctx, updateTalkgroup, + arg.Name, + arg.AlphaTag, + arg.TgGroup, + arg.Frequency, + arg.Metadata, + arg.Tags, + arg.Alert, + arg.AlertConfig, + arg.Weight, + arg.ID, + ) + var i Talkgroup + err := row.Scan( + &i.ID, + &i.SystemID, + &i.Tgid, + &i.Name, + &i.AlphaTag, + &i.TgGroup, + &i.Frequency, + &i.Metadata, + &i.Tags, + &i.Alert, + &i.AlertConfig, + &i.Weight, + ) + return i, err +} diff --git a/pkg/talkgroups/cache.go b/pkg/talkgroups/cache.go index 14c2237..7ada361 100644 --- a/pkg/talkgroups/cache.go +++ b/pkg/talkgroups/cache.go @@ -17,7 +17,15 @@ import ( type tgMap map[ID]*Talkgroup +var ( + ErrNotFound = errors.New("talkgroup not found") + ErrNoSuchSystem = errors.New("no such system") +) + type Store interface { + // UpdateTG updates a talkgroup record. + UpdateTG(ctx context.Context, input database.UpdateTalkgroupParams) (*Talkgroup, error) + // TG retrieves a Talkgroup from the Store. TG(ctx context.Context, tg ID) (*Talkgroup, error) @@ -213,8 +221,6 @@ func (t *cache) Load(ctx context.Context, tgs []int64) error { return nil } -var ErrNotFound = errors.New("talkgroup not found") - func (t *cache) Weight(ctx context.Context, id ID, tm time.Time) float64 { tg, err := t.TG(ctx, id) if err != nil { @@ -290,3 +296,23 @@ func (t *cache) SystemName(ctx context.Context, id int) (name string, has bool) return n, has } + +func (t *cache) UpdateTG(ctx context.Context, input database.UpdateTalkgroupParams) (*Talkgroup, error) { + sysName, has := t.SystemName(ctx, int(Unpack(input.ID).System)) + if !has { + return nil, ErrNoSuchSystem + } + + tg, err := database.FromCtx(ctx).UpdateTalkgroup(ctx, input) + if err != nil { + return nil, err + } + + record := &Talkgroup{ + Talkgroup: tg, + System: database.System{ID: int(tg.SystemID), Name: sysName}, + } + t.add(record) + + return record, nil +} diff --git a/pkg/talkgroups/talkgroup.go b/pkg/talkgroups/talkgroup.go index 288e488..e3cca71 100644 --- a/pkg/talkgroups/talkgroup.go +++ b/pkg/talkgroups/talkgroup.go @@ -49,6 +49,13 @@ func (t ID) Pack() int64 { return int64((int64(t.System) << 32) | int64(t.Talkgroup)) } +func Unpack(id int64) ID { + return ID{ + System: uint32(id >> 32), + Talkgroup: uint32(id & 0xffffffff), + } +} + func (t ID) String() string { return fmt.Sprintf("%d:%d", t.System, t.Talkgroup) diff --git a/sql/postgres/queries/talkgroups.sql b/sql/postgres/queries/talkgroups.sql index d201d0e..732a848 100644 --- a/sql/postgres/queries/talkgroups.sql +++ b/sql/postgres/queries/talkgroups.sql @@ -104,6 +104,20 @@ FROM talkgroups_learned tgl JOIN systems sys ON tgl.system_id = sys.id WHERE ignored IS NOT TRUE; - -- name: GetSystemName :one SELECT name FROM systems WHERE id = sqlc.arg(system_id); + +-- name: UpdateTalkgroup :one +UPDATE talkgroups +SET + name = COALESCE(sqlc.narg('name'), name), + alpha_tag = COALESCE(sqlc.narg('alpha_tag'), alpha_tag), + tg_group = COALESCE(sqlc.narg('tg_group'), tg_group), + frequency = COALESCE(sqlc.narg('frequency'), frequency), + metadata = COALESCE(sqlc.narg('metadata'), metadata), + tags = COALESCE(sqlc.narg('tags'), tags), + alert = COALESCE(sqlc.narg('alert'), alert), + alert_config = COALESCE(sqlc.narg('alert_config'), alert_config), + weight = COALESCE(sqlc.narg('weight'), weight) +WHERE id = @id +RETURNING *; From e97c9ced0ebad57b2d90b1165cab92c855a5a63b Mon Sep 17 00:00:00 2001 From: Daniel Ponte Date: Sun, 10 Nov 2024 10:17:02 -0500 Subject: [PATCH 5/8] use profiler middleware --- pkg/server/profile.go | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/pkg/server/profile.go b/pkg/server/profile.go index eb5787e..dcb9d6f 100644 --- a/pkg/server/profile.go +++ b/pkg/server/profile.go @@ -4,14 +4,9 @@ package server import ( - "net/http/pprof" + "github.com/go-chi/chi/v5/middleware" ) func (s *Server) installPprof() { - r := s.r - r.HandleFunc("/debug/pprof/", pprof.Index) - r.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline) - r.HandleFunc("/debug/pprof/profile", pprof.Profile) - r.HandleFunc("/debug/pprof/symbol", pprof.Symbol) - r.HandleFunc("/debug/pprof/trace", pprof.Trace) + s.r.Mount("/debug", middleware.Profiler()) } From cecbeb78fe3d7ae0fd65b648764c0e98327d252c Mon Sep 17 00:00:00 2001 From: Daniel Ponte Date: Sun, 10 Nov 2024 10:28:04 -0500 Subject: [PATCH 6/8] Rename ctx keys --- pkg/database/database.go | 8 ++++---- pkg/server/routes.go | 4 ++-- pkg/talkgroups/cache.go | 8 ++++---- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/pkg/database/database.go b/pkg/database/database.go index f37b135..9f1c69b 100644 --- a/pkg/database/database.go +++ b/pkg/database/database.go @@ -56,13 +56,13 @@ func NewClient(ctx context.Context, conf config.DB) (*DB, error) { return db, nil } -type DBCtxKey string +type dBCtxKey string -const DBCtxKeyValue DBCtxKey = "dbctx" +const DBCtxKey dBCtxKey = "dbctx" // FromCtx returns the database handle from the provided Context. func FromCtx(ctx context.Context) *DB { - c, ok := ctx.Value(DBCtxKeyValue).(*DB) + c, ok := ctx.Value(DBCtxKey).(*DB) if !ok { panic("no DB in context") } @@ -72,7 +72,7 @@ func FromCtx(ctx context.Context) *DB { // CtxWithDB returns a Context with the provided database handle. func CtxWithDB(ctx context.Context, conn *DB) context.Context { - return context.WithValue(ctx, DBCtxKeyValue, conn) + return context.WithValue(ctx, DBCtxKey, conn) } // IsNoRows is a convenience function that returns whether a returned error is a database diff --git a/pkg/server/routes.go b/pkg/server/routes.go index 9b66b88..aa3a8a1 100644 --- a/pkg/server/routes.go +++ b/pkg/server/routes.go @@ -27,8 +27,8 @@ func (s *Server) setupRoutes() { } r := s.r - r.Use(middleware.WithValue(database.DBCtxKeyValue, s.db)) - r.Use(middleware.WithValue(talkgroups.StoreCtxKeyValue, s.tgs)) + r.Use(middleware.WithValue(database.DBCtxKey, s.db)) + r.Use(middleware.WithValue(talkgroups.StoreCtxKey, s.tgs)) s.installPprof() diff --git a/pkg/talkgroups/cache.go b/pkg/talkgroups/cache.go index 7ada361..16b8ac6 100644 --- a/pkg/talkgroups/cache.go +++ b/pkg/talkgroups/cache.go @@ -57,16 +57,16 @@ type Store interface { HUP(*config.Config) } -type CtxStoreKey string +type storeCtxKey string -const StoreCtxKeyValue CtxStoreKey = "store" +const StoreCtxKey storeCtxKey = "store" func CtxWithStore(ctx context.Context, s Store) context.Context { - return context.WithValue(ctx, StoreCtxKeyValue, s) + return context.WithValue(ctx, StoreCtxKey, s) } func StoreFrom(ctx context.Context) Store { - s, ok := ctx.Value(StoreCtxKeyValue).(Store) + s, ok := ctx.Value(StoreCtxKey).(Store) if !ok { return NewCache() } From 759c274950a88dd0f9ab812a910507362871afbd Mon Sep 17 00:00:00 2001 From: Daniel Ponte Date: Sun, 10 Nov 2024 14:40:50 -0500 Subject: [PATCH 7/8] Use chi render, improvements --- pkg/api/api.go | 88 +++++++++++++++++++++++++++---------------- pkg/api/talkgroups.go | 23 +++++------ 2 files changed, 68 insertions(+), 43 deletions(-) diff --git a/pkg/api/api.go b/pkg/api/api.go index 9910d9f..923f139 100644 --- a/pkg/api/api.go +++ b/pkg/api/api.go @@ -1,13 +1,13 @@ package api import ( - "encoding/json" "errors" "net/http" "dynatron.me/x/stillbox/pkg/talkgroups" "github.com/go-chi/chi/v5" + "github.com/go-chi/render" "github.com/go-viper/mapstructure/v2" "github.com/jackc/pgx/v5" "github.com/rs/zerolog/log" @@ -29,57 +29,81 @@ func New() API { func (a *api) Subrouter() http.Handler { r := chi.NewMux() - r.Mount("/talkgroup", new(talkgroupAPI).routes()) + r.Mount("/talkgroup", new(talkgroupAPI).Subrouter()) return r } type errResponse struct { - text string - code int + Err error `json:"-"` + Code int `json:"-"` + Error string `json:"error"` } -var statusMapping = map[error]errResponse{ - talkgroups.ErrNotFound: {talkgroups.ErrNotFound.Error(), http.StatusNotFound}, - pgx.ErrNoRows: {"no such record", http.StatusNotFound}, +func (e *errResponse) Render(w http.ResponseWriter, r *http.Request) error { + switch e.Code { + case http.StatusNotFound: + case http.StatusBadRequest: + default: + log.Error().Str("path", r.URL.Path).Err(e.Err).Int("code", e.Code).Str("msg", e.Error).Msg("request failed") + } + + render.Status(r, e.Code) + + return nil } -func httpCode(err error) (string, int) { +func badRequest(err error) render.Renderer { + return &errResponse{ + Err: err, + Code: http.StatusBadRequest, + Error: "Bad request", + } +} + +func recordNotFound(err error) render.Renderer { + return &errResponse{ + Err: err, + Code: http.StatusNotFound, + Error: "Record not found", + } +} + +func internalError(err error) render.Renderer { + return &errResponse{ + Err: err, + Code: http.StatusNotFound, + Error: "Internal server error", + } +} + +type errResponder func(error) render.Renderer + +var statusMapping = map[error]errResponder{ + talkgroups.ErrNotFound: recordNotFound, + pgx.ErrNoRows: recordNotFound, +} + +func autoError(err error) render.Renderer { c, ok := statusMapping[err] if ok { - return c.text, c.code + c(err) } for e, c := range statusMapping { // check if err wraps an error we know about if errors.Is(err, e) { - return c.text, c.code + return c(err) } } - return err.Error(), http.StatusInternalServerError + return internalError(err) } -func writeResponse(w http.ResponseWriter, r *http.Request, data interface{}, err error) { +func wErr(w http.ResponseWriter, r *http.Request, v render.Renderer) { + err := render.Render(w, r, v) if err != nil { - log.Error().Str("path", r.URL.Path).Err(err).Msg("request failed") - text, code := httpCode(err) - http.Error(w, text, code) - return + log.Error().Err(err).Msg("wErr render error") } - - w.Header().Set("Content-Type", "application/json") - enc := json.NewEncoder(w) - err = enc.Encode(data) - if err != nil { - log.Error().Str("path", r.URL.Path).Err(err).Msg("response marshal failed") - text, code := httpCode(err) - http.Error(w, text, code) - return - } -} - -func reqErr(w http.ResponseWriter, err error, code int) { - http.Error(w, err.Error(), code) } func decodeParams(d interface{}, r *http.Request) error { @@ -103,6 +127,6 @@ func decodeParams(d interface{}, r *http.Request) error { return dec.Decode(m) } -func badReq(w http.ResponseWriter, err error) { - reqErr(w, err, http.StatusBadRequest) +func respond(w http.ResponseWriter, r *http.Request, v interface{}) { + render.DefaultResponder(w, r, v) } diff --git a/pkg/api/talkgroups.go b/pkg/api/talkgroups.go index e699b78..ac74190 100644 --- a/pkg/api/talkgroups.go +++ b/pkg/api/talkgroups.go @@ -1,7 +1,6 @@ package api import ( - "encoding/json" "net/http" "dynatron.me/x/stillbox/internal/forms" @@ -14,7 +13,7 @@ import ( type talkgroupAPI struct { } -func (tga *talkgroupAPI) routes() http.Handler { +func (tga *talkgroupAPI) Subrouter() http.Handler { r := chi.NewMux() r.Get("/{system:\\d+}/{id:\\d+}", tga.talkgroup) @@ -57,7 +56,7 @@ func (tga *talkgroupAPI) talkgroup(w http.ResponseWriter, r *http.Request) { err := decodeParams(&p, r) if err != nil { - badReq(w, err) + wErr(w, r, badRequest(err)) return } @@ -71,14 +70,19 @@ func (tga *talkgroupAPI) talkgroup(w http.ResponseWriter, r *http.Request) { res, err = tgs.TGs(ctx, nil) } - writeResponse(w, r, res, err) + if err != nil { + wErr(w, r, autoError(err)) + return + } + + respond(w, r, res) } func (tga *talkgroupAPI) putTalkgroup(w http.ResponseWriter, r *http.Request) { var id tgParams err := decodeParams(&id, r) if err != nil { - badReq(w, err) + wErr(w, r, badRequest(err)) return } @@ -89,19 +93,16 @@ func (tga *talkgroupAPI) putTalkgroup(w http.ResponseWriter, r *http.Request) { err = forms.Unmarshal(r, &input, forms.WithTag("json"), forms.WithAcceptBlank(), forms.WithOmitEmpty()) if err != nil { - writeResponse(w, r, nil, err) + wErr(w, r, badRequest(err)) return } input.ID = id.ToID().Pack() record, err := tgs.UpdateTG(ctx, input) if err != nil { - writeResponse(w, r, nil, err) + wErr(w, r, autoError(err)) return } - err = json.NewEncoder(w).Encode(record) - if err != nil { - writeResponse(w, r, nil, err) - } + respond(w, r, record) } From fb3fb4eeab3693a2711808696d8e1091dfa019a2 Mon Sep 17 00:00:00 2001 From: Daniel Ponte Date: Sun, 10 Nov 2024 14:44:52 -0500 Subject: [PATCH 8/8] Rename api to rest --- pkg/{api => rest}/api.go | 2 +- pkg/{api => rest}/talkgroups.go | 14 +++++++------- pkg/server/routes.go | 2 +- pkg/server/server.go | 8 ++++---- 4 files changed, 13 insertions(+), 13 deletions(-) rename pkg/{api => rest}/api.go (99%) rename pkg/{api => rest}/talkgroups.go (83%) diff --git a/pkg/api/api.go b/pkg/rest/api.go similarity index 99% rename from pkg/api/api.go rename to pkg/rest/api.go index 923f139..0c08daa 100644 --- a/pkg/api/api.go +++ b/pkg/rest/api.go @@ -1,4 +1,4 @@ -package api +package rest import ( "errors" diff --git a/pkg/api/talkgroups.go b/pkg/rest/talkgroups.go similarity index 83% rename from pkg/api/talkgroups.go rename to pkg/rest/talkgroups.go index ac74190..0520af7 100644 --- a/pkg/api/talkgroups.go +++ b/pkg/rest/talkgroups.go @@ -1,4 +1,4 @@ -package api +package rest import ( "net/http" @@ -16,10 +16,10 @@ type talkgroupAPI struct { func (tga *talkgroupAPI) Subrouter() http.Handler { r := chi.NewMux() - r.Get("/{system:\\d+}/{id:\\d+}", tga.talkgroup) - r.Put("/{system:\\d+}/{id:\\d+}", tga.putTalkgroup) - r.Get("/{system:\\d+}/", tga.talkgroup) - r.Get("/", tga.talkgroup) + r.Get("/{system:\\d+}/{id:\\d+}", tga.get) + r.Put("/{system:\\d+}/{id:\\d+}", tga.put) + r.Get("/{system:\\d+}/", tga.get) + r.Get("/", tga.get) return r } @@ -48,7 +48,7 @@ func (t tgParams) ToID() talkgroups.ID { } } -func (tga *talkgroupAPI) talkgroup(w http.ResponseWriter, r *http.Request) { +func (tga *talkgroupAPI) get(w http.ResponseWriter, r *http.Request) { ctx := r.Context() tgs := talkgroups.StoreFrom(ctx) @@ -78,7 +78,7 @@ func (tga *talkgroupAPI) talkgroup(w http.ResponseWriter, r *http.Request) { respond(w, r, res) } -func (tga *talkgroupAPI) putTalkgroup(w http.ResponseWriter, r *http.Request) { +func (tga *talkgroupAPI) put(w http.ResponseWriter, r *http.Request) { var id tgParams err := decodeParams(&id, r) if err != nil { diff --git a/pkg/server/routes.go b/pkg/server/routes.go index aa3a8a1..b296d6b 100644 --- a/pkg/server/routes.go +++ b/pkg/server/routes.go @@ -38,7 +38,7 @@ func (s *Server) setupRoutes() { s.nex.PrivateRoutes(r) s.auth.PrivateRoutes(r) s.alerter.PrivateRoutes(r) - r.Mount("/api", s.api.Subrouter()) + r.Mount("/api", s.rest.Subrouter()) }) r.Group(func(r chi.Router) { diff --git a/pkg/server/server.go b/pkg/server/server.go index 366bd6c..ebec903 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -7,12 +7,12 @@ import ( "time" "dynatron.me/x/stillbox/pkg/alerting" - "dynatron.me/x/stillbox/pkg/api" "dynatron.me/x/stillbox/pkg/auth" "dynatron.me/x/stillbox/pkg/config" "dynatron.me/x/stillbox/pkg/database" "dynatron.me/x/stillbox/pkg/nexus" "dynatron.me/x/stillbox/pkg/notify" + "dynatron.me/x/stillbox/pkg/rest" "dynatron.me/x/stillbox/pkg/sinks" "dynatron.me/x/stillbox/pkg/sources" "dynatron.me/x/stillbox/pkg/talkgroups" @@ -37,7 +37,7 @@ type Server struct { notifier notify.Notifier hup chan os.Signal tgs talkgroups.Store - api api.API + rest rest.API } func New(ctx context.Context, cfg *config.Config) (*Server, error) { @@ -61,7 +61,7 @@ func New(ctx context.Context, cfg *config.Config) (*Server, error) { } tgCache := talkgroups.NewCache() - api := api.New() + api := rest.New() srv := &Server{ auth: authenticator, @@ -73,7 +73,7 @@ func New(ctx context.Context, cfg *config.Config) (*Server, error) { alerter: alerting.New(cfg.Alerting, tgCache, alerting.WithNotifier(notifier)), notifier: notifier, tgs: tgCache, - api: api, + rest: api, } srv.sinks.Register("database", sinks.NewDatabaseSink(srv.db), true)