diff --git a/gateway/field_presence.go b/gateway/field_presence.go index 8992fb89..e531cd98 100644 --- a/gateway/field_presence.go +++ b/gateway/field_presence.go @@ -15,23 +15,21 @@ import ( "google.golang.org/grpc/metadata" ) -const fieldPresenceMetaKey = "field-paths" +const ( + fieldPresenceMetaKey = "field-paths" + pathsSeparator = "$" + bulkField = "objects" +) -// PresenceAnnotator will parse the JSON input and then add the paths to the +// NewPresenceAnnotator will parse the JSON input and then add the paths to the // metadata to be pulled from context later func NewPresenceAnnotator(methods ...string) func(context.Context, *http.Request) metadata.MD { return func(ctx context.Context, req *http.Request) metadata.MD { if req == nil { return nil } - validMethod := false - for _, m := range methods { - if req.Method == m { - validMethod = true - break - } - } - if !validMethod { + + if !isValidMethod(req, methods...) { return nil } @@ -49,43 +47,95 @@ func NewPresenceAnnotator(methods ...string) func(context.Context, *http.Request return md } - var paths []string var root interface{} - if err := json.Unmarshal(body, &root); err != nil { return nil } - queue := []pathItem{{node: root}} - for len(queue) > 0 { - // dequeue an item - item := queue[0] - queue = queue[1:] - if m, ok := item.node.(map[string]interface{}); ok { - l := len(item.path) - // if the item is an object, then enqueue all of its children - for k, v := range m { - newPath := make([]string, l+1) - copy(newPath, item.path) - newPath[l] = generator.CamelCase(k) - queue = append(queue, pathItem{path: newPath, node: v}) - } + roots := getRoots(root) + for _, r := range roots { + queue := []pathItem{{node: r}} + + paths := []string{} + for len(queue) > 0 { + // dequeue an item + item := queue[0] + queue = queue[1:] - if len(m) == 0 && l > 0 { - // otherwise, it's a leaf node so print its path + if isLeaf(item) { paths = append(paths, strings.Join(item.path, ".")) + } else { + if m, ok := item.node.(map[string]interface{}); ok { + // if the item is an object, then enqueue all of its children + for k, v := range m { + newPath := extendPath(item.path, k, v) + queue = append(queue, pathItem{path: newPath, node: v}) + } + } } - } else if len(item.path) > 0 { - // otherwise, it's a leaf node so print its path - paths = append(paths, strings.Join(item.path, ".")) } + + entry := strings.Join(paths, pathsSeparator) + if len(entry) == 0 { + continue + } + + md[fieldPresenceMetaKey] = append(md[fieldPresenceMetaKey], strings.Join(paths, pathsSeparator)) } - md[fieldPresenceMetaKey] = paths return md } } +func extendPath(parrent []string, key string, value interface{}) []string { + newPath := make([]string, len(parrent)+1) + copy(newPath, parrent) + newPath[len(newPath)-1] = generator.CamelCase(key) + return newPath +} + +func isLeaf(item pathItem) bool { + if m, ok := item.node.(map[string]interface{}); ok { + if len(m) == 0 && len(item.path) > 0 { + return true + } + } else if len(item.path) > 0 { + return true + } + + return false +} + +func getRoots(root interface{}) []interface{} { + defaultRoot := []interface{}{root} + m, ok := root.(map[string]interface{}) + if !ok { + return defaultRoot + } + + bulk, ok := m[bulkField] + if !ok { + return defaultRoot + } + + slice, ok := bulk.([]interface{}) + if !ok { + return defaultRoot + } + + return slice +} + +func isValidMethod(req *http.Request, methods ...string) bool { + for _, m := range methods { + if req.Method == m { + return true + } + } + + return false +} + // pathItem stores a in-progress deconstruction of a path for a fieldmask type pathItem struct { // the list of prior fields leading up to node @@ -110,17 +160,19 @@ func PresenceClientInterceptor() grpc.UnaryClientInterceptor { if !found { return } - fieldMask := &field_mask.FieldMask{Paths: paths} - // If a field with type *FieldMask exists, set the paths in it + fieldMask := fieldMaskFromPaths(paths) + // If a field with type *FieldMask or []*FieldMask exists, set the paths in it t := reflect.ValueOf(req) if t.Kind() != reflect.Interface && t.Kind() != reflect.Ptr { return } + t = t.Elem() if t.Kind() != reflect.Struct { // only Structs can have their fields enumerated return } + for i := 0; i < t.NumField(); i++ { f := t.Field(i) if f.Type() == reflect.TypeOf(fieldMask) && f.IsNil() { @@ -130,3 +182,19 @@ func PresenceClientInterceptor() grpc.UnaryClientInterceptor { return } } + +func fieldMaskFromPaths(paths []string) interface{} { + if len(paths) == 0 { + return &field_mask.FieldMask{} + } + + if len(paths) > 1 { + bulkFieldMasks := make([]*field_mask.FieldMask, len(paths)) + for i, p := range paths { + bulkFieldMasks[i] = &field_mask.FieldMask{Paths: strings.Split(p, pathsSeparator)} + } + return bulkFieldMasks + } + + return &field_mask.FieldMask{Paths: strings.Split(paths[0], pathsSeparator)} +} diff --git a/gateway/field_presence_test.go b/gateway/field_presence_test.go index 3de86951..96de41dd 100644 --- a/gateway/field_presence_test.go +++ b/gateway/field_presence_test.go @@ -2,6 +2,7 @@ package gateway import ( "context" + "fmt" "io/ioutil" "net/http" "reflect" @@ -20,8 +21,9 @@ func TestAnnotator(t *testing.T) { ``: metadata.MD{fieldPresenceMetaKey: nil}, `{}`: metadata.MD{fieldPresenceMetaKey: nil}, `{`: nil, - `{"one":{"two":"a", "three":[]}, "four": 5}`: {fieldPresenceMetaKey: []string{"Four", "One.Two", "One.Three"}}, - `{"one": {}}`: {fieldPresenceMetaKey: []string{"One"}}, + `{"objects":[{"one": {"two":"a", "three":[]}, "four": 5}, {"one":{"two":"a", "three":[]}, "four": 5}]}`: {fieldPresenceMetaKey: []string{"Four$One.Two$One.Three", "Four$One.Two$One.Three"}}, + `{"one":{"two":"a", "three":[]}, "four": 5}`: {fieldPresenceMetaKey: []string{"Four$One.Two$One.Three"}}, + `{"one": {}}`: {fieldPresenceMetaKey: []string{"One"}}, `{ "name": "atlas", "burden": { @@ -40,9 +42,7 @@ func TestAnnotator(t *testing.T) { "mortals": [] } } -}`: {fieldPresenceMetaKey: []string{"Name", "Burden.Duration", "Burden.Weight", - "Burden.Breaks", "Burden.Replacements.Hero.Name", "Burden.Replacements.Hero.Duration", - "Burden.Replacements.Hero.Lineage.Mother", "Burden.Replacements.Hero.Lineage.Father", "Burden.Replacements.Mortals"}}, +}`: {fieldPresenceMetaKey: []string{"Name$Burden.Duration$Burden.Weight$Burden.Breaks$Burden.Replacements.Hero.Name$Burden.Replacements.Hero.Duration$Burden.Replacements.Hero.Lineage.Mother$Burden.Replacements.Hero.Lineage.Father$Burden.Replacements.Mortals"}}, } { postReq := &http.Request{ Method: "POST", @@ -54,15 +54,33 @@ func TestAnnotator(t *testing.T) { continue } // Because the order of objects at the same depth is not guaranteed - sort.Strings(md[fieldPresenceMetaKey]) - sort.Strings(expect[fieldPresenceMetaKey]) - if !reflect.DeepEqual(md, expect) { + if !isEqualFieldMasks(md[fieldPresenceMetaKey], expect[fieldPresenceMetaKey]) { t.Errorf("Did not produce expected metadata %+v, got %+v", expect, md) } } } +func isEqualFieldMasks(s1 []string, s2 []string) bool { + if len(s1) != len(s2) { + fmt.Println("len(s1) != len(s2)", len(s1), len(s2)) + return false + } + + for i := 0; i < len(s1); i++ { + mask1, mask2 := strings.Split(s1[i], "$"), strings.Split(s2[i], "$") + sort.Strings(mask1) + sort.Strings(mask2) + + if !reflect.DeepEqual(mask1, mask2) { + fmt.Println("!reflect.DeepEqual(mask1, mask2)") + return false + } + } + + return true +} + type dummyReq struct { SomeFieldMaskField *field_mask.FieldMask } @@ -80,7 +98,7 @@ func TestUnaryServerInterceptor(t *testing.T) { interceptor := PresenceClientInterceptor() md := runtime.ServerMetadata{ HeaderMD: metadata.MD{ - fieldPresenceMetaKey: []string{"one.two.three", "one.four"}, + fieldPresenceMetaKey: []string{"one.two.three$one.four"}, }, } ctx := runtime.NewServerMetadataContext(context.Background(), md) @@ -112,7 +130,7 @@ func TestUnaryServerInterceptor(t *testing.T) { if req == nil { t.Fatal("For some reason it deleted the request object") } - got, want := req.SomeFieldMaskField, &field_mask.FieldMask{Paths: nil} + got, want := req.SomeFieldMaskField, &field_mask.FieldMask{} if !reflect.DeepEqual(got, want) { t.Errorf("Didn't properly set the fieldmask in the request.\ngot :%v\nwant:%v", got, want) }