stillbox/internal/forms/forms.go

338 lines
6.5 KiB
Go

package forms
import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"reflect"
"strconv"
"strings"
"time"
"dynatron.me/x/stillbox/internal/jsontime"
"github.com/araddon/dateparse"
)
var (
ErrNotStruct = errors.New("destination is not a struct")
ErrNotPointer = errors.New("destination is not a pointer")
ErrContentType = errors.New("bad content type")
)
const (
MaxMultipartMemory int64 = 1024 * 1024 // 1MB
)
type options struct {
tagOverride *string
parseTimeIn *time.Location
maxMultipartMemory int64
parseLocal bool
acceptBlank bool
}
type Option func(*options)
func WithParseTimeInTZ(l *time.Location) Option {
return func(o *options) {
o.parseTimeIn = l
}
}
func WithParseLocalTime() Option {
return func(o *options) {
o.parseLocal = true
}
}
func WithAcceptBlank() Option {
return func(o *options) {
o.acceptBlank = true
}
}
func WithTag(t string) Option {
return func(o *options) {
o.tagOverride = &t
}
}
func WithMaxMultipartSize(s int64) Option {
return func(o *options) {
o.maxMultipartMemory = s
}
}
func (o *options) Tag() string {
if o.tagOverride != nil {
return *o.tagOverride
}
return "form"
}
func (o *options) parseTime(s string, dpo ...dateparse.ParserOption) (t time.Time, set bool, err error) {
if o.acceptBlank && s == "" {
set = false
return
}
if iv, err := strconv.Atoi(s); err == nil {
return time.Unix(int64(iv), 0), true, nil
}
switch {
case o.parseTimeIn != nil:
t, err = dateparse.ParseIn(s, o.parseTimeIn, dpo...)
case o.parseLocal:
t, err = dateparse.ParseLocal(s, dpo...)
default:
t, err = dateparse.ParseAny(s, dpo...)
}
set = true
return
}
func (o *options) parseBool(s string) (v bool, set bool, err error) {
if o.acceptBlank && s == "" {
set = false
return
}
set = true
v, err = strconv.ParseBool(s)
if err != nil {
return v, set, fmt.Errorf("parsebool('%s'): %w", s, err)
}
return
}
func (o *options) parseInt(s string) (v int, set bool, err error) {
if o.acceptBlank && s == "" {
set = false
return
}
set = true
v, err = strconv.Atoi(s)
if err != nil {
return v, set, fmt.Errorf("atoi('%s'): %w", s, err)
}
return
}
func (o *options) parseFloat64(s string) (v float64, set bool, err error) {
if o.acceptBlank && s == "" {
set = false
return
}
set = true
v, err = strconv.ParseFloat(s, 64)
if err != nil {
return v, set, fmt.Errorf("ParseFloat('%s'): %w", s, err)
}
return
}
func (o *options) parseDuration(s string) (v time.Duration, set bool, err error) {
if o.acceptBlank && s == "" {
set = false
return
}
set = true
v, err = time.ParseDuration(s)
if err != nil {
return v, set, fmt.Errorf("ParseDuration('%s'): %w", s, err)
}
return
}
func (o *options) iterFields(r *http.Request, rv reflect.Value) error {
rt := rv.Type()
for i := 0; i < rv.NumField(); i++ {
f := rv.Field(i)
tf := rt.Field(i)
if !tf.IsExported() && !tf.Anonymous {
continue
}
if f.Kind() == reflect.Struct && tf.Anonymous {
err := o.iterFields(r, f)
if err != nil {
return err
}
}
var tAr []string
var formField string
formTag, has := rt.Field(i).Tag.Lookup(o.Tag())
if has {
tAr = strings.Split(formTag, ",")
formField = tAr[0]
}
if !has || formField == "-" {
continue
}
fi := f.Interface()
switch v := fi.(type) {
case string, *string:
s := r.Form.Get(formField)
setVal(f, s != "" || o.acceptBlank, v, s)
case int, uint, *int, *uint:
ff := r.Form.Get(formField)
val, set, err := o.parseInt(ff)
if err != nil {
return err
}
setVal(f, set, v, val)
case float64:
ff := r.Form.Get(formField)
val, set, err := o.parseFloat64(ff)
if err != nil {
return err
}
setVal(f, set, v, val)
case bool, *bool:
ff := r.Form.Get(formField)
val, set, err := o.parseBool(ff)
if err != nil {
return err
}
setVal(f, set, v, val)
case []byte:
file, hdr, err := r.FormFile(formField)
if err != nil {
return fmt.Errorf("get form file: %w", err)
}
nameField, hasFilename := rt.Field(i).Tag.Lookup("filenameField")
if hasFilename {
fnf := rv.FieldByName(nameField)
if fnf == (reflect.Value{}) {
panic(fmt.Errorf("filenameField '%s' does not exist", nameField))
}
fnf.SetString(hdr.Filename)
}
audioBytes, err := io.ReadAll(file)
if err != nil {
return fmt.Errorf("file read: %w", err)
}
f.SetBytes(audioBytes)
case time.Time, *time.Time, jsontime.Time, *jsontime.Time:
tval := r.Form.Get(formField)
t, set, err := o.parseTime(tval)
if err != nil {
return err
}
setVal(f, set, v, t)
case time.Duration, *time.Duration, jsontime.Duration, *jsontime.Duration:
dval := r.Form.Get(formField)
d, set, err := o.parseDuration(dval)
if err != nil {
return err
}
setVal(f, set, v, d)
case []int:
val := strings.Trim(r.Form.Get(formField), "[]")
if val == "" && o.acceptBlank {
continue
}
vals := strings.Split(val, ",")
ar := make([]int, 0, len(vals))
for _, v := range vals {
i, err := strconv.Atoi(v)
if err == nil {
ar = append(ar, i)
}
}
f.Set(reflect.ValueOf(ar))
default:
panic(fmt.Errorf("unsupported type %T", v))
}
}
return nil
}
func setVal(setField reflect.Value, set bool, dest any, src any) {
if !set {
return
}
destType := reflect.TypeOf(dest)
srcVal := reflect.ValueOf(src)
if srcVal.CanConvert(destType) {
srcVal = srcVal.Convert(destType)
}
if destType.Kind() == reflect.Ptr && srcVal.CanAddr() {
srcVal = srcVal.Addr()
}
setField.Set(srcVal)
}
func Unmarshal(r *http.Request, dest any, opt ...Option) error {
o := options{
maxMultipartMemory: MaxMultipartMemory,
}
for _, opt := range opt {
opt(&o)
}
contentType := strings.Split(r.Header.Get("Content-Type"), ";")[0]
switch contentType {
case "multipart/form-data":
err := r.ParseMultipartForm(o.maxMultipartMemory)
if err != nil {
return fmt.Errorf("ParseForm: %w", err)
}
return o.unmarshalForm(r, dest)
case "application/x-www-form-urlencoded":
err := r.ParseForm()
if err != nil {
return fmt.Errorf("ParseForm: %w", err)
}
return o.unmarshalForm(r, dest)
case "application/json":
return json.NewDecoder(r.Body).Decode(dest)
}
return ErrContentType
}
func (o *options) unmarshalForm(r *http.Request, dest any) error {
rv := reflect.ValueOf(dest)
if k := rv.Kind(); k == reflect.Ptr {
rv = rv.Elem()
} else {
return ErrNotPointer
}
if rv.Kind() != reflect.Struct {
return ErrNotStruct
}
return o.iterFields(r, rv)
}