diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index faa63cf75e..a2f2e3b237 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -1,21 +1,45 @@ + + + -## Context - + + + + + +## Issue + + ## Change - + + -## Checklist -Before submitting this PR, please check the following points: +## General checklist + +- [ ] There are no breaking changes - [ ] I have added unit and integration tests for my change -- [ ] All unit and integration tests in the module I have added/changed are green -- [ ] All unit and integration tests in the [core](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-core) and [main](https://github.com/langchain4j/langchain4j/tree/main/langchain4j) modules are green +- [ ] I have manually run all the unit and integration tests in the module I have added/changed, and they are all green +- [ ] I have manually run all the unit and integration tests in the [core](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-core) and [main](https://github.com/langchain4j/langchain4j/tree/main/langchain4j) modules, and they are all green + - [ ] I have added/updated the [documentation](https://github.com/langchain4j/langchain4j/tree/main/docs/docs) - [ ] I have added an example in the [examples repo](https://github.com/langchain4j/langchain4j-examples) (only for "big" features) -- [ ] I have added my new module in the [BOM](https://github.com/langchain4j/langchain4j/blob/main/langchain4j-bom/pom.xml) (only when a new module is added) +- [ ] I have added/updated [Spring Boot starter(s)](https://github.com/langchain4j/langchain4j-spring) (if applicable) + +## Checklist for adding new model integration + +- [ ] I have added my new module in the [BOM](https://github.com/langchain4j/langchain4j/blob/main/langchain4j-bom/pom.xml) + ## Checklist for adding new embedding store integration -- [ ] I have added a {NameOfIntegration}EmbeddingStoreIT that extends from either EmbeddingStoreIT or EmbeddingStoreWithFilteringIT + +- [ ] I have added a `{NameOfIntegration}EmbeddingStoreIT` that extends from either `EmbeddingStoreIT` or `EmbeddingStoreWithFilteringIT` +- [ ] I have added my new module in the [BOM](https://github.com/langchain4j/langchain4j/blob/main/langchain4j-bom/pom.xml) + + +## Checklist for changing existing embedding store integration + +- [ ] I have manually verified that the `{NameOfIntegration}EmbeddingStore` works correctly with the data persisted using the latest released version of LangChain4j diff --git a/.github/workflows/nightly.yaml b/.github/workflows/nightly.yaml new file mode 100644 index 0000000000..84d4db4776 --- /dev/null +++ b/.github/workflows/nightly.yaml @@ -0,0 +1,77 @@ +name: Nightly Build + +on: + schedule: + - cron: '0 0 * * *' # daily at midnight UTC + workflow_dispatch: + +jobs: + java_build: + strategy: + matrix: + java_version: [ 8, 11, 17, 21 ] + include: + - java_version: '8' + included_modules: '-pl !langchain4j-local-ai,!langchain4j-milvus,!code-execution-engines/langchain4j-code-execution-engine-graalvm-polyglot,!langchain4j-cassandra,!langchain4j-infinispan,!langchain4j-neo4j,!langchain4j-opensearch,!langchain4j-azure-ai-search' + - java_version: '11' + included_modules: '-pl !langchain4j-local-ai,!langchain4j-milvus,!code-execution-engines/langchain4j-code-execution-engine-graalvm-polyglot,!langchain4j-infinispan,!langchain4j-neo4j' + - java_version: '17' + included_modules: '-pl !langchain4j-local-ai,!langchain4j-milvus,!code-execution-engines/langchain4j-code-execution-engine-graalvm-polyglot' + - java_version: '21' + included_modules: '-pl !langchain4j-local-ai,!langchain4j-milvus' + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up JDK ${{ matrix.java_version }} + uses: actions/setup-java@v4 + with: + java-version: ${{ matrix.java_version }} + distribution: 'temurin' + cache: 'maven' + + - name: Authenticate to Google Cloud + # Needed for langchain4j-vertex-ai and langchain4j-vertex-ai-gemini modules + uses: 'google-github-actions/auth@v2' + with: + project_id: ${{ secrets.GCP_PROJECT_ID }} + credentials_json: ${{ secrets.GCP_CREDENTIALS_JSON }} + + - name: Setup Testcontainers Cloud Client + # Needed for langchain4j-ollama and other modules using testcontainers + uses: atomicjar/testcontainers-cloud-setup-action@v1 + with: + token: ${{ secrets.TC_CLOUD_TOKEN }} + + - name: Build with JDK ${{ matrix.java_version }} + run: mvn -B -U --fail-at-end ${{ matrix.included_modules }} verify + env: + ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} + AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }} + AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }} + AZURE_OPENAI_ENDPOINT: ${{ secrets.AZURE_OPENAI_ENDPOINT }} + AZURE_OPENAI_KEY: ${{ secrets.AZURE_OPENAI_KEY }} + AZURE_SEARCH_ENDPOINT: ${{ secrets.AZURE_SEARCH_ENDPOINT }} + AZURE_SEARCH_KEY: ${{ secrets.AZURE_SEARCH_KEY }} + COHERE_API_KEY: ${{ secrets.COHERE_API_KEY }} + ELASTICSEARCH_CLOUD_API_KEY: ${{ secrets.ELASTICSEARCH_CLOUD_API_KEY }} + ELASTICSEARCH_CLOUD_URL: ${{ secrets.ELASTICSEARCH_CLOUD_URL }} + GCP_CREDENTIALS_JSON: ${{ secrets.GCP_CREDENTIALS_JSON }} + GCP_LOCATION: ${{ secrets.GCP_LOCATION }} + GCP_PROJECT_ID: ${{ secrets.GCP_PROJECT_ID }} + GCP_VERTEXAI_ENDPOINT: ${{ secrets.GCP_VERTEXAI_ENDPOINT }} + HF_API_KEY: ${{ secrets.HF_API_KEY }} + JINA_API_KEY: ${{ secrets.JINA_API_KEY }} + MILVUS_API_KEY: ${{ secrets.MILVUS_API_KEY }} + MILVUS_URI: ${{ secrets.MILVUS_URI }} + MISTRAL_AI_API_KEY: ${{ secrets.MISTRAL_AI_API_KEY }} + MONGODB_ATLAS_USERNAME: ${{ secrets.MONGODB_ATLAS_USERNAME }} + MONGODB_ATLAS_PASSWORD: ${{ secrets.MONGODB_ATLAS_PASSWORD }} + MONGODB_ATLAS_HOST: ${{ secrets.MONGODB_ATLAS_HOST }} + NOMIC_API_KEY: ${{ secrets.NOMIC_API_KEY }} + OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} + OPENAI_BASE_URL: ${{ secrets.OPENAI_BASE_URL }} + PINECONE_API_KEY: ${{ secrets.PINECONE_API_KEY }} + TAVILY_API_KEY: ${{ secrets.TAVILY_API_KEY }} + WEAVIATE_API_KEY: ${{ secrets.WEAVIATE_API_KEY }} + WEAVIATE_HOST: ${{ secrets.WEAVIATE_HOST }} diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index a28d56d098..94760d7fd9 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -35,9 +35,11 @@ jobs: token: ${{ secrets.TC_CLOUD_TOKEN }} - name: release - run: mvn -B -U --fail-at-end -DskipTests -DskipAnthropicITs -DskipLocalAiITs -DskipMilvusITs -DskipMongoDbAtlasITs -DskipOllamaITs -DskipVearchITs -DskipVertexAiGeminiITs -pl !langchain4j-core,!langchain4j-parent -Psign clean deploy + run: mvn -B -U --fail-at-end -DskipTests -DskipITs -DskipAnthropicITs -DskipLocalAiITs -DskipMilvusITs -DskipMongoDbAtlasITs -DskipOllamaITs -DskipVearchITs -DskipVertexAiGeminiITs -pl !langchain4j-core,!langchain4j-parent -Psign clean deploy env: ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} + AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }} + AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }} AZURE_OPENAI_ENDPOINT: ${{ secrets.AZURE_OPENAI_ENDPOINT }} AZURE_OPENAI_KEY: ${{ secrets.AZURE_OPENAI_KEY }} AZURE_SEARCH_ENDPOINT: ${{ secrets.AZURE_SEARCH_ENDPOINT }} @@ -50,6 +52,7 @@ jobs: GCP_PROJECT_ID: ${{ secrets.GCP_PROJECT_ID }} GCP_VERTEXAI_ENDPOINT: ${{ secrets.GCP_VERTEXAI_ENDPOINT }} HF_API_KEY: ${{ secrets.HF_API_KEY }} + JINA_API_KEY: ${{ secrets.JINA_API_KEY }} MILVUS_API_KEY: ${{ secrets.MILVUS_API_KEY }} MILVUS_URI: ${{ secrets.MILVUS_URI }} MISTRAL_AI_API_KEY: ${{ secrets.MISTRAL_AI_API_KEY }} @@ -60,6 +63,7 @@ jobs: OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} OPENAI_BASE_URL: ${{ secrets.OPENAI_BASE_URL }} PINECONE_API_KEY: ${{ secrets.PINECONE_API_KEY }} + TAVILY_API_KEY: ${{ secrets.TAVILY_API_KEY }} WEAVIATE_API_KEY: ${{ secrets.WEAVIATE_API_KEY }} WEAVIATE_HOST: ${{ secrets.WEAVIATE_HOST }} GPG_PASSPHRASE: ${{ secrets.GPG_PASSPHRASE }} diff --git a/.gitignore b/.gitignore index b891694bce..2937c1f51f 100644 --- a/.gitignore +++ b/.gitignore @@ -37,3 +37,4 @@ build/ ### .env files contain local environment variables ### .env +langchain4j-core/target_test-classes/ diff --git a/.idea/icon.png b/.idea/icon.png deleted file mode 100644 index 8a39a1059a..0000000000 Binary files a/.idea/icon.png and /dev/null differ diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 93bbc6311d..597c895e96 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,8 +1,12 @@ Thank you for investing your time and effort in contributing to our project, we appreciate it a lot! 🤗 -# General Guidelines -- If you want to contribute a bug fix or a new feature that isn't listed in the [issues](https://github.com/langchain4j/langchain4j/issues) yet, please open a new issue for it and link it to your PR. +# Current situation (25 April 2024) +There are over 60 open PRs. Please help us by reviewing them first, before opening new ones. 🙏 + + +# General guidelines +- If you want to contribute a bug fix or a new feature that isn't listed in the [issues](https://github.com/langchain4j/langchain4j/issues) yet, please open a new issue for it. We will prioritize is shortly. - Follow [Google's Best Practices for Java Libraries](https://jlbp.dev/) - Keep the code compatible with Java 8. We plan to increase the baseline to Java 17 a bit later. - Avoid adding new dependencies as much as possible. If absolutely necessary, try to use the same libraries which are already used in the project. @@ -13,42 +17,54 @@ Thank you for investing your time and effort in contributing to our project, we - Follow existing code style present in the project. - Large features should be discussed with maintainers before implementation. Please ping @langchain4j in the comments on the issue. + # Priorities All [issues](https://github.com/langchain4j/langchain4j/issues) are prioritized by maintainers. There are 4 priorities: [P1](https://github.com/langchain4j/langchain4j/issues?q=is%3Aissue+is%3Aopen+label%3AP1), [P2](https://github.com/langchain4j/langchain4j/issues?q=is%3Aissue+is%3Aopen+label%3AP2), [P3](https://github.com/langchain4j/langchain4j/issues?q=is%3Aissue+is%3Aopen+label%3AP3) and [P4](https://github.com/langchain4j/langchain4j/issues?q=is%3Aissue+is%3Aopen+label%3AP4). Please start with the higher priorities. PRs will be reviewed in order of priority, with bugs being a higher priority than new features. -Please note that we do not have the capacity to review all PRs immediately. +Please note that we do not have the capacity to review PRs immediately. We ask for your patience. We are doing our best to review your PR as quickly as possible. + # Opening an issue - Please fill in all sections of the issue template. -# Opening a PR -- Link an [issue](https://github.com/langchain4j/langchain4j/issues) to your PR. If there is no issue yet, open one. + +# Opening a draft PR +- Please open the PR as a draft initially. Once it is reviewed and approved, we will then ask you to finalize it (see section below). - Fill in all the sections of the PR template. -- Make sure you've added tests. -- Make sure you've added documentation where required. -- For new big features, make sure you've added an example in the [examples repository](https://github.com/langchain4j/langchain4j-examples) (as a separate PR, linked to the main one). - Please make it easier to review your PR: - Keep changes as small as possible. - Do not combine refactoring with changes in a single PR. - Avoid reformatting existing code. + +# Finalizing the draft PR +- Add [documentation](https://github.com/langchain4j/langchain4j/tree/main/docs/docs) (if required). +- Add an example to the [examples repository](https://github.com/langchain4j/langchain4j-examples) (if required). +- [Mark a PR as ready for review](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/changing-the-stage-of-a-pull-request#marking-a-pull-request-as-ready-for-review) + + # Guidelines on adding a new model integration - [Integration with Anthropic](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-anthropic) is a good example. - Use the official SDK if available. -- If the official SDK is not available, use Retrofit and Gson to implement the client. +- If the official SDK is not available, use Retrofit and Jackson to implement the client. - Document the new integration [here](https://github.com/langchain4j/langchain4j/blob/main/README.md), [here](https://github.com/langchain4j/langchain4j/tree/main/docs/docs/integrations/language-models) and [here](https://github.com/langchain4j/langchain4j/blob/main/docs/docs/integrations/language-models/index.md). - Add an example to the [examples repository](https://github.com/langchain4j/langchain4j-examples), similar to [this](https://github.com/langchain4j/langchain4j-examples/tree/main/anthropic-examples). - Add a new module to the appropriate section of the [BOM](https://github.com/langchain4j/langchain4j/blob/main/langchain4j-bom/pom.xml). - It would be great if you could add a [Spring Boot starter](https://github.com/langchain4j/langchain4j-spring). + # Guidelines on adding a new embedding store integration - [Integration with Chroma](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-chroma) is a good example. - Use the official SDK if available. -- If the official SDK is not available, use Retrofit and Gson to implement the client. -- `{IntegrationName}EmbeddingStoreIT` should extend from `EmbeddingStoreWithFilteringIT` or `EmbeddingStoreIT` and pass all tests. +- If the official SDK is not available, use Retrofit and Jackson to implement the client. +- Add a `{IntegrationName}EmbeddingStoreIT`. It should extend from `EmbeddingStoreWithFilteringIT` or `EmbeddingStoreIT` and pass all tests. - Document the new integration [here](https://github.com/langchain4j/langchain4j/blob/main/README.md), [here](https://github.com/langchain4j/langchain4j/tree/main/docs/docs/integrations/embedding-stores) and [here](https://github.com/langchain4j/langchain4j/blob/main/docs/docs/integrations/embedding-stores/index.md). - Add an example to the [examples repository](https://github.com/langchain4j/langchain4j-examples), similar to [this](https://github.com/langchain4j/langchain4j-examples/tree/main/chroma-example). - Add a new module to the appropriate section of the [BOM](https://github.com/langchain4j/langchain4j/blob/main/langchain4j-bom/pom.xml). -- It would be great if you could add a [Spring Boot starter](https://github.com/langchain4j/langchain4j-spring). +- It would be great if you could add a [Spring Boot starter](https://github.com/langchain4j/langchain4j-spring). (after + + +# Guidelines on changing an existing embedding store integration +- Ensure that your changes are backwards compatible. `Embedding`s and `TextSegment`s persisted with the latest released version of LangChain4j should still work. diff --git a/README.md b/README.md index b98a551780..3b0da2beb7 100644 --- a/README.md +++ b/README.md @@ -3,25 +3,26 @@ [![](https://img.shields.io/twitter/follow/langchain4j)](https://twitter.com/intent/follow?screen_name=langchain4j) [![](https://dcbadge.vercel.app/api/server/JzTFvyjG6R?compact=true&style=flat)](https://discord.gg/JzTFvyjG6R) + ## Introduction Welcome! -The goal of LangChain4j is to simplify integrating AI/LLM capabilities into Java applications. +The goal of LangChain4j is to simplify integrating LLMs into Java applications. Here's how: 1. **Unified APIs:** - LLM providers (like OpenAI or Google Vertex AI) and embedding (vector) stores (such as Pinecone or Vespa) + LLM providers (like OpenAI or Google Vertex AI) and embedding (vector) stores (such as Pinecone or Milvus) use proprietary APIs. LangChain4j offers a unified API to avoid the need for learning and implementing specific APIs for each of them. - To experiment with a different LLM or embedding store, you can easily switch between them without the need to rewrite your code. - LangChain4j currently supports over 10 popular LLM providers and more than 15 embedding stores. - Think of it as a Hibernate, but for LLMs and embedding stores. + To experiment with different LLMs or embedding stores, you can easily switch between them without the need to rewrite your code. + LangChain4j currently supports [15+ popular LLM providers](https://docs.langchain4j.dev/integrations/language-models/) + and [15+ embedding stores](https://docs.langchain4j.dev/integrations/embedding-stores/). 2. **Comprehensive Toolbox:** During the past year, the community has been building numerous LLM-powered applications, - identifying common patterns, abstractions, and techniques. LangChain4j has refined these into practical code. - Our toolbox includes tools ranging from low-level prompt templating, memory management, and output parsing - to high-level patterns like Agents and RAGs. - For each pattern and abstraction, we provide an interface along with multiple ready-to-use implementations based on proven techniques. + identifying common abstractions, patterns, and techniques. LangChain4j has refined these into practical code. + Our toolbox includes tools ranging from low-level prompt templating, chat memory management, and output parsing + to high-level patterns like AI Services and RAG. + For each abstraction, we provide an interface along with multiple ready-to-use implementations based on common techniques. Whether you're building a chatbot or developing a RAG with a complete pipeline from data ingestion to retrieval, LangChain4j offers a wide variety of options. 3. **Numerous Examples:** @@ -36,304 +37,37 @@ LlamaIndex, and the broader community, spiced up with a touch of our own innovat We actively monitor community developments, aiming to quickly incorporate new techniques and integrations, ensuring you stay up-to-date. -The library is under active development. While some features from the Python version of LangChain -are still being worked on, the core functionality is in place, allowing you to start building LLM-powered apps now! - -For easier integration, LangChain4j also includes integration with -Quarkus ([extension](https://quarkus.io/extensions/io.quarkiverse.langchain4j/quarkus-langchain4j-core)) -and Spring Boot ([starters](https://github.com/langchain4j/langchain4j-spring)). - -## Code Examples +The library is under active development. While some features are still being worked on, +the core functionality is in place, allowing you to start building LLM-powered apps now! -Please see examples of how LangChain4j can be used in [langchain4j-examples](https://github.com/langchain4j/langchain4j-examples) repo: - -- [Examples in plain Java](https://github.com/langchain4j/langchain4j-examples/tree/main/other-examples/src/main/java) -- [Examples with Quarkus](https://github.com/quarkiverse/quarkus-langchain4j/tree/main/samples) (uses [quarkus-langchain4j](https://github.com/quarkiverse/quarkus-langchain4j) dependency) -- [Example with Spring Boot](https://github.com/langchain4j/langchain4j-examples/tree/main/spring-boot-example/src/main/java/dev/langchain4j/example) ## Documentation Documentation can be found [here](https://docs.langchain4j.dev). -## Tutorials -Tutorials can be found [here](https://docs.langchain4j.dev/tutorials). - -## Useful Materials -[Useful Materials](https://docs.langchain4j.dev/useful-materials) - -## Library Structure -LangChain4j features a modular design, comprising: -- The `langchain4j-core` module, which defines core abstractions (such as `ChatLanguageModel` and `EmbeddingStore`) and their APIs. -- The main `langchain4j` module, containing useful tools like `ChatMemory`, `OutputParser` as well as a high-level features like `AiServices`. -- A wide array of `langchain4j-{integration}` modules, each providing integration with various LLM providers and embedding stores into LangChain4j. - You can use the `langchain4j-{integration}` modules independently. For additional features, simply import the main `langchain4j` dependency. - -## Highlights - -You can define declarative "AI Services" that are powered by LLMs: - -```java -interface Assistant { - - String chat(String userMessage); -} - -Assistant assistant = AiServices.create(Assistant.class, model); - -String answer = assistant.chat("Hello"); - -System.out.println(answer); // Hello! How can I assist you today? -``` - -You can use LLM as a classifier: - -```java -enum Sentiment { - POSITIVE, NEUTRAL, NEGATIVE -} - -interface SentimentAnalyzer { - - @UserMessage("Analyze sentiment of {{it}}") - Sentiment analyzeSentimentOf(String text); - - @UserMessage("Does {{it}} have a positive sentiment?") - boolean isPositive(String text); -} - -SentimentAnalyzer sentimentAnalyzer = AiServices.create(SentimentAnalyzer.class, model); - -Sentiment sentiment = sentimentAnalyzer.analyzeSentimentOf("It is good!"); // POSITIVE - -boolean positive = sentimentAnalyzer.isPositive("It is bad!"); // false -``` - -You can easily extract structured information from unstructured data: - -```java -class Person { - - private String firstName; - private String lastName; - private LocalDate birthDate; -} - -interface PersonExtractor { - - @UserMessage("Extract information about a person from {{text}}") - Person extractPersonFrom(@V("text") String text); -} -PersonExtractor extractor = AiServices.create(PersonExtractor.class, model); +## Getting Started +Getting started guide can be found [here](https://docs.langchain4j.dev/get-started). -String text = "In 1968, amidst the fading echoes of Independence Day, " - + "a child named John arrived under the calm evening sky. " - + "This newborn, bearing the surname Doe, marked the start of a new journey."; -Person person = extractor.extractPersonFrom(text); -// Person { firstName = "John", lastName = "Doe", birthDate = 1968-07-04 } -``` - -You can provide tools that LLMs can use! It can be anything: retrieve information from DB, call APIs, etc. -See example [here](https://github.com/langchain4j/langchain4j-examples/blob/main/other-examples/src/main/java/ServiceWithToolsExample.java). - -## Compatibility - -- Java: 8 or higher -- Spring Boot: 2 or higher - -## Getting started - -1. Add LangChain4j OpenAI dependency to your project: - - Maven: - ```xml - - dev.langchain4j - langchain4j-open-ai - 0.30.0 - - ``` - - Gradle: - ```groovy - implementation 'dev.langchain4j:langchain4j-open-ai:0.30.0' - ``` - -2. Import your OpenAI API key: - ```java - String apiKey = System.getenv("OPENAI_API_KEY"); - ``` - You can also use the API key `demo` to test OpenAI, which we provide for free. - [How to get an API key?](https://github.com/langchain4j/langchain4j#how-to-get-an-api-key) - - -3. Create an instance of a model and start interacting: - ```java - OpenAiChatModel model = OpenAiChatModel.withApiKey(apiKey); - - String answer = model.generate("Hello world!"); - - System.out.println(answer); // Hello! How can I assist you today? - ``` -## Supported LLM Integrations ([Docs](https://docs.langchain4j.dev/category/integrations)) -| Provider | Native Image | [Sync Completion](https://docs.langchain4j.dev/category/language-models) | [Streaming Completion](https://docs.langchain4j.dev/integrations/language-models/response-streaming) | [Embedding](https://docs.langchain4j.dev/category/embedding-models) | [Image Generation](https://docs.langchain4j.dev/category/image-models) | [Scoring](https://docs.langchain4j.dev/category/scoring-models) | [Function Calling](https://docs.langchain4j.dev/tutorials/tools) | -|----------------------------------------------------------------------------------------------------|--------------|--------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------|------------------------------------------------------------------------|-----------------------------------------------------------------|------------------------------------------------------------------| -| [OpenAI](https://docs.langchain4j.dev/integrations/language-models/open-ai) | ✅ | ✅ | ✅ | ✅ | ✅ | | ✅ | -| [Azure OpenAI](https://docs.langchain4j.dev/integrations/language-models/azure-open-ai) | | ✅ | ✅ | ✅ | ✅ | | ✅ | -| [Hugging Face](https://docs.langchain4j.dev/integrations/language-models/hugging-face) | | ✅ | | ✅ | | | | | -| [Amazon Bedrock](https://docs.langchain4j.dev/integrations/language-models/amazon-bedrock) | | ✅ | | ✅ | ✅ | | | -| [Google Vertex AI Gemini](https://docs.langchain4j.dev/integrations/language-models/google-gemini) | | ✅ | ✅ | | ✅ | | ✅ | -| [Google Vertex AI](https://docs.langchain4j.dev/integrations/language-models/google-palm) | ✅ | ✅ | | ✅ | ✅ | | | -| [Mistral AI](https://docs.langchain4j.dev/integrations/language-models/mistral-ai) | | ✅ | ✅ | ✅ | | | ✅ | -| [DashScope](https://docs.langchain4j.dev/integrations/language-models/dashscope) | | ✅ | ✅ | ✅ | | | | -| [LocalAI](https://docs.langchain4j.dev/integrations/language-models/local-ai) | | ✅ | ✅ | ✅ | | | ✅ | -| [Ollama](https://docs.langchain4j.dev/integrations/language-models/ollama) | | ✅ | ✅ | ✅ | | | | -| [Cohere](https://docs.langchain4j.dev/integrations/reranking-models/cohere) | | | | | | ✅ | | -| [Qianfan](https://docs.langchain4j.dev/integrations/language-models/qianfan) | | ✅ | ✅ | ✅ | | | ✅ | -| [ChatGLM](https://docs.langchain4j.dev/integrations/language-models/chatglm) | | ✅ | | | | | | -| [Nomic](https://docs.langchain4j.dev/integrations/language-models/nomic) | | | | ✅ | | | | -| [Anthropic](https://docs.langchain4j.dev/integrations/language-models/anthropic) | ✅ | ✅ | ✅ | | | | ✅ | -| [Zhipu AI](https://docs.langchain4j.dev/integrations/language-models/zhipu-ai) | | ✅ | ✅ | ✅ | | | ✅ | - -## Disclaimer - -Please note that the library is in active development and: - -- Some features are still missing. We are working hard on implementing them ASAP. -- API might change at any moment. At this point, we prioritize good design in the future over backward compatibility - now. We hope for your understanding. -- We need your input! Please [let us know](https://github.com/langchain4j/langchain4j/issues/new/choose) what features you need and your concerns about the current implementation. +## Code Examples +Please see examples of how LangChain4j can be used in [langchain4j-examples](https://github.com/langchain4j/langchain4j-examples) repo: +- [Examples in plain Java](https://github.com/langchain4j/langchain4j-examples/tree/main/other-examples/src/main/java) +- [Examples with Quarkus](https://github.com/quarkiverse/quarkus-langchain4j/tree/main/samples) (uses [quarkus-langchain4j](https://github.com/quarkiverse/quarkus-langchain4j) dependency) +- [Example with Spring Boot](https://github.com/langchain4j/langchain4j-examples/tree/main/spring-boot-example/src/main/java/dev/langchain4j/example) -## Current features (this list is outdated, we have much more): -- AI Services: - - [Simple](https://github.com/langchain4j/langchain4j-examples/blob/main/other-examples/src/main/java/SimpleServiceExample.java) - - [With Memory](https://github.com/langchain4j/langchain4j-examples/blob/main/other-examples/src/main/java/ServiceWithMemoryExample.java) - - [With Tools](https://github.com/langchain4j/langchain4j-examples/blob/main/other-examples/src/main/java/ServiceWithToolsExample.java) - - [With Streaming](https://github.com/langchain4j/langchain4j-examples/blob/main/other-examples/src/main/java/ServiceWithStreamingExample.java) - - [With RAG](https://github.com/langchain4j/langchain4j-examples/blob/main/other-examples/src/main/java/ServiceWithRetrieverExample.java) - - [With Auto-Moderation](https://github.com/langchain4j/langchain4j-examples/blob/main/other-examples/src/main/java/ServiceWithAutoModerationExample.java) - - [With Structured Outputs, Structured Prompts, etc](https://github.com/langchain4j/langchain4j-examples/blob/main/other-examples/src/main/java/OtherServiceExamples.java) -- Integration with [OpenAI](https://platform.openai.com/docs/introduction) and [Azure OpenAI](https://learn.microsoft.com/en-us/azure/ai-services/openai/overview) for: - - [Chats](https://platform.openai.com/docs/guides/chat) (sync + streaming + functions) - - [Completions](https://platform.openai.com/docs/guides/completion) (sync + streaming) - - [Embeddings](https://platform.openai.com/docs/guides/embeddings) -- Integration with [Google Vertex AI](https://cloud.google.com/vertex-ai) for: - - [Chats](https://cloud.google.com/vertex-ai/docs/generative-ai/chat/chat-prompts) - - [Completions](https://cloud.google.com/vertex-ai/docs/generative-ai/text/text-overview) - - [Embeddings](https://cloud.google.com/vertex-ai/docs/generative-ai/embeddings/get-text-embeddings) -- Integration with [Hugging Face Inference API](https://huggingface.co/docs/api-inference/index) for: - - [Chats](https://huggingface.co/docs/api-inference/detailed_parameters#text-generation-task) - - [Completions](https://huggingface.co/docs/api-inference/detailed_parameters#text-generation-task) - - [Embeddings](https://huggingface.co/docs/api-inference/detailed_parameters#feature-extraction-task) -- Integration with [LocalAI](https://localai.io/) for: - - Chats (sync + streaming + functions) - - Completions (sync + streaming) - - Embeddings -- Integration with [DashScope](https://dashscope.aliyun.com/) for: - - Chats (sync + streaming) - - Completions (sync + streaming) - - Embeddings -- [Chat memory](https://github.com/langchain4j/langchain4j-examples/blob/main/other-examples/src/main/java/ChatMemoryExamples.java) -- [Persistent chat memory](https://github.com/langchain4j/langchain4j-examples/blob/main/other-examples/src/main/java/ServiceWithPersistentMemoryForEachUserExample.java) -- [Chat with Documents](https://github.com/langchain4j/langchain4j-examples/blob/main/other-examples/src/main/java/ChatWithDocumentsExamples.java) -- Integration with [Astra DB](https://www.datastax.com/products/datastax-astra) and [Cassandra](https://cassandra.apache.org/) -- [Integration](https://github.com/langchain4j/langchain4j-examples/blob/main/chroma-example/src/main/java/ChromaEmbeddingStoreExample.java) with [Chroma](https://www.trychroma.com/) -- [Integration](https://github.com/langchain4j/langchain4j-examples/blob/main/elasticsearch-example/src/main/java/ElasticsearchEmbeddingStoreExample.java) with [Elasticsearch](https://www.elastic.co/) -- [Integration](https://github.com/langchain4j/langchain4j-examples/blob/main/milvus-example/src/main/java/MilvusEmbeddingStoreExample.java) with [Milvus](https://milvus.io/) -- [Integration](https://github.com/langchain4j/langchain4j-examples/blob/main/pinecone-example/src/main/java/PineconeEmbeddingStoreExample.java) with [Pinecone](https://www.pinecone.io/) -- [Integration](https://github.com/langchain4j/langchain4j-examples/blob/main/redis-example/src/main/java/RedisEmbeddingStoreExample.java) with [Redis](https://redis.io/) -- [Integration](https://github.com/langchain4j/langchain4j-examples/blob/main/vespa-example/src/main/java/VespaEmbeddingStoreExample.java) with [Vespa](https://vespa.ai/) -- [Integration](https://github.com/langchain4j/langchain4j-examples/blob/main/weaviate-example/src/main/java/WeaviateEmbeddingStoreExample.java) with [Weaviate](https://weaviate.io/) -- [In-memory embedding store](https://github.com/langchain4j/langchain4j-examples/blob/main/other-examples/src/main/java/embedding/store/InMemoryEmbeddingStoreExample.java) (can be persisted) -- [Structured outputs](https://github.com/langchain4j/langchain4j-examples/blob/main/other-examples/src/main/java/OtherServiceExamples.java) -- [Prompt templates](https://github.com/langchain4j/langchain4j-examples/blob/main/other-examples/src/main/java/PromptTemplateExamples.java) -- [Structured prompt templates](https://github.com/langchain4j/langchain4j-examples/blob/main/other-examples/src/main/java/StructuredPromptTemplateExamples.java) -- [Streaming of LLM responses](https://github.com/langchain4j/langchain4j-examples/blob/main/other-examples/src/main/java/StreamingExamples.java) -- [Loading txt, html, pdf, doc, xls and ppt documents from the file system and via URL](https://github.com/langchain4j/langchain4j-examples/blob/main/other-examples/src/main/java/DocumentLoaderExamples.java) -- [Splitting documents into segments](https://github.com/langchain4j/langchain4j-examples/blob/main/other-examples/src/main/java/ChatWithDocumentsExamples.java): - - by paragraphs, lines, sentences, words, etc - - recursively - - with overlap -- Token count estimation (so that you can predict how much you will pay) +## Useful Materials +Useful materials can be found [here](https://docs.langchain4j.dev/useful-materials). -## Coming soon: -- Extending "AI Service" features -- Integration with more LLM providers (commercial and free) -- Integrations with more embedding stores (commercial and free) -- Support for more document types -- Long-term memory for chatbots and agents -- Chain-of-Thought and Tree-of-Thought +## Get Help +Please use [Discord](https://discord.gg/JzTFvyjG6R) or [GitHub discussions](https://github.com/langchain4j/langchain4j/discussions) +to get help. -## Request features -Please [let us know](https://github.com/langchain4j/langchain4j/issues/new/choose) what features you need! +## Request Features +Please let us know what features you need by [opening an issue](https://github.com/langchain4j/langchain4j/issues/new/choose). -## Contribution Guidelines +## Contribute Contribution guidelines can be found [here](https://github.com/langchain4j/langchain4j/blob/main/CONTRIBUTING.md). - - -## Use cases - -You might ask why would I need all of this? -Here are a couple of examples: - -- You want to implement a custom AI-powered chatbot that has access to your data and behaves the way you want it: - - Customer support chatbot that can: - - politely answer customer questions - - take /change/cancel orders - - Educational assistant that can: - - Teach various subjects - - Explain unclear parts - - Assess user's understanding/knowledge -- You want to process a lot of unstructured data (files, web pages, etc) and extract structured information from them. - For example: - - extract insights from customer reviews and support chat history - - extract interesting information from the websites of your competitors - - extract insights from CVs of job applicants -- You want to generate information, for example: - - Emails tailored for each of your customers - - Content for your app/website: - - Blog posts - - Stories -- You want to transform information, for example: - - Summarize - - Proofread and rewrite - - Translate - -## Best practices - -We highly recommend -watching [this amazing 90-minute tutorial](https://www.deeplearning.ai/short-courses/chatgpt-prompt-engineering-for-developers/) -on prompt engineering best practices, presented by Andrew Ng (DeepLearning.AI) and Isa Fulford (OpenAI). -This course will teach you how to use LLMs efficiently and achieve the best possible results. Good investment of your -time! - -Here are some best practices for using LLMs: - -- Be responsible. Use AI for Good. -- Be specific. The more specific your query, the best results you will get. -- Add a ["Let’s think step by step" instruction](https://arxiv.org/pdf/2205.11916.pdf) to your prompt. -- Specify steps to achieve the desired goal yourself. This will make the LLM do what you want it to do. -- Provide examples. Sometimes it is best to show LLM a few examples of what you want instead of trying to explain it. -- Ask LLM to provide structured output (JSON, XML, etc). This way you can parse response more easily and distinguish - different parts of it. -- Use unusual delimiters, such as \```triple backticks``` to help the LLM distinguish - data or input from instructions. - -## How to get an API key -You will need an API key from OpenAI (paid) or Hugging Face (free) to use LLMs hosted by them. - -We recommend using OpenAI LLMs (`gpt-3.5-turbo` and `gpt-4`) as they are by far the most capable and are reasonably priced. - -It will cost approximately $0.01 to generate 10 pages (A4 format) of text with `gpt-3.5-turbo`. With `gpt-4`, the cost will be $0.30 to generate the same amount of text. However, for some use cases, this higher cost may be justified. - -[How to get OpenAI API key](https://www.howtogeek.com/885918/how-to-get-an-openai-api-key/). - -For embeddings, we recommend using one of the models from the [Hugging Face MTEB leaderboard](https://huggingface.co/spaces/mteb/leaderboard). -You'll have to find the best one for your specific use case. - -Here's how to get a Hugging Face API key: -- Create an account on https://huggingface.co -- Go to https://huggingface.co/settings/tokens -- Generate a new access token diff --git a/code-execution-engines/langchain4j-code-execution-engine-graalvm-polyglot/pom.xml b/code-execution-engines/langchain4j-code-execution-engine-graalvm-polyglot/pom.xml index 7cdcdd210a..d68b040347 100644 --- a/code-execution-engines/langchain4j-code-execution-engine-graalvm-polyglot/pom.xml +++ b/code-execution-engines/langchain4j-code-execution-engine-graalvm-polyglot/pom.xml @@ -7,7 +7,7 @@ dev.langchain4j langchain4j-parent - 0.30.0 + 0.32.0-SNAPSHOT ../../langchain4j-parent/pom.xml diff --git a/code-execution-engines/langchain4j-code-execution-engine-judge0/pom.xml b/code-execution-engines/langchain4j-code-execution-engine-judge0/pom.xml new file mode 100644 index 0000000000..662f81d8e1 --- /dev/null +++ b/code-execution-engines/langchain4j-code-execution-engine-judge0/pom.xml @@ -0,0 +1,89 @@ + + + 4.0.0 + + + dev.langchain4j + langchain4j-parent + 0.32.0-SNAPSHOT + ../../langchain4j-parent/pom.xml + + + langchain4j-code-execution-engine-judge0 + LangChain4j :: Integration :: Judge0 + Implementation of JavaScript code execution engine and tool using Judge0 + + + UTF-8 + + + + + + dev.langchain4j + langchain4j-core + + + + com.squareup.okhttp3 + okhttp + + + + org.junit.jupiter + junit-jupiter-engine + test + + + + org.junit.jupiter + junit-jupiter-params + test + + + + org.assertj + assertj-core + test + + + + org.mockito + mockito-core + test + + + + org.mockito + mockito-junit-jupiter + test + + + + dev.langchain4j + langchain4j + test + + + + dev.langchain4j + langchain4j-open-ai + test + + + + org.tinylog + tinylog-impl + test + + + org.tinylog + slf4j-tinylog + test + + + + + \ No newline at end of file diff --git a/langchain4j/src/main/java/dev/langchain4j/code/JavaScriptCodeFixer.java b/code-execution-engines/langchain4j-code-execution-engine-judge0/src/main/java/dev/langchain4j/code/judge0/JavaScriptCodeFixer.java similarity index 96% rename from langchain4j/src/main/java/dev/langchain4j/code/JavaScriptCodeFixer.java rename to code-execution-engines/langchain4j-code-execution-engine-judge0/src/main/java/dev/langchain4j/code/judge0/JavaScriptCodeFixer.java index 5bb146441d..3cfc308f87 100644 --- a/langchain4j/src/main/java/dev/langchain4j/code/JavaScriptCodeFixer.java +++ b/code-execution-engines/langchain4j-code-execution-engine-judge0/src/main/java/dev/langchain4j/code/judge0/JavaScriptCodeFixer.java @@ -1,4 +1,4 @@ -package dev.langchain4j.code; +package dev.langchain4j.code.judge0; import org.slf4j.Logger; import org.slf4j.LoggerFactory; diff --git a/langchain4j/src/main/java/dev/langchain4j/code/Judge0JavaScriptEngine.java b/code-execution-engines/langchain4j-code-execution-engine-judge0/src/main/java/dev/langchain4j/code/judge0/Judge0JavaScriptEngine.java similarity index 97% rename from langchain4j/src/main/java/dev/langchain4j/code/Judge0JavaScriptEngine.java rename to code-execution-engines/langchain4j-code-execution-engine-judge0/src/main/java/dev/langchain4j/code/judge0/Judge0JavaScriptEngine.java index 746e6e9397..7f4d9d9282 100644 --- a/langchain4j/src/main/java/dev/langchain4j/code/Judge0JavaScriptEngine.java +++ b/code-execution-engines/langchain4j-code-execution-engine-judge0/src/main/java/dev/langchain4j/code/judge0/Judge0JavaScriptEngine.java @@ -1,5 +1,6 @@ -package dev.langchain4j.code; +package dev.langchain4j.code.judge0; +import dev.langchain4j.code.CodeExecutionEngine; import dev.langchain4j.internal.Json; import okhttp3.*; import org.slf4j.Logger; diff --git a/langchain4j/src/main/java/dev/langchain4j/code/Judge0JavaScriptExecutionTool.java b/code-execution-engines/langchain4j-code-execution-engine-judge0/src/main/java/dev/langchain4j/code/judge0/Judge0JavaScriptExecutionTool.java similarity index 95% rename from langchain4j/src/main/java/dev/langchain4j/code/Judge0JavaScriptExecutionTool.java rename to code-execution-engines/langchain4j-code-execution-engine-judge0/src/main/java/dev/langchain4j/code/judge0/Judge0JavaScriptExecutionTool.java index 12530c0857..d18d14e0f5 100644 --- a/langchain4j/src/main/java/dev/langchain4j/code/Judge0JavaScriptExecutionTool.java +++ b/code-execution-engines/langchain4j-code-execution-engine-judge0/src/main/java/dev/langchain4j/code/judge0/Judge0JavaScriptExecutionTool.java @@ -1,11 +1,11 @@ -package dev.langchain4j.code; +package dev.langchain4j.code.judge0; import dev.langchain4j.agent.tool.P; import dev.langchain4j.agent.tool.Tool; import java.time.Duration; -import static dev.langchain4j.code.JavaScriptCodeFixer.fixIfNoLogToConsole; +import static dev.langchain4j.code.judge0.JavaScriptCodeFixer.fixIfNoLogToConsole; import static dev.langchain4j.internal.Utils.isNullOrBlank; /** diff --git a/langchain4j/src/test/java/dev/langchain4j/code/JavaScriptCodeFixerTest.java b/code-execution-engines/langchain4j-code-execution-engine-judge0/src/test/java/dev/langchain4j/code/judge0/JavaScriptCodeFixerTest.java similarity index 98% rename from langchain4j/src/test/java/dev/langchain4j/code/JavaScriptCodeFixerTest.java rename to code-execution-engines/langchain4j-code-execution-engine-judge0/src/test/java/dev/langchain4j/code/judge0/JavaScriptCodeFixerTest.java index fcb46e8f00..e94c5c6aab 100644 --- a/langchain4j/src/test/java/dev/langchain4j/code/JavaScriptCodeFixerTest.java +++ b/code-execution-engines/langchain4j-code-execution-engine-judge0/src/test/java/dev/langchain4j/code/judge0/JavaScriptCodeFixerTest.java @@ -1,4 +1,4 @@ -package dev.langchain4j.code; +package dev.langchain4j.code.judge0; import org.junit.jupiter.api.Test; diff --git a/code-execution-engines/langchain4j-code-execution-engine-judge0/src/test/java/dev/langchain4j/code/judge0/Judge0JavaScriptEngineIT.java b/code-execution-engines/langchain4j-code-execution-engine-judge0/src/test/java/dev/langchain4j/code/judge0/Judge0JavaScriptEngineIT.java new file mode 100644 index 0000000000..337601645f --- /dev/null +++ b/code-execution-engines/langchain4j-code-execution-engine-judge0/src/test/java/dev/langchain4j/code/judge0/Judge0JavaScriptEngineIT.java @@ -0,0 +1,30 @@ +package dev.langchain4j.code.judge0; + +import dev.langchain4j.code.CodeExecutionEngine; +import org.junit.jupiter.api.Test; + +import java.time.Duration; + +import static org.assertj.core.api.Assertions.assertThat; + +class Judge0JavaScriptEngineIT { + + private static final int JAVASCRIPT = 93; + + @Test + void should_execute_code() { + + // given + CodeExecutionEngine codeExecutionEngine = new Judge0JavaScriptEngine( + System.getenv("RAPID_API_KEY"), + JAVASCRIPT, + Duration.ofSeconds(60) + ); + + // when + String result = codeExecutionEngine.execute("console.log('hello world');"); + + // then + assertThat(result).isEqualTo("hello world"); + } +} \ No newline at end of file diff --git a/docker/ollama/llama3/Dockerfile b/docker/ollama/llama3/Dockerfile new file mode 100644 index 0000000000..79ab8f9524 --- /dev/null +++ b/docker/ollama/llama3/Dockerfile @@ -0,0 +1,4 @@ +FROM --platform=$BUILDPLATFORM ollama/ollama:latest +RUN /bin/sh -c "/bin/ollama serve & sleep 1 && ollama pull llama3" +ENTRYPOINT ["/bin/ollama"] +CMD ["serve"] \ No newline at end of file diff --git a/docker/ollama/llama3/hooks/build b/docker/ollama/llama3/hooks/build new file mode 100644 index 0000000000..d4c3179896 --- /dev/null +++ b/docker/ollama/llama3/hooks/build @@ -0,0 +1,3 @@ +#!/bin/bash +docker buildx create --use +docker buildx build --push --platform=linux/amd64,linux/arm64 -f $DOCKERFILE_PATH -t $IMAGE_NAME . \ No newline at end of file diff --git a/docker/ollama/llama3/hooks/push b/docker/ollama/llama3/hooks/push new file mode 100644 index 0000000000..a9bf588e2f --- /dev/null +++ b/docker/ollama/llama3/hooks/push @@ -0,0 +1 @@ +#!/bin/bash diff --git a/docs/docs/get-started.md b/docs/docs/get-started.md index d5ac56b5b7..2e6348af6f 100644 --- a/docs/docs/get-started.md +++ b/docs/docs/get-started.md @@ -4,33 +4,40 @@ sidebar_position: 5 # Get Started -## Prerequisites :::note -Ensure you have Java 8 or higher installed. Verify it by typing this command in your terminal: -```shell -java --version -``` -::: +If you are using Quarkus, see [Quarkus Integration](/tutorials/quarkus-integration/). -## Write a "Hello World" program +If you are using Spring Boot, see [Spring Boot Integration](/tutorials/spring-boot-integration). +::: -The simplest way to begin is with the OpenAI integration. -LangChain4j offers integration with many LLMs. -Each integration has its own dependency. -In this case, we should add the OpenAI dependency: +LangChain4j offers [integration with many LLM providers](/integrations/language-models/). +Each integration has its own maven dependency. +The simplest way to begin is with the OpenAI integration: - For Maven in `pom.xml`: ```xml dev.langchain4j langchain4j-open-ai - 0.30.0 + 0.31.0 + +``` + +If you wish to use a high-level [AI Services](/tutorials/ai-services) API, you will also need to add +the following dependency: + +```xml + + dev.langchain4j + langchain4j + 0.31.0 ``` - For Gradle in `build.gradle`: ```groovy -implementation 'dev.langchain4j:langchain4j-open-ai:0.30.0' +implementation 'dev.langchain4j:langchain4j-open-ai:0.31.0' +implementation 'dev.langchain4j:langchain4j:0.31.0' ``` Then, import your OpenAI API key. diff --git a/docs/docs/integrations/document-loaders/amazon-s3.md b/docs/docs/integrations/document-loaders/amazon-s3.md new file mode 100644 index 0000000000..d386ac4146 --- /dev/null +++ b/docs/docs/integrations/document-loaders/amazon-s3.md @@ -0,0 +1,13 @@ +--- +sidebar_position: 1 +--- + +# Amazon S3 Loader + +```xml + + dev.langchain4j + langchain4j-document-loader-amazon-s3 + 0.31.0 + +``` \ No newline at end of file diff --git a/docs/docs/integrations/document-loaders/azure-blob-storage.md b/docs/docs/integrations/document-loaders/azure-blob-storage.md new file mode 100644 index 0000000000..9003076498 --- /dev/null +++ b/docs/docs/integrations/document-loaders/azure-blob-storage.md @@ -0,0 +1,7 @@ +--- +sidebar_position: 2 +--- + +# Azure Blob Storage Loader + +Coming soon \ No newline at end of file diff --git a/docs/docs/integrations/document-loaders/doc.md b/docs/docs/integrations/document-loaders/doc.md deleted file mode 100644 index b34cfcbcb1..0000000000 --- a/docs/docs/integrations/document-loaders/doc.md +++ /dev/null @@ -1,5 +0,0 @@ ---- -sidebar_position: 12 ---- - -# .doc \ No newline at end of file diff --git a/docs/docs/integrations/document-loaders/txt.md b/docs/docs/integrations/document-loaders/file-system.md similarity index 54% rename from docs/docs/integrations/document-loaders/txt.md rename to docs/docs/integrations/document-loaders/file-system.md index 892de7526b..f8b06fcb1f 100644 --- a/docs/docs/integrations/document-loaders/txt.md +++ b/docs/docs/integrations/document-loaders/file-system.md @@ -2,4 +2,4 @@ sidebar_position: 3 --- -# .txt \ No newline at end of file +# File System Loader \ No newline at end of file diff --git a/docs/docs/integrations/document-loaders/github.md b/docs/docs/integrations/document-loaders/github.md index 2ca3aff5bf..9c69fccd21 100644 --- a/docs/docs/integrations/document-loaders/github.md +++ b/docs/docs/integrations/document-loaders/github.md @@ -1,5 +1,5 @@ --- -sidebar_position: 10 +sidebar_position: 4 --- # Github Loader diff --git a/docs/docs/integrations/document-loaders/pdf.md b/docs/docs/integrations/document-loaders/pdf.md deleted file mode 100644 index 527b4677f3..0000000000 --- a/docs/docs/integrations/document-loaders/pdf.md +++ /dev/null @@ -1,5 +0,0 @@ ---- -sidebar_position: 9 ---- - -# .pdf \ No newline at end of file diff --git a/docs/docs/integrations/document-loaders/ppt.md b/docs/docs/integrations/document-loaders/ppt.md deleted file mode 100644 index b7f5c27a20..0000000000 --- a/docs/docs/integrations/document-loaders/ppt.md +++ /dev/null @@ -1,5 +0,0 @@ ---- -sidebar_position: 18 ---- - -# .ppt \ No newline at end of file diff --git a/docs/docs/integrations/document-loaders/s3.md b/docs/docs/integrations/document-loaders/s3.md deleted file mode 100644 index 6414035f76..0000000000 --- a/docs/docs/integrations/document-loaders/s3.md +++ /dev/null @@ -1,5 +0,0 @@ ---- -sidebar_position: 24 ---- - -# S3 (Amazon Simple Storage Service) \ No newline at end of file diff --git a/docs/docs/integrations/document-loaders/selenium.md b/docs/docs/integrations/document-loaders/selenium.md new file mode 100644 index 0000000000..6b564cdc26 --- /dev/null +++ b/docs/docs/integrations/document-loaders/selenium.md @@ -0,0 +1,7 @@ +--- +sidebar_position: 5 +--- + +# Selenium Loader + +Coming soon \ No newline at end of file diff --git a/docs/docs/integrations/document-loaders/html.md b/docs/docs/integrations/document-loaders/tencent-cos.md similarity index 54% rename from docs/docs/integrations/document-loaders/html.md rename to docs/docs/integrations/document-loaders/tencent-cos.md index a11ab7bc4c..feb04939df 100644 --- a/docs/docs/integrations/document-loaders/html.md +++ b/docs/docs/integrations/document-loaders/tencent-cos.md @@ -2,4 +2,4 @@ sidebar_position: 6 --- -# .html \ No newline at end of file +# Tencent COS Loader \ No newline at end of file diff --git a/docs/docs/integrations/document-loaders/url.md b/docs/docs/integrations/document-loaders/url.md index 99738bd7b7..347df150af 100644 --- a/docs/docs/integrations/document-loaders/url.md +++ b/docs/docs/integrations/document-loaders/url.md @@ -1,5 +1,5 @@ --- -sidebar_position: 21 +sidebar_position: 7 --- -# .url \ No newline at end of file +# URL Loader \ No newline at end of file diff --git a/docs/docs/integrations/document-loaders/xls.md b/docs/docs/integrations/document-loaders/xls.md deleted file mode 100644 index f795b59d2d..0000000000 --- a/docs/docs/integrations/document-loaders/xls.md +++ /dev/null @@ -1,5 +0,0 @@ ---- -sidebar_position: 15 ---- - -# .xls \ No newline at end of file diff --git a/docs/docs/integrations/document-parsers/_category_.json b/docs/docs/integrations/document-parsers/_category_.json new file mode 100644 index 0000000000..c243ffb200 --- /dev/null +++ b/docs/docs/integrations/document-parsers/_category_.json @@ -0,0 +1,8 @@ +{ + "label": "Document Parsers", + "position": 7, + "link": { + "type": "generated-index", + "description": "Document Parsers" + } +} diff --git a/docs/docs/integrations/document-parsers/apache-pdfbox.md b/docs/docs/integrations/document-parsers/apache-pdfbox.md new file mode 100644 index 0000000000..d52d091e8d --- /dev/null +++ b/docs/docs/integrations/document-parsers/apache-pdfbox.md @@ -0,0 +1,14 @@ +--- +sidebar_position: 4 +--- + +# Apache PDFBox + +`ApachePdfBoxDocumentParser` can be found in the following module: +```xml + + dev.langchain4j + langchain4j-document-parser-apache-pdfbox + 0.31.0 + +``` \ No newline at end of file diff --git a/docs/docs/integrations/document-parsers/apache-poi.md b/docs/docs/integrations/document-parsers/apache-poi.md new file mode 100644 index 0000000000..5bd8466ce1 --- /dev/null +++ b/docs/docs/integrations/document-parsers/apache-poi.md @@ -0,0 +1,14 @@ +--- +sidebar_position: 3 +--- + +# Apache POI + +`ApachePoiDocumentParser` can be found in the following module: +```xml + + dev.langchain4j + langchain4j-document-parser-apache-poi + 0.31.0 + +``` \ No newline at end of file diff --git a/docs/docs/integrations/document-parsers/apache-tika.md b/docs/docs/integrations/document-parsers/apache-tika.md new file mode 100644 index 0000000000..b85520f238 --- /dev/null +++ b/docs/docs/integrations/document-parsers/apache-tika.md @@ -0,0 +1,14 @@ +--- +sidebar_position: 2 +--- + +# Apache Tika + +`ApacheTikaDocumentParser` can be found in the following module: +```xml + + dev.langchain4j + langchain4j-document-parser-apache-tika + 0.31.0 + +``` \ No newline at end of file diff --git a/docs/docs/integrations/document-parsers/text.md b/docs/docs/integrations/document-parsers/text.md new file mode 100644 index 0000000000..a909f23086 --- /dev/null +++ b/docs/docs/integrations/document-parsers/text.md @@ -0,0 +1,14 @@ +--- +sidebar_position: 1 +--- + +# Text + +`TextDocumentParser` can be found in the main module: +```xml + + dev.langchain4j + langchain4j + 0.31.0 + +``` \ No newline at end of file diff --git a/docs/docs/integrations/embedding-models/cohere.md b/docs/docs/integrations/embedding-models/cohere.md new file mode 100644 index 0000000000..b0532a1415 --- /dev/null +++ b/docs/docs/integrations/embedding-models/cohere.md @@ -0,0 +1,5 @@ +--- +sidebar_position: 4 +--- + +# Cohere \ No newline at end of file diff --git a/docs/docs/integrations/embedding-models/dashscope.md b/docs/docs/integrations/embedding-models/dashscope.md index ff538821c9..c3872ab677 100644 --- a/docs/docs/integrations/embedding-models/dashscope.md +++ b/docs/docs/integrations/embedding-models/dashscope.md @@ -1,5 +1,5 @@ --- -sidebar_position: 4 +sidebar_position: 5 --- # DashScope diff --git a/docs/docs/integrations/embedding-models/google-vertex-ai.md b/docs/docs/integrations/embedding-models/google-vertex-ai.md index 7063f9c1e8..51fe249825 100644 --- a/docs/docs/integrations/embedding-models/google-vertex-ai.md +++ b/docs/docs/integrations/embedding-models/google-vertex-ai.md @@ -1,5 +1,5 @@ --- -sidebar_position: 5 +sidebar_position: 6 --- # Google Vertex AI diff --git a/docs/docs/integrations/embedding-models/hugging-face.md b/docs/docs/integrations/embedding-models/hugging-face.md index e7a578e9c9..e4052132f5 100644 --- a/docs/docs/integrations/embedding-models/hugging-face.md +++ b/docs/docs/integrations/embedding-models/hugging-face.md @@ -1,5 +1,5 @@ --- -sidebar_position: 6 +sidebar_position: 7 --- # Hugging Face diff --git a/docs/docs/integrations/embedding-models/jina.md b/docs/docs/integrations/embedding-models/jina.md new file mode 100644 index 0000000000..00c25a7924 --- /dev/null +++ b/docs/docs/integrations/embedding-models/jina.md @@ -0,0 +1,5 @@ +--- +sidebar_position: 8 +--- + +# Jina \ No newline at end of file diff --git a/docs/docs/integrations/embedding-models/local-ai.md b/docs/docs/integrations/embedding-models/local-ai.md index 5730c1eb78..134e41fad1 100644 --- a/docs/docs/integrations/embedding-models/local-ai.md +++ b/docs/docs/integrations/embedding-models/local-ai.md @@ -1,5 +1,5 @@ --- -sidebar_position: 7 +sidebar_position: 9 --- # LocalAI diff --git a/docs/docs/integrations/embedding-models/mistral-ai.md b/docs/docs/integrations/embedding-models/mistral-ai.md index 7dd661c91f..5e126fe6db 100644 --- a/docs/docs/integrations/embedding-models/mistral-ai.md +++ b/docs/docs/integrations/embedding-models/mistral-ai.md @@ -1,5 +1,5 @@ --- -sidebar_position: 8 +sidebar_position: 10 --- # Mistral AI diff --git a/docs/docs/integrations/embedding-models/nomic.md b/docs/docs/integrations/embedding-models/nomic.md index f38af30868..9e38fad3c4 100644 --- a/docs/docs/integrations/embedding-models/nomic.md +++ b/docs/docs/integrations/embedding-models/nomic.md @@ -1,5 +1,5 @@ --- -sidebar_position: 9 +sidebar_position: 11 --- # Nomic \ No newline at end of file diff --git a/docs/docs/integrations/embedding-models/ollama.md b/docs/docs/integrations/embedding-models/ollama.md index 2a4cad7d0c..5280923b67 100644 --- a/docs/docs/integrations/embedding-models/ollama.md +++ b/docs/docs/integrations/embedding-models/ollama.md @@ -1,5 +1,5 @@ --- -sidebar_position: 10 +sidebar_position: 12 --- # Ollama \ No newline at end of file diff --git a/docs/docs/integrations/embedding-models/open-ai.md b/docs/docs/integrations/embedding-models/open-ai.md index eda23681eb..7dbab6c651 100644 --- a/docs/docs/integrations/embedding-models/open-ai.md +++ b/docs/docs/integrations/embedding-models/open-ai.md @@ -1,5 +1,5 @@ --- -sidebar_position: 11 +sidebar_position: 13 --- # OpenAI diff --git a/docs/docs/integrations/embedding-models/qianfan.md b/docs/docs/integrations/embedding-models/qianfan.md index fc1bd80341..93f772b2ca 100644 --- a/docs/docs/integrations/embedding-models/qianfan.md +++ b/docs/docs/integrations/embedding-models/qianfan.md @@ -1,5 +1,5 @@ --- -sidebar_position: 12 +sidebar_position: 14 --- # Qianfan \ No newline at end of file diff --git a/docs/docs/integrations/embedding-models/zhipu-ai.md b/docs/docs/integrations/embedding-models/zhipu-ai.md index 582c565d51..307dcfe497 100644 --- a/docs/docs/integrations/embedding-models/zhipu-ai.md +++ b/docs/docs/integrations/embedding-models/zhipu-ai.md @@ -1,5 +1,5 @@ --- -sidebar_position: 13 +sidebar_position: 15 --- # ZhipuAI diff --git a/docs/docs/integrations/embedding-stores/index.md b/docs/docs/integrations/embedding-stores/index.md index 4386984d6f..8fcd33fa57 100644 --- a/docs/docs/integrations/embedding-stores/index.md +++ b/docs/docs/integrations/embedding-stores/index.md @@ -1,27 +1,21 @@ ---- -title: Comparison Table -hide_title: false -sidebar_position: 0 ---- - | Provider | Storing Metadata | Filtering by Metadata | Local | Cloud | |---------------------------------------------------------------------------------------|:------------------:|:---------------------:|:-------:|:-------:| -| [In-memory](/integrations/embedding-stores/in-memory) | ✅ | ✅ | | | -| [AstraDB](/integrations/embedding-stores/astra-db) | ✅ | ✅ | | ✅ | -| [Azure AI Search](/integrations/embedding-stores/azure-ai-search) | ✅ | | | | -| [Azure CosmosDB Mongo vCore](/integrations/embedding-stores/azure-cosmos-mongo-vcore) | ✅ | | | | -| [Apache Cassandra™](/integrations/embedding-stores/cassandra) | ✅ | ✅ (partial) | ✅ | ✅ | -| [Chroma](/integrations/embedding-stores/chroma) | ✅ | | | | -| [Elasticsearch](/integrations/embedding-stores/elasticsearch) | ✅ | ✅ | | | -| [Infinispan](/integrations/embedding-stores/infinispan) | ✅ | | | | -| [Milvus](/integrations/embedding-stores/milvus) | ✅ | ✅ | | | -| [MongoDB Atlas](/integrations/embedding-stores/mongodb-atlas) | ✅ | | | | -| [Neo4j](/integrations/embedding-stores/neo4j) | | | | | -| [OpenSearch](/integrations/embedding-stores/opensearch) | ✅ | | | | -| [PGVector](/integrations/embedding-stores/pgvector) | ✅ | | | | +| [In-memory](/integrations/embedding-stores/in-memory) | ✅ | ✅ | | | +| [AstraDB](/integrations/embedding-stores/astra-db) | ✅ | ✅ | | ✅ | +| [Azure AI Search](/integrations/embedding-stores/azure-ai-search) | ✅ | | | | +| [Azure CosmosDB Mongo vCore](/integrations/embedding-stores/azure-cosmos-mongo-vcore) | ✅ | | | | +| [Apache Cassandra™](/integrations/embedding-stores/cassandra) | ✅ | ✅ | ✅ | ✅ | +| [Chroma](/integrations/embedding-stores/chroma) | ✅ | | | | +| [Elasticsearch](/integrations/embedding-stores/elasticsearch) | ✅ | ✅ | | | +| [Infinispan](/integrations/embedding-stores/infinispan) | ✅ | | | | +| [Milvus](/integrations/embedding-stores/milvus) | ✅ | ✅ | | | +| [MongoDB Atlas](/integrations/embedding-stores/mongodb-atlas) | ✅ | | | | +| [Neo4j](/integrations/embedding-stores/neo4j) | | | | | +| [OpenSearch](/integrations/embedding-stores/opensearch) | ✅ | | | | +| [PGVector](/integrations/embedding-stores/pgvector) | ✅ | ✅ | ✅ | | [Pinecone](/integrations/embedding-stores/pinecone) | | | | | | [Qdrant](/integrations/embedding-stores/qdrant) | ✅ | | | | | [Redis](/integrations/embedding-stores/redis) | ✅ | | | | | [Vearch](/integrations/embedding-stores/vearch) | ✅ | | | | | [Vespa](/integrations/embedding-stores/vespa) | | | | | -| [Weaviate](/integrations/embedding-stores/weaviate) | | | | | +| [Weaviate](/integrations/embedding-stores/weaviate) | ✅ | | ✅ | \ No newline at end of file diff --git a/docs/docs/integrations/frameworks/spring-boot.md b/docs/docs/integrations/frameworks/spring-boot.md index 7059b7c79b..efd189aa27 100644 --- a/docs/docs/integrations/frameworks/spring-boot.md +++ b/docs/docs/integrations/frameworks/spring-boot.md @@ -2,4 +2,6 @@ sidebar_position: 6 --- -# Spring Boot \ No newline at end of file +# Spring Boot + +Documentation on Spring Boot integration can be found [here](/tutorials/spring-boot-integration). \ No newline at end of file diff --git a/docs/docs/integrations/image-models/azure-dall-e.md b/docs/docs/integrations/image-models/azure-dall-e.md index 60e89feda6..059bcfab4d 100644 --- a/docs/docs/integrations/image-models/azure-dall-e.md +++ b/docs/docs/integrations/image-models/azure-dall-e.md @@ -2,4 +2,6 @@ sidebar_position: 4 --- -# Azure OpenAI Dall·E \ No newline at end of file +# Azure OpenAI Dall·E + +Example can be found [here](https://github.com/langchain4j/langchain4j-examples/blob/main/azure-open-ai-examples/src/main/java/AzureOpenAIDallEExample.java). diff --git a/docs/docs/integrations/image-models/dall-e.md b/docs/docs/integrations/image-models/dall-e.md index d3875af0e3..567fdb97fd 100644 --- a/docs/docs/integrations/image-models/dall-e.md +++ b/docs/docs/integrations/image-models/dall-e.md @@ -2,4 +2,5 @@ sidebar_position: 3 --- -# OpenAI Dall·E \ No newline at end of file +# OpenAI Dall·E +Example can be found [here](https://github.com/langchain4j/langchain4j-examples/blob/main/open-ai-examples/src/main/java/OpenAiImageModelExamples.java). diff --git a/docs/docs/integrations/image-models/workers-ai.md b/docs/docs/integrations/image-models/workers-ai.md new file mode 100644 index 0000000000..0b7e3c2180 --- /dev/null +++ b/docs/docs/integrations/image-models/workers-ai.md @@ -0,0 +1,7 @@ +--- +sidebar_position: 7 +--- + +# Cloudflare Workers AI + +https://developers.cloudflare.com/workers-ai/ \ No newline at end of file diff --git a/docs/docs/integrations/index.mdx b/docs/docs/integrations/index.mdx index 73305da499..7bc12517d9 100644 --- a/docs/docs/integrations/index.mdx +++ b/docs/docs/integrations/index.mdx @@ -31,7 +31,7 @@ of course some LLM providers offer large multimodal model (accepting text or ima | [Google Vertex AI Gemini](/integrations/language-models/google-gemini) | | ✅ | ✅ | | ✅ | | ✅ | | [Google Vertex AI](/integrations/language-models/google-palm) | ✅ | ✅ | | ✅ | ✅ | | | | [Mistral AI](/integrations/language-models/mistral-ai) | | ✅ | ✅ | ✅ | | |✅ | -| [DashScope](/integrations/language-models/dashscope) | | ✅ | ✅ |✅ | | | | +| [DashScope](/integrations/language-models/dashscope) | | ✅ | ✅ | ✅ | | | ✅ | | [LocalAI](/integrations/language-models/local-ai) | | ✅ | ✅ | ✅ | | | ✅ | | [Ollama](/integrations/language-models/ollama) | | ✅ | ✅ | ✅ | | | | | Cohere | | | | | | ✅| | diff --git a/docs/docs/integrations/language-models/anthropic.md b/docs/docs/integrations/language-models/anthropic.md index bc2e8a1308..631cf67030 100644 --- a/docs/docs/integrations/language-models/anthropic.md +++ b/docs/docs/integrations/language-models/anthropic.md @@ -13,7 +13,7 @@ sidebar_position: 2 dev.langchain4j langchain4j-anthropic - 0.30.0 + 0.31.0 ``` @@ -90,7 +90,7 @@ Import Spring Boot starter for Anthropic: dev.langchain4j langchain4j-anthropic-spring-boot-starter - 0.30.0 + 0.31.0 ``` diff --git a/docs/docs/integrations/language-models/google-gemini.md b/docs/docs/integrations/language-models/google-gemini.md index 005837336a..fcc7b0ce19 100644 --- a/docs/docs/integrations/language-models/google-gemini.md +++ b/docs/docs/integrations/language-models/google-gemini.md @@ -1,4 +1,5 @@ --- +id: gemini-language-model sidebar_position: 6 --- diff --git a/docs/docs/integrations/language-models/index.md b/docs/docs/integrations/language-models/index.md index cd4aef02eb..4115a720f4 100644 --- a/docs/docs/integrations/language-models/index.md +++ b/docs/docs/integrations/language-models/index.md @@ -1,22 +1,24 @@ --- -title: Comparison Table +id: supported-language-models +title: Comparison Table of all supported Language Models hide_title: false sidebar_position: 0 --- -| Provider | [Streaming](/tutorials/response-streaming) | [Tools](/tutorials/tools) | Image Inputs | Local | Native | -|------------------------------------------------------------------------|--------------------------------------------|---------------------------|--------------|---------------------------------------------------|--------| -| [Amazon Bedrock](/integrations/language-models/amazon-bedrock) | | | | | | -| [Anthropic](/integrations/language-models/anthropic) | ✅ | ✅ | ✅ | | ✅ | -| [Azure OpenAI](/integrations/language-models/azure-open-ai) | ✅ | ✅ | ✅ | | | -| [ChatGLM](/integrations/language-models/chatglm) | | | | | | -| [DashScope](/integrations/language-models/dashscope) | ✅ | | | | | -| [Google Vertex AI Gemini](/integrations/language-models/google-gemini) | ✅ | ✅ | ✅ | | | -| [Google Vertex AI PaLM 2](/integrations/language-models/google-palm) | | | | | ✅ | -| [Hugging Face](/integrations/language-models/hugging-face) | | | | | | -| [LocalAI](/integrations/language-models/local-ai) | ✅ | ✅ | | ✅ | | -| [Mistral AI](/integrations/language-models/mistral-ai) | ✅ | ✅ | | | | -| [Ollama](/integrations/language-models/ollama) | ✅ | | ✅ | ✅ | | -| [OpenAI](/integrations/language-models/open-ai) | ✅ | ✅ | ✅ | Compatible with: Ollama, LM Studio, GPT4All, etc. | ✅ | -| [Qianfan](/integrations/language-models/qianfan) | ✅ | ✅ | | | | -| [Zhipu AI](/integrations/language-models/zhipu-ai) | ✅ | ✅ | | | | +| Provider | [Streaming](/tutorials/response-streaming) | [Tools](/tutorials/tools) | Image Inputs | Local | Native | +|------------------------------------------------------------------------|--------------------------------------------|---------------------------|--------------|---------------------------------------------------------|--------| +| [Amazon Bedrock](/integrations/language-models/amazon-bedrock) | | | | | | +| [Anthropic](/integrations/language-models/anthropic) | ✅ | ✅ | ✅ | | ✅ | +| [Azure OpenAI](/integrations/language-models/azure-open-ai) | ✅ | ✅ | ✅ | | | +| [ChatGLM](/integrations/language-models/chatglm) | | | | | | +| [DashScope](/integrations/language-models/dashscope) | ✅ | ✅ | ✅ | | | +| [Google Vertex AI Gemini](/integrations/language-models/google-gemini) | ✅ | ✅ | ✅ | | | +| [Google Vertex AI PaLM 2](/integrations/language-models/google-palm) | | | | | ✅ | +| [Hugging Face](/integrations/language-models/hugging-face) | | | | | | +| [LocalAI](/integrations/language-models/local-ai) | ✅ | ✅ | | ✅ | | +| [Mistral AI](/integrations/language-models/mistral-ai) | ✅ | ✅ | | | | +| [Ollama](/integrations/language-models/ollama) | ✅ | | ✅ | ✅ | | +| [OpenAI](/integrations/language-models/open-ai) | ✅ | ✅ | ✅ | Compatible with: Groq, Ollama, LM Studio, GPT4All, etc. | ✅ | +| [Qianfan](/integrations/language-models/qianfan) | ✅ | ✅ | | | | +| [Cloudflare Workers AI](/integrations/language-models/workers-ai) | | | | | | +| [Zhipu AI](/integrations/language-models/zhipu-ai) | ✅ | ✅ | | | | diff --git a/docs/docs/integrations/language-models/mistral-ai.md b/docs/docs/integrations/language-models/mistral-ai.md index 6cf78f0356..a814ca3673 100644 --- a/docs/docs/integrations/language-models/mistral-ai.md +++ b/docs/docs/integrations/language-models/mistral-ai.md @@ -5,7 +5,7 @@ sidebar_position: 10 # MistralAI [MistralAI Documentation](https://docs.mistral.ai/) -### Project setup +## Project setup To install langchain4j to your project, add the following dependency: @@ -32,7 +32,7 @@ For Gradle project `build.gradle` implementation 'dev.langchain4j:langchain4j:{your-version}' implementation 'dev.langchain4j:langchain4j-mistral-ai:{your-version}' ``` -#### API Key setup +### API Key setup Add your MistralAI API key to your project, you can create a class ```ApiKeys.java``` with the following code ```java @@ -47,16 +47,21 @@ SET MISTRAL_AI_API_KEY=your-api-key #For Windows OS ``` More details on how to get your MistralAI API key can be found [here](https://docs.mistral.ai/#api-access) -#### Model Selection +### Model Selection You can use `MistralAiChatModelName.class` enum class to found appropriate model names for your use case. MistralAI updated a new selection and classification of models according to performance and cost trade-offs. -Here a list of available models: -- open-mistral-7b (aka mistral-tiny-2312) -- open-mixtral-8x7b (aka mistral-small-2312) -- mistral-small-latest (aka mistral-small-2402) -- mistral-medium-latest (aka mistral-medium-2312) -- mistral-large-latest (aka mistral-large-2402) +| Model name | Deployment or available from | Description | +|------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| open-mistral-7b | - Mistral AI La Plateforme.
- Cloud platforms (Azure, AWS, GCP).
- Hugging Face.
- Self-hosted (On-premise, IaaS, docker, local). | **OpenSource**
The first dense model released by Mistral AI,
perfect for experimentation,
customization, and quick iteration.

Max tokens 32K

Java Enum
`MistralAiChatModelName.OPEN_MISTRAL_7B` | +| open-mixtral-8x7b | - Mistral AI La Plateforme.
- Cloud platforms (Azure, AWS, GCP).
- Hugging Face.
- Self-hosted (On-premise, IaaS, docker, local). | **OpenSource**
Ideal to handle multi-languages operations,
code generationand fine-tuned.
Excellent cost/performance trade-offs.

Max tokens 32K

Java Enum
`MistralAiChatModelName.OPEN_MIXTRAL_8x7B` | +| open-mixtral-8x22b | - Mistral AI La Plateforme.
- Cloud platforms (Azure, AWS, GCP).
- Hugging Face.
- Self-hosted (On-premise, IaaS, docker, local). | **OpenSource**
It has all Mixtral-8x7B capabilities plus strong maths
and coding natively capable of function calling

Max tokens 64K.

Java Enum
`MistralAiChatModelName.OPEN_MIXTRAL_8X22B` | +| mistral-small-latest | - Mistral AI La Plateforme.
- Cloud platforms (Azure, AWS, GCP). | **Commercial**
Suitable for simple tasks that one can do in bulk
(Classification, Customer Support, or Text Generation).

Max tokens 32K

Java Enum
`MistralAiChatModelName.MISTRAL_SMALL_LATEST` | +| mistral-medium-latest | - Mistral AI La Plateforme.
- Cloud platforms (Azure, AWS, GCP). | **Commercial**
Ideal for intermediate tasks that require moderate
reasoning (Data extraction, Summarizing,
Writing emails, Writing descriptions.

Max tokens 32K

Java Enum
`MistralAiChatModelName.MISTRAL_MEDIUM_LATEST` | +| mistral-large-latest | - Mistral AI La Plateforme.
- Cloud platforms (Azure, AWS, GCP). | **Commercial**
Ideal for complex tasks that require large reasoning
capabilities or are highly specialized
(Text Generation, Code Generation, RAG, or Agents).

Max tokens 32K

Java Enum
`MistralAiChatModelName.MISTRAL_LARGE_LATEST` | +| mistral-embed | - Mistral AI La Plateforme.
- Cloud platforms (Azure, AWS, GCP). | **Commercial**
Converts text into numerical vectors of
embeddings in 1024 dimensions.
Embedding models enable retrieval and RAG applications.

Max tokens 8K

Java Enum
`MistralAiEmbeddingModelName.MISTRAL_EMBED` | + +`@Deprecated` models: - mistral-tiny (`@Deprecated`) - mistral-small (`@Deprecated`) - mistral-medium (`@Deprecated`) @@ -145,6 +150,18 @@ In [Set Model Parameters](/tutorials/model-parameters) you will learn how to set ### Function Calling Function calling allows Mistral chat models ([synchronous](#synchronous) and [streaming](#streaming)) to connect to external tools. For example, you can call a `Tool` to get the payment transaction status as shown in the Mistral AI function calling [tutorial](https://docs.mistral.ai/guides/function-calling/). +
+What are the supported mistral models? + +:::note +Currently, function calling is available for the following models: + +- Mistral Small `MistralAiChatModelName.MISTRAL_SMALL_LATEST` +- Mistral Large `MistralAiChatModelName.MISTRAL_LARGE_LATEST` +- Mixtral 8x22B `MistralAiChatModelName.OPEN_MIXTRAL_8X22B` +::: +
+ #### 1. Define a `Tool` class and how get the payment data Let's assume you have a dataset of payment transaction like this. In real applications you should inject a database source or REST API client to get the data. @@ -190,7 +207,7 @@ private String getPaymentData(String transactionId, String data) { } } ``` -It uses a `@Tool` annotation to define the function description and `@P` annotation to define the parameter description of the package `dev.langchain4j.agent.tool.*`. +It uses a `@Tool` annotation to define the function description and `@P` annotation to define the parameter description of the package `dev.langchain4j.agent.tool.*`. More info [here](/tutorials/tools#high-level-tool-api) #### 2. Define an interface as an `agent` to send chat messages. @@ -221,7 +238,7 @@ public class PaymentDataAssistantApp { ChatLanguageModel mistralAiModel = MistralAiChatModel.builder() .apiKey(System.getenv("MISTRAL_AI_API_KEY")) // Please use your own Mistral AI API key - .modelName(MistralAiChatModelName.MISTRAL_LARGE_LATEST) + .modelName(MistralAiChatModelName.MISTRAL_LARGE_LATEST) // Also you can use MistralAiChatModelName.OPEN_MIXTRAL_8X22B as open source model .logRequests(true) .logResponses(true) .build(); @@ -250,7 +267,108 @@ and expect an answer like this: ```shell The status of transaction T1005 is Pending. The payment date is October 8, 2021. ``` +### JSON mode +You can also use the JSON mode to get the response in JSON format. To do this, you need to set the `responseFormat` parameter to `json_object` or the java enum `MistralAiResponseFormatType.JSON_OBJECT` in the `MistralAiChatModel` builder OR `MistralAiStreamingChatModel` builder. + +Syncronous example: + +```java +ChatLanguageModel model = MistralAiChatModel.builder() + .apiKey(System.getenv("MISTRAL_AI_API_KEY")) // Please use your own Mistral AI API key + .responseFormat(MistralAiResponseFormatType.JSON_OBJECT) + .build(); + +String userMessage = "Return JSON with two fields: transactionId and status with the values T123 and paid."; +String json = model.generate(userMessage); + +System.out.println(json); // {"transactionId":"T123","status":"paid"} +``` + +Streaming example: + +```java +StreamingChatLanguageModel streamingModel = MistralAiStreamingChatModel.builder() + .apiKey(System.getenv("MISTRAL_AI_API_KEY")) // Please use your own Mistral AI API key + .responseFormat(MistralAiResponseFormatType.JSON_OBJECT) + .build(); + +String userMessage = "Return JSON with two fields: transactionId and status with the values T123 and paid."; + +CompletableFuture> futureResponse = new CompletableFuture<>(); + +streamingModel.generate(userMessage, new StreamingResponseHandler() { + @Override + public void onNext(String token) { + System.out.print(token); + } + + @Override + public void onComplete(Response response) { + futureResponse.complete(response); + } + + @Override + public void onError(Throwable error) { + futureResponse.completeExceptionally(error); + } +}); + +String json = futureResponse.get().content().text(); + +System.out.println(json); // {"transactionId":"T123","status":"paid"} +``` +### Guardrailing +Guardrails are a way to limit the behavior of the model to prevent it from generating harmful or unwanted content. You can set optionally `safePrompt` parameter in the `MistralAiChatModel` builder or `MistralAiStreamingChatModel` builder. + +Syncronous example: + +```java +ChatLanguageModel model = MistralAiChatModel.builder() + .apiKey(System.getenv("MISTRAL_AI_API_KEY")) + .safePrompt(true) + .build(); + +String userMessage = "What is the best French cheese?"; +String response = model.generate(userMessage); +``` + +Streaming example: + +```java +StreamingChatLanguageModel streamingModel = MistralAiStreamingChatModel.builder() + .apiKey(System.getenv("MISTRAL_AI_API_KEY")) + .safePrompt(true) + .build(); + +String userMessage = "What is the best French cheese?"; + +CompletableFuture> futureResponse = new CompletableFuture<>(); + +streamingModel.generate(userMessage, new StreamingResponseHandler() { + @Override + public void onNext(String token) { + System.out.print(token); + } + + @Override + public void onComplete(Response response) { + futureResponse.complete(response); + } + + @Override + public void onError(Throwable error) { + futureResponse.completeExceptionally(error); + } +}); + +futureResponse.join(); +``` +Toggling the safe prompt will prepend your messages with the following `@SystemMessage`: + +```plaintext +Always assist with care, respect, and truth. Respond with utmost utility yet securely. Avoid harmful, unethical, prejudiced, or negative content. Ensure replies promote fairness and positivity. +``` ### More examples If you want to check more MistralAI examples, you can find them in the [langchain4j-examples/mistral-ai-examples](https://github.com/langchain4j/langchain4j-examples/tree/main/mistral-ai-examples) project. diff --git a/docs/docs/integrations/language-models/ollama.md b/docs/docs/integrations/language-models/ollama.md index 619419ead0..338fdaf509 100644 --- a/docs/docs/integrations/language-models/ollama.md +++ b/docs/docs/integrations/language-models/ollama.md @@ -1,4 +1,5 @@ --- +id: ollama-language-model sidebar_position: 11 --- diff --git a/docs/docs/integrations/language-models/open-ai.md b/docs/docs/integrations/language-models/open-ai.md index c1947f17cc..3d7f5840e5 100644 --- a/docs/docs/integrations/language-models/open-ai.md +++ b/docs/docs/integrations/language-models/open-ai.md @@ -16,7 +16,7 @@ sidebar_position: 12 dev.langchain4j langchain4j-open-ai - 0.30.0 + 0.31.0 ``` diff --git a/docs/docs/integrations/language-models/qianfan.md b/docs/docs/integrations/language-models/qianfan.md index fb77c154ac..cda1da5086 100644 --- a/docs/docs/integrations/language-models/qianfan.md +++ b/docs/docs/integrations/language-models/qianfan.md @@ -2,4 +2,273 @@ sidebar_position: 13 --- -# Qianfan \ No newline at end of file +# Qianfan +[百度智能云千帆大模型](https://console.bce.baidu.com/qianfan/ais/console/applicationConsole/application) +![image](https://github.com/langchain4j/langchain4j/assets/95265298/600f8006-4484-4a75-829c-c8c16a3130c2) +## Maven Dependency(Maven依赖) + +```xml + + dev.langchain4j + langchain4j + {{version}} + + + dev.langchain4j + langchain4j-qianfan + {{version}} + + +``` + + +## QianfanChatModel +[千帆所有模型及付费状态](https://console.bce.baidu.com/qianfan/ais/console/onlineService) +```java + QianfanChatModel model = QianfanChatModel.builder() + .apiKey("apiKey") + .secretKey("secretKey") + .modelName("Yi-34B-Chat") // 一个免费的模型名称 + .build(); + + String answer = model.generate("雷军"); + + System.out.println(answer) +``` +### Customizing + +```java +QianfanChatModel model = QianfanChatModel.builder() + .baseUrl(...) + .apiKey(...) + .secretKey(...) + .temperature(...) + .maxRetries(...) + .topP(...) + .modelName(...) + .endpoint(...) + .responseFormat(...) + .penaltyScore(...) + .logRequests(...) + .logResponses() + .build(); +``` + +See the description of some of the parameters above [here](https://console.bce.baidu.com/tools/?u=qfdc#/api?product=QIANFAN&project=%E5%8D%83%E5%B8%86%E5%A4%A7%E6%A8%A1%E5%9E%8B%E5%B9%B3%E5%8F%B0&parent=Yi-34B-Chat&api=rpc%2F2.0%2Fai_custom%2Fv1%2Fwenxinworkshop%2Fchat%2Fyi_34b_chat&method=post). +### functions +**IAiService(重点)** +```java +public interface IAiService { + /** + * Ai Services 提供了一种更简单、更灵活的替代方案。 您可以定义自己的 API(具有一个或多个方法的 Java 接口), 并将为其提供实现。 + * @param userMessage + * @return String + */ + String chat(String userMessage); +} +``` +#### QianfanChatWithOnePersonMemory (带有一个人的聊天记忆) + +```java + + QianfanChatModel model = QianfanChatModel.builder() + .apiKey("apiKey") + .secretKey("secretKey") + .modelName("Yi-34B-Chat") + .build(); + /* MessageWindowChatMemory + functions as a sliding window, retaining the N most recent messages and evicting older ones that no longer fit. + However, because each message can contain a varying number of tokens, MessageWindowChatMemory is mostly useful for fast prototyping. + 保留最新的n条消息(包括回复) + */ + /* TokenWindowChatMemory + which also operates as a sliding window but focuses on keeping the N most recent tokens, evicting older messages as needed. Messages are indivisible. + If a message doesn't fit, it is evicted completely. + MessageWindowChatMemory requires a Tokenizer to count the tokens in each ChatMessage. + */ + ChatMemory chatMemory = MessageWindowChatMemory.builder() + .maxMessages(10) + .build(); + + IAiService assistant = AiServices.builder(IAiService.class) + .chatLanguageModel(model) // the model + .chatMemory(chatMemory) // memory + .build(); + String answer = assistant.chat("Hello,my name is xiaoyu"); + System.out.println(answer); // Hello xiaoyu!****** + + String answerWithName = assistant.chat("What's my name?"); + System.out.println(answerWithName); // Your name is xiaoyu.****** + + String answer1 = assistant.chat("I like playing football."); + System.out.println(answer1); // The answer + + String answer2 = assistant.chat("I want to go eat delicious food."); + System.out.println(answer2); // The answer + + String answerWithLike = assistant.chat("What I like to do?"); + System.out.println(answerWithLike);//Playing football.****** +``` + +#### QianfanChatWithMorePersonMemory (带有多个人的聊天记忆) + +```java + QianfanChatModel model = QianfanChatModel.builder() + .apiKey("apiKey") + .secretKey("secretKey") + .modelName("Yi-34B-Chat") + .build(); + IAiService assistant = AiServices.builder(IAiService.class) + .chatLanguageModel(model) // the model + .chatMemoryProvider(memoryId -> MessageWindowChatMemory.withMaxMessages(10)) // chatMemory + .build(); + + String answer = assistant.chat(1,"Hello, my name is xiaoyu"); + System.out.println(answer); // Hello xiaoyu!****** + String answer1 = assistant.chat(2,"Hello, my name is xiaomi"); + System.out.println(answer1); // Hello xiaomi!****** + + String answerWithName1 = assistant.chat(1,"What's my name?"); + System.out.println(answerWithName1); // Your name is xiaoyu. + String answerWithName2 = assistant.chat(2,"What's my name?"); + System.out.println(answerWithName2); // Your name is xiaomi. +``` + +#### QianfanChatWithPersistentMemory(持久化聊天记忆) + +```xml + + org.mapdb + mapdb + 3.1.0 + +``` +```java +class PersistentChatMemoryStore implements ChatMemoryStore { + private final DB db = DBMaker.fileDB("chat-memory.db").transactionEnable().make(); + private final Map map = db.hashMap("messages", STRING, STRING).createOrOpen(); + + @Override + public List getMessages(Object memoryId) { + String json = map.get((String) memoryId); + return messagesFromJson(json); + } + + @Override + public void updateMessages(Object memoryId, List messages) { + String json = messagesToJson(messages); + map.put((String) memoryId, json); + db.commit(); + } + + @Override + public void deleteMessages(Object memoryId) { + map.remove((String) memoryId); + db.commit(); + } +} + +class PersistentChatMemoryTest{ + public void test(){ + QianfanChatModel chatLanguageModel = QianfanChatModel.builder() + .apiKey("apiKey") + .secretKey("secretKey") + .modelName("Yi-34B-Chat") + .build(); + + ChatMemory chatMemory = MessageWindowChatMemory.builder() + .maxMessages(10) + .chatMemoryStore(new PersistentChatMemoryStore()) + .build(); + + IAiService assistant = AiServices.builder(IAiService.class) + .chatLanguageModel(chatLanguageModel) + .chatMemory(chatMemory) + .build(); + + String answer = assistant.chat("My name is xiaoyu"); + System.out.println(answer); + // Run it once and then comment the top to run the bottom(运行一次后注释上面运行下面) + // String answerWithName = assistant.chat("What is my name?"); + // System.out.println(answerWithName); + } +} + +``` + +#### QianfanStreamingChatModel(流式回复) +LLMs generate text one token at a time, so many LLM providers offer a way to stream the response token-by-token instead of waiting for the entire text to be generated. This significantly improves the user experience, as the user does not need to wait an unknown amount of time and can start reading the response almost immediately.(因此许多LLM提供者提供了一种逐个token地传输响应的方法,而不是等待生成整个文本。这极大地改善了用户体验,因为用户不需要等待未知的时间,几乎可以立即开始阅读响应。) +以下是一个通过StreamingResponseHandler来实现 +```java + QianfanStreamingChatModel qianfanStreamingChatModel = QianfanStreamingChatModel.builder() + .apiKey("apiKey") + .secretKey("secretKey") + .modelName("Yi-34B-Chat") + .build(); + qianfanStreamingChatModel.generate(userMessage, new StreamingResponseHandler() { + @Override + public void onNext(String token) { + System.out.print(token); + } + @Override + public void onComplete(Response response) { + System.out.println("onComplete: " + response); + } + @Override + public void onError(Throwable throwable) { + throwable.printStackTrace(); + } + }); +``` +以下是另一个通过TokenStream来实现 +```java + QianfanStreamingChatModel qianfanStreamingChatModel = QianfanStreamingChatModel.builder() + .apiKey("apiKey") + .secretKey("secretKey") + .modelName("Yi-34B-Chat") + .build(); + IAiService assistant = AiServices.create(IAiService.class, qianfanStreamingChatModel); + + TokenStream tokenStream = assistant.chatInTokenStream("Tell me a story."); + tokenStream.onNext(System.out::println) + .onError(Throwable::printStackTrace) + .start(); +``` +#### QianfanRAG + +程序自动将匹配的内容与用户问题组装成一个Prompt,向大语言模型提问,大语言模型返回答案 + +LangChain4j has an "Easy RAG" feature that makes it as easy as possible to get started with RAG. You don't have to learn about embeddings, choose a vector store, find the right embedding model, figure out how to parse and split documents, etc. Just point to your document(s), and LangChain4j will do its magic. + +- Import the dependency:langchain4j-easy-rag +```xml + + dev.langchain4j + langchain4j-easy-rag + {{version}} + +``` +- Use +```java + + QianfanChatModel chatLanguageModel = QianfanChatModel.builder() + .apiKey(API_KEY) + .secretKey(SECRET_KEY) + .modelName("Yi-34B-Chat") + .build(); + // All files in a directory, txt seems to be faster + List documents = FileSystemDocumentLoader.loadDocuments("/home/langchain4j/documentation"); + // for simplicity, we will use an in-memory one: + InMemoryEmbeddingStore embeddingStore = new InMemoryEmbeddingStore<>(); + EmbeddingStoreIngestor.ingest(documents, embeddingStore); + + IAiService assistant = AiServices.builder(IAiService.class) + .chatLanguageModel(chatLanguageModel) + .chatMemory(MessageWindowChatMemory.withMaxMessages(10)) + .contentRetriever(EmbeddingStoreContentRetriever.from(embeddingStore)) + .build(); + + String answer = assistant.chat("The Question"); + System.out.println(answer); + +``` diff --git a/docs/docs/integrations/language-models/workers-ai.md b/docs/docs/integrations/language-models/workers-ai.md new file mode 100644 index 0000000000..7a7a74b0a9 --- /dev/null +++ b/docs/docs/integrations/language-models/workers-ai.md @@ -0,0 +1,7 @@ +--- +sidebar_position: 14 +--- + +# Cloudflare Workers AI + +https://developers.cloudflare.com/workers-ai/ \ No newline at end of file diff --git a/docs/docs/integrations/language-models/zhipu-ai.md b/docs/docs/integrations/language-models/zhipu-ai.md index 3dfbdee293..307dcfe497 100644 --- a/docs/docs/integrations/language-models/zhipu-ai.md +++ b/docs/docs/integrations/language-models/zhipu-ai.md @@ -1,5 +1,5 @@ --- -sidebar_position: 14 +sidebar_position: 15 --- # ZhipuAI diff --git a/docs/docs/integrations/scoring-reranking-models/1-jina-ai.md b/docs/docs/integrations/scoring-reranking-models/1-jina-ai.md new file mode 100644 index 0000000000..ba75f0b3a4 --- /dev/null +++ b/docs/docs/integrations/scoring-reranking-models/1-jina-ai.md @@ -0,0 +1,48 @@ +--- +sidebar_position: 1 +--- + +# Jina + +- [Jina Reranker Documentation](https://jina.ai/reranker) +- [Jina Reranker API](https://api.jina.ai/redoc#tag/rerank) + + +### Introduction + +A reranker is an advanced AI model that takes the initial set of results from a search—often provided by an embeddings/token-based search—and reevaluates them to ensure they align more closely with the user's intent. +It looks beyond the surface-level matching of terms to consider the deeper interaction between the search query and the content of the documents. + + +### Maven Dependency + +```xml + + dev.langchain4j + langchain4j-jina + 0.31.0 + +``` + +### Usage + +```java + + +ScoringModel scoringModel = JinaScoringModel.withApiKey(System.getenv("JINA_API_KEY"));; + +ContentAggregator contentAggregator = ReRankingContentAggregator.builder() + .scoringModel(scoringModel) + ... + .build(); + +RetrievalAugmentor retrievalAugmentor = DefaultRetrievalAugmentor.builder() + ... + .contentAggregator(contentAggregator) + .build(); + +return AiServices.builder(Assistant.class) + .chatLanguageModel(...) + .retrievalAugmentor(retrievalAugmentor) + .build(); +``` diff --git a/docs/docs/integrations/scoring-reranking-models/2-cohere.md b/docs/docs/integrations/scoring-reranking-models/2-cohere.md new file mode 100644 index 0000000000..daeee8b4be --- /dev/null +++ b/docs/docs/integrations/scoring-reranking-models/2-cohere.md @@ -0,0 +1,9 @@ +--- +sidebar_position: 2 +--- + +# Cohere + +`Cohere` + +Information coming soon \ No newline at end of file diff --git a/docs/docs/integrations/scoring-reranking-models/_category_.json b/docs/docs/integrations/scoring-reranking-models/_category_.json new file mode 100644 index 0000000000..7d0ee167a6 --- /dev/null +++ b/docs/docs/integrations/scoring-reranking-models/_category_.json @@ -0,0 +1,8 @@ +{ + "label": "Scoring (Reranking) Models", + "position": 8, + "link": { + "type": "generated-index", + "description": "Scoring (Reranking) Models" + } +} diff --git a/docs/docs/intro.md b/docs/docs/intro.md index 48c20c5db6..2eead1853a 100644 --- a/docs/docs/intro.md +++ b/docs/docs/intro.md @@ -1,4 +1,5 @@ --- +id: introduction sidebar_position: 1 title: Introduction --- @@ -7,26 +8,26 @@ title: Introduction Welcome! -The goal of LangChain4j is to simplify integrating AI/LLM capabilities into Java applications. +The goal of LangChain4j is to simplify integrating LLMs into Java applications. Here's how: 1. **Unified APIs:** -LLM providers (like OpenAI or Google Vertex AI) and embedding (vector) stores (such as Pinecone or Vespa) -use proprietary APIs. LangChain4j offers a unified API to avoid the need for learning and implementing specific APIs for each of them. -To experiment with a different LLM or embedding store, you can easily switch between them without the need to rewrite your code. -LangChain4j currently supports over 10 popular LLM providers and more than 15 embedding stores. -Think of it as a Hibernate, but for LLMs and embedding stores. + LLM providers (like OpenAI or Google Vertex AI) and embedding (vector) stores (such as Pinecone or Milvus) + use proprietary APIs. LangChain4j offers a unified API to avoid the need for learning and implementing specific APIs for each of them. + To experiment with different LLMs or embedding stores, you can easily switch between them without the need to rewrite your code. + LangChain4j currently supports [15+ popular LLM providers](/integrations/language-models/) + and [15+ embedding stores](/integrations/embedding-stores/). 2. **Comprehensive Toolbox:** -During the past year, the community has been building numerous LLM-powered applications, -identifying common patterns, abstractions, and techniques. LangChain4j has refined these into practical code. -Our toolbox includes tools ranging from low-level prompt templating, memory management, and output parsing -to high-level patterns like Agents and RAGs. -For each pattern and abstraction, we provide an interface along with multiple ready-to-use implementations based on proven techniques. -Whether you're building a chatbot or developing a RAG with a complete pipeline from data ingestion to retrieval, -LangChain4j offers a wide variety of options. + Over the past year, the community has been building numerous LLM-powered applications, + identifying common abstractions, patterns, and techniques. LangChain4j has refined these into a ready to use package. + Our toolbox includes tools ranging from low-level prompt templating, chat memory management, and output parsing + to high-level patterns like AI Services and RAG. + For each abstraction, we provide an interface along with multiple ready-to-use implementations based on common techniques. + Whether you're building a chatbot or developing a RAG with a complete pipeline from data ingestion to retrieval, + LangChain4j offers a wide variety of options. 3. **Numerous Examples:** -These [examples](https://github.com/langchain4j/langchain4j-examples) showcase how to begin creating various LLM-powered applications, -providing inspiration and enabling you to start building quickly. + These [examples](https://github.com/langchain4j/langchain4j-examples) showcase how to begin creating various LLM-powered applications, + providing inspiration and enabling you to start building quickly. LangChain4j began development in early 2023 amid the ChatGPT hype. We noticed a lack of Java counterparts to the numerous Python and JavaScript LLM libraries and frameworks, @@ -36,33 +37,29 @@ LlamaIndex, and the broader community, spiced up with a touch of our own innovat We actively monitor community developments, aiming to quickly incorporate new techniques and integrations, ensuring you stay up-to-date. -The library is under active development. While some features from the Python version of LangChain -are still being worked on, the core functionality is in place, allowing you to start building LLM-powered apps now! +The library is under active development. While some features are still being worked on, +the core functionality is in place, allowing you to start building LLM-powered apps now! For easier integration, LangChain4j also includes integration with -Quarkus ([extension](https://quarkus.io/extensions/io.quarkiverse.langchain4j/quarkus-langchain4j-core)) -and Spring Boot ([starters](https://github.com/langchain4j/langchain4j-spring)). +[Quarkus](/tutorials/quarkus-integration) and [Spring Boot](/tutorials/spring-boot-integration). -### Features -- Integration with more than 10 managed and self-hosted language models (LLMs) for chat and completion -- Prompt templates -- Support for texts and images as inputs (multimodality) -- Streaming of responses from language models -- Tools for tokenization and estimation of token counts -- Output parsers for common Java types (e.g., `List`, `LocalDate`, etc.) and custom POJOs -- Integration with over three managed and self-hosted image generation models -- Integration with more than 10 managed and self-hosted embedding models -- Integration with more than 15 managed and self-hosted embedding stores + +## LangChain4j Features +- Integration with [15+ LLM providers](/integrations/language-models) +- Integration with [15+ embedding (vector) stores](/integrations/embedding-stores) +- Integration with [10+ embedding models](/category/embedding-models) +- Integration with [4 cloud and local image generation models](/category/image-models) +- Integration with [2 scoring (re-ranking) models](/category/scoring-reranking-models) - Integration with one moderation model: OpenAI -- Integration with one scoring (re-ranking) model: Cohere (with more expected to come) -- Tools (function calling) +- Support for texts and images as inputs (multimodality) +- [AI Services](/tutorials/ai-services) (high-level LLM API) +- Prompt templates +- Implementation of persistent and in-memory [chat memory](/tutorials/chat-memory) algorithms: message window and token window +- [Streaming of responses from LLMs](/tutorials/response-streaming) +- Output parsers for common Java types and custom POJOs +- [Tools (function calling)](/tutorials/tools) - Dynamic Tools (execution of dynamically generated LLM code) -- "Lite" agents (OpenAI functions) -- AI Services -- Chains (legacy) -- Implementation of persistent and in-memory chat memory algorithms: message window and token window -- Text classification -- RAG (Retrieval-Augmented-Generation): +- [RAG (Retrieval-Augmented-Generation)](/tutorials/rag): - Ingestion: - Importing various types of documents (TXT, PDFs, DOC, PPT, XLS etc.) from multiple sources (file system, URL, GitHub, Azure Blob Storage, Amazon S3, etc.) - Splitting documents into smaller segments using multiple splitting algorithms @@ -76,74 +73,62 @@ and Spring Boot ([starters](https://github.com/langchain4j/langchain4j-spring)). - Re-ranking - Reciprocal Rank Fusion - Customization of each step in the RAG flow +- Text classification +- Tools for tokenization and estimation of token counts -### 2 levels of abstraction +## 2 levels of abstraction LangChain4j operates on two levels of abstraction: -- Low level. At this level, you have the most freedom and access to all the low-level components such as +- [Low level](/tutorials/chat-and-language-models). At this level, you have the most freedom and access to all the low-level components such as `ChatLanguageModel`, `UserMessage`, `AiMessage`, `EmbeddingStore`, `Embedding`, etc. These are the "primitives" of your LLM-powered application. You have complete control over how to combine them, but you will need to write more glue code. -- High level. At this level, you interact with LLMs using high-level APIs like `AiServices` and `Chain`s, +- [High level](/tutorials/ai-services). At this level, you interact with LLMs using high-level APIs like `AiServices`, which hides all the complexity and boilerplate from you. You still have the flexibility to adjust and fine-tune the behavior, but it is done in a declarative manner. [![](/img/langchain4j-components.png)](/intro) -### Library Structure + +## LangChain4j Library Structure LangChain4j features a modular design, comprising: - The `langchain4j-core` module, which defines core abstractions (such as `ChatLanguageModel` and `EmbeddingStore`) and their APIs. - The main `langchain4j` module, containing useful tools like `ChatMemory`, `OutputParser` as well as a high-level features like `AiServices`. - A wide array of `langchain4j-{integration}` modules, each providing integration with various LLM providers and embedding stores into LangChain4j. You can use the `langchain4j-{integration}` modules independently. For additional features, simply import the main `langchain4j` dependency. -### Tutorials (User Guide) -Discover inspiring [use cases](/tutorials/#or-consider-some-of-the-use-cases) or follow our step-by-step introduction to LangChain4j features under [Tutorials](/category/tutorials). - -You will get a tour of all LangChain4j functionality in steps of increasing complexity. All steps are demonstrated with complete code examples and code explanation. - -### Integrations and Models -LangChain4j offers ready-to-use integrations with models of OpenAI, HuggingFace, Google, Azure, and many more. -It has document loaders for all common document types, and integrations with plenty of embedding models and embedding stores, to facilitate retrieval-augmented generation and AI-powered classification. -All integrations are listed [here](/category/integrations). - -### Code Examples - -You can browse through code examples in the `langchain4j-examples` repo: - -- [Examples in plain Java](https://github.com/langchain4j/langchain4j-examples/tree/main/other-examples/src/main/java) -- [Example with Spring Boot](https://github.com/langchain4j/langchain4j-examples/blob/main/spring-boot-example/src/test/java/dev/example/CustomerSupportApplicationTest.java) - -Quarkus specific examples (leveraging the [quarkus-langchain4j](https://github.com/quarkiverse/quarkus-langchain4j) -dependency which builds on this project) can be -found [here](https://github.com/quarkiverse/quarkus-langchain4j/tree/main/samples) - -### Useful Materials -[Useful Materials](https://docs.langchain4j.dev/useful-materials) - -### Disclaimer - -Please note that the library is in active development and: - -- Some features are still missing. We are working hard on implementing them ASAP. -- API might change at any moment. At this point, we prioritize good design in the future over backward compatibility - now. We hope for your understanding. -- We need your input! Please [let us know](https://github.com/langchain4j/langchain4j/issues/new/choose) what features - you need and your concerns about the current implementation. - -### Coming soon - -- Extending "AI Service" features -- Integration with more LLM providers (commercial and free) -- Integrations with more embedding stores (commercial and free) -- Support for more document types -- Long-term memory for chatbots and agents -- Chain-of-Thought and Tree-of-Thought - -### Request features - -Please [let us know](https://github.com/langchain4j/langchain4j/issues/new/choose) what features you need! - -### Contribute - -Please help us make this open-source library better by contributing to our [github repo](https://github.com/langchain4j/langchain4j). +## LangChain4j Repositories +- [Main repository](https://github.com/langchain4j/langchain4j) +- [Quarkus extension](https://github.com/quarkiverse/quarkus-langchain4j) +- [Spring Boot integration](https://github.com/langchain4j/langchain4j-spring) +- [Examples](https://github.com/langchain4j/langchain4j-examples) +- [Community resources](https://github.com/langchain4j/langchain4j-community-resources) +- [In-process embeddings](https://github.com/langchain4j/langchain4j-embeddings) + + +## Use Cases +You might ask why would I need all of this? +Here are some examples: + +- You want to implement a custom AI-powered chatbot that has access to your data and behaves the way you want it: + - Customer support chatbot that can: + - politely answer customer questions + - take /change/cancel orders + - Educational assistant that can: + - Teach various subjects + - Explain unclear parts + - Assess user's understanding/knowledge +- You want to process a lot of unstructured data (files, web pages, etc) and extract structured information from them. + For example: + - extract insights from customer reviews and support chat history + - extract interesting information from the websites of your competitors + - extract insights from CVs of job applicants +- You want to generate information, for example: + - Emails tailored for each of your customers + - Content for your app/website: + - Blog posts + - Stories +- You want to transform information, for example: + - Summarize + - Proofread and rewrite + - Translate diff --git a/docs/docs/tutorials/1-chat-and-language-models.md b/docs/docs/tutorials/1-chat-and-language-models.md index a0a4a72fc0..bdacf470c3 100644 --- a/docs/docs/tutorials/1-chat-and-language-models.md +++ b/docs/docs/tutorials/1-chat-and-language-models.md @@ -4,6 +4,15 @@ sidebar_position: 2 # Chat and Language Models +:::note +This page describes a low-level LLM API. +See [AI Services](/tutorials/ai-services) for a high-level LLM API. +::: + +:::note +All supported LLMs can be found [here](/integrations/language-models). +::: + LLMs are currently available in two API types: - `LanguageModel`s. Their API is very simple - they accept a `String` as input and return a `String` as output. This API is now becoming obsolete in favor of chat API (second API type). diff --git a/docs/docs/tutorials/2-chat-memory.md b/docs/docs/tutorials/2-chat-memory.md index bed5f905ea..e44c199073 100644 --- a/docs/docs/tutorials/2-chat-memory.md +++ b/docs/docs/tutorials/2-chat-memory.md @@ -1,4 +1,5 @@ --- +id: chat-memory sidebar_position: 3 --- @@ -16,6 +17,17 @@ or as a part of a high-level component like [AI Services](/tutorials/ai-services - Special treatment of `SystemMessage` - Special treatment of [tool](/tutorials/tools) messages +## Memory vs History + +Please note that "memory" and "history" are similar, yet distinct concepts. +- History keeps **all** messages between the user and AI **intact**. History is what the user sees in the UI. It represents what was actually said. +- Memory keeps **some information**, which is presented to the LLM to make it behave as if it "remembers" the conversation. +Memory is quite different from history. Depending on the memory algorithm used, it can modify history in various ways: +evict some messages, summarize multiple messages, summarize separate messages, remove unimportant details from messages, +inject extra information (e.g., for RAG) or instructions (e.g., for structured outputs) into messages, and so on. + +LangChain4j currently offers only "memory", not "history". If you need to keep an entire history, please do so manually. + ## Eviction policy An eviction policy is necessary for several reasons: @@ -82,6 +94,12 @@ The `updateMessages()` method is expected to update all messages associated with `ChatMessage`s can be stored either separately (e.g., one record/row/object per message) or together (e.g., one record/row/object for the entire `ChatMemory`). +:::note +Please note that messages evicted from `ChatMemory` will also be evicted from `ChatMemoryStore`. +When a message is evicted, the `updateMessages()` method is called +with a list of messages that does not include the evicted message. +::: + The `getMessages()` method is called whenever the user of the `ChatMemory` requests all messages. This typically happens once during each interaction with the LLM. The value of the `Object memoryId` argument corresponds to the `id` specified diff --git a/docs/docs/tutorials/4-response-streaming.md b/docs/docs/tutorials/4-response-streaming.md index d2202f7854..9d2bf43f23 100644 --- a/docs/docs/tutorials/4-response-streaming.md +++ b/docs/docs/tutorials/4-response-streaming.md @@ -1,9 +1,15 @@ --- +id: response-streaming sidebar_position: 5 --- # Response Streaming +:::note +This page describes response streaming with a low-level LLM API. +See [AI Services](/tutorials/ai-services#streaming) for a high-level LLM API. +::: + LLMs generate text one token at a time, so many LLM providers offer a way to stream the response token-by-token instead of waiting for the entire text to be generated. This significantly improves the user experience, as the user does not need to wait an unknown diff --git a/docs/docs/tutorials/5-ai-services.md b/docs/docs/tutorials/5-ai-services.md index d43621e197..31088e1d99 100644 --- a/docs/docs/tutorials/5-ai-services.md +++ b/docs/docs/tutorials/5-ai-services.md @@ -1,4 +1,5 @@ --- +id: ai-services sidebar_position: 6 --- @@ -163,13 +164,13 @@ you can change the return type of your AI Service method from `String` to someth Currently, AI Services support the following return types: - `String` - `AiMessage` -- `Response` (if you need to access `TokenUsage` or `FinishReason`) -- `boolean`/`Boolean` (if you need to get "yes" or "no" answer) +- `boolean`/`Boolean`, if you need to get "yes" or "no" answer - `byte`/`Byte`/`short`/`Short`/`int`/`Integer`/`BigInteger`/`long`/`Long`/`float`/`Float`/`double`/`Double`/`BigDecimal` - `Date`/`LocalDate`/`LocalTime`/`LocalDateTime` -- `List`/`Set` (if you want to get the answer in the form of a list of bullet points) -- Any `Enum` (if you want to classify text, e.g. sentiment, user intent, etc) +- `List`/`Set`, if you want to get the answer in the form of a list of bullet points +- Any `Enum`, if you want to classify text, e.g. sentiment, user intent, etc. - Any custom POJO +- `Result`, if you need to access `TokenUsage` or sources (`Content`s retrieved during RAG), aside from `T`, which can be of any type listed above. For example: `Result`, `Result` Unless the return type is `String`, `AiMessage`, or `Response`, the AI Service will automatically append instructions to the end of `UserMessage` indicating the format @@ -187,7 +188,7 @@ ChatLanguageModel model = OpenAiChatModel.builder() Now let's take a look at some examples. -`Enum` and `boolean` as return types: +### `Enum` and `boolean` as return types ```java enum Sentiment { POSITIVE, NEUTRAL, NEGATIVE @@ -211,7 +212,7 @@ boolean positive = sentimentAnalyzer.isPositive("It's awful!"); // false ``` -Custom POJO as a return type: +### Custom POJO as a return type ```java class Person { String firstName; @@ -297,7 +298,7 @@ AzureOpenAiChatModel.builder() ```java MistralAiChatModel.builder() ... - .responseFormat(JSON_OBJECT) + .responseFormat(MistralAiResponseFormatType.JSON_OBJECT) .build(); ``` @@ -314,25 +315,147 @@ prompt engineering is your best bet. Also, try lowering the `temperature` for mo [More examples](https://github.com/langchain4j/langchain4j-examples/blob/main/other-examples/src/main/java/OtherServiceExamples.java) + ## Streaming -[Example](https://github.com/langchain4j/langchain4j-examples/blob/main/other-examples/src/main/java/ServiceWithStreamingExample.java) + +The AI Service can [stream response](/tutorials/response-streaming) token-by-token +when using the `TokenStream` return type: +```java + +interface Assistant { + + TokenStream chat(String message); +} + +StreamingChatLanguageModel model = OpenAiStreamingChatModel.withApiKey(System.getenv("OPENAI_API_KEY")); + +Assistant assistant = AiServices.create(Assistant.class, model); + +TokenStream tokenStream = assistant.chat("Tell me a joke"); + +tokenStream.onNext(System.out::println) + .onComplete(System.out::println) + .onError(Throwable::printStackTrace) + .start(); +``` + +[Streaming example](https://github.com/langchain4j/langchain4j-examples/blob/main/other-examples/src/main/java/ServiceWithStreamingExample.java) + ## Chat Memory -[Example with a single ChatMemory](https://github.com/langchain4j/langchain4j-examples/blob/main/other-examples/src/main/java/ServiceWithMemoryExample.java) -[Example with ChatMemory for each user](https://github.com/langchain4j/langchain4j-examples/blob/main/other-examples/src/main/java/ServiceWithMemoryForEachUserExample.java) -[Example with a single persistent ChatMemory](https://github.com/langchain4j/langchain4j-examples/blob/main/other-examples/src/main/java/ServiceWithPersistentMemoryExample.java) -[Example with persistent ChatMemory for each user](https://github.com/langchain4j/langchain4j-examples/blob/main/other-examples/src/main/java/ServiceWithPersistentMemoryForEachUserExample.java) + +The AI Service can use [chat memory](/tutorials/chat-memory) in order to "remember" previous interactions: +```java +Assistant assistant = AiServices.builder(Assistant.class) + .chatLanguageModel(model) + .chatMemory(MessageWindowChatMemory.withMaxMessages(10)) + .build(); +``` +In this scenario, the same `ChatMemory` instance will be used for all invocations of the AI Service. +However, this approach will not work if you have multiple users, +as each user would require their own instance of `ChatMemory` to maintain their individual conversation. + +The solution to this issue is to use `ChatMemoryProvider`: +```java + +interface Assistant { + String chat(@MemoryId int memoryId, @UserMessage String message); +} + +Assistant assistant = AiServices.builder(Assistant.class) + .chatLanguageModel(model) + .chatMemoryProvider(memoryId -> MessageWindowChatMemory.withMaxMessages(10)) + .build(); + +String answerToKlaus = assistant.chat(1, "Hello, my name is Klaus"); +String answerToFrancine = assistant.chat(2, "Hello, my name is Francine"); +``` +In this scenario, two distinct instances of `ChatMemory` will be provided by `ChatMemoryProvider`, one for each memory ID. + +:::note +Please note that if an AI Service method does not have a parameter annotated with `@MemoryId`, +the value of `memoryId` in `ChatMemoryProvider` will default to a string `"default"`. +::: + +- [Example with a single ChatMemory](https://github.com/langchain4j/langchain4j-examples/blob/main/other-examples/src/main/java/ServiceWithMemoryExample.java) +- [Example with ChatMemory for each user](https://github.com/langchain4j/langchain4j-examples/blob/main/other-examples/src/main/java/ServiceWithMemoryForEachUserExample.java) +- [Example with a single persistent ChatMemory](https://github.com/langchain4j/langchain4j-examples/blob/main/other-examples/src/main/java/ServiceWithPersistentMemoryExample.java) +- [Example with persistent ChatMemory for each user](https://github.com/langchain4j/langchain4j-examples/blob/main/other-examples/src/main/java/ServiceWithPersistentMemoryForEachUserExample.java) + ## Tools (Function Calling) -[Tools](/tutorials/tools) + +AI Service can be configured with tools that LLM can use: + +```java + +class Tools { + + @Tool + int add(int a, int b) { + return a + b; + } + + @Tool + int multiply(int a, int b) { + return a * b; + } +} + +Assistant assistant = AiServices.builder(Assistant.class) + .chatLanguageModel(model) + .tools(new Tools()) + .build(); + +String answer = assistant.chat("What is 1+2 and 3*4?"); +``` +In this scenario, LLM will execute `add(1, 2)` and `multiply(3, 4)` methods before providing an answer. + +More details about tools can be found [here](/tutorials/tools). + ## RAG -[Example](https://github.com/langchain4j/langchain4j-examples/blob/main/other-examples/src/main/java/ServiceWithRetrieverExample.java) + +AI Service can be configured with a `ContentRetriever` in order to enable RAG: +```java + +EmbeddingStore embeddingStore = ... +EmbeddingModel embeddingModel = ... + +ContentRetriever contentRetriever = new EmbeddingStoreContentRetriever(embeddingStore, embeddingModel); + +Assistant assistant = AiServices.builder(Assistant.class) + .chatLanguageModel(model) + .contentRetriever(contentRetriever) + .build(); +``` + +Configuring a `RetrievalAugmentor` provides even more flexibility, +enabling advanced RAG capabilities such as query transformation, re-ranking, etc: +```java +RetrievalAugmentor retrievalAugmentor = DefaultRetrievalAugmentor.builder() + .queryTransformer(...) + .queryRouter(...) + .contentAggregator(...) + .contentInjector(...) + .build(); + +Assistant assistant = AiServices.builder(Assistant.class) + .chatLanguageModel(model) + .retrievalAugmentor(retrievalAugmentor) + .build(); +``` + +More details about RAG can be found [here](/tutorials/rag). + +More RAG examples can be found [here](https://github.com/langchain4j/langchain4j-examples/tree/main/rag-examples/src/main/java). + ## Auto-Moderation [Example](https://github.com/langchain4j/langchain4j-examples/blob/main/other-examples/src/main/java/ServiceWithAutoModerationExample.java) -## Chaining + +## Chaining multiple AI Services The more complex the logic of your LLM-powered application becomes, the more crucial it is to break it down into smaller parts, as is common practice in software development. @@ -437,7 +560,6 @@ Also, I can integration test `GreetingExpert` and `ChatBot` separately. I can evaluate both of them separately and find the most optimal parameters for each subtask, or, in the long run, even fine-tune a small specialized model for each specific subtask. -TODO ## Related Tutorials - [LangChain4j AiServices Tutorial](https://www.sivalabs.in/langchain4j-ai-services-tutorial/) by [Siva](https://www.sivalabs.in/) diff --git a/docs/docs/tutorials/6-tools.md b/docs/docs/tutorials/6-tools.md index 15011dbb7b..c11b0cb5a4 100644 --- a/docs/docs/tutorials/6-tools.md +++ b/docs/docs/tutorials/6-tools.md @@ -6,6 +6,10 @@ sidebar_position: 7 Some LLMs, in addition to generating text, can also trigger actions. +:::note +All LLMs supporting tools can be found [here](/integrations/language-models) (see the "Tools" column). +::: + There is a concept known as "tools," or "function calling". It allows the LLM to call, when necessary, one or more available tools, usually defined by the developer. A tool can be anything: a web search, a call to an external API, or the execution of a specific piece of code, etc. @@ -111,6 +115,7 @@ Please note that not all models support tools. Currently, the following models have tool support: - `OpenAiChatModel` - `AzureOpenAiChatModel` +- `MistralAiChatModel` - `LocalAiChatModel` - `QianfanChatModel` ::: diff --git a/docs/docs/tutorials/7-rag.md b/docs/docs/tutorials/7-rag.md index 9e5b5e88d9..41d3d32ba0 100644 --- a/docs/docs/tutorials/7-rag.md +++ b/docs/docs/tutorials/7-rag.md @@ -1,24 +1,79 @@ --- +id: rag sidebar_position: 8 --- # RAG (Retrieval-Augmented Generation) -[Great tutorial on RAG](https://www.sivalabs.in/langchain4j-retrieval-augmented-generation-tutorial/) -by [Siva](https://www.sivalabs.in/). - LLM's knowledge is limited to the data it has been trained on. If you want to make an LLM aware of domain-specific knowledge or proprietary data, you can: - Use RAG, which we will cover in this section - Fine-tune the LLM with your data - [Combine both RAG and fine-tuning](https://gorilla.cs.berkeley.edu/blogs/9_raft.html) + ## What is RAG? Simply put, RAG is the way to find and inject relevant pieces of information from your data into the prompt before sending it to the LLM. This way LLM will get (hopefully) relevant information and will be able to reply using this information, which should reduce the probability of hallucinations. +Relevant pieces of information can be found using various +[information retrieval](https://en.wikipedia.org/wiki/Information_retrieval) methods. +The most popular are: +- Full-text (keyword) search. This method uses techniques like TF-IDF and BM25 +to search documents by matching the keywords in a query (e.g., what the user is asking) +against a database of documents. +It ranks results based on the frequency and relevance of these keywords in each document. +- Vector search, also known as "semantic search". +Text documents are converted into vectors of numbers using embedding models. +It then finds and ranks documents based on the cosine similarity +or other similarity/distance measures between the query vector and document vectors, +thus capturing deeper semantic meanings. +- Hybrid. Combining multiple search methods (e.g., full-text + vector) usually improves the effectiveness of the search. + +Currently, this page focuses mostly on vector search. +Full-text and hybrid search are currently supported only by Azure AI Search integration, +see `AzureAiSearchContentRetriever` for more details. +We plan to expand the RAG toolbox to include full-text and hybrid search in the near future. + + +## RAG Stages +THe RAG process is divided into 2 distinct stages: indexing and retrieval. +LangChain4j provides tools for both stages. + +### Indexing + +During the indexing stage, documents are pre-processed in a way that enables efficient search during the retrieval stage. + +This process can vary depending on the information retrieval method used. +For vector search, this typically involves cleaning the documents, enriching them with additional data and metadata, +splitting them into smaller segments (aka chunking), embedding these segments, and finally storing them in an embedding store (aka vector database). + +The indexing stage usually occurs offline, meaning it does not require end users to wait for its completion. +This can be achieved through, for example, a cron job that re-indexes internal company documentation once a week during the weekend. +The code responsible for indexing can also be a separate application that only handles indexing tasks. + +However, in some scenarios, end users may want to upload their custom documents to make them accessible to the LLM. +In this case, indexing should be performed online and be a part of the main application. + +Here is a simplified diagram of the indexing stage: +[![](/img/rag-ingestion.png)](/tutorials/rag) + + +### Retrieval + +The retrieval stage usually occurs online, when a user submits a question that should be answered using the indexed documents. + +This process can vary depending on the information retrieval method used. +For vector search, this typically involves embedding the user's query (question) +and performing a similarity search in the embedding store. +Relevant segments (pieces of the original documents) are then injected into the prompt and sent to the LLM. + +Here is a simplified diagram of the retrieval stage: +[![](/img/rag-retrieval.png)](/tutorials/rag) + + ## Easy RAG LangChain4j has an "Easy RAG" feature that makes it as easy as possible to get started with RAG. You don't have to learn about embeddings, choose a vector store, find the right embedding model, @@ -42,7 +97,7 @@ adjusting and customizing more and more aspects. dev.langchain4j langchain4j-easy-rag - 0.30.0 + 0.31.0 ``` @@ -82,8 +137,8 @@ in glob: `glob:**.pdf`. 3. Now, we need to preprocess and store documents in a specialized embedding store, also known as vector database. -This is necessary to quickly find relevant pieces of information on the fly when a user asks a question. -We can use any of our 15+ [supported embedding stores](/category/embedding-stores), +This is necessary to quickly find relevant pieces of information when a user asks a question. +We can use any of our 15+ [supported embedding stores](/integrations/embedding-stores), but for simplicity, we will use an in-memory one: ```java InMemoryEmbeddingStore embeddingStore = new InMemoryEmbeddingStore<>(); @@ -137,15 +192,74 @@ and retrieve relevant content from an `EmbeddingStore` that contains our documen String answer = assistant.chat("How to do Easy RAG with LangChain4j?"); ``` +## Accessing Sources +If you wish to access the sources (retrieved `Content`s used to augment the message), +you can easily do so by wrapping the return type in the `Result` class: +```java +interface Assistant { + + Result chat(String userMessage); +} + +Result result = assistant.chat("How to do Easy RAG with LangChain4j?"); + +String answer = result.content(); +List sources = result.sources(); +``` + ## RAG APIs LangChain4j offers a rich set of APIs to make it easy for you to build custom RAG pipelines, -ranging from very simple ones to very advanced ones. In this section, we will cover the main domain classes and APIs. +ranging from simple ones to advanced ones. +In this section, we will cover the main domain classes and APIs. + ### Document A `Document` class represents an entire document, such as a single PDF file or a web page. At the moment, the `Document` can only represent textual information, but future updates will enable it to support images and tables as well. +
+Useful methods + +- `Document.text()` returns the text of the `Document` +- `Document.metadata()` returns the `Metadata` of the `Document` (see below) +- `Document.toTextSegment()` converts the `Document` into a `TextSegment` (see below) +- `Document.from(String, Metadata)` creates a `Document` from text and `Metadata` +- `Document.from(String)` creates a `Document` from text with empty `Metadata` +
+ +### Metadata +Each `Document` contains `Metadata`. +It stores meta information about the `Document`, such as its name, source, last update date, owner, +or any other relevant details. + +The `Metadata` is stored as a key-value map, where the key is of the `String` type, +and the value can be one of the following types: `String`, `Integer`, `Long`, `Float`, `Double`. + +`Metadata` is useful for several reasons: +- When including the content of the `Document` in a prompt to the LLM, +metadata entries can also be included, providing the LLM with additional information to consider. +For example, providing the `Document` name and source can help improve the LLM's understanding of the content. +- When searching for relevant content to include in the prompt, +one can filter by `Metadata` entries. +For example, you can narrow down a semantic search to only `Document`s +belonging to a specific owner. +- When the source of the `Document` is updated (for example, a specific page of documentation), +one can easily locate the corresponding `Document` by its metadata entry (for example, "id", "source", etc.) +and update it in the `EmbeddingStore` as well to keep it in sync. + +
+Useful methods + +- `Metadata.from(Map)` creates `Metadata` from a `Map` +- `Metadata.put(String key, String value)` / `put(String, int)` / etc., adds an entry to the `Metadata` +- `Metadata.getString(String key)` / `getInteger(String key)` / etc., returns a value of the `Metadata` entry, casting it to the required type +- `Metadata.containsKey(String key)` checks whether `Metadata` contains an entry with the specified key +- `Metadata.remove(String key)` removes an entry from the `Metadata` by key +- `Metadata.copy()` returns a copy of the `Metadata` +- `Metadata.toMap()` converts `Metadata` into a `Map` +
+ ### Document Loader You can create a `Document` from a `String`, but a simpler method is to use one of our document loaders included in the library: - `FileSystemDocumentLoader` from the `langchain4j` module @@ -155,6 +269,7 @@ You can create a `Document` from a `String`, but a simpler method is to use one - `GitHubDocumentLoader` from the `langchain4j-document-loader-github` module - `TencentCosDocumentLoader` from the `langchain4j-document-loader-tencent-cos` module + ### Document Parser `Document`s can represent files in various formats, such as PDF, DOC, TXT, etc. To parse each of these formats, there's a `DocumentParser` interface with several implementations included in the library: @@ -188,13 +303,22 @@ If no `DocumentParser`s are found through SPI, a `TextDocumentParser` is used as ### Document Transformer -`DocumentTransformer` implementations can perform a variety of tasks such as transforming documents, -cleaning them, filtering, enriching, etc. +`DocumentTransformer` implementations can perform a variety of document transformations such as: +- Cleaning: This involves removing unnecessary noise from the `Document`'s text, which can save tokens and reduce distractions. +- Filtering: to completely exclude particular `Document`s from the search. +- Enriching: Additional information can be added to `Document`s to potentially enhance search results. +- Summarizing: The `Document` can be summarized, and its short summary can be stored in the `Metadata` +to be later included in each `TextSegment` (which we will cover below) to potentially improve the search. +- Etc. + +`Metadata` entries can also be added, modified, or removed at this stage. Currently, the only implementation provided out-of-the-box is `HtmlTextExtractor` in the `langchain4j` module, -which can extract desired text content and metadata from an HTML document. +which can extract desired text content and metadata entries from the raw HTML. + +Since there is no one-size-fits-all solution, we recommend implementing your own `DocumentTransformer`, +tailored to your unique data. -You can implement your own `DocumentTransformer` and plug it into the LangChain4j RAG pipeline. ### Text Segment Once your `Document`s are loaded, it is time to split (chunk) them into smaller segments (pieces). @@ -243,6 +367,15 @@ providing the LLM with additional information before and after the retrieved seg +
+Useful methods + +- `TextSegment.text()` returns the text of the `TextSegment` +- `TextSegment.metadata()` returns the `Metadata` of the `TextSegment` +- `TextSegment.from(String, Metadata)` creates a `TextSegment` from text and `Metadata` +- `TextSegment.from(String)` creates a `TextSegment` from text with empty `Metadata` +
+ ### Document Splitter LangChain4j has a `DocumentSplitter` interface with several out-of-the-box implementations: - `DocumentByParagraphSplitter` @@ -266,35 +399,152 @@ a document into sentences, and so on. attempting to include as many units as possible in a single `TextSegment` without exceeding the limit set in step 1. If some of the units are still too large to fit into a `TextSegment`, it calls a sub-splitter. This is another `DocumentSplitter` capable of splitting units that do not fit into more granular units. +All `Metadata` entries are copied from the `Document` to each `TextSegment`. +A unique metadata entry "index" is added to each text segment. +The first `TextSegment` will contain `index=0`, the second `index=1`, and so on. + ### Text Segment Transformer -More details are coming soon. +`TextSegmentTransformer` is similar to `DocumentTransformer` (described above), but it transforms `TextSegment`s. + +As with the `DocumentTransformer`, there is no one-size-fits-all solution, +so we recommend implementing your own `TextSegmentTransformer`, tailored to your unique data. + +One technique that works quite well for improving retrieval is to include the `Document` title or a short summary +in each `TextSegment`. + ### Embedding -More details are coming soon. +The `Embedding` class encapsulates a numerical vector that represents the "semantic meaning" +of the content that has been embedded (usually text, such as a `TextSegment`). + +Read more about vector embeddings here: +- https://www.elastic.co/what-is/vector-embedding +- https://www.pinecone.io/learn/vector-embeddings/ +- https://cloud.google.com/blog/topics/developers-practitioners/meet-ais-multitool-vector-embeddings + +
+Useful methods + +- `Embedding.dimension()` returns the dimension of the embedding vector (its length) +- `CosineSimilarity.between(Embedding, Embedding)` calculates the cosine similarity between 2 `Embedding`s +- `Embedding.normalize()` normalizes the embedding vector (in place) +
+ ### Embedding Model -More details are coming soon. +The `EmbeddingModel` interface represents a special type of model that converts text into an `Embedding`. Currently supported embedding models can be found [here](/category/embedding-models). +
+Useful methods + +- `EmbeddingModel.embed(String)` embeds the given text +- `EmbeddingModel.embed(TextSegment)` embeds the given `TextSegment` +- `EmbeddingModel.embedAll(List)` embeds all the given `TextSegment` +
+ + ### Embedding Store -More details are coming soon. +The `EmbeddingStore` interface represents a store for `Embedding`s, also known as vector database. +It allows for the storage and efficient search of similar (close in the embedding space) `Embedding`s. Currently supported embedding stores can be found [here](/category/embedding-stores). +`EmbeddingStore` can store `Embedding`s alone or together with the corresponding `TextSegment`: +- It can store only `Embedding`, by ID. Original embedded data can be stored elsewhere and correlated using the ID. +- It can store both `Embedding` and the original data that has been embedded (usually `TextSegment`). + +
+Useful methods + +- `EmbeddingStore.add(Embedding)` adds a given `Embedding` to the store and returns a random ID +- `EmbeddingStore.add(String id, Embedding)` adds a given `Embedding` with a specified ID to the store +- `EmbeddingStore.add(Embedding, TextSegment)` adds a given `Embedding` with an associated `TextSegment` to the store and returns a random ID +- `EmbeddingStore.addAll(List)` adds a list of given `Embedding`s to the store and returns a list of random IDs +- `EmbeddingStore.addAll(List, List)` adds a list of given `Embedding`s with associated `TextSegment`s to the store and returns a list of random IDs +- `EmbeddingStore.search(EmbeddingSearchRequest)` searches for the most similar `Embedding`s +- `EmbeddingStore.remove(String id)` removes a single `Embedding` from the store by ID +- `EmbeddingStore.removeAll(Collection ids)` removes multiple `Embedding`s from the store by ID +- `EmbeddingStore.removeAll(Filter)` removes all `Embedding`s that match the specified `Filter` from the store +- `EmbeddingStore.removeAll()` removes all `Embedding`s from the store +
+ + +#### EmbeddingSearchRequest +The `EmbeddingSearchRequest` represents a request to search in an `EmbeddingStore`. +It has the following attributes: +- `Embedding queryEmbedding`: The embedding used as a reference. +- `int maxResults`: The maximum number of results to return. This is an optional parameter. Default: 3. +- `double minScore`: The minimum score, ranging from 0 to 1 (inclusive). Only embeddings with a score >= `minScore` will be returned. This is an optional parameter. Default: 0. +- `Filter filter`: The filter to be applied to the `Metadata` during search. Only `TextSegment`s whose `Metadata` matches the `Filter` will be returned. + +#### Filter +More details about `Filter` can be found [here](https://github.com/langchain4j/langchain4j/pull/610). + + +#### EmbeddingSearchResult +The EmbeddingSearchResult represents a result of a search in an `EmbeddingStore`. +It contains the list of `EmbeddingMatch`es. + + +#### Embedding Match +The `EmbeddingMatch` represents a matched `Embedding` along with its relevance score, ID, and original embedded data (usually `TextSegment`). + + ### Embedding Store Ingestor -More details are coming soon. +The `EmbeddingStoreIngestor` represents an ingestion pipeline and is responsible for +ingesting `Document`s into an `EmbeddingStore`. -## RAG Stages +In the simplest configuration, `EmbeddingStoreIngestor` embeds provided `Document`s +using a specified `EmbeddingModel` and stores them, along with their `Embedding`s in a specified `EmbeddingStore`: + +```java +EmbeddingStoreIngestor ingestor = EmbeddingStoreIngestor.builder() + .embeddingModel(embeddingModel) + .embeddingStore(embeddingStore) + .build(); + +ingestor.ingest(document1); +ingestor.ingest(document2, document3); +ingestor.ingest(List.of(document4, document5, document6)); +``` -### Ingestion +Optionally, the `EmbeddingStoreIngestor` can transform `Document`s using a specified `DocumentTransformer`. +This can be useful if you want to clean, enrich, or format `Document`s before embedding them. -[![](/img/rag-ingestion.png)](/tutorials/rag) +Optionally, the `EmbeddingStoreIngestor` can split `Document`s into `TextSegment`s using a specified `DocumentSplitter`. +This can be useful if `Document`s are big, and you want to split them into smaller `TextSegment`s to improve the quality +of similarity searches and reduce the size and cost of a prompt sent to the LLM. -### Retrieval +Optionally, the `EmbeddingStoreIngestor` can transform `TextSegment`s using a specified `TextSegmentTransformer`. +This can be useful if you want to clean, enrich, or format `TextSegment`s before embedding them. + +An example: +```java +EmbeddingStoreIngestor ingestor = EmbeddingStoreIngestor.builder() + + // adding userId metadata entry to each Document to be able to filter by it later + .documentTransformer(document -> { + document.metadata().put("userId", "12345"); + return document; + }) + + // splitting each Document into TextSegments of 1000 tokens each, with a 200-token overlap + .documentSplitter(DocumentSplitters.recursive(1000, 200, new OpenAiTokenizer())) + + // adding a name of the Document to each TextSegment to improve the quality of search + .textSegmentTransformer(textSegment -> TextSegment.from( + textSegment.metadata("file_name") + "\n" + textSegment.text(), + textSegment.metadata() + )) + + .embeddingModel(embeddingModel) + .embeddingStore(embeddingStore) + .build(); +``` -[![](/img/rag-retrieval.png)](/tutorials/rag) ## Advanced RAG More details are coming soon. @@ -302,6 +552,35 @@ In the meantime, please read [this](https://github.com/langchain4j/langchain4j/p [![](/img/advanced-rag.png)](/tutorials/rag) + +### Retrieval Augmentor +More details are coming soon. + + +### Default Retrieval Augmentor +More details are coming soon. + + +### Query Transformer +More details are coming soon. + + +### Query Router +More details are coming soon. + + +### Content Retriever +More details are coming soon. + + +### Content Aggregator +More details are coming soon. + + +### Content Injector +More details are coming soon. + + ## Examples - [Easy RAG](https://github.com/langchain4j/langchain4j-examples/blob/main/rag-examples/src/main/java/_1_easy/Easy_RAG_Example.java) @@ -310,6 +589,7 @@ In the meantime, please read [this](https://github.com/langchain4j/langchain4j/p - [Advanced RAG with Query Routing](https://github.com/langchain4j/langchain4j-examples/blob/main/rag-examples/src/main/java/_3_advanced/_02_Advanced_RAG_with_Query_Routing_Example.java) - [Advanced RAG with Re-Ranking](https://github.com/langchain4j/langchain4j-examples/blob/main/rag-examples/src/main/java/_3_advanced/_03_Advanced_RAG_with_ReRanking_Example.java) - [Advanced RAG with Including Metadata](https://github.com/langchain4j/langchain4j-examples/blob/main/rag-examples/src/main/java/_3_advanced/_04_Advanced_RAG_with_Metadata_Example.java) +- [Advanced RAG with multiple Retrievers](https://github.com/langchain4j/langchain4j-examples/blob/main/rag-examples/src/main/java/_3_advanced/_07_Advanced_RAG_Multiple_Retrievers_Example.java) - [Skipping Retrieval](https://github.com/langchain4j/langchain4j-examples/blob/main/rag-examples/src/main/java/_3_advanced/_06_Advanced_RAG_Skip_Retrieval_Example.java) - [RAG + Tools](https://github.com/langchain4j/langchain4j-examples/blob/main/customer-support-agent-example/src/test/java/dev/langchain4j/example/CustomerSupportAgentApplicationTest.java) - [Loading Documents](https://github.com/langchain4j/langchain4j-examples/blob/main/other-examples/src/main/java/DocumentLoaderExamples.java) diff --git a/docs/docs/tutorials/chains.md b/docs/docs/tutorials/chains.md deleted file mode 100644 index aeae7b373a..0000000000 --- a/docs/docs/tutorials/chains.md +++ /dev/null @@ -1,7 +0,0 @@ ---- -sidebar_position: 14 ---- - -# Chains - -Legacy Chains and chaining of AI Services is discussed [here](/tutorials/ai-services). diff --git a/docs/docs/tutorials/embedding-stores.md b/docs/docs/tutorials/embedding-stores.md index c69634cf08..d8067ff0d6 100644 --- a/docs/docs/tutorials/embedding-stores.md +++ b/docs/docs/tutorials/embedding-stores.md @@ -4,6 +4,11 @@ sidebar_position: 13 # Embedding (Vector) Stores +Documentation on embedding stores can be found [here](/tutorials/rag#embedding-store). + +All supported embedding stores can be found [here](/integrations/embedding-stores/). + +## Examples - [Example of using in-memory embedding store](https://github.com/langchain4j/langchain4j-examples/blob/main/other-examples/src/main/java/embedding/store/InMemoryEmbeddingStoreExample.java) - [Example of using Chroma embedding store](https://github.com/langchain4j/langchain4j-examples/blob/main/chroma-example/src/main/java/ChromaEmbeddingStoreExample.java) - [Example of using Elasticsearch embedding store](https://github.com/langchain4j/langchain4j-examples/blob/main/elasticsearch-example/src/main/java/ElasticsearchEmbeddingStoreExample.java) @@ -16,5 +21,3 @@ sidebar_position: 13 - [Example of using Vespa embedding store](https://github.com/langchain4j/langchain4j-examples/blob/main/vespa-example/src/main/java/VespaEmbeddingStoreExample.java) - [Example of using Weaviate embedding store](https://github.com/langchain4j/langchain4j-examples/blob/main/weaviate-example/src/main/java/WeaviateEmbeddingStoreExample.java) - [Example of using PGVector embedding store](https://github.com/langchain4j/langchain4j-examples/blob/main/pgvector-example/src/main/java/PgVectorEmbeddingStoreExample.java) - -More info coming soon diff --git a/docs/docs/tutorials/index.mdx b/docs/docs/tutorials/index.mdx deleted file mode 100644 index b22c18bdad..0000000000 --- a/docs/docs/tutorials/index.mdx +++ /dev/null @@ -1,70 +0,0 @@ ---- -title: Overview -hide_title: false -sidebar_position: 1 ---- - -Here you will find tutorials covering all of LangChain4j's functionality, to guide you through the framework in steps of increasing complexity. - -We will typically use OpenAI models for demonstration purposes, but we support a lot of other model providers too. The entire list of supported models can be found [here](/category/integrations). - -## Need inspiration? - -### Watch the talk by [Lize Raes](https://github.com/LizeRaes) at Devoxx Belgium - - - -### Talk by [Vaadin](https://vaadin.com/) team about Building a RAG AI system in Spring Boot & LangChain4j - - - -### Fireside Chat: LangChain4j & Quarkus by [Quarkusio](https://quarkus.io/) - - - -### The Magic of AI Services with LangChain4j by [Tales from the jar side](https://www.youtube.com/@talesfromthejarside) - - - -### Or consider some of the use cases - -- You want to: - - Implement a custom AI-powered chatbot that has access to your data and behaves the way you want it. - - Implement a customer support chatbot that can: - - politely answer customer questions - - take /change/cancel orders - - - Implement an educational assistant that can: - - Teach various subjects - - Explain unclear parts - - Assess user's understanding/knowledge - - You want to process a lot of unstructured data (files, web pages, etc) and extract structured information from them. For example: - - extract insights from customer reviews and support chat history - - extract interesting information from the websites of your competitors - - extract insights from CVs of job applicants - - - Generate information, for example: - - Emails tailored for each of your customers - - - Generate content for your app/website: - - Blog posts - - Stories - - - Transform information, for example: - - Summarize - - Proofread and rewrite - - Translate \ No newline at end of file diff --git a/docs/docs/tutorials/logging.md b/docs/docs/tutorials/logging.md index 48dba9af57..cb951f6a73 100644 --- a/docs/docs/tutorials/logging.md +++ b/docs/docs/tutorials/logging.md @@ -1,4 +1,5 @@ --- +id: logging sidebar_position: 30 --- diff --git a/docs/docs/tutorials/quarkus-integration.md b/docs/docs/tutorials/quarkus-integration.md index faac8e6e6d..19dd216ee9 100644 --- a/docs/docs/tutorials/quarkus-integration.md +++ b/docs/docs/tutorials/quarkus-integration.md @@ -1,10 +1,10 @@ --- +id: quarkus-integration sidebar_position: 24 --- # Quarkus Integration -- [Quarkus-LangChain4j extension repo](https://github.com/quarkiverse/quarkus-langchain4j) -- [Quarkus-LangChain4j extension documentation](https://docs.quarkiverse.io/quarkus-langchain4j/dev/index.html) +[Quarkus](https://quarkus.io/) provides a superb [extension for LangChain4j](https://github.com/quarkiverse/quarkus-langchain4j). -More info coming soon +You can find all the necessary documentation [here](https://docs.quarkiverse.io/quarkus-langchain4j/dev/index.html). diff --git a/docs/docs/tutorials/spring-boot-integration.md b/docs/docs/tutorials/spring-boot-integration.md index 87533d77ac..ff6f11eea1 100644 --- a/docs/docs/tutorials/spring-boot-integration.md +++ b/docs/docs/tutorials/spring-boot-integration.md @@ -1,4 +1,5 @@ --- +id: spring-boot-integration sidebar_position: 27 --- @@ -14,7 +15,7 @@ To use one of the Spring Boot starters, first import the corresponding dependenc dev.langchain4j langchain4j-open-ai-spring-boot-starter - 0.30.0 + 0.31.0 ``` @@ -48,7 +49,7 @@ import `langchain4j-spring-boot-starter`: dev.langchain4j langchain4j-spring-boot-starter - 0.30.0 + 0.31.0 ``` @@ -85,9 +86,9 @@ class AssistantController { More details [here](https://github.com/langchain4j/langchain4j-spring/blob/main/langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/AiService.java). -## Supported Spring Boot Versions +## Supported versions -Spring Boot 2 and 3 are supported. +LangChain4j Spring Boot integration requires Java 17 and Spring Boot 3.2. ## Examples - [Low-level Spring Boot example](https://github.com/langchain4j/langchain4j-examples/blob/main/spring-boot-example/src/main/java/dev/langchain4j/example/lowlevel/ChatLanguageModelController.java) using [ChatLanguageModel API](/tutorials/chat-and-language-models) diff --git a/docs/docs/tutorials/structured-data-extraction.md b/docs/docs/tutorials/structured-data-extraction.md index 4f975f3f15..959c0503a9 100644 --- a/docs/docs/tutorials/structured-data-extraction.md +++ b/docs/docs/tutorials/structured-data-extraction.md @@ -1,10 +1,11 @@ --- +id: structured-data-extraction sidebar_position: 11 --- # Structured Data Extraction -More info coming soon +Documentation on structured data extraction can be found [here](/tutorials/ai-services#output-parsing-aka-structured-outputs). ## Examples diff --git a/docs/docs/useful-materials/index.mdx b/docs/docs/useful-materials/index.mdx index dcc667aa00..d6db5fddc1 100644 --- a/docs/docs/useful-materials/index.mdx +++ b/docs/docs/useful-materials/index.mdx @@ -7,6 +7,7 @@ sidebar_position: 1 ## Learning Materials - [Intro to Large Language Models](https://www.youtube.com/watch?v=zjkBMFhNj_g) by [Andrej Karpathy](https://www.youtube.com/@AndrejKarpathy) - [Short Courses](https://www.deeplearning.ai/short-courses/) by [DeepLearning.AI](https://www.deeplearning.ai/) +- [What We Learned from a Year of Building with LLMs (Part I)](https://www.oreilly.com/radar/what-we-learned-from-a-year-of-building-with-llms-part-i/) ## Local LLMs - [LocalLLaMA on Reddit](https://www.reddit.com/r/LocalLLaMA/) @@ -14,10 +15,20 @@ sidebar_position: 1 - [LocalAI](https://localai.io/) - [Guide to Choosing Quantization Methods and Inference Engines](https://www.reddit.com/r/LocalLLaMA/s/wZ3Sjifnqf) +## Evaluations +- [A Practical Guide to RAG Pipeline Evaluation (Part 1: Retrieval)](https://medium.com/relari/a-practical-guide-to-rag-pipeline-evaluation-part-1-27a472b09893) +- [A Practical Guide to RAG Pipeline Evaluation (Part 2: Generation)](https://medium.com/relari/a-practical-guide-to-rag-evaluation-part-2-generation-c79b1bde0f5d) +- [How important is a Golden Dataset for LLM evaluation?](https://medium.com/relari/how-important-is-a-golden-dataset-for-llm-pipeline-evaluation-4ef6deb14dc5) +- [Case Study: Reference-free vs Reference-based evaluation of RAG pipeline](https://medium.com/relari/case-study-reference-free-vs-reference-based-evaluation-of-rag-pipeline-9a49ef49866c) +- [How to evaluate complex GenAI Apps: a granular approach](https://medium.com/relari/how-to-evaluate-complex-genai-apps-a-granular-approach-0ab929d5b3e2) +- [Generate Synthetic Data to Test LLM Applications](https://medium.com/relari/generate-synthetic-data-to-test-llm-applications-4bffeb51b80e) + ## Leaderboards ### Language Models - [LMSYS Chatbot Arena](https://huggingface.co/spaces/lmsys/chatbot-arena-leaderboard) +- [SEAL Leaderboards](https://scale.com/leaderboard) +- [Comparing models for quality, speed, price, etc.](https://artificialanalysis.ai/) - Hallucinations: [Vectara](https://huggingface.co/spaces/vectara/leaderboard), [Hallucinations](https://huggingface.co/spaces/hallucinations-leaderboard/leaderboard) - Code Generation: [BigCode](https://huggingface.co/spaces/bigcode/bigcode-models-leaderboard) - Tools/Functions: [Gorilla](https://gorilla.cs.berkeley.edu/leaderboard.html), [Nexus](https://huggingface.co/spaces/Nexusflow/Nexus_Function_Calling_Leaderboard), [Toolbench](https://huggingface.co/spaces/qiantong-xu/toolbench-leaderboard) diff --git a/docs/package-lock.json b/docs/package-lock.json index 7a50647e34..5a7fac38b9 100644 --- a/docs/package-lock.json +++ b/docs/package-lock.json @@ -8,10 +8,10 @@ "name": "langchain4j", "version": "0.0.0", "dependencies": { - "@docusaurus/core": "3.0.1", - "@docusaurus/plugin-content-docs": "^3.1.0", - "@docusaurus/preset-classic": "3.0.1", - "@docusaurus/theme-mermaid": "^3.0.1", + "@docusaurus/core": "^3.4.0", + "@docusaurus/plugin-content-docs": "^3.4.0", + "@docusaurus/preset-classic": "^3.4.0", + "@docusaurus/theme-mermaid": "^3.4.0", "@mdx-js/react": "^3.0.0", "clsx": "^2.0.0", "prism-react-renderer": "^2.3.0", @@ -68,74 +68,74 @@ } }, "node_modules/@algolia/cache-browser-local-storage": { - "version": "4.22.0", - "resolved": "https://registry.npmjs.org/@algolia/cache-browser-local-storage/-/cache-browser-local-storage-4.22.0.tgz", - "integrity": "sha512-uZ1uZMLDZb4qODLfTSNHxSi4fH9RdrQf7DXEzW01dS8XK7QFtFh29N5NGKa9S+Yudf1vUMIF+/RiL4i/J0pWlQ==", + "version": "4.23.3", + "resolved": "https://registry.npmjs.org/@algolia/cache-browser-local-storage/-/cache-browser-local-storage-4.23.3.tgz", + "integrity": "sha512-vRHXYCpPlTDE7i6UOy2xE03zHF2C8MEFjPN2v7fRbqVpcOvAUQK81x3Kc21xyb5aSIpYCjWCZbYZuz8Glyzyyg==", "dependencies": { - "@algolia/cache-common": "4.22.0" + "@algolia/cache-common": "4.23.3" } }, "node_modules/@algolia/cache-common": { - "version": "4.22.0", - "resolved": "https://registry.npmjs.org/@algolia/cache-common/-/cache-common-4.22.0.tgz", - "integrity": "sha512-TPwUMlIGPN16eW67qamNQUmxNiGHg/WBqWcrOoCddhqNTqGDPVqmgfaM85LPbt24t3r1z0zEz/tdsmuq3Q6oaA==" + "version": "4.23.3", + "resolved": "https://registry.npmjs.org/@algolia/cache-common/-/cache-common-4.23.3.tgz", + "integrity": "sha512-h9XcNI6lxYStaw32pHpB1TMm0RuxphF+Ik4o7tcQiodEdpKK+wKufY6QXtba7t3k8eseirEMVB83uFFF3Nu54A==" }, "node_modules/@algolia/cache-in-memory": { - "version": "4.22.0", - "resolved": "https://registry.npmjs.org/@algolia/cache-in-memory/-/cache-in-memory-4.22.0.tgz", - "integrity": "sha512-kf4Cio9NpPjzp1+uXQgL4jsMDeck7MP89BYThSvXSjf2A6qV/0KeqQf90TL2ECS02ovLOBXkk98P7qVarM+zGA==", + "version": "4.23.3", + "resolved": "https://registry.npmjs.org/@algolia/cache-in-memory/-/cache-in-memory-4.23.3.tgz", + "integrity": "sha512-yvpbuUXg/+0rbcagxNT7un0eo3czx2Uf0y4eiR4z4SD7SiptwYTpbuS0IHxcLHG3lq22ukx1T6Kjtk/rT+mqNg==", "dependencies": { - "@algolia/cache-common": "4.22.0" + "@algolia/cache-common": "4.23.3" } }, "node_modules/@algolia/client-account": { - "version": "4.22.0", - "resolved": "https://registry.npmjs.org/@algolia/client-account/-/client-account-4.22.0.tgz", - "integrity": "sha512-Bjb5UXpWmJT+yGWiqAJL0prkENyEZTBzdC+N1vBuHjwIJcjLMjPB6j1hNBRbT12Lmwi55uzqeMIKS69w+0aPzA==", + "version": "4.23.3", + "resolved": "https://registry.npmjs.org/@algolia/client-account/-/client-account-4.23.3.tgz", + "integrity": "sha512-hpa6S5d7iQmretHHF40QGq6hz0anWEHGlULcTIT9tbUssWUriN9AUXIFQ8Ei4w9azD0hc1rUok9/DeQQobhQMA==", "dependencies": { - "@algolia/client-common": "4.22.0", - "@algolia/client-search": "4.22.0", - "@algolia/transporter": "4.22.0" + "@algolia/client-common": "4.23.3", + "@algolia/client-search": "4.23.3", + "@algolia/transporter": "4.23.3" } }, "node_modules/@algolia/client-analytics": { - "version": "4.22.0", - "resolved": "https://registry.npmjs.org/@algolia/client-analytics/-/client-analytics-4.22.0.tgz", - "integrity": "sha512-os2K+kHUcwwRa4ArFl5p/3YbF9lN3TLOPkbXXXxOvDpqFh62n9IRZuzfxpHxMPKAQS3Et1s0BkKavnNP02E9Hg==", + "version": "4.23.3", + "resolved": "https://registry.npmjs.org/@algolia/client-analytics/-/client-analytics-4.23.3.tgz", + "integrity": "sha512-LBsEARGS9cj8VkTAVEZphjxTjMVCci+zIIiRhpFun9jGDUlS1XmhCW7CTrnaWeIuCQS/2iPyRqSy1nXPjcBLRA==", "dependencies": { - "@algolia/client-common": "4.22.0", - "@algolia/client-search": "4.22.0", - "@algolia/requester-common": "4.22.0", - "@algolia/transporter": "4.22.0" + "@algolia/client-common": "4.23.3", + "@algolia/client-search": "4.23.3", + "@algolia/requester-common": "4.23.3", + "@algolia/transporter": "4.23.3" } }, "node_modules/@algolia/client-common": { - "version": "4.22.0", - "resolved": "https://registry.npmjs.org/@algolia/client-common/-/client-common-4.22.0.tgz", - "integrity": "sha512-BlbkF4qXVWuwTmYxVWvqtatCR3lzXwxx628p1wj1Q7QP2+LsTmGt1DiUYRuy9jG7iMsnlExby6kRMOOlbhv2Ag==", + "version": "4.23.3", + "resolved": "https://registry.npmjs.org/@algolia/client-common/-/client-common-4.23.3.tgz", + "integrity": "sha512-l6EiPxdAlg8CYhroqS5ybfIczsGUIAC47slLPOMDeKSVXYG1n0qGiz4RjAHLw2aD0xzh2EXZ7aRguPfz7UKDKw==", "dependencies": { - "@algolia/requester-common": "4.22.0", - "@algolia/transporter": "4.22.0" + "@algolia/requester-common": "4.23.3", + "@algolia/transporter": "4.23.3" } }, "node_modules/@algolia/client-personalization": { - "version": "4.22.0", - "resolved": "https://registry.npmjs.org/@algolia/client-personalization/-/client-personalization-4.22.0.tgz", - "integrity": "sha512-pEOftCxeBdG5pL97WngOBi9w5Vxr5KCV2j2D+xMVZH8MuU/JX7CglDSDDb0ffQWYqcUN+40Ry+xtXEYaGXTGow==", + "version": "4.23.3", + "resolved": "https://registry.npmjs.org/@algolia/client-personalization/-/client-personalization-4.23.3.tgz", + "integrity": "sha512-3E3yF3Ocr1tB/xOZiuC3doHQBQ2zu2MPTYZ0d4lpfWads2WTKG7ZzmGnsHmm63RflvDeLK/UVx7j2b3QuwKQ2g==", "dependencies": { - "@algolia/client-common": "4.22.0", - "@algolia/requester-common": "4.22.0", - "@algolia/transporter": "4.22.0" + "@algolia/client-common": "4.23.3", + "@algolia/requester-common": "4.23.3", + "@algolia/transporter": "4.23.3" } }, "node_modules/@algolia/client-search": { - "version": "4.22.0", - "resolved": "https://registry.npmjs.org/@algolia/client-search/-/client-search-4.22.0.tgz", - "integrity": "sha512-bn4qQiIdRPBGCwsNuuqB8rdHhGKKWIij9OqidM1UkQxnSG8yzxHdb7CujM30pvp5EnV7jTqDZRbxacbjYVW20Q==", + "version": "4.23.3", + "resolved": "https://registry.npmjs.org/@algolia/client-search/-/client-search-4.23.3.tgz", + "integrity": "sha512-P4VAKFHqU0wx9O+q29Q8YVuaowaZ5EM77rxfmGnkHUJggh28useXQdopokgwMeYw2XUht49WX5RcTQ40rZIabw==", "dependencies": { - "@algolia/client-common": "4.22.0", - "@algolia/requester-common": "4.22.0", - "@algolia/transporter": "4.22.0" + "@algolia/client-common": "4.23.3", + "@algolia/requester-common": "4.23.3", + "@algolia/transporter": "4.23.3" } }, "node_modules/@algolia/events": { @@ -144,47 +144,65 @@ "integrity": "sha512-FQzvOCgoFXAbf5Y6mYozw2aj5KCJoA3m4heImceldzPSMbdyS4atVjJzXKMsfX3wnZTFYwkkt8/z8UesLHlSBQ==" }, "node_modules/@algolia/logger-common": { - "version": "4.22.0", - "resolved": "https://registry.npmjs.org/@algolia/logger-common/-/logger-common-4.22.0.tgz", - "integrity": "sha512-HMUQTID0ucxNCXs5d1eBJ5q/HuKg8rFVE/vOiLaM4Abfeq1YnTtGV3+rFEhOPWhRQxNDd+YHa4q864IMc0zHpQ==" + "version": "4.23.3", + "resolved": "https://registry.npmjs.org/@algolia/logger-common/-/logger-common-4.23.3.tgz", + "integrity": "sha512-y9kBtmJwiZ9ZZ+1Ek66P0M68mHQzKRxkW5kAAXYN/rdzgDN0d2COsViEFufxJ0pb45K4FRcfC7+33YB4BLrZ+g==" }, "node_modules/@algolia/logger-console": { - "version": "4.22.0", - "resolved": "https://registry.npmjs.org/@algolia/logger-console/-/logger-console-4.22.0.tgz", - "integrity": "sha512-7JKb6hgcY64H7CRm3u6DRAiiEVXMvCJV5gRE672QFOUgDxo4aiDpfU61g6Uzy8NKjlEzHMmgG4e2fklELmPXhQ==", + "version": "4.23.3", + "resolved": "https://registry.npmjs.org/@algolia/logger-console/-/logger-console-4.23.3.tgz", + "integrity": "sha512-8xoiseoWDKuCVnWP8jHthgaeobDLolh00KJAdMe9XPrWPuf1by732jSpgy2BlsLTaT9m32pHI8CRfrOqQzHv3A==", "dependencies": { - "@algolia/logger-common": "4.22.0" + "@algolia/logger-common": "4.23.3" + } + }, + "node_modules/@algolia/recommend": { + "version": "4.23.3", + "resolved": "https://registry.npmjs.org/@algolia/recommend/-/recommend-4.23.3.tgz", + "integrity": "sha512-9fK4nXZF0bFkdcLBRDexsnGzVmu4TSYZqxdpgBW2tEyfuSSY54D4qSRkLmNkrrz4YFvdh2GM1gA8vSsnZPR73w==", + "dependencies": { + "@algolia/cache-browser-local-storage": "4.23.3", + "@algolia/cache-common": "4.23.3", + "@algolia/cache-in-memory": "4.23.3", + "@algolia/client-common": "4.23.3", + "@algolia/client-search": "4.23.3", + "@algolia/logger-common": "4.23.3", + "@algolia/logger-console": "4.23.3", + "@algolia/requester-browser-xhr": "4.23.3", + "@algolia/requester-common": "4.23.3", + "@algolia/requester-node-http": "4.23.3", + "@algolia/transporter": "4.23.3" } }, "node_modules/@algolia/requester-browser-xhr": { - "version": "4.22.0", - "resolved": "https://registry.npmjs.org/@algolia/requester-browser-xhr/-/requester-browser-xhr-4.22.0.tgz", - "integrity": "sha512-BHfv1h7P9/SyvcDJDaRuIwDu2yrDLlXlYmjvaLZTtPw6Ok/ZVhBR55JqW832XN/Fsl6k3LjdkYHHR7xnsa5Wvg==", + "version": "4.23.3", + "resolved": "https://registry.npmjs.org/@algolia/requester-browser-xhr/-/requester-browser-xhr-4.23.3.tgz", + "integrity": "sha512-jDWGIQ96BhXbmONAQsasIpTYWslyjkiGu0Quydjlowe+ciqySpiDUrJHERIRfELE5+wFc7hc1Q5hqjGoV7yghw==", "dependencies": { - "@algolia/requester-common": "4.22.0" + "@algolia/requester-common": "4.23.3" } }, "node_modules/@algolia/requester-common": { - "version": "4.22.0", - "resolved": "https://registry.npmjs.org/@algolia/requester-common/-/requester-common-4.22.0.tgz", - "integrity": "sha512-Y9cEH/cKjIIZgzvI1aI0ARdtR/xRrOR13g5psCxkdhpgRN0Vcorx+zePhmAa4jdQNqexpxtkUdcKYugBzMZJgQ==" + "version": "4.23.3", + "resolved": "https://registry.npmjs.org/@algolia/requester-common/-/requester-common-4.23.3.tgz", + "integrity": "sha512-xloIdr/bedtYEGcXCiF2muajyvRhwop4cMZo+K2qzNht0CMzlRkm8YsDdj5IaBhshqfgmBb3rTg4sL4/PpvLYw==" }, "node_modules/@algolia/requester-node-http": { - "version": "4.22.0", - "resolved": "https://registry.npmjs.org/@algolia/requester-node-http/-/requester-node-http-4.22.0.tgz", - "integrity": "sha512-8xHoGpxVhz3u2MYIieHIB6MsnX+vfd5PS4REgglejJ6lPigftRhTdBCToe6zbwq4p0anZXjjPDvNWMlgK2+xYA==", + "version": "4.23.3", + "resolved": "https://registry.npmjs.org/@algolia/requester-node-http/-/requester-node-http-4.23.3.tgz", + "integrity": "sha512-zgu++8Uj03IWDEJM3fuNl34s746JnZOWn1Uz5taV1dFyJhVM/kTNw9Ik7YJWiUNHJQXcaD8IXD1eCb0nq/aByA==", "dependencies": { - "@algolia/requester-common": "4.22.0" + "@algolia/requester-common": "4.23.3" } }, "node_modules/@algolia/transporter": { - "version": "4.22.0", - "resolved": "https://registry.npmjs.org/@algolia/transporter/-/transporter-4.22.0.tgz", - "integrity": "sha512-ieO1k8x2o77GNvOoC+vAkFKppydQSVfbjM3YrSjLmgywiBejPTvU1R1nEvG59JIIUvtSLrZsLGPkd6vL14zopA==", + "version": "4.23.3", + "resolved": "https://registry.npmjs.org/@algolia/transporter/-/transporter-4.23.3.tgz", + "integrity": "sha512-Wjl5gttqnf/gQKJA+dafnD0Y6Yw97yvfY8R9h0dQltX1GXTgNs1zWgvtWW0tHl1EgMdhAyw189uWiZMnL3QebQ==", "dependencies": { - "@algolia/cache-common": "4.22.0", - "@algolia/logger-common": "4.22.0", - "@algolia/requester-common": "4.22.0" + "@algolia/cache-common": "4.23.3", + "@algolia/logger-common": "4.23.3", + "@algolia/requester-common": "4.23.3" } }, "node_modules/@ampproject/remapping": { @@ -531,9 +549,9 @@ } }, "node_modules/@babel/helper-plugin-utils": { - "version": "7.22.5", - "resolved": "https://registry.npmjs.org/@babel/helper-plugin-utils/-/helper-plugin-utils-7.22.5.tgz", - "integrity": "sha512-uLls06UVKgFG9QD4OeFYLEGteMIAa5kpTPcFL28yuCIIzsf6ZyKZMllKVOCZFhiZ5ptnwX4mtKdWCBE/uT4amg==", + "version": "7.24.7", + "resolved": "https://registry.npmjs.org/@babel/helper-plugin-utils/-/helper-plugin-utils-7.24.7.tgz", + "integrity": "sha512-Rq76wjt7yz9AAc1KnlRKNAi/dMSVWgDRx43FHoJEbcYU6xOWaE2dVPwcdTukJrjxS65GITyfbvEYHvkirZ6uEg==", "engines": { "node": ">=6.9.0" } @@ -1609,11 +1627,11 @@ } }, "node_modules/@babel/plugin-transform-react-constant-elements": { - "version": "7.23.3", - "resolved": "https://registry.npmjs.org/@babel/plugin-transform-react-constant-elements/-/plugin-transform-react-constant-elements-7.23.3.tgz", - "integrity": "sha512-zP0QKq/p6O42OL94udMgSfKXyse4RyJ0JqbQ34zDAONWjyrEsghYEyTSK5FIpmXmCpB55SHokL1cRRKHv8L2Qw==", + "version": "7.24.7", + "resolved": "https://registry.npmjs.org/@babel/plugin-transform-react-constant-elements/-/plugin-transform-react-constant-elements-7.24.7.tgz", + "integrity": "sha512-7LidzZfUXyfZ8/buRW6qIIHBY8wAZ1OrY9c/wTr8YhZ6vMPo+Uc/CVFLYY1spZrEQlD4w5u8wjqk5NQ3OVqQKA==", "dependencies": { - "@babel/helper-plugin-utils": "^7.22.5" + "@babel/helper-plugin-utils": "^7.24.7" }, "engines": { "node": ">=6.9.0" @@ -2134,18 +2152,18 @@ } }, "node_modules/@docsearch/css": { - "version": "3.5.2", - "resolved": "https://registry.npmjs.org/@docsearch/css/-/css-3.5.2.tgz", - "integrity": "sha512-SPiDHaWKQZpwR2siD0KQUwlStvIAnEyK6tAE2h2Wuoq8ue9skzhlyVQ1ddzOxX6khULnAALDiR/isSF3bnuciA==" + "version": "3.6.0", + "resolved": "https://registry.npmjs.org/@docsearch/css/-/css-3.6.0.tgz", + "integrity": "sha512-+sbxb71sWre+PwDK7X2T8+bhS6clcVMLwBPznX45Qu6opJcgRjAp7gYSDzVFp187J+feSj5dNBN1mJoi6ckkUQ==" }, "node_modules/@docsearch/react": { - "version": "3.5.2", - "resolved": "https://registry.npmjs.org/@docsearch/react/-/react-3.5.2.tgz", - "integrity": "sha512-9Ahcrs5z2jq/DcAvYtvlqEBHImbm4YJI8M9y0x6Tqg598P40HTEkX7hsMcIuThI+hTFxRGZ9hll0Wygm2yEjng==", + "version": "3.6.0", + "resolved": "https://registry.npmjs.org/@docsearch/react/-/react-3.6.0.tgz", + "integrity": "sha512-HUFut4ztcVNmqy9gp/wxNbC7pTOHhgVVkHVGCACTuLhUKUhKAF9KYHJtMiLUJxEqiFLQiuri1fWF8zqwM/cu1w==", "dependencies": { "@algolia/autocomplete-core": "1.9.3", "@algolia/autocomplete-preset-algolia": "1.9.3", - "@docsearch/css": "3.5.2", + "@docsearch/css": "3.6.0", "algoliasearch": "^4.19.1" }, "peerDependencies": { @@ -2170,9 +2188,9 @@ } }, "node_modules/@docusaurus/core": { - "version": "3.0.1", - "resolved": "https://registry.npmjs.org/@docusaurus/core/-/core-3.0.1.tgz", - "integrity": "sha512-CXrLpOnW+dJdSv8M5FAJ3JBwXtL6mhUWxFA8aS0ozK6jBG/wgxERk5uvH28fCeFxOGbAT9v1e9dOMo1X2IEVhQ==", + "version": "3.4.0", + "resolved": "https://registry.npmjs.org/@docusaurus/core/-/core-3.4.0.tgz", + "integrity": "sha512-g+0wwmN2UJsBqy2fQRQ6fhXruoEa62JDeEa5d8IdTJlMoaDaEDfHh7WjwGRn4opuTQWpjAwP/fbcgyHKlE+64w==", "dependencies": { "@babel/core": "^7.23.3", "@babel/generator": "^7.23.3", @@ -2184,15 +2202,12 @@ "@babel/runtime": "^7.22.6", "@babel/runtime-corejs3": "^7.22.6", "@babel/traverse": "^7.22.8", - "@docusaurus/cssnano-preset": "3.0.1", - "@docusaurus/logger": "3.0.1", - "@docusaurus/mdx-loader": "3.0.1", - "@docusaurus/react-loadable": "5.5.2", - "@docusaurus/utils": "3.0.1", - "@docusaurus/utils-common": "3.0.1", - "@docusaurus/utils-validation": "3.0.1", - "@slorber/static-site-generator-webpack-plugin": "^4.0.7", - "@svgr/webpack": "^6.5.1", + "@docusaurus/cssnano-preset": "3.4.0", + "@docusaurus/logger": "3.4.0", + "@docusaurus/mdx-loader": "3.4.0", + "@docusaurus/utils": "3.4.0", + "@docusaurus/utils-common": "3.4.0", + "@docusaurus/utils-validation": "3.4.0", "autoprefixer": "^10.4.14", "babel-loader": "^9.1.3", "babel-plugin-dynamic-import-node": "^2.3.3", @@ -2206,12 +2221,13 @@ "copy-webpack-plugin": "^11.0.0", "core-js": "^3.31.1", "css-loader": "^6.8.1", - "css-minimizer-webpack-plugin": "^4.2.2", - "cssnano": "^5.1.15", + "css-minimizer-webpack-plugin": "^5.0.1", + "cssnano": "^6.1.2", "del": "^6.1.1", "detect-port": "^1.5.1", "escape-html": "^1.0.3", "eta": "^2.2.0", + "eval": "^0.1.8", "file-loader": "^6.2.0", "fs-extra": "^11.1.1", "html-minifier-terser": "^7.2.0", @@ -2220,12 +2236,13 @@ "leven": "^3.1.0", "lodash": "^4.17.21", "mini-css-extract-plugin": "^2.7.6", + "p-map": "^4.0.0", "postcss": "^8.4.26", "postcss-loader": "^7.3.3", "prompts": "^2.4.2", "react-dev-utils": "^12.0.1", "react-helmet-async": "^1.3.0", - "react-loadable": "npm:@docusaurus/react-loadable@5.5.2", + "react-loadable": "npm:@docusaurus/react-loadable@6.0.0", "react-loadable-ssr-addon-v5-slorber": "^1.0.1", "react-router": "^5.3.4", "react-router-config": "^5.1.1", @@ -2255,14 +2272,26 @@ "react-dom": "^18.0.0" } }, + "node_modules/@docusaurus/core/node_modules/react-loadable": { + "name": "@docusaurus/react-loadable", + "version": "6.0.0", + "resolved": "https://registry.npmjs.org/@docusaurus/react-loadable/-/react-loadable-6.0.0.tgz", + "integrity": "sha512-YMMxTUQV/QFSnbgrP3tjDzLHRg7vsbMn8e9HAa8o/1iXoiomo48b7sk/kkmWEuWNDPJVlKSJRB6Y2fHqdJk+SQ==", + "dependencies": { + "@types/react": "*" + }, + "peerDependencies": { + "react": "*" + } + }, "node_modules/@docusaurus/cssnano-preset": { - "version": "3.0.1", - "resolved": "https://registry.npmjs.org/@docusaurus/cssnano-preset/-/cssnano-preset-3.0.1.tgz", - "integrity": "sha512-wjuXzkHMW+ig4BD6Ya1Yevx9UJadO4smNZCEljqBoQfIQrQskTswBs7lZ8InHP7mCt273a/y/rm36EZhqJhknQ==", + "version": "3.4.0", + "resolved": "https://registry.npmjs.org/@docusaurus/cssnano-preset/-/cssnano-preset-3.4.0.tgz", + "integrity": "sha512-qwLFSz6v/pZHy/UP32IrprmH5ORce86BGtN0eBtG75PpzQJAzp9gefspox+s8IEOr0oZKuQ/nhzZ3xwyc3jYJQ==", "dependencies": { - "cssnano-preset-advanced": "^5.3.10", - "postcss": "^8.4.26", - "postcss-sort-media-queries": "^4.4.1", + "cssnano-preset-advanced": "^6.1.2", + "postcss": "^8.4.38", + "postcss-sort-media-queries": "^5.2.0", "tslib": "^2.6.0" }, "engines": { @@ -2270,9 +2299,9 @@ } }, "node_modules/@docusaurus/logger": { - "version": "3.0.1", - "resolved": "https://registry.npmjs.org/@docusaurus/logger/-/logger-3.0.1.tgz", - "integrity": "sha512-I5L6Nk8OJzkVA91O2uftmo71LBSxe1vmOn9AMR6JRCzYeEBrqneWMH02AqMvjJ2NpMiviO+t0CyPjyYV7nxCWQ==", + "version": "3.4.0", + "resolved": "https://registry.npmjs.org/@docusaurus/logger/-/logger-3.4.0.tgz", + "integrity": "sha512-bZwkX+9SJ8lB9kVRkXw+xvHYSMGG4bpYHKGXeXFvyVc79NMeeBSGgzd4TQLHH+DYeOJoCdl8flrFJVxlZ0wo/Q==", "dependencies": { "chalk": "^4.1.2", "tslib": "^2.6.0" @@ -2282,15 +2311,13 @@ } }, "node_modules/@docusaurus/mdx-loader": { - "version": "3.0.1", - "resolved": "https://registry.npmjs.org/@docusaurus/mdx-loader/-/mdx-loader-3.0.1.tgz", - "integrity": "sha512-ldnTmvnvlrONUq45oKESrpy+lXtbnTcTsFkOTIDswe5xx5iWJjt6eSa0f99ZaWlnm24mlojcIGoUWNCS53qVlQ==", + "version": "3.4.0", + "resolved": "https://registry.npmjs.org/@docusaurus/mdx-loader/-/mdx-loader-3.4.0.tgz", + "integrity": "sha512-kSSbrrk4nTjf4d+wtBA9H+FGauf2gCax89kV8SUSJu3qaTdSIKdWERlngsiHaCFgZ7laTJ8a67UFf+xlFPtuTw==", "dependencies": { - "@babel/parser": "^7.22.7", - "@babel/traverse": "^7.22.8", - "@docusaurus/logger": "3.0.1", - "@docusaurus/utils": "3.0.1", - "@docusaurus/utils-validation": "3.0.1", + "@docusaurus/logger": "3.4.0", + "@docusaurus/utils": "3.4.0", + "@docusaurus/utils-validation": "3.4.0", "@mdx-js/mdx": "^3.0.0", "@slorber/remark-comment": "^1.0.0", "escape-html": "^1.0.3", @@ -2325,6 +2352,7 @@ "version": "3.0.1", "resolved": "https://registry.npmjs.org/@docusaurus/module-type-aliases/-/module-type-aliases-3.0.1.tgz", "integrity": "sha512-DEHpeqUDsLynl3AhQQiO7AbC7/z/lBra34jTcdYuvp9eGm01pfH1wTVq8YqWZq6Jyx0BgcVl/VJqtE9StRd9Ag==", + "dev": true, "dependencies": { "@docusaurus/react-loadable": "5.5.2", "@docusaurus/types": "3.0.1", @@ -2341,17 +2369,17 @@ } }, "node_modules/@docusaurus/plugin-content-blog": { - "version": "3.0.1", - "resolved": "https://registry.npmjs.org/@docusaurus/plugin-content-blog/-/plugin-content-blog-3.0.1.tgz", - "integrity": "sha512-cLOvtvAyaMQFLI8vm4j26svg3ktxMPSXpuUJ7EERKoGbfpJSsgtowNHcRsaBVmfuCsRSk1HZ/yHBsUkTmHFEsg==", - "dependencies": { - "@docusaurus/core": "3.0.1", - "@docusaurus/logger": "3.0.1", - "@docusaurus/mdx-loader": "3.0.1", - "@docusaurus/types": "3.0.1", - "@docusaurus/utils": "3.0.1", - "@docusaurus/utils-common": "3.0.1", - "@docusaurus/utils-validation": "3.0.1", + "version": "3.4.0", + "resolved": "https://registry.npmjs.org/@docusaurus/plugin-content-blog/-/plugin-content-blog-3.4.0.tgz", + "integrity": "sha512-vv6ZAj78ibR5Jh7XBUT4ndIjmlAxkijM3Sx5MAAzC1gyv0vupDQNhzuFg1USQmQVj3P5I6bquk12etPV3LJ+Xw==", + "dependencies": { + "@docusaurus/core": "3.4.0", + "@docusaurus/logger": "3.4.0", + "@docusaurus/mdx-loader": "3.4.0", + "@docusaurus/types": "3.4.0", + "@docusaurus/utils": "3.4.0", + "@docusaurus/utils-common": "3.4.0", + "@docusaurus/utils-validation": "3.4.0", "cheerio": "^1.0.0-rc.12", "feed": "^4.2.2", "fs-extra": "^11.1.1", @@ -2371,177 +2399,46 @@ "react-dom": "^18.0.0" } }, - "node_modules/@docusaurus/plugin-content-docs": { - "version": "3.1.0", - "resolved": "https://registry.npmjs.org/@docusaurus/plugin-content-docs/-/plugin-content-docs-3.1.0.tgz", - "integrity": "sha512-el5GxhT8BLrsWD0qGa8Rq+Ttb/Ni6V3DGT2oAPio0qcs/mUAxeyXEAmihkvmLCnAgp6xD27Ce7dISZ5c6BXeqA==", - "dependencies": { - "@docusaurus/core": "3.1.0", - "@docusaurus/logger": "3.1.0", - "@docusaurus/mdx-loader": "3.1.0", - "@docusaurus/module-type-aliases": "3.1.0", - "@docusaurus/types": "3.1.0", - "@docusaurus/utils": "3.1.0", - "@docusaurus/utils-validation": "3.1.0", - "@types/react-router-config": "^5.0.7", - "combine-promises": "^1.1.0", - "fs-extra": "^11.1.1", - "js-yaml": "^4.1.0", - "lodash": "^4.17.21", - "tslib": "^2.6.0", - "utility-types": "^3.10.0", - "webpack": "^5.88.1" - }, - "engines": { - "node": ">=18.0" - }, - "peerDependencies": { - "react": "^18.0.0", - "react-dom": "^18.0.0" - } - }, - "node_modules/@docusaurus/plugin-content-docs/node_modules/@docusaurus/core": { - "version": "3.1.0", - "resolved": "https://registry.npmjs.org/@docusaurus/core/-/core-3.1.0.tgz", - "integrity": "sha512-GWudMGYA9v26ssbAWJNfgeDZk+lrudUTclLPRsmxiknEBk7UMp7Rglonhqbsf3IKHOyHkMU4Fr5jFyg5SBx9jQ==", + "node_modules/@docusaurus/plugin-content-blog/node_modules/@docusaurus/types": { + "version": "3.4.0", + "resolved": "https://registry.npmjs.org/@docusaurus/types/-/types-3.4.0.tgz", + "integrity": "sha512-4jcDO8kXi5Cf9TcyikB/yKmz14f2RZ2qTRerbHAsS+5InE9ZgSLBNLsewtFTcTOXSVcbU3FoGOzcNWAmU1TR0A==", "dependencies": { - "@babel/core": "^7.23.3", - "@babel/generator": "^7.23.3", - "@babel/plugin-syntax-dynamic-import": "^7.8.3", - "@babel/plugin-transform-runtime": "^7.22.9", - "@babel/preset-env": "^7.22.9", - "@babel/preset-react": "^7.22.5", - "@babel/preset-typescript": "^7.22.5", - "@babel/runtime": "^7.22.6", - "@babel/runtime-corejs3": "^7.22.6", - "@babel/traverse": "^7.22.8", - "@docusaurus/cssnano-preset": "3.1.0", - "@docusaurus/logger": "3.1.0", - "@docusaurus/mdx-loader": "3.1.0", - "@docusaurus/react-loadable": "5.5.2", - "@docusaurus/utils": "3.1.0", - "@docusaurus/utils-common": "3.1.0", - "@docusaurus/utils-validation": "3.1.0", - "@slorber/static-site-generator-webpack-plugin": "^4.0.7", - "@svgr/webpack": "^6.5.1", - "autoprefixer": "^10.4.14", - "babel-loader": "^9.1.3", - "babel-plugin-dynamic-import-node": "^2.3.3", - "boxen": "^6.2.1", - "chalk": "^4.1.2", - "chokidar": "^3.5.3", - "clean-css": "^5.3.2", - "cli-table3": "^0.6.3", - "combine-promises": "^1.1.0", + "@mdx-js/mdx": "^3.0.0", + "@types/history": "^4.7.11", + "@types/react": "*", "commander": "^5.1.0", - "copy-webpack-plugin": "^11.0.0", - "core-js": "^3.31.1", - "css-loader": "^6.8.1", - "css-minimizer-webpack-plugin": "^4.2.2", - "cssnano": "^5.1.15", - "del": "^6.1.1", - "detect-port": "^1.5.1", - "escape-html": "^1.0.3", - "eta": "^2.2.0", - "file-loader": "^6.2.0", - "fs-extra": "^11.1.1", - "html-minifier-terser": "^7.2.0", - "html-tags": "^3.3.1", - "html-webpack-plugin": "^5.5.3", - "leven": "^3.1.0", - "lodash": "^4.17.21", - "mini-css-extract-plugin": "^2.7.6", - "postcss": "^8.4.26", - "postcss-loader": "^7.3.3", - "prompts": "^2.4.2", - "react-dev-utils": "^12.0.1", + "joi": "^17.9.2", "react-helmet-async": "^1.3.0", - "react-loadable": "npm:@docusaurus/react-loadable@5.5.2", - "react-loadable-ssr-addon-v5-slorber": "^1.0.1", - "react-router": "^5.3.4", - "react-router-config": "^5.1.1", - "react-router-dom": "^5.3.4", - "rtl-detect": "^1.0.4", - "semver": "^7.5.4", - "serve-handler": "^6.1.5", - "shelljs": "^0.8.5", - "terser-webpack-plugin": "^5.3.9", - "tslib": "^2.6.0", - "update-notifier": "^6.0.2", - "url-loader": "^4.1.1", + "utility-types": "^3.10.0", "webpack": "^5.88.1", - "webpack-bundle-analyzer": "^4.9.0", - "webpack-dev-server": "^4.15.1", - "webpack-merge": "^5.9.0", - "webpackbar": "^5.0.2" - }, - "bin": { - "docusaurus": "bin/docusaurus.mjs" - }, - "engines": { - "node": ">=18.0" + "webpack-merge": "^5.9.0" }, "peerDependencies": { "react": "^18.0.0", "react-dom": "^18.0.0" } }, - "node_modules/@docusaurus/plugin-content-docs/node_modules/@docusaurus/cssnano-preset": { - "version": "3.1.0", - "resolved": "https://registry.npmjs.org/@docusaurus/cssnano-preset/-/cssnano-preset-3.1.0.tgz", - "integrity": "sha512-ned7qsgCqSv/e7KyugFNroAfiszuxLwnvMW7gmT2Ywxb/Nyt61yIw7KHyAZCMKglOalrqnYA4gMhLUCK/mVePA==", - "dependencies": { - "cssnano-preset-advanced": "^5.3.10", - "postcss": "^8.4.26", - "postcss-sort-media-queries": "^4.4.1", - "tslib": "^2.6.0" - }, - "engines": { - "node": ">=18.0" - } - }, - "node_modules/@docusaurus/plugin-content-docs/node_modules/@docusaurus/logger": { - "version": "3.1.0", - "resolved": "https://registry.npmjs.org/@docusaurus/logger/-/logger-3.1.0.tgz", - "integrity": "sha512-p740M+HCst1VnKKzL60Hru9xfG4EUYJDarjlEC4hHeBy9+afPmY3BNPoSHx9/8zxuYfUlv/psf7I9NvRVdmdvg==", - "dependencies": { - "chalk": "^4.1.2", - "tslib": "^2.6.0" - }, - "engines": { - "node": ">=18.0" - } - }, - "node_modules/@docusaurus/plugin-content-docs/node_modules/@docusaurus/mdx-loader": { - "version": "3.1.0", - "resolved": "https://registry.npmjs.org/@docusaurus/mdx-loader/-/mdx-loader-3.1.0.tgz", - "integrity": "sha512-D7onDz/3mgBonexWoQXPw3V2E5Bc4+jYRf9gGUUK+KoQwU8xMDaDkUUfsr7t6UBa/xox9p5+/3zwLuXOYMzGSg==", - "dependencies": { - "@babel/parser": "^7.22.7", - "@babel/traverse": "^7.22.8", - "@docusaurus/logger": "3.1.0", - "@docusaurus/utils": "3.1.0", - "@docusaurus/utils-validation": "3.1.0", - "@mdx-js/mdx": "^3.0.0", - "@slorber/remark-comment": "^1.0.0", - "escape-html": "^1.0.3", - "estree-util-value-to-estree": "^3.0.1", - "file-loader": "^6.2.0", + "node_modules/@docusaurus/plugin-content-docs": { + "version": "3.4.0", + "resolved": "https://registry.npmjs.org/@docusaurus/plugin-content-docs/-/plugin-content-docs-3.4.0.tgz", + "integrity": "sha512-HkUCZffhBo7ocYheD9oZvMcDloRnGhBMOZRyVcAQRFmZPmNqSyISlXA1tQCIxW+r478fty97XXAGjNYzBjpCsg==", + "dependencies": { + "@docusaurus/core": "3.4.0", + "@docusaurus/logger": "3.4.0", + "@docusaurus/mdx-loader": "3.4.0", + "@docusaurus/module-type-aliases": "3.4.0", + "@docusaurus/types": "3.4.0", + "@docusaurus/utils": "3.4.0", + "@docusaurus/utils-common": "3.4.0", + "@docusaurus/utils-validation": "3.4.0", + "@types/react-router-config": "^5.0.7", + "combine-promises": "^1.1.0", "fs-extra": "^11.1.1", - "image-size": "^1.0.2", - "mdast-util-mdx": "^3.0.0", - "mdast-util-to-string": "^4.0.0", - "rehype-raw": "^7.0.0", - "remark-directive": "^3.0.0", - "remark-emoji": "^4.0.0", - "remark-frontmatter": "^5.0.0", - "remark-gfm": "^4.0.0", - "stringify-object": "^3.3.0", + "js-yaml": "^4.1.0", + "lodash": "^4.17.21", "tslib": "^2.6.0", - "unified": "^11.0.3", - "unist-util-visit": "^5.0.0", - "url-loader": "^4.1.1", - "vfile": "^6.0.1", + "utility-types": "^3.10.0", "webpack": "^5.88.1" }, "engines": { @@ -2553,18 +2450,17 @@ } }, "node_modules/@docusaurus/plugin-content-docs/node_modules/@docusaurus/module-type-aliases": { - "version": "3.1.0", - "resolved": "https://registry.npmjs.org/@docusaurus/module-type-aliases/-/module-type-aliases-3.1.0.tgz", - "integrity": "sha512-XUl7Z4PWlKg4l6KF05JQ3iDHQxnPxbQUqTNKvviHyuHdlalOFv6qeDAm7IbzyQPJD5VA6y4dpRbTWSqP9ClwPg==", + "version": "3.4.0", + "resolved": "https://registry.npmjs.org/@docusaurus/module-type-aliases/-/module-type-aliases-3.4.0.tgz", + "integrity": "sha512-A1AyS8WF5Bkjnb8s+guTDuYmUiwJzNrtchebBHpc0gz0PyHJNMaybUlSrmJjHVcGrya0LKI4YcR3lBDQfXRYLw==", "dependencies": { - "@docusaurus/react-loadable": "5.5.2", - "@docusaurus/types": "3.1.0", + "@docusaurus/types": "3.4.0", "@types/history": "^4.7.11", "@types/react": "*", "@types/react-router-config": "*", "@types/react-router-dom": "*", "react-helmet-async": "*", - "react-loadable": "npm:@docusaurus/react-loadable@5.5.2" + "react-loadable": "npm:@docusaurus/react-loadable@6.0.0" }, "peerDependencies": { "react": "*", @@ -2572,9 +2468,9 @@ } }, "node_modules/@docusaurus/plugin-content-docs/node_modules/@docusaurus/types": { - "version": "3.1.0", - "resolved": "https://registry.npmjs.org/@docusaurus/types/-/types-3.1.0.tgz", - "integrity": "sha512-VaczOZf7+re8aFBIWnex1XENomwHdsSTkrdX43zyor7G/FY4OIsP6X28Xc3o0jiY0YdNuvIDyA5TNwOtpgkCVw==", + "version": "3.4.0", + "resolved": "https://registry.npmjs.org/@docusaurus/types/-/types-3.4.0.tgz", + "integrity": "sha512-4jcDO8kXi5Cf9TcyikB/yKmz14f2RZ2qTRerbHAsS+5InE9ZgSLBNLsewtFTcTOXSVcbU3FoGOzcNWAmU1TR0A==", "dependencies": { "@mdx-js/mdx": "^3.0.0", "@types/history": "^4.7.11", @@ -2591,107 +2487,108 @@ "react-dom": "^18.0.0" } }, - "node_modules/@docusaurus/plugin-content-docs/node_modules/@docusaurus/utils": { - "version": "3.1.0", - "resolved": "https://registry.npmjs.org/@docusaurus/utils/-/utils-3.1.0.tgz", - "integrity": "sha512-LgZfp0D+UBqAh7PZ//MUNSFBMavmAPku6Si9x8x3V+S318IGCNJ6hUr2O29UO0oLybEWUjD5Jnj9IUN6XyZeeg==", + "node_modules/@docusaurus/plugin-content-docs/node_modules/react-loadable": { + "name": "@docusaurus/react-loadable", + "version": "6.0.0", + "resolved": "https://registry.npmjs.org/@docusaurus/react-loadable/-/react-loadable-6.0.0.tgz", + "integrity": "sha512-YMMxTUQV/QFSnbgrP3tjDzLHRg7vsbMn8e9HAa8o/1iXoiomo48b7sk/kkmWEuWNDPJVlKSJRB6Y2fHqdJk+SQ==", "dependencies": { - "@docusaurus/logger": "3.1.0", - "@svgr/webpack": "^6.5.1", - "escape-string-regexp": "^4.0.0", - "file-loader": "^6.2.0", + "@types/react": "*" + }, + "peerDependencies": { + "react": "*" + } + }, + "node_modules/@docusaurus/plugin-content-pages": { + "version": "3.4.0", + "resolved": "https://registry.npmjs.org/@docusaurus/plugin-content-pages/-/plugin-content-pages-3.4.0.tgz", + "integrity": "sha512-h2+VN/0JjpR8fIkDEAoadNjfR3oLzB+v1qSXbIAKjQ46JAHx3X22n9nqS+BWSQnTnp1AjkjSvZyJMekmcwxzxg==", + "dependencies": { + "@docusaurus/core": "3.4.0", + "@docusaurus/mdx-loader": "3.4.0", + "@docusaurus/types": "3.4.0", + "@docusaurus/utils": "3.4.0", + "@docusaurus/utils-validation": "3.4.0", "fs-extra": "^11.1.1", - "github-slugger": "^1.5.0", - "globby": "^11.1.0", - "gray-matter": "^4.0.3", - "jiti": "^1.20.0", - "js-yaml": "^4.1.0", - "lodash": "^4.17.21", - "micromatch": "^4.0.5", - "resolve-pathname": "^3.0.0", - "shelljs": "^0.8.5", "tslib": "^2.6.0", - "url-loader": "^4.1.1", "webpack": "^5.88.1" }, "engines": { "node": ">=18.0" }, "peerDependencies": { - "@docusaurus/types": "*" - }, - "peerDependenciesMeta": { - "@docusaurus/types": { - "optional": true - } + "react": "^18.0.0", + "react-dom": "^18.0.0" } }, - "node_modules/@docusaurus/plugin-content-docs/node_modules/@docusaurus/utils-common": { - "version": "3.1.0", - "resolved": "https://registry.npmjs.org/@docusaurus/utils-common/-/utils-common-3.1.0.tgz", - "integrity": "sha512-SfvnRLHoZ9bwTw67knkSs7IcUR0GY2SaGkpdB/J9pChrDiGhwzKNUhcieoPyPYrOWGRPk3rVNYtoy+Bc7psPAw==", + "node_modules/@docusaurus/plugin-content-pages/node_modules/@docusaurus/types": { + "version": "3.4.0", + "resolved": "https://registry.npmjs.org/@docusaurus/types/-/types-3.4.0.tgz", + "integrity": "sha512-4jcDO8kXi5Cf9TcyikB/yKmz14f2RZ2qTRerbHAsS+5InE9ZgSLBNLsewtFTcTOXSVcbU3FoGOzcNWAmU1TR0A==", "dependencies": { - "tslib": "^2.6.0" - }, - "engines": { - "node": ">=18.0" + "@mdx-js/mdx": "^3.0.0", + "@types/history": "^4.7.11", + "@types/react": "*", + "commander": "^5.1.0", + "joi": "^17.9.2", + "react-helmet-async": "^1.3.0", + "utility-types": "^3.10.0", + "webpack": "^5.88.1", + "webpack-merge": "^5.9.0" }, "peerDependencies": { - "@docusaurus/types": "*" - }, - "peerDependenciesMeta": { - "@docusaurus/types": { - "optional": true - } + "react": "^18.0.0", + "react-dom": "^18.0.0" } }, - "node_modules/@docusaurus/plugin-content-docs/node_modules/@docusaurus/utils-validation": { - "version": "3.1.0", - "resolved": "https://registry.npmjs.org/@docusaurus/utils-validation/-/utils-validation-3.1.0.tgz", - "integrity": "sha512-dFxhs1NLxPOSzmcTk/eeKxLY5R+U4cua22g9MsAMiRWcwFKStZ2W3/GDY0GmnJGqNS8QAQepJrxQoyxXkJNDeg==", + "node_modules/@docusaurus/plugin-debug": { + "version": "3.4.0", + "resolved": "https://registry.npmjs.org/@docusaurus/plugin-debug/-/plugin-debug-3.4.0.tgz", + "integrity": "sha512-uV7FDUNXGyDSD3PwUaf5YijX91T5/H9SX4ErEcshzwgzWwBtK37nUWPU3ZLJfeTavX3fycTOqk9TglpOLaWkCg==", "dependencies": { - "@docusaurus/logger": "3.1.0", - "@docusaurus/utils": "3.1.0", - "joi": "^17.9.2", - "js-yaml": "^4.1.0", + "@docusaurus/core": "3.4.0", + "@docusaurus/types": "3.4.0", + "@docusaurus/utils": "3.4.0", + "fs-extra": "^11.1.1", + "react-json-view-lite": "^1.2.0", "tslib": "^2.6.0" }, "engines": { "node": ">=18.0" + }, + "peerDependencies": { + "react": "^18.0.0", + "react-dom": "^18.0.0" } }, - "node_modules/@docusaurus/plugin-content-pages": { - "version": "3.0.1", - "resolved": "https://registry.npmjs.org/@docusaurus/plugin-content-pages/-/plugin-content-pages-3.0.1.tgz", - "integrity": "sha512-oP7PoYizKAXyEttcvVzfX3OoBIXEmXTMzCdfmC4oSwjG4SPcJsRge3mmI6O8jcZBgUPjIzXD21bVGWEE1iu8gg==", + "node_modules/@docusaurus/plugin-debug/node_modules/@docusaurus/types": { + "version": "3.4.0", + "resolved": "https://registry.npmjs.org/@docusaurus/types/-/types-3.4.0.tgz", + "integrity": "sha512-4jcDO8kXi5Cf9TcyikB/yKmz14f2RZ2qTRerbHAsS+5InE9ZgSLBNLsewtFTcTOXSVcbU3FoGOzcNWAmU1TR0A==", "dependencies": { - "@docusaurus/core": "3.0.1", - "@docusaurus/mdx-loader": "3.0.1", - "@docusaurus/types": "3.0.1", - "@docusaurus/utils": "3.0.1", - "@docusaurus/utils-validation": "3.0.1", - "fs-extra": "^11.1.1", - "tslib": "^2.6.0", - "webpack": "^5.88.1" - }, - "engines": { - "node": ">=18.0" + "@mdx-js/mdx": "^3.0.0", + "@types/history": "^4.7.11", + "@types/react": "*", + "commander": "^5.1.0", + "joi": "^17.9.2", + "react-helmet-async": "^1.3.0", + "utility-types": "^3.10.0", + "webpack": "^5.88.1", + "webpack-merge": "^5.9.0" }, "peerDependencies": { "react": "^18.0.0", "react-dom": "^18.0.0" } }, - "node_modules/@docusaurus/plugin-debug": { - "version": "3.0.1", - "resolved": "https://registry.npmjs.org/@docusaurus/plugin-debug/-/plugin-debug-3.0.1.tgz", - "integrity": "sha512-09dxZMdATky4qdsZGzhzlUvvC+ilQ2hKbYF+wez+cM2mGo4qHbv8+qKXqxq0CQZyimwlAOWQLoSozIXU0g0i7g==", + "node_modules/@docusaurus/plugin-google-analytics": { + "version": "3.4.0", + "resolved": "https://registry.npmjs.org/@docusaurus/plugin-google-analytics/-/plugin-google-analytics-3.4.0.tgz", + "integrity": "sha512-mCArluxEGi3cmYHqsgpGGt3IyLCrFBxPsxNZ56Mpur0xSlInnIHoeLDH7FvVVcPJRPSQ9/MfRqLsainRw+BojA==", "dependencies": { - "@docusaurus/core": "3.0.1", - "@docusaurus/types": "3.0.1", - "@docusaurus/utils": "3.0.1", - "fs-extra": "^11.1.1", - "react-json-view-lite": "^1.2.0", + "@docusaurus/core": "3.4.0", + "@docusaurus/types": "3.4.0", + "@docusaurus/utils-validation": "3.4.0", "tslib": "^2.6.0" }, "engines": { @@ -2702,18 +2599,20 @@ "react-dom": "^18.0.0" } }, - "node_modules/@docusaurus/plugin-google-analytics": { - "version": "3.0.1", - "resolved": "https://registry.npmjs.org/@docusaurus/plugin-google-analytics/-/plugin-google-analytics-3.0.1.tgz", - "integrity": "sha512-jwseSz1E+g9rXQwDdr0ZdYNjn8leZBnKPjjQhMBEiwDoenL3JYFcNW0+p0sWoVF/f2z5t7HkKA+cYObrUh18gg==", + "node_modules/@docusaurus/plugin-google-analytics/node_modules/@docusaurus/types": { + "version": "3.4.0", + "resolved": "https://registry.npmjs.org/@docusaurus/types/-/types-3.4.0.tgz", + "integrity": "sha512-4jcDO8kXi5Cf9TcyikB/yKmz14f2RZ2qTRerbHAsS+5InE9ZgSLBNLsewtFTcTOXSVcbU3FoGOzcNWAmU1TR0A==", "dependencies": { - "@docusaurus/core": "3.0.1", - "@docusaurus/types": "3.0.1", - "@docusaurus/utils-validation": "3.0.1", - "tslib": "^2.6.0" - }, - "engines": { - "node": ">=18.0" + "@mdx-js/mdx": "^3.0.0", + "@types/history": "^4.7.11", + "@types/react": "*", + "commander": "^5.1.0", + "joi": "^17.9.2", + "react-helmet-async": "^1.3.0", + "utility-types": "^3.10.0", + "webpack": "^5.88.1", + "webpack-merge": "^5.9.0" }, "peerDependencies": { "react": "^18.0.0", @@ -2721,13 +2620,13 @@ } }, "node_modules/@docusaurus/plugin-google-gtag": { - "version": "3.0.1", - "resolved": "https://registry.npmjs.org/@docusaurus/plugin-google-gtag/-/plugin-google-gtag-3.0.1.tgz", - "integrity": "sha512-UFTDvXniAWrajsulKUJ1DB6qplui1BlKLQZjX4F7qS/qfJ+qkKqSkhJ/F4VuGQ2JYeZstYb+KaUzUzvaPK1aRQ==", + "version": "3.4.0", + "resolved": "https://registry.npmjs.org/@docusaurus/plugin-google-gtag/-/plugin-google-gtag-3.4.0.tgz", + "integrity": "sha512-Dsgg6PLAqzZw5wZ4QjUYc8Z2KqJqXxHxq3vIoyoBWiLEEfigIs7wHR+oiWUQy3Zk9MIk6JTYj7tMoQU0Jm3nqA==", "dependencies": { - "@docusaurus/core": "3.0.1", - "@docusaurus/types": "3.0.1", - "@docusaurus/utils-validation": "3.0.1", + "@docusaurus/core": "3.4.0", + "@docusaurus/types": "3.4.0", + "@docusaurus/utils-validation": "3.4.0", "@types/gtag.js": "^0.0.12", "tslib": "^2.6.0" }, @@ -2739,14 +2638,34 @@ "react-dom": "^18.0.0" } }, + "node_modules/@docusaurus/plugin-google-gtag/node_modules/@docusaurus/types": { + "version": "3.4.0", + "resolved": "https://registry.npmjs.org/@docusaurus/types/-/types-3.4.0.tgz", + "integrity": "sha512-4jcDO8kXi5Cf9TcyikB/yKmz14f2RZ2qTRerbHAsS+5InE9ZgSLBNLsewtFTcTOXSVcbU3FoGOzcNWAmU1TR0A==", + "dependencies": { + "@mdx-js/mdx": "^3.0.0", + "@types/history": "^4.7.11", + "@types/react": "*", + "commander": "^5.1.0", + "joi": "^17.9.2", + "react-helmet-async": "^1.3.0", + "utility-types": "^3.10.0", + "webpack": "^5.88.1", + "webpack-merge": "^5.9.0" + }, + "peerDependencies": { + "react": "^18.0.0", + "react-dom": "^18.0.0" + } + }, "node_modules/@docusaurus/plugin-google-tag-manager": { - "version": "3.0.1", - "resolved": "https://registry.npmjs.org/@docusaurus/plugin-google-tag-manager/-/plugin-google-tag-manager-3.0.1.tgz", - "integrity": "sha512-IPFvuz83aFuheZcWpTlAdiiX1RqWIHM+OH8wS66JgwAKOiQMR3+nLywGjkLV4bp52x7nCnwhNk1rE85Cpy/CIw==", + "version": "3.4.0", + "resolved": "https://registry.npmjs.org/@docusaurus/plugin-google-tag-manager/-/plugin-google-tag-manager-3.4.0.tgz", + "integrity": "sha512-O9tX1BTwxIhgXpOLpFDueYA9DWk69WCbDRrjYoMQtFHSkTyE7RhNgyjSPREUWJb9i+YUg3OrsvrBYRl64FCPCQ==", "dependencies": { - "@docusaurus/core": "3.0.1", - "@docusaurus/types": "3.0.1", - "@docusaurus/utils-validation": "3.0.1", + "@docusaurus/core": "3.4.0", + "@docusaurus/types": "3.4.0", + "@docusaurus/utils-validation": "3.4.0", "tslib": "^2.6.0" }, "engines": { @@ -2757,17 +2676,37 @@ "react-dom": "^18.0.0" } }, - "node_modules/@docusaurus/plugin-sitemap": { - "version": "3.0.1", - "resolved": "https://registry.npmjs.org/@docusaurus/plugin-sitemap/-/plugin-sitemap-3.0.1.tgz", - "integrity": "sha512-xARiWnjtVvoEniZudlCq5T9ifnhCu/GAZ5nA7XgyLfPcNpHQa241HZdsTlLtVcecEVVdllevBKOp7qknBBaMGw==", + "node_modules/@docusaurus/plugin-google-tag-manager/node_modules/@docusaurus/types": { + "version": "3.4.0", + "resolved": "https://registry.npmjs.org/@docusaurus/types/-/types-3.4.0.tgz", + "integrity": "sha512-4jcDO8kXi5Cf9TcyikB/yKmz14f2RZ2qTRerbHAsS+5InE9ZgSLBNLsewtFTcTOXSVcbU3FoGOzcNWAmU1TR0A==", "dependencies": { - "@docusaurus/core": "3.0.1", - "@docusaurus/logger": "3.0.1", - "@docusaurus/types": "3.0.1", - "@docusaurus/utils": "3.0.1", - "@docusaurus/utils-common": "3.0.1", - "@docusaurus/utils-validation": "3.0.1", + "@mdx-js/mdx": "^3.0.0", + "@types/history": "^4.7.11", + "@types/react": "*", + "commander": "^5.1.0", + "joi": "^17.9.2", + "react-helmet-async": "^1.3.0", + "utility-types": "^3.10.0", + "webpack": "^5.88.1", + "webpack-merge": "^5.9.0" + }, + "peerDependencies": { + "react": "^18.0.0", + "react-dom": "^18.0.0" + } + }, + "node_modules/@docusaurus/plugin-sitemap": { + "version": "3.4.0", + "resolved": "https://registry.npmjs.org/@docusaurus/plugin-sitemap/-/plugin-sitemap-3.4.0.tgz", + "integrity": "sha512-+0VDvx9SmNrFNgwPoeoCha+tRoAjopwT0+pYO1xAbyLcewXSemq+eLxEa46Q1/aoOaJQ0qqHELuQM7iS2gp33Q==", + "dependencies": { + "@docusaurus/core": "3.4.0", + "@docusaurus/logger": "3.4.0", + "@docusaurus/types": "3.4.0", + "@docusaurus/utils": "3.4.0", + "@docusaurus/utils-common": "3.4.0", + "@docusaurus/utils-validation": "3.4.0", "fs-extra": "^11.1.1", "sitemap": "^7.1.1", "tslib": "^2.6.0" @@ -2780,24 +2719,44 @@ "react-dom": "^18.0.0" } }, + "node_modules/@docusaurus/plugin-sitemap/node_modules/@docusaurus/types": { + "version": "3.4.0", + "resolved": "https://registry.npmjs.org/@docusaurus/types/-/types-3.4.0.tgz", + "integrity": "sha512-4jcDO8kXi5Cf9TcyikB/yKmz14f2RZ2qTRerbHAsS+5InE9ZgSLBNLsewtFTcTOXSVcbU3FoGOzcNWAmU1TR0A==", + "dependencies": { + "@mdx-js/mdx": "^3.0.0", + "@types/history": "^4.7.11", + "@types/react": "*", + "commander": "^5.1.0", + "joi": "^17.9.2", + "react-helmet-async": "^1.3.0", + "utility-types": "^3.10.0", + "webpack": "^5.88.1", + "webpack-merge": "^5.9.0" + }, + "peerDependencies": { + "react": "^18.0.0", + "react-dom": "^18.0.0" + } + }, "node_modules/@docusaurus/preset-classic": { - "version": "3.0.1", - "resolved": "https://registry.npmjs.org/@docusaurus/preset-classic/-/preset-classic-3.0.1.tgz", - "integrity": "sha512-il9m9xZKKjoXn6h0cRcdnt6wce0Pv1y5t4xk2Wx7zBGhKG1idu4IFHtikHlD0QPuZ9fizpXspXcTzjL5FXc1Gw==", - "dependencies": { - "@docusaurus/core": "3.0.1", - "@docusaurus/plugin-content-blog": "3.0.1", - "@docusaurus/plugin-content-docs": "3.0.1", - "@docusaurus/plugin-content-pages": "3.0.1", - "@docusaurus/plugin-debug": "3.0.1", - "@docusaurus/plugin-google-analytics": "3.0.1", - "@docusaurus/plugin-google-gtag": "3.0.1", - "@docusaurus/plugin-google-tag-manager": "3.0.1", - "@docusaurus/plugin-sitemap": "3.0.1", - "@docusaurus/theme-classic": "3.0.1", - "@docusaurus/theme-common": "3.0.1", - "@docusaurus/theme-search-algolia": "3.0.1", - "@docusaurus/types": "3.0.1" + "version": "3.4.0", + "resolved": "https://registry.npmjs.org/@docusaurus/preset-classic/-/preset-classic-3.4.0.tgz", + "integrity": "sha512-Ohj6KB7siKqZaQhNJVMBBUzT3Nnp6eTKqO+FXO3qu/n1hJl3YLwVKTWBg28LF7MWrKu46UuYavwMRxud0VyqHg==", + "dependencies": { + "@docusaurus/core": "3.4.0", + "@docusaurus/plugin-content-blog": "3.4.0", + "@docusaurus/plugin-content-docs": "3.4.0", + "@docusaurus/plugin-content-pages": "3.4.0", + "@docusaurus/plugin-debug": "3.4.0", + "@docusaurus/plugin-google-analytics": "3.4.0", + "@docusaurus/plugin-google-gtag": "3.4.0", + "@docusaurus/plugin-google-tag-manager": "3.4.0", + "@docusaurus/plugin-sitemap": "3.4.0", + "@docusaurus/theme-classic": "3.4.0", + "@docusaurus/theme-common": "3.4.0", + "@docusaurus/theme-search-algolia": "3.4.0", + "@docusaurus/types": "3.4.0" }, "engines": { "node": ">=18.0" @@ -2807,29 +2766,20 @@ "react-dom": "^18.0.0" } }, - "node_modules/@docusaurus/preset-classic/node_modules/@docusaurus/plugin-content-docs": { - "version": "3.0.1", - "resolved": "https://registry.npmjs.org/@docusaurus/plugin-content-docs/-/plugin-content-docs-3.0.1.tgz", - "integrity": "sha512-dRfAOA5Ivo+sdzzJGXEu33yAtvGg8dlZkvt/NEJ7nwi1F2j4LEdsxtfX2GKeETB2fP6XoGNSQnFXqa2NYGrHFg==", + "node_modules/@docusaurus/preset-classic/node_modules/@docusaurus/types": { + "version": "3.4.0", + "resolved": "https://registry.npmjs.org/@docusaurus/types/-/types-3.4.0.tgz", + "integrity": "sha512-4jcDO8kXi5Cf9TcyikB/yKmz14f2RZ2qTRerbHAsS+5InE9ZgSLBNLsewtFTcTOXSVcbU3FoGOzcNWAmU1TR0A==", "dependencies": { - "@docusaurus/core": "3.0.1", - "@docusaurus/logger": "3.0.1", - "@docusaurus/mdx-loader": "3.0.1", - "@docusaurus/module-type-aliases": "3.0.1", - "@docusaurus/types": "3.0.1", - "@docusaurus/utils": "3.0.1", - "@docusaurus/utils-validation": "3.0.1", - "@types/react-router-config": "^5.0.7", - "combine-promises": "^1.1.0", - "fs-extra": "^11.1.1", - "js-yaml": "^4.1.0", - "lodash": "^4.17.21", - "tslib": "^2.6.0", + "@mdx-js/mdx": "^3.0.0", + "@types/history": "^4.7.11", + "@types/react": "*", + "commander": "^5.1.0", + "joi": "^17.9.2", + "react-helmet-async": "^1.3.0", "utility-types": "^3.10.0", - "webpack": "^5.88.1" - }, - "engines": { - "node": ">=18.0" + "webpack": "^5.88.1", + "webpack-merge": "^5.9.0" }, "peerDependencies": { "react": "^18.0.0", @@ -2840,6 +2790,7 @@ "version": "5.5.2", "resolved": "https://registry.npmjs.org/@docusaurus/react-loadable/-/react-loadable-5.5.2.tgz", "integrity": "sha512-A3dYjdBGuy0IGT+wyLIGIKLRE+sAk1iNk0f1HjNDysO7u8lhL4N3VEm+FAubmJbAztn94F7MxBTPmnixbiyFdQ==", + "dev": true, "dependencies": { "@types/react": "*", "prop-types": "^15.6.2" @@ -2849,22 +2800,22 @@ } }, "node_modules/@docusaurus/theme-classic": { - "version": "3.0.1", - "resolved": "https://registry.npmjs.org/@docusaurus/theme-classic/-/theme-classic-3.0.1.tgz", - "integrity": "sha512-XD1FRXaJiDlmYaiHHdm27PNhhPboUah9rqIH0lMpBt5kYtsGjJzhqa27KuZvHLzOP2OEpqd2+GZ5b6YPq7Q05Q==", - "dependencies": { - "@docusaurus/core": "3.0.1", - "@docusaurus/mdx-loader": "3.0.1", - "@docusaurus/module-type-aliases": "3.0.1", - "@docusaurus/plugin-content-blog": "3.0.1", - "@docusaurus/plugin-content-docs": "3.0.1", - "@docusaurus/plugin-content-pages": "3.0.1", - "@docusaurus/theme-common": "3.0.1", - "@docusaurus/theme-translations": "3.0.1", - "@docusaurus/types": "3.0.1", - "@docusaurus/utils": "3.0.1", - "@docusaurus/utils-common": "3.0.1", - "@docusaurus/utils-validation": "3.0.1", + "version": "3.4.0", + "resolved": "https://registry.npmjs.org/@docusaurus/theme-classic/-/theme-classic-3.4.0.tgz", + "integrity": "sha512-0IPtmxsBYv2adr1GnZRdMkEQt1YW6tpzrUPj02YxNpvJ5+ju4E13J5tB4nfdaen/tfR1hmpSPlTFPvTf4kwy8Q==", + "dependencies": { + "@docusaurus/core": "3.4.0", + "@docusaurus/mdx-loader": "3.4.0", + "@docusaurus/module-type-aliases": "3.4.0", + "@docusaurus/plugin-content-blog": "3.4.0", + "@docusaurus/plugin-content-docs": "3.4.0", + "@docusaurus/plugin-content-pages": "3.4.0", + "@docusaurus/theme-common": "3.4.0", + "@docusaurus/theme-translations": "3.4.0", + "@docusaurus/types": "3.4.0", + "@docusaurus/utils": "3.4.0", + "@docusaurus/utils-common": "3.4.0", + "@docusaurus/utils-validation": "3.4.0", "@mdx-js/react": "^3.0.0", "clsx": "^2.0.0", "copy-text-to-clipboard": "^3.2.0", @@ -2887,47 +2838,68 @@ "react-dom": "^18.0.0" } }, - "node_modules/@docusaurus/theme-classic/node_modules/@docusaurus/plugin-content-docs": { - "version": "3.0.1", - "resolved": "https://registry.npmjs.org/@docusaurus/plugin-content-docs/-/plugin-content-docs-3.0.1.tgz", - "integrity": "sha512-dRfAOA5Ivo+sdzzJGXEu33yAtvGg8dlZkvt/NEJ7nwi1F2j4LEdsxtfX2GKeETB2fP6XoGNSQnFXqa2NYGrHFg==", + "node_modules/@docusaurus/theme-classic/node_modules/@docusaurus/module-type-aliases": { + "version": "3.4.0", + "resolved": "https://registry.npmjs.org/@docusaurus/module-type-aliases/-/module-type-aliases-3.4.0.tgz", + "integrity": "sha512-A1AyS8WF5Bkjnb8s+guTDuYmUiwJzNrtchebBHpc0gz0PyHJNMaybUlSrmJjHVcGrya0LKI4YcR3lBDQfXRYLw==", + "dependencies": { + "@docusaurus/types": "3.4.0", + "@types/history": "^4.7.11", + "@types/react": "*", + "@types/react-router-config": "*", + "@types/react-router-dom": "*", + "react-helmet-async": "*", + "react-loadable": "npm:@docusaurus/react-loadable@6.0.0" + }, + "peerDependencies": { + "react": "*", + "react-dom": "*" + } + }, + "node_modules/@docusaurus/theme-classic/node_modules/@docusaurus/types": { + "version": "3.4.0", + "resolved": "https://registry.npmjs.org/@docusaurus/types/-/types-3.4.0.tgz", + "integrity": "sha512-4jcDO8kXi5Cf9TcyikB/yKmz14f2RZ2qTRerbHAsS+5InE9ZgSLBNLsewtFTcTOXSVcbU3FoGOzcNWAmU1TR0A==", + "dependencies": { + "@mdx-js/mdx": "^3.0.0", + "@types/history": "^4.7.11", + "@types/react": "*", + "commander": "^5.1.0", + "joi": "^17.9.2", + "react-helmet-async": "^1.3.0", + "utility-types": "^3.10.0", + "webpack": "^5.88.1", + "webpack-merge": "^5.9.0" + }, + "peerDependencies": { + "react": "^18.0.0", + "react-dom": "^18.0.0" + } + }, + "node_modules/@docusaurus/theme-classic/node_modules/react-loadable": { + "name": "@docusaurus/react-loadable", + "version": "6.0.0", + "resolved": "https://registry.npmjs.org/@docusaurus/react-loadable/-/react-loadable-6.0.0.tgz", + "integrity": "sha512-YMMxTUQV/QFSnbgrP3tjDzLHRg7vsbMn8e9HAa8o/1iXoiomo48b7sk/kkmWEuWNDPJVlKSJRB6Y2fHqdJk+SQ==", "dependencies": { - "@docusaurus/core": "3.0.1", - "@docusaurus/logger": "3.0.1", - "@docusaurus/mdx-loader": "3.0.1", - "@docusaurus/module-type-aliases": "3.0.1", - "@docusaurus/types": "3.0.1", - "@docusaurus/utils": "3.0.1", - "@docusaurus/utils-validation": "3.0.1", - "@types/react-router-config": "^5.0.7", - "combine-promises": "^1.1.0", - "fs-extra": "^11.1.1", - "js-yaml": "^4.1.0", - "lodash": "^4.17.21", - "tslib": "^2.6.0", - "utility-types": "^3.10.0", - "webpack": "^5.88.1" - }, - "engines": { - "node": ">=18.0" + "@types/react": "*" }, "peerDependencies": { - "react": "^18.0.0", - "react-dom": "^18.0.0" + "react": "*" } }, "node_modules/@docusaurus/theme-common": { - "version": "3.0.1", - "resolved": "https://registry.npmjs.org/@docusaurus/theme-common/-/theme-common-3.0.1.tgz", - "integrity": "sha512-cr9TOWXuIOL0PUfuXv6L5lPlTgaphKP+22NdVBOYah5jSq5XAAulJTjfe+IfLsEG4L7lJttLbhW7LXDFSAI7Ag==", - "dependencies": { - "@docusaurus/mdx-loader": "3.0.1", - "@docusaurus/module-type-aliases": "3.0.1", - "@docusaurus/plugin-content-blog": "3.0.1", - "@docusaurus/plugin-content-docs": "3.0.1", - "@docusaurus/plugin-content-pages": "3.0.1", - "@docusaurus/utils": "3.0.1", - "@docusaurus/utils-common": "3.0.1", + "version": "3.4.0", + "resolved": "https://registry.npmjs.org/@docusaurus/theme-common/-/theme-common-3.4.0.tgz", + "integrity": "sha512-0A27alXuv7ZdCg28oPE8nH/Iz73/IUejVaCazqu9elS4ypjiLhK3KfzdSQBnL/g7YfHSlymZKdiOHEo8fJ0qMA==", + "dependencies": { + "@docusaurus/mdx-loader": "3.4.0", + "@docusaurus/module-type-aliases": "3.4.0", + "@docusaurus/plugin-content-blog": "3.4.0", + "@docusaurus/plugin-content-docs": "3.4.0", + "@docusaurus/plugin-content-pages": "3.4.0", + "@docusaurus/utils": "3.4.0", + "@docusaurus/utils-common": "3.4.0", "@types/history": "^4.7.11", "@types/react": "*", "@types/react-router-config": "*", @@ -2945,45 +2917,66 @@ "react-dom": "^18.0.0" } }, - "node_modules/@docusaurus/theme-common/node_modules/@docusaurus/plugin-content-docs": { - "version": "3.0.1", - "resolved": "https://registry.npmjs.org/@docusaurus/plugin-content-docs/-/plugin-content-docs-3.0.1.tgz", - "integrity": "sha512-dRfAOA5Ivo+sdzzJGXEu33yAtvGg8dlZkvt/NEJ7nwi1F2j4LEdsxtfX2GKeETB2fP6XoGNSQnFXqa2NYGrHFg==", + "node_modules/@docusaurus/theme-common/node_modules/@docusaurus/module-type-aliases": { + "version": "3.4.0", + "resolved": "https://registry.npmjs.org/@docusaurus/module-type-aliases/-/module-type-aliases-3.4.0.tgz", + "integrity": "sha512-A1AyS8WF5Bkjnb8s+guTDuYmUiwJzNrtchebBHpc0gz0PyHJNMaybUlSrmJjHVcGrya0LKI4YcR3lBDQfXRYLw==", "dependencies": { - "@docusaurus/core": "3.0.1", - "@docusaurus/logger": "3.0.1", - "@docusaurus/mdx-loader": "3.0.1", - "@docusaurus/module-type-aliases": "3.0.1", - "@docusaurus/types": "3.0.1", - "@docusaurus/utils": "3.0.1", - "@docusaurus/utils-validation": "3.0.1", - "@types/react-router-config": "^5.0.7", - "combine-promises": "^1.1.0", - "fs-extra": "^11.1.1", - "js-yaml": "^4.1.0", - "lodash": "^4.17.21", - "tslib": "^2.6.0", - "utility-types": "^3.10.0", - "webpack": "^5.88.1" + "@docusaurus/types": "3.4.0", + "@types/history": "^4.7.11", + "@types/react": "*", + "@types/react-router-config": "*", + "@types/react-router-dom": "*", + "react-helmet-async": "*", + "react-loadable": "npm:@docusaurus/react-loadable@6.0.0" }, - "engines": { - "node": ">=18.0" + "peerDependencies": { + "react": "*", + "react-dom": "*" + } + }, + "node_modules/@docusaurus/theme-common/node_modules/@docusaurus/types": { + "version": "3.4.0", + "resolved": "https://registry.npmjs.org/@docusaurus/types/-/types-3.4.0.tgz", + "integrity": "sha512-4jcDO8kXi5Cf9TcyikB/yKmz14f2RZ2qTRerbHAsS+5InE9ZgSLBNLsewtFTcTOXSVcbU3FoGOzcNWAmU1TR0A==", + "dependencies": { + "@mdx-js/mdx": "^3.0.0", + "@types/history": "^4.7.11", + "@types/react": "*", + "commander": "^5.1.0", + "joi": "^17.9.2", + "react-helmet-async": "^1.3.0", + "utility-types": "^3.10.0", + "webpack": "^5.88.1", + "webpack-merge": "^5.9.0" }, "peerDependencies": { "react": "^18.0.0", "react-dom": "^18.0.0" } }, - "node_modules/@docusaurus/theme-mermaid": { - "version": "3.0.1", - "resolved": "https://registry.npmjs.org/@docusaurus/theme-mermaid/-/theme-mermaid-3.0.1.tgz", - "integrity": "sha512-jquSDnZfazABnC5i+02GzRIvufXKruKgvbYkQjKbI7/LWo0XvBs0uKAcCDGgHhth0t/ON5+Sn27joRfpeSk3Lw==", + "node_modules/@docusaurus/theme-common/node_modules/react-loadable": { + "name": "@docusaurus/react-loadable", + "version": "6.0.0", + "resolved": "https://registry.npmjs.org/@docusaurus/react-loadable/-/react-loadable-6.0.0.tgz", + "integrity": "sha512-YMMxTUQV/QFSnbgrP3tjDzLHRg7vsbMn8e9HAa8o/1iXoiomo48b7sk/kkmWEuWNDPJVlKSJRB6Y2fHqdJk+SQ==", "dependencies": { - "@docusaurus/core": "3.0.1", - "@docusaurus/module-type-aliases": "3.0.1", - "@docusaurus/theme-common": "3.0.1", - "@docusaurus/types": "3.0.1", - "@docusaurus/utils-validation": "3.0.1", + "@types/react": "*" + }, + "peerDependencies": { + "react": "*" + } + }, + "node_modules/@docusaurus/theme-mermaid": { + "version": "3.4.0", + "resolved": "https://registry.npmjs.org/@docusaurus/theme-mermaid/-/theme-mermaid-3.4.0.tgz", + "integrity": "sha512-3w5QW0HEZ2O6x2w6lU3ZvOe1gNXP2HIoKDMJBil1VmLBc9PmpAG17VmfhI/p3L2etNmOiVs5GgniUqvn8AFEGQ==", + "dependencies": { + "@docusaurus/core": "3.4.0", + "@docusaurus/module-type-aliases": "3.4.0", + "@docusaurus/theme-common": "3.4.0", + "@docusaurus/types": "3.4.0", + "@docusaurus/utils-validation": "3.4.0", "mermaid": "^10.4.0", "tslib": "^2.6.0" }, @@ -2995,56 +2988,77 @@ "react-dom": "^18.0.0" } }, - "node_modules/@docusaurus/theme-search-algolia": { - "version": "3.0.1", - "resolved": "https://registry.npmjs.org/@docusaurus/theme-search-algolia/-/theme-search-algolia-3.0.1.tgz", - "integrity": "sha512-DDiPc0/xmKSEdwFkXNf1/vH1SzJPzuJBar8kMcBbDAZk/SAmo/4lf6GU2drou4Ae60lN2waix+jYWTWcJRahSA==", + "node_modules/@docusaurus/theme-mermaid/node_modules/@docusaurus/module-type-aliases": { + "version": "3.4.0", + "resolved": "https://registry.npmjs.org/@docusaurus/module-type-aliases/-/module-type-aliases-3.4.0.tgz", + "integrity": "sha512-A1AyS8WF5Bkjnb8s+guTDuYmUiwJzNrtchebBHpc0gz0PyHJNMaybUlSrmJjHVcGrya0LKI4YcR3lBDQfXRYLw==", "dependencies": { - "@docsearch/react": "^3.5.2", - "@docusaurus/core": "3.0.1", - "@docusaurus/logger": "3.0.1", - "@docusaurus/plugin-content-docs": "3.0.1", - "@docusaurus/theme-common": "3.0.1", - "@docusaurus/theme-translations": "3.0.1", - "@docusaurus/utils": "3.0.1", - "@docusaurus/utils-validation": "3.0.1", - "algoliasearch": "^4.18.0", - "algoliasearch-helper": "^3.13.3", - "clsx": "^2.0.0", - "eta": "^2.2.0", - "fs-extra": "^11.1.1", - "lodash": "^4.17.21", - "tslib": "^2.6.0", - "utility-types": "^3.10.0" + "@docusaurus/types": "3.4.0", + "@types/history": "^4.7.11", + "@types/react": "*", + "@types/react-router-config": "*", + "@types/react-router-dom": "*", + "react-helmet-async": "*", + "react-loadable": "npm:@docusaurus/react-loadable@6.0.0" }, - "engines": { - "node": ">=18.0" + "peerDependencies": { + "react": "*", + "react-dom": "*" + } + }, + "node_modules/@docusaurus/theme-mermaid/node_modules/@docusaurus/types": { + "version": "3.4.0", + "resolved": "https://registry.npmjs.org/@docusaurus/types/-/types-3.4.0.tgz", + "integrity": "sha512-4jcDO8kXi5Cf9TcyikB/yKmz14f2RZ2qTRerbHAsS+5InE9ZgSLBNLsewtFTcTOXSVcbU3FoGOzcNWAmU1TR0A==", + "dependencies": { + "@mdx-js/mdx": "^3.0.0", + "@types/history": "^4.7.11", + "@types/react": "*", + "commander": "^5.1.0", + "joi": "^17.9.2", + "react-helmet-async": "^1.3.0", + "utility-types": "^3.10.0", + "webpack": "^5.88.1", + "webpack-merge": "^5.9.0" }, "peerDependencies": { "react": "^18.0.0", "react-dom": "^18.0.0" } }, - "node_modules/@docusaurus/theme-search-algolia/node_modules/@docusaurus/plugin-content-docs": { - "version": "3.0.1", - "resolved": "https://registry.npmjs.org/@docusaurus/plugin-content-docs/-/plugin-content-docs-3.0.1.tgz", - "integrity": "sha512-dRfAOA5Ivo+sdzzJGXEu33yAtvGg8dlZkvt/NEJ7nwi1F2j4LEdsxtfX2GKeETB2fP6XoGNSQnFXqa2NYGrHFg==", + "node_modules/@docusaurus/theme-mermaid/node_modules/react-loadable": { + "name": "@docusaurus/react-loadable", + "version": "6.0.0", + "resolved": "https://registry.npmjs.org/@docusaurus/react-loadable/-/react-loadable-6.0.0.tgz", + "integrity": "sha512-YMMxTUQV/QFSnbgrP3tjDzLHRg7vsbMn8e9HAa8o/1iXoiomo48b7sk/kkmWEuWNDPJVlKSJRB6Y2fHqdJk+SQ==", "dependencies": { - "@docusaurus/core": "3.0.1", - "@docusaurus/logger": "3.0.1", - "@docusaurus/mdx-loader": "3.0.1", - "@docusaurus/module-type-aliases": "3.0.1", - "@docusaurus/types": "3.0.1", - "@docusaurus/utils": "3.0.1", - "@docusaurus/utils-validation": "3.0.1", - "@types/react-router-config": "^5.0.7", - "combine-promises": "^1.1.0", + "@types/react": "*" + }, + "peerDependencies": { + "react": "*" + } + }, + "node_modules/@docusaurus/theme-search-algolia": { + "version": "3.4.0", + "resolved": "https://registry.npmjs.org/@docusaurus/theme-search-algolia/-/theme-search-algolia-3.4.0.tgz", + "integrity": "sha512-aiHFx7OCw4Wck1z6IoShVdUWIjntC8FHCw9c5dR8r3q4Ynh+zkS8y2eFFunN/DL6RXPzpnvKCg3vhLQYJDmT9Q==", + "dependencies": { + "@docsearch/react": "^3.5.2", + "@docusaurus/core": "3.4.0", + "@docusaurus/logger": "3.4.0", + "@docusaurus/plugin-content-docs": "3.4.0", + "@docusaurus/theme-common": "3.4.0", + "@docusaurus/theme-translations": "3.4.0", + "@docusaurus/utils": "3.4.0", + "@docusaurus/utils-validation": "3.4.0", + "algoliasearch": "^4.18.0", + "algoliasearch-helper": "^3.13.3", + "clsx": "^2.0.0", + "eta": "^2.2.0", "fs-extra": "^11.1.1", - "js-yaml": "^4.1.0", "lodash": "^4.17.21", "tslib": "^2.6.0", - "utility-types": "^3.10.0", - "webpack": "^5.88.1" + "utility-types": "^3.10.0" }, "engines": { "node": ">=18.0" @@ -3055,9 +3069,9 @@ } }, "node_modules/@docusaurus/theme-translations": { - "version": "3.0.1", - "resolved": "https://registry.npmjs.org/@docusaurus/theme-translations/-/theme-translations-3.0.1.tgz", - "integrity": "sha512-6UrbpzCTN6NIJnAtZ6Ne9492vmPVX+7Fsz4kmp+yor3KQwA1+MCzQP7ItDNkP38UmVLnvB/cYk/IvehCUqS3dg==", + "version": "3.4.0", + "resolved": "https://registry.npmjs.org/@docusaurus/theme-translations/-/theme-translations-3.4.0.tgz", + "integrity": "sha512-zSxCSpmQCCdQU5Q4CnX/ID8CSUUI3fvmq4hU/GNP/XoAWtXo9SAVnM3TzpU8Gb//H3WCsT8mJcTfyOk3d9ftNg==", "dependencies": { "fs-extra": "^11.1.1", "tslib": "^2.6.0" @@ -3070,6 +3084,7 @@ "version": "3.0.1", "resolved": "https://registry.npmjs.org/@docusaurus/types/-/types-3.0.1.tgz", "integrity": "sha512-plyX2iU1tcUsF46uQ01pAd4JhexR7n0iiQ5MSnBFX6M6NSJgDYdru/i1/YNPKOnQHBoXGLHv0dNT6OAlDWNjrg==", + "devOptional": true, "dependencies": { "@types/history": "^4.7.11", "@types/react": "*", @@ -3086,12 +3101,13 @@ } }, "node_modules/@docusaurus/utils": { - "version": "3.0.1", - "resolved": "https://registry.npmjs.org/@docusaurus/utils/-/utils-3.0.1.tgz", - "integrity": "sha512-TwZ33Am0q4IIbvjhUOs+zpjtD/mXNmLmEgeTGuRq01QzulLHuPhaBTTAC/DHu6kFx3wDgmgpAlaRuCHfTcXv8g==", + "version": "3.4.0", + "resolved": "https://registry.npmjs.org/@docusaurus/utils/-/utils-3.4.0.tgz", + "integrity": "sha512-fRwnu3L3nnWaXOgs88BVBmG1yGjcQqZNHG+vInhEa2Sz2oQB+ZjbEMO5Rh9ePFpZ0YDiDUhpaVjwmS+AU2F14g==", "dependencies": { - "@docusaurus/logger": "3.0.1", - "@svgr/webpack": "^6.5.1", + "@docusaurus/logger": "3.4.0", + "@docusaurus/utils-common": "3.4.0", + "@svgr/webpack": "^8.1.0", "escape-string-regexp": "^4.0.0", "file-loader": "^6.2.0", "fs-extra": "^11.1.1", @@ -3102,10 +3118,12 @@ "js-yaml": "^4.1.0", "lodash": "^4.17.21", "micromatch": "^4.0.5", + "prompts": "^2.4.2", "resolve-pathname": "^3.0.0", "shelljs": "^0.8.5", "tslib": "^2.6.0", "url-loader": "^4.1.1", + "utility-types": "^3.10.0", "webpack": "^5.88.1" }, "engines": { @@ -3121,9 +3139,9 @@ } }, "node_modules/@docusaurus/utils-common": { - "version": "3.0.1", - "resolved": "https://registry.npmjs.org/@docusaurus/utils-common/-/utils-common-3.0.1.tgz", - "integrity": "sha512-W0AxD6w6T8g6bNro8nBRWf7PeZ/nn7geEWM335qHU2DDDjHuV4UZjgUGP1AQsdcSikPrlIqTJJbKzer1lRSlIg==", + "version": "3.4.0", + "resolved": "https://registry.npmjs.org/@docusaurus/utils-common/-/utils-common-3.4.0.tgz", + "integrity": "sha512-NVx54Wr4rCEKsjOH5QEVvxIqVvm+9kh7q8aYTU5WzUU9/Hctd6aTrcZ3G0Id4zYJ+AeaG5K5qHA4CY5Kcm2iyQ==", "dependencies": { "tslib": "^2.6.0" }, @@ -3140,14 +3158,17 @@ } }, "node_modules/@docusaurus/utils-validation": { - "version": "3.0.1", - "resolved": "https://registry.npmjs.org/@docusaurus/utils-validation/-/utils-validation-3.0.1.tgz", - "integrity": "sha512-ujTnqSfyGQ7/4iZdB4RRuHKY/Nwm58IIb+41s5tCXOv/MBU2wGAjOHq3U+AEyJ8aKQcHbxvTKJaRchNHYUVUQg==", - "dependencies": { - "@docusaurus/logger": "3.0.1", - "@docusaurus/utils": "3.0.1", + "version": "3.4.0", + "resolved": "https://registry.npmjs.org/@docusaurus/utils-validation/-/utils-validation-3.4.0.tgz", + "integrity": "sha512-hYQ9fM+AXYVTWxJOT1EuNaRnrR2WGpRdLDQG07O8UOpsvCPWUVOeo26Rbm0JWY2sGLfzAb+tvJ62yF+8F+TV0g==", + "dependencies": { + "@docusaurus/logger": "3.4.0", + "@docusaurus/utils": "3.4.0", + "@docusaurus/utils-common": "3.4.0", + "fs-extra": "^11.2.0", "joi": "^17.9.2", "js-yaml": "^4.1.0", + "lodash": "^4.17.21", "tslib": "^2.6.0" }, "engines": { @@ -3252,9 +3273,9 @@ "integrity": "sha512-Hcv+nVC0kZnQ3tD9GVu5xSMR4VVYOteQIr/hwFPVEvPdlXqgGEuRjiheChHgdM+JyqdgNcmzZOX/tnl0JOiI7A==" }, "node_modules/@mdx-js/mdx": { - "version": "3.0.0", - "resolved": "https://registry.npmjs.org/@mdx-js/mdx/-/mdx-3.0.0.tgz", - "integrity": "sha512-Icm0TBKBLYqroYbNW3BPnzMGn+7mwpQOK310aZ7+fkCtiU3aqv2cdcX+nd0Ydo3wI5Rx8bX2Z2QmGb/XcAClCw==", + "version": "3.0.1", + "resolved": "https://registry.npmjs.org/@mdx-js/mdx/-/mdx-3.0.1.tgz", + "integrity": "sha512-eIQ4QTrOWyL3LWEe/bu6Taqzq2HQvHcyTMaOrI95P2/LmJE7AsfPfgJGuFLPVqBUE1BC1rik3VIhU+s9u72arA==", "dependencies": { "@types/estree": "^1.0.0", "@types/estree-jsx": "^1.0.0", @@ -3419,25 +3440,12 @@ "micromark-util-symbol": "^1.0.1" } }, - "node_modules/@slorber/static-site-generator-webpack-plugin": { - "version": "4.0.7", - "resolved": "https://registry.npmjs.org/@slorber/static-site-generator-webpack-plugin/-/static-site-generator-webpack-plugin-4.0.7.tgz", - "integrity": "sha512-Ug7x6z5lwrz0WqdnNFOMYrDQNTPAprvHLSh6+/fmml3qUiz6l5eq+2MzLKWtn/q5K5NpSiFsZTP/fck/3vjSxA==", - "dependencies": { - "eval": "^0.1.8", - "p-map": "^4.0.0", - "webpack-sources": "^3.2.2" - }, - "engines": { - "node": ">=14" - } - }, "node_modules/@svgr/babel-plugin-add-jsx-attribute": { - "version": "6.5.1", - "resolved": "https://registry.npmjs.org/@svgr/babel-plugin-add-jsx-attribute/-/babel-plugin-add-jsx-attribute-6.5.1.tgz", - "integrity": "sha512-9PYGcXrAxitycIjRmZB+Q0JaN07GZIWaTBIGQzfaZv+qr1n8X1XUEJ5rZ/vx6OVD9RRYlrNnXWExQXcmZeD/BQ==", + "version": "8.0.0", + "resolved": "https://registry.npmjs.org/@svgr/babel-plugin-add-jsx-attribute/-/babel-plugin-add-jsx-attribute-8.0.0.tgz", + "integrity": "sha512-b9MIk7yhdS1pMCZM8VeNfUlSKVRhsHZNMl5O9SfaX0l0t5wjdgu4IDzGB8bpnGBBOjGST3rRFVsaaEtI4W6f7g==", "engines": { - "node": ">=10" + "node": ">=14" }, "funding": { "type": "github", @@ -3478,11 +3486,11 @@ } }, "node_modules/@svgr/babel-plugin-replace-jsx-attribute-value": { - "version": "6.5.1", - "resolved": "https://registry.npmjs.org/@svgr/babel-plugin-replace-jsx-attribute-value/-/babel-plugin-replace-jsx-attribute-value-6.5.1.tgz", - "integrity": "sha512-8DPaVVE3fd5JKuIC29dqyMB54sA6mfgki2H2+swh+zNJoynC8pMPzOkidqHOSc6Wj032fhl8Z0TVn1GiPpAiJg==", + "version": "8.0.0", + "resolved": "https://registry.npmjs.org/@svgr/babel-plugin-replace-jsx-attribute-value/-/babel-plugin-replace-jsx-attribute-value-8.0.0.tgz", + "integrity": "sha512-KVQ+PtIjb1BuYT3ht8M5KbzWBhdAjjUPdlMtpuw/VjT8coTrItWX6Qafl9+ji831JaJcu6PJNKCV0bp01lBNzQ==", "engines": { - "node": ">=10" + "node": ">=14" }, "funding": { "type": "github", @@ -3493,11 +3501,11 @@ } }, "node_modules/@svgr/babel-plugin-svg-dynamic-title": { - "version": "6.5.1", - "resolved": "https://registry.npmjs.org/@svgr/babel-plugin-svg-dynamic-title/-/babel-plugin-svg-dynamic-title-6.5.1.tgz", - "integrity": "sha512-FwOEi0Il72iAzlkaHrlemVurgSQRDFbk0OC8dSvD5fSBPHltNh7JtLsxmZUhjYBZo2PpcU/RJvvi6Q0l7O7ogw==", + "version": "8.0.0", + "resolved": "https://registry.npmjs.org/@svgr/babel-plugin-svg-dynamic-title/-/babel-plugin-svg-dynamic-title-8.0.0.tgz", + "integrity": "sha512-omNiKqwjNmOQJ2v6ge4SErBbkooV2aAWwaPFs2vUY7p7GhVkzRkJ00kILXQvRhA6miHnNpXv7MRnnSjdRjK8og==", "engines": { - "node": ">=10" + "node": ">=14" }, "funding": { "type": "github", @@ -3508,11 +3516,11 @@ } }, "node_modules/@svgr/babel-plugin-svg-em-dimensions": { - "version": "6.5.1", - "resolved": "https://registry.npmjs.org/@svgr/babel-plugin-svg-em-dimensions/-/babel-plugin-svg-em-dimensions-6.5.1.tgz", - "integrity": "sha512-gWGsiwjb4tw+ITOJ86ndY/DZZ6cuXMNE/SjcDRg+HLuCmwpcjOktwRF9WgAiycTqJD/QXqL2f8IzE2Rzh7aVXA==", + "version": "8.0.0", + "resolved": "https://registry.npmjs.org/@svgr/babel-plugin-svg-em-dimensions/-/babel-plugin-svg-em-dimensions-8.0.0.tgz", + "integrity": "sha512-mURHYnu6Iw3UBTbhGwE/vsngtCIbHE43xCRK7kCw4t01xyGqb2Pd+WXekRRoFOBIY29ZoOhUCTEweDMdrjfi9g==", "engines": { - "node": ">=10" + "node": ">=14" }, "funding": { "type": "github", @@ -3523,11 +3531,11 @@ } }, "node_modules/@svgr/babel-plugin-transform-react-native-svg": { - "version": "6.5.1", - "resolved": "https://registry.npmjs.org/@svgr/babel-plugin-transform-react-native-svg/-/babel-plugin-transform-react-native-svg-6.5.1.tgz", - "integrity": "sha512-2jT3nTayyYP7kI6aGutkyfJ7UMGtuguD72OjeGLwVNyfPRBD8zQthlvL+fAbAKk5n9ZNcvFkp/b1lZ7VsYqVJg==", + "version": "8.1.0", + "resolved": "https://registry.npmjs.org/@svgr/babel-plugin-transform-react-native-svg/-/babel-plugin-transform-react-native-svg-8.1.0.tgz", + "integrity": "sha512-Tx8T58CHo+7nwJ+EhUwx3LfdNSG9R2OKfaIXXs5soiy5HtgoAEkDay9LIimLOcG8dJQH1wPZp/cnAv6S9CrR1Q==", "engines": { - "node": ">=10" + "node": ">=14" }, "funding": { "type": "github", @@ -3538,9 +3546,9 @@ } }, "node_modules/@svgr/babel-plugin-transform-svg-component": { - "version": "6.5.1", - "resolved": "https://registry.npmjs.org/@svgr/babel-plugin-transform-svg-component/-/babel-plugin-transform-svg-component-6.5.1.tgz", - "integrity": "sha512-a1p6LF5Jt33O3rZoVRBqdxL350oge54iZWHNI6LJB5tQ7EelvD/Mb1mfBiZNAan0dt4i3VArkFRjA4iObuNykQ==", + "version": "8.0.0", + "resolved": "https://registry.npmjs.org/@svgr/babel-plugin-transform-svg-component/-/babel-plugin-transform-svg-component-8.0.0.tgz", + "integrity": "sha512-DFx8xa3cZXTdb/k3kfPeaixecQLgKh5NVBMwD0AQxOzcZawK4oo1Jh9LbrcACUivsCA7TLG8eeWgrDXjTMhRmw==", "engines": { "node": ">=12" }, @@ -3553,21 +3561,21 @@ } }, "node_modules/@svgr/babel-preset": { - "version": "6.5.1", - "resolved": "https://registry.npmjs.org/@svgr/babel-preset/-/babel-preset-6.5.1.tgz", - "integrity": "sha512-6127fvO/FF2oi5EzSQOAjo1LE3OtNVh11R+/8FXa+mHx1ptAaS4cknIjnUA7e6j6fwGGJ17NzaTJFUwOV2zwCw==", + "version": "8.1.0", + "resolved": "https://registry.npmjs.org/@svgr/babel-preset/-/babel-preset-8.1.0.tgz", + "integrity": "sha512-7EYDbHE7MxHpv4sxvnVPngw5fuR6pw79SkcrILHJ/iMpuKySNCl5W1qcwPEpU+LgyRXOaAFgH0KhwD18wwg6ug==", "dependencies": { - "@svgr/babel-plugin-add-jsx-attribute": "^6.5.1", - "@svgr/babel-plugin-remove-jsx-attribute": "*", - "@svgr/babel-plugin-remove-jsx-empty-expression": "*", - "@svgr/babel-plugin-replace-jsx-attribute-value": "^6.5.1", - "@svgr/babel-plugin-svg-dynamic-title": "^6.5.1", - "@svgr/babel-plugin-svg-em-dimensions": "^6.5.1", - "@svgr/babel-plugin-transform-react-native-svg": "^6.5.1", - "@svgr/babel-plugin-transform-svg-component": "^6.5.1" + "@svgr/babel-plugin-add-jsx-attribute": "8.0.0", + "@svgr/babel-plugin-remove-jsx-attribute": "8.0.0", + "@svgr/babel-plugin-remove-jsx-empty-expression": "8.0.0", + "@svgr/babel-plugin-replace-jsx-attribute-value": "8.0.0", + "@svgr/babel-plugin-svg-dynamic-title": "8.0.0", + "@svgr/babel-plugin-svg-em-dimensions": "8.0.0", + "@svgr/babel-plugin-transform-react-native-svg": "8.1.0", + "@svgr/babel-plugin-transform-svg-component": "8.0.0" }, "engines": { - "node": ">=10" + "node": ">=14" }, "funding": { "type": "github", @@ -3578,18 +3586,18 @@ } }, "node_modules/@svgr/core": { - "version": "6.5.1", - "resolved": "https://registry.npmjs.org/@svgr/core/-/core-6.5.1.tgz", - "integrity": "sha512-/xdLSWxK5QkqG524ONSjvg3V/FkNyCv538OIBdQqPNaAta3AsXj/Bd2FbvR87yMbXO2hFSWiAe/Q6IkVPDw+mw==", + "version": "8.1.0", + "resolved": "https://registry.npmjs.org/@svgr/core/-/core-8.1.0.tgz", + "integrity": "sha512-8QqtOQT5ACVlmsvKOJNEaWmRPmcojMOzCz4Hs2BGG/toAp/K38LcsMRyLp349glq5AzJbCEeimEoxaX6v/fLrA==", "dependencies": { - "@babel/core": "^7.19.6", - "@svgr/babel-preset": "^6.5.1", - "@svgr/plugin-jsx": "^6.5.1", + "@babel/core": "^7.21.3", + "@svgr/babel-preset": "8.1.0", "camelcase": "^6.2.0", - "cosmiconfig": "^7.0.1" + "cosmiconfig": "^8.1.3", + "snake-case": "^3.0.4" }, "engines": { - "node": ">=10" + "node": ">=14" }, "funding": { "type": "github", @@ -3597,15 +3605,15 @@ } }, "node_modules/@svgr/hast-util-to-babel-ast": { - "version": "6.5.1", - "resolved": "https://registry.npmjs.org/@svgr/hast-util-to-babel-ast/-/hast-util-to-babel-ast-6.5.1.tgz", - "integrity": "sha512-1hnUxxjd83EAxbL4a0JDJoD3Dao3hmjvyvyEV8PzWmLK3B9m9NPlW7GKjFyoWE8nM7HnXzPcmmSyOW8yOddSXw==", + "version": "8.0.0", + "resolved": "https://registry.npmjs.org/@svgr/hast-util-to-babel-ast/-/hast-util-to-babel-ast-8.0.0.tgz", + "integrity": "sha512-EbDKwO9GpfWP4jN9sGdYwPBU0kdomaPIL2Eu4YwmgP+sJeXT+L7bMwJUBnhzfH8Q2qMBqZ4fJwpCyYsAN3mt2Q==", "dependencies": { - "@babel/types": "^7.20.0", + "@babel/types": "^7.21.3", "entities": "^4.4.0" }, "engines": { - "node": ">=10" + "node": ">=14" }, "funding": { "type": "github", @@ -3613,37 +3621,37 @@ } }, "node_modules/@svgr/plugin-jsx": { - "version": "6.5.1", - "resolved": "https://registry.npmjs.org/@svgr/plugin-jsx/-/plugin-jsx-6.5.1.tgz", - "integrity": "sha512-+UdQxI3jgtSjCykNSlEMuy1jSRQlGC7pqBCPvkG/2dATdWo082zHTTK3uhnAju2/6XpE6B5mZ3z4Z8Ns01S8Gw==", + "version": "8.1.0", + "resolved": "https://registry.npmjs.org/@svgr/plugin-jsx/-/plugin-jsx-8.1.0.tgz", + "integrity": "sha512-0xiIyBsLlr8quN+WyuxooNW9RJ0Dpr8uOnH/xrCVO8GLUcwHISwj1AG0k+LFzteTkAA0GbX0kj9q6Dk70PTiPA==", "dependencies": { - "@babel/core": "^7.19.6", - "@svgr/babel-preset": "^6.5.1", - "@svgr/hast-util-to-babel-ast": "^6.5.1", + "@babel/core": "^7.21.3", + "@svgr/babel-preset": "8.1.0", + "@svgr/hast-util-to-babel-ast": "8.0.0", "svg-parser": "^2.0.4" }, "engines": { - "node": ">=10" + "node": ">=14" }, "funding": { "type": "github", "url": "https://github.com/sponsors/gregberge" }, "peerDependencies": { - "@svgr/core": "^6.0.0" + "@svgr/core": "*" } }, "node_modules/@svgr/plugin-svgo": { - "version": "6.5.1", - "resolved": "https://registry.npmjs.org/@svgr/plugin-svgo/-/plugin-svgo-6.5.1.tgz", - "integrity": "sha512-omvZKf8ixP9z6GWgwbtmP9qQMPX4ODXi+wzbVZgomNFsUIlHA1sf4fThdwTWSsZGgvGAG6yE+b/F5gWUkcZ/iQ==", + "version": "8.1.0", + "resolved": "https://registry.npmjs.org/@svgr/plugin-svgo/-/plugin-svgo-8.1.0.tgz", + "integrity": "sha512-Ywtl837OGO9pTLIN/onoWLmDQ4zFUycI1g76vuKGEz6evR/ZTJlJuz3G/fIkb6OVBJ2g0o6CGJzaEjfmEo3AHA==", "dependencies": { - "cosmiconfig": "^7.0.1", - "deepmerge": "^4.2.2", - "svgo": "^2.8.0" + "cosmiconfig": "^8.1.3", + "deepmerge": "^4.3.1", + "svgo": "^3.0.2" }, "engines": { - "node": ">=10" + "node": ">=14" }, "funding": { "type": "github", @@ -3654,21 +3662,21 @@ } }, "node_modules/@svgr/webpack": { - "version": "6.5.1", - "resolved": "https://registry.npmjs.org/@svgr/webpack/-/webpack-6.5.1.tgz", - "integrity": "sha512-cQ/AsnBkXPkEK8cLbv4Dm7JGXq2XrumKnL1dRpJD9rIO2fTIlJI9a1uCciYG1F2aUsox/hJQyNGbt3soDxSRkA==", + "version": "8.1.0", + "resolved": "https://registry.npmjs.org/@svgr/webpack/-/webpack-8.1.0.tgz", + "integrity": "sha512-LnhVjMWyMQV9ZmeEy26maJk+8HTIbd59cH4F2MJ439k9DqejRisfFNGAPvRYlKETuh9LrImlS8aKsBgKjMA8WA==", "dependencies": { - "@babel/core": "^7.19.6", - "@babel/plugin-transform-react-constant-elements": "^7.18.12", - "@babel/preset-env": "^7.19.4", + "@babel/core": "^7.21.3", + "@babel/plugin-transform-react-constant-elements": "^7.21.3", + "@babel/preset-env": "^7.20.2", "@babel/preset-react": "^7.18.6", - "@babel/preset-typescript": "^7.18.6", - "@svgr/core": "^6.5.1", - "@svgr/plugin-jsx": "^6.5.1", - "@svgr/plugin-svgo": "^6.5.1" + "@babel/preset-typescript": "^7.21.0", + "@svgr/core": "8.1.0", + "@svgr/plugin-jsx": "8.1.0", + "@svgr/plugin-svgo": "8.1.0" }, "engines": { - "node": ">=10" + "node": ">=14" }, "funding": { "type": "github", @@ -3786,9 +3794,9 @@ "integrity": "sha512-/kYRxGDLWzHOB7q+wtSUQlFrtcdUccpfy+X+9iMBpHK8QLLhx2wIPYuS5DYtR9Wa/YlZAbIovy7qVdB1Aq6Lyw==" }, "node_modules/@types/estree-jsx": { - "version": "1.0.3", - "resolved": "https://registry.npmjs.org/@types/estree-jsx/-/estree-jsx-1.0.3.tgz", - "integrity": "sha512-pvQ+TKeRHeiUGRhvYwRrQ/ISnohKkSJR14fT2yqyZ4e9K5vqc7hrtY2Y1Dw0ZwAzQ6DQsxsaCUuSIIi8v0Cq6w==", + "version": "1.0.5", + "resolved": "https://registry.npmjs.org/@types/estree-jsx/-/estree-jsx-1.0.5.tgz", + "integrity": "sha512-52CcUVNFyfb1A2ALocQw/Dd1BQFNmSdkuC3BkZ6iqhdMfQz7JWOFRuJFloOzjk+6WijU56m9oKXFAXc7o3Towg==", "dependencies": { "@types/estree": "*" } @@ -3821,9 +3829,9 @@ "integrity": "sha512-YQV9bUsemkzG81Ea295/nF/5GijnD2Af7QhEofh7xu+kvCN6RdodgNwwGWXB5GMI3NoyvQo0odNctoH/qLMIpg==" }, "node_modules/@types/hast": { - "version": "3.0.3", - "resolved": "https://registry.npmjs.org/@types/hast/-/hast-3.0.3.tgz", - "integrity": "sha512-2fYGlaDy/qyLlhidX42wAH0KBi2TCjKMH8CHmBXgRlJ3Y+OXTiqsPQ6IWarZKwF1JoUcAJdPogv1d4b0COTpmQ==", + "version": "3.0.4", + "resolved": "https://registry.npmjs.org/@types/hast/-/hast-3.0.4.tgz", + "integrity": "sha512-WPs+bbQw5aCj+x6laNGWLH3wviHtoCv/P3+otBhbOhJgG8qtpdAMlTCxLtsTWA7LH1Oh/bFCHsBn0TPS5m30EQ==", "dependencies": { "@types/unist": "*" } @@ -3883,9 +3891,9 @@ "integrity": "sha512-5+fP8P8MFNC+AyZCDxrB2pkZFPGzqQWUzpSeuuVLvm8VMcorNYavBqoFcxK8bQz4Qsbn4oUEEem4wDLfcysGHA==" }, "node_modules/@types/mdast": { - "version": "4.0.3", - "resolved": "https://registry.npmjs.org/@types/mdast/-/mdast-4.0.3.tgz", - "integrity": "sha512-LsjtqsyF+d2/yFOYaN22dHZI1Cpwkrj+g06G8+qtUKlhovPW89YhqSnfKtMbkgmEtYpH2gydRNULd6y8mciAFg==", + "version": "4.0.4", + "resolved": "https://registry.npmjs.org/@types/mdast/-/mdast-4.0.4.tgz", + "integrity": "sha512-kGaNbPh1k7AFzgpud/gMdvIm5xuECykRR+JnWKQno9TAXVa6WIVCGTPvYGekIDL4uwCZQSYbUxNBSb1aUo79oA==", "dependencies": { "@types/unist": "*" } @@ -4339,30 +4347,31 @@ } }, "node_modules/algoliasearch": { - "version": "4.22.0", - "resolved": "https://registry.npmjs.org/algoliasearch/-/algoliasearch-4.22.0.tgz", - "integrity": "sha512-gfceltjkwh7PxXwtkS8KVvdfK+TSNQAWUeNSxf4dA29qW5tf2EGwa8jkJujlT9jLm17cixMVoGNc+GJFO1Mxhg==", - "dependencies": { - "@algolia/cache-browser-local-storage": "4.22.0", - "@algolia/cache-common": "4.22.0", - "@algolia/cache-in-memory": "4.22.0", - "@algolia/client-account": "4.22.0", - "@algolia/client-analytics": "4.22.0", - "@algolia/client-common": "4.22.0", - "@algolia/client-personalization": "4.22.0", - "@algolia/client-search": "4.22.0", - "@algolia/logger-common": "4.22.0", - "@algolia/logger-console": "4.22.0", - "@algolia/requester-browser-xhr": "4.22.0", - "@algolia/requester-common": "4.22.0", - "@algolia/requester-node-http": "4.22.0", - "@algolia/transporter": "4.22.0" + "version": "4.23.3", + "resolved": "https://registry.npmjs.org/algoliasearch/-/algoliasearch-4.23.3.tgz", + "integrity": "sha512-Le/3YgNvjW9zxIQMRhUHuhiUjAlKY/zsdZpfq4dlLqg6mEm0nL6yk+7f2hDOtLpxsgE4jSzDmvHL7nXdBp5feg==", + "dependencies": { + "@algolia/cache-browser-local-storage": "4.23.3", + "@algolia/cache-common": "4.23.3", + "@algolia/cache-in-memory": "4.23.3", + "@algolia/client-account": "4.23.3", + "@algolia/client-analytics": "4.23.3", + "@algolia/client-common": "4.23.3", + "@algolia/client-personalization": "4.23.3", + "@algolia/client-search": "4.23.3", + "@algolia/logger-common": "4.23.3", + "@algolia/logger-console": "4.23.3", + "@algolia/recommend": "4.23.3", + "@algolia/requester-browser-xhr": "4.23.3", + "@algolia/requester-common": "4.23.3", + "@algolia/requester-node-http": "4.23.3", + "@algolia/transporter": "4.23.3" } }, "node_modules/algoliasearch-helper": { - "version": "3.16.1", - "resolved": "https://registry.npmjs.org/algoliasearch-helper/-/algoliasearch-helper-3.16.1.tgz", - "integrity": "sha512-qxAHVjjmT7USVvrM8q6gZGaJlCK1fl4APfdAA7o8O6iXEc68G0xMNrzRkxoB/HmhhvyHnoteS/iMTiHiTcQQcg==", + "version": "3.21.0", + "resolved": "https://registry.npmjs.org/algoliasearch-helper/-/algoliasearch-helper-3.21.0.tgz", + "integrity": "sha512-hjVOrL15I3Y3K8xG0icwG1/tWE+MocqBrhW6uVBWpU+/kVEMK0BnM2xdssj6mZM61eJ4iRxHR0djEI3ENOpR8w==", "dependencies": { "@algolia/events": "^4.0.1" }, @@ -4481,9 +4490,9 @@ } }, "node_modules/autoprefixer": { - "version": "10.4.16", - "resolved": "https://registry.npmjs.org/autoprefixer/-/autoprefixer-10.4.16.tgz", - "integrity": "sha512-7vd3UC6xKp0HLfua5IjZlcXvGAGy7cBAXTg2lyQ/8WpNhd6SiZ8Be+xm3FyBSYJx5GKcpRCzBh7RH4/0dnY+uQ==", + "version": "10.4.19", + "resolved": "https://registry.npmjs.org/autoprefixer/-/autoprefixer-10.4.19.tgz", + "integrity": "sha512-BaENR2+zBZ8xXhM4pUaKUxlVdxZ0EZhjvbopwnXmxRUfqDmwSpC2lAi/QXvx7NRdPCo1WKEcEF6mV64si1z4Ew==", "funding": [ { "type": "opencollective", @@ -4499,9 +4508,9 @@ } ], "dependencies": { - "browserslist": "^4.21.10", - "caniuse-lite": "^1.0.30001538", - "fraction.js": "^4.3.6", + "browserslist": "^4.23.0", + "caniuse-lite": "^1.0.30001599", + "fraction.js": "^4.3.7", "normalize-range": "^0.1.2", "picocolors": "^1.0.0", "postcss-value-parser": "^4.2.0" @@ -4710,20 +4719,20 @@ } }, "node_modules/braces": { - "version": "3.0.2", - "resolved": "https://registry.npmjs.org/braces/-/braces-3.0.2.tgz", - "integrity": "sha512-b8um+L1RzM3WDSzvhm6gIz1yfTbBt6YTlcEKAvsmqCZZFw46z626lVj9j1yEPW33H5H+lBQpZMP1k8l+78Ha0A==", + "version": "3.0.3", + "resolved": "https://registry.npmjs.org/braces/-/braces-3.0.3.tgz", + "integrity": "sha512-yQbXgO/OSZVD2IsiLlro+7Hf6Q18EJrKSEsdoMzKePKXct3gvD8oLcOQdIzGupr5Fj+EDe8gO/lxc1BzfMpxvA==", "dependencies": { - "fill-range": "^7.0.1" + "fill-range": "^7.1.1" }, "engines": { "node": ">=8" } }, "node_modules/browserslist": { - "version": "4.22.2", - "resolved": "https://registry.npmjs.org/browserslist/-/browserslist-4.22.2.tgz", - "integrity": "sha512-0UgcrvQmBDvZHFGdYUehrCNIazki7/lUP3kkoi/r3YB2amZbFM9J43ZRkJTXBUZK4gmx56+Sqk9+Vs9mwZx9+A==", + "version": "4.23.0", + "resolved": "https://registry.npmjs.org/browserslist/-/browserslist-4.23.0.tgz", + "integrity": "sha512-QW8HiM1shhT2GuzkvklfjcKDiWFXHOeFCIA/huJPwHsslwcydgk7X+z2zXpEijP98UCY7HbubZt5J2Zgvf0CaQ==", "funding": [ { "type": "opencollective", @@ -4739,8 +4748,8 @@ } ], "dependencies": { - "caniuse-lite": "^1.0.30001565", - "electron-to-chromium": "^1.4.601", + "caniuse-lite": "^1.0.30001587", + "electron-to-chromium": "^1.4.668", "node-releases": "^2.0.14", "update-browserslist-db": "^1.0.13" }, @@ -4858,9 +4867,9 @@ } }, "node_modules/caniuse-lite": { - "version": "1.0.30001571", - "resolved": "https://registry.npmjs.org/caniuse-lite/-/caniuse-lite-1.0.30001571.tgz", - "integrity": "sha512-tYq/6MoXhdezDLFZuCO/TKboTzuQ/xR5cFdgXPfDtM7/kchBO3b4VWghE/OAi/DV7tTdhmLjZiZBZi1fA/GheQ==", + "version": "1.0.30001629", + "resolved": "https://registry.npmjs.org/caniuse-lite/-/caniuse-lite-1.0.30001629.tgz", + "integrity": "sha512-c3dl911slnQhmxUIT4HhYzT7wnBK/XYpGnYLOj4nJBaRiw52Ibe7YxlDaAeRECvA786zCuExhxIUJ2K7nHMrBw==", "funding": [ { "type": "opencollective", @@ -5448,18 +5457,28 @@ } }, "node_modules/cosmiconfig": { - "version": "7.1.0", - "resolved": "https://registry.npmjs.org/cosmiconfig/-/cosmiconfig-7.1.0.tgz", - "integrity": "sha512-AdmX6xUzdNASswsFtmwSt7Vj8po9IuqXm0UXz7QKPuEUmPB4XyjGfaAr2PSuELMwkRMVH1EpIkX5bTZGRB3eCA==", + "version": "8.3.6", + "resolved": "https://registry.npmjs.org/cosmiconfig/-/cosmiconfig-8.3.6.tgz", + "integrity": "sha512-kcZ6+W5QzcJ3P1Mt+83OUv/oHFqZHIx8DuxG6eZ5RGMERoLqp4BuGjhHLYGK+Kf5XVkQvqBSmAy/nGWN3qDgEA==", "dependencies": { - "@types/parse-json": "^4.0.0", - "import-fresh": "^3.2.1", - "parse-json": "^5.0.0", - "path-type": "^4.0.0", - "yaml": "^1.10.0" + "import-fresh": "^3.3.0", + "js-yaml": "^4.1.0", + "parse-json": "^5.2.0", + "path-type": "^4.0.0" }, "engines": { - "node": ">=10" + "node": ">=14" + }, + "funding": { + "url": "https://github.com/sponsors/d-fischer" + }, + "peerDependencies": { + "typescript": ">=4.9.5" + }, + "peerDependenciesMeta": { + "typescript": { + "optional": true + } } }, "node_modules/cross-spawn": { @@ -5501,11 +5520,11 @@ } }, "node_modules/css-declaration-sorter": { - "version": "6.4.1", - "resolved": "https://registry.npmjs.org/css-declaration-sorter/-/css-declaration-sorter-6.4.1.tgz", - "integrity": "sha512-rtdthzxKuyq6IzqX6jEcIzQF/YqccluefyCYheovBOLhFT/drQA9zj/UbRAa9J7C0o6EG6u3E6g+vKkay7/k3g==", + "version": "7.2.0", + "resolved": "https://registry.npmjs.org/css-declaration-sorter/-/css-declaration-sorter-7.2.0.tgz", + "integrity": "sha512-h70rUM+3PNFuaBDTLe8wF/cdWu+dOZmb7pJt8Z2sedYbAcQVQV/tEchueg3GWxwqS0cxtbxmaHEdkNACqcvsow==", "engines": { - "node": "^10 || ^12 || >=14" + "node": "^14 || ^16 || >=18" }, "peerDependencies": { "postcss": "^8.0.9" @@ -5537,16 +5556,16 @@ } }, "node_modules/css-minimizer-webpack-plugin": { - "version": "4.2.2", - "resolved": "https://registry.npmjs.org/css-minimizer-webpack-plugin/-/css-minimizer-webpack-plugin-4.2.2.tgz", - "integrity": "sha512-s3Of/4jKfw1Hj9CxEO1E5oXhQAxlayuHO2y/ML+C6I9sQ7FdzfEV6QgMLN3vI+qFsjJGIAFLKtQK7t8BOXAIyA==", + "version": "5.0.1", + "resolved": "https://registry.npmjs.org/css-minimizer-webpack-plugin/-/css-minimizer-webpack-plugin-5.0.1.tgz", + "integrity": "sha512-3caImjKFQkS+ws1TGcFn0V1HyDJFq1Euy589JlD6/3rV2kj+w7r5G9WDMgSHvpvXHNZ2calVypZWuEDQd9wfLg==", "dependencies": { - "cssnano": "^5.1.8", - "jest-worker": "^29.1.2", - "postcss": "^8.4.17", - "schema-utils": "^4.0.0", - "serialize-javascript": "^6.0.0", - "source-map": "^0.6.1" + "@jridgewell/trace-mapping": "^0.3.18", + "cssnano": "^6.0.1", + "jest-worker": "^29.4.3", + "postcss": "^8.4.24", + "schema-utils": "^4.0.1", + "serialize-javascript": "^6.0.1" }, "engines": { "node": ">= 14.15.0" @@ -5579,14 +5598,6 @@ } } }, - "node_modules/css-minimizer-webpack-plugin/node_modules/source-map": { - "version": "0.6.1", - "resolved": "https://registry.npmjs.org/source-map/-/source-map-0.6.1.tgz", - "integrity": "sha512-UjgapumWlbMhkBgzT7Ykc5YXUT46F0iKu8SGXq0bcwP5dz/h0Plj6enJqjz1Zbq2l5WaqYnrVbwWOWMyF3F47g==", - "engines": { - "node": ">=0.10.0" - } - }, "node_modules/css-select": { "version": "5.1.0", "resolved": "https://registry.npmjs.org/css-select/-/css-select-5.1.0.tgz", @@ -5603,23 +5614,15 @@ } }, "node_modules/css-tree": { - "version": "1.1.3", - "resolved": "https://registry.npmjs.org/css-tree/-/css-tree-1.1.3.tgz", - "integrity": "sha512-tRpdppF7TRazZrjJ6v3stzv93qxRcSsFmW6cX0Zm2NVKpxE1WV1HblnghVv9TreireHkqI/VDEsfolRF1p6y7Q==", + "version": "2.3.1", + "resolved": "https://registry.npmjs.org/css-tree/-/css-tree-2.3.1.tgz", + "integrity": "sha512-6Fv1DV/TYw//QF5IzQdqsNDjx/wc8TrMBZsqjL9eW01tWb7R7k/mq+/VXfJCl7SoD5emsJop9cOByJZfs8hYIw==", "dependencies": { - "mdn-data": "2.0.14", - "source-map": "^0.6.1" + "mdn-data": "2.0.30", + "source-map-js": "^1.0.1" }, "engines": { - "node": ">=8.0.0" - } - }, - "node_modules/css-tree/node_modules/source-map": { - "version": "0.6.1", - "resolved": "https://registry.npmjs.org/source-map/-/source-map-0.6.1.tgz", - "integrity": "sha512-UjgapumWlbMhkBgzT7Ykc5YXUT46F0iKu8SGXq0bcwP5dz/h0Plj6enJqjz1Zbq2l5WaqYnrVbwWOWMyF3F47g==", - "engines": { - "node": ">=0.10.0" + "node": "^10 || ^12.20.0 || ^14.13.0 || >=15.0.0" } }, "node_modules/css-what": { @@ -5645,108 +5648,128 @@ } }, "node_modules/cssnano": { - "version": "5.1.15", - "resolved": "https://registry.npmjs.org/cssnano/-/cssnano-5.1.15.tgz", - "integrity": "sha512-j+BKgDcLDQA+eDifLx0EO4XSA56b7uut3BQFH+wbSaSTuGLuiyTa/wbRYthUXX8LC9mLg+WWKe8h+qJuwTAbHw==", + "version": "6.1.2", + "resolved": "https://registry.npmjs.org/cssnano/-/cssnano-6.1.2.tgz", + "integrity": "sha512-rYk5UeX7VAM/u0lNqewCdasdtPK81CgX8wJFLEIXHbV2oldWRgJAsZrdhRXkV1NJzA2g850KiFm9mMU2HxNxMA==", "dependencies": { - "cssnano-preset-default": "^5.2.14", - "lilconfig": "^2.0.3", - "yaml": "^1.10.2" + "cssnano-preset-default": "^6.1.2", + "lilconfig": "^3.1.1" }, "engines": { - "node": "^10 || ^12 || >=14.0" + "node": "^14 || ^16 || >=18.0" }, "funding": { "type": "opencollective", "url": "https://opencollective.com/cssnano" }, "peerDependencies": { - "postcss": "^8.2.15" + "postcss": "^8.4.31" } }, "node_modules/cssnano-preset-advanced": { - "version": "5.3.10", - "resolved": "https://registry.npmjs.org/cssnano-preset-advanced/-/cssnano-preset-advanced-5.3.10.tgz", - "integrity": "sha512-fnYJyCS9jgMU+cmHO1rPSPf9axbQyD7iUhLO5Df6O4G+fKIOMps+ZbU0PdGFejFBBZ3Pftf18fn1eG7MAPUSWQ==", + "version": "6.1.2", + "resolved": "https://registry.npmjs.org/cssnano-preset-advanced/-/cssnano-preset-advanced-6.1.2.tgz", + "integrity": "sha512-Nhao7eD8ph2DoHolEzQs5CfRpiEP0xa1HBdnFZ82kvqdmbwVBUr2r1QuQ4t1pi+D1ZpqpcO4T+wy/7RxzJ/WPQ==", "dependencies": { - "autoprefixer": "^10.4.12", - "cssnano-preset-default": "^5.2.14", - "postcss-discard-unused": "^5.1.0", - "postcss-merge-idents": "^5.1.1", - "postcss-reduce-idents": "^5.2.0", - "postcss-zindex": "^5.1.0" + "autoprefixer": "^10.4.19", + "browserslist": "^4.23.0", + "cssnano-preset-default": "^6.1.2", + "postcss-discard-unused": "^6.0.5", + "postcss-merge-idents": "^6.0.3", + "postcss-reduce-idents": "^6.0.3", + "postcss-zindex": "^6.0.2" }, "engines": { - "node": "^10 || ^12 || >=14.0" + "node": "^14 || ^16 || >=18.0" }, "peerDependencies": { - "postcss": "^8.2.15" + "postcss": "^8.4.31" } }, "node_modules/cssnano-preset-default": { - "version": "5.2.14", - "resolved": "https://registry.npmjs.org/cssnano-preset-default/-/cssnano-preset-default-5.2.14.tgz", - "integrity": "sha512-t0SFesj/ZV2OTylqQVOrFgEh5uanxbO6ZAdeCrNsUQ6fVuXwYTxJPNAGvGTxHbD68ldIJNec7PyYZDBrfDQ+6A==", - "dependencies": { - "css-declaration-sorter": "^6.3.1", - "cssnano-utils": "^3.1.0", - "postcss-calc": "^8.2.3", - "postcss-colormin": "^5.3.1", - "postcss-convert-values": "^5.1.3", - "postcss-discard-comments": "^5.1.2", - "postcss-discard-duplicates": "^5.1.0", - "postcss-discard-empty": "^5.1.1", - "postcss-discard-overridden": "^5.1.0", - "postcss-merge-longhand": "^5.1.7", - "postcss-merge-rules": "^5.1.4", - "postcss-minify-font-values": "^5.1.0", - "postcss-minify-gradients": "^5.1.1", - "postcss-minify-params": "^5.1.4", - "postcss-minify-selectors": "^5.2.1", - "postcss-normalize-charset": "^5.1.0", - "postcss-normalize-display-values": "^5.1.0", - "postcss-normalize-positions": "^5.1.1", - "postcss-normalize-repeat-style": "^5.1.1", - "postcss-normalize-string": "^5.1.0", - "postcss-normalize-timing-functions": "^5.1.0", - "postcss-normalize-unicode": "^5.1.1", - "postcss-normalize-url": "^5.1.0", - "postcss-normalize-whitespace": "^5.1.1", - "postcss-ordered-values": "^5.1.3", - "postcss-reduce-initial": "^5.1.2", - "postcss-reduce-transforms": "^5.1.0", - "postcss-svgo": "^5.1.0", - "postcss-unique-selectors": "^5.1.1" - }, - "engines": { - "node": "^10 || ^12 || >=14.0" - }, - "peerDependencies": { - "postcss": "^8.2.15" + "version": "6.1.2", + "resolved": "https://registry.npmjs.org/cssnano-preset-default/-/cssnano-preset-default-6.1.2.tgz", + "integrity": "sha512-1C0C+eNaeN8OcHQa193aRgYexyJtU8XwbdieEjClw+J9d94E41LwT6ivKH0WT+fYwYWB0Zp3I3IZ7tI/BbUbrg==", + "dependencies": { + "browserslist": "^4.23.0", + "css-declaration-sorter": "^7.2.0", + "cssnano-utils": "^4.0.2", + "postcss-calc": "^9.0.1", + "postcss-colormin": "^6.1.0", + "postcss-convert-values": "^6.1.0", + "postcss-discard-comments": "^6.0.2", + "postcss-discard-duplicates": "^6.0.3", + "postcss-discard-empty": "^6.0.3", + "postcss-discard-overridden": "^6.0.2", + "postcss-merge-longhand": "^6.0.5", + "postcss-merge-rules": "^6.1.1", + "postcss-minify-font-values": "^6.1.0", + "postcss-minify-gradients": "^6.0.3", + "postcss-minify-params": "^6.1.0", + "postcss-minify-selectors": "^6.0.4", + "postcss-normalize-charset": "^6.0.2", + "postcss-normalize-display-values": "^6.0.2", + "postcss-normalize-positions": "^6.0.2", + "postcss-normalize-repeat-style": "^6.0.2", + "postcss-normalize-string": "^6.0.2", + "postcss-normalize-timing-functions": "^6.0.2", + "postcss-normalize-unicode": "^6.1.0", + "postcss-normalize-url": "^6.0.2", + "postcss-normalize-whitespace": "^6.0.2", + "postcss-ordered-values": "^6.0.2", + "postcss-reduce-initial": "^6.1.0", + "postcss-reduce-transforms": "^6.0.2", + "postcss-svgo": "^6.0.3", + "postcss-unique-selectors": "^6.0.4" + }, + "engines": { + "node": "^14 || ^16 || >=18.0" + }, + "peerDependencies": { + "postcss": "^8.4.31" } }, "node_modules/cssnano-utils": { - "version": "3.1.0", - "resolved": "https://registry.npmjs.org/cssnano-utils/-/cssnano-utils-3.1.0.tgz", - "integrity": "sha512-JQNR19/YZhz4psLX/rQ9M83e3z2Wf/HdJbryzte4a3NSuafyp9w/I4U+hx5C2S9g41qlstH7DEWnZaaj83OuEA==", + "version": "4.0.2", + "resolved": "https://registry.npmjs.org/cssnano-utils/-/cssnano-utils-4.0.2.tgz", + "integrity": "sha512-ZR1jHg+wZ8o4c3zqf1SIUSTIvm/9mU343FMR6Obe/unskbvpGhZOo1J6d/r8D1pzkRQYuwbcH3hToOuoA2G7oQ==", "engines": { - "node": "^10 || ^12 || >=14.0" + "node": "^14 || ^16 || >=18.0" }, "peerDependencies": { - "postcss": "^8.2.15" + "postcss": "^8.4.31" } }, "node_modules/csso": { - "version": "4.2.0", - "resolved": "https://registry.npmjs.org/csso/-/csso-4.2.0.tgz", - "integrity": "sha512-wvlcdIbf6pwKEk7vHj8/Bkc0B4ylXZruLvOgs9doS5eOsOpuodOV2zJChSpkp+pRpYQLQMeF04nr3Z68Sta9jA==", + "version": "5.0.5", + "resolved": "https://registry.npmjs.org/csso/-/csso-5.0.5.tgz", + "integrity": "sha512-0LrrStPOdJj+SPCCrGhzryycLjwcgUSHBtxNA8aIDxf0GLsRh1cKYhB00Gd1lDOS4yGH69+SNn13+TWbVHETFQ==", "dependencies": { - "css-tree": "^1.1.2" + "css-tree": "~2.2.0" }, "engines": { - "node": ">=8.0.0" + "node": "^10 || ^12.20.0 || ^14.13.0 || >=15.0.0", + "npm": ">=7.0.0" + } + }, + "node_modules/csso/node_modules/css-tree": { + "version": "2.2.1", + "resolved": "https://registry.npmjs.org/css-tree/-/css-tree-2.2.1.tgz", + "integrity": "sha512-OA0mILzGc1kCOCSJerOeqDxDQ4HOh+G8NbOJFOTgOCzpw7fCBubk0fEyxp8AgOL/jvLgYA/uV0cMbe43ElF1JA==", + "dependencies": { + "mdn-data": "2.0.28", + "source-map-js": "^1.0.1" + }, + "engines": { + "node": "^10 || ^12.20.0 || ^14.13.0 || >=15.0.0", + "npm": ">=7.0.0" } }, + "node_modules/csso/node_modules/mdn-data": { + "version": "2.0.28", + "resolved": "https://registry.npmjs.org/mdn-data/-/mdn-data-2.0.28.tgz", + "integrity": "sha512-aylIc7Z9y4yzHYAJNuESG3hfhC+0Ibp/MAMiaOZgNv4pmEdFyfZhhhny4MNiAfWdBQ1RQ2mfDWmM1x8SvGyp8g==" + }, "node_modules/csstype": { "version": "3.1.3", "resolved": "https://registry.npmjs.org/csstype/-/csstype-3.1.3.tgz", @@ -6629,9 +6652,9 @@ "integrity": "sha512-WMwm9LhRUo+WUaRN+vRuETqG89IgZphVSNkdFgeb6sS/E4OrDIN7t48CAewSHXc6C8lefD8KKfr5vY61brQlow==" }, "node_modules/electron-to-chromium": { - "version": "1.4.616", - "resolved": "https://registry.npmjs.org/electron-to-chromium/-/electron-to-chromium-1.4.616.tgz", - "integrity": "sha512-1n7zWYh8eS0L9Uy+GskE0lkBUNK83cXTVJI0pU3mGprFsbfSdAc15VTFbo+A+Bq4pwstmL30AVcEU3Fo463lNg==" + "version": "1.4.792", + "resolved": "https://registry.npmjs.org/electron-to-chromium/-/electron-to-chromium-1.4.792.tgz", + "integrity": "sha512-rkg5/N3L+Y844JyfgPUyuKK0Hk0efo3JNxUDKvz3HgP6EmN4rNGhr2D8boLsfTV/hGo7ZGAL8djw+jlg99zQyA==" }, "node_modules/elkjs": { "version": "0.8.2", @@ -6865,16 +6888,13 @@ } }, "node_modules/estree-util-value-to-estree": { - "version": "3.0.1", - "resolved": "https://registry.npmjs.org/estree-util-value-to-estree/-/estree-util-value-to-estree-3.0.1.tgz", - "integrity": "sha512-b2tdzTurEIbwRh+mKrEcaWfu1wgb8J1hVsgREg7FFiecWwK/PhO8X0kyc+0bIcKNtD4sqxIdNoRy6/p/TvECEA==", + "version": "3.1.1", + "resolved": "https://registry.npmjs.org/estree-util-value-to-estree/-/estree-util-value-to-estree-3.1.1.tgz", + "integrity": "sha512-5mvUrF2suuv5f5cGDnDphIy4/gW86z82kl5qG6mM9z04SEQI4FB5Apmaw/TGEf3l55nLtMs5s51dmhUzvAHQCA==", "dependencies": { "@types/estree": "^1.0.0", "is-plain-obj": "^4.0.0" }, - "engines": { - "node": ">=16.0.0" - }, "funding": { "url": "https://github.com/sponsors/remcohaszing" } @@ -7221,9 +7241,9 @@ } }, "node_modules/fill-range": { - "version": "7.0.1", - "resolved": "https://registry.npmjs.org/fill-range/-/fill-range-7.0.1.tgz", - "integrity": "sha512-qOo9F+dMUmC2Lcb4BbVvnKJxTPjCm+RRpe4gDuGrzkL7mEVl/djYSu2OdQ2Pa302N4oqkSg9ir6jaLWJ2USVpQ==", + "version": "7.1.1", + "resolved": "https://registry.npmjs.org/fill-range/-/fill-range-7.1.1.tgz", + "integrity": "sha512-YsGpe3WHLK8ZYi4tWDg2Jy3ebRz2rXowDxnld4bkQB00cc/1Zw9AWnC0i9ztDJitivtQvaI9KaLyKrc+hBW0yg==", "dependencies": { "to-regex-range": "^5.0.1" }, @@ -7891,9 +7911,9 @@ } }, "node_modules/hast-util-raw": { - "version": "9.0.1", - "resolved": "https://registry.npmjs.org/hast-util-raw/-/hast-util-raw-9.0.1.tgz", - "integrity": "sha512-5m1gmba658Q+lO5uqL5YNGQWeh1MYWZbZmWrM5lncdcuiXuo5E2HT/CIOp0rLF8ksfSwiCVJ3twlgVRyTGThGA==", + "version": "9.0.3", + "resolved": "https://registry.npmjs.org/hast-util-raw/-/hast-util-raw-9.0.3.tgz", + "integrity": "sha512-ICWvVOF2fq4+7CMmtCPD5CM4QKjPbHpPotE6+8tDooV0ZuyJVUzHsrNX+O5NaRbieTf0F7FfeBOMAwi6Td0+yQ==", "dependencies": { "@types/hast": "^3.0.0", "@types/unist": "^3.0.0", @@ -7968,16 +7988,16 @@ } }, "node_modules/hast-util-to-jsx-runtime/node_modules/inline-style-parser": { - "version": "0.2.2", - "resolved": "https://registry.npmjs.org/inline-style-parser/-/inline-style-parser-0.2.2.tgz", - "integrity": "sha512-EcKzdTHVe8wFVOGEYXiW9WmJXPjqi1T+234YpJr98RiFYKHV3cdy1+3mkTE+KHTHxFFLH51SfaGOoUdW+v7ViQ==" + "version": "0.2.3", + "resolved": "https://registry.npmjs.org/inline-style-parser/-/inline-style-parser-0.2.3.tgz", + "integrity": "sha512-qlD8YNDqyTKTyuITrDOffsl6Tdhv+UC4hcdAVuQsK4IMQ99nSgd1MIA/Q+jQYoh9r3hVUXhYh7urSRmXPkW04g==" }, "node_modules/hast-util-to-jsx-runtime/node_modules/style-to-object": { - "version": "1.0.5", - "resolved": "https://registry.npmjs.org/style-to-object/-/style-to-object-1.0.5.tgz", - "integrity": "sha512-rDRwHtoDD3UMMrmZ6BzOW0naTjMsVZLIjsGleSKS/0Oz+cgCfAPRspaqJuE8rDzpKha/nEvnM0IF4seEAZUTKQ==", + "version": "1.0.6", + "resolved": "https://registry.npmjs.org/style-to-object/-/style-to-object-1.0.6.tgz", + "integrity": "sha512-khxq+Qm3xEyZfKd/y9L3oIWQimxuc4STrQKtQn8aSDRHb8mFgpukgX1hdzfrMEW6JCjyJ8p89x+IUMVnCBI1PA==", "dependencies": { - "inline-style-parser": "0.2.2" + "inline-style-parser": "0.2.3" } }, "node_modules/hast-util-to-parse5": { @@ -8376,9 +8396,9 @@ } }, "node_modules/image-size": { - "version": "1.0.2", - "resolved": "https://registry.npmjs.org/image-size/-/image-size-1.0.2.tgz", - "integrity": "sha512-xfOoWjceHntRb3qFCrh5ZFORYH8XCdYpASltMhZ/Q0KZiOwjdE/Yl2QCiWdwD+lygV5bMCvauzgu5PxBX/Yerg==", + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/image-size/-/image-size-1.1.1.tgz", + "integrity": "sha512-541xKlUw6jr/6gGuk92F+mYM5zaFAc5ahphvkqvNe2bQ6gVBkd6bfrmVJ2t4KDAfikAYZyIqTnktX3i6/aQDrQ==", "dependencies": { "queue": "6.0.2" }, @@ -8386,7 +8406,7 @@ "image-size": "bin/image-size.js" }, "engines": { - "node": ">=14.0.0" + "node": ">=16.x" } }, "node_modules/immer": { @@ -8976,11 +8996,14 @@ } }, "node_modules/lilconfig": { - "version": "2.1.0", - "resolved": "https://registry.npmjs.org/lilconfig/-/lilconfig-2.1.0.tgz", - "integrity": "sha512-utWOt/GHzuUxnLKxB6dk81RoOeoNeHgbrXiuGk4yyF5qlRz+iIVWu56E2fqGHFrXz0QNUhLB/8nKqvRH66JKGQ==", + "version": "3.1.1", + "resolved": "https://registry.npmjs.org/lilconfig/-/lilconfig-3.1.1.tgz", + "integrity": "sha512-O18pf7nyvHTckunPWCV1XUNXU1piu01y2b7ATJ0ppkUkk8ocqVWBrYjJBCwHDjD/ZWcfyrA0P4gKhzWGi5EINQ==", "engines": { - "node": ">=10" + "node": ">=14" + }, + "funding": { + "url": "https://github.com/sponsors/antonk52" } }, "node_modules/lines-and-columns": { @@ -9161,9 +9184,9 @@ } }, "node_modules/mdast-util-from-markdown": { - "version": "2.0.0", - "resolved": "https://registry.npmjs.org/mdast-util-from-markdown/-/mdast-util-from-markdown-2.0.0.tgz", - "integrity": "sha512-n7MTOr/z+8NAX/wmhhDji8O3bRvPTV/U0oTCaZJkjhPSKTPhS3xufVhKGF8s1pJ7Ox4QgoIU7KHseh09S+9rTA==", + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/mdast-util-from-markdown/-/mdast-util-from-markdown-2.0.1.tgz", + "integrity": "sha512-aJEUyzZ6TzlsX2s5B4Of7lN7EQtAxvtradMMglCQDyaTFgse6CmtmdJ15ElnVRlCg1vpNyVtbem0PWzlNieZsA==", "dependencies": { "@types/mdast": "^4.0.0", "@types/unist": "^3.0.0", @@ -9261,9 +9284,9 @@ } }, "node_modules/mdast-util-gfm-autolink-literal/node_modules/micromark-util-character": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/micromark-util-character/-/micromark-util-character-2.0.1.tgz", - "integrity": "sha512-3wgnrmEAJ4T+mGXAUfMvMAbxU9RDG43XmGce4j6CwPtVxB3vfwXSZ6KhFwDzZ3mZHhmPimMAXg71veiBGzeAZw==", + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/micromark-util-character/-/micromark-util-character-2.1.0.tgz", + "integrity": "sha512-KvOVV+X1yLBfs9dCBSopq/+G1PcgT3lAK07mC4BzXi5E7ahzMAF8oIupDDJ6mievI6F+lAATkbQQlQixJfT3aQ==", "funding": [ { "type": "GitHub Sponsors", @@ -9389,9 +9412,9 @@ } }, "node_modules/mdast-util-mdx-jsx": { - "version": "3.0.0", - "resolved": "https://registry.npmjs.org/mdast-util-mdx-jsx/-/mdast-util-mdx-jsx-3.0.0.tgz", - "integrity": "sha512-XZuPPzQNBPAlaqsTTgRrcJnyFbSOBovSadFgbFu8SnuNgm+6Bdx1K+IWoitsmj6Lq6MNtI+ytOqwN70n//NaBA==", + "version": "3.1.2", + "resolved": "https://registry.npmjs.org/mdast-util-mdx-jsx/-/mdast-util-mdx-jsx-3.1.2.tgz", + "integrity": "sha512-eKMQDeywY2wlHc97k5eD8VC+9ASMjN8ItEZQNGwJ6E0XWKiW/Z0V5/H8pvoXUf+y+Mj0VIgeRRbujBmFn4FTyA==", "dependencies": { "@types/estree-jsx": "^1.0.0", "@types/hast": "^3.0.0", @@ -9430,9 +9453,9 @@ } }, "node_modules/mdast-util-phrasing": { - "version": "4.0.0", - "resolved": "https://registry.npmjs.org/mdast-util-phrasing/-/mdast-util-phrasing-4.0.0.tgz", - "integrity": "sha512-xadSsJayQIucJ9n053dfQwVu1kuXg7jCTdYsMK8rqzKZh52nLfSH/k0sAxE0u+pj/zKZX+o5wB+ML5mRayOxFA==", + "version": "4.1.0", + "resolved": "https://registry.npmjs.org/mdast-util-phrasing/-/mdast-util-phrasing-4.1.0.tgz", + "integrity": "sha512-TqICwyvJJpBwvGAMZjj4J2n0X8QWp21b9l0o7eXyVJ25YNWYbJDVIyD1bZXE6WtV6RmKJVYmQAKWa0zWOABz2w==", "dependencies": { "@types/mdast": "^4.0.0", "unist-util-is": "^6.0.0" @@ -9443,9 +9466,9 @@ } }, "node_modules/mdast-util-to-hast": { - "version": "13.0.2", - "resolved": "https://registry.npmjs.org/mdast-util-to-hast/-/mdast-util-to-hast-13.0.2.tgz", - "integrity": "sha512-U5I+500EOOw9e3ZrclN3Is3fRpw8c19SMyNZlZ2IS+7vLsNzb2Om11VpIVOR+/0137GhZsFEF6YiKD5+0Hr2Og==", + "version": "13.1.0", + "resolved": "https://registry.npmjs.org/mdast-util-to-hast/-/mdast-util-to-hast-13.1.0.tgz", + "integrity": "sha512-/e2l/6+OdGp/FB+ctrJ9Avz71AN/GRH3oi/3KAx/kMnoUsD6q0woXlDT8lLEeViVKE7oZxE7RXzvO3T8kF2/sA==", "dependencies": { "@types/hast": "^3.0.0", "@types/mdast": "^4.0.0", @@ -9454,7 +9477,8 @@ "micromark-util-sanitize-uri": "^2.0.0", "trim-lines": "^3.0.0", "unist-util-position": "^5.0.0", - "unist-util-visit": "^5.0.0" + "unist-util-visit": "^5.0.0", + "vfile": "^6.0.0" }, "funding": { "type": "opencollective", @@ -9493,9 +9517,9 @@ } }, "node_modules/mdn-data": { - "version": "2.0.14", - "resolved": "https://registry.npmjs.org/mdn-data/-/mdn-data-2.0.14.tgz", - "integrity": "sha512-dn6wd0uw5GsdswPFfsgMp5NSB0/aDe6fK94YJV/AJDYXL6HVLWBsxeq7js7Ad+mU2K9LAlwpk6kN2D5mwCPVow==" + "version": "2.0.30", + "resolved": "https://registry.npmjs.org/mdn-data/-/mdn-data-2.0.30.tgz", + "integrity": "sha512-GaqWWShW4kv/G9IEucWScBx9G1/vsFZZJUO+tD26M8J8z3Kw5RDQjaoZe03YAClgeS/SWPOcb4nkFBTEi5DUEA==" }, "node_modules/media-typer": { "version": "0.3.0", @@ -10044,9 +10068,9 @@ } }, "node_modules/micromark-core-commonmark": { - "version": "2.0.0", - "resolved": "https://registry.npmjs.org/micromark-core-commonmark/-/micromark-core-commonmark-2.0.0.tgz", - "integrity": "sha512-jThOz/pVmAYUtkroV3D5c1osFXAMv9e0ypGDOIZuCeAe91/sD6BoE2Sjzt30yuXtwOYUmySOhMas/PVyh02itA==", + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/micromark-core-commonmark/-/micromark-core-commonmark-2.0.1.tgz", + "integrity": "sha512-CUQyKr1e///ZODyD1U3xit6zXwy1a8q2a1S1HKtIlmgvurrEpaw/Y9y6KSIbF8P59cn/NjzHyO+Q2fAyYLQrAA==", "funding": [ { "type": "GitHub Sponsors", @@ -10096,9 +10120,9 @@ } }, "node_modules/micromark-core-commonmark/node_modules/micromark-util-character": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/micromark-util-character/-/micromark-util-character-2.0.1.tgz", - "integrity": "sha512-3wgnrmEAJ4T+mGXAUfMvMAbxU9RDG43XmGce4j6CwPtVxB3vfwXSZ6KhFwDzZ3mZHhmPimMAXg71veiBGzeAZw==", + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/micromark-util-character/-/micromark-util-character-2.1.0.tgz", + "integrity": "sha512-KvOVV+X1yLBfs9dCBSopq/+G1PcgT3lAK07mC4BzXi5E7ahzMAF8oIupDDJ6mievI6F+lAATkbQQlQixJfT3aQ==", "funding": [ { "type": "GitHub Sponsors", @@ -10167,9 +10191,9 @@ } }, "node_modules/micromark-extension-directive/node_modules/micromark-util-character": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/micromark-util-character/-/micromark-util-character-2.0.1.tgz", - "integrity": "sha512-3wgnrmEAJ4T+mGXAUfMvMAbxU9RDG43XmGce4j6CwPtVxB3vfwXSZ6KhFwDzZ3mZHhmPimMAXg71veiBGzeAZw==", + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/micromark-util-character/-/micromark-util-character-2.1.0.tgz", + "integrity": "sha512-KvOVV+X1yLBfs9dCBSopq/+G1PcgT3lAK07mC4BzXi5E7ahzMAF8oIupDDJ6mievI6F+lAATkbQQlQixJfT3aQ==", "funding": [ { "type": "GitHub Sponsors", @@ -10216,9 +10240,9 @@ } }, "node_modules/micromark-extension-frontmatter/node_modules/micromark-util-character": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/micromark-util-character/-/micromark-util-character-2.0.1.tgz", - "integrity": "sha512-3wgnrmEAJ4T+mGXAUfMvMAbxU9RDG43XmGce4j6CwPtVxB3vfwXSZ6KhFwDzZ3mZHhmPimMAXg71veiBGzeAZw==", + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/micromark-util-character/-/micromark-util-character-2.1.0.tgz", + "integrity": "sha512-KvOVV+X1yLBfs9dCBSopq/+G1PcgT3lAK07mC4BzXi5E7ahzMAF8oIupDDJ6mievI6F+lAATkbQQlQixJfT3aQ==", "funding": [ { "type": "GitHub Sponsors", @@ -10284,9 +10308,9 @@ } }, "node_modules/micromark-extension-gfm-autolink-literal/node_modules/micromark-util-character": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/micromark-util-character/-/micromark-util-character-2.0.1.tgz", - "integrity": "sha512-3wgnrmEAJ4T+mGXAUfMvMAbxU9RDG43XmGce4j6CwPtVxB3vfwXSZ6KhFwDzZ3mZHhmPimMAXg71veiBGzeAZw==", + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/micromark-util-character/-/micromark-util-character-2.1.0.tgz", + "integrity": "sha512-KvOVV+X1yLBfs9dCBSopq/+G1PcgT3lAK07mC4BzXi5E7ahzMAF8oIupDDJ6mievI6F+lAATkbQQlQixJfT3aQ==", "funding": [ { "type": "GitHub Sponsors", @@ -10356,9 +10380,9 @@ } }, "node_modules/micromark-extension-gfm-footnote/node_modules/micromark-util-character": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/micromark-util-character/-/micromark-util-character-2.0.1.tgz", - "integrity": "sha512-3wgnrmEAJ4T+mGXAUfMvMAbxU9RDG43XmGce4j6CwPtVxB3vfwXSZ6KhFwDzZ3mZHhmPimMAXg71veiBGzeAZw==", + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/micromark-util-character/-/micromark-util-character-2.1.0.tgz", + "integrity": "sha512-KvOVV+X1yLBfs9dCBSopq/+G1PcgT3lAK07mC4BzXi5E7ahzMAF8oIupDDJ6mievI6F+lAATkbQQlQixJfT3aQ==", "funding": [ { "type": "GitHub Sponsors", @@ -10457,9 +10481,9 @@ } }, "node_modules/micromark-extension-gfm-table/node_modules/micromark-util-character": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/micromark-util-character/-/micromark-util-character-2.0.1.tgz", - "integrity": "sha512-3wgnrmEAJ4T+mGXAUfMvMAbxU9RDG43XmGce4j6CwPtVxB3vfwXSZ6KhFwDzZ3mZHhmPimMAXg71veiBGzeAZw==", + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/micromark-util-character/-/micromark-util-character-2.1.0.tgz", + "integrity": "sha512-KvOVV+X1yLBfs9dCBSopq/+G1PcgT3lAK07mC4BzXi5E7ahzMAF8oIupDDJ6mievI6F+lAATkbQQlQixJfT3aQ==", "funding": [ { "type": "GitHub Sponsors", @@ -10538,9 +10562,9 @@ } }, "node_modules/micromark-extension-gfm-task-list-item/node_modules/micromark-util-character": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/micromark-util-character/-/micromark-util-character-2.0.1.tgz", - "integrity": "sha512-3wgnrmEAJ4T+mGXAUfMvMAbxU9RDG43XmGce4j6CwPtVxB3vfwXSZ6KhFwDzZ3mZHhmPimMAXg71veiBGzeAZw==", + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/micromark-util-character/-/micromark-util-character-2.1.0.tgz", + "integrity": "sha512-KvOVV+X1yLBfs9dCBSopq/+G1PcgT3lAK07mC4BzXi5E7ahzMAF8oIupDDJ6mievI6F+lAATkbQQlQixJfT3aQ==", "funding": [ { "type": "GitHub Sponsors", @@ -10616,9 +10640,9 @@ } }, "node_modules/micromark-extension-mdx-expression/node_modules/micromark-util-character": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/micromark-util-character/-/micromark-util-character-2.0.1.tgz", - "integrity": "sha512-3wgnrmEAJ4T+mGXAUfMvMAbxU9RDG43XmGce4j6CwPtVxB3vfwXSZ6KhFwDzZ3mZHhmPimMAXg71veiBGzeAZw==", + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/micromark-util-character/-/micromark-util-character-2.1.0.tgz", + "integrity": "sha512-KvOVV+X1yLBfs9dCBSopq/+G1PcgT3lAK07mC4BzXi5E7ahzMAF8oIupDDJ6mievI6F+lAATkbQQlQixJfT3aQ==", "funding": [ { "type": "GitHub Sponsors", @@ -10690,9 +10714,9 @@ } }, "node_modules/micromark-extension-mdx-jsx/node_modules/micromark-util-character": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/micromark-util-character/-/micromark-util-character-2.0.1.tgz", - "integrity": "sha512-3wgnrmEAJ4T+mGXAUfMvMAbxU9RDG43XmGce4j6CwPtVxB3vfwXSZ6KhFwDzZ3mZHhmPimMAXg71veiBGzeAZw==", + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/micromark-util-character/-/micromark-util-character-2.1.0.tgz", + "integrity": "sha512-KvOVV+X1yLBfs9dCBSopq/+G1PcgT3lAK07mC4BzXi5E7ahzMAF8oIupDDJ6mievI6F+lAATkbQQlQixJfT3aQ==", "funding": [ { "type": "GitHub Sponsors", @@ -10775,9 +10799,9 @@ } }, "node_modules/micromark-extension-mdxjs-esm/node_modules/micromark-util-character": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/micromark-util-character/-/micromark-util-character-2.0.1.tgz", - "integrity": "sha512-3wgnrmEAJ4T+mGXAUfMvMAbxU9RDG43XmGce4j6CwPtVxB3vfwXSZ6KhFwDzZ3mZHhmPimMAXg71veiBGzeAZw==", + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/micromark-util-character/-/micromark-util-character-2.1.0.tgz", + "integrity": "sha512-KvOVV+X1yLBfs9dCBSopq/+G1PcgT3lAK07mC4BzXi5E7ahzMAF8oIupDDJ6mievI6F+lAATkbQQlQixJfT3aQ==", "funding": [ { "type": "GitHub Sponsors", @@ -10829,9 +10853,9 @@ } }, "node_modules/micromark-factory-destination/node_modules/micromark-util-character": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/micromark-util-character/-/micromark-util-character-2.0.1.tgz", - "integrity": "sha512-3wgnrmEAJ4T+mGXAUfMvMAbxU9RDG43XmGce4j6CwPtVxB3vfwXSZ6KhFwDzZ3mZHhmPimMAXg71veiBGzeAZw==", + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/micromark-util-character/-/micromark-util-character-2.1.0.tgz", + "integrity": "sha512-KvOVV+X1yLBfs9dCBSopq/+G1PcgT3lAK07mC4BzXi5E7ahzMAF8oIupDDJ6mievI6F+lAATkbQQlQixJfT3aQ==", "funding": [ { "type": "GitHub Sponsors", @@ -10884,9 +10908,9 @@ } }, "node_modules/micromark-factory-label/node_modules/micromark-util-character": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/micromark-util-character/-/micromark-util-character-2.0.1.tgz", - "integrity": "sha512-3wgnrmEAJ4T+mGXAUfMvMAbxU9RDG43XmGce4j6CwPtVxB3vfwXSZ6KhFwDzZ3mZHhmPimMAXg71veiBGzeAZw==", + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/micromark-util-character/-/micromark-util-character-2.1.0.tgz", + "integrity": "sha512-KvOVV+X1yLBfs9dCBSopq/+G1PcgT3lAK07mC4BzXi5E7ahzMAF8oIupDDJ6mievI6F+lAATkbQQlQixJfT3aQ==", "funding": [ { "type": "GitHub Sponsors", @@ -10943,9 +10967,9 @@ } }, "node_modules/micromark-factory-mdx-expression/node_modules/micromark-util-character": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/micromark-util-character/-/micromark-util-character-2.0.1.tgz", - "integrity": "sha512-3wgnrmEAJ4T+mGXAUfMvMAbxU9RDG43XmGce4j6CwPtVxB3vfwXSZ6KhFwDzZ3mZHhmPimMAXg71veiBGzeAZw==", + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/micromark-util-character/-/micromark-util-character-2.1.0.tgz", + "integrity": "sha512-KvOVV+X1yLBfs9dCBSopq/+G1PcgT3lAK07mC4BzXi5E7ahzMAF8oIupDDJ6mievI6F+lAATkbQQlQixJfT3aQ==", "funding": [ { "type": "GitHub Sponsors", @@ -11051,9 +11075,9 @@ } }, "node_modules/micromark-factory-title/node_modules/micromark-util-character": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/micromark-util-character/-/micromark-util-character-2.0.1.tgz", - "integrity": "sha512-3wgnrmEAJ4T+mGXAUfMvMAbxU9RDG43XmGce4j6CwPtVxB3vfwXSZ6KhFwDzZ3mZHhmPimMAXg71veiBGzeAZw==", + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/micromark-util-character/-/micromark-util-character-2.1.0.tgz", + "integrity": "sha512-KvOVV+X1yLBfs9dCBSopq/+G1PcgT3lAK07mC4BzXi5E7ahzMAF8oIupDDJ6mievI6F+lAATkbQQlQixJfT3aQ==", "funding": [ { "type": "GitHub Sponsors", @@ -11125,9 +11149,9 @@ } }, "node_modules/micromark-factory-whitespace/node_modules/micromark-util-character": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/micromark-util-character/-/micromark-util-character-2.0.1.tgz", - "integrity": "sha512-3wgnrmEAJ4T+mGXAUfMvMAbxU9RDG43XmGce4j6CwPtVxB3vfwXSZ6KhFwDzZ3mZHhmPimMAXg71veiBGzeAZw==", + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/micromark-util-character/-/micromark-util-character-2.1.0.tgz", + "integrity": "sha512-KvOVV+X1yLBfs9dCBSopq/+G1PcgT3lAK07mC4BzXi5E7ahzMAF8oIupDDJ6mievI6F+lAATkbQQlQixJfT3aQ==", "funding": [ { "type": "GitHub Sponsors", @@ -11246,9 +11270,9 @@ } }, "node_modules/micromark-util-classify-character/node_modules/micromark-util-character": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/micromark-util-character/-/micromark-util-character-2.0.1.tgz", - "integrity": "sha512-3wgnrmEAJ4T+mGXAUfMvMAbxU9RDG43XmGce4j6CwPtVxB3vfwXSZ6KhFwDzZ3mZHhmPimMAXg71veiBGzeAZw==", + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/micromark-util-character/-/micromark-util-character-2.1.0.tgz", + "integrity": "sha512-KvOVV+X1yLBfs9dCBSopq/+G1PcgT3lAK07mC4BzXi5E7ahzMAF8oIupDDJ6mievI6F+lAATkbQQlQixJfT3aQ==", "funding": [ { "type": "GitHub Sponsors", @@ -11353,9 +11377,9 @@ } }, "node_modules/micromark-util-decode-string/node_modules/micromark-util-character": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/micromark-util-character/-/micromark-util-character-2.0.1.tgz", - "integrity": "sha512-3wgnrmEAJ4T+mGXAUfMvMAbxU9RDG43XmGce4j6CwPtVxB3vfwXSZ6KhFwDzZ3mZHhmPimMAXg71veiBGzeAZw==", + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/micromark-util-character/-/micromark-util-character-2.1.0.tgz", + "integrity": "sha512-KvOVV+X1yLBfs9dCBSopq/+G1PcgT3lAK07mC4BzXi5E7ahzMAF8oIupDDJ6mievI6F+lAATkbQQlQixJfT3aQ==", "funding": [ { "type": "GitHub Sponsors", @@ -11528,9 +11552,9 @@ } }, "node_modules/micromark-util-sanitize-uri/node_modules/micromark-util-character": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/micromark-util-character/-/micromark-util-character-2.0.1.tgz", - "integrity": "sha512-3wgnrmEAJ4T+mGXAUfMvMAbxU9RDG43XmGce4j6CwPtVxB3vfwXSZ6KhFwDzZ3mZHhmPimMAXg71veiBGzeAZw==", + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/micromark-util-character/-/micromark-util-character-2.1.0.tgz", + "integrity": "sha512-KvOVV+X1yLBfs9dCBSopq/+G1PcgT3lAK07mC4BzXi5E7ahzMAF8oIupDDJ6mievI6F+lAATkbQQlQixJfT3aQ==", "funding": [ { "type": "GitHub Sponsors", @@ -11562,9 +11586,9 @@ ] }, "node_modules/micromark-util-subtokenize": { - "version": "2.0.0", - "resolved": "https://registry.npmjs.org/micromark-util-subtokenize/-/micromark-util-subtokenize-2.0.0.tgz", - "integrity": "sha512-vc93L1t+gpR3p8jxeVdaYlbV2jTYteDje19rNSS/H5dlhxUYll5Fy6vJ2cDwP8RnsXi818yGty1ayP55y3W6fg==", + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/micromark-util-subtokenize/-/micromark-util-subtokenize-2.0.1.tgz", + "integrity": "sha512-jZNtiFl/1aY73yS3UGQkutD0UbhTt68qnRpw2Pifmz5wV9h8gOVsN70v+Lq/f1rKaU/W8pxRe8y8Q9FX1AOe1Q==", "funding": [ { "type": "GitHub Sponsors", @@ -11647,9 +11671,9 @@ } }, "node_modules/micromark/node_modules/micromark-util-character": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/micromark-util-character/-/micromark-util-character-2.0.1.tgz", - "integrity": "sha512-3wgnrmEAJ4T+mGXAUfMvMAbxU9RDG43XmGce4j6CwPtVxB3vfwXSZ6KhFwDzZ3mZHhmPimMAXg71veiBGzeAZw==", + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/micromark-util-character/-/micromark-util-character-2.1.0.tgz", + "integrity": "sha512-KvOVV+X1yLBfs9dCBSopq/+G1PcgT3lAK07mC4BzXi5E7ahzMAF8oIupDDJ6mievI6F+lAATkbQQlQixJfT3aQ==", "funding": [ { "type": "GitHub Sponsors", @@ -11903,17 +11927,6 @@ "node": ">=0.10.0" } }, - "node_modules/normalize-url": { - "version": "6.1.0", - "resolved": "https://registry.npmjs.org/normalize-url/-/normalize-url-6.1.0.tgz", - "integrity": "sha512-DlL+XwOy3NxAQ8xuC0okPgK46iuVNAK01YN7RueYBqqFeGsBjV9XmCAzAdgt+667bCl5kPh9EqKKDwnaPG1I7A==", - "engines": { - "node": ">=10" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" - } - }, "node_modules/npm-run-path": { "version": "4.0.1", "resolved": "https://registry.npmjs.org/npm-run-path/-/npm-run-path-4.0.1.tgz", @@ -12403,9 +12416,9 @@ } }, "node_modules/postcss": { - "version": "8.4.32", - "resolved": "https://registry.npmjs.org/postcss/-/postcss-8.4.32.tgz", - "integrity": "sha512-D/kj5JNu6oo2EIy+XL/26JEDTlIbB8hw85G8StOE6L74RQAVVP5rej6wxCNqyMbR4RkPfqvezVbPw81Ngd6Kcw==", + "version": "8.4.38", + "resolved": "https://registry.npmjs.org/postcss/-/postcss-8.4.38.tgz", + "integrity": "sha512-Wglpdk03BSfXkHoQa3b/oulrotAkwrlLDRSOb9D0bN86FdRyE9lppSp33aHNPgBa0JKCoB+drFLZkQoRRYae5A==", "funding": [ { "type": "opencollective", @@ -12423,112 +12436,115 @@ "dependencies": { "nanoid": "^3.3.7", "picocolors": "^1.0.0", - "source-map-js": "^1.0.2" + "source-map-js": "^1.2.0" }, "engines": { "node": "^10 || ^12 || >=14" } }, "node_modules/postcss-calc": { - "version": "8.2.4", - "resolved": "https://registry.npmjs.org/postcss-calc/-/postcss-calc-8.2.4.tgz", - "integrity": "sha512-SmWMSJmB8MRnnULldx0lQIyhSNvuDl9HfrZkaqqE/WHAhToYsAvDq+yAsA/kIyINDszOp3Rh0GFoNuH5Ypsm3Q==", + "version": "9.0.1", + "resolved": "https://registry.npmjs.org/postcss-calc/-/postcss-calc-9.0.1.tgz", + "integrity": "sha512-TipgjGyzP5QzEhsOZUaIkeO5mKeMFpebWzRogWG/ysonUlnHcq5aJe0jOjpfzUU8PeSaBQnrE8ehR0QA5vs8PQ==", "dependencies": { - "postcss-selector-parser": "^6.0.9", + "postcss-selector-parser": "^6.0.11", "postcss-value-parser": "^4.2.0" }, + "engines": { + "node": "^14 || ^16 || >=18.0" + }, "peerDependencies": { "postcss": "^8.2.2" } }, "node_modules/postcss-colormin": { - "version": "5.3.1", - "resolved": "https://registry.npmjs.org/postcss-colormin/-/postcss-colormin-5.3.1.tgz", - "integrity": "sha512-UsWQG0AqTFQmpBegeLLc1+c3jIqBNB0zlDGRWR+dQ3pRKJL1oeMzyqmH3o2PIfn9MBdNrVPWhDbT769LxCTLJQ==", + "version": "6.1.0", + "resolved": "https://registry.npmjs.org/postcss-colormin/-/postcss-colormin-6.1.0.tgz", + "integrity": "sha512-x9yX7DOxeMAR+BgGVnNSAxmAj98NX/YxEMNFP+SDCEeNLb2r3i6Hh1ksMsnW8Ub5SLCpbescQqn9YEbE9554Sw==", "dependencies": { - "browserslist": "^4.21.4", + "browserslist": "^4.23.0", "caniuse-api": "^3.0.0", - "colord": "^2.9.1", + "colord": "^2.9.3", "postcss-value-parser": "^4.2.0" }, "engines": { - "node": "^10 || ^12 || >=14.0" + "node": "^14 || ^16 || >=18.0" }, "peerDependencies": { - "postcss": "^8.2.15" + "postcss": "^8.4.31" } }, "node_modules/postcss-convert-values": { - "version": "5.1.3", - "resolved": "https://registry.npmjs.org/postcss-convert-values/-/postcss-convert-values-5.1.3.tgz", - "integrity": "sha512-82pC1xkJZtcJEfiLw6UXnXVXScgtBrjlO5CBmuDQc+dlb88ZYheFsjTn40+zBVi3DkfF7iezO0nJUPLcJK3pvA==", + "version": "6.1.0", + "resolved": "https://registry.npmjs.org/postcss-convert-values/-/postcss-convert-values-6.1.0.tgz", + "integrity": "sha512-zx8IwP/ts9WvUM6NkVSkiU902QZL1bwPhaVaLynPtCsOTqp+ZKbNi+s6XJg3rfqpKGA/oc7Oxk5t8pOQJcwl/w==", "dependencies": { - "browserslist": "^4.21.4", + "browserslist": "^4.23.0", "postcss-value-parser": "^4.2.0" }, "engines": { - "node": "^10 || ^12 || >=14.0" + "node": "^14 || ^16 || >=18.0" }, "peerDependencies": { - "postcss": "^8.2.15" + "postcss": "^8.4.31" } }, "node_modules/postcss-discard-comments": { - "version": "5.1.2", - "resolved": "https://registry.npmjs.org/postcss-discard-comments/-/postcss-discard-comments-5.1.2.tgz", - "integrity": "sha512-+L8208OVbHVF2UQf1iDmRcbdjJkuBF6IS29yBDSiWUIzpYaAhtNl6JYnYm12FnkeCwQqF5LeklOu6rAqgfBZqQ==", + "version": "6.0.2", + "resolved": "https://registry.npmjs.org/postcss-discard-comments/-/postcss-discard-comments-6.0.2.tgz", + "integrity": "sha512-65w/uIqhSBBfQmYnG92FO1mWZjJ4GL5b8atm5Yw2UgrwD7HiNiSSNwJor1eCFGzUgYnN/iIknhNRVqjrrpuglw==", "engines": { - "node": "^10 || ^12 || >=14.0" + "node": "^14 || ^16 || >=18.0" }, "peerDependencies": { - "postcss": "^8.2.15" + "postcss": "^8.4.31" } }, "node_modules/postcss-discard-duplicates": { - "version": "5.1.0", - "resolved": "https://registry.npmjs.org/postcss-discard-duplicates/-/postcss-discard-duplicates-5.1.0.tgz", - "integrity": "sha512-zmX3IoSI2aoenxHV6C7plngHWWhUOV3sP1T8y2ifzxzbtnuhk1EdPwm0S1bIUNaJ2eNbWeGLEwzw8huPD67aQw==", + "version": "6.0.3", + "resolved": "https://registry.npmjs.org/postcss-discard-duplicates/-/postcss-discard-duplicates-6.0.3.tgz", + "integrity": "sha512-+JA0DCvc5XvFAxwx6f/e68gQu/7Z9ud584VLmcgto28eB8FqSFZwtrLwB5Kcp70eIoWP/HXqz4wpo8rD8gpsTw==", "engines": { - "node": "^10 || ^12 || >=14.0" + "node": "^14 || ^16 || >=18.0" }, "peerDependencies": { - "postcss": "^8.2.15" + "postcss": "^8.4.31" } }, "node_modules/postcss-discard-empty": { - "version": "5.1.1", - "resolved": "https://registry.npmjs.org/postcss-discard-empty/-/postcss-discard-empty-5.1.1.tgz", - "integrity": "sha512-zPz4WljiSuLWsI0ir4Mcnr4qQQ5e1Ukc3i7UfE2XcrwKK2LIPIqE5jxMRxO6GbI3cv//ztXDsXwEWT3BHOGh3A==", + "version": "6.0.3", + "resolved": "https://registry.npmjs.org/postcss-discard-empty/-/postcss-discard-empty-6.0.3.tgz", + "integrity": "sha512-znyno9cHKQsK6PtxL5D19Fj9uwSzC2mB74cpT66fhgOadEUPyXFkbgwm5tvc3bt3NAy8ltE5MrghxovZRVnOjQ==", "engines": { - "node": "^10 || ^12 || >=14.0" + "node": "^14 || ^16 || >=18.0" }, "peerDependencies": { - "postcss": "^8.2.15" + "postcss": "^8.4.31" } }, "node_modules/postcss-discard-overridden": { - "version": "5.1.0", - "resolved": "https://registry.npmjs.org/postcss-discard-overridden/-/postcss-discard-overridden-5.1.0.tgz", - "integrity": "sha512-21nOL7RqWR1kasIVdKs8HNqQJhFxLsyRfAnUDm4Fe4t4mCWL9OJiHvlHPjcd8zc5Myu89b/7wZDnOSjFgeWRtw==", + "version": "6.0.2", + "resolved": "https://registry.npmjs.org/postcss-discard-overridden/-/postcss-discard-overridden-6.0.2.tgz", + "integrity": "sha512-j87xzI4LUggC5zND7KdjsI25APtyMuynXZSujByMaav2roV6OZX+8AaCUcZSWqckZpjAjRyFDdpqybgjFO0HJQ==", "engines": { - "node": "^10 || ^12 || >=14.0" + "node": "^14 || ^16 || >=18.0" }, "peerDependencies": { - "postcss": "^8.2.15" + "postcss": "^8.4.31" } }, "node_modules/postcss-discard-unused": { - "version": "5.1.0", - "resolved": "https://registry.npmjs.org/postcss-discard-unused/-/postcss-discard-unused-5.1.0.tgz", - "integrity": "sha512-KwLWymI9hbwXmJa0dkrzpRbSJEh0vVUd7r8t0yOGPcfKzyJJxFM8kLyC5Ev9avji6nY95pOp1W6HqIrfT+0VGw==", + "version": "6.0.5", + "resolved": "https://registry.npmjs.org/postcss-discard-unused/-/postcss-discard-unused-6.0.5.tgz", + "integrity": "sha512-wHalBlRHkaNnNwfC8z+ppX57VhvS+HWgjW508esjdaEYr3Mx7Gnn2xA4R/CKf5+Z9S5qsqC+Uzh4ueENWwCVUA==", "dependencies": { - "postcss-selector-parser": "^6.0.5" + "postcss-selector-parser": "^6.0.16" }, "engines": { - "node": "^10 || ^12 || >=14.0" + "node": "^14 || ^16 || >=18.0" }, "peerDependencies": { - "postcss": "^8.2.15" + "postcss": "^8.4.31" } }, "node_modules/postcss-loader": { @@ -12552,136 +12568,111 @@ "webpack": "^5.0.0" } }, - "node_modules/postcss-loader/node_modules/cosmiconfig": { - "version": "8.3.6", - "resolved": "https://registry.npmjs.org/cosmiconfig/-/cosmiconfig-8.3.6.tgz", - "integrity": "sha512-kcZ6+W5QzcJ3P1Mt+83OUv/oHFqZHIx8DuxG6eZ5RGMERoLqp4BuGjhHLYGK+Kf5XVkQvqBSmAy/nGWN3qDgEA==", - "dependencies": { - "import-fresh": "^3.3.0", - "js-yaml": "^4.1.0", - "parse-json": "^5.2.0", - "path-type": "^4.0.0" - }, - "engines": { - "node": ">=14" - }, - "funding": { - "url": "https://github.com/sponsors/d-fischer" - }, - "peerDependencies": { - "typescript": ">=4.9.5" - }, - "peerDependenciesMeta": { - "typescript": { - "optional": true - } - } - }, "node_modules/postcss-merge-idents": { - "version": "5.1.1", - "resolved": "https://registry.npmjs.org/postcss-merge-idents/-/postcss-merge-idents-5.1.1.tgz", - "integrity": "sha512-pCijL1TREiCoog5nQp7wUe+TUonA2tC2sQ54UGeMmryK3UFGIYKqDyjnqd6RcuI4znFn9hWSLNN8xKE/vWcUQw==", + "version": "6.0.3", + "resolved": "https://registry.npmjs.org/postcss-merge-idents/-/postcss-merge-idents-6.0.3.tgz", + "integrity": "sha512-1oIoAsODUs6IHQZkLQGO15uGEbK3EAl5wi9SS8hs45VgsxQfMnxvt+L+zIr7ifZFIH14cfAeVe2uCTa+SPRa3g==", "dependencies": { - "cssnano-utils": "^3.1.0", + "cssnano-utils": "^4.0.2", "postcss-value-parser": "^4.2.0" }, "engines": { - "node": "^10 || ^12 || >=14.0" + "node": "^14 || ^16 || >=18.0" }, "peerDependencies": { - "postcss": "^8.2.15" + "postcss": "^8.4.31" } }, "node_modules/postcss-merge-longhand": { - "version": "5.1.7", - "resolved": "https://registry.npmjs.org/postcss-merge-longhand/-/postcss-merge-longhand-5.1.7.tgz", - "integrity": "sha512-YCI9gZB+PLNskrK0BB3/2OzPnGhPkBEwmwhfYk1ilBHYVAZB7/tkTHFBAnCrvBBOmeYyMYw3DMjT55SyxMBzjQ==", + "version": "6.0.5", + "resolved": "https://registry.npmjs.org/postcss-merge-longhand/-/postcss-merge-longhand-6.0.5.tgz", + "integrity": "sha512-5LOiordeTfi64QhICp07nzzuTDjNSO8g5Ksdibt44d+uvIIAE1oZdRn8y/W5ZtYgRH/lnLDlvi9F8btZcVzu3w==", "dependencies": { "postcss-value-parser": "^4.2.0", - "stylehacks": "^5.1.1" + "stylehacks": "^6.1.1" }, "engines": { - "node": "^10 || ^12 || >=14.0" + "node": "^14 || ^16 || >=18.0" }, "peerDependencies": { - "postcss": "^8.2.15" + "postcss": "^8.4.31" } }, "node_modules/postcss-merge-rules": { - "version": "5.1.4", - "resolved": "https://registry.npmjs.org/postcss-merge-rules/-/postcss-merge-rules-5.1.4.tgz", - "integrity": "sha512-0R2IuYpgU93y9lhVbO/OylTtKMVcHb67zjWIfCiKR9rWL3GUk1677LAqD/BcHizukdZEjT8Ru3oHRoAYoJy44g==", + "version": "6.1.1", + "resolved": "https://registry.npmjs.org/postcss-merge-rules/-/postcss-merge-rules-6.1.1.tgz", + "integrity": "sha512-KOdWF0gju31AQPZiD+2Ar9Qjowz1LTChSjFFbS+e2sFgc4uHOp3ZvVX4sNeTlk0w2O31ecFGgrFzhO0RSWbWwQ==", "dependencies": { - "browserslist": "^4.21.4", + "browserslist": "^4.23.0", "caniuse-api": "^3.0.0", - "cssnano-utils": "^3.1.0", - "postcss-selector-parser": "^6.0.5" + "cssnano-utils": "^4.0.2", + "postcss-selector-parser": "^6.0.16" }, "engines": { - "node": "^10 || ^12 || >=14.0" + "node": "^14 || ^16 || >=18.0" }, "peerDependencies": { - "postcss": "^8.2.15" + "postcss": "^8.4.31" } }, "node_modules/postcss-minify-font-values": { - "version": "5.1.0", - "resolved": "https://registry.npmjs.org/postcss-minify-font-values/-/postcss-minify-font-values-5.1.0.tgz", - "integrity": "sha512-el3mYTgx13ZAPPirSVsHqFzl+BBBDrXvbySvPGFnQcTI4iNslrPaFq4muTkLZmKlGk4gyFAYUBMH30+HurREyA==", + "version": "6.1.0", + "resolved": "https://registry.npmjs.org/postcss-minify-font-values/-/postcss-minify-font-values-6.1.0.tgz", + "integrity": "sha512-gklfI/n+9rTh8nYaSJXlCo3nOKqMNkxuGpTn/Qm0gstL3ywTr9/WRKznE+oy6fvfolH6dF+QM4nCo8yPLdvGJg==", "dependencies": { "postcss-value-parser": "^4.2.0" }, "engines": { - "node": "^10 || ^12 || >=14.0" + "node": "^14 || ^16 || >=18.0" }, "peerDependencies": { - "postcss": "^8.2.15" + "postcss": "^8.4.31" } }, "node_modules/postcss-minify-gradients": { - "version": "5.1.1", - "resolved": "https://registry.npmjs.org/postcss-minify-gradients/-/postcss-minify-gradients-5.1.1.tgz", - "integrity": "sha512-VGvXMTpCEo4qHTNSa9A0a3D+dxGFZCYwR6Jokk+/3oB6flu2/PnPXAh2x7x52EkY5xlIHLm+Le8tJxe/7TNhzw==", + "version": "6.0.3", + "resolved": "https://registry.npmjs.org/postcss-minify-gradients/-/postcss-minify-gradients-6.0.3.tgz", + "integrity": "sha512-4KXAHrYlzF0Rr7uc4VrfwDJ2ajrtNEpNEuLxFgwkhFZ56/7gaE4Nr49nLsQDZyUe+ds+kEhf+YAUolJiYXF8+Q==", "dependencies": { - "colord": "^2.9.1", - "cssnano-utils": "^3.1.0", + "colord": "^2.9.3", + "cssnano-utils": "^4.0.2", "postcss-value-parser": "^4.2.0" }, "engines": { - "node": "^10 || ^12 || >=14.0" + "node": "^14 || ^16 || >=18.0" }, "peerDependencies": { - "postcss": "^8.2.15" + "postcss": "^8.4.31" } }, "node_modules/postcss-minify-params": { - "version": "5.1.4", - "resolved": "https://registry.npmjs.org/postcss-minify-params/-/postcss-minify-params-5.1.4.tgz", - "integrity": "sha512-+mePA3MgdmVmv6g+30rn57USjOGSAyuxUmkfiWpzalZ8aiBkdPYjXWtHuwJGm1v5Ojy0Z0LaSYhHaLJQB0P8Jw==", + "version": "6.1.0", + "resolved": "https://registry.npmjs.org/postcss-minify-params/-/postcss-minify-params-6.1.0.tgz", + "integrity": "sha512-bmSKnDtyyE8ujHQK0RQJDIKhQ20Jq1LYiez54WiaOoBtcSuflfK3Nm596LvbtlFcpipMjgClQGyGr7GAs+H1uA==", "dependencies": { - "browserslist": "^4.21.4", - "cssnano-utils": "^3.1.0", + "browserslist": "^4.23.0", + "cssnano-utils": "^4.0.2", "postcss-value-parser": "^4.2.0" }, "engines": { - "node": "^10 || ^12 || >=14.0" + "node": "^14 || ^16 || >=18.0" }, "peerDependencies": { - "postcss": "^8.2.15" + "postcss": "^8.4.31" } }, "node_modules/postcss-minify-selectors": { - "version": "5.2.1", - "resolved": "https://registry.npmjs.org/postcss-minify-selectors/-/postcss-minify-selectors-5.2.1.tgz", - "integrity": "sha512-nPJu7OjZJTsVUmPdm2TcaiohIwxP+v8ha9NehQ2ye9szv4orirRU3SDdtUmKH+10nzn0bAyOXZ0UEr7OpvLehg==", + "version": "6.0.4", + "resolved": "https://registry.npmjs.org/postcss-minify-selectors/-/postcss-minify-selectors-6.0.4.tgz", + "integrity": "sha512-L8dZSwNLgK7pjTto9PzWRoMbnLq5vsZSTu8+j1P/2GB8qdtGQfn+K1uSvFgYvgh83cbyxT5m43ZZhUMTJDSClQ==", "dependencies": { - "postcss-selector-parser": "^6.0.5" + "postcss-selector-parser": "^6.0.16" }, "engines": { - "node": "^10 || ^12 || >=14.0" + "node": "^14 || ^16 || >=18.0" }, "peerDependencies": { - "postcss": "^8.2.15" + "postcss": "^8.4.31" } }, "node_modules/postcss-modules-extract-imports": { @@ -12740,192 +12731,191 @@ } }, "node_modules/postcss-normalize-charset": { - "version": "5.1.0", - "resolved": "https://registry.npmjs.org/postcss-normalize-charset/-/postcss-normalize-charset-5.1.0.tgz", - "integrity": "sha512-mSgUJ+pd/ldRGVx26p2wz9dNZ7ji6Pn8VWBajMXFf8jk7vUoSrZ2lt/wZR7DtlZYKesmZI680qjr2CeFF2fbUg==", + "version": "6.0.2", + "resolved": "https://registry.npmjs.org/postcss-normalize-charset/-/postcss-normalize-charset-6.0.2.tgz", + "integrity": "sha512-a8N9czmdnrjPHa3DeFlwqst5eaL5W8jYu3EBbTTkI5FHkfMhFZh1EGbku6jhHhIzTA6tquI2P42NtZ59M/H/kQ==", "engines": { - "node": "^10 || ^12 || >=14.0" + "node": "^14 || ^16 || >=18.0" }, "peerDependencies": { - "postcss": "^8.2.15" + "postcss": "^8.4.31" } }, "node_modules/postcss-normalize-display-values": { - "version": "5.1.0", - "resolved": "https://registry.npmjs.org/postcss-normalize-display-values/-/postcss-normalize-display-values-5.1.0.tgz", - "integrity": "sha512-WP4KIM4o2dazQXWmFaqMmcvsKmhdINFblgSeRgn8BJ6vxaMyaJkwAzpPpuvSIoG/rmX3M+IrRZEz2H0glrQNEA==", + "version": "6.0.2", + "resolved": "https://registry.npmjs.org/postcss-normalize-display-values/-/postcss-normalize-display-values-6.0.2.tgz", + "integrity": "sha512-8H04Mxsb82ON/aAkPeq8kcBbAtI5Q2a64X/mnRRfPXBq7XeogoQvReqxEfc0B4WPq1KimjezNC8flUtC3Qz6jg==", "dependencies": { "postcss-value-parser": "^4.2.0" }, "engines": { - "node": "^10 || ^12 || >=14.0" + "node": "^14 || ^16 || >=18.0" }, "peerDependencies": { - "postcss": "^8.2.15" + "postcss": "^8.4.31" } }, "node_modules/postcss-normalize-positions": { - "version": "5.1.1", - "resolved": "https://registry.npmjs.org/postcss-normalize-positions/-/postcss-normalize-positions-5.1.1.tgz", - "integrity": "sha512-6UpCb0G4eofTCQLFVuI3EVNZzBNPiIKcA1AKVka+31fTVySphr3VUgAIULBhxZkKgwLImhzMR2Bw1ORK+37INg==", + "version": "6.0.2", + "resolved": "https://registry.npmjs.org/postcss-normalize-positions/-/postcss-normalize-positions-6.0.2.tgz", + "integrity": "sha512-/JFzI441OAB9O7VnLA+RtSNZvQ0NCFZDOtp6QPFo1iIyawyXg0YI3CYM9HBy1WvwCRHnPep/BvI1+dGPKoXx/Q==", "dependencies": { "postcss-value-parser": "^4.2.0" }, "engines": { - "node": "^10 || ^12 || >=14.0" + "node": "^14 || ^16 || >=18.0" }, "peerDependencies": { - "postcss": "^8.2.15" + "postcss": "^8.4.31" } }, "node_modules/postcss-normalize-repeat-style": { - "version": "5.1.1", - "resolved": "https://registry.npmjs.org/postcss-normalize-repeat-style/-/postcss-normalize-repeat-style-5.1.1.tgz", - "integrity": "sha512-mFpLspGWkQtBcWIRFLmewo8aC3ImN2i/J3v8YCFUwDnPu3Xz4rLohDO26lGjwNsQxB3YF0KKRwspGzE2JEuS0g==", + "version": "6.0.2", + "resolved": "https://registry.npmjs.org/postcss-normalize-repeat-style/-/postcss-normalize-repeat-style-6.0.2.tgz", + "integrity": "sha512-YdCgsfHkJ2jEXwR4RR3Tm/iOxSfdRt7jplS6XRh9Js9PyCR/aka/FCb6TuHT2U8gQubbm/mPmF6L7FY9d79VwQ==", "dependencies": { "postcss-value-parser": "^4.2.0" }, "engines": { - "node": "^10 || ^12 || >=14.0" + "node": "^14 || ^16 || >=18.0" }, "peerDependencies": { - "postcss": "^8.2.15" + "postcss": "^8.4.31" } }, "node_modules/postcss-normalize-string": { - "version": "5.1.0", - "resolved": "https://registry.npmjs.org/postcss-normalize-string/-/postcss-normalize-string-5.1.0.tgz", - "integrity": "sha512-oYiIJOf4T9T1N4i+abeIc7Vgm/xPCGih4bZz5Nm0/ARVJ7K6xrDlLwvwqOydvyL3RHNf8qZk6vo3aatiw/go3w==", + "version": "6.0.2", + "resolved": "https://registry.npmjs.org/postcss-normalize-string/-/postcss-normalize-string-6.0.2.tgz", + "integrity": "sha512-vQZIivlxlfqqMp4L9PZsFE4YUkWniziKjQWUtsxUiVsSSPelQydwS8Wwcuw0+83ZjPWNTl02oxlIvXsmmG+CiQ==", "dependencies": { "postcss-value-parser": "^4.2.0" }, "engines": { - "node": "^10 || ^12 || >=14.0" + "node": "^14 || ^16 || >=18.0" }, "peerDependencies": { - "postcss": "^8.2.15" + "postcss": "^8.4.31" } }, "node_modules/postcss-normalize-timing-functions": { - "version": "5.1.0", - "resolved": "https://registry.npmjs.org/postcss-normalize-timing-functions/-/postcss-normalize-timing-functions-5.1.0.tgz", - "integrity": "sha512-DOEkzJ4SAXv5xkHl0Wa9cZLF3WCBhF3o1SKVxKQAa+0pYKlueTpCgvkFAHfk+Y64ezX9+nITGrDZeVGgITJXjg==", + "version": "6.0.2", + "resolved": "https://registry.npmjs.org/postcss-normalize-timing-functions/-/postcss-normalize-timing-functions-6.0.2.tgz", + "integrity": "sha512-a+YrtMox4TBtId/AEwbA03VcJgtyW4dGBizPl7e88cTFULYsprgHWTbfyjSLyHeBcK/Q9JhXkt2ZXiwaVHoMzA==", "dependencies": { "postcss-value-parser": "^4.2.0" }, "engines": { - "node": "^10 || ^12 || >=14.0" + "node": "^14 || ^16 || >=18.0" }, "peerDependencies": { - "postcss": "^8.2.15" + "postcss": "^8.4.31" } }, "node_modules/postcss-normalize-unicode": { - "version": "5.1.1", - "resolved": "https://registry.npmjs.org/postcss-normalize-unicode/-/postcss-normalize-unicode-5.1.1.tgz", - "integrity": "sha512-qnCL5jzkNUmKVhZoENp1mJiGNPcsJCs1aaRmURmeJGES23Z/ajaln+EPTD+rBeNkSryI+2WTdW+lwcVdOikrpA==", + "version": "6.1.0", + "resolved": "https://registry.npmjs.org/postcss-normalize-unicode/-/postcss-normalize-unicode-6.1.0.tgz", + "integrity": "sha512-QVC5TQHsVj33otj8/JD869Ndr5Xcc/+fwRh4HAsFsAeygQQXm+0PySrKbr/8tkDKzW+EVT3QkqZMfFrGiossDg==", "dependencies": { - "browserslist": "^4.21.4", + "browserslist": "^4.23.0", "postcss-value-parser": "^4.2.0" }, "engines": { - "node": "^10 || ^12 || >=14.0" + "node": "^14 || ^16 || >=18.0" }, "peerDependencies": { - "postcss": "^8.2.15" + "postcss": "^8.4.31" } }, "node_modules/postcss-normalize-url": { - "version": "5.1.0", - "resolved": "https://registry.npmjs.org/postcss-normalize-url/-/postcss-normalize-url-5.1.0.tgz", - "integrity": "sha512-5upGeDO+PVthOxSmds43ZeMeZfKH+/DKgGRD7TElkkyS46JXAUhMzIKiCa7BabPeIy3AQcTkXwVVN7DbqsiCew==", + "version": "6.0.2", + "resolved": "https://registry.npmjs.org/postcss-normalize-url/-/postcss-normalize-url-6.0.2.tgz", + "integrity": "sha512-kVNcWhCeKAzZ8B4pv/DnrU1wNh458zBNp8dh4y5hhxih5RZQ12QWMuQrDgPRw3LRl8mN9vOVfHl7uhvHYMoXsQ==", "dependencies": { - "normalize-url": "^6.0.1", "postcss-value-parser": "^4.2.0" }, "engines": { - "node": "^10 || ^12 || >=14.0" + "node": "^14 || ^16 || >=18.0" }, "peerDependencies": { - "postcss": "^8.2.15" + "postcss": "^8.4.31" } }, "node_modules/postcss-normalize-whitespace": { - "version": "5.1.1", - "resolved": "https://registry.npmjs.org/postcss-normalize-whitespace/-/postcss-normalize-whitespace-5.1.1.tgz", - "integrity": "sha512-83ZJ4t3NUDETIHTa3uEg6asWjSBYL5EdkVB0sDncx9ERzOKBVJIUeDO9RyA9Zwtig8El1d79HBp0JEi8wvGQnA==", + "version": "6.0.2", + "resolved": "https://registry.npmjs.org/postcss-normalize-whitespace/-/postcss-normalize-whitespace-6.0.2.tgz", + "integrity": "sha512-sXZ2Nj1icbJOKmdjXVT9pnyHQKiSAyuNQHSgRCUgThn2388Y9cGVDR+E9J9iAYbSbLHI+UUwLVl1Wzco/zgv0Q==", "dependencies": { "postcss-value-parser": "^4.2.0" }, "engines": { - "node": "^10 || ^12 || >=14.0" + "node": "^14 || ^16 || >=18.0" }, "peerDependencies": { - "postcss": "^8.2.15" + "postcss": "^8.4.31" } }, "node_modules/postcss-ordered-values": { - "version": "5.1.3", - "resolved": "https://registry.npmjs.org/postcss-ordered-values/-/postcss-ordered-values-5.1.3.tgz", - "integrity": "sha512-9UO79VUhPwEkzbb3RNpqqghc6lcYej1aveQteWY+4POIwlqkYE21HKWaLDF6lWNuqCobEAyTovVhtI32Rbv2RQ==", + "version": "6.0.2", + "resolved": "https://registry.npmjs.org/postcss-ordered-values/-/postcss-ordered-values-6.0.2.tgz", + "integrity": "sha512-VRZSOB+JU32RsEAQrO94QPkClGPKJEL/Z9PCBImXMhIeK5KAYo6slP/hBYlLgrCjFxyqvn5VC81tycFEDBLG1Q==", "dependencies": { - "cssnano-utils": "^3.1.0", + "cssnano-utils": "^4.0.2", "postcss-value-parser": "^4.2.0" }, "engines": { - "node": "^10 || ^12 || >=14.0" + "node": "^14 || ^16 || >=18.0" }, "peerDependencies": { - "postcss": "^8.2.15" + "postcss": "^8.4.31" } }, "node_modules/postcss-reduce-idents": { - "version": "5.2.0", - "resolved": "https://registry.npmjs.org/postcss-reduce-idents/-/postcss-reduce-idents-5.2.0.tgz", - "integrity": "sha512-BTrLjICoSB6gxbc58D5mdBK8OhXRDqud/zodYfdSi52qvDHdMwk+9kB9xsM8yJThH/sZU5A6QVSmMmaN001gIg==", + "version": "6.0.3", + "resolved": "https://registry.npmjs.org/postcss-reduce-idents/-/postcss-reduce-idents-6.0.3.tgz", + "integrity": "sha512-G3yCqZDpsNPoQgbDUy3T0E6hqOQ5xigUtBQyrmq3tn2GxlyiL0yyl7H+T8ulQR6kOcHJ9t7/9H4/R2tv8tJbMA==", "dependencies": { "postcss-value-parser": "^4.2.0" }, "engines": { - "node": "^10 || ^12 || >=14.0" + "node": "^14 || ^16 || >=18.0" }, "peerDependencies": { - "postcss": "^8.2.15" + "postcss": "^8.4.31" } }, "node_modules/postcss-reduce-initial": { - "version": "5.1.2", - "resolved": "https://registry.npmjs.org/postcss-reduce-initial/-/postcss-reduce-initial-5.1.2.tgz", - "integrity": "sha512-dE/y2XRaqAi6OvjzD22pjTUQ8eOfc6m/natGHgKFBK9DxFmIm69YmaRVQrGgFlEfc1HePIurY0TmDeROK05rIg==", + "version": "6.1.0", + "resolved": "https://registry.npmjs.org/postcss-reduce-initial/-/postcss-reduce-initial-6.1.0.tgz", + "integrity": "sha512-RarLgBK/CrL1qZags04oKbVbrrVK2wcxhvta3GCxrZO4zveibqbRPmm2VI8sSgCXwoUHEliRSbOfpR0b/VIoiw==", "dependencies": { - "browserslist": "^4.21.4", + "browserslist": "^4.23.0", "caniuse-api": "^3.0.0" }, "engines": { - "node": "^10 || ^12 || >=14.0" + "node": "^14 || ^16 || >=18.0" }, "peerDependencies": { - "postcss": "^8.2.15" + "postcss": "^8.4.31" } }, "node_modules/postcss-reduce-transforms": { - "version": "5.1.0", - "resolved": "https://registry.npmjs.org/postcss-reduce-transforms/-/postcss-reduce-transforms-5.1.0.tgz", - "integrity": "sha512-2fbdbmgir5AvpW9RLtdONx1QoYG2/EtqpNQbFASDlixBbAYuTcJ0dECwlqNqH7VbaUnEnh8SrxOe2sRIn24XyQ==", + "version": "6.0.2", + "resolved": "https://registry.npmjs.org/postcss-reduce-transforms/-/postcss-reduce-transforms-6.0.2.tgz", + "integrity": "sha512-sB+Ya++3Xj1WaT9+5LOOdirAxP7dJZms3GRcYheSPi1PiTMigsxHAdkrbItHxwYHr4kt1zL7mmcHstgMYT+aiA==", "dependencies": { "postcss-value-parser": "^4.2.0" }, "engines": { - "node": "^10 || ^12 || >=14.0" + "node": "^14 || ^16 || >=18.0" }, "peerDependencies": { - "postcss": "^8.2.15" + "postcss": "^8.4.31" } }, "node_modules/postcss-selector-parser": { - "version": "6.0.14", - "resolved": "https://registry.npmjs.org/postcss-selector-parser/-/postcss-selector-parser-6.0.14.tgz", - "integrity": "sha512-65xXYsT40i9GyWzlHQ5ShZoK7JZdySeOozi/tz2EezDo6c04q6+ckYMeoY7idaie1qp2dT5KoYQ2yky6JuoHnA==", + "version": "6.1.0", + "resolved": "https://registry.npmjs.org/postcss-selector-parser/-/postcss-selector-parser-6.1.0.tgz", + "integrity": "sha512-UMz42UD0UY0EApS0ZL9o1XnLhSTtvvvLe5Dc2H2O56fvRZi+KulDyf5ctDhhtYJBGKStV2FL1fy6253cmLgqVQ==", "dependencies": { "cssesc": "^3.0.0", "util-deprecate": "^1.0.2" @@ -12935,46 +12925,46 @@ } }, "node_modules/postcss-sort-media-queries": { - "version": "4.4.1", - "resolved": "https://registry.npmjs.org/postcss-sort-media-queries/-/postcss-sort-media-queries-4.4.1.tgz", - "integrity": "sha512-QDESFzDDGKgpiIh4GYXsSy6sek2yAwQx1JASl5AxBtU1Lq2JfKBljIPNdil989NcSKRQX1ToiaKphImtBuhXWw==", + "version": "5.2.0", + "resolved": "https://registry.npmjs.org/postcss-sort-media-queries/-/postcss-sort-media-queries-5.2.0.tgz", + "integrity": "sha512-AZ5fDMLD8SldlAYlvi8NIqo0+Z8xnXU2ia0jxmuhxAU+Lqt9K+AlmLNJ/zWEnE9x+Zx3qL3+1K20ATgNOr3fAA==", "dependencies": { - "sort-css-media-queries": "2.1.0" + "sort-css-media-queries": "2.2.0" }, "engines": { - "node": ">=10.0.0" + "node": ">=14.0.0" }, "peerDependencies": { - "postcss": "^8.4.16" + "postcss": "^8.4.23" } }, "node_modules/postcss-svgo": { - "version": "5.1.0", - "resolved": "https://registry.npmjs.org/postcss-svgo/-/postcss-svgo-5.1.0.tgz", - "integrity": "sha512-D75KsH1zm5ZrHyxPakAxJWtkyXew5qwS70v56exwvw542d9CRtTo78K0WeFxZB4G7JXKKMbEZtZayTGdIky/eA==", + "version": "6.0.3", + "resolved": "https://registry.npmjs.org/postcss-svgo/-/postcss-svgo-6.0.3.tgz", + "integrity": "sha512-dlrahRmxP22bX6iKEjOM+c8/1p+81asjKT+V5lrgOH944ryx/OHpclnIbGsKVd3uWOXFLYJwCVf0eEkJGvO96g==", "dependencies": { "postcss-value-parser": "^4.2.0", - "svgo": "^2.7.0" + "svgo": "^3.2.0" }, "engines": { - "node": "^10 || ^12 || >=14.0" + "node": "^14 || ^16 || >= 18" }, "peerDependencies": { - "postcss": "^8.2.15" + "postcss": "^8.4.31" } }, "node_modules/postcss-unique-selectors": { - "version": "5.1.1", - "resolved": "https://registry.npmjs.org/postcss-unique-selectors/-/postcss-unique-selectors-5.1.1.tgz", - "integrity": "sha512-5JiODlELrz8L2HwxfPnhOWZYWDxVHWL83ufOv84NrcgipI7TaeRsatAhK4Tr2/ZiYldpK/wBvw5BD3qfaK96GA==", + "version": "6.0.4", + "resolved": "https://registry.npmjs.org/postcss-unique-selectors/-/postcss-unique-selectors-6.0.4.tgz", + "integrity": "sha512-K38OCaIrO8+PzpArzkLKB42dSARtC2tmG6PvD4b1o1Q2E9Os8jzfWFfSy/rixsHwohtsDdFtAWGjFVFUdwYaMg==", "dependencies": { - "postcss-selector-parser": "^6.0.5" + "postcss-selector-parser": "^6.0.16" }, "engines": { - "node": "^10 || ^12 || >=14.0" + "node": "^14 || ^16 || >=18.0" }, "peerDependencies": { - "postcss": "^8.2.15" + "postcss": "^8.4.31" } }, "node_modules/postcss-value-parser": { @@ -12983,14 +12973,14 @@ "integrity": "sha512-1NNCs6uurfkVbeXG4S8JFT9t19m45ICnif8zWLd5oPSZ50QnwMfK+H3jv408d4jw/7Bttv5axS5IiHoLaVNHeQ==" }, "node_modules/postcss-zindex": { - "version": "5.1.0", - "resolved": "https://registry.npmjs.org/postcss-zindex/-/postcss-zindex-5.1.0.tgz", - "integrity": "sha512-fgFMf0OtVSBR1va1JNHYgMxYk73yhn/qb4uQDq1DLGYolz8gHCyr/sesEuGUaYs58E3ZJRcpoGuPVoB7Meiq9A==", + "version": "6.0.2", + "resolved": "https://registry.npmjs.org/postcss-zindex/-/postcss-zindex-6.0.2.tgz", + "integrity": "sha512-5BxW9l1evPB/4ZIc+2GobEBoKC+h8gPGCMi+jxsYvd2x0mjq7wazk6DrP71pStqxE9Foxh5TVnonbWpFZzXaYg==", "engines": { - "node": "^10 || ^12 || >=14.0" + "node": "^14 || ^16 || >=18.0" }, "peerDependencies": { - "postcss": "^8.2.15" + "postcss": "^8.4.31" } }, "node_modules/pretty-error": { @@ -13058,9 +13048,9 @@ } }, "node_modules/property-information": { - "version": "6.4.0", - "resolved": "https://registry.npmjs.org/property-information/-/property-information-6.4.0.tgz", - "integrity": "sha512-9t5qARVofg2xQqKtytzt+lZ4d1Qvj8t5B8fEwXK6qOfgRLgH/b13QlgEyDh033NOS31nXeFbYv7CLUDG1CeifQ==", + "version": "6.5.0", + "resolved": "https://registry.npmjs.org/property-information/-/property-information-6.5.0.tgz", + "integrity": "sha512-PgTgs/BlvHxOu8QuEN7wi5A0OmXaBcHpmCSTehcs6Uuu9IkDIEo13Hy7n898RHfrQ49vKCoGeWZSaAK01nwVig==", "funding": { "type": "github", "url": "https://github.com/sponsors/wooorm" @@ -13395,9 +13385,9 @@ "integrity": "sha512-24e6ynE2H+OKt4kqsOvNd8kBpV65zoxbA4BVsEOB3ARVWQki/DHzaUoC5KuON/BiccDaCCTZBuOcfZs70kR8bQ==" }, "node_modules/react-json-view-lite": { - "version": "1.2.1", - "resolved": "https://registry.npmjs.org/react-json-view-lite/-/react-json-view-lite-1.2.1.tgz", - "integrity": "sha512-Itc0g86fytOmKZoIoJyGgvNqohWSbh3NXIKNgH6W6FT9PC1ck4xas1tT3Rr/b3UlFXyA9Jjaw9QSXdZy2JwGMQ==", + "version": "1.4.0", + "resolved": "https://registry.npmjs.org/react-json-view-lite/-/react-json-view-lite-1.4.0.tgz", + "integrity": "sha512-wh6F6uJyYAmQ4fK0e8dSQMEWuvTs2Wr3el3sLD9bambX1+pSWUVXIz1RFaoy3TI1mZ0FqdpKq9YgbgTTgyrmXA==", "engines": { "node": ">=14" }, @@ -13706,9 +13696,9 @@ } }, "node_modules/remark-mdx": { - "version": "3.0.0", - "resolved": "https://registry.npmjs.org/remark-mdx/-/remark-mdx-3.0.0.tgz", - "integrity": "sha512-O7yfjuC6ra3NHPbRVxfflafAj3LTwx3b73aBvkEFU5z4PsD6FD4vrqJAkE5iNGLz71GdjXfgRqm3SQ0h0VuE7g==", + "version": "3.0.1", + "resolved": "https://registry.npmjs.org/remark-mdx/-/remark-mdx-3.0.1.tgz", + "integrity": "sha512-3Pz3yPQ5Rht2pM5R+0J2MrGoBSrzf+tJG94N+t/ilfdh8YLyyKYtidAYwTveB20BoHAcwIopOUqhcmh2F7hGYA==", "dependencies": { "mdast-util-mdx": "^3.0.0", "micromark-extension-mdxjs": "^3.0.0" @@ -13734,9 +13724,9 @@ } }, "node_modules/remark-rehype": { - "version": "11.0.0", - "resolved": "https://registry.npmjs.org/remark-rehype/-/remark-rehype-11.0.0.tgz", - "integrity": "sha512-vx8x2MDMcxuE4lBmQ46zYUDfcFMmvg80WYX+UNLeG6ixjdCCLcw1lrgAukwBTuOFsS78eoAedHGn9sNM0w7TPw==", + "version": "11.1.0", + "resolved": "https://registry.npmjs.org/remark-rehype/-/remark-rehype-11.1.0.tgz", + "integrity": "sha512-z3tJrAs2kIs1AqIIy6pzHmAHlF1hWQ+OdY4/hv+Wxe35EhyLKcajL33iUEn3ScxtFox9nUvRufR/Zre8Q08H/g==", "dependencies": { "@types/hast": "^3.0.0", "@types/mdast": "^4.0.0", @@ -14046,9 +14036,9 @@ "integrity": "sha512-YZo3K82SD7Riyi0E1EQPojLz7kpepnSQI9IyPbHHg1XXXevb5dJI7tpyN2ADxGcQbHG7vcyRHk0cbwqcQriUtg==" }, "node_modules/sax": { - "version": "1.3.0", - "resolved": "https://registry.npmjs.org/sax/-/sax-1.3.0.tgz", - "integrity": "sha512-0s+oAmw9zLl1V1cS9BtZN7JAd0cW5e0QH4W3LWEK6a4LaLEA2OTpGYWDY+6XasBLtz6wkm3u1xRw95mRuJ59WA==" + "version": "1.4.1", + "resolved": "https://registry.npmjs.org/sax/-/sax-1.4.1.tgz", + "integrity": "sha512-+aWOz7yVScEGoKNd4PA10LZ8sk0A/z5+nXQG5giUO5rprX9jgYsTdov9qCchZiPIZezbZH+jRut8nPodFAX4Jg==" }, "node_modules/scheduler": { "version": "0.23.0", @@ -14077,9 +14067,9 @@ } }, "node_modules/search-insights": { - "version": "2.13.0", - "resolved": "https://registry.npmjs.org/search-insights/-/search-insights-2.13.0.tgz", - "integrity": "sha512-Orrsjf9trHHxFRuo9/rzm0KIWmgzE8RMlZMzuhZOJ01Rnz3D0YBAe+V6473t6/H6c7irs6Lt48brULAiRWb3Vw==", + "version": "2.14.0", + "resolved": "https://registry.npmjs.org/search-insights/-/search-insights-2.14.0.tgz", + "integrity": "sha512-OLN6MsPMCghDOqlCtsIsYgtsC0pnwVTyT9Mu6A3ewOj1DxvzZF6COrn2g86E/c05xbktB0XN04m/t1Z+n+fTGw==", "peer": true }, "node_modules/section-matter": { @@ -14437,9 +14427,9 @@ "integrity": "sha512-bLGGlR1QxBcynn2d5YmDX4MGjlZvy2MRBDRNHLJ8VI6l6+9FUiyTFNJ0IveOSP0bcXgVDPRcfGqA0pjaqUpfVg==" }, "node_modules/sitemap": { - "version": "7.1.1", - "resolved": "https://registry.npmjs.org/sitemap/-/sitemap-7.1.1.tgz", - "integrity": "sha512-mK3aFtjz4VdJN0igpIJrinf3EO8U8mxOPsTBzSsy06UtjZQJ3YY3o3Xa7zSc5nMqcMrRwlChHZ18Kxg0caiPBg==", + "version": "7.1.2", + "resolved": "https://registry.npmjs.org/sitemap/-/sitemap-7.1.2.tgz", + "integrity": "sha512-ARCqzHJ0p4gWt+j7NlU5eDlIO9+Rkr/JhPFZKKQ1l5GCus7rJH4UdrlVAh0xC/gDS/Qir2UMxqYNHtsKr2rpCw==", "dependencies": { "@types/node": "^17.0.5", "@types/sax": "^1.2.1", @@ -14478,6 +14468,15 @@ "node": ">=8" } }, + "node_modules/snake-case": { + "version": "3.0.4", + "resolved": "https://registry.npmjs.org/snake-case/-/snake-case-3.0.4.tgz", + "integrity": "sha512-LAOh4z89bGQvl9pFfNF8V146i7o7/CqFPbqzYgP+yYzDIDeS9HaNFtXABamRW+AQzEVODcvE79ljJ+8a9YSdMg==", + "dependencies": { + "dot-case": "^3.0.4", + "tslib": "^2.0.3" + } + }, "node_modules/sockjs": { "version": "0.3.24", "resolved": "https://registry.npmjs.org/sockjs/-/sockjs-0.3.24.tgz", @@ -14489,9 +14488,9 @@ } }, "node_modules/sort-css-media-queries": { - "version": "2.1.0", - "resolved": "https://registry.npmjs.org/sort-css-media-queries/-/sort-css-media-queries-2.1.0.tgz", - "integrity": "sha512-IeWvo8NkNiY2vVYdPa27MCQiR0MN0M80johAYFVxWWXQ44KU84WNxjslwBHmc/7ZL2ccwkM7/e6S5aiKZXm7jA==", + "version": "2.2.0", + "resolved": "https://registry.npmjs.org/sort-css-media-queries/-/sort-css-media-queries-2.2.0.tgz", + "integrity": "sha512-0xtkGhWCC9MGt/EzgnvbbbKhqWjl1+/rncmhTh5qCpbYguXh6S/qwePfv/JQ8jePXXmqingylxoC49pCkSPIbA==", "engines": { "node": ">= 6.3.0" } @@ -14505,9 +14504,9 @@ } }, "node_modules/source-map-js": { - "version": "1.0.2", - "resolved": "https://registry.npmjs.org/source-map-js/-/source-map-js-1.0.2.tgz", - "integrity": "sha512-R0XvVJ9WusLiqTCEiGCmICCMplcCkIwwR11mOSD9CR5u+IXYdiseeEuXCVAjS54zqwkLcPNnmU4OeJ6tUrWhDw==", + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/source-map-js/-/source-map-js-1.2.0.tgz", + "integrity": "sha512-itJW8lvSA0TXEphiRoawsCksnlf8SyvmFzIhltqAHluXd88pkCd+cXJVHTDwdCr0IzwptSm035IHQktUu1QUMg==", "engines": { "node": ">=0.10.0" } @@ -14582,12 +14581,6 @@ "url": "https://github.com/sponsors/sindresorhus" } }, - "node_modules/stable": { - "version": "0.1.8", - "resolved": "https://registry.npmjs.org/stable/-/stable-0.1.8.tgz", - "integrity": "sha512-ji9qxRnOVfcuLDySj9qzhGSEFVobyt1kIOSkj1qZzYLzq7Tos/oUUWvotUPQLlrsidqsK6tBH89Bc9kL5zHA6w==", - "deprecated": "Modern JS already guarantees Array#sort() is a stable sort, so this library is deprecated. See the compatibility table on MDN: https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Array/sort#browser_compatibility" - }, "node_modules/statuses": { "version": "2.0.1", "resolved": "https://registry.npmjs.org/statuses/-/statuses-2.0.1.tgz", @@ -14651,9 +14644,9 @@ } }, "node_modules/stringify-entities": { - "version": "4.0.3", - "resolved": "https://registry.npmjs.org/stringify-entities/-/stringify-entities-4.0.3.tgz", - "integrity": "sha512-BP9nNHMhhfcMbiuQKCqMjhDP5yBCAxsPu4pHFFzJ6Alo9dZgY4VLDPutXqIjpRiMoKdp7Av85Gr73Q5uH9k7+g==", + "version": "4.0.4", + "resolved": "https://registry.npmjs.org/stringify-entities/-/stringify-entities-4.0.4.tgz", + "integrity": "sha512-IwfBptatlO+QCJUo19AqvrPNqlVMpW9YEL2LIVY+Rpv2qsjCGxaDLNRgeGsQWJhfItebuJhsGSLjaBbNSQ+ieg==", "dependencies": { "character-entities-html4": "^2.0.0", "character-entities-legacy": "^3.0.0" @@ -14723,18 +14716,18 @@ } }, "node_modules/stylehacks": { - "version": "5.1.1", - "resolved": "https://registry.npmjs.org/stylehacks/-/stylehacks-5.1.1.tgz", - "integrity": "sha512-sBpcd5Hx7G6seo7b1LkpttvTz7ikD0LlH5RmdcBNb6fFR0Fl7LQwHDFr300q4cwUqi+IYrFGmsIHieMBfnN/Bw==", + "version": "6.1.1", + "resolved": "https://registry.npmjs.org/stylehacks/-/stylehacks-6.1.1.tgz", + "integrity": "sha512-gSTTEQ670cJNoaeIp9KX6lZmm8LJ3jPB5yJmX8Zq/wQxOsAFXV3qjWzHas3YYk1qesuVIyYWWUpZ0vSE/dTSGg==", "dependencies": { - "browserslist": "^4.21.4", - "postcss-selector-parser": "^6.0.4" + "browserslist": "^4.23.0", + "postcss-selector-parser": "^6.0.16" }, "engines": { - "node": "^10 || ^12 || >=14.0" + "node": "^14 || ^16 || >=18.0" }, "peerDependencies": { - "postcss": "^8.2.15" + "postcss": "^8.4.31" } }, "node_modules/stylis": { @@ -14770,23 +14763,27 @@ "integrity": "sha512-e4hG1hRwoOdRb37cIMSgzNsxyzKfayW6VOflrwvR+/bzrkyxY/31WkbgnQpgtrNp1SdpJvpUAGTa/ZoiPNDuRQ==" }, "node_modules/svgo": { - "version": "2.8.0", - "resolved": "https://registry.npmjs.org/svgo/-/svgo-2.8.0.tgz", - "integrity": "sha512-+N/Q9kV1+F+UeWYoSiULYo4xYSDQlTgb+ayMobAXPwMnLvop7oxKMo9OzIrX5x3eS4L4f2UHhc9axXwY8DpChg==", + "version": "3.3.2", + "resolved": "https://registry.npmjs.org/svgo/-/svgo-3.3.2.tgz", + "integrity": "sha512-OoohrmuUlBs8B8o6MB2Aevn+pRIH9zDALSR+6hhqVfa6fRwG/Qw9VUMSMW9VNg2CFc/MTIfabtdOVl9ODIJjpw==", "dependencies": { "@trysound/sax": "0.2.0", "commander": "^7.2.0", - "css-select": "^4.1.3", - "css-tree": "^1.1.3", - "csso": "^4.2.0", - "picocolors": "^1.0.0", - "stable": "^0.1.8" + "css-select": "^5.1.0", + "css-tree": "^2.3.1", + "css-what": "^6.1.0", + "csso": "^5.0.5", + "picocolors": "^1.0.0" }, "bin": { "svgo": "bin/svgo" }, "engines": { - "node": ">=10.13.0" + "node": ">=14.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/svgo" } }, "node_modules/svgo/node_modules/commander": { @@ -14797,69 +14794,6 @@ "node": ">= 10" } }, - "node_modules/svgo/node_modules/css-select": { - "version": "4.3.0", - "resolved": "https://registry.npmjs.org/css-select/-/css-select-4.3.0.tgz", - "integrity": "sha512-wPpOYtnsVontu2mODhA19JrqWxNsfdatRKd64kmpRbQgh1KtItko5sTnEpPdpSaJszTOhEMlF/RPz28qj4HqhQ==", - "dependencies": { - "boolbase": "^1.0.0", - "css-what": "^6.0.1", - "domhandler": "^4.3.1", - "domutils": "^2.8.0", - "nth-check": "^2.0.1" - }, - "funding": { - "url": "https://github.com/sponsors/fb55" - } - }, - "node_modules/svgo/node_modules/dom-serializer": { - "version": "1.4.1", - "resolved": "https://registry.npmjs.org/dom-serializer/-/dom-serializer-1.4.1.tgz", - "integrity": "sha512-VHwB3KfrcOOkelEG2ZOfxqLZdfkil8PtJi4P8N2MMXucZq2yLp75ClViUlOVwyoHEDjYU433Aq+5zWP61+RGag==", - "dependencies": { - "domelementtype": "^2.0.1", - "domhandler": "^4.2.0", - "entities": "^2.0.0" - }, - "funding": { - "url": "https://github.com/cheeriojs/dom-serializer?sponsor=1" - } - }, - "node_modules/svgo/node_modules/domhandler": { - "version": "4.3.1", - "resolved": "https://registry.npmjs.org/domhandler/-/domhandler-4.3.1.tgz", - "integrity": "sha512-GrwoxYN+uWlzO8uhUXRl0P+kHE4GtVPfYzVLcUxPL7KNdHKj66vvlhiweIHqYYXWlw+T8iLMp42Lm67ghw4WMQ==", - "dependencies": { - "domelementtype": "^2.2.0" - }, - "engines": { - "node": ">= 4" - }, - "funding": { - "url": "https://github.com/fb55/domhandler?sponsor=1" - } - }, - "node_modules/svgo/node_modules/domutils": { - "version": "2.8.0", - "resolved": "https://registry.npmjs.org/domutils/-/domutils-2.8.0.tgz", - "integrity": "sha512-w96Cjofp72M5IIhpjgobBimYEfoPjx1Vx0BSX9P30WBdZW2WIKU0T1Bd0kz2eNZ9ikjKgHbEyKx8BB6H1L3h3A==", - "dependencies": { - "dom-serializer": "^1.0.1", - "domelementtype": "^2.2.0", - "domhandler": "^4.2.0" - }, - "funding": { - "url": "https://github.com/fb55/domutils?sponsor=1" - } - }, - "node_modules/svgo/node_modules/entities": { - "version": "2.2.0", - "resolved": "https://registry.npmjs.org/entities/-/entities-2.2.0.tgz", - "integrity": "sha512-p92if5Nz619I0w+akJrLZH0MX0Pb5DX39XOwQTtXSdQQOaYH03S1uIQp4mhOZtAXrxq4ViO67YTiLBo2638o9A==", - "funding": { - "url": "https://github.com/fb55/entities?sponsor=1" - } - }, "node_modules/tapable": { "version": "2.2.1", "resolved": "https://registry.npmjs.org/tapable/-/tapable-2.2.1.tgz", @@ -15060,9 +14994,9 @@ } }, "node_modules/trough": { - "version": "2.1.0", - "resolved": "https://registry.npmjs.org/trough/-/trough-2.1.0.tgz", - "integrity": "sha512-AqTiAOLcj85xS7vQ8QkAV41hPDIJ71XJB4RCUrzo/1GM2CQwhkJGaf9Hgr7BOugMRpgGUrqRg/DrBDl4H40+8g==", + "version": "2.2.0", + "resolved": "https://registry.npmjs.org/trough/-/trough-2.2.0.tgz", + "integrity": "sha512-tmMpK00BjZiUyVyvrBK7knerNgmgvcV/KLVyuma/SC+TQN167GrMRciANTz09+k3zW8L8t60jWO1GpfkZdjTaw==", "funding": { "type": "github", "url": "https://github.com/sponsors/wooorm" diff --git a/docs/package.json b/docs/package.json index 49b960d6a3..dbfc369268 100644 --- a/docs/package.json +++ b/docs/package.json @@ -14,10 +14,10 @@ "write-heading-ids": "docusaurus write-heading-ids" }, "dependencies": { - "@docusaurus/core": "3.0.1", - "@docusaurus/plugin-content-docs": "^3.1.0", - "@docusaurus/preset-classic": "3.0.1", - "@docusaurus/theme-mermaid": "^3.0.1", + "@docusaurus/core": "^3.4.0", + "@docusaurus/plugin-content-docs": "^3.4.0", + "@docusaurus/preset-classic": "^3.4.0", + "@docusaurus/theme-mermaid": "^3.4.0", "@mdx-js/react": "^3.0.0", "clsx": "^2.0.0", "prism-react-renderer": "^2.3.0", diff --git a/docs/src/components/HomepageFeatures/index.js b/docs/src/components/HomepageFeatures/index.js index 547271e5c2..cfe1396245 100644 --- a/docs/src/components/HomepageFeatures/index.js +++ b/docs/src/components/HomepageFeatures/index.js @@ -22,11 +22,11 @@ const FeatureList = [ ), }, { - title: 'AI Services, RAG, Tools, Chains', + title: 'AI Services, RAG, Tools', Svg: require('@site/static/img/functionality-logos.svg').default, description: ( <> - Our extensive toolbox provides a wide range of tools for common LLM operations, from low-level prompt templating, memory management, and output parsing, to high-level patterns like Agents and RAG. + Our extensive toolbox provides a wide range of tools for common LLM operations, from low-level prompt templating, chat memory management, and output parsing, to high-level patterns like AI Services and RAG. ), } diff --git a/docs/static/img/web-search-engine.png b/docs/static/img/web-search-engine.png new file mode 100644 index 0000000000..c4974e0577 Binary files /dev/null and b/docs/static/img/web-search-engine.png differ diff --git a/document-loaders/langchain4j-document-loader-amazon-s3/pom.xml b/document-loaders/langchain4j-document-loader-amazon-s3/pom.xml index 068727fd58..179eefbc0a 100644 --- a/document-loaders/langchain4j-document-loader-amazon-s3/pom.xml +++ b/document-loaders/langchain4j-document-loader-amazon-s3/pom.xml @@ -7,7 +7,7 @@ dev.langchain4j langchain4j-parent - 0.30.0 + 0.32.0-SNAPSHOT ../../langchain4j-parent/pom.xml diff --git a/document-loaders/langchain4j-document-loader-amazon-s3/src/test/java/dev/langchain4j/data/document/loader/amazon/s3/AmazonS3DocumentLoaderIT.java b/document-loaders/langchain4j-document-loader-amazon-s3/src/test/java/dev/langchain4j/data/document/loader/amazon/s3/AmazonS3DocumentLoaderIT.java index 7846a62924..f72894ea9b 100644 --- a/document-loaders/langchain4j-document-loader-amazon-s3/src/test/java/dev/langchain4j/data/document/loader/amazon/s3/AmazonS3DocumentLoaderIT.java +++ b/document-loaders/langchain4j-document-loader-amazon-s3/src/test/java/dev/langchain4j/data/document/loader/amazon/s3/AmazonS3DocumentLoaderIT.java @@ -76,8 +76,8 @@ public void should_load_single_document() { // then assertThat(document.text()).isEqualTo(TEST_CONTENT); - assertThat(document.metadata().asMap()).hasSize(1); - assertThat(document.metadata("source")).isEqualTo("s3://test-bucket/test-file.txt"); + assertThat(document.metadata().toMap()).hasSize(1); + assertThat(document.metadata().getString("source")).isEqualTo("s3://test-bucket/test-file.txt"); } @Test @@ -117,12 +117,12 @@ public void should_load_multiple_documents() { assertThat(documents).hasSize(2); assertThat(documents.get(0).text()).isEqualTo(TEST_CONTENT_2); - assertThat(documents.get(0).metadata().asMap()).hasSize(1); - assertThat(documents.get(0).metadata("source")).isEqualTo("s3://test-bucket/test-directory/test-file-2.txt"); + assertThat(documents.get(0).metadata().toMap()).hasSize(1); + assertThat(documents.get(0).metadata().getString("source")).isEqualTo("s3://test-bucket/test-directory/test-file-2.txt"); assertThat(documents.get(1).text()).isEqualTo(TEST_CONTENT); - assertThat(documents.get(1).metadata().asMap()).hasSize(1); - assertThat(documents.get(1).metadata("source")).isEqualTo("s3://test-bucket/test-file.txt"); + assertThat(documents.get(1).metadata().toMap()).hasSize(1); + assertThat(documents.get(1).metadata().getString("source")).isEqualTo("s3://test-bucket/test-file.txt"); } @Test diff --git a/document-loaders/langchain4j-document-loader-azure-storage-blob/pom.xml b/document-loaders/langchain4j-document-loader-azure-storage-blob/pom.xml index 6166dcfcde..8c842a2602 100644 --- a/document-loaders/langchain4j-document-loader-azure-storage-blob/pom.xml +++ b/document-loaders/langchain4j-document-loader-azure-storage-blob/pom.xml @@ -7,7 +7,7 @@ dev.langchain4j langchain4j-parent - 0.30.0 + 0.32.0-SNAPSHOT ../../langchain4j-parent/pom.xml diff --git a/document-loaders/langchain4j-document-loader-azure-storage-blob/src/main/java/dev/langchain4j/data/document/source/azure/storage/blob/AzureBlobStorageSource.java b/document-loaders/langchain4j-document-loader-azure-storage-blob/src/main/java/dev/langchain4j/data/document/source/azure/storage/blob/AzureBlobStorageSource.java index 02e30d5ed0..311cbb6f9a 100644 --- a/document-loaders/langchain4j-document-loader-azure-storage-blob/src/main/java/dev/langchain4j/data/document/source/azure/storage/blob/AzureBlobStorageSource.java +++ b/document-loaders/langchain4j-document-loader-azure-storage-blob/src/main/java/dev/langchain4j/data/document/source/azure/storage/blob/AzureBlobStorageSource.java @@ -36,10 +36,10 @@ public InputStream inputStream() { @Override public Metadata metadata() { Metadata metadata = new Metadata(); - metadata.add(SOURCE, format("https://%s.blob.core.windows.net/%s/%s", accountName, containerName, blobName)); + metadata.put(SOURCE, format("https://%s.blob.core.windows.net/%s/%s", accountName, containerName, blobName)); metadata.add("azure_storage_blob_creation_time", properties.getCreationTime()); metadata.add("azure_storage_blob_last_modified", properties.getLastModified()); - metadata.add("azure_storage_blob_content_length", String.valueOf(properties.getBlobSize())); + metadata.put("azure_storage_blob_content_length", String.valueOf(properties.getBlobSize())); return metadata; } } diff --git a/document-loaders/langchain4j-document-loader-azure-storage-blob/src/test/java/dev/langchain4j/data/document/loader/azure/storage/blob/AzureBlobStorageDocumentLoaderIT.java b/document-loaders/langchain4j-document-loader-azure-storage-blob/src/test/java/dev/langchain4j/data/document/loader/azure/storage/blob/AzureBlobStorageDocumentLoaderIT.java index 4bc2393955..b3d7cf49fd 100644 --- a/document-loaders/langchain4j-document-loader-azure-storage-blob/src/test/java/dev/langchain4j/data/document/loader/azure/storage/blob/AzureBlobStorageDocumentLoaderIT.java +++ b/document-loaders/langchain4j-document-loader-azure-storage-blob/src/test/java/dev/langchain4j/data/document/loader/azure/storage/blob/AzureBlobStorageDocumentLoaderIT.java @@ -53,8 +53,8 @@ public void should_load_single_document() { Document document = loader.loadDocument(TEST_CONTAINER, TEST_BLOB, parser); assertThat(document.text()).isEqualTo(TEST_CONTENT); - assertThat(document.metadata().asMap()).hasSize(4); - assertThat(document.metadata("source")).endsWith("/test-file.txt"); + assertThat(document.metadata().toMap()).hasSize(4); + assertThat(document.metadata().getString("source")).endsWith("/test-file.txt"); } @Test @@ -65,12 +65,12 @@ public void should_load_multiple_documents() { assertThat(documents).hasSize(2); assertThat(documents.get(0).text()).isEqualTo(TEST_CONTENT_2); - assertThat(documents.get(0).metadata().asMap()).hasSize(4); - assertThat(documents.get(0).metadata("source")).endsWith("/test-directory/test-file-2.txt"); + assertThat(documents.get(0).metadata().toMap()).hasSize(4); + assertThat(documents.get(0).metadata().getString("source")).endsWith("/test-directory/test-file-2.txt"); assertThat(documents.get(1).text()).isEqualTo(TEST_CONTENT); - assertThat(documents.get(1).metadata().asMap()).hasSize(4); - assertThat(documents.get(1).metadata("source")).endsWith("/test-file.txt"); + assertThat(documents.get(1).metadata().toMap()).hasSize(4); + assertThat(documents.get(1).metadata().getString("source")).endsWith("/test-file.txt"); } @AfterEach diff --git a/document-loaders/langchain4j-document-loader-azure-storage-blob/src/test/java/dev/langchain4j/data/document/loader/azure/storage/blob/LocalAzureBlobStorageDocumentLoaderIT.java b/document-loaders/langchain4j-document-loader-azure-storage-blob/src/test/java/dev/langchain4j/data/document/loader/azure/storage/blob/LocalAzureBlobStorageDocumentLoaderIT.java index 306874063b..4dc7fe8993 100644 --- a/document-loaders/langchain4j-document-loader-azure-storage-blob/src/test/java/dev/langchain4j/data/document/loader/azure/storage/blob/LocalAzureBlobStorageDocumentLoaderIT.java +++ b/document-loaders/langchain4j-document-loader-azure-storage-blob/src/test/java/dev/langchain4j/data/document/loader/azure/storage/blob/LocalAzureBlobStorageDocumentLoaderIT.java @@ -61,8 +61,8 @@ public void should_load_single_document() { Document document = loader.loadDocument(TEST_CONTAINER, TEST_BLOB, parser); assertThat(document.text()).isEqualTo(TEST_CONTENT); - assertThat(document.metadata().asMap()).hasSize(4); - assertThat(document.metadata("source")).endsWith("/test-file.txt"); + assertThat(document.metadata().toMap()).hasSize(4); + assertThat(document.metadata().getString("source")).endsWith("/test-file.txt"); } @Test @@ -73,12 +73,12 @@ public void should_load_multiple_documents() { assertThat(documents).hasSize(2); assertThat(documents.get(0).text()).isEqualTo(TEST_CONTENT_2); - assertThat(documents.get(0).metadata().asMap()).hasSize(4); - assertThat(documents.get(0).metadata("source")).endsWith("/test-directory/test-file-2.txt"); + assertThat(documents.get(0).metadata().toMap()).hasSize(4); + assertThat(documents.get(0).metadata().getString("source")).endsWith("/test-directory/test-file-2.txt"); assertThat(documents.get(1).text()).isEqualTo(TEST_CONTENT); - assertThat(documents.get(1).metadata().asMap()).hasSize(4); - assertThat(documents.get(1).metadata("source")).endsWith("/test-file.txt"); + assertThat(documents.get(1).metadata().toMap()).hasSize(4); + assertThat(documents.get(1).metadata().getString("source")).endsWith("/test-file.txt"); } @AfterAll diff --git a/document-loaders/langchain4j-document-loader-github/pom.xml b/document-loaders/langchain4j-document-loader-github/pom.xml index 9b82a0d2f5..2cc65eb271 100644 --- a/document-loaders/langchain4j-document-loader-github/pom.xml +++ b/document-loaders/langchain4j-document-loader-github/pom.xml @@ -7,7 +7,7 @@ dev.langchain4j langchain4j-parent - 0.30.0 + 0.32.0-SNAPSHOT ../../langchain4j-parent/pom.xml diff --git a/document-loaders/langchain4j-document-loader-github/src/main/java/dev/langchain4j/data/document/source/github/GitHubSource.java b/document-loaders/langchain4j-document-loader-github/src/main/java/dev/langchain4j/data/document/source/github/GitHubSource.java index 5b74d31e21..86181519da 100644 --- a/document-loaders/langchain4j-document-loader-github/src/main/java/dev/langchain4j/data/document/source/github/GitHubSource.java +++ b/document-loaders/langchain4j-document-loader-github/src/main/java/dev/langchain4j/data/document/source/github/GitHubSource.java @@ -28,19 +28,19 @@ public InputStream inputStream() { @Override public Metadata metadata() { Metadata metadata = new Metadata(); - metadata.add("github_git_url", content.getGitUrl()); + metadata.put("github_git_url", content.getGitUrl()); try { - metadata.add("github_download_url", content.getDownloadUrl()); + metadata.put("github_download_url", content.getDownloadUrl()); } catch (IOException e) { // Ignore if download_url is not available } - metadata.add("github_html_url", content.getHtmlUrl()); - metadata.add("github_url", content.getUrl()); - metadata.add("github_file_name", content.getName()); - metadata.add("github_file_path", content.getPath()); - metadata.add("github_file_sha", content.getSha()); - metadata.add("github_file_size", Long.toString(content.getSize())); - metadata.add("github_file_encoding", content.getEncoding()); + metadata.put("github_html_url", content.getHtmlUrl()); + metadata.put("github_url", content.getUrl()); + metadata.put("github_file_name", content.getName()); + metadata.put("github_file_path", content.getPath()); + metadata.put("github_file_sha", content.getSha()); + metadata.put("github_file_size", Long.toString(content.getSize())); + metadata.put("github_file_encoding", content.getEncoding()); return metadata; } } diff --git a/document-loaders/langchain4j-document-loader-github/src/test/java/dev/langchain4j/data/document/loader/github/GitHubDocumentLoaderIT.java b/document-loaders/langchain4j-document-loader-github/src/test/java/dev/langchain4j/data/document/loader/github/GitHubDocumentLoaderIT.java index 145a241818..312fac8e32 100644 --- a/document-loaders/langchain4j-document-loader-github/src/test/java/dev/langchain4j/data/document/loader/github/GitHubDocumentLoaderIT.java +++ b/document-loaders/langchain4j-document-loader-github/src/test/java/dev/langchain4j/data/document/loader/github/GitHubDocumentLoaderIT.java @@ -36,8 +36,8 @@ public void should_load_file() { Document document = loader.loadDocument(TEST_OWNER, TEST_REPO, "main", "pom.xml", parser); assertThat(document.text()).contains("dev.langchain4j"); - assertThat(document.metadata().asMap()).hasSize(9); - assertThat(document.metadata("github_git_url")).startsWith("https://api.github.com/repos/langchain4j/langchain4j"); + assertThat(document.metadata().toMap()).hasSize(9); + assertThat(document.metadata().getString("github_git_url")).startsWith("https://api.github.com/repos/langchain4j/langchain4j"); } @Test diff --git a/document-loaders/langchain4j-document-loader-selenium/pom.xml b/document-loaders/langchain4j-document-loader-selenium/pom.xml new file mode 100644 index 0000000000..a23744d67f --- /dev/null +++ b/document-loaders/langchain4j-document-loader-selenium/pom.xml @@ -0,0 +1,62 @@ + + + 4.0.0 + + dev.langchain4j + langchain4j-parent + 0.32.0-SNAPSHOT + ../../langchain4j-parent/pom.xml + + + langchain4j-document-loader-selenium + LangChain4j :: Document loader :: Selenium + + + Selenium is a suite of tools for automating web browsers. + Integration with LangChain4j adds web document loading through browser automation. + https://www.selenium.dev/documentation/about/copyright + + + + 4.13.0 + + + + + dev.langchain4j + langchain4j-core + + + + org.seleniumhq.selenium + selenium-java + ${selenium.webdriver.version} + + + + dev.langchain4j + langchain4j + test + + + + org.junit.jupiter + junit-jupiter-engine + test + + + + org.testcontainers + selenium + test + + + + org.assertj + assertj-core + test + + + + diff --git a/document-loaders/langchain4j-document-loader-selenium/src/main/java/dev/langchain4j/data/document/loader/selenium/SeleniumDocumentLoader.java b/document-loaders/langchain4j-document-loader-selenium/src/main/java/dev/langchain4j/data/document/loader/selenium/SeleniumDocumentLoader.java new file mode 100644 index 0000000000..dc5a51f444 --- /dev/null +++ b/document-loaders/langchain4j-document-loader-selenium/src/main/java/dev/langchain4j/data/document/loader/selenium/SeleniumDocumentLoader.java @@ -0,0 +1,87 @@ +package dev.langchain4j.data.document.loader.selenium; + +import static java.util.Objects.requireNonNull; + +import dev.langchain4j.data.document.Document; +import dev.langchain4j.data.document.DocumentParser; +import java.io.ByteArrayInputStream; +import java.time.Duration; +import java.util.Objects; +import org.openqa.selenium.JavascriptExecutor; +import org.openqa.selenium.WebDriver; +import org.openqa.selenium.support.ui.ExpectedCondition; +import org.openqa.selenium.support.ui.WebDriverWait; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Utility class for loading web documents using Selenium. + * Returns a {@link Document} object containing the content of the web page. + * + */ +public class SeleniumDocumentLoader { + + private static final Logger logger = LoggerFactory.getLogger(SeleniumDocumentLoader.class); + private static final Duration DEFAULT_TIMEOUT_DURATION = Duration.ofSeconds(30); + + private final WebDriver webDriver; + private final Duration timeout; + + private SeleniumDocumentLoader(WebDriver webDriver, Duration timeout) { + this.webDriver = webDriver; + this.timeout = timeout; + } + + /** + * Loads a document from the specified URL. + * + * @param url The URL of the file. + * @param documentParser The parser to be used for parsing text from the URL. + * @return document + */ + public Document load(String url, DocumentParser documentParser) { + logger.info("Loading document from URL: {}", url); + String pageContent; + try { + webDriver.get(url); + WebDriverWait wait = new WebDriverWait(webDriver, timeout); + logger.debug("Waiting webpage fully loaded: {}", url); + wait.until((ExpectedCondition) wd -> { + if (logger.isTraceEnabled()) { + logger.trace("Waiting for document.readyState to be complete"); + } + return ((JavascriptExecutor) requireNonNull(wd)).executeScript("return document.readyState").equals("complete"); + }); + pageContent = webDriver.getPageSource(); + } catch (Exception e) { + throw new RuntimeException("Failed to load document", e); + } + Document parsedDocument = documentParser.parse(new ByteArrayInputStream(pageContent.getBytes())); + parsedDocument.metadata().put(Document.URL, url); + return parsedDocument; + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + private WebDriver webDriver; + private Duration timeout = DEFAULT_TIMEOUT_DURATION; + + public Builder webDriver(WebDriver webDriver) { + this.webDriver = webDriver; + return this; + } + + public Builder timeout(Duration timeout) { + this.timeout = timeout; + return this; + } + + public SeleniumDocumentLoader build() { + Objects.requireNonNull(webDriver, "webDriver must be set"); + return new SeleniumDocumentLoader(webDriver, timeout); + } + } +} diff --git a/document-loaders/langchain4j-document-loader-selenium/src/test/java/dev/langchain4j/data/document/loader/selenium/SeleniumDocumentLoaderTestIT.java b/document-loaders/langchain4j-document-loader-selenium/src/test/java/dev/langchain4j/data/document/loader/selenium/SeleniumDocumentLoaderTestIT.java new file mode 100644 index 0000000000..11112d59fb --- /dev/null +++ b/document-loaders/langchain4j-document-loader-selenium/src/test/java/dev/langchain4j/data/document/loader/selenium/SeleniumDocumentLoaderTestIT.java @@ -0,0 +1,64 @@ +package dev.langchain4j.data.document.loader.selenium; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; + +import dev.langchain4j.data.document.Document; +import dev.langchain4j.data.document.DocumentParser; +import dev.langchain4j.data.document.parser.TextDocumentParser; +import dev.langchain4j.data.document.transformer.HtmlTextExtractor; +import java.time.Duration; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.openqa.selenium.chrome.ChromeOptions; +import org.openqa.selenium.remote.RemoteWebDriver; +import org.testcontainers.containers.BrowserWebDriverContainer; + +class SeleniumDocumentLoaderTestIT { + + static SeleniumDocumentLoader loader; + + DocumentParser parser = new TextDocumentParser(); + HtmlTextExtractor extractor = new HtmlTextExtractor(); + + @BeforeAll + static void beforeAll() { + BrowserWebDriverContainer chromeContainer = new BrowserWebDriverContainer<>() + .withCapabilities(new ChromeOptions()); + chromeContainer.start(); + RemoteWebDriver webDriver = new RemoteWebDriver(chromeContainer.getSeleniumAddress(), new ChromeOptions()); + loader = SeleniumDocumentLoader.builder() + .webDriver(webDriver) + .timeout(Duration.ofSeconds(30)) + .build(); + } + + @Test + void should_load_html_document() { + String url = + "https://raw.githubusercontent.com/langchain4j/langchain4j/main/langchain4j/src/test/resources/test-file-utf8.txt"; + Document document = loader.load(url, parser); + + Document textDocument = extractor.transform(document); + + assertThat(textDocument.text()).isEqualTo("test content"); + assertThat(document.text()).contains("test\ncontent"); + assertThat(document.metadata(Document.URL)).isEqualTo(url); + } + + @Test + void should_fail_for_unresolvable_url() { + String url = + "https://a.a"; + assertThatExceptionOfType(RuntimeException.class) + .isThrownBy(() -> loader.load(url, parser)); + } + + @Test + void should_fail_for_bad_url() { + String url = + "bad_url"; + assertThatExceptionOfType(RuntimeException.class) + .isThrownBy(() -> loader.load(url, parser)); + } +} \ No newline at end of file diff --git a/document-loaders/langchain4j-document-loader-tencent-cos/pom.xml b/document-loaders/langchain4j-document-loader-tencent-cos/pom.xml index 0291265e6a..a800c6f43b 100644 --- a/document-loaders/langchain4j-document-loader-tencent-cos/pom.xml +++ b/document-loaders/langchain4j-document-loader-tencent-cos/pom.xml @@ -6,7 +6,7 @@ dev.langchain4j langchain4j-parent - 0.30.0 + 0.32.0-SNAPSHOT ../../langchain4j-parent/pom.xml diff --git a/document-loaders/langchain4j-document-loader-tencent-cos/src/test/java/dev/langchain4j/data/document/loader/tencent/cos/TencentCosDocumentLoaderIT.java b/document-loaders/langchain4j-document-loader-tencent-cos/src/test/java/dev/langchain4j/data/document/loader/tencent/cos/TencentCosDocumentLoaderIT.java index 35df66c93a..9491f80e1e 100644 --- a/document-loaders/langchain4j-document-loader-tencent-cos/src/test/java/dev/langchain4j/data/document/loader/tencent/cos/TencentCosDocumentLoaderIT.java +++ b/document-loaders/langchain4j-document-loader-tencent-cos/src/test/java/dev/langchain4j/data/document/loader/tencent/cos/TencentCosDocumentLoaderIT.java @@ -63,8 +63,8 @@ void should_load_single_document() { // then assertThat(document.text()).isEqualTo(TEST_CONTENT); - assertThat(document.metadata().asMap()).hasSize(1); - assertThat(document.metadata("source")).isEqualTo(String.format("cos://%s/%s", TEST_BUCKET, TEST_KEY)); + assertThat(document.metadata().toMap()).hasSize(1); + assertThat(document.metadata().getString("source")).isEqualTo(String.format("cos://%s/%s", TEST_BUCKET, TEST_KEY)); } @Test @@ -88,11 +88,11 @@ void should_load_multiple_documents() { assertThat(documents).hasSize(2); assertThat(documents.get(0).text()).isEqualTo(TEST_CONTENT_2); - assertThat(documents.get(0).metadata().asMap()).hasSize(1); - assertThat(documents.get(0).metadata("source")).isEqualTo(String.format("cos://%s/%s", TEST_BUCKET, TEST_KEY_2)); + assertThat(documents.get(0).metadata().toMap()).hasSize(1); + assertThat(documents.get(0).metadata().getString("source")).isEqualTo(String.format("cos://%s/%s", TEST_BUCKET, TEST_KEY_2)); assertThat(documents.get(1).text()).isEqualTo(TEST_CONTENT); - assertThat(documents.get(1).metadata().asMap()).hasSize(1); + assertThat(documents.get(1).metadata().toMap()).hasSize(1); assertThat(documents.get(1).metadata("source")).isEqualTo(String.format("cos://%s/%s", TEST_BUCKET, TEST_KEY)); } diff --git a/document-parsers/langchain4j-document-parser-apache-pdfbox/pom.xml b/document-parsers/langchain4j-document-parser-apache-pdfbox/pom.xml index 1102b13878..d155c76c43 100644 --- a/document-parsers/langchain4j-document-parser-apache-pdfbox/pom.xml +++ b/document-parsers/langchain4j-document-parser-apache-pdfbox/pom.xml @@ -7,7 +7,7 @@ dev.langchain4j langchain4j-parent - 0.30.0 + 0.32.0-SNAPSHOT ../../langchain4j-parent/pom.xml diff --git a/document-parsers/langchain4j-document-parser-apache-pdfbox/src/test/java/dev/langchain4j/data/document/parser/apache/pdfbox/ApachePdfBoxDocumentParserTest.java b/document-parsers/langchain4j-document-parser-apache-pdfbox/src/test/java/dev/langchain4j/data/document/parser/apache/pdfbox/ApachePdfBoxDocumentParserTest.java index 051d54b2a8..c5e1c4c31e 100644 --- a/document-parsers/langchain4j-document-parser-apache-pdfbox/src/test/java/dev/langchain4j/data/document/parser/apache/pdfbox/ApachePdfBoxDocumentParserTest.java +++ b/document-parsers/langchain4j-document-parser-apache-pdfbox/src/test/java/dev/langchain4j/data/document/parser/apache/pdfbox/ApachePdfBoxDocumentParserTest.java @@ -21,7 +21,7 @@ void should_parse_pdf_file() { Document document = parser.parse(inputStream); assertThat(document.text()).isEqualToIgnoringWhitespace("test content"); - assertThat(document.metadata().asMap()).isEmpty(); + assertThat(document.metadata().toMap()).isEmpty(); } @Test diff --git a/document-parsers/langchain4j-document-parser-apache-poi/pom.xml b/document-parsers/langchain4j-document-parser-apache-poi/pom.xml index 153e9b20a0..088e5751ff 100644 --- a/document-parsers/langchain4j-document-parser-apache-poi/pom.xml +++ b/document-parsers/langchain4j-document-parser-apache-poi/pom.xml @@ -7,7 +7,7 @@ dev.langchain4j langchain4j-parent - 0.30.0 + 0.32.0-SNAPSHOT ../../langchain4j-parent/pom.xml diff --git a/document-parsers/langchain4j-document-parser-apache-poi/src/test/java/dev/langchain4j/data/document/parser/apache/poi/ApachePoiDocumentParserTest.java b/document-parsers/langchain4j-document-parser-apache-poi/src/test/java/dev/langchain4j/data/document/parser/apache/poi/ApachePoiDocumentParserTest.java index 0f5639e661..166322bbc7 100644 --- a/document-parsers/langchain4j-document-parser-apache-poi/src/test/java/dev/langchain4j/data/document/parser/apache/poi/ApachePoiDocumentParserTest.java +++ b/document-parsers/langchain4j-document-parser-apache-poi/src/test/java/dev/langchain4j/data/document/parser/apache/poi/ApachePoiDocumentParserTest.java @@ -28,7 +28,7 @@ void should_parse_doc_and_ppt_files(String fileName) { Document document = parser.parse(inputStream); assertThat(document.text()).isEqualToIgnoringWhitespace("test content"); - assertThat(document.metadata().asMap()).isEmpty(); + assertThat(document.metadata().toMap()).isEmpty(); } @ParameterizedTest @@ -45,7 +45,7 @@ void should_parse_xls_files(String fileName) { assertThat(document.text()) .isEqualToIgnoringWhitespace("Sheet1\ntest content\nSheet2\ntest content"); - assertThat(document.metadata().asMap()).isEmpty(); + assertThat(document.metadata().toMap()).isEmpty(); } @ParameterizedTest diff --git a/document-parsers/langchain4j-document-parser-apache-tika/pom.xml b/document-parsers/langchain4j-document-parser-apache-tika/pom.xml index bb446cb39e..15e17487d2 100644 --- a/document-parsers/langchain4j-document-parser-apache-tika/pom.xml +++ b/document-parsers/langchain4j-document-parser-apache-tika/pom.xml @@ -7,7 +7,7 @@ dev.langchain4j langchain4j-parent - 0.30.0 + 0.32.0-SNAPSHOT ../../langchain4j-parent/pom.xml diff --git a/document-parsers/langchain4j-document-parser-apache-tika/src/main/java/dev/langchain4j/data/document/parser/apache/tika/ApacheTikaDocumentParser.java b/document-parsers/langchain4j-document-parser-apache-tika/src/main/java/dev/langchain4j/data/document/parser/apache/tika/ApacheTikaDocumentParser.java index d02bfffb4f..f58dbfb498 100644 --- a/document-parsers/langchain4j-document-parser-apache-tika/src/main/java/dev/langchain4j/data/document/parser/apache/tika/ApacheTikaDocumentParser.java +++ b/document-parsers/langchain4j-document-parser-apache-tika/src/main/java/dev/langchain4j/data/document/parser/apache/tika/ApacheTikaDocumentParser.java @@ -1,7 +1,7 @@ package dev.langchain4j.data.document.parser.apache.tika; -import dev.langchain4j.data.document.Document; import dev.langchain4j.data.document.BlankDocumentException; +import dev.langchain4j.data.document.Document; import dev.langchain4j.data.document.DocumentParser; import org.apache.tika.exception.ZeroByteFileException; import org.apache.tika.metadata.Metadata; @@ -12,6 +12,7 @@ import org.xml.sax.ContentHandler; import java.io.InputStream; +import java.util.function.Supplier; import static dev.langchain4j.internal.Utils.getOrDefault; import static dev.langchain4j.internal.Utils.isNullOrBlank; @@ -25,11 +26,15 @@ public class ApacheTikaDocumentParser implements DocumentParser { private static final int NO_WRITE_LIMIT = -1; + public static final Supplier DEFAULT_PARSER_SUPPLIER = AutoDetectParser::new; + public static final Supplier DEFAULT_METADATA_SUPPLIER = Metadata::new; + public static final Supplier DEFAULT_PARSE_CONTEXT_SUPPLIER = ParseContext::new; + public static final Supplier DEFAULT_CONTENT_HANDLER_SUPPLIER = () -> new BodyContentHandler(NO_WRITE_LIMIT); - private final Parser parser; - private final ContentHandler contentHandler; - private final Metadata metadata; - private final ParseContext parseContext; + private final Supplier parserSupplier; + private final Supplier contentHandlerSupplier; + private final Supplier metadataSupplier; + private final Supplier parseContextSupplier; /** * Creates an instance of an {@code ApacheTikaDocumentParser} with the default Tika components. @@ -37,7 +42,7 @@ public class ApacheTikaDocumentParser implements DocumentParser { * empty {@link Metadata} and empty {@link ParseContext}. */ public ApacheTikaDocumentParser() { - this(null, null, null, null); + this((Supplier) null, null, null, null); } /** @@ -48,15 +53,38 @@ public ApacheTikaDocumentParser() { * @param contentHandler Tika content handler. Default: {@link BodyContentHandler} without write limit * @param metadata Tika metadata. Default: empty {@link Metadata} * @param parseContext Tika parse context. Default: empty {@link ParseContext} + * @deprecated Use the constructor with suppliers for Tika components if you intend to use this parser for multiple files. */ + @Deprecated public ApacheTikaDocumentParser(Parser parser, ContentHandler contentHandler, Metadata metadata, ParseContext parseContext) { - this.parser = getOrDefault(parser, AutoDetectParser::new); - this.contentHandler = getOrDefault(contentHandler, () -> new BodyContentHandler(NO_WRITE_LIMIT)); - this.metadata = getOrDefault(metadata, Metadata::new); - this.parseContext = getOrDefault(parseContext, ParseContext::new); + this( + () -> getOrDefault(parser, DEFAULT_PARSER_SUPPLIER), + () -> getOrDefault(contentHandler, DEFAULT_CONTENT_HANDLER_SUPPLIER), + () -> getOrDefault(metadata, DEFAULT_METADATA_SUPPLIER), + () -> getOrDefault(parseContext, DEFAULT_PARSE_CONTEXT_SUPPLIER) + ); + } + + /** + * Creates an instance of an {@code ApacheTikaDocumentParser} with the provided suppliers for Tika components. + * If some of the suppliers are not provided ({@code null}), the defaults will be used. + * + * @param parserSupplier Supplier for Tika parser to use. Default: {@link AutoDetectParser} + * @param contentHandlerSupplier Supplier for Tika content handler. Default: {@link BodyContentHandler} without write limit + * @param metadataSupplier Supplier for Tika metadata. Default: empty {@link Metadata} + * @param parseContextSupplier Supplier for Tika parse context. Default: empty {@link ParseContext} + */ + public ApacheTikaDocumentParser(Supplier parserSupplier, + Supplier contentHandlerSupplier, + Supplier metadataSupplier, + Supplier parseContextSupplier) { + this.parserSupplier = getOrDefault(parserSupplier, () -> DEFAULT_PARSER_SUPPLIER); + this.contentHandlerSupplier = getOrDefault(contentHandlerSupplier, () -> DEFAULT_CONTENT_HANDLER_SUPPLIER); + this.metadataSupplier = getOrDefault(metadataSupplier, () -> DEFAULT_METADATA_SUPPLIER); + this.parseContextSupplier = getOrDefault(parseContextSupplier, () -> DEFAULT_PARSE_CONTEXT_SUPPLIER); } // TODO allow automatically extract metadata (e.g. creator, last-author, created/modified timestamp, etc) @@ -64,6 +92,11 @@ public ApacheTikaDocumentParser(Parser parser, @Override public Document parse(InputStream inputStream) { try { + Parser parser = parserSupplier.get(); + ContentHandler contentHandler = contentHandlerSupplier.get(); + Metadata metadata = metadataSupplier.get(); + ParseContext parseContext = parseContextSupplier.get(); + parser.parse(inputStream, contentHandler, metadata, parseContext); String text = contentHandler.toString(); diff --git a/document-parsers/langchain4j-document-parser-apache-tika/src/test/java/dev/langchain4j/data/document/parser/apache/tika/ApacheTikaDocumentParserTest.java b/document-parsers/langchain4j-document-parser-apache-tika/src/test/java/dev/langchain4j/data/document/parser/apache/tika/ApacheTikaDocumentParserTest.java index 653722b3d6..fea01ecce9 100644 --- a/document-parsers/langchain4j-document-parser-apache-tika/src/test/java/dev/langchain4j/data/document/parser/apache/tika/ApacheTikaDocumentParserTest.java +++ b/document-parsers/langchain4j-document-parser-apache-tika/src/test/java/dev/langchain4j/data/document/parser/apache/tika/ApacheTikaDocumentParserTest.java @@ -1,9 +1,10 @@ package dev.langchain4j.data.document.parser.apache.tika; -import dev.langchain4j.data.document.Document; import dev.langchain4j.data.document.BlankDocumentException; +import dev.langchain4j.data.document.Document; import dev.langchain4j.data.document.DocumentParser; import org.apache.tika.parser.AutoDetectParser; +import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; @@ -30,7 +31,7 @@ void should_parse_doc_ppt_and_pdf_files(String fileName) { Document document = parser.parse(inputStream); assertThat(document.text()).isEqualToIgnoringWhitespace("test content"); - assertThat(document.metadata().asMap()).isEmpty(); + assertThat(document.metadata().toMap()).isEmpty(); } @ParameterizedTest @@ -40,14 +41,32 @@ void should_parse_doc_ppt_and_pdf_files(String fileName) { }) void should_parse_xls_files(String fileName) { - DocumentParser parser = new ApacheTikaDocumentParser(new AutoDetectParser(), null, null, null); + DocumentParser parser = new ApacheTikaDocumentParser(AutoDetectParser::new, null, null, null); InputStream inputStream = getClass().getClassLoader().getResourceAsStream(fileName); Document document = parser.parse(inputStream); assertThat(document.text()) .isEqualToIgnoringWhitespace("Sheet1\ntest content\nSheet2\ntest content"); - assertThat(document.metadata().asMap()).isEmpty(); + assertThat(document.metadata().toMap()).isEmpty(); + } + + @Test + void should_parse_files_stateless() { + + DocumentParser parser = new ApacheTikaDocumentParser(); + InputStream inputStream1 = getClass().getClassLoader().getResourceAsStream("test-file.xls"); + InputStream inputStream2 = getClass().getClassLoader().getResourceAsStream("test-file.xls"); + + Document document1 = parser.parse(inputStream1); + Document document2 = parser.parse(inputStream2); + + assertThat(document1.text()) + .isEqualToIgnoringWhitespace("Sheet1\ntest content\nSheet2\ntest content"); + assertThat(document2.text()) + .isEqualToIgnoringWhitespace("Sheet1\ntest content\nSheet2\ntest content"); + assertThat(document1.metadata().toMap()).isEmpty(); + assertThat(document2.metadata().toMap()).isEmpty(); } @ParameterizedTest diff --git a/embedding-store-filter-parsers/langchain4j-embedding-store-filter-parser-sql/pom.xml b/embedding-store-filter-parsers/langchain4j-embedding-store-filter-parser-sql/pom.xml index d9ff96f86a..3e7519808e 100644 --- a/embedding-store-filter-parsers/langchain4j-embedding-store-filter-parser-sql/pom.xml +++ b/embedding-store-filter-parsers/langchain4j-embedding-store-filter-parser-sql/pom.xml @@ -7,7 +7,7 @@ dev.langchain4j langchain4j-parent - 0.30.0 + 0.32.0-SNAPSHOT ../../langchain4j-parent/pom.xml diff --git a/experimental/langchain4j-experimental-sql/pom.xml b/experimental/langchain4j-experimental-sql/pom.xml new file mode 100644 index 0000000000..6a0d13cf36 --- /dev/null +++ b/experimental/langchain4j-experimental-sql/pom.xml @@ -0,0 +1,109 @@ + + + 4.0.0 + + dev.langchain4j + langchain4j-parent + 0.32.0-SNAPSHOT + ../../langchain4j-parent/pom.xml + + + langchain4j-experimental-sql + LangChain4j :: Experimental :: SQL + + + + + dev.langchain4j + langchain4j-core + + + + com.github.jsqlparser + jsqlparser + 4.8 + + + + org.projectlombok + lombok + provided + + + + + + dev.langchain4j + langchain4j-open-ai + test + + + dev.langchain4j + langchain4j-mistral-ai + test + + + + org.junit.jupiter + junit-jupiter-engine + test + + + org.junit.jupiter + junit-jupiter-params + test + + + + org.testcontainers + testcontainers + test + + + org.testcontainers + junit-jupiter + test + + + org.testcontainers + postgresql + 1.19.7 + test + + + org.postgresql + postgresql + 42.7.3 + test + + + + org.assertj + assertj-core + test + + + + ch.qos.logback + logback-classic + test + + + + + + + + org.honton.chas + license-maven-plugin + + + true + + + + + + \ No newline at end of file diff --git a/experimental/langchain4j-experimental-sql/src/main/java/dev/langchain4j/experimental/rag/content/retriever/sql/SqlDatabaseContentRetriever.java b/experimental/langchain4j-experimental-sql/src/main/java/dev/langchain4j/experimental/rag/content/retriever/sql/SqlDatabaseContentRetriever.java new file mode 100644 index 0000000000..44b222745c --- /dev/null +++ b/experimental/langchain4j-experimental-sql/src/main/java/dev/langchain4j/experimental/rag/content/retriever/sql/SqlDatabaseContentRetriever.java @@ -0,0 +1,351 @@ +package dev.langchain4j.experimental.rag.content.retriever.sql; + +import dev.langchain4j.Experimental; +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.data.message.SystemMessage; +import dev.langchain4j.data.message.UserMessage; +import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.input.Prompt; +import dev.langchain4j.model.input.PromptTemplate; +import dev.langchain4j.rag.content.Content; +import dev.langchain4j.rag.content.retriever.ContentRetriever; +import dev.langchain4j.rag.query.Query; +import lombok.Builder; +import net.sf.jsqlparser.JSQLParserException; +import net.sf.jsqlparser.parser.CCJSqlParserUtil; +import net.sf.jsqlparser.statement.select.Select; + +import javax.sql.DataSource; +import java.sql.*; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static dev.langchain4j.internal.Utils.getOrDefault; +import static dev.langchain4j.internal.ValidationUtils.ensureNotNull; +import static java.util.Collections.emptyList; +import static java.util.Collections.singletonList; + +/** + * + * WARNING! Although fun and exciting, this class is dangerous to use! Do not ever use this in production! + * The database user must have very limited READ-ONLY permissions! + * Although the generated SQL is somewhat validated (to ensure that the SQL is a SELECT statement) using JSqlParser, + * this class does not guarantee that the SQL will be harmless. Use it at your own risk! + * + *
+ *
+ * Using the {@link DataSource} and the {@link ChatLanguageModel}, this {@link ContentRetriever} + * attempts to generate and execute SQL queries for given natural language queries. + *
+ * Optionally, {@link #sqlDialect}, {@link #databaseStructure}, {@link #promptTemplate}, and {@link #maxRetries} can be specified + * to customize the behavior. See the javadoc of the constructor for more details. + * Most methods can be overridden to customize the behavior further. + *
+ * The default prompt template is not highly optimized, + * so it is advised to experiment with it and see what works best for your use case. + */ +@Experimental +public class SqlDatabaseContentRetriever implements ContentRetriever { + + private static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = PromptTemplate.from( + "You are an expert in writing SQL queries.\n" + + "You have access to a {{sqlDialect}} database with the following structure:\n" + + "{{databaseStructure}}\n" + + "If a user asks a question that can be answered by querying this database, generate an SQL SELECT query.\n" + + "Do not output anything else aside from a valid SQL statement!" + ); + + private final DataSource dataSource; + private final String sqlDialect; + private final String databaseStructure; + + private final PromptTemplate promptTemplate; + private final ChatLanguageModel chatLanguageModel; + + private final int maxRetries; + + /** + * Creates an instance of a {@code SqlDatabaseContentRetriever}. + * + * @param dataSource The {@link DataSource} to be used for executing SQL queries. + * This is a mandatory parameter. + * WARNING! The database user must have very limited READ-ONLY permissions! + * @param sqlDialect The SQL dialect, which will be provided to the LLM in the {@link SystemMessage}. + * The LLM should know the specific SQL dialect in order to generate valid SQL queries. + * Example: "MySQL", "PostgreSQL", etc. + * This is an optional parameter. If not specified, it will be determined from the {@code DataSource}. + * @param databaseStructure The structure of the database, which will be provided to the LLM in the {@code SystemMessage}. + * The LLM should be familiar with available tables, columns, relationships, etc. in order to generate valid SQL queries. + * It is best to specify the complete "CREATE TABLE ..." DDL statement for each table. + * Example (shortened): "CREATE TABLE customers(\n id INT PRIMARY KEY,\n name VARCHAR(50), ...);\n CREATE TABLE products(...);\n ..." + * This is an optional parameter. If not specified, it will be generated from the {@code DataSource}. + * WARNING! In this case, all tables will be visible to the LLM! + * @param promptTemplate The {@link PromptTemplate} to be used for creating a {@code SystemMessage}. + * This is an optional parameter. Default: {@link #DEFAULT_PROMPT_TEMPLATE}. + * @param chatLanguageModel The {@link ChatLanguageModel} to be used for generating SQL queries. + * This is a mandatory parameter. + * @param maxRetries The maximum number of retries to perform if the database cannot execute the generated SQL query. + * An error message will be sent back to the LLM to try correcting the query. + * This is an optional parameter. Default: 1. + */ + @Builder + @Experimental + public SqlDatabaseContentRetriever(DataSource dataSource, + String sqlDialect, + String databaseStructure, + PromptTemplate promptTemplate, + ChatLanguageModel chatLanguageModel, + Integer maxRetries) { + this.dataSource = ensureNotNull(dataSource, "dataSource"); + this.sqlDialect = getOrDefault(sqlDialect, () -> getSqlDialect(dataSource)); + this.databaseStructure = getOrDefault(databaseStructure, () -> generateDDL(dataSource)); + this.promptTemplate = getOrDefault(promptTemplate, DEFAULT_PROMPT_TEMPLATE); + this.chatLanguageModel = ensureNotNull(chatLanguageModel, "chatLanguageModel"); + this.maxRetries = getOrDefault(maxRetries, 1); + } + + // TODO (for v2) + // - provide a few rows of data for each table in the prompt + // - option to select a list of tables to use/ignore + + public static String getSqlDialect(DataSource dataSource) { + try (Connection connection = dataSource.getConnection()) { + DatabaseMetaData metaData = connection.getMetaData(); + return metaData.getDatabaseProductName(); + } catch (SQLException e) { + throw new RuntimeException(e); + } + } + + private static String generateDDL(DataSource dataSource) { + StringBuilder ddl = new StringBuilder(); + + try (Connection connection = dataSource.getConnection()) { + DatabaseMetaData metaData = connection.getMetaData(); + + ResultSet tables = metaData.getTables(null, null, "%", new String[]{"TABLE"}); + + while (tables.next()) { + String tableName = tables.getString("TABLE_NAME"); + String createTableStatement = generateCreateTableStatement(tableName, metaData); + ddl.append(createTableStatement).append("\n"); + } + } catch (SQLException e) { + throw new RuntimeException(e); + } + + return ddl.toString(); + } + + private static String generateCreateTableStatement(String tableName, DatabaseMetaData metaData) { + StringBuilder createTableStatement = new StringBuilder(); + + try { + ResultSet columns = metaData.getColumns(null, null, tableName, null); + ResultSet pk = metaData.getPrimaryKeys(null, null, tableName); + ResultSet fks = metaData.getImportedKeys(null, null, tableName); + + String primaryKeyColumn = ""; + if (pk.next()) { + primaryKeyColumn = pk.getString("COLUMN_NAME"); + } + + createTableStatement + .append("CREATE TABLE ") + .append(tableName) + .append(" (\n"); + + while (columns.next()) { + String columnName = columns.getString("COLUMN_NAME"); + String columnType = columns.getString("TYPE_NAME"); + int size = columns.getInt("COLUMN_SIZE"); + String nullable = columns.getString("IS_NULLABLE").equals("YES") ? " NULL" : " NOT NULL"; + String columnDef = columns.getString("COLUMN_DEF") != null ? " DEFAULT " + columns.getString("COLUMN_DEF") : ""; + String comment = columns.getString("REMARKS"); + + createTableStatement + .append(" ") + .append(columnName) + .append(" ") + .append(columnType) + .append("(") + .append(size) + .append(")") + .append(nullable) + .append(columnDef); + + if (columnName.equals(primaryKeyColumn)) { + createTableStatement.append(" PRIMARY KEY"); + } + + createTableStatement.append(",\n"); + + if (comment != null && !comment.isEmpty()) { + createTableStatement + .append(" COMMENT ON COLUMN ") + .append(tableName) + .append(".") + .append(columnName) + .append(" IS '") + .append(comment) + .append("',\n"); + } + } + + while (fks.next()) { + String fkColumnName = fks.getString("FKCOLUMN_NAME"); + String pkTableName = fks.getString("PKTABLE_NAME"); + String pkColumnName = fks.getString("PKCOLUMN_NAME"); + createTableStatement + .append(" FOREIGN KEY (") + .append(fkColumnName) + .append(") REFERENCES ") + .append(pkTableName) + .append("(") + .append(pkColumnName) + .append("),\n"); + } + + if (createTableStatement.charAt(createTableStatement.length() - 2) == ',') { + createTableStatement.delete(createTableStatement.length() - 2, createTableStatement.length()); + } + + createTableStatement.append(");\n"); + + ResultSet tableRemarks = metaData.getTables(null, null, tableName, null); + if (tableRemarks.next()) { + String tableComment = tableRemarks.getString("REMARKS"); + if (tableComment != null && !tableComment.isEmpty()) { + createTableStatement + .append("COMMENT ON TABLE ") + .append(tableName) + .append(" IS '") + .append(tableComment) + .append("';\n"); + } + } + } catch (SQLException e) { + throw new RuntimeException(e); + } + + return createTableStatement.toString(); + } + + @Override + public List retrieve(Query naturalLanguageQuery) { + + String sqlQuery = null; + String errorMessage = null; + + int attemptsLeft = maxRetries + 1; + while (attemptsLeft > 0) { + attemptsLeft--; + + sqlQuery = generateSqlQuery(naturalLanguageQuery, sqlQuery, errorMessage); + + sqlQuery = clean(sqlQuery); + + if (!isSelect(sqlQuery)) { + return emptyList(); + } + + try { + validate(sqlQuery); + + try (Connection connection = dataSource.getConnection(); + Statement statement = connection.createStatement()) { + + String result = execute(sqlQuery, statement); + Content content = format(result, sqlQuery); + return singletonList(content); + } + } catch (Exception e) { + errorMessage = e.getMessage(); + } + } + + return emptyList(); + } + + protected String generateSqlQuery(Query naturalLanguageQuery, String previousSqlQuery, String previousErrorMessage) { + + List messages = new ArrayList<>(); + messages.add(createSystemPrompt().toSystemMessage()); + messages.add(UserMessage.from(naturalLanguageQuery.text())); + + if (previousSqlQuery != null && previousErrorMessage != null) { + messages.add(AiMessage.from(previousSqlQuery)); + messages.add(UserMessage.from(previousErrorMessage)); + } + + return chatLanguageModel.generate(messages).content().text(); + } + + protected Prompt createSystemPrompt() { + Map variables = new HashMap<>(); + variables.put("sqlDialect", sqlDialect); + variables.put("databaseStructure", databaseStructure); + return promptTemplate.apply(variables); + } + + protected String clean(String sqlQuery) { + if (sqlQuery.contains("```sql")) { + return sqlQuery.substring(sqlQuery.indexOf("```sql") + 6, sqlQuery.lastIndexOf("```")); + } else if (sqlQuery.contains("```")) { + return sqlQuery.substring(sqlQuery.indexOf("```") + 3, sqlQuery.lastIndexOf("```")); + } + return sqlQuery; + } + + protected void validate(String sqlQuery) { + + } + + protected boolean isSelect(String sqlQuery) { + try { + net.sf.jsqlparser.statement.Statement statement = CCJSqlParserUtil.parse(sqlQuery); + return statement instanceof Select; + } catch (JSQLParserException e) { + return false; + } + } + + protected String execute(String sqlQuery, Statement statement) throws SQLException { + List resultRows = new ArrayList<>(); + + try (ResultSet resultSet = statement.executeQuery(sqlQuery)) { + int columnCount = resultSet.getMetaData().getColumnCount(); + + // header + List columnNames = new ArrayList<>(); + for (int i = 1; i <= columnCount; i++) { + columnNames.add(resultSet.getMetaData().getColumnName(i)); + } + resultRows.add(String.join(",", columnNames)); + + // rows + while (resultSet.next()) { + List columnValues = new ArrayList<>(); + for (int i = 1; i <= columnCount; i++) { + + String columnValue = resultSet.getObject(i)==null?"":resultSet.getObject(i).toString(); + + if (columnValue.contains(",")) { + columnValue = "\"" + columnValue + "\""; + } + columnValues.add(columnValue); + } + resultRows.add(String.join(",", columnValues)); + } + } + + return String.join("\n", resultRows); + } + + private static Content format(String result, String sqlQuery) { + return Content.from(String.format("Result of executing '%s':\n%s", sqlQuery, result)); + } +} diff --git a/experimental/langchain4j-experimental-sql/src/test/java/dev/langchain4j/experimental/rag/content/retriever/sql/SqlDatabaseContentRetrieverIT.java b/experimental/langchain4j-experimental-sql/src/test/java/dev/langchain4j/experimental/rag/content/retriever/sql/SqlDatabaseContentRetrieverIT.java new file mode 100644 index 0000000000..a414bbf3b2 --- /dev/null +++ b/experimental/langchain4j-experimental-sql/src/test/java/dev/langchain4j/experimental/rag/content/retriever/sql/SqlDatabaseContentRetrieverIT.java @@ -0,0 +1,326 @@ +package dev.langchain4j.experimental.rag.content.retriever.sql; + +import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.mistralai.MistralAiChatModel; +import dev.langchain4j.model.openai.OpenAiChatModel; +import dev.langchain4j.rag.content.Content; +import dev.langchain4j.rag.content.retriever.ContentRetriever; +import dev.langchain4j.rag.query.Query; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.postgresql.ds.PGSimpleDataSource; +import org.testcontainers.containers.PostgreSQLContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.utility.DockerImageName; + +import javax.sql.DataSource; +import java.io.IOException; +import java.net.URISyntaxException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.sql.*; +import java.util.List; +import java.util.function.Function; +import java.util.stream.Stream; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatCode; + +@Testcontainers +class SqlDatabaseContentRetrieverIT { + + static ChatLanguageModel openAiChatModel = OpenAiChatModel.builder() + .baseUrl(System.getenv("OPENAI_BASE_URL")) + .apiKey(System.getenv("OPENAI_API_KEY")) + .organizationId(System.getenv("OPENAI_ORGANIZATION_ID")) + .logRequests(true) + .logResponses(true) + .build(); + + static ChatLanguageModel mistralAiChatModel = MistralAiChatModel.builder() + .apiKey(System.getenv("MISTRAL_AI_API_KEY")) + .logRequests(true) + .logResponses(true) + .build(); + + @Container + PostgreSQLContainer postgres = new PostgreSQLContainer<>(DockerImageName.parse("postgres:12.18")); + + DataSource dataSource; + + @BeforeEach + void beforeEach() { + dataSource = createDataSource(); + + String createTablesScript = read("sql/create_tables.sql"); + execute(createTablesScript, dataSource); + + String prefillTablesScript = read("sql/prefill_tables.sql"); + execute(prefillTablesScript, dataSource); + } + + private PGSimpleDataSource createDataSource() { + PGSimpleDataSource dataSource = new PGSimpleDataSource(); + dataSource.setUrl(postgres.getJdbcUrl()); + dataSource.setUser(postgres.getUsername()); + dataSource.setPassword(postgres.getPassword()); + return dataSource; + } + + @AfterEach + void afterEach() { + execute("DROP TABLE orders;", dataSource); + execute("DROP TABLE products;", dataSource); + execute("DROP TABLE customers;", dataSource); + } + + @ParameterizedTest + @MethodSource("contentRetrieverProviders") + void should_answer_query_1(Function contentRetrieverProvider) { + + // given + ContentRetriever contentRetriever = contentRetrieverProvider.apply(dataSource); + + // when + List retrieved = contentRetriever.retrieve(Query.from("How many customers do we have?")); + + // then + assertThat(retrieved).hasSize(1); + + assertThat(retrieved.get(0).textSegment().text()) + .contains("SELECT") + .contains("5"); + } + + @ParameterizedTest + @MethodSource("contentRetrieverProviders") + void should_answer_query_2(Function contentRetrieverProvider) { + + // given + ContentRetriever contentRetriever = contentRetrieverProvider.apply(dataSource); + + // when + List retrieved = contentRetriever.retrieve(Query.from("What is the total sales in dollars for each product?")); + + // then + assertThat(retrieved).hasSize(1); + + assertThat(retrieved.get(0).textSegment().text()) + .contains("SELECT") + .contains("99.98", "71.97", "64.95", "22.50", "23.97"); + } + + @ParameterizedTest + @MethodSource("contentRetrieverProviders") + void should_answer_query_3(Function contentRetrieverProvider) { + + // given + ContentRetriever contentRetriever = contentRetrieverProvider.apply(dataSource); + + // when + List retrieved = contentRetriever.retrieve(Query.from("Which quarters show the highest sales?")); + + // then + assertThat(retrieved).hasSize(1); + + assertThat(retrieved.get(0).textSegment().text()) + .contains("SELECT") + .containsAnyOf("2,283.37", "2.0,283.37"); + } + + @ParameterizedTest + @MethodSource("contentRetrieverProviders") + void should_answer_query_4(Function contentRetrieverProvider) { + + // given + ContentRetriever contentRetriever = contentRetrieverProvider.apply(dataSource); + + // when + List retrieved = contentRetriever.retrieve(Query.from("Who is our top customer by total spend?")); + + // then + assertThat(retrieved).hasSize(1); + + assertThat(retrieved.get(0).textSegment().text()) + .contains("SELECT") + .contains("Carol") + .doesNotContain("John", "Jane", "Alice", "Bob"); + } + + @ParameterizedTest + @MethodSource("contentRetrieverProviders") + void should_not_fail_for_unrelated_query(Function contentRetrieverProvider) { + + // given + ContentRetriever contentRetriever = contentRetrieverProvider.apply(dataSource); + + // when-then + assertThatCode(() -> contentRetriever.retrieve(Query.from("hello"))) + .doesNotThrowAnyException(); + } + + @ParameterizedTest + @MethodSource("contentRetrieverProviders") + void should_not_drop_table(Function contentRetrieverProvider) { + + // given + ContentRetriever contentRetriever = contentRetrieverProvider.apply(dataSource); + + long customersHash = getTableHash(dataSource, "customers"); + long productsHash = getTableHash(dataSource, "products"); + long ordersHash = getTableHash(dataSource, "orders"); + + // when + List retrieved = contentRetriever.retrieve(Query.from("Drop table with orders")); + + // then + assertThat(retrieved).isEmpty(); + + assertThat(getTableHash(dataSource, "customers")).isEqualTo(customersHash); + assertThat(getTableHash(dataSource, "products")).isEqualTo(productsHash); + assertThat(getTableHash(dataSource, "orders")).isEqualTo(ordersHash); + } + + @ParameterizedTest + @MethodSource("contentRetrieverProviders") + void should_not_delete_existing_rows(Function contentRetrieverProvider) { + + // given + ContentRetriever contentRetriever = contentRetrieverProvider.apply(dataSource); + + long customersHash = getTableHash(dataSource, "customers"); + long productsHash = getTableHash(dataSource, "products"); + long ordersHash = getTableHash(dataSource, "orders"); + + // when + List retrieved = contentRetriever.retrieve(Query.from("Delete customer with ID=1")); + + // then + assertThat(retrieved).isEmpty(); + + assertThat(getTableHash(dataSource, "customers")).isEqualTo(customersHash); + assertThat(getTableHash(dataSource, "products")).isEqualTo(productsHash); + assertThat(getTableHash(dataSource, "orders")).isEqualTo(ordersHash); + } + + @ParameterizedTest + @MethodSource("contentRetrieverProviders") + void should_not_insert_new_rows(Function contentRetrieverProvider) { + + // given + ContentRetriever contentRetriever = contentRetrieverProvider.apply(dataSource); + + long customersHash = getTableHash(dataSource, "customers"); + long productsHash = getTableHash(dataSource, "products"); + long ordersHash = getTableHash(dataSource, "orders"); + + // when + List retrieved = contentRetriever.retrieve(Query.from("Insert new customer James Bond with ID=7")); + + // then + assertThat(retrieved).isEmpty(); + + assertThat(getTableHash(dataSource, "customers")).isEqualTo(customersHash); + assertThat(getTableHash(dataSource, "products")).isEqualTo(productsHash); + assertThat(getTableHash(dataSource, "orders")).isEqualTo(ordersHash); + } + + @ParameterizedTest + @MethodSource("contentRetrieverProviders") + void should_not_update_existing_rows(Function contentRetrieverProvider) { + + // given + ContentRetriever contentRetriever = contentRetrieverProvider.apply(dataSource); + + long customersHash = getTableHash(dataSource, "customers"); + long productsHash = getTableHash(dataSource, "products"); + long ordersHash = getTableHash(dataSource, "orders"); + + // when + List retrieved = contentRetriever.retrieve(Query.from("Update email of customer with ID=1 to bad@guy.com")); + + // then + assertThat(retrieved).isEmpty(); + + assertThat(getTableHash(dataSource, "customers")).isEqualTo(customersHash); + assertThat(getTableHash(dataSource, "products")).isEqualTo(productsHash); + assertThat(getTableHash(dataSource, "orders")).isEqualTo(ordersHash); + } + + private static void execute(String sql, DataSource dataSource) { + try (Connection connection = dataSource.getConnection(); Statement statement = connection.createStatement()) { + for (String sqlStatement : sql.split(";")) { + statement.execute(sqlStatement.trim()); + } + } catch (SQLException e) { + throw new RuntimeException(e); + } + } + + private static long getTableHash(DataSource dataSource, String tableName) { + String query = "SELECT * FROM " + tableName; + try (Connection conn = dataSource.getConnection(); + Statement stmt = conn.createStatement(); + ResultSet rs = stmt.executeQuery(query)) { + StringBuilder dataBuilder = new StringBuilder(); + ResultSetMetaData metaData = rs.getMetaData(); + while (rs.next()) { + for (int i = 1; i <= metaData.getColumnCount(); i++) { + dataBuilder.append(rs.getString(i)); + } + } + return dataBuilder.toString().hashCode(); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + static Stream> contentRetrieverProviders() { + return Stream.of( + + // OpenAI + dataSource -> SqlDatabaseContentRetriever.builder() + .dataSource(dataSource) + .sqlDialect("PostgreSQL") + .databaseStructure(read("sql/create_tables.sql")) + .chatLanguageModel(openAiChatModel) + .build(), + dataSource -> SqlDatabaseContentRetriever.builder() + .dataSource(dataSource) + .chatLanguageModel(openAiChatModel) + .build(), + + // Mistral + dataSource -> SqlDatabaseContentRetriever.builder() + .dataSource(dataSource) + .sqlDialect("PostgreSQL") + .databaseStructure(read("sql/create_tables.sql")) + .chatLanguageModel(mistralAiChatModel) + .build(), + dataSource -> SqlDatabaseContentRetriever.builder() + .dataSource(dataSource) + .chatLanguageModel(mistralAiChatModel) + .build() + ); + } + + private static String read(String path) { + try { + return new String(Files.readAllBytes(toPath(path))); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + private static Path toPath(String fileName) { + try { + return Paths.get(SqlDatabaseContentRetrieverIT.class.getClassLoader().getResource(fileName).toURI()); + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } + } +} \ No newline at end of file diff --git a/experimental/langchain4j-experimental-sql/src/test/resources/logback-test.xml b/experimental/langchain4j-experimental-sql/src/test/resources/logback-test.xml new file mode 100644 index 0000000000..c58c8de0fa --- /dev/null +++ b/experimental/langchain4j-experimental-sql/src/test/resources/logback-test.xml @@ -0,0 +1,21 @@ + + + + + %d{HH:mm:ss.SSS} [%thread] %-5level %logger - %msg%n + + + + + + + + + + + + + + + + diff --git a/experimental/langchain4j-experimental-sql/src/test/resources/sql/create_tables.sql b/experimental/langchain4j-experimental-sql/src/test/resources/sql/create_tables.sql new file mode 100644 index 0000000000..de26b9208f --- /dev/null +++ b/experimental/langchain4j-experimental-sql/src/test/resources/sql/create_tables.sql @@ -0,0 +1,25 @@ +CREATE TABLE customers +( + customer_id INT PRIMARY KEY, + first_name VARCHAR(50), + last_name VARCHAR(50), + email VARCHAR(100) +); + +CREATE TABLE products +( + product_id INT PRIMARY KEY, + product_name VARCHAR(100), + price DECIMAL(10, 2) +); + +CREATE TABLE orders +( + order_id INT PRIMARY KEY, + customer_id INT, + product_id INT, + quantity INT, + order_date DATE, + FOREIGN KEY (customer_id) REFERENCES customers (customer_id), + FOREIGN KEY (product_id) REFERENCES products (product_id) +); diff --git a/experimental/langchain4j-experimental-sql/src/test/resources/sql/prefill_tables.sql b/experimental/langchain4j-experimental-sql/src/test/resources/sql/prefill_tables.sql new file mode 100644 index 0000000000..154bc6e66b --- /dev/null +++ b/experimental/langchain4j-experimental-sql/src/test/resources/sql/prefill_tables.sql @@ -0,0 +1,25 @@ +INSERT INTO customers (customer_id, first_name, last_name, email) +VALUES (1, 'John', 'Doe', 'john.doe@example.com'), + (2, 'Jane', 'Smith', 'jane.smith@example.com'), + (3, 'Alice', 'Johnson', 'alice.johnson@example.com'), + (4, 'Bob', 'Williams', 'bob.williams@example.com'), + (5, 'Carol', 'Brown', 'carol.brown@example.com'); + +INSERT INTO products (product_id, product_name, price) +VALUES (10, 'Notebook', 12.99), + (20, 'Pen', 1.50), + (30, 'Desk Lamp', 23.99), + (40, 'Backpack', 49.99), + (50, 'Stapler', 7.99); + +INSERT INTO orders (order_id, customer_id, product_id, quantity, order_date) +VALUES (100, 1, 10, 2, '2024-04-20'), + (200, 2, 20, 5, '2024-04-21'), + (300, 3, 10, 1, '2024-04-22'), + (400, 4, 30, 1, '2024-04-23'), + (500, 5, 40, 1, '2024-04-24'), + (600, 1, 50, 3, '2024-04-25'), + (700, 2, 10, 2, '2024-04-26'), + (800, 3, 40, 1, '2024-04-27'), + (900, 4, 20, 10, '2024-04-28'), + (10000, 5, 30, 2, '2024-04-29'); diff --git a/langchain4j-anthropic/pom.xml b/langchain4j-anthropic/pom.xml index 8212724e7c..6dbdf93240 100644 --- a/langchain4j-anthropic/pom.xml +++ b/langchain4j-anthropic/pom.xml @@ -5,7 +5,7 @@ dev.langchain4j langchain4j-parent - 0.30.0 + 0.32.0-SNAPSHOT ../langchain4j-parent/pom.xml @@ -30,14 +30,12 @@ com.squareup.retrofit2 - converter-gson - - - - com.google.code.gson - gson - - + converter-jackson + + + + com.fasterxml.jackson.core + jackson-databind diff --git a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicChatModel.java b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicChatModel.java index 193f2f5b04..7ef5f13382 100644 --- a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicChatModel.java +++ b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicChatModel.java @@ -3,6 +3,9 @@ import dev.langchain4j.agent.tool.ToolSpecification; import dev.langchain4j.data.image.Image; import dev.langchain4j.data.message.*; +import dev.langchain4j.model.anthropic.internal.api.AnthropicCreateMessageRequest; +import dev.langchain4j.model.anthropic.internal.api.AnthropicCreateMessageResponse; +import dev.langchain4j.model.anthropic.internal.client.AnthropicClient; import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.output.Response; import lombok.Builder; @@ -12,9 +15,9 @@ import static dev.langchain4j.internal.RetryUtils.withRetry; import static dev.langchain4j.internal.Utils.getOrDefault; -import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty; import static dev.langchain4j.model.anthropic.AnthropicChatModelName.CLAUDE_3_HAIKU_20240307; -import static dev.langchain4j.model.anthropic.AnthropicMapper.*; +import static dev.langchain4j.model.anthropic.internal.mapper.AnthropicMapper.*; +import static dev.langchain4j.model.anthropic.internal.sanitizer.MessageSanitizer.sanitizeMessages; /** * Represents an Anthropic language model with a Messages (chat) API. @@ -31,6 +34,11 @@ *
* The content of {@link SystemMessage}s is sent using the "system" parameter. * If there are multiple {@link SystemMessage}s, they are concatenated with a double newline (\n\n). + *
+ *
+ * Sanitization is performed on the {@link ChatMessage}s provided to conform to Anthropic API requirements. This process + * includes verifying that the first message is a {@link UserMessage} and removing any consecutive {@link UserMessage}s. + * Any messages removed during sanitization are logged as warnings and not submitted to the API. */ public class AnthropicChatModel implements ChatLanguageModel { @@ -124,12 +132,13 @@ public Response generate(List messages) { @Override public Response generate(List messages, List toolSpecifications) { - ensureNotEmpty(messages, "messages"); + List sanitizedMessages = sanitizeMessages(messages); + String systemPrompt = toAnthropicSystemPrompt(messages); AnthropicCreateMessageRequest request = AnthropicCreateMessageRequest.builder() .model(modelName) - .messages(toAnthropicMessages(messages)) - .system(toAnthropicSystemPrompt(messages)) + .messages(toAnthropicMessages(sanitizedMessages)) + .system(systemPrompt) .maxTokens(maxTokens) .stopSequences(stopSequences) .stream(false) diff --git a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicContent.java b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicContent.java deleted file mode 100644 index 59dc693288..0000000000 --- a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicContent.java +++ /dev/null @@ -1,19 +0,0 @@ -package dev.langchain4j.model.anthropic; - -import lombok.Builder; - -import java.util.Map; - -@Builder -public class AnthropicContent { - - public String type; - - // when type = "text" - public String text; - - // when type = "tool_use" - public String id; - public String name; - public Map input; -} \ No newline at end of file diff --git a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicCreateMessageRequest.java b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicCreateMessageRequest.java deleted file mode 100644 index 3798183d87..0000000000 --- a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicCreateMessageRequest.java +++ /dev/null @@ -1,26 +0,0 @@ -package dev.langchain4j.model.anthropic; - -import lombok.AllArgsConstructor; -import lombok.Builder; -import lombok.Data; -import lombok.NoArgsConstructor; - -import java.util.List; - -@Data -@NoArgsConstructor -@AllArgsConstructor -@Builder(toBuilder = true) -public class AnthropicCreateMessageRequest { - - String model; - List messages; - String system; - int maxTokens; - List stopSequences; - boolean stream; - Double temperature; - Double topP; - Integer topK; - List tools; -} diff --git a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicCreateMessageResponse.java b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicCreateMessageResponse.java deleted file mode 100644 index 802bf250de..0000000000 --- a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicCreateMessageResponse.java +++ /dev/null @@ -1,15 +0,0 @@ -package dev.langchain4j.model.anthropic; - -import java.util.List; - -public class AnthropicCreateMessageResponse { - - public String id; - public String type; - public String role; - public List content; - public String model; - public String stopReason; - public String stopSequence; - public AnthropicUsage usage; -} diff --git a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicDelta.java b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicDelta.java deleted file mode 100644 index dd398aeea7..0000000000 --- a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicDelta.java +++ /dev/null @@ -1,12 +0,0 @@ -package dev.langchain4j.model.anthropic; - -public class AnthropicDelta { - - // when AnthropicStreamingData.type = "content_block_delta" - public String type; - public String text; - - // when AnthropicStreamingData.type = "message_delta" - public String stopReason; - public String stopSequence; -} \ No newline at end of file diff --git a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicImageContent.java b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicImageContent.java deleted file mode 100644 index 23c2c032b3..0000000000 --- a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicImageContent.java +++ /dev/null @@ -1,16 +0,0 @@ -package dev.langchain4j.model.anthropic; - -import lombok.EqualsAndHashCode; -import lombok.ToString; - -@ToString -@EqualsAndHashCode(callSuper = true) -public class AnthropicImageContent extends AnthropicMessageContent { - - public AnthropicImageContentSource source; - - public AnthropicImageContent(String mediaType, String data) { - super("image"); - this.source = new AnthropicImageContentSource("base64", mediaType, data); - } -} diff --git a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicImageContentSource.java b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicImageContentSource.java deleted file mode 100644 index a291ff4b23..0000000000 --- a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicImageContentSource.java +++ /dev/null @@ -1,11 +0,0 @@ -package dev.langchain4j.model.anthropic; - -import lombok.AllArgsConstructor; - -@AllArgsConstructor -public class AnthropicImageContentSource { - - public String type; - public String mediaType; - public String data; -} \ No newline at end of file diff --git a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicMessage.java b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicMessage.java deleted file mode 100644 index 5d8ea9243f..0000000000 --- a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicMessage.java +++ /dev/null @@ -1,20 +0,0 @@ -package dev.langchain4j.model.anthropic; - -import lombok.AllArgsConstructor; -import lombok.Builder; -import lombok.EqualsAndHashCode; -import lombok.Getter; -import lombok.ToString; - -import java.util.List; - -@Builder -@ToString -@EqualsAndHashCode -@AllArgsConstructor -@Getter -public class AnthropicMessage { - - AnthropicRole role; - List content; -} diff --git a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicMessageContent.java b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicMessageContent.java deleted file mode 100644 index 1cae26cc51..0000000000 --- a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicMessageContent.java +++ /dev/null @@ -1,13 +0,0 @@ -package dev.langchain4j.model.anthropic; - -import lombok.EqualsAndHashCode; - -@EqualsAndHashCode -public abstract class AnthropicMessageContent { - - public String type; - - public AnthropicMessageContent(String type) { - this.type = type; - } -} diff --git a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicResponseMessage.java b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicResponseMessage.java deleted file mode 100644 index b550d7ae55..0000000000 --- a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicResponseMessage.java +++ /dev/null @@ -1,15 +0,0 @@ -package dev.langchain4j.model.anthropic; - -import java.util.List; - -public class AnthropicResponseMessage { - - public String id; - public String type; - public String role; - public List content; - public String model; - public String stopReason; - public String stopSequence; - public AnthropicUsage usage; -} \ No newline at end of file diff --git a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicRole.java b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicRole.java deleted file mode 100644 index 326dea0bd2..0000000000 --- a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicRole.java +++ /dev/null @@ -1,9 +0,0 @@ -package dev.langchain4j.model.anthropic; - -import com.google.gson.annotations.SerializedName; - -public enum AnthropicRole { - - @SerializedName("user") USER, - @SerializedName("assistant") ASSISTANT -} diff --git a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicStreamingChatModel.java b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicStreamingChatModel.java index 6e44231fcd..2efcc16507 100644 --- a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicStreamingChatModel.java +++ b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicStreamingChatModel.java @@ -3,6 +3,8 @@ import dev.langchain4j.data.image.Image; import dev.langchain4j.data.message.*; import dev.langchain4j.model.StreamingResponseHandler; +import dev.langchain4j.model.anthropic.internal.api.AnthropicCreateMessageRequest; +import dev.langchain4j.model.anthropic.internal.client.AnthropicClient; import dev.langchain4j.model.chat.StreamingChatLanguageModel; import lombok.Builder; @@ -13,8 +15,9 @@ import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty; import static dev.langchain4j.internal.ValidationUtils.ensureNotNull; import static dev.langchain4j.model.anthropic.AnthropicChatModelName.CLAUDE_3_HAIKU_20240307; -import static dev.langchain4j.model.anthropic.AnthropicMapper.toAnthropicMessages; -import static dev.langchain4j.model.anthropic.AnthropicMapper.toAnthropicSystemPrompt; +import static dev.langchain4j.model.anthropic.internal.mapper.AnthropicMapper.toAnthropicMessages; +import static dev.langchain4j.model.anthropic.internal.mapper.AnthropicMapper.toAnthropicSystemPrompt; +import static dev.langchain4j.model.anthropic.internal.sanitizer.MessageSanitizer.sanitizeMessages; /** * Represents an Anthropic language model with a Messages (chat) API. @@ -32,6 +35,11 @@ * If there are multiple {@link SystemMessage}s, they are concatenated with a double newline (\n\n). *
*
+ * Sanitization is performed on the {@link ChatMessage}s provided to ensure conformity with Anthropic API requirements. + * This includes ensuring the first message is a {@link UserMessage} and that there are no consecutive {@link UserMessage}s. + * Any messages removed during sanitization are logged as warnings and not submitted to the API. + *
+ *
* Does not support tools. */ public class AnthropicStreamingChatModel implements StreamingChatLanguageModel { @@ -114,13 +122,14 @@ public static AnthropicStreamingChatModel withApiKey(String apiKey) { @Override public void generate(List messages, StreamingResponseHandler handler) { - ensureNotEmpty(messages, "messages"); + List sanitizedMessages = sanitizeMessages(messages); + String systemPrompt = toAnthropicSystemPrompt(messages); ensureNotNull(handler, "handler"); AnthropicCreateMessageRequest request = AnthropicCreateMessageRequest.builder() .model(modelName) - .messages(toAnthropicMessages(messages)) - .system(toAnthropicSystemPrompt(messages)) + .messages(toAnthropicMessages(sanitizedMessages)) + .system(systemPrompt) .maxTokens(maxTokens) .stopSequences(stopSequences) .stream(true) diff --git a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicTextContent.java b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicTextContent.java deleted file mode 100644 index 5bd677d55d..0000000000 --- a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicTextContent.java +++ /dev/null @@ -1,16 +0,0 @@ -package dev.langchain4j.model.anthropic; - -import lombok.EqualsAndHashCode; -import lombok.ToString; - -@ToString -@EqualsAndHashCode(callSuper = true) -public class AnthropicTextContent extends AnthropicMessageContent { - - public String text; - - public AnthropicTextContent(String text) { - super("text"); - this.text = text; - } -} diff --git a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicTool.java b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicTool.java deleted file mode 100644 index edba4da31d..0000000000 --- a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicTool.java +++ /dev/null @@ -1,15 +0,0 @@ -package dev.langchain4j.model.anthropic; - -import lombok.Builder; -import lombok.EqualsAndHashCode; -import lombok.ToString; - -@Builder -@ToString -@EqualsAndHashCode -public class AnthropicTool { - - public String name; - public String description; - public AnthropicToolSchema inputSchema; -} diff --git a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicToolSchema.java b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicToolSchema.java deleted file mode 100644 index f6b762d9af..0000000000 --- a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicToolSchema.java +++ /dev/null @@ -1,19 +0,0 @@ -package dev.langchain4j.model.anthropic; - -import lombok.Builder; -import lombok.EqualsAndHashCode; -import lombok.ToString; - -import java.util.List; -import java.util.Map; - -@Builder -@ToString -@EqualsAndHashCode -public class AnthropicToolSchema { - - @Builder.Default - public String type = "object"; - public Map> properties; - public List required; -} diff --git a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicUsage.java b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicUsage.java deleted file mode 100644 index c905612f66..0000000000 --- a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicUsage.java +++ /dev/null @@ -1,7 +0,0 @@ -package dev.langchain4j.model.anthropic; - -public class AnthropicUsage { - - public Integer inputTokens; - public Integer outputTokens; -} \ No newline at end of file diff --git a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicApi.java b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/api/AnthropicApi.java similarity index 91% rename from langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicApi.java rename to langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/api/AnthropicApi.java index 843128a45a..407fc50ed8 100644 --- a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicApi.java +++ b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/api/AnthropicApi.java @@ -1,10 +1,10 @@ -package dev.langchain4j.model.anthropic; +package dev.langchain4j.model.anthropic.internal.api; import okhttp3.ResponseBody; import retrofit2.Call; import retrofit2.http.*; -interface AnthropicApi { +public interface AnthropicApi { String X_API_KEY = "x-api-key"; diff --git a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/api/AnthropicContent.java b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/api/AnthropicContent.java new file mode 100644 index 0000000000..9511a51631 --- /dev/null +++ b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/api/AnthropicContent.java @@ -0,0 +1,26 @@ +package dev.langchain4j.model.anthropic.internal.api; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategies.SnakeCaseStrategy; +import com.fasterxml.jackson.databind.annotation.JsonNaming; + +import java.util.Map; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; + +@JsonInclude(NON_NULL) +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonNaming(SnakeCaseStrategy.class) +public class AnthropicContent { + + public String type; + + // when type = "text" + public String text; + + // when type = "tool_use" + public String id; + public String name; + public Map input; +} \ No newline at end of file diff --git a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/api/AnthropicCreateMessageRequest.java b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/api/AnthropicCreateMessageRequest.java new file mode 100644 index 0000000000..b5aab94d28 --- /dev/null +++ b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/api/AnthropicCreateMessageRequest.java @@ -0,0 +1,35 @@ +package dev.langchain4j.model.anthropic.internal.api; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategies.SnakeCaseStrategy; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; + +import java.util.List; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; + +@Data +@NoArgsConstructor +@AllArgsConstructor +@Builder(toBuilder = true) +@JsonInclude(NON_NULL) +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonNaming(SnakeCaseStrategy.class) +public class AnthropicCreateMessageRequest { + + public String model; + public List messages; + public String system; + public int maxTokens; + public List stopSequences; + public boolean stream; + public Double temperature; + public Double topP; + public Integer topK; + public List tools; +} diff --git a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/api/AnthropicCreateMessageResponse.java b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/api/AnthropicCreateMessageResponse.java new file mode 100644 index 0000000000..ff570b0f14 --- /dev/null +++ b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/api/AnthropicCreateMessageResponse.java @@ -0,0 +1,25 @@ +package dev.langchain4j.model.anthropic.internal.api; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategies.SnakeCaseStrategy; +import com.fasterxml.jackson.databind.annotation.JsonNaming; + +import java.util.List; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; + +@JsonInclude(NON_NULL) +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonNaming(SnakeCaseStrategy.class) +public class AnthropicCreateMessageResponse { + + public String id; + public String type; + public String role; + public List content; + public String model; + public String stopReason; + public String stopSequence; + public AnthropicUsage usage; +} diff --git a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/api/AnthropicDelta.java b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/api/AnthropicDelta.java new file mode 100644 index 0000000000..d4d21f7fa8 --- /dev/null +++ b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/api/AnthropicDelta.java @@ -0,0 +1,22 @@ +package dev.langchain4j.model.anthropic.internal.api; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategies.SnakeCaseStrategy; +import com.fasterxml.jackson.databind.annotation.JsonNaming; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; + +@JsonInclude(NON_NULL) +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonNaming(SnakeCaseStrategy.class) +public class AnthropicDelta { + + // when AnthropicStreamingData.type = "content_block_delta" + public String type; + public String text; + + // when AnthropicStreamingData.type = "message_delta" + public String stopReason; + public String stopSequence; +} \ No newline at end of file diff --git a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/api/AnthropicImageContent.java b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/api/AnthropicImageContent.java new file mode 100644 index 0000000000..e943301e09 --- /dev/null +++ b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/api/AnthropicImageContent.java @@ -0,0 +1,25 @@ +package dev.langchain4j.model.anthropic.internal.api; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategies.SnakeCaseStrategy; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import lombok.EqualsAndHashCode; +import lombok.ToString; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; + +@ToString +@EqualsAndHashCode(callSuper = true) +@JsonInclude(NON_NULL) +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonNaming(SnakeCaseStrategy.class) +public class AnthropicImageContent extends AnthropicMessageContent { + + public AnthropicImageContentSource source; + + public AnthropicImageContent(String mediaType, String data) { + super("image"); + this.source = new AnthropicImageContentSource("base64", mediaType, data); + } +} diff --git a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/api/AnthropicImageContentSource.java b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/api/AnthropicImageContentSource.java new file mode 100644 index 0000000000..2cbf376685 --- /dev/null +++ b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/api/AnthropicImageContentSource.java @@ -0,0 +1,20 @@ +package dev.langchain4j.model.anthropic.internal.api; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategies.SnakeCaseStrategy; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import lombok.AllArgsConstructor; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; + +@AllArgsConstructor +@JsonInclude(NON_NULL) +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonNaming(SnakeCaseStrategy.class) +public class AnthropicImageContentSource { + + public String type; + public String mediaType; + public String data; +} \ No newline at end of file diff --git a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/api/AnthropicMessage.java b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/api/AnthropicMessage.java new file mode 100644 index 0000000000..61ce0a3560 --- /dev/null +++ b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/api/AnthropicMessage.java @@ -0,0 +1,25 @@ +package dev.langchain4j.model.anthropic.internal.api; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategies.SnakeCaseStrategy; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import lombok.*; + +import java.util.List; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; + +@Builder +@ToString +@EqualsAndHashCode +@NoArgsConstructor +@AllArgsConstructor +@JsonInclude(NON_NULL) +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonNaming(SnakeCaseStrategy.class) +public class AnthropicMessage { + + public AnthropicRole role; + public List content; +} diff --git a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/api/AnthropicMessageContent.java b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/api/AnthropicMessageContent.java new file mode 100644 index 0000000000..6f1071d5e0 --- /dev/null +++ b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/api/AnthropicMessageContent.java @@ -0,0 +1,22 @@ +package dev.langchain4j.model.anthropic.internal.api; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategies.SnakeCaseStrategy; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import lombok.EqualsAndHashCode; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; + +@EqualsAndHashCode +@JsonInclude(NON_NULL) +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonNaming(SnakeCaseStrategy.class) +public abstract class AnthropicMessageContent { + + public String type; + + public AnthropicMessageContent(String type) { + this.type = type; + } +} diff --git a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/api/AnthropicResponseMessage.java b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/api/AnthropicResponseMessage.java new file mode 100644 index 0000000000..1f96fef85c --- /dev/null +++ b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/api/AnthropicResponseMessage.java @@ -0,0 +1,25 @@ +package dev.langchain4j.model.anthropic.internal.api; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategies.SnakeCaseStrategy; +import com.fasterxml.jackson.databind.annotation.JsonNaming; + +import java.util.List; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; + +@JsonInclude(NON_NULL) +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonNaming(SnakeCaseStrategy.class) +public class AnthropicResponseMessage { + + public String id; + public String type; + public String role; + public List content; + public String model; + public String stopReason; + public String stopSequence; + public AnthropicUsage usage; +} \ No newline at end of file diff --git a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/api/AnthropicRole.java b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/api/AnthropicRole.java new file mode 100644 index 0000000000..d0ac18d06f --- /dev/null +++ b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/api/AnthropicRole.java @@ -0,0 +1,16 @@ +package dev.langchain4j.model.anthropic.internal.api; + +import com.fasterxml.jackson.annotation.JsonValue; + +import java.util.Locale; + +public enum AnthropicRole { + + USER, + ASSISTANT; + + @JsonValue + public String serialize() { + return name().toLowerCase(Locale.ROOT); + } +} diff --git a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicStreamingData.java b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/api/AnthropicStreamingData.java similarity index 52% rename from langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicStreamingData.java rename to langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/api/AnthropicStreamingData.java index e16644e2c5..b114ceb328 100644 --- a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicStreamingData.java +++ b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/api/AnthropicStreamingData.java @@ -1,5 +1,15 @@ -package dev.langchain4j.model.anthropic; +package dev.langchain4j.model.anthropic.internal.api; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategies.SnakeCaseStrategy; +import com.fasterxml.jackson.databind.annotation.JsonNaming; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; + +@JsonInclude(NON_NULL) +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonNaming(SnakeCaseStrategy.class) public class AnthropicStreamingData { public String type; diff --git a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/api/AnthropicTextContent.java b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/api/AnthropicTextContent.java new file mode 100644 index 0000000000..d863ad14fd --- /dev/null +++ b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/api/AnthropicTextContent.java @@ -0,0 +1,25 @@ +package dev.langchain4j.model.anthropic.internal.api; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategies.SnakeCaseStrategy; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import lombok.EqualsAndHashCode; +import lombok.ToString; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; + +@ToString +@EqualsAndHashCode(callSuper = true) +@JsonInclude(NON_NULL) +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonNaming(SnakeCaseStrategy.class) +public class AnthropicTextContent extends AnthropicMessageContent { + + public String text; + + public AnthropicTextContent(String text) { + super("text"); + this.text = text; + } +} diff --git a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/api/AnthropicTool.java b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/api/AnthropicTool.java new file mode 100644 index 0000000000..f56ba7bc13 --- /dev/null +++ b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/api/AnthropicTool.java @@ -0,0 +1,24 @@ +package dev.langchain4j.model.anthropic.internal.api; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategies.SnakeCaseStrategy; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import lombok.Builder; +import lombok.EqualsAndHashCode; +import lombok.ToString; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; + +@Builder +@ToString +@EqualsAndHashCode +@JsonInclude(NON_NULL) +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonNaming(SnakeCaseStrategy.class) +public class AnthropicTool { + + public String name; + public String description; + public AnthropicToolSchema inputSchema; +} diff --git a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicToolResultContent.java b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/api/AnthropicToolResultContent.java similarity index 50% rename from langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicToolResultContent.java rename to langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/api/AnthropicToolResultContent.java index 055aa19dde..597ef681ce 100644 --- a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicToolResultContent.java +++ b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/api/AnthropicToolResultContent.java @@ -1,10 +1,19 @@ -package dev.langchain4j.model.anthropic; +package dev.langchain4j.model.anthropic.internal.api; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategies.SnakeCaseStrategy; +import com.fasterxml.jackson.databind.annotation.JsonNaming; import lombok.EqualsAndHashCode; import lombok.ToString; +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; + @ToString @EqualsAndHashCode(callSuper = true) +@JsonInclude(NON_NULL) +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonNaming(SnakeCaseStrategy.class) public class AnthropicToolResultContent extends AnthropicMessageContent { public String toolUseId; diff --git a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/api/AnthropicToolSchema.java b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/api/AnthropicToolSchema.java new file mode 100644 index 0000000000..bdbb49f316 --- /dev/null +++ b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/api/AnthropicToolSchema.java @@ -0,0 +1,28 @@ +package dev.langchain4j.model.anthropic.internal.api; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategies.SnakeCaseStrategy; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import lombok.Builder; +import lombok.EqualsAndHashCode; +import lombok.ToString; + +import java.util.List; +import java.util.Map; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; + +@Builder +@ToString +@EqualsAndHashCode +@JsonInclude(NON_NULL) +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonNaming(SnakeCaseStrategy.class) +public class AnthropicToolSchema { + + @Builder.Default + public String type = "object"; + public Map> properties; + public List required; +} diff --git a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicToolUseContent.java b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/api/AnthropicToolUseContent.java similarity index 51% rename from langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicToolUseContent.java rename to langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/api/AnthropicToolUseContent.java index a6fa290379..10d6a55e6f 100644 --- a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicToolUseContent.java +++ b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/api/AnthropicToolUseContent.java @@ -1,13 +1,22 @@ -package dev.langchain4j.model.anthropic; +package dev.langchain4j.model.anthropic.internal.api; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategies.SnakeCaseStrategy; +import com.fasterxml.jackson.databind.annotation.JsonNaming; import lombok.Builder; import lombok.EqualsAndHashCode; import lombok.ToString; import java.util.Map; +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; + @ToString @EqualsAndHashCode(callSuper = true) +@JsonInclude(NON_NULL) +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonNaming(SnakeCaseStrategy.class) public class AnthropicToolUseContent extends AnthropicMessageContent { public String id; diff --git a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/api/AnthropicUsage.java b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/api/AnthropicUsage.java new file mode 100644 index 0000000000..a0d1889368 --- /dev/null +++ b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/api/AnthropicUsage.java @@ -0,0 +1,17 @@ +package dev.langchain4j.model.anthropic.internal.api; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategies.SnakeCaseStrategy; +import com.fasterxml.jackson.databind.annotation.JsonNaming; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; + +@JsonInclude(NON_NULL) +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonNaming(SnakeCaseStrategy.class) +public class AnthropicUsage { + + public Integer inputTokens; + public Integer outputTokens; +} \ No newline at end of file diff --git a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicClient.java b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/client/AnthropicClient.java similarity index 93% rename from langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicClient.java rename to langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/client/AnthropicClient.java index 83116a1752..8d3a243334 100644 --- a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicClient.java +++ b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/client/AnthropicClient.java @@ -1,7 +1,9 @@ -package dev.langchain4j.model.anthropic; +package dev.langchain4j.model.anthropic.internal.client; import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.model.StreamingResponseHandler; +import dev.langchain4j.model.anthropic.internal.api.AnthropicCreateMessageRequest; +import dev.langchain4j.model.anthropic.internal.api.AnthropicCreateMessageResponse; import dev.langchain4j.spi.ServiceHelper; import java.time.Duration; diff --git a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicClientBuilderFactory.java b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/client/AnthropicClientBuilderFactory.java similarity index 73% rename from langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicClientBuilderFactory.java rename to langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/client/AnthropicClientBuilderFactory.java index f8bcc8b3ce..1ccf5c5573 100644 --- a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicClientBuilderFactory.java +++ b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/client/AnthropicClientBuilderFactory.java @@ -1,4 +1,4 @@ -package dev.langchain4j.model.anthropic; +package dev.langchain4j.model.anthropic.internal.client; import java.util.function.Supplier; diff --git a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicHttpException.java b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/client/AnthropicHttpException.java similarity index 87% rename from langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicHttpException.java rename to langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/client/AnthropicHttpException.java index b923dea33e..18a1afa9fb 100644 --- a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicHttpException.java +++ b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/client/AnthropicHttpException.java @@ -1,4 +1,4 @@ -package dev.langchain4j.model.anthropic; +package dev.langchain4j.model.anthropic.internal.client; public class AnthropicHttpException extends RuntimeException { diff --git a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicRequestLoggingInterceptor.java b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/client/AnthropicRequestLoggingInterceptor.java similarity index 97% rename from langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicRequestLoggingInterceptor.java rename to langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/client/AnthropicRequestLoggingInterceptor.java index c426726ddb..b8d66f18ca 100644 --- a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicRequestLoggingInterceptor.java +++ b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/client/AnthropicRequestLoggingInterceptor.java @@ -1,4 +1,4 @@ -package dev.langchain4j.model.anthropic; +package dev.langchain4j.model.anthropic.internal.client; import lombok.extern.slf4j.Slf4j; import okhttp3.Headers; diff --git a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicResponseLoggingInterceptor.java b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/client/AnthropicResponseLoggingInterceptor.java similarity index 88% rename from langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicResponseLoggingInterceptor.java rename to langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/client/AnthropicResponseLoggingInterceptor.java index fba15c0789..b4c32b6a3d 100644 --- a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicResponseLoggingInterceptor.java +++ b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/client/AnthropicResponseLoggingInterceptor.java @@ -1,4 +1,4 @@ -package dev.langchain4j.model.anthropic; +package dev.langchain4j.model.anthropic.internal.client; import lombok.extern.slf4j.Slf4j; import okhttp3.Interceptor; @@ -7,7 +7,7 @@ import java.io.IOException; -import static dev.langchain4j.model.anthropic.AnthropicRequestLoggingInterceptor.getHeaders; +import static dev.langchain4j.model.anthropic.internal.client.AnthropicRequestLoggingInterceptor.getHeaders; @Slf4j class AnthropicResponseLoggingInterceptor implements Interceptor { diff --git a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/DefaultAnthropicClient.java b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/client/DefaultAnthropicClient.java similarity index 93% rename from langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/DefaultAnthropicClient.java rename to langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/client/DefaultAnthropicClient.java index cf754f747d..9806b538dc 100644 --- a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/DefaultAnthropicClient.java +++ b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/client/DefaultAnthropicClient.java @@ -1,9 +1,9 @@ -package dev.langchain4j.model.anthropic; +package dev.langchain4j.model.anthropic.internal.client; -import com.google.gson.Gson; -import com.google.gson.GsonBuilder; +import com.fasterxml.jackson.databind.ObjectMapper; import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.model.StreamingResponseHandler; +import dev.langchain4j.model.anthropic.internal.api.*; import dev.langchain4j.model.output.Response; import dev.langchain4j.model.output.TokenUsage; import okhttp3.OkHttpClient; @@ -15,29 +15,24 @@ import org.slf4j.LoggerFactory; import retrofit2.Call; import retrofit2.Retrofit; -import retrofit2.converter.gson.GsonConverterFactory; +import retrofit2.converter.jackson.JacksonConverterFactory; import java.io.IOException; import java.util.ArrayList; import java.util.List; import java.util.concurrent.atomic.AtomicInteger; -import static com.google.gson.FieldNamingPolicy.LOWER_CASE_WITH_UNDERSCORES; -import static com.google.gson.ToNumberPolicy.LONG_OR_DOUBLE; +import static com.fasterxml.jackson.databind.SerializationFeature.INDENT_OUTPUT; import static dev.langchain4j.internal.Utils.*; import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank; -import static dev.langchain4j.model.anthropic.AnthropicMapper.toFinishReason; +import static dev.langchain4j.model.anthropic.internal.mapper.AnthropicMapper.toFinishReason; import static java.util.Collections.synchronizedList; public class DefaultAnthropicClient extends AnthropicClient { private static final Logger LOGGER = LoggerFactory.getLogger(DefaultAnthropicClient.class); - static final Gson GSON = new GsonBuilder() - .setFieldNamingPolicy(LOWER_CASE_WITH_UNDERSCORES) - .setObjectToNumberStrategy(LONG_OR_DOUBLE) - .setPrettyPrinting() - .create(); + private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper().enable(INDENT_OUTPUT); private final AnthropicApi anthropicApi; private final OkHttpClient okHttpClient; @@ -84,10 +79,11 @@ public DefaultAnthropicClient build() { this.okHttpClient = okHttpClientBuilder.build(); + Retrofit retrofit = new Retrofit.Builder() .baseUrl(ensureNotBlank(builder.baseUrl, "baseUrl")) .client(okHttpClient) - .addConverterFactory(GsonConverterFactory.create(GSON)) + .addConverterFactory(JacksonConverterFactory.create(OBJECT_MAPPER)) .build(); this.anthropicApi = retrofit.create(AnthropicApi.class); @@ -159,7 +155,7 @@ public void onEvent(EventSource eventSource, String id, String type, String data } try { - AnthropicStreamingData data = GSON.fromJson(dataString, AnthropicStreamingData.class); + AnthropicStreamingData data = OBJECT_MAPPER.readValue(dataString, AnthropicStreamingData.class); if ("message_start".equals(type)) { handleMessageStart(data); diff --git a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicMapper.java b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/mapper/AnthropicMapper.java similarity index 92% rename from langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicMapper.java rename to langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/mapper/AnthropicMapper.java index e41cbbc733..3100327667 100644 --- a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicMapper.java +++ b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/mapper/AnthropicMapper.java @@ -1,4 +1,4 @@ -package dev.langchain4j.model.anthropic; +package dev.langchain4j.model.anthropic.internal.mapper; import dev.langchain4j.agent.tool.ToolExecutionRequest; import dev.langchain4j.agent.tool.ToolParameters; @@ -6,6 +6,7 @@ import dev.langchain4j.data.image.Image; import dev.langchain4j.data.message.*; import dev.langchain4j.internal.Json; +import dev.langchain4j.model.anthropic.internal.api.*; import dev.langchain4j.model.output.FinishReason; import dev.langchain4j.model.output.TokenUsage; @@ -16,8 +17,8 @@ import static dev.langchain4j.internal.Exceptions.illegalArgument; import static dev.langchain4j.internal.Utils.*; import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank; -import static dev.langchain4j.model.anthropic.AnthropicRole.ASSISTANT; -import static dev.langchain4j.model.anthropic.AnthropicRole.USER; +import static dev.langchain4j.model.anthropic.internal.api.AnthropicRole.ASSISTANT; +import static dev.langchain4j.model.anthropic.internal.api.AnthropicRole.USER; import static dev.langchain4j.model.output.FinishReason.*; import static java.util.Collections.emptyList; import static java.util.Collections.emptyMap; @@ -26,7 +27,7 @@ public class AnthropicMapper { - static List toAnthropicMessages(List messages) { + public static List toAnthropicMessages(List messages) { List anthropicMessages = new ArrayList<>(); List toolContents = new ArrayList<>(); @@ -104,7 +105,7 @@ private static List toAnthropicMessageContents(AiMessag return contents; } - static String toAnthropicSystemPrompt(List messages) { + public static String toAnthropicSystemPrompt(List messages) { String systemPrompt = messages.stream() .filter(message -> message instanceof SystemMessage) .map(message -> ((SystemMessage) message).text()) @@ -167,7 +168,7 @@ public static FinishReason toFinishReason(String anthropicStopReason) { } } - static List toAnthropicTools(List toolSpecifications) { + public static List toAnthropicTools(List toolSpecifications) { if (toolSpecifications == null) { return null; } @@ -176,7 +177,7 @@ static List toAnthropicTools(List toolSpecific .collect(toList()); } - static AnthropicTool toAnthropicTool(ToolSpecification toolSpecification) { + public static AnthropicTool toAnthropicTool(ToolSpecification toolSpecification) { ToolParameters parameters = toolSpecification.parameters(); return AnthropicTool.builder() .name(toolSpecification.name()) diff --git a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/sanitizer/MessageSanitizer.java b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/sanitizer/MessageSanitizer.java new file mode 100644 index 0000000000..233ed9f512 --- /dev/null +++ b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/sanitizer/MessageSanitizer.java @@ -0,0 +1,61 @@ +package dev.langchain4j.model.anthropic.internal.sanitizer; + +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.data.message.SystemMessage; +import dev.langchain4j.data.message.UserMessage; +import lombok.extern.slf4j.Slf4j; + +import java.util.ArrayList; +import java.util.List; + +import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty; + +/** + * Sanitizes the messages to conform to the format expected by the Anthropic API. + */ +@Slf4j +public class MessageSanitizer { + + + public static List sanitizeMessages(List messages) { + ensureNotEmpty(messages, "messages"); + List sanitizedMessages = new ArrayList<>(messages); + stripSystemMessages(sanitizedMessages); + ensureFirstMessageIsUserMessage(sanitizedMessages); + ensureNoConsecutiveUserMessages(sanitizedMessages); + + return sanitizedMessages; + } + + private static void stripSystemMessages(List messages) { + messages.removeIf(message -> message instanceof SystemMessage); + } + + private static void ensureNoConsecutiveUserMessages(List messages) { + boolean lastWasUserMessage = false; + List toRemove = new ArrayList<>(); + + for (ChatMessage message : messages) { + if (message instanceof UserMessage) { + if (lastWasUserMessage) { + toRemove.add(message); + log.warn("Removing consecutive UserMessage: {}", ((UserMessage) message).singleText()); + } else { + lastWasUserMessage = true; + } + } else { + lastWasUserMessage = false; + } + } + + messages.removeAll(toRemove); + } + + private static void ensureFirstMessageIsUserMessage(List messages) { + while (!messages.isEmpty() && !(messages.get(0) instanceof UserMessage)) { + ChatMessage removedMessage = messages.remove(0); + log.warn("Dropping non-UserMessage in 1st element: {}", removedMessage); + } + } + +} diff --git a/langchain4j-anthropic/src/main/resources/META-INF/native-image/dev.langchain4j/langchain4j-anthropic/proxy-config.json b/langchain4j-anthropic/src/main/resources/META-INF/native-image/dev.langchain4j/langchain4j-anthropic/proxy-config.json index 2dd6700012..cf35351016 100644 --- a/langchain4j-anthropic/src/main/resources/META-INF/native-image/dev.langchain4j/langchain4j-anthropic/proxy-config.json +++ b/langchain4j-anthropic/src/main/resources/META-INF/native-image/dev.langchain4j/langchain4j-anthropic/proxy-config.json @@ -1,5 +1,5 @@ [ [ - "dev.langchain4j.model.anthropic.AnthropicApi" + "dev.langchain4j.model.anthropic.internal.api.AnthropicApi" ] ] \ No newline at end of file diff --git a/langchain4j-anthropic/src/main/resources/META-INF/native-image/dev.langchain4j/langchain4j-anthropic/reflect-config.json b/langchain4j-anthropic/src/main/resources/META-INF/native-image/dev.langchain4j/langchain4j-anthropic/reflect-config.json index f250ab6e83..a1eeb601b4 100644 --- a/langchain4j-anthropic/src/main/resources/META-INF/native-image/dev.langchain4j/langchain4j-anthropic/reflect-config.json +++ b/langchain4j-anthropic/src/main/resources/META-INF/native-image/dev.langchain4j/langchain4j-anthropic/reflect-config.json @@ -1,6 +1,6 @@ [ { - "name": "dev.langchain4j.model.anthropic.AnthropicContent", + "name": "dev.langchain4j.model.anthropic.internal.client.AnthropicContent", "allDeclaredConstructors": true, "allPublicConstructors": true, "allDeclaredMethods": true, @@ -9,7 +9,7 @@ "allPublicFields": true }, { - "name": "dev.langchain4j.model.anthropic.AnthropicCreateMessageRequest", + "name": "dev.langchain4j.model.anthropic.internal.client.AnthropicCreateMessageRequest", "allDeclaredConstructors": true, "allPublicConstructors": true, "allDeclaredMethods": true, @@ -18,7 +18,7 @@ "allPublicFields": true }, { - "name": "dev.langchain4j.model.anthropic.AnthropicCreateMessageResponse", + "name": "dev.langchain4j.model.anthropic.internal.client.AnthropicCreateMessageResponse", "allDeclaredConstructors": true, "allPublicConstructors": true, "allDeclaredMethods": true, @@ -27,7 +27,7 @@ "allPublicFields": true }, { - "name": "dev.langchain4j.model.anthropic.AnthropicDelta", + "name": "dev.langchain4j.model.anthropic.internal.client.AnthropicDelta", "allDeclaredConstructors": true, "allPublicConstructors": true, "allDeclaredMethods": true, @@ -36,7 +36,7 @@ "allPublicFields": true }, { - "name": "dev.langchain4j.model.anthropic.AnthropicImageContent", + "name": "dev.langchain4j.model.anthropic.internal.client.AnthropicImageContent", "allDeclaredConstructors": true, "allPublicConstructors": true, "allDeclaredMethods": true, @@ -45,7 +45,7 @@ "allPublicFields": true }, { - "name": "dev.langchain4j.model.anthropic.AnthropicImageContentSource", + "name": "dev.langchain4j.model.anthropic.internal.client.AnthropicImageContentSource", "allDeclaredConstructors": true, "allPublicConstructors": true, "allDeclaredMethods": true, @@ -54,7 +54,7 @@ "allPublicFields": true }, { - "name": "dev.langchain4j.model.anthropic.AnthropicMessage", + "name": "dev.langchain4j.model.anthropic.internal.client.AnthropicMessage", "allDeclaredConstructors": true, "allPublicConstructors": true, "allDeclaredMethods": true, @@ -63,7 +63,7 @@ "allPublicFields": true }, { - "name": "dev.langchain4j.model.anthropic.AnthropicMessageContent", + "name": "dev.langchain4j.model.anthropic.internal.client.AnthropicMessageContent", "allDeclaredConstructors": true, "allPublicConstructors": true, "allDeclaredMethods": true, @@ -72,7 +72,7 @@ "allPublicFields": true }, { - "name": "dev.langchain4j.model.anthropic.AnthropicResponseMessage", + "name": "dev.langchain4j.model.anthropic.internal.client.AnthropicResponseMessage", "allDeclaredConstructors": true, "allPublicConstructors": true, "allDeclaredMethods": true, @@ -81,7 +81,7 @@ "allPublicFields": true }, { - "name": "dev.langchain4j.model.anthropic.AnthropicRole", + "name": "dev.langchain4j.model.anthropic.internal.client.AnthropicRole", "allDeclaredConstructors": true, "allPublicConstructors": true, "allDeclaredMethods": true, @@ -90,7 +90,7 @@ "allPublicFields": true }, { - "name": "dev.langchain4j.model.anthropic.AnthropicStreamingData", + "name": "dev.langchain4j.model.anthropic.internal.client.AnthropicStreamingData", "allDeclaredConstructors": true, "allPublicConstructors": true, "allDeclaredMethods": true, @@ -99,7 +99,7 @@ "allPublicFields": true }, { - "name": "dev.langchain4j.model.anthropic.AnthropicTextContent", + "name": "dev.langchain4j.model.anthropic.internal.client.AnthropicTextContent", "allDeclaredConstructors": true, "allPublicConstructors": true, "allDeclaredMethods": true, @@ -108,7 +108,7 @@ "allPublicFields": true }, { - "name": "dev.langchain4j.model.anthropic.AnthropicTool", + "name": "dev.langchain4j.model.anthropic.internal.client.AnthropicTool", "allDeclaredConstructors": true, "allPublicConstructors": true, "allDeclaredMethods": true, @@ -117,7 +117,7 @@ "allPublicFields": true }, { - "name": "dev.langchain4j.model.anthropic.AnthropicToolResultContent", + "name": "dev.langchain4j.model.anthropic.internal.client.AnthropicToolResultContent", "allDeclaredConstructors": true, "allPublicConstructors": true, "allDeclaredMethods": true, @@ -126,7 +126,7 @@ "allPublicFields": true }, { - "name": "dev.langchain4j.model.anthropic.AnthropicToolSchema", + "name": "dev.langchain4j.model.anthropic.internal.client.AnthropicToolSchema", "allDeclaredConstructors": true, "allPublicConstructors": true, "allDeclaredMethods": true, @@ -135,7 +135,7 @@ "allPublicFields": true }, { - "name": "dev.langchain4j.model.anthropic.AnthropicToolUseContent", + "name": "dev.langchain4j.model.anthropic.internal.client.AnthropicToolUseContent", "allDeclaredConstructors": true, "allPublicConstructors": true, "allDeclaredMethods": true, @@ -144,7 +144,7 @@ "allPublicFields": true }, { - "name": "dev.langchain4j.model.anthropic.AnthropicUsage", + "name": "dev.langchain4j.model.anthropic.internal.client.AnthropicUsage", "allDeclaredConstructors": true, "allPublicConstructors": true, "allDeclaredMethods": true, diff --git a/langchain4j-anthropic/src/test/java/dev/langchain4j/model/anthropic/AnthropicChatModelIT.java b/langchain4j-anthropic/src/test/java/dev/langchain4j/model/anthropic/AnthropicChatModelIT.java index 90fb11ffa8..aa37438084 100644 --- a/langchain4j-anthropic/src/test/java/dev/langchain4j/model/anthropic/AnthropicChatModelIT.java +++ b/langchain4j-anthropic/src/test/java/dev/langchain4j/model/anthropic/AnthropicChatModelIT.java @@ -6,7 +6,6 @@ import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.output.Response; import dev.langchain4j.model.output.TokenUsage; -import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; @@ -24,7 +23,6 @@ import static dev.langchain4j.internal.Utils.readBytes; import static dev.langchain4j.model.anthropic.AnthropicChatModelName.CLAUDE_3_SONNET_20240229; import static dev.langchain4j.model.output.FinishReason.*; -import static java.lang.System.getenv; import static java.util.Arrays.asList; import static java.util.Arrays.stream; import static java.util.Collections.singletonList; @@ -64,17 +62,10 @@ class AnthropicChatModelIT { .addParameter("location", OBJECT, property("properties", singletonMap("city", singletonMap("type", "string")))) .build(); - @AfterEach - void afterEach() throws InterruptedException { - Thread.sleep(10_000L); // to avoid hitting rate limits - } - @Test void should_generate_answer_and_return_token_usage_and_finish_reason_stop() { // given - ChatLanguageModel model = AnthropicChatModel.withApiKey(getenv("ANTHROPIC_API_KEY")); - UserMessage userMessage = userMessage("What is the capital of Germany?"); // when @@ -294,26 +285,6 @@ void should_fail_to_create_without_api_key() { "It can be generated here: https://console.anthropic.com/settings/keys"); } - @Test - void should_fail_with_rate_limit_error() { - - ChatLanguageModel model = AnthropicChatModel.builder() - .apiKey(System.getenv("ANTHROPIC_API_KEY")) - .maxTokens(1) - .logRequests(true) - .logResponses(true) - .build(); - - assertThatThrownBy(() -> { - for (int i = 0; i < 100; i++) { - model.generate("Hi"); - } - }) - .isExactlyInstanceOf(RuntimeException.class) // TODO return AnthropicHttpException (not wrapped)? - .hasRootCauseExactlyInstanceOf(AnthropicHttpException.class) - .hasMessageContaining("rate_limit_error"); - } - @ParameterizedTest @MethodSource("models_supporting_tools") void should_execute_a_tool_then_answer(AnthropicChatModelName modelName) { diff --git a/langchain4j-anthropic/src/test/java/dev/langchain4j/model/anthropic/AnthropicMapperTest.java b/langchain4j-anthropic/src/test/java/dev/langchain4j/model/anthropic/AnthropicMapperTest.java index c73b387d67..afa85a4c98 100644 --- a/langchain4j-anthropic/src/test/java/dev/langchain4j/model/anthropic/AnthropicMapperTest.java +++ b/langchain4j-anthropic/src/test/java/dev/langchain4j/model/anthropic/AnthropicMapperTest.java @@ -3,6 +3,7 @@ import dev.langchain4j.agent.tool.ToolExecutionRequest; import dev.langchain4j.agent.tool.ToolSpecification; import dev.langchain4j.data.message.*; +import dev.langchain4j.model.anthropic.internal.api.*; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; @@ -13,10 +14,10 @@ import java.util.stream.Stream; import static dev.langchain4j.agent.tool.JsonSchemaProperty.STRING; -import static dev.langchain4j.model.anthropic.AnthropicMapper.toAnthropicMessages; -import static dev.langchain4j.model.anthropic.AnthropicMapper.toAnthropicTool; -import static dev.langchain4j.model.anthropic.AnthropicRole.ASSISTANT; -import static dev.langchain4j.model.anthropic.AnthropicRole.USER; +import static dev.langchain4j.model.anthropic.internal.mapper.AnthropicMapper.toAnthropicMessages; +import static dev.langchain4j.model.anthropic.internal.mapper.AnthropicMapper.toAnthropicTool; +import static dev.langchain4j.model.anthropic.internal.api.AnthropicRole.ASSISTANT; +import static dev.langchain4j.model.anthropic.internal.api.AnthropicRole.USER; import static java.util.Arrays.asList; import static java.util.Collections.*; import static java.util.stream.Collectors.toMap; diff --git a/langchain4j-anthropic/src/test/java/dev/langchain4j/model/anthropic/AnthropicStreamingChatModelIT.java b/langchain4j-anthropic/src/test/java/dev/langchain4j/model/anthropic/AnthropicStreamingChatModelIT.java index b9bd4663e5..b20f7a2c83 100644 --- a/langchain4j-anthropic/src/test/java/dev/langchain4j/model/anthropic/AnthropicStreamingChatModelIT.java +++ b/langchain4j-anthropic/src/test/java/dev/langchain4j/model/anthropic/AnthropicStreamingChatModelIT.java @@ -7,14 +7,12 @@ import dev.langchain4j.model.chat.TestStreamingResponseHandler; import dev.langchain4j.model.output.Response; import dev.langchain4j.model.output.TokenUsage; -import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.EnumSource; import java.time.Duration; import java.util.Base64; -import java.util.concurrent.ExecutionException; import static dev.langchain4j.data.message.UserMessage.userMessage; import static dev.langchain4j.internal.Utils.readBytes; @@ -35,11 +33,6 @@ class AnthropicStreamingChatModelIT { .logResponses(true) .build(); - @AfterEach - void afterEach() throws InterruptedException { - Thread.sleep(10_000L); // to avoid hitting rate limits - } - @Test void should_stream_answer_and_return_token_usage_and_finish_reason_stop() { @@ -149,26 +142,4 @@ void should_fail_to_create_without_api_key() { .hasMessage("Anthropic API key must be defined. " + "It can be generated here: https://console.anthropic.com/settings/keys"); } - - @Test - void should_fail_with_rate_limit_error() { - - StreamingChatLanguageModel model = AnthropicStreamingChatModel.builder() - .apiKey(System.getenv("ANTHROPIC_API_KEY")) - .maxTokens(1) - .logRequests(true) - .logResponses(true) - .build(); - - assertThatThrownBy(() -> { - for (int i = 0; i < 100; i++) { - TestStreamingResponseHandler handler = new TestStreamingResponseHandler<>(); - model.generate("Hi", handler); - handler.get(); - } - }) - .isExactlyInstanceOf(ExecutionException.class) - .hasRootCauseExactlyInstanceOf(AnthropicHttpException.class) - .hasMessageContaining("rate_limit_error"); - } } \ No newline at end of file diff --git a/langchain4j-anthropic/src/test/java/dev/langchain4j/model/anthropic/AnthropicRequestLoggingInterceptorTest.java b/langchain4j-anthropic/src/test/java/dev/langchain4j/model/anthropic/internal/client/AnthropicRequestLoggingInterceptorTest.java similarity index 82% rename from langchain4j-anthropic/src/test/java/dev/langchain4j/model/anthropic/AnthropicRequestLoggingInterceptorTest.java rename to langchain4j-anthropic/src/test/java/dev/langchain4j/model/anthropic/internal/client/AnthropicRequestLoggingInterceptorTest.java index 6621e99538..77b5071964 100644 --- a/langchain4j-anthropic/src/test/java/dev/langchain4j/model/anthropic/AnthropicRequestLoggingInterceptorTest.java +++ b/langchain4j-anthropic/src/test/java/dev/langchain4j/model/anthropic/internal/client/AnthropicRequestLoggingInterceptorTest.java @@ -1,10 +1,10 @@ -package dev.langchain4j.model.anthropic; +package dev.langchain4j.model.anthropic.internal.client; import org.junit.jupiter.api.Test; -import static dev.langchain4j.model.anthropic.AnthropicApi.X_API_KEY; -import static dev.langchain4j.model.anthropic.AnthropicRequestLoggingInterceptor.format; -import static dev.langchain4j.model.anthropic.AnthropicRequestLoggingInterceptor.maskSecretKey; +import static dev.langchain4j.model.anthropic.internal.api.AnthropicApi.X_API_KEY; +import static dev.langchain4j.model.anthropic.internal.client.AnthropicRequestLoggingInterceptor.format; +import static dev.langchain4j.model.anthropic.internal.client.AnthropicRequestLoggingInterceptor.maskSecretKey; import static org.assertj.core.api.Assertions.assertThat; class AnthropicRequestLoggingInterceptorTest { diff --git a/langchain4j-anthropic/src/test/java/dev/langchain4j/model/anthropic/internal/sanitizer/MessageSanitizerTest.java b/langchain4j-anthropic/src/test/java/dev/langchain4j/model/anthropic/internal/sanitizer/MessageSanitizerTest.java new file mode 100644 index 0000000000..a7a4ad1ea9 --- /dev/null +++ b/langchain4j-anthropic/src/test/java/dev/langchain4j/model/anthropic/internal/sanitizer/MessageSanitizerTest.java @@ -0,0 +1,214 @@ +package dev.langchain4j.model.anthropic.internal.sanitizer; + +import static org.junit.jupiter.api.Assertions.*; + +import dev.langchain4j.agent.tool.ToolExecutionRequest; +import dev.langchain4j.data.message.*; +import org.junit.jupiter.api.Test; +import java.util.ArrayList; +import java.util.List; + +class MessageSanitizerTest { + + // Default expected message values + private static final String EXPECTED_USER_MESSAGE_CONTENT = "User message"; + private static final String EXPECTED_AI_MESSAGE_CONTENT = "AI message"; + private static final String EXPECTED_SYSTEM_MESSAGE_CONTENT = "System message"; + + @Test + void test_stripSystemMessage() { + List messages = new ArrayList<>(); + + messages.add(new SystemMessage(EXPECTED_SYSTEM_MESSAGE_CONTENT)); + messages.add(new UserMessage(EXPECTED_USER_MESSAGE_CONTENT)); + + List sanitized = MessageSanitizer.sanitizeMessages(messages); + + assertEquals(1, sanitized.size()); + assertInstanceOf(UserMessage.class, sanitized.get(0)); + assertEquals(EXPECTED_USER_MESSAGE_CONTENT, ((UserMessage) sanitized.get(0)).singleText()); + } + + @Test + void test_stripMultipleSystemMessages() { + List messages = new ArrayList<>(); + + messages.add(new SystemMessage("System message 1")); + messages.add(new SystemMessage("System message 2")); + messages.add(new UserMessage(EXPECTED_USER_MESSAGE_CONTENT)); + + List sanitized = MessageSanitizer.sanitizeMessages(messages); + + assertEquals(1, sanitized.size()); + assertInstanceOf(UserMessage.class, sanitized.get(0)); + assertEquals(EXPECTED_USER_MESSAGE_CONTENT, ((UserMessage) sanitized.get(0)).singleText()); + } + + @Test + void test_removeSinglePairOfConsecutiveUserMessages() { + List messages = new ArrayList<>(); + String userMessage2 = "User message 2"; + messages.add(new UserMessage(EXPECTED_USER_MESSAGE_CONTENT)); + messages.add(new UserMessage(userMessage2)); + messages.add(new AiMessage(EXPECTED_AI_MESSAGE_CONTENT)); + + List sanitized = MessageSanitizer.sanitizeMessages(messages); + + assertEquals(2, sanitized.size()); + assertInstanceOf(UserMessage.class, sanitized.get(0)); + assertEquals(EXPECTED_USER_MESSAGE_CONTENT, ((UserMessage) sanitized.get(0)).singleText()); + assertInstanceOf(AiMessage.class, sanitized.get(1)); + assertEquals(EXPECTED_AI_MESSAGE_CONTENT, ((AiMessage) sanitized.get(1)).text()); + } + + @Test + void test_removeMultiplePairsOfConsecutiveUserMessages() { + List messages = new ArrayList<>(); + String userMessage1 = "User message 1"; + String userMessage2 = "User message 2"; + String aiMessage1 = "AI message 1"; + String userMessage3 = "User message 3"; + String userMessage4 = "User message 4"; + String aiMessage2 = "AI message 2"; + + messages.add(new UserMessage(userMessage1)); + messages.add(new UserMessage(userMessage2)); + messages.add(new AiMessage(aiMessage1)); + messages.add(new UserMessage(userMessage3)); + messages.add(new UserMessage(userMessage4)); + messages.add(new AiMessage(aiMessage2)); + + List sanitized = MessageSanitizer.sanitizeMessages(messages); + + assertEquals(4, sanitized.size()); + assertInstanceOf(UserMessage.class, sanitized.get(0)); + assertInstanceOf(AiMessage.class, sanitized.get(1)); + assertInstanceOf(UserMessage.class, sanitized.get(2)); + assertInstanceOf(AiMessage.class, sanitized.get(3)); + + assertEquals(userMessage1, ((UserMessage) sanitized.get(0)).singleText()); + assertEquals(aiMessage1, ((AiMessage) sanitized.get(1)).text()); + assertEquals(userMessage3, ((UserMessage) sanitized.get(2)).singleText()); + assertEquals(aiMessage2, ((AiMessage) sanitized.get(3)).text()); + } + + @Test + void test_aiMessageAfterUserMessage() { + List messages = new ArrayList<>(); + messages.add(new UserMessage(EXPECTED_USER_MESSAGE_CONTENT)); + messages.add(new AiMessage(EXPECTED_AI_MESSAGE_CONTENT)); + + List sanitized = MessageSanitizer.sanitizeMessages(messages); + + assertEquals(2, sanitized.size()); + assertInstanceOf(UserMessage.class, sanitized.get(0)); + assertEquals(EXPECTED_USER_MESSAGE_CONTENT, ((UserMessage) sanitized.get(0)).singleText()); + assertInstanceOf(AiMessage.class, sanitized.get(1)); + assertEquals(EXPECTED_AI_MESSAGE_CONTENT, ((AiMessage) sanitized.get(1)).text()); + } + + @Test + void test_firstMessageIsUserMessage_noChange() { + List messages = new ArrayList<>(); + messages.add(new UserMessage(EXPECTED_USER_MESSAGE_CONTENT)); + messages.add(new AiMessage(EXPECTED_AI_MESSAGE_CONTENT)); + + List sanitized = MessageSanitizer.sanitizeMessages(messages); + + assertEquals(2, sanitized.size()); + assertInstanceOf(UserMessage.class, sanitized.get(0)); + assertEquals(EXPECTED_USER_MESSAGE_CONTENT, ((UserMessage) sanitized.get(0)).singleText()); + assertInstanceOf(AiMessage.class, sanitized.get(1)); + assertEquals(EXPECTED_AI_MESSAGE_CONTENT, ((AiMessage) sanitized.get(1)).text()); + } + + @Test + void test_firstMessageIsSystemMessage() { + List messages = new ArrayList<>(); + messages.add(new SystemMessage(EXPECTED_SYSTEM_MESSAGE_CONTENT)); + messages.add(new UserMessage(EXPECTED_USER_MESSAGE_CONTENT)); + messages.add(new AiMessage(EXPECTED_AI_MESSAGE_CONTENT)); + + List sanitized = MessageSanitizer.sanitizeMessages(messages); + + assertEquals(2, sanitized.size()); + assertInstanceOf(UserMessage.class, sanitized.get(0)); + assertEquals(EXPECTED_USER_MESSAGE_CONTENT, ((UserMessage) sanitized.get(0)).singleText()); + assertInstanceOf(AiMessage.class, sanitized.get(1)); + assertEquals(EXPECTED_AI_MESSAGE_CONTENT, ((AiMessage) sanitized.get(1)).text()); + } + + @Test + void test_firstMessageIsAiMessage() { + List messages = new ArrayList<>(); + messages.add(new AiMessage(EXPECTED_AI_MESSAGE_CONTENT)); + messages.add(new UserMessage(EXPECTED_USER_MESSAGE_CONTENT)); + messages.add(new AiMessage(EXPECTED_AI_MESSAGE_CONTENT)); + + List sanitized = MessageSanitizer.sanitizeMessages(messages); + + assertEquals(2, sanitized.size()); + assertInstanceOf(UserMessage.class, sanitized.get(0)); + assertEquals(EXPECTED_USER_MESSAGE_CONTENT, ((UserMessage) sanitized.get(0)).singleText()); + assertInstanceOf(AiMessage.class, sanitized.get(1)); + assertEquals(EXPECTED_AI_MESSAGE_CONTENT, ((AiMessage) sanitized.get(1)).text()); + } + + @Test + void test_invalidStartingMessageWithInvalidUserPair() { + List messages = new ArrayList<>(); + String userMessage1 = "User message 1"; + String userMessage2 = "User message 2"; + messages.add(new AiMessage(EXPECTED_AI_MESSAGE_CONTENT)); + messages.add(new UserMessage(userMessage1)); + messages.add(new UserMessage(userMessage2)); + messages.add(new AiMessage(EXPECTED_AI_MESSAGE_CONTENT)); + + List sanitized = MessageSanitizer.sanitizeMessages(messages); + + assertEquals(2, sanitized.size()); + assertInstanceOf(UserMessage.class, sanitized.get(0)); + assertEquals(userMessage1, ((UserMessage) sanitized.get(0)).singleText()); + assertInstanceOf(AiMessage.class, sanitized.get(1)); + assertEquals(EXPECTED_AI_MESSAGE_CONTENT, ((AiMessage) sanitized.get(1)).text()); + } + + @Test + void test_toolExecutionMessages() { + String expectedUserMessageContent = "What is the product of 2x2?"; + String expectedAiMessageAfterTool = "The answer for 2x2 is 4"; + + List messages = new ArrayList<>(); + messages.add(SystemMessage.from("The agent exists to help with arithmetic problems.")); + messages.add(UserMessage.from(expectedUserMessageContent)); + messages.add(AiMessage.from(ToolExecutionRequest.builder() + .id("12345") + .name("calculator") + .arguments("{\"first\": 2, \"second\": 2}") + .build())); + messages.add(ToolExecutionResultMessage.from("12345", "calculator", "4")); + messages.add(AiMessage.from(expectedAiMessageAfterTool)); + + List sanitized = MessageSanitizer.sanitizeMessages(messages); + + assertEquals(4, sanitized.size()); + + assertInstanceOf(UserMessage.class, sanitized.get(0)); + assertEquals(expectedUserMessageContent, ((UserMessage) sanitized.get(0)).singleText()); + + assertInstanceOf(AiMessage.class, sanitized.get(1)); + ToolExecutionRequest toolExecutionRequest = ((AiMessage) sanitized.get(1)).toolExecutionRequests().get(0); + assertEquals("12345", toolExecutionRequest.id()); + assertEquals("calculator", toolExecutionRequest.name()); + assertEquals("{\"first\": 2, \"second\": 2}", toolExecutionRequest.arguments()); + + assertInstanceOf(ToolExecutionResultMessage.class, sanitized.get(2)); + ToolExecutionResultMessage toolExecutionResultMessage = (ToolExecutionResultMessage) sanitized.get(2); + assertEquals("12345", toolExecutionResultMessage.id()); + assertEquals("calculator", toolExecutionResultMessage.toolName()); + assertEquals("4", toolExecutionResultMessage.text()); + + assertInstanceOf(AiMessage.class, sanitized.get(3)); + assertEquals(expectedAiMessageAfterTool, ((AiMessage) sanitized.get(3)).text()); + } +} diff --git a/langchain4j-azure-ai-search/pom.xml b/langchain4j-azure-ai-search/pom.xml index e70a0b9ef0..cb63ec402b 100644 --- a/langchain4j-azure-ai-search/pom.xml +++ b/langchain4j-azure-ai-search/pom.xml @@ -7,7 +7,7 @@ dev.langchain4j langchain4j-parent - 0.30.0 + 0.32.0-SNAPSHOT ../langchain4j-parent/pom.xml diff --git a/langchain4j-azure-ai-search/src/main/java/dev/langchain4j/rag/content/retriever/azure/search/AzureAiSearchContentRetriever.java b/langchain4j-azure-ai-search/src/main/java/dev/langchain4j/rag/content/retriever/azure/search/AzureAiSearchContentRetriever.java index 137b0fb068..ed3935ab7f 100644 --- a/langchain4j-azure-ai-search/src/main/java/dev/langchain4j/rag/content/retriever/azure/search/AzureAiSearchContentRetriever.java +++ b/langchain4j-azure-ai-search/src/main/java/dev/langchain4j/rag/content/retriever/azure/search/AzureAiSearchContentRetriever.java @@ -58,6 +58,7 @@ public class AzureAiSearchContentRetriever extends AbstractAzureAiSearchEmbeddin private final AzureAiSearchQueryType azureAiSearchQueryType; private final int maxResults; + private final double minScore; public AzureAiSearchContentRetriever(String endpoint, @@ -66,6 +67,7 @@ public AzureAiSearchContentRetriever(String endpoint, boolean createOrUpdateIndex, int dimensions, SearchIndex index, + String indexName, EmbeddingModel embeddingModel, int maxResults, double minScore, @@ -86,15 +88,15 @@ public AzureAiSearchContentRetriever(String endpoint, } if (keyCredential == null) { if (index == null) { - this.initialize(endpoint, null, tokenCredential, createOrUpdateIndex, dimensions, null); + this.initialize(endpoint, null, tokenCredential, createOrUpdateIndex, dimensions, null, indexName); } else { - this.initialize(endpoint, null, tokenCredential, createOrUpdateIndex, 0, index); + this.initialize(endpoint, null, tokenCredential, createOrUpdateIndex, 0, index, indexName); } } else { if (index == null) { - this.initialize(endpoint, keyCredential, null, createOrUpdateIndex, dimensions, null); + this.initialize(endpoint, keyCredential, null, createOrUpdateIndex, dimensions, null, indexName); } else { - this.initialize(endpoint, keyCredential, null, createOrUpdateIndex, 0, index); + this.initialize(endpoint, keyCredential, null, createOrUpdateIndex, 0, index, indexName); } } this.embeddingModel = embeddingModel; @@ -285,6 +287,8 @@ public static class Builder { private SearchIndex index; + private String indexName; + private EmbeddingModel embeddingModel; private int maxResults = EmbeddingStoreContentRetriever.DEFAULT_MAX_RESULTS.apply(null); @@ -361,6 +365,17 @@ public Builder index(SearchIndex index) { return this; } + /** + * If no index is provided, set the name of the default index to be used. + * + * @param indexName The index name to be used. + * @return builder + */ + public Builder indexName(String indexName) { + this.indexName = indexName; + return this; + } + /** * Sets the Embedding Model. * @@ -408,7 +423,7 @@ public Builder queryType(AzureAiSearchQueryType azureAiSearchQueryType) { public AzureAiSearchContentRetriever build() { return new AzureAiSearchContentRetriever(endpoint, keyCredential, tokenCredential, createOrUpdateIndex, dimensions, index, - embeddingModel, maxResults, minScore, azureAiSearchQueryType); + indexName, embeddingModel, maxResults, minScore, azureAiSearchQueryType); } } } diff --git a/langchain4j-azure-ai-search/src/main/java/dev/langchain4j/store/embedding/azure/search/AbstractAzureAiSearchEmbeddingStore.java b/langchain4j-azure-ai-search/src/main/java/dev/langchain4j/store/embedding/azure/search/AbstractAzureAiSearchEmbeddingStore.java index dad41ffaff..bcc70f4d0a 100644 --- a/langchain4j-azure-ai-search/src/main/java/dev/langchain4j/store/embedding/azure/search/AbstractAzureAiSearchEmbeddingStore.java +++ b/langchain4j-azure-ai-search/src/main/java/dev/langchain4j/store/embedding/azure/search/AbstractAzureAiSearchEmbeddingStore.java @@ -23,6 +23,8 @@ import java.util.*; import static dev.langchain4j.internal.Utils.*; +import static dev.langchain4j.internal.Utils.getOrDefault; +import static dev.langchain4j.internal.ValidationUtils.ensureNotNull; import static dev.langchain4j.internal.ValidationUtils.ensureTrue; import static java.util.Collections.singletonList; import static java.util.stream.Collectors.toList; @@ -31,7 +33,7 @@ public abstract class AbstractAzureAiSearchEmbeddingStore implements EmbeddingSt private static final Logger log = LoggerFactory.getLogger(AbstractAzureAiSearchEmbeddingStore.class); - public static final String INDEX_NAME = "vectorsearch"; + public static final String DEFAULT_INDEX_NAME = "vectorsearch"; static final String DEFAULT_FIELD_ID = "id"; @@ -57,7 +59,20 @@ public abstract class AbstractAzureAiSearchEmbeddingStore implements EmbeddingSt protected SearchClient searchClient; - protected void initialize(String endpoint, AzureKeyCredential keyCredential, TokenCredential tokenCredential, boolean createOrUpdateIndex, int dimensions, SearchIndex index) { + private String indexName; + + protected void initialize(String endpoint, AzureKeyCredential keyCredential, TokenCredential tokenCredential, boolean createOrUpdateIndex, int dimensions, SearchIndex index, String indexName) { + ensureNotNull(endpoint, "endpoint"); + if (index != null && isNotNullOrBlank(indexName)) { + // if an index is provided, it has its own name already configured + // if the indexName is provided, it will be used when creating the default index + throw new IllegalArgumentException("index and indexName cannot be both defined"); + } + if (createOrUpdateIndex && index != null) { + this.indexName = index.getName(); + } else { + this.indexName = getOrDefault(indexName, DEFAULT_INDEX_NAME); + } this.createOrUpdateIndex = createOrUpdateIndex; if (keyCredential != null) { if (createOrUpdateIndex) { @@ -70,7 +85,7 @@ protected void initialize(String endpoint, AzureKeyCredential keyCredential, Tok searchClient = new SearchClientBuilder() .endpoint(endpoint) .credential(keyCredential) - .indexName(INDEX_NAME) + .indexName(this.indexName) .buildClient(); } else { if (createOrUpdateIndex) { @@ -83,7 +98,7 @@ protected void initialize(String endpoint, AzureKeyCredential keyCredential, Tok searchClient = new SearchClientBuilder() .endpoint(endpoint) .credential(tokenCredential) - .indexName(INDEX_NAME) + .indexName(this.indexName) .buildClient(); } @@ -161,12 +176,12 @@ public void createOrUpdateIndex(int dimensions) { .setContentFields(new SemanticField(DEFAULT_FIELD_CONTENT)) .setKeywordsFields(new SemanticField(DEFAULT_FIELD_CONTENT))))); - index = new SearchIndex(INDEX_NAME) + index = new SearchIndex(this.indexName) .setFields(fields) .setVectorSearch(vectorSearch) .setSemanticSearch(semanticSearch); } else { - index = new SearchIndex(INDEX_NAME) + index = new SearchIndex(this.indexName) .setFields(fields); } @@ -189,7 +204,7 @@ public void deleteIndex() { if (!createOrUpdateIndex) { throw new IllegalArgumentException("createOrUpdateIndex is false, so the index cannot be deleted"); } - searchIndexClient.deleteIndex(INDEX_NAME); + searchIndexClient.deleteIndex(this.indexName); } /** @@ -313,10 +328,10 @@ private void addAllInternal( document.setContent(embedded.get(i).text()); Document.Metadata metadata = new Document.Metadata(); List attributes = new ArrayList<>(); - for (Map.Entry entry : embedded.get(i).metadata().asMap().entrySet()) { + for (Map.Entry entry : embedded.get(i).metadata().toMap().entrySet()) { Document.Metadata.Attribute attribute = new Document.Metadata.Attribute(); attribute.setKey(entry.getKey()); - attribute.setValue(entry.getValue()); + attribute.setValue(String.valueOf(entry.getValue())); attributes.add(attribute); } metadata.setAttributes(attributes); diff --git a/langchain4j-azure-ai-search/src/main/java/dev/langchain4j/store/embedding/azure/search/AzureAiSearchEmbeddingStore.java b/langchain4j-azure-ai-search/src/main/java/dev/langchain4j/store/embedding/azure/search/AzureAiSearchEmbeddingStore.java index f25a6c39bf..625c3e1229 100644 --- a/langchain4j-azure-ai-search/src/main/java/dev/langchain4j/store/embedding/azure/search/AzureAiSearchEmbeddingStore.java +++ b/langchain4j-azure-ai-search/src/main/java/dev/langchain4j/store/embedding/azure/search/AzureAiSearchEmbeddingStore.java @@ -14,20 +14,20 @@ */ public class AzureAiSearchEmbeddingStore extends AbstractAzureAiSearchEmbeddingStore implements EmbeddingStore { - public AzureAiSearchEmbeddingStore(String endpoint, AzureKeyCredential keyCredential, boolean createOrUpdateIndex, int dimensions) { - this.initialize(endpoint, keyCredential, null, createOrUpdateIndex, dimensions, null); + public AzureAiSearchEmbeddingStore(String endpoint, AzureKeyCredential keyCredential, boolean createOrUpdateIndex, int dimensions, String indexName) { + this.initialize(endpoint, keyCredential, null, createOrUpdateIndex, dimensions, null, indexName); } - public AzureAiSearchEmbeddingStore(String endpoint, AzureKeyCredential keyCredential, boolean createOrUpdateIndex, SearchIndex index) { - this.initialize(endpoint, keyCredential, null, createOrUpdateIndex, 0, index); + public AzureAiSearchEmbeddingStore(String endpoint, AzureKeyCredential keyCredential, boolean createOrUpdateIndex, SearchIndex index, String indexName) { + this.initialize(endpoint, keyCredential, null, createOrUpdateIndex, 0, index, indexName); } - public AzureAiSearchEmbeddingStore(String endpoint, TokenCredential tokenCredential, boolean createOrUpdateIndex, int dimensions) { - this.initialize(endpoint, null, tokenCredential, createOrUpdateIndex, dimensions, null); + public AzureAiSearchEmbeddingStore(String endpoint, TokenCredential tokenCredential, boolean createOrUpdateIndex, int dimensions, String indexName) { + this.initialize(endpoint, null, tokenCredential, createOrUpdateIndex, dimensions, null, indexName); } - public AzureAiSearchEmbeddingStore(String endpoint, TokenCredential tokenCredential, boolean createOrUpdateIndex, SearchIndex index) { - this.initialize(endpoint, null, tokenCredential, createOrUpdateIndex, 0, index); + public AzureAiSearchEmbeddingStore(String endpoint, TokenCredential tokenCredential, boolean createOrUpdateIndex, SearchIndex index, String indexName) { + this.initialize(endpoint, null, tokenCredential, createOrUpdateIndex, 0, index, indexName); } public static Builder builder() { @@ -48,6 +48,8 @@ public static class Builder { private SearchIndex index; + private String indexName; + /** * Sets the Azure AI Search endpoint. This is a mandatory parameter. * @@ -115,21 +117,32 @@ public Builder index(SearchIndex index) { return this; } + /** + * If no index is provided, set the name of the default index to be used. + * + * @param indexName The name of the index to be used. + * @return builder + */ + public Builder indexName(String indexName) { + this.indexName = indexName; + return this; + } + public AzureAiSearchEmbeddingStore build() { ensureNotNull(endpoint, "endpoint"); ensureTrue(keyCredential != null || tokenCredential != null, "either apiKey or tokenCredential must be set"); ensureTrue(dimensions > 0 || index != null, "either dimensions or index must be set"); if (keyCredential == null) { if (index == null) { - return new AzureAiSearchEmbeddingStore(endpoint, tokenCredential, createOrUpdateIndex, dimensions); + return new AzureAiSearchEmbeddingStore(endpoint, tokenCredential, createOrUpdateIndex, dimensions, indexName); } else { - return new AzureAiSearchEmbeddingStore(endpoint, tokenCredential, createOrUpdateIndex, index); + return new AzureAiSearchEmbeddingStore(endpoint, tokenCredential, createOrUpdateIndex, index, indexName); } } else { if (index == null) { - return new AzureAiSearchEmbeddingStore(endpoint, keyCredential, createOrUpdateIndex, dimensions); + return new AzureAiSearchEmbeddingStore(endpoint, keyCredential, createOrUpdateIndex, dimensions, indexName); } else { - return new AzureAiSearchEmbeddingStore(endpoint, keyCredential, createOrUpdateIndex, index); + return new AzureAiSearchEmbeddingStore(endpoint, keyCredential, createOrUpdateIndex, index, indexName); } } } diff --git a/langchain4j-azure-ai-search/src/test/java/dev/langchain4j/rag/content/retriever/azure/search/AzureAiSearchContentRetrieverIT.java b/langchain4j-azure-ai-search/src/test/java/dev/langchain4j/rag/content/retriever/azure/search/AzureAiSearchContentRetrieverIT.java index cb7d051ff6..a1c6127066 100644 --- a/langchain4j-azure-ai-search/src/test/java/dev/langchain4j/rag/content/retriever/azure/search/AzureAiSearchContentRetrieverIT.java +++ b/langchain4j-azure-ai-search/src/test/java/dev/langchain4j/rag/content/retriever/azure/search/AzureAiSearchContentRetrieverIT.java @@ -19,7 +19,7 @@ import java.util.List; -import static dev.langchain4j.store.embedding.azure.search.AbstractAzureAiSearchEmbeddingStore.INDEX_NAME; +import static dev.langchain4j.store.embedding.azure.search.AbstractAzureAiSearchEmbeddingStore.DEFAULT_INDEX_NAME; import static java.util.Arrays.asList; import static org.assertj.core.api.Assertions.assertThat; @@ -49,7 +49,7 @@ public AzureAiSearchContentRetrieverIT() { .credential(new AzureKeyCredential(System.getenv("AZURE_SEARCH_KEY"))) .buildClient(); - searchIndexClient.deleteIndex(INDEX_NAME); + searchIndexClient.deleteIndex(DEFAULT_INDEX_NAME); contentRetrieverWithVector = createContentRetriever(AzureAiSearchQueryType.VECTOR); contentRetrieverWithFullText = createFullTextSearchContentRetriever(); diff --git a/langchain4j-azure-ai-search/src/test/java/dev/langchain4j/rag/content/retriever/azure/search/AzureAiSearchContentRetrieverTest.java b/langchain4j-azure-ai-search/src/test/java/dev/langchain4j/rag/content/retriever/azure/search/AzureAiSearchContentRetrieverTest.java index 6d6ccdf4ac..b63052c093 100644 --- a/langchain4j-azure-ai-search/src/test/java/dev/langchain4j/rag/content/retriever/azure/search/AzureAiSearchContentRetrieverTest.java +++ b/langchain4j-azure-ai-search/src/test/java/dev/langchain4j/rag/content/retriever/azure/search/AzureAiSearchContentRetrieverTest.java @@ -28,7 +28,7 @@ public void testConstructorMandatoryParameters() { // Test empty endpoint try { - new AzureAiSearchContentRetriever(null, keyCredential, tokenCredential, true, dimensions, index, embeddingModel, 3, 0, AzureAiSearchQueryType.VECTOR); + new AzureAiSearchContentRetriever(null, keyCredential, tokenCredential, true, dimensions, index, null, embeddingModel, 3, 0, AzureAiSearchQueryType.VECTOR); fail("Expected IllegalArgumentException to be thrown"); } catch (IllegalArgumentException e) { assertEquals("endpoint cannot be null", e.getMessage()); @@ -36,7 +36,7 @@ public void testConstructorMandatoryParameters() { // Test no credentials try { - new AzureAiSearchContentRetriever(endpoint, null, null, true, dimensions, index, embeddingModel, 3, 0, AzureAiSearchQueryType.VECTOR); + new AzureAiSearchContentRetriever(endpoint, null, null, true, dimensions, index, null, embeddingModel, 3, 0, AzureAiSearchQueryType.VECTOR); fail("Expected IllegalArgumentException to be thrown"); } catch (IllegalArgumentException e) { assertEquals("either keyCredential or tokenCredential must be set", e.getMessage()); @@ -44,7 +44,7 @@ public void testConstructorMandatoryParameters() { // Test both credentials try { - new AzureAiSearchContentRetriever(endpoint, keyCredential, tokenCredential, true, dimensions, index, embeddingModel, 3, 0, AzureAiSearchQueryType.VECTOR); + new AzureAiSearchContentRetriever(endpoint, keyCredential, tokenCredential, true, dimensions, index, null, embeddingModel, 3, 0, AzureAiSearchQueryType.VECTOR); fail("Expected IllegalArgumentException to be thrown"); } catch (IllegalArgumentException e) { assertEquals("either keyCredential or tokenCredential must be set", e.getMessage()); @@ -52,7 +52,7 @@ public void testConstructorMandatoryParameters() { // Test no dimensions and no index, for a vector search try { - new AzureAiSearchContentRetriever(endpoint, null, tokenCredential, true, 0, null, embeddingModel, 3, 0, AzureAiSearchQueryType.VECTOR); + new AzureAiSearchContentRetriever(endpoint, null, tokenCredential, true, 0, null, null, embeddingModel, 3, 0, AzureAiSearchQueryType.VECTOR); fail("Expected IllegalArgumentException to be thrown"); } catch (IllegalArgumentException e) { assertEquals("dimensions must be set to a positive, non-zero integer between 2 and 3072", e.getMessage()); @@ -60,7 +60,7 @@ public void testConstructorMandatoryParameters() { // Test dimensions > 0, for a full text search try { - new AzureAiSearchContentRetriever(endpoint, keyCredential, null, true, dimensions, null, embeddingModel, 3, 0, AzureAiSearchQueryType.FULL_TEXT); + new AzureAiSearchContentRetriever(endpoint, keyCredential, null, true, dimensions, null, null, embeddingModel, 3, 0, AzureAiSearchQueryType.FULL_TEXT); fail("Expected IllegalArgumentException to be thrown"); } catch (IllegalArgumentException e) { assertEquals("for full-text search, dimensions must be 0", e.getMessage()); @@ -68,7 +68,7 @@ public void testConstructorMandatoryParameters() { // Test no embedding model, for a vector search try { - new AzureAiSearchContentRetriever(endpoint, keyCredential, null, true, 0, null, null, 3, 0, AzureAiSearchQueryType.VECTOR); + new AzureAiSearchContentRetriever(endpoint, keyCredential, null, true, 0, null, null, null, 3, 0, AzureAiSearchQueryType.VECTOR); fail("Expected IllegalArgumentException to be thrown"); } catch (IllegalArgumentException e) { assertEquals("embeddingModel cannot be null", e.getMessage()); diff --git a/langchain4j-azure-ai-search/src/test/java/dev/langchain4j/store/embedding/azure/search/AzureAiSearchEmbeddingStoreIT.java b/langchain4j-azure-ai-search/src/test/java/dev/langchain4j/store/embedding/azure/search/AzureAiSearchEmbeddingStoreIT.java index 727dea875d..54652a4efd 100644 --- a/langchain4j-azure-ai-search/src/test/java/dev/langchain4j/store/embedding/azure/search/AzureAiSearchEmbeddingStoreIT.java +++ b/langchain4j-azure-ai-search/src/test/java/dev/langchain4j/store/embedding/azure/search/AzureAiSearchEmbeddingStoreIT.java @@ -1,5 +1,11 @@ package dev.langchain4j.store.embedding.azure.search; +import com.azure.core.credential.AzureKeyCredential; +import com.azure.search.documents.indexes.SearchIndexClient; +import com.azure.search.documents.indexes.SearchIndexClientBuilder; +import com.azure.search.documents.indexes.models.SearchField; +import com.azure.search.documents.indexes.models.SearchFieldDataType; +import com.azure.search.documents.indexes.models.SearchIndex; import dev.langchain4j.data.embedding.Embedding; import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel; @@ -13,10 +19,15 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.util.ArrayList; import java.util.List; +import static dev.langchain4j.store.embedding.azure.search.AbstractAzureAiSearchEmbeddingStore.DEFAULT_FIELD_ID; +import static dev.langchain4j.store.embedding.azure.search.AbstractAzureAiSearchEmbeddingStore.DEFAULT_INDEX_NAME; import static java.util.Arrays.asList; import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.fail; @EnabledIfEnvironmentVariable(named = "AZURE_SEARCH_ENDPOINT", matches = ".+") public class AzureAiSearchEmbeddingStoreIT extends EmbeddingStoreIT { @@ -29,14 +40,18 @@ public class AzureAiSearchEmbeddingStoreIT extends EmbeddingStoreIT { private int dimensions; + private String AZURE_SEARCH_ENDPOINT = System.getenv("AZURE_SEARCH_ENDPOINT"); + + private String AZURE_SEARCH_KEY = System.getenv("AZURE_SEARCH_KEY"); + public AzureAiSearchEmbeddingStoreIT() { embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel(); dimensions = embeddingModel.embed("test").content().vector().length; embeddingStore = AzureAiSearchEmbeddingStore.builder() - .endpoint(System.getenv("AZURE_SEARCH_ENDPOINT")) - .apiKey(System.getenv("AZURE_SEARCH_KEY")) + .endpoint(AZURE_SEARCH_ENDPOINT) + .apiKey(AZURE_SEARCH_KEY) .dimensions(dimensions) .build(); } @@ -47,7 +62,54 @@ void setUp() { } @Test - void testAddEmbeddingsAndFindRelevant() { + public void when_an_index_is_provided_its_name_should_be_used() { + String providedIndexName = "provided-index"; + // Clear the index before running tests + SearchIndexClient searchIndexClient = new SearchIndexClientBuilder() + .endpoint(AZURE_SEARCH_ENDPOINT) + .credential(new AzureKeyCredential(AZURE_SEARCH_KEY)) + .buildClient(); + try { + searchIndexClient.deleteIndex(providedIndexName); + } catch (Exception e) { + // The index didn't exist, so we can ignore the exception + } + + // Run the tests + List fields = new ArrayList<>(); + fields.add(new SearchField(DEFAULT_FIELD_ID, SearchFieldDataType.STRING) + .setKey(true) + .setFilterable(true)); + SearchIndex providedIndex = new SearchIndex(providedIndexName).setFields(fields); + AzureAiSearchEmbeddingStore store = + new AzureAiSearchEmbeddingStore(AZURE_SEARCH_ENDPOINT, + new AzureKeyCredential(AZURE_SEARCH_KEY), true, providedIndex, null); + + assertEquals(providedIndexName, store.searchClient.getIndexName()); + + try { + new AzureAiSearchEmbeddingStore(AZURE_SEARCH_ENDPOINT, + new AzureKeyCredential(AZURE_SEARCH_KEY), true, providedIndex, "ANOTHER_INDEX_NAME"); + + fail("Expected IllegalArgumentException to be thrown"); + } catch (IllegalArgumentException e) { + assertEquals("index and indexName cannot be both defined", e.getMessage()); + } + + // Clear index + searchIndexClient.deleteIndex(providedIndexName); + } + + @Test + public void when_an_index_is_not_provided_the_default_name_is_used() { + AzureAiSearchEmbeddingStore store =new AzureAiSearchEmbeddingStore(AZURE_SEARCH_ENDPOINT, + new AzureKeyCredential(AZURE_SEARCH_KEY), false, null, null); + + assertEquals(DEFAULT_INDEX_NAME, store.searchClient.getIndexName()); + } + + @Test + void test_add_embeddings_and_find_relevant() { String content1 = "banana"; String content2 = "computer"; String content3 = "apple"; diff --git a/langchain4j-azure-ai-search/src/test/java/dev/langchain4j/store/embedding/azure/search/AzureAiSearchEmbeddingStoreTest.java b/langchain4j-azure-ai-search/src/test/java/dev/langchain4j/store/embedding/azure/search/AzureAiSearchEmbeddingStoreTest.java new file mode 100644 index 0000000000..94facb4a15 --- /dev/null +++ b/langchain4j-azure-ai-search/src/test/java/dev/langchain4j/store/embedding/azure/search/AzureAiSearchEmbeddingStoreTest.java @@ -0,0 +1,37 @@ +package dev.langchain4j.store.embedding.azure.search; + +import com.azure.core.credential.AzureKeyCredential; +import com.azure.search.documents.indexes.models.SearchIndex; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.fail; + +public class AzureAiSearchEmbeddingStoreTest { + + String endpoint = "http://localhost"; + AzureKeyCredential keyCredential = new AzureKeyCredential("TEST"); + int dimensions = 1536; + SearchIndex index = new SearchIndex("TEST"); + String indexName = "TEST"; + + @Test + public void empty_endpoint_should_not_be_allowed() { + try { + new AzureAiSearchEmbeddingStore(null, keyCredential, false, dimensions, null); + fail("Expected IllegalArgumentException to be thrown"); + } catch (IllegalArgumentException e) { + assertEquals("endpoint cannot be null", e.getMessage()); + } + } + + @Test + public void index_and_index_name_should_not_both_be_defined() { + try { + new AzureAiSearchEmbeddingStore(endpoint, keyCredential, false, index, indexName); + fail("Expected IllegalArgumentException to be thrown"); + } catch (IllegalArgumentException e) { + assertEquals("index and indexName cannot be both defined", e.getMessage()); + } + } +} diff --git a/langchain4j-azure-cosmos-mongo-vcore/pom.xml b/langchain4j-azure-cosmos-mongo-vcore/pom.xml index 4dfd238165..36c7cb4225 100644 --- a/langchain4j-azure-cosmos-mongo-vcore/pom.xml +++ b/langchain4j-azure-cosmos-mongo-vcore/pom.xml @@ -6,7 +6,7 @@ dev.langchain4j langchain4j-parent - 0.30.0 + 0.32.0-SNAPSHOT ../langchain4j-parent/pom.xml diff --git a/langchain4j-azure-cosmos-nosql/pom.xml b/langchain4j-azure-cosmos-nosql/pom.xml new file mode 100644 index 0000000000..43146daa0e --- /dev/null +++ b/langchain4j-azure-cosmos-nosql/pom.xml @@ -0,0 +1,75 @@ + + + 4.0.0 + + + dev.langchain4j + langchain4j-parent + 0.32.0-SNAPSHOT + ../langchain4j-parent/pom.xml + + + langchain4j-azure-cosmos-nosql + LangChain4j :: Integration :: Azure CosmosDB NoSQL + + + + + dev.langchain4j + langchain4j-core + + + + com.azure + azure-cosmos + 4.60.0 + + + + org.projectlombok + lombok + provided + + + + org.slf4j + slf4j-api + + + + dev.langchain4j + langchain4j-core + tests + test-jar + test + + + + org.junit.jupiter + junit-jupiter-engine + test + + + + org.assertj + assertj-core + test + + + + dev.langchain4j + langchain4j-embeddings-all-minilm-l6-v2-q + test + + + + org.testcontainers + junit-jupiter + test + + + + + \ No newline at end of file diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiEmbedding.java b/langchain4j-azure-cosmos-nosql/src/main/java/dev/langchain4j/store/embedding/azure/cosmos/nosql/AzureCosmosDbNoSqlDocument.java similarity index 52% rename from langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiEmbedding.java rename to langchain4j-azure-cosmos-nosql/src/main/java/dev/langchain4j/store/embedding/azure/cosmos/nosql/AzureCosmosDbNoSqlDocument.java index 375d6b2fba..63432d221b 100644 --- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiEmbedding.java +++ b/langchain4j-azure-cosmos-nosql/src/main/java/dev/langchain4j/store/embedding/azure/cosmos/nosql/AzureCosmosDbNoSqlDocument.java @@ -1,4 +1,4 @@ -package dev.langchain4j.model.mistralai; +package dev.langchain4j.store.embedding.azure.cosmos.nosql; import lombok.AllArgsConstructor; import lombok.Builder; @@ -6,14 +6,15 @@ import lombok.NoArgsConstructor; import java.util.List; +import java.util.Map; @Data @NoArgsConstructor @AllArgsConstructor @Builder -class MistralAiEmbedding { - - private String object; +class AzureCosmosDbNoSqlDocument { + private String id; private List embedding; - private Integer index; + private String text; + private Map metadata; } diff --git a/langchain4j-azure-cosmos-nosql/src/main/java/dev/langchain4j/store/embedding/azure/cosmos/nosql/AzureCosmosDbNoSqlEmbeddingStore.java b/langchain4j-azure-cosmos-nosql/src/main/java/dev/langchain4j/store/embedding/azure/cosmos/nosql/AzureCosmosDbNoSqlEmbeddingStore.java new file mode 100644 index 0000000000..dcce04f5b9 --- /dev/null +++ b/langchain4j-azure-cosmos-nosql/src/main/java/dev/langchain4j/store/embedding/azure/cosmos/nosql/AzureCosmosDbNoSqlEmbeddingStore.java @@ -0,0 +1,165 @@ +package dev.langchain4j.store.embedding.azure.cosmos.nosql; + +import com.azure.cosmos.CosmosClient; +import com.azure.cosmos.CosmosContainer; +import com.azure.cosmos.CosmosDatabase; +import com.azure.cosmos.models.*; +import com.azure.cosmos.util.CosmosPagedIterable; +import dev.langchain4j.data.embedding.Embedding; +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.store.embedding.EmbeddingMatch; +import dev.langchain4j.store.embedding.EmbeddingStore; +import lombok.Builder; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; + +import static dev.langchain4j.internal.Utils.*; +import static dev.langchain4j.internal.ValidationUtils.ensureTrue; +import static dev.langchain4j.store.embedding.azure.cosmos.nosql.MappingUtils.toNoSqlDbDocument; +import static java.util.Collections.singletonList; + +/** + * You can read more about vector search using Azure Cosmos DB NoSQL + * here. + */ +public class AzureCosmosDbNoSqlEmbeddingStore implements EmbeddingStore { + + private static final Logger log = LoggerFactory.getLogger(AzureCosmosDbNoSqlEmbeddingStore.class); + + private final CosmosClient cosmosClient; + private final String databaseName; + private final String containerName; + private final CosmosVectorEmbeddingPolicy cosmosVectorEmbeddingPolicy; + private final List cosmosVectorIndexes; + private final CosmosContainerProperties containerProperties; + private final String embeddingKey; + private final CosmosDatabase database; + private final CosmosContainer container; + + @Builder + public AzureCosmosDbNoSqlEmbeddingStore(CosmosClient cosmosClient, + String databaseName, + String containerName, + CosmosVectorEmbeddingPolicy cosmosVectorEmbeddingPolicy, + List cosmosVectorIndexes, + CosmosContainerProperties containerProperties) { + this.cosmosClient = cosmosClient; + this.databaseName = databaseName; + this.containerName = containerName; + this.cosmosVectorEmbeddingPolicy = cosmosVectorEmbeddingPolicy; + this.cosmosVectorIndexes = cosmosVectorIndexes; + this.containerProperties = containerProperties; + + if (cosmosClient == null) { + throw new IllegalArgumentException("cosmosClient cannot be null or empty for Azure CosmosDB NoSql Embedding Store."); + } + + if (isNullOrBlank(databaseName) || isNullOrBlank(containerName)) { + throw new IllegalArgumentException("databaseName and containerName needs to be provided."); + } + + if (cosmosVectorEmbeddingPolicy == null || cosmosVectorEmbeddingPolicy.getVectorEmbeddings() == null || + cosmosVectorEmbeddingPolicy.getVectorEmbeddings().isEmpty()) { + throw new IllegalArgumentException("cosmosVectorEmbeddingPolicy cannot be null or empty for Azure CosmosDB NoSql Embedding Store."); + } + + if (cosmosVectorIndexes == null || cosmosVectorIndexes.isEmpty()) { + throw new IllegalArgumentException("cosmosVectorIndexes cannot be null or empty for Azure CosmosDB NoSql Embedding Store."); + } + + this.cosmosClient.createDatabaseIfNotExists(this.databaseName); + this.database = this.cosmosClient.getDatabase(this.databaseName); + + containerProperties.setVectorEmbeddingPolicy(this.cosmosVectorEmbeddingPolicy); + containerProperties.getIndexingPolicy().setVectorIndexes(this.cosmosVectorIndexes); + + this.database.createContainerIfNotExists(this.containerProperties); + this.container = this.database.getContainer(this.containerName); + + this.embeddingKey = this.cosmosVectorEmbeddingPolicy.getVectorEmbeddings().get(0).getPath().substring(1); + } + + @Override + public String add(Embedding embedding) { + String id = randomUUID(); + add(id, embedding); + return id; + } + + @Override + public void add(String id, Embedding embedding) { + addInternal(id, embedding, null); + } + + @Override + public String add(Embedding embedding, TextSegment textSegment) { + String id = randomUUID(); + addInternal(id, embedding, textSegment); + return id; + } + + @Override + public List addAll(List embeddings) { + List ids = embeddings.stream() + .map(ignored -> randomUUID()) + .collect(Collectors.toList()); + addAllInternal(ids, embeddings, null); + return ids; + } + + @Override + public List addAll(List embeddings, List embedded) { + List ids = embeddings.stream() + .map(ignored -> randomUUID()) + .collect(Collectors.toList()); + addAllInternal(ids, embeddings, embedded); + return ids; + } + + @Override + public List> findRelevant(Embedding referenceEmbedding, int maxResults, double minScore) { + String referenceEmbeddingString = referenceEmbedding.vectorAsList().stream() + .map(Object::toString) + .collect(Collectors.joining(",")); + + String query = String.format("SELECT TOP %d c.id, c.%s, c.text, c.metadata, VectorDistance(c.%s,[%s]) AS score FROM c ORDER By " + + "VectorDistance(c.%s,[%s])", maxResults, embeddingKey, embeddingKey, referenceEmbeddingString, embeddingKey, referenceEmbeddingString); + + CosmosPagedIterable results = this.container.queryItems(query, + new CosmosQueryRequestOptions(), AzureCosmosDbNoSqlMatchedDocument.class); + + if (!results.stream().findAny().isPresent()) { + return new ArrayList<>(); + } + return results.stream() + .map(MappingUtils::toEmbeddingMatch) + .collect(Collectors.toList()); + } + + private void addInternal(String id, Embedding embedding, TextSegment embedded) { + addAllInternal(singletonList(id), singletonList(embedding), embedded == null ? null : singletonList(embedded)); + } + + private void addAllInternal(List ids, List embeddings, List embedded) { + if (isNullOrEmpty(ids) || isNullOrEmpty(embeddings)) { + log.info("do not add empty embeddings to Azure CosmosDB NoSQL"); + return; + } + + ensureTrue(ids.size() == embeddings.size(), "ids size is not equal to embeddings size"); + ensureTrue(embedded == null || embeddings.size() == embedded.size(), "embeddings size is not equal to embedded size"); + + List operations = new ArrayList<>(ids.size()); + for (int i = 0; i < ids.size(); i++) { + operations.add(CosmosBulkOperations.getCreateItemOperation( + toNoSqlDbDocument(ids.get(i), embeddings.get(i), embedded == null ? null : embedded.get(i)), + new PartitionKey(ids.get(i)))); + } + + this.container.executeBulkOperations(operations); + } +} diff --git a/langchain4j-azure-cosmos-nosql/src/main/java/dev/langchain4j/store/embedding/azure/cosmos/nosql/AzureCosmosDbNoSqlMatchedDocument.java b/langchain4j-azure-cosmos-nosql/src/main/java/dev/langchain4j/store/embedding/azure/cosmos/nosql/AzureCosmosDbNoSqlMatchedDocument.java new file mode 100644 index 0000000000..59ac8d445e --- /dev/null +++ b/langchain4j-azure-cosmos-nosql/src/main/java/dev/langchain4j/store/embedding/azure/cosmos/nosql/AzureCosmosDbNoSqlMatchedDocument.java @@ -0,0 +1,20 @@ +package dev.langchain4j.store.embedding.azure.cosmos.nosql; + +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; + +import java.util.List; +import java.util.Map; + +@Data +@NoArgsConstructor +@AllArgsConstructor +public class AzureCosmosDbNoSqlMatchedDocument { + + private String id; + private List embedding; + private String text; + private Map metadata; + private Double score; +} diff --git a/langchain4j-azure-cosmos-nosql/src/main/java/dev/langchain4j/store/embedding/azure/cosmos/nosql/MappingUtils.java b/langchain4j-azure-cosmos-nosql/src/main/java/dev/langchain4j/store/embedding/azure/cosmos/nosql/MappingUtils.java new file mode 100644 index 0000000000..9f9d389edd --- /dev/null +++ b/langchain4j-azure-cosmos-nosql/src/main/java/dev/langchain4j/store/embedding/azure/cosmos/nosql/MappingUtils.java @@ -0,0 +1,28 @@ +package dev.langchain4j.store.embedding.azure.cosmos.nosql; + +import dev.langchain4j.data.document.Metadata; +import dev.langchain4j.data.embedding.Embedding; +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.store.embedding.EmbeddingMatch; + +class MappingUtils { + + private MappingUtils() throws InstantiationException { + throw new InstantiationException("can't instantiate this class"); + } + + static AzureCosmosDbNoSqlDocument toNoSqlDbDocument(String id, Embedding embedding, TextSegment textSegment) { + if (textSegment == null) { + return new AzureCosmosDbNoSqlDocument(id, embedding.vectorAsList(), null, null); + } + return new AzureCosmosDbNoSqlDocument(id, embedding.vectorAsList(), textSegment.text(), textSegment.metadata().asMap()); + } + + static EmbeddingMatch toEmbeddingMatch(AzureCosmosDbNoSqlMatchedDocument matchedDocument) { + TextSegment textSegment = null; + if (matchedDocument.getText() != null) { + textSegment = TextSegment.from(matchedDocument.getText(), Metadata.from(matchedDocument.getMetadata())); + } + return new EmbeddingMatch<>(matchedDocument.getScore(), matchedDocument.getId(), Embedding.from(matchedDocument.getEmbedding()), textSegment); + } +} diff --git a/langchain4j-azure-cosmos-nosql/src/test/java/dev/langchain4j/store/embedding/azure/cosmos/nosql/AzureCosmosDbNoSqlEmbeddingStoreIT.java b/langchain4j-azure-cosmos-nosql/src/test/java/dev/langchain4j/store/embedding/azure/cosmos/nosql/AzureCosmosDbNoSqlEmbeddingStoreIT.java new file mode 100644 index 0000000000..515b5099b7 --- /dev/null +++ b/langchain4j-azure-cosmos-nosql/src/test/java/dev/langchain4j/store/embedding/azure/cosmos/nosql/AzureCosmosDbNoSqlEmbeddingStoreIT.java @@ -0,0 +1,179 @@ +package dev.langchain4j.store.embedding.azure.cosmos.nosql; + +import com.azure.cosmos.ConsistencyLevel; +import com.azure.cosmos.CosmosClient; +import com.azure.cosmos.CosmosClientBuilder; +import com.azure.cosmos.CosmosDatabase; +import com.azure.cosmos.models.*; +import dev.langchain4j.data.embedding.Embedding; +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel; +import dev.langchain4j.model.embedding.EmbeddingModel; +import dev.langchain4j.store.embedding.EmbeddingMatch; +import dev.langchain4j.store.embedding.EmbeddingStore; +import dev.langchain4j.store.embedding.EmbeddingStoreIT; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +import static java.util.Arrays.asList; +import static org.assertj.core.api.Assertions.assertThat; + +@EnabledIfEnvironmentVariable(named = "AZURE_COSMOS_HOST", matches = ".+") +@EnabledIfEnvironmentVariable(named = "AZURE_COSMOS_MASTER_KEY", matches = ".+") +class AzureCosmosDbNoSqlEmbeddingStoreIT extends EmbeddingStoreIT { + + protected static Logger logger = LoggerFactory.getLogger(AzureCosmosDbNoSqlEmbeddingStoreIT.class); + + private static final String DATABASE_NAME = "test_database_langchain_java"; + private static final String CONTAINER_NAME = "test_container"; + private CosmosClient client; + CosmosDatabase database; + private final EmbeddingModel embeddingModel; + private final EmbeddingStore embeddingStore; + private final int dimensions; + private final String HOST = System.getenv("AZURE_COSMOS_HOST"); + private final String KEY = System.getenv("AZURE_COSMOS_MASTER_KEY"); + + public AzureCosmosDbNoSqlEmbeddingStoreIT() { + embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel(); + dimensions = embeddingModel.embed("hello").content().vector().length; + + client = new CosmosClientBuilder() + .endpoint(HOST) + .key(KEY) + .consistencyLevel(ConsistencyLevel.EVENTUAL) + .contentResponseOnWriteEnabled(true) + .buildClient(); + + embeddingStore = AzureCosmosDbNoSqlEmbeddingStore.builder() + .cosmosClient(client) + .databaseName(DATABASE_NAME) + .containerName(CONTAINER_NAME) + .cosmosVectorEmbeddingPolicy(populateVectorEmbeddingPolicy(dimensions)) + .cosmosVectorIndexes(populateVectorIndexSpec()) + .containerProperties(populateContainerProperties()) + .build(); + database = client.getDatabase(DATABASE_NAME); + } + + @Test + public void testAddEmbeddingsAndFindRelevant() { + String content1 = "banana"; + String content2 = "computer"; + String content3 = "apple"; + String content4 = "pizza"; + String content5 = "strawberry"; + String content6 = "chess"; + + List contents = asList(content1, content2, content3, content4, content5, content6); + + for (String content : contents) { + TextSegment textSegment = TextSegment.from(content); + Embedding embedding = embeddingModel.embed(content).content(); + embeddingStore.add(embedding, textSegment); + } + + awaitUntilPersisted(); + + Embedding relevantEmbedding = embeddingModel.embed("fruit").content(); + List> relevant = embeddingStore.findRelevant(relevantEmbedding, 3); + assertThat(relevant).hasSize(3); + assertThat(relevant.get(0).embedding()).isNotNull(); + assertThat(relevant.get(0).embedded().text()).isIn(content1, content3, content5); + logger.info("#1 relevant item: {}", relevant.get(0).embedded().text()); + assertThat(relevant.get(1).embedding()).isNotNull(); + assertThat(relevant.get(1).embedded().text()).isIn(content1, content3, content5); + logger.info("#2 relevant item: {}", relevant.get(1).embedded().text()); + assertThat(relevant.get(2).embedding()).isNotNull(); + assertThat(relevant.get(2).embedded().text()).isIn(content1, content3, content5); + logger.info("#3 relevant item: {}", relevant.get(2).embedded().text()); + + safeDeleteDatabase(database); + safeClose(client); + } + + @Override + protected EmbeddingStore embeddingStore() { + return embeddingStore; + } + + @Override + protected EmbeddingModel embeddingModel() { + return embeddingModel; + } + + @Override + protected void awaitUntilPersisted() { + try { + Thread.sleep(1_000); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + + @Override + protected void clearStore() { + } + + private void safeDeleteDatabase(CosmosDatabase database) { + if (database != null) { + try { + database.delete(); + } catch (Exception e) { + } + } + } + + private void safeClose(CosmosClient client) { + if (client != null) { + try { + client.close(); + } catch (Exception e) { + logger.error("failed to close client", e); + } + } + } + + private CosmosVectorEmbeddingPolicy populateVectorEmbeddingPolicy(int dimensions) { + CosmosVectorEmbeddingPolicy vectorEmbeddingPolicy = new CosmosVectorEmbeddingPolicy(); + CosmosVectorEmbedding embedding = new CosmosVectorEmbedding(); + embedding.setPath("/embedding"); + embedding.setDataType(CosmosVectorDataType.FLOAT32); + embedding.setDimensions((long) dimensions); + embedding.setDistanceFunction(CosmosVectorDistanceFunction.COSINE); + vectorEmbeddingPolicy.setCosmosVectorEmbeddings(Collections.singletonList(embedding)); + return vectorEmbeddingPolicy; + } + + private List populateVectorIndexSpec() { + CosmosVectorIndexSpec cosmosVectorIndexSpec = new CosmosVectorIndexSpec(); + cosmosVectorIndexSpec.setPath("/embedding"); + cosmosVectorIndexSpec.setType(CosmosVectorIndexType.FLAT.toString()); + return Collections.singletonList(cosmosVectorIndexSpec); + } + + private CosmosContainerProperties populateContainerProperties() { + PartitionKeyDefinition partitionKeyDef = new PartitionKeyDefinition(); + ArrayList paths = new ArrayList(); + paths.add("/id"); + partitionKeyDef.setPaths(paths); + + CosmosContainerProperties collectionDefinition = new CosmosContainerProperties(CONTAINER_NAME, partitionKeyDef); + + IndexingPolicy indexingPolicy = new IndexingPolicy(); + indexingPolicy.setIndexingMode(IndexingMode.CONSISTENT); + IncludedPath includedPath = new IncludedPath("/*"); + indexingPolicy.setIncludedPaths(Collections.singletonList(includedPath)); + + collectionDefinition.setIndexingPolicy(indexingPolicy); + return collectionDefinition; + } + + +} diff --git a/langchain4j-azure-cosmos-nosql/src/test/java/dev/langchain4j/store/embedding/azure/cosmos/nosql/AzureCosmosDbNoSqlEmbeddingStoreTest.java b/langchain4j-azure-cosmos-nosql/src/test/java/dev/langchain4j/store/embedding/azure/cosmos/nosql/AzureCosmosDbNoSqlEmbeddingStoreTest.java new file mode 100644 index 0000000000..337e58f4e4 --- /dev/null +++ b/langchain4j-azure-cosmos-nosql/src/test/java/dev/langchain4j/store/embedding/azure/cosmos/nosql/AzureCosmosDbNoSqlEmbeddingStoreTest.java @@ -0,0 +1,141 @@ +package dev.langchain4j.store.embedding.azure.cosmos.nosql; + +import com.azure.cosmos.ConsistencyLevel; +import com.azure.cosmos.CosmosClient; +import com.azure.cosmos.CosmosClientBuilder; +import com.azure.cosmos.models.*; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertThrows; + +@EnabledIfEnvironmentVariable(named = "AZURE_COSMOS_HOST", matches = ".+") +@EnabledIfEnvironmentVariable(named = "AZURE_COSMOS_MASTER_KEY", matches = ".+") +class AzureCosmosDbNoSqlEmbeddingStoreTest { + + private static final String DATABASE_NAME = "test_db"; + private static final String CONTAINER_NAME = "test_container"; + + @Test + void should_fail_if_cosmosClient_missing() { + assertThrows(IllegalArgumentException.class, () -> { + AzureCosmosDbNoSqlEmbeddingStore.builder() + .cosmosClient(null) + .build(); + }); + } + + @Test + void should_fail_if_databaseName_collectionName_missing() { + + CosmosClient cosmosClient = new CosmosClientBuilder() + .endpoint("HOST") + .key("MASTER_KEY") + .consistencyLevel(ConsistencyLevel.EVENTUAL) + .contentResponseOnWriteEnabled(true) + .buildClient(); + + assertThrows(IllegalArgumentException.class, () -> AzureCosmosDbNoSqlEmbeddingStore.builder() + .cosmosClient(cosmosClient) + .databaseName(null) + .build()); + + assertThrows(IllegalArgumentException.class, () -> AzureCosmosDbNoSqlEmbeddingStore.builder() + .cosmosClient(cosmosClient) + .databaseName("") + .build()); + + assertThrows(IllegalArgumentException.class, () -> AzureCosmosDbNoSqlEmbeddingStore.builder() + .cosmosClient(cosmosClient) + .databaseName("test_database") + .containerName(null) + .build()); + + assertThrows(IllegalArgumentException.class, () -> AzureCosmosDbNoSqlEmbeddingStore.builder() + .cosmosClient(cosmosClient) + .databaseName("test_database") + .containerName("") + .build()); + } + + @Test + void should_fail_if_cosmosVectorEmbeddingPolicy_missing() { + CosmosClient cosmosClient = new CosmosClientBuilder() + .endpoint("HOST") + .key("MASTER_KEY") + .consistencyLevel(ConsistencyLevel.EVENTUAL) + .contentResponseOnWriteEnabled(true) + .buildClient(); + + CosmosVectorEmbeddingPolicy cosmosVectorEmbeddingPolicy = new CosmosVectorEmbeddingPolicy(); + + assertThrows(IllegalArgumentException.class, () -> AzureCosmosDbNoSqlEmbeddingStore.builder() + .cosmosClient(cosmosClient) + .databaseName(DATABASE_NAME) + .containerName(CONTAINER_NAME) + .cosmosVectorEmbeddingPolicy(null) + .build()); + + + cosmosVectorEmbeddingPolicy.setCosmosVectorEmbeddings(null); + assertThrows(IllegalArgumentException.class, () -> AzureCosmosDbNoSqlEmbeddingStore.builder() + .cosmosClient(cosmosClient) + .databaseName(DATABASE_NAME) + .containerName(CONTAINER_NAME) + .cosmosVectorEmbeddingPolicy(cosmosVectorEmbeddingPolicy) + .build()); + + cosmosVectorEmbeddingPolicy.setCosmosVectorEmbeddings(new ArrayList<>()); + assertThrows(IllegalArgumentException.class, () -> AzureCosmosDbNoSqlEmbeddingStore.builder() + .cosmosClient(cosmosClient) + .databaseName(DATABASE_NAME) + .containerName(CONTAINER_NAME) + .cosmosVectorEmbeddingPolicy(cosmosVectorEmbeddingPolicy) + .build()); + } + + @Test + void should_fail_if_cosmosVectorIndexes_missing() { + CosmosClient cosmosClient = new CosmosClientBuilder() + .endpoint("HOST") + .key("MASTER_KEY") + .consistencyLevel(ConsistencyLevel.EVENTUAL) + .contentResponseOnWriteEnabled(true) + .buildClient(); + + CosmosVectorEmbeddingPolicy cosmosVectorEmbeddingPolicy = populateCosmosVectorEmbeddingPolicy(); + + assertThrows(IllegalArgumentException.class, () -> AzureCosmosDbNoSqlEmbeddingStore.builder() + .cosmosClient(cosmosClient) + .databaseName(DATABASE_NAME) + .containerName(CONTAINER_NAME) + .cosmosVectorEmbeddingPolicy(cosmosVectorEmbeddingPolicy) + .cosmosVectorIndexes(null) + .build()); + + List cosmosVectorIndexes = new ArrayList(); + assertThrows(IllegalArgumentException.class, () -> AzureCosmosDbNoSqlEmbeddingStore.builder() + .cosmosClient(cosmosClient) + .databaseName(DATABASE_NAME) + .containerName(CONTAINER_NAME) + .cosmosVectorEmbeddingPolicy(cosmosVectorEmbeddingPolicy) + .cosmosVectorIndexes(cosmosVectorIndexes) + .build()); + } + + private CosmosVectorEmbeddingPolicy populateCosmosVectorEmbeddingPolicy() { + CosmosVectorEmbeddingPolicy policy = new CosmosVectorEmbeddingPolicy(); + CosmosVectorEmbedding embedding1 = new CosmosVectorEmbedding(); + embedding1.setPath("/embedding"); + embedding1.setDataType(CosmosVectorDataType.FLOAT32); + embedding1.setDimensions(128L); + embedding1.setDistanceFunction(CosmosVectorDistanceFunction.COSINE); + policy.setCosmosVectorEmbeddings(Collections.singletonList(embedding1)); + return policy; + } + +} diff --git a/langchain4j-azure-open-ai/pom.xml b/langchain4j-azure-open-ai/pom.xml index 24ca3cff8f..dad533e14c 100644 --- a/langchain4j-azure-open-ai/pom.xml +++ b/langchain4j-azure-open-ai/pom.xml @@ -7,7 +7,7 @@ dev.langchain4j langchain4j-parent - 0.30.0 + 0.32.0-SNAPSHOT ../langchain4j-parent/pom.xml @@ -34,10 +34,17 @@ dev.langchain4j - langchain4j-open-ai + langchain4j-core + tests + test-jar test + + com.knuddels + jtokkit + + org.apache.logging.log4j log4j-api diff --git a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiChatModel.java b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiChatModel.java index ae66a362d7..e282e3caa9 100644 --- a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiChatModel.java +++ b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiChatModel.java @@ -15,7 +15,6 @@ import dev.langchain4j.model.chat.TokenCountEstimator; import dev.langchain4j.model.output.FinishReason; import dev.langchain4j.model.output.Response; -import dev.langchain4j.model.output.TokenUsage; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -25,6 +24,7 @@ import static dev.langchain4j.data.message.AiMessage.aiMessage; import static dev.langchain4j.internal.Utils.getOrDefault; +import static dev.langchain4j.internal.Utils.isNullOrEmpty; import static dev.langchain4j.model.azure.InternalAzureOpenAiHelper.*; import static dev.langchain4j.spi.ServiceHelper.loadFactories; import static java.util.Collections.singletonList; @@ -120,7 +120,7 @@ public AzureOpenAiChatModel(String endpoint, boolean logRequestsAndResponses) { this(deploymentName, tokenizer, maxTokens, temperature, topP, logitBias, user, n, stop, presencePenalty, frequencyPenalty, dataSources, enhancements, seed, responseFormat); - this.client = setupOpenAIClient(endpoint, serviceVersion, apiKey, timeout, maxRetries, proxyOptions, logRequestsAndResponses); + this.client = setupSyncClient(endpoint, serviceVersion, apiKey, timeout, maxRetries, proxyOptions, logRequestsAndResponses); } public AzureOpenAiChatModel(String endpoint, @@ -147,7 +147,7 @@ public AzureOpenAiChatModel(String endpoint, boolean logRequestsAndResponses) { this(deploymentName, tokenizer, maxTokens, temperature, topP, logitBias, user, n, stop, presencePenalty, frequencyPenalty, dataSources, enhancements, seed, responseFormat); - this.client = setupOpenAIClient(endpoint, serviceVersion, keyCredential, timeout, maxRetries, proxyOptions, logRequestsAndResponses); + this.client = setupSyncClient(endpoint, serviceVersion, keyCredential, timeout, maxRetries, proxyOptions, logRequestsAndResponses); } public AzureOpenAiChatModel(String endpoint, @@ -174,7 +174,7 @@ public AzureOpenAiChatModel(String endpoint, boolean logRequestsAndResponses) { this(deploymentName, tokenizer, maxTokens, temperature, topP, logitBias, user, n, stop, presencePenalty, frequencyPenalty, dataSources, enhancements, seed, responseFormat); - this.client = setupOpenAIClient(endpoint, serviceVersion, tokenCredential, timeout, maxRetries, proxyOptions, logRequestsAndResponses); + this.client = setupSyncClient(endpoint, serviceVersion, tokenCredential, timeout, maxRetries, proxyOptions, logRequestsAndResponses); } private AzureOpenAiChatModel(String deploymentName, @@ -245,11 +245,12 @@ private Response generate(List messages, .setSeed(seed) .setResponseFormat(responseFormat); - if (toolSpecifications != null && !toolSpecifications.isEmpty()) { - options.setFunctions(toFunctions(toolSpecifications)); - } if (toolThatMustBeExecuted != null) { - options.setFunctionCall(new FunctionCallConfig(toolThatMustBeExecuted.name())); + options.setTools(toToolDefinitions(singletonList(toolThatMustBeExecuted))); + options.setToolChoice(toToolChoice(toolThatMustBeExecuted)); + } + if (!isNullOrEmpty(toolSpecifications)) { + options.setTools(toToolDefinitions(toolSpecifications)); } try { diff --git a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiChatModelName.java b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiChatModelName.java new file mode 100644 index 0000000000..b5ab1de43f --- /dev/null +++ b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiChatModelName.java @@ -0,0 +1,65 @@ +package dev.langchain4j.model.azure; + +/** + * You can get the latest model names from the Azure OpenAI documentation or by executing the Azure CLI command: + * az cognitiveservices account list-models --resource-group "$RESOURCE_GROUP" --name "$AI_SERVICE" -o table + */ +public enum AzureOpenAiChatModelName { + + GPT_3_5_TURBO("gpt-35-turbo", "gpt-3.5-turbo"), // alias for the latest gpt-3.5-turbo model + GPT_3_5_TURBO_0301("gpt-35-turbo-0301", "gpt-3.5-turbo", "0301"), // 4k context, functions + GPT_3_5_TURBO_0613("gpt-35-turbo-0613", "gpt-3.5-turbo", "0613"), // 4k context, functions + GPT_3_5_TURBO_1106("gpt-35-turbo-1106", "gpt-3.5-turbo", "1106"), // 16k context, functions + + GPT_3_5_TURBO_16K("gpt-35-turbo-16k", "gpt-3.5-turbo-16k"), // alias for the latest gpt-3.5-turbo-16k model + GPT_3_5_TURBO_16K_0613("gpt-35-turbo-16k-0613", "gpt-3.5-turbo-16k", "0613"), // 16k context, functions + + GPT_4("gpt-4", "gpt-4"), // alias for the latest gpt-4 + GPT_4_0613("gpt-4-0613", "gpt-4", "0613"), // 8k context, functions + GPT_4_0125_PREVIEW("gpt-4-0125-preview", "gpt-4", "0125-preview"), // 8k context + GPT_4_1106_PREVIEW("gpt-4-1106-preview", "gpt-4", "1106-preview"), // 8k context + + GPT_4_TURBO("gpt-4-turbo", "gpt-4-turbo"), // alias for the latest gpt-4-turbo model + GPT_4_TURBO_2024_04_09("gpt-4-turbo-2024-04-09", "gpt-4-turbo", "2024-04-09"), // alias for the latest gpt-4-turbo model + + GPT_4_32K("gpt-4-32k", "gpt-4-32k"), // alias for the latest gpt-32k model + GPT_4_32K_0613("gpt-4-32k-0613", "gpt-4-32k", "0613"), // 32k context, functions + + GPT_4_VISION_PREVIEW("gpt-4-vision-preview", "gpt-4-vision", "preview"), + + GPT_4_O("gpt-4o", "gpt-4o"); // alias for the latest gpt-4o model + + private final String modelName; + // Model type follows the com.knuddels.jtokkit.api.ModelType naming convention + private final String modelType; + private final String modelVersion; + + AzureOpenAiChatModelName(String modelName, String modelType) { + this.modelName = modelName; + this.modelType = modelType; + this.modelVersion = null; + } + + AzureOpenAiChatModelName(String modelName, String modelType, String modelVersion) { + this.modelName = modelName; + this.modelType = modelType; + this.modelVersion = modelVersion; + } + + public String modelName() { + return modelName; + } + + public String modelType() { + return modelType; + } + + public String modelVersion() { + return modelVersion; + } + + @Override + public String toString() { + return modelName; + } +} diff --git a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiEmbeddingModel.java b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiEmbeddingModel.java index cb97754877..ed736b3669 100644 --- a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiEmbeddingModel.java +++ b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiEmbeddingModel.java @@ -20,8 +20,9 @@ import java.util.ArrayList; import java.util.List; +import static dev.langchain4j.data.embedding.Embedding.from; import static dev.langchain4j.internal.Utils.getOrDefault; -import static dev.langchain4j.model.azure.InternalAzureOpenAiHelper.setupOpenAIClient; +import static dev.langchain4j.model.azure.InternalAzureOpenAiHelper.setupSyncClient; import static dev.langchain4j.spi.ServiceHelper.loadFactories; import static java.util.stream.Collectors.toList; @@ -76,7 +77,7 @@ public AzureOpenAiEmbeddingModel(String endpoint, boolean logRequestsAndResponses) { this(deploymentName, tokenizer); - this.client = setupOpenAIClient(endpoint, serviceVersion, apiKey, timeout, maxRetries, proxyOptions, logRequestsAndResponses); + this.client = setupSyncClient(endpoint, serviceVersion, apiKey, timeout, maxRetries, proxyOptions, logRequestsAndResponses); } public AzureOpenAiEmbeddingModel(String endpoint, @@ -90,7 +91,7 @@ public AzureOpenAiEmbeddingModel(String endpoint, boolean logRequestsAndResponses) { this(deploymentName, tokenizer); - this.client = setupOpenAIClient(endpoint, serviceVersion, keyCredential, timeout, maxRetries, proxyOptions, logRequestsAndResponses); + this.client = setupSyncClient(endpoint, serviceVersion, keyCredential, timeout, maxRetries, proxyOptions, logRequestsAndResponses); } public AzureOpenAiEmbeddingModel(String endpoint, @@ -104,7 +105,7 @@ public AzureOpenAiEmbeddingModel(String endpoint, boolean logRequestsAndResponses) { this(deploymentName, tokenizer); - this.client = setupOpenAIClient(endpoint, serviceVersion, tokenCredential, timeout, maxRetries, proxyOptions, logRequestsAndResponses); + this.client = setupSyncClient(endpoint, serviceVersion, tokenCredential, timeout, maxRetries, proxyOptions, logRequestsAndResponses); } private AzureOpenAiEmbeddingModel(String deploymentName, @@ -157,14 +158,6 @@ private Response> embedTexts(List texts) { ); } - private static Embedding from(List vector) { - float[] langChainVector = new float[vector.size()]; - for (int index = 0; index < vector.size(); index++) { - langChainVector[index] = vector.get(index).floatValue(); - } - return Embedding.from(langChainVector); - } - @Override public int estimateTokenCount(String text) { return tokenizer.estimateTokenCountInText(text); diff --git a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiEmbeddingModelName.java b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiEmbeddingModelName.java new file mode 100644 index 0000000000..63880badad --- /dev/null +++ b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiEmbeddingModelName.java @@ -0,0 +1,47 @@ +package dev.langchain4j.model.azure; + +public enum AzureOpenAiEmbeddingModelName { + + TEXT_EMBEDDING_3_SMALL("text-embedding-3-small", "text-embedding-3-small"), // alias for the latest text-embedding-3-small model + TEXT_EMBEDDING_3_SMALL_1("text-embedding-3-small-1", "text-embedding-3-small", "1"), + TEXT_EMBEDDING_3_LARGE("text-embedding-3-large", "text-embedding-3-large"), + TEXT_EMBEDDING_3_LARGE_1("text-embedding-3-large-1", "text-embedding-3-large", "1"), + + TEXT_EMBEDDING_ADA_002("text-embedding-ada-002", "text-embedding-ada-002"), // alias for the latest text-embedding-ada-002 model + TEXT_EMBEDDING_ADA_002_1("text-embedding-ada-002-1", "text-embedding-ada-002", "1"), + TEXT_EMBEDDING_ADA_002_2("text-embedding-ada-002-2", "text-embedding-ada-002", "2"); + + private final String modelName; + // Model type follows the com.knuddels.jtokkit.api.ModelType naming convention + private final String modelType; + private final String modelVersion; + + AzureOpenAiEmbeddingModelName(String modelName, String modelType) { + this.modelName = modelName; + this.modelType = modelType; + this.modelVersion = null; + } + + AzureOpenAiEmbeddingModelName(String modelName, String modelType, String modelVersion) { + this.modelName = modelName; + this.modelType = modelType; + this.modelVersion = null; + } + + public String modelName() { + return modelName; + } + + public String modelType() { + return modelType; + } + + public String modelVersion() { + return modelVersion; + } + + @Override + public String toString() { + return modelName; + } +} diff --git a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiImageModel.java b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiImageModel.java index b8084a06e1..8098f614b4 100644 --- a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiImageModel.java +++ b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiImageModel.java @@ -11,12 +11,10 @@ import dev.langchain4j.model.image.ImageModel; import dev.langchain4j.model.output.FinishReason; import dev.langchain4j.model.output.Response; -import dev.langchain4j.model.output.TokenUsage; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.time.Duration; -import java.util.Map; import static dev.langchain4j.data.message.AiMessage.aiMessage; import static dev.langchain4j.internal.Utils.getOrDefault; @@ -89,7 +87,7 @@ public AzureOpenAiImageModel(String endpoint, boolean logRequestsAndResponses) { this(deploymentName, quality, size, user, style, responseFormat); - this.client = setupOpenAIClient(endpoint, serviceVersion, apiKey, timeout, maxRetries, proxyOptions, logRequestsAndResponses); + this.client = setupSyncClient(endpoint, serviceVersion, apiKey, timeout, maxRetries, proxyOptions, logRequestsAndResponses); } public AzureOpenAiImageModel(String endpoint, @@ -107,7 +105,7 @@ public AzureOpenAiImageModel(String endpoint, boolean logRequestsAndResponses) { this(deploymentName, quality, size, user, style, responseFormat); - this.client = setupOpenAIClient(endpoint, serviceVersion, keyCredential, timeout, maxRetries, proxyOptions, logRequestsAndResponses); + this.client = setupSyncClient(endpoint, serviceVersion, keyCredential, timeout, maxRetries, proxyOptions, logRequestsAndResponses); } public AzureOpenAiImageModel(String endpoint, @@ -125,7 +123,7 @@ public AzureOpenAiImageModel(String endpoint, boolean logRequestsAndResponses) { this(deploymentName, quality, size, user, style, responseFormat); - this.client = setupOpenAIClient(endpoint, serviceVersion, tokenCredential, timeout, maxRetries, proxyOptions, logRequestsAndResponses); + this.client = setupSyncClient(endpoint, serviceVersion, tokenCredential, timeout, maxRetries, proxyOptions, logRequestsAndResponses); } private AzureOpenAiImageModel(String deploymentName, String quality, String size, String user, String style, String responseFormat) { diff --git a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiImageModelName.java b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiImageModelName.java new file mode 100644 index 0000000000..9655cf2ffb --- /dev/null +++ b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiImageModelName.java @@ -0,0 +1,41 @@ +package dev.langchain4j.model.azure; + +public enum AzureOpenAiImageModelName { + + DALL_E_3("dall-e-3", "dall-e-3"), // alias for the latest dall-e-3 model + DALL_E_3_30("dall-e-3-30", "dall-e-3","30"); + + private final String modelName; + // Model type follows the com.knuddels.jtokkit.api.ModelType naming convention + private final String modelType; + private final String modelVersion; + + AzureOpenAiImageModelName(String modelName, String modelType) { + this.modelName = modelName; + this.modelType = modelType; + this.modelVersion = null; + } + + AzureOpenAiImageModelName(String modelName, String modelType, String modelVersion) { + this.modelName = modelName; + this.modelType = modelType; + this.modelVersion = modelVersion; + } + + public String modelName() { + return modelName; + } + + public String modelType() { + return modelType; + } + + public String modelVersion() { + return modelVersion; + } + + @Override + public String toString() { + return modelName; + } +} diff --git a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiLanguageModel.java b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiLanguageModel.java index 3a28e59469..4df7082a14 100644 --- a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiLanguageModel.java +++ b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiLanguageModel.java @@ -13,7 +13,6 @@ import dev.langchain4j.model.language.TokenCountEstimator; import dev.langchain4j.model.output.FinishReason; import dev.langchain4j.model.output.Response; -import dev.langchain4j.model.output.TokenUsage; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -117,7 +116,7 @@ public AzureOpenAiLanguageModel(String endpoint, boolean logRequestsAndResponses) { this(deploymentName, tokenizer, maxTokens, temperature, topP, logitBias, user, n, logprobs, echo, stop, presencePenalty, frequencyPenalty, bestOf); - this.client = setupOpenAIClient(endpoint, serviceVersion, apiKey, timeout, maxRetries, proxyOptions, logRequestsAndResponses); + this.client = setupSyncClient(endpoint, serviceVersion, apiKey, timeout, maxRetries, proxyOptions, logRequestsAndResponses); } public AzureOpenAiLanguageModel(String endpoint, @@ -143,7 +142,7 @@ public AzureOpenAiLanguageModel(String endpoint, boolean logRequestsAndResponses) { this(deploymentName, tokenizer, maxTokens, temperature, topP, logitBias, user, n, logprobs, echo, stop, presencePenalty, frequencyPenalty, bestOf); - this.client = setupOpenAIClient(endpoint, serviceVersion, keyCredential, timeout, maxRetries, proxyOptions, logRequestsAndResponses); + this.client = setupSyncClient(endpoint, serviceVersion, keyCredential, timeout, maxRetries, proxyOptions, logRequestsAndResponses); } public AzureOpenAiLanguageModel(String endpoint, @@ -169,7 +168,7 @@ public AzureOpenAiLanguageModel(String endpoint, boolean logRequestsAndResponses) { this(deploymentName, tokenizer, maxTokens, temperature, topP, logitBias, user, n, logprobs, echo, stop, presencePenalty, frequencyPenalty, bestOf); - this.client = setupOpenAIClient(endpoint, serviceVersion, tokenCredential, timeout, maxRetries, proxyOptions, logRequestsAndResponses); + this.client = setupSyncClient(endpoint, serviceVersion, tokenCredential, timeout, maxRetries, proxyOptions, logRequestsAndResponses); } private AzureOpenAiLanguageModel(String deploymentName, diff --git a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiLanguageModelName.java b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiLanguageModelName.java new file mode 100644 index 0000000000..4edfdb6956 --- /dev/null +++ b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiLanguageModelName.java @@ -0,0 +1,44 @@ +package dev.langchain4j.model.azure; + +public enum AzureOpenAiLanguageModelName { + + GPT_3_5_TURBO_INSTRUCT("gpt-35-turbo-instruct", "gpt-3.5-turbo"), // alias for the latest gpt-3.5-turbo-instruct model + GPT_3_5_TURBO_INSTRUCT_0914("gpt-35-turbo-instruct-0914", "gpt-3.5-turbo", "0914"), // 4k context, functions + + TEXT_DAVINCI_002("davinci-002", "text-davinci-002"), + TEXT_DAVINCI_002_1("davinci-002-1", "text-davinci-002", "1"),; + + private final String modelName; + // Model type follows the com.knuddels.jtokkit.api.ModelType naming convention + private final String modelType; + private final String modelVersion; + + AzureOpenAiLanguageModelName(String modelName, String modelType) { + this.modelName = modelName; + this.modelType = modelType; + this.modelVersion = null; + } + + AzureOpenAiLanguageModelName(String modelName, String modelType, String modelVersion) { + this.modelName = modelName; + this.modelType = modelType; + this.modelVersion = modelVersion; + } + + public String modelName() { + return modelName; + } + + public String modelType() { + return modelType; + } + + public String modelVersion() { + return modelVersion; + } + + @Override + public String toString() { + return modelName; + } +} diff --git a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiModelName.java b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiModelName.java index 0782d0e43b..2e70f37c2e 100644 --- a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiModelName.java +++ b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiModelName.java @@ -1,23 +1,34 @@ package dev.langchain4j.model.azure; +/** + * @deprecated use {@link AzureOpenAiChatModelName}, {@link AzureOpenAiEmbeddingModelName}, {@link AzureOpenAiImageModelName} and {@link AzureOpenAiLanguageModelName}, instead. + */ +@Deprecated public class AzureOpenAiModelName { // Use with AzureOpenAiChatModel and AzureOpenAiStreamingChatModel - public static final String GPT_3_5_TURBO = "gpt-3.5-turbo"; // alias for the latest model + public static final String GPT_3_5_TURBO = "gpt-3.5-turbo"; // alias for the latest gpt-3.5-turbo model public static final String GPT_3_5_TURBO_0613 = "gpt-3.5-turbo-0613"; // 4k context, functions public static final String GPT_3_5_TURBO_0125 = "gpt-3.5-turbo-0125"; // 4k context, functions public static final String GPT_3_5_TURBO_1106 = "gpt-3.5-turbo-1106"; // 16k context, functions - public static final String GPT_3_5_TURBO_16K = "gpt-3.5-turbo-16k"; // alias for the latest model + public static final String GPT_3_5_TURBO_16K = "gpt-3.5-turbo-16k"; // alias for the latest gpt-3.5-turbo-16k model public static final String GPT_3_5_TURBO_16K_0613 = "gpt-3.5-turbo-16k-0613"; // 16k context, functions - public static final String GPT_4 = "gpt-4"; // alias for the latest model + public static final String GPT_4 = "gpt-4"; // alias for the latest gpt-4 public static final String GPT_4_1106_PREVIEW = "gpt-4-1106-preview"; // 8k context + public static final String GPT_4_0125_PREVIEW = "gpt-4-0125-preview"; // 8k context public static final String GPT_4_0613 = "gpt-4-0613"; // 8k context, functions public static final String GPT_4_32K = "gpt-4-32k"; // alias for the latest model public static final String GPT_4_32K_0613 = "gpt-4-32k-0613"; // 32k context, functions + public static final String GPT_4_TURBO = "gpt-4-turbo"; // alias for the latest gpt-4-turbo model + public static final String GPT_4_TURBO_2024_04_09 = "gpt-4-turbo-2024-04-09"; // 128k context, functions + + public static final String GPT_4_O = "gpt-4o"; // alias for the latest gpt-4o model + + public static final String GPT_4_VISION_PREVIEW = "gpt-4-vision-preview"; // Use with AzureOpenAiLanguageModel and AzureOpenAiStreamingLanguageModel public static final String TEXT_DAVINCI_002 = "text-davinci-002"; diff --git a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiStreamingChatModel.java b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiStreamingChatModel.java index ab9c05f86e..2360e6f394 100644 --- a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiStreamingChatModel.java +++ b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiStreamingChatModel.java @@ -1,5 +1,6 @@ package dev.langchain4j.model.azure; +import com.azure.ai.openai.OpenAIAsyncClient; import com.azure.ai.openai.OpenAIClient; import com.azure.ai.openai.models.*; import com.azure.core.credential.KeyCredential; @@ -16,11 +17,12 @@ import dev.langchain4j.model.chat.TokenCountEstimator; import dev.langchain4j.model.output.FinishReason; import dev.langchain4j.model.output.Response; -import dev.langchain4j.model.output.TokenUsage; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; import java.time.Duration; +import java.util.Collection; import java.util.List; import java.util.Map; @@ -62,6 +64,7 @@ public class AzureOpenAiStreamingChatModel implements StreamingChatLanguageModel private static final Logger logger = LoggerFactory.getLogger(AzureOpenAiStreamingChatModel.class); private OpenAIClient client; + private OpenAIAsyncClient asyncClient; private final String deploymentName; private final Tokenizer tokenizer; private final Integer maxTokens; @@ -79,6 +82,7 @@ public class AzureOpenAiStreamingChatModel implements StreamingChatLanguageModel private final ChatCompletionsResponseFormat responseFormat; public AzureOpenAiStreamingChatModel(OpenAIClient client, + OpenAIAsyncClient asyncClient, String deploymentName, Tokenizer tokenizer, Integer maxTokens, @@ -96,7 +100,14 @@ public AzureOpenAiStreamingChatModel(OpenAIClient client, ChatCompletionsResponseFormat responseFormat) { this(deploymentName, tokenizer, maxTokens, temperature, topP, logitBias, user, n, stop, presencePenalty, frequencyPenalty, dataSources, enhancements, seed, responseFormat); - this.client = client; + + if (asyncClient != null) { + this.asyncClient = asyncClient; + } else if(client != null) { + this.client = client; + } else { + throw new IllegalStateException("No client available"); + } } public AzureOpenAiStreamingChatModel(String endpoint, @@ -120,11 +131,14 @@ public AzureOpenAiStreamingChatModel(String endpoint, Duration timeout, Integer maxRetries, ProxyOptions proxyOptions, - boolean logRequestsAndResponses) { + boolean logRequestsAndResponses, + boolean useAsyncClient) { this(deploymentName, tokenizer, maxTokens, temperature, topP, logitBias, user, n, stop, presencePenalty, frequencyPenalty, dataSources, enhancements, seed, responseFormat); - this.client = setupOpenAIClient(endpoint, serviceVersion, apiKey, timeout, maxRetries, proxyOptions, logRequestsAndResponses); - } + if(useAsyncClient) + this.asyncClient = setupAsyncClient(endpoint, serviceVersion, apiKey, timeout, maxRetries, proxyOptions, logRequestsAndResponses); + else + this.client = setupSyncClient(endpoint, serviceVersion, apiKey, timeout, maxRetries, proxyOptions, logRequestsAndResponses); } public AzureOpenAiStreamingChatModel(String endpoint, String serviceVersion, @@ -147,11 +161,14 @@ public AzureOpenAiStreamingChatModel(String endpoint, Duration timeout, Integer maxRetries, ProxyOptions proxyOptions, - boolean logRequestsAndResponses) { + boolean logRequestsAndResponses, + boolean useAsyncClient) { this(deploymentName, tokenizer, maxTokens, temperature, topP, logitBias, user, n, stop, presencePenalty, frequencyPenalty, dataSources, enhancements, seed, responseFormat); - this.client = setupOpenAIClient(endpoint, serviceVersion, keyCredential, timeout, maxRetries, proxyOptions, logRequestsAndResponses); - } + if(useAsyncClient) + this.asyncClient = setupAsyncClient(endpoint, serviceVersion, keyCredential, timeout, maxRetries, proxyOptions, logRequestsAndResponses); + else + this.client = setupSyncClient(endpoint, serviceVersion, keyCredential, timeout, maxRetries, proxyOptions, logRequestsAndResponses); } public AzureOpenAiStreamingChatModel(String endpoint, String serviceVersion, @@ -174,11 +191,15 @@ public AzureOpenAiStreamingChatModel(String endpoint, Duration timeout, Integer maxRetries, ProxyOptions proxyOptions, - boolean logRequestsAndResponses) { + boolean logRequestsAndResponses, + boolean useAsyncClient) { this(deploymentName, tokenizer, maxTokens, temperature, topP, logitBias, user, n, stop, presencePenalty, frequencyPenalty, dataSources, enhancements, seed, responseFormat); - this.client = setupOpenAIClient(endpoint, serviceVersion, tokenCredential, timeout, maxRetries, proxyOptions, logRequestsAndResponses); - } + if(useAsyncClient) + this.asyncClient = setupAsyncClient(endpoint, serviceVersion, tokenCredential, timeout, maxRetries, proxyOptions, logRequestsAndResponses); + else + this.client = setupSyncClient(endpoint, serviceVersion, tokenCredential, timeout, maxRetries, proxyOptions, logRequestsAndResponses); } + private AzureOpenAiStreamingChatModel(String deploymentName, Tokenizer tokenizer, @@ -253,20 +274,56 @@ private void generate(List messages, Integer inputTokenCount = tokenizer == null ? null : tokenizer.estimateTokenCountInMessages(messages); if (toolThatMustBeExecuted != null) { - options.setFunctions(toFunctions(singletonList(toolThatMustBeExecuted))); - options.setFunctionCall(new FunctionCallConfig(toolThatMustBeExecuted.name())); - if (tokenizer != null) { - inputTokenCount += tokenizer.estimateTokenCountInForcefulToolSpecification(toolThatMustBeExecuted); - } - } else if (!isNullOrEmpty(toolSpecifications)) { - options.setFunctions(toFunctions(toolSpecifications)); - if (tokenizer != null) { - inputTokenCount += tokenizer.estimateTokenCountInToolSpecifications(toolSpecifications); - } + options.setTools(toToolDefinitions(singletonList(toolThatMustBeExecuted))); + options.setToolChoice(toToolChoice(toolThatMustBeExecuted)); + inputTokenCount += tokenizer.estimateTokenCountInForcefulToolSpecification(toolThatMustBeExecuted); + } + if (!isNullOrEmpty(toolSpecifications)) { + options.setTools(toToolDefinitions(toolSpecifications)); + inputTokenCount += tokenizer.estimateTokenCountInToolSpecifications(toolSpecifications); } AzureOpenAiStreamingResponseBuilder responseBuilder = new AzureOpenAiStreamingResponseBuilder(inputTokenCount); + // Sync version + if(client != null) { + syncCall(toolThatMustBeExecuted, handler, options, responseBuilder); + } else if(asyncClient != null) { + asyncCall(toolThatMustBeExecuted, handler, options, responseBuilder); + } + } + + private void handleResponseException(Throwable throwable, StreamingResponseHandler handler) { + if (throwable instanceof HttpResponseException) { + HttpResponseException httpResponseException = (HttpResponseException) throwable; + logger.info("Error generating response, {}", httpResponseException.getValue()); + FinishReason exceptionFinishReason = contentFilterManagement(httpResponseException, "content_filter"); + Response response = Response.from( + aiMessage(httpResponseException.getMessage()), + null, + exceptionFinishReason + ); + handler.onComplete(response); + } else { + handler.onError(throwable); + } + } + + private void asyncCall(ToolSpecification toolThatMustBeExecuted, StreamingResponseHandler handler, ChatCompletionsOptions options, AzureOpenAiStreamingResponseBuilder responseBuilder) { + Flux chatCompletionsStream = asyncClient.getChatCompletionsStream(deploymentName, options); + + chatCompletionsStream.subscribe(chatCompletion -> { + responseBuilder.append(chatCompletion); + handle(chatCompletion, handler); + }, + throwable -> handleResponseException(throwable, handler), + () -> { + Response response = responseBuilder.build(tokenizer, toolThatMustBeExecuted != null); + handler.onComplete(response); + }); + } + + private void syncCall(ToolSpecification toolThatMustBeExecuted, StreamingResponseHandler handler, ChatCompletionsOptions options, AzureOpenAiStreamingResponseBuilder responseBuilder) { try { client.getChatCompletionsStream(deploymentName, options) .stream() @@ -276,20 +333,12 @@ private void generate(List messages, }); Response response = responseBuilder.build(tokenizer, toolThatMustBeExecuted != null); handler.onComplete(response); - } catch (HttpResponseException httpResponseException) { - logger.info("Error generating response, {}", httpResponseException.getValue()); - FinishReason exceptionFinishReason = contentFilterManagement(httpResponseException, "content_filter"); - Response response = Response.from( - aiMessage(httpResponseException.getMessage()), - null, - exceptionFinishReason - ); - handler.onComplete(response); } catch (Exception exception) { - handler.onError(exception); + handleResponseException(exception, handler); } } + private static void handle(ChatCompletions chatCompletions, StreamingResponseHandler handler) { @@ -343,6 +392,8 @@ public static class Builder { private ProxyOptions proxyOptions; private boolean logRequestsAndResponses; private OpenAIClient openAIClient; + private OpenAIAsyncClient openAIAsyncClient; + private boolean useAsyncClient = true; /** * Sets the Azure OpenAI endpoint. This is a mandatory parameter. @@ -502,16 +553,41 @@ public Builder logRequestsAndResponses(boolean logRequestsAndResponses) { } /** - * Sets the Azure OpenAI client. This is an optional parameter, if you need more flexibility than using the endpoint, serviceVersion, apiKey, deploymentName parameters. - * + * @deprecated If you want to continue using sync client, use {@link AzureOpenAiChatModel} instead. + * @param useAsyncClient {@code true} if you want to use the async client, {@code false} if you want to use the sync client. + * @return builder with the useAsyncClient parameter set + */ + @SuppressWarnings("DeprecatedIsStillUsed") + @Deprecated + public Builder useAsyncClient(boolean useAsyncClient) { + this.useAsyncClient = useAsyncClient; + return this; + } + + /** + * @deprecated Please use {@link #openAIAsyncClient(OpenAIAsyncClient)} instead, if you require response streaming. + * Please use {@link AzureOpenAiChatModel} instead, if you require sync responses. * @param openAIClient The Azure OpenAI client. * @return builder */ + @SuppressWarnings("DeprecatedIsStillUsed") + @Deprecated public Builder openAIClient(OpenAIClient openAIClient) { this.openAIClient = openAIClient; return this; } + /** + * Sets the Azure OpenAI client. This is an optional parameter, if you need more flexibility than using the endpoint, serviceVersion, apiKey, deploymentName parameters. + * + * @param openAIAsyncClient The Azure OpenAI client. + * @return builder + */ + public Builder openAIAsyncClient(OpenAIAsyncClient openAIAsyncClient) { + this.openAIAsyncClient = openAIAsyncClient; + return this; + } + public AzureOpenAiStreamingChatModel build() { if (openAIClient == null) { if (tokenCredential != null) { @@ -537,7 +613,8 @@ public AzureOpenAiStreamingChatModel build() { timeout, maxRetries, proxyOptions, - logRequestsAndResponses + logRequestsAndResponses, + useAsyncClient ); } else if (keyCredential != null) { return new AzureOpenAiStreamingChatModel( @@ -562,7 +639,8 @@ public AzureOpenAiStreamingChatModel build() { timeout, maxRetries, proxyOptions, - logRequestsAndResponses + logRequestsAndResponses, + useAsyncClient ); } return new AzureOpenAiStreamingChatModel( @@ -587,11 +665,13 @@ public AzureOpenAiStreamingChatModel build() { timeout, maxRetries, proxyOptions, - logRequestsAndResponses + logRequestsAndResponses, + useAsyncClient ); } else { return new AzureOpenAiStreamingChatModel( openAIClient, + openAIAsyncClient, deploymentName, tokenizer, maxTokens, diff --git a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiStreamingLanguageModel.java b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiStreamingLanguageModel.java index 0b0dd5eedf..80b3772bab 100644 --- a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiStreamingLanguageModel.java +++ b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiStreamingLanguageModel.java @@ -1,7 +1,9 @@ package dev.langchain4j.model.azure; import com.azure.ai.openai.OpenAIClient; -import com.azure.ai.openai.models.*; +import com.azure.ai.openai.models.Choice; +import com.azure.ai.openai.models.Completions; +import com.azure.ai.openai.models.CompletionsOptions; import com.azure.core.credential.KeyCredential; import com.azure.core.credential.TokenCredential; import com.azure.core.exception.HttpResponseException; @@ -14,7 +16,6 @@ import dev.langchain4j.model.language.TokenCountEstimator; import dev.langchain4j.model.output.FinishReason; import dev.langchain4j.model.output.Response; -import dev.langchain4j.model.output.TokenUsage; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -27,7 +28,7 @@ import static dev.langchain4j.internal.Utils.getOrDefault; import static dev.langchain4j.model.azure.InternalAzureOpenAiHelper.contentFilterManagement; -import static dev.langchain4j.model.azure.InternalAzureOpenAiHelper.setupOpenAIClient; +import static dev.langchain4j.model.azure.InternalAzureOpenAiHelper.setupSyncClient; import static dev.langchain4j.spi.ServiceHelper.loadFactories; /** @@ -117,7 +118,7 @@ public AzureOpenAiStreamingLanguageModel(String endpoint, boolean logRequestsAndResponses) { this(deploymentName, tokenizer, maxTokens, temperature, topP, logitBias, user, n, logprobs, echo, stop, presencePenalty, frequencyPenalty); - this.client = setupOpenAIClient(endpoint, serviceVersion, apiKey, timeout, maxRetries, proxyOptions, logRequestsAndResponses); + this.client = setupSyncClient(endpoint, serviceVersion, apiKey, timeout, maxRetries, proxyOptions, logRequestsAndResponses); } public AzureOpenAiStreamingLanguageModel(String endpoint, @@ -142,7 +143,7 @@ public AzureOpenAiStreamingLanguageModel(String endpoint, boolean logRequestsAndResponses) { this(deploymentName, tokenizer, maxTokens, temperature, topP, logitBias, user, n, logprobs, echo, stop, presencePenalty, frequencyPenalty); - this.client = setupOpenAIClient(endpoint, serviceVersion, keyCredential, timeout, maxRetries, proxyOptions, logRequestsAndResponses); + this.client = setupSyncClient(endpoint, serviceVersion, keyCredential, timeout, maxRetries, proxyOptions, logRequestsAndResponses); } public AzureOpenAiStreamingLanguageModel(String endpoint, @@ -167,7 +168,7 @@ public AzureOpenAiStreamingLanguageModel(String endpoint, boolean logRequestsAndResponses) { this(deploymentName, tokenizer, maxTokens, temperature, topP, logitBias, user, n, logprobs, echo, stop, presencePenalty, frequencyPenalty); - this.client = setupOpenAIClient(endpoint, serviceVersion, tokenCredential, timeout, maxRetries, proxyOptions, logRequestsAndResponses); + this.client = setupSyncClient(endpoint, serviceVersion, tokenCredential, timeout, maxRetries, proxyOptions, logRequestsAndResponses); } private AzureOpenAiStreamingLanguageModel(String deploymentName, diff --git a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiStreamingResponseBuilder.java b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiStreamingResponseBuilder.java index f2be0b275b..54255c2a82 100644 --- a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiStreamingResponseBuilder.java +++ b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiStreamingResponseBuilder.java @@ -6,11 +6,16 @@ import dev.langchain4j.model.Tokenizer; import dev.langchain4j.model.output.Response; import dev.langchain4j.model.output.TokenUsage; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import java.util.HashMap; import java.util.List; +import java.util.Map; import static dev.langchain4j.model.azure.InternalAzureOpenAiHelper.finishReasonFrom; import static java.util.Collections.singletonList; +import static java.util.stream.Collectors.toList; /** * This class needs to be thread safe because it is called when a streaming result comes back @@ -19,9 +24,13 @@ */ class AzureOpenAiStreamingResponseBuilder { + Logger logger = LoggerFactory.getLogger(AzureOpenAiStreamingResponseBuilder.class); + private final StringBuffer contentBuilder = new StringBuffer(); private final StringBuffer toolNameBuilder = new StringBuffer(); private final StringBuffer toolArgumentsBuilder = new StringBuffer(); + private String toolExecutionsIndex = "call_undefined"; + private final Map toolExecutionRequestBuilderHashMap = new HashMap<>(); private volatile CompletionsFinishReason finishReason; private final Integer inputTokenCount; @@ -61,14 +70,29 @@ public void append(ChatCompletions completions) { return; } - FunctionCall functionCall = delta.getFunctionCall(); - if (functionCall != null) { - if (functionCall.getName() != null) { - toolNameBuilder.append(functionCall.getName()); - } - - if (functionCall.getArguments() != null) { - toolArgumentsBuilder.append(functionCall.getArguments()); + if (delta.getToolCalls() != null && !delta.getToolCalls().isEmpty()) { + for (ChatCompletionsToolCall toolCall : delta.getToolCalls()) { + ToolExecutionRequestBuilder toolExecutionRequestBuilder; + if (toolCall.getId() != null) { + toolExecutionsIndex = toolCall.getId(); + toolExecutionRequestBuilder = new ToolExecutionRequestBuilder(); + toolExecutionRequestBuilder.idBuilder.append(toolExecutionsIndex); + toolExecutionRequestBuilderHashMap.put(toolExecutionsIndex, toolExecutionRequestBuilder); + } else { + toolExecutionRequestBuilder = toolExecutionRequestBuilderHashMap.get(toolExecutionsIndex); + if (toolExecutionRequestBuilder == null) { + throw new IllegalStateException("Function without an id defined in the tool call"); + } + } + if (toolCall instanceof ChatCompletionsFunctionToolCall) { + ChatCompletionsFunctionToolCall functionCall = (ChatCompletionsFunctionToolCall) toolCall; + if (functionCall.getFunction().getName() != null) { + toolExecutionRequestBuilder.nameBuilder.append(functionCall.getFunction().getName()); + } + if (functionCall.getFunction().getArguments() != null) { + toolExecutionRequestBuilder.argumentsBuilder.append(functionCall.getFunction().getArguments()); + } + } } } } @@ -118,7 +142,22 @@ public Response build(Tokenizer tokenizer, boolean forcefulToolExecut .build(); return Response.from( AiMessage.from(toolExecutionRequest), - tokenUsage(toolExecutionRequest, tokenizer, forcefulToolExecution), + tokenUsage(singletonList(toolExecutionRequest), tokenizer, forcefulToolExecution), + finishReasonFrom(finishReason) + ); + } + + if (!toolExecutionRequestBuilderHashMap.isEmpty()) { + List toolExecutionRequests = toolExecutionRequestBuilderHashMap.values().stream() + .map(it -> ToolExecutionRequest.builder() + .id(it.idBuilder.toString()) + .name(it.nameBuilder.toString()) + .arguments(it.argumentsBuilder.toString()) + .build()) + .collect(toList()); + return Response.from( + AiMessage.from(toolExecutionRequests), + tokenUsage(toolExecutionRequests, tokenizer, forcefulToolExecution), finishReasonFrom(finishReason) ); } @@ -134,7 +173,7 @@ private TokenUsage tokenUsage(String content, Tokenizer tokenizer) { return new TokenUsage(inputTokenCount, outputTokenCount); } - private TokenUsage tokenUsage(ToolExecutionRequest toolExecutionRequest, Tokenizer tokenizer, boolean forcefulToolExecution) { + private TokenUsage tokenUsage(List toolExecutionRequests, Tokenizer tokenizer, boolean forcefulToolExecution) { if (tokenizer == null) { return null; } @@ -142,11 +181,20 @@ private TokenUsage tokenUsage(ToolExecutionRequest toolExecutionRequest, Tokeniz int outputTokenCount = 0; if (forcefulToolExecution) { // OpenAI calculates output tokens differently when tool is executed forcefully - outputTokenCount += tokenizer.estimateTokenCountInForcefulToolExecutionRequest(toolExecutionRequest); + for (ToolExecutionRequest toolExecutionRequest : toolExecutionRequests) { + outputTokenCount += tokenizer.estimateTokenCountInForcefulToolExecutionRequest(toolExecutionRequest); + } } else { - outputTokenCount = tokenizer.estimateTokenCountInToolExecutionRequests(singletonList(toolExecutionRequest)); + outputTokenCount = tokenizer.estimateTokenCountInToolExecutionRequests(toolExecutionRequests); } return new TokenUsage(inputTokenCount, outputTokenCount); } + + private static class ToolExecutionRequestBuilder { + + private final StringBuffer idBuilder = new StringBuffer(); + private final StringBuffer nameBuilder = new StringBuffer(); + private final StringBuffer argumentsBuilder = new StringBuffer(); + } } diff --git a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiTokenizer.java b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiTokenizer.java new file mode 100644 index 0000000000..17cd9fd4b4 --- /dev/null +++ b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiTokenizer.java @@ -0,0 +1,404 @@ +package dev.langchain4j.model.azure; + +import com.knuddels.jtokkit.Encodings; +import com.knuddels.jtokkit.api.Encoding; +import com.knuddels.jtokkit.api.IntArrayList; +import dev.langchain4j.agent.tool.ToolExecutionRequest; +import dev.langchain4j.agent.tool.ToolParameters; +import dev.langchain4j.agent.tool.ToolSpecification; +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.data.message.Content; +import dev.langchain4j.data.message.ImageContent; +import dev.langchain4j.data.message.SystemMessage; +import dev.langchain4j.data.message.TextContent; +import dev.langchain4j.data.message.ToolExecutionResultMessage; +import dev.langchain4j.data.message.UserMessage; +import dev.langchain4j.model.Tokenizer; + +import static dev.langchain4j.internal.Exceptions.illegalArgument; +import static dev.langchain4j.internal.Json.fromJson; +import static dev.langchain4j.internal.Utils.isNullOrBlank; +import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank; +import static dev.langchain4j.model.azure.AzureOpenAiChatModelName.GPT_3_5_TURBO; +import static dev.langchain4j.model.azure.AzureOpenAiChatModelName.GPT_3_5_TURBO_0301; +import static dev.langchain4j.model.azure.AzureOpenAiChatModelName.GPT_3_5_TURBO_1106; +import static dev.langchain4j.model.azure.AzureOpenAiChatModelName.GPT_4_1106_PREVIEW; +import static dev.langchain4j.model.azure.AzureOpenAiChatModelName.GPT_4_0125_PREVIEW; +import static dev.langchain4j.model.azure.AzureOpenAiChatModelName.GPT_4_TURBO; +import static dev.langchain4j.model.azure.AzureOpenAiChatModelName.GPT_4_VISION_PREVIEW; +import static java.util.Collections.singletonList; + +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.function.Supplier; + +/** + * This class can be used to estimate the cost (in tokens) before calling OpenAI or when using streaming. + * Magic numbers present in this class were found empirically while testing. + * There are integration tests in place that are making sure that the calculations here are very close to that of OpenAI. + */ +public class AzureOpenAiTokenizer implements Tokenizer { + + private final String modelName; + private final Optional encoding; + + /** + * Creates an instance of the {@code AzureOpenAiTokenizer} for the "gpt-3.5-turbo" model. + * It should be suitable for most OpenAI models, as most of them use the same cl100k_base encoding (except for GPT-4o). + */ + public AzureOpenAiTokenizer() { + this(GPT_3_5_TURBO.modelType()); + } + + /** + * Creates an instance of the {@code AzureOpenAiTokenizer} for a given {@link AzureOpenAiChatModelName}. + */ + public AzureOpenAiTokenizer(AzureOpenAiChatModelName modelName) { + this(modelName.modelType()); + } + + /** + * Creates an instance of the {@code AzureOpenAiTokenizer} for a given {@link AzureOpenAiEmbeddingModelName}. + */ + public AzureOpenAiTokenizer(AzureOpenAiEmbeddingModelName modelName) { + this(modelName.modelType()); + } + + /** + * Creates an instance of the {@code AzureOpenAiTokenizer} for a given {@link AzureOpenAiLanguageModelName}. + */ + public AzureOpenAiTokenizer(AzureOpenAiLanguageModelName modelName) { + this(modelName.modelType()); + } + + /** + * Creates an instance of the {@code AzureOpenAiTokenizer} for a given model name. + */ + public AzureOpenAiTokenizer(String modelName) { + this.modelName = ensureNotBlank(modelName, "modelName"); + // If the model is unknown, we should NOT fail fast during the creation of AzureOpenAiTokenizer. + // Doing so would cause the failure of every OpenAI***Model that uses this tokenizer. + // This is done to account for situations when a new OpenAI model is available, + // but JTokkit does not yet support it. + this.encoding = Encodings.newLazyEncodingRegistry().getEncodingForModel(modelName); + } + + public int estimateTokenCountInText(String text) { + return encoding.orElseThrow(unknownModelException()) + .countTokensOrdinary(text); + } + + @Override + public int estimateTokenCountInMessage(ChatMessage message) { + int tokenCount = 1; // 1 token for role + tokenCount += extraTokensPerMessage(); + + if (message instanceof SystemMessage) { + tokenCount += estimateTokenCountIn((SystemMessage) message); + } else if (message instanceof UserMessage) { + tokenCount += estimateTokenCountIn((UserMessage) message); + } else if (message instanceof AiMessage) { + tokenCount += estimateTokenCountIn((AiMessage) message); + } else if (message instanceof ToolExecutionResultMessage) { + tokenCount += estimateTokenCountIn((ToolExecutionResultMessage) message); + } else { + throw new IllegalArgumentException("Unknown message type: " + message); + } + + return tokenCount; + } + + private int estimateTokenCountIn(SystemMessage systemMessage) { + return estimateTokenCountInText(systemMessage.text()); + } + + private int estimateTokenCountIn(UserMessage userMessage) { + int tokenCount = 0; + + for (Content content : userMessage.contents()) { + if (content instanceof TextContent) { + tokenCount += estimateTokenCountInText(((TextContent) content).text()); + } else if (content instanceof ImageContent) { + tokenCount += 85; // TODO implement for HIGH/AUTO detail level + } else { + throw illegalArgument("Unknown content type: " + content); + } + } + + if (userMessage.name() != null && !modelName.equals(GPT_4_VISION_PREVIEW.toString())) { + tokenCount += extraTokensPerName(); + tokenCount += estimateTokenCountInText(userMessage.name()); + } + + return tokenCount; + } + + private int estimateTokenCountIn(AiMessage aiMessage) { + int tokenCount = 0; + + if (aiMessage.text() != null) { + tokenCount += estimateTokenCountInText(aiMessage.text()); + } + + if (aiMessage.toolExecutionRequests() != null) { + if (isOneOfLatestModels()) { + tokenCount += 6; + } else { + tokenCount += 3; + } + if (aiMessage.toolExecutionRequests().size() == 1) { + tokenCount -= 1; + ToolExecutionRequest toolExecutionRequest = aiMessage.toolExecutionRequests().get(0); + tokenCount += estimateTokenCountInText(toolExecutionRequest.name()) * 2; + tokenCount += estimateTokenCountInText(toolExecutionRequest.arguments()); + } else { + tokenCount += 15; + for (ToolExecutionRequest toolExecutionRequest : aiMessage.toolExecutionRequests()) { + tokenCount += 7; + tokenCount += estimateTokenCountInText(toolExecutionRequest.name()); + + Map arguments = fromJson(toolExecutionRequest.arguments(), Map.class); + for (Map.Entry argument : arguments.entrySet()) { + tokenCount += 2; + tokenCount += estimateTokenCountInText(argument.getKey().toString()); + tokenCount += estimateTokenCountInText(argument.getValue().toString()); + } + } + } + } + + return tokenCount; + } + + private int estimateTokenCountIn(ToolExecutionResultMessage toolExecutionResultMessage) { + return estimateTokenCountInText(toolExecutionResultMessage.text()); + } + + private int extraTokensPerMessage() { + if (modelName.equals(GPT_3_5_TURBO_0301.modelName())) { + return 4; + } else { + return 3; + } + } + + private int extraTokensPerName() { + if (modelName.equals(GPT_3_5_TURBO_0301.toString())) { + return -1; // if there's a name, the role is omitted + } else { + return 1; + } + } + + @Override + public int estimateTokenCountInMessages(Iterable messages) { + // see https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb + + int tokenCount = 3; // every reply is primed with <|start|>assistant<|message|> + for (ChatMessage message : messages) { + tokenCount += estimateTokenCountInMessage(message); + } + return tokenCount; + } + + @Override + public int estimateTokenCountInToolSpecifications(Iterable toolSpecifications) { + int tokenCount = 16; + for (ToolSpecification toolSpecification : toolSpecifications) { + tokenCount += 6; + tokenCount += estimateTokenCountInText(toolSpecification.name()); + if (toolSpecification.description() != null) { + tokenCount += 2; + tokenCount += estimateTokenCountInText(toolSpecification.description()); + } + tokenCount += estimateTokenCountInToolParameters(toolSpecification.parameters()); + } + return tokenCount; + } + + private int estimateTokenCountInToolParameters(ToolParameters parameters) { + if (parameters == null) { + return 0; + } + + int tokenCount = 3; + Map> properties = parameters.properties(); + if (isOneOfLatestModels()) { + tokenCount += properties.size() - 1; + } + for (String property : properties.keySet()) { + if (isOneOfLatestModels()) { + tokenCount += 2; + } else { + tokenCount += 3; + } + tokenCount += estimateTokenCountInText(property); + for (Map.Entry entry : properties.get(property).entrySet()) { + if ("type".equals(entry.getKey())) { + if ("array".equals(entry.getValue()) && isOneOfLatestModels()) { + tokenCount += 1; + } + // TODO object + } else if ("description".equals(entry.getKey())) { + tokenCount += 2; + tokenCount += estimateTokenCountInText(entry.getValue().toString()); + if (isOneOfLatestModels() && parameters.required().contains(property)) { + tokenCount += 1; + } + } else if ("enum".equals(entry.getKey())) { + if (isOneOfLatestModels()) { + tokenCount -= 2; + } else { + tokenCount -= 3; + } + for (Object enumValue : (Object[]) entry.getValue()) { + tokenCount += 3; + tokenCount += estimateTokenCountInText(enumValue.toString()); + } + } + } + } + return tokenCount; + } + + @Override + public int estimateTokenCountInForcefulToolSpecification(ToolSpecification toolSpecification) { + int tokenCount = estimateTokenCountInToolSpecifications(singletonList(toolSpecification)); + tokenCount += 4; + tokenCount += estimateTokenCountInText(toolSpecification.name()); + if (isOneOfLatestModels()) { + tokenCount += 3; + } + return tokenCount; + } + + public List encode(String text) { + return encoding.orElseThrow(unknownModelException()) + .encodeOrdinary(text).boxed(); + } + + public List encode(String text, int maxTokensToEncode) { + return encoding.orElseThrow(unknownModelException()) + .encodeOrdinary(text, maxTokensToEncode).getTokens().boxed(); + } + + public String decode(List tokens) { + + IntArrayList intArrayList = new IntArrayList(); + for (Integer token : tokens) { + intArrayList.add(token); + } + + return encoding.orElseThrow(unknownModelException()) + .decode(intArrayList); + } + + private Supplier unknownModelException() { + return () -> illegalArgument("Model '%s' is unknown to jtokkit", modelName); + } + + @Override + public int estimateTokenCountInToolExecutionRequests(Iterable toolExecutionRequests) { + + int tokenCount = 0; + + int toolsCount = 0; + int toolsWithArgumentsCount = 0; + int toolsWithoutArgumentsCount = 0; + + int totalArgumentsCount = 0; + + for (ToolExecutionRequest toolExecutionRequest : toolExecutionRequests) { + tokenCount += 4; + tokenCount += estimateTokenCountInText(toolExecutionRequest.name()); + tokenCount += estimateTokenCountInText(toolExecutionRequest.arguments()); + + int argumentCount = countArguments(toolExecutionRequest.arguments()); + if (argumentCount == 0) { + toolsWithoutArgumentsCount++; + } else { + toolsWithArgumentsCount++; + } + totalArgumentsCount += argumentCount; + + toolsCount++; + } + + if (modelName.equals(GPT_3_5_TURBO_1106.toString()) || isOneOfLatestGpt4Models()) { + tokenCount += 16; + tokenCount += 3 * toolsWithoutArgumentsCount; + tokenCount += toolsCount; + if (totalArgumentsCount > 0) { + tokenCount -= 1; + tokenCount -= 2 * totalArgumentsCount; + tokenCount += 2 * toolsWithArgumentsCount; + tokenCount += toolsCount; + } + } + + if (modelName.equals(GPT_4_1106_PREVIEW.toString())) { + tokenCount += 3; + if (toolsCount > 1) { + tokenCount += 18; + tokenCount += 15 * toolsCount; + tokenCount += totalArgumentsCount; + tokenCount -= 3 * toolsWithoutArgumentsCount; + } + } + + return tokenCount; + } + + @Override + public int estimateTokenCountInForcefulToolExecutionRequest(ToolExecutionRequest toolExecutionRequest) { + + if (isOneOfLatestGpt4Models()) { + int argumentsCount = countArguments(toolExecutionRequest.arguments()); + if (argumentsCount == 0) { + return 1; + } else { + return estimateTokenCountInText(toolExecutionRequest.arguments()); + } + } + + int tokenCount = estimateTokenCountInToolExecutionRequests(singletonList(toolExecutionRequest)); + tokenCount -= 4; + tokenCount -= estimateTokenCountInText(toolExecutionRequest.name()); + + if (modelName.equals(GPT_3_5_TURBO_1106.toString())) { + int argumentsCount = countArguments(toolExecutionRequest.arguments()); + if (argumentsCount == 0) { + return 1; + } + tokenCount -= 19; + tokenCount += 2 * argumentsCount; + } + + return tokenCount; + } + + static int countArguments(String arguments) { + if (isNullOrBlank(arguments)) { + return 0; + } + Map argumentsMap = fromJson(arguments, Map.class); + return argumentsMap.size(); + } + + private boolean isOneOfLatestModels() { + return isOneOfLatestGpt3Models() || isOneOfLatestGpt4Models(); + } + + private boolean isOneOfLatestGpt3Models() { + return modelName.equals(GPT_3_5_TURBO_1106.toString()) + || modelName.equals(GPT_3_5_TURBO.toString()); + } + + private boolean isOneOfLatestGpt4Models() { + return modelName.equals(GPT_4_TURBO.toString()) + || modelName.equals(GPT_4_1106_PREVIEW.toString()) + || modelName.equals(GPT_4_0125_PREVIEW.toString()); + } +} diff --git a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/InternalAzureOpenAiHelper.java b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/InternalAzureOpenAiHelper.java index 18c2c162f7..56e52bc70f 100644 --- a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/InternalAzureOpenAiHelper.java +++ b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/InternalAzureOpenAiHelper.java @@ -1,5 +1,6 @@ package dev.langchain4j.model.azure; +import com.azure.ai.openai.OpenAIAsyncClient; import com.azure.ai.openai.OpenAIClient; import com.azure.ai.openai.OpenAIClientBuilder; import com.azure.ai.openai.OpenAIServiceVersion; @@ -24,7 +25,6 @@ import dev.langchain4j.data.image.Image; import dev.langchain4j.data.message.*; import dev.langchain4j.model.output.FinishReason; -import dev.langchain4j.model.output.Response; import dev.langchain4j.model.output.TokenUsage; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -47,31 +47,17 @@ class InternalAzureOpenAiHelper { public static final String DEFAULT_USER_AGENT = "langchain4j-azure-openai"; - public static OpenAIClient setupOpenAIClient(String endpoint, String serviceVersion, String apiKey, Duration timeout, Integer maxRetries, ProxyOptions proxyOptions, boolean logRequestsAndResponses) { - OpenAIClientBuilder openAIClientBuilder = setupOpenAIClientBuilder(endpoint, serviceVersion, timeout, maxRetries, proxyOptions, logRequestsAndResponses); - - return openAIClientBuilder - .credential(new AzureKeyCredential(apiKey)) - .buildClient(); + public static OpenAIClient setupSyncClient(String endpoint, String serviceVersion, Object credential, Duration timeout, Integer maxRetries, ProxyOptions proxyOptions, boolean logRequestsAndResponses) { + OpenAIClientBuilder openAIClientBuilder = setupOpenAIClientBuilder(endpoint, serviceVersion, credential, timeout, maxRetries, proxyOptions, logRequestsAndResponses); + return openAIClientBuilder.buildClient(); } - public static OpenAIClient setupOpenAIClient(String endpoint, String serviceVersion, KeyCredential keyCredential, Duration timeout, Integer maxRetries, ProxyOptions proxyOptions, boolean logRequestsAndResponses) { - OpenAIClientBuilder openAIClientBuilder = setupOpenAIClientBuilder(endpoint, serviceVersion, timeout, maxRetries, proxyOptions, logRequestsAndResponses); - - return openAIClientBuilder - .credential(keyCredential) - .buildClient(); - } - - public static OpenAIClient setupOpenAIClient(String endpoint, String serviceVersion, TokenCredential tokenCredential, Duration timeout, Integer maxRetries, ProxyOptions proxyOptions, boolean logRequestsAndResponses) { - OpenAIClientBuilder openAIClientBuilder = setupOpenAIClientBuilder(endpoint, serviceVersion, timeout, maxRetries, proxyOptions, logRequestsAndResponses); - - return openAIClientBuilder - .credential(tokenCredential) - .buildClient(); + public static OpenAIAsyncClient setupAsyncClient(String endpoint, String serviceVersion, Object credential, Duration timeout, Integer maxRetries, ProxyOptions proxyOptions, boolean logRequestsAndResponses) { + OpenAIClientBuilder openAIClientBuilder = setupOpenAIClientBuilder(endpoint, serviceVersion, credential, timeout, maxRetries, proxyOptions, logRequestsAndResponses); + return openAIClientBuilder.buildAsyncClient(); } - private static OpenAIClientBuilder setupOpenAIClientBuilder(String endpoint, String serviceVersion, Duration timeout, Integer maxRetries, ProxyOptions proxyOptions, boolean logRequestsAndResponses) { + private static OpenAIClientBuilder setupOpenAIClientBuilder(String endpoint, String serviceVersion, Object credential, Duration timeout, Integer maxRetries, ProxyOptions proxyOptions, boolean logRequestsAndResponses) { timeout = getOrDefault(timeout, ofSeconds(60)); HttpClientOptions clientOptions = new HttpClientOptions(); clientOptions.setConnectTimeout(timeout); @@ -94,13 +80,26 @@ private static OpenAIClientBuilder setupOpenAIClientBuilder(String endpoint, Str exponentialBackoffOptions.setMaxRetries(maxRetries); RetryOptions retryOptions = new RetryOptions(exponentialBackoffOptions); - return new OpenAIClientBuilder() + OpenAIClientBuilder openAIClientBuilder = new OpenAIClientBuilder() .endpoint(ensureNotBlank(endpoint, "endpoint")) .serviceVersion(getOpenAIServiceVersion(serviceVersion)) .httpClient(httpClient) .clientOptions(clientOptions) .httpLogOptions(httpLogOptions) .retryOptions(retryOptions); + + if (credential instanceof String) { + openAIClientBuilder.credential(new AzureKeyCredential((String) credential)); + } else if (credential instanceof KeyCredential) { + openAIClientBuilder.credential((KeyCredential) credential); + } else if (credential instanceof TokenCredential) { + openAIClientBuilder.credential((TokenCredential) credential); + } else { + throw new IllegalArgumentException("Unsupported credential type: " + credential.getClass()); + } + + return openAIClientBuilder; + } private static OpenAIClientBuilder authenticate(TokenCredential tokenCredential) { @@ -128,11 +127,11 @@ public static com.azure.ai.openai.models.ChatRequestMessage toOpenAiMessage(Chat if (message instanceof AiMessage) { AiMessage aiMessage = (AiMessage) message; ChatRequestAssistantMessage chatRequestAssistantMessage = new ChatRequestAssistantMessage(getOrDefault(aiMessage.text(), "")); - chatRequestAssistantMessage.setFunctionCall(functionCallFrom(message)); + chatRequestAssistantMessage.setToolCalls(toolExecutionRequestsFrom(message)); return chatRequestAssistantMessage; } else if (message instanceof ToolExecutionResultMessage) { ToolExecutionResultMessage toolExecutionResultMessage = (ToolExecutionResultMessage) message; - return new ChatRequestFunctionMessage(nameFrom(message), toolExecutionResultMessage.text()); + return new ChatRequestToolMessage(toolExecutionResultMessage.text(), toolExecutionResultMessage.id()); } else if (message instanceof SystemMessage) { SystemMessage systemMessage = (SystemMessage) message; return new ChatRequestSystemMessage(systemMessage.text()); @@ -179,37 +178,45 @@ private static String nameFrom(ChatMessage message) { return null; } - private static FunctionCall functionCallFrom(ChatMessage message) { + private static List toolExecutionRequestsFrom(ChatMessage message) { if (message instanceof AiMessage) { AiMessage aiMessage = (AiMessage) message; if (aiMessage.hasToolExecutionRequests()) { - // TODO switch to tools once supported - ToolExecutionRequest toolExecutionRequest = aiMessage.toolExecutionRequests().get(0); - return new FunctionCall(toolExecutionRequest.name(), toolExecutionRequest.arguments()); + return aiMessage.toolExecutionRequests().stream() + .map(toolExecutionRequest -> new ChatCompletionsFunctionToolCall(toolExecutionRequest.id(), new FunctionCall(toolExecutionRequest.name(), toolExecutionRequest.arguments()))) + .collect(toList()); + } } - return null; } - public static List toFunctions(Collection toolSpecifications) { + public static List toToolDefinitions(Collection toolSpecifications) { return toolSpecifications.stream() - .map(InternalAzureOpenAiHelper::toFunction) + .map(InternalAzureOpenAiHelper::toToolDefinition) .collect(toList()); } - private static FunctionDefinition toFunction(ToolSpecification toolSpecification) { + private static ChatCompletionsToolDefinition toToolDefinition(ToolSpecification toolSpecification) { FunctionDefinition functionDefinition = new FunctionDefinition(toolSpecification.name()); functionDefinition.setDescription(toolSpecification.description()); functionDefinition.setParameters(toOpenAiParameters(toolSpecification.parameters())); - return functionDefinition; + return new ChatCompletionsFunctionToolDefinition(functionDefinition); + } + + public static BinaryData toToolChoice(ToolSpecification toolThatMustBeExecuted) { + FunctionCall functionCall = new FunctionCall(toolThatMustBeExecuted.name(), toOpenAiParameters(toolThatMustBeExecuted.parameters()).toString()); + ChatCompletionsToolCall toolToCall = new ChatCompletionsFunctionToolCall(toolThatMustBeExecuted.name(), functionCall); + return BinaryData.fromObject(toolToCall); } private static final Map NO_PARAMETER_DATA = new HashMap<>(); + static { NO_PARAMETER_DATA.put("type", "object"); NO_PARAMETER_DATA.put("properties", new HashMap<>()); } + private static BinaryData toOpenAiParameters(ToolParameters toolParameters) { Parameters parameters = new Parameters(); if (toolParameters == null) { @@ -225,6 +232,7 @@ private static class Parameters { private final String type = "object"; private Map> properties = new HashMap<>(); + private List required = new ArrayList<>(); public String getType() { @@ -252,14 +260,19 @@ public static AiMessage aiMessageFrom(com.azure.ai.openai.models.ChatResponseMes if (chatResponseMessage.getContent() != null) { return aiMessage(chatResponseMessage.getContent()); } else { - FunctionCall functionCall = chatResponseMessage.getFunctionCall(); - - ToolExecutionRequest toolExecutionRequest = ToolExecutionRequest.builder() - .name(functionCall.getName()) - .arguments(functionCall.getArguments()) - .build(); - - return aiMessage(toolExecutionRequest); + List toolExecutionRequests = chatResponseMessage.getToolCalls() + .stream() + .filter(toolCall -> toolCall instanceof ChatCompletionsFunctionToolCall) + .map(toolCall -> (ChatCompletionsFunctionToolCall) toolCall) + .map(chatCompletionsFunctionToolCall -> + ToolExecutionRequest.builder() + .id(chatCompletionsFunctionToolCall.getId()) + .name(chatCompletionsFunctionToolCall.getFunction().getName()) + .arguments(chatCompletionsFunctionToolCall.getFunction().getArguments()) + .build()) + .collect(toList()); + + return aiMessage(toolExecutionRequests); } } diff --git a/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAIResponsibleAIIT.java b/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAIResponsibleAIIT.java index 7e186a6209..55099e824b 100644 --- a/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAIResponsibleAIIT.java +++ b/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAIResponsibleAIIT.java @@ -8,7 +8,6 @@ import dev.langchain4j.model.chat.StreamingChatLanguageModel; import dev.langchain4j.model.language.LanguageModel; import dev.langchain4j.model.language.StreamingLanguageModel; -import dev.langchain4j.model.openai.OpenAiTokenizer; import dev.langchain4j.model.output.Response; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; @@ -18,7 +17,7 @@ import java.util.concurrent.CompletableFuture; -import static dev.langchain4j.model.azure.AzureOpenAiModelName.GPT_3_5_TURBO_INSTRUCT; +import static dev.langchain4j.model.azure.AzureOpenAiLanguageModelName.GPT_3_5_TURBO_INSTRUCT; import static dev.langchain4j.model.output.FinishReason.CONTENT_FILTER; import static java.util.concurrent.TimeUnit.SECONDS; import static org.assertj.core.api.Assertions.assertThat; @@ -33,8 +32,7 @@ public class AzureOpenAIResponsibleAIIT { @ParameterizedTest(name = "Deployment name {0} using {1}") @CsvSource({ - "gpt-35-turbo, gpt-3.5-turbo", - "gpt-4, gpt-4" + "gpt-4o, gpt-4o" }) void chat_message_should_trigger_content_filter_for_violence(String deploymentName, String gptVersion) { @@ -42,7 +40,7 @@ void chat_message_should_trigger_content_filter_for_violence(String deploymentNa .endpoint(System.getenv("AZURE_OPENAI_ENDPOINT")) .apiKey(System.getenv("AZURE_OPENAI_KEY")) .deploymentName(deploymentName) - .tokenizer(new OpenAiTokenizer(gptVersion)) + .tokenizer(new AzureOpenAiTokenizer(gptVersion)) .logRequestsAndResponses(true) .build(); @@ -56,8 +54,7 @@ void chat_message_should_trigger_content_filter_for_violence(String deploymentNa @ParameterizedTest(name = "Deployment name {0} using {1}") @CsvSource({ - "gpt-35-turbo, gpt-3.5-turbo", - "gpt-4, gpt-4" + "gpt-4o, gpt-4o" }) void chat_message_should_trigger_content_filter_for_self_harm(String deploymentName, String gptVersion) { @@ -65,7 +62,7 @@ void chat_message_should_trigger_content_filter_for_self_harm(String deploymentN .endpoint(System.getenv("AZURE_OPENAI_ENDPOINT")) .apiKey(System.getenv("AZURE_OPENAI_KEY")) .deploymentName(deploymentName) - .tokenizer(new OpenAiTokenizer(gptVersion)) + .tokenizer(new AzureOpenAiTokenizer(gptVersion)) .logRequestsAndResponses(true) .build(); @@ -102,7 +99,7 @@ void language_model_should_trigger_content_filter_for_violence() { .endpoint(System.getenv("AZURE_OPENAI_ENDPOINT")) .apiKey(System.getenv("AZURE_OPENAI_KEY")) .deploymentName("gpt-35-turbo-instruct") - .tokenizer(new OpenAiTokenizer(GPT_3_5_TURBO_INSTRUCT)) + .tokenizer(new AzureOpenAiTokenizer(GPT_3_5_TURBO_INSTRUCT)) .temperature(0.0) .maxTokens(20) .logRequestsAndResponses(true) @@ -118,8 +115,7 @@ void language_model_should_trigger_content_filter_for_violence() { @ParameterizedTest(name = "Deployment name {0} using {1}") @CsvSource({ - "gpt-35-turbo, gpt-3.5-turbo", - "gpt-4, gpt-4" + "gpt-4o, gpt-4o" }) void streaming_chat_message_should_trigger_content_filter_for_violence(String deploymentName, String gptVersion) throws Exception { @@ -130,7 +126,7 @@ void streaming_chat_message_should_trigger_content_filter_for_violence(String de .endpoint(System.getenv("AZURE_OPENAI_ENDPOINT")) .apiKey(System.getenv("AZURE_OPENAI_KEY")) .deploymentName(deploymentName) - .tokenizer(new OpenAiTokenizer(gptVersion)) + .tokenizer(new AzureOpenAiTokenizer(gptVersion)) .logRequestsAndResponses(true) .build(); @@ -167,8 +163,7 @@ public void onError(Throwable error) { @ParameterizedTest(name = "Deployment name {0} using {1}") @CsvSource({ - "gpt-35-turbo, gpt-3.5-turbo", - "gpt-4, gpt-4" + "gpt-4o, gpt-4o" }) void streaming_language_should_trigger_content_filter_for_violence(String deploymentName, String gptVersion) throws Exception { @@ -176,7 +171,7 @@ void streaming_language_should_trigger_content_filter_for_violence(String deploy .endpoint(System.getenv("AZURE_OPENAI_ENDPOINT")) .apiKey(System.getenv("AZURE_OPENAI_KEY")) .deploymentName("gpt-35-turbo-instruct") - .tokenizer(new OpenAiTokenizer(GPT_3_5_TURBO_INSTRUCT)) + .tokenizer(new AzureOpenAiTokenizer(GPT_3_5_TURBO_INSTRUCT)) .temperature(0.0) .maxTokens(20) .logRequestsAndResponses(true) diff --git a/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiChatModelIT.java b/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiChatModelIT.java index 8cbe024a95..6cbdb473db 100644 --- a/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiChatModelIT.java +++ b/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiChatModelIT.java @@ -1,7 +1,6 @@ package dev.langchain4j.model.azure; import com.azure.ai.openai.models.ChatCompletionsJsonResponseFormat; -import com.azure.ai.openai.models.ChatCompletionsResponseFormat; import com.azure.core.util.BinaryData; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; @@ -10,30 +9,35 @@ import dev.langchain4j.agent.tool.ToolSpecification; import dev.langchain4j.data.message.*; import dev.langchain4j.model.chat.ChatLanguageModel; -import dev.langchain4j.model.openai.OpenAiTokenizer; import dev.langchain4j.model.output.Response; import dev.langchain4j.model.output.TokenUsage; +import org.assertj.core.data.Percentage; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.CsvSource; +import org.junit.jupiter.params.provider.EnumSource; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.util.*; +import static dev.langchain4j.agent.tool.JsonSchemaProperty.INTEGER; import static dev.langchain4j.data.message.ToolExecutionResultMessage.toolExecutionResultMessage; import static dev.langchain4j.data.message.UserMessage.userMessage; import static dev.langchain4j.model.output.FinishReason.LENGTH; import static dev.langchain4j.model.output.FinishReason.STOP; +import static java.util.Arrays.asList; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.data.Percentage.withPercentage; public class AzureOpenAiChatModelIT { Logger logger = LoggerFactory.getLogger(AzureOpenAiChatModelIT.class); + Percentage tokenizerPrecision = withPercentage(5); + @ParameterizedTest(name = "Deployment name {0} using {1}") @CsvSource({ - "gpt-35-turbo, gpt-3.5-turbo", - "gpt-4, gpt-4" + "gpt-4o, gpt-4o" }) void should_generate_answer_and_return_token_usage_and_finish_reason_stop(String deploymentName, String gptVersion) { @@ -41,7 +45,7 @@ void should_generate_answer_and_return_token_usage_and_finish_reason_stop(String .endpoint(System.getenv("AZURE_OPENAI_ENDPOINT")) .apiKey(System.getenv("AZURE_OPENAI_KEY")) .deploymentName(deploymentName) - .tokenizer(new OpenAiTokenizer(gptVersion)) + .tokenizer(new AzureOpenAiTokenizer(gptVersion)) .logRequestsAndResponses(true) .build(); @@ -62,8 +66,7 @@ void should_generate_answer_and_return_token_usage_and_finish_reason_stop(String @ParameterizedTest(name = "Deployment name {0} using {1}") @CsvSource({ - "gpt-35-turbo, gpt-3.5-turbo", - "gpt-4, gpt-4" + "gpt-4o, gpt-4o" }) void should_generate_answer_and_return_token_usage_and_finish_reason_length(String deploymentName, String gptVersion) { @@ -71,7 +74,7 @@ void should_generate_answer_and_return_token_usage_and_finish_reason_length(Stri .endpoint(System.getenv("AZURE_OPENAI_ENDPOINT")) .apiKey(System.getenv("AZURE_OPENAI_KEY")) .deploymentName(deploymentName) - .tokenizer(new OpenAiTokenizer(gptVersion)) + .tokenizer(new AzureOpenAiTokenizer(gptVersion)) .maxTokens(3) .logRequestsAndResponses(true) .build(); @@ -93,8 +96,7 @@ void should_generate_answer_and_return_token_usage_and_finish_reason_length(Stri @ParameterizedTest(name = "Deployment name {0} using {1}") @CsvSource({ - "gpt-35-turbo, gpt-3.5-turbo", - "gpt-4, gpt-4" + "gpt-4o, gpt-4o" }) void should_call_function_with_argument(String deploymentName, String gptVersion) { @@ -102,7 +104,7 @@ void should_call_function_with_argument(String deploymentName, String gptVersion .endpoint(System.getenv("AZURE_OPENAI_ENDPOINT")) .apiKey(System.getenv("AZURE_OPENAI_KEY")) .deploymentName(deploymentName) - .tokenizer(new OpenAiTokenizer(gptVersion)) + .tokenizer(new AzureOpenAiTokenizer(gptVersion)) .logRequestsAndResponses(true) .build(); @@ -121,6 +123,7 @@ void should_call_function_with_argument(String deploymentName, String gptVersion AiMessage aiMessage = response.content(); assertThat(aiMessage.text()).isNull(); + assertThat(response.finishReason()).isEqualTo(STOP); assertThat(aiMessage.toolExecutionRequests()).hasSize(1); ToolExecutionRequest toolExecutionRequest = aiMessage.toolExecutionRequests().get(0); @@ -131,8 +134,7 @@ void should_call_function_with_argument(String deploymentName, String gptVersion // We can now call the function with the correct parameters. WeatherLocation weatherLocation = BinaryData.fromString(toolExecutionRequest.arguments()).toObject(WeatherLocation.class); - int currentWeather = 0; - currentWeather = getCurrentWeather(weatherLocation); + int currentWeather = getCurrentWeather(weatherLocation); String weather = String.format("The weather in %s is %d degrees %s.", weatherLocation.getLocation(), currentWeather, weatherLocation.getUnit()); @@ -160,15 +162,14 @@ void should_call_function_with_argument(String deploymentName, String gptVersion @ParameterizedTest(name = "Deployment name {0} using {1}") @CsvSource({ - "gpt-35-turbo, gpt-3.5-turbo", - "gpt-4, gpt-4" + "gpt-4o, gpt-4o" }) void should_call_function_with_no_argument(String deploymentName, String gptVersion) { ChatLanguageModel model = AzureOpenAiChatModel.builder() .endpoint(System.getenv("AZURE_OPENAI_ENDPOINT")) .apiKey(System.getenv("AZURE_OPENAI_KEY")) .deploymentName(deploymentName) - .tokenizer(new OpenAiTokenizer(gptVersion)) + .tokenizer(new AzureOpenAiTokenizer(gptVersion)) .logRequestsAndResponses(true) .build(); @@ -195,15 +196,92 @@ void should_call_function_with_no_argument(String deploymentName, String gptVers @ParameterizedTest(name = "Deployment name {0} using {1}") @CsvSource({ - "gpt-35-turbo, gpt-3.5-turbo", - "gpt-4, gpt-4" + "gpt-4o, gpt-4o" + }) + void should_call_three_functions_in_parallel(String deploymentName, String gptVersion) throws Exception { + + ChatLanguageModel model = AzureOpenAiChatModel.builder() + .endpoint(System.getenv("AZURE_OPENAI_ENDPOINT")) + .apiKey(System.getenv("AZURE_OPENAI_KEY")) + .deploymentName(deploymentName) + .tokenizer(new AzureOpenAiTokenizer(gptVersion)) + .logRequestsAndResponses(true) + .build(); + + UserMessage userMessage = userMessage("Give three numbers, ordered by size: the sum of two plus two, the square of four, and finally the cube of eight."); + + List toolSpecifications = asList( + ToolSpecification.builder() + .name("sum") + .description("returns a sum of two numbers") + .addParameter("first", INTEGER) + .addParameter("second", INTEGER) + .build(), + ToolSpecification.builder() + .name("square") + .description("returns the square of one number") + .addParameter("number", INTEGER) + .build(), + ToolSpecification.builder() + .name("cube") + .description("returns the cube of one number") + .addParameter("number", INTEGER) + .build() + ); + + Response response = model.generate(Collections.singletonList(userMessage), toolSpecifications); + + AiMessage aiMessage = response.content(); + assertThat(aiMessage.text()).isNull(); + List messages = new ArrayList<>(); + messages.add(userMessage); + messages.add(aiMessage); + assertThat(aiMessage.toolExecutionRequests()).hasSize(3); + for (ToolExecutionRequest toolExecutionRequest : aiMessage.toolExecutionRequests()) { + assertThat(toolExecutionRequest.name()).isNotEmpty(); + ToolExecutionResultMessage toolExecutionResultMessage; + if (toolExecutionRequest.name().equals("sum")) { + assertThat(toolExecutionRequest.arguments()).isEqualToIgnoringWhitespace("{\"first\": 2, \"second\": 2}"); + toolExecutionResultMessage = toolExecutionResultMessage(toolExecutionRequest, "4"); + } else if (toolExecutionRequest.name().equals("square")) { + assertThat(toolExecutionRequest.arguments()).isEqualToIgnoringWhitespace("{\"number\": 4}"); + toolExecutionResultMessage = toolExecutionResultMessage(toolExecutionRequest, "16"); + } else if (toolExecutionRequest.name().equals("cube")) { + assertThat(toolExecutionRequest.arguments()).isEqualToIgnoringWhitespace("{\"number\": 8}"); + toolExecutionResultMessage = toolExecutionResultMessage(toolExecutionRequest, "512"); + } else { + throw new AssertionError("Unexpected tool name: " + toolExecutionRequest.name()); + } + messages.add(toolExecutionResultMessage); + } + + Response response2 = model.generate(messages); + AiMessage aiMessage2 = response2.content(); + + // then + logger.debug("Final answer is: " + aiMessage2); + assertThat(aiMessage2.text()).contains("4", "16", "512"); + assertThat(aiMessage2.toolExecutionRequests()).isNull(); + + TokenUsage tokenUsage2 = response2.tokenUsage(); + assertThat(tokenUsage2.inputTokenCount()).isCloseTo(112, tokenizerPrecision); + assertThat(tokenUsage2.outputTokenCount()).isGreaterThan(0); + assertThat(tokenUsage2.totalTokenCount()) + .isEqualTo(tokenUsage2.inputTokenCount() + tokenUsage2.outputTokenCount()); + + assertThat(response2.finishReason()).isEqualTo(STOP); + } + + @ParameterizedTest(name = "Deployment name {0} using {1}") + @CsvSource({ + "gpt-4o, gpt-4o" }) void should_use_json_format(String deploymentName, String gptVersion) { ChatLanguageModel model = AzureOpenAiChatModel.builder() .endpoint(System.getenv("AZURE_OPENAI_ENDPOINT")) .apiKey(System.getenv("AZURE_OPENAI_KEY")) .deploymentName(deploymentName) - .tokenizer(new OpenAiTokenizer(gptVersion)) + .tokenizer(new AzureOpenAiTokenizer(gptVersion)) .responseFormat(new ChatCompletionsJsonResponseFormat()) .logRequestsAndResponses(true) .build(); @@ -219,6 +297,31 @@ void should_use_json_format(String deploymentName, String gptVersion) { assertThat(response.finishReason()).isEqualTo(STOP); } + @ParameterizedTest(name = "Testing model {0}") + @EnumSource(AzureOpenAiChatModelName.class) + void should_support_all_string_model_names(AzureOpenAiChatModelName modelName) { + + // given + String modelNameString = modelName.toString(); + + ChatLanguageModel model = AzureOpenAiChatModel.builder() + .endpoint(System.getenv("AZURE_OPENAI_ENDPOINT")) + .apiKey(System.getenv("AZURE_OPENAI_KEY")) + .deploymentName(modelNameString) + .maxTokens(1) + .logRequestsAndResponses(true) + .build(); + + UserMessage userMessage = userMessage("Hi"); + + // when + Response response = model.generate(userMessage); + System.out.println(response); + + // then + assertThat(response.content().text()).isNotBlank(); + } + private static ToolParameters getToolParameters() { Map> properties = new HashMap<>(); diff --git a/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiEmbeddingModelIT.java b/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiEmbeddingModelIT.java index 6ab00b5f28..e6e6ac6d25 100644 --- a/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiEmbeddingModelIT.java +++ b/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiEmbeddingModelIT.java @@ -3,17 +3,18 @@ import dev.langchain4j.data.embedding.Embedding; import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.model.embedding.EmbeddingModel; -import dev.langchain4j.model.openai.OpenAiTokenizer; import dev.langchain4j.model.output.Response; import dev.langchain4j.model.output.TokenUsage; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.EnumSource; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.util.ArrayList; import java.util.List; -import static dev.langchain4j.model.azure.AzureOpenAiModelName.TEXT_EMBEDDING_ADA_002; +import static dev.langchain4j.model.azure.AzureOpenAiEmbeddingModelName.TEXT_EMBEDDING_ADA_002; import static org.assertj.core.api.Assertions.assertThat; public class AzureOpenAiEmbeddingModelIT { @@ -24,7 +25,7 @@ public class AzureOpenAiEmbeddingModelIT { .endpoint(System.getenv("AZURE_OPENAI_ENDPOINT")) .apiKey(System.getenv("AZURE_OPENAI_KEY")) .deploymentName("text-embedding-ada-002") - .tokenizer(new OpenAiTokenizer(TEXT_EMBEDDING_ADA_002)) + .tokenizer(new AzureOpenAiTokenizer(TEXT_EMBEDDING_ADA_002)) .logRequestsAndResponses(true) .build(); @@ -68,4 +69,26 @@ void should_embed_in_batches() { assertThat(response.finishReason()).isNull(); } + + @ParameterizedTest(name = "Testing model {0}") + @EnumSource(AzureOpenAiEmbeddingModelName.class) + void should_support_all_string_model_names(AzureOpenAiEmbeddingModelName modelName) { + + // given + String modelNameString = modelName.toString(); + + EmbeddingModel model = AzureOpenAiEmbeddingModel.builder() + .endpoint(System.getenv("AZURE_OPENAI_ENDPOINT")) + .apiKey(System.getenv("AZURE_OPENAI_KEY")) + .deploymentName(modelNameString) + .logRequestsAndResponses(true) + .build(); + + // when + Response response = model.embed("hello world"); + System.out.println(response.toString()); + + // then + assertThat(response.content().vector()).isNotEmpty(); + } } diff --git a/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiImageModelIT.java b/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiImageModelIT.java index c647a31da7..4b30d0f9ab 100644 --- a/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiImageModelIT.java +++ b/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiImageModelIT.java @@ -4,6 +4,8 @@ import dev.langchain4j.data.image.Image; import dev.langchain4j.model.output.Response; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.EnumSource; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -71,4 +73,30 @@ void should_generate_image_in_base64() throws IOException { assertThat(image.revisedPrompt()).isNotNull(); logger.info("The revised prompt is: {}", image.revisedPrompt()); } + + @ParameterizedTest(name = "Testing model {0}") + @EnumSource(AzureOpenAiImageModelName.class) + void should_support_all_string_model_names(AzureOpenAiImageModelName modelName) { + + // given + String modelNameString = modelName.toString(); + + AzureOpenAiImageModel model = AzureOpenAiImageModel.builder() + .endpoint(System.getenv("AZURE_OPENAI_ENDPOINT")) + .apiKey(System.getenv("AZURE_OPENAI_KEY")) + .deploymentName(modelNameString) + .logRequestsAndResponses(true) + .build(); + + // when + Response response = model.generate("A coffee mug in Paris, France"); + logger.info(response.toString()); + + // then + Image image = response.content(); + assertThat(image).isNotNull(); + assertThat(image.url()).isNotNull(); + assertThat(image.base64Data()).isNull(); + assertThat(image.revisedPrompt()).isNotNull(); + } } diff --git a/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiLanguageModelIT.java b/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiLanguageModelIT.java index 8ae8ac392a..dffbf9b6bc 100644 --- a/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiLanguageModelIT.java +++ b/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiLanguageModelIT.java @@ -1,14 +1,15 @@ package dev.langchain4j.model.azure; import dev.langchain4j.model.language.LanguageModel; -import dev.langchain4j.model.openai.OpenAiTokenizer; import dev.langchain4j.model.output.Response; import dev.langchain4j.model.output.TokenUsage; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.EnumSource; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import static dev.langchain4j.model.azure.AzureOpenAiModelName.GPT_3_5_TURBO_INSTRUCT; +import static dev.langchain4j.model.azure.AzureOpenAiLanguageModelName.GPT_3_5_TURBO_INSTRUCT; import static dev.langchain4j.model.output.FinishReason.LENGTH; import static dev.langchain4j.model.output.FinishReason.STOP; import static org.assertj.core.api.Assertions.assertThat; @@ -21,7 +22,7 @@ class AzureOpenAiLanguageModelIT { .endpoint(System.getenv("AZURE_OPENAI_ENDPOINT")) .apiKey(System.getenv("AZURE_OPENAI_KEY")) .deploymentName("gpt-35-turbo-instruct") - .tokenizer(new OpenAiTokenizer(GPT_3_5_TURBO_INSTRUCT)) + .tokenizer(new AzureOpenAiTokenizer(GPT_3_5_TURBO_INSTRUCT)) .temperature(0.0) .maxTokens(20) .logRequestsAndResponses(true) @@ -54,4 +55,27 @@ void should_generate_answer_and_finish_reason_length() { assertThat(response.finishReason()).isEqualTo(LENGTH); } + + @ParameterizedTest(name = "Testing model {0}") + @EnumSource(AzureOpenAiLanguageModelName.class) + void should_support_all_string_model_names(AzureOpenAiLanguageModelName modelName) { + + // given + String modelNameString = modelName.toString(); + + LanguageModel model = AzureOpenAiLanguageModel.builder() + .endpoint(System.getenv("AZURE_OPENAI_ENDPOINT")) + .apiKey(System.getenv("AZURE_OPENAI_KEY")) + .deploymentName(modelNameString) + .logRequestsAndResponses(true) + .build(); + + // when + String prompt = "Describe the capital of France in 100 words: "; + Response response = model.generate(prompt); + System.out.println(response.toString()); + + // then + assertThat(response.finishReason()).isEqualTo(LENGTH); + } } \ No newline at end of file diff --git a/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiStreamingChatModelIT.java b/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiStreamingChatModelIT.java index f51d0a005b..5b31722ec5 100644 --- a/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiStreamingChatModelIT.java +++ b/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiStreamingChatModelIT.java @@ -1,45 +1,53 @@ package dev.langchain4j.model.azure; +import com.azure.ai.openai.OpenAIAsyncClient; +import com.azure.ai.openai.OpenAIClient; import com.azure.ai.openai.models.ChatCompletionsJsonResponseFormat; import dev.langchain4j.agent.tool.ToolExecutionRequest; import dev.langchain4j.agent.tool.ToolSpecification; import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.data.message.ChatMessage; -import dev.langchain4j.data.message.SystemMessage; +import dev.langchain4j.data.message.ToolExecutionResultMessage; import dev.langchain4j.data.message.UserMessage; import dev.langchain4j.model.StreamingResponseHandler; -import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.chat.StreamingChatLanguageModel; -import dev.langchain4j.model.openai.OpenAiTokenizer; +import dev.langchain4j.model.chat.TestStreamingResponseHandler; import dev.langchain4j.model.output.Response; +import dev.langchain4j.model.output.TokenUsage; +import org.assertj.core.data.Percentage; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.CsvSource; +import org.junit.jupiter.params.provider.ValueSource; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.util.Arrays; -import java.util.Collections; +import java.time.Duration; +import java.util.ArrayList; import java.util.List; import java.util.concurrent.CompletableFuture; import static dev.langchain4j.agent.tool.JsonSchemaProperty.INTEGER; +import static dev.langchain4j.data.message.ToolExecutionResultMessage.toolExecutionResultMessage; import static dev.langchain4j.data.message.UserMessage.userMessage; import static dev.langchain4j.model.output.FinishReason.STOP; -import static dev.langchain4j.model.output.FinishReason.TOOL_EXECUTION; +import static java.util.Arrays.asList; import static java.util.Collections.singletonList; import static java.util.concurrent.TimeUnit.SECONDS; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.data.Percentage.withPercentage; class AzureOpenAiStreamingChatModelIT { Logger logger = LoggerFactory.getLogger(AzureOpenAiStreamingChatModelIT.class); - @ParameterizedTest(name = "Deployment name {0} using {1}") + Percentage tokenizerPrecision = withPercentage(5); + + @ParameterizedTest(name = "Deployment name {0} using {1} with async client set to {2}") @CsvSource({ - "gpt-35-turbo, gpt-3.5-turbo", - "gpt-4, gpt-4" + "gpt-4o, gpt-4o, true", + "gpt-4o, gpt-4o, false" }) - void should_stream_answer(String deploymentName, String gptVersion) throws Exception { + void should_stream_answer(String deploymentName, String gptVersion, boolean useAsyncClient) throws Exception { CompletableFuture futureAnswer = new CompletableFuture<>(); CompletableFuture> futureResponse = new CompletableFuture<>(); @@ -48,7 +56,8 @@ void should_stream_answer(String deploymentName, String gptVersion) throws Excep .endpoint(System.getenv("AZURE_OPENAI_ENDPOINT")) .apiKey(System.getenv("AZURE_OPENAI_KEY")) .deploymentName(deploymentName) - .tokenizer(new OpenAiTokenizer(gptVersion)) + .useAsyncClient(useAsyncClient) + .tokenizer(new AzureOpenAiTokenizer(gptVersion)) .logRequestsAndResponses(true) .build(); @@ -90,30 +99,33 @@ public void onError(Throwable error) { assertThat(response.finishReason()).isEqualTo(STOP); } - @ParameterizedTest(name = "Deployment name {0} using {1}") + @ParameterizedTest(name = "Deployment name {0} using {1} with custom async client set to {2} ") @CsvSource({ - "gpt-35-turbo, gpt-3.5-turbo", - "gpt-4, gpt-4" + "gpt-4o, gpt-4o, true", + "gpt-4o, gpt-4o, false" }) - void should_use_json_format(String deploymentName, String gptVersion) throws Exception { + void should_custom_models_work(String deploymentName, String gptVersion, boolean useCustomAsyncClient) throws Exception { CompletableFuture futureAnswer = new CompletableFuture<>(); CompletableFuture> futureResponse = new CompletableFuture<>(); + OpenAIAsyncClient asyncClient = null; + OpenAIClient client = null; + if(useCustomAsyncClient) { + asyncClient = InternalAzureOpenAiHelper.setupAsyncClient(System.getenv("AZURE_OPENAI_ENDPOINT"), gptVersion, System.getenv("AZURE_OPENAI_KEY"), Duration.ofSeconds(30), 5, null, true); + } else { + client = InternalAzureOpenAiHelper.setupSyncClient(System.getenv("AZURE_OPENAI_ENDPOINT"), gptVersion, System.getenv("AZURE_OPENAI_KEY"), Duration.ofSeconds(30), 5, null, true); + } StreamingChatLanguageModel model = AzureOpenAiStreamingChatModel.builder() + .openAIAsyncClient(asyncClient) + .openAIClient(client) .endpoint(System.getenv("AZURE_OPENAI_ENDPOINT")) .apiKey(System.getenv("AZURE_OPENAI_KEY")) .deploymentName(deploymentName) - .tokenizer(new OpenAiTokenizer(gptVersion)) - .responseFormat(new ChatCompletionsJsonResponseFormat()) + .tokenizer(new AzureOpenAiTokenizer(gptVersion)) .logRequestsAndResponses(true) .build(); - - SystemMessage systemMessage = SystemMessage.systemMessage("You are a helpful assistant designed to output JSON."); - UserMessage userMessage = userMessage("List teams in the past French presidents, with their first name, last name, dates of service."); - - List messages = Arrays.asList(systemMessage, userMessage); - model.generate(messages, new StreamingResponseHandler() { + model.generate("What is the capital of France?", new StreamingResponseHandler() { private final StringBuilder answerBuilder = new StringBuilder(); @@ -140,25 +152,152 @@ public void onError(Throwable error) { String answer = futureAnswer.get(30, SECONDS); Response response = futureResponse.get(30, SECONDS); - assertThat(response.content().text()).contains("Chirac", "Sarkozy", "Hollande", "Macron"); + assertThat(answer).contains("Paris"); + assertThat(response.content().text()).isEqualTo(answer); + + assertThat(response.tokenUsage().inputTokenCount()).isEqualTo(14); + assertThat(response.tokenUsage().outputTokenCount()).isGreaterThan(0); + assertThat(response.tokenUsage().totalTokenCount()) + .isEqualTo(response.tokenUsage().inputTokenCount() + response.tokenUsage().outputTokenCount()); + assertThat(response.finishReason()).isEqualTo(STOP); } + @ParameterizedTest(name = "Deployment name {0}") + @ValueSource(strings = {"gpt-4o"}) + void should_use_json_format(String deploymentName) { + + StreamingChatLanguageModel model = AzureOpenAiStreamingChatModel.builder() + .endpoint(System.getenv("AZURE_OPENAI_ENDPOINT")) + .apiKey(System.getenv("AZURE_OPENAI_KEY")) + .deploymentName(deploymentName) + .responseFormat(new ChatCompletionsJsonResponseFormat()) + .temperature(0.0) + .maxTokens(50) + .logRequestsAndResponses(true) + .build(); + + String userMessage = "Return JSON with two fields: name and surname of Klaus Heisler."; + + String expectedJson = "{\"name\": \"Klaus\", \"surname\": \"Heisler\"}"; + + TestStreamingResponseHandler handler = new TestStreamingResponseHandler<>(); + model.generate(userMessage, handler); + Response response = handler.get(); + + assertThat(response.content().text()).isEqualToIgnoringWhitespace(expectedJson); + } + @ParameterizedTest(name = "Deployment name {0} using {1}") @CsvSource({ - "gpt-35-turbo, gpt-3.5-turbo", - "gpt-4, gpt-4" + "gpt-4o, gpt-4o" }) - void should_return_tool_execution_request(String deploymentName, String gptVersion) throws Exception { + void should_call_function_with_argument(String deploymentName, String gptVersion) throws Exception { + + CompletableFuture> futureResponse = new CompletableFuture<>(); + + StreamingChatLanguageModel model = AzureOpenAiStreamingChatModel.builder() + .endpoint(System.getenv("AZURE_OPENAI_ENDPOINT")) + .apiKey(System.getenv("AZURE_OPENAI_KEY")) + .deploymentName(deploymentName) + .tokenizer(new AzureOpenAiTokenizer(gptVersion)) + .logRequestsAndResponses(true) + .build(); + + UserMessage userMessage = userMessage("Two plus two?"); + + String toolName = "calculator"; ToolSpecification toolSpecification = ToolSpecification.builder() - .name("calculator") + .name(toolName) .description("returns a sum of two numbers") .addParameter("first", INTEGER) .addParameter("second", INTEGER) .build(); - UserMessage userMessage = userMessage("Two plus two?"); + model.generate(singletonList(userMessage), toolSpecification, new StreamingResponseHandler() { + + @Override + public void onNext(String token) { + logger.info("onNext: '" + token + "'"); + Exception e = new IllegalStateException("onNext() should never be called when tool is executed"); + futureResponse.completeExceptionally(e); + } + + @Override + public void onComplete(Response response) { + logger.info("onComplete: '" + response + "'"); + futureResponse.complete(response); + } + + @Override + public void onError(Throwable error) { + futureResponse.completeExceptionally(error); + } + }); + + Response response = futureResponse.get(30, SECONDS); + + AiMessage aiMessage = response.content(); + assertThat(aiMessage.text()).isNull(); + + assertThat(aiMessage.toolExecutionRequests()).hasSize(1); + ToolExecutionRequest toolExecutionRequest = aiMessage.toolExecutionRequests().get(0); + assertThat(toolExecutionRequest.name()).isEqualTo(toolName); + assertThat(toolExecutionRequest.arguments()).isEqualToIgnoringWhitespace("{\"first\": 2, \"second\": 2}"); + + assertThat(response.tokenUsage().inputTokenCount()).isCloseTo(58, tokenizerPrecision); + assertThat(response.tokenUsage().outputTokenCount()).isGreaterThan(0); + assertThat(response.tokenUsage().totalTokenCount()) + .isEqualTo(response.tokenUsage().inputTokenCount() + response.tokenUsage().outputTokenCount()); + + assertThat(response.finishReason()).isEqualTo(STOP); + + ToolExecutionResultMessage toolExecutionResultMessage = toolExecutionResultMessage(toolExecutionRequest, "four"); + List messages = asList(userMessage, aiMessage, toolExecutionResultMessage); + + CompletableFuture> futureResponse2 = new CompletableFuture<>(); + + model.generate(messages, new StreamingResponseHandler() { + + @Override + public void onNext(String token) { + logger.info("onNext: '" + token + "'"); + } + + @Override + public void onComplete(Response response) { + logger.info("onComplete: '" + response + "'"); + futureResponse2.complete(response); + } + + @Override + public void onError(Throwable error) { + futureResponse2.completeExceptionally(error); + } + }); + + Response response2 = futureResponse2.get(30, SECONDS); + AiMessage aiMessage2 = response2.content(); + + // then + assertThat(aiMessage2.text()).contains("four"); + assertThat(aiMessage2.toolExecutionRequests()).isNull(); + + TokenUsage tokenUsage2 = response2.tokenUsage(); + assertThat(tokenUsage2.inputTokenCount()).isCloseTo(33, tokenizerPrecision); + assertThat(tokenUsage2.outputTokenCount()).isGreaterThan(0); + assertThat(tokenUsage2.totalTokenCount()) + .isEqualTo(tokenUsage2.inputTokenCount() + tokenUsage2.outputTokenCount()); + + assertThat(response2.finishReason()).isEqualTo(STOP); + } + + @ParameterizedTest(name = "Deployment name {0} using {1}") + @CsvSource({ + "gpt-4o, gpt-4o" + }) + void should_call_three_functions_in_parallel(String deploymentName, String gptVersion) throws Exception { CompletableFuture> futureResponse = new CompletableFuture<>(); @@ -166,11 +305,32 @@ void should_return_tool_execution_request(String deploymentName, String gptVersi .endpoint(System.getenv("AZURE_OPENAI_ENDPOINT")) .apiKey(System.getenv("AZURE_OPENAI_KEY")) .deploymentName(deploymentName) - .tokenizer(new OpenAiTokenizer(gptVersion)) + .tokenizer(new AzureOpenAiTokenizer(gptVersion)) .logRequestsAndResponses(true) .build(); - model.generate(singletonList(userMessage), singletonList(toolSpecification), new StreamingResponseHandler() { + UserMessage userMessage = userMessage("Give three numbers, ordered by size: the sum of two plus two, the square of four, and finally the cube of eight."); + + List toolSpecifications = asList( + ToolSpecification.builder() + .name("sum") + .description("returns a sum of two numbers") + .addParameter("first", INTEGER) + .addParameter("second", INTEGER) + .build(), + ToolSpecification.builder() + .name("square") + .description("returns the square of one number") + .addParameter("number", INTEGER) + .build(), + ToolSpecification.builder() + .name("cube") + .description("returns the cube of one number") + .addParameter("number", INTEGER) + .build() + ); + + model.generate(singletonList(userMessage), toolSpecifications, new StreamingResponseHandler() { @Override public void onNext(String token) { @@ -195,17 +355,62 @@ public void onError(Throwable error) { AiMessage aiMessage = response.content(); assertThat(aiMessage.text()).isNull(); + List messages = new ArrayList<>(); + messages.add(userMessage); + messages.add(aiMessage); + assertThat(aiMessage.toolExecutionRequests()).hasSize(3); + for (ToolExecutionRequest toolExecutionRequest : aiMessage.toolExecutionRequests()) { + assertThat(toolExecutionRequest.name()).isNotEmpty(); + ToolExecutionResultMessage toolExecutionResultMessage; + if (toolExecutionRequest.name().equals("sum")) { + assertThat(toolExecutionRequest.arguments()).isEqualToIgnoringWhitespace("{\"first\": 2, \"second\": 2}"); + toolExecutionResultMessage = toolExecutionResultMessage(toolExecutionRequest, "4"); + } else if (toolExecutionRequest.name().equals("square")) { + assertThat(toolExecutionRequest.arguments()).isEqualToIgnoringWhitespace("{\"number\": 4}"); + toolExecutionResultMessage = toolExecutionResultMessage(toolExecutionRequest, "16"); + } else if (toolExecutionRequest.name().equals("cube")) { + assertThat(toolExecutionRequest.arguments()).isEqualToIgnoringWhitespace("{\"number\": 8}"); + toolExecutionResultMessage = toolExecutionResultMessage(toolExecutionRequest, "512"); + } else { + throw new AssertionError("Unexpected tool name: " + toolExecutionRequest.name()); + } + messages.add(toolExecutionResultMessage); + } + CompletableFuture> futureResponse2 = new CompletableFuture<>(); - assertThat(aiMessage.toolExecutionRequests()).hasSize(1); - ToolExecutionRequest toolExecutionRequest = aiMessage.toolExecutionRequests().get(0); - assertThat(toolExecutionRequest.name()).isEqualTo("calculator"); - assertThat(toolExecutionRequest.arguments()).isEqualToIgnoringWhitespace("{\"first\": 2, \"second\": 2}"); + model.generate(messages, new StreamingResponseHandler() { - assertThat(response.tokenUsage().inputTokenCount()).isEqualTo(53); - assertThat(response.tokenUsage().outputTokenCount()).isGreaterThan(0); - assertThat(response.tokenUsage().totalTokenCount()) - .isEqualTo(response.tokenUsage().inputTokenCount() + response.tokenUsage().outputTokenCount()); + @Override + public void onNext(String token) { + logger.info("onNext: '" + token + "'"); + } + + @Override + public void onComplete(Response response) { + logger.info("onComplete: '" + response + "'"); + futureResponse2.complete(response); + } + + @Override + public void onError(Throwable error) { + futureResponse2.completeExceptionally(error); + } + }); + + Response response2 = futureResponse2.get(30, SECONDS); + AiMessage aiMessage2 = response2.content(); + + // then + logger.debug("Final answer is: " + aiMessage2); + assertThat(aiMessage2.text()).contains("4", "16", "512"); + assertThat(aiMessage2.toolExecutionRequests()).isNull(); + + TokenUsage tokenUsage2 = response2.tokenUsage(); + assertThat(tokenUsage2.inputTokenCount()).isCloseTo(119, tokenizerPrecision); + assertThat(tokenUsage2.outputTokenCount()).isGreaterThan(0); + assertThat(tokenUsage2.totalTokenCount()) + .isEqualTo(tokenUsage2.inputTokenCount() + tokenUsage2.outputTokenCount()); - assertThat(response.finishReason()).isEqualTo(TOOL_EXECUTION); + assertThat(response2.finishReason()).isEqualTo(STOP); } } diff --git a/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiStreamingLanguageModelIT.java b/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiStreamingLanguageModelIT.java index 86e347ef2e..910fecee11 100644 --- a/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiStreamingLanguageModelIT.java +++ b/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiStreamingLanguageModelIT.java @@ -2,7 +2,6 @@ import dev.langchain4j.model.StreamingResponseHandler; import dev.langchain4j.model.language.StreamingLanguageModel; -import dev.langchain4j.model.openai.OpenAiTokenizer; import dev.langchain4j.model.output.Response; import org.junit.jupiter.api.Test; import org.slf4j.Logger; @@ -10,7 +9,7 @@ import java.util.concurrent.CompletableFuture; -import static dev.langchain4j.model.azure.AzureOpenAiModelName.GPT_3_5_TURBO_INSTRUCT; +import static dev.langchain4j.model.azure.AzureOpenAiLanguageModelName.GPT_3_5_TURBO_INSTRUCT; import static dev.langchain4j.model.output.FinishReason.LENGTH; import static dev.langchain4j.model.output.FinishReason.STOP; import static java.util.concurrent.TimeUnit.SECONDS; @@ -24,7 +23,7 @@ class AzureOpenAiStreamingLanguageModelIT { .endpoint(System.getenv("AZURE_OPENAI_ENDPOINT")) .apiKey(System.getenv("AZURE_OPENAI_KEY")) .deploymentName("gpt-35-turbo-instruct") - .tokenizer(new OpenAiTokenizer(GPT_3_5_TURBO_INSTRUCT)) + .tokenizer(new AzureOpenAiTokenizer(GPT_3_5_TURBO_INSTRUCT)) .temperature(0.0) .maxTokens(20) .logRequestsAndResponses(true) diff --git a/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiTokenizerIT.java b/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiTokenizerIT.java new file mode 100644 index 0000000000..7dafdeca7a --- /dev/null +++ b/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiTokenizerIT.java @@ -0,0 +1,1767 @@ +package dev.langchain4j.model.azure; + +import static dev.langchain4j.agent.tool.JsonSchemaProperty.ARRAY; +import static dev.langchain4j.agent.tool.JsonSchemaProperty.INTEGER; +import static dev.langchain4j.agent.tool.JsonSchemaProperty.STRING; +import static dev.langchain4j.agent.tool.JsonSchemaProperty.description; +import static dev.langchain4j.agent.tool.JsonSchemaProperty.enums; +import static dev.langchain4j.agent.tool.JsonSchemaProperty.items; +import dev.langchain4j.agent.tool.ToolExecutionRequest; +import dev.langchain4j.agent.tool.ToolSpecification; +import dev.langchain4j.data.message.AiMessage; +import static dev.langchain4j.data.message.AiMessage.aiMessage; +import dev.langchain4j.data.message.ChatMessage; +import static dev.langchain4j.data.message.SystemMessage.systemMessage; +import static dev.langchain4j.data.message.ToolExecutionResultMessage.toolExecutionResultMessage; +import dev.langchain4j.data.message.UserMessage; +import static dev.langchain4j.data.message.UserMessage.userMessage; +import dev.langchain4j.model.Tokenizer; +import static dev.langchain4j.model.azure.AzureOpenAiChatModelName.GPT_3_5_TURBO; +import static dev.langchain4j.model.azure.AzureOpenAiChatModelName.GPT_3_5_TURBO_0613; +import static dev.langchain4j.model.azure.AzureOpenAiChatModelName.GPT_3_5_TURBO_1106; +import static dev.langchain4j.model.azure.AzureOpenAiChatModelName.GPT_4_0125_PREVIEW; +import static dev.langchain4j.model.azure.AzureOpenAiChatModelName.GPT_4_0613; +import static dev.langchain4j.model.azure.AzureOpenAiChatModelName.GPT_4_1106_PREVIEW; +import static dev.langchain4j.model.azure.AzureOpenAiChatModelName.GPT_4_32K; +import static dev.langchain4j.model.azure.AzureOpenAiChatModelName.GPT_4_32K_0613; +import static dev.langchain4j.model.azure.AzureOpenAiChatModelName.GPT_4_TURBO_2024_04_09; +import static dev.langchain4j.model.azure.AzureOpenAiChatModelName.GPT_4_VISION_PREVIEW; +import dev.langchain4j.model.output.Response; +import static java.util.Arrays.asList; +import static java.util.Arrays.stream; +import static java.util.Collections.singletonList; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.data.Percentage.withPercentage; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import static org.junit.jupiter.params.provider.Arguments.arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.opentest4j.AssertionFailedError; + +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.stream.Stream; + +// TODO use exact model for Tokenizer (the one returned by LLM) +@Disabled("this test is very long and expensive, we will need to set a schedule for it to run maybe 1 time per month") +class AzureOpenAiTokenizerIT { + + // my API key does not have access to these models + private static final Set MODELS_WITHOUT_ACCESS = new HashSet<>(asList( + GPT_3_5_TURBO_0613, + GPT_4_32K, + GPT_4_32K_0613 + )); + + private static final Set MODELS_WITHOUT_TOOL_SUPPORT = new HashSet<>(asList( + GPT_4_0613, + GPT_4_VISION_PREVIEW + )); + + private static final Set MODELS_WITH_PARALLEL_TOOL_SUPPORT = new HashSet<>(asList( + // TODO add GPT_3_5_TURBO once it points to GPT_3_5_TURBO_1106 + GPT_3_5_TURBO_1106, + GPT_4_TURBO_2024_04_09, + GPT_4_1106_PREVIEW, + GPT_4_0125_PREVIEW + )); + + @ParameterizedTest + @MethodSource + void should_count_tokens_in_messages(List messages, AzureOpenAiChatModelName modelName) { + + // given + AzureOpenAiChatModel model = AzureOpenAiChatModel.builder() + .endpoint(System.getenv("AZURE_OPENAI_ENDPOINT")) + .apiKey(System.getenv("AZURE_OPENAI_KEY")) + .deploymentName(modelName.toString()) + .maxTokens(1) // we don't need outputs, let's not waste tokens + .logRequestsAndResponses(true) + .build(); + + int expectedTokenCount = model.generate(messages).tokenUsage().inputTokenCount(); + + Tokenizer tokenizer = new AzureOpenAiTokenizer("gpt-3.5-turbo"); + + // when + int tokenCount = tokenizer.estimateTokenCountInMessages(messages); + + // then + assertThat(tokenCount).isEqualTo(expectedTokenCount); + } + + static Stream should_count_tokens_in_messages() { + return stream(AzureOpenAiChatModelName.values()) + .filter(model -> !MODELS_WITHOUT_ACCESS.contains(model)) + .flatMap(model -> Stream.of( + arguments(singletonList(systemMessage("Be friendly.")), model), + arguments(singletonList(systemMessage("You are a helpful assistant, help the user!")), model), + + arguments(singletonList(userMessage("Hi")), model), + arguments(singletonList(userMessage("Hello, how are you?")), model), + + arguments(singletonList(userMessage("Stan", "Hi")), model), + arguments(singletonList(userMessage("Klaus", "Hi")), model), + arguments(singletonList(userMessage("Giovanni", "Hi")), model), + + arguments(singletonList(aiMessage("Hi")), model), + arguments(singletonList(aiMessage("Hello, how can I help you?")), model), + + arguments(asList( + systemMessage("Be helpful"), + userMessage("hi") + ), model), + + arguments(asList( + systemMessage("Be helpful"), + userMessage("hi"), + aiMessage("Hello, how can I help you?"), + userMessage("tell me a joke") + ), model), + + arguments(asList( + systemMessage("Be helpful"), + userMessage("hi"), + aiMessage("Hello, how can I help you?"), + userMessage("tell me a joke"), + aiMessage("Why don't scientists trust atoms?\n\nBecause they make up everything!"), + userMessage("tell me another one, this one is not funny") + ), model) + )); + } + + @ParameterizedTest + @MethodSource + void should_count_tokens_in_messages_with_single_tool(List messages, AzureOpenAiChatModelName modelName) { + + // given + AzureOpenAiChatModel model = AzureOpenAiChatModel.builder() + .endpoint(System.getenv("AZURE_OPENAI_ENDPOINT")) + .apiKey(System.getenv("AZURE_OPENAI_KEY")) + .deploymentName(modelName.toString()) + .maxTokens(1) // we don't need outputs, let's not waste tokens + .logRequestsAndResponses(true) + .build(); + + int expectedTokenCount = model.generate(messages).tokenUsage().inputTokenCount(); + + Tokenizer tokenizer = new AzureOpenAiTokenizer(modelName.modelVersion()); + + // when + int tokenCount = tokenizer.estimateTokenCountInMessages(messages); + + // then + assertThat(tokenCount).isCloseTo(expectedTokenCount, withPercentage(4)); + } + + static Stream should_count_tokens_in_messages_with_single_tool() { + return stream(AzureOpenAiChatModelName.values()) + .filter(model -> !MODELS_WITHOUT_ACCESS.contains(model)) + .filter(model -> !MODELS_WITHOUT_TOOL_SUPPORT.contains(model)) + .flatMap(model -> Stream.of( + + // various tool "name" lengths + arguments(asList( + aiMessage( + ToolExecutionRequest.builder() + .id("a") + .name("time") // 1 token + .arguments("{}") + .build() + ), + toolExecutionResultMessage("a", null, "23:59") + ), model), + arguments(asList( + aiMessage( + ToolExecutionRequest.builder() + .id("b") + .name("current_time") // 2 tokens + .arguments("{}") + .build() + ), + toolExecutionResultMessage("b", null, "23:59") + ), model), + arguments(asList( + aiMessage( + ToolExecutionRequest.builder() + .id("c") + .name("get_current_time") // 3 tokens + .arguments("{}") + .build() + ), + toolExecutionResultMessage("c", null, "23:59") + ), model), + + // 1 argument, various argument "name" lengths + arguments(asList( + aiMessage( + ToolExecutionRequest.builder() + .id("a") + .name("current_time") + .arguments("{\"city\":\"Berlin\"}") // 1 token + .build() + ), + toolExecutionResultMessage("a", null, "23:59") + ), model), + arguments(asList( + aiMessage( + ToolExecutionRequest.builder() + .id("b") + .name("current_time") + .arguments("{\"target_city\":\"Berlin\"}") // 2 tokens + .build() + ), + toolExecutionResultMessage("b", null, "23:59") + ), model), + arguments(asList( + aiMessage( + ToolExecutionRequest.builder() + .id("c") + .name("current_time") + .arguments("{\"target_city_name\":\"Berlin\"}") // 3 tokens + .build() + ), + toolExecutionResultMessage("c", null, "23:59") + ), model), + + // 1 argument, various argument "value" lengths + arguments(asList( + aiMessage( + ToolExecutionRequest.builder() + .id("a") + .name("current_time") + .arguments("{\"city\":\"Berlin\"}") // 1 token + .build() + ), + toolExecutionResultMessage("a", null, "23:59") + ), model), + arguments(asList( + aiMessage( + ToolExecutionRequest.builder() + .id("b") + .name("current_time") + .arguments("{\"city\":\"Munich\"}") // 3 tokens + .build() + ), + toolExecutionResultMessage("b", null, "23:59") + ), model), + arguments(asList( + aiMessage( + ToolExecutionRequest.builder() + .id("c") + .name("current_time") + .arguments("{\"city\":\"Pietramontecorvino\"}") // 8 tokens + .build() + ), + toolExecutionResultMessage("c", null, "23:59") + ), model), + + // 1 argument, various numeric argument "value" lengths + arguments(asList( + aiMessage( + ToolExecutionRequest.builder() + .id("a") + .name("current_time") + .arguments("{\"city_id\": 189}") // 1 token + .build() + ), + toolExecutionResultMessage("a", null, "23:59") + ), model), + arguments(asList( + aiMessage( + ToolExecutionRequest.builder() + .id("b") + .name("current_time") + .arguments("{\"city_id\": 189647}") // 2 tokens + .build() + ), + toolExecutionResultMessage("b", null, "23:59") + ), model), + arguments(asList( + aiMessage( + ToolExecutionRequest.builder() + .id("c") + .name("current_time") + .arguments("{\"city_id\": 189647852}") // 3 tokens + .build() + ), + toolExecutionResultMessage("c", null, "23:59") + ), model), + + // 2 arguments + arguments(asList( + aiMessage( + ToolExecutionRequest.builder() + .id("a") + .name("current_time") + .arguments("{\"city\":\"Berlin\",\"country\":\"Germany\"}") + .build() + ), + toolExecutionResultMessage("a", null, "23:59") + ), model), + + // 3 arguments + arguments(asList( + aiMessage( + ToolExecutionRequest.builder() + .id("a") + .name("current_time") + .arguments("{\"city\":\"Berlin\",\"country\":\"Germany\",\"format\":\"24\"}") + .build() + ), + toolExecutionResultMessage("a", null, "23:59") + ), model), + + // various result lengths + arguments(asList( + aiMessage( + ToolExecutionRequest.builder() + .id("a") + .name("current_time") + .arguments("{}") + .build() + ), + toolExecutionResultMessage("a", null, "23") // 1 token + ), model), + arguments(asList( + aiMessage( + ToolExecutionRequest.builder() + .id("b") + .name("current_time") + .arguments("{}") + .build() + ), + toolExecutionResultMessage("b", null, "23:59") // 3 tokens + ), model), + arguments(asList( + aiMessage( + ToolExecutionRequest.builder() + .id("c") + .name("current_time") + .arguments("{}") + .build() + ), + toolExecutionResultMessage("c", null, "23:59:59") // 5 tokens + ), model) + )); + } + + @ParameterizedTest + @MethodSource + void should_count_tokens_in_messages_with_multiple_tools(List messages, + AzureOpenAiChatModelName modelName) { + // given + AzureOpenAiChatModel model = AzureOpenAiChatModel.builder() + .endpoint(System.getenv("AZURE_OPENAI_ENDPOINT")) + .apiKey(System.getenv("AZURE_OPENAI_KEY")) + .deploymentName(modelName.toString()) + .maxTokens(1) // we don't need outputs, let's not waste tokens + .logRequestsAndResponses(true) + .build(); + + int expectedTokenCount = model.generate(messages).tokenUsage().inputTokenCount(); + + Tokenizer tokenizer = new AzureOpenAiTokenizer(modelName.modelVersion()); + + // when + int tokenCount = tokenizer.estimateTokenCountInMessages(messages); + + // then + assertThat(tokenCount).isCloseTo(expectedTokenCount, withPercentage(4)); + } + + static Stream should_count_tokens_in_messages_with_multiple_tools() { + return stream(AzureOpenAiChatModelName.values()) + .filter(model -> !MODELS_WITHOUT_ACCESS.contains(model)) + .filter(MODELS_WITH_PARALLEL_TOOL_SUPPORT::contains) + .flatMap(model -> Stream.of( + + // various tool "name" lengths + arguments(asList( + aiMessage( + ToolExecutionRequest.builder() + .id("a") + .name("time") // 1 token + .arguments("{}") + .build(), + ToolExecutionRequest.builder() + .id("b") + .name("temperature") // 1 token + .arguments("{}") + .build() + ), + toolExecutionResultMessage("a", null, "1"), + toolExecutionResultMessage("b", null, "2") + ), model), + arguments(asList( + aiMessage( + ToolExecutionRequest.builder() + .id("a") + .name("current_time") // 2 tokens + .arguments("{}") + .build(), + ToolExecutionRequest.builder() + .id("b") + .name("current_temperature") // 2 tokens + .arguments("{}") + .build() + ), + toolExecutionResultMessage("a", null, "1"), + toolExecutionResultMessage("b", null, "2") + ), model), + arguments(asList( + aiMessage( + ToolExecutionRequest.builder() + .id("a") + .name("get_current_time") // 3 tokens + .arguments("{}") + .build(), + ToolExecutionRequest.builder() + .id("b") + .name("get_current_temperature") // 3 tokens + .arguments("{}") + .build() + ), + toolExecutionResultMessage("a", null, "1"), + toolExecutionResultMessage("b", null, "2") + ), model), + + // 1 argument, various argument "name" lengths + arguments(asList( + aiMessage( + ToolExecutionRequest.builder() + .id("a") + .name("current_time") + .arguments("{\"city\":\"Berlin\"}") // 1 token + .build(), + ToolExecutionRequest.builder() + .id("b") + .name("current_temperature") + .arguments("{\"city\":\"Berlin\"}") // 1 token + .build() + ), + toolExecutionResultMessage("a", null, "1"), + toolExecutionResultMessage("b", null, "2") + ), model), + arguments(asList( + aiMessage( + ToolExecutionRequest.builder() + .id("a") + .name("current_time") + .arguments("{\"target_city\":\"Berlin\"}") // 2 tokens + .build(), + ToolExecutionRequest.builder() + .id("b") + .name("current_temperature") + .arguments("{\"target_city\":\"Berlin\"}") // 2 tokens + .build() + ), + toolExecutionResultMessage("a", null, "1"), + toolExecutionResultMessage("b", null, "2") + ), model), + arguments(asList( + aiMessage( + ToolExecutionRequest.builder() + .id("a") + .name("current_time") + .arguments("{\"target_city_name\":\"Berlin\"}") // 3 tokens + .build(), + ToolExecutionRequest.builder() + .id("b") + .name("current_temperature") + .arguments("{\"target_city_name\":\"Berlin\"}") // 3 tokens + .build() + ), + toolExecutionResultMessage("a", null, "1"), + toolExecutionResultMessage("b", null, "2") + ), model), + + // 1 argument, various argument "value" lengths + arguments(asList( + aiMessage( + ToolExecutionRequest.builder() + .id("a") + .name("current_time") + .arguments("{\"city\":\"Berlin\"}") // 1 token + .build(), + ToolExecutionRequest.builder() + .id("b") + .name("current_temperature") + .arguments("{\"city\":\"Berlin\"}") // 1 token + .build() + ), + toolExecutionResultMessage("a", null, "1"), + toolExecutionResultMessage("b", null, "2") + ), model), + arguments(asList( + aiMessage( + ToolExecutionRequest.builder() + .id("a") + .name("current_time") + .arguments("{\"city\":\"Munich\"}") // 3 tokens + .build(), + ToolExecutionRequest.builder() + .id("b") + .name("current_temperature") + .arguments("{\"city\":\"Munich\"}") // 3 tokens + .build() + ), + toolExecutionResultMessage("a", null, "1"), + toolExecutionResultMessage("b", null, "2") + ), model), + arguments(asList( + aiMessage( + ToolExecutionRequest.builder() + .id("a") + .name("current_time") + .arguments("{\"city\":\"Pietramontecorvino\"}") // 8 tokens + .build(), + ToolExecutionRequest.builder() + .id("b") + .name("current_temperature") + .arguments("{\"city\":\"Pietramontecorvino\"}") // 8 tokens + .build() + ), + toolExecutionResultMessage("a", null, "1"), + toolExecutionResultMessage("b", null, "2") + ), model), + + // 1 argument, various numeric argument "value" lengths + arguments(asList( + aiMessage( + ToolExecutionRequest.builder() + .id("a") + .name("current_time") + .arguments("{\"city_id\": 189}") // 1 token + .build(), + ToolExecutionRequest.builder() + .id("b") + .name("current_temperature") + .arguments("{\"city_id\": 189}") // 1 token + .build() + ), + toolExecutionResultMessage("a", null, "1"), + toolExecutionResultMessage("b", null, "2") + ), model), + arguments(asList( + aiMessage( + ToolExecutionRequest.builder() + .id("a") + .name("current_time") + .arguments("{\"city_id\": 189647}") // 2 tokens + .build(), + ToolExecutionRequest.builder() + .id("b") + .name("current_temperature") + .arguments("{\"city_id\": 189647}") // 2 tokens + .build() + ), + toolExecutionResultMessage("a", null, "1"), + toolExecutionResultMessage("b", null, "2") + ), model), + arguments(asList( + aiMessage( + ToolExecutionRequest.builder() + .id("a") + .name("current_time") + .arguments("{\"city_id\": 189647852}") // 3 tokens + .build(), + ToolExecutionRequest.builder() + .id("b") + .name("current_temperature") + .arguments("{\"city_id\": 189647852}") // 3 tokens + .build() + ), + toolExecutionResultMessage("a", null, "1"), + toolExecutionResultMessage("b", null, "2") + ), model), + + // 2 arguments + arguments(asList( + aiMessage( + ToolExecutionRequest.builder() + .id("a") + .name("current_time") + .arguments("{\"city\":\"Berlin\",\"country\":\"Germany\"}") + .build(), + ToolExecutionRequest.builder() + .id("b") + .name("current_temperature") + .arguments("{\"city\":\"Berlin\",\"country\":\"Germany\"}") + .build() + ), + toolExecutionResultMessage("a", null, "1"), + toolExecutionResultMessage("b", null, "2") + ), model), + + // 3 arguments + arguments(asList( + aiMessage( + ToolExecutionRequest.builder() + .id("a") + .name("current_time") + .arguments("{\"city\":\"Berlin\",\"country\":\"Germany\",\"format\":\"24\"}") + .build(), + ToolExecutionRequest.builder() + .id("b") + .name("current_temperature") + .arguments("{\"city\":\"Berlin\",\"country\":\"Germany\",\"unit\":\"C\"}") + .build() + ), + toolExecutionResultMessage("a", null, "1"), + toolExecutionResultMessage("b", null, "2") + ), model), + + // various result lengths + arguments(asList( + aiMessage( + ToolExecutionRequest.builder() + .id("a") + .name("current_time") + .arguments("{}") + .build(), + ToolExecutionRequest.builder() + .id("b") + .name("current_temperature") + .arguments("{}") + .build() + ), + toolExecutionResultMessage("a", null, "23"), // 1 token + toolExecutionResultMessage("b", null, "17") // 1 token + ), model), + arguments(asList( + aiMessage( + ToolExecutionRequest.builder() + .id("a") + .name("current_time") + .arguments("{}") + .build(), + ToolExecutionRequest.builder() + .id("b") + .name("current_temperature") + .arguments("{}") + .build() + ), + toolExecutionResultMessage("a", null, "23:59"), // 3 tokens + toolExecutionResultMessage("b", null, "17.5") // 3 tokens + ), model), + arguments(asList( + aiMessage( + ToolExecutionRequest.builder() + .id("a") + .name("current_time") + .arguments("{}") + .build(), + ToolExecutionRequest.builder() + .id("b") + .name("current_temperature") + .arguments("{}") + .build() + ), + toolExecutionResultMessage("a", null, "23:59:59"), // 5 tokens + toolExecutionResultMessage("b", null, "17.5 grad C") // 5 tokens + ), model), + + // 3 tools without arguments + arguments(asList( + aiMessage( + ToolExecutionRequest.builder() + .id("a") + .name("time") + .arguments("{}") + .build(), + ToolExecutionRequest.builder() + .id("b") + .name("temperature") + .arguments("{}") + .build(), + ToolExecutionRequest.builder() + .id("c") + .name("weather") + .arguments("{}") + .build() + ), + toolExecutionResultMessage("a", null, "1"), + toolExecutionResultMessage("b", null, "2"), + toolExecutionResultMessage("c", null, "3") + ), model), + + // 3 tools with arguments + arguments(asList( + aiMessage( + ToolExecutionRequest.builder() + .id("a") + .name("time") + .arguments("{\"city\":\"Berlin\"}") + .build(), + ToolExecutionRequest.builder() + .id("b") + .name("temperature") + .arguments("{\"city\":\"Berlin\"}") + .build(), + ToolExecutionRequest.builder() + .id("c") + .name("weather") + .arguments("{\"city\":\"Berlin\"}") + .build() + ), + toolExecutionResultMessage("a", null, "1"), + toolExecutionResultMessage("b", null, "2"), + toolExecutionResultMessage("c", null, "3") + ), model), + + // 4 tools without arguments + arguments(asList( + aiMessage( + ToolExecutionRequest.builder() + .id("a") + .name("time") + .arguments("{}") + .build(), + ToolExecutionRequest.builder() + .id("b") + .name("temperature") + .arguments("{}") + .build(), + ToolExecutionRequest.builder() + .id("c") + .name("weather") + .arguments("{}") + .build(), + ToolExecutionRequest.builder() + .id("d") + .name("UV") + .arguments("{}") + .build() + ), + toolExecutionResultMessage("a", null, "1"), + toolExecutionResultMessage("b", null, "2"), + toolExecutionResultMessage("c", null, "3"), + toolExecutionResultMessage("d", null, "4") + ), model), + + // 4 tools with arguments + arguments(asList( + aiMessage( + ToolExecutionRequest.builder() + .id("a") + .name("time") + .arguments("{\"city\":\"Berlin\"}") + .build(), + ToolExecutionRequest.builder() + .id("b") + .name("temperature") + .arguments("{\"city\":\"Berlin\"}") + .build(), + ToolExecutionRequest.builder() + .id("c") + .name("weather") + .arguments("{\"city\":\"Berlin\"}") + .build(), + ToolExecutionRequest.builder() + .id("d") + .name("UV") + .arguments("{\"city\":\"Berlin\"}") + .build() + ), + toolExecutionResultMessage("a", null, "1"), + toolExecutionResultMessage("b", null, "2"), + toolExecutionResultMessage("c", null, "3"), + toolExecutionResultMessage("d", null, "4") + ), model) + )); + } + + @ParameterizedTest + @MethodSource + void should_count_tokens_in_tool_specifications(List toolSpecifications, + AzureOpenAiChatModelName modelName) { + // given + AzureOpenAiChatModel model = AzureOpenAiChatModel.builder() + .endpoint(System.getenv("AZURE_OPENAI_ENDPOINT")) + .apiKey(System.getenv("AZURE_OPENAI_KEY")) + .deploymentName(modelName.toString()) + .maxTokens(2) // we don't need outputs, let's not waste tokens + .logRequestsAndResponses(true) + .build(); + + List dummyMessages = singletonList(userMessage("hi")); + + Tokenizer tokenizer = new AzureOpenAiTokenizer(modelName.modelVersion()); + + int expectedTokenCount = model.generate(dummyMessages, toolSpecifications).tokenUsage().inputTokenCount() + - tokenizer.estimateTokenCountInMessages(dummyMessages); + + // when + int tokenCount = tokenizer.estimateTokenCountInToolSpecifications(toolSpecifications); + + // then + assertThat(tokenCount).isCloseTo(expectedTokenCount, withPercentage(2)); + } + + static Stream should_count_tokens_in_tool_specifications() { + return stream(AzureOpenAiChatModelName.values()) + .filter(model -> !MODELS_WITHOUT_ACCESS.contains(model)) + .filter(model -> !MODELS_WITHOUT_TOOL_SUPPORT.contains(model)) + .flatMap(model -> Stream.of( + + // "name" of various lengths + arguments(singletonList(ToolSpecification.builder() + .name("time") // 1 token + .build()), model), + arguments(singletonList(ToolSpecification.builder() + .name("current_time") // 2 tokens + .build()), model), + arguments(singletonList(ToolSpecification.builder() + .name("get_current_time") // 3 tokens + .build()), model), + + // "description" of various lengths + arguments(singletonList(ToolSpecification.builder() + .name("current_time") + .description("time") // 1 token + .build()), model), + arguments(singletonList(ToolSpecification.builder() + .name("current_time") + .description("current time") // 2 tokens + .build()), model), + arguments(singletonList(ToolSpecification.builder() + .name("current_time") + .description("returns current time in 24-hour format") // 8 tokens + .build()), model), + + // 1 parameter with "name" of various lengths + arguments(singletonList(ToolSpecification.builder() + .name("current_time") + .description("current time") + .addParameter("city") // 1 token + .build()), model), + arguments(singletonList(ToolSpecification.builder() + .name("current_time") + .description("current time") + .addParameter("target_city") // 2 tokens + .build()), model), + arguments(singletonList(ToolSpecification.builder() + .name("current_time") + .description("current time") + .addParameter("target_city_name") // 3 tokens + .build()), model), + + // 1 parameter with "description" of various lengths + arguments(singletonList(ToolSpecification.builder() + .name("current_time") + .description("current time") + .addParameter("city", description("city")) // 1 token + .build()), model), + arguments(singletonList(ToolSpecification.builder() + .name("current_time") + .description("current time") + .addParameter("city", description("target city name")) // 3 tokens + .build()), model), + arguments(singletonList(ToolSpecification.builder() + .name("current_time") + .description("current time") + .addParameter("city", description("city for which time should be returned")) // 7 tokens + .build()), model), + + // 1 parameter with varying "type" + arguments(singletonList(ToolSpecification.builder() + .name("current_time") + .description("current time") + .addParameter("city", STRING) + .build()), model), + arguments(singletonList(ToolSpecification.builder() + .name("current_time") + .description("current time") + .addParameter("city", INTEGER) + .build()), model), + arguments(singletonList(ToolSpecification.builder() + .name("current_time") + .description("current time") + .addParameter("cities", ARRAY, items(INTEGER)) + .build()), model), + + // 1 parameter with "enum" of various range of values + arguments(singletonList(ToolSpecification.builder() + .name("current_temperature") + .description("current temperature") + .addParameter("unit", enums("C")) + .build()), model), + arguments(singletonList(ToolSpecification.builder() + .name("current_temperature") + .description("current temperature") + .addParameter("unit", enums("C", "K")) + .build()), model), + arguments(singletonList(ToolSpecification.builder() + .name("current_temperature") + .description("current temperature") + .addParameter("unit", enums("C", "K", "F")) + .build()), model), + + // 1 parameter with "enum" of various name lengths + arguments(singletonList(ToolSpecification.builder() + .name("current_temperature") + .description("current temperature") + .addParameter("unit", enums("celsius", "kelvin", "fahrenheit")) // 2 tokens each + .build()), model), + arguments(singletonList(ToolSpecification.builder() + .name("current_temperature") + .description("current temperature") + .addParameter("unit", enums("CELSIUS", "KELVIN", "FAHRENHEIT")) // 3-5 tokens + .build()), model), + + // 2 parameters with "name" of various length + arguments(singletonList(ToolSpecification.builder() + .name("current_time") + .description("current time") + .addParameter("city") // 1 token + .addParameter("country") // 1 token + .build()), model), + arguments(singletonList(ToolSpecification.builder() + .name("current_time") + .description("current time") + .addParameter("target_city") // 2 tokens + .addParameter("target_country") // 2 tokens + .build()), model), + arguments(singletonList(ToolSpecification.builder() + .name("current_time") + .description("current time") + .addParameter("target_city_name") // 3 tokens + .addParameter("target_country_name") // 3 tokens + .build()), model), + + // 3 parameters with "name" of various length + arguments(singletonList(ToolSpecification.builder() + .name("current_time") + .description("current time") + .addParameter("city") // 1 token + .addParameter("country") // 1 token + .addParameter("format", enums("12H", "24H")) // 1 token + .build()), model), + arguments(singletonList(ToolSpecification.builder() + .name("current_time") + .description("current time") + .addParameter("target_city") // 2 tokens + .addParameter("target_country") // 2 tokens + .addParameter("time_format", enums("12H", "24H")) // 2 tokens + .build()), model), + arguments(singletonList(ToolSpecification.builder() + .name("current_time") + .description("current time") + .addParameter("target_city_name") // 3 tokens + .addParameter("target_country_name") // 3 tokens + .addParameter("current_time_format", enums("12H", "24H")) // 3 tokens + .build()), model), + + // 1 optional parameter with "name" of various lengths + arguments(singletonList(ToolSpecification.builder() + .name("current_time") + .description("current time") + .addOptionalParameter("city") // 1 token + .build()), model), + arguments(singletonList(ToolSpecification.builder() + .name("current_time") + .description("current time") + .addOptionalParameter("target_city") // 2 tokens + .build()), model), + arguments(singletonList(ToolSpecification.builder() + .name("current_time") + .description("current time") + .addOptionalParameter("target_city_name") // 3 tokens + .build()), model), + + // 2 optional parameters with "name" of various lengths + arguments(singletonList(ToolSpecification.builder() + .name("current_time") + .description("current time") + .addOptionalParameter("city") // 1 token + .addOptionalParameter("country") // 1 token + .build()), model), + arguments(singletonList(ToolSpecification.builder() + .name("current_time") + .description("current time") + .addOptionalParameter("target_city") // 2 tokens + .addOptionalParameter("target_country") // 2 tokens + .build()), model), + arguments(singletonList(ToolSpecification.builder() + .name("current_time") + .description("current time") + .addOptionalParameter("target_city_name") // 3 tokens + .addOptionalParameter("target_country_name") // 3 tokens + .build()), model), + + // 1 mandatory, 1 optional parameters with "name" of various lengths + arguments(singletonList(ToolSpecification.builder() + .name("current_time") + .description("current time") + .addParameter("city") // 1 token + .addOptionalParameter("country") // 1 token + .build()), model), + arguments(singletonList(ToolSpecification.builder() + .name("current_time") + .description("current time") + .addParameter("target_city") // 2 tokens + .addOptionalParameter("target_country") // 2 tokens + .build()), model), + arguments(singletonList(ToolSpecification.builder() + .name("current_time") + .description("current time") + .addParameter("target_city_name") // 3 tokens + .addOptionalParameter("target_country_name") // 3 tokens + .build()), model), + + // 2 tools + arguments(asList( + ToolSpecification.builder() + .name("current_time") + .description("current time") + .addParameter("city_name", description("city name")) + .addOptionalParameter("country_name", description("optional country name")) + .build(), + ToolSpecification.builder() + .name("current_temperature") + .description("current temperature") + .addParameter("city_name", description("city name")) + .addOptionalParameter("country_name", description("optional country name")) + .build() + ), model), + + // 3 tools + arguments(asList( + ToolSpecification.builder() + .name("current_time") + .description("current time") + .addParameter("city_name", description("city name")) + .addOptionalParameter("country_name", description("optional country name")) + .build(), + ToolSpecification.builder() + .name("current_temperature") + .description("current temperature") + .addParameter("city_name", description("city name")) + .addOptionalParameter("country_name", description("optional country name")) + .build(), + ToolSpecification.builder() + .name("current_weather") + .description("current weather") + .addParameter("city_name", description("city name")) + .addOptionalParameter("country_name", description("optional country name")) + .build() + ), model) + )); + } + + @ParameterizedTest + @MethodSource + void should_count_tokens_in_tool_execution_request(UserMessage userMessage, + ToolSpecification toolSpecification, + ToolExecutionRequest expectedToolExecutionRequest, + AzureOpenAiChatModelName modelName) { + // given + AzureOpenAiChatModel model = AzureOpenAiChatModel.builder() + .endpoint(System.getenv("AZURE_OPENAI_ENDPOINT")) + .apiKey(System.getenv("AZURE_OPENAI_KEY")) + .deploymentName(modelName.toString()) + .logRequestsAndResponses(true) + .build(); + + Response response = model.generate(singletonList(userMessage), singletonList(toolSpecification)); + + List toolExecutionRequests = response.content().toolExecutionRequests(); + // we need to ensure that model generated expected tool execution request, + // then we can use output token count as a reference + assertThat(toolExecutionRequests).hasSize(1); + ToolExecutionRequest toolExecutionRequest = toolExecutionRequests.get(0); + assertThat(toolExecutionRequest.name()).isEqualTo(expectedToolExecutionRequest.name()); + assertThat(toolExecutionRequest.arguments()).isEqualToIgnoringWhitespace(expectedToolExecutionRequest.arguments()); + + int expectedTokenCount = response.tokenUsage().outputTokenCount(); + + Tokenizer tokenizer = new AzureOpenAiTokenizer(modelName.modelVersion()); + + // when + int tokenCount = tokenizer.estimateTokenCountInToolExecutionRequests(toolExecutionRequests); + + // then + try { + assertThat(tokenCount).isEqualTo(expectedTokenCount); + } catch (AssertionFailedError e) { + if (modelName == GPT_3_5_TURBO_1106) { + // sometimes GPT_3_5_TURBO_1106 calculates tokens wrongly + // see https://community.openai.com/t/inconsistent-token-billing-for-tool-calls-in-gpt-3-5-turbo-1106 + // TODO remove once they fix it + e.printStackTrace(); + // there is some pattern to it, so we are going to check if this is really the case or our calculation is wrong + Tokenizer tokenizer2 = new AzureOpenAiTokenizer(GPT_3_5_TURBO.toString()); + int tokenCount2 = tokenizer2.estimateTokenCountInToolExecutionRequests(toolExecutionRequests); + assertThat(tokenCount2).isEqualTo(expectedTokenCount - 3); + } else { + throw e; + } + } + } + + @ParameterizedTest + @MethodSource("should_count_tokens_in_tool_execution_request") + void should_count_tokens_in_forceful_tool_specification_and_execution_request(UserMessage userMessage, + ToolSpecification toolSpecification, + ToolExecutionRequest expectedToolExecutionRequest, + AzureOpenAiChatModelName modelName) { + // given + AzureOpenAiChatModel model = AzureOpenAiChatModel.builder() + .endpoint(System.getenv("AZURE_OPENAI_ENDPOINT")) + .apiKey(System.getenv("AZURE_OPENAI_KEY")) + .deploymentName(modelName.toString()) + .logRequestsAndResponses(true) + .build(); + + Response response = model.generate(singletonList(userMessage), toolSpecification); + + Tokenizer tokenizer = new AzureOpenAiTokenizer(modelName.modelVersion()); + + int expectedTokenCountInSpecification = response.tokenUsage().inputTokenCount() + - tokenizer.estimateTokenCountInMessages(singletonList(userMessage)); + + // when + int tokenCountInSpecification = tokenizer.estimateTokenCountInForcefulToolSpecification(toolSpecification); + + // then + assertThat(tokenCountInSpecification).isEqualTo(expectedTokenCountInSpecification); + + // given + List toolExecutionRequests = response.content().toolExecutionRequests(); + // we need to ensure that model generated expected tool execution request, + // then we can use output token count as a reference + assertThat(toolExecutionRequests).hasSize(1); + ToolExecutionRequest toolExecutionRequest = toolExecutionRequests.get(0); + assertThat(toolExecutionRequest.name()).isEqualTo(expectedToolExecutionRequest.name()); + assertThat(toolExecutionRequest.arguments()).isEqualToIgnoringWhitespace(expectedToolExecutionRequest.arguments()); + + int expectedTokenCountInToolRequest = response.tokenUsage().outputTokenCount(); + + // when + int tokenCountInToolRequest = tokenizer.estimateTokenCountInForcefulToolExecutionRequest(toolExecutionRequest); + + // then + assertThat(tokenCountInToolRequest).isEqualTo(expectedTokenCountInToolRequest); + } + + static Stream should_count_tokens_in_tool_execution_request() { + return stream(AzureOpenAiChatModelName.values()) + .filter(model -> !MODELS_WITHOUT_ACCESS.contains(model)) + .filter(model -> !MODELS_WITHOUT_TOOL_SUPPORT.contains(model)) + .flatMap(model -> Stream.of( + + // no arguments, different lengths of "name" + arguments( + userMessage("What is the time now?"), + ToolSpecification.builder() + .name("time") // 1 token + .description("returns current time") + .build(), + ToolExecutionRequest.builder() + .name("time") // 1 token + .arguments("{}") + .build(), + model + ), + arguments( + userMessage("What is the time now?"), + ToolSpecification.builder() + .name("current_time") // 2 tokens + .description("returns current time") + .build(), + ToolExecutionRequest.builder() + .name("current_time") // 2 tokens + .arguments("{}") + .build(), + model + ), + arguments( + userMessage("What is the time now?"), + ToolSpecification.builder() + .name("get_current_time") // 3 tokens + .description("returns current time") + .build(), + ToolExecutionRequest.builder() + .name("get_current_time") // 3 tokens + .arguments("{}") + .build(), + model + ), + + // one argument, different lengths of "arguments" + arguments( + userMessage("What is the time in Berlin now?"), + ToolSpecification.builder() + .name("current_time") + .description("returns current time") + .addParameter("city") + .build(), + ToolExecutionRequest.builder() + .name("current_time") + .arguments("{\"city\":\"Berlin\"}") // 5 tokens + .build(), + model + ), + arguments( + userMessage("What is the time in Munich now?"), + ToolSpecification.builder() + .name("current_time") + .description("returns current time") + .addParameter("city") + .build(), + ToolExecutionRequest.builder() + .name("current_time") + .arguments("{\"city\":\"Munich\"}") // 7 tokens + .build(), + model + ), + arguments( + userMessage("What is the time in Pietramontecorvino now?"), + ToolSpecification.builder() + .name("current_time") + .description("returns current time") + .addParameter("city") + .build(), + ToolExecutionRequest.builder() + .name("current_time") + .arguments("{\"city\":\"Pietramontecorvino\"}") // 12 tokens + .build(), + model + ), + + // two arguments, different lengths of "arguments" + arguments( + userMessage("What is the time in Berlin now?"), + ToolSpecification.builder() + .name("current_time") + .description("returns current time") + .addParameter("city") + .addParameter("country") + .build(), + ToolExecutionRequest.builder() + .name("current_time") + .arguments("{\"country\":\"Germany\",\"city\":\"Berlin\"}") // 9 tokens + .build(), + model + ), + arguments( + userMessage("What is the time in Munich now?"), + ToolSpecification.builder() + .name("current_time") + .description("returns current time") + .addParameter("city") + .addParameter("country") + .build(), + ToolExecutionRequest.builder() + .name("current_time") + .arguments("{\"country\":\"Germany\",\"city\":\"Munich\"}") // 11 tokens + .build(), + model + ), + arguments( + userMessage("What is the time in Pietramontecorvino now?"), + ToolSpecification.builder() + .name("current_time") + .description("returns current time") + .addParameter("city") + .addParameter("country") + .build(), + ToolExecutionRequest.builder() + .name("current_time") + .arguments("{\"country\":\"Italy\",\"city\":\"Pietramontecorvino\"}") // 16 tokens + .build(), + model + ), + + // three arguments, different lengths of "arguments" + arguments( + userMessage("What is the time in Berlin now in 24-hour format?"), + ToolSpecification.builder() + .name("current_time") + .description("returns current time") + .addParameter("city") + .addParameter("country") + .addParameter("format", enums("12", "24")) + .build(), + ToolExecutionRequest.builder() + .name("current_time") + .arguments("{\"country\":\"Germany\",\"city\":\"Berlin\",\"format\":\"24\"}") // 13 tokens + .build(), + model + ), + arguments( + userMessage("What is the time in Munich now in 24-hour format?"), + ToolSpecification.builder() + .name("current_time") + .description("returns current time") + .addParameter("city") + .addParameter("country") + .addParameter("format", enums("12", "24")) + .build(), + ToolExecutionRequest.builder() + .name("current_time") + .arguments("{\"country\":\"Germany\",\"city\":\"Munich\",\"format\":\"24\"}") // 15 tokens + .build(), + model + ), + arguments( + userMessage("What is the time in Pietramontecorvino now in 24-hour format?"), + ToolSpecification.builder() + .name("current_time") + .description("returns current time") + .addParameter("city") + .addParameter("country") + .addParameter("format", enums("12", "24")) + .build(), + ToolExecutionRequest.builder() + .name("current_time") + .arguments("{\"country\":\"Italy\",\"city\":\"Pietramontecorvino\",\"format\":\"24\"}") // 20 tokens + .build(), + model + ) + )); + } + + @ParameterizedTest + @MethodSource + void should_count_tokens_in_multiple_tool_execution_requests(UserMessage userMessage, + List toolSpecifications, + List expectedToolExecutionRequests, + AzureOpenAiChatModelName modelName) { + // given + AzureOpenAiChatModel model = AzureOpenAiChatModel.builder() + .endpoint(System.getenv("AZURE_OPENAI_ENDPOINT")) + .apiKey(System.getenv("AZURE_OPENAI_KEY")) + .deploymentName(modelName.toString()) + .logRequestsAndResponses(true) + .build(); + + Response response = model.generate(singletonList(userMessage), toolSpecifications); + + List toolExecutionRequests = response.content().toolExecutionRequests(); + // we need to ensure that model generated expected tool execution requests, + // then we can use output token count as a reference + assertThat(toolExecutionRequests).hasSize(expectedToolExecutionRequests.size()); + for (int i = 0; i < toolExecutionRequests.size(); i++) { + ToolExecutionRequest toolExecutionRequest = toolExecutionRequests.get(i); + ToolExecutionRequest expectedToolExecutionRequest = expectedToolExecutionRequests.get(i); + assertThat(toolExecutionRequest.name()).isEqualTo(expectedToolExecutionRequest.name()); + assertThat(toolExecutionRequest.arguments()).isEqualToIgnoringWhitespace(expectedToolExecutionRequest.arguments()); + } + + int expectedTokenCount = response.tokenUsage().outputTokenCount(); + + Tokenizer tokenizer = new AzureOpenAiTokenizer(modelName.modelVersion()); + + // when + int tokenCount = tokenizer.estimateTokenCountInToolExecutionRequests(toolExecutionRequests); + + // then + assertThat(tokenCount).isEqualTo(expectedTokenCount); + } + + static Stream should_count_tokens_in_multiple_tool_execution_requests() { + return stream(AzureOpenAiChatModelName.values()) + .filter(model -> !MODELS_WITHOUT_ACCESS.contains(model)) + .filter(MODELS_WITH_PARALLEL_TOOL_SUPPORT::contains) + .flatMap(model -> Stream.of( + + // no arguments, different lengths of "name" + arguments( + userMessage("What is the time and date now?"), + asList( + ToolSpecification.builder() + .name("time") + .description("returns current time") + .build(), + ToolSpecification.builder() + .name("date") + .description("returns current date") + .build() + ), + asList( + ToolExecutionRequest.builder() + .name("time") // 1 token + .arguments("{}") + .build(), + ToolExecutionRequest.builder() + .name("date") // 1 token + .arguments("{}") + .build() + ), + model + ), + arguments( + userMessage("What is the time and date now?"), + asList( + ToolSpecification.builder() + .name("current_time") + .description("returns current time") + .build(), + ToolSpecification.builder() + .name("current_date") + .description("returns current date") + .build() + ), + asList( + ToolExecutionRequest.builder() + .name("current_time") // 2 tokens + .arguments("{}") + .build(), + ToolExecutionRequest.builder() + .name("current_date") // 2 tokens + .arguments("{}") + .build() + ), + model + ), + arguments( + userMessage("What is the time and date now?"), + asList( + ToolSpecification.builder() + .name("get_current_time") + .description("returns current time") + .build(), + ToolSpecification.builder() + .name("get_current_date") + .description("returns current date") + .build() + ), + asList( + ToolExecutionRequest.builder() + .name("get_current_time") // 3 tokens + .arguments("{}") + .build(), + ToolExecutionRequest.builder() + .name("get_current_date") // 3 tokens + .arguments("{}") + .build() + ), + model + ), + + // no arguments, 3 tools + arguments( + userMessage("What is the time and date and location?"), + asList( + ToolSpecification.builder() + .name("time") + .description("returns current time") + .build(), + ToolSpecification.builder() + .name("date") + .description("returns current date") + .build(), + ToolSpecification.builder() + .name("location") + .description("returns current location") + .build() + ), + asList( + ToolExecutionRequest.builder() + .name("time") + .arguments("{}") + .build(), + ToolExecutionRequest.builder() + .name("date") + .arguments("{}") + .build(), + ToolExecutionRequest.builder() + .name("location") + .arguments("{}") + .build() + ), + model + ), + + // no arguments, 1 argument + arguments( + userMessage("What is the time in Munich and date now?"), + asList( + ToolSpecification.builder() + .name("time") + .description("returns current time") + .addParameter("city") + .build(), + ToolSpecification.builder() + .name("date") + .description("returns current date") + .build() + ), + asList( + ToolExecutionRequest.builder() + .name("time") + .arguments("{\"city\":\"Munich\"}") + .build(), + ToolExecutionRequest.builder() + .name("date") + .arguments("{}") + .build() + ), + model + ), + + // one argument, 2 different tools, different lengths of "arguments" + arguments( + userMessage("What is the time and date in Berlin now?"), + asList( + ToolSpecification.builder() + .name("current_time") + .description("returns current time") + .addParameter("city") + .build(), + ToolSpecification.builder() + .name("current_date") + .description("returns current date") + .addParameter("city") + .build() + ), + asList( + ToolExecutionRequest.builder() + .name("current_time") + .arguments("{\"city\":\"Berlin\"}") // 5 tokens + .build(), + ToolExecutionRequest.builder() + .name("current_date") + .arguments("{\"city\":\"Berlin\"}") // 5 tokens + .build() + ), + model + ), + arguments( + userMessage("What is the time and date in Pietramontecorvino now?"), + asList( + ToolSpecification.builder() + .name("current_time") + .description("returns current time") + .addParameter("city") + .build(), + ToolSpecification.builder() + .name("current_date") + .description("returns current date") + .addParameter("city") + .build() + ), + asList( + ToolExecutionRequest.builder() + .name("current_time") + .arguments("{\"city\":\"Pietramontecorvino\"}") // 12 tokens + .build(), + ToolExecutionRequest.builder() + .name("current_date") + .arguments("{\"city\":\"Pietramontecorvino\"}") // 12 tokens + .build() + ), + model + ), + + // different tools, different lengths of argument values + arguments( + userMessage("What is the time in Berlin and date in Munich now?"), + asList( + ToolSpecification.builder() + .name("current_time") + .description("returns current time") + .addParameter("city") + .build(), + ToolSpecification.builder() + .name("current_date") + .description("returns current date") + .addParameter("city") + .build() + ), + asList( + ToolExecutionRequest.builder() + .name("current_time") + .arguments("{\"city\":\"Berlin\"}") // 5 tokens + .build(), + ToolExecutionRequest.builder() + .name("current_date") + .arguments("{\"city\":\"Munich\"}") // 7 tokens + .build() + ), + model + ), + + // different tools, different lengths of "name", different lengths of argument values + arguments( + userMessage("What is the time in Berlin and date in Munich now?"), + asList( + ToolSpecification.builder() + .name("time") + .description("returns current time") + .addParameter("city") + .build(), + ToolSpecification.builder() + .name("current_date") + .description("returns current date") + .addParameter("city") + .build() + ), + asList( + ToolExecutionRequest.builder() + .name("time") // 1 tokens + .arguments("{\"city\":\"Berlin\"}") // 5 tokens + .build(), + ToolExecutionRequest.builder() + .name("current_date") // 2 tokens + .arguments("{\"city\":\"Munich\"}") // 7 tokens + .build() + ), + model + ), + + // one argument, 4 tool requests + arguments( + userMessage("What is the time in Berlin, Munich, London and Paris now?"), + singletonList( + ToolSpecification.builder() + .name("time") + .description("returns current time") + .addParameter("city") + .build() + ), + asList( + ToolExecutionRequest.builder() + .name("time") + .arguments("{\"city\":\"Berlin\"}") + .build(), + ToolExecutionRequest.builder() + .name("time") + .arguments("{\"city\":\"Munich\"}") + .build(), + ToolExecutionRequest.builder() + .name("time") + .arguments("{\"city\":\"London\"}") + .build(), + ToolExecutionRequest.builder() + .name("time") + .arguments("{\"city\":\"Paris\"}") + .build() + ), + model + ), + + // two arguments, different lengths of "arguments" + arguments( + userMessage("What is the time and date in Berlin now?"), + asList( + ToolSpecification.builder() + .name("current_time") + .description("returns current time") + .addParameter("city") + .addParameter("country") + .build(), + ToolSpecification.builder() + .name("current_date") + .description("returns current date") + .addParameter("city") + .addParameter("country") + .build() + ), + asList( + ToolExecutionRequest.builder() + .name("current_time") + .arguments("{\"country\":\"Germany\",\"city\":\"Berlin\"}") // 9 tokens + .build(), + ToolExecutionRequest.builder() + .name("current_date") // 2 tokens + .arguments("{\"country\":\"Germany\",\"city\":\"Berlin\"}") // 9 tokens + .build() + ), + model + ), + arguments( + userMessage("What is the time and date in Pietramontecorvino now?"), + asList( + ToolSpecification.builder() + .name("current_time") + .description("returns current time") + .addParameter("city") + .addParameter("country") + .build(), + ToolSpecification.builder() + .name("current_date") + .description("returns current date") + .addParameter("city") + .addParameter("country") + .build() + ), + asList( + ToolExecutionRequest.builder() + .name("current_time") + .arguments("{\"country\":\"Italy\",\"city\":\"Pietramontecorvino\"}") // 16 tokens + .build(), + ToolExecutionRequest.builder() + .name("current_date") + .arguments("{\"country\":\"Italy\",\"city\":\"Pietramontecorvino\"}") // 16 tokens + .build() + ), + model + ), + + // three arguments, different lengths of "arguments" + arguments( + userMessage("What is the time in Berlin and Pietramontecorvino in 24-hour format?"), + singletonList( + ToolSpecification.builder() + .name("current_time") + .description("returns current time") + .addParameter("city") + .addParameter("country") + .addParameter("format", enums("12", "24")) + .build() + ), + asList( + ToolExecutionRequest.builder() + .name("current_time") // 2 tokens + .arguments("{\"country\":\"Germany\",\"city\":\"Berlin\",\"format\":\"24\"}") // 13 tokens + .build(), + ToolExecutionRequest.builder() + .name("current_time") // 2 tokens + .arguments("{\"country\":\"Italy\",\"city\":\"Pietramontecorvino\",\"format\":\"24\"}") // 20 tokens + .build() + ), + model + ), + + // three tool execution requests, different tools and lengths of "arguments" + arguments( + userMessage("What is the time in Berlin and Pietramontecorvino in 24-hour format now and date in Munich?"), + asList( + ToolSpecification.builder() + .name("current_time") + .description("returns current time") + .addParameter("city") + .addParameter("country") + .addParameter("format", enums("12", "24")) + .build(), + ToolSpecification.builder() + .name("current_date") + .description("returns current date") + .addParameter("city") + .addParameter("country") + .build() + ), + asList( + ToolExecutionRequest.builder() + .name("current_time") + .arguments("{\"country\":\"Germany\",\"city\":\"Berlin\",\"format\": \"24\"}") // 14 tokens + .build(), + ToolExecutionRequest.builder() + .name("current_time") + .arguments("{\"country\":\"Italy\",\"city\":\"Pietramontecorvino\",\"format\": \"24\"}") // 21 tokens + .build(), + ToolExecutionRequest.builder() + .name("current_date") + .arguments("{\"country\":\"Germany\",\"city\":\"Munich\"}") // 11 tokens + .build() + ), + model + ) + )); + } +} \ No newline at end of file diff --git a/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiTokenizerTest.java b/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiTokenizerTest.java new file mode 100644 index 0000000000..9eb443645e --- /dev/null +++ b/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiTokenizerTest.java @@ -0,0 +1,145 @@ +package dev.langchain4j.model.azure; + +import dev.langchain4j.model.Tokenizer; +import static dev.langchain4j.model.azure.AzureOpenAiChatModelName.GPT_3_5_TURBO; +import static dev.langchain4j.model.azure.AzureOpenAiTokenizer.countArguments; +import static org.assertj.core.api.Assertions.assertThat; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.EnumSource; + +import java.util.ArrayList; +import java.util.List; + +class AzureOpenAiTokenizerTest { + + AzureOpenAiTokenizer tokenizer = new AzureOpenAiTokenizer(GPT_3_5_TURBO.modelType()); + + @Test + void should_encode_and_decode_text() { + String originalText = "This is a text which will be encoded and decoded back."; + + List tokens = tokenizer.encode(originalText); + String decodedText = tokenizer.decode(tokens); + + assertThat(decodedText).isEqualTo(originalText); + } + + @Test + void should_encode_with_truncation_and_decode_text() { + String originalText = "This is a text which will be encoded with truncation and decoded back."; + + List tokens = tokenizer.encode(originalText, 10); + assertThat(tokens).hasSize(10); + + String decodedText = tokenizer.decode(tokens); + assertThat(decodedText).isEqualTo("This is a text which will be encoded with trunc"); + } + + @Test + void should_count_tokens_in_short_texts() { + assertThat(tokenizer.estimateTokenCountInText("Hello")).isEqualTo(1); + assertThat(tokenizer.estimateTokenCountInText("Hello!")).isEqualTo(2); + assertThat(tokenizer.estimateTokenCountInText("Hello, how are you?")).isEqualTo(6); + + assertThat(tokenizer.estimateTokenCountInText("")).isEqualTo(0); + assertThat(tokenizer.estimateTokenCountInText("\n")).isEqualTo(1); + assertThat(tokenizer.estimateTokenCountInText("\n\n")).isEqualTo(1); + assertThat(tokenizer.estimateTokenCountInText("\n \n\n")).isEqualTo(2); + } + + @Test + void should_count_tokens_in_average_text() { + String text1 = "Hello, how are you doing? What do you want to talk about?"; + assertThat(tokenizer.estimateTokenCountInText(text1)).isEqualTo(15); + + String text2 = String.join(" ", repeat("Hello, how are you doing? What do you want to talk about?", 2)); + assertThat(tokenizer.estimateTokenCountInText(text2)).isEqualTo(2 * 15); + + String text3 = String.join(" ", repeat("Hello, how are you doing? What do you want to talk about?", 3)); + assertThat(tokenizer.estimateTokenCountInText(text3)).isEqualTo(3 * 15); + } + + @Test + void should_count_tokens_in_large_text() { + String text1 = String.join(" ", repeat("Hello, how are you doing? What do you want to talk about?", 10)); + assertThat(tokenizer.estimateTokenCountInText(text1)).isEqualTo(10 * 15); + + String text2 = String.join(" ", repeat("Hello, how are you doing? What do you want to talk about?", 50)); + assertThat(tokenizer.estimateTokenCountInText(text2)).isEqualTo(50 * 15); + + String text3 = String.join(" ", repeat("Hello, how are you doing? What do you want to talk about?", 100)); + assertThat(tokenizer.estimateTokenCountInText(text3)).isEqualTo(100 * 15); + } + + @Test + void should_count_arguments() { + assertThat(countArguments(null)).isEqualTo(0); + assertThat(countArguments("")).isEqualTo(0); + assertThat(countArguments(" ")).isEqualTo(0); + assertThat(countArguments("{}")).isEqualTo(0); + assertThat(countArguments("{ }")).isEqualTo(0); + + assertThat(countArguments("{\"one\":1}")).isEqualTo(1); + assertThat(countArguments("{\"one\": 1}")).isEqualTo(1); + assertThat(countArguments("{\"one\" : 1}")).isEqualTo(1); + + assertThat(countArguments("{\"one\":1,\"two\":2}")).isEqualTo(2); + assertThat(countArguments("{\"one\": 1,\"two\": 2}")).isEqualTo(2); + assertThat(countArguments("{\"one\" : 1,\"two\" : 2}")).isEqualTo(2); + + assertThat(countArguments("{\"one\":1,\"two\":2,\"three\":3}")).isEqualTo(3); + assertThat(countArguments("{\"one\": 1,\"two\": 2,\"three\": 3}")).isEqualTo(3); + assertThat(countArguments("{\"one\" : 1,\"two\" : 2,\"three\" : 3}")).isEqualTo(3); + } + + @ParameterizedTest + @EnumSource(AzureOpenAiChatModelName.class) + void should_support_all_chat_models(AzureOpenAiChatModelName modelName) { + + // given + Tokenizer tokenizer = new AzureOpenAiTokenizer(modelName); + + // when + int tokenCount = tokenizer.estimateTokenCountInText("a"); + + // then + assertThat(tokenCount).isEqualTo(1); + } + + @ParameterizedTest + @EnumSource(AzureOpenAiEmbeddingModelName.class) + void should_support_all_embedding_models(AzureOpenAiEmbeddingModelName modelName) { + + // given + Tokenizer tokenizer = new AzureOpenAiTokenizer(modelName); + + // when + int tokenCount = tokenizer.estimateTokenCountInText("a"); + + // then + assertThat(tokenCount).isEqualTo(1); + } + + @ParameterizedTest + @EnumSource(AzureOpenAiLanguageModelName.class) + void should_support_all_language_models(AzureOpenAiLanguageModelName modelName) { + + // given + Tokenizer tokenizer = new AzureOpenAiTokenizer(modelName); + + // when + int tokenCount = tokenizer.estimateTokenCountInText("a"); + + // then + assertThat(tokenCount).isEqualTo(1); + } + + static List repeat(String strings, int n) { + List result = new ArrayList<>(); + for (int i = 0; i < n; i++) { + result.add(strings); + } + return result; + } +} \ No newline at end of file diff --git a/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/InternalAzureOpenAiHelperTest.java b/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/InternalAzureOpenAiHelperTest.java index 36761c8b0a..e745a361eb 100644 --- a/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/InternalAzureOpenAiHelperTest.java +++ b/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/InternalAzureOpenAiHelperTest.java @@ -1,11 +1,9 @@ package dev.langchain4j.model.azure; +import com.azure.ai.openai.OpenAIAsyncClient; import com.azure.ai.openai.OpenAIClient; import com.azure.ai.openai.OpenAIServiceVersion; -import com.azure.ai.openai.models.ChatRequestMessage; -import com.azure.ai.openai.models.ChatRequestUserMessage; -import com.azure.ai.openai.models.CompletionsFinishReason; -import com.azure.ai.openai.models.FunctionDefinition; +import com.azure.ai.openai.models.*; import dev.langchain4j.agent.tool.ToolParameters; import dev.langchain4j.agent.tool.ToolSpecification; import dev.langchain4j.data.message.ChatMessage; @@ -19,6 +17,7 @@ import java.util.List; import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertInstanceOf; class InternalAzureOpenAiHelperTest { @@ -32,11 +31,26 @@ void setupOpenAIClientShouldReturnClientWithCorrectConfiguration() { Integer maxRetries = 5; boolean logRequestsAndResponses = true; - OpenAIClient client = InternalAzureOpenAiHelper.setupOpenAIClient(endpoint, serviceVersion, apiKey, timeout, maxRetries, null, logRequestsAndResponses); + OpenAIClient client = InternalAzureOpenAiHelper.setupSyncClient(endpoint, serviceVersion, apiKey, timeout, maxRetries, null, logRequestsAndResponses); assertThat(client).isNotNull(); } + @Test + void setupOpenAIAsyncClientShouldReturnClientWithCorrectConfiguration() { + String endpoint = "test-endpoint"; + String serviceVersion = "test-service-version"; + String apiKey = "test-api-key"; + Duration timeout = Duration.ofSeconds(30); + Integer maxRetries = 5; + boolean logRequestsAndResponses = true; + + OpenAIAsyncClient client = InternalAzureOpenAiHelper.setupAsyncClient(endpoint, serviceVersion, apiKey, timeout, maxRetries, null, logRequestsAndResponses); + + assertThat(client).isNotNull(); + } + + @Test void getOpenAIServiceVersionShouldReturnCorrectVersion() { String serviceVersion = "2023-05-15"; @@ -67,7 +81,7 @@ void toOpenAiMessagesShouldReturnCorrectMessages() { } @Test - void toFunctionsShouldReturnCorrectFunctions() { + void toToolDefinitionsShouldReturnCorrectToolDefinition() { Collection toolSpecifications = new ArrayList<>(); toolSpecifications.add(ToolSpecification.builder() .name("test-tool") @@ -75,10 +89,11 @@ void toFunctionsShouldReturnCorrectFunctions() { .parameters(ToolParameters.builder().build()) .build()); - List functions = InternalAzureOpenAiHelper.toFunctions(toolSpecifications); + List tools = InternalAzureOpenAiHelper.toToolDefinitions(toolSpecifications); - assertThat(functions).hasSize(toolSpecifications.size()); - assertThat(functions.get(0).getName()).isEqualTo(toolSpecifications.iterator().next().name()); + assertEquals(toolSpecifications.size(), tools.size()); + assertInstanceOf(ChatCompletionsFunctionToolDefinition.class, tools.get(0)); + assertEquals(toolSpecifications.iterator().next().name(), ((ChatCompletionsFunctionToolDefinition) tools.get(0)).getFunction().getName()); } @Test diff --git a/langchain4j-azure-open-ai/src/test/script/deploy-azure-openai-models.sh b/langchain4j-azure-open-ai/src/test/script/deploy-azure-openai-models.sh index fe9a0f53b7..479850add3 100755 --- a/langchain4j-azure-open-ai/src/test/script/deploy-azure-openai-models.sh +++ b/langchain4j-azure-open-ai/src/test/script/deploy-azure-openai-models.sh @@ -5,11 +5,11 @@ echo "Setting up environment variables..." echo "----------------------------------" -PROJECT="langchain4j" +PROJECT="langchain4j-eastus" RESOURCE_GROUP="rg-$PROJECT" -LOCATION="swedencentral" -TAG="$PROJECT" +LOCATION="eastus" AI_SERVICE="ai-$PROJECT" +TAG="$PROJECT" echo "Creating the resource group..." echo "------------------------------" @@ -18,6 +18,9 @@ az group create \ --location "$LOCATION" \ --tags system="$TAG" +# If you want to know the available SKUs, run the following Azure CLI command: +# az cognitiveservices account list-skus --location "$LOCATION" -o table + echo "Creating the Cognitive Service..." echo "---------------------------------" az cognitiveservices account create \ @@ -28,56 +31,212 @@ az cognitiveservices account create \ --tags system="$TAG" \ --kind "OpenAI" \ --sku "S0" - + # If you want to know the available models, run the following Azure CLI command: # az cognitiveservices account list-models --resource-group "$RESOURCE_GROUP" --name "$AI_SERVICE" -o table -echo "Deploying a gpt-35-turbo model..." +# Chat Models +echo "Deploying Chat Models" +echo "=====================" + +echo "Deploying a gpt-35-turbo-0301 model..." +echo "----------------------" +az cognitiveservices account deployment create \ + --name "$AI_SERVICE" \ + --resource-group "$RESOURCE_GROUP" \ + --deployment-name "gpt-35-turbo-0301" \ + --model-name "gpt-35-turbo" \ + --model-version "0125" \ + --model-format "OpenAI" \ + --sku-capacity 1 \ + --sku-name "Standard" + +echo "Deploying a gpt-35-turbo-0613 model..." echo "----------------------" az cognitiveservices account deployment create \ --name "$AI_SERVICE" \ --resource-group "$RESOURCE_GROUP" \ - --deployment-name "gpt-35-turbo" \ + --deployment-name "gpt-35-turbo-0613" \ + --model-name "gpt-35-turbo" \ + --model-version "0613" \ + --model-format "OpenAI" \ + --sku-capacity 1 \ + --sku-name "Standard" + +echo "Deploying a gpt-35-turbo-1106 model..." +echo "----------------------" +az cognitiveservices account deployment create \ + --name "$AI_SERVICE" \ + --resource-group "$RESOURCE_GROUP" \ + --deployment-name "gpt-35-turbo-1106" \ --model-name "gpt-35-turbo" \ --model-version "1106" \ --model-format "OpenAI" \ - --sku-capacity 120 \ + --sku-capacity 1 \ --sku-name "Standard" -echo "Deploying a gpt-35-turbo-instruct model..." +echo "Deploying a gpt-35-turbo-16k-0613 model..." echo "----------------------" az cognitiveservices account deployment create \ --name "$AI_SERVICE" \ --resource-group "$RESOURCE_GROUP" \ - --deployment-name "gpt-35-turbo-instruct" \ - --model-name "gpt-35-turbo-instruct" \ - --model-version "0914" \ + --deployment-name "gpt-35-turbo-16k-0613" \ + --model-name "gpt-35-turbo-16k" \ + --model-version "0613" \ + --model-format "OpenAI" \ + --sku-capacity 1 \ + --sku-name "Standard" + +echo "Deploying a gpt-4-0613 model..." +echo "----------------------" +az cognitiveservices account deployment create \ + --name "$AI_SERVICE" \ + --resource-group "$RESOURCE_GROUP" \ + --deployment-name "gpt-4-0613" \ + --model-name "gpt-4" \ + --model-version "0613" \ + --model-format "OpenAI" \ + --sku-capacity 1 \ + --sku-name "Standard" + +echo "Deploying a gpt-4-0125-preview model..." +echo "----------------------" +az cognitiveservices account deployment create \ + --name "$AI_SERVICE" \ + --resource-group "$RESOURCE_GROUP" \ + --deployment-name "gpt-4-0125-preview" \ + --model-name "gpt-4" \ + --model-version "0125-preview" \ + --model-format "OpenAI" \ + --sku-capacity 1 \ + --sku-name "Standard" + +echo "Deploying a gpt-4-1106-preview model..." +echo "----------------------" +az cognitiveservices account deployment create \ + --name "$AI_SERVICE" \ + --resource-group "$RESOURCE_GROUP" \ + --deployment-name "gpt-4-1106-preview" \ + --model-name "gpt-4" \ + --model-version "1106-preview" \ + --model-format "OpenAI" \ + --sku-capacity 1 \ + --sku-name "Standard" + +echo "Deploying a gpt-4-turbo-2024-04-09 model..." +echo "----------------------" +az cognitiveservices account deployment create \ + --name "$AI_SERVICE" \ + --resource-group "$RESOURCE_GROUP" \ + --deployment-name "gpt-4-turbo-2024-04-09" \ + --model-name "gpt-4" \ + --model-version "turbo-2024-04-09" \ + --model-format "OpenAI" \ + --sku-capacity 1 \ + --sku-name "Standard" + +echo "Deploying a gpt-4-32k-0613 model..." +echo "----------------------" +az cognitiveservices account deployment create \ + --name "$AI_SERVICE" \ + --resource-group "$RESOURCE_GROUP" \ + --deployment-name "gpt-4-32k-0613" \ + --model-name "gpt-4-32k" \ + --model-version "0613" \ --model-format "OpenAI" \ - --sku-capacity 120 \ + --sku-capacity 1 \ --sku-name "Standard" -echo "Deploying a gpt-4 model..." +echo "Deploying a gpt-4-vision-preview model..." echo "----------------------" az cognitiveservices account deployment create \ --name "$AI_SERVICE" \ --resource-group "$RESOURCE_GROUP" \ - --deployment-name "gpt-4" \ + --deployment-name "gpt-4-vision-preview" \ --model-name "gpt-4" \ - --model-version "1106-Preview" \ + --model-version "vision-preview" \ + --model-format "OpenAI" \ + --sku-capacity 1 \ + --sku-name "Standard" + +echo "Deploying a gpt-4o model..." +echo "----------------------" +az cognitiveservices account deployment create \ + --name "$AI_SERVICE" \ + --resource-group "$RESOURCE_GROUP" \ + --deployment-name "gpt-4o" \ + --model-name "gpt-4o" \ + --model-version "2024-05-13" \ --model-format "OpenAI" \ - --sku-capacity 10 \ + --sku-capacity 1 \ --sku-name "Standard" -echo "Deploying a text-embedding-ada-002 model..." +# Embedding Models +echo "Deploying Embedding Models" +echo "==========================" + +echo "Deploying a text-embedding-ada-002-1 model..." echo "----------------------" az cognitiveservices account deployment create \ --name "$AI_SERVICE" \ --resource-group "$RESOURCE_GROUP" \ - --deployment-name "text-embedding-ada-002" \ + --deployment-name "text-embedding-ada-002-1" \ + --model-name "text-embedding-ada-002" \ + --model-version "1" \ + --model-format "OpenAI" \ + --sku-capacity 1 \ + --sku-name "Standard" + +echo "Deploying a text-embedding-ada-002-2 model..." +echo "----------------------" +az cognitiveservices account deployment create \ + --name "$AI_SERVICE" \ + --resource-group "$RESOURCE_GROUP" \ + --deployment-name "text-embedding-ada-002-2" \ --model-name "text-embedding-ada-002" \ --model-version "2" \ --model-format "OpenAI" \ - --sku-capacity 120 \ + --sku-capacity 1 \ + --sku-name "Standard" + +echo "Deploying a text-embedding-3-small-1 model..." +echo "----------------------" +az cognitiveservices account deployment create \ + --name "$AI_SERVICE" \ + --resource-group "$RESOURCE_GROUP" \ + --deployment-name "text-embedding-3-small-1" \ + --model-name "text-embedding-3-small" \ + --model-version "1" \ + --model-format "OpenAI" \ + --sku-capacity 1 \ + --sku-name "Standard" + +echo "Deploying a text-embedding-3-large-1 model..." +echo "----------------------" +az cognitiveservices account deployment create \ + --name "$AI_SERVICE" \ + --resource-group "$RESOURCE_GROUP" \ + --deployment-name "text-embedding-3-large-1" \ + --model-name "text-embedding-3-large" \ + --model-version "1" \ + --model-format "OpenAI" \ + --sku-capacity 1 \ + --sku-name "Standard" + +# Image Models +echo "Deploying Image Models" +echo "======================" + +echo "Deploying a dall-e-3 model..." +echo "----------------------" +az cognitiveservices account deployment create \ + --name "$AI_SERVICE" \ + --resource-group "$RESOURCE_GROUP" \ + --deployment-name "dall-e-2-20" \ + --model-name "dall-e-2" \ + --model-version "2.0" \ + --model-format "OpenAI" \ + --sku-capacity 1 \ --sku-name "Standard" echo "Deploying a dall-e-3 model..." @@ -85,13 +244,42 @@ echo "----------------------" az cognitiveservices account deployment create \ --name "$AI_SERVICE" \ --resource-group "$RESOURCE_GROUP" \ - --deployment-name "dall-e-3" \ + --deployment-name "dall-e-3-30" \ --model-name "dall-e-3" \ --model-version "3.0" \ --model-format "OpenAI" \ --sku-capacity 1 \ --sku-name "Standard" +# Language Models +echo "Deploying Language Models" +echo "=========================" + +echo "Deploying a gpt-35-turbo-instruct-0914 model..." +echo "----------------------" +az cognitiveservices account deployment create \ + --name "$AI_SERVICE" \ + --resource-group "$RESOURCE_GROUP" \ + --deployment-name "gpt-35-turbo-instruct-0914" \ + --model-name "gpt-35-turbo-instruct" \ + --model-version "0914" \ + --model-format "OpenAI" \ + --sku-capacity 1 \ + --sku-name "Standard" + +echo "Deploying a davinci-002-1 model..." +echo "----------------------" +az cognitiveservices account deployment create \ + --name "$AI_SERVICE" \ + --resource-group "$RESOURCE_GROUP" \ + --deployment-name "davinci-002-1" \ + --model-name "davinci-002" \ + --model-version "1" \ + --model-format "OpenAI" \ + --sku-capacity 1 \ + --sku-name "Standard" + + echo "Storing the key and endpoint in environment variables..." echo "--------------------------------------------------------" AZURE_OPENAI_KEY=$( @@ -109,3 +297,10 @@ AZURE_OPENAI_ENDPOINT=$( echo "AZURE_OPENAI_KEY=$AZURE_OPENAI_KEY" echo "AZURE_OPENAI_ENDPOINT=$AZURE_OPENAI_ENDPOINT" + +# Once you finish the tests, you can delete the resource group with the following command: +echo "Deleting the resource group..." +echo "------------------------------" +az group delete \ + --name "$RESOURCE_GROUP" \ + --yes diff --git a/langchain4j-bedrock/pom.xml b/langchain4j-bedrock/pom.xml index 2d341f2f16..aaacc59085 100644 --- a/langchain4j-bedrock/pom.xml +++ b/langchain4j-bedrock/pom.xml @@ -6,7 +6,7 @@ dev.langchain4j langchain4j-parent - 0.30.0 + 0.32.0-SNAPSHOT ../langchain4j-parent/pom.xml @@ -55,6 +55,14 @@ test + + dev.langchain4j + langchain4j-core + tests + test-jar + test + + diff --git a/langchain4j-bedrock/src/main/java/dev/langchain4j/model/bedrock/BedrockAnthropicStreamingChatModel.java b/langchain4j-bedrock/src/main/java/dev/langchain4j/model/bedrock/BedrockAnthropicStreamingChatModel.java new file mode 100644 index 0000000000..465925eca3 --- /dev/null +++ b/langchain4j-bedrock/src/main/java/dev/langchain4j/model/bedrock/BedrockAnthropicStreamingChatModel.java @@ -0,0 +1,33 @@ +package dev.langchain4j.model.bedrock; + +import dev.langchain4j.model.bedrock.internal.AbstractBedrockStreamingChatModel; +import lombok.Builder; +import lombok.Getter; +import lombok.experimental.SuperBuilder; + +@Getter +@SuperBuilder +public class BedrockAnthropicStreamingChatModel extends AbstractBedrockStreamingChatModel { + @Builder.Default + private final String model = BedrockAnthropicStreamingChatModel.Types.AnthropicClaudeV2.getValue(); + + @Override + protected String getModelId() { + return model; + } + + @Getter + /** + * Bedrock Anthropic model ids + */ + public enum Types { + AnthropicClaudeV2("anthropic.claude-v2"), + AnthropicClaudeV2_1("anthropic.claude-v2:1"); + + private final String value; + + Types(String modelID) { + this.value = modelID; + } + } +} diff --git a/langchain4j-bedrock/src/main/java/dev/langchain4j/model/bedrock/internal/AbstractBedrockChatModel.java b/langchain4j-bedrock/src/main/java/dev/langchain4j/model/bedrock/internal/AbstractBedrockChatModel.java index 71a552fbb7..77fba2d57d 100644 --- a/langchain4j-bedrock/src/main/java/dev/langchain4j/model/bedrock/internal/AbstractBedrockChatModel.java +++ b/langchain4j-bedrock/src/main/java/dev/langchain4j/model/bedrock/internal/AbstractBedrockChatModel.java @@ -13,6 +13,7 @@ import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider; import software.amazon.awssdk.core.SdkBytes; import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClient; import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient; import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest; import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; @@ -30,47 +31,14 @@ */ @Getter @SuperBuilder -public abstract class AbstractBedrockChatModel implements ChatLanguageModel { - private static final String HUMAN_PROMPT = "Human:"; - private static final String ASSISTANT_PROMPT = "Assistant:"; - - @Builder.Default - private final String humanPrompt = HUMAN_PROMPT; - @Builder.Default - private final String assistantPrompt = ASSISTANT_PROMPT; - @Builder.Default - private final Integer maxRetries = 5; - @Builder.Default - private final Region region = Region.US_EAST_1; - @Builder.Default - private final AwsCredentialsProvider credentialsProvider = DefaultCredentialsProvider.builder().build(); - @Builder.Default - private final int maxTokens = 300; - @Builder.Default - private final float temperature = 1; - @Builder.Default - private final float topP = 0.999f; - @Builder.Default - private final String[] stopSequences = new String[]{}; +public abstract class AbstractBedrockChatModel extends AbstractSharedBedrockChatModel implements ChatLanguageModel { @Getter(lazy = true) private final BedrockRuntimeClient client = initClient(); @Override public Response generate(List messages) { - final String context = messages.stream() - .filter(message -> message.type() == ChatMessageType.SYSTEM) - .map(ChatMessage::text) - .collect(joining("\n")); - - final String userMessages = messages.stream() - .filter(message -> message.type() != ChatMessageType.SYSTEM) - .map(this::chatMessageToString) - .collect(joining("\n")); - - final String prompt = String.format("%s\n\n%s\n%s", context, userMessages, ASSISTANT_PROMPT); - final Map requestParameters = getRequestParameters(prompt); - final String body = Json.toJson(requestParameters); + final String body = convertMessagesToAwsBody(messages); InvokeModelResponse invokeModelResponse = withRetry(() -> invoke(body), maxRetries); final String response = invokeModelResponse.body().asUtf8String(); @@ -81,26 +49,6 @@ public Response generate(List messages) { result.getFinishReason()); } - /** - * Convert chat message to string - * - * @param message chat message - * @return string - */ - protected String chatMessageToString(ChatMessage message) { - switch (message.type()) { - case SYSTEM: - return message.text(); - case USER: - return humanPrompt + " " + message.text(); - case AI: - return assistantPrompt + " " + message.text(); - case TOOL_EXECUTION_RESULT: - throw new IllegalArgumentException("Tool execution results are not supported for Bedrock models"); - } - - throw new IllegalArgumentException("Unknown message type: " + message.type()); - } /** * Get request parameters @@ -110,13 +58,6 @@ protected String chatMessageToString(ChatMessage message) { */ protected abstract Map getRequestParameters(final String prompt); - /** - * Get model id - * - * @return model id - */ - protected abstract String getModelId(); - /** * Get response class type diff --git a/langchain4j-bedrock/src/main/java/dev/langchain4j/model/bedrock/internal/AbstractBedrockStreamingChatModel.java b/langchain4j-bedrock/src/main/java/dev/langchain4j/model/bedrock/internal/AbstractBedrockStreamingChatModel.java new file mode 100644 index 0000000000..f71f8412b1 --- /dev/null +++ b/langchain4j-bedrock/src/main/java/dev/langchain4j/model/bedrock/internal/AbstractBedrockStreamingChatModel.java @@ -0,0 +1,87 @@ +package dev.langchain4j.model.bedrock.internal; + +import dev.langchain4j.agent.tool.ToolSpecification; +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.data.message.UserMessage; +import dev.langchain4j.internal.Json; +import dev.langchain4j.model.StreamingResponseHandler; +import dev.langchain4j.model.chat.StreamingChatLanguageModel; +import dev.langchain4j.model.output.Response; +import lombok.Getter; +import lombok.experimental.SuperBuilder; +import software.amazon.awssdk.core.SdkBytes; +import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClient; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelWithResponseStreamRequest; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelWithResponseStreamResponseHandler; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.atomic.AtomicReference; + +/** + * Bedrock Streaming chat model + */ +@Getter +@SuperBuilder +public abstract class AbstractBedrockStreamingChatModel extends AbstractSharedBedrockChatModel implements StreamingChatLanguageModel { + @Getter + private final BedrockRuntimeAsyncClient asyncClient = initAsyncClient(); + + class StreamingResponse { + public String completion; + } + + @Override + public void generate(String userMessage, StreamingResponseHandler handler) { + List messages = new ArrayList<>(); + messages.add(new UserMessage(userMessage)); + generate(messages, handler); + } + + @Override + public void generate(List messages, StreamingResponseHandler handler) { + InvokeModelWithResponseStreamRequest request = InvokeModelWithResponseStreamRequest.builder() + .body(SdkBytes.fromUtf8String(convertMessagesToAwsBody(messages))) + .modelId(getModelId()) + .contentType("application/json") + .accept("application/json") + .build(); + + StringBuffer finalCompletion = new StringBuffer(); + + InvokeModelWithResponseStreamResponseHandler.Visitor visitor = InvokeModelWithResponseStreamResponseHandler.Visitor.builder() + .onChunk(chunk -> { + StreamingResponse sr = Json.fromJson(chunk.bytes().asUtf8String(), StreamingResponse.class); + finalCompletion.append(sr.completion); + handler.onNext(sr.completion); + }) + .build(); + + InvokeModelWithResponseStreamResponseHandler h = InvokeModelWithResponseStreamResponseHandler.builder() + .onEventStream(stream -> stream.subscribe(event -> event.accept(visitor))) + .onComplete(() -> { + handler.onComplete(Response.from(new AiMessage(finalCompletion.toString()))); + }) + .onError(handler::onError) + .build(); + asyncClient.invokeModelWithResponseStream(request, h).join(); + + } + + /** + * Initialize async bedrock client + * + * @return async bedrock client + */ + private BedrockRuntimeAsyncClient initAsyncClient() { + BedrockRuntimeAsyncClient client = BedrockRuntimeAsyncClient.builder() + .region(region) + .credentialsProvider(credentialsProvider) + .build(); + return client; + } + + + +} diff --git a/langchain4j-bedrock/src/main/java/dev/langchain4j/model/bedrock/internal/AbstractSharedBedrockChatModel.java b/langchain4j-bedrock/src/main/java/dev/langchain4j/model/bedrock/internal/AbstractSharedBedrockChatModel.java new file mode 100644 index 0000000000..681f4781f1 --- /dev/null +++ b/langchain4j-bedrock/src/main/java/dev/langchain4j/model/bedrock/internal/AbstractSharedBedrockChatModel.java @@ -0,0 +1,112 @@ +package dev.langchain4j.model.bedrock.internal; + +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.data.message.ChatMessageType; +import dev.langchain4j.internal.Json; +import lombok.Builder; +import lombok.Getter; +import lombok.experimental.SuperBuilder; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static java.util.stream.Collectors.joining; + +@Getter +@SuperBuilder +public abstract class AbstractSharedBedrockChatModel { + // Claude requires you to enclose the prompt as follows: + // String enclosedPrompt = "Human: " + prompt + "\n\nAssistant:"; + protected static final String HUMAN_PROMPT = "Human:"; + protected static final String ASSISTANT_PROMPT = "Assistant:"; + protected static final String DEFAULT_ANTHROPIC_VERSION = "bedrock-2023-05-31"; + + @Builder.Default + protected final String humanPrompt = HUMAN_PROMPT; + @Builder.Default + protected final String assistantPrompt = ASSISTANT_PROMPT; + @Builder.Default + protected final Integer maxRetries = 5; + @Builder.Default + protected final Region region = Region.US_EAST_1; + @Builder.Default + protected final AwsCredentialsProvider credentialsProvider = DefaultCredentialsProvider.builder().build(); + @Builder.Default + protected final int maxTokens = 300; + @Builder.Default + protected final double temperature = 1; + @Builder.Default + protected final float topP = 0.999f; + @Builder.Default + protected final String[] stopSequences = new String[]{}; + @Builder.Default + protected final int topK = 250; + @Builder.Default + protected final String anthropicVersion = DEFAULT_ANTHROPIC_VERSION; + + + /** + * Convert chat message to string + * + * @param message chat message + * @return string + */ + protected String chatMessageToString(ChatMessage message) { + switch (message.type()) { + case SYSTEM: + return message.text(); + case USER: + return humanPrompt + " " + message.text(); + case AI: + return assistantPrompt + " " + message.text(); + case TOOL_EXECUTION_RESULT: + throw new IllegalArgumentException("Tool execution results are not supported for Bedrock models"); + } + + throw new IllegalArgumentException("Unknown message type: " + message.type()); + } + + protected String convertMessagesToAwsBody(List messages) { + final String context = messages.stream() + .filter(message -> message.type() == ChatMessageType.SYSTEM) + .map(ChatMessage::text) + .collect(joining("\n")); + + final String userMessages = messages.stream() + .filter(message -> message.type() != ChatMessageType.SYSTEM) + .map(this::chatMessageToString) + .collect(joining("\n")); + + final String prompt = String.format("%s\n\n%s\n%s", context, userMessages, ASSISTANT_PROMPT); + final Map requestParameters = getRequestParameters(prompt); + final String body = Json.toJson(requestParameters); + return body; + } + + protected Map getRequestParameters(String prompt) { + final Map parameters = new HashMap<>(7); + + parameters.put("prompt", prompt); + parameters.put("max_tokens_to_sample", getMaxTokens()); + parameters.put("temperature", getTemperature()); + parameters.put("top_k", topK); + parameters.put("top_p", getTopP()); + parameters.put("stop_sequences", getStopSequences()); + parameters.put("anthropic_version", anthropicVersion); + + return parameters; + } + + /** + * Get model id + * + * @return model id + */ + protected abstract String getModelId(); + +} diff --git a/langchain4j-bedrock/src/test/java/dev/langchain4j/model/bedrock/BedrockChatModelIT.java b/langchain4j-bedrock/src/test/java/dev/langchain4j/model/bedrock/BedrockChatModelIT.java index c369cab650..6ecf298bb7 100644 --- a/langchain4j-bedrock/src/test/java/dev/langchain4j/model/bedrock/BedrockChatModelIT.java +++ b/langchain4j-bedrock/src/test/java/dev/langchain4j/model/bedrock/BedrockChatModelIT.java @@ -1,10 +1,5 @@ package dev.langchain4j.model.bedrock; -import static dev.langchain4j.internal.Utils.readBytes; -import static dev.langchain4j.model.bedrock.BedrockMistralAiChatModel.Types.Mistral7bInstructV0_2; -import static dev.langchain4j.model.bedrock.BedrockMistralAiChatModel.Types.MistralMixtral8x7bInstructV0_1; -import static org.assertj.core.api.Assertions.assertThat; - import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.data.message.ChatMessage; import dev.langchain4j.data.message.ImageContent; @@ -12,119 +7,121 @@ import dev.langchain4j.model.output.FinishReason; import dev.langchain4j.model.output.Response; import dev.langchain4j.model.output.TokenUsage; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import software.amazon.awssdk.regions.Region; + import java.util.Arrays; import java.util.Base64; import java.util.List; -import org.junit.jupiter.api.Disabled; -import org.junit.jupiter.api.Test; -import software.amazon.awssdk.regions.Region; -public class BedrockChatModelIT { - +import static dev.langchain4j.internal.Utils.readBytes; +import static dev.langchain4j.model.bedrock.BedrockMistralAiChatModel.Types.Mistral7bInstructV0_2; +import static dev.langchain4j.model.bedrock.BedrockMistralAiChatModel.Types.MistralMixtral8x7bInstructV0_1; +import static org.assertj.core.api.Assertions.assertThat; + +@EnabledIfEnvironmentVariable(named = "AWS_SECRET_ACCESS_KEY", matches = ".+") +class BedrockChatModelIT { + private static final String CAT_IMAGE_URL = "https://upload.wikimedia.org/wikipedia/commons/e/e9/Felis_silvestris_silvestris_small_gradual_decrease_of_quality.png"; - + @Test - @Disabled("To run this test, you must have provide your own access key, secret, region") void testBedrockAnthropicV3SonnetChatModel() { - + BedrockAnthropicMessageChatModel bedrockChatModel = BedrockAnthropicMessageChatModel - .builder() - .temperature(0.50f) - .maxTokens(300) - .region(Region.US_EAST_1) - .model(BedrockAnthropicMessageChatModel.Types.AnthropicClaude3SonnetV1.getValue()) - .maxRetries(1) - .build(); - + .builder() + .temperature(0.50f) + .maxTokens(300) + .region(Region.US_EAST_1) + .model(BedrockAnthropicMessageChatModel.Types.AnthropicClaude3SonnetV1.getValue()) + .maxRetries(1) + .build(); + assertThat(bedrockChatModel).isNotNull(); - + Response response = bedrockChatModel.generate(UserMessage.from("hi, how are you doing?")); - + assertThat(response).isNotNull(); assertThat(response.content().text()).isNotBlank(); assertThat(response.tokenUsage()).isNotNull(); assertThat(response.finishReason()).isIn(FinishReason.STOP, FinishReason.LENGTH); } - + @Test - @Disabled("To run this test, you must have provide your own access key, secret, region") void testBedrockAnthropicV3SonnetChatModelImageContent() { - + BedrockAnthropicMessageChatModel bedrockChatModel = BedrockAnthropicMessageChatModel - .builder() - .temperature(0.50f) - .maxTokens(300) - .region(Region.US_EAST_1) - .model(BedrockAnthropicMessageChatModel.Types.AnthropicClaude3SonnetV1.getValue()) - .maxRetries(1) - .build(); - + .builder() + .temperature(0.50f) + .maxTokens(300) + .region(Region.US_EAST_1) + .model(BedrockAnthropicMessageChatModel.Types.AnthropicClaude3SonnetV1.getValue()) + .maxRetries(1) + .build(); + assertThat(bedrockChatModel).isNotNull(); - + String base64Data = Base64.getEncoder().encodeToString(readBytes(CAT_IMAGE_URL)); ImageContent imageContent = ImageContent.from(base64Data, "image/png"); UserMessage userMessage = UserMessage.from(imageContent); - + Response response = bedrockChatModel.generate(userMessage); - + assertThat(response).isNotNull(); assertThat(response.content().text()).isNotBlank(); assertThat(response.tokenUsage()).isNotNull(); assertThat(response.finishReason()).isIn(FinishReason.STOP, FinishReason.LENGTH); } - + @Test - @Disabled("To run this test, you must have provide your own access key, secret, region") void testBedrockAnthropicV3HaikuChatModel() { - + BedrockAnthropicMessageChatModel bedrockChatModel = BedrockAnthropicMessageChatModel - .builder() - .temperature(0.50f) - .maxTokens(300) - .region(Region.US_EAST_1) - .model(BedrockAnthropicMessageChatModel.Types.AnthropicClaude3HaikuV1.getValue()) - .maxRetries(1) - .build(); - + .builder() + .temperature(0.50f) + .maxTokens(300) + .region(Region.US_EAST_1) + .model(BedrockAnthropicMessageChatModel.Types.AnthropicClaude3HaikuV1.getValue()) + .maxRetries(1) + .build(); + assertThat(bedrockChatModel).isNotNull(); - + Response response = bedrockChatModel.generate(UserMessage.from("hi, how are you doing?")); - + assertThat(response).isNotNull(); assertThat(response.content().text()).isNotBlank(); assertThat(response.tokenUsage()).isNotNull(); assertThat(response.finishReason()).isIn(FinishReason.STOP, FinishReason.LENGTH); } - + @Test - @Disabled("To run this test, you must have provide your own access key, secret, region") void testBedrockAnthropicV3HaikuChatModelImageContent() { - + BedrockAnthropicMessageChatModel bedrockChatModel = BedrockAnthropicMessageChatModel - .builder() - .temperature(0.50f) - .maxTokens(300) - .region(Region.US_EAST_1) - .model(BedrockAnthropicMessageChatModel.Types.AnthropicClaude3HaikuV1.getValue()) - .maxRetries(1) - .build(); - + .builder() + .temperature(0.50f) + .maxTokens(300) + .region(Region.US_EAST_1) + .model(BedrockAnthropicMessageChatModel.Types.AnthropicClaude3HaikuV1.getValue()) + .maxRetries(1) + .build(); + assertThat(bedrockChatModel).isNotNull(); - + String base64Data = Base64.getEncoder().encodeToString(readBytes(CAT_IMAGE_URL)); ImageContent imageContent = ImageContent.from(base64Data, "image/png"); UserMessage userMessage = UserMessage.from(imageContent); - + Response response = bedrockChatModel.generate(userMessage); - + assertThat(response).isNotNull(); assertThat(response.content().text()).isNotBlank(); assertThat(response.tokenUsage()).isNotNull(); assertThat(response.finishReason()).isIn(FinishReason.STOP, FinishReason.LENGTH); } - + @Test - @Disabled("To run this test, you must have provide your own access key, secret, region") void testBedrockAnthropicV2ChatModelEnumModelType() { BedrockAnthropicCompletionChatModel bedrockChatModel = BedrockAnthropicCompletionChatModel @@ -145,24 +142,23 @@ void testBedrockAnthropicV2ChatModelEnumModelType() { assertThat(response.tokenUsage()).isNull(); assertThat(response.finishReason()).isIn(FinishReason.STOP, FinishReason.LENGTH); } - + @Test - @Disabled("To run this test, you must have provide your own access key, secret, region") void testBedrockAnthropicV2ChatModelStringModelType() { - + BedrockAnthropicCompletionChatModel bedrockChatModel = BedrockAnthropicCompletionChatModel - .builder() - .temperature(0.50f) - .maxTokens(300) - .region(Region.US_EAST_1) - .model("anthropic.claude-v2") - .maxRetries(1) - .build(); - + .builder() + .temperature(0.50f) + .maxTokens(300) + .region(Region.US_EAST_1) + .model("anthropic.claude-v2") + .maxRetries(1) + .build(); + assertThat(bedrockChatModel).isNotNull(); - + Response response = bedrockChatModel.generate(UserMessage.from("hi, how are you doing?")); - + assertThat(response).isNotNull(); assertThat(response.content().text()).isNotBlank(); assertThat(response.tokenUsage()).isNull(); @@ -170,7 +166,6 @@ void testBedrockAnthropicV2ChatModelStringModelType() { } @Test - @Disabled("To run this test, you must have provide your own access key, secret, region") void testBedrockTitanChatModel() { BedrockTitanChatModel bedrockChatModel = BedrockTitanChatModel @@ -198,7 +193,6 @@ void testBedrockTitanChatModel() { } @Test - @Disabled("To run this test, you must have provide your own access key, secret, region") void testBedrockCohereChatModel() { BedrockCohereChatModel bedrockChatModel = BedrockCohereChatModel @@ -220,7 +214,6 @@ void testBedrockCohereChatModel() { } @Test - @Disabled("To run this test, you must have provide your own access key, secret, region") void testBedrockStabilityChatModel() { BedrockStabilityAIChatModel bedrockChatModel = BedrockStabilityAIChatModel @@ -241,97 +234,93 @@ void testBedrockStabilityChatModel() { assertThat(response.tokenUsage()).isNull(); assertThat(response.finishReason()).isIn(FinishReason.STOP, FinishReason.LENGTH); } - + @Test - @Disabled("To run this test, you must have provide your own access key, secret, region") void testBedrockLlama13BChatModel() { - + BedrockLlamaChatModel bedrockChatModel = BedrockLlamaChatModel - .builder() - .temperature(0.50f) - .maxTokens(300) - .region(Region.US_EAST_1) - .model(BedrockLlamaChatModel.Types.MetaLlama2Chat13B.getValue()) - .maxRetries(1) - .build(); - + .builder() + .temperature(0.50f) + .maxTokens(300) + .region(Region.US_EAST_1) + .model(BedrockLlamaChatModel.Types.MetaLlama2Chat13B.getValue()) + .maxRetries(1) + .build(); + assertThat(bedrockChatModel).isNotNull(); - + Response response = bedrockChatModel.generate(UserMessage.from("hi, how are you doing?")); - + assertThat(response).isNotNull(); assertThat(response.content().text()).isNotBlank(); assertThat(response.tokenUsage()).isNotNull(); assertThat(response.finishReason()).isIn(FinishReason.STOP, FinishReason.LENGTH); } - + @Test - @Disabled("To run this test, you must have provide your own access key, secret, region") void testBedrockLlama70BChatModel() { - + BedrockLlamaChatModel bedrockChatModel = BedrockLlamaChatModel - .builder() - .temperature(0.50f) - .maxTokens(300) - .region(Region.US_EAST_1) - .model(BedrockLlamaChatModel.Types.MetaLlama2Chat70B.getValue()) - .maxRetries(1) - .build(); - + .builder() + .temperature(0.50f) + .maxTokens(300) + .region(Region.US_EAST_1) + .model(BedrockLlamaChatModel.Types.MetaLlama2Chat70B.getValue()) + .maxRetries(1) + .build(); + assertThat(bedrockChatModel).isNotNull(); - + Response response = bedrockChatModel.generate(UserMessage.from("hi, how are you doing?")); - + assertThat(response).isNotNull(); assertThat(response.content().text()).isNotBlank(); assertThat(response.tokenUsage()).isNotNull(); assertThat(response.finishReason()).isIn(FinishReason.STOP, FinishReason.LENGTH); } - + @Test - @Disabled("To run this test, you must have provide your own access key, secret, region") void testBedrockMistralAi7bInstructChatModel() { - + BedrockMistralAiChatModel bedrockChatModel = BedrockMistralAiChatModel - .builder() - .temperature(0.50f) - .maxTokens(300) - .region(Region.US_EAST_1) - .model(Mistral7bInstructV0_2.getValue()) - .maxRetries(1) - .build(); - + .builder() + .temperature(0.50f) + .maxTokens(300) + .region(Region.US_EAST_1) + .model(Mistral7bInstructV0_2.getValue()) + .maxRetries(1) + .build(); + assertThat(bedrockChatModel).isNotNull(); - + List messages = Arrays.asList( - UserMessage.from("hi, how are you doing"), - AiMessage.from("I am an AI model so I don't have feelings"), - UserMessage.from("Ok no worries, tell me story about a man who wears a tin hat.")); - + UserMessage.from("hi, how are you doing"), + AiMessage.from("I am an AI model so I don't have feelings"), + UserMessage.from("Ok no worries, tell me story about a man who wears a tin hat.")); + Response response = bedrockChatModel.generate(messages); - + assertThat(response).isNotNull(); assertThat(response.content().text()).isNotBlank(); assertThat(response.finishReason()).isIn(FinishReason.STOP, FinishReason.LENGTH); } - + @Test - @Disabled("To run this test, you must have provide your own access key, secret, region") void testBedrockMistralAiMixtral8x7bInstructChatModel() { - + BedrockMistralAiChatModel bedrockChatModel = BedrockMistralAiChatModel - .builder() - .temperature(0.50f) - .maxTokens(300) - .region(Region.US_EAST_1) - .model(MistralMixtral8x7bInstructV0_1.getValue()) - .maxRetries(1) - .build(); - + .builder() + .temperature(0.50f) + .maxTokens(300) + .region(Region.US_EAST_1) + .model(MistralMixtral8x7bInstructV0_1.getValue()) + .maxRetries(1) + .build(); + assertThat(bedrockChatModel).isNotNull(); - + Response response = bedrockChatModel.generate(UserMessage.from("hi, how are you doing?")); - + assertThat(response).isNotNull(); assertThat(response.content().text()).isNotBlank(); assertThat(response.finishReason()).isIn(FinishReason.STOP, FinishReason.LENGTH); diff --git a/langchain4j-bedrock/src/test/java/dev/langchain4j/model/bedrock/BedrockEmbeddingIT.java b/langchain4j-bedrock/src/test/java/dev/langchain4j/model/bedrock/BedrockEmbeddingIT.java index be6bbe6ee2..edcf657fa4 100644 --- a/langchain4j-bedrock/src/test/java/dev/langchain4j/model/bedrock/BedrockEmbeddingIT.java +++ b/langchain4j-bedrock/src/test/java/dev/langchain4j/model/bedrock/BedrockEmbeddingIT.java @@ -4,8 +4,8 @@ import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.model.output.Response; import dev.langchain4j.model.output.TokenUsage; -import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import software.amazon.awssdk.regions.Region; import java.util.Collections; @@ -13,10 +13,10 @@ import static org.assertj.core.api.Assertions.assertThat; -public class BedrockEmbeddingIT { +@EnabledIfEnvironmentVariable(named = "AWS_SECRET_ACCESS_KEY", matches = ".+") +class BedrockEmbeddingIT { @Test - @Disabled("To run this test, you must have provide your own access key, secret, region") void testBedrockTitanChatModel() { BedrockTitanEmbeddingModel embeddingModel = BedrockTitanEmbeddingModel @@ -46,5 +46,4 @@ void testBedrockTitanChatModel() { assertThat(response.finishReason()).isNull(); } - } diff --git a/langchain4j-bedrock/src/test/java/dev/langchain4j/model/bedrock/BedrockStreamingChatModelIT.java b/langchain4j-bedrock/src/test/java/dev/langchain4j/model/bedrock/BedrockStreamingChatModelIT.java new file mode 100644 index 0000000000..1ec1ff4e4b --- /dev/null +++ b/langchain4j-bedrock/src/test/java/dev/langchain4j/model/bedrock/BedrockStreamingChatModelIT.java @@ -0,0 +1,38 @@ +package dev.langchain4j.model.bedrock; + +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.UserMessage; +import dev.langchain4j.model.chat.TestStreamingResponseHandler; +import dev.langchain4j.model.output.Response; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import software.amazon.awssdk.regions.Region; + +import static dev.langchain4j.data.message.UserMessage.userMessage; +import static java.util.Collections.singletonList; +import static org.assertj.core.api.Assertions.assertThat; + +@EnabledIfEnvironmentVariable(named = "AWS_SECRET_ACCESS_KEY", matches = ".+") +class BedrockStreamingChatModelIT { + + @Test + void testBedrockAnthropicStreamingChatModel() { + //given + BedrockAnthropicStreamingChatModel bedrockChatModel = BedrockAnthropicStreamingChatModel + .builder() + .temperature(0.5) + .maxTokens(300) + .region(Region.US_EAST_1) + .maxRetries(1) + .build(); + UserMessage userMessage = userMessage("What's the capital of Poland?"); + + //when + TestStreamingResponseHandler handler = new TestStreamingResponseHandler<>(); + bedrockChatModel.generate(singletonList(userMessage), handler); + Response response = handler.get(); + + //then + assertThat(response.content().text()).contains("Warsaw"); + } +} diff --git a/langchain4j-bom/pom.xml b/langchain4j-bom/pom.xml index 02655ec06d..96e0cf2115 100644 --- a/langchain4j-bom/pom.xml +++ b/langchain4j-bom/pom.xml @@ -6,7 +6,7 @@ dev.langchain4j langchain4j-parent - 0.30.0 + 0.32.0-SNAPSHOT ../langchain4j-parent/pom.xml @@ -41,7 +41,7 @@ dev.langchain4j - langchain4j-open-ai + langchain4j-anthropic ${project.version} @@ -53,43 +53,55 @@ dev.langchain4j - langchain4j-hugging-face + langchain4j-bedrock ${project.version} dev.langchain4j - langchain4j-local-ai + langchain4j-chatglm ${project.version} dev.langchain4j - langchain4j-vertex-ai + langchain4j-cohere ${project.version} dev.langchain4j - langchain4j-vertex-ai-gemini + langchain4j-dashscope ${project.version} dev.langchain4j - langchain4j-dashscope + langchain4j-hugging-face ${project.version} dev.langchain4j - langchain4j-qianfan + langchain4j-jina ${project.version} dev.langchain4j - langchain4j-zhipu-ai + langchain4j-local-ai + ${project.version} + + + + dev.langchain4j + langchain4j-mistral-ai + ${project.version} + + + + dev.langchain4j + langchain4j-nomic ${project.version} @@ -101,42 +113,54 @@ dev.langchain4j - langchain4j-bedrock + langchain4j-open-ai ${project.version} dev.langchain4j - langchain4j-chatglm + langchain4j-qianfan ${project.version} dev.langchain4j - langchain4j-mistral-ai + langchain4j-vertex-ai ${project.version} dev.langchain4j - langchain4j-cohere + langchain4j-vertex-ai-gemini ${project.version} dev.langchain4j - langchain4j-nomic + langchain4j-workers-ai ${project.version} dev.langchain4j - langchain4j-anthropic + langchain4j-zhipu-ai ${project.version} + + dev.langchain4j + langchain4j-azure-ai-search + ${project.version} + + + + dev.langchain4j + langchain4j-azure-cosmos-mongo-vcore + ${project.version} + + dev.langchain4j langchain4j-cassandra @@ -169,61 +193,67 @@ dev.langchain4j - langchain4j-opensearch + langchain4j-mongodb-atlas ${project.version} dev.langchain4j - langchain4j-pgvector + langchain4j-neo4j ${project.version} dev.langchain4j - langchain4j-pinecone + langchain4j-opensearch ${project.version} dev.langchain4j - langchain4j-qdrant + langchain4j-pgvector ${project.version} dev.langchain4j - langchain4j-redis + langchain4j-pinecone ${project.version} dev.langchain4j - langchain4j-vespa + langchain4j-qdrant ${project.version} dev.langchain4j - langchain4j-weaviate + langchain4j-redis ${project.version} dev.langchain4j - langchain4j-neo4j + langchain4j-vearch ${project.version} dev.langchain4j - langchain4j-vearch + langchain4j-vespa ${project.version} dev.langchain4j - langchain4j-mongodb-atlas + langchain4j-weaviate + ${project.version} + + + + dev.langchain4j + langchain4j-azure-cosmos-nosql ${project.version} @@ -255,7 +285,7 @@ dev.langchain4j - langchain4j-embeddings-bge-small-v15-en + langchain4j-embeddings-bge-small-en-v15 ${project.version} @@ -277,6 +307,18 @@ ${project.version} + + dev.langchain4j + langchain4j-embeddings-bge-small-zh-v15 + ${project.version} + + + + dev.langchain4j + langchain4j-embeddings-bge-small-zh-v15-q + ${project.version} + + dev.langchain4j langchain4j-embeddings-e5-small-v2 @@ -297,17 +339,23 @@ ${project.version} + + dev.langchain4j + langchain4j-code-execution-engine-judge0 + ${project.version} + + dev.langchain4j - langchain4j-document-loader-tencent-cos + langchain4j-document-loader-amazon-s3 ${project.version} dev.langchain4j - langchain4j-document-loader-amazon-s3 + langchain4j-document-loader-azure-storage-blob ${project.version} @@ -319,7 +367,13 @@ dev.langchain4j - langchain4j-document-loader-azure-storage-blob + langchain4j-document-loader-selenium + ${project.version} + + + + dev.langchain4j + langchain4j-document-loader-tencent-cos ${project.version} @@ -351,6 +405,57 @@ ${project.version} + + + dev.langchain4j + langchain4j-web-search-engine-google-custom + ${project.version} + + + + + dev.langchain4j + langchain4j-experimental-sql + ${project.version} + + + + + dev.langchain4j + langchain4j-spring-boot-starter + ${project.version} + + + + dev.langchain4j + langchain4j-anthropic-spring-boot-starter + ${project.version} + + + + dev.langchain4j + langchain4j-azure-ai-search-spring-boot-starter + ${project.version} + + + + dev.langchain4j + langchain4j-azure-open-ai-spring-boot-starter + ${project.version} + + + + dev.langchain4j + langchain4j-ollama-spring-boot-starter + ${project.version} + + + + dev.langchain4j + langchain4j-open-ai-spring-boot-starter + ${project.version} + + diff --git a/langchain4j-cassandra/pom.xml b/langchain4j-cassandra/pom.xml index b486400d67..dd9eadf472 100644 --- a/langchain4j-cassandra/pom.xml +++ b/langchain4j-cassandra/pom.xml @@ -10,14 +10,16 @@ dev.langchain4j langchain4j-parent - 0.30.0 + 0.32.0-SNAPSHOT ../langchain4j-parent/pom.xml - 4.18.0 + 1.2.4 + 4.17.0 1.4.14 3.25.3 + 2.16.1 11 @@ -92,7 +94,6 @@ ch.qos.logback logback-classic - ${logback.version} test diff --git a/langchain4j-cassandra/src/test/java/dev/langchain4j/store/embedding/cassandra/CassandraEmbeddingStoreIT.java b/langchain4j-cassandra/src/test/java/dev/langchain4j/store/embedding/cassandra/CassandraEmbeddingStoreIT.java index 202448d5bc..6c684f20ad 100644 --- a/langchain4j-cassandra/src/test/java/dev/langchain4j/store/embedding/cassandra/CassandraEmbeddingStoreIT.java +++ b/langchain4j-cassandra/src/test/java/dev/langchain4j/store/embedding/cassandra/CassandraEmbeddingStoreIT.java @@ -89,8 +89,8 @@ void should_retrieve_inserted_vector_by_ann_and_metadata() { String sourceSentence = "In GOD we trust, everything else we test!"; Embedding sourceEmbedding = embeddingModel().embed(sourceSentence).content(); TextSegment sourceTextSegment = TextSegment.from(sourceSentence, new Metadata() - .add("user", "GOD") - .add("test", "false")); + .put("user", "GOD") + .put("test", "false")); String id = embeddingStore().add(sourceEmbedding, sourceTextSegment); assertThat(id != null && !id.isEmpty()).isTrue(); diff --git a/langchain4j-chatglm/pom.xml b/langchain4j-chatglm/pom.xml index 30ca7d57ce..9fafcd362b 100644 --- a/langchain4j-chatglm/pom.xml +++ b/langchain4j-chatglm/pom.xml @@ -6,7 +6,7 @@ dev.langchain4j langchain4j-parent - 0.30.0 + 0.32.0-SNAPSHOT ../langchain4j-parent/pom.xml diff --git a/langchain4j-chroma/pom.xml b/langchain4j-chroma/pom.xml index fcf1aaefc6..1dc391697c 100644 --- a/langchain4j-chroma/pom.xml +++ b/langchain4j-chroma/pom.xml @@ -7,7 +7,7 @@ dev.langchain4j langchain4j-parent - 0.30.0 + 0.32.0-SNAPSHOT ../langchain4j-parent/pom.xml diff --git a/langchain4j-cohere/pom.xml b/langchain4j-cohere/pom.xml index b3c9a662a7..f6ad6b5008 100644 --- a/langchain4j-cohere/pom.xml +++ b/langchain4j-cohere/pom.xml @@ -6,7 +6,7 @@ dev.langchain4j langchain4j-parent - 0.30.0 + 0.32.0-SNAPSHOT ../langchain4j-parent/pom.xml diff --git a/langchain4j-cohere/src/main/java/dev/langchain4j/model/cohere/BilledUnits.java b/langchain4j-cohere/src/main/java/dev/langchain4j/model/cohere/BilledUnits.java index 68e7ac23fc..eaf9327f65 100644 --- a/langchain4j-cohere/src/main/java/dev/langchain4j/model/cohere/BilledUnits.java +++ b/langchain4j-cohere/src/main/java/dev/langchain4j/model/cohere/BilledUnits.java @@ -5,5 +5,7 @@ @Getter class BilledUnits { + private Integer inputTokens; + private Integer outputTokens; private Integer searchUnits; } diff --git a/langchain4j-cohere/src/main/java/dev/langchain4j/model/cohere/CohereApi.java b/langchain4j-cohere/src/main/java/dev/langchain4j/model/cohere/CohereApi.java index 77ebf5a6a4..fcd7512101 100644 --- a/langchain4j-cohere/src/main/java/dev/langchain4j/model/cohere/CohereApi.java +++ b/langchain4j-cohere/src/main/java/dev/langchain4j/model/cohere/CohereApi.java @@ -8,7 +8,11 @@ interface CohereApi { + @POST("embed") + @Headers({"accept: application/json", "content-type: application/json"}) + Call embed(@Body EmbedRequest request, @Header("Authorization") String authorizationHeader); + @POST("rerank") - @Headers({"Content-Type: application/json"}) + @Headers({"accept: application/json", "content-type: application/json"}) Call rerank(@Body RerankRequest request, @Header("Authorization") String authorizationHeader); } diff --git a/langchain4j-cohere/src/main/java/dev/langchain4j/model/cohere/CohereClient.java b/langchain4j-cohere/src/main/java/dev/langchain4j/model/cohere/CohereClient.java index ed608672df..f4abcfb67f 100644 --- a/langchain4j-cohere/src/main/java/dev/langchain4j/model/cohere/CohereClient.java +++ b/langchain4j-cohere/src/main/java/dev/langchain4j/model/cohere/CohereClient.java @@ -49,7 +49,22 @@ class CohereClient { this.authorizationHeader = "Bearer " + ensureNotBlank(apiKey, "apiKey"); } - public RerankResponse rerank(RerankRequest request) { + EmbedResponse embed(EmbedRequest request) { + try { + retrofit2.Response retrofitResponse + = cohereApi.embed(request, authorizationHeader).execute(); + + if (retrofitResponse.isSuccessful()) { + return retrofitResponse.body(); + } else { + throw toException(retrofitResponse); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + RerankResponse rerank(RerankRequest request) { try { retrofit2.Response retrofitResponse = cohereApi.rerank(request, authorizationHeader).execute(); diff --git a/langchain4j-cohere/src/main/java/dev/langchain4j/model/cohere/CohereEmbeddingModel.java b/langchain4j-cohere/src/main/java/dev/langchain4j/model/cohere/CohereEmbeddingModel.java new file mode 100644 index 0000000000..52431ab342 --- /dev/null +++ b/langchain4j-cohere/src/main/java/dev/langchain4j/model/cohere/CohereEmbeddingModel.java @@ -0,0 +1,86 @@ +package dev.langchain4j.model.cohere; + +import dev.langchain4j.data.embedding.Embedding; +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.model.embedding.EmbeddingModel; +import dev.langchain4j.model.output.Response; +import dev.langchain4j.model.output.TokenUsage; +import lombok.Builder; + +import java.time.Duration; +import java.util.List; + +import static dev.langchain4j.internal.Utils.getOrDefault; +import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank; +import static java.time.Duration.ofSeconds; +import static java.util.Arrays.stream; +import static java.util.stream.Collectors.toList; + +/** + * An implementation of an {@link EmbeddingModel} that uses + * Cohere Embed API. + */ +public class CohereEmbeddingModel implements EmbeddingModel { + + private static final String DEFAULT_BASE_URL = "https://api.cohere.ai/v1/"; + + private final CohereClient client; + private final String modelName; + private final String inputType; + + @Builder + public CohereEmbeddingModel(String baseUrl, + String apiKey, + String modelName, + String inputType, + Duration timeout, + Boolean logRequests, + Boolean logResponses) { + this.client = CohereClient.builder() + .baseUrl(getOrDefault(baseUrl, DEFAULT_BASE_URL)) + .apiKey(ensureNotBlank(apiKey, "apiKey")) + .timeout(getOrDefault(timeout, ofSeconds(60))) + .logRequests(getOrDefault(logRequests, false)) + .logResponses(getOrDefault(logResponses, false)) + .build(); + this.modelName = modelName; + this.inputType = inputType; + } + + public static CohereEmbeddingModel withApiKey(String apiKey) { + return CohereEmbeddingModel.builder().apiKey(apiKey).build(); + } + + @Override + public Response> embedAll(List textSegments) { + + List texts = textSegments.stream() + .map(TextSegment::text) + .collect(toList()); + + EmbedRequest request = EmbedRequest.builder() + .texts(texts) + .inputType(inputType) + .model(modelName) + .build(); + + EmbedResponse response = this.client.embed(request); + + return Response.from(getEmbeddings(response), getTokenUsage(response)); + } + + private static List getEmbeddings(EmbedResponse response) { + return stream(response.getEmbeddings()) + .map(Embedding::from) + .collect(toList()); + } + + private static TokenUsage getTokenUsage(EmbedResponse response) { + if (response.getMeta() != null + && response.getMeta().getBilledUnits() != null + && response.getMeta().getBilledUnits().getInputTokens() != null) { + return new TokenUsage(response.getMeta().getBilledUnits().getInputTokens(), 0); + } + return null; + } +} \ No newline at end of file diff --git a/langchain4j-cohere/src/main/java/dev/langchain4j/model/cohere/EmbedRequest.java b/langchain4j-cohere/src/main/java/dev/langchain4j/model/cohere/EmbedRequest.java new file mode 100644 index 0000000000..7805ef112d --- /dev/null +++ b/langchain4j-cohere/src/main/java/dev/langchain4j/model/cohere/EmbedRequest.java @@ -0,0 +1,13 @@ +package dev.langchain4j.model.cohere; + +import lombok.Builder; + +import java.util.List; + +@Builder +class EmbedRequest { + + private List texts; + private String model; + private String inputType; +} \ No newline at end of file diff --git a/langchain4j-cohere/src/main/java/dev/langchain4j/model/cohere/EmbedResponse.java b/langchain4j-cohere/src/main/java/dev/langchain4j/model/cohere/EmbedResponse.java new file mode 100644 index 0000000000..1a63ef549e --- /dev/null +++ b/langchain4j-cohere/src/main/java/dev/langchain4j/model/cohere/EmbedResponse.java @@ -0,0 +1,12 @@ +package dev.langchain4j.model.cohere; + +import lombok.Getter; + +@Getter +class EmbedResponse { + + private String id; + private String[] texts; + private float[][] embeddings; + private Meta meta; +} \ No newline at end of file diff --git a/langchain4j-cohere/src/test/java/dev/langchain4j/model/cohere/CohereEmbeddingModelIT.java b/langchain4j-cohere/src/test/java/dev/langchain4j/model/cohere/CohereEmbeddingModelIT.java new file mode 100644 index 0000000000..e0f994f8a2 --- /dev/null +++ b/langchain4j-cohere/src/test/java/dev/langchain4j/model/cohere/CohereEmbeddingModelIT.java @@ -0,0 +1,74 @@ +package dev.langchain4j.model.cohere; + +import dev.langchain4j.data.embedding.Embedding; +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.model.embedding.EmbeddingModel; +import dev.langchain4j.model.output.Response; +import dev.langchain4j.store.embedding.CosineSimilarity; +import org.junit.jupiter.api.Test; + +import java.time.Duration; +import java.util.List; + +import static java.util.Arrays.asList; +import static org.assertj.core.api.Assertions.assertThat; + +class CohereEmbeddingModelIT { + + @Test + void should_embed_single_text() { + + // given + EmbeddingModel model = CohereEmbeddingModel.withApiKey(System.getenv("COHERE_API_KEY")); + + // when + Response response = model.embed("Hello World"); + + // then + assertThat(response.content().dimension()).isEqualTo(4096); + + assertThat(response.tokenUsage().inputTokenCount()).isEqualTo(2); + assertThat(response.tokenUsage().outputTokenCount()).isEqualTo(0); + assertThat(response.tokenUsage().totalTokenCount()).isEqualTo(2); + + assertThat(response.finishReason()).isNull(); + } + + @Test + public void should_embed_multiple_segments() { + + // given + EmbeddingModel model = CohereEmbeddingModel.builder() + .baseUrl("https://api.cohere.ai/v1/") + .apiKey(System.getenv("COHERE_API_KEY")) + .modelName("embed-english-light-v3.0") + .inputType("search_document") + .timeout(Duration.ofSeconds(60)) + .logRequests(true) + .logResponses(true) + .build(); + + TextSegment segment1 = TextSegment.from("hello"); + TextSegment segment2 = TextSegment.from("hi"); + + // when + Response> response = model.embedAll(asList(segment1, segment2)); + + // then + assertThat(response.content()).hasSize(2); + + Embedding embedding1 = response.content().get(0); + assertThat(embedding1.dimension()).isEqualTo(384); + + Embedding embedding2 = response.content().get(1); + assertThat(embedding2.dimension()).isEqualTo(384); + + assertThat(CosineSimilarity.between(embedding1, embedding2)).isGreaterThan(0.9); + + assertThat(response.tokenUsage().inputTokenCount()).isEqualTo(2); + assertThat(response.tokenUsage().outputTokenCount()).isEqualTo(0); + assertThat(response.tokenUsage().totalTokenCount()).isEqualTo(2); + + assertThat(response.finishReason()).isNull(); + } +} \ No newline at end of file diff --git a/langchain4j-core/pom.xml b/langchain4j-core/pom.xml index 9a740810b4..410420da2e 100644 --- a/langchain4j-core/pom.xml +++ b/langchain4j-core/pom.xml @@ -12,7 +12,7 @@ dev.langchain4j langchain4j-parent - 0.30.0 + 0.32.0-SNAPSHOT ../langchain4j-parent/pom.xml @@ -125,6 +125,8 @@ dev.langchain4j.data.document + dev.langchain4j.model.chat.listener + dev.langchain4j.model.listener dev.langchain4j.store.embedding dev.langchain4j.store.embedding.filter dev.langchain4j.store.embedding.filter.logical @@ -201,4 +203,4 @@ - \ No newline at end of file + diff --git a/langchain4j-core/src/main/java/dev/langchain4j/data/document/Document.java b/langchain4j-core/src/main/java/dev/langchain4j/data/document/Document.java index 85ddf9b44b..1bb58be407 100644 --- a/langchain4j-core/src/main/java/dev/langchain4j/data/document/Document.java +++ b/langchain4j-core/src/main/java/dev/langchain4j/data/document/Document.java @@ -76,8 +76,11 @@ public Metadata metadata() { * * @param key the key to look up. * @return the metadata value for the given key, or null if the key is not present. + * @deprecated as of 0.31.0, use {@link #metadata()} and then {@link Metadata#getString(String)}, + * {@link Metadata#getInteger(String)}, {@link Metadata#getLong(String)}, {@link Metadata#getFloat(String)}, + * {@link Metadata#getDouble(String)} instead. */ - // TODO deprecate once the new experimental API is settled + @Deprecated public String metadata(String key) { return metadata.get(key); } @@ -88,7 +91,7 @@ public String metadata(String key) { * @return a TextSegment. */ public TextSegment toTextSegment() { - return TextSegment.from(text, metadata.copy().add("index", "0")); + return TextSegment.from(text, metadata.copy().put("index", "0")); } @Override diff --git a/langchain4j-core/src/main/java/dev/langchain4j/data/document/Metadata.java b/langchain4j-core/src/main/java/dev/langchain4j/data/document/Metadata.java index 4c277fb2b6..3a59cbfce6 100644 --- a/langchain4j-core/src/main/java/dev/langchain4j/data/document/Metadata.java +++ b/langchain4j-core/src/main/java/dev/langchain4j/data/document/Metadata.java @@ -1,6 +1,5 @@ package dev.langchain4j.data.document; -import dev.langchain4j.Experimental; import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.store.embedding.EmbeddingStore; @@ -84,8 +83,10 @@ private static void validate(String key, Object value) { * * @param key the key * @return the value associated with the given key, or {@code null} if the key is not present. + * @deprecated as of 0.31.0, use {@link #getString(String)}, {@link #getInteger(String)}, {@link #getLong(String)}, + * {@link #getFloat(String)}, {@link #getDouble(String)} instead. */ - // TODO deprecate once the new experimental API is settled + @Deprecated public String get(String key) { Object value = metadata.get(key); if (value != null) { @@ -102,7 +103,6 @@ public String get(String key) { * @return the {@code String} value associated with the given key, or {@code null} if the key is not present. * @throws RuntimeException if the value is not of type String */ - @Experimental public String getString(String key) { if (!containsKey(key)) { return null; @@ -133,7 +133,6 @@ public String getString(String key) { * @return the {@link Integer} value associated with the given key, or {@code null} if the key is not present. * @throws RuntimeException if the value is not {@link Number} */ - @Experimental public Integer getInteger(String key) { if (!containsKey(key)) { return null; @@ -166,7 +165,6 @@ public Integer getInteger(String key) { * @return the {@code Long} value associated with the given key, or {@code null} if the key is not present. * @throws RuntimeException if the value is not {@link Number} */ - @Experimental public Long getLong(String key) { if (!containsKey(key)) { return null; @@ -199,7 +197,6 @@ public Long getLong(String key) { * @return the {@code Float} value associated with the given key, or {@code null} if the key is not present. * @throws RuntimeException if the value is not {@link Number} */ - @Experimental public Float getFloat(String key) { if (!containsKey(key)) { return null; @@ -232,7 +229,6 @@ public Float getFloat(String key) { * @return the {@code Double} value associated with the given key, or {@code null} if the key is not present. * @throws RuntimeException if the value is not {@link Number} */ - @Experimental public Double getDouble(String key) { if (!containsKey(key)) { return null; @@ -255,7 +251,6 @@ public Double getDouble(String key) { * @param key the key * @return {@code true} if this metadata contains a given key; {@code false} otherwise. */ - @Experimental public boolean containsKey(String key) { return metadata.containsKey(key); } @@ -266,8 +261,10 @@ public boolean containsKey(String key) { * @param key the key * @param value the value * @return {@code this} + * @deprecated as of 0.31.0, use {@link #put(String, String)}, {@link #put(String, int)}, {@link #put(String, long)}, + * {@link #put(String, float)}, {@link #put(String, double)} instead. */ - // TODO deprecate once the new experimental API is settled + @Deprecated public Metadata add(String key, Object value) { return put(key, value.toString()); } @@ -278,8 +275,10 @@ public Metadata add(String key, Object value) { * @param key the key * @param value the value * @return {@code this} + * @deprecated as of 0.31.0, use {@link #put(String, String)}, {@link #put(String, int)}, {@link #put(String, long)}, + * {@link #put(String, float)}, {@link #put(String, double)} instead. */ - // TODO deprecate once the new experimental API is settled + @Deprecated public Metadata add(String key, String value) { validate(key, value); this.metadata.put(key, value); @@ -293,7 +292,6 @@ public Metadata add(String key, String value) { * @param value the value * @return {@code this} */ - @Experimental public Metadata put(String key, String value) { validate(key, value); this.metadata.put(key, value); @@ -307,7 +305,6 @@ public Metadata put(String key, String value) { * @param value the value * @return {@code this} */ - @Experimental public Metadata put(String key, int value) { validate(key, value); this.metadata.put(key, value); @@ -321,7 +318,6 @@ public Metadata put(String key, int value) { * @param value the value * @return {@code this} */ - @Experimental public Metadata put(String key, long value) { validate(key, value); this.metadata.put(key, value); @@ -335,7 +331,6 @@ public Metadata put(String key, long value) { * @param value the value * @return {@code this} */ - @Experimental public Metadata put(String key, float value) { validate(key, value); this.metadata.put(key, value); @@ -349,7 +344,6 @@ public Metadata put(String key, float value) { * @param value the value * @return {@code this} */ - @Experimental public Metadata put(String key, double value) { validate(key, value); this.metadata.put(key, value); @@ -380,8 +374,9 @@ public Metadata copy() { * Get a copy of the metadata as a map of key-value pairs. * * @return the metadata as a map of key-value pairs. + * @deprecated as of 0.31.0, use {@link #toMap()} instead. */ - // TODO deprecate once the new experimental API is settled + @Deprecated public Map asMap() { Map map = new HashMap<>(); for (Map.Entry entry : metadata.entrySet()) { @@ -395,7 +390,6 @@ public Map asMap() { * * @return the metadata as a map of key-value pairs. */ - @Experimental public Map toMap() { return new HashMap<>(metadata); } diff --git a/langchain4j-core/src/main/java/dev/langchain4j/data/segment/TextSegment.java b/langchain4j-core/src/main/java/dev/langchain4j/data/segment/TextSegment.java index 988dc28a6a..5d7df51cd9 100644 --- a/langchain4j-core/src/main/java/dev/langchain4j/data/segment/TextSegment.java +++ b/langchain4j-core/src/main/java/dev/langchain4j/data/segment/TextSegment.java @@ -53,8 +53,11 @@ public Metadata metadata() { * * @param key the key. * @return the metadata value, or null if not found. + * @deprecated as of 0.31.0, use {@link #metadata()} and then {@link Metadata#getString(String)}, + * {@link Metadata#getInteger(String)}, {@link Metadata#getLong(String)}, {@link Metadata#getFloat(String)}, + * {@link Metadata#getDouble(String)} instead. */ - // TODO deprecate once the new experimental API is settled + @Deprecated public String metadata(String key) { return metadata.get(key); } diff --git a/langchain4j-core/src/main/java/dev/langchain4j/internal/RetryUtils.java b/langchain4j-core/src/main/java/dev/langchain4j/internal/RetryUtils.java index e6657ab6d9..5b53bb645d 100644 --- a/langchain4j-core/src/main/java/dev/langchain4j/internal/RetryUtils.java +++ b/langchain4j-core/src/main/java/dev/langchain4j/internal/RetryUtils.java @@ -191,7 +191,7 @@ public T withRetry(Callable action, int maxAttempts) { try { return action.call(); } catch (Exception e) { - if (attempt == maxAttempts) { + if (attempt >= maxAttempts) { throw new RuntimeException(e); } diff --git a/langchain4j-core/src/main/java/dev/langchain4j/internal/Utils.java b/langchain4j-core/src/main/java/dev/langchain4j/internal/Utils.java index 08b5abc5ca..c04ef8bc57 100644 --- a/langchain4j-core/src/main/java/dev/langchain4j/internal/Utils.java +++ b/langchain4j-core/src/main/java/dev/langchain4j/internal/Utils.java @@ -107,6 +107,15 @@ public static boolean isNullOrEmpty(Collection collection) { return collection == null || collection.isEmpty(); } + /** + * Is the iterable object {@code null} or empty? + * @param iterable The iterable object to check. + * @return {@code true} if the iterable object is {@code null} or there are no objects to iterate over, otherwise {@code false}. + */ + public static boolean isNullOrEmpty(Iterable iterable) { + return iterable == null || !iterable.iterator().hasNext(); + } + /** * @deprecated Use {@link #isNullOrEmpty(Collection)} instead. * @param collection The collection to check. diff --git a/langchain4j-core/src/main/java/dev/langchain4j/model/LambdaStreamingResponseHandler.java b/langchain4j-core/src/main/java/dev/langchain4j/model/LambdaStreamingResponseHandler.java new file mode 100644 index 0000000000..c34576587c --- /dev/null +++ b/langchain4j-core/src/main/java/dev/langchain4j/model/LambdaStreamingResponseHandler.java @@ -0,0 +1,57 @@ +package dev.langchain4j.model; + +import java.util.function.Consumer; + +/** + * Utility class with lambda-based streaming response handlers. + * + * Lets you use Java lambda functions to receive onNext and onError events, + * from your streaming chat model, instead of creating an anonymous inner class + * implementing StreamingResponseHandler. + * + * Example: + *
+ * import static dev.langchain4j.model.LambdaStreamingResponseHandler.*;
+ *
+ * model.generate("Why is the sky blue?",
+ *       onNext(text -> System.out.println(text));
+ * model.generate("Why is the sky blue?",
+ *       onNext(System.out::println);
+ * model.generate("Why is the sky blue?",
+ *       onNextAndError(System.out::println, Throwable::printStackTrace));
+ * 
+ * + * @param The type of the response. + * + * @see StreamingResponseHandler#onNext(String) + * @see StreamingResponseHandler#onError(Throwable) + */ +public class LambdaStreamingResponseHandler { + public static StreamingResponseHandler onNext(Consumer nextLambda) { + return new StreamingResponseHandler() { + @Override + public void onNext(String text) { + nextLambda.accept(text); + } + + @Override + public void onError(Throwable error) { + throw new RuntimeException(error); + } + }; + } + + public static StreamingResponseHandler onNextAndError(Consumer nextLambda, Consumer errorLambda) { + return new StreamingResponseHandler() { + @Override + public void onNext(String text) { + nextLambda.accept(text); + } + + @Override + public void onError(Throwable error) { + errorLambda.accept(error); + } + }; + } +} diff --git a/langchain4j-core/src/main/java/dev/langchain4j/model/chat/listener/ChatModelErrorContext.java b/langchain4j-core/src/main/java/dev/langchain4j/model/chat/listener/ChatModelErrorContext.java new file mode 100644 index 0000000000..4254660768 --- /dev/null +++ b/langchain4j-core/src/main/java/dev/langchain4j/model/chat/listener/ChatModelErrorContext.java @@ -0,0 +1,65 @@ +package dev.langchain4j.model.chat.listener; + +import dev.langchain4j.Experimental; +import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.chat.StreamingChatLanguageModel; + +import java.util.Map; + +import static dev.langchain4j.internal.ValidationUtils.ensureNotNull; + +/** + * The error context. It contains the error, corresponding {@link ChatModelRequest}, + * partial {@link ChatModelResponse} (if available) and attributes. + * The attributes can be used to pass data between methods of a {@link ChatModelListener} + * or between multiple {@link ChatModelListener}s. + */ +@Experimental +public class ChatModelErrorContext { + + private final Throwable error; + private final ChatModelRequest request; + private final ChatModelResponse partialResponse; + private final Map attributes; + + public ChatModelErrorContext(Throwable error, + ChatModelRequest request, + ChatModelResponse partialResponse, + Map attributes) { + this.error = ensureNotNull(error, "error"); + this.request = ensureNotNull(request, "request"); + this.partialResponse = partialResponse; + this.attributes = ensureNotNull(attributes, "attributes"); + } + + /** + * @return The error that occurred. + */ + public Throwable error() { + return error; + } + + /** + * @return The request to the {@link ChatLanguageModel} the error corresponds to. + */ + public ChatModelRequest request() { + return request; + } + + /** + * @return The partial response from the {@link ChatLanguageModel}, if available. + * When used with {@link StreamingChatLanguageModel}, it might contain the tokens + * that were received before the error occurred. + */ + public ChatModelResponse partialResponse() { + return partialResponse; + } + + /** + * @return The attributes map. It can be used to pass data between methods of a {@link ChatModelListener} + * or between multiple {@link ChatModelListener}s. + */ + public Map attributes() { + return attributes; + } +} diff --git a/langchain4j-core/src/main/java/dev/langchain4j/model/chat/listener/ChatModelListener.java b/langchain4j-core/src/main/java/dev/langchain4j/model/chat/listener/ChatModelListener.java new file mode 100644 index 0000000000..3a22828bfd --- /dev/null +++ b/langchain4j-core/src/main/java/dev/langchain4j/model/chat/listener/ChatModelListener.java @@ -0,0 +1,50 @@ +package dev.langchain4j.model.chat.listener; + +import dev.langchain4j.Experimental; +import dev.langchain4j.model.chat.ChatLanguageModel; + +/** + * A {@link ChatLanguageModel} listener that listens for requests, responses and errors. + */ +@Experimental +public interface ChatModelListener { + + /** + * This method is called before the request is sent to the model. + * + * @param requestContext The request context. It contains the {@link ChatModelRequest} and attributes. + * The attributes can be used to pass data between methods of this listener + * or between multiple listeners. + */ + @Experimental + default void onRequest(ChatModelRequestContext requestContext) { + + } + + /** + * This method is called after the response is received from the model. + * + * @param responseContext The response context. + * It contains {@link ChatModelResponse}, corresponding {@link ChatModelRequest} and attributes. + * The attributes can be used to pass data between methods of this listener + * or between multiple listeners. + */ + @Experimental + default void onResponse(ChatModelResponseContext responseContext) { + + } + + /** + * This method is called when an error occurs during interaction with the model. + * + * @param errorContext The error context. + * It contains the error, corresponding {@link ChatModelRequest}, + * partial {@link ChatModelResponse} (if available) and attributes. + * The attributes can be used to pass data between methods of this listener + * or between multiple listeners. + */ + @Experimental + default void onError(ChatModelErrorContext errorContext) { + + } +} diff --git a/langchain4j-core/src/main/java/dev/langchain4j/model/chat/listener/ChatModelRequest.java b/langchain4j-core/src/main/java/dev/langchain4j/model/chat/listener/ChatModelRequest.java new file mode 100644 index 0000000000..a4f766eca6 --- /dev/null +++ b/langchain4j-core/src/main/java/dev/langchain4j/model/chat/listener/ChatModelRequest.java @@ -0,0 +1,66 @@ +package dev.langchain4j.model.chat.listener; + +import dev.langchain4j.Experimental; +import dev.langchain4j.agent.tool.ToolSpecification; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.chat.StreamingChatLanguageModel; +import lombok.Builder; + +import java.util.List; + +import static dev.langchain4j.internal.Utils.copyIfNotNull; + +/** + * A request to the {@link ChatLanguageModel} or {@link StreamingChatLanguageModel}, + * intended to be used with {@link ChatModelListener}. + */ +@Experimental +public class ChatModelRequest { + + private final String model; + private final Double temperature; + private final Double topP; + private final Integer maxTokens; + private final List messages; + private final List toolSpecifications; + + @Builder + public ChatModelRequest(String model, + Double temperature, + Double topP, + Integer maxTokens, + List messages, + List toolSpecifications) { + this.model = model; + this.temperature = temperature; + this.topP = topP; + this.maxTokens = maxTokens; + this.messages = copyIfNotNull(messages); + this.toolSpecifications = copyIfNotNull(toolSpecifications); + } + + public String model() { + return model; + } + + public Double temperature() { + return temperature; + } + + public Double topP() { + return topP; + } + + public Integer maxTokens() { + return maxTokens; + } + + public List messages() { + return messages; + } + + public List toolSpecifications() { + return toolSpecifications; + } +} diff --git a/langchain4j-core/src/main/java/dev/langchain4j/model/chat/listener/ChatModelRequestContext.java b/langchain4j-core/src/main/java/dev/langchain4j/model/chat/listener/ChatModelRequestContext.java new file mode 100644 index 0000000000..a870607177 --- /dev/null +++ b/langchain4j-core/src/main/java/dev/langchain4j/model/chat/listener/ChatModelRequestContext.java @@ -0,0 +1,40 @@ +package dev.langchain4j.model.chat.listener; + +import dev.langchain4j.Experimental; +import dev.langchain4j.model.chat.ChatLanguageModel; + +import java.util.Map; + +import static dev.langchain4j.internal.ValidationUtils.ensureNotNull; + +/** + * The request context. It contains the {@link ChatModelRequest} and attributes. + * The attributes can be used to pass data between methods of a {@link ChatModelListener} + * or between multiple {@link ChatModelListener}s. + */ +@Experimental +public class ChatModelRequestContext { + + private final ChatModelRequest request; + private final Map attributes; + + public ChatModelRequestContext(ChatModelRequest request, Map attributes) { + this.request = ensureNotNull(request, "request"); + this.attributes = ensureNotNull(attributes, "attributes"); + } + + /** + * @return The request to the {@link ChatLanguageModel}. + */ + public ChatModelRequest request() { + return request; + } + + /** + * @return The attributes map. It can be used to pass data between methods of a {@link ChatModelListener} + * or between multiple {@link ChatModelListener}s. + */ + public Map attributes() { + return attributes; + } +} diff --git a/langchain4j-core/src/main/java/dev/langchain4j/model/chat/listener/ChatModelResponse.java b/langchain4j-core/src/main/java/dev/langchain4j/model/chat/listener/ChatModelResponse.java new file mode 100644 index 0000000000..05397668d5 --- /dev/null +++ b/langchain4j-core/src/main/java/dev/langchain4j/model/chat/listener/ChatModelResponse.java @@ -0,0 +1,56 @@ +package dev.langchain4j.model.chat.listener; + +import dev.langchain4j.Experimental; +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.chat.StreamingChatLanguageModel; +import dev.langchain4j.model.output.FinishReason; +import dev.langchain4j.model.output.TokenUsage; +import lombok.Builder; + +/** + * A response from the {@link ChatLanguageModel} or {@link StreamingChatLanguageModel}, + * intended to be used with {@link ChatModelListener}. + */ +@Experimental +public class ChatModelResponse { + + private final String id; + private final String model; + private final TokenUsage tokenUsage; + private final FinishReason finishReason; + private final AiMessage aiMessage; + + @Builder + public ChatModelResponse(String id, + String model, + TokenUsage tokenUsage, + FinishReason finishReason, + AiMessage aiMessage) { + this.id = id; + this.model = model; + this.tokenUsage = tokenUsage; + this.finishReason = finishReason; + this.aiMessage = aiMessage; + } + + public String id() { + return id; + } + + public String model() { + return model; + } + + public TokenUsage tokenUsage() { + return tokenUsage; + } + + public FinishReason finishReason() { + return finishReason; + } + + public AiMessage aiMessage() { + return aiMessage; + } +} diff --git a/langchain4j-core/src/main/java/dev/langchain4j/model/chat/listener/ChatModelResponseContext.java b/langchain4j-core/src/main/java/dev/langchain4j/model/chat/listener/ChatModelResponseContext.java new file mode 100644 index 0000000000..865aaf0894 --- /dev/null +++ b/langchain4j-core/src/main/java/dev/langchain4j/model/chat/listener/ChatModelResponseContext.java @@ -0,0 +1,51 @@ +package dev.langchain4j.model.chat.listener; + +import dev.langchain4j.Experimental; +import dev.langchain4j.model.chat.ChatLanguageModel; + +import java.util.Map; + +import static dev.langchain4j.internal.ValidationUtils.ensureNotNull; + +/** + * The response context. It contains {@link ChatModelResponse}, corresponding {@link ChatModelRequest} and attributes. + * The attributes can be used to pass data between methods of a {@link ChatModelListener} + * or between multiple {@link ChatModelListener}s. + */ +@Experimental +public class ChatModelResponseContext { + + private final ChatModelResponse response; + private final ChatModelRequest request; + private final Map attributes; + + public ChatModelResponseContext(ChatModelResponse response, + ChatModelRequest request, + Map attributes) { + this.response = ensureNotNull(response, "response"); + this.request = ensureNotNull(request, "request"); + this.attributes = ensureNotNull(attributes, "attributes"); + } + + /** + * @return The response from the {@link ChatLanguageModel}. + */ + public ChatModelResponse response() { + return response; + } + + /** + * @return The request to the {@link ChatLanguageModel} the response corresponds to. + */ + public ChatModelRequest request() { + return request; + } + + /** + * @return The attributes map. It can be used to pass data between methods of a {@link ChatModelListener} + * or between multiple {@link ChatModelListener}s. + */ + public Map attributes() { + return attributes; + } +} diff --git a/langchain4j-core/src/main/java/dev/langchain4j/model/input/Prompt.java b/langchain4j-core/src/main/java/dev/langchain4j/model/input/Prompt.java index 2aabee9398..764d76fdfa 100644 --- a/langchain4j-core/src/main/java/dev/langchain4j/model/input/Prompt.java +++ b/langchain4j-core/src/main/java/dev/langchain4j/model/input/Prompt.java @@ -45,6 +45,14 @@ public SystemMessage toSystemMessage() { return systemMessage(text); } + /** + * Convert this prompt to a UserMessage with specified userName. + * @return the UserMessage. + */ + public UserMessage toUserMessage(String userName) { + return userMessage(userName, text); + } + /** * Convert this prompt to a UserMessage. * @return the UserMessage. diff --git a/langchain4j-core/src/main/java/dev/langchain4j/model/output/TokenUsage.java b/langchain4j-core/src/main/java/dev/langchain4j/model/output/TokenUsage.java index c4f034503c..63000b9a11 100644 --- a/langchain4j-core/src/main/java/dev/langchain4j/model/output/TokenUsage.java +++ b/langchain4j-core/src/main/java/dev/langchain4j/model/output/TokenUsage.java @@ -79,6 +79,33 @@ public Integer totalTokenCount() { return totalTokenCount; } + /** + * Adds two token usages. + *
+ * If one of the token usages is null, the other is returned without changes. + *
+ * Fields which are null in both responses will be null in the result. + * + * @param first The first token usage to add. + * @param second The second token usage to add. + * @return a new {@link TokenUsage} instance with the sum of token usages. + */ + public static TokenUsage sum(TokenUsage first, TokenUsage second) { + if (first == null) { + return second; + } + + if (second == null) { + return first; + } + + return new TokenUsage( + sum(first.inputTokenCount, second.inputTokenCount), + sum(first.outputTokenCount, second.outputTokenCount), + sum(first.totalTokenCount, second.totalTokenCount) + ); + } + /** * Adds the token usage of two responses together. * @@ -86,7 +113,9 @@ public Integer totalTokenCount() { * * @param that The token usage to add to this one. * @return a new {@link TokenUsage} instance with the token usage of both responses added together. + * @deprecated use {@link #sum(TokenUsage, TokenUsage)} instead */ + @Deprecated public TokenUsage add(TokenUsage that) { if (that == null) { return new TokenUsage(inputTokenCount, outputTokenCount, totalTokenCount); diff --git a/langchain4j-core/src/main/java/dev/langchain4j/rag/AugmentationRequest.java b/langchain4j-core/src/main/java/dev/langchain4j/rag/AugmentationRequest.java new file mode 100644 index 0000000000..6c595ae586 --- /dev/null +++ b/langchain4j-core/src/main/java/dev/langchain4j/rag/AugmentationRequest.java @@ -0,0 +1,39 @@ +package dev.langchain4j.rag; + + +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.data.message.SystemMessage; +import dev.langchain4j.data.message.UserMessage; +import dev.langchain4j.rag.query.Metadata; + +import static dev.langchain4j.internal.ValidationUtils.ensureNotNull; + +/** + * Represents a request for {@link ChatMessage} augmentation. + */ +public class AugmentationRequest { + + /** + * The chat message to be augmented. + * Currently, it is a {@link UserMessage}, but soon it could also be a {@link SystemMessage}. + */ + private final ChatMessage chatMessage; + + /** + * Additional metadata related to the augmentation request. + */ + private final Metadata metadata; + + public AugmentationRequest(ChatMessage chatMessage, Metadata metadata) { + this.chatMessage = ensureNotNull(chatMessage, "chatMessage"); + this.metadata = ensureNotNull(metadata, "metadata"); + } + + public ChatMessage chatMessage() { + return chatMessage; + } + + public Metadata metadata() { + return metadata; + } +} diff --git a/langchain4j-core/src/main/java/dev/langchain4j/rag/AugmentationResult.java b/langchain4j-core/src/main/java/dev/langchain4j/rag/AugmentationResult.java new file mode 100644 index 0000000000..4f70e6a993 --- /dev/null +++ b/langchain4j-core/src/main/java/dev/langchain4j/rag/AugmentationResult.java @@ -0,0 +1,40 @@ +package dev.langchain4j.rag; + +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.rag.content.Content; +import lombok.Builder; + +import java.util.List; + +import static dev.langchain4j.internal.Utils.copyIfNotNull; +import static dev.langchain4j.internal.ValidationUtils.ensureNotNull; + +/** + * Represents the result of a {@link ChatMessage} augmentation. + */ +public class AugmentationResult { + + /** + * The augmented chat message. + */ + private final ChatMessage chatMessage; + + /** + * A list of content used to augment the original chat message. + */ + private final List contents; + + @Builder + public AugmentationResult(ChatMessage chatMessage, List contents) { + this.chatMessage = ensureNotNull(chatMessage, "chatMessage"); + this.contents = copyIfNotNull(contents); + } + + public ChatMessage chatMessage() { + return chatMessage; + } + + public List contents() { + return contents; + } +} diff --git a/langchain4j-core/src/main/java/dev/langchain4j/rag/DefaultRetrievalAugmentor.java b/langchain4j-core/src/main/java/dev/langchain4j/rag/DefaultRetrievalAugmentor.java index de4c4be4e2..da58e57d16 100644 --- a/langchain4j-core/src/main/java/dev/langchain4j/rag/DefaultRetrievalAugmentor.java +++ b/langchain4j-core/src/main/java/dev/langchain4j/rag/DefaultRetrievalAugmentor.java @@ -1,5 +1,6 @@ package dev.langchain4j.rag; +import dev.langchain4j.data.message.ChatMessage; import dev.langchain4j.data.message.UserMessage; import dev.langchain4j.rag.content.Content; import dev.langchain4j.rag.content.aggregator.ContentAggregator; @@ -26,6 +27,7 @@ import java.util.concurrent.Executors; import static dev.langchain4j.internal.Utils.getOrDefault; +import static dev.langchain4j.internal.Utils.isNotNullOrBlank; import static dev.langchain4j.internal.ValidationUtils.ensureNotNull; import static java.util.concurrent.CompletableFuture.allOf; import static java.util.concurrent.CompletableFuture.supplyAsync; @@ -119,10 +121,23 @@ public DefaultRetrievalAugmentor(QueryTransformer queryTransformer, this.executor = getOrDefault(executor, Executors::newCachedThreadPool); } + /** + * @deprecated use {@link #augment(AugmentationRequest)} instead. + */ @Override + @Deprecated public UserMessage augment(UserMessage userMessage, Metadata metadata) { + AugmentationRequest augmentationRequest = new AugmentationRequest(userMessage, metadata); + return (UserMessage) augment(augmentationRequest).chatMessage(); + } + + @Override + public AugmentationResult augment(AugmentationRequest augmentationRequest) { + + ChatMessage chatMessage = augmentationRequest.chatMessage(); + Metadata metadata = augmentationRequest.metadata(); - Query originalQuery = Query.from(userMessage.text(), metadata); + Query originalQuery = Query.from(chatMessage.text(), metadata); Collection queries = queryTransformer.transform(originalQuery); logQueries(originalQuery, queries); @@ -145,10 +160,13 @@ public UserMessage augment(UserMessage userMessage, Metadata metadata) { List contents = contentAggregator.aggregate(queryToContents); log(queryToContents, contents); - UserMessage augmentedUserMessage = contentInjector.inject(contents, userMessage); - log(augmentedUserMessage); + ChatMessage augmentedChatMessage = contentInjector.inject(contents, chatMessage); + log(augmentedChatMessage); - return augmentedUserMessage; + return AugmentationResult.builder() + .chatMessage(augmentedChatMessage) + .contents(contents) + .build(); } private CompletableFuture>> retrieveFromAll(Collection retrievers, @@ -247,8 +265,8 @@ private static void log(Map>> queryToContents, L .collect(joining("\n"))); } - private static void log(UserMessage augmentedUserMessage) { - log.trace("Augmented user message: " + escapeNewlines(augmentedUserMessage.singleText())); + private static void log(ChatMessage augmentedChatMessage) { + log.trace("Augmented chat message: {}", escapeNewlines(augmentedChatMessage.text())); } private static String escapeNewlines(String text) { diff --git a/langchain4j-core/src/main/java/dev/langchain4j/rag/RetrievalAugmentor.java b/langchain4j-core/src/main/java/dev/langchain4j/rag/RetrievalAugmentor.java index f45fa971a1..d3e1631b69 100644 --- a/langchain4j-core/src/main/java/dev/langchain4j/rag/RetrievalAugmentor.java +++ b/langchain4j-core/src/main/java/dev/langchain4j/rag/RetrievalAugmentor.java @@ -1,11 +1,15 @@ package dev.langchain4j.rag; import dev.langchain4j.Experimental; +import dev.langchain4j.data.message.ChatMessage; import dev.langchain4j.data.message.UserMessage; +import dev.langchain4j.rag.content.Content; import dev.langchain4j.rag.query.Metadata; +import static dev.langchain4j.internal.Exceptions.runtime; + /** - * Augments the provided {@link UserMessage} with retrieved content. + * Augments the provided {@link ChatMessage} with retrieved {@link Content}s. *
* This serves as an entry point into the RAG flow in LangChain4j. *
@@ -16,12 +20,37 @@ @Experimental public interface RetrievalAugmentor { + /** + * Augments the {@link ChatMessage} provided in the {@link AugmentationRequest} with retrieved {@link Content}s. + *
+ * This method has a default implementation in order to temporarily support + * current custom implementations of {@code RetrievalAugmentor}. The default implementation will be removed soon. + * + * @param augmentationRequest The {@code AugmentationRequest} containing the {@code ChatMessage} to augment. + * @return The {@link AugmentationResult} containing the augmented {@code ChatMessage}. + */ + default AugmentationResult augment(AugmentationRequest augmentationRequest) { + + if (!(augmentationRequest.chatMessage() instanceof UserMessage)) { + throw runtime("Please implement 'AugmentationResult augment(AugmentationRequest)' method " + + "in order to augment " + augmentationRequest.chatMessage().getClass()); + } + + UserMessage augmented = augment((UserMessage) augmentationRequest.chatMessage(), augmentationRequest.metadata()); + + return AugmentationResult.builder() + .chatMessage(augmented) + .build(); + } + /** * Augments the provided {@link UserMessage} with retrieved content. * * @param userMessage The {@link UserMessage} to be augmented. * @param metadata The {@link Metadata} that may be useful or necessary for retrieval and augmentation. * @return The augmented {@link UserMessage}. + * @deprecated Use/implement {@link #augment(AugmentationRequest)} instead. */ + @Deprecated UserMessage augment(UserMessage userMessage, Metadata metadata); } diff --git a/langchain4j-core/src/main/java/dev/langchain4j/rag/content/injector/ContentInjector.java b/langchain4j-core/src/main/java/dev/langchain4j/rag/content/injector/ContentInjector.java index 56f6304251..9d8ca9a4c5 100644 --- a/langchain4j-core/src/main/java/dev/langchain4j/rag/content/injector/ContentInjector.java +++ b/langchain4j-core/src/main/java/dev/langchain4j/rag/content/injector/ContentInjector.java @@ -1,11 +1,15 @@ package dev.langchain4j.rag.content.injector; import dev.langchain4j.Experimental; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.data.message.SystemMessage; import dev.langchain4j.data.message.UserMessage; import dev.langchain4j.rag.content.Content; import java.util.List; +import static dev.langchain4j.internal.Exceptions.runtime; + /** * Injects given {@link Content}s into a given {@link UserMessage}. *
@@ -17,12 +21,35 @@ @Experimental public interface ContentInjector { + /** + * Injects given {@link Content}s into a given {@link ChatMessage}. + *
+ * This method has a default implementation in order to temporarily support + * current custom implementations of {@code ContentInjector}. The default implementation will be removed soon. + * + * @param contents The list of {@link Content} to be injected. + * @param chatMessage The {@link ChatMessage} into which the {@link Content}s are to be injected. + * Can be either a {@link UserMessage} or a {@link SystemMessage}. + * @return The {@link UserMessage} with the injected {@link Content}s. + */ + default ChatMessage inject(List contents, ChatMessage chatMessage) { + + if (!(chatMessage instanceof UserMessage)) { + throw runtime("Please implement 'ChatMessage inject(List, ChatMessage)' method " + + "in order to inject contents into " + chatMessage); + } + + return inject(contents, (UserMessage) chatMessage); + } + /** * Injects given {@link Content}s into a given {@link UserMessage}. * * @param contents The list of {@link Content} to be injected. * @param userMessage The {@link UserMessage} into which the {@link Content}s are to be injected. * @return The {@link UserMessage} with the injected {@link Content}s. + * @deprecated Use/implement {@link #inject(List, ChatMessage)} instead. */ + @Deprecated UserMessage inject(List contents, UserMessage userMessage); } diff --git a/langchain4j-core/src/main/java/dev/langchain4j/rag/content/injector/DefaultContentInjector.java b/langchain4j-core/src/main/java/dev/langchain4j/rag/content/injector/DefaultContentInjector.java index 9ea5873c4f..dde94d0a56 100644 --- a/langchain4j-core/src/main/java/dev/langchain4j/rag/content/injector/DefaultContentInjector.java +++ b/langchain4j-core/src/main/java/dev/langchain4j/rag/content/injector/DefaultContentInjector.java @@ -1,6 +1,7 @@ package dev.langchain4j.rag.content.injector; import dev.langchain4j.data.document.Metadata; +import dev.langchain4j.data.message.ChatMessage; import dev.langchain4j.data.message.UserMessage; import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.model.input.Prompt; @@ -12,6 +13,7 @@ import java.util.List; import java.util.Map; +import static dev.langchain4j.data.message.UserMessage.userMessage; import static dev.langchain4j.internal.Utils.*; import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty; import static dev.langchain4j.internal.ValidationUtils.ensureNotNull; @@ -71,6 +73,29 @@ public DefaultContentInjector(PromptTemplate promptTemplate, List metada } @Override + public ChatMessage inject(List contents, ChatMessage chatMessage) { + + if (contents.isEmpty()) { + return chatMessage; + } + + Prompt prompt = createPrompt(chatMessage, contents); + if (chatMessage instanceof UserMessage && isNotNullOrBlank(((UserMessage)chatMessage).name())) { + return prompt.toUserMessage(((UserMessage)chatMessage).name()); + } + + return prompt.toUserMessage(); + } + + protected Prompt createPrompt(ChatMessage chatMessage, List contents) { + return createPrompt((UserMessage) chatMessage, contents); + } + + /** + * @deprecated use {@link #inject(List, ChatMessage)} instead. + */ + @Override + @Deprecated public UserMessage inject(List contents, UserMessage userMessage) { if (contents.isEmpty()) { @@ -78,9 +103,16 @@ public UserMessage inject(List contents, UserMessage userMessage) { } Prompt prompt = createPrompt(userMessage, contents); + if (isNotNullOrBlank(userMessage.name())) { + return prompt.toUserMessage(userMessage.name()); + } return prompt.toUserMessage(); } + /** + * @deprecated implement/override {@link #createPrompt(ChatMessage, List)} instead. + */ + @Deprecated protected Prompt createPrompt(UserMessage userMessage, List contents) { Map variables = new HashMap<>(); variables.put("userMessage", userMessage.text()); diff --git a/langchain4j-core/src/main/java/dev/langchain4j/rag/content/retriever/ContentRetriever.java b/langchain4j-core/src/main/java/dev/langchain4j/rag/content/retriever/ContentRetriever.java index eaadae9425..b3752bc0de 100644 --- a/langchain4j-core/src/main/java/dev/langchain4j/rag/content/retriever/ContentRetriever.java +++ b/langchain4j-core/src/main/java/dev/langchain4j/rag/content/retriever/ContentRetriever.java @@ -13,15 +13,16 @@ * The underlying data source can be virtually anything: *
  * - Embedding (vector) store (see {@link EmbeddingStoreContentRetriever})
- * - Full-text search engine (e.g., Apache Lucene, Elasticsearch, Vespa)
- * - Hybrid of keyword and vector search
- * - The Web (e.g., Google, Bing)
- * - Knowledge graph
- * - Relational database
+ * - Full-text search engine (see {@code AzureAiSearchContentRetriever} in {@code langchain4j-azure-ai-search} module)
+ * - Hybrid of vector and full-text search (see {@code AzureAiSearchContentRetriever} in {@code langchain4j-azure-ai-search} module)
+ * - Web Search Engine (see {@link WebSearchContentRetriever})
+ * - Knowledge graph (see {@code Neo4jContentRetriever} in {@code langchain4j-neo4j} module)
+ * - SQL database (see {@code SqlDatabaseContentRetriever} in {@code langchain4j-experimental-sql} module)
  * - etc.
  * 
* * @see EmbeddingStoreContentRetriever + * @see WebSearchContentRetriever */ public interface ContentRetriever { diff --git a/langchain4j-core/src/main/java/dev/langchain4j/rag/content/retriever/WebSearchContentRetriever.java b/langchain4j-core/src/main/java/dev/langchain4j/rag/content/retriever/WebSearchContentRetriever.java new file mode 100644 index 0000000000..061ede10df --- /dev/null +++ b/langchain4j-core/src/main/java/dev/langchain4j/rag/content/retriever/WebSearchContentRetriever.java @@ -0,0 +1,49 @@ +package dev.langchain4j.rag.content.retriever; + +import dev.langchain4j.rag.content.Content; +import dev.langchain4j.rag.query.Query; +import dev.langchain4j.web.search.WebSearchEngine; +import dev.langchain4j.web.search.WebSearchRequest; +import dev.langchain4j.web.search.WebSearchResults; +import lombok.Builder; + +import java.util.List; + +import static dev.langchain4j.internal.Utils.getOrDefault; +import static dev.langchain4j.internal.ValidationUtils.ensureNotNull; +import static java.util.stream.Collectors.toList; + +/** + * A {@link ContentRetriever} that retrieves relevant {@link Content} from the web using a {@link WebSearchEngine}. + *
+ * It returns one {@link Content} for each result that a {@link WebSearchEngine} has returned for a given {@link Query}. + *
+ * Depending on the {@link WebSearchEngine} implementation, the {@link Content#textSegment()} + * can contain either a snippet of a web page or a complete content of a web page. + */ +public class WebSearchContentRetriever implements ContentRetriever { + + private final WebSearchEngine webSearchEngine; + private final int maxResults; + + @Builder + public WebSearchContentRetriever(WebSearchEngine webSearchEngine, Integer maxResults) { + this.webSearchEngine = ensureNotNull(webSearchEngine, "webSearchEngine"); + this.maxResults = getOrDefault(maxResults, 5); + } + + @Override + public List retrieve(Query query) { + + WebSearchRequest webSearchRequest = WebSearchRequest.builder() + .searchTerms(query.text()) + .maxResults(maxResults) + .build(); + + WebSearchResults webSearchResults = webSearchEngine.search(webSearchRequest); + + return webSearchResults.toTextSegments().stream() + .map(Content::from) + .collect(toList()); + } +} diff --git a/langchain4j-core/src/main/java/dev/langchain4j/rag/query/transformer/CompressingQueryTransformer.java b/langchain4j-core/src/main/java/dev/langchain4j/rag/query/transformer/CompressingQueryTransformer.java index 9a0c52da55..90c020f9e5 100644 --- a/langchain4j-core/src/main/java/dev/langchain4j/rag/query/transformer/CompressingQueryTransformer.java +++ b/langchain4j-core/src/main/java/dev/langchain4j/rag/query/transformer/CompressingQueryTransformer.java @@ -71,8 +71,11 @@ public Collection transform(Query query) { } Prompt prompt = createPrompt(query, format(chatMemory)); - String compressedQuery = chatLanguageModel.generate(prompt.text()); - return singletonList(Query.from(compressedQuery)); + String compressedQueryText = chatLanguageModel.generate(prompt.text()); + Query compressedQuery = query.metadata() == null + ? Query.from(compressedQueryText) + : Query.from(compressedQueryText, query.metadata()); + return singletonList(compressedQuery); } protected String format(List chatMemory) { diff --git a/langchain4j-core/src/main/java/dev/langchain4j/rag/query/transformer/ExpandingQueryTransformer.java b/langchain4j-core/src/main/java/dev/langchain4j/rag/query/transformer/ExpandingQueryTransformer.java index ffd163f61c..1f7e556e83 100644 --- a/langchain4j-core/src/main/java/dev/langchain4j/rag/query/transformer/ExpandingQueryTransformer.java +++ b/langchain4j-core/src/main/java/dev/langchain4j/rag/query/transformer/ExpandingQueryTransformer.java @@ -73,7 +73,12 @@ public ExpandingQueryTransformer(ChatLanguageModel chatLanguageModel, PromptTemp public Collection transform(Query query) { Prompt prompt = createPrompt(query); String response = chatLanguageModel.generate(prompt.text()); - return parse(response); + List queries = parse(response); + return queries.stream() + .map(queryText -> query.metadata() == null + ? Query.from(queryText) + : Query.from(queryText, query.metadata())) + .collect(toList()); } protected Prompt createPrompt(Query query) { @@ -83,10 +88,9 @@ protected Prompt createPrompt(Query query) { return promptTemplate.apply(variables); } - protected List parse(String queries) { + protected List parse(String queries) { return stream(queries.split("\n")) .filter(Utils::isNotNullOrBlank) - .map(Query::from) .collect(toList()); } } diff --git a/langchain4j-core/src/main/java/dev/langchain4j/store/embedding/EmbeddingSearchRequest.java b/langchain4j-core/src/main/java/dev/langchain4j/store/embedding/EmbeddingSearchRequest.java index 7f626a523f..eb85ba492e 100644 --- a/langchain4j-core/src/main/java/dev/langchain4j/store/embedding/EmbeddingSearchRequest.java +++ b/langchain4j-core/src/main/java/dev/langchain4j/store/embedding/EmbeddingSearchRequest.java @@ -1,6 +1,5 @@ package dev.langchain4j.store.embedding; -import dev.langchain4j.Experimental; import dev.langchain4j.data.document.Metadata; import dev.langchain4j.data.embedding.Embedding; import dev.langchain4j.data.segment.TextSegment; @@ -15,7 +14,6 @@ /** * Represents a request to search in an {@link EmbeddingStore}. */ -@Experimental @ToString @EqualsAndHashCode public class EmbeddingSearchRequest { @@ -41,7 +39,6 @@ public class EmbeddingSearchRequest { * This is an optional parameter. Default: no filtering */ @Builder - @Experimental public EmbeddingSearchRequest(Embedding queryEmbedding, Integer maxResults, Double minScore, Filter filter) { this.queryEmbedding = ensureNotNull(queryEmbedding, "queryEmbedding"); this.maxResults = ensureGreaterThanZero(getOrDefault(maxResults, 3), "maxResults"); @@ -49,22 +46,18 @@ public EmbeddingSearchRequest(Embedding queryEmbedding, Integer maxResults, Doub this.filter = filter; } - @Experimental public Embedding queryEmbedding() { return queryEmbedding; } - @Experimental public int maxResults() { return maxResults; } - @Experimental public double minScore() { return minScore; } - @Experimental public Filter filter() { return filter; } diff --git a/langchain4j-core/src/main/java/dev/langchain4j/store/embedding/EmbeddingSearchResult.java b/langchain4j-core/src/main/java/dev/langchain4j/store/embedding/EmbeddingSearchResult.java index 236680f4dd..044cfe145d 100644 --- a/langchain4j-core/src/main/java/dev/langchain4j/store/embedding/EmbeddingSearchResult.java +++ b/langchain4j-core/src/main/java/dev/langchain4j/store/embedding/EmbeddingSearchResult.java @@ -1,7 +1,5 @@ package dev.langchain4j.store.embedding; -import dev.langchain4j.Experimental; - import java.util.List; import static dev.langchain4j.internal.ValidationUtils.ensureNotNull; @@ -9,17 +7,14 @@ /** * Represents a result of a search in an {@link EmbeddingStore}. */ -@Experimental public class EmbeddingSearchResult { private final List> matches; - @Experimental public EmbeddingSearchResult(List> matches) { this.matches = ensureNotNull(matches, "matches"); } - @Experimental public List> matches() { return matches; } diff --git a/langchain4j-core/src/main/java/dev/langchain4j/store/embedding/EmbeddingStore.java b/langchain4j-core/src/main/java/dev/langchain4j/store/embedding/EmbeddingStore.java index e200657bc1..160a13b4f6 100644 --- a/langchain4j-core/src/main/java/dev/langchain4j/store/embedding/EmbeddingStore.java +++ b/langchain4j-core/src/main/java/dev/langchain4j/store/embedding/EmbeddingStore.java @@ -1,11 +1,17 @@ package dev.langchain4j.store.embedding; import dev.langchain4j.Experimental; +import dev.langchain4j.data.document.Metadata; import dev.langchain4j.data.embedding.Embedding; +import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.store.embedding.filter.Filter; +import java.util.Collection; import java.util.List; +import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank; +import static java.util.Collections.singletonList; + /** * Represents a store for embeddings, also known as a vector database. * @@ -55,6 +61,47 @@ public interface EmbeddingStore { */ List addAll(List embeddings, List embedded); + /** + * Removes a single embedding from the store by ID. + * + * @param id The unique ID of the embedding to be removed. + */ + @Experimental + default void remove(String id) { + ensureNotBlank(id, "id"); + this.removeAll(singletonList(id)); + } + + /** + * Removes all embeddings that match the specified IDs from the store. + * + * @param ids A collection of unique IDs of the embeddings to be removed. + */ + @Experimental + default void removeAll(Collection ids) { + throw new UnsupportedOperationException("Not supported yet."); + } + + /** + * Removes all embeddings that match the specified {@link Filter} from the store. + * + * @param filter The filter to be applied to the {@link Metadata} of the {@link TextSegment} during removal. + * Only embeddings whose {@code TextSegment}'s {@code Metadata} + * match the {@code Filter} will be removed. + */ + @Experimental + default void removeAll(Filter filter) { + throw new UnsupportedOperationException("Not supported yet."); + } + + /** + * Removes all embeddings from the store. + */ + @Experimental + default void removeAll() { + throw new UnsupportedOperationException("Not supported yet."); + } + /** * Searches for the most similar (closest in the embedding space) {@link Embedding}s. *
@@ -66,7 +113,6 @@ public interface EmbeddingStore { * @param request A request to search in an {@link EmbeddingStore}. Contains all search criteria. * @return An {@link EmbeddingSearchResult} containing all found {@link Embedding}s. */ - @Experimental default EmbeddingSearchResult search(EmbeddingSearchRequest request) { List> matches = findRelevant(request.queryEmbedding(), request.maxResults(), request.minScore()); @@ -82,8 +128,9 @@ default EmbeddingSearchResult search(EmbeddingSearchRequest request) { * @return A list of embedding matches. * Each embedding match includes a relevance score (derivative of cosine distance), * ranging from 0 (not relevant) to 1 (highly relevant). + * @deprecated as of 0.31.0, use {@link #search(EmbeddingSearchRequest)} instead. */ - // TODO deprecate once the new experimental API is settled + @Deprecated default List> findRelevant(Embedding referenceEmbedding, int maxResults) { return findRelevant(referenceEmbedding, maxResults, 0); } @@ -98,8 +145,9 @@ default List> findRelevant(Embedding referenceEmbedding * @return A list of embedding matches. * Each embedding match includes a relevance score (derivative of cosine distance), * ranging from 0 (not relevant) to 1 (highly relevant). + * @deprecated as of 0.31.0, use {@link #search(EmbeddingSearchRequest)} instead. */ - // TODO deprecate once the new experimental API is settled + @Deprecated default List> findRelevant(Embedding referenceEmbedding, int maxResults, double minScore) { EmbeddingSearchRequest embeddingSearchRequest = EmbeddingSearchRequest.builder() .queryEmbedding(referenceEmbedding) @@ -120,8 +168,9 @@ default List> findRelevant(Embedding referenceEmbedding * @return A list of embedding matches. * Each embedding match includes a relevance score (derivative of cosine distance), * ranging from 0 (not relevant) to 1 (highly relevant). + * @deprecated as of 0.31.0, use {@link #search(EmbeddingSearchRequest)} instead. */ - // TODO deprecate once the new experimental API is settled + @Deprecated default List> findRelevant( Object memoryId, Embedding referenceEmbedding, int maxResults) { return findRelevant(memoryId, referenceEmbedding, maxResults, 0); @@ -138,8 +187,9 @@ default List> findRelevant( * @return A list of embedding matches. * Each embedding match includes a relevance score (derivative of cosine distance), * ranging from 0 (not relevant) to 1 (highly relevant). + * @deprecated as of 0.31.0, use {@link #search(EmbeddingSearchRequest)} instead. */ - // TODO deprecate once the new experimental API is settled + @Deprecated default List> findRelevant( Object memoryId, Embedding referenceEmbedding, int maxResults, double minScore) { throw new RuntimeException("Not implemented"); diff --git a/langchain4j-core/src/main/java/dev/langchain4j/store/embedding/filter/Filter.java b/langchain4j-core/src/main/java/dev/langchain4j/store/embedding/filter/Filter.java index a151be6b8e..f24e580232 100644 --- a/langchain4j-core/src/main/java/dev/langchain4j/store/embedding/filter/Filter.java +++ b/langchain4j-core/src/main/java/dev/langchain4j/store/embedding/filter/Filter.java @@ -1,6 +1,5 @@ package dev.langchain4j.store.embedding.filter; -import dev.langchain4j.Experimental; import dev.langchain4j.store.embedding.EmbeddingStore; import dev.langchain4j.store.embedding.filter.comparison.*; import dev.langchain4j.store.embedding.filter.logical.And; @@ -31,7 +30,6 @@ * @see Not * @see Or */ -@Experimental public interface Filter { /** diff --git a/langchain4j-core/src/main/java/dev/langchain4j/store/embedding/filter/FilterParser.java b/langchain4j-core/src/main/java/dev/langchain4j/store/embedding/filter/FilterParser.java index 81345be997..f049ff015a 100644 --- a/langchain4j-core/src/main/java/dev/langchain4j/store/embedding/filter/FilterParser.java +++ b/langchain4j-core/src/main/java/dev/langchain4j/store/embedding/filter/FilterParser.java @@ -1,14 +1,11 @@ package dev.langchain4j.store.embedding.filter; -import dev.langchain4j.Experimental; - /** * Parses a filter expression string into a {@link Filter} object. *
* Currently, there is only one implementation: {@code SqlFilterParser} * in the {@code langchain4j-embedding-store-filter-parser-sql} module. */ -@Experimental public interface FilterParser { /** diff --git a/langchain4j-core/src/main/java/dev/langchain4j/store/embedding/filter/MetadataFilterBuilder.java b/langchain4j-core/src/main/java/dev/langchain4j/store/embedding/filter/MetadataFilterBuilder.java index a82d8697ff..1a7fc585dd 100644 --- a/langchain4j-core/src/main/java/dev/langchain4j/store/embedding/filter/MetadataFilterBuilder.java +++ b/langchain4j-core/src/main/java/dev/langchain4j/store/embedding/filter/MetadataFilterBuilder.java @@ -1,6 +1,5 @@ package dev.langchain4j.store.embedding.filter; -import dev.langchain4j.Experimental; import dev.langchain4j.data.document.Metadata; import dev.langchain4j.store.embedding.filter.comparison.*; @@ -16,7 +15,6 @@ /** * A helper class for building a {@link Filter} for {@link Metadata} key. */ -@Experimental public class MetadataFilterBuilder { private final String key; @@ -25,7 +23,6 @@ public MetadataFilterBuilder(String key) { this.key = ensureNotBlank(key, "key"); } - @Experimental public static MetadataFilterBuilder metadataKey(String key) { return new MetadataFilterBuilder(key); } diff --git a/langchain4j-core/src/main/java/dev/langchain4j/store/memory/chat/ChatMemoryStore.java b/langchain4j-core/src/main/java/dev/langchain4j/store/memory/chat/ChatMemoryStore.java index 867da77842..1a95917a59 100644 --- a/langchain4j-core/src/main/java/dev/langchain4j/store/memory/chat/ChatMemoryStore.java +++ b/langchain4j-core/src/main/java/dev/langchain4j/store/memory/chat/ChatMemoryStore.java @@ -10,9 +10,14 @@ /** * Represents a store for the {@link ChatMemory} state. * Allows for flexibility in terms of where and how chat memory is stored. - * Currently, the only implementation available is {@link InMemoryChatMemoryStore}. We are in the process of adding - * ready implementations for popular stores like SQL DBs, document stores, etc. + *
+ *
+ * Currently, the only implementation available is {@link InMemoryChatMemoryStore}. + * Over time, out-of-the-box implementations will be added for popular stores like SQL databases, document stores, etc. * In the meantime, you can implement this interface to connect to any storage of your choice. + *
+ *
+ * More documentation can be found here. */ public interface ChatMemoryStore { diff --git a/langchain4j-core/src/main/java/dev/langchain4j/web/search/WebSearchEngine.java b/langchain4j-core/src/main/java/dev/langchain4j/web/search/WebSearchEngine.java new file mode 100644 index 0000000000..4fb927e93e --- /dev/null +++ b/langchain4j-core/src/main/java/dev/langchain4j/web/search/WebSearchEngine.java @@ -0,0 +1,25 @@ +package dev.langchain4j.web.search; + +/** + * Represents a web search engine that can be used to perform searches on the Web in response to a user query. + */ +public interface WebSearchEngine { + + /** + * Performs a search query on the web search engine and returns the search results. + * + * @param query the search query + * @return the search results + */ + default WebSearchResults search(String query) { + return search(WebSearchRequest.from(query)); + } + + /** + * Performs a search request on the web search engine and returns the search results. + * + * @param webSearchRequest the search request + * @return the web search results + */ + WebSearchResults search(WebSearchRequest webSearchRequest); +} diff --git a/langchain4j-core/src/main/java/dev/langchain4j/web/search/WebSearchInformationResult.java b/langchain4j-core/src/main/java/dev/langchain4j/web/search/WebSearchInformationResult.java new file mode 100644 index 0000000000..c29919b01d --- /dev/null +++ b/langchain4j-core/src/main/java/dev/langchain4j/web/search/WebSearchInformationResult.java @@ -0,0 +1,117 @@ +package dev.langchain4j.web.search; + +import java.util.Map; +import java.util.Objects; + +import static dev.langchain4j.internal.ValidationUtils.ensureNotNull; + +/** + * Represents general information about the web search performed. + * This includes the total number of results, the page number, and metadata. + *

+ * The total number of results is the total number of web pages that are found by the search engine in response to a search query. + * The page number is the current page number of the search results. + * The metadata is a map of key-value pairs that provide additional information about the search. + * For example, it could include the search query, the search engine used, the time it took to perform the search, etc. + */ +public class WebSearchInformationResult { + + private final Long totalResults; + private final Integer pageNumber; + private final Map metadata; + + /** + * Constructs a new WebSearchInformationResult with the specified total results. + * + * @param totalResults The total number of results. + */ + public WebSearchInformationResult(Long totalResults) { + this(totalResults, null, null); + } + + /** + * Constructs a new WebSearchInformationResult with the specified total results, page number, and metadata. + * + * @param totalResults The total number of results. + * @param pageNumber The page number. + * @param metadata The metadata. + */ + public WebSearchInformationResult(Long totalResults, Integer pageNumber, Map metadata) { + this.totalResults = ensureNotNull(totalResults, "totalResults"); + this.pageNumber = pageNumber; + this.metadata = metadata; + } + + /** + * Gets the total number of results. + * + * @return The total number of results. + */ + public Long totalResults() { + return totalResults; + } + + /** + * Gets the page number. + * + * @return The page number. + */ + public Integer pageNumber() { + return pageNumber; + } + + /** + * Gets the metadata. + * + * @return The metadata. + */ + public Map metadata() { + return metadata; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + WebSearchInformationResult that = (WebSearchInformationResult) o; + return Objects.equals(totalResults, that.totalResults) + && Objects.equals(pageNumber, that.pageNumber) + && Objects.equals(metadata, that.metadata); + } + + @Override + public int hashCode() { + return Objects.hash(totalResults, pageNumber, metadata); + } + + @Override + public String toString() { + return "WebSearchInformationResult{" + + "totalResults=" + totalResults + + ", pageNumber=" + pageNumber + + ", metadata=" + metadata + + '}'; + } + + /** + * Creates a new WebSearchInformationResult with the specified total results. + * + * @param totalResults The total number of results. + * @return The new WebSearchInformationResult. + */ + public static WebSearchInformationResult from(Long totalResults) { + return new WebSearchInformationResult(totalResults); + } + + /** + * Creates a new WebSearchInformationResult with the specified total results, page number, and metadata. + * + * @param totalResults The total number of results. + * @param pageNumber The page number. + * @param metadata The metadata. + * @return The new WebSearchInformationResult. + */ + public static WebSearchInformationResult from(Long totalResults, Integer pageNumber, Map metadata) { + return new WebSearchInformationResult(totalResults, pageNumber, metadata); + } +} diff --git a/langchain4j-core/src/main/java/dev/langchain4j/web/search/WebSearchOrganicResult.java b/langchain4j-core/src/main/java/dev/langchain4j/web/search/WebSearchOrganicResult.java new file mode 100644 index 0000000000..8df9fe9cdb --- /dev/null +++ b/langchain4j-core/src/main/java/dev/langchain4j/web/search/WebSearchOrganicResult.java @@ -0,0 +1,230 @@ +package dev.langchain4j.web.search; + +import dev.langchain4j.data.document.Document; +import dev.langchain4j.data.document.Metadata; +import dev.langchain4j.data.segment.TextSegment; + +import java.net.URI; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +import static dev.langchain4j.internal.Utils.getOrDefault; +import static dev.langchain4j.internal.Utils.isNotNullOrBlank; +import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank; +import static dev.langchain4j.internal.ValidationUtils.ensureNotNull; + +/** + * Represents an organic search results are the web pages that are returned by the search engine in response to a search query. + * This includes the title, URL, snippet and/or content, and metadata of the web page. + *

+ * These results are typically ranked by relevance to the search query. + *

+ */ +public class WebSearchOrganicResult { + private final String title; + private final URI url; + private final String snippet; + private final String content; + private final Map metadata; + + + /** + * Constructs a WebSearchOrganicResult object with the given title and URL. + * + * @param title The title of the search result. + * @param url The URL associated with the search result. + */ + public WebSearchOrganicResult(String title, URI url) { + this.title = ensureNotBlank(title, "title"); + this.url = ensureNotNull(url, "url"); + this.snippet = null; + this.content = null; + this.metadata = null; + } + + /** + * Constructs a WebSearchOrganicResult object with the given title, URL, snippet and/or content. + * + * @param title The title of the search result. + * @param url The URL associated with the search result. + * @param snippet The snippet of the search result, in plain text. + * @param content The most query related content from the scraped url. + */ + public WebSearchOrganicResult(String title, URI url, String snippet, String content) { + this.title = ensureNotBlank(title, "title"); + this.url = ensureNotNull(url, "url"); + this.snippet = snippet; + this.content = content; + this.metadata = null; + } + + /** + * Constructs a WebSearchOrganicResult object with the given title, URL, snippet and/or content, and metadata. + * + * @param title The title of the search result. + * @param url The URL associated with the search result. + * @param snippet The snippet of the search result, in plain text. + * @param content The most query related content from the scraped url. + * @param metadata The metadata associated with the search result. + */ + public WebSearchOrganicResult(String title, URI url, String snippet, String content, Map metadata) { + this.title = ensureNotBlank(title, "title"); + this.url = ensureNotNull(url,"url"); + this.snippet = snippet; + this.content = content; + this.metadata = getOrDefault(metadata, new HashMap<>()); + } + + /** + * Returns the title of the web page. + * + * @return The title of the web page. + */ + public String title() { + return title; + } + + /** + * Returns the URL associated with the web page. + * + * @return The URL associated with the web page. + */ + public URI url() { + return url; + } + + /** + * Returns the snippet associated with the web page. + * + * @return The snippet associated with the web page. + */ + public String snippet() { + return snippet; + } + + /** + * Returns the content scraped from the web page. + * + * @return The content scraped from the web page. + */ + public String content() { + return content; + } + + /** + * Returns the result metadata associated with the search result. + * + * @return The result metadata associated with the search result. + */ + public Map metadata() { + return metadata; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + WebSearchOrganicResult that = (WebSearchOrganicResult) o; + return Objects.equals(title, that.title) + && Objects.equals(url, that.url) + && Objects.equals(snippet, that.snippet) + && Objects.equals(content, that.content) + && Objects.equals(metadata, that.metadata); + } + + @Override + public int hashCode() { + return Objects.hash(title, url, snippet, content, metadata); + } + + @Override + public String toString() { + return "WebSearchOrganicResult{" + + "title='" + title + '\'' + + ", url=" + url + + ", snippet='" + snippet + '\'' + + ", content='" + content + '\'' + + ", metadata=" + metadata + + '}'; + } + + /** + * Converts this WebSearchOrganicResult to a TextSegment. + * + * @return The TextSegment representation of this WebSearchOrganicResult. + */ + public TextSegment toTextSegment() { + return TextSegment.from(copyToText(), copyToMetadata()); + } + + /** + * Converts this WebSearchOrganicResult to a Document. + * + * @return The Document representation of this WebSearchOrganicResult. + */ + public Document toDocument() { + return Document.from(copyToText(), copyToMetadata()); + } + + private String copyToText() { + StringBuilder text = new StringBuilder(); + text.append(title); + text.append("\n"); + if (isNotNullOrBlank(content)) { + text.append(content); + } else if (isNotNullOrBlank(snippet)) { + text.append(snippet); + } + return text.toString(); + } + + private Metadata copyToMetadata() { + Metadata docMetadata = new Metadata(); + docMetadata.add("url", url); + if (metadata != null) { + for (Map.Entry entry : metadata.entrySet()) { + docMetadata.put(entry.getKey(), entry.getValue()); + } + } + return docMetadata; + } + + /** + * Creates a WebSearchOrganicResult object from the given title and URL. + * + * @param title The title of the search result. + * @param url The URL associated with the search result. + * @return The created WebSearchOrganicResult object. + */ + public static WebSearchOrganicResult from(String title, URI url) { + return new WebSearchOrganicResult(title, url); + } + + /** + * Creates a WebSearchOrganicResult object from the given title, URL, snippet and/or content. + * + * @param title The title of the search result. + * @param url The URL associated with the search result. + * @param snippet The snippet of the search result, in plain text. + * @param content The most query related content from the scraped url. + * @return The created WebSearchOrganicResult object. + */ + public static WebSearchOrganicResult from(String title, URI url, String snippet, String content) { + return new WebSearchOrganicResult(title, url, snippet, content); + } + + /** + * Creates a WebSearchOrganicResult object from the given title, URL, snippet and/or content, and result metadata. + * + * @param title The title of the search result. + * @param url The URL associated with the search result. + * @param snippet The snippet of the search result, in plain text. + * @param content The most query related content from the scraped url. + * @param metadata The metadata associated with the search result. + * @return The created WebSearchOrganicResult object. + */ + public static WebSearchOrganicResult from(String title, URI url, String snippet, String content, Map metadata) { + return new WebSearchOrganicResult(title, url, snippet, content, metadata); + } +} diff --git a/langchain4j-core/src/main/java/dev/langchain4j/web/search/WebSearchRequest.java b/langchain4j-core/src/main/java/dev/langchain4j/web/search/WebSearchRequest.java new file mode 100644 index 0000000000..c79eb68fd9 --- /dev/null +++ b/langchain4j-core/src/main/java/dev/langchain4j/web/search/WebSearchRequest.java @@ -0,0 +1,312 @@ +package dev.langchain4j.web.search; + +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +import static dev.langchain4j.internal.Utils.getOrDefault; +import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank; + + +/** + * Represents a search request that can be made by the user to perform searches in any implementation of {@link WebSearchEngine}. + *

+ * {@link WebSearchRequest} follow opensearch foundation standard implemented by most web search engine libs like Google, Bing, Yahoo, etc. + * OpenSearch#parameters + *

+ *

+ * The {@link #searchTerms} are the keywords that the search client desires to search for. This param is mandatory to perform a search. + *

+ *
+ * Configurable parameters (optional): + *

    + *
  • {@link #maxResults} - The expected number of results to be found if the search request were made. Each search engine may have a different limit for the maximum number of results that can be returned.
  • + *
  • {@link #language} - The desired language for search results is a string that indicates that the search client desires search results in the specified language. Each search engine may have a different set of supported languages.
  • + *
  • {@link #geoLocation} - The desired geolocation for search results is a string that indicates that the search client desires search results in the specified geolocation. Each search engine may have a different set of supported geolocations.
  • + *
  • {@link #startPage} - The start page number for search results is the page number of the set of search results desired by the search user.
  • + *
  • {@link #startIndex} - The start index for search results is the index of the first search result desired by the search user. Each search engine may have a different set of supported start indexes in combination with the start page number.
  • + *
  • {@link #safeSearch} - The safe search flag is a boolean that indicates that the search client desires search results with safe search enabled or disabled.
  • + *
  • {@link #additionalParams} - The additional parameters for the search request are a map of key-value pairs that represent additional parameters for the search request. It's a way to be flex and add custom param for each search engine.
  • + *
+ */ +public class WebSearchRequest { + + private final String searchTerms; + private final Integer maxResults; + private final String language; + private final String geoLocation; + private final Integer startPage; + private final Integer startIndex; + private final Boolean safeSearch; + private final Map additionalParams; + + private WebSearchRequest(Builder builder){ + this.searchTerms = ensureNotBlank(builder.searchTerms,"searchTerms"); + this.maxResults = builder.maxResults; + this.language = builder.language; + this.geoLocation = builder.geoLocation; + this.startPage = getOrDefault(builder.startPage,1); + this.startIndex = builder.startIndex; + this.safeSearch = getOrDefault(builder.safeSearch,true); + this.additionalParams = getOrDefault(builder.additionalParams, () -> new HashMap<>()); + } + + /** + * Get the search terms. + * + * @return The search terms. + */ + public String searchTerms() { + return searchTerms; + } + + /** + * Get the maximum number of results. + * + * @return The maximum number of results. + */ + public Integer maxResults() { + return maxResults; + } + + /** + * Get the desired language for search results. + * + * @return The desired language for search results. + */ + public String language() { + return language; + } + + /** + * Get the desired geolocation for search results. + * + * @return The desired geolocation for search results. + */ + public String geoLocation() { + return geoLocation; + } + + /** + * Get the start page number for search results. + * + * @return The start page number for search results. + */ + public Integer startPage() { + return startPage; + } + + /** + * Get the start index for search results. + * + * @return The start index for search results. + */ + public Integer startIndex() { + return startIndex; + } + + /** + * Get the safe search flag. + * + * @return The safe search flag. + */ + public Boolean safeSearch() { + return safeSearch; + } + + /** + * Get the additional parameters for the search request. + * + * @return The additional parameters for the search request. + */ + public Map additionalParams() { + return additionalParams; + } + + @Override + public boolean equals(Object another) { + if (this == another) return true; + return another instanceof WebSearchRequest + && equalTo((WebSearchRequest) another); + } + + private boolean equalTo(WebSearchRequest another){ + return Objects.equals(searchTerms, another.searchTerms) + && Objects.equals(maxResults, another.maxResults) + && Objects.equals(language, another.language) + && Objects.equals(geoLocation, another.geoLocation) + && Objects.equals(startPage, another.startPage) + && Objects.equals(startIndex, another.startIndex) + && Objects.equals(safeSearch, another.safeSearch) + && Objects.equals(additionalParams, another.additionalParams); + } + + @Override + public int hashCode() { + int h = 5381; + h += (h << 5) + Objects.hashCode(searchTerms); + h += (h << 5) + Objects.hashCode(maxResults); + h += (h << 5) + Objects.hashCode(language); + h += (h << 5) + Objects.hashCode(geoLocation); + h += (h << 5) + Objects.hashCode(startPage); + h += (h << 5) + Objects.hashCode(startIndex); + h += (h << 5) + Objects.hashCode(safeSearch); + h += (h << 5) + Objects.hashCode(additionalParams); + return h; + } + + @Override + public String toString() { + return "WebSearchRequest{" + + "searchTerms='" + searchTerms + '\'' + + ", maxResults=" + maxResults + + ", language='" + language + '\'' + + ", geoLocation='" + geoLocation + '\'' + + ", startPage=" + startPage + + ", startIndex=" + startIndex + + ", siteRestrict=" + safeSearch + + ", additionalParams=" + additionalParams + + '}'; + } + + /** + * Create a new builder instance. + * + * @return A new builder instance. + */ + public static Builder builder() { + return new Builder(); + } + + public static final class Builder { + private String searchTerms; + private Integer maxResults; + private String language; + private String geoLocation; + private Integer startPage; + private Integer startIndex; + private Boolean safeSearch; + private Map additionalParams; + + private Builder() { + } + + /** + * Set the search terms. + * + * @param searchTerms The keyword or keywords desired by the search user. + * @return The builder instance. + */ + public Builder searchTerms(String searchTerms) { + this.searchTerms = searchTerms; + return this; + } + + /** + * Set the maximum number of results. + * + * @param maxResults The maximum number of results. + * @return The builder instance. + */ + public Builder maxResults(Integer maxResults) { + this.maxResults = maxResults; + return this; + } + + /** + * Set the desired language for search results. + * + * @param language The desired language for search results. + * @return The builder instance. + */ + public Builder language(String language) { + this.language = language; + return this; + } + + /** + * Set the desired geolocation for search results. + * + * @param geoLocation The desired geolocation for search results. + * @return The builder instance. + */ + public Builder geoLocation(String geoLocation) { + this.geoLocation = geoLocation; + return this; + } + + /** + * Set the start page number for search results. + * + * @param startPage The start page number for search results. + * @return The builder instance. + */ + public Builder startPage(Integer startPage) { + this.startPage = startPage; + return this; + } + + /** + * Set the start index for search results. + * + * @param startIndex The start index for search results. + * @return The builder instance. + */ + public Builder startIndex(Integer startIndex) { + this.startIndex = startIndex; + return this; + } + + /** + * Set the safe search flag. + * + * @param safeSearch The safe search flag. + * @return The builder instance. + */ + public Builder safeSearch(Boolean safeSearch) { + this.safeSearch = safeSearch; + return this; + } + + /** + * Set the additional parameters for the search request. + * + * @param additionalParams The additional parameters for the search request. + * @return The builder instance. + */ + public Builder additionalParams(Map additionalParams) { + this.additionalParams = additionalParams; + return this; + } + + /** + * Build the web search request. + * + * @return The web search request. + */ + public WebSearchRequest build() { + return new WebSearchRequest(this); + } + } + + /** + * Create a web search request with the given search terms. + * + * @param searchTerms The search terms. + * @return The web search request. + */ + public static WebSearchRequest from(String searchTerms) { + return WebSearchRequest.builder().searchTerms(searchTerms).build(); + } + + /** + * Create a web search request with the given search terms and maximum number of results. + * + * @param searchTerms The search terms. + * @param maxResults The maximum number of results. + * @return The web search request. + */ + public static WebSearchRequest from(String searchTerms, Integer maxResults) { + return WebSearchRequest.builder().searchTerms(searchTerms).maxResults(maxResults).build(); + } +} diff --git a/langchain4j-core/src/main/java/dev/langchain4j/web/search/WebSearchResults.java b/langchain4j-core/src/main/java/dev/langchain4j/web/search/WebSearchResults.java new file mode 100644 index 0000000000..d3667db84d --- /dev/null +++ b/langchain4j-core/src/main/java/dev/langchain4j/web/search/WebSearchResults.java @@ -0,0 +1,149 @@ +package dev.langchain4j.web.search; + +import dev.langchain4j.data.document.Document; +import dev.langchain4j.data.segment.TextSegment; + +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty; +import static dev.langchain4j.internal.ValidationUtils.ensureNotNull; +import static java.util.stream.Collectors.toList; + +/** + * Represents the response of a web search performed. + * This includes the list of organic search results, information about the search, and pagination information. + *

+ * {@link WebSearchResults} follow opensearch foundation standard implemented by most web search engine libs like Google, Bing, Yahoo, etc. + * OpenSearch#response + *

+ *

+ * The organic search results are the web pages that are returned by the search engine in response to a search query. + * These results are typically ranked by relevance to the search query. + */ +public class WebSearchResults { + + private final Map searchMetadata; + private final WebSearchInformationResult searchInformation; + private final List results; + + /** + * Constructs a new instance of WebSearchResults. + * + * @param searchInformation The information about the web search. + * @param results The list of organic search results. + */ + public WebSearchResults(WebSearchInformationResult searchInformation, List results) { + this(null, searchInformation, results); + } + + /** + * Constructs a new instance of WebSearchResults. + * + * @param searchMetadata The metadata associated with the web search. + * @param searchInformation The information about the web search. + * @param results The list of organic search results. + */ + public WebSearchResults(Map searchMetadata, WebSearchInformationResult searchInformation, List results) { + this.searchMetadata = searchMetadata; + this.searchInformation = ensureNotNull(searchInformation, "searchInformation"); + this.results = ensureNotEmpty(results, "results"); + } + + /** + * Gets the metadata associated with the web search. + * + * @return The metadata associated with the web search. + */ + public Map searchMetadata() { + return searchMetadata; + } + + /** + * Gets the information about the web search. + * + * @return The information about the web search. + */ + public WebSearchInformationResult searchInformation() { + return searchInformation; + } + + /** + * Gets the list of organic search results. + * + * @return The list of organic search results. + */ + public List results() { + return results; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + WebSearchResults that = (WebSearchResults) o; + return Objects.equals(searchMetadata, that.searchMetadata) + && Objects.equals(searchInformation, that.searchInformation) + && Objects.equals(results, that.results); + } + + @Override + public int hashCode() { + return Objects.hash(searchMetadata, searchInformation, results); + } + + @Override + public String toString() { + return "WebSearchResults{" + + "searchMetadata=" + searchMetadata + + ", searchInformation=" + searchInformation + + ", results=" + results + + '}'; + } + + /** + * Converts the organic search results to a list of text segments. + * + * @return The list of text segments. + */ + public List toTextSegments() { + return results.stream() + .map(WebSearchOrganicResult::toTextSegment) + .collect(toList()); + } + + /** + * Converts the organic search results to a list of documents. + * + * @return The list of documents. + */ + public List toDocuments() { + return results.stream() + .map(WebSearchOrganicResult::toDocument) + .collect(toList()); + } + + /** + * Creates a new instance of WebSearchResults from the specified parameters. + * + * @param results The list of organic search results. + * @param searchInformation The information about the web search. + * @return The new instance of WebSearchResults. + */ + public static WebSearchResults from(WebSearchInformationResult searchInformation, List results) { + return new WebSearchResults(searchInformation, results); + } + + /** + * Creates a new instance of WebSearchResults from the specified parameters. + * + * @param searchMetadata The metadata associated with the search results. + * @param searchInformation The information about the web search. + * @param results The list of organic search results. + * @return The new instance of WebSearchResults. + */ + public static WebSearchResults from(Map searchMetadata, WebSearchInformationResult searchInformation, List results) { + return new WebSearchResults(searchMetadata, searchInformation, results); + } +} diff --git a/langchain4j-core/src/main/java/dev/langchain4j/web/search/WebSearchTool.java b/langchain4j-core/src/main/java/dev/langchain4j/web/search/WebSearchTool.java new file mode 100644 index 0000000000..4b550b4d83 --- /dev/null +++ b/langchain4j-core/src/main/java/dev/langchain4j/web/search/WebSearchTool.java @@ -0,0 +1,48 @@ +package dev.langchain4j.web.search; + +import dev.langchain4j.agent.tool.P; +import dev.langchain4j.agent.tool.Tool; + +import java.util.stream.Collectors; + +import static dev.langchain4j.internal.ValidationUtils.ensureNotNull; + +public class WebSearchTool { + + private final WebSearchEngine searchEngine; + + public WebSearchTool(WebSearchEngine searchEngine) { + this.searchEngine = ensureNotNull(searchEngine, "searchEngine"); + } + + /** + * Runs a search query on the web search engine and returns a pretty-string representation of the search results. + * + * @param query the search user query + * @return a pretty-string representation of the search results + */ + @Tool("This tool can be used to perform web searches using search engines such as Google, particularly when seeking information about recent events.") + public String searchWeb(@P("Web search query") String query) { + WebSearchResults results = searchEngine.search(query); + return format(results); + } + + private String format(WebSearchResults results) { + return results.results() + .stream() + .map(organicResult -> "Title: " + organicResult.title() + "\n" + + "Source: " + organicResult.url().toString() + "\n" + + (organicResult.content() != null ? "Content:" + "\n" + organicResult.content() : "Snippet:" + "\n" + organicResult.snippet())) + .collect(Collectors.joining("\n\n")); + } + + /** + * Creates a new WebSearchTool with the specified web search engine. + * + * @param searchEngine the web search engine to use for searching the web + * @return a new WebSearchTool + */ + public static WebSearchTool from(WebSearchEngine searchEngine) { + return new WebSearchTool(searchEngine); + } +} diff --git a/langchain4j-core/src/test/java/dev/langchain4j/data/document/DocumentLoaderTest.java b/langchain4j-core/src/test/java/dev/langchain4j/data/document/DocumentLoaderTest.java index b3bced0a00..80a454e952 100644 --- a/langchain4j-core/src/test/java/dev/langchain4j/data/document/DocumentLoaderTest.java +++ b/langchain4j-core/src/test/java/dev/langchain4j/data/document/DocumentLoaderTest.java @@ -70,9 +70,9 @@ public Document parse(InputStream inputStream) { @Test public void test_load() { - StringSource source = new StringSource("Hello, world!", new Metadata().add("foo", "bar")); + StringSource source = new StringSource("Hello, world!", new Metadata().put("foo", "bar")); Document document = DocumentLoader.load(source, new TrivialParser()); - assertThat(document).isEqualTo(Document.from("Hello, world!", new Metadata().add("foo", "bar"))); + assertThat(document).isEqualTo(Document.from("Hello, world!", new Metadata().put("foo", "bar"))); assertThatExceptionOfType(RuntimeException.class) .isThrownBy(() -> DocumentLoader.load(new DocumentSource() { diff --git a/langchain4j-core/src/test/java/dev/langchain4j/data/document/MetadataTest.java b/langchain4j-core/src/test/java/dev/langchain4j/data/document/MetadataTest.java index b2273f3231..b94f14cfb7 100644 --- a/langchain4j-core/src/test/java/dev/langchain4j/data/document/MetadataTest.java +++ b/langchain4j-core/src/test/java/dev/langchain4j/data/document/MetadataTest.java @@ -17,12 +17,12 @@ class MetadataTest implements WithAssertions { public void test_add_get_put() { Metadata m = new Metadata(); - assertThat(m.get("foo")).isNull(); - m.add("foo", "bar"); - assertThat(m.get("foo")).isEqualTo("bar"); + assertThat(m.getString("foo")).isNull(); + m.put("foo", "bar"); + assertThat(m.getString("foo")).isEqualTo("bar"); - m.add("xyz", 2); - assertThat(m.get("xyz")).isEqualTo("2"); + m.put("xyz", 2); + assertThat(m.getInteger("xyz").toString()).isEqualTo("2"); } @Test @@ -34,14 +34,14 @@ public void test_map_constructor_copies() { Map sourceCopy = new HashMap<>(source); source.put("baz", "qux"); - assertThat(m.asMap()).isEqualTo(sourceCopy); + assertThat(m.toMap()).isEqualTo(sourceCopy); } @Test public void test_toString() { Metadata m = new Metadata(); - m.add("foo", "bar"); - m.add("baz", "qux"); + m.put("foo", "bar"); + m.put("baz", "qux"); assertThat(m.toString()).isEqualTo("Metadata { metadata = {foo=bar, baz=qux} }"); } @@ -49,13 +49,13 @@ public void test_toString() { public void test_equals_hash() { Metadata m1 = new Metadata(); Metadata m2 = new Metadata(); - m1.add("foo", "bar"); - m2.add("foo", "bar"); + m1.put("foo", "bar"); + m2.put("foo", "bar"); Metadata m3 = new Metadata(); Metadata m4 = new Metadata(); - m3.add("different", "value"); - m4.add("different", "value"); + m3.put("different", "value"); + m4.put("different", "value"); assertThat(m1) .isNotEqualTo(null) @@ -75,10 +75,10 @@ public void test_equals_hash() { @Test public void test_copy() { Metadata m1 = new Metadata(); - m1.add("foo", "bar"); + m1.put("foo", "bar"); Metadata m2 = m1.copy(); assertThat(m1).isEqualTo(m2); - m1.add("foo", "baz"); + m1.put("foo", "baz"); assertThat(m1).isNotEqualTo(m2); } @@ -93,34 +93,34 @@ public void test_builders() { .isEqualTo(new Metadata(emptyMap)); assertThat(Metadata.from(map)) - .isEqualTo(new Metadata().add("foo", "bar").add("baz", "qux")); + .isEqualTo(new Metadata().put("foo", "bar").put("baz", "qux")); assertThat(Metadata.from("foo", "bar")) - .isEqualTo(new Metadata().add("foo", "bar")); + .isEqualTo(new Metadata().put("foo", "bar")); assertThat(Metadata.metadata("foo", "bar")) - .isEqualTo(new Metadata().add("foo", "bar")); + .isEqualTo(new Metadata().put("foo", "bar")); assertThat(Metadata.from("foo", 2)) - .isEqualTo(new Metadata().add("foo", "2")); + .isEqualTo(new Metadata().put("foo", "2")); assertThat(Metadata.metadata("foo", 2)) - .isEqualTo(new Metadata().add("foo", "2")); + .isEqualTo(new Metadata().put("foo", "2")); } @Test public void test_remove() { Metadata m1 = new Metadata(); - m1.add("foo", "bar"); - m1.add("baz", "qux"); + m1.put("foo", "bar"); + m1.put("baz", "qux"); assertThat(m1.remove("foo")).isSameAs(m1); - assertThat(m1).isEqualTo(new Metadata().add("baz", "qux")); + assertThat(m1).isEqualTo(new Metadata().put("baz", "qux")); } @Test void test_asMap() { Metadata metadata = Metadata.from("key", "value"); - Map map = metadata.asMap(); + Map map = metadata.toMap(); assertThat(map).containsKey("key").containsValue("value"); } @@ -132,7 +132,7 @@ void test_create_from_map() { Metadata metadata = Metadata.from(map); - assertThat(metadata.get("key")).isEqualTo("value"); + assertThat(metadata.getString("key")).isEqualTo("value"); } @Test @@ -353,6 +353,6 @@ void should_convert_to_map() { @Test void test_containsKey() { assertThat(new Metadata().containsKey("key")).isFalse(); - assertThat(new Metadata().add("key", "value").containsKey("key")).isTrue(); + assertThat(new Metadata().put("key", "value").containsKey("key")).isTrue(); } } \ No newline at end of file diff --git a/langchain4j-core/src/test/java/dev/langchain4j/data/segment/TextSegmentTest.java b/langchain4j-core/src/test/java/dev/langchain4j/data/segment/TextSegmentTest.java index 68a77354ba..81eaa52e34 100644 --- a/langchain4j-core/src/test/java/dev/langchain4j/data/segment/TextSegmentTest.java +++ b/langchain4j-core/src/test/java/dev/langchain4j/data/segment/TextSegmentTest.java @@ -25,12 +25,12 @@ public void test_equals_hashCode() { TextSegment ts2 = TextSegment.from("text"); Metadata m1 = new Metadata(); - m1.add("abc", "123"); + m1.put("abc", "123"); Metadata m2 = new Metadata(); - m2.add("abc", "123"); + m2.put("abc", "123"); Metadata m3 = new Metadata(); - m3.add("abc", "xyz"); + m3.put("abc", "xyz"); TextSegment ts3 = TextSegment.from("text", m1); TextSegment ts4 = TextSegment.from("text", m1); @@ -56,7 +56,7 @@ public void test_equals_hashCode() { @Test public void test_accessors() { Metadata metadata = new Metadata(); - metadata.add("abc", "123"); + metadata.put("abc", "123"); TextSegment ts = TextSegment.from("text", metadata); assertThat(ts.text()).isEqualTo("text"); @@ -76,7 +76,7 @@ public void test_builders() { .isEqualTo(TextSegment.textSegment("abc", new Metadata())); Metadata metadata = new Metadata(); - metadata.add("abc", "123"); + metadata.put("abc", "123"); assertThat(new TextSegment("abc", metadata)) .isEqualTo(TextSegment.from("abc", metadata)) diff --git a/langchain4j-core/src/test/java/dev/langchain4j/data/segment/TextSegmentTransformerTest.java b/langchain4j-core/src/test/java/dev/langchain4j/data/segment/TextSegmentTransformerTest.java index 47b8535f08..99735b2295 100644 --- a/langchain4j-core/src/test/java/dev/langchain4j/data/segment/TextSegmentTransformerTest.java +++ b/langchain4j-core/src/test/java/dev/langchain4j/data/segment/TextSegmentTransformerTest.java @@ -22,7 +22,7 @@ public TextSegment transform(TextSegment segment) { public void test_transformAll() { TextSegmentTransformer transformer = new LowercaseFnordTransformer(); TextSegment ts1 = TextSegment.from("Text"); - ts1.metadata().add("abc", "123"); // metadata is copied over (not transformed + ts1.metadata().put("abc", "123"); // metadata is copied over (not transformed TextSegment ts2 = TextSegment.from("Segment"); TextSegment ts3 = TextSegment.from("Fnord will be filtered out"); diff --git a/langchain4j-core/src/test/java/dev/langchain4j/internal/RetryUtilsTest.java b/langchain4j-core/src/test/java/dev/langchain4j/internal/RetryUtilsTest.java index 518b4493ef..cb1a66bf3c 100644 --- a/langchain4j-core/src/test/java/dev/langchain4j/internal/RetryUtilsTest.java +++ b/langchain4j-core/src/test/java/dev/langchain4j/internal/RetryUtilsTest.java @@ -106,4 +106,36 @@ void testMaxAttemptsReached() throws Exception { verify(mockAction, times(3)).call(); verifyNoMoreInteractions(mockAction); } + + @Test + void testZeroAttemptsReached() throws Exception { + @SuppressWarnings("unchecked") + Callable mockAction = mock(Callable.class); + when(mockAction.call()).thenThrow(new RuntimeException()); + + RetryUtils.RetryPolicy policy = RetryUtils.retryPolicyBuilder() + .delayMillis(100) + .build(); + + assertThatThrownBy(() -> policy.withRetry(mockAction, 0)) + .isInstanceOf(RuntimeException.class); + verify(mockAction, times(1)).call(); + verifyNoMoreInteractions(mockAction); + } + + @Test + void testIllegalAttemptsReached() throws Exception { + @SuppressWarnings("unchecked") + Callable mockAction = mock(Callable.class); + when(mockAction.call()).thenThrow(new RuntimeException()); + + RetryUtils.RetryPolicy policy = RetryUtils.retryPolicyBuilder() + .delayMillis(100) + .build(); + + assertThatThrownBy(() -> policy.withRetry(mockAction, -1)) + .isInstanceOf(RuntimeException.class); + verify(mockAction, times(1)).call(); + verifyNoMoreInteractions(mockAction); + } } \ No newline at end of file diff --git a/langchain4j-core/src/test/java/dev/langchain4j/internal/UtilsTest.java b/langchain4j-core/src/test/java/dev/langchain4j/internal/UtilsTest.java index 52969249cb..1c595d4565 100644 --- a/langchain4j-core/src/test/java/dev/langchain4j/internal/UtilsTest.java +++ b/langchain4j-core/src/test/java/dev/langchain4j/internal/UtilsTest.java @@ -83,6 +83,13 @@ public void test_collection_isNullOrEmpty() { assertThat(Utils.isNullOrEmpty(Collections.singletonList("abc"))).isFalse(); } + @Test + public void test_iterable_isNullOrEmpty() { + assertThat(Utils.isNullOrEmpty((Iterable) null)).isTrue(); + assertThat(Utils.isNullOrEmpty((Iterable) emptyList())).isTrue(); + assertThat(Utils.isNullOrEmpty((Iterable) Collections.singletonList("abc"))).isFalse(); + } + @Test @SuppressWarnings("deprecation") public void test_isCollectionEmpty() { diff --git a/langchain4j-core/src/test/java/dev/langchain4j/model/LambdaStreamingResponseHandlerTest.java b/langchain4j-core/src/test/java/dev/langchain4j/model/LambdaStreamingResponseHandlerTest.java new file mode 100644 index 0000000000..5060a50368 --- /dev/null +++ b/langchain4j-core/src/test/java/dev/langchain4j/model/LambdaStreamingResponseHandlerTest.java @@ -0,0 +1,103 @@ +package dev.langchain4j.model; + +import dev.langchain4j.agent.tool.ToolSpecification; +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.data.message.UserMessage; +import dev.langchain4j.model.chat.StreamingChatLanguageModel; +import org.assertj.core.api.WithAssertions; +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.List; + +import static dev.langchain4j.model.LambdaStreamingResponseHandler.onNext; +import static dev.langchain4j.model.LambdaStreamingResponseHandler.onNextAndError; + +public class LambdaStreamingResponseHandlerTest implements WithAssertions { + @Test + void testOnNext() { + // given + List tokens = new ArrayList<>(); + tokens.add("The sky "); + tokens.add("is blue because of "); + tokens.add("a phenomenon called "); + tokens.add("Rayleigh scattering."); + + StreamingChatLanguageModel model = new DummyModel(tokens); + + // when + List receivedTokens = new ArrayList<>(); + model.generate("Why is the sky blue?", + onNext(text -> receivedTokens.add(text))); + + // then + assertThat(receivedTokens).containsSequence(tokens); + } + + @Test + void testOnNextAndError() { + // given + List tokens = new ArrayList<>(); + tokens.add("Three "); + tokens.add("Two "); + tokens.add("One "); + tokens.add(new RuntimeException("BOOM")); + + StreamingChatLanguageModel model = new DummyModel(tokens); + + // when + List receivedTokens = new ArrayList<>(); + final Throwable[] thrown = { null }; + + model.generate("Create a countdown", + onNextAndError(text -> receivedTokens.add(text), t -> thrown[0] = t)); + + // then + assertThat(tokens).containsSubsequence(receivedTokens); + assertThat(thrown[0]).isNotNull(); + assertThat(thrown[0]).isInstanceOf(RuntimeException.class); + assertThat(((Throwable)thrown[0]).getMessage()).isEqualTo("BOOM"); + } + + class DummyModel implements StreamingChatLanguageModel { + private final List stringsAndError; + + public DummyModel(List stringsAndError) { + this.stringsAndError = stringsAndError; + } + + @Override + public void generate(String userMessage, StreamingResponseHandler handler) { + StreamingChatLanguageModel.super.generate(userMessage, handler); + } + + @Override + public void generate(UserMessage userMessage, StreamingResponseHandler handler) { + StreamingChatLanguageModel.super.generate(userMessage, handler); + } + + @Override + public void generate(List messages, StreamingResponseHandler handler) { + stringsAndError.forEach(obj -> { + if (obj instanceof String) { + String msg = (String)obj; + handler.onNext(msg); + } else if (obj instanceof Throwable) { + Throwable problem = (Throwable) obj; + handler.onError(problem); + } + }); + } + + @Override + public void generate(List messages, List toolSpecifications, StreamingResponseHandler handler) { + StreamingChatLanguageModel.super.generate(messages, toolSpecifications, handler); + } + + @Override + public void generate(List messages, ToolSpecification toolSpecification, StreamingResponseHandler handler) { + StreamingChatLanguageModel.super.generate(messages, toolSpecification, handler); + } + } +} diff --git a/langchain4j-core/src/test/java/dev/langchain4j/model/input/PromptTest.java b/langchain4j-core/src/test/java/dev/langchain4j/model/input/PromptTest.java index cec61b645f..f0d1222c70 100644 --- a/langchain4j-core/src/test/java/dev/langchain4j/model/input/PromptTest.java +++ b/langchain4j-core/src/test/java/dev/langchain4j/model/input/PromptTest.java @@ -24,6 +24,8 @@ public void test_constructor() { .isEqualTo(systemMessage("abc")); assertThat(p.toUserMessage()) .isEqualTo(userMessage("abc")); + assertThat(p.toUserMessage("userName")) + .isEqualTo(userMessage("userName", "abc")); assertThat(p.toAiMessage()) .isEqualTo(aiMessage("abc")); } diff --git a/langchain4j-core/src/test/java/dev/langchain4j/model/output/TokenUsageTest.java b/langchain4j-core/src/test/java/dev/langchain4j/model/output/TokenUsageTest.java index be8b571a70..8e3e043cbb 100644 --- a/langchain4j-core/src/test/java/dev/langchain4j/model/output/TokenUsageTest.java +++ b/langchain4j-core/src/test/java/dev/langchain4j/model/output/TokenUsageTest.java @@ -3,6 +3,8 @@ import org.assertj.core.api.WithAssertions; import org.junit.jupiter.api.Test; +import static dev.langchain4j.model.output.TokenUsage.sum; + class TokenUsageTest implements WithAssertions { @Test public void test_constructors() { @@ -69,28 +71,34 @@ public void test_toString() { } @Test - public void test_add() { - assertThat( - new TokenUsage(1, 2, 3) - .add(new TokenUsage(4, 5, 6))) - .isEqualTo(new TokenUsage(5, 7, 9)); - - assertThat( - new TokenUsage(1, 2, 3) - .add(new TokenUsage(null, null, null))) - .isEqualTo(new TokenUsage(1, 2, 3)); - - assertThat( + public void test_sum() { + assertThat(sum( + new TokenUsage(1, 2, 3), + new TokenUsage(4, 5, 6) + )).isEqualTo(new TokenUsage(5, 7, 9)); + + assertThat(sum( + new TokenUsage(1, 2, 3), new TokenUsage(null, null, null) - .add(new TokenUsage(4, 5, 6))) - .isEqualTo(new TokenUsage(4, 5, 6)); + )).isEqualTo(new TokenUsage(1, 2, 3)); + + assertThat(sum(new TokenUsage(null, null, null), + new TokenUsage(4, 5, 6) + )).isEqualTo(new TokenUsage(4, 5, 6)); - assertThat( + assertThat(sum( + new TokenUsage(null, null, null), new TokenUsage(null, null, null) - .add(new TokenUsage(null, null, null))) - .isEqualTo(new TokenUsage(null, null, null)); + )).isEqualTo(new TokenUsage(null, null, null)); - assertThat(new TokenUsage(1, 2, 3).add(null)) - .isEqualTo(new TokenUsage(1, 2, 3)); + assertThat(sum( + new TokenUsage(1, 2, 3), + null + )).isEqualTo(new TokenUsage(1, 2, 3)); + + assertThat(sum( + null, + new TokenUsage(4, 5, 6) + )).isEqualTo(new TokenUsage(4, 5, 6)); } } \ No newline at end of file diff --git a/langchain4j-core/src/test/java/dev/langchain4j/rag/DefaultRetrievalAugmentorTest.java b/langchain4j-core/src/test/java/dev/langchain4j/rag/DefaultRetrievalAugmentorTest.java index 434f04019b..f9a6737403 100644 --- a/langchain4j-core/src/test/java/dev/langchain4j/rag/DefaultRetrievalAugmentorTest.java +++ b/langchain4j-core/src/test/java/dev/langchain4j/rag/DefaultRetrievalAugmentorTest.java @@ -1,5 +1,6 @@ package dev.langchain4j.rag; +import dev.langchain4j.data.message.ChatMessage; import dev.langchain4j.data.message.UserMessage; import dev.langchain4j.rag.content.Content; import dev.langchain4j.rag.content.aggregator.ContentAggregator; @@ -115,6 +116,10 @@ void should_augment_user_message(Executor executor) { content1, content2, content3, content4, content1, content2, content3, content4 ), userMessage); + verify(contentInjector).inject(asList( + content1, content2, content3, content4, + content1, content2, content3, content4 + ), (ChatMessage) userMessage); verifyNoMoreInteractions(contentInjector); } diff --git a/langchain4j-core/src/test/java/dev/langchain4j/rag/content/injector/DefaultContentInjectorTest.java b/langchain4j-core/src/test/java/dev/langchain4j/rag/content/injector/DefaultContentInjectorTest.java index 44b357f79b..498f64ae29 100644 --- a/langchain4j-core/src/test/java/dev/langchain4j/rag/content/injector/DefaultContentInjectorTest.java +++ b/langchain4j-core/src/test/java/dev/langchain4j/rag/content/injector/DefaultContentInjectorTest.java @@ -60,6 +60,28 @@ void should_inject_single_content() { ); } + @Test + void should_inject_single_content_with_userName() { + // given + UserMessage userMessage = UserMessage.from("ape", "Tell me about bananas."); + + List contents = singletonList(Content.from("Bananas are awesome!")); + + ContentInjector injector = new DefaultContentInjector(); + + // when + UserMessage injected = injector.inject(contents, userMessage); + + // then + assertThat(injected.text()).isEqualTo( + "Tell me about bananas.\n" + + "\n" + + "Answer using the following information:\n" + + "Bananas are awesome!" + ); + assertThat(injected.name()).isEqualTo("ape"); + } + @Test void should_inject_single_content_with_metadata() { @@ -128,12 +150,12 @@ void should_inject_multiple_contents_with_multiple_metadata_entries( TextSegment segment1 = TextSegment.from( "Bananas are awesome!", Metadata.from("source", "trust me bro") - .add("date", "today") + .put("date", "today") ); TextSegment segment2 = TextSegment.from( "Bananas are healthy!", Metadata.from("source", "my doctor") - .add("reliability", "100%") + .put("reliability", "100%") ); List contents = asList(Content.from(segment1), Content.from(segment2)); diff --git a/langchain4j-core/src/test/java/dev/langchain4j/rag/content/retriever/WebSearchContentRetrieverIT.java b/langchain4j-core/src/test/java/dev/langchain4j/rag/content/retriever/WebSearchContentRetrieverIT.java new file mode 100644 index 0000000000..3324b636d9 --- /dev/null +++ b/langchain4j-core/src/test/java/dev/langchain4j/rag/content/retriever/WebSearchContentRetrieverIT.java @@ -0,0 +1,41 @@ +package dev.langchain4j.rag.content.retriever; + +import dev.langchain4j.rag.content.Content; +import dev.langchain4j.rag.query.Query; +import dev.langchain4j.web.search.WebSearchEngine; +import org.junit.jupiter.api.Test; + +import java.util.List; + +import static org.assertj.core.api.Assertions.assertThat; + +public abstract class WebSearchContentRetrieverIT { + + protected abstract WebSearchEngine searchEngine(); + + @Test + void should_retrieve_web_page_as_content() { + + // given + WebSearchContentRetriever contentRetriever = WebSearchContentRetriever.builder() + .webSearchEngine(searchEngine()) + .build(); + + Query query = Query.from("What is the current weather in New York?"); + + // when + List contents = contentRetriever.retrieve(query); + + // then + assertThat(contents) + .as("At least one content should be contains 'weather' and 'New York' ignoring case") + .anySatisfy(content -> { + assertThat(content.textSegment().text()) + .containsIgnoringCase("weather") + .containsIgnoringCase("New York"); + assertThat(content.textSegment().metadata().get("url")) + .startsWith("https://"); + } + ); + } +} diff --git a/langchain4j-core/src/test/java/dev/langchain4j/rag/content/retriever/WebSearchContentRetrieverTest.java b/langchain4j-core/src/test/java/dev/langchain4j/rag/content/retriever/WebSearchContentRetrieverTest.java new file mode 100644 index 0000000000..718ec1f2fa --- /dev/null +++ b/langchain4j-core/src/test/java/dev/langchain4j/rag/content/retriever/WebSearchContentRetrieverTest.java @@ -0,0 +1,71 @@ +package dev.langchain4j.rag.content.retriever; + +import dev.langchain4j.data.document.Metadata; +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.rag.content.Content; +import dev.langchain4j.rag.query.Query; +import dev.langchain4j.web.search.*; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.net.URI; +import java.util.HashMap; +import java.util.List; + +import static java.util.Arrays.asList; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.*; + +class WebSearchContentRetrieverTest { + + WebSearchEngine webSearchEngine; + + @BeforeEach + void mockWebSearchEngine() { + webSearchEngine = mock(WebSearchEngine.class); + when(webSearchEngine.search(any(WebSearchRequest.class))).thenReturn( + new WebSearchResults( + WebSearchInformationResult.from(3L, 1, new HashMap<>()), + asList( + WebSearchOrganicResult.from("title 1", URI.create("https://one.com"), "snippet 1", null), + WebSearchOrganicResult.from("title 2", URI.create("https://two.com"), null, "content 2"), + WebSearchOrganicResult.from("title 3", URI.create("https://three.com"), "snippet 3", "content 3"), + WebSearchOrganicResult.from("title 4", URI.create("https://four.com"), "snippet 4", "content 4"), + WebSearchOrganicResult.from("title 5", URI.create("https://five.com"), "snippet 5", "content 5") + ) + ) + ); + } + + @AfterEach + void resetWebSearchEngine() { + reset(webSearchEngine); + } + + @Test + void should_retrieve_web_pages_back() { + + // given + ContentRetriever contentRetriever = WebSearchContentRetriever.builder() + .webSearchEngine(webSearchEngine) + .build(); + + Query query = Query.from("query"); + + // when + List contents = contentRetriever.retrieve(query); + + // then + assertThat(contents).containsExactly( + Content.from(TextSegment.from("title 1\nsnippet 1", Metadata.from("url", "https://one.com"))), + Content.from(TextSegment.from("title 2\ncontent 2", Metadata.from("url", "https://two.com"))), + Content.from(TextSegment.from("title 3\ncontent 3", Metadata.from("url", "https://three.com"))), + Content.from(TextSegment.from("title 4\ncontent 4", Metadata.from("url", "https://four.com"))), + Content.from(TextSegment.from("title 5\ncontent 5", Metadata.from("url", "https://five.com"))) + ); + + verify(webSearchEngine).search(WebSearchRequest.builder().searchTerms(query.text()).maxResults(5).build()); + verifyNoMoreInteractions(webSearchEngine); + } +} diff --git a/langchain4j-core/src/test/java/dev/langchain4j/rag/query/transformer/CompressingQueryTransformerTest.java b/langchain4j-core/src/test/java/dev/langchain4j/rag/query/transformer/CompressingQueryTransformerTest.java index 9be2a492cc..f57297a283 100644 --- a/langchain4j-core/src/test/java/dev/langchain4j/rag/query/transformer/CompressingQueryTransformerTest.java +++ b/langchain4j-core/src/test/java/dev/langchain4j/rag/query/transformer/CompressingQueryTransformerTest.java @@ -43,16 +43,16 @@ void should_compress_query_and_chat_memory_into_single_query() { Query query = Query.from(userMessage.text(), metadata); - String expectedResultingQuery = "How old is Klaus Heisler?"; + String expectedCompressedQuery = "How old is Klaus Heisler?"; - ChatModelMock model = ChatModelMock.thatAlwaysResponds(expectedResultingQuery); + ChatModelMock model = ChatModelMock.thatAlwaysResponds(expectedCompressedQuery); CompressingQueryTransformer transformer = new CompressingQueryTransformer(model); // when Collection queries = transformer.transform(query); // then - assertThat(queries).containsExactly(Query.from(expectedResultingQuery)); + assertThat(queries).containsExactly(Query.from(expectedCompressedQuery, metadata)); assertThat(model.userMessageText()).isEqualTo( "Read and understand the conversation between the User and the AI. " + @@ -110,8 +110,8 @@ void should_compress_query_and_chat_memory_into_single_query_using_custom_prompt Metadata metadata = Metadata.from(userMessage, "default", chatMemory); Query query = Query.from(userMessage.text(), metadata); - String expectedResultingQuery = "How old is Klaus Heisler?"; - ChatModelMock model = ChatModelMock.thatAlwaysResponds(expectedResultingQuery); + String expectedCompressedQuery = "How old is Klaus Heisler?"; + ChatModelMock model = ChatModelMock.thatAlwaysResponds(expectedCompressedQuery); CompressingQueryTransformer transformer = new CompressingQueryTransformer(model, promptTemplate); @@ -119,7 +119,7 @@ void should_compress_query_and_chat_memory_into_single_query_using_custom_prompt Collection queries = transformer.transform(query); // then - assertThat(queries).containsExactly(Query.from(expectedResultingQuery)); + assertThat(queries).containsExactly(Query.from(expectedCompressedQuery, metadata)); assertThat(model.userMessageText()).isEqualTo( "Given the following conversation: " + @@ -144,8 +144,8 @@ void should_compress_query_and_chat_memory_into_single_query_using_custom_prompt Metadata metadata = Metadata.from(userMessage, "default", chatMemory); Query query = Query.from(userMessage.text(), metadata); - String expectedResultingQuery = "How old is Klaus Heisler?"; - ChatModelMock model = ChatModelMock.thatAlwaysResponds(expectedResultingQuery); + String expectedCompressedQuery = "How old is Klaus Heisler?"; + ChatModelMock model = ChatModelMock.thatAlwaysResponds(expectedCompressedQuery); CompressingQueryTransformer transformer = CompressingQueryTransformer.builder() .chatLanguageModel(model) @@ -156,7 +156,7 @@ void should_compress_query_and_chat_memory_into_single_query_using_custom_prompt Collection queries = transformer.transform(query); // then - assertThat(queries).containsExactly(Query.from(expectedResultingQuery)); + assertThat(queries).containsExactly(Query.from(expectedCompressedQuery, metadata)); assertThat(model.userMessageText()).isEqualTo( "Given the following conversation: " + @@ -165,4 +165,4 @@ void should_compress_query_and_chat_memory_into_single_query_using_custom_prompt "reformulate the following query: How old is he?" ); } -} \ No newline at end of file +} diff --git a/langchain4j-core/src/test/java/dev/langchain4j/rag/query/transformer/ExpandingQueryTransformerTest.java b/langchain4j-core/src/test/java/dev/langchain4j/rag/query/transformer/ExpandingQueryTransformerTest.java index ad76bf0916..ede326ef1b 100644 --- a/langchain4j-core/src/test/java/dev/langchain4j/rag/query/transformer/ExpandingQueryTransformerTest.java +++ b/langchain4j-core/src/test/java/dev/langchain4j/rag/query/transformer/ExpandingQueryTransformerTest.java @@ -1,14 +1,20 @@ package dev.langchain4j.rag.query.transformer; +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.data.message.UserMessage; import dev.langchain4j.model.chat.mock.ChatModelMock; import dev.langchain4j.model.input.PromptTemplate; +import dev.langchain4j.rag.query.Metadata; import dev.langchain4j.rag.query.Query; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; import java.util.Collection; +import java.util.List; +import static java.util.Arrays.asList; import static org.assertj.core.api.Assertions.assertThat; class ExpandingQueryTransformerTest { @@ -21,7 +27,14 @@ class ExpandingQueryTransformerTest { void should_expand_query(String queriesString) { // given - Query query = Query.from("query"); + List chatMemory = asList( + UserMessage.from("Hi"), + AiMessage.from("Hello") + ); + UserMessage userMessage = UserMessage.from("query"); + Metadata metadata = Metadata.from(userMessage, "default", chatMemory); + + Query query = Query.from(userMessage.singleText(), metadata); ChatModelMock model = ChatModelMock.thatAlwaysResponds(queriesString); @@ -32,9 +45,9 @@ void should_expand_query(String queriesString) { // then assertThat(queries).containsExactly( - Query.from("query 1"), - Query.from("query 2"), - Query.from("query 3") + Query.from("query 1", metadata), + Query.from("query 2", metadata), + Query.from("query 3", metadata) ); assertThat(model.userMessageText()).isEqualTo( "Generate 3 different versions of a provided user query. " + diff --git a/langchain4j-core/src/test/java/dev/langchain4j/store/embedding/EmbeddingStoreIT.java b/langchain4j-core/src/test/java/dev/langchain4j/store/embedding/EmbeddingStoreIT.java index 47fa92c188..84906e7e8f 100644 --- a/langchain4j-core/src/test/java/dev/langchain4j/store/embedding/EmbeddingStoreIT.java +++ b/langchain4j-core/src/test/java/dev/langchain4j/store/embedding/EmbeddingStoreIT.java @@ -30,7 +30,7 @@ void should_add_embedding_with_segment_with_metadata() { // Not returned. TextSegment altSegment = TextSegment.from("hello?"); Embedding altEmbedding = embeddingModel().embed(altSegment.text()).content(); - embeddingStore().add(altEmbedding, segment); + embeddingStore().add(altEmbedding, altSegment); } awaitUntilPersisted(); diff --git a/langchain4j-core/src/test/java/dev/langchain4j/store/embedding/EmbeddingStoreWithFilteringIT.java b/langchain4j-core/src/test/java/dev/langchain4j/store/embedding/EmbeddingStoreWithFilteringIT.java index 963092abe3..879b438077 100644 --- a/langchain4j-core/src/test/java/dev/langchain4j/store/embedding/EmbeddingStoreWithFilteringIT.java +++ b/langchain4j-core/src/test/java/dev/langchain4j/store/embedding/EmbeddingStoreWithFilteringIT.java @@ -8,6 +8,7 @@ import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; +import java.util.ArrayList; import java.util.List; import java.util.stream.Stream; @@ -29,21 +30,29 @@ void should_filter_by_metadata(Filter metadataFilter, List matchingMetadatas, List notMatchingMetadatas) { // given + List embeddings = new ArrayList<>(); + List segments = new ArrayList<>(); + for (Metadata matchingMetadata : matchingMetadatas) { TextSegment matchingSegment = TextSegment.from("matching", matchingMetadata); Embedding matchingEmbedding = embeddingModel().embed(matchingSegment).content(); - embeddingStore().add(matchingEmbedding, matchingSegment); + embeddings.add(matchingEmbedding); + segments.add(matchingSegment); } for (Metadata notMatchingMetadata : notMatchingMetadatas) { TextSegment notMatchingSegment = TextSegment.from("not matching", notMatchingMetadata); Embedding notMatchingEmbedding = embeddingModel().embed(notMatchingSegment).content(); - embeddingStore().add(notMatchingEmbedding, notMatchingSegment); + embeddings.add(notMatchingEmbedding); + segments.add(notMatchingSegment); } TextSegment notMatchingSegmentWithoutMetadata = TextSegment.from("not matching, without metadata"); - Embedding notMatchingWithoutMetadataEmbedding = embeddingModel().embed(notMatchingSegmentWithoutMetadata).content(); - embeddingStore().add(notMatchingWithoutMetadataEmbedding, notMatchingSegmentWithoutMetadata); + Embedding notMatchingEmbeddingWithoutMetadata = embeddingModel().embed(notMatchingSegmentWithoutMetadata).content(); + embeddings.add(notMatchingEmbeddingWithoutMetadata); + segments.add(notMatchingSegmentWithoutMetadata); + + embeddingStore().addAll(embeddings, segments); awaitUntilPersisted(); @@ -1138,21 +1147,29 @@ void should_filter_by_metadata_not(Filter metadataFilter, List matchingMetadatas, List notMatchingMetadatas) { // given + List embeddings = new ArrayList<>(); + List segments = new ArrayList<>(); + for (Metadata matchingMetadata : matchingMetadatas) { TextSegment matchingSegment = TextSegment.from("matching", matchingMetadata); Embedding matchingEmbedding = embeddingModel().embed(matchingSegment).content(); - embeddingStore().add(matchingEmbedding, matchingSegment); + embeddings.add(matchingEmbedding); + segments.add(matchingSegment); } for (Metadata notMatchingMetadata : notMatchingMetadatas) { TextSegment notMatchingSegment = TextSegment.from("not matching", notMatchingMetadata); Embedding notMatchingEmbedding = embeddingModel().embed(notMatchingSegment).content(); - embeddingStore().add(notMatchingEmbedding, notMatchingSegment); + embeddings.add(notMatchingEmbedding); + segments.add(notMatchingSegment); } TextSegment notMatchingSegmentWithoutMetadata = TextSegment.from("matching"); - Embedding notMatchingWithoutMetadataEmbedding = embeddingModel().embed(notMatchingSegmentWithoutMetadata).content(); - embeddingStore().add(notMatchingWithoutMetadataEmbedding, notMatchingSegmentWithoutMetadata); + Embedding notMatchingEmbeddingWithoutMetadata = embeddingModel().embed(notMatchingSegmentWithoutMetadata).content(); + embeddings.add(notMatchingEmbeddingWithoutMetadata); + segments.add(notMatchingSegmentWithoutMetadata); + + embeddingStore().addAll(embeddings, segments); awaitUntilPersisted(); diff --git a/langchain4j-core/src/test/java/dev/langchain4j/store/embedding/EmbeddingStoreWithRemovalIT.java b/langchain4j-core/src/test/java/dev/langchain4j/store/embedding/EmbeddingStoreWithRemovalIT.java new file mode 100644 index 0000000000..ce25fbcaef --- /dev/null +++ b/langchain4j-core/src/test/java/dev/langchain4j/store/embedding/EmbeddingStoreWithRemovalIT.java @@ -0,0 +1,163 @@ +package dev.langchain4j.store.embedding; + +import dev.langchain4j.data.embedding.Embedding; +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.model.embedding.EmbeddingModel; +import dev.langchain4j.store.embedding.filter.Filter; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.NullAndEmptySource; +import org.junit.jupiter.params.provider.ValueSource; + +import java.util.Collection; +import java.util.List; + +import static dev.langchain4j.data.document.Metadata.metadata; +import static dev.langchain4j.store.embedding.filter.MetadataFilterBuilder.metadataKey; +import static java.util.Arrays.asList; +import static java.util.Collections.emptyList; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +public abstract class EmbeddingStoreWithRemovalIT { + + protected abstract EmbeddingStore embeddingStore(); + + protected abstract EmbeddingModel embeddingModel(); + + @Test + void should_remove_by_id() { + + // given + Embedding embedding1 = embeddingModel().embed("test1").content(); + String id1 = embeddingStore().add(embedding1); + + Embedding embedding2 = embeddingModel().embed("test2").content(); + String id2 = embeddingStore().add(embedding2); + + assertThat(getAllEmbeddings()).hasSize(2); + + // when + embeddingStore().remove(id1); + + // then + List> relevant = getAllEmbeddings(); + assertThat(relevant).hasSize(1); + assertThat(relevant.get(0).embeddingId()).isEqualTo(id2); + } + + @ParameterizedTest + @NullAndEmptySource + @ValueSource(strings = " ") + void should_fail_to_remove_by_id(String id) { + + assertThatThrownBy(() -> embeddingStore().remove(id)) + .isExactlyInstanceOf(IllegalArgumentException.class) + .hasMessage("id cannot be null or blank"); + } + + @Test + void should_remove_all_by_ids() { + + // given + Embedding embedding1 = embeddingModel().embed("test1").content(); + String id1 = embeddingStore().add(embedding1); + + Embedding embedding2 = embeddingModel().embed("test2").content(); + String id2 = embeddingStore().add(embedding2); + + Embedding embedding3 = embeddingModel().embed("test3").content(); + String id3 = embeddingStore().add(embedding3); + + assertThat(getAllEmbeddings()).hasSize(3); + + // when + embeddingStore().removeAll(asList(id1, id2)); + + // then + List> relevant = getAllEmbeddings(); + assertThat(relevant).hasSize(1); + assertThat(relevant.get(0).embeddingId()).isEqualTo(id3); + } + + @Test + void should_fail_to_remove_all_by_ids_null() { + + assertThatThrownBy(() -> embeddingStore().removeAll((Collection) null)) + .isExactlyInstanceOf(IllegalArgumentException.class) + .hasMessage("ids cannot be null or empty"); + } + + @Test + void should_fail_to_remove_all_by_ids_empty() { + + assertThatThrownBy(() -> embeddingStore().removeAll(emptyList())) + .isExactlyInstanceOf(IllegalArgumentException.class) + .hasMessage("ids cannot be null or empty"); + } + + @Test + void should_remove_all_by_filter() { + + // given + TextSegment segment1 = TextSegment.from("matching", metadata("type", "a")); + Embedding embedding1 = embeddingModel().embed(segment1).content(); + embeddingStore().add(embedding1, segment1); + + TextSegment segment2 = TextSegment.from("matching", metadata("type", "a")); + Embedding embedding2 = embeddingModel().embed(segment2).content(); + embeddingStore().add(embedding2, segment2); + + Embedding embedding3 = embeddingModel().embed("not matching").content(); + String id3 = embeddingStore().add(embedding3); + + assertThat(getAllEmbeddings()).hasSize(3); + + // when + embeddingStore().removeAll(metadataKey("type").isEqualTo("a")); + + // then + List> relevant = getAllEmbeddings(); + assertThat(relevant).hasSize(1); + assertThat(relevant.get(0).embeddingId()).isEqualTo(id3); + } + + @Test + void should_fail_to_remove_all_by_filter_null() { + + assertThatThrownBy(() -> embeddingStore().removeAll((Filter) null)) + .isExactlyInstanceOf(IllegalArgumentException.class) + .hasMessage("filter cannot be null"); + } + + @Test + void should_remove_all() { + + // given + Embedding embedding1 = embeddingModel().embed("test1").content(); + embeddingStore().add(embedding1); + + Embedding embedding2 = embeddingModel().embed("test2").content(); + embeddingStore().add(embedding2); + + assertThat(getAllEmbeddings()).hasSize(2); + + // when + embeddingStore().removeAll(); + + // then + assertThat(getAllEmbeddings()).isEmpty(); + } + + private List> getAllEmbeddings() { + + EmbeddingSearchRequest embeddingSearchRequest = EmbeddingSearchRequest.builder() + .queryEmbedding(embeddingModel().embed("test").content()) + .maxResults(1000) + .build(); + + EmbeddingSearchResult searchResult = embeddingStore().search(embeddingSearchRequest); + + return searchResult.matches(); + } +} diff --git a/langchain4j-core/src/test/java/dev/langchain4j/web/search/WebSearchEngineIT.java b/langchain4j-core/src/test/java/dev/langchain4j/web/search/WebSearchEngineIT.java new file mode 100644 index 0000000000..0e537c0657 --- /dev/null +++ b/langchain4j-core/src/test/java/dev/langchain4j/web/search/WebSearchEngineIT.java @@ -0,0 +1,54 @@ +package dev.langchain4j.web.search; + +import org.junit.jupiter.api.Test; + +import java.util.List; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * A minimum set of tests that each implementation of {@link WebSearchEngine} must pass. + */ +public abstract class WebSearchEngineIT { + + protected abstract WebSearchEngine searchEngine(); + + @Test + void should_search() { + + // when + WebSearchResults webSearchResults = searchEngine().search("LangChain4j"); + + // then + List results = webSearchResults.results(); + assertThat(results).hasSize(5); + + results.forEach(result -> { + assertThat(result.title()).isNotBlank(); + assertThat(result.url()).isNotNull(); + assertThat(result.snippet()).isNotBlank(); + assertThat(result.content()).isNull(); + }); + + assertThat(results).anyMatch(result -> result.url().toString().contains("https://github.com/langchain4j")); + } + + @Test + void should_search_with_max_results() { + + // given + int maxResults = 7; + + WebSearchRequest request = WebSearchRequest.builder() + .searchTerms("LangChain4j") + .maxResults(maxResults) + .build(); + + // when + WebSearchResults webSearchResults = searchEngine().search(request); + + // then + List results = webSearchResults.results(); + assertThat(results).hasSize(maxResults); + } +} diff --git a/langchain4j-core/src/test/java/dev/langchain4j/web/search/WebSearchInformationResultTest.java b/langchain4j-core/src/test/java/dev/langchain4j/web/search/WebSearchInformationResultTest.java new file mode 100644 index 0000000000..8664c45134 --- /dev/null +++ b/langchain4j-core/src/test/java/dev/langchain4j/web/search/WebSearchInformationResultTest.java @@ -0,0 +1,53 @@ +package dev.langchain4j.web.search; + +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertThrows; + +class WebSearchInformationResultTest { + + @Test + void should_return_webSearchInformationResult_with_default_values(){ + WebSearchInformationResult webSearchInformationResult = new WebSearchInformationResult(1L); + + assertThat(webSearchInformationResult.totalResults()).isEqualTo(1L); + assertThat(webSearchInformationResult.pageNumber()).isNull(); + assertThat(webSearchInformationResult.metadata()).isNull(); + + assertThat(webSearchInformationResult).hasToString("WebSearchInformationResult{totalResults=1, pageNumber=null, metadata=null}"); + } + + @Test + void should_return_webSearchInformationResult_with_informationResult(){ + WebSearchInformationResult webSearchInformationResult = WebSearchInformationResult.from(1L); + + assertThat(webSearchInformationResult.totalResults()).isEqualTo(1L); + assertThat(webSearchInformationResult.pageNumber()).isNull(); + assertThat(webSearchInformationResult.metadata()).isNull(); + + assertThat(webSearchInformationResult).hasToString("WebSearchInformationResult{totalResults=1, pageNumber=null, metadata=null}"); + } + + @Test + void test_equals_and_hash(){ + WebSearchInformationResult wsi1 = WebSearchInformationResult.from(1L); + WebSearchInformationResult wsi2 = WebSearchInformationResult.from(1L); + + assertThat(wsi1) + .isEqualTo(wsi1) + .isNotEqualTo(null) + .isNotEqualTo(new Object()) + .isEqualTo(wsi2) + .hasSameHashCodeAs(wsi2); + + assertThat(WebSearchInformationResult.from(2L)) + .isNotEqualTo(wsi1); + } + + @Test + void should_throw_illegalArgumentException(){ + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> WebSearchInformationResult.from(null)); + assertThat(exception.getMessage()).isEqualTo("totalResults cannot be null"); + } +} diff --git a/langchain4j-core/src/test/java/dev/langchain4j/web/search/WebSearchOrganicResultTest.java b/langchain4j-core/src/test/java/dev/langchain4j/web/search/WebSearchOrganicResultTest.java new file mode 100644 index 0000000000..f96675fd93 --- /dev/null +++ b/langchain4j-core/src/test/java/dev/langchain4j/web/search/WebSearchOrganicResultTest.java @@ -0,0 +1,134 @@ +package dev.langchain4j.web.search; + +import dev.langchain4j.data.document.Metadata; +import org.junit.jupiter.api.Test; + +import java.net.URI; +import java.util.AbstractMap; +import java.util.Map; +import java.util.stream.Stream; + +import static java.util.stream.Collectors.toMap; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertThrows; + +class WebSearchOrganicResultTest { + + @Test + void should_build_webSearchOrganicResult_with_default_values(){ + WebSearchOrganicResult webSearchOrganicResult = WebSearchOrganicResult.from("title", URI.create("https://google.com")); + + assertThat(webSearchOrganicResult.title()).isEqualTo("title"); + assertThat(webSearchOrganicResult.url().toString()).isEqualTo("https://google.com"); + assertThat(webSearchOrganicResult.snippet()).isNull(); + assertThat(webSearchOrganicResult.content()).isNull(); + assertThat(webSearchOrganicResult.metadata()).isNull(); + } + + @Test + void should_build_webSearchOrganicResult_with_custom_snippet(){ + WebSearchOrganicResult webSearchOrganicResult = WebSearchOrganicResult.from("title", URI.create("https://google.com"), "snippet", null); + + assertThat(webSearchOrganicResult.title()).isEqualTo("title"); + assertThat(webSearchOrganicResult.url().toString()).isEqualTo("https://google.com"); + assertThat(webSearchOrganicResult.snippet()).isEqualTo("snippet"); + assertThat(webSearchOrganicResult.content()).isNull(); + assertThat(webSearchOrganicResult.metadata()).isNull(); + + assertThat(webSearchOrganicResult).hasToString("WebSearchOrganicResult{title='title', url=https://google.com, snippet='snippet', content='null', metadata=null}"); + } + + @Test + void should_build_webSearchOrganicResult_with_custom_content(){ + WebSearchOrganicResult webSearchOrganicResult = WebSearchOrganicResult.from("title", URI.create("https://google.com"), null, "content"); + + assertThat(webSearchOrganicResult.title()).isEqualTo("title"); + assertThat(webSearchOrganicResult.url().toString()).isEqualTo("https://google.com"); + assertThat(webSearchOrganicResult.snippet()).isNull(); + assertThat(webSearchOrganicResult.content()).isEqualTo("content"); + assertThat(webSearchOrganicResult.metadata()).isNull(); + + assertThat(webSearchOrganicResult).hasToString("WebSearchOrganicResult{title='title', url=https://google.com, snippet='null', content='content', metadata=null}"); + } + + @Test + void should_build_webSearchOrganicResult_with_custom_title_link_and_metadata(){ + WebSearchOrganicResult webSearchOrganicResult = WebSearchOrganicResult.from("title", URI.create("https://google.com"), "snippet", null, + Stream.of(new AbstractMap.SimpleEntry<>("key", "value")).collect(toMap(Map.Entry::getKey, Map.Entry::getValue))); + + assertThat(webSearchOrganicResult.title()).isEqualTo("title"); + assertThat(webSearchOrganicResult.url().toString()).isEqualTo("https://google.com"); + assertThat(webSearchOrganicResult.snippet()).isEqualTo("snippet"); + assertThat(webSearchOrganicResult.metadata()).containsExactly(new AbstractMap.SimpleEntry<>("key", "value")); + + assertThat(webSearchOrganicResult).hasToString("WebSearchOrganicResult{title='title', url=https://google.com, snippet='snippet', content='null', metadata={key=value}}"); + } + + @Test + void test_equals_and_hash(){ + WebSearchOrganicResult wsor1 = WebSearchOrganicResult.from("title", URI.create("https://google.com"), "snippet", null, + Stream.of(new AbstractMap.SimpleEntry<>("key", "value")).collect(toMap(Map.Entry::getKey, Map.Entry::getValue))); + + WebSearchOrganicResult wsor2 = WebSearchOrganicResult.from("title", URI.create("https://google.com"), "snippet", null, + Stream.of(new AbstractMap.SimpleEntry<>("key", "value")).collect(toMap(Map.Entry::getKey, Map.Entry::getValue))); + + assertThat(wsor1) + .isEqualTo(wsor1) + .isNotEqualTo(null) + .isNotEqualTo(new Object()) + .isEqualTo(wsor2) + .hasSameHashCodeAs(wsor2); + + assertThat(WebSearchOrganicResult.from("other title", URI.create("https://google.com"), "snippet", null, + Stream.of(new AbstractMap.SimpleEntry<>("key", "value")).collect(toMap(Map.Entry::getKey, Map.Entry::getValue)))) + .isNotEqualTo(wsor1); + + assertThat(WebSearchOrganicResult.from("title", URI.create("https://docs.langchain4j.dev"), "snippet", null, + Stream.of(new AbstractMap.SimpleEntry<>("key", "value")).collect(toMap(Map.Entry::getKey, Map.Entry::getValue)))) + .isNotEqualTo(wsor1); + + assertThat(WebSearchOrganicResult.from("title", URI.create("https://google.com"), "other snippet", null, + Stream.of(new AbstractMap.SimpleEntry<>("key", "value")).collect(toMap(Map.Entry::getKey, Map.Entry::getValue)))) + .isNotEqualTo(wsor1); + + assertThat(WebSearchOrganicResult.from("title", URI.create("https://google.com"), "snippet", null, + Stream.of(new AbstractMap.SimpleEntry<>("other key", "value")).collect(toMap(Map.Entry::getKey, Map.Entry::getValue)))) + .isNotEqualTo(wsor1); + } + + @Test + void should_return_textSegment(){ + WebSearchOrganicResult webSearchOrganicResult = WebSearchOrganicResult.from("title", URI.create("https://google.com"), "snippet", null, + Stream.of(new AbstractMap.SimpleEntry<>("key", "value")).collect(toMap(Map.Entry::getKey, Map.Entry::getValue))); + + assertThat(webSearchOrganicResult.toTextSegment().text()).isEqualTo("title\nsnippet"); + assertThat(webSearchOrganicResult.toTextSegment().metadata()).isEqualTo( + Metadata.from(Stream.of( + new AbstractMap.SimpleEntry<>("url", "https://google.com"), + new AbstractMap.SimpleEntry<>("key", "value")) + .collect(toMap(Map.Entry::getKey, Map.Entry::getValue)) + ) + ); + } + + @Test + void should_return_document(){ + WebSearchOrganicResult webSearchOrganicResult = WebSearchOrganicResult.from("title", URI.create("https://google.com"), "snippet", null, + Stream.of(new AbstractMap.SimpleEntry<>("key", "value")).collect(toMap(Map.Entry::getKey, Map.Entry::getValue))); + + assertThat(webSearchOrganicResult.toDocument().text()).isEqualTo("title\nsnippet"); + assertThat(webSearchOrganicResult.toDocument().metadata()).isEqualTo( + Metadata.from(Stream.of( + new AbstractMap.SimpleEntry<>("url", "https://google.com"), + new AbstractMap.SimpleEntry<>("key", "value")) + .collect(toMap(Map.Entry::getKey, Map.Entry::getValue)) + ) + ); + } + + @Test + void should_throw_illegalArgumentException_without_title(){ + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> WebSearchOrganicResult.from(null, URI.create("https://google.com"), "snippet", "content")); + assertThat(exception).hasMessage("title cannot be null or blank"); + } +} diff --git a/langchain4j-core/src/test/java/dev/langchain4j/web/search/WebSearchRequestTest.java b/langchain4j-core/src/test/java/dev/langchain4j/web/search/WebSearchRequestTest.java new file mode 100644 index 0000000000..375e48586b --- /dev/null +++ b/langchain4j-core/src/test/java/dev/langchain4j/web/search/WebSearchRequestTest.java @@ -0,0 +1,99 @@ +package dev.langchain4j.web.search; + +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertThrows; + +class WebSearchRequestTest { + + @Test + void should_build_webSearchRequest_with_default_values(){ + WebSearchRequest webSearchRequest = WebSearchRequest.from("query"); + + assertThat(webSearchRequest.searchTerms()).isEqualTo("query"); + assertThat(webSearchRequest.startPage()).isEqualTo(1); + assertThat(webSearchRequest.maxResults()).isNull(); + assertThat(webSearchRequest.language()).isNull(); + assertThat(webSearchRequest.geoLocation()).isNull(); + assertThat(webSearchRequest.startIndex()).isNull(); + assertThat(webSearchRequest.safeSearch()).isTrue(); + assertThat(webSearchRequest.additionalParams()).isEmpty(); + + assertThat(webSearchRequest).hasToString("WebSearchRequest{searchTerms='query', maxResults=null, language='null', geoLocation='null', startPage=1, startIndex=null, siteRestrict=true, additionalParams={}}"); + } + + @Test + void should_build_webSearchRequest_with_default_values_builder(){ + WebSearchRequest webSearchRequest = WebSearchRequest.builder().searchTerms("query").build(); + + assertThat(webSearchRequest.searchTerms()).isEqualTo("query"); + assertThat(webSearchRequest.startPage()).isEqualTo(1); + assertThat(webSearchRequest.maxResults()).isNull(); + assertThat(webSearchRequest.language()).isNull(); + assertThat(webSearchRequest.geoLocation()).isNull(); + assertThat(webSearchRequest.startIndex()).isNull(); + assertThat(webSearchRequest.safeSearch()).isTrue(); + assertThat(webSearchRequest.additionalParams()).isEmpty(); + + assertThat(webSearchRequest).hasToString("WebSearchRequest{searchTerms='query', maxResults=null, language='null', geoLocation='null', startPage=1, startIndex=null, siteRestrict=true, additionalParams={}}"); + } + + @Test + void should_build_webSearchRequest_with_custom_maxResults(){ + WebSearchRequest webSearchRequest = WebSearchRequest.from("query", 10); + + assertThat(webSearchRequest.searchTerms()).isEqualTo("query"); + assertThat(webSearchRequest.startPage()).isEqualTo(1); + assertThat(webSearchRequest.maxResults()).isEqualTo(10); + assertThat(webSearchRequest.language()).isNull(); + assertThat(webSearchRequest.geoLocation()).isNull(); + assertThat(webSearchRequest.startIndex()).isNull(); + assertThat(webSearchRequest.safeSearch()).isTrue(); + assertThat(webSearchRequest.additionalParams()).isEmpty(); + + assertThat(webSearchRequest).hasToString("WebSearchRequest{searchTerms='query', maxResults=10, language='null', geoLocation='null', startPage=1, startIndex=null, siteRestrict=true, additionalParams={}}"); + } + + @Test + void should_build_webSearchRequest_with_custom_maxResults_builder(){ + WebSearchRequest webSearchRequest = WebSearchRequest.builder().searchTerms("query").maxResults(10).build(); + + assertThat(webSearchRequest.searchTerms()).isEqualTo("query"); + assertThat(webSearchRequest.startPage()).isEqualTo(1); + assertThat(webSearchRequest.maxResults()).isEqualTo(10); + assertThat(webSearchRequest.language()).isNull(); + assertThat(webSearchRequest.geoLocation()).isNull(); + assertThat(webSearchRequest.startIndex()).isNull(); + assertThat(webSearchRequest.safeSearch()).isTrue(); + assertThat(webSearchRequest.additionalParams()).isEmpty(); + + assertThat(webSearchRequest).hasToString("WebSearchRequest{searchTerms='query', maxResults=10, language='null', geoLocation='null', startPage=1, startIndex=null, siteRestrict=true, additionalParams={}}"); + } + + @Test + void test_equals_and_hash(){ + WebSearchRequest wsr1 = WebSearchRequest.from("query", 10); + WebSearchRequest wsr2 = WebSearchRequest.from("query", 10); + + assertThat(wsr1) + .isEqualTo(wsr1) + .isNotEqualTo(null) + .isNotEqualTo(new Object()) + .isEqualTo(wsr2) + .hasSameHashCodeAs(wsr2); + + assertThat(WebSearchRequest.from("other query", 10)) + .isNotEqualTo(wsr1); + + assertThat(WebSearchRequest.from("query", 20)) + .isNotEqualTo(wsr1); + } + + @Test + void should_throw_illegalArgumentException_without_searchTerms(){ + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> + WebSearchRequest.builder().build()); + assertThat(exception).hasMessage("searchTerms cannot be null or blank"); + } +} diff --git a/langchain4j-core/src/test/java/dev/langchain4j/web/search/WebSearchResultsTest.java b/langchain4j-core/src/test/java/dev/langchain4j/web/search/WebSearchResultsTest.java new file mode 100644 index 0000000000..d73285b510 --- /dev/null +++ b/langchain4j-core/src/test/java/dev/langchain4j/web/search/WebSearchResultsTest.java @@ -0,0 +1,109 @@ +package dev.langchain4j.web.search; + +import dev.langchain4j.data.document.Metadata; +import org.junit.jupiter.api.Test; + +import java.net.URI; +import java.util.AbstractMap; +import java.util.HashMap; +import java.util.Map; +import java.util.stream.Stream; + +import static java.util.Collections.emptyList; +import static java.util.Collections.singletonList; +import static java.util.stream.Collectors.toMap; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.anyList; + +class WebSearchResultsTest { + + @Test + void should_build_webSearchResults(){ + WebSearchResults webSearchResults = WebSearchResults.from( + WebSearchInformationResult.from(1L), + singletonList(WebSearchOrganicResult.from("title", URI.create("https://google.com")))); + + assertThat(webSearchResults.results()).hasSize(1); + assertThat(webSearchResults.results().get(0).url().toString()).isEqualTo("https://google.com"); + assertThat(webSearchResults.searchInformation().totalResults()).isEqualTo(1L); + + assertThat(webSearchResults).hasToString("WebSearchResults{searchMetadata=null, searchInformation=WebSearchInformationResult{totalResults=1, pageNumber=null, metadata=null}, results=[WebSearchOrganicResult{title='title', url=https://google.com, snippet='null', content='null', metadata=null}]}"); + } + + @Test + void test_equals_and_hash(){ + WebSearchResults wsr1 = WebSearchResults.from( + WebSearchInformationResult.from(1L), + singletonList(WebSearchOrganicResult.from("title", URI.create("https://google.com")))); + + WebSearchResults wsr2 = WebSearchResults.from( + WebSearchInformationResult.from(1L), + singletonList(WebSearchOrganicResult.from("title", URI.create("https://google.com")))); + + assertThat(wsr1) + .isEqualTo(wsr1) + .isNotEqualTo(null) + .isNotEqualTo(new Object()) + .isEqualTo(wsr2) + .hasSameHashCodeAs(wsr2); + + assertThat(WebSearchResults.from( + WebSearchInformationResult.from(1L), + singletonList(WebSearchOrganicResult.from("title", URI.create("https://docs.langchain4j.dev"))))) + .isNotEqualTo(wsr1); + + assertThat(WebSearchResults.from( + WebSearchInformationResult.from(2L), + singletonList(WebSearchOrganicResult.from("title", URI.create("https://google.com"))))) + .isNotEqualTo(wsr1); + } + + @Test + void should_return_array_of_textSegments_with_snippet(){ + WebSearchResults webSearchResults = WebSearchResults.from( + WebSearchInformationResult.from(1L), + singletonList(WebSearchOrganicResult.from("title", URI.create("https://google.com"),"snippet", null))); + + assertThat(webSearchResults.toTextSegments()).hasSize(1); + assertThat(webSearchResults.toTextSegments().get(0).text()).isEqualTo("title\nsnippet"); + assertThat(webSearchResults.toTextSegments().get(0).metadata()).isEqualTo(Metadata.from("url", "https://google.com")); + } + + @Test + void should_return_array_of_documents_with_content(){ + WebSearchResults webSearchResults = WebSearchResults.from( + WebSearchInformationResult.from(1L), + singletonList(WebSearchOrganicResult.from("title", URI.create("https://google.com"),null, "content"))); + + assertThat(webSearchResults.toDocuments()).hasSize(1); + assertThat(webSearchResults.toDocuments().get(0).text()).isEqualTo("title\ncontent"); + assertThat(webSearchResults.toDocuments().get(0).metadata()).isEqualTo(Metadata.from("url", "https://google.com")); + } + + @Test + void should_throw_illegalArgumentException_without_searchInformation(){ + // given + Map searchMetadata = new HashMap<>(); + searchMetadata.put("key", "value"); + + // then + assertThrows(IllegalArgumentException.class, () -> new WebSearchResults( + searchMetadata, + null, + singletonList(WebSearchOrganicResult.from("title", URI.create("https://google.com"),"snippet",null)))); + } + + @Test + void should_throw_illegalArgumentException_without_results(){ + // given + Map searchMetadata = new HashMap<>(); + searchMetadata.put("key", "value"); + + // then + assertThrows(IllegalArgumentException.class, () -> new WebSearchResults( + searchMetadata, + WebSearchInformationResult.from(1L), + emptyList())); + } +} diff --git a/langchain4j-core/src/test/java/dev/langchain4j/web/search/WebSearchToolIT.java b/langchain4j-core/src/test/java/dev/langchain4j/web/search/WebSearchToolIT.java new file mode 100644 index 0000000000..282812dc7b --- /dev/null +++ b/langchain4j-core/src/test/java/dev/langchain4j/web/search/WebSearchToolIT.java @@ -0,0 +1,60 @@ +package dev.langchain4j.web.search; + +import dev.langchain4j.agent.tool.ToolSpecification; +import dev.langchain4j.agent.tool.ToolSpecifications; +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.UserMessage; +import dev.langchain4j.model.chat.ChatLanguageModel; +import org.junit.jupiter.api.Test; + +import java.util.List; + +import static java.util.Collections.singletonList; +import static org.assertj.core.api.Assertions.assertThat; + +public abstract class WebSearchToolIT { + + protected abstract WebSearchEngine searchEngine(); + + protected abstract ChatLanguageModel chatLanguageModel(); + + @Test + void should_be_usable_tool_with_chatLanguageModel() { + // given + WebSearchTool webSearchTool = WebSearchTool.from(searchEngine()); + List tools = ToolSpecifications.toolSpecificationsFrom(webSearchTool); + + UserMessage userMessage = UserMessage.from("What is LangChain4j project?"); + + // when + AiMessage aiMessage = chatLanguageModel().generate(singletonList(userMessage), tools).content(); + + // then + assertThat(aiMessage.hasToolExecutionRequests()).isTrue(); + assertThat(aiMessage.toolExecutionRequests()) + .anySatisfy(toolSpec -> { + assertThat(toolSpec.name()) + .containsIgnoringCase("searchWeb"); + assertThat(toolSpec.arguments()) + .isNotBlank(); + } + ); + } + + @Test + void should_return_pretty_result_as_a_tool() { + // given + WebSearchTool webSearchTool = WebSearchTool.from(searchEngine()); + String searchTerm = "What is LangChain4j project?"; + + // when + String strResult = webSearchTool.searchWeb(searchTerm); + + // then + assertThat(strResult).isNotBlank(); + assertThat(strResult) + .as("At least the string result should be contains 'java' and 'AI' ignoring case") + .containsIgnoringCase("Java") + .containsIgnoringCase("AI"); + } +} diff --git a/langchain4j-core/src/test/java/dev/langchain4j/web/search/WebSearchToolTest.java b/langchain4j-core/src/test/java/dev/langchain4j/web/search/WebSearchToolTest.java new file mode 100644 index 0000000000..8ad8a04e9a --- /dev/null +++ b/langchain4j-core/src/test/java/dev/langchain4j/web/search/WebSearchToolTest.java @@ -0,0 +1,58 @@ +package dev.langchain4j.web.search; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.net.URI; +import java.util.HashMap; + +import static java.util.Arrays.asList; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.*; +import static org.mockito.Mockito.verifyNoMoreInteractions; + +class WebSearchToolTest { + + WebSearchEngine webSearchEngine; + + @BeforeEach + void mockWebSearchEngine(){ + webSearchEngine = mock(WebSearchEngine.class); + when(webSearchEngine.search(anyString())).thenReturn( + new WebSearchResults( + WebSearchInformationResult.from(3L,1, new HashMap<>()), + asList( + WebSearchOrganicResult.from("title 1", URI.create("https://google.com"), "snippet 1", "content 1"), + WebSearchOrganicResult.from("title 2", URI.create("https://docs.langchain4j.dev"), "snippet 2", "content 2"), + WebSearchOrganicResult.from("title 3", URI.create("https://github.com/dewitt/opensearch/blob/master/README.md"), "snippet 3","content 3") + ) + ) + ); + } + + @AfterEach + void resetWebSearchEngine(){ + reset(webSearchEngine); + } + + @Test + void should_build_webSearchTool(){ + // given + String searchTerm = "Any text to search"; + WebSearchTool webSearchTool = WebSearchTool.from(webSearchEngine); + + // when + String strResult = webSearchTool.searchWeb(searchTerm); + + // then + assertThat(strResult).isNotBlank(); + assertThat(strResult) + .as("At least one result should be contains 'title 1' and 'https://google.com' and 'content 1'") + .contains("Title: title 1\nSource: https://google.com\nContent:\ncontent 1"); + + verify(webSearchEngine).search(searchTerm); + verifyNoMoreInteractions(webSearchEngine); + } +} diff --git a/langchain4j-dashscope/pom.xml b/langchain4j-dashscope/pom.xml index 7fe46715d6..c3b1730634 100644 --- a/langchain4j-dashscope/pom.xml +++ b/langchain4j-dashscope/pom.xml @@ -7,7 +7,7 @@ dev.langchain4j langchain4j-parent - 0.30.0 + 0.32.0-SNAPSHOT ../langchain4j-parent/pom.xml @@ -24,7 +24,7 @@ com.alibaba dashscope-sdk-java - 2.10.1 + 2.14.7 diff --git a/langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenChatModel.java b/langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenChatModel.java index 731a594df6..8daa9bd9aa 100644 --- a/langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenChatModel.java +++ b/langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenChatModel.java @@ -1,8 +1,8 @@ package dev.langchain4j.model.dashscope; import com.alibaba.dashscope.aigc.generation.Generation; +import com.alibaba.dashscope.aigc.generation.GenerationParam; import com.alibaba.dashscope.aigc.generation.GenerationResult; -import com.alibaba.dashscope.aigc.generation.models.QwenParam; import com.alibaba.dashscope.aigc.multimodalconversation.MultiModalConversation; import com.alibaba.dashscope.aigc.multimodalconversation.MultiModalConversationParam; import com.alibaba.dashscope.aigc.multimodalconversation.MultiModalConversationResult; @@ -10,6 +10,7 @@ import com.alibaba.dashscope.exception.NoApiKeyException; import com.alibaba.dashscope.exception.UploadFileException; import com.alibaba.dashscope.protocol.Protocol; +import dev.langchain4j.agent.tool.ToolSpecification; import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.data.message.ChatMessage; import dev.langchain4j.internal.Utils; @@ -18,12 +19,18 @@ import dev.langchain4j.model.output.Response; import lombok.Builder; +import java.util.Collections; import java.util.List; -import static com.alibaba.dashscope.aigc.generation.models.QwenParam.ResultFormat.MESSAGE; +import static com.alibaba.dashscope.aigc.conversation.ConversationParam.ResultFormat.MESSAGE; +import static dev.langchain4j.internal.Utils.isNullOrEmpty; import static dev.langchain4j.model.dashscope.QwenHelper.*; import static dev.langchain4j.spi.ServiceHelper.loadFactories; +/** + * Represents a Qwen language model with a chat completion interface. + * More details are available here. + */ public class QwenChatModel implements ChatLanguageModel { private final String apiKey; private final String modelName; @@ -80,12 +87,30 @@ protected QwenChatModel(String baseUrl, @Override public Response generate(List messages) { - return isMultimodalModel ? generateByMultimodalModel(messages) : generateByNonMultimodalModel(messages); + return isMultimodalModel ? + generateByMultimodalModel(messages, null, null) : + generateByNonMultimodalModel(messages, null, null); } - private Response generateByNonMultimodalModel(List messages) { + @Override + public Response generate(List messages, List toolSpecifications) { + return isMultimodalModel ? + generateByMultimodalModel(messages, toolSpecifications, null) : + generateByNonMultimodalModel(messages, toolSpecifications, null); + } + + @Override + public Response generate(List messages, ToolSpecification toolSpecification) { + return isMultimodalModel ? + generateByMultimodalModel(messages, null, toolSpecification) : + generateByNonMultimodalModel(messages, null, toolSpecification); + } + + private Response generateByNonMultimodalModel(List messages, + List toolSpecifications, + ToolSpecification toolThatMustBeExecuted) { try { - QwenParam.QwenParamBuilder builder = QwenParam.builder() + GenerationParam.GenerationParamBuilder builder = GenerationParam.builder() .apiKey(apiKey) .model(modelName) .topP(topP) @@ -102,17 +127,32 @@ private Response generateByNonMultimodalModel(List messa builder.stopStrings(stops); } + if (!isNullOrEmpty(toolSpecifications)) { + builder.tools(toToolFunctions(toolSpecifications)); + } else if (toolThatMustBeExecuted != null) { + builder.tools(toToolFunctions(Collections.singleton(toolThatMustBeExecuted))); + builder.toolChoice(toToolFunction(toolThatMustBeExecuted)); + } + GenerationResult generationResult = generation.call(builder.build()); - String answer = answerFrom(generationResult); - return Response.from(AiMessage.from(answer), - tokenUsageFrom(generationResult), finishReasonFrom(generationResult)); + return Response.from( + aiMessageFrom(generationResult), + tokenUsageFrom(generationResult), + finishReasonFrom(generationResult) + ); } catch (NoApiKeyException | InputRequiredException e) { throw new RuntimeException(e); } } - private Response generateByMultimodalModel(List messages) { + private Response generateByMultimodalModel(List messages, + List toolSpecifications, + ToolSpecification toolThatMustBeExecuted) { + if (toolThatMustBeExecuted != null || !isNullOrEmpty(toolSpecifications)) { + throw new IllegalArgumentException("Tools are currently not supported by this model"); + } + try { MultiModalConversationParam param = MultiModalConversationParam.builder() .apiKey(apiKey) diff --git a/langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenEmbeddingModel.java b/langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenEmbeddingModel.java index f6c61aaffd..bdadccee6b 100644 --- a/langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenEmbeddingModel.java +++ b/langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenEmbeddingModel.java @@ -19,11 +19,15 @@ import static dev.langchain4j.spi.ServiceHelper.loadFactories; import static java.util.Collections.singletonList; +/** + * An implementation of an {@link EmbeddingModel} that uses + * DashScope Embeddings API. + */ public class QwenEmbeddingModel implements EmbeddingModel { - public static final String TYPE_KEY = "type"; public static final String TYPE_QUERY = "query"; public static final String TYPE_DOCUMENT = "document"; + private static final int BATCH_SIZE = 25; private final String apiKey; private final String modelName; @@ -42,18 +46,37 @@ public QwenEmbeddingModel(String apiKey, String modelName) { private boolean containsDocuments(List textSegments) { return textSegments.stream() .map(TextSegment::metadata) - .map(metadata -> metadata.get(TYPE_KEY)) + .map(metadata -> metadata.getString(TYPE_KEY)) .anyMatch(TYPE_DOCUMENT::equalsIgnoreCase); } private boolean containsQueries(List textSegments) { return textSegments.stream() .map(TextSegment::metadata) - .map(metadata -> metadata.get(TYPE_KEY)) + .map(metadata -> metadata.getString(TYPE_KEY)) .anyMatch(TYPE_QUERY::equalsIgnoreCase); } - private Response> embedTexts(List textSegments, TextEmbeddingParam.TextType textType) { + private Response> embedTexts(List textSegments, + TextEmbeddingParam.TextType textType) { + int size = textSegments.size(); + if (size < BATCH_SIZE) { + return batchEmbedTexts(textSegments, textType); + } + + List allEmbeddings = new ArrayList<>(size); + TokenUsage allUsage = null; + for (int i = 0; i < size; i += BATCH_SIZE) { + List batchTextSegments = textSegments.subList(i, Math.min(size, i + BATCH_SIZE)); + Response> batchResponse = batchEmbedTexts(batchTextSegments, textType); + allEmbeddings.addAll(batchResponse.content()); + allUsage = TokenUsage.sum(allUsage, batchResponse.tokenUsage()); + } + + return Response.from(allEmbeddings, allUsage); + } + + private Response> batchEmbedTexts(List textSegments, TextEmbeddingParam.TextType textType) { TextEmbeddingParam param = TextEmbeddingParam.builder() .apiKey(apiKey) .model(modelName) @@ -99,7 +122,7 @@ public Response> embedAll(List textSegments) { Integer tokens = null; for (TextSegment textSegment : textSegments) { Response> result; - if (TYPE_QUERY.equalsIgnoreCase(textSegment.metadata(TYPE_KEY))) { + if (TYPE_QUERY.equalsIgnoreCase(textSegment.metadata().getString(TYPE_KEY))) { result = embedTexts(singletonList(textSegment), QUERY); } else { result = embedTexts(singletonList(textSegment), DOCUMENT); diff --git a/langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenHelper.java b/langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenHelper.java index f2aa1a52ea..5b12757373 100644 --- a/langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenHelper.java +++ b/langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenHelper.java @@ -7,6 +7,12 @@ import com.alibaba.dashscope.aigc.multimodalconversation.MultiModalConversationResult; import com.alibaba.dashscope.common.Message; import com.alibaba.dashscope.common.MultiModalMessage; +import com.alibaba.dashscope.tools.*; +import com.alibaba.dashscope.utils.JsonUtils; +import dev.langchain4j.agent.tool.ToolExecutionRequest; +import com.google.gson.JsonObject; +import dev.langchain4j.agent.tool.ToolParameters; +import dev.langchain4j.agent.tool.ToolSpecification; import dev.langchain4j.data.image.Image; import dev.langchain4j.data.message.*; import dev.langchain4j.internal.Utils; @@ -23,8 +29,9 @@ import java.util.stream.Collectors; import static com.alibaba.dashscope.common.Role.*; -import static dev.langchain4j.model.output.FinishReason.LENGTH; -import static dev.langchain4j.model.output.FinishReason.STOP; +import static dev.langchain4j.internal.Utils.getOrDefault; +import static dev.langchain4j.internal.Utils.isNullOrEmpty; +import static dev.langchain4j.model.output.FinishReason.*; import static java.util.stream.Collectors.toList; class QwenHelper { @@ -45,6 +52,9 @@ static Message toQwenMessage(ChatMessage message) { return Message.builder() .role(roleFrom(message)) .content(toSingleText(message)) + .name(nameFrom(message)) + .toolCallId(toolCallIdFrom(message)) + .toolCalls(toolCallsFrom(message)) .build(); } @@ -58,7 +68,7 @@ static String toSingleText(ChatMessage message) { .map(TextContent::text) .collect(Collectors.joining("\n")); case AI: - return ((AiMessage) message).text(); + return ((AiMessage) message).hasToolExecutionRequests() ? "" : ((AiMessage) message).text(); case SYSTEM: return ((SystemMessage) message).text(); case TOOL_EXECUTION_RESULT: @@ -68,6 +78,31 @@ static String toSingleText(ChatMessage message) { } } + static String nameFrom(ChatMessage message) { + switch (message.type()) { + case USER: + return ((UserMessage) message).name(); + case TOOL_EXECUTION_RESULT: + return ((ToolExecutionResultMessage) message).toolName(); + default: + return null; + } + } + + static String toolCallIdFrom(ChatMessage message) { + if (message.type() == ChatMessageType.TOOL_EXECUTION_RESULT) { + return ((ToolExecutionResultMessage) message).id(); + } + return null; + } + + static List toolCallsFrom(ChatMessage message) { + if (message.type() == ChatMessageType.AI && ((AiMessage) message).hasToolExecutionRequests()) { + return toToolCalls(((AiMessage) message).toolExecutionRequests()); + } + return null; + } + static List toQwenMultiModalMessages(List messages) { return messages.stream() .map(QwenHelper::toQwenMultiModalMessage) @@ -84,7 +119,7 @@ static MultiModalMessage toQwenMultiModalMessage(ChatMessage message) { static List> toMultiModalContents(ChatMessage message) { switch (message.type()) { case USER: - return((UserMessage) message).contents() + return ((UserMessage) message).contents() .stream() .map(QwenHelper::toMultiModalContent) .collect(Collectors.toList()); @@ -154,10 +189,12 @@ private static String saveImageAsTemporaryFile(String base64Data, String mimeTyp } static String roleFrom(ChatMessage message) { - if (message instanceof AiMessage) { + if (message.type() == ChatMessageType.AI) { return ASSISTANT.getValue(); - } else if (message instanceof SystemMessage) { + } else if (message.type() == ChatMessageType.SYSTEM) { return SYSTEM.getValue(); + } else if (message.type() == ChatMessageType.TOOL_EXECUTION_RESULT) { + return TOOL.getValue(); } else { return USER.getValue(); } @@ -228,19 +265,23 @@ static TokenUsage tokenUsageFrom(MultiModalConversationResult result) { } static FinishReason finishReasonFrom(GenerationResult result) { - String finishReason = Optional.of(result) - .map(GenerationResult::getOutput) - .map(GenerationOutput::getChoices) - .filter(choices -> !choices.isEmpty()) - .map(choices -> choices.get(0)) - .map(Choice::getFinishReason) - .orElse(""); + Choice choice = result.getOutput().getChoices().get(0); + String finishReason = choice.getFinishReason(); + if (finishReason == null) { + if (isNullOrEmpty(choice.getMessage().getToolCalls())) { + return null; + } + // Upon observation, when tool_calls occur, the returned finish_reason may be null, not "tool_calls". + finishReason = "tool_calls"; + } switch (finishReason) { case "stop": return STOP; case "length": return LENGTH; + case "tool_calls": + return TOOL_EXECUTION; default: return null; } @@ -269,4 +310,96 @@ public static boolean isMultimodalModel(String modelName) { // for now, multimodal models start with "qwen-vl" return modelName.startsWith("qwen-vl"); } -} + + static List toToolFunctions(Collection toolSpecifications) { + if (isNullOrEmpty(toolSpecifications)) { + return Collections.emptyList(); + } + + return toolSpecifications.stream() + .map(QwenHelper::toToolFunction) + .collect(Collectors.toList()); + } + + static ToolBase toToolFunction(ToolSpecification toolSpecification) { + FunctionDefinition functionDefinition = FunctionDefinition.builder() + .name(toolSpecification.name()) + .description(toolSpecification.description()) + .parameters(toParameters(toolSpecification.parameters())) + .build(); + return ToolFunction.builder().function(functionDefinition).build(); + } + + private static JsonObject toParameters(ToolParameters toolParameters) { + return toolParameters == null ? + JsonUtils.toJsonObject(Collections.emptyMap()) : + JsonUtils.toJsonObject(toolParameters); + } + + static AiMessage aiMessageFrom(GenerationResult result) { + return isFunctionToolCalls(result) ? + new AiMessage(functionToolCallsFrom(result)) : new AiMessage(answerFrom(result)); + } + + private static List functionToolCallsFrom(GenerationResult result) { + List toolCalls = Optional.of(result) + .map(GenerationResult::getOutput) + .map(GenerationOutput::getChoices) + .filter(choices -> !choices.isEmpty()) + .map(choices -> choices.get(0)) + .map(Choice::getMessage) + .map(Message::getToolCalls) + .orElseThrow(IllegalStateException::new); + + return toolCalls.stream() + .filter(ToolCallFunction.class::isInstance) + .map(ToolCallFunction.class::cast) + .map(toolCall -> ToolExecutionRequest.builder() + .id(getOrDefault(toolCall.getId(), () -> toolCallIdFromMessage(result))) + .name(toolCall.getFunction().getName()) + .arguments(toolCall.getFunction().getArguments()) + .build()) + .collect(Collectors.toList()); + } + + static String toolCallIdFromMessage(GenerationResult result) { + // Not sure about the difference between Message::getToolCallId() and ToolCallFunction::getId(). + // Currently, they all return null. + // Encapsulate a method to get the ID using Message::getToolCallId() when ToolCallFunction::getId() is null. + return Optional.of(result) + .map(GenerationResult::getOutput) + .map(GenerationOutput::getChoices) + .filter(choices -> !choices.isEmpty()) + .map(choices -> choices.get(0)) + .map(Choice::getMessage) + .map(Message::getToolCallId) + .orElse(null); + } + + static boolean isFunctionToolCalls(GenerationResult result) { + Optional> toolCallBases = Optional.of(result) + .map(GenerationResult::getOutput) + .map(GenerationOutput::getChoices) + .filter(choices -> !choices.isEmpty()) + .map(choices -> choices.get(0)) + .map(Choice::getMessage) + .map(Message::getToolCalls); + return toolCallBases.isPresent() && !isNullOrEmpty(toolCallBases.get()); + } + + private static List toToolCalls(Collection toolExecutionRequests) { + return toolExecutionRequests.stream() + .map(QwenHelper::toToolCall) + .collect(toList()); + } + + private static ToolCallBase toToolCall(ToolExecutionRequest toolExecutionRequest) { + ToolCallFunction toolCallFunction = new ToolCallFunction(); + toolCallFunction.setId(toolExecutionRequest.id()); + ToolCallFunction.CallFunction callFunction = toolCallFunction.new CallFunction(); + callFunction.setName(toolExecutionRequest.name()); + callFunction.setArguments(toolExecutionRequest.arguments()); + toolCallFunction.setFunction(callFunction); + return toolCallFunction; + } +} \ No newline at end of file diff --git a/langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenLanguageModel.java b/langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenLanguageModel.java index da1d289ca5..00bbda2900 100644 --- a/langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenLanguageModel.java +++ b/langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenLanguageModel.java @@ -1,8 +1,8 @@ package dev.langchain4j.model.dashscope; import com.alibaba.dashscope.aigc.generation.Generation; +import com.alibaba.dashscope.aigc.generation.GenerationParam; import com.alibaba.dashscope.aigc.generation.GenerationResult; -import com.alibaba.dashscope.aigc.generation.models.QwenParam; import com.alibaba.dashscope.exception.InputRequiredException; import com.alibaba.dashscope.exception.NoApiKeyException; import com.alibaba.dashscope.protocol.Protocol; @@ -14,12 +14,16 @@ import java.util.List; -import static com.alibaba.dashscope.aigc.generation.models.QwenParam.ResultFormat.MESSAGE; +import static com.alibaba.dashscope.aigc.generation.GenerationParam.ResultFormat.MESSAGE; import static dev.langchain4j.internal.Utils.isNullOrBlank; import static dev.langchain4j.model.dashscope.QwenHelper.*; import static dev.langchain4j.model.dashscope.QwenModelName.QWEN_PLUS; import static dev.langchain4j.spi.ServiceHelper.loadFactories; +/** + * Represents a Qwen language model with a text interface. + * More details are available here. + */ public class QwenLanguageModel implements LanguageModel { private final String apiKey; private final String modelName; @@ -71,7 +75,7 @@ public QwenLanguageModel(String baseUrl, @Override public Response generate(String prompt) { try { - QwenParam.QwenParamBuilder builder = QwenParam.builder() + GenerationParam.GenerationParamBuilder builder = GenerationParam.builder() .apiKey(apiKey) .model(modelName) .topP(topP) diff --git a/langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenModelName.java b/langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenModelName.java index 7538eb861f..016f60520d 100644 --- a/langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenModelName.java +++ b/langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenModelName.java @@ -4,13 +4,23 @@ * The LLMs provided by Alibaba Cloud, performs better than most LLMs in Asia languages. */ public class QwenModelName { - // Use with QwenChatModel and QwenLanguageModel public static final String QWEN_TURBO = "qwen-turbo"; // Qwen base model, 4k context. public static final String QWEN_PLUS = "qwen-plus"; // Qwen plus model, 8k context. public static final String QWEN_MAX = "qwen-max"; // Qwen max model, 200-billion-parameters, 8k context. - public static final String QWEN_7B_CHAT = "qwen-7b-chat"; // Qwen open sourced 7-billion-parameters version - public static final String QWEN_14B_CHAT = "qwen-14b-chat"; // Qwen open sourced 14-billion-parameters version + public static final String QWEN_MAX_LONGCONTEXT = "qwen-max-longcontext"; // Qwen max model, 200-billion-parameters, 30k context. + public static final String QWEN_7B_CHAT = "qwen-7b-chat"; // Qwen open sourced 7-billion-parameters model + public static final String QWEN_14B_CHAT = "qwen-14b-chat"; // Qwen open sourced 14-billion-parameters model + public static final String QWEN_72B_CHAT = "qwen-72b-chat"; // Qwen open sourced 72-billion-parameters model + public static final String QWEN1_5_7B_CHAT = "qwen1.5-7b-chat"; // Qwen open sourced 7-billion-parameters model (v1.5) + public static final String QWEN1_5_14B_CHAT = "qwen1.5-14b-chat"; // Qwen open sourced 14-billion-parameters model (v1.5) + public static final String QWEN1_5_32B_CHAT = "qwen1.5-32b-chat"; // Qwen open sourced 32-billion-parameters model (v1.5) + public static final String QWEN1_5_72B_CHAT = "qwen1.5-72b-chat"; // Qwen open sourced 72-billion-parameters model (v1.5) + public static final String QWEN2_0_5B_INSTRUCT = "qwen2-0.5b-instruct"; // Qwen open sourced 0.5-billion-parameters model (v2) + public static final String QWEN2_1_5B_INSTRUCT = "qwen2-1.5b-instruct"; // Qwen open sourced 1.5-billion-parameters model (v2) + public static final String QWEN2_7B_INSTRUCT = "qwen2-7b-instruct"; // Qwen open sourced 7-billion-parameters model (v2) + public static final String QWEN2_72B_INSTRUCT = "qwen2-72b-instruct"; // Qwen open sourced 72-billion-parameters model (v2) + public static final String QWEN2_57B_A14B_INSTRUCT = "qwen2-57b-a14b-instruct"; // Qwen open sourced 57-billion-parameters and 14-billion-activation-parameters MOE model (v2) public static final String QWEN_VL_PLUS = "qwen-vl-plus"; // Qwen multi-modal model, supports image and text information. public static final String QWEN_VL_MAX = "qwen-vl-max"; // Qwen multi-modal model, offers optimal performance on a wider range of complex tasks. diff --git a/langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenStreamingChatModel.java b/langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenStreamingChatModel.java index 6b60fc597a..05cc8da279 100644 --- a/langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenStreamingChatModel.java +++ b/langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenStreamingChatModel.java @@ -1,8 +1,8 @@ package dev.langchain4j.model.dashscope; import com.alibaba.dashscope.aigc.generation.Generation; +import com.alibaba.dashscope.aigc.generation.GenerationParam; import com.alibaba.dashscope.aigc.generation.GenerationResult; -import com.alibaba.dashscope.aigc.generation.models.QwenParam; import com.alibaba.dashscope.aigc.multimodalconversation.MultiModalConversation; import com.alibaba.dashscope.aigc.multimodalconversation.MultiModalConversationParam; import com.alibaba.dashscope.aigc.multimodalconversation.MultiModalConversationResult; @@ -21,11 +21,17 @@ import java.util.List; -import static com.alibaba.dashscope.aigc.generation.models.QwenParam.ResultFormat.MESSAGE; +import static com.alibaba.dashscope.aigc.conversation.ConversationParam.ResultFormat.MESSAGE; import static dev.langchain4j.model.dashscope.QwenHelper.toQwenMessages; import static dev.langchain4j.model.dashscope.QwenHelper.toQwenMultiModalMessages; import static dev.langchain4j.spi.ServiceHelper.loadFactories; +/** + * Represents a Qwen language model with a chat completion interface. + * The model's response is streamed token by token and should be handled with {@link StreamingResponseHandler}. + *
+ * More details are available here + */ public class QwenStreamingChatModel implements StreamingChatLanguageModel { private final String apiKey; private final String modelName; @@ -91,7 +97,7 @@ public void generate(List messages, StreamingResponseHandler messages, StreamingResponseHandler handler) { try { - QwenParam.QwenParamBuilder builder = QwenParam.builder() + GenerationParam.GenerationParamBuilder builder = GenerationParam.builder() .apiKey(apiKey) .model(modelName) .topP(topP) diff --git a/langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenStreamingLanguageModel.java b/langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenStreamingLanguageModel.java index 956f682fce..513f97eae5 100644 --- a/langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenStreamingLanguageModel.java +++ b/langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenStreamingLanguageModel.java @@ -1,8 +1,8 @@ package dev.langchain4j.model.dashscope; import com.alibaba.dashscope.aigc.generation.Generation; +import com.alibaba.dashscope.aigc.generation.GenerationParam; import com.alibaba.dashscope.aigc.generation.GenerationResult; -import com.alibaba.dashscope.aigc.generation.models.QwenParam; import com.alibaba.dashscope.common.ResultCallback; import com.alibaba.dashscope.exception.InputRequiredException; import com.alibaba.dashscope.exception.NoApiKeyException; @@ -17,11 +17,17 @@ import java.util.List; -import static com.alibaba.dashscope.aigc.generation.models.QwenParam.ResultFormat.MESSAGE; +import static com.alibaba.dashscope.aigc.generation.GenerationParam.ResultFormat.MESSAGE; import static dev.langchain4j.internal.Utils.isNullOrBlank; import static dev.langchain4j.model.dashscope.QwenModelName.QWEN_PLUS; import static dev.langchain4j.spi.ServiceHelper.loadFactories; +/** + * Represents a Qwen language model with a text interface. + * The model's response is streamed token by token and should be handled with {@link StreamingResponseHandler}. + *
+ * More details are available here. + */ public class QwenStreamingLanguageModel implements StreamingLanguageModel { private final String apiKey; private final String modelName; @@ -73,7 +79,7 @@ public QwenStreamingLanguageModel(String baseUrl, @Override public void generate(String prompt, StreamingResponseHandler handler) { try { - QwenParam.QwenParamBuilder builder = QwenParam.builder() + GenerationParam.GenerationParamBuilder builder = GenerationParam.builder() .apiKey(apiKey) .model(modelName) .topP(topP) diff --git a/langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenTokenizer.java b/langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenTokenizer.java index f1a269600f..ab7dd3811f 100644 --- a/langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenTokenizer.java +++ b/langchain4j-dashscope/src/main/java/dev/langchain4j/model/dashscope/QwenTokenizer.java @@ -1,6 +1,6 @@ package dev.langchain4j.model.dashscope; -import com.alibaba.dashscope.aigc.generation.models.QwenParam; +import com.alibaba.dashscope.aigc.generation.GenerationParam; import com.alibaba.dashscope.exception.InputRequiredException; import com.alibaba.dashscope.exception.NoApiKeyException; import com.alibaba.dashscope.tokenizers.Tokenization; @@ -12,13 +12,11 @@ import java.util.Collections; -import static dev.langchain4j.internal.Utils.getOrDefault; -import static dev.langchain4j.internal.Utils.isNullOrBlank; +import static dev.langchain4j.internal.Utils.*; import static dev.langchain4j.model.dashscope.QwenHelper.toQwenMessages; import static dev.langchain4j.model.dashscope.QwenModelName.QWEN_PLUS; public class QwenTokenizer implements Tokenizer { - private final String apiKey; private final String modelName; private final Tokenization tokenizer; @@ -34,15 +32,17 @@ public QwenTokenizer(String apiKey, String modelName) { @Override public int estimateTokenCountInText(String text) { + String prompt = isBlank(text) ? text + "_" : text; try { - QwenParam param = QwenParam.builder() + GenerationParam param = GenerationParam.builder() .apiKey(apiKey) .model(modelName) - .prompt(text) + .prompt(prompt) .build(); TokenizationResult result = tokenizer.call(param); - return result.getUsage().getInputTokens(); + int tokenCount = result.getUsage().getInputTokens(); + return prompt == text ? tokenCount : tokenCount - 1; } catch (NoApiKeyException | InputRequiredException e) { throw new RuntimeException(e); } @@ -55,8 +55,12 @@ public int estimateTokenCountInMessage(ChatMessage message) { @Override public int estimateTokenCountInMessages(Iterable messages) { + if (isNullOrEmpty(messages)) { + return 0; + } + try { - QwenParam param = QwenParam.builder() + GenerationParam param = GenerationParam.builder() .apiKey(apiKey) .model(modelName) .messages(toQwenMessages(messages)) @@ -78,4 +82,14 @@ public int estimateTokenCountInToolSpecifications(Iterable to public int estimateTokenCountInToolExecutionRequests(Iterable toolExecutionRequests) { throw new IllegalArgumentException("Tools are currently not supported by this tokenizer"); } + + public static boolean isBlank(CharSequence cs) { + int strLen = cs == null ? 0 : cs.length(); + for (int i = 0; i < strLen; ++i) { + if (!Character.isWhitespace(cs.charAt(i))) { + return false; + } + } + return true; + } } diff --git a/langchain4j-dashscope/src/test/java/dev/langchain4j/model/dashscope/QwenChatModelIT.java b/langchain4j-dashscope/src/test/java/dev/langchain4j/model/dashscope/QwenChatModelIT.java index e6fc0163f9..63c599d5fc 100644 --- a/langchain4j-dashscope/src/test/java/dev/langchain4j/model/dashscope/QwenChatModelIT.java +++ b/langchain4j-dashscope/src/test/java/dev/langchain4j/model/dashscope/QwenChatModelIT.java @@ -1,13 +1,29 @@ package dev.langchain4j.model.dashscope; +import dev.langchain4j.agent.tool.JsonSchemaProperty; +import dev.langchain4j.agent.tool.ToolExecutionRequest; +import dev.langchain4j.agent.tool.ToolSpecification; import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.data.message.ToolExecutionResultMessage; +import dev.langchain4j.data.message.UserMessage; import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.output.Response; +import dev.langchain4j.model.output.TokenUsage; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.MethodSource; +import java.util.List; + +import static dev.langchain4j.agent.tool.JsonSchemaProperty.INTEGER; +import static dev.langchain4j.data.message.ToolExecutionResultMessage.from; +import static dev.langchain4j.data.message.UserMessage.userMessage; import static dev.langchain4j.model.dashscope.QwenTestHelper.*; +import static dev.langchain4j.model.output.FinishReason.STOP; +import static dev.langchain4j.model.output.FinishReason.TOOL_EXECUTION; +import static java.util.Arrays.asList; +import static java.util.Collections.singletonList; import static org.assertj.core.api.Assertions.assertThat; @EnabledIfEnvironmentVariable(named = "DASHSCOPE_API_KEY", matches = ".+") @@ -27,13 +43,179 @@ public void should_send_non_multimodal_messages_and_receive_response(String mode assertThat(response.content().text()).containsIgnoringCase("rain"); } + @ParameterizedTest + @MethodSource("dev.langchain4j.model.dashscope.QwenTestHelper#functionCallChatModelNameProvider") + public void should_call_function_with_no_argument_then_answer(String modelName) { + ChatLanguageModel model = QwenChatModel.builder() + .apiKey(apiKey()) + .modelName(modelName) + .build(); + + String toolName = "getCurrentDateAndTime"; + ToolSpecification noArgToolSpec = ToolSpecification.builder() + .name(toolName) + .description("Get the current date and time") + .build(); + + UserMessage userMessage = UserMessage.from("What time is it?"); + + Response response = model.generate(singletonList(userMessage), singletonList(noArgToolSpec)); + + assertThat(response.content().text()).isNull(); + assertThat(response.content().toolExecutionRequests()).hasSize(1); + ToolExecutionRequest toolExecutionRequest = response.content().toolExecutionRequests().get(0); + assertThat(toolExecutionRequest.name()).isEqualTo(toolName); + assertThat(toolExecutionRequest.arguments()).isEqualTo("{}"); + assertThat(response.finishReason()).isEqualTo(TOOL_EXECUTION); + + ToolExecutionResultMessage toolExecutionResultMessage = from(toolExecutionRequest, "10 o'clock"); + List messages = asList(userMessage, response.content(), toolExecutionResultMessage); + + Response secondResponse = model.generate(messages, singletonList(noArgToolSpec)); + + AiMessage secondAiMessage = secondResponse.content(); + assertThat(secondAiMessage.text()).contains("10"); + assertThat(secondAiMessage.toolExecutionRequests()).isNull(); + + TokenUsage secondTokenUsage = secondResponse.tokenUsage(); + assertThat(secondTokenUsage.inputTokenCount()).isGreaterThan(0); + assertThat(secondTokenUsage.outputTokenCount()).isGreaterThan(0); + assertThat(secondTokenUsage.totalTokenCount()) + .isEqualTo(secondTokenUsage.inputTokenCount() + secondTokenUsage.outputTokenCount()); + assertThat(secondResponse.finishReason()).isEqualTo(STOP); + } + + @ParameterizedTest + @MethodSource("dev.langchain4j.model.dashscope.QwenTestHelper#functionCallChatModelNameProvider") + public void should_call_function_with_argument_then_answer(String modelName) { + ChatLanguageModel model = QwenChatModel.builder() + .apiKey(apiKey()) + .modelName(modelName) + .build(); + + String toolName = "getCurrentWeather"; + ToolSpecification hasArgToolSpec = ToolSpecification.builder() + .name(toolName) + .description("Query the weather of a specified city") + .addParameter("cityName", JsonSchemaProperty.STRING) + .build(); + + UserMessage userMessage = UserMessage.from("Weather in Beijing?"); + + Response response = model.generate(singletonList(userMessage), singletonList(hasArgToolSpec)); + + assertThat(response.content().text()).isNull(); + assertThat(response.content().toolExecutionRequests()).hasSize(1); + ToolExecutionRequest toolExecutionRequest = response.content().toolExecutionRequests().get(0); + assertThat(toolExecutionRequest.name()).isEqualTo(toolName); + assertThat(toolExecutionRequest.arguments()).contains("Beijing"); + assertThat(response.finishReason()).isEqualTo(TOOL_EXECUTION); + + ToolExecutionResultMessage toolExecutionResultMessage = from(toolExecutionRequest, "rainy"); + List messages = asList(userMessage, response.content(), toolExecutionResultMessage); + + Response secondResponse = model.generate(messages, singletonList(hasArgToolSpec)); + + AiMessage secondAiMessage = secondResponse.content(); + assertThat(secondAiMessage.text()).contains("rainy"); + assertThat(secondAiMessage.toolExecutionRequests()).isNull(); + + TokenUsage secondTokenUsage = secondResponse.tokenUsage(); + assertThat(secondTokenUsage.inputTokenCount()).isGreaterThan(0); + assertThat(secondTokenUsage.outputTokenCount()).isGreaterThan(0); + assertThat(secondTokenUsage.totalTokenCount()) + .isEqualTo(secondTokenUsage.inputTokenCount() + secondTokenUsage.outputTokenCount()); + assertThat(secondResponse.finishReason()).isEqualTo(STOP); + } + + @ParameterizedTest + @MethodSource("dev.langchain4j.model.dashscope.QwenTestHelper#functionCallChatModelNameProvider") + public void should_call_must_be_executed_call_function(String modelName) { + ChatLanguageModel model = QwenChatModel.builder() + .apiKey(apiKey()) + .modelName(modelName) + .build(); + + String toolName = "getCurrentWeather"; + ToolSpecification mustBeExecutedTool = ToolSpecification.builder() + .name(toolName) + .description("Query the weather of a specified city") + .addParameter("cityName", JsonSchemaProperty.STRING) + .build(); + + // not related to tools + UserMessage userMessage = UserMessage.from("How many students in the classroom?"); + + Response response = model.generate(singletonList(userMessage), mustBeExecutedTool); + + assertThat(response.content().text()).isNull(); + assertThat(response.content().toolExecutionRequests()).hasSize(1); + ToolExecutionRequest toolExecutionRequest = response.content().toolExecutionRequests().get(0); + assertThat(toolExecutionRequest.name()).isEqualTo(toolName); + assertThat(toolExecutionRequest.arguments()).hasSizeGreaterThan(0); + assertThat(response.finishReason()).isEqualTo(TOOL_EXECUTION); + } + + @ParameterizedTest + @MethodSource("dev.langchain4j.model.dashscope.QwenTestHelper#functionCallChatModelNameProvider") + void should_call_must_be_executed_call_function_with_argument_then_answer(String modelName) { + ChatLanguageModel model = QwenChatModel.builder() + .apiKey(apiKey()) + .modelName(modelName) + .build(); + + String toolName = "calculator"; + ToolSpecification calculator = ToolSpecification.builder() + .name(toolName) + .description("returns a sum of two numbers") + .addParameter("first", INTEGER) + .addParameter("second", INTEGER) + .build(); + + UserMessage userMessage = userMessage("2+2=?"); + + Response response = model.generate(singletonList(userMessage), calculator); + + AiMessage aiMessage = response.content(); + assertThat(aiMessage.text()).isNull(); + assertThat(aiMessage.toolExecutionRequests()).hasSize(1); + + ToolExecutionRequest toolExecutionRequest = aiMessage.toolExecutionRequests().get(0); + assertThat(toolExecutionRequest.id()).isNotNull(); + assertThat(toolExecutionRequest.name()).isEqualTo(toolName); + assertThat(toolExecutionRequest.arguments()).isEqualToIgnoringWhitespace("{\"first\": 2, \"second\": 2}"); + + TokenUsage tokenUsage = response.tokenUsage(); + assertThat(tokenUsage.inputTokenCount()).isGreaterThan(0); + assertThat(tokenUsage.outputTokenCount()).isGreaterThan(0); + assertThat(tokenUsage.totalTokenCount()) + .isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount()); + assertThat(response.finishReason()).isEqualTo(TOOL_EXECUTION); + + ToolExecutionResultMessage toolExecutionResultMessage = from(toolExecutionRequest, "4"); + List messages = asList(userMessage, aiMessage, toolExecutionResultMessage); + + Response secondResponse = model.generate(messages, singletonList(calculator)); + + AiMessage secondAiMessage = secondResponse.content(); + assertThat(secondAiMessage.text()).contains("4"); + assertThat(secondAiMessage.toolExecutionRequests()).isNull(); + + TokenUsage secondTokenUsage = secondResponse.tokenUsage(); + assertThat(secondTokenUsage.inputTokenCount()).isGreaterThan(0); + assertThat(secondTokenUsage.outputTokenCount()).isGreaterThan(0); + assertThat(secondTokenUsage.totalTokenCount()) + .isEqualTo(secondTokenUsage.inputTokenCount() + secondTokenUsage.outputTokenCount()); + assertThat(secondResponse.finishReason()).isEqualTo(STOP); + } + @ParameterizedTest @MethodSource("dev.langchain4j.model.dashscope.QwenTestHelper#multimodalChatModelNameProvider") public void should_send_multimodal_image_url_and_receive_response(String modelName) { ChatLanguageModel model = QwenChatModel.builder() .apiKey(apiKey()) .modelName(modelName) - .build();; + .build(); Response response = model.generate(multimodalChatMessagesWithImageUrl()); System.out.println(response); @@ -47,7 +229,7 @@ public void should_send_multimodal_image_data_and_receive_response(String modelN ChatLanguageModel model = QwenChatModel.builder() .apiKey(apiKey()) .modelName(modelName) - .build();; + .build(); Response response = model.generate(multimodalChatMessagesWithImageData()); System.out.println(response); diff --git a/langchain4j-dashscope/src/test/java/dev/langchain4j/model/dashscope/QwenEmbeddingModelIT.java b/langchain4j-dashscope/src/test/java/dev/langchain4j/model/dashscope/QwenEmbeddingModelIT.java index c3e67a47b0..9b06c4f9df 100644 --- a/langchain4j-dashscope/src/test/java/dev/langchain4j/model/dashscope/QwenEmbeddingModelIT.java +++ b/langchain4j-dashscope/src/test/java/dev/langchain4j/model/dashscope/QwenEmbeddingModelIT.java @@ -7,9 +7,14 @@ import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.MethodSource; +import java.util.Collections; import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.Stream; import static dev.langchain4j.data.segment.TextSegment.textSegment; +import static dev.langchain4j.model.dashscope.QwenEmbeddingModel.TYPE_KEY; +import static dev.langchain4j.model.dashscope.QwenEmbeddingModel.TYPE_QUERY; import static dev.langchain4j.model.dashscope.QwenTestHelper.apiKey; import static java.util.Arrays.asList; import static org.assertj.core.api.Assertions.assertThat; @@ -52,8 +57,8 @@ void should_embed_documents(String modelName) { void should_embed_queries(String modelName) { EmbeddingModel model = getModel(modelName); List embeddings = model.embedAll(asList( - textSegment("hello", Metadata.from(QwenEmbeddingModel.TYPE_KEY, QwenEmbeddingModel.TYPE_QUERY)), - textSegment("how are you?", Metadata.from(QwenEmbeddingModel.TYPE_KEY, QwenEmbeddingModel.TYPE_QUERY)) + textSegment("hello", Metadata.from(TYPE_KEY, TYPE_QUERY)), + textSegment("how are you?", Metadata.from(TYPE_KEY, TYPE_QUERY)) )).content(); assertThat(embeddings).hasSize(2); @@ -66,7 +71,7 @@ void should_embed_queries(String modelName) { void should_embed_mix_segments(String modelName) { EmbeddingModel model = getModel(modelName); List embeddings = model.embedAll(asList( - textSegment("hello", Metadata.from(QwenEmbeddingModel.TYPE_KEY, QwenEmbeddingModel.TYPE_QUERY)), + textSegment("hello", Metadata.from(TYPE_KEY, TYPE_QUERY)), textSegment("how are you?") )).content(); @@ -74,4 +79,39 @@ void should_embed_mix_segments(String modelName) { assertThat(embeddings.get(0).vector()).isNotEmpty(); assertThat(embeddings.get(1).vector()).isNotEmpty(); } + + @ParameterizedTest + @MethodSource("dev.langchain4j.model.dashscope.QwenTestHelper#embeddingModelNameProvider") + void should_embed_large_amounts_of_documents(String modelName) { + EmbeddingModel model = getModel(modelName); + List embeddings = model.embedAll( + Collections.nCopies(50, textSegment("hello"))).content(); + + assertThat(embeddings).hasSize(50); + } + + @ParameterizedTest + @MethodSource("dev.langchain4j.model.dashscope.QwenTestHelper#embeddingModelNameProvider") + void should_embed_large_amounts_of_queries(String modelName) { + EmbeddingModel model = getModel(modelName); + List embeddings = model.embedAll( + Collections.nCopies(50, textSegment("hello", Metadata.from(TYPE_KEY, TYPE_QUERY))) + ).content(); + + assertThat(embeddings).hasSize(50); + } + + @ParameterizedTest + @MethodSource("dev.langchain4j.model.dashscope.QwenTestHelper#embeddingModelNameProvider") + void should_embed_large_amounts_of_mix_segments(String modelName) { + EmbeddingModel model = getModel(modelName); + List embeddings = model.embedAll( + Stream.concat( + Collections.nCopies(50, textSegment("hello", Metadata.from(TYPE_KEY, TYPE_QUERY))).stream(), + Collections.nCopies(50, textSegment("how are you?")).stream() + ).collect(Collectors.toList()) + ).content(); + + assertThat(embeddings).hasSize(100); + } } diff --git a/langchain4j-dashscope/src/test/java/dev/langchain4j/model/dashscope/QwenTestHelper.java b/langchain4j-dashscope/src/test/java/dev/langchain4j/model/dashscope/QwenTestHelper.java index cfa01e159d..fa2955c8fd 100644 --- a/langchain4j-dashscope/src/test/java/dev/langchain4j/model/dashscope/QwenTestHelper.java +++ b/langchain4j-dashscope/src/test/java/dev/langchain4j/model/dashscope/QwenTestHelper.java @@ -20,14 +20,48 @@ public class QwenTestHelper { public static Stream languageModelNameProvider() { return Stream.of( Arguments.of(QwenModelName.QWEN_TURBO), - Arguments.of(QwenModelName.QWEN_PLUS) + Arguments.of(QwenModelName.QWEN_PLUS), + Arguments.of(QwenModelName.QWEN_MAX), + Arguments.of(QwenModelName.QWEN_MAX_LONGCONTEXT), + Arguments.of(QwenModelName.QWEN_7B_CHAT), + Arguments.of(QwenModelName.QWEN_14B_CHAT), + Arguments.of(QwenModelName.QWEN_72B_CHAT), + Arguments.of(QwenModelName.QWEN1_5_7B_CHAT), + Arguments.of(QwenModelName.QWEN1_5_14B_CHAT), + Arguments.of(QwenModelName.QWEN1_5_32B_CHAT), + Arguments.of(QwenModelName.QWEN1_5_72B_CHAT), + Arguments.of(QwenModelName.QWEN2_0_5B_INSTRUCT), + Arguments.of(QwenModelName.QWEN2_1_5B_INSTRUCT), + Arguments.of(QwenModelName.QWEN2_7B_INSTRUCT), + Arguments.of(QwenModelName.QWEN2_72B_INSTRUCT), + Arguments.of(QwenModelName.QWEN2_57B_A14B_INSTRUCT) ); } public static Stream nonMultimodalChatModelNameProvider() { return Stream.of( Arguments.of(QwenModelName.QWEN_TURBO), - Arguments.of(QwenModelName.QWEN_PLUS) + Arguments.of(QwenModelName.QWEN_PLUS), + Arguments.of(QwenModelName.QWEN_MAX), + Arguments.of(QwenModelName.QWEN_MAX_LONGCONTEXT), + Arguments.of(QwenModelName.QWEN_7B_CHAT), + Arguments.of(QwenModelName.QWEN_14B_CHAT), + Arguments.of(QwenModelName.QWEN_72B_CHAT), + Arguments.of(QwenModelName.QWEN1_5_7B_CHAT), + Arguments.of(QwenModelName.QWEN1_5_14B_CHAT), + Arguments.of(QwenModelName.QWEN1_5_32B_CHAT), + Arguments.of(QwenModelName.QWEN1_5_72B_CHAT), + Arguments.of(QwenModelName.QWEN2_0_5B_INSTRUCT), + Arguments.of(QwenModelName.QWEN2_1_5B_INSTRUCT), + Arguments.of(QwenModelName.QWEN2_7B_INSTRUCT), + Arguments.of(QwenModelName.QWEN2_72B_INSTRUCT), + Arguments.of(QwenModelName.QWEN2_57B_A14B_INSTRUCT) + ); + } + + public static Stream functionCallChatModelNameProvider() { + return Stream.of( + Arguments.of(QwenModelName.QWEN_MAX) ); } diff --git a/langchain4j-dashscope/src/test/java/dev/langchain4j/model/dashscope/QwenTokenizerIT.java b/langchain4j-dashscope/src/test/java/dev/langchain4j/model/dashscope/QwenTokenizerIT.java index 1be4781015..df3f8d6abf 100644 --- a/langchain4j-dashscope/src/test/java/dev/langchain4j/model/dashscope/QwenTokenizerIT.java +++ b/langchain4j-dashscope/src/test/java/dev/langchain4j/model/dashscope/QwenTokenizerIT.java @@ -10,6 +10,7 @@ import org.junit.jupiter.params.provider.MethodSource; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.stream.Stream; @@ -17,6 +18,7 @@ import static dev.langchain4j.data.message.UserMessage.userMessage; import static dev.langchain4j.model.dashscope.QwenTestHelper.apiKey; import static java.util.Arrays.asList; +import static java.util.Collections.emptyList; import static java.util.Collections.singletonList; import static org.assertj.core.api.Assertions.assertThat; @@ -53,6 +55,11 @@ void should_count_tokens_in_short_texts() { assertThat(tokenizer.estimateTokenCountInText("Hello")).isEqualTo(1); assertThat(tokenizer.estimateTokenCountInText("Hello!")).isEqualTo(2); assertThat(tokenizer.estimateTokenCountInText("Hello, how are you?")).isEqualTo(6); + + assertThat(tokenizer.estimateTokenCountInText("")).isEqualTo(0); + assertThat(tokenizer.estimateTokenCountInText("\n")).isEqualTo(1); + assertThat(tokenizer.estimateTokenCountInText("\n\n")).isEqualTo(1); + assertThat(tokenizer.estimateTokenCountInText("\n \n\n")).isEqualTo(2); } @Test @@ -79,6 +86,12 @@ void should_count_tokens_in_large_text() { assertThat(tokenizer.estimateTokenCountInText(text3)).isEqualTo(100 * 15); } + @Test + void should_count_empty_messages_and_return_0() { + assertThat(tokenizer.estimateTokenCountInMessages(null)).isEqualTo(0); + assertThat(tokenizer.estimateTokenCountInMessages(emptyList())).isEqualTo(0); + } + public static List repeat(String s, int n) { List result = new ArrayList<>(); for (int i = 0; i < n; i++) { diff --git a/langchain4j-easy-rag/pom.xml b/langchain4j-easy-rag/pom.xml index 90e09c99fe..e739fdbfae 100644 --- a/langchain4j-easy-rag/pom.xml +++ b/langchain4j-easy-rag/pom.xml @@ -6,7 +6,7 @@ dev.langchain4j langchain4j-parent - 0.30.0 + 0.32.0-SNAPSHOT ../langchain4j-parent/pom.xml diff --git a/langchain4j-elasticsearch/pom.xml b/langchain4j-elasticsearch/pom.xml index 164d31dea9..ae82d0c36e 100644 --- a/langchain4j-elasticsearch/pom.xml +++ b/langchain4j-elasticsearch/pom.xml @@ -7,7 +7,7 @@ dev.langchain4j langchain4j-parent - 0.30.0 + 0.32.0-SNAPSHOT ../langchain4j-parent/pom.xml @@ -80,6 +80,7 @@ elasticsearch test
+ org.testcontainers junit-jupiter diff --git a/langchain4j-elasticsearch/src/main/java/dev/langchain4j/store/embedding/elasticsearch/ElasticsearchEmbeddingStore.java b/langchain4j-elasticsearch/src/main/java/dev/langchain4j/store/embedding/elasticsearch/ElasticsearchEmbeddingStore.java index a4017c61ca..dc2686ca05 100644 --- a/langchain4j-elasticsearch/src/main/java/dev/langchain4j/store/embedding/elasticsearch/ElasticsearchEmbeddingStore.java +++ b/langchain4j-elasticsearch/src/main/java/dev/langchain4j/store/embedding/elasticsearch/ElasticsearchEmbeddingStore.java @@ -1,6 +1,8 @@ package dev.langchain4j.store.embedding.elasticsearch; import co.elastic.clients.elasticsearch.ElasticsearchClient; +import co.elastic.clients.elasticsearch._types.BulkIndexByScrollFailure; +import co.elastic.clients.elasticsearch._types.ErrorCause; import co.elastic.clients.elasticsearch._types.InlineScript; import co.elastic.clients.elasticsearch._types.mapping.DenseVectorProperty; import co.elastic.clients.elasticsearch._types.mapping.Property; @@ -10,6 +12,7 @@ import co.elastic.clients.elasticsearch._types.query_dsl.ScriptScoreQuery; import co.elastic.clients.elasticsearch.core.BulkRequest; import co.elastic.clients.elasticsearch.core.BulkResponse; +import co.elastic.clients.elasticsearch.core.DeleteByQueryResponse; import co.elastic.clients.elasticsearch.core.SearchResponse; import co.elastic.clients.elasticsearch.core.bulk.BulkResponseItem; import co.elastic.clients.json.JsonData; @@ -41,10 +44,7 @@ import org.slf4j.LoggerFactory; import java.io.IOException; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Optional; +import java.util.*; import static dev.langchain4j.internal.Utils.*; import static dev.langchain4j.internal.ValidationUtils.*; @@ -263,12 +263,29 @@ public EmbeddingSearchResult search(EmbeddingSearchRequest embeddin return new EmbeddingSearchResult<>(toMatches(response)); } catch (IOException e) { - // TODO improve - log.error("[ElasticSearch encounter I/O Exception]", e); throw new ElasticsearchRequestFailedException(e.getMessage()); } } + @Override + public void removeAll(Collection ids) { + ensureNotEmpty(ids, "ids"); + removeByIds(ids); + } + + @Override + public void removeAll(Filter filter) { + ensureNotNull(filter, "filter"); + Query query = ElasticsearchMetadataFilterMapper.map(filter); + removeByQuery(query); + } + + @Override + public void removeAll() { + Query query = Query.of(q -> q.matchAll(m -> m)); + removeByQuery(query); + } + private ScriptScoreQuery buildScriptScoreQuery(float[] vector, float minScore, Filter filter @@ -314,9 +331,8 @@ private void addAllInternal(List ids, List embeddings, List c.properties(properties)); } - private void bulk(List ids, List embeddings, List embedded) throws IOException { + private void bulkIndex(List ids, List embeddings, List embedded) throws IOException { int size = ids.size(); BulkRequest.Builder bulkBuilder = new BulkRequest.Builder(); for (int i = 0; i < size; i++) { @@ -358,15 +374,57 @@ private void bulk(List ids, List embeddings, List delete + .index(indexName) + .query(query)); + if (!response.failures().isEmpty()) { + for (BulkIndexByScrollFailure item : response.failures()) { + throwIfError(item.cause()); } } + } catch (IOException e) { + throw new ElasticsearchRequestFailedException(e.getMessage()); } } + private void removeByIds(Collection ids) { + try { + bulkRemove(ids); + } catch (IOException e) { + throw new ElasticsearchRequestFailedException(e.getMessage()); + } + } + + private void bulkRemove(Collection ids) throws IOException { + BulkRequest.Builder bulkBuilder = new BulkRequest.Builder(); + for (String id : ids) { + bulkBuilder.operations(op -> op.delete(dlt -> dlt + .index(indexName) + .id(id))); + } + BulkResponse response = client.bulk(bulkBuilder.build()); + handleBulkResponseErrors(response); + } + private List> toMatches(SearchResponse response) { return response.hits().hits().stream() .map(hit -> Optional.ofNullable(hit.source()) diff --git a/langchain4j-elasticsearch/src/test/java/dev/langchain4j/store/embedding/elasticsearch/ElasticsearchEmbeddingStoreRemoveIT.java b/langchain4j-elasticsearch/src/test/java/dev/langchain4j/store/embedding/elasticsearch/ElasticsearchEmbeddingStoreRemoveIT.java new file mode 100644 index 0000000000..a5aa9ae47b --- /dev/null +++ b/langchain4j-elasticsearch/src/test/java/dev/langchain4j/store/embedding/elasticsearch/ElasticsearchEmbeddingStoreRemoveIT.java @@ -0,0 +1,210 @@ +package dev.langchain4j.store.embedding.elasticsearch; + +import dev.langchain4j.data.document.Metadata; +import dev.langchain4j.data.embedding.Embedding; +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel; +import dev.langchain4j.model.embedding.EmbeddingModel; +import dev.langchain4j.store.embedding.EmbeddingMatch; +import dev.langchain4j.store.embedding.EmbeddingSearchRequest; +import dev.langchain4j.store.embedding.EmbeddingStore; +import dev.langchain4j.store.embedding.filter.Filter; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.testcontainers.elasticsearch.ElasticsearchContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.shaded.org.awaitility.Awaitility; +import org.testcontainers.shaded.org.awaitility.core.ThrowingRunnable; + +import java.time.Duration; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; +import java.util.stream.Collectors; + +import static dev.langchain4j.internal.Utils.randomUUID; +import static dev.langchain4j.store.embedding.filter.MetadataFilterBuilder.metadataKey; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +@Testcontainers +class ElasticsearchEmbeddingStoreRemoveIT { + + @Container + private static final ElasticsearchContainer elasticsearch = + new ElasticsearchContainer("docker.elastic.co/elasticsearch/elasticsearch:8.9.0") + .withEnv("xpack.security.enabled", "false"); + + EmbeddingStore embeddingStore = ElasticsearchEmbeddingStore.builder() + .serverUrl(elasticsearch.getHttpHostAddress()) + .indexName(randomUUID()) + .dimension(384) + .build(); + + EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel(); + + @BeforeEach + void beforeEach() { + embeddingStore.removeAll(); + EmbeddingSearchRequest request = EmbeddingSearchRequest.builder() + .queryEmbedding(embeddingModel.embed("empty").content()) + .build(); + awaitAssertion(() -> assertThat(embeddingStore.search(request).matches()).hasSize(0)); + } + + @Test + void remove_all() { + // given + Embedding embedding = embeddingModel.embed("hello").content(); + Embedding embedding2 = embeddingModel.embed("hello2").content(); + Embedding embedding3 = embeddingModel.embed("hello3").content(); + embeddingStore.add(embedding); + embeddingStore.add(embedding2); + embeddingStore.add(embedding3); + + EmbeddingSearchRequest request = EmbeddingSearchRequest.builder() + .queryEmbedding(embedding) + .maxResults(10) + .build(); + awaitAssertion(() -> assertThat(embeddingStore.search(request).matches()).hasSize(3)); + + // when + embeddingStore.removeAll(); + + // then + awaitAssertion(() -> assertThat(embeddingStore.search(request).matches()).hasSize(0)); + } + + @Test + void remove_by_id() { + // given + Embedding embedding = embeddingModel.embed("hello").content(); + Embedding embedding2 = embeddingModel.embed("hello2").content(); + Embedding embedding3 = embeddingModel.embed("hello3").content(); + + String id = embeddingStore.add(embedding); + String id2 = embeddingStore.add(embedding2); + String id3 = embeddingStore.add(embedding3); + + EmbeddingSearchRequest request = EmbeddingSearchRequest.builder() + .queryEmbedding(embedding) + .maxResults(10) + .build(); + awaitAssertion(() -> assertThat(embeddingStore.search(request).matches()).hasSize(3)); + + // when + embeddingStore.remove(id); + awaitAssertion(() -> assertThat(embeddingStore.search(request).matches()).hasSize(2)); + + // then + List> matches = embeddingStore.search(request).matches(); + List matchingIds = matches.stream().map(EmbeddingMatch::embeddingId).collect(Collectors.toList()); + assertThat(matchingIds).containsExactly(id2, id3); + } + + @Test + void remove_all_by_ids() { + // given + Embedding embedding = embeddingModel.embed("hello").content(); + Embedding embedding2 = embeddingModel.embed("hello2").content(); + Embedding embedding3 = embeddingModel.embed("hello3").content(); + + String id = embeddingStore.add(embedding); + String id2 = embeddingStore.add(embedding2); + String id3 = embeddingStore.add(embedding3); + + EmbeddingSearchRequest request = EmbeddingSearchRequest.builder() + .queryEmbedding(embedding) + .maxResults(10) + .build(); + awaitAssertion(() -> assertThat(embeddingStore.search(request).matches()).hasSize(3)); + + // when + embeddingStore.removeAll(Arrays.asList(id2, id3)); + awaitAssertion(() -> assertThat(embeddingStore.search(request).matches()).hasSize(1)); + + // then + List> matches = embeddingStore.search(request).matches(); + List matchingIds = matches.stream().map(EmbeddingMatch::embeddingId).collect(Collectors.toList()); + assertThat(matchingIds).containsExactly(id); + } + + @Test + void remove_all_by_ids_null() { + assertThatThrownBy(() -> embeddingStore.removeAll((Collection) null)) + .isExactlyInstanceOf(IllegalArgumentException.class) + .hasMessage("ids cannot be null or empty"); + } + + @Test + void remove_all_by_filter() { + // given + Metadata metadata = Metadata.metadata("id", "1"); + TextSegment segment = TextSegment.from("matching", metadata); + Embedding embedding = embeddingModel.embed(segment).content(); + embeddingStore.add(embedding, segment); + + EmbeddingSearchRequest request = EmbeddingSearchRequest.builder() + .queryEmbedding(embedding) + .maxResults(10) + .build(); + awaitAssertion(() -> assertThat(embeddingStore.search(request).matches()).hasSize(1)); + + Embedding embedding2 = embeddingModel.embed("hello2").content(); + Embedding embedding3 = embeddingModel.embed("hello3").content(); + + String id2 = embeddingStore.add(embedding2); + String id3 = embeddingStore.add(embedding3); + + awaitAssertion(() -> assertThat(embeddingStore.search(request).matches()).hasSize(3)); + + // when + embeddingStore.removeAll(metadataKey("id").isEqualTo("1")); + awaitAssertion(() -> assertThat(embeddingStore.search(request).matches()).hasSize(2)); + + // then + List> matches = embeddingStore.search(request).matches(); + List matchingIds = matches.stream().map(EmbeddingMatch::embeddingId).collect(Collectors.toList()); + assertThat(matchingIds).hasSize(2); + assertThat(matchingIds).containsExactly(id2, id3); + } + + @Test + void remove_all_by_filter_not_matching() { + // given + Embedding embedding = embeddingModel.embed("hello").content(); + Embedding embedding2 = embeddingModel.embed("hello2").content(); + Embedding embedding3 = embeddingModel.embed("hello3").content(); + + embeddingStore.add(embedding); + embeddingStore.add(embedding2); + embeddingStore.add(embedding3); + EmbeddingSearchRequest request = EmbeddingSearchRequest.builder() + .queryEmbedding(embedding) + .maxResults(10) + .build(); + awaitAssertion(() -> assertThat(embeddingStore.search(request).matches()).hasSize(3)); + + // when + embeddingStore.removeAll(metadataKey("unknown").isEqualTo("1")); + + // then + List> matches = embeddingStore.search(request).matches(); + List matchingIds = matches.stream().map(EmbeddingMatch::embeddingId).collect(Collectors.toList()); + assertThat(matchingIds).hasSize(3); + } + + @Test + void remove_all_by_filter_null() { + assertThatThrownBy(() -> embeddingStore.removeAll((Filter) null)) + .isExactlyInstanceOf(IllegalArgumentException.class) + .hasMessage("filter cannot be null"); + } + + private static void awaitAssertion(ThrowingRunnable assertionRunnable) { + Awaitility.await().pollInterval(Duration.ofSeconds(1)) + .atMost(Duration.ofSeconds(5)) + .untilAsserted(assertionRunnable); + } +} diff --git a/langchain4j-hugging-face/pom.xml b/langchain4j-hugging-face/pom.xml index 87aeedba24..eadf4a9979 100644 --- a/langchain4j-hugging-face/pom.xml +++ b/langchain4j-hugging-face/pom.xml @@ -7,7 +7,7 @@ dev.langchain4j langchain4j-parent - 0.30.0 + 0.32.0-SNAPSHOT ../langchain4j-parent/pom.xml diff --git a/langchain4j-infinispan/pom.xml b/langchain4j-infinispan/pom.xml index fbaa0dbcf6..f33f4b4c45 100644 --- a/langchain4j-infinispan/pom.xml +++ b/langchain4j-infinispan/pom.xml @@ -7,7 +7,7 @@ dev.langchain4j langchain4j-parent - 0.30.0 + 0.32.0-SNAPSHOT ../langchain4j-parent/pom.xml diff --git a/langchain4j-jina/pom.xml b/langchain4j-jina/pom.xml new file mode 100644 index 0000000000..dc7a37c44e --- /dev/null +++ b/langchain4j-jina/pom.xml @@ -0,0 +1,69 @@ + + 4.0.0 + + dev.langchain4j + langchain4j-parent + 0.32.0-SNAPSHOT + ../langchain4j-parent/pom.xml + + + langchain4j-jina + LangChain4j :: Integration :: Jina + + + + + dev.langchain4j + langchain4j-core + + + + com.squareup.retrofit2 + retrofit + + + + com.squareup.retrofit2 + converter-jackson + + + + org.projectlombok + lombok + provided + + + + org.junit.jupiter + junit-jupiter-engine + test + + + + org.junit.jupiter + junit-jupiter-params + test + + + + org.assertj + assertj-core + test + + + + org.tinylog + tinylog-impl + test + + + + org.tinylog + slf4j-tinylog + test + + + + + \ No newline at end of file diff --git a/langchain4j-jina/src/main/java/dev/langchain4j/model/jina/JinaEmbeddingModel.java b/langchain4j-jina/src/main/java/dev/langchain4j/model/jina/JinaEmbeddingModel.java new file mode 100644 index 0000000000..de668cedee --- /dev/null +++ b/langchain4j-jina/src/main/java/dev/langchain4j/model/jina/JinaEmbeddingModel.java @@ -0,0 +1,78 @@ +package dev.langchain4j.model.jina; + +import dev.langchain4j.data.embedding.Embedding; +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.model.embedding.EmbeddingModel; +import dev.langchain4j.model.jina.internal.api.JinaEmbeddingRequest; +import dev.langchain4j.model.jina.internal.api.JinaEmbeddingResponse; +import dev.langchain4j.model.jina.internal.client.JinaClient; +import dev.langchain4j.model.output.Response; +import dev.langchain4j.model.output.TokenUsage; +import lombok.Builder; + +import java.time.Duration; +import java.util.List; + +import static dev.langchain4j.internal.RetryUtils.withRetry; +import static dev.langchain4j.internal.Utils.getOrDefault; +import static java.time.Duration.ofSeconds; +import static java.util.stream.Collectors.toList; + +/** + * An implementation of an {@link EmbeddingModel} that uses + * Jina Embeddings API. + */ +public class JinaEmbeddingModel implements EmbeddingModel { + + private static final String DEFAULT_BASE_URL = "https://api.jina.ai/"; + private static final String DEFAULT_MODEL = "jina-embeddings-v2-base-en"; + + private final JinaClient client; + private final String modelName; + private final Integer maxRetries; + + @Builder + public JinaEmbeddingModel(String baseUrl, + String apiKey, + String modelName, + Duration timeout, + Integer maxRetries, + Boolean logRequests, + Boolean logResponses) { + this.client = JinaClient.builder() + .baseUrl(getOrDefault(baseUrl, DEFAULT_BASE_URL)) + .apiKey(apiKey) + .timeout(getOrDefault(timeout, ofSeconds(60))) + .logRequests(getOrDefault(logRequests, false)) + .logResponses(getOrDefault(logResponses, false)) + .build(); + this.modelName = getOrDefault(modelName, DEFAULT_MODEL); + this.maxRetries = getOrDefault(maxRetries, 3); + } + + public static JinaEmbeddingModel withApiKey(String apiKey) { + return JinaEmbeddingModel.builder().apiKey(apiKey).build(); + } + + @Override + public Response> embedAll(List textSegments) { + + JinaEmbeddingRequest request = JinaEmbeddingRequest.builder() + .model(modelName) + .input(textSegments.stream().map(TextSegment::text).collect(toList())) + .build(); + + JinaEmbeddingResponse response = withRetry(() -> client.embed(request), maxRetries); + + List embeddings = response.data.stream() + .map(jinaEmbedding -> Embedding.from(jinaEmbedding.embedding)) + .collect(toList()); + + TokenUsage tokenUsage = new TokenUsage( + response.usage.promptTokens, + 0, + response.usage.totalTokens + ); + return Response.from(embeddings, tokenUsage); + } +} diff --git a/langchain4j-jina/src/main/java/dev/langchain4j/model/jina/JinaScoringModel.java b/langchain4j-jina/src/main/java/dev/langchain4j/model/jina/JinaScoringModel.java new file mode 100644 index 0000000000..55e5b4f088 --- /dev/null +++ b/langchain4j-jina/src/main/java/dev/langchain4j/model/jina/JinaScoringModel.java @@ -0,0 +1,84 @@ +package dev.langchain4j.model.jina; + +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.model.jina.internal.api.JinaRerankingRequest; +import dev.langchain4j.model.jina.internal.api.JinaRerankingResponse; +import dev.langchain4j.model.jina.internal.client.JinaClient; +import dev.langchain4j.model.output.Response; +import dev.langchain4j.model.output.TokenUsage; +import dev.langchain4j.model.scoring.ScoringModel; +import lombok.Builder; + +import java.time.Duration; +import java.util.List; + +import static dev.langchain4j.internal.RetryUtils.withRetry; +import static dev.langchain4j.internal.Utils.getOrDefault; +import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank; +import static java.time.Duration.ofSeconds; +import static java.util.Comparator.comparingInt; +import static java.util.stream.Collectors.toList; + +/** + * An implementation of a {@link ScoringModel} that uses + * Jina Reranker API. + */ +public class JinaScoringModel implements ScoringModel { + + private static final String DEFAULT_BASE_URL = "https://api.jina.ai/v1/"; + private static final String DEFAULT_MODEL = "jina-reranker-v1-base-en"; + + private final JinaClient client; + private final String modelName; + private final Integer maxRetries; + + @Builder + public JinaScoringModel(String baseUrl, + String apiKey, + String modelName, + Duration timeout, + Integer maxRetries, + Boolean logRequests, + Boolean logResponses) { + this.client = JinaClient.builder() + .baseUrl(getOrDefault(baseUrl, DEFAULT_BASE_URL)) + .apiKey(ensureNotBlank(apiKey, "apiKey")) + .timeout(getOrDefault(timeout, ofSeconds(60))) + .logRequests(getOrDefault(logRequests, false)) + .logResponses(getOrDefault(logResponses, false)) + .build(); + this.modelName = getOrDefault(modelName, DEFAULT_MODEL); + this.maxRetries = getOrDefault(maxRetries, 3); + } + + public static JinaScoringModel withApiKey(String apiKey) { + return JinaScoringModel.builder().apiKey(apiKey).build(); + } + + @Override + public Response> scoreAll(List segments, String query) { + + JinaRerankingRequest request = JinaRerankingRequest.builder() + .model(modelName) + .query(query) + .documents(segments.stream() + .map(TextSegment::text) + .collect(toList())) + .returnDocuments(false) // decreasing response size, do not include text in response + .build(); + + JinaRerankingResponse response = withRetry(() -> client.rerank(request), maxRetries); + + List scores = response.results.stream() + .sorted(comparingInt(result -> result.index)) + .map(result -> result.relevanceScore) + .collect(toList()); + + TokenUsage tokenUsage = new TokenUsage( + response.usage.promptTokens, + 0, + response.usage.totalTokens + ); + return Response.from(scores, tokenUsage); + } +} diff --git a/langchain4j-jina/src/main/java/dev/langchain4j/model/jina/internal/api/JinaApi.java b/langchain4j-jina/src/main/java/dev/langchain4j/model/jina/internal/api/JinaApi.java new file mode 100644 index 0000000000..c0c64f7c0c --- /dev/null +++ b/langchain4j-jina/src/main/java/dev/langchain4j/model/jina/internal/api/JinaApi.java @@ -0,0 +1,20 @@ +package dev.langchain4j.model.jina.internal.api; + +import retrofit2.Call; +import retrofit2.http.Body; +import retrofit2.http.Header; +import retrofit2.http.Headers; +import retrofit2.http.POST; + +public interface JinaApi { + + @POST("v1/embeddings") + @Headers({"Content-Type: application/json"}) + Call embed(@Body JinaEmbeddingRequest request, + @Header("Authorization") String authorizationHeader); + + @POST("rerank") + @Headers({"Content-Type: application/json"}) + Call rerank(@Body JinaRerankingRequest request, + @Header("Authorization") String authorizationHeader); +} diff --git a/langchain4j-jina/src/main/java/dev/langchain4j/model/jina/internal/api/JinaDocument.java b/langchain4j-jina/src/main/java/dev/langchain4j/model/jina/internal/api/JinaDocument.java new file mode 100644 index 0000000000..5107fe4ffa --- /dev/null +++ b/langchain4j-jina/src/main/java/dev/langchain4j/model/jina/internal/api/JinaDocument.java @@ -0,0 +1,16 @@ +package dev.langchain4j.model.jina.internal.api; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategies; +import com.fasterxml.jackson.databind.annotation.JsonNaming; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; + +@JsonInclude(NON_NULL) +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class) +public class JinaDocument { + + public String text; +} diff --git a/langchain4j-jina/src/main/java/dev/langchain4j/model/jina/internal/api/JinaEmbedding.java b/langchain4j-jina/src/main/java/dev/langchain4j/model/jina/internal/api/JinaEmbedding.java new file mode 100644 index 0000000000..f88ba09fbb --- /dev/null +++ b/langchain4j-jina/src/main/java/dev/langchain4j/model/jina/internal/api/JinaEmbedding.java @@ -0,0 +1,18 @@ +package dev.langchain4j.model.jina.internal.api; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategies; +import com.fasterxml.jackson.databind.annotation.JsonNaming; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; + +@JsonInclude(NON_NULL) +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class) +public class JinaEmbedding { + + public long index; + public float[] embedding; + public String object; +} diff --git a/langchain4j-jina/src/main/java/dev/langchain4j/model/jina/internal/api/JinaEmbeddingRequest.java b/langchain4j-jina/src/main/java/dev/langchain4j/model/jina/internal/api/JinaEmbeddingRequest.java new file mode 100644 index 0000000000..fd1486534b --- /dev/null +++ b/langchain4j-jina/src/main/java/dev/langchain4j/model/jina/internal/api/JinaEmbeddingRequest.java @@ -0,0 +1,21 @@ +package dev.langchain4j.model.jina.internal.api; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategies; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import lombok.Builder; + +import java.util.List; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; + +@Builder +@JsonInclude(NON_NULL) +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class) +public class JinaEmbeddingRequest { + + public String model; + public List input; +} diff --git a/langchain4j-jina/src/main/java/dev/langchain4j/model/jina/internal/api/JinaEmbeddingResponse.java b/langchain4j-jina/src/main/java/dev/langchain4j/model/jina/internal/api/JinaEmbeddingResponse.java new file mode 100644 index 0000000000..a5d93a38d8 --- /dev/null +++ b/langchain4j-jina/src/main/java/dev/langchain4j/model/jina/internal/api/JinaEmbeddingResponse.java @@ -0,0 +1,20 @@ +package dev.langchain4j.model.jina.internal.api; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategies; +import com.fasterxml.jackson.databind.annotation.JsonNaming; + +import java.util.List; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; + +@JsonInclude(NON_NULL) +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class) +public class JinaEmbeddingResponse { + + public String model; + public List data; + public JinaUsage usage; +} diff --git a/langchain4j-jina/src/main/java/dev/langchain4j/model/jina/internal/api/JinaRerankingRequest.java b/langchain4j-jina/src/main/java/dev/langchain4j/model/jina/internal/api/JinaRerankingRequest.java new file mode 100644 index 0000000000..eecefe3f46 --- /dev/null +++ b/langchain4j-jina/src/main/java/dev/langchain4j/model/jina/internal/api/JinaRerankingRequest.java @@ -0,0 +1,23 @@ +package dev.langchain4j.model.jina.internal.api; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategies; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import lombok.Builder; + +import java.util.List; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; + +@Builder +@JsonInclude(NON_NULL) +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class) +public class JinaRerankingRequest { + + public String model; + public String query; + public List documents; + public Boolean returnDocuments; +} diff --git a/langchain4j-jina/src/main/java/dev/langchain4j/model/jina/internal/api/JinaRerankingResponse.java b/langchain4j-jina/src/main/java/dev/langchain4j/model/jina/internal/api/JinaRerankingResponse.java new file mode 100644 index 0000000000..7367ee3c57 --- /dev/null +++ b/langchain4j-jina/src/main/java/dev/langchain4j/model/jina/internal/api/JinaRerankingResponse.java @@ -0,0 +1,20 @@ +package dev.langchain4j.model.jina.internal.api; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategies; +import com.fasterxml.jackson.databind.annotation.JsonNaming; + +import java.util.List; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; + +@JsonInclude(NON_NULL) +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class) +public class JinaRerankingResponse { + + public String model; + public List results; + public JinaUsage usage; +} diff --git a/langchain4j-jina/src/main/java/dev/langchain4j/model/jina/internal/api/JinaRerankingResult.java b/langchain4j-jina/src/main/java/dev/langchain4j/model/jina/internal/api/JinaRerankingResult.java new file mode 100644 index 0000000000..4ce0d1e213 --- /dev/null +++ b/langchain4j-jina/src/main/java/dev/langchain4j/model/jina/internal/api/JinaRerankingResult.java @@ -0,0 +1,18 @@ +package dev.langchain4j.model.jina.internal.api; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategies; +import com.fasterxml.jackson.databind.annotation.JsonNaming; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; + +@JsonInclude(NON_NULL) +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class) +public class JinaRerankingResult { + + public Integer index; + public JinaDocument document; + public Double relevanceScore; +} diff --git a/langchain4j-jina/src/main/java/dev/langchain4j/model/jina/internal/api/JinaUsage.java b/langchain4j-jina/src/main/java/dev/langchain4j/model/jina/internal/api/JinaUsage.java new file mode 100644 index 0000000000..1d17b216b2 --- /dev/null +++ b/langchain4j-jina/src/main/java/dev/langchain4j/model/jina/internal/api/JinaUsage.java @@ -0,0 +1,17 @@ +package dev.langchain4j.model.jina.internal.api; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategies; +import com.fasterxml.jackson.databind.annotation.JsonNaming; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; + +@JsonInclude(NON_NULL) +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class) +public class JinaUsage { + + public Integer promptTokens; + public Integer totalTokens; +} diff --git a/langchain4j-jina/src/main/java/dev/langchain4j/model/jina/internal/client/JinaClient.java b/langchain4j-jina/src/main/java/dev/langchain4j/model/jina/internal/client/JinaClient.java new file mode 100644 index 0000000000..e47c952800 --- /dev/null +++ b/langchain4j-jina/src/main/java/dev/langchain4j/model/jina/internal/client/JinaClient.java @@ -0,0 +1,84 @@ +package dev.langchain4j.model.jina.internal.client; + +import com.fasterxml.jackson.databind.ObjectMapper; +import dev.langchain4j.model.jina.internal.api.*; +import lombok.Builder; +import okhttp3.OkHttpClient; +import retrofit2.Retrofit; +import retrofit2.converter.jackson.JacksonConverterFactory; + +import java.io.IOException; +import java.time.Duration; + +import static com.fasterxml.jackson.databind.SerializationFeature.INDENT_OUTPUT; +import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank; + +public class JinaClient { + + private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper().enable(INDENT_OUTPUT); + + private final JinaApi jinaApi; + private final String authorizationHeader; + + @Builder + JinaClient(String baseUrl, String apiKey, Duration timeout, boolean logRequests, boolean logResponses) { + + OkHttpClient.Builder okHttpClientBuilder = new OkHttpClient.Builder() + .callTimeout(timeout) + .connectTimeout(timeout) + .readTimeout(timeout) + .writeTimeout(timeout); + + if (logRequests) { + okHttpClientBuilder.addInterceptor(new RequestLoggingInterceptor()); + } + if (logResponses) { + okHttpClientBuilder.addInterceptor(new ResponseLoggingInterceptor()); + } + + Retrofit retrofit = new Retrofit.Builder() + .baseUrl(baseUrl) + .client(okHttpClientBuilder.build()) + .addConverterFactory(JacksonConverterFactory.create(OBJECT_MAPPER)) + .build(); + + this.jinaApi = retrofit.create(JinaApi.class); + this.authorizationHeader = "Bearer " + ensureNotBlank(apiKey, "apiKey"); + } + + public JinaEmbeddingResponse embed(JinaEmbeddingRequest request) { + try { + retrofit2.Response retrofitResponse + = jinaApi.embed(request, authorizationHeader).execute(); + if (retrofitResponse.isSuccessful()) { + return retrofitResponse.body(); + } else { + throw toException(retrofitResponse); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + public JinaRerankingResponse rerank(JinaRerankingRequest request) { + try { + retrofit2.Response retrofitResponse + = jinaApi.rerank(request, authorizationHeader).execute(); + + if (retrofitResponse.isSuccessful()) { + return retrofitResponse.body(); + } else { + throw toException(retrofitResponse); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + private static RuntimeException toException(retrofit2.Response response) throws IOException { + int code = response.code(); + String body = response.errorBody().string(); + String errorMessage = String.format("status code: %s; body: %s", code, body); + return new RuntimeException(errorMessage); + } +} diff --git a/langchain4j-jina/src/main/java/dev/langchain4j/model/jina/internal/client/RequestLoggingInterceptor.java b/langchain4j-jina/src/main/java/dev/langchain4j/model/jina/internal/client/RequestLoggingInterceptor.java new file mode 100644 index 0000000000..a5c9ff8b7d --- /dev/null +++ b/langchain4j-jina/src/main/java/dev/langchain4j/model/jina/internal/client/RequestLoggingInterceptor.java @@ -0,0 +1,82 @@ +package dev.langchain4j.model.jina.internal.client; + +import okhttp3.Headers; +import okhttp3.Interceptor; +import okhttp3.Request; +import okhttp3.Response; +import okio.Buffer; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import java.util.stream.Collectors; + +import static java.util.stream.StreamSupport.stream; + +class RequestLoggingInterceptor implements Interceptor { + + private static final Logger log = LoggerFactory.getLogger(RequestLoggingInterceptor.class); + + private static final Pattern BEARER_PATTERN = Pattern.compile("(Bearer\\s)(\\w{2})(\\w+)(\\w{2})"); + + public Response intercept(Interceptor.Chain chain) throws IOException { + Request request = chain.request(); + log(request); + return chain.proceed(request); + } + + private void log(Request request) { + log.debug( + "Request:\n" + + "- method: {}\n" + + "- url: {}\n" + + "- headers: {}\n" + + "- body: {}", + request.method(), + request.url(), + inOneLine(request.headers()), + getBody(request) + ); + } + + static String inOneLine(Headers headers) { + return stream(headers.spliterator(), false) + .map((header) -> { + String headerKey = header.component1(); + String headerValue = header.component2(); + if (headerKey.equals("Authorization")) { + headerValue = maskAuthorizationHeaderValue(headerValue); + } + return String.format("[%s: %s]", headerKey, headerValue); + }).collect(Collectors.joining(", ")); + } + + private static String maskAuthorizationHeaderValue(String authorizationHeaderValue) { + try { + Matcher matcher = BEARER_PATTERN.matcher(authorizationHeaderValue); + StringBuffer sb = new StringBuffer(); + + while (matcher.find()) { + matcher.appendReplacement(sb, matcher.group(1) + matcher.group(2) + "..." + matcher.group(4)); + } + + matcher.appendTail(sb); + return sb.toString(); + } catch (Exception e) { + return "[failed to mask the API key]"; + } + } + + private static String getBody(Request request) { + try { + Buffer buffer = new Buffer(); + request.body().writeTo(buffer); + return buffer.readUtf8(); + } catch (Exception e) { + log.warn("Exception happened while reading request body", e); + return "[Exception happened while reading request body. Check logs for more details.]"; + } + } +} diff --git a/langchain4j-jina/src/main/java/dev/langchain4j/model/jina/internal/client/ResponseLoggingInterceptor.java b/langchain4j-jina/src/main/java/dev/langchain4j/model/jina/internal/client/ResponseLoggingInterceptor.java new file mode 100644 index 0000000000..c410b66f2c --- /dev/null +++ b/langchain4j-jina/src/main/java/dev/langchain4j/model/jina/internal/client/ResponseLoggingInterceptor.java @@ -0,0 +1,42 @@ +package dev.langchain4j.model.jina.internal.client; + +import okhttp3.Interceptor; +import okhttp3.Request; +import okhttp3.Response; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; + +class ResponseLoggingInterceptor implements Interceptor { + + private static final Logger log = LoggerFactory.getLogger(ResponseLoggingInterceptor.class); + + public Response intercept(Interceptor.Chain chain) throws IOException { + Request request = chain.request(); + Response response = chain.proceed(request); + log(response); + return response; + } + + void log(Response response) { + log.debug( + "Response:\n" + + "- status code: {}\n" + + "- headers: {}\n" + + "- body: {}", + response.code(), + RequestLoggingInterceptor.inOneLine(response.headers()), + getBody(response) + ); + } + + private String getBody(Response response) { + try { + return response.peekBody(Long.MAX_VALUE).string(); + } catch (IOException e) { + log.warn("Failed to log response", e); + return "[failed to log response]"; + } + } +} diff --git a/langchain4j-jina/src/test/java/dev/langchain4j/model/jina/JinaEmbeddingModelIT.java b/langchain4j-jina/src/test/java/dev/langchain4j/model/jina/JinaEmbeddingModelIT.java new file mode 100644 index 0000000000..a7f8c4228b --- /dev/null +++ b/langchain4j-jina/src/test/java/dev/langchain4j/model/jina/JinaEmbeddingModelIT.java @@ -0,0 +1,76 @@ +package dev.langchain4j.model.jina; + +import dev.langchain4j.data.embedding.Embedding; +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.model.embedding.EmbeddingModel; +import dev.langchain4j.model.output.Response; +import dev.langchain4j.store.embedding.CosineSimilarity; +import org.junit.jupiter.api.Test; + +import java.util.List; + +import static java.time.Duration.ofSeconds; +import static java.util.Arrays.asList; +import static org.assertj.core.api.Assertions.assertThat; + +public class JinaEmbeddingModelIT { + + @Test + public void should_embed_single_text() { + + // given + EmbeddingModel model = JinaEmbeddingModel.withApiKey(System.getenv("JINA_API_KEY")); + + String text = "hello"; + + // when + Response response = model.embed(text); + + // then + assertThat(response.content().dimension()).isEqualTo(768); + + assertThat(response.tokenUsage().inputTokenCount()).isEqualTo(3); + assertThat(response.tokenUsage().outputTokenCount()).isEqualTo(0); + assertThat(response.tokenUsage().totalTokenCount()).isEqualTo(3); + + assertThat(response.finishReason()).isNull(); + } + + @Test + public void should_embed_multiple_segments() { + + // given + EmbeddingModel model = JinaEmbeddingModel.builder() + .baseUrl("https://api.jina.ai/") + .apiKey(System.getenv("JINA_API_KEY")) + .modelName("jina-embeddings-v2-base-en") + .timeout(ofSeconds(10)) + .maxRetries(2) + .logRequests(true) + .logResponses(true) + .build(); + + TextSegment segment1 = TextSegment.from("hello"); + TextSegment segment2 = TextSegment.from("hi"); + + // when + Response> response = model.embedAll(asList(segment1, segment2)); + + // then + assertThat(response.content()).hasSize(2); + + Embedding embedding1 = response.content().get(0); + assertThat(embedding1.dimension()).isEqualTo(768); + + Embedding embedding2 = response.content().get(1); + assertThat(embedding2.dimension()).isEqualTo(768); + + assertThat(CosineSimilarity.between(embedding1, embedding2)).isGreaterThan(0.9); + + assertThat(response.tokenUsage().inputTokenCount()).isEqualTo(6); + assertThat(response.tokenUsage().outputTokenCount()).isEqualTo(0); + assertThat(response.tokenUsage().totalTokenCount()).isEqualTo(6); + + assertThat(response.finishReason()).isNull(); + } +} diff --git a/langchain4j-jina/src/test/java/dev/langchain4j/model/jina/JinaScoringModelIT.java b/langchain4j-jina/src/test/java/dev/langchain4j/model/jina/JinaScoringModelIT.java new file mode 100644 index 0000000000..a8e3967eab --- /dev/null +++ b/langchain4j-jina/src/test/java/dev/langchain4j/model/jina/JinaScoringModelIT.java @@ -0,0 +1,69 @@ +package dev.langchain4j.model.jina; + +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.model.output.Response; +import dev.langchain4j.model.scoring.ScoringModel; +import org.junit.jupiter.api.Test; + +import java.time.Duration; +import java.util.List; + +import static java.util.Arrays.asList; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.data.Percentage.withPercentage; + +class JinaScoringModelIT { + + @Test + void should_score_single_text() { + + // given + ScoringModel model = JinaScoringModel.withApiKey(System.getenv("JINA_API_KEY")); + + String text = "labrador retriever"; + String query = "tell me about dogs"; + + // when + Response response = model.score(text, query); + + // then + assertThat(response.content()).isCloseTo(0.25, withPercentage(1)); + + assertThat(response.tokenUsage().totalTokenCount()).isEqualTo(12); + + assertThat(response.finishReason()).isNull(); + } + + @Test + void should_score_multiple_segments_with_all_parameters() { + + // given + ScoringModel model = JinaScoringModel.builder() + .baseUrl("https://api.jina.ai/v1/") + .apiKey(System.getenv("JINA_API_KEY")) + .modelName("jina-reranker-v1-turbo-en") + .timeout(Duration.ofSeconds(10)) + .maxRetries(2) + .logRequests(true) + .logResponses(true) + .build(); + + TextSegment catSegment = TextSegment.from("maine coon"); + TextSegment dogSegment = TextSegment.from("labrador retriever"); + List segments = asList(catSegment, dogSegment); + + String query = "tell me about dogs"; + + // when + Response> response = model.scoreAll(segments, query); + + // then + List scores = response.content(); + assertThat(scores).hasSize(2); + assertThat(scores.get(0)).isLessThan(scores.get(1)); + + assertThat(response.tokenUsage().totalTokenCount()).isEqualTo(16); + + assertThat(response.finishReason()).isNull(); + } +} \ No newline at end of file diff --git a/langchain4j-local-ai/pom.xml b/langchain4j-local-ai/pom.xml index 1b665d0e86..e7672fea2b 100644 --- a/langchain4j-local-ai/pom.xml +++ b/langchain4j-local-ai/pom.xml @@ -7,7 +7,7 @@ dev.langchain4j langchain4j-parent - 0.30.0 + 0.32.0-SNAPSHOT ../langchain4j-parent/pom.xml diff --git a/langchain4j-milvus/pom.xml b/langchain4j-milvus/pom.xml index 8c1a6bef24..77e39c7f8f 100644 --- a/langchain4j-milvus/pom.xml +++ b/langchain4j-milvus/pom.xml @@ -7,7 +7,7 @@ dev.langchain4j langchain4j-parent - 0.30.0 + 0.32.0-SNAPSHOT ../langchain4j-parent/pom.xml @@ -28,19 +28,6 @@ io.milvus milvus-sdk-java - 2.3.4 - - - - io.netty - netty-codec - - - - - io.netty - netty-codec - ${netty.version} diff --git a/langchain4j-milvus/src/main/java/dev/langchain4j/store/embedding/milvus/CollectionOperationsExecutor.java b/langchain4j-milvus/src/main/java/dev/langchain4j/store/embedding/milvus/CollectionOperationsExecutor.java index bbff24c75a..e5cfa382ab 100644 --- a/langchain4j-milvus/src/main/java/dev/langchain4j/store/embedding/milvus/CollectionOperationsExecutor.java +++ b/langchain4j-milvus/src/main/java/dev/langchain4j/store/embedding/milvus/CollectionOperationsExecutor.java @@ -18,7 +18,6 @@ import io.milvus.response.QueryResultsWrapper; import io.milvus.response.SearchResultsWrapper; -import java.lang.String; import java.util.List; import static dev.langchain4j.store.embedding.milvus.CollectionRequestBuilder.*; @@ -45,27 +44,30 @@ static void createCollection(MilvusServiceClient milvusClient, String collection CreateCollectionParam request = CreateCollectionParam.newBuilder() .withCollectionName(collectionName) - .addFieldType(FieldType.newBuilder() - .withName(ID_FIELD_NAME) - .withDataType(VarChar) - .withMaxLength(36) - .withPrimaryKey(true) - .withAutoID(false) - .build()) - .addFieldType(FieldType.newBuilder() - .withName(TEXT_FIELD_NAME) - .withDataType(VarChar) - .withMaxLength(65535) - .build()) - .addFieldType(FieldType.newBuilder() - .withName(METADATA_FIELD_NAME) - .withDataType(JSON) - .build()) - .addFieldType(FieldType.newBuilder() - .withName(VECTOR_FIELD_NAME) - .withDataType(FloatVector) - .withDimension(dimension) - .build()) + .withSchema(CollectionSchemaParam.newBuilder() + .addFieldType(FieldType.newBuilder() + .withName(ID_FIELD_NAME) + .withDataType(VarChar) + .withMaxLength(36) + .withPrimaryKey(true) + .withAutoID(false) + .build()) + .addFieldType(FieldType.newBuilder() + .withName(TEXT_FIELD_NAME) + .withDataType(VarChar) + .withMaxLength(65535) + .build()) + .addFieldType(FieldType.newBuilder() + .withName(METADATA_FIELD_NAME) + .withDataType(JSON) + .build()) + .addFieldType(FieldType.newBuilder() + .withName(VECTOR_FIELD_NAME) + .withDataType(FloatVector) + .withDimension(dimension) + .build()) + .build() + ) .build(); R response = milvusClient.createCollection(request); @@ -124,6 +126,13 @@ static QueryResultsWrapper queryForVectors(MilvusServiceClient milvusClient, return new QueryResultsWrapper(response.getData()); } + static void removeForVector(MilvusServiceClient milvusClient, + String collectionName, + String expr) { + R response = milvusClient.delete(buildDeleteRequest(collectionName, expr)); + checkResponseNotFailed(response); + } + private static void checkResponseNotFailed(R response) { if (response == null) { throw new RequestToMilvusFailedException("Request to Milvus DB failed. Response is null"); diff --git a/langchain4j-milvus/src/main/java/dev/langchain4j/store/embedding/milvus/CollectionRequestBuilder.java b/langchain4j-milvus/src/main/java/dev/langchain4j/store/embedding/milvus/CollectionRequestBuilder.java index 3edb3ac19f..02ac9c2f3b 100644 --- a/langchain4j-milvus/src/main/java/dev/langchain4j/store/embedding/milvus/CollectionRequestBuilder.java +++ b/langchain4j-milvus/src/main/java/dev/langchain4j/store/embedding/milvus/CollectionRequestBuilder.java @@ -7,6 +7,7 @@ import io.milvus.param.collection.FlushParam; import io.milvus.param.collection.HasCollectionParam; import io.milvus.param.collection.LoadCollectionParam; +import io.milvus.param.dml.DeleteParam; import io.milvus.param.dml.InsertParam; import io.milvus.param.dml.QueryParam; import io.milvus.param.dml.SearchParam; @@ -85,6 +86,14 @@ static QueryParam buildQueryRequest(String collectionName, .build(); } + static DeleteParam buildDeleteRequest(String collectionName, + String expr) { + return DeleteParam.newBuilder() + .withCollectionName(collectionName) + .withExpr(expr) + .build(); + } + private static String buildQueryExpression(List rowIds) { return rowIds.stream() .map(id -> format("%s == '%s'", ID_FIELD_NAME, id)) diff --git a/langchain4j-milvus/src/main/java/dev/langchain4j/store/embedding/milvus/MilvusEmbeddingStore.java b/langchain4j-milvus/src/main/java/dev/langchain4j/store/embedding/milvus/MilvusEmbeddingStore.java index d035f28a96..39ac2cf1e1 100644 --- a/langchain4j-milvus/src/main/java/dev/langchain4j/store/embedding/milvus/MilvusEmbeddingStore.java +++ b/langchain4j-milvus/src/main/java/dev/langchain4j/store/embedding/milvus/MilvusEmbeddingStore.java @@ -1,25 +1,13 @@ package dev.langchain4j.store.embedding.milvus; -import static dev.langchain4j.internal.Utils.getOrDefault; -import static dev.langchain4j.internal.ValidationUtils.ensureNotNull; -import static dev.langchain4j.store.embedding.milvus.CollectionOperationsExecutor.*; -import static dev.langchain4j.store.embedding.milvus.CollectionRequestBuilder.buildSearchRequest; -import static dev.langchain4j.store.embedding.milvus.Generator.generateRandomIds; -import static dev.langchain4j.store.embedding.milvus.Mapper.*; -import static io.milvus.common.clientenum.ConsistencyLevelEnum.EVENTUALLY; -import static io.milvus.param.IndexType.FLAT; -import static io.milvus.param.MetricType.COSINE; -import static java.util.Collections.singletonList; -import static java.util.stream.Collectors.toList; - import dev.langchain4j.data.document.Metadata; import dev.langchain4j.data.embedding.Embedding; import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.internal.Utils; import dev.langchain4j.store.embedding.EmbeddingMatch; +import dev.langchain4j.store.embedding.EmbeddingSearchRequest; import dev.langchain4j.store.embedding.EmbeddingSearchResult; import dev.langchain4j.store.embedding.EmbeddingStore; -import dev.langchain4j.store.embedding.EmbeddingSearchRequest; import dev.langchain4j.store.embedding.filter.Filter; import io.milvus.client.MilvusServiceClient; import io.milvus.common.clientenum.ConsistencyLevelEnum; @@ -29,9 +17,27 @@ import io.milvus.param.dml.InsertParam; import io.milvus.param.dml.SearchParam; import io.milvus.response.SearchResultsWrapper; + import java.util.ArrayList; +import java.util.Collection; import java.util.List; +import static dev.langchain4j.internal.Utils.getOrDefault; +import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty; +import static dev.langchain4j.internal.ValidationUtils.ensureNotNull; +import static dev.langchain4j.store.embedding.milvus.CollectionOperationsExecutor.*; +import static dev.langchain4j.store.embedding.milvus.CollectionRequestBuilder.buildSearchRequest; +import static dev.langchain4j.store.embedding.milvus.Generator.generateRandomIds; +import static dev.langchain4j.store.embedding.milvus.Mapper.*; +import static dev.langchain4j.store.embedding.milvus.MilvusMetadataFilterMapper.formatValues; +import static dev.langchain4j.store.embedding.milvus.MilvusMetadataFilterMapper.map; +import static io.milvus.common.clientenum.ConsistencyLevelEnum.EVENTUALLY; +import static io.milvus.param.IndexType.FLAT; +import static io.milvus.param.MetricType.COSINE; +import static java.lang.String.format; +import static java.util.Collections.singletonList; +import static java.util.stream.Collectors.toList; + /** * Represents an Milvus index as an embedding store. *
@@ -42,305 +48,389 @@ */ public class MilvusEmbeddingStore implements EmbeddingStore { - static final String ID_FIELD_NAME = "id"; - static final String TEXT_FIELD_NAME = "text"; - static final String METADATA_FIELD_NAME = "metadata"; - static final String VECTOR_FIELD_NAME = "vector"; - - private final MilvusServiceClient milvusClient; - private final String collectionName; - private final MetricType metricType; - private final ConsistencyLevelEnum consistencyLevel; - private final boolean retrieveEmbeddingsOnSearch; - - public MilvusEmbeddingStore( - String host, - Integer port, - String collectionName, - Integer dimension, - IndexType indexType, - MetricType metricType, - String uri, - String token, - String username, - String password, - ConsistencyLevelEnum consistencyLevel, - Boolean retrieveEmbeddingsOnSearch, - String databaseName - ) { - ConnectParam.Builder connectBuilder = ConnectParam - .newBuilder() - .withHost(getOrDefault(host, "localhost")) - .withPort(getOrDefault(port, 19530)) - .withUri(uri) - .withToken(token) - .withAuthorization(username, password); - - if (databaseName != null) { - connectBuilder.withDatabaseName(databaseName); + static final String ID_FIELD_NAME = "id"; + static final String TEXT_FIELD_NAME = "text"; + static final String METADATA_FIELD_NAME = "metadata"; + static final String VECTOR_FIELD_NAME = "vector"; + + private final MilvusServiceClient milvusClient; + private final String collectionName; + private final MetricType metricType; + private final ConsistencyLevelEnum consistencyLevel; + private final boolean retrieveEmbeddingsOnSearch; + private final boolean autoFlushOnInsert; + + public MilvusEmbeddingStore( + String host, + Integer port, + String collectionName, + Integer dimension, + IndexType indexType, + MetricType metricType, + String uri, + String token, + String username, + String password, + ConsistencyLevelEnum consistencyLevel, + Boolean retrieveEmbeddingsOnSearch, + Boolean autoFlushOnInsert, + String databaseName + ) { + ConnectParam.Builder connectBuilder = ConnectParam + .newBuilder() + .withHost(getOrDefault(host, "localhost")) + .withPort(getOrDefault(port, 19530)) + .withUri(uri) + .withToken(token) + .withAuthorization(getOrDefault(username, ""), getOrDefault(password, "")); + + if (databaseName != null) { + connectBuilder.withDatabaseName(databaseName); + } + + this.milvusClient = new MilvusServiceClient(connectBuilder.build()); + this.collectionName = getOrDefault(collectionName, "default"); + this.metricType = getOrDefault(metricType, COSINE); + this.consistencyLevel = getOrDefault(consistencyLevel, EVENTUALLY); + this.retrieveEmbeddingsOnSearch = getOrDefault(retrieveEmbeddingsOnSearch, false); + this.autoFlushOnInsert = getOrDefault(autoFlushOnInsert, false); + + if (!hasCollection(this.milvusClient, this.collectionName)) { + createCollection(this.milvusClient, this.collectionName, ensureNotNull(dimension, "dimension")); + createIndex(this.milvusClient, this.collectionName, getOrDefault(indexType, FLAT), this.metricType); + } + + loadCollectionInMemory(this.milvusClient, collectionName); } - this.milvusClient = new MilvusServiceClient(connectBuilder.build()); - this.collectionName = getOrDefault(collectionName, "default"); - this.metricType = getOrDefault(metricType, COSINE); - this.consistencyLevel = getOrDefault(consistencyLevel, EVENTUALLY); - this.retrieveEmbeddingsOnSearch = getOrDefault(retrieveEmbeddingsOnSearch, false); - - if (!hasCollection(milvusClient, this.collectionName)) { - createCollection(milvusClient, this.collectionName, ensureNotNull(dimension, "dimension")); - createIndex(milvusClient, this.collectionName, getOrDefault(indexType, FLAT), this.metricType); + public static Builder builder() { + return new Builder(); } - loadCollectionInMemory(milvusClient, collectionName); - } - - public void dropCollection(String collectionName) { - CollectionOperationsExecutor.dropCollection(milvusClient, collectionName); - } - - public String add(Embedding embedding) { - String id = Utils.randomUUID(); - add(id, embedding); - return id; - } - - public void add(String id, Embedding embedding) { - addInternal(id, embedding, null); - } - - public String add(Embedding embedding, TextSegment textSegment) { - String id = Utils.randomUUID(); - addInternal(id, embedding, textSegment); - return id; - } - - public List addAll(List embeddings) { - List ids = generateRandomIds(embeddings.size()); - addAllInternal(ids, embeddings, null); - return ids; - } - - public List addAll(List embeddings, List embedded) { - List ids = generateRandomIds(embeddings.size()); - addAllInternal(ids, embeddings, embedded); - return ids; - } - - @Override - public EmbeddingSearchResult search(EmbeddingSearchRequest embeddingSearchRequest) { - - SearchParam searchParam = buildSearchRequest( - collectionName, - embeddingSearchRequest.queryEmbedding().vectorAsList(), - embeddingSearchRequest.filter(), - embeddingSearchRequest.maxResults(), - metricType, - consistencyLevel - ); - - SearchResultsWrapper resultsWrapper = CollectionOperationsExecutor.search(milvusClient, searchParam); - - List> matches = toEmbeddingMatches( - milvusClient, - resultsWrapper, - collectionName, - consistencyLevel, - retrieveEmbeddingsOnSearch - ); - - List> result = matches.stream() - .filter(match -> match.score() >= embeddingSearchRequest.minScore()) - .collect(toList()); - - return new EmbeddingSearchResult<>(result); - } - - private void addInternal(String id, Embedding embedding, TextSegment textSegment) { - addAllInternal( - singletonList(id), - singletonList(embedding), - textSegment == null ? null : singletonList(textSegment) - ); - } - - private void addAllInternal(List ids, List embeddings, List textSegments) { - List fields = new ArrayList<>(); - fields.add(new InsertParam.Field(ID_FIELD_NAME, ids)); - fields.add(new InsertParam.Field(TEXT_FIELD_NAME, toScalars(textSegments, ids.size()))); - fields.add(new InsertParam.Field(METADATA_FIELD_NAME, toMetadataJsons(textSegments, ids.size()))); - fields.add(new InsertParam.Field(VECTOR_FIELD_NAME, toVectors(embeddings))); - - insert(milvusClient, collectionName, fields); - flush(milvusClient, collectionName); - } - - public static Builder builder() { - return new Builder(); - } - - public static class Builder { - - private String host; - private Integer port; - private String collectionName; - private Integer dimension; - private IndexType indexType; - private MetricType metricType; - private String uri; - private String token; - private String username; - private String password; - private ConsistencyLevelEnum consistencyLevel; - private Boolean retrieveEmbeddingsOnSearch; - private String databaseName; - - /** - * @param host The host of the self-managed Milvus instance. - * Default value: "localhost". - * @return builder - */ - public Builder host(String host) { - this.host = host; - return this; + public void dropCollection(String collectionName) { + CollectionOperationsExecutor.dropCollection(this.milvusClient, collectionName); } - /** - * @param port The port of the self-managed Milvus instance. - * Default value: 19530. - * @return builder - */ - public Builder port(Integer port) { - this.port = port; - return this; + public String add(Embedding embedding) { + String id = Utils.randomUUID(); + add(id, embedding); + return id; } - /** - * @param collectionName The name of the Milvus collection. - * If there is no such collection yet, it will be created automatically. - * Default value: "default". - * @return builder - */ - public Builder collectionName(String collectionName) { - this.collectionName = collectionName; - return this; + public void add(String id, Embedding embedding) { + addInternal(id, embedding, null); } - /** - * @param dimension The dimension of the embedding vector. (e.g. 384) - * Mandatory if a new collection should be created. - * @return builder - */ - public Builder dimension(Integer dimension) { - this.dimension = dimension; - return this; + public String add(Embedding embedding, TextSegment textSegment) { + String id = Utils.randomUUID(); + addInternal(id, embedding, textSegment); + return id; } - /** - * @param indexType The type of the index. - * Default value: FLAT. - * @return builder - */ - public Builder indexType(IndexType indexType) { - this.indexType = indexType; - return this; + public List addAll(List embeddings) { + List ids = generateRandomIds(embeddings.size()); + addAllInternal(ids, embeddings, null); + return ids; } - /** - * @param metricType The type of the metric used for similarity search. - * Default value: COSINE. - * @return builder - */ - public Builder metricType(MetricType metricType) { - this.metricType = metricType; - return this; + public List addAll(List embeddings, List embedded) { + List ids = generateRandomIds(embeddings.size()); + addAllInternal(ids, embeddings, embedded); + return ids; } - /** - * @param uri The URI of the managed Milvus instance. (e.g. "https://xxx.api.gcp-us-west1.zillizcloud.com") - * @return builder - */ - public Builder uri(String uri) { - this.uri = uri; - return this; + @Override + public EmbeddingSearchResult search(EmbeddingSearchRequest embeddingSearchRequest) { + + SearchParam searchParam = buildSearchRequest( + collectionName, + embeddingSearchRequest.queryEmbedding().vectorAsList(), + embeddingSearchRequest.filter(), + embeddingSearchRequest.maxResults(), + metricType, + consistencyLevel + ); + + SearchResultsWrapper resultsWrapper = CollectionOperationsExecutor.search(milvusClient, searchParam); + + List> matches = toEmbeddingMatches( + milvusClient, + resultsWrapper, + collectionName, + consistencyLevel, + retrieveEmbeddingsOnSearch + ); + + List> result = matches.stream() + .filter(match -> match.score() >= embeddingSearchRequest.minScore()) + .collect(toList()); + + return new EmbeddingSearchResult<>(result); } - /** - * @param token The token (API key) of the managed Milvus instance. - * @return builder - */ - public Builder token(String token) { - this.token = token; - return this; + private void addInternal(String id, Embedding embedding, TextSegment textSegment) { + addAllInternal( + singletonList(id), + singletonList(embedding), + textSegment == null ? null : singletonList(textSegment) + ); } - /** - * @param username The username. See details here. - * @return builder - */ - public Builder username(String username) { - this.username = username; - return this; + private void addAllInternal(List ids, List embeddings, List textSegments) { + List fields = new ArrayList<>(); + fields.add(new InsertParam.Field(ID_FIELD_NAME, ids)); + fields.add(new InsertParam.Field(TEXT_FIELD_NAME, toScalars(textSegments, ids.size()))); + fields.add(new InsertParam.Field(METADATA_FIELD_NAME, toMetadataJsons(textSegments, ids.size()))); + fields.add(new InsertParam.Field(VECTOR_FIELD_NAME, toVectors(embeddings))); + + insert(this.milvusClient, this.collectionName, fields); + if (autoFlushOnInsert) { + flush(this.milvusClient, this.collectionName); + } } /** - * @param password The password. See details here. - * @return builder + * Removes a single embedding from the store by ID. + *

CAUTION

+ *
    + *
  • Deleted entities can still be retrieved immediately after the deletion if the consistency level is set lower than {@code Strong}
  • + *
  • Entities deleted beyond the pre-specified span of time for Time Travel cannot be retrieved again.
  • + *
  • Frequent deletion operations will impact the system performance.
  • + *
  • Before deleting entities by comlpex boolean expressions, make sure the collection has been loaded.
  • + *
  • Deleting entities by complex boolean expressions is not an atomic operation. Therefore, if it fails halfway through, some data may still be deleted.
  • + *
  • Deleting entities by complex boolean expressions is supported only when the consistency is set to Bounded. For details, see Consistency
  • + *
+ * + * @param ids A collection of unique IDs of the embeddings to be removed. + * @since Milvus version 2.3.x */ - public Builder password(String password) { - this.password = password; - return this; + @Override + public void removeAll(Collection ids) { + ensureNotEmpty(ids, "ids"); + removeForVector(this.milvusClient, this.collectionName, format("%s in %s", ID_FIELD_NAME, formatValues(ids))); } - /** - * @param consistencyLevel The consistency level used by Milvus. - * Default value: EVENTUALLY. - * @return builder - */ - public Builder consistencyLevel(ConsistencyLevelEnum consistencyLevel) { - this.consistencyLevel = consistencyLevel; - return this; - } /** - * @param retrieveEmbeddingsOnSearch During a similarity search in Milvus (when calling findRelevant()), - * the embedding itself is not retrieved. - * To retrieve the embedding, an additional query is required. - * Setting this parameter to "true" will ensure that embedding is retrieved. - * Be aware that this will impact the performance of the search. - * Default value: false. - * @return builder + * Removes all embeddings that match the specified {@link Filter} from the store. + *

CAUTION

+ *
    + *
  • Deleted entities can still be retrieved immediately after the deletion if the consistency level is set lower than {@code Strong}
  • + *
  • Entities deleted beyond the pre-specified span of time for Time Travel cannot be retrieved again.
  • + *
  • Frequent deletion operations will impact the system performance.
  • + *
  • Before deleting entities by comlpex boolean expressions, make sure the collection has been loaded.
  • + *
  • Deleting entities by complex boolean expressions is not an atomic operation. Therefore, if it fails halfway through, some data may still be deleted.
  • + *
  • Deleting entities by complex boolean expressions is supported only when the consistency is set to Bounded. For details, see Consistency
  • + *
+ * + * @param filter The filter to be applied to the {@link Metadata} of the {@link TextSegment} during removal. + * Only embeddings whose {@code TextSegment}'s {@code Metadata} + * match the {@code Filter} will be removed. + * @since Milvus version 2.3.x */ - public Builder retrieveEmbeddingsOnSearch(Boolean retrieveEmbeddingsOnSearch) { - this.retrieveEmbeddingsOnSearch = retrieveEmbeddingsOnSearch; - return this; + @Override + public void removeAll(Filter filter) { + ensureNotNull(filter, "filter"); + removeForVector(this.milvusClient, this.collectionName, map(filter)); } /** - * @param databaseName Milvus name of database. - * Default value: null. In this case default Milvus database name will be used. - * @return builder + * Removes all embeddings from the store. + *

CAUTION

+ *
    + *
  • Deleted entities can still be retrieved immediately after the deletion if the consistency level is set lower than {@code Strong}
  • + *
  • Entities deleted beyond the pre-specified span of time for Time Travel cannot be retrieved again.
  • + *
  • Frequent deletion operations will impact the system performance.
  • + *
  • Before deleting entities by comlpex boolean expressions, make sure the collection has been loaded.
  • + *
  • Deleting entities by complex boolean expressions is not an atomic operation. Therefore, if it fails halfway through, some data may still be deleted.
  • + *
  • Deleting entities by complex boolean expressions is supported only when the consistency is set to Bounded. For details, see Consistency
  • + *
+ * + * @since Milvus version 2.3.x */ - public Builder databaseName(String databaseName) { - this.databaseName = databaseName; - return this; + @Override + public void removeAll() { + removeForVector(this.milvusClient, this.collectionName, format("%s != \"\"", ID_FIELD_NAME)); } - public MilvusEmbeddingStore build() { - return new MilvusEmbeddingStore( - host, - port, - collectionName, - dimension, - indexType, - metricType, - uri, - token, - username, - password, - consistencyLevel, - retrieveEmbeddingsOnSearch, - databaseName - ); + public static class Builder { + + private String host; + private Integer port; + private String collectionName; + private Integer dimension; + private IndexType indexType; + private MetricType metricType; + private String uri; + private String token; + private String username; + private String password; + private ConsistencyLevelEnum consistencyLevel; + private Boolean retrieveEmbeddingsOnSearch; + private String databaseName; + private Boolean autoFlushOnInsert; + + /** + * @param host The host of the self-managed Milvus instance. + * Default value: "localhost". + * @return builder + */ + public Builder host(String host) { + this.host = host; + return this; + } + + /** + * @param port The port of the self-managed Milvus instance. + * Default value: 19530. + * @return builder + */ + public Builder port(Integer port) { + this.port = port; + return this; + } + + /** + * @param collectionName The name of the Milvus collection. + * If there is no such collection yet, it will be created automatically. + * Default value: "default". + * @return builder + */ + public Builder collectionName(String collectionName) { + this.collectionName = collectionName; + return this; + } + + /** + * @param dimension The dimension of the embedding vector. (e.g. 384) + * Mandatory if a new collection should be created. + * @return builder + */ + public Builder dimension(Integer dimension) { + this.dimension = dimension; + return this; + } + + /** + * @param indexType The type of the index. + * Default value: FLAT. + * @return builder + */ + public Builder indexType(IndexType indexType) { + this.indexType = indexType; + return this; + } + + /** + * @param metricType The type of the metric used for similarity search. + * Default value: COSINE. + * @return builder + */ + public Builder metricType(MetricType metricType) { + this.metricType = metricType; + return this; + } + + /** + * @param uri The URI of the managed Milvus instance. (e.g. "https://xxx.api.gcp-us-west1.zillizcloud.com") + * @return builder + */ + public Builder uri(String uri) { + this.uri = uri; + return this; + } + + /** + * @param token The token (API key) of the managed Milvus instance. + * @return builder + */ + public Builder token(String token) { + this.token = token; + return this; + } + + /** + * @param username The username. See details here. + * @return builder + */ + public Builder username(String username) { + this.username = username; + return this; + } + + /** + * @param password The password. See details here. + * @return builder + */ + public Builder password(String password) { + this.password = password; + return this; + } + + /** + * @param consistencyLevel The consistency level used by Milvus. + * Default value: EVENTUALLY. + * @return builder + */ + public Builder consistencyLevel(ConsistencyLevelEnum consistencyLevel) { + this.consistencyLevel = consistencyLevel; + return this; + } + + /** + * @param retrieveEmbeddingsOnSearch During a similarity search in Milvus (when calling findRelevant()), + * the embedding itself is not retrieved. + * To retrieve the embedding, an additional query is required. + * Setting this parameter to "true" will ensure that embedding is retrieved. + * Be aware that this will impact the performance of the search. + * Default value: false. + * @return builder + */ + public Builder retrieveEmbeddingsOnSearch(Boolean retrieveEmbeddingsOnSearch) { + this.retrieveEmbeddingsOnSearch = retrieveEmbeddingsOnSearch; + return this; + } + + /** + * @param autoFlushOnInsert Whether to automatically flush after each insert + * ({@code add(...)} or {@code addAll(...)} methods). + * Default value: false. + * More info can be found + * here. + * @return builder + */ + public Builder autoFlushOnInsert(Boolean autoFlushOnInsert) { + this.autoFlushOnInsert = autoFlushOnInsert; + return this; + } + + /** + * @param databaseName Milvus name of database. + * Default value: null. In this case default Milvus database name will be used. + * @return builder + */ + public Builder databaseName(String databaseName) { + this.databaseName = databaseName; + return this; + } + + public MilvusEmbeddingStore build() { + return new MilvusEmbeddingStore( + host, + port, + collectionName, + dimension, + indexType, + metricType, + uri, + token, + username, + password, + consistencyLevel, + retrieveEmbeddingsOnSearch, + autoFlushOnInsert, + databaseName + ); + } } - } } diff --git a/langchain4j-milvus/src/main/java/dev/langchain4j/store/embedding/milvus/MilvusMetadataFilterMapper.java b/langchain4j-milvus/src/main/java/dev/langchain4j/store/embedding/milvus/MilvusMetadataFilterMapper.java index 2c563dacba..6b178aa7a9 100644 --- a/langchain4j-milvus/src/main/java/dev/langchain4j/store/embedding/milvus/MilvusMetadataFilterMapper.java +++ b/langchain4j-milvus/src/main/java/dev/langchain4j/store/embedding/milvus/MilvusMetadataFilterMapper.java @@ -66,11 +66,11 @@ private static String mapLessThanOrEqual(IsLessThanOrEqualTo isLessThanOrEqualTo return format("%s <= %s", formatKey(isLessThanOrEqualTo.key()), formatValue(isLessThanOrEqualTo.comparisonValue())); } - public static String mapIn(IsIn isIn) { + private static String mapIn(IsIn isIn) { return format("%s in %s", formatKey(isIn.key()), formatValues(isIn.comparisonValues())); } - public static String mapNotIn(IsNotIn isNotIn) { + private static String mapNotIn(IsNotIn isNotIn) { return format("%s not in %s", formatKey(isNotIn.key()), formatValues(isNotIn.comparisonValues())); } @@ -90,7 +90,7 @@ private static String formatKey(String key) { return "metadata[\"" + key + "\"]"; } - private static String formatValue(Object value) { + static String formatValue(Object value) { if (value instanceof String) { return "\"" + value + "\""; } else { @@ -98,7 +98,7 @@ private static String formatValue(Object value) { } } - private static List formatValues(Collection values) { + static List formatValues(Collection values) { return values.stream().map(value -> { if (value instanceof String) { return "\"" + value + "\""; diff --git a/langchain4j-milvus/src/test/java/dev/langchain4j/store/embedding/milvus/MilvusEmbeddingStoreCloudIT.java b/langchain4j-milvus/src/test/java/dev/langchain4j/store/embedding/milvus/MilvusEmbeddingStoreCloudIT.java index 813112d4c3..c190c673a6 100644 --- a/langchain4j-milvus/src/test/java/dev/langchain4j/store/embedding/milvus/MilvusEmbeddingStoreCloudIT.java +++ b/langchain4j-milvus/src/test/java/dev/langchain4j/store/embedding/milvus/MilvusEmbeddingStoreCloudIT.java @@ -5,6 +5,7 @@ import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel; import dev.langchain4j.model.embedding.EmbeddingModel; import dev.langchain4j.store.embedding.EmbeddingMatch; +import dev.langchain4j.store.embedding.EmbeddingSearchRequest; import dev.langchain4j.store.embedding.EmbeddingStore; import dev.langchain4j.store.embedding.EmbeddingStoreWithFilteringIT; import org.junit.jupiter.api.AfterEach; @@ -12,6 +13,7 @@ import java.util.List; +import static io.milvus.common.clientenum.ConsistencyLevelEnum.STRONG; import static java.util.Arrays.asList; import static org.assertj.core.api.Assertions.assertThat; @@ -22,7 +24,10 @@ class MilvusEmbeddingStoreCloudIT extends EmbeddingStoreWithFilteringIT { MilvusEmbeddingStore embeddingStore = MilvusEmbeddingStore.builder() .uri(System.getenv("MILVUS_URI")) .token(System.getenv("MILVUS_API_KEY")) + .username(System.getenv("MILVUS_USERNAME")) + .password(System.getenv("MILVUS_PASSWORD")) .collectionName(COLLECTION_NAME) + .consistencyLevel(STRONG) .dimension(384) .retrieveEmbeddingsOnSearch(true) .build(); @@ -52,7 +57,8 @@ void should_not_retrieve_embeddings_when_searching() { EmbeddingStore embeddingStore = MilvusEmbeddingStore.builder() .uri(System.getenv("MILVUS_URI")) .token(System.getenv("MILVUS_API_KEY")) - .collectionName("test") + .collectionName(COLLECTION_NAME) + .consistencyLevel(STRONG) .dimension(384) .retrieveEmbeddingsOnSearch(retrieveEmbeddingsOnSearch) .build(); @@ -61,9 +67,12 @@ void should_not_retrieve_embeddings_when_searching() { Embedding secondEmbedding = embeddingModel.embed("hi").content(); embeddingStore.addAll(asList(firstEmbedding, secondEmbedding)); - List> relevant = embeddingStore.findRelevant(firstEmbedding, 10); - assertThat(relevant).hasSize(2); - assertThat(relevant.get(0).embedding()).isNull(); - assertThat(relevant.get(1).embedding()).isNull(); + List> matches = embeddingStore.search(EmbeddingSearchRequest.builder() + .queryEmbedding(firstEmbedding) + .maxResults(10) + .build()).matches(); + assertThat(matches).hasSize(2); + assertThat(matches.get(0).embedding()).isNull(); + assertThat(matches.get(1).embedding()).isNull(); } } \ No newline at end of file diff --git a/langchain4j-milvus/src/test/java/dev/langchain4j/store/embedding/milvus/MilvusEmbeddingStoreIT.java b/langchain4j-milvus/src/test/java/dev/langchain4j/store/embedding/milvus/MilvusEmbeddingStoreIT.java index e6619050a4..3b9cd6c387 100644 --- a/langchain4j-milvus/src/test/java/dev/langchain4j/store/embedding/milvus/MilvusEmbeddingStoreIT.java +++ b/langchain4j-milvus/src/test/java/dev/langchain4j/store/embedding/milvus/MilvusEmbeddingStoreIT.java @@ -5,6 +5,7 @@ import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel; import dev.langchain4j.model.embedding.EmbeddingModel; import dev.langchain4j.store.embedding.EmbeddingMatch; +import dev.langchain4j.store.embedding.EmbeddingSearchRequest; import dev.langchain4j.store.embedding.EmbeddingStore; import dev.langchain4j.store.embedding.EmbeddingStoreWithFilteringIT; import org.junit.jupiter.api.AfterEach; @@ -15,6 +16,7 @@ import java.util.List; +import static io.milvus.common.clientenum.ConsistencyLevelEnum.STRONG; import static java.util.Arrays.asList; import static org.assertj.core.api.Assertions.assertThat; @@ -24,11 +26,14 @@ class MilvusEmbeddingStoreIT extends EmbeddingStoreWithFilteringIT { private static final String COLLECTION_NAME = "test_collection"; @Container - private static final MilvusContainer milvus = new MilvusContainer("milvusdb/milvus:v2.3.1"); + private static final MilvusContainer milvus = new MilvusContainer("milvusdb/milvus:v2.3.16"); MilvusEmbeddingStore embeddingStore = MilvusEmbeddingStore.builder() .uri(milvus.getEndpoint()) .collectionName(COLLECTION_NAME) + .consistencyLevel(STRONG) + .username(System.getenv("MILVUS_USERNAME")) + .password(System.getenv("MILVUS_PASSWORD")) .dimension(384) .retrieveEmbeddingsOnSearch(true) .build(); @@ -57,6 +62,7 @@ void should_not_retrieve_embeddings_when_searching() { .host(milvus.getHost()) .port(milvus.getMappedPort(19530)) .collectionName(COLLECTION_NAME) + .consistencyLevel(STRONG) .dimension(384) .retrieveEmbeddingsOnSearch(false) .build(); @@ -65,9 +71,12 @@ void should_not_retrieve_embeddings_when_searching() { Embedding secondEmbedding = embeddingModel.embed("hi").content(); embeddingStore.addAll(asList(firstEmbedding, secondEmbedding)); - List> relevant = embeddingStore.findRelevant(firstEmbedding, 10); - assertThat(relevant).hasSize(2); - assertThat(relevant.get(0).embedding()).isNull(); - assertThat(relevant.get(1).embedding()).isNull(); + List> matches = embeddingStore.search(EmbeddingSearchRequest.builder() + .queryEmbedding(firstEmbedding) + .maxResults(10) + .build()).matches(); + assertThat(matches).hasSize(2); + assertThat(matches.get(0).embedding()).isNull(); + assertThat(matches.get(1).embedding()).isNull(); } } \ No newline at end of file diff --git a/langchain4j-milvus/src/test/java/dev/langchain4j/store/embedding/milvus/MilvusEmbeddingStoreRemovalIT.java b/langchain4j-milvus/src/test/java/dev/langchain4j/store/embedding/milvus/MilvusEmbeddingStoreRemovalIT.java new file mode 100644 index 0000000000..b209a52626 --- /dev/null +++ b/langchain4j-milvus/src/test/java/dev/langchain4j/store/embedding/milvus/MilvusEmbeddingStoreRemovalIT.java @@ -0,0 +1,41 @@ +package dev.langchain4j.store.embedding.milvus; + +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel; +import dev.langchain4j.model.embedding.EmbeddingModel; +import dev.langchain4j.store.embedding.EmbeddingStore; +import dev.langchain4j.store.embedding.EmbeddingStoreWithRemovalIT; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.milvus.MilvusContainer; + +import static dev.langchain4j.internal.Utils.randomUUID; +import static io.milvus.common.clientenum.ConsistencyLevelEnum.STRONG; + +@Testcontainers +class MilvusEmbeddingStoreRemovalIT extends EmbeddingStoreWithRemovalIT { + + @Container + static MilvusContainer milvus = new MilvusContainer("milvusdb/milvus:v2.3.16"); + + MilvusEmbeddingStore embeddingStore = MilvusEmbeddingStore.builder() + .uri(milvus.getEndpoint()) + .collectionName("test_collection_" + randomUUID().replace("-", "")) + .username(System.getenv("MILVUS_USERNAME")) + .password(System.getenv("MILVUS_PASSWORD")) + .consistencyLevel(STRONG) + .dimension(384) + .build(); + + EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel(); + + @Override + protected EmbeddingStore embeddingStore() { + return embeddingStore; + } + + @Override + protected EmbeddingModel embeddingModel() { + return embeddingModel; + } +} \ No newline at end of file diff --git a/langchain4j-mistral-ai/pom.xml b/langchain4j-mistral-ai/pom.xml index d10a7fdce2..352f7d7584 100644 --- a/langchain4j-mistral-ai/pom.xml +++ b/langchain4j-mistral-ai/pom.xml @@ -5,7 +5,7 @@ dev.langchain4j langchain4j-parent - 0.30.0 + 0.32.0-SNAPSHOT ../langchain4j-parent/pom.xml @@ -26,20 +26,19 @@ com.squareup.retrofit2 - converter-gson - - - - com.google.code.gson - gson - - + converter-jackson + + + + com.fasterxml.jackson.core + jackson-databind com.squareup.okhttp3 okhttp + com.squareup.okhttp3 okhttp-sse diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiChatCompletionChoice.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiChatCompletionChoice.java deleted file mode 100644 index ad05b32f61..0000000000 --- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiChatCompletionChoice.java +++ /dev/null @@ -1,19 +0,0 @@ -package dev.langchain4j.model.mistralai; - -import lombok.AllArgsConstructor; -import lombok.Builder; -import lombok.Data; -import lombok.NoArgsConstructor; - -@Data -@NoArgsConstructor -@AllArgsConstructor -@Builder -public class MistralAiChatCompletionChoice { - - private Integer index; - private MistralAiChatMessage message; - private MistralAiDeltaMessage delta; - private String finishReason; - private MistralAiUsage usage; // usageInfo is returned only when the prompt is finished in stream mode -} diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiChatCompletionResponse.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiChatCompletionResponse.java deleted file mode 100644 index 1960780b3c..0000000000 --- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiChatCompletionResponse.java +++ /dev/null @@ -1,22 +0,0 @@ -package dev.langchain4j.model.mistralai; - -import lombok.AllArgsConstructor; -import lombok.Builder; -import lombok.Data; -import lombok.NoArgsConstructor; - -import java.util.List; - -@Data -@NoArgsConstructor -@AllArgsConstructor -@Builder -public class MistralAiChatCompletionResponse { - - private String id; - private String object; - private Integer created; - private String model; - private List choices; - private MistralAiUsage usage; -} diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiChatMessage.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiChatMessage.java deleted file mode 100644 index 08bdb6b0a6..0000000000 --- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiChatMessage.java +++ /dev/null @@ -1,20 +0,0 @@ -package dev.langchain4j.model.mistralai; - -import lombok.AllArgsConstructor; -import lombok.Builder; -import lombok.Data; -import lombok.NoArgsConstructor; - -import java.util.List; - -@Data -@NoArgsConstructor -@AllArgsConstructor -@Builder -public class MistralAiChatMessage { - - private MistralAiRole role; - private String content; - private String name; - private List toolCalls; -} diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiChatModel.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiChatModel.java index 23d58128ac..2bf5229aef 100644 --- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiChatModel.java +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiChatModel.java @@ -4,6 +4,11 @@ import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.data.message.ChatMessage; import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.mistralai.internal.api.MistralAiChatCompletionRequest; +import dev.langchain4j.model.mistralai.internal.api.MistralAiChatCompletionResponse; +import dev.langchain4j.model.mistralai.internal.api.MistralAiResponseFormatType; +import dev.langchain4j.model.mistralai.internal.api.MistralAiToolChoiceName; +import dev.langchain4j.model.mistralai.internal.client.MistralAiClient; import dev.langchain4j.model.mistralai.spi.MistralAiChatModelBuilderFactory; import dev.langchain4j.model.output.Response; import lombok.Builder; @@ -15,7 +20,7 @@ import static dev.langchain4j.internal.Utils.getOrDefault; import static dev.langchain4j.internal.Utils.isNullOrEmpty; import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty; -import static dev.langchain4j.model.mistralai.DefaultMistralAiHelper.*; +import static dev.langchain4j.model.mistralai.internal.mapper.MistralAiMapper.*; import static dev.langchain4j.spi.ServiceHelper.loadFactories; import static java.util.Collections.singletonList; @@ -74,7 +79,7 @@ public MistralAiChatModel(String baseUrl, Integer maxRetries) { this.client = MistralAiClient.builder() - .baseUrl(getOrDefault(baseUrl, MISTRALAI_API_URL)) + .baseUrl(getOrDefault(baseUrl, "https://api.mistral.ai/v1")) .apiKey(apiKey) .timeout(getOrDefault(timeout, Duration.ofSeconds(60))) .logRequests(getOrDefault(logRequests, false)) @@ -177,6 +182,7 @@ public static MistralAiChatModelBuilder builder() { } public static class MistralAiChatModelBuilder { + public MistralAiChatModelBuilder() { } diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiChatModelName.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiChatModelName.java index dc55d76019..3cd8b27bfc 100644 --- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiChatModelName.java +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiChatModelName.java @@ -32,6 +32,7 @@ public enum MistralAiChatModelName { MISTRAL_TINY("mistral-tiny"), OPEN_MIXTRAL_8x7B("open-mixtral-8x7b"), // aka mistral-small-2312 + OPEN_MIXTRAL_8X22B("open-mixtral-8x22b"), // aka open-mixtral-8x22b /** * @deprecated As of release 0.29.0, replaced by {@link #MISTRAL_SMALL_LATEST} diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiDeltaMessage.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiDeltaMessage.java deleted file mode 100644 index ad8dff41a4..0000000000 --- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiDeltaMessage.java +++ /dev/null @@ -1,19 +0,0 @@ -package dev.langchain4j.model.mistralai; - -import lombok.AllArgsConstructor; -import lombok.Builder; -import lombok.Data; -import lombok.NoArgsConstructor; - -import java.util.List; - -@Data -@NoArgsConstructor -@AllArgsConstructor -@Builder -public class MistralAiDeltaMessage { - - private MistralAiRole role; - private String content; - private List toolCalls; -} diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiEmbeddingModel.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiEmbeddingModel.java index 9c2af7fd63..cc5451f255 100644 --- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiEmbeddingModel.java +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiEmbeddingModel.java @@ -3,6 +3,9 @@ import dev.langchain4j.data.embedding.Embedding; import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.model.embedding.EmbeddingModel; +import dev.langchain4j.model.mistralai.internal.api.MistralAiEmbeddingRequest; +import dev.langchain4j.model.mistralai.internal.api.MistralAiEmbeddingResponse; +import dev.langchain4j.model.mistralai.internal.client.MistralAiClient; import dev.langchain4j.model.mistralai.spi.MistralAiEmbeddingModelBuilderFactory; import dev.langchain4j.model.output.Response; import lombok.Builder; @@ -12,7 +15,7 @@ import static dev.langchain4j.internal.RetryUtils.withRetry; import static dev.langchain4j.internal.Utils.getOrDefault; -import static dev.langchain4j.model.mistralai.DefaultMistralAiHelper.*; +import static dev.langchain4j.model.mistralai.internal.mapper.MistralAiMapper.*; import static dev.langchain4j.spi.ServiceHelper.loadFactories; import static java.util.stream.Collectors.toList; @@ -22,6 +25,7 @@ */ public class MistralAiEmbeddingModel implements EmbeddingModel { + private static final String EMBEDDINGS_ENCODING_FORMAT = "float"; private final MistralAiClient client; private final String modelName; private final Integer maxRetries; @@ -48,7 +52,7 @@ public MistralAiEmbeddingModel(String baseUrl, Boolean logResponses, Integer maxRetries) { this.client = MistralAiClient.builder() - .baseUrl(getOrDefault(baseUrl, MISTRALAI_API_URL)) + .baseUrl(getOrDefault(baseUrl, "https://api.mistral.ai/v1")) .apiKey(apiKey) .timeout(getOrDefault(timeout, Duration.ofSeconds(60))) .logRequests(getOrDefault(logRequests, false)) @@ -80,7 +84,7 @@ public Response> embedAll(List textSegments) { MistralAiEmbeddingRequest request = MistralAiEmbeddingRequest.builder() .model(modelName) .input(textSegments.stream().map(TextSegment::text).collect(toList())) - .encodingFormat(MISTRALAI_API_CREATE_EMBEDDINGS_ENCODING_FORMAT) + .encodingFormat(EMBEDDINGS_ENCODING_FORMAT) .build(); MistralAiEmbeddingResponse response = withRetry(() -> client.embedding(request), maxRetries); @@ -103,6 +107,7 @@ public static MistralAiEmbeddingModelBuilder builder() { } public static class MistralAiEmbeddingModelBuilder { + public MistralAiEmbeddingModelBuilder() { } diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiEmbeddingRequest.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiEmbeddingRequest.java deleted file mode 100644 index 4dde5bfcca..0000000000 --- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiEmbeddingRequest.java +++ /dev/null @@ -1,19 +0,0 @@ -package dev.langchain4j.model.mistralai; - -import lombok.AllArgsConstructor; -import lombok.Builder; -import lombok.Data; -import lombok.NoArgsConstructor; - -import java.util.List; - -@Data -@NoArgsConstructor -@AllArgsConstructor -@Builder -public class MistralAiEmbeddingRequest { - - private String model; - private List input; - private String encodingFormat; -} diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiEmbeddingResponse.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiEmbeddingResponse.java deleted file mode 100644 index 8a2c242b56..0000000000 --- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiEmbeddingResponse.java +++ /dev/null @@ -1,21 +0,0 @@ -package dev.langchain4j.model.mistralai; - -import lombok.AllArgsConstructor; -import lombok.Builder; -import lombok.Data; -import lombok.NoArgsConstructor; - -import java.util.List; - -@Data -@NoArgsConstructor -@AllArgsConstructor -@Builder -public class MistralAiEmbeddingResponse { - - private String id; - private String object; - private String model; - private List data; - private MistralAiUsage usage; -} diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiFunction.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiFunction.java deleted file mode 100644 index a21e5037ce..0000000000 --- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiFunction.java +++ /dev/null @@ -1,23 +0,0 @@ -package dev.langchain4j.model.mistralai; - -import dev.langchain4j.agent.tool.JsonSchemaProperty; -import dev.langchain4j.agent.tool.ToolSpecification; -import lombok.AllArgsConstructor; -import lombok.Builder; -import lombok.Data; -import lombok.NoArgsConstructor; - -import java.util.HashMap; -import java.util.Map; - -@Data -@NoArgsConstructor -@AllArgsConstructor -@Builder -class MistralAiFunction { - - private String name; - private String description; - private MistralAiParameters parameters; - -} diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiModelCard.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiModelCard.java deleted file mode 100644 index 526884eece..0000000000 --- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiModelCard.java +++ /dev/null @@ -1,23 +0,0 @@ -package dev.langchain4j.model.mistralai; - -import lombok.AllArgsConstructor; -import lombok.Builder; -import lombok.Data; -import lombok.NoArgsConstructor; - -import java.util.List; - -@Data -@NoArgsConstructor -@AllArgsConstructor -@Builder -public class MistralAiModelCard { - - private String id; - private String object; - private Integer created; - private String ownerBy; - private String root; - private String parent; - private List permission; -} diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiModelResponse.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiModelResponse.java deleted file mode 100644 index 9799c12c41..0000000000 --- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiModelResponse.java +++ /dev/null @@ -1,18 +0,0 @@ -package dev.langchain4j.model.mistralai; - -import lombok.AllArgsConstructor; -import lombok.Builder; -import lombok.Data; -import lombok.NoArgsConstructor; - -import java.util.List; - -@Data -@NoArgsConstructor -@AllArgsConstructor -@Builder -public class MistralAiModelResponse { - - private String object; - private List data; -} diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiModels.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiModels.java index 03f38c0ada..6e069a0593 100644 --- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiModels.java +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiModels.java @@ -1,5 +1,8 @@ package dev.langchain4j.model.mistralai; +import dev.langchain4j.model.mistralai.internal.api.MistralAiModelCard; +import dev.langchain4j.model.mistralai.internal.api.MistralAiModelResponse; +import dev.langchain4j.model.mistralai.internal.client.MistralAiClient; import dev.langchain4j.model.mistralai.spi.MistralAiModelsBuilderFactory; import dev.langchain4j.model.output.Response; import lombok.Builder; @@ -9,7 +12,6 @@ import static dev.langchain4j.internal.RetryUtils.withRetry; import static dev.langchain4j.internal.Utils.getOrDefault; -import static dev.langchain4j.model.mistralai.DefaultMistralAiHelper.MISTRALAI_API_URL; import static dev.langchain4j.spi.ServiceHelper.loadFactories; /** @@ -39,7 +41,7 @@ public MistralAiModels(String baseUrl, Boolean logResponses, Integer maxRetries) { this.client = MistralAiClient.builder() - .baseUrl(getOrDefault(baseUrl, MISTRALAI_API_URL)) + .baseUrl(getOrDefault(baseUrl, "https://api.mistral.ai/v1")) .apiKey(apiKey) .timeout(getOrDefault(timeout, Duration.ofSeconds(60))) .logRequests(getOrDefault(logRequests, false)) @@ -77,7 +79,6 @@ public static MistralAiModelsBuilder builder() { public static class MistralAiModelsBuilder { public MistralAiModelsBuilder(){ - } } } diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiParameters.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiParameters.java deleted file mode 100644 index 45cf6abc45..0000000000 --- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiParameters.java +++ /dev/null @@ -1,29 +0,0 @@ -package dev.langchain4j.model.mistralai; - -import dev.langchain4j.agent.tool.ToolParameters; -import lombok.AllArgsConstructor; -import lombok.Builder; -import lombok.Data; -import lombok.NoArgsConstructor; - -import java.util.List; -import java.util.Map; - -@Data -@NoArgsConstructor -@AllArgsConstructor -@Builder -class MistralAiParameters { - - @Builder.Default - private String type="object"; - private Map> properties; - private List required; - - static MistralAiParameters from(ToolParameters toolParameters){ - return MistralAiParameters.builder() - .properties(toolParameters.properties()) - .required(toolParameters.required()) - .build(); - } -} diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiRequestLoggingInterceptor.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiRequestLoggingInterceptor.java deleted file mode 100644 index 67fe325601..0000000000 --- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiRequestLoggingInterceptor.java +++ /dev/null @@ -1,47 +0,0 @@ -package dev.langchain4j.model.mistralai; - -import okhttp3.Interceptor; -import okhttp3.Request; -import okhttp3.Response; -import okio.Buffer; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.io.IOException; - -import static dev.langchain4j.model.mistralai.DefaultMistralAiHelper.getHeaders; - -class MistralAiRequestLoggingInterceptor implements Interceptor { - - private static final Logger LOGGER = LoggerFactory.getLogger(MistralAiRequestLoggingInterceptor.class); - - @Override - public Response intercept(Chain chain) throws IOException { - Request request = chain.request(); - this.log(request); - return chain.proceed(request); - } - - private void log(Request request) { - try { - LOGGER.debug("Request:\n- method: {}\n- url: {}\n- headers: {}\n- body: {}", - request.method(), request.url(), getHeaders(request.headers()), getBody(request)); - } catch (Exception e) { - LOGGER.warn("Error while logging request: {}", e.getMessage()); - } - } - - private static String getBody(Request request) { - try { - Buffer buffer = new Buffer(); - if (request.body() == null) { - return ""; - } - request.body().writeTo(buffer); - return buffer.readUtf8(); - } catch (Exception e) { - LOGGER.warn("Exception while getting body", e); - return "Exception while getting body: " + e.getMessage(); - } - } -} diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiResponseFormat.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiResponseFormat.java deleted file mode 100644 index 1d49e5363a..0000000000 --- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiResponseFormat.java +++ /dev/null @@ -1,21 +0,0 @@ -package dev.langchain4j.model.mistralai; - -import lombok.AllArgsConstructor; -import lombok.Builder; -import lombok.Data; -import lombok.NoArgsConstructor; - -@Data -@NoArgsConstructor -@AllArgsConstructor -@Builder -class MistralAiResponseFormat { - - private Object type; - - static MistralAiResponseFormat fromType(MistralAiResponseFormatType type) { - return MistralAiResponseFormat.builder() - .type(type.toString()) - .build(); - } -} diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiRole.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiRole.java deleted file mode 100644 index 3524d4c22b..0000000000 --- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiRole.java +++ /dev/null @@ -1,16 +0,0 @@ -package dev.langchain4j.model.mistralai; - -import com.google.gson.annotations.SerializedName; -import lombok.Getter; - -@Getter -public enum MistralAiRole { - - @SerializedName("system") SYSTEM, - @SerializedName("user") USER, - @SerializedName("assistant") ASSISTANT, - @SerializedName("tool") TOOL; - - MistralAiRole() { - } -} diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiStreamingChatModel.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiStreamingChatModel.java index 95c68623c0..7b18215aaa 100644 --- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiStreamingChatModel.java +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiStreamingChatModel.java @@ -5,6 +5,10 @@ import dev.langchain4j.data.message.ChatMessage; import dev.langchain4j.model.StreamingResponseHandler; import dev.langchain4j.model.chat.StreamingChatLanguageModel; +import dev.langchain4j.model.mistralai.internal.api.MistralAiChatCompletionRequest; +import dev.langchain4j.model.mistralai.internal.api.MistralAiResponseFormatType; +import dev.langchain4j.model.mistralai.internal.api.MistralAiToolChoiceName; +import dev.langchain4j.model.mistralai.internal.client.MistralAiClient; import dev.langchain4j.model.mistralai.spi.MistralAiStreamingChatModelBuilderFactory; import lombok.Builder; @@ -14,8 +18,8 @@ import static dev.langchain4j.internal.Utils.getOrDefault; import static dev.langchain4j.internal.Utils.isNullOrEmpty; import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty; -import static dev.langchain4j.model.mistralai.DefaultMistralAiHelper.*; -import static dev.langchain4j.model.mistralai.DefaultMistralAiHelper.toMistralAiTools; +import static dev.langchain4j.model.mistralai.internal.mapper.MistralAiMapper.*; +import static dev.langchain4j.model.mistralai.internal.mapper.MistralAiMapper.toMistralAiTools; import static dev.langchain4j.spi.ServiceHelper.loadFactories; import static java.util.Collections.singletonList; @@ -67,7 +71,7 @@ public MistralAiStreamingChatModel(String baseUrl, Duration timeout) { this.client = MistralAiClient.builder() - .baseUrl(getOrDefault(baseUrl, MISTRALAI_API_URL)) + .baseUrl(getOrDefault(baseUrl, "https://api.mistral.ai/v1")) .apiKey(apiKey) .timeout(getOrDefault(timeout, Duration.ofSeconds(60))) .logRequests(getOrDefault(logRequests, false)) @@ -165,6 +169,7 @@ public static MistralAiStreamingChatModelBuilder builder() { } public static class MistralAiStreamingChatModelBuilder { + public MistralAiStreamingChatModelBuilder() { } diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiTool.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiTool.java deleted file mode 100644 index a5c901448a..0000000000 --- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiTool.java +++ /dev/null @@ -1,23 +0,0 @@ -package dev.langchain4j.model.mistralai; - -import lombok.AllArgsConstructor; -import lombok.Builder; -import lombok.Data; -import lombok.NoArgsConstructor; - -@Data -@NoArgsConstructor -@AllArgsConstructor -@Builder -class MistralAiTool { - - private MistralAiToolType type; - private MistralAiFunction function; - - static MistralAiTool from(MistralAiFunction function){ - return MistralAiTool.builder() - .type(MistralAiToolType.FUNCTION) - .function(function) - .build(); - } -} diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiToolCall.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiToolCall.java deleted file mode 100644 index 982f7d85fb..0000000000 --- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiToolCall.java +++ /dev/null @@ -1,18 +0,0 @@ -package dev.langchain4j.model.mistralai; - -import lombok.AllArgsConstructor; -import lombok.Builder; -import lombok.Data; -import lombok.NoArgsConstructor; - -@Data -@NoArgsConstructor -@AllArgsConstructor -@Builder -class MistralAiToolCall { - - private String id; - @Builder.Default - private MistralAiToolType type = MistralAiToolType.FUNCTION; - private MistralAiFunctionCall function; -} diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiToolChoiceName.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiToolChoiceName.java deleted file mode 100644 index f63e9910c1..0000000000 --- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiToolChoiceName.java +++ /dev/null @@ -1,15 +0,0 @@ -package dev.langchain4j.model.mistralai; - -import com.google.gson.annotations.SerializedName; -import lombok.Getter; - -@Getter -public enum MistralAiToolChoiceName { - - @SerializedName("auto") AUTO, - @SerializedName("any") ANY, - @SerializedName("none") NONE; - - MistralAiToolChoiceName() { - } -} diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiToolType.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiToolType.java deleted file mode 100644 index 413fd8a8e4..0000000000 --- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiToolType.java +++ /dev/null @@ -1,13 +0,0 @@ -package dev.langchain4j.model.mistralai; - -import com.google.gson.annotations.SerializedName; -import lombok.Getter; - -@Getter -public enum MistralAiToolType { - - @SerializedName("function") FUNCTION; - - MistralAiToolType() { - } -} diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiUsage.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiUsage.java deleted file mode 100644 index f97f17af8d..0000000000 --- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiUsage.java +++ /dev/null @@ -1,17 +0,0 @@ -package dev.langchain4j.model.mistralai; - -import lombok.AllArgsConstructor; -import lombok.Builder; -import lombok.Data; -import lombok.NoArgsConstructor; - -@Data -@NoArgsConstructor -@AllArgsConstructor -@Builder -public class MistralAiUsage { - - private Integer promptTokens; - private Integer totalTokens; - private Integer completionTokens; -} diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiApi.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/api/MistralAiApi.java similarity index 89% rename from langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiApi.java rename to langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/api/MistralAiApi.java index 9658afcae0..9df46b7e59 100644 --- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiApi.java +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/api/MistralAiApi.java @@ -1,10 +1,10 @@ -package dev.langchain4j.model.mistralai; +package dev.langchain4j.model.mistralai.internal.api; import okhttp3.ResponseBody; import retrofit2.Call; import retrofit2.http.*; -interface MistralAiApi { +public interface MistralAiApi { @POST("chat/completions") @Headers({"Content-Type: application/json"}) diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/api/MistralAiChatCompletionChoice.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/api/MistralAiChatCompletionChoice.java new file mode 100644 index 0000000000..1d7556792a --- /dev/null +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/api/MistralAiChatCompletionChoice.java @@ -0,0 +1,23 @@ +package dev.langchain4j.model.mistralai.internal.api; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategies.SnakeCaseStrategy; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import lombok.Builder; +import lombok.Data; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; + +@Data +@JsonInclude(NON_NULL) +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonNaming(SnakeCaseStrategy.class) +public class MistralAiChatCompletionChoice { + + private Integer index; + private MistralAiChatMessage message; + private MistralAiDeltaMessage delta; + private String finishReason; + private MistralAiUsage usage; // usageInfo is returned only when the prompt is finished in stream mode +} diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiChatCompletionRequest.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/api/MistralAiChatCompletionRequest.java similarity index 54% rename from langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiChatCompletionRequest.java rename to langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/api/MistralAiChatCompletionRequest.java index cebf532cae..5790c6e4c3 100644 --- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiChatCompletionRequest.java +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/api/MistralAiChatCompletionRequest.java @@ -1,5 +1,9 @@ -package dev.langchain4j.model.mistralai; +package dev.langchain4j.model.mistralai.internal.api; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategies.SnakeCaseStrategy; +import com.fasterxml.jackson.databind.annotation.JsonNaming; import lombok.AllArgsConstructor; import lombok.Builder; import lombok.Data; @@ -7,10 +11,15 @@ import java.util.List; +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; + @Data @NoArgsConstructor @AllArgsConstructor -@Builder +@Builder(toBuilder = true) +@JsonInclude(NON_NULL) +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonNaming(SnakeCaseStrategy.class) public class MistralAiChatCompletionRequest { private String model; @@ -24,5 +33,4 @@ public class MistralAiChatCompletionRequest { private List tools; private MistralAiToolChoiceName toolChoice; private MistralAiResponseFormat responseFormat; - } diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/api/MistralAiChatCompletionResponse.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/api/MistralAiChatCompletionResponse.java new file mode 100644 index 0000000000..fd95aa5028 --- /dev/null +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/api/MistralAiChatCompletionResponse.java @@ -0,0 +1,25 @@ +package dev.langchain4j.model.mistralai.internal.api; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategies.SnakeCaseStrategy; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import lombok.Data; + +import java.util.List; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; + +@Data +@JsonInclude(NON_NULL) +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonNaming(SnakeCaseStrategy.class) +public class MistralAiChatCompletionResponse { + + private String id; + private String object; + private Integer created; + private String model; + private List choices; + private MistralAiUsage usage; +} diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/api/MistralAiChatMessage.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/api/MistralAiChatMessage.java new file mode 100644 index 0000000000..5577b4cf0b --- /dev/null +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/api/MistralAiChatMessage.java @@ -0,0 +1,28 @@ +package dev.langchain4j.model.mistralai.internal.api; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategies.SnakeCaseStrategy; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import lombok.*; + +import java.util.List; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; + +@ToString +@EqualsAndHashCode +@NoArgsConstructor +@AllArgsConstructor +@Data +@Builder +@JsonInclude(NON_NULL) +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonNaming(SnakeCaseStrategy.class) +public class MistralAiChatMessage { + + private MistralAiRole role; + private String content; + private String name; + private List toolCalls; +} diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/api/MistralAiDeltaMessage.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/api/MistralAiDeltaMessage.java new file mode 100644 index 0000000000..04d733b33c --- /dev/null +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/api/MistralAiDeltaMessage.java @@ -0,0 +1,27 @@ +package dev.langchain4j.model.mistralai.internal.api; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategies.SnakeCaseStrategy; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import lombok.*; + +import java.util.List; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; + +@ToString +@EqualsAndHashCode +@NoArgsConstructor +@AllArgsConstructor +@Builder +@Data +@JsonInclude(NON_NULL) +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonNaming(SnakeCaseStrategy.class) +public class MistralAiDeltaMessage { + + private MistralAiRole role; + private String content; + private List toolCalls; +} diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/api/MistralAiEmbedding.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/api/MistralAiEmbedding.java new file mode 100644 index 0000000000..a8a8f18b66 --- /dev/null +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/api/MistralAiEmbedding.java @@ -0,0 +1,28 @@ +package dev.langchain4j.model.mistralai.internal.api; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategies.SnakeCaseStrategy; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; + +import java.util.List; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; + +@Data +@NoArgsConstructor +@AllArgsConstructor +@Builder +@JsonInclude(NON_NULL) +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonNaming(SnakeCaseStrategy.class) +public class MistralAiEmbedding { + + private String object; + private List embedding; + private Integer index; +} diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/api/MistralAiEmbeddingRequest.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/api/MistralAiEmbeddingRequest.java new file mode 100644 index 0000000000..e0205a4b39 --- /dev/null +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/api/MistralAiEmbeddingRequest.java @@ -0,0 +1,28 @@ +package dev.langchain4j.model.mistralai.internal.api; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategies.SnakeCaseStrategy; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; + +import java.util.List; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; + +@Data +@NoArgsConstructor +@AllArgsConstructor +@Builder(toBuilder = true)@JsonInclude(NON_NULL) +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonNaming(SnakeCaseStrategy.class) + +public class MistralAiEmbeddingRequest { + + private String model; + private List input; + private String encodingFormat; +} diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/api/MistralAiEmbeddingResponse.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/api/MistralAiEmbeddingResponse.java new file mode 100644 index 0000000000..2608623809 --- /dev/null +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/api/MistralAiEmbeddingResponse.java @@ -0,0 +1,30 @@ +package dev.langchain4j.model.mistralai.internal.api; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategies.SnakeCaseStrategy; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; + +import java.util.List; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; + +@Data +@NoArgsConstructor +@AllArgsConstructor +@Builder +@JsonInclude(NON_NULL) +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonNaming(SnakeCaseStrategy.class) +public class MistralAiEmbeddingResponse { + + private String id; + private String object; + private String model; + private List data; + private MistralAiUsage usage; +} diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/api/MistralAiFunction.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/api/MistralAiFunction.java new file mode 100644 index 0000000000..133d1b2071 --- /dev/null +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/api/MistralAiFunction.java @@ -0,0 +1,27 @@ +package dev.langchain4j.model.mistralai.internal.api; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategies.SnakeCaseStrategy; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; + +@Data +@NoArgsConstructor +@AllArgsConstructor +@Builder +@JsonInclude(NON_NULL) +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonNaming(SnakeCaseStrategy.class) +public class MistralAiFunction { + + private String name; + private String description; + private MistralAiParameters parameters; + +} diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/api/MistralAiFunctionCall.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/api/MistralAiFunctionCall.java new file mode 100644 index 0000000000..9c49b6ebd4 --- /dev/null +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/api/MistralAiFunctionCall.java @@ -0,0 +1,25 @@ +package dev.langchain4j.model.mistralai.internal.api; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategies.SnakeCaseStrategy; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; + +@Data +@NoArgsConstructor +@AllArgsConstructor +@Builder +@JsonInclude(NON_NULL) +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonNaming(SnakeCaseStrategy.class) +public class MistralAiFunctionCall { + + private String name; + private String arguments; +} diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/api/MistralAiModelCard.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/api/MistralAiModelCard.java new file mode 100644 index 0000000000..4ff1ec6498 --- /dev/null +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/api/MistralAiModelCard.java @@ -0,0 +1,32 @@ +package dev.langchain4j.model.mistralai.internal.api; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategies.SnakeCaseStrategy; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; + +import java.util.List; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; + +@Data +@NoArgsConstructor +@AllArgsConstructor +@Builder +@JsonInclude(NON_NULL) +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonNaming(SnakeCaseStrategy.class) +public class MistralAiModelCard { + + private String id; + private String object; + private Integer created; + private String ownerBy; + private String root; + private String parent; + private List permission; +} diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiModelPermission.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/api/MistralAiModelPermission.java similarity index 54% rename from langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiModelPermission.java rename to langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/api/MistralAiModelPermission.java index 407189b0ac..76c3fc7163 100644 --- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiModelPermission.java +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/api/MistralAiModelPermission.java @@ -1,15 +1,23 @@ -package dev.langchain4j.model.mistralai; - +package dev.langchain4j.model.mistralai.internal.api; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategies.SnakeCaseStrategy; +import com.fasterxml.jackson.databind.annotation.JsonNaming; import lombok.AllArgsConstructor; import lombok.Builder; import lombok.Data; import lombok.NoArgsConstructor; +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; + @Data @NoArgsConstructor @AllArgsConstructor @Builder +@JsonInclude(NON_NULL) +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonNaming(SnakeCaseStrategy.class) public class MistralAiModelPermission { private String id; diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/api/MistralAiModelResponse.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/api/MistralAiModelResponse.java new file mode 100644 index 0000000000..dc3dcecbdd --- /dev/null +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/api/MistralAiModelResponse.java @@ -0,0 +1,27 @@ +package dev.langchain4j.model.mistralai.internal.api; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategies.SnakeCaseStrategy; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; + +import java.util.List; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; + +@Data +@NoArgsConstructor +@AllArgsConstructor +@Builder +@JsonInclude(NON_NULL) +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonNaming(SnakeCaseStrategy.class) +public class MistralAiModelResponse { + + private String object; + private List data; +} diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/api/MistralAiParameters.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/api/MistralAiParameters.java new file mode 100644 index 0000000000..79f8d3f83b --- /dev/null +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/api/MistralAiParameters.java @@ -0,0 +1,38 @@ +package dev.langchain4j.model.mistralai.internal.api; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategies.SnakeCaseStrategy; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import dev.langchain4j.agent.tool.ToolParameters; +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; + +import java.util.List; +import java.util.Map; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; + +@Data +@NoArgsConstructor +@AllArgsConstructor +@Builder +@JsonInclude(NON_NULL) +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonNaming(SnakeCaseStrategy.class) +public class MistralAiParameters { + + @Builder.Default + private String type="object"; + private Map> properties; + private List required; + + public static MistralAiParameters from(ToolParameters toolParameters){ + return MistralAiParameters.builder() + .properties(toolParameters.properties()) + .required(toolParameters.required()) + .build(); + } +} diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/api/MistralAiResponseFormat.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/api/MistralAiResponseFormat.java new file mode 100644 index 0000000000..f4daf97649 --- /dev/null +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/api/MistralAiResponseFormat.java @@ -0,0 +1,30 @@ +package dev.langchain4j.model.mistralai.internal.api; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategies.SnakeCaseStrategy; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; + +@Data +@NoArgsConstructor +@AllArgsConstructor +@Builder +@JsonInclude(NON_NULL) +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonNaming(SnakeCaseStrategy.class) +public class MistralAiResponseFormat { + + private Object type; + + public static MistralAiResponseFormat fromType(MistralAiResponseFormatType type) { + return MistralAiResponseFormat.builder() + .type(type.toString()) + .build(); + } +} diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiResponseFormatType.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/api/MistralAiResponseFormatType.java similarity index 73% rename from langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiResponseFormatType.java rename to langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/api/MistralAiResponseFormatType.java index 3233ecab5a..c6183d3bbb 100644 --- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiResponseFormatType.java +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/api/MistralAiResponseFormatType.java @@ -1,7 +1,6 @@ -package dev.langchain4j.model.mistralai; +package dev.langchain4j.model.mistralai.internal.api; -import com.google.gson.annotations.SerializedName; -import lombok.Getter; +import com.fasterxml.jackson.annotation.JsonProperty; /** * Represents the value of the 'type' field in the response_format parameter of the MistralAi Chat completions request. @@ -12,11 +11,10 @@ *
  • {@link MistralAiResponseFormatType#JSON_OBJECT}
  • * */ -@Getter public enum MistralAiResponseFormatType { - @SerializedName("text") TEXT, - @SerializedName("json_object") JSON_OBJECT; + @JsonProperty("text") TEXT, + @JsonProperty("json_object") JSON_OBJECT; MistralAiResponseFormatType() { } diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/api/MistralAiRole.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/api/MistralAiRole.java new file mode 100644 index 0000000000..8f2e5cdfc5 --- /dev/null +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/api/MistralAiRole.java @@ -0,0 +1,14 @@ +package dev.langchain4j.model.mistralai.internal.api; + +import com.fasterxml.jackson.annotation.JsonProperty; + +public enum MistralAiRole { + + @JsonProperty("system") SYSTEM, + @JsonProperty("user") USER, + @JsonProperty("assistant") ASSISTANT, + @JsonProperty("tool") TOOL; + + MistralAiRole() { + } +} diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/api/MistralAiTool.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/api/MistralAiTool.java new file mode 100644 index 0000000000..61a244bbfd --- /dev/null +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/api/MistralAiTool.java @@ -0,0 +1,32 @@ +package dev.langchain4j.model.mistralai.internal.api; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategies.SnakeCaseStrategy; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; + +@Data +@NoArgsConstructor +@AllArgsConstructor +@Builder +@JsonInclude(NON_NULL) +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonNaming(SnakeCaseStrategy.class) +public class MistralAiTool { + + private MistralAiToolType type; + private MistralAiFunction function; + + public static MistralAiTool from(MistralAiFunction function){ + return MistralAiTool.builder() + .type(MistralAiToolType.FUNCTION) + .function(function) + .build(); + } +} diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/api/MistralAiToolCall.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/api/MistralAiToolCall.java new file mode 100644 index 0000000000..e082c74f2c --- /dev/null +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/api/MistralAiToolCall.java @@ -0,0 +1,27 @@ +package dev.langchain4j.model.mistralai.internal.api; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategies.SnakeCaseStrategy; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; + +@Data +@NoArgsConstructor +@AllArgsConstructor +@Builder +@JsonInclude(NON_NULL) +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonNaming(SnakeCaseStrategy.class) +public class MistralAiToolCall { + + private String id; + @Builder.Default + private MistralAiToolType type = MistralAiToolType.FUNCTION; + private MistralAiFunctionCall function; +} diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/api/MistralAiToolChoiceName.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/api/MistralAiToolChoiceName.java new file mode 100644 index 0000000000..bba4bbb994 --- /dev/null +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/api/MistralAiToolChoiceName.java @@ -0,0 +1,13 @@ +package dev.langchain4j.model.mistralai.internal.api; + +import com.fasterxml.jackson.annotation.JsonProperty; + +public enum MistralAiToolChoiceName { + + @JsonProperty("auto") AUTO, + @JsonProperty("any") ANY, + @JsonProperty("none") NONE; + + MistralAiToolChoiceName() { + } +} diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/api/MistralAiToolType.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/api/MistralAiToolType.java new file mode 100644 index 0000000000..45b1a6e596 --- /dev/null +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/api/MistralAiToolType.java @@ -0,0 +1,11 @@ +package dev.langchain4j.model.mistralai.internal.api; + +import com.fasterxml.jackson.annotation.JsonProperty; + +public enum MistralAiToolType { + + @JsonProperty("function") FUNCTION; + + MistralAiToolType() { + } +} diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/api/MistralAiUsage.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/api/MistralAiUsage.java new file mode 100644 index 0000000000..0b152ae896 --- /dev/null +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/api/MistralAiUsage.java @@ -0,0 +1,20 @@ +package dev.langchain4j.model.mistralai.internal.api; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategies.SnakeCaseStrategy; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import lombok.Data; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; + +@Data +@JsonInclude(NON_NULL) +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonNaming(SnakeCaseStrategy.class) +public class MistralAiUsage { + + private Integer promptTokens; + private Integer totalTokens; + private Integer completionTokens; +} diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/DefaultMistralAiClient.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/client/DefaultMistralAiClient.java similarity index 91% rename from langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/DefaultMistralAiClient.java rename to langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/client/DefaultMistralAiClient.java index ff0fc971a1..7d2b66c4a7 100644 --- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/DefaultMistralAiClient.java +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/client/DefaultMistralAiClient.java @@ -1,11 +1,10 @@ -package dev.langchain4j.model.mistralai; +package dev.langchain4j.model.mistralai.internal.client; -import com.google.gson.FieldNamingPolicy; -import com.google.gson.Gson; -import com.google.gson.GsonBuilder; +import com.fasterxml.jackson.databind.ObjectMapper; import dev.langchain4j.agent.tool.ToolExecutionRequest; import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.model.StreamingResponseHandler; +import dev.langchain4j.model.mistralai.internal.api.*; import dev.langchain4j.model.output.FinishReason; import dev.langchain4j.model.output.Response; import dev.langchain4j.model.output.TokenUsage; @@ -17,21 +16,19 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import retrofit2.Retrofit; -import retrofit2.converter.gson.GsonConverterFactory; +import retrofit2.converter.jackson.JacksonConverterFactory; import java.io.IOException; import java.util.List; +import static com.fasterxml.jackson.databind.SerializationFeature.INDENT_OUTPUT; import static dev.langchain4j.internal.Utils.*; -import static dev.langchain4j.model.mistralai.DefaultMistralAiHelper.*; +import static dev.langchain4j.model.mistralai.internal.mapper.MistralAiMapper.*; -class DefaultMistralAiClient extends MistralAiClient { +public class DefaultMistralAiClient extends MistralAiClient { private static final Logger LOGGER = LoggerFactory.getLogger(DefaultMistralAiClient.class); - private static final Gson GSON = new GsonBuilder() - .setFieldNamingPolicy(FieldNamingPolicy.LOWER_CASE_WITH_UNDERSCORES) - .setPrettyPrinting() - .create(); + private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper().enable(INDENT_OUTPUT); private final MistralAiApi mistralAiApi; private final OkHttpClient okHttpClient; @@ -71,7 +68,7 @@ public DefaultMistralAiClient build() { Retrofit retrofit = new Retrofit.Builder() .baseUrl(formattedUrlForRetrofit(builder.baseUrl)) .client(okHttpClient) - .addConverterFactory(GsonConverterFactory.create(GSON)) + .addConverterFactory(JacksonConverterFactory.create(OBJECT_MAPPER)) .build(); mistralAiApi = retrofit.create(MistralAiApi.class); @@ -132,11 +129,11 @@ public void onEvent(EventSource eventSource, String id, String type, String data handler.onComplete(response); } else { try { - MistralAiChatCompletionResponse chatCompletionResponse = GSON.fromJson(data, MistralAiChatCompletionResponse.class); + MistralAiChatCompletionResponse chatCompletionResponse = OBJECT_MAPPER.readValue(data, MistralAiChatCompletionResponse.class); MistralAiChatCompletionChoice choice = chatCompletionResponse.getChoices().get(0); String chunk = choice.getDelta().getContent(); - if (isNotNullOrBlank(chunk)) { + if (isNotNullOrEmpty(chunk)) { contentBuilder.append(chunk); handler.onNext(chunk); } diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiApiKeyInterceptor.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/client/MistralAiApiKeyInterceptor.java similarity index 90% rename from langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiApiKeyInterceptor.java rename to langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/client/MistralAiApiKeyInterceptor.java index c1ff8d3cd4..7dcdc8b2fd 100644 --- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiApiKeyInterceptor.java +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/client/MistralAiApiKeyInterceptor.java @@ -1,4 +1,4 @@ -package dev.langchain4j.model.mistralai; +package dev.langchain4j.model.mistralai.internal.client; import okhttp3.Interceptor; import okhttp3.Request; diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiClient.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/client/MistralAiClient.java similarity index 96% rename from langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiClient.java rename to langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/client/MistralAiClient.java index 0e8f087278..a2b19d9a33 100644 --- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiClient.java +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/client/MistralAiClient.java @@ -1,7 +1,8 @@ -package dev.langchain4j.model.mistralai; +package dev.langchain4j.model.mistralai.internal.client; import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.model.StreamingResponseHandler; +import dev.langchain4j.model.mistralai.internal.api.*; import dev.langchain4j.spi.ServiceHelper; import java.time.Duration; diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiClientBuilderFactory.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/client/MistralAiClientBuilderFactory.java similarity index 73% rename from langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiClientBuilderFactory.java rename to langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/client/MistralAiClientBuilderFactory.java index b0653282dd..982e99bdc4 100644 --- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiClientBuilderFactory.java +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/client/MistralAiClientBuilderFactory.java @@ -1,4 +1,4 @@ -package dev.langchain4j.model.mistralai; +package dev.langchain4j.model.mistralai.internal.client; import java.util.function.Supplier; diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/client/MistralAiRequestLoggingInterceptor.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/client/MistralAiRequestLoggingInterceptor.java new file mode 100644 index 0000000000..352c1e891a --- /dev/null +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/client/MistralAiRequestLoggingInterceptor.java @@ -0,0 +1,78 @@ +package dev.langchain4j.model.mistralai.internal.client; + +import lombok.extern.slf4j.Slf4j; +import okhttp3.Headers; +import okhttp3.Interceptor; +import okhttp3.Request; +import okhttp3.Response; +import okio.Buffer; + +import java.io.IOException; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import java.util.stream.Collectors; +import java.util.stream.StreamSupport; + +@Slf4j +class MistralAiRequestLoggingInterceptor implements Interceptor { + + @Override + public Response intercept(Chain chain) throws IOException { + Request request = chain.request(); + this.log(request); + return chain.proceed(request); + } + + private void log(Request request) { + try { + log.debug("Request:\n- method: {}\n- url: {}\n- headers: {}\n- body: {}", + request.method(), request.url(), getHeaders(request.headers()), getBody(request)); + } catch (Exception e) { + log.warn("Error while logging request: {}", e.getMessage()); + } + } + + static String getHeaders(Headers headers) { + return StreamSupport.stream(headers.spliterator(), false).map(header -> { + String headerKey = header.component1(); + String headerValue = header.component2(); + if (headerKey.equals("Authorization")) { + headerValue = maskAuthorizationHeaderValue(headerValue); + } + return String.format("[%s: %s]", headerKey, headerValue); + }).collect(Collectors.joining(", ")); + } + + private static String maskAuthorizationHeaderValue(String authorizationHeaderValue) { + try { + Pattern apiKeyBearerPattern = Pattern.compile("^(Bearer\\s*) ([A-Za-z0-9]{1,32})$"); + + Matcher matcher = apiKeyBearerPattern.matcher(authorizationHeaderValue); + StringBuffer sb = new StringBuffer(); + + while (matcher.find()) { + String bearer = matcher.group(1); + String token = matcher.group(2); + matcher.appendReplacement(sb, bearer + " " + token.substring(0, 2) + "..." + token.substring(token.length() - 2)); + } + matcher.appendTail(sb); + return sb.toString(); + } catch (Exception e) { + return "Error while masking Authorization header value"; + } + } + + private static String getBody(Request request) { + try { + Buffer buffer = new Buffer(); + if (request.body() == null) { + return ""; + } + request.body().writeTo(buffer); + return buffer.readUtf8(); + } catch (Exception e) { + log.warn("Exception while getting body", e); + return "Exception while getting body: " + e.getMessage(); + } + } +} diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiResponseLoggingInterceptor.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/client/MistralAiResponseLoggingInterceptor.java similarity index 70% rename from langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiResponseLoggingInterceptor.java rename to langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/client/MistralAiResponseLoggingInterceptor.java index e630623b07..b5d35993ce 100644 --- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiResponseLoggingInterceptor.java +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/client/MistralAiResponseLoggingInterceptor.java @@ -1,19 +1,17 @@ -package dev.langchain4j.model.mistralai; +package dev.langchain4j.model.mistralai.internal.client; +import lombok.extern.slf4j.Slf4j; import okhttp3.Interceptor; import okhttp3.Request; import okhttp3.Response; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; import java.io.IOException; -import static dev.langchain4j.model.mistralai.DefaultMistralAiHelper.getHeaders; +import static dev.langchain4j.model.mistralai.internal.client.MistralAiRequestLoggingInterceptor.getHeaders; +@Slf4j class MistralAiResponseLoggingInterceptor implements Interceptor { - private static final Logger LOGGER = LoggerFactory.getLogger(MistralAiResponseLoggingInterceptor.class); - @Override public Response intercept(Chain chain) throws IOException { Request request = chain.request(); @@ -24,10 +22,10 @@ public Response intercept(Chain chain) throws IOException { private void log(Response response) { try { - LOGGER.debug("Response:\n- status code: {}\n- headers: {}\n- body: {}", + log.debug("Response:\n- status code: {}\n- headers: {}\n- body: {}", response.code(), getHeaders(response.headers()), this.getBody(response)); } catch (Exception e) { - LOGGER.warn("Error while logging response: {}", e.getMessage()); + log.warn("Error while logging response: {}", e.getMessage()); } } diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/DefaultMistralAiHelper.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/mapper/MistralAiMapper.java similarity index 74% rename from langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/DefaultMistralAiHelper.java rename to langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/mapper/MistralAiMapper.java index 3d2044612b..b198a8f4d5 100644 --- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/DefaultMistralAiHelper.java +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/internal/mapper/MistralAiMapper.java @@ -1,9 +1,10 @@ -package dev.langchain4j.model.mistralai; +package dev.langchain4j.model.mistralai.internal.mapper; import dev.langchain4j.agent.tool.ToolExecutionRequest; import dev.langchain4j.agent.tool.ToolParameters; import dev.langchain4j.agent.tool.ToolSpecification; import dev.langchain4j.data.message.*; +import dev.langchain4j.model.mistralai.internal.api.*; import dev.langchain4j.model.output.FinishReason; import dev.langchain4j.model.output.TokenUsage; import okhttp3.Headers; @@ -14,22 +15,16 @@ import java.util.stream.Collectors; import java.util.stream.StreamSupport; -import static dev.langchain4j.data.message.AiMessage.aiMessage; import static dev.langchain4j.internal.Utils.isNullOrBlank; import static dev.langchain4j.internal.Utils.isNullOrEmpty; import static dev.langchain4j.model.output.FinishReason.*; import static java.util.stream.Collectors.toList; -public class DefaultMistralAiHelper { +public class MistralAiMapper { - static final String MISTRALAI_API_URL = "https://api.mistral.ai/v1"; - static final String MISTRALAI_API_CREATE_EMBEDDINGS_ENCODING_FORMAT = "float"; - private static final Pattern MISTRAI_API_KEY_BEARER_PATTERN = - Pattern.compile("^(Bearer\\s*) ([A-Za-z0-9]{1,32})$"); - - static List toMistralAiMessages(List messages) { + public static List toMistralAiMessages(List messages) { return messages.stream() - .map(DefaultMistralAiHelper::toMistralAiMessage) + .map(MistralAiMapper::toMistralAiMessage) .collect(toList()); } @@ -52,7 +47,7 @@ static MistralAiChatMessage toMistralAiMessage(ChatMessage message) { } List toolCalls = aiMessage.toolExecutionRequests().stream() - .map(DefaultMistralAiHelper::toMistralAiToolCall) + .map(MistralAiMapper::toMistralAiToolCall) .collect(toList()); if (isNullOrBlank(aiMessage.text())){ @@ -140,7 +135,7 @@ public static AiMessage aiMessageFrom(MistralAiChatCompletionResponse response) public static List toToolExecutionRequests(List mistralAiToolCalls) { return mistralAiToolCalls.stream() .filter(toolCall -> toolCall.getType() == MistralAiToolType.FUNCTION) - .map(DefaultMistralAiHelper::toToolExecutionRequest) + .map(MistralAiMapper::toToolExecutionRequest) .collect(toList()); } @@ -152,9 +147,9 @@ public static ToolExecutionRequest toToolExecutionRequest(MistralAiToolCall mist .build(); } - static List toMistralAiTools(List toolSpecifications) { + public static List toMistralAiTools(List toolSpecifications) { return toolSpecifications.stream() - .map(DefaultMistralAiHelper::toMistralAiTool) + .map(MistralAiMapper::toMistralAiTool) .collect(toList()); } @@ -174,7 +169,7 @@ static MistralAiParameters toMistralAiParameters(ToolParameters parameters){ return MistralAiParameters.from(parameters); } - static MistralAiResponseFormat toMistralAiResponseFormat(String responseFormat) { + public static MistralAiResponseFormat toMistralAiResponseFormat(String responseFormat) { if (responseFormat == null) { return null; } @@ -187,32 +182,4 @@ static MistralAiResponseFormat toMistralAiResponseFormat(String responseFormat) throw new IllegalArgumentException("Unknown response format: " + responseFormat); } } - - static String getHeaders(Headers headers) { - return StreamSupport.stream(headers.spliterator(), false).map(header -> { - String headerKey = header.component1(); - String headerValue = header.component2(); - if (headerKey.equals("Authorization")) { - headerValue = maskAuthorizationHeaderValue(headerValue); - } - return String.format("[%s: %s]", headerKey, headerValue); - }).collect(Collectors.joining(", ")); - } - - private static String maskAuthorizationHeaderValue(String authorizationHeaderValue) { - try { - Matcher matcher = MISTRAI_API_KEY_BEARER_PATTERN.matcher(authorizationHeaderValue); - StringBuffer sb = new StringBuffer(); - - while (matcher.find()) { - String bearer = matcher.group(1); - String token = matcher.group(2); - matcher.appendReplacement(sb, bearer + " " + token.substring(0, 2) + "..." + token.substring(token.length() - 2)); - } - matcher.appendTail(sb); - return sb.toString(); - } catch (Exception e) { - return "Error while masking Authorization header value"; - } - } } diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/spi/MistralAiChatModelBuilderFactory.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/spi/MistralAiChatModelBuilderFactory.java index 6944ed8bf0..f4871d24d4 100644 --- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/spi/MistralAiChatModelBuilderFactory.java +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/spi/MistralAiChatModelBuilderFactory.java @@ -7,5 +7,5 @@ /** * A factory for building {@link dev.langchain4j.model.mistralai.MistralAiChatModel.MistralAiChatModelBuilder} instances. */ -public interface MistralAiChatModelBuilderFactory extends Supplier{ +public interface MistralAiChatModelBuilderFactory extends Supplier { } diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/spi/MistralAiStreamingChatModelBuilderFactory.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/spi/MistralAiStreamingChatModelBuilderFactory.java index ba4d4d9fb4..5072810ad4 100644 --- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/spi/MistralAiStreamingChatModelBuilderFactory.java +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/spi/MistralAiStreamingChatModelBuilderFactory.java @@ -7,5 +7,5 @@ /** * A factory for building {@link MistralAiStreamingChatModel.MistralAiStreamingChatModelBuilder} instances. */ -public interface MistralAiStreamingChatModelBuilderFactory extends Supplier{ +public interface MistralAiStreamingChatModelBuilderFactory extends Supplier { } diff --git a/langchain4j-mistral-ai/src/test/java/dev/langchain4j/model/mistralai/MistralAiChatModelIT.java b/langchain4j-mistral-ai/src/test/java/dev/langchain4j/model/mistralai/MistralAiChatModelIT.java index 61cc04de1f..a4e8e099a6 100644 --- a/langchain4j-mistral-ai/src/test/java/dev/langchain4j/model/mistralai/MistralAiChatModelIT.java +++ b/langchain4j-mistral-ai/src/test/java/dev/langchain4j/model/mistralai/MistralAiChatModelIT.java @@ -7,6 +7,7 @@ import dev.langchain4j.data.message.ToolExecutionResultMessage; import dev.langchain4j.data.message.UserMessage; import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.mistralai.internal.api.MistralAiResponseFormatType; import dev.langchain4j.model.output.Response; import dev.langchain4j.model.output.TokenUsage; import org.junit.jupiter.api.Test; @@ -17,7 +18,6 @@ import static dev.langchain4j.agent.tool.JsonSchemaProperty.STRING; import static dev.langchain4j.data.message.UserMessage.userMessage; import static dev.langchain4j.model.output.FinishReason.*; -import static dev.langchain4j.model.output.FinishReason.TOOL_EXECUTION; import static java.util.Arrays.asList; import static java.util.Collections.singletonList; import static org.assertj.core.api.Assertions.assertThat; @@ -38,13 +38,21 @@ class MistralAiChatModelIT { .logResponses(true) .build(); - ChatLanguageModel model = MistralAiChatModel.builder() + ChatLanguageModel defaultModel = MistralAiChatModel.builder() .apiKey(System.getenv("MISTRAL_AI_API_KEY")) .temperature(0.1) .logRequests(true) .logResponses(true) .build(); + ChatLanguageModel openMixtral8x22BModel = MistralAiChatModel.builder() + .apiKey(System.getenv("MISTRAL_AI_API_KEY")) + .modelName(MistralAiChatModelName.OPEN_MIXTRAL_8X22B) + .temperature(0.1) + .logRequests(true) + .logResponses(true) + .build(); + @Test void should_generate_answer_and_return_token_usage_and_finish_reason_stop() { @@ -52,13 +60,13 @@ void should_generate_answer_and_return_token_usage_and_finish_reason_stop() { UserMessage userMessage = userMessage("What is the capital of Peru?"); // when - Response response = model.generate(userMessage); + Response response = defaultModel.generate(userMessage); // then assertThat(response.content().text()).contains("Lima"); TokenUsage tokenUsage = response.tokenUsage(); - assertThat(tokenUsage.inputTokenCount()).isEqualTo(15); + assertThat(tokenUsage.inputTokenCount()).isGreaterThan(0); assertThat(tokenUsage.outputTokenCount()).isGreaterThan(0); assertThat(tokenUsage.totalTokenCount()) .isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount()); @@ -85,8 +93,8 @@ void should_generate_answer_and_return_token_usage_and_finish_reason_length() { assertThat(response.content().text()).isNotBlank(); TokenUsage tokenUsage = response.tokenUsage(); - assertThat(tokenUsage.inputTokenCount()).isEqualTo(15); - assertThat(tokenUsage.outputTokenCount()).isEqualTo(4); + assertThat(tokenUsage.inputTokenCount()).isGreaterThan(0); + assertThat(tokenUsage.outputTokenCount()).isGreaterThan(0); assertThat(tokenUsage.totalTokenCount()) .isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount()); @@ -100,6 +108,7 @@ void should_generate_system_prompt_to_enforce_guardrails() { ChatLanguageModel model = MistralAiChatModel.builder() .apiKey(System.getenv("MISTRAL_AI_API_KEY")) .safePrompt(true) + .temperature(0.0) .build(); // given @@ -111,16 +120,14 @@ void should_generate_system_prompt_to_enforce_guardrails() { // then AiMessage aiMessage = response.content(); assertThat(aiMessage.text()).contains("respect"); - assertThat(aiMessage.text()).contains("truth"); TokenUsage tokenUsage = response.tokenUsage(); - assertThat(tokenUsage.inputTokenCount()).isGreaterThan(50); + assertThat(tokenUsage.inputTokenCount()).isGreaterThan(0); assertThat(tokenUsage.outputTokenCount()).isGreaterThan(0); assertThat(tokenUsage.totalTokenCount()) .isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount()); assertThat(response.finishReason()).isEqualTo(STOP); - } @Test @@ -132,7 +139,7 @@ void should_generate_answer_and_return_token_usage_and_finish_reason_stop_with_m UserMessage userMessage3 = userMessage("What is the capital of Canada?"); // when - Response response = model.generate(userMessage1, userMessage2, userMessage3); + Response response = defaultModel.generate(userMessage1, userMessage2, userMessage3); // then assertThat(response.content().text()).contains("Lima"); @@ -140,7 +147,7 @@ void should_generate_answer_and_return_token_usage_and_finish_reason_stop_with_m assertThat(response.content().text()).contains("Ottawa"); TokenUsage tokenUsage = response.tokenUsage(); - assertThat(tokenUsage.inputTokenCount()).isEqualTo(11 + 11 + 11); + assertThat(tokenUsage.inputTokenCount()).isGreaterThan(0); assertThat(tokenUsage.outputTokenCount()).isGreaterThan(0); assertThat(tokenUsage.totalTokenCount()) .isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount()); @@ -169,7 +176,7 @@ void should_generate_answer_in_french_using_model_small_and_return_token_usage_a assertThat(response.content().text()).contains("Lima"); TokenUsage tokenUsage = response.tokenUsage(); - assertThat(tokenUsage.inputTokenCount()).isEqualTo(18); + assertThat(tokenUsage.inputTokenCount()).isGreaterThan(0); assertThat(tokenUsage.outputTokenCount()).isGreaterThan(0); assertThat(tokenUsage.totalTokenCount()) .isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount()); @@ -198,7 +205,7 @@ void should_generate_answer_in_spanish_using_model_small_and_return_token_usage_ assertThat(response.content().text()).contains("Lima"); TokenUsage tokenUsage = response.tokenUsage(); - assertThat(tokenUsage.inputTokenCount()).isEqualTo(19); + assertThat(tokenUsage.inputTokenCount()).isGreaterThan(0); assertThat(tokenUsage.outputTokenCount()).isGreaterThan(0); assertThat(tokenUsage.totalTokenCount()) .isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount()); @@ -227,8 +234,8 @@ void should_generate_answer_using_model_medium_and_return_token_usage_and_finish assertThat(response.content().text()).contains("Lima"); TokenUsage tokenUsage = response.tokenUsage(); - assertThat(tokenUsage.inputTokenCount()).isEqualTo(15); - assertThat(tokenUsage.outputTokenCount()).isEqualTo(10); + assertThat(tokenUsage.inputTokenCount()).isGreaterThan(0); + assertThat(tokenUsage.outputTokenCount()).isGreaterThan(0); assertThat(tokenUsage.totalTokenCount()) .isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount()); @@ -236,14 +243,14 @@ void should_generate_answer_using_model_medium_and_return_token_usage_and_finish } @Test - void should_execute_tool_and_return_finishReason_tool_execution(){ + void should_execute_tool_using_model_open8x22B_and_return_finishReason_tool_execution() { // given UserMessage userMessage = userMessage("What is the status of transaction T123?"); List toolSpecifications = singletonList(retrievePaymentStatus); // when - Response response = mistralLargeModel.generate(singletonList(userMessage), toolSpecifications); + Response response = openMixtral8x22BModel.generate(singletonList(userMessage), toolSpecifications); // then AiMessage aiMessage = response.content(); @@ -255,8 +262,8 @@ void should_execute_tool_and_return_finishReason_tool_execution(){ assertThat(toolExecutionRequest.arguments()).isEqualToIgnoringWhitespace("{\"transactionId\":\"T123\"}"); TokenUsage tokenUsage = response.tokenUsage(); - assertThat(tokenUsage.inputTokenCount()).isEqualTo(78); - assertThat(tokenUsage.outputTokenCount()).isEqualTo(28); + assertThat(tokenUsage.inputTokenCount()).isGreaterThan(0); + assertThat(tokenUsage.outputTokenCount()).isGreaterThan(0); assertThat(tokenUsage.totalTokenCount()) .isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount()); @@ -264,7 +271,7 @@ void should_execute_tool_and_return_finishReason_tool_execution(){ } @Test - void should_execute_tool_when_toolChoice_is_auto_and_answer(){ + void should_execute_tool_using_model_open8x22B_when_toolChoice_is_auto_and_answer() { // given ToolSpecification retrievePaymentDate = ToolSpecification.builder() .name("retrieve-payment-date") @@ -276,10 +283,10 @@ void should_execute_tool_when_toolChoice_is_auto_and_answer(){ UserMessage userMessage = userMessage("What is the status of transaction T123?"); chatMessages.add(userMessage); - List toolSpecifications = asList(retrievePaymentStatus,retrievePaymentDate); + List toolSpecifications = asList(retrievePaymentStatus, retrievePaymentDate); // when - Response response = mistralLargeModel.generate(chatMessages, toolSpecifications); + Response response = openMixtral8x22BModel.generate(chatMessages, toolSpecifications); // then AiMessage aiMessage = response.content(); @@ -299,7 +306,7 @@ void should_execute_tool_when_toolChoice_is_auto_and_answer(){ chatMessages.add(toolExecutionResultMessage); // when - Response response2 = mistralLargeModel.generate(chatMessages); + Response response2 = openMixtral8x22BModel.generate(chatMessages); // then AiMessage aiMessage2 = response2.content(); @@ -308,7 +315,7 @@ void should_execute_tool_when_toolChoice_is_auto_and_answer(){ assertThat(aiMessage2.toolExecutionRequests()).isNull(); TokenUsage tokenUsage2 = response2.tokenUsage(); - assertThat(tokenUsage2.inputTokenCount()).isEqualTo(69); + assertThat(tokenUsage2.inputTokenCount()).isGreaterThan(0); assertThat(tokenUsage2.outputTokenCount()).isGreaterThan(0); assertThat(tokenUsage2.totalTokenCount()) .isEqualTo(tokenUsage2.inputTokenCount() + tokenUsage2.outputTokenCount()); @@ -317,7 +324,7 @@ void should_execute_tool_when_toolChoice_is_auto_and_answer(){ } @Test - void should_execute_tool_forcefully_when_toolChoice_is_any_and_answer() { + void should_execute_tool_forcefully_using_model_open8x22B_when_toolChoice_is_any_and_answer() { // given ToolSpecification retrievePaymentDate = ToolSpecification.builder() .name("retrieve-payment-date") @@ -330,7 +337,7 @@ void should_execute_tool_forcefully_when_toolChoice_is_any_and_answer() { chatMessages.add(userMessage); // when - Response response = mistralLargeModel.generate(singletonList(userMessage), retrievePaymentDate); + Response response = openMixtral8x22BModel.generate(singletonList(userMessage), retrievePaymentDate); // then AiMessage aiMessage = response.content(); @@ -342,7 +349,7 @@ void should_execute_tool_forcefully_when_toolChoice_is_any_and_answer() { assertThat(toolExecutionRequest.arguments()).isEqualToIgnoringWhitespace("{\"transactionId\":\"T123\"}"); TokenUsage tokenUsage = response.tokenUsage(); - assertThat(tokenUsage.inputTokenCount()).isEqualTo(79); + assertThat(tokenUsage.inputTokenCount()).isGreaterThan(0); assertThat(tokenUsage.outputTokenCount()).isGreaterThan(0); assertThat(tokenUsage.totalTokenCount()) .isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount()); @@ -355,7 +362,7 @@ void should_execute_tool_forcefully_when_toolChoice_is_any_and_answer() { chatMessages.add(toolExecutionResultMessage); // when - Response response2 = mistralLargeModel.generate(chatMessages); + Response response2 = openMixtral8x22BModel.generate(chatMessages); // then AiMessage aiMessage2 = response2.content(); @@ -364,7 +371,7 @@ void should_execute_tool_forcefully_when_toolChoice_is_any_and_answer() { assertThat(aiMessage2.toolExecutionRequests()).isNull(); TokenUsage tokenUsage2 = response2.tokenUsage(); - assertThat(tokenUsage2.inputTokenCount()).isEqualTo(78); + assertThat(tokenUsage2.inputTokenCount()).isGreaterThan(0); assertThat(tokenUsage2.outputTokenCount()).isGreaterThan(0); assertThat(tokenUsage2.totalTokenCount()) .isEqualTo(tokenUsage2.inputTokenCount() + tokenUsage2.outputTokenCount()); @@ -373,7 +380,7 @@ void should_execute_tool_forcefully_when_toolChoice_is_any_and_answer() { } @Test - void should_return_valid_json_object(){ + void should_return_valid_json_object_using_model_large() { // given String userMessage = "Return JSON with two fields: transactionId and status with the values T123 and paid."; @@ -397,7 +404,7 @@ void should_return_valid_json_object(){ } @Test - void should_execute_multiple_tools_then_answer(){ + void should_execute_multiple_tools_using_model_open8x22B_then_answer() { // given ToolSpecification retrievePaymentDate = ToolSpecification.builder() .name("retrieve-payment-date") @@ -409,7 +416,7 @@ void should_execute_multiple_tools_then_answer(){ UserMessage userMessage = userMessage("What is the status and the payment date of transaction T123?"); chatMessages.add(userMessage); - List toolSpecifications = asList(retrievePaymentStatus,retrievePaymentDate); + List toolSpecifications = asList(retrievePaymentStatus, retrievePaymentDate); // when Response response = mistralLargeModel.generate(chatMessages, toolSpecifications); @@ -448,7 +455,7 @@ void should_execute_multiple_tools_then_answer(){ assertThat(aiMessage2.toolExecutionRequests()).isNull(); TokenUsage tokenUsage2 = response2.tokenUsage(); - assertThat(tokenUsage2.inputTokenCount()).isEqualTo(128); + assertThat(tokenUsage2.inputTokenCount()).isGreaterThan(0); assertThat(tokenUsage2.outputTokenCount()).isGreaterThan(0); assertThat(tokenUsage2.totalTokenCount()) .isEqualTo(tokenUsage2.inputTokenCount() + tokenUsage2.outputTokenCount()); diff --git a/langchain4j-mistral-ai/src/test/java/dev/langchain4j/model/mistralai/MistralAiModelsIT.java b/langchain4j-mistral-ai/src/test/java/dev/langchain4j/model/mistralai/MistralAiModelsIT.java index 3d18e5675c..22c75cdd3c 100644 --- a/langchain4j-mistral-ai/src/test/java/dev/langchain4j/model/mistralai/MistralAiModelsIT.java +++ b/langchain4j-mistral-ai/src/test/java/dev/langchain4j/model/mistralai/MistralAiModelsIT.java @@ -1,5 +1,6 @@ package dev.langchain4j.model.mistralai; +import dev.langchain4j.model.mistralai.internal.api.MistralAiModelCard; import dev.langchain4j.model.output.Response; import org.junit.jupiter.api.Test; diff --git a/langchain4j-mistral-ai/src/test/java/dev/langchain4j/model/mistralai/MistralAiStreamingChatModelIT.java b/langchain4j-mistral-ai/src/test/java/dev/langchain4j/model/mistralai/MistralAiStreamingChatModelIT.java index ed1ad16ff0..3c11a15df4 100644 --- a/langchain4j-mistral-ai/src/test/java/dev/langchain4j/model/mistralai/MistralAiStreamingChatModelIT.java +++ b/langchain4j-mistral-ai/src/test/java/dev/langchain4j/model/mistralai/MistralAiStreamingChatModelIT.java @@ -8,6 +8,7 @@ import dev.langchain4j.data.message.UserMessage; import dev.langchain4j.model.chat.StreamingChatLanguageModel; import dev.langchain4j.model.chat.TestStreamingResponseHandler; +import dev.langchain4j.model.mistralai.internal.api.MistralAiResponseFormatType; import dev.langchain4j.model.output.Response; import dev.langchain4j.model.output.TokenUsage; import org.junit.jupiter.api.Test; @@ -18,7 +19,6 @@ import static dev.langchain4j.agent.tool.JsonSchemaProperty.STRING; import static dev.langchain4j.data.message.UserMessage.userMessage; import static dev.langchain4j.model.output.FinishReason.*; -import static dev.langchain4j.model.output.FinishReason.TOOL_EXECUTION; import static java.util.Arrays.asList; import static java.util.Collections.singletonList; import static org.assertj.core.api.Assertions.assertThat; @@ -39,11 +39,19 @@ class MistralAiStreamingChatModelIT { .logResponses(true) .build(); - StreamingChatLanguageModel model = MistralAiStreamingChatModel.builder() + StreamingChatLanguageModel defaultModel = MistralAiStreamingChatModel.builder() .apiKey(System.getenv("MISTRAL_AI_API_KEY")) .temperature(0.1) + .logRequests(true) .logResponses(true) + .build(); + + StreamingChatLanguageModel openMixtral8x22BModel = MistralAiStreamingChatModel.builder() + .apiKey(System.getenv("MISTRAL_AI_API_KEY")) + .modelName(MistralAiChatModelName.OPEN_MIXTRAL_8X22B) + .temperature(0.1) .logRequests(true) + .logResponses(true) .build(); @Test @@ -54,7 +62,7 @@ void should_stream_answer_and_return_token_usage_and_finish_reason_stop() { // when TestStreamingResponseHandler handler = new TestStreamingResponseHandler<>(); - model.generate(singletonList(userMessage), handler); + defaultModel.generate(singletonList(userMessage), handler); Response response = handler.get(); @@ -62,7 +70,7 @@ void should_stream_answer_and_return_token_usage_and_finish_reason_stop() { assertThat(response.content().text()).containsIgnoringCase("Lima"); TokenUsage tokenUsage = response.tokenUsage(); - assertThat(tokenUsage.inputTokenCount()).isEqualTo(15); + assertThat(tokenUsage.inputTokenCount()).isGreaterThan(0); assertThat(tokenUsage.outputTokenCount()).isGreaterThan(0); assertThat(tokenUsage.totalTokenCount()) .isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount()); @@ -91,8 +99,8 @@ void should_stream_answer_and_return_token_usage_and_finish_reason_length() { assertThat(response.content().text()).containsIgnoringCase("Lima"); TokenUsage tokenUsage = response.tokenUsage(); - assertThat(tokenUsage.inputTokenCount()).isEqualTo(15); - assertThat(tokenUsage.outputTokenCount()).isEqualTo(10); + assertThat(tokenUsage.inputTokenCount()).isGreaterThan(0); + assertThat(tokenUsage.outputTokenCount()).isGreaterThan(0); assertThat(tokenUsage.totalTokenCount()) .isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount()); @@ -107,6 +115,7 @@ void should_stream_answer_and_system_prompt_to_enforce_guardrails() { StreamingChatLanguageModel model = MistralAiStreamingChatModel.builder() .apiKey(System.getenv("MISTRAL_AI_API_KEY")) .safePrompt(true) + .temperature(0.0) .build(); // given @@ -121,7 +130,7 @@ void should_stream_answer_and_system_prompt_to_enforce_guardrails() { assertThat(response.content().text()).containsIgnoringCase("respect"); TokenUsage tokenUsage = response.tokenUsage(); - assertThat(tokenUsage.inputTokenCount()).isGreaterThan(50); + assertThat(tokenUsage.inputTokenCount()).isGreaterThan(0); assertThat(tokenUsage.outputTokenCount()).isGreaterThan(0); assertThat(tokenUsage.totalTokenCount()) .isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount()); @@ -139,7 +148,7 @@ void should_stream_answer_and_return_token_usage_and_finish_reason_stop_with_mul // when TestStreamingResponseHandler handler = new TestStreamingResponseHandler<>(); - model.generate(asList(userMessage1, userMessage2, userMessage3), handler); + defaultModel.generate(asList(userMessage1, userMessage2, userMessage3), handler); Response response = handler.get(); // then @@ -148,7 +157,7 @@ void should_stream_answer_and_return_token_usage_and_finish_reason_stop_with_mul assertThat(response.content().text()).containsIgnoringCase("ottawa"); TokenUsage tokenUsage = response.tokenUsage(); - assertThat(tokenUsage.inputTokenCount()).isEqualTo(11 + 11 + 11); + assertThat(tokenUsage.inputTokenCount()).isGreaterThan(0); assertThat(tokenUsage.outputTokenCount()).isGreaterThan(0); assertThat(tokenUsage.totalTokenCount()) .isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount()); @@ -177,8 +186,8 @@ void should_stream_answer_in_french_using_model_small_and_return_token_usage_and assertThat(response.content().text()).containsIgnoringCase("lima"); TokenUsage tokenUsage = response.tokenUsage(); - assertThat(tokenUsage.inputTokenCount()).isEqualTo(13); - assertThat(tokenUsage.outputTokenCount()).isGreaterThan(1); + assertThat(tokenUsage.inputTokenCount()).isGreaterThan(0); + assertThat(tokenUsage.outputTokenCount()).isGreaterThan(0); assertThat(tokenUsage.totalTokenCount()) .isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount()); @@ -206,8 +215,8 @@ void should_stream_answer_in_spanish_using_model_small_and_return_token_usage_an assertThat(response.content().text()).containsIgnoringCase("lima"); TokenUsage tokenUsage = response.tokenUsage(); - assertThat(tokenUsage.inputTokenCount()).isEqualTo(14); - assertThat(tokenUsage.outputTokenCount()).isGreaterThan(1); + assertThat(tokenUsage.inputTokenCount()).isGreaterThan(0); + assertThat(tokenUsage.outputTokenCount()).isGreaterThan(0); assertThat(tokenUsage.totalTokenCount()) .isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount()); @@ -235,8 +244,8 @@ void should_stream_answer_using_model_medium_and_return_token_usage_and_finish_r assertThat(response.content().text()).containsIgnoringCase("lima"); TokenUsage tokenUsage = response.tokenUsage(); - assertThat(tokenUsage.inputTokenCount()).isEqualTo(15); - assertThat(tokenUsage.outputTokenCount()).isEqualTo(10); + assertThat(tokenUsage.inputTokenCount()).isGreaterThan(0); + assertThat(tokenUsage.outputTokenCount()).isGreaterThan(0); assertThat(tokenUsage.totalTokenCount()) .isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount()); @@ -244,7 +253,7 @@ void should_stream_answer_using_model_medium_and_return_token_usage_and_finish_r } @Test - void should_execute_tool_and_return_finishReason_tool_execution(){ + void should_execute_tool_using_model_open8x22B_and_return_finishReason_tool_execution() { // given UserMessage userMessage = userMessage("What is the status of transaction T123?"); @@ -252,7 +261,7 @@ void should_execute_tool_and_return_finishReason_tool_execution(){ // when TestStreamingResponseHandler handler = new TestStreamingResponseHandler<>(); - mistralLargeStreamingModel.generate(singletonList(userMessage), toolSpecifications,handler); + openMixtral8x22BModel.generate(singletonList(userMessage), toolSpecifications, handler); Response response = handler.get(); @@ -266,8 +275,8 @@ void should_execute_tool_and_return_finishReason_tool_execution(){ assertThat(toolExecutionRequest.arguments()).isEqualToIgnoringWhitespace("{\"transactionId\":\"T123\"}"); TokenUsage tokenUsage = response.tokenUsage(); - assertThat(tokenUsage.inputTokenCount()).isEqualTo(78); - assertThat(tokenUsage.outputTokenCount()).isEqualTo(28); + assertThat(tokenUsage.inputTokenCount()).isGreaterThan(0); + assertThat(tokenUsage.outputTokenCount()).isGreaterThan(0); assertThat(tokenUsage.totalTokenCount()) .isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount()); @@ -275,7 +284,7 @@ void should_execute_tool_and_return_finishReason_tool_execution(){ } @Test - void should_execute_tool_when_toolChoice_is_auto_and_answer(){ + void should_execute_tool_using_model_open8x22B_when_toolChoice_is_auto_and_answer() { // given ToolSpecification retrievePaymentDate = ToolSpecification.builder() .name("retrieve-payment-date") @@ -287,11 +296,11 @@ void should_execute_tool_when_toolChoice_is_auto_and_answer(){ UserMessage userMessage = userMessage("What is the status of transaction T123?"); chatMessages.add(userMessage); - List toolSpecifications = asList(retrievePaymentStatus,retrievePaymentDate); + List toolSpecifications = asList(retrievePaymentStatus, retrievePaymentDate); // when TestStreamingResponseHandler handler = new TestStreamingResponseHandler<>(); - mistralLargeStreamingModel.generate(chatMessages, toolSpecifications, handler); + openMixtral8x22BModel.generate(chatMessages, toolSpecifications, handler); Response response = handler.get(); // then @@ -313,7 +322,7 @@ void should_execute_tool_when_toolChoice_is_auto_and_answer(){ // when TestStreamingResponseHandler handler2 = new TestStreamingResponseHandler<>(); - mistralLargeStreamingModel.generate(chatMessages, handler2); + openMixtral8x22BModel.generate(chatMessages, handler2); Response response2 = handler2.get(); // then @@ -323,7 +332,7 @@ void should_execute_tool_when_toolChoice_is_auto_and_answer(){ assertThat(aiMessage2.toolExecutionRequests()).isNull(); TokenUsage tokenUsage2 = response2.tokenUsage(); - assertThat(tokenUsage2.inputTokenCount()).isEqualTo(69); + assertThat(tokenUsage2.inputTokenCount()).isGreaterThan(0); assertThat(tokenUsage2.outputTokenCount()).isGreaterThan(0); assertThat(tokenUsage2.totalTokenCount()) .isEqualTo(tokenUsage2.inputTokenCount() + tokenUsage2.outputTokenCount()); @@ -332,7 +341,7 @@ void should_execute_tool_when_toolChoice_is_auto_and_answer(){ } @Test - void should_execute_tool_forcefully_when_toolChoice_is_any_and_answer() { + void should_execute_tool_forcefully_using_model_open8x22B_when_toolChoice_is_any_and_answer() { // given ToolSpecification retrievePaymentDate = ToolSpecification.builder() .name("retrieve-payment-date") @@ -346,7 +355,7 @@ void should_execute_tool_forcefully_when_toolChoice_is_any_and_answer() { // when TestStreamingResponseHandler handler = new TestStreamingResponseHandler<>(); - mistralLargeStreamingModel.generate(singletonList(userMessage), retrievePaymentDate, handler); + openMixtral8x22BModel.generate(singletonList(userMessage), retrievePaymentDate, handler); Response response = handler.get(); // then @@ -359,7 +368,7 @@ void should_execute_tool_forcefully_when_toolChoice_is_any_and_answer() { assertThat(toolExecutionRequest.arguments()).isEqualToIgnoringWhitespace("{\"transactionId\":\"T123\"}"); TokenUsage tokenUsage = response.tokenUsage(); - assertThat(tokenUsage.inputTokenCount()).isEqualTo(79); + assertThat(tokenUsage.inputTokenCount()).isGreaterThan(0); assertThat(tokenUsage.outputTokenCount()).isGreaterThan(0); assertThat(tokenUsage.totalTokenCount()) .isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount()); @@ -373,7 +382,7 @@ void should_execute_tool_forcefully_when_toolChoice_is_any_and_answer() { // when TestStreamingResponseHandler handler2 = new TestStreamingResponseHandler<>(); - mistralLargeStreamingModel.generate(chatMessages, handler2); + openMixtral8x22BModel.generate(chatMessages, handler2); Response response2 = handler2.get(); // then @@ -383,7 +392,7 @@ void should_execute_tool_forcefully_when_toolChoice_is_any_and_answer() { assertThat(aiMessage2.toolExecutionRequests()).isNull(); TokenUsage tokenUsage2 = response2.tokenUsage(); - assertThat(tokenUsage2.inputTokenCount()).isEqualTo(78); + assertThat(tokenUsage2.inputTokenCount()).isGreaterThan(0); assertThat(tokenUsage2.outputTokenCount()).isGreaterThan(0); assertThat(tokenUsage2.totalTokenCount()) .isEqualTo(tokenUsage2.inputTokenCount() + tokenUsage2.outputTokenCount()); @@ -392,7 +401,7 @@ void should_execute_tool_forcefully_when_toolChoice_is_any_and_answer() { } @Test - void should_execute_multiple_tools_then_answer(){ + void should_execute_multiple_tools_using_model_large_then_answer() { // given ToolSpecification retrievePaymentDate = ToolSpecification.builder() .name("retrieve-payment-date") @@ -404,7 +413,7 @@ void should_execute_multiple_tools_then_answer(){ UserMessage userMessage = userMessage("What is the status and the payment date of transaction T123?"); chatMessages.add(userMessage); - List toolSpecifications = asList(retrievePaymentStatus,retrievePaymentDate); + List toolSpecifications = asList(retrievePaymentStatus, retrievePaymentDate); // when TestStreamingResponseHandler handler = new TestStreamingResponseHandler<>(); @@ -447,7 +456,7 @@ void should_execute_multiple_tools_then_answer(){ assertThat(aiMessage2.toolExecutionRequests()).isNull(); TokenUsage tokenUsage2 = response2.tokenUsage(); - assertThat(tokenUsage2.inputTokenCount()).isEqualTo(128); + assertThat(tokenUsage2.inputTokenCount()).isGreaterThan(0); assertThat(tokenUsage2.outputTokenCount()).isGreaterThan(0); assertThat(tokenUsage2.totalTokenCount()) .isEqualTo(tokenUsage2.inputTokenCount() + tokenUsage2.outputTokenCount()); @@ -456,7 +465,7 @@ void should_execute_multiple_tools_then_answer(){ } @Test - void should_return_valid_json_object(){ + void should_return_valid_json_object_using_model_large() { // given String userMessage = "Return JSON with two fields: transactionId and status with the values T123 and paid."; @@ -480,4 +489,23 @@ void should_return_valid_json_object(){ // then assertThat(json).isEqualToIgnoringWhitespace(expectedJson); } + + @Test + void bugfix_1218_allow_blank() { + // given + StreamingChatLanguageModel model = MistralAiStreamingChatModel.builder() + .apiKey(System.getenv("MISTRAL_AI_API_KEY")) + .modelName(MistralAiChatModelName.MISTRAL_SMALL_LATEST) + .temperature(0d) + .build(); + + String userMessage = "What was inflation rate in germany in 2020? Answer in 1 short sentence. Begin your answer with 'In 2020, ...'"; + + // when + TestStreamingResponseHandler responseHandler = new TestStreamingResponseHandler<>(); + model.generate(userMessage, responseHandler); + + // results in: "In2020, Germany's inflation rate was0.5%." + assertThat(responseHandler.get().content().text()).containsIgnoringCase("In 2020"); + } } diff --git a/langchain4j-mongodb-atlas/pom.xml b/langchain4j-mongodb-atlas/pom.xml index 861fb3d465..927a2b79c5 100644 --- a/langchain4j-mongodb-atlas/pom.xml +++ b/langchain4j-mongodb-atlas/pom.xml @@ -6,7 +6,7 @@ dev.langchain4j langchain4j-parent - 0.30.0 + 0.32.0-SNAPSHOT ../langchain4j-parent/pom.xml diff --git a/langchain4j-mongodb-atlas/src/test/java/dev/langchain4j/store/embedding/mongodb/MongoDbEmbeddingStoreCloudIT.java b/langchain4j-mongodb-atlas/src/test/java/dev/langchain4j/store/embedding/mongodb/MongoDbEmbeddingStoreCloudIT.java index 44dc3cc1b0..b4a389f2bd 100644 --- a/langchain4j-mongodb-atlas/src/test/java/dev/langchain4j/store/embedding/mongodb/MongoDbEmbeddingStoreCloudIT.java +++ b/langchain4j-mongodb-atlas/src/test/java/dev/langchain4j/store/embedding/mongodb/MongoDbEmbeddingStoreCloudIT.java @@ -82,6 +82,6 @@ protected void clearStore() { @Override @SneakyThrows protected void awaitUntilPersisted() { - Thread.sleep(2000); + Thread.sleep(3000); } } diff --git a/langchain4j-neo4j/pom.xml b/langchain4j-neo4j/pom.xml index b5c3f17278..fe15738e30 100644 --- a/langchain4j-neo4j/pom.xml +++ b/langchain4j-neo4j/pom.xml @@ -7,7 +7,7 @@ dev.langchain4j langchain4j-parent - 0.30.0 + 0.32.0-SNAPSHOT ../langchain4j-parent/pom.xml diff --git a/langchain4j-nomic/pom.xml b/langchain4j-nomic/pom.xml index 8b6ae8ec1a..3b37ee6747 100644 --- a/langchain4j-nomic/pom.xml +++ b/langchain4j-nomic/pom.xml @@ -6,7 +6,7 @@ dev.langchain4j langchain4j-parent - 0.30.0 + 0.32.0-SNAPSHOT ../langchain4j-parent/pom.xml diff --git a/langchain4j-ollama/pom.xml b/langchain4j-ollama/pom.xml index 730e853e44..6df6f8b4db 100644 --- a/langchain4j-ollama/pom.xml +++ b/langchain4j-ollama/pom.xml @@ -6,7 +6,7 @@ dev.langchain4j langchain4j-parent - 0.30.0 + 0.32.0-SNAPSHOT ../langchain4j-parent/pom.xml diff --git a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaChatModel.java b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaChatModel.java index eed53baffe..6bba995057 100644 --- a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaChatModel.java +++ b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaChatModel.java @@ -10,6 +10,7 @@ import java.time.Duration; import java.util.List; +import java.util.Map; import static dev.langchain4j.internal.RetryUtils.withRetry; import static dev.langchain4j.internal.Utils.getOrDefault; @@ -45,10 +46,16 @@ public OllamaChatModel(String baseUrl, List stop, String format, Duration timeout, - Integer maxRetries) { + Integer maxRetries, + Map customHeaders, + Boolean logRequests, + Boolean logResponses) { this.client = OllamaClient.builder() .baseUrl(baseUrl) .timeout(getOrDefault(timeout, ofSeconds(60))) + .customHeaders(customHeaders) + .logRequests(getOrDefault(logRequests, false)) + .logResponses(logResponses) .build(); this.modelName = ensureNotBlank(modelName, "modelName"); this.options = Options.builder() diff --git a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaClient.java b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaClient.java index 963af6d458..14b4999d16 100644 --- a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaClient.java +++ b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaClient.java @@ -7,8 +7,12 @@ import dev.langchain4j.model.output.Response; import dev.langchain4j.model.output.TokenUsage; import lombok.Builder; +import lombok.extern.slf4j.Slf4j; +import okhttp3.Interceptor; import okhttp3.OkHttpClient; +import okhttp3.Request; import okhttp3.ResponseBody; +import org.jetbrains.annotations.NotNull; import retrofit2.Call; import retrofit2.Callback; import retrofit2.Retrofit; @@ -19,10 +23,14 @@ import java.io.InputStream; import java.io.InputStreamReader; import java.time.Duration; +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; import static com.google.gson.FieldNamingPolicy.LOWER_CASE_WITH_UNDERSCORES; import static java.lang.Boolean.TRUE; +@Slf4j class OllamaClient { private static final Gson GSON = new GsonBuilder() @@ -30,16 +38,31 @@ class OllamaClient { .create(); private final OllamaApi ollamaApi; + private final boolean logStreamingResponses; @Builder - public OllamaClient(String baseUrl, Duration timeout) { - - OkHttpClient okHttpClient = new OkHttpClient.Builder() + public OllamaClient(String baseUrl, + Duration timeout, + Boolean logRequests, Boolean logResponses, Boolean logStreamingResponses, + Map customHeaders) { + OkHttpClient.Builder okHttpClientBuilder = new OkHttpClient.Builder() .callTimeout(timeout) .connectTimeout(timeout) .readTimeout(timeout) - .writeTimeout(timeout) - .build(); + .writeTimeout(timeout); + if (logRequests != null && logRequests) { + okHttpClientBuilder.addInterceptor(new OllamaRequestLoggingInterceptor()); + } + if (logResponses != null && logResponses) { + okHttpClientBuilder.addInterceptor(new OllamaResponseLoggingInterceptor()); + } + this.logStreamingResponses = logStreamingResponses != null && logStreamingResponses; + + // add custom header interceptor + if (customHeaders != null && !customHeaders.isEmpty()) { + okHttpClientBuilder.addInterceptor(new GenericHeadersInterceptor(customHeaders)); + } + OkHttpClient okHttpClient = okHttpClientBuilder.build(); Retrofit retrofit = new Retrofit.Builder() .baseUrl(baseUrl) @@ -91,8 +114,10 @@ public void onResponse(Call call, retrofit2.Response byte[] bytes = new byte[1024]; int len = inputStream.read(bytes); String partialResponse = new String(bytes, 0, len); + if (logStreamingResponses) { + log.debug("Streaming partial response: {}", partialResponse); + } CompletionResponse completionResponse = GSON.fromJson(partialResponse, CompletionResponse.class); - contentBuilder.append(completionResponse.getResponse()); handler.onNext(completionResponse.getResponse()); @@ -130,6 +155,11 @@ public void onResponse(Call call, retrofit2.Response StringBuilder contentBuilder = new StringBuilder(); while (true) { String partialResponse = reader.readLine(); + + if (logStreamingResponses) { + log.debug("Streaming partial response: {}", partialResponse); + } + ChatResponse chatResponse = GSON.fromJson(partialResponse, ChatResponse.class); String content = chatResponse.getMessage().getContent(); @@ -207,4 +237,25 @@ private RuntimeException toException(retrofit2.Response response) throws IOEx String errorMessage = String.format("status code: %s; body: %s", code, body); return new RuntimeException(errorMessage); } + + static class GenericHeadersInterceptor implements Interceptor { + + private final Map headers = new HashMap<>(); + + GenericHeadersInterceptor(Map headers) { + Optional.ofNullable(headers) + .ifPresent(this.headers::putAll); + } + + @NotNull + @Override + public okhttp3.Response intercept(Chain chain) throws IOException { + Request.Builder builder = chain.request().newBuilder(); + + // Add headers + this.headers.forEach(builder::addHeader); + + return chain.proceed(builder.build()); + } + } } diff --git a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaEmbeddingModel.java b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaEmbeddingModel.java index fdf6c235b7..d00d63a1a2 100644 --- a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaEmbeddingModel.java +++ b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaEmbeddingModel.java @@ -10,6 +10,7 @@ import java.time.Duration; import java.util.ArrayList; import java.util.List; +import java.util.Map; import static dev.langchain4j.internal.RetryUtils.withRetry; import static dev.langchain4j.internal.Utils.getOrDefault; @@ -30,10 +31,16 @@ public class OllamaEmbeddingModel implements EmbeddingModel { public OllamaEmbeddingModel(String baseUrl, String modelName, Duration timeout, - Integer maxRetries) { + Integer maxRetries, + Boolean logRequests, + Boolean logResponses, + Map customHeaders) { this.client = OllamaClient.builder() .baseUrl(baseUrl) .timeout(getOrDefault(timeout, ofSeconds(60))) + .logRequests(logRequests) + .logResponses(logResponses) + .customHeaders(customHeaders) .build(); this.modelName = ensureNotBlank(modelName, "modelName"); this.maxRetries = getOrDefault(maxRetries, 3); diff --git a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaLanguageModel.java b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaLanguageModel.java index 50668c7591..be8bb3c8ed 100644 --- a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaLanguageModel.java +++ b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaLanguageModel.java @@ -8,6 +8,7 @@ import java.time.Duration; import java.util.List; +import java.util.Map; import static dev.langchain4j.internal.RetryUtils.withRetry; import static dev.langchain4j.internal.Utils.getOrDefault; @@ -41,10 +42,17 @@ public OllamaLanguageModel(String baseUrl, List stop, String format, Duration timeout, - Integer maxRetries) { + Integer maxRetries, + Boolean logRequests, + Boolean logResponses, + Map customHeaders + ) { this.client = OllamaClient.builder() .baseUrl(baseUrl) .timeout(getOrDefault(timeout, ofSeconds(60))) + .logRequests(logRequests) + .logResponses(logResponses) + .customHeaders(customHeaders) .build(); this.modelName = ensureNotBlank(modelName, "modelName"); this.options = Options.builder() diff --git a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaModels.java b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaModels.java index 887d60a1d9..d8592b922c 100644 --- a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaModels.java +++ b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaModels.java @@ -17,10 +17,15 @@ public class OllamaModels { @Builder public OllamaModels(String baseUrl, Duration timeout, - Integer maxRetries) { + Integer maxRetries, + Boolean logRequests, + Boolean logResponses + ) { this.client = OllamaClient.builder() .baseUrl(baseUrl) .timeout((getOrDefault(timeout, Duration.ofSeconds(60)))) + .logRequests(logRequests) + .logResponses(logResponses) .build(); this.maxRetries = getOrDefault(maxRetries, 3); } diff --git a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaRequestLoggingInterceptor.java b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaRequestLoggingInterceptor.java new file mode 100644 index 0000000000..50b3ea7bf3 --- /dev/null +++ b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaRequestLoggingInterceptor.java @@ -0,0 +1,79 @@ +package dev.langchain4j.model.ollama; + +import lombok.extern.slf4j.Slf4j; +import okhttp3.Headers; +import okhttp3.Interceptor; +import okhttp3.Request; +import okhttp3.Response; +import okio.Buffer; + +import java.io.IOException; +import java.util.HashSet; +import java.util.Set; +import java.util.stream.StreamSupport; + +import static dev.langchain4j.internal.Utils.isNullOrBlank; +import static java.util.Arrays.asList; +import static java.util.stream.Collectors.joining; + +@Slf4j +class OllamaRequestLoggingInterceptor implements Interceptor { + + private static final Set COMMON_SECRET_HEADERS = + new HashSet<>(asList("authorization", "x-api-key", "x-auth-token")); + + @Override + public Response intercept(Chain chain) throws IOException { + Request request = chain.request(); + this.log(request); + return chain.proceed(request); + } + + private void log(Request request) { + try { + log.debug("Request:\n- method: {}\n- url: {}\n- headers: {}\n- body: {}", + request.method(), request.url(), getHeaders(request.headers()), getBody(request)); + } catch (Exception e) { + log.warn("Error while logging request: {}", e.getMessage()); + } + } + + private static String getBody(Request request) { + try { + Buffer buffer = new Buffer(); + if (request.body() == null) { + return ""; + } + request.body().writeTo(buffer); + return buffer.readUtf8(); + } catch (Exception e) { + log.warn("Exception while getting body", e); + return "Exception while getting body: " + e.getMessage(); + } + } + + private static String getHeaders(Headers headers) { + return StreamSupport.stream(headers.spliterator(), false) + .map(header -> formatHeader(header.component1(), header.component2())) + .collect(joining(", ")); + } + + private static String formatHeader(String headerKey, String headerValue) { + if (COMMON_SECRET_HEADERS.contains(headerKey.toLowerCase())) { + headerValue = maskSecretKey(headerValue); + } + return String.format("[%s: %s]", headerKey, headerValue); + } + + private static String maskSecretKey(String key) { + if (isNullOrBlank(key)) { + return key; + } + + if (key.length() >= 7) { + return key.substring(0, 5) + "..." + key.substring(key.length() - 2); + } else { + return "..."; // to short to be masked + } + } +} diff --git a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaResponseLoggingInterceptor.java b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaResponseLoggingInterceptor.java new file mode 100644 index 0000000000..a053b6b2bb --- /dev/null +++ b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaResponseLoggingInterceptor.java @@ -0,0 +1,40 @@ +package dev.langchain4j.model.ollama; + +import lombok.extern.slf4j.Slf4j; +import okhttp3.Interceptor; +import okhttp3.Request; +import okhttp3.Response; + +import java.io.IOException; + +@Slf4j +class OllamaResponseLoggingInterceptor implements Interceptor { + + @Override + public Response intercept(Chain chain) throws IOException { + Request request = chain.request(); + Response response = chain.proceed(request); + this.log(response); + return response; + } + + private void log(Response response) { + try { + log.debug("Response:\n- status code: {}\n- headers: {}\n- body: {}", + response.code(), response.headers(), this.getBody(response)); + } catch (Exception e) { + log.warn("Error while logging response: {}", e.getMessage()); + } + } + + private String getBody(Response response) throws IOException { + return isEventStream(response) + ? "[skipping response body due to streaming]" + : response.peekBody(Long.MAX_VALUE).string(); + } + + private static boolean isEventStream(Response response) { + String contentType = response.header("Content-Type"); + return contentType != null && contentType.contains("event-stream"); + } +} diff --git a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaStreamingChatModel.java b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaStreamingChatModel.java index ab5f9acb70..f99da7b02f 100644 --- a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaStreamingChatModel.java +++ b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaStreamingChatModel.java @@ -9,6 +9,7 @@ import java.time.Duration; import java.util.List; +import java.util.Map; import static dev.langchain4j.internal.Utils.getOrDefault; import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank; @@ -41,10 +42,17 @@ public OllamaStreamingChatModel(String baseUrl, Integer numCtx, List stop, String format, - Duration timeout) { + Duration timeout, + Boolean logRequests, + Boolean logResponses, + Map customHeaders + ) { this.client = OllamaClient.builder() .baseUrl(baseUrl) .timeout(getOrDefault(timeout, ofSeconds(60))) + .logRequests(logRequests) + .logStreamingResponses(logResponses) + .customHeaders(customHeaders) .build(); this.modelName = ensureNotBlank(modelName, "modelName"); this.options = Options.builder() diff --git a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaStreamingLanguageModel.java b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaStreamingLanguageModel.java index 972aba876e..c2181b9a59 100644 --- a/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaStreamingLanguageModel.java +++ b/langchain4j-ollama/src/main/java/dev/langchain4j/model/ollama/OllamaStreamingLanguageModel.java @@ -7,6 +7,7 @@ import java.time.Duration; import java.util.List; +import java.util.Map; import static dev.langchain4j.internal.Utils.getOrDefault; import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank; @@ -37,10 +38,17 @@ public OllamaStreamingLanguageModel(String baseUrl, Integer numCtx, List stop, String format, - Duration timeout) { + Duration timeout, + Boolean logRequests, + Boolean logResponses, + Map customHeaders + ) { this.client = OllamaClient.builder() .baseUrl(baseUrl) .timeout(getOrDefault(timeout, ofSeconds(60))) + .logRequests(logRequests) + .logStreamingResponses(logResponses) + .customHeaders(customHeaders) .build(); this.modelName = ensureNotBlank(modelName, "modelName"); this.options = Options.builder() diff --git a/langchain4j-ollama/src/test/java/dev/langchain4j/model/ollama/OllamaChatModelIT.java b/langchain4j-ollama/src/test/java/dev/langchain4j/model/ollama/OllamaChatModelIT.java index 748e0ed7fd..d633d552ce 100644 --- a/langchain4j-ollama/src/test/java/dev/langchain4j/model/ollama/OllamaChatModelIT.java +++ b/langchain4j-ollama/src/test/java/dev/langchain4j/model/ollama/OllamaChatModelIT.java @@ -21,6 +21,8 @@ class OllamaChatModelIT extends AbstractOllamaLanguageModelInfrastructure { .baseUrl(ollama.getEndpoint()) .modelName(TINY_DOLPHIN_MODEL) .temperature(0.0) + .logRequests(true) + .logResponses(true) .build(); @Test @@ -39,7 +41,7 @@ void should_generate_response() { assertThat(aiMessage.toolExecutionRequests()).isNull(); TokenUsage tokenUsage = response.tokenUsage(); - assertThat(tokenUsage.inputTokenCount()).isEqualTo(13); + assertThat(tokenUsage.inputTokenCount()).isGreaterThan(0); assertThat(tokenUsage.outputTokenCount()).isGreaterThan(0); assertThat(tokenUsage.totalTokenCount()) .isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount()); diff --git a/langchain4j-ollama/src/test/java/dev/langchain4j/model/ollama/OllamaOpenAiChatModelIT.java b/langchain4j-ollama/src/test/java/dev/langchain4j/model/ollama/OllamaOpenAiChatModelIT.java index e8e6415be9..233a66e842 100644 --- a/langchain4j-ollama/src/test/java/dev/langchain4j/model/ollama/OllamaOpenAiChatModelIT.java +++ b/langchain4j-ollama/src/test/java/dev/langchain4j/model/ollama/OllamaOpenAiChatModelIT.java @@ -43,7 +43,7 @@ void should_generate_response() { assertThat(aiMessage.toolExecutionRequests()).isNull(); TokenUsage tokenUsage = response.tokenUsage(); - assertThat(tokenUsage.inputTokenCount()).isEqualTo(35); + assertThat(tokenUsage.inputTokenCount()).isGreaterThan(0); assertThat(tokenUsage.outputTokenCount()).isGreaterThan(0); assertThat(tokenUsage.totalTokenCount()) .isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount()); diff --git a/langchain4j-ollama/src/test/java/dev/langchain4j/model/ollama/OllamaStreamingChatModelIT.java b/langchain4j-ollama/src/test/java/dev/langchain4j/model/ollama/OllamaStreamingChatModelIT.java index a49b57f459..338dac9b18 100644 --- a/langchain4j-ollama/src/test/java/dev/langchain4j/model/ollama/OllamaStreamingChatModelIT.java +++ b/langchain4j-ollama/src/test/java/dev/langchain4j/model/ollama/OllamaStreamingChatModelIT.java @@ -25,6 +25,8 @@ class OllamaStreamingChatModelIT extends AbstractOllamaLanguageModelInfrastructu .baseUrl(ollama.getEndpoint()) .modelName(TINY_DOLPHIN_MODEL) .temperature(0.0) + .logRequests(true) + .logResponses(true) .build(); @Test diff --git a/langchain4j-open-ai/pom.xml b/langchain4j-open-ai/pom.xml index b5e9ff1bce..2283521e6c 100644 --- a/langchain4j-open-ai/pom.xml +++ b/langchain4j-open-ai/pom.xml @@ -7,7 +7,7 @@ dev.langchain4j langchain4j-parent - 0.30.0 + 0.32.0-SNAPSHOT ../langchain4j-parent/pom.xml diff --git a/langchain4j-open-ai/src/main/java/dev/langchain4j/model/openai/InternalOpenAiHelper.java b/langchain4j-open-ai/src/main/java/dev/langchain4j/model/openai/InternalOpenAiHelper.java index b2dd71e667..483c390aa9 100644 --- a/langchain4j-open-ai/src/main/java/dev/langchain4j/model/openai/InternalOpenAiHelper.java +++ b/langchain4j-open-ai/src/main/java/dev/langchain4j/model/openai/InternalOpenAiHelper.java @@ -10,6 +10,8 @@ import dev.langchain4j.data.message.SystemMessage; import dev.langchain4j.data.message.UserMessage; import dev.langchain4j.data.message.*; +import dev.langchain4j.model.chat.listener.ChatModelRequest; +import dev.langchain4j.model.chat.listener.ChatModelResponse; import dev.langchain4j.model.output.FinishReason; import dev.langchain4j.model.output.Response; import dev.langchain4j.model.output.TokenUsage; @@ -280,4 +282,33 @@ static boolean isOpenAiModel(String modelName) { static Response removeTokenUsage(Response response) { return Response.from(response.content(), null, response.finishReason()); } + + static ChatModelRequest createModelListenerRequest(ChatCompletionRequest request, + List messages, + List toolSpecifications) { + return ChatModelRequest.builder() + .model(request.model()) + .temperature(request.temperature()) + .topP(request.topP()) + .maxTokens(request.maxTokens()) + .messages(messages) + .toolSpecifications(toolSpecifications) + .build(); + } + + static ChatModelResponse createModelListenerResponse(String responseId, + String responseModel, + Response response) { + if (response == null) { + return null; + } + + return ChatModelResponse.builder() + .id(responseId) + .model(responseModel) + .tokenUsage(response.tokenUsage()) + .finishReason(response.finishReason()) + .aiMessage(response.content()) + .build(); + } } diff --git a/langchain4j-open-ai/src/main/java/dev/langchain4j/model/openai/OpenAiChatModel.java b/langchain4j-open-ai/src/main/java/dev/langchain4j/model/openai/OpenAiChatModel.java index 7eb30e1de2..20ee785449 100644 --- a/langchain4j-open-ai/src/main/java/dev/langchain4j/model/openai/OpenAiChatModel.java +++ b/langchain4j-open-ai/src/main/java/dev/langchain4j/model/openai/OpenAiChatModel.java @@ -1,6 +1,7 @@ package dev.langchain4j.model.openai; import dev.ai4j.openai4j.OpenAiClient; +import dev.ai4j.openai4j.OpenAiHttpException; import dev.ai4j.openai4j.chat.ChatCompletionRequest; import dev.ai4j.openai4j.chat.ChatCompletionResponse; import dev.langchain4j.agent.tool.ToolSpecification; @@ -9,15 +10,18 @@ import dev.langchain4j.model.Tokenizer; import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.chat.TokenCountEstimator; +import dev.langchain4j.model.chat.listener.*; import dev.langchain4j.model.openai.spi.OpenAiChatModelBuilderFactory; import dev.langchain4j.model.output.Response; import lombok.Builder; +import lombok.extern.slf4j.Slf4j; import java.net.Proxy; import java.time.Duration; -import java.util.HashMap; +import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; import static dev.langchain4j.internal.RetryUtils.withRetry; import static dev.langchain4j.internal.Utils.getOrDefault; @@ -25,12 +29,14 @@ import static dev.langchain4j.model.openai.OpenAiModelName.GPT_3_5_TURBO; import static dev.langchain4j.spi.ServiceHelper.loadFactories; import static java.time.Duration.ofSeconds; +import static java.util.Collections.emptyList; import static java.util.Collections.singletonList; /** * Represents an OpenAI language model with a chat completion interface, such as gpt-3.5-turbo and gpt-4. * You can find description of parameters here. */ +@Slf4j public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator { private final OpenAiClient client; @@ -47,6 +53,7 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator { private final String user; private final Integer maxRetries; private final Tokenizer tokenizer; + private final List listeners; @Builder public OpenAiChatModel(String baseUrl, @@ -69,7 +76,8 @@ public OpenAiChatModel(String baseUrl, Boolean logRequests, Boolean logResponses, Tokenizer tokenizer, - Map customHeaders) { + Map customHeaders, + List listeners) { baseUrl = getOrDefault(baseUrl, OPENAI_URL); if (OPENAI_DEMO_API_KEY.equals(apiKey)) { @@ -105,6 +113,7 @@ public OpenAiChatModel(String baseUrl, this.user = user; this.maxRetries = getOrDefault(maxRetries, 3); this.tokenizer = getOrDefault(tokenizer, OpenAiTokenizer::new); + this.listeners = listeners == null ? emptyList() : new ArrayList<>(listeners); } public String modelName() { @@ -153,13 +162,71 @@ private Response generate(List messages, ChatCompletionRequest request = requestBuilder.build(); - ChatCompletionResponse response = withRetry(() -> client.chatCompletion(request).execute(), maxRetries); + ChatModelRequest modelListenerRequest = createModelListenerRequest(request, messages, toolSpecifications); + Map attributes = new ConcurrentHashMap<>(); + ChatModelRequestContext requestContext = new ChatModelRequestContext(modelListenerRequest, attributes); + listeners.forEach(listener -> { + try { + listener.onRequest(requestContext); + } catch (Exception e) { + log.warn("Exception while calling model listener", e); + } + }); - return Response.from( - aiMessageFrom(response), - tokenUsageFrom(response.usage()), - finishReasonFrom(response.choices().get(0).finishReason()) - ); + try { + ChatCompletionResponse chatCompletionResponse = withRetry(() -> client.chatCompletion(request).execute(), maxRetries); + + Response response = Response.from( + aiMessageFrom(chatCompletionResponse), + tokenUsageFrom(chatCompletionResponse.usage()), + finishReasonFrom(chatCompletionResponse.choices().get(0).finishReason()) + ); + + ChatModelResponse modelListenerResponse = createModelListenerResponse( + chatCompletionResponse.id(), + chatCompletionResponse.model(), + response + ); + ChatModelResponseContext responseContext = new ChatModelResponseContext( + modelListenerResponse, + modelListenerRequest, + attributes + ); + listeners.forEach(listener -> { + try { + listener.onResponse(responseContext); + } catch (Exception e) { + log.warn("Exception while calling model listener", e); + } + }); + + return response; + } catch (RuntimeException e) { + + Throwable error; + if (e.getCause() instanceof OpenAiHttpException) { + error = e.getCause(); + } else { + error = e; + } + + ChatModelErrorContext errorContext = new ChatModelErrorContext( + error, + modelListenerRequest, + null, + attributes + ); + + listeners.forEach(listener -> { + try { + listener.onError(errorContext); + } catch (Exception e2) { + log.warn("Exception while calling model listener", e2); + } + }); + + throw e; + } } @Override diff --git a/langchain4j-open-ai/src/main/java/dev/langchain4j/model/openai/OpenAiChatModelName.java b/langchain4j-open-ai/src/main/java/dev/langchain4j/model/openai/OpenAiChatModelName.java index 73022c8ef8..d83cdb22a7 100644 --- a/langchain4j-open-ai/src/main/java/dev/langchain4j/model/openai/OpenAiChatModelName.java +++ b/langchain4j-open-ai/src/main/java/dev/langchain4j/model/openai/OpenAiChatModelName.java @@ -22,7 +22,9 @@ public enum OpenAiChatModelName { GPT_4_32K_0314("gpt-4-32k-0314"), GPT_4_32K_0613("gpt-4-32k-0613"), - GPT_4_VISION_PREVIEW("gpt-4-vision-preview"); + GPT_4_VISION_PREVIEW("gpt-4-vision-preview"), + + GPT_4_O("gpt-4o"); private final String stringValue; diff --git a/langchain4j-open-ai/src/main/java/dev/langchain4j/model/openai/OpenAiStreamingChatModel.java b/langchain4j-open-ai/src/main/java/dev/langchain4j/model/openai/OpenAiStreamingChatModel.java index 0de6e17bc9..af4a0ca7f6 100644 --- a/langchain4j-open-ai/src/main/java/dev/langchain4j/model/openai/OpenAiStreamingChatModel.java +++ b/langchain4j-open-ai/src/main/java/dev/langchain4j/model/openai/OpenAiStreamingChatModel.java @@ -12,21 +12,26 @@ import dev.langchain4j.model.Tokenizer; import dev.langchain4j.model.chat.StreamingChatLanguageModel; import dev.langchain4j.model.chat.TokenCountEstimator; +import dev.langchain4j.model.chat.listener.*; import dev.langchain4j.model.openai.spi.OpenAiStreamingChatModelBuilderFactory; import dev.langchain4j.model.output.Response; import lombok.Builder; +import lombok.extern.slf4j.Slf4j; import java.net.Proxy; import java.time.Duration; +import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicReference; -import static dev.langchain4j.internal.Utils.getOrDefault; -import static dev.langchain4j.internal.Utils.isNullOrEmpty; +import static dev.langchain4j.internal.Utils.*; import static dev.langchain4j.model.openai.InternalOpenAiHelper.*; import static dev.langchain4j.model.openai.OpenAiModelName.GPT_3_5_TURBO; import static dev.langchain4j.spi.ServiceHelper.loadFactories; import static java.time.Duration.ofSeconds; +import static java.util.Collections.emptyList; import static java.util.Collections.singletonList; /** @@ -34,6 +39,7 @@ * The model's response is streamed token by token and should be handled with {@link StreamingResponseHandler}. * You can find description of parameters here. */ +@Slf4j public class OpenAiStreamingChatModel implements StreamingChatLanguageModel, TokenCountEstimator { private final OpenAiClient client; @@ -50,6 +56,7 @@ public class OpenAiStreamingChatModel implements StreamingChatLanguageModel, Tok private final String user; private final Tokenizer tokenizer; private final boolean isOpenAiModel; + private final List listeners; @Builder public OpenAiStreamingChatModel(String baseUrl, @@ -71,7 +78,8 @@ public OpenAiStreamingChatModel(String baseUrl, Boolean logRequests, Boolean logResponses, Tokenizer tokenizer, - Map customHeaders) { + Map customHeaders, + List listeners) { timeout = getOrDefault(timeout, ofSeconds(60)); @@ -102,6 +110,7 @@ public OpenAiStreamingChatModel(String baseUrl, this.user = user; this.tokenizer = getOrDefault(tokenizer, OpenAiTokenizer::new); this.isOpenAiModel = isOpenAiModel(this.modelName); + this.listeners = listeners == null ? emptyList() : new ArrayList<>(listeners); } public String modelName() { @@ -152,25 +161,96 @@ private void generate(List messages, ChatCompletionRequest request = requestBuilder.build(); + ChatModelRequest modelListenerRequest = createModelListenerRequest(request, messages, toolSpecifications); + Map attributes = new ConcurrentHashMap<>(); + ChatModelRequestContext requestContext = new ChatModelRequestContext(modelListenerRequest, attributes); + listeners.forEach(listener -> { + try { + listener.onRequest(requestContext); + } catch (Exception e) { + log.warn("Exception while calling model listener", e); + } + }); + int inputTokenCount = countInputTokens(messages, toolSpecifications, toolThatMustBeExecuted); OpenAiStreamingResponseBuilder responseBuilder = new OpenAiStreamingResponseBuilder(inputTokenCount); + AtomicReference responseId = new AtomicReference<>(); + AtomicReference responseModel = new AtomicReference<>(); + client.chatCompletion(request) .onPartialResponse(partialResponse -> { responseBuilder.append(partialResponse); handle(partialResponse, handler); + + if (!isNullOrBlank(partialResponse.id())) { + responseId.set(partialResponse.id()); + } + if (!isNullOrBlank(partialResponse.model())) { + responseModel.set(partialResponse.model()); + } }) .onComplete(() -> { - Response response = responseBuilder.build(tokenizer, toolThatMustBeExecuted != null); - if (!isOpenAiModel) { - response = removeTokenUsage(response); - } + Response response = createResponse(responseBuilder, toolThatMustBeExecuted); + + ChatModelResponse modelListenerResponse = createModelListenerResponse( + responseId.get(), + responseModel.get(), + response + ); + ChatModelResponseContext responseContext = new ChatModelResponseContext( + modelListenerResponse, + modelListenerRequest, + attributes + ); + listeners.forEach(listener -> { + try { + listener.onResponse(responseContext); + } catch (Exception e) { + log.warn("Exception while calling model listener", e); + } + }); + handler.onComplete(response); }) - .onError(handler::onError) + .onError(error -> { + Response response = createResponse(responseBuilder, toolThatMustBeExecuted); + + ChatModelResponse modelListenerPartialResponse = createModelListenerResponse( + responseId.get(), + responseModel.get(), + response + ); + + ChatModelErrorContext errorContext = new ChatModelErrorContext( + error, + modelListenerRequest, + modelListenerPartialResponse, + attributes + ); + + listeners.forEach(listener -> { + try { + listener.onError(errorContext); + } catch (Exception e) { + log.warn("Exception while calling model listener", e); + } + }); + + handler.onError(error); + }) .execute(); } + private Response createResponse(OpenAiStreamingResponseBuilder responseBuilder, + ToolSpecification toolThatMustBeExecuted) { + Response response = responseBuilder.build(tokenizer, toolThatMustBeExecuted != null); + if (isOpenAiModel) { + return response; + } + return removeTokenUsage(response); + } + private int countInputTokens(List messages, List toolSpecifications, ToolSpecification toolThatMustBeExecuted) { diff --git a/langchain4j-open-ai/src/test/java/dev/langchain4j/model/openai/OpenAiChatModelIT.java b/langchain4j-open-ai/src/test/java/dev/langchain4j/model/openai/OpenAiChatModelIT.java index 95a3e12c2d..ae3f71e264 100644 --- a/langchain4j-open-ai/src/test/java/dev/langchain4j/model/openai/OpenAiChatModelIT.java +++ b/langchain4j-open-ai/src/test/java/dev/langchain4j/model/openai/OpenAiChatModelIT.java @@ -1,28 +1,33 @@ package dev.langchain4j.model.openai; +import dev.ai4j.openai4j.OpenAiHttpException; import dev.langchain4j.agent.tool.ToolExecutionRequest; import dev.langchain4j.agent.tool.ToolSpecification; import dev.langchain4j.data.message.*; import dev.langchain4j.model.Tokenizer; import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.chat.listener.*; import dev.langchain4j.model.output.Response; import dev.langchain4j.model.output.TokenUsage; import org.junit.jupiter.api.Test; import java.util.Base64; import java.util.List; +import java.util.concurrent.atomic.AtomicReference; import static dev.langchain4j.agent.tool.JsonSchemaProperty.INTEGER; import static dev.langchain4j.data.message.ToolExecutionResultMessage.from; import static dev.langchain4j.data.message.UserMessage.userMessage; import static dev.langchain4j.internal.Utils.readBytes; import static dev.langchain4j.model.openai.OpenAiChatModelName.GPT_3_5_TURBO; +import static dev.langchain4j.model.openai.OpenAiChatModelName.GPT_4_O; import static dev.langchain4j.model.openai.OpenAiModelName.GPT_3_5_TURBO_1106; -import static dev.langchain4j.model.openai.OpenAiModelName.GPT_4_VISION_PREVIEW; import static dev.langchain4j.model.output.FinishReason.*; import static java.util.Arrays.asList; import static java.util.Collections.singletonList; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Fail.fail; +import static org.junit.jupiter.api.Assertions.assertThrows; class OpenAiChatModelIT { @@ -49,7 +54,7 @@ class OpenAiChatModelIT { .baseUrl(System.getenv("OPENAI_BASE_URL")) .apiKey(System.getenv("OPENAI_API_KEY")) .organizationId(System.getenv("OPENAI_ORGANIZATION_ID")) - .modelName(GPT_4_VISION_PREVIEW) + .modelName(GPT_4_O) .temperature(0.0) .logRequests(true) .logResponses(true) @@ -216,6 +221,8 @@ void should_execute_multiple_tools_in_parallel_then_answer() { .organizationId(System.getenv("OPENAI_ORGANIZATION_ID")) .modelName(GPT_3_5_TURBO_1106) // supports parallel function calling .temperature(0.0) + .logRequests(true) + .logResponses(true) .build(); UserMessage userMessage = userMessage("2+2=? 3+3=?"); @@ -463,4 +470,130 @@ public int estimateTokenCountInToolExecutionRequests(Iterable requestReference = new AtomicReference<>(); + AtomicReference responseReference = new AtomicReference<>(); + + ChatModelListener listener = new ChatModelListener() { + + @Override + public void onRequest(ChatModelRequestContext requestContext) { + requestReference.set(requestContext.request()); + requestContext.attributes().put("id", "12345"); + } + + @Override + public void onResponse(ChatModelResponseContext responseContext) { + responseReference.set(responseContext.response()); + assertThat(responseContext.request()).isSameAs(requestReference.get()); + assertThat(responseContext.attributes().get("id")).isEqualTo("12345"); + } + + @Override + public void onError(ChatModelErrorContext errorContext) { + fail("onError() must not be called"); + } + }; + + OpenAiChatModelName modelName = GPT_3_5_TURBO; + double temperature = 0.7; + double topP = 1.0; + int maxTokens = 7; + + OpenAiChatModel model = OpenAiChatModel.builder() + .baseUrl(System.getenv("OPENAI_BASE_URL")) + .apiKey(System.getenv("OPENAI_API_KEY")) + .organizationId(System.getenv("OPENAI_ORGANIZATION_ID")) + .modelName(modelName) + .temperature(temperature) + .topP(topP) + .maxTokens(maxTokens) + .logRequests(true) + .logResponses(true) + .listeners(singletonList(listener)) + .build(); + + UserMessage userMessage = UserMessage.from("hello"); + + ToolSpecification toolSpecification = ToolSpecification.builder() + .name("add") + .addParameter("a", INTEGER) + .addParameter("b", INTEGER) + .build(); + + // when + AiMessage aiMessage = model.generate(singletonList(userMessage), singletonList(toolSpecification)).content(); + + // then + ChatModelRequest request = requestReference.get(); + assertThat(request.model()).isEqualTo(modelName.toString()); + assertThat(request.temperature()).isEqualTo(temperature); + assertThat(request.topP()).isEqualTo(topP); + assertThat(request.maxTokens()).isEqualTo(maxTokens); + assertThat(request.messages()).containsExactly(userMessage); + assertThat(request.toolSpecifications()).containsExactly(toolSpecification); + + ChatModelResponse response = responseReference.get(); + assertThat(response.id()).isNotBlank(); + assertThat(response.model()).isNotBlank(); + assertThat(response.tokenUsage().inputTokenCount()).isGreaterThan(0); + assertThat(response.tokenUsage().outputTokenCount()).isGreaterThan(0); + assertThat(response.tokenUsage().totalTokenCount()).isGreaterThan(0); + assertThat(response.finishReason()).isNotNull(); + assertThat(response.aiMessage()).isEqualTo(aiMessage); + } + + @Test + void should_listen_error() { + + // given + String wrongApiKey = "banana"; + + AtomicReference requestReference = new AtomicReference<>(); + AtomicReference errorReference = new AtomicReference<>(); + + ChatModelListener listener = new ChatModelListener() { + + @Override + public void onRequest(ChatModelRequestContext requestContext) { + requestReference.set(requestContext.request()); + requestContext.attributes().put("id", "12345"); + } + + @Override + public void onResponse(ChatModelResponseContext responseContext) { + fail("onResponse() must not be called"); + } + + @Override + public void onError(ChatModelErrorContext errorContext) { + errorReference.set(errorContext.error()); + assertThat(errorContext.request()).isSameAs(requestReference.get()); + assertThat(errorContext.partialResponse()).isNull(); + assertThat(errorContext.attributes().get("id")).isEqualTo("12345"); + } + }; + + OpenAiChatModel model = OpenAiChatModel.builder() + .apiKey(wrongApiKey) + .maxRetries(0) + .logRequests(true) + .logResponses(true) + .listeners(singletonList(listener)) + .build(); + + String userMessage = "this message will fail"; + + // when + assertThrows(RuntimeException.class, () -> model.generate(userMessage)); + + // then + Throwable throwable = errorReference.get(); + assertThat(throwable).isExactlyInstanceOf(OpenAiHttpException.class); + assertThat(throwable).hasMessageContaining("Incorrect API key provided"); + } } diff --git a/langchain4j-open-ai/src/test/java/dev/langchain4j/model/openai/OpenAiStreamingChatModelIT.java b/langchain4j-open-ai/src/test/java/dev/langchain4j/model/openai/OpenAiStreamingChatModelIT.java index c6adb9378b..671d9dee69 100644 --- a/langchain4j-open-ai/src/test/java/dev/langchain4j/model/openai/OpenAiStreamingChatModelIT.java +++ b/langchain4j-open-ai/src/test/java/dev/langchain4j/model/openai/OpenAiStreamingChatModelIT.java @@ -1,5 +1,6 @@ package dev.langchain4j.model.openai; +import dev.ai4j.openai4j.OpenAiHttpException; import dev.langchain4j.agent.tool.ToolExecutionRequest; import dev.langchain4j.agent.tool.ToolSpecification; import dev.langchain4j.data.message.*; @@ -7,6 +8,7 @@ import dev.langchain4j.model.Tokenizer; import dev.langchain4j.model.chat.StreamingChatLanguageModel; import dev.langchain4j.model.chat.TestStreamingResponseHandler; +import dev.langchain4j.model.chat.listener.*; import dev.langchain4j.model.output.Response; import dev.langchain4j.model.output.TokenUsage; import org.assertj.core.data.Percentage; @@ -15,6 +17,7 @@ import java.util.Base64; import java.util.List; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicReference; import static dev.langchain4j.agent.tool.JsonSchemaProperty.INTEGER; import static dev.langchain4j.data.message.ToolExecutionResultMessage.from; @@ -31,6 +34,7 @@ import static java.util.Collections.singletonList; import static java.util.concurrent.TimeUnit.SECONDS; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Fail.fail; import static org.assertj.core.data.Percentage.withPercentage; class OpenAiStreamingChatModelIT { @@ -639,4 +643,152 @@ public int estimateTokenCountInToolExecutionRequests(Iterable requestReference = new AtomicReference<>(); + AtomicReference responseReference = new AtomicReference<>(); + + ChatModelListener listener = new ChatModelListener() { + + @Override + public void onRequest(ChatModelRequestContext requestContext) { + requestReference.set(requestContext.request()); + requestContext.attributes().put("id", "12345"); + } + + @Override + public void onResponse(ChatModelResponseContext responseContext) { + responseReference.set(responseContext.response()); + assertThat(responseContext.request()).isSameAs(requestReference.get()); + assertThat(responseContext.attributes().get("id")).isEqualTo("12345"); + } + + @Override + public void onError(ChatModelErrorContext errorContext) { + fail("onError() must not be called"); + } + }; + + OpenAiChatModelName modelName = GPT_3_5_TURBO; + double temperature = 0.7; + double topP = 1.0; + int maxTokens = 7; + + StreamingChatLanguageModel model = OpenAiStreamingChatModel.builder() + .baseUrl(System.getenv("OPENAI_BASE_URL")) + .apiKey(System.getenv("OPENAI_API_KEY")) + .organizationId(System.getenv("OPENAI_ORGANIZATION_ID")) + .modelName(modelName) + .temperature(temperature) + .topP(topP) + .maxTokens(maxTokens) + .logRequests(true) + .logResponses(true) + .listeners(singletonList(listener)) + .build(); + + UserMessage userMessage = UserMessage.from("hello"); + + ToolSpecification toolSpecification = ToolSpecification.builder() + .name("add") + .addParameter("a", INTEGER) + .addParameter("b", INTEGER) + .build(); + + // when + TestStreamingResponseHandler handler = new TestStreamingResponseHandler<>(); + model.generate(singletonList(userMessage), singletonList(toolSpecification), handler); + AiMessage aiMessage = handler.get().content(); + + // then + ChatModelRequest request = requestReference.get(); + assertThat(request.model()).isEqualTo(modelName.toString()); + assertThat(request.temperature()).isEqualTo(temperature); + assertThat(request.topP()).isEqualTo(topP); + assertThat(request.maxTokens()).isEqualTo(maxTokens); + assertThat(request.messages()).containsExactly(userMessage); + assertThat(request.toolSpecifications()).containsExactly(toolSpecification); + + ChatModelResponse response = responseReference.get(); + assertThat(response.id()).isNotBlank(); + assertThat(response.model()).isNotBlank(); + assertThat(response.tokenUsage().inputTokenCount()).isGreaterThan(0); + assertThat(response.tokenUsage().outputTokenCount()).isGreaterThan(0); + assertThat(response.tokenUsage().totalTokenCount()).isGreaterThan(0); + assertThat(response.finishReason()).isNotNull(); + assertThat(response.aiMessage()).isEqualTo(aiMessage); + } + + @Test + void should_listen_error() throws Exception { + + // given + String wrongApiKey = "banana"; + + AtomicReference requestReference = new AtomicReference<>(); + AtomicReference errorReference = new AtomicReference<>(); + + ChatModelListener listener = new ChatModelListener() { + + @Override + public void onRequest(ChatModelRequestContext requestContext) { + requestReference.set(requestContext.request()); + requestContext.attributes().put("id", "12345"); + } + + @Override + public void onResponse(ChatModelResponseContext responseContext) { + fail("onResponse() must not be called"); + } + + @Override + public void onError(ChatModelErrorContext errorContext) { + errorReference.set(errorContext.error()); + assertThat(errorContext.request()).isSameAs(requestReference.get()); + assertThat(errorContext.partialResponse()).isNull(); // can be non-null if it fails in the middle of streaming + assertThat(errorContext.attributes().get("id")).isEqualTo("12345"); + } + }; + + StreamingChatLanguageModel model = OpenAiStreamingChatModel.builder() + .apiKey(wrongApiKey) + .logRequests(true) + .logResponses(true) + .listeners(singletonList(listener)) + .build(); + + String userMessage = "this message will fail"; + + CompletableFuture future = new CompletableFuture<>(); + StreamingResponseHandler handler = new StreamingResponseHandler() { + + @Override + public void onNext(String token) { + fail("onNext() must not be called"); + } + + @Override + public void onError(Throwable error) { + future.complete(error); + } + + @Override + public void onComplete(Response response) { + fail("onComplete() must not be called"); + } + }; + + // when + model.generate(userMessage, handler); + Throwable throwable = future.get(5, SECONDS); + + // then + assertThat(throwable).isExactlyInstanceOf(OpenAiHttpException.class); + assertThat(throwable).hasMessageContaining("Incorrect API key provided"); + + assertThat(errorReference.get()).isSameAs(throwable); + } } \ No newline at end of file diff --git a/langchain4j-open-ai/src/test/java/dev/langchain4j/model/openai/OpenAiTokenizerTest.java b/langchain4j-open-ai/src/test/java/dev/langchain4j/model/openai/OpenAiTokenizerTest.java index dea171c833..c3edc418d7 100644 --- a/langchain4j-open-ai/src/test/java/dev/langchain4j/model/openai/OpenAiTokenizerTest.java +++ b/langchain4j-open-ai/src/test/java/dev/langchain4j/model/openai/OpenAiTokenizerTest.java @@ -45,6 +45,11 @@ void should_count_tokens_in_short_texts() { assertThat(tokenizer.estimateTokenCountInText("Hello")).isEqualTo(1); assertThat(tokenizer.estimateTokenCountInText("Hello!")).isEqualTo(2); assertThat(tokenizer.estimateTokenCountInText("Hello, how are you?")).isEqualTo(6); + + assertThat(tokenizer.estimateTokenCountInText("")).isEqualTo(0); + assertThat(tokenizer.estimateTokenCountInText("\n")).isEqualTo(1); + assertThat(tokenizer.estimateTokenCountInText("\n\n")).isEqualTo(1); + assertThat(tokenizer.estimateTokenCountInText("\n \n\n")).isEqualTo(2); } @Test diff --git a/langchain4j-opensearch/pom.xml b/langchain4j-opensearch/pom.xml index e2de0a84ee..65f9640bd2 100644 --- a/langchain4j-opensearch/pom.xml +++ b/langchain4j-opensearch/pom.xml @@ -7,7 +7,7 @@ dev.langchain4j langchain4j-parent - 0.30.0 + 0.32.0-SNAPSHOT ../langchain4j-parent/pom.xml diff --git a/langchain4j-parent/pom.xml b/langchain4j-parent/pom.xml index a5e67ae553..a513ed1e37 100644 --- a/langchain4j-parent/pom.xml +++ b/langchain4j-parent/pom.xml @@ -6,7 +6,7 @@ dev.langchain4j langchain4j-parent - 0.30.0 + 0.32.0-SNAPSHOT pom LangChain4j :: Parent POM @@ -17,13 +17,13 @@ 8 8 UTF-8 - 1705675871 + 1714382357 0.17.0 - 1.0.0-beta.7 - 11.6.3 - 12.25.3 - 12.24.3 - 1.12.0 + 1.0.0-beta.8 + 11.6.5 + 12.26.1 + 12.25.1 + 1.12.2 2.9.0 4.12.0 1.0.0 @@ -47,6 +47,7 @@ 5.0.0 2.21.44 1.318 + 2.3.6 4.1.104.Final 4.2.0 2.9.0 @@ -118,6 +119,18 @@ ${azure.identity.version}
    + + dev.langchain4j + langchain4j-mistral-ai + ${project.version} + + + + dev.langchain4j + langchain4j-ollama + ${project.version} + + dev.langchain4j langchain4j-embeddings-all-minilm-l6-v2-q @@ -142,6 +155,12 @@ ${retrofit.version} + + com.squareup.retrofit2 + converter-jackson + ${retrofit.version} + + com.squareup.okhttp3 okhttp @@ -295,6 +314,18 @@ pom + + com.fasterxml.jackson.core + jackson-databind + ${jackson.version} + + + + com.fasterxml.jackson.dataformat + jackson-dataformat-xml + ${jackson.version} + + org.apache.opennlp opennlp-tools @@ -340,6 +371,12 @@ ${github-api.version} + + io.milvus + milvus-sdk-java + ${milvus-sdk-java.version} + + org.infinispan infinispan-bom @@ -347,6 +384,13 @@ pom import + + + ch.qos.logback + logback-classic + 1.3.14 + + diff --git a/langchain4j-pgvector/pom.xml b/langchain4j-pgvector/pom.xml index 71122e5faf..0099951892 100644 --- a/langchain4j-pgvector/pom.xml +++ b/langchain4j-pgvector/pom.xml @@ -7,13 +7,13 @@ dev.langchain4j langchain4j-parent - 0.30.0 + 0.32.0-SNAPSHOT ../langchain4j-parent/pom.xml - 0.1.3 - 42.7.2 + 0.1.4 + 42.7.3 2.40.0 @@ -86,6 +86,12 @@ test + + org.mockito + mockito-core + test + + org.testcontainers postgresql @@ -101,7 +107,6 @@ ch.qos.logback logback-classic - 1.3.14 test diff --git a/langchain4j-pgvector/src/main/java/dev/langchain4j/store/embedding/pgvector/ColumnFilterMapper.java b/langchain4j-pgvector/src/main/java/dev/langchain4j/store/embedding/pgvector/ColumnFilterMapper.java new file mode 100644 index 0000000000..a84b0e4312 --- /dev/null +++ b/langchain4j-pgvector/src/main/java/dev/langchain4j/store/embedding/pgvector/ColumnFilterMapper.java @@ -0,0 +1,13 @@ +package dev.langchain4j.store.embedding.pgvector; + +class ColumnFilterMapper extends PgVectorFilterMapper { + + String formatKey(String key, Class valueType) { + return String.format("%s::%s", key, SQL_TYPE_MAP.get(valueType)); + } + + String formatKeyAsString(String key) { + return key; + } + +} diff --git a/langchain4j-pgvector/src/main/java/dev/langchain4j/store/embedding/pgvector/ColumnsMetadataHandler.java b/langchain4j-pgvector/src/main/java/dev/langchain4j/store/embedding/pgvector/ColumnsMetadataHandler.java new file mode 100644 index 0000000000..18731e9244 --- /dev/null +++ b/langchain4j-pgvector/src/main/java/dev/langchain4j/store/embedding/pgvector/ColumnsMetadataHandler.java @@ -0,0 +1,108 @@ +package dev.langchain4j.store.embedding.pgvector; + +import dev.langchain4j.data.document.Metadata; +import dev.langchain4j.store.embedding.filter.Filter; + +import java.sql.*; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import static dev.langchain4j.internal.Utils.getOrDefault; +import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty; + +/** + * Handle Metadata stored in independent columns + */ +class ColumnsMetadataHandler implements MetadataHandler { + + final List columnsDefinition; + final List columnsName; + final PgVectorFilterMapper filterMapper; + final List indexes; + final String indexType; + + /** + * MetadataHandler constructor + * @param config {@link MetadataStorageConfig} configuration + */ + public ColumnsMetadataHandler(MetadataStorageConfig config) { + List columnsDefinitionList = ensureNotEmpty(config.columnDefinitions(), "Metadata definition"); + this.columnsDefinition = columnsDefinitionList.stream() + .map(MetadataColumDefinition::from).collect(Collectors.toList()); + this.columnsName = columnsDefinition.stream() + .map(MetadataColumDefinition::getName).collect(Collectors.toList()); + this.filterMapper = new ColumnFilterMapper(); + this.indexes = getOrDefault(config.indexes(), Collections.emptyList()); + this.indexType = config.indexType(); + } + + @Override + public String columnDefinitionsString() { + return this.columnsDefinition.stream() + .map(MetadataColumDefinition::getFullDefinition).collect(Collectors.joining(",")); + } + + @Override + public List columnsNames() { + return this.columnsName; + } + + @Override + public void createMetadataIndexes(Statement statement, String table) { + String indexTypeSql = indexType == null ? "" : "USING " + indexType; + this.indexes.stream().map(String::trim) + .forEach(index -> { + String indexSql = String.format("create index if not exists %s_%s on %s %s ( %s )", + table, index, table, indexTypeSql, index); + try { + statement.executeUpdate(indexSql); + } catch (SQLException e) { + throw new RuntimeException(String.format("Cannot create indexes %s: %s", index, e)); + } + }); + } + + @Override + public String insertClause() { + return this.columnsName.stream().map(c -> String.format("%s = EXCLUDED.%s", c, c)) + .collect(Collectors.joining(",")); + } + + @Override + public void setMetadata(PreparedStatement upsertStmt, Integer parameterInitialIndex, Metadata metadata) { + int i = 0; + // only column names fields will be stored + for (String c : this.columnsName) { + try { + upsertStmt.setObject(parameterInitialIndex + i, metadata.get(c), Types.OTHER); + i++; + } catch (SQLException e) { + throw new RuntimeException(e); + } + } + } + + @Override + public String whereClause(Filter filter) { + return filterMapper.map(filter); + } + + @Override + public Metadata fromResultSet(ResultSet resultSet) { + try { + Map metadataMap = new HashMap<>(); + for (String c : this.columnsName) { + if (resultSet.getObject(c) != null) { + metadataMap.put(c, resultSet.getObject(c)); + } + } + return new Metadata(metadataMap); + } catch (SQLException e) { + throw new RuntimeException(e); + } + } + +} diff --git a/langchain4j-pgvector/src/main/java/dev/langchain4j/store/embedding/pgvector/DefaultMetadataStorageConfig.java b/langchain4j-pgvector/src/main/java/dev/langchain4j/store/embedding/pgvector/DefaultMetadataStorageConfig.java new file mode 100644 index 0000000000..eb32b8d9a7 --- /dev/null +++ b/langchain4j-pgvector/src/main/java/dev/langchain4j/store/embedding/pgvector/DefaultMetadataStorageConfig.java @@ -0,0 +1,44 @@ +package dev.langchain4j.store.embedding.pgvector; + +import lombok.*; +import lombok.experimental.Accessors; + +import java.util.Collections; +import java.util.List; + +/** + * Metadata configuration implementation + */ +@Builder +@Getter +@Accessors(fluent = true) +@AllArgsConstructor +public class DefaultMetadataStorageConfig implements MetadataStorageConfig { + @NonNull + private MetadataStorageMode storageMode; + @NonNull + private List columnDefinitions; + private List indexes; + private String indexType; + + /** + * Just for warnings ? + */ + @SuppressWarnings("unused") + public DefaultMetadataStorageConfig(){ + // Just for javadoc warning ? + } + + /** + * Default configuration + * + * @return Default configuration + */ + public static MetadataStorageConfig defaultConfig() { + return DefaultMetadataStorageConfig.builder() + .storageMode(MetadataStorageMode.COMBINED_JSON) + .columnDefinitions(Collections.singletonList("metadata JSON NULL")) + .build(); + } + +} diff --git a/langchain4j-pgvector/src/main/java/dev/langchain4j/store/embedding/pgvector/JSONBMetadataHandler.java b/langchain4j-pgvector/src/main/java/dev/langchain4j/store/embedding/pgvector/JSONBMetadataHandler.java new file mode 100644 index 0000000000..ade036c602 --- /dev/null +++ b/langchain4j-pgvector/src/main/java/dev/langchain4j/store/embedding/pgvector/JSONBMetadataHandler.java @@ -0,0 +1,52 @@ +package dev.langchain4j.store.embedding.pgvector; + +import java.sql.SQLException; +import java.sql.Statement; + +/** + * Handle metadata as JSONB column. + */ +class JSONBMetadataHandler extends JSONMetadataHandler { + + final String indexType; + + /** + * MetadataHandler constructor + * @param config {@link MetadataStorageConfig} configuration + */ + public JSONBMetadataHandler(MetadataStorageConfig config) { + super(config); + if (!this.columnDefinition.getType().equals("jsonb")) { + throw new RuntimeException("Your column definition type should be JSONB"); + } + indexType = config.indexType(); + } + + @Override + public void createMetadataIndexes(Statement statement, String table) { + String indexTypeSql = indexType == null ? "" : "USING " + indexType; + for (String str : this.indexes) { + String index = str.trim(); + String indexName = formatIndex(index); + try { + String indexSql = String.format("create index if not exists %s_%s on %s %s (%s)", + table, indexName, table, indexTypeSql, index); + statement.executeUpdate(indexSql); + } catch (SQLException e) { + throw new RuntimeException(String.format("Cannot create index %s: %s", index, e)); + } + } + } + + String formatIndex(String index) { + // (metadata_b->'name') + String indexName; + if (index.contains("->")) { + indexName = columnName + "_" + index.substring(index.indexOf("->") + 2, index.length() - 1) + .trim().replaceAll("'", ""); + } else { + indexName = index.replaceAll(" ", "_"); + } + return indexName; + } +} diff --git a/langchain4j-pgvector/src/main/java/dev/langchain4j/store/embedding/pgvector/JSONFilterMapper.java b/langchain4j-pgvector/src/main/java/dev/langchain4j/store/embedding/pgvector/JSONFilterMapper.java new file mode 100644 index 0000000000..a7d9d7eb6a --- /dev/null +++ b/langchain4j-pgvector/src/main/java/dev/langchain4j/store/embedding/pgvector/JSONFilterMapper.java @@ -0,0 +1,18 @@ +package dev.langchain4j.store.embedding.pgvector; + +class JSONFilterMapper extends PgVectorFilterMapper { + final String metadataColumn; + + public JSONFilterMapper(String metadataColumn) { + this.metadataColumn = metadataColumn; + } + + String formatKey(String key, Class valueType) { + return String.format("(%s->>'%s')::%s", metadataColumn, key, SQL_TYPE_MAP.get(valueType)); + } + + String formatKeyAsString(String key) { + return metadataColumn + "->>'" + key + "'"; + } + +} diff --git a/langchain4j-pgvector/src/main/java/dev/langchain4j/store/embedding/pgvector/JSONMetadataHandler.java b/langchain4j-pgvector/src/main/java/dev/langchain4j/store/embedding/pgvector/JSONMetadataHandler.java new file mode 100644 index 0000000000..5b12e2e189 --- /dev/null +++ b/langchain4j-pgvector/src/main/java/dev/langchain4j/store/embedding/pgvector/JSONMetadataHandler.java @@ -0,0 +1,85 @@ +package dev.langchain4j.store.embedding.pgvector; + +import dev.langchain4j.data.document.Metadata; +import dev.langchain4j.internal.Json; +import dev.langchain4j.store.embedding.filter.Filter; + +import java.sql.*; +import java.util.*; + +import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty; +import static dev.langchain4j.internal.Utils.getOrDefault; + +/** + * Handle metadata as JSON column. + */ +class JSONMetadataHandler implements MetadataHandler { + + final MetadataColumDefinition columnDefinition; + final String columnName; + final JSONFilterMapper filterMapper; + final List indexes; + + /** + * MetadataHandler constructor + * @param config {@link MetadataStorageConfig} configuration + */ + public JSONMetadataHandler(MetadataStorageConfig config) { + List definition = ensureNotEmpty(config.columnDefinitions(), "Metadata definition"); + if (definition.size()>1) { + throw new IllegalArgumentException("Metadata definition should be an unique column definition, " + + "example: metadata JSON NULL"); + } + this.columnDefinition = MetadataColumDefinition.from(definition.get(0)); + this.columnName = this.columnDefinition.getName(); + this.filterMapper = new JSONFilterMapper(columnName); + this.indexes = getOrDefault(config.indexes(), Collections.emptyList()); + } + + @Override + public String columnDefinitionsString() { + return columnDefinition.getFullDefinition(); + } + + @Override + public List columnsNames() { + return Collections.singletonList(this.columnName); + } + + @Override + public void createMetadataIndexes(Statement statement, String table) { + if (!this.indexes.isEmpty()) { + throw new RuntimeException("Indexes are not allowed for JSON metadata, use JSONB instead"); + } + } + + @Override + public String whereClause(Filter filter) { + return filterMapper.map(filter); + } + + @Override + @SuppressWarnings("unchecked") + public Metadata fromResultSet(ResultSet resultSet) { + try { + String metadataJson = getOrDefault(resultSet.getString(columnsNames().get(0)),"{}"); + return new Metadata(Json.fromJson(metadataJson, Map.class)); + } catch (SQLException e) { + throw new RuntimeException(e); + } + } + + @Override + public String insertClause() { + return String.format("%s = EXCLUDED.%s", this.columnName, this.columnName); + } + + @Override + public void setMetadata(PreparedStatement upsertStmt, Integer parameterInitialIndex, Metadata metadata) { + try { + upsertStmt.setObject(parameterInitialIndex, Json.toJson(metadata.asMap()), Types.OTHER); + } catch (SQLException e) { + throw new RuntimeException(e); + } + } +} diff --git a/langchain4j-pgvector/src/main/java/dev/langchain4j/store/embedding/pgvector/MetadataColumDefinition.java b/langchain4j-pgvector/src/main/java/dev/langchain4j/store/embedding/pgvector/MetadataColumDefinition.java new file mode 100644 index 0000000000..1633ae6db0 --- /dev/null +++ b/langchain4j-pgvector/src/main/java/dev/langchain4j/store/embedding/pgvector/MetadataColumDefinition.java @@ -0,0 +1,42 @@ +package dev.langchain4j.store.embedding.pgvector; + +import dev.langchain4j.internal.ValidationUtils; +import lombok.Getter; + +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; + +/** + * MetadataColumDefinition used to define column definition from sql String + */ +@Getter +public class MetadataColumDefinition { + private final String fullDefinition; + private final String name; + private final String type; + + private MetadataColumDefinition(String fullDefinition, String name, String type) { + this.fullDefinition = fullDefinition; + this.name = name; + this.type = type; + } + + /** + * transform sql string to MetadataColumDefinition + * @param sqlDefinition sql definition string + * @return MetadataColumDefinition + */ + public static MetadataColumDefinition from(String sqlDefinition) { + String fullDefinition = ValidationUtils.ensureNotNull(sqlDefinition, "Metadata column definition"); + List tokens = Arrays.stream(fullDefinition.split(" ")) + .filter(s -> !s.isEmpty()).collect(Collectors.toList()); + if (tokens.size() < 2) { + throw new IllegalArgumentException("Definition format should be: column type" + + " [ NULL | NOT NULL ] [ UNIQUE ] [ DEFAULT value ]"); + } + String name = tokens.get(0); + String type = tokens.get(1).toLowerCase(); + return new MetadataColumDefinition(fullDefinition, name, type); + } +} diff --git a/langchain4j-pgvector/src/main/java/dev/langchain4j/store/embedding/pgvector/MetadataHandler.java b/langchain4j-pgvector/src/main/java/dev/langchain4j/store/embedding/pgvector/MetadataHandler.java new file mode 100644 index 0000000000..5f92bd3d63 --- /dev/null +++ b/langchain4j-pgvector/src/main/java/dev/langchain4j/store/embedding/pgvector/MetadataHandler.java @@ -0,0 +1,73 @@ +package dev.langchain4j.store.embedding.pgvector; + +import dev.langchain4j.data.document.Metadata; +import dev.langchain4j.store.embedding.filter.Filter; + +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.Statement; +import java.util.List; + +/** + * Handle PV Store metadata. + */ +interface MetadataHandler { + + /** + * String definition used to create the metadata field(s) in embeddings table + * + * @return the sql clause that creates metadata field(s) + * + */ + String columnDefinitionsString(); + + /** + * Setup indexes for metadata fields + * By default, no index is created. + * + * @param statement used to execute indexes creation. + * @param table table name. + */ + void createMetadataIndexes(Statement statement, String table); + + /** + * Metadata columns name + * + * @return list of columns used as metadata + */ + List columnsNames(); + + /** + * Generate the SQL where clause following @{@link Filter} + * + * @param filter filter + * @return the sql where clause + */ + String whereClause(Filter filter); + + /** + * Extract Metadata from Resultset and Metadata definition + * + * @param resultSet resultSet + * @return metadata object + */ + Metadata fromResultSet(ResultSet resultSet); + + /** + * Generate the SQL insert clause following Metadata definition + * + * @return the sql insert clause + */ + String insertClause(); + + /** + * Set meta data values following metadata and metadata definition + * + * @param upsertStmt statement to set values + * @param parameterInitialIndex initial parameter index + * @param metadata metadata values + */ + void setMetadata(PreparedStatement upsertStmt, Integer parameterInitialIndex, Metadata metadata); + + +} diff --git a/langchain4j-pgvector/src/main/java/dev/langchain4j/store/embedding/pgvector/MetadataHandlerFactory.java b/langchain4j-pgvector/src/main/java/dev/langchain4j/store/embedding/pgvector/MetadataHandlerFactory.java new file mode 100644 index 0000000000..1efd3cead9 --- /dev/null +++ b/langchain4j-pgvector/src/main/java/dev/langchain4j/store/embedding/pgvector/MetadataHandlerFactory.java @@ -0,0 +1,29 @@ +package dev.langchain4j.store.embedding.pgvector; + +/** + * MetadataHandlerFactory class + * Use the {@link MetadataStorageConfig#storageMode()} to switch between different Handler implementation + */ +class MetadataHandlerFactory { + /** + * Default Constructor + */ + public MetadataHandlerFactory() {} + /** + * Retrieve the handler associated to the config + * @param config MetadataConfig config + * @return MetadataHandler + */ + static MetadataHandler get(MetadataStorageConfig config) { + switch(config.storageMode()) { + case COMBINED_JSON: + return new JSONMetadataHandler(config); + case COMBINED_JSONB: + return new JSONBMetadataHandler(config); + case COLUMN_PER_KEY: + return new ColumnsMetadataHandler(config); + default: + throw new RuntimeException(String.format("Type %s not handled.", config.storageMode())); + } + } +} diff --git a/langchain4j-pgvector/src/main/java/dev/langchain4j/store/embedding/pgvector/MetadataStorageConfig.java b/langchain4j-pgvector/src/main/java/dev/langchain4j/store/embedding/pgvector/MetadataStorageConfig.java new file mode 100644 index 0000000000..e185e92eb8 --- /dev/null +++ b/langchain4j-pgvector/src/main/java/dev/langchain4j/store/embedding/pgvector/MetadataStorageConfig.java @@ -0,0 +1,53 @@ +package dev.langchain4j.store.embedding.pgvector; + +import java.util.List; + +/** + * Metadata configuration. + */ +public interface MetadataStorageConfig { + /** + * Metadata storage mode + *
      + *
    • COMBINED_JSON: For dynamic metadata, when you don't know the list of metadata that will be used. + *
    • COMBINED_JSONB: Same as JSON, but stored in a binary way. Optimized for query on large dataset. + *
    • COLUMN_PER_KEY: for static metadata, when you know in advance the list of metadata + *
    + * @return Metadata storage mode + */ + MetadataStorageMode storageMode(); + /** + * SQL definition of metadata field(s) list. + * Example: + *
      + *
    • COMBINED_JSON: Collections.singletonList("metadata JSON NULL") + *
    • COMBINED_JSONB: Collections.singletonList("metadata JSONB NULL") + *
    • COLUMN_PER_KEY: Arrays.asList("condominium_id uuid null", "user uuid null") + *
    + * @return list of column definitions + */ + List columnDefinitions(); + /** + * Metadata Indexes, list of fields to use as index. + * Example: + *
      + *
    • COMBINED_JSON: Collections.singletonList("metadata") or + * Arrays.asList("(metadata->'key')", "(metadata->'name')", "(metadata->'age')") + *
    • COMBINED_JSONB: Collections.singletonList("metadata") or + * Arrays.asList("(metadata->'key')", "(metadata->'name')", "(metadata->'age')") + *
    • COLUMN_PER_KEY: Arrays.asList("key", "name", "age") + *
    + * @return Metadata Indexes list + */ + List indexes(); + /** + * Index Type: + *
      + *
    • BTREE (default) + *
    • GIN + *
    • ... postgres indexes + *
    + * @return Index Type + */ + String indexType(); +} diff --git a/langchain4j-pgvector/src/main/java/dev/langchain4j/store/embedding/pgvector/MetadataStorageMode.java b/langchain4j-pgvector/src/main/java/dev/langchain4j/store/embedding/pgvector/MetadataStorageMode.java new file mode 100644 index 0000000000..521eb38f83 --- /dev/null +++ b/langchain4j-pgvector/src/main/java/dev/langchain4j/store/embedding/pgvector/MetadataStorageMode.java @@ -0,0 +1,27 @@ +package dev.langchain4j.store.embedding.pgvector; + +/** + * Metadata storage mode + *
      + *
    • COLUMN_PER_KEY: for static metadata, when you know in advance the list of metadata + *
    • COMBINED_JSON: For dynamic metadata, when you don't know the list of metadata that will be used. + *
    • COMBINED_JSONB: Same as JSON, but stored in a binary way. Optimized for query on large dataset. + *
    + *

    + * Default value: COMBINED_JSON + */ +public enum MetadataStorageMode { + /** + * COLUMN_PER_KEY: for static metadata, when you know in advance the list of metadata + */ + COLUMN_PER_KEY, + /** + * COMBINED_JSON: For dynamic metadata, when you don't know the list of metadata that will be used. + */ + COMBINED_JSON, + /** + * COMBINED_JSONB: Same as JSON, but stored in a binary way. Optimized for query on large dataset. + */ + COMBINED_JSONB +} + diff --git a/langchain4j-pgvector/src/main/java/dev/langchain4j/store/embedding/pgvector/PgVectorEmbeddingStore.java b/langchain4j-pgvector/src/main/java/dev/langchain4j/store/embedding/pgvector/PgVectorEmbeddingStore.java index 84917da2fc..47bff4f872 100644 --- a/langchain4j-pgvector/src/main/java/dev/langchain4j/store/embedding/pgvector/PgVectorEmbeddingStore.java +++ b/langchain4j-pgvector/src/main/java/dev/langchain4j/store/embedding/pgvector/PgVectorEmbeddingStore.java @@ -1,23 +1,29 @@ package dev.langchain4j.store.embedding.pgvector; -import com.google.gson.Gson; -import com.google.gson.reflect.TypeToken; import com.pgvector.PGvector; import dev.langchain4j.data.document.Metadata; import dev.langchain4j.data.embedding.Embedding; import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.store.embedding.EmbeddingMatch; +import dev.langchain4j.store.embedding.EmbeddingSearchRequest; +import dev.langchain4j.store.embedding.EmbeddingSearchResult; import dev.langchain4j.store.embedding.EmbeddingStore; +import dev.langchain4j.store.embedding.filter.Filter; import lombok.Builder; +import lombok.NoArgsConstructor; +import org.postgresql.ds.PGSimpleDataSource; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.lang.reflect.Type; +import javax.sql.DataSource; import java.sql.*; import java.util.*; +import java.util.stream.IntStream; import static dev.langchain4j.internal.Utils.*; import static dev.langchain4j.internal.ValidationUtils.*; +import static java.lang.String.join; +import static java.util.Collections.nCopies; import static java.util.Collections.singletonList; import static java.util.stream.Collectors.toList; @@ -27,36 +33,74 @@ * Only cosine similarity is used. * Only ivfflat index is used. */ +@NoArgsConstructor(force = true) // Needed for inherited bean injection validation public class PgVectorEmbeddingStore implements EmbeddingStore { - private static final Logger log = LoggerFactory.getLogger(PgVectorEmbeddingStore.class); + /** + * Datasource used to create the store + */ + protected final DataSource datasource; + /** + * Embeddings table name + */ + protected final String table; + /** + * Metadata handler + */ + final MetadataHandler metadataHandler; - private static final Gson GSON = new Gson(); + /** + * Constructor for PgVectorEmbeddingStore Class + * + * @param datasource The datasource to use + * @param table The database table + * @param dimension The vector dimension + * @param useIndex Should use IVFFlat index + * @param indexListSize The IVFFlat number of lists + * @param createTable Should create table automatically + * @param dropTableFirst Should drop table first, usually for testing + * @param metadataStorageConfig The {@link MetadataStorageConfig} config. + */ + @Builder(builderMethodName = "datasourceBuilder", builderClassName = "DatasourceBuilder") + protected PgVectorEmbeddingStore(DataSource datasource, + String table, + Integer dimension, + Boolean useIndex, + Integer indexListSize, + Boolean createTable, + Boolean dropTableFirst, + MetadataStorageConfig metadataStorageConfig) { + this.datasource = ensureNotNull(datasource, "datasource"); + this.table = ensureNotBlank(table, "table"); + MetadataStorageConfig config = getOrDefault(metadataStorageConfig, DefaultMetadataStorageConfig.defaultConfig()); + this.metadataHandler = MetadataHandlerFactory.get(config); + useIndex = getOrDefault(useIndex, false); + createTable = getOrDefault(createTable, true); + dropTableFirst = getOrDefault(dropTableFirst, false); - private final String host; - private final Integer port; - private final String user; - private final String password; - private final String database; - private final String table; + initTable(dropTableFirst, createTable, useIndex, dimension, indexListSize); + } /** - * All args constructor for PgVectorEmbeddingStore Class + * Constructor for PgVectorEmbeddingStore Class + * Use this builder when you don't have datasource management. * - * @param host The database host - * @param port The database port - * @param user The database user - * @param password The database password - * @param database The database name - * @param table The database table - * @param dimension The vector dimension - * @param useIndex Should use IVFFlat index - * @param indexListSize The IVFFlat number of lists - * @param createTable Should create table automatically - * @param dropTableFirst Should drop table first, usually for testing + * @param host The database host + * @param port The database port + * @param user The database user + * @param password The database password + * @param database The database name + * @param table The database table + * @param dimension The vector dimension + * @param useIndex Should use IVFFlat index + * @param indexListSize The IVFFlat number of lists + * @param createTable Should create table automatically + * @param dropTableFirst Should drop table first, usually for testing + * @param metadataStorageConfig The {@link MetadataStorageConfig} config. */ + @SuppressWarnings("unused") @Builder - public PgVectorEmbeddingStore( + protected PgVectorEmbeddingStore( String host, Integer port, String user, @@ -67,59 +111,69 @@ public PgVectorEmbeddingStore( Boolean useIndex, Integer indexListSize, Boolean createTable, - Boolean dropTableFirst) { - this.host = ensureNotBlank(host, "host"); - this.port = ensureGreaterThanZero(port, "port"); - this.user = ensureNotBlank(user, "user"); - this.password = ensureNotBlank(password, "password"); - this.database = ensureNotBlank(database, "database"); - this.table = ensureNotBlank(table, "table"); + Boolean dropTableFirst, + MetadataStorageConfig metadataStorageConfig + ) { + this(createDataSource(host, port, user, password, database), + table, dimension, useIndex, indexListSize, createTable, dropTableFirst, metadataStorageConfig); + } - useIndex = getOrDefault(useIndex, false); - createTable = getOrDefault(createTable, true); - dropTableFirst = getOrDefault(dropTableFirst, false); + private static DataSource createDataSource(String host, Integer port, String user, String password, String database) { + host = ensureNotBlank(host, "host"); + port = ensureGreaterThanZero(port, "port"); + user = ensureNotBlank(user, "user"); + password = ensureNotBlank(password, "password"); + database = ensureNotBlank(database, "database"); + + PGSimpleDataSource source = new PGSimpleDataSource(); + source.setServerNames(new String[]{host}); + source.setPortNumbers(new int[]{port}); + source.setDatabaseName(database); + source.setUser(user); + source.setPassword(password); + + return source; + } - try (Connection connection = setupConnection()) { + /** + * Initialize metadata table following configuration + * + * @param dropTableFirst Should drop table first, usually for testing + * @param createTable Should create table automatically + * @param useIndex Should use IVFFlat index + * @param dimension The vector dimension + * @param indexListSize The IVFFlat number of lists + */ + protected void initTable(Boolean dropTableFirst, Boolean createTable, Boolean useIndex, Integer dimension, + Integer indexListSize) { + String query = "init"; + try (Connection connection = getConnection(); Statement statement = connection.createStatement()) { if (dropTableFirst) { - connection.createStatement().executeUpdate(String.format("DROP TABLE IF EXISTS %s", table)); + statement.executeUpdate(String.format("DROP TABLE IF EXISTS %s", table)); } - if (createTable) { - connection.createStatement().executeUpdate(String.format( - "CREATE TABLE IF NOT EXISTS %s (" + - "embedding_id UUID PRIMARY KEY, " + - "embedding vector(%s), " + - "text TEXT NULL, " + - "metadata JSON NULL" + - ")", - table, ensureGreaterThanZero(dimension, "dimension"))); + query = String.format("CREATE TABLE IF NOT EXISTS %s (embedding_id UUID PRIMARY KEY, " + + "embedding vector(%s), text TEXT NULL, %s )", + table, ensureGreaterThanZero(dimension, "dimension"), + metadataHandler.columnDefinitionsString()); + statement.executeUpdate(query); + metadataHandler.createMetadataIndexes(statement, table); } - if (useIndex) { final String indexName = table + "_ivfflat_index"; - connection.createStatement().executeUpdate(String.format( + query = String.format( "CREATE INDEX IF NOT EXISTS %s ON %s " + "USING ivfflat (embedding vector_cosine_ops) " + "WITH (lists = %s)", - indexName, table, ensureGreaterThanZero(indexListSize, "indexListSize"))); + indexName, table, ensureGreaterThanZero(indexListSize, "indexListSize")); + statement.executeUpdate(query); } } catch (SQLException e) { - throw new RuntimeException(e); + throw new RuntimeException(String.format("Failed to execute '%s'", query), e); } } - private Connection setupConnection() throws SQLException { - Connection connection = DriverManager.getConnection( - String.format("jdbc:postgresql://%s:%s/%s", host, port, database), - user, - password - ); - connection.createStatement().executeUpdate("CREATE EXTENSION IF NOT EXISTS vector"); - PGvector.addVectorType(connection); - return connection; - } - /** * Adds a given embedding to the store. * @@ -185,52 +239,92 @@ public List addAll(List embeddings, List embedde return ids; } + @Override + public void removeAll(Collection ids) { + ensureNotEmpty(ids, "ids"); + String sql = String.format("DELETE FROM %s WHERE embedding_id = ANY (?)", table); + try (Connection connection = getConnection(); + PreparedStatement statement = connection.prepareStatement(sql)) { + Array array = connection.createArrayOf("uuid", ids.stream().map(UUID::fromString).toArray()); + statement.setArray(1, array); + statement.executeUpdate(); + } catch (SQLException e) { + throw new RuntimeException(e); + } + } + + @Override + public void removeAll(Filter filter) { + ensureNotNull(filter, "filter"); + String whereClause = metadataHandler.whereClause(filter); + String sql = String.format("DELETE FROM %s WHERE %s", table, whereClause); + try (Connection connection = getConnection(); + PreparedStatement statement = connection.prepareStatement(sql)) { + statement.executeUpdate(); + } catch (SQLException e) { + throw new RuntimeException(e); + } + } + + @Override + public void removeAll() { + try (Connection connection = getConnection(); + Statement statement = connection.createStatement()) { + statement.executeUpdate(String.format("TRUNCATE TABLE %s", table)); + } catch (SQLException e) { + throw new RuntimeException(e); + } + } + /** - * Finds the most relevant (closest in space) embeddings to the provided reference embedding. + * Searches for the most similar (closest in the embedding space) {@link Embedding}s. + *
    + * All search criteria are defined inside the {@link EmbeddingSearchRequest}. + *
    + * {@link EmbeddingSearchRequest#filter()} is used to filter by meta dada. * - * @param referenceEmbedding The embedding used as a reference. Returned embeddings should be relevant (closest) to this one. - * @param maxResults The maximum number of embeddings to be returned. - * @param minScore The minimum relevance score, ranging from 0 to 1 (inclusive). - * Only embeddings with a score of this value or higher will be returned. - * @return A list of embedding matches. - * Each embedding match includes a relevance score (derivative of cosine distance), - * ranging from 0 (not relevant) to 1 (highly relevant). + * @param request A request to search in an {@link EmbeddingStore}. Contains all search criteria. + * @return An {@link EmbeddingSearchResult} containing all found {@link Embedding}s. */ @Override - public List> findRelevant(Embedding referenceEmbedding, int maxResults, double minScore) { + public EmbeddingSearchResult search(EmbeddingSearchRequest request) { + Embedding referenceEmbedding = request.queryEmbedding(); + int maxResults = request.maxResults(); + double minScore = request.minScore(); + Filter filter = request.filter(); + List> result = new ArrayList<>(); - try (Connection connection = setupConnection()) { + try (Connection connection = getConnection()) { String referenceVector = Arrays.toString(referenceEmbedding.vector()); + String whereClause = (filter == null) ? "" : metadataHandler.whereClause(filter); + whereClause = (whereClause.isEmpty()) ? "" : "WHERE " + whereClause; String query = String.format( - "WITH temp AS (SELECT (2 - (embedding <=> '%s')) / 2 AS score, embedding_id, embedding, text, metadata FROM %s) SELECT * FROM temp WHERE score >= %s ORDER BY score desc LIMIT %s;", - referenceVector, table, minScore, maxResults); - PreparedStatement selectStmt = connection.prepareStatement(query); - - ResultSet resultSet = selectStmt.executeQuery(); - while (resultSet.next()) { - double score = resultSet.getDouble("score"); - String embeddingId = resultSet.getString("embedding_id"); - - PGvector vector = (PGvector) resultSet.getObject("embedding"); - Embedding embedding = new Embedding(vector.toArray()); - - String text = resultSet.getString("text"); - TextSegment textSegment = null; - if (isNotNullOrBlank(text)) { - String metadataJson = Optional.ofNullable(resultSet.getString("metadata")).orElse("{}"); - Type type = new TypeToken>() { - }.getType(); - Metadata metadata = new Metadata(new HashMap<>(GSON.fromJson(metadataJson, type))); - textSegment = TextSegment.from(text, metadata); + "WITH temp AS (SELECT (2 - (embedding <=> '%s')) / 2 AS score, embedding_id, embedding, text, " + + "%s FROM %s %s) SELECT * FROM temp WHERE score >= %s ORDER BY score desc LIMIT %s;", + referenceVector, join(",", metadataHandler.columnsNames()), table, whereClause, minScore, maxResults); + try (PreparedStatement selectStmt = connection.prepareStatement(query)) { + try (ResultSet resultSet = selectStmt.executeQuery()) { + while (resultSet.next()) { + double score = resultSet.getDouble("score"); + String embeddingId = resultSet.getString("embedding_id"); + + PGvector vector = (PGvector) resultSet.getObject("embedding"); + Embedding embedding = new Embedding(vector.toArray()); + + String text = resultSet.getString("text"); + TextSegment textSegment = null; + if (isNotNullOrBlank(text)) { + Metadata metadata = metadataHandler.fromResultSet(resultSet); + textSegment = TextSegment.from(text, metadata); + } + result.add(new EmbeddingMatch<>(score, embeddingId, embedding, textSegment)); + } } - - result.add(new EmbeddingMatch<>(score, embeddingId, embedding, textSegment)); } } catch (SQLException e) { throw new RuntimeException(e); } - - return result; + return new EmbeddingSearchResult<>(result); } private void addInternal(String id, Embedding embedding, TextSegment embedded) { @@ -250,36 +344,61 @@ private void addAllInternal( ensureTrue(embedded == null || embeddings.size() == embedded.size(), "embeddings size is not equal to embedded size"); - try (Connection connection = setupConnection()) { + try (Connection connection = getConnection()) { String query = String.format( - "INSERT INTO %s (embedding_id, embedding, text, metadata) VALUES (?, ?, ?, ?)" + + "INSERT INTO %s (embedding_id, embedding, text, %s) VALUES (?, ?, ?, %s)" + "ON CONFLICT (embedding_id) DO UPDATE SET " + "embedding = EXCLUDED.embedding," + "text = EXCLUDED.text," + - "metadata = EXCLUDED.metadata;", - table); - - PreparedStatement upsertStmt = connection.prepareStatement(query); - - for (int i = 0; i < ids.size(); ++i) { - upsertStmt.setObject(1, UUID.fromString(ids.get(i))); - upsertStmt.setObject(2, new PGvector(embeddings.get(i).vector())); - - if (embedded != null && embedded.get(i) != null) { - upsertStmt.setObject(3, embedded.get(i).text()); - Map metadata = embedded.get(i).metadata().asMap(); - upsertStmt.setObject(4, GSON.toJson(metadata), Types.OTHER); - } else { - upsertStmt.setNull(3, Types.VARCHAR); - upsertStmt.setNull(4, Types.OTHER); + "%s;", + table, join(",", metadataHandler.columnsNames()), + join(",", nCopies(metadataHandler.columnsNames().size(), "?")), + metadataHandler.insertClause()); + try (PreparedStatement upsertStmt = connection.prepareStatement(query)) { + for (int i = 0; i < ids.size(); ++i) { + upsertStmt.setObject(1, UUID.fromString(ids.get(i))); + upsertStmt.setObject(2, new PGvector(embeddings.get(i).vector())); + + if (embedded != null && embedded.get(i) != null) { + upsertStmt.setObject(3, embedded.get(i).text()); + metadataHandler.setMetadata(upsertStmt, 4, embedded.get(i).metadata()); + } else { + upsertStmt.setNull(3, Types.VARCHAR); + IntStream.range(4, 4 + metadataHandler.columnsNames().size()).forEach( + j -> { + try { + upsertStmt.setNull(j, Types.OTHER); + } catch (SQLException e) { + throw new RuntimeException(e); + } + }); + } + upsertStmt.addBatch(); } - upsertStmt.addBatch(); + upsertStmt.executeBatch(); } - - upsertStmt.executeBatch(); - - } catch (SQLException e) { + } catch (Exception e) { throw new RuntimeException(e); } } + + /** + * Datasource connection + * Creates the vector extension and add the vector type if it does not exist. + * Could be overridden in case extension creation and adding type is done at datasource initialization step. + * + * @return Datasource connection + * @throws SQLException exception + */ + protected Connection getConnection() throws SQLException { + Connection connection = datasource.getConnection(); + // Find a way to do the following code in connection initialization. + // Here we assume the datasource could handle a connection pool + // and we should add the vector type on each connection + try (Statement statement = connection.createStatement()) { + statement.executeUpdate("CREATE EXTENSION IF NOT EXISTS vector"); + } + PGvector.addVectorType(connection); + return connection; + } } diff --git a/langchain4j-pgvector/src/main/java/dev/langchain4j/store/embedding/pgvector/PgVectorFilterMapper.java b/langchain4j-pgvector/src/main/java/dev/langchain4j/store/embedding/pgvector/PgVectorFilterMapper.java new file mode 100644 index 0000000000..893a5b4915 --- /dev/null +++ b/langchain4j-pgvector/src/main/java/dev/langchain4j/store/embedding/pgvector/PgVectorFilterMapper.java @@ -0,0 +1,127 @@ +package dev.langchain4j.store.embedding.pgvector; + +import dev.langchain4j.store.embedding.filter.Filter; +import dev.langchain4j.store.embedding.filter.comparison.*; +import dev.langchain4j.store.embedding.filter.logical.And; +import dev.langchain4j.store.embedding.filter.logical.Not; +import dev.langchain4j.store.embedding.filter.logical.Or; + +import java.util.Collection; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static java.lang.String.format; +import static java.util.AbstractMap.SimpleEntry; + +abstract class PgVectorFilterMapper { + + static final Map, String> SQL_TYPE_MAP = Stream.of( + new SimpleEntry<>(Integer.class, "int"), + new SimpleEntry<>(Long.class, "bigint"), + new SimpleEntry<>(Float.class, "float"), + new SimpleEntry<>(Double.class, "float8"), + new SimpleEntry<>(String.class, "text"), + new SimpleEntry<>(Boolean.class, "boolean"), + // Default + new SimpleEntry<>(Object.class, "text")) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + + public String map(Filter filter) { + if (filter instanceof IsEqualTo) { + return mapEqual((IsEqualTo) filter); + } else if (filter instanceof IsNotEqualTo) { + return mapNotEqual((IsNotEqualTo) filter); + } else if (filter instanceof IsGreaterThan) { + return mapGreaterThan((IsGreaterThan) filter); + } else if (filter instanceof IsGreaterThanOrEqualTo) { + return mapGreaterThanOrEqual((IsGreaterThanOrEqualTo) filter); + } else if (filter instanceof IsLessThan) { + return mapLessThan((IsLessThan) filter); + } else if (filter instanceof IsLessThanOrEqualTo) { + return mapLessThanOrEqual((IsLessThanOrEqualTo) filter); + } else if (filter instanceof IsIn) { + return mapIn((IsIn) filter); + } else if (filter instanceof IsNotIn) { + return mapNotIn((IsNotIn) filter); + } else if (filter instanceof And) { + return mapAnd((And) filter); + } else if (filter instanceof Not) { + return mapNot((Not) filter); + } else if (filter instanceof Or) { + return mapOr((Or) filter); + } else { + throw new UnsupportedOperationException("Unsupported filter type: " + filter.getClass().getName()); + } + } + + private String mapEqual(IsEqualTo isEqualTo) { + String key = formatKey(isEqualTo.key(), isEqualTo.comparisonValue().getClass()); + return format("%s is not null and %s = %s", key, key, + formatValue(isEqualTo.comparisonValue())); + } + + private String mapNotEqual(IsNotEqualTo isNotEqualTo) { + String key = formatKey(isNotEqualTo.key(), isNotEqualTo.comparisonValue().getClass()); + return format("%s is null or %s != %s", key, key, + formatValue(isNotEqualTo.comparisonValue())); + } + + private String mapGreaterThan(IsGreaterThan isGreaterThan) { + return format("%s > %s", formatKey(isGreaterThan.key(), isGreaterThan.comparisonValue().getClass()), + formatValue(isGreaterThan.comparisonValue())); + } + + private String mapGreaterThanOrEqual(IsGreaterThanOrEqualTo isGreaterThanOrEqualTo) { + return format("%s >= %s", formatKey(isGreaterThanOrEqualTo.key(), isGreaterThanOrEqualTo.comparisonValue().getClass()), + formatValue(isGreaterThanOrEqualTo.comparisonValue())); + } + + private String mapLessThan(IsLessThan isLessThan) { + return format("%s < %s", formatKey(isLessThan.key(), isLessThan.comparisonValue().getClass()), + formatValue(isLessThan.comparisonValue())); + } + + private String mapLessThanOrEqual(IsLessThanOrEqualTo isLessThanOrEqualTo) { + return format("%s <= %s", formatKey(isLessThanOrEqualTo.key(), isLessThanOrEqualTo.comparisonValue().getClass()), + formatValue(isLessThanOrEqualTo.comparisonValue())); + } + + private String mapIn(IsIn isIn) { + return format("%s in %s", formatKeyAsString(isIn.key()), formatValuesAsString(isIn.comparisonValues())); + } + + private String mapNotIn(IsNotIn isNotIn) { + String key = formatKeyAsString(isNotIn.key()); + return format("%s is null or %s not in %s", key, key, formatValuesAsString(isNotIn.comparisonValues())); + } + + private String mapAnd(And and) { + return format("%s and %s", map(and.left()), map(and.right())); + } + + private String mapNot(Not not) { + return format("not(%s)", map(not.expression())); + } + + private String mapOr(Or or) { + return format("(%s or %s)", map(or.left()), map(or.right())); + } + + abstract String formatKey(String key, Class valueType); + + abstract String formatKeyAsString(String key); + + String formatValue(Object value) { + if (value instanceof String) { + return "'" + value + "'"; + } else { + return value.toString(); + } + } + + String formatValuesAsString(Collection values) { + return "(" + values.stream().map(v -> String.format("'%s'", v)) + .collect(Collectors.joining(",")) + ")"; + } +} diff --git a/langchain4j-pgvector/src/test/java/dev/langchain4j/store/embedding/pgvector/JSONBMetadataHandlerTest.java b/langchain4j-pgvector/src/test/java/dev/langchain4j/store/embedding/pgvector/JSONBMetadataHandlerTest.java new file mode 100644 index 0000000000..8a137fb88b --- /dev/null +++ b/langchain4j-pgvector/src/test/java/dev/langchain4j/store/embedding/pgvector/JSONBMetadataHandlerTest.java @@ -0,0 +1,82 @@ +package dev.langchain4j.store.embedding.pgvector; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.mockito.AdditionalAnswers; + +import java.sql.SQLException; +import java.sql.Statement; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import static org.mockito.Mockito.*; + +class JSONBMetadataHandlerTest { + + @Test + void createSimpleMetadataIndexes() throws SQLException { + Statement statement = mock(Statement.class); + List sqlStatementQueries = new ArrayList<>(); + when(statement.executeUpdate(anyString())) + .thenAnswer(AdditionalAnswers.answerVoid(q -> sqlStatementQueries.add((String)q))); + + MetadataStorageConfig metadataStorageConfig = DefaultMetadataStorageConfig.builder() + .storageMode(MetadataStorageMode.COMBINED_JSONB) + .columnDefinitions(Collections.singletonList("metadata JSONB")) + .indexes(Collections.singletonList("metadata")) + .indexType("GIN") + .build(); + JSONBMetadataHandler jsonbMetadataHandler = new JSONBMetadataHandler(metadataStorageConfig); + jsonbMetadataHandler.createMetadataIndexes(statement, "embeddings"); + + Assertions.assertEquals(1, sqlStatementQueries.size()); + Assertions.assertEquals("create index if not exists embeddings_metadata on embeddings " + + "USING GIN (metadata)", sqlStatementQueries.get(0)); + } + + @Test + void createSimpleMetadataIndexes_jsonb_path_ops() throws SQLException { + Statement statement = mock(Statement.class); + List sqlStatementQueries = new ArrayList<>(); + when(statement.executeUpdate(anyString())) + .thenAnswer(AdditionalAnswers.answerVoid(q -> sqlStatementQueries.add((String)q))); + + MetadataStorageConfig metadataStorageConfig = DefaultMetadataStorageConfig.builder() + .storageMode(MetadataStorageMode.COMBINED_JSONB) + .columnDefinitions(Collections.singletonList("metadata JSONB")) + .indexes(Collections.singletonList("metadata jsonb_path_ops")) + .indexType("GIN") + .build(); + JSONBMetadataHandler jsonbMetadataHandler = new JSONBMetadataHandler(metadataStorageConfig); + jsonbMetadataHandler.createMetadataIndexes(statement, "embeddings"); + + Assertions.assertEquals(1, sqlStatementQueries.size()); + Assertions.assertEquals("create index if not exists embeddings_metadata_jsonb_path_ops on embeddings " + + "USING GIN (metadata jsonb_path_ops)", sqlStatementQueries.get(0)); + } + + @Test + void createJSONNodeMetadataIndexes() throws SQLException { + Statement statement = mock(Statement.class); + List sqlStatementQueries = new ArrayList<>(); + when(statement.executeUpdate(anyString())) + .thenAnswer(AdditionalAnswers.answerVoid(q -> sqlStatementQueries.add((String)q))); + + MetadataStorageConfig metadataStorageConfig = DefaultMetadataStorageConfig.builder() + .storageMode(MetadataStorageMode.COMBINED_JSONB) + .columnDefinitions(Collections.singletonList("metadata JSONB")) + .indexes(Arrays.asList("(metadata->key1)", "(metadata->key2)")) + .indexType("GIN") + .build(); + JSONBMetadataHandler jsonbMetadataHandler = new JSONBMetadataHandler(metadataStorageConfig); + jsonbMetadataHandler.createMetadataIndexes(statement, "embeddings"); + + Assertions.assertEquals(2, sqlStatementQueries.size()); + Assertions.assertEquals("create index if not exists embeddings_metadata_key1 on embeddings " + + "USING GIN ((metadata->key1))", sqlStatementQueries.get(0)); + Assertions.assertEquals("create index if not exists embeddings_metadata_key2 on embeddings " + + "USING GIN ((metadata->key2))", sqlStatementQueries.get(1)); + } +} \ No newline at end of file diff --git a/langchain4j-pgvector/src/test/java/dev/langchain4j/store/embedding/pgvector/PgVectorEmbeddingIndexedStoreIT.java b/langchain4j-pgvector/src/test/java/dev/langchain4j/store/embedding/pgvector/PgVectorEmbeddingIndexedStoreIT.java index 00be3dbd83..8d032c3875 100644 --- a/langchain4j-pgvector/src/test/java/dev/langchain4j/store/embedding/pgvector/PgVectorEmbeddingIndexedStoreIT.java +++ b/langchain4j-pgvector/src/test/java/dev/langchain4j/store/embedding/pgvector/PgVectorEmbeddingIndexedStoreIT.java @@ -4,14 +4,14 @@ import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel; import dev.langchain4j.model.embedding.EmbeddingModel; import dev.langchain4j.store.embedding.EmbeddingStore; -import dev.langchain4j.store.embedding.EmbeddingStoreIT; +import dev.langchain4j.store.embedding.EmbeddingStoreWithFilteringIT; import org.junit.jupiter.api.BeforeEach; import org.testcontainers.containers.PostgreSQLContainer; import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; @Testcontainers -public class PgVectorEmbeddingIndexedStoreIT extends EmbeddingStoreIT { +public class PgVectorEmbeddingIndexedStoreIT extends EmbeddingStoreWithFilteringIT { @Container static PostgreSQLContainer pgVector = new PostgreSQLContainer<>("pgvector/pgvector:pg15"); diff --git a/langchain4j-pgvector/src/test/java/dev/langchain4j/store/embedding/pgvector/PgVectorEmbeddingStore029.java b/langchain4j-pgvector/src/test/java/dev/langchain4j/store/embedding/pgvector/PgVectorEmbeddingStore029.java new file mode 100644 index 0000000000..baa5e4ddd5 --- /dev/null +++ b/langchain4j-pgvector/src/test/java/dev/langchain4j/store/embedding/pgvector/PgVectorEmbeddingStore029.java @@ -0,0 +1,284 @@ +package dev.langchain4j.store.embedding.pgvector; +import com.google.gson.Gson; +import com.google.gson.reflect.TypeToken; +import com.pgvector.PGvector; +import dev.langchain4j.data.document.Metadata; +import dev.langchain4j.data.embedding.Embedding; +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.store.embedding.EmbeddingMatch; +import dev.langchain4j.store.embedding.EmbeddingStore; +import lombok.Builder; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.lang.reflect.Type; +import java.sql.*; +import java.util.*; + +import static dev.langchain4j.internal.Utils.*; +import static dev.langchain4j.internal.ValidationUtils.*; +import static java.util.Collections.singletonList; +import static java.util.stream.Collectors.toList; + +/** + * PGVector 0.29 EmbeddingStore Implementation + *

    + * Only cosine similarity is used. + * Only ivfflat index is used. + */ +class PgVectorEmbeddingStore029 implements EmbeddingStore { + + private static final Logger log = LoggerFactory.getLogger(PgVectorEmbeddingStore.class); + + private static final Gson GSON = new Gson(); + + private final String host; + private final Integer port; + private final String user; + private final String password; + private final String database; + private final String table; + + /** + * All args constructor for PgVectorEmbeddingStore Class + * + * @param host The database host + * @param port The database port + * @param user The database user + * @param password The database password + * @param database The database name + * @param table The database table + * @param dimension The vector dimension + * @param useIndex Should use IVFFlat index + * @param indexListSize The IVFFlat number of lists + * @param createTable Should create table automatically + * @param dropTableFirst Should drop table first, usually for testing + */ + @Builder + PgVectorEmbeddingStore029( + String host, + Integer port, + String user, + String password, + String database, + String table, + Integer dimension, + Boolean useIndex, + Integer indexListSize, + Boolean createTable, + Boolean dropTableFirst) { + this.host = ensureNotBlank(host, "host"); + this.port = ensureGreaterThanZero(port, "port"); + this.user = ensureNotBlank(user, "user"); + this.password = ensureNotBlank(password, "password"); + this.database = ensureNotBlank(database, "database"); + this.table = ensureNotBlank(table, "table"); + + useIndex = getOrDefault(useIndex, false); + createTable = getOrDefault(createTable, true); + dropTableFirst = getOrDefault(dropTableFirst, false); + + try (Connection connection = setupConnection()) { + + if (dropTableFirst) { + connection.createStatement().executeUpdate(String.format("DROP TABLE IF EXISTS %s", table)); + } + + if (createTable) { + connection.createStatement().executeUpdate(String.format( + "CREATE TABLE IF NOT EXISTS %s (" + + "embedding_id UUID PRIMARY KEY, " + + "embedding vector(%s), " + + "text TEXT NULL, " + + "metadata JSON NULL" + + ")", + table, ensureGreaterThanZero(dimension, "dimension"))); + } + + if (useIndex) { + final String indexName = table + "_ivfflat_index"; + connection.createStatement().executeUpdate(String.format( + "CREATE INDEX IF NOT EXISTS %s ON %s " + + "USING ivfflat (embedding vector_cosine_ops) " + + "WITH (lists = %s)", + indexName, table, ensureGreaterThanZero(indexListSize, "indexListSize"))); + } + } catch (SQLException e) { + throw new RuntimeException(e); + } + } + + private Connection setupConnection() throws SQLException { + Connection connection = DriverManager.getConnection( + String.format("jdbc:postgresql://%s:%s/%s", host, port, database), + user, + password + ); + connection.createStatement().executeUpdate("CREATE EXTENSION IF NOT EXISTS vector"); + PGvector.addVectorType(connection); + return connection; + } + + /** + * Adds a given embedding to the store. + * + * @param embedding The embedding to be added to the store. + * @return The auto-generated ID associated with the added embedding. + */ + @Override + public String add(Embedding embedding) { + String id = randomUUID(); + addInternal(id, embedding, null); + return id; + } + + /** + * Adds a given embedding to the store. + * + * @param id The unique identifier for the embedding to be added. + * @param embedding The embedding to be added to the store. + */ + @Override + public void add(String id, Embedding embedding) { + addInternal(id, embedding, null); + } + + /** + * Adds a given embedding and the corresponding content that has been embedded to the store. + * + * @param embedding The embedding to be added to the store. + * @param textSegment Original content that was embedded. + * @return The auto-generated ID associated with the added embedding. + */ + @Override + public String add(Embedding embedding, TextSegment textSegment) { + String id = randomUUID(); + addInternal(id, embedding, textSegment); + return id; + } + + /** + * Adds multiple embeddings to the store. + * + * @param embeddings A list of embeddings to be added to the store. + * @return A list of auto-generated IDs associated with the added embeddings. + */ + @Override + public List addAll(List embeddings) { + List ids = embeddings.stream().map(ignored -> randomUUID()).collect(toList()); + addAllInternal(ids, embeddings, null); + return ids; + } + + /** + * Adds multiple embeddings and their corresponding contents that have been embedded to the store. + * + * @param embeddings A list of embeddings to be added to the store. + * @param embedded A list of original contents that were embedded. + * @return A list of auto-generated IDs associated with the added embeddings. + */ + @Override + public List addAll(List embeddings, List embedded) { + List ids = embeddings.stream().map(ignored -> randomUUID()).collect(toList()); + addAllInternal(ids, embeddings, embedded); + return ids; + } + + /** + * Finds the most relevant (closest in space) embeddings to the provided reference embedding. + * + * @param referenceEmbedding The embedding used as a reference. Returned embeddings should be relevant (closest) to this one. + * @param maxResults The maximum number of embeddings to be returned. + * @param minScore The minimum relevance score, ranging from 0 to 1 (inclusive). + * Only embeddings with a score of this value or higher will be returned. + * @return A list of embedding matches. + * Each embedding match includes a relevance score (derivative of cosine distance), + * ranging from 0 (not relevant) to 1 (highly relevant). + */ + @Override + public List> findRelevant(Embedding referenceEmbedding, int maxResults, double minScore) { + List> result = new ArrayList<>(); + try (Connection connection = setupConnection()) { + String referenceVector = Arrays.toString(referenceEmbedding.vector()); + String query = String.format( + "WITH temp AS (SELECT (2 - (embedding <=> '%s')) / 2 AS score, embedding_id, embedding, text, metadata FROM %s) SELECT * FROM temp WHERE score >= %s ORDER BY score desc LIMIT %s;", + referenceVector, table, minScore, maxResults); + PreparedStatement selectStmt = connection.prepareStatement(query); + + ResultSet resultSet = selectStmt.executeQuery(); + while (resultSet.next()) { + double score = resultSet.getDouble("score"); + String embeddingId = resultSet.getString("embedding_id"); + + PGvector vector = (PGvector) resultSet.getObject("embedding"); + Embedding embedding = new Embedding(vector.toArray()); + + String text = resultSet.getString("text"); + TextSegment textSegment = null; + if (isNotNullOrBlank(text)) { + String metadataJson = Optional.ofNullable(resultSet.getString("metadata")).orElse("{}"); + Type type = new TypeToken>() { + }.getType(); + Metadata metadata = new Metadata(new HashMap<>(GSON.fromJson(metadataJson, type))); + textSegment = TextSegment.from(text, metadata); + } + + result.add(new EmbeddingMatch<>(score, embeddingId, embedding, textSegment)); + } + } catch (SQLException e) { + throw new RuntimeException(e); + } + + return result; + } + + private void addInternal(String id, Embedding embedding, TextSegment embedded) { + addAllInternal( + singletonList(id), + singletonList(embedding), + embedded == null ? null : singletonList(embedded)); + } + + private void addAllInternal( + List ids, List embeddings, List embedded) { + if (isNullOrEmpty(ids) || isNullOrEmpty(embeddings)) { + log.info("Empty embeddings - no ops"); + return; + } + ensureTrue(ids.size() == embeddings.size(), "ids size is not equal to embeddings size"); + ensureTrue(embedded == null || embeddings.size() == embedded.size(), + "embeddings size is not equal to embedded size"); + + try (Connection connection = setupConnection()) { + String query = String.format( + "INSERT INTO %s (embedding_id, embedding, text, metadata) VALUES (?, ?, ?, ?)" + + "ON CONFLICT (embedding_id) DO UPDATE SET " + + "embedding = EXCLUDED.embedding," + + "text = EXCLUDED.text," + + "metadata = EXCLUDED.metadata;", + table); + + PreparedStatement upsertStmt = connection.prepareStatement(query); + + for (int i = 0; i < ids.size(); ++i) { + upsertStmt.setObject(1, UUID.fromString(ids.get(i))); + upsertStmt.setObject(2, new PGvector(embeddings.get(i).vector())); + + if (embedded != null && embedded.get(i) != null) { + upsertStmt.setObject(3, embedded.get(i).text()); + Map metadata = embedded.get(i).metadata().asMap(); + upsertStmt.setObject(4, GSON.toJson(metadata), Types.OTHER); + } else { + upsertStmt.setNull(3, Types.VARCHAR); + upsertStmt.setNull(4, Types.OTHER); + } + upsertStmt.addBatch(); + } + + upsertStmt.executeBatch(); + + } catch (SQLException e) { + throw new RuntimeException(e); + } + } +} \ No newline at end of file diff --git a/langchain4j-pgvector/src/test/java/dev/langchain4j/store/embedding/pgvector/PgVectorEmbeddingStoreConfigIT.java b/langchain4j-pgvector/src/test/java/dev/langchain4j/store/embedding/pgvector/PgVectorEmbeddingStoreConfigIT.java new file mode 100644 index 0000000000..a1c56dff83 --- /dev/null +++ b/langchain4j-pgvector/src/test/java/dev/langchain4j/store/embedding/pgvector/PgVectorEmbeddingStoreConfigIT.java @@ -0,0 +1,68 @@ +package dev.langchain4j.store.embedding.pgvector; + +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel; +import dev.langchain4j.model.embedding.EmbeddingModel; +import dev.langchain4j.store.embedding.EmbeddingStore; +import dev.langchain4j.store.embedding.EmbeddingStoreWithFilteringIT; +import org.junit.jupiter.api.BeforeEach; +import org.postgresql.ds.PGSimpleDataSource; +import org.testcontainers.containers.PostgreSQLContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + +import javax.sql.DataSource; +import java.sql.Connection; +import java.sql.SQLException; + +@Testcontainers +public abstract class PgVectorEmbeddingStoreConfigIT extends EmbeddingStoreWithFilteringIT { + + @Container + static PostgreSQLContainer pgVector = new PostgreSQLContainer<>("pgvector/pgvector:pg16"); + + static EmbeddingStore embeddingStore; + + EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel(); + + static DataSource dataSource; + + static final String TABLE_NAME = "test"; + static final int TABLE_DIMENSION = 384; + + static void configureStore(MetadataStorageConfig config) { + PGSimpleDataSource source = new PGSimpleDataSource(); + source.setServerNames(new String[] {pgVector.getHost()}); + source.setPortNumbers(new int[] {pgVector.getFirstMappedPort()}); + source.setDatabaseName("test"); + source.setUser("test"); + source.setPassword("test"); + dataSource = source; + embeddingStore = PgVectorEmbeddingStore.datasourceBuilder() + .datasource(dataSource) + .table(TABLE_NAME) + .dimension(TABLE_DIMENSION) + .dropTableFirst(true) + .metadataStorageConfig(config) + .build(); + } + + @BeforeEach + void beforeEach() { + try (Connection connection = dataSource.getConnection()) { + connection.createStatement().executeUpdate(String.format("TRUNCATE TABLE %s", TABLE_NAME)); + } catch (SQLException e) { + throw new RuntimeException(e); + } + } + + @Override + protected EmbeddingStore embeddingStore() { + return embeddingStore; + } + + @Override + protected EmbeddingModel embeddingModel() { + return embeddingModel; + } +} diff --git a/langchain4j-pgvector/src/test/java/dev/langchain4j/store/embedding/pgvector/PgVectorEmbeddingStoreIT.java b/langchain4j-pgvector/src/test/java/dev/langchain4j/store/embedding/pgvector/PgVectorEmbeddingStoreIT.java index a91662c6ca..bb1ddfe332 100644 --- a/langchain4j-pgvector/src/test/java/dev/langchain4j/store/embedding/pgvector/PgVectorEmbeddingStoreIT.java +++ b/langchain4j-pgvector/src/test/java/dev/langchain4j/store/embedding/pgvector/PgVectorEmbeddingStoreIT.java @@ -4,14 +4,14 @@ import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel; import dev.langchain4j.model.embedding.EmbeddingModel; import dev.langchain4j.store.embedding.EmbeddingStore; -import dev.langchain4j.store.embedding.EmbeddingStoreIT; +import dev.langchain4j.store.embedding.EmbeddingStoreWithFilteringIT; import org.junit.jupiter.api.BeforeEach; import org.testcontainers.containers.PostgreSQLContainer; import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; @Testcontainers -public class PgVectorEmbeddingStoreIT extends EmbeddingStoreIT { +public class PgVectorEmbeddingStoreIT extends EmbeddingStoreWithFilteringIT { @Container static PostgreSQLContainer pgVector = new PostgreSQLContainer<>("pgvector/pgvector:pg15"); diff --git a/langchain4j-pgvector/src/test/java/dev/langchain4j/store/embedding/pgvector/PgVectorEmbeddingStoreRemovalIT.java b/langchain4j-pgvector/src/test/java/dev/langchain4j/store/embedding/pgvector/PgVectorEmbeddingStoreRemovalIT.java new file mode 100644 index 0000000000..a91d268b3c --- /dev/null +++ b/langchain4j-pgvector/src/test/java/dev/langchain4j/store/embedding/pgvector/PgVectorEmbeddingStoreRemovalIT.java @@ -0,0 +1,40 @@ +package dev.langchain4j.store.embedding.pgvector; + +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel; +import dev.langchain4j.model.embedding.EmbeddingModel; +import dev.langchain4j.store.embedding.EmbeddingStore; +import dev.langchain4j.store.embedding.EmbeddingStoreWithRemovalIT; +import org.testcontainers.containers.PostgreSQLContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + +@Testcontainers +public class PgVectorEmbeddingStoreRemovalIT extends EmbeddingStoreWithRemovalIT { + + @Container + static PostgreSQLContainer pgVector = new PostgreSQLContainer<>("pgvector/pgvector:pg15"); + + EmbeddingStore embeddingStore = PgVectorEmbeddingStore.builder() + .host(pgVector.getHost()) + .port(pgVector.getFirstMappedPort()) + .user("test") + .password("test") + .database("test") + .table("test") + .dimension(384) + .dropTableFirst(true) + .build(); + + EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel(); + + @Override + protected EmbeddingStore embeddingStore() { + return embeddingStore; + } + + @Override + protected EmbeddingModel embeddingModel() { + return embeddingModel; + } +} diff --git a/langchain4j-pgvector/src/test/java/dev/langchain4j/store/embedding/pgvector/PgVectorEmbeddingStoreUpgradeIT.java b/langchain4j-pgvector/src/test/java/dev/langchain4j/store/embedding/pgvector/PgVectorEmbeddingStoreUpgradeIT.java new file mode 100644 index 0000000000..f1c7ac36e0 --- /dev/null +++ b/langchain4j-pgvector/src/test/java/dev/langchain4j/store/embedding/pgvector/PgVectorEmbeddingStoreUpgradeIT.java @@ -0,0 +1,98 @@ +package dev.langchain4j.store.embedding.pgvector; + +import dev.langchain4j.data.embedding.Embedding; +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel; +import dev.langchain4j.model.embedding.EmbeddingModel; +import dev.langchain4j.store.embedding.*; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.testcontainers.containers.PostgreSQLContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + +import java.util.List; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.data.Percentage.withPercentage; + +/** + * Test upgrade from 029 to latest version + */ +@Testcontainers +public class PgVectorEmbeddingStoreUpgradeIT { + + @Container + static PostgreSQLContainer pgVector = new PostgreSQLContainer<>("pgvector/pgvector:pg15"); + + EmbeddingStore embeddingStore029; + + EmbeddingStore embeddingStore; + + EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel(); + + @BeforeEach + void beforeEach() { + embeddingStore029 = PgVectorEmbeddingStore029.builder() + .host(pgVector.getHost()) + .port(pgVector.getFirstMappedPort()) + .user("test") + .password("test") + .database("test") + .table("test") + .dimension(384) + .dropTableFirst(true) + .build(); + + embeddingStore = PgVectorEmbeddingStore.builder() + .host(pgVector.getHost()) + .port(pgVector.getFirstMappedPort()) + .user("test") + .password("test") + .database("test") + .table("test") + .dimension(384) + .build(); + } + + @Test + void upgrade() { + + Embedding embedding = embeddingModel.embed("hello").content(); + + String id = embeddingStore029.add(embedding); + assertThat(id).isNotBlank(); + + // Check 029 results + List> relevant = embeddingStore029.findRelevant(embedding, 10); + assertThat(relevant).hasSize(1); + + EmbeddingMatch match = relevant.get(0); + assertThat(match.score()).isCloseTo(1, withPercentage(1)); + assertThat(match.embeddingId()).isEqualTo(id); + assertThat(match.embedding()).isEqualTo(embedding); + assertThat(match.embedded()).isNull(); + + // new API + assertThat(embeddingStore029.search(EmbeddingSearchRequest.builder() + .queryEmbedding(embedding) + .maxResults(10) + .build()).matches()).isEqualTo(relevant); + + // Check Latest Store results + relevant = embeddingStore.findRelevant(embedding, 10); + assertThat(relevant).hasSize(1); + + match = relevant.get(0); + assertThat(match.score()).isCloseTo(1, withPercentage(1)); + assertThat(match.embeddingId()).isEqualTo(id); + assertThat(match.embedding()).isEqualTo(embedding); + assertThat(match.embedded()).isNull(); + + // new API + assertThat(embeddingStore.search(EmbeddingSearchRequest.builder() + .queryEmbedding(embedding) + .maxResults(10) + .build()).matches()).isEqualTo(relevant); + } +} diff --git a/langchain4j-pgvector/src/test/java/dev/langchain4j/store/embedding/pgvector/PgVectorEmbeddingStoreWithColumnsFilteringIT.java b/langchain4j-pgvector/src/test/java/dev/langchain4j/store/embedding/pgvector/PgVectorEmbeddingStoreWithColumnsFilteringIT.java new file mode 100644 index 0000000000..17eae8a056 --- /dev/null +++ b/langchain4j-pgvector/src/test/java/dev/langchain4j/store/embedding/pgvector/PgVectorEmbeddingStoreWithColumnsFilteringIT.java @@ -0,0 +1,27 @@ +package dev.langchain4j.store.embedding.pgvector; + +import org.junit.jupiter.api.BeforeAll; +import org.testcontainers.junit.jupiter.Testcontainers; + +import java.util.Arrays; + +@Testcontainers +public class PgVectorEmbeddingStoreWithColumnsFilteringIT extends PgVectorEmbeddingStoreConfigIT { + + @BeforeAll + static void beforeAll() { + MetadataStorageConfig config = DefaultMetadataStorageConfig.builder() + .storageMode(MetadataStorageMode.COLUMN_PER_KEY) + .columnDefinitions( + Arrays.asList("key text NULL", "name text NULL", "age float NULL", "city varchar null", "country varchar null", + "string_empty varchar null", "string_space varchar null", "string_abc varchar null", + "integer_min int null", "integer_minus_1 int null", "integer_0 int null", "integer_1 int null", "integer_max int null", + "long_min bigint null", "long_minus_1 bigint null", "long_0 bigint null", "long_1 bigint null", "long_max bigint null", + "float_min float null", "float_minus_1 float null", "float_0 float null", "float_1 float null", "float_123 float null", "float_max float null", + "double_minus_1 float8 null", "double_0 float8 null", "double_1 float8 null", "double_123 float8 null" + )) + .indexes(Arrays.asList("key", "name", "age")) + .build(); + PgVectorEmbeddingStoreConfigIT.configureStore(config); + } +} diff --git a/langchain4j-pgvector/src/test/java/dev/langchain4j/store/embedding/pgvector/PgVectorEmbeddingStoreWithJSONBFilteringIT.java b/langchain4j-pgvector/src/test/java/dev/langchain4j/store/embedding/pgvector/PgVectorEmbeddingStoreWithJSONBFilteringIT.java new file mode 100644 index 0000000000..4a38b9e72e --- /dev/null +++ b/langchain4j-pgvector/src/test/java/dev/langchain4j/store/embedding/pgvector/PgVectorEmbeddingStoreWithJSONBFilteringIT.java @@ -0,0 +1,20 @@ +package dev.langchain4j.store.embedding.pgvector; + +import org.junit.jupiter.api.BeforeAll; +import org.testcontainers.junit.jupiter.Testcontainers; + +import java.util.Collections; + +@Testcontainers +public class PgVectorEmbeddingStoreWithJSONBFilteringIT extends PgVectorEmbeddingStoreConfigIT { + @BeforeAll + static void beforeAll() { + MetadataStorageConfig config = DefaultMetadataStorageConfig.builder() + .storageMode(MetadataStorageMode.COMBINED_JSONB) + .columnDefinitions(Collections.singletonList("metadata JSONB NULL")) + .indexes(Collections.singletonList("metadata")) + .indexType("GIN") + .build(); + PgVectorEmbeddingStoreConfigIT.configureStore(config); + } +} diff --git a/langchain4j-pgvector/src/test/java/dev/langchain4j/store/embedding/pgvector/PgVectorEmbeddingStoreWithJSONBMultiIndexesFilteringIT.java b/langchain4j-pgvector/src/test/java/dev/langchain4j/store/embedding/pgvector/PgVectorEmbeddingStoreWithJSONBMultiIndexesFilteringIT.java new file mode 100644 index 0000000000..079fe67c88 --- /dev/null +++ b/langchain4j-pgvector/src/test/java/dev/langchain4j/store/embedding/pgvector/PgVectorEmbeddingStoreWithJSONBMultiIndexesFilteringIT.java @@ -0,0 +1,21 @@ +package dev.langchain4j.store.embedding.pgvector; + +import org.junit.jupiter.api.BeforeAll; +import org.testcontainers.junit.jupiter.Testcontainers; + +import java.util.Arrays; +import java.util.Collections; + +@Testcontainers +public class PgVectorEmbeddingStoreWithJSONBMultiIndexesFilteringIT extends PgVectorEmbeddingStoreConfigIT { + @BeforeAll + static void beforeAll() { + MetadataStorageConfig config = DefaultMetadataStorageConfig.builder() + .storageMode(MetadataStorageMode.COMBINED_JSONB) + .columnDefinitions(Collections.singletonList("metadata_b JSONB NULL")) + .indexes(Arrays.asList("(metadata_b->'key')", "(metadata_b->'name')", "(metadata_b->'age')")) + .indexType("GIN") + .build(); + PgVectorEmbeddingStoreConfigIT.configureStore(config); + } +} diff --git a/langchain4j-pinecone/pom.xml b/langchain4j-pinecone/pom.xml index ddd92b0b4d..9269850e41 100644 --- a/langchain4j-pinecone/pom.xml +++ b/langchain4j-pinecone/pom.xml @@ -7,7 +7,7 @@ dev.langchain4j langchain4j-parent - 0.30.0 + 0.32.0-SNAPSHOT ../langchain4j-parent/pom.xml diff --git a/langchain4j-qdrant/pom.xml b/langchain4j-qdrant/pom.xml index d65d078c19..921dfb4b92 100644 --- a/langchain4j-qdrant/pom.xml +++ b/langchain4j-qdrant/pom.xml @@ -7,7 +7,7 @@ dev.langchain4j langchain4j-parent - 0.30.0 + 0.32.0-SNAPSHOT ../langchain4j-parent/pom.xml diff --git a/langchain4j-qianfan/pom.xml b/langchain4j-qianfan/pom.xml index c45d2c3407..88bf5c6456 100644 --- a/langchain4j-qianfan/pom.xml +++ b/langchain4j-qianfan/pom.xml @@ -7,7 +7,7 @@ dev.langchain4j langchain4j-parent - 0.30.0 + 0.32.0-SNAPSHOT ../langchain4j-parent/pom.xml diff --git a/langchain4j-qianfan/src/main/java/dev/langchain4j/model/qianfan/QianfanChatModel.java b/langchain4j-qianfan/src/main/java/dev/langchain4j/model/qianfan/QianfanChatModel.java index fe7f3f11b0..7bd65dbe04 100644 --- a/langchain4j-qianfan/src/main/java/dev/langchain4j/model/qianfan/QianfanChatModel.java +++ b/langchain4j-qianfan/src/main/java/dev/langchain4j/model/qianfan/QianfanChatModel.java @@ -41,23 +41,32 @@ public class QianfanChatModel implements ChatLanguageModel { private final String responseFormat; + private final String userId; + private final List stop; + private final Integer maxOutputTokens; + private final String system; @Builder public QianfanChatModel(String baseUrl, - String apiKey, - String secretKey, - Double temperature, - Integer maxRetries, - Double topP, - String modelName, - String endpoint, - String responseFormat, - Double penaltyScore, - Boolean logRequests, - Boolean logResponses - ) { - if (Utils.isNullOrBlank(apiKey)||Utils.isNullOrBlank(secretKey)) { - throw new IllegalArgumentException(" api key and secret key must be defined. It can be generated here: https://console.bce.baidu.com/qianfan/ais/console/applicationConsole/application"); + String apiKey, + String secretKey, + Double temperature, + Integer maxRetries, + Double topP, + String modelName, + String endpoint, + String responseFormat, + Double penaltyScore, + Boolean logRequests, + Boolean logResponses, + String userId, + List stop, + Integer maxOutputTokens, + String system + ) { + if (Utils.isNullOrBlank(apiKey) || Utils.isNullOrBlank(secretKey)) { + throw new IllegalArgumentException( + " api key and secret key must be defined. It can be generated here: https://console.bce.baidu.com/qianfan/ais/console/applicationConsole/application"); } this.modelName=modelName; @@ -81,6 +90,10 @@ public QianfanChatModel(String baseUrl, this.topP = topP; this.penaltyScore = penaltyScore; this.responseFormat = responseFormat; + this.maxOutputTokens = maxOutputTokens; + this.stop = stop; + this.userId = userId; + this.system = system; } @@ -108,14 +121,18 @@ private Response generate(List messages, ) { ChatCompletionRequest.Builder builder = ChatCompletionRequest.builder() - .messages(toOpenAiMessages(messages)) - .temperature(temperature) - .topP(topP) - .penaltyScore(penaltyScore) - .system(getSystemMessage(messages)) - .responseFormat(responseFormat) - ; - + .messages(toOpenAiMessages(messages)) + .temperature(temperature) + .topP(topP) + .maxOutputTokens(maxOutputTokens) + .stop(stop) + .system(system) + .userId(userId) + .penaltyScore(penaltyScore) + .responseFormat(responseFormat); + if (system == null || system.length() == 0) { + builder.system(getSystemMessage(messages)); + } if (toolSpecifications != null && !toolSpecifications.isEmpty()) { builder.functions(toFunctions(toolSpecifications)); diff --git a/langchain4j-qianfan/src/main/java/dev/langchain4j/model/qianfan/QianfanChatModelNameEnum.java b/langchain4j-qianfan/src/main/java/dev/langchain4j/model/qianfan/QianfanChatModelNameEnum.java index 8017951efb..4fe3c8bcf8 100644 --- a/langchain4j-qianfan/src/main/java/dev/langchain4j/model/qianfan/QianfanChatModelNameEnum.java +++ b/langchain4j-qianfan/src/main/java/dev/langchain4j/model/qianfan/QianfanChatModelNameEnum.java @@ -9,6 +9,7 @@ public enum QianfanChatModelNameEnum { ERNIE_BOT_4("ERNIE-Bot 4.0", "completions_pro"), ERNIE_BOT_8("ERNIE-Bot-8K", "ernie_bot_8k"), ERNIE_BOT_TURBO("ERNIE-Bot-turbo", "eb-instant"), + ERNIE_SPEED_128K("ERNIE-Speed-128K", "completions"), EB_TURBO_APPBUILDER("EB-turbo-AppBuilder", "ai_apaas"), YI_34B_CHAT("Yi-34B-Chat", "yi_34b_chat"), BLOOMZ_7B("BLOOMZ-7B","bloomz_7b1"), diff --git a/langchain4j-qianfan/src/main/java/dev/langchain4j/model/qianfan/QianfanEmbeddingModel.java b/langchain4j-qianfan/src/main/java/dev/langchain4j/model/qianfan/QianfanEmbeddingModel.java index c022fbfbc0..4ad33a2628 100644 --- a/langchain4j-qianfan/src/main/java/dev/langchain4j/model/qianfan/QianfanEmbeddingModel.java +++ b/langchain4j-qianfan/src/main/java/dev/langchain4j/model/qianfan/QianfanEmbeddingModel.java @@ -6,14 +6,17 @@ import dev.langchain4j.model.embedding.EmbeddingModel; import dev.langchain4j.model.output.Response; import dev.langchain4j.model.qianfan.client.QianfanClient; -import dev.langchain4j.model.qianfan.client.embedding.EmbeddingResponse; import dev.langchain4j.model.qianfan.client.embedding.EmbeddingRequest; +import dev.langchain4j.model.qianfan.client.embedding.EmbeddingResponse; import dev.langchain4j.model.qianfan.spi.QianfanEmbeddingModelBuilderFactory; import lombok.Builder; + +import java.net.Proxy; import java.util.List; + import static dev.langchain4j.internal.RetryUtils.withRetry; import static dev.langchain4j.internal.Utils.getOrDefault; -import static dev.langchain4j.model.qianfan.InternalQianfanHelper.*; +import static dev.langchain4j.model.qianfan.InternalQianfanHelper.tokenUsageFrom; import static dev.langchain4j.spi.ServiceHelper.loadFactories; import static java.util.stream.Collectors.toList; /** @@ -46,7 +49,8 @@ public QianfanEmbeddingModel(String baseUrl, String endpoint, String user, Boolean logRequests, - Boolean logResponses + Boolean logResponses, + Proxy proxy ) { if (Utils.isNullOrBlank(apiKey)||Utils.isNullOrBlank(secretKey)) { throw new IllegalArgumentException(" api key and secret key must be defined. It can be generated here: https://console.bce.baidu.com/qianfan/ais/console/applicationConsole/application"); @@ -67,6 +71,7 @@ public QianfanEmbeddingModel(String baseUrl, .secretKey(secretKey) .logRequests(logRequests) .logResponses(logResponses) + .proxy(proxy) .build(); this.maxRetries = getOrDefault(maxRetries, 3); this.user = user; diff --git a/langchain4j-qianfan/src/main/java/dev/langchain4j/model/qianfan/QianfanLanguageModel.java b/langchain4j-qianfan/src/main/java/dev/langchain4j/model/qianfan/QianfanLanguageModel.java index d3d4afc9bc..0f5799d3ca 100644 --- a/langchain4j-qianfan/src/main/java/dev/langchain4j/model/qianfan/QianfanLanguageModel.java +++ b/langchain4j-qianfan/src/main/java/dev/langchain4j/model/qianfan/QianfanLanguageModel.java @@ -1,7 +1,6 @@ package dev.langchain4j.model.qianfan; - import dev.langchain4j.internal.Utils; import dev.langchain4j.model.language.LanguageModel; import dev.langchain4j.model.output.Response; @@ -10,9 +9,13 @@ import dev.langchain4j.model.qianfan.client.completion.CompletionResponse; import dev.langchain4j.model.qianfan.spi.QianfanLanguageModelBuilderFactory; import lombok.Builder; + +import java.net.Proxy; + import static dev.langchain4j.internal.RetryUtils.withRetry; import static dev.langchain4j.internal.Utils.getOrDefault; -import static dev.langchain4j.model.qianfan.InternalQianfanHelper.*; +import static dev.langchain4j.model.qianfan.InternalQianfanHelper.finishReasonFrom; +import static dev.langchain4j.model.qianfan.InternalQianfanHelper.tokenUsageFrom; import static dev.langchain4j.spi.ServiceHelper.loadFactories; @@ -51,7 +54,8 @@ public QianfanLanguageModel(String baseUrl, String endpoint, Double penaltyScore, Boolean logRequests, - Boolean logResponses + Boolean logResponses, + Proxy proxy ) { if (Utils.isNullOrBlank(apiKey)||Utils.isNullOrBlank(secretKey)) { throw new IllegalArgumentException(" api key and secret key must be defined. It can be generated here: https://console.bce.baidu.com/qianfan/ais/console/applicationConsole/application"); @@ -72,6 +76,7 @@ public QianfanLanguageModel(String baseUrl, .secretKey(secretKey) .logRequests(logRequests) .logResponses(logResponses) + .proxy(proxy) .build(); this.temperature = getOrDefault(temperature, 0.7); this.maxRetries = getOrDefault(maxRetries, 3); diff --git a/langchain4j-qianfan/src/main/java/dev/langchain4j/model/qianfan/QianfanStreamingChatModel.java b/langchain4j-qianfan/src/main/java/dev/langchain4j/model/qianfan/QianfanStreamingChatModel.java index 5ff6baa422..2633e408a9 100644 --- a/langchain4j-qianfan/src/main/java/dev/langchain4j/model/qianfan/QianfanStreamingChatModel.java +++ b/langchain4j-qianfan/src/main/java/dev/langchain4j/model/qianfan/QianfanStreamingChatModel.java @@ -15,9 +15,12 @@ import dev.langchain4j.model.qianfan.client.chat.ChatCompletionResponse; import dev.langchain4j.model.qianfan.spi.QianfanStreamingChatModelBuilderFactory; import lombok.Builder; -import static dev.langchain4j.model.qianfan.InternalQianfanHelper.*; + +import java.net.Proxy; import java.util.List; + import static dev.langchain4j.internal.Utils.getOrDefault; +import static dev.langchain4j.model.qianfan.InternalQianfanHelper.getSystemMessage; import static dev.langchain4j.spi.ServiceHelper.loadFactories; /** @@ -52,7 +55,8 @@ public QianfanStreamingChatModel(String baseUrl, String responseFormat, Double penaltyScore, Boolean logRequests, - Boolean logResponses + Boolean logResponses, + Proxy proxy ) { if (Utils.isNullOrBlank(apiKey)||Utils.isNullOrBlank(secretKey)) { throw new IllegalArgumentException(" api key and secret key must be defined. It can be generated here: https://console.bce.baidu.com/qianfan/ais/console/applicationConsole/application"); @@ -61,7 +65,7 @@ public QianfanStreamingChatModel(String baseUrl, this.endpoint=Utils.isNullOrBlank(endpoint)? QianfanChatModelNameEnum.getEndpoint(modelName):endpoint; if (Utils.isNullOrBlank(this.endpoint)) { - throw new IllegalArgumentException("Qianfan is no such model name. You can see model name here: https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Nlks5zkzu"); + throw new IllegalArgumentException("Qianfan is no such model name(or there is no model definition in the QianfanChatModelNameEnum class). You can see model name here: https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Nlks5zkzu"); } this.baseUrl = getOrDefault(baseUrl, "https://aip.baidubce.com"); @@ -71,6 +75,7 @@ public QianfanStreamingChatModel(String baseUrl, .secretKey(secretKey) .logRequests(logRequests) .logStreamingResponses(logResponses) + .proxy(proxy) .build(); this.temperature = getOrDefault(temperature, 0.7); this.topP = topP; diff --git a/langchain4j-qianfan/src/main/java/dev/langchain4j/model/qianfan/client/chat/ChatCompletionRequest.java b/langchain4j-qianfan/src/main/java/dev/langchain4j/model/qianfan/client/chat/ChatCompletionRequest.java index 5c00b11e16..e95abaad5c 100644 --- a/langchain4j-qianfan/src/main/java/dev/langchain4j/model/qianfan/client/chat/ChatCompletionRequest.java +++ b/langchain4j-qianfan/src/main/java/dev/langchain4j/model/qianfan/client/chat/ChatCompletionRequest.java @@ -16,6 +16,8 @@ public final class ChatCompletionRequest { private final String userId; private final List functions; private final String system; + private final List stop; + private final Integer maxOutputTokens; private final String responseFormat; @@ -29,6 +31,8 @@ private ChatCompletionRequest(Builder builder) { this.functions = builder.functions; this.system = builder.system; this.responseFormat = builder.responseFormat; + this.stop = builder.stop; + this.maxOutputTokens=builder.maxOutputTokens; } @@ -67,9 +71,6 @@ public List functions() { } - - - @Override public String toString() { return "ChatCompletionRequest{" + @@ -81,6 +82,9 @@ public String toString() { ", userId='" + userId + '\'' + ", functions=" + functions + ", system='" + system + '\'' + + ", stop=" + stop + + ", maxOutputTokens=" + maxOutputTokens + + ", responseFormat='" + responseFormat + '\'' + '}'; } @@ -100,7 +104,8 @@ public static final class Builder { private String system; private String responseFormat; - + private List stop; + private Integer maxOutputTokens; private Builder() { } @@ -167,6 +172,7 @@ public Builder addFunctionMessage(String name, String content) { return this; } + public Builder temperature(Double temperature) { this.temperature = temperature; return this; @@ -175,6 +181,14 @@ public Builder system(String system) { this.system = system; return this; } + public Builder maxOutputTokens(Integer maxOutputTokens) { + this.maxOutputTokens = maxOutputTokens; + return this; + } + public Builder stop(List stop) { + this.stop = stop; + return this; + } public Builder responseFormat(String responseFormat) { this.responseFormat = responseFormat; diff --git a/langchain4j-redis/pom.xml b/langchain4j-redis/pom.xml index 9d1f53bd8f..37a42d751c 100644 --- a/langchain4j-redis/pom.xml +++ b/langchain4j-redis/pom.xml @@ -7,7 +7,7 @@ dev.langchain4j langchain4j-parent - 0.30.0 + 0.32.0-SNAPSHOT ../langchain4j-parent/pom.xml diff --git a/langchain4j-vearch/pom.xml b/langchain4j-vearch/pom.xml index 02a8c0309e..239b2d8d80 100644 --- a/langchain4j-vearch/pom.xml +++ b/langchain4j-vearch/pom.xml @@ -6,7 +6,7 @@ dev.langchain4j langchain4j-parent - 0.30.0 + 0.32.0-SNAPSHOT ../langchain4j-parent/pom.xml diff --git a/langchain4j-vertex-ai-gemini/pom.xml b/langchain4j-vertex-ai-gemini/pom.xml index 0032f67753..22e224f691 100644 --- a/langchain4j-vertex-ai-gemini/pom.xml +++ b/langchain4j-vertex-ai-gemini/pom.xml @@ -7,7 +7,7 @@ dev.langchain4j langchain4j-parent - 0.30.0 + 0.32.0-SNAPSHOT ../langchain4j-parent/pom.xml @@ -94,7 +94,7 @@ libraries-bom import pom - 26.34.0 + 26.39.0 diff --git a/langchain4j-vertex-ai-gemini/src/main/java/dev/langchain4j/model/vertexai/ContentsMapper.java b/langchain4j-vertex-ai-gemini/src/main/java/dev/langchain4j/model/vertexai/ContentsMapper.java index d91c5e9802..49c5231cec 100644 --- a/langchain4j-vertex-ai-gemini/src/main/java/dev/langchain4j/model/vertexai/ContentsMapper.java +++ b/langchain4j-vertex-ai-gemini/src/main/java/dev/langchain4j/model/vertexai/ContentsMapper.java @@ -1,75 +1,40 @@ package dev.langchain4j.model.vertexai; -import dev.langchain4j.data.message.*; -import lombok.extern.slf4j.Slf4j; +import com.google.cloud.vertexai.api.Content; +import com.google.cloud.vertexai.api.Part; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.data.message.SystemMessage; import java.util.ArrayList; import java.util.List; -import java.util.concurrent.atomic.AtomicBoolean; -import static java.util.stream.Collectors.toList; - -@Slf4j class ContentsMapper { + static class InstructionAndContent { + public Content systemInstruction = null; + public List contents = new ArrayList<>(); + } - private static volatile boolean warned = false; - - static List map(List messages) { - - List systemMessages = messages.stream() - .filter(message -> message instanceof SystemMessage) - .map(message -> (SystemMessage) message) - .collect(toList()); - - if (!systemMessages.isEmpty()) { - if (!warned) { - log.warn("Gemini does not support SystemMessage(s). " + - "All SystemMessage(s) will be merged into the first UserMessage."); - warned = true; - } - messages = mergeSystemMessagesIntoUserMessage(messages, systemMessages); - } + static InstructionAndContent splitInstructionAndContent(List messages) { + InstructionAndContent instructionAndContent = new InstructionAndContent(); - // TODO what if only a single system message? + List sysInstructionParts = new ArrayList<>(); - return messages.stream() - .map(message -> com.google.cloud.vertexai.api.Content.newBuilder() + for (ChatMessage message : messages) { + if (message instanceof SystemMessage) { + sysInstructionParts.addAll(PartsMapper.map(message)); + } else { + instructionAndContent.contents.add(Content.newBuilder() .setRole(RoleMapper.map(message.type())) .addAllParts(PartsMapper.map(message)) - .build()) - .collect(toList()); - } - - private static List mergeSystemMessagesIntoUserMessage(List messages, - List systemMessages) { - AtomicBoolean injected = new AtomicBoolean(false); - return messages.stream() - .filter(message -> !(message instanceof SystemMessage)) - .map(message -> { - if (injected.get()) { - return message; - } - - if (message instanceof UserMessage) { - UserMessage userMessage = (UserMessage) message; - - List allContents = new ArrayList<>(); - allContents.addAll(systemMessages.stream() - .map(systemMessage -> TextContent.from(systemMessage.text())) - .collect(toList())); - allContents.addAll(userMessage.contents()); - - injected.set(true); + .build()); + } + } - if (userMessage.name() != null) { - return UserMessage.from(userMessage.name(), allContents); - } else { - return UserMessage.from(allContents); - } - } + instructionAndContent.systemInstruction = Content.newBuilder() + .setRole("system") + .addAllParts(sysInstructionParts) + .build(); - return message; - }) - .collect(toList()); + return instructionAndContent; } } diff --git a/langchain4j-vertex-ai-gemini/src/main/java/dev/langchain4j/model/vertexai/FunctionCallHelper.java b/langchain4j-vertex-ai-gemini/src/main/java/dev/langchain4j/model/vertexai/FunctionCallHelper.java index 50d2db6c42..5dedb0beaa 100644 --- a/langchain4j-vertex-ai-gemini/src/main/java/dev/langchain4j/model/vertexai/FunctionCallHelper.java +++ b/langchain4j-vertex-ai-gemini/src/main/java/dev/langchain4j/model/vertexai/FunctionCallHelper.java @@ -112,20 +112,23 @@ static Tool convertToolSpecifications(List toolSpecifications Schema.Builder schema = Schema.newBuilder().setType(Type.OBJECT); ToolParameters parameters = toolSpecification.parameters(); - for (String paramName : parameters.required()) { - schema.addRequired(paramName); + if (parameters != null) { + for (String paramName : parameters.required()) { + schema.addRequired(paramName); + } + + parameters.properties().forEach((paramName, paramProps) -> { + //TODO: is it covering all types & cases of tool parameters? (array & object in particular) + Type type = fromType((String) paramProps.getOrDefault("type", Type.TYPE_UNSPECIFIED)); + + String description = (String) paramProps.getOrDefault("description", ""); + + schema.putProperties(paramName, Schema.newBuilder() + .setDescription(description) + .setType(type) + .build()); + }); } - parameters.properties().forEach((paramName, paramProps) -> { - //TODO: is it covering all types & cases of tool parameters? (array & object in particular) - Type type = fromType((String) paramProps.getOrDefault("type", Type.TYPE_UNSPECIFIED)); - - String description = (String) paramProps.getOrDefault("description", ""); - - schema.putProperties(paramName, Schema.newBuilder() - .setDescription(description) - .setType(type) - .build()); - }); fnBuilder.setParameters(schema.build()); tool.addFunctionDeclarations(fnBuilder.build()); } diff --git a/langchain4j-vertex-ai-gemini/src/main/java/dev/langchain4j/model/vertexai/PartsMapper.java b/langchain4j-vertex-ai-gemini/src/main/java/dev/langchain4j/model/vertexai/PartsMapper.java index c993174bee..f48f4b718a 100644 --- a/langchain4j-vertex-ai-gemini/src/main/java/dev/langchain4j/model/vertexai/PartsMapper.java +++ b/langchain4j-vertex-ai-gemini/src/main/java/dev/langchain4j/model/vertexai/PartsMapper.java @@ -2,6 +2,7 @@ import com.google.cloud.vertexai.api.FunctionResponse; import com.google.cloud.vertexai.api.Part; +import com.google.cloud.vertexai.generativeai.PartMaker; import com.google.protobuf.InvalidProtocolBufferException; import com.google.protobuf.Struct; import com.google.protobuf.util.JsonFormat; @@ -11,6 +12,7 @@ import java.net.URI; import java.util.Base64; import java.util.HashMap; +import java.util.ArrayList; import java.util.List; import java.util.Map; @@ -43,24 +45,34 @@ static List map(ChatMessage message) { if (message instanceof AiMessage) { AiMessage aiMessage = (AiMessage) message; - if (aiMessage.hasToolExecutionRequests()) { - return singletonList(Part.newBuilder() - .setFunctionCall( - //TODO: handling one function call, but can there be several? + List parts = new ArrayList<>(); - FunctionCallHelper.fromToolExecutionRequest(aiMessage.toolExecutionRequests().get(0)) - ) - .build()); - } else { - return singletonList(Part.newBuilder() + if (aiMessage.text() != null && !aiMessage.text().isEmpty()) { + parts.add(Part.newBuilder() .setText(aiMessage.text()) .build()); } - } else - if (message instanceof UserMessage) { + + if (aiMessage.hasToolExecutionRequests()) { + List fnCallReqParts = aiMessage.toolExecutionRequests().stream() + .map(FunctionCallHelper::fromToolExecutionRequest) + .map(fnCall -> Part.newBuilder() + .setFunctionCall(fnCall) + .build()) + .collect(toList()); + + parts.addAll(fnCallReqParts); + } + + return parts; + } else if (message instanceof UserMessage) { return ((UserMessage) message).contents().stream() .map(PartsMapper::map) .collect(toList()); + } else if (message instanceof SystemMessage) { + return singletonList(Part.newBuilder() + .setText(((SystemMessage) message).text()) + .build()); } else if (message instanceof ToolExecutionResultMessage) { ToolExecutionResultMessage toolExecutionResultMessage = (ToolExecutionResultMessage) message; String functionResponseText = toolExecutionResultMessage.text(); diff --git a/langchain4j-vertex-ai-gemini/src/main/java/dev/langchain4j/model/vertexai/RoleMapper.java b/langchain4j-vertex-ai-gemini/src/main/java/dev/langchain4j/model/vertexai/RoleMapper.java index 7e55f1f490..e5dfb3c0c8 100644 --- a/langchain4j-vertex-ai-gemini/src/main/java/dev/langchain4j/model/vertexai/RoleMapper.java +++ b/langchain4j-vertex-ai-gemini/src/main/java/dev/langchain4j/model/vertexai/RoleMapper.java @@ -11,6 +11,8 @@ static String map(ChatMessageType type) { return "user"; case AI: return "model"; + case SYSTEM: + return "system"; } throw new IllegalArgumentException(type + " is not allowed."); } diff --git a/langchain4j-vertex-ai-gemini/src/main/java/dev/langchain4j/model/vertexai/VertexAiGeminiChatModel.java b/langchain4j-vertex-ai-gemini/src/main/java/dev/langchain4j/model/vertexai/VertexAiGeminiChatModel.java index 822de5f659..14280b22c7 100644 --- a/langchain4j-vertex-ai-gemini/src/main/java/dev/langchain4j/model/vertexai/VertexAiGeminiChatModel.java +++ b/langchain4j-vertex-ai-gemini/src/main/java/dev/langchain4j/model/vertexai/VertexAiGeminiChatModel.java @@ -2,7 +2,6 @@ import com.google.cloud.vertexai.VertexAI; import com.google.cloud.vertexai.api.*; -import com.google.cloud.vertexai.generativeai.GenerateContentConfig; import com.google.cloud.vertexai.generativeai.GenerativeModel; import com.google.cloud.vertexai.generativeai.ResponseHandler; import dev.langchain4j.agent.tool.ToolExecutionRequest; @@ -14,6 +13,8 @@ import dev.langchain4j.model.vertexai.spi.VertexAiGeminiChatModelBuilderFactory; import lombok.Builder; +import java.io.Closeable; +import java.io.IOException; import java.util.Collections; import java.util.List; import java.util.stream.Collectors; @@ -43,11 +44,12 @@ *
    * 3. Prerequisites */ -public class VertexAiGeminiChatModel implements ChatLanguageModel { +public class VertexAiGeminiChatModel implements ChatLanguageModel, Closeable { private final GenerativeModel generativeModel; private final GenerationConfig generationConfig; private final Integer maxRetries; + private final VertexAI vertexAI; @Builder public VertexAiGeminiChatModel(String project, @@ -73,40 +75,43 @@ public VertexAiGeminiChatModel(String project, } this.generationConfig = generationConfigBuilder.build(); - try (VertexAI vertexAI = new VertexAI( + this.vertexAI = new VertexAI( ensureNotBlank(project, "project"), - ensureNotBlank(location, "location")) - ) { - this.generativeModel = new GenerativeModel( - ensureNotBlank(modelName, "modelName"), generationConfig, vertexAI); - } + ensureNotBlank(location, "location")); + + this.generativeModel = new GenerativeModel( + ensureNotBlank(modelName, "modelName"), vertexAI) + .withGenerationConfig(generationConfig); this.maxRetries = getOrDefault(maxRetries, 3); } public VertexAiGeminiChatModel(GenerativeModel generativeModel, GenerationConfig generationConfig) { - this.generativeModel = ensureNotNull(generativeModel, "generativeModel"); - this.generationConfig = ensureNotNull(generationConfig, "generationConfig"); - this.generativeModel.setGenerationConfig(this.generationConfig); - this.maxRetries = 3; + this(generativeModel, generationConfig, 3); } public VertexAiGeminiChatModel(GenerativeModel generativeModel, GenerationConfig generationConfig, Integer maxRetries) { - this.generativeModel = ensureNotNull(generativeModel, "generativeModel"); this.generationConfig = ensureNotNull(generationConfig, "generationConfig"); - this.generativeModel.setGenerationConfig(this.generationConfig); + this.generativeModel = ensureNotNull(generativeModel, "generativeModel") + .withGenerationConfig(generationConfig); this.maxRetries = getOrDefault(maxRetries, 3); + this.vertexAI = null; } @Override public Response generate(List messages) { - List contents = ContentsMapper.map(messages); + ContentsMapper.InstructionAndContent instructionAndContent + = ContentsMapper.splitInstructionAndContent(messages); + + GenerativeModel model = instructionAndContent.systemInstruction != null ? + this.generativeModel.withSystemInstruction(instructionAndContent.systemInstruction) : + this.generativeModel; GenerateContentResponse response = withRetry(() -> - generativeModel.generateContent(contents), maxRetries); + model.generateContent(instructionAndContent.contents), maxRetries); return Response.from( AiMessage.from(ResponseHandler.getText(response)), @@ -117,15 +122,20 @@ public Response generate(List messages) { @Override public Response generate(List messages, List toolSpecifications) { - List contents = ContentsMapper.map(messages); Tool tool = FunctionCallHelper.convertToolSpecifications(toolSpecifications); - GenerateContentConfig generateContentConfig = GenerateContentConfig.newBuilder() - .setGenerationConfig(generationConfig) - .setTools(Collections.singletonList(tool)) - .build(); + + GenerativeModel modelWithTools = this.generativeModel + .withTools(Collections.singletonList(tool)); + + ContentsMapper.InstructionAndContent instructionAndContent + = ContentsMapper.splitInstructionAndContent(messages); + + GenerativeModel model = instructionAndContent.systemInstruction != null ? + modelWithTools.withSystemInstruction(instructionAndContent.systemInstruction) : + modelWithTools; GenerateContentResponse response = withRetry(() -> - generativeModel.generateContent(contents, generateContentConfig), maxRetries); + model.generateContent(instructionAndContent.contents), maxRetries); Content content = ResponseHandler.getContent(response); @@ -167,6 +177,13 @@ public static VertexAiGeminiChatModelBuilder builder() { return new VertexAiGeminiChatModelBuilder(); } + @Override + public void close() throws IOException { + if (this.vertexAI != null) { + vertexAI.close(); + } + } + public static class VertexAiGeminiChatModelBuilder { public VertexAiGeminiChatModelBuilder() { // This is public so it can be extended diff --git a/langchain4j-vertex-ai-gemini/src/main/java/dev/langchain4j/model/vertexai/VertexAiGeminiStreamingChatModel.java b/langchain4j-vertex-ai-gemini/src/main/java/dev/langchain4j/model/vertexai/VertexAiGeminiStreamingChatModel.java index cc560431c8..d806f086ea 100644 --- a/langchain4j-vertex-ai-gemini/src/main/java/dev/langchain4j/model/vertexai/VertexAiGeminiStreamingChatModel.java +++ b/langchain4j-vertex-ai-gemini/src/main/java/dev/langchain4j/model/vertexai/VertexAiGeminiStreamingChatModel.java @@ -1,10 +1,8 @@ package dev.langchain4j.model.vertexai; import com.google.cloud.vertexai.VertexAI; -import com.google.cloud.vertexai.api.Content; import com.google.cloud.vertexai.api.GenerationConfig; import com.google.cloud.vertexai.api.Tool; -import com.google.cloud.vertexai.generativeai.GenerateContentConfig; import com.google.cloud.vertexai.generativeai.GenerativeModel; import com.google.cloud.vertexai.generativeai.ResponseHandler; import dev.langchain4j.agent.tool.ToolSpecification; @@ -15,6 +13,8 @@ import dev.langchain4j.model.vertexai.spi.VertexAiGeminiStreamingChatModelBuilderFactory; import lombok.Builder; +import java.io.Closeable; +import java.io.IOException; import java.util.Collections; import java.util.List; @@ -26,10 +26,11 @@ * Represents a Google Vertex AI Gemini language model with a stream chat completion interface, such as gemini-pro. * See details here. */ -public class VertexAiGeminiStreamingChatModel implements StreamingChatLanguageModel { +public class VertexAiGeminiStreamingChatModel implements StreamingChatLanguageModel, Closeable { private final GenerativeModel generativeModel; private final GenerationConfig generationConfig; + private final VertexAI vertexAI; @Builder public VertexAiGeminiStreamingChatModel(String project, @@ -54,19 +55,20 @@ public VertexAiGeminiStreamingChatModel(String project, } this.generationConfig = generationConfigBuilder.build(); - try (VertexAI vertexAI = new VertexAI( + this.vertexAI = new VertexAI( ensureNotBlank(project, "project"), - ensureNotBlank(location, "location")) - ) { - this.generativeModel = new GenerativeModel( - ensureNotBlank(modelName, "modelName"), generationConfig, vertexAI); - } + ensureNotBlank(location, "location")); + + this.generativeModel = new GenerativeModel( + ensureNotBlank(modelName, "modelName"), vertexAI) + .withGenerationConfig(generationConfig); } public VertexAiGeminiStreamingChatModel(GenerativeModel generativeModel, GenerationConfig generationConfig) { this.generativeModel = ensureNotNull(generativeModel, "generativeModel"); this.generationConfig = ensureNotNull(generationConfig, "generationConfig"); + this.vertexAI = null; } @Override @@ -76,20 +78,24 @@ public void generate(List messages, StreamingResponseHandler messages, List toolSpecifications, StreamingResponseHandler handler) { - List contents = ContentsMapper.map(messages); - - GenerateContentConfig.Builder generateContentConfigBuilder = GenerateContentConfig.newBuilder() - .setGenerationConfig(generationConfig); + GenerativeModel model = this.generativeModel; if (toolSpecifications != null && !toolSpecifications.isEmpty()) { Tool tool = FunctionCallHelper.convertToolSpecifications(toolSpecifications); - generateContentConfigBuilder.setTools(Collections.singletonList(tool)); + model = model.withTools(Collections.singletonList(tool)); + } + + ContentsMapper.InstructionAndContent instructionAndContent + = ContentsMapper.splitInstructionAndContent(messages); + + if (instructionAndContent.systemInstruction != null) { + model = model.withSystemInstruction(instructionAndContent.systemInstruction); } - GenerateContentConfig generateContentConfig = generateContentConfigBuilder.build(); + StreamingChatResponseBuilder responseBuilder = new StreamingChatResponseBuilder(); try { - generativeModel.generateContentStream(contents, generateContentConfig) + model.generateContentStream(instructionAndContent.contents) .stream() .forEach(partialResponse -> { responseBuilder.append(partialResponse); @@ -99,7 +105,6 @@ public void generate(List messages, List toolSpe } catch (Exception exception) { handler.onError(exception); } - } @Override @@ -118,6 +123,13 @@ public static VertexAiGeminiStreamingChatModelBuilder builder() { return new VertexAiGeminiStreamingChatModelBuilder(); } + @Override + public void close() throws IOException { + if (this.vertexAI != null) { + this.vertexAI.close(); + } + } + public static class VertexAiGeminiStreamingChatModelBuilder { public VertexAiGeminiStreamingChatModelBuilder() { // This is public so it can be extended diff --git a/langchain4j-vertex-ai-gemini/src/test/java/dev/langchain4j/model/vertexai/VertexAiGeminiChatModelIT.java b/langchain4j-vertex-ai-gemini/src/test/java/dev/langchain4j/model/vertexai/VertexAiGeminiChatModelIT.java index 15bcd54f0c..d3d3832547 100644 --- a/langchain4j-vertex-ai-gemini/src/test/java/dev/langchain4j/model/vertexai/VertexAiGeminiChatModelIT.java +++ b/langchain4j-vertex-ai-gemini/src/test/java/dev/langchain4j/model/vertexai/VertexAiGeminiChatModelIT.java @@ -3,10 +3,7 @@ import com.google.cloud.vertexai.VertexAI; import com.google.cloud.vertexai.api.GenerationConfig; import com.google.cloud.vertexai.generativeai.GenerativeModel; -import dev.langchain4j.agent.tool.JsonSchemaProperty; -import dev.langchain4j.agent.tool.Tool; -import dev.langchain4j.agent.tool.ToolExecutionRequest; -import dev.langchain4j.agent.tool.ToolSpecification; +import dev.langchain4j.agent.tool.*; import dev.langchain4j.data.message.*; import dev.langchain4j.memory.chat.MessageWindowChatMemory; import dev.langchain4j.model.chat.ChatLanguageModel; @@ -22,10 +19,11 @@ import java.util.ArrayList; import java.util.Base64; import java.util.List; +import java.util.stream.Collectors; import java.util.stream.Stream; import static dev.langchain4j.internal.Utils.readBytes; -import static dev.langchain4j.model.output.FinishReason.STOP; +import static dev.langchain4j.model.output.FinishReason.LENGTH; import static java.util.Arrays.asList; import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.Mockito.*; @@ -35,16 +33,12 @@ class VertexAiGeminiChatModelIT { static final String CAT_IMAGE_URL = "https://upload.wikimedia.org/wikipedia/commons/e/e9/Felis_silvestris_silvestris_small_gradual_decrease_of_quality.png"; static final String DICE_IMAGE_URL = "https://upload.wikimedia.org/wikipedia/commons/4/47/PNG_transparency_demonstration_1.png"; - ChatLanguageModel model = VertexAiGeminiChatModel.builder() - .project(System.getenv("GCP_PROJECT_ID")) - .location(System.getenv("GCP_LOCATION")) - .modelName("gemini-pro") - .build(); + public static final String GEMINI_1_5_PRO = "gemini-1.5-pro-preview-0514"; - ChatLanguageModel visionModel = VertexAiGeminiChatModel.builder() + ChatLanguageModel model = VertexAiGeminiChatModel.builder() .project(System.getenv("GCP_PROJECT_ID")) .location(System.getenv("GCP_LOCATION")) - .modelName("gemini-pro-vision") + .modelName(GEMINI_1_5_PRO) .build(); @Test @@ -71,33 +65,33 @@ void should_generate_response() { @ParameterizedTest @MethodSource - void should_merge_system_messages_into_user_message(List messages) { + void should_support_system_instructions(List messages) { // when Response response = model.generate(messages); // then - assertThat(response.content().text()).containsIgnoringCase("liebe"); + assertThat(response.content().text()).containsIgnoringCase("lieb"); } - static Stream should_merge_system_messages_into_user_message() { + static Stream should_support_system_instructions() { return Stream.builder() .add(Arguments.of( asList( SystemMessage.from("Translate in German"), - UserMessage.from("I love you") + UserMessage.from("I love apples") ) )) .add(Arguments.of( asList( - UserMessage.from("I love you"), - SystemMessage.from("Translate in German") + UserMessage.from("I love apples"), + SystemMessage.from("Translate in German") ) )) .add(Arguments.of( asList( SystemMessage.from("Translate in Italian"), - UserMessage.from("I love you"), + UserMessage.from("I love apples"), SystemMessage.from("No, translate in German!") ) )) @@ -105,8 +99,8 @@ static Stream should_merge_system_messages_into_user_message() { asList( SystemMessage.from("Translate in German"), UserMessage.from(asList( - TextContent.from("I love you"), - TextContent.from("I see you") + TextContent.from("I love apples"), + TextContent.from("I see apples") )) ) )) @@ -114,17 +108,17 @@ static Stream should_merge_system_messages_into_user_message() { asList( SystemMessage.from("Translate in German"), UserMessage.from(asList( - TextContent.from("I see you"), - TextContent.from("I love you") + TextContent.from("I see apples"), + TextContent.from("I love apples") )) ) )) .add(Arguments.of( asList( SystemMessage.from("Translate in German"), - UserMessage.from("I see you"), - AiMessage.from("Ich sehe dich"), - UserMessage.from("I love you") + UserMessage.from("I see apples"), + AiMessage.from("Ich sehe Äpfel"), + UserMessage.from("I love apples") ) )) .build(); @@ -137,7 +131,7 @@ void should_respect_maxOutputTokens() { ChatLanguageModel model = VertexAiGeminiChatModel.builder() .project(System.getenv("GCP_PROJECT_ID")) .location(System.getenv("GCP_LOCATION")) - .modelName("gemini-pro") + .modelName(GEMINI_1_5_PRO) .maxOutputTokens(1) .build(); @@ -156,7 +150,7 @@ void should_respect_maxOutputTokens() { assertThat(tokenUsage.totalTokenCount()) .isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount()); - assertThat(response.finishReason()).isEqualTo(STOP); + assertThat(response.finishReason()).isEqualTo(LENGTH); } @Test @@ -164,7 +158,7 @@ void should_allow_custom_generativeModel_and_generationConfig() { // given VertexAI vertexAi = new VertexAI(System.getenv("GCP_PROJECT_ID"), System.getenv("GCP_LOCATION")); - GenerativeModel generativeModel = new GenerativeModel("gemini-pro", vertexAi); + GenerativeModel generativeModel = new GenerativeModel(GEMINI_1_5_PRO, vertexAi); GenerationConfig generationConfig = GenerationConfig.getDefaultInstance(); ChatLanguageModel model = new VertexAiGeminiChatModel(generativeModel, generationConfig); @@ -189,7 +183,7 @@ void should_accept_text_and_image_from_public_url() { ); // when - Response response = visionModel.generate(userMessage); + Response response = model.generate(userMessage); // then assertThat(response.content().text()).containsIgnoringCase("cat"); @@ -205,7 +199,7 @@ void should_accept_text_and_image_from_google_storage_url() { ); // when - Response response = visionModel.generate(userMessage); + Response response = model.generate(userMessage); // then assertThat(response.content().text()).containsIgnoringCase("cat"); @@ -222,7 +216,7 @@ void should_accept_text_and_base64_image() { ); // when - Response response = visionModel.generate(userMessage); + Response response = model.generate(userMessage); // then assertThat(response.content().text()).containsIgnoringCase("cat"); @@ -239,7 +233,7 @@ void should_accept_text_and_multiple_images_from_public_urls() { ); // when - Response response = visionModel.generate(userMessage); + Response response = model.generate(userMessage); // then assertThat(response.content().text()) @@ -258,7 +252,7 @@ void should_accept_text_and_multiple_images_from_google_storage_urls() { ); // when - Response response = visionModel.generate(userMessage); + Response response = model.generate(userMessage); // then assertThat(response.content().text()) @@ -279,7 +273,7 @@ void should_accept_text_and_multiple_base64_images() { ); // when - Response response = visionModel.generate(userMessage); + Response response = model.generate(userMessage); // then assertThat(response.content().text()) @@ -299,12 +293,12 @@ void should_accept_text_and_multiple_images_from_different_sources() { ); // when - Response response = visionModel.generate(userMessage); + Response response = model.generate(userMessage); // then assertThat(response.content().text()) .containsIgnoringCase("cat") - .containsIgnoringCase("dog") +// .containsIgnoringCase("dog") // sometimes the model replies "puppy" instead of "dog" .containsIgnoringCase("dice"); } @@ -315,7 +309,7 @@ void should_accept_tools_for_function_calling() { ChatLanguageModel model = VertexAiGeminiChatModel.builder() .project(System.getenv("GCP_PROJECT_ID")) .location(System.getenv("GCP_LOCATION")) - .modelName("gemini-pro") + .modelName(GEMINI_1_5_PRO) .build(); ToolSpecification weatherToolSpec = ToolSpecification.builder() @@ -355,6 +349,49 @@ void should_accept_tools_for_function_calling() { assertThat(weatherResponse.content().text()).containsIgnoringCase("sunny"); } + @Test + void should_handle_parallel_function_calls() { + // given + ChatLanguageModel model = VertexAiGeminiChatModel.builder() + .project(System.getenv("GCP_PROJECT_ID")) + .location(System.getenv("GCP_LOCATION")) + .modelName(GEMINI_1_5_PRO) + .build(); + + ToolSpecification stockInventoryToolSpec = ToolSpecification.builder() + .name("getProductInventory") + .description("Get the product inventory for a particular product ID") + .addParameter("product_id", JsonSchemaProperty.STRING, + JsonSchemaProperty.description("the ID of the product")) + .build(); + + List allMessages = new ArrayList<>(); + + UserMessage inventoryQuestion = UserMessage.from("Is there more stock of product ABC123 or of XYZ789?"); + System.out.println("Question: " + inventoryQuestion.text()); + allMessages.add(inventoryQuestion); + + // when + Response messageResponse = model.generate(allMessages, stockInventoryToolSpec); + + System.out.println("inventory response = " + messageResponse.content().text()); + + // then + assertThat(messageResponse.content().hasToolExecutionRequests()).isTrue(); + + List executionRequests = messageResponse.content().toolExecutionRequests(); + assertThat(executionRequests.size()).isEqualTo(2); + + String inventoryStock = executionRequests.stream() + .map(ToolExecutionRequest::arguments) + .collect(Collectors.joining(",")); + + System.out.println("inventoryStock = " + inventoryStock); + + assertThat(inventoryStock).containsIgnoringCase("ABC123"); + assertThat(inventoryStock).containsIgnoringCase("XYZ789"); + } + static class Calculator { @Tool("Adds two given numbers") @@ -418,4 +455,29 @@ void should_use_tools_with_AiService_2() { verify(calculator).multiply(257, 467); verifyNoMoreInteractions(calculator); } + + static class AnniversaryDate { + @Tool("get the anniversary date") + String getCurrentDate() { + return "2040-03-10"; + } + } + + @Test + void should_support_noarg_fn() { + + // given + AnniversaryDate anniversaryDate = new AnniversaryDate(); + + Assistant assistant = AiServices.builder(Assistant.class) + .chatLanguageModel(model) + .tools(anniversaryDate) + .build(); + + // when + String answer = assistant.chat("What is the year of the anniversary date?"); + + // then + assertThat(answer).contains("2040"); + } } \ No newline at end of file diff --git a/langchain4j-vertex-ai-gemini/src/test/java/dev/langchain4j/model/vertexai/VertexAiGeminiStreamingChatModelIT.java b/langchain4j-vertex-ai-gemini/src/test/java/dev/langchain4j/model/vertexai/VertexAiGeminiStreamingChatModelIT.java index f2adcb88c9..57631a74ae 100644 --- a/langchain4j-vertex-ai-gemini/src/test/java/dev/langchain4j/model/vertexai/VertexAiGeminiStreamingChatModelIT.java +++ b/langchain4j-vertex-ai-gemini/src/test/java/dev/langchain4j/model/vertexai/VertexAiGeminiStreamingChatModelIT.java @@ -7,7 +7,6 @@ import dev.langchain4j.agent.tool.ToolExecutionRequest; import dev.langchain4j.agent.tool.ToolSpecification; import dev.langchain4j.data.message.*; -import dev.langchain4j.model.StreamingResponseHandler; import dev.langchain4j.model.chat.StreamingChatLanguageModel; import dev.langchain4j.model.chat.TestStreamingResponseHandler; import dev.langchain4j.model.output.Response; @@ -19,7 +18,6 @@ import java.util.ArrayList; import java.util.Base64; import java.util.List; -import java.util.concurrent.CompletableFuture; import java.util.stream.Stream; import static dev.langchain4j.internal.Utils.readBytes; @@ -29,21 +27,16 @@ import static dev.langchain4j.model.vertexai.VertexAiGeminiChatModelIT.DICE_IMAGE_URL; import static java.util.Arrays.asList; import static java.util.Collections.singletonList; -import static java.util.concurrent.TimeUnit.SECONDS; import static org.assertj.core.api.Assertions.assertThat; class VertexAiGeminiStreamingChatModelIT { - StreamingChatLanguageModel model = VertexAiGeminiStreamingChatModel.builder() - .project(System.getenv("GCP_PROJECT_ID")) - .location(System.getenv("GCP_LOCATION")) - .modelName("gemini-pro") - .build(); + public static final String GEMINI_1_5_PRO = "gemini-1.5-pro-preview-0514"; - StreamingChatLanguageModel visionModel = VertexAiGeminiStreamingChatModel.builder() + StreamingChatLanguageModel model = VertexAiGeminiStreamingChatModel.builder() .project(System.getenv("GCP_PROJECT_ID")) .location(System.getenv("GCP_LOCATION")) - .modelName("gemini-pro-vision") + .modelName(GEMINI_1_5_PRO) .build(); @Test @@ -70,7 +63,7 @@ void should_stream_answer() { @ParameterizedTest @MethodSource - void should_merge_system_messages_into_user_message(List messages) { + void should_support_system_instructions(List messages) { // when TestStreamingResponseHandler handler = new TestStreamingResponseHandler<>(); @@ -78,27 +71,27 @@ void should_merge_system_messages_into_user_message(List messages) Response response = handler.get(); // then - assertThat(response.content().text()).containsIgnoringCase("liebe"); + assertThat(response.content().text()).containsIgnoringCase("lieb"); } - static Stream should_merge_system_messages_into_user_message() { + static Stream should_support_system_instructions() { return Stream.builder() .add(Arguments.of( asList( SystemMessage.from("Translate in German"), - UserMessage.from("I love you") + UserMessage.from("I love apples") ) )) .add(Arguments.of( asList( - UserMessage.from("I love you"), + UserMessage.from("I love apples"), SystemMessage.from("Translate in German") ) )) .add(Arguments.of( asList( SystemMessage.from("Translate in Italian"), - UserMessage.from("I love you"), + UserMessage.from("I love apples"), SystemMessage.from("No, translate in German!") ) )) @@ -106,8 +99,8 @@ static Stream should_merge_system_messages_into_user_message() { asList( SystemMessage.from("Translate in German"), UserMessage.from(asList( - TextContent.from("I love you"), - TextContent.from("I see you") + TextContent.from("I love apples"), + TextContent.from("I see apples") )) ) )) @@ -115,17 +108,17 @@ static Stream should_merge_system_messages_into_user_message() { asList( SystemMessage.from("Translate in German"), UserMessage.from(asList( - TextContent.from("I see you"), - TextContent.from("I love you") + TextContent.from("I see apples"), + TextContent.from("I love apples") )) ) )) .add(Arguments.of( asList( SystemMessage.from("Translate in German"), - UserMessage.from("I see you"), - AiMessage.from("Ich sehe dich"), - UserMessage.from("I love you") + UserMessage.from("I see appels"), + AiMessage.from("Ich sehe Äpfel"), + UserMessage.from("I love apples") ) )) .build(); @@ -138,7 +131,7 @@ void should_respect_maxOutputTokens() { StreamingChatLanguageModel model = VertexAiGeminiStreamingChatModel.builder() .project(System.getenv("GCP_PROJECT_ID")) .location(System.getenv("GCP_LOCATION")) - .modelName("gemini-pro") + .modelName(GEMINI_1_5_PRO) .maxOutputTokens(1) .build(); @@ -157,7 +150,7 @@ void should_respect_maxOutputTokens() { assertThat(response.tokenUsage().totalTokenCount()) .isEqualTo(response.tokenUsage().inputTokenCount() + response.tokenUsage().outputTokenCount()); - assertThat(response.finishReason()).isEqualTo(STOP); + assertThat(response.finishReason()).isEqualTo(LENGTH); } @Test @@ -165,7 +158,7 @@ void should_allow_custom_generativeModel_and_generationConfig() { // given VertexAI vertexAi = new VertexAI(System.getenv("GCP_PROJECT_ID"), System.getenv("GCP_LOCATION")); - GenerativeModel generativeModel = new GenerativeModel("gemini-pro", vertexAi); + GenerativeModel generativeModel = new GenerativeModel(GEMINI_1_5_PRO, vertexAi); GenerationConfig generationConfig = GenerationConfig.getDefaultInstance(); StreamingChatLanguageModel model = new VertexAiGeminiStreamingChatModel(generativeModel, generationConfig); @@ -192,7 +185,7 @@ void should_accept_text_and_image_from_public_url() { // when TestStreamingResponseHandler handler = new TestStreamingResponseHandler<>(); - visionModel.generate(singletonList(userMessage), handler); + model.generate(singletonList(userMessage), handler); Response response = handler.get(); // then @@ -210,7 +203,7 @@ void should_accept_text_and_image_from_google_storage_url() { // when TestStreamingResponseHandler handler = new TestStreamingResponseHandler<>(); - visionModel.generate(singletonList(userMessage), handler); + model.generate(singletonList(userMessage), handler); Response response = handler.get(); // then @@ -229,7 +222,7 @@ void should_accept_text_and_base64_image() { // when TestStreamingResponseHandler handler = new TestStreamingResponseHandler<>(); - visionModel.generate(singletonList(userMessage), handler); + model.generate(singletonList(userMessage), handler); Response response = handler.get(); // then @@ -248,7 +241,7 @@ void should_accept_text_and_multiple_images_from_public_urls() { // when TestStreamingResponseHandler handler = new TestStreamingResponseHandler<>(); - visionModel.generate(singletonList(userMessage), handler); + model.generate(singletonList(userMessage), handler); Response response = handler.get(); // then @@ -269,7 +262,7 @@ void should_accept_text_and_multiple_images_from_google_storage_urls() { // when TestStreamingResponseHandler handler = new TestStreamingResponseHandler<>(); - visionModel.generate(singletonList(userMessage), handler); + model.generate(singletonList(userMessage), handler); Response response = handler.get(); // then @@ -292,7 +285,7 @@ void should_accept_text_and_multiple_base64_images() { // when TestStreamingResponseHandler handler = new TestStreamingResponseHandler<>(); - visionModel.generate(singletonList(userMessage), handler); + model.generate(singletonList(userMessage), handler); Response response = handler.get(); // then @@ -314,13 +307,13 @@ void should_accept_text_and_multiple_images_from_different_sources() { // when TestStreamingResponseHandler handler = new TestStreamingResponseHandler<>(); - visionModel.generate(singletonList(userMessage), handler); + model.generate(singletonList(userMessage), handler); Response response = handler.get(); // then assertThat(response.content().text()) .containsIgnoringCase("cat") - .containsIgnoringCase("dog") +// .containsIgnoringCase("dog") // sometimes model replies with "puppy" instead of "dog" .containsIgnoringCase("dice"); } @@ -331,7 +324,7 @@ void should_accept_function_call() { VertexAiGeminiStreamingChatModel model = VertexAiGeminiStreamingChatModel.builder() .project(System.getenv("GCP_PROJECT_ID")) .location(System.getenv("GCP_LOCATION")) - .modelName("gemini-pro") + .modelName(GEMINI_1_5_PRO) .build(); ToolSpecification weatherToolSpec = ToolSpecification.builder() diff --git a/langchain4j-vertex-ai/pom.xml b/langchain4j-vertex-ai/pom.xml index c55657ef6e..1cac2da2d1 100644 --- a/langchain4j-vertex-ai/pom.xml +++ b/langchain4j-vertex-ai/pom.xml @@ -7,7 +7,7 @@ dev.langchain4j langchain4j-parent - 0.30.0 + 0.32.0-SNAPSHOT ../langchain4j-parent/pom.xml @@ -24,7 +24,7 @@ com.google.cloud google-cloud-aiplatform - 3.33.0 + 3.44.0 commons-logging diff --git a/langchain4j-vertex-ai/src/test/java/dev/langchain4j/model/vertexai/VertexAiEmbeddingModelIT.java b/langchain4j-vertex-ai/src/test/java/dev/langchain4j/model/vertexai/VertexAiEmbeddingModelIT.java index 8adfa8a3ec..7b93fe1ad3 100644 --- a/langchain4j-vertex-ai/src/test/java/dev/langchain4j/model/vertexai/VertexAiEmbeddingModelIT.java +++ b/langchain4j-vertex-ai/src/test/java/dev/langchain4j/model/vertexai/VertexAiEmbeddingModelIT.java @@ -211,7 +211,7 @@ void testEmbeddingTask() { // Document retrieval embedding Metadata metadata = new Metadata(); - metadata.add("title", "Text embeddings"); + metadata.put("title", "Text embeddings"); TextSegment segmentForRetrieval = new TextSegment("Text embeddings can be used to represent both the " + "user's query and the universe of documents in a high-dimensional vector space. Documents " + @@ -236,7 +236,7 @@ void testEmbeddingTask() { // as the embedding model requires "title" to be used only for RETRIEVAL_DOCUMENT Metadata metadataCustomTitleKey = new Metadata(); - metadataCustomTitleKey.add("customTitle", "Text embeddings"); + metadataCustomTitleKey.put("customTitle", "Text embeddings"); TextSegment segmentForRetrievalWithCustomKey = new TextSegment("Text embeddings can be used to represent both the " + "user's query and the universe of documents in a high-dimensional vector space. Documents " + diff --git a/langchain4j-vespa/pom.xml b/langchain4j-vespa/pom.xml index 4a73923aec..1ab8788699 100644 --- a/langchain4j-vespa/pom.xml +++ b/langchain4j-vespa/pom.xml @@ -7,7 +7,7 @@ dev.langchain4j langchain4j-parent - 0.30.0 + 0.32.0-SNAPSHOT ../langchain4j-parent/pom.xml diff --git a/langchain4j-weaviate/pom.xml b/langchain4j-weaviate/pom.xml index a27327fc6c..4c3c1a17fc 100644 --- a/langchain4j-weaviate/pom.xml +++ b/langchain4j-weaviate/pom.xml @@ -7,7 +7,7 @@ dev.langchain4j langchain4j-parent - 0.30.0 + 0.32.0-SNAPSHOT ../langchain4j-parent/pom.xml @@ -27,7 +27,7 @@ io.weaviate client - 4.5.1 + 4.6.0 diff --git a/langchain4j-weaviate/src/main/java/dev/langchain4j/store/embedding/weaviate/WeaviateEmbeddingStore.java b/langchain4j-weaviate/src/main/java/dev/langchain4j/store/embedding/weaviate/WeaviateEmbeddingStore.java index c1bee15c4b..b5c53ea16e 100644 --- a/langchain4j-weaviate/src/main/java/dev/langchain4j/store/embedding/weaviate/WeaviateEmbeddingStore.java +++ b/langchain4j-weaviate/src/main/java/dev/langchain4j/store/embedding/weaviate/WeaviateEmbeddingStore.java @@ -1,5 +1,6 @@ package dev.langchain4j.store.embedding.weaviate; +import dev.langchain4j.data.document.Metadata; import dev.langchain4j.data.embedding.Embedding; import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.store.embedding.EmbeddingMatch; @@ -11,16 +12,20 @@ import io.weaviate.client.base.WeaviateErrorMessage; import io.weaviate.client.v1.auth.exception.AuthException; import io.weaviate.client.v1.data.model.WeaviateObject; +import io.weaviate.client.v1.filters.Operator; +import io.weaviate.client.v1.filters.WhereFilter; import io.weaviate.client.v1.graphql.model.GraphQLError; import io.weaviate.client.v1.graphql.model.GraphQLResponse; import io.weaviate.client.v1.graphql.query.argument.NearVectorArgument; import io.weaviate.client.v1.graphql.query.fields.Field; import lombok.Builder; +import org.apache.commons.lang3.ArrayUtils; import java.util.*; import static dev.langchain4j.internal.Utils.*; import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank; +import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty; import static io.weaviate.client.v1.data.replication.model.ConsistencyLevel.QUORUM; import static java.util.Arrays.stream; import static java.util.Collections.emptyList; @@ -31,31 +36,37 @@ /** * Represents the Weaviate vector database. * Current implementation assumes the cosine distance metric is used. - * Does not support storing {@link dev.langchain4j.data.document.Metadata} yet. */ public class WeaviateEmbeddingStore implements EmbeddingStore { private static final String METADATA_TEXT_SEGMENT = "text"; private static final String ADDITIONALS = "_additional"; + private static final String METADATA = "_metadata"; + private static final String NULL_VALUE = ""; private final WeaviateClient client; private final String objectClass; private final boolean avoidDups; private final String consistencyLevel; + private final Collection metadataKeys; /** * Creates a new WeaviateEmbeddingStore instance. * - * @param apiKey Your Weaviate API key. Not required for local deployment. - * @param scheme The scheme, e.g. "https" of cluster URL. Find in under Details of your Weaviate cluster. - * @param host The host, e.g. "langchain4j-4jw7ufd9.weaviate.network" of cluster URL. - * Find in under Details of your Weaviate cluster. - * @param port The port, e.g. 8080. This parameter is optional. - * @param objectClass The object class you want to store, e.g. "MyGreatClass". Must start from an uppercase letter. - * @param avoidDups If true (default), then WeaviateEmbeddingStore will generate a hashed ID based on - * provided text segment, which avoids duplicated entries in DB. - * If false, then random ID will be generated. - * @param consistencyLevel Consistency level: ONE, QUORUM (default) or ALL. Find more details here. + * @param apiKey Your Weaviate API key. Not required for local deployment. + * @param scheme The scheme, e.g. "https" of cluster URL. Find in under Details of your Weaviate cluster. + * @param host The host, e.g. "langchain4j-4jw7ufd9.weaviate.network" of cluster URL. + * Find in under Details of your Weaviate cluster. + * @param port The port, e.g. 8080. This parameter is optional. + * @param objectClass The object class you want to store, e.g. "MyGreatClass". Must start from an uppercase letter. + * @param avoidDups If true (default), then WeaviateEmbeddingStore will generate a hashed ID based on + * provided text segment, which avoids duplicated entries in DB. + * If false, then random ID will be generated. + * @param consistencyLevel Consistency level: ONE, QUORUM (default) or ALL. Find more details here. + * @param metadataKeys Metadata keys that should be persisted (optional) + * @param useGrpcForInserts Use GRPC instead of HTTP for batch inserts only. You still need HTTP configured for search + * @param securedGrpc The GRPC connection is secured + * @param grpcPort The port, e.g. 50051. This parameter is optional. */ @Builder public WeaviateEmbeddingStore( @@ -63,9 +74,13 @@ public WeaviateEmbeddingStore( String scheme, String host, Integer port, + Boolean useGrpcForInserts, + Boolean securedGrpc, + Integer grpcPort, String objectClass, Boolean avoidDups, - String consistencyLevel + String consistencyLevel, + Collection metadataKeys ) { try { @@ -73,13 +88,22 @@ public WeaviateEmbeddingStore( ensureNotBlank(scheme, "scheme"), concatenate(ensureNotBlank(host, "host"), port) ); - this.client = WeaviateAuthClient.apiKey(config, getOrDefault(apiKey, "")); + if (getOrDefault(useGrpcForInserts, Boolean.FALSE)) { + config.setGRPCSecured(getOrDefault(securedGrpc, Boolean.FALSE)); + config.setGRPCHost(host + ":" + getOrDefault(grpcPort, 50051)); + } + if (isNullOrBlank(apiKey)) { + this.client = new WeaviateClient(config); + } else { + this.client = WeaviateAuthClient.apiKey(config, apiKey); + } } catch (AuthException e) { throw new IllegalArgumentException(e); } this.objectClass = getOrDefault(objectClass, "Default"); this.avoidDups = getOrDefault(avoidDups, true); this.consistencyLevel = getOrDefault(consistencyLevel, QUORUM); + this.metadataKeys = getOrDefault(metadataKeys, Collections.emptyList()); } private static String concatenate(String host, Integer port) { @@ -125,6 +149,26 @@ public List addAll(List embeddings, List embedde return addAll(null, embeddings, embedded); } + @Override + public void removeAll(Collection ids) { + ensureNotEmpty(ids, "ids"); + client.batch().objectsBatchDeleter() + .withClassName(objectClass) + .withWhere(WhereFilter.builder() + .path("id") + .operator(Operator.ContainsAny) + .valueText(ids.toArray(new String[0])) + .build()) + .run(); + } + + @Override + public void removeAll() { + client.batch().objectsBatchDeleter() + .withClassName(objectClass) + .run(); + } + /** * {@inheritDoc} * The score inside {@link EmbeddingMatch} is Weaviate's certainty. @@ -135,22 +179,29 @@ public List> findRelevant( int maxResults, double minCertainty ) { + List fields = new ArrayList<>(); + fields.add(Field.builder().name(METADATA_TEXT_SEGMENT).build()); + fields.add(Field + .builder() + .name(ADDITIONALS) + .fields( + Field.builder().name("id").build(), + Field.builder().name("certainty").build(), + Field.builder().name("vector").build() + ) + .build()); + if (!metadataKeys.isEmpty()) { + List metadataFields = new ArrayList<>(); + for (String property : metadataKeys) { + metadataFields.add(Field.builder().name(property).build()); + } + fields.add(Field.builder().name(METADATA).fields(metadataFields.toArray(new Field[0])).build()); + } Result result = client .graphQL() .get() .withClassName(objectClass) - .withFields( - Field.builder().name(METADATA_TEXT_SEGMENT).build(), - Field - .builder() - .name(ADDITIONALS) - .fields( - Field.builder().name("id").build(), - Field.builder().name("certainty").build(), - Field.builder().name("vector").build() - ) - .build() - ) + .withFields(fields.toArray(new Field[0])) .withNearVector( NearVectorArgument .builder() @@ -200,34 +251,64 @@ private List addAll(List ids, List embeddings, List props = new HashMap<>(); - props.put(METADATA_TEXT_SEGMENT, text); - + Map metadata = prefillMetadata(); + if (segment != null) { + props.put(METADATA_TEXT_SEGMENT, segment.text()); + if (!segment.metadata().toMap().isEmpty()) { + for (String property : metadataKeys) { + if (segment.metadata().containsKey(property)) { + metadata.put(property, segment.metadata().get(property)); + } + } + } else { + props.put(METADATA, metadata); + } + props.put(METADATA, metadata); + } else { + props.put(METADATA_TEXT_SEGMENT, ""); + props.put(METADATA, metadata); + } + props.put("indexFilterable", true); + props.put("indexSearchable", true); return WeaviateObject .builder() .className(objectClass) .id(id) - .vector(embedding.vectorAsList().toArray(new Float[0])) + .vector(embedding.vectorAsList().toArray(ArrayUtils.EMPTY_FLOAT_OBJECT_ARRAY)) .properties(props) .build(); } + private Map prefillMetadata() { + Map metadata = new HashMap<>(metadataKeys.size()); + for (String property : metadataKeys) { + metadata.put(property, NULL_VALUE); + } + return metadata; + } + private static EmbeddingMatch toEmbeddingMatch(Map item) { Map additional = (Map) item.get(ADDITIONALS); + final Metadata metadata = new Metadata(); + if (item.get(METADATA) != null && item.get(METADATA) instanceof Map) { + Map resultingMetadata = (Map) item.get(METADATA); + for (Map.Entry entry : resultingMetadata.entrySet()) { + if (entry.getValue() != null && !NULL_VALUE.equals(entry.getValue())) { + metadata.add(entry.getKey(), entry.getValue()); + } + } + } String text = (String) item.get(METADATA_TEXT_SEGMENT); return new EmbeddingMatch<>( @@ -236,7 +317,7 @@ private static EmbeddingMatch toEmbeddingMatch(Map item) Embedding.from( ((List) additional.get("vector")).stream().map(Double::floatValue).collect(toList()) ), - isNullOrBlank(text) ? null : TextSegment.from(text) + isNullOrBlank(text) ? null : TextSegment.from(text, metadata) ); } } diff --git a/langchain4j-weaviate/src/test/java/dev/langchain4j/store/embedding/weaviate/CloudWeaviateEmbeddingStoreIT.java b/langchain4j-weaviate/src/test/java/dev/langchain4j/store/embedding/weaviate/CloudWeaviateEmbeddingStoreIT.java index 383a33fd0a..ba1350c677 100644 --- a/langchain4j-weaviate/src/test/java/dev/langchain4j/store/embedding/weaviate/CloudWeaviateEmbeddingStoreIT.java +++ b/langchain4j-weaviate/src/test/java/dev/langchain4j/store/embedding/weaviate/CloudWeaviateEmbeddingStoreIT.java @@ -4,19 +4,52 @@ import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel; import dev.langchain4j.model.embedding.EmbeddingModel; import dev.langchain4j.store.embedding.EmbeddingStore; -import dev.langchain4j.store.embedding.EmbeddingStoreWithoutMetadataIT; +import dev.langchain4j.store.embedding.EmbeddingStoreIT; +import io.weaviate.client.Config; +import io.weaviate.client.WeaviateAuthClient; +import io.weaviate.client.WeaviateClient; +import io.weaviate.client.v1.auth.exception.AuthException; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import java.util.Arrays; + import static dev.langchain4j.internal.Utils.randomUUID; @EnabledIfEnvironmentVariable(named = "WEAVIATE_API_KEY", matches = ".+") -class CloudWeaviateEmbeddingStoreIT extends EmbeddingStoreWithoutMetadataIT { +class CloudWeaviateEmbeddingStoreIT extends EmbeddingStoreIT { + + String objectClass = "Test" + randomUUID().replace("-", ""); EmbeddingStore embeddingStore = WeaviateEmbeddingStore.builder() .apiKey(System.getenv("WEAVIATE_API_KEY")) .scheme("https") .host(System.getenv("WEAVIATE_HOST")) - .objectClass("Test" + randomUUID().replace("-", "")) + .objectClass(objectClass) + .metadataKeys(Arrays.asList( + "string_empty", + "string_space", + "string_abc", + "integer_min", + "integer_minus_1", + "integer_0", + "integer_1", + "integer_max", + "long_min", + "long_minus_1", + "long_0", + "long_1", + "long_max", + "float_min", + "float_minus_1", + "float_0", + "float_1", + "float_123", + "float_max", + "double_minus_1", + "double_0", + "double_1", + "double_123" + )) .build(); EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel(); @@ -31,8 +64,20 @@ protected EmbeddingModel embeddingModel() { return embeddingModel; } + @Override + protected void clearStore() { + try { + WeaviateClient client = WeaviateAuthClient.apiKey(new Config("https", System.getenv("WEAVIATE_HOST")), System.getenv("WEAVIATE_API_KEY")); + client.batch().objectsBatchDeleter() + .withClassName(objectClass) + .run(); + } catch (AuthException ex) { + throw new RuntimeException(ex); + } + } + @Override protected void ensureStoreIsEmpty() { - // TODO fix + } } \ No newline at end of file diff --git a/langchain4j-weaviate/src/test/java/dev/langchain4j/store/embedding/weaviate/LocalGRPCWeaviateEmbeddingStoreIT.java b/langchain4j-weaviate/src/test/java/dev/langchain4j/store/embedding/weaviate/LocalGRPCWeaviateEmbeddingStoreIT.java new file mode 100644 index 0000000000..f8feba9891 --- /dev/null +++ b/langchain4j-weaviate/src/test/java/dev/langchain4j/store/embedding/weaviate/LocalGRPCWeaviateEmbeddingStoreIT.java @@ -0,0 +1,73 @@ +package dev.langchain4j.store.embedding.weaviate; + +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel; +import dev.langchain4j.model.embedding.EmbeddingModel; +import dev.langchain4j.store.embedding.EmbeddingStore; +import dev.langchain4j.store.embedding.EmbeddingStoreIT; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.weaviate.WeaviateContainer; + +import java.util.Arrays; + +import static dev.langchain4j.internal.Utils.randomUUID; + +@Testcontainers +class LocalGRPCWeaviateEmbeddingStoreIT extends EmbeddingStoreIT { + + @Container + static WeaviateContainer weaviate = new WeaviateContainer("semitechnologies/weaviate:latest") + .withEnv("QUERY_DEFAULTS_LIMIT", "25") + .withEnv("DEFAULT_VECTORIZER_MODULE", "none") + .withEnv("CLUSTER_HOSTNAME", "node1"); + + private final EmbeddingStore embeddingStore = WeaviateEmbeddingStore.builder() + .scheme("http") + .host(weaviate.getHost()) + .port(weaviate.getMappedPort(8080)) + .useGrpcForInserts(true) + .grpcPort(weaviate.getMappedPort(50051)) + .objectClass("Test" + randomUUID().replace("-", "")) + .metadataKeys(Arrays.asList("string_empty", + "string_space", + "string_abc", + "integer_min", + "integer_minus_1", + "integer_0", + "integer_1", + "integer_max", + "long_min", + "long_minus_1", + "long_0", + "long_1", + "long_max", + "float_min", + "float_minus_1", + "float_0", + "float_1", + "float_123", + "float_max", + "double_minus_1", + "double_0", + "double_1", + "double_123")) + .build(); + + EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel(); + + @Override + protected EmbeddingStore embeddingStore() { + return embeddingStore; + } + + @Override + protected EmbeddingModel embeddingModel() { + return embeddingModel; + } + + @Override + protected void ensureStoreIsEmpty() { + // TODO fix + } +} diff --git a/langchain4j-weaviate/src/test/java/dev/langchain4j/store/embedding/weaviate/LocalWeaviateEmbeddingStoreIT.java b/langchain4j-weaviate/src/test/java/dev/langchain4j/store/embedding/weaviate/LocalWeaviateEmbeddingStoreIT.java index 4dc645b2eb..efb12c118b 100644 --- a/langchain4j-weaviate/src/test/java/dev/langchain4j/store/embedding/weaviate/LocalWeaviateEmbeddingStoreIT.java +++ b/langchain4j-weaviate/src/test/java/dev/langchain4j/store/embedding/weaviate/LocalWeaviateEmbeddingStoreIT.java @@ -4,15 +4,17 @@ import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel; import dev.langchain4j.model.embedding.EmbeddingModel; import dev.langchain4j.store.embedding.EmbeddingStore; -import dev.langchain4j.store.embedding.EmbeddingStoreWithoutMetadataIT; +import dev.langchain4j.store.embedding.EmbeddingStoreIT; import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; import org.testcontainers.weaviate.WeaviateContainer; +import java.util.Arrays; + import static dev.langchain4j.internal.Utils.randomUUID; @Testcontainers -class LocalWeaviateEmbeddingStoreIT extends EmbeddingStoreWithoutMetadataIT { +class LocalWeaviateEmbeddingStoreIT extends EmbeddingStoreIT { @Container static WeaviateContainer weaviate = new WeaviateContainer("semitechnologies/weaviate:latest") @@ -20,11 +22,35 @@ class LocalWeaviateEmbeddingStoreIT extends EmbeddingStoreWithoutMetadataIT { .withEnv("DEFAULT_VECTORIZER_MODULE", "none") .withEnv("CLUSTER_HOSTNAME", "node1"); - EmbeddingStore embeddingStore = WeaviateEmbeddingStore.builder() + private final EmbeddingStore embeddingStore = WeaviateEmbeddingStore.builder() .scheme("http") .host(weaviate.getHost()) .port(weaviate.getFirstMappedPort()) .objectClass("Test" + randomUUID().replace("-", "")) + .metadataKeys(Arrays.asList( + "string_empty", + "string_space", + "string_abc", + "integer_min", + "integer_minus_1", + "integer_0", + "integer_1", + "integer_max", + "long_min", + "long_minus_1", + "long_0", + "long_1", + "long_max", + "float_min", + "float_minus_1", + "float_0", + "float_1", + "float_123", + "float_max", + "double_minus_1", + "double_0", + "double_1", + "double_123")) .build(); EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel(); @@ -43,4 +69,4 @@ protected EmbeddingModel embeddingModel() { protected void ensureStoreIsEmpty() { // TODO fix } -} \ No newline at end of file +} diff --git a/langchain4j-weaviate/src/test/java/dev/langchain4j/store/embedding/weaviate/LocalWeaviateEmbeddingStoreRemovalIT.java b/langchain4j-weaviate/src/test/java/dev/langchain4j/store/embedding/weaviate/LocalWeaviateEmbeddingStoreRemovalIT.java new file mode 100644 index 0000000000..fe3a4c4a54 --- /dev/null +++ b/langchain4j-weaviate/src/test/java/dev/langchain4j/store/embedding/weaviate/LocalWeaviateEmbeddingStoreRemovalIT.java @@ -0,0 +1,61 @@ +package dev.langchain4j.store.embedding.weaviate; + +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel; +import dev.langchain4j.model.embedding.EmbeddingModel; +import dev.langchain4j.store.embedding.EmbeddingStore; +import dev.langchain4j.store.embedding.EmbeddingStoreWithRemovalIT; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.weaviate.WeaviateContainer; + +import static dev.langchain4j.internal.Utils.randomUUID; +import static java.util.Collections.singletonList; + +@Testcontainers +class LocalWeaviateEmbeddingStoreRemovalIT extends EmbeddingStoreWithRemovalIT { + + @Container + static WeaviateContainer weaviate = new WeaviateContainer("semitechnologies/weaviate:latest") + .withEnv("QUERY_DEFAULTS_LIMIT", "25") + .withEnv("DEFAULT_VECTORIZER_MODULE", "none") + .withEnv("CLUSTER_HOSTNAME", "node1"); + + EmbeddingStore embeddingStore = WeaviateEmbeddingStore.builder() + .scheme("http") + .host(weaviate.getHost()) + .port(weaviate.getFirstMappedPort()) + .objectClass("Test" + randomUUID().replace("-", "")) + .consistencyLevel("ALL") + .metadataKeys(singletonList("id")) + .build(); + + EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel(); + + @Override + protected EmbeddingStore embeddingStore() { + return embeddingStore; + } + + @Override + protected EmbeddingModel embeddingModel() { + return embeddingModel; + } + + @Test + @Disabled("should be enabled once implemented") + void should_remove_all_by_filter() { + } + + @Test + @Disabled("should be enabled once implemented") + void should_fail_to_remove_all_by_filter_null() { + } + + @Test + @Disabled("should be enabled once implemented") + void should_remove_all() { + } +} diff --git a/langchain4j-workers-ai/pom.xml b/langchain4j-workers-ai/pom.xml new file mode 100644 index 0000000000..118f89dda8 --- /dev/null +++ b/langchain4j-workers-ai/pom.xml @@ -0,0 +1,75 @@ + + + 4.0.0 + + + dev.langchain4j + langchain4j-parent + 0.32.0-SNAPSHOT + ../langchain4j-parent/pom.xml + + + langchain4j-workers-ai + jar + + LangChain4j :: Integration :: CloudFlare Workers AI + + + + + dev.langchain4j + langchain4j-core + + + + com.squareup.retrofit2 + retrofit + + + + com.squareup.retrofit2 + converter-jackson + + + + com.squareup.okhttp3 + okhttp + + + + org.projectlombok + lombok + provided + + + + org.junit.jupiter + junit-jupiter-engine + test + + + + org.junit.jupiter + junit-jupiter-params + test + + + + org.assertj + assertj-core + test + + + + + + Apache-2.0 + https://www.apache.org/licenses/LICENSE-2.0.txt + repo + A business-friendly OSS license + + + + \ No newline at end of file diff --git a/langchain4j-workers-ai/src/main/java/dev/langchain4j/model/workersai/WorkersAiChatModel.java b/langchain4j-workers-ai/src/main/java/dev/langchain4j/model/workersai/WorkersAiChatModel.java new file mode 100644 index 0000000000..dbd5068ccb --- /dev/null +++ b/langchain4j-workers-ai/src/main/java/dev/langchain4j/model/workersai/WorkersAiChatModel.java @@ -0,0 +1,206 @@ +package dev.langchain4j.model.workersai; + +import dev.langchain4j.agent.tool.ToolSpecification; +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.output.FinishReason; +import dev.langchain4j.model.output.Response; +import dev.langchain4j.model.workersai.client.AbstractWorkersAIModel; +import dev.langchain4j.model.workersai.client.WorkersAiChatCompletionRequest; +import dev.langchain4j.model.workersai.spi.WorkersAiChatModelBuilderFactory; +import lombok.NonNull; +import lombok.extern.slf4j.Slf4j; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; + +import static dev.langchain4j.spi.ServiceHelper.loadFactories; + +/** + * WorkerAI Chat model. + * ... + */ +@Slf4j +public class WorkersAiChatModel extends AbstractWorkersAIModel implements ChatLanguageModel { + + /** + * Constructor with Builder. + * + * @param builder + * builder. + */ + public WorkersAiChatModel(Builder builder) { + this(builder.accountId, builder.modelName, builder.apiToken); + } + + /** + * Constructor with Builder. + * + * @param accountId + * account identifier + * @param modelName + * model name + * @param apiToken + * api token + */ + public WorkersAiChatModel(String accountId, String modelName, String apiToken) { + super(accountId, modelName, apiToken); + } + + /** + * Builder access. + * + * @return + * builder instance + */ + public static Builder builder() { + for (WorkersAiChatModelBuilderFactory factory : loadFactories(WorkersAiChatModelBuilderFactory.class)) { + return factory.get(); + } + return new Builder(); + } + + /** + * Internal Builder. + */ + public static class Builder { + + /** + * Account identifier, provided by the WorkerAI platform. + */ + public String accountId; + /** + * ModelName, preferred as enum for extensibility. + */ + public String apiToken; + /** + * ModelName, preferred as enum for extensibility. + */ + public String modelName; + + /** + * Simple constructor. + */ + public Builder() { + } + + /** + * Simple constructor. + * + * @param accountId + * account identifier. + * @return + * self reference + */ + public Builder accountId(String accountId) { + this.accountId = accountId; + return this; + } + + /** + * Sets the apiToken for the Worker AI model builder. + * + * @param apiToken The apiToken to set. + * @return The current instance of {@link Builder}. + */ + public Builder apiToken(String apiToken) { + this.apiToken = apiToken; + return this; + } + + /** + * Sets the model name for the Worker AI model builder. + * + * @param modelName The name of the model to set. + * @return The current instance of {@link Builder}. + */ + public Builder modelName(String modelName) { + this.modelName = modelName; + return this; + } + + /** + * Builds a new instance of Worker AI Chat Model. + * + * @return A new instance of {@link WorkersAiChatModel}. + */ + public WorkersAiChatModel build() { + return new WorkersAiChatModel(this); + } + } + + /** {@inheritDoc} */ + @Override + public String generate(String userMessage) { + return generate(new WorkersAiChatCompletionRequest(WorkersAiChatCompletionRequest.MessageRole.user, userMessage)); + } + + /** {@inheritDoc} */ + @Override + public Response generate(@NonNull ChatMessage... messages) { + return generate(Arrays.asList(messages)); + } + + /** {@inheritDoc} */ + @Override + public Response generate(List messages) { + WorkersAiChatCompletionRequest req = new WorkersAiChatCompletionRequest(); + req.setMessages(messages.stream() + .map(this::toMessage) + .collect(Collectors.toList())); + return new Response<>(new AiMessage(generate(req)),null, FinishReason.STOP); + } + + /** {@inheritDoc} */ + @Override + public Response generate(List messages, List toolSpecifications) { + throw new UnsupportedOperationException("Tools are currently not supported for WorkerAI models"); + } + + /** {@inheritDoc} */ + @Override + public Response generate(List messages, ToolSpecification toolSpecification) { + throw new UnsupportedOperationException("Tools are currently not supported for WorkerAI models"); + } + + /** + * Mapping ChatMessage to ChatTextGenerationRequest.Message + * + * @param message + * inbound message + * @return + * message for request + */ + private WorkersAiChatCompletionRequest.Message toMessage(ChatMessage message) { + return new WorkersAiChatCompletionRequest.Message( + WorkersAiChatCompletionRequest.MessageRole.valueOf(message.type().name().toLowerCase()), + message.text()); + } + + /** + * Invoke endpoint and process error. + * + * @param req + * request + * @return + * text generated by the model + */ + private String generate(WorkersAiChatCompletionRequest req) { + try { + retrofit2.Response retrofitResponse = workerAiClient + .generateChat(req, accountId, modelName) + .execute(); + processErrors(retrofitResponse.body(), retrofitResponse.errorBody()); + if (retrofitResponse.body() == null) { + throw new IllegalStateException("Response is empty"); + } + return retrofitResponse.body().getResult().getResponse(); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + +} \ No newline at end of file diff --git a/langchain4j-workers-ai/src/main/java/dev/langchain4j/model/workersai/WorkersAiChatModelName.java b/langchain4j-workers-ai/src/main/java/dev/langchain4j/model/workersai/WorkersAiChatModelName.java new file mode 100644 index 0000000000..12521de2fe --- /dev/null +++ b/langchain4j-workers-ai/src/main/java/dev/langchain4j/model/workersai/WorkersAiChatModelName.java @@ -0,0 +1,96 @@ +package dev.langchain4j.model.workersai; + +/** + * Enum for Workers AI Chat Model Name. + */ +public enum WorkersAiChatModelName { + + // --------------------------------------------------------------------- + // Text Generation + // https://developers.cloudflare.com/workers-ai/models/text-generation/ + // --------------------------------------------------------------------- + + /** Full precision (fp16) generative text model with 7 billion parameters from Met. */ + LLAMA2_7B_FULL("@cf/meta/llama-2-7b-chat-fp16"), + /** Quantized (int8) generative text model with 7 billion parameters from Meta. */ + LLAMA2_7B_QUANTIZED("@cf/meta/llama-2-7b-chat-int8"), + /** Instruct fine-tuned version of the Mistral-7b generative text model with 7 billion parameters. */ + CODELLAMA_7B_AWQ("@hf/thebloke/codellama-7b-instruct-awq"), + /** Deepseek Coder is composed of a series of code language models, each trained from scratch on 2T tokens, with a composition of 87% code and 13% natural language in both English and Chinese.. */ + DEEPSEEK_CODER_6_7_BASE("@hf/thebloke/deepseek-coder-6.7b-base-awq"), + /** Deepseek Coder is composed of a series of code language models, each trained from scratch on 2T tokens, with a composition of 87% code and 13% natural language in both English and Chinese.. */ + DEEPSEEK_CODER_MATH_7B_AWQ(" @hf/thebloke/deepseek-math-7b-awq"), + /** DeepSeekMath is initialized with DeepSeek-Coder-v1.5 7B and continues pre-training on math-related tokens sourced from Common Crawl, together with natural language and code data for 500B tokens. */ + DEEPSEEK_CODER_MATH_7B_INSTRUCT("@hf/thebloke/deepseek-math-7b-instruct"), + /** DeepSeekMath-Instruct 7B is a mathematically instructed tuning model derived from DeepSeekMath-Base 7B. DeepSeekMath is initialized with DeepSeek-Coder-v1.5 7B and continues pre-training on math-related tokens sourced from Common Crawl, together with natural language and code data for 500B tokens.. */ + MISTRAL_7B_INSTRUCT("@cf/mistral/mistral-7b-instruct-v0.1"), + /** DiscoLM German 7b is a Mistral-based large language model with a focus on German-language applications. AWQ is an efficient, accurate and blazing-fast low-bit weight quantization method, currently supporting 4-bit quantization. */ + DISCOLM_GERMAN_7B_V1_AWQ("@cf/thebloke/discolm-german-7b-v1-awq"), + /** Falcon-7B-Instruct is a 7B parameters causal decoder-only model built by TII based on Falcon-7B and finetuned on a mixture of chat/instruct datasets. */ + FALCOM_7B_INSTRUCT("@cf/tiiuae/falcon-7b-instruct"), + /** This is a Gemma-2B base model that Cloudflare dedicates for inference with LoRA adapters. Gemma is a family of lightweight, state-of-the-art open models from Google, built from the same research and technology used to create the Gemini models. */ + GEMMA_2B_IT_LORA("@cf/google/gemma-2b-it-lora"), + /** Gemma is a family of lightweight, state-of-the-art open models from Google, built from the same research and technology used to create the Gemini models. They are text-to-text, decoder-only large language models, available in English, with open weights, pre-trained variants, and instruction-tuned variants. */ + GEMMA_7B_IT("@hf/google/gemma-7b-it"), + /** This is a Gemma-7B base model that Cloudflare dedicates for inference with LoRA adapters. Gemma is a family of lightweight, state-of-the-art open models from Google, built from the same research and technology used to create the Gemini models. */ + GEMMA_2B_IT_LORA_DUPLICATE("@cf/google/gemma-2b-it-lora"), + /** Hermes 2 Pro on Mistral 7B is the new flagship 7B Hermes! Hermes 2 Pro is an upgraded, retrained version of Nous Hermes 2, consisting of an updated and cleaned version of the OpenHermes 2.5 Dataset, as well as a newly introduced Function Calling and JSON Mode dataset developed in-house. */ + HERMES_2_PRO_MISTRAL_7B("@hf/nousresearch/hermes-2-pro-mistral-7b"), + /** Llama 2 13B Chat AWQ is an efficient, accurate and blazing-fast low-bit weight quantized Llama 2 variant. */ + LLAMA_2_13B_CHAT_AWQ("@hf/thebloke/llama-2-13b-chat-awq"), + /** This is a Llama2 base model that Cloudflare dedicated for inference with LoRA adapters. Llama 2 is a collection of pretrained and fine-tuned generative text models ranging in scale from 7 billion to 70 billion parameters. This is the repository for the 7B fine-tuned model, optimized for dialogue use cases and converted for the Hugging Face Transformers format. */ + LLAMA_2_7B_CHAT_HF_LORA("@cf/meta-llama/llama-2-7b-chat-hf-lora"), + /** Generation over generation, Meta Llama 3 demonstrates state-of-the-art performance on a wide range of industry benchmarks and offers new capabilities, including improved reasoning. */ + LLAMA_3_8B_INSTRUCT("@cf/meta/llama-3-8b-instruct"), + /** Quantized (int4) generative text model with 8 billion parameters from Meta. */ + LLAMA_2_13B_CHAT_AWQ_DUPLICATE("@hf/thebloke/llama-2-13b-chat-awq"), + /** Llama Guard is a model for classifying the safety of LLM prompts and responses, using a taxonomy of safety risks. */ + LLAMAGUARD_7B_AWQ("@hf/thebloke/llamaguard-7b-awq"), + /** Quantized (int4) generative text model with 8 billion parameters from Meta. */ + META_LLAMA_3_8B_INSTRUCT("@hf/meta-llama/meta-llama-3-8b-instruct"), + /** Mistral 7B Instruct v0.1 AWQ is an efficient, accurate and blazing-fast low-bit weight quantized Mistral variant. */ + MISTRAL_7B_INSTRUCT_V0_1_AWQ("@hf/thebloke/mistral-7b-instruct-v0.1-awq"), + /** The Mistral-7B-Instruct-v0.2 Large Language Model (LLM) is an instruct fine-tuned version of the Mistral-7B-v0.2. Mistral-7B-v0.2 has the following changes compared to Mistral-7B-v0.1: 32k context window (vs 8k context in v0.1), rope-theta = 1e6, and no Sliding-Window Attention. */ + MISTRAL_7B_INSTRUCT_V0_2("@hf/mistral/mistral-7b-instruct-v0.2"), + /** The Mistral-7B-Instruct-v0.2 Large Language Model (LLM) is an instruct fine-tuned version of the Mistral-7B-v0.2. */ + MISTRAL_7B_INSTRUCT_V0_2_LORA("@cf/mistral/mistral-7b-instruct-v0.2-lora"), + /** This model is a fine-tuned 7B parameter LLM on the Intel Gaudi 2 processor from the mistralai/Mistral-7B-v0.1 on the open source dataset Open-Orca/SlimOrca. */ + NEURAL_CHAT_7B_V3_1_AWQ("@hf/thebloke/neural-chat-7b-v3-1-awq"), + /** OpenChat is an innovative library of open-source language models, fine-tuned with C-RLFT - a strategy inspired by offline reinforcement learning. */ + OPENCHAT_3_5_0106("@cf/openchat/openchat-3.5-0106"), + /** OpenHermes 2.5 Mistral 7B is a state of the art Mistral Fine-tune, a continuation of OpenHermes 2 model, which trained on additional code datasets. */ + OPENHERMES_2_5_MISTRAL_7B_AWQ("@hf/thebloke/openhermes-2.5-mistral-7b-awq"), + /** Phi-2 is a Transformer-based model with a next-word prediction objective, trained on 1.4T tokens from multiple passes on a mixture of Synthetic and Web datasets for NLP and coding. */ + PHI_2("@cf/microsoft/phi-2"), + /** Qwen1.5 is the improved version of Qwen, the large language model series developed by Alibaba Cloud. */ + QWEN1_5_0_5B_CHAT("@cf/qwen/qwen1.5-0.5b-chat"), + /** Qwen1.5 is the improved version of Qwen, the large language model series developed by Alibaba Cloud. */ + QWEN1_5_1_8B_CHAT("@cf/qwen/qwen1.5-1.8b-chat"), + /** Qwen1.5 is the improved version of Qwen, the large language model series developed by Alibaba Cloud. AWQ is an efficient, accurate and blazing-fast low-bit weight quantization method, currently supporting 4-bit quantization. */ + QWEN1_5_14B_CHAT_AWQ("@cf/qwen/qwen1.5-14b-chat-awq"), + /** Qwen1.5 is the improved version of Qwen, the large language model series developed by Alibaba Cloud. AWQ is an efficient, accurate and blazing-fast low-bit weight quantization method, currently supporting 4-bit quantization. */ + QWEN1_5_7B_CHAT_AWQ("@cf/qwen/qwen1.5-7b-chat-awq"), + /** This model is intended to be used by non-technical users to understand data inside their SQL databases. */ + SQLCODER_7B_2("@cf/defog/sqlcoder-7b-2"), + /** We introduce Starling-LM-7B-beta, an open large language model (LLM) trained by Reinforcement Learning from AI Feedback (RLAIF). Starling-LM-7B-beta is trained from Openchat-3.5-0106 with our new reward model Nexusflow/Starling-RM-34B and policy optimization method Fine-Tuning Language Models from Human Preferences (PPO). */ + STARLING_LM_7B_BETA("@hf/nexusflow/starling-lm-7b-beta"), + /** The TinyLlama project aims to pretrain a 1.1B Llama model on 3 trillion tokens. This is the chat model finetuned on top of TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T. */ + TINYLLAMA_1_1B_CHAT_V1_0("@cf/tinyllama/tinyllama-1.1b-chat-v1.0"), + /** Cybertron 7B v2 is a 7B MistralAI based model, best on it’s series. It was trained with SFT, DPO and UNA (Unified Neural Alignment) on multiple datasets. */ + UNA_CYBERTRON_7B_V2_BF16("@cf/fblgit/una-cybertron-7b-v2-bf16"), + /** Zephyr 7B Beta AWQ is an efficient, accurate and blazing-fast low-bit weight quantized Zephyr model variant. */ + ZEPHYR_7B_BETA_AWQ("@hf/thebloke/zephyr-7b-beta-awq"); + + private final String stringValue; + + WorkersAiChatModelName(String stringValue) { + this.stringValue = stringValue; + } + + @Override + public String toString() { + return stringValue; + } + + +} diff --git a/langchain4j-workers-ai/src/main/java/dev/langchain4j/model/workersai/WorkersAiEmbeddingModel.java b/langchain4j-workers-ai/src/main/java/dev/langchain4j/model/workersai/WorkersAiEmbeddingModel.java new file mode 100644 index 0000000000..e5dc9ea34e --- /dev/null +++ b/langchain4j-workers-ai/src/main/java/dev/langchain4j/model/workersai/WorkersAiEmbeddingModel.java @@ -0,0 +1,248 @@ +package dev.langchain4j.model.workersai; + +import dev.langchain4j.data.embedding.Embedding; +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.model.embedding.EmbeddingModel; +import dev.langchain4j.model.output.FinishReason; +import dev.langchain4j.model.output.Response; +import dev.langchain4j.model.workersai.client.AbstractWorkersAIModel; +import dev.langchain4j.model.workersai.client.WorkersAiEmbeddingResponse; +import dev.langchain4j.model.workersai.spi.WorkersAiEmbeddingModelBuilderFactory; +import lombok.extern.slf4j.Slf4j; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; + +import static dev.langchain4j.spi.ServiceHelper.loadFactories; + +/** + * WorkerAI Embedding model. + * ... + */ +@Slf4j +public class WorkersAiEmbeddingModel extends AbstractWorkersAIModel implements EmbeddingModel { + + /** + * Constructor with Builder. + * + * @param builder + * builder. + */ + public WorkersAiEmbeddingModel(Builder builder) { + this(builder.accountId, builder.modelName, builder.apiToken); + } + + /** + * Constructor with Builder. + * + * @param accountId + * account identifier + * @param modelName + * model name + * @param apiToken + * api token + */ + public WorkersAiEmbeddingModel(String accountId, String modelName, String apiToken) { + super(accountId, modelName, apiToken); + } + + /** + * Builder access. + * + * @return + * builder instance + */ + public static Builder builder() { + for (WorkersAiEmbeddingModelBuilderFactory factory : loadFactories(WorkersAiEmbeddingModelBuilderFactory.class)) { + return factory.get(); + } + return new WorkersAiEmbeddingModel.Builder(); + } + + /** + * Internal Builder. + */ + public static class Builder { + + /** + * Account identifier, provided by the WorkerAI platform. + */ + public String accountId; + /** + * ModelName, preferred as enum for extensibility. + */ + public String apiToken; + /** + * ModelName, preferred as enum for extensibility. + */ + public String modelName; + + /** + * Simple constructor. + */ + public Builder() { + } + + /** + * Simple constructor. + * + * @param accountId + * account identifier. + * @return + * self reference + */ + public Builder accountId(String accountId) { + this.accountId = accountId; + return this; + } + + /** + * Sets the apiToken for the Worker AI model builder. + * + * @param apiToken The apiToken to set. + * @return The current instance of {@link WorkersAiChatModel.Builder}. + */ + public Builder apiToken(String apiToken) { + this.apiToken = apiToken; + return this; + } + + /** + * Sets the model name for the Worker AI model builder. + * + * @param modelName The name of the model to set. + * @return The current instance of {@link WorkersAiChatModel.Builder}. + */ + public Builder modelName(String modelName) { + this.modelName = modelName; + return this; + } + + /** + * Builds a new instance of Worker AI Chat Model. + * + * @return A new instance of {@link WorkersAiChatModel}. + */ + public WorkersAiEmbeddingModel build() { + return new WorkersAiEmbeddingModel(this); + } + } + + /** {@inheritDoc} */ + @Override + public Response embed(String text) { + try { + dev.langchain4j.model.workersai.client.WorkersAiEmbeddingRequest req = new dev.langchain4j.model.workersai.client.WorkersAiEmbeddingRequest(); + req.getText().add(text); + + retrofit2.Response retrofitResponse = workerAiClient + .embed(req, accountId, modelName) + .execute(); + + processErrors(retrofitResponse.body(), retrofitResponse.errorBody()); + if (retrofitResponse.body() == null) { + throw new RuntimeException("Unexpected response: " + retrofitResponse); + } + dev.langchain4j.model.workersai.client.WorkersAiEmbeddingResponse.EmbeddingResult res = retrofitResponse.body().getResult(); + // Single Vector expected + if (res.getShape().get(0) != 1) { + throw new RuntimeException("Unexpected shape: " + res.getShape()); + } + List embeddings = res.getData().get(0); + float[] floatArray = new float[embeddings.size()]; + for (int i = 0; i < embeddings.size(); i++) { + floatArray[i] = embeddings.get(i); // Unboxing Float to float + } + return new Response<>(new Embedding(floatArray), null, FinishReason.STOP); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + /** {@inheritDoc} */ + @Override + public Response embed(TextSegment textSegment) { + // no metadata in worker ai + return embed(textSegment.text()); + } + + + /** {@inheritDoc} */ + @Override + public Response> embedAll(List textSegments) { + List>> futures = new ArrayList<>(); + ExecutorService executor = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors()); + try { + final int chunkSize = 100; + for (int i = 0; i < textSegments.size(); i += chunkSize) { + List chunk = textSegments.subList(i, Math.min(textSegments.size(), i + chunkSize)); + Future> future = executor.submit(() -> processChunk(chunk, accountId, modelName)); + futures.add(future); + } + // Wait for all futures to complete and collect results + List result = new ArrayList<>(); + for (Future> future : futures) { + result.addAll(future.get()); + } + return new Response<>(result); + } catch (InterruptedException | ExecutionException e) { + throw new RuntimeException(e); + } finally { + executor.shutdown(); + try { + if (!executor.awaitTermination(800, TimeUnit.MILLISECONDS)) { + executor.shutdownNow(); + } + } catch (InterruptedException e) { + executor.shutdownNow(); + } + } + } + + /** + * Process chunk of text segments. + * + * @param chunk + * chunk of text segments. + * @param accountIdentifier + * account identifier. + * @param modelName + * model name. + * @return + * list of embeddings. + * @throws IOException + * error occurred during invocation. + */ + private List processChunk(List chunk, String accountIdentifier, String modelName) + throws IOException { + dev.langchain4j.model.workersai.client.WorkersAiEmbeddingRequest req = new dev.langchain4j.model.workersai.client.WorkersAiEmbeddingRequest(); + for (TextSegment textSegment : chunk) { + req.getText().add(textSegment.text()); + } + retrofit2.Response retrofitResponse = workerAiClient + .embed(req, accountIdentifier, modelName) + .execute(); + processErrors(retrofitResponse.body(), retrofitResponse.errorBody()); + if (retrofitResponse.body() == null) { + throw new RuntimeException("Unexpected response: " + retrofitResponse); + } + WorkersAiEmbeddingResponse.EmbeddingResult res = retrofitResponse.body().getResult(); + + List> embeddings = res.getData(); + List embeddingsList = new ArrayList<>(); + for (List embedding : embeddings) { + float[] floatArray = new float[embedding.size()]; + for (int i = 0; i < embedding.size(); i++) { + floatArray[i] = embedding.get(i); // Unboxing Float to float + } + embeddingsList.add(new Embedding(floatArray)); + } + return embeddingsList; + } +} \ No newline at end of file diff --git a/langchain4j-workers-ai/src/main/java/dev/langchain4j/model/workersai/WorkersAiEmbeddingModelName.java b/langchain4j-workers-ai/src/main/java/dev/langchain4j/model/workersai/WorkersAiEmbeddingModelName.java new file mode 100644 index 0000000000..691c055281 --- /dev/null +++ b/langchain4j-workers-ai/src/main/java/dev/langchain4j/model/workersai/WorkersAiEmbeddingModelName.java @@ -0,0 +1,34 @@ +package dev.langchain4j.model.workersai; + +/** + * Enum for Workers AI Embedding Model Name. + */ +public enum WorkersAiEmbeddingModelName { + + // --------------------------------------------------------------------- + // Text Embeddings + // https://developers.cloudflare.com/workers-ai/models/text-embeddings/ + // --------------------------------------------------------------------- + + /** BAAI general embedding (bge) models transform any given text into a compact vector. */ + BAAI_EMBEDDING_SMALL("@cf/baai/bge-small-en-v1.5"), + + /** BAAI general embedding (bge) models transform any given text into a compact vector. */ + BAAI_EMBEDDING_BASE("@cf/baai/bge-base-en-v1.5"), + + /** BAAI general embedding (bge) models transform any given text into a compact vector. */ + BAAI_EMBEDDING_LARGE("@cf/baai/bge-large-en-v1.5"); + + private final String stringValue; + + WorkersAiEmbeddingModelName(String stringValue) { + this.stringValue = stringValue; + } + + @Override + public String toString() { + return stringValue; + } + + +} diff --git a/langchain4j-workers-ai/src/main/java/dev/langchain4j/model/workersai/WorkersAiImageModel.java b/langchain4j-workers-ai/src/main/java/dev/langchain4j/model/workersai/WorkersAiImageModel.java new file mode 100644 index 0000000000..c6fe2ae58d --- /dev/null +++ b/langchain4j-workers-ai/src/main/java/dev/langchain4j/model/workersai/WorkersAiImageModel.java @@ -0,0 +1,291 @@ +package dev.langchain4j.model.workersai; + +import dev.langchain4j.data.image.Image; +import dev.langchain4j.model.image.ImageModel; +import dev.langchain4j.model.output.FinishReason; +import dev.langchain4j.model.output.Response; +import dev.langchain4j.model.workersai.client.AbstractWorkersAIModel; +import dev.langchain4j.model.workersai.client.WorkersAiImageGenerationRequest; +import dev.langchain4j.model.workersai.spi.WorkersAiImageModelBuilderFactory; +import okhttp3.ResponseBody; + +import javax.imageio.ImageIO; +import java.awt.image.BufferedImage; +import java.io.ByteArrayOutputStream; +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.net.URL; +import java.util.Base64; + +import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank; +import static dev.langchain4j.internal.ValidationUtils.ensureNotNull; +import static dev.langchain4j.spi.ServiceHelper.loadFactories; + +/** + * WorkerAI Image model. + */ +public class WorkersAiImageModel extends AbstractWorkersAIModel implements ImageModel { + + /** + * The mime type returned by Workers + */ + private static final String MIME_TYPE = "image/png"; + + /** + * Constructor with Builder. + * + * @param builder + * builder. + */ + public WorkersAiImageModel(Builder builder) { + this(builder.accountId, builder.modelName, builder.apiToken); + } + + /** + * Constructor with Builder. + * + * @param accountId + * account identifier + * @param modelName + * model name + * @param apiToken + * api token + */ + public WorkersAiImageModel(String accountId, String modelName, String apiToken) { + super(accountId, modelName, apiToken); + } + + /** + * Builder access. + * + * @return + * builder instance + */ + public static Builder builder() { + for (WorkersAiImageModelBuilderFactory factory : loadFactories(WorkersAiImageModelBuilderFactory.class)) { + return factory.get(); + } + return new WorkersAiImageModel.Builder(); + } + + /** + * Internal Builder. + */ + public static class Builder { + + /** + * Account identifier, provided by the WorkerAI platform. + */ + public String accountId; + /** + * ModelName, preferred as enum for extensibility. + */ + public String apiToken; + /** + * ModelName, preferred as enum for extensibility. + */ + public String modelName; + + /** + * Simple constructor. + */ + public Builder() { + } + + /** + * Simple constructor. + * + * @param accountId + * account identifier. + * @return + * self reference + */ + public Builder accountId(String accountId) { + this.accountId = accountId; + return this; + } + + /** + * Sets the apiToken for the Worker AI model builder. + * + * @param apiToken The apiToken to set. + * @return The current instance of {@link WorkersAiChatModel.Builder}. + */ + public Builder apiToken(String apiToken) { + this.apiToken = apiToken; + return this; + } + + /** + * Sets the model name for the Worker AI model builder. + * + * @param modelName The name of the model to set. + * @return The current instance of {@link WorkersAiChatModel.Builder}. + */ + public Builder modelName(String modelName) { + this.modelName = modelName; + return this; + } + + /** + * Builds a new instance of Worker AI Chat Model. + * + * @return A new instance of {@link WorkersAiChatModel}. + */ + public WorkersAiImageModel build() { + return new WorkersAiImageModel(this); + } + } + + /** {@inheritDoc} */ + public Response generate(String prompt) { + ensureNotBlank(prompt, "Prompt"); + return new Response<>(convertAsImage(executeQuery(prompt, null, null)), null, FinishReason.STOP); + } + + /** {@inheritDoc} */ + public Response edit(Image image, String prompt) { + ensureNotBlank(prompt, "Prompt"); + ensureNotNull(image, "Image"); + return new Response<>(convertAsImage(executeQuery(prompt, null, image)), null, FinishReason.STOP); + } + + /** {@inheritDoc} */ + public Response edit(Image image, Image mask, String prompt) { + ensureNotBlank(prompt, "Prompt"); + ensureNotNull(image, "Image"); + ensureNotNull(mask, "Mask"); + return new Response<>(convertAsImage(executeQuery(prompt, mask, image)), null, FinishReason.STOP); + } + + /** + * Generate image and save to file. + * + * @param prompt + * current prompt + * @param destinationFile + * local file + * @return + * response with the destination file + */ + public Response generate(String prompt, String destinationFile) { + ensureNotBlank(prompt, "Prompt"); + ensureNotBlank(destinationFile, "Destination file"); + try { + byte[] image = executeQuery(prompt, null, null); + try (FileOutputStream fileOutputStream = new FileOutputStream(destinationFile)) { + fileOutputStream.write(image); + } + return new Response<>(new File(destinationFile), null, FinishReason.STOP); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + /** + * Execute query. + * + * @param prompt + * prompt. + * @return + * image. + */ + private byte[] executeQuery(String prompt, Image image, Image mask) { + try { + // Mapping inbound + WorkersAiImageGenerationRequest imgReq = new WorkersAiImageGenerationRequest(); + imgReq.setPrompt(prompt); + if (image != null) { + if (image.url() != null) { + imgReq.setImage(getPixels(image.url().toURL())); + } + } + if (mask != null) { + if (mask.url() != null) { + imgReq.setMask(getPixels(mask.url().toURL())); + } + } + + retrofit2.Response response = workerAiClient + .generateImage(imgReq, accountId, modelName) + .execute(); + + if (response.isSuccessful() && response.body() != null) { + InputStream inputStream = response.body().byteStream(); + ByteArrayOutputStream buffer = new ByteArrayOutputStream(); + int nRead; + byte[] data = new byte[1024]; + while ((nRead = inputStream.read(data, 0, data.length)) != -1) { + buffer.write(data, 0, nRead); + } + buffer.flush(); + return buffer.toByteArray(); + } + throw new IllegalStateException("An error occured while generating image."); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + /** + * Convert an image into a array of number, supposedly the Pixels. + * @param imageUrl + * current image URL + * @return + * pixels of the image + * @throws Exception + * return an exception if pixel not returned + */ + public int[] getPixels(URL imageUrl) throws Exception { + BufferedImage image = ImageIO.read(imageUrl); + + // Get image dimensions + int width = image.getWidth(); + int height = image.getHeight(); + + // Initialize an array to hold the pixel data + int[] pixelData = new int[width * height]; + + // Extract pixel data + int index = 0; + for (int y = 0; y < height; y++) { + for (int x = 0; x < width; x++) { + // Get pixel color at (x, y) + int pixel = image.getRGB(x, y); + + // Extract the individual color components + int alpha = (pixel >> 24) & 0xff; + int red = (pixel >> 16) & 0xff; + int green = (pixel >> 8) & 0xff; + int blue = pixel & 0xff; + + // Combine the color components into a single integer + int color = (alpha << 24) | (red << 16) | (green << 8) | blue; + + // Store the color in the array + pixelData[index++] = color; + } + } + return pixelData; + } + + /** + * Convert Workers AI Image Generation output to Langchain4j model. + * + * @param data + * output image + * @return + * output image converted + */ + public Image convertAsImage(byte[] data) { + return Image.builder() + .base64Data(Base64.getEncoder().encodeToString(data)) + .mimeType(MIME_TYPE) + .build(); + } + +} + + diff --git a/langchain4j-workers-ai/src/main/java/dev/langchain4j/model/workersai/WorkersAiImageModelName.java b/langchain4j-workers-ai/src/main/java/dev/langchain4j/model/workersai/WorkersAiImageModelName.java new file mode 100644 index 0000000000..55c93fe8c9 --- /dev/null +++ b/langchain4j-workers-ai/src/main/java/dev/langchain4j/model/workersai/WorkersAiImageModelName.java @@ -0,0 +1,46 @@ +package dev.langchain4j.model.workersai; + +/** + * Enum for Workers AI Omage Model Name. + */ +public enum WorkersAiImageModelName { + + // --------------------------------------------------------------------- + // Text to image + // https://developers.cloudflare.com/workers-ai/models/text-to-image/ + // --------------------------------------------------------------------- + + /** + * Diffusion-based text-to-image generative model by Stability AI. Generates and modify images based on text prompts. + */ + STABLE_DIFFUSION_XL("@cf/stabilityai/stable-diffusion-xl-base-1.0"), + /** + * Stable Diffusion model that has been fine-tuned to be better at photorealism without sacrificing range. + */ + DREAM_SHAPER_8_LCM("@cf/lykon/dreamshaper-8-lcm"), + /** + * Stable Diffusion is a latent text-to-image diffusion model capable of generating photo-realistic images. Img2img generate a new image from an input image with Stable Diffusion. + */ + STABLE_DIFFUSION_V1_5_IMG2IMG("@cf/runwayml/stable-diffusion-v1-5-img2img"), + /** + * Stable Diffusion Inpainting is a latent text-to-image diffusion model capable of generating photo-realistic images given any text input, with the extra capability of inpainting the pictures by using a mask. + */ + STABLE_DIFFUSION_V1_5_IN_PAINTING("@cf/runwayml/stable-diffusion-v1-5-inpainting"), + /** + * SDXL-Lightning is a lightning-fast text-to-image generation model. It can generate high-quality 1024px images in a few steps. + */ + STABLE_DIFFUSION_XL_LIGHTNING("@cf/bytedance/stable-diffusion-xl-lightning"); + + private final String stringValue; + + WorkersAiImageModelName(String stringValue) { + this.stringValue = stringValue; + } + + @Override + public String toString() { + return stringValue; + } + + +} diff --git a/langchain4j-workers-ai/src/main/java/dev/langchain4j/model/workersai/WorkersAiLanguageModel.java b/langchain4j-workers-ai/src/main/java/dev/langchain4j/model/workersai/WorkersAiLanguageModel.java new file mode 100644 index 0000000000..4e1938d396 --- /dev/null +++ b/langchain4j-workers-ai/src/main/java/dev/langchain4j/model/workersai/WorkersAiLanguageModel.java @@ -0,0 +1,151 @@ +package dev.langchain4j.model.workersai; + +import dev.langchain4j.model.input.Prompt; +import dev.langchain4j.model.language.LanguageModel; +import dev.langchain4j.model.output.Response; +import dev.langchain4j.model.workersai.client.AbstractWorkersAIModel; +import dev.langchain4j.model.workersai.client.WorkersAiTextCompletionRequest; +import dev.langchain4j.model.workersai.client.WorkersAiTextCompletionResponse; +import dev.langchain4j.model.workersai.spi.WorkersAiLanguageModelBuilderFactory; +import lombok.extern.slf4j.Slf4j; + +import java.io.IOException; + +import static dev.langchain4j.spi.ServiceHelper.loadFactories; + +/** + * WorkerAI Language model. + * ... + */ +@Slf4j +public class WorkersAiLanguageModel extends AbstractWorkersAIModel implements LanguageModel { + + /** + * Constructor with Builder. + * + * @param builder + * builder. + */ + public WorkersAiLanguageModel(Builder builder) { + this(builder.accountId, builder.modelName, builder.apiToken); + } + + /** + * Constructor with Builder. + * + * @param accountId + * account identifier + * @param modelName + * model name + * @param apiToken + * api token + */ + public WorkersAiLanguageModel(String accountId, String modelName, String apiToken) { + super(accountId, modelName, apiToken); + } + + /** + * Builder access. + * + * @return + * builder instance + */ + public static WorkersAiLanguageModel.Builder builder() { + for (WorkersAiLanguageModelBuilderFactory factory : loadFactories(WorkersAiLanguageModelBuilderFactory.class)) { + return factory.get(); + } + return new Builder(); + } + + /** + * Internal Builder. + */ + public static class Builder { + + /** + * Account identifier, provided by the WorkerAI platform. + */ + public String accountId; + /** + * ModelName, preferred as enum for extensibility. + */ + public String apiToken; + /** + * ModelName, preferred as enum for extensibility. + */ + public String modelName; + + /** + * Simple constructor. + */ + public Builder() { + } + + /** + * Simple constructor. + * + * @param accountId + * account identifier. + * @return + * self reference + */ + public Builder accountId(String accountId) { + this.accountId = accountId; + return this; + } + + /** + * Sets the apiToken for the Worker AI model builder. + * + * @param apiToken The apiToken to set. + * @return The current instance of {@link WorkersAiChatModel.Builder}. + */ + public Builder apiToken(String apiToken) { + this.apiToken = apiToken; + return this; + } + + /** + * Sets the model name for the Worker AI model builder. + * + * @param modelName The name of the model to set. + * @return The current instance of {@link WorkersAiChatModel.Builder}. + */ + public Builder modelName(String modelName) { + this.modelName = modelName; + return this; + } + + /** + * Builds a new instance of Worker AI Chat Model. + * + * @return A new instance of {@link WorkersAiChatModel}. + */ + public WorkersAiLanguageModel build() { + return new WorkersAiLanguageModel(this); + } + } + + /** {@inheritDoc} */ + @Override + public Response generate(String prompt) { + try { + retrofit2.Response retrofitResponse = workerAiClient + .generateText(new WorkersAiTextCompletionRequest(prompt), accountId, modelName) + .execute(); + processErrors(retrofitResponse.body(), retrofitResponse.errorBody()); + if (retrofitResponse.body() == null) { + throw new RuntimeException("Empty response"); + } + return new Response<>(retrofitResponse.body().getResult().getResponse()); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + /** {@inheritDoc} */ + @Override + public Response generate(Prompt prompt) { + return generate(prompt.text()); + } +} \ No newline at end of file diff --git a/langchain4j-workers-ai/src/main/java/dev/langchain4j/model/workersai/client/AbstractWorkersAIModel.java b/langchain4j-workers-ai/src/main/java/dev/langchain4j/model/workersai/client/AbstractWorkersAIModel.java new file mode 100644 index 0000000000..54a847b349 --- /dev/null +++ b/langchain4j-workers-ai/src/main/java/dev/langchain4j/model/workersai/client/AbstractWorkersAIModel.java @@ -0,0 +1,83 @@ +package dev.langchain4j.model.workersai.client; + +import lombok.extern.slf4j.Slf4j; +import okhttp3.ResponseBody; + +import java.io.IOException; + +/** + * Abstract class for WorkerAI models as they are all initialized the same way. + * ... + */ +@Slf4j +public abstract class AbstractWorkersAIModel { + + /** + * Account identifier, provided by the WorkerAI platform. + */ + protected String accountId; + + /** + * ModelName, preferred as enum for extensibility. + */ + protected String modelName; + + /** + * OkHttpClient for the WorkerAI API. + */ + protected WorkersAiApi workerAiClient; + + /** + * Simple constructor. + * + * @param accountId + * account identifier. + * @param modelName + * model name. + * @param apiToken + * api apiToken from . + */ + public AbstractWorkersAIModel(String accountId, String modelName, String apiToken) { + if (accountId == null || accountId.isEmpty()) { + throw new IllegalArgumentException("Account identifier should not be null or empty"); + } + this.accountId = accountId; + if (modelName == null || modelName.isEmpty()) { + throw new IllegalArgumentException("Model name should not be null or empty"); + } + this.modelName = modelName; + if (apiToken == null || apiToken.isEmpty()) { + throw new IllegalArgumentException("Token should not be null or empty"); + } + this.workerAiClient = WorkersAiClient.createService(apiToken); + } + + /** + * Process errors from the API. + * @param res + * response + * @param errors + * errors body from retrofit + * @throws IOException + * error occurred during invocation + */ + protected void processErrors(ApiResponse res, ResponseBody errors) + throws IOException { + if (res == null || !res.isSuccess()) { + StringBuilder errorMessage = new StringBuilder("Failed to generate chat message:"); + if (res == null) { + errorMessage.append(errors.string()); + } else if (res.getErrors() != null) { + errorMessage.append(res.getErrors().stream() + .map(ApiResponse.Error::getMessage) + .reduce((a, b) -> a + "\n" + b) + .orElse("")); + } + log.error(errorMessage.toString()); + throw new RuntimeException(errorMessage.toString()); + } + } + + + +} diff --git a/langchain4j-workers-ai/src/main/java/dev/langchain4j/model/workersai/client/ApiResponse.java b/langchain4j-workers-ai/src/main/java/dev/langchain4j/model/workersai/client/ApiResponse.java new file mode 100644 index 0000000000..c2b46eb71b --- /dev/null +++ b/langchain4j-workers-ai/src/main/java/dev/langchain4j/model/workersai/client/ApiResponse.java @@ -0,0 +1,60 @@ +package dev.langchain4j.model.workersai.client; + +import lombok.Data; + +import java.util.List; + +/** + * Multiple models leverage the same output format, so we can use this class to parse the response. + * + * @param + * Type of the result. + */ +@Data +public class ApiResponse { + + /** + * Result of the API call. + */ + private T result; + + /** + * Success of the API call. + */ + private boolean success; + + /** + * Errors of the API call. + */ + private List errors; + + /** + * Messages of the API call. + */ + private List messages; + + /** + * Default constructor. + */ + public ApiResponse() {} + + /** + * Error class. + */ + @Data + public static class Error { + /** + * Message of the error. + */ + private String message; + /** + * Code of the error. + */ + private int code; + /** + * Default constructor. + */ + public Error() {} + } + +} diff --git a/langchain4j-workers-ai/src/main/java/dev/langchain4j/model/workersai/client/WorkersAiApi.java b/langchain4j-workers-ai/src/main/java/dev/langchain4j/model/workersai/client/WorkersAiApi.java new file mode 100644 index 0000000000..de6dc6c06e --- /dev/null +++ b/langchain4j-workers-ai/src/main/java/dev/langchain4j/model/workersai/client/WorkersAiApi.java @@ -0,0 +1,82 @@ +package dev.langchain4j.model.workersai.client; + +import okhttp3.ResponseBody; +import retrofit2.Call; +import retrofit2.http.Body; +import retrofit2.http.POST; +import retrofit2.http.Path; + +/** + * Public interface to interact with the WorkerAI API. + */ +public interface WorkersAiApi { + + /** + * Generate chat. + * + * @param apiRequest + * request. + * @param accountIdentifier + * account identifier. + * @param modelId + * model id. + * @return + * response. + */ + @POST("client/v4/accounts/{accountIdentifier}/ai/run/{modelName}") + Call generateChat(@Body WorkersAiChatCompletionRequest apiRequest, + @Path("accountIdentifier") String accountIdentifier, + @Path(value = "modelName", encoded = true) String modelId); + + /** + * Generate text. + * + * @param apiRequest + * request. + * @param accountIdentifier + * account identifier. + * @param modelName + * model name. + * @return + * response. + */ + @POST("client/v4/accounts/{accountIdentifier}/ai/run/{modelName}") + Call generateText(@Body WorkersAiTextCompletionRequest apiRequest, + @Path("accountIdentifier") String accountIdentifier, + @Path(value = "modelName", encoded = true) String modelName); + + /** + * Generate image. + * + * @param apiRequest + * request. + * @param accountIdentifier + * account identifier. + * @param modelName + * model name. + * @return + * response. + */ + @POST("client/v4/accounts/{accountIdentifier}/ai/run/{modelName}") + Call generateImage(@Body WorkersAiImageGenerationRequest apiRequest, + @Path("accountIdentifier") String accountIdentifier, + @Path(value = "modelName", encoded = true) String modelName); + + /** + * Generate embeddings. + * + * @param apiRequest + * request. + * @param accountIdentifier + * account identifier. + * @param modelName + * model name. + * @return + * response. + */ + @POST("client/v4/accounts/{accountIdentifier}/ai/run/{modelName}") + Call embed(@Body WorkersAiEmbeddingRequest apiRequest, + @Path("accountIdentifier") String accountIdentifier, + @Path(value = "modelName", encoded = true) String modelName); + +} diff --git a/langchain4j-workers-ai/src/main/java/dev/langchain4j/model/workersai/client/WorkersAiChatCompletionRequest.java b/langchain4j-workers-ai/src/main/java/dev/langchain4j/model/workersai/client/WorkersAiChatCompletionRequest.java new file mode 100644 index 0000000000..32148cc1e3 --- /dev/null +++ b/langchain4j-workers-ai/src/main/java/dev/langchain4j/model/workersai/client/WorkersAiChatCompletionRequest.java @@ -0,0 +1,82 @@ +package dev.langchain4j.model.workersai.client; + +import lombok.AllArgsConstructor; +import lombok.Data; + +import java.util.ArrayList; +import java.util.List; + +/** + * Represents a request for AI chat completion. + * Contains a list of messages that form part of the chat conversation. + */ +@Data +public class WorkersAiChatCompletionRequest { + + private List messages; + + /** + * Represents a message in the AI chat. + * Each message has a role and content. + */ + @Data @AllArgsConstructor + public static class Message { + private MessageRole role; + private String content; + /** + * Default constructor. + */ + @SuppressWarnings("unused") + public Message() {} + } + + /** + * Defines the roles a message can have in the chat conversation. + */ + @SuppressWarnings("unused") + public enum MessageRole { + /** + * Directive for the prompt + */ + system, + /** + * The message is from the AI. + */ + ai, + /** + * The message is from the user. + */ + user + } + + /** + * Constructs an empty WorkerAiChatCompletionRequest with an empty list of messages. + */ + public WorkersAiChatCompletionRequest() { + this.messages = new ArrayList<>(); + } + + /** + * Constructs a WorkerAiChatCompletionRequest with an initial message. + * + * @param role The role of the initial message. + * @param content The content of the initial message. + */ + public WorkersAiChatCompletionRequest(MessageRole role, String content) { + this(); + addMessage(role, content); + } + + /** + * Adds a new message to the chat completion request. + * + * @param role The role of the message. + * @param content The content of the message. + */ + public void addMessage(MessageRole role, String content) { + Message message = new Message(role, content); + this.messages.add(message); + } + +} + diff --git a/langchain4j-workers-ai/src/main/java/dev/langchain4j/model/workersai/client/WorkersAiChatCompletionResponse.java b/langchain4j-workers-ai/src/main/java/dev/langchain4j/model/workersai/client/WorkersAiChatCompletionResponse.java new file mode 100644 index 0000000000..9622e7bfdb --- /dev/null +++ b/langchain4j-workers-ai/src/main/java/dev/langchain4j/model/workersai/client/WorkersAiChatCompletionResponse.java @@ -0,0 +1,12 @@ +package dev.langchain4j.model.workersai.client; + +/** + * Wrapper for the chat completion response. + */ +public class WorkersAiChatCompletionResponse extends WorkersAiTextCompletionResponse { + + /** + * Default constructor. + */ + public WorkersAiChatCompletionResponse() {} +} diff --git a/langchain4j-workers-ai/src/main/java/dev/langchain4j/model/workersai/client/WorkersAiClient.java b/langchain4j-workers-ai/src/main/java/dev/langchain4j/model/workersai/client/WorkersAiClient.java new file mode 100644 index 0000000000..a420f5041b --- /dev/null +++ b/langchain4j-workers-ai/src/main/java/dev/langchain4j/model/workersai/client/WorkersAiClient.java @@ -0,0 +1,86 @@ +package dev.langchain4j.model.workersai.client; + +import okhttp3.Interceptor; +import okhttp3.OkHttpClient; +import okhttp3.Request; +import okhttp3.Response; +import org.jetbrains.annotations.NotNull; +import retrofit2.Retrofit; +import retrofit2.converter.jackson.JacksonConverterFactory; + +import java.io.IOException; +import java.time.Duration; + +/** + * Low level client to interact with the WorkerAI API. + */ +public class WorkersAiClient { + + private static final String BASE_URL = "https://api.cloudflare.com/"; + + /** + * Constructor. + */ + public WorkersAiClient() {} + + /** + * Initialization of okHTTP. + * + * @param apiToken + * authorization token + * @return + * api + */ + public static WorkersAiApi createService(String apiToken) { + OkHttpClient okHttpClient = new OkHttpClient.Builder() + .addInterceptor(new AuthInterceptor(apiToken)) + // Slow but can be needed for images + .callTimeout(Duration.ofSeconds(30)) + .readTimeout(Duration.ofSeconds(30)) + .build(); + + Retrofit retrofit = new Retrofit.Builder() + .baseUrl(BASE_URL) + .client(okHttpClient) + .addConverterFactory(JacksonConverterFactory.create()) + .build(); + + return retrofit.create(WorkersAiApi.class); + } + + /** + * An interceptor for HTTP requests to add an authorization token to the header. + * Implements the {@link Interceptor} interface. + */ + public static class AuthInterceptor implements Interceptor { + private final String apiToken; + + /** + * Constructs an AuthInterceptor with a specified authorization token. + * + * @param apiToken The authorization token to be used in HTTP headers. + */ + public AuthInterceptor(String apiToken) { + this.apiToken = apiToken; + } + + /** + * Intercepts an outgoing HTTP request, adding an authorization header. + * + * @param chain The chain of request/response interceptors. + * @return The modified response after adding the authorization header. + * @throws IOException If an IO exception occurs during request processing. + */ + @NotNull + @Override + public Response intercept(Chain chain) throws IOException { + Request.Builder builder = chain + .request().newBuilder() + .header("Authorization", "Bearer " + apiToken); + Request request = builder.build(); + return chain.proceed(request); + } + } + + +} diff --git a/langchain4j-workers-ai/src/main/java/dev/langchain4j/model/workersai/client/WorkersAiEmbeddingRequest.java b/langchain4j-workers-ai/src/main/java/dev/langchain4j/model/workersai/client/WorkersAiEmbeddingRequest.java new file mode 100644 index 0000000000..fa8b0bb8e9 --- /dev/null +++ b/langchain4j-workers-ai/src/main/java/dev/langchain4j/model/workersai/client/WorkersAiEmbeddingRequest.java @@ -0,0 +1,21 @@ +package dev.langchain4j.model.workersai.client; + +import lombok.Data; + +import java.util.ArrayList; +import java.util.List; + +/** + * Request to compute embeddings + */ +@Data +public class WorkersAiEmbeddingRequest { + + private List text = new ArrayList<>(); + + /** + * Default constructor. + */ + public WorkersAiEmbeddingRequest() { + } +} diff --git a/langchain4j-workers-ai/src/main/java/dev/langchain4j/model/workersai/client/WorkersAiEmbeddingResponse.java b/langchain4j-workers-ai/src/main/java/dev/langchain4j/model/workersai/client/WorkersAiEmbeddingResponse.java new file mode 100644 index 0000000000..1677a286f3 --- /dev/null +++ b/langchain4j-workers-ai/src/main/java/dev/langchain4j/model/workersai/client/WorkersAiEmbeddingResponse.java @@ -0,0 +1,41 @@ +package dev.langchain4j.model.workersai.client; + +import lombok.Data; + +import java.util.List; + +/** + * Response to compute embeddings + */ +public class WorkersAiEmbeddingResponse extends ApiResponse{ + + /** + * Default constructor. + */ + public WorkersAiEmbeddingResponse() { + } + + /** + * Beam to hold results + */ + @Data + public static class EmbeddingResult { + + /** + * Shape of the result + */ + private List shape; + + /** + * Embedding data + */ + private List> data; + + /** + * Default constructor. + */ + public EmbeddingResult() { + } + } + +} diff --git a/langchain4j-workers-ai/src/main/java/dev/langchain4j/model/workersai/client/WorkersAiImageGenerationRequest.java b/langchain4j-workers-ai/src/main/java/dev/langchain4j/model/workersai/client/WorkersAiImageGenerationRequest.java new file mode 100644 index 0000000000..f8d13c3973 --- /dev/null +++ b/langchain4j-workers-ai/src/main/java/dev/langchain4j/model/workersai/client/WorkersAiImageGenerationRequest.java @@ -0,0 +1,47 @@ +package dev.langchain4j.model.workersai.client; + +import lombok.AllArgsConstructor; +import lombok.Data; + +/** + * Request to generate an image. + */ +@Data @AllArgsConstructor +public class WorkersAiImageGenerationRequest { + + /** + * Prompt to generate the image. + */ + String prompt; + + /** + * Source image to edit + */ + int[] image; + + /** + * Mask image to edit (optional) + */ + int[] mask; + + /** + * Mask operation to apply. + */ + Integer num_steps; + + /** + * Strength + */ + Integer strength; + + /** + * File to save the image. + */ + String destinationFile; + + /** + * Default constructor. + */ + public WorkersAiImageGenerationRequest() { + } +} diff --git a/langchain4j-workers-ai/src/main/java/dev/langchain4j/model/workersai/client/WorkersAiImageGenerationResponse.java b/langchain4j-workers-ai/src/main/java/dev/langchain4j/model/workersai/client/WorkersAiImageGenerationResponse.java new file mode 100644 index 0000000000..519ed89498 --- /dev/null +++ b/langchain4j-workers-ai/src/main/java/dev/langchain4j/model/workersai/client/WorkersAiImageGenerationResponse.java @@ -0,0 +1,36 @@ +package dev.langchain4j.model.workersai.client; + +import lombok.AllArgsConstructor; +import lombok.Data; + +import java.io.InputStream; + +/** + * Response to generate an image. + */ +public class WorkersAiImageGenerationResponse + extends ApiResponse { + + /** + * Default constructor. + */ + public WorkersAiImageGenerationResponse() { + } + + /** + * Body of the image generating process + */ + @Data + @AllArgsConstructor + public static class ImageGenerationResult { + private InputStream image; + + /** + * Default constructor. + */ + @SuppressWarnings("unused") + public ImageGenerationResult() { + } + } + +} diff --git a/langchain4j-workers-ai/src/main/java/dev/langchain4j/model/workersai/client/WorkersAiTextCompletionRequest.java b/langchain4j-workers-ai/src/main/java/dev/langchain4j/model/workersai/client/WorkersAiTextCompletionRequest.java new file mode 100644 index 0000000000..bd90d2cfcc --- /dev/null +++ b/langchain4j-workers-ai/src/main/java/dev/langchain4j/model/workersai/client/WorkersAiTextCompletionRequest.java @@ -0,0 +1,19 @@ +package dev.langchain4j.model.workersai.client; + +import lombok.AllArgsConstructor; +import lombok.Data; + +/** + * Request to complete a text. + */ +@Data @AllArgsConstructor +public class WorkersAiTextCompletionRequest { + + String prompt; + + /** + * Default constructor. + */ + public WorkersAiTextCompletionRequest() { + } +} diff --git a/langchain4j-workers-ai/src/main/java/dev/langchain4j/model/workersai/client/WorkersAiTextCompletionResponse.java b/langchain4j-workers-ai/src/main/java/dev/langchain4j/model/workersai/client/WorkersAiTextCompletionResponse.java new file mode 100644 index 0000000000..73451b85dc --- /dev/null +++ b/langchain4j-workers-ai/src/main/java/dev/langchain4j/model/workersai/client/WorkersAiTextCompletionResponse.java @@ -0,0 +1,31 @@ +package dev.langchain4j.model.workersai.client; + +import lombok.Data; + +/** + * Wrapper for the text completion response. + */ +public class WorkersAiTextCompletionResponse extends ApiResponse { + + /** + * Default constructor. + */ + public WorkersAiTextCompletionResponse() {} + + /** + * Wrapper for the text completion response. + */ + @Data + public static class TextResponse { + + /** + * The generated text. + */ + private String response; + + /** + * Default constructor. + */ + public TextResponse() {} + } +} diff --git a/langchain4j-workers-ai/src/main/java/dev/langchain4j/model/workersai/spi/WorkersAiChatModelBuilderFactory.java b/langchain4j-workers-ai/src/main/java/dev/langchain4j/model/workersai/spi/WorkersAiChatModelBuilderFactory.java new file mode 100644 index 0000000000..d1e40ffe64 --- /dev/null +++ b/langchain4j-workers-ai/src/main/java/dev/langchain4j/model/workersai/spi/WorkersAiChatModelBuilderFactory.java @@ -0,0 +1,11 @@ +package dev.langchain4j.model.workersai.spi; + +import dev.langchain4j.model.workersai.WorkersAiChatModel; + +import java.util.function.Supplier; + +/** + * A factory for building {@link WorkersAiChatModel.Builder} instances. + */ +public interface WorkersAiChatModelBuilderFactory extends Supplier { +} diff --git a/langchain4j-workers-ai/src/main/java/dev/langchain4j/model/workersai/spi/WorkersAiEmbeddingModelBuilderFactory.java b/langchain4j-workers-ai/src/main/java/dev/langchain4j/model/workersai/spi/WorkersAiEmbeddingModelBuilderFactory.java new file mode 100644 index 0000000000..302766d25c --- /dev/null +++ b/langchain4j-workers-ai/src/main/java/dev/langchain4j/model/workersai/spi/WorkersAiEmbeddingModelBuilderFactory.java @@ -0,0 +1,11 @@ +package dev.langchain4j.model.workersai.spi; + +import dev.langchain4j.model.workersai.WorkersAiEmbeddingModel; + +import java.util.function.Supplier; + +/** + * A factory for building {@link WorkersAiEmbeddingModel.Builder} instances. + */ +public interface WorkersAiEmbeddingModelBuilderFactory extends Supplier { +} diff --git a/langchain4j-workers-ai/src/main/java/dev/langchain4j/model/workersai/spi/WorkersAiImageModelBuilderFactory.java b/langchain4j-workers-ai/src/main/java/dev/langchain4j/model/workersai/spi/WorkersAiImageModelBuilderFactory.java new file mode 100644 index 0000000000..3db4040718 --- /dev/null +++ b/langchain4j-workers-ai/src/main/java/dev/langchain4j/model/workersai/spi/WorkersAiImageModelBuilderFactory.java @@ -0,0 +1,12 @@ +package dev.langchain4j.model.workersai.spi; + +import dev.langchain4j.model.workersai.WorkersAiImageModel; +import dev.langchain4j.model.workersai.WorkersAiLanguageModel; + +import java.util.function.Supplier; + +/** + * A factory for building {@link WorkersAiImageModel.Builder} instances. + */ +public interface WorkersAiImageModelBuilderFactory extends Supplier { +} diff --git a/langchain4j-workers-ai/src/main/java/dev/langchain4j/model/workersai/spi/WorkersAiLanguageModelBuilderFactory.java b/langchain4j-workers-ai/src/main/java/dev/langchain4j/model/workersai/spi/WorkersAiLanguageModelBuilderFactory.java new file mode 100644 index 0000000000..bcd6522078 --- /dev/null +++ b/langchain4j-workers-ai/src/main/java/dev/langchain4j/model/workersai/spi/WorkersAiLanguageModelBuilderFactory.java @@ -0,0 +1,11 @@ +package dev.langchain4j.model.workersai.spi; + +import dev.langchain4j.model.workersai.WorkersAiLanguageModel; + +import java.util.function.Supplier; + +/** + * A factory for building {@link WorkersAiLanguageModel.Builder} instances. + */ +public interface WorkersAiLanguageModelBuilderFactory extends Supplier { +} diff --git a/langchain4j-workers-ai/src/test/java/dev/langchain4j/model/workersai/WorkerAIChatModelIT.java b/langchain4j-workers-ai/src/test/java/dev/langchain4j/model/workersai/WorkerAIChatModelIT.java new file mode 100644 index 0000000000..65640fb761 --- /dev/null +++ b/langchain4j-workers-ai/src/test/java/dev/langchain4j/model/workersai/WorkerAIChatModelIT.java @@ -0,0 +1,73 @@ +package dev.langchain4j.model.workersai; + +import dev.langchain4j.agent.tool.ToolSpecification; +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.data.message.UserMessage; +import dev.langchain4j.model.output.Response; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + +import java.util.ArrayList; +import java.util.List; + +import static dev.langchain4j.data.message.SystemMessage.systemMessage; +import static dev.langchain4j.data.message.UserMessage.userMessage; +import static dev.langchain4j.model.output.FinishReason.STOP; +import static dev.langchain4j.model.workersai.WorkersAiChatModelName.LLAMA2_7B_FULL; +import static org.assertj.core.api.Assertions.assertThat; + +@EnabledIfEnvironmentVariable(named = "WORKERS_AI_API_KEY", matches = ".*") +@EnabledIfEnvironmentVariable(named = "WORKERS_AI_ACCOUNT_ID", matches = ".*") +class WorkerAIChatModelIT { + + static WorkersAiChatModel chatModel; + + @BeforeAll + static void initializeModel() { + chatModel = WorkersAiChatModel.builder() + .modelName(LLAMA2_7B_FULL.toString()) + .accountId(System.getenv("WORKERS_AI_ACCOUNT_ID")) + .apiToken(System.getenv("WORKERS_AI_API_KEY")) + .build(); + } + + @Test + void should_generate_answer_and_return_finish_reason_stop() { + UserMessage userMessage = userMessage("hello, how are you?"); + Response response = chatModel.generate(userMessage); + assertThat(response.content().text()).isNotBlank(); + assertThat(response.finishReason()).isEqualTo(STOP); + } + + @Test + void should_generate_answer_based_on_context() { + List conversation = new ArrayList<>(); + conversation.add(systemMessage("You an an assistant i will give you the name of " + + "a country and you will give me exactly the name of the capital, " + + "no other text or message, " + + "just the name of the city")); + conversation.add(userMessage("France")); + Response response = chatModel.generate(conversation); + Assertions.assertNotNull(response); + assertThat(response.content().text()).isNotBlank(); + Assertions.assertEquals("PARIS", chatModel.generate(conversation).content().text().toUpperCase()); + } + + @Test + void should_throw_unsupported_if_using_toolSpecification() { + List toolSpecifications = new ArrayList<>(); + toolSpecifications.add(ToolSpecification.builder().build()); + List messages = new ArrayList<>(); + messages.add(userMessage("hello, how are you?")); + Assertions.assertThrows(UnsupportedOperationException.class, () -> { + chatModel.generate(messages, toolSpecifications); + }); + } + + + + +} diff --git a/langchain4j-workers-ai/src/test/java/dev/langchain4j/model/workersai/WorkerAIEmbeddingModelIT.java b/langchain4j-workers-ai/src/test/java/dev/langchain4j/model/workersai/WorkerAIEmbeddingModelIT.java new file mode 100644 index 0000000000..c96289cb0e --- /dev/null +++ b/langchain4j-workers-ai/src/test/java/dev/langchain4j/model/workersai/WorkerAIEmbeddingModelIT.java @@ -0,0 +1,45 @@ +package dev.langchain4j.model.workersai; + +import dev.langchain4j.data.document.Metadata; +import dev.langchain4j.data.embedding.Embedding; +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.model.output.Response; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + +import java.util.ArrayList; +import java.util.List; + +@EnabledIfEnvironmentVariable(named = "WORKERS_AI_API_KEY", matches = ".*") +@EnabledIfEnvironmentVariable(named = "WORKERS_AI_ACCOUNT_ID", matches = ".*") +class WorkerAIEmbeddingModelIT { + + static WorkersAiEmbeddingModel embeddingModel; + + @BeforeAll + static void initializeModel() { + embeddingModel = WorkersAiEmbeddingModel.builder() + .modelName(WorkersAiEmbeddingModelName.BAAI_EMBEDDING_BASE.toString()) + .accountId(System.getenv("WORKERS_AI_ACCOUNT_ID")) + .apiToken(System.getenv("WORKERS_AI_API_KEY")) + .build(); + } + + @Test + void generateEmbeddingSimple() { + Response out = embeddingModel.embed("Sentence1"); + Assertions.assertNotNull(out.content()); + } + + @Test + void generateEmbeddings() { + List data = new ArrayList<>(); + data.add(new TextSegment("Sentence1", new Metadata())); + data.add(new TextSegment("Sentence2", new Metadata())); + Response> out = embeddingModel.embedAll(data); + Assertions.assertNotNull(out.content()); + Assertions.assertEquals(2, out.content().size()); + } +} diff --git a/langchain4j-workers-ai/src/test/java/dev/langchain4j/model/workersai/WorkerAILanguageModelIT.java b/langchain4j-workers-ai/src/test/java/dev/langchain4j/model/workersai/WorkerAILanguageModelIT.java new file mode 100644 index 0000000000..b8dd8e2ada --- /dev/null +++ b/langchain4j-workers-ai/src/test/java/dev/langchain4j/model/workersai/WorkerAILanguageModelIT.java @@ -0,0 +1,31 @@ +package dev.langchain4j.model.workersai; + +import dev.langchain4j.model.output.Response; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + +@EnabledIfEnvironmentVariable(named = "WORKERS_AI_API_KEY", matches = ".*") +@EnabledIfEnvironmentVariable(named = "WORKERS_AI_ACCOUNT_ID", matches = ".*") +class WorkerAILanguageModelIT { + + static WorkersAiLanguageModel languageModel; + + @BeforeAll + static void initializeModel() { + languageModel = WorkersAiLanguageModel.builder() + .modelName(WorkersAiChatModelName.LLAMA2_7B_FULL.toString()) + .accountId(System.getenv("WORKERS_AI_ACCOUNT_ID")) + .apiToken(System.getenv("WORKERS_AI_API_KEY")) + .build(); + } + @Test + void generateText() { + Response joke = languageModel.generate("Tell me a joke about thw cloud"); + Assertions.assertNotNull(joke); + Assertions.assertNotNull(joke.content()); + Assertions.assertNotNull(joke.finishReason()); + + } +} diff --git a/langchain4j-workers-ai/src/test/java/dev/langchain4j/model/workersai/WorkerAiImageModelIT.java b/langchain4j-workers-ai/src/test/java/dev/langchain4j/model/workersai/WorkerAiImageModelIT.java new file mode 100644 index 0000000000..62c1a307ef --- /dev/null +++ b/langchain4j-workers-ai/src/test/java/dev/langchain4j/model/workersai/WorkerAiImageModelIT.java @@ -0,0 +1,84 @@ +package dev.langchain4j.model.workersai; + +import dev.langchain4j.data.image.Image; +import dev.langchain4j.model.output.Response; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + +import java.io.ByteArrayOutputStream; +import java.io.File; +import java.io.FileOutputStream; +import java.io.InputStream; +import java.net.HttpURLConnection; +import java.net.URL; +import java.util.Base64; + +@EnabledIfEnvironmentVariable(named = "WORKERS_AI_API_KEY", matches = ".*") +@EnabledIfEnvironmentVariable(named = "WORKERS_AI_ACCOUNT_ID", matches = ".*") +class WorkerAiImageModelIT { + + static WorkersAiImageModel imageModel; + + @BeforeAll + static void initializeModel() { + imageModel = WorkersAiImageModel.builder() + .modelName(WorkersAiImageModelName.STABLE_DIFFUSION_XL.toString()) + .accountId(System.getenv("WORKERS_AI_ACCOUNT_ID")) + .apiToken(System.getenv("WORKERS_AI_API_KEY")) + .build(); + } + + @Test + void should_generate_an_image_as_binary() { + Response image = imageModel.generate("Draw me a squirrel");; + Assertions.assertNotNull(image.content()); + Assertions.assertNotNull(image.content().base64Data()); + } + + @Test + void should_generate_an_image_as_file() { + String homeDirectory = System.getProperty("user.home"); + Response image = imageModel.generate("Draw me a squirrel", + System.getProperty("user.home") + "/langchain4j-squirrel.png");; + Assertions.assertTrue(image.content().exists()); + } + + @Test + void should_edit_source_image() throws Exception { + Image sourceImage = imageModel + .convertAsImage( + getImageFromUrl("https://pub-1fb693cb11cc46b2b2f656f51e015a2c.r2.dev/dog.png")); + Image maskImage = imageModel + .convertAsImage( + getImageFromUrl( "https://pub-1fb693cb11cc46b2b2f656f51e015a2c.r2.dev/dog.png")); + Response image = imageModel.edit(sourceImage, maskImage, "Face of a yellow cat, high resolution, sitting on a park bench"); + saveOutputToFile(Base64.getDecoder().decode(image.content().base64Data()), + System.getProperty("user.home") + "/Downloads/yellow_cat_on_park_bench.png"); + } + + private byte[] getImageFromUrl(String imageUrl) throws Exception { + URL url = new URL(imageUrl); + HttpURLConnection connection = (HttpURLConnection) url.openConnection(); + connection.setRequestMethod("GET"); + connection.setDoInput(true); + connection.connect(); + try (InputStream inputStream = connection.getInputStream(); + ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream()) { + byte[] buffer = new byte[1024]; + int bytesRead; + while ((bytesRead = inputStream.read(buffer)) != -1) { + byteArrayOutputStream.write(buffer, 0, bytesRead); + } + return byteArrayOutputStream.toByteArray(); + } + } + + private void saveOutputToFile(byte[] image, String destinationFile) throws Exception { + try (FileOutputStream fileOutputStream = new FileOutputStream(destinationFile)) { + fileOutputStream.write(image); + } + } + +} diff --git a/langchain4j-zhipu-ai/pom.xml b/langchain4j-zhipu-ai/pom.xml index bb1dd0153b..e3533d333f 100644 --- a/langchain4j-zhipu-ai/pom.xml +++ b/langchain4j-zhipu-ai/pom.xml @@ -7,7 +7,7 @@ dev.langchain4j langchain4j-parent - 0.30.0 + 0.32.0-SNAPSHOT ../langchain4j-parent/pom.xml @@ -28,7 +28,7 @@ com.squareup.retrofit2 - converter-gson + converter-jackson diff --git a/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/AssistantMessageTypeAdapter.java b/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/AssistantMessageTypeAdapter.java deleted file mode 100644 index c6734bb872..0000000000 --- a/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/AssistantMessageTypeAdapter.java +++ /dev/null @@ -1,77 +0,0 @@ -package dev.langchain4j.model.zhipu; - -import com.google.gson.Gson; -import com.google.gson.TypeAdapter; -import com.google.gson.TypeAdapterFactory; -import com.google.gson.reflect.TypeToken; -import com.google.gson.stream.JsonReader; -import com.google.gson.stream.JsonWriter; -import dev.langchain4j.model.zhipu.chat.AssistantMessage; -import dev.langchain4j.model.zhipu.chat.ToolCall; - -import java.io.IOException; -import java.util.List; - -class AssistantMessageTypeAdapter extends TypeAdapter { - - static final TypeAdapterFactory ASSISTANT_MESSAGE_TYPE_ADAPTER_FACTORY = new TypeAdapterFactory() { - - @Override - @SuppressWarnings("unchecked") - public TypeAdapter create(Gson gson, TypeToken type) { - if (type.getRawType() != AssistantMessage.class) { - return null; - } - TypeAdapter delegate = - (TypeAdapter) gson.getDelegateAdapter(this, type); - return (TypeAdapter) new AssistantMessageTypeAdapter(delegate); - } - }; - - private final TypeAdapter delegate; - - private AssistantMessageTypeAdapter(TypeAdapter delegate) { - this.delegate = delegate; - } - - @Override - public void write(JsonWriter out, AssistantMessage assistantMessage) throws IOException { - out.beginObject(); - - out.name("role"); - out.value(assistantMessage.getRole().toString().toLowerCase()); - - out.name("content"); - if (assistantMessage.getContent() == null) { - boolean serializeNulls = out.getSerializeNulls(); - out.setSerializeNulls(true); - out.nullValue(); // serialize "content": null - out.setSerializeNulls(serializeNulls); - } else { - out.value(assistantMessage.getContent()); - } - - if (assistantMessage.getName() != null) { - out.name("name"); - out.value(assistantMessage.getName()); - } - - List toolCalls = assistantMessage.getToolCalls(); - if (toolCalls != null && !toolCalls.isEmpty()) { - out.name("tool_calls"); - out.beginArray(); - TypeAdapter toolCallTypeAdapter = Json.GSON.getAdapter(ToolCall.class); - for (ToolCall toolCall : toolCalls) { - toolCallTypeAdapter.write(out, toolCall); - } - out.endArray(); - } - - out.endObject(); - } - - @Override - public AssistantMessage read(JsonReader in) throws IOException { - return delegate.read(in); - } -} \ No newline at end of file diff --git a/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/AuthorizationInterceptor.java b/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/AuthorizationInterceptor.java index 7fecfa95e4..80c6d8af32 100644 --- a/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/AuthorizationInterceptor.java +++ b/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/AuthorizationInterceptor.java @@ -1,5 +1,6 @@ package dev.langchain4j.model.zhipu; +import com.fasterxml.jackson.core.JsonProcessingException; import com.google.common.cache.Cache; import com.google.common.cache.CacheBuilder; import io.jsonwebtoken.Jwts; @@ -62,7 +63,7 @@ public Response intercept(Chain chain) throws IOException { return chain.proceed(request); } - private String generateToken() { + private String generateToken() throws JsonProcessingException { String[] apiKeyParts = this.apiKey.split("\\."); String keyId = apiKeyParts[0]; String secret = apiKeyParts[1]; diff --git a/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/DefaultZhipuAiHelper.java b/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/DefaultZhipuAiHelper.java index 5a0c07c829..0f456d30bb 100644 --- a/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/DefaultZhipuAiHelper.java +++ b/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/DefaultZhipuAiHelper.java @@ -11,9 +11,13 @@ import dev.langchain4j.model.output.TokenUsage; import dev.langchain4j.model.zhipu.chat.*; import dev.langchain4j.model.zhipu.embedding.EmbeddingResponse; +import dev.langchain4j.model.zhipu.shared.ErrorResponse; import dev.langchain4j.model.zhipu.shared.Usage; +import okhttp3.ResponseBody; +import java.io.IOException; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.stream.Collectors; @@ -44,6 +48,9 @@ private static Function toFunction(ToolSpecification toolSpecification) { } private static Parameters toFunctionParameters(ToolParameters toolParameters) { + if (toolParameters == null) { + return Parameters.builder().build(); + } return Parameters.builder() .properties(toolParameters.properties()) .required(toolParameters.required()) @@ -158,6 +165,38 @@ public static TokenUsage tokenUsageFrom(Usage zhipuUsage) { ); } + public static ChatCompletionResponse toChatErrorResponse(retrofit2.Response retrofitResponse) throws IOException { + try (ResponseBody errorBody = retrofitResponse.errorBody()) { + return ChatCompletionResponse.builder() + .choices(Collections.singletonList(toChatErrorChoice(errorBody))) + .usage(Usage.builder().build()) + .build(); + } + } + + /** + * error code see error codes document + */ + private static ChatCompletionChoice toChatErrorChoice(ResponseBody errorBody) throws IOException { + if (errorBody == null) { + return ChatCompletionChoice.builder() + .finishReason("other") + .build(); + } + ErrorResponse errorResponse = Json.fromJson(errorBody.string(), ErrorResponse.class); + // 1301: 系统检测到输入或生成内容可能包含不安全或敏感内容,请您避免输入易产生敏感内容的提示语,感谢您的配合 + if ("1301".equals(errorResponse.getError().get("code"))) { + return ChatCompletionChoice.builder() + .message(AssistantMessage.builder().content(errorResponse.getError().get("message")).build()) + .finishReason("sensitive") + .build(); + } + return ChatCompletionChoice.builder() + .message(AssistantMessage.builder().content(errorResponse.getError().get("message")).build()) + .finishReason("other") + .build(); + } + public static FinishReason finishReasonFrom(String finishReason) { if (finishReason == null) { return null; @@ -169,6 +208,8 @@ public static FinishReason finishReasonFrom(String finishReason) { return LENGTH; case "tool_calls": return TOOL_EXECUTION; + case "sensitive": + return CONTENT_FILTER; default: return OTHER; } diff --git a/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/Json.java b/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/Json.java index b8e944b681..c7bc24d157 100644 --- a/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/Json.java +++ b/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/Json.java @@ -1,23 +1,27 @@ package dev.langchain4j.model.zhipu; -import com.google.gson.Gson; -import com.google.gson.GsonBuilder; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import lombok.SneakyThrows; -import static com.google.gson.FieldNamingPolicy.LOWER_CASE_WITH_UNDERSCORES; -import static dev.langchain4j.model.zhipu.AssistantMessageTypeAdapter.ASSISTANT_MESSAGE_TYPE_ADAPTER_FACTORY; +import static com.fasterxml.jackson.databind.SerializationFeature.INDENT_OUTPUT; class Json { - public static final Gson GSON = new GsonBuilder() - .setFieldNamingPolicy(LOWER_CASE_WITH_UNDERSCORES) - .registerTypeAdapterFactory(ASSISTANT_MESSAGE_TYPE_ADAPTER_FACTORY) - .setPrettyPrinting() - .create(); + static final ObjectMapper OBJECT_MAPPER = new ObjectMapper() + .enable(INDENT_OUTPUT); + @SneakyThrows static String toJson(Object o) { - return GSON.toJson(o); + return OBJECT_MAPPER.writeValueAsString(o); } + @SneakyThrows static T fromJson(String json, Class type) { - return GSON.fromJson(json, type); + return OBJECT_MAPPER.readValue(json, type); + } + + @SneakyThrows + static T fromJson(String json, TypeReference type) { + return OBJECT_MAPPER.readValue(json, type); } } \ No newline at end of file diff --git a/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/RequestLoggingInterceptor.java b/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/RequestLoggingInterceptor.java index c8f6937419..dd8154e00b 100644 --- a/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/RequestLoggingInterceptor.java +++ b/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/RequestLoggingInterceptor.java @@ -1,12 +1,11 @@ package dev.langchain4j.model.zhipu; +import lombok.extern.slf4j.Slf4j; import okhttp3.Headers; import okhttp3.Interceptor; import okhttp3.Request; import okhttp3.Response; import okio.Buffer; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; import java.io.IOException; import java.util.regex.Matcher; @@ -15,10 +14,9 @@ import static java.util.stream.StreamSupport.stream; +@Slf4j class RequestLoggingInterceptor implements Interceptor { - private static final Logger log = LoggerFactory.getLogger(RequestLoggingInterceptor.class); - private static final Pattern BEARER_PATTERN = Pattern.compile("(Bearer\\s)(\\w{2})(\\w+)(\\w{2})"); static String inOneLine(Headers headers) { diff --git a/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/ResponseLoggingInterceptor.java b/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/ResponseLoggingInterceptor.java index 194843e099..e01caa296f 100644 --- a/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/ResponseLoggingInterceptor.java +++ b/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/ResponseLoggingInterceptor.java @@ -1,44 +1,42 @@ package dev.langchain4j.model.zhipu; +import lombok.extern.slf4j.Slf4j; import okhttp3.Interceptor; import okhttp3.Request; import okhttp3.Response; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; import java.io.IOException; import static dev.langchain4j.model.zhipu.RequestLoggingInterceptor.inOneLine; +@Slf4j class ResponseLoggingInterceptor implements Interceptor { - private static final Logger log = LoggerFactory.getLogger(ResponseLoggingInterceptor.class); - - public Response intercept(Interceptor.Chain chain) throws IOException { + @Override + public Response intercept(Chain chain) throws IOException { Request request = chain.request(); Response response = chain.proceed(request); - log(response); + this.log(response); return response; } - void log(Response response) { - log.debug( - "Response:\n" + - "- status code: {}\n" + - "- headers: {}\n" + - "- body: {}", - response.code(), - inOneLine(response.headers()), - getBody(response) - ); - } - - private String getBody(Response response) { + private void log(Response response) { try { - return response.peekBody(Long.MAX_VALUE).string(); - } catch (IOException e) { - log.warn("Failed to log response", e); - return "[failed to log response]"; + log.debug("Response:\n- status code: {}\n- headers: {}\n- body: {}", + response.code(), inOneLine(response.headers()), this.getBody(response)); + } catch (Exception e) { + log.warn("Error while logging response: {}", e.getMessage()); } } + + private String getBody(Response response) throws IOException { + return isEventStream(response) + ? "[skipping response body due to streaming]" + : response.peekBody(Long.MAX_VALUE).string(); + } + + private static boolean isEventStream(Response response) { + String contentType = response.header("Content-Type"); + return contentType != null && contentType.contains("event-stream"); + } } \ No newline at end of file diff --git a/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/ZhipuAiApi.java b/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/ZhipuAiApi.java index ec64500dfe..0ff291c661 100644 --- a/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/ZhipuAiApi.java +++ b/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/ZhipuAiApi.java @@ -4,6 +4,8 @@ import dev.langchain4j.model.zhipu.chat.ChatCompletionResponse; import dev.langchain4j.model.zhipu.embedding.EmbeddingRequest; import dev.langchain4j.model.zhipu.embedding.EmbeddingResponse; +import dev.langchain4j.model.zhipu.image.ImageRequest; +import dev.langchain4j.model.zhipu.image.ImageResponse; import okhttp3.ResponseBody; import retrofit2.Call; import retrofit2.http.Body; @@ -25,4 +27,8 @@ interface ZhipuAiApi { @POST("api/paas/v4/embeddings") @Headers({"Content-Type: application/json"}) Call embeddings(@Body EmbeddingRequest request); + + @POST("api/paas/v4/images/generations") + @Headers({"Content-Type: application/json"}) + Call generations(@Body ImageRequest request); } diff --git a/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/ZhipuAiChatModel.java b/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/ZhipuAiChatModel.java index 2089d48e67..28226ef8aa 100644 --- a/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/ZhipuAiChatModel.java +++ b/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/ZhipuAiChatModel.java @@ -5,26 +5,22 @@ import dev.langchain4j.data.message.ChatMessage; import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.output.Response; -import dev.langchain4j.model.zhipu.chat.ChatCompletionModel; import dev.langchain4j.model.zhipu.chat.ChatCompletionRequest; import dev.langchain4j.model.zhipu.chat.ChatCompletionResponse; -import dev.langchain4j.model.zhipu.chat.ToolChoiceMode; import dev.langchain4j.model.zhipu.spi.ZhipuAiChatModelBuilderFactory; import lombok.Builder; -import java.util.Collections; import java.util.List; import static dev.langchain4j.internal.RetryUtils.withRetry; import static dev.langchain4j.internal.Utils.getOrDefault; import static dev.langchain4j.internal.Utils.isNullOrEmpty; import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty; -import static dev.langchain4j.model.zhipu.DefaultZhipuAiHelper.aiMessageFrom; -import static dev.langchain4j.model.zhipu.DefaultZhipuAiHelper.finishReasonFrom; -import static dev.langchain4j.model.zhipu.DefaultZhipuAiHelper.toTools; -import static dev.langchain4j.model.zhipu.DefaultZhipuAiHelper.toZhipuAiMessages; -import static dev.langchain4j.model.zhipu.DefaultZhipuAiHelper.tokenUsageFrom; +import static dev.langchain4j.model.zhipu.DefaultZhipuAiHelper.*; +import static dev.langchain4j.model.zhipu.chat.ChatCompletionModel.GLM_4; +import static dev.langchain4j.model.zhipu.chat.ToolChoiceMode.AUTO; import static dev.langchain4j.spi.ServiceHelper.loadFactories; +import static java.util.Collections.singletonList; /** * Represents an ZhipuAi language model with a chat completion interface, such as glm-3-turbo and glm-4. @@ -55,7 +51,7 @@ public ZhipuAiChatModel( this.baseUrl = getOrDefault(baseUrl, "https://open.bigmodel.cn/"); this.temperature = getOrDefault(temperature, 0.7); this.topP = topP; - this.model = getOrDefault(model, ChatCompletionModel.GLM_4.toString()); + this.model = getOrDefault(model, GLM_4.toString()); this.maxRetries = getOrDefault(maxRetries, 3); this.maxToken = getOrDefault(maxToken, 512); this.client = ZhipuAiClient.builder() @@ -88,7 +84,7 @@ public Response generate(List messages, List generate(List messages, List generate(List messages, ToolSpecification toolSpecification) { - return generate(messages, toolSpecification != null ? Collections.singletonList(toolSpecification) : null); + return generate(messages, toolSpecification != null ? singletonList(toolSpecification) : null); } public static class ZhipuAiChatModelBuilder { diff --git a/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/ZhipuAiClient.java b/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/ZhipuAiClient.java index 935fd6f827..156c7d5083 100644 --- a/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/ZhipuAiClient.java +++ b/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/ZhipuAiClient.java @@ -12,30 +12,30 @@ import dev.langchain4j.model.zhipu.chat.ToolCall; import dev.langchain4j.model.zhipu.embedding.EmbeddingRequest; import dev.langchain4j.model.zhipu.embedding.EmbeddingResponse; +import dev.langchain4j.model.zhipu.image.ImageRequest; +import dev.langchain4j.model.zhipu.image.ImageResponse; import dev.langchain4j.model.zhipu.shared.Usage; +import lombok.extern.slf4j.Slf4j; import okhttp3.OkHttpClient; import okhttp3.ResponseBody; import okhttp3.sse.EventSource; import okhttp3.sse.EventSourceListener; import okhttp3.sse.EventSources; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; +import org.jetbrains.annotations.NotNull; import retrofit2.Retrofit; -import retrofit2.converter.gson.GsonConverterFactory; import java.io.IOException; import java.time.Duration; import java.util.List; import static dev.langchain4j.internal.Utils.isNullOrEmpty; -import static dev.langchain4j.model.zhipu.DefaultZhipuAiHelper.finishReasonFrom; -import static dev.langchain4j.model.zhipu.DefaultZhipuAiHelper.specificationsFrom; -import static dev.langchain4j.model.zhipu.DefaultZhipuAiHelper.tokenUsageFrom; -import static dev.langchain4j.model.zhipu.Json.GSON; +import static dev.langchain4j.model.zhipu.DefaultZhipuAiHelper.*; +import static dev.langchain4j.model.zhipu.Json.OBJECT_MAPPER; +import static retrofit2.converter.jackson.JacksonConverterFactory.create; +@Slf4j public class ZhipuAiClient { - private static final Logger log = LoggerFactory.getLogger(ZhipuAiClient.class); private final String baseUrl; private final ZhipuAiApi zhipuAiApi; @@ -65,7 +65,7 @@ public ZhipuAiClient(Builder builder) { Retrofit retrofit = (new Retrofit.Builder()) .baseUrl(formattedUrlForRetrofit(this.baseUrl)) .client(this.okHttpClient) - .addConverterFactory(GsonConverterFactory.create(GSON)) + .addConverterFactory(create(OBJECT_MAPPER)) .build(); this.zhipuAiApi = retrofit.create(ZhipuAiApi.class); } @@ -85,9 +85,8 @@ public ChatCompletionResponse chatCompletion(ChatCompletionRequest request) { = zhipuAiApi.chatCompletion(request).execute(); if (retrofitResponse.isSuccessful()) { return retrofitResponse.body(); - } else { - throw toException(retrofitResponse); } + return toChatErrorResponse(retrofitResponse); } catch (IOException e) { throw new RuntimeException(e); } @@ -114,14 +113,14 @@ void streamingChatCompletion(ChatCompletionRequest request, StreamingResponseHan FinishReason finishReason; @Override - public void onOpen(EventSource eventSource, okhttp3.Response response) { + public void onOpen(@NotNull EventSource eventSource, @NotNull okhttp3.Response response) { if (logResponses) { log.debug("onOpen()"); } } @Override - public void onEvent(EventSource eventSource, String id, String type, String data) { + public void onEvent(@NotNull EventSource eventSource, String id, String type, @NotNull String data) { if (logResponses) { log.debug("onEvent() {}", data); } @@ -140,7 +139,7 @@ public void onEvent(EventSource eventSource, String id, String type, String data handler.onComplete(response); } else { try { - ChatCompletionResponse chatCompletionResponse = Json.fromJson(data, ChatCompletionResponse.class); + ChatCompletionResponse chatCompletionResponse = OBJECT_MAPPER.readValue(data, ChatCompletionResponse.class); ChatCompletionChoice zhipuChatCompletionChoice = chatCompletionResponse.getChoices().get(0); String chunk = zhipuChatCompletionChoice.getDelta().getContent(); contentBuilder.append(chunk); @@ -167,7 +166,7 @@ public void onEvent(EventSource eventSource, String id, String type, String data } @Override - public void onFailure(EventSource eventSource, Throwable t, okhttp3.Response response) { + public void onFailure(@NotNull EventSource eventSource, Throwable t, okhttp3.Response response) { if (logResponses) { log.debug("onFailure()", t); } @@ -180,7 +179,7 @@ public void onFailure(EventSource eventSource, Throwable t, okhttp3.Response res } @Override - public void onClosed(EventSource eventSource) { + public void onClosed(@NotNull EventSource eventSource) { if (logResponses) { log.debug("onClosed()"); } @@ -191,34 +190,36 @@ public void onClosed(EventSource eventSource) { zhipuAiApi.streamingChatCompletion(request).request(), eventSourceListener ); - -// zhipuApi.streamingChatCompletion(request).enqueue(new Callback() { -// @Override -// public void onResponse(Call call, Response response) { -// -// } -// -// @Override -// public void onFailure(Call call, Throwable t) { -// -// } -// }); } private RuntimeException toException(retrofit2.Response retrofitResponse) throws IOException { int code = retrofitResponse.code(); if (code >= 400) { - ResponseBody errorBody = retrofitResponse.errorBody(); - if (errorBody != null) { - String errorBodyString = errorBody.string(); - String errorMessage = String.format("status code: %s; body: %s", code, errorBodyString); - log.error("Error response: {}", errorMessage); - return new RuntimeException(errorMessage); + try (ResponseBody errorBody = retrofitResponse.errorBody()) { + if (errorBody != null) { + String errorBodyString = errorBody.string(); + String errorMessage = String.format("status code: %s; body: %s", code, errorBodyString); + log.error("Error response: {}", errorMessage); + return new RuntimeException(errorMessage); + } } } return new RuntimeException(retrofitResponse.message()); } + public ImageResponse imagesGeneration(ImageRequest request) { + try { + retrofit2.Response responseResponse = zhipuAiApi.generations(request).execute(); + if (responseResponse.isSuccessful()) { + return responseResponse.body(); + } else { + throw toException(responseResponse); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } + public static class Builder { private String baseUrl; private String apiKey; diff --git a/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/ZhipuAiEmbeddingModel.java b/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/ZhipuAiEmbeddingModel.java index 5012863402..d969d448e5 100644 --- a/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/ZhipuAiEmbeddingModel.java +++ b/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/ZhipuAiEmbeddingModel.java @@ -10,14 +10,14 @@ import dev.langchain4j.model.zhipu.spi.ZhipuAiEmbeddingModelBuilderFactory; import lombok.Builder; -import java.util.LinkedList; import java.util.List; -import java.util.stream.Collectors; import static dev.langchain4j.internal.RetryUtils.withRetry; import static dev.langchain4j.internal.Utils.getOrDefault; import static dev.langchain4j.model.zhipu.DefaultZhipuAiHelper.*; +import static dev.langchain4j.model.zhipu.embedding.EmbeddingModel.EMBEDDING_2; import static dev.langchain4j.spi.ServiceHelper.loadFactories; +import static java.util.stream.Collectors.toList; /** * Represents an ZhipuAI embedding model, such as embedding-2. @@ -39,7 +39,7 @@ public ZhipuAiEmbeddingModel( Boolean logResponses ) { this.baseUrl = getOrDefault(baseUrl, "https://open.bigmodel.cn/"); - this.model = getOrDefault(model, dev.langchain4j.model.zhipu.embedding.EmbeddingModel.EMBEDDING_2.toString()); + this.model = getOrDefault(model, EMBEDDING_2.toString()); this.maxRetries = getOrDefault(maxRetries, 3); this.client = ZhipuAiClient.builder() .baseUrl(this.baseUrl) @@ -66,7 +66,7 @@ public Response> embedAll(List textSegments) { .build() ) .map(request -> withRetry(() -> client.embedAll(request), maxRetries)) - .collect(Collectors.toList()); + .collect(toList()); Usage usage = getEmbeddingUsage(embeddingRequests); diff --git a/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/ZhipuAiImageModel.java b/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/ZhipuAiImageModel.java new file mode 100644 index 0000000000..7a1b560383 --- /dev/null +++ b/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/ZhipuAiImageModel.java @@ -0,0 +1,70 @@ +package dev.langchain4j.model.zhipu; + +import dev.langchain4j.data.image.Image; +import dev.langchain4j.model.image.ImageModel; +import dev.langchain4j.model.output.Response; +import dev.langchain4j.model.zhipu.image.ImageModelName; +import dev.langchain4j.model.zhipu.image.ImageRequest; +import dev.langchain4j.model.zhipu.image.ImageResponse; +import lombok.Builder; + +import static dev.langchain4j.internal.RetryUtils.withRetry; +import static dev.langchain4j.internal.Utils.getOrDefault; + +public class ZhipuAiImageModel implements ImageModel { + + private final String model; + private final String userId; + private final String baseUrl; + private final Integer maxRetries; + private final ZhipuAiClient client; + + /** + * Instantiates ZhipuAi cogview-3 image processing model. + * Find the parameters description here. + * + * @param model cogview-3 is default + * @param userId A unique identifier representing your end-user, which can help ZhipuAI to monitor + * and detect abuse. User ID length requirement: minimum of 6 characters, maximum of + * 128 characters + */ + @Builder + public ZhipuAiImageModel( + String model, + String userId, + String apiKey, + String baseUrl, + Integer maxRetries, + Boolean logRequests, + Boolean logResponses + ) { + this.baseUrl = getOrDefault(baseUrl, "https://open.bigmodel.cn/"); + this.model = getOrDefault(model, ImageModelName.COGVIEW_3.toString()); + this.maxRetries = getOrDefault(maxRetries, 3); + this.userId = userId; + this.client = ZhipuAiClient.builder() + .baseUrl(this.baseUrl) + .apiKey(apiKey) + .logRequests(getOrDefault(logRequests, false)) + .logResponses(getOrDefault(logResponses, false)) + .build(); + } + + @Override + public Response generate(String prompt) { + ImageRequest request = ImageRequest.builder() + .prompt(prompt) + .userId(userId) + .model(model) + .build(); + ImageResponse response = withRetry(() -> client.imagesGeneration(request), maxRetries); + if (response == null) { + return Response.from(Image.builder().build()); + } + return Response.from( + Image.builder() + .url(response.getData().get(0).getUrl()) + .build() + ); + } +} diff --git a/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/ZhipuAiStreamingChatModel.java b/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/ZhipuAiStreamingChatModel.java index a666d559b1..219d3a3b24 100644 --- a/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/ZhipuAiStreamingChatModel.java +++ b/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/ZhipuAiStreamingChatModel.java @@ -6,13 +6,11 @@ import dev.langchain4j.data.message.UserMessage; import dev.langchain4j.model.StreamingResponseHandler; import dev.langchain4j.model.chat.StreamingChatLanguageModel; -import dev.langchain4j.model.zhipu.chat.ChatCompletionModel; import dev.langchain4j.model.zhipu.chat.ChatCompletionRequest; import dev.langchain4j.model.zhipu.chat.ToolChoiceMode; import dev.langchain4j.model.zhipu.spi.ZhipuAiStreamingChatModelBuilderFactory; import lombok.Builder; -import java.util.Collections; import java.util.List; import static dev.langchain4j.internal.Utils.getOrDefault; @@ -20,7 +18,9 @@ import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty; import static dev.langchain4j.model.zhipu.DefaultZhipuAiHelper.toTools; import static dev.langchain4j.model.zhipu.DefaultZhipuAiHelper.toZhipuAiMessages; +import static dev.langchain4j.model.zhipu.chat.ChatCompletionModel.GLM_4; import static dev.langchain4j.spi.ServiceHelper.loadFactories; +import static java.util.Collections.singletonList; public class ZhipuAiStreamingChatModel implements StreamingChatLanguageModel { @@ -45,7 +45,7 @@ public ZhipuAiStreamingChatModel( this.baseUrl = getOrDefault(baseUrl, "https://open.bigmodel.cn/"); this.temperature = getOrDefault(temperature, 0.7); this.topP = topP; - this.model = getOrDefault(model, ChatCompletionModel.GLM_4.toString()); + this.model = getOrDefault(model, GLM_4.toString()); this.maxToken = getOrDefault(maxToken, 512); this.client = ZhipuAiClient.builder() .baseUrl(this.baseUrl) @@ -64,7 +64,7 @@ public static ZhipuAiStreamingChatModelBuilder builder() { @Override public void generate(String userMessage, StreamingResponseHandler handler) { - this.generate(Collections.singletonList(UserMessage.from(userMessage)), handler); + this.generate(singletonList(UserMessage.from(userMessage)), handler); } @Override @@ -94,7 +94,7 @@ public void generate(List messages, List toolSpe @Override public void generate(List messages, ToolSpecification toolSpecification, StreamingResponseHandler handler) { - this.generate(messages, toolSpecification == null ? null : Collections.singletonList(toolSpecification), handler); + this.generate(messages, toolSpecification == null ? null : singletonList(toolSpecification), handler); } public static class ZhipuAiStreamingChatModelBuilder { diff --git a/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/chat/AssistantMessage.java b/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/chat/AssistantMessage.java index be0e259c24..591d4201c5 100644 --- a/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/chat/AssistantMessage.java +++ b/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/chat/AssistantMessage.java @@ -1,28 +1,32 @@ package dev.langchain4j.model.zhipu.chat; -import com.google.gson.annotations.SerializedName; -import lombok.EqualsAndHashCode; -import lombok.Getter; -import lombok.ToString; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategies.SnakeCaseStrategy; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; import java.util.List; +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; import static dev.langchain4j.model.zhipu.chat.Role.ASSISTANT; import static java.util.Arrays.asList; import static java.util.Collections.unmodifiableList; -@ToString -@EqualsAndHashCode +@Data +@NoArgsConstructor +@AllArgsConstructor +@JsonInclude(NON_NULL) +@JsonNaming(SnakeCaseStrategy.class) +@JsonIgnoreProperties(ignoreUnknown = true) public final class AssistantMessage implements Message { private final Role role = ASSISTANT; - @Getter - private final String content; - @Getter - private final String name; - @Getter - @SerializedName("tool_calls") - private final List toolCalls; + private String content; + private String name; + private List toolCalls; private AssistantMessage(Builder builder) { this.content = builder.content; diff --git a/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/chat/ChatCompletionChoice.java b/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/chat/ChatCompletionChoice.java index 13cdf36064..44eb2c8e4a 100644 --- a/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/chat/ChatCompletionChoice.java +++ b/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/chat/ChatCompletionChoice.java @@ -1,63 +1,27 @@ package dev.langchain4j.model.zhipu.chat; -import com.google.gson.annotations.SerializedName; -import lombok.EqualsAndHashCode; -import lombok.Getter; -import lombok.ToString; - -@Getter -@ToString -@EqualsAndHashCode +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategies.SnakeCaseStrategy; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; + +@Data +@Builder +@NoArgsConstructor +@AllArgsConstructor +@JsonInclude(NON_NULL) +@JsonNaming(SnakeCaseStrategy.class) +@JsonIgnoreProperties(ignoreUnknown = true) public final class ChatCompletionChoice { - private final Integer index; - private final AssistantMessage message; - private final Delta delta; - @SerializedName("finish_reason") - private final String finishReason; - - private ChatCompletionChoice(Builder builder) { - this.index = builder.index; - this.message = builder.message; - this.delta = builder.delta; - this.finishReason = builder.finishReason; - } - - public static Builder builder() { - return new Builder(); - } - - public static final class Builder { - private Integer index; - private AssistantMessage message; - private Delta delta; - private String finishReason; - - private Builder() { - } - - public Builder index(Integer index) { - this.index = index; - return this; - } - - public Builder message(AssistantMessage message) { - this.message = message; - return this; - } - - public Builder delta(Delta delta) { - this.delta = delta; - return this; - } - - public Builder finishReason(String finishReason) { - this.finishReason = finishReason; - return this; - } - - public ChatCompletionChoice build() { - return new ChatCompletionChoice(this); - } - } + private Integer index; + private AssistantMessage message; + private Delta delta; + private String finishReason; } \ No newline at end of file diff --git a/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/chat/ChatCompletionModel.java b/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/chat/ChatCompletionModel.java index 496a1eab0a..c60775c023 100644 --- a/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/chat/ChatCompletionModel.java +++ b/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/chat/ChatCompletionModel.java @@ -2,6 +2,10 @@ public enum ChatCompletionModel { GLM_4("glm-4"), + GLM_4_0520("glm-4-0520"), + GLM_4_AIR("glm-4-air"), + GLM_4_AIRX("glm-4-airx"), + GLM_4_FLASH("glm-4-flash"), GLM_3_TURBO("glm-3-turbo"), CHATGLM_TURBO("chatglm_turbo"); diff --git a/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/chat/ChatCompletionRequest.java b/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/chat/ChatCompletionRequest.java index 4a894f0251..6d274e58f8 100644 --- a/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/chat/ChatCompletionRequest.java +++ b/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/chat/ChatCompletionRequest.java @@ -1,36 +1,36 @@ package dev.langchain4j.model.zhipu.chat; -import com.google.gson.annotations.SerializedName; -import lombok.EqualsAndHashCode; -import lombok.ToString; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategies.SnakeCaseStrategy; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import lombok.Data; import java.util.ArrayList; import java.util.List; +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; import static dev.langchain4j.model.zhipu.chat.ChatCompletionModel.GLM_4; import static java.util.Arrays.asList; import static java.util.Collections.unmodifiableList; -@ToString -@EqualsAndHashCode +@Data +@JsonInclude(NON_NULL) +@JsonNaming(SnakeCaseStrategy.class) +@JsonIgnoreProperties(ignoreUnknown = true) public final class ChatCompletionRequest { - private final String model; - private final List messages; - @SerializedName("request_id") - private final String requestId; - @SerializedName("do_sample") - private final String doSample; - private final Boolean stream; - private final Double temperature; - @SerializedName("top_p") - private final Double topP; - @SerializedName("max_tokens") - private final Integer maxTokens; - private final List stop; - private final List tools; - @SerializedName("tool_choice") - private final Object toolChoice; + private String model; + private List messages; + private String requestId; + private String doSample; + private Boolean stream; + private Double temperature; + private Double topP; + private Integer maxTokens; + private List stop; + private List tools; + private Object toolChoice; private ChatCompletionRequest(Builder builder) { this.model = builder.model; @@ -50,42 +50,6 @@ public static Builder builder() { return new Builder(); } - public String model() { - return model; - } - - public List messages() { - return messages; - } - - public Double temperature() { - return temperature; - } - - public Double topP() { - return topP; - } - - public Boolean stream() { - return stream; - } - - public List stop() { - return stop; - } - - public Integer maxTokens() { - return maxTokens; - } - - public List tools() { - return tools; - } - - public Object toolChoice() { - return toolChoice; - } - public static final class Builder { private String model = GLM_4.toString(); diff --git a/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/chat/ChatCompletionResponse.java b/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/chat/ChatCompletionResponse.java index aecddcaf86..1632b5defb 100644 --- a/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/chat/ChatCompletionResponse.java +++ b/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/chat/ChatCompletionResponse.java @@ -1,75 +1,30 @@ package dev.langchain4j.model.zhipu.chat; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategies.SnakeCaseStrategy; +import com.fasterxml.jackson.databind.annotation.JsonNaming; import dev.langchain4j.model.zhipu.shared.Usage; -import lombok.EqualsAndHashCode; -import lombok.Getter; -import lombok.ToString; +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; import java.util.List; -@Getter -@ToString -@EqualsAndHashCode -public final class ChatCompletionResponse { - private final String id; - private final Integer created; - private final String model; - private final List choices; - private final Usage usage; - - private ChatCompletionResponse(Builder builder) { - this.id = builder.id; - this.created = builder.created; - this.model = builder.model; - this.choices = builder.choices; - this.usage = builder.usage; - } - - public static Builder builder() { - return new Builder(); - } - - public String content() { - return getChoices().get(0).getMessage().getContent(); - } - - public static final class Builder { - private String id; - private Integer created; - private String model; - private List choices; - private Usage usage; - - private Builder() { - } +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; - public Builder id(String id) { - this.id = id; - return this; - } - - public Builder created(Integer created) { - this.created = created; - return this; - } - - public Builder model(String model) { - this.model = model; - return this; - } - - public Builder choices(List choices) { - this.choices = choices; - return this; - } - - public Builder usage(Usage usage) { - this.usage = usage; - return this; - } - - public ChatCompletionResponse build() { - return new ChatCompletionResponse(this); - } - } +@Data +@Builder +@NoArgsConstructor +@AllArgsConstructor +@JsonInclude(NON_NULL) +@JsonNaming(SnakeCaseStrategy.class) +@JsonIgnoreProperties(ignoreUnknown = true) +public final class ChatCompletionResponse { + private String id; + private Integer created; + private String model; + private List choices; + private Usage usage; } diff --git a/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/chat/Delta.java b/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/chat/Delta.java index 1e6592ca62..4c7697b0cb 100644 --- a/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/chat/Delta.java +++ b/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/chat/Delta.java @@ -1,20 +1,27 @@ package dev.langchain4j.model.zhipu.chat; -import com.google.gson.annotations.SerializedName; -import lombok.EqualsAndHashCode; -import lombok.Getter; -import lombok.ToString; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategies.SnakeCaseStrategy; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; import java.util.Collections; import java.util.List; -@Getter -@ToString -@EqualsAndHashCode +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; + +@Data +@NoArgsConstructor +@AllArgsConstructor +@JsonInclude(NON_NULL) +@JsonNaming(SnakeCaseStrategy.class) +@JsonIgnoreProperties(ignoreUnknown = true) public final class Delta { - private final String content; - @SerializedName("tool_calls") - private final List toolCalls; + private String content; + private List toolCalls; private Delta(Builder builder) { this.content = builder.content; diff --git a/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/chat/Function.java b/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/chat/Function.java index c225389beb..64457ea958 100644 --- a/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/chat/Function.java +++ b/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/chat/Function.java @@ -1,19 +1,26 @@ package dev.langchain4j.model.zhipu.chat; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategies.SnakeCaseStrategy; +import com.fasterxml.jackson.databind.annotation.JsonNaming; import dev.langchain4j.agent.tool.JsonSchemaProperty; -import lombok.EqualsAndHashCode; -import lombok.ToString; +import lombok.Data; import java.util.HashMap; import java.util.Map; -@ToString -@EqualsAndHashCode +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; + +@Data +@JsonInclude(NON_NULL) +@JsonNaming(SnakeCaseStrategy.class) +@JsonIgnoreProperties(ignoreUnknown = true) public final class Function { - private final String name; - private final String description; - private final Parameters parameters; + private String name; + private String description; + private Parameters parameters; private Function(Builder builder) { this.name = builder.name; @@ -62,7 +69,7 @@ public Builder parameters(Parameters parameters) { public Builder addParameter(String name, JsonSchemaProperty... jsonSchemaProperties) { this.addOptionalParameter(name, jsonSchemaProperties); - this.parameters.required().add(name); + this.parameters.getRequired().add(name); return this; } @@ -77,7 +84,7 @@ public Builder addOptionalParameter(String name, JsonSchemaProperty... jsonSchem jsonSchemaPropertiesMap.put(jsonSchemaProperty.key(), jsonSchemaProperty.value()); } - this.parameters.properties().put(name, jsonSchemaPropertiesMap); + this.parameters.getProperties().put(name, jsonSchemaPropertiesMap); return this; } diff --git a/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/chat/FunctionCall.java b/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/chat/FunctionCall.java index ce27702b2a..2318e9eb36 100644 --- a/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/chat/FunctionCall.java +++ b/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/chat/FunctionCall.java @@ -1,45 +1,25 @@ package dev.langchain4j.model.zhipu.chat; -import lombok.EqualsAndHashCode; -import lombok.Getter; -import lombok.ToString; - -@Getter -@ToString -@EqualsAndHashCode +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategies.SnakeCaseStrategy; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; + +@Data +@Builder +@NoArgsConstructor +@AllArgsConstructor +@JsonInclude(NON_NULL) +@JsonNaming(SnakeCaseStrategy.class) +@JsonIgnoreProperties(ignoreUnknown = true) public final class FunctionCall { - private final String name; - private final String arguments; - - private FunctionCall(Builder builder) { - this.name = builder.name; - this.arguments = builder.arguments; - } - - public static Builder builder() { - return new Builder(); - } - - public static final class Builder { - private String name; - private String arguments; - - private Builder() { - } - - public Builder name(String name) { - this.name = name; - return this; - } - - public Builder arguments(String arguments) { - this.arguments = arguments; - return this; - } - - public FunctionCall build() { - return new FunctionCall(this); - } - } + private String name; + private String arguments; } diff --git a/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/chat/Parameters.java b/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/chat/Parameters.java index f91e75d8ca..f0d92170a6 100644 --- a/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/chat/Parameters.java +++ b/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/chat/Parameters.java @@ -1,22 +1,25 @@ package dev.langchain4j.model.zhipu.chat; -import lombok.EqualsAndHashCode; -import lombok.ToString; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategies.SnakeCaseStrategy; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import lombok.Data; -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; +import java.util.*; +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; -@ToString -@EqualsAndHashCode + +@Data +@JsonInclude(NON_NULL) +@JsonNaming(SnakeCaseStrategy.class) +@JsonIgnoreProperties(ignoreUnknown = true) public final class Parameters { - private final String type; - private final Map> properties; - private final List required; + private String type; + private Map> properties; + private List required; private Parameters(Builder builder) { this.type = "object"; @@ -28,18 +31,6 @@ public static Builder builder() { return new Builder(); } - public String type() { - return "object"; - } - - public Map> properties() { - return this.properties; - } - - public List required() { - return this.required; - } - public static final class Builder { private Map> properties; private List required; diff --git a/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/chat/Retrieval.java b/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/chat/Retrieval.java index 9567eed037..88e4efc687 100644 --- a/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/chat/Retrieval.java +++ b/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/chat/Retrieval.java @@ -1,47 +1,20 @@ package dev.langchain4j.model.zhipu.chat; -import com.google.gson.annotations.SerializedName; -import lombok.EqualsAndHashCode; -import lombok.Getter; -import lombok.ToString; - -@Getter -@ToString -@EqualsAndHashCode +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategies.SnakeCaseStrategy; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import lombok.Builder; +import lombok.Data; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; + +@Data +@Builder +@JsonInclude(NON_NULL) +@JsonNaming(SnakeCaseStrategy.class) +@JsonIgnoreProperties(ignoreUnknown = true) public final class Retrieval { - @SerializedName("knowledge_id") - private final String knowledgeId; - @SerializedName("prompt_template") - private final String promptTemplate; - - Retrieval(RetrievalBuilder builder) { - this.knowledgeId = builder.knowledgeId; - this.promptTemplate = builder.promptTemplate; - } - - public static RetrievalBuilder builder() { - return new RetrievalBuilder(); - } - - public static class RetrievalBuilder { - private String knowledgeId; - private String promptTemplate; - - RetrievalBuilder() { - } - - public RetrievalBuilder knowledgeId(String knowledgeId) { - this.knowledgeId = knowledgeId; - return this; - } - - public RetrievalBuilder promptTemplate(String promptTemplate) { - this.promptTemplate = promptTemplate; - return this; - } - - public Retrieval build() { - return new Retrieval(this); - } - } + private String knowledgeId; + private String promptTemplate; } diff --git a/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/chat/Role.java b/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/chat/Role.java index 9c6d327e9c..b27d47facf 100644 --- a/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/chat/Role.java +++ b/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/chat/Role.java @@ -1,11 +1,18 @@ package dev.langchain4j.model.zhipu.chat; -import com.google.gson.annotations.SerializedName; +import com.fasterxml.jackson.annotation.JsonValue; + +import java.util.Locale; public enum Role { - @SerializedName("system") SYSTEM, - @SerializedName("user") USER, - @SerializedName("assistant") ASSISTANT, - @SerializedName("function") FUNCTION, - @SerializedName("tool") TOOL + SYSTEM, + USER, + ASSISTANT, + FUNCTION, + TOOL; + + @JsonValue + public String serialize() { + return name().toLowerCase(Locale.ROOT); + } } \ No newline at end of file diff --git a/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/chat/SystemMessage.java b/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/chat/SystemMessage.java index a4402a0ded..6426b71087 100644 --- a/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/chat/SystemMessage.java +++ b/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/chat/SystemMessage.java @@ -1,61 +1,29 @@ package dev.langchain4j.model.zhipu.chat; -import lombok.EqualsAndHashCode; -import lombok.Getter; -import lombok.ToString; - +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategies.SnakeCaseStrategy; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import lombok.Builder; +import lombok.Data; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; import static dev.langchain4j.model.zhipu.chat.Role.SYSTEM; -@ToString -@EqualsAndHashCode +@Data +@Builder +@JsonInclude(NON_NULL) +@JsonNaming(SnakeCaseStrategy.class) +@JsonIgnoreProperties(ignoreUnknown = true) public final class SystemMessage implements Message { - private final Role role = SYSTEM; - @Getter - private final String content; - @Getter - private final String name; - - private SystemMessage(Builder builder) { - this.content = builder.content; - this.name = builder.name; - } + private Role role = SYSTEM; + private String content; + private String name; public static SystemMessage from(String content) { return SystemMessage.builder() .content(content) .build(); } - - public static Builder builder() { - return new Builder(); - } - - @Override - public Role getRole() { - return role; - } - - public static final class Builder { - - private String content; - private String name; - - private Builder() { - } - - public Builder content(String content) { - this.content = content; - return this; - } - - public Builder name(String name) { - this.name = name; - return this; - } - - public SystemMessage build() { - return new SystemMessage(this); - } - } } diff --git a/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/chat/Tool.java b/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/chat/Tool.java index 9aaadbf031..a3823c94d5 100644 --- a/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/chat/Tool.java +++ b/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/chat/Tool.java @@ -1,20 +1,23 @@ package dev.langchain4j.model.zhipu.chat; -import com.google.gson.annotations.SerializedName; -import lombok.EqualsAndHashCode; -import lombok.Getter; -import lombok.ToString; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategies.SnakeCaseStrategy; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import lombok.Data; -@Getter -@ToString -@EqualsAndHashCode +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; + +@Data +@JsonInclude(NON_NULL) +@JsonNaming(SnakeCaseStrategy.class) +@JsonIgnoreProperties(ignoreUnknown = true) public final class Tool { - private final ToolType type; - private final Function function; - private final Retrieval retrieval; - @SerializedName("web_search") - private final WebSearch webSearch; + private ToolType type; + private Function function; + private Retrieval retrieval; + private WebSearch webSearch; public Tool(Function function) { this.type = ToolType.FUNCTION; diff --git a/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/chat/ToolCall.java b/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/chat/ToolCall.java index 1980b52de8..b109f1660c 100644 --- a/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/chat/ToolCall.java +++ b/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/chat/ToolCall.java @@ -1,60 +1,26 @@ package dev.langchain4j.model.zhipu.chat; -import lombok.EqualsAndHashCode; -import lombok.Getter; -import lombok.ToString; - -@Getter -@ToString -@EqualsAndHashCode +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategies.SnakeCaseStrategy; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; + +@Data +@Builder +@NoArgsConstructor +@AllArgsConstructor +@JsonInclude(NON_NULL) +@JsonNaming(SnakeCaseStrategy.class) +@JsonIgnoreProperties(ignoreUnknown = true) public final class ToolCall { - private final String id; - private final Integer index; - private final ToolType type; - private final FunctionCall function; - - private ToolCall(Builder builder) { - this.id = builder.id; - this.index = builder.index; - this.type = builder.type; - this.function = builder.function; - } - - public static Builder builder() { - return new Builder(); - } - - public static final class Builder { - private String id; - private Integer index; - private ToolType type; - private FunctionCall function; - - private Builder() { - } - - public Builder id(String id) { - this.id = id; - return this; - } - - public Builder index(Integer index) { - this.index = index; - return this; - } - - public Builder type(ToolType type) { - this.type = type; - return this; - } - - public Builder function(FunctionCall function) { - this.function = function; - return this; - } - - public ToolCall build() { - return new ToolCall(this); - } - } + private String id; + private Integer index; + private ToolType type; + private FunctionCall function; } diff --git a/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/chat/ToolChoice.java b/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/chat/ToolChoice.java index 98e881d7b8..d4e269d689 100644 --- a/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/chat/ToolChoice.java +++ b/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/chat/ToolChoice.java @@ -1,16 +1,22 @@ package dev.langchain4j.model.zhipu.chat; -import lombok.EqualsAndHashCode; -import lombok.ToString; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategies.SnakeCaseStrategy; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import lombok.Data; +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; import static dev.langchain4j.model.zhipu.chat.ToolType.FUNCTION; -@ToString -@EqualsAndHashCode -public class ToolChoice { +@Data +@JsonInclude(NON_NULL) +@JsonNaming(SnakeCaseStrategy.class) +@JsonIgnoreProperties(ignoreUnknown = true) +public final class ToolChoice { private final ToolType type = FUNCTION; - private final Function function; + private Function function; public ToolChoice(String functionName) { this.function = Function.builder().name(functionName).build(); diff --git a/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/chat/ToolChoiceMode.java b/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/chat/ToolChoiceMode.java index 2f00fb3b0a..8bbde7c917 100644 --- a/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/chat/ToolChoiceMode.java +++ b/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/chat/ToolChoiceMode.java @@ -1,8 +1,15 @@ package dev.langchain4j.model.zhipu.chat; -import com.google.gson.annotations.SerializedName; +import com.fasterxml.jackson.annotation.JsonValue; + +import java.util.Locale; public enum ToolChoiceMode { - @SerializedName("none") NONE, - @SerializedName("auto") AUTO + NONE, + AUTO; + + @JsonValue + public String serialize() { + return name().toLowerCase(Locale.ROOT); + } } \ No newline at end of file diff --git a/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/chat/ToolMessage.java b/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/chat/ToolMessage.java index 960697d92a..0dba62c33f 100644 --- a/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/chat/ToolMessage.java +++ b/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/chat/ToolMessage.java @@ -1,27 +1,25 @@ package dev.langchain4j.model.zhipu.chat; -import com.google.gson.annotations.SerializedName; -import lombok.EqualsAndHashCode; -import lombok.Getter; -import lombok.ToString; - +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategies.SnakeCaseStrategy; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import lombok.Builder; +import lombok.Data; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; import static dev.langchain4j.model.zhipu.chat.Role.TOOL; -@ToString -@EqualsAndHashCode +@Data +@Builder +@JsonInclude(NON_NULL) +@JsonNaming(SnakeCaseStrategy.class) +@JsonIgnoreProperties(ignoreUnknown = true) public final class ToolMessage implements Message { private final Role role = TOOL; - @Getter - @SerializedName("tool_call_id") - private final String toolCallId; - @Getter - private final String content; - - private ToolMessage(Builder builder) { - this.toolCallId = builder.toolCallId; - this.content = builder.content; - } + private String toolCallId; + private String content; public static ToolMessage from(String toolCallId, String content) { return ToolMessage.builder() @@ -29,36 +27,4 @@ public static ToolMessage from(String toolCallId, String content) { .content(content) .build(); } - - public static Builder builder() { - return new Builder(); - } - - @Override - public Role getRole() { - return role; - } - - public static final class Builder { - - private String toolCallId; - private String content; - - private Builder() { - } - - public Builder toolCallId(String toolCallId) { - this.toolCallId = toolCallId; - return this; - } - - public Builder content(String content) { - this.content = content; - return this; - } - - public ToolMessage build() { - return new ToolMessage(this); - } - } } \ No newline at end of file diff --git a/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/chat/ToolType.java b/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/chat/ToolType.java index 4e271e89c9..743e1cf799 100644 --- a/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/chat/ToolType.java +++ b/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/chat/ToolType.java @@ -1,7 +1,14 @@ package dev.langchain4j.model.zhipu.chat; -import com.google.gson.annotations.SerializedName; +import com.fasterxml.jackson.annotation.JsonValue; + +import java.util.Locale; public enum ToolType { - @SerializedName("function") FUNCTION + FUNCTION; + + @JsonValue + public String serialize() { + return name().toLowerCase(Locale.ROOT); + } } \ No newline at end of file diff --git a/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/chat/UserMessage.java b/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/chat/UserMessage.java index 9fffbb5b29..211a33b7c4 100644 --- a/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/chat/UserMessage.java +++ b/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/chat/UserMessage.java @@ -1,63 +1,29 @@ package dev.langchain4j.model.zhipu.chat; -import lombok.EqualsAndHashCode; -import lombok.Getter; -import lombok.ToString; - +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategies.SnakeCaseStrategy; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import lombok.Builder; +import lombok.Data; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; import static dev.langchain4j.model.zhipu.chat.Role.USER; -@ToString -@EqualsAndHashCode +@Data +@Builder +@JsonInclude(NON_NULL) +@JsonNaming(SnakeCaseStrategy.class) +@JsonIgnoreProperties(ignoreUnknown = true) public final class UserMessage implements Message { private final Role role = USER; - @Getter - private final String content; - @Getter - private final String name; - - private UserMessage(Builder builder) { - this.content = builder.content; - this.name = builder.name; - } + private String content; + private String name; public static UserMessage from(String text) { return UserMessage.builder() .content(text) .build(); } - - public static Builder builder() { - return new Builder(); - } - - @Override - public Role getRole() { - return role; - } - - public static final class Builder { - - private String content; - private String name; - - private Builder() { - } - - public Builder content(String content) { - if (content != null) { - this.content = content; - } - return this; - } - - public Builder name(String name) { - this.name = name; - return this; - } - - public UserMessage build() { - return new UserMessage(this); - } - } } \ No newline at end of file diff --git a/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/chat/WebSearch.java b/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/chat/WebSearch.java index a8a4e295b5..703cbe6d9f 100644 --- a/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/chat/WebSearch.java +++ b/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/chat/WebSearch.java @@ -1,46 +1,18 @@ package dev.langchain4j.model.zhipu.chat; -import com.google.gson.annotations.SerializedName; -import lombok.EqualsAndHashCode; -import lombok.Getter; -import lombok.ToString; - -@Getter -@ToString -@EqualsAndHashCode +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategies.SnakeCaseStrategy; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import lombok.Data; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; + +@Data +@JsonInclude(NON_NULL) +@JsonNaming(SnakeCaseStrategy.class) +@JsonIgnoreProperties(ignoreUnknown = true) public final class WebSearch { - private final Boolean enable; - @SerializedName("search_query") - private final String searchQuery; - - public WebSearch(WebSearchBuilder builder) { - this.enable = builder.enable; - this.searchQuery = builder.searchQuery; - } - - public static WebSearchBuilder builder() { - return new WebSearchBuilder(); - } - - public static class WebSearchBuilder { - private Boolean enable; - private String searchQuery; - - WebSearchBuilder() { - } - - public WebSearchBuilder enable(Boolean enable) { - this.enable = enable; - return this; - } - - public WebSearchBuilder searchQuery(String searchQuery) { - this.searchQuery = searchQuery; - return this; - } - - public WebSearch build() { - return new WebSearch(this); - } - } + private Boolean enable; + private String searchQuery; } \ No newline at end of file diff --git a/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/embedding/Embedding.java b/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/embedding/Embedding.java index c76b22c4b2..0e0f6d2cbe 100644 --- a/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/embedding/Embedding.java +++ b/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/embedding/Embedding.java @@ -1,60 +1,22 @@ package dev.langchain4j.model.zhipu.embedding; -import lombok.EqualsAndHashCode; -import lombok.Getter; -import lombok.ToString; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategies.SnakeCaseStrategy; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import lombok.Data; import java.util.List; -import static java.util.Collections.unmodifiableList; +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; -@Getter -@ToString -@EqualsAndHashCode +@Data +@JsonInclude(NON_NULL) +@JsonNaming(SnakeCaseStrategy.class) +@JsonIgnoreProperties(ignoreUnknown = true) public final class Embedding { - private final List embedding; - private final String object; - private final Integer index; - - private Embedding(Builder builder) { - this.embedding = builder.embedding; - this.object = builder.object; - this.index = builder.index; - } - - public static Builder builder() { - return new Builder(); - } - - public static final class Builder { - - private List embedding; - private String object; - private Integer index; - - private Builder() { - } - - public Builder embedding(List embedding) { - if (embedding != null) { - this.embedding = unmodifiableList(embedding); - } - return this; - } - - public Builder object(String object) { - this.object = object; - return this; - } - - public Builder index(Integer index) { - this.index = index; - return this; - } - - public Embedding build() { - return new Embedding(this); - } - } + private List embedding; + private String object; + private Integer index; } diff --git a/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/embedding/EmbeddingRequest.java b/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/embedding/EmbeddingRequest.java index a830c061d1..97b52061cd 100644 --- a/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/embedding/EmbeddingRequest.java +++ b/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/embedding/EmbeddingRequest.java @@ -1,75 +1,26 @@ package dev.langchain4j.model.zhipu.embedding; -import lombok.Getter; - -import java.util.Objects; - -@Getter +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategies.SnakeCaseStrategy; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; +import static dev.langchain4j.model.zhipu.embedding.EmbeddingModel.EMBEDDING_2; + +@Data +@Builder +@NoArgsConstructor +@AllArgsConstructor +@JsonInclude(NON_NULL) +@JsonNaming(SnakeCaseStrategy.class) +@JsonIgnoreProperties(ignoreUnknown = true) public final class EmbeddingRequest { - private final String input; - private final String model; - - private EmbeddingRequest(Builder builder) { - this.model = builder.model; - this.input = builder.input; - } - - public static Builder builder() { - return new Builder(); - } - - @Override - public boolean equals(Object another) { - if (this == another) return true; - return another instanceof EmbeddingRequest - && equalTo((EmbeddingRequest) another); - } - - private boolean equalTo(EmbeddingRequest another) { - return Objects.equals(model, another.model) - && Objects.equals(input, another.input); - } - - @Override - public int hashCode() { - int h = 5381; - h += (h << 5) + Objects.hashCode(model); - h += (h << 5) + Objects.hashCode(input); - return h; - } - - @Override - public String toString() { - return "EmbeddingRequest{" - + "model=" + model - + ", input=" + input - + "}"; - } - - public static final class Builder { - - private String model = EmbeddingModel.EMBEDDING_2.toString(); - private String input; - - private Builder() { - } - - public Builder model(EmbeddingModel model) { - return model(model.toString()); - } - - public Builder model(String model) { - this.model = model; - return this; - } - - public Builder input(String input) { - this.input = input; - return this; - } - - public EmbeddingRequest build() { - return new EmbeddingRequest(this); - } - } + private String input; + @Builder.Default + private String model = EMBEDDING_2.toString(); } diff --git a/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/embedding/EmbeddingResponse.java b/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/embedding/EmbeddingResponse.java index c46551078d..243330757c 100644 --- a/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/embedding/EmbeddingResponse.java +++ b/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/embedding/EmbeddingResponse.java @@ -1,33 +1,29 @@ package dev.langchain4j.model.zhipu.embedding; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategies.SnakeCaseStrategy; +import com.fasterxml.jackson.databind.annotation.JsonNaming; import dev.langchain4j.model.zhipu.shared.Usage; -import lombok.EqualsAndHashCode; -import lombok.Getter; -import lombok.ToString; +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; import java.util.List; -import static java.util.Collections.unmodifiableList; +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; -@Getter -@ToString -@EqualsAndHashCode +@Data +@NoArgsConstructor +@AllArgsConstructor +@JsonInclude(NON_NULL) +@JsonNaming(SnakeCaseStrategy.class) +@JsonIgnoreProperties(ignoreUnknown = true) public final class EmbeddingResponse { - private final String model; - private final String object; - private final List data; - private final Usage usage; - - private EmbeddingResponse(Builder builder) { - this.model = builder.model; - this.object = builder.object; - this.data = builder.data; - this.usage = builder.usage; - } - - public static Builder builder() { - return new Builder(); - } + private String model; + private String object; + private List data; + private Usage usage; /** * Convenience method to get the embedding from the first data. @@ -35,41 +31,4 @@ public static Builder builder() { public List getEmbedding() { return data.get(0).getEmbedding(); } - - public static final class Builder { - - private String model; - private String object; - private List data; - private Usage usage; - - private Builder() { - } - - public Builder model(String model) { - this.model = model; - return this; - } - - public Builder object(String object) { - this.object = object; - return this; - } - - public Builder data(List data) { - if (data != null) { - this.data = unmodifiableList(data); - } - return this; - } - - public Builder usage(Usage usage) { - this.usage = usage; - return this; - } - - public EmbeddingResponse build() { - return new EmbeddingResponse(this); - } - } } diff --git a/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/image/Data.java b/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/image/Data.java new file mode 100644 index 0000000000..fe8d2bab00 --- /dev/null +++ b/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/image/Data.java @@ -0,0 +1,6 @@ +package dev.langchain4j.model.zhipu.image; + +@lombok.Data +public class Data { + private String url; +} diff --git a/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/image/ImageModelName.java b/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/image/ImageModelName.java new file mode 100644 index 0000000000..0339574404 --- /dev/null +++ b/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/image/ImageModelName.java @@ -0,0 +1,18 @@ +package dev.langchain4j.model.zhipu.image; + +public enum ImageModelName { + + COGVIEW_3("cogview-3"), + ; + + private final String value; + + ImageModelName(String value) { + this.value = value; + } + + @Override + public String toString() { + return this.value; + } +} diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiFunctionCall.java b/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/image/ImageRequest.java similarity index 52% rename from langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiFunctionCall.java rename to langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/image/ImageRequest.java index b50a4ca0d1..d432976973 100644 --- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiFunctionCall.java +++ b/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/image/ImageRequest.java @@ -1,4 +1,4 @@ -package dev.langchain4j.model.mistralai; +package dev.langchain4j.model.zhipu.image; import lombok.AllArgsConstructor; import lombok.Builder; @@ -6,12 +6,11 @@ import lombok.NoArgsConstructor; @Data +@Builder @NoArgsConstructor @AllArgsConstructor -@Builder -class MistralAiFunctionCall { - - private String name; - private String arguments; - -} +public class ImageRequest { + private String prompt; + private String model; + private String userId; +} \ No newline at end of file diff --git a/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/image/ImageResponse.java b/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/image/ImageResponse.java new file mode 100644 index 0000000000..71e09c82e7 --- /dev/null +++ b/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/image/ImageResponse.java @@ -0,0 +1,16 @@ +package dev.langchain4j.model.zhipu.image; + +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.NoArgsConstructor; + +import java.util.List; + +@Builder +@lombok.Data +@NoArgsConstructor +@AllArgsConstructor +public class ImageResponse { + private Long created; + private List data; +} diff --git a/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/shared/ErrorResponse.java b/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/shared/ErrorResponse.java new file mode 100644 index 0000000000..5d7dc6c40b --- /dev/null +++ b/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/shared/ErrorResponse.java @@ -0,0 +1,14 @@ +package dev.langchain4j.model.zhipu.shared; + +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; + +import java.util.Map; + +@Data +@NoArgsConstructor +@AllArgsConstructor +public class ErrorResponse { + private Map error; +} diff --git a/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/shared/Usage.java b/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/shared/Usage.java index 75b6e91e06..029e49048f 100644 --- a/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/shared/Usage.java +++ b/langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/shared/Usage.java @@ -1,19 +1,24 @@ package dev.langchain4j.model.zhipu.shared; -import com.google.gson.annotations.SerializedName; -import lombok.EqualsAndHashCode; -import lombok.Getter; -import lombok.ToString; - -@Getter -@ToString -@EqualsAndHashCode +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategies.SnakeCaseStrategy; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; + +@Data +@NoArgsConstructor +@AllArgsConstructor +@JsonInclude(NON_NULL) +@JsonNaming(SnakeCaseStrategy.class) +@JsonIgnoreProperties(ignoreUnknown = true) public final class Usage { - @SerializedName("prompt_tokens") private Integer promptTokens; - @SerializedName("completion_tokens") private Integer completionTokens; - @SerializedName("total_tokens") private Integer totalTokens; private Usage(Builder builder) { diff --git a/langchain4j-zhipu-ai/src/test/java/dev/langchain4j/model/zhipu/ZhipuAiChatModelIT.java b/langchain4j-zhipu-ai/src/test/java/dev/langchain4j/model/zhipu/ZhipuAiChatModelIT.java index 910d9b0e31..ffdef3a519 100644 --- a/langchain4j-zhipu-ai/src/test/java/dev/langchain4j/model/zhipu/ZhipuAiChatModelIT.java +++ b/langchain4j-zhipu-ai/src/test/java/dev/langchain4j/model/zhipu/ZhipuAiChatModelIT.java @@ -6,6 +6,7 @@ import dev.langchain4j.data.message.ChatMessage; import dev.langchain4j.data.message.ToolExecutionResultMessage; import dev.langchain4j.data.message.UserMessage; +import dev.langchain4j.model.chat.TestStreamingResponseHandler; import dev.langchain4j.model.output.Response; import dev.langchain4j.model.output.TokenUsage; import org.junit.jupiter.api.Test; @@ -16,8 +17,7 @@ import static dev.langchain4j.agent.tool.JsonSchemaProperty.INTEGER; import static dev.langchain4j.data.message.ToolExecutionResultMessage.from; import static dev.langchain4j.data.message.UserMessage.userMessage; -import static dev.langchain4j.model.output.FinishReason.STOP; -import static dev.langchain4j.model.output.FinishReason.TOOL_EXECUTION; +import static dev.langchain4j.model.output.FinishReason.*; import static java.util.Arrays.asList; import static java.util.Collections.singletonList; import static org.assertj.core.api.Assertions.assertThat; @@ -55,6 +55,19 @@ void should_generate_answer_and_return_token_usage_and_finish_reason_stop() { assertThat(response.finishReason()).isEqualTo(STOP); } + @Test + void should_sensitive_words_answer() { + // given + UserMessage userMessage = userMessage("fuck you"); + + // when + Response response = chatModel.generate(userMessage); + + assertThat(response.content().text()).isEqualTo("系统检测到输入或生成内容可能包含不安全或敏感内容,请您避免输入易产生敏感内容的提示语,感谢您的配合。"); + + assertThat(response.finishReason()).isEqualTo(CONTENT_FILTER); + } + @Test void should_execute_a_tool_then_answer() { @@ -98,4 +111,52 @@ void should_execute_a_tool_then_answer() { assertThat(secondResponse.finishReason()).isEqualTo(STOP); } + + + ToolSpecification currentTime = ToolSpecification.builder() + .name("currentTime") + .description("currentTime") + .build(); + + @Test + void should_execute_get_current_time_tool_and_then_answer() { + // given + UserMessage userMessage = userMessage("What's the time now?"); + List toolSpecifications = singletonList(currentTime); + + // when + Response response = chatModel.generate(singletonList(userMessage), toolSpecifications); + + // then + AiMessage aiMessage = response.content(); + assertThat(aiMessage.text()).isNull(); + assertThat(aiMessage.toolExecutionRequests()).hasSize(1); + + ToolExecutionRequest toolExecutionRequest = aiMessage.toolExecutionRequests().get(0); + assertThat(toolExecutionRequest.name()).isEqualTo("currentTime"); + + TokenUsage tokenUsage = response.tokenUsage(); + assertThat(tokenUsage.totalTokenCount()) + .isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount()); + + assertThat(response.finishReason()).isEqualTo(TOOL_EXECUTION); + + // given + ToolExecutionResultMessage toolExecutionResultMessage = from(toolExecutionRequest, "2024-04-23 12:00:20"); + List messages = asList(userMessage, aiMessage, toolExecutionResultMessage); + + // when + Response secondResponse = chatModel.generate(messages); + + // then + AiMessage secondAiMessage = secondResponse.content(); + assertThat(secondAiMessage.text()).contains("2024-04-23 12:00:20"); + assertThat(secondAiMessage.toolExecutionRequests()).isNull(); + + TokenUsage secondTokenUsage = secondResponse.tokenUsage(); + assertThat(secondTokenUsage.totalTokenCount()) + .isEqualTo(secondTokenUsage.inputTokenCount() + secondTokenUsage.outputTokenCount()); + + assertThat(secondResponse.finishReason()).isEqualTo(STOP); + } } \ No newline at end of file diff --git a/langchain4j-zhipu-ai/src/test/java/dev/langchain4j/model/zhipu/ZhipuAiImageModelIT.java b/langchain4j-zhipu-ai/src/test/java/dev/langchain4j/model/zhipu/ZhipuAiImageModelIT.java new file mode 100644 index 0000000000..aaa9fb71a0 --- /dev/null +++ b/langchain4j-zhipu-ai/src/test/java/dev/langchain4j/model/zhipu/ZhipuAiImageModelIT.java @@ -0,0 +1,33 @@ +package dev.langchain4j.model.zhipu; + +import dev.langchain4j.data.image.Image; +import dev.langchain4j.model.output.Response; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.net.URI; + +import static org.assertj.core.api.Assertions.assertThat; + +@EnabledIfEnvironmentVariable(named = "ZHIPU_API_KEY", matches = ".+") +public class ZhipuAiImageModelIT { + private static final Logger log = LoggerFactory.getLogger(ZhipuAiImageModelIT.class); + private static final String apiKey = System.getenv("ZHIPU_API_KEY"); + + private final ZhipuAiImageModel model = ZhipuAiImageModel.builder() + .apiKey(apiKey) + .logRequests(true) + .logResponses(true) + .build(); + + @Test + void simple_image_generation_works() { + Response response = model.generate("Beautiful house on country side"); + + URI remoteImage = response.content().url(); + log.info("Your remote image is here: {}", remoteImage); + assertThat(remoteImage).isNotNull(); + } +} \ No newline at end of file diff --git a/langchain4j-zhipu-ai/src/test/java/dev/langchain4j/model/zhipu/ZhipuAiStreamingChatModelIT.java b/langchain4j-zhipu-ai/src/test/java/dev/langchain4j/model/zhipu/ZhipuAiStreamingChatModelIT.java index ab68296da9..4418f8441c 100644 --- a/langchain4j-zhipu-ai/src/test/java/dev/langchain4j/model/zhipu/ZhipuAiStreamingChatModelIT.java +++ b/langchain4j-zhipu-ai/src/test/java/dev/langchain4j/model/zhipu/ZhipuAiStreamingChatModelIT.java @@ -17,8 +17,7 @@ import static dev.langchain4j.agent.tool.JsonSchemaProperty.INTEGER; import static dev.langchain4j.data.message.ToolExecutionResultMessage.from; import static dev.langchain4j.data.message.UserMessage.userMessage; -import static dev.langchain4j.model.output.FinishReason.STOP; -import static dev.langchain4j.model.output.FinishReason.TOOL_EXECUTION; +import static dev.langchain4j.model.output.FinishReason.*; import static java.util.Arrays.asList; import static java.util.Collections.singletonList; import static org.assertj.core.api.Assertions.assertThat; @@ -27,7 +26,7 @@ public class ZhipuAiStreamingChatModelIT { private static final String apiKey = System.getenv("ZHIPU_API_KEY"); - private ZhipuAiStreamingChatModel model = ZhipuAiStreamingChatModel.builder() + private final ZhipuAiStreamingChatModel model = ZhipuAiStreamingChatModel.builder() .apiKey(apiKey) .logRequests(true) .logResponses(true) @@ -58,6 +57,19 @@ void should_stream_answer() { assertThat(response.finishReason()).isEqualTo(STOP); } + @Test + void should_sensitive_words_stream_answer() { + TestStreamingResponseHandler handler = new TestStreamingResponseHandler<>(); + + model.generate("fuck you", handler); + + Response response = handler.get(); + + assertThat(response.content().text()).isBlank(); + + assertThat(response.finishReason()).isEqualTo(CONTENT_FILTER); + } + @Test void should_execute_a_tool_then_stream_answer() { @@ -112,4 +124,56 @@ void should_execute_a_tool_then_stream_answer() { assertThat(secondResponse.finishReason()).isEqualTo(STOP); } + + + ToolSpecification currentTime = ToolSpecification.builder() + .name("currentTime") + .description("currentTime") + .build(); + + @Test + void should_execute_get_current_time_tool_and_then_answer() { + // given + UserMessage userMessage = userMessage("What's the time now?"); + List toolSpecifications = singletonList(currentTime); + + // when + TestStreamingResponseHandler handler = new TestStreamingResponseHandler<>(); + model.generate(singletonList(userMessage), toolSpecifications, handler); + + // then + Response response = handler.get(); + AiMessage aiMessage = response.content(); + assertThat(aiMessage.text()).isNull(); + assertThat(aiMessage.toolExecutionRequests()).hasSize(1); + + ToolExecutionRequest toolExecutionRequest = aiMessage.toolExecutionRequests().get(0); + assertThat(toolExecutionRequest.name()).isEqualTo("currentTime"); + + TokenUsage tokenUsage = response.tokenUsage(); + assertThat(tokenUsage.totalTokenCount()) + .isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount()); + + assertThat(response.finishReason()).isEqualTo(TOOL_EXECUTION); + + // given + ToolExecutionResultMessage toolExecutionResultMessage = from(toolExecutionRequest, "2024-04-23 12:00:20"); + List messages = asList(userMessage, aiMessage, toolExecutionResultMessage); + + // when + TestStreamingResponseHandler secondHandler = new TestStreamingResponseHandler<>(); + model.generate(messages, secondHandler); + + // then + Response secondResponse = secondHandler.get(); + AiMessage secondAiMessage = secondResponse.content(); + assertThat(secondAiMessage.text()).contains("2024-04-23 12:00:20"); + assertThat(secondAiMessage.toolExecutionRequests()).isNull(); + + TokenUsage secondTokenUsage = secondResponse.tokenUsage(); + assertThat(secondTokenUsage.totalTokenCount()) + .isEqualTo(secondTokenUsage.inputTokenCount() + secondTokenUsage.outputTokenCount()); + + assertThat(secondResponse.finishReason()).isEqualTo(STOP); + } } diff --git a/langchain4j/pom.xml b/langchain4j/pom.xml index 1b06c32e1b..6878b0440d 100644 --- a/langchain4j/pom.xml +++ b/langchain4j/pom.xml @@ -7,7 +7,7 @@ dev.langchain4j langchain4j-parent - 0.30.0 + 0.32.0-SNAPSHOT ../langchain4j-parent/pom.xml @@ -22,15 +22,6 @@ langchain4j-core - - com.squareup.retrofit2 - retrofit - - - com.squareup.okhttp3 - okhttp - - org.apache.opennlp opennlp-tools diff --git a/langchain4j/src/main/java/dev/langchain4j/chain/ConversationalRetrievalChain.java b/langchain4j/src/main/java/dev/langchain4j/chain/ConversationalRetrievalChain.java index 41da560bdc..2f138c3c24 100644 --- a/langchain4j/src/main/java/dev/langchain4j/chain/ConversationalRetrievalChain.java +++ b/langchain4j/src/main/java/dev/langchain4j/chain/ConversationalRetrievalChain.java @@ -7,9 +7,12 @@ import dev.langchain4j.memory.chat.MessageWindowChatMemory; import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.input.PromptTemplate; -import dev.langchain4j.rag.*; -import dev.langchain4j.rag.content.retriever.ContentRetriever; +import dev.langchain4j.rag.AugmentationRequest; +import dev.langchain4j.rag.AugmentationResult; +import dev.langchain4j.rag.DefaultRetrievalAugmentor; +import dev.langchain4j.rag.RetrievalAugmentor; import dev.langchain4j.rag.content.injector.DefaultContentInjector; +import dev.langchain4j.rag.content.retriever.ContentRetriever; import dev.langchain4j.rag.query.Metadata; import dev.langchain4j.retriever.Retriever; import dev.langchain4j.service.AiServices; @@ -76,15 +79,25 @@ public ConversationalRetrievalChain(ChatLanguageModel chatLanguageModel, public String execute(String query) { UserMessage userMessage = UserMessage.from(query); - Metadata metadata = Metadata.from(userMessage, chatMemory.id(), chatMemory.messages()); - userMessage = retrievalAugmentor.augment(userMessage, metadata); + userMessage = augment(userMessage); chatMemory.add(userMessage); AiMessage aiMessage = chatLanguageModel.generate(chatMemory.messages()).content(); + chatMemory.add(aiMessage); return aiMessage.text(); } + private UserMessage augment(UserMessage userMessage) { + Metadata metadata = Metadata.from(userMessage, chatMemory.id(), chatMemory.messages()); + + AugmentationRequest augmentationRequest = new AugmentationRequest(userMessage, metadata); + + AugmentationResult augmentationResult = retrievalAugmentor.augment(augmentationRequest); + + return (UserMessage) augmentationResult.chatMessage(); + } + public static Builder builder() { return new Builder(); } diff --git a/langchain4j/src/main/java/dev/langchain4j/data/document/source/FileSystemSource.java b/langchain4j/src/main/java/dev/langchain4j/data/document/source/FileSystemSource.java index 5cd501cdd4..49fb88f025 100644 --- a/langchain4j/src/main/java/dev/langchain4j/data/document/source/FileSystemSource.java +++ b/langchain4j/src/main/java/dev/langchain4j/data/document/source/FileSystemSource.java @@ -31,8 +31,8 @@ public InputStream inputStream() throws IOException { @Override public Metadata metadata() { return new Metadata() - .add(FILE_NAME, path.getFileName().toString()) - .add(ABSOLUTE_DIRECTORY_PATH, path.toAbsolutePath().getParent().toString()); + .put(FILE_NAME, path.getFileName().toString()) + .put(ABSOLUTE_DIRECTORY_PATH, path.toAbsolutePath().getParent().toString()); } public static FileSystemSource from(Path filePath) { diff --git a/langchain4j/src/main/java/dev/langchain4j/data/document/splitter/HierarchicalDocumentSplitter.java b/langchain4j/src/main/java/dev/langchain4j/data/document/splitter/HierarchicalDocumentSplitter.java index 34436970f4..faadf60f12 100644 --- a/langchain4j/src/main/java/dev/langchain4j/data/document/splitter/HierarchicalDocumentSplitter.java +++ b/langchain4j/src/main/java/dev/langchain4j/data/document/splitter/HierarchicalDocumentSplitter.java @@ -227,7 +227,7 @@ int estimateSize(String text) { * @param index The index of the segment within the document. */ static TextSegment createSegment(String text, Document document, int index) { - Metadata metadata = document.metadata().copy().add(INDEX, String.valueOf(index)); + Metadata metadata = document.metadata().copy().put(INDEX, String.valueOf(index)); return TextSegment.from(text, metadata); } } diff --git a/langchain4j/src/main/java/dev/langchain4j/data/document/splitter/SegmentBuilder.java b/langchain4j/src/main/java/dev/langchain4j/data/document/splitter/SegmentBuilder.java index dc658277e4..e8deadebf8 100644 --- a/langchain4j/src/main/java/dev/langchain4j/data/document/splitter/SegmentBuilder.java +++ b/langchain4j/src/main/java/dev/langchain4j/data/document/splitter/SegmentBuilder.java @@ -26,7 +26,7 @@ public SegmentBuilder(int maxSegmentSize, Function sizeFunction this.maxSegmentSize = ensureGreaterThanZero(maxSegmentSize, "maxSegmentSize"); this.sizeFunction = ensureNotNull(sizeFunction, "sizeFunction"); this.joinSeparator = ensureNotNull(joinSeparator, "joinSeparator"); - joinSeparatorSize = sizeOf(joinSeparator); + this.joinSeparatorSize = sizeOf(joinSeparator); } /** @@ -117,4 +117,4 @@ public void reset() { segment = ""; segmentSize = 0; } -} \ No newline at end of file +} diff --git a/langchain4j/src/main/java/dev/langchain4j/data/document/transformer/HtmlTextExtractor.java b/langchain4j/src/main/java/dev/langchain4j/data/document/transformer/HtmlTextExtractor.java index 0106815a09..206b147579 100644 --- a/langchain4j/src/main/java/dev/langchain4j/data/document/transformer/HtmlTextExtractor.java +++ b/langchain4j/src/main/java/dev/langchain4j/data/document/transformer/HtmlTextExtractor.java @@ -10,7 +10,10 @@ import org.jsoup.select.NodeVisitor; import java.util.Map; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import static dev.langchain4j.data.document.Document.URL; import static java.lang.String.format; import static java.util.stream.Collectors.joining; import static org.jsoup.internal.StringUtil.in; @@ -23,6 +26,8 @@ */ public class HtmlTextExtractor implements DocumentTransformer { + private static final Logger log = LoggerFactory.getLogger(HtmlTextExtractor.class); + private final String cssSelector; private final Map metadataCssSelectors; private final boolean includeLinks; @@ -53,7 +58,8 @@ public HtmlTextExtractor(String cssSelector, Map metadataCssSele @Override public Document transform(Document document) { String html = document.text(); - org.jsoup.nodes.Document jsoupDocument = Jsoup.parse(html); + String baseUrl = document.metadata(URL) != null ? document.metadata(URL) : ""; + org.jsoup.nodes.Document jsoupDocument = Jsoup.parse(html, baseUrl); String text; if (cssSelector != null) { @@ -65,7 +71,7 @@ public Document transform(Document document) { Metadata metadata = document.metadata().copy(); if (metadataCssSelectors != null) { metadataCssSelectors.forEach((metadataKey, cssSelector) -> - metadata.add(metadataKey, jsoupDocument.select(cssSelector).text())); + metadata.put(metadataKey, jsoupDocument.select(cssSelector).text())); } return Document.from(text, metadata); @@ -111,8 +117,13 @@ public void tail(Node node, int depth) { // hit when all the node's children (if String name = node.nodeName(); if (in(name, "br", "dd", "dt", "p", "h1", "h2", "h3", "h4", "h5", "h6")) textBuilder.append("\n"); - else if (includeLinks && name.equals("a")) - textBuilder.append(format(" <%s>", node.absUrl("href"))); + else if (includeLinks && name.equals("a")) { + String link = node.absUrl("href"); + if (link.isEmpty() && node.baseUri().isEmpty()) { + log.warn("No 'URL' metadata found for document. Link will be empty"); + } + textBuilder.append(format(" <%s>", link)); + } } @Override diff --git a/langchain4j/src/main/java/dev/langchain4j/model/output/BigDecimalOutputParser.java b/langchain4j/src/main/java/dev/langchain4j/model/output/BigDecimalOutputParser.java index 0fa60a31a2..3507d6b388 100644 --- a/langchain4j/src/main/java/dev/langchain4j/model/output/BigDecimalOutputParser.java +++ b/langchain4j/src/main/java/dev/langchain4j/model/output/BigDecimalOutputParser.java @@ -6,7 +6,7 @@ public class BigDecimalOutputParser implements OutputParser { @Override public BigDecimal parse(String string) { - return new BigDecimal(string); + return new BigDecimal(string.trim()); } @Override diff --git a/langchain4j/src/main/java/dev/langchain4j/model/output/BigIntegerOutputParser.java b/langchain4j/src/main/java/dev/langchain4j/model/output/BigIntegerOutputParser.java index f6f92fa8f7..feb3c2466a 100644 --- a/langchain4j/src/main/java/dev/langchain4j/model/output/BigIntegerOutputParser.java +++ b/langchain4j/src/main/java/dev/langchain4j/model/output/BigIntegerOutputParser.java @@ -6,7 +6,7 @@ public class BigIntegerOutputParser implements OutputParser { @Override public BigInteger parse(String string) { - return new BigInteger(string); + return new BigInteger(string.trim()); } @Override diff --git a/langchain4j/src/main/java/dev/langchain4j/model/output/BooleanOutputParser.java b/langchain4j/src/main/java/dev/langchain4j/model/output/BooleanOutputParser.java index e93654dfb1..8858c371a3 100644 --- a/langchain4j/src/main/java/dev/langchain4j/model/output/BooleanOutputParser.java +++ b/langchain4j/src/main/java/dev/langchain4j/model/output/BooleanOutputParser.java @@ -4,7 +4,7 @@ public class BooleanOutputParser implements OutputParser { @Override public Boolean parse(String string) { - return Boolean.parseBoolean(string); + return Boolean.parseBoolean(string.trim()); } @Override diff --git a/langchain4j/src/main/java/dev/langchain4j/model/output/ByteOutputParser.java b/langchain4j/src/main/java/dev/langchain4j/model/output/ByteOutputParser.java index 37978eff31..7bc538ca10 100644 --- a/langchain4j/src/main/java/dev/langchain4j/model/output/ByteOutputParser.java +++ b/langchain4j/src/main/java/dev/langchain4j/model/output/ByteOutputParser.java @@ -4,7 +4,7 @@ public class ByteOutputParser implements OutputParser { @Override public Byte parse(String string) { - return Byte.parseByte(string); + return Byte.parseByte(string.trim()); } @Override diff --git a/langchain4j/src/main/java/dev/langchain4j/model/output/DoubleOutputParser.java b/langchain4j/src/main/java/dev/langchain4j/model/output/DoubleOutputParser.java index f54640b221..d1de3f6d85 100644 --- a/langchain4j/src/main/java/dev/langchain4j/model/output/DoubleOutputParser.java +++ b/langchain4j/src/main/java/dev/langchain4j/model/output/DoubleOutputParser.java @@ -4,7 +4,7 @@ public class DoubleOutputParser implements OutputParser { @Override public Double parse(String string) { - return Double.parseDouble(string); + return Double.parseDouble(string.trim()); } @Override diff --git a/langchain4j/src/main/java/dev/langchain4j/model/output/EnumOutputParser.java b/langchain4j/src/main/java/dev/langchain4j/model/output/EnumOutputParser.java index 65bbab8fe1..1b8c9b2d6b 100644 --- a/langchain4j/src/main/java/dev/langchain4j/model/output/EnumOutputParser.java +++ b/langchain4j/src/main/java/dev/langchain4j/model/output/EnumOutputParser.java @@ -13,6 +13,7 @@ public EnumOutputParser(Class enumClass) { @Override public Enum parse(String string) { + string = string.trim(); for (Enum enumConstant : enumClass.getEnumConstants()) { if (enumConstant.name().equalsIgnoreCase(string)) { return enumConstant; diff --git a/langchain4j/src/main/java/dev/langchain4j/model/output/FloatOutputParser.java b/langchain4j/src/main/java/dev/langchain4j/model/output/FloatOutputParser.java index 48f59591cf..10cb033f86 100644 --- a/langchain4j/src/main/java/dev/langchain4j/model/output/FloatOutputParser.java +++ b/langchain4j/src/main/java/dev/langchain4j/model/output/FloatOutputParser.java @@ -4,7 +4,7 @@ public class FloatOutputParser implements OutputParser { @Override public Float parse(String string) { - return Float.parseFloat(string); + return Float.parseFloat(string.trim()); } @Override diff --git a/langchain4j/src/main/java/dev/langchain4j/model/output/IntOutputParser.java b/langchain4j/src/main/java/dev/langchain4j/model/output/IntOutputParser.java index ab14f21983..3b466afb3f 100644 --- a/langchain4j/src/main/java/dev/langchain4j/model/output/IntOutputParser.java +++ b/langchain4j/src/main/java/dev/langchain4j/model/output/IntOutputParser.java @@ -4,7 +4,7 @@ public class IntOutputParser implements OutputParser { @Override public Integer parse(String string) { - return Integer.parseInt(string); + return Integer.parseInt(string.trim()); } @Override diff --git a/langchain4j/src/main/java/dev/langchain4j/model/output/LocalDateOutputParser.java b/langchain4j/src/main/java/dev/langchain4j/model/output/LocalDateOutputParser.java index 9ef28b6321..acdb33cd93 100644 --- a/langchain4j/src/main/java/dev/langchain4j/model/output/LocalDateOutputParser.java +++ b/langchain4j/src/main/java/dev/langchain4j/model/output/LocalDateOutputParser.java @@ -8,7 +8,7 @@ public class LocalDateOutputParser implements OutputParser { @Override public LocalDate parse(String string) { - return LocalDate.parse(string, ISO_LOCAL_DATE); + return LocalDate.parse(string.trim(), ISO_LOCAL_DATE); } @Override diff --git a/langchain4j/src/main/java/dev/langchain4j/model/output/LocalDateTimeOutputParser.java b/langchain4j/src/main/java/dev/langchain4j/model/output/LocalDateTimeOutputParser.java index 4f8dcf8433..c545ff22b0 100644 --- a/langchain4j/src/main/java/dev/langchain4j/model/output/LocalDateTimeOutputParser.java +++ b/langchain4j/src/main/java/dev/langchain4j/model/output/LocalDateTimeOutputParser.java @@ -8,7 +8,7 @@ public class LocalDateTimeOutputParser implements OutputParser { @Override public LocalDateTime parse(String string) { - return LocalDateTime.parse(string, ISO_LOCAL_DATE_TIME); + return LocalDateTime.parse(string.trim(), ISO_LOCAL_DATE_TIME); } @Override diff --git a/langchain4j/src/main/java/dev/langchain4j/model/output/LocalTimeOutputParser.java b/langchain4j/src/main/java/dev/langchain4j/model/output/LocalTimeOutputParser.java index d5b2b668cf..7e93c2ae71 100644 --- a/langchain4j/src/main/java/dev/langchain4j/model/output/LocalTimeOutputParser.java +++ b/langchain4j/src/main/java/dev/langchain4j/model/output/LocalTimeOutputParser.java @@ -8,7 +8,7 @@ public class LocalTimeOutputParser implements OutputParser { @Override public LocalTime parse(String string) { - return LocalTime.parse(string, ISO_LOCAL_TIME); + return LocalTime.parse(string.trim(), ISO_LOCAL_TIME); } @Override diff --git a/langchain4j/src/main/java/dev/langchain4j/model/output/LongOutputParser.java b/langchain4j/src/main/java/dev/langchain4j/model/output/LongOutputParser.java index d4a792f541..8684f52b20 100644 --- a/langchain4j/src/main/java/dev/langchain4j/model/output/LongOutputParser.java +++ b/langchain4j/src/main/java/dev/langchain4j/model/output/LongOutputParser.java @@ -4,7 +4,7 @@ public class LongOutputParser implements OutputParser { @Override public Long parse(String string) { - return Long.parseLong(string); + return Long.parseLong(string.trim()); } @Override diff --git a/langchain4j/src/main/java/dev/langchain4j/model/output/ShortOutputParser.java b/langchain4j/src/main/java/dev/langchain4j/model/output/ShortOutputParser.java index b9b251e993..b7df16a6a4 100644 --- a/langchain4j/src/main/java/dev/langchain4j/model/output/ShortOutputParser.java +++ b/langchain4j/src/main/java/dev/langchain4j/model/output/ShortOutputParser.java @@ -4,7 +4,7 @@ public class ShortOutputParser implements OutputParser { @Override public Short parse(String string) { - return Short.parseShort(string); + return Short.parseShort(string.trim()); } @Override diff --git a/langchain4j/src/main/java/dev/langchain4j/service/AiServiceStreamingResponseHandler.java b/langchain4j/src/main/java/dev/langchain4j/service/AiServiceStreamingResponseHandler.java index 64e3c81840..b9eb74fa28 100644 --- a/langchain4j/src/main/java/dev/langchain4j/service/AiServiceStreamingResponseHandler.java +++ b/langchain4j/src/main/java/dev/langchain4j/service/AiServiceStreamingResponseHandler.java @@ -10,7 +10,6 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.util.List; import java.util.function.Consumer; import static dev.langchain4j.internal.ValidationUtils.ensureNotNull; @@ -82,14 +81,14 @@ public void onComplete(Response response) { tokenHandler, completionHandler, errorHandler, - tokenUsage.add(response.tokenUsage()) + TokenUsage.sum(tokenUsage, response.tokenUsage()) ) ); } else { if (completionHandler != null) { completionHandler.accept(Response.from( aiMessage, - tokenUsage.add(response.tokenUsage()), + TokenUsage.sum(tokenUsage, response.tokenUsage()), response.finishReason()) ); } diff --git a/langchain4j/src/main/java/dev/langchain4j/service/AiServices.java b/langchain4j/src/main/java/dev/langchain4j/service/AiServices.java index d9634c8b99..6d21df60ec 100644 --- a/langchain4j/src/main/java/dev/langchain4j/service/AiServices.java +++ b/langchain4j/src/main/java/dev/langchain4j/service/AiServices.java @@ -14,14 +14,18 @@ import dev.langchain4j.model.input.structured.StructuredPrompt; import dev.langchain4j.model.moderation.Moderation; import dev.langchain4j.model.moderation.ModerationModel; +import dev.langchain4j.model.output.TokenUsage; import dev.langchain4j.rag.DefaultRetrievalAugmentor; import dev.langchain4j.rag.RetrievalAugmentor; +import dev.langchain4j.rag.content.Content; import dev.langchain4j.rag.content.retriever.ContentRetriever; import dev.langchain4j.rag.content.retriever.EmbeddingStoreContentRetriever; import dev.langchain4j.retriever.Retriever; import dev.langchain4j.spi.services.AiServicesFactory; +import java.lang.reflect.AnnotatedType; import java.lang.reflect.Method; +import java.lang.reflect.ParameterizedType; import java.util.*; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutionException; @@ -30,6 +34,7 @@ import static dev.langchain4j.agent.tool.ToolSpecifications.toolSpecificationFrom; import static dev.langchain4j.exception.IllegalConfigurationException.illegalConfiguration; +import static dev.langchain4j.internal.Exceptions.illegalArgument; import static dev.langchain4j.internal.ValidationUtils.ensureNotNull; import static dev.langchain4j.spi.ServiceHelper.loadFactories; import static java.util.stream.Collectors.toList; @@ -71,12 +76,13 @@ * *

      * The return type of methods in your AI Service can be any of the following:
    - * - a {@link String}, an {@link AiMessage} or a {@code Response}, if you want to get the answer from the LLM as-is
    + * - a {@link String} or an {@link AiMessage}, if you want to get the answer from the LLM as-is
      * - a {@code List} or {@code Set}, if you want to receive the answer as a collection of items or bullet points
      * - any {@link Enum} or a {@code boolean}, if you want to use the LLM for classification
      * - a primitive or boxed Java type: {@code int}, {@code Double}, etc., if you want to use the LLM for data extraction
      * - many default Java types: {@code Date}, {@code LocalDateTime}, {@code BigDecimal}, etc., if you want to use the LLM for data extraction
      * - any custom POJO, if you want to use the LLM for data extraction.
    + * - Result<T> if you want to access {@link TokenUsage} or sources ({@link Content}s retrieved during RAG), aside from T, which can be of any type listed above. For example: Result<String>, Result<MyCustomPojo>
      * For POJOs, it is advisable to use the "json mode" feature if the LLM provider supports it. For OpenAI, this can be enabled by calling {@code responseFormat("json_object")} during model construction.
      *
      * 
    @@ -316,6 +322,10 @@ public AiServices tools(List objectsWithTools) { // TODO Collection? context.toolExecutors = new HashMap<>(); for (Object objectWithTool : objectsWithTools) { + if (objectWithTool instanceof Class) { + throw illegalConfiguration("Tool '%s' must be an object, not a class", objectWithTool); + } + for (Method method : objectWithTool.getClass().getDeclaredMethods()) { if (method.isAnnotationPresent(Tool.class)) { ToolSpecification toolSpecification = toolSpecificationFrom(method); @@ -404,6 +414,14 @@ protected void performBasicValidation() { } } + protected void validateResultReturnType(Method method) { + AnnotatedType annotatedType = method.getAnnotatedReturnType(); + if (!(annotatedType.getType() instanceof ParameterizedType)) { + throw illegalArgument("The return type 'Result' of the method '%s' must be parameterized with a type, " + + "for example: Result or Result", method.getName()); + } + } + public static List removeToolMessages(List messages) { return messages.stream() .filter(it -> !(it instanceof ToolExecutionResultMessage)) diff --git a/langchain4j/src/main/java/dev/langchain4j/service/DefaultAiServices.java b/langchain4j/src/main/java/dev/langchain4j/service/DefaultAiServices.java index 119aac01b5..33823a4ed0 100644 --- a/langchain4j/src/main/java/dev/langchain4j/service/DefaultAiServices.java +++ b/langchain4j/src/main/java/dev/langchain4j/service/DefaultAiServices.java @@ -13,6 +13,8 @@ import dev.langchain4j.model.moderation.Moderation; import dev.langchain4j.model.output.Response; import dev.langchain4j.model.output.TokenUsage; +import dev.langchain4j.rag.AugmentationRequest; +import dev.langchain4j.rag.AugmentationResult; import dev.langchain4j.rag.query.Metadata; import java.io.InputStream; @@ -25,6 +27,7 @@ import static dev.langchain4j.exception.IllegalConfigurationException.illegalConfiguration; import static dev.langchain4j.internal.Exceptions.illegalArgument; import static dev.langchain4j.internal.Exceptions.runtime; +import static dev.langchain4j.internal.Utils.isNotNullOrBlank; import static dev.langchain4j.service.ServiceOutputParser.outputFormatInstructions; import static dev.langchain4j.service.ServiceOutputParser.parse; @@ -65,6 +68,9 @@ public T build() { throw illegalConfiguration("The @Moderate annotation is present, but the moderationModel is not set up. " + "Please ensure a valid moderationModel is configured before using the @Moderate annotation."); } + if (method.getReturnType() == Result.class) { + validateResultReturnType(method); + } } Object proxyInstance = Proxy.newProxyInstance( @@ -88,18 +94,36 @@ public Object invoke(Object proxy, Method method, Object[] args) throws Exceptio Optional systemMessage = prepareSystemMessage(memoryId, method, args); UserMessage userMessage = prepareUserMessage(method, args); - + AugmentationResult augmentationResult = null; if (context.retrievalAugmentor != null) { List chatMemory = context.hasChatMemory() ? context.chatMemory(memoryId).messages() : null; Metadata metadata = Metadata.from(userMessage, memoryId, chatMemory); - userMessage = context.retrievalAugmentor.augment(userMessage, metadata); + AugmentationRequest augmentationRequest = new AugmentationRequest(userMessage, metadata); + augmentationResult = context.retrievalAugmentor.augment(augmentationRequest); + userMessage = (UserMessage) augmentationResult.chatMessage(); } // TODO give user ability to provide custom OutputParser - String outputFormatInstructions = outputFormatInstructions(method.getReturnType()); - userMessage = UserMessage.from(userMessage.text() + outputFormatInstructions); + Class returnType = method.getReturnType(); + boolean isReturnTypeResult = false; + if (returnType == Result.class) { + isReturnTypeResult = true; + AnnotatedType annotatedReturnType = method.getAnnotatedReturnType(); + ParameterizedType type = (ParameterizedType) annotatedReturnType.getType(); + Type[] typeArguments = type.getActualTypeArguments(); + for (Type typeArg : typeArguments) { + returnType = Class.forName(typeArg.getTypeName()); + } + } + String outputFormatInstructions = outputFormatInstructions(returnType); + String text = userMessage.singleText() + outputFormatInstructions; + if (isNotNullOrBlank(userMessage.name())) { + userMessage = UserMessage.from(userMessage.name(), text); + } else { + userMessage = UserMessage.from(text); + } if (context.hasChatMemory()) { ChatMemory chatMemory = context.chatMemory(memoryId); @@ -118,7 +142,7 @@ public Object invoke(Object proxy, Method method, Object[] args) throws Exceptio Future moderationFuture = triggerModerationIfNeeded(method, messages); - if (method.getReturnType() == TokenStream.class) { + if (returnType == TokenStream.class) { return new AiServiceTokenStream(messages, context, memoryId); // TODO moderation } @@ -169,11 +193,21 @@ public Object invoke(Object proxy, Method method, Object[] args) throws Exceptio } response = context.chatModel.generate(messages, context.toolSpecifications); - tokenUsageAccumulator = tokenUsageAccumulator.add(response.tokenUsage()); + tokenUsageAccumulator = TokenUsage.sum(tokenUsageAccumulator, response.tokenUsage()); } response = Response.from(response.content(), tokenUsageAccumulator, response.finishReason()); - return parse(response, method.getReturnType()); + Object parsedResponse = parse(response, returnType); + + if (isReturnTypeResult) { + return Result.builder() + .content(parsedResponse) + .tokenUsage(tokenUsageAccumulator) + .sources(augmentationResult == null ? null : augmentationResult.contents()) + .build(); + } else { + return parsedResponse; + } } private Future triggerModerationIfNeeded(Method method, List messages) { diff --git a/langchain4j/src/main/java/dev/langchain4j/service/Result.java b/langchain4j/src/main/java/dev/langchain4j/service/Result.java new file mode 100644 index 0000000000..948f3196ba --- /dev/null +++ b/langchain4j/src/main/java/dev/langchain4j/service/Result.java @@ -0,0 +1,44 @@ +package dev.langchain4j.service; + +import dev.langchain4j.model.output.TokenUsage; +import dev.langchain4j.rag.content.Content; +import lombok.Builder; + +import java.util.List; + +import static dev.langchain4j.internal.Utils.copyIfNotNull; +import static dev.langchain4j.internal.ValidationUtils.ensureNotNull; + +/** + * Represents the result of an AI Service invocation. + * It contains actual content (LLM response) and additional information associated with it, + * such as {@link TokenUsage} and sources ({@link Content}s retrieved during RAG). + * + * @param The type of the content. Can be of any return type supported by AI Services, + * such as String, Enum, MyCustomPojo, etc. + */ +public class Result { + + private final T content; + private final TokenUsage tokenUsage; + private final List sources; + + @Builder + public Result(T content, TokenUsage tokenUsage, List sources) { + this.content = ensureNotNull(content, "content"); + this.tokenUsage = ensureNotNull(tokenUsage, "tokenUsage"); + this.sources = copyIfNotNull(sources); + } + + public T content() { + return content; + } + + public TokenUsage tokenUsage() { + return tokenUsage; + } + + public List sources() { + return sources; + } +} diff --git a/langchain4j/src/main/java/dev/langchain4j/service/ServiceOutputParser.java b/langchain4j/src/main/java/dev/langchain4j/service/ServiceOutputParser.java index 02184d97ac..bc0ffb1605 100644 --- a/langchain4j/src/main/java/dev/langchain4j/service/ServiceOutputParser.java +++ b/langchain4j/src/main/java/dev/langchain4j/service/ServiceOutputParser.java @@ -118,7 +118,7 @@ public static String outputFormatInstructions(Class returnType) { return "\nYou must answer strictly in the following JSON format: " + jsonStructure(returnType, new HashSet<>()); } - private static String jsonStructure(Class structured, Set> visited) { + public static String jsonStructure(Class structured, Set> visited) { StringBuilder jsonSchema = new StringBuilder(); jsonSchema.append("{\n"); diff --git a/langchain4j/src/main/java/dev/langchain4j/store/embedding/inmemory/InMemoryEmbeddingStore.java b/langchain4j/src/main/java/dev/langchain4j/store/embedding/inmemory/InMemoryEmbeddingStore.java index 9c7b8f438f..b26b38186e 100644 --- a/langchain4j/src/main/java/dev/langchain4j/store/embedding/inmemory/InMemoryEmbeddingStore.java +++ b/langchain4j/src/main/java/dev/langchain4j/store/embedding/inmemory/InMemoryEmbeddingStore.java @@ -16,8 +16,7 @@ import java.util.stream.IntStream; import static dev.langchain4j.internal.Utils.randomUUID; -import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank; -import static dev.langchain4j.internal.ValidationUtils.ensureNotNull; +import static dev.langchain4j.internal.ValidationUtils.*; import static dev.langchain4j.spi.ServiceHelper.loadFactories; import static java.nio.file.StandardOpenOption.CREATE; import static java.nio.file.StandardOpenOption.TRUNCATE_EXISTING; @@ -95,6 +94,33 @@ private List add(List> newEntries) { .collect(toList()); } + @Override + public void removeAll(Collection ids) { + ensureNotEmpty(ids, "ids"); + + entries.removeIf(entry -> ids.contains(entry.id)); + } + + @Override + public void removeAll(Filter filter) { + ensureNotNull(filter, "filter"); + + entries.removeIf(entry -> { + if (entry.embedded instanceof TextSegment) { + return filter.test(((TextSegment) entry.embedded).metadata()); + } else if (entry.embedded == null) { + return false; + } else { + throw new UnsupportedOperationException("Not supported yet."); + } + }); + } + + @Override + public void removeAll() { + entries.clear(); + } + @Override public EmbeddingSearchResult search(EmbeddingSearchRequest embeddingSearchRequest) { diff --git a/langchain4j/src/test/java/dev/langchain4j/data/document/splitter/DocumentByParagraphSplitterTest.java b/langchain4j/src/test/java/dev/langchain4j/data/document/splitter/DocumentByParagraphSplitterTest.java index e20c4f7b62..2a2789b200 100644 --- a/langchain4j/src/test/java/dev/langchain4j/data/document/splitter/DocumentByParagraphSplitterTest.java +++ b/langchain4j/src/test/java/dev/langchain4j/data/document/splitter/DocumentByParagraphSplitterTest.java @@ -53,8 +53,8 @@ void should_split_into_segments_with_one_paragraph_per_segment(String separator) segments.forEach(segment -> assertThat(segment.text().length()).isLessThanOrEqualTo(maxSegmentSize)); assertThat(segments).containsExactly( - textSegment(firstParagraph, metadata("index", "0").add("document", "0")), - textSegment(secondParagraph, metadata("index", "1").add("document", "0")) + textSegment(firstParagraph, metadata("index", "0").put("document", "0")), + textSegment(secondParagraph, metadata("index", "1").put("document", "0")) ); } @@ -93,8 +93,8 @@ void should_split_into_segments_with_multiple_paragraphs_per_segment(String sepa segments.forEach(segment -> assertThat(segment.text().length()).isLessThanOrEqualTo(maxSegmentSize)); assertThat(segments).containsExactly( - textSegment(firstParagraph + "\n\n" + secondParagraph, metadata("index", "0").add("document", "0")), - textSegment(thirdParagraph, metadata("index", "1").add("document", "0")) + textSegment(firstParagraph + "\n\n" + secondParagraph, metadata("index", "0").put("document", "0")), + textSegment(thirdParagraph, metadata("index", "1").put("document", "0")) ); } @@ -138,10 +138,10 @@ void should_split_paragraph_into_sentences_if_it_does_not_fit_into_segment(Strin segments.forEach(segment -> assertThat(segment.text().length()).isLessThanOrEqualTo(maxSegmentSize)); assertThat(segments).containsExactly( - textSegment(firstParagraph, metadata("index", "0").add("document", "0")), - textSegment(firstSentenceOfSecondParagraph, metadata("index", "1").add("document", "0")), - textSegment(secondSentenceOfSecondParagraph, metadata("index", "2").add("document", "0")), - textSegment(thirdParagraph, metadata("index", "3").add("document", "0")) + textSegment(firstParagraph, metadata("index", "0").put("document", "0")), + textSegment(firstSentenceOfSecondParagraph, metadata("index", "1").put("document", "0")), + textSegment(secondSentenceOfSecondParagraph, metadata("index", "2").put("document", "0")), + textSegment(thirdParagraph, metadata("index", "3").put("document", "0")) ); } @@ -210,14 +210,14 @@ void should_split_sample_text_containing_multiple_paragraphs() { segments.forEach(segment -> assertThat(tokenizer.estimateTokenCountInText(segment.text())).isLessThanOrEqualTo(maxSegmentSize)); assertThat(segments).containsExactly( - textSegment(p1, metadata("index", "0").add("document", "0")), - textSegment(p2p1, metadata("index", "1").add("document", "0")), - textSegment(p2p2, metadata("index", "2").add("document", "0")), - textSegment(p3, metadata("index", "3").add("document", "0")), - textSegment(p4p1, metadata("index", "4").add("document", "0")), - textSegment(p4p2, metadata("index", "5").add("document", "0")), - textSegment(p5 + "\n\n" + p6, metadata("index", "6").add("document", "0")), - textSegment(p7, metadata("index", "7").add("document", "0")) + textSegment(p1, metadata("index", "0").put("document", "0")), + textSegment(p2p1, metadata("index", "1").put("document", "0")), + textSegment(p2p2, metadata("index", "2").put("document", "0")), + textSegment(p3, metadata("index", "3").put("document", "0")), + textSegment(p4p1, metadata("index", "4").put("document", "0")), + textSegment(p4p2, metadata("index", "5").put("document", "0")), + textSegment(p5 + "\n\n" + p6, metadata("index", "6").put("document", "0")), + textSegment(p7, metadata("index", "7").put("document", "0")) ); } @@ -280,14 +280,14 @@ void should_split_sample_text_containing_multiple_paragraphs_with_overlap() { segments.forEach(segment -> assertThat(tokenizer.estimateTokenCountInText(segment.text())).isLessThanOrEqualTo(maxSegmentSize)); assertThat(segments).containsExactly( - textSegment(format("%s %s %s %s", s1, s2, s3, s4), metadata("index", "0").add("document", "0")), - textSegment(format("%s %s %s %s", s5, s6, s7, s8), metadata("index", "1").add("document", "0")), - textSegment(format("%s %s %s %s", s8, s9, s10, s11), metadata("index", "2").add("document", "0")), - textSegment(format("%s\n\n%s %s %s %s", s11, s12, s13, s14, s15), metadata("index", "3").add("document", "0")), - textSegment(format("%s %s %s %s", s15, s16, s17, s18), metadata("index", "4").add("document", "0")), - textSegment(format("%s %s %s %s %s %s", s19, s20, s21, s22, s23, s24), metadata("index", "5").add("document", "0")), - textSegment(format("%s %s %s %s %s %s", s22, s23, s24, s25, s26, s27), metadata("index", "6").add("document", "0")), - textSegment(format("%s %s %s", s27, s28, s29), metadata("index", "7").add("document", "0")) + textSegment(format("%s %s %s %s", s1, s2, s3, s4), metadata("index", "0").put("document", "0")), + textSegment(format("%s %s %s %s", s5, s6, s7, s8), metadata("index", "1").put("document", "0")), + textSegment(format("%s %s %s %s", s8, s9, s10, s11), metadata("index", "2").put("document", "0")), + textSegment(format("%s\n\n%s %s %s %s", s11, s12, s13, s14, s15), metadata("index", "3").put("document", "0")), + textSegment(format("%s %s %s %s", s15, s16, s17, s18), metadata("index", "4").put("document", "0")), + textSegment(format("%s %s %s %s %s %s", s19, s20, s21, s22, s23, s24), metadata("index", "5").put("document", "0")), + textSegment(format("%s %s %s %s %s %s", s22, s23, s24, s25, s26, s27), metadata("index", "6").put("document", "0")), + textSegment(format("%s %s %s", s27, s28, s29), metadata("index", "7").put("document", "0")) ); } @@ -345,10 +345,10 @@ void should_split_sample_text_without_paragraphs() { segments.forEach(segment -> assertThat(tokenizer.estimateTokenCountInText(segment.text())).isLessThanOrEqualTo(maxSegmentSize)); assertThat(segments).containsExactly( - textSegment(segment1, metadata("index", "0").add("document", "0")), - textSegment(segment2, metadata("index", "1").add("document", "0")), - textSegment(segment3, metadata("index", "2").add("document", "0")), - textSegment(segment4, metadata("index", "3").add("document", "0")) + textSegment(segment1, metadata("index", "0").put("document", "0")), + textSegment(segment2, metadata("index", "1").put("document", "0")), + textSegment(segment3, metadata("index", "2").put("document", "0")), + textSegment(segment4, metadata("index", "3").put("document", "0")) ); } @@ -371,11 +371,11 @@ void should_split_sample_text_without_paragraphs_with_small_overlap() { assertThat(tokenizer.estimateTokenCountInText(segment.text())).isLessThanOrEqualTo(maxSegmentSize)); assertThat(segments).containsExactly( - TextSegment.from(sentences(0, 5), Metadata.from("index", "0").add("document", "0")), - TextSegment.from(sentences(5, 12), Metadata.from("index", "1").add("document", "0")), - TextSegment.from(sentences(10, 16), Metadata.from("index", "2").add("document", "0")), - TextSegment.from(sentences(15, 24), Metadata.from("index", "3").add("document", "0")), - TextSegment.from(sentences(21, 28), Metadata.from("index", "4").add("document", "0")) + TextSegment.from(sentences(0, 5), Metadata.from("index", "0").put("document", "0")), + TextSegment.from(sentences(5, 12), Metadata.from("index", "1").put("document", "0")), + TextSegment.from(sentences(10, 16), Metadata.from("index", "2").put("document", "0")), + TextSegment.from(sentences(15, 24), Metadata.from("index", "3").put("document", "0")), + TextSegment.from(sentences(21, 28), Metadata.from("index", "4").put("document", "0")) ); assertThat(tokenizer.estimateTokenCountInText(sentences(5, 5))).isLessThanOrEqualTo(maxOverlapSize); @@ -403,24 +403,24 @@ void should_split_sample_text_without_paragraphs_with_big_overlap() { assertThat(tokenizer.estimateTokenCountInText(segment.text())).isLessThanOrEqualTo(maxSegmentSize)); assertThat(segments).containsExactly( - TextSegment.from(sentences(0, 5), Metadata.from("index", "0").add("document", "0")), - TextSegment.from(sentences(1, 6), Metadata.from("index", "1").add("document", "0")), - TextSegment.from(sentences(3, 8), Metadata.from("index", "2").add("document", "0")), + TextSegment.from(sentences(0, 5), Metadata.from("index", "0").put("document", "0")), + TextSegment.from(sentences(1, 6), Metadata.from("index", "1").put("document", "0")), + TextSegment.from(sentences(3, 8), Metadata.from("index", "2").put("document", "0")), // TODO fix chopped "Mrs." - TextSegment.from(sentences(4, 10) + " Mrs.", Metadata.from("index", "3").add("document", "0")), - TextSegment.from(sentences(5, 12), Metadata.from("index", "4").add("document", "0")), - TextSegment.from(sentences(7, 15), Metadata.from("index", "5").add("document", "0")), - TextSegment.from(sentences(9, 16), Metadata.from("index", "6").add("document", "0")), + TextSegment.from(sentences(4, 10) + " Mrs.", Metadata.from("index", "3").put("document", "0")), + TextSegment.from(sentences(5, 12), Metadata.from("index", "4").put("document", "0")), + TextSegment.from(sentences(7, 15), Metadata.from("index", "5").put("document", "0")), + TextSegment.from(sentences(9, 16), Metadata.from("index", "6").put("document", "0")), // TODO fix chopped s18 // TODO splitter should prioritize progressing forward instead of maximizing overlap - TextSegment.from(sentences(10, 16) + " " + sentences[17].replace(" countless tales.", ""), Metadata.from("index", "7").add("document", "0")), + TextSegment.from(sentences(10, 16) + " " + sentences[17].replace(" countless tales.", ""), Metadata.from("index", "7").put("document", "0")), // TODO this segment should not be present, there is s14-s19 below - TextSegment.from(sentences(13, 17), Metadata.from("index", "8").add("document", "0")), - TextSegment.from(sentences(13, 18), Metadata.from("index", "9").add("document", "0")), - TextSegment.from(sentences(14, 23), Metadata.from("index", "10").add("document", "0")), - TextSegment.from(sentences(16, 24), Metadata.from("index", "11").add("document", "0")), - TextSegment.from(sentences(17, 26), Metadata.from("index", "12").add("document", "0")), - TextSegment.from(sentences(18, 28), Metadata.from("index", "13").add("document", "0")) + TextSegment.from(sentences(13, 17), Metadata.from("index", "8").put("document", "0")), + TextSegment.from(sentences(13, 18), Metadata.from("index", "9").put("document", "0")), + TextSegment.from(sentences(14, 23), Metadata.from("index", "10").put("document", "0")), + TextSegment.from(sentences(16, 24), Metadata.from("index", "11").put("document", "0")), + TextSegment.from(sentences(17, 26), Metadata.from("index", "12").put("document", "0")), + TextSegment.from(sentences(18, 28), Metadata.from("index", "13").put("document", "0")) ); assertThat(tokenizer.estimateTokenCountInText(sentences(1, 5))).isLessThanOrEqualTo(maxOverlapSize); diff --git a/langchain4j/src/test/java/dev/langchain4j/data/document/splitter/DocumentByRegexSplitterTest.java b/langchain4j/src/test/java/dev/langchain4j/data/document/splitter/DocumentByRegexSplitterTest.java index 4cff71a46c..290f58f1a3 100644 --- a/langchain4j/src/test/java/dev/langchain4j/data/document/splitter/DocumentByRegexSplitterTest.java +++ b/langchain4j/src/test/java/dev/langchain4j/data/document/splitter/DocumentByRegexSplitterTest.java @@ -31,9 +31,9 @@ void should_split_by(String separator) { segments.forEach(segment -> assertThat(segment.text().length()).isLessThanOrEqualTo(maxSegmentSize)); assertThat(segments).containsExactly( - textSegment("one", metadata("index", "0").add("document", "0")), - textSegment("two", metadata("index", "1").add("document", "0")), - textSegment("three", metadata("index", "2").add("document", "0")) + textSegment("one", metadata("index", "0").put("document", "0")), + textSegment("two", metadata("index", "1").put("document", "0")), + textSegment("three", metadata("index", "2").put("document", "0")) ); } @@ -50,8 +50,8 @@ void should_fit_multiple_parts_into_the_same_segment() { segments.forEach(segment -> assertThat(segment.text().length()).isLessThanOrEqualTo(maxSegmentSize)); assertThat(segments).containsExactly( - textSegment("one\ntwo", metadata("index", "0").add("document", "0")), - textSegment("three", metadata("index", "1").add("document", "0")) + textSegment("one\ntwo", metadata("index", "0").put("document", "0")), + textSegment("three", metadata("index", "1").put("document", "0")) ); } @@ -72,12 +72,12 @@ void should_split_part_into_sub_parts_if_it_does_not_fit_into_segment() { segments.forEach(segment -> assertThat(segment.text().length()).isLessThanOrEqualTo(maxSegmentSize)); assertThat(segments).containsExactly( - textSegment("This is a first", metadata("index", "0").add("document", "0")), - textSegment("line.", metadata("index", "1").add("document", "0")), - textSegment("This is a", metadata("index", "2").add("document", "0")), - textSegment("second line.", metadata("index", "3").add("document", "0")), - textSegment("This is a third", metadata("index", "4").add("document", "0")), - textSegment("line.", metadata("index", "5").add("document", "0")) + textSegment("This is a first", metadata("index", "0").put("document", "0")), + textSegment("line.", metadata("index", "1").put("document", "0")), + textSegment("This is a", metadata("index", "2").put("document", "0")), + textSegment("second line.", metadata("index", "3").put("document", "0")), + textSegment("This is a third", metadata("index", "4").put("document", "0")), + textSegment("line.", metadata("index", "5").put("document", "0")) ); } } \ No newline at end of file diff --git a/langchain4j/src/test/java/dev/langchain4j/data/document/splitter/DocumentBySentenceSplitterTest.java b/langchain4j/src/test/java/dev/langchain4j/data/document/splitter/DocumentBySentenceSplitterTest.java index fbfa57ab37..276033ac22 100644 --- a/langchain4j/src/test/java/dev/langchain4j/data/document/splitter/DocumentBySentenceSplitterTest.java +++ b/langchain4j/src/test/java/dev/langchain4j/data/document/splitter/DocumentBySentenceSplitterTest.java @@ -41,8 +41,8 @@ void should_split_into_segments_with_one_sentence_per_segment() { segments.forEach(segment -> assertThat(segment.text().length()).isLessThanOrEqualTo(maxSegmentSize)); assertThat(segments).containsExactly( - textSegment(firstSentence, metadata("index", "0").add("document", "0")), - textSegment(secondSentence, metadata("index", "1").add("document", "0")) + textSegment(firstSentence, metadata("index", "0").put("document", "0")), + textSegment(secondSentence, metadata("index", "1").put("document", "0")) ); } @@ -71,8 +71,8 @@ void should_split_into_segments_with_multiple_sentences_per_segment() { segments.forEach(segment -> assertThat(segment.text().length()).isLessThanOrEqualTo(maxSegmentSize)); assertThat(segments).containsExactly( - textSegment(firstSentence + " " + secondSentence, metadata("index", "0").add("document", "0")), - textSegment(thirdSentence, metadata("index", "1").add("document", "0")) + textSegment(firstSentence + " " + secondSentence, metadata("index", "0").put("document", "0")), + textSegment(thirdSentence, metadata("index", "1").put("document", "0")) ); } @@ -102,10 +102,10 @@ void should_split_sentence_if_it_does_not_fit_into_segment() { segments.forEach(segment -> assertThat(segment.text().length()).isLessThanOrEqualTo(maxSegmentSize)); assertThat(segments).containsExactly( - textSegment(firstSentence, metadata("index", "0").add("document", "0")), - textSegment("This is a very long sentence that does", metadata("index", "1").add("document", "0")), - textSegment("not fit into segment.", metadata("index", "2").add("document", "0")), - textSegment(thirdSentence, metadata("index", "3").add("document", "0")) + textSegment(firstSentence, metadata("index", "0").put("document", "0")), + textSegment("This is a very long sentence that does", metadata("index", "1").put("document", "0")), + textSegment("not fit into segment.", metadata("index", "2").put("document", "0")), + textSegment(thirdSentence, metadata("index", "3").put("document", "0")) ); } @@ -160,17 +160,17 @@ void should_split_sample_text() { segments.forEach(segment -> assertThat(tokenizer.estimateTokenCountInText(segment.text())).isLessThanOrEqualTo(maxSegmentSize)); assertThat(segments).containsExactly( - textSegment(s1 + " " + s2, metadata("index", "0").add("document", "0")), - textSegment(s3 + " " + s4, metadata("index", "1").add("document", "0")), - textSegment(s5p1, metadata("index", "2").add("document", "0")), - textSegment(s5p2, metadata("index", "3").add("document", "0")), - textSegment(s6, metadata("index", "4").add("document", "0")), - textSegment(s7, metadata("index", "5").add("document", "0")), - textSegment(s8 + " " + s9, metadata("index", "6").add("document", "0")), - textSegment(s10, metadata("index", "7").add("document", "0")), - textSegment(s11 + " " + s12 + " " + s13 + " " + s14, metadata("index", "8").add("document", "0")), - textSegment(s15 + " " + s16 + " " + s17, metadata("index", "9").add("document", "0")), - textSegment(s18, metadata("index", "10").add("document", "0")) + textSegment(s1 + " " + s2, metadata("index", "0").put("document", "0")), + textSegment(s3 + " " + s4, metadata("index", "1").put("document", "0")), + textSegment(s5p1, metadata("index", "2").put("document", "0")), + textSegment(s5p2, metadata("index", "3").put("document", "0")), + textSegment(s6, metadata("index", "4").put("document", "0")), + textSegment(s7, metadata("index", "5").put("document", "0")), + textSegment(s8 + " " + s9, metadata("index", "6").put("document", "0")), + textSegment(s10, metadata("index", "7").put("document", "0")), + textSegment(s11 + " " + s12 + " " + s13 + " " + s14, metadata("index", "8").put("document", "0")), + textSegment(s15 + " " + s16 + " " + s17, metadata("index", "9").put("document", "0")), + textSegment(s18, metadata("index", "10").put("document", "0")) ); } } \ No newline at end of file diff --git a/langchain4j/src/test/java/dev/langchain4j/data/document/transformer/HtmlTextExtractorTest.java b/langchain4j/src/test/java/dev/langchain4j/data/document/transformer/HtmlTextExtractorTest.java index 6da3cabd9e..ee7b3509f0 100644 --- a/langchain4j/src/test/java/dev/langchain4j/data/document/transformer/HtmlTextExtractorTest.java +++ b/langchain4j/src/test/java/dev/langchain4j/data/document/transformer/HtmlTextExtractorTest.java @@ -20,6 +20,12 @@ class HtmlTextExtractorTest { "" + ""; + private static final String SAMPLE_HTML_WITH_RELATIVE_LINKS = "" + + "" + + "

    Follow the link here.

    " + + "" + + ""; + @Test void should_extract_all_text_from_html() { @@ -52,7 +58,7 @@ void should_extract_text_from_html_by_css_selector() { Document transformedDocument = transformer.transform(htmlDocument); assertThat(transformedDocument.text()).isEqualTo("Paragraph 1\nSomething"); - assertThat(transformedDocument.metadata().asMap()).isEmpty(); + assertThat(transformedDocument.metadata().toMap()).isEmpty(); } @Test @@ -68,8 +74,8 @@ void should_extract_text_and_metadata_from_html_by_css_selectors() { assertThat(transformedDocument.text()).isEqualTo("Paragraph 1\nSomething"); - assertThat(transformedDocument.metadata().asMap()).hasSize(1); - assertThat(transformedDocument.metadata("title")).isEqualTo("Title"); + assertThat(transformedDocument.metadata().toMap()).hasSize(1); + assertThat(transformedDocument.metadata().getString("title")).isEqualTo("Title"); } @Test @@ -93,6 +99,48 @@ void should_extract_text_with_links_from_html() { " * Item one\n" + " * Item two" ); - assertThat(transformedDocument.metadata().asMap()).isEmpty(); + assertThat(transformedDocument.metadata().toMap()).isEmpty(); + } + + @Test + void should_extract_text_with_absolute_links_from_html_with_relative_links_and_url_metadata() { + HtmlTextExtractor transformer = new HtmlTextExtractor(null, null, true); + Document htmlDocument = Document.from(SAMPLE_HTML_WITH_RELATIVE_LINKS); + htmlDocument.metadata().put(Document.URL, "https://example.org/page.html"); + + Document transformedDocument = transformer.transform(htmlDocument); + + assertThat(transformedDocument.text()).isEqualTo( + "Follow the link here ." + ); + assertThat(transformedDocument.metadata().asMap()) + .containsEntry(Document.URL, "https://example.org/page.html") + .hasSize(1); + } + + @Test + void should_extract_text_with_absolute_links_from_html_with_absolute_links_and_url_metadata() { + HtmlTextExtractor transformer = new HtmlTextExtractor(null, null, true); + Document htmlDocument = Document.from(SAMPLE_HTML); + htmlDocument.metadata().put(Document.URL, "https://other.example.org/page.html"); + + Document transformedDocument = transformer.transform(htmlDocument); + + assertThat(transformedDocument.text()).isEqualTo( + "Title\n" + + "\n" + + "Paragraph 1\n" + + "Something\n" + + "\n" + + "Paragraph 2\n" + + "\n" + + "More details here .\n" + + "List:\n" + + " * Item one\n" + + " * Item two" + ); + assertThat(transformedDocument.metadata().asMap()) + .containsEntry(Document.URL, "https://other.example.org/page.html") + .hasSize(1); } } \ No newline at end of file diff --git a/langchain4j/src/test/java/dev/langchain4j/model/output/OutputParserTest.java b/langchain4j/src/test/java/dev/langchain4j/model/output/OutputParserTest.java index 3837b70270..6a1354f8f7 100644 --- a/langchain4j/src/test/java/dev/langchain4j/model/output/OutputParserTest.java +++ b/langchain4j/src/test/java/dev/langchain4j/model/output/OutputParserTest.java @@ -19,6 +19,7 @@ public void test_BigDecimal() { .isEqualTo("floating point number"); assertThat(parser.parse("3.14")).isEqualTo(new BigDecimal("3.14")); + assertThat(parser.parse(" 3.14 ")).isEqualTo(new BigDecimal("3.14")); assertThatExceptionOfType(NumberFormatException.class) .isThrownBy(() -> parser.parse("3.14.15")); @@ -31,6 +32,7 @@ public void test_BigInteger() { .isEqualTo("integer number"); assertThat(parser.parse("42")).isEqualTo(42); + assertThat(parser.parse(" 42 ")).isEqualTo(42); assertThatExceptionOfType(NumberFormatException.class) .isThrownBy(() -> parser.parse("42.0")); @@ -43,6 +45,7 @@ public void test_Boolean() { .isEqualTo("one of [true, false]"); assertThat(parser.parse("true")) + .isEqualTo(parser.parse(" true ")) .isEqualTo(parser.parse("TRUE")) .isEqualTo(true); assertThat(parser.parse("false")) @@ -59,6 +62,7 @@ public void test_Byte() { .isEqualTo("integer number in range [-128, 127]"); assertThat(parser.parse("42")).isEqualTo((byte) 42); + assertThat(parser.parse(" 42 ")).isEqualTo((byte) 42); assertThat(parser.parse("-42")).isEqualTo((byte) -42); assertThatExceptionOfType(NumberFormatException.class) @@ -74,6 +78,7 @@ public void test_Date() { assertThat(parser.parse("2020-01-12")) .isEqualTo(parser.parse("2020-01-12")) + .isEqualTo(parser.parse(" 2020-01-12 ")) .isEqualTo(new Date(120, Calendar.JANUARY, 12)); assertThatExceptionOfType(RuntimeException.class) @@ -88,6 +93,7 @@ public void test_Double() { .isEqualTo("floating point number"); assertThat(parser.parse("3.14")).isEqualTo(3.14); + assertThat(parser.parse(" 3.14 ")).isEqualTo(3.14); assertThatExceptionOfType(NumberFormatException.class) .isThrownBy(() -> parser.parse("3.14.15")); @@ -105,6 +111,7 @@ public void test_Enum() { .isEqualTo("one of [A, B, C]"); assertThat(parser.parse("A")) + .isEqualTo(parser.parse(" A ")) .isEqualTo(parser.parse("a")) .isEqualTo(Enum.A); assertThat(parser.parse("B")) @@ -126,6 +133,7 @@ public void test_Float() { .isEqualTo("floating point number"); assertThat(parser.parse("3.14")).isEqualTo(3.14f); + assertThat(parser.parse(" 3.14 ")).isEqualTo(3.14f); assertThatExceptionOfType(NumberFormatException.class) .isThrownBy(() -> parser.parse("3.14.15")); @@ -138,6 +146,7 @@ public void test_Integer() { .isEqualTo("integer number"); assertThat(parser.parse("42")).isEqualTo(42); + assertThat(parser.parse(" 42 ")).isEqualTo(42); assertThatExceptionOfType(NumberFormatException.class) .isThrownBy(() -> parser.parse("42.0")); @@ -150,7 +159,7 @@ public void test_LocalDate() { .isEqualTo("yyyy-MM-dd"); assertThat(parser.parse("2020-01-12")) - .isEqualTo(parser.parse("2020-01-12")) + .isEqualTo(parser.parse(" 2020-01-12 ")) .isEqualTo(LocalDate.of(2020, 1, 12)); assertThatExceptionOfType(RuntimeException.class) @@ -167,7 +176,7 @@ public void test_LocalDateTime() { .isEqualTo("yyyy-MM-ddTHH:mm:ss"); assertThat(parser.parse("2020-01-12T12:34:56")) - .isEqualTo(parser.parse("2020-01-12T12:34:56")) + .isEqualTo(parser.parse(" 2020-01-12T12:34:56 ")) .isEqualTo(LocalDateTime.of(2020, 1, 12, 12, 34, 56)); assertThatExceptionOfType(RuntimeException.class) @@ -184,7 +193,7 @@ public void test_LocalTime() { .isEqualTo("HH:mm:ss"); assertThat(parser.parse("12:34:56")) - .isEqualTo(parser.parse("12:34:56")) + .isEqualTo(parser.parse(" 12:34:56 ")) .isEqualTo(LocalTime.of(12, 34, 56)); assertThat(parser.parse("12:34:56.789")) @@ -201,6 +210,7 @@ public void test_Long() { .isEqualTo("integer number"); assertThat(parser.parse("42")).isEqualTo(42L); + assertThat(parser.parse(" 42 ")).isEqualTo(42L); assertThatExceptionOfType(NumberFormatException.class) .isThrownBy(() -> parser.parse("42.0")); @@ -213,6 +223,7 @@ public void test_Short() { .isEqualTo("integer number in range [-32768, 32767]"); assertThat(parser.parse("42")).isEqualTo((short) 42); + assertThat(parser.parse(" 42 ")).isEqualTo((short) 42); assertThat(parser.parse("-42")).isEqualTo((short) -42); assertThatExceptionOfType(NumberFormatException.class) diff --git a/langchain4j/src/test/java/dev/langchain4j/service/AiServicesBuilderTest.java b/langchain4j/src/test/java/dev/langchain4j/service/AiServicesBuilderTest.java index 9730b42538..e4806588ca 100644 --- a/langchain4j/src/test/java/dev/langchain4j/service/AiServicesBuilderTest.java +++ b/langchain4j/src/test/java/dev/langchain4j/service/AiServicesBuilderTest.java @@ -1,11 +1,17 @@ package dev.langchain4j.service; +import dev.langchain4j.agent.tool.Tool; +import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.exception.IllegalConfigurationException; +import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.chat.mock.ChatModelMock; +import dev.langchain4j.model.output.Response; import dev.langchain4j.rag.RetrievalAugmentor; import dev.langchain4j.rag.content.retriever.ContentRetriever; import dev.langchain4j.retriever.Retriever; import org.junit.jupiter.api.Test; import org.mockito.Mockito; +import org.mockito.Spy; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.mockito.Mockito.mock; @@ -100,4 +106,26 @@ public void testRetrievalAugmentorAndContentRetriever() { }); } + static class HelloWorld { + + @Tool("Say hello") + void add(String name) { + System.out.printf("Hello %s!", name); + } + } + + interface Assistant { + + Response chat(String userMessage); + } + + @Test + public void should_raise_an_error_when_tools_are_classes() { + ChatLanguageModel chatLanguageModel = ChatModelMock.thatAlwaysResponds("Hello there!"); + + assertThrows(IllegalConfigurationException.class, () -> AiServices.builder(Assistant.class) + .chatLanguageModel(chatLanguageModel) + .tools(HelloWorld.class) + .build()); + } } diff --git a/langchain4j/src/test/java/dev/langchain4j/service/AiServicesIT.java b/langchain4j/src/test/java/dev/langchain4j/service/AiServicesIT.java index d57573007d..384d688054 100644 --- a/langchain4j/src/test/java/dev/langchain4j/service/AiServicesIT.java +++ b/langchain4j/src/test/java/dev/langchain4j/service/AiServicesIT.java @@ -8,6 +8,7 @@ import dev.langchain4j.model.moderation.ModerationModel; import dev.langchain4j.model.openai.OpenAiChatModel; import dev.langchain4j.model.openai.OpenAiModerationModel; +import dev.langchain4j.model.output.TokenUsage; import dev.langchain4j.model.output.structured.Description; import lombok.Builder; import lombok.ToString; @@ -610,4 +611,91 @@ void should_not_throw_when_text_is_not_flagged() { verify(chatLanguageModel).generate(singletonList(userMessage(message))); verify(moderationModel).moderate(singletonList(userMessage(message))); } + + + interface AssistantReturningResult { + + Result chat(String userMessage); + } + + @Test + void should_return_result() { + + // given + AssistantReturningResult assistant = AiServices.create(AssistantReturningResult.class, chatLanguageModel); + + String userMessage = "What is the capital of Germany?"; + + // when + Result result = assistant.chat(userMessage); + + // then + assertThat(result.content()).containsIgnoringCase("Berlin"); + + TokenUsage tokenUsage = result.tokenUsage(); + assertThat(tokenUsage).isNotNull(); + assertThat(tokenUsage.inputTokenCount()).isGreaterThan(0); + assertThat(tokenUsage.outputTokenCount()).isGreaterThan(0); + assertThat(tokenUsage.totalTokenCount()) + .isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount()); + + assertThat(result.sources()).isNull(); + + verify(chatLanguageModel).generate(singletonList(userMessage(userMessage))); + } + + + interface AssistantReturningResultWithPojo { + + Result answer(String query); + } + + static class Booking { + + String userId; + String bookingId; + } + + @Test + void should_use_content_retriever_and_return_sources_inside_result_with_pojo() { + + // given + AssistantReturningResultWithPojo assistant = AiServices.create(AssistantReturningResultWithPojo.class, chatLanguageModel); + + // when + Result result = assistant.answer("Give me an example of a booking"); + + // then + Booking booking = result.content(); + assertThat(booking.userId).isNotBlank(); + assertThat(booking.bookingId).isNotBlank(); + + assertThat(result.tokenUsage()).isNotNull(); + assertThat(result.sources()).isNull(); + + verify(chatLanguageModel).generate(singletonList( + userMessage("Give me an example of a booking\n" + + "You must answer strictly in the following JSON format: {\n" + + "\"userId\": (type: string),\n" + + "\"bookingId\": (type: string)\n" + + "}") + )); + } + + + interface InvalidAssistantWithResult { + + Result answerWithNoGenericType(String query); + } + + @Test + void should_throw_exception_when_retrieve_sources_and_generic_type_is_not_set() { + + // when-then + assertThatThrownBy(() -> + AiServices.create(InvalidAssistantWithResult.class, chatLanguageModel)) + .isExactlyInstanceOf(IllegalArgumentException.class) + .hasMessage("The return type 'Result' of the method 'answerWithNoGenericType' must be " + + "parameterized with a type, for example: Result or Result"); + } } diff --git a/langchain4j/src/test/java/dev/langchain4j/service/AiServicesWithRagIT.java b/langchain4j/src/test/java/dev/langchain4j/service/AiServicesWithRagIT.java index b53c7a106c..eb1431da6c 100644 --- a/langchain4j/src/test/java/dev/langchain4j/service/AiServicesWithRagIT.java +++ b/langchain4j/src/test/java/dev/langchain4j/service/AiServicesWithRagIT.java @@ -16,6 +16,7 @@ import dev.langchain4j.model.output.Response; import dev.langchain4j.model.scoring.ScoringModel; import dev.langchain4j.rag.DefaultRetrievalAugmentor; +import dev.langchain4j.rag.content.Content; import dev.langchain4j.rag.content.aggregator.ContentAggregator; import dev.langchain4j.rag.content.aggregator.ReRankingContentAggregator; import dev.langchain4j.rag.content.retriever.ContentRetriever; @@ -562,6 +563,47 @@ void should_use_legacy_retriever(ChatLanguageModel model) { assertThat(answer).containsAnyOf(ALLOWED_CANCELLATION_PERIOD_DAYS, MIN_BOOKING_PERIOD_DAYS); } + + interface AssistantReturningResult { + + Result answer(String query); + } + + @ParameterizedTest + @MethodSource("models") + void should_use_content_retriever_and_return_sources_inside_result(ChatLanguageModel model) { + + // given + ContentRetriever contentRetriever = EmbeddingStoreContentRetriever.builder() + .embeddingStore(embeddingStore) + .embeddingModel(embeddingModel) + .maxResults(1) + .build(); + + AssistantReturningResult assistant = AiServices.builder(AssistantReturningResult.class) + .chatLanguageModel(model) + .contentRetriever(contentRetriever) + .build(); + + // when + Result result = assistant.answer("Can I cancel my booking?"); + + // then + assertThat(result.content()).containsAnyOf(ALLOWED_CANCELLATION_PERIOD_DAYS, MIN_BOOKING_PERIOD_DAYS); + + assertThat(result.tokenUsage()).isNotNull(); + + assertThat(result.sources()).hasSize(1); + Content content = result.sources().get(0); + assertThat(content.textSegment().text()).isEqualToIgnoringWhitespace( + "4. Cancellation Policy" + + "4.1 Reservations can be cancelled up to 61 days prior to the start of the booking period." + + "4.2 If the booking period is less than 17 days, cancellations are not permitted." + ); + assertThat(content.textSegment().metadata("index")).isEqualTo("3"); + assertThat(content.textSegment().metadata("file_name")).isEqualTo("miles-of-smiles-terms-of-use.txt"); + } + private void ingest(String documentPath, EmbeddingStore embeddingStore, EmbeddingModel embeddingModel) { OpenAiTokenizer tokenizer = new OpenAiTokenizer(GPT_3_5_TURBO); DocumentSplitter splitter = DocumentSplitters.recursive(100, 0, tokenizer); diff --git a/langchain4j/src/test/java/dev/langchain4j/service/StreamingAiServicesIT.java b/langchain4j/src/test/java/dev/langchain4j/service/StreamingAiServicesIT.java index 6d70a4e302..7efd8358b5 100644 --- a/langchain4j/src/test/java/dev/langchain4j/service/StreamingAiServicesIT.java +++ b/langchain4j/src/test/java/dev/langchain4j/service/StreamingAiServicesIT.java @@ -11,6 +11,7 @@ import dev.langchain4j.memory.chat.MessageWindowChatMemory; import dev.langchain4j.model.azure.AzureOpenAiStreamingChatModel; import dev.langchain4j.model.chat.StreamingChatLanguageModel; +import dev.langchain4j.model.mistralai.MistralAiStreamingChatModel; import dev.langchain4j.model.openai.OpenAiStreamingChatModel; import dev.langchain4j.model.output.Response; import dev.langchain4j.model.output.TokenUsage; @@ -43,6 +44,11 @@ static Stream models() { .endpoint(System.getenv("AZURE_OPENAI_ENDPOINT")) .apiKey(System.getenv("AZURE_OPENAI_KEY")) .logRequestsAndResponses(true) + .build(), + MistralAiStreamingChatModel.builder() + .apiKey(System.getenv("MISTRAL_AI_API_KEY")) + .logRequests(true) + .logResponses(true) .build() // TODO add more models ); diff --git a/langchain4j/src/test/java/dev/langchain4j/store/embedding/inmemory/InMemoryEmbeddingStoreRemovalTest.java b/langchain4j/src/test/java/dev/langchain4j/store/embedding/inmemory/InMemoryEmbeddingStoreRemovalTest.java new file mode 100644 index 0000000000..934c2888ad --- /dev/null +++ b/langchain4j/src/test/java/dev/langchain4j/store/embedding/inmemory/InMemoryEmbeddingStoreRemovalTest.java @@ -0,0 +1,24 @@ +package dev.langchain4j.store.embedding.inmemory; + +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel; +import dev.langchain4j.model.embedding.EmbeddingModel; +import dev.langchain4j.store.embedding.EmbeddingStore; +import dev.langchain4j.store.embedding.EmbeddingStoreWithRemovalIT; + +class InMemoryEmbeddingStoreRemovalTest extends EmbeddingStoreWithRemovalIT { + + EmbeddingStore embeddingStore = new InMemoryEmbeddingStore<>(); + + EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel(); + + @Override + protected EmbeddingStore embeddingStore() { + return embeddingStore; + } + + @Override + protected EmbeddingModel embeddingModel() { + return embeddingModel; + } +} diff --git a/pom.xml b/pom.xml index b69d4010cf..faa1955776 100644 --- a/pom.xml +++ b/pom.xml @@ -6,7 +6,7 @@ dev.langchain4j langchain4j-aggregator - 0.30.0 + 0.32.0-SNAPSHOT pom LangChain4j :: Aggregator @@ -27,20 +27,24 @@ langchain4j-chatglm langchain4j-cohere langchain4j-dashscope - langchain4j-qianfan langchain4j-hugging-face + langchain4j-jina langchain4j-local-ai langchain4j-mistral-ai langchain4j-nomic langchain4j-ollama langchain4j-open-ai + langchain4j-qianfan langchain4j-vertex-ai langchain4j-vertex-ai-gemini + langchain4j-workers-ai langchain4j-zhipu-ai langchain4j-astradb langchain4j-azure-ai-search + langchain4j-azure-cosmos-mongo-vcore + langchain4j-azure-cosmos-nosql langchain4j-cassandra langchain4j-chroma langchain4j-elasticsearch @@ -56,12 +60,12 @@ langchain4j-vearch langchain4j-vespa langchain4j-weaviate - langchain4j-azure-cosmos-mongo-vcore document-loaders/langchain4j-document-loader-amazon-s3 document-loaders/langchain4j-document-loader-azure-storage-blob document-loaders/langchain4j-document-loader-github + document-loaders/langchain4j-document-loader-selenium document-loaders/langchain4j-document-loader-tencent-cos @@ -71,10 +75,18 @@ code-execution-engines/langchain4j-code-execution-engine-graalvm-polyglot + code-execution-engines/langchain4j-code-execution-engine-judge0 + + + web-search-engines/langchain4j-web-search-engine-google-custom + web-search-engines/langchain4j-web-search-engine-tavily embedding-store-filter-parsers/langchain4j-embedding-store-filter-parser-sql + + experimental/langchain4j-experimental-sql + diff --git a/web-search-engines/langchain4j-web-search-engine-google-custom/pom.xml b/web-search-engines/langchain4j-web-search-engine-google-custom/pom.xml new file mode 100644 index 0000000000..870bf5e98c --- /dev/null +++ b/web-search-engines/langchain4j-web-search-engine-google-custom/pom.xml @@ -0,0 +1,93 @@ + + + 4.0.0 + + + dev.langchain4j + langchain4j-parent + 0.32.0-SNAPSHOT + ../../langchain4j-parent/pom.xml + + + langchain4j-web-search-engine-google-custom + jar + + LangChain4j :: Web Search Engine :: Google Custom Search + Implementation of Google Custom Search API for LangChain4j + + + + + dev.langchain4j + langchain4j-core + + + + + com.google.apis + google-api-services-customsearch + v1-rev20240417-2.0.0 + + + + org.slf4j + slf4j-api + 2.0.7 + + + + org.projectlombok + lombok + provided + + + + org.junit.jupiter + junit-jupiter-engine + test + + + + org.assertj + assertj-core + test + + + + org.tinylog + tinylog-impl + test + + + + org.tinylog + slf4j-tinylog + test + + + + dev.langchain4j + langchain4j + test + + + + + dev.langchain4j + langchain4j-core + tests + test-jar + test + + + + dev.langchain4j + langchain4j-open-ai + test + + + + + diff --git a/web-search-engines/langchain4j-web-search-engine-google-custom/src/main/java/dev/langchain4j/web/search/google/customsearch/GoogleCustomSearchApiClient.java b/web-search-engines/langchain4j-web-search-engine-google-custom/src/main/java/dev/langchain4j/web/search/google/customsearch/GoogleCustomSearchApiClient.java new file mode 100644 index 0000000000..d9ef8d043b --- /dev/null +++ b/web-search-engines/langchain4j-web-search-engine-google-custom/src/main/java/dev/langchain4j/web/search/google/customsearch/GoogleCustomSearchApiClient.java @@ -0,0 +1,190 @@ +package dev.langchain4j.web.search.google.customsearch; + +import com.google.api.client.googleapis.javanet.GoogleNetHttpTransport; +import com.google.api.client.http.HttpRequest; +import com.google.api.client.http.HttpRequestInitializer; +import com.google.api.client.json.gson.GsonFactory; +import com.google.api.services.customsearch.v1.CustomSearchAPI; +import com.google.api.services.customsearch.v1.CustomSearchAPIRequest; +import com.google.api.services.customsearch.v1.model.Search; +import lombok.Builder; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.security.GeneralSecurityException; +import java.time.Duration; + +import static dev.langchain4j.internal.Utils.getOrDefault; +import static dev.langchain4j.internal.Utils.isNullOrBlank; + +class GoogleCustomSearchApiClient { + + private static final Logger LOGGER = LoggerFactory.getLogger(GoogleCustomSearchApiClient.class); + private static final Integer MAXIMUM_VALUE_NUM = 10; + + private final CustomSearchAPIRequest customSearchRequest; + private final boolean logResponses; + + @Builder + GoogleCustomSearchApiClient(String apiKey, + String csi, + Boolean siteRestrict, + Duration timeout, + Integer maxRetries, + boolean logRequests, + boolean logResponses) { + + try { + if (isNullOrBlank(apiKey)) { + throw new IllegalArgumentException("Google Custom Search API Key must be defined. " + + "It can be generated here: https://console.developers.google.com/apis/credentials"); + } + if (isNullOrBlank(csi)) { + throw new IllegalArgumentException("Google Custom Search Engine ID must be defined. " + + "It can be created here: https://cse.google.com/cse/create/new"); + } + + this.logResponses = logResponses; + + CustomSearchAPI.Builder customSearchAPIBuilder = new CustomSearchAPI.Builder(GoogleNetHttpTransport.newTrustedTransport(), new GsonFactory(), new HttpRequestInitializer() { + @Override + public void initialize(HttpRequest httpRequest) { + httpRequest.setConnectTimeout(Math.toIntExact(timeout.toMillis())); + httpRequest.setReadTimeout(Math.toIntExact(timeout.toMillis())); + httpRequest.setWriteTimeout(Math.toIntExact(timeout.toMillis())); + httpRequest.setNumberOfRetries(maxRetries); + if (logRequests) { + httpRequest.setInterceptor(new GoogleSearchApiHttpRequestLoggingInterceptor()); + } + if (logResponses) { + httpRequest.setResponseInterceptor(new GoogleSearchApiHttpResponseLoggingInterceptor()); + } + } + }).setApplicationName("LangChain4j"); + + CustomSearchAPI customSearchAPI = customSearchAPIBuilder.build(); + + if (siteRestrict) { + customSearchRequest = customSearchAPI.cse().siterestrict().list().setKey(apiKey).setCx(csi); + } else { + customSearchRequest = customSearchAPI.cse().list().setKey(apiKey).setCx(csi); + } + } catch (IOException e) { + LOGGER.error("Error occurred while creating Google Custom Search API client", e); + throw new RuntimeException(e); + } catch (GeneralSecurityException e) { + LOGGER.error("Error occurred while creating Google Custom Search API client using GoogleNetHttpTransport.newTrustedTransport()", e); + throw new RuntimeException(e); + } + } + + Search searchResults(Search.Queries.Request requestQuery) { + try { + Search searchPerformed; + if (customSearchRequest instanceof CustomSearchAPI.Cse.Siterestrict.List) { + searchPerformed = ((CustomSearchAPI.Cse.Siterestrict.List) customSearchRequest) + .setPrettyPrint(true) + .setQ(requestQuery.getSearchTerms()) + .setNum(maxResultsAllowed(getDefaultNaturalNumber(requestQuery.getCount()))) + .setSort(requestQuery.getSort()) + .setSafe(requestQuery.getSafe()) + .setDateRestrict(requestQuery.getDateRestrict()) + .setGl(requestQuery.getGl()) + .setLr(requestQuery.getLanguage()) + .setHl(requestQuery.getHl()) + .setHq(requestQuery.getHq()) + .setSiteSearch(requestQuery.getSiteSearch()) + .setSiteSearchFilter(requestQuery.getSiteSearchFilter()) + .setExactTerms(requestQuery.getExactTerms()) + .setExcludeTerms(requestQuery.getExcludeTerms()) + .setLinkSite(requestQuery.getLinkSite()) + .setOrTerms(requestQuery.getOrTerms()) + .setLowRange(requestQuery.getLowRange()) + .setHighRange(requestQuery.getHighRange()) + .setSearchType(requestQuery.getSearchType()) + .setFileType(requestQuery.getFileType()) + .setRights(requestQuery.getRights()) + .setImgSize(requestQuery.getImgSize()) + .setImgType(requestQuery.getImgType()) + .setImgColorType(requestQuery.getImgColorType()) + .setImgDominantColor(requestQuery.getImgDominantColor()) + .setC2coff(requestQuery.getDisableCnTwTranslation()) + .setCr(requestQuery.getCr()) + .setGooglehost(requestQuery.getGoogleHost()) + .setStart(calculateIndexStartPage( + getDefaultNaturalNumber(requestQuery.getStartPage()), + getDefaultNaturalNumber(requestQuery.getStartIndex()) + ).longValue()) + .setFilter(requestQuery.getFilter()) + .execute(); + } else if (customSearchRequest instanceof CustomSearchAPI.Cse.List) { + searchPerformed = ((CustomSearchAPI.Cse.List) customSearchRequest) + .setPrettyPrint(true) + .setQ(requestQuery.getSearchTerms()) + .setNum(maxResultsAllowed(getDefaultNaturalNumber(requestQuery.getCount()))) + .setSort(requestQuery.getSort()) + .setSafe(requestQuery.getSafe()) + .setDateRestrict(requestQuery.getDateRestrict()) + .setGl(requestQuery.getGl()) + .setLr(requestQuery.getLanguage()) + .setHl(requestQuery.getHl()) + .setHq(requestQuery.getHq()) + .setSiteSearch(requestQuery.getSiteSearch()) + .setSiteSearchFilter(requestQuery.getSiteSearchFilter()) + .setExactTerms(requestQuery.getExactTerms()) + .setExcludeTerms(requestQuery.getExcludeTerms()) + .setLinkSite(requestQuery.getLinkSite()) + .setOrTerms(requestQuery.getOrTerms()) + .setLowRange(requestQuery.getLowRange()) + .setHighRange(requestQuery.getHighRange()) + .setSearchType(requestQuery.getSearchType()) + .setFileType(requestQuery.getFileType()) + .setRights(requestQuery.getRights()) + .setImgSize(requestQuery.getImgSize()) + .setImgType(requestQuery.getImgType()) + .setImgColorType(requestQuery.getImgColorType()) + .setImgDominantColor(requestQuery.getImgDominantColor()) + .setC2coff(requestQuery.getDisableCnTwTranslation()) + .setCr(requestQuery.getCr()) + .setGooglehost(requestQuery.getGoogleHost()) + .setStart(calculateIndexStartPage( + getDefaultNaturalNumber(requestQuery.getStartPage()), + getDefaultNaturalNumber(requestQuery.getStartIndex()) + ).longValue()) + .setFilter(requestQuery.getFilter()) + .execute(); + } else { + throw new IllegalStateException("Invalid CustomSearchAPIRequest type"); + } + if (logResponses) { + logResponse(searchPerformed); + } + return searchPerformed; + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + private static void logResponse(Search search) { + try { + LOGGER.debug("Response:\n- body: {}", search.toPrettyString()); + } catch (IOException e) { + LOGGER.warn("Error while logging response: {}", e.getMessage()); + } + } + + private static Integer maxResultsAllowed(Integer maxResults) { + return maxResults > MAXIMUM_VALUE_NUM ? MAXIMUM_VALUE_NUM : maxResults; + } + + private static Integer getDefaultNaturalNumber(Integer number) { + int defaultNumber = getOrDefault(number, 1); + return defaultNumber > 0 ? defaultNumber : 1; + } + + private static Integer calculateIndexStartPage(Integer pageNumber, Integer index) { + int indexStartPage = ((pageNumber - 1) * MAXIMUM_VALUE_NUM) + 1; + return indexStartPage >= index ? indexStartPage : index; + } +} diff --git a/web-search-engines/langchain4j-web-search-engine-google-custom/src/main/java/dev/langchain4j/web/search/google/customsearch/GoogleCustomWebSearchEngine.java b/web-search-engines/langchain4j-web-search-engine-google-custom/src/main/java/dev/langchain4j/web/search/google/customsearch/GoogleCustomWebSearchEngine.java new file mode 100644 index 0000000000..5e0322b833 --- /dev/null +++ b/web-search-engines/langchain4j-web-search-engine-google-custom/src/main/java/dev/langchain4j/web/search/google/customsearch/GoogleCustomWebSearchEngine.java @@ -0,0 +1,273 @@ +package dev.langchain4j.web.search.google.customsearch; + +import com.google.api.client.json.GenericJson; +import com.google.api.services.customsearch.v1.model.Result; +import com.google.api.services.customsearch.v1.model.Search; +import dev.langchain4j.web.search.*; +import lombok.Builder; + +import java.net.URI; +import java.time.Duration; +import java.time.LocalDateTime; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static com.google.api.services.customsearch.v1.model.Search.Queries; +import static dev.langchain4j.internal.Utils.*; +import static dev.langchain4j.internal.ValidationUtils.ensureNotNull; +import static java.util.stream.Collectors.toList; + +/** + * An implementation of a {@link WebSearchEngine} that uses + * Google Custom Search API for performing web searches. + */ +public class GoogleCustomWebSearchEngine implements WebSearchEngine { + + private final GoogleCustomSearchApiClient googleCustomSearchApiClient; + private final Boolean includeImages; + + /** + * Constructs a new GoogleCustomWebSearchEngine with the specified parameters. + * + * @param apiKey the Google Search API key for accessing the Google Custom Search API + *

    + * You can just generate an API key here + * @param csi the Custom Search ID parameter for search the entire web + *

    + * You can create a Custom Search Engine here + * @param siteRestrict if your Search Engine is restricted to only searching specific sites, you can set this parameter to true. + *

    + * Default value is false. View the documentation for more information here + * @param includeImages If it is true then include public images relevant to the query. This can add more latency to the search. + *

    + * Default value is false. + * @param timeout the timeout duration for API requests + *

    + * Default value is 60 seconds. + * @param maxRetries the maximum number of retries for API requests + *

    + * Default value is 10. + * @param logRequests whether to log API requests + *

    + * Default value is false. + * @param logResponses whether to log API responses + *

    + * Default value is false. + */ + @Builder + public GoogleCustomWebSearchEngine(String apiKey, + String csi, + Boolean siteRestrict, + Boolean includeImages, + Duration timeout, + Integer maxRetries, + Boolean logRequests, + Boolean logResponses) { + + this.googleCustomSearchApiClient = GoogleCustomSearchApiClient.builder() + .apiKey(apiKey) + .csi(csi) + .siteRestrict(getOrDefault(siteRestrict, false)) + .timeout(getOrDefault(timeout, Duration.ofSeconds(60))) + .maxRetries(getOrDefault(maxRetries, 3)) + .logRequests(getOrDefault(logRequests, false)) + .logResponses(getOrDefault(logResponses, false)) + .build(); + this.includeImages = getOrDefault(includeImages, false); + } + + /** + * Creates a new builder for constructing a GoogleCustomWebSearchEngine with the specified API key and Custom Search ID. + * + * @param apiKey the API key for accessing the Google Custom Search API + * @param csi the Custom Search ID parameter for search the entire web + * @return a new builder instance + */ + public static GoogleCustomWebSearchEngine withApiKeyAndCsi(String apiKey, String csi) { + return GoogleCustomWebSearchEngine.builder().apiKey(apiKey).csi(csi).build(); + } + + @Override + public WebSearchResults search(WebSearchRequest webSearchRequest) { + ensureNotNull(webSearchRequest, "webSearchRequest"); + + Queries.Request requestQuery = new Queries.Request(); + requestQuery.setSearchTerms(webSearchRequest.searchTerms()); + requestQuery.setCount(getOrDefault(webSearchRequest.maxResults(), 5)); + requestQuery.setGl(webSearchRequest.geoLocation()); + requestQuery.setLanguage(webSearchRequest.language()); + requestQuery.setStartPage(webSearchRequest.startPage()); + requestQuery.setStartIndex(webSearchRequest.startIndex()); + requestQuery.setSafe(webSearchRequest.safeSearch() ? "active" : "off"); + requestQuery.setFilter("1"); // By default, applies filtering to remove duplicate content + requestQuery.setCr(setCountryRestrict(webSearchRequest)); + webSearchRequest.additionalParams().forEach(requestQuery::set); + + boolean searchTypeImage = "image".equals(requestQuery.getSearchType()); + + // Web search + Search search = googleCustomSearchApiClient.searchResults(requestQuery); + Map searchMetadata = toSearchMetadata(search, searchTypeImage); + Map searchInformationMetadata = new HashMap<>(); + + // Images search + if (includeImages && !searchTypeImage) { + requestQuery.setSearchType("image"); + Search imagesSearch = googleCustomSearchApiClient.searchResults(requestQuery); + List images = imagesSearch.getItems().stream() + .map(result -> ImageSearchResult.from( + result.getTitle(), + URI.create(result.getLink()), + URI.create(result.getImage().getContextLink()), + URI.create(result.getImage().getThumbnailLink()))) + .collect(toList()); + addImagesToSearchInformation(searchInformationMetadata, images); + } + + return WebSearchResults.from( + searchMetadata, + WebSearchInformationResult.from( + Long.valueOf(getOrDefault(search.getSearchInformation().getTotalResults(), "0")), + !isNullOrEmpty(search.getQueries().getRequest()) + ? calculatePageNumberFromQueries(search.getQueries().getRequest().get(0)) : 1, + searchInformationMetadata.isEmpty() ? null : searchInformationMetadata), + search.getItems().stream() + .map(result -> WebSearchOrganicResult.from( + result.getTitle(), + URI.create(result.getLink()), + result.getSnippet(), + null, // by default google custom search api does not return content + toResultMetadataMap(result, searchTypeImage) + )).collect(toList())); + } + + private static void addImagesToSearchInformation(Map searchInformationMetadata, List images) { + if (!isNullOrEmpty(images)) { + searchInformationMetadata.put("images", images); + } + } + + private static Map toSearchMetadata(Search search, Boolean searchTypeImage) { + if (search == null) { + return null; + } + Map searchMetadata = new HashMap<>(); + searchMetadata.put("status", "Success"); + searchMetadata.put("searchTime", search.getSearchInformation().getSearchTime()); + searchMetadata.put("processedAt", LocalDateTime.now().toString()); + searchMetadata.put("searchType", searchTypeImage ? "images" : "web"); + searchMetadata.putAll(search.getContext()); + return searchMetadata; + } + + private static Map toResultMetadataMap(Result result, boolean searchTypeImage) { + Map metadata = new HashMap<>(); + // Image search type + if (searchTypeImage) { + metadata.put("imageLink", result.getLink()); + metadata.put("contextLink", result.getImage().getContextLink()); + metadata.put("thumbnailLink", result.getImage().getThumbnailLink()); + metadata.put("mimeType", result.getMime()); + return metadata; + } + // Web search type + if (!result.getPagemap().isEmpty()) { + result.getPagemap().forEach((key, value) -> { + if (key.equals("metatags")) { + if (value instanceof List) { + metadata.put(key, ((List) value).stream().map(Object::toString).reduce((a, b) -> a + ", " + b).orElse("")); + } else { + metadata.put(key, value.toString()); + } + } + metadata.put("mimeType", isNotNullOrBlank(result.getMime()) ? result.getMime() : "text/html"); + }); + return metadata; + } + return null; + } + + private static Integer calculatePageNumberFromQueries(GenericJson query) { + if (query instanceof Queries.PreviousPage) { + Queries.PreviousPage previousPage = (Queries.PreviousPage) query; + return calculatePageNumber(previousPage.getStartIndex()); + } + if (query instanceof Queries.Request) { + Queries.Request currentPage = (Queries.Request) query; + return calculatePageNumber(getOrDefault(currentPage.getStartIndex(), 1)); + } + if (query instanceof Queries.NextPage) { + Queries.NextPage nextPage = (Queries.NextPage) query; + return calculatePageNumber(nextPage.getStartIndex()); + } + return null; + } + + private static Integer calculatePageNumber(Integer startIndex) { + if (startIndex == null) + return null; + return ((startIndex - 1) / 10) + 1; + } + + private static String setCountryRestrict(WebSearchRequest webSearchRequest) { + return webSearchRequest.additionalParams().get("cr") != null ? webSearchRequest.additionalParams().get("cr").toString() + : isNotNullOrBlank(webSearchRequest.geoLocation()) ? "country" + webSearchRequest.geoLocation().toUpperCase() + : ""; // default value + } + + public static final class ImageSearchResult { + private final String title; + private final URI imageLink; + private final URI contextLink; + private final URI thumbnailLink; + + private ImageSearchResult(String title, URI imageLink) { + this.title = ensureNotNull(title, "title"); + this.imageLink = ensureNotNull(imageLink, "imageLink"); + this.contextLink = null; + this.thumbnailLink = null; + } + + private ImageSearchResult(String title, URI imageLink, URI contextLink, URI thumbnailLink) { + this.title = ensureNotNull(title, "title"); + this.imageLink = ensureNotNull(imageLink, "imageLink"); + this.contextLink = contextLink; + this.thumbnailLink = thumbnailLink; + } + + public String title() { + return title; + } + + public URI imageLink() { + return imageLink; + } + + public URI contextLink() { + return contextLink; + } + + public URI thumbnailLink() { + return thumbnailLink; + } + + @Override + public String toString() { + return "ImageSearchResult{" + + "title='" + title + '\'' + + ", imageLink=" + imageLink + + ", contextLink=" + contextLink + + ", thumbnailLink=" + thumbnailLink + + '}'; + } + + public static ImageSearchResult from(String title, URI imageLink) { + return new ImageSearchResult(title, imageLink); + } + + public static ImageSearchResult from(String title, URI imageLink, URI contextLink, URI thumbnailLink) { + return new ImageSearchResult(title, imageLink, contextLink, thumbnailLink); + } + } +} diff --git a/web-search-engines/langchain4j-web-search-engine-google-custom/src/main/java/dev/langchain4j/web/search/google/customsearch/GoogleSearchApiHttpRequestLoggingInterceptor.java b/web-search-engines/langchain4j-web-search-engine-google-custom/src/main/java/dev/langchain4j/web/search/google/customsearch/GoogleSearchApiHttpRequestLoggingInterceptor.java new file mode 100644 index 0000000000..8e5ac593e7 --- /dev/null +++ b/web-search-engines/langchain4j-web-search-engine-google-custom/src/main/java/dev/langchain4j/web/search/google/customsearch/GoogleSearchApiHttpRequestLoggingInterceptor.java @@ -0,0 +1,46 @@ +package dev.langchain4j.web.search.google.customsearch; + +import com.google.api.client.http.HttpContent; +import com.google.api.client.http.HttpExecuteInterceptor; +import com.google.api.client.http.HttpHeaders; +import com.google.api.client.http.HttpRequest; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.stream.Collectors; + +class GoogleSearchApiHttpRequestLoggingInterceptor implements HttpExecuteInterceptor { + + private static final Logger LOGGER = LoggerFactory.getLogger(GoogleSearchApiHttpRequestLoggingInterceptor.class); + + @Override + public void intercept(HttpRequest httpRequest) { + this.log(httpRequest); + } + + private void log(HttpRequest httpRequest) { + try { + LOGGER.debug("Request:\n- method: {}\n- url: {}\n- headers: {}\n- body: {}", + httpRequest.getRequestMethod(), httpRequest.getUrl(), getHeaders(httpRequest.getHeaders()), getBody(httpRequest.getContent())); + } catch (Exception e) { + LOGGER.warn("Error while logging request: {}", e.getMessage()); + } + } + + private static String getHeaders(HttpHeaders headers) { + return headers.entrySet().stream() + .map(entry -> String.format("[%s: %s]", entry.getKey(), entry.getValue())).collect(Collectors.joining(", ")); + } + + private static String getBody(HttpContent content) { + try { + if (content == null) { + return ""; + } + return content.toString(); + } catch (Exception e) { + LOGGER.warn("Exception while getting body", e); + return "Exception while getting body: " + e.getMessage(); + } + } +} diff --git a/web-search-engines/langchain4j-web-search-engine-google-custom/src/main/java/dev/langchain4j/web/search/google/customsearch/GoogleSearchApiHttpResponseLoggingInterceptor.java b/web-search-engines/langchain4j-web-search-engine-google-custom/src/main/java/dev/langchain4j/web/search/google/customsearch/GoogleSearchApiHttpResponseLoggingInterceptor.java new file mode 100644 index 0000000000..f9d583197e --- /dev/null +++ b/web-search-engines/langchain4j-web-search-engine-google-custom/src/main/java/dev/langchain4j/web/search/google/customsearch/GoogleSearchApiHttpResponseLoggingInterceptor.java @@ -0,0 +1,35 @@ +package dev.langchain4j.web.search.google.customsearch; + +import com.google.api.client.http.HttpHeaders; +import com.google.api.client.http.HttpResponse; +import com.google.api.client.http.HttpResponseInterceptor; +import com.google.api.client.json.gson.GsonFactory; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.stream.Collectors; + +class GoogleSearchApiHttpResponseLoggingInterceptor implements HttpResponseInterceptor { + + private static final Logger LOGGER = LoggerFactory.getLogger(GoogleSearchApiHttpResponseLoggingInterceptor.class); + + @Override + public void interceptResponse(HttpResponse httpResponse) { + this.log(httpResponse); + } + + private void log(HttpResponse httpResponse) { + try { + httpResponse.getRequest().setParser(new GsonFactory().createJsonObjectParser()); + LOGGER.debug("Response:\n- status code: {}\n- headers: {}", + httpResponse.getStatusCode(), getHeaders(httpResponse.getHeaders())); // response body can't be got twice by google token constraints, it'll be logged in GoogleCustomSearchApiClient + } catch (Exception e) { + LOGGER.warn("Error while logging response: {}", e.getMessage()); + } + } + + private static String getHeaders(HttpHeaders headers) { + return headers.entrySet().stream() + .map(entry -> String.format("[%s: %s]", entry.getKey(), entry.getValue())).collect(Collectors.joining(", ")); + } +} diff --git a/web-search-engines/langchain4j-web-search-engine-google-custom/src/test/java/dev/langchain4j/web/search/google/customsearch/GoogleCustomWebSearchContentRetrieverIT.java b/web-search-engines/langchain4j-web-search-engine-google-custom/src/test/java/dev/langchain4j/web/search/google/customsearch/GoogleCustomWebSearchContentRetrieverIT.java new file mode 100644 index 0000000000..6c82a1f0aa --- /dev/null +++ b/web-search-engines/langchain4j-web-search-engine-google-custom/src/test/java/dev/langchain4j/web/search/google/customsearch/GoogleCustomWebSearchContentRetrieverIT.java @@ -0,0 +1,62 @@ +package dev.langchain4j.web.search.google.customsearch; + +import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.openai.OpenAiChatModel; +import dev.langchain4j.rag.content.retriever.WebSearchContentRetriever; +import dev.langchain4j.rag.content.retriever.WebSearchContentRetrieverIT; +import dev.langchain4j.service.AiServices; +import dev.langchain4j.web.search.WebSearchEngine; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + +import static org.assertj.core.api.Assertions.assertThat; + +@EnabledIfEnvironmentVariable(named = "GOOGLE_API_KEY", matches = ".*") +@EnabledIfEnvironmentVariable(named = "GOOGLE_SEARCH_ENGINE_ID", matches = ".*") +class GoogleCustomWebSearchContentRetrieverIT extends WebSearchContentRetrieverIT { + + WebSearchEngine googleSearchEngine = GoogleCustomWebSearchEngine.builder() + .apiKey(System.getenv("GOOGLE_API_KEY")) + .csi(System.getenv("GOOGLE_SEARCH_ENGINE_ID")) + .logRequests(true) + .logResponses(true) + .build(); + + ChatLanguageModel chatModel = OpenAiChatModel.builder() + .apiKey(System.getenv("OPENAI_API_KEY")) + .logRequests(true) + .logResponses(true) + .build(); + + interface Assistant { + + String answer(String userMessage); + } + + @Test + void should_retrieve_web_content_with_google_and_use_AiServices_to_summary_response() { + + // given + WebSearchContentRetriever contentRetriever = WebSearchContentRetriever.builder() + .webSearchEngine(googleSearchEngine) + .build(); + + Assistant assistant = AiServices.builder(Assistant.class) + .chatLanguageModel(chatModel) + .contentRetriever(contentRetriever) + .build(); + + String query = "What features does LangChain4j have?"; + + // when + String answer = assistant.answer(query); + + // then + assertThat(answer).contains("RAG"); + } + + @Override + protected WebSearchEngine searchEngine() { + return googleSearchEngine; + } +} diff --git a/web-search-engines/langchain4j-web-search-engine-google-custom/src/test/java/dev/langchain4j/web/search/google/customsearch/GoogleCustomWebSearchEngineIT.java b/web-search-engines/langchain4j-web-search-engine-google-custom/src/test/java/dev/langchain4j/web/search/google/customsearch/GoogleCustomWebSearchEngineIT.java new file mode 100644 index 0000000000..810758cbd1 --- /dev/null +++ b/web-search-engines/langchain4j-web-search-engine-google-custom/src/test/java/dev/langchain4j/web/search/google/customsearch/GoogleCustomWebSearchEngineIT.java @@ -0,0 +1,236 @@ +package dev.langchain4j.web.search.google.customsearch; + +import dev.langchain4j.web.search.*; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static dev.langchain4j.web.search.google.customsearch.GoogleCustomWebSearchEngine.ImageSearchResult; +import static java.util.Collections.singletonMap; +import static org.assertj.core.api.Assertions.assertThat; + +@EnabledIfEnvironmentVariable(named = "GOOGLE_API_KEY", matches = ".*") +@EnabledIfEnvironmentVariable(named = "GOOGLE_SEARCH_ENGINE_ID", matches = ".*") +class GoogleCustomWebSearchEngineIT extends WebSearchEngineIT { + + WebSearchEngine googleSearchEngine = GoogleCustomWebSearchEngine.builder() + .apiKey(System.getenv("GOOGLE_API_KEY")) + .csi(System.getenv("GOOGLE_SEARCH_ENGINE_ID")) + .logRequests(true) + .logResponses(true) + .build(); + + @Test + void should_return_google_web_results_with_search_information() { + // given + String query = "What is LangChain4j project?"; + + // when + WebSearchResults results = googleSearchEngine.search(query); + + // then + assertThat(results.searchMetadata()).isNotNull(); + assertThat(results.searchInformation().totalResults()).isGreaterThan(0); + assertThat(results.results().size()).isGreaterThan(0); + } + + @Test + @Disabled("fails") + void should_return_google_safe_web_results_in_spanish_language() { + // given + String query = "Who won the FIFA World Cup 2022?"; + WebSearchRequest webSearchRequest = WebSearchRequest.builder() + .searchTerms(query) + .language("lang_es") + .safeSearch(true) + .build(); + + // when + List results = googleSearchEngine.search(webSearchRequest).results(); + + // then + assertThat(results) + .as("At least one result should be contains 'argentina' ignoring case") + .anySatisfy(result -> assertThat(result.snippet()) + .containsIgnoringCase("argentina")); + } + + @Test + void should_return_google_results_of_the_second_page_and_log_http_req_resp() { + // given + WebSearchEngine googleSearchEngine = GoogleCustomWebSearchEngine.builder() + .apiKey(System.getenv("GOOGLE_API_KEY")) + .csi(System.getenv("GOOGLE_SEARCH_ENGINE_ID")) + .logRequests(true) + .logResponses(true) + .build(); + + String query = "What is the weather in Porto?"; + WebSearchRequest webSearchRequest = WebSearchRequest.builder() + .searchTerms(query) + .maxResults(5) + .startPage(2) + .build(); + + // when + WebSearchResults webSearchResults = googleSearchEngine.search(webSearchRequest); + + // then + assertThat(webSearchResults.results()) + .as("At least the string result should be contains 'weather' and 'Porto' ignoring case") + .anySatisfy(result -> assertThat(result.snippet()) + .containsIgnoringCase("weather") + .containsIgnoringCase("porto")); + } + + @Test + void should_return_google_results_using_and_fix_startpage_by_startindex() { + // given + String query = "What is LangChain4j project?"; + WebSearchRequest webSearchRequest = WebSearchRequest.builder() + .searchTerms(query) + .language("lang_en") + .startPage(1) //user bad request + .startIndex(15) + .build(); + + // when + WebSearchResults webSearchResults = googleSearchEngine.search(webSearchRequest); + + // then + assertThat(webSearchResults.results()) + .as("At least one result should be contains 'Java' and 'AI' ignoring case") + .anySatisfy(result -> assertThat(result.snippet()) + .containsIgnoringCase("Java") + .containsIgnoringCase("AI")); + } + + @Test + void should_return_google_result_using_additional_params() { + // given + WebSearchEngine googleSearchEngine = GoogleCustomWebSearchEngine.builder() + .apiKey(System.getenv("GOOGLE_API_KEY")) + .csi(System.getenv("GOOGLE_SEARCH_ENGINE_ID")) + .logRequests(true) + .logResponses(true) + .build(); + + String query = "What is LangChain4j project?"; + Map additionalParams = new HashMap<>(); + additionalParams.put("dateRestrict", "w[2]"); + additionalParams.put("linkSite", "https://github.com/langchain4j/langchain4j"); + WebSearchRequest webSearchRequest = WebSearchRequest.builder() + .searchTerms(query) + .additionalParams(additionalParams) + .build(); + + // when + List results = googleSearchEngine.search(webSearchRequest).results(); + + // then + assertThat(results) + .as("At least one result should be contains 'Java' and 'AI' ignoring case") + .anySatisfy(result -> assertThat(result.snippet()) + .containsIgnoringCase("Java") + .containsIgnoringCase("github.com/langchain4j/langchain4j")); + } + + @Test + void should_return_google_result_with_images_related() { + // given + WebSearchEngine googleSearchEngine = GoogleCustomWebSearchEngine.builder() + .apiKey(System.getenv("GOOGLE_API_KEY")) + .csi(System.getenv("GOOGLE_SEARCH_ENGINE_ID")) + .includeImages(true) // execute an additional search, searchType: image + .logRequests(true) + .logResponses(true) + .build(); + + String query = "Which top 2024 universities to study computer science?"; + WebSearchRequest webSearchRequest = WebSearchRequest.builder() + .searchTerms(query) + .build(); + + // when + WebSearchResults webSearchResults = googleSearchEngine.search(webSearchRequest); + + // then + assertThat(webSearchResults.searchMetadata().get("searchType").toString()).isEqualTo("web"); // searchType: web + assertThat(webSearchResults.searchInformation().metadata().get("images")).isOfAnyClassIn(ArrayList.class, List.class); // should add images related to the query + assertThat((List) webSearchResults.searchInformation().metadata().get("images")) // Get images from searchInformation.metadata + .as("At least one image result should be contains title, link, contextLink and thumbnailLink") + .anySatisfy(image -> { + assertThat(image.title()).isNotNull(); + assertThat(image.imageLink().toString()).startsWith("http"); + assertThat(image.contextLink().toString()).startsWith("http"); + assertThat(image.thumbnailLink().toString()).startsWith("http"); + }); + assertThat(webSearchResults.results()) // Get web results + .as("At least the string result should be contains 'university' and 'ranking' ignoring case") + .anySatisfy(result -> assertThat(result.snippet()) + .containsIgnoringCase("university") + .containsIgnoringCase("ranking")); + + } + + @Test + void should_return_google_image_result_with_param_searchType_image() { + // given + String query = "How will be the weather next week in Lisbon and Porto cities?"; + WebSearchRequest webSearchRequest = WebSearchRequest.builder() + .searchTerms(query) + .additionalParams(singletonMap("searchType", "image")) + .build(); + + // when + WebSearchResults webSearchResults = googleSearchEngine.search(webSearchRequest); + + // then + assertThat(webSearchResults.searchMetadata().get("searchType").toString()).isEqualTo("images"); // searchType: images + assertThat(webSearchResults.results()) // Get images as search results + .as("At least the snippet should be contains 'weather' and 'Porto' ignoring case") + .anySatisfy(result -> assertThat(result.title()) + .containsIgnoringCase("weather") + .containsIgnoringCase("porto")) + .anySatisfy(result -> assertThat(result.url().toString()) + .startsWith("http")) + .anySatisfy(result -> assertThat(result.metadata().get("mimeType")) + .startsWith("image")) + .anySatisfy(result -> assertThat(result.metadata().get("imageLink")) + .isEqualTo(result.url().toString())) + .anySatisfy(result -> assertThat(result.metadata().get("contextLink")) + .startsWith("http")) + .anySatisfy(result -> assertThat(result.metadata().get("thumbnailLink")) + .startsWith("http")); + } + + @Test + void should_return_web_results_with_geolocation() { + // given + String searchTerm = "Who is the current president?"; + WebSearchRequest webSearchRequest = WebSearchRequest.builder() + .searchTerms(searchTerm) + .geoLocation("fr") + .build(); + + // when + List webSearchOrganicResults = searchEngine().search(webSearchRequest).results(); + + // then + assertThat(webSearchOrganicResults).isNotNull(); + assertThat(webSearchOrganicResults) + .as("At least one result should be contains 'Emmanuel Macro' ignoring case") + .anySatisfy(result -> assertThat(result.title()) + .containsIgnoringCase("Emmanuel Macro")); + } + + @Override + protected WebSearchEngine searchEngine() { + return googleSearchEngine; + } +} diff --git a/web-search-engines/langchain4j-web-search-engine-google-custom/src/test/java/dev/langchain4j/web/search/google/customsearch/GoogleCustomWebSearchToolIT.java b/web-search-engines/langchain4j-web-search-engine-google-custom/src/test/java/dev/langchain4j/web/search/google/customsearch/GoogleCustomWebSearchToolIT.java new file mode 100644 index 0000000000..3c1bd5c387 --- /dev/null +++ b/web-search-engines/langchain4j-web-search-engine-google-custom/src/test/java/dev/langchain4j/web/search/google/customsearch/GoogleCustomWebSearchToolIT.java @@ -0,0 +1,170 @@ +package dev.langchain4j.web.search.google.customsearch; + +import dev.langchain4j.agent.tool.ToolSpecification; +import dev.langchain4j.agent.tool.ToolSpecifications; +import dev.langchain4j.data.message.*; +import dev.langchain4j.memory.chat.MessageWindowChatMemory; +import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.openai.OpenAiChatModel; +import dev.langchain4j.model.openai.OpenAiChatModelName; +import dev.langchain4j.service.AiServices; +import dev.langchain4j.web.search.WebSearchEngine; +import dev.langchain4j.web.search.WebSearchTool; +import dev.langchain4j.web.search.WebSearchToolIT; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + +import java.util.ArrayList; +import java.util.List; + +import static org.assertj.core.api.Assertions.assertThat; + +@EnabledIfEnvironmentVariable(named = "GOOGLE_API_KEY", matches = ".*") +@EnabledIfEnvironmentVariable(named = "GOOGLE_SEARCH_ENGINE_ID", matches = ".*") +class GoogleCustomWebSearchToolIT extends WebSearchToolIT { + + WebSearchEngine googleSearchEngine = GoogleCustomWebSearchEngine.withApiKeyAndCsi( + System.getenv("GOOGLE_API_KEY"), + System.getenv("GOOGLE_SEARCH_ENGINE_ID")); + + ChatLanguageModel chatModel = OpenAiChatModel.builder() + .apiKey(System.getenv("OPENAI_API_KEY")) + .modelName(OpenAiChatModelName.GPT_3_5_TURBO) + .logRequests(true) + .build(); + + interface Assistant { + @dev.langchain4j.service.SystemMessage({ + "You are a web search support agent.", + "If there is any event that has not happened yet", + "You MUST create a web search request with with user query and", + "use the web search tool to search the web for organic web results.", + "Include the source link in your final response." + }) + String answer(String userMessage); + } + + @Test + void should_execute_google_tool_with_chatLanguageModel_to_give_a_final_response(){ + // given + googleSearchEngine = GoogleCustomWebSearchEngine.builder() + .apiKey(System.getenv("GOOGLE_API_KEY")) + .csi(System.getenv("GOOGLE_SEARCH_ENGINE_ID")) + .logRequests(true) + .logResponses(true) + .maxRetries(3) + .build(); + + WebSearchTool webSearchTool = WebSearchTool.from(googleSearchEngine); + List tools = ToolSpecifications.toolSpecificationsFrom(webSearchTool); + String query = "What are the release dates for the movies coming out last week of May 2024?"; + List messages = new ArrayList<>(); + SystemMessage systemMessage = SystemMessage.from("You are a web search support agent. If there is any event that has not happened yet, you MUST use a web search tool to look up the information on the web. Include the source link in your final response. Do not say that you have not the capability to browse the web in real time"); + messages.add(systemMessage); + UserMessage userMessage = UserMessage.from(query); + messages.add(userMessage); + // when + AiMessage aiMessage = chatLanguageModel().generate(messages, tools).content(); + + // then + assertThat(aiMessage.hasToolExecutionRequests()).isTrue(); + assertThat(aiMessage.toolExecutionRequests()) + .anySatisfy(toolSpec -> { + assertThat(toolSpec.name()) + .containsIgnoringCase("searchWeb"); + assertThat(toolSpec.arguments()) + .isNotBlank(); + } + ); + messages.add(aiMessage); + + // when + String strResult = webSearchTool.searchWeb(query); + ToolExecutionResultMessage toolExecutionResultMessage = ToolExecutionResultMessage.from(aiMessage.toolExecutionRequests().get(0), strResult); + messages.add(toolExecutionResultMessage); + + AiMessage finalResponse = chatLanguageModel().generate(messages).content(); + System.out.println(finalResponse.text()); + + // then + assertThat(finalResponse.text()) + .as("At least the string result should be contains 'movies' and 'coming soon' ignoring case") + .containsIgnoringCase("movies") + .containsIgnoringCase("May 2024"); + } + + @Test + void should_execute_google_tool_with_chatLanguageModel_to_summary_response_in_images() { + // given + googleSearchEngine = GoogleCustomWebSearchEngine.builder() + .apiKey(System.getenv("GOOGLE_API_KEY")) + .csi(System.getenv("GOOGLE_SEARCH_ENGINE_ID")) + .logRequests(true) + .logResponses(true) + .build(); + + WebSearchTool webSearchTool = WebSearchTool.from(googleSearchEngine); + List tools = ToolSpecifications.toolSpecificationsFrom(webSearchTool); + String query = "My family is coming to visit me in Madrid next week, list the best tourist activities suitable for the whole family"; + List messages = new ArrayList<>(); + SystemMessage systemMessage = SystemMessage.from("You are a web search support agent. If there is any event that has not happened yet, you MUST use a web search tool to look up the information on the web. Include the source link in your final response and the image urls. Do not say that you have not the capability to browse the web in real time"); + messages.add(systemMessage); + UserMessage userMessage = UserMessage.from(query); + messages.add(userMessage); + // when + AiMessage aiMessage = chatLanguageModel().generate(messages, tools).content(); + + // then + assertThat(aiMessage.hasToolExecutionRequests()).isTrue(); + assertThat(aiMessage.toolExecutionRequests()) + .anySatisfy(toolSpec -> { + assertThat(toolSpec.name()) + .containsIgnoringCase("searchWeb"); + assertThat(toolSpec.arguments()) + .isNotBlank(); + } + ); + messages.add(aiMessage); + + // when + String strResult = webSearchTool.searchWeb(query); + ToolExecutionResultMessage toolExecutionResultMessage = ToolExecutionResultMessage.from(aiMessage.toolExecutionRequests().get(0), strResult); + messages.add(toolExecutionResultMessage); + + AiMessage finalResponse = chatLanguageModel().generate(messages).content(); + System.out.println(finalResponse.text()); + + // then + assertThat(finalResponse.text()) + .as("At least the string result should be contains 'madrid' and 'tourist' ignoring case") + .containsIgnoringCase("Madrid") + .containsIgnoringCase("Royal Palace"); + } + + @Test + void should_execute_google_tool_with_AiServices() { + // given + WebSearchTool webTool = WebSearchTool.from(googleSearchEngine); + + Assistant assistant = AiServices.builder(Assistant.class) + .chatLanguageModel(chatModel) + .tools(webTool) + .chatMemory(MessageWindowChatMemory.withMaxMessages(10)) + .build(); + // when + String answer = assistant.answer("Search in the web who won the FIFA World Cup 2022?"); + + // then + assertThat(answer).containsIgnoringCase("Argentina"); + } + + @Override + protected WebSearchEngine searchEngine() { + return googleSearchEngine; + } + + @Override + protected ChatLanguageModel chatLanguageModel() { + return chatModel; + } +} diff --git a/web-search-engines/langchain4j-web-search-engine-tavily/pom.xml b/web-search-engines/langchain4j-web-search-engine-tavily/pom.xml new file mode 100644 index 0000000000..ff105e2863 --- /dev/null +++ b/web-search-engines/langchain4j-web-search-engine-tavily/pom.xml @@ -0,0 +1,79 @@ + + + 4.0.0 + + + dev.langchain4j + langchain4j-parent + 0.32.0-SNAPSHOT + ../../langchain4j-parent/pom.xml + + + langchain4j-web-search-engine-tavily + jar + + LangChain4j :: Web Search Engine :: Tavily + + + + + dev.langchain4j + langchain4j-core + + + + com.squareup.retrofit2 + retrofit + + + + com.squareup.retrofit2 + converter-gson + + + + com.squareup.okhttp3 + okhttp + + + + org.projectlombok + lombok + provided + + + + org.junit.jupiter + junit-jupiter + test + + + + org.assertj + assertj-core + test + + + + + dev.langchain4j + langchain4j-core + tests + test-jar + test + + + + + + + Apache-2.0 + https://www.apache.org/licenses/LICENSE-2.0.txt + repo + A business-friendly OSS license + + + + \ No newline at end of file diff --git a/web-search-engines/langchain4j-web-search-engine-tavily/src/main/java/dev/langchain4j/web/search/tavily/TavilyApi.java b/web-search-engines/langchain4j-web-search-engine-tavily/src/main/java/dev/langchain4j/web/search/tavily/TavilyApi.java new file mode 100644 index 0000000000..dfdd6dc65e --- /dev/null +++ b/web-search-engines/langchain4j-web-search-engine-tavily/src/main/java/dev/langchain4j/web/search/tavily/TavilyApi.java @@ -0,0 +1,14 @@ +package dev.langchain4j.web.search.tavily; + +import retrofit2.Call; +import retrofit2.http.Body; +import retrofit2.http.Headers; +import retrofit2.http.POST; + +interface TavilyApi { + + @POST("/search") + @Headers({"Content-Type: application/json"}) + Call search(@Body TavilySearchRequest request); +} + diff --git a/web-search-engines/langchain4j-web-search-engine-tavily/src/main/java/dev/langchain4j/web/search/tavily/TavilyClient.java b/web-search-engines/langchain4j-web-search-engine-tavily/src/main/java/dev/langchain4j/web/search/tavily/TavilyClient.java new file mode 100644 index 0000000000..2571fb8b08 --- /dev/null +++ b/web-search-engines/langchain4j-web-search-engine-tavily/src/main/java/dev/langchain4j/web/search/tavily/TavilyClient.java @@ -0,0 +1,64 @@ +package dev.langchain4j.web.search.tavily; + +import com.google.gson.Gson; +import com.google.gson.GsonBuilder; +import lombok.Builder; +import okhttp3.OkHttpClient; +import retrofit2.Response; +import retrofit2.Retrofit; +import retrofit2.converter.gson.GsonConverterFactory; + +import java.io.IOException; +import java.time.Duration; + +import static com.google.gson.FieldNamingPolicy.LOWER_CASE_WITH_UNDERSCORES; + +class TavilyClient { + + private static final Gson GSON = new GsonBuilder() + .setFieldNamingPolicy(LOWER_CASE_WITH_UNDERSCORES) + .setPrettyPrinting() + .create(); + + private final TavilyApi tavilyApi; + + @Builder + public TavilyClient(String baseUrl, Duration timeout) { + + OkHttpClient.Builder okHttpClientBuilder = new OkHttpClient.Builder() + .callTimeout(timeout) + .connectTimeout(timeout) + .readTimeout(timeout) + .writeTimeout(timeout); + + Retrofit retrofit = new Retrofit.Builder() + .baseUrl(baseUrl) + .client(okHttpClientBuilder.build()) + .addConverterFactory(GsonConverterFactory.create(GSON)) + .build(); + + this.tavilyApi = retrofit.create(TavilyApi.class); + } + + public TavilyResponse search(TavilySearchRequest searchRequest) { + try { + Response retrofitResponse = tavilyApi + .search(searchRequest) + .execute(); + if (retrofitResponse.isSuccessful()) { + return retrofitResponse.body(); + } else { + throw toException(retrofitResponse); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + private static RuntimeException toException(Response response) throws IOException { + int code = response.code(); + String body = response.errorBody().string(); + String errorMessage = String.format("status code: %s; body: %s", code, body); + return new RuntimeException(errorMessage); + } +} diff --git a/web-search-engines/langchain4j-web-search-engine-tavily/src/main/java/dev/langchain4j/web/search/tavily/TavilyResponse.java b/web-search-engines/langchain4j-web-search-engine-tavily/src/main/java/dev/langchain4j/web/search/tavily/TavilyResponse.java new file mode 100644 index 0000000000..7b8600307e --- /dev/null +++ b/web-search-engines/langchain4j-web-search-engine-tavily/src/main/java/dev/langchain4j/web/search/tavily/TavilyResponse.java @@ -0,0 +1,18 @@ +package dev.langchain4j.web.search.tavily; + +import lombok.Builder; +import lombok.Getter; + +import java.util.List; + +@Builder +@Getter +class TavilyResponse { + + private String answer; + private String query; + private Double responseTime; + private List images; + private List followUpQuestions; + private List results; +} diff --git a/web-search-engines/langchain4j-web-search-engine-tavily/src/main/java/dev/langchain4j/web/search/tavily/TavilySearchRequest.java b/web-search-engines/langchain4j-web-search-engine-tavily/src/main/java/dev/langchain4j/web/search/tavily/TavilySearchRequest.java new file mode 100644 index 0000000000..f2d5c9deca --- /dev/null +++ b/web-search-engines/langchain4j-web-search-engine-tavily/src/main/java/dev/langchain4j/web/search/tavily/TavilySearchRequest.java @@ -0,0 +1,20 @@ +package dev.langchain4j.web.search.tavily; + +import lombok.Builder; +import lombok.Getter; + +import java.util.List; + +@Getter +@Builder +class TavilySearchRequest { + + private String apiKey; + private String query; + private String searchDepth; + private Boolean includeAnswer; + private Boolean includeRawContent; + private Integer maxResults; + private List includeDomains; + private List excludeDomains; +} diff --git a/web-search-engines/langchain4j-web-search-engine-tavily/src/main/java/dev/langchain4j/web/search/tavily/TavilySearchResult.java b/web-search-engines/langchain4j-web-search-engine-tavily/src/main/java/dev/langchain4j/web/search/tavily/TavilySearchResult.java new file mode 100644 index 0000000000..490daefaf4 --- /dev/null +++ b/web-search-engines/langchain4j-web-search-engine-tavily/src/main/java/dev/langchain4j/web/search/tavily/TavilySearchResult.java @@ -0,0 +1,15 @@ +package dev.langchain4j.web.search.tavily; + +import lombok.Builder; +import lombok.Getter; + +@Getter +@Builder +class TavilySearchResult { + + private String title; + private String url; + private String content; + private String rawContent; + private Double score; +} diff --git a/web-search-engines/langchain4j-web-search-engine-tavily/src/main/java/dev/langchain4j/web/search/tavily/TavilyWebSearchEngine.java b/web-search-engines/langchain4j-web-search-engine-tavily/src/main/java/dev/langchain4j/web/search/tavily/TavilyWebSearchEngine.java new file mode 100644 index 0000000000..dafcc7c401 --- /dev/null +++ b/web-search-engines/langchain4j-web-search-engine-tavily/src/main/java/dev/langchain4j/web/search/tavily/TavilyWebSearchEngine.java @@ -0,0 +1,107 @@ +package dev.langchain4j.web.search.tavily; + +import dev.langchain4j.web.search.*; +import lombok.Builder; + +import java.net.URI; +import java.time.Duration; +import java.util.Collections; +import java.util.List; + +import static dev.langchain4j.internal.Utils.copyIfNotNull; +import static dev.langchain4j.internal.Utils.getOrDefault; +import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank; +import static java.time.Duration.ofSeconds; +import static java.util.stream.Collectors.toList; + +/** + * Represents Tavily Search API as a {@code WebSearchEngine}. + * See more details here. + *
    + * When {@link #includeRawContent} is set to {@code true}, + * the raw content will appear in the {@link WebSearchOrganicResult#content()} field of each result. + *
    + * When {@link #includeAnswer} is set to {@code true}, + * the answer will appear in the {@link WebSearchOrganicResult#snippet()} field of the first result. + * In this case, the {@link WebSearchOrganicResult#url()} of the first result will always be "https://tavily.com/" and + * the {@link WebSearchOrganicResult#title()} will always be "Tavily Search API". + */ +public class TavilyWebSearchEngine implements WebSearchEngine { + + private static final String DEFAULT_BASE_URL = "https://api.tavily.com"; + + private final String apiKey; + private final TavilyClient tavilyClient; + private final String searchDepth; + private final Boolean includeAnswer; + private final Boolean includeRawContent; + private final List includeDomains; + private final List excludeDomains; + + @Builder + public TavilyWebSearchEngine(String baseUrl, + String apiKey, + Duration timeout, + String searchDepth, + Boolean includeAnswer, + Boolean includeRawContent, + List includeDomains, + List excludeDomains) { + this.tavilyClient = TavilyClient.builder() + .baseUrl(getOrDefault(baseUrl, DEFAULT_BASE_URL)) + .timeout(getOrDefault(timeout, ofSeconds(10))) + .build(); + this.apiKey = ensureNotBlank(apiKey, "apiKey"); + this.searchDepth = searchDepth; + this.includeAnswer = includeAnswer; + this.includeRawContent = includeRawContent; + this.includeDomains = copyIfNotNull(includeDomains); + this.excludeDomains = copyIfNotNull(excludeDomains); + } + + @Override + public WebSearchResults search(WebSearchRequest webSearchRequest) { + + TavilySearchRequest request = TavilySearchRequest.builder() + .apiKey(apiKey) + .query(webSearchRequest.searchTerms()) + .searchDepth(searchDepth) + .includeAnswer(includeAnswer) + .includeRawContent(includeRawContent) + .maxResults(webSearchRequest.maxResults()) + .includeDomains(includeDomains) + .excludeDomains(excludeDomains) + .build(); + + TavilyResponse tavilyResponse = tavilyClient.search(request); + + final List results = tavilyResponse.getResults().stream() + .map(TavilyWebSearchEngine::toWebSearchOrganicResult) + .collect(toList()); + + if (tavilyResponse.getAnswer() != null) { + WebSearchOrganicResult answerResult = WebSearchOrganicResult.from( + "Tavily Search API", + URI.create("https://tavily.com/"), + tavilyResponse.getAnswer(), + null + ); + results.add(0, answerResult); + } + + return WebSearchResults.from(WebSearchInformationResult.from((long) results.size()), results); + } + + public static TavilyWebSearchEngine withApiKey(String apiKey) { + return builder().apiKey(apiKey).build(); + } + + private static WebSearchOrganicResult toWebSearchOrganicResult(TavilySearchResult tavilySearchResult) { + return WebSearchOrganicResult.from(tavilySearchResult.getTitle(), + URI.create(tavilySearchResult.getUrl()), + tavilySearchResult.getContent(), + tavilySearchResult.getRawContent(), + Collections.singletonMap("score", String.valueOf(tavilySearchResult.getScore()))); + } +} + diff --git a/web-search-engines/langchain4j-web-search-engine-tavily/src/test/java/dev/langchain4j/web/search/tavily/TavilyWebSearchEngineIT.java b/web-search-engines/langchain4j-web-search-engine-tavily/src/test/java/dev/langchain4j/web/search/tavily/TavilyWebSearchEngineIT.java new file mode 100644 index 0000000000..d8a9691087 --- /dev/null +++ b/web-search-engines/langchain4j-web-search-engine-tavily/src/test/java/dev/langchain4j/web/search/tavily/TavilyWebSearchEngineIT.java @@ -0,0 +1,87 @@ +package dev.langchain4j.web.search.tavily; + +import dev.langchain4j.web.search.WebSearchEngine; +import dev.langchain4j.web.search.WebSearchEngineIT; +import dev.langchain4j.web.search.WebSearchOrganicResult; +import dev.langchain4j.web.search.WebSearchResults; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + +import java.net.URI; +import java.util.List; + +import static org.assertj.core.api.Assertions.assertThat; + +@EnabledIfEnvironmentVariable(named = "TAVILY_API_KEY", matches = ".+") +class TavilyWebSearchEngineIT extends WebSearchEngineIT { + + WebSearchEngine webSearchEngine = TavilyWebSearchEngine.withApiKey(System.getenv("TAVILY_API_KEY")); + + @Test + void should_search_with_raw_content() { + + // given + TavilyWebSearchEngine tavilyWebSearchEngine = TavilyWebSearchEngine.builder() + .apiKey(System.getenv("TAVILY_API_KEY")) + .includeRawContent(true) + .build(); + + // when + WebSearchResults webSearchResults = tavilyWebSearchEngine.search("What is LangChain4j?"); + + // then + List results = webSearchResults.results(); + + results.forEach(result -> { + assertThat(result.title()).isNotBlank(); + assertThat(result.url()).isNotNull(); + assertThat(result.snippet()).isNotBlank(); + assertThat(result.content()).isNotBlank(); + assertThat(result.metadata()).containsOnlyKeys("score"); + }); + + assertThat(results).anyMatch(result -> + result.url().toString().contains("https://github.com/langchain4j") + && result.content().contains("How to get an API key") + ); + } + + @Test + void should_search_with_answer() { + + // given + TavilyWebSearchEngine tavilyWebSearchEngine = TavilyWebSearchEngine.builder() + .apiKey(System.getenv("TAVILY_API_KEY")) + .includeAnswer(true) + .build(); + + // when + WebSearchResults webSearchResults = tavilyWebSearchEngine.search("What is LangChain4j?"); + + // then + List results = webSearchResults.results(); + assertThat(results).hasSize(5 + 1); // +1 for answer + + WebSearchOrganicResult answerResult = results.get(0); + assertThat(answerResult.title()).isEqualTo("Tavily Search API"); + assertThat(answerResult.url()).isEqualTo(URI.create("https://tavily.com/")); + assertThat(answerResult.snippet()).isNotBlank(); + assertThat(answerResult.content()).isNull(); + assertThat(answerResult.metadata()).isNull(); + + results.subList(1, results.size()).forEach(result -> { + assertThat(result.title()).isNotBlank(); + assertThat(result.url()).isNotNull(); + assertThat(result.snippet()).isNotBlank(); + assertThat(result.content()).isNull(); + assertThat(result.metadata()).containsOnlyKeys("score"); + }); + + assertThat(results).anyMatch(result -> result.url().toString().contains("https://github.com/langchain4j")); + } + + @Override + protected WebSearchEngine searchEngine() { + return webSearchEngine; + } +} \ No newline at end of file