diff --git a/docs/modules/ROOT/pages/openai.adoc b/docs/modules/ROOT/pages/openai.adoc index edfcd1ee7..799159b50 100644 --- a/docs/modules/ROOT/pages/openai.adoc +++ b/docs/modules/ROOT/pages/openai.adoc @@ -60,11 +60,18 @@ When this extension is used, the following configuration properties are required [source, properties] ---- -quarkus.langchain4j.azure-openai.api-key= quarkus.langchain4j.azure-openai.resource-name= quarkus.langchain4j.azure-openai.deployment-name= + +# And one of the below depending on your scenario +quarkus.langchain4j.azure-openai.api-key= +quarkus.langchain4j.azure-openai.ad-token= ---- +In the case of Azure, the `api-key` and `ad-token` properties are mutually exclusive. The `api-key` property should be used when the Azure OpenAI service is configured to use API keys, while the `ad-token` property should be used when the Azure OpenAI service is configured to use Azure Active Directory tokens. + +In both cases, the key will be placed in the Authorization header when communicating with the Azure OpenAI service. + == Advanced usage `quarkus-langchain4j-openai` and `quarkus-langchain4j-azure-openai` extensions use a REST Client under the hood to make the REST calls required by LangChain4j. diff --git a/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/AzureOpenAiChatModel.java b/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/AzureOpenAiChatModel.java index 1c7eea884..753fc99ba 100644 --- a/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/AzureOpenAiChatModel.java +++ b/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/AzureOpenAiChatModel.java @@ -57,6 +57,7 @@ public class AzureOpenAiChatModel implements ChatLanguageModel, TokenCountEstima public AzureOpenAiChatModel(String endpoint, String apiVersion, String apiKey, + String adToken, Tokenizer tokenizer, Double temperature, Double topP, @@ -74,7 +75,6 @@ public AzureOpenAiChatModel(String endpoint, this.client = ((QuarkusOpenAiClient.Builder) OpenAiClient.builder() .baseUrl(ensureNotBlank(endpoint, "endpoint")) - .azureApiKey(apiKey) .apiVersion(apiVersion) .callTimeout(timeout) .connectTimeout(timeout) @@ -84,6 +84,8 @@ public AzureOpenAiChatModel(String endpoint, .logRequests(logRequests) .logResponses(logResponses)) .userAgent(Consts.DEFAULT_USER_AGENT) + .azureAdToken(adToken) + .azureApiKey(apiKey) .build(); this.temperature = getOrDefault(temperature, 0.7); @@ -154,6 +156,7 @@ public static class Builder { private String endpoint; private String apiVersion; private String apiKey; + private String adToken; private Tokenizer tokenizer; private Double temperature; private Double topP; @@ -201,6 +204,11 @@ public Builder apiKey(String apiKey) { return this; } + public Builder adToken(String adToken) { + this.adToken = adToken; + return this; + } + public Builder tokenizer(Tokenizer tokenizer) { this.tokenizer = tokenizer; return this; @@ -265,6 +273,7 @@ public AzureOpenAiChatModel build() { return new AzureOpenAiChatModel(endpoint, apiVersion, apiKey, + adToken, tokenizer, temperature, topP, diff --git a/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/AzureOpenAiEmbeddingModel.java b/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/AzureOpenAiEmbeddingModel.java index f9c47f4cb..b4d8a5673 100644 --- a/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/AzureOpenAiEmbeddingModel.java +++ b/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/AzureOpenAiEmbeddingModel.java @@ -52,6 +52,7 @@ public class AzureOpenAiEmbeddingModel implements EmbeddingModel, TokenCountEsti public AzureOpenAiEmbeddingModel(String endpoint, String apiVersion, String apiKey, + String adToken, Tokenizer tokenizer, Duration timeout, Integer maxRetries, @@ -63,7 +64,6 @@ public AzureOpenAiEmbeddingModel(String endpoint, this.client = ((QuarkusOpenAiClient.Builder) OpenAiClient.builder() .baseUrl(ensureNotBlank(endpoint, "endpoint")) - .azureApiKey(apiKey) .apiVersion(apiVersion) .callTimeout(timeout) .connectTimeout(timeout) @@ -73,6 +73,8 @@ public AzureOpenAiEmbeddingModel(String endpoint, .logRequests(logRequests) .logResponses(logResponses)) .userAgent(DEFAULT_USER_AGENT) + .azureAdToken(adToken) + .azureApiKey(apiKey) .build(); this.maxRetries = getOrDefault(maxRetries, 3); this.tokenizer = tokenizer; @@ -143,6 +145,7 @@ public static class Builder { private Proxy proxy; private Boolean logRequests; private Boolean logResponses; + private String adToken; /** * Sets the Azure OpenAI endpoint. This is a mandatory parameter. @@ -178,6 +181,11 @@ public Builder apiKey(String apiKey) { return this; } + public Builder adToken(String adToken) { + this.adToken = adToken; + return this; + } + public Builder tokenizer(Tokenizer tokenizer) { this.tokenizer = tokenizer; return this; @@ -212,6 +220,7 @@ public AzureOpenAiEmbeddingModel build() { return new AzureOpenAiEmbeddingModel(endpoint, apiVersion, apiKey, + adToken, tokenizer, timeout, maxRetries, diff --git a/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/AzureOpenAiImageModel.java b/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/AzureOpenAiImageModel.java index 87344be0d..ec502561d 100644 --- a/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/AzureOpenAiImageModel.java +++ b/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/AzureOpenAiImageModel.java @@ -43,7 +43,8 @@ public class AzureOpenAiImageModel implements ImageModel { private final QuarkusOpenAiClient client; - public AzureOpenAiImageModel(String endpoint, String apiKey, String apiVersion, String modelName, String size, + public AzureOpenAiImageModel(String endpoint, String apiKey, String adToken, String apiVersion, String modelName, + String size, String quality, String style, Optional user, String responseFormat, Duration timeout, Integer maxRetries, Boolean logRequests, Boolean logResponses, Optional persistDirectory) { @@ -58,7 +59,6 @@ public AzureOpenAiImageModel(String endpoint, String apiKey, String apiVersion, this.client = QuarkusOpenAiClient.builder() .baseUrl(ensureNotBlank(endpoint, "endpoint")) - .azureApiKey(apiKey) .apiVersion(apiVersion) .callTimeout(timeout) .connectTimeout(timeout) @@ -67,6 +67,8 @@ public AzureOpenAiImageModel(String endpoint, String apiKey, String apiVersion, .logRequests(logRequests) .logResponses(logResponses) .userAgent(DEFAULT_USER_AGENT) + .azureAdToken(adToken) + .azureApiKey(apiKey) .build(); } @@ -144,6 +146,7 @@ public static Builder builder() { public static class Builder { private String endpoint; private String apiKey; + private String adToken; private String apiVersion; private String modelName; private String size; @@ -167,6 +170,11 @@ public Builder apiKey(String apiKey) { return this; } + public Builder adToken(String adToken) { + this.adToken = adToken; + return this; + } + public Builder apiVersion(String apiVersion) { this.apiVersion = apiVersion; return this; @@ -228,7 +236,7 @@ public Builder persistDirectory(Optional persistDirectory) { } public AzureOpenAiImageModel build() { - return new AzureOpenAiImageModel(endpoint, apiKey, apiVersion, modelName, size, quality, style, user, + return new AzureOpenAiImageModel(endpoint, apiKey, adToken, apiVersion, modelName, size, quality, style, user, responseFormat, timeout, maxRetries, logRequests, logResponses, persistDirectory); } diff --git a/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/AzureOpenAiStreamingChatModel.java b/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/AzureOpenAiStreamingChatModel.java index 5ef49225d..dd01b286f 100644 --- a/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/AzureOpenAiStreamingChatModel.java +++ b/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/AzureOpenAiStreamingChatModel.java @@ -63,6 +63,7 @@ public class AzureOpenAiStreamingChatModel implements StreamingChatLanguageModel public AzureOpenAiStreamingChatModel(String endpoint, String apiVersion, String apiKey, + String adToken, Tokenizer tokenizer, Double temperature, Double topP, @@ -79,7 +80,6 @@ public AzureOpenAiStreamingChatModel(String endpoint, this.client = ((QuarkusOpenAiClient.Builder) OpenAiClient.builder() .baseUrl(ensureNotBlank(endpoint, "endpoint")) - .azureApiKey(apiKey) .apiVersion(apiVersion) .callTimeout(timeout) .connectTimeout(timeout) @@ -89,6 +89,8 @@ public AzureOpenAiStreamingChatModel(String endpoint, .logRequests(logRequests) .logStreamingResponses(logResponses)) .userAgent(DEFAULT_USER_AGENT) + .azureAdToken(adToken) + .azureApiKey(apiKey) .build(); this.temperature = getOrDefault(temperature, 0.7); this.topP = topP; @@ -189,6 +191,7 @@ public static class Builder { private String endpoint; private String apiVersion; private String apiKey; + private String adToken; private Tokenizer tokenizer; private Double temperature; private Double topP; @@ -235,6 +238,11 @@ public Builder apiKey(String apiKey) { return this; } + public Builder adToken(String adToken) { + this.adToken = adToken; + return this; + } + public Builder tokenizer(Tokenizer tokenizer) { this.tokenizer = tokenizer; return this; @@ -294,6 +302,7 @@ public AzureOpenAiStreamingChatModel build() { return new AzureOpenAiStreamingChatModel(endpoint, apiVersion, apiKey, + adToken, tokenizer, temperature, topP, diff --git a/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/runtime/AzureOpenAiRecorder.java b/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/runtime/AzureOpenAiRecorder.java index 67b02e098..4b4771148 100644 --- a/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/runtime/AzureOpenAiRecorder.java +++ b/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/runtime/AzureOpenAiRecorder.java @@ -34,7 +34,6 @@ @Recorder public class AzureOpenAiRecorder { - private static final String DUMMY_KEY = "dummy"; static final String AZURE_ENDPOINT_URL_PATTERN = "https://%s.openai.azure.com/openai/deployments/%s"; public static final Problem[] EMPTY_PROBLEMS = new Problem[0]; @@ -43,13 +42,15 @@ public Supplier chatModel(LangChain4jAzureOpenAiConfig runtim if (azureAiConfig.enableIntegration()) { ChatModelConfig chatModelConfig = azureAiConfig.chatModel(); - String apiKey = azureAiConfig.apiKey(); - if (DUMMY_KEY.equals(apiKey)) { - throw new ConfigValidationException(createApiKeyConfigProblem(modelName)); - } + String apiKey = azureAiConfig.apiKey().orElse(null); + String adToken = azureAiConfig.adToken().orElse(null); + + throwIfApiKeysNotConfigured(apiKey, adToken, modelName); + var builder = AzureOpenAiChatModel.builder() .endpoint(getEndpoint(azureAiConfig, modelName)) .apiKey(apiKey) + .adToken(adToken) .apiVersion(azureAiConfig.apiVersion()) .timeout(azureAiConfig.timeout()) .maxRetries(azureAiConfig.maxRetries()) @@ -87,13 +88,15 @@ public Supplier streamingChatModel(LangChain4jAzureO if (azureAiConfig.enableIntegration()) { ChatModelConfig chatModelConfig = azureAiConfig.chatModel(); - String apiKey = azureAiConfig.apiKey(); - if (DUMMY_KEY.equals(apiKey)) { - throw new ConfigValidationException(createApiKeyConfigProblem(modelName)); - } + String apiKey = azureAiConfig.apiKey().orElse(null); + String adToken = azureAiConfig.adToken().orElse(null); + + throwIfApiKeysNotConfigured(apiKey, adToken, modelName); + var builder = AzureOpenAiStreamingChatModel.builder() .endpoint(getEndpoint(azureAiConfig, modelName)) .apiKey(apiKey) + .adToken(adToken) .apiVersion(azureAiConfig.apiVersion()) .timeout(azureAiConfig.timeout()) .logRequests(firstOrDefault(false, chatModelConfig.logRequests(), azureAiConfig.logRequests())) @@ -129,13 +132,15 @@ public Supplier embeddingModel(LangChain4jAzureOpenAiConfig runt if (azureAiConfig.enableIntegration()) { EmbeddingModelConfig embeddingModelConfig = azureAiConfig.embeddingModel(); - String apiKey = azureAiConfig.apiKey(); - if (DUMMY_KEY.equals(apiKey)) { - throw new ConfigValidationException(createApiKeyConfigProblem(modelName)); + String apiKey = azureAiConfig.apiKey().orElse(null); + String adToken = azureAiConfig.adToken().orElse(null); + if (apiKey == null && adToken == null) { + throw new ConfigValidationException(createKeyMisconfigurationProblem(modelName)); } var builder = AzureOpenAiEmbeddingModel.builder() .endpoint(getEndpoint(azureAiConfig, modelName)) .apiKey(apiKey) + .adToken(apiKey) .apiVersion(azureAiConfig.apiVersion()) .timeout(azureAiConfig.timeout()) .maxRetries(azureAiConfig.maxRetries()) @@ -162,16 +167,15 @@ public Supplier imageModel(LangChain4jAzureOpenAiConfig runtimeConfi LangChain4jAzureOpenAiConfig.AzureAiConfig azureAiConfig = correspondingAzureOpenAiConfig(runtimeConfig, modelName); if (azureAiConfig.enableIntegration()) { - var apiKey = azureAiConfig.apiKey(); - - if (DUMMY_KEY.equals(apiKey)) { - throw new ConfigValidationException(createApiKeyConfigProblem(modelName)); - } + var apiKey = azureAiConfig.apiKey().orElse(null); + String adToken = azureAiConfig.adToken().orElse(null); + throwIfApiKeysNotConfigured(apiKey, adToken, modelName); var imageModelConfig = azureAiConfig.imageModel(); var builder = AzureOpenAiImageModel.builder() .endpoint(getEndpoint(azureAiConfig, modelName)) .apiKey(apiKey) + .adToken(adToken) .apiVersion(azureAiConfig.apiVersion()) .timeout(azureAiConfig.timeout()) .maxRetries(azureAiConfig.maxRetries()) @@ -261,12 +265,20 @@ private LangChain4jAzureOpenAiConfig.AzureAiConfig correspondingAzureOpenAiConfi return azureAiConfig; } - private ConfigValidationException.Problem[] createApiKeyConfigProblem(String modelName) { - return createConfigProblems("api-key", modelName); + private void throwIfApiKeysNotConfigured(String apiKey, String adToken, String modelName) { + if ((apiKey != null) == (adToken != null)) { + throw new ConfigValidationException(createKeyMisconfigurationProblem(modelName)); + } } - private ConfigValidationException.Problem[] createConfigProblems(String key, String modelName) { - return new ConfigValidationException.Problem[] { createConfigProblem(key, modelName) }; + private ConfigValidationException.Problem[] createKeyMisconfigurationProblem(String modelName) { + return new ConfigValidationException.Problem[] { + new ConfigValidationException.Problem( + String.format( + "SRCFG00014: Exactly of the configuration properties must be present: quarkus.langchain4j.azure-openai%s%s or quarkus.langchain4j.azure-openai%s%s", + NamedModelUtil.isDefault(modelName) ? "." : ("." + modelName + "."), "api-key", + NamedModelUtil.isDefault(modelName) ? "." : ("." + modelName + "."), "ad-token")) + }; } private static ConfigValidationException.Problem createConfigProblem(String key, String modelName) { diff --git a/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/runtime/config/LangChain4jAzureOpenAiConfig.java b/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/runtime/config/LangChain4jAzureOpenAiConfig.java index 8492db2fe..3e8143258 100644 --- a/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/runtime/config/LangChain4jAzureOpenAiConfig.java +++ b/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/runtime/config/LangChain4jAzureOpenAiConfig.java @@ -68,6 +68,13 @@ interface AzureAiConfig { */ Optional endpoint(); + /** + * The Azure AD token to use for this operation. + * If present, then the requests towards OpenAI will include this in the Authorization header. + * Note that this property overrides the functionality of {@code quarkus.langchain4j.azure-openai.api-key}. + */ + Optional adToken(); + /** * The API version to use for this operation. This follows the YYYY-MM-DD format */ @@ -77,8 +84,7 @@ interface AzureAiConfig { /** * Azure OpenAI API key */ - @WithDefault("dummy") // TODO: this should be optional but Smallrye Config doesn't like it.. - String apiKey(); + Optional apiKey(); /** * Timeout for OpenAI calls diff --git a/openai/azure-openai/runtime/src/test/java/io/quarkiverse/langchain4j/azure/openai/runtime/AzureOpenAiRecorderEndpointTests.java b/openai/azure-openai/runtime/src/test/java/io/quarkiverse/langchain4j/azure/openai/runtime/AzureOpenAiRecorderEndpointTests.java index 1c8b4e15c..a93647e13 100644 --- a/openai/azure-openai/runtime/src/test/java/io/quarkiverse/langchain4j/azure/openai/runtime/AzureOpenAiRecorderEndpointTests.java +++ b/openai/azure-openai/runtime/src/test/java/io/quarkiverse/langchain4j/azure/openai/runtime/AzureOpenAiRecorderEndpointTests.java @@ -142,14 +142,19 @@ public Optional endpoint() { return Optional.empty(); } + @Override + public Optional adToken() { + return Optional.empty(); + } + @Override public String apiVersion() { return null; } @Override - public String apiKey() { - return "my-key"; + public Optional apiKey() { + return Optional.of("my-key"); } @Override diff --git a/openai/openai-common/runtime/src/main/java/io/quarkiverse/langchain4j/openai/OpenAiRestApi.java b/openai/openai-common/runtime/src/main/java/io/quarkiverse/langchain4j/openai/OpenAiRestApi.java index 2f3a99199..f8c063360 100644 --- a/openai/openai-common/runtime/src/main/java/io/quarkiverse/langchain4j/openai/OpenAiRestApi.java +++ b/openai/openai-common/runtime/src/main/java/io/quarkiverse/langchain4j/openai/OpenAiRestApi.java @@ -438,9 +438,20 @@ class ApiMetadata { @HeaderParam("OpenAI-Organization") public final String organizationId; - private ApiMetadata(String openaiApiKey, String azureApiKey, - String apiVersion, String organizationId) { - this.authorization = (openaiApiKey != null) ? "Bearer " + openaiApiKey : null; + private ApiMetadata( + String openaiApiKey, + String azureApiKey, + String azureAdToken, + String apiVersion, + String organizationId) { + if (azureAdToken != null) { + this.authorization = "Bearer " + azureAdToken; + } else if (openaiApiKey != null) { + this.authorization = "Bearer " + openaiApiKey; + } else { + this.authorization = null; + } + this.azureApiKey = azureApiKey; this.apiVersion = apiVersion; this.organizationId = organizationId; @@ -452,20 +463,26 @@ public static ApiMetadata.Builder builder() { public static class Builder { private String azureApiKey; + private String azureAdToken; private String openAiApiKey; private String apiVersion; private String organizationId; public ApiMetadata build() { - if ((azureApiKey != null) && (openAiApiKey != null)) { - return new ApiMetadata(openAiApiKey, azureApiKey, apiVersion, organizationId); + if (azureAdToken != null) { + return new ApiMetadata(null, null, azureAdToken, apiVersion, organizationId); } else if (azureApiKey != null) { - return new ApiMetadata(null, azureApiKey, apiVersion, organizationId); + return new ApiMetadata(null, azureApiKey, null, apiVersion, organizationId); } else if (openAiApiKey != null) { - return new ApiMetadata(openAiApiKey, null, apiVersion, organizationId); + return new ApiMetadata(openAiApiKey, null, null, apiVersion, organizationId); } - return new ApiMetadata(null, null, apiVersion, organizationId); + return new ApiMetadata(null, null, null, apiVersion, organizationId); + } + + public ApiMetadata.Builder azureAdToken(String azureAdToken) { + this.azureAdToken = azureAdToken; + return this; } public ApiMetadata.Builder azureApiKey(String azureApiKey) { diff --git a/openai/openai-common/runtime/src/main/java/io/quarkiverse/langchain4j/openai/QuarkusOpenAiClient.java b/openai/openai-common/runtime/src/main/java/io/quarkiverse/langchain4j/openai/QuarkusOpenAiClient.java index ff75a9c97..e87b46335 100644 --- a/openai/openai-common/runtime/src/main/java/io/quarkiverse/langchain4j/openai/QuarkusOpenAiClient.java +++ b/openai/openai-common/runtime/src/main/java/io/quarkiverse/langchain4j/openai/QuarkusOpenAiClient.java @@ -52,6 +52,7 @@ public class QuarkusOpenAiClient extends OpenAiClient { private final String azureApiKey; + private final String azureAdToken; private final String openaiApiKey; private final String apiVersion; private final String organizationId; @@ -77,6 +78,7 @@ private QuarkusOpenAiClient(Builder builder) { this.openaiApiKey = builder.openAiApiKey; this.apiVersion = builder.apiVersion; this.organizationId = builder.organizationId; + this.azureAdToken = builder.azureAdToken; // cache the client the builder could be called with the same parameters from multiple models this.restApi = cache.compute(builder, new BiFunction() { @Override @@ -129,6 +131,7 @@ public CompletionResponse execute() { CompletionRequest.builder().from(request).stream(null).build(), OpenAiRestApi.ApiMetadata.builder() .azureApiKey(azureApiKey) + .azureAdToken(azureAdToken) .openAiApiKey(openaiApiKey) .apiVersion(apiVersion) .organizationId(organizationId) @@ -187,6 +190,7 @@ public ChatCompletionResponse execute() { ChatCompletionRequest.builder().from(request).stream(null).build(), OpenAiRestApi.ApiMetadata.builder() .azureApiKey(azureApiKey) + .azureAdToken(azureAdToken) .openAiApiKey(openaiApiKey) .apiVersion(apiVersion) .organizationId(organizationId) @@ -244,6 +248,7 @@ public String execute() { .blockingChatCompletion(request, OpenAiRestApi.ApiMetadata.builder() .azureApiKey(azureApiKey) + .azureAdToken(azureAdToken) .openAiApiKey(openaiApiKey) .apiVersion(apiVersion) .organizationId(organizationId) @@ -315,6 +320,7 @@ public EmbeddingResponse execute() { return restApi.blockingEmbedding(request, OpenAiRestApi.ApiMetadata.builder() .azureApiKey(azureApiKey) + .azureAdToken(azureAdToken) .openAiApiKey(openaiApiKey) .apiVersion(apiVersion) .organizationId(organizationId) @@ -352,6 +358,7 @@ public List execute() { return restApi.blockingEmbedding(request, OpenAiRestApi.ApiMetadata.builder() .azureApiKey(azureApiKey) + .azureAdToken(azureAdToken) .openAiApiKey(openaiApiKey) .apiVersion(apiVersion) .organizationId(organizationId) @@ -388,6 +395,7 @@ public ModerationResponse execute() { return restApi.blockingModeration(request, OpenAiRestApi.ApiMetadata.builder() .azureApiKey(azureApiKey) + .azureAdToken(azureAdToken) .openAiApiKey(openaiApiKey) .apiVersion(apiVersion) .organizationId(organizationId) @@ -426,6 +434,7 @@ public ModerationResult execute() { return restApi.blockingModeration(request, OpenAiRestApi.ApiMetadata.builder() .azureApiKey(azureApiKey) + .azureAdToken(azureAdToken) .openAiApiKey(openaiApiKey) .apiVersion(apiVersion) .organizationId(organizationId) @@ -462,6 +471,7 @@ public GenerateImagesResponse execute() { return restApi.blockingImagesGenerations(generateImagesRequest, OpenAiRestApi.ApiMetadata.builder() .azureApiKey(azureApiKey) + .azureAdToken(azureAdToken) .openAiApiKey(openaiApiKey) .apiVersion(apiVersion) .organizationId(organizationId) @@ -504,12 +514,30 @@ public Builder get() { public static class Builder extends OpenAiClient.Builder { private String userAgent; + private String azureAdToken; public Builder userAgent(String userAgent) { this.userAgent = userAgent; return this; } + public Builder azureAdToken(String azureAdToken) { + this.azureAdToken = azureAdToken; + return this; + } + + @Override + public Builder openAiApiKey(String openAiApiKey) { + this.openAiApiKey = openAiApiKey; + return this; + } + + @Override + public Builder azureApiKey(String azureApiKey) { + this.azureApiKey = azureApiKey; + return this; + } + @Override public QuarkusOpenAiClient build() { return new QuarkusOpenAiClient(this); @@ -535,6 +563,7 @@ public boolean equals(Object o) { && Objects.equals(readTimeout, builder.readTimeout) && Objects.equals(writeTimeout, builder.writeTimeout) && Objects.equals(proxy, builder.proxy) + && Objects.equals(azureAdToken, builder.azureAdToken) && Objects.equals(userAgent, builder.userAgent); } @@ -542,7 +571,7 @@ public boolean equals(Object o) { public int hashCode() { return Objects.hash(baseUrl, apiVersion, openAiApiKey, azureApiKey, organizationId, callTimeout, connectTimeout, readTimeout, - writeTimeout, proxy, logRequests, logResponses, logStreamingResponses, userAgent); + writeTimeout, proxy, logRequests, logResponses, logStreamingResponses, userAgent, azureAdToken); } }