diff --git a/internal/forms/forms.go b/internal/forms/forms.go index facb68c..bc28cc4 100644 --- a/internal/forms/forms.go +++ b/internal/forms/forms.go @@ -1,6 +1,7 @@ package forms import ( + "encoding/json" "errors" "fmt" "io" @@ -16,15 +17,22 @@ import ( ) var ( - ErrNotStruct = errors.New("destination is not a struct") - ErrNotPointer = errors.New("destination is not a pointer") + ErrNotStruct = errors.New("destination is not a struct") + ErrNotPointer = errors.New("destination is not a pointer") + ErrContentType = errors.New("bad content type") +) + +const ( + MaxMultipartMemory int64 = 1024 * 1024 // 1MB ) type options struct { - tagOverride *string - parseTimeIn *time.Location - parseLocal bool - acceptBlank bool + tagOverride *string + parseTimeIn *time.Location + parseLocal bool + acceptBlank bool + maxMultipartMemory int64 + defaultOmitEmpty bool } type Option func(*options) @@ -53,6 +61,18 @@ func WithTag(t string) Option { } } +func WithMaxMultipartSize(s int64) Option { + return func(o *options) { + o.maxMultipartMemory = s + } +} + +func WithOmitEmpty() Option { + return func(o *options) { + o.defaultOmitEmpty = true + } +} + func (o *options) Tag() string { if o.tagOverride != nil { return *o.tagOverride @@ -147,17 +167,19 @@ func (o *options) parseDuration(s string) (v time.Duration, set bool, err error) return } -func (o *options) iterFields(r *http.Request, rv reflect.Value) error { - rt := rv.Type() - for i := 0; i < rv.NumField(); i++ { - f := rv.Field(i) - tf := rt.Field(i) - if !tf.IsExported() && !tf.Anonymous { +var typeOfByteSlice = reflect.TypeOf([]byte(nil)) + +func (o *options) iterFields(r *http.Request, destStruct reflect.Value) error { + structType := destStruct.Type() + for i := 0; i < destStruct.NumField(); i++ { + destFieldVal := destStruct.Field(i) + fieldType := structType.Field(i) + if !fieldType.IsExported() && !fieldType.Anonymous { continue } - if f.Kind() == reflect.Struct && tf.Anonymous { - err := o.iterFields(r, f) + if destFieldVal.Kind() == reflect.Struct && fieldType.Anonymous { + err := o.iterFields(r, destFieldVal) if err != nil { return err } @@ -165,51 +187,38 @@ func (o *options) iterFields(r *http.Request, rv reflect.Value) error { var tAr []string var formField string - formTag, has := rt.Field(i).Tag.Lookup(o.Tag()) + var omitEmpty bool + if o.defaultOmitEmpty { + omitEmpty = true + } + + formTag, has := structType.Field(i).Tag.Lookup(o.Tag()) if has { tAr = strings.Split(formTag, ",") formField = tAr[0] + for _, v := range tAr[1:] { + if v == "omitempty" { + omitEmpty = true + break + } + } } + if !has || formField == "-" { continue } - fi := f.Interface() + destFieldIntf := destFieldVal.Interface() - switch v := fi.(type) { - case string, *string: - s := r.Form.Get(formField) - setVal(f, s != "" || o.acceptBlank, v, s) - case int, uint, *int, *uint: - ff := r.Form.Get(formField) - val, set, err := o.parseInt(ff) - if err != nil { - return err - } - setVal(f, set, v, val) - case float64: - ff := r.Form.Get(formField) - val, set, err := o.parseFloat64(ff) - if err != nil { - return err - } - setVal(f, set, v, val) - case bool, *bool: - ff := r.Form.Get(formField) - val, set, err := o.parseBool(ff) - if err != nil { - return err - } - setVal(f, set, v, val) - case []byte: + if destFieldVal.Kind() == reflect.Slice && destFieldVal.Type() == typeOfByteSlice { file, hdr, err := r.FormFile(formField) if err != nil { return fmt.Errorf("get form file: %w", err) } - nameField, hasFilename := rt.Field(i).Tag.Lookup("filenameField") + nameField, hasFilename := structType.Field(i).Tag.Lookup("filenameField") if hasFilename { - fnf := rv.FieldByName(nameField) + fnf := destStruct.FieldByName(nameField) if fnf == (reflect.Value{}) { panic(fmt.Errorf("filenameField '%s' does not exist", nameField)) } @@ -221,23 +230,52 @@ func (o *options) iterFields(r *http.Request, rv reflect.Value) error { return fmt.Errorf("file read: %w", err) } - f.SetBytes(audioBytes) + destFieldVal.SetBytes(audioBytes) + + continue + } + + if !r.Form.Has(formField) && omitEmpty { + continue + } + + ff := r.Form.Get(formField) + + switch v := destFieldIntf.(type) { + case string, *string: + setVal(destFieldVal, ff != "" || o.acceptBlank, ff) + case int, uint, *int, *uint: + val, set, err := o.parseInt(ff) + if err != nil { + return err + } + setVal(destFieldVal, set, val) + case float64: + val, set, err := o.parseFloat64(ff) + if err != nil { + return err + } + setVal(destFieldVal, set, val) + case bool, *bool: + val, set, err := o.parseBool(ff) + if err != nil { + return err + } + setVal(destFieldVal, set, val) case time.Time, *time.Time, jsontime.Time, *jsontime.Time: - tval := r.Form.Get(formField) - t, set, err := o.parseTime(tval) + t, set, err := o.parseTime(ff) if err != nil { return err } - setVal(f, set, v, t) + setVal(destFieldVal, set, t) case time.Duration, *time.Duration, jsontime.Duration, *jsontime.Duration: - dval := r.Form.Get(formField) - d, set, err := o.parseDuration(dval) + d, set, err := o.parseDuration(ff) if err != nil { return err } - setVal(f, set, v, d) + setVal(destFieldVal, set, d) case []int: - val := strings.Trim(r.Form.Get(formField), "[]") + val := strings.Trim(ff, "[]") if val == "" && o.acceptBlank { continue } @@ -249,7 +287,7 @@ func (o *options) iterFields(r *http.Request, rv reflect.Value) error { ar = append(ar, i) } } - f.Set(reflect.ValueOf(ar)) + destFieldVal.Set(reflect.ValueOf(ar)) default: panic(fmt.Errorf("unsupported type %T", v)) } @@ -258,48 +296,77 @@ func (o *options) iterFields(r *http.Request, rv reflect.Value) error { return nil } -func setVal(setField reflect.Value, set bool, fv any, sv any) { +func setVal(destFieldVal reflect.Value, set bool, src any) { if !set { return } - rv := reflect.TypeOf(fv) - svo := reflect.ValueOf(sv) + destType := destFieldVal.Type() + srcVal := reflect.ValueOf(src) - if svo.CanConvert(rv) { - svo = svo.Convert(rv) + if srcVal.Kind() == reflect.Ptr { + srcVal = srcVal.Elem() } - if rv.Kind() == reflect.Ptr { - svo = svo.Addr() + if destType.Kind() == reflect.Ptr { + if !srcVal.CanAddr() { + if srcVal.CanConvert(destType.Elem()) { + srcVal = srcVal.Convert(destType.Elem()) + } + copy := reflect.New(srcVal.Type()) + copy.Elem().Set(srcVal) + srcVal = copy + } + } else if srcVal.CanConvert(destFieldVal.Type()) { + srcVal = srcVal.Convert(destFieldVal.Type()) } - setField.Set(svo) + destFieldVal.Set(srcVal) } func Unmarshal(r *http.Request, dest any, opt ...Option) error { - o := options{} + o := options{ + maxMultipartMemory: MaxMultipartMemory, + } + for _, opt := range opt { opt(&o) } - rv := reflect.ValueOf(dest) - if k := rv.Kind(); k == reflect.Ptr { - rv = rv.Elem() - } else { - return ErrNotPointer - } + contentType := strings.Split(r.Header.Get("Content-Type"), ";")[0] - if rv.Kind() != reflect.Struct { - return ErrNotStruct - } + switch contentType { + case "multipart/form-data": + err := r.ParseMultipartForm(o.maxMultipartMemory) + if err != nil { + return fmt.Errorf("ParseForm: %w", err) + } - if strings.HasPrefix(r.Header.Get("Content-Type"), "application/x-www-form-urlencoded") { + return o.unmarshalForm(r, dest) + case "application/x-www-form-urlencoded": err := r.ParseForm() if err != nil { return fmt.Errorf("ParseForm: %w", err) } + return o.unmarshalForm(r, dest) + case "application/json": + return json.NewDecoder(r.Body).Decode(dest) } - return o.iterFields(r, rv) + return ErrContentType +} + +func (o *options) unmarshalForm(r *http.Request, dest any) error { + destVal := reflect.ValueOf(dest) + if k := destVal.Kind(); k == reflect.Ptr { + destVal = destVal.Elem() + } else { + return ErrNotPointer + } + + if destVal.Kind() != reflect.Struct { + return ErrNotStruct + } + + return o.iterFields(r, destVal) } diff --git a/internal/forms/forms_test.go b/internal/forms/forms_test.go index c75b839..5f73dca 100644 --- a/internal/forms/forms_test.go +++ b/internal/forms/forms_test.go @@ -8,6 +8,7 @@ import ( "testing" "time" + "dynatron.me/x/stillbox/internal/common" "dynatron.me/x/stillbox/internal/forms" "dynatron.me/x/stillbox/internal/jsontime" @@ -54,6 +55,14 @@ type urlEncTestJT struct { ScoreEnd jsontime.Time `json:"scoreEnd"` } +type ptrTestJT struct { + LookbackDays uint `form:"lookbackDays"` + HalfLife *jsontime.Duration `form:"halfLife"` + Recent *string `form:"recent"` + ScoreStart *jsontime.Time `form:"scoreStart"` + ScoreEnd jsontime.Time `form:"scoreEnd"` +} + var ( UrlEncTest = urlEncTest{ LookbackDays: 7, @@ -69,6 +78,13 @@ var ( ScoreStart: jsontime.Time(time.Date(2024, time.October, 28, 9, 25, 0, 0, time.UTC)), } + PtrTestJT = ptrTestJT{ + LookbackDays: 7, + HalfLife: common.PtrTo(jsontime.Duration(30 * time.Minute)), + Recent: common.PtrTo("2h0m0s"), + ScoreStart: common.PtrTo(jsontime.Time(time.Date(2024, time.October, 28, 9, 25, 0, 0, time.UTC))), + } + UrlEncTestJTLocal = urlEncTestJT{ LookbackDays: 7, HalfLife: jsontime.Duration(30 * time.Minute), @@ -122,29 +138,29 @@ func TestUnmarshal(t *testing.T) { name string r *http.Request dest any - compare any + expect any expectErr error opts []forms.Option }{ { - name: "base case", - r: makeRequest("call1.http"), - dest: &callUploadRequest{}, - compare: &Call1, - opts: []forms.Option{forms.WithAcceptBlank()}, + name: "base case", + r: makeRequest("call1.http"), + dest: &callUploadRequest{}, + expect: &Call1, + opts: []forms.Option{forms.WithAcceptBlank()}, }, { name: "base case no accept blank", r: makeRequest("call1.http"), dest: &callUploadRequest{}, - compare: &Call1, + expect: &Call1, expectErr: errors.New(`parsebool(''): strconv.ParseBool: parsing "": invalid syntax`), }, { name: "not a pointer", r: makeRequest("call1.http"), dest: callUploadRequest{}, - compare: callUploadRequest{}, + expect: callUploadRequest{}, expectErr: forms.ErrNotPointer, opts: []forms.Option{forms.WithAcceptBlank()}, }, @@ -152,7 +168,7 @@ func TestUnmarshal(t *testing.T) { name: "not a struct", r: makeRequest("call1.http"), dest: &str, - compare: callUploadRequest{}, + expect: callUploadRequest{}, expectErr: forms.ErrNotStruct, opts: []forms.Option{forms.WithAcceptBlank()}, }, @@ -160,44 +176,51 @@ func TestUnmarshal(t *testing.T) { name: "url encoded", r: makeRequest("urlenc.http"), dest: &urlEncTest{}, - compare: &UrlEncTest, + expect: &UrlEncTest, expectErr: errors.New(`Could not find format for ""`), }, { - name: "url encoded accept blank", - r: makeRequest("urlenc.http"), - dest: &urlEncTest{}, - compare: &UrlEncTest, - opts: []forms.Option{forms.WithAcceptBlank()}, + name: "url encoded accept blank", + r: makeRequest("urlenc.http"), + dest: &urlEncTest{}, + expect: &UrlEncTest, + opts: []forms.Option{forms.WithAcceptBlank()}, + }, + { + name: "url encoded accept blank pointer", + r: makeRequest("urlenc.http"), + dest: &ptrTestJT{}, + expect: &PtrTestJT, + opts: []forms.Option{forms.WithAcceptBlank()}, }, { name: "url encoded jsontime", r: makeRequest("urlenc.http"), dest: &urlEncTestJT{}, - compare: &UrlEncTestJT, + expect: &UrlEncTestJT, expectErr: errors.New(`Could not find format for ""`), opts: []forms.Option{forms.WithTag("json")}, }, { - name: "url encoded jsontime with tz", - r: makeRequest("urlenc.http"), - dest: &urlEncTestJT{}, - compare: &UrlEncTestJT, - opts: []forms.Option{forms.WithAcceptBlank(), forms.WithParseTimeInTZ(time.UTC), forms.WithTag("json")}, + name: "url encoded jsontime with tz", + r: makeRequest("urlenc.http"), + dest: &urlEncTestJT{}, + expect: &UrlEncTestJT, + opts: []forms.Option{forms.WithAcceptBlank(), forms.WithParseTimeInTZ(time.UTC), forms.WithTag("json")}, }, { - name: "url encoded jsontime with local", - r: makeRequest("urlenc.http"), - dest: &urlEncTestJT{}, - compare: &UrlEncTestJTLocal, - opts: []forms.Option{forms.WithAcceptBlank(), forms.WithParseLocalTime(), forms.WithTag("json")}, + name: "url encoded jsontime with local", + r: makeRequest("urlenc.http"), + dest: &urlEncTestJT{}, + expect: &UrlEncTestJTLocal, + opts: []forms.Option{forms.WithAcceptBlank(), forms.WithParseLocalTime(), forms.WithTag("json")}, }, { - name: "sim real data", - r: makeRequest("urlenc2.http"), - dest: &alerting.Simulation{}, - compare: realSim, - opts: []forms.Option{forms.WithAcceptBlank(), forms.WithParseLocalTime()}, + name: "sim real data", + r: makeRequest("urlenc2.http"), + dest: &alerting.Simulation{}, + expect: realSim, + opts: []forms.Option{forms.WithAcceptBlank(), forms.WithParseLocalTime()}, }, } @@ -209,7 +232,7 @@ func TestUnmarshal(t *testing.T) { assert.Contains(t, tc.expectErr.Error(), err.Error()) } else { require.NoError(t, err) - assert.Equal(t, tc.compare, tc.dest) + assert.Equal(t, tc.expect, tc.dest) } }) } diff --git a/pkg/alerting/alert/alert.go b/pkg/alerting/alert/alert.go index 52d1bf0..6f3ad84 100644 --- a/pkg/alerting/alert/alert.go +++ b/pkg/alerting/alert/alert.go @@ -44,8 +44,7 @@ func (a *Alert) ToAddAlertParams() database.AddAlertParams { } } -// makeAlert creates a notification for later rendering by the template. -// It takes a talkgroup Score as input. +// Make creates an alert for later rendering or storage. func Make(ctx context.Context, store talkgroups.Store, score trending.Score[talkgroups.ID], origScore float64) (Alert, error) { d := Alert{ ID: uuid.New(), diff --git a/pkg/alerting/simulate.go b/pkg/alerting/simulate.go index 6d646fa..0ebefb2 100644 --- a/pkg/alerting/simulate.go +++ b/pkg/alerting/simulate.go @@ -2,7 +2,6 @@ package alerting import ( "context" - "encoding/json" "errors" "fmt" "net/http" @@ -114,24 +113,15 @@ func (s *Simulation) Simulate(ctx context.Context) (trending.Scores[talkgroups.I func (as *alerter) simulateHandler(w http.ResponseWriter, r *http.Request) { ctx := r.Context() s := new(Simulation) - switch r.Header.Get("Content-Type") { - case "application/json": - err := json.NewDecoder(r.Body).Decode(s) - if err != nil { - err = fmt.Errorf("simulate decode: %w", err) - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - default: - err := forms.Unmarshal(r, s, forms.WithAcceptBlank(), forms.WithParseLocalTime()) - if err != nil { - err = fmt.Errorf("simulate unmarshal: %w", err) - http.Error(w, err.Error(), http.StatusBadRequest) - return - } + + err := forms.Unmarshal(r, s, forms.WithAcceptBlank(), forms.WithParseLocalTime()) + if err != nil { + err = fmt.Errorf("simulate unmarshal: %w", err) + http.Error(w, err.Error(), http.StatusBadRequest) + return } - err := s.verify() + err = s.verify() if err != nil { err = fmt.Errorf("simulation profile verify: %w", err) http.Error(w, err.Error(), http.StatusBadRequest) diff --git a/pkg/api/api.go b/pkg/api/api.go deleted file mode 100644 index 19ab3ef..0000000 --- a/pkg/api/api.go +++ /dev/null @@ -1,127 +0,0 @@ -package api - -import ( - "encoding/json" - "errors" - "net/http" - - "dynatron.me/x/stillbox/pkg/talkgroups" - - "github.com/go-chi/chi/v5" - "github.com/go-viper/mapstructure/v2" - "github.com/jackc/pgx/v5" - "github.com/rs/zerolog/log" -) - -type API interface { - Subrouter() http.Handler -} - -type api struct { - tgs talkgroups.Store -} - -func New(tgs talkgroups.Store) API { - s := &api{ - tgs: tgs, - } - - return s -} - -func (a *api) Subrouter() http.Handler { - r := chi.NewMux() - - r.Get("/talkgroup/{system:\\d+}/{id:\\d+}", a.talkgroup) - r.Get("/talkgroup/{system:\\d+}/", a.talkgroup) - r.Get("/talkgroup/", a.talkgroup) - return r -} - -var statusMapping = map[error]int{ - talkgroups.ErrNotFound: http.StatusNotFound, - pgx.ErrNoRows: http.StatusNotFound, -} - -func httpCode(err error) int { - c, ok := statusMapping[err] - if ok { - return c - } - - for e, c := range statusMapping { // check if err wraps an error we know about - if errors.Is(err, e) { - return c - } - } - - return http.StatusInternalServerError -} - -func (a *api) writeResponse(w http.ResponseWriter, r *http.Request, data interface{}, err error) { - if err != nil { - log.Error().Str("path", r.URL.Path).Err(err).Msg("request failed") - http.Error(w, err.Error(), httpCode(err)) - return - } - - w.Header().Set("Content-Type", "application/json") - enc := json.NewEncoder(w) - err = enc.Encode(data) - if err != nil { - log.Error().Str("path", r.URL.Path).Err(err).Msg("response marshal failed") - http.Error(w, err.Error(), httpCode(err)) - return - } -} - -func decodeParams(d interface{}, r *http.Request) error { - params := chi.RouteContext(r.Context()).URLParams - m := make(map[string]string, len(params.Keys)) - - for i, k := range params.Keys { - m[k] = params.Values[i] - } - - dec, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{ - Metadata: nil, - Result: d, - TagName: "param", - WeaklyTypedInput: true, - }) - if err != nil { - return err - } - - return dec.Decode(m) -} - -func (a *api) badReq(w http.ResponseWriter, err error) { - http.Error(w, err.Error(), http.StatusBadRequest) -} - -func (a *api) talkgroup(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - p := struct { - System *int `param:"system"` - ID *int `param:"id"` - }{} - - err := decodeParams(&p, r) - if err != nil { - a.badReq(w, err) - return - } - - var res interface{} - switch { - case p.System != nil && p.ID != nil: - res, err = a.tgs.TG(ctx, talkgroups.TG(*p.System, *p.ID)) - case p.System != nil: - res, err = a.tgs.SystemTGs(ctx, int32(*p.System)) - default: - res, err = a.tgs.TGs(ctx, nil) - } - - a.writeResponse(w, r, res, err) -} diff --git a/pkg/auth/login.html b/pkg/auth/login.html index eeb9783..9c99bfc 100644 --- a/pkg/auth/login.html +++ b/pkg/auth/login.html @@ -5,7 +5,7 @@
-
+ diff --git a/pkg/database/database.go b/pkg/database/database.go index 37794ff..9f1c69b 100644 --- a/pkg/database/database.go +++ b/pkg/database/database.go @@ -56,13 +56,13 @@ func NewClient(ctx context.Context, conf config.DB) (*DB, error) { return db, nil } -type DBCtxKey string +type dBCtxKey string -const DBCTXKeyValue DBCtxKey = "dbctx" +const DBCtxKey dBCtxKey = "dbctx" // FromCtx returns the database handle from the provided Context. func FromCtx(ctx context.Context) *DB { - c, ok := ctx.Value(DBCTXKeyValue).(*DB) + c, ok := ctx.Value(DBCtxKey).(*DB) if !ok { panic("no DB in context") } @@ -72,7 +72,7 @@ func FromCtx(ctx context.Context) *DB { // CtxWithDB returns a Context with the provided database handle. func CtxWithDB(ctx context.Context, conn *DB) context.Context { - return context.WithValue(ctx, DBCTXKeyValue, conn) + return context.WithValue(ctx, DBCtxKey, conn) } // IsNoRows is a convenience function that returns whether a returned error is a database diff --git a/pkg/database/extend.go b/pkg/database/extend.go index 3a96cd1..5f165a0 100644 --- a/pkg/database/extend.go +++ b/pkg/database/extend.go @@ -9,3 +9,6 @@ func (g GetTalkgroupsWithLearnedRow) GetLearned() bool { retur func (g GetTalkgroupsWithLearnedBySystemRow) GetTalkgroup() Talkgroup { return g.Talkgroup } func (g GetTalkgroupsWithLearnedBySystemRow) GetSystem() System { return g.System } func (g GetTalkgroupsWithLearnedBySystemRow) GetLearned() bool { return g.Learned } +func (g Talkgroup) GetTalkgroup() Talkgroup { return g } +func (g Talkgroup) GetSystem() System { return System{ID: int(g.SystemID)} } +func (g Talkgroup) GetLearned() bool { return false } diff --git a/pkg/database/querier.go b/pkg/database/querier.go index ba1ace1..6a814c3 100644 --- a/pkg/database/querier.go +++ b/pkg/database/querier.go @@ -39,6 +39,7 @@ type Querier interface { SetCallTranscript(ctx context.Context, iD uuid.UUID, transcript *string) error SetTalkgroupTags(ctx context.Context, sys int, tg int, tags []string) error UpdatePassword(ctx context.Context, username string, password string) error + UpdateTalkgroup(ctx context.Context, arg UpdateTalkgroupParams) (Talkgroup, error) } var _ Querier = (*Queries)(nil) diff --git a/pkg/database/talkgroups.sql.go b/pkg/database/talkgroups.sql.go index b32deb3..81d224f 100644 --- a/pkg/database/talkgroups.sql.go +++ b/pkg/database/talkgroups.sql.go @@ -472,3 +472,63 @@ func (q *Queries) SetTalkgroupTags(ctx context.Context, sys int, tg int, tags [] _, err := q.db.Exec(ctx, setTalkgroupTags, sys, tg, tags) return err } + +const updateTalkgroup = `-- name: UpdateTalkgroup :one +UPDATE talkgroups +SET + name = COALESCE($1, name), + alpha_tag = COALESCE($2, alpha_tag), + tg_group = COALESCE($3, tg_group), + frequency = COALESCE($4, frequency), + metadata = COALESCE($5, metadata), + tags = COALESCE($6, tags), + alert = COALESCE($7, alert), + alert_config = COALESCE($8, alert_config), + weight = COALESCE($9, weight) +WHERE id = $10 +RETURNING id, system_id, tgid, name, alpha_tag, tg_group, frequency, metadata, tags, alert, alert_config, weight +` + +type UpdateTalkgroupParams struct { + Name *string `json:"name"` + AlphaTag *string `json:"alpha_tag"` + TgGroup *string `json:"tg_group"` + Frequency *int32 `json:"frequency"` + Metadata []byte `json:"metadata"` + Tags []string `json:"tags"` + Alert *bool `json:"alert"` + AlertConfig []byte `json:"alert_config"` + Weight *float32 `json:"weight"` + ID int64 `json:"id"` +} + +func (q *Queries) UpdateTalkgroup(ctx context.Context, arg UpdateTalkgroupParams) (Talkgroup, error) { + row := q.db.QueryRow(ctx, updateTalkgroup, + arg.Name, + arg.AlphaTag, + arg.TgGroup, + arg.Frequency, + arg.Metadata, + arg.Tags, + arg.Alert, + arg.AlertConfig, + arg.Weight, + arg.ID, + ) + var i Talkgroup + err := row.Scan( + &i.ID, + &i.SystemID, + &i.Tgid, + &i.Name, + &i.AlphaTag, + &i.TgGroup, + &i.Frequency, + &i.Metadata, + &i.Tags, + &i.Alert, + &i.AlertConfig, + &i.Weight, + ) + return i, err +} diff --git a/pkg/rest/api.go b/pkg/rest/api.go new file mode 100644 index 0000000..0c08daa --- /dev/null +++ b/pkg/rest/api.go @@ -0,0 +1,132 @@ +package rest + +import ( + "errors" + "net/http" + + "dynatron.me/x/stillbox/pkg/talkgroups" + + "github.com/go-chi/chi/v5" + "github.com/go-chi/render" + "github.com/go-viper/mapstructure/v2" + "github.com/jackc/pgx/v5" + "github.com/rs/zerolog/log" +) + +type API interface { + Subrouter() http.Handler +} + +type api struct { +} + +func New() API { + s := new(api) + + return s +} + +func (a *api) Subrouter() http.Handler { + r := chi.NewMux() + + r.Mount("/talkgroup", new(talkgroupAPI).Subrouter()) + + return r +} + +type errResponse struct { + Err error `json:"-"` + Code int `json:"-"` + Error string `json:"error"` +} + +func (e *errResponse) Render(w http.ResponseWriter, r *http.Request) error { + switch e.Code { + case http.StatusNotFound: + case http.StatusBadRequest: + default: + log.Error().Str("path", r.URL.Path).Err(e.Err).Int("code", e.Code).Str("msg", e.Error).Msg("request failed") + } + + render.Status(r, e.Code) + + return nil +} + +func badRequest(err error) render.Renderer { + return &errResponse{ + Err: err, + Code: http.StatusBadRequest, + Error: "Bad request", + } +} + +func recordNotFound(err error) render.Renderer { + return &errResponse{ + Err: err, + Code: http.StatusNotFound, + Error: "Record not found", + } +} + +func internalError(err error) render.Renderer { + return &errResponse{ + Err: err, + Code: http.StatusNotFound, + Error: "Internal server error", + } +} + +type errResponder func(error) render.Renderer + +var statusMapping = map[error]errResponder{ + talkgroups.ErrNotFound: recordNotFound, + pgx.ErrNoRows: recordNotFound, +} + +func autoError(err error) render.Renderer { + c, ok := statusMapping[err] + if ok { + c(err) + } + + for e, c := range statusMapping { // check if err wraps an error we know about + if errors.Is(err, e) { + return c(err) + } + } + + return internalError(err) +} + +func wErr(w http.ResponseWriter, r *http.Request, v render.Renderer) { + err := render.Render(w, r, v) + if err != nil { + log.Error().Err(err).Msg("wErr render error") + } +} + +func decodeParams(d interface{}, r *http.Request) error { + params := chi.RouteContext(r.Context()).URLParams + m := make(map[string]string, len(params.Keys)) + + for i, k := range params.Keys { + m[k] = params.Values[i] + } + + dec, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{ + Metadata: nil, + Result: d, + TagName: "param", + WeaklyTypedInput: true, + }) + if err != nil { + return err + } + + return dec.Decode(m) +} + +func respond(w http.ResponseWriter, r *http.Request, v interface{}) { + render.DefaultResponder(w, r, v) +} diff --git a/pkg/rest/talkgroups.go b/pkg/rest/talkgroups.go new file mode 100644 index 0000000..0520af7 --- /dev/null +++ b/pkg/rest/talkgroups.go @@ -0,0 +1,108 @@ +package rest + +import ( + "net/http" + + "dynatron.me/x/stillbox/internal/forms" + "dynatron.me/x/stillbox/pkg/database" + "dynatron.me/x/stillbox/pkg/talkgroups" + + "github.com/go-chi/chi/v5" +) + +type talkgroupAPI struct { +} + +func (tga *talkgroupAPI) Subrouter() http.Handler { + r := chi.NewMux() + + r.Get("/{system:\\d+}/{id:\\d+}", tga.get) + r.Put("/{system:\\d+}/{id:\\d+}", tga.put) + r.Get("/{system:\\d+}/", tga.get) + r.Get("/", tga.get) + + return r +} + +type tgParams struct { + System *int `param:"system"` + ID *int `param:"id"` +} + +func (t tgParams) haveBoth() bool { + return t.System != nil && t.ID != nil +} + +func (t tgParams) ToID() talkgroups.ID { + nilOr := func(i *int) uint32 { + if i == nil { + return 0 + } + + return uint32(*i) + } + + return talkgroups.ID{ + System: nilOr(t.System), + Talkgroup: nilOr(t.ID), + } +} + +func (tga *talkgroupAPI) get(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + tgs := talkgroups.StoreFrom(ctx) + + var p tgParams + + err := decodeParams(&p, r) + if err != nil { + wErr(w, r, badRequest(err)) + return + } + + var res interface{} + switch { + case p.System != nil && p.ID != nil: + res, err = tgs.TG(ctx, talkgroups.TG(*p.System, *p.ID)) + case p.System != nil: + res, err = tgs.SystemTGs(ctx, int32(*p.System)) + default: + res, err = tgs.TGs(ctx, nil) + } + + if err != nil { + wErr(w, r, autoError(err)) + return + } + + respond(w, r, res) +} + +func (tga *talkgroupAPI) put(w http.ResponseWriter, r *http.Request) { + var id tgParams + err := decodeParams(&id, r) + if err != nil { + wErr(w, r, badRequest(err)) + return + } + + ctx := r.Context() + tgs := talkgroups.StoreFrom(ctx) + + input := database.UpdateTalkgroupParams{} + + err = forms.Unmarshal(r, &input, forms.WithTag("json"), forms.WithAcceptBlank(), forms.WithOmitEmpty()) + if err != nil { + wErr(w, r, badRequest(err)) + return + } + input.ID = id.ToID().Pack() + + record, err := tgs.UpdateTG(ctx, input) + if err != nil { + wErr(w, r, autoError(err)) + return + } + + respond(w, r, record) +} diff --git a/pkg/server/profile.go b/pkg/server/profile.go index eb5787e..dcb9d6f 100644 --- a/pkg/server/profile.go +++ b/pkg/server/profile.go @@ -4,14 +4,9 @@ package server import ( - "net/http/pprof" + "github.com/go-chi/chi/v5/middleware" ) func (s *Server) installPprof() { - r := s.r - r.HandleFunc("/debug/pprof/", pprof.Index) - r.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline) - r.HandleFunc("/debug/pprof/profile", pprof.Profile) - r.HandleFunc("/debug/pprof/symbol", pprof.Symbol) - r.HandleFunc("/debug/pprof/trace", pprof.Trace) + s.r.Mount("/debug", middleware.Profiler()) } diff --git a/pkg/server/routes.go b/pkg/server/routes.go index bcc7c3b..b296d6b 100644 --- a/pkg/server/routes.go +++ b/pkg/server/routes.go @@ -9,6 +9,7 @@ import ( "dynatron.me/x/stillbox/internal/version" "dynatron.me/x/stillbox/pkg/config" "dynatron.me/x/stillbox/pkg/database" + "dynatron.me/x/stillbox/pkg/talkgroups" "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" "github.com/go-chi/httprate" @@ -26,7 +27,8 @@ func (s *Server) setupRoutes() { } r := s.r - r.Use(middleware.WithValue(database.DBCTXKeyValue, s.db)) + r.Use(middleware.WithValue(database.DBCtxKey, s.db)) + r.Use(middleware.WithValue(talkgroups.StoreCtxKey, s.tgs)) s.installPprof() @@ -36,7 +38,7 @@ func (s *Server) setupRoutes() { s.nex.PrivateRoutes(r) s.auth.PrivateRoutes(r) s.alerter.PrivateRoutes(r) - r.Mount("/api", s.api.Subrouter()) + r.Mount("/api", s.rest.Subrouter()) }) r.Group(func(r chi.Router) { diff --git a/pkg/server/server.go b/pkg/server/server.go index ed11996..ebec903 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -7,12 +7,12 @@ import ( "time" "dynatron.me/x/stillbox/pkg/alerting" - "dynatron.me/x/stillbox/pkg/api" "dynatron.me/x/stillbox/pkg/auth" "dynatron.me/x/stillbox/pkg/config" "dynatron.me/x/stillbox/pkg/database" "dynatron.me/x/stillbox/pkg/nexus" "dynatron.me/x/stillbox/pkg/notify" + "dynatron.me/x/stillbox/pkg/rest" "dynatron.me/x/stillbox/pkg/sinks" "dynatron.me/x/stillbox/pkg/sources" "dynatron.me/x/stillbox/pkg/talkgroups" @@ -37,7 +37,7 @@ type Server struct { notifier notify.Notifier hup chan os.Signal tgs talkgroups.Store - api api.API + rest rest.API } func New(ctx context.Context, cfg *config.Config) (*Server, error) { @@ -61,7 +61,7 @@ func New(ctx context.Context, cfg *config.Config) (*Server, error) { } tgCache := talkgroups.NewCache() - api := api.New(tgCache) + api := rest.New() srv := &Server{ auth: authenticator, @@ -73,7 +73,7 @@ func New(ctx context.Context, cfg *config.Config) (*Server, error) { alerter: alerting.New(cfg.Alerting, tgCache, alerting.WithNotifier(notifier)), notifier: notifier, tgs: tgCache, - api: api, + rest: api, } srv.sinks.Register("database", sinks.NewDatabaseSink(srv.db), true) diff --git a/pkg/talkgroups/cache.go b/pkg/talkgroups/cache.go index 85ac5d0..16b8ac6 100644 --- a/pkg/talkgroups/cache.go +++ b/pkg/talkgroups/cache.go @@ -17,7 +17,15 @@ import ( type tgMap map[ID]*Talkgroup +var ( + ErrNotFound = errors.New("talkgroup not found") + ErrNoSuchSystem = errors.New("no such system") +) + type Store interface { + // UpdateTG updates a talkgroup record. + UpdateTG(ctx context.Context, input database.UpdateTalkgroupParams) (*Talkgroup, error) + // TG retrieves a Talkgroup from the Store. TG(ctx context.Context, tg ID) (*Talkgroup, error) @@ -49,16 +57,16 @@ type Store interface { HUP(*config.Config) } -type CtxStoreKeyT string +type storeCtxKey string -const CtxStoreKey CtxStoreKeyT = "store" +const StoreCtxKey storeCtxKey = "store" func CtxWithStore(ctx context.Context, s Store) context.Context { - return context.WithValue(ctx, CtxStoreKey, s) + return context.WithValue(ctx, StoreCtxKey, s) } func StoreFrom(ctx context.Context) Store { - s, ok := ctx.Value(CtxStoreKey).(Store) + s, ok := ctx.Value(StoreCtxKey).(Store) if !ok { return NewCache() } @@ -213,8 +221,6 @@ func (t *cache) Load(ctx context.Context, tgs []int64) error { return nil } -var ErrNotFound = errors.New("talkgroup not found") - func (t *cache) Weight(ctx context.Context, id ID, tm time.Time) float64 { tg, err := t.TG(ctx, id) if err != nil { @@ -290,3 +296,23 @@ func (t *cache) SystemName(ctx context.Context, id int) (name string, has bool) return n, has } + +func (t *cache) UpdateTG(ctx context.Context, input database.UpdateTalkgroupParams) (*Talkgroup, error) { + sysName, has := t.SystemName(ctx, int(Unpack(input.ID).System)) + if !has { + return nil, ErrNoSuchSystem + } + + tg, err := database.FromCtx(ctx).UpdateTalkgroup(ctx, input) + if err != nil { + return nil, err + } + + record := &Talkgroup{ + Talkgroup: tg, + System: database.System{ID: int(tg.SystemID), Name: sysName}, + } + t.add(record) + + return record, nil +} diff --git a/pkg/talkgroups/talkgroup.go b/pkg/talkgroups/talkgroup.go index 288e488..e3cca71 100644 --- a/pkg/talkgroups/talkgroup.go +++ b/pkg/talkgroups/talkgroup.go @@ -49,6 +49,13 @@ func (t ID) Pack() int64 { return int64((int64(t.System) << 32) | int64(t.Talkgroup)) } +func Unpack(id int64) ID { + return ID{ + System: uint32(id >> 32), + Talkgroup: uint32(id & 0xffffffff), + } +} + func (t ID) String() string { return fmt.Sprintf("%d:%d", t.System, t.Talkgroup) diff --git a/sql/postgres/queries/talkgroups.sql b/sql/postgres/queries/talkgroups.sql index d201d0e..732a848 100644 --- a/sql/postgres/queries/talkgroups.sql +++ b/sql/postgres/queries/talkgroups.sql @@ -104,6 +104,20 @@ FROM talkgroups_learned tgl JOIN systems sys ON tgl.system_id = sys.id WHERE ignored IS NOT TRUE; - -- name: GetSystemName :one SELECT name FROM systems WHERE id = sqlc.arg(system_id); + +-- name: UpdateTalkgroup :one +UPDATE talkgroups +SET + name = COALESCE(sqlc.narg('name'), name), + alpha_tag = COALESCE(sqlc.narg('alpha_tag'), alpha_tag), + tg_group = COALESCE(sqlc.narg('tg_group'), tg_group), + frequency = COALESCE(sqlc.narg('frequency'), frequency), + metadata = COALESCE(sqlc.narg('metadata'), metadata), + tags = COALESCE(sqlc.narg('tags'), tags), + alert = COALESCE(sqlc.narg('alert'), alert), + alert_config = COALESCE(sqlc.narg('alert_config'), alert_config), + weight = COALESCE(sqlc.narg('weight'), weight) +WHERE id = @id +RETURNING *;