-
Notifications
You must be signed in to change notification settings - Fork 494
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Browse files
Browse the repository at this point in the history
… (#3885) * Fixes #3634: Updated ML procs for Azure OpenAI services * Code clean * added enpoint env vars * Code clean part 2 * removed unused imports --------- Co-authored-by: Andrea Santurbano <santand@gmail.com>
- Loading branch information
Showing
8 changed files
with
334 additions
and
60 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
package apoc.ml; | ||
|
||
|
||
import apoc.ApocConfig; | ||
import org.apache.commons.lang3.StringUtils; | ||
|
||
import java.util.Map; | ||
import java.util.stream.Collectors; | ||
import java.util.stream.Stream; | ||
|
||
import static apoc.ExtendedApocConfig.APOC_ML_OPENAI_AZURE_VERSION; | ||
import static apoc.ExtendedApocConfig.APOC_ML_OPENAI_URL; | ||
import static apoc.ml.OpenAI.API_VERSION_CONF_KEY; | ||
import static apoc.ml.OpenAI.ENDPOINT_CONF_KEY; | ||
|
||
abstract class OpenAIRequestHandler { | ||
|
||
private final String defaultUrl; | ||
|
||
public OpenAIRequestHandler(String defaultUrl) { | ||
this.defaultUrl = defaultUrl; | ||
} | ||
|
||
public String getDefaultUrl() { | ||
return defaultUrl; | ||
} | ||
public abstract String getApiVersion(Map<String, Object> configuration, ApocConfig apocConfig); | ||
public abstract void addApiKey(Map<String, Object> headers, String apiKey); | ||
|
||
public String getEndpoint(Map<String, Object> procConfig, ApocConfig apocConfig) { | ||
return (String) procConfig.getOrDefault(ENDPOINT_CONF_KEY, | ||
apocConfig.getString(APOC_ML_OPENAI_URL, System.getProperty(APOC_ML_OPENAI_URL, getDefaultUrl()))); | ||
} | ||
|
||
public String getFullUrl(String method, Map<String, Object> procConfig, ApocConfig apocConfig) { | ||
return Stream.of(getEndpoint(procConfig, apocConfig), method, getApiVersion(procConfig, apocConfig)) | ||
.filter(StringUtils::isNotBlank) | ||
.collect(Collectors.joining("/")); | ||
} | ||
|
||
enum Type { | ||
AZURE(new Azure(null)), OPENAI(new OpenAi("https://api.openai.com/v1")); | ||
|
||
private final OpenAIRequestHandler handler; | ||
Type(OpenAIRequestHandler handler) { | ||
this.handler = handler; | ||
} | ||
|
||
public OpenAIRequestHandler get() { | ||
return handler; | ||
} | ||
} | ||
|
||
static class Azure extends OpenAIRequestHandler { | ||
|
||
public Azure(String defaultUrl) { | ||
super(defaultUrl); | ||
} | ||
|
||
@Override | ||
public String getApiVersion(Map<String, Object> configuration, ApocConfig apocConfig) { | ||
return "?api-version=" + configuration.getOrDefault(API_VERSION_CONF_KEY, apocConfig.getString(APOC_ML_OPENAI_AZURE_VERSION)); | ||
} | ||
|
||
@Override | ||
public void addApiKey(Map<String, Object> headers, String apiKey) { | ||
headers.put("api-key", apiKey); | ||
} | ||
} | ||
|
||
static class OpenAi extends OpenAIRequestHandler { | ||
|
||
public OpenAi(String defaultUrl) { | ||
super(defaultUrl); | ||
} | ||
|
||
@Override | ||
public String getApiVersion(Map<String, Object> configuration, ApocConfig apocConfig) { | ||
return ""; | ||
} | ||
|
||
@Override | ||
public void addApiKey(Map<String, Object> headers, String apiKey) { | ||
headers.put("Authorization", "Bearer " + apiKey); | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
package apoc.ml; | ||
|
||
import apoc.util.TestUtil; | ||
import org.junit.BeforeClass; | ||
import org.junit.ClassRule; | ||
import org.junit.Ignore; | ||
import org.junit.Test; | ||
import org.neo4j.test.rule.DbmsRule; | ||
import org.neo4j.test.rule.ImpermanentDbmsRule; | ||
|
||
import java.util.Map; | ||
import java.util.stream.Stream; | ||
|
||
import static apoc.ml.OpenAI.API_TYPE_CONF_KEY; | ||
import static apoc.ml.OpenAI.API_VERSION_CONF_KEY; | ||
import static apoc.ml.OpenAI.ENDPOINT_CONF_KEY; | ||
import static apoc.ml.OpenAITestResultUtils.assertChatCompletion; | ||
import static apoc.ml.OpenAITestResultUtils.assertCompletion; | ||
import static apoc.util.TestUtil.testCall; | ||
import static org.junit.Assume.assumeNotNull; | ||
|
||
public class OpenAIAzureIT { | ||
// In Azure, the endpoints can be different | ||
private static String OPENAI_EMBEDDING_URL; | ||
private static String OPENAI_CHAT_URL; | ||
private static String OPENAI_COMPLETION_URL; | ||
|
||
private static String OPENAI_AZURE_API_VERSION; | ||
|
||
private static String OPENAI_KEY; | ||
|
||
@ClassRule | ||
public static DbmsRule db = new ImpermanentDbmsRule(); | ||
|
||
@BeforeClass | ||
public static void setUp() throws Exception { | ||
OPENAI_KEY = System.getenv("OPENAI_KEY"); | ||
// Azure OpenAI base URLs | ||
OPENAI_EMBEDDING_URL = System.getenv("OPENAI_EMBEDDING_URL"); | ||
OPENAI_CHAT_URL = System.getenv("OPENAI_CHAT_URL"); | ||
OPENAI_COMPLETION_URL = System.getenv("OPENAI_COMPLETION_URL"); | ||
|
||
// Azure OpenAI query url (`<baseURL>/<type>/?api-version=<OPENAI_AZURE_API_VERSION>`) | ||
OPENAI_AZURE_API_VERSION = System.getenv("OPENAI_AZURE_API_VERSION"); | ||
|
||
Stream.of(OPENAI_EMBEDDING_URL, | ||
OPENAI_CHAT_URL, | ||
OPENAI_COMPLETION_URL, | ||
OPENAI_AZURE_API_VERSION, | ||
OPENAI_KEY) | ||
.forEach(key -> assumeNotNull("No " + key + " environment configured", key)); | ||
|
||
|
||
TestUtil.registerProcedure(db, OpenAI.class); | ||
} | ||
|
||
@Test | ||
public void embedding() { | ||
testCall(db, "CALL apoc.ml.openai.embedding(['Some Text'], $apiKey, $conf)", | ||
getParams(OPENAI_EMBEDDING_URL), | ||
OpenAITestResultUtils::assertEmbeddings); | ||
} | ||
|
||
|
||
@Test | ||
@Ignore("It returns wrong answers sometimes") | ||
public void completion() { | ||
testCall(db, "CALL apoc.ml.openai.completion('What color is the sky? Answer in one word: ', $apiKey, $conf)", | ||
getParams(OPENAI_CHAT_URL), | ||
(row) -> assertCompletion(row, "gpt-35-turbo")); | ||
} | ||
|
||
@Test | ||
public void chatCompletion() { | ||
testCall(db, """ | ||
CALL apoc.ml.openai.chat([ | ||
{role:"system", content:"Only answer with a single word"}, | ||
{role:"user", content:"What planet do humans live on?"} | ||
], $apiKey, $conf) | ||
""", getParams(OPENAI_COMPLETION_URL), | ||
(row) -> assertChatCompletion(row, "gpt-35-turbo")); | ||
} | ||
|
||
private static Map<String, Object> getParams(String url) { | ||
return Map.of("apiKey", OPENAI_KEY, | ||
"conf", Map.of(ENDPOINT_CONF_KEY, url, | ||
API_TYPE_CONF_KEY, OpenAIRequestHandler.Type.AZURE.name(), | ||
API_VERSION_CONF_KEY, OPENAI_AZURE_API_VERSION, | ||
// on Azure is available only "gpt-35-turbo" | ||
"model", "gpt-35-turbo" | ||
) | ||
); | ||
} | ||
} |
Oops, something went wrong.