305 lines
5.7 KiB
Go
305 lines
5.7 KiB
Go
package forms
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"reflect"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"dynatron.me/x/stillbox/internal/jsontime"
|
|
|
|
"github.com/araddon/dateparse"
|
|
)
|
|
|
|
var (
|
|
ErrNotStruct = errors.New("destination is not a struct")
|
|
ErrNotPointer = errors.New("destination is not a pointer")
|
|
)
|
|
|
|
type options struct {
|
|
tagOverride *string
|
|
parseTimeIn *time.Location
|
|
parseLocal bool
|
|
acceptBlank bool
|
|
}
|
|
|
|
type Option func(*options)
|
|
|
|
func WithParseTimeInTZ(l *time.Location) Option {
|
|
return func(o *options) {
|
|
o.parseTimeIn = l
|
|
}
|
|
}
|
|
|
|
func WithParseLocalTime() Option {
|
|
return func(o *options) {
|
|
o.parseLocal = true
|
|
}
|
|
}
|
|
|
|
func WithAcceptBlank() Option {
|
|
return func(o *options) {
|
|
o.acceptBlank = true
|
|
}
|
|
}
|
|
|
|
func WithTag(t string) Option {
|
|
return func(o *options) {
|
|
o.tagOverride = &t
|
|
}
|
|
}
|
|
|
|
func (o *options) Tag() string {
|
|
if o.tagOverride != nil {
|
|
return *o.tagOverride
|
|
}
|
|
|
|
return "form"
|
|
}
|
|
|
|
func (o *options) parseTime(s string, dpo ...dateparse.ParserOption) (t time.Time, set bool, err error) {
|
|
if o.acceptBlank && s == "" {
|
|
set = false
|
|
return
|
|
}
|
|
|
|
if iv, err := strconv.Atoi(s); err == nil {
|
|
return time.Unix(int64(iv), 0), true, nil
|
|
}
|
|
|
|
switch {
|
|
case o.parseTimeIn != nil:
|
|
t, err = dateparse.ParseIn(s, o.parseTimeIn, dpo...)
|
|
case o.parseLocal:
|
|
t, err = dateparse.ParseLocal(s, dpo...)
|
|
default:
|
|
t, err = dateparse.ParseAny(s, dpo...)
|
|
}
|
|
|
|
set = true
|
|
|
|
return
|
|
}
|
|
|
|
func (o *options) parseBool(s string) (v bool, set bool, err error) {
|
|
if o.acceptBlank && s == "" {
|
|
set = false
|
|
return
|
|
}
|
|
|
|
set = true
|
|
|
|
v, err = strconv.ParseBool(s)
|
|
if err != nil {
|
|
return v, set, fmt.Errorf("parsebool('%s'): %w", s, err)
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
func (o *options) parseInt(s string) (v int, set bool, err error) {
|
|
if o.acceptBlank && s == "" {
|
|
set = false
|
|
return
|
|
}
|
|
set = true
|
|
|
|
v, err = strconv.Atoi(s)
|
|
if err != nil {
|
|
return v, set, fmt.Errorf("atoi('%s'): %w", s, err)
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
func (o *options) parseFloat64(s string) (v float64, set bool, err error) {
|
|
if o.acceptBlank && s == "" {
|
|
set = false
|
|
return
|
|
}
|
|
set = true
|
|
|
|
v, err = strconv.ParseFloat(s, 64)
|
|
if err != nil {
|
|
return v, set, fmt.Errorf("ParseFloat('%s'): %w", s, err)
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
func (o *options) parseDuration(s string) (v time.Duration, set bool, err error) {
|
|
if o.acceptBlank && s == "" {
|
|
set = false
|
|
return
|
|
}
|
|
|
|
set = true
|
|
|
|
v, err = time.ParseDuration(s)
|
|
if err != nil {
|
|
return v, set, fmt.Errorf("ParseDuration('%s'): %w", s, err)
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
func (o *options) iterFields(r *http.Request, rv reflect.Value) error {
|
|
rt := rv.Type()
|
|
for i := 0; i < rv.NumField(); i++ {
|
|
f := rv.Field(i)
|
|
tf := rt.Field(i)
|
|
if !tf.IsExported() && !tf.Anonymous {
|
|
continue
|
|
}
|
|
|
|
if f.Kind() == reflect.Struct && tf.Anonymous {
|
|
err := o.iterFields(r, f)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
var tAr []string
|
|
var formField string
|
|
formTag, has := rt.Field(i).Tag.Lookup(o.Tag())
|
|
if has {
|
|
tAr = strings.Split(formTag, ",")
|
|
formField = tAr[0]
|
|
}
|
|
if !has || formField == "-" {
|
|
continue
|
|
}
|
|
|
|
fi := f.Interface()
|
|
|
|
switch v := fi.(type) {
|
|
case string, *string:
|
|
s := r.Form.Get(formField)
|
|
setVal(f, s != "" || o.acceptBlank, v, s)
|
|
case int, uint, *int, *uint:
|
|
ff := r.Form.Get(formField)
|
|
val, set, err := o.parseInt(ff)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
setVal(f, set, v, val)
|
|
case float64:
|
|
ff := r.Form.Get(formField)
|
|
val, set, err := o.parseFloat64(ff)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
setVal(f, set, v, val)
|
|
case bool, *bool:
|
|
ff := r.Form.Get(formField)
|
|
val, set, err := o.parseBool(ff)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
setVal(f, set, v, val)
|
|
case []byte:
|
|
file, hdr, err := r.FormFile(formField)
|
|
if err != nil {
|
|
return fmt.Errorf("get form file: %w", err)
|
|
}
|
|
|
|
nameField, hasFilename := rt.Field(i).Tag.Lookup("filenameField")
|
|
if hasFilename {
|
|
fnf := rv.FieldByName(nameField)
|
|
if fnf == (reflect.Value{}) {
|
|
panic(fmt.Errorf("filenameField '%s' does not exist", nameField))
|
|
}
|
|
|
|
fnf.SetString(hdr.Filename)
|
|
}
|
|
audioBytes, err := io.ReadAll(file)
|
|
if err != nil {
|
|
return fmt.Errorf("file read: %w", err)
|
|
}
|
|
|
|
f.SetBytes(audioBytes)
|
|
case time.Time, *time.Time, jsontime.Time, *jsontime.Time:
|
|
tval := r.Form.Get(formField)
|
|
t, set, err := o.parseTime(tval)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
setVal(f, set, v, t)
|
|
case time.Duration, *time.Duration, jsontime.Duration, *jsontime.Duration:
|
|
dval := r.Form.Get(formField)
|
|
d, set, err := o.parseDuration(dval)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
setVal(f, set, v, d)
|
|
case []int:
|
|
val := strings.Trim(r.Form.Get(formField), "[]")
|
|
if val == "" && o.acceptBlank {
|
|
continue
|
|
}
|
|
vals := strings.Split(val, ",")
|
|
ar := make([]int, 0, len(vals))
|
|
for _, v := range vals {
|
|
i, err := strconv.Atoi(v)
|
|
if err == nil {
|
|
ar = append(ar, i)
|
|
}
|
|
}
|
|
f.Set(reflect.ValueOf(ar))
|
|
default:
|
|
panic(fmt.Errorf("unsupported type %T", v))
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func setVal(setField reflect.Value, set bool, fv any, sv any) {
|
|
if !set {
|
|
return
|
|
}
|
|
|
|
rv := reflect.TypeOf(fv)
|
|
svo := reflect.ValueOf(sv)
|
|
|
|
if svo.CanConvert(rv) {
|
|
svo = svo.Convert(rv)
|
|
}
|
|
|
|
if rv.Kind() == reflect.Ptr {
|
|
svo = svo.Addr()
|
|
}
|
|
|
|
setField.Set(svo)
|
|
}
|
|
|
|
func Unmarshal(r *http.Request, dest any, opt ...Option) error {
|
|
o := options{}
|
|
for _, opt := range opt {
|
|
opt(&o)
|
|
}
|
|
|
|
rv := reflect.ValueOf(dest)
|
|
if k := rv.Kind(); k == reflect.Ptr {
|
|
rv = rv.Elem()
|
|
} else {
|
|
return ErrNotPointer
|
|
}
|
|
|
|
if rv.Kind() != reflect.Struct {
|
|
return ErrNotStruct
|
|
}
|
|
|
|
if strings.HasPrefix(r.Header.Get("Content-Type"), "application/x-www-form-urlencoded") {
|
|
err := r.ParseForm()
|
|
if err != nil {
|
|
return fmt.Errorf("ParseForm: %w", err)
|
|
}
|
|
}
|
|
|
|
return o.iterFields(r, rv)
|
|
}
|