diff --git a/dgraph/dgraph.go b/dgraph/dgraph.go index ef909c5a..c23ca105 100644 --- a/dgraph/dgraph.go +++ b/dgraph/dgraph.go @@ -11,7 +11,6 @@ import ( "io" "log" "net/http" - "strings" "time" "github.com/dgraph-io/gqlparser/ast" @@ -21,25 +20,47 @@ import ( var DgraphUrl *string -func ExecuteDQL(ctx context.Context, stmt string, isMutation bool) ([]byte, error) { - reqBody := strings.NewReader(stmt) +func executePostRequestWithVars(stmt string, vars map[string]string, endpoint string) (*http.Response, error) { + payload := map[string]any{ + "query": stmt, + "variables": vars, + } + + // Convert payload to JSON + jsonPayload, err := json.Marshal(payload) + if err != nil { + return nil, fmt.Errorf("error marshaling payload: %w", err) + } + + // Create the HTTP request + req, err := http.NewRequest("POST", endpoint, bytes.NewBuffer(jsonPayload)) + if err != nil { + return nil, fmt.Errorf("error creating request: %w", err) + } + + // Set headers + req.Header.Set("Content-Type", "application/json") + + // Perform the request + return httpClient.Do(req) + +} +func ExecuteDQL(ctx context.Context, stmt string, vars map[string]string, isMutation bool) ([]byte, error) { host := *DgraphUrl - var endpoint, contentType string + var endpoint string if isMutation { endpoint = "/mutate?commitNow=true" - contentType = "application/rdf" } else { endpoint = "/query" - contentType = "application/dql" } - resp, err := http.Post(host+endpoint, contentType, reqBody) + resp, err := executePostRequestWithVars(stmt, vars, host+endpoint) if err != nil { return nil, fmt.Errorf("error posting DQL statement: %w", err) } - defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("DQL operation failed with status code %d", resp.StatusCode) } @@ -52,14 +73,14 @@ func ExecuteDQL(ctx context.Context, stmt string, isMutation bool) ([]byte, erro return respBody, nil } -func ExecuteGQL(ctx context.Context, stmt string) ([]byte, error) { - reqBody := strings.NewReader(stmt) - resp, err := http.Post(fmt.Sprintf("%s/graphql", *DgraphUrl), "application/graphql", reqBody) +func ExecuteGQL(ctx context.Context, stmt string, vars map[string]string) ([]byte, error) { + // Perform the request + resp, err := executePostRequestWithVars(stmt, vars, fmt.Sprintf("%s/graphql", *DgraphUrl)) if err != nil { return nil, fmt.Errorf("error posting GraphQL statement: %w", err) } - defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("GraphQL operation failed with status code %d", resp.StatusCode) } @@ -86,7 +107,7 @@ var schemaQuery = "{node(func:has(dgraph.graphql.schema)){dgraph.graphql.schema} func GetGQLSchema(ctx context.Context) (string, error) { - r, err := ExecuteDQL(ctx, schemaQuery, false) + r, err := ExecuteDQL(ctx, schemaQuery, nil, false) if err != nil { return "", fmt.Errorf("error getting GraphQL schema from Dgraph: %w", err) } @@ -154,7 +175,7 @@ func (schema functionSchema) FunctionArgs() ast.ArgumentDefinitionList { if err == nil { var list ast.ArgumentDefinitionList var argName string - for _, val := range v.([]interface{}) { + for _, val := range v.([]any) { argName = val.(string) fld := schema.ObjectDef.Fields.ForName(argName) if fld == nil { @@ -237,9 +258,9 @@ func GetModelEndpoint(mid string) (string, error) { } }` - payload := map[string]interface{}{ + payload := map[string]any{ "query": query, - "variables": map[string]string{"id": mid}, + "variables": map[string]any{"id": mid}, } // Convert payload to JSON diff --git a/functions/hostfns.go b/functions/hostfns.go index 193f89b2..e72bf1b6 100644 --- a/functions/hostfns.go +++ b/functions/hostfns.go @@ -42,7 +42,7 @@ func InstantiateHostFunctions(ctx context.Context, runtime wazero.Runtime) error return nil } -func hostExecuteDQL(ctx context.Context, mod wasm.Module, pStmt uint32, isMutation uint32) uint32 { +func hostExecuteDQL(ctx context.Context, mod wasm.Module, pStmt uint32, pVars uint32, isMutation uint32) uint32 { mem := mod.Memory() stmt, err := readString(mem, pStmt) if err != nil { @@ -50,7 +50,19 @@ func hostExecuteDQL(ctx context.Context, mod wasm.Module, pStmt uint32, isMutati return 0 } - r, err := dgraph.ExecuteDQL(ctx, stmt, isMutation != 0) + sVars, err := readString(mem, pVars) + if err != nil { + log.Println("error reading DQL variables string from wasm memory:", err) + return 0 + } + + vars := make(map[string]string) + if err := json.Unmarshal([]byte(sVars), &vars); err != nil { + log.Println("error unmarshaling GraphQL variables:", err) + return 0 + } + + r, err := dgraph.ExecuteDQL(ctx, stmt, vars, isMutation != 0) if err != nil { log.Println("error executing DQL statement:", err) return 0 @@ -59,15 +71,27 @@ func hostExecuteDQL(ctx context.Context, mod wasm.Module, pStmt uint32, isMutati return writeString(ctx, mod, string(r)) } -func hostExecuteGQL(ctx context.Context, mod wasm.Module, pStmt uint32) uint32 { +func hostExecuteGQL(ctx context.Context, mod wasm.Module, pStmt uint32, pVars uint32) uint32 { mem := mod.Memory() stmt, err := readString(mem, pStmt) if err != nil { - log.Println("error reading GraphQL string from wasm memory:", err) + log.Println("error reading GraphQL query string from wasm memory:", err) + return 0 + } + + sVars, err := readString(mem, pVars) + if err != nil { + log.Println("error reading GraphQL variables string from wasm memory:", err) + return 0 + } + + vars := make(map[string]string) + if err := json.Unmarshal([]byte(sVars), &vars); err != nil { + log.Println("error unmarshaling GraphQL variables:", err) return 0 } - r, err := dgraph.ExecuteGQL(ctx, stmt) + r, err := dgraph.ExecuteGQL(ctx, stmt, vars) if err != nil { log.Println("error executing GraphQL operation:", err) return 0