diff --git a/pyproject.toml b/pyproject.toml index 94db45e6..3bba5aab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,19 +32,21 @@ black = "^23.12.1" colorama = "^0.4.6" pinecone-client = {version="^3.0.0", optional = true} regex = "^2023.12.25" +spacy = { version = "^3.0", optional = true } torchvision = { version = "^0.17.0", optional = true} pillow = { version= "^10.2.0", optional = true} tiktoken = "^0.6.0" matplotlib = { version="^3.8.3", optional = true} qdrant-client = {version="^1.8.0", optional = true} + [tool.poetry.extras] hybrid = ["pinecone-text"] fastembed = ["fastembed"] local = ["torch", "transformers", "llama-cpp-python"] pinecone = ["pinecone-client"] vision = ["torch", "torchvision", "transformers", "pillow"] -processing = ["matplotlib"] +processing = ["matplotlib", "spacy"] mistralai = ["mistralai"] qdrant = ["qdrant-client"] diff --git a/semantic_router/splitters/rolling_window.py b/semantic_router/splitters/rolling_window.py index 092433fe..7f0eb1d6 100644 --- a/semantic_router/splitters/rolling_window.py +++ b/semantic_router/splitters/rolling_window.py @@ -6,7 +6,7 @@ from semantic_router.encoders.base import BaseEncoder from semantic_router.schema import DocumentSplit from semantic_router.splitters.base import BaseSplitter -from semantic_router.splitters.utils import split_to_sentences, tiktoken_length +from semantic_router.splitters.utils import split_to_sentences, split_to_sentences_spacy, tiktoken_length from semantic_router.utils.logger import logger @@ -39,6 +39,8 @@ class RollingWindowSplitter(BaseSplitter): def __init__( self, encoder: BaseEncoder, + pre_splitter: str = "regex", + spacy_model: str = "en_core_web_sm", threshold_adjustment=0.01, dynamic_threshold: bool = True, window_size=5, @@ -51,6 +53,8 @@ def __init__( super().__init__(name=name, encoder=encoder) self.calculated_threshold: float self.encoder = encoder + self.pre_splitter = pre_splitter + self.spacy_model = spacy_model self.threshold_adjustment = threshold_adjustment self.dynamic_threshold = dynamic_threshold self.window_size = window_size @@ -79,7 +83,14 @@ def __call__(self, docs: List[str]) -> List[DocumentSplit]: f"of {self.max_split_tokens}. " "Splitting to sentences before semantically splitting." ) - docs = split_to_sentences(docs[0]) + try: + if self.pre_splitter == "spacy": + docs = split_to_sentences_spacy(docs[0], self.spacy_model) + elif self.pre_splitter == "regex": + docs = split_to_sentences(docs[0]) + except Exception as e: + logger.error(f"Error splitting document to sentences: {e}") + raise encoded_docs = self._encode_documents(docs) similarities = self._calculate_similarity_scores(encoded_docs) if self.dynamic_threshold: @@ -401,7 +412,12 @@ def plot_sentence_similarity_scores( sentence after a similarity score below a specified threshold. """ - sentences = [sentence for doc in docs for sentence in split_to_sentences(doc)] + if self.pre_splitter == "spacy": + sentences = [sentence for doc in docs for sentence in split_to_sentences_spacy(doc)] + elif self.pre_splitter == "regex": + sentences = [sentence for doc in docs for sentence in split_to_sentences(doc)] + else: + raise ValueError("Invalid pre_splitter value. Supported values are 'spacy' and 'regex'.") encoded_sentences = self._encode_documents(sentences) similarity_scores = [] diff --git a/semantic_router/splitters/utils.py b/semantic_router/splitters/utils.py index 349c3eaa..c1492d2c 100644 --- a/semantic_router/splitters/utils.py +++ b/semantic_router/splitters/utils.py @@ -1,6 +1,8 @@ import regex import tiktoken +from semantic_router.utils.logger import logger + def split_to_sentences(text: str) -> list[str]: """ @@ -57,6 +59,42 @@ def split_to_sentences(text: str) -> list[str]: return sentences +def split_to_sentences_spacy(text: str, spacy_model: str = "en_core_web_sm") -> list[str]: + """ + Use SpaCy to split a given text into sentences. Supported languages: English. + + Args: + text (str): The text to split into sentences. + + Returns: + list: A list of sentences extracted from the text. + """ + + # Check if SpaCy is installed + try: + import spacy + except ImportError: + logger.warning( + "SpaCy is not installed. Please `pip install " + "semantic-router[processing]`." + ) + return + + # Check if the SpaCy model is installed + try: + spacy.load(spacy_model) + except OSError: + print(f"Spacy model '{spacy_model}' not found, downloading...") + from spacy.cli import download + download(spacy_model) + print(f"Downloaded and installed model '{spacy_model}'.") + + nlp = spacy.load("en_core_web_sm") + doc = nlp(text) + sentences = [sentence.text.strip() for sentence in doc.sents] + return sentences + + def tiktoken_length(text: str) -> int: tokenizer = tiktoken.get_encoding("cl100k_base") tokens = tokenizer.encode(text, disallowed_special=())