diff --git a/gorm/filtering.go b/gorm/filtering.go index 70c5b173..43c05e69 100644 --- a/gorm/filtering.go +++ b/gorm/filtering.go @@ -119,11 +119,21 @@ func LogicalOperatorToGorm(ctx context.Context, lop *query.LogicalOperator, obj // StringConditionToGorm returns GORM Plain SQL representation of the string condition. func StringConditionToGorm(ctx context.Context, c *query.StringCondition, obj interface{}, pb proto.Message) (string, []interface{}, map[string]struct{}, error) { - var assocToJoin map[string]struct{} - dbName, assoc, err := HandleFieldPath(ctx, c.FieldPath, obj) + var ( + assocToJoin map[string]struct{} + dbName, assoc string + err error + ) + + if IsJSONCondition(ctx, c.FieldPath, obj) { + dbName, assoc, err = HandleJSONFieldPath(ctx, c.FieldPath, obj, c.Value) + } else { + dbName, assoc, err = HandleFieldPath(ctx, c.FieldPath, obj) + } if err != nil { return "", nil, nil, err } + if assoc != "" { assocToJoin = make(map[string]struct{}) assocToJoin[assoc] = struct{}{} @@ -295,8 +305,16 @@ func NumberArrayConditionToGorm(ctx context.Context, c *query.NumberArrayConditi } func StringArrayConditionToGorm(ctx context.Context, c *query.StringArrayCondition, obj interface{}, pb proto.Message) (string, []interface{}, map[string]struct{}, error) { - var assocToJoin map[string]struct{} - dbName, assoc, err := HandleFieldPath(ctx, c.FieldPath, obj) + var ( + assocToJoin map[string]struct{} + dbName, assoc string + err error + ) + if IsJSONCondition(ctx, c.FieldPath, obj) { + dbName, assoc, err = HandleJSONFieldPath(ctx, c.FieldPath, obj, c.Values...) + } else { + dbName, assoc, err = HandleFieldPath(ctx, c.FieldPath, obj) + } if err != nil { return "", nil, nil, err } diff --git a/gorm/filtering_test.go b/gorm/filtering_test.go index 1f14961d..dca07a08 100644 --- a/gorm/filtering_test.go +++ b/gorm/filtering_test.go @@ -4,6 +4,7 @@ import ( "context" "testing" + "github.com/jinzhu/gorm/dialects/postgres" "github.com/stretchr/testify/assert" "github.com/infobloxopen/atlas-app-toolkit/query" @@ -18,6 +19,7 @@ type Entity struct { NestedEntity NestedEntity Id string Ref *string + Tags *postgres.Jsonb } type EntityProto struct { @@ -258,6 +260,69 @@ func TestGormFiltering(t *testing.T) { nil, nil, }, + { + `tags == '{"Location": "Tacoma"}'`, + `(entities.tags = ?)`, + []interface{}{`{"Location": "Tacoma"}`}, + nil, + nil, + }, + { + `tags.Location == 'Tacoma'`, + `(entities.tags #>> '{Location}' = ?)`, + []interface{}{"Tacoma"}, + nil, + nil, + }, + { + `tags.Location == '{"City": "Tacoma"}'`, + `(entities.tags #> '{Location}' = ?)`, + []interface{}{`{"City": "Tacoma"}`}, + nil, + nil, + }, + { + `tags in ['{"Location": "Tacoma"}', '{"Location": "Minsk"}']`, + `(entities.tags IN (?, ?))`, + []interface{}{`{"Location": "Tacoma"}`, `{"Location": "Minsk"}`}, + nil, + nil, + }, + { + `tags.Location in ['Tacoma', 'Minsk']`, + `(entities.tags #>> '{Location}' IN (?, ?))`, + []interface{}{"Tacoma", "Minsk"}, + nil, + nil, + }, + { + `tags.Location in ['{"City": "Tacoma"}', '{"City": "Minsk"}']`, + `(entities.tags #> '{Location}' IN (?, ?))`, + []interface{}{`{"City": "Tacoma"}`, `{"City": "Minsk"}`}, + nil, + nil, + }, + { + `not(tags.Location == 'Tacoma')`, + `NOT(entities.tags #>> '{Location}' = ?)`, + []interface{}{"Tacoma"}, + nil, + nil, + }, + { + `not(tags.Location in ['Tacoma', 'Minsk'])`, + `(entities.tags #>> '{Location}' NOT IN (?, ?))`, + []interface{}{"Tacoma", "Minsk"}, + nil, + nil, + }, + { + `not(tags.Location in ['{"City": "Tacoma"}', '{"City": "Minsk"}'])`, + `(entities.tags #> '{Location}' NOT IN (?, ?))`, + []interface{}{`{"City": "Tacoma"}`, `{"City": "Minsk"}`}, + nil, + nil, + }, } for _, test := range tests { diff --git a/gorm/utilities.go b/gorm/utilities.go index fcab746f..be6de3de 100644 --- a/gorm/utilities.go +++ b/gorm/utilities.go @@ -10,6 +10,7 @@ import ( "github.com/golang/protobuf/proto" "github.com/golang/protobuf/protoc-gen-go/generator" jgorm "github.com/jinzhu/gorm" + "github.com/jinzhu/gorm/dialects/postgres" "github.com/jinzhu/inflection" "time" @@ -40,6 +41,65 @@ func HandleFieldPath(ctx context.Context, fieldPath []string, obj interface{}) ( return dbPath, "", nil } +//HandleJSONFiledPath translate field path to JSONB path for postgres jsonb +func HandleJSONFieldPath(ctx context.Context, fieldPath []string, obj interface{}, values ...string) (string, string, error) { + operator := "#>>" + if isRawJSON(values...) { + operator = "#>" + } + + dbPath, err := fieldPathToDBName(fieldPath[:1], obj) + if err != nil { + switch err.(type) { + case *EmptyFieldPathError: + return "", "", err + default: + dbPath = fieldPath[0] + } + } + + if len(fieldPath) == 1 { + return dbPath, "", nil + } + + return fmt.Sprintf("%s %s '{%s}'", dbPath, operator, strings.Join(fieldPath[1:], ",")), "", nil +} + +func isRawJSON(values ...string) bool { + if len(values) == 0 { + return false + } + + for _, v := range values { + //TODO: this is a very poor check to prevent unexpected errors from Database engine consider to make full validation + //TODO: also we need return an error if json invalid to prevent database error for json parsing + v = strings.TrimSpace(v) + if !strings.HasPrefix(v, "{") || !strings.HasSuffix(v, "}") { + return false + } + } + + return true +} + +//TODO: add supprt for embeded objects +func IsJSONCondition(ctx context.Context, fieldPath []string, obj interface{}) bool { + fieldName := generator.CamelCase(fieldPath[0]) + objType := indirectType(reflect.TypeOf(obj)) + field, ok := objType.FieldByName(fieldName) + if !ok { + return false + } + + fInterface := reflect.Zero(indirectType(field.Type)).Interface() + switch fInterface.(type) { + case postgres.Jsonb: + return true + } + + return false +} + func fieldPathToDBName(fieldPath []string, obj interface{}) (string, error) { objType := indirectType(reflect.TypeOf(obj)) pathLength := len(fieldPath)