Compare commits
9 commits
e4e819ee90
...
657c00e326
Author | SHA1 | Date | |
---|---|---|---|
657c00e326 | |||
fb3fb4eeab | |||
759c274950 | |||
cecbeb78fe | |||
e97c9ced0e | |||
6e1640e4b4 | |||
f76db949e0 | |||
9046e346b1 | |||
e3a7313806 |
18 changed files with 576 additions and 276 deletions
|
@ -1,6 +1,7 @@
|
||||||
package forms
|
package forms
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
@ -16,15 +17,22 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ErrNotStruct = errors.New("destination is not a struct")
|
ErrNotStruct = errors.New("destination is not a struct")
|
||||||
ErrNotPointer = errors.New("destination is not a pointer")
|
ErrNotPointer = errors.New("destination is not a pointer")
|
||||||
|
ErrContentType = errors.New("bad content type")
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
MaxMultipartMemory int64 = 1024 * 1024 // 1MB
|
||||||
)
|
)
|
||||||
|
|
||||||
type options struct {
|
type options struct {
|
||||||
tagOverride *string
|
tagOverride *string
|
||||||
parseTimeIn *time.Location
|
parseTimeIn *time.Location
|
||||||
parseLocal bool
|
parseLocal bool
|
||||||
acceptBlank bool
|
acceptBlank bool
|
||||||
|
maxMultipartMemory int64
|
||||||
|
defaultOmitEmpty bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type Option func(*options)
|
type Option func(*options)
|
||||||
|
@ -53,6 +61,18 @@ func WithTag(t string) Option {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func WithMaxMultipartSize(s int64) Option {
|
||||||
|
return func(o *options) {
|
||||||
|
o.maxMultipartMemory = s
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithOmitEmpty() Option {
|
||||||
|
return func(o *options) {
|
||||||
|
o.defaultOmitEmpty = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (o *options) Tag() string {
|
func (o *options) Tag() string {
|
||||||
if o.tagOverride != nil {
|
if o.tagOverride != nil {
|
||||||
return *o.tagOverride
|
return *o.tagOverride
|
||||||
|
@ -147,17 +167,19 @@ func (o *options) parseDuration(s string) (v time.Duration, set bool, err error)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *options) iterFields(r *http.Request, rv reflect.Value) error {
|
var typeOfByteSlice = reflect.TypeOf([]byte(nil))
|
||||||
rt := rv.Type()
|
|
||||||
for i := 0; i < rv.NumField(); i++ {
|
func (o *options) iterFields(r *http.Request, destStruct reflect.Value) error {
|
||||||
f := rv.Field(i)
|
structType := destStruct.Type()
|
||||||
tf := rt.Field(i)
|
for i := 0; i < destStruct.NumField(); i++ {
|
||||||
if !tf.IsExported() && !tf.Anonymous {
|
destFieldVal := destStruct.Field(i)
|
||||||
|
fieldType := structType.Field(i)
|
||||||
|
if !fieldType.IsExported() && !fieldType.Anonymous {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if f.Kind() == reflect.Struct && tf.Anonymous {
|
if destFieldVal.Kind() == reflect.Struct && fieldType.Anonymous {
|
||||||
err := o.iterFields(r, f)
|
err := o.iterFields(r, destFieldVal)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -165,51 +187,38 @@ func (o *options) iterFields(r *http.Request, rv reflect.Value) error {
|
||||||
|
|
||||||
var tAr []string
|
var tAr []string
|
||||||
var formField string
|
var formField string
|
||||||
formTag, has := rt.Field(i).Tag.Lookup(o.Tag())
|
var omitEmpty bool
|
||||||
|
if o.defaultOmitEmpty {
|
||||||
|
omitEmpty = true
|
||||||
|
}
|
||||||
|
|
||||||
|
formTag, has := structType.Field(i).Tag.Lookup(o.Tag())
|
||||||
if has {
|
if has {
|
||||||
tAr = strings.Split(formTag, ",")
|
tAr = strings.Split(formTag, ",")
|
||||||
formField = tAr[0]
|
formField = tAr[0]
|
||||||
|
for _, v := range tAr[1:] {
|
||||||
|
if v == "omitempty" {
|
||||||
|
omitEmpty = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if !has || formField == "-" {
|
if !has || formField == "-" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
fi := f.Interface()
|
destFieldIntf := destFieldVal.Interface()
|
||||||
|
|
||||||
switch v := fi.(type) {
|
if destFieldVal.Kind() == reflect.Slice && destFieldVal.Type() == typeOfByteSlice {
|
||||||
case string, *string:
|
|
||||||
s := r.Form.Get(formField)
|
|
||||||
setVal(f, s != "" || o.acceptBlank, v, s)
|
|
||||||
case int, uint, *int, *uint:
|
|
||||||
ff := r.Form.Get(formField)
|
|
||||||
val, set, err := o.parseInt(ff)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
setVal(f, set, v, val)
|
|
||||||
case float64:
|
|
||||||
ff := r.Form.Get(formField)
|
|
||||||
val, set, err := o.parseFloat64(ff)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
setVal(f, set, v, val)
|
|
||||||
case bool, *bool:
|
|
||||||
ff := r.Form.Get(formField)
|
|
||||||
val, set, err := o.parseBool(ff)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
setVal(f, set, v, val)
|
|
||||||
case []byte:
|
|
||||||
file, hdr, err := r.FormFile(formField)
|
file, hdr, err := r.FormFile(formField)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("get form file: %w", err)
|
return fmt.Errorf("get form file: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
nameField, hasFilename := rt.Field(i).Tag.Lookup("filenameField")
|
nameField, hasFilename := structType.Field(i).Tag.Lookup("filenameField")
|
||||||
if hasFilename {
|
if hasFilename {
|
||||||
fnf := rv.FieldByName(nameField)
|
fnf := destStruct.FieldByName(nameField)
|
||||||
if fnf == (reflect.Value{}) {
|
if fnf == (reflect.Value{}) {
|
||||||
panic(fmt.Errorf("filenameField '%s' does not exist", nameField))
|
panic(fmt.Errorf("filenameField '%s' does not exist", nameField))
|
||||||
}
|
}
|
||||||
|
@ -221,23 +230,52 @@ func (o *options) iterFields(r *http.Request, rv reflect.Value) error {
|
||||||
return fmt.Errorf("file read: %w", err)
|
return fmt.Errorf("file read: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
f.SetBytes(audioBytes)
|
destFieldVal.SetBytes(audioBytes)
|
||||||
|
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if !r.Form.Has(formField) && omitEmpty {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
ff := r.Form.Get(formField)
|
||||||
|
|
||||||
|
switch v := destFieldIntf.(type) {
|
||||||
|
case string, *string:
|
||||||
|
setVal(destFieldVal, ff != "" || o.acceptBlank, ff)
|
||||||
|
case int, uint, *int, *uint:
|
||||||
|
val, set, err := o.parseInt(ff)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
setVal(destFieldVal, set, val)
|
||||||
|
case float64:
|
||||||
|
val, set, err := o.parseFloat64(ff)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
setVal(destFieldVal, set, val)
|
||||||
|
case bool, *bool:
|
||||||
|
val, set, err := o.parseBool(ff)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
setVal(destFieldVal, set, val)
|
||||||
case time.Time, *time.Time, jsontime.Time, *jsontime.Time:
|
case time.Time, *time.Time, jsontime.Time, *jsontime.Time:
|
||||||
tval := r.Form.Get(formField)
|
t, set, err := o.parseTime(ff)
|
||||||
t, set, err := o.parseTime(tval)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
setVal(f, set, v, t)
|
setVal(destFieldVal, set, t)
|
||||||
case time.Duration, *time.Duration, jsontime.Duration, *jsontime.Duration:
|
case time.Duration, *time.Duration, jsontime.Duration, *jsontime.Duration:
|
||||||
dval := r.Form.Get(formField)
|
d, set, err := o.parseDuration(ff)
|
||||||
d, set, err := o.parseDuration(dval)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
setVal(f, set, v, d)
|
setVal(destFieldVal, set, d)
|
||||||
case []int:
|
case []int:
|
||||||
val := strings.Trim(r.Form.Get(formField), "[]")
|
val := strings.Trim(ff, "[]")
|
||||||
if val == "" && o.acceptBlank {
|
if val == "" && o.acceptBlank {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
@ -249,7 +287,7 @@ func (o *options) iterFields(r *http.Request, rv reflect.Value) error {
|
||||||
ar = append(ar, i)
|
ar = append(ar, i)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
f.Set(reflect.ValueOf(ar))
|
destFieldVal.Set(reflect.ValueOf(ar))
|
||||||
default:
|
default:
|
||||||
panic(fmt.Errorf("unsupported type %T", v))
|
panic(fmt.Errorf("unsupported type %T", v))
|
||||||
}
|
}
|
||||||
|
@ -258,48 +296,77 @@ func (o *options) iterFields(r *http.Request, rv reflect.Value) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func setVal(setField reflect.Value, set bool, fv any, sv any) {
|
func setVal(destFieldVal reflect.Value, set bool, src any) {
|
||||||
if !set {
|
if !set {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
rv := reflect.TypeOf(fv)
|
destType := destFieldVal.Type()
|
||||||
svo := reflect.ValueOf(sv)
|
srcVal := reflect.ValueOf(src)
|
||||||
|
|
||||||
if svo.CanConvert(rv) {
|
if srcVal.Kind() == reflect.Ptr {
|
||||||
svo = svo.Convert(rv)
|
srcVal = srcVal.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
if rv.Kind() == reflect.Ptr {
|
if destType.Kind() == reflect.Ptr {
|
||||||
svo = svo.Addr()
|
if !srcVal.CanAddr() {
|
||||||
|
if srcVal.CanConvert(destType.Elem()) {
|
||||||
|
srcVal = srcVal.Convert(destType.Elem())
|
||||||
|
}
|
||||||
|
copy := reflect.New(srcVal.Type())
|
||||||
|
copy.Elem().Set(srcVal)
|
||||||
|
srcVal = copy
|
||||||
|
}
|
||||||
|
} else if srcVal.CanConvert(destFieldVal.Type()) {
|
||||||
|
srcVal = srcVal.Convert(destFieldVal.Type())
|
||||||
}
|
}
|
||||||
|
|
||||||
setField.Set(svo)
|
destFieldVal.Set(srcVal)
|
||||||
}
|
}
|
||||||
|
|
||||||
func Unmarshal(r *http.Request, dest any, opt ...Option) error {
|
func Unmarshal(r *http.Request, dest any, opt ...Option) error {
|
||||||
o := options{}
|
o := options{
|
||||||
|
maxMultipartMemory: MaxMultipartMemory,
|
||||||
|
}
|
||||||
|
|
||||||
for _, opt := range opt {
|
for _, opt := range opt {
|
||||||
opt(&o)
|
opt(&o)
|
||||||
}
|
}
|
||||||
|
|
||||||
rv := reflect.ValueOf(dest)
|
contentType := strings.Split(r.Header.Get("Content-Type"), ";")[0]
|
||||||
if k := rv.Kind(); k == reflect.Ptr {
|
|
||||||
rv = rv.Elem()
|
|
||||||
} else {
|
|
||||||
return ErrNotPointer
|
|
||||||
}
|
|
||||||
|
|
||||||
if rv.Kind() != reflect.Struct {
|
switch contentType {
|
||||||
return ErrNotStruct
|
case "multipart/form-data":
|
||||||
}
|
err := r.ParseMultipartForm(o.maxMultipartMemory)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("ParseForm: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
if strings.HasPrefix(r.Header.Get("Content-Type"), "application/x-www-form-urlencoded") {
|
return o.unmarshalForm(r, dest)
|
||||||
|
case "application/x-www-form-urlencoded":
|
||||||
err := r.ParseForm()
|
err := r.ParseForm()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("ParseForm: %w", err)
|
return fmt.Errorf("ParseForm: %w", err)
|
||||||
}
|
}
|
||||||
|
return o.unmarshalForm(r, dest)
|
||||||
|
case "application/json":
|
||||||
|
return json.NewDecoder(r.Body).Decode(dest)
|
||||||
}
|
}
|
||||||
|
|
||||||
return o.iterFields(r, rv)
|
return ErrContentType
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *options) unmarshalForm(r *http.Request, dest any) error {
|
||||||
|
destVal := reflect.ValueOf(dest)
|
||||||
|
if k := destVal.Kind(); k == reflect.Ptr {
|
||||||
|
destVal = destVal.Elem()
|
||||||
|
} else {
|
||||||
|
return ErrNotPointer
|
||||||
|
}
|
||||||
|
|
||||||
|
if destVal.Kind() != reflect.Struct {
|
||||||
|
return ErrNotStruct
|
||||||
|
}
|
||||||
|
|
||||||
|
return o.iterFields(r, destVal)
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,6 +8,7 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"dynatron.me/x/stillbox/internal/common"
|
||||||
"dynatron.me/x/stillbox/internal/forms"
|
"dynatron.me/x/stillbox/internal/forms"
|
||||||
"dynatron.me/x/stillbox/internal/jsontime"
|
"dynatron.me/x/stillbox/internal/jsontime"
|
||||||
|
|
||||||
|
@ -54,6 +55,14 @@ type urlEncTestJT struct {
|
||||||
ScoreEnd jsontime.Time `json:"scoreEnd"`
|
ScoreEnd jsontime.Time `json:"scoreEnd"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ptrTestJT struct {
|
||||||
|
LookbackDays uint `form:"lookbackDays"`
|
||||||
|
HalfLife *jsontime.Duration `form:"halfLife"`
|
||||||
|
Recent *string `form:"recent"`
|
||||||
|
ScoreStart *jsontime.Time `form:"scoreStart"`
|
||||||
|
ScoreEnd jsontime.Time `form:"scoreEnd"`
|
||||||
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
UrlEncTest = urlEncTest{
|
UrlEncTest = urlEncTest{
|
||||||
LookbackDays: 7,
|
LookbackDays: 7,
|
||||||
|
@ -69,6 +78,13 @@ var (
|
||||||
ScoreStart: jsontime.Time(time.Date(2024, time.October, 28, 9, 25, 0, 0, time.UTC)),
|
ScoreStart: jsontime.Time(time.Date(2024, time.October, 28, 9, 25, 0, 0, time.UTC)),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
PtrTestJT = ptrTestJT{
|
||||||
|
LookbackDays: 7,
|
||||||
|
HalfLife: common.PtrTo(jsontime.Duration(30 * time.Minute)),
|
||||||
|
Recent: common.PtrTo("2h0m0s"),
|
||||||
|
ScoreStart: common.PtrTo(jsontime.Time(time.Date(2024, time.October, 28, 9, 25, 0, 0, time.UTC))),
|
||||||
|
}
|
||||||
|
|
||||||
UrlEncTestJTLocal = urlEncTestJT{
|
UrlEncTestJTLocal = urlEncTestJT{
|
||||||
LookbackDays: 7,
|
LookbackDays: 7,
|
||||||
HalfLife: jsontime.Duration(30 * time.Minute),
|
HalfLife: jsontime.Duration(30 * time.Minute),
|
||||||
|
@ -122,29 +138,29 @@ func TestUnmarshal(t *testing.T) {
|
||||||
name string
|
name string
|
||||||
r *http.Request
|
r *http.Request
|
||||||
dest any
|
dest any
|
||||||
compare any
|
expect any
|
||||||
expectErr error
|
expectErr error
|
||||||
opts []forms.Option
|
opts []forms.Option
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "base case",
|
name: "base case",
|
||||||
r: makeRequest("call1.http"),
|
r: makeRequest("call1.http"),
|
||||||
dest: &callUploadRequest{},
|
dest: &callUploadRequest{},
|
||||||
compare: &Call1,
|
expect: &Call1,
|
||||||
opts: []forms.Option{forms.WithAcceptBlank()},
|
opts: []forms.Option{forms.WithAcceptBlank()},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "base case no accept blank",
|
name: "base case no accept blank",
|
||||||
r: makeRequest("call1.http"),
|
r: makeRequest("call1.http"),
|
||||||
dest: &callUploadRequest{},
|
dest: &callUploadRequest{},
|
||||||
compare: &Call1,
|
expect: &Call1,
|
||||||
expectErr: errors.New(`parsebool(''): strconv.ParseBool: parsing "": invalid syntax`),
|
expectErr: errors.New(`parsebool(''): strconv.ParseBool: parsing "": invalid syntax`),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "not a pointer",
|
name: "not a pointer",
|
||||||
r: makeRequest("call1.http"),
|
r: makeRequest("call1.http"),
|
||||||
dest: callUploadRequest{},
|
dest: callUploadRequest{},
|
||||||
compare: callUploadRequest{},
|
expect: callUploadRequest{},
|
||||||
expectErr: forms.ErrNotPointer,
|
expectErr: forms.ErrNotPointer,
|
||||||
opts: []forms.Option{forms.WithAcceptBlank()},
|
opts: []forms.Option{forms.WithAcceptBlank()},
|
||||||
},
|
},
|
||||||
|
@ -152,7 +168,7 @@ func TestUnmarshal(t *testing.T) {
|
||||||
name: "not a struct",
|
name: "not a struct",
|
||||||
r: makeRequest("call1.http"),
|
r: makeRequest("call1.http"),
|
||||||
dest: &str,
|
dest: &str,
|
||||||
compare: callUploadRequest{},
|
expect: callUploadRequest{},
|
||||||
expectErr: forms.ErrNotStruct,
|
expectErr: forms.ErrNotStruct,
|
||||||
opts: []forms.Option{forms.WithAcceptBlank()},
|
opts: []forms.Option{forms.WithAcceptBlank()},
|
||||||
},
|
},
|
||||||
|
@ -160,44 +176,51 @@ func TestUnmarshal(t *testing.T) {
|
||||||
name: "url encoded",
|
name: "url encoded",
|
||||||
r: makeRequest("urlenc.http"),
|
r: makeRequest("urlenc.http"),
|
||||||
dest: &urlEncTest{},
|
dest: &urlEncTest{},
|
||||||
compare: &UrlEncTest,
|
expect: &UrlEncTest,
|
||||||
expectErr: errors.New(`Could not find format for ""`),
|
expectErr: errors.New(`Could not find format for ""`),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "url encoded accept blank",
|
name: "url encoded accept blank",
|
||||||
r: makeRequest("urlenc.http"),
|
r: makeRequest("urlenc.http"),
|
||||||
dest: &urlEncTest{},
|
dest: &urlEncTest{},
|
||||||
compare: &UrlEncTest,
|
expect: &UrlEncTest,
|
||||||
opts: []forms.Option{forms.WithAcceptBlank()},
|
opts: []forms.Option{forms.WithAcceptBlank()},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "url encoded accept blank pointer",
|
||||||
|
r: makeRequest("urlenc.http"),
|
||||||
|
dest: &ptrTestJT{},
|
||||||
|
expect: &PtrTestJT,
|
||||||
|
opts: []forms.Option{forms.WithAcceptBlank()},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "url encoded jsontime",
|
name: "url encoded jsontime",
|
||||||
r: makeRequest("urlenc.http"),
|
r: makeRequest("urlenc.http"),
|
||||||
dest: &urlEncTestJT{},
|
dest: &urlEncTestJT{},
|
||||||
compare: &UrlEncTestJT,
|
expect: &UrlEncTestJT,
|
||||||
expectErr: errors.New(`Could not find format for ""`),
|
expectErr: errors.New(`Could not find format for ""`),
|
||||||
opts: []forms.Option{forms.WithTag("json")},
|
opts: []forms.Option{forms.WithTag("json")},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "url encoded jsontime with tz",
|
name: "url encoded jsontime with tz",
|
||||||
r: makeRequest("urlenc.http"),
|
r: makeRequest("urlenc.http"),
|
||||||
dest: &urlEncTestJT{},
|
dest: &urlEncTestJT{},
|
||||||
compare: &UrlEncTestJT,
|
expect: &UrlEncTestJT,
|
||||||
opts: []forms.Option{forms.WithAcceptBlank(), forms.WithParseTimeInTZ(time.UTC), forms.WithTag("json")},
|
opts: []forms.Option{forms.WithAcceptBlank(), forms.WithParseTimeInTZ(time.UTC), forms.WithTag("json")},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "url encoded jsontime with local",
|
name: "url encoded jsontime with local",
|
||||||
r: makeRequest("urlenc.http"),
|
r: makeRequest("urlenc.http"),
|
||||||
dest: &urlEncTestJT{},
|
dest: &urlEncTestJT{},
|
||||||
compare: &UrlEncTestJTLocal,
|
expect: &UrlEncTestJTLocal,
|
||||||
opts: []forms.Option{forms.WithAcceptBlank(), forms.WithParseLocalTime(), forms.WithTag("json")},
|
opts: []forms.Option{forms.WithAcceptBlank(), forms.WithParseLocalTime(), forms.WithTag("json")},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "sim real data",
|
name: "sim real data",
|
||||||
r: makeRequest("urlenc2.http"),
|
r: makeRequest("urlenc2.http"),
|
||||||
dest: &alerting.Simulation{},
|
dest: &alerting.Simulation{},
|
||||||
compare: realSim,
|
expect: realSim,
|
||||||
opts: []forms.Option{forms.WithAcceptBlank(), forms.WithParseLocalTime()},
|
opts: []forms.Option{forms.WithAcceptBlank(), forms.WithParseLocalTime()},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -209,7 +232,7 @@ func TestUnmarshal(t *testing.T) {
|
||||||
assert.Contains(t, tc.expectErr.Error(), err.Error())
|
assert.Contains(t, tc.expectErr.Error(), err.Error())
|
||||||
} else {
|
} else {
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, tc.compare, tc.dest)
|
assert.Equal(t, tc.expect, tc.dest)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -44,8 +44,7 @@ func (a *Alert) ToAddAlertParams() database.AddAlertParams {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// makeAlert creates a notification for later rendering by the template.
|
// Make creates an alert for later rendering or storage.
|
||||||
// It takes a talkgroup Score as input.
|
|
||||||
func Make(ctx context.Context, store talkgroups.Store, score trending.Score[talkgroups.ID], origScore float64) (Alert, error) {
|
func Make(ctx context.Context, store talkgroups.Store, score trending.Score[talkgroups.ID], origScore float64) (Alert, error) {
|
||||||
d := Alert{
|
d := Alert{
|
||||||
ID: uuid.New(),
|
ID: uuid.New(),
|
||||||
|
|
|
@ -2,7 +2,6 @@ package alerting
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
@ -114,24 +113,15 @@ func (s *Simulation) Simulate(ctx context.Context) (trending.Scores[talkgroups.I
|
||||||
func (as *alerter) simulateHandler(w http.ResponseWriter, r *http.Request) {
|
func (as *alerter) simulateHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
s := new(Simulation)
|
s := new(Simulation)
|
||||||
switch r.Header.Get("Content-Type") {
|
|
||||||
case "application/json":
|
err := forms.Unmarshal(r, s, forms.WithAcceptBlank(), forms.WithParseLocalTime())
|
||||||
err := json.NewDecoder(r.Body).Decode(s)
|
if err != nil {
|
||||||
if err != nil {
|
err = fmt.Errorf("simulate unmarshal: %w", err)
|
||||||
err = fmt.Errorf("simulate decode: %w", err)
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
return
|
||||||
return
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
err := forms.Unmarshal(r, s, forms.WithAcceptBlank(), forms.WithParseLocalTime())
|
|
||||||
if err != nil {
|
|
||||||
err = fmt.Errorf("simulate unmarshal: %w", err)
|
|
||||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
err := s.verify()
|
err = s.verify()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = fmt.Errorf("simulation profile verify: %w", err)
|
err = fmt.Errorf("simulation profile verify: %w", err)
|
||||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||||
|
|
127
pkg/api/api.go
127
pkg/api/api.go
|
@ -1,127 +0,0 @@
|
||||||
package api
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
|
||||||
"net/http"
|
|
||||||
|
|
||||||
"dynatron.me/x/stillbox/pkg/talkgroups"
|
|
||||||
|
|
||||||
"github.com/go-chi/chi/v5"
|
|
||||||
"github.com/go-viper/mapstructure/v2"
|
|
||||||
"github.com/jackc/pgx/v5"
|
|
||||||
"github.com/rs/zerolog/log"
|
|
||||||
)
|
|
||||||
|
|
||||||
type API interface {
|
|
||||||
Subrouter() http.Handler
|
|
||||||
}
|
|
||||||
|
|
||||||
type api struct {
|
|
||||||
tgs talkgroups.Store
|
|
||||||
}
|
|
||||||
|
|
||||||
func New(tgs talkgroups.Store) API {
|
|
||||||
s := &api{
|
|
||||||
tgs: tgs,
|
|
||||||
}
|
|
||||||
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *api) Subrouter() http.Handler {
|
|
||||||
r := chi.NewMux()
|
|
||||||
|
|
||||||
r.Get("/talkgroup/{system:\\d+}/{id:\\d+}", a.talkgroup)
|
|
||||||
r.Get("/talkgroup/{system:\\d+}/", a.talkgroup)
|
|
||||||
r.Get("/talkgroup/", a.talkgroup)
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
|
|
||||||
var statusMapping = map[error]int{
|
|
||||||
talkgroups.ErrNotFound: http.StatusNotFound,
|
|
||||||
pgx.ErrNoRows: http.StatusNotFound,
|
|
||||||
}
|
|
||||||
|
|
||||||
func httpCode(err error) int {
|
|
||||||
c, ok := statusMapping[err]
|
|
||||||
if ok {
|
|
||||||
return c
|
|
||||||
}
|
|
||||||
|
|
||||||
for e, c := range statusMapping { // check if err wraps an error we know about
|
|
||||||
if errors.Is(err, e) {
|
|
||||||
return c
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return http.StatusInternalServerError
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *api) writeResponse(w http.ResponseWriter, r *http.Request, data interface{}, err error) {
|
|
||||||
if err != nil {
|
|
||||||
log.Error().Str("path", r.URL.Path).Err(err).Msg("request failed")
|
|
||||||
http.Error(w, err.Error(), httpCode(err))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
enc := json.NewEncoder(w)
|
|
||||||
err = enc.Encode(data)
|
|
||||||
if err != nil {
|
|
||||||
log.Error().Str("path", r.URL.Path).Err(err).Msg("response marshal failed")
|
|
||||||
http.Error(w, err.Error(), httpCode(err))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func decodeParams(d interface{}, r *http.Request) error {
|
|
||||||
params := chi.RouteContext(r.Context()).URLParams
|
|
||||||
m := make(map[string]string, len(params.Keys))
|
|
||||||
|
|
||||||
for i, k := range params.Keys {
|
|
||||||
m[k] = params.Values[i]
|
|
||||||
}
|
|
||||||
|
|
||||||
dec, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{
|
|
||||||
Metadata: nil,
|
|
||||||
Result: d,
|
|
||||||
TagName: "param",
|
|
||||||
WeaklyTypedInput: true,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return dec.Decode(m)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *api) badReq(w http.ResponseWriter, err error) {
|
|
||||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *api) talkgroup(w http.ResponseWriter, r *http.Request) {
|
|
||||||
ctx := r.Context()
|
|
||||||
p := struct {
|
|
||||||
System *int `param:"system"`
|
|
||||||
ID *int `param:"id"`
|
|
||||||
}{}
|
|
||||||
|
|
||||||
err := decodeParams(&p, r)
|
|
||||||
if err != nil {
|
|
||||||
a.badReq(w, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var res interface{}
|
|
||||||
switch {
|
|
||||||
case p.System != nil && p.ID != nil:
|
|
||||||
res, err = a.tgs.TG(ctx, talkgroups.TG(*p.System, *p.ID))
|
|
||||||
case p.System != nil:
|
|
||||||
res, err = a.tgs.SystemTGs(ctx, int32(*p.System))
|
|
||||||
default:
|
|
||||||
res, err = a.tgs.TGs(ctx, nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
a.writeResponse(w, r, res, err)
|
|
||||||
}
|
|
|
@ -5,7 +5,7 @@
|
||||||
</head>
|
</head>
|
||||||
<body>
|
<body>
|
||||||
<div>
|
<div>
|
||||||
<form action="/login" method="POST">
|
<form action="/api/login" method="POST">
|
||||||
<label for="username">Username: </label>
|
<label for="username">Username: </label>
|
||||||
<input type="text" name="username" />
|
<input type="text" name="username" />
|
||||||
<label for="password">Password: </label>
|
<label for="password">Password: </label>
|
||||||
|
|
|
@ -56,13 +56,13 @@ func NewClient(ctx context.Context, conf config.DB) (*DB, error) {
|
||||||
return db, nil
|
return db, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type DBCtxKey string
|
type dBCtxKey string
|
||||||
|
|
||||||
const DBCTXKeyValue DBCtxKey = "dbctx"
|
const DBCtxKey dBCtxKey = "dbctx"
|
||||||
|
|
||||||
// FromCtx returns the database handle from the provided Context.
|
// FromCtx returns the database handle from the provided Context.
|
||||||
func FromCtx(ctx context.Context) *DB {
|
func FromCtx(ctx context.Context) *DB {
|
||||||
c, ok := ctx.Value(DBCTXKeyValue).(*DB)
|
c, ok := ctx.Value(DBCtxKey).(*DB)
|
||||||
if !ok {
|
if !ok {
|
||||||
panic("no DB in context")
|
panic("no DB in context")
|
||||||
}
|
}
|
||||||
|
@ -72,7 +72,7 @@ func FromCtx(ctx context.Context) *DB {
|
||||||
|
|
||||||
// CtxWithDB returns a Context with the provided database handle.
|
// CtxWithDB returns a Context with the provided database handle.
|
||||||
func CtxWithDB(ctx context.Context, conn *DB) context.Context {
|
func CtxWithDB(ctx context.Context, conn *DB) context.Context {
|
||||||
return context.WithValue(ctx, DBCTXKeyValue, conn)
|
return context.WithValue(ctx, DBCtxKey, conn)
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsNoRows is a convenience function that returns whether a returned error is a database
|
// IsNoRows is a convenience function that returns whether a returned error is a database
|
||||||
|
|
|
@ -9,3 +9,6 @@ func (g GetTalkgroupsWithLearnedRow) GetLearned() bool { retur
|
||||||
func (g GetTalkgroupsWithLearnedBySystemRow) GetTalkgroup() Talkgroup { return g.Talkgroup }
|
func (g GetTalkgroupsWithLearnedBySystemRow) GetTalkgroup() Talkgroup { return g.Talkgroup }
|
||||||
func (g GetTalkgroupsWithLearnedBySystemRow) GetSystem() System { return g.System }
|
func (g GetTalkgroupsWithLearnedBySystemRow) GetSystem() System { return g.System }
|
||||||
func (g GetTalkgroupsWithLearnedBySystemRow) GetLearned() bool { return g.Learned }
|
func (g GetTalkgroupsWithLearnedBySystemRow) GetLearned() bool { return g.Learned }
|
||||||
|
func (g Talkgroup) GetTalkgroup() Talkgroup { return g }
|
||||||
|
func (g Talkgroup) GetSystem() System { return System{ID: int(g.SystemID)} }
|
||||||
|
func (g Talkgroup) GetLearned() bool { return false }
|
||||||
|
|
|
@ -39,6 +39,7 @@ type Querier interface {
|
||||||
SetCallTranscript(ctx context.Context, iD uuid.UUID, transcript *string) error
|
SetCallTranscript(ctx context.Context, iD uuid.UUID, transcript *string) error
|
||||||
SetTalkgroupTags(ctx context.Context, sys int, tg int, tags []string) error
|
SetTalkgroupTags(ctx context.Context, sys int, tg int, tags []string) error
|
||||||
UpdatePassword(ctx context.Context, username string, password string) error
|
UpdatePassword(ctx context.Context, username string, password string) error
|
||||||
|
UpdateTalkgroup(ctx context.Context, arg UpdateTalkgroupParams) (Talkgroup, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ Querier = (*Queries)(nil)
|
var _ Querier = (*Queries)(nil)
|
||||||
|
|
|
@ -472,3 +472,63 @@ func (q *Queries) SetTalkgroupTags(ctx context.Context, sys int, tg int, tags []
|
||||||
_, err := q.db.Exec(ctx, setTalkgroupTags, sys, tg, tags)
|
_, err := q.db.Exec(ctx, setTalkgroupTags, sys, tg, tags)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const updateTalkgroup = `-- name: UpdateTalkgroup :one
|
||||||
|
UPDATE talkgroups
|
||||||
|
SET
|
||||||
|
name = COALESCE($1, name),
|
||||||
|
alpha_tag = COALESCE($2, alpha_tag),
|
||||||
|
tg_group = COALESCE($3, tg_group),
|
||||||
|
frequency = COALESCE($4, frequency),
|
||||||
|
metadata = COALESCE($5, metadata),
|
||||||
|
tags = COALESCE($6, tags),
|
||||||
|
alert = COALESCE($7, alert),
|
||||||
|
alert_config = COALESCE($8, alert_config),
|
||||||
|
weight = COALESCE($9, weight)
|
||||||
|
WHERE id = $10
|
||||||
|
RETURNING id, system_id, tgid, name, alpha_tag, tg_group, frequency, metadata, tags, alert, alert_config, weight
|
||||||
|
`
|
||||||
|
|
||||||
|
type UpdateTalkgroupParams struct {
|
||||||
|
Name *string `json:"name"`
|
||||||
|
AlphaTag *string `json:"alpha_tag"`
|
||||||
|
TgGroup *string `json:"tg_group"`
|
||||||
|
Frequency *int32 `json:"frequency"`
|
||||||
|
Metadata []byte `json:"metadata"`
|
||||||
|
Tags []string `json:"tags"`
|
||||||
|
Alert *bool `json:"alert"`
|
||||||
|
AlertConfig []byte `json:"alert_config"`
|
||||||
|
Weight *float32 `json:"weight"`
|
||||||
|
ID int64 `json:"id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *Queries) UpdateTalkgroup(ctx context.Context, arg UpdateTalkgroupParams) (Talkgroup, error) {
|
||||||
|
row := q.db.QueryRow(ctx, updateTalkgroup,
|
||||||
|
arg.Name,
|
||||||
|
arg.AlphaTag,
|
||||||
|
arg.TgGroup,
|
||||||
|
arg.Frequency,
|
||||||
|
arg.Metadata,
|
||||||
|
arg.Tags,
|
||||||
|
arg.Alert,
|
||||||
|
arg.AlertConfig,
|
||||||
|
arg.Weight,
|
||||||
|
arg.ID,
|
||||||
|
)
|
||||||
|
var i Talkgroup
|
||||||
|
err := row.Scan(
|
||||||
|
&i.ID,
|
||||||
|
&i.SystemID,
|
||||||
|
&i.Tgid,
|
||||||
|
&i.Name,
|
||||||
|
&i.AlphaTag,
|
||||||
|
&i.TgGroup,
|
||||||
|
&i.Frequency,
|
||||||
|
&i.Metadata,
|
||||||
|
&i.Tags,
|
||||||
|
&i.Alert,
|
||||||
|
&i.AlertConfig,
|
||||||
|
&i.Weight,
|
||||||
|
)
|
||||||
|
return i, err
|
||||||
|
}
|
||||||
|
|
132
pkg/rest/api.go
Normal file
132
pkg/rest/api.go
Normal file
|
@ -0,0 +1,132 @@
|
||||||
|
package rest
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"dynatron.me/x/stillbox/pkg/talkgroups"
|
||||||
|
|
||||||
|
"github.com/go-chi/chi/v5"
|
||||||
|
"github.com/go-chi/render"
|
||||||
|
"github.com/go-viper/mapstructure/v2"
|
||||||
|
"github.com/jackc/pgx/v5"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
type API interface {
|
||||||
|
Subrouter() http.Handler
|
||||||
|
}
|
||||||
|
|
||||||
|
type api struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func New() API {
|
||||||
|
s := new(api)
|
||||||
|
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *api) Subrouter() http.Handler {
|
||||||
|
r := chi.NewMux()
|
||||||
|
|
||||||
|
r.Mount("/talkgroup", new(talkgroupAPI).Subrouter())
|
||||||
|
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
type errResponse struct {
|
||||||
|
Err error `json:"-"`
|
||||||
|
Code int `json:"-"`
|
||||||
|
Error string `json:"error"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *errResponse) Render(w http.ResponseWriter, r *http.Request) error {
|
||||||
|
switch e.Code {
|
||||||
|
case http.StatusNotFound:
|
||||||
|
case http.StatusBadRequest:
|
||||||
|
default:
|
||||||
|
log.Error().Str("path", r.URL.Path).Err(e.Err).Int("code", e.Code).Str("msg", e.Error).Msg("request failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
render.Status(r, e.Code)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func badRequest(err error) render.Renderer {
|
||||||
|
return &errResponse{
|
||||||
|
Err: err,
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
Error: "Bad request",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func recordNotFound(err error) render.Renderer {
|
||||||
|
return &errResponse{
|
||||||
|
Err: err,
|
||||||
|
Code: http.StatusNotFound,
|
||||||
|
Error: "Record not found",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func internalError(err error) render.Renderer {
|
||||||
|
return &errResponse{
|
||||||
|
Err: err,
|
||||||
|
Code: http.StatusNotFound,
|
||||||
|
Error: "Internal server error",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type errResponder func(error) render.Renderer
|
||||||
|
|
||||||
|
var statusMapping = map[error]errResponder{
|
||||||
|
talkgroups.ErrNotFound: recordNotFound,
|
||||||
|
pgx.ErrNoRows: recordNotFound,
|
||||||
|
}
|
||||||
|
|
||||||
|
func autoError(err error) render.Renderer {
|
||||||
|
c, ok := statusMapping[err]
|
||||||
|
if ok {
|
||||||
|
c(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for e, c := range statusMapping { // check if err wraps an error we know about
|
||||||
|
if errors.Is(err, e) {
|
||||||
|
return c(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return internalError(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func wErr(w http.ResponseWriter, r *http.Request, v render.Renderer) {
|
||||||
|
err := render.Render(w, r, v)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Msg("wErr render error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeParams(d interface{}, r *http.Request) error {
|
||||||
|
params := chi.RouteContext(r.Context()).URLParams
|
||||||
|
m := make(map[string]string, len(params.Keys))
|
||||||
|
|
||||||
|
for i, k := range params.Keys {
|
||||||
|
m[k] = params.Values[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
dec, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{
|
||||||
|
Metadata: nil,
|
||||||
|
Result: d,
|
||||||
|
TagName: "param",
|
||||||
|
WeaklyTypedInput: true,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return dec.Decode(m)
|
||||||
|
}
|
||||||
|
|
||||||
|
func respond(w http.ResponseWriter, r *http.Request, v interface{}) {
|
||||||
|
render.DefaultResponder(w, r, v)
|
||||||
|
}
|
108
pkg/rest/talkgroups.go
Normal file
108
pkg/rest/talkgroups.go
Normal file
|
@ -0,0 +1,108 @@
|
||||||
|
package rest
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"dynatron.me/x/stillbox/internal/forms"
|
||||||
|
"dynatron.me/x/stillbox/pkg/database"
|
||||||
|
"dynatron.me/x/stillbox/pkg/talkgroups"
|
||||||
|
|
||||||
|
"github.com/go-chi/chi/v5"
|
||||||
|
)
|
||||||
|
|
||||||
|
type talkgroupAPI struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tga *talkgroupAPI) Subrouter() http.Handler {
|
||||||
|
r := chi.NewMux()
|
||||||
|
|
||||||
|
r.Get("/{system:\\d+}/{id:\\d+}", tga.get)
|
||||||
|
r.Put("/{system:\\d+}/{id:\\d+}", tga.put)
|
||||||
|
r.Get("/{system:\\d+}/", tga.get)
|
||||||
|
r.Get("/", tga.get)
|
||||||
|
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
type tgParams struct {
|
||||||
|
System *int `param:"system"`
|
||||||
|
ID *int `param:"id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t tgParams) haveBoth() bool {
|
||||||
|
return t.System != nil && t.ID != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t tgParams) ToID() talkgroups.ID {
|
||||||
|
nilOr := func(i *int) uint32 {
|
||||||
|
if i == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
return uint32(*i)
|
||||||
|
}
|
||||||
|
|
||||||
|
return talkgroups.ID{
|
||||||
|
System: nilOr(t.System),
|
||||||
|
Talkgroup: nilOr(t.ID),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tga *talkgroupAPI) get(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx := r.Context()
|
||||||
|
tgs := talkgroups.StoreFrom(ctx)
|
||||||
|
|
||||||
|
var p tgParams
|
||||||
|
|
||||||
|
err := decodeParams(&p, r)
|
||||||
|
if err != nil {
|
||||||
|
wErr(w, r, badRequest(err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var res interface{}
|
||||||
|
switch {
|
||||||
|
case p.System != nil && p.ID != nil:
|
||||||
|
res, err = tgs.TG(ctx, talkgroups.TG(*p.System, *p.ID))
|
||||||
|
case p.System != nil:
|
||||||
|
res, err = tgs.SystemTGs(ctx, int32(*p.System))
|
||||||
|
default:
|
||||||
|
res, err = tgs.TGs(ctx, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
wErr(w, r, autoError(err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
respond(w, r, res)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tga *talkgroupAPI) put(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var id tgParams
|
||||||
|
err := decodeParams(&id, r)
|
||||||
|
if err != nil {
|
||||||
|
wErr(w, r, badRequest(err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := r.Context()
|
||||||
|
tgs := talkgroups.StoreFrom(ctx)
|
||||||
|
|
||||||
|
input := database.UpdateTalkgroupParams{}
|
||||||
|
|
||||||
|
err = forms.Unmarshal(r, &input, forms.WithTag("json"), forms.WithAcceptBlank(), forms.WithOmitEmpty())
|
||||||
|
if err != nil {
|
||||||
|
wErr(w, r, badRequest(err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
input.ID = id.ToID().Pack()
|
||||||
|
|
||||||
|
record, err := tgs.UpdateTG(ctx, input)
|
||||||
|
if err != nil {
|
||||||
|
wErr(w, r, autoError(err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
respond(w, r, record)
|
||||||
|
}
|
|
@ -4,14 +4,9 @@
|
||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http/pprof"
|
"github.com/go-chi/chi/v5/middleware"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s *Server) installPprof() {
|
func (s *Server) installPprof() {
|
||||||
r := s.r
|
s.r.Mount("/debug", middleware.Profiler())
|
||||||
r.HandleFunc("/debug/pprof/", pprof.Index)
|
|
||||||
r.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline)
|
|
||||||
r.HandleFunc("/debug/pprof/profile", pprof.Profile)
|
|
||||||
r.HandleFunc("/debug/pprof/symbol", pprof.Symbol)
|
|
||||||
r.HandleFunc("/debug/pprof/trace", pprof.Trace)
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,6 +9,7 @@ import (
|
||||||
"dynatron.me/x/stillbox/internal/version"
|
"dynatron.me/x/stillbox/internal/version"
|
||||||
"dynatron.me/x/stillbox/pkg/config"
|
"dynatron.me/x/stillbox/pkg/config"
|
||||||
"dynatron.me/x/stillbox/pkg/database"
|
"dynatron.me/x/stillbox/pkg/database"
|
||||||
|
"dynatron.me/x/stillbox/pkg/talkgroups"
|
||||||
"github.com/go-chi/chi/v5"
|
"github.com/go-chi/chi/v5"
|
||||||
"github.com/go-chi/chi/v5/middleware"
|
"github.com/go-chi/chi/v5/middleware"
|
||||||
"github.com/go-chi/httprate"
|
"github.com/go-chi/httprate"
|
||||||
|
@ -26,7 +27,8 @@ func (s *Server) setupRoutes() {
|
||||||
}
|
}
|
||||||
|
|
||||||
r := s.r
|
r := s.r
|
||||||
r.Use(middleware.WithValue(database.DBCTXKeyValue, s.db))
|
r.Use(middleware.WithValue(database.DBCtxKey, s.db))
|
||||||
|
r.Use(middleware.WithValue(talkgroups.StoreCtxKey, s.tgs))
|
||||||
|
|
||||||
s.installPprof()
|
s.installPprof()
|
||||||
|
|
||||||
|
@ -36,7 +38,7 @@ func (s *Server) setupRoutes() {
|
||||||
s.nex.PrivateRoutes(r)
|
s.nex.PrivateRoutes(r)
|
||||||
s.auth.PrivateRoutes(r)
|
s.auth.PrivateRoutes(r)
|
||||||
s.alerter.PrivateRoutes(r)
|
s.alerter.PrivateRoutes(r)
|
||||||
r.Mount("/api", s.api.Subrouter())
|
r.Mount("/api", s.rest.Subrouter())
|
||||||
})
|
})
|
||||||
|
|
||||||
r.Group(func(r chi.Router) {
|
r.Group(func(r chi.Router) {
|
||||||
|
|
|
@ -7,12 +7,12 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"dynatron.me/x/stillbox/pkg/alerting"
|
"dynatron.me/x/stillbox/pkg/alerting"
|
||||||
"dynatron.me/x/stillbox/pkg/api"
|
|
||||||
"dynatron.me/x/stillbox/pkg/auth"
|
"dynatron.me/x/stillbox/pkg/auth"
|
||||||
"dynatron.me/x/stillbox/pkg/config"
|
"dynatron.me/x/stillbox/pkg/config"
|
||||||
"dynatron.me/x/stillbox/pkg/database"
|
"dynatron.me/x/stillbox/pkg/database"
|
||||||
"dynatron.me/x/stillbox/pkg/nexus"
|
"dynatron.me/x/stillbox/pkg/nexus"
|
||||||
"dynatron.me/x/stillbox/pkg/notify"
|
"dynatron.me/x/stillbox/pkg/notify"
|
||||||
|
"dynatron.me/x/stillbox/pkg/rest"
|
||||||
"dynatron.me/x/stillbox/pkg/sinks"
|
"dynatron.me/x/stillbox/pkg/sinks"
|
||||||
"dynatron.me/x/stillbox/pkg/sources"
|
"dynatron.me/x/stillbox/pkg/sources"
|
||||||
"dynatron.me/x/stillbox/pkg/talkgroups"
|
"dynatron.me/x/stillbox/pkg/talkgroups"
|
||||||
|
@ -37,7 +37,7 @@ type Server struct {
|
||||||
notifier notify.Notifier
|
notifier notify.Notifier
|
||||||
hup chan os.Signal
|
hup chan os.Signal
|
||||||
tgs talkgroups.Store
|
tgs talkgroups.Store
|
||||||
api api.API
|
rest rest.API
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(ctx context.Context, cfg *config.Config) (*Server, error) {
|
func New(ctx context.Context, cfg *config.Config) (*Server, error) {
|
||||||
|
@ -61,7 +61,7 @@ func New(ctx context.Context, cfg *config.Config) (*Server, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
tgCache := talkgroups.NewCache()
|
tgCache := talkgroups.NewCache()
|
||||||
api := api.New(tgCache)
|
api := rest.New()
|
||||||
|
|
||||||
srv := &Server{
|
srv := &Server{
|
||||||
auth: authenticator,
|
auth: authenticator,
|
||||||
|
@ -73,7 +73,7 @@ func New(ctx context.Context, cfg *config.Config) (*Server, error) {
|
||||||
alerter: alerting.New(cfg.Alerting, tgCache, alerting.WithNotifier(notifier)),
|
alerter: alerting.New(cfg.Alerting, tgCache, alerting.WithNotifier(notifier)),
|
||||||
notifier: notifier,
|
notifier: notifier,
|
||||||
tgs: tgCache,
|
tgs: tgCache,
|
||||||
api: api,
|
rest: api,
|
||||||
}
|
}
|
||||||
|
|
||||||
srv.sinks.Register("database", sinks.NewDatabaseSink(srv.db), true)
|
srv.sinks.Register("database", sinks.NewDatabaseSink(srv.db), true)
|
||||||
|
|
|
@ -17,7 +17,15 @@ import (
|
||||||
|
|
||||||
type tgMap map[ID]*Talkgroup
|
type tgMap map[ID]*Talkgroup
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrNotFound = errors.New("talkgroup not found")
|
||||||
|
ErrNoSuchSystem = errors.New("no such system")
|
||||||
|
)
|
||||||
|
|
||||||
type Store interface {
|
type Store interface {
|
||||||
|
// UpdateTG updates a talkgroup record.
|
||||||
|
UpdateTG(ctx context.Context, input database.UpdateTalkgroupParams) (*Talkgroup, error)
|
||||||
|
|
||||||
// TG retrieves a Talkgroup from the Store.
|
// TG retrieves a Talkgroup from the Store.
|
||||||
TG(ctx context.Context, tg ID) (*Talkgroup, error)
|
TG(ctx context.Context, tg ID) (*Talkgroup, error)
|
||||||
|
|
||||||
|
@ -49,16 +57,16 @@ type Store interface {
|
||||||
HUP(*config.Config)
|
HUP(*config.Config)
|
||||||
}
|
}
|
||||||
|
|
||||||
type CtxStoreKeyT string
|
type storeCtxKey string
|
||||||
|
|
||||||
const CtxStoreKey CtxStoreKeyT = "store"
|
const StoreCtxKey storeCtxKey = "store"
|
||||||
|
|
||||||
func CtxWithStore(ctx context.Context, s Store) context.Context {
|
func CtxWithStore(ctx context.Context, s Store) context.Context {
|
||||||
return context.WithValue(ctx, CtxStoreKey, s)
|
return context.WithValue(ctx, StoreCtxKey, s)
|
||||||
}
|
}
|
||||||
|
|
||||||
func StoreFrom(ctx context.Context) Store {
|
func StoreFrom(ctx context.Context) Store {
|
||||||
s, ok := ctx.Value(CtxStoreKey).(Store)
|
s, ok := ctx.Value(StoreCtxKey).(Store)
|
||||||
if !ok {
|
if !ok {
|
||||||
return NewCache()
|
return NewCache()
|
||||||
}
|
}
|
||||||
|
@ -213,8 +221,6 @@ func (t *cache) Load(ctx context.Context, tgs []int64) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var ErrNotFound = errors.New("talkgroup not found")
|
|
||||||
|
|
||||||
func (t *cache) Weight(ctx context.Context, id ID, tm time.Time) float64 {
|
func (t *cache) Weight(ctx context.Context, id ID, tm time.Time) float64 {
|
||||||
tg, err := t.TG(ctx, id)
|
tg, err := t.TG(ctx, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -290,3 +296,23 @@ func (t *cache) SystemName(ctx context.Context, id int) (name string, has bool)
|
||||||
|
|
||||||
return n, has
|
return n, has
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *cache) UpdateTG(ctx context.Context, input database.UpdateTalkgroupParams) (*Talkgroup, error) {
|
||||||
|
sysName, has := t.SystemName(ctx, int(Unpack(input.ID).System))
|
||||||
|
if !has {
|
||||||
|
return nil, ErrNoSuchSystem
|
||||||
|
}
|
||||||
|
|
||||||
|
tg, err := database.FromCtx(ctx).UpdateTalkgroup(ctx, input)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
record := &Talkgroup{
|
||||||
|
Talkgroup: tg,
|
||||||
|
System: database.System{ID: int(tg.SystemID), Name: sysName},
|
||||||
|
}
|
||||||
|
t.add(record)
|
||||||
|
|
||||||
|
return record, nil
|
||||||
|
}
|
||||||
|
|
|
@ -49,6 +49,13 @@ func (t ID) Pack() int64 {
|
||||||
return int64((int64(t.System) << 32) | int64(t.Talkgroup))
|
return int64((int64(t.System) << 32) | int64(t.Talkgroup))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Unpack(id int64) ID {
|
||||||
|
return ID{
|
||||||
|
System: uint32(id >> 32),
|
||||||
|
Talkgroup: uint32(id & 0xffffffff),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (t ID) String() string {
|
func (t ID) String() string {
|
||||||
return fmt.Sprintf("%d:%d", t.System, t.Talkgroup)
|
return fmt.Sprintf("%d:%d", t.System, t.Talkgroup)
|
||||||
|
|
||||||
|
|
|
@ -104,6 +104,20 @@ FROM talkgroups_learned tgl
|
||||||
JOIN systems sys ON tgl.system_id = sys.id
|
JOIN systems sys ON tgl.system_id = sys.id
|
||||||
WHERE ignored IS NOT TRUE;
|
WHERE ignored IS NOT TRUE;
|
||||||
|
|
||||||
|
|
||||||
-- name: GetSystemName :one
|
-- name: GetSystemName :one
|
||||||
SELECT name FROM systems WHERE id = sqlc.arg(system_id);
|
SELECT name FROM systems WHERE id = sqlc.arg(system_id);
|
||||||
|
|
||||||
|
-- name: UpdateTalkgroup :one
|
||||||
|
UPDATE talkgroups
|
||||||
|
SET
|
||||||
|
name = COALESCE(sqlc.narg('name'), name),
|
||||||
|
alpha_tag = COALESCE(sqlc.narg('alpha_tag'), alpha_tag),
|
||||||
|
tg_group = COALESCE(sqlc.narg('tg_group'), tg_group),
|
||||||
|
frequency = COALESCE(sqlc.narg('frequency'), frequency),
|
||||||
|
metadata = COALESCE(sqlc.narg('metadata'), metadata),
|
||||||
|
tags = COALESCE(sqlc.narg('tags'), tags),
|
||||||
|
alert = COALESCE(sqlc.narg('alert'), alert),
|
||||||
|
alert_config = COALESCE(sqlc.narg('alert_config'), alert_config),
|
||||||
|
weight = COALESCE(sqlc.narg('weight'), weight)
|
||||||
|
WHERE id = @id
|
||||||
|
RETURNING *;
|
||||||
|
|
Loading…
Reference in a new issue