Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add SpaCy as a pre-splitter for Rolling Window - Fix #193 #204

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,13 @@ black = "^23.12.1"
colorama = "^0.4.6"
pinecone-client = {version="^3.0.0", optional = true}
regex = "^2023.12.25"
spacy = "^3.0"
klein-t marked this conversation as resolved.
Show resolved Hide resolved
torchvision = { version = "^0.17.0", optional = true}
pillow = { version= "^10.2.0", optional = true}
tiktoken = "^0.6.0"
matplotlib = { version="^3.8.3", optional = true}


[tool.poetry.extras]
hybrid = ["pinecone-text"]
fastembed = ["fastembed"]
Expand Down
23 changes: 20 additions & 3 deletions semantic_router/splitters/rolling_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from semantic_router.encoders.base import BaseEncoder
from semantic_router.schema import DocumentSplit
from semantic_router.splitters.base import BaseSplitter
from semantic_router.splitters.utils import split_to_sentences, tiktoken_length
from semantic_router.splitters.utils import split_to_sentences, split_to_sentences_spacy, check_and_download_spacy_model, tiktoken_length
from semantic_router.utils.logger import logger


Expand Down Expand Up @@ -39,6 +39,8 @@ class RollingWindowSplitter(BaseSplitter):
def __init__(
self,
encoder: BaseEncoder,
pre_splitter: str = "regex",
spacy_model: str = "en_core_web_sm",
threshold_adjustment=0.01,
dynamic_threshold: bool = True,
window_size=5,
Expand All @@ -51,6 +53,8 @@ def __init__(
super().__init__(name=name, encoder=encoder)
self.calculated_threshold: float
self.encoder = encoder
self.pre_splitter = pre_splitter
self.spacy_model = spacy_model
self.threshold_adjustment = threshold_adjustment
self.dynamic_threshold = dynamic_threshold
self.window_size = window_size
Expand Down Expand Up @@ -79,7 +83,15 @@ def __call__(self, docs: List[str]) -> List[DocumentSplit]:
f"of {self.max_split_tokens}. "
"Splitting to sentences before semantically splitting."
)
docs = split_to_sentences(docs[0])
try:
if self.pre_splitter == "spacy":
check_and_download_spacy_model(self.spacy_model)
docs = split_to_sentences_spacy(docs[0])
elif self.pre_splitter == "regex":
docs = split_to_sentences(docs[0])
except Exception as e:
logger.error(f"Error splitting document to sentences: {e}")
raise
encoded_docs = self._encode_documents(docs)
similarities = self._calculate_similarity_scores(encoded_docs)
if self.dynamic_threshold:
Expand Down Expand Up @@ -401,7 +413,12 @@ def plot_sentence_similarity_scores(
sentence after a similarity score below
a specified threshold.
"""
sentences = [sentence for doc in docs for sentence in split_to_sentences(doc)]
if self.pre_splitter == "spacy":
sentences = [sentence for doc in docs for sentence in split_to_sentences_spacy(doc)]
elif self.pre_splitter == "regex":
sentences = [sentence for doc in docs for sentence in split_to_sentences(doc)]
else:
raise ValueError("Invalid pre_splitter value. Supported values are 'spacy' and 'regex'.")
encoded_sentences = self._encode_documents(sentences)
similarity_scores = []

Expand Down
35 changes: 35 additions & 0 deletions semantic_router/splitters/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import regex
import spacy
klein-t marked this conversation as resolved.
Show resolved Hide resolved
import tiktoken


Expand Down Expand Up @@ -57,6 +58,40 @@ def split_to_sentences(text: str) -> list[str]:
return sentences


def split_to_sentences_spacy(text: str) -> list[str]:
"""
Use SpaCy to split a given text into sentences. Supported languages: English.

Args:
text (str): The text to split into sentences.

Returns:
list: A list of sentences extracted from the text.
"""
nlp = spacy.load("en_core_web_sm")
doc = nlp(text)
sentences = [sentence.text.strip() for sentence in doc.sents]
return sentences

def check_and_download_spacy_model(model_name="en_core_web_sm"):
"""
Checks if the specified SpaCy language model is installed, and if not, attempts to download and install it.

Args:
- model_name (str): The name of the SpaCy model to check and download. Defaults to 'en_core_web_sm'.

"""
try:
# Try loading the model to see if it's already installed
spacy.load(model_name)
print(f"Spacy model '{model_name}' is installed.")
except OSError:
print(f"Spacy model '{model_name}' not found, downloading...")
from spacy.cli import download
download(model_name)
print(f"Downloaded and installed model '{model_name}'.")


def tiktoken_length(text: str) -> int:
tokenizer = tiktoken.get_encoding("cl100k_base")
tokens = tokenizer.encode(text, disallowed_special=())
Expand Down
Loading