Skip to content

Commit

Permalink
Merge pull request #824 from quarkiverse/#819
Browse files Browse the repository at this point in the history
Support LangChain4j's Result return type
  • Loading branch information
geoand authored Aug 26, 2024
2 parents 310e40e + 0cd6f69 commit 10d9382
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
import dev.langchain4j.rag.query.Metadata;
import dev.langchain4j.service.AiServiceContext;
import dev.langchain4j.service.AiServiceTokenStream;
import dev.langchain4j.service.Result;
import dev.langchain4j.service.TokenStream;
import dev.langchain4j.service.output.ServiceOutputParser;
import dev.langchain4j.service.tool.ToolExecutor;
Expand Down Expand Up @@ -134,7 +135,7 @@ private static Object doImplement(AiServiceMethodCreateInfo methodCreateInfo, Ob
boolean needsMemorySeed = needsMemorySeed(context, memoryId); // we need to know figure this out before we add the system and user message

Type returnType = methodCreateInfo.getReturnType();
AugmentationResult augmentationResult;
AugmentationResult augmentationResult = null;
if (context.retrievalAugmentor != null) {
List<ChatMessage> chatMemory = context.hasChatMemory()
? context.chatMemory(memoryId).messages()
Expand Down Expand Up @@ -276,7 +277,17 @@ private List<ChatMessage> messagesToSend(ChatMessage augmentedUserMessage,
chatMemory.commit();

response = Response.from(response.content(), tokenUsageAccumulator, response.finishReason());
return SERVICE_OUTPUT_PARSER.parse(response, returnType);
if (isResult(returnType)) {
var parsedResponse = SERVICE_OUTPUT_PARSER.parse(response, resultTypeParam((ParameterizedType) returnType));
return Result.builder()
.content(parsedResponse)
.tokenUsage(tokenUsageAccumulator)
.sources(augmentationResult == null ? null : augmentationResult.contents())
.finishReason(response.finishReason())
.build();
} else {
return SERVICE_OUTPUT_PARSER.parse(response, returnType);
}
}

private static boolean needsMemorySeed(QuarkusAiServiceContext context, Object memoryId) {
Expand Down Expand Up @@ -343,6 +354,17 @@ private static boolean isMulti(Type returnType) {
return isTypeOf(returnType, Multi.class);
}

private static boolean isResult(Type returnType) {
return isTypeOf(returnType, Result.class);
}

private static Type resultTypeParam(ParameterizedType returnType) {
if (!isTypeOf(returnType, Result.class)) {
throw new IllegalStateException("Can only be called with Result<T> type");
}
return returnType.getActualTypeArguments()[0];
}

private static boolean isTypeOf(Type type, Class<?> clazz) {
if (type instanceof Class<?>) {
return type.equals(clazz);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import static dev.langchain4j.data.message.ChatMessageType.AI;
import static dev.langchain4j.data.message.ChatMessageType.SYSTEM;
import static dev.langchain4j.data.message.ChatMessageType.USER;
import static dev.langchain4j.data.message.UserMessage.userMessage;
import static java.time.Month.JULY;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
Expand Down Expand Up @@ -44,11 +45,13 @@
import dev.langchain4j.model.input.structured.StructuredPrompt;
import dev.langchain4j.model.openai.OpenAiChatModel;
import dev.langchain4j.model.openai.OpenAiModerationModel;
import dev.langchain4j.model.output.TokenUsage;
import dev.langchain4j.model.output.structured.Description;
import dev.langchain4j.service.AiServices;
import dev.langchain4j.service.MemoryId;
import dev.langchain4j.service.Moderate;
import dev.langchain4j.service.ModerationException;
import dev.langchain4j.service.Result;
import dev.langchain4j.service.SystemMessage;
import dev.langchain4j.service.UserMessage;
import dev.langchain4j.service.V;
Expand Down Expand Up @@ -901,6 +904,38 @@ public void deleteMessages(Object memoryId) {
tuple(USER, secondsMessageFromSecondUser), tuple(AI, secondAiMessageToSecondUser));
}

interface AssistantReturningResult {

Result<String> chat(String userMessage);
}

@Test
void should_return_result() throws IOException {
setChatCompletionMessageContent("Berlin is the capital of Germany");

// given
AssistantReturningResult assistant = AiServices.create(AssistantReturningResult.class, createChatModel());

String userMessage = "What is the capital of Germany?";

// when
Result<String> result = assistant.chat(userMessage);

// then
assertThat(result.content()).containsIgnoringCase("Berlin");

TokenUsage tokenUsage = result.tokenUsage();
assertThat(tokenUsage).isNotNull();
assertThat(tokenUsage.inputTokenCount()).isGreaterThan(0);
assertThat(tokenUsage.outputTokenCount()).isGreaterThan(0);
assertThat(tokenUsage.totalTokenCount())
.isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount());

assertThat(result.sources()).isNull();

assertSingleRequestMessage(getRequestAsMap(), "What is the capital of Germany?");
}

static class Calculator {

private final Runnable after;
Expand Down

0 comments on commit 10d9382

Please sign in to comment.