From ad8d265bdf2263f40cc5b740a2a0bfc9b30a9361 Mon Sep 17 00:00:00 2001 From: Giuseppe Villani Date: Thu, 11 Jan 2024 16:44:18 +0100 Subject: [PATCH] Fixes #3634: Updated ML procs for Azure OpenAI services (#3850) (#3863) (#3885) * Fixes #3634: Updated ML procs for Azure OpenAI services * Code clean * added enpoint env vars * Code clean part 2 * removed unused imports --------- Co-authored-by: Andrea Santurbano --- .../modules/ROOT/pages/ml/openai.adoc | 64 ++++++++++++- .../main/java/apoc/ExtendedApocConfig.java | 3 + extended/src/main/java/apoc/ml/OpenAI.java | 43 ++++++--- .../java/apoc/ml/OpenAIRequestHandler.java | 87 +++++++++++++++++ .../src/test/java/apoc/ml/OpenAIAzureIT.java | 94 +++++++++++++++++++ extended/src/test/java/apoc/ml/OpenAIIT.java | 52 ++-------- .../src/test/java/apoc/ml/OpenAITest.java | 3 +- .../java/apoc/ml/OpenAITestResultUtils.java | 48 ++++++++++ 8 files changed, 334 insertions(+), 60 deletions(-) create mode 100644 extended/src/main/java/apoc/ml/OpenAIRequestHandler.java create mode 100644 extended/src/test/java/apoc/ml/OpenAIAzureIT.java create mode 100644 extended/src/test/java/apoc/ml/OpenAITestResultUtils.java diff --git a/docs/asciidoc/modules/ROOT/pages/ml/openai.adoc b/docs/asciidoc/modules/ROOT/pages/ml/openai.adoc index de45ad97de..b0c34f0052 100644 --- a/docs/asciidoc/modules/ROOT/pages/ml/openai.adoc +++ b/docs/asciidoc/modules/ROOT/pages/ml/openai.adoc @@ -4,6 +4,60 @@ NOTE: You need to acquire an https://platform.openai.com/account/api-keys[OpenAI API key^] to use these procedures. Using them will incur costs on your OpenAI account. You can set the api key globally by defining the `apoc.openai.key` configuration in `apoc.conf` + + +All the following procedures can have the following APOC config, i.e. in `apoc.conf` or via docker env variable +.Apoc configuration +|=== +|key | description | default +| apoc.ml.openai.type | "AZURE" or "OPENAI", indicates whether the API is Azure or not | "OPENAI" +| apoc.ml.openai.url | the OpenAI endpoint base url | https://api.openai.com/v1 + (or empty string if `apoc.ml.openai.type=AZURE`) +| apoc.ml.azure.api.version | in case of `apoc.ml.openai.type=AZURE`, indicates the `api-version` to be passed after the `?api-version=` url +|=== + + +Moreover, they can have the following configuration keys, as the last parameter. +If present, they take precedence over the analogous APOC configs. + +.Common configuration parameter + +|=== +| key | description +| apiType | analogous to `apoc.ml.openai.type` APOC config +| endpoint | analogous to `apoc.ml.openai.url` APOC config +| apiVersion | analogous to `apoc.ml.azure.api.version` APOC config +|=== + + +Therefore, we can use the following procedures with the Open AI Services provided by Azure, +pointing to the correct endpoints https://learn.microsoft.com/it-it/azure/ai-services/openai/reference[as explained in the documentation]. + +That is, if we want to call an endpoint like https://my-resource.openai.azure.com/openai/deployments/my-deployment-id/embeddings?api-version=my-api-version` for example, +by passing as a configuration parameter: +``` + {endpoint: "https://my-resource.openai.azure.com/openai/deployments/my-deployment-id", + apiVersion: my-api-version, + apiType: 'AZURE' +} +``` + +The `/embeddings` portion will be added under-the-hood. +Similarly, if we use the `apoc.ml.openai.completion`, if we want to call an endpoint like `https://my-resource.openai.azure.com/openai/deployments/my-deployment-id/completions?api-version=my-api-version` for example, +we can write the same configuration parameter as above, +where the `/completions` portion will be added. + +While using the `apoc.ml.openai.chat`, with the same configuration, the url portion `/chat/completions` will be added + +Or else, we can write this `apoc.conf`: +``` +apoc.ml.openai.url=https://my-resource.openai.azure.com/openai/deployments/my-deployment-id +apoc.ml.azure.api.version=my-api-version +apoc.ml.openai.type=AZURE +``` + + + == Generate Embeddings API This procedure `apoc.ml.openai.embedding` can take a list of text strings, and will return one row per string, with the embedding data as a 1536 element vector. @@ -30,7 +84,15 @@ CALL apoc.ml.openai.embedding(['Some Text'], $apiKey, {}) yield index, text, emb |name | description | texts | List of text strings | apiKey | OpenAI API key -| configuration | optional map for entries like model and other request parameters +| configuration | optional map for entries like model and other request parameters. + + We can also pass a custom `endpoint: ` entry (it takes precedence over the `apoc.ml.openai.url` config). + The `` can be the complete andpoint (e.g. using Azure: `https://my-resource.openai.azure.com/openai/deployments/my-deployment-id/chat/completions?api-version=my-api-version`), + or with a `%s` (e.g. using Azure: `https://my-resource.openai.azure.com/openai/deployments/my-deployment-id/%s?api-version=my-api-version`) which will eventually be replaced with `embeddings`, `chat/completion` and `completion` + by using respectively the `apoc.ml.openai.embedding`, `apoc.ml.openai.chat` and `apoc.ml.openai.completion`. + + Or an `authType: `AUTH_TYPE`, which can be `authType: "BEARER"` (default config.), to pass the apiKey via the header as an `Authorization: Bearer $apiKey`, + or `authType: "API_KEY"` to pass the apiKey as an `api-key: $apiKey` header entry. |=== diff --git a/extended/src/main/java/apoc/ExtendedApocConfig.java b/extended/src/main/java/apoc/ExtendedApocConfig.java index 0008694725..9840b5df75 100644 --- a/extended/src/main/java/apoc/ExtendedApocConfig.java +++ b/extended/src/main/java/apoc/ExtendedApocConfig.java @@ -30,6 +30,9 @@ public class ExtendedApocConfig extends LifecycleAdapter public static final String APOC_UUID_ENABLED_DB = "apoc.uuid.enabled.%s"; public static final String APOC_UUID_FORMAT = "apoc.uuid.format"; public static final String APOC_OPENAI_KEY = "apoc.openai.key"; + public static final String APOC_ML_OPENAI_URL = "apoc.ml.openai.url"; + public static final String APOC_ML_OPENAI_TYPE = "apoc.ml.openai.type"; + public static final String APOC_ML_OPENAI_AZURE_VERSION = "apoc.ml.azure.api.version"; public static final String APOC_AWS_KEY_ID = "apoc.aws.key.id"; public static final String APOC_AWS_SECRET_KEY = "apoc.aws.secret.key"; public enum UuidFormatType { hex, base64 } diff --git a/extended/src/main/java/apoc/ml/OpenAI.java b/extended/src/main/java/apoc/ml/OpenAI.java index acc43c2ac8..3a019981ad 100644 --- a/extended/src/main/java/apoc/ml/OpenAI.java +++ b/extended/src/main/java/apoc/ml/OpenAI.java @@ -2,6 +2,7 @@ import apoc.ApocConfig; import apoc.Extended; +import apoc.result.MapResult; import apoc.util.JsonUtil; import com.fasterxml.jackson.core.JsonProcessingException; import org.neo4j.graphdb.security.URLAccessChecker; @@ -11,21 +12,23 @@ import org.neo4j.procedure.Procedure; import java.net.MalformedURLException; -import java.net.URL; import java.util.HashMap; -import java.util.Map; import java.util.List; +import java.util.Locale; +import java.util.Map; import java.util.stream.Stream; -import apoc.result.MapResult; - -import com.fasterxml.jackson.databind.ObjectMapper; - +import static apoc.ExtendedApocConfig.APOC_ML_OPENAI_TYPE; import static apoc.ExtendedApocConfig.APOC_OPENAI_KEY; @Extended public class OpenAI { + public static final String API_TYPE_CONF_KEY = "apiType"; + public static final String APIKEY_CONF_KEY = "apiKey"; + public static final String ENDPOINT_CONF_KEY = "endpoint"; + public static final String API_VERSION_CONF_KEY = "apiVersion"; + @Context public ApocConfig apocConfig; @@ -47,22 +50,34 @@ public EmbeddingResult(long index, String text, List embedding) { } static Stream executeRequest(String apiKey, Map configuration, String path, String model, String key, Object inputs, String jsonPath, ApocConfig apocConfig, URLAccessChecker urlAccessChecker) throws JsonProcessingException, MalformedURLException { - apiKey = apocConfig.getString(APOC_OPENAI_KEY, apiKey); + apiKey = (String) configuration.getOrDefault(APIKEY_CONF_KEY, apocConfig.getString(APOC_OPENAI_KEY, apiKey)); if (apiKey == null || apiKey.isBlank()) throw new IllegalArgumentException("API Key must not be empty"); - String endpoint = System.getProperty(APOC_ML_OPENAI_URL,"https://api.openai.com/v1/"); - Map headers = Map.of( - "Content-Type", "application/json", - "Authorization", "Bearer " + apiKey + + String apiTypeString = (String) configuration.getOrDefault(API_TYPE_CONF_KEY, + apocConfig.getString(APOC_ML_OPENAI_TYPE, OpenAIRequestHandler.Type.OPENAI.name()) ); + OpenAIRequestHandler apiType = OpenAIRequestHandler.Type.valueOf(apiTypeString.toUpperCase(Locale.ENGLISH)) + .get(); + + String endpoint = apiType.getEndpoint(configuration, apocConfig); + + final Map headers = new HashMap<>(); + headers.put("Content-Type", "application/json"); + apiType.addApiKey(headers, apiKey); var config = new HashMap<>(configuration); + // we remove these keys from config, since the json payload is calculated starting from the config map + Stream.of(ENDPOINT_CONF_KEY, API_TYPE_CONF_KEY, API_VERSION_CONF_KEY, APIKEY_CONF_KEY).forEach(config::remove); config.putIfAbsent("model", model); config.put(key, inputs); - String payload = new ObjectMapper().writeValueAsString(config); - - var url = new URL(new URL(endpoint), path).toString(); + String payload = JsonUtil.OBJECT_MAPPER.writeValueAsString(config); + + // new URL(endpoint), path) can produce a wrong path, since endpoint can have for example embedding, + // eg: https://my-resource.openai.azure.com/openai/deployments/apoc-embeddings-model + // therefore is better to join the not-empty path pieces + var url = apiType.getFullUrl(path, configuration, apocConfig); return JsonUtil.loadJson(url, headers, payload, jsonPath, true, List.of(), urlAccessChecker); } diff --git a/extended/src/main/java/apoc/ml/OpenAIRequestHandler.java b/extended/src/main/java/apoc/ml/OpenAIRequestHandler.java new file mode 100644 index 0000000000..339c9cc106 --- /dev/null +++ b/extended/src/main/java/apoc/ml/OpenAIRequestHandler.java @@ -0,0 +1,87 @@ +package apoc.ml; + + +import apoc.ApocConfig; +import org.apache.commons.lang3.StringUtils; + +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static apoc.ExtendedApocConfig.APOC_ML_OPENAI_AZURE_VERSION; +import static apoc.ExtendedApocConfig.APOC_ML_OPENAI_URL; +import static apoc.ml.OpenAI.API_VERSION_CONF_KEY; +import static apoc.ml.OpenAI.ENDPOINT_CONF_KEY; + +abstract class OpenAIRequestHandler { + + private final String defaultUrl; + + public OpenAIRequestHandler(String defaultUrl) { + this.defaultUrl = defaultUrl; + } + + public String getDefaultUrl() { + return defaultUrl; + } + public abstract String getApiVersion(Map configuration, ApocConfig apocConfig); + public abstract void addApiKey(Map headers, String apiKey); + + public String getEndpoint(Map procConfig, ApocConfig apocConfig) { + return (String) procConfig.getOrDefault(ENDPOINT_CONF_KEY, + apocConfig.getString(APOC_ML_OPENAI_URL, System.getProperty(APOC_ML_OPENAI_URL, getDefaultUrl()))); + } + + public String getFullUrl(String method, Map procConfig, ApocConfig apocConfig) { + return Stream.of(getEndpoint(procConfig, apocConfig), method, getApiVersion(procConfig, apocConfig)) + .filter(StringUtils::isNotBlank) + .collect(Collectors.joining("/")); + } + + enum Type { + AZURE(new Azure(null)), OPENAI(new OpenAi("https://api.openai.com/v1")); + + private final OpenAIRequestHandler handler; + Type(OpenAIRequestHandler handler) { + this.handler = handler; + } + + public OpenAIRequestHandler get() { + return handler; + } + } + + static class Azure extends OpenAIRequestHandler { + + public Azure(String defaultUrl) { + super(defaultUrl); + } + + @Override + public String getApiVersion(Map configuration, ApocConfig apocConfig) { + return "?api-version=" + configuration.getOrDefault(API_VERSION_CONF_KEY, apocConfig.getString(APOC_ML_OPENAI_AZURE_VERSION)); + } + + @Override + public void addApiKey(Map headers, String apiKey) { + headers.put("api-key", apiKey); + } + } + + static class OpenAi extends OpenAIRequestHandler { + + public OpenAi(String defaultUrl) { + super(defaultUrl); + } + + @Override + public String getApiVersion(Map configuration, ApocConfig apocConfig) { + return ""; + } + + @Override + public void addApiKey(Map headers, String apiKey) { + headers.put("Authorization", "Bearer " + apiKey); + } + } +} diff --git a/extended/src/test/java/apoc/ml/OpenAIAzureIT.java b/extended/src/test/java/apoc/ml/OpenAIAzureIT.java new file mode 100644 index 0000000000..506cb901fe --- /dev/null +++ b/extended/src/test/java/apoc/ml/OpenAIAzureIT.java @@ -0,0 +1,94 @@ +package apoc.ml; + +import apoc.util.TestUtil; +import org.junit.BeforeClass; +import org.junit.ClassRule; +import org.junit.Ignore; +import org.junit.Test; +import org.neo4j.test.rule.DbmsRule; +import org.neo4j.test.rule.ImpermanentDbmsRule; + +import java.util.Map; +import java.util.stream.Stream; + +import static apoc.ml.OpenAI.API_TYPE_CONF_KEY; +import static apoc.ml.OpenAI.API_VERSION_CONF_KEY; +import static apoc.ml.OpenAI.ENDPOINT_CONF_KEY; +import static apoc.ml.OpenAITestResultUtils.assertChatCompletion; +import static apoc.ml.OpenAITestResultUtils.assertCompletion; +import static apoc.util.TestUtil.testCall; +import static org.junit.Assume.assumeNotNull; + +public class OpenAIAzureIT { + // In Azure, the endpoints can be different + private static String OPENAI_EMBEDDING_URL; + private static String OPENAI_CHAT_URL; + private static String OPENAI_COMPLETION_URL; + + private static String OPENAI_AZURE_API_VERSION; + + private static String OPENAI_KEY; + + @ClassRule + public static DbmsRule db = new ImpermanentDbmsRule(); + + @BeforeClass + public static void setUp() throws Exception { + OPENAI_KEY = System.getenv("OPENAI_KEY"); + // Azure OpenAI base URLs + OPENAI_EMBEDDING_URL = System.getenv("OPENAI_EMBEDDING_URL"); + OPENAI_CHAT_URL = System.getenv("OPENAI_CHAT_URL"); + OPENAI_COMPLETION_URL = System.getenv("OPENAI_COMPLETION_URL"); + + // Azure OpenAI query url (`//?api-version=`) + OPENAI_AZURE_API_VERSION = System.getenv("OPENAI_AZURE_API_VERSION"); + + Stream.of(OPENAI_EMBEDDING_URL, + OPENAI_CHAT_URL, + OPENAI_COMPLETION_URL, + OPENAI_AZURE_API_VERSION, + OPENAI_KEY) + .forEach(key -> assumeNotNull("No " + key + " environment configured", key)); + + + TestUtil.registerProcedure(db, OpenAI.class); + } + + @Test + public void embedding() { + testCall(db, "CALL apoc.ml.openai.embedding(['Some Text'], $apiKey, $conf)", + getParams(OPENAI_EMBEDDING_URL), + OpenAITestResultUtils::assertEmbeddings); + } + + + @Test + @Ignore("It returns wrong answers sometimes") + public void completion() { + testCall(db, "CALL apoc.ml.openai.completion('What color is the sky? Answer in one word: ', $apiKey, $conf)", + getParams(OPENAI_CHAT_URL), + (row) -> assertCompletion(row, "gpt-35-turbo")); + } + + @Test + public void chatCompletion() { + testCall(db, """ + CALL apoc.ml.openai.chat([ + {role:"system", content:"Only answer with a single word"}, + {role:"user", content:"What planet do humans live on?"} + ], $apiKey, $conf) + """, getParams(OPENAI_COMPLETION_URL), + (row) -> assertChatCompletion(row, "gpt-35-turbo")); + } + + private static Map getParams(String url) { + return Map.of("apiKey", OPENAI_KEY, + "conf", Map.of(ENDPOINT_CONF_KEY, url, + API_TYPE_CONF_KEY, OpenAIRequestHandler.Type.AZURE.name(), + API_VERSION_CONF_KEY, OPENAI_AZURE_API_VERSION, + // on Azure is available only "gpt-35-turbo" + "model", "gpt-35-turbo" + ) + ); + } +} \ No newline at end of file diff --git a/extended/src/test/java/apoc/ml/OpenAIIT.java b/extended/src/test/java/apoc/ml/OpenAIIT.java index 59deba1863..68c543d7d1 100644 --- a/extended/src/test/java/apoc/ml/OpenAIIT.java +++ b/extended/src/test/java/apoc/ml/OpenAIIT.java @@ -8,12 +8,11 @@ import org.neo4j.test.rule.DbmsRule; import org.neo4j.test.rule.ImpermanentDbmsRule; -import java.util.List; import java.util.Map; +import static apoc.ml.OpenAITestResultUtils.assertChatCompletion; +import static apoc.ml.OpenAITestResultUtils.assertCompletion; import static apoc.util.TestUtil.testCall; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertTrue; public class OpenAIIT { @@ -34,34 +33,15 @@ public void setUp() throws Exception { @Test public void getEmbedding() { - testCall(db, "CALL apoc.ml.openai.embedding(['Some Text'], $apiKey)", Map.of("apiKey",openaiKey),(row) -> { - System.out.println("row = " + row); - assertEquals(0L, row.get("index")); - assertEquals("Some Text", row.get("text")); - var embedding = (List) row.get("embedding"); - assertEquals(1536, embedding.size()); - assertEquals(true, embedding.stream().allMatch(d -> d instanceof Double)); - }); + testCall(db, "CALL apoc.ml.openai.embedding(['Some Text'], $apiKey)", Map.of("apiKey",openaiKey), + OpenAITestResultUtils::assertEmbeddings); } @Test public void completion() { testCall(db, "CALL apoc.ml.openai.completion('What color is the sky? Answer in one word: ', $apiKey)", - Map.of("apiKey",openaiKey),(row) -> { - System.out.println("row = " + row); - var result = (Map)row.get("value"); - assertEquals(true, result.get("created") instanceof Number); - assertEquals(true, result.containsKey("choices")); - var finishReason = (String)((List) result.get("choices")).get(0).get("finish_reason"); - assertEquals(true, finishReason.matches("stop|length")); - String text = (String) ((List) result.get("choices")).get(0).get("text"); - assertEquals(true, text != null && !text.isBlank()); - assertEquals(true, text.toLowerCase().contains("blue")); - assertEquals(true, result.containsKey("usage")); - assertEquals(true, ((Map)result.get("usage")).get("prompt_tokens") instanceof Number); - assertEquals("text-davinci-003", result.get("model")); - assertEquals("text_completion", result.get("object")); - }); + Map.of("apiKey", openaiKey), + (row) -> assertCompletion(row, "text-davinci-003")); } @Test @@ -71,24 +51,8 @@ public void chatCompletion() { {role:"system", content:"Only answer with a single word"}, {role:"user", content:"What planet do humans live on?"} ], $apiKey) -""", Map.of("apiKey",openaiKey), (row) -> { - System.out.println("row = " + row); - var result = (Map)row.get("value"); - assertEquals(true, result.get("created") instanceof Number); - assertEquals(true, result.containsKey("choices")); - - Map message = ((List>) result.get("choices")).get(0).get("message"); - assertEquals("assistant", message.get("role")); - // assertEquals("stop", message.get("finish_reason")); - String text = (String) message.get("content"); - assertEquals(true, text != null && !text.isBlank()); - - - assertEquals(true, result.containsKey("usage")); - assertEquals(true, ((Map)result.get("usage")).get("prompt_tokens") instanceof Number); - assertTrue(result.get("model").toString().startsWith("gpt-3.5-turbo")); - assertEquals("chat.completion", result.get("object")); - }); +""", Map.of("apiKey",openaiKey), + (row) -> assertChatCompletion(row, "gpt-3.5-turbo")); /* { diff --git a/extended/src/test/java/apoc/ml/OpenAITest.java b/extended/src/test/java/apoc/ml/OpenAITest.java index 87dc84b100..b036c5dda0 100644 --- a/extended/src/test/java/apoc/ml/OpenAITest.java +++ b/extended/src/test/java/apoc/ml/OpenAITest.java @@ -13,6 +13,7 @@ import static apoc.ApocConfig.APOC_IMPORT_FILE_ENABLED; import static apoc.ApocConfig.apocConfig; +import static apoc.ExtendedApocConfig.APOC_ML_OPENAI_URL; import static apoc.util.TestUtil.getUrlFileName; import static apoc.util.TestUtil.testCall; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -32,7 +33,7 @@ public void setUp() throws Exception { // openaiKey = System.getenv("OPENAI_KEY"); // Assume.assumeNotNull("No OPENAI_KEY environment configured", openaiKey); var path = Paths.get(getUrlFileName("embeddings").toURI()).getParent().toUri(); - System.setProperty(OpenAI.APOC_ML_OPENAI_URL, path.toString()); + System.setProperty(APOC_ML_OPENAI_URL, path.toString()); apocConfig().setProperty(APOC_IMPORT_FILE_ENABLED, true); TestUtil.registerProcedure(db, OpenAI.class); } diff --git a/extended/src/test/java/apoc/ml/OpenAITestResultUtils.java b/extended/src/test/java/apoc/ml/OpenAITestResultUtils.java new file mode 100644 index 0000000000..5e93e1c5a7 --- /dev/null +++ b/extended/src/test/java/apoc/ml/OpenAITestResultUtils.java @@ -0,0 +1,48 @@ +package apoc.ml; + +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class OpenAITestResultUtils { + public static void assertEmbeddings(Map row) { + assertEquals(0L, row.get("index")); + assertEquals("Some Text", row.get("text")); + var embedding = (List) row.get("embedding"); + assertEquals(1536, embedding.size()); + } + + public static void assertCompletion(Map row, String expectedModel) { + var result = (Map) row.get("value"); + assertTrue(result.get("created") instanceof Number); + assertTrue(result.containsKey("choices")); + var finishReason = (String)((List) result.get("choices")).get(0).get("finish_reason"); + assertTrue(finishReason.matches("stop|length")); + String text = (String) ((List) result.get("choices")).get(0).get("text"); + System.out.println("OpenAI text response for assertCompletion = " + text); + assertTrue(text != null && !text.isBlank()); + assertTrue(text.toLowerCase().contains("blue")); + assertTrue(result.containsKey("usage")); + assertTrue(((Map) result.get("usage")).get("prompt_tokens") instanceof Number); + assertEquals(expectedModel, result.get("model")); + assertEquals("text_completion", result.get("object")); + } + + public static void assertChatCompletion(Map row, String modelId) { + var result = (Map) row.get("value"); + assertTrue(result.get("created") instanceof Number); + assertTrue(result.containsKey("choices")); + + Map message = ((List>) result.get("choices")).get(0).get("message"); + assertEquals("assistant", message.get("role")); + String text = (String) message.get("content"); + assertTrue(text != null && !text.isBlank()); + + assertTrue(result.containsKey("usage")); + assertTrue(((Map) result.get("usage")).get("prompt_tokens") instanceof Number); + assertEquals("chat.completion", result.get("object")); + assertTrue(result.get("model").toString().startsWith(modelId)); + } +}