diff --git a/internal/forms/forms.go b/internal/forms/forms.go index facb68c..bc28cc4 100644 --- a/internal/forms/forms.go +++ b/internal/forms/forms.go @@ -1,6 +1,7 @@ package forms import ( + "encoding/json" "errors" "fmt" "io" @@ -16,15 +17,22 @@ import ( ) var ( - ErrNotStruct = errors.New("destination is not a struct") - ErrNotPointer = errors.New("destination is not a pointer") + ErrNotStruct = errors.New("destination is not a struct") + ErrNotPointer = errors.New("destination is not a pointer") + ErrContentType = errors.New("bad content type") +) + +const ( + MaxMultipartMemory int64 = 1024 * 1024 // 1MB ) type options struct { - tagOverride *string - parseTimeIn *time.Location - parseLocal bool - acceptBlank bool + tagOverride *string + parseTimeIn *time.Location + parseLocal bool + acceptBlank bool + maxMultipartMemory int64 + defaultOmitEmpty bool } type Option func(*options) @@ -53,6 +61,18 @@ func WithTag(t string) Option { } } +func WithMaxMultipartSize(s int64) Option { + return func(o *options) { + o.maxMultipartMemory = s + } +} + +func WithOmitEmpty() Option { + return func(o *options) { + o.defaultOmitEmpty = true + } +} + func (o *options) Tag() string { if o.tagOverride != nil { return *o.tagOverride @@ -147,17 +167,19 @@ func (o *options) parseDuration(s string) (v time.Duration, set bool, err error) return } -func (o *options) iterFields(r *http.Request, rv reflect.Value) error { - rt := rv.Type() - for i := 0; i < rv.NumField(); i++ { - f := rv.Field(i) - tf := rt.Field(i) - if !tf.IsExported() && !tf.Anonymous { +var typeOfByteSlice = reflect.TypeOf([]byte(nil)) + +func (o *options) iterFields(r *http.Request, destStruct reflect.Value) error { + structType := destStruct.Type() + for i := 0; i < destStruct.NumField(); i++ { + destFieldVal := destStruct.Field(i) + fieldType := structType.Field(i) + if !fieldType.IsExported() && !fieldType.Anonymous { continue } - if f.Kind() == reflect.Struct && tf.Anonymous { - err := o.iterFields(r, f) + if destFieldVal.Kind() == reflect.Struct && fieldType.Anonymous { + err := o.iterFields(r, destFieldVal) if err != nil { return err } @@ -165,51 +187,38 @@ func (o *options) iterFields(r *http.Request, rv reflect.Value) error { var tAr []string var formField string - formTag, has := rt.Field(i).Tag.Lookup(o.Tag()) + var omitEmpty bool + if o.defaultOmitEmpty { + omitEmpty = true + } + + formTag, has := structType.Field(i).Tag.Lookup(o.Tag()) if has { tAr = strings.Split(formTag, ",") formField = tAr[0] + for _, v := range tAr[1:] { + if v == "omitempty" { + omitEmpty = true + break + } + } } + if !has || formField == "-" { continue } - fi := f.Interface() + destFieldIntf := destFieldVal.Interface() - switch v := fi.(type) { - case string, *string: - s := r.Form.Get(formField) - setVal(f, s != "" || o.acceptBlank, v, s) - case int, uint, *int, *uint: - ff := r.Form.Get(formField) - val, set, err := o.parseInt(ff) - if err != nil { - return err - } - setVal(f, set, v, val) - case float64: - ff := r.Form.Get(formField) - val, set, err := o.parseFloat64(ff) - if err != nil { - return err - } - setVal(f, set, v, val) - case bool, *bool: - ff := r.Form.Get(formField) - val, set, err := o.parseBool(ff) - if err != nil { - return err - } - setVal(f, set, v, val) - case []byte: + if destFieldVal.Kind() == reflect.Slice && destFieldVal.Type() == typeOfByteSlice { file, hdr, err := r.FormFile(formField) if err != nil { return fmt.Errorf("get form file: %w", err) } - nameField, hasFilename := rt.Field(i).Tag.Lookup("filenameField") + nameField, hasFilename := structType.Field(i).Tag.Lookup("filenameField") if hasFilename { - fnf := rv.FieldByName(nameField) + fnf := destStruct.FieldByName(nameField) if fnf == (reflect.Value{}) { panic(fmt.Errorf("filenameField '%s' does not exist", nameField)) } @@ -221,23 +230,52 @@ func (o *options) iterFields(r *http.Request, rv reflect.Value) error { return fmt.Errorf("file read: %w", err) } - f.SetBytes(audioBytes) + destFieldVal.SetBytes(audioBytes) + + continue + } + + if !r.Form.Has(formField) && omitEmpty { + continue + } + + ff := r.Form.Get(formField) + + switch v := destFieldIntf.(type) { + case string, *string: + setVal(destFieldVal, ff != "" || o.acceptBlank, ff) + case int, uint, *int, *uint: + val, set, err := o.parseInt(ff) + if err != nil { + return err + } + setVal(destFieldVal, set, val) + case float64: + val, set, err := o.parseFloat64(ff) + if err != nil { + return err + } + setVal(destFieldVal, set, val) + case bool, *bool: + val, set, err := o.parseBool(ff) + if err != nil { + return err + } + setVal(destFieldVal, set, val) case time.Time, *time.Time, jsontime.Time, *jsontime.Time: - tval := r.Form.Get(formField) - t, set, err := o.parseTime(tval) + t, set, err := o.parseTime(ff) if err != nil { return err } - setVal(f, set, v, t) + setVal(destFieldVal, set, t) case time.Duration, *time.Duration, jsontime.Duration, *jsontime.Duration: - dval := r.Form.Get(formField) - d, set, err := o.parseDuration(dval) + d, set, err := o.parseDuration(ff) if err != nil { return err } - setVal(f, set, v, d) + setVal(destFieldVal, set, d) case []int: - val := strings.Trim(r.Form.Get(formField), "[]") + val := strings.Trim(ff, "[]") if val == "" && o.acceptBlank { continue } @@ -249,7 +287,7 @@ func (o *options) iterFields(r *http.Request, rv reflect.Value) error { ar = append(ar, i) } } - f.Set(reflect.ValueOf(ar)) + destFieldVal.Set(reflect.ValueOf(ar)) default: panic(fmt.Errorf("unsupported type %T", v)) } @@ -258,48 +296,77 @@ func (o *options) iterFields(r *http.Request, rv reflect.Value) error { return nil } -func setVal(setField reflect.Value, set bool, fv any, sv any) { +func setVal(destFieldVal reflect.Value, set bool, src any) { if !set { return } - rv := reflect.TypeOf(fv) - svo := reflect.ValueOf(sv) + destType := destFieldVal.Type() + srcVal := reflect.ValueOf(src) - if svo.CanConvert(rv) { - svo = svo.Convert(rv) + if srcVal.Kind() == reflect.Ptr { + srcVal = srcVal.Elem() } - if rv.Kind() == reflect.Ptr { - svo = svo.Addr() + if destType.Kind() == reflect.Ptr { + if !srcVal.CanAddr() { + if srcVal.CanConvert(destType.Elem()) { + srcVal = srcVal.Convert(destType.Elem()) + } + copy := reflect.New(srcVal.Type()) + copy.Elem().Set(srcVal) + srcVal = copy + } + } else if srcVal.CanConvert(destFieldVal.Type()) { + srcVal = srcVal.Convert(destFieldVal.Type()) } - setField.Set(svo) + destFieldVal.Set(srcVal) } func Unmarshal(r *http.Request, dest any, opt ...Option) error { - o := options{} + o := options{ + maxMultipartMemory: MaxMultipartMemory, + } + for _, opt := range opt { opt(&o) } - rv := reflect.ValueOf(dest) - if k := rv.Kind(); k == reflect.Ptr { - rv = rv.Elem() - } else { - return ErrNotPointer - } + contentType := strings.Split(r.Header.Get("Content-Type"), ";")[0] - if rv.Kind() != reflect.Struct { - return ErrNotStruct - } + switch contentType { + case "multipart/form-data": + err := r.ParseMultipartForm(o.maxMultipartMemory) + if err != nil { + return fmt.Errorf("ParseForm: %w", err) + } - if strings.HasPrefix(r.Header.Get("Content-Type"), "application/x-www-form-urlencoded") { + return o.unmarshalForm(r, dest) + case "application/x-www-form-urlencoded": err := r.ParseForm() if err != nil { return fmt.Errorf("ParseForm: %w", err) } + return o.unmarshalForm(r, dest) + case "application/json": + return json.NewDecoder(r.Body).Decode(dest) } - return o.iterFields(r, rv) + return ErrContentType +} + +func (o *options) unmarshalForm(r *http.Request, dest any) error { + destVal := reflect.ValueOf(dest) + if k := destVal.Kind(); k == reflect.Ptr { + destVal = destVal.Elem() + } else { + return ErrNotPointer + } + + if destVal.Kind() != reflect.Struct { + return ErrNotStruct + } + + return o.iterFields(r, destVal) } diff --git a/internal/forms/forms_test.go b/internal/forms/forms_test.go index c75b839..5f73dca 100644 --- a/internal/forms/forms_test.go +++ b/internal/forms/forms_test.go @@ -8,6 +8,7 @@ import ( "testing" "time" + "dynatron.me/x/stillbox/internal/common" "dynatron.me/x/stillbox/internal/forms" "dynatron.me/x/stillbox/internal/jsontime" @@ -54,6 +55,14 @@ type urlEncTestJT struct { ScoreEnd jsontime.Time `json:"scoreEnd"` } +type ptrTestJT struct { + LookbackDays uint `form:"lookbackDays"` + HalfLife *jsontime.Duration `form:"halfLife"` + Recent *string `form:"recent"` + ScoreStart *jsontime.Time `form:"scoreStart"` + ScoreEnd jsontime.Time `form:"scoreEnd"` +} + var ( UrlEncTest = urlEncTest{ LookbackDays: 7, @@ -69,6 +78,13 @@ var ( ScoreStart: jsontime.Time(time.Date(2024, time.October, 28, 9, 25, 0, 0, time.UTC)), } + PtrTestJT = ptrTestJT{ + LookbackDays: 7, + HalfLife: common.PtrTo(jsontime.Duration(30 * time.Minute)), + Recent: common.PtrTo("2h0m0s"), + ScoreStart: common.PtrTo(jsontime.Time(time.Date(2024, time.October, 28, 9, 25, 0, 0, time.UTC))), + } + UrlEncTestJTLocal = urlEncTestJT{ LookbackDays: 7, HalfLife: jsontime.Duration(30 * time.Minute), @@ -122,29 +138,29 @@ func TestUnmarshal(t *testing.T) { name string r *http.Request dest any - compare any + expect any expectErr error opts []forms.Option }{ { - name: "base case", - r: makeRequest("call1.http"), - dest: &callUploadRequest{}, - compare: &Call1, - opts: []forms.Option{forms.WithAcceptBlank()}, + name: "base case", + r: makeRequest("call1.http"), + dest: &callUploadRequest{}, + expect: &Call1, + opts: []forms.Option{forms.WithAcceptBlank()}, }, { name: "base case no accept blank", r: makeRequest("call1.http"), dest: &callUploadRequest{}, - compare: &Call1, + expect: &Call1, expectErr: errors.New(`parsebool(''): strconv.ParseBool: parsing "": invalid syntax`), }, { name: "not a pointer", r: makeRequest("call1.http"), dest: callUploadRequest{}, - compare: callUploadRequest{}, + expect: callUploadRequest{}, expectErr: forms.ErrNotPointer, opts: []forms.Option{forms.WithAcceptBlank()}, }, @@ -152,7 +168,7 @@ func TestUnmarshal(t *testing.T) { name: "not a struct", r: makeRequest("call1.http"), dest: &str, - compare: callUploadRequest{}, + expect: callUploadRequest{}, expectErr: forms.ErrNotStruct, opts: []forms.Option{forms.WithAcceptBlank()}, }, @@ -160,44 +176,51 @@ func TestUnmarshal(t *testing.T) { name: "url encoded", r: makeRequest("urlenc.http"), dest: &urlEncTest{}, - compare: &UrlEncTest, + expect: &UrlEncTest, expectErr: errors.New(`Could not find format for ""`), }, { - name: "url encoded accept blank", - r: makeRequest("urlenc.http"), - dest: &urlEncTest{}, - compare: &UrlEncTest, - opts: []forms.Option{forms.WithAcceptBlank()}, + name: "url encoded accept blank", + r: makeRequest("urlenc.http"), + dest: &urlEncTest{}, + expect: &UrlEncTest, + opts: []forms.Option{forms.WithAcceptBlank()}, + }, + { + name: "url encoded accept blank pointer", + r: makeRequest("urlenc.http"), + dest: &ptrTestJT{}, + expect: &PtrTestJT, + opts: []forms.Option{forms.WithAcceptBlank()}, }, { name: "url encoded jsontime", r: makeRequest("urlenc.http"), dest: &urlEncTestJT{}, - compare: &UrlEncTestJT, + expect: &UrlEncTestJT, expectErr: errors.New(`Could not find format for ""`), opts: []forms.Option{forms.WithTag("json")}, }, { - name: "url encoded jsontime with tz", - r: makeRequest("urlenc.http"), - dest: &urlEncTestJT{}, - compare: &UrlEncTestJT, - opts: []forms.Option{forms.WithAcceptBlank(), forms.WithParseTimeInTZ(time.UTC), forms.WithTag("json")}, + name: "url encoded jsontime with tz", + r: makeRequest("urlenc.http"), + dest: &urlEncTestJT{}, + expect: &UrlEncTestJT, + opts: []forms.Option{forms.WithAcceptBlank(), forms.WithParseTimeInTZ(time.UTC), forms.WithTag("json")}, }, { - name: "url encoded jsontime with local", - r: makeRequest("urlenc.http"), - dest: &urlEncTestJT{}, - compare: &UrlEncTestJTLocal, - opts: []forms.Option{forms.WithAcceptBlank(), forms.WithParseLocalTime(), forms.WithTag("json")}, + name: "url encoded jsontime with local", + r: makeRequest("urlenc.http"), + dest: &urlEncTestJT{}, + expect: &UrlEncTestJTLocal, + opts: []forms.Option{forms.WithAcceptBlank(), forms.WithParseLocalTime(), forms.WithTag("json")}, }, { - name: "sim real data", - r: makeRequest("urlenc2.http"), - dest: &alerting.Simulation{}, - compare: realSim, - opts: []forms.Option{forms.WithAcceptBlank(), forms.WithParseLocalTime()}, + name: "sim real data", + r: makeRequest("urlenc2.http"), + dest: &alerting.Simulation{}, + expect: realSim, + opts: []forms.Option{forms.WithAcceptBlank(), forms.WithParseLocalTime()}, }, } @@ -209,7 +232,7 @@ func TestUnmarshal(t *testing.T) { assert.Contains(t, tc.expectErr.Error(), err.Error()) } else { require.NoError(t, err) - assert.Equal(t, tc.compare, tc.dest) + assert.Equal(t, tc.expect, tc.dest) } }) } diff --git a/pkg/alerting/alert/alert.go b/pkg/alerting/alert/alert.go index 52d1bf0..6f3ad84 100644 --- a/pkg/alerting/alert/alert.go +++ b/pkg/alerting/alert/alert.go @@ -44,8 +44,7 @@ func (a *Alert) ToAddAlertParams() database.AddAlertParams { } } -// makeAlert creates a notification for later rendering by the template. -// It takes a talkgroup Score as input. +// Make creates an alert for later rendering or storage. func Make(ctx context.Context, store talkgroups.Store, score trending.Score[talkgroups.ID], origScore float64) (Alert, error) { d := Alert{ ID: uuid.New(), diff --git a/pkg/alerting/simulate.go b/pkg/alerting/simulate.go index 6d646fa..0ebefb2 100644 --- a/pkg/alerting/simulate.go +++ b/pkg/alerting/simulate.go @@ -2,7 +2,6 @@ package alerting import ( "context" - "encoding/json" "errors" "fmt" "net/http" @@ -114,24 +113,15 @@ func (s *Simulation) Simulate(ctx context.Context) (trending.Scores[talkgroups.I func (as *alerter) simulateHandler(w http.ResponseWriter, r *http.Request) { ctx := r.Context() s := new(Simulation) - switch r.Header.Get("Content-Type") { - case "application/json": - err := json.NewDecoder(r.Body).Decode(s) - if err != nil { - err = fmt.Errorf("simulate decode: %w", err) - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - default: - err := forms.Unmarshal(r, s, forms.WithAcceptBlank(), forms.WithParseLocalTime()) - if err != nil { - err = fmt.Errorf("simulate unmarshal: %w", err) - http.Error(w, err.Error(), http.StatusBadRequest) - return - } + + err := forms.Unmarshal(r, s, forms.WithAcceptBlank(), forms.WithParseLocalTime()) + if err != nil { + err = fmt.Errorf("simulate unmarshal: %w", err) + http.Error(w, err.Error(), http.StatusBadRequest) + return } - err := s.verify() + err = s.verify() if err != nil { err = fmt.Errorf("simulation profile verify: %w", err) http.Error(w, err.Error(), http.StatusBadRequest) diff --git a/pkg/api/api.go b/pkg/api/api.go deleted file mode 100644 index 19ab3ef..0000000 --- a/pkg/api/api.go +++ /dev/null @@ -1,127 +0,0 @@ -package api - -import ( - "encoding/json" - "errors" - "net/http" - - "dynatron.me/x/stillbox/pkg/talkgroups" - - "github.com/go-chi/chi/v5" - "github.com/go-viper/mapstructure/v2" - "github.com/jackc/pgx/v5" - "github.com/rs/zerolog/log" -) - -type API interface { - Subrouter() http.Handler -} - -type api struct { - tgs talkgroups.Store -} - -func New(tgs talkgroups.Store) API { - s := &api{ - tgs: tgs, - } - - return s -} - -func (a *api) Subrouter() http.Handler { - r := chi.NewMux() - - r.Get("/talkgroup/{system:\\d+}/{id:\\d+}", a.talkgroup) - r.Get("/talkgroup/{system:\\d+}/", a.talkgroup) - r.Get("/talkgroup/", a.talkgroup) - return r -} - -var statusMapping = map[error]int{ - talkgroups.ErrNotFound: http.StatusNotFound, - pgx.ErrNoRows: http.StatusNotFound, -} - -func httpCode(err error) int { - c, ok := statusMapping[err] - if ok { - return c - } - - for e, c := range statusMapping { // check if err wraps an error we know about - if errors.Is(err, e) { - return c - } - } - - return http.StatusInternalServerError -} - -func (a *api) writeResponse(w http.ResponseWriter, r *http.Request, data interface{}, err error) { - if err != nil { - log.Error().Str("path", r.URL.Path).Err(err).Msg("request failed") - http.Error(w, err.Error(), httpCode(err)) - return - } - - w.Header().Set("Content-Type", "application/json") - enc := json.NewEncoder(w) - err = enc.Encode(data) - if err != nil { - log.Error().Str("path", r.URL.Path).Err(err).Msg("response marshal failed") - http.Error(w, err.Error(), httpCode(err)) - return - } -} - -func decodeParams(d interface{}, r *http.Request) error { - params := chi.RouteContext(r.Context()).URLParams - m := make(map[string]string, len(params.Keys)) - - for i, k := range params.Keys { - m[k] = params.Values[i] - } - - dec, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{ - Metadata: nil, - Result: d, - TagName: "param", - WeaklyTypedInput: true, - }) - if err != nil { - return err - } - - return dec.Decode(m) -} - -func (a *api) badReq(w http.ResponseWriter, err error) { - http.Error(w, err.Error(), http.StatusBadRequest) -} - -func (a *api) talkgroup(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - p := struct { - System *int `param:"system"` - ID *int `param:"id"` - }{} - - err := decodeParams(&p, r) - if err != nil { - a.badReq(w, err) - return - } - - var res interface{} - switch { - case p.System != nil && p.ID != nil: - res, err = a.tgs.TG(ctx, talkgroups.TG(*p.System, *p.ID)) - case p.System != nil: - res, err = a.tgs.SystemTGs(ctx, int32(*p.System)) - default: - res, err = a.tgs.TGs(ctx, nil) - } - - a.writeResponse(w, r, res, err) -} diff --git a/pkg/auth/login.html b/pkg/auth/login.html index eeb9783..9c99bfc 100644 --- a/pkg/auth/login.html +++ b/pkg/auth/login.html @@ -5,7 +5,7 @@