From 03099f785e72d94b8aa7124d0d1bf0f3c4b39a3d Mon Sep 17 00:00:00 2001 From: Andrea Di Maio Date: Thu, 10 Oct 2024 18:54:09 +0200 Subject: [PATCH 1/4] [watsonx.ai] Add time_limit parameter --- .../watsonx/deployment/AllPropertiesTest.java | 1 + .../watsonx/deployment/DefaultPropertiesTest.java | 1 + .../langchain4j/watsonx/WatsonxChatModel.java | 1 + .../langchain4j/watsonx/WatsonxModel.java | 2 ++ .../watsonx/WatsonxStreamingChatModel.java | 1 + .../langchain4j/watsonx/bean/Parameters.java | 12 ++++++++++++ 6 files changed, 18 insertions(+) diff --git a/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/AllPropertiesTest.java b/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/AllPropertiesTest.java index 9cba911d0..9179008c9 100644 --- a/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/AllPropertiesTest.java +++ b/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/AllPropertiesTest.java @@ -98,6 +98,7 @@ void handlerBeforeEach() { .randomSeed(2) .stopSequences(List.of("\n", "\n\n")) .temperature(1.5) + .timeLimit(60000L) .topK(90) .topP(0.5) .repetitionPenalty(2.0) diff --git a/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/DefaultPropertiesTest.java b/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/DefaultPropertiesTest.java index da76262dd..0722e6ec5 100644 --- a/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/DefaultPropertiesTest.java +++ b/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/DefaultPropertiesTest.java @@ -58,6 +58,7 @@ void handlerBeforeEach() { .maxNewTokens(200) .decodingMethod("greedy") .temperature(1.0) + .timeLimit(10000L) .build(); @Inject diff --git a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxChatModel.java b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxChatModel.java index 6a5de36e9..57e9d10f0 100644 --- a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxChatModel.java +++ b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxChatModel.java @@ -108,6 +108,7 @@ private Parameters createParameters() { .randomSeed(randomSeed) .stopSequences(stopSequences) .temperature(temperature) + .timeLimit(timeLimit) .topP(topP) .topK(topK) .repetitionPenalty(repetitionPenalty) diff --git a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxModel.java b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxModel.java index e9d65c4db..66b39f5d3 100644 --- a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxModel.java +++ b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxModel.java @@ -37,6 +37,7 @@ public abstract class WatsonxModel { final Integer randomSeed; final List stopSequences; final Double temperature; + final Long timeLimit; final Double topP; final Integer topK; final Double repetitionPenalty; @@ -72,6 +73,7 @@ public WatsonxModel(Builder builder) { this.randomSeed = builder.randomSeed; this.stopSequences = builder.stopSequences; this.temperature = builder.temperature; + this.timeLimit = builder.timeout.toMillis(); this.topP = builder.topP; this.topK = builder.topK; this.repetitionPenalty = builder.repetitionPenalty; diff --git a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxStreamingChatModel.java b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxStreamingChatModel.java index 5db6cf940..8ab87a703 100644 --- a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxStreamingChatModel.java +++ b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxStreamingChatModel.java @@ -43,6 +43,7 @@ public void generate(List messages, StreamingResponseHandler stopSequences; private final Double temperature; + private final Long timeLimit; private final Integer topK; private final Double topP; private final Double repetitionPenalty; @@ -28,6 +29,7 @@ private Parameters(Builder builder) { this.randomSeed = builder.randomSeed; this.stopSequences = builder.stopSequences; this.temperature = builder.temperature; + this.timeLimit = builder.timeLimit; this.topK = builder.topK; this.topP = builder.topP; this.repetitionPenalty = builder.repetitionPenalty; @@ -59,6 +61,10 @@ public Double getTemperature() { return temperature; } + public Long getTimeLimit() { + return timeLimit; + } + public Integer getRandomSeed() { return randomSeed; } @@ -96,6 +102,7 @@ public static class Builder { private Integer randomSeed; private List stopSequences; private Double temperature; + private Long timeLimit; private Integer topK; private Double topP; private Double repetitionPenalty; @@ -127,6 +134,11 @@ public Builder temperature(Double temperature) { return this; } + public Builder timeLimit(Long timeLimit) { + this.timeLimit = timeLimit; + return this; + } + public Builder randomSeed(Integer randomSeed) { this.randomSeed = randomSeed; return this; From f87f11221c08acce25d13de8678f7d738975244e Mon Sep 17 00:00:00 2001 From: Andrea Di Maio Date: Thu, 10 Oct 2024 22:01:01 +0200 Subject: [PATCH 2/4] [watsonx.ai] Refactor WatsonxChatModel/WatsonxStreamingChatModel to WatsonxGenerationModel --- .../multiple/MultipleChatProvidersTest.java | 4 +- ...tipleTokenCountEstimatorProvidersTest.java | 6 +- .../watsonx/deployment/WatsonxProcessor.java | 15 +- .../watsonx/deployment/AiChatServiceTest.java | 1 + .../watsonx/deployment/AllPropertiesTest.java | 14 - .../deployment/ChatMemoryPlaceholderTest.java | 84 ++-- .../deployment/DefaultPropertiesTest.java | 9 - .../PromptFormatterForceDefaultTest.java | 94 ---- .../deployment/PromptFormatterTest.java | 1 + .../deployment/ResponseSchemaOnTest.java | 1 + .../langchain4j/watsonx/WatsonxChatModel.java | 121 ------ .../watsonx/WatsonxEmbeddingModel.java | 96 ++++- .../watsonx/WatsonxGenerationModel.java | 401 ++++++++++++++++++ .../langchain4j/watsonx/WatsonxModel.java | 302 ------------- .../watsonx/WatsonxStreamingChatModel.java | 129 ------ ...erator.java => WatsonxTokenGenerator.java} | 4 +- .../langchain4j/watsonx/WatsonxUtils.java | 45 ++ .../filter/BearerTokenHeaderFactory.java | 6 +- .../watsonx/runtime/WatsonxRecorder.java | 34 +- .../DisabledModelsWatsonRecorderTest.java | 5 +- 20 files changed, 611 insertions(+), 761 deletions(-) delete mode 100644 model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/PromptFormatterForceDefaultTest.java delete mode 100644 model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxChatModel.java create mode 100644 model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxGenerationModel.java delete mode 100644 model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxModel.java delete mode 100644 model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxStreamingChatModel.java rename model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/{TokenGenerator.java => WatsonxTokenGenerator.java} (94%) create mode 100644 model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxUtils.java diff --git a/integration-tests/multiple-providers/src/test/java/org/acme/example/multiple/MultipleChatProvidersTest.java b/integration-tests/multiple-providers/src/test/java/org/acme/example/multiple/MultipleChatProvidersTest.java index 8dd45a3db..fd0352868 100644 --- a/integration-tests/multiple-providers/src/test/java/org/acme/example/multiple/MultipleChatProvidersTest.java +++ b/integration-tests/multiple-providers/src/test/java/org/acme/example/multiple/MultipleChatProvidersTest.java @@ -13,7 +13,7 @@ import io.quarkiverse.langchain4j.huggingface.QuarkusHuggingFaceChatModel; import io.quarkiverse.langchain4j.ollama.OllamaChatLanguageModel; import io.quarkiverse.langchain4j.openshiftai.OpenshiftAiChatModel; -import io.quarkiverse.langchain4j.watsonx.WatsonxChatModel; +import io.quarkiverse.langchain4j.watsonx.WatsonxGenerationModel; import io.quarkus.arc.ClientProxy; import io.quarkus.test.junit.QuarkusTest; @@ -79,6 +79,6 @@ void sixthNamedModel() { @Test void seventhNamedModel() { - assertThat(ClientProxy.unwrap(seventhNamedModel)).isInstanceOf(WatsonxChatModel.class); + assertThat(ClientProxy.unwrap(seventhNamedModel)).isInstanceOf(WatsonxGenerationModel.class); } } diff --git a/integration-tests/multiple-providers/src/test/java/org/acme/example/multiple/MultipleTokenCountEstimatorProvidersTest.java b/integration-tests/multiple-providers/src/test/java/org/acme/example/multiple/MultipleTokenCountEstimatorProvidersTest.java index 444fdb5bf..bcb127713 100644 --- a/integration-tests/multiple-providers/src/test/java/org/acme/example/multiple/MultipleTokenCountEstimatorProvidersTest.java +++ b/integration-tests/multiple-providers/src/test/java/org/acme/example/multiple/MultipleTokenCountEstimatorProvidersTest.java @@ -10,7 +10,7 @@ import dev.langchain4j.model.chat.TokenCountEstimator; import io.quarkiverse.langchain4j.ModelName; import io.quarkiverse.langchain4j.azure.openai.AzureOpenAiChatModel; -import io.quarkiverse.langchain4j.watsonx.WatsonxChatModel; +import io.quarkiverse.langchain4j.watsonx.WatsonxGenerationModel; import io.quarkus.arc.ClientProxy; import io.quarkus.test.junit.QuarkusTest; @@ -41,7 +41,7 @@ void azureOpenAiTest() { @Test void watsonxTest() { - assertThat(ClientProxy.unwrap(watsonxChat)).isInstanceOf(WatsonxChatModel.class); - assertThat(ClientProxy.unwrap(watsonxTokenizer)).isInstanceOf(WatsonxChatModel.class); + assertThat(ClientProxy.unwrap(watsonxChat)).isInstanceOf(WatsonxGenerationModel.class); + assertThat(ClientProxy.unwrap(watsonxTokenizer)).isInstanceOf(WatsonxGenerationModel.class); } } diff --git a/model-providers/watsonx/deployment/src/main/java/io/quarkiverse/langchain4j/watsonx/deployment/WatsonxProcessor.java b/model-providers/watsonx/deployment/src/main/java/io/quarkiverse/langchain4j/watsonx/deployment/WatsonxProcessor.java index 10f2fb2d7..f29726932 100644 --- a/model-providers/watsonx/deployment/src/main/java/io/quarkiverse/langchain4j/watsonx/deployment/WatsonxProcessor.java +++ b/model-providers/watsonx/deployment/src/main/java/io/quarkiverse/langchain4j/watsonx/deployment/WatsonxProcessor.java @@ -154,13 +154,17 @@ void generateBeans(WatsonxRecorder recorder, LangChain4jWatsonxConfig runtimeCon String configName = selected.getConfigName(); PromptFormatter promptFormatter = selected.getPromptFormatter(); - var chatModel = recorder.chatModel(runtimeConfig, fixedRuntimeConfig, configName, promptFormatter); + var chatLanguageModel = recorder.generationModel(runtimeConfig, fixedRuntimeConfig, configName, promptFormatter); + var streamingChatLanguageModel = recorder.generationStreamingModel(runtimeConfig, fixedRuntimeConfig, configName, + promptFormatter); + var chatBuilder = SyntheticBeanBuildItem .configure(CHAT_MODEL) .setRuntimeInit() .defaultBean() .scope(ApplicationScoped.class) - .supplier(chatModel); + .supplier(chatLanguageModel); + addQualifierIfNecessary(chatBuilder, configName); beanProducer.produce(chatBuilder.done()); @@ -169,7 +173,8 @@ void generateBeans(WatsonxRecorder recorder, LangChain4jWatsonxConfig runtimeCon .setRuntimeInit() .defaultBean() .scope(ApplicationScoped.class) - .supplier(chatModel); + .supplier(chatLanguageModel); + addQualifierIfNecessary(tokenizerBuilder, configName); beanProducer.produce(tokenizerBuilder.done()); @@ -178,8 +183,8 @@ void generateBeans(WatsonxRecorder recorder, LangChain4jWatsonxConfig runtimeCon .setRuntimeInit() .defaultBean() .scope(ApplicationScoped.class) - .supplier(recorder.streamingChatModel(runtimeConfig, fixedRuntimeConfig, configName, - promptFormatter)); + .supplier(streamingChatLanguageModel); + addQualifierIfNecessary(streamingBuilder, configName); beanProducer.produce(streamingBuilder.done()); } diff --git a/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/AiChatServiceTest.java b/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/AiChatServiceTest.java index 998d5a0a0..dc596d716 100644 --- a/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/AiChatServiceTest.java +++ b/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/AiChatServiceTest.java @@ -60,6 +60,7 @@ void chat() throws Exception { .temperature(chatModelConfig.temperature()) .minNewTokens(chatModelConfig.minNewTokens()) .maxNewTokens(chatModelConfig.maxNewTokens()) + .timeLimit(10000L) .build(); TextGenerationRequest body = new TextGenerationRequest(modelId, projectId, input, parameters); diff --git a/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/AllPropertiesTest.java b/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/AllPropertiesTest.java index 9179008c9..87c867c5f 100644 --- a/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/AllPropertiesTest.java +++ b/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/AllPropertiesTest.java @@ -4,7 +4,6 @@ import static org.awaitility.Awaitility.await; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; -import static org.junit.jupiter.api.Assertions.assertTrue; import java.time.Duration; import java.util.Date; @@ -26,15 +25,11 @@ import dev.langchain4j.model.chat.TokenCountEstimator; import dev.langchain4j.model.embedding.EmbeddingModel; import dev.langchain4j.model.output.Response; -import io.quarkiverse.langchain4j.watsonx.WatsonxChatModel; -import io.quarkiverse.langchain4j.watsonx.WatsonxStreamingChatModel; import io.quarkiverse.langchain4j.watsonx.bean.EmbeddingRequest; import io.quarkiverse.langchain4j.watsonx.bean.Parameters; import io.quarkiverse.langchain4j.watsonx.bean.Parameters.LengthPenalty; import io.quarkiverse.langchain4j.watsonx.bean.TextGenerationRequest; import io.quarkiverse.langchain4j.watsonx.bean.TokenizationRequest; -import io.quarkiverse.langchain4j.watsonx.prompt.impl.NoopPromptFormatter; -import io.quarkus.arc.ClientProxy; import io.quarkus.test.QuarkusUnitTest; public class AllPropertiesTest extends WireMockAbstract { @@ -106,15 +101,6 @@ void handlerBeforeEach() { .includeStopSequence(false) .build(); - @Test - void prompt_formatter() { - var unwrapChatModel = (WatsonxChatModel) ClientProxy.unwrap(chatModel); - assertTrue(unwrapChatModel.getPromptFormatter() instanceof NoopPromptFormatter); - - var unwrapStreamingChatModel = (WatsonxStreamingChatModel) ClientProxy.unwrap(streamingChatModel); - assertTrue(unwrapStreamingChatModel.getPromptFormatter() instanceof NoopPromptFormatter); - } - @Test void check_config() throws Exception { var runtimeConfig = langchain4jWatsonConfig.defaultConfig(); diff --git a/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/ChatMemoryPlaceholderTest.java b/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/ChatMemoryPlaceholderTest.java index 9f143359f..b30b36571 100644 --- a/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/ChatMemoryPlaceholderTest.java +++ b/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/ChatMemoryPlaceholderTest.java @@ -97,26 +97,16 @@ public interface NoMemoryAiService { @Test void extract_dialogue_test() throws Exception { - LangChain4jWatsonxConfig.WatsonConfig watsonConfig = langchain4jWatsonConfig.defaultConfig(); - ChatModelConfig chatModelConfig = watsonConfig.chatModel(); - String modelId = langchain4jWatsonFixedRuntimeConfig.defaultConfig().chatModel().modelId(); String chatMemoryId = "userId"; - String projectId = watsonConfig.projectId(); - Parameters parameters = Parameters.builder() - .decodingMethod(chatModelConfig.decodingMethod()) - .temperature(chatModelConfig.temperature()) - .minNewTokens(chatModelConfig.minNewTokens()) - .maxNewTokens(chatModelConfig.maxNewTokens()) - .build(); var input = """ You are a helpful assistant Context: Hello"""; - var body = new TextGenerationRequest(modelId, projectId, input, parameters); + mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_CHAT_API, 200) - .body(mapper.writeValueAsString(body)) + .body(mapper.writeValueAsString(createRequest(input))) .response(""" { "results": [ @@ -140,9 +130,9 @@ void extract_dialogue_test() throws Exception { Hello Hi! What is your name?"""; - body = new TextGenerationRequest(modelId, projectId, input, parameters); + mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_CHAT_API, 200) - .body(mapper.writeValueAsString(body)) + .body(mapper.writeValueAsString(createRequest(input))) .response(""" { "results": [ @@ -162,26 +152,16 @@ void extract_dialogue_test() throws Exception { @Test void extract_dialogue_with_delimiter_test() throws Exception { - LangChain4jWatsonxConfig.WatsonConfig watsonConfig = langchain4jWatsonConfig.defaultConfig(); - ChatModelConfig chatModelConfig = watsonConfig.chatModel(); - String modelId = langchain4jWatsonFixedRuntimeConfig.defaultConfig().chatModel().modelId(); String chatMemoryId = "userId_with_delimiter"; - String projectId = watsonConfig.projectId(); - Parameters parameters = Parameters.builder() - .decodingMethod(chatModelConfig.decodingMethod()) - .temperature(chatModelConfig.temperature()) - .minNewTokens(chatModelConfig.minNewTokens()) - .maxNewTokens(chatModelConfig.maxNewTokens()) - .build(); var input = """ You are a helpful assistant Context: Hello"""; - var body = new TextGenerationRequest(modelId, projectId, input, parameters); + mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_CHAT_API, 200) - .body(mapper.writeValueAsString(body)) + .body(mapper.writeValueAsString(createRequest(input))) .response(""" { "results": [ @@ -204,9 +184,9 @@ void extract_dialogue_with_delimiter_test() throws Exception { Hello Hi! What is your name?"""; - body = new TextGenerationRequest(modelId, projectId, input, parameters); + mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_CHAT_API, 200) - .body(mapper.writeValueAsString(body)) + .body(mapper.writeValueAsString(createRequest(input))) .response(""" { "results": [ @@ -226,26 +206,16 @@ void extract_dialogue_with_delimiter_test() throws Exception { @Test void extract_dialogue_with_all_params_test() throws Exception { - LangChain4jWatsonxConfig.WatsonConfig watsonConfig = langchain4jWatsonConfig.defaultConfig(); - ChatModelConfig chatModelConfig = watsonConfig.chatModel(); - String modelId = langchain4jWatsonFixedRuntimeConfig.defaultConfig().chatModel().modelId(); String chatMemoryId = "userId_with_all_params"; - String projectId = watsonConfig.projectId(); - Parameters parameters = Parameters.builder() - .decodingMethod(chatModelConfig.decodingMethod()) - .temperature(chatModelConfig.temperature()) - .minNewTokens(chatModelConfig.minNewTokens()) - .maxNewTokens(chatModelConfig.maxNewTokens()) - .build(); var input = """ You are a helpful assistant Context: Hello"""; - var body = new TextGenerationRequest(modelId, projectId, input, parameters); + mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_CHAT_API, 200) - .body(mapper.writeValueAsString(body)) + .body(mapper.writeValueAsString(createRequest(input))) .response(""" { "results": [ @@ -268,9 +238,9 @@ void extract_dialogue_with_all_params_test() throws Exception { Hello Hi! What is your name?"""; - body = new TextGenerationRequest(modelId, projectId, input, parameters); + mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_CHAT_API, 200) - .body(mapper.writeValueAsString(body)) + .body(mapper.writeValueAsString(createRequest(input))) .response(""" { "results": [ @@ -290,17 +260,7 @@ void extract_dialogue_with_all_params_test() throws Exception { @Test void extract_dialogue_no_memory_test() throws Exception { - LangChain4jWatsonxConfig.WatsonConfig watsonConfig = langchain4jWatsonConfig.defaultConfig(); - ChatModelConfig chatModelConfig = watsonConfig.chatModel(); - String modelId = langchain4jWatsonFixedRuntimeConfig.defaultConfig().chatModel().modelId(); String chatMemoryId = "userId_with_all_params"; - String projectId = watsonConfig.projectId(); - Parameters parameters = Parameters.builder() - .decodingMethod(chatModelConfig.decodingMethod()) - .temperature(chatModelConfig.temperature()) - .minNewTokens(chatModelConfig.minNewTokens()) - .maxNewTokens(chatModelConfig.maxNewTokens()) - .build(); var input = """ Context: @@ -309,9 +269,9 @@ void extract_dialogue_no_memory_test() throws Exception { User: What is your name? Assistant: My name is AiBot Hello"""; - var body = new TextGenerationRequest(modelId, projectId, input, parameters); + mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_CHAT_API, 200) - .body(mapper.writeValueAsString(body)) + .body(mapper.writeValueAsString(createRequest(input))) .response(""" { "results": [ @@ -327,4 +287,20 @@ void extract_dialogue_no_memory_test() throws Exception { noMemoryAiService.rephrase(chatMemoryStore.getMessages(chatMemoryId), "Hello"); } + + private TextGenerationRequest createRequest(String input) { + LangChain4jWatsonxConfig.WatsonConfig watsonConfig = langchain4jWatsonConfig.defaultConfig(); + ChatModelConfig chatModelConfig = watsonConfig.chatModel(); + String modelId = langchain4jWatsonFixedRuntimeConfig.defaultConfig().chatModel().modelId(); + String projectId = watsonConfig.projectId(); + Parameters parameters = Parameters.builder() + .decodingMethod(chatModelConfig.decodingMethod()) + .temperature(chatModelConfig.temperature()) + .minNewTokens(chatModelConfig.minNewTokens()) + .maxNewTokens(chatModelConfig.maxNewTokens()) + .timeLimit(10000L) + .build(); + + return new TextGenerationRequest(modelId, projectId, input, parameters); + } } diff --git a/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/DefaultPropertiesTest.java b/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/DefaultPropertiesTest.java index 0722e6ec5..4638b2bfa 100644 --- a/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/DefaultPropertiesTest.java +++ b/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/DefaultPropertiesTest.java @@ -27,13 +27,10 @@ import dev.langchain4j.model.chat.TokenCountEstimator; import dev.langchain4j.model.embedding.EmbeddingModel; import dev.langchain4j.model.output.Response; -import io.quarkiverse.langchain4j.watsonx.WatsonxChatModel; import io.quarkiverse.langchain4j.watsonx.bean.EmbeddingRequest; import io.quarkiverse.langchain4j.watsonx.bean.Parameters; import io.quarkiverse.langchain4j.watsonx.bean.TextGenerationRequest; import io.quarkiverse.langchain4j.watsonx.bean.TokenizationRequest; -import io.quarkiverse.langchain4j.watsonx.prompt.impl.NoopPromptFormatter; -import io.quarkus.arc.ClientProxy; import io.quarkus.test.QuarkusUnitTest; public class DefaultPropertiesTest extends WireMockAbstract { @@ -73,12 +70,6 @@ void handlerBeforeEach() { @Inject TokenCountEstimator tokenCountEstimator; - @Test - void prompt_formatter() { - var unwrapChatModel = (WatsonxChatModel) ClientProxy.unwrap(chatModel); - assertTrue(unwrapChatModel.getPromptFormatter() instanceof NoopPromptFormatter); - } - @Test void check_config() throws Exception { var runtimeConfig = langchain4jWatsonConfig.defaultConfig(); diff --git a/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/PromptFormatterForceDefaultTest.java b/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/PromptFormatterForceDefaultTest.java deleted file mode 100644 index 446d54168..000000000 --- a/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/PromptFormatterForceDefaultTest.java +++ /dev/null @@ -1,94 +0,0 @@ -package io.quarkiverse.langchain4j.watsonx.deployment; - -import static org.junit.jupiter.api.Assertions.assertTrue; - -import jakarta.inject.Inject; -import jakarta.inject.Singleton; - -import org.jboss.shrinkwrap.api.ShrinkWrap; -import org.jboss.shrinkwrap.api.spec.JavaArchive; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.extension.RegisterExtension; - -import dev.langchain4j.model.chat.ChatLanguageModel; -import dev.langchain4j.model.chat.StreamingChatLanguageModel; -import dev.langchain4j.service.SystemMessage; -import dev.langchain4j.service.UserMessage; -import io.quarkiverse.langchain4j.ModelName; -import io.quarkiverse.langchain4j.RegisterAiService; -import io.quarkiverse.langchain4j.watsonx.WatsonxChatModel; -import io.quarkiverse.langchain4j.watsonx.WatsonxStreamingChatModel; -import io.quarkiverse.langchain4j.watsonx.prompt.impl.NoopPromptFormatter; -import io.quarkus.arc.ClientProxy; -import io.quarkus.test.QuarkusUnitTest; - -public class PromptFormatterForceDefaultTest { - - @RegisterExtension - static QuarkusUnitTest unitTest = new QuarkusUnitTest() - - .overrideRuntimeConfigKey("quarkus.langchain4j.model1.chat-model.provider", "watsonx") - .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.model1.base-url", WireMockUtil.URL_WATSONX_SERVER) - .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.model1.iam.base-url", WireMockUtil.URL_IAM_SERVER) - .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.model1.api-key", WireMockUtil.API_KEY) - .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.model1.project-id", WireMockUtil.PROJECT_ID) - .overrideConfigKey("quarkus.langchain4j.watsonx.model1.chat-model.prompt-formatter", "true") - .overrideRuntimeConfigKey("quarkus.langchain4j.model2.chat-model.provider", "watsonx") - .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.model2.base-url", WireMockUtil.URL_WATSONX_SERVER) - .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.model2.iam.base-url", WireMockUtil.URL_IAM_SERVER) - .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.model2.api-key", WireMockUtil.API_KEY) - .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.model2.project-id", WireMockUtil.PROJECT_ID) - .overrideConfigKey("quarkus.langchain4j.watsonx.model2.chat-model.prompt-formatter", "true") - .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class) - .addClass(WireMockUtil.class)); - - @RegisterAiService(modelName = "model1") - @Singleton - interface AIServiceWithTokenInSystemMessage { - @SystemMessage("<|system|>This is a systemMessage") - @UserMessage("{text}") - String chat(String text); - } - - @RegisterAiService(modelName = "model2") - @Singleton - interface AIServiceWithTokenInUserMessage { - @SystemMessage("This is a systemMessage") - @UserMessage("<|system|>{text}") - String chat(String text); - } - - @Inject - @ModelName("model1") - ChatLanguageModel model1ChatModel; - - @Inject - @ModelName("model1") - StreamingChatLanguageModel model1StreamingChatModel; - - @Inject - @ModelName("model2") - ChatLanguageModel model2ChatModel; - - @Inject - @ModelName("model2") - StreamingChatLanguageModel model2StreamingChatModel; - - @Test - void prompt_formatter_model_1() { - var unwrapChatModel = (WatsonxChatModel) ClientProxy.unwrap(model1ChatModel); - assertTrue(unwrapChatModel.getPromptFormatter() instanceof NoopPromptFormatter); - - var unwrapStreamingChatModel = (WatsonxStreamingChatModel) ClientProxy.unwrap(model1StreamingChatModel); - assertTrue(unwrapStreamingChatModel.getPromptFormatter() instanceof NoopPromptFormatter); - } - - @Test - void prompt_formatter_model_2() { - var unwrapChatModel = (WatsonxChatModel) ClientProxy.unwrap(model2ChatModel); - assertTrue(unwrapChatModel.getPromptFormatter() instanceof NoopPromptFormatter); - - var unwrapStreamingChatModel = (WatsonxStreamingChatModel) ClientProxy.unwrap(model2StreamingChatModel); - assertTrue(unwrapStreamingChatModel.getPromptFormatter() instanceof NoopPromptFormatter); - } -} diff --git a/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/PromptFormatterTest.java b/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/PromptFormatterTest.java index 99506d00d..dbdd4de9e 100644 --- a/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/PromptFormatterTest.java +++ b/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/PromptFormatterTest.java @@ -127,6 +127,7 @@ void tests() throws Exception { .temperature(chatModelConfig.temperature()) .minNewTokens(chatModelConfig.minNewTokens()) .maxNewTokens(chatModelConfig.maxNewTokens()) + .timeLimit(10000L) .build(); TextGenerationRequest body = new TextGenerationRequest(modelId, projectId, diff --git a/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/ResponseSchemaOnTest.java b/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/ResponseSchemaOnTest.java index 783976bd9..e0a890912 100644 --- a/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/ResponseSchemaOnTest.java +++ b/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/ResponseSchemaOnTest.java @@ -266,6 +266,7 @@ private TextGenerationRequest from(List messages) { .temperature(config.chatModel().temperature()) .minNewTokens(config.chatModel().minNewTokens()) .maxNewTokens(config.chatModel().maxNewTokens()) + .timeLimit(10000L) .build(); var input = messages.stream() diff --git a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxChatModel.java b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxChatModel.java deleted file mode 100644 index 57e9d10f0..000000000 --- a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxChatModel.java +++ /dev/null @@ -1,121 +0,0 @@ -package io.quarkiverse.langchain4j.watsonx; - -import java.util.List; -import java.util.Objects; -import java.util.concurrent.Callable; - -import dev.langchain4j.agent.tool.ToolSpecification; -import dev.langchain4j.data.message.AiMessage; -import dev.langchain4j.data.message.ChatMessage; -import dev.langchain4j.model.chat.ChatLanguageModel; -import dev.langchain4j.model.chat.TokenCountEstimator; -import dev.langchain4j.model.output.Response; -import dev.langchain4j.model.output.TokenUsage; -import io.quarkiverse.langchain4j.watsonx.bean.Parameters; -import io.quarkiverse.langchain4j.watsonx.bean.Parameters.LengthPenalty; -import io.quarkiverse.langchain4j.watsonx.bean.TextGenerationRequest; -import io.quarkiverse.langchain4j.watsonx.bean.TextGenerationResponse; -import io.quarkiverse.langchain4j.watsonx.bean.TextGenerationResponse.Result; -import io.quarkiverse.langchain4j.watsonx.bean.TokenizationRequest; - -public class WatsonxChatModel extends WatsonxModel implements ChatLanguageModel, TokenCountEstimator { - - public WatsonxChatModel(WatsonxModel.Builder builder) { - super(builder); - } - - @Override - public Response generate(List messages) { - - Parameters parameters = createParameters(); - TextGenerationRequest request = new TextGenerationRequest(modelId, projectId, toInput(messages), parameters); - - Result result = retryOn(new Callable() { - @Override - public TextGenerationResponse call() throws Exception { - return client.chat(request, version); - } - }).results().get(0); - - var finishReason = toFinishReason(result.stopReason()); - var content = AiMessage.from(result.generatedText()); - var tokenUsage = new TokenUsage( - result.inputTokenCount(), - result.generatedTokenCount()); - - return Response.from(content, tokenUsage, finishReason); - } - - @Override - public Response generate(List messages, List toolSpecifications) { - var parameters = createParameters(); - var request = new TextGenerationRequest(modelId, projectId, toInput(messages, toolSpecifications), parameters); - - Result result = retryOn(new Callable() { - @Override - public TextGenerationResponse call() throws Exception { - return client.chat(request, version); - } - }).results().get(0); - - var finishReason = toFinishReason(result.stopReason()); - var tokenUsage = new TokenUsage( - result.inputTokenCount(), - result.generatedTokenCount()); - - AiMessage content; - - if (result.generatedText().startsWith(promptFormatter.toolExecution())) { - var tools = result.generatedText().replace(promptFormatter.toolExecution(), ""); - content = AiMessage.from(promptFormatter.toolExecutionRequestFormatter(tools)); - } else { - content = AiMessage.from(result.generatedText()); - } - - return Response.from(content, tokenUsage, finishReason); - } - - @Override - public Response generate(List messages, ToolSpecification toolSpecification) { - return generate(messages, List.of(toolSpecification)); - } - - @Override - public int estimateTokenCount(List messages) { - - var input = toInput(messages); - var request = new TokenizationRequest(modelId, input, projectId); - - return retryOn(new Callable() { - @Override - public Integer call() throws Exception { - return client.tokenization(request, version).result().tokenCount(); - } - }); - } - - private Parameters createParameters() { - LengthPenalty lengthPenalty = null; - if (Objects.nonNull(decayFactor) || Objects.nonNull(startIndex)) { - lengthPenalty = new LengthPenalty(decayFactor, startIndex); - } - - Parameters parameters = Parameters.builder() - .decodingMethod(decodingMethod) - .lengthPenalty(lengthPenalty) - .minNewTokens(minNewTokens) - .maxNewTokens(maxNewTokens) - .randomSeed(randomSeed) - .stopSequences(stopSequences) - .temperature(temperature) - .timeLimit(timeLimit) - .topP(topP) - .topK(topK) - .repetitionPenalty(repetitionPenalty) - .truncateInputTokens(truncateInputTokens) - .includeStopSequence(includeStopSequence) - .build(); - - return parameters; - } -} diff --git a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxEmbeddingModel.java b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxEmbeddingModel.java index fa4d37dae..5de1d1287 100644 --- a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxEmbeddingModel.java +++ b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxEmbeddingModel.java @@ -1,10 +1,17 @@ package io.quarkiverse.langchain4j.watsonx; +import static io.quarkiverse.langchain4j.watsonx.WatsonxUtils.retryOn; + +import java.net.URL; +import java.time.Duration; import java.util.List; import java.util.Objects; import java.util.concurrent.Callable; +import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; +import org.jboss.resteasy.reactive.client.api.LoggingScope; + import dev.langchain4j.data.embedding.Embedding; import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.model.embedding.EmbeddingModel; @@ -14,11 +21,34 @@ import io.quarkiverse.langchain4j.watsonx.bean.EmbeddingResponse; import io.quarkiverse.langchain4j.watsonx.bean.EmbeddingResponse.Result; import io.quarkiverse.langchain4j.watsonx.bean.TokenizationRequest; +import io.quarkiverse.langchain4j.watsonx.client.WatsonxRestApi; +import io.quarkiverse.langchain4j.watsonx.client.filter.BearerTokenHeaderFactory; +import io.quarkus.rest.client.reactive.QuarkusRestClientBuilder; + +public class WatsonxEmbeddingModel implements EmbeddingModel, TokenCountEstimator { + + private final String modelId, projectId, version; + private final WatsonxRestApi client; + + public WatsonxEmbeddingModel(Builder builder) { -public class WatsonxEmbeddingModel extends WatsonxModel implements EmbeddingModel, TokenCountEstimator { + QuarkusRestClientBuilder restClientBuilder = QuarkusRestClientBuilder.newBuilder() + .baseUrl(builder.url) + .clientHeadersFactory(new BearerTokenHeaderFactory(builder.tokenGenerator)) + .connectTimeout(builder.timeout.toSeconds(), TimeUnit.SECONDS) + .readTimeout(builder.timeout.toSeconds(), TimeUnit.SECONDS); - public WatsonxEmbeddingModel(Builder config) { - super(config); + if (builder.logRequests || builder.logResponses) { + restClientBuilder.loggingScope(LoggingScope.REQUEST_RESPONSE); + restClientBuilder.clientLogger(new WatsonxRestApi.WatsonClientLogger( + builder.logRequests, + builder.logResponses)); + } + + this.client = restClientBuilder.build(WatsonxRestApi.class); + this.modelId = builder.modelId; + this.projectId = builder.projectId; + this.version = builder.version; } @Override @@ -61,4 +91,64 @@ public Integer call() throws Exception { } }); } + + public static Builder builder() { + return new Builder(); + } + + public static final class Builder { + + private String modelId; + private String version; + private String projectId; + private Duration timeout; + private boolean logResponses; + private boolean logRequests; + private URL url; + private WatsonxTokenGenerator tokenGenerator; + + public Builder modelId(String modelId) { + this.modelId = modelId; + return this; + } + + public Builder version(String version) { + this.version = version; + return this; + } + + public Builder projectId(String projectId) { + this.projectId = projectId; + return this; + } + + public Builder timeout(Duration timeout) { + this.timeout = timeout; + return this; + } + + public Builder logRequests(boolean logRequests) { + this.logRequests = logRequests; + return this; + } + + public Builder logResponses(boolean logResponses) { + this.logResponses = logResponses; + return this; + } + + public Builder url(URL url) { + this.url = url; + return this; + } + + public Builder tokenGenerator(WatsonxTokenGenerator tokenGenerator) { + this.tokenGenerator = tokenGenerator; + return this; + } + + public WatsonxEmbeddingModel build() { + return new WatsonxEmbeddingModel(this); + } + } } diff --git a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxGenerationModel.java b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxGenerationModel.java new file mode 100644 index 000000000..1122085b5 --- /dev/null +++ b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxGenerationModel.java @@ -0,0 +1,401 @@ +package io.quarkiverse.langchain4j.watsonx; + +import static io.quarkiverse.langchain4j.watsonx.WatsonxUtils.retryOn; + +import java.net.URL; +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; +import java.util.concurrent.Callable; +import java.util.concurrent.TimeUnit; +import java.util.function.Consumer; + +import org.jboss.logging.Logger; +import org.jboss.resteasy.reactive.client.api.LoggingScope; + +import dev.langchain4j.agent.tool.ToolSpecification; +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.model.StreamingResponseHandler; +import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.chat.StreamingChatLanguageModel; +import dev.langchain4j.model.chat.TokenCountEstimator; +import dev.langchain4j.model.output.FinishReason; +import dev.langchain4j.model.output.Response; +import dev.langchain4j.model.output.TokenUsage; +import io.quarkiverse.langchain4j.watsonx.bean.Parameters; +import io.quarkiverse.langchain4j.watsonx.bean.Parameters.LengthPenalty; +import io.quarkiverse.langchain4j.watsonx.bean.TextGenerationRequest; +import io.quarkiverse.langchain4j.watsonx.bean.TextGenerationResponse; +import io.quarkiverse.langchain4j.watsonx.bean.TextGenerationResponse.Result; +import io.quarkiverse.langchain4j.watsonx.bean.TokenizationRequest; +import io.quarkiverse.langchain4j.watsonx.client.WatsonxRestApi; +import io.quarkiverse.langchain4j.watsonx.client.filter.BearerTokenHeaderFactory; +import io.quarkiverse.langchain4j.watsonx.prompt.PromptFormatter; +import io.quarkus.rest.client.reactive.QuarkusRestClientBuilder; +import io.smallrye.mutiny.Context; + +public class WatsonxGenerationModel implements ChatLanguageModel, StreamingChatLanguageModel, TokenCountEstimator { + + private static final Logger log = Logger.getLogger(WatsonxGenerationModel.class); + + private final String modelId, projectId, version; + private final WatsonxRestApi client; + private final Parameters parameters; + private final PromptFormatter promptFormatter; + + public WatsonxGenerationModel(Builder builder) { + + QuarkusRestClientBuilder restClientBuilder = QuarkusRestClientBuilder.newBuilder() + .baseUrl(builder.url) + .clientHeadersFactory(new BearerTokenHeaderFactory(builder.tokenGenerator)) + .connectTimeout(builder.timeout.toSeconds(), TimeUnit.SECONDS) + .readTimeout(builder.timeout.toSeconds(), TimeUnit.SECONDS); + + if (builder.logRequests || builder.logResponses) { + restClientBuilder.loggingScope(LoggingScope.REQUEST_RESPONSE); + restClientBuilder.clientLogger(new WatsonxRestApi.WatsonClientLogger( + builder.logRequests, + builder.logResponses)); + } + + this.client = restClientBuilder.build(WatsonxRestApi.class); + this.modelId = builder.modelId; + this.projectId = builder.projectId; + this.version = builder.version; + + if (builder.promptFormatter != null) { + this.promptFormatter = builder.promptFormatter; + } else { + this.promptFormatter = null; + } + + LengthPenalty lengthPenalty = null; + if (Objects.nonNull(builder.decayFactor) || Objects.nonNull(builder.startIndex)) { + lengthPenalty = new LengthPenalty(builder.decayFactor, builder.startIndex); + } + + this.parameters = Parameters.builder() + .decodingMethod(builder.decodingMethod) + .lengthPenalty(lengthPenalty) + .minNewTokens(builder.minNewTokens) + .maxNewTokens(builder.maxNewTokens) + .randomSeed(builder.randomSeed) + .stopSequences(builder.stopSequences) + .temperature(builder.temperature) + .timeLimit(builder.timeout.toMillis()) + .topP(builder.topP) + .topK(builder.topK) + .repetitionPenalty(builder.repetitionPenalty) + .truncateInputTokens(builder.truncateInputTokens) + .includeStopSequence(builder.includeStopSequence) + .build(); + } + + @Override + public Response generate(List messages) { + TextGenerationRequest request = new TextGenerationRequest(modelId, projectId, toInput(messages), parameters); + + Result result = retryOn(new Callable() { + @Override + public TextGenerationResponse call() throws Exception { + return client.chat(request, version); + } + }).results().get(0); + + var finishReason = toFinishReason(result.stopReason()); + var content = AiMessage.from(result.generatedText()); + var tokenUsage = new TokenUsage( + result.inputTokenCount(), + result.generatedTokenCount()); + + return Response.from(content, tokenUsage, finishReason); + } + + @Override + public Response generate(List messages, List toolSpecifications) { + TextGenerationRequest request = new TextGenerationRequest(modelId, projectId, toInput(messages, toolSpecifications), + parameters); + + Result result = retryOn(new Callable() { + @Override + public TextGenerationResponse call() throws Exception { + return client.chat(request, version); + } + }).results().get(0); + + var finishReason = toFinishReason(result.stopReason()); + var tokenUsage = new TokenUsage( + result.inputTokenCount(), + result.generatedTokenCount()); + + AiMessage content; + + if (result.generatedText().startsWith(promptFormatter.toolExecution())) { + var tools = result.generatedText().replace(promptFormatter.toolExecution(), ""); + content = AiMessage.from(promptFormatter.toolExecutionRequestFormatter(tools)); + } else { + content = AiMessage.from(result.generatedText()); + } + + return Response.from(content, tokenUsage, finishReason); + } + + @Override + public Response generate(List messages, ToolSpecification toolSpecification) { + return generate(messages, List.of(toolSpecification)); + } + + @Override + public void generate(List messages, StreamingResponseHandler handler) { + TextGenerationRequest request = new TextGenerationRequest(modelId, projectId, toInput(messages), parameters); + Context context = Context.of("response", new ArrayList()); + + client.chatStreaming(request, version) + .subscribe() + .with(context, + new Consumer() { + @Override + @SuppressWarnings("unchecked") + public void accept(TextGenerationResponse response) { + try { + + if (response == null || response.results() == null || response.results().isEmpty()) + return; + + ((List) context.get("response")).add(response); + handler.onNext(response.results().get(0).generatedText()); + + } catch (Exception e) { + handler.onError(e); + } + } + }, + new Consumer() { + @Override + public void accept(Throwable error) { + handler.onError(error); + } + }, + new Runnable() { + @Override + @SuppressWarnings("unchecked") + public void run() { + var list = ((List) context.get("response")); + + int inputTokenCount = 0; + int outputTokenCount = 0; + String stopReason = null; + StringBuilder builder = new StringBuilder(); + + for (int i = 0; i < list.size(); i++) { + + TextGenerationResponse.Result response = list.get(i).results().get(0); + + if (i == 0) + inputTokenCount = response.inputTokenCount(); + + if (i == list.size() - 1) { + outputTokenCount = response.generatedTokenCount(); + stopReason = response.stopReason(); + } + + builder.append(response.generatedText()); + } + + AiMessage message = new AiMessage(builder.toString()); + TokenUsage tokenUsage = new TokenUsage(inputTokenCount, outputTokenCount); + FinishReason finishReason = toFinishReason(stopReason); + handler.onComplete(Response.from(message, tokenUsage, finishReason)); + } + }); + } + + @Override + public int estimateTokenCount(List messages) { + + var input = toInput(messages); + var request = new TokenizationRequest(modelId, input, projectId); + + return retryOn(new Callable() { + @Override + public Integer call() throws Exception { + return client.tokenization(request, version).result().tokenCount(); + } + }); + } + + public static Builder builder() { + return new Builder(); + } + + private String toInput(List messages) { + var prompt = promptFormatter.format(messages, List.of()); + log.debugf(""" + Formatted prompt: + ----------------- + %s + -----------------""", prompt); + return prompt; + } + + private String toInput(List messages, List tools) { + var prompt = promptFormatter.format(messages, tools); + log.debugf(""" + Formatted prompt: + ----------------- + %s + -----------------""", prompt); + return prompt; + } + + private FinishReason toFinishReason(String stopReason) { + return switch (stopReason) { + case "max_tokens", "token_limit" -> FinishReason.LENGTH; + case "eos_token", "stop_sequence" -> FinishReason.STOP; + case "not_finished", "cancelled", "time_limit", "error" -> FinishReason.OTHER; + default -> throw new IllegalArgumentException("%s not supported".formatted(stopReason)); + }; + } + + public static final class Builder { + + private String modelId; + private String version; + private String projectId; + private Duration timeout; + private String decodingMethod; + private Double decayFactor; + private Integer startIndex; + private Integer maxNewTokens; + private Integer minNewTokens; + private Integer randomSeed; + private List stopSequences; + private Double temperature; + private Integer topK; + private Double topP; + private Double repetitionPenalty; + private Integer truncateInputTokens; + private Boolean includeStopSequence; + private URL url; + public boolean logResponses; + public boolean logRequests; + private WatsonxTokenGenerator tokenGenerator; + private PromptFormatter promptFormatter; + + public Builder modelId(String modelId) { + this.modelId = modelId; + return this; + } + + public Builder version(String version) { + this.version = version; + return this; + } + + public Builder projectId(String projectId) { + this.projectId = projectId; + return this; + } + + public Builder url(URL url) { + this.url = url; + return this; + } + + public Builder timeout(Duration timeout) { + this.timeout = timeout; + return this; + } + + public Builder decodingMethod(String decodingMethod) { + this.decodingMethod = decodingMethod; + return this; + } + + public Builder decayFactor(Double decayFactor) { + this.decayFactor = decayFactor; + return this; + } + + public Builder startIndex(Integer startIndex) { + this.startIndex = startIndex; + return this; + } + + public Builder minNewTokens(Integer minNewTokens) { + this.minNewTokens = minNewTokens; + return this; + } + + public Builder maxNewTokens(Integer maxNewTokens) { + this.maxNewTokens = maxNewTokens; + return this; + } + + public Builder temperature(Double temperature) { + this.temperature = temperature; + return this; + } + + public Builder topK(Integer topK) { + this.topK = topK; + return this; + } + + public Builder topP(Double topP) { + this.topP = topP; + return this; + } + + public Builder randomSeed(Integer randomSeed) { + this.randomSeed = randomSeed; + return this; + } + + public Builder repetitionPenalty(Double repetitionPenalty) { + this.repetitionPenalty = repetitionPenalty; + return this; + } + + public Builder stopSequences(List stopSequences) { + this.stopSequences = stopSequences; + return this; + } + + public Builder truncateInputTokens(Integer truncateInputTokens) { + this.truncateInputTokens = truncateInputTokens; + return this; + } + + public Builder includeStopSequence(Boolean includeStopSequence) { + this.includeStopSequence = includeStopSequence; + return this; + } + + public Builder tokenGenerator(WatsonxTokenGenerator tokenGenerator) { + this.tokenGenerator = tokenGenerator; + return this; + } + + public Builder promptFormatter(PromptFormatter promptFormatter) { + this.promptFormatter = promptFormatter; + return this; + } + + public Builder logRequests(boolean logRequests) { + this.logRequests = logRequests; + return this; + } + + public Builder logResponses(boolean logResponses) { + this.logResponses = logResponses; + return this; + } + + public WatsonxGenerationModel build() { + return new WatsonxGenerationModel(this); + } + } +} diff --git a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxModel.java b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxModel.java deleted file mode 100644 index 66b39f5d3..000000000 --- a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxModel.java +++ /dev/null @@ -1,302 +0,0 @@ -package io.quarkiverse.langchain4j.watsonx; - -import java.net.URL; -import java.time.Duration; -import java.util.List; -import java.util.Optional; -import java.util.concurrent.Callable; -import java.util.concurrent.TimeUnit; - -import jakarta.ws.rs.WebApplicationException; - -import org.jboss.logging.Logger; -import org.jboss.resteasy.reactive.client.api.LoggingScope; - -import dev.langchain4j.agent.tool.ToolSpecification; -import dev.langchain4j.data.message.ChatMessage; -import dev.langchain4j.model.output.FinishReason; -import io.quarkiverse.langchain4j.watsonx.bean.WatsonxError; -import io.quarkiverse.langchain4j.watsonx.client.WatsonxRestApi; -import io.quarkiverse.langchain4j.watsonx.client.filter.BearerTokenHeaderFactory; -import io.quarkiverse.langchain4j.watsonx.exception.WatsonxException; -import io.quarkiverse.langchain4j.watsonx.prompt.PromptFormatter; -import io.quarkus.rest.client.reactive.QuarkusRestClientBuilder; - -public abstract class WatsonxModel { - - private static final Logger log = Logger.getLogger(WatsonxModel.class); - - final String modelId; - final String version; - final String projectId; - final String decodingMethod; - final Double decayFactor; - final Integer startIndex; - final Integer maxNewTokens; - final Integer minNewTokens; - final Integer randomSeed; - final List stopSequences; - final Double temperature; - final Long timeLimit; - final Double topP; - final Integer topK; - final Double repetitionPenalty; - final Integer truncateInputTokens; - final Boolean includeStopSequence; - final WatsonxRestApi client; - final PromptFormatter promptFormatter; - - public WatsonxModel(Builder builder) { - - QuarkusRestClientBuilder restClientBuilder = QuarkusRestClientBuilder.newBuilder() - .baseUrl(builder.url) - .clientHeadersFactory(new BearerTokenHeaderFactory(builder.tokenGenerator)) - .connectTimeout(builder.timeout.toSeconds(), TimeUnit.SECONDS) - .readTimeout(builder.timeout.toSeconds(), TimeUnit.SECONDS); - - if (builder.logRequests || builder.logResponses) { - restClientBuilder.loggingScope(LoggingScope.REQUEST_RESPONSE); - restClientBuilder.clientLogger(new WatsonxRestApi.WatsonClientLogger( - builder.logRequests, - builder.logResponses)); - } - - this.client = restClientBuilder.build(WatsonxRestApi.class); - this.modelId = builder.modelId; - this.version = builder.version; - this.projectId = builder.projectId; - this.decodingMethod = builder.decodingMethod; - this.decayFactor = builder.decayFactor; - this.startIndex = builder.startIndex; - this.maxNewTokens = builder.maxNewTokens; - this.minNewTokens = builder.minNewTokens; - this.randomSeed = builder.randomSeed; - this.stopSequences = builder.stopSequences; - this.temperature = builder.temperature; - this.timeLimit = builder.timeout.toMillis(); - this.topP = builder.topP; - this.topK = builder.topK; - this.repetitionPenalty = builder.repetitionPenalty; - this.truncateInputTokens = builder.truncateInputTokens; - this.includeStopSequence = builder.includeStopSequence; - - if (builder.promptFormatter != null) { - this.promptFormatter = builder.promptFormatter; - } else { - this.promptFormatter = null; - } - } - - public PromptFormatter getPromptFormatter() { - return promptFormatter; - } - - public static Builder builder() { - return new Builder(); - } - - protected String toInput(List messages) { - var prompt = promptFormatter.format(messages, List.of()); - log.debugf(""" - Formatted prompt: - ----------------- - %s - -----------------""", prompt); - return prompt; - } - - protected String toInput(List messages, List tools) { - var prompt = promptFormatter.format(messages, tools); - log.debugf(""" - Formatted prompt: - ----------------- - %s - -----------------""", prompt); - return prompt; - } - - protected FinishReason toFinishReason(String stopReason) { - return switch (stopReason) { - case "max_tokens" -> FinishReason.LENGTH; - case "eos_token", "stop_sequence" -> FinishReason.STOP; - default -> throw new IllegalArgumentException("%s not supported".formatted(stopReason)); - }; - } - - protected static T retryOn(Callable action) { - int maxAttempts = 1; - for (int i = 0; i <= maxAttempts; i++) { - - try { - - return action.call(); - - } catch (WatsonxException e) { - - if (e.details() == null || e.details().errors() == null || e.details().errors().size() == 0) - throw e; - - Optional optional = Optional.empty(); - for (WatsonxError.Error error : e.details().errors()) { - if (WatsonxError.Code.AUTHENTICATION_TOKEN_EXPIRED.equals(error.code())) { - optional = Optional.of(error.code()); - break; - } - } - - if (!optional.isPresent()) - throw e; - - } catch (WebApplicationException e) { - throw e; - } catch (Exception e) { - throw new RuntimeException(e); - } - } - throw new RuntimeException("Failed after " + maxAttempts + " attempts"); - } - - public static final class Builder { - - private String modelId; - private String version; - private String projectId; - private Duration timeout; - private String decodingMethod; - private Double decayFactor; - private Integer startIndex; - private Integer maxNewTokens; - private Integer minNewTokens; - private Integer randomSeed; - private List stopSequences; - private Double temperature; - private Integer topK; - private Double topP; - private Double repetitionPenalty; - private Integer truncateInputTokens; - private Boolean includeStopSequence; - private URL url; - public boolean logResponses; - public boolean logRequests; - private TokenGenerator tokenGenerator; - private PromptFormatter promptFormatter; - - public Builder modelId(String modelId) { - this.modelId = modelId; - return this; - } - - public Builder version(String version) { - this.version = version; - return this; - } - - public Builder projectId(String projectId) { - this.projectId = projectId; - return this; - } - - public Builder url(URL url) { - this.url = url; - return this; - } - - public Builder timeout(Duration timeout) { - this.timeout = timeout; - return this; - } - - public Builder decodingMethod(String decodingMethod) { - this.decodingMethod = decodingMethod; - return this; - } - - public Builder decayFactor(Double decayFactor) { - this.decayFactor = decayFactor; - return this; - } - - public Builder startIndex(Integer startIndex) { - this.startIndex = startIndex; - return this; - } - - public Builder minNewTokens(Integer minNewTokens) { - this.minNewTokens = minNewTokens; - return this; - } - - public Builder maxNewTokens(Integer maxNewTokens) { - this.maxNewTokens = maxNewTokens; - return this; - } - - public Builder temperature(Double temperature) { - this.temperature = temperature; - return this; - } - - public Builder topK(Integer topK) { - this.topK = topK; - return this; - } - - public Builder topP(Double topP) { - this.topP = topP; - return this; - } - - public Builder randomSeed(Integer randomSeed) { - this.randomSeed = randomSeed; - return this; - } - - public Builder repetitionPenalty(Double repetitionPenalty) { - this.repetitionPenalty = repetitionPenalty; - return this; - } - - public Builder stopSequences(List stopSequences) { - this.stopSequences = stopSequences; - return this; - } - - public Builder truncateInputTokens(Integer truncateInputTokens) { - this.truncateInputTokens = truncateInputTokens; - return this; - } - - public Builder includeStopSequence(Boolean includeStopSequence) { - this.includeStopSequence = includeStopSequence; - return this; - } - - public Builder tokenGenerator(TokenGenerator tokenGenerator) { - this.tokenGenerator = tokenGenerator; - return this; - } - - public Builder promptFormatter(PromptFormatter promptFormatter) { - this.promptFormatter = promptFormatter; - return this; - } - - public Builder logRequests(boolean logRequests) { - this.logRequests = logRequests; - return this; - } - - public Builder logResponses(boolean logResponses) { - this.logResponses = logResponses; - return this; - } - - public T build(Class clazz) { - try { - return clazz.getConstructor(Builder.class).newInstance(this); - } catch (Exception e) { - throw new RuntimeException(e); - } - } - } -} diff --git a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxStreamingChatModel.java b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxStreamingChatModel.java deleted file mode 100644 index 8ab87a703..000000000 --- a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxStreamingChatModel.java +++ /dev/null @@ -1,129 +0,0 @@ -package io.quarkiverse.langchain4j.watsonx; - -import java.util.ArrayList; -import java.util.List; -import java.util.Objects; -import java.util.concurrent.Callable; -import java.util.function.Consumer; - -import dev.langchain4j.data.message.AiMessage; -import dev.langchain4j.data.message.ChatMessage; -import dev.langchain4j.model.StreamingResponseHandler; -import dev.langchain4j.model.chat.StreamingChatLanguageModel; -import dev.langchain4j.model.chat.TokenCountEstimator; -import dev.langchain4j.model.output.FinishReason; -import dev.langchain4j.model.output.Response; -import dev.langchain4j.model.output.TokenUsage; -import io.quarkiverse.langchain4j.watsonx.bean.Parameters; -import io.quarkiverse.langchain4j.watsonx.bean.Parameters.LengthPenalty; -import io.quarkiverse.langchain4j.watsonx.bean.TextGenerationRequest; -import io.quarkiverse.langchain4j.watsonx.bean.TextGenerationResponse; -import io.quarkiverse.langchain4j.watsonx.bean.TokenizationRequest; -import io.smallrye.mutiny.Context; - -public class WatsonxStreamingChatModel extends WatsonxModel implements StreamingChatLanguageModel, TokenCountEstimator { - - public WatsonxStreamingChatModel(WatsonxModel.Builder config) { - super(config); - } - - @Override - public void generate(List messages, StreamingResponseHandler handler) { - - LengthPenalty lengthPenalty = null; - if (Objects.nonNull(decayFactor) || Objects.nonNull(startIndex)) { - lengthPenalty = new LengthPenalty(decayFactor, startIndex); - } - - Parameters parameters = Parameters.builder() - .decodingMethod(decodingMethod) - .lengthPenalty(lengthPenalty) - .minNewTokens(minNewTokens) - .maxNewTokens(maxNewTokens) - .randomSeed(randomSeed) - .stopSequences(stopSequences) - .temperature(temperature) - .timeLimit(timeLimit) - .topP(topP) - .topK(topK) - .repetitionPenalty(repetitionPenalty) - .truncateInputTokens(truncateInputTokens) - .includeStopSequence(includeStopSequence) - .build(); - - TextGenerationRequest request = new TextGenerationRequest(modelId, projectId, toInput(messages), parameters); - Context context = Context.of("response", new ArrayList()); - - client.chatStreaming(request, version) - .subscribe() - .with(context, - new Consumer() { - @Override - @SuppressWarnings("unchecked") - public void accept(TextGenerationResponse response) { - try { - - if (response == null || response.results() == null || response.results().isEmpty()) - return; - - ((List) context.get("response")).add(response); - handler.onNext(response.results().get(0).generatedText()); - - } catch (Exception e) { - handler.onError(e); - } - } - }, - new Consumer() { - @Override - public void accept(Throwable error) { - handler.onError(error); - } - }, - new Runnable() { - @Override - @SuppressWarnings("unchecked") - public void run() { - var list = ((List) context.get("response")); - - int inputTokenCount = 0; - int outputTokenCount = 0; - String stopReason = null; - StringBuilder builder = new StringBuilder(); - - for (int i = 0; i < list.size(); i++) { - - TextGenerationResponse.Result response = list.get(i).results().get(0); - - if (i == 0) - inputTokenCount = response.inputTokenCount(); - - if (i == list.size() - 1) { - outputTokenCount = response.generatedTokenCount(); - stopReason = response.stopReason(); - } - - builder.append(response.generatedText()); - } - - AiMessage message = new AiMessage(builder.toString()); - TokenUsage tokenUsage = new TokenUsage(inputTokenCount, outputTokenCount); - FinishReason finishReason = toFinishReason(stopReason); - handler.onComplete(Response.from(message, tokenUsage, finishReason)); - } - }); - } - - @Override - public int estimateTokenCount(List messages) { - - var input = toInput(messages); - var request = new TokenizationRequest(modelId, input, projectId); - return retryOn(new Callable() { - @Override - public Integer call() throws Exception { - return client.tokenization(request, version).result().tokenCount(); - } - }); - } -} diff --git a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/TokenGenerator.java b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxTokenGenerator.java similarity index 94% rename from model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/TokenGenerator.java rename to model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxTokenGenerator.java index 332affeb5..30b3e49c8 100644 --- a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/TokenGenerator.java +++ b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxTokenGenerator.java @@ -13,7 +13,7 @@ import io.quarkus.rest.client.reactive.QuarkusRestClientBuilder; import io.smallrye.mutiny.Uni; -public class TokenGenerator { +public class WatsonxTokenGenerator { private final static Semaphore lock = new Semaphore(1); private final IAMRestApi client; @@ -21,7 +21,7 @@ public class TokenGenerator { private final String grantType; private IdentityTokenResponse token; - public TokenGenerator(URL url, Duration timeout, String grantType, String apiKey) { + public WatsonxTokenGenerator(URL url, Duration timeout, String grantType, String apiKey) { this.client = QuarkusRestClientBuilder.newBuilder() .baseUrl(url) diff --git a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxUtils.java b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxUtils.java new file mode 100644 index 000000000..589c2830b --- /dev/null +++ b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxUtils.java @@ -0,0 +1,45 @@ +package io.quarkiverse.langchain4j.watsonx; + +import java.util.Optional; +import java.util.concurrent.Callable; + +import jakarta.ws.rs.WebApplicationException; + +import io.quarkiverse.langchain4j.watsonx.bean.WatsonxError; +import io.quarkiverse.langchain4j.watsonx.exception.WatsonxException; + +public class WatsonxUtils { + + public static T retryOn(Callable action) { + int maxAttempts = 1; + for (int i = 0; i <= maxAttempts; i++) { + + try { + + return action.call(); + + } catch (WatsonxException e) { + + if (e.details() == null || e.details().errors() == null || e.details().errors().size() == 0) + throw e; + + Optional optional = Optional.empty(); + for (WatsonxError.Error error : e.details().errors()) { + if (WatsonxError.Code.AUTHENTICATION_TOKEN_EXPIRED.equals(error.code())) { + optional = Optional.of(error.code()); + break; + } + } + + if (!optional.isPresent()) + throw e; + + } catch (WebApplicationException e) { + throw e; + } catch (Exception e) { + throw new RuntimeException(e); + } + } + throw new RuntimeException("Failed after " + maxAttempts + " attempts"); + } +} diff --git a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/client/filter/BearerTokenHeaderFactory.java b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/client/filter/BearerTokenHeaderFactory.java index 2939c5576..6bb89268e 100644 --- a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/client/filter/BearerTokenHeaderFactory.java +++ b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/client/filter/BearerTokenHeaderFactory.java @@ -4,7 +4,7 @@ import jakarta.ws.rs.core.MultivaluedMap; -import io.quarkiverse.langchain4j.watsonx.TokenGenerator; +import io.quarkiverse.langchain4j.watsonx.WatsonxTokenGenerator; import io.quarkus.rest.client.reactive.ReactiveClientHeadersFactory; import io.smallrye.mutiny.Uni; @@ -13,9 +13,9 @@ */ public class BearerTokenHeaderFactory extends ReactiveClientHeadersFactory { - private TokenGenerator tokenGenerator; + private WatsonxTokenGenerator tokenGenerator; - public BearerTokenHeaderFactory(TokenGenerator tokenGenerator) { + public BearerTokenHeaderFactory(WatsonxTokenGenerator tokenGenerator) { this.tokenGenerator = tokenGenerator; } diff --git a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/runtime/WatsonxRecorder.java b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/runtime/WatsonxRecorder.java index 49c5a2162..137a2a553 100644 --- a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/runtime/WatsonxRecorder.java +++ b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/runtime/WatsonxRecorder.java @@ -20,11 +20,9 @@ import dev.langchain4j.model.embedding.DisabledEmbeddingModel; import dev.langchain4j.model.embedding.EmbeddingModel; import io.quarkiverse.langchain4j.runtime.NamedConfigUtil; -import io.quarkiverse.langchain4j.watsonx.TokenGenerator; -import io.quarkiverse.langchain4j.watsonx.WatsonxChatModel; import io.quarkiverse.langchain4j.watsonx.WatsonxEmbeddingModel; -import io.quarkiverse.langchain4j.watsonx.WatsonxModel; -import io.quarkiverse.langchain4j.watsonx.WatsonxStreamingChatModel; +import io.quarkiverse.langchain4j.watsonx.WatsonxGenerationModel; +import io.quarkiverse.langchain4j.watsonx.WatsonxTokenGenerator; import io.quarkiverse.langchain4j.watsonx.prompt.PromptFormatter; import io.quarkiverse.langchain4j.watsonx.prompt.impl.NoopPromptFormatter; import io.quarkiverse.langchain4j.watsonx.runtime.config.ChatModelConfig; @@ -43,10 +41,10 @@ public class WatsonxRecorder { private static final String DUMMY_URL = "https://dummy.ai/api"; private static final String DUMMY_API_KEY = "dummy"; private static final String DUMMY_PROJECT_ID = "dummy"; - private static final Map tokenGeneratorCache = new HashMap<>(); + private static final Map tokenGeneratorCache = new HashMap<>(); private static final ConfigValidationException.Problem[] EMPTY_PROBLEMS = new ConfigValidationException.Problem[0]; - public Supplier chatModel(LangChain4jWatsonxConfig runtimeConfig, + public Supplier generationModel(LangChain4jWatsonxConfig runtimeConfig, LangChain4jWatsonxFixedRuntimeConfig fixedRuntimeConfig, String configName, PromptFormatter promptFormatter) { @@ -65,7 +63,7 @@ public Supplier chatModel(LangChain4jWatsonxConfig runtimeCon return new Supplier<>() { @Override public ChatLanguageModel get() { - return builder.build(WatsonxChatModel.class); + return builder.build(); } }; @@ -79,7 +77,7 @@ public ChatLanguageModel get() { } } - public Supplier streamingChatModel(LangChain4jWatsonxConfig runtimeConfig, + public Supplier generationStreamingModel(LangChain4jWatsonxConfig runtimeConfig, LangChain4jWatsonxFixedRuntimeConfig fixedRuntimeConfig, String configName, PromptFormatter promptFormatter) { LangChain4jWatsonxConfig.WatsonConfig watsonRuntimeConfig = correspondingWatsonRuntimeConfig(runtimeConfig, configName); @@ -92,7 +90,7 @@ public Supplier streamingChatModel(LangChain4jWatson return new Supplier<>() { @Override public StreamingChatLanguageModel get() { - return builder.build(WatsonxStreamingChatModel.class); + return builder.build(); } }; @@ -117,7 +115,7 @@ public Supplier embeddingModel(LangChain4jWatsonxConfig runtimeC } String iamUrl = watsonConfig.iam().baseUrl().toExternalForm(); - TokenGenerator tokenGenerator = tokenGeneratorCache.computeIfAbsent(iamUrl, + WatsonxTokenGenerator tokenGenerator = tokenGeneratorCache.computeIfAbsent(iamUrl, createTokenGenerator(watsonConfig.iam(), watsonConfig.apiKey())); URL url; @@ -141,7 +139,7 @@ public Supplier embeddingModel(LangChain4jWatsonxConfig runtimeC return new Supplier<>() { @Override public WatsonxEmbeddingModel get() { - return builder.build(WatsonxEmbeddingModel.class); + return builder.build(); } }; @@ -155,18 +153,18 @@ public EmbeddingModel get() { } } - private Function createTokenGenerator(IAMConfig iamConfig, String apiKey) { - return new Function() { + private Function createTokenGenerator(IAMConfig iamConfig, String apiKey) { + return new Function() { @Override - public TokenGenerator apply(String iamUrl) { - return new TokenGenerator(iamConfig.baseUrl(), iamConfig.timeout().orElse(Duration.ofSeconds(10)), + public WatsonxTokenGenerator apply(String iamUrl) { + return new WatsonxTokenGenerator(iamConfig.baseUrl(), iamConfig.timeout().orElse(Duration.ofSeconds(10)), iamConfig.grantType(), apiKey); } }; } - private WatsonxModel.Builder generateChatBuilder( + private WatsonxGenerationModel.Builder generateChatBuilder( LangChain4jWatsonxConfig.WatsonConfig watsonRuntimeConfig, LangChain4jWatsonxFixedRuntimeConfig.WatsonConfig watsonFixedRuntimeConfig, String configName, PromptFormatter promptFormatter) { @@ -179,7 +177,7 @@ private WatsonxModel.Builder generateChatBuilder( } String iamUrl = watsonRuntimeConfig.iam().baseUrl().toExternalForm(); - TokenGenerator tokenGenerator = tokenGeneratorCache.computeIfAbsent(iamUrl, + WatsonxTokenGenerator tokenGenerator = tokenGeneratorCache.computeIfAbsent(iamUrl, createTokenGenerator(watsonRuntimeConfig.iam(), watsonRuntimeConfig.apiKey())); URL url; @@ -193,7 +191,7 @@ private WatsonxModel.Builder generateChatBuilder( Integer startIndex = chatModelConfig.lengthPenalty().startIndex().orElse(null); String promptJoiner = chatModelConfig.promptJoiner(); - return WatsonxChatModel.builder() + return WatsonxGenerationModel.builder() .tokenGenerator(tokenGenerator) .url(url) .timeout(watsonRuntimeConfig.timeout().orElse(Duration.ofSeconds(10))) diff --git a/model-providers/watsonx/runtime/src/test/java/io/quarkiverse/langchain4j/watsonx/runtime/DisabledModelsWatsonRecorderTest.java b/model-providers/watsonx/runtime/src/test/java/io/quarkiverse/langchain4j/watsonx/runtime/DisabledModelsWatsonRecorderTest.java index ef7f49426..ca9de010d 100644 --- a/model-providers/watsonx/runtime/src/test/java/io/quarkiverse/langchain4j/watsonx/runtime/DisabledModelsWatsonRecorderTest.java +++ b/model-providers/watsonx/runtime/src/test/java/io/quarkiverse/langchain4j/watsonx/runtime/DisabledModelsWatsonRecorderTest.java @@ -34,12 +34,13 @@ void setupMocks() { @Test void disabledChatModel() { assertThat(recorder - .chatModel(runtimeConfig, fixedRuntimeConfig, NamedConfigUtil.DEFAULT_NAME, null) + .generationModel(runtimeConfig, fixedRuntimeConfig, NamedConfigUtil.DEFAULT_NAME, null) .get()) .isNotNull() .isExactlyInstanceOf(DisabledChatLanguageModel.class); - assertThat(recorder.streamingChatModel(runtimeConfig, fixedRuntimeConfig, NamedConfigUtil.DEFAULT_NAME, null).get()) + assertThat( + recorder.generationStreamingModel(runtimeConfig, fixedRuntimeConfig, NamedConfigUtil.DEFAULT_NAME, null).get()) .isNotNull() .isExactlyInstanceOf(DisabledStreamingChatLanguageModel.class); From 70688363a23f7cc5a30e07b0f55bd2a9c32e65e2 Mon Sep 17 00:00:00 2001 From: Andrea Di Maio Date: Fri, 11 Oct 2024 09:22:33 +0200 Subject: [PATCH 3/4] [watsonx.ai] Enable /chat and /chat_stream APIs --- docs/modules/ROOT/pages/watsonx.adoc | 112 ++---- .../src/main/resources/application.properties | 6 + .../multiple/MultipleChatProvidersTest.java | 12 +- ...tipleTokenCountEstimatorProvidersTest.java | 6 +- .../watsonx/deployment/WatsonxProcessor.java | 128 +++--- .../WatsonxChatModelProviderBuildItem.java | 8 +- .../watsonx/deployment/AiChatServiceTest.java | 369 +++++++++++++++-- .../deployment/AiGenerationServiceTest.java | 270 +++++++++++++ .../watsonx/deployment/CacheTokenTest.java | 15 +- .../deployment/ChatAllPropertiesTest.java | 167 ++++++++ .../deployment/ChatDefaultPropertiesTest.java | 154 +++++++ .../deployment/ChatMemoryPlaceholderTest.java | 21 +- ....java => GenerationAllPropertiesTest.java} | 17 +- ...a => GenerationDefaultPropertiesTest.java} | 17 +- .../watsonx/deployment/HttpErrorTest.java | 119 +++++- .../PromptFormatterExceptionTest.java | 7 +- .../deployment/PromptFormatterTest.java | 11 +- .../deployment/ResponseSchemaOnTest.java | 21 +- .../deployment/TokenCountEstimatorTest.java | 13 +- .../watsonx/deployment/WireMockUtil.java | 67 ++- .../langchain4j/watsonx/WatsonxChatModel.java | 380 ++++++++++++++++++ .../watsonx/WatsonxGenerationModel.java | 145 ++++--- .../langchain4j/watsonx/WatsonxUtils.java | 6 +- .../watsonx/bean/TextChatMessage.java | 258 ++++++++++++ .../watsonx/bean/TextChatParameters.java | 87 ++++ .../watsonx/bean/TextChatRequest.java | 12 + .../watsonx/bean/TextChatResponse.java | 32 ++ ...ers.java => TextGenerationParameters.java} | 8 +- .../watsonx/bean/TextGenerationRequest.java | 2 +- .../bean/TextStreamingChatResponse.java | 32 ++ .../watsonx/bean/WatsonxError.java | 16 +- .../watsonx/client/WatsonxRestApi.java | 18 +- .../watsonx/runtime/WatsonxRecorder.java | 126 +++++- .../runtime/config/ChatModelConfig.java | 61 ++- .../config/ChatModelFixedRuntimeConfig.java | 20 +- 35 files changed, 2418 insertions(+), 325 deletions(-) create mode 100644 model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/AiGenerationServiceTest.java create mode 100644 model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/ChatAllPropertiesTest.java create mode 100644 model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/ChatDefaultPropertiesTest.java rename model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/{AllPropertiesTest.java => GenerationAllPropertiesTest.java} (94%) rename model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/{DefaultPropertiesTest.java => GenerationDefaultPropertiesTest.java} (93%) create mode 100644 model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxChatModel.java create mode 100644 model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/bean/TextChatMessage.java create mode 100644 model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/bean/TextChatParameters.java create mode 100644 model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/bean/TextChatRequest.java create mode 100644 model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/bean/TextChatResponse.java rename model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/bean/{Parameters.java => TextGenerationParameters.java} (95%) create mode 100644 model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/bean/TextStreamingChatResponse.java diff --git a/docs/modules/ROOT/pages/watsonx.adoc b/docs/modules/ROOT/pages/watsonx.adoc index 0a3249cd7..d837bb40d 100644 --- a/docs/modules/ROOT/pages/watsonx.adoc +++ b/docs/modules/ROOT/pages/watsonx.adoc @@ -62,111 +62,75 @@ quarkus.langchain4j.watsonx.api-key=hG-... NOTE: To determine the API key, go to https://cloud.ibm.com/iam/apikeys and generate it. -==== Writing prompts +==== Interacting with Models -When creating prompts using watsonx.ai, it's important to follow the guidelines of the model you choose. Depending on the model, some special instructions may be required to ensure the desired output. For best results, always refer to the documentation provided for each model to maximize the effectiveness of your prompts. +The `watsonx.ai` module provides two different modes for interacting with LLM models: `generation` and `chat`. These modes allow you to tailor the interaction based on the complexity of your use case and how much control you want to have over the prompt structure. -To simplify the process of prompt creation, you can use the `prompt-formatter` property to automatically handle the addition of tags to your prompts. This property allows you to avoid manually adding tags by letting the system handle the formatting based on the model's requirements. This functionality is particularly useful for models such as `ibm/granite-13b-chat-v2`, `meta-llama/llama-3-405b-instruct`, and other supported models, ensuring consistent and accurate prompt structures without additional effort. +You can select the interaction mode using the property `quarkus.langchain4j.watsonx.chat-model.mode`. -To enable this functionality, configure the `prompt-formatter` property in your `application.properties` file as follows: +* `generation`: In this mode, you must explicitly structure the prompts using the required model-specific tags. This provides full control over the format of the prompt, but requires in-depth knowledge of the model being used. For best results, always refer to the documentation provided of each model to maximize the effectiveness of your prompts. +* `chat`: This mode abstracts the complexity of tagging by automatically formatting prompts so you can focus on the content (*default value*). + +To choose between one of these two modes, add the `chat-model.mode` property to your `application.properties` file: [source,properties,subs=attributes+] ---- -quarkus.langchain4j.watsonx.chat-model.prompt-formatter=true +quarkus.langchain4j.watsonx.chat-model.mode=chat // or 'generate' ---- -When this property is set to `true`, the system will automatically format prompts with the appropriate tags. This helps to maintain prompt clarity and improves interaction with the LLM by ensuring that prompts follow the required structure. If set to `false`, you'll need to manage the tags manually. +==== Chat Mode + +In `chat` mode, you can interact with models without having to manually manage the tags of a prompt. + +You might choose this mode if you are looking for dynamic interactions where the model can build on previous messages and provide more contextually relevant responses. This mode simplifies the interaction by automatically managing the necessary tags, allowing you to focus on the content of your prompts rather than formatting. -For example, if you choose to use `ibm/granite-13b-chat-v2` without using the `prompt-formatter`, you will need to manually add the `<|system|>`, `<|user|>` and `<|assistant|>` instructions: +Chat mode also supports the use of `tools`, allowing the model to perform specific actions or retrieve external data as part of its responses. This extends the capabilities of the model, allowing it to perform complex tasks dynamically and adapt to your needs. More information about tools is available on the xref:./agent-and-tools.adoc[Agent and Tools] page. [source,properties,subs=attributes+] ---- -quarkus.langchain4j.watsonx.api-key=hG-... -quarkus.langchain4j.watsonx.base-url=https://us-south.ml.cloud.ibm.com -quarkus.langchain4j.watsonx.chat-model.model-id=ibm/granite-13b-chat-v2 -quarkus.langchain4j.watsonx.chat-model.prompt-formatter=false +quarkus.langchain4j.watsonx.base-url=${BASE_URL} +quarkus.langchain4j.watsonx.api-key=${API_KEY} +quarkus.langchain4j.watsonx.project-id=${PROJECT_ID} +quarkus.langchain4j.watsonx.chat-model.model-id=mistralai/mistral-large +quarkus.langchain4j.watsonx.chat-model.mode=chat ---- [source,java] ---- @RegisterAiService -public interface LLMService { - - public record Result(Integer result) {} - - @SystemMessage(""" - <|system|> - You are a calculator and you must perform the mathematical operation - {response_schema} - """) - @UserMessage(""" - <|user|> - {firstNumber} + {secondNumber} - <|assistant|> - """) - public Result calculator(int firstNumber, int secondNumber); +public interface AiService { + @SystemMessage("You are a helpful assistant") + public String chat(@MemoryId String id, @UserMessage message); } ---- -Enabling the `prompt-formatter` will result in: +NOTE: The availability of `chat` and `tools` is currently limited to certain models. Not all models support these features, so be sure to consult the documentation for the specific model you are using to confirm whether these features are available. + +==== Generation Mode + +In `generation` mode, you have complete control over the structure of your prompts by manually specifying tags for a specific model. This mode could be useful in scenarios where a single-response is desired. [source,properties,subs=attributes+] ---- -quarkus.langchain4j.watsonx.api-key=hG-... -quarkus.langchain4j.watsonx.base-url=https://us-south.ml.cloud.ibm.com -quarkus.langchain4j.watsonx.chat-model.model-id=ibm/granite-13b-chat-v2 -quarkus.langchain4j.watsonx.chat-model.prompt-formatter=true +quarkus.langchain4j.watsonx.base-url=${BASE_URL} +quarkus.langchain4j.watsonx.api-key=${API_KEY} +quarkus.langchain4j.watsonx.project-id=${PROJECT_ID} +quarkus.langchain4j.watsonx.chat-model.model-id=mistralai/mistral-large +quarkus.langchain4j.watsonx.chat-model.mode=generation ---- [source,java] ---- -@RegisterAiService -public interface LLMService { - - public record Result(Integer result) {} - - @SystemMessage(""" - You are a calculator and you must perform the mathematical operation - {response_schema} - """) +@RegisterAiService(chatMemoryProviderSupplier = RegisterAiService.NoChatMemoryProviderSupplier.class) +public interface AiService { @UserMessage(""" - {firstNumber} + {secondNumber} - """) - public Result calculator(int firstNumber, int secondNumber); + [INST] You are a helpful assistant [/INST]\ + [INST] What is the capital of {capital}? [/INST]""") + public String askCapital(String capital); } ---- -The `prompt-formatter` supports the following models: - -* `mistralai/mistral-large` -* `mistralai/mixtral-8x7b-instruct-v01` -* `sdaia/allam-1-13b-instruct` -* `meta-llama/llama-3-405b-instruct` -* `meta-llama/llama-3-1-70b-instruct` -* `meta-llama/llama-3-1-8b-instruct` -* `meta-llama/llama-3-70b-instruct` -* `meta-llama/llama-3-8b-instruct` -* `ibm/granite-13b-chat-v2` -* `ibm/granite-13b-instruct-v2` -* `ibm/granite-7b-lab` -* `ibm/granite-20b-code-instruct` -* `ibm/granite-34b-code-instruct` -* `ibm/granite-3b-code-instruct` -* `ibm/granite-8b-code-instruct` - -==== Tool Execution with Prompt Formatter - -In addition to simplifying prompt creation, the `prompt-formatter` property also enables the execution of tools for specific models. Tools allow for dynamic interactions within the model, enabling the AI to perform specific actions or fetch data as part of its response. - -When the `prompt-formatter` is enabled and a supported model is selected, the prompt will be automatically formatted to use the tools. More information about tools is available in the xref:./agent-and-tools.adoc[Agent and Tools] page. - -Currently, the following model supports tool execution: - -* `mistralai/mistral-large` -* `meta-llama/llama-3-405b-instruct` -* `meta-llama/llama-3-1-70b-instruct` - -IMPORTANT: The `@SystemMessage` and `@UserMessage` annotations are joined by default with a new line. If you want to change this behavior, use the property `quarkus.langchain4j.watsonx.chat-model.prompt-joiner=`. By adjusting this property, you can define your preferred way of joining messages and ensure that the prompt structure meets your specific needs. This customization option is available only when the `prompt-formatter` property is set to `false`. When the `prompt-formatter` is enabled (set to `true`), the prompt formatting, including the addition of tags and message joining, is automatically handled. In this case, the `prompt-joiner` property will be ignored, and you will not have the ability to customize how messages are joined. +NOTE: The `@SystemMessage` and `@UserMessage` annotations are joined by default with a new line. If you want to change this behavior, use the property `quarkus.langchain4j.watsonx.chat-model.prompt-joiner=`. By adjusting this property, you can define your preferred way of joining messages and ensure that the prompt structure meets your specific needs. NOTE: Sometimes it may be useful to use the `quarkus.langchain4j.watsonx.chat-model.stop-sequences` property to prevent the LLM model from returning more results than desired. diff --git a/integration-tests/multiple-providers/src/main/resources/application.properties b/integration-tests/multiple-providers/src/main/resources/application.properties index 054e96f8a..08f91d7c9 100644 --- a/integration-tests/multiple-providers/src/main/resources/application.properties +++ b/integration-tests/multiple-providers/src/main/resources/application.properties @@ -33,6 +33,12 @@ quarkus.langchain4j.watsonx.c7.base-url=https://somecluster.somedomain.ai:443/ap quarkus.langchain4j.watsonx.c7.api-key=test8 quarkus.langchain4j.watsonx.c7.project-id=proj +quarkus.langchain4j.c8.chat-model.provider=watsonx +quarkus.langchain4j.watsonx.c8.base-url=https://somecluster.somedomain.ai:443/api +quarkus.langchain4j.watsonx.c8.api-key=test9 +quarkus.langchain4j.watsonx.c8.project-id=proj +quarkus.langchain4j.watsonx.c8.chat-model.mode=generation + quarkus.langchain4j.e1.embedding-model.provider=openai quarkus.langchain4j.openai.e1.api-key=test5 quarkus.langchain4j.e2.embedding-model.provider=ollama diff --git a/integration-tests/multiple-providers/src/test/java/org/acme/example/multiple/MultipleChatProvidersTest.java b/integration-tests/multiple-providers/src/test/java/org/acme/example/multiple/MultipleChatProvidersTest.java index fd0352868..779c03ca2 100644 --- a/integration-tests/multiple-providers/src/test/java/org/acme/example/multiple/MultipleChatProvidersTest.java +++ b/integration-tests/multiple-providers/src/test/java/org/acme/example/multiple/MultipleChatProvidersTest.java @@ -13,6 +13,7 @@ import io.quarkiverse.langchain4j.huggingface.QuarkusHuggingFaceChatModel; import io.quarkiverse.langchain4j.ollama.OllamaChatLanguageModel; import io.quarkiverse.langchain4j.openshiftai.OpenshiftAiChatModel; +import io.quarkiverse.langchain4j.watsonx.WatsonxChatModel; import io.quarkiverse.langchain4j.watsonx.WatsonxGenerationModel; import io.quarkus.arc.ClientProxy; import io.quarkus.test.junit.QuarkusTest; @@ -47,6 +48,10 @@ public class MultipleChatProvidersTest { @ModelName("c7") ChatLanguageModel seventhNamedModel; + @Inject + @ModelName("c8") + ChatLanguageModel eighthNamedModel; + @Test void defaultModel() { assertThat(ClientProxy.unwrap(defaultModel)).isInstanceOf(OpenAiChatModel.class); @@ -79,6 +84,11 @@ void sixthNamedModel() { @Test void seventhNamedModel() { - assertThat(ClientProxy.unwrap(seventhNamedModel)).isInstanceOf(WatsonxGenerationModel.class); + assertThat(ClientProxy.unwrap(seventhNamedModel)).isInstanceOf(WatsonxChatModel.class); + } + + @Test + void eighthNamedModel() { + assertThat(ClientProxy.unwrap(eighthNamedModel)).isInstanceOf(WatsonxGenerationModel.class); } } diff --git a/integration-tests/multiple-providers/src/test/java/org/acme/example/multiple/MultipleTokenCountEstimatorProvidersTest.java b/integration-tests/multiple-providers/src/test/java/org/acme/example/multiple/MultipleTokenCountEstimatorProvidersTest.java index bcb127713..444fdb5bf 100644 --- a/integration-tests/multiple-providers/src/test/java/org/acme/example/multiple/MultipleTokenCountEstimatorProvidersTest.java +++ b/integration-tests/multiple-providers/src/test/java/org/acme/example/multiple/MultipleTokenCountEstimatorProvidersTest.java @@ -10,7 +10,7 @@ import dev.langchain4j.model.chat.TokenCountEstimator; import io.quarkiverse.langchain4j.ModelName; import io.quarkiverse.langchain4j.azure.openai.AzureOpenAiChatModel; -import io.quarkiverse.langchain4j.watsonx.WatsonxGenerationModel; +import io.quarkiverse.langchain4j.watsonx.WatsonxChatModel; import io.quarkus.arc.ClientProxy; import io.quarkus.test.junit.QuarkusTest; @@ -41,7 +41,7 @@ void azureOpenAiTest() { @Test void watsonxTest() { - assertThat(ClientProxy.unwrap(watsonxChat)).isInstanceOf(WatsonxGenerationModel.class); - assertThat(ClientProxy.unwrap(watsonxTokenizer)).isInstanceOf(WatsonxGenerationModel.class); + assertThat(ClientProxy.unwrap(watsonxChat)).isInstanceOf(WatsonxChatModel.class); + assertThat(ClientProxy.unwrap(watsonxTokenizer)).isInstanceOf(WatsonxChatModel.class); } } diff --git a/model-providers/watsonx/deployment/src/main/java/io/quarkiverse/langchain4j/watsonx/deployment/WatsonxProcessor.java b/model-providers/watsonx/deployment/src/main/java/io/quarkiverse/langchain4j/watsonx/deployment/WatsonxProcessor.java index f29726932..7359e2949 100644 --- a/model-providers/watsonx/deployment/src/main/java/io/quarkiverse/langchain4j/watsonx/deployment/WatsonxProcessor.java +++ b/model-providers/watsonx/deployment/src/main/java/io/quarkiverse/langchain4j/watsonx/deployment/WatsonxProcessor.java @@ -7,12 +7,15 @@ import static io.quarkiverse.langchain4j.deployment.TemplateUtil.getTemplateFromAnnotationInstance; import java.util.List; +import java.util.function.Supplier; import jakarta.enterprise.context.ApplicationScoped; import org.jboss.jandex.AnnotationInstance; import org.jboss.logging.Logger; +import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.chat.StreamingChatLanguageModel; import io.quarkiverse.langchain4j.ModelName; import io.quarkiverse.langchain4j.deployment.LangChain4jDotNames; import io.quarkiverse.langchain4j.deployment.items.ChatModelProviderCandidateBuildItem; @@ -60,7 +63,7 @@ public void providerCandidates(BuildProducer selectedChatItem, @@ -81,63 +84,80 @@ void createPromptFormatters( ? fixedRuntimeConfig.defaultConfig().chatModel().modelId() : fixedRuntimeConfig.namedConfig().get(configName).chatModel().modelId(); - boolean promptFormatterIsEnabled = NamedConfigUtil.isDefault(configName) - ? fixedRuntimeConfig.defaultConfig().chatModel().promptFormatter() - : fixedRuntimeConfig.namedConfig().get(configName).chatModel().promptFormatter(); + String mode = NamedConfigUtil.isDefault(configName) + ? fixedRuntimeConfig.defaultConfig().chatModel().mode() + : fixedRuntimeConfig.namedConfig().get(configName).chatModel().mode(); - PromptFormatter promptFormatter = null; + if (mode.equalsIgnoreCase("chat")) { - if (promptFormatterIsEnabled) { - promptFormatter = PromptFormatterMapper.get(modelId); - if (promptFormatter == null) { - log.warnf( - "The \"%s\" model does not have a PromptFormatter implementation, no tags are automatically generated.", - modelId); - } - } + chatModelBuilder.produce(new WatsonxChatModelProviderBuildItem(configName, mode, null)); - var registerAiService = annotationInstances.stream() - .filter(annotationInstance -> { - var modelName = annotationInstance.value("modelName"); - if (modelName == null) { - return configName.equals(NamedConfigUtil.DEFAULT_NAME); - } else { - return configName.equals(modelName.asString()); - } - }).findFirst(); + } else if (mode.equalsIgnoreCase("generation")) { - if (!registerAiService.isEmpty()) { + boolean promptFormatterIsEnabled = NamedConfigUtil.isDefault(configName) + ? fixedRuntimeConfig.defaultConfig().chatModel().promptFormatter() + : fixedRuntimeConfig.namedConfig().get(configName).chatModel().promptFormatter(); - var classInfo = registerAiService.get().target().asClass(); - var tools = classInfo.annotation(LangChain4jDotNames.REGISTER_AI_SERVICES).value("tools"); + PromptFormatter promptFormatter = null; - if (tools != null) { - if (!promptFormatterIsEnabled) - throw new RuntimeException("The prompt-formatter must be enabled to use the tool functionality"); - - if (!PromptFormatterMapper.toolIsSupported(modelId)) - throw new RuntimeException( - "The tool functionality is not supported for the model \"%s\"".formatted(modelId)); + if (promptFormatterIsEnabled) { + promptFormatter = PromptFormatterMapper.get(modelId); + if (promptFormatter == null) { + log.warnf( + "The \"%s\" model does not have a PromptFormatter implementation, no tags are automatically generated.", + modelId); + } } - if (promptFormatter != null) { - var systemMessage = getTemplateFromAnnotationInstance( - classInfo.annotation(LangChain4jDotNames.SYSTEM_MESSAGE)); - var userMessage = getTemplateFromAnnotationInstance(classInfo.annotation(LangChain4jDotNames.USER_MESSAGE)); - var tokenAlreadyExist = promptFormatter.tokens().stream() - .filter(token -> systemMessage.contains(token) || userMessage.contains(token)) - .findFirst(); + var registerAiService = annotationInstances.stream() + .filter(annotationInstance -> { + var modelName = annotationInstance.value("modelName"); + if (modelName == null) { + return configName.equals(NamedConfigUtil.DEFAULT_NAME); + } else { + return configName.equals(modelName.asString()); + } + }).findFirst(); - if (tokenAlreadyExist.isPresent()) { - log.warnf( - "The prompt in the AIService \"%s\" already contains one or more tags for the model \"%s\", the prompt-formatter option is disabled." - .formatted(classInfo.name().toString(), modelId)); - promptFormatter = null; + if (!registerAiService.isEmpty()) { + + var classInfo = registerAiService.get().target().asClass(); + var tools = classInfo.annotation(LangChain4jDotNames.REGISTER_AI_SERVICES).value("tools"); + + if (tools != null) { + if (!promptFormatterIsEnabled) + throw new RuntimeException("The prompt-formatter must be enabled to use the tool functionality"); + + if (!PromptFormatterMapper.toolIsSupported(modelId)) + throw new RuntimeException( + "The tool functionality is not supported for the model \"%s\"".formatted(modelId)); + } + + if (promptFormatter != null) { + var systemMessage = getTemplateFromAnnotationInstance( + classInfo.annotation(LangChain4jDotNames.SYSTEM_MESSAGE)); + var userMessage = getTemplateFromAnnotationInstance( + classInfo.annotation(LangChain4jDotNames.USER_MESSAGE)); + var tokenAlreadyExist = promptFormatter.tokens().stream() + .filter(token -> systemMessage.contains(token) || userMessage.contains(token)) + .findFirst(); + + if (tokenAlreadyExist.isPresent()) { + log.warnf( + "The prompt in the AIService \"%s\" already contains one or more tags for the model \"%s\", the prompt-formatter option is disabled." + .formatted(classInfo.name().toString(), modelId)); + promptFormatter = null; + } } } - } - chatModelBuilder.produce(new WatsonxChatModelProviderBuildItem(configName, promptFormatter)); + chatModelBuilder.produce(new WatsonxChatModelProviderBuildItem(configName, mode, promptFormatter)); + + } else { + throw new RuntimeException( + "The \"mode\" value for the model \"%s\" is not valid. Choose one between [\"chat\", \"generation\"]" + .formatted(mode, configName)); + } } } @@ -152,11 +172,19 @@ void generateBeans(WatsonxRecorder recorder, LangChain4jWatsonxConfig runtimeCon for (var selected : selectedChatItem) { String configName = selected.getConfigName(); - PromptFormatter promptFormatter = selected.getPromptFormatter(); - var chatLanguageModel = recorder.generationModel(runtimeConfig, fixedRuntimeConfig, configName, promptFormatter); - var streamingChatLanguageModel = recorder.generationStreamingModel(runtimeConfig, fixedRuntimeConfig, configName, - promptFormatter); + Supplier chatLanguageModel; + Supplier streamingChatLanguageModel; + + if (selected.getMode().equals("chat")) { + chatLanguageModel = recorder.chatModel(runtimeConfig, fixedRuntimeConfig, configName); + streamingChatLanguageModel = recorder.streamingChatModel(runtimeConfig, fixedRuntimeConfig, configName); + } else { + PromptFormatter promptFormatter = selected.getPromptFormatter(); + chatLanguageModel = recorder.generationModel(runtimeConfig, fixedRuntimeConfig, configName, promptFormatter); + streamingChatLanguageModel = recorder.generationStreamingModel(runtimeConfig, fixedRuntimeConfig, configName, + promptFormatter); + } var chatBuilder = SyntheticBeanBuildItem .configure(CHAT_MODEL) diff --git a/model-providers/watsonx/deployment/src/main/java/io/quarkiverse/langchain4j/watsonx/deployment/items/WatsonxChatModelProviderBuildItem.java b/model-providers/watsonx/deployment/src/main/java/io/quarkiverse/langchain4j/watsonx/deployment/items/WatsonxChatModelProviderBuildItem.java index a69131a40..fea796898 100644 --- a/model-providers/watsonx/deployment/src/main/java/io/quarkiverse/langchain4j/watsonx/deployment/items/WatsonxChatModelProviderBuildItem.java +++ b/model-providers/watsonx/deployment/src/main/java/io/quarkiverse/langchain4j/watsonx/deployment/items/WatsonxChatModelProviderBuildItem.java @@ -6,10 +6,12 @@ public final class WatsonxChatModelProviderBuildItem extends MultiBuildItem { private final String configName; + private final String mode; private final PromptFormatter promptFormatter; - public WatsonxChatModelProviderBuildItem(String configName, PromptFormatter promptTemplate) { + public WatsonxChatModelProviderBuildItem(String configName, String mode, PromptFormatter promptTemplate) { this.configName = configName; + this.mode = mode; this.promptFormatter = promptTemplate; } @@ -17,6 +19,10 @@ public String getConfigName() { return configName; } + public String getMode() { + return mode; + } + public PromptFormatter getPromptFormatter() { return promptFormatter; } diff --git a/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/AiChatServiceTest.java b/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/AiChatServiceTest.java index dc596d716..29eb0e59c 100644 --- a/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/AiChatServiceTest.java +++ b/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/AiChatServiceTest.java @@ -1,25 +1,47 @@ package io.quarkiverse.langchain4j.watsonx.deployment; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; import java.util.Date; +import java.util.List; +import java.util.Map; import jakarta.inject.Inject; import jakarta.inject.Singleton; +import jakarta.ws.rs.core.MediaType; import org.jboss.shrinkwrap.api.ShrinkWrap; import org.jboss.shrinkwrap.api.spec.JavaArchive; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; +import com.github.tomakehurst.wiremock.stubbing.Scenario; + +import dev.langchain4j.agent.tool.Tool; +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.ToolExecutionResultMessage; +import dev.langchain4j.service.MemoryId; import dev.langchain4j.service.SystemMessage; import dev.langchain4j.service.UserMessage; +import dev.langchain4j.store.memory.chat.ChatMemoryStore; import io.quarkiverse.langchain4j.RegisterAiService; -import io.quarkiverse.langchain4j.watsonx.bean.Parameters; -import io.quarkiverse.langchain4j.watsonx.bean.TextGenerationRequest; +import io.quarkiverse.langchain4j.watsonx.bean.TextChatMessage; +import io.quarkiverse.langchain4j.watsonx.bean.TextChatMessage.TextChatMessageAssistant; +import io.quarkiverse.langchain4j.watsonx.bean.TextChatMessage.TextChatMessageSystem; +import io.quarkiverse.langchain4j.watsonx.bean.TextChatMessage.TextChatMessageTool; +import io.quarkiverse.langchain4j.watsonx.bean.TextChatMessage.TextChatMessageUser; +import io.quarkiverse.langchain4j.watsonx.bean.TextChatMessage.TextChatParameterTools; +import io.quarkiverse.langchain4j.watsonx.bean.TextChatMessage.TextChatParameterTools.TextChatParameterFunction; +import io.quarkiverse.langchain4j.watsonx.bean.TextChatMessage.TextChatToolCall; +import io.quarkiverse.langchain4j.watsonx.bean.TextChatMessage.TextChatToolCall.TextChatFunctionCall; +import io.quarkiverse.langchain4j.watsonx.bean.TextChatParameters; +import io.quarkiverse.langchain4j.watsonx.bean.TextChatRequest; import io.quarkiverse.langchain4j.watsonx.runtime.config.ChatModelConfig; import io.quarkiverse.langchain4j.watsonx.runtime.config.LangChain4jWatsonxConfig; import io.quarkus.test.QuarkusUnitTest; +import io.smallrye.mutiny.Multi; public class AiChatServiceTest extends WireMockAbstract { @@ -29,51 +51,340 @@ public class AiChatServiceTest extends WireMockAbstract { .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.iam.base-url", WireMockUtil.URL_IAM_SERVER) .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.api-key", WireMockUtil.API_KEY) .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.project-id", WireMockUtil.PROJECT_ID) - .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class).addClass(WireMockUtil.class)); + .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class).addClasses(WireMockUtil.class, Calculator.class)); - @RegisterAiService - @Singleton - interface NewAIService { + @Override + void handlerBeforeEach() { + mockServers.mockIAMBuilder(200) + .grantType(langchain4jWatsonConfig.defaultConfig().iam().grantType()) + .response(WireMockUtil.BEARER_TOKEN, new Date()) + .build(); + } - @SystemMessage("This is a systemMessage") + @Singleton + @RegisterAiService(chatMemoryProviderSupplier = RegisterAiService.NoChatMemoryProviderSupplier.class) + @SystemMessage("This is a systemMessage") + interface AIService { @UserMessage("This is a userMessage {text}") String chat(String text); + + @UserMessage("This is a userMessage {text}") + Multi streaming(String text); } + @Singleton + @RegisterAiService(tools = Calculator.class) + @SystemMessage("This is a systemMessage") + interface AIServiceWithTool { + String chat(@MemoryId String memoryId, @UserMessage String text); + + Multi streaming(@MemoryId String memoryId, @UserMessage String text); + } + + @Inject + AIService aiService; + @Inject - NewAIService service; + AIServiceWithTool aiServiceWithTool; + + @Inject + ChatMemoryStore memory; + + @Singleton + static class Calculator { + + @Tool("Execute the sum of two numbers") + public int sum(int first, int second) { + return first + second; + } + } + + static List tools = List.of( + new TextChatParameterTools("function", new TextChatParameterFunction( + "sum", + "Execute the sum of two numbers", + Map. of( + "type", "object", + "properties", Map. of( + "first", Map. of("type", "integer"), + "second", Map. of("type", "integer")), + "required", List.of("first", "second"))))); @Test void chat() throws Exception { - LangChain4jWatsonxConfig.WatsonConfig watsonConfig = langchain4jWatsonConfig.defaultConfig(); - ChatModelConfig chatModelConfig = watsonConfig.chatModel(); - String modelId = langchain4jWatsonFixedRuntimeConfig.defaultConfig().chatModel().modelId(); - String projectId = watsonConfig.projectId(); - String input = new StringBuilder() - .append("This is a systemMessage") - .append("\n") - .append("This is a userMessage Hello") - .toString(); - Parameters parameters = Parameters.builder() - .decodingMethod(chatModelConfig.decodingMethod()) - .temperature(chatModelConfig.temperature()) - .minNewTokens(chatModelConfig.minNewTokens()) - .maxNewTokens(chatModelConfig.maxNewTokens()) - .timeLimit(10000L) + var messages = List. of( + TextChatMessageSystem.of("This is a systemMessage"), + TextChatMessageUser.of("This is a userMessage Hello")); + + mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_CHAT_API, 200) + .body(mapper.writeValueAsString(generateChatRequest(messages, null))) + .response(WireMockUtil.RESPONSE_WATSONX_CHAT_API) .build(); - TextGenerationRequest body = new TextGenerationRequest(modelId, projectId, input, parameters); + assertEquals("AI Response", aiService.chat("Hello")); + } - mockServers.mockIAMBuilder(200) - .response(WireMockUtil.BEARER_TOKEN, new Date()) + @Test + void chat_with_tool() throws Exception { + + var STARTED = List. of( + TextChatMessageSystem.of("This is a systemMessage"), + TextChatMessageUser.of("Execute the sum of 1 + 1")); + + mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_CHAT_API, 200) + .body(mapper.writeValueAsString(generateChatRequest(STARTED, tools))) + .scenario(Scenario.STARTED, "TOOL_CALL") + .response( + """ + { + "id": "chat-2e8d342d8ced41d89c0ff4efd32b3f9d", + "model_id": "mistralai/mistral-large", + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "tool_calls": [ + { + "id": "chatcmpl-tool-3f621ce6ad9240da963d661215621711", + "type": "function", + "function": { + "name": "sum", + "arguments": "{\\\"first\\\":1, \\\"second\\\":1}" + } + } + ] + }, + "finish_reason": "tool_calls" + }], + "created": 1728808696, + "model_version": "2.0.0", + "created_at": "2024-10-13T08:38:16.960Z", + "usage": { + "completion_tokens": 25, + "prompt_tokens": 84, + "total_tokens": 109 + } + }""") .build(); + var TOOL_CALL = List. of( + TextChatMessageSystem.of("This is a systemMessage"), + TextChatMessageUser.of("Execute the sum of 1 + 1"), + TextChatMessageAssistant.of(List.of( + new TextChatToolCall(null, "chatcmpl-tool-3f621ce6ad9240da963d661215621711", "function", + new TextChatFunctionCall("sum", "{\"first\":1, \"second\":1}")))), + TextChatMessageTool.of("2", "chatcmpl-tool-3f621ce6ad9240da963d661215621711")); + mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_CHAT_API, 200) - .body(mapper.writeValueAsString(body)) - .response(WireMockUtil.RESPONSE_WATSONX_CHAT_API) + .body(mapper.writeValueAsString(generateChatRequest(TOOL_CALL, tools))) + .scenario("TOOL_CALL", "AI_RESPONSE") + .response(""" + { + "id": "cmpl-15475d0dea9b4429a55843c77997f8a9", + "model_id": "mistralai/mistral-large", + "created": 1728806666, + "created_at": "2024-10-13T08:04:26.200Z", + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "content": "The result is 2" + }, + "finish_reason": "stop" + }], + "usage": { + "completion_tokens": 47, + "prompt_tokens": 59, + "total_tokens": 106 + } + }""") + .build(); + + var result = aiServiceWithTool.chat("no_streaming", "Execute the sum of 1 + 1"); + assertEquals("The result is 2", result); + + var messages = memory.getMessages("no_streaming"); + assertEquals("This is a systemMessage", messages.get(0).text()); + assertEquals("Execute the sum of 1 + 1", messages.get(1).text()); + assertEquals("The result is 2", messages.get(4).text()); + + if (messages.get(2) instanceof AiMessage aiMessage) { + assertTrue(aiMessage.hasToolExecutionRequests()); + assertEquals("{\"first\":1, \"second\":1}", aiMessage.toolExecutionRequests().get(0).arguments()); + } else { + fail("The third message is not of type AiMessage"); + } + + if (messages.get(3) instanceof ToolExecutionResultMessage toolResultMessage) { + assertEquals(2, Integer.parseInt(toolResultMessage.text())); + } else { + fail("The fourth message is not of type ToolExecutionResultMessage"); + } + } + + @Test + void streaming_chat() throws Exception { + + var messages = List. of( + TextChatMessageSystem.of("This is a systemMessage"), + TextChatMessageUser.of("This is a userMessage Hello")); + + mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_CHAT_STREAMING_API, 200) + .body(mapper.writeValueAsString(generateChatRequest(messages, null))) + .responseMediaType(MediaType.SERVER_SENT_EVENTS) + .response(WireMockUtil.RESPONSE_WATSONX_CHAT_STREAMING_API) + .build(); + + var result = aiService.streaming("Hello").collect().asList().await().indefinitely(); + assertEquals(List.of(" He", "llo"), result); + } + + @Test + void streaming_chat_with_tool() throws Exception { + + var STARTED = List. of( + TextChatMessageSystem.of("This is a systemMessage"), + TextChatMessageUser.of("Execute the sum of 1 + 1")); + + mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_CHAT_STREAMING_API, 200) + .body(mapper.writeValueAsString(generateChatRequest(STARTED, tools))) + .responseMediaType(MediaType.SERVER_SENT_EVENTS) + .scenario(Scenario.STARTED, "TOOL_CALL") + .response( + """ + id: 1 + event: message + data: {"id":"chat-188595e69470446fb1740c98acfdfe12","model_id":"mistralai/mistral-large","choices":[{"index":0,"finish_reason":null,"delta":{"role":"assistant"}}],"created":1728811250,"model_version":"2.0.0","created_at":"2024-10-13T09:20:50.490Z","usage":{"prompt_tokens":84,"total_tokens":84},"system":{"warnings":[{"message":"This model is a Non-IBM Product governed by a third-party license that may impose use restrictions and other obligations. By using this model you agree to its terms as identified in the following URL.","id":"disclaimer_warning","more_info":"https://dataplatform.cloud.ibm.com/docs/content/wsj/analyze-data/fm-models.html?context=wx"},{"message":"The value of 'max_tokens' for this model was set to value 1024","id":"unspecified_max_token","additional_properties":{"limit":0,"new_value":1024,"parameter":"max_tokens","value":0}}]}} + + id: 2 + event: message + data: {"id":"chat-188595e69470446fb1740c98acfdfe12","model_id":"mistralai/mistral-large","choices":[{"index":0,"finish_reason":null,"delta":{"tool_calls":[{"index":0,"id":"chatcmpl-tool-7cf5dfd7c52441e59a7585243b22a86a","type":"function","function":{"name":"","arguments":""}}]}}],"created":1728811250,"model_version":"2.0.0","created_at":"2024-10-13T09:20:50.546Z","usage":{"completion_tokens":4,"prompt_tokens":84,"total_tokens":88}} + + id: 3 + event: message + data: {"id":"chat-188595e69470446fb1740c98acfdfe12","model_id":"mistralai/mistral-large","choices":[{"index":0,"finish_reason":null,"delta":{"tool_calls":[{"index":0,"function":{"name":"sum","arguments":""}}]}}],"created":1728811250,"model_version":"2.0.0","created_at":"2024-10-13T09:20:50.620Z","usage":{"completion_tokens":8,"prompt_tokens":84,"total_tokens":92}} + + id: 4 + event: message + data: {"id":"chat-188595e69470446fb1740c98acfdfe12","model_id":"mistralai/mistral-large","choices":[{"index":0,"finish_reason":null,"delta":{"tool_calls":[{"index":0,"function":{"name":"","arguments":"{\\\"first\\\": 1"}}]}}],"created":1728811250,"model_version":"2.0.0","created_at":"2024-10-13T09:20:50.768Z","usage":{"completion_tokens":16,"prompt_tokens":84,"total_tokens":100}} + + id: 5 + event: message + data: {"id":"chat-188595e69470446fb1740c98acfdfe12","model_id":"mistralai/mistral-large","choices":[{"index":0,"finish_reason":null,"delta":{"tool_calls":[{"index":0,"function":{"name":"","arguments":""}}]}}],"created":1728811250,"model_version":"2.0.0","created_at":"2024-10-13T09:20:50.786Z","usage":{"completion_tokens":17,"prompt_tokens":84,"total_tokens":101}} + + id: 6 + event: message + data: {"id":"chat-188595e69470446fb1740c98acfdfe12","model_id":"mistralai/mistral-large","choices":[{"index":0,"finish_reason":null,"delta":{"tool_calls":[{"index":0,"function":{"name":"","arguments":""}}]}}],"created":1728811250,"model_version":"2.0.0","created_at":"2024-10-13T09:20:50.805Z","usage":{"completion_tokens":18,"prompt_tokens":84,"total_tokens":102}} + + id: 7 + event: message + data: {"id":"chat-188595e69470446fb1740c98acfdfe12","model_id":"mistralai/mistral-large","choices":[{"index":0,"finish_reason":null,"delta":{"tool_calls":[{"index":0,"function":{"name":"","arguments":""}}]}}],"created":1728811250,"model_version":"2.0.0","created_at":"2024-10-13T09:20:50.823Z","usage":{"completion_tokens":19,"prompt_tokens":84,"total_tokens":103}} + + id: 8 + event: message + data: {"id":"chat-188595e69470446fb1740c98acfdfe12","model_id":"mistralai/mistral-large","choices":[{"index":0,"finish_reason":null,"delta":{"tool_calls":[{"index":0,"function":{"name":"","arguments":""}}]}}],"created":1728811250,"model_version":"2.0.0","created_at":"2024-10-13T09:20:50.842Z","usage":{"completion_tokens":20,"prompt_tokens":84,"total_tokens":104}} + + id: 9 + event: message + data: {"id":"chat-188595e69470446fb1740c98acfdfe12","model_id":"mistralai/mistral-large","choices":[{"index":0,"finish_reason":null,"delta":{"tool_calls":[{"index":0,"function":{"name":"","arguments":""}}]}}],"created":1728811250,"model_version":"2.0.0","created_at":"2024-10-13T09:20:50.861Z","usage":{"completion_tokens":21,"prompt_tokens":84,"total_tokens":105}} + + id: 10 + event: message + data: {"id":"chat-188595e69470446fb1740c98acfdfe12","model_id":"mistralai/mistral-large","choices":[{"index":0,"finish_reason":null,"delta":{"tool_calls":[{"index":0,"function":{"name":"","arguments":", \\\"second\\\": 1"}}]}}],"created":1728811250,"model_version":"2.0.0","created_at":"2024-10-13T09:20:50.879Z","usage":{"completion_tokens":22,"prompt_tokens":84,"total_tokens":106}} + + id: 11 + event: message + data: {"id":"chat-188595e69470446fb1740c98acfdfe12","model_id":"mistralai/mistral-large","choices":[{"index":0,"finish_reason":null,"delta":{"tool_calls":[{"index":0,"function":{"name":"","arguments":""}}]}}],"created":1728811250,"model_version":"2.0.0","created_at":"2024-10-13T09:20:50.897Z","usage":{"completion_tokens":23,"prompt_tokens":84,"total_tokens":107}} + + id: 12 + event: message + data: {"id":"chat-188595e69470446fb1740c98acfdfe12","model_id":"mistralai/mistral-large","choices":[{"index":0,"finish_reason":null,"delta":{"tool_calls":[{"index":0,"function":{"name":"","arguments":""}}]}}],"created":1728811250,"model_version":"2.0.0","created_at":"2024-10-13T09:20:50.916Z","usage":{"completion_tokens":24,"prompt_tokens":84,"total_tokens":108}} + + id: 13 + event: message + data: {"id":"chat-188595e69470446fb1740c98acfdfe12","model_id":"mistralai/mistral-large","choices":[{"index":0,"finish_reason":"tool_calls","delta":{"tool_calls":[{"index":0,"function":{"name":"","arguments":"}"}}]}}],"created":1728811250,"model_version":"2.0.0","created_at":"2024-10-13T09:20:50.934Z","usage":{"completion_tokens":25,"prompt_tokens":84,"total_tokens":109}} + + id: 14 + event: message + data: {"id":"chat-188595e69470446fb1740c98acfdfe12","model_id":"mistralai/mistral-large","choices":[],"created":1728811250,"model_version":"2.0.0","created_at":"2024-10-13T09:20:50.935Z","usage":{"completion_tokens":25,"prompt_tokens":84,"total_tokens":109}} + + id: 15 + event: close + """) + .build(); + + var TOOL_CALL = List. of( + TextChatMessageSystem.of("This is a systemMessage"), + TextChatMessageUser.of("Execute the sum of 1 + 1"), + TextChatMessageAssistant.of(List.of( + new TextChatToolCall(null, "chatcmpl-tool-7cf5dfd7c52441e59a7585243b22a86a", "function", + new TextChatFunctionCall("sum", "{\"first\": 1, \"second\": 1}")))), + TextChatMessageTool.of("2", "chatcmpl-tool-7cf5dfd7c52441e59a7585243b22a86a")); + + mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_CHAT_STREAMING_API, 200) + .body(mapper.writeValueAsString(generateChatRequest(TOOL_CALL, tools))) + .responseMediaType(MediaType.SERVER_SENT_EVENTS) + .scenario("TOOL_CALL", "AI_RESPONSE") + .response( + """ + id: 1 + event: message + data: {"id":"chat-049e3ff7ff08416fb5c334d05af059da","model_id":"mistralai/mistral-large","choices":[{"index":0,"finish_reason":null,"delta":{"role":"assistant"}}],"created":1728810714,"model_version":"2.0.0","created_at":"2024-10-13T09:11:55.072Z","usage":{"prompt_tokens":88,"total_tokens":88},"system":{"warnings":[{"message":"This model is a Non-IBM Product governed by a third-party license that may impose use restrictions and other obligations. By using this model you agree to its terms as identified in the following URL.","id":"disclaimer_warning","more_info":"https://dataplatform.cloud.ibm.com/docs/content/wsj/analyze-data/fm-models.html?context=wx"},{"message":"The value of 'time_limit' for this model must be larger than 0 and not larger than 10m0s; it was set to 10m0s","id":"time_limit_out_of_range","additional_properties":{"limit":600000,"new_value":600000,"parameter":"time_limit","value":999000}}]}} + + id: 2 + event: message + data: {"id":"chat-049e3ff7ff08416fb5c334d05af059da","model_id":"mistralai/mistral-large","choices":[{"index":0,"finish_reason":null,"delta":{"content":"The res"}}],"created":1728810714,"model_version":"2.0.0","created_at":"2024-10-13T09:11:55.073Z","usage":{"completion_tokens":1,"prompt_tokens":88,"total_tokens":89}} + + id: 3 + event: message + data: {"id":"chat-049e3ff7ff08416fb5c334d05af059da","model_id":"mistralai/mistral-large","choices":[{"index":0,"finish_reason":"stop","delta":{"content":"ult is 2"}}],"created":1728810714,"model_version":"2.0.0","created_at":"2024-10-13T09:11:55.090Z","usage":{"completion_tokens":2,"prompt_tokens":88,"total_tokens":90}} + + id: 4 + event: message + data: {"id":"chat-049e3ff7ff08416fb5c334d05af059da","model_id":"mistralai/mistral-large","choices":[],"created":1728810714,"model_version":"2.0.0","created_at":"2024-10-13T09:11:55.715Z","usage":{"completion_tokens":36,"prompt_tokens":88,"total_tokens":124}} + + id: 5 + event: close + data: {} + """) + .build(); + + var result = aiServiceWithTool.streaming("streaming", "Execute the sum of 1 + 1").collect().asList().await() + .indefinitely(); + assertEquals(List.of("The res", "ult is 2"), result); + + var messages = memory.getMessages("streaming"); + assertEquals("This is a systemMessage", messages.get(0).text()); + assertEquals("Execute the sum of 1 + 1", messages.get(1).text()); + assertEquals("The result is 2", messages.get(4).text()); + + if (messages.get(2) instanceof AiMessage aiMessage) { + assertTrue(aiMessage.hasToolExecutionRequests()); + assertEquals("{\"first\": 1, \"second\": 1}", aiMessage.toolExecutionRequests().get(0).arguments()); + } else { + fail("The third message is not of type AiMessage"); + } + + if (messages.get(3) instanceof ToolExecutionResultMessage toolResultMessage) { + assertEquals(2, Integer.parseInt(toolResultMessage.text())); + } else { + fail("The fourth message is not of type ToolExecutionResultMessage"); + } + } + + private TextChatRequest generateChatRequest(List messages, List tools) { + LangChain4jWatsonxConfig.WatsonConfig watsonConfig = langchain4jWatsonConfig.defaultConfig(); + ChatModelConfig chatModelConfig = watsonConfig.chatModel(); + String modelId = langchain4jWatsonFixedRuntimeConfig.defaultConfig().chatModel().modelId(); + String projectId = watsonConfig.projectId(); + + TextChatParameters parameters = TextChatParameters.builder() + .temperature(chatModelConfig.temperature()) + .maxTokens(chatModelConfig.maxNewTokens()) + .timeLimit(WireMockUtil.DEFAULT_TIME_LIMIT) .build(); - assertEquals("AI Response", service.chat("Hello")); + return new TextChatRequest(modelId, projectId, messages, tools, parameters); } } diff --git a/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/AiGenerationServiceTest.java b/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/AiGenerationServiceTest.java new file mode 100644 index 000000000..ec23595aa --- /dev/null +++ b/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/AiGenerationServiceTest.java @@ -0,0 +1,270 @@ +package io.quarkiverse.langchain4j.watsonx.deployment; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +import java.util.Date; +import java.util.List; + +import jakarta.inject.Inject; +import jakarta.inject.Singleton; +import jakarta.ws.rs.core.MediaType; + +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import com.github.tomakehurst.wiremock.stubbing.Scenario; + +import dev.langchain4j.agent.tool.Tool; +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.ToolExecutionResultMessage; +import dev.langchain4j.service.MemoryId; +import dev.langchain4j.service.SystemMessage; +import dev.langchain4j.service.UserMessage; +import dev.langchain4j.store.memory.chat.ChatMemoryStore; +import io.quarkiverse.langchain4j.RegisterAiService; +import io.quarkiverse.langchain4j.watsonx.bean.TextGenerationParameters; +import io.quarkiverse.langchain4j.watsonx.bean.TextGenerationRequest; +import io.quarkiverse.langchain4j.watsonx.runtime.config.ChatModelConfig; +import io.quarkiverse.langchain4j.watsonx.runtime.config.LangChain4jWatsonxConfig; +import io.quarkus.test.QuarkusUnitTest; +import io.smallrye.mutiny.Multi; + +public class AiGenerationServiceTest extends WireMockAbstract { + + @RegisterExtension + static QuarkusUnitTest unitTest = new QuarkusUnitTest() + .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.base-url", WireMockUtil.URL_WATSONX_SERVER) + .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.iam.base-url", WireMockUtil.URL_IAM_SERVER) + .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.api-key", WireMockUtil.API_KEY) + .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.project-id", WireMockUtil.PROJECT_ID) + .overrideConfigKey("quarkus.langchain4j.watsonx.chat-model.model-id", "mistralai/mistral-large") + .overrideConfigKey("quarkus.langchain4j.watsonx.chat-model.prompt-formatter", "true") + .overrideConfigKey("quarkus.langchain4j.watsonx.chat-model.mode", "generation") + .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class).addClasses(WireMockUtil.class, Calculator.class)); + + @Override + void handlerBeforeEach() { + mockServers.mockIAMBuilder(200) + .grantType(langchain4jWatsonConfig.defaultConfig().iam().grantType()) + .response(WireMockUtil.BEARER_TOKEN, new Date()) + .build(); + } + + @Singleton + @RegisterAiService(chatMemoryProviderSupplier = RegisterAiService.NoChatMemoryProviderSupplier.class) + @SystemMessage("This is a systemMessage") + interface AIService { + @UserMessage("This is a userMessage {text}") + String chat(String text); + + @UserMessage("This is a userMessage {text}") + Multi streaming(String text); + } + + @Singleton + @RegisterAiService(tools = Calculator.class) + @SystemMessage("This is a systemMessage") + interface AIServiceWithTool { + String chat(@MemoryId String memoryId, @UserMessage String text); + + Multi streaming(@MemoryId String memoryId, @UserMessage String text); + } + + @Inject + AIService aiService; + + @Inject + AIServiceWithTool aiServiceWithTool; + + @Inject + ChatMemoryStore memory; + + @Singleton + static class Calculator { + + @Tool("Execute the sum of two numbers") + public int sum(int first, int second) { + return first + second; + } + } + + static String TOOL_CALL = "[TOOL_CALLS] [{\\\"id\\\":\\\"1\\\",\\\"name\\\":\\\"sum\\\",\\\"arguments\\\":{\\\"first\\\":1,\\\"second\\\":1}}]"; + + @Test + void chat() throws Exception { + + mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_GENERATION_API, 200) + .body(mapper.writeValueAsString(generateRequest())) + .response(WireMockUtil.RESPONSE_WATSONX_GENERATION_API) + .build(); + + assertEquals("AI Response", aiService.chat("Hello")); + } + + @Test + void chat_with_tool() throws Exception { + + mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_GENERATION_API, 200) + .scenario(Scenario.STARTED, "TOOL_CALL") + .response(""" + { + "model_id": "mistralai/mistral-large", + "created_at": "2024-01-21T17:06:14.052Z", + "results": [ + { + "generated_text": "%s", + "generated_token_count": 5, + "input_token_count": 50, + "stop_reason": "eos_token", + "seed": 2123876088 + } + ] + }""".formatted(TOOL_CALL)) + .build(); + + mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_GENERATION_API, 200) + .scenario("TOOL_CALL", "AI_RESPONSE") + .response(""" + { + "model_id": "mistralai/mistral-large", + "created_at": "2024-01-21T17:06:14.052Z", + "results": [ + { + "generated_text": "The result is 2", + "generated_token_count": 5, + "input_token_count": 50, + "stop_reason": "eos_token", + "seed": 2123876088 + } + ] + }""") + .build(); + + var result = aiServiceWithTool.chat("no_streaming", "Execute the sum of 1 + 1"); + assertEquals("The result is 2", result); + + var messages = memory.getMessages("no_streaming"); + assertEquals("This is a systemMessage", messages.get(0).text()); + assertEquals("Execute the sum of 1 + 1", messages.get(1).text()); + assertEquals("The result is 2", messages.get(4).text()); + + if (messages.get(2) instanceof AiMessage aiMessage) { + assertTrue(aiMessage.hasToolExecutionRequests()); + assertEquals("{\"first\":1,\"second\":1}", aiMessage.toolExecutionRequests().get(0).arguments()); + } else { + fail("The third message is not of type AiMessage"); + } + + if (messages.get(3) instanceof ToolExecutionResultMessage toolResultMessage) { + assertEquals(2, Integer.parseInt(toolResultMessage.text())); + } else { + fail("The fourth message is not of type ToolExecutionResultMessage"); + } + } + + @Test + void streaming_chat() throws Exception { + + mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_GENERATION_STREAMING_API, 200) + .body(mapper.writeValueAsString(generateRequest())) + .responseMediaType(MediaType.SERVER_SENT_EVENTS) + .response(WireMockUtil.RESPONSE_WATSONX_GENERATION_STREAMING_API) + .build(); + + var result = aiService.streaming("Hello").collect().asList().await().indefinitely(); + assertEquals(List.of(". ", "I'", "m ", "a beginner"), result); + } + + @Test + void streaming_chat_with_tool() throws Exception { + + mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_GENERATION_STREAMING_API, 200) + .responseMediaType(MediaType.SERVER_SENT_EVENTS) + .scenario(Scenario.STARTED, "TOOL_CALL") + .response( + """ + id: 1 + event: message + data: {} + + id: 2 + event: message + data: {"modelId":"mistralai/mistral-large","results":[{"generated_text":"","generated_token_count":0,"input_token_count":2,"stop_reason":"not_finished"}]} + + id: 3 + event: message + data: {"modelId":"mistralai/mistral-large","results":[{"generated_text":"%s","generated_token_count":0,"input_token_count":2,"stop_reason":"not_finished"}]} + + id: 4 + event: close""" + .formatted(TOOL_CALL)) + .build(); + + mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_GENERATION_STREAMING_API, 200) + .responseMediaType(MediaType.SERVER_SENT_EVENTS) + .scenario("TOOL_CALL", "AI_RESPONSE") + .response( + """ + id: 1 + event: message + data: {} + + id: 2 + event: message + data: {"modelId":"mistralai/mistral-large","results":[{"generated_text":"","generated_token_count":0,"input_token_count":2,"stop_reason":"not_finished"}]} + + id: 3 + event: message + data: {"modelId":"mistralai/mistral-large","results":[{"generated_text":"The result is 2","generated_token_count":0,"input_token_count":2,"stop_reason":"not_finished"}]} + + id: 4 + event: close""") + .build(); + + var result = aiServiceWithTool.streaming("streaming", "Execute the sum of 1 + 1").collect().asList().await() + .indefinitely(); + assertEquals("The result is 2", result.get(0)); + + var messages = memory.getMessages("streaming"); + assertEquals("This is a systemMessage", messages.get(0).text()); + assertEquals("Execute the sum of 1 + 1", messages.get(1).text()); + assertEquals("The result is 2", messages.get(4).text()); + + if (messages.get(2) instanceof AiMessage aiMessage) { + assertTrue(aiMessage.hasToolExecutionRequests()); + assertEquals("{\"first\":1,\"second\":1}", aiMessage.toolExecutionRequests().get(0).arguments()); + } else { + fail("The third message is not of type AiMessage"); + } + + if (messages.get(3) instanceof ToolExecutionResultMessage toolResultMessage) { + assertEquals(2, Integer.parseInt(toolResultMessage.text())); + } else { + fail("The fourth message is not of type ToolExecutionResultMessage"); + } + } + + private TextGenerationRequest generateRequest() { + LangChain4jWatsonxConfig.WatsonConfig watsonConfig = langchain4jWatsonConfig.defaultConfig(); + ChatModelConfig chatModelConfig = watsonConfig.chatModel(); + String modelId = langchain4jWatsonFixedRuntimeConfig.defaultConfig().chatModel().modelId(); + String projectId = watsonConfig.projectId(); + String input = new StringBuilder() + .append("[INST] This is a systemMessage [/INST]") + .append("[INST] This is a userMessage Hello [/INST]") + .toString(); + TextGenerationParameters parameters = TextGenerationParameters.builder() + .decodingMethod(chatModelConfig.decodingMethod()) + .temperature(chatModelConfig.temperature()) + .minNewTokens(chatModelConfig.minNewTokens()) + .maxNewTokens(chatModelConfig.maxNewTokens()) + .timeLimit(WireMockUtil.DEFAULT_TIME_LIMIT) + .build(); + + return new TextGenerationRequest(modelId, projectId, input, parameters); + } +} diff --git a/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/CacheTokenTest.java b/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/CacheTokenTest.java index df97d042a..52b2f0458 100644 --- a/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/CacheTokenTest.java +++ b/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/CacheTokenTest.java @@ -49,6 +49,7 @@ public class CacheTokenTest extends WireMockAbstract { .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.iam.base-url", WireMockUtil.URL_IAM_SERVER) .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.api-key", WireMockUtil.API_KEY) .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.project-id", WireMockUtil.PROJECT_ID) + .overrideConfigKey("quarkus.langchain4j.watsonx.chat-model.mode", "generation") .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class).addClass(WireMockUtil.class)); @Inject @@ -81,14 +82,15 @@ void try_token_cache() throws InterruptedException { .build(); Stream.of( - Map.entry(WireMockUtil.URL_WATSONX_CHAT_API, WireMockUtil.RESPONSE_WATSONX_CHAT_API), + Map.entry(WireMockUtil.URL_WATSONX_GENERATION_API, WireMockUtil.RESPONSE_WATSONX_GENERATION_API), Map.entry(WireMockUtil.URL_WATSONX_EMBEDDING_API, WireMockUtil.RESPONSE_WATSONX_EMBEDDING_API), - Map.entry(WireMockUtil.URL_WATSONX_CHAT_STREAMING_API, WireMockUtil.RESPONSE_WATSONX_STREAMING_API), + Map.entry(WireMockUtil.URL_WATSONX_GENERATION_STREAMING_API, + WireMockUtil.RESPONSE_WATSONX_GENERATION_STREAMING_API), Map.entry(WireMockUtil.URL_WATSONX_TOKENIZER_API, WireMockUtil.RESPONSE_WATSONX_TOKENIZER_API)) .forEach(entry -> { mockServers.mockWatsonxBuilder(entry.getKey(), 200) .token("3secondstoken") - .responseMediaType(entry.getKey().equals(WireMockUtil.URL_WATSONX_CHAT_STREAMING_API) + .responseMediaType(entry.getKey().equals(WireMockUtil.URL_WATSONX_GENERATION_STREAMING_API) ? MediaType.SERVER_SENT_EVENTS : MediaType.APPLICATION_JSON) .response(entry.getValue()) @@ -127,9 +129,10 @@ void try_token_retry() throws InterruptedException { .build(); Stream.of( - Map.entry(WireMockUtil.URL_WATSONX_CHAT_API, WireMockUtil.RESPONSE_WATSONX_CHAT_API), + Map.entry(WireMockUtil.URL_WATSONX_GENERATION_API, WireMockUtil.RESPONSE_WATSONX_GENERATION_API), Map.entry(WireMockUtil.URL_WATSONX_EMBEDDING_API, WireMockUtil.RESPONSE_WATSONX_EMBEDDING_API), - Map.entry(WireMockUtil.URL_WATSONX_CHAT_STREAMING_API, WireMockUtil.RESPONSE_WATSONX_STREAMING_API), + Map.entry(WireMockUtil.URL_WATSONX_GENERATION_STREAMING_API, + WireMockUtil.RESPONSE_WATSONX_GENERATION_STREAMING_API), Map.entry(WireMockUtil.URL_WATSONX_TOKENIZER_API, WireMockUtil.RESPONSE_WATSONX_TOKENIZER_API)) .forEach(entry -> { mockServers.mockWatsonxBuilder(entry.getKey(), 401) @@ -141,7 +144,7 @@ void try_token_retry() throws InterruptedException { mockServers.mockWatsonxBuilder(entry.getKey(), 200) .token("my_super_token") .scenario("retry", Scenario.STARTED) - .responseMediaType(entry.getKey().equals(WireMockUtil.URL_WATSONX_CHAT_STREAMING_API) + .responseMediaType(entry.getKey().equals(WireMockUtil.URL_WATSONX_GENERATION_STREAMING_API) ? MediaType.SERVER_SENT_EVENTS : MediaType.APPLICATION_JSON) .response(entry.getValue()) diff --git a/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/ChatAllPropertiesTest.java b/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/ChatAllPropertiesTest.java new file mode 100644 index 000000000..7b2b740da --- /dev/null +++ b/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/ChatAllPropertiesTest.java @@ -0,0 +1,167 @@ +package io.quarkiverse.langchain4j.watsonx.deployment; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.awaitility.Awaitility.await; +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.time.Duration; +import java.util.Date; +import java.util.List; +import java.util.concurrent.atomic.AtomicReference; + +import jakarta.inject.Inject; +import jakarta.ws.rs.core.MediaType; + +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.chat.StreamingChatLanguageModel; +import dev.langchain4j.model.chat.TokenCountEstimator; +import io.quarkiverse.langchain4j.watsonx.bean.TextChatMessage; +import io.quarkiverse.langchain4j.watsonx.bean.TextChatMessage.TextChatMessageSystem; +import io.quarkiverse.langchain4j.watsonx.bean.TextChatMessage.TextChatMessageUser; +import io.quarkiverse.langchain4j.watsonx.bean.TextChatParameters; +import io.quarkiverse.langchain4j.watsonx.bean.TextChatRequest; +import io.quarkiverse.langchain4j.watsonx.bean.TokenizationRequest; +import io.quarkus.test.QuarkusUnitTest; + +public class ChatAllPropertiesTest extends WireMockAbstract { + + @RegisterExtension + static QuarkusUnitTest unitTest = new QuarkusUnitTest() + .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.base-url", WireMockUtil.URL_WATSONX_SERVER) + .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.api-key", WireMockUtil.API_KEY) + .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.project-id", WireMockUtil.PROJECT_ID) + .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.timeout", "60s") + .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.log-requests", "true") + .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.log-responses", "true") + .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.version", "aaaa-mm-dd") + .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.iam.base-url", WireMockUtil.URL_IAM_SERVER) + .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.iam.timeout", "60s") + .overrideConfigKey("quarkus.langchain4j.watsonx.chat-model.mode", "chat") + .overrideConfigKey("quarkus.langchain4j.watsonx.chat-model.model-id", "my_super_model") + .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.chat-model.max-new-tokens", "200") + .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.chat-model.temperature", "1.5") + .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.chat-model.top-p", "0.5") + .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.chat-model.response-format", "new_format") + .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.embedding-model.model-id", "my_super_embedding_model") + .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class).addClass(WireMockUtil.class)); + + @Override + void handlerBeforeEach() { + mockServers.mockIAMBuilder(200) + .grantType(langchain4jWatsonConfig.defaultConfig().iam().grantType()) + .response(WireMockUtil.BEARER_TOKEN, new Date()) + .build(); + } + + @Inject + ChatLanguageModel chatModel; + + @Inject + StreamingChatLanguageModel streamingChatModel; + + @Inject + TokenCountEstimator tokenCountEstimator; + + static TextChatParameters parameters = TextChatParameters.builder() + .maxTokens(200) + .temperature(1.5) + .timeLimit(60000L) + .topP(0.5) + .responseFormat("new_format") + .build(); + + @Test + void check_config() throws Exception { + var runtimeConfig = langchain4jWatsonConfig.defaultConfig(); + var fixedRuntimeConfig = langchain4jWatsonFixedRuntimeConfig.defaultConfig(); + assertEquals(WireMockUtil.URL_WATSONX_SERVER, runtimeConfig.baseUrl().toString()); + assertEquals(WireMockUtil.URL_IAM_SERVER, runtimeConfig.iam().baseUrl().toString()); + assertEquals(WireMockUtil.API_KEY, runtimeConfig.apiKey()); + assertEquals(WireMockUtil.PROJECT_ID, runtimeConfig.projectId()); + assertEquals(Duration.ofSeconds(60), runtimeConfig.timeout().get()); + assertEquals(Duration.ofSeconds(60), runtimeConfig.iam().timeout().get()); + assertEquals(true, runtimeConfig.logRequests().orElse(false)); + assertEquals(true, runtimeConfig.logResponses().orElse(false)); + assertEquals("aaaa-mm-dd", runtimeConfig.version()); + assertEquals("my_super_model", fixedRuntimeConfig.chatModel().modelId()); + assertEquals(200, runtimeConfig.chatModel().maxNewTokens()); + assertEquals(0.5, runtimeConfig.chatModel().topP().get()); + assertEquals("new_format", runtimeConfig.chatModel().responseFormat().orElse(null)); + } + + @Test + void check_chat_model_config() throws Exception { + var config = langchain4jWatsonConfig.defaultConfig(); + String modelId = langchain4jWatsonFixedRuntimeConfig.defaultConfig().chatModel().modelId(); + String projectId = config.projectId(); + + var messages = List. of( + TextChatMessageSystem.of("SystemMessage"), + TextChatMessageUser.of("UserMessage")); + + TextChatRequest body = new TextChatRequest(modelId, projectId, messages, null, parameters); + mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_CHAT_API, 200, "aaaa-mm-dd") + .body(mapper.writeValueAsString(body)) + .response(WireMockUtil.RESPONSE_WATSONX_CHAT_API) + .build(); + + assertEquals("AI Response", chatModel.generate(dev.langchain4j.data.message.SystemMessage.from("SystemMessage"), + dev.langchain4j.data.message.UserMessage.from("UserMessage")).content().text()); + } + + @Test + void check_token_count_estimator() throws Exception { + var config = langchain4jWatsonConfig.defaultConfig(); + String modelId = langchain4jWatsonFixedRuntimeConfig.defaultConfig().chatModel().modelId(); + String projectId = config.projectId(); + + var body = new TokenizationRequest(modelId, "test", projectId); + + mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_TOKENIZER_API, 200, "aaaa-mm-dd") + .body(mapper.writeValueAsString(body)) + .response(WireMockUtil.RESPONSE_WATSONX_TOKENIZER_API.formatted(modelId)) + .build(); + + assertEquals(11, tokenCountEstimator.estimateTokenCount("test")); + } + + @Test + void check_chat_streaming_model_config() throws Exception { + var config = langchain4jWatsonConfig.defaultConfig(); + String modelId = langchain4jWatsonFixedRuntimeConfig.defaultConfig().chatModel().modelId(); + String projectId = config.projectId(); + + var messagesToSend = List. of( + TextChatMessageSystem.of("SystemMessage"), + TextChatMessageUser.of("UserMessage")); + + TextChatRequest body = new TextChatRequest(modelId, projectId, messagesToSend, null, parameters); + + mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_CHAT_STREAMING_API, 200, "aaaa-mm-dd") + .body(mapper.writeValueAsString(body)) + .responseMediaType(MediaType.SERVER_SENT_EVENTS) + .response(WireMockUtil.RESPONSE_WATSONX_CHAT_STREAMING_API) + .build(); + + var messages = List.of( + dev.langchain4j.data.message.SystemMessage.from("SystemMessage"), + dev.langchain4j.data.message.UserMessage.from("UserMessage")); + + var streamingResponse = new AtomicReference(); + streamingChatModel.generate(messages, WireMockUtil.streamingResponseHandler(streamingResponse)); + + await().atMost(Duration.ofMinutes(1)) + .pollInterval(Duration.ofSeconds(2)) + .until(() -> streamingResponse.get() != null); + + assertThat(streamingResponse.get().text()) + .isNotNull() + .isEqualTo(" Hello"); + } +} diff --git a/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/ChatDefaultPropertiesTest.java b/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/ChatDefaultPropertiesTest.java new file mode 100644 index 000000000..7f7853e87 --- /dev/null +++ b/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/ChatDefaultPropertiesTest.java @@ -0,0 +1,154 @@ +package io.quarkiverse.langchain4j.watsonx.deployment; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.awaitility.Awaitility.await; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.time.Duration; +import java.util.Date; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.atomic.AtomicReference; + +import jakarta.inject.Inject; +import jakarta.ws.rs.core.MediaType; + +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.chat.StreamingChatLanguageModel; +import dev.langchain4j.model.chat.TokenCountEstimator; +import io.quarkiverse.langchain4j.watsonx.bean.TextChatMessage; +import io.quarkiverse.langchain4j.watsonx.bean.TextChatMessage.TextChatMessageSystem; +import io.quarkiverse.langchain4j.watsonx.bean.TextChatMessage.TextChatMessageUser; +import io.quarkiverse.langchain4j.watsonx.bean.TextChatParameters; +import io.quarkiverse.langchain4j.watsonx.bean.TextChatRequest; +import io.quarkiverse.langchain4j.watsonx.bean.TokenizationRequest; +import io.quarkus.test.QuarkusUnitTest; + +public class ChatDefaultPropertiesTest extends WireMockAbstract { + + @RegisterExtension + static QuarkusUnitTest unitTest = new QuarkusUnitTest() + .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.base-url", WireMockUtil.URL_WATSONX_SERVER) + .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.iam.base-url", WireMockUtil.URL_IAM_SERVER) + .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.api-key", WireMockUtil.API_KEY) + .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.project-id", WireMockUtil.PROJECT_ID) + .overrideConfigKey("quarkus.langchain4j.watsonx.chat-model.mode", "chat") + .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class).addClass(WireMockUtil.class)); + + @Override + void handlerBeforeEach() { + mockServers.mockIAMBuilder(200) + .response("my_super_token", new Date()) + .build(); + } + + static TextChatParameters parameters = TextChatParameters.builder() + .maxTokens(200) + .temperature(1.0) + .timeLimit(WireMockUtil.DEFAULT_TIME_LIMIT) + .build(); + + @Inject + ChatLanguageModel chatModel; + + @Inject + StreamingChatLanguageModel streamingChatModel; + + @Inject + TokenCountEstimator tokenCountEstimator; + + @Test + void check_config() throws Exception { + var runtimeConfig = langchain4jWatsonConfig.defaultConfig(); + var fixedRuntimeConfig = langchain4jWatsonFixedRuntimeConfig.defaultConfig(); + assertEquals(Optional.empty(), runtimeConfig.timeout()); + assertEquals(Optional.empty(), runtimeConfig.iam().timeout()); + assertEquals(false, runtimeConfig.logRequests().orElse(false)); + assertEquals(false, runtimeConfig.logResponses().orElse(false)); + assertEquals(WireMockUtil.VERSION, runtimeConfig.version()); + assertEquals(WireMockUtil.DEFAULT_CHAT_MODEL, fixedRuntimeConfig.chatModel().modelId()); + assertEquals(200, runtimeConfig.chatModel().maxNewTokens()); + assertEquals(1.0, runtimeConfig.chatModel().temperature()); + assertTrue(runtimeConfig.chatModel().topP().isEmpty()); + assertTrue(runtimeConfig.chatModel().responseFormat().isEmpty()); + assertEquals("urn:ibm:params:oauth:grant-type:apikey", runtimeConfig.iam().grantType()); + } + + @Test + void check_chat_model_config() throws Exception { + var config = langchain4jWatsonConfig.defaultConfig(); + String modelId = langchain4jWatsonFixedRuntimeConfig.defaultConfig().chatModel().modelId(); + String projectId = config.projectId(); + + var messages = List. of( + TextChatMessageSystem.of("SystemMessage"), + TextChatMessageUser.of("UserMessage")); + + TextChatRequest body = new TextChatRequest(modelId, projectId, messages, null, parameters); + + mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_CHAT_API, 200) + .body(mapper.writeValueAsString(body)) + .response(WireMockUtil.RESPONSE_WATSONX_CHAT_API) + .build(); + + assertEquals("AI Response", chatModel.generate(dev.langchain4j.data.message.SystemMessage.from("SystemMessage"), + dev.langchain4j.data.message.UserMessage.from("UserMessage")).content().text()); + } + + @Test + void check_token_count_estimator() throws Exception { + var config = langchain4jWatsonConfig.defaultConfig(); + String modelId = langchain4jWatsonFixedRuntimeConfig.defaultConfig().chatModel().modelId(); + String projectId = config.projectId(); + + var body = new TokenizationRequest(modelId, "test", projectId); + + mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_TOKENIZER_API, 200) + .body(mapper.writeValueAsString(body)) + .response(WireMockUtil.RESPONSE_WATSONX_TOKENIZER_API.formatted(modelId)) + .build(); + + assertEquals(11, tokenCountEstimator.estimateTokenCount("test")); + } + + @Test + void check_chat_streaming_model_config() throws Exception { + var config = langchain4jWatsonConfig.defaultConfig(); + String modelId = langchain4jWatsonFixedRuntimeConfig.defaultConfig().chatModel().modelId(); + String projectId = config.projectId(); + + var messagesToSend = List. of( + TextChatMessageSystem.of("SystemMessage"), + TextChatMessageUser.of("UserMessage")); + + TextChatRequest body = new TextChatRequest(modelId, projectId, messagesToSend, null, parameters); + + mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_CHAT_STREAMING_API, 200) + .body(mapper.writeValueAsString(body)) + .responseMediaType(MediaType.SERVER_SENT_EVENTS) + .response(WireMockUtil.RESPONSE_WATSONX_CHAT_STREAMING_API) + .build(); + + var messages = List.of( + dev.langchain4j.data.message.SystemMessage.from("SystemMessage"), + dev.langchain4j.data.message.UserMessage.from("UserMessage")); + + var streamingResponse = new AtomicReference(); + streamingChatModel.generate(messages, WireMockUtil.streamingResponseHandler(streamingResponse)); + + await().atMost(Duration.ofMinutes(1)) + .pollInterval(Duration.ofSeconds(2)) + .until(() -> streamingResponse.get() != null); + + assertThat(streamingResponse.get().text()) + .isNotNull() + .isEqualTo(" Hello"); + } +} diff --git a/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/ChatMemoryPlaceholderTest.java b/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/ChatMemoryPlaceholderTest.java index b30b36571..6a0ffecad 100644 --- a/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/ChatMemoryPlaceholderTest.java +++ b/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/ChatMemoryPlaceholderTest.java @@ -17,7 +17,7 @@ import dev.langchain4j.service.UserMessage; import dev.langchain4j.store.memory.chat.ChatMemoryStore; import io.quarkiverse.langchain4j.RegisterAiService; -import io.quarkiverse.langchain4j.watsonx.bean.Parameters; +import io.quarkiverse.langchain4j.watsonx.bean.TextGenerationParameters; import io.quarkiverse.langchain4j.watsonx.bean.TextGenerationRequest; import io.quarkiverse.langchain4j.watsonx.runtime.config.ChatModelConfig; import io.quarkiverse.langchain4j.watsonx.runtime.config.LangChain4jWatsonxConfig; @@ -31,6 +31,7 @@ public class ChatMemoryPlaceholderTest extends WireMockAbstract { .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.iam.base-url", WireMockUtil.URL_IAM_SERVER) .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.api-key", WireMockUtil.API_KEY) .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.project-id", WireMockUtil.PROJECT_ID) + .overrideConfigKey("quarkus.langchain4j.watsonx.chat-model.mode", "generation") .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class).addClass(WireMockUtil.class)); @Override @@ -105,7 +106,7 @@ void extract_dialogue_test() throws Exception { Hello"""; - mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_CHAT_API, 200) + mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_GENERATION_API, 200) .body(mapper.writeValueAsString(createRequest(input))) .response(""" { @@ -131,7 +132,7 @@ void extract_dialogue_test() throws Exception { Hi! What is your name?"""; - mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_CHAT_API, 200) + mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_GENERATION_API, 200) .body(mapper.writeValueAsString(createRequest(input))) .response(""" { @@ -160,7 +161,7 @@ void extract_dialogue_with_delimiter_test() throws Exception { Hello"""; - mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_CHAT_API, 200) + mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_GENERATION_API, 200) .body(mapper.writeValueAsString(createRequest(input))) .response(""" { @@ -185,7 +186,7 @@ void extract_dialogue_with_delimiter_test() throws Exception { Hi! What is your name?"""; - mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_CHAT_API, 200) + mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_GENERATION_API, 200) .body(mapper.writeValueAsString(createRequest(input))) .response(""" { @@ -214,7 +215,7 @@ void extract_dialogue_with_all_params_test() throws Exception { Hello"""; - mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_CHAT_API, 200) + mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_GENERATION_API, 200) .body(mapper.writeValueAsString(createRequest(input))) .response(""" { @@ -239,7 +240,7 @@ void extract_dialogue_with_all_params_test() throws Exception { Hi! What is your name?"""; - mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_CHAT_API, 200) + mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_GENERATION_API, 200) .body(mapper.writeValueAsString(createRequest(input))) .response(""" { @@ -270,7 +271,7 @@ void extract_dialogue_no_memory_test() throws Exception { Assistant: My name is AiBot Hello"""; - mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_CHAT_API, 200) + mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_GENERATION_API, 200) .body(mapper.writeValueAsString(createRequest(input))) .response(""" { @@ -293,12 +294,12 @@ private TextGenerationRequest createRequest(String input) { ChatModelConfig chatModelConfig = watsonConfig.chatModel(); String modelId = langchain4jWatsonFixedRuntimeConfig.defaultConfig().chatModel().modelId(); String projectId = watsonConfig.projectId(); - Parameters parameters = Parameters.builder() + TextGenerationParameters parameters = TextGenerationParameters.builder() .decodingMethod(chatModelConfig.decodingMethod()) .temperature(chatModelConfig.temperature()) .minNewTokens(chatModelConfig.minNewTokens()) .maxNewTokens(chatModelConfig.maxNewTokens()) - .timeLimit(10000L) + .timeLimit(WireMockUtil.DEFAULT_TIME_LIMIT) .build(); return new TextGenerationRequest(modelId, projectId, input, parameters); diff --git a/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/AllPropertiesTest.java b/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/GenerationAllPropertiesTest.java similarity index 94% rename from model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/AllPropertiesTest.java rename to model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/GenerationAllPropertiesTest.java index 87c867c5f..5c38a1e7a 100644 --- a/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/AllPropertiesTest.java +++ b/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/GenerationAllPropertiesTest.java @@ -26,13 +26,13 @@ import dev.langchain4j.model.embedding.EmbeddingModel; import dev.langchain4j.model.output.Response; import io.quarkiverse.langchain4j.watsonx.bean.EmbeddingRequest; -import io.quarkiverse.langchain4j.watsonx.bean.Parameters; -import io.quarkiverse.langchain4j.watsonx.bean.Parameters.LengthPenalty; +import io.quarkiverse.langchain4j.watsonx.bean.TextGenerationParameters; +import io.quarkiverse.langchain4j.watsonx.bean.TextGenerationParameters.LengthPenalty; import io.quarkiverse.langchain4j.watsonx.bean.TextGenerationRequest; import io.quarkiverse.langchain4j.watsonx.bean.TokenizationRequest; import io.quarkus.test.QuarkusUnitTest; -public class AllPropertiesTest extends WireMockAbstract { +public class GenerationAllPropertiesTest extends WireMockAbstract { @RegisterExtension static QuarkusUnitTest unitTest = new QuarkusUnitTest() @@ -46,6 +46,7 @@ public class AllPropertiesTest extends WireMockAbstract { .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.iam.base-url", WireMockUtil.URL_IAM_SERVER) .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.iam.timeout", "60s") .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.iam.grant-type", "grantME") + .overrideConfigKey("quarkus.langchain4j.watsonx.chat-model.mode", "generation") .overrideConfigKey("quarkus.langchain4j.watsonx.chat-model.model-id", "my_super_model") .overrideConfigKey("quarkus.langchain4j.watsonx.chat-model.prompt-formatter", "true") .overrideConfigKey("quarkus.langchain4j.watsonx.chat-model.prompt-joiner", "@") @@ -85,7 +86,7 @@ void handlerBeforeEach() { @Inject TokenCountEstimator tokenCountEstimator; - static Parameters parameters = Parameters.builder() + static TextGenerationParameters parameters = TextGenerationParameters.builder() .minNewTokens(10) .maxNewTokens(200) .decodingMethod("greedy") @@ -141,9 +142,9 @@ void check_chat_model_config() throws Exception { String projectId = config.projectId(); TextGenerationRequest body = new TextGenerationRequest(modelId, projectId, "SystemMessage@UserMessage", parameters); - mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_CHAT_API, 200, "aaaa-mm-dd") + mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_GENERATION_API, 200, "aaaa-mm-dd") .body(mapper.writeValueAsString(body)) - .response(WireMockUtil.RESPONSE_WATSONX_CHAT_API) + .response(WireMockUtil.RESPONSE_WATSONX_GENERATION_API) .build(); assertEquals("AI Response", chatModel.generate(dev.langchain4j.data.message.SystemMessage.from("SystemMessage"), @@ -193,10 +194,10 @@ void check_chat_streaming_model_config() throws Exception { TextGenerationRequest body = new TextGenerationRequest(modelId, projectId, "SystemMessage@UserMessage", parameters); - mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_CHAT_STREAMING_API, 200, "aaaa-mm-dd") + mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_GENERATION_STREAMING_API, 200, "aaaa-mm-dd") .body(mapper.writeValueAsString(body)) .responseMediaType(MediaType.SERVER_SENT_EVENTS) - .response(WireMockUtil.RESPONSE_WATSONX_STREAMING_API) + .response(WireMockUtil.RESPONSE_WATSONX_GENERATION_STREAMING_API) .build(); var messages = List.of( diff --git a/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/DefaultPropertiesTest.java b/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/GenerationDefaultPropertiesTest.java similarity index 93% rename from model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/DefaultPropertiesTest.java rename to model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/GenerationDefaultPropertiesTest.java index 4638b2bfa..86742fc71 100644 --- a/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/DefaultPropertiesTest.java +++ b/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/GenerationDefaultPropertiesTest.java @@ -28,12 +28,12 @@ import dev.langchain4j.model.embedding.EmbeddingModel; import dev.langchain4j.model.output.Response; import io.quarkiverse.langchain4j.watsonx.bean.EmbeddingRequest; -import io.quarkiverse.langchain4j.watsonx.bean.Parameters; +import io.quarkiverse.langchain4j.watsonx.bean.TextGenerationParameters; import io.quarkiverse.langchain4j.watsonx.bean.TextGenerationRequest; import io.quarkiverse.langchain4j.watsonx.bean.TokenizationRequest; import io.quarkus.test.QuarkusUnitTest; -public class DefaultPropertiesTest extends WireMockAbstract { +public class GenerationDefaultPropertiesTest extends WireMockAbstract { @RegisterExtension static QuarkusUnitTest unitTest = new QuarkusUnitTest() @@ -41,6 +41,7 @@ public class DefaultPropertiesTest extends WireMockAbstract { .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.iam.base-url", WireMockUtil.URL_IAM_SERVER) .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.api-key", WireMockUtil.API_KEY) .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.project-id", WireMockUtil.PROJECT_ID) + .overrideConfigKey("quarkus.langchain4j.watsonx.chat-model.mode", "generation") .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class).addClass(WireMockUtil.class)); @Override @@ -50,12 +51,12 @@ void handlerBeforeEach() { .build(); } - static Parameters parameters = Parameters.builder() + static TextGenerationParameters parameters = TextGenerationParameters.builder() .minNewTokens(0) .maxNewTokens(200) .decodingMethod("greedy") .temperature(1.0) - .timeLimit(10000L) + .timeLimit(WireMockUtil.DEFAULT_TIME_LIMIT) .build(); @Inject @@ -107,9 +108,9 @@ void check_chat_model_config() throws Exception { TextGenerationRequest body = new TextGenerationRequest(modelId, projectId, "SystemMessage\nUserMessage", parameters); - mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_CHAT_API, 200) + mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_GENERATION_API, 200) .body(mapper.writeValueAsString(body)) - .response(WireMockUtil.RESPONSE_WATSONX_CHAT_API) + .response(WireMockUtil.RESPONSE_WATSONX_GENERATION_API) .build(); assertEquals("AI Response", chatModel.generate(dev.langchain4j.data.message.SystemMessage.from("SystemMessage"), @@ -159,10 +160,10 @@ void check_chat_streaming_model_config() throws Exception { TextGenerationRequest body = new TextGenerationRequest(modelId, projectId, "SystemMessage\nUserMessage", parameters); - mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_CHAT_STREAMING_API, 200) + mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_GENERATION_STREAMING_API, 200) .body(mapper.writeValueAsString(body)) .responseMediaType(MediaType.SERVER_SENT_EVENTS) - .response(WireMockUtil.RESPONSE_WATSONX_STREAMING_API) + .response(WireMockUtil.RESPONSE_WATSONX_GENERATION_STREAMING_API) .build(); var messages = List.of( diff --git a/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/HttpErrorTest.java b/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/HttpErrorTest.java index 9e63c9867..35685ed81 100644 --- a/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/HttpErrorTest.java +++ b/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/HttpErrorTest.java @@ -19,7 +19,6 @@ import org.junit.jupiter.api.extension.RegisterExtension; import dev.langchain4j.model.chat.ChatLanguageModel; -import io.quarkiverse.langchain4j.watsonx.bean.WatsonxError; import io.quarkiverse.langchain4j.watsonx.exception.WatsonxException; import io.quarkus.test.QuarkusUnitTest; @@ -31,39 +30,137 @@ public class HttpErrorTest extends WireMockAbstract { .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.iam.base-url", WireMockUtil.URL_IAM_SERVER) .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.api-key", WireMockUtil.API_KEY) .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.project-id", WireMockUtil.PROJECT_ID) + .overrideConfigKey("quarkus.langchain4j.watsonx.chat-model.mode", "generation") .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class).addClass(WireMockUtil.class)); @Inject ChatLanguageModel chatModel; @Test - void error_404_model_not_supported() { + void not_registered_error() { mockServers.mockIAMBuilder(200) .response(WireMockUtil.BEARER_TOKEN, new Date()) .build(); - mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_CHAT_API, 404) + mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_GENERATION_API, 500) .responseMediaType(MediaType.APPLICATION_JSON) .response(""" { "errors": [ { - "code": "model_not_supported", - "message": "Model 'meta-llama/llama-2-70b-chats' is not supported" + "code": "xxx", + "message": "yyyy" } ], "trace": "xxx", + "status_code": 500 + } + """) + .build(); + + WatsonxException ex = assertThrowsExactly(WatsonxException.class, () -> chatModel.generate("message")); + assertEquals(500, ex.details().statusCode()); + assertNotNull(ex.details().errors()); + assertEquals(1, ex.details().errors().size()); + assertEquals("xxx", ex.details().errors().get(0).code()); + assertEquals("yyyy", ex.details().errors().get(0).message()); + } + + @Test + void error_400_model_no_support_for_function() { + + mockServers.mockIAMBuilder(200) + .response(WireMockUtil.BEARER_TOKEN, new Date()) + .build(); + + mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_GENERATION_API, 404) + .responseMediaType(MediaType.APPLICATION_JSON) + .response(""" + { + "errors": [ + { + "code": "model_no_support_for_function", + "message": "Model 'ibm/granite-7b-lab' does not support function 'function_text_chat'", + "more_info": "https://cloud.ibm.com/apidocs/watsonx-ai" + } + ], + "trace": "xxx", + "status_code": 400 + } + """) + .build(); + + WatsonxException ex = assertThrowsExactly(WatsonxException.class, () -> chatModel.generate("message")); + assertEquals(400, ex.details().statusCode()); + assertNotNull(ex.details().errors()); + assertEquals(1, ex.details().errors().size()); + assertEquals("model_no_support_for_function", ex.details().errors().get(0).code()); + } + + @Test + void error_400_json_type_error() { + + mockServers.mockIAMBuilder(200) + .response(WireMockUtil.BEARER_TOKEN, new Date()) + .build(); + + mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_GENERATION_API, 400) + .response( + """ + { + "errors": [ + { + "code": "json_type_error", + "message": "Json field type error: response_format must be of type schemas.TextChatPropertyResponseFormat", + "more_info": "https://cloud.ibm.com/apidocs/watsonx-ai" + } + ], + "trace": "xxx", + "status_code": 400 + } + """) + .build(); + + WatsonxException ex = assertThrowsExactly(WatsonxException.class, () -> chatModel.generate("message")); + assertNotNull(ex.details()); + assertNotNull(ex.details().trace()); + assertEquals(400, ex.details().statusCode()); + assertNotNull(ex.details().errors()); + assertEquals(1, ex.details().errors().size()); + assertEquals("json_type_error", ex.details().errors().get(0).code()); + } + + @Test + void error_404_model_not_supported() { + + mockServers.mockIAMBuilder(200) + .response(WireMockUtil.BEARER_TOKEN, new Date()) + .build(); + + mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_GENERATION_API, 400) + .response(""" + { + "errors": [ + { + "code": "model_not_supported", + "message": "Model 'meta-llama/llama-3-1-70b-instructs' is not supported", + "more_info": "https://cloud.ibm.com/apidocs/watsonx-ai" + } + ], + "trace": "91c784e9f44da953ebafc25933809817", "status_code": 404 } """) .build(); WatsonxException ex = assertThrowsExactly(WatsonxException.class, () -> chatModel.generate("message")); + assertNotNull(ex.details()); + assertNotNull(ex.details().trace()); assertEquals(404, ex.details().statusCode()); assertNotNull(ex.details().errors()); assertEquals(1, ex.details().errors().size()); - assertEquals(WatsonxError.Code.MODEL_NOT_SUPPORTED, ex.details().errors().get(0).code()); + assertEquals("model_not_supported", ex.details().errors().get(0).code()); } @Test @@ -73,7 +170,7 @@ void error_400_json_validation_error() { .response(WireMockUtil.BEARER_TOKEN, new Date()) .build(); - mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_CHAT_API, 400) + mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_GENERATION_API, 400) .response(""" { "errors": [ @@ -94,7 +191,7 @@ void error_400_json_validation_error() { assertEquals(400, ex.details().statusCode()); assertNotNull(ex.details().errors()); assertEquals(1, ex.details().errors().size()); - assertEquals(WatsonxError.Code.JSON_VALIDATION_ERROR, ex.details().errors().get(0).code()); + assertEquals("json_validation_error", ex.details().errors().get(0).code()); } @Test @@ -104,7 +201,7 @@ void error_400_invalid_request_entity() { .response(WireMockUtil.BEARER_TOKEN, new Date()) .build(); - mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_CHAT_API, 400) + mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_GENERATION_API, 400) .response(""" { "errors": [ @@ -125,7 +222,7 @@ void error_400_invalid_request_entity() { assertEquals(400, ex.details().statusCode()); assertNotNull(ex.details().errors()); assertEquals(1, ex.details().errors().size()); - assertEquals(WatsonxError.Code.INVALID_REQUEST_ENTITY, ex.details().errors().get(0).code()); + assertEquals("invalid_request_entity", ex.details().errors().get(0).code()); assertEquals("Missing either space_id or project_id or wml_instance_crn", ex.details().errors().get(0).message()); } @@ -136,7 +233,7 @@ void error_500() { .response(WireMockUtil.BEARER_TOKEN, new Date()) .build(); - mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_CHAT_API, 500) + mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_GENERATION_API, 500) .response("{") .build(); diff --git a/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/PromptFormatterExceptionTest.java b/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/PromptFormatterExceptionTest.java index 314f9bc42..4efcba30e 100644 --- a/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/PromptFormatterExceptionTest.java +++ b/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/PromptFormatterExceptionTest.java @@ -39,14 +39,14 @@ class ToolsModelNotSupported { .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.iam.base-url", WireMockUtil.URL_IAM_SERVER) .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.api-key", WireMockUtil.API_KEY) .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.project-id", WireMockUtil.PROJECT_ID) - .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.embedding-model.model-id", - WireMockUtil.DEFAULT_CHAT_MODEL) + .overrideConfigKey("quarkus.langchain4j.watsonx.chat-model.model-id", "ibm/granite-7b-lab") + .overrideConfigKey("quarkus.langchain4j.watsonx.chat-model.mode", "generation") .overrideConfigKey("quarkus.langchain4j.watsonx.chat-model.prompt-formatter", "true") .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class).addClasses(AIService.class, Calculator.class)) .assertException(t -> { assertThat(t).isInstanceOf(RuntimeException.class) .hasMessage("The tool functionality is not supported for the model \"%s\"" - .formatted(WireMockUtil.DEFAULT_CHAT_MODEL)); + .formatted("ibm/granite-7b-lab")); }); @Test @@ -63,6 +63,7 @@ class ToolsPromptFormatterOff { .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.iam.base-url", WireMockUtil.URL_IAM_SERVER) .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.api-key", WireMockUtil.API_KEY) .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.project-id", WireMockUtil.PROJECT_ID) + .overrideConfigKey("quarkus.langchain4j.watsonx.chat-model.mode", "generation") .overrideConfigKey("quarkus.langchain4j.watsonx.chat-model.prompt-formatter", "false") .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class).addClasses(AIService.class, Calculator.class)) .assertException(t -> { diff --git a/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/PromptFormatterTest.java b/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/PromptFormatterTest.java index dbdd4de9e..e1662c3ea 100644 --- a/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/PromptFormatterTest.java +++ b/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/PromptFormatterTest.java @@ -17,7 +17,7 @@ import dev.langchain4j.service.UserMessage; import dev.langchain4j.service.V; import io.quarkiverse.langchain4j.RegisterAiService; -import io.quarkiverse.langchain4j.watsonx.bean.Parameters; +import io.quarkiverse.langchain4j.watsonx.bean.TextGenerationParameters; import io.quarkiverse.langchain4j.watsonx.bean.TextGenerationRequest; import io.quarkiverse.langchain4j.watsonx.runtime.config.ChatModelConfig; import io.quarkiverse.langchain4j.watsonx.runtime.config.LangChain4jWatsonxConfig; @@ -33,6 +33,7 @@ public class PromptFormatterTest extends WireMockAbstract { .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.project-id", WireMockUtil.PROJECT_ID) .overrideConfigKey("quarkus.langchain4j.watsonx.chat-model.model-id", "mistralai/mistral-large") .overrideConfigKey("quarkus.langchain4j.watsonx.chat-model.prompt-formatter", "true") + .overrideConfigKey("quarkus.langchain4j.watsonx.chat-model.mode", "generation") .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class) .addAsResource("messages/system.txt") .addAsResource("messages/user.txt") @@ -122,20 +123,20 @@ void tests() throws Exception { String modelId = langchain4jWatsonFixedRuntimeConfig.defaultConfig().chatModel().modelId(); String projectId = watsonConfig.projectId(); - Parameters parameters = Parameters.builder() + TextGenerationParameters parameters = TextGenerationParameters.builder() .decodingMethod(chatModelConfig.decodingMethod()) .temperature(chatModelConfig.temperature()) .minNewTokens(chatModelConfig.minNewTokens()) .maxNewTokens(chatModelConfig.maxNewTokens()) - .timeLimit(10000L) + .timeLimit(WireMockUtil.DEFAULT_TIME_LIMIT) .build(); TextGenerationRequest body = new TextGenerationRequest(modelId, projectId, "[INST] You are a poet [/INST][INST] Generate a poem about dog [/INST]", parameters); - mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_CHAT_API, 200) + mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_GENERATION_API, 200) .body(mapper.writeValueAsString(body)) - .response(WireMockUtil.RESPONSE_WATSONX_CHAT_API) + .response(WireMockUtil.RESPONSE_WATSONX_GENERATION_API) .build(); assertEquals("AI Response", aiService.poem("dog")); diff --git a/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/ResponseSchemaOnTest.java b/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/ResponseSchemaOnTest.java index e0a890912..c250e80e2 100644 --- a/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/ResponseSchemaOnTest.java +++ b/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/ResponseSchemaOnTest.java @@ -20,7 +20,7 @@ import dev.langchain4j.service.UserMessage; import dev.langchain4j.service.V; import io.quarkiverse.langchain4j.RegisterAiService; -import io.quarkiverse.langchain4j.watsonx.bean.Parameters; +import io.quarkiverse.langchain4j.watsonx.bean.TextGenerationParameters; import io.quarkiverse.langchain4j.watsonx.bean.TextGenerationRequest; import io.quarkus.test.QuarkusUnitTest; @@ -32,6 +32,7 @@ public class ResponseSchemaOnTest extends WireMockAbstract { .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.iam.base-url", WireMockUtil.URL_IAM_SERVER) .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.api-key", WireMockUtil.API_KEY) .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.project-id", WireMockUtil.PROJECT_ID) + .overrideConfigKey("quarkus.langchain4j.watsonx.chat-model.mode", "generation") .overrideConfigKey("quarkus.langchain4j.response-schema", "true") .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class).addClass(WireMockUtil.class)); @@ -159,7 +160,7 @@ void no_bean_ai_service() throws Exception { dev.langchain4j.data.message.SystemMessage.from("You are a poet"), dev.langchain4j.data.message.UserMessage.from("Generate a poem about dog")); - mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_CHAT_API, 200) + mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_GENERATION_API, 200) .body(mapper.writeValueAsString(from(messages))) .response(RESPONSE) .build(); @@ -174,7 +175,7 @@ void bean_ai_service() throws Exception { dev.langchain4j.data.message.SystemMessage.from("You are a poet"), dev.langchain4j.data.message.UserMessage.from("Generate a poem about dog".concat(SCHEMA))); - mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_CHAT_API, 200) + mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_GENERATION_API, 200) .body(mapper.writeValueAsString(from(messages))) .response(RESPONSE) .build(); @@ -189,7 +190,7 @@ void schema_ai_service() throws Exception { dev.langchain4j.data.message.SystemMessage.from(SCHEMA.concat(" You are a poet")), dev.langchain4j.data.message.UserMessage.from("user", "Generate a poem about dog ".concat(SCHEMA))); - mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_CHAT_API, 200) + mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_GENERATION_API, 200) .body(mapper.writeValueAsString(from(messages))) .response(RESPONSE) .build(); @@ -204,7 +205,7 @@ void schema_system_message_ai_service() throws Exception { dev.langchain4j.data.message.SystemMessage.from(SCHEMA.concat(" You are a poet")), dev.langchain4j.data.message.UserMessage.from("Generate a poem about dog")); - mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_CHAT_API, 200) + mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_GENERATION_API, 200) .body(mapper.writeValueAsString(from(messages))) .response(RESPONSE) .build(); @@ -218,7 +219,7 @@ void on_method_ai_service() throws Exception { List messages = List.of( dev.langchain4j.data.message.UserMessage.from(SCHEMA.concat(" Generate a poem about dog"))); - mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_CHAT_API, 200) + mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_GENERATION_API, 200) .body(mapper.writeValueAsString(from(messages))) .response(RESPONSE) .build(); @@ -234,7 +235,7 @@ void structured_prompt_ai_service() throws Exception { List messages = List.of( dev.langchain4j.data.message.UserMessage.from(SCHEMA.concat("Generate a poem about dog"))); - mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_CHAT_API, 200) + mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_GENERATION_API, 200) .body(mapper.writeValueAsString(from(messages))) .response(RESPONSE) .build(); @@ -249,7 +250,7 @@ void system_message_on_class_ai_service() throws Exception { dev.langchain4j.data.message.SystemMessage.from(SCHEMA.concat(" You are a poet")), dev.langchain4j.data.message.UserMessage.from("Generate a poem about dog")); - mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_CHAT_API, 200) + mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_GENERATION_API, 200) .body(mapper.writeValueAsString(from(messages))) .response(RESPONSE) .build(); @@ -261,12 +262,12 @@ private TextGenerationRequest from(List messages) { var modelId = langchain4jWatsonFixedRuntimeConfig.defaultConfig().chatModel().modelId(); var config = langchain4jWatsonConfig.defaultConfig(); - var parameters = Parameters.builder() + var parameters = TextGenerationParameters.builder() .decodingMethod(config.chatModel().decodingMethod()) .temperature(config.chatModel().temperature()) .minNewTokens(config.chatModel().minNewTokens()) .maxNewTokens(config.chatModel().maxNewTokens()) - .timeLimit(10000L) + .timeLimit(WireMockUtil.DEFAULT_TIME_LIMIT) .build(); var input = messages.stream() diff --git a/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/TokenCountEstimatorTest.java b/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/TokenCountEstimatorTest.java index 6aef84ce1..139f5f614 100644 --- a/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/TokenCountEstimatorTest.java +++ b/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/TokenCountEstimatorTest.java @@ -15,11 +15,9 @@ import dev.langchain4j.data.message.SystemMessage; import dev.langchain4j.data.message.UserMessage; import dev.langchain4j.data.segment.TextSegment; -import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.chat.TokenCountEstimator; import dev.langchain4j.model.input.Prompt; import io.quarkiverse.langchain4j.watsonx.bean.TokenizationRequest; -import io.quarkiverse.langchain4j.watsonx.runtime.config.LangChain4jWatsonxConfig; import io.quarkus.test.QuarkusUnitTest; public class TokenCountEstimatorTest extends WireMockAbstract { @@ -30,6 +28,7 @@ public class TokenCountEstimatorTest extends WireMockAbstract { .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.iam.base-url", WireMockUtil.URL_IAM_SERVER) .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.api-key", WireMockUtil.API_KEY) .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.project-id", WireMockUtil.PROJECT_ID) + .overrideConfigKey("quarkus.langchain4j.watsonx.chat-model.mode", "generation") .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class).addClass(WireMockUtil.class)); @Override @@ -39,15 +38,9 @@ void handlerBeforeEach() { .build(); } - @Inject - ChatLanguageModel model; - @Inject TokenCountEstimator tokenization; - @Inject - LangChain4jWatsonxConfig langchain4jWatsonxConfig; - @Test void token_count_estimator_text() throws Exception { var input = mockServer(); @@ -77,7 +70,7 @@ void token_count_estimator_list() throws Exception { var modelId = langchain4jWatsonFixedRuntimeConfig.defaultConfig().chatModel().modelId(); var input = "Write a tagline for an alumni\nassociation: Together we"; - var projectId = langchain4jWatsonxConfig.defaultConfig().projectId(); + var projectId = langchain4jWatsonConfig.defaultConfig().projectId(); var body = new TokenizationRequest(modelId, input, projectId); mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_TOKENIZER_API, 200) @@ -93,7 +86,7 @@ private String mockServer() throws Exception { var modelId = langchain4jWatsonFixedRuntimeConfig.defaultConfig().chatModel().modelId(); var input = "Write a tagline for an alumni association: Together we"; - var projectId = langchain4jWatsonxConfig.defaultConfig().projectId(); + var projectId = langchain4jWatsonConfig.defaultConfig().projectId(); var body = new TokenizationRequest(modelId, input, projectId); mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_TOKENIZER_API, 200) diff --git a/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/WireMockUtil.java b/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/WireMockUtil.java index 3d17235d6..3eb73f06d 100644 --- a/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/WireMockUtil.java +++ b/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/WireMockUtil.java @@ -22,10 +22,14 @@ public class WireMockUtil { + public static final Long DEFAULT_TIME_LIMIT = 10000l; + public static final int PORT_WATSONX_SERVER = 8089; public static final String URL_WATSONX_SERVER = "http://localhost:8089"; - public static final String URL_WATSONX_CHAT_API = "/ml/v1/text/generation?version=%s"; - public static final String URL_WATSONX_CHAT_STREAMING_API = "/ml/v1/text/generation_stream?version=%s"; + public static final String URL_WATSONX_CHAT_API = "/ml/v1/text/chat?version=%s"; + public static final String URL_WATSONX_CHAT_STREAMING_API = "/ml/v1/text/chat_stream?version=%s"; + public static final String URL_WATSONX_GENERATION_API = "/ml/v1/text/generation?version=%s"; + public static final String URL_WATSONX_GENERATION_STREAMING_API = "/ml/v1/text/generation_stream?version=%s"; public static final String URL_WATSONX_EMBEDDING_API = "/ml/v1/text/embeddings?version=%s"; public static final String URL_WATSONX_TOKENIZER_API = "/ml/v1/text/tokenization?version=%s"; @@ -38,7 +42,7 @@ public class WireMockUtil { public static final String PROJECT_ID = "123123321321"; public static final String GRANT_TYPE = "urn:ibm:params:oauth:grant-type:apikey"; public static final String VERSION = "2024-03-14"; - public static final String DEFAULT_CHAT_MODEL = "ibm/granite-13b-chat-v2"; + public static final String DEFAULT_CHAT_MODEL = "mistralai/mistral-large"; public static final String DEFAULT_EMBEDDING_MODEL = "ibm/slate-125m-english-rtrvr"; public static final String IAM_200_RESPONSE = """ { @@ -50,9 +54,9 @@ public class WireMockUtil { "scope": "ibm openid" } """; - public static String RESPONSE_WATSONX_CHAT_API = """ + public static String RESPONSE_WATSONX_GENERATION_API = """ { - "model_id": "ibm/granite-13b-chat-v2", + "model_id": "mistralai/mistral-large", "created_at": "2024-01-21T17:06:14.052Z", "results": [ { @@ -65,6 +69,26 @@ public class WireMockUtil { ] } """; + public static String RESPONSE_WATSONX_CHAT_API = """ + { + "id": "cmpl-15475d0dea9b4429a55843c77997f8a9", + "model_id": "mistralai/mistral-large", + "created": 1689958352, + "created_at": "2023-07-21T16:52:32.190Z", + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "content": "AI Response" + }, + "finish_reason": "stop" + }], + "usage": { + "completion_tokens": 47, + "prompt_tokens": 59, + "total_tokens": 106 + } + }"""; public static String RESPONSE_WATSONX_EMBEDDING_API = """ { @@ -82,30 +106,51 @@ public class WireMockUtil { "input_token_count": 10 } """; - public static String RESPONSE_WATSONX_STREAMING_API = """ + public static String RESPONSE_WATSONX_CHAT_STREAMING_API = """ + id: 1 + event: message + data: {"id":"chat-049e3ff7ff08416fb5c334d05af059da","model_id":"mistralai/mistral-large","choices":[{"index":0,"finish_reason":null,"delta":{"role":"assistant"}}],"created":1728810714,"model_version":"2.0.0","created_at":"2024-10-13T09:11:55.072Z","usage":{"prompt_tokens":88,"total_tokens":88},"system":{"warnings":[{"message":"This model is a Non-IBM Product governed by a third-party license that may impose use restrictions and other obligations. By using this model you agree to its terms as identified in the following URL.","id":"disclaimer_warning","more_info":"https://dataplatform.cloud.ibm.com/docs/content/wsj/analyze-data/fm-models.html?context=wx"},{"message":"The value of 'time_limit' for this model must be larger than 0 and not larger than 10m0s; it was set to 10m0s","id":"time_limit_out_of_range","additional_properties":{"limit":600000,"new_value":600000,"parameter":"time_limit","value":999000}}]}} + + id: 2 + event: message + data: {"id":"chat-049e3ff7ff08416fb5c334d05af059da","model_id":"mistralai/mistral-large","choices":[{"index":0,"finish_reason":null,"delta":{"content":" He"}}],"created":1728810714,"model_version":"2.0.0","created_at":"2024-10-13T09:11:55.073Z","usage":{"completion_tokens":1,"prompt_tokens":88,"total_tokens":89}} + + id: 3 + event: message + data: {"id":"chat-049e3ff7ff08416fb5c334d05af059da","model_id":"mistralai/mistral-large","choices":[{"index":0,"finish_reason":"stop","delta":{"content":"llo"}}],"created":1728810714,"model_version":"2.0.0","created_at":"2024-10-13T09:11:55.090Z","usage":{"completion_tokens":2,"prompt_tokens":88,"total_tokens":90}} + + id: 4 + event: message + data: {"id":"chat-049e3ff7ff08416fb5c334d05af059da","model_id":"mistralai/mistral-large","choices":[],"created":1728810714,"model_version":"2.0.0","created_at":"2024-10-13T09:11:55.715Z","usage":{"completion_tokens":36,"prompt_tokens":88,"total_tokens":124}} + + id: 5 + event: close + data: {} + """; + public static String RESPONSE_WATSONX_GENERATION_STREAMING_API = """ id: 1 event: message data: {} id: 2 event: message - data: {"modelId":"ibm/granite-13b-chat-v2","model_version":"2.1.0","created_at":"2024-05-04T14:29:19.162Z","results":[{"generated_text":"","generated_token_count":0,"input_token_count":2,"stop_reason":"not_finished"}]} + data: {"modelId":"mistralai/mistral-large","model_version":"2.1.0","created_at":"2024-05-04T14:29:19.162Z","results":[{"generated_text":"","generated_token_count":0,"input_token_count":2,"stop_reason":"not_finished"}]} id: 3 event: message - data: {"model_id":"ibm/granite-13b-chat-v2","model_version":"2.1.0","created_at":"2024-05-04T14:29:19.203Z","results":[{"generated_text":". ","generated_token_count":2,"input_token_count":0,"stop_reason":"not_finished"}]} + data: {"model_id":"mistralai/mistral-large","model_version":"2.1.0","created_at":"2024-05-04T14:29:19.203Z","results":[{"generated_text":". ","generated_token_count":2,"input_token_count":0,"stop_reason":"not_finished"}]} id: 4 event: message - data: {"model_id":"ibm/granite-13b-chat-v2","model_version":"2.1.0","created_at":"2024-05-04T14:29:19.223Z","results":[{"generated_text":"I'","generated_token_count":3,"input_token_count":0,"stop_reason":"not_finished"}]} + data: {"model_id":"mistralai/mistral-large","model_version":"2.1.0","created_at":"2024-05-04T14:29:19.223Z","results":[{"generated_text":"I'","generated_token_count":3,"input_token_count":0,"stop_reason":"not_finished"}]} id: 5 event: message - data: {"model_id":"ibm/granite-13b-chat-v2","model_version":"2.1.0","created_at":"2024-05-04T14:29:19.243Z","results":[{"generated_text":"m ","generated_token_count":4,"input_token_count":0,"stop_reason":"not_finished"}]} + data: {"model_id":"mistralai/mistral-large","model_version":"2.1.0","created_at":"2024-05-04T14:29:19.243Z","results":[{"generated_text":"m ","generated_token_count":4,"input_token_count":0,"stop_reason":"not_finished"}]} id: 6 event: message - data: {"model_id":"ibm/granite-13b-chat-v2","model_version":"2.1.0","created_at":"2024-05-04T14:29:19.262Z","results":[{"generated_text":"a beginner","generated_token_count":5,"input_token_count":0,"stop_reason":"max_tokens"}]} + data: {"model_id":"mistralai/mistral-large","model_version":"2.1.0","created_at":"2024-05-04T14:29:19.262Z","results":[{"generated_text":"a beginner","generated_token_count":5,"input_token_count":0,"stop_reason":"max_tokens"}]} id: 7 event: close diff --git a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxChatModel.java b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxChatModel.java new file mode 100644 index 000000000..9f0411f97 --- /dev/null +++ b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxChatModel.java @@ -0,0 +1,380 @@ +package io.quarkiverse.langchain4j.watsonx; + +import static io.quarkiverse.langchain4j.watsonx.WatsonxUtils.retryOn; + +import java.net.URL; +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.Callable; +import java.util.concurrent.TimeUnit; +import java.util.function.Consumer; +import java.util.stream.Collectors; + +import org.jboss.resteasy.reactive.client.api.LoggingScope; + +import dev.langchain4j.agent.tool.ToolExecutionRequest; +import dev.langchain4j.agent.tool.ToolSpecification; +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.model.StreamingResponseHandler; +import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.chat.StreamingChatLanguageModel; +import dev.langchain4j.model.chat.TokenCountEstimator; +import dev.langchain4j.model.output.FinishReason; +import dev.langchain4j.model.output.Response; +import dev.langchain4j.model.output.TokenUsage; +import io.quarkiverse.langchain4j.watsonx.bean.TextChatMessage; +import io.quarkiverse.langchain4j.watsonx.bean.TextChatMessage.StreamingToolFetcher; +import io.quarkiverse.langchain4j.watsonx.bean.TextChatMessage.TextChatParameterTools; +import io.quarkiverse.langchain4j.watsonx.bean.TextChatMessage.TextChatToolCall; +import io.quarkiverse.langchain4j.watsonx.bean.TextChatParameters; +import io.quarkiverse.langchain4j.watsonx.bean.TextChatRequest; +import io.quarkiverse.langchain4j.watsonx.bean.TextChatResponse; +import io.quarkiverse.langchain4j.watsonx.bean.TextChatResponse.TextChatResultChoice; +import io.quarkiverse.langchain4j.watsonx.bean.TextChatResponse.TextChatResultMessage; +import io.quarkiverse.langchain4j.watsonx.bean.TextChatResponse.TextChatUsage; +import io.quarkiverse.langchain4j.watsonx.bean.TextStreamingChatResponse; +import io.quarkiverse.langchain4j.watsonx.bean.TokenizationRequest; +import io.quarkiverse.langchain4j.watsonx.client.WatsonxRestApi; +import io.quarkiverse.langchain4j.watsonx.client.filter.BearerTokenHeaderFactory; +import io.quarkus.rest.client.reactive.QuarkusRestClientBuilder; +import io.smallrye.mutiny.Context; +import io.smallrye.mutiny.infrastructure.Infrastructure; + +public class WatsonxChatModel implements ChatLanguageModel, StreamingChatLanguageModel, TokenCountEstimator { + + private static final String USAGE_CONTEXT = "USAGE"; + private static final String FINISH_REASON_CONTEXT = "FINISH_REASON"; + private static final String ROLE_CONTEXT = "ROLE"; + private static final String TOOLS_CONTEXT = "TOOLS"; + private static final String COMPLETE_MESSAGE_CONTEXT = "COMPLETE_MESSAGE"; + + private final String modelId, projectId, version; + private final WatsonxRestApi client; + private final TextChatParameters parameters; + + public WatsonxChatModel(Builder builder) { + + QuarkusRestClientBuilder restClientBuilder = QuarkusRestClientBuilder.newBuilder() + .baseUrl(builder.url) + .clientHeadersFactory(new BearerTokenHeaderFactory(builder.tokenGenerator)) + .connectTimeout(builder.timeout.toSeconds(), TimeUnit.SECONDS) + .readTimeout(builder.timeout.toSeconds(), TimeUnit.SECONDS); + + if (builder.logRequests || builder.logResponses) { + restClientBuilder.loggingScope(LoggingScope.REQUEST_RESPONSE); + restClientBuilder.clientLogger(new WatsonxRestApi.WatsonClientLogger( + builder.logRequests, + builder.logResponses)); + } + + this.client = restClientBuilder.build(WatsonxRestApi.class); + this.modelId = builder.modelId; + this.projectId = builder.projectId; + this.version = builder.version; + + this.parameters = TextChatParameters.builder() + .maxTokens(builder.maxTokens) + .temperature(builder.temperature) + .topP(builder.topP) + .timeLimit(builder.timeout.toMillis()) + .responseFormat(builder.responseFormat) + .build(); + } + + @Override + public Response generate(List messages, List toolSpecifications) { + var convertedMessages = messages.stream().map(TextChatMessage::convert).toList(); + var tools = (toolSpecifications != null && toolSpecifications.size() > 0) + ? toolSpecifications.stream().map(TextChatParameterTools::of).toList() + : null; + + TextChatRequest request = new TextChatRequest(modelId, projectId, convertedMessages, tools, parameters); + + TextChatResponse response = retryOn(new Callable() { + @Override + public TextChatResponse call() throws Exception { + return client.chat(request, version); + } + }); + + TextChatResultChoice choice = response.choices().get(0); + TextChatResultMessage message = choice.message(); + TextChatUsage usage = response.usage(); + + AiMessage content; + if (message.toolCalls() != null && message.toolCalls().size() > 0) { + content = AiMessage.from(message.toolCalls().stream().map(TextChatToolCall::convert).toList()); + } else { + content = AiMessage.from(message.content().trim()); + } + + var finishReason = toFinishReason(choice.finishReason()); + var tokenUsage = new TokenUsage( + usage.promptTokens(), + usage.completionTokens(), + usage.totalTokens()); + + return Response.from(content, tokenUsage, finishReason); + } + + @Override + public void generate(List messages, List toolSpecifications, + StreamingResponseHandler handler) { + var convertedMessages = messages.stream().map(TextChatMessage::convert).toList(); + var tools = (toolSpecifications != null && toolSpecifications.size() > 0) + ? toolSpecifications.stream().map(TextChatParameterTools::of).toList() + : null; + + TextChatRequest request = new TextChatRequest(modelId, projectId, convertedMessages, tools, parameters); + Context context = Context.empty(); + context.put(TOOLS_CONTEXT, new ArrayList()); + context.put(COMPLETE_MESSAGE_CONTEXT, new StringBuilder()); + + var mutiny = client.streamingChat(request, version); + if (tools != null) { + // Today Langchain4j doesn't allow to use the async operation with tools. + // One idea might be to give to the developer the possibility to use the VirtualThread. + mutiny.emitOn(Infrastructure.getDefaultWorkerPool()); + } + + mutiny.subscribe() + .with(context, + new Consumer() { + @Override + public void accept(TextStreamingChatResponse chunk) { + try { + + // Last message get the "usage" values + if (chunk.choices().size() == 0) { + context.put(USAGE_CONTEXT, chunk.usage()); + return; + } + + var message = chunk.choices().get(0); + + if (message.finishReason() != null) { + context.put(FINISH_REASON_CONTEXT, message.finishReason()); + } + + if (message.delta().role() != null) { + context.put(ROLE_CONTEXT, message.delta().role()); + } + + if (message.delta().toolCalls() != null) { + + StreamingToolFetcher toolFetcher; + + // During streaming there is only one element in the tool_calls, + // but the "index" field can be used to understand how many tools need to be executed. + var deltaTool = message.delta().toolCalls().get(0); + var index = deltaTool.index(); + + List tools = context.get(TOOLS_CONTEXT); + + // Check if there is an incomplete version of the TextChatToolCall object. + if ((index + 1) > tools.size()) { + // First occurrence of the object, create it. + toolFetcher = new StreamingToolFetcher(index); + tools.add(toolFetcher); + } else { + // Incomplete version is present, complete it. + toolFetcher = tools.get(index); + } + + toolFetcher.setId(deltaTool.id()); + toolFetcher.setType(deltaTool.type()); + + if (deltaTool.function() != null) { + toolFetcher.setName(deltaTool.function().name()); + toolFetcher.appendArguments(deltaTool.function().arguments()); + } + } + + if (message.delta().content() != null) { + + StringBuilder stringBuilder = context.get(COMPLETE_MESSAGE_CONTEXT); + String token = message.delta().content(); + + if (token.isEmpty()) + return; + + stringBuilder.append(token); + handler.onNext(token); + } + + } catch (Exception e) { + handler.onError(e); + } + } + }, + new Consumer() { + @Override + public void accept(Throwable error) { + handler.onError(error); + } + }, + new Runnable() { + @Override + public void run() { + + TextStreamingChatResponse.TextChatUsage usage = context.get(USAGE_CONTEXT); + TokenUsage tokenUsage = new TokenUsage( + usage.promptTokens(), + usage.completionTokens(), + usage.totalTokens()); + + String finishReason = context.get(FINISH_REASON_CONTEXT); + FinishReason finishReasonObj = toFinishReason(finishReason); + + if (finishReason.equals("tool_calls")) { + + List tools = context.get(TOOLS_CONTEXT); + List toolExecutionRequests = tools.stream() + .map(StreamingToolFetcher::build).map(TextChatToolCall::convert).toList(); + + handler.onComplete( + Response.from(AiMessage.from(toolExecutionRequests), tokenUsage, finishReasonObj)); + + } else { + + StringBuilder message = context.get(COMPLETE_MESSAGE_CONTEXT); + handler.onComplete( + Response.from(AiMessage.from(message.toString()), tokenUsage, finishReasonObj)); + } + } + }); + } + + @Override + public int estimateTokenCount(List messages) { + var input = messages.stream().map(ChatMessage::text).collect(Collectors.joining()); + var request = new TokenizationRequest(modelId, input, projectId); + + return retryOn(new Callable() { + @Override + public Integer call() throws Exception { + return client.tokenization(request, version).result().tokenCount(); + } + }); + } + + @Override + public Response generate(List messages) { + return generate(messages, List.of()); + } + + @Override + public Response generate(List messages, ToolSpecification toolSpecification) { + return generate(messages, List.of(toolSpecification)); + } + + @Override + public void generate(List messages, ToolSpecification toolSpecification, + StreamingResponseHandler handler) { + generate(messages, List.of(toolSpecification), handler); + } + + @Override + public void generate(List messages, StreamingResponseHandler handler) { + generate(messages, List.of(), handler); + } + + public static Builder builder() { + return new Builder(); + } + + private FinishReason toFinishReason(String reason) { + if (reason == null) + return FinishReason.OTHER; + + return switch (reason) { + case "length" -> FinishReason.LENGTH; + case "stop" -> FinishReason.STOP; + case "tool_calls" -> FinishReason.TOOL_EXECUTION; + case "time_limit", "cancelled", "error" -> FinishReason.OTHER; + default -> throw new IllegalArgumentException("%s not supported".formatted(reason)); + }; + } + + public static final class Builder { + + private String modelId; + private String version; + private String projectId; + private Duration timeout; + private Integer maxTokens; + private Double temperature; + private Double topP; + private String responseFormat; + private URL url; + public boolean logResponses; + public boolean logRequests; + private WatsonxTokenGenerator tokenGenerator; + + public Builder modelId(String modelId) { + this.modelId = modelId; + return this; + } + + public Builder version(String version) { + this.version = version; + return this; + } + + public Builder projectId(String projectId) { + this.projectId = projectId; + return this; + } + + public Builder url(URL url) { + this.url = url; + return this; + } + + public Builder timeout(Duration timeout) { + this.timeout = timeout; + return this; + } + + public Builder maxTokens(Integer maxTokens) { + this.maxTokens = maxTokens; + return this; + } + + public Builder temperature(Double temperature) { + this.temperature = temperature; + return this; + } + + public Builder topP(Double topP) { + this.topP = topP; + return this; + } + + public Builder responseFormat(String responseFormat) { + this.responseFormat = responseFormat; + return this; + } + + public Builder tokenGenerator(WatsonxTokenGenerator tokenGenerator) { + this.tokenGenerator = tokenGenerator; + return this; + } + + public Builder logRequests(boolean logRequests) { + this.logRequests = logRequests; + return this; + } + + public Builder logResponses(boolean logResponses) { + this.logResponses = logResponses; + return this; + } + + public WatsonxChatModel build() { + return new WatsonxChatModel(this); + } + } +} diff --git a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxGenerationModel.java b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxGenerationModel.java index 1122085b5..4561b86a0 100644 --- a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxGenerationModel.java +++ b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxGenerationModel.java @@ -24,8 +24,8 @@ import dev.langchain4j.model.output.FinishReason; import dev.langchain4j.model.output.Response; import dev.langchain4j.model.output.TokenUsage; -import io.quarkiverse.langchain4j.watsonx.bean.Parameters; -import io.quarkiverse.langchain4j.watsonx.bean.Parameters.LengthPenalty; +import io.quarkiverse.langchain4j.watsonx.bean.TextGenerationParameters; +import io.quarkiverse.langchain4j.watsonx.bean.TextGenerationParameters.LengthPenalty; import io.quarkiverse.langchain4j.watsonx.bean.TextGenerationRequest; import io.quarkiverse.langchain4j.watsonx.bean.TextGenerationResponse; import io.quarkiverse.langchain4j.watsonx.bean.TextGenerationResponse.Result; @@ -35,6 +35,7 @@ import io.quarkiverse.langchain4j.watsonx.prompt.PromptFormatter; import io.quarkus.rest.client.reactive.QuarkusRestClientBuilder; import io.smallrye.mutiny.Context; +import io.smallrye.mutiny.infrastructure.Infrastructure; public class WatsonxGenerationModel implements ChatLanguageModel, StreamingChatLanguageModel, TokenCountEstimator { @@ -42,7 +43,7 @@ public class WatsonxGenerationModel implements ChatLanguageModel, StreamingChatL private final String modelId, projectId, version; private final WatsonxRestApi client; - private final Parameters parameters; + private final TextGenerationParameters parameters; private final PromptFormatter promptFormatter; public WatsonxGenerationModel(Builder builder) { @@ -76,7 +77,7 @@ public WatsonxGenerationModel(Builder builder) { lengthPenalty = new LengthPenalty(builder.decayFactor, builder.startIndex); } - this.parameters = Parameters.builder() + this.parameters = TextGenerationParameters.builder() .decodingMethod(builder.decodingMethod) .lengthPenalty(lengthPenalty) .minNewTokens(builder.minNewTokens) @@ -93,35 +94,16 @@ public WatsonxGenerationModel(Builder builder) { .build(); } - @Override - public Response generate(List messages) { - TextGenerationRequest request = new TextGenerationRequest(modelId, projectId, toInput(messages), parameters); - - Result result = retryOn(new Callable() { - @Override - public TextGenerationResponse call() throws Exception { - return client.chat(request, version); - } - }).results().get(0); - - var finishReason = toFinishReason(result.stopReason()); - var content = AiMessage.from(result.generatedText()); - var tokenUsage = new TokenUsage( - result.inputTokenCount(), - result.generatedTokenCount()); - - return Response.from(content, tokenUsage, finishReason); - } - @Override public Response generate(List messages, List toolSpecifications) { TextGenerationRequest request = new TextGenerationRequest(modelId, projectId, toInput(messages, toolSpecifications), parameters); + boolean toolsEnabled = (toolSpecifications != null && toolSpecifications.size() > 0) ? true : false; Result result = retryOn(new Callable() { @Override public TextGenerationResponse call() throws Exception { - return client.chat(request, version); + return client.generation(request, version); } }).results().get(0); @@ -132,7 +114,7 @@ public TextGenerationResponse call() throws Exception { AiMessage content; - if (result.generatedText().startsWith(promptFormatter.toolExecution())) { + if (toolsEnabled && result.generatedText().startsWith(promptFormatter.toolExecution())) { var tools = result.generatedText().replace(promptFormatter.toolExecution(), ""); content = AiMessage.from(promptFormatter.toolExecutionRequestFormatter(tools)); } else { @@ -143,29 +125,57 @@ public TextGenerationResponse call() throws Exception { } @Override - public Response generate(List messages, ToolSpecification toolSpecification) { - return generate(messages, List.of(toolSpecification)); - } + public void generate(List messages, List toolSpecifications, + StreamingResponseHandler handler) { + TextGenerationRequest request = new TextGenerationRequest(modelId, projectId, toInput(messages, toolSpecifications), + parameters); - @Override - public void generate(List messages, StreamingResponseHandler handler) { - TextGenerationRequest request = new TextGenerationRequest(modelId, projectId, toInput(messages), parameters); - Context context = Context.of("response", new ArrayList()); + Context context = Context.empty(); + context.put("response", new ArrayList()); + context.put("toolExecution", false); + final boolean toolsEnabled = (toolSpecifications != null && toolSpecifications.size() > 0) ? true : false; - client.chatStreaming(request, version) - .subscribe() + var mutiny = client.generationStreaming(request, version); + if (toolsEnabled) { + // Today Langchain4j doesn't allow to use the async operation with tools. + // One idea might be to give to the developer the possibility to use the VirtualThread. + mutiny.emitOn(Infrastructure.getDefaultWorkerPool()); + } + + mutiny.subscribe() .with(context, new Consumer() { @Override - @SuppressWarnings("unchecked") public void accept(TextGenerationResponse response) { try { if (response == null || response.results() == null || response.results().isEmpty()) return; - ((List) context.get("response")).add(response); - handler.onNext(response.results().get(0).generatedText()); + String chunk = response.results().get(0).generatedText(); + + if (chunk.isEmpty()) + return; + + boolean isToolExecutionState = context.get("toolExecution"); + List responses = context.get("response"); + responses.add(response); + + if (isToolExecutionState) { + // If we are in the tool execution state, the chunk is associated with the tool execution, + // which means that it must not be sent to the client. + } else { + + // Check if the chunk contains the "ToolExecution" tag. + if (toolsEnabled && chunk.startsWith(promptFormatter.toolExecution().trim())) { + // If true, enter in the ToolExecutionState. + context.put("toolExecution", true); + return; + } + + // Send the chunk to the client. + handler.onNext(chunk); + } } catch (Exception e) { handler.onError(e); @@ -180,9 +190,9 @@ public void accept(Throwable error) { }, new Runnable() { @Override - @SuppressWarnings("unchecked") public void run() { - var list = ((List) context.get("response")); + List list = context.get("response"); + boolean isToolExecutionState = context.get("toolExecution"); int inputTokenCount = 0; int outputTokenCount = 0; @@ -204,18 +214,28 @@ public void run() { builder.append(response.generatedText()); } - AiMessage message = new AiMessage(builder.toString()); + AiMessage content; TokenUsage tokenUsage = new TokenUsage(inputTokenCount, outputTokenCount); FinishReason finishReason = toFinishReason(stopReason); - handler.onComplete(Response.from(message, tokenUsage, finishReason)); + + String message = builder.toString(); + + if (isToolExecutionState) { + context.put("toolExecution", false); + var tools = message.replace(promptFormatter.toolExecution(), ""); + content = AiMessage.from(promptFormatter.toolExecutionRequestFormatter(tools)); + } else { + content = AiMessage.from(message); + } + + handler.onComplete(Response.from(content, tokenUsage, finishReason)); } }); } @Override public int estimateTokenCount(List messages) { - - var input = toInput(messages); + var input = toInput(messages, null); var request = new TokenizationRequest(modelId, input, projectId); return retryOn(new Callable() { @@ -226,18 +246,29 @@ public Integer call() throws Exception { }); } - public static Builder builder() { - return new Builder(); + @Override + public Response generate(List messages) { + return generate(messages, List.of()); } - private String toInput(List messages) { - var prompt = promptFormatter.format(messages, List.of()); - log.debugf(""" - Formatted prompt: - ----------------- - %s - -----------------""", prompt); - return prompt; + @Override + public Response generate(List messages, ToolSpecification toolSpecification) { + return generate(messages, List.of(toolSpecification)); + } + + @Override + public void generate(List messages, StreamingResponseHandler handler) { + generate(messages, List.of(), handler); + } + + @Override + public void generate(List messages, ToolSpecification toolSpecification, + StreamingResponseHandler handler) { + generate(messages, List.of(toolSpecification), handler); + } + + public static Builder builder() { + return new Builder(); } private String toInput(List messages, List tools) { @@ -250,12 +281,12 @@ private String toInput(List messages, List tools return prompt; } - private FinishReason toFinishReason(String stopReason) { - return switch (stopReason) { + private FinishReason toFinishReason(String reason) { + return switch (reason) { case "max_tokens", "token_limit" -> FinishReason.LENGTH; case "eos_token", "stop_sequence" -> FinishReason.STOP; case "not_finished", "cancelled", "time_limit", "error" -> FinishReason.OTHER; - default -> throw new IllegalArgumentException("%s not supported".formatted(stopReason)); + default -> throw new IllegalArgumentException("%s not supported".formatted(reason)); }; } diff --git a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxUtils.java b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxUtils.java index 589c2830b..2635ac78f 100644 --- a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxUtils.java +++ b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxUtils.java @@ -25,8 +25,10 @@ public static T retryOn(Callable action) { Optional optional = Optional.empty(); for (WatsonxError.Error error : e.details().errors()) { - if (WatsonxError.Code.AUTHENTICATION_TOKEN_EXPIRED.equals(error.code())) { - optional = Optional.of(error.code()); + + var c = error.codeToEnum(); + if (c.isPresent() && WatsonxError.Code.AUTHENTICATION_TOKEN_EXPIRED.equals(c.get())) { + optional = Optional.of(c.get()); break; } } diff --git a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/bean/TextChatMessage.java b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/bean/TextChatMessage.java new file mode 100644 index 000000000..bd55ecdd5 --- /dev/null +++ b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/bean/TextChatMessage.java @@ -0,0 +1,258 @@ +package io.quarkiverse.langchain4j.watsonx.bean; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +import dev.langchain4j.agent.tool.ToolExecutionRequest; +import dev.langchain4j.agent.tool.ToolSpecification; +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.data.message.Content; +import dev.langchain4j.data.message.SystemMessage; +import dev.langchain4j.data.message.TextContent; +import dev.langchain4j.data.message.ToolExecutionResultMessage; +import dev.langchain4j.data.message.UserMessage; +import io.quarkiverse.langchain4j.watsonx.bean.TextChatMessage.TextChatMessageAssistant; +import io.quarkiverse.langchain4j.watsonx.bean.TextChatMessage.TextChatMessageSystem; +import io.quarkiverse.langchain4j.watsonx.bean.TextChatMessage.TextChatMessageTool; +import io.quarkiverse.langchain4j.watsonx.bean.TextChatMessage.TextChatMessageUser; +import io.quarkiverse.langchain4j.watsonx.bean.TextChatMessage.TextChatToolCall.TextChatFunctionCall; + +public sealed interface TextChatMessage + permits TextChatMessageAssistant, TextChatMessageSystem, TextChatMessageUser, TextChatMessageTool { + /** + * Converts the {@link ChatMessage} into a {@link TextChatMessage}. + * + * @param chatMessage the chat message to convert + * @return the converted {@link TextChatMessage} + */ + public static TextChatMessage convert(ChatMessage chatMessage) { + return switch (chatMessage.type()) { + case AI -> TextChatMessageAssistant.of(AiMessage.class.cast(chatMessage)); + case SYSTEM -> TextChatMessageSystem.of(SystemMessage.class.cast(chatMessage)); + case USER -> TextChatMessageUser.of(UserMessage.class.cast(chatMessage)); + case TOOL_EXECUTION_RESULT -> TextChatMessageTool.of(ToolExecutionResultMessage.class.cast(chatMessage)); + }; + } + + public record TextChatMessageAssistant(String role, String content, + List toolCalls) implements TextChatMessage { + + private static final String ROLE = "assistant"; + + /** + * Creates a {@link TextChatMessageAssistant} from a {@link AiMessage}. + * + * @param aiMessage the ai message to convert + * @return the created {@link TextChatMessageAssistant} + */ + public static TextChatMessageAssistant of(AiMessage aiMessage) { + if (!aiMessage.hasToolExecutionRequests()) { + return new TextChatMessageAssistant(ROLE, aiMessage.text(), null); + } + + // Mapping the tool execution requests + var toolCalls = aiMessage.toolExecutionRequests().stream() + .map(TextChatToolCall::of) + .toList(); + + return new TextChatMessageAssistant(ROLE, aiMessage.text(), toolCalls); + } + + /** + * Creates a {@link TextChatMessageAssistant}. + * + * @param message the content of the system message to be created. + * @return the created {@link TextChatMessageAssistant}. + */ + public static TextChatMessageAssistant of(String message) { + return new TextChatMessageAssistant(ROLE, message, null); + } + + /** + * Creates a {@link TextChatMessageAssistant}. + * + * @param toolCalls the tools to execute. + * @return the created {@link TextChatMessageAssistant}. + */ + public static TextChatMessageAssistant of(List toolCalls) { + return new TextChatMessageAssistant(ROLE, null, toolCalls); + } + } + + public record TextChatMessageSystem(String role, String content) implements TextChatMessage { + private static final String ROLE = "system"; + + /** + * Creates a {@link TextChatMessageSystem} from a {@link SystemMessage}. + * + * @param systemMessage the system message to convert + * @return the created {@link TextChatMessageSystem}. + */ + public static TextChatMessageSystem of(SystemMessage systemMessage) { + return new TextChatMessageSystem(ROLE, systemMessage.text()); + } + + /** + * Creates a {@link TextChatMessageSystem}. + * + * @param message the content of the system message to be created. + * @return the created {@link TextChatMessageSystem}. + */ + public static TextChatMessageSystem of(String message) { + return new TextChatMessageSystem(ROLE, message); + } + } + + public record TextChatMessageUser(String role, List> content, String name) implements TextChatMessage { + + private static final String ROLE = "user"; + + /** + * Creates a {@link TextChatMessageUser} from a {@link UserMessage}. + * + * @param systemMessage the user message to convert + * @return the created {@link TextChatMessageUser} + */ + public static TextChatMessageUser of(UserMessage userMessage) { + var values = new ArrayList>(); + for (Content content : userMessage.contents()) { + switch (content.type()) { + case TEXT -> { + var textContent = TextContent.class.cast(content); + values.add(Map.of( + "type", "text", + "text", textContent.text())); + } + case AUDIO, IMAGE, PDF, TEXT_FILE, VIDEO -> + throw new UnsupportedOperationException("Unimplemented case: " + content.type()); + } + } + return new TextChatMessageUser(ROLE, values, userMessage.name()); + } + + /** + * Creates a {@link TextChatMessageUser}. + * + * @param message the content of the system message to be created. + * @return the created {@link TextChatMessageUser}. + */ + public static TextChatMessageUser of(String message) { + return of(UserMessage.from(message)); + } + } + + public record TextChatMessageTool(String role, String content, String toolCallId) implements TextChatMessage { + + private static final String ROLE = "tool"; + + /** + * Creates a {@link TextChatMessageTool} from a {@link ToolExecutionResultMessage}. + * + * @param toolExecutionResultMessage the tool execution result message to convert + * @return the created {@link TextChatMessageTool} + */ + public static TextChatMessageTool of(ToolExecutionResultMessage toolExecutionResultMessage) { + return new TextChatMessageTool(ROLE, toolExecutionResultMessage.text(), toolExecutionResultMessage.id()); + } + + /** + * Creates a {@link TextChatMessageTool}. + * + * @param message the content of the message tool. + * @param toolCallId the unique identifier of the message tool. + * @return the created {@link TextChatMessageTool}. + */ + public static TextChatMessageTool of(String content, String toolCallId) { + return new TextChatMessageTool(ROLE, content, toolCallId); + } + } + + public record TextChatToolCall(Integer index, String id, String type, TextChatFunctionCall function) { + public record TextChatFunctionCall(String name, String arguments) { + } + + /** + * Creates a {@link TextChatToolCall} from a {@link ToolExecutionRequest}. + * + * @param toolExecutionRequest the tool execution request to convert + * @return the created {@link TextChatToolCall} + */ + public static TextChatToolCall of(ToolExecutionRequest toolExecutionRequest) { + return new TextChatToolCall(null, toolExecutionRequest.id(), "function", + new TextChatFunctionCall(toolExecutionRequest.name(), toolExecutionRequest.arguments())); + } + + /** + * Converts a {@link TextChatToolCall} into a {@link ToolExecutionRequest}. + * + * @return the converted {@link ToolExecutionRequest} + */ + public ToolExecutionRequest convert() { + return ToolExecutionRequest.builder() + .id(id) + .name(function.name) + .arguments(function.arguments) + .build(); + } + } + + public record TextChatParameterTools(String type, TextChatParameterFunction function) { + public record TextChatParameterFunction(String name, String description, Map parameters) { + } + + /** + * Creates a {@link TextChatParameterTools} from a {@link ToolSpecification}. + * + * @param toolExecutionRequest the tool specification to convert + * @return the created {@link TextChatParameterTools} + */ + public static TextChatParameterTools of(ToolSpecification toolSpecification) { + var parameters = new TextChatParameterFunction(toolSpecification.name(), toolSpecification.description(), Map.of( + "type", toolSpecification.parameters().type(), + "properties", toolSpecification.parameters().properties(), + "required", toolSpecification.parameters().required())); + return new TextChatParameterTools("function", parameters); + } + } + + /** + * The {@code StreamingToolFetcher} class is responsible for fetching a list of tools from a streaming API. + */ + public class StreamingToolFetcher { + + private int index; + private StringBuilder arguments; + private String id, type, name; + + public StreamingToolFetcher(int index) { + this.index = index; + arguments = new StringBuilder(); + } + + public void setId(String id) { + if (id != null) + this.id = id; + } + + public void setType(String type) { + if (type != null) + this.type = type; + } + + public void setName(String name) { + if (name != null && !name.isBlank()) + this.name = name; + } + + public void appendArguments(String arguments) { + if (arguments != null) + this.arguments.append(arguments); + } + + public TextChatToolCall build() { + return new TextChatToolCall(index, id, type, new TextChatFunctionCall(name, arguments.toString())); + } + } +} diff --git a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/bean/TextChatParameters.java b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/bean/TextChatParameters.java new file mode 100644 index 000000000..ec86b0a07 --- /dev/null +++ b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/bean/TextChatParameters.java @@ -0,0 +1,87 @@ +package io.quarkiverse.langchain4j.watsonx.bean; + +public class TextChatParameters { + + public record TextChatResponseFormat(String type) { + }; + + private final Integer maxTokens; + private final Double temperature; + private final Double topP; + private final Long timeLimit; + private final TextChatResponseFormat responseFormat; + + public TextChatParameters(Builder builder) { + this.maxTokens = builder.maxTokens; + this.temperature = builder.temperature; + this.topP = builder.topP; + this.timeLimit = builder.timeLimit; + + if (builder.responseFormat != null) + this.responseFormat = new TextChatResponseFormat(builder.responseFormat); + else + this.responseFormat = null; + } + + public static Builder builder() { + return new Builder(); + } + + public Integer getMaxTokens() { + return maxTokens; + } + + public Double getTemperature() { + return temperature; + } + + public Double getTopP() { + return topP; + } + + public Long getTimeLimit() { + return timeLimit; + } + + public TextChatResponseFormat getResponseFormat() { + return responseFormat; + } + + public static class Builder { + + private Integer maxTokens; + private Double temperature; + private Double topP; + private Long timeLimit; + private String responseFormat; + + public Builder maxTokens(Integer maxTokens) { + this.maxTokens = maxTokens; + return this; + } + + public Builder temperature(Double temperature) { + this.temperature = temperature; + return this; + } + + public Builder topP(Double topP) { + this.topP = topP; + return this; + } + + public Builder timeLimit(Long timeLimit) { + this.timeLimit = timeLimit; + return this; + } + + public Builder responseFormat(String responseFormat) { + this.responseFormat = responseFormat; + return this; + } + + public TextChatParameters build() { + return new TextChatParameters(this); + } + } +} diff --git a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/bean/TextChatRequest.java b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/bean/TextChatRequest.java new file mode 100644 index 000000000..373b746f9 --- /dev/null +++ b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/bean/TextChatRequest.java @@ -0,0 +1,12 @@ +package io.quarkiverse.langchain4j.watsonx.bean; + +import java.util.List; + +import com.fasterxml.jackson.annotation.JsonUnwrapped; + +import io.quarkiverse.langchain4j.watsonx.bean.TextChatMessage.TextChatParameterTools; + +public record TextChatRequest(String modelId, String projectId, List messages, + List tools, + @JsonUnwrapped TextChatParameters parameters) { +} diff --git a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/bean/TextChatResponse.java b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/bean/TextChatResponse.java new file mode 100644 index 000000000..ecef5d314 --- /dev/null +++ b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/bean/TextChatResponse.java @@ -0,0 +1,32 @@ +package io.quarkiverse.langchain4j.watsonx.bean; + +import java.util.List; +import java.util.concurrent.TimeUnit; + +import com.fasterxml.jackson.databind.PropertyNamingStrategies; +import com.fasterxml.jackson.databind.annotation.JsonNaming; + +import io.quarkiverse.langchain4j.watsonx.bean.TextChatMessage.TextChatToolCall; + +public record TextChatResponse(String id, String modelId, List choices, Long created, + TextChatUsage usage) { + + @JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class) + public record TextChatResultChoice(Integer index, TextChatResultMessage message, String finishReason) { + } + + @JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class) + public record TextChatUsage(Integer completionTokens, Integer promptTokens, Integer totalTokens) { + } + + @JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class) + public record TextChatResultMessage(String role, String content, List toolCalls) { + } + + public Long created() { + if (created != null) + return TimeUnit.SECONDS.toMillis(created); + else + return null; + } +} diff --git a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/bean/Parameters.java b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/bean/TextGenerationParameters.java similarity index 95% rename from model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/bean/Parameters.java rename to model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/bean/TextGenerationParameters.java index 33722408d..40b3931a6 100644 --- a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/bean/Parameters.java +++ b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/bean/TextGenerationParameters.java @@ -2,7 +2,7 @@ import java.util.List; -public class Parameters { +public class TextGenerationParameters { public record LengthPenalty(Double decayFactor, Integer startIndex) { }; @@ -21,7 +21,7 @@ public record LengthPenalty(Double decayFactor, Integer startIndex) { private final Integer truncateInputTokens; private final Boolean includeStopSequence; - private Parameters(Builder builder) { + private TextGenerationParameters(Builder builder) { this.decodingMethod = builder.decodingMethod; this.lengthPenalty = builder.lengthPenalty; this.minNewTokens = builder.minNewTokens; @@ -174,8 +174,8 @@ public Builder includeStopSequence(Boolean includeStopSequence) { return this; } - public Parameters build() { - return new Parameters(this); + public TextGenerationParameters build() { + return new TextGenerationParameters(this); } } } diff --git a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/bean/TextGenerationRequest.java b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/bean/TextGenerationRequest.java index aacee2250..5cac045d7 100644 --- a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/bean/TextGenerationRequest.java +++ b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/bean/TextGenerationRequest.java @@ -4,5 +4,5 @@ public record TextGenerationRequest( String modelId, String projectId, String input, - Parameters parameters) { + TextGenerationParameters parameters) { } diff --git a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/bean/TextStreamingChatResponse.java b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/bean/TextStreamingChatResponse.java new file mode 100644 index 000000000..8d0272461 --- /dev/null +++ b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/bean/TextStreamingChatResponse.java @@ -0,0 +1,32 @@ +package io.quarkiverse.langchain4j.watsonx.bean; + +import java.util.List; +import java.util.concurrent.TimeUnit; + +import com.fasterxml.jackson.databind.PropertyNamingStrategies; +import com.fasterxml.jackson.databind.annotation.JsonNaming; + +import io.quarkiverse.langchain4j.watsonx.bean.TextChatMessage.TextChatToolCall; + +public record TextStreamingChatResponse(String id, String modelId, List choices, Long created, + TextChatUsage usage) { + + @JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class) + public record TextChatResultChoice(Integer index, TextChatResultMessage delta, String finishReason) { + } + + @JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class) + public record TextChatUsage(Integer completionTokens, Integer promptTokens, Integer totalTokens) { + } + + @JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class) + public record TextChatResultMessage(String role, String content, List toolCalls) { + } + + public Long created() { + if (created != null) + return TimeUnit.SECONDS.toMillis(created); + else + return null; + } +} diff --git a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/bean/WatsonxError.java b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/bean/WatsonxError.java index e97134fce..2910b7c72 100644 --- a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/bean/WatsonxError.java +++ b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/bean/WatsonxError.java @@ -1,12 +1,20 @@ package io.quarkiverse.langchain4j.watsonx.bean; import java.util.List; +import java.util.Optional; import com.fasterxml.jackson.annotation.JsonProperty; public record WatsonxError(Integer statusCode, String trace, List errors) { - public static record Error(Code code, String message) { + public static record Error(String code, String message) { + public Optional codeToEnum() { + try { + return Optional.of(Code.valueOf(code.toUpperCase())); + } catch (Exception e) { + return Optional.empty(); + } + } } public static enum Code { @@ -14,9 +22,15 @@ public static enum Code { @JsonProperty("authorization_rejected") AUTHORIZATION_REJECTED, + @JsonProperty("json_type_error") + JSON_TYPE_ERROR, + @JsonProperty("model_not_supported") MODEL_NOT_SUPPORTED, + @JsonProperty("model_no_support_for_function") + MODEL_NO_SUPPORT_FOR_FUNCTION, + @JsonProperty("user_authorization_failed") USER_AUTHORIZATION_FAILED, diff --git a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/client/WatsonxRestApi.java b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/client/WatsonxRestApi.java index 25422e07a..6e1166711 100644 --- a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/client/WatsonxRestApi.java +++ b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/client/WatsonxRestApi.java @@ -23,8 +23,11 @@ import io.quarkiverse.langchain4j.QuarkusJsonCodecFactory; import io.quarkiverse.langchain4j.watsonx.bean.EmbeddingRequest; import io.quarkiverse.langchain4j.watsonx.bean.EmbeddingResponse; +import io.quarkiverse.langchain4j.watsonx.bean.TextChatRequest; +import io.quarkiverse.langchain4j.watsonx.bean.TextChatResponse; import io.quarkiverse.langchain4j.watsonx.bean.TextGenerationRequest; import io.quarkiverse.langchain4j.watsonx.bean.TextGenerationResponse; +import io.quarkiverse.langchain4j.watsonx.bean.TextStreamingChatResponse; import io.quarkiverse.langchain4j.watsonx.bean.TokenizationRequest; import io.quarkiverse.langchain4j.watsonx.bean.TokenizationResponse; import io.quarkiverse.langchain4j.watsonx.bean.WatsonxError; @@ -50,13 +53,24 @@ public interface WatsonxRestApi { @POST @Path("text/generation") - TextGenerationResponse chat(TextGenerationRequest request, @QueryParam("version") String version) + TextGenerationResponse generation(TextGenerationRequest request, @QueryParam("version") String version) throws WatsonxException; @POST @Path("text/generation_stream") @RestStreamElementType(MediaType.APPLICATION_JSON) - Multi chatStreaming(TextGenerationRequest request, @QueryParam("version") String version); + Multi generationStreaming(TextGenerationRequest request, @QueryParam("version") String version); + + @POST + @Path("text/chat") + TextChatResponse chat(TextChatRequest request, @QueryParam("version") String version) + throws WatsonxException; + + @POST + @Path("text/chat_stream") + @RestStreamElementType(MediaType.APPLICATION_JSON) + Multi streamingChat(TextChatRequest request, @QueryParam("version") String version) + throws WatsonxException; @POST @Path("text/tokenization") diff --git a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/runtime/WatsonxRecorder.java b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/runtime/WatsonxRecorder.java index 137a2a553..d2b63bc5f 100644 --- a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/runtime/WatsonxRecorder.java +++ b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/runtime/WatsonxRecorder.java @@ -20,6 +20,7 @@ import dev.langchain4j.model.embedding.DisabledEmbeddingModel; import dev.langchain4j.model.embedding.EmbeddingModel; import io.quarkiverse.langchain4j.runtime.NamedConfigUtil; +import io.quarkiverse.langchain4j.watsonx.WatsonxChatModel; import io.quarkiverse.langchain4j.watsonx.WatsonxEmbeddingModel; import io.quarkiverse.langchain4j.watsonx.WatsonxGenerationModel; import io.quarkiverse.langchain4j.watsonx.WatsonxTokenGenerator; @@ -44,6 +45,64 @@ public class WatsonxRecorder { private static final Map tokenGeneratorCache = new HashMap<>(); private static final ConfigValidationException.Problem[] EMPTY_PROBLEMS = new ConfigValidationException.Problem[0]; + public Supplier chatModel(LangChain4jWatsonxConfig runtimeConfig, + LangChain4jWatsonxFixedRuntimeConfig fixedRuntimeConfig, String configName) { + + LangChain4jWatsonxConfig.WatsonConfig watsonRuntimeConfig = correspondingWatsonRuntimeConfig(runtimeConfig, configName); + LangChain4jWatsonxFixedRuntimeConfig.WatsonConfig watsonFixedRuntimeConfig = correspondingWatsonFixedRuntimeConfig( + fixedRuntimeConfig, configName); + + if (watsonRuntimeConfig.enableIntegration()) { + + var builder = chatBuilder(watsonRuntimeConfig, watsonFixedRuntimeConfig, configName); + return new Supplier<>() { + @Override + public ChatLanguageModel get() { + return builder.build(); + } + }; + + } else { + return new Supplier<>() { + + @Override + public ChatLanguageModel get() { + return new DisabledChatLanguageModel(); + } + + }; + } + } + + public Supplier streamingChatModel(LangChain4jWatsonxConfig runtimeConfig, + LangChain4jWatsonxFixedRuntimeConfig fixedRuntimeConfig, String configName) { + + LangChain4jWatsonxConfig.WatsonConfig watsonRuntimeConfig = correspondingWatsonRuntimeConfig(runtimeConfig, configName); + LangChain4jWatsonxFixedRuntimeConfig.WatsonConfig watsonFixedRuntimeConfig = correspondingWatsonFixedRuntimeConfig( + fixedRuntimeConfig, configName); + + if (watsonRuntimeConfig.enableIntegration()) { + + var builder = chatBuilder(watsonRuntimeConfig, watsonFixedRuntimeConfig, configName); + return new Supplier<>() { + @Override + public StreamingChatLanguageModel get() { + return builder.build(); + } + }; + + } else { + return new Supplier<>() { + + @Override + public StreamingChatLanguageModel get() { + return new DisabledStreamingChatLanguageModel(); + } + + }; + } + } + public Supplier generationModel(LangChain4jWatsonxConfig runtimeConfig, LangChain4jWatsonxFixedRuntimeConfig fixedRuntimeConfig, String configName, PromptFormatter promptFormatter) { @@ -59,7 +118,7 @@ public Supplier generationModel(LangChain4jWatsonxConfig runt if (watsonRuntimeConfig.enableIntegration()) { - var builder = generateChatBuilder(watsonRuntimeConfig, watsonFixedRuntimeConfig, configName, promptFormatter); + var builder = generationBuilder(watsonRuntimeConfig, watsonFixedRuntimeConfig, configName, promptFormatter); return new Supplier<>() { @Override public ChatLanguageModel get() { @@ -69,10 +128,12 @@ public ChatLanguageModel get() { } else { return new Supplier<>() { + @Override public ChatLanguageModel get() { return new DisabledChatLanguageModel(); } + }; } } @@ -86,7 +147,7 @@ public Supplier generationStreamingModel(LangChain4j if (watsonRuntimeConfig.enableIntegration()) { - var builder = generateChatBuilder(watsonRuntimeConfig, watsonFixedRuntimeConfig, configName, promptFormatter); + var builder = generationBuilder(watsonRuntimeConfig, watsonFixedRuntimeConfig, configName, promptFormatter); return new Supplier<>() { @Override public StreamingChatLanguageModel get() { @@ -96,10 +157,12 @@ public StreamingChatLanguageModel get() { } else { return new Supplier<>() { + @Override public StreamingChatLanguageModel get() { return new DisabledStreamingChatLanguageModel(); } + }; } } @@ -145,10 +208,12 @@ public WatsonxEmbeddingModel get() { } else { return new Supplier<>() { + @Override public EmbeddingModel get() { return new DisabledEmbeddingModel(); } + }; } } @@ -164,7 +229,47 @@ public WatsonxTokenGenerator apply(String iamUrl) { }; } - private WatsonxGenerationModel.Builder generateChatBuilder( + private WatsonxChatModel.Builder chatBuilder( + LangChain4jWatsonxConfig.WatsonConfig watsonRuntimeConfig, + LangChain4jWatsonxFixedRuntimeConfig.WatsonConfig watsonFixedRuntimeConfig, + String configName) { + + ChatModelConfig chatModelConfig = watsonRuntimeConfig.chatModel(); + var configProblems = checkConfigurations(watsonRuntimeConfig, configName); + + if (!configProblems.isEmpty()) { + throw new ConfigValidationException(configProblems.toArray(EMPTY_PROBLEMS)); + } + + String iamUrl = watsonRuntimeConfig.iam().baseUrl().toExternalForm(); + WatsonxTokenGenerator tokenGenerator = tokenGeneratorCache.computeIfAbsent(iamUrl, + createTokenGenerator(watsonRuntimeConfig.iam(), watsonRuntimeConfig.apiKey())); + + URL url; + try { + url = new URL(watsonRuntimeConfig.baseUrl()); + } catch (Exception e) { + throw new RuntimeException(e); + } + + checkProperties(watsonRuntimeConfig, watsonFixedRuntimeConfig, configName); + + return WatsonxChatModel.builder() + .tokenGenerator(tokenGenerator) + .url(url) + .timeout(watsonRuntimeConfig.timeout().orElse(Duration.ofSeconds(10))) + .logRequests(firstOrDefault(false, chatModelConfig.logRequests(), watsonRuntimeConfig.logRequests())) + .logResponses(firstOrDefault(false, chatModelConfig.logResponses(), watsonRuntimeConfig.logResponses())) + .version(watsonRuntimeConfig.version()) + .projectId(watsonRuntimeConfig.projectId()) + .modelId(watsonFixedRuntimeConfig.chatModel().modelId()) + .maxTokens(chatModelConfig.maxNewTokens()) + .temperature(chatModelConfig.temperature()) + .topP(firstOrDefault(null, chatModelConfig.topP())) + .responseFormat(chatModelConfig.responseFormat().orElse(null)); + } + + private WatsonxGenerationModel.Builder generationBuilder( LangChain4jWatsonxConfig.WatsonConfig watsonRuntimeConfig, LangChain4jWatsonxFixedRuntimeConfig.WatsonConfig watsonFixedRuntimeConfig, String configName, PromptFormatter promptFormatter) { @@ -258,6 +363,21 @@ private List checkConfigurations(LangChain4jW return configProblems; } + private void checkProperties(LangChain4jWatsonxConfig.WatsonConfig watsonxRuntimeConfig, + LangChain4jWatsonxFixedRuntimeConfig.WatsonConfig watsonxFixedRuntimeConfig, String configName) { + + final String message = "The property '{}' is being used in '{}' mode, but it is only applicable in the following mode(s): {}. This property will be ignored."; + final String chat = "chat"; + + String currentMode = watsonxFixedRuntimeConfig.chatModel().mode(); + + if (watsonxFixedRuntimeConfig.chatModel().mode().equals("generation")) { + if (watsonxRuntimeConfig.chatModel().responseFormat().isPresent()) { + log.warnf(message, configName.concat(".response-format"), currentMode, chat); + } + } + } + private ConfigValidationException.Problem createBaseURLConfigProblem(String configName) { return createConfigProblem("base-url", configName); } diff --git a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/runtime/config/ChatModelConfig.java b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/runtime/config/ChatModelConfig.java index 8292d6242..605af95ea 100644 --- a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/runtime/config/ChatModelConfig.java +++ b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/runtime/config/ChatModelConfig.java @@ -23,6 +23,8 @@ public interface ChatModelConfig { * parameters. *

* Allowable values: [sample,greedy] + *

+ * Applicable in modes: [generation] */ @WithDefault("greedy") String decodingMethod(); @@ -31,6 +33,8 @@ public interface ChatModelConfig { * It can be used to exponentially increase the likelihood of the text generation terminating once a specified number of * tokens * have been generated. + *

+ * Applicable in modes: [generation] */ LengthPenaltyConfig lengthPenalty(); @@ -42,7 +46,9 @@ public interface ChatModelConfig { * are a mix of full words and sub-words. Depending on the users plan, and on the model being used, there may be an enforced * maximum number of new tokens. *

- * Possible values: ≥ 0 + * Possible values: ≥ 0 * + *

+ * Applicable in modes: [chat,generation] */ @WithDefault("200") Integer maxNewTokens(); @@ -50,7 +56,9 @@ public interface ChatModelConfig { /** * If stop sequences are given, they are ignored until minimum tokens are generated. *

- * Possible values: ≥ 0 + * Possible values: ≥ 0 * + *

+ * Applicable in modes: [generation] */ @WithDefault("0") Integer minNewTokens(); @@ -58,7 +66,9 @@ public interface ChatModelConfig { /** * Random number generator seed to use in sampling mode for experimental repeatability. *

- * Possible values: ≥ 1 + * Possible values: ≥ 1 * + *

+ * Applicable in modes: [generation] */ Optional randomSeed(); @@ -67,7 +77,9 @@ public interface ChatModelConfig { * the * output. Stop sequences encountered prior to the minimum number of tokens being generated will be ignored. *

- * Possible values: 0 ≤ number of items ≤ 6, contains only unique items + * Possible values: 0 ≤ number of items ≤ 6, contains only unique items * + *

+ * Applicable in modes: [generation] */ Optional> stopSequences(); @@ -78,7 +90,9 @@ public interface ChatModelConfig { * probability * distribution, resulting in "more random" output. A value of 1.0 has no effect. *

- * Possible values: 0 ≤ value ≤ 2 + * Possible values: 0 ≤ value ≤ 2 * + *

+ * Applicable in modes: [chat,generation] */ @WithDefault("1.0") Double temperature(); @@ -89,7 +103,9 @@ public interface ChatModelConfig { * When decoding_strategy is set to sample, only the top_k most likely tokens are considered as * candidates for the next generated token. *

- * Possible values: 1 ≤ value ≤ 100 + * Possible values: 1 ≤ value ≤ 100 * + *

+ * Applicable in modes: [generation] */ Optional topK(); @@ -99,7 +115,9 @@ public interface ChatModelConfig { * that add up to at least top_p. Also known as nucleus sampling. A value of 1.0 is equivalent to * disabled. *

- * Possible values: 0 < value ≤ 1 + * Possible values: 0 < value ≤ 1 * + *

+ * Applicable in modes: [chat,generation] */ Optional topP(); @@ -107,7 +125,9 @@ public interface ChatModelConfig { * Represents the penalty for penalizing tokens that have already been generated or belong to the context. The value * 1.0 means that there is no penalty. *

- * Possible values: 1 ≤ value ≤ 2 + * Possible values: 1 ≤ value ≤ 2 * + *

+ * Applicable in modes: [generation] */ Optional repetitionPenalty(); @@ -122,25 +142,29 @@ public interface ChatModelConfig { * don't * truncate. *

- * Possible values: ≥ 0 + * Possible values: ≥ 0 * + *

+ * Applicable in modes: [generation] */ Optional truncateInputTokens(); /** * Pass false to omit matched stop sequences from the end of the output text. The default is true, - * meaning that the output will end with the stop sequence text when matched. + * meaning that the output will end with the stop sequence text when matched. * + *

+ * Applicable in modes: [generation] */ Optional includeStopSequence(); /** - * Whether chat model requests should be logged. + * Whether chat model requests should be logged. * */ @ConfigDocDefault("false") @WithDefault("${quarkus.langchain4j.watsonx.log-requests}") Optional logRequests(); /** - * Whether chat model responses should be logged. + * Whether chat model responses should be logged. * */ @ConfigDocDefault("false") @WithDefault("${quarkus.langchain4j.watsonx.log-responses}") @@ -149,11 +173,22 @@ public interface ChatModelConfig { /** * Delimiter used to concatenate the ChatMessage elements into a single string. By setting this property, you can define * your - * preferred way of concatenating messages to ensure that the prompt is structured in the correct way. + * preferred way of concatenating messages to ensure that the prompt is structured in the correct way. * + *

+ * Applicable in modes: [generation] */ @WithDefault("\n") String promptJoiner(); + /** + * Specifies the desired format for the model's output. + *

+ * Allowable values: [json_object] * + *

+ * Applicable in modes: [chat] + */ + Optional responseFormat(); + @ConfigGroup public interface LengthPenaltyConfig { diff --git a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/runtime/config/ChatModelFixedRuntimeConfig.java b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/runtime/config/ChatModelFixedRuntimeConfig.java index 274e1c0e2..7b62a769a 100644 --- a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/runtime/config/ChatModelFixedRuntimeConfig.java +++ b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/runtime/config/ChatModelFixedRuntimeConfig.java @@ -14,17 +14,33 @@ public interface ChatModelFixedRuntimeConfig { * "https://dataplatform.cloud.ibm.com/docs/content/wsj/analyze-data/fm-api-model-ids.html?context=wx&audience=wdp#model-ids">click * here. */ - @WithDefault("ibm/granite-13b-chat-v2") + @WithDefault("mistralai/mistral-large") String modelId(); /** - * Configuration property that enables or disables the functionality of the prompt formatter. + * Configuration property that enables or disables the functionality of the prompt formatter for the `generation` mode. * *

    *
  • true: When enabled, prompts are automatically enriched with the specific tags defined by the model.
  • *
  • false: Prompts will not be enriched with the model's tags.
  • *
+ *

+ * Applicable in modes: [generation] */ @WithDefault("false") boolean promptFormatter(); + + /** + * Specifies the mode of interaction with the selected model. + *

+ * This property allows you to choose between two modes of operation: + *

    + *
  • chat: prompts are automatically enriched with the specific tags defined by the model
  • + *
  • generation: prompts require manual specification of tags
  • + *
+ * Allowable values: [chat, generation] + */ + @WithDefault("chat") + String mode(); + } From 6b3edf47a783d9bbbb47b5e3f031fa9bfb68fa73 Mon Sep 17 00:00:00 2001 From: Andrea Di Maio Date: Mon, 14 Oct 2024 18:37:51 +0200 Subject: [PATCH 4/4] [watsonx.ai] Add truncate_input_tokens property --- .../watsonx/deployment/AiEmbeddingTest.java | 4 ++-- .../deployment/GenerationAllPropertiesTest.java | 7 ++++++- .../GenerationDefaultPropertiesTest.java | 3 ++- .../watsonx/WatsonxEmbeddingModel.java | 15 ++++++++++++++- .../watsonx/bean/EmbeddingParameters.java | 4 ++++ .../watsonx/bean/EmbeddingRequest.java | 6 +++--- .../watsonx/runtime/WatsonxRecorder.java | 3 ++- .../runtime/config/EmbeddingModelConfig.java | 14 +++++++++++++- 8 files changed, 46 insertions(+), 10 deletions(-) create mode 100644 model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/bean/EmbeddingParameters.java diff --git a/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/AiEmbeddingTest.java b/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/AiEmbeddingTest.java index 27134bcd8..26db85417 100644 --- a/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/AiEmbeddingTest.java +++ b/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/AiEmbeddingTest.java @@ -83,7 +83,7 @@ void test_embed_list_of_three_textsegment() throws Exception { var input = "Embedding THIS!"; EmbeddingRequest request = new EmbeddingRequest(WireMockUtil.DEFAULT_EMBEDDING_MODEL, WireMockUtil.PROJECT_ID, - List.of(input, input, input)); + List.of(input, input, input), null); mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_EMBEDDING_API, 200) .body(mapper.writeValueAsString(request)) @@ -140,7 +140,7 @@ private List mockEmbeddingServer(String input) throws Exception { .build(); EmbeddingRequest request = new EmbeddingRequest(WireMockUtil.DEFAULT_EMBEDDING_MODEL, WireMockUtil.PROJECT_ID, - List.of(input)); + List.of(input), null); mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_EMBEDDING_API, 200) .body(mapper.writeValueAsString(request)) diff --git a/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/GenerationAllPropertiesTest.java b/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/GenerationAllPropertiesTest.java index 5c38a1e7a..745545eec 100644 --- a/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/GenerationAllPropertiesTest.java +++ b/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/GenerationAllPropertiesTest.java @@ -25,6 +25,7 @@ import dev.langchain4j.model.chat.TokenCountEstimator; import dev.langchain4j.model.embedding.EmbeddingModel; import dev.langchain4j.model.output.Response; +import io.quarkiverse.langchain4j.watsonx.bean.EmbeddingParameters; import io.quarkiverse.langchain4j.watsonx.bean.EmbeddingRequest; import io.quarkiverse.langchain4j.watsonx.bean.TextGenerationParameters; import io.quarkiverse.langchain4j.watsonx.bean.TextGenerationParameters.LengthPenalty; @@ -64,6 +65,7 @@ public class GenerationAllPropertiesTest extends WireMockAbstract { .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.chat-model.truncate-input-tokens", "0") .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.chat-model.include-stop-sequence", "false") .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.embedding-model.model-id", "my_super_embedding_model") + .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.embedding-model.truncate-input-tokens", "10") .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class).addClass(WireMockUtil.class)); @Override @@ -102,6 +104,8 @@ void handlerBeforeEach() { .includeStopSequence(false) .build(); + static EmbeddingParameters embeddingParameters = new EmbeddingParameters(10); + @Test void check_config() throws Exception { var runtimeConfig = langchain4jWatsonConfig.defaultConfig(); @@ -133,6 +137,7 @@ void check_config() throws Exception { assertEquals("@", runtimeConfig.chatModel().promptJoiner()); assertEquals(true, fixedRuntimeConfig.chatModel().promptFormatter()); assertEquals("my_super_embedding_model", runtimeConfig.embeddingModel().modelId()); + assertEquals(10, runtimeConfig.embeddingModel().truncateInputTokens().orElse(null)); } @Test @@ -158,7 +163,7 @@ void check_embedding_model() throws Exception { String projectId = config.projectId(); EmbeddingRequest request = new EmbeddingRequest(modelId, projectId, - List.of("Embedding THIS!")); + List.of("Embedding THIS!"), embeddingParameters); mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_EMBEDDING_API, 200, "aaaa-mm-dd") .body(mapper.writeValueAsString(request)) diff --git a/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/GenerationDefaultPropertiesTest.java b/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/GenerationDefaultPropertiesTest.java index 86742fc71..194a14a49 100644 --- a/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/GenerationDefaultPropertiesTest.java +++ b/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/GenerationDefaultPropertiesTest.java @@ -98,6 +98,7 @@ void check_config() throws Exception { assertTrue(runtimeConfig.chatModel().includeStopSequence().isEmpty()); assertEquals("urn:ibm:params:oauth:grant-type:apikey", runtimeConfig.iam().grantType()); assertEquals(WireMockUtil.DEFAULT_EMBEDDING_MODEL, runtimeConfig.embeddingModel().modelId()); + assertTrue(runtimeConfig.embeddingModel().truncateInputTokens().isEmpty()); } @Test @@ -124,7 +125,7 @@ void check_embedding_model() throws Exception { String projectId = config.projectId(); EmbeddingRequest request = new EmbeddingRequest(modelId, projectId, - List.of("Embedding THIS!")); + List.of("Embedding THIS!"), null); mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_EMBEDDING_API, 200) .body(mapper.writeValueAsString(request)) diff --git a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxEmbeddingModel.java b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxEmbeddingModel.java index 5de1d1287..2331f05b9 100644 --- a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxEmbeddingModel.java +++ b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxEmbeddingModel.java @@ -17,6 +17,7 @@ import dev.langchain4j.model.embedding.EmbeddingModel; import dev.langchain4j.model.embedding.TokenCountEstimator; import dev.langchain4j.model.output.Response; +import io.quarkiverse.langchain4j.watsonx.bean.EmbeddingParameters; import io.quarkiverse.langchain4j.watsonx.bean.EmbeddingRequest; import io.quarkiverse.langchain4j.watsonx.bean.EmbeddingResponse; import io.quarkiverse.langchain4j.watsonx.bean.EmbeddingResponse.Result; @@ -28,6 +29,7 @@ public class WatsonxEmbeddingModel implements EmbeddingModel, TokenCountEstimator { private final String modelId, projectId, version; + private final EmbeddingParameters parameters; private final WatsonxRestApi client; public WatsonxEmbeddingModel(Builder builder) { @@ -49,6 +51,11 @@ public WatsonxEmbeddingModel(Builder builder) { this.modelId = builder.modelId; this.projectId = builder.projectId; this.version = builder.version; + + if (builder.truncateInputTokens != null) + this.parameters = new EmbeddingParameters(builder.truncateInputTokens); + else + this.parameters = null; } @Override @@ -61,7 +68,7 @@ public Response> embedAll(List textSegments) { .map(TextSegment::text) .collect(Collectors.toList()); - EmbeddingRequest request = new EmbeddingRequest(modelId, projectId, inputs); + EmbeddingRequest request = new EmbeddingRequest(modelId, projectId, inputs, parameters); EmbeddingResponse result = retryOn(new Callable() { @Override public EmbeddingResponse call() throws Exception { @@ -102,6 +109,7 @@ public static final class Builder { private String version; private String projectId; private Duration timeout; + private Integer truncateInputTokens; private boolean logResponses; private boolean logRequests; private URL url; @@ -127,6 +135,11 @@ public Builder timeout(Duration timeout) { return this; } + public Builder truncateInputTokens(Integer truncateInputTokens) { + this.truncateInputTokens = truncateInputTokens; + return this; + } + public Builder logRequests(boolean logRequests) { this.logRequests = logRequests; return this; diff --git a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/bean/EmbeddingParameters.java b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/bean/EmbeddingParameters.java new file mode 100644 index 000000000..0e1d8924b --- /dev/null +++ b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/bean/EmbeddingParameters.java @@ -0,0 +1,4 @@ +package io.quarkiverse.langchain4j.watsonx.bean; + +public record EmbeddingParameters(Integer truncateInputTokens) { +} diff --git a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/bean/EmbeddingRequest.java b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/bean/EmbeddingRequest.java index 7322f9c68..5d56e7e76 100644 --- a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/bean/EmbeddingRequest.java +++ b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/bean/EmbeddingRequest.java @@ -2,9 +2,9 @@ import java.util.List; -public record EmbeddingRequest(String modelId, String projectId, List inputs) { +public record EmbeddingRequest(String modelId, String projectId, List inputs, EmbeddingParameters parameters) { - public EmbeddingRequest of(String modelId, String projectId, String input) { - return new EmbeddingRequest(modelId, projectId, List.of(input)); + public EmbeddingRequest of(String modelId, String projectId, String input, EmbeddingParameters parameters) { + return new EmbeddingRequest(modelId, projectId, List.of(input), parameters); } } diff --git a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/runtime/WatsonxRecorder.java b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/runtime/WatsonxRecorder.java index d2b63bc5f..56b9a0916 100644 --- a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/runtime/WatsonxRecorder.java +++ b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/runtime/WatsonxRecorder.java @@ -197,7 +197,8 @@ public Supplier embeddingModel(LangChain4jWatsonxConfig runtimeC .logResponses(firstOrDefault(false, embeddingModelConfig.logResponses(), watsonConfig.logResponses())) .version(watsonConfig.version()) .projectId(watsonConfig.projectId()) - .modelId(embeddingModelConfig.modelId()); + .modelId(embeddingModelConfig.modelId()) + .truncateInputTokens(embeddingModelConfig.truncateInputTokens().orElse(null)); return new Supplier<>() { @Override diff --git a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/runtime/config/EmbeddingModelConfig.java b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/runtime/config/EmbeddingModelConfig.java index d1dea97e4..0357c1086 100644 --- a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/runtime/config/EmbeddingModelConfig.java +++ b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/runtime/config/EmbeddingModelConfig.java @@ -12,13 +12,25 @@ public interface EmbeddingModelConfig { /** * Model id to use. * - * To view the complete model list, click * here. */ @WithDefault("ibm/slate-125m-english-rtrvr") String modelId(); + /** + * Represents the maximum number of input tokens accepted. This can be used to avoid requests failing due to input being + * longer + * than configured limits. If the text is truncated, then it truncates the end of the input (on the right), so the start of + * the + * input will remain the same. If this value exceeds the maximum sequence length (refer to the documentation to find this + * value + * for the model) then the call will fail if the total number of tokens exceeds the maximum sequence length. + */ + Optional truncateInputTokens(); + /** * Whether embedding model requests should be logged. */