add marshal
This commit is contained in:
parent
4eddd7d1df
commit
a318242de2
5 changed files with 566 additions and 301 deletions
|
@ -1,19 +1,8 @@
|
|||
package forms
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"dynatron.me/x/stillbox/internal/jsontypes"
|
||||
|
||||
"github.com/araddon/dateparse"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -80,293 +69,3 @@ func (o *options) Tag() string {
|
|||
|
||||
return "form"
|
||||
}
|
||||
|
||||
func (o *options) parseTime(s string, dpo ...dateparse.ParserOption) (t time.Time, set bool, err error) {
|
||||
if o.acceptBlank && s == "" {
|
||||
set = false
|
||||
return
|
||||
}
|
||||
|
||||
if iv, err := strconv.Atoi(s); err == nil {
|
||||
return time.Unix(int64(iv), 0), true, nil
|
||||
}
|
||||
|
||||
switch {
|
||||
case o.parseTimeIn != nil:
|
||||
t, err = dateparse.ParseIn(s, o.parseTimeIn, dpo...)
|
||||
case o.parseLocal:
|
||||
t, err = dateparse.ParseLocal(s, dpo...)
|
||||
default:
|
||||
t, err = dateparse.ParseAny(s, dpo...)
|
||||
}
|
||||
|
||||
set = true
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (o *options) parseBool(s string) (v bool, set bool, err error) {
|
||||
if o.acceptBlank && s == "" {
|
||||
set = false
|
||||
return
|
||||
}
|
||||
|
||||
set = true
|
||||
|
||||
v, err = strconv.ParseBool(s)
|
||||
if err != nil {
|
||||
return v, set, fmt.Errorf("parsebool('%s'): %w", s, err)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (o *options) parseInt(s string) (v int, set bool, err error) {
|
||||
if o.acceptBlank && s == "" {
|
||||
set = false
|
||||
return
|
||||
}
|
||||
set = true
|
||||
|
||||
v, err = strconv.Atoi(s)
|
||||
if err != nil {
|
||||
return v, set, fmt.Errorf("atoi('%s'): %w", s, err)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (o *options) parseFloat64(s string) (v float64, set bool, err error) {
|
||||
if o.acceptBlank && s == "" {
|
||||
set = false
|
||||
return
|
||||
}
|
||||
set = true
|
||||
|
||||
v, err = strconv.ParseFloat(s, 64)
|
||||
if err != nil {
|
||||
return v, set, fmt.Errorf("ParseFloat('%s'): %w", s, err)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (o *options) parseDuration(s string) (v time.Duration, set bool, err error) {
|
||||
if o.acceptBlank && s == "" {
|
||||
set = false
|
||||
return
|
||||
}
|
||||
|
||||
set = true
|
||||
|
||||
v, err = time.ParseDuration(s)
|
||||
if err != nil {
|
||||
return v, set, fmt.Errorf("ParseDuration('%s'): %w", s, err)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
var typeOfByteSlice = reflect.TypeOf([]byte(nil))
|
||||
|
||||
func (o *options) iterFields(r *http.Request, destStruct reflect.Value) error {
|
||||
structType := destStruct.Type()
|
||||
for i := 0; i < destStruct.NumField(); i++ {
|
||||
destFieldVal := destStruct.Field(i)
|
||||
fieldType := structType.Field(i)
|
||||
if !fieldType.IsExported() && !fieldType.Anonymous {
|
||||
continue
|
||||
}
|
||||
|
||||
if destFieldVal.Kind() == reflect.Struct && fieldType.Anonymous {
|
||||
err := o.iterFields(r, destFieldVal)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
var tAr []string
|
||||
var formField string
|
||||
var omitEmpty bool
|
||||
if o.defaultOmitEmpty {
|
||||
omitEmpty = true
|
||||
}
|
||||
|
||||
formTag, has := structType.Field(i).Tag.Lookup(o.Tag())
|
||||
if has {
|
||||
tAr = strings.Split(formTag, ",")
|
||||
formField = tAr[0]
|
||||
for _, v := range tAr[1:] {
|
||||
if v == "omitempty" {
|
||||
omitEmpty = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !has || formField == "-" {
|
||||
continue
|
||||
}
|
||||
|
||||
destFieldIntf := destFieldVal.Interface()
|
||||
|
||||
if destFieldVal.Kind() == reflect.Slice && destFieldVal.Type() == typeOfByteSlice {
|
||||
file, hdr, err := r.FormFile(formField)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get form file: %w", err)
|
||||
}
|
||||
|
||||
nameField, hasFilename := structType.Field(i).Tag.Lookup("filenameField")
|
||||
if hasFilename {
|
||||
fnf := destStruct.FieldByName(nameField)
|
||||
if fnf == (reflect.Value{}) {
|
||||
panic(fmt.Errorf("filenameField '%s' does not exist", nameField))
|
||||
}
|
||||
|
||||
fnf.SetString(hdr.Filename)
|
||||
}
|
||||
audioBytes, err := io.ReadAll(file)
|
||||
if err != nil {
|
||||
return fmt.Errorf("file read: %w", err)
|
||||
}
|
||||
|
||||
destFieldVal.SetBytes(audioBytes)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
if !r.Form.Has(formField) && omitEmpty {
|
||||
continue
|
||||
}
|
||||
|
||||
ff := r.Form.Get(formField)
|
||||
|
||||
switch v := destFieldIntf.(type) {
|
||||
case string, *string:
|
||||
setVal(destFieldVal, ff != "" || o.acceptBlank, ff)
|
||||
case int, uint, *int, *uint:
|
||||
val, set, err := o.parseInt(ff)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
setVal(destFieldVal, set, val)
|
||||
case float64:
|
||||
val, set, err := o.parseFloat64(ff)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
setVal(destFieldVal, set, val)
|
||||
case bool, *bool:
|
||||
val, set, err := o.parseBool(ff)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
setVal(destFieldVal, set, val)
|
||||
case time.Time, *time.Time, jsontypes.Time, *jsontypes.Time:
|
||||
t, set, err := o.parseTime(ff)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
setVal(destFieldVal, set, t)
|
||||
case time.Duration, *time.Duration, jsontypes.Duration, *jsontypes.Duration:
|
||||
d, set, err := o.parseDuration(ff)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
setVal(destFieldVal, set, d)
|
||||
case []int:
|
||||
val := strings.Trim(ff, "[]")
|
||||
if val == "" && o.acceptBlank {
|
||||
continue
|
||||
}
|
||||
vals := strings.Split(val, ",")
|
||||
ar := make([]int, 0, len(vals))
|
||||
for _, v := range vals {
|
||||
i, err := strconv.Atoi(v)
|
||||
if err == nil {
|
||||
ar = append(ar, i)
|
||||
}
|
||||
}
|
||||
destFieldVal.Set(reflect.ValueOf(ar))
|
||||
default:
|
||||
panic(fmt.Errorf("unsupported type %T", v))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func setVal(destFieldVal reflect.Value, set bool, src any) {
|
||||
if !set {
|
||||
return
|
||||
}
|
||||
|
||||
destType := destFieldVal.Type()
|
||||
srcVal := reflect.ValueOf(src)
|
||||
|
||||
if srcVal.Kind() == reflect.Ptr {
|
||||
srcVal = srcVal.Elem()
|
||||
}
|
||||
|
||||
if destType.Kind() == reflect.Ptr {
|
||||
if !srcVal.CanAddr() {
|
||||
if srcVal.CanConvert(destType.Elem()) {
|
||||
srcVal = srcVal.Convert(destType.Elem())
|
||||
}
|
||||
copy := reflect.New(srcVal.Type())
|
||||
copy.Elem().Set(srcVal)
|
||||
srcVal = copy
|
||||
}
|
||||
} else if srcVal.CanConvert(destFieldVal.Type()) {
|
||||
srcVal = srcVal.Convert(destFieldVal.Type())
|
||||
}
|
||||
|
||||
destFieldVal.Set(srcVal)
|
||||
}
|
||||
|
||||
func Unmarshal(r *http.Request, dest any, opt ...Option) error {
|
||||
o := options{
|
||||
maxMultipartMemory: MaxMultipartMemory,
|
||||
}
|
||||
|
||||
for _, opt := range opt {
|
||||
opt(&o)
|
||||
}
|
||||
|
||||
contentType := strings.Split(r.Header.Get("Content-Type"), ";")[0]
|
||||
|
||||
switch contentType {
|
||||
case "multipart/form-data":
|
||||
err := r.ParseMultipartForm(o.maxMultipartMemory)
|
||||
if err != nil {
|
||||
return fmt.Errorf("ParseForm: %w", err)
|
||||
}
|
||||
|
||||
return o.unmarshalForm(r, dest)
|
||||
case "application/x-www-form-urlencoded":
|
||||
err := r.ParseForm()
|
||||
if err != nil {
|
||||
return fmt.Errorf("ParseForm: %w", err)
|
||||
}
|
||||
return o.unmarshalForm(r, dest)
|
||||
case "application/json":
|
||||
return json.NewDecoder(r.Body).Decode(dest)
|
||||
}
|
||||
|
||||
return ErrContentType
|
||||
}
|
||||
|
||||
func (o *options) unmarshalForm(r *http.Request, dest any) error {
|
||||
destVal := reflect.ValueOf(dest)
|
||||
if k := destVal.Kind(); k == reflect.Ptr {
|
||||
destVal = destVal.Elem()
|
||||
} else {
|
||||
return ErrNotPointer
|
||||
}
|
||||
|
||||
if destVal.Kind() != reflect.Struct {
|
||||
return ErrNotStruct
|
||||
}
|
||||
|
||||
return o.iterFields(r, destVal)
|
||||
}
|
||||
|
|
139
internal/forms/marshal.go
Normal file
139
internal/forms/marshal.go
Normal file
|
@ -0,0 +1,139 @@
|
|||
package forms
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"mime/multipart"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
func Marshal(src any, dest *multipart.Writer, opts ...Option) error {
|
||||
o := options{}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(&o)
|
||||
}
|
||||
|
||||
return o.marshalMultipartForm(src, dest)
|
||||
}
|
||||
|
||||
func (o *options) marshalMultipartForm(src any, dest *multipart.Writer) error {
|
||||
srcVal := reflect.ValueOf(src)
|
||||
if k := srcVal.Kind(); k == reflect.Ptr {
|
||||
srcVal = srcVal.Elem()
|
||||
}
|
||||
|
||||
if srcVal.Kind() != reflect.Struct {
|
||||
return ErrNotStruct
|
||||
}
|
||||
|
||||
return o.marIterFields(srcVal, dest)
|
||||
}
|
||||
|
||||
func (o *options) marIterFields(srcVal reflect.Value, dest *multipart.Writer) error {
|
||||
structType := srcVal.Type()
|
||||
for i := 0; i < structType.NumField(); i++ {
|
||||
srcFieldVal := srcVal.Field(i)
|
||||
fieldType := structType.Field(i)
|
||||
if !fieldType.IsExported() && !fieldType.Anonymous {
|
||||
continue
|
||||
}
|
||||
|
||||
if srcFieldVal.Kind() == reflect.Struct && fieldType.Anonymous {
|
||||
continue
|
||||
}
|
||||
|
||||
var tAr []string
|
||||
var formField string
|
||||
var omitEmpty bool
|
||||
if o.defaultOmitEmpty {
|
||||
omitEmpty = true
|
||||
}
|
||||
|
||||
formTag, has := structType.Field(i).Tag.Lookup(o.Tag())
|
||||
if has {
|
||||
tAr = strings.Split(formTag, ",")
|
||||
formField = tAr[0]
|
||||
for _, v := range tAr[1:] {
|
||||
if v == "omitempty" {
|
||||
omitEmpty = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !has || formField == "-" {
|
||||
continue
|
||||
}
|
||||
|
||||
if srcFieldVal.Kind() == reflect.Ptr {
|
||||
srcFieldVal = srcFieldVal.Elem()
|
||||
if srcFieldVal == (reflect.Value{}) || srcFieldVal.IsZero() {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
srcFieldIntf := srcFieldVal.Interface()
|
||||
if srcFieldVal.Kind() == reflect.Slice && srcFieldVal.Type() == typeOfByteSlice {
|
||||
nameField, hasFilename := structType.Field(i).Tag.Lookup("filenameField")
|
||||
fileName := ""
|
||||
if hasFilename {
|
||||
fnf := srcVal.FieldByName(nameField)
|
||||
if fnf == (reflect.Value{}) {
|
||||
panic(fmt.Errorf("filenameField '%s' does not exist", nameField))
|
||||
}
|
||||
|
||||
fileName = fnf.String()
|
||||
}
|
||||
|
||||
fw, err := dest.CreateFormFile(formField, fileName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("form marshal: createFormFile: %w", err)
|
||||
}
|
||||
|
||||
_, err = fw.Write(srcFieldVal.Bytes())
|
||||
if err != nil {
|
||||
return fmt.Errorf("form marshal: write file: %w", err)
|
||||
}
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
if srcFieldVal.IsZero() && omitEmpty {
|
||||
continue
|
||||
}
|
||||
|
||||
var val string
|
||||
switch v := srcFieldIntf.(type) {
|
||||
case []string:
|
||||
if omitEmpty && len(v) == 0 {
|
||||
continue
|
||||
}
|
||||
val = "[" + strings.Join(v, ",") + "]"
|
||||
case []int:
|
||||
if omitEmpty && len(v) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
sl := make([]string, len(v))
|
||||
for i := range v {
|
||||
sl[i] = strconv.Itoa(v[i])
|
||||
}
|
||||
|
||||
val = "[" + strings.Join(sl, ",") + "]"
|
||||
case time.Time:
|
||||
val = strconv.Itoa(int(v.Unix()))
|
||||
default:
|
||||
val = fmt.Sprint(v)
|
||||
}
|
||||
|
||||
err := dest.WriteField(formField, val)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal field '%s': %w", formField, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
121
internal/forms/marshal_test.go
Normal file
121
internal/forms/marshal_test.go
Normal file
|
@ -0,0 +1,121 @@
|
|||
package forms_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"dynatron.me/x/stillbox/internal/common"
|
||||
"dynatron.me/x/stillbox/internal/forms"
|
||||
"dynatron.me/x/stillbox/pkg/auth"
|
||||
"dynatron.me/x/stillbox/pkg/calls"
|
||||
"dynatron.me/x/stillbox/pkg/sources"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type hand func(w http.ResponseWriter, r *http.Request)
|
||||
|
||||
func (h hand) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
h(w, r)
|
||||
}
|
||||
|
||||
func call(url string, call *calls.Call) error {
|
||||
var buf bytes.Buffer
|
||||
body := multipart.NewWriter(&buf)
|
||||
|
||||
err := forms.Marshal(call, body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("relay form parse: %w", err)
|
||||
}
|
||||
body.Close()
|
||||
|
||||
r, err := http.NewRequest(http.MethodPost, url, &buf)
|
||||
if err != nil {
|
||||
return fmt.Errorf("relay newrequest: %w", err)
|
||||
}
|
||||
|
||||
r.Header.Set("Content-Type", body.FormDataContentType())
|
||||
|
||||
resp, err := http.DefaultClient.Do(r)
|
||||
if err != nil {
|
||||
return fmt.Errorf("relay: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("relay: received HTTP %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestMarshal(t *testing.T) {
|
||||
uuid.SetRand(rand.New(rand.NewSource(1)))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
submitter auth.UserID
|
||||
apiKey string
|
||||
call calls.Call
|
||||
}{
|
||||
{
|
||||
name: "base",
|
||||
submitter: auth.UserID(1),
|
||||
call: calls.Call{
|
||||
ID: uuid.UUID([16]byte{0x52, 0xfd, 0xfc, 0x07, 0x21, 0x82, 0x45, 0x4f, 0x96, 0x3f, 0x5f, 0x0f, 0x9a, 0x62, 0x1d, 0x72}),
|
||||
Submitter: common.PtrTo(auth.UserID(1)),
|
||||
System: 197,
|
||||
Talkgroup: 10101,
|
||||
DateTime: time.Date(2024, 11, 10, 23, 33, 02, 0, time.Local),
|
||||
AudioName: "rightnow.mp3",
|
||||
Audio: []byte{0xFF, 0xF3, 0x14, 0xC4, 0x00, 0x00, 0x00, 0x03, 0x48, 0x01, 0x40, 0x00, 0x00, 0x4C, 0x41, 0x4D, 0x45, 0x33, 0x2E, 0x39, 0x36, 0x2E, 0x31, 0x55},
|
||||
AudioType: "audio/mpeg",
|
||||
Duration: calls.CallDuration(24000000),
|
||||
TalkgroupLabel: common.PtrTo("Some TG"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
var serr error
|
||||
var called bool
|
||||
h := hand(func(w http.ResponseWriter, r *http.Request) {
|
||||
called = true
|
||||
serr = r.ParseMultipartForm(1024 * 1024 * 2)
|
||||
if serr != nil {
|
||||
t.Log("parsemultipart", serr)
|
||||
return
|
||||
}
|
||||
|
||||
cur := new(sources.CallUploadRequest)
|
||||
serr = forms.Unmarshal(r, cur, forms.WithAcceptBlank())
|
||||
cur.DontStore = true
|
||||
if serr != nil {
|
||||
t.Log("unmarshal", serr)
|
||||
return
|
||||
}
|
||||
|
||||
assert.Equal(t, tc.apiKey, cur.Key)
|
||||
|
||||
toC, tcerr := cur.ToCall(tc.submitter)
|
||||
require.NoError(t, tcerr)
|
||||
assert.Equal(t, &tc.call, toC)
|
||||
})
|
||||
svr := httptest.NewServer(h)
|
||||
|
||||
err := call(svr.URL, &tc.call)
|
||||
assert.True(t, called)
|
||||
assert.NoError(t, err)
|
||||
assert.NoError(t, serr)
|
||||
})
|
||||
}
|
||||
}
|
306
internal/forms/unmarshal.go
Normal file
306
internal/forms/unmarshal.go
Normal file
|
@ -0,0 +1,306 @@
|
|||
package forms
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"dynatron.me/x/stillbox/internal/jsontypes"
|
||||
|
||||
"github.com/araddon/dateparse"
|
||||
)
|
||||
|
||||
func (o *options) parseTime(s string, dpo ...dateparse.ParserOption) (t time.Time, set bool, err error) {
|
||||
if o.acceptBlank && s == "" {
|
||||
set = false
|
||||
return
|
||||
}
|
||||
|
||||
if iv, err := strconv.Atoi(s); err == nil {
|
||||
return time.Unix(int64(iv), 0), true, nil
|
||||
}
|
||||
|
||||
switch {
|
||||
case o.parseTimeIn != nil:
|
||||
t, err = dateparse.ParseIn(s, o.parseTimeIn, dpo...)
|
||||
case o.parseLocal:
|
||||
t, err = dateparse.ParseLocal(s, dpo...)
|
||||
default:
|
||||
t, err = dateparse.ParseAny(s, dpo...)
|
||||
}
|
||||
|
||||
set = true
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (o *options) parseBool(s string) (v bool, set bool, err error) {
|
||||
if o.acceptBlank && s == "" {
|
||||
set = false
|
||||
return
|
||||
}
|
||||
|
||||
set = true
|
||||
|
||||
v, err = strconv.ParseBool(s)
|
||||
if err != nil {
|
||||
return v, set, fmt.Errorf("parsebool('%s'): %w", s, err)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (o *options) parseInt(s string) (v int, set bool, err error) {
|
||||
if o.acceptBlank && s == "" {
|
||||
set = false
|
||||
return
|
||||
}
|
||||
set = true
|
||||
|
||||
v, err = strconv.Atoi(s)
|
||||
if err != nil {
|
||||
return v, set, fmt.Errorf("atoi('%s'): %w", s, err)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (o *options) parseFloat64(s string) (v float64, set bool, err error) {
|
||||
if o.acceptBlank && s == "" {
|
||||
set = false
|
||||
return
|
||||
}
|
||||
set = true
|
||||
|
||||
v, err = strconv.ParseFloat(s, 64)
|
||||
if err != nil {
|
||||
return v, set, fmt.Errorf("ParseFloat('%s'): %w", s, err)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (o *options) parseDuration(s string) (v time.Duration, set bool, err error) {
|
||||
if o.acceptBlank && s == "" {
|
||||
set = false
|
||||
return
|
||||
}
|
||||
|
||||
set = true
|
||||
|
||||
v, err = time.ParseDuration(s)
|
||||
if err != nil {
|
||||
return v, set, fmt.Errorf("ParseDuration('%s'): %w", s, err)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
var typeOfByteSlice = reflect.TypeOf([]byte(nil))
|
||||
|
||||
func (o *options) unmIterFields(r *http.Request, destStruct reflect.Value) error {
|
||||
structType := destStruct.Type()
|
||||
for i := 0; i < destStruct.NumField(); i++ {
|
||||
destFieldVal := destStruct.Field(i)
|
||||
fieldType := structType.Field(i)
|
||||
if !fieldType.IsExported() && !fieldType.Anonymous {
|
||||
continue
|
||||
}
|
||||
|
||||
if destFieldVal.Kind() == reflect.Struct && fieldType.Anonymous {
|
||||
err := o.unmIterFields(r, destFieldVal)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
var tAr []string
|
||||
var formField string
|
||||
var omitEmpty bool
|
||||
if o.defaultOmitEmpty {
|
||||
omitEmpty = true
|
||||
}
|
||||
|
||||
formTag, has := structType.Field(i).Tag.Lookup(o.Tag())
|
||||
if has {
|
||||
tAr = strings.Split(formTag, ",")
|
||||
formField = tAr[0]
|
||||
for _, v := range tAr[1:] {
|
||||
if v == "omitempty" {
|
||||
omitEmpty = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !has || formField == "-" {
|
||||
continue
|
||||
}
|
||||
|
||||
destFieldIntf := destFieldVal.Interface()
|
||||
|
||||
if destFieldVal.Kind() == reflect.Slice && destFieldVal.Type() == typeOfByteSlice {
|
||||
file, hdr, err := r.FormFile(formField)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get form file: %w", err)
|
||||
}
|
||||
|
||||
nameField, hasFilename := structType.Field(i).Tag.Lookup("filenameField")
|
||||
if hasFilename {
|
||||
fnf := destStruct.FieldByName(nameField)
|
||||
if fnf == (reflect.Value{}) {
|
||||
panic(fmt.Errorf("filenameField '%s' does not exist", nameField))
|
||||
}
|
||||
|
||||
fnf.SetString(hdr.Filename)
|
||||
}
|
||||
audioBytes, err := io.ReadAll(file)
|
||||
if err != nil {
|
||||
return fmt.Errorf("file read: %w", err)
|
||||
}
|
||||
|
||||
destFieldVal.SetBytes(audioBytes)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
if !r.Form.Has(formField) && omitEmpty {
|
||||
continue
|
||||
}
|
||||
|
||||
ff := r.Form.Get(formField)
|
||||
|
||||
switch v := destFieldIntf.(type) {
|
||||
case string, *string:
|
||||
setVal(destFieldVal, ff != "" || o.acceptBlank, ff)
|
||||
case int, uint, *int, *uint:
|
||||
val, set, err := o.parseInt(ff)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
setVal(destFieldVal, set, val)
|
||||
case float64:
|
||||
val, set, err := o.parseFloat64(ff)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
setVal(destFieldVal, set, val)
|
||||
case bool, *bool:
|
||||
val, set, err := o.parseBool(ff)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
setVal(destFieldVal, set, val)
|
||||
case time.Time, *time.Time, jsontypes.Time, *jsontypes.Time:
|
||||
t, set, err := o.parseTime(ff)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
setVal(destFieldVal, set, t)
|
||||
case time.Duration, *time.Duration, jsontypes.Duration, *jsontypes.Duration:
|
||||
d, set, err := o.parseDuration(ff)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
setVal(destFieldVal, set, d)
|
||||
case []int:
|
||||
val := strings.Trim(ff, "[]")
|
||||
if val == "" && o.acceptBlank {
|
||||
continue
|
||||
}
|
||||
vals := strings.Split(val, ",")
|
||||
ar := make([]int, 0, len(vals))
|
||||
for _, v := range vals {
|
||||
i, err := strconv.Atoi(v)
|
||||
if err == nil {
|
||||
ar = append(ar, i)
|
||||
}
|
||||
}
|
||||
destFieldVal.Set(reflect.ValueOf(ar))
|
||||
default:
|
||||
panic(fmt.Errorf("unsupported type %T", v))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func setVal(destFieldVal reflect.Value, set bool, src any) {
|
||||
if !set {
|
||||
return
|
||||
}
|
||||
|
||||
destType := destFieldVal.Type()
|
||||
srcVal := reflect.ValueOf(src)
|
||||
|
||||
if srcVal.Kind() == reflect.Ptr {
|
||||
srcVal = srcVal.Elem()
|
||||
}
|
||||
|
||||
if destType.Kind() == reflect.Ptr {
|
||||
if !srcVal.CanAddr() {
|
||||
if srcVal.CanConvert(destType.Elem()) {
|
||||
srcVal = srcVal.Convert(destType.Elem())
|
||||
}
|
||||
copy := reflect.New(srcVal.Type())
|
||||
copy.Elem().Set(srcVal)
|
||||
srcVal = copy
|
||||
}
|
||||
} else if srcVal.CanConvert(destFieldVal.Type()) {
|
||||
srcVal = srcVal.Convert(destFieldVal.Type())
|
||||
}
|
||||
|
||||
destFieldVal.Set(srcVal)
|
||||
}
|
||||
|
||||
func Unmarshal(r *http.Request, dest any, opts ...Option) error {
|
||||
o := options{
|
||||
maxMultipartMemory: MaxMultipartMemory,
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(&o)
|
||||
}
|
||||
|
||||
contentType := strings.Split(r.Header.Get("Content-Type"), ";")[0]
|
||||
|
||||
switch contentType {
|
||||
case "multipart/form-data":
|
||||
err := r.ParseMultipartForm(o.maxMultipartMemory)
|
||||
if err != nil {
|
||||
return fmt.Errorf("ParseForm: %w", err)
|
||||
}
|
||||
|
||||
return o.unmarshalForm(r, dest)
|
||||
case "application/x-www-form-urlencoded":
|
||||
err := r.ParseForm()
|
||||
if err != nil {
|
||||
return fmt.Errorf("ParseForm: %w", err)
|
||||
}
|
||||
return o.unmarshalForm(r, dest)
|
||||
case "application/json":
|
||||
return json.NewDecoder(r.Body).Decode(dest)
|
||||
}
|
||||
|
||||
return ErrContentType
|
||||
}
|
||||
|
||||
func (o *options) unmarshalForm(r *http.Request, dest any) error {
|
||||
destVal := reflect.ValueOf(dest)
|
||||
if k := destVal.Kind(); k == reflect.Ptr {
|
||||
destVal = destVal.Elem()
|
||||
} else {
|
||||
return ErrNotPointer
|
||||
}
|
||||
|
||||
if destVal.Kind() != reflect.Struct {
|
||||
return ErrNotStruct
|
||||
}
|
||||
|
||||
return o.unmIterFields(r, destVal)
|
||||
}
|
Loading…
Reference in a new issue