Skip to content

Commit

Permalink
Merge pull request #7 from google-marketing-solutions/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
maximilianw-google authored Aug 8, 2024
2 parents 8ccecfd + bb6ccbb commit 8fe1afd
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 8 deletions.
3 changes: 3 additions & 0 deletions api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class ClassifyRequest(pydantic.BaseModel):

text: str | list[str] = None
media_uri: str | list[str] = None
embeddings: bool = False


class ClassifyResponse(pydantic.BaseModel):
Expand All @@ -50,6 +51,7 @@ class ClassifyResponse(pydantic.BaseModel):
media_uri: str | None = None
media_description: str | None = None
categories: list[dict[str, Union[str, float]]]
embedding: list[float] | None = None


class GenerateTaxonomyEmbeddingsRequest(pydantic.BaseModel):
Expand Down Expand Up @@ -119,6 +121,7 @@ def classify(
classify_results = services['classify_service'].classify(
request.text,
request.media_uri,
request.embeddings,
)
return classify_results
except Exception as e:
Expand Down
32 changes: 27 additions & 5 deletions api/main_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,17 @@


_TEST_CLASSIFY_RESPONSE_SUCCESS = [
{
'text': 'foobar',
'categories': [
{'name': 'category_1', 'similarity': 0.98},
{'name': 'category_2', 'similarity': 0.89},
],
'embedding': [0.1, 0.2, 0.3],
},
]

_TEST_CLASSIFY_RESPONSE_SUCCESS_NO_EMBEDDINGS = [
{
'text': 'foobar',
'categories': [
Expand Down Expand Up @@ -155,10 +166,13 @@ def test_classify_service(self):
{'name': 'category_1', 'similarity': 0.98},
{'name': 'category_2', 'similarity': 0.89},
],
'embedding': [0.1, 0.2, 0.3],
},
]
with testclient.TestClient(main.app) as client:
actual = client.post('/classify', json={'text': 'foobar'})
actual = client.post(
'/classify', json={'text': 'foobar', 'embeddings': True}
)

self.assertEqual(actual.status_code, 200)
self.assertListEqual(actual.json(), _TEST_CLASSIFY_RESPONSE_SUCCESS)
Expand All @@ -171,12 +185,15 @@ def test_classify_service_use_vector_search(self):
{'name': 'category_1', 'similarity': 0.98},
{'name': 'category_2', 'similarity': 0.89},
],
'embedding': [0.1, 0.2, 0.3],
},
]
with testclient.TestClient(main.app) as client:
actual = client.post('/classify', json={'text': 'foobar'})
actual = client.post(
'/classify', json={'text': 'foobar', 'embeddings': True}
)
self.mock_classify_service.return_value.classify.assert_called_once_with(
'foobar', None
'foobar', None, True
)
self.assertEqual(actual.status_code, 200)
self.assertListEqual(actual.json(), _TEST_CLASSIFY_RESPONSE_SUCCESS)
Expand All @@ -191,11 +208,14 @@ def test_classify_service_with_list(self):
],
},
]

with testclient.TestClient(main.app) as client:
response = client.post('/classify', json={'text': ['foobar']})

self.assertEqual(response.status_code, 200)
self.assertListEqual(response.json(), _TEST_CLASSIFY_RESPONSE_SUCCESS)
self.assertListEqual(
response.json(), _TEST_CLASSIFY_RESPONSE_SUCCESS_NO_EMBEDDINGS
)

