diff --git a/internal/forms/forms.go b/internal/forms/forms.go index facb68c..bc28cc4 100644 --- a/internal/forms/forms.go +++ b/internal/forms/forms.go @@ -1,6 +1,7 @@ package forms import ( + "encoding/json" "errors" "fmt" "io" @@ -16,15 +17,22 @@ import ( ) var ( - ErrNotStruct = errors.New("destination is not a struct") - ErrNotPointer = errors.New("destination is not a pointer") + ErrNotStruct = errors.New("destination is not a struct") + ErrNotPointer = errors.New("destination is not a pointer") + ErrContentType = errors.New("bad content type") +) + +const ( + MaxMultipartMemory int64 = 1024 * 1024 // 1MB ) type options struct { - tagOverride *string - parseTimeIn *time.Location - parseLocal bool - acceptBlank bool + tagOverride *string + parseTimeIn *time.Location + parseLocal bool + acceptBlank bool + maxMultipartMemory int64 + defaultOmitEmpty bool } type Option func(*options) @@ -53,6 +61,18 @@ func WithTag(t string) Option { } } +func WithMaxMultipartSize(s int64) Option { + return func(o *options) { + o.maxMultipartMemory = s + } +} + +func WithOmitEmpty() Option { + return func(o *options) { + o.defaultOmitEmpty = true + } +} + func (o *options) Tag() string { if o.tagOverride != nil { return *o.tagOverride @@ -147,17 +167,19 @@ func (o *options) parseDuration(s string) (v time.Duration, set bool, err error) return } -func (o *options) iterFields(r *http.Request, rv reflect.Value) error { - rt := rv.Type() - for i := 0; i < rv.NumField(); i++ { - f := rv.Field(i) - tf := rt.Field(i) - if !tf.IsExported() && !tf.Anonymous { +var typeOfByteSlice = reflect.TypeOf([]byte(nil)) + +func (o *options) iterFields(r *http.Request, destStruct reflect.Value) error { + structType := destStruct.Type() + for i := 0; i < destStruct.NumField(); i++ { + destFieldVal := destStruct.Field(i) + fieldType := structType.Field(i) + if !fieldType.IsExported() && !fieldType.Anonymous { continue } - if f.Kind() == reflect.Struct && tf.Anonymous { - err := o.iterFields(r, f) + if destFieldVal.Kind() == reflect.Struct && fieldType.Anonymous { + err := o.iterFields(r, destFieldVal) if err != nil { return err } @@ -165,51 +187,38 @@ func (o *options) iterFields(r *http.Request, rv reflect.Value) error { var tAr []string var formField string - formTag, has := rt.Field(i).Tag.Lookup(o.Tag()) + var omitEmpty bool + if o.defaultOmitEmpty { + omitEmpty = true + } + + formTag, has := structType.Field(i).Tag.Lookup(o.Tag()) if has { tAr = strings.Split(formTag, ",") formField = tAr[0] + for _, v := range tAr[1:] { + if v == "omitempty" { + omitEmpty = true + break + } + } } + if !has || formField == "-" { continue } - fi := f.Interface() + destFieldIntf := destFieldVal.Interface() - switch v := fi.(type) { - case string, *string: - s := r.Form.Get(formField) - setVal(f, s != "" || o.acceptBlank, v, s) - case int, uint, *int, *uint: - ff := r.Form.Get(formField) - val, set, err := o.parseInt(ff) - if err != nil { - return err - } - setVal(f, set, v, val) - case float64: - ff := r.Form.Get(formField) - val, set, err := o.parseFloat64(ff) - if err != nil { - return err - } - setVal(f, set, v, val) - case bool, *bool: - ff := r.Form.Get(formField) - val, set, err := o.parseBool(ff) - if err != nil { - return err - } - setVal(f, set, v, val) - case []byte: + if destFieldVal.Kind() == reflect.Slice && destFieldVal.Type() == typeOfByteSlice { file, hdr, err := r.FormFile(formField) if err != nil { return fmt.Errorf("get form file: %w", err) } - nameField, hasFilename := rt.Field(i).Tag.Lookup("filenameField") + nameField, hasFilename := structType.Field(i).Tag.Lookup("filenameField") if hasFilename { - fnf := rv.FieldByName(nameField) + fnf := destStruct.FieldByName(nameField) if fnf == (reflect.Value{}) { panic(fmt.Errorf("filenameField '%s' does not exist", nameField)) } @@ -221,23 +230,52 @@ func (o *options) iterFields(r *http.Request, rv reflect.Value) error { return fmt.Errorf("file read: %w", err) } - f.SetBytes(audioBytes) + destFieldVal.SetBytes(audioBytes) + + continue + } + + if !r.Form.Has(formField) && omitEmpty { + continue + } + + ff := r.Form.Get(formField) + + switch v := destFieldIntf.(type) { + case string, *string: + setVal(destFieldVal, ff != "" || o.acceptBlank, ff) + case int, uint, *int, *uint: + val, set, err := o.parseInt(ff) + if err != nil { + return err + } + setVal(destFieldVal, set, val) + case float64: + val, set, err := o.parseFloat64(ff) + if err != nil { + return err + } + setVal(destFieldVal, set, val) + case bool, *bool: + val, set, err := o.parseBool(ff) + if err != nil { + return err + } + setVal(destFieldVal, set, val) case time.Time, *time.Time, jsontime.Time, *jsontime.Time: - tval := r.Form.Get(formField) - t, set, err := o.parseTime(tval) + t, set, err := o.parseTime(ff) if err != nil { return err } - setVal(f, set, v, t) + setVal(destFieldVal, set, t) case time.Duration, *time.Duration, jsontime.Duration, *jsontime.Duration: - dval := r.Form.Get(formField) - d, set, err := o.parseDuration(dval) + d, set, err := o.parseDuration(ff) if err != nil { return err } - setVal(f, set, v, d) + setVal(destFieldVal, set, d) case []int: - val := strings.Trim(r.Form.Get(formField), "[]") + val := strings.Trim(ff, "[]") if val == "" && o.acceptBlank { continue } @@ -249,7 +287,7 @@ func (o *options) iterFields(r *http.Request, rv reflect.Value) error { ar = append(ar, i) } } - f.Set(reflect.ValueOf(ar)) + destFieldVal.Set(reflect.ValueOf(ar)) default: panic(fmt.Errorf("unsupported type %T", v)) } @@ -258,48 +296,77 @@ func (o *options) iterFields(r *http.Request, rv reflect.Value) error { return nil } -func setVal(setField reflect.Value, set bool, fv any, sv any) { +func setVal(destFieldVal reflect.Value, set bool, src any) { if !set { return } - rv := reflect.TypeOf(fv) - svo := reflect.ValueOf(sv) + destType := destFieldVal.Type() + srcVal := reflect.ValueOf(src) - if svo.CanConvert(rv) { - svo = svo.Convert(rv) + if srcVal.Kind() == reflect.Ptr { + srcVal = srcVal.Elem() } - if rv.Kind() == reflect.Ptr { - svo = svo.Addr() + if destType.Kind() == reflect.Ptr { + if !srcVal.CanAddr() { + if srcVal.CanConvert(destType.Elem()) { + srcVal = srcVal.Convert(destType.Elem()) + } + copy := reflect.New(srcVal.Type()) + copy.Elem().Set(srcVal) + srcVal = copy + } + } else if srcVal.CanConvert(destFieldVal.Type()) { + srcVal = srcVal.Convert(destFieldVal.Type()) } - setField.Set(svo) + destFieldVal.Set(srcVal) } func Unmarshal(r *http.Request, dest any, opt ...Option) error { - o := options{} + o := options{ + maxMultipartMemory: MaxMultipartMemory, + } + for _, opt := range opt { opt(&o) } - rv := reflect.ValueOf(dest) - if k := rv.Kind(); k == reflect.Ptr { - rv = rv.Elem() - } else { - return ErrNotPointer - } + contentType := strings.Split(r.Header.Get("Content-Type"), ";")[0] - if rv.Kind() != reflect.Struct { - return ErrNotStruct - } + switch contentType { + case "multipart/form-data": + err := r.ParseMultipartForm(o.maxMultipartMemory) + if err != nil { + return fmt.Errorf("ParseForm: %w", err) + } - if strings.HasPrefix(r.Header.Get("Content-Type"), "application/x-www-form-urlencoded") { + return o.unmarshalForm(r, dest) + case "application/x-www-form-urlencoded": err := r.ParseForm() if err != nil { return fmt.Errorf("ParseForm: %w", err) } + return o.unmarshalForm(r, dest) + case "application/json": + return json.NewDecoder(r.Body).Decode(dest) } - return o.iterFields(r, rv) + return ErrContentType +} + +func (o *options) unmarshalForm(r *http.Request, dest any) error { + destVal := reflect.ValueOf(dest) + if k := destVal.Kind(); k == reflect.Ptr { + destVal = destVal.Elem() + } else { + return ErrNotPointer + } + + if destVal.Kind() != reflect.Struct { + return ErrNotStruct + } + + return o.iterFields(r, destVal) } diff --git a/internal/forms/forms_test.go b/internal/forms/forms_test.go index c75b839..5f73dca 100644 --- a/internal/forms/forms_test.go +++ b/internal/forms/forms_test.go @@ -8,6 +8,7 @@ import ( "testing" "time" + "dynatron.me/x/stillbox/internal/common" "dynatron.me/x/stillbox/internal/forms" "dynatron.me/x/stillbox/internal/jsontime" @@ -54,6 +55,14 @@ type urlEncTestJT struct { ScoreEnd jsontime.Time `json:"scoreEnd"` } +type ptrTestJT struct { + LookbackDays uint `form:"lookbackDays"` + HalfLife *jsontime.Duration `form:"halfLife"` + Recent *string `form:"recent"` + ScoreStart *jsontime.Time `form:"scoreStart"` + ScoreEnd jsontime.Time `form:"scoreEnd"` +} + var ( UrlEncTest = urlEncTest{ LookbackDays: 7, @@ -69,6 +78,13 @@ var ( ScoreStart: jsontime.Time(time.Date(2024, time.October, 28, 9, 25, 0, 0, time.UTC)), } + PtrTestJT = ptrTestJT{ + LookbackDays: 7, + HalfLife: common.PtrTo(jsontime.Duration(30 * time.Minute)), + Recent: common.PtrTo("2h0m0s"), + ScoreStart: common.PtrTo(jsontime.Time(time.Date(2024, time.October, 28, 9, 25, 0, 0, time.UTC))), + } + UrlEncTestJTLocal = urlEncTestJT{ LookbackDays: 7, HalfLife: jsontime.Duration(30 * time.Minute), @@ -122,29 +138,29 @@ func TestUnmarshal(t *testing.T) { name string r *http.Request dest any - compare any + expect any expectErr error opts []forms.Option }{ { - name: "base case", - r: makeRequest("call1.http"), - dest: &callUploadRequest{}, - compare: &Call1, - opts: []forms.Option{forms.WithAcceptBlank()}, + name: "base case", + r: makeRequest("call1.http"), + dest: &callUploadRequest{}, + expect: &Call1, + opts: []forms.Option{forms.WithAcceptBlank()}, }, { name: "base case no accept blank", r: makeRequest("call1.http"), dest: &callUploadRequest{}, - compare: &Call1, + expect: &Call1, expectErr: errors.New(`parsebool(''): strconv.ParseBool: parsing "": invalid syntax`), }, { name: "not a pointer", r: makeRequest("call1.http"), dest: callUploadRequest{}, - compare: callUploadRequest{}, + expect: callUploadRequest{}, expectErr: forms.ErrNotPointer, opts: []forms.Option{forms.WithAcceptBlank()}, }, @@ -152,7 +168,7 @@ func TestUnmarshal(t *testing.T) { name: "not a struct", r: makeRequest("call1.http"), dest: &str, - compare: callUploadRequest{}, + expect: callUploadRequest{}, expectErr: forms.ErrNotStruct, opts: []forms.Option{forms.WithAcceptBlank()}, }, @@ -160,44 +176,51 @@ func TestUnmarshal(t *testing.T) { name: "url encoded", r: makeRequest("urlenc.http"), dest: &urlEncTest{}, - compare: &UrlEncTest, + expect: &UrlEncTest, expectErr: errors.New(`Could not find format for ""`), }, { - name: "url encoded accept blank", - r: makeRequest("urlenc.http"), - dest: &urlEncTest{}, - compare: &UrlEncTest, - opts: []forms.Option{forms.WithAcceptBlank()}, + name: "url encoded accept blank", + r: makeRequest("urlenc.http"), + dest: &urlEncTest{}, + expect: &UrlEncTest, + opts: []forms.Option{forms.WithAcceptBlank()}, + }, + { + name: "url encoded accept blank pointer", + r: makeRequest("urlenc.http"), + dest: &ptrTestJT{}, + expect: &PtrTestJT, + opts: []forms.Option{forms.WithAcceptBlank()}, }, { name: "url encoded jsontime", r: makeRequest("urlenc.http"), dest: &urlEncTestJT{}, - compare: &UrlEncTestJT, + expect: &UrlEncTestJT, expectErr: errors.New(`Could not find format for ""`), opts: []forms.Option{forms.WithTag("json")}, }, { - name: "url encoded jsontime with tz", - r: makeRequest("urlenc.http"), - dest: &urlEncTestJT{}, - compare: &UrlEncTestJT, - opts: []forms.Option{forms.WithAcceptBlank(), forms.WithParseTimeInTZ(time.UTC), forms.WithTag("json")}, + name: "url encoded jsontime with tz", + r: makeRequest("urlenc.http"), + dest: &urlEncTestJT{}, + expect: &UrlEncTestJT, + opts: []forms.Option{forms.WithAcceptBlank(), forms.WithParseTimeInTZ(time.UTC), forms.WithTag("json")}, }, { - name: "url encoded jsontime with local", - r: makeRequest("urlenc.http"), - dest: &urlEncTestJT{}, - compare: &UrlEncTestJTLocal, - opts: []forms.Option{forms.WithAcceptBlank(), forms.WithParseLocalTime(), forms.WithTag("json")}, + name: "url encoded jsontime with local", + r: makeRequest("urlenc.http"), + dest: &urlEncTestJT{}, + expect: &UrlEncTestJTLocal, + opts: []forms.Option{forms.WithAcceptBlank(), forms.WithParseLocalTime(), forms.WithTag("json")}, }, { - name: "sim real data", - r: makeRequest("urlenc2.http"), - dest: &alerting.Simulation{}, - compare: realSim, - opts: []forms.Option{forms.WithAcceptBlank(), forms.WithParseLocalTime()}, + name: "sim real data", + r: makeRequest("urlenc2.http"), + dest: &alerting.Simulation{}, + expect: realSim, + opts: []forms.Option{forms.WithAcceptBlank(), forms.WithParseLocalTime()}, }, } @@ -209,7 +232,7 @@ func TestUnmarshal(t *testing.T) { assert.Contains(t, tc.expectErr.Error(), err.Error()) } else { require.NoError(t, err) - assert.Equal(t, tc.compare, tc.dest) + assert.Equal(t, tc.expect, tc.dest) } }) } diff --git a/pkg/alerting/alert/alert.go b/pkg/alerting/alert/alert.go index 52d1bf0..6f3ad84 100644 --- a/pkg/alerting/alert/alert.go +++ b/pkg/alerting/alert/alert.go @@ -44,8 +44,7 @@ func (a *Alert) ToAddAlertParams() database.AddAlertParams { } } -// makeAlert creates a notification for later rendering by the template. -// It takes a talkgroup Score as input. +// Make creates an alert for later rendering or storage. func Make(ctx context.Context, store talkgroups.Store, score trending.Score[talkgroups.ID], origScore float64) (Alert, error) { d := Alert{ ID: uuid.New(), diff --git a/pkg/alerting/simulate.go b/pkg/alerting/simulate.go index 6d646fa..0ebefb2 100644 --- a/pkg/alerting/simulate.go +++ b/pkg/alerting/simulate.go @@ -2,7 +2,6 @@ package alerting import ( "context" - "encoding/json" "errors" "fmt" "net/http" @@ -114,24 +113,15 @@ func (s *Simulation) Simulate(ctx context.Context) (trending.Scores[talkgroups.I func (as *alerter) simulateHandler(w http.ResponseWriter, r *http.Request) { ctx := r.Context() s := new(Simulation) - switch r.Header.Get("Content-Type") { - case "application/json": - err := json.NewDecoder(r.Body).Decode(s) - if err != nil { - err = fmt.Errorf("simulate decode: %w", err) - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - default: - err := forms.Unmarshal(r, s, forms.WithAcceptBlank(), forms.WithParseLocalTime()) - if err != nil { - err = fmt.Errorf("simulate unmarshal: %w", err) - http.Error(w, err.Error(), http.StatusBadRequest) - return - } + + err := forms.Unmarshal(r, s, forms.WithAcceptBlank(), forms.WithParseLocalTime()) + if err != nil { + err = fmt.Errorf("simulate unmarshal: %w", err) + http.Error(w, err.Error(), http.StatusBadRequest) + return } - err := s.verify() + err = s.verify() if err != nil { err = fmt.Errorf("simulation profile verify: %w", err) http.Error(w, err.Error(), http.StatusBadRequest) diff --git a/pkg/api/api.go b/pkg/api/api.go index 19ab3ef..580da85 100644 --- a/pkg/api/api.go +++ b/pkg/api/api.go @@ -18,13 +18,10 @@ type API interface { } type api struct { - tgs talkgroups.Store } -func New(tgs talkgroups.Store) API { - s := &api{ - tgs: tgs, - } +func New() API { + s := new(api) return s } @@ -32,9 +29,8 @@ func New(tgs talkgroups.Store) API { func (a *api) Subrouter() http.Handler { r := chi.NewMux() - r.Get("/talkgroup/{system:\\d+}/{id:\\d+}", a.talkgroup) - r.Get("/talkgroup/{system:\\d+}/", a.talkgroup) - r.Get("/talkgroup/", a.talkgroup) + r.Mount("/talkgroup", new(talkgroupAPI).routes()) + return r } @@ -58,7 +54,7 @@ func httpCode(err error) int { return http.StatusInternalServerError } -func (a *api) writeResponse(w http.ResponseWriter, r *http.Request, data interface{}, err error) { +func writeResponse(w http.ResponseWriter, r *http.Request, data interface{}, err error) { if err != nil { log.Error().Str("path", r.URL.Path).Err(err).Msg("request failed") http.Error(w, err.Error(), httpCode(err)) @@ -96,32 +92,6 @@ func decodeParams(d interface{}, r *http.Request) error { return dec.Decode(m) } -func (a *api) badReq(w http.ResponseWriter, err error) { +func badReq(w http.ResponseWriter, err error) { http.Error(w, err.Error(), http.StatusBadRequest) } - -func (a *api) talkgroup(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - p := struct { - System *int `param:"system"` - ID *int `param:"id"` - }{} - - err := decodeParams(&p, r) - if err != nil { - a.badReq(w, err) - return - } - - var res interface{} - switch { - case p.System != nil && p.ID != nil: - res, err = a.tgs.TG(ctx, talkgroups.TG(*p.System, *p.ID)) - case p.System != nil: - res, err = a.tgs.SystemTGs(ctx, int32(*p.System)) - default: - res, err = a.tgs.TGs(ctx, nil) - } - - a.writeResponse(w, r, res, err) -} diff --git a/pkg/api/talkgroups.go b/pkg/api/talkgroups.go new file mode 100644 index 0000000..9051539 --- /dev/null +++ b/pkg/api/talkgroups.go @@ -0,0 +1,54 @@ +package api + +import ( + "net/http" + + "dynatron.me/x/stillbox/internal/forms" + "dynatron.me/x/stillbox/pkg/talkgroups" + + "github.com/go-chi/chi/v5" +) + +type talkgroupAPI struct { +} + +func (tga *talkgroupAPI) routes() http.Handler { + r := chi.NewMux() + + r.Get("/{system:\\d+}/{id:\\d+}", tga.talkgroup) + r.Put("/{system:\\d+}/{id:\\d+}", tga.putTalkgroup) + r.Get("/{system:\\d+}/", tga.talkgroup) + r.Get("/", tga.talkgroup) + + return r +} + +func (tga *talkgroupAPI) talkgroup(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + tgs := talkgroups.StoreFrom(ctx) + p := struct { + System *int `param:"system"` + ID *int `param:"id"` + }{} + + err := decodeParams(&p, r) + if err != nil { + badReq(w, err) + return + } + + var res interface{} + switch { + case p.System != nil && p.ID != nil: + res, err = tgs.TG(ctx, talkgroups.TG(*p.System, *p.ID)) + case p.System != nil: + res, err = tgs.SystemTGs(ctx, int32(*p.System)) + default: + res, err = tgs.TGs(ctx, nil) + } + + writeResponse(w, r, res, err) +} + +func (tga *talkgroupAPI) putTalkgroup(w http.ResponseWriter, r *http.Request) { +} diff --git a/pkg/database/database.go b/pkg/database/database.go index 37794ff..f37b135 100644 --- a/pkg/database/database.go +++ b/pkg/database/database.go @@ -58,11 +58,11 @@ func NewClient(ctx context.Context, conf config.DB) (*DB, error) { type DBCtxKey string -const DBCTXKeyValue DBCtxKey = "dbctx" +const DBCtxKeyValue DBCtxKey = "dbctx" // FromCtx returns the database handle from the provided Context. func FromCtx(ctx context.Context) *DB { - c, ok := ctx.Value(DBCTXKeyValue).(*DB) + c, ok := ctx.Value(DBCtxKeyValue).(*DB) if !ok { panic("no DB in context") } @@ -72,7 +72,7 @@ func FromCtx(ctx context.Context) *DB { // CtxWithDB returns a Context with the provided database handle. func CtxWithDB(ctx context.Context, conn *DB) context.Context { - return context.WithValue(ctx, DBCTXKeyValue, conn) + return context.WithValue(ctx, DBCtxKeyValue, conn) } // IsNoRows is a convenience function that returns whether a returned error is a database diff --git a/pkg/server/routes.go b/pkg/server/routes.go index bcc7c3b..9b66b88 100644 --- a/pkg/server/routes.go +++ b/pkg/server/routes.go @@ -9,6 +9,7 @@ import ( "dynatron.me/x/stillbox/internal/version" "dynatron.me/x/stillbox/pkg/config" "dynatron.me/x/stillbox/pkg/database" + "dynatron.me/x/stillbox/pkg/talkgroups" "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" "github.com/go-chi/httprate" @@ -26,7 +27,8 @@ func (s *Server) setupRoutes() { } r := s.r - r.Use(middleware.WithValue(database.DBCTXKeyValue, s.db)) + r.Use(middleware.WithValue(database.DBCtxKeyValue, s.db)) + r.Use(middleware.WithValue(talkgroups.StoreCtxKeyValue, s.tgs)) s.installPprof() diff --git a/pkg/talkgroups/cache.go b/pkg/talkgroups/cache.go index 85ac5d0..14c2237 100644 --- a/pkg/talkgroups/cache.go +++ b/pkg/talkgroups/cache.go @@ -49,16 +49,16 @@ type Store interface { HUP(*config.Config) } -type CtxStoreKeyT string +type CtxStoreKey string -const CtxStoreKey CtxStoreKeyT = "store" +const StoreCtxKeyValue CtxStoreKey = "store" func CtxWithStore(ctx context.Context, s Store) context.Context { - return context.WithValue(ctx, CtxStoreKey, s) + return context.WithValue(ctx, StoreCtxKeyValue, s) } func StoreFrom(ctx context.Context) Store { - s, ok := ctx.Value(CtxStoreKey).(Store) + s, ok := ctx.Value(StoreCtxKeyValue).(Store) if !ok { return NewCache() }