-
Notifications
You must be signed in to change notification settings - Fork 94
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Feature: Add amazon opensearch serverless vector store (#118)
- Loading branch information
Showing
7 changed files
with
484 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
// To run this example execute: cargo run --example vector_store_opensearch --features opensearch | ||
|
||
use aws_config::SdkConfig; | ||
use langchain_rust::vectorstore::{VecStoreOptions, VectorStore}; | ||
#[cfg(feature = "opensearch")] | ||
use langchain_rust::{ | ||
embedding::openai::openai_embedder::OpenAiEmbedder, schemas::Document, | ||
vectorstore::opensearch::Store, vectorstore::opensearch::*, | ||
}; | ||
use serde_json::json; | ||
use std::collections::HashMap; | ||
use std::error::Error; | ||
use std::io::Write; | ||
use url::Url; | ||
|
||
#[cfg(feature = "opensearch")] | ||
#[tokio::main] | ||
async fn main() { | ||
// Initialize Embedder | ||
let embedder = OpenAiEmbedder::default(); | ||
|
||
/* In this example we use an opensearch instance running on localhost (docker): | ||
docker run --rm -it -p 9200:9200 -p 9600:9600 \ | ||
-e "discovery.type=single-node" \ | ||
-e "node.name=localhost" \ | ||
-e "OPENSEARCH_INITIAL_ADMIN_PASSWORD=ZxPZ3cZL0ky1bzVu+~N" \ | ||
opensearchproject/opensearch:latest | ||
*/ | ||
|
||
let opensearch_host = "https://localhost:9200"; | ||
let opensearch_index = "test"; | ||
|
||
let url = Url::parse(opensearch_host).unwrap(); | ||
let conn_pool = SingleNodeConnectionPool::new(url); | ||
let transport = TransportBuilder::new(conn_pool) | ||
.disable_proxy() | ||
.auth(Credentials::Basic( | ||
"admin".to_string(), | ||
"ZxPZ3cZL0ky1bzVu+~N".to_string(), | ||
)) | ||
.cert_validation(CertificateValidation::None) | ||
.build() | ||
.unwrap(); | ||
let client = OpenSearch::new(transport); | ||
|
||
// We could also use an AOSS instance: | ||
// std::env::set_var("AWS_PROFILE", "xxx"); | ||
// use aws_config::BehaviorVersion; | ||
// let sdk_config = aws_config::load_defaults(BehaviorVersion::latest()).await; | ||
// let aoss_host = "https://blahblah.eu-central-1.aoss.amazonaws.com/"; | ||
// let client = build_aoss_client(&sdk_config, aoss_host).unwrap(); | ||
|
||
let store = StoreBuilder::new() | ||
.embedder(embedder) | ||
.index(opensearch_index) | ||
.content_field("the_content_field") | ||
.vector_field("the_vector_field") | ||
.client(client) | ||
.build() | ||
.await | ||
.unwrap(); | ||
|
||
let _ = store.delete_index().await; | ||
store.create_index().await.unwrap(); | ||
let added_ids = add_documents_to_index(&store).await.unwrap(); | ||
for id in added_ids { | ||
println!("added document with id: {id}"); | ||
} | ||
// it can take a while before the documents are actually available in the index... | ||
|
||
// Ask for user input | ||
print!("Query> "); | ||
std::io::stdout().flush().unwrap(); | ||
let mut query = String::new(); | ||
std::io::stdin().read_line(&mut query).unwrap(); | ||
|
||
let results = store | ||
.similarity_search(&query, 2, &VecStoreOptions::default()) | ||
.await | ||
.unwrap(); | ||
|
||
if results.is_empty() { | ||
println!("No results found."); | ||
return; | ||
} else { | ||
results.iter().for_each(|r| { | ||
println!("Document: {}", r.page_content); | ||
}); | ||
} | ||
} | ||
|
||
async fn add_documents_to_index(store: &Store) -> Result<Vec<String>, Box<dyn Error>> { | ||
let doc1 = Document::new( | ||
"langchain-rust is a port of the langchain python library to rust and was written in 2024.", | ||
) | ||
.with_metadata(HashMap::from([("source".to_string(), json!("cli"))])); | ||
|
||
let doc2 = Document::new( | ||
"langchaingo is a port of the langchain python library to go language and was written in 2023." | ||
); | ||
|
||
let doc3 = Document::new( | ||
"Capital of United States of America (USA) is Washington D.C. and the capital of France is Paris." | ||
); | ||
|
||
let doc4 = Document::new("Capital of France is Paris."); | ||
|
||
let opts = VecStoreOptions { | ||
name_space: None, | ||
score_threshold: None, | ||
filters: None, | ||
embedder: Some(store.embedder.clone()), | ||
}; | ||
|
||
let result = store | ||
.add_documents(&vec![doc1, doc2, doc3, doc4], &opts) | ||
.await?; | ||
|
||
Ok(result) | ||
} | ||
|
||
#[allow(dead_code)] | ||
fn build_aoss_client(sdk_config: &SdkConfig, host: &str) -> Result<OpenSearch, Box<dyn Error>> { | ||
let opensearch_url = Url::parse(host)?; | ||
let conn_pool = SingleNodeConnectionPool::new(opensearch_url); | ||
|
||
let transport = TransportBuilder::new(conn_pool) | ||
.auth(sdk_config.try_into()?) | ||
.service_name("aoss") | ||
.build()?; | ||
let client = OpenSearch::new(transport); | ||
Ok(client) | ||
} | ||
|
||
#[cfg(not(feature = "opensearch"))] | ||
fn main() { | ||
println!("This example requires the 'opensearch' feature to be enabled."); | ||
println!("Please run the command as follows:"); | ||
println!("cargo run --example vector_store_opensearch --features opensearch"); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
use crate::embedding::Embedder; | ||
use crate::vectorstore::opensearch::Store; | ||
use opensearch::OpenSearch; | ||
use std::error::Error; | ||
use std::sync::Arc; | ||
|
||
pub struct StoreBuilder { | ||
client: Option<OpenSearch>, | ||
embedder: Option<Arc<dyn Embedder>>, | ||
k: i32, | ||
index: Option<String>, | ||
vector_field: String, | ||
content_field: String, | ||
} | ||
|
||
impl StoreBuilder { | ||
// Returns a new StoreBuilder instance with default values for each option | ||
pub fn new() -> Self { | ||
StoreBuilder { | ||
client: None, | ||
embedder: None, | ||
k: 2, | ||
index: None, | ||
vector_field: "vector_field".to_string(), | ||
content_field: "page_content".to_string(), | ||
} | ||
} | ||
|
||
pub fn client(mut self, client: OpenSearch) -> Self { | ||
self.client = Some(client); | ||
self | ||
} | ||
|
||
pub fn embedder<E: Embedder + 'static>(mut self, embedder: E) -> Self { | ||
self.embedder = Some(Arc::new(embedder)); | ||
self | ||
} | ||
|
||
pub fn k(mut self, k: i32) -> Self { | ||
self.k = k; | ||
self | ||
} | ||
|
||
pub fn index(mut self, index: &str) -> Self { | ||
self.index = Some(index.to_string()); | ||
self | ||
} | ||
|
||
pub fn vector_field(mut self, vector_field: &str) -> Self { | ||
self.vector_field = vector_field.to_string(); | ||
self | ||
} | ||
|
||
pub fn content_field(mut self, content_field: &str) -> Self { | ||
self.content_field = content_field.to_string(); | ||
self | ||
} | ||
|
||
// Finalize the builder and construct the Store object | ||
pub async fn build(self) -> Result<Store, Box<dyn Error>> { | ||
if self.client.is_none() { | ||
return Err("Client is required".into()); | ||
} | ||
|
||
if self.embedder.is_none() { | ||
return Err("Embedder is required".into()); | ||
} | ||
|
||
if self.index.is_none() { | ||
return Err("Index is required".into()); | ||
} | ||
|
||
Ok(Store { | ||
client: self.client.unwrap(), | ||
embedder: self.embedder.unwrap(), | ||
k: self.k, | ||
index: self.index.unwrap(), | ||
vector_field: self.vector_field, | ||
content_field: self.content_field, | ||
}) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
mod builder; | ||
mod opensearch; | ||
|
||
pub use builder::*; | ||
pub use opensearch::*; |
Oops, something went wrong.