Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add SpaCy as a pre-splitter for Rolling Window - Fix #193 #204

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,21 @@ black = "^23.12.1"
colorama = "^0.4.6"
pinecone-client = {version="^3.0.0", optional = true}
regex = "^2023.12.25"
spacy = { version = "^3.0", optional = true }
torchvision = { version = "^0.17.0", optional = true}
pillow = { version= "^10.2.0", optional = true}
tiktoken = "^0.6.0"
matplotlib = { version="^3.8.3", optional = true}
qdrant-client = {version="^1.8.0", optional = true}


[tool.poetry.extras]
hybrid = ["pinecone-text"]
fastembed = ["fastembed"]
local = ["torch", "transformers", "llama-cpp-python"]
pinecone = ["pinecone-client"]
vision = ["torch", "torchvision", "transformers", "pillow"]
processing = ["matplotlib"]
processing = ["matplotlib", "spacy"]
mistralai = ["mistralai"]
qdrant = ["qdrant-client"]

Expand Down
22 changes: 19 additions & 3 deletions semantic_router/splitters/rolling_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from semantic_router.encoders.base import BaseEncoder
from semantic_router.schema import DocumentSplit
from semantic_router.splitters.base import BaseSplitter
from semantic_router.splitters.utils import split_to_sentences, tiktoken_length
from semantic_router.splitters.utils import split_to_sentences, split_to_sentences_spacy, tiktoken_length
from semantic_router.utils.logger import logger


Expand Down Expand Up @@ -39,6 +39,8 @@ class RollingWindowSplitter(BaseSplitter):
def __init__(
self,
encoder: BaseEncoder,
pre_splitter: str = "regex",
spacy_model: str = "en_core_web_sm",
threshold_adjustment=0.01,
dynamic_threshold: bool = True,
window_size=5,
Expand All @@ -51,6 +53,8 @@ def __init__(
super().__init__(name=name, encoder=encoder)
self.calculated_threshold: float
self.encoder = encoder
self.pre_splitter = pre_splitter
self.spacy_model = spacy_model
self.threshold_adjustment = threshold_adjustment
self.dynamic_threshold = dynamic_threshold
self.window_size = window_size
Expand Down Expand Up @@ -79,7 +83,14 @@ def __call__(self, docs: List[str]) -> List[DocumentSplit]:
f"of {self.max_split_tokens}. "
"Splitting to sentences before semantically splitting."
)
docs = split_to_sentences(docs[0])
try:
if self.pre_splitter == "spacy":
docs = split_to_sentences_spacy(docs[0], self.spacy_model)
elif self.pre_splitter == "regex":
docs = split_to_sentences(docs[0])
except Exception as e:
logger.error(f"Error splitting document to sentences: {e}")
raise
encoded_docs = self._encode_documents(docs)
similarities = self._calculate_similarity_scores(encoded_docs)
if self.dynamic_threshold:
Expand Down Expand Up @@ -401,7 +412,12 @@ def plot_sentence_similarity_scores(
sentence after a similarity score below
a specified threshold.
"""
sentences = [sentence for doc in docs for sentence in split_to_sentences(doc)]
if self.pre_splitter == "spacy":
sentences = [sentence for doc in docs for sentence in split_to_sentences_spacy(doc)]
elif self.pre_splitter == "regex":
sentences = [sentence for doc in docs for sentence in split_to_sentences(doc)]
else:
raise ValueError("Invalid pre_splitter value. Supported values are 'spacy' and 'regex'.")
encoded_sentences = self._encode_documents(sentences)
similarity_scores = []

Expand Down
38 changes: 38 additions & 0 deletions semantic_router/splitters/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import regex
import tiktoken

from semantic_router.utils.logger import logger


def split_to_sentences(text: str) -> list[str]:
"""
Expand Down Expand Up @@ -57,6 +59,42 @@ def split_to_sentences(text: str) -> list[str]:
return sentences


def split_to_sentences_spacy(text: str, spacy_model: str = "en_core_web_sm") -> list[str]:
"""
Use SpaCy to split a given text into sentences. Supported languages: English.

Args:
text (str): The text to split into sentences.

Returns:
list: A list of sentences extracted from the text.
"""

# Check if SpaCy is installed
try:
import spacy
except ImportError:
logger.warning(
"SpaCy is not installed. Please `pip install "
"semantic-router[processing]`."
)
return

# Check if the SpaCy model is installed
try:
spacy.load(spacy_model)
except OSError:
print(f"Spacy model '{spacy_model}' not found, downloading...")
from spacy.cli import download
download(spacy_model)
print(f"Downloaded and installed model '{spacy_model}'.")

nlp = spacy.load("en_core_web_sm")
doc = nlp(text)
sentences = [sentence.text.strip() for sentence in doc.sents]
return sentences


def tiktoken_length(text: str) -> int:
tokenizer = tiktoken.get_encoding("cl100k_base")
tokens = tokenizer.encode(text, disallowed_special=())
Expand Down
Loading