Skip to content

Commit

Permalink
added test + lint + codecov
Browse files Browse the repository at this point in the history
  • Loading branch information
simjak committed Dec 13, 2023
1 parent b743510 commit 671cd07
Show file tree
Hide file tree
Showing 11 changed files with 164 additions and 58 deletions.
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ lint_diff: PYTHON_FILES=$(shell git diff --name-only --diff-filter=d main | grep
lint lint_diff:
poetry run black $(PYTHON_FILES) --check
poetry run ruff .
poetry run mypy $(PYTHON_FILES)

test:
poetry run pytest -vv -n 20 --cov=semantic_router --cov-report=term-missing --cov-report=xml --cov-fail-under=100
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
<img alt="" src="https://img.shields.io/github/repo-size/aurelio-labs/semantic-router" />
<img alt="GitHub Issues" src="https://img.shields.io/github/issues/aurelio-labs/semantic-router" />
<img alt="GitHub Pull Requests" src="https://img.shields.io/github/issues-pr/aurelio-labs/semantic-router" />
<img src="https://codecov.io/gh/aurelio-labs/semantic-router/graph/badge.svg?token=H8OOMV2TUF" />
<img alt="Github License" src="https://img.shields.io/badge/License-MIT-yellow.svg" />
</p>

Expand Down
99 changes: 58 additions & 41 deletions coverage.xml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
<?xml version="1.0" ?>
<coverage version="7.3.2" timestamp="1702462041712" lines-valid="317" lines-covered="317" line-rate="1" branches-covered="0" branches-valid="0" branch-rate="0" complexity="0">
<coverage version="7.3.2" timestamp="1702463592393" lines-valid="334" lines-covered="334" line-rate="1" branches-covered="0" branches-valid="0" branch-rate="0" complexity="0">
<!-- Generated by coverage.py: https://coverage.readthedocs.io/en/7.3.2 -->
<!-- Based on https://raw.githubusercontent.com/cobertura/web/master/htdocs/xml/coverage-04.dtd -->
<sources>
Expand All @@ -22,8 +22,8 @@
<line number="1" hits="1"/>
<line number="2" hits="1"/>
<line number="3" hits="1"/>
<line number="4" hits="1"/>
<line number="6" hits="1"/>
<line number="5" hits="1"/>
<line number="11" hits="1"/>
<line number="12" hits="1"/>
<line number="15" hits="1"/>
<line number="16" hits="1"/>
Expand Down Expand Up @@ -102,10 +102,13 @@
<line number="131" hits="1"/>
<line number="132" hits="1"/>
<line number="135" hits="1"/>
<line number="137" hits="1"/>
<line number="136" hits="1"/>
<line number="138" hits="1"/>
<line number="139" hits="1"/>
<line number="141" hits="1"/>
<line number="142" hits="1"/>
<line number="143" hits="1"/>
<line number="145" hits="1"/>
</lines>
</class>
<class name="layer.py" filename="layer.py" complexity="0" line-rate="1" branch-rate="0">
Expand All @@ -115,68 +118,73 @@
<line number="3" hits="1"/>
<line number="8" hits="1"/>
<line number="9" hits="1"/>
<line number="12" hits="1"/>
<line number="10" hits="1"/>
<line number="13" hits="1"/>
<line number="14" hits="1"/>
<line number="15" hits="1"/>
<line number="17" hits="1"/>
<line number="16" hits="1"/>
<line number="18" hits="1"/>
<line number="20" hits="1"/>
<line number="19" hits="1"/>
<line number="21" hits="1"/>
<line number="22" hits="1"/>
<line number="23" hits="1"/>
<line number="25" hits="1"/>
<line number="27" hits="1"/>
<line number="29" hits="1"/>
<line number="31" hits="1"/>
<line number="24" hits="1"/>
<line number="26" hits="1"/>
<line number="28" hits="1"/>
<line number="30" hits="1"/>
<line number="32" hits="1"/>
<line number="33" hits="1"/>
<line number="34" hits="1"/>
<line number="35" hits="1"/>
<line number="36" hits="1"/>
<line number="38" hits="1"/>
<line number="40" hits="1"/>
<line number="42" hits="1"/>
<line number="45" hits="1"/>
<line number="37" hits="1"/>
<line number="39" hits="1"/>
<line number="41" hits="1"/>
<line number="43" hits="1"/>
<line number="46" hits="1"/>
<line number="48" hits="1"/>
<line number="47" hits="1"/>
<line number="49" hits="1"/>
<line number="51" hits="1"/>
<line number="50" hits="1"/>
<line number="52" hits="1"/>
<line number="54" hits="1"/>
<line number="53" hits="1"/>
<line number="55" hits="1"/>
<line number="57" hits="1"/>
<line number="59" hits="1"/>
<line number="62" hits="1"/>
<line number="65" hits="1"/>
<line number="56" hits="1"/>
<line number="58" hits="1"/>
<line number="60" hits="1"/>
<line number="63" hits="1"/>
<line number="66" hits="1"/>
<line number="67" hits="1"/>
<line number="74" hits="1"/>
<line number="68" hits="1"/>
<line number="75" hits="1"/>
<line number="81" hits="1"/>
<line number="86" hits="1"/>
<line number="76" hits="1"/>
<line number="82" hits="1"/>
<line number="87" hits="1"/>
<line number="89" hits="1"/>
<line number="91" hits="1"/>
<line number="88" hits="1"/>
<line number="90" hits="1"/>
<line number="92" hits="1"/>
<line number="94" hits="1"/>
<line number="93" hits="1"/>
<line number="95" hits="1"/>
<line number="97" hits="1"/>
<line number="96" hits="1"/>
<line number="98" hits="1"/>
<line number="99" hits="1"/>
<line number="100" hits="1"/>
<line number="101" hits="1"/>
<line number="102" hits="1"/>
<line number="103" hits="1"/>
<line number="104" hits="1"/>
<line number="105" hits="1"/>
<line number="106" hits="1"/>
<line number="107" hits="1"/>
<line number="110" hits="1"/>
<line number="111" hits="1"/>
<line number="114" hits="1"/>
<line number="109" hits="1"/>
<line number="112" hits="1"/>
<line number="113" hits="1"/>
<line number="116" hits="1"/>
<line number="117" hits="1"/>
<line number="118" hits="1"/>
<line number="119" hits="1"/>
<line number="120" hits="1"/>
<line number="122" hits="1"/>
<line number="123" hits="1"/>
<line number="124" hits="1"/>
<line number="126" hits="1"/>
</lines>
</class>
<class name="linear.py" filename="linear.py" complexity="0" line-rate="1" branch-rate="0">
Expand Down Expand Up @@ -271,31 +279,40 @@
<lines>
<line number="1" hits="1"/>
<line number="3" hits="1"/>
<line number="6" hits="1"/>
<line number="7" hits="1"/>
<line number="5" hits="1"/>
<line number="8" hits="1"/>
<line number="9" hits="1"/>
<line number="10" hits="1"/>
<line number="11" hits="1"/>
<line number="12" hits="1"/>
<line number="13" hits="1"/>
<line number="14" hits="1"/>
<line number="16" hits="1"/>
<line number="17" hits="1"/>
<line number="18" hits="1"/>
<line number="19" hits="1"/>
<line number="20" hits="1"/>
<line number="21" hits="1"/>
<line number="22" hits="1"/>
<line number="23" hits="1"/>
<line number="24" hits="1"/>
<line number="25" hits="1"/>
<line number="26" hits="1"/>
<line number="27" hits="1"/>
<line number="28" hits="1"/>
<line number="29" hits="1"/>
<line number="30" hits="1"/>
<line number="31" hits="1"/>
<line number="32" hits="1"/>
<line number="33" hits="1"/>
<line number="34" hits="1"/>
<line number="35" hits="1"/>
<line number="36" hits="1"/>
<line number="37" hits="1"/>
<line number="38" hits="1"/>
<line number="39" hits="1"/>
<line number="40" hits="1"/>
<line number="41" hits="1"/>
<line number="42" hits="1"/>
<line number="44" hits="1"/>
<line number="45" hits="1"/>
<line number="46" hits="1"/>
<line number="47" hits="1"/>
</lines>
</class>
<class name="cohere.py" filename="encoders/cohere.py" complexity="0" line-rate="1" branch-rate="0">
Expand Down
49 changes: 48 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,14 @@ pytest = "^7.4.3"
pytest-mock = "^3.12.0"
pytest-cov = "^4.1.0"
pytest-xdist = "^3.5.0"
mypy = "^1.7.1"

[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"

[tool.ruff.per-file-ignores]
"*.ipynb" = ["E402"]

[tool.mypy]
ignore_missing_imports = true
2 changes: 1 addition & 1 deletion semantic_router/encoders/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@ class BaseEncoder(BaseModel):
class Config:
arbitrary_types_allowed = True

def __call__(self, docs: list[str]) -> list[float]:
def __call__(self, docs: list[str]) -> list[list[float]]:
raise NotImplementedError("Subclasses must implement this method")
25 changes: 16 additions & 9 deletions semantic_router/encoders/bm25.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,36 @@
from typing import Any

from pinecone_text.sparse import BM25Encoder as encoder

from semantic_router.encoders import BaseEncoder


class BM25Encoder(BaseEncoder):
model: encoder | None = None
model: Any | None = None
idx_mapping: dict[int, int] | None = None

def __init__(self, name: str = "bm25"):
super().__init__(name=name)
# initialize BM25 encoder with default params (trained on MSMarco)
self.model = encoder.default()
self.idx_mapping = {
idx: i
for i, idx in enumerate(self.model.get_params()["doc_freq"]["indices"])
}

params = self.model.get_params()
doc_freq = params["doc_freq"]
if isinstance(doc_freq, dict):
indices = doc_freq["indices"]
self.idx_mapping = {int(idx): i for i, idx in enumerate(indices)}
else:
raise TypeError("Expected a dictionary for 'doc_freq'")

def __call__(self, docs: list[str]) -> list[list[float]]:
if self.model is None or self.idx_mapping is None:
raise ValueError("Model or index mapping is not initialized.")
if len(docs) == 1:
sparse_dicts = self.model.encode_queries(docs)
elif len(docs) > 1:
sparse_dicts = self.model.encode_documents(docs)
else:
raise ValueError("No documents to encode.")
# convert sparse dict to sparse vector

embeds = [[0.0] * len(self.idx_mapping)] * len(docs)
for i, output in enumerate(sparse_dicts):
indices = output["indices"]
Expand All @@ -32,9 +39,9 @@ def __call__(self, docs: list[str]) -> list[list[float]]:
if idx in self.idx_mapping:
position = self.idx_mapping[idx]
embeds[i][position] = val
else:
print(idx, "not in encoder.idx_mapping")
return embeds

def fit(self, docs: list[str]):
if self.model is None:
raise ValueError("Model is not initialized.")
self.model.fit(docs)
10 changes: 7 additions & 3 deletions semantic_router/hybrid_layer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import numpy as np
from numpy.linalg import norm
from tqdm.auto import tqdm
from semantic_router.utils.logger import logger

from semantic_router.encoders import (
BaseEncoder,
Expand All @@ -10,6 +9,7 @@
OpenAIEncoder,
)
from semantic_router.schema import Route
from semantic_router.utils.logger import logger


class HybridRouteLayer:
Expand Down Expand Up @@ -118,7 +118,7 @@ def _convex_scaling(self, dense: np.ndarray, sparse: np.ndarray):
return dense, sparse

def _semantic_classify(self, query_results: list[dict]) -> tuple[str, list[float]]:
scores_by_class = {}
scores_by_class: dict[str, list[float]] = {}
for result in query_results:
score = result["score"]
route = result["route"]
Expand All @@ -132,7 +132,11 @@ def _semantic_classify(self, query_results: list[dict]) -> tuple[str, list[float
top_class = max(total_scores, key=lambda x: total_scores[x], default=None)

# Return the top class and its associated scores
return str(top_class), scores_by_class.get(top_class, [])
if top_class is not None:
return str(top_class), scores_by_class.get(top_class, [])
else:
logger.warning("No classification found for semantic classifier.")
return "", []

def _pass_threshold(self, scores: list[float], threshold: float) -> bool:
if scores:
Expand Down
Loading

0 comments on commit 671cd07

Please sign in to comment.