Skip to content

Commit

Permalink
Add a WithMapVisitor option. (#28)
Browse files Browse the repository at this point in the history
  • Loading branch information
mennanov committed Sep 21, 2022
1 parent 35e0dbd commit 57ce6c2
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 0 deletions.
21 changes: 21 additions & 0 deletions copy.go
Original file line number Diff line number Diff line change
Expand Up @@ -240,8 +240,19 @@ type options struct {

// CopyListSize can control the number of elements copied from src depending on src's Value
CopyListSize func(src *reflect.Value) int

// MapVisitor is called for every filtered field in structToMap.
//
// It is called before copying the data from source to destination allowing custom processing.
// If the visitor function returns true the visited field is skipped.
MapVisitor mapVisitor
}

// mapVisitor is called for every filtered field in structToMap.
type mapVisitor func(
filter FieldFilter, src interface{}, dst map[string]interface{},
srcFieldName, dstFieldName string, srcFieldValue reflect.Value) (skipToNext bool)

// Option function modifies the given options.
type Option func(*options)

Expand All @@ -259,6 +270,13 @@ func WithCopyListSize(f func(src *reflect.Value) int) Option {
}
}

// WithMapVisitor sets the fields visitor function for StructToMap.
func WithMapVisitor(visitor mapVisitor) Option {
return func(o *options) {
o.MapVisitor = visitor
}
}

func newDefaultOptions() *options {
// set default CopyListSize is func which return src.Len()
return &options{CopyListSize: func(src *reflect.Value) int { return src.Len() }}
Expand Down Expand Up @@ -306,6 +324,9 @@ func structToMap(filter FieldFilter, src interface{}, dst map[string]interface{}
}

dstName := dstKey(userOptions.DstTag, srcType.Field(i))
if userOptions.MapVisitor != nil && userOptions.MapVisitor(filter, src, dst, fieldName, dstName, srcField) {
continue
}

switch srcField.Kind() {
case reflect.Ptr, reflect.Interface:
Expand Down
67 changes: 67 additions & 0 deletions copy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"fmt"
"reflect"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -1790,6 +1791,72 @@ func TestStructToMap_CopyIntArray_WithMaxCopyListSize(t *testing.T) {
}, dst)
}

func TestStructToMap_CopyStructWithPrivateFields_WithMapVisitor(t *testing.T) {
type A struct {
Time time.Time
Other int
}
unixTime := time.Unix(10, 10)
src := &A{Time: unixTime}
dst := map[string]interface{}{}
mask := fieldmask_utils.MaskFromString("Time")
err := fieldmask_utils.StructToMap(mask, src, dst, fieldmask_utils.WithMapVisitor(
func(_ fieldmask_utils.FieldFilter, _ interface{}, dst map[string]interface{},
srcFieldName, dstFieldName string, srcFieldValue reflect.Value) (skipToNext bool) {
if srcFieldName == "Time" {
dst[dstFieldName] = srcFieldValue.Interface()
skipToNext = true
}
return
}))
require.NoError(t, err)
assert.Equal(t, map[string]interface{}{
"Time": unixTime,
}, dst)
}

func TestStructToMap_MapVisitorVisitsOnlyFilteredFields(t *testing.T) {
type A struct {
Field1 int
Field2 string
Field3 int
}
src := &A{Field1: 42, Field2: "hello", Field3: 44}
dst := map[string]interface{}{}
mask := fieldmask_utils.MaskFromString("Field1, Field2")
var visitedFields []string
err := fieldmask_utils.StructToMap(mask, src, dst, fieldmask_utils.WithMapVisitor(
func(_ fieldmask_utils.FieldFilter, _ interface{}, _ map[string]interface{},
srcFieldName, _ string, _ reflect.Value) (skipToNext bool) {
visitedFields = append(visitedFields, srcFieldName)
return
}))
require.NoError(t, err)
assert.Equal(t, visitedFields, []string{"Field1", "Field2"})
}

func TestStructToMap_WithMapVisitor_SkipsToNextField(t *testing.T) {
type A struct {
Field1 int
Field2 string
Field3 int
}
src := &A{Field1: 42, Field2: "hello", Field3: 44}
dst := map[string]interface{}{}
mask := fieldmask_utils.MaskFromString("Field1, Field2")
err := fieldmask_utils.StructToMap(mask, src, dst, fieldmask_utils.WithMapVisitor(
func(_ fieldmask_utils.FieldFilter, _ interface{}, _ map[string]interface{},
srcFieldName, dstFieldName string, _ reflect.Value) (skipToNext bool) {
if srcFieldName == "Field1" {
dst[dstFieldName] = 33
skipToNext = true
}
return
}))
require.NoError(t, err)
assert.Equal(t, map[string]interface{}{"Field1": 33, "Field2": "hello"}, dst)
}

func TestStructToStruct_CopySlice_WithDiffentItemKind(t *testing.T) {
type A struct {
Field1 []int
Expand Down

0 comments on commit 57ce6c2

Please sign in to comment.