From 38e87594e8dffb012b80d9c47c053030bfcae621 Mon Sep 17 00:00:00 2001 From: Matt Johnson-Pint Date: Fri, 25 Oct 2024 11:43:47 -0700 Subject: [PATCH 1/6] Use GraphQL terminology --- runtime/graphql/schemagen/filters.go | 4 +- runtime/graphql/schemagen/schemagen.go | 68 +++++++++++++------------- 2 files changed, 36 insertions(+), 36 deletions(-) diff --git a/runtime/graphql/schemagen/filters.go b/runtime/graphql/schemagen/filters.go index 26a36ac9..a4f88a59 100644 --- a/runtime/graphql/schemagen/filters.go +++ b/runtime/graphql/schemagen/filters.go @@ -11,7 +11,7 @@ package schemagen import "github.com/hypermodeinc/modus/runtime/manifestdata" -func getFnFilter() func(*FunctionSignature) bool { +func getFieldFilter() func(*FieldDefinition) bool { embedders := make(map[string]bool) for _, collection := range manifestdata.GetManifest().Collections { for _, searchMethod := range collection.SearchMethods { @@ -19,7 +19,7 @@ func getFnFilter() func(*FunctionSignature) bool { } } - return func(f *FunctionSignature) bool { + return func(f *FieldDefinition) bool { return !embedders[f.Name] } } diff --git a/runtime/graphql/schemagen/schemagen.go b/runtime/graphql/schemagen/schemagen.go index 0b0ef212..c20b6174 100644 --- a/runtime/graphql/schemagen/schemagen.go +++ b/runtime/graphql/schemagen/schemagen.go @@ -42,20 +42,20 @@ func GetGraphQLSchema(ctx context.Context, md *metadata.Metadata) (*GraphQLSchem inputTypeDefs, errors := transformTypes(md.Types, lti, true) resultTypeDefs, errs := transformTypes(md.Types, lti, false) errors = append(errors, errs...) - functions, errs := transformFunctions(md.FnExports, inputTypeDefs, resultTypeDefs, lti) + fields, errs := transformFunctions(md.FnExports, inputTypeDefs, resultTypeDefs, lti) errors = append(errors, errs...) if len(errors) > 0 { return nil, fmt.Errorf("failed to generate schema: %+v", errors) } - functions = filterFunctions(functions) + fields = filterFields(fields) scalarTypes := extractCustomScalarTypes(inputTypeDefs, resultTypeDefs) - inputTypes := filterTypes(utils.MapValues(inputTypeDefs), functions, true) - resultTypes := filterTypes(utils.MapValues(resultTypeDefs), functions, false) + inputTypes := filterTypes(utils.MapValues(inputTypeDefs), fields, true) + resultTypes := filterTypes(utils.MapValues(resultTypeDefs), fields, false) buf := bytes.Buffer{} - writeSchema(&buf, functions, scalarTypes, inputTypes, resultTypes) + writeSchema(&buf, fields, scalarTypes, inputTypes, resultTypes) mapTypes := make([]string, 0, len(resultTypeDefs)) for _, t := range resultTypeDefs { @@ -119,9 +119,9 @@ func transformTypes(types metadata.TypeMap, lti langsupport.LanguageTypeInfo, fo return typeDefs, errors } -type FunctionSignature struct { +type FieldDefinition struct { Name string - Parameters []*ParameterSignature + Arguments []*ArgumentDefinition ReturnType string } @@ -136,14 +136,14 @@ type NameTypePair struct { Type string } -type ParameterSignature struct { +type ArgumentDefinition struct { Name string Type string Default *any } -func transformFunctions(functions metadata.FunctionMap, inputTypeDefs, resultTypeDefs map[string]*TypeDefinition, lti langsupport.LanguageTypeInfo) ([]*FunctionSignature, []*TransformError) { - output := make([]*FunctionSignature, len(functions)) +func transformFunctions(functions metadata.FunctionMap, inputTypeDefs, resultTypeDefs map[string]*TypeDefinition, lti langsupport.LanguageTypeInfo) ([]*FieldDefinition, []*TransformError) { + fields := make([]*FieldDefinition, len(functions)) errors := make([]*TransformError, 0) i := 0 @@ -152,7 +152,7 @@ func transformFunctions(functions metadata.FunctionMap, inputTypeDefs, resultTyp for _, name := range fnNames { f := functions[name] - params, err := convertParameters(f.Parameters, lti, inputTypeDefs) + args, err := convertParameters(f.Parameters, lti, inputTypeDefs) if err != nil { errors = append(errors, &TransformError{f, err}) continue @@ -164,23 +164,23 @@ func transformFunctions(functions metadata.FunctionMap, inputTypeDefs, resultTyp continue } - output[i] = &FunctionSignature{ + fields[i] = &FieldDefinition{ Name: f.Name, - Parameters: params, + Arguments: args, ReturnType: returnType, } i++ } - return output, errors + return fields, errors } -func filterFunctions(functions []*FunctionSignature) []*FunctionSignature { - fnFilter := getFnFilter() - results := make([]*FunctionSignature, 0, len(functions)) - for _, f := range functions { - if fnFilter(f) { +func filterFields(fields []*FieldDefinition) []*FieldDefinition { + filter := getFieldFilter() + results := make([]*FieldDefinition, 0, len(fields)) + for _, f := range fields { + if filter(f) { results = append(results, f) } } @@ -188,8 +188,8 @@ func filterFunctions(functions []*FunctionSignature) []*FunctionSignature { return results } -func filterTypes(types []*TypeDefinition, functions []*FunctionSignature, forInput bool) []*TypeDefinition { - // Filter out types that are not used by any function. +func filterTypes(types []*TypeDefinition, fields []*FieldDefinition, forInput bool) []*TypeDefinition { + // Filter out types that are not used by any field. // Also then recursively filter out types that are not used by any type. // Make a map of all types @@ -199,11 +199,11 @@ func filterTypes(types []*TypeDefinition, functions []*FunctionSignature, forInp typeMap[name] = t } - // Get all types used by functions, including subtypes + // Get all types used by fields, including subtypes usedTypes := make(map[string]bool) - for _, f := range functions { + for _, f := range fields { if forInput { - for _, p := range f.Parameters { + for _, p := range f.Arguments { addUsedTypes(p.Type, typeMap, usedTypes) } } else { @@ -263,13 +263,13 @@ func getBaseType(name string) string { return name } -func writeSchema(buf *bytes.Buffer, functions []*FunctionSignature, scalarTypes []string, inputTypeDefs, resultTypeDefs []*TypeDefinition) { +func writeSchema(buf *bytes.Buffer, fields []*FieldDefinition, scalarTypes []string, inputTypeDefs, resultTypeDefs []*TypeDefinition) { // write header buf.WriteString("# Modus GraphQL Schema (auto-generated)\n\n") // sort everything - slices.SortFunc(functions, func(a, b *FunctionSignature) int { + slices.SortFunc(fields, func(a, b *FieldDefinition) int { return cmp.Compare(strings.ToLower(a.Name), strings.ToLower(b.Name)) }) slices.SortFunc(scalarTypes, func(a, b string) int { @@ -282,14 +282,14 @@ func writeSchema(buf *bytes.Buffer, functions []*FunctionSignature, scalarTypes return cmp.Compare(strings.ToLower(a.Name), strings.ToLower(b.Name)) }) - // write query functions + // write query object buf.WriteString("type Query {\n") - for _, f := range functions { + for _, f := range fields { buf.WriteString(" ") buf.WriteString(f.Name) - if len(f.Parameters) > 0 { + if len(f.Arguments) > 0 { buf.WriteByte('(') - for i, p := range f.Parameters { + for i, p := range f.Arguments { if i > 0 { buf.WriteString(", ") } @@ -358,12 +358,12 @@ func writeSchema(buf *bytes.Buffer, functions []*FunctionSignature, scalarTypes buf.WriteByte('\n') } -func convertParameters(parameters []*metadata.Parameter, lti langsupport.LanguageTypeInfo, typeDefs map[string]*TypeDefinition) ([]*ParameterSignature, error) { +func convertParameters(parameters []*metadata.Parameter, lti langsupport.LanguageTypeInfo, typeDefs map[string]*TypeDefinition) ([]*ArgumentDefinition, error) { if len(parameters) == 0 { return nil, nil } - output := make([]*ParameterSignature, len(parameters)) + args := make([]*ArgumentDefinition, len(parameters)) for i, p := range parameters { t, err := convertType(p.Type, lti, typeDefs, false, true) @@ -371,13 +371,13 @@ func convertParameters(parameters []*metadata.Parameter, lti langsupport.Languag return nil, err } - output[i] = &ParameterSignature{ + args[i] = &ArgumentDefinition{ Name: p.Name, Type: t, Default: p.Default, } } - return output, nil + return args, nil } func convertResults(results []*metadata.Result, lti langsupport.LanguageTypeInfo, typeDefs map[string]*TypeDefinition) (string, error) { From 2438d37f68917db00ef2a297c1fba95dabf0f0c0 Mon Sep 17 00:00:00 2001 From: Matt Johnson-Pint Date: Fri, 25 Oct 2024 14:19:59 -0700 Subject: [PATCH 2/6] Apply conventions to schema generation --- runtime/graphql/schemagen/conventions.go | 62 +++++++++ runtime/graphql/schemagen/filters.go | 14 +- runtime/graphql/schemagen/schemagen.go | 124 +++++++++++------- .../graphql/schemagen/schemagen_as_test.go | 15 ++- .../graphql/schemagen/schemagen_go_test.go | 15 ++- 5 files changed, 163 insertions(+), 67 deletions(-) create mode 100644 runtime/graphql/schemagen/conventions.go diff --git a/runtime/graphql/schemagen/conventions.go b/runtime/graphql/schemagen/conventions.go new file mode 100644 index 00000000..8d6c5ed5 --- /dev/null +++ b/runtime/graphql/schemagen/conventions.go @@ -0,0 +1,62 @@ +/* + * Copyright 2024 Hypermode Inc. + * Licensed under the terms of the Apache License, Version 2.0 + * See the LICENSE file that accompanied this code for further details. + * + * SPDX-FileCopyrightText: 2024 Hypermode Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package schemagen + +import "strings" + +// prefixes that are used to identify query fields, and will be trimmed from the field name +var queryTrimPrefixes = []string{"get", "list"} + +// prefixes that are used to identify mutation fields +var mutationPrefixes = []string{ + "mutate", + "post", "patch", "put", "delete", + "add", "update", "insert", "upsert", + "create", "edit", "save", "remove", "alter", "modify", +} + +func isMutation(fnName string) bool { + prefix := getPrefix(fnName, mutationPrefixes) + if prefix == "" { + return false + } + + // embedders are not mutations + embedders := getEmbedderFields() + return !embedders[fnName] +} + +func getFieldName(fnName string) string { + prefix := getPrefix(fnName, queryTrimPrefixes) + fieldName := strings.TrimPrefix(fnName, prefix) + return strings.ToLower(fieldName[:1]) + fieldName[1:] +} + +func getPrefix(fnName string, prefixes []string) string { + for _, prefix := range prefixes { + // check for exact match + fnNameLowered := strings.ToLower(fnName) + if fnNameLowered == prefix { + return prefix + } + + // check for a prefix, but only if the prefix is NOT followed by a lowercase letter + // for example, we want to match "addPost" but not "additionalPosts" + prefixLen := len(prefix) + if len(fnName) > prefixLen && strings.HasPrefix(fnNameLowered, prefix) { + c := fnName[prefixLen] + if c < 'a' || c > 'z' { + return prefix + } + } + } + + return "" +} diff --git a/runtime/graphql/schemagen/filters.go b/runtime/graphql/schemagen/filters.go index a4f88a59..11bd04aa 100644 --- a/runtime/graphql/schemagen/filters.go +++ b/runtime/graphql/schemagen/filters.go @@ -12,14 +12,18 @@ package schemagen import "github.com/hypermodeinc/modus/runtime/manifestdata" func getFieldFilter() func(*FieldDefinition) bool { + embedders := getEmbedderFields() + return func(f *FieldDefinition) bool { + return !embedders[f.Name] + } +} + +func getEmbedderFields() map[string]bool { embedders := make(map[string]bool) for _, collection := range manifestdata.GetManifest().Collections { for _, searchMethod := range collection.SearchMethods { - embedders[searchMethod.Embedder] = true + embedders[getFieldName(searchMethod.Embedder)] = true } } - - return func(f *FieldDefinition) bool { - return !embedders[f.Name] - } + return embedders } diff --git a/runtime/graphql/schemagen/schemagen.go b/runtime/graphql/schemagen/schemagen.go index c20b6174..3502aa8a 100644 --- a/runtime/graphql/schemagen/schemagen.go +++ b/runtime/graphql/schemagen/schemagen.go @@ -42,20 +42,23 @@ func GetGraphQLSchema(ctx context.Context, md *metadata.Metadata) (*GraphQLSchem inputTypeDefs, errors := transformTypes(md.Types, lti, true) resultTypeDefs, errs := transformTypes(md.Types, lti, false) errors = append(errors, errs...) - fields, errs := transformFunctions(md.FnExports, inputTypeDefs, resultTypeDefs, lti) + queryFields, mutationFields, errs := transformFunctions(md.FnExports, inputTypeDefs, resultTypeDefs, lti) errors = append(errors, errs...) if len(errors) > 0 { return nil, fmt.Errorf("failed to generate schema: %+v", errors) } - fields = filterFields(fields) + queryFields = filterFields(queryFields) + mutationFields = filterFields(mutationFields) + allFields := append(queryFields, mutationFields...) + scalarTypes := extractCustomScalarTypes(inputTypeDefs, resultTypeDefs) - inputTypes := filterTypes(utils.MapValues(inputTypeDefs), fields, true) - resultTypes := filterTypes(utils.MapValues(resultTypeDefs), fields, false) + inputTypes := filterTypes(utils.MapValues(inputTypeDefs), allFields, true) + resultTypes := filterTypes(utils.MapValues(resultTypeDefs), allFields, false) buf := bytes.Buffer{} - writeSchema(&buf, fields, scalarTypes, inputTypes, resultTypes) + writeSchema(&buf, queryFields, mutationFields, scalarTypes, inputTypes, resultTypes) mapTypes := make([]string, 0, len(resultTypeDefs)) for _, t := range resultTypeDefs { @@ -142,11 +145,11 @@ type ArgumentDefinition struct { Default *any } -func transformFunctions(functions metadata.FunctionMap, inputTypeDefs, resultTypeDefs map[string]*TypeDefinition, lti langsupport.LanguageTypeInfo) ([]*FieldDefinition, []*TransformError) { - fields := make([]*FieldDefinition, len(functions)) +func transformFunctions(functions metadata.FunctionMap, inputTypeDefs, resultTypeDefs map[string]*TypeDefinition, lti langsupport.LanguageTypeInfo) ([]*FieldDefinition, []*FieldDefinition, []*TransformError) { + queryFields := make([]*FieldDefinition, 0, len(functions)) + mutationFields := make([]*FieldDefinition, 0, len(functions)) errors := make([]*TransformError, 0) - i := 0 fnNames := utils.MapKeys(functions) sort.Strings(fnNames) for _, name := range fnNames { @@ -164,16 +167,20 @@ func transformFunctions(functions metadata.FunctionMap, inputTypeDefs, resultTyp continue } - fields[i] = &FieldDefinition{ - Name: f.Name, + field := &FieldDefinition{ + Name: getFieldName(f.Name), Arguments: args, ReturnType: returnType, } - i++ + if isMutation(f.Name) { + mutationFields = append(mutationFields, field) + } else { + queryFields = append(queryFields, field) + } } - return fields, errors + return queryFields, mutationFields, errors } func filterFields(fields []*FieldDefinition) []*FieldDefinition { @@ -263,13 +270,16 @@ func getBaseType(name string) string { return name } -func writeSchema(buf *bytes.Buffer, fields []*FieldDefinition, scalarTypes []string, inputTypeDefs, resultTypeDefs []*TypeDefinition) { +func writeSchema(buf *bytes.Buffer, queryFields []*FieldDefinition, mutationFields []*FieldDefinition, scalarTypes []string, inputTypeDefs, resultTypeDefs []*TypeDefinition) { // write header - buf.WriteString("# Modus GraphQL Schema (auto-generated)\n\n") + buf.WriteString("# Modus GraphQL Schema (auto-generated)\n") // sort everything - slices.SortFunc(fields, func(a, b *FieldDefinition) int { + slices.SortFunc(queryFields, func(a, b *FieldDefinition) int { + return cmp.Compare(strings.ToLower(a.Name), strings.ToLower(b.Name)) + }) + slices.SortFunc(mutationFields, func(a, b *FieldDefinition) int { return cmp.Compare(strings.ToLower(a.Name), strings.ToLower(b.Name)) }) slices.SortFunc(scalarTypes, func(a, b string) int { @@ -283,49 +293,38 @@ func writeSchema(buf *bytes.Buffer, fields []*FieldDefinition, scalarTypes []str }) // write query object - buf.WriteString("type Query {\n") - for _, f := range fields { - buf.WriteString(" ") - buf.WriteString(f.Name) - if len(f.Arguments) > 0 { - buf.WriteByte('(') - for i, p := range f.Arguments { - if i > 0 { - buf.WriteString(", ") - } - buf.WriteString(p.Name) - buf.WriteString(": ") - buf.WriteString(p.Type) - if p.Default != nil { - val, err := utils.JsonSerialize(*p.Default) - if err == nil { - buf.WriteString(" = ") - buf.Write(val) - } - } - } - buf.WriteByte(')') + if len(queryFields) > 0 { + buf.WriteByte('\n') + buf.WriteString("type Query {\n") + for _, field := range queryFields { + writeField(buf, field) } - buf.WriteString(": ") - buf.WriteString(f.ReturnType) + buf.WriteString("}\n") + } + + // write mutation object + if len(mutationFields) > 0 { buf.WriteByte('\n') + buf.WriteString("type Mutation {\n") + for _, field := range mutationFields { + writeField(buf, field) + } + buf.WriteString("}\n") } - buf.WriteByte('}') // write scalars - for i, scalar := range scalarTypes { - if i == 0 { + if len(scalarTypes) > 0 { + buf.WriteByte('\n') + for _, scalar := range scalarTypes { + buf.WriteString("scalar ") + buf.WriteString(scalar) buf.WriteByte('\n') } - - buf.WriteByte('\n') - buf.WriteString("scalar ") - buf.WriteString(scalar) } // write input types for _, t := range inputTypeDefs { - buf.WriteString("\n\n") + buf.WriteByte('\n') buf.WriteString("input ") buf.WriteString(t.Name) buf.WriteString(" {\n") @@ -336,12 +335,12 @@ func writeSchema(buf *bytes.Buffer, fields []*FieldDefinition, scalarTypes []str buf.WriteString(f.Type) buf.WriteByte('\n') } - buf.WriteByte('}') + buf.WriteString("}\n") } // write result types for _, t := range resultTypeDefs { - buf.WriteString("\n\n") + buf.WriteByte('\n') buf.WriteString("type ") buf.WriteString(t.Name) buf.WriteString(" {\n") @@ -352,9 +351,34 @@ func writeSchema(buf *bytes.Buffer, fields []*FieldDefinition, scalarTypes []str buf.WriteString(f.Type) buf.WriteByte('\n') } - buf.WriteByte('}') + buf.WriteString("}\n") } +} +func writeField(buf *bytes.Buffer, field *FieldDefinition) { + buf.WriteString(" ") + buf.WriteString(field.Name) + if len(field.Arguments) > 0 { + buf.WriteByte('(') + for i, p := range field.Arguments { + if i > 0 { + buf.WriteString(", ") + } + buf.WriteString(p.Name) + buf.WriteString(": ") + buf.WriteString(p.Type) + if p.Default != nil { + val, err := utils.JsonSerialize(*p.Default) + if err == nil { + buf.WriteString(" = ") + buf.Write(val) + } + } + } + buf.WriteByte(')') + } + buf.WriteString(": ") + buf.WriteString(field.ReturnType) buf.WriteByte('\n') } diff --git a/runtime/graphql/schemagen/schemagen_as_test.go b/runtime/graphql/schemagen/schemagen_as_test.go index e9a28f20..9c372f89 100644 --- a/runtime/graphql/schemagen/schemagen_as_test.go +++ b/runtime/graphql/schemagen/schemagen_as_test.go @@ -86,7 +86,7 @@ func Test_GetGraphQLSchema_AssemblyScript(t *testing.T) { md.FnExports.AddFunction("getPerson"). WithResult("assembly/test/Person") - md.FnExports.AddFunction("getPeople"). + md.FnExports.AddFunction("listPeople"). WithResult("~lib/array/Array") md.FnExports.AddFunction("addPerson"). @@ -177,13 +177,11 @@ func Test_GetGraphQLSchema_AssemblyScript(t *testing.T) { # Modus GraphQL Schema (auto-generated) type Query { - add(a: Int!, b: Int!): Int! - addPerson(person: PersonInput!): Void currentTime: Timestamp! doNothing: Void - getPeople: [Person!]! - getPerson: Person! - getProductMap: [StringProductPair!]! + people: [Person!]! + person: Person! + productMap: [StringProductPair!]! sayHello(name: String!): String! testDefaultArrayParams(a: [Int!]!, b: [Int!]! = [], c: [Int!]! = [1,2,3], d: [Int!], e: [Int!] = null, f: [Int!] = [], g: [Int!] = [1,2,3]): Void testDefaultIntParams(a: Int!, b: Int! = 0, c: Int! = 1): Void @@ -195,6 +193,11 @@ type Query { transform(items: [StringStringPairInput!]!): [StringStringPair!]! } +type Mutation { + add(a: Int!, b: Int!): Int! + addPerson(person: PersonInput!): Void +} + scalar Timestamp scalar Void diff --git a/runtime/graphql/schemagen/schemagen_go_test.go b/runtime/graphql/schemagen/schemagen_go_test.go index 82281ac9..a44d112e 100644 --- a/runtime/graphql/schemagen/schemagen_go_test.go +++ b/runtime/graphql/schemagen/schemagen_go_test.go @@ -85,7 +85,7 @@ func Test_GetGraphQLSchema_Go(t *testing.T) { md.FnExports.AddFunction("getPerson"). WithResult("testdata.Person") - md.FnExports.AddFunction("getPeople"). + md.FnExports.AddFunction("listPeople"). WithResult("[]testdata.Person") md.FnExports.AddFunction("addPerson"). @@ -205,13 +205,11 @@ func Test_GetGraphQLSchema_Go(t *testing.T) { # Modus GraphQL Schema (auto-generated) type Query { - add(a: Int!, b: Int!): Int! - addPerson(person: PersonInput!): Void currentTime: Timestamp! doNothing: Void - getPeople: [Person!] - getPerson: Person! - getProductMap: [StringProductPair!] + people: [Person!] + person: Person! + productMap: [StringProductPair!] sayHello(name: String!): String! testDefaultArrayParams(a: [Int!], b: [Int!] = [], c: [Int!] = [1,2,3], d: [Int!], e: [Int!] = null, f: [Int!] = [], g: [Int!] = [1,2,3]): Void testDefaultIntParams(a: Int!, b: Int! = 0, c: Int! = 1): Void @@ -226,6 +224,11 @@ type Query { transform(items: [StringStringPairInput!]): [StringStringPair!] } +type Mutation { + add(a: Int!, b: Int!): Int! + addPerson(person: PersonInput!): Void +} + scalar Timestamp scalar Void From 70f29dd322c06a2c4d2051df8d3f22658b597645 Mon Sep 17 00:00:00 2001 From: Matt Johnson-Pint Date: Fri, 25 Oct 2024 16:33:19 -0700 Subject: [PATCH 3/6] Support mutations --- runtime/graphql/datasource/planner.go | 29 ++++++++++------ runtime/graphql/engine/engine.go | 48 ++++++++++++++++++--------- runtime/graphql/graphql.go | 8 ++--- 3 files changed, 54 insertions(+), 31 deletions(-) diff --git a/runtime/graphql/datasource/planner.go b/runtime/graphql/datasource/planner.go index 41c2c75c..954ad89b 100644 --- a/runtime/graphql/datasource/planner.go +++ b/runtime/graphql/datasource/planner.go @@ -71,7 +71,7 @@ func (p *HypDSPlanner) DownstreamResponseFieldAlias(downstreamFieldRef int) (ali func (p *HypDSPlanner) DataSourcePlanningBehavior() plan.DataSourcePlanningBehavior { return plan.DataSourcePlanningBehavior{ - // This needs to be true, so we can distinguish results for multiple function calls in the same query. + // This needs to be true, so we can distinguish results for multiple function calls in the same operation. // Example: // query SayHello { // a: sayHello(name: "Sam") @@ -101,12 +101,12 @@ func (p *HypDSPlanner) EnterDocument(operation, definition *ast.Document) { func (p *HypDSPlanner) EnterField(ref int) { - // Capture information about every field in the query. + // Capture information about every field in the operation. f := p.captureField(ref) p.fields[ref] = *f - // If the field is enclosed by a root node, then it represents the function we want to call. - if p.enclosingTypeIsRootNode() { + // Capture only the fields that represent function calls. + if p.currentNodeIsFunctionCall() { // Save the field for the function. p.template.function = f @@ -126,7 +126,7 @@ func (p *HypDSPlanner) LeaveDocument(operation, definition *ast.Document) { } func (p *HypDSPlanner) stitchFields(f *fieldInfo) { - if len(f.fieldRefs) == 0 { + if f == nil || len(f.fieldRefs) == 0 { return } @@ -138,14 +138,21 @@ func (p *HypDSPlanner) stitchFields(f *fieldInfo) { } } -func (p *HypDSPlanner) enclosingTypeIsRootNode() bool { +func (p *HypDSPlanner) currentNodeIsFunctionCall() bool { + if p.visitor.Walker.CurrentKind != ast.NodeKindField { + return false + } + enclosingTypeDef := p.visitor.Walker.EnclosingTypeDefinition - for _, node := range p.visitor.Operation.RootNodes { - if node.Ref == enclosingTypeDef.Ref { - return true - } + if enclosingTypeDef.Kind != ast.NodeKindObjectTypeDefinition { + return false } - return false + + // TODO: This works, but it's a hack. We should find a better way to determine if the field is a function call. + // The previous approach of root node testing worked for queries, but not for mutations. + // The enclosing type name should not be relevant. + enclosingTypeName := p.visitor.Definition.ObjectTypeDefinitionNameString(enclosingTypeDef.Ref) + return enclosingTypeName == "Query" || enclosingTypeName == "Mutation" } func (p *HypDSPlanner) captureField(ref int) *fieldInfo { diff --git a/runtime/graphql/engine/engine.go b/runtime/graphql/engine/engine.go index 084029a5..64d4bed0 100644 --- a/runtime/graphql/engine/engine.go +++ b/runtime/graphql/engine/engine.go @@ -106,24 +106,25 @@ func getDatasourceConfig(ctx context.Context, schema *gql.Schema, cfg *datasourc defer span.Finish() queryTypeName := schema.QueryTypeName() - queryFieldNames := getAllQueryFields(ctx, schema) + queryFieldNames := getTypeFields(ctx, schema, queryTypeName) + + mutationTypeName := schema.MutationTypeName() + mutationFieldNames := getTypeFields(ctx, schema, mutationTypeName) + rootNodes := []plan.TypeField{ { TypeName: queryTypeName, FieldNames: queryFieldNames, }, + { + TypeName: mutationTypeName, + FieldNames: mutationFieldNames, + }, } - var childNodes []plan.TypeField - for _, f := range queryFieldNames { - fields := schema.GetAllNestedFieldChildrenFromTypeField(queryTypeName, f, gql.NewSkipReservedNamesFunc()) - for _, field := range fields { - childNodes = append(childNodes, plan.TypeField{ - TypeName: field.TypeName, - FieldNames: field.FieldNames, - }) - } - } + childNodes := []plan.TypeField{} + childNodes = append(childNodes, getChildNodes(queryFieldNames, schema, queryTypeName)...) + childNodes = append(childNodes, getChildNodes(mutationFieldNames, schema, mutationTypeName)...) return plan.NewDataSourceConfiguration( datasource.DataSourceName, @@ -133,6 +134,24 @@ func getDatasourceConfig(ctx context.Context, schema *gql.Schema, cfg *datasourc ) } +func getChildNodes(fieldNames []string, schema *gql.Schema, typeName string) []plan.TypeField { + var foundFields = make(map[string]bool) + var childNodes []plan.TypeField + for _, fieldName := range fieldNames { + fields := schema.GetAllNestedFieldChildrenFromTypeField(typeName, fieldName, gql.NewSkipReservedNamesFunc()) + for _, field := range fields { + if !foundFields[field.TypeName] { + foundFields[field.TypeName] = true + childNodes = append(childNodes, plan.TypeField{ + TypeName: field.TypeName, + FieldNames: field.FieldNames, + }) + } + } + } + return childNodes +} + func makeEngine(ctx context.Context, schema *gql.Schema, datasourceConfig plan.DataSourceConfiguration[datasource.HypDSConfig]) (*engine.ExecutionEngine, error) { span, ctx := utils.NewSentrySpanForCurrentFunc(ctx) defer span.Finish() @@ -150,17 +169,14 @@ func makeEngine(ctx context.Context, schema *gql.Schema, datasourceConfig plan.D return engine.NewExecutionEngine(ctx, adapter, engineConfig, resolverOptions) } -func getAllQueryFields(ctx context.Context, s *gql.Schema) []string { +func getTypeFields(ctx context.Context, s *gql.Schema, typeName string) []string { span, _ := utils.NewSentrySpanForCurrentFunc(ctx) defer span.Finish() doc := s.Document() - queryTypeName := s.QueryTypeName() - fields := make([]string, 0) for _, objectType := range doc.ObjectTypeDefinitions { - typeName := doc.Input.ByteSliceString(objectType.Name) - if typeName == queryTypeName { + if doc.Input.ByteSliceString(objectType.Name) == typeName { for _, fieldRef := range objectType.FieldsDefinition.Refs { field := doc.FieldDefinitions[fieldRef] fieldName := doc.Input.ByteSliceString(field.Name) diff --git a/runtime/graphql/graphql.go b/runtime/graphql/graphql.go index 864294aa..70ed1d98 100644 --- a/runtime/graphql/graphql.go +++ b/runtime/graphql/graphql.go @@ -100,14 +100,14 @@ func handleGraphQLRequest(w http.ResponseWriter, r *http.Request) { options = append(options, eng.WithRequestTraceOptions(traceOpts)) } - // Execute the GraphQL query + // Execute the GraphQL operation resultWriter := gql.NewEngineResultWriter() if err := engine.Execute(ctx, &gqlRequest, &resultWriter, options...); err != nil { if report, ok := err.(operationreport.Report); ok { if len(report.InternalErrors) > 0 { // Log internal errors, but don't return them to the client - msg := "Failed to execute GraphQL query." + msg := "Failed to execute GraphQL operation." logger.Err(ctx, err).Msg(msg) http.Error(w, msg, http.StatusInternalServerError) return @@ -124,10 +124,10 @@ func handleGraphQLRequest(w http.ResponseWriter, r *http.Request) { // cleanup empty arrays from error message before logging errMsg := strings.Replace(err.Error(), ", locations: []", "", 1) errMsg = strings.Replace(errMsg, ", path: []", "", 1) - logger.Warn(ctx).Str("error", errMsg).Msg("Failed to execute GraphQL query.") + logger.Warn(ctx).Str("error", errMsg).Msg("Failed to execute GraphQL operation.") } } else { - msg := "Failed to execute GraphQL query." + msg := "Failed to execute GraphQL operation." logger.Err(ctx, err).Msg(msg) http.Error(w, fmt.Sprintf("%s\n%v", msg, err), http.StatusInternalServerError) } From 38f12be3edda589d7907cb454813834933c0a115 Mon Sep 17 00:00:00 2001 From: Matt Johnson-Pint Date: Fri, 25 Oct 2024 17:23:41 -0700 Subject: [PATCH 4/6] Map field names back to function names --- runtime/graphql/datasource/configuration.go | 5 +-- runtime/graphql/datasource/planner.go | 25 ++++++++------ runtime/graphql/datasource/source.go | 17 +++++----- runtime/graphql/engine/engine.go | 5 +-- runtime/graphql/schemagen/schemagen.go | 36 +++++++++++++-------- 5 files changed, 52 insertions(+), 36 deletions(-) diff --git a/runtime/graphql/datasource/configuration.go b/runtime/graphql/datasource/configuration.go index 6ac10c68..4bd5e9e2 100644 --- a/runtime/graphql/datasource/configuration.go +++ b/runtime/graphql/datasource/configuration.go @@ -14,6 +14,7 @@ import ( ) type HypDSConfig struct { - WasmHost wasmhost.WasmHost - MapTypes []string + WasmHost wasmhost.WasmHost + FieldsToFunctions map[string]string + MapTypes []string } diff --git a/runtime/graphql/datasource/planner.go b/runtime/graphql/datasource/planner.go index 954ad89b..04dbdd6c 100644 --- a/runtime/graphql/datasource/planner.go +++ b/runtime/graphql/datasource/planner.go @@ -31,8 +31,9 @@ type HypDSPlanner struct { variables resolve.Variables fields map[int]fieldInfo template struct { - function *fieldInfo - data []byte + fieldInfo *fieldInfo + functionName string + data []byte } } @@ -108,12 +109,10 @@ func (p *HypDSPlanner) EnterField(ref int) { // Capture only the fields that represent function calls. if p.currentNodeIsFunctionCall() { - // Save the field for the function. - p.template.function = f + p.template.fieldInfo = f + p.template.functionName = p.config.FieldsToFunctions[f.Name] - // Also capture the input data for the function. - err := p.captureInputData(ref) - if err != nil { + if err := p.captureInputData(ref); err != nil { logger.Err(p.ctx, err).Msg("Error capturing input data.") return } @@ -122,7 +121,7 @@ func (p *HypDSPlanner) EnterField(ref int) { func (p *HypDSPlanner) LeaveDocument(operation, definition *ast.Document) { // Stitch the captured fields together to form a tree. - p.stitchFields(p.template.function) + p.stitchFields(p.template.fieldInfo) } func (p *HypDSPlanner) stitchFields(f *fieldInfo) { @@ -224,7 +223,13 @@ func (p *HypDSPlanner) captureInputData(fieldRef int) error { } func (p *HypDSPlanner) ConfigureFetch() resolve.FetchConfiguration { - fnJson, err := utils.JsonSerialize(p.template.function) + fieldInfoJson, err := utils.JsonSerialize(p.template.fieldInfo) + if err != nil { + logger.Error(p.ctx).Err(err).Msg("Error serializing json while configuring graphql fetch.") + return resolve.FetchConfiguration{} + } + + functionNameJson, err := utils.JsonSerialize(p.template.functionName) if err != nil { logger.Error(p.ctx).Err(err).Msg("Error serializing json while configuring graphql fetch.") return resolve.FetchConfiguration{} @@ -233,7 +238,7 @@ func (p *HypDSPlanner) ConfigureFetch() resolve.FetchConfiguration { // Note: we have to build the rest of the template manually, because the data field may // contain placeholders for variables, such as $$0$$ which are not valid in JSON. // They are replaced with the actual values by the time Load is called. - inputTemplate := fmt.Sprintf(`{"fn":%s,"data":%s}`, fnJson, p.template.data) + inputTemplate := fmt.Sprintf(`{"field":%s,"function":%s,"data":%s}`, fieldInfoJson, functionNameJson, p.template.data) return resolve.FetchConfiguration{ Input: inputTemplate, diff --git a/runtime/graphql/datasource/source.go b/runtime/graphql/datasource/source.go index 783955e4..2183ff20 100644 --- a/runtime/graphql/datasource/source.go +++ b/runtime/graphql/datasource/source.go @@ -28,8 +28,9 @@ import ( const DataSourceName = "ModusDataSource" type callInfo struct { - Function fieldInfo `json:"fn"` - Parameters map[string]any `json:"data"` + FieldInfo fieldInfo `json:"field"` + FunctionName string `json:"function"` + Parameters map[string]any `json:"data"` } type ModusDataSource struct { @@ -65,7 +66,7 @@ func (*ModusDataSource) LoadWithFiles(ctx context.Context, input []byte, files [ func (ds *ModusDataSource) callFunction(ctx context.Context, callInfo *callInfo) (any, []resolve.GraphQLError, error) { // Get the function info - fnInfo, err := ds.WasmHost.GetFunctionInfo(callInfo.Function.Name) + fnInfo, err := ds.WasmHost.GetFunctionInfo(callInfo.FunctionName) if err != nil { return nil, nil, err } @@ -79,7 +80,7 @@ func (ds *ModusDataSource) callFunction(ctx context.Context, callInfo *callInfo) // Store the execution info into the function output map. outputMap := ctx.Value(utils.FunctionOutputContextKey).(map[string]wasmhost.ExecutionInfo) - outputMap[callInfo.Function.AliasOrName()] = execInfo + outputMap[callInfo.FieldInfo.AliasOrName()] = execInfo // Transform messages (and error lines in the output buffers) to GraphQL errors. messages := append(execInfo.Messages(), utils.TransformConsoleOutput(execInfo.Buffers())...) @@ -107,7 +108,7 @@ func (ds *ModusDataSource) callFunction(ctx context.Context, callInfo *callInfo) func writeGraphQLResponse(ctx context.Context, out *bytes.Buffer, result any, gqlErrors []resolve.GraphQLError, fnErr error, ci *callInfo) error { - fieldName := ci.Function.AliasOrName() + fieldName := ci.FieldInfo.AliasOrName() // Include the function error if fnErr != nil { @@ -139,7 +140,7 @@ func writeGraphQLResponse(ctx context.Context, out *bytes.Buffer, result any, gq msg := fmt.Sprintf("Function completed successfully, but the result contains a %v value that cannot be serialized to JSON.", err.Value) logger.Warn(ctx). Bool("user_visible", true). - Str("function", ci.Function.Name). + Str("function", ci.FunctionName). Str("result", fmt.Sprintf("%+v", result)). Msg(msg) fmt.Fprintf(out, `{"errors":[{"message":"%s","path":["%s"],"extensions":{"level":"error"}}]}`, msg, fieldName) @@ -149,7 +150,7 @@ func writeGraphQLResponse(ctx context.Context, out *bytes.Buffer, result any, gq } // Transform the data - if r, err := transformValue(jsonResult, &ci.Function); err != nil { + if r, err := transformValue(jsonResult, &ci.FieldInfo); err != nil { return err } else { jsonData = r @@ -395,7 +396,7 @@ func transformErrors(messages []utils.LogMessage, ci *callInfo) []resolve.GraphQ if msg.IsError() { errors = append(errors, resolve.GraphQLError{ Message: msg.Message, - Path: []any{ci.Function.AliasOrName()}, + Path: []any{ci.FieldInfo.AliasOrName()}, Extensions: map[string]interface{}{ "level": msg.Level, }, diff --git a/runtime/graphql/engine/engine.go b/runtime/graphql/engine/engine.go index 64d4bed0..ea6367a2 100644 --- a/runtime/graphql/engine/engine.go +++ b/runtime/graphql/engine/engine.go @@ -94,8 +94,9 @@ func generateSchema(ctx context.Context, md *metadata.Metadata) (*gql.Schema, *d } cfg := &datasource.HypDSConfig{ - WasmHost: wasmhost.GetWasmHost(ctx), - MapTypes: generated.MapTypes, + WasmHost: wasmhost.GetWasmHost(ctx), + FieldsToFunctions: generated.FieldsToFunctions, + MapTypes: generated.MapTypes, } return schema, cfg, nil diff --git a/runtime/graphql/schemagen/schemagen.go b/runtime/graphql/schemagen/schemagen.go index 3502aa8a..38cc23c9 100644 --- a/runtime/graphql/schemagen/schemagen.go +++ b/runtime/graphql/schemagen/schemagen.go @@ -25,8 +25,9 @@ import ( ) type GraphQLSchema struct { - Schema string - MapTypes []string + Schema string + FieldsToFunctions map[string]string + MapTypes []string } func GetGraphQLSchema(ctx context.Context, md *metadata.Metadata) (*GraphQLSchema, error) { @@ -42,7 +43,7 @@ func GetGraphQLSchema(ctx context.Context, md *metadata.Metadata) (*GraphQLSchem inputTypeDefs, errors := transformTypes(md.Types, lti, true) resultTypeDefs, errs := transformTypes(md.Types, lti, false) errors = append(errors, errs...) - queryFields, mutationFields, errs := transformFunctions(md.FnExports, inputTypeDefs, resultTypeDefs, lti) + fieldsToFunctions, queryFields, mutationFields, errs := transformFunctions(md.FnExports, inputTypeDefs, resultTypeDefs, lti) errors = append(errors, errs...) if len(errors) > 0 { @@ -68,8 +69,9 @@ func GetGraphQLSchema(ctx context.Context, md *metadata.Metadata) (*GraphQLSchem } return &GraphQLSchema{ - Schema: buf.String(), - MapTypes: mapTypes, + Schema: buf.String(), + FieldsToFunctions: fieldsToFunctions, + MapTypes: mapTypes, }, nil } @@ -145,7 +147,10 @@ type ArgumentDefinition struct { Default *any } -func transformFunctions(functions metadata.FunctionMap, inputTypeDefs, resultTypeDefs map[string]*TypeDefinition, lti langsupport.LanguageTypeInfo) ([]*FieldDefinition, []*FieldDefinition, []*TransformError) { +// TODO: refactor for readability + +func transformFunctions(functions metadata.FunctionMap, inputTypeDefs, resultTypeDefs map[string]*TypeDefinition, lti langsupport.LanguageTypeInfo) (map[string]string, []*FieldDefinition, []*FieldDefinition, []*TransformError) { + fieldsToFunctions := make(map[string]string, len(functions)) queryFields := make([]*FieldDefinition, 0, len(functions)) mutationFields := make([]*FieldDefinition, 0, len(functions)) errors := make([]*TransformError, 0) @@ -153,34 +158,37 @@ func transformFunctions(functions metadata.FunctionMap, inputTypeDefs, resultTyp fnNames := utils.MapKeys(functions) sort.Strings(fnNames) for _, name := range fnNames { - f := functions[name] + fn := functions[name] - args, err := convertParameters(f.Parameters, lti, inputTypeDefs) + args, err := convertParameters(fn.Parameters, lti, inputTypeDefs) if err != nil { - errors = append(errors, &TransformError{f, err}) + errors = append(errors, &TransformError{fn, err}) continue } - returnType, err := convertResults(f.Results, lti, resultTypeDefs) + returnType, err := convertResults(fn.Results, lti, resultTypeDefs) if err != nil { - errors = append(errors, &TransformError{f, err}) + errors = append(errors, &TransformError{fn, err}) continue } + fieldName := getFieldName(fn.Name) + fieldsToFunctions[fieldName] = fn.Name + field := &FieldDefinition{ - Name: getFieldName(f.Name), + Name: fieldName, Arguments: args, ReturnType: returnType, } - if isMutation(f.Name) { + if isMutation(fn.Name) { mutationFields = append(mutationFields, field) } else { queryFields = append(queryFields, field) } } - return queryFields, mutationFields, errors + return fieldsToFunctions, queryFields, mutationFields, errors } func filterFields(fields []*FieldDefinition) []*FieldDefinition { From 0961528a4daaa33a7a44572fa29ce285c483a951 Mon Sep 17 00:00:00 2001 From: Matt Johnson-Pint Date: Fri, 25 Oct 2024 17:25:05 -0700 Subject: [PATCH 5/6] Update CHANGELOG.md --- CHANGELOG.md | 3 ++- cspell.json | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index cff66a38..3be89200 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -59,7 +59,8 @@ In previous releases, the name "Hypermode" was used for all three._ - Fix runtime shutdown issues with `modus dev` [#508](https://github.com/hypermodeinc/modus/pull/508) - Monitored manifest and env files for changes [#509](https://github.com/hypermodeinc/modus/pull/509) - Log bad GraphQL requests in dev [#510](https://github.com/hypermodeinc/modus/pull/510) -- Add jwks endpoint key support to auth [#511](https://github.com/hypermodeinc/modus/pull/511) +- Add JWKS endpoint key support to auth [#511](https://github.com/hypermodeinc/modus/pull/511) +- Use conventions to support GraphQL mutations and adjust query names [#513](https://github.com/hypermodeinc/modus/pull/513) ## 2024-10-02 - Version 0.12.7 diff --git a/cspell.json b/cspell.json index 6b3b6102..c84a6b43 100644 --- a/cspell.json +++ b/cspell.json @@ -88,6 +88,7 @@ "jsonlogs", "jsonparser", "jsonschema", + "JWKS", "langsupport", "ldflags", "legacymodels", From a54157763f6afba35f49a2244e474484f3253b53 Mon Sep 17 00:00:00 2001 From: Matt Johnson-Pint Date: Fri, 25 Oct 2024 17:40:10 -0700 Subject: [PATCH 6/6] Update integration test --- .../postgresql_integration_test.go | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/runtime/integration_tests/postgresql_integration_test.go b/runtime/integration_tests/postgresql_integration_test.go index e6d2599c..7623f10b 100644 --- a/runtime/integration_tests/postgresql_integration_test.go +++ b/runtime/integration_tests/postgresql_integration_test.go @@ -154,7 +154,7 @@ func TestPostgresqlNoConnection(t *testing.T) { // wait here to make sure the plugin is loaded time.Sleep(waitRefreshPluginInterval) - query := "{ getAllPeople { id name age } }" + query := "{ allPeople { id name age } }" response, err := runGraphqlQuery(graphQLRequest{Query: query}) assert.Nil(t, response) assert.NotNil(t, err) @@ -185,7 +185,7 @@ func TestPostgresqlNoHost(t *testing.T) { time.Sleep(waitRefreshPluginInterval) // when host name does not exist - query := "{ getAllPeople { id name age } }" + query := "{ allPeople { id name age } }" response, err := runGraphqlQuery(graphQLRequest{Query: query}) assert.Nil(t, response) assert.NotNil(t, err) @@ -216,7 +216,7 @@ func TestPostgresqlNoPostgresqlHost(t *testing.T) { time.Sleep(waitRefreshPluginInterval) // when host name has the wrong host type - query := "{ getAllPeople { id name age } }" + query := "{ allPeople { id name age } }" response, err := runGraphqlQuery(graphQLRequest{Query: query}) assert.Nil(t, response) assert.NotNil(t, err) @@ -247,7 +247,7 @@ func TestPostgresqlWrongConnString(t *testing.T) { time.Sleep(waitRefreshPluginInterval) // when connection string is wrong - query := "{ getAllPeople { id name age } }" + query := "{ allPeople { id name age } }" response, err := runGraphqlQuery(graphQLRequest{Query: query}) assert.Nil(t, response) assert.NotNil(t, err) @@ -277,7 +277,7 @@ func TestPostgresqlNoConnString(t *testing.T) { time.Sleep(waitRefreshPluginInterval) // when host name has no connection string - query := "{ getAllPeople { id name age } }" + query := "{ allPeople { id name age } }" response, err := runGraphqlQuery(graphQLRequest{Query: query}) assert.Nil(t, response) assert.NotNil(t, err) @@ -400,7 +400,7 @@ func (ps *postgresqlSuite) TearDownSuite() { func (ps *postgresqlSuite) TestPostgresqlBasicOps() { query := ` -query AddPerson { +mutation AddPerson { addPerson(name: "test", age: 21) { id name @@ -414,7 +414,7 @@ query AddPerson { func (ps *postgresqlSuite) TestPostgresqlWrongTypeInsert() { // try inserting data with wrong type, column: age query := ` -query AddPerson { +mutation AddPerson { addPerson(name: "test", age: "abc") { id name