diff --git a/Cargo.toml b/Cargo.toml index 3825227e..53364c1e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -41,6 +41,7 @@ tokio-stream = "0.1.15" secrecy = "0.8.0" readability = "0.3.0" url = "2.5.0" +fastembed = "3.3.0" [features] default = [] diff --git a/examples/fastembed.rs b/examples/fastembed.rs new file mode 100644 index 00000000..c10105d8 --- /dev/null +++ b/examples/fastembed.rs @@ -0,0 +1,31 @@ +use langchain_rust::embedding::{Embedder, EmbeddingModel, FastEmbed, InitOptions, TextEmbedding}; + +#[tokio::main] +async fn main() { + //Default + let fastembed = FastEmbed::try_new().unwrap(); + let embeddings = fastembed + .embed_documents(&["hello world".to_string(), "foo bar".to_string()]) + .await + .unwrap(); + + println!("Len: {:?}", embeddings.len()); + + //With custom model + + let model = TextEmbedding::try_new(InitOptions { + model_name: EmbeddingModel::AllMiniLML6V2, + show_download_progress: true, + ..Default::default() + }) + .unwrap(); + + let fastembed = FastEmbed::from(model); + + fastembed + .embed_documents(&["hello world".to_string(), "foo bar".to_string()]) + .await + .unwrap(); + + println!("Len: {:?}", embeddings.len()); +} diff --git a/src/embedding/error.rs b/src/embedding/error.rs index e2c8ff48..e2aa4cdf 100644 --- a/src/embedding/error.rs +++ b/src/embedding/error.rs @@ -18,4 +18,7 @@ pub enum EmbedderError { status_code: StatusCode, error_message: String, }, + + #[error("FastEmbed error: {0}")] + FastEmbedError(String), } diff --git a/src/embedding/fastembed/fastembed.rs b/src/embedding/fastembed/fastembed.rs new file mode 100644 index 00000000..eed20401 --- /dev/null +++ b/src/embedding/fastembed/fastembed.rs @@ -0,0 +1,76 @@ +use async_trait::async_trait; +pub use fastembed::{EmbeddingModel, InitOptions, TextEmbedding}; + +use crate::embedding::{Embedder, EmbedderError}; + +pub struct FastEmbed { + model: TextEmbedding, + batch_size: Option, +} + +impl FastEmbed { + pub fn try_new() -> Result { + Ok(Self { + model: TextEmbedding::try_new(Default::default()) + .map_err(|e| EmbedderError::FastEmbedError(e.to_string()))?, + batch_size: None, + }) + } + + pub fn with_batch_size(mut self, batch_size: usize) -> Self { + self.batch_size = Some(batch_size); + self + } +} + +impl From for FastEmbed { + fn from(model: TextEmbedding) -> Self { + Self { + model, + batch_size: None, + } + } +} + +#[async_trait] +impl Embedder for FastEmbed { + async fn embed_documents(&self, documents: &[String]) -> Result>, EmbedderError> { + let embeddings = self + .model + .embed(documents.to_vec(), self.batch_size) + .map_err(|e| EmbedderError::FastEmbedError(e.to_string()))?; + + Ok(embeddings + .into_iter() + .map(|inner_vec| { + inner_vec + .into_iter() + .map(|x| x as f64) + .collect::>() + }) + .collect::>>()) + } + + async fn embed_query(&self, text: &str) -> Result, EmbedderError> { + let embedding = self + .model + .embed(vec![text], self.batch_size) + .map_err(|e| EmbedderError::FastEmbedError(e.to_string()))?; + + Ok(embedding[0].iter().map(|x| *x as f64).collect()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + #[tokio::test] + async fn test_fastembed() { + let fastembed = FastEmbed::try_new().unwrap(); + let embeddings = fastembed + .embed_documents(&["hello world".to_string(), "foo bar".to_string()]) + .await + .unwrap(); + assert_eq!(embeddings.len(), 2); + } +} diff --git a/src/embedding/fastembed/mod.rs b/src/embedding/fastembed/mod.rs new file mode 100644 index 00000000..05c6fae8 --- /dev/null +++ b/src/embedding/fastembed/mod.rs @@ -0,0 +1,2 @@ +mod fastembed; +pub use fastembed::*; diff --git a/src/embedding/mod.rs b/src/embedding/mod.rs index e68f0243..da49278d 100644 --- a/src/embedding/mod.rs +++ b/src/embedding/mod.rs @@ -4,3 +4,6 @@ mod error; pub mod ollama; pub mod openai; pub use error::*; + +mod fastembed; +pub use fastembed::*;