Skip to content

Commit

Permalink
Fixes #3634: Updated ML procs for Azure OpenAI services (#3850) (#3863)…
Browse files Browse the repository at this point in the history
… (#3885)

* Fixes #3634: Updated ML procs for Azure OpenAI services

* Code clean

* added enpoint env vars

* Code clean part 2

* removed unused imports

---------

Co-authored-by: Andrea Santurbano <santand@gmail.com>
  • Loading branch information
vga91 and conker84 authored Jan 11, 2024
1 parent 08b98d7 commit ad8d265
Show file tree
Hide file tree
Showing 8 changed files with 334 additions and 60 deletions.
64 changes: 63 additions & 1 deletion docs/asciidoc/modules/ROOT/pages/ml/openai.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,60 @@

NOTE: You need to acquire an https://platform.openai.com/account/api-keys[OpenAI API key^] to use these procedures. Using them will incur costs on your OpenAI account. You can set the api key globally by defining the `apoc.openai.key` configuration in `apoc.conf`



All the following procedures can have the following APOC config, i.e. in `apoc.conf` or via docker env variable
.Apoc configuration
|===
|key | description | default
| apoc.ml.openai.type | "AZURE" or "OPENAI", indicates whether the API is Azure or not | "OPENAI"
| apoc.ml.openai.url | the OpenAI endpoint base url | https://api.openai.com/v1
(or empty string if `apoc.ml.openai.type=AZURE`)
| apoc.ml.azure.api.version | in case of `apoc.ml.openai.type=AZURE`, indicates the `api-version` to be passed after the `?api-version=` url
|===


Moreover, they can have the following configuration keys, as the last parameter.
If present, they take precedence over the analogous APOC configs.

.Common configuration parameter

|===
| key | description
| apiType | analogous to `apoc.ml.openai.type` APOC config
| endpoint | analogous to `apoc.ml.openai.url` APOC config
| apiVersion | analogous to `apoc.ml.azure.api.version` APOC config
|===


Therefore, we can use the following procedures with the Open AI Services provided by Azure,
pointing to the correct endpoints https://learn.microsoft.com/it-it/azure/ai-services/openai/reference[as explained in the documentation].

That is, if we want to call an endpoint like https://my-resource.openai.azure.com/openai/deployments/my-deployment-id/embeddings?api-version=my-api-version` for example,
by passing as a configuration parameter:
```
{endpoint: "https://my-resource.openai.azure.com/openai/deployments/my-deployment-id",
apiVersion: my-api-version,
apiType: 'AZURE'
}
```

The `/embeddings` portion will be added under-the-hood.
Similarly, if we use the `apoc.ml.openai.completion`, if we want to call an endpoint like `https://my-resource.openai.azure.com/openai/deployments/my-deployment-id/completions?api-version=my-api-version` for example,
we can write the same configuration parameter as above,
where the `/completions` portion will be added.

While using the `apoc.ml.openai.chat`, with the same configuration, the url portion `/chat/completions` will be added

Or else, we can write this `apoc.conf`:
```
apoc.ml.openai.url=https://my-resource.openai.azure.com/openai/deployments/my-deployment-id
apoc.ml.azure.api.version=my-api-version
apoc.ml.openai.type=AZURE
```



== Generate Embeddings API

This procedure `apoc.ml.openai.embedding` can take a list of text strings, and will return one row per string, with the embedding data as a 1536 element vector.
Expand All @@ -30,7 +84,15 @@ CALL apoc.ml.openai.embedding(['Some Text'], $apiKey, {}) yield index, text, emb
|name | description
| texts | List of text strings
| apiKey | OpenAI API key
| configuration | optional map for entries like model and other request parameters
| configuration | optional map for entries like model and other request parameters.

We can also pass a custom `endpoint: <MyAndPointKey>` entry (it takes precedence over the `apoc.ml.openai.url` config).
The `<MyAndPointKey>` can be the complete andpoint (e.g. using Azure: `https://my-resource.openai.azure.com/openai/deployments/my-deployment-id/chat/completions?api-version=my-api-version`),
or with a `%s` (e.g. using Azure: `https://my-resource.openai.azure.com/openai/deployments/my-deployment-id/%s?api-version=my-api-version`) which will eventually be replaced with `embeddings`, `chat/completion` and `completion`
by using respectively the `apoc.ml.openai.embedding`, `apoc.ml.openai.chat` and `apoc.ml.openai.completion`.

Or an `authType: `AUTH_TYPE`, which can be `authType: "BEARER"` (default config.), to pass the apiKey via the header as an `Authorization: Bearer $apiKey`,
or `authType: "API_KEY"` to pass the apiKey as an `api-key: $apiKey` header entry.
|===


Expand Down
3 changes: 3 additions & 0 deletions extended/src/main/java/apoc/ExtendedApocConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ public class ExtendedApocConfig extends LifecycleAdapter
public static final String APOC_UUID_ENABLED_DB = "apoc.uuid.enabled.%s";
public static final String APOC_UUID_FORMAT = "apoc.uuid.format";
public static final String APOC_OPENAI_KEY = "apoc.openai.key";
public static final String APOC_ML_OPENAI_URL = "apoc.ml.openai.url";
public static final String APOC_ML_OPENAI_TYPE = "apoc.ml.openai.type";
public static final String APOC_ML_OPENAI_AZURE_VERSION = "apoc.ml.azure.api.version";
public static final String APOC_AWS_KEY_ID = "apoc.aws.key.id";
public static final String APOC_AWS_SECRET_KEY = "apoc.aws.secret.key";
public enum UuidFormatType { hex, base64 }
Expand Down
43 changes: 29 additions & 14 deletions extended/src/main/java/apoc/ml/OpenAI.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import apoc.ApocConfig;
import apoc.Extended;
import apoc.result.MapResult;
import apoc.util.JsonUtil;
import com.fasterxml.jackson.core.JsonProcessingException;
import org.neo4j.graphdb.security.URLAccessChecker;
Expand All @@ -11,21 +12,23 @@
import org.neo4j.procedure.Procedure;

import java.net.MalformedURLException;
import java.net.URL;
import java.util.HashMap;
import java.util.Map;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.stream.Stream;

import apoc.result.MapResult;

import com.fasterxml.jackson.databind.ObjectMapper;

import static apoc.ExtendedApocConfig.APOC_ML_OPENAI_TYPE;
import static apoc.ExtendedApocConfig.APOC_OPENAI_KEY;


@Extended
public class OpenAI {
public static final String API_TYPE_CONF_KEY = "apiType";
public static final String APIKEY_CONF_KEY = "apiKey";
public static final String ENDPOINT_CONF_KEY = "endpoint";
public static final String API_VERSION_CONF_KEY = "apiVersion";

@Context
public ApocConfig apocConfig;

Expand All @@ -47,22 +50,34 @@ public EmbeddingResult(long index, String text, List<Double> embedding) {
}

static Stream<Object> executeRequest(String apiKey, Map<String, Object> configuration, String path, String model, String key, Object inputs, String jsonPath, ApocConfig apocConfig, URLAccessChecker urlAccessChecker) throws JsonProcessingException, MalformedURLException {
apiKey = apocConfig.getString(APOC_OPENAI_KEY, apiKey);
apiKey = (String) configuration.getOrDefault(APIKEY_CONF_KEY, apocConfig.getString(APOC_OPENAI_KEY, apiKey));
if (apiKey == null || apiKey.isBlank())
throw new IllegalArgumentException("API Key must not be empty");
String endpoint = System.getProperty(APOC_ML_OPENAI_URL,"https://api.openai.com/v1/");
Map<String, Object> headers = Map.of(
"Content-Type", "application/json",
"Authorization", "Bearer " + apiKey

String apiTypeString = (String) configuration.getOrDefault(API_TYPE_CONF_KEY,
apocConfig.getString(APOC_ML_OPENAI_TYPE, OpenAIRequestHandler.Type.OPENAI.name())
);
OpenAIRequestHandler apiType = OpenAIRequestHandler.Type.valueOf(apiTypeString.toUpperCase(Locale.ENGLISH))
.get();

String endpoint = apiType.getEndpoint(configuration, apocConfig);

final Map<String, Object> headers = new HashMap<>();
headers.put("Content-Type", "application/json");
apiType.addApiKey(headers, apiKey);

var config = new HashMap<>(configuration);
// we remove these keys from config, since the json payload is calculated starting from the config map
Stream.of(ENDPOINT_CONF_KEY, API_TYPE_CONF_KEY, API_VERSION_CONF_KEY, APIKEY_CONF_KEY).forEach(config::remove);
config.putIfAbsent("model", model);
config.put(key, inputs);

String payload = new ObjectMapper().writeValueAsString(config);

var url = new URL(new URL(endpoint), path).toString();
String payload = JsonUtil.OBJECT_MAPPER.writeValueAsString(config);

// new URL(endpoint), path) can produce a wrong path, since endpoint can have for example embedding,
// eg: https://my-resource.openai.azure.com/openai/deployments/apoc-embeddings-model
// therefore is better to join the not-empty path pieces
var url = apiType.getFullUrl(path, configuration, apocConfig);
return JsonUtil.loadJson(url, headers, payload, jsonPath, true, List.of(), urlAccessChecker);
}

Expand Down
87 changes: 87 additions & 0 deletions extended/src/main/java/apoc/ml/OpenAIRequestHandler.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
package apoc.ml;


import apoc.ApocConfig;
import org.apache.commons.lang3.StringUtils;

import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static apoc.ExtendedApocConfig.APOC_ML_OPENAI_AZURE_VERSION;
import static apoc.ExtendedApocConfig.APOC_ML_OPENAI_URL;
import static apoc.ml.OpenAI.API_VERSION_CONF_KEY;
import static apoc.ml.OpenAI.ENDPOINT_CONF_KEY;

abstract class OpenAIRequestHandler {

private final String defaultUrl;

public OpenAIRequestHandler(String defaultUrl) {
this.defaultUrl = defaultUrl;
}

public String getDefaultUrl() {
return defaultUrl;
}
public abstract String getApiVersion(Map<String, Object> configuration, ApocConfig apocConfig);
public abstract void addApiKey(Map<String, Object> headers, String apiKey);

public String getEndpoint(Map<String, Object> procConfig, ApocConfig apocConfig) {
return (String) procConfig.getOrDefault(ENDPOINT_CONF_KEY,
apocConfig.getString(APOC_ML_OPENAI_URL, System.getProperty(APOC_ML_OPENAI_URL, getDefaultUrl())));
}

public String getFullUrl(String method, Map<String, Object> procConfig, ApocConfig apocConfig) {
return Stream.of(getEndpoint(procConfig, apocConfig), method, getApiVersion(procConfig, apocConfig))
.filter(StringUtils::isNotBlank)
.collect(Collectors.joining("/"));
}

enum Type {
AZURE(new Azure(null)), OPENAI(new OpenAi("https://api.openai.com/v1"));

private final OpenAIRequestHandler handler;
Type(OpenAIRequestHandler handler) {
this.handler = handler;
}

public OpenAIRequestHandler get() {
return handler;
}
}

static class Azure extends OpenAIRequestHandler {

public Azure(String defaultUrl) {
super(defaultUrl);
}

@Override
public String getApiVersion(Map<String, Object> configuration, ApocConfig apocConfig) {
return "?api-version=" + configuration.getOrDefault(API_VERSION_CONF_KEY, apocConfig.getString(APOC_ML_OPENAI_AZURE_VERSION));
}

@Override
public void addApiKey(Map<String, Object> headers, String apiKey) {
headers.put("api-key", apiKey);
}
}

static class OpenAi extends OpenAIRequestHandler {

public OpenAi(String defaultUrl) {
super(defaultUrl);
}

@Override
public String getApiVersion(Map<String, Object> configuration, ApocConfig apocConfig) {
return "";
}

@Override
public void addApiKey(Map<String, Object> headers, String apiKey) {
headers.put("Authorization", "Bearer " + apiKey);
}
}
}
94 changes: 94 additions & 0 deletions extended/src/test/java/apoc/ml/OpenAIAzureIT.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
package apoc.ml;

import apoc.util.TestUtil;
import org.junit.BeforeClass;
import org.junit.ClassRule;
import org.junit.Ignore;
import org.junit.Test;
import org.neo4j.test.rule.DbmsRule;
import org.neo4j.test.rule.ImpermanentDbmsRule;

import java.util.Map;
import java.util.stream.Stream;

import static apoc.ml.OpenAI.API_TYPE_CONF_KEY;
import static apoc.ml.OpenAI.API_VERSION_CONF_KEY;
import static apoc.ml.OpenAI.ENDPOINT_CONF_KEY;
import static apoc.ml.OpenAITestResultUtils.assertChatCompletion;
import static apoc.ml.OpenAITestResultUtils.assertCompletion;
import static apoc.util.TestUtil.testCall;
import static org.junit.Assume.assumeNotNull;

public class OpenAIAzureIT {
// In Azure, the endpoints can be different
private static String OPENAI_EMBEDDING_URL;
private static String OPENAI_CHAT_URL;
private static String OPENAI_COMPLETION_URL;

private static String OPENAI_AZURE_API_VERSION;

private static String OPENAI_KEY;

@ClassRule
public static DbmsRule db = new ImpermanentDbmsRule();

@BeforeClass
public static void setUp() throws Exception {
OPENAI_KEY = System.getenv("OPENAI_KEY");
// Azure OpenAI base URLs
OPENAI_EMBEDDING_URL = System.getenv("OPENAI_EMBEDDING_URL");
OPENAI_CHAT_URL = System.getenv("OPENAI_CHAT_URL");
OPENAI_COMPLETION_URL = System.getenv("OPENAI_COMPLETION_URL");

// Azure OpenAI query url (`<baseURL>/<type>/?api-version=<OPENAI_AZURE_API_VERSION>`)
OPENAI_AZURE_API_VERSION = System.getenv("OPENAI_AZURE_API_VERSION");

Stream.of(OPENAI_EMBEDDING_URL,
OPENAI_CHAT_URL,
OPENAI_COMPLETION_URL,
OPENAI_AZURE_API_VERSION,
OPENAI_KEY)
.forEach(key -> assumeNotNull("No " + key + " environment configured", key));


TestUtil.registerProcedure(db, OpenAI.class);
}

@Test
public void embedding() {
testCall(db, "CALL apoc.ml.openai.embedding(['Some Text'], $apiKey, $conf)",
getParams(OPENAI_EMBEDDING_URL),
OpenAITestResultUtils::assertEmbeddings);
}


@Test
@Ignore("It returns wrong answers sometimes")
public void completion() {
testCall(db, "CALL apoc.ml.openai.completion('What color is the sky? Answer in one word: ', $apiKey, $conf)",
getParams(OPENAI_CHAT_URL),
(row) -> assertCompletion(row, "gpt-35-turbo"));
}

@Test
public void chatCompletion() {
testCall(db, """
CALL apoc.ml.openai.chat([
{role:"system", content:"Only answer with a single word"},
{role:"user", content:"What planet do humans live on?"}
], $apiKey, $conf)
""", getParams(OPENAI_COMPLETION_URL),
(row) -> assertChatCompletion(row, "gpt-35-turbo"));
}

private static Map<String, Object> getParams(String url) {
return Map.of("apiKey", OPENAI_KEY,
"conf", Map.of(ENDPOINT_CONF_KEY, url,
API_TYPE_CONF_KEY, OpenAIRequestHandler.Type.AZURE.name(),
API_VERSION_CONF_KEY, OPENAI_AZURE_API_VERSION,
// on Azure is available only "gpt-35-turbo"
"model", "gpt-35-turbo"
)
);
}
}
Loading

0 comments on commit ad8d265

Please sign in to comment.