Skip to content

Commit

Permalink
Merge branch 'main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
clun authored Jul 2, 2024
2 parents 79384e5 + 234141b commit 6175b44
Show file tree
Hide file tree
Showing 13 changed files with 213 additions and 19 deletions.
9 changes: 8 additions & 1 deletion docs/docs/tutorials/5-ai-services.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,13 @@ Friend friend = AiServices.create(Friend.class, model);

String answer = friend.chat("Hello"); // Hey! What's up?
```
In this example, we have added the `@SystemMessage` annotation with a system prompt we want to use.

In this example, we have added the `@SystemMessage` annotation with a system prompt template we want to use.
This will be converted into a `SystemMessage` behind the scenes and sent to the LLM along with the `UserMessage`.

`@SystemMessage` can also load a prompt template from resources:
`@SystemMessage(fromResource = "my-prompt-template.txt")`

### System Message Provider
System messages can also be defined dynamically with the system message provider:
```java
Expand Down Expand Up @@ -147,6 +151,9 @@ String answer = friend.chat("Hello"); // Hey! What's shakin'?
We have replaced the `@SystemMessage` annotation with `@UserMessage`
and specified a prompt template with the variable `it` to refer to the only method argument.

`@UserMessage` can also load a prompt template from resources:
`@UserMessage(fromResource = "my-prompt-template.txt")`