def test_classify_service_with_medias(self):
self.mock_classify_service.return_value.classify.return_value = [
Expand All @@ -211,7 +231,9 @@ def test_classify_service_with_medias(self):
response = client.post('/classify', json={'media_uri': ['foobar']})

self.assertEqual(response.status_code, 200)
self.assertListEqual(response.json(), _TEST_CLASSIFY_RESPONSE_SUCCESS)
self.assertListEqual(
response.json(), _TEST_CLASSIFY_RESPONSE_SUCCESS_NO_EMBEDDINGS
)

def test_classify_service_with_error(self):
self.mock_classify_service.return_value.classify.return_value = (
Expand Down
14 changes: 12 additions & 2 deletions api/services/classify_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class ClassifyResult:
media_uri: Optional[str] = None
media_description: Optional[str] = None
categories: Optional[list[dict[str, Union[str, float]]]] = None
embedding: Optional[list[float]] = None


ClassifyResults = list[ClassifyResult]
Expand Down Expand Up @@ -73,6 +74,7 @@ def classify(
self,
text: Optional[Union[str, list[str]]] = None,
media_uri: Optional[Union[str, list[str]]] = None,
embeddings: bool = False,
) -> ClassifyResults:
"""Gets the semantic similarty of the passed input relative to the taxonomy.
Expand All @@ -82,6 +84,7 @@ def classify(
Args:
text: A string or a list of strings.
media_uri: A file path or list of file paths.
embeddings: A boolean indicating whether to return the embeddings.
Returns:
A response object containing the text input elements as keys and their
Expand All @@ -107,19 +110,21 @@ def classify(
text_list, media_descriptions
)
return self._find_nearest_neighbors_for_text(
text_embeddings, media_descriptions
text_embeddings, media_descriptions, embeddings
)

def _find_nearest_neighbors_for_text(
self,
text_embeddings: dict[str, list[float]],
media_descriptions: Optional[list[tuple[str, str]]] = None,
embeddings: bool = False,
) -> ClassifyResults:
"""Finds the nearest neighbors for text.
Args:
text_embeddings: An object containing embeddings for all text elements.
media_descriptions: List of (media path, description) tuples.
embeddings: A boolean indicating whether to return the embeddings.
Returns:
A list of dict objects with text as the key and the similarities to
Expand All @@ -146,10 +151,15 @@ def _find_nearest_neighbors_for_text(
media_uri=text,
categories=similar_categories,
media_description=media_descriptions_dict[text],
embedding=text_embeddings[text] if embeddings else None,
)
)
else:
classify_results.append(
ClassifyResult(text=text, categories=similar_categories)
ClassifyResult(
text=text,
categories=similar_categories,
embedding=text_embeddings[text] if embeddings else None,
)
)
return classify_results
39 changes: 38 additions & 1 deletion api/services/classify_service_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def setUp(self):
},
find_neighbors_result_count=2,
generate_descriptions_from_medias_return_value={},
embeddings=False,
expected=[
classify_service_lib.ClassifyResult(
text='fake_text_1',
Expand All @@ -90,10 +91,45 @@ def setUp(self):
),
],
),
dict(
testcase_name='text_with_embeddings',
text_input=['fake_text_1', 'fake_text_2'],
media_input=None,
embeddings=True,
get_embeddings_batch_return_value={
'fake_text_1': [0.1, 0.2, 0.3],
'fake_text_2': [0.1, 0.2, 0.3],
},
find_neighbors_result_count=2,
generate_descriptions_from_medias_return_value={},
expected=[
classify_service_lib.ClassifyResult(
text='fake_text_1',
categories=[
{
'name': 'fake_id',
'similarity': mock.ANY,
},
],
embedding=[0.1, 0.2, 0.3],
),
classify_service_lib.ClassifyResult(
text='fake_text_2',
categories=[
{
'name': 'fake_id',
'similarity': mock.ANY,
},
],
embedding=[0.1, 0.2, 0.3],
),
],
),
dict(
testcase_name='text_and_media',
text_input=['fake_text_1', 'fake_text_2'],
media_input=['gs://fake/path_1.jpeg', 'gs://fake/path_2.jpeg'],
embeddings=False,
get_embeddings_batch_return_value={
'fake_text_1': [0.1, 0.2, 0.3],
'fake_text_2': [0.1, 0.2, 0.3],
Expand Down Expand Up @@ -151,6 +187,7 @@ def test_classify(
self,
text_input,
media_input,
embeddings,
get_embeddings_batch_return_value,
find_neighbors_result_count,
generate_descriptions_from_medias_return_value,
Expand All @@ -170,7 +207,7 @@ def test_classify(
classifier = classify_service_lib.ClassifyService(
self.postgres_client, self.vertex_client, self.ai_platform_client
)
actual = classifier.classify(text_input, media_input)
actual = classifier.classify(text_input, media_input, embeddings)
self.assertEqual(actual, expected)

def test_classify_without_media_text_input(self):
Expand Down

0 comments on commit 8fe1afd

Please sign in to comment.