Merge pull request 'REST API, forms improvement' (#30) from rest into trunk

Reviewed-on: #30
This commit is contained in:
Daniel 2024-11-10 14:46:25 -05:00
commit 657c00e326
18 changed files with 576 additions and 276 deletions

View file

@ -1,6 +1,7 @@
package forms package forms
import ( import (
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"io" "io"
@ -18,6 +19,11 @@ import (
var ( var (
ErrNotStruct = errors.New("destination is not a struct") ErrNotStruct = errors.New("destination is not a struct")
ErrNotPointer = errors.New("destination is not a pointer") ErrNotPointer = errors.New("destination is not a pointer")
ErrContentType = errors.New("bad content type")
)
const (
MaxMultipartMemory int64 = 1024 * 1024 // 1MB
) )
type options struct { type options struct {
@ -25,6 +31,8 @@ type options struct {
parseTimeIn *time.Location parseTimeIn *time.Location
parseLocal bool parseLocal bool
acceptBlank bool acceptBlank bool
maxMultipartMemory int64
defaultOmitEmpty bool
} }
type Option func(*options) type Option func(*options)
@ -53,6 +61,18 @@ func WithTag(t string) Option {
} }
} }
func WithMaxMultipartSize(s int64) Option {
return func(o *options) {
o.maxMultipartMemory = s
}
}
func WithOmitEmpty() Option {
return func(o *options) {
o.defaultOmitEmpty = true
}
}
func (o *options) Tag() string { func (o *options) Tag() string {
if o.tagOverride != nil { if o.tagOverride != nil {
return *o.tagOverride return *o.tagOverride
@ -147,17 +167,19 @@ func (o *options) parseDuration(s string) (v time.Duration, set bool, err error)
return return
} }
func (o *options) iterFields(r *http.Request, rv reflect.Value) error { var typeOfByteSlice = reflect.TypeOf([]byte(nil))
rt := rv.Type()
for i := 0; i < rv.NumField(); i++ { func (o *options) iterFields(r *http.Request, destStruct reflect.Value) error {
f := rv.Field(i) structType := destStruct.Type()
tf := rt.Field(i) for i := 0; i < destStruct.NumField(); i++ {
if !tf.IsExported() && !tf.Anonymous { destFieldVal := destStruct.Field(i)
fieldType := structType.Field(i)
if !fieldType.IsExported() && !fieldType.Anonymous {
continue continue
} }
if f.Kind() == reflect.Struct && tf.Anonymous { if destFieldVal.Kind() == reflect.Struct && fieldType.Anonymous {
err := o.iterFields(r, f) err := o.iterFields(r, destFieldVal)
if err != nil { if err != nil {
return err return err
} }
@ -165,51 +187,38 @@ func (o *options) iterFields(r *http.Request, rv reflect.Value) error {
var tAr []string var tAr []string
var formField string var formField string
formTag, has := rt.Field(i).Tag.Lookup(o.Tag()) var omitEmpty bool
if o.defaultOmitEmpty {
omitEmpty = true
}
formTag, has := structType.Field(i).Tag.Lookup(o.Tag())
if has { if has {
tAr = strings.Split(formTag, ",") tAr = strings.Split(formTag, ",")
formField = tAr[0] formField = tAr[0]
for _, v := range tAr[1:] {
if v == "omitempty" {
omitEmpty = true
break
} }
}
}
if !has || formField == "-" { if !has || formField == "-" {
continue continue
} }
fi := f.Interface() destFieldIntf := destFieldVal.Interface()
switch v := fi.(type) { if destFieldVal.Kind() == reflect.Slice && destFieldVal.Type() == typeOfByteSlice {
case string, *string:
s := r.Form.Get(formField)
setVal(f, s != "" || o.acceptBlank, v, s)
case int, uint, *int, *uint:
ff := r.Form.Get(formField)
val, set, err := o.parseInt(ff)
if err != nil {
return err
}
setVal(f, set, v, val)
case float64:
ff := r.Form.Get(formField)
val, set, err := o.parseFloat64(ff)
if err != nil {
return err
}
setVal(f, set, v, val)
case bool, *bool:
ff := r.Form.Get(formField)
val, set, err := o.parseBool(ff)
if err != nil {
return err
}
setVal(f, set, v, val)
case []byte:
file, hdr, err := r.FormFile(formField) file, hdr, err := r.FormFile(formField)
if err != nil { if err != nil {
return fmt.Errorf("get form file: %w", err) return fmt.Errorf("get form file: %w", err)
} }
nameField, hasFilename := rt.Field(i).Tag.Lookup("filenameField") nameField, hasFilename := structType.Field(i).Tag.Lookup("filenameField")
if hasFilename { if hasFilename {
fnf := rv.FieldByName(nameField) fnf := destStruct.FieldByName(nameField)
if fnf == (reflect.Value{}) { if fnf == (reflect.Value{}) {
panic(fmt.Errorf("filenameField '%s' does not exist", nameField)) panic(fmt.Errorf("filenameField '%s' does not exist", nameField))
} }
@ -221,23 +230,52 @@ func (o *options) iterFields(r *http.Request, rv reflect.Value) error {
return fmt.Errorf("file read: %w", err) return fmt.Errorf("file read: %w", err)
} }
f.SetBytes(audioBytes) destFieldVal.SetBytes(audioBytes)
continue
}
if !r.Form.Has(formField) && omitEmpty {
continue
}
ff := r.Form.Get(formField)
switch v := destFieldIntf.(type) {
case string, *string:
setVal(destFieldVal, ff != "" || o.acceptBlank, ff)
case int, uint, *int, *uint:
val, set, err := o.parseInt(ff)
if err != nil {
return err
}
setVal(destFieldVal, set, val)
case float64:
val, set, err := o.parseFloat64(ff)
if err != nil {
return err
}
setVal(destFieldVal, set, val)
case bool, *bool:
val, set, err := o.parseBool(ff)
if err != nil {
return err
}
setVal(destFieldVal, set, val)
case time.Time, *time.Time, jsontime.Time, *jsontime.Time: case time.Time, *time.Time, jsontime.Time, *jsontime.Time:
tval := r.Form.Get(formField) t, set, err := o.parseTime(ff)
t, set, err := o.parseTime(tval)
if err != nil { if err != nil {
return err return err
} }
setVal(f, set, v, t) setVal(destFieldVal, set, t)
case time.Duration, *time.Duration, jsontime.Duration, *jsontime.Duration: case time.Duration, *time.Duration, jsontime.Duration, *jsontime.Duration:
dval := r.Form.Get(formField) d, set, err := o.parseDuration(ff)
d, set, err := o.parseDuration(dval)
if err != nil { if err != nil {
return err return err
} }
setVal(f, set, v, d) setVal(destFieldVal, set, d)
case []int: case []int:
val := strings.Trim(r.Form.Get(formField), "[]") val := strings.Trim(ff, "[]")
if val == "" && o.acceptBlank { if val == "" && o.acceptBlank {
continue continue
} }
@ -249,7 +287,7 @@ func (o *options) iterFields(r *http.Request, rv reflect.Value) error {
ar = append(ar, i) ar = append(ar, i)
} }
} }
f.Set(reflect.ValueOf(ar)) destFieldVal.Set(reflect.ValueOf(ar))
default: default:
panic(fmt.Errorf("unsupported type %T", v)) panic(fmt.Errorf("unsupported type %T", v))
} }
@ -258,48 +296,77 @@ func (o *options) iterFields(r *http.Request, rv reflect.Value) error {
return nil return nil
} }
func setVal(setField reflect.Value, set bool, fv any, sv any) { func setVal(destFieldVal reflect.Value, set bool, src any) {
if !set { if !set {
return return
} }
rv := reflect.TypeOf(fv) destType := destFieldVal.Type()
svo := reflect.ValueOf(sv) srcVal := reflect.ValueOf(src)
if svo.CanConvert(rv) { if srcVal.Kind() == reflect.Ptr {
svo = svo.Convert(rv) srcVal = srcVal.Elem()
} }
if rv.Kind() == reflect.Ptr { if destType.Kind() == reflect.Ptr {
svo = svo.Addr() if !srcVal.CanAddr() {
if srcVal.CanConvert(destType.Elem()) {
srcVal = srcVal.Convert(destType.Elem())
}
copy := reflect.New(srcVal.Type())
copy.Elem().Set(srcVal)
srcVal = copy
}
} else if srcVal.CanConvert(destFieldVal.Type()) {
srcVal = srcVal.Convert(destFieldVal.Type())
} }
setField.Set(svo) destFieldVal.Set(srcVal)
} }
func Unmarshal(r *http.Request, dest any, opt ...Option) error { func Unmarshal(r *http.Request, dest any, opt ...Option) error {
o := options{} o := options{
maxMultipartMemory: MaxMultipartMemory,
}
for _, opt := range opt { for _, opt := range opt {
opt(&o) opt(&o)
} }
rv := reflect.ValueOf(dest) contentType := strings.Split(r.Header.Get("Content-Type"), ";")[0]
if k := rv.Kind(); k == reflect.Ptr {
rv = rv.Elem() switch contentType {
} else { case "multipart/form-data":
return ErrNotPointer err := r.ParseMultipartForm(o.maxMultipartMemory)
if err != nil {
return fmt.Errorf("ParseForm: %w", err)
} }
if rv.Kind() != reflect.Struct { return o.unmarshalForm(r, dest)
return ErrNotStruct case "application/x-www-form-urlencoded":
}
if strings.HasPrefix(r.Header.Get("Content-Type"), "application/x-www-form-urlencoded") {
err := r.ParseForm() err := r.ParseForm()
if err != nil { if err != nil {
return fmt.Errorf("ParseForm: %w", err) return fmt.Errorf("ParseForm: %w", err)
} }
return o.unmarshalForm(r, dest)
case "application/json":
return json.NewDecoder(r.Body).Decode(dest)
} }
return o.iterFields(r, rv) return ErrContentType
}
func (o *options) unmarshalForm(r *http.Request, dest any) error {
destVal := reflect.ValueOf(dest)
if k := destVal.Kind(); k == reflect.Ptr {
destVal = destVal.Elem()
} else {
return ErrNotPointer
}
if destVal.Kind() != reflect.Struct {
return ErrNotStruct
}
return o.iterFields(r, destVal)
} }

View file

@ -8,6 +8,7 @@ import (
"testing" "testing"
"time" "time"
"dynatron.me/x/stillbox/internal/common"
"dynatron.me/x/stillbox/internal/forms" "dynatron.me/x/stillbox/internal/forms"
"dynatron.me/x/stillbox/internal/jsontime" "dynatron.me/x/stillbox/internal/jsontime"
@ -54,6 +55,14 @@ type urlEncTestJT struct {
ScoreEnd jsontime.Time `json:"scoreEnd"` ScoreEnd jsontime.Time `json:"scoreEnd"`
} }
type ptrTestJT struct {
LookbackDays uint `form:"lookbackDays"`
HalfLife *jsontime.Duration `form:"halfLife"`
Recent *string `form:"recent"`
ScoreStart *jsontime.Time `form:"scoreStart"`
ScoreEnd jsontime.Time `form:"scoreEnd"`
}
var ( var (
UrlEncTest = urlEncTest{ UrlEncTest = urlEncTest{
LookbackDays: 7, LookbackDays: 7,
@ -69,6 +78,13 @@ var (
ScoreStart: jsontime.Time(time.Date(2024, time.October, 28, 9, 25, 0, 0, time.UTC)), ScoreStart: jsontime.Time(time.Date(2024, time.October, 28, 9, 25, 0, 0, time.UTC)),
} }
PtrTestJT = ptrTestJT{
LookbackDays: 7,
HalfLife: common.PtrTo(jsontime.Duration(30 * time.Minute)),
Recent: common.PtrTo("2h0m0s"),
ScoreStart: common.PtrTo(jsontime.Time(time.Date(2024, time.October, 28, 9, 25, 0, 0, time.UTC))),
}
UrlEncTestJTLocal = urlEncTestJT{ UrlEncTestJTLocal = urlEncTestJT{
LookbackDays: 7, LookbackDays: 7,
HalfLife: jsontime.Duration(30 * time.Minute), HalfLife: jsontime.Duration(30 * time.Minute),
@ -122,7 +138,7 @@ func TestUnmarshal(t *testing.T) {
name string name string
r *http.Request r *http.Request
dest any dest any
compare any expect any
expectErr error expectErr error
opts []forms.Option opts []forms.Option
}{ }{
@ -130,21 +146,21 @@ func TestUnmarshal(t *testing.T) {
name: "base case", name: "base case",
r: makeRequest("call1.http"), r: makeRequest("call1.http"),
dest: &callUploadRequest{}, dest: &callUploadRequest{},
compare: &Call1, expect: &Call1,
opts: []forms.Option{forms.WithAcceptBlank()}, opts: []forms.Option{forms.WithAcceptBlank()},
}, },
{ {
name: "base case no accept blank", name: "base case no accept blank",
r: makeRequest("call1.http"), r: makeRequest("call1.http"),
dest: &callUploadRequest{}, dest: &callUploadRequest{},
compare: &Call1, expect: &Call1,
expectErr: errors.New(`parsebool(''): strconv.ParseBool: parsing "": invalid syntax`), expectErr: errors.New(`parsebool(''): strconv.ParseBool: parsing "": invalid syntax`),
}, },
{ {
name: "not a pointer", name: "not a pointer",
r: makeRequest("call1.http"), r: makeRequest("call1.http"),
dest: callUploadRequest{}, dest: callUploadRequest{},
compare: callUploadRequest{}, expect: callUploadRequest{},
expectErr: forms.ErrNotPointer, expectErr: forms.ErrNotPointer,
opts: []forms.Option{forms.WithAcceptBlank()}, opts: []forms.Option{forms.WithAcceptBlank()},
}, },
@ -152,7 +168,7 @@ func TestUnmarshal(t *testing.T) {
name: "not a struct", name: "not a struct",
r: makeRequest("call1.http"), r: makeRequest("call1.http"),
dest: &str, dest: &str,
compare: callUploadRequest{}, expect: callUploadRequest{},
expectErr: forms.ErrNotStruct, expectErr: forms.ErrNotStruct,
opts: []forms.Option{forms.WithAcceptBlank()}, opts: []forms.Option{forms.WithAcceptBlank()},
}, },
@ -160,21 +176,28 @@ func TestUnmarshal(t *testing.T) {
name: "url encoded", name: "url encoded",
r: makeRequest("urlenc.http"), r: makeRequest("urlenc.http"),
dest: &urlEncTest{}, dest: &urlEncTest{},
compare: &UrlEncTest, expect: &UrlEncTest,
expectErr: errors.New(`Could not find format for ""`), expectErr: errors.New(`Could not find format for ""`),
}, },
{ {
name: "url encoded accept blank", name: "url encoded accept blank",
r: makeRequest("urlenc.http"), r: makeRequest("urlenc.http"),
dest: &urlEncTest{}, dest: &urlEncTest{},
compare: &UrlEncTest, expect: &UrlEncTest,
opts: []forms.Option{forms.WithAcceptBlank()},
},
{
name: "url encoded accept blank pointer",
r: makeRequest("urlenc.http"),
dest: &ptrTestJT{},
expect: &PtrTestJT,
opts: []forms.Option{forms.WithAcceptBlank()}, opts: []forms.Option{forms.WithAcceptBlank()},
}, },
{ {
name: "url encoded jsontime", name: "url encoded jsontime",
r: makeRequest("urlenc.http"), r: makeRequest("urlenc.http"),
dest: &urlEncTestJT{}, dest: &urlEncTestJT{},
compare: &UrlEncTestJT, expect: &UrlEncTestJT,
expectErr: errors.New(`Could not find format for ""`), expectErr: errors.New(`Could not find format for ""`),
opts: []forms.Option{forms.WithTag("json")}, opts: []forms.Option{forms.WithTag("json")},
}, },
@ -182,21 +205,21 @@ func TestUnmarshal(t *testing.T) {
name: "url encoded jsontime with tz", name: "url encoded jsontime with tz",
r: makeRequest("urlenc.http"), r: makeRequest("urlenc.http"),
dest: &urlEncTestJT{}, dest: &urlEncTestJT{},
compare: &UrlEncTestJT, expect: &UrlEncTestJT,
opts: []forms.Option{forms.WithAcceptBlank(), forms.WithParseTimeInTZ(time.UTC), forms.WithTag("json")}, opts: []forms.Option{forms.WithAcceptBlank(), forms.WithParseTimeInTZ(time.UTC), forms.WithTag("json")},
}, },
{ {
name: "url encoded jsontime with local", name: "url encoded jsontime with local",
r: makeRequest("urlenc.http"), r: makeRequest("urlenc.http"),
dest: &urlEncTestJT{}, dest: &urlEncTestJT{},
compare: &UrlEncTestJTLocal, expect: &UrlEncTestJTLocal,
opts: []forms.Option{forms.WithAcceptBlank(), forms.WithParseLocalTime(), forms.WithTag("json")}, opts: []forms.Option{forms.WithAcceptBlank(), forms.WithParseLocalTime(), forms.WithTag("json")},
}, },
{ {
name: "sim real data", name: "sim real data",
r: makeRequest("urlenc2.http"), r: makeRequest("urlenc2.http"),
dest: &alerting.Simulation{}, dest: &alerting.Simulation{},
compare: realSim, expect: realSim,
opts: []forms.Option{forms.WithAcceptBlank(), forms.WithParseLocalTime()}, opts: []forms.Option{forms.WithAcceptBlank(), forms.WithParseLocalTime()},
}, },
} }
@ -209,7 +232,7 @@ func TestUnmarshal(t *testing.T) {
assert.Contains(t, tc.expectErr.Error(), err.Error()) assert.Contains(t, tc.expectErr.Error(), err.Error())
} else { } else {
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, tc.compare, tc.dest) assert.Equal(t, tc.expect, tc.dest)
} }
}) })
} }

