Skip to content

Commit

Permalink
feat: add FastEmbed text embedding support (#108)
Browse files Browse the repository at this point in the history
* feat: add FastEmbed text embedding support

This commit includes changes made to add support for FastEmbed text embedding. We've added a new dependency 'fastembed' in our Cargo.toml. New error message has been added in 'error.rs' to handle any FastEmbed related errors. Created a new struct 'FastEmbed' under 'fastembed.rs' that utilizes the FastEmbed library for text embedding. Also, added FastEmbed related modules and updated embedding/mod.rs accordingly. Tests for FastEmbed have been included as well.

* chore: add example
  • Loading branch information
Abraxas-365 authored Apr 6, 2024
1 parent 328f931 commit 5745688
Show file tree
Hide file tree
Showing 6 changed files with 116 additions and 0 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ tokio-stream = "0.1.15"
secrecy = "0.8.0"
readability = "0.3.0"
url = "2.5.0"
fastembed = "3.3.0"

[features]
default = []
Expand Down
31 changes: 31 additions & 0 deletions examples/fastembed.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
use langchain_rust::embedding::{Embedder, EmbeddingModel, FastEmbed, InitOptions, TextEmbedding};

#[tokio::main]
async fn main() {
//Default
let fastembed = FastEmbed::try_new().unwrap();
let embeddings = fastembed
.embed_documents(&["hello world".to_string(), "foo bar".to_string()])
.await
.unwrap();

println!("Len: {:?}", embeddings.len());

//With custom model

let model = TextEmbedding::try_new(InitOptions {
model_name: EmbeddingModel::AllMiniLML6V2,
show_download_progress: true,
..Default::default()
})
.unwrap();

let fastembed = FastEmbed::from(model);

fastembed
.embed_documents(&["hello world".to_string(), "foo bar".to_string()])
.await
.unwrap();

println!("Len: {:?}", embeddings.len());
}
3 changes: 3 additions & 0 deletions src/embedding/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,7 @@ pub enum EmbedderError {
status_code: StatusCode,
error_message: String,
},

#[error("FastEmbed error: {0}")]
FastEmbedError(String),
}
76 changes: 76 additions & 0 deletions src/embedding/fastembed/fastembed.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
use async_trait::async_trait;
pub use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};

use crate::embedding::{Embedder, EmbedderError};

pub struct FastEmbed {
model: TextEmbedding,
batch_size: Option<usize>,
}

impl FastEmbed {
pub fn try_new() -> Result<Self, EmbedderError> {
Ok(Self {
model: TextEmbedding::try_new(Default::default())
.map_err(|e| EmbedderError::FastEmbedError(e.to_string()))?,
batch_size: None,
})
}

pub fn with_batch_size(mut self, batch_size: usize) -> Self {
self.batch_size = Some(batch_size);
self
}
}

impl From<TextEmbedding> for FastEmbed {
fn from(model: TextEmbedding) -> Self {
Self {
model,
batch_size: None,
}
}
}

#[async_trait]
impl Embedder for FastEmbed {
async fn embed_documents(&self, documents: &[String]) -> Result<Vec<Vec<f64>>, EmbedderError> {
let embeddings = self
.model
.embed(documents.to_vec(), self.batch_size)
.map_err(|e| EmbedderError::FastEmbedError(e.to_string()))?;

Ok(embeddings
.into_iter()
.map(|inner_vec| {
inner_vec
.into_iter()
.map(|x| x as f64)
.collect::<Vec<f64>>()
})
.collect::<Vec<Vec<f64>>>())
}

async fn embed_query(&self, text: &str) -> Result<Vec<f64>, EmbedderError> {
let embedding = self
.model
.embed(vec![text], self.batch_size)
.map_err(|e| EmbedderError::FastEmbedError(e.to_string()))?;

Ok(embedding[0].iter().map(|x| *x as f64).collect())
}
}

#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_fastembed() {
let fastembed = FastEmbed::try_new().unwrap();
let embeddings = fastembed
.embed_documents(&["hello world".to_string(), "foo bar".to_string()])
.await
.unwrap();
assert_eq!(embeddings.len(), 2);
}
}
2 changes: 2 additions & 0 deletions src/embedding/fastembed/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
mod fastembed;
pub use fastembed::*;
3 changes: 3 additions & 0 deletions src/embedding/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,6 @@ mod error;
pub mod ollama;
pub mod openai;
pub use error::*;

mod fastembed;
pub use fastembed::*;

0 comments on commit 5745688

Please sign in to comment.