diff --git a/internal/forms/forms.go b/internal/forms/forms.go index f8a7f7a..8f0be96 100644 --- a/internal/forms/forms.go +++ b/internal/forms/forms.go @@ -1,19 +1,8 @@ package forms import ( - "encoding/json" "errors" - "fmt" - "io" - "net/http" - "reflect" - "strconv" - "strings" "time" - - "dynatron.me/x/stillbox/internal/jsontypes" - - "github.com/araddon/dateparse" ) var ( @@ -80,293 +69,3 @@ func (o *options) Tag() string { return "form" } - -func (o *options) parseTime(s string, dpo ...dateparse.ParserOption) (t time.Time, set bool, err error) { - if o.acceptBlank && s == "" { - set = false - return - } - - if iv, err := strconv.Atoi(s); err == nil { - return time.Unix(int64(iv), 0), true, nil - } - - switch { - case o.parseTimeIn != nil: - t, err = dateparse.ParseIn(s, o.parseTimeIn, dpo...) - case o.parseLocal: - t, err = dateparse.ParseLocal(s, dpo...) - default: - t, err = dateparse.ParseAny(s, dpo...) - } - - set = true - - return -} - -func (o *options) parseBool(s string) (v bool, set bool, err error) { - if o.acceptBlank && s == "" { - set = false - return - } - - set = true - - v, err = strconv.ParseBool(s) - if err != nil { - return v, set, fmt.Errorf("parsebool('%s'): %w", s, err) - } - - return -} - -func (o *options) parseInt(s string) (v int, set bool, err error) { - if o.acceptBlank && s == "" { - set = false - return - } - set = true - - v, err = strconv.Atoi(s) - if err != nil { - return v, set, fmt.Errorf("atoi('%s'): %w", s, err) - } - - return -} - -func (o *options) parseFloat64(s string) (v float64, set bool, err error) { - if o.acceptBlank && s == "" { - set = false - return - } - set = true - - v, err = strconv.ParseFloat(s, 64) - if err != nil { - return v, set, fmt.Errorf("ParseFloat('%s'): %w", s, err) - } - - return -} - -func (o *options) parseDuration(s string) (v time.Duration, set bool, err error) { - if o.acceptBlank && s == "" { - set = false - return - } - - set = true - - v, err = time.ParseDuration(s) - if err != nil { - return v, set, fmt.Errorf("ParseDuration('%s'): %w", s, err) - } - - return -} - -var typeOfByteSlice = reflect.TypeOf([]byte(nil)) - -func (o *options) iterFields(r *http.Request, destStruct reflect.Value) error { - structType := destStruct.Type() - for i := 0; i < destStruct.NumField(); i++ { - destFieldVal := destStruct.Field(i) - fieldType := structType.Field(i) - if !fieldType.IsExported() && !fieldType.Anonymous { - continue - } - - if destFieldVal.Kind() == reflect.Struct && fieldType.Anonymous { - err := o.iterFields(r, destFieldVal) - if err != nil { - return err - } - } - - var tAr []string - var formField string - var omitEmpty bool - if o.defaultOmitEmpty { - omitEmpty = true - } - - formTag, has := structType.Field(i).Tag.Lookup(o.Tag()) - if has { - tAr = strings.Split(formTag, ",") - formField = tAr[0] - for _, v := range tAr[1:] { - if v == "omitempty" { - omitEmpty = true - break - } - } - } - - if !has || formField == "-" { - continue - } - - destFieldIntf := destFieldVal.Interface() - - if destFieldVal.Kind() == reflect.Slice && destFieldVal.Type() == typeOfByteSlice { - file, hdr, err := r.FormFile(formField) - if err != nil { - return fmt.Errorf("get form file: %w", err) - } - - nameField, hasFilename := structType.Field(i).Tag.Lookup("filenameField") - if hasFilename { - fnf := destStruct.FieldByName(nameField) - if fnf == (reflect.Value{}) { - panic(fmt.Errorf("filenameField '%s' does not exist", nameField)) - } - - fnf.SetString(hdr.Filename) - } - audioBytes, err := io.ReadAll(file) - if err != nil { - return fmt.Errorf("file read: %w", err) - } - - destFieldVal.SetBytes(audioBytes) - - continue - } - - if !r.Form.Has(formField) && omitEmpty { - continue - } - - ff := r.Form.Get(formField) - - switch v := destFieldIntf.(type) { - case string, *string: - setVal(destFieldVal, ff != "" || o.acceptBlank, ff) - case int, uint, *int, *uint: - val, set, err := o.parseInt(ff) - if err != nil { - return err - } - setVal(destFieldVal, set, val) - case float64: - val, set, err := o.parseFloat64(ff) - if err != nil { - return err - } - setVal(destFieldVal, set, val) - case bool, *bool: - val, set, err := o.parseBool(ff) - if err != nil { - return err - } - setVal(destFieldVal, set, val) - case time.Time, *time.Time, jsontypes.Time, *jsontypes.Time: - t, set, err := o.parseTime(ff) - if err != nil { - return err - } - setVal(destFieldVal, set, t) - case time.Duration, *time.Duration, jsontypes.Duration, *jsontypes.Duration: - d, set, err := o.parseDuration(ff) - if err != nil { - return err - } - setVal(destFieldVal, set, d) - case []int: - val := strings.Trim(ff, "[]") - if val == "" && o.acceptBlank { - continue - } - vals := strings.Split(val, ",") - ar := make([]int, 0, len(vals)) - for _, v := range vals { - i, err := strconv.Atoi(v) - if err == nil { - ar = append(ar, i) - } - } - destFieldVal.Set(reflect.ValueOf(ar)) - default: - panic(fmt.Errorf("unsupported type %T", v)) - } - } - - return nil -} - -func setVal(destFieldVal reflect.Value, set bool, src any) { - if !set { - return - } - - destType := destFieldVal.Type() - srcVal := reflect.ValueOf(src) - - if srcVal.Kind() == reflect.Ptr { - srcVal = srcVal.Elem() - } - - if destType.Kind() == reflect.Ptr { - if !srcVal.CanAddr() { - if srcVal.CanConvert(destType.Elem()) { - srcVal = srcVal.Convert(destType.Elem()) - } - copy := reflect.New(srcVal.Type()) - copy.Elem().Set(srcVal) - srcVal = copy - } - } else if srcVal.CanConvert(destFieldVal.Type()) { - srcVal = srcVal.Convert(destFieldVal.Type()) - } - - destFieldVal.Set(srcVal) -} - -func Unmarshal(r *http.Request, dest any, opt ...Option) error { - o := options{ - maxMultipartMemory: MaxMultipartMemory, - } - - for _, opt := range opt { - opt(&o) - } - - contentType := strings.Split(r.Header.Get("Content-Type"), ";")[0] - - switch contentType { - case "multipart/form-data": - err := r.ParseMultipartForm(o.maxMultipartMemory) - if err != nil { - return fmt.Errorf("ParseForm: %w", err) - } - - return o.unmarshalForm(r, dest) - case "application/x-www-form-urlencoded": - err := r.ParseForm() - if err != nil { - return fmt.Errorf("ParseForm: %w", err) - } - return o.unmarshalForm(r, dest) - case "application/json": - return json.NewDecoder(r.Body).Decode(dest) - } - - return ErrContentType -} - -func (o *options) unmarshalForm(r *http.Request, dest any) error { - destVal := reflect.ValueOf(dest) - if k := destVal.Kind(); k == reflect.Ptr { - destVal = destVal.Elem() - } else { - return ErrNotPointer - } - - if destVal.Kind() != reflect.Struct { - return ErrNotStruct - } - - return o.iterFields(r, destVal) -} diff --git a/internal/forms/marshal.go b/internal/forms/marshal.go new file mode 100644 index 0000000..dedd9ad --- /dev/null +++ b/internal/forms/marshal.go @@ -0,0 +1,139 @@ +package forms + +import ( + "fmt" + "mime/multipart" + "reflect" + "strconv" + "strings" + "time" +) + +func Marshal(src any, dest *multipart.Writer, opts ...Option) error { + o := options{} + + for _, opt := range opts { + opt(&o) + } + + return o.marshalMultipartForm(src, dest) +} + +func (o *options) marshalMultipartForm(src any, dest *multipart.Writer) error { + srcVal := reflect.ValueOf(src) + if k := srcVal.Kind(); k == reflect.Ptr { + srcVal = srcVal.Elem() + } + + if srcVal.Kind() != reflect.Struct { + return ErrNotStruct + } + + return o.marIterFields(srcVal, dest) +} + +func (o *options) marIterFields(srcVal reflect.Value, dest *multipart.Writer) error { + structType := srcVal.Type() + for i := 0; i < structType.NumField(); i++ { + srcFieldVal := srcVal.Field(i) + fieldType := structType.Field(i) + if !fieldType.IsExported() && !fieldType.Anonymous { + continue + } + + if srcFieldVal.Kind() == reflect.Struct && fieldType.Anonymous { + continue + } + + var tAr []string + var formField string + var omitEmpty bool + if o.defaultOmitEmpty { + omitEmpty = true + } + + formTag, has := structType.Field(i).Tag.Lookup(o.Tag()) + if has { + tAr = strings.Split(formTag, ",") + formField = tAr[0] + for _, v := range tAr[1:] { + if v == "omitempty" { + omitEmpty = true + break + } + } + } + + if !has || formField == "-" { + continue + } + + if srcFieldVal.Kind() == reflect.Ptr { + srcFieldVal = srcFieldVal.Elem() + if srcFieldVal == (reflect.Value{}) || srcFieldVal.IsZero() { + continue + } + } + + srcFieldIntf := srcFieldVal.Interface() + if srcFieldVal.Kind() == reflect.Slice && srcFieldVal.Type() == typeOfByteSlice { + nameField, hasFilename := structType.Field(i).Tag.Lookup("filenameField") + fileName := "" + if hasFilename { + fnf := srcVal.FieldByName(nameField) + if fnf == (reflect.Value{}) { + panic(fmt.Errorf("filenameField '%s' does not exist", nameField)) + } + + fileName = fnf.String() + } + + fw, err := dest.CreateFormFile(formField, fileName) + if err != nil { + return fmt.Errorf("form marshal: createFormFile: %w", err) + } + + _, err = fw.Write(srcFieldVal.Bytes()) + if err != nil { + return fmt.Errorf("form marshal: write file: %w", err) + } + + continue + } + + if srcFieldVal.IsZero() && omitEmpty { + continue + } + + var val string + switch v := srcFieldIntf.(type) { + case []string: + if omitEmpty && len(v) == 0 { + continue + } + val = "[" + strings.Join(v, ",") + "]" + case []int: + if omitEmpty && len(v) == 0 { + continue + } + + sl := make([]string, len(v)) + for i := range v { + sl[i] = strconv.Itoa(v[i]) + } + + val = "[" + strings.Join(sl, ",") + "]" + case time.Time: + val = strconv.Itoa(int(v.Unix())) + default: + val = fmt.Sprint(v) + } + + err := dest.WriteField(formField, val) + if err != nil { + return fmt.Errorf("marshal field '%s': %w", formField, err) + } + } + + return nil +} diff --git a/internal/forms/marshal_test.go b/internal/forms/marshal_test.go new file mode 100644 index 0000000..6004218 --- /dev/null +++ b/internal/forms/marshal_test.go @@ -0,0 +1,121 @@ +package forms_test + +import ( + "bytes" + "fmt" + "math/rand" + "mime/multipart" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "dynatron.me/x/stillbox/internal/common" + "dynatron.me/x/stillbox/internal/forms" + "dynatron.me/x/stillbox/pkg/auth" + "dynatron.me/x/stillbox/pkg/calls" + "dynatron.me/x/stillbox/pkg/sources" + + "github.com/google/uuid" +) + +type hand func(w http.ResponseWriter, r *http.Request) + +func (h hand) ServeHTTP(w http.ResponseWriter, r *http.Request) { + h(w, r) +} + +func call(url string, call *calls.Call) error { + var buf bytes.Buffer + body := multipart.NewWriter(&buf) + + err := forms.Marshal(call, body) + if err != nil { + return fmt.Errorf("relay form parse: %w", err) + } + body.Close() + + r, err := http.NewRequest(http.MethodPost, url, &buf) + if err != nil { + return fmt.Errorf("relay newrequest: %w", err) + } + + r.Header.Set("Content-Type", body.FormDataContentType()) + + resp, err := http.DefaultClient.Do(r) + if err != nil { + return fmt.Errorf("relay: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("relay: received HTTP %d", resp.StatusCode) + } + + return nil +} + +func TestMarshal(t *testing.T) { + uuid.SetRand(rand.New(rand.NewSource(1))) + + tests := []struct { + name string + submitter auth.UserID + apiKey string + call calls.Call + }{ + { + name: "base", + submitter: auth.UserID(1), + call: calls.Call{ + ID: uuid.UUID([16]byte{0x52, 0xfd, 0xfc, 0x07, 0x21, 0x82, 0x45, 0x4f, 0x96, 0x3f, 0x5f, 0x0f, 0x9a, 0x62, 0x1d, 0x72}), + Submitter: common.PtrTo(auth.UserID(1)), + System: 197, + Talkgroup: 10101, + DateTime: time.Date(2024, 11, 10, 23, 33, 02, 0, time.Local), + AudioName: "rightnow.mp3", + Audio: []byte{0xFF, 0xF3, 0x14, 0xC4, 0x00, 0x00, 0x00, 0x03, 0x48, 0x01, 0x40, 0x00, 0x00, 0x4C, 0x41, 0x4D, 0x45, 0x33, 0x2E, 0x39, 0x36, 0x2E, 0x31, 0x55}, + AudioType: "audio/mpeg", + Duration: calls.CallDuration(24000000), + TalkgroupLabel: common.PtrTo("Some TG"), + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + var serr error + var called bool + h := hand(func(w http.ResponseWriter, r *http.Request) { + called = true + serr = r.ParseMultipartForm(1024 * 1024 * 2) + if serr != nil { + t.Log("parsemultipart", serr) + return + } + + cur := new(sources.CallUploadRequest) + serr = forms.Unmarshal(r, cur, forms.WithAcceptBlank()) + cur.DontStore = true + if serr != nil { + t.Log("unmarshal", serr) + return + } + + assert.Equal(t, tc.apiKey, cur.Key) + + toC, tcerr := cur.ToCall(tc.submitter) + require.NoError(t, tcerr) + assert.Equal(t, &tc.call, toC) + }) + svr := httptest.NewServer(h) + + err := call(svr.URL, &tc.call) + assert.True(t, called) + assert.NoError(t, err) + assert.NoError(t, serr) + }) + } +} diff --git a/internal/forms/unmarshal.go b/internal/forms/unmarshal.go new file mode 100644 index 0000000..4b8f568 --- /dev/null +++ b/internal/forms/unmarshal.go @@ -0,0 +1,306 @@ +package forms + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "reflect" + "strconv" + "strings" + "time" + + "dynatron.me/x/stillbox/internal/jsontypes" + + "github.com/araddon/dateparse" +) + +func (o *options) parseTime(s string, dpo ...dateparse.ParserOption) (t time.Time, set bool, err error) { + if o.acceptBlank && s == "" { + set = false + return + } + + if iv, err := strconv.Atoi(s); err == nil { + return time.Unix(int64(iv), 0), true, nil + } + + switch { + case o.parseTimeIn != nil: + t, err = dateparse.ParseIn(s, o.parseTimeIn, dpo...) + case o.parseLocal: + t, err = dateparse.ParseLocal(s, dpo...) + default: + t, err = dateparse.ParseAny(s, dpo...) + } + + set = true + + return +} + +func (o *options) parseBool(s string) (v bool, set bool, err error) { + if o.acceptBlank && s == "" { + set = false + return + } + + set = true + + v, err = strconv.ParseBool(s) + if err != nil { + return v, set, fmt.Errorf("parsebool('%s'): %w", s, err) + } + + return +} + +func (o *options) parseInt(s string) (v int, set bool, err error) { + if o.acceptBlank && s == "" { + set = false + return + } + set = true + + v, err = strconv.Atoi(s) + if err != nil { + return v, set, fmt.Errorf("atoi('%s'): %w", s, err) + } + + return +} + +func (o *options) parseFloat64(s string) (v float64, set bool, err error) { + if o.acceptBlank && s == "" { + set = false + return + } + set = true + + v, err = strconv.ParseFloat(s, 64) + if err != nil { + return v, set, fmt.Errorf("ParseFloat('%s'): %w", s, err) + } + + return +} + +func (o *options) parseDuration(s string) (v time.Duration, set bool, err error) { + if o.acceptBlank && s == "" { + set = false + return + } + + set = true + + v, err = time.ParseDuration(s) + if err != nil { + return v, set, fmt.Errorf("ParseDuration('%s'): %w", s, err) + } + + return +} + +var typeOfByteSlice = reflect.TypeOf([]byte(nil)) + +func (o *options) unmIterFields(r *http.Request, destStruct reflect.Value) error { + structType := destStruct.Type() + for i := 0; i < destStruct.NumField(); i++ { + destFieldVal := destStruct.Field(i) + fieldType := structType.Field(i) + if !fieldType.IsExported() && !fieldType.Anonymous { + continue + } + + if destFieldVal.Kind() == reflect.Struct && fieldType.Anonymous { + err := o.unmIterFields(r, destFieldVal) + if err != nil { + return err + } + } + + var tAr []string + var formField string + var omitEmpty bool + if o.defaultOmitEmpty { + omitEmpty = true + } + + formTag, has := structType.Field(i).Tag.Lookup(o.Tag()) + if has { + tAr = strings.Split(formTag, ",") + formField = tAr[0] + for _, v := range tAr[1:] { + if v == "omitempty" { + omitEmpty = true + break + } + } + } + + if !has || formField == "-" { + continue + } + + destFieldIntf := destFieldVal.Interface() + + if destFieldVal.Kind() == reflect.Slice && destFieldVal.Type() == typeOfByteSlice { + file, hdr, err := r.FormFile(formField) + if err != nil { + return fmt.Errorf("get form file: %w", err) + } + + nameField, hasFilename := structType.Field(i).Tag.Lookup("filenameField") + if hasFilename { + fnf := destStruct.FieldByName(nameField) + if fnf == (reflect.Value{}) { + panic(fmt.Errorf("filenameField '%s' does not exist", nameField)) + } + + fnf.SetString(hdr.Filename) + } + audioBytes, err := io.ReadAll(file) + if err != nil { + return fmt.Errorf("file read: %w", err) + } + + destFieldVal.SetBytes(audioBytes) + + continue + } + + if !r.Form.Has(formField) && omitEmpty { + continue + } + + ff := r.Form.Get(formField) + + switch v := destFieldIntf.(type) { + case string, *string: + setVal(destFieldVal, ff != "" || o.acceptBlank, ff) + case int, uint, *int, *uint: + val, set, err := o.parseInt(ff) + if err != nil { + return err + } + setVal(destFieldVal, set, val) + case float64: + val, set, err := o.parseFloat64(ff) + if err != nil { + return err + } + setVal(destFieldVal, set, val) + case bool, *bool: + val, set, err := o.parseBool(ff) + if err != nil { + return err + } + setVal(destFieldVal, set, val) + case time.Time, *time.Time, jsontypes.Time, *jsontypes.Time: + t, set, err := o.parseTime(ff) + if err != nil { + return err + } + setVal(destFieldVal, set, t) + case time.Duration, *time.Duration, jsontypes.Duration, *jsontypes.Duration: + d, set, err := o.parseDuration(ff) + if err != nil { + return err + } + setVal(destFieldVal, set, d) + case []int: + val := strings.Trim(ff, "[]") + if val == "" && o.acceptBlank { + continue + } + vals := strings.Split(val, ",") + ar := make([]int, 0, len(vals)) + for _, v := range vals { + i, err := strconv.Atoi(v) + if err == nil { + ar = append(ar, i) + } + } + destFieldVal.Set(reflect.ValueOf(ar)) + default: + panic(fmt.Errorf("unsupported type %T", v)) + } + } + + return nil +} + +func setVal(destFieldVal reflect.Value, set bool, src any) { + if !set { + return + } + + destType := destFieldVal.Type() + srcVal := reflect.ValueOf(src) + + if srcVal.Kind() == reflect.Ptr { + srcVal = srcVal.Elem() + } + + if destType.Kind() == reflect.Ptr { + if !srcVal.CanAddr() { + if srcVal.CanConvert(destType.Elem()) { + srcVal = srcVal.Convert(destType.Elem()) + } + copy := reflect.New(srcVal.Type()) + copy.Elem().Set(srcVal) + srcVal = copy + } + } else if srcVal.CanConvert(destFieldVal.Type()) { + srcVal = srcVal.Convert(destFieldVal.Type()) + } + + destFieldVal.Set(srcVal) +} + +func Unmarshal(r *http.Request, dest any, opts ...Option) error { + o := options{ + maxMultipartMemory: MaxMultipartMemory, + } + + for _, opt := range opts { + opt(&o) + } + + contentType := strings.Split(r.Header.Get("Content-Type"), ";")[0] + + switch contentType { + case "multipart/form-data": + err := r.ParseMultipartForm(o.maxMultipartMemory) + if err != nil { + return fmt.Errorf("ParseForm: %w", err) + } + + return o.unmarshalForm(r, dest) + case "application/x-www-form-urlencoded": + err := r.ParseForm() + if err != nil { + return fmt.Errorf("ParseForm: %w", err) + } + return o.unmarshalForm(r, dest) + case "application/json": + return json.NewDecoder(r.Body).Decode(dest) + } + + return ErrContentType +} + +func (o *options) unmarshalForm(r *http.Request, dest any) error { + destVal := reflect.ValueOf(dest) + if k := destVal.Kind(); k == reflect.Ptr { + destVal = destVal.Elem() + } else { + return ErrNotPointer + } + + if destVal.Kind() != reflect.Struct { + return ErrNotStruct + } + + return o.unmIterFields(r, destVal) +} diff --git a/internal/forms/forms_test.go b/internal/forms/unmarshal_test.go similarity index 100% rename from internal/forms/forms_test.go rename to internal/forms/unmarshal_test.go