Skip to content

Commit

Permalink
Rewrite StructToMap (major update) (#34)
Browse files Browse the repository at this point in the history
  • Loading branch information
mennanov committed Feb 6, 2023
1 parent fc158cc commit f9573a9
Show file tree
Hide file tree
Showing 3 changed files with 233 additions and 126 deletions.
257 changes: 153 additions & 104 deletions copy.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@
package fieldmask_utils

import (
"reflect"
"strings"

"github.com/pkg/errors"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/anypb"
"reflect"
"strings"
)

// StructToStruct copies `src` struct to `dst` struct using the given FieldFilter.
Expand Down Expand Up @@ -250,8 +249,13 @@ type options struct {

// mapVisitor is called for every filtered field in structToMap.
type mapVisitor func(
filter FieldFilter, src interface{}, dst map[string]interface{},
srcFieldName, dstFieldName string, srcFieldValue reflect.Value) (skipToNext bool)
filter FieldFilter, src, dst reflect.Value,
srcFieldName, dstFieldName string, srcFieldValue reflect.Value) MapVisitorResult

type MapVisitorResult struct {
SkipToNext bool
UpdatedDst *reflect.Value
}

// Option function modifies the given options.
type Option func(*options)
Expand Down Expand Up @@ -305,130 +309,175 @@ func StructToMap(filter FieldFilter, src interface{}, dst map[string]interface{}
for _, o := range userOpts {
o(opts)
}
return structToMap(filter, src, dst, opts)
_, err := structToMap(filter, reflect.ValueOf(src), reflect.ValueOf(dst), opts)
return err
}

func structToMap(filter FieldFilter, src interface{}, dst map[string]interface{}, userOptions *options) error {
srcVal := indirect(reflect.ValueOf(src))
srcType := srcVal.Type()
for i := 0; i < srcVal.NumField(); i++ {
fieldName := srcType.Field(i).Name
subFilter, ok := filter.Filter(fieldName)
if !ok {
// Skip this field.
continue
}
srcField := srcVal.FieldByName(fieldName)
if !srcField.CanInterface() {
continue
}

dstName := dstKey(userOptions.DstTag, srcType.Field(i))
if userOptions.MapVisitor != nil && userOptions.MapVisitor(filter, src, dst, fieldName, dstName, srcField) {
continue
func structToMap(filter FieldFilter, src, dst reflect.Value, userOptions *options) (reflect.Value, error) {
switch src.Kind() {
case reflect.Struct:
if dst.Kind() != reflect.Map {
return dst, errors.Errorf("incompatible destination kind: %s, expected map", dst.Kind())
}

switch srcField.Kind() {
case reflect.Ptr, reflect.Interface:
if srcField.IsNil() {
dst[dstName] = nil
srcType := src.Type()
for i := 0; i < src.NumField(); i++ {
fieldName := srcType.Field(i).Name
if !isExported(srcType.Field(i)) {
// Unexported fields can not be copied.
continue
}

var newValue map[string]interface{}
existingValue, ok := dst[dstName]
if ok {
newValue = existingValue.(map[string]interface{})
} else {
newValue = make(map[string]interface{})
}
if err := structToMap(subFilter, srcField.Interface(), newValue, userOptions); err != nil {
return err
subFilter, ok := filter.Filter(fieldName)
if !ok {
// Skip this field.
continue
}
dst[dstName] = newValue

case reflect.Array, reflect.Slice:
// Check if it is a slice of primitive values.
itemKind := srcField.Type().Elem().Kind()
srcLen := userOptions.CopyListSize(&srcField)
if itemKind != reflect.Ptr && itemKind != reflect.Struct && itemKind != reflect.Interface {
// Handle this array/slice as a regular non-nested data structure: copy it entirely to dst.
if srcLen < srcField.Len() {
dst[dstName] = srcField.Slice(0, srcLen).Interface()
srcField := indirect(src.FieldByName(fieldName))
dstName := dstKey(userOptions.DstTag, srcType.Field(i))
mapValue := indirect(dst.MapIndex(reflect.ValueOf(dstName)))
if !mapValue.IsValid() {
if srcField.IsValid() {
mapValue = newValue(srcField.Type())
} else {
dst[dstName] = srcField.Interface()
dstMap := dst.Interface().(map[string]interface{})
dstMap[dstName] = nil
continue
}
continue
}
var newValue []map[string]interface{}
if srcField.Kind() == reflect.Slice && !srcField.IsNil() {
// If the source slice is not nil then the dst should not be nil either even if the src slice is empty.
newValue = make([]map[string]interface{}, 0, srcField.Len())
}
existingValue, ok := dst[dstName]
if ok {
v := reflect.ValueOf(existingValue)
if v.Kind() == reflect.Array {
// Convert the array to a slice.
for i := 0; i < v.Len(); i++ {
itemInterface := v.Index(i).Interface()
item, k := itemInterface.(map[string]interface{})
if !k {
return errors.Errorf("unexpected dst type %T, expected map[string]interface{}", itemInterface)
}
newValue = append(newValue, item)
}
} else {
newValue, ok = existingValue.([]map[string]interface{})
if !ok {
return errors.Errorf("unexpected dst type %T, expected []map[string]interface{}", newValue)
if userOptions.MapVisitor != nil {
result := userOptions.MapVisitor(filter, src, mapValue, fieldName, dstName, srcField)
if result.UpdatedDst != nil {
mapValue = *result.UpdatedDst

}
if result.SkipToNext {
if result.UpdatedDst != nil {
dst.SetMapIndex(reflect.ValueOf(dstName), mapValue)
}
continue
}
}
if isPrimitive(mapValue.Kind()) {
dst.SetMapIndex(reflect.ValueOf(dstName), srcField)
continue
}
var err error
if mapValue, err = structToMap(subFilter, srcField, mapValue, userOptions); err != nil {
return dst, err
}
dst.SetMapIndex(reflect.ValueOf(dstName), mapValue)
}

// Iterate over items of the slice/array.
dstLen := len(newValue)
if dstLen < srcLen {
// Grow the dst slice to match the src len.
for i := 0; i < srcLen-dstLen; i++ {
newValue = append(newValue, make(map[string]interface{}))
}
dstLen = srcLen
case reflect.Ptr:
if src.IsNil() {
reflect.ValueOf(dst).Set(reflect.ValueOf(nil))
break
}
var err error
if dst, err = structToMap(filter, indirect(src), dst, userOptions); err != nil {
return dst, err
}

case reflect.Interface:
if src.IsNil() {
reflect.ValueOf(dst).Set(reflect.ValueOf(nil))
break
}

var err error
if dst, err = structToMap(filter, indirect(src), dst, userOptions); err != nil {
return dst, err
}

case reflect.Array, reflect.Slice:
if dstKind := dst.Kind(); dstKind != reflect.Slice && dstKind != reflect.Array {
return dst, errors.Errorf("incompatible destination kind: %s, expected slice", dst.Kind())
}
itemType := src.Type().Elem()
desiredDstLen := userOptions.CopyListSize(&src)
itemKind := itemType.Kind()
if isPrimitive(itemKind) {
// Handle this array/slice as a regular non-nested data structure: copy it entirely to dst.
if desiredDstLen < src.Len() {
dst = src.Slice(0, desiredDstLen)
} else {
dst = src
}
for i := 0; i < srcLen; i++ {
subValue := srcField.Index(i)
if err := structToMap(subFilter, subValue.Interface(), newValue[i], userOptions); err != nil {
return err
} else {
if dst.Kind() == reflect.Array {
// Convert the array to a slice.
sliceDst := newValue(src.Type())
for i := 0; i < dst.Len(); i++ {
sliceDst = reflect.Append(sliceDst, dst.Index(i))
}
dst = sliceDst
}
// Truncate the dst to the length of src.
newValue = newValue[:srcLen]
dst[dstName] = newValue

case reflect.Struct:
var newValue map[string]interface{}
existingValue, ok := dst[dstName]
if ok {
newValue = existingValue.(map[string]interface{})
} else {
newValue = make(map[string]interface{})
var err error
for i := 0; i < desiredDstLen; i++ {
itemExists := false
var subDst reflect.Value
if i < dst.Len() {
subDst = dst.Index(i)
itemExists = true
} else {
subDst = newValue(itemType)
}
if subDst, err = structToMap(filter, src.Index(i), subDst, userOptions); err != nil {
return subDst, err
}
if !itemExists {
dst = reflect.Append(dst, subDst)
}
}
if err := structToMap(subFilter, srcField.Interface(), newValue, userOptions); err != nil {
return err
if desiredDstLen < dst.Len() {
dst = dst.Slice(0, desiredDstLen)
}
dst[dstName] = newValue

default:
// Set a value on a map.
dst[dstName] = srcField.Interface()
}

case reflect.Invalid:
dst.Set(reflect.ValueOf(nil))

default:
dst.Set(src)
}
return nil
return dst, nil
}

func indirect(v reflect.Value) reflect.Value {
for v.Kind() == reflect.Ptr {
for v.Kind() == reflect.Ptr || v.Kind() == reflect.Interface {
v = v.Elem()
}
return v
}

// isPrimitive checks whether the given kind is simple enough so that it can be copied directly without recursion.
func isPrimitive(kind reflect.Kind) bool {
return kind != reflect.Ptr &&
kind != reflect.Struct &&
kind != reflect.Interface &&
kind != reflect.Slice &&
kind != reflect.Array &&
kind != reflect.Map
}

// newValue creates a new value given its type.
func newValue(t reflect.Type) reflect.Value {
switch t.Kind() {
case reflect.Struct:
return reflect.MakeMap(reflect.TypeOf(map[string]interface{}{}))

case reflect.Array, reflect.Slice:
return reflect.MakeSlice(reflect.SliceOf(newValue(t.Elem()).Type()), 0, 0)

case reflect.Ptr:
return newValue(t.Elem())

default:
return reflect.New(t).Elem()
}
}

// isExported is a backport of reflect.StructField.IsExported() for the older versions of golang (<1.17).
func isExported(f reflect.StructField) bool {
return f.PkgPath == ""
}
6 changes: 3 additions & 3 deletions copy_proto_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ func TestStructToMap_Success(t *testing.T) {
mask := fieldmask_utils.MaskFromString(
"Id,Avatar{OriginalUrl},Tags,Images,Permissions,Friends{Images{ResizedUrl}}")
err := fieldmask_utils.StructToMap(mask, testUserFull, userDst)
assert.Nil(t, err)
require.Nil(t, err)
expected := map[string]interface{}{
"Id": testUserFull.Id,
"Avatar": map[string]interface{}{
Expand Down Expand Up @@ -311,11 +311,11 @@ func TestStructToMap_PartialProtoSuccess(t *testing.T) {
mask := fieldmask_utils.MaskFromString(
"Id,Avatar{OriginalUrl},Images,Username,Permissions,Name{MaleName}")
err := fieldmask_utils.StructToMap(mask, testUserPartial, userDst)
assert.Nil(t, err)
require.Nil(t, err)
expected := map[string]interface{}{
"Id": testUserPartial.Id,
"Avatar": nil,
"Images": []map[string]interface{}(nil),
"Images": []map[string]interface{}{},
"Username": testUserPartial.Username,
"Permissions": []testproto.Permission(nil),
"Name": nil,
Expand Down
Loading

0 comments on commit f9573a9

Please sign in to comment.