Skip to content

Commit

Permalink
Merge pull request #467 from csotiriou/feature/azure-ad-token-support
Browse files Browse the repository at this point in the history
Add support for Azure AD Tokens
  • Loading branch information
geoand authored Apr 23, 2024
2 parents 0611146 + c33d8c3 commit 98bb594
Show file tree
Hide file tree
Showing 10 changed files with 152 additions and 41 deletions.
9 changes: 8 additions & 1 deletion docs/modules/ROOT/pages/openai.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,18 @@ When this extension is used, the following configuration properties are required

[source, properties]
----
quarkus.langchain4j.azure-openai.api-key=
quarkus.langchain4j.azure-openai.resource-name=
quarkus.langchain4j.azure-openai.deployment-name=
# And one of the below depending on your scenario
quarkus.langchain4j.azure-openai.api-key=
quarkus.langchain4j.azure-openai.ad-token=
----

In the case of Azure, the `api-key` and `ad-token` properties are mutually exclusive. The `api-key` property should be used when the Azure OpenAI service is configured to use API keys, while the `ad-token` property should be used when the Azure OpenAI service is configured to use Azure Active Directory tokens.

In both cases, the key will be placed in the Authorization header when communicating with the Azure OpenAI service.

== Advanced usage

`quarkus-langchain4j-openai` and `quarkus-langchain4j-azure-openai` extensions use a REST Client under the hood to make the REST calls required by LangChain4j.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ public class AzureOpenAiChatModel implements ChatLanguageModel, TokenCountEstima
public AzureOpenAiChatModel(String endpoint,
String apiVersion,
String apiKey,
String adToken,
Tokenizer tokenizer,
Double temperature,
Double topP,
Expand All @@ -74,7 +75,6 @@ public AzureOpenAiChatModel(String endpoint,

this.client = ((QuarkusOpenAiClient.Builder) OpenAiClient.builder()
.baseUrl(ensureNotBlank(endpoint, "endpoint"))
.azureApiKey(apiKey)
.apiVersion(apiVersion)
.callTimeout(timeout)
.connectTimeout(timeout)
Expand All @@ -84,6 +84,8 @@ public AzureOpenAiChatModel(String endpoint,
.logRequests(logRequests)
.logResponses(logResponses))
.userAgent(Consts.DEFAULT_USER_AGENT)
.azureAdToken(adToken)
.azureApiKey(apiKey)
.build();

this.temperature = getOrDefault(temperature, 0.7);
Expand Down Expand Up @@ -154,6 +156,7 @@ public static class Builder {
private String endpoint;
private String apiVersion;
private String apiKey;
private String adToken;
private Tokenizer tokenizer;
private Double temperature;
private Double topP;
Expand Down Expand Up @@ -201,6 +204,11 @@ public Builder apiKey(String apiKey) {
return this;
}

public Builder adToken(String adToken) {
this.adToken = adToken;
return this;
}

public Builder tokenizer(Tokenizer tokenizer) {
this.tokenizer = tokenizer;
return this;
Expand Down Expand Up @@ -265,6 +273,7 @@ public AzureOpenAiChatModel build() {
return new AzureOpenAiChatModel(endpoint,
apiVersion,
apiKey,
adToken,
tokenizer,
temperature,
topP,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ public class AzureOpenAiEmbeddingModel implements EmbeddingModel, TokenCountEsti
public AzureOpenAiEmbeddingModel(String endpoint,
String apiVersion,
String apiKey,
String adToken,
Tokenizer tokenizer,
Duration timeout,
Integer maxRetries,
Expand All @@ -63,7 +64,6 @@ public AzureOpenAiEmbeddingModel(String endpoint,

this.client = ((QuarkusOpenAiClient.Builder) OpenAiClient.builder()
.baseUrl(ensureNotBlank(endpoint, "endpoint"))
.azureApiKey(apiKey)
.apiVersion(apiVersion)
.callTimeout(timeout)
.connectTimeout(timeout)
Expand All @@ -73,6 +73,8 @@ public AzureOpenAiEmbeddingModel(String endpoint,
.logRequests(logRequests)
.logResponses(logResponses))
.userAgent(DEFAULT_USER_AGENT)
.azureAdToken(adToken)
.azureApiKey(apiKey)
.build();
this.maxRetries = getOrDefault(maxRetries, 3);
this.tokenizer = tokenizer;
Expand Down Expand Up @@ -143,6 +145,7 @@ public static class Builder {
private Proxy proxy;
private Boolean logRequests;
private Boolean logResponses;
private String adToken;

/**
* Sets the Azure OpenAI endpoint. This is a mandatory parameter.
Expand Down Expand Up @@ -178,6 +181,11 @@ public Builder apiKey(String apiKey) {
return this;
}

public Builder adToken(String adToken) {
this.adToken = adToken;
return this;
}

public Builder tokenizer(Tokenizer tokenizer) {
this.tokenizer = tokenizer;
return this;
Expand Down Expand Up @@ -212,6 +220,7 @@ public AzureOpenAiEmbeddingModel build() {
return new AzureOpenAiEmbeddingModel(endpoint,
apiVersion,
apiKey,
adToken,
tokenizer,
timeout,
maxRetries,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ public class AzureOpenAiImageModel implements ImageModel {

private final QuarkusOpenAiClient client;

public AzureOpenAiImageModel(String endpoint, String apiKey, String apiVersion, String modelName, String size,
public AzureOpenAiImageModel(String endpoint, String apiKey, String adToken, String apiVersion, String modelName,
String size,
String quality, String style, Optional<String> user, String responseFormat, Duration timeout,
Integer maxRetries, Boolean logRequests, Boolean logResponses,
Optional<Path> persistDirectory) {
Expand All @@ -58,7 +59,6 @@ public AzureOpenAiImageModel(String endpoint, String apiKey, String apiVersion,

this.client = QuarkusOpenAiClient.builder()
.baseUrl(ensureNotBlank(endpoint, "endpoint"))
.azureApiKey(apiKey)
.apiVersion(apiVersion)
.callTimeout(timeout)
.connectTimeout(timeout)
Expand All @@ -67,6 +67,8 @@ public AzureOpenAiImageModel(String endpoint, String apiKey, String apiVersion,
.logRequests(logRequests)
.logResponses(logResponses)
.userAgent(DEFAULT_USER_AGENT)
.azureAdToken(adToken)
.azureApiKey(apiKey)
.build();
}

Expand Down Expand Up @@ -144,6 +146,7 @@ public static Builder builder() {
public static class Builder {
private String endpoint;
private String apiKey;
private String adToken;
private String apiVersion;
private String modelName;
private String size;
Expand All @@ -167,6 +170,11 @@ public Builder apiKey(String apiKey) {
return this;
}

public Builder adToken(String adToken) {
this.adToken = adToken;
return this;
}

public Builder apiVersion(String apiVersion) {
this.apiVersion = apiVersion;
return this;
Expand Down Expand Up @@ -228,7 +236,7 @@ public Builder persistDirectory(Optional<Path> persistDirectory) {
}

public AzureOpenAiImageModel build() {
return new AzureOpenAiImageModel(endpoint, apiKey, apiVersion, modelName, size, quality, style, user,
return new AzureOpenAiImageModel(endpoint, apiKey, adToken, apiVersion, modelName, size, quality, style, user,
responseFormat, timeout, maxRetries, logRequests, logResponses,
persistDirectory);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ public class AzureOpenAiStreamingChatModel implements StreamingChatLanguageModel
public AzureOpenAiStreamingChatModel(String endpoint,
String apiVersion,
String apiKey,
String adToken,
Tokenizer tokenizer,
Double temperature,
Double topP,
Expand All @@ -79,7 +80,6 @@ public AzureOpenAiStreamingChatModel(String endpoint,

this.client = ((QuarkusOpenAiClient.Builder) OpenAiClient.builder()
.baseUrl(ensureNotBlank(endpoint, "endpoint"))
.azureApiKey(apiKey)
.apiVersion(apiVersion)
.callTimeout(timeout)
.connectTimeout(timeout)
Expand All @@ -89,6 +89,8 @@ public AzureOpenAiStreamingChatModel(String endpoint,
.logRequests(logRequests)
.logStreamingResponses(logResponses))
.userAgent(DEFAULT_USER_AGENT)
.azureAdToken(adToken)
.azureApiKey(apiKey)
.build();
this.temperature = getOrDefault(temperature, 0.7);
this.topP = topP;
Expand Down Expand Up @@ -189,6 +191,7 @@ public static class Builder {
private String endpoint;
private String apiVersion;
private String apiKey;
private String adToken;
private Tokenizer tokenizer;
private Double temperature;
private Double topP;
Expand Down Expand Up @@ -235,6 +238,11 @@ public Builder apiKey(String apiKey) {
return this;
}

public Builder adToken(String adToken) {
this.adToken = adToken;
return this;
}

public Builder tokenizer(Tokenizer tokenizer) {
this.tokenizer = tokenizer;
return this;
Expand Down Expand Up @@ -294,6 +302,7 @@ public AzureOpenAiStreamingChatModel build() {
return new AzureOpenAiStreamingChatModel(endpoint,
apiVersion,
apiKey,
adToken,
tokenizer,
temperature,
topP,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
@Recorder
public class AzureOpenAiRecorder {

private static final String DUMMY_KEY = "dummy";
static final String AZURE_ENDPOINT_URL_PATTERN = "https://%s.openai.azure.com/openai/deployments/%s";
public static final Problem[] EMPTY_PROBLEMS = new Problem[0];

Expand All @@ -43,13 +42,15 @@ public Supplier<ChatLanguageModel> chatModel(LangChain4jAzureOpenAiConfig runtim

if (azureAiConfig.enableIntegration()) {
ChatModelConfig chatModelConfig = azureAiConfig.chatModel();
String apiKey = azureAiConfig.apiKey();
if (DUMMY_KEY.equals(apiKey)) {
throw new ConfigValidationException(createApiKeyConfigProblem(modelName));
}
String apiKey = azureAiConfig.apiKey().orElse(null);
String adToken = azureAiConfig.adToken().orElse(null);

throwIfApiKeysNotConfigured(apiKey, adToken, modelName);

var builder = AzureOpenAiChatModel.builder()
.endpoint(getEndpoint(azureAiConfig, modelName))
.apiKey(apiKey)
.adToken(adToken)
.apiVersion(azureAiConfig.apiVersion())
.timeout(azureAiConfig.timeout())
.maxRetries(azureAiConfig.maxRetries())
Expand Down Expand Up @@ -87,13 +88,15 @@ public Supplier<StreamingChatLanguageModel> streamingChatModel(LangChain4jAzureO

if (azureAiConfig.enableIntegration()) {
ChatModelConfig chatModelConfig = azureAiConfig.chatModel();
String apiKey = azureAiConfig.apiKey();
if (DUMMY_KEY.equals(apiKey)) {
throw new ConfigValidationException(createApiKeyConfigProblem(modelName));
}
String apiKey = azureAiConfig.apiKey().orElse(null);
String adToken = azureAiConfig.adToken().orElse(null);

throwIfApiKeysNotConfigured(apiKey, adToken, modelName);

var builder = AzureOpenAiStreamingChatModel.builder()
.endpoint(getEndpoint(azureAiConfig, modelName))
.apiKey(apiKey)
.adToken(adToken)
.apiVersion(azureAiConfig.apiVersion())
.timeout(azureAiConfig.timeout())
.logRequests(firstOrDefault(false, chatModelConfig.logRequests(), azureAiConfig.logRequests()))
Expand Down Expand Up @@ -129,13 +132,15 @@ public Supplier<EmbeddingModel> embeddingModel(LangChain4jAzureOpenAiConfig runt

if (azureAiConfig.enableIntegration()) {
EmbeddingModelConfig embeddingModelConfig = azureAiConfig.embeddingModel();
String apiKey = azureAiConfig.apiKey();
if (DUMMY_KEY.equals(apiKey)) {
throw new ConfigValidationException(createApiKeyConfigProblem(modelName));
String apiKey = azureAiConfig.apiKey().orElse(null);
String adToken = azureAiConfig.adToken().orElse(null);
if (apiKey == null && adToken == null) {
throw new ConfigValidationException(createKeyMisconfigurationProblem(modelName));
}
var builder = AzureOpenAiEmbeddingModel.builder()
.endpoint(getEndpoint(azureAiConfig, modelName))
.apiKey(apiKey)
.adToken(apiKey)
.apiVersion(azureAiConfig.apiVersion())
.timeout(azureAiConfig.timeout())
.maxRetries(azureAiConfig.maxRetries())
Expand All @@ -162,16 +167,15 @@ public Supplier<ImageModel> imageModel(LangChain4jAzureOpenAiConfig runtimeConfi
LangChain4jAzureOpenAiConfig.AzureAiConfig azureAiConfig = correspondingAzureOpenAiConfig(runtimeConfig, modelName);

if (azureAiConfig.enableIntegration()) {
var apiKey = azureAiConfig.apiKey();

if (DUMMY_KEY.equals(apiKey)) {
throw new ConfigValidationException(createApiKeyConfigProblem(modelName));
}
var apiKey = azureAiConfig.apiKey().orElse(null);
String adToken = azureAiConfig.adToken().orElse(null);
throwIfApiKeysNotConfigured(apiKey, adToken, modelName);

var imageModelConfig = azureAiConfig.imageModel();
var builder = AzureOpenAiImageModel.builder()
.endpoint(getEndpoint(azureAiConfig, modelName))
.apiKey(apiKey)
.adToken(adToken)
.apiVersion(azureAiConfig.apiVersion())
.timeout(azureAiConfig.timeout())
.maxRetries(azureAiConfig.maxRetries())
Expand Down Expand Up @@ -261,12 +265,20 @@ private LangChain4jAzureOpenAiConfig.AzureAiConfig correspondingAzureOpenAiConfi
return azureAiConfig;
}

private ConfigValidationException.Problem[] createApiKeyConfigProblem(String modelName) {
return createConfigProblems("api-key", modelName);
private void throwIfApiKeysNotConfigured(String apiKey, String adToken, String modelName) {
if ((apiKey != null) == (adToken != null)) {
throw new ConfigValidationException(createKeyMisconfigurationProblem(modelName));
}
}

private ConfigValidationException.Problem[] createConfigProblems(String key, String modelName) {
return new ConfigValidationException.Problem[] { createConfigProblem(key, modelName) };
private ConfigValidationException.Problem[] createKeyMisconfigurationProblem(String modelName) {
return new ConfigValidationException.Problem[] {
new ConfigValidationException.Problem(
String.format(
"SRCFG00014: Exactly of the configuration properties must be present: quarkus.langchain4j.azure-openai%s%s or quarkus.langchain4j.azure-openai%s%s",
NamedModelUtil.isDefault(modelName) ? "." : ("." + modelName + "."), "api-key",
NamedModelUtil.isDefault(modelName) ? "." : ("." + modelName + "."), "ad-token"))
};
}

private static ConfigValidationException.Problem createConfigProblem(String key, String modelName) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,13 @@ interface AzureAiConfig {
*/
Optional<String> endpoint();

/**
* The Azure AD token to use for this operation.
* If present, then the requests towards OpenAI will include this in the Authorization header.
* Note that this property overrides the functionality of {@code quarkus.langchain4j.azure-openai.api-key}.
*/
Optional<String> adToken();

/**
* The API version to use for this operation. This follows the YYYY-MM-DD format
*/
Expand All @@ -77,8 +84,7 @@ interface AzureAiConfig {
/**
* Azure OpenAI API key
*/
@WithDefault("dummy") // TODO: this should be optional but Smallrye Config doesn't like it..
String apiKey();
Optional<String> apiKey();

/**
* Timeout for OpenAI calls
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,14 +142,19 @@ public Optional<String> endpoint() {
return Optional.empty();
}

@Override
public Optional<String> adToken() {
return Optional.empty();
}

@Override
public String apiVersion() {
return null;
}

@Override
public String apiKey() {
return "my-key";
public Optional<String> apiKey() {
return Optional.of("my-key");
}

@Override
Expand Down
Loading

0 comments on commit 98bb594

Please sign in to comment.