Skip to content

Commit

Permalink
hf encoder tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ashraq1455 committed Jan 9, 2024
1 parent 9bfba6b commit 7edeb32
Showing 1 changed file with 42 additions and 0 deletions.
42 changes: 42 additions & 0 deletions tests/unit/encoders/test_huggingface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import pytest
import numpy as np
from semantic_router.encoders.huggingface import HuggingFaceEncoder


class TestHuggingFaceEncoder:
def test_huggingface_encoder(self):
encoder = HuggingFaceEncoder()
test_docs = ["This is a test", "This is another test"]
embeddings = encoder(test_docs)
assert isinstance(embeddings, list)
assert len(embeddings) == len(test_docs)
assert all(isinstance(embedding, list) for embedding in embeddings)
assert all(len(embedding) > 0 for embedding in embeddings)

def test_huggingface_encoder_normalized_embeddings(self):
encoder = HuggingFaceEncoder()
docs = ["This is a test document.", "Another test document."]
unnormalized_embeddings = encoder(docs, normalize_embeddings=False)
normalized_embeddings = encoder(docs, normalize_embeddings=True)
assert len(unnormalized_embeddings) == len(normalized_embeddings)

for unnormalized, normalized in zip(
unnormalized_embeddings, normalized_embeddings
):
norm_unnormalized = np.linalg.norm(unnormalized, ord=2)
norm_normalized = np.linalg.norm(normalized, ord=2)
# Ensure the norm of the normalized embeddings is approximately 1
assert np.isclose(norm_normalized, 1.0)
# Ensure the normalized embeddings are actually normalized versions of unnormalized embeddings
np.testing.assert_allclose(
normalized,
np.divide(unnormalized, norm_unnormalized),
rtol=1e-5,
atol=1e-5, # Adjust tolerance levels
)

def test_huggingface_encoder_invalid_pooling_strategy(self):
encoder = HuggingFaceEncoder()
docs = ["This is a test document.", "Another test document."]
with pytest.raises(ValueError):
encoder(docs, pooling_strategy="invalid_strategy")

0 comments on commit 7edeb32

Please sign in to comment.