Skip to content

Commit

Permalink
Add support for repeated field mask (#173)
Browse files Browse the repository at this point in the history
* Add support for repeated field mask
  • Loading branch information
Aliaksei Burau authored Sep 6, 2019
1 parent d292e91 commit dae90b8
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 44 deletions.
136 changes: 102 additions & 34 deletions gateway/field_presence.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,21 @@ import (
"google.golang.org/grpc/metadata"
)

const fieldPresenceMetaKey = "field-paths"
const (
fieldPresenceMetaKey = "field-paths"
pathsSeparator = "$"
bulkField = "objects"
)

// PresenceAnnotator will parse the JSON input and then add the paths to the
// NewPresenceAnnotator will parse the JSON input and then add the paths to the
// metadata to be pulled from context later
func NewPresenceAnnotator(methods ...string) func(context.Context, *http.Request) metadata.MD {
return func(ctx context.Context, req *http.Request) metadata.MD {
if req == nil {
return nil
}
validMethod := false
for _, m := range methods {
if req.Method == m {
validMethod = true
break
}
}
if !validMethod {

if !isValidMethod(req, methods...) {
return nil
}

Expand All @@ -49,43 +47,95 @@ func NewPresenceAnnotator(methods ...string) func(context.Context, *http.Request
return md
}

var paths []string
var root interface{}

if err := json.Unmarshal(body, &root); err != nil {
return nil
}

queue := []pathItem{{node: root}}
for len(queue) > 0 {
// dequeue an item
item := queue[0]
queue = queue[1:]
if m, ok := item.node.(map[string]interface{}); ok {
l := len(item.path)
// if the item is an object, then enqueue all of its children
for k, v := range m {
newPath := make([]string, l+1)
copy(newPath, item.path)
newPath[l] = generator.CamelCase(k)
queue = append(queue, pathItem{path: newPath, node: v})
}
roots := getRoots(root)
for _, r := range roots {
queue := []pathItem{{node: r}}

paths := []string{}
for len(queue) > 0 {
// dequeue an item
item := queue[0]
queue = queue[1:]

if len(m) == 0 && l > 0 {
// otherwise, it's a leaf node so print its path
if isLeaf(item) {
paths = append(paths, strings.Join(item.path, "."))
} else {
if m, ok := item.node.(map[string]interface{}); ok {
// if the item is an object, then enqueue all of its children
for k, v := range m {
newPath := extendPath(item.path, k, v)
queue = append(queue, pathItem{path: newPath, node: v})
}
}
}
} else if len(item.path) > 0 {
// otherwise, it's a leaf node so print its path
paths = append(paths, strings.Join(item.path, "."))
}

entry := strings.Join(paths, pathsSeparator)
if len(entry) == 0 {
continue
}

md[fieldPresenceMetaKey] = append(md[fieldPresenceMetaKey], strings.Join(paths, pathsSeparator))
}

md[fieldPresenceMetaKey] = paths
return md
}
}

func extendPath(parrent []string, key string, value interface{}) []string {
newPath := make([]string, len(parrent)+1)
copy(newPath, parrent)
newPath[len(newPath)-1] = generator.CamelCase(key)
return newPath
}

func isLeaf(item pathItem) bool {
if m, ok := item.node.(map[string]interface{}); ok {
if len(m) == 0 && len(item.path) > 0 {
return true
}
} else if len(item.path) > 0 {
return true
}

return false
}

func getRoots(root interface{}) []interface{} {
defaultRoot := []interface{}{root}
m, ok := root.(map[string]interface{})
if !ok {
return defaultRoot
}

bulk, ok := m[bulkField]
if !ok {
return defaultRoot
}

slice, ok := bulk.([]interface{})
if !ok {
return defaultRoot
}

return slice
}

func isValidMethod(req *http.Request, methods ...string) bool {
for _, m := range methods {
if req.Method == m {
return true
}
}

return false
}

// pathItem stores a in-progress deconstruction of a path for a fieldmask
type pathItem struct {
// the list of prior fields leading up to node
Expand All @@ -110,17 +160,19 @@ func PresenceClientInterceptor() grpc.UnaryClientInterceptor {
if !found {
return
}
fieldMask := &field_mask.FieldMask{Paths: paths}

// If a field with type *FieldMask exists, set the paths in it
fieldMask := fieldMaskFromPaths(paths)
// If a field with type *FieldMask or []*FieldMask exists, set the paths in it
t := reflect.ValueOf(req)
if t.Kind() != reflect.Interface && t.Kind() != reflect.Ptr {
return
}

t = t.Elem()
if t.Kind() != reflect.Struct { // only Structs can have their fields enumerated
return
}

for i := 0; i < t.NumField(); i++ {
f := t.Field(i)
if f.Type() == reflect.TypeOf(fieldMask) && f.IsNil() {
Expand All @@ -130,3 +182,19 @@ func PresenceClientInterceptor() grpc.UnaryClientInterceptor {
return
}
}

func fieldMaskFromPaths(paths []string) interface{} {
if len(paths) == 0 {
return &field_mask.FieldMask{}
}

if len(paths) > 1 {
bulkFieldMasks := make([]*field_mask.FieldMask, len(paths))
for i, p := range paths {
bulkFieldMasks[i] = &field_mask.FieldMask{Paths: strings.Split(p, pathsSeparator)}
}
return bulkFieldMasks
}

return &field_mask.FieldMask{Paths: strings.Split(paths[0], pathsSeparator)}
}
38 changes: 28 additions & 10 deletions gateway/field_presence_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package gateway

import (
"context"
"fmt"
"io/ioutil"
"net/http"
"reflect"
Expand All @@ -20,8 +21,9 @@ func TestAnnotator(t *testing.T) {
``: metadata.MD{fieldPresenceMetaKey: nil},
`{}`: metadata.MD{fieldPresenceMetaKey: nil},
`{`: nil,
`{"one":{"two":"a", "three":[]}, "four": 5}`: {fieldPresenceMetaKey: []string{"Four", "One.Two", "One.Three"}},
`{"one": {}}`: {fieldPresenceMetaKey: []string{"One"}},
`{"objects":[{"one": {"two":"a", "three":[]}, "four": 5}, {"one":{"two":"a", "three":[]}, "four": 5}]}`: {fieldPresenceMetaKey: []string{"Four$One.Two$One.Three", "Four$One.Two$One.Three"}},
`{"one":{"two":"a", "three":[]}, "four": 5}`: {fieldPresenceMetaKey: []string{"Four$One.Two$One.Three"}},
`{"one": {}}`: {fieldPresenceMetaKey: []string{"One"}},
`{
"name": "atlas",
"burden": {
Expand All @@ -40,9 +42,7 @@ func TestAnnotator(t *testing.T) {
"mortals": []
}
}
}`: {fieldPresenceMetaKey: []string{"Name", "Burden.Duration", "Burden.Weight",
"Burden.Breaks", "Burden.Replacements.Hero.Name", "Burden.Replacements.Hero.Duration",
"Burden.Replacements.Hero.Lineage.Mother", "Burden.Replacements.Hero.Lineage.Father", "Burden.Replacements.Mortals"}},
}`: {fieldPresenceMetaKey: []string{"Name$Burden.Duration$Burden.Weight$Burden.Breaks$Burden.Replacements.Hero.Name$Burden.Replacements.Hero.Duration$Burden.Replacements.Hero.Lineage.Mother$Burden.Replacements.Hero.Lineage.Father$Burden.Replacements.Mortals"}},
} {
postReq := &http.Request{
Method: "POST",
Expand All @@ -54,15 +54,33 @@ func TestAnnotator(t *testing.T) {
continue
}
// Because the order of objects at the same depth is not guaranteed
sort.Strings(md[fieldPresenceMetaKey])
sort.Strings(expect[fieldPresenceMetaKey])
if !reflect.DeepEqual(md, expect) {
if !isEqualFieldMasks(md[fieldPresenceMetaKey], expect[fieldPresenceMetaKey]) {
t.Errorf("Did not produce expected metadata %+v, got %+v", expect, md)
}

}
}

func isEqualFieldMasks(s1 []string, s2 []string) bool {
if len(s1) != len(s2) {
fmt.Println("len(s1) != len(s2)", len(s1), len(s2))
return false
}

for i := 0; i < len(s1); i++ {
mask1, mask2 := strings.Split(s1[i], "$"), strings.Split(s2[i], "$")
sort.Strings(mask1)
sort.Strings(mask2)

if !reflect.DeepEqual(mask1, mask2) {
fmt.Println("!reflect.DeepEqual(mask1, mask2)")
return false
}
}

return true
}

type dummyReq struct {
SomeFieldMaskField *field_mask.FieldMask
}
Expand All @@ -80,7 +98,7 @@ func TestUnaryServerInterceptor(t *testing.T) {
interceptor := PresenceClientInterceptor()
md := runtime.ServerMetadata{
HeaderMD: metadata.MD{
fieldPresenceMetaKey: []string{"one.two.three", "one.four"},
fieldPresenceMetaKey: []string{"one.two.three$one.four"},
},
}
ctx := runtime.NewServerMetadataContext(context.Background(), md)
Expand Down Expand Up @@ -112,7 +130,7 @@ func TestUnaryServerInterceptor(t *testing.T) {
if req == nil {
t.Fatal("For some reason it deleted the request object")
}
got, want := req.SomeFieldMaskField, &field_mask.FieldMask{Paths: nil}
got, want := req.SomeFieldMaskField, &field_mask.FieldMask{}
if !reflect.DeepEqual(got, want) {
t.Errorf("Didn't properly set the fieldmask in the request.\ngot :%v\nwant:%v", got, want)
}
Expand Down

0 comments on commit dae90b8

Please sign in to comment.