From c33d8c31d2c3ffeccea40b1c12e5af268ee4d480 Mon Sep 17 00:00:00 2001 From: Christos Sotiriou <> Date: Tue, 23 Apr 2024 13:33:12 +0300 Subject: [PATCH] adding support for Azure AD Token Making the apiKey optional + better handling of misconfigurations further ironing out passing arguments for the OpenAI default implementation OpenAI rest client needs the api token passed, and there had to be an overriding in the builder to allow instantiation without passing the api key (since we may pass other token properties) Renamed configuration property, proper formatting, and doc update proper formatting renamed "azure-ad-token" to "ad-token" for the sake of brevity also updated documentation forgotten text --- docs/modules/ROOT/pages/openai.adoc | 9 +++- .../azure/openai/AzureOpenAiChatModel.java | 11 +++- .../openai/AzureOpenAiEmbeddingModel.java | 11 +++- .../azure/openai/AzureOpenAiImageModel.java | 14 +++-- .../openai/AzureOpenAiStreamingChatModel.java | 11 +++- .../openai/runtime/AzureOpenAiRecorder.java | 54 +++++++++++-------- .../config/LangChain4jAzureOpenAiConfig.java | 10 +++- .../AzureOpenAiRecorderEndpointTests.java | 9 +++- .../langchain4j/openai/OpenAiRestApi.java | 33 +++++++++--- .../openai/QuarkusOpenAiClient.java | 31 ++++++++++- 10 files changed, 152 insertions(+), 41 deletions(-) diff --git a/docs/modules/ROOT/pages/openai.adoc b/docs/modules/ROOT/pages/openai.adoc index edfcd1ee7..799159b50 100644 --- a/docs/modules/ROOT/pages/openai.adoc +++ b/docs/modules/ROOT/pages/openai.adoc @@ -60,11 +60,18 @@ When this extension is used, the following configuration properties are required [source, properties] ---- -quarkus.langchain4j.azure-openai.api-key= quarkus.langchain4j.azure-openai.resource-name= quarkus.langchain4j.azure-openai.deployment-name= + +# And one of the below depending on your scenario +quarkus.langchain4j.azure-openai.api-key= +quarkus.langchain4j.azure-openai.ad-token= ---- +In the case of Azure, the `api-key` and `ad-token` properties are mutually exclusive. The `api-key` property should be used when the Azure OpenAI service is configured to use API keys, while the `ad-token` property should be used when the Azure OpenAI service is configured to use Azure Active Directory tokens. + +In both cases, the key will be placed in the Authorization header when communicating with the Azure OpenAI service. + == Advanced usage `quarkus-langchain4j-openai` and `quarkus-langchain4j-azure-openai` extensions use a REST Client under the hood to make the REST calls required by LangChain4j. diff --git a/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/AzureOpenAiChatModel.java b/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/AzureOpenAiChatModel.java index 3cbee33db..0cfd97667 100644 --- a/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/AzureOpenAiChatModel.java +++ b/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/AzureOpenAiChatModel.java @@ -56,6 +56,7 @@ public class AzureOpenAiChatModel implements ChatLanguageModel, TokenCountEstima public AzureOpenAiChatModel(String endpoint, String apiVersion, String apiKey, + String adToken, Tokenizer tokenizer, Double temperature, Double topP, @@ -72,7 +73,6 @@ public AzureOpenAiChatModel(String endpoint, this.client = ((QuarkusOpenAiClient.Builder) OpenAiClient.builder() .baseUrl(ensureNotBlank(endpoint, "endpoint")) - .azureApiKey(apiKey) .apiVersion(apiVersion) .callTimeout(timeout) .connectTimeout(timeout) @@ -82,6 +82,8 @@ public AzureOpenAiChatModel(String endpoint, .logRequests(logRequests) .logResponses(logResponses)) .userAgent(Consts.DEFAULT_USER_AGENT) + .azureAdToken(adToken) + .azureApiKey(apiKey) .build(); this.temperature = getOrDefault(temperature, 0.7); this.topP = topP; @@ -149,6 +151,7 @@ public static class Builder { private String endpoint; private String apiVersion; private String apiKey; + private String adToken; private Tokenizer tokenizer; private Double temperature; private Double topP; @@ -195,6 +198,11 @@ public Builder apiKey(String apiKey) { return this; } + public Builder adToken(String adToken) { + this.adToken = adToken; + return this; + } + public Builder tokenizer(Tokenizer tokenizer) { this.tokenizer = tokenizer; return this; @@ -254,6 +262,7 @@ public AzureOpenAiChatModel build() { return new AzureOpenAiChatModel(endpoint, apiVersion, apiKey, + adToken, tokenizer, temperature, topP, diff --git a/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/AzureOpenAiEmbeddingModel.java b/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/AzureOpenAiEmbeddingModel.java index f9c47f4cb..b4d8a5673 100644 --- a/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/AzureOpenAiEmbeddingModel.java +++ b/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/AzureOpenAiEmbeddingModel.java @@ -52,6 +52,7 @@ public class AzureOpenAiEmbeddingModel implements EmbeddingModel, TokenCountEsti public AzureOpenAiEmbeddingModel(String endpoint, String apiVersion, String apiKey, + String adToken, Tokenizer tokenizer, Duration timeout, Integer maxRetries, @@ -63,7 +64,6 @@ public AzureOpenAiEmbeddingModel(String endpoint, this.client = ((QuarkusOpenAiClient.Builder) OpenAiClient.builder() .baseUrl(ensureNotBlank(endpoint, "endpoint")) - .azureApiKey(apiKey) .apiVersion(apiVersion) .callTimeout(timeout) .connectTimeout(timeout) @@ -73,6 +73,8 @@ public AzureOpenAiEmbeddingModel(String endpoint, .logRequests(logRequests) .logResponses(logResponses)) .userAgent(DEFAULT_USER_AGENT) + .azureAdToken(adToken) + .azureApiKey(apiKey) .build(); this.maxRetries = getOrDefault(maxRetries, 3); this.tokenizer = tokenizer; @@ -143,6 +145,7 @@ public static class Builder { private Proxy proxy; private Boolean logRequests; private Boolean logResponses; + private String adToken; /** * Sets the Azure OpenAI endpoint. This is a mandatory parameter. @@ -178,6 +181,11 @@ public Builder apiKey(String apiKey) { return this; } + public Builder adToken(String adToken) { + this.adToken = adToken; + return this; + } + public Builder tokenizer(Tokenizer tokenizer) { this.tokenizer = tokenizer; return this; @@ -212,6 +220,7 @@ public AzureOpenAiEmbeddingModel build() { return new AzureOpenAiEmbeddingModel(endpoint, apiVersion, apiKey, + adToken, tokenizer, timeout, maxRetries, diff --git a/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/AzureOpenAiImageModel.java b/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/AzureOpenAiImageModel.java index 87344be0d..ec502561d 100644 --- a/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/AzureOpenAiImageModel.java +++ b/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/AzureOpenAiImageModel.java @@ -43,7 +43,8 @@ public class AzureOpenAiImageModel implements ImageModel { private final QuarkusOpenAiClient client; - public AzureOpenAiImageModel(String endpoint, String apiKey, String apiVersion, String modelName, String size, + public AzureOpenAiImageModel(String endpoint, String apiKey, String adToken, String apiVersion, String modelName, + String size, String quality, String style, Optional user, String responseFormat, Duration timeout, Integer maxRetries, Boolean logRequests, Boolean logResponses, Optional persistDirectory) { @@ -58,7 +59,6 @@ public AzureOpenAiImageModel(String endpoint, String apiKey, String apiVersion, this.client = QuarkusOpenAiClient.builder() .baseUrl(ensureNotBlank(endpoint, "endpoint")) - .azureApiKey(apiKey) .apiVersion(apiVersion) .callTimeout(timeout) .connectTimeout(timeout) @@ -67,6 +67,8 @@ public AzureOpenAiImageModel(String endpoint, String apiKey, String apiVersion, .logRequests(logRequests) .logResponses(logResponses) .userAgent(DEFAULT_USER_AGENT) + .azureAdToken(adToken) + .azureApiKey(apiKey) .build(); } @@ -144,6 +146,7 @@ public static Builder builder() { public static class Builder { private String endpoint; private String apiKey; + private String adToken; private String apiVersion; private String modelName; private String size; @@ -167,6 +170,11 @@ public Builder apiKey(String apiKey) { return this; } + public Builder adToken(String adToken) { + this.adToken = adToken; + return this; + } + public Builder apiVersion(String apiVersion) { this.apiVersion = apiVersion; return this; @@ -228,7 +236,7 @@ public Builder persistDirectory(Optional persistDirectory) { } public AzureOpenAiImageModel build() { - return new AzureOpenAiImageModel(endpoint, apiKey, apiVersion, modelName, size, quality, style, user, + return new AzureOpenAiImageModel(endpoint, apiKey, adToken, apiVersion, modelName, size, quality, style, user, responseFormat, timeout, maxRetries, logRequests, logResponses, persistDirectory); } diff --git a/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/AzureOpenAiStreamingChatModel.java b/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/AzureOpenAiStreamingChatModel.java index 07e0cc920..586d0acde 100644 --- a/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/AzureOpenAiStreamingChatModel.java +++ b/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/AzureOpenAiStreamingChatModel.java @@ -62,6 +62,7 @@ public class AzureOpenAiStreamingChatModel implements StreamingChatLanguageModel public AzureOpenAiStreamingChatModel(String endpoint, String apiVersion, String apiKey, + String adToken, Tokenizer tokenizer, Double temperature, Double topP, @@ -77,7 +78,6 @@ public AzureOpenAiStreamingChatModel(String endpoint, this.client = ((QuarkusOpenAiClient.Builder) OpenAiClient.builder() .baseUrl(ensureNotBlank(endpoint, "endpoint")) - .azureApiKey(apiKey) .apiVersion(apiVersion) .callTimeout(timeout) .connectTimeout(timeout) @@ -87,6 +87,8 @@ public AzureOpenAiStreamingChatModel(String endpoint, .logRequests(logRequests) .logStreamingResponses(logResponses)) .userAgent(DEFAULT_USER_AGENT) + .azureAdToken(adToken) + .azureApiKey(apiKey) .build(); this.temperature = getOrDefault(temperature, 0.7); this.topP = topP; @@ -185,6 +187,7 @@ public static class Builder { private String endpoint; private String apiVersion; private String apiKey; + private String adToken; private Tokenizer tokenizer; private Double temperature; private Double topP; @@ -230,6 +233,11 @@ public Builder apiKey(String apiKey) { return this; } + public Builder adToken(String adToken) { + this.adToken = adToken; + return this; + } + public Builder tokenizer(Tokenizer tokenizer) { this.tokenizer = tokenizer; return this; @@ -284,6 +292,7 @@ public AzureOpenAiStreamingChatModel build() { return new AzureOpenAiStreamingChatModel(endpoint, apiVersion, apiKey, + adToken, tokenizer, temperature, topP, diff --git a/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/runtime/AzureOpenAiRecorder.java b/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/runtime/AzureOpenAiRecorder.java index ff86cfe10..e52786ff7 100644 --- a/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/runtime/AzureOpenAiRecorder.java +++ b/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/runtime/AzureOpenAiRecorder.java @@ -34,7 +34,6 @@ @Recorder public class AzureOpenAiRecorder { - private static final String DUMMY_KEY = "dummy"; static final String AZURE_ENDPOINT_URL_PATTERN = "https://%s.openai.azure.com/openai/deployments/%s"; public static final Problem[] EMPTY_PROBLEMS = new Problem[0]; @@ -43,13 +42,15 @@ public Supplier chatModel(LangChain4jAzureOpenAiConfig runtim if (azureAiConfig.enableIntegration()) { ChatModelConfig chatModelConfig = azureAiConfig.chatModel(); - String apiKey = azureAiConfig.apiKey(); - if (DUMMY_KEY.equals(apiKey)) { - throw new ConfigValidationException(createApiKeyConfigProblem(modelName)); - } + String apiKey = azureAiConfig.apiKey().orElse(null); + String adToken = azureAiConfig.adToken().orElse(null); + + throwIfApiKeysNotConfigured(apiKey, adToken, modelName); + var builder = AzureOpenAiChatModel.builder() .endpoint(getEndpoint(azureAiConfig, modelName)) .apiKey(apiKey) + .adToken(adToken) .apiVersion(azureAiConfig.apiVersion()) .timeout(azureAiConfig.timeout()) .maxRetries(azureAiConfig.maxRetries()) @@ -86,13 +87,15 @@ public Supplier streamingChatModel(LangChain4jAzureO if (azureAiConfig.enableIntegration()) { ChatModelConfig chatModelConfig = azureAiConfig.chatModel(); - String apiKey = azureAiConfig.apiKey(); - if (DUMMY_KEY.equals(apiKey)) { - throw new ConfigValidationException(createApiKeyConfigProblem(modelName)); - } + String apiKey = azureAiConfig.apiKey().orElse(null); + String adToken = azureAiConfig.adToken().orElse(null); + + throwIfApiKeysNotConfigured(apiKey, adToken, modelName); + var builder = AzureOpenAiStreamingChatModel.builder() .endpoint(getEndpoint(azureAiConfig, modelName)) .apiKey(apiKey) + .adToken(adToken) .apiVersion(azureAiConfig.apiVersion()) .timeout(azureAiConfig.timeout()) .logRequests(firstOrDefault(false, chatModelConfig.logRequests(), azureAiConfig.logRequests())) @@ -127,13 +130,15 @@ public Supplier embeddingModel(LangChain4jAzureOpenAiConfig runt if (azureAiConfig.enableIntegration()) { EmbeddingModelConfig embeddingModelConfig = azureAiConfig.embeddingModel(); - String apiKey = azureAiConfig.apiKey(); - if (DUMMY_KEY.equals(apiKey)) { - throw new ConfigValidationException(createApiKeyConfigProblem(modelName)); + String apiKey = azureAiConfig.apiKey().orElse(null); + String adToken = azureAiConfig.adToken().orElse(null); + if (apiKey == null && adToken == null) { + throw new ConfigValidationException(createKeyMisconfigurationProblem(modelName)); } var builder = AzureOpenAiEmbeddingModel.builder() .endpoint(getEndpoint(azureAiConfig, modelName)) .apiKey(apiKey) + .adToken(apiKey) .apiVersion(azureAiConfig.apiVersion()) .timeout(azureAiConfig.timeout()) .maxRetries(azureAiConfig.maxRetries()) @@ -160,16 +165,15 @@ public Supplier imageModel(LangChain4jAzureOpenAiConfig runtimeConfi LangChain4jAzureOpenAiConfig.AzureAiConfig azureAiConfig = correspondingAzureOpenAiConfig(runtimeConfig, modelName); if (azureAiConfig.enableIntegration()) { - var apiKey = azureAiConfig.apiKey(); - - if (DUMMY_KEY.equals(apiKey)) { - throw new ConfigValidationException(createApiKeyConfigProblem(modelName)); - } + var apiKey = azureAiConfig.apiKey().orElse(null); + String adToken = azureAiConfig.adToken().orElse(null); + throwIfApiKeysNotConfigured(apiKey, adToken, modelName); var imageModelConfig = azureAiConfig.imageModel(); var builder = AzureOpenAiImageModel.builder() .endpoint(getEndpoint(azureAiConfig, modelName)) .apiKey(apiKey) + .adToken(adToken) .apiVersion(azureAiConfig.apiVersion()) .timeout(azureAiConfig.timeout()) .maxRetries(azureAiConfig.maxRetries()) @@ -259,12 +263,20 @@ private LangChain4jAzureOpenAiConfig.AzureAiConfig correspondingAzureOpenAiConfi return azureAiConfig; } - private ConfigValidationException.Problem[] createApiKeyConfigProblem(String modelName) { - return createConfigProblems("api-key", modelName); + private void throwIfApiKeysNotConfigured(String apiKey, String adToken, String modelName) { + if ((apiKey != null) == (adToken != null)) { + throw new ConfigValidationException(createKeyMisconfigurationProblem(modelName)); + } } - private ConfigValidationException.Problem[] createConfigProblems(String key, String modelName) { - return new ConfigValidationException.Problem[] { createConfigProblem(key, modelName) }; + private ConfigValidationException.Problem[] createKeyMisconfigurationProblem(String modelName) { + return new ConfigValidationException.Problem[] { + new ConfigValidationException.Problem( + String.format( + "SRCFG00014: Exactly of the configuration properties must be present: quarkus.langchain4j.azure-openai%s%s or quarkus.langchain4j.azure-openai%s%s", + NamedModelUtil.isDefault(modelName) ? "." : ("." + modelName + "."), "api-key", + NamedModelUtil.isDefault(modelName) ? "." : ("." + modelName + "."), "ad-token")) + }; } private static ConfigValidationException.Problem createConfigProblem(String key, String modelName) { diff --git a/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/runtime/config/LangChain4jAzureOpenAiConfig.java b/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/runtime/config/LangChain4jAzureOpenAiConfig.java index 8492db2fe..3e8143258 100644 --- a/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/runtime/config/LangChain4jAzureOpenAiConfig.java +++ b/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/runtime/config/LangChain4jAzureOpenAiConfig.java @@ -68,6 +68,13 @@ interface AzureAiConfig { */ Optional endpoint(); + /** + * The Azure AD token to use for this operation. + * If present, then the requests towards OpenAI will include this in the Authorization header. + * Note that this property overrides the functionality of {@code quarkus.langchain4j.azure-openai.api-key}. + */ + Optional adToken(); + /** * The API version to use for this operation. This follows the YYYY-MM-DD format */ @@ -77,8 +84,7 @@ interface AzureAiConfig { /** * Azure OpenAI API key */ - @WithDefault("dummy") // TODO: this should be optional but Smallrye Config doesn't like it.. - String apiKey(); + Optional apiKey(); /** * Timeout for OpenAI calls diff --git a/openai/azure-openai/runtime/src/test/java/io/quarkiverse/langchain4j/azure/openai/runtime/AzureOpenAiRecorderEndpointTests.java b/openai/azure-openai/runtime/src/test/java/io/quarkiverse/langchain4j/azure/openai/runtime/AzureOpenAiRecorderEndpointTests.java index 07e952d14..08bb36c33 100644 --- a/openai/azure-openai/runtime/src/test/java/io/quarkiverse/langchain4j/azure/openai/runtime/AzureOpenAiRecorderEndpointTests.java +++ b/openai/azure-openai/runtime/src/test/java/io/quarkiverse/langchain4j/azure/openai/runtime/AzureOpenAiRecorderEndpointTests.java @@ -142,14 +142,19 @@ public Optional endpoint() { return Optional.empty(); } + @Override + public Optional adToken() { + return Optional.empty(); + } + @Override public String apiVersion() { return null; } @Override - public String apiKey() { - return "my-key"; + public Optional apiKey() { + return Optional.of("my-key"); } @Override diff --git a/openai/openai-common/runtime/src/main/java/io/quarkiverse/langchain4j/openai/OpenAiRestApi.java b/openai/openai-common/runtime/src/main/java/io/quarkiverse/langchain4j/openai/OpenAiRestApi.java index 84e26966d..d0284a03a 100644 --- a/openai/openai-common/runtime/src/main/java/io/quarkiverse/langchain4j/openai/OpenAiRestApi.java +++ b/openai/openai-common/runtime/src/main/java/io/quarkiverse/langchain4j/openai/OpenAiRestApi.java @@ -423,9 +423,20 @@ class ApiMetadata { @HeaderParam("OpenAI-Organization") public final String organizationId; - private ApiMetadata(String openaiApiKey, String azureApiKey, - String apiVersion, String organizationId) { - this.authorization = (openaiApiKey != null) ? "Bearer " + openaiApiKey : null; + private ApiMetadata( + String openaiApiKey, + String azureApiKey, + String azureAdToken, + String apiVersion, + String organizationId) { + if (azureAdToken != null) { + this.authorization = "Bearer " + azureAdToken; + } else if (openaiApiKey != null) { + this.authorization = "Bearer " + openaiApiKey; + } else { + this.authorization = null; + } + this.azureApiKey = azureApiKey; this.apiVersion = apiVersion; this.organizationId = organizationId; @@ -437,20 +448,26 @@ public static ApiMetadata.Builder builder() { public static class Builder { private String azureApiKey; + private String azureAdToken; private String openAiApiKey; private String apiVersion; private String organizationId; public ApiMetadata build() { - if ((azureApiKey != null) && (openAiApiKey != null)) { - return new ApiMetadata(openAiApiKey, azureApiKey, apiVersion, organizationId); + if (azureAdToken != null) { + return new ApiMetadata(null, null, azureAdToken, apiVersion, organizationId); } else if (azureApiKey != null) { - return new ApiMetadata(null, azureApiKey, apiVersion, organizationId); + return new ApiMetadata(null, azureApiKey, null, apiVersion, organizationId); } else if (openAiApiKey != null) { - return new ApiMetadata(openAiApiKey, null, apiVersion, organizationId); + return new ApiMetadata(openAiApiKey, null, null, apiVersion, organizationId); } - return new ApiMetadata(null, null, apiVersion, organizationId); + return new ApiMetadata(null, null, null, apiVersion, organizationId); + } + + public ApiMetadata.Builder azureAdToken(String azureAdToken) { + this.azureAdToken = azureAdToken; + return this; } public ApiMetadata.Builder azureApiKey(String azureApiKey) { diff --git a/openai/openai-common/runtime/src/main/java/io/quarkiverse/langchain4j/openai/QuarkusOpenAiClient.java b/openai/openai-common/runtime/src/main/java/io/quarkiverse/langchain4j/openai/QuarkusOpenAiClient.java index ff75a9c97..e87b46335 100644 --- a/openai/openai-common/runtime/src/main/java/io/quarkiverse/langchain4j/openai/QuarkusOpenAiClient.java +++ b/openai/openai-common/runtime/src/main/java/io/quarkiverse/langchain4j/openai/QuarkusOpenAiClient.java @@ -52,6 +52,7 @@ public class QuarkusOpenAiClient extends OpenAiClient { private final String azureApiKey; + private final String azureAdToken; private final String openaiApiKey; private final String apiVersion; private final String organizationId; @@ -77,6 +78,7 @@ private QuarkusOpenAiClient(Builder builder) { this.openaiApiKey = builder.openAiApiKey; this.apiVersion = builder.apiVersion; this.organizationId = builder.organizationId; + this.azureAdToken = builder.azureAdToken; // cache the client the builder could be called with the same parameters from multiple models this.restApi = cache.compute(builder, new BiFunction() { @Override @@ -129,6 +131,7 @@ public CompletionResponse execute() { CompletionRequest.builder().from(request).stream(null).build(), OpenAiRestApi.ApiMetadata.builder() .azureApiKey(azureApiKey) + .azureAdToken(azureAdToken) .openAiApiKey(openaiApiKey) .apiVersion(apiVersion) .organizationId(organizationId) @@ -187,6 +190,7 @@ public ChatCompletionResponse execute() { ChatCompletionRequest.builder().from(request).stream(null).build(), OpenAiRestApi.ApiMetadata.builder() .azureApiKey(azureApiKey) + .azureAdToken(azureAdToken) .openAiApiKey(openaiApiKey) .apiVersion(apiVersion) .organizationId(organizationId) @@ -244,6 +248,7 @@ public String execute() { .blockingChatCompletion(request, OpenAiRestApi.ApiMetadata.builder() .azureApiKey(azureApiKey) + .azureAdToken(azureAdToken) .openAiApiKey(openaiApiKey) .apiVersion(apiVersion) .organizationId(organizationId) @@ -315,6 +320,7 @@ public EmbeddingResponse execute() { return restApi.blockingEmbedding(request, OpenAiRestApi.ApiMetadata.builder() .azureApiKey(azureApiKey) + .azureAdToken(azureAdToken) .openAiApiKey(openaiApiKey) .apiVersion(apiVersion) .organizationId(organizationId) @@ -352,6 +358,7 @@ public List execute() { return restApi.blockingEmbedding(request, OpenAiRestApi.ApiMetadata.builder() .azureApiKey(azureApiKey) + .azureAdToken(azureAdToken) .openAiApiKey(openaiApiKey) .apiVersion(apiVersion) .organizationId(organizationId) @@ -388,6 +395,7 @@ public ModerationResponse execute() { return restApi.blockingModeration(request, OpenAiRestApi.ApiMetadata.builder() .azureApiKey(azureApiKey) + .azureAdToken(azureAdToken) .openAiApiKey(openaiApiKey) .apiVersion(apiVersion) .organizationId(organizationId) @@ -426,6 +434,7 @@ public ModerationResult execute() { return restApi.blockingModeration(request, OpenAiRestApi.ApiMetadata.builder() .azureApiKey(azureApiKey) + .azureAdToken(azureAdToken) .openAiApiKey(openaiApiKey) .apiVersion(apiVersion) .organizationId(organizationId) @@ -462,6 +471,7 @@ public GenerateImagesResponse execute() { return restApi.blockingImagesGenerations(generateImagesRequest, OpenAiRestApi.ApiMetadata.builder() .azureApiKey(azureApiKey) + .azureAdToken(azureAdToken) .openAiApiKey(openaiApiKey) .apiVersion(apiVersion) .organizationId(organizationId) @@ -504,12 +514,30 @@ public Builder get() { public static class Builder extends OpenAiClient.Builder { private String userAgent; + private String azureAdToken; public Builder userAgent(String userAgent) { this.userAgent = userAgent; return this; } + public Builder azureAdToken(String azureAdToken) { + this.azureAdToken = azureAdToken; + return this; + } + + @Override + public Builder openAiApiKey(String openAiApiKey) { + this.openAiApiKey = openAiApiKey; + return this; + } + + @Override + public Builder azureApiKey(String azureApiKey) { + this.azureApiKey = azureApiKey; + return this; + } + @Override public QuarkusOpenAiClient build() { return new QuarkusOpenAiClient(this); @@ -535,6 +563,7 @@ public boolean equals(Object o) { && Objects.equals(readTimeout, builder.readTimeout) && Objects.equals(writeTimeout, builder.writeTimeout) && Objects.equals(proxy, builder.proxy) + && Objects.equals(azureAdToken, builder.azureAdToken) && Objects.equals(userAgent, builder.userAgent); } @@ -542,7 +571,7 @@ public boolean equals(Object o) { public int hashCode() { return Objects.hash(baseUrl, apiVersion, openAiApiKey, azureApiKey, organizationId, callTimeout, connectTimeout, readTimeout, - writeTimeout, proxy, logRequests, logResponses, logStreamingResponses, userAgent); + writeTimeout, proxy, logRequests, logResponses, logStreamingResponses, userAgent, azureAdToken); } }