api restructure, support pointers in forms
Tests pass Support pointer fields
This commit is contained in:
parent
e4e819ee90
commit
e3a7313806
9 changed files with 273 additions and 168 deletions
|
@ -1,6 +1,7 @@
|
|||
package forms
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
@ -18,6 +19,11 @@ import (
|
|||
var (
|
||||
ErrNotStruct = errors.New("destination is not a struct")
|
||||
ErrNotPointer = errors.New("destination is not a pointer")
|
||||
ErrContentType = errors.New("bad content type")
|
||||
)
|
||||
|
||||
const (
|
||||
MaxMultipartMemory int64 = 1024 * 1024 // 1MB
|
||||
)
|
||||
|
||||
type options struct {
|
||||
|
@ -25,6 +31,8 @@ type options struct {
|
|||
parseTimeIn *time.Location
|
||||
parseLocal bool
|
||||
acceptBlank bool
|
||||
maxMultipartMemory int64
|
||||
defaultOmitEmpty bool
|
||||
}
|
||||
|
||||
type Option func(*options)
|
||||
|
@ -53,6 +61,18 @@ func WithTag(t string) Option {
|
|||
}
|
||||
}
|
||||
|
||||
func WithMaxMultipartSize(s int64) Option {
|
||||
return func(o *options) {
|
||||
o.maxMultipartMemory = s
|
||||
}
|
||||
}
|
||||
|
||||
func WithOmitEmpty() Option {
|
||||
return func(o *options) {
|
||||
o.defaultOmitEmpty = true
|
||||
}
|
||||
}
|
||||
|
||||
func (o *options) Tag() string {
|
||||
if o.tagOverride != nil {
|
||||
return *o.tagOverride
|
||||
|
@ -147,17 +167,19 @@ func (o *options) parseDuration(s string) (v time.Duration, set bool, err error)
|
|||
return
|
||||
}
|
||||
|
||||
func (o *options) iterFields(r *http.Request, rv reflect.Value) error {
|
||||
rt := rv.Type()
|
||||
for i := 0; i < rv.NumField(); i++ {
|
||||
f := rv.Field(i)
|
||||
tf := rt.Field(i)
|
||||
if !tf.IsExported() && !tf.Anonymous {
|
||||
var typeOfByteSlice = reflect.TypeOf([]byte(nil))
|
||||
|
||||
func (o *options) iterFields(r *http.Request, destStruct reflect.Value) error {
|
||||
structType := destStruct.Type()
|
||||
for i := 0; i < destStruct.NumField(); i++ {
|
||||
destFieldVal := destStruct.Field(i)
|
||||
fieldType := structType.Field(i)
|
||||
if !fieldType.IsExported() && !fieldType.Anonymous {
|
||||
continue
|
||||
}
|
||||
|
||||
if f.Kind() == reflect.Struct && tf.Anonymous {
|
||||
err := o.iterFields(r, f)
|
||||
if destFieldVal.Kind() == reflect.Struct && fieldType.Anonymous {
|
||||
err := o.iterFields(r, destFieldVal)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -165,51 +187,38 @@ func (o *options) iterFields(r *http.Request, rv reflect.Value) error {
|
|||
|
||||
var tAr []string
|
||||
var formField string
|
||||
formTag, has := rt.Field(i).Tag.Lookup(o.Tag())
|
||||
var omitEmpty bool
|
||||
if o.defaultOmitEmpty {
|
||||
omitEmpty = true
|
||||
}
|
||||
|
||||
formTag, has := structType.Field(i).Tag.Lookup(o.Tag())
|
||||
if has {
|
||||
tAr = strings.Split(formTag, ",")
|
||||
formField = tAr[0]
|
||||
for _, v := range tAr[1:] {
|
||||
if v == "omitempty" {
|
||||
omitEmpty = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !has || formField == "-" {
|
||||
continue
|
||||
}
|
||||
|
||||
fi := f.Interface()
|
||||
destFieldIntf := destFieldVal.Interface()
|
||||
|
||||
switch v := fi.(type) {
|
||||
case string, *string:
|
||||
s := r.Form.Get(formField)
|
||||
setVal(f, s != "" || o.acceptBlank, v, s)
|
||||
case int, uint, *int, *uint:
|
||||
ff := r.Form.Get(formField)
|
||||
val, set, err := o.parseInt(ff)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
setVal(f, set, v, val)
|
||||
case float64:
|
||||
ff := r.Form.Get(formField)
|
||||
val, set, err := o.parseFloat64(ff)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
setVal(f, set, v, val)
|
||||
case bool, *bool:
|
||||
ff := r.Form.Get(formField)
|
||||
val, set, err := o.parseBool(ff)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
setVal(f, set, v, val)
|
||||
case []byte:
|
||||
if destFieldVal.Kind() == reflect.Slice && destFieldVal.Type() == typeOfByteSlice {
|
||||
file, hdr, err := r.FormFile(formField)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get form file: %w", err)
|
||||
}
|
||||
|
||||
nameField, hasFilename := rt.Field(i).Tag.Lookup("filenameField")
|
||||
nameField, hasFilename := structType.Field(i).Tag.Lookup("filenameField")
|
||||
if hasFilename {
|
||||
fnf := rv.FieldByName(nameField)
|
||||
fnf := destStruct.FieldByName(nameField)
|
||||
if fnf == (reflect.Value{}) {
|
||||
panic(fmt.Errorf("filenameField '%s' does not exist", nameField))
|
||||
}
|
||||
|
@ -221,23 +230,52 @@ func (o *options) iterFields(r *http.Request, rv reflect.Value) error {
|
|||
return fmt.Errorf("file read: %w", err)
|
||||
}
|
||||
|
||||
f.SetBytes(audioBytes)
|
||||
destFieldVal.SetBytes(audioBytes)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
if !r.Form.Has(formField) && omitEmpty {
|
||||
continue
|
||||
}
|
||||
|
||||
ff := r.Form.Get(formField)
|
||||
|
||||
switch v := destFieldIntf.(type) {
|
||||
case string, *string:
|
||||
setVal(destFieldVal, ff != "" || o.acceptBlank, ff)
|
||||
case int, uint, *int, *uint:
|
||||
val, set, err := o.parseInt(ff)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
setVal(destFieldVal, set, val)
|
||||
case float64:
|
||||
val, set, err := o.parseFloat64(ff)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
setVal(destFieldVal, set, val)
|
||||
case bool, *bool:
|
||||
val, set, err := o.parseBool(ff)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
setVal(destFieldVal, set, val)
|
||||
case time.Time, *time.Time, jsontime.Time, *jsontime.Time:
|
||||
tval := r.Form.Get(formField)
|
||||
t, set, err := o.parseTime(tval)
|
||||
t, set, err := o.parseTime(ff)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
setVal(f, set, v, t)
|
||||
setVal(destFieldVal, set, t)
|
||||
case time.Duration, *time.Duration, jsontime.Duration, *jsontime.Duration:
|
||||
dval := r.Form.Get(formField)
|
||||
d, set, err := o.parseDuration(dval)
|
||||
d, set, err := o.parseDuration(ff)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
setVal(f, set, v, d)
|
||||
setVal(destFieldVal, set, d)
|
||||
case []int:
|
||||
val := strings.Trim(r.Form.Get(formField), "[]")
|
||||
val := strings.Trim(ff, "[]")
|
||||
if val == "" && o.acceptBlank {
|
||||
continue
|
||||
}
|
||||
|
@ -249,7 +287,7 @@ func (o *options) iterFields(r *http.Request, rv reflect.Value) error {
|
|||
ar = append(ar, i)
|
||||
}
|
||||
}
|
||||
f.Set(reflect.ValueOf(ar))
|
||||
destFieldVal.Set(reflect.ValueOf(ar))
|
||||
default:
|
||||
panic(fmt.Errorf("unsupported type %T", v))
|
||||
}
|
||||
|
@ -258,48 +296,77 @@ func (o *options) iterFields(r *http.Request, rv reflect.Value) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func setVal(setField reflect.Value, set bool, fv any, sv any) {
|
||||
func setVal(destFieldVal reflect.Value, set bool, src any) {
|
||||
if !set {
|
||||
return
|
||||
}
|
||||
|
||||
rv := reflect.TypeOf(fv)
|
||||
svo := reflect.ValueOf(sv)
|
||||
destType := destFieldVal.Type()
|
||||
srcVal := reflect.ValueOf(src)
|
||||
|
||||
if svo.CanConvert(rv) {
|
||||
svo = svo.Convert(rv)
|
||||
if srcVal.Kind() == reflect.Ptr {
|
||||
srcVal = srcVal.Elem()
|
||||
}
|
||||
|
||||
if rv.Kind() == reflect.Ptr {
|
||||
svo = svo.Addr()
|
||||
if destType.Kind() == reflect.Ptr {
|
||||
if !srcVal.CanAddr() {
|
||||
if srcVal.CanConvert(destType.Elem()) {
|
||||
srcVal = srcVal.Convert(destType.Elem())
|
||||
}
|
||||
copy := reflect.New(srcVal.Type())
|
||||
copy.Elem().Set(srcVal)
|
||||
srcVal = copy
|
||||
}
|
||||
} else if srcVal.CanConvert(destFieldVal.Type()) {
|
||||
srcVal = srcVal.Convert(destFieldVal.Type())
|
||||
}
|
||||
|
||||
setField.Set(svo)
|
||||
destFieldVal.Set(srcVal)
|
||||
}
|
||||
|
||||
func Unmarshal(r *http.Request, dest any, opt ...Option) error {
|
||||
o := options{}
|
||||
o := options{
|
||||
maxMultipartMemory: MaxMultipartMemory,
|
||||
}
|
||||
|
||||
for _, opt := range opt {
|
||||
opt(&o)
|
||||
}
|
||||
|
||||
rv := reflect.ValueOf(dest)
|
||||
if k := rv.Kind(); k == reflect.Ptr {
|
||||
rv = rv.Elem()
|
||||
} else {
|
||||
return ErrNotPointer
|
||||
contentType := strings.Split(r.Header.Get("Content-Type"), ";")[0]
|
||||
|
||||
switch contentType {
|
||||
case "multipart/form-data":
|
||||
err := r.ParseMultipartForm(o.maxMultipartMemory)
|
||||
if err != nil {
|
||||
return fmt.Errorf("ParseForm: %w", err)
|
||||
}
|
||||
|
||||
if rv.Kind() != reflect.Struct {
|
||||
return ErrNotStruct
|
||||
}
|
||||
|
||||
if strings.HasPrefix(r.Header.Get("Content-Type"), "application/x-www-form-urlencoded") {
|
||||
return o.unmarshalForm(r, dest)
|
||||
case "application/x-www-form-urlencoded":
|
||||
err := r.ParseForm()
|
||||
if err != nil {
|
||||
return fmt.Errorf("ParseForm: %w", err)
|
||||
}
|
||||
return o.unmarshalForm(r, dest)
|
||||
case "application/json":
|
||||
return json.NewDecoder(r.Body).Decode(dest)
|
||||
}
|
||||
|
||||
return o.iterFields(r, rv)
|
||||
return ErrContentType
|
||||
}
|
||||
|
||||
func (o *options) unmarshalForm(r *http.Request, dest any) error {
|
||||
destVal := reflect.ValueOf(dest)
|
||||
if k := destVal.Kind(); k == reflect.Ptr {
|
||||
destVal = destVal.Elem()
|
||||
} else {
|
||||
return ErrNotPointer
|
||||
}
|
||||
|
||||
if destVal.Kind() != reflect.Struct {
|
||||
return ErrNotStruct
|
||||
}
|
||||
|
||||
return o.iterFields(r, destVal)
|
||||
}
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"dynatron.me/x/stillbox/internal/common"
|
||||
"dynatron.me/x/stillbox/internal/forms"
|
||||
"dynatron.me/x/stillbox/internal/jsontime"
|
||||
|
||||
|
@ -54,6 +55,14 @@ type urlEncTestJT struct {
|
|||
ScoreEnd jsontime.Time `json:"scoreEnd"`
|
||||
}
|
||||
|
||||
type ptrTestJT struct {
|
||||
LookbackDays uint `form:"lookbackDays"`
|
||||
HalfLife *jsontime.Duration `form:"halfLife"`
|
||||
Recent *string `form:"recent"`
|
||||
ScoreStart *jsontime.Time `form:"scoreStart"`
|
||||
ScoreEnd jsontime.Time `form:"scoreEnd"`
|
||||
}
|
||||
|
||||
var (
|
||||
UrlEncTest = urlEncTest{
|
||||
LookbackDays: 7,
|
||||
|
@ -69,6 +78,13 @@ var (
|
|||
ScoreStart: jsontime.Time(time.Date(2024, time.October, 28, 9, 25, 0, 0, time.UTC)),
|
||||
}
|
||||
|
||||
PtrTestJT = ptrTestJT{
|
||||
LookbackDays: 7,
|
||||
HalfLife: common.PtrTo(jsontime.Duration(30 * time.Minute)),
|
||||
Recent: common.PtrTo("2h0m0s"),
|
||||
ScoreStart: common.PtrTo(jsontime.Time(time.Date(2024, time.October, 28, 9, 25, 0, 0, time.UTC))),
|
||||
}
|
||||
|
||||
UrlEncTestJTLocal = urlEncTestJT{
|
||||
LookbackDays: 7,
|
||||
HalfLife: jsontime.Duration(30 * time.Minute),
|
||||
|
@ -122,7 +138,7 @@ func TestUnmarshal(t *testing.T) {
|
|||
name string
|
||||
r *http.Request
|
||||
dest any
|
||||
compare any
|
||||
expect any
|
||||
expectErr error
|
||||
opts []forms.Option
|
||||
}{
|
||||
|
@ -130,21 +146,21 @@ func TestUnmarshal(t *testing.T) {
|
|||
name: "base case",
|
||||
r: makeRequest("call1.http"),
|
||||
dest: &callUploadRequest{},
|
||||
compare: &Call1,
|
||||
expect: &Call1,
|
||||
opts: []forms.Option{forms.WithAcceptBlank()},
|
||||
},
|
||||
{
|
||||
name: "base case no accept blank",
|
||||
r: makeRequest("call1.http"),
|
||||
dest: &callUploadRequest{},
|
||||
compare: &Call1,
|
||||
expect: &Call1,
|
||||
expectErr: errors.New(`parsebool(''): strconv.ParseBool: parsing "": invalid syntax`),
|
||||
},
|
||||
{
|
||||
name: "not a pointer",
|
||||
r: makeRequest("call1.http"),
|
||||
dest: callUploadRequest{},
|
||||
compare: callUploadRequest{},
|
||||
expect: callUploadRequest{},
|
||||
expectErr: forms.ErrNotPointer,
|
||||
opts: []forms.Option{forms.WithAcceptBlank()},
|
||||
},
|
||||
|
@ -152,7 +168,7 @@ func TestUnmarshal(t *testing.T) {
|
|||
name: "not a struct",
|
||||
r: makeRequest("call1.http"),
|
||||
dest: &str,
|
||||
compare: callUploadRequest{},
|
||||
expect: callUploadRequest{},
|
||||
expectErr: forms.ErrNotStruct,
|
||||
opts: []forms.Option{forms.WithAcceptBlank()},
|
||||
},
|
||||
|
@ -160,21 +176,28 @@ func TestUnmarshal(t *testing.T) {
|
|||
name: "url encoded",
|
||||
r: makeRequest("urlenc.http"),
|
||||
dest: &urlEncTest{},
|
||||
compare: &UrlEncTest,
|
||||
expect: &UrlEncTest,
|
||||
expectErr: errors.New(`Could not find format for ""`),
|
||||
},
|
||||
{
|
||||
name: "url encoded accept blank",
|
||||
r: makeRequest("urlenc.http"),
|
||||
dest: &urlEncTest{},
|
||||
compare: &UrlEncTest,
|
||||
expect: &UrlEncTest,
|
||||
opts: []forms.Option{forms.WithAcceptBlank()},
|
||||
},
|
||||
{
|
||||
name: "url encoded accept blank pointer",
|
||||
r: makeRequest("urlenc.http"),
|
||||
dest: &ptrTestJT{},
|
||||
expect: &PtrTestJT,
|
||||
opts: []forms.Option{forms.WithAcceptBlank()},
|
||||
},
|
||||
{
|
||||
name: "url encoded jsontime",
|
||||
r: makeRequest("urlenc.http"),
|
||||
dest: &urlEncTestJT{},
|
||||
compare: &UrlEncTestJT,
|
||||
expect: &UrlEncTestJT,
|
||||
expectErr: errors.New(`Could not find format for ""`),
|
||||
opts: []forms.Option{forms.WithTag("json")},
|
||||
},
|
||||
|
@ -182,21 +205,21 @@ func TestUnmarshal(t *testing.T) {
|
|||
name: "url encoded jsontime with tz",
|
||||
r: makeRequest("urlenc.http"),
|
||||
dest: &urlEncTestJT{},
|
||||
compare: &UrlEncTestJT,
|
||||
expect: &UrlEncTestJT,
|
||||
opts: []forms.Option{forms.WithAcceptBlank(), forms.WithParseTimeInTZ(time.UTC), forms.WithTag("json")},
|
||||
},
|
||||
{
|
||||
name: "url encoded jsontime with local",
|
||||
r: makeRequest("urlenc.http"),
|
||||
dest: &urlEncTestJT{},
|
||||
compare: &UrlEncTestJTLocal,
|
||||
expect: &UrlEncTestJTLocal,
|
||||
opts: []forms.Option{forms.WithAcceptBlank(), forms.WithParseLocalTime(), forms.WithTag("json")},
|
||||
},
|
||||
{
|
||||
name: "sim real data",
|
||||
r: makeRequest("urlenc2.http"),
|
||||
dest: &alerting.Simulation{},
|
||||
compare: realSim,
|
||||
expect: realSim,
|
||||
opts: []forms.Option{forms.WithAcceptBlank(), forms.WithParseLocalTime()},
|
||||
},
|
||||
}
|
||||
|
@ -209,7 +232,7 @@ func TestUnmarshal(t *testing.T) {
|
|||
assert.Contains(t, tc.expectErr.Error(), err.Error())
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tc.compare, tc.dest)
|
||||
assert.Equal(t, tc.expect, tc.dest)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
@ -44,8 +44,7 @@ func (a *Alert) ToAddAlertParams() database.AddAlertParams {
|
|||
}
|
||||
}
|
||||
|
||||
// makeAlert creates a notification for later rendering by the template.
|
||||
// It takes a talkgroup Score as input.
|
||||
// Make creates an alert for later rendering or storage.
|
||||
func Make(ctx context.Context, store talkgroups.Store, score trending.Score[talkgroups.ID], origScore float64) (Alert, error) {
|
||||
d := Alert{
|
||||
ID: uuid.New(),
|
||||
|
|
|
@ -2,7 +2,6 @@ package alerting
|
|||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
@ -114,24 +113,15 @@ func (s *Simulation) Simulate(ctx context.Context) (trending.Scores[talkgroups.I
|
|||
func (as *alerter) simulateHandler(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
s := new(Simulation)
|
||||
switch r.Header.Get("Content-Type") {
|
||||
case "application/json":
|
||||
err := json.NewDecoder(r.Body).Decode(s)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("simulate decode: %w", err)
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
default:
|
||||
|
||||
err := forms.Unmarshal(r, s, forms.WithAcceptBlank(), forms.WithParseLocalTime())
|
||||
if err != nil {
|
||||
err = fmt.Errorf("simulate unmarshal: %w", err)
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
err := s.verify()
|
||||
err = s.verify()
|
||||
if err != nil {
|
||||
err = fmt.Errorf("simulation profile verify: %w", err)
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
|
|
|
@ -18,13 +18,10 @@ type API interface {
|
|||
}
|
||||
|
||||
type api struct {
|
||||
tgs talkgroups.Store
|
||||
}
|
||||
|
||||
func New(tgs talkgroups.Store) API {
|
||||
s := &api{
|
||||
tgs: tgs,
|
||||
}
|
||||
func New() API {
|
||||
s := new(api)
|
||||
|
||||
return s
|
||||
}
|
||||
|
@ -32,9 +29,8 @@ func New(tgs talkgroups.Store) API {
|
|||
func (a *api) Subrouter() http.Handler {
|
||||
r := chi.NewMux()
|
||||
|
||||
r.Get("/talkgroup/{system:\\d+}/{id:\\d+}", a.talkgroup)
|
||||
r.Get("/talkgroup/{system:\\d+}/", a.talkgroup)
|
||||
r.Get("/talkgroup/", a.talkgroup)
|
||||
r.Mount("/talkgroup", new(talkgroupAPI).routes())
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
|
@ -58,7 +54,7 @@ func httpCode(err error) int {
|
|||
return http.StatusInternalServerError
|
||||
}
|
||||
|
||||
func (a *api) writeResponse(w http.ResponseWriter, r *http.Request, data interface{}, err error) {
|
||||
func writeResponse(w http.ResponseWriter, r *http.Request, data interface{}, err error) {
|
||||
if err != nil {
|
||||
log.Error().Str("path", r.URL.Path).Err(err).Msg("request failed")
|
||||
http.Error(w, err.Error(), httpCode(err))
|
||||
|
@ -96,32 +92,6 @@ func decodeParams(d interface{}, r *http.Request) error {
|
|||
return dec.Decode(m)
|
||||
}
|
||||
|
||||
func (a *api) badReq(w http.ResponseWriter, err error) {
|
||||
func badReq(w http.ResponseWriter, err error) {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
}
|
||||
|
||||
func (a *api) talkgroup(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
p := struct {
|
||||
System *int `param:"system"`
|
||||
ID *int `param:"id"`
|
||||
}{}
|
||||
|
||||
err := decodeParams(&p, r)
|
||||
if err != nil {
|
||||
a.badReq(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
var res interface{}
|
||||
switch {
|
||||
case p.System != nil && p.ID != nil:
|
||||
res, err = a.tgs.TG(ctx, talkgroups.TG(*p.System, *p.ID))
|
||||
case p.System != nil:
|
||||
res, err = a.tgs.SystemTGs(ctx, int32(*p.System))
|
||||
default:
|
||||
res, err = a.tgs.TGs(ctx, nil)
|
||||
}
|
||||
|
||||
a.writeResponse(w, r, res, err)
|
||||
}
|
||||
|
|
54
pkg/api/talkgroups.go
Normal file
54
pkg/api/talkgroups.go
Normal file
|
@ -0,0 +1,54 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"dynatron.me/x/stillbox/internal/forms"
|
||||
"dynatron.me/x/stillbox/pkg/talkgroups"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
)
|
||||
|
||||
type talkgroupAPI struct {
|
||||
}
|
||||
|
||||
func (tga *talkgroupAPI) routes() http.Handler {
|
||||
r := chi.NewMux()
|
||||
|
||||
r.Get("/{system:\\d+}/{id:\\d+}", tga.talkgroup)
|
||||
r.Put("/{system:\\d+}/{id:\\d+}", tga.putTalkgroup)
|
||||
r.Get("/{system:\\d+}/", tga.talkgroup)
|
||||
r.Get("/", tga.talkgroup)
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
func (tga *talkgroupAPI) talkgroup(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
tgs := talkgroups.StoreFrom(ctx)
|
||||
p := struct {
|
||||
System *int `param:"system"`
|
||||
ID *int `param:"id"`
|
||||
}{}
|
||||
|
||||
err := decodeParams(&p, r)
|
||||
if err != nil {
|
||||
badReq(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
var res interface{}
|
||||
switch {
|
||||
case p.System != nil && p.ID != nil:
|
||||
res, err = tgs.TG(ctx, talkgroups.TG(*p.System, *p.ID))
|
||||
case p.System != nil:
|
||||
res, err = tgs.SystemTGs(ctx, int32(*p.System))
|
||||
default:
|
||||
res, err = tgs.TGs(ctx, nil)
|
||||
}
|
||||
|
||||
writeResponse(w, r, res, err)
|
||||
}
|
||||
|
||||
func (tga *talkgroupAPI) putTalkgroup(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
|
@ -58,11 +58,11 @@ func NewClient(ctx context.Context, conf config.DB) (*DB, error) {
|
|||
|
||||
type DBCtxKey string
|
||||
|
||||
const DBCTXKeyValue DBCtxKey = "dbctx"
|
||||
const DBCtxKeyValue DBCtxKey = "dbctx"
|
||||
|
||||
// FromCtx returns the database handle from the provided Context.
|
||||
func FromCtx(ctx context.Context) *DB {
|
||||
c, ok := ctx.Value(DBCTXKeyValue).(*DB)
|
||||
c, ok := ctx.Value(DBCtxKeyValue).(*DB)
|
||||
if !ok {
|
||||
panic("no DB in context")
|
||||
}
|
||||
|
@ -72,7 +72,7 @@ func FromCtx(ctx context.Context) *DB {
|
|||
|
||||
// CtxWithDB returns a Context with the provided database handle.
|
||||
func CtxWithDB(ctx context.Context, conn *DB) context.Context {
|
||||
return context.WithValue(ctx, DBCTXKeyValue, conn)
|
||||
return context.WithValue(ctx, DBCtxKeyValue, conn)
|
||||
}
|
||||
|
||||
// IsNoRows is a convenience function that returns whether a returned error is a database
|
||||
|
|
|
@ -9,6 +9,7 @@ import (
|
|||
"dynatron.me/x/stillbox/internal/version"
|
||||
"dynatron.me/x/stillbox/pkg/config"
|
||||
"dynatron.me/x/stillbox/pkg/database"
|
||||
"dynatron.me/x/stillbox/pkg/talkgroups"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/go-chi/chi/v5/middleware"
|
||||
"github.com/go-chi/httprate"
|
||||
|
@ -26,7 +27,8 @@ func (s *Server) setupRoutes() {
|
|||
}
|
||||
|
||||
r := s.r
|
||||
r.Use(middleware.WithValue(database.DBCTXKeyValue, s.db))
|
||||
r.Use(middleware.WithValue(database.DBCtxKeyValue, s.db))
|
||||
r.Use(middleware.WithValue(talkgroups.StoreCtxKeyValue, s.tgs))
|
||||
|
||||
s.installPprof()
|
||||
|
||||
|
|
|
@ -49,16 +49,16 @@ type Store interface {
|
|||
HUP(*config.Config)
|
||||
}
|
||||
|
||||
type CtxStoreKeyT string
|
||||
type CtxStoreKey string
|
||||
|
||||
const CtxStoreKey CtxStoreKeyT = "store"
|
||||
const StoreCtxKeyValue CtxStoreKey = "store"
|
||||
|
||||
func CtxWithStore(ctx context.Context, s Store) context.Context {
|
||||
return context.WithValue(ctx, CtxStoreKey, s)
|
||||
return context.WithValue(ctx, StoreCtxKeyValue, s)
|
||||
}
|
||||
|
||||
func StoreFrom(ctx context.Context) Store {
|
||||
s, ok := ctx.Value(CtxStoreKey).(Store)
|
||||
s, ok := ctx.Value(StoreCtxKeyValue).(Store)
|
||||
if !ok {
|
||||
return NewCache()
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue