-
Notifications
You must be signed in to change notification settings - Fork 94
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(tools): add text2speech module (#107)
* feat(tools): add text2speech module Add new text2speech module under tools with OpenAI as a speech conversion tool. This change includes the creation of text2speech module, OpenAI client along with models and speech storage interface. With this addition, support for converting text to speech using OpenAI has been introduced. * chore: add 'with_config' method to Text2SpeechOpenAI client * chore: add examples * fix: Correct typo in method name in Text2SpeechOpenAI client.rs * refactor: Update Text2SpeechOpenAI client.rs and mod.rs files with new changes, deleting unesasary structs and leaving the asyncopenai ones * feat: Modify with_model and with_voice functions to accept enum types in text2speech openai client The with_model and with_voice function parameters in the text2speech openai client have been changed to accept enum types. This change improves type safety and readability.
- Loading branch information
1 parent
5745688
commit dc9bee7
Showing
6 changed files
with
166 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
use async_trait::async_trait; | ||
use langchain_rust::tools::{SpeechStorage, Text2SpeechOpenAI, Tool}; | ||
|
||
struct XStorage {} | ||
|
||
//You can add save te result to s3 or other storage using | ||
|
||
#[async_trait] | ||
impl SpeechStorage for XStorage { | ||
async fn save(&self, path: &str, _data: &[u8]) -> Result<String, Box<dyn std::error::Error>> { | ||
println!("Saving to: {}", path); | ||
Ok(path.to_string()) | ||
} | ||
} | ||
|
||
#[tokio::main] | ||
async fn main() { | ||
let openai = Text2SpeechOpenAI::default().with_path("./data/audio.mp3"); | ||
// .with_storage(XStorage {}); | ||
|
||
let path = openai.call("Hi, My name is Luis").await.unwrap(); | ||
println!("Path: {}", path); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,3 +15,6 @@ pub use serpapi::*; | |
|
||
mod command_executor; | ||
pub use command_executor::*; | ||
|
||
mod text2speech; | ||
pub use text2speech::*; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
mod openai; | ||
pub use openai::*; | ||
|
||
mod speech_storage; | ||
pub use speech_storage::*; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
use std::{error::Error, sync::Arc}; | ||
|
||
use async_openai::types::CreateSpeechRequestArgs; | ||
use async_openai::Client; | ||
pub use async_openai::{ | ||
config::{Config, OpenAIConfig}, | ||
types::{SpeechModel, SpeechResponseFormat, Voice}, | ||
}; | ||
use async_trait::async_trait; | ||
use serde_json::Value; | ||
|
||
use crate::tools::{SpeechStorage, Tool}; | ||
|
||
#[derive(Clone)] | ||
pub struct Text2SpeechOpenAI<C: Config> { | ||
config: C, | ||
model: SpeechModel, | ||
voice: Voice, | ||
storage: Option<Arc<dyn SpeechStorage>>, | ||
response_format: SpeechResponseFormat, | ||
path: String, | ||
} | ||
|
||
impl<C: Config> Text2SpeechOpenAI<C> { | ||
pub fn new(config: C) -> Self { | ||
Self { | ||
config, | ||
model: SpeechModel::Tts1, | ||
voice: Voice::Alloy, | ||
storage: None, | ||
response_format: SpeechResponseFormat::Mp3, | ||
path: "./data/audio.mp3".to_string(), | ||
} | ||
} | ||
|
||
pub fn with_model(mut self, model: SpeechModel) -> Self { | ||
self.model = model; | ||
self | ||
} | ||
|
||
pub fn with_voice(mut self, voice: Voice) -> Self { | ||
self.voice = voice; | ||
self | ||
} | ||
|
||
pub fn with_storage<SS: SpeechStorage + 'static>(mut self, storage: SS) -> Self { | ||
self.storage = Some(Arc::new(storage)); | ||
self | ||
} | ||
|
||
pub fn with_response_format(mut self, response_format: SpeechResponseFormat) -> Self { | ||
self.response_format = response_format; | ||
self | ||
} | ||
|
||
pub fn with_path<S: Into<String>>(mut self, path: S) -> Self { | ||
self.path = path.into(); | ||
self | ||
} | ||
|
||
pub fn with_config(mut self, config: C) -> Self { | ||
self.config = config; | ||
self | ||
} | ||
} | ||
|
||
impl Default for Text2SpeechOpenAI<OpenAIConfig> { | ||
fn default() -> Self { | ||
Self::new(OpenAIConfig::default()) | ||
} | ||
} | ||
|
||
#[async_trait] | ||
impl<C: Config + Send + Sync> Tool for Text2SpeechOpenAI<C> { | ||
fn name(&self) -> String { | ||
"Text2SpeechOpenAI".to_string() | ||
} | ||
|
||
fn description(&self) -> String { | ||
format!( | ||
r#"A wrapper around OpenAI Text2Speech. " | ||
"Useful for when you need to convert text to speech. " | ||
"It supports multiple languages, including English, German, Polish, " | ||
"Spanish, Italian, French, Portuguese""# | ||
) | ||
} | ||
|
||
async fn run(&self, input: Value) -> Result<String, Box<dyn Error>> { | ||
let input = input.as_str().ok_or("Invalid input")?; | ||
let client = Client::new(); | ||
let response_format: SpeechResponseFormat = self.response_format.clone().into(); | ||
|
||
let request = CreateSpeechRequestArgs::default() | ||
.input(input) | ||
.voice(self.voice.clone()) | ||
.response_format(response_format) | ||
.model(self.model.clone()) | ||
.build()?; | ||
|
||
let response = client.audio().speech(request).await?; | ||
|
||
if self.storage.is_some() { | ||
let storage = self.storage.as_ref().unwrap(); //safe to unwrap | ||
let data = response.bytes; | ||
return storage.save(&self.path, &data).await; | ||
} else { | ||
response.save(&self.path).await?; | ||
} | ||
|
||
Ok(self.path.clone()) | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use crate::tools::{Text2SpeechOpenAI, Tool}; | ||
|
||
#[tokio::test] | ||
#[ignore] | ||
async fn openai_speech2text_tool() { | ||
let openai = Text2SpeechOpenAI::default(); | ||
let s = openai.call("Hola como estas").await.unwrap(); | ||
println!("{}", s); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
mod client; | ||
pub use client::*; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
use std::error::Error; | ||
|
||
use async_trait::async_trait; | ||
|
||
#[async_trait] | ||
pub trait SpeechStorage: Send + Sync { | ||
async fn save(&self, key: &str, data: &[u8]) -> Result<String, Box<dyn Error>>; | ||
} |