Additionally, it's possible to annotate the `String userMessage` with `@V`
and assign a custom name to the prompt template variable:
```java
Expand Down
40 changes: 39 additions & 1 deletion docs/docs/tutorials/6-tools.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,9 @@ Please note that tools/function calling is not the same as [JSON mode](/tutorial

## 2 levels of abstraction

LangChain4j provides two levels of abstraction for using tools.
LangChain4j provides two levels of abstraction for using tools:
- Low-level, using the `ChatLanguageModel` API
- High-level, using [AI Services](/tutorials/ai-services) and `@Tool`-annotated Java methods

### Low level Tool API

Expand Down Expand Up @@ -294,6 +296,42 @@ The value provided to the AI Service method will be automatically passed to the
This feature is useful if you have multiple users and/or multiple chats/memories per user
and wish to distinguish between them inside the `@Tool` method.

### Configuring Tools Programmatically

When using AI Services, tools can also be configured programmatically.
This approach offers a lot of flexibility, as tools can now be loaded
from external sources such as databases and configuration files.

Tool names, descriptions, parameter names, and descriptions
can all be configured using `ToolSpecification`:
```java
ToolSpecification toolSpecification = ToolSpecification.builder()
.name("get_booking_details")
.description("Returns booking details")
.addParameter("bookingNumber", type("string"), description("Booking number in B-12345 format"))
.build();
```

For each `ToolSpecification`, one needs to provide a `ToolExecutor` implementation
that will be handling tool execution requests generated by the LLM:
```java
ToolExecutor toolExecutor = (toolExecutionRequest, memoryId) -> {
Map<String, Object> arguments = fromJson(toolExecutionRequest.arguments());
String bookingNumber = arguments.get("bookingNumber").toString();
Booking booking = getBooking(bookingNumber);
return booking.toString();
};
```

Once we have one or multiple (`ToolSpecification`, `ToolExecutor`) pairs,
we can specify them when creating an AI Service:
```java
Assistant assistant = AiServices.builder(Assistant.class)
.chatLanguageModel(chatLanguageModel)
.tools(singletonMap(toolSpecification, toolExecutor))
.build();
```

## Related Tutorials

- [Great guide on Tools](https://www.youtube.com/watch?v=cjI_6Siry-s)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,13 @@
import static java.lang.annotation.RetentionPolicy.RUNTIME;

/**
* Java methods annotated with @Tool are considered tools that language model can use.
* When using OpenAI models, <a href="https://platform.openai.com/docs/guides/function-calling">function calling</a>
* Java methods annotated with {@code @Tool} are considered tools/functions that language model can execute/call.
* Tool/function calling LLM capability (e.g., see <a href="https://platform.openai.com/docs/guides/function-calling">OpenAI function calling documentation</a>)
* is used under the hood.
* A low-level {@link ToolSpecification} will be automatically created from the method signature
* (e.g. method name, method parameters (names and types), @Tool and @P annotations, etc.)
* and will be sent to the LLM.
* If LLM decides to call the tool, the arguments are automatically parsed and injected as method arguments.
*/
@Retention(RUNTIME)
@Target({METHOD})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import static dev.langchain4j.internal.Utils.quoted;

/**
* Represents a request to execute a tool.
* Represents an LLM-generated request to execute a tool.
*/
public class ToolExecutionRequest {
private final String id;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,18 @@
package dev.langchain4j.agent.tool;

import dev.langchain4j.service.MemoryId;

/**
* A low-level executor/handler of a {@link ToolExecutionRequest}.
*/
public interface ToolExecutor {

/**
* Executes a tool requests.
*
* @param toolExecutionRequest The tool execution request. Contains tool name and arguments.
* @param memoryId The ID of the chat memory. See {@link MemoryId} for more details.
* @return The result of the tool execution.
*/
String execute(ToolExecutionRequest toolExecutionRequest, Object memoryId);
}
41 changes: 36 additions & 5 deletions langchain4j/src/main/java/dev/langchain4j/service/AiServices.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import dev.langchain4j.agent.tool.DefaultToolExecutor;
import dev.langchain4j.agent.tool.Tool;
import dev.langchain4j.agent.tool.ToolExecutor;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
Expand Down Expand Up @@ -37,6 +38,7 @@
import static dev.langchain4j.internal.Exceptions.illegalArgument;
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
import static dev.langchain4j.spi.ServiceHelper.loadFactories;
import static java.util.Arrays.asList;
import static java.util.stream.Collectors.toList;

/**
Expand Down Expand Up @@ -294,7 +296,6 @@ public AiServices<T> moderationModel(ModerationModel moderationModel) {

/**
* Configures the tools that the LLM can use.
* A {@link ChatMemory} that can hold at least 3 messages is required for the tools to work properly.
*
* @param objectsWithTools One or more objects whose methods are annotated with {@link Tool}.
* All these tools (methods annotated with {@link Tool}) will be accessible to the LLM.
Expand All @@ -303,12 +304,11 @@ public AiServices<T> moderationModel(ModerationModel moderationModel) {
* @see Tool
*/
public AiServices<T> tools(Object... objectsWithTools) {
return tools(Arrays.asList(objectsWithTools));
return tools(asList(objectsWithTools));
}

/**
* Configures the tools that the LLM can use.
* A {@link ChatMemory} that can hold at least 3 messages is required for the tools to work properly.
*
* @param objectsWithTools A list of objects whose methods are annotated with {@link Tool}.
* All these tools (methods annotated with {@link Tool}) are accessible to the LLM.
Expand All @@ -318,8 +318,13 @@ public AiServices<T> tools(Object... objectsWithTools) {
*/
public AiServices<T> tools(List<Object> objectsWithTools) { // TODO Collection?
// TODO validate uniqueness of tool names
context.toolSpecifications = new ArrayList<>();
context.toolExecutors = new HashMap<>();

if (context.toolSpecifications == null) {
context.toolSpecifications = new ArrayList<>();
}
if (context.toolExecutors == null) {
context.toolExecutors = new HashMap<>();
}

for (Object objectWithTool : objectsWithTools) {
if (objectWithTool instanceof Class) {
Expand All @@ -338,6 +343,32 @@ public AiServices<T> tools(List<Object> objectsWithTools) { // TODO Collection?
return this;
}

/**
* Configures the tools that the LLM can use.
*
* @param tools A map of {@link ToolSpecification} to {@link ToolExecutor} entries.
* This method of configuring tools is useful when tools must be configured programmatically.
* Otherwise, it is recommended to use the {@link Tool}-annotated java methods
* and configure tools with the {@link #tools(Object...)} and {@link #tools(List)} methods.
* @return builder
*/
public AiServices<T> tools(Map<ToolSpecification, ToolExecutor> tools) {

if (context.toolSpecifications == null) {
context.toolSpecifications = new ArrayList<>();
}
if (context.toolExecutors == null) {
context.toolExecutors = new HashMap<>();
}

tools.forEach((toolSpecification, toolExecutor) -> {
context.toolSpecifications.add(toolSpecification);
context.toolExecutors.put(toolSpecification.name(), toolExecutor);
});

return this;
}

/**
* Deprecated. Use {@link #contentRetriever(ContentRetriever)}
* (e.g. {@link EmbeddingStoreContentRetriever}) instead.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -372,8 +372,12 @@ private static String getTemplate(Method method, String type, String resource, S
return messageTemplate;
}

private static String getResourceText(Class<?> clazz, String name) {
return getText(clazz.getResourceAsStream(name));
private static String getResourceText(Class<?> clazz, String resource) {
InputStream inputStream = clazz.getResourceAsStream(resource);
if (inputStream == null) {
inputStream = clazz.getResourceAsStream("/" + resource);
}
return getText(inputStream);
}

private static String getText(InputStream inputStream) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ void should_load_documents_including_unknown_document_types() {
"miles-of-smiles-terms-of-use.txt",
"test-file.banana",
"test-file-iso-8859-1.txt",
"test-file-utf8.txt"
"test-file-utf8.txt",
"chefs-prompt-based-on-ingredients-in-root.txt"
);

// when-then
Expand Down Expand Up @@ -169,6 +170,8 @@ void should_recursively_load_documents() {
"test-file-utf8.txt",
"chefs-prompt-based-on-ingredients.txt",
"chefs-prompt-system-message.txt",
"chefs-prompt-based-on-ingredients-in-root.txt",
"chefs-prompt-based-on-ingredients-in-subdirectory.txt",
"test-file-2.banana"
);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,12 @@ interface Chef {
@UserMessage(fromResource = "chefs-prompt-based-on-ingredients.txt")
Recipe createRecipeFromUsingResource(String... ingredients);

@UserMessage(fromResource = "chefs-prompt-based-on-ingredients-in-root.txt")
Recipe createRecipeFromUsingResourceInRoot(String... ingredients);

@UserMessage(fromResource = "subdirectory/chefs-prompt-based-on-ingredients-in-subdirectory.txt")
Recipe createRecipeFromUsingResourceInSubdirectory(String... ingredients);

Recipe createRecipeFrom(CreateRecipePrompt prompt);

@SystemMessage("You are very {{character}} chef")
Expand Down Expand Up @@ -355,6 +361,52 @@ void test_create_recipe_from_list_of_ingredients_using_resource() {
"}")));
}

@Test
void test_create_recipe_from_list_of_ingredients_using_resource_in_root() {

Chef chef = AiServices.create(Chef.class, chatLanguageModel);

Recipe recipe = chef.createRecipeFromUsingResourceInRoot("cucumber", "tomato", "feta", "onion", "olives");
System.out.println(recipe);

assertThat(recipe.title).isNotBlank();
assertThat(recipe.description).isNotBlank();
assertThat(recipe.steps).isNotEmpty();
assertThat(recipe.preparationTimeMinutes).isPositive();

verify(chatLanguageModel).generate(singletonList(userMessage(
"Create recipe using only [cucumber, tomato, feta, onion, olives]\n" +
"You must answer strictly in the following JSON format: {\n" +
"\"title\": (type: string),\n" +
"\"description\": (type: string),\n" +
"\"steps\": (each step should be described in 4 words, steps should rhyme; type: array of string),\n" +
"\"preparationTimeMinutes\": (type: integer)\n" +
"}")));
}

@Test
void test_create_recipe_from_list_of_ingredients_using_resource_in_subdirectory() {

Chef chef = AiServices.create(Chef.class, chatLanguageModel);

Recipe recipe = chef.createRecipeFromUsingResourceInSubdirectory("cucumber", "tomato", "feta", "onion", "olives");
System.out.println(recipe);

assertThat(recipe.title).isNotBlank();
assertThat(recipe.description).isNotBlank();
assertThat(recipe.steps).isNotEmpty();
assertThat(recipe.preparationTimeMinutes).isPositive();

verify(chatLanguageModel).generate(singletonList(userMessage(
"Create recipe using only [cucumber, tomato, feta, onion, olives]\n" +
"You must answer strictly in the following JSON format: {\n" +
"\"title\": (type: string),\n" +
"\"description\": (type: string),\n" +
"\"steps\": (each step should be described in 4 words, steps should rhyme; type: array of string),\n" +
"\"preparationTimeMinutes\": (type: integer)\n" +
"}")));
}

interface BadChef {
String CHEFS_PROMPT_DOES_NOT_EXIST_TXT = "chefs-prompt-does-not-exist.txt";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -537,10 +537,10 @@ void should_use_static_metadata_filter(ChatLanguageModel model) {
.build();

// when
String answer1 = assistant.answer("Which animal?");
String answer = assistant.answer("Which animal is mentioned?");

// then
assertThat(answer1).containsIgnoringCase("dog");
assertThat(answer).containsIgnoringCase("dog");
}

@ParameterizedTest
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
package dev.langchain4j.service;

import dev.langchain4j.agent.tool.P;
import dev.langchain4j.agent.tool.Tool;
import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolSpecification;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import dev.langchain4j.agent.tool.*;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.ToolExecutionResultMessage;
Expand All @@ -26,6 +26,7 @@

import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.stream.Stream;

import static dev.langchain4j.agent.tool.JsonSchemaProperty.description;
Expand All @@ -38,7 +39,9 @@
import static dev.langchain4j.service.AiServicesWithToolsIT.TransactionService.EXPECTED_SPECIFICATION;
import static java.util.Arrays.asList;
import static java.util.Collections.singletonList;
import static java.util.Collections.singletonMap;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.data.MapEntry.entry;
import static org.mockito.Mockito.*;

@ExtendWith(MockitoExtension.class)
Expand Down Expand Up @@ -646,4 +649,42 @@ void should_use_tool_with_pojo(ChatLanguageModel chatLanguageModel) {

assertThat(response.content().text()).contains("Amar", "Akbar", "Antony");
}

@ParameterizedTest
@MethodSource("models")
void should_use_programmatically_configured_tools(ChatLanguageModel chatLanguageModel) {

// given
ToolSpecification toolSpecification = ToolSpecification.builder()
.name("get_booking_details")
.description("Returns booking details")
.addParameter("bookingNumber", type("string"))
.build();

ToolExecutor toolExecutor = (toolExecutionRequest, memoryId) -> {
Map<String, Object> arguments = toMap(toolExecutionRequest.arguments());
assertThat(arguments).containsExactly(entry("bookingNumber", "123-456"));
return "Booking period: from 1 July 2027 to 10 July 2027";
};

Assistant assistant = AiServices.builder(Assistant.class)
.chatLanguageModel(chatLanguageModel)
.tools(singletonMap(toolSpecification, toolExecutor))
.build();

// when
Response<AiMessage> response = assistant.chat("When does my booking 123-456 starts?");

// then
assertThat(response.content().text()).contains("2027");
}

private static Map<String, Object> toMap(String arguments) {
try {
return new ObjectMapper().readValue(arguments, new TypeReference<Map<String, Object>>() {
});
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
}
}
Loading

0 comments on commit 6175b44

Please sign in to comment.