Skip to content

Commit

Permalink
Provide an example of CSV loading
Browse files Browse the repository at this point in the history
  • Loading branch information
cescoffier committed Nov 24, 2023
1 parent 71e181a commit 22c98b5
Show file tree
Hide file tree
Showing 19 changed files with 2,266 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/modules/ROOT/nav.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
** xref:redis-store.adoc[Redis Store]
** xref:chroma-store.adoc[Chroma Store]
** xref:in-process-embedding.adoc[In-Process Embeddings]
** xref:csv.adoc[Loading CSV files]
* Advanced topics
** xref:fault-tolerance.adoc[Fault Tolerance]
129 changes: 129 additions & 0 deletions docs/modules/ROOT/pages/csv.adoc
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
= Loading CSV files

include::./includes/attributes.adoc[]

When working with the Retrieval Augmented Generation (RAG) model, it is often necessary to load tabular data, such as a CSV file. This guide provides recommendations for loading CSV files in a way that is compatible with the RAG model.

When loading a CSV file, the process involves:

1. Transforming each row into a *document*.
2. Ingesting the set of documents using an appropriate *document splitter*.
3. Storing the documents in the database.
You can find a complete example in the [GitHub repository](https://github.com/quarkiverse/quarkus-langchain4j/tree/main/samples/csv-chatbot).

== From CSV to Documents

There are multiple ways to load CSV files in Java. In this example, we use the following dependencies:

[source,xml]
----
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-csv</artifactId>
<version>1.10.0</version>
</dependency>
----

You can choose a different library; the APIs are similar enough.

Once you have the dependency, load the CSV and process the rows:

[source, java]
----
/**
* The CSV file to load.
*/
@ConfigProperty(name = "csv.file")
File file;
/**
* The CSV file headers.
* Some libraries provide an API to extract them.
*/
@ConfigProperty(name = "csv.headers")
List<String> headers;
/**
* Ingest the CSV file.
* This method is executed when the application starts.
*/
public void ingest(@Observes StartupEvent event) throws IOException {
// Configure the CSV format.
CSVFormat csvFormat = CSVFormat.DEFAULT.builder()
.setHeader(headers.toArray(new String[0]))
.setSkipHeaderRecord(true)
.build();
// This will be the resulting list of documents:
List<Document> documents = new ArrayList<>();
try (Reader reader = new FileReader(file)) {
// Generate one document per row, using the specified syntax.
Iterable<CSVRecord> records = csvFormat.parse(reader);
int i = 1;
for (CSVRecord record : records) {
Map<String, String> metadata = new HashMap<>();
metadata.put("source", file.getAbsolutePath());
metadata.put("row", String.valueOf(i++));
StringBuilder content = new StringBuilder();
for (String header : headers) {
// Include all headers in the metadata.
metadata.put(header, record.get(header));
content.append(header).append(": ").append(record.get(header)).append("\n");
}
documents.add(new Document(content.toString(), Metadata.from(metadata)));
}
// ...
}
----

== Ingesting the Documents

Once you have the list of documents, they need to be ingested. For this, use a *document splitter*. We recommend the `recurve` splitter, a simple splitter that divides the document into chunks of a given size. While it may not be the most suitable splitter for your use case, it serves as a good starting point.

[source, java]
----
var ingestor = EmbeddingStoreIngestor.builder()
.embeddingStore(store) // Injected
.embeddingModel(embeddingModel) // Injected
.documentSplitter(recursive(500, 0))
.build();
ingestor.ingest(documents);
----

== Implementing the Retriever

With the documents ingested, you can now implement the retriever:

[source, java]
----
package io.quarkiverse.langchain4j.sample.chatbot;
import java.util.List;
import jakarta.enterprise.context.ApplicationScoped;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.retriever.EmbeddingStoreRetriever;
import dev.langchain4j.retriever.Retriever;
import io.quarkiverse.langchain4j.redis.RedisEmbeddingStore;
@ApplicationScoped
public class RetrieverExample implements Retriever<TextSegment> {
private final EmbeddingStoreRetriever retriever;
RetrieverExample(RedisEmbeddingStore store, EmbeddingModel model) {
// Limit the number of documents to avoid exceeding the context size.
retriever = EmbeddingStoreRetriever.from(store, model, 10);
}
@Override
public List<TextSegment> findRelevant(String s) {
return retriever.findRelevant(s);
}
}
----

1 change: 1 addition & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@
<module>samples/review-triage</module>
<module>samples/fraud-detection</module>
<module>samples/chatbot</module>
<module>samples/csv-chatbot</module>
</modules>
</profile>
</profiles>
Expand Down
160 changes: 160 additions & 0 deletions samples/csv-chatbot/pom.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>

<parent>
<groupId>io.quarkiverse.langchain4j</groupId>
<artifactId>quarkus-langchain4j-parent</artifactId>
<version>999-SNAPSHOT</version>
<relativePath>../..</relativePath>
</parent>

<artifactId>quarkus-langchain4j-sample-csv-chatbot</artifactId>
<name>Quarkus langchain4j - Sample - Chatbot &amp; RAG loading a CSV file</name>

<dependencies>
<dependency>
<groupId>io.quarkus</groupId>
<artifactId>quarkus-resteasy-reactive-jackson</artifactId>
</dependency>
<dependency>
<groupId>io.quarkus</groupId>
<artifactId>quarkus-websockets</artifactId>
</dependency>
<dependency>
<groupId>io.quarkiverse.langchain4j</groupId>
<artifactId>quarkus-langchain4j-openai</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>io.quarkiverse.langchain4j</groupId>
<artifactId>quarkus-langchain4j-redis</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-csv</artifactId>
<version>1.10.0</version>
</dependency>


<!-- UI -->
<dependency>
<groupId>org.mvnpm</groupId>
<artifactId>importmap</artifactId>
<version>1.0.8</version>
</dependency>
<dependency>
<groupId>org.mvnpm.at.mvnpm</groupId>
<artifactId>vaadin-webcomponents</artifactId>
<version>24.2.1</version>
<scope>runtime</scope>
</dependency>
<dependency>
<groupId>org.mvnpm</groupId>
<artifactId>es-module-shims</artifactId>
<scope>runtime</scope>
<version>1.8.2</version>
</dependency>
<dependency>
<groupId>org.mvnpm</groupId>
<artifactId>wc-chatbot</artifactId>
<version>0.1.2</version>
<scope>runtime</scope>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<groupId>io.quarkus</groupId>
<artifactId>quarkus-maven-plugin</artifactId>
<version>${quarkus.version}</version>
<executions>
<execution>
<goals>
<goal>build</goal>
</goals>
</execution>
</executions>
</plugin>
<plugin>
<artifactId>maven-compiler-plugin</artifactId>
</plugin>
<plugin>
<artifactId>maven-surefire-plugin</artifactId>
<version>3.2.2</version>
<configuration>
<systemPropertyVariables>
<java.util.logging.manager>org.jboss.logmanager.LogManager</java.util.logging.manager>
<maven.home>${maven.home}</maven.home>
</systemPropertyVariables>
</configuration>
</plugin>
</plugins>
</build>

<profiles>
<profile>
<id>native</id>
<activation>
<property>
<name>native</name>
</property>
</activation>
<build>
<plugins>
<plugin>
<artifactId>maven-failsafe-plugin</artifactId>
<version>3.2.2</version>
<executions>
<execution>
<goals>
<goal>integration-test</goal>
<goal>verify</goal>
</goals>
<configuration>
<systemPropertyVariables>
<native.image.path>
${project.build.directory}/${project.build.finalName}-runner
</native.image.path>
<java.util.logging.manager>org.jboss.logmanager.LogManager
</java.util.logging.manager>
<maven.home>${maven.home}</maven.home>
</systemPropertyVariables>
</configuration>
</execution>
</executions>
</plugin>
</plugins>
</build>
<properties>
<quarkus.package.type>native</quarkus.package.type>
</properties>
</profile>

<profile>
<id>mvnpm</id>
<activation>
<activeByDefault>true</activeByDefault>
</activation>
<repositories>
<repository>
<id>central</id>
<name>central</name>
<url>https://repo.maven.apache.org/maven2</url>
</repository>
<repository>
<snapshots>
<enabled>false</enabled>
</snapshots>
<id>mvnpm.org</id>
<name>mvnpm</name>
<url>https://repo.mvnpm.org/maven2</url>
</repository>
</repositories>
</profile>
</profiles>

</project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package io.quarkiverse.langchain4j.sample.chatbot;

import java.io.IOException;

import jakarta.inject.Inject;
import jakarta.websocket.*;
import jakarta.websocket.server.ServerEndpoint;

import io.smallrye.mutiny.infrastructure.Infrastructure;

@ServerEndpoint("/chatbot")
public class ChatBotWebSocket {

@Inject
MovieMuse bot;

@Inject
ChatMemoryBean chatMemoryBean;

@OnOpen
public void onOpen(Session session) {
Infrastructure.getDefaultExecutor().execute(() -> {
String response = bot.chat(session, "hello");
try {
session.getBasicRemote().sendText(response);
} catch (IOException e) {
throw new RuntimeException(e);
}
});
}

@OnClose
void onClose(Session session) {
chatMemoryBean.clear(session);
}

@OnMessage
public void onMessage(String message, Session session) {
Infrastructure.getDefaultExecutor().execute(() -> {
String response = bot.chat(session, message);
try {
session.getBasicRemote().sendText(response);
} catch (IOException e) {
throw new RuntimeException(e);
}
});

}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package io.quarkiverse.langchain4j.sample.chatbot;

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

import jakarta.enterprise.context.ApplicationScoped;

import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.memory.chat.ChatMemoryProvider;
import dev.langchain4j.memory.chat.MessageWindowChatMemory;

@ApplicationScoped
public class ChatMemoryBean implements ChatMemoryProvider {

private final Map<Object, ChatMemory> memories = new ConcurrentHashMap<>();

@Override
public ChatMemory get(Object memoryId) {
return memories.computeIfAbsent(memoryId, id -> MessageWindowChatMemory.builder()
.maxMessages(10)
.id(memoryId)
.build());
}

public void clear(Object session) {
memories.remove(session);
}
}
Loading

0 comments on commit 22c98b5

Please sign in to comment.