-
Notifications
You must be signed in to change notification settings - Fork 94
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add FastEmbed text embedding support (#108)
* feat: add FastEmbed text embedding support This commit includes changes made to add support for FastEmbed text embedding. We've added a new dependency 'fastembed' in our Cargo.toml. New error message has been added in 'error.rs' to handle any FastEmbed related errors. Created a new struct 'FastEmbed' under 'fastembed.rs' that utilizes the FastEmbed library for text embedding. Also, added FastEmbed related modules and updated embedding/mod.rs accordingly. Tests for FastEmbed have been included as well. * chore: add example
- Loading branch information
1 parent
328f931
commit 5745688
Showing
6 changed files
with
116 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
use langchain_rust::embedding::{Embedder, EmbeddingModel, FastEmbed, InitOptions, TextEmbedding}; | ||
|
||
#[tokio::main] | ||
async fn main() { | ||
//Default | ||
let fastembed = FastEmbed::try_new().unwrap(); | ||
let embeddings = fastembed | ||
.embed_documents(&["hello world".to_string(), "foo bar".to_string()]) | ||
.await | ||
.unwrap(); | ||
|
||
println!("Len: {:?}", embeddings.len()); | ||
|
||
//With custom model | ||
|
||
let model = TextEmbedding::try_new(InitOptions { | ||
model_name: EmbeddingModel::AllMiniLML6V2, | ||
show_download_progress: true, | ||
..Default::default() | ||
}) | ||
.unwrap(); | ||
|
||
let fastembed = FastEmbed::from(model); | ||
|
||
fastembed | ||
.embed_documents(&["hello world".to_string(), "foo bar".to_string()]) | ||
.await | ||
.unwrap(); | ||
|
||
println!("Len: {:?}", embeddings.len()); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
use async_trait::async_trait; | ||
pub use fastembed::{EmbeddingModel, InitOptions, TextEmbedding}; | ||
|
||
use crate::embedding::{Embedder, EmbedderError}; | ||
|
||
pub struct FastEmbed { | ||
model: TextEmbedding, | ||
batch_size: Option<usize>, | ||
} | ||
|
||
impl FastEmbed { | ||
pub fn try_new() -> Result<Self, EmbedderError> { | ||
Ok(Self { | ||
model: TextEmbedding::try_new(Default::default()) | ||
.map_err(|e| EmbedderError::FastEmbedError(e.to_string()))?, | ||
batch_size: None, | ||
}) | ||
} | ||
|
||
pub fn with_batch_size(mut self, batch_size: usize) -> Self { | ||
self.batch_size = Some(batch_size); | ||
self | ||
} | ||
} | ||
|
||
impl From<TextEmbedding> for FastEmbed { | ||
fn from(model: TextEmbedding) -> Self { | ||
Self { | ||
model, | ||
batch_size: None, | ||
} | ||
} | ||
} | ||
|
||
#[async_trait] | ||
impl Embedder for FastEmbed { | ||
async fn embed_documents(&self, documents: &[String]) -> Result<Vec<Vec<f64>>, EmbedderError> { | ||
let embeddings = self | ||
.model | ||
.embed(documents.to_vec(), self.batch_size) | ||
.map_err(|e| EmbedderError::FastEmbedError(e.to_string()))?; | ||
|
||
Ok(embeddings | ||
.into_iter() | ||
.map(|inner_vec| { | ||
inner_vec | ||
.into_iter() | ||
.map(|x| x as f64) | ||
.collect::<Vec<f64>>() | ||
}) | ||
.collect::<Vec<Vec<f64>>>()) | ||
} | ||
|
||
async fn embed_query(&self, text: &str) -> Result<Vec<f64>, EmbedderError> { | ||
let embedding = self | ||
.model | ||
.embed(vec![text], self.batch_size) | ||
.map_err(|e| EmbedderError::FastEmbedError(e.to_string()))?; | ||
|
||
Ok(embedding[0].iter().map(|x| *x as f64).collect()) | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use super::*; | ||
#[tokio::test] | ||
async fn test_fastembed() { | ||
let fastembed = FastEmbed::try_new().unwrap(); | ||
let embeddings = fastembed | ||
.embed_documents(&["hello world".to_string(), "foo bar".to_string()]) | ||
.await | ||
.unwrap(); | ||
assert_eq!(embeddings.len(), 2); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
mod fastembed; | ||
pub use fastembed::*; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,3 +4,6 @@ mod error; | |
pub mod ollama; | ||
pub mod openai; | ||
pub use error::*; | ||
|
||
mod fastembed; | ||
pub use fastembed::*; |