From e5f8d03c0a1dd9cc571d648cd610305139078de5 Mon Sep 17 00:00:00 2001 From: Aman Karmani Date: Wed, 2 Nov 2022 16:37:28 -0700 Subject: [PATCH] feat: apply omitEmptyFlag to empty structs --- encode_map.go | 2 +- msgpack_test.go | 22 ++++++++++++++++++++++ types.go | 16 +++++++++++----- 3 files changed, 34 insertions(+), 6 deletions(-) diff --git a/encode_map.go b/encode_map.go index ba4c61b..d99c165 100644 --- a/encode_map.go +++ b/encode_map.go @@ -148,7 +148,7 @@ func encodeStructValue(e *Encoder, strct reflect.Value) error { if e.flags&arrayEncodedStructsFlag != 0 || structFields.AsArray { return encodeStructValueAsArray(e, strct, structFields.List) } - fields := structFields.OmitEmpty(strct, e.flags&omitEmptyFlag != 0) + fields := structFields.OmitEmpty(e, strct) if err := e.EncodeMapLen(len(fields)); err != nil { return err diff --git a/msgpack_test.go b/msgpack_test.go index 1b77027..da133d0 100644 --- a/msgpack_test.go +++ b/msgpack_test.go @@ -393,6 +393,28 @@ func TestSetOmitEmpty(t *testing.T) { err = dec.Decode(&t2) require.Nil(t, err) require.Nil(t, t2.Exported) + + type Nested struct { + Foo string + Bar string + } + type Item struct { + X Nested + Y *Nested + } + i := Item{} + buf.Reset() + err = enc.Encode(i) + require.Nil(t, err) + require.NotContains(t, buf.Bytes(), byte('X')) + require.NotContains(t, buf.Bytes(), byte('Y')) + + i = Item{Y: &Nested{}} + buf.Reset() + err = enc.Encode(i) + require.Nil(t, err) + require.NotContains(t, buf.Bytes(), byte('X')) + require.Contains(t, buf.Bytes(), byte('Y')) } type NullInt struct { diff --git a/types.go b/types.go index 69aca61..51851fe 100644 --- a/types.go +++ b/types.go @@ -97,12 +97,13 @@ type field struct { decoder decoderFunc } -func (f *field) Omit(strct reflect.Value, forced bool) bool { +func (f *field) Omit(e *Encoder, strct reflect.Value) bool { v, ok := fieldByIndex(strct, f.index) if !ok { return true } - return (f.omitEmpty || forced) && isEmptyValue(v) + forced := e.flags&omitEmptyFlag != 0 + return (f.omitEmpty || forced) && e.isEmptyValue(v) } func (f *field) EncodeValue(e *Encoder, strct reflect.Value) error { @@ -152,7 +153,8 @@ func (fs *fields) warnIfFieldExists(name string) { } } -func (fs *fields) OmitEmpty(strct reflect.Value, forced bool) []*field { +func (fs *fields) OmitEmpty(e *Encoder, strct reflect.Value) []*field { + forced := e.flags&omitEmptyFlag != 0 if !fs.hasOmitEmpty && !forced { return fs.List } @@ -160,7 +162,7 @@ func (fs *fields) OmitEmpty(strct reflect.Value, forced bool) []*field { fields := make([]*field, 0, len(fs.List)) for _, f := range fs.List { - if !f.Omit(strct, forced) { + if !f.Omit(e, strct) { fields = append(fields, f) } } @@ -317,7 +319,7 @@ type isZeroer interface { IsZero() bool } -func isEmptyValue(v reflect.Value) bool { +func (e *Encoder) isEmptyValue(v reflect.Value) bool { kind := v.Kind() for kind == reflect.Interface { @@ -335,6 +337,10 @@ func isEmptyValue(v reflect.Value) bool { switch kind { case reflect.Array, reflect.Map, reflect.Slice, reflect.String: return v.Len() == 0 + case reflect.Struct: + structFields := structs.Fields(v.Type(), e.structTag) + fields := structFields.OmitEmpty(e, v) + return len(fields) == 0 case reflect.Bool: return !v.Bool() case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: