From fd29b1968478a8f587d70c8535c05de944180bb0 Mon Sep 17 00:00:00 2001 From: Georgios Andrianakis Date: Tue, 22 Oct 2024 14:07:17 +0300 Subject: [PATCH] Properly support TokenStream return type in AiService --- .../langchain4j/deployment/AiServicesProcessor.java | 5 +++++ .../langchain4j/deployment/LangChain4jDotNames.java | 2 ++ 2 files changed, 7 insertions(+) diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java index c2b25faf3..637b71c79 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java @@ -510,6 +510,11 @@ public void handleDeclarativeServices(AiServicesRecorder recorder, // currently in one class either streaming or blocking model are supported, but not both // if we want to support it, the injectStreamingChatModelBean needs to be recorded per injection point for (MethodInfo method : declarativeAiServiceClassInfo.methods()) { + if (LangChain4jDotNames.TOKEN_STREAM.equals(method.returnType().name())) { + injectStreamingChatModelBean = true; + continue; + } + if (!DotNames.MULTI.equals(method.returnType().name())) { continue; } diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/LangChain4jDotNames.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/LangChain4jDotNames.java index 6e3cb312c..6a6eaae35 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/LangChain4jDotNames.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/LangChain4jDotNames.java @@ -22,6 +22,7 @@ import dev.langchain4j.service.Moderate; import dev.langchain4j.service.Result; import dev.langchain4j.service.SystemMessage; +import dev.langchain4j.service.TokenStream; import dev.langchain4j.service.UserMessage; import dev.langchain4j.service.UserName; import dev.langchain4j.web.search.WebSearchEngine; @@ -45,6 +46,7 @@ public class LangChain4jDotNames { public static final DotName IMAGE_MODEL = DotName.createSimple(ImageModel.class); public static final DotName TOKEN_COUNT_ESTIMATOR = DotName.createSimple(TokenCountEstimator.class); public static final DotName CHAT_MESSAGE = DotName.createSimple(ChatMessage.class); + public static final DotName TOKEN_STREAM = DotName.createSimple(TokenStream.class); public static final DotName OUTPUT_GUARDRAILS = DotName.createSimple(OutputGuardrails.class); public static final DotName INPUT_GUARDRAILS = DotName.createSimple(InputGuardrails.class); static final DotName AI_SERVICES = DotName.createSimple(AiServices.class);