diff --git a/Cargo.toml b/Cargo.toml index 089bf577..9490583f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -43,6 +43,8 @@ readability = "0.3.0" url = "2.5.0" fastembed = "3" gix = { version = "0.61.0", default-features = false, optional = true, features = ["parallel", "revision", "serde"] } +opensearch = { version = "2", optional = true, features = ["aws-auth"] } +aws-config = { version = "1.1", optional = true, features = ["behavior-version-latest"] } [features] default = [] @@ -50,6 +52,7 @@ postgres = ["pgvector", "sqlx", "uuid"] surrealdb = ["dep:surrealdb"] sqlite = ["sqlx"] git = ["gix"] +opensearch = ["dep:opensearch", "aws-config"] [dev-dependencies] tokio-test = "0.4.4" diff --git a/README.md b/README.md index b3e82042..420d7c47 100644 --- a/README.md +++ b/README.md @@ -32,6 +32,7 @@ This is the Rust language implementation of [LangChain](https://github.com/langc - VectorStores + - [x] [OpenSearch](https://github.com/Abraxas-365/langchain-rust/blob/main/examples/vector_store_opensearch.rs) - [x] [Postgres](https://github.com/Abraxas-365/langchain-rust/blob/main/examples/vector_store_postgres.rs) - [x] [Sqlite](https://github.com/Abraxas-365/langchain-rust/blob/main/examples/vector_store_sqlite.rs) - [x] [SurrealDB](https://github.com/Abraxas-365/langchain-rust/blob/main/examples/vector_store_surrealdb/src/main.rs) diff --git a/examples/vector_store_opensearch.rs b/examples/vector_store_opensearch.rs new file mode 100644 index 00000000..fe3f9d41 --- /dev/null +++ b/examples/vector_store_opensearch.rs @@ -0,0 +1,142 @@ +// To run this example execute: cargo run --example vector_store_opensearch --features opensearch + +use aws_config::SdkConfig; +use langchain_rust::vectorstore::{VecStoreOptions, VectorStore}; +#[cfg(feature = "opensearch")] +use langchain_rust::{ + embedding::openai::openai_embedder::OpenAiEmbedder, schemas::Document, + vectorstore::opensearch::Store, vectorstore::opensearch::*, +}; +use serde_json::json; +use std::collections::HashMap; +use std::error::Error; +use std::io::Write; +use url::Url; + +#[cfg(feature = "opensearch")] +#[tokio::main] +async fn main() { + // Initialize Embedder + let embedder = OpenAiEmbedder::default(); + + /* In this example we use an opensearch instance running on localhost (docker): + + docker run --rm -it -p 9200:9200 -p 9600:9600 \ + -e "discovery.type=single-node" \ + -e "node.name=localhost" \ + -e "OPENSEARCH_INITIAL_ADMIN_PASSWORD=ZxPZ3cZL0ky1bzVu+~N" \ + opensearchproject/opensearch:latest + + */ + + let opensearch_host = "https://localhost:9200"; + let opensearch_index = "test"; + + let url = Url::parse(opensearch_host).unwrap(); + let conn_pool = SingleNodeConnectionPool::new(url); + let transport = TransportBuilder::new(conn_pool) + .disable_proxy() + .auth(Credentials::Basic( + "admin".to_string(), + "ZxPZ3cZL0ky1bzVu+~N".to_string(), + )) + .cert_validation(CertificateValidation::None) + .build() + .unwrap(); + let client = OpenSearch::new(transport); + + // We could also use an AOSS instance: + // std::env::set_var("AWS_PROFILE", "xxx"); + // use aws_config::BehaviorVersion; + // let sdk_config = aws_config::load_defaults(BehaviorVersion::latest()).await; + // let aoss_host = "https://blahblah.eu-central-1.aoss.amazonaws.com/"; + // let client = build_aoss_client(&sdk_config, aoss_host).unwrap(); + + let store = StoreBuilder::new() + .embedder(embedder) + .index(opensearch_index) + .content_field("the_content_field") + .vector_field("the_vector_field") + .client(client) + .build() + .await + .unwrap(); + + let _ = store.delete_index().await; + store.create_index().await.unwrap(); + let added_ids = add_documents_to_index(&store).await.unwrap(); + for id in added_ids { + println!("added document with id: {id}"); + } + // it can take a while before the documents are actually available in the index... + + // Ask for user input + print!("Query> "); + std::io::stdout().flush().unwrap(); + let mut query = String::new(); + std::io::stdin().read_line(&mut query).unwrap(); + + let results = store + .similarity_search(&query, 2, &VecStoreOptions::default()) + .await + .unwrap(); + + if results.is_empty() { + println!("No results found."); + return; + } else { + results.iter().for_each(|r| { + println!("Document: {}", r.page_content); + }); + } +} + +async fn add_documents_to_index(store: &Store) -> Result, Box> { + let doc1 = Document::new( + "langchain-rust is a port of the langchain python library to rust and was written in 2024.", + ) + .with_metadata(HashMap::from([("source".to_string(), json!("cli"))])); + + let doc2 = Document::new( + "langchaingo is a port of the langchain python library to go language and was written in 2023." + ); + + let doc3 = Document::new( + "Capital of United States of America (USA) is Washington D.C. and the capital of France is Paris." + ); + + let doc4 = Document::new("Capital of France is Paris."); + + let opts = VecStoreOptions { + name_space: None, + score_threshold: None, + filters: None, + embedder: Some(store.embedder.clone()), + }; + + let result = store + .add_documents(&vec![doc1, doc2, doc3, doc4], &opts) + .await?; + + Ok(result) +} + +#[allow(dead_code)] +fn build_aoss_client(sdk_config: &SdkConfig, host: &str) -> Result> { + let opensearch_url = Url::parse(host)?; + let conn_pool = SingleNodeConnectionPool::new(opensearch_url); + + let transport = TransportBuilder::new(conn_pool) + .auth(sdk_config.try_into()?) + .service_name("aoss") + .build()?; + let client = OpenSearch::new(transport); + Ok(client) +} + +#[cfg(not(feature = "opensearch"))] +fn main() { + println!("This example requires the 'opensearch' feature to be enabled."); + println!("Please run the command as follows:"); + println!("cargo run --example vector_store_opensearch --features opensearch"); +} diff --git a/src/vectorstore/mod.rs b/src/vectorstore/mod.rs index 05ab03c0..67b923ce 100644 --- a/src/vectorstore/mod.rs +++ b/src/vectorstore/mod.rs @@ -9,6 +9,10 @@ pub mod sqlite_vss; #[cfg(feature = "surrealdb")] pub mod surrealdb; +#[cfg(feature = "opensearch")] +pub mod opensearch; + mod vectorstore; + pub use options::*; pub use vectorstore::*; diff --git a/src/vectorstore/opensearch/builder.rs b/src/vectorstore/opensearch/builder.rs new file mode 100644 index 00000000..a0fd5872 --- /dev/null +++ b/src/vectorstore/opensearch/builder.rs @@ -0,0 +1,82 @@ +use crate::embedding::Embedder; +use crate::vectorstore::opensearch::Store; +use opensearch::OpenSearch; +use std::error::Error; +use std::sync::Arc; + +pub struct StoreBuilder { + client: Option, + embedder: Option>, + k: i32, + index: Option, + vector_field: String, + content_field: String, +} + +impl StoreBuilder { + // Returns a new StoreBuilder instance with default values for each option + pub fn new() -> Self { + StoreBuilder { + client: None, + embedder: None, + k: 2, + index: None, + vector_field: "vector_field".to_string(), + content_field: "page_content".to_string(), + } + } + + pub fn client(mut self, client: OpenSearch) -> Self { + self.client = Some(client); + self + } + + pub fn embedder(mut self, embedder: E) -> Self { + self.embedder = Some(Arc::new(embedder)); + self + } + + pub fn k(mut self, k: i32) -> Self { + self.k = k; + self + } + + pub fn index(mut self, index: &str) -> Self { + self.index = Some(index.to_string()); + self + } + + pub fn vector_field(mut self, vector_field: &str) -> Self { + self.vector_field = vector_field.to_string(); + self + } + + pub fn content_field(mut self, content_field: &str) -> Self { + self.content_field = content_field.to_string(); + self + } + + // Finalize the builder and construct the Store object + pub async fn build(self) -> Result> { + if self.client.is_none() { + return Err("Client is required".into()); + } + + if self.embedder.is_none() { + return Err("Embedder is required".into()); + } + + if self.index.is_none() { + return Err("Index is required".into()); + } + + Ok(Store { + client: self.client.unwrap(), + embedder: self.embedder.unwrap(), + k: self.k, + index: self.index.unwrap(), + vector_field: self.vector_field, + content_field: self.content_field, + }) + } +} diff --git a/src/vectorstore/opensearch/mod.rs b/src/vectorstore/opensearch/mod.rs new file mode 100644 index 00000000..9cede859 --- /dev/null +++ b/src/vectorstore/opensearch/mod.rs @@ -0,0 +1,5 @@ +mod builder; +mod opensearch; + +pub use builder::*; +pub use opensearch::*; diff --git a/src/vectorstore/opensearch/opensearch.rs b/src/vectorstore/opensearch/opensearch.rs new file mode 100644 index 00000000..77c052ca --- /dev/null +++ b/src/vectorstore/opensearch/opensearch.rs @@ -0,0 +1,247 @@ +use async_trait::async_trait; +use opensearch::http::request::JsonBody; +use opensearch::http::response::Response; +use opensearch::indices::{IndicesCreateParts, IndicesDeleteParts}; +use opensearch::{BulkParts, SearchParts}; +use serde_json::{json, Value}; +use std::collections::HashMap; +use std::error::Error; +use std::sync::Arc; + +pub use opensearch::auth::Credentials; +pub use opensearch::cert::CertificateValidation; +pub use opensearch::http::transport::{SingleNodeConnectionPool, TransportBuilder}; +pub use opensearch::OpenSearch; + +use crate::{ + embedding::embedder_trait::Embedder, + schemas::Document, + vectorstore::{VecStoreOptions, VectorStore}, +}; + +pub struct Store { + pub client: OpenSearch, + pub embedder: Arc, + pub k: i32, + pub index: String, + pub vector_field: String, + pub content_field: String, +} + +// https://opensearch.org/docs/latest/search-plugins/knn/approximate-knn/ +// https://opensearch.org/blog/efficient-filters-in-knn/ +// https://opensearch.org/docs/latest/clients/rust/ + +impl Store { + pub async fn delete_index(&self) -> Result> { + let response = self + .client + .indices() + .delete(IndicesDeleteParts::Index(&[&self.index])) + .send() + .await?; + + let result = response.error_for_status_code().map_err(|e| Box::new(e))?; + + Ok(result) + } + + pub async fn create_index(&self) -> Result> { + let body = json!({ + "settings": { + "index.knn": true, + "knn.algo_param": { + "ef_search": "512" + }, + }, + "mappings": { + "properties": { + &self.vector_field: { + "type": "knn_vector", + "dimension": 1536, + "method": { + "engine": "faiss", + "name": "hnsw", + "space_type": "l2", + "parameters": { + "ef_construction": 512, + "m": 16 + } + } + }, + &self.content_field: { + "type": "text" + }, + "metadata": { + "properties": { + "source": { + "type": "text", + } + } + } + } + } + }); + + let response = self + .client + .indices() + .create(IndicesCreateParts::Index(&self.index)) + .body(body) + .send() + .await?; + + let result = response.error_for_status_code().map_err(|e| Box::new(e))?; + + Ok(result) + } +} + +#[async_trait] +impl VectorStore for Store { + async fn add_documents( + &self, + docs: &[Document], + opt: &VecStoreOptions, + ) -> Result, Box> { + let texts: Vec = docs.iter().map(|d| d.page_content.clone()).collect(); + let embedder = opt.embedder.as_ref().unwrap_or(&self.embedder); + let vectors = embedder.embed_documents(&texts).await?; + + if vectors.len() != docs.len() { + return Err(Box::new(std::io::Error::new( + std::io::ErrorKind::Other, + "Number of vectors and documents do not match", + ))); + } + + let mut body: Vec> = Vec::with_capacity(docs.len() * 2); + + for (doc, vector) in docs.iter().zip(vectors.iter()) { + let operation = json!({"index": {}}); + body.push(operation.into()); + + let document = json!({ + &self.content_field: doc.page_content, + "metadata": doc.metadata, + &self.vector_field: vector, + }); + body.push(document.into()); + } + + let response = self + .client + .bulk(BulkParts::Index(&self.index)) + .body(body) + .send() + .await? + .error_for_status_code() + .map_err(|e| Box::new(e))?; + + let response_body = response.json::().await?; + + let ids = response_body["items"] + .as_array() + .unwrap() + .iter() + .map(|item| serde_json::from_value::(item["index"]["_id"].clone()).unwrap()) + .collect::>(); + + Ok(ids) + } + + async fn similarity_search( + &self, + query: &str, + limit: usize, + opt: &VecStoreOptions, + ) -> Result, Box> { + let query_vector = self.embedder.embed_query(query).await?; + let query = build_similarity_search_query( + query_vector, + &self.vector_field, + limit, + self.k, + opt.filters.clone(), + ); + + let response = self + .client + .search(SearchParts::Index(&[&self.index])) + .from(0) + .size(3) + .body(query) + .send() + .await?; + + let response_body = response.json::().await?; + + let aoss_documents = response_body["hits"]["hits"] + .as_array() + .unwrap() + .iter() + .map(|raw_value| { + serde_json::from_value::>(raw_value.clone()).unwrap() + }) + .collect::>(); + + let documents = aoss_documents + .into_iter() + .map(|item| { + let page_content = + serde_json::from_value::(item["_source"][&self.content_field].clone()) + .unwrap(); + let metadata = serde_json::from_value::>( + item["_source"]["metadata"].clone(), + ) + .unwrap(); + let score = serde_json::from_value::(item["_score"].clone()).unwrap(); + Document { + page_content, + metadata, + score, + } + }) + .collect(); + + Ok(documents) + } +} + +fn build_similarity_search_query( + embedded_query: Vec, + vector_field: &str, + size: usize, + k: i32, + maybe_filter: Option, +) -> Value { + match maybe_filter { + Some(filter) => { + json!({ + "size": size, + "query": { + "knn": { + vector_field: { + "vector": embedded_query, + "k": k, + "filter": filter, + } + } + } + }) + } + None => { + json!({ + "size": size, + "query": { + "knn": { + vector_field: { + "vector": embedded_query, + "k": k, + } + } + } + }) + } + } +}