Skip to content

Commit

Permalink
feat: support struct bindings
Browse files Browse the repository at this point in the history
  • Loading branch information
zjregee committed Aug 26, 2024
1 parent 9352546 commit de68153
Show file tree
Hide file tree
Showing 2 changed files with 139 additions and 4 deletions.
113 changes: 109 additions & 4 deletions ahttp/binder.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,21 @@ import (
"encoding/json"
"net/http"
"reflect"
"strconv"
"strings"
)

type DefaultBinder struct{}

func (b *DefaultBinder) BindHeaders(i interface{}, c *Context) error {
if err := b.bindData(i, c.Request().Header); err != nil {
if err := b.bindData(i, c.Request().Header, "header"); err != nil {
return err
}
return nil
}

func (b *DefaultBinder) BindQueryParams(i interface{}, c *Context) error {
if err := b.bindData(i, c.QueryParams()); err != nil {
if err := b.bindData(i, c.QueryParams(), "query"); err != nil {
return err
}
return nil
Expand Down Expand Up @@ -45,7 +46,7 @@ func (b *DefaultBinder) BindBody(i interface{}, c *Context) error {
if err != nil {
return ErrBadRequest.SetInternal(err)
}
if err := b.bindData(i, params); err != nil {
if err := b.bindData(i, params, "param"); err != nil {
return ErrBadRequest.SetInternal(err)
}
return nil
Expand All @@ -62,7 +63,7 @@ func (b *DefaultBinder) Bind(i interface{}, c *Context) error {
}
}

func (b *DefaultBinder) bindData(dest interface{}, data map[string][]string) error {
func (b *DefaultBinder) bindData(dest interface{}, data map[string][]string, tag string) error {
if dest == nil || len(data) == 0 {
return nil
}
Expand All @@ -85,6 +86,110 @@ func (b *DefaultBinder) bindData(dest interface{}, data map[string][]string) err
val.SetMapIndex(reflect.ValueOf(k), reflect.ValueOf(v))
}
}
return nil
}
if typ.Kind() != reflect.Ptr || typ.Elem().Kind() != reflect.Struct {
return ErrUnsupportedMediaType
}
typ = typ.Elem()
val = val.Elem()
for i := 0; i < typ.NumField(); i++ {
typField := typ.Field(i)
inputFieldName := typField.Tag.Get(tag)
if inputFieldName == "" {
return ErrUnsupportedMediaType
}
inputFiledValue, ok := data[inputFieldName]
if !ok {
return ErrUnsupportedMediaType
}
valField := val.Field(i)
valFieldKind := valField.Kind()
if err := setWithProperType(valFieldKind, inputFiledValue[0], valField); err != nil {
return err
}
}
return nil
}

func setWithProperType(valueKind reflect.Kind, val string, structFiled reflect.Value) error {
switch valueKind {
case reflect.Ptr:
return setWithProperType(structFiled.Elem().Kind(), val, structFiled.Elem())
case reflect.Int:
return setIntFiled(val, 0, structFiled)
case reflect.Int8:
return setIntFiled(val, 8, structFiled)
case reflect.Int16:
return setIntFiled(val, 16, structFiled)
case reflect.Int32:
return setIntFiled(val, 32, structFiled)
case reflect.Int64:
return setIntFiled(val, 64, structFiled)
case reflect.Uint:
return setUintFiled(val, 0, structFiled)
case reflect.Uint8:
return setUintFiled(val, 8, structFiled)
case reflect.Uint16:
return setUintFiled(val, 16, structFiled)
case reflect.Uint32:
return setUintFiled(val, 32, structFiled)
case reflect.Uint64:
return setUintFiled(val, 64, structFiled)
case reflect.Bool:
return setBoolFiled(val, structFiled)
case reflect.Float32:
return setFloatFiled(val, 32, structFiled)
case reflect.Float64:
return setFloatFiled(val, 64, structFiled)
case reflect.String:
structFiled.SetString(val)
default:
return ErrUnsupportedMediaType
}
return nil
}

func setIntFiled(value string, bitSize int, filed reflect.Value) error {
if value == "" {
value = "0"
}
intVal, err := strconv.ParseInt(value, 10, bitSize)
if err == nil {
filed.SetInt(intVal)
}
return err
}

func setUintFiled(value string, bitSize int, filed reflect.Value) error {
if value == "" {
value = "0"
}
uintVal, err := strconv.ParseUint(value, 10, bitSize)
if err == nil {
filed.SetUint(uintVal)
}
return err
}

func setBoolFiled(value string, filed reflect.Value) error {
if value == "" {
value = "false"
}
boolVal, err := strconv.ParseBool(value)
if err == nil {
filed.SetBool(boolVal)
}
return err
}

func setFloatFiled(value string, bitSize int, filed reflect.Value) error {
if value == "" {
value = "0.0"
}
floatVal, err := strconv.ParseFloat(value, bitSize)
if err == nil {
filed.SetFloat(floatVal)
}
return err
}
30 changes: 30 additions & 0 deletions ahttp/binder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,33 @@ func TestBinderBindBodyForm(t *testing.T) {
assert.Equal(t, "test", bindData["name"])
assert.Equal(t, "18", bindData["age"])
}

func TestBinderBindStruct(t *testing.T) {
c := newTestContextWithJson()
b := &DefaultBinder{}

type testQueryData struct {
Name string `query:"name"`
Age int `query:"age"`
}
bindQueryData := &testQueryData{}
_ = b.BindQueryParams(bindQueryData, c)
assert.Equal(t, "test", bindQueryData.Name)
assert.Equal(t, 18, bindQueryData.Age)

type testFormData struct {
Name string `form:"name"`
Age int `form:"age"`
}
bindFormData := &testFormData{}
_ = b.BindBody(bindFormData, c)
assert.Equal(t, "test", bindFormData.Name)
assert.Equal(t, 18, bindFormData.Age)

type testHeaderData struct {
ContentType string `header:"Content-Type"`
}
bindHeaderData := &testHeaderData{}
_ = b.BindHeaders(bindHeaderData, c)
assert.Equal(t, MIMEApplicationJSON, bindHeaderData.ContentType)
}

0 comments on commit de68153

Please sign in to comment.