View file

@ -44,8 +44,7 @@ func (a *Alert) ToAddAlertParams() database.AddAlertParams {
} }
} }
// makeAlert creates a notification for later rendering by the template. // Make creates an alert for later rendering or storage.
// It takes a talkgroup Score as input.
func Make(ctx context.Context, store talkgroups.Store, score trending.Score[talkgroups.ID], origScore float64) (Alert, error) { func Make(ctx context.Context, store talkgroups.Store, score trending.Score[talkgroups.ID], origScore float64) (Alert, error) {
d := Alert{ d := Alert{
ID: uuid.New(), ID: uuid.New(),

View file

@ -2,7 +2,6 @@ package alerting
import ( import (
"context" "context"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"net/http" "net/http"
@ -114,24 +113,15 @@ func (s *Simulation) Simulate(ctx context.Context) (trending.Scores[talkgroups.I
func (as *alerter) simulateHandler(w http.ResponseWriter, r *http.Request) { func (as *alerter) simulateHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
s := new(Simulation) s := new(Simulation)
switch r.Header.Get("Content-Type") {
case "application/json":
err := json.NewDecoder(r.Body).Decode(s)
if err != nil {
err = fmt.Errorf("simulate decode: %w", err)
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
default:
err := forms.Unmarshal(r, s, forms.WithAcceptBlank(), forms.WithParseLocalTime()) err := forms.Unmarshal(r, s, forms.WithAcceptBlank(), forms.WithParseLocalTime())
if err != nil { if err != nil {
err = fmt.Errorf("simulate unmarshal: %w", err) err = fmt.Errorf("simulate unmarshal: %w", err)
http.Error(w, err.Error(), http.StatusBadRequest) http.Error(w, err.Error(), http.StatusBadRequest)
return return
} }
}
err := s.verify() err = s.verify()
if err != nil { if err != nil {
err = fmt.Errorf("simulation profile verify: %w", err) err = fmt.Errorf("simulation profile verify: %w", err)
http.Error(w, err.Error(), http.StatusBadRequest) http.Error(w, err.Error(), http.StatusBadRequest)

View file

@ -1,127 +0,0 @@
package api
import (
"encoding/json"
"errors"
"net/http"
"dynatron.me/x/stillbox/pkg/talkgroups"
"github.com/go-chi/chi/v5"
"github.com/go-viper/mapstructure/v2"
"github.com/jackc/pgx/v5"
"github.com/rs/zerolog/log"
)
type API interface {
Subrouter() http.Handler
}
type api struct {
tgs talkgroups.Store
}
func New(tgs talkgroups.Store) API {
s := &api{
tgs: tgs,
}
return s
}
func (a *api) Subrouter() http.Handler {
r := chi.NewMux()
r.Get("/talkgroup/{system:\\d+}/{id:\\d+}", a.talkgroup)
r.Get("/talkgroup/{system:\\d+}/", a.talkgroup)
r.Get("/talkgroup/", a.talkgroup)
return r
}
var statusMapping = map[error]int{
talkgroups.ErrNotFound: http.StatusNotFound,
pgx.ErrNoRows: http.StatusNotFound,
}
func httpCode(err error) int {
c, ok := statusMapping[err]
if ok {
return c
}
for e, c := range statusMapping { // check if err wraps an error we know about
if errors.Is(err, e) {
return c
}
}
return http.StatusInternalServerError
}
func (a *api) writeResponse(w http.ResponseWriter, r *http.Request, data interface{}, err error) {
if err != nil {
log.Error().Str("path", r.URL.Path).Err(err).Msg("request failed")
http.Error(w, err.Error(), httpCode(err))
return
}
w.Header().Set("Content-Type", "application/json")
enc := json.NewEncoder(w)
err = enc.Encode(data)
if err != nil {
log.Error().Str("path", r.URL.Path).Err(err).Msg("response marshal failed")
http.Error(w, err.Error(), httpCode(err))
return
}
}
func decodeParams(d interface{}, r *http.Request) error {
params := chi.RouteContext(r.Context()).URLParams
m := make(map[string]string, len(params.Keys))
for i, k := range params.Keys {
m[k] = params.Values[i]
}
dec, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{
Metadata: nil,
Result: d,
TagName: "param",
WeaklyTypedInput: true,
})
if err != nil {
return err
}
return dec.Decode(m)
}
func (a *api) badReq(w http.ResponseWriter, err error) {
http.Error(w, err.Error(), http.StatusBadRequest)
}
func (a *api) talkgroup(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
p := struct {
System *int `param:"system"`
ID *int `param:"id"`
}{}
err := decodeParams(&p, r)
if err != nil {
a.badReq(w, err)
return
}
var res interface{}
switch {
case p.System != nil && p.ID != nil:
res, err = a.tgs.TG(ctx, talkgroups.TG(*p.System, *p.ID))
case p.System != nil:
res, err = a.tgs.SystemTGs(ctx, int32(*p.System))
default:
res, err = a.tgs.TGs(ctx, nil)
}
a.writeResponse(w, r, res, err)
}

View file

@ -5,7 +5,7 @@
</head> </head>
<body> <body>
<div> <div>
<form action="/login" method="POST"> <form action="/api/login" method="POST">
<label for="username">Username: </label> <label for="username">Username: </label>
<input type="text" name="username" /> <input type="text" name="username" />
<label for="password">Password: </label> <label for="password">Password: </label>

View file

@ -56,13 +56,13 @@ func NewClient(ctx context.Context, conf config.DB) (*DB, error) {
return db, nil return db, nil
} }
type DBCtxKey string type dBCtxKey string
const DBCTXKeyValue DBCtxKey = "dbctx" const DBCtxKey dBCtxKey = "dbctx"
// FromCtx returns the database handle from the provided Context. // FromCtx returns the database handle from the provided Context.
func FromCtx(ctx context.Context) *DB { func FromCtx(ctx context.Context) *DB {
c, ok := ctx.Value(DBCTXKeyValue).(*DB) c, ok := ctx.Value(DBCtxKey).(*DB)
if !ok { if !ok {
panic("no DB in context") panic("no DB in context")
} }
@ -72,7 +72,7 @@ func FromCtx(ctx context.Context) *DB {
// CtxWithDB returns a Context with the provided database handle. // CtxWithDB returns a Context with the provided database handle.
func CtxWithDB(ctx context.Context, conn *DB) context.Context { func CtxWithDB(ctx context.Context, conn *DB) context.Context {
return context.WithValue(ctx, DBCTXKeyValue, conn) return context.WithValue(ctx, DBCtxKey, conn)
} }
// IsNoRows is a convenience function that returns whether a returned error is a database // IsNoRows is a convenience function that returns whether a returned error is a database

View file

@ -9,3 +9,6 @@ func (g GetTalkgroupsWithLearnedRow) GetLearned() bool { retur
func (g GetTalkgroupsWithLearnedBySystemRow) GetTalkgroup() Talkgroup { return g.Talkgroup } func (g GetTalkgroupsWithLearnedBySystemRow) GetTalkgroup() Talkgroup { return g.Talkgroup }
func (g GetTalkgroupsWithLearnedBySystemRow) GetSystem() System { return g.System } func (g GetTalkgroupsWithLearnedBySystemRow) GetSystem() System { return g.System }
func (g GetTalkgroupsWithLearnedBySystemRow) GetLearned() bool { return g.Learned } func (g GetTalkgroupsWithLearnedBySystemRow) GetLearned() bool { return g.Learned }
func (g Talkgroup) GetTalkgroup() Talkgroup { return g }
func (g Talkgroup) GetSystem() System { return System{ID: int(g.SystemID)} }
func (g Talkgroup) GetLearned() bool { return false }

View file

@ -39,6 +39,7 @@ type Querier interface {
SetCallTranscript(ctx context.Context, iD uuid.UUID, transcript *string) error SetCallTranscript(ctx context.Context, iD uuid.UUID, transcript *string) error
SetTalkgroupTags(ctx context.Context, sys int, tg int, tags []string) error SetTalkgroupTags(ctx context.Context, sys int, tg int, tags []string) error
UpdatePassword(ctx context.Context, username string, password string) error UpdatePassword(ctx context.Context, username string, password string) error
UpdateTalkgroup(ctx context.Context, arg UpdateTalkgroupParams) (Talkgroup, error)
} }
var _ Querier = (*Queries)(nil) var _ Querier = (*Queries)(nil)

View file

@ -472,3 +472,63 @@ func (q *Queries) SetTalkgroupTags(ctx context.Context, sys int, tg int, tags []
_, err := q.db.Exec(ctx, setTalkgroupTags, sys, tg, tags) _, err := q.db.Exec(ctx, setTalkgroupTags, sys, tg, tags)
return err return err
} }
const updateTalkgroup = `-- name: UpdateTalkgroup :one
UPDATE talkgroups
SET
name = COALESCE($1, name),
alpha_tag = COALESCE($2, alpha_tag),
tg_group = COALESCE($3, tg_group),
frequency = COALESCE($4, frequency),
metadata = COALESCE($5, metadata),
tags = COALESCE($6, tags),
alert = COALESCE($7, alert),
alert_config = COALESCE($8, alert_config),
weight = COALESCE($9, weight)
WHERE id = $10
RETURNING id, system_id, tgid, name, alpha_tag, tg_group, frequency, metadata, tags, alert, alert_config, weight
`
type UpdateTalkgroupParams struct {
Name *string `json:"name"`
AlphaTag *string `json:"alpha_tag"`
TgGroup *string `json:"tg_group"`
Frequency *int32 `json:"frequency"`
Metadata []byte `json:"metadata"`
Tags []string `json:"tags"`
Alert *bool `json:"alert"`
AlertConfig []byte `json:"alert_config"`
Weight *float32 `json:"weight"`
ID int64 `json:"id"`
}
func (q *Queries) UpdateTalkgroup(ctx context.Context, arg UpdateTalkgroupParams) (Talkgroup, error) {
row := q.db.QueryRow(ctx, updateTalkgroup,
arg.Name,
arg.AlphaTag,
arg.TgGroup,
arg.Frequency,
arg.Metadata,
arg.Tags,
arg.Alert,
arg.AlertConfig,
arg.Weight,
arg.ID,
)
var i Talkgroup
err := row.Scan(
&i.ID,
&i.SystemID,
&i.Tgid,
&i.Name,
&i.AlphaTag,
&i.TgGroup,
&i.Frequency,
&i.Metadata,
&i.Tags,
&i.Alert,
&i.AlertConfig,
&i.Weight,
)
return i, err
}

132
pkg/rest/api.go Normal file
View file

@ -0,0 +1,132 @@
package rest
import (
"errors"
"net/http"
"dynatron.me/x/stillbox/pkg/talkgroups"
"github.com/go-chi/chi/v5"
"github.com/go-chi/render"
"github.com/go-viper/mapstructure/v2"
"github.com/jackc/pgx/v5"
"github.com/rs/zerolog/log"
)
type API interface {
Subrouter() http.Handler
}
type api struct {
}
func New() API {
s := new(api)
return s
}
func (a *api) Subrouter() http.Handler {
r := chi.NewMux()
r.Mount("/talkgroup", new(talkgroupAPI).Subrouter())
return r
}
type errResponse struct {
Err error `json:"-"`
Code int `json:"-"`
Error string `json:"error"`
}
func (e *errResponse) Render(w http.ResponseWriter, r *http.Request) error {
switch e.Code {
case http.StatusNotFound:
case http.StatusBadRequest:
default:
log.Error().Str("path", r.URL.Path).Err(e.Err).Int("code", e.Code).Str("msg", e.Error).Msg("request failed")
}
render.Status(r, e.Code)
return nil
}
func badRequest(err error) render.Renderer {
return &errResponse{
Err: err,
Code: http.StatusBadRequest,
Error: "Bad request",
}
}
func recordNotFound(err error) render.Renderer {
return &errResponse{
Err: err,
Code: http.StatusNotFound,
Error: "Record not found",
}
}
func internalError(err error) render.Renderer {
return &errResponse{
Err: err,
Code: http.StatusNotFound,
Error: "Internal server error",
}
}
type errResponder func(error) render.Renderer
var statusMapping = map[error]errResponder{
talkgroups.ErrNotFound: recordNotFound,
pgx.ErrNoRows: recordNotFound,
}
func autoError(err error) render.Renderer {
c, ok := statusMapping[err]
if ok {
c(err)
}
for e, c := range statusMapping { // check if err wraps an error we know about
if errors.Is(err, e) {
return c(err)
}
}
return internalError(err)
}
func wErr(w http.ResponseWriter, r *http.Request, v render.Renderer) {
err := render.Render(w, r, v)
if err != nil {
log.Error().Err(err).Msg("wErr render error")
}
}
func decodeParams(d interface{}, r *http.Request) error {
params := chi.RouteContext(r.Context()).URLParams
m := make(map[string]string, len(params.Keys))
for i, k := range params.Keys {
m[k] = params.Values[i]
}
dec, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{
Metadata: nil,
Result: d,
TagName: "param",
WeaklyTypedInput: true,
})
if err != nil {
return err
}
return dec.Decode(m)
}
func respond(w http.ResponseWriter, r *http.Request, v interface{}) {
render.DefaultResponder(w, r, v)
}

108
pkg/rest/talkgroups.go Normal file
View file

@ -0,0 +1,108 @@
package rest
import (
"net/http"
"dynatron.me/x/stillbox/internal/forms"
"dynatron.me/x/stillbox/pkg/database"
"dynatron.me/x/stillbox/pkg/talkgroups"
"github.com/go-chi/chi/v5"
)
type talkgroupAPI struct {
}
func (tga *talkgroupAPI) Subrouter() http.Handler {
r := chi.NewMux()
r.Get("/{system:\\d+}/{id:\\d+}", tga.get)
r.Put("/{system:\\d+}/{id:\\d+}", tga.put)
r.Get("/{system:\\d+}/", tga.get)
r.Get("/", tga.get)
return r
}
type tgParams struct {
System *int `param:"system"`
ID *int `param:"id"`
}
func (t tgParams) haveBoth() bool {
return t.System != nil && t.ID != nil
}
func (t tgParams) ToID() talkgroups.ID {
nilOr := func(i *int) uint32 {
if i == nil {
return 0
}
return uint32(*i)
}
return talkgroups.ID{
System: nilOr(t.System),
Talkgroup: nilOr(t.ID),
}
}
func (tga *talkgroupAPI) get(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
tgs := talkgroups.StoreFrom(ctx)
var p tgParams
err := decodeParams(&p, r)
if err != nil {
wErr(w, r, badRequest(err))
return
}
var res interface{}
switch {
case p.System != nil && p.ID != nil:
res, err = tgs.TG(ctx, talkgroups.TG(*p.System, *p.ID))
case p.System != nil:
res, err = tgs.SystemTGs(ctx, int32(*p.System))
default:
res, err = tgs.TGs(ctx, nil)
}
if err != nil {
wErr(w, r, autoError(err))
return
}
respond(w, r, res)
}
func (tga *talkgroupAPI) put(w http.ResponseWriter, r *http.Request) {
var id tgParams
err := decodeParams(&id, r)
if err != nil {
wErr(w, r, badRequest(err))
return
}
ctx := r.Context()
tgs := talkgroups.StoreFrom(ctx)
input := database.UpdateTalkgroupParams{}
err = forms.Unmarshal(r, &input, forms.WithTag("json"), forms.WithAcceptBlank(), forms.WithOmitEmpty())
if err != nil {
wErr(w, r, badRequest(err))
return
}
input.ID = id.ToID().Pack()
record, err := tgs.UpdateTG(ctx, input)
if err != nil {
wErr(w, r, autoError(err))
return
}
respond(w, r, record)
}

View file

@ -4,14 +4,9 @@
package server package server
import ( import (
"net/http/pprof" "github.com/go-chi/chi/v5/middleware"
) )
func (s *Server) installPprof() { func (s *Server) installPprof() {
r := s.r s.r.Mount("/debug", middleware.Profiler())
r.HandleFunc("/debug/pprof/", pprof.Index)
r.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline)
r.HandleFunc("/debug/pprof/profile", pprof.Profile)
r.HandleFunc("/debug/pprof/symbol", pprof.Symbol)
r.HandleFunc("/debug/pprof/trace", pprof.Trace)
} }

View file

@ -9,6 +9,7 @@ import (
"dynatron.me/x/stillbox/internal/version" "dynatron.me/x/stillbox/internal/version"
"dynatron.me/x/stillbox/pkg/config" "dynatron.me/x/stillbox/pkg/config"
"dynatron.me/x/stillbox/pkg/database" "dynatron.me/x/stillbox/pkg/database"
"dynatron.me/x/stillbox/pkg/talkgroups"
"github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware" "github.com/go-chi/chi/v5/middleware"
"github.com/go-chi/httprate" "github.com/go-chi/httprate"
@ -26,7 +27,8 @@ func (s *Server) setupRoutes() {
} }
r := s.r r := s.r
r.Use(middleware.WithValue(database.DBCTXKeyValue, s.db)) r.Use(middleware.WithValue(database.DBCtxKey, s.db))
r.Use(middleware.WithValue(talkgroups.StoreCtxKey, s.tgs))
s.installPprof() s.installPprof()
@ -36,7 +38,7 @@ func (s *Server) setupRoutes() {
s.nex.PrivateRoutes(r) s.nex.PrivateRoutes(r)
s.auth.PrivateRoutes(r) s.auth.PrivateRoutes(r)
s.alerter.PrivateRoutes(r) s.alerter.PrivateRoutes(r)
r.Mount("/api", s.api.Subrouter()) r.Mount("/api", s.rest.Subrouter())
}) })
r.Group(func(r chi.Router) { r.Group(func(r chi.Router) {

View file

@ -7,12 +7,12 @@ import (
"time" "time"
"dynatron.me/x/stillbox/pkg/alerting" "dynatron.me/x/stillbox/pkg/alerting"
"dynatron.me/x/stillbox/pkg/api"
"dynatron.me/x/stillbox/pkg/auth" "dynatron.me/x/stillbox/pkg/auth"
"dynatron.me/x/stillbox/pkg/config" "dynatron.me/x/stillbox/pkg/config"
"dynatron.me/x/stillbox/pkg/database" "dynatron.me/x/stillbox/pkg/database"
"dynatron.me/x/stillbox/pkg/nexus" "dynatron.me/x/stillbox/pkg/nexus"
"dynatron.me/x/stillbox/pkg/notify" "dynatron.me/x/stillbox/pkg/notify"
"dynatron.me/x/stillbox/pkg/rest"
"dynatron.me/x/stillbox/pkg/sinks" "dynatron.me/x/stillbox/pkg/sinks"
"dynatron.me/x/stillbox/pkg/sources" "dynatron.me/x/stillbox/pkg/sources"
"dynatron.me/x/stillbox/pkg/talkgroups" "dynatron.me/x/stillbox/pkg/talkgroups"
@ -37,7 +37,7 @@ type Server struct {
notifier notify.Notifier notifier notify.Notifier
hup chan os.Signal hup chan os.Signal
tgs talkgroups.Store tgs talkgroups.Store
api api.API rest rest.API
} }
func New(ctx context.Context, cfg *config.Config) (*Server, error) { func New(ctx context.Context, cfg *config.Config) (*Server, error) {
@ -61,7 +61,7 @@ func New(ctx context.Context, cfg *config.Config) (*Server, error) {
} }
tgCache := talkgroups.NewCache() tgCache := talkgroups.NewCache()
api := api.New(tgCache) api := rest.New()
srv := &Server{ srv := &Server{
auth: authenticator, auth: authenticator,
@ -73,7 +73,7 @@ func New(ctx context.Context, cfg *config.Config) (*Server, error) {
alerter: alerting.New(cfg.Alerting, tgCache, alerting.WithNotifier(notifier)), alerter: alerting.New(cfg.Alerting, tgCache, alerting.WithNotifier(notifier)),
notifier: notifier, notifier: notifier,
tgs: tgCache, tgs: tgCache,
api: api, rest: api,
} }
srv.sinks.Register("database", sinks.NewDatabaseSink(srv.db), true) srv.sinks.Register("database", sinks.NewDatabaseSink(srv.db), true)

View file

@ -17,7 +17,15 @@ import (
type tgMap map[ID]*Talkgroup type tgMap map[ID]*Talkgroup
var (
ErrNotFound = errors.New("talkgroup not found")
ErrNoSuchSystem = errors.New("no such system")
)
type Store interface { type Store interface {
// UpdateTG updates a talkgroup record.
UpdateTG(ctx context.Context, input database.UpdateTalkgroupParams) (*Talkgroup, error)
// TG retrieves a Talkgroup from the Store. // TG retrieves a Talkgroup from the Store.
TG(ctx context.Context, tg ID) (*Talkgroup, error) TG(ctx context.Context, tg ID) (*Talkgroup, error)
@ -49,16 +57,16 @@ type Store interface {
HUP(*config.Config) HUP(*config.Config)
} }
type CtxStoreKeyT string type storeCtxKey string
const CtxStoreKey CtxStoreKeyT = "store" const StoreCtxKey storeCtxKey = "store"
func CtxWithStore(ctx context.Context, s Store) context.Context { func CtxWithStore(ctx context.Context, s Store) context.Context {
return context.WithValue(ctx, CtxStoreKey, s) return context.WithValue(ctx, StoreCtxKey, s)
} }
func StoreFrom(ctx context.Context) Store { func StoreFrom(ctx context.Context) Store {
s, ok := ctx.Value(CtxStoreKey).(Store) s, ok := ctx.Value(StoreCtxKey).(Store)
if !ok { if !ok {
return NewCache() return NewCache()
} }
@ -213,8 +221,6 @@ func (t *cache) Load(ctx context.Context, tgs []int64) error {
return nil return nil
} }
var ErrNotFound = errors.New("talkgroup not found")
func (t *cache) Weight(ctx context.Context, id ID, tm time.Time) float64 { func (t *cache) Weight(ctx context.Context, id ID, tm time.Time) float64 {
tg, err := t.TG(ctx, id) tg, err := t.TG(ctx, id)
if err != nil { if err != nil {
@ -290,3 +296,23 @@ func (t *cache) SystemName(ctx context.Context, id int) (name string, has bool)
return n, has return n, has
} }
func (t *cache) UpdateTG(ctx context.Context, input database.UpdateTalkgroupParams) (*Talkgroup, error) {
sysName, has := t.SystemName(ctx, int(Unpack(input.ID).System))
if !has {
return nil, ErrNoSuchSystem
}
tg, err := database.FromCtx(ctx).UpdateTalkgroup(ctx, input)
if err != nil {
return nil, err
}
record := &Talkgroup{
Talkgroup: tg,
System: database.System{ID: int(tg.SystemID), Name: sysName},
}
t.add(record)
return record, nil
}

View file

@ -49,6 +49,13 @@ func (t ID) Pack() int64 {
return int64((int64(t.System) << 32) | int64(t.Talkgroup)) return int64((int64(t.System) << 32) | int64(t.Talkgroup))
} }
func Unpack(id int64) ID {
return ID{
System: uint32(id >> 32),
Talkgroup: uint32(id & 0xffffffff),
}
}
func (t ID) String() string { func (t ID) String() string {
return fmt.Sprintf("%d:%d", t.System, t.Talkgroup) return fmt.Sprintf("%d:%d", t.System, t.Talkgroup)

View file

@ -104,6 +104,20 @@ FROM talkgroups_learned tgl
JOIN systems sys ON tgl.system_id = sys.id JOIN systems sys ON tgl.system_id = sys.id
WHERE ignored IS NOT TRUE; WHERE ignored IS NOT TRUE;
-- name: GetSystemName :one -- name: GetSystemName :one
SELECT name FROM systems WHERE id = sqlc.arg(system_id); SELECT name FROM systems WHERE id = sqlc.arg(system_id);
-- name: UpdateTalkgroup :one
UPDATE talkgroups
SET
name = COALESCE(sqlc.narg('name'), name),
alpha_tag = COALESCE(sqlc.narg('alpha_tag'), alpha_tag),
tg_group = COALESCE(sqlc.narg('tg_group'), tg_group),
frequency = COALESCE(sqlc.narg('frequency'), frequency),
metadata = COALESCE(sqlc.narg('metadata'), metadata),
tags = COALESCE(sqlc.narg('tags'), tags),
alert = COALESCE(sqlc.narg('alert'), alert),
alert_config = COALESCE(sqlc.narg('alert_config'), alert_config),
weight = COALESCE(sqlc.narg('weight'), weight)
WHERE id = @id
RETURNING *;