diff --git a/.gitignore b/.gitignore index 8e9a60e..610f57c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ config.yaml -config.test.yaml +config.*.yaml +!config.sample.yaml /*.sql client/calls/ !client/calls/.gitkeep diff --git a/config.sample.yaml b/config.sample.yaml index 5a71064..6384f65 100644 --- a/config.sample.yaml +++ b/config.sample.yaml @@ -41,3 +41,11 @@ notify: # {{ end -}} config: webhookURL: "http://somewhere" +# configure upstream relays here +relay: + # `url` is the root of the instance +# - url: 'http://some.host:3051/' +# apiKey: aaaabbbb-cccc-dddd-eeee-ffff11112222 + # `required` specifies whether we should report failure (i.e. HTTP 500 for rdio-http) to the source + # if the relay call submission fails +# required: false diff --git a/internal/forms/forms.go b/internal/forms/forms.go index f8a7f7a..8f0be96 100644 --- a/internal/forms/forms.go +++ b/internal/forms/forms.go @@ -1,19 +1,8 @@ package forms import ( - "encoding/json" "errors" - "fmt" - "io" - "net/http" - "reflect" - "strconv" - "strings" "time" - - "dynatron.me/x/stillbox/internal/jsontypes" - - "github.com/araddon/dateparse" ) var ( @@ -80,293 +69,3 @@ func (o *options) Tag() string { return "form" } - -func (o *options) parseTime(s string, dpo ...dateparse.ParserOption) (t time.Time, set bool, err error) { - if o.acceptBlank && s == "" { - set = false - return - } - - if iv, err := strconv.Atoi(s); err == nil { - return time.Unix(int64(iv), 0), true, nil - } - - switch { - case o.parseTimeIn != nil: - t, err = dateparse.ParseIn(s, o.parseTimeIn, dpo...) - case o.parseLocal: - t, err = dateparse.ParseLocal(s, dpo...) - default: - t, err = dateparse.ParseAny(s, dpo...) - } - - set = true - - return -} - -func (o *options) parseBool(s string) (v bool, set bool, err error) { - if o.acceptBlank && s == "" { - set = false - return - } - - set = true - - v, err = strconv.ParseBool(s) - if err != nil { - return v, set, fmt.Errorf("parsebool('%s'): %w", s, err) - } - - return -} - -func (o *options) parseInt(s string) (v int, set bool, err error) { - if o.acceptBlank && s == "" { - set = false - return - } - set = true - - v, err = strconv.Atoi(s) - if err != nil { - return v, set, fmt.Errorf("atoi('%s'): %w", s, err) - } - - return -} - -func (o *options) parseFloat64(s string) (v float64, set bool, err error) { - if o.acceptBlank && s == "" { - set = false - return - } - set = true - - v, err = strconv.ParseFloat(s, 64) - if err != nil { - return v, set, fmt.Errorf("ParseFloat('%s'): %w", s, err) - } - - return -} - -func (o *options) parseDuration(s string) (v time.Duration, set bool, err error) { - if o.acceptBlank && s == "" { - set = false - return - } - - set = true - - v, err = time.ParseDuration(s) - if err != nil { - return v, set, fmt.Errorf("ParseDuration('%s'): %w", s, err) - } - - return -} - -var typeOfByteSlice = reflect.TypeOf([]byte(nil)) - -func (o *options) iterFields(r *http.Request, destStruct reflect.Value) error { - structType := destStruct.Type() - for i := 0; i < destStruct.NumField(); i++ { - destFieldVal := destStruct.Field(i) - fieldType := structType.Field(i) - if !fieldType.IsExported() && !fieldType.Anonymous { - continue - } - - if destFieldVal.Kind() == reflect.Struct && fieldType.Anonymous { - err := o.iterFields(r, destFieldVal) - if err != nil { - return err - } - } - - var tAr []string - var formField string - var omitEmpty bool - if o.defaultOmitEmpty { - omitEmpty = true - } - - formTag, has := structType.Field(i).Tag.Lookup(o.Tag()) - if has { - tAr = strings.Split(formTag, ",") - formField = tAr[0] - for _, v := range tAr[1:] { - if v == "omitempty" { - omitEmpty = true - break - } - } - } - - if !has || formField == "-" { - continue - } - - destFieldIntf := destFieldVal.Interface() - - if destFieldVal.Kind() == reflect.Slice && destFieldVal.Type() == typeOfByteSlice { - file, hdr, err := r.FormFile(formField) - if err != nil { - return fmt.Errorf("get form file: %w", err) - } - - nameField, hasFilename := structType.Field(i).Tag.Lookup("filenameField") - if hasFilename { - fnf := destStruct.FieldByName(nameField) - if fnf == (reflect.Value{}) { - panic(fmt.Errorf("filenameField '%s' does not exist", nameField)) - } - - fnf.SetString(hdr.Filename) - } - audioBytes, err := io.ReadAll(file) - if err != nil { - return fmt.Errorf("file read: %w", err) - } - - destFieldVal.SetBytes(audioBytes) - - continue - } - - if !r.Form.Has(formField) && omitEmpty { - continue - } - - ff := r.Form.Get(formField) - - switch v := destFieldIntf.(type) { - case string, *string: - setVal(destFieldVal, ff != "" || o.acceptBlank, ff) - case int, uint, *int, *uint: - val, set, err := o.parseInt(ff) - if err != nil { - return err - } - setVal(destFieldVal, set, val) - case float64: - val, set, err := o.parseFloat64(ff) - if err != nil { - return err - } - setVal(destFieldVal, set, val) - case bool, *bool: - val, set, err := o.parseBool(ff) - if err != nil { - return err - } - setVal(destFieldVal, set, val) - case time.Time, *time.Time, jsontypes.Time, *jsontypes.Time: - t, set, err := o.parseTime(ff) - if err != nil { - return err - } - setVal(destFieldVal, set, t) - case time.Duration, *time.Duration, jsontypes.Duration, *jsontypes.Duration: - d, set, err := o.parseDuration(ff) - if err != nil { - return err - } - setVal(destFieldVal, set, d) - case []int: - val := strings.Trim(ff, "[]") - if val == "" && o.acceptBlank { - continue - } - vals := strings.Split(val, ",") - ar := make([]int, 0, len(vals)) - for _, v := range vals { - i, err := strconv.Atoi(v) - if err == nil { - ar = append(ar, i) - } - } - destFieldVal.Set(reflect.ValueOf(ar)) - default: - panic(fmt.Errorf("unsupported type %T", v)) - } - } - - return nil -} - -func setVal(destFieldVal reflect.Value, set bool, src any) { - if !set { - return - } - - destType := destFieldVal.Type() - srcVal := reflect.ValueOf(src) - - if srcVal.Kind() == reflect.Ptr { - srcVal = srcVal.Elem() - } - - if destType.Kind() == reflect.Ptr { - if !srcVal.CanAddr() { - if srcVal.CanConvert(destType.Elem()) { - srcVal = srcVal.Convert(destType.Elem()) - } - copy := reflect.New(srcVal.Type()) - copy.Elem().Set(srcVal) - srcVal = copy - } - } else if srcVal.CanConvert(destFieldVal.Type()) { - srcVal = srcVal.Convert(destFieldVal.Type()) - } - - destFieldVal.Set(srcVal) -} - -func Unmarshal(r *http.Request, dest any, opt ...Option) error { - o := options{ - maxMultipartMemory: MaxMultipartMemory, - } - - for _, opt := range opt { - opt(&o) - } - - contentType := strings.Split(r.Header.Get("Content-Type"), ";")[0] - - switch contentType { - case "multipart/form-data": - err := r.ParseMultipartForm(o.maxMultipartMemory) - if err != nil { - return fmt.Errorf("ParseForm: %w", err) - } - - return o.unmarshalForm(r, dest) - case "application/x-www-form-urlencoded": - err := r.ParseForm() - if err != nil { - return fmt.Errorf("ParseForm: %w", err) - } - return o.unmarshalForm(r, dest) - case "application/json": - return json.NewDecoder(r.Body).Decode(dest) - } - - return ErrContentType -} - -func (o *options) unmarshalForm(r *http.Request, dest any) error { - destVal := reflect.ValueOf(dest) - if k := destVal.Kind(); k == reflect.Ptr { - destVal = destVal.Elem() - } else { - return ErrNotPointer - } - - if destVal.Kind() != reflect.Struct { - return ErrNotStruct - } - - return o.iterFields(r, destVal) -} diff --git a/internal/forms/marshal.go b/internal/forms/marshal.go new file mode 100644 index 0000000..dedd9ad --- /dev/null +++ b/internal/forms/marshal.go @@ -0,0 +1,139 @@ +package forms + +import ( + "fmt" + "mime/multipart" + "reflect" + "strconv" + "strings" + "time" +) + +func Marshal(src any, dest *multipart.Writer, opts ...Option) error { + o := options{} + + for _, opt := range opts { + opt(&o) + } + + return o.marshalMultipartForm(src, dest) +} + +func (o *options) marshalMultipartForm(src any, dest *multipart.Writer) error { + srcVal := reflect.ValueOf(src) + if k := srcVal.Kind(); k == reflect.Ptr { + srcVal = srcVal.Elem() + } + + if srcVal.Kind() != reflect.Struct { + return ErrNotStruct + } + + return o.marIterFields(srcVal, dest) +} + +func (o *options) marIterFields(srcVal reflect.Value, dest *multipart.Writer) error { + structType := srcVal.Type() + for i := 0; i < structType.NumField(); i++ { + srcFieldVal := srcVal.Field(i) + fieldType := structType.Field(i) + if !fieldType.IsExported() && !fieldType.Anonymous { + continue + } + + if srcFieldVal.Kind() == reflect.Struct && fieldType.Anonymous { + continue + } + + var tAr []string + var formField string + var omitEmpty bool + if o.defaultOmitEmpty { + omitEmpty = true + } + + formTag, has := structType.Field(i).Tag.Lookup(o.Tag()) + if has { + tAr = strings.Split(formTag, ",") + formField = tAr[0] + for _, v := range tAr[1:] { + if v == "omitempty" { + omitEmpty = true + break + } + } + } + + if !has || formField == "-" { + continue + } + + if srcFieldVal.Kind() == reflect.Ptr { + srcFieldVal = srcFieldVal.Elem() + if srcFieldVal == (reflect.Value{}) || srcFieldVal.IsZero() { + continue + } + } + + srcFieldIntf := srcFieldVal.Interface() + if srcFieldVal.Kind() == reflect.Slice && srcFieldVal.Type() == typeOfByteSlice { + nameField, hasFilename := structType.Field(i).Tag.Lookup("filenameField") + fileName := "" + if hasFilename { + fnf := srcVal.FieldByName(nameField) + if fnf == (reflect.Value{}) { + panic(fmt.Errorf("filenameField '%s' does not exist", nameField)) + } + + fileName = fnf.String() + } + + fw, err := dest.CreateFormFile(formField, fileName) + if err != nil { + return fmt.Errorf("form marshal: createFormFile: %w", err) + } + + _, err = fw.Write(srcFieldVal.Bytes()) + if err != nil { + return fmt.Errorf("form marshal: write file: %w", err) + } + + continue + } + + if srcFieldVal.IsZero() && omitEmpty { + continue + } + + var val string + switch v := srcFieldIntf.(type) { + case []string: + if omitEmpty && len(v) == 0 { + continue + } + val = "[" + strings.Join(v, ",") + "]" + case []int: + if omitEmpty && len(v) == 0 { + continue + } + + sl := make([]string, len(v)) + for i := range v { + sl[i] = strconv.Itoa(v[i]) + } + + val = "[" + strings.Join(sl, ",") + "]" + case time.Time: + val = strconv.Itoa(int(v.Unix())) + default: + val = fmt.Sprint(v) + } + + err := dest.WriteField(formField, val) + if err != nil { + return fmt.Errorf("marshal field '%s': %w", formField, err) + } + } + + return nil +} diff --git a/internal/forms/marshal_test.go b/internal/forms/marshal_test.go new file mode 100644 index 0000000..6004218 --- /dev/null +++ b/internal/forms/marshal_test.go @@ -0,0 +1,121 @@ +package forms_test + +import ( + "bytes" + "fmt" + "math/rand" + "mime/multipart" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "dynatron.me/x/stillbox/internal/common" + "dynatron.me/x/stillbox/internal/forms" + "dynatron.me/x/stillbox/pkg/auth" + "dynatron.me/x/stillbox/pkg/calls" + "dynatron.me/x/stillbox/pkg/sources" + + "github.com/google/uuid" +) + +type hand func(w http.ResponseWriter, r *http.Request) + +func (h hand) ServeHTTP(w http.ResponseWriter, r *http.Request) { + h(w, r) +} + +func call(url string, call *calls.Call) error { + var buf bytes.Buffer + body := multipart.NewWriter(&buf) + + err := forms.Marshal(call, body) + if err != nil { + return fmt.Errorf("relay form parse: %w", err) + } + body.Close() + + r, err := http.NewRequest(http.MethodPost, url, &buf) + if err != nil { + return fmt.Errorf("relay newrequest: %w", err) + } + + r.Header.Set("Content-Type", body.FormDataContentType()) + + resp, err := http.DefaultClient.Do(r) + if err != nil { + return fmt.Errorf("relay: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("relay: received HTTP %d", resp.StatusCode) + } + + return nil +} + +func TestMarshal(t *testing.T) { + uuid.SetRand(rand.New(rand.NewSource(1))) + + tests := []struct { + name string + submitter auth.UserID + apiKey string + call calls.Call + }{ + { + name: "base", + submitter: auth.UserID(1), + call: calls.Call{ + ID: uuid.UUID([16]byte{0x52, 0xfd, 0xfc, 0x07, 0x21, 0x82, 0x45, 0x4f, 0x96, 0x3f, 0x5f, 0x0f, 0x9a, 0x62, 0x1d, 0x72}), + Submitter: common.PtrTo(auth.UserID(1)), + System: 197, + Talkgroup: 10101, + DateTime: time.Date(2024, 11, 10, 23, 33, 02, 0, time.Local), + AudioName: "rightnow.mp3", + Audio: []byte{0xFF, 0xF3, 0x14, 0xC4, 0x00, 0x00, 0x00, 0x03, 0x48, 0x01, 0x40, 0x00, 0x00, 0x4C, 0x41, 0x4D, 0x45, 0x33, 0x2E, 0x39, 0x36, 0x2E, 0x31, 0x55}, + AudioType: "audio/mpeg", + Duration: calls.CallDuration(24000000), + TalkgroupLabel: common.PtrTo("Some TG"), + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + var serr error + var called bool + h := hand(func(w http.ResponseWriter, r *http.Request) { + called = true + serr = r.ParseMultipartForm(1024 * 1024 * 2) + if serr != nil { + t.Log("parsemultipart", serr) + return + } + + cur := new(sources.CallUploadRequest) + serr = forms.Unmarshal(r, cur, forms.WithAcceptBlank()) + cur.DontStore = true + if serr != nil { + t.Log("unmarshal", serr) + return + } + + assert.Equal(t, tc.apiKey, cur.Key) + + toC, tcerr := cur.ToCall(tc.submitter) + require.NoError(t, tcerr) + assert.Equal(t, &tc.call, toC) + }) + svr := httptest.NewServer(h) + + err := call(svr.URL, &tc.call) + assert.True(t, called) + assert.NoError(t, err) + assert.NoError(t, serr) + }) + } +} diff --git a/internal/forms/unmarshal.go b/internal/forms/unmarshal.go new file mode 100644 index 0000000..4b8f568 --- /dev/null +++ b/internal/forms/unmarshal.go @@ -0,0 +1,306 @@ +package forms + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "reflect" + "strconv" + "strings" + "time" + + "dynatron.me/x/stillbox/internal/jsontypes" + + "github.com/araddon/dateparse" +) + +func (o *options) parseTime(s string, dpo ...dateparse.ParserOption) (t time.Time, set bool, err error) { + if o.acceptBlank && s == "" { + set = false + return + } + + if iv, err := strconv.Atoi(s); err == nil { + return time.Unix(int64(iv), 0), true, nil + } + + switch { + case o.parseTimeIn != nil: + t, err = dateparse.ParseIn(s, o.parseTimeIn, dpo...) + case o.parseLocal: + t, err = dateparse.ParseLocal(s, dpo...) + default: + t, err = dateparse.ParseAny(s, dpo...) + } + + set = true + + return +} + +func (o *options) parseBool(s string) (v bool, set bool, err error) { + if o.acceptBlank && s == "" { + set = false + return + } + + set = true + + v, err = strconv.ParseBool(s) + if err != nil { + return v, set, fmt.Errorf("parsebool('%s'): %w", s, err) + } + + return +} + +func (o *options) parseInt(s string) (v int, set bool, err error) { + if o.acceptBlank && s == "" { + set = false + return + } + set = true + + v, err = strconv.Atoi(s) + if err != nil { + return v, set, fmt.Errorf("atoi('%s'): %w", s, err) + } + + return +} + +func (o *options) parseFloat64(s string) (v float64, set bool, err error) { + if o.acceptBlank && s == "" { + set = false + return + } + set = true + + v, err = strconv.ParseFloat(s, 64) + if err != nil { + return v, set, fmt.Errorf("ParseFloat('%s'): %w", s, err) + } + + return +} + +func (o *options) parseDuration(s string) (v time.Duration, set bool, err error) { + if o.acceptBlank && s == "" { + set = false + return + } + + set = true + + v, err = time.ParseDuration(s) + if err != nil { + return v, set, fmt.Errorf("ParseDuration('%s'): %w", s, err) + } + + return +} + +var typeOfByteSlice = reflect.TypeOf([]byte(nil)) + +func (o *options) unmIterFields(r *http.Request, destStruct reflect.Value) error { + structType := destStruct.Type() + for i := 0; i < destStruct.NumField(); i++ { + destFieldVal := destStruct.Field(i) + fieldType := structType.Field(i) + if !fieldType.IsExported() && !fieldType.Anonymous { + continue + } + + if destFieldVal.Kind() == reflect.Struct && fieldType.Anonymous { + err := o.unmIterFields(r, destFieldVal) + if err != nil { + return err + } + } + + var tAr []string + var formField string + var omitEmpty bool + if o.defaultOmitEmpty { + omitEmpty = true + } + + formTag, has := structType.Field(i).Tag.Lookup(o.Tag()) + if has { + tAr = strings.Split(formTag, ",") + formField = tAr[0] + for _, v := range tAr[1:] { + if v == "omitempty" { + omitEmpty = true + break + } + } + } + + if !has || formField == "-" { + continue + } + + destFieldIntf := destFieldVal.Interface() + + if destFieldVal.Kind() == reflect.Slice && destFieldVal.Type() == typeOfByteSlice { + file, hdr, err := r.FormFile(formField) + if err != nil { + return fmt.Errorf("get form file: %w", err) + } + + nameField, hasFilename := structType.Field(i).Tag.Lookup("filenameField") + if hasFilename { + fnf := destStruct.FieldByName(nameField) + if fnf == (reflect.Value{}) { + panic(fmt.Errorf("filenameField '%s' does not exist", nameField)) + } + + fnf.SetString(hdr.Filename) + } + audioBytes, err := io.ReadAll(file) + if err != nil { + return fmt.Errorf("file read: %w", err) + } + + destFieldVal.SetBytes(audioBytes) + + continue + } + + if !r.Form.Has(formField) && omitEmpty { + continue + } + + ff := r.Form.Get(formField) + + switch v := destFieldIntf.(type) { + case string, *string: + setVal(destFieldVal, ff != "" || o.acceptBlank, ff) + case int, uint, *int, *uint: + val, set, err := o.parseInt(ff) + if err != nil { + return err + } + setVal(destFieldVal, set, val) + case float64: + val, set, err := o.parseFloat64(ff) + if err != nil { + return err + } + setVal(destFieldVal, set, val) + case bool, *bool: + val, set, err := o.parseBool(ff) + if err != nil { + return err + } + setVal(destFieldVal, set, val) + case time.Time, *time.Time, jsontypes.Time, *jsontypes.Time: + t, set, err := o.parseTime(ff) + if err != nil { + return err + } + setVal(destFieldVal, set, t) + case time.Duration, *time.Duration, jsontypes.Duration, *jsontypes.Duration: + d, set, err := o.parseDuration(ff) + if err != nil { + return err + } + setVal(destFieldVal, set, d) + case []int: + val := strings.Trim(ff, "[]") + if val == "" && o.acceptBlank { + continue + } + vals := strings.Split(val, ",") + ar := make([]int, 0, len(vals)) + for _, v := range vals { + i, err := strconv.Atoi(v) + if err == nil { + ar = append(ar, i) + } + } + destFieldVal.Set(reflect.ValueOf(ar)) + default: + panic(fmt.Errorf("unsupported type %T", v)) + } + } + + return nil +} + +func setVal(destFieldVal reflect.Value, set bool, src any) { + if !set { + return + } + + destType := destFieldVal.Type() + srcVal := reflect.ValueOf(src) + + if srcVal.Kind() == reflect.Ptr { + srcVal = srcVal.Elem() + } + + if destType.Kind() == reflect.Ptr { + if !srcVal.CanAddr() { + if srcVal.CanConvert(destType.Elem()) { + srcVal = srcVal.Convert(destType.Elem()) + } + copy := reflect.New(srcVal.Type()) + copy.Elem().Set(srcVal) + srcVal = copy + } + } else if srcVal.CanConvert(destFieldVal.Type()) { + srcVal = srcVal.Convert(destFieldVal.Type()) + } + + destFieldVal.Set(srcVal) +} + +func Unmarshal(r *http.Request, dest any, opts ...Option) error { + o := options{ + maxMultipartMemory: MaxMultipartMemory, + } + + for _, opt := range opts { + opt(&o) + } + + contentType := strings.Split(r.Header.Get("Content-Type"), ";")[0] + + switch contentType { + case "multipart/form-data": + err := r.ParseMultipartForm(o.maxMultipartMemory) + if err != nil { + return fmt.Errorf("ParseForm: %w", err) + } + + return o.unmarshalForm(r, dest) + case "application/x-www-form-urlencoded": + err := r.ParseForm() + if err != nil { + return fmt.Errorf("ParseForm: %w", err) + } + return o.unmarshalForm(r, dest) + case "application/json": + return json.NewDecoder(r.Body).Decode(dest) + } + + return ErrContentType +} + +func (o *options) unmarshalForm(r *http.Request, dest any) error { + destVal := reflect.ValueOf(dest) + if k := destVal.Kind(); k == reflect.Ptr { + destVal = destVal.Elem() + } else { + return ErrNotPointer + } + + if destVal.Kind() != reflect.Struct { + return ErrNotStruct + } + + return o.unmIterFields(r, destVal) +} diff --git a/internal/forms/forms_test.go b/internal/forms/unmarshal_test.go similarity index 100% rename from internal/forms/forms_test.go rename to internal/forms/unmarshal_test.go diff --git a/pkg/calls/call.go b/pkg/calls/call.go index 07b633d..6359940 100644 --- a/pkg/calls/call.go +++ b/pkg/calls/call.go @@ -33,26 +33,26 @@ func (d CallDuration) Seconds() int32 { } type Call struct { - ID uuid.UUID - Audio []byte - AudioName string - AudioType string - Duration CallDuration - DateTime time.Time - Frequencies []int - Frequency int - Patches []int - Source int - Sources []int - System int - Submitter *auth.UserID - SystemLabel string - Talkgroup int - TalkgroupGroup *string - TalkgroupLabel *string - TGAlphaTag *string + ID uuid.UUID `form:"-"` + Audio []byte `form:"audio" filenameField:"AudioName"` + AudioName string `form:"audioName"` + AudioType string `form:"audioType"` + Duration CallDuration `form:"-"` + DateTime time.Time `form:"dateTime"` + Frequencies []int `form:"frequencies"` + Frequency int `form:"frequency"` + Patches []int `form:"patches"` + Source int `form:"source"` + Sources []int `form:"sources"` + System int `form:"system"` + Submitter *auth.UserID `form:"-"` + SystemLabel string `form:"systemLabel"` + Talkgroup int `form:"talkgroup"` + TalkgroupGroup *string `form:"talkgroupGroup"` + TalkgroupLabel *string `form:"talkgroupLabel"` + TGAlphaTag *string `form:"talkgroupTag"` // not 1:1 - shouldStore bool + shouldStore bool `form:"-"` } func (c *Call) String() string { diff --git a/pkg/config/config.go b/pkg/config/config.go index fb2c589..c9f3074 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -22,6 +22,7 @@ type Config struct { Public bool `yaml:"public"` RateLimit RateLimit `yaml:"rateLimit"` Notify Notify `yaml:"notify"` + Relay []Relay `yaml:"relay"` configPath string } @@ -63,6 +64,12 @@ type Alerting struct { Renotify *jsontypes.Duration `yaml:"renotify,omitempty" form:"renotify,omitempty"` } +type Relay struct { + URL string `yaml:"url"` + APIKey string `yaml:"apiKey"` + Required bool `yaml:"required"` +} + type Notify []NotifyService type NotifyService struct { @@ -72,17 +79,6 @@ type NotifyService struct { Config map[string]interface{} `yaml:"config" json:"config"` } -func (n *NotifyService) GetS(k, defaultVal string) string { - if v, has := n.Config[k]; has { - if v, isString := v.(string); isString { - return v - } - log.Error().Str("configKey", k).Str("provider", n.Provider).Str("default", defaultVal).Msg("notify config value is not a string! using default") - } - - return defaultVal -} - func (rl *RateLimit) Verify() bool { if rl.Enable { if rl.Requests > 0 && rl.Over > 0 { diff --git a/pkg/server/server.go b/pkg/server/server.go index a5ff3bd..4357d12 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -31,6 +31,7 @@ type Server struct { r *chi.Mux sources sources.Sources sinks sinks.Sinks + relayer *sinks.RelayManager nex *nexus.Nexus logger *Logger alerter alerting.Alerter @@ -73,6 +74,7 @@ func New(ctx context.Context, cfg *config.Config) (*Server, error) { alerter: alerting.New(cfg.Alerting, tgCache, alerting.WithNotifier(notifier)), notifier: notifier, tgs: tgCache, + sinks: sinks.NewSinkManager(), rest: api, } @@ -85,6 +87,13 @@ func New(ctx context.Context, cfg *config.Config) (*Server, error) { srv.sources.Register("rdio-http", sources.NewRdioHTTP(authenticator, srv)) + relayer, err := sinks.NewRelayManager(srv.sinks, cfg.Relay) + if err != nil { + return nil, err + } + + srv.relayer = relayer + r.Use(middleware.RequestID) r.Use(middleware.RealIP) r.Use(RequestLogger()) diff --git a/pkg/sinks/relay.go b/pkg/sinks/relay.go new file mode 100644 index 0000000..1f3bb97 --- /dev/null +++ b/pkg/sinks/relay.go @@ -0,0 +1,118 @@ +package sinks + +import ( + "bytes" + "context" + "fmt" + "mime/multipart" + "net/http" + "net/url" + + "dynatron.me/x/stillbox/internal/forms" + "dynatron.me/x/stillbox/internal/version" + "dynatron.me/x/stillbox/pkg/calls" + "dynatron.me/x/stillbox/pkg/config" +) + +type RelayManager struct { + xp *http.Transport + client *http.Client + + relays []*Relay +} + +type Relay struct { + config.Relay + mgr *RelayManager + Name string + + url *url.URL +} + +func NewRelayManager(s Sinks, cfgs []config.Relay) (*RelayManager, error) { + xp := http.DefaultTransport.(*http.Transport).Clone() + xp.MaxIdleConnsPerHost = 10 + + client := &http.Client{ + Transport: xp, + } + + rm := &RelayManager{ + xp: xp, + client: client, + relays: make([]*Relay, 0, len(cfgs)), + } + + for i, cfg := range cfgs { + rs, err := rm.newRelay(cfg) + if err != nil { + return nil, err + } + + rm.relays = append(rm.relays, rs) + + sinkName := fmt.Sprintf("relay%d:%s", i, rs.url.Host) + s.Register(sinkName, rs, cfg.Required) + } + + return rm, nil +} + +func (rs *RelayManager) newRelay(cfg config.Relay) (*Relay, error) { + u, err := url.Parse(cfg.URL) + if err != nil { + return nil, err + } + + if u.Path != "" && u.Path != "/" { + return nil, fmt.Errorf("relay path in %s must be instance root", cfg.URL) + } + + u = u.JoinPath("/api/call-upload") + + return &Relay{ + Relay: cfg, + url: u, + mgr: rs, + }, nil +} + +func (s *Relay) Call(ctx context.Context, call *calls.Call) error { + var buf bytes.Buffer + body := multipart.NewWriter(&buf) + + err := forms.Marshal(call, body) + if err != nil { + return fmt.Errorf("relay form parse: %w", err) + } + + err = body.WriteField("key", s.APIKey) + if err != nil { + return fmt.Errorf("relay set API key: %w", err) + } + + body.Close() + + r, err := http.NewRequestWithContext(ctx, http.MethodPost, s.url.String(), &buf) + if err != nil { + return fmt.Errorf("relay newrequest: %w", err) + } + + r.Header.Set("Content-Type", body.FormDataContentType()) + r.Header.Set("User-Agent", version.HttpString("call-relay")) + + resp, err := s.mgr.client.Do(r) + if err != nil { + return fmt.Errorf("relay %s: %w", s.Name, err) + } + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("relay %s: received HTTP %d", s.Name, resp.StatusCode) + } + + return nil +} + +func (s *Relay) SinkType() string { + return "relay" +} diff --git a/pkg/sinks/relay_test.go b/pkg/sinks/relay_test.go new file mode 100644 index 0000000..90cfda0 --- /dev/null +++ b/pkg/sinks/relay_test.go @@ -0,0 +1,107 @@ +package sinks + +import ( + "context" + "math/rand" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "dynatron.me/x/stillbox/internal/common" + "dynatron.me/x/stillbox/internal/forms" + "dynatron.me/x/stillbox/pkg/auth" + "dynatron.me/x/stillbox/pkg/calls" + "dynatron.me/x/stillbox/pkg/config" + "dynatron.me/x/stillbox/pkg/sources" + + "github.com/google/uuid" +) + +type hand func(w http.ResponseWriter, r *http.Request) + +func (h hand) ServeHTTP(w http.ResponseWriter, r *http.Request) { + h(w, r) +} + +func TestRelay(t *testing.T) { + uuid.SetRand(rand.New(rand.NewSource(1))) + + tests := []struct { + name string + submitter auth.UserID + apiKey string + call calls.Call + }{ + { + name: "base", + submitter: auth.UserID(1), + call: calls.Call{ + ID: uuid.UUID([16]byte{0x52, 0xfd, 0xfc, 0x07, 0x21, 0x82, 0x45, 0x4f, 0x96, 0x3f, 0x5f, 0x0f, 0x9a, 0x62, 0x1d, 0x72}), + Submitter: common.PtrTo(auth.UserID(1)), + System: 197, + Talkgroup: 10101, + DateTime: time.Date(2024, 11, 10, 23, 33, 02, 0, time.Local), + AudioName: "rightnow.mp3", + Audio: []byte{0xFF, 0xF3, 0x14, 0xC4, 0x00, 0x00, 0x00, 0x03, 0x48, 0x01, 0x40, 0x00, 0x00, 0x4C, 0x41, 0x4D, 0x45, 0x33, 0x2E, 0x39, 0x36, 0x2E, 0x31, 0x55}, + AudioType: "audio/mpeg", + Duration: calls.CallDuration(24000000), + TalkgroupLabel: common.PtrTo("Some TG"), + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + var serr error + var called bool + h := hand(func(w http.ResponseWriter, r *http.Request) { + called = true + assert.Equal(t, "/api/call-upload", r.URL.Path) + serr = r.ParseMultipartForm(1024 * 1024 * 2) + if serr != nil { + t.Log("parsemultipart", serr) + return + } + + cur := new(sources.CallUploadRequest) + serr = forms.Unmarshal(r, cur, forms.WithAcceptBlank()) + cur.DontStore = true + if serr != nil { + t.Log("unmarshal", serr) + return + } + + assert.Equal(t, tc.apiKey, cur.Key) + + toC, tcerr := cur.ToCall(tc.submitter) + require.NoError(t, tcerr) + assert.Equal(t, &tc.call, toC) + }) + svr := httptest.NewServer(h) + + cfg := config.Relay{ + URL: svr.URL, + APIKey: tc.apiKey, + } + ns := &nullSinks{} + + rm, err := NewRelayManager(ns, []config.Relay{cfg}) + require.NoError(t, err) + err = rm.relays[0].Call(context.Background(), &tc.call) + assert.True(t, called) + assert.NoError(t, err) + assert.NoError(t, serr) + }) + } +} + +type nullSinks struct{} + +func (*nullSinks) Register(name string, toAdd Sink, required bool) {} +func (*nullSinks) Unregister(name string) {} +func (*nullSinks) Shutdown() {} +func (*nullSinks) EmitCall(ctx context.Context, call *calls.Call) error { return nil } diff --git a/pkg/sinks/sinks.go b/pkg/sinks/sinks.go index 3c8afd9..8259f30 100644 --- a/pkg/sinks/sinks.go +++ b/pkg/sinks/sinks.go @@ -24,36 +24,55 @@ type sinkInstance struct { Required bool } -type Sinks struct { - sync.RWMutex - sinks []sinkInstance +type Sinks interface { + Register(name string, toAdd Sink, required bool) + Unregister(name string) + Shutdown() + EmitCall(ctx context.Context, call *calls.Call) error } -func (s *Sinks) Register(name string, toAdd Sink, required bool) { +type sinks struct { + sync.RWMutex + sinks map[string]sinkInstance +} + +func NewSinkManager() Sinks { + return &sinks{ + sinks: make(map[string]sinkInstance), + } +} + +func (s *sinks) Register(name string, toAdd Sink, required bool) { s.Lock() defer s.Unlock() - s.sinks = append(s.sinks, sinkInstance{ + s.sinks[name] = sinkInstance{ Name: name, Sink: toAdd, Required: required, - }) + } } -func (s *Sinks) Shutdown() { +func (s *sinks) Unregister(name string) { s.Lock() defer s.Unlock() - s.sinks = nil + delete(s.sinks, name) } -func (s *Sinks) EmitCall(ctx context.Context, call *calls.Call) error { +func (s *sinks) Shutdown() { + s.Lock() + defer s.Unlock() + + clear(s.sinks) +} + +func (s *sinks) EmitCall(ctx context.Context, call *calls.Call) error { s.Lock() defer s.Unlock() g, ctx := errgroup.WithContext(ctx) - for i := range s.sinks { - sink := s.sinks[i] + for _, sink := range s.sinks { g.Go(sink.callEmitter(ctx, call)) } diff --git a/pkg/sources/http.go b/pkg/sources/http.go index 32d1a28..7c30347 100644 --- a/pkg/sources/http.go +++ b/pkg/sources/http.go @@ -36,7 +36,7 @@ func (h *RdioHTTP) InstallPublicRoutes(r chi.Router) { r.Post("/api/call-upload", h.routeCallUpload) } -type callUploadRequest struct { +type CallUploadRequest struct { Audio []byte `form:"audio" filenameField:"AudioName"` AudioName string AudioType string `form:"audioType"` @@ -56,7 +56,7 @@ type callUploadRequest struct { DontStore bool `form:"dontStore"` } -func (car *callUploadRequest) mimeType() string { +func (car *CallUploadRequest) mimeType() string { // this is super naïve fn := car.AudioName switch { @@ -71,7 +71,7 @@ func (car *callUploadRequest) mimeType() string { return "" } -func (car *callUploadRequest) toCall(submitter auth.UserID) (*calls.Call, error) { +func (car *CallUploadRequest) ToCall(submitter auth.UserID) (*calls.Call, error) { return calls.Make(&calls.Call{ Submitter: &submitter, System: car.System, @@ -111,14 +111,14 @@ func (h *RdioHTTP) routeCallUpload(w http.ResponseWriter, r *http.Request) { return } - cur := new(callUploadRequest) + cur := new(CallUploadRequest) err = forms.Unmarshal(r, cur, forms.WithAcceptBlank()) if err != nil { http.Error(w, "cannot bind upload "+err.Error(), http.StatusExpectationFailed) return } - call, err := cur.toCall(*submitter) + call, err := cur.ToCall(*submitter) if err != nil { log.Error().Err(err).Msg("toCall failed") http.Error(w, err.Error(), http.StatusBadRequest)