diff --git a/Cargo.toml b/Cargo.toml index dfd96854..6c044f4a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -43,7 +43,7 @@ pgvector = { version = "0.3.2", features = [ "postgres", "sqlx", ], optional = true } -text-splitter = { version = "0.11", features = ["tiktoken-rs", "markdown"] } +text-splitter = { version = "0.13", features = ["tiktoken-rs", "markdown"] } surrealdb = { version = "1.4.2", optional = true, default-features = false } csv = "1.3.0" urlencoding = "2.1.3" diff --git a/src/text_splitter/error.rs b/src/text_splitter/error.rs index 926aad09..cf85bd9a 100644 --- a/src/text_splitter/error.rs +++ b/src/text_splitter/error.rs @@ -1,3 +1,4 @@ +use text_splitter::ChunkConfigError; use thiserror::Error; #[derive(Error, Debug)] @@ -17,6 +18,15 @@ pub enum TextSplitterError { #[error("Tokenizer creation failed due to invalid model")] InvalidModel, + #[error("Invalid chunk overlap and size")] + InvalidSplitterOptions, + #[error("Error: {0}")] OtherError(String), } + +impl From for TextSplitterError { + fn from(_: ChunkConfigError) -> Self { + Self::InvalidSplitterOptions + } +} diff --git a/src/text_splitter/markdown_splitter.rs b/src/text_splitter/markdown_splitter.rs index 5a3c2c15..f954a75d 100644 --- a/src/text_splitter/markdown_splitter.rs +++ b/src/text_splitter/markdown_splitter.rs @@ -1,13 +1,11 @@ use async_trait::async_trait; -use tiktoken_rs::{get_bpe_from_model, get_bpe_from_tokenizer, tokenizer::Tokenizer, CoreBPE}; +use text_splitter::ChunkConfig; +use tiktoken_rs::tokenizer::Tokenizer; use super::{SplitterOptions, TextSplitter, TextSplitterError}; pub struct MarkdownSplitter { - chunk_size: usize, - model_name: String, - encoding_name: String, - trim_chunks: bool, + splitter_options: SplitterOptions, } impl Default for MarkdownSplitter { @@ -19,13 +17,11 @@ impl Default for MarkdownSplitter { impl MarkdownSplitter { pub fn new(options: SplitterOptions) -> MarkdownSplitter { MarkdownSplitter { - chunk_size: options.chunk_size, - model_name: options.model_name, - encoding_name: options.encoding_name, - trim_chunks: options.trim_chunks, + splitter_options: options, } } + #[deprecated = "Use `SplitterOptions::get_tokenizer_from_str` instead"] pub fn get_tokenizer_from_str(&self, s: &str) -> Option { match s.to_lowercase().as_str() { "cl100k_base" => Some(Tokenizer::Cl100kBase), @@ -36,30 +32,15 @@ impl MarkdownSplitter { _ => None, } } - - fn split(&self, text: &str, tokenizer: CoreBPE) -> Vec { - let splitter = - text_splitter::MarkdownSplitter::new(tokenizer).with_trim_chunks(self.trim_chunks); - splitter - .chunks(text, self.chunk_size) - .map(|x| x.to_string()) - .collect() - } } #[async_trait] impl TextSplitter for MarkdownSplitter { async fn split_text(&self, text: &str) -> Result, TextSplitterError> { - let tk = if !self.encoding_name.is_empty() { - let tokenizer = self - .get_tokenizer_from_str(&self.encoding_name) - .ok_or(TextSplitterError::TokenizerNotFound)?; - - get_bpe_from_tokenizer(tokenizer).map_err(|_| TextSplitterError::InvalidTokenizer)? - } else { - get_bpe_from_model(&self.model_name).map_err(|_| TextSplitterError::InvalidModel)? - }; - let text = self.split(text, tk); - Ok(text) + let chunk_config = ChunkConfig::try_from(&self.splitter_options)?; + Ok(text_splitter::MarkdownSplitter::new(chunk_config) + .chunks(text) + .map(|x| x.to_string()) + .collect()) } } diff --git a/src/text_splitter/options.rs b/src/text_splitter/options.rs index 01609f6c..ccb585be 100644 --- a/src/text_splitter/options.rs +++ b/src/text_splitter/options.rs @@ -1,6 +1,13 @@ +use text_splitter::ChunkConfig; +use tiktoken_rs::{get_bpe_from_model, get_bpe_from_tokenizer, tokenizer::Tokenizer, CoreBPE}; + +use super::TextSplitterError; + // Options is a struct that contains options for a text splitter. +#[derive(Debug, Clone)] pub struct SplitterOptions { pub chunk_size: usize, + pub chunk_overlap: usize, pub model_name: String, pub encoding_name: String, pub trim_chunks: bool, @@ -16,6 +23,7 @@ impl SplitterOptions { pub fn new() -> Self { SplitterOptions { chunk_size: 512, + chunk_overlap: 0, model_name: String::from("gpt-3.5-turbo"), encoding_name: String::from("cl100k_base"), trim_chunks: false, @@ -30,6 +38,11 @@ impl SplitterOptions { self } + pub fn with_chunk_overlap(mut self, chunk_overlap: usize) -> Self { + self.chunk_overlap = chunk_overlap; + self + } + pub fn with_model_name(mut self, model_name: &str) -> Self { self.model_name = String::from(model_name); self @@ -44,4 +57,35 @@ impl SplitterOptions { self.trim_chunks = trim_chunks; self } + + pub fn get_tokenizer_from_str(s: &str) -> Option { + match s.to_lowercase().as_str() { + "cl100k_base" => Some(Tokenizer::Cl100kBase), + "p50k_base" => Some(Tokenizer::P50kBase), + "r50k_base" => Some(Tokenizer::R50kBase), + "p50k_edit" => Some(Tokenizer::P50kEdit), + "gpt2" => Some(Tokenizer::Gpt2), + _ => None, + } + } +} + +impl TryFrom<&SplitterOptions> for ChunkConfig { + type Error = TextSplitterError; + + fn try_from(options: &SplitterOptions) -> Result { + let tk = if !options.encoding_name.is_empty() { + let tokenizer = SplitterOptions::get_tokenizer_from_str(&options.encoding_name) + .ok_or(TextSplitterError::TokenizerNotFound)?; + + get_bpe_from_tokenizer(tokenizer).map_err(|_| TextSplitterError::InvalidTokenizer)? + } else { + get_bpe_from_model(&options.model_name).map_err(|_| TextSplitterError::InvalidModel)? + }; + + Ok(ChunkConfig::new(options.chunk_size) + .with_sizer(tk) + .with_trim(options.trim_chunks) + .with_overlap(options.chunk_overlap)?) + } } diff --git a/src/text_splitter/token_splitter.rs b/src/text_splitter/token_splitter.rs index 96cc08b7..567964fa 100644 --- a/src/text_splitter/token_splitter.rs +++ b/src/text_splitter/token_splitter.rs @@ -1,14 +1,12 @@ use async_trait::async_trait; -use tiktoken_rs::{get_bpe_from_model, get_bpe_from_tokenizer, tokenizer::Tokenizer, CoreBPE}; +use text_splitter::ChunkConfig; +use tiktoken_rs::tokenizer::Tokenizer; use super::{SplitterOptions, TextSplitter, TextSplitterError}; #[derive(Debug, Clone)] pub struct TokenSplitter { - chunk_size: usize, - model_name: String, - encoding_name: String, - trim_chunks: bool, + splitter_options: SplitterOptions, } impl Default for TokenSplitter { @@ -20,13 +18,11 @@ impl Default for TokenSplitter { impl TokenSplitter { pub fn new(options: SplitterOptions) -> TokenSplitter { TokenSplitter { - chunk_size: options.chunk_size, - model_name: options.model_name, - encoding_name: options.encoding_name, - trim_chunks: options.trim_chunks, + splitter_options: options, } } + #[deprecated = "Use `SplitterOptions::get_tokenizer_from_str` instead"] pub fn get_tokenizer_from_str(&self, s: &str) -> Option { match s.to_lowercase().as_str() { "cl100k_base" => Some(Tokenizer::Cl100kBase), @@ -37,30 +33,15 @@ impl TokenSplitter { _ => None, } } - - fn split(&self, text: &str, tokenizer: CoreBPE) -> Vec { - let splitter = - text_splitter::TextSplitter::new(tokenizer).with_trim_chunks(self.trim_chunks); - splitter - .chunks(text, self.chunk_size) - .map(|x| x.to_string()) - .collect() - } } #[async_trait] impl TextSplitter for TokenSplitter { async fn split_text(&self, text: &str) -> Result, TextSplitterError> { - let tk = if !self.encoding_name.is_empty() { - let tokenizer = self - .get_tokenizer_from_str(&self.encoding_name) - .ok_or(TextSplitterError::TokenizerNotFound)?; - - get_bpe_from_tokenizer(tokenizer).map_err(|_| TextSplitterError::InvalidTokenizer)? - } else { - get_bpe_from_model(&self.model_name).map_err(|_| TextSplitterError::InvalidModel)? - }; - let text = self.split(text, tk); - Ok(text) + let chunk_config = ChunkConfig::try_from(&self.splitter_options)?; + Ok(text_splitter::TextSplitter::new(chunk_config) + .chunks(text) + .map(|x| x.to_string()) + .collect()) } }