Skip to content

Commit

Permalink
fix(deps): update rust crate text-splitter to 0.13 (#152)
Browse files Browse the repository at this point in the history
* fix(deps): update rust crate text-splitter to 0.13

* Update for text-splitter 0.12 chunk config

* Add deprecated warning for get_tokenizer_from_str on splitters

---------

Co-authored-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com>
Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>
  • Loading branch information
renovate[bot] and benbrandt authored May 17, 2024
1 parent 0e91fdc commit 72dbfc2
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 59 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ pgvector = { version = "0.3.2", features = [
"postgres",
"sqlx",
], optional = true }
text-splitter = { version = "0.11", features = ["tiktoken-rs", "markdown"] }
text-splitter = { version = "0.13", features = ["tiktoken-rs", "markdown"] }
surrealdb = { version = "1.4.2", optional = true, default-features = false }
csv = "1.3.0"
urlencoding = "2.1.3"
Expand Down
10 changes: 10 additions & 0 deletions src/text_splitter/error.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use text_splitter::ChunkConfigError;
use thiserror::Error;

#[derive(Error, Debug)]
Expand All @@ -17,6 +18,15 @@ pub enum TextSplitterError {
#[error("Tokenizer creation failed due to invalid model")]
InvalidModel,

#[error("Invalid chunk overlap and size")]
InvalidSplitterOptions,

#[error("Error: {0}")]
OtherError(String),
}

impl From<ChunkConfigError> for TextSplitterError {
fn from(_: ChunkConfigError) -> Self {
Self::InvalidSplitterOptions
}
}
39 changes: 10 additions & 29 deletions src/text_splitter/markdown_splitter.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
use async_trait::async_trait;
use tiktoken_rs::{get_bpe_from_model, get_bpe_from_tokenizer, tokenizer::Tokenizer, CoreBPE};
use text_splitter::ChunkConfig;
use tiktoken_rs::tokenizer::Tokenizer;

use super::{SplitterOptions, TextSplitter, TextSplitterError};

pub struct MarkdownSplitter {
chunk_size: usize,
model_name: String,
encoding_name: String,
trim_chunks: bool,
splitter_options: SplitterOptions,
}

impl Default for MarkdownSplitter {
Expand All @@ -19,13 +17,11 @@ impl Default for MarkdownSplitter {
impl MarkdownSplitter {
pub fn new(options: SplitterOptions) -> MarkdownSplitter {
MarkdownSplitter {
chunk_size: options.chunk_size,
model_name: options.model_name,
encoding_name: options.encoding_name,
trim_chunks: options.trim_chunks,
splitter_options: options,
}
}

#[deprecated = "Use `SplitterOptions::get_tokenizer_from_str` instead"]
pub fn get_tokenizer_from_str(&self, s: &str) -> Option<Tokenizer> {
match s.to_lowercase().as_str() {
"cl100k_base" => Some(Tokenizer::Cl100kBase),
Expand All @@ -36,30 +32,15 @@ impl MarkdownSplitter {
_ => None,
}
}

fn split(&self, text: &str, tokenizer: CoreBPE) -> Vec<String> {
let splitter =
text_splitter::MarkdownSplitter::new(tokenizer).with_trim_chunks(self.trim_chunks);
splitter
.chunks(text, self.chunk_size)
.map(|x| x.to_string())
.collect()
}
}

#[async_trait]
impl TextSplitter for MarkdownSplitter {
async fn split_text(&self, text: &str) -> Result<Vec<String>, TextSplitterError> {
let tk = if !self.encoding_name.is_empty() {
let tokenizer = self
.get_tokenizer_from_str(&self.encoding_name)
.ok_or(TextSplitterError::TokenizerNotFound)?;

get_bpe_from_tokenizer(tokenizer).map_err(|_| TextSplitterError::InvalidTokenizer)?
} else {
get_bpe_from_model(&self.model_name).map_err(|_| TextSplitterError::InvalidModel)?
};
let text = self.split(text, tk);
Ok(text)
let chunk_config = ChunkConfig::try_from(&self.splitter_options)?;
Ok(text_splitter::MarkdownSplitter::new(chunk_config)
.chunks(text)
.map(|x| x.to_string())
.collect())
}
}
44 changes: 44 additions & 0 deletions src/text_splitter/options.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
use text_splitter::ChunkConfig;
use tiktoken_rs::{get_bpe_from_model, get_bpe_from_tokenizer, tokenizer::Tokenizer, CoreBPE};

use super::TextSplitterError;

// Options is a struct that contains options for a text splitter.
#[derive(Debug, Clone)]
pub struct SplitterOptions {
pub chunk_size: usize,
pub chunk_overlap: usize,
pub model_name: String,
pub encoding_name: String,
pub trim_chunks: bool,
Expand All @@ -16,6 +23,7 @@ impl SplitterOptions {
pub fn new() -> Self {
SplitterOptions {
chunk_size: 512,
chunk_overlap: 0,
model_name: String::from("gpt-3.5-turbo"),
encoding_name: String::from("cl100k_base"),
trim_chunks: false,
Expand All @@ -30,6 +38,11 @@ impl SplitterOptions {
self
}

pub fn with_chunk_overlap(mut self, chunk_overlap: usize) -> Self {
self.chunk_overlap = chunk_overlap;
self
}

pub fn with_model_name(mut self, model_name: &str) -> Self {
self.model_name = String::from(model_name);
self
Expand All @@ -44,4 +57,35 @@ impl SplitterOptions {
self.trim_chunks = trim_chunks;
self
}

pub fn get_tokenizer_from_str(s: &str) -> Option<Tokenizer> {
match s.to_lowercase().as_str() {
"cl100k_base" => Some(Tokenizer::Cl100kBase),
"p50k_base" => Some(Tokenizer::P50kBase),
"r50k_base" => Some(Tokenizer::R50kBase),
"p50k_edit" => Some(Tokenizer::P50kEdit),
"gpt2" => Some(Tokenizer::Gpt2),
_ => None,
}
}
}

impl TryFrom<&SplitterOptions> for ChunkConfig<CoreBPE> {
type Error = TextSplitterError;

fn try_from(options: &SplitterOptions) -> Result<Self, Self::Error> {
let tk = if !options.encoding_name.is_empty() {
let tokenizer = SplitterOptions::get_tokenizer_from_str(&options.encoding_name)
.ok_or(TextSplitterError::TokenizerNotFound)?;

get_bpe_from_tokenizer(tokenizer).map_err(|_| TextSplitterError::InvalidTokenizer)?
} else {
get_bpe_from_model(&options.model_name).map_err(|_| TextSplitterError::InvalidModel)?
};

Ok(ChunkConfig::new(options.chunk_size)
.with_sizer(tk)
.with_trim(options.trim_chunks)
.with_overlap(options.chunk_overlap)?)
}
}
39 changes: 10 additions & 29 deletions src/text_splitter/token_splitter.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
use async_trait::async_trait;
use tiktoken_rs::{get_bpe_from_model, get_bpe_from_tokenizer, tokenizer::Tokenizer, CoreBPE};
use text_splitter::ChunkConfig;
use tiktoken_rs::tokenizer::Tokenizer;

use super::{SplitterOptions, TextSplitter, TextSplitterError};

#[derive(Debug, Clone)]
pub struct TokenSplitter {
chunk_size: usize,
model_name: String,
encoding_name: String,
trim_chunks: bool,
splitter_options: SplitterOptions,
}

impl Default for TokenSplitter {
Expand All @@ -20,13 +18,11 @@ impl Default for TokenSplitter {
impl TokenSplitter {
pub fn new(options: SplitterOptions) -> TokenSplitter {
TokenSplitter {
chunk_size: options.chunk_size,
model_name: options.model_name,
encoding_name: options.encoding_name,
trim_chunks: options.trim_chunks,
splitter_options: options,
}
}

#[deprecated = "Use `SplitterOptions::get_tokenizer_from_str` instead"]
pub fn get_tokenizer_from_str(&self, s: &str) -> Option<Tokenizer> {
match s.to_lowercase().as_str() {
"cl100k_base" => Some(Tokenizer::Cl100kBase),
Expand All @@ -37,30 +33,15 @@ impl TokenSplitter {
_ => None,
}
}

fn split(&self, text: &str, tokenizer: CoreBPE) -> Vec<String> {
let splitter =
text_splitter::TextSplitter::new(tokenizer).with_trim_chunks(self.trim_chunks);
splitter
.chunks(text, self.chunk_size)
.map(|x| x.to_string())
.collect()
}
}

#[async_trait]
impl TextSplitter for TokenSplitter {
async fn split_text(&self, text: &str) -> Result<Vec<String>, TextSplitterError> {
let tk = if !self.encoding_name.is_empty() {
let tokenizer = self
.get_tokenizer_from_str(&self.encoding_name)
.ok_or(TextSplitterError::TokenizerNotFound)?;

get_bpe_from_tokenizer(tokenizer).map_err(|_| TextSplitterError::InvalidTokenizer)?
} else {
get_bpe_from_model(&self.model_name).map_err(|_| TextSplitterError::InvalidModel)?
};
let text = self.split(text, tk);
Ok(text)
let chunk_config = ChunkConfig::try_from(&self.splitter_options)?;
Ok(text_splitter::TextSplitter::new(chunk_config)
.chunks(text)
.map(|x| x.to_string())
.collect())
}
}

0 comments on commit 72dbfc2

Please sign in to comment.