Skip to content

Commit

Permalink
Merge pull request #36 from akshayballal95/main
Browse files Browse the repository at this point in the history
Add Whisper-OpenAI and Whisper-Jina
  • Loading branch information
akshayballal95 authored Jul 16, 2024
2 parents 4a24a77 + 01da61c commit dae8308
Show file tree
Hide file tree
Showing 10 changed files with 162 additions and 224 deletions.
3 changes: 2 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
[package]
name = "embed_anything"

version = "0.1.20"

version = "0.1.21"
edition = "2021"

[dependencies]
Expand Down
2 changes: 1 addition & 1 deletion examples/audio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@ fn main() {

let embeddings = embed_file(audio_path.to_str().unwrap(), "Whisper-Bert").unwrap();

println!("{:?}", embeddings);
// println!("{:?}", embeddings);
}
3 changes: 3 additions & 0 deletions examples/web.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
import embed_anything

data = embed_anything.embed_webpage("https://www.akshaymakes.com/", "Bert")
214 changes: 0 additions & 214 deletions examples/zeroshot.ipynb

This file was deleted.

24 changes: 22 additions & 2 deletions python/embed_anything/embed_anything.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def embed_file(file_path: str, embeder: str) -> list[EmbedData]:
- Text -> "OpenAI", "Bert"
- Image -> "Clip"
- Audio -> "Whisper-Bert"
- Audio -> "Whisper-Bert", "Whisper-OpenAI", "Whisper-Jina"
Args:
file_path: The path to the file to embed.
Expand All @@ -65,6 +65,13 @@ def embed_file(file_path: str, embeder: str) -> list[EmbedData]:
Returns:
A list of EmbedData objects.
Example:
```python
import embed_anything
data = embed_anything.embed_file("test_files/test.pdf", embeder="Bert
```
"""

def embed_directory(file_path: str, embeder: str) -> list[EmbedData]:
Expand All @@ -76,7 +83,20 @@ def embed_directory(file_path: str, embeder: str) -> list[EmbedData]:
embeder: The name of the embedding model to use. Choose between "OpenAI" and "Bert"
Returns:
- A list of EmbedData objects.
A list of EmbedData objects.
"""

def embed_webpage(url: str, embeder: str) -> list[EmbedData]:
"""
Embeds the webpage at the given URL and returns a list of EmbedData objects.
Args:
url: The URL of the webpage to embed.
embeder: The name of the embedding model to use. Choose between "OpenAI", "Jina", "Bert"
Returns:
A list of EmbedData objects
"""

class EmbedData:
Expand Down
55 changes: 54 additions & 1 deletion src/embedding_model/jina.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ extern crate accelerate_src;

use std::collections::HashMap;

use super::embed::{Embed, EmbedData, TextEmbed};
use crate::file_processor::audio::audio_processor::Segment;

use super::embed::{AudioEmbed, Embed, EmbedData, TextEmbed};
use anyhow::Error as E;
use candle_core::{DType, Device, Tensor};
use candle_nn::{Module, VarBuilder};
Expand Down Expand Up @@ -98,8 +100,49 @@ impl JinaEmbeder {
.collect::<Vec<_>>();
Ok(final_embeddings)
}

fn embed_audio<T: AsRef<std::path::Path>>(
&self,
segments: Vec<Segment>,
audio_file: T,
) -> Result<Vec<EmbedData>, anyhow::Error> {
let text_batch = segments
.iter()
.map(|segment| segment.dr.text.clone())
.collect::<Vec<String>>();

let token_ids = self
.tokenize_batch(&text_batch, &self.model.device)
.unwrap();
println!("{:?}", token_ids);
let token_type_ids = token_ids.zeros_like().unwrap();
let embeddings = self.model.forward(&token_ids).unwrap();
let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3().unwrap();
let embeddings = (embeddings.sum(1).unwrap() / (n_tokens as f64)).unwrap();
let embeddings = normalize_l2(&embeddings).unwrap();
let encodings = embeddings.to_vec2::<f32>().unwrap();
let final_embeddings = encodings
.iter()
.enumerate()
.map(|(i, data)| {
let mut metadata = HashMap::new();
metadata.insert("start".to_string(), segments[i].start.to_string());
metadata.insert(
"end".to_string(),
(segments[i].start + segments[i].duration).to_string(),
);
metadata.insert(
"file_name".to_string(),
(audio_file.as_ref().to_str().unwrap()).to_string(),
);
EmbedData::new(data.to_vec(), Some(text_batch[i].clone()), Some(metadata))
})
.collect::<Vec<_>>();
Ok(final_embeddings)
}
}


impl Embed for JinaEmbeder {
fn embed(
&self,
Expand All @@ -120,6 +163,16 @@ impl TextEmbed for JinaEmbeder {
}
}

impl AudioEmbed for JinaEmbeder {
fn embed_audio<T: AsRef<std::path::Path>>(
&self,
segments: Vec<Segment>,
audio_file: T,
) -> Result<Vec<EmbedData>, anyhow::Error> {
self.embed_audio(segments, audio_file)
}
}

pub fn normalize_l2(v: &Tensor) -> candle_core::Result<Tensor> {
v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)
}
76 changes: 74 additions & 2 deletions src/embedding_model/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ use reqwest::Client;
use serde::Deserialize;
use serde_json::json;

use crate::embedding_model::embed::{EmbedData, EmbedResponse};
use crate::{embedding_model::embed::{EmbedData, EmbedResponse}, file_processor::audio::audio_processor::Segment};

use super::embed::{Embed, TextEmbed};
use super::embed::{AudioEmbed, Embed, TextEmbed};

/// Represents an OpenAIEmbeder struct that contains the URL and API key for making requests to the OpenAI API.
#[derive(Deserialize, Debug)]
Expand Down Expand Up @@ -91,6 +91,78 @@ impl OpenAIEmbeder {

Ok(emb_data)
}

fn embed_audio<T: AsRef<std::path::Path>>(
&self,
segments: Vec<Segment>,
audio_file: T,
) -> Result<Vec<EmbedData>, anyhow::Error> {

let client = Client::new();
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_io()
.build()
.unwrap();

let text_batch = segments
.iter()
.map(|segment| segment.dr.text.clone())
.collect::<Vec<String>>();

let data = runtime.block_on(async {
let response = client
.post(&self.url)
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&json!({
"input": text_batch,
"model": "text-embedding-3-small",
}))
.send()
.await
.unwrap();

let data = response.json::<EmbedResponse>().await.unwrap();
println!("{:?}", data.usage);
data
});

let encodings = data.data.iter().map(|data| data.embedding.clone()).collect::<Vec<_>>();

let emb_data = encodings
.iter()
.enumerate()
.map(|(i, data)| {
let mut metadata = HashMap::new();
metadata.insert("start".to_string(), segments[i].start.to_string());
metadata.insert(
"end".to_string(),
(segments[i].start + segments[i].duration).to_string(),
);
metadata.insert(
"file_name".to_string(),
(audio_file.as_ref().to_str().unwrap()).to_string(),
);
EmbedData::new(data.to_vec(), Some(text_batch[i].clone()), Some(metadata))
})
.collect::<Vec<_>>();

Ok(emb_data)



}

}

impl AudioEmbed for OpenAIEmbeder {
fn embed_audio<T: AsRef<std::path::Path>>(
&self,
segments: Vec<Segment>,
audio_file: T,
) -> Result<Vec<EmbedData>, anyhow::Error> {
self.embed_audio(segments, audio_file)
}
}

#[cfg(test)]
Expand Down
6 changes: 4 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ pub fn embed_file(file_name: &str, embeder: &str) -> PyResult<Vec<EmbedData>> {
"Bert" => emb_text(file_name, Embeder::Bert(embedding_model::bert::BertEmbeder::default()))?,
"Clip" => vec![emb_image(file_name, embedding_model::clip::ClipEmbeder::default())?],
"Whisper-Bert" => emb_audio(file_name, embedding_model::bert::BertEmbeder::default())?,
"Whisper-OpenAI"=> emb_audio(file_name, embedding_model::openai::OpenAIEmbeder::default())?,
"Whisper-Jina" => emb_audio(file_name, embedding_model::jina::JinaEmbeder::default())?,
_ => {
return Err(PyValueError::new_err(
"Invalid embedding model. Choose between OpenAI and Bert for text files and Clip for image files.",
Expand Down Expand Up @@ -204,7 +206,7 @@ pub fn embed_directory(
/// };
/// ```
#[pyfunction]
pub fn emb_webpage(url: String, embeder: &str) -> PyResult<Vec<EmbedData>> {
pub fn embed_webpage(url: String, embeder: &str) -> PyResult<Vec<EmbedData>> {
let website_processor = file_processor::website_processor::WebsiteProcessor::new();
let runtime = Builder::new_current_thread().enable_all().build().unwrap();
let webpage = runtime
Expand Down Expand Up @@ -236,7 +238,7 @@ fn embed_anything(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(embed_file, m)?)?;
m.add_function(wrap_pyfunction!(embed_directory, m)?)?;
m.add_function(wrap_pyfunction!(embed_query, m)?)?;
m.add_function(wrap_pyfunction!(emb_webpage, m)?)?;
m.add_function(wrap_pyfunction!(embed_webpage, m)?)?;
m.add_class::<embedding_model::embed::EmbedData>()?;

Ok(())
Expand Down
Empty file added ~bashrc
Empty file.

0 comments on commit dae8308

Please sign in to comment.