From c5d436077a2ce3da0ce6dbca1341de7392cd42d6 Mon Sep 17 00:00:00 2001 From: Maximilian Weiss Date: Wed, 7 Aug 2024 10:58:45 -0700 Subject: [PATCH 1/2] No public description PiperOrigin-RevId: 660456104 Change-Id: I2e99bf37488b229b664ff69dfbb29bb88c38f077 --- api/main.py | 2 ++ api/main_test.py | 32 ++++++++++++++++++---- api/services/classify_service.py | 14 ++++++++-- api/services/classify_service_test.py | 39 ++++++++++++++++++++++++++- 4 files changed, 79 insertions(+), 8 deletions(-) diff --git a/api/main.py b/api/main.py index 2ef47f8..d707589 100644 --- a/api/main.py +++ b/api/main.py @@ -41,6 +41,7 @@ class ClassifyRequest(pydantic.BaseModel): text: str | list[str] = None media_uri: str | list[str] = None + embeddings: bool = False class ClassifyResponse(pydantic.BaseModel): @@ -119,6 +120,7 @@ def classify( classify_results = services['classify_service'].classify( request.text, request.media_uri, + request.embeddings, ) return classify_results except Exception as e: diff --git a/api/main_test.py b/api/main_test.py index 43b109a..34fc25b 100644 --- a/api/main_test.py +++ b/api/main_test.py @@ -34,6 +34,17 @@ _TEST_CLASSIFY_RESPONSE_SUCCESS = [ + { + 'text': 'foobar', + 'categories': [ + {'name': 'category_1', 'similarity': 0.98}, + {'name': 'category_2', 'similarity': 0.89}, + ], + 'embedding': [0.1, 0.2, 0.3], + }, +] + +_TEST_CLASSIFY_RESPONSE_SUCCESS_NO_EMBEDDINGS = [ { 'text': 'foobar', 'categories': [ @@ -155,10 +166,13 @@ def test_classify_service(self): {'name': 'category_1', 'similarity': 0.98}, {'name': 'category_2', 'similarity': 0.89}, ], + 'embedding': [0.1, 0.2, 0.3], }, ] with testclient.TestClient(main.app) as client: - actual = client.post('/classify', json={'text': 'foobar'}) + actual = client.post( + '/classify', json={'text': 'foobar', 'embeddings': True} + ) self.assertEqual(actual.status_code, 200) self.assertListEqual(actual.json(), _TEST_CLASSIFY_RESPONSE_SUCCESS) @@ -171,12 +185,15 @@ def test_classify_service_use_vector_search(self): {'name': 'category_1', 'similarity': 0.98}, {'name': 'category_2', 'similarity': 0.89}, ], + 'embedding': [0.1, 0.2, 0.3], }, ] with testclient.TestClient(main.app) as client: - actual = client.post('/classify', json={'text': 'foobar'}) + actual = client.post( + '/classify', json={'text': 'foobar', 'embeddings': True} + ) self.mock_classify_service.return_value.classify.assert_called_once_with( - 'foobar', None + 'foobar', None, True ) self.assertEqual(actual.status_code, 200) self.assertListEqual(actual.json(), _TEST_CLASSIFY_RESPONSE_SUCCESS) @@ -191,11 +208,14 @@ def test_classify_service_with_list(self): ], }, ] + with testclient.TestClient(main.app) as client: response = client.post('/classify', json={'text': ['foobar']}) self.assertEqual(response.status_code, 200) - self.assertListEqual(response.json(), _TEST_CLASSIFY_RESPONSE_SUCCESS) + self.assertListEqual( + response.json(), _TEST_CLASSIFY_RESPONSE_SUCCESS_NO_EMBEDDINGS + ) def test_classify_service_with_medias(self): self.mock_classify_service.return_value.classify.return_value = [ @@ -211,7 +231,9 @@ def test_classify_service_with_medias(self): response = client.post('/classify', json={'media_uri': ['foobar']}) self.assertEqual(response.status_code, 200) - self.assertListEqual(response.json(), _TEST_CLASSIFY_RESPONSE_SUCCESS) + self.assertListEqual( + response.json(), _TEST_CLASSIFY_RESPONSE_SUCCESS_NO_EMBEDDINGS + ) def test_classify_service_with_error(self): self.mock_classify_service.return_value.classify.return_value = ( diff --git a/api/services/classify_service.py b/api/services/classify_service.py index 911ebe1..caf42bf 100644 --- a/api/services/classify_service.py +++ b/api/services/classify_service.py @@ -29,6 +29,7 @@ class ClassifyResult: media_uri: Optional[str] = None media_description: Optional[str] = None categories: Optional[list[dict[str, Union[str, float]]]] = None + embedding: Optional[list[float]] = None ClassifyResults = list[ClassifyResult] @@ -73,6 +74,7 @@ def classify( self, text: Optional[Union[str, list[str]]] = None, media_uri: Optional[Union[str, list[str]]] = None, + embeddings: bool = False, ) -> ClassifyResults: """Gets the semantic similarty of the passed input relative to the taxonomy. @@ -82,6 +84,7 @@ def classify( Args: text: A string or a list of strings. media_uri: A file path or list of file paths. + embeddings: A boolean indicating whether to return the embeddings. Returns: A response object containing the text input elements as keys and their @@ -107,19 +110,21 @@ def classify( text_list, media_descriptions ) return self._find_nearest_neighbors_for_text( - text_embeddings, media_descriptions + text_embeddings, media_descriptions, embeddings ) def _find_nearest_neighbors_for_text( self, text_embeddings: dict[str, list[float]], media_descriptions: Optional[list[tuple[str, str]]] = None, + embeddings: bool = False, ) -> ClassifyResults: """Finds the nearest neighbors for text. Args: text_embeddings: An object containing embeddings for all text elements. media_descriptions: List of (media path, description) tuples. + embeddings: A boolean indicating whether to return the embeddings. Returns: A list of dict objects with text as the key and the similarities to @@ -146,10 +151,15 @@ def _find_nearest_neighbors_for_text( media_uri=text, categories=similar_categories, media_description=media_descriptions_dict[text], + embedding=text_embeddings[text] if embeddings else None, ) ) else: classify_results.append( - ClassifyResult(text=text, categories=similar_categories) + ClassifyResult( + text=text, + categories=similar_categories, + embedding=text_embeddings[text] if embeddings else None, + ) ) return classify_results diff --git a/api/services/classify_service_test.py b/api/services/classify_service_test.py index a2b983f..057a08e 100644 --- a/api/services/classify_service_test.py +++ b/api/services/classify_service_test.py @@ -69,6 +69,7 @@ def setUp(self): }, find_neighbors_result_count=2, generate_descriptions_from_medias_return_value={}, + embeddings=False, expected=[ classify_service_lib.ClassifyResult( text='fake_text_1', @@ -90,10 +91,45 @@ def setUp(self): ), ], ), + dict( + testcase_name='text_with_embeddings', + text_input=['fake_text_1', 'fake_text_2'], + media_input=None, + embeddings=True, + get_embeddings_batch_return_value={ + 'fake_text_1': [0.1, 0.2, 0.3], + 'fake_text_2': [0.1, 0.2, 0.3], + }, + find_neighbors_result_count=2, + generate_descriptions_from_medias_return_value={}, + expected=[ + classify_service_lib.ClassifyResult( + text='fake_text_1', + categories=[ + { + 'name': 'fake_id', + 'similarity': mock.ANY, + }, + ], + embedding=[0.1, 0.2, 0.3], + ), + classify_service_lib.ClassifyResult( + text='fake_text_2', + categories=[ + { + 'name': 'fake_id', + 'similarity': mock.ANY, + }, + ], + embedding=[0.1, 0.2, 0.3], + ), + ], + ), dict( testcase_name='text_and_media', text_input=['fake_text_1', 'fake_text_2'], media_input=['gs://fake/path_1.jpeg', 'gs://fake/path_2.jpeg'], + embeddings=False, get_embeddings_batch_return_value={ 'fake_text_1': [0.1, 0.2, 0.3], 'fake_text_2': [0.1, 0.2, 0.3], @@ -151,6 +187,7 @@ def test_classify( self, text_input, media_input, + embeddings, get_embeddings_batch_return_value, find_neighbors_result_count, generate_descriptions_from_medias_return_value, @@ -170,7 +207,7 @@ def test_classify( classifier = classify_service_lib.ClassifyService( self.postgres_client, self.vertex_client, self.ai_platform_client ) - actual = classifier.classify(text_input, media_input) + actual = classifier.classify(text_input, media_input, embeddings) self.assertEqual(actual, expected) def test_classify_without_media_text_input(self): From bb6ccbb81e221f8adc44f448d23dde62faa85d4b Mon Sep 17 00:00:00 2001 From: Maximilian Weiss Date: Wed, 7 Aug 2024 11:22:34 -0700 Subject: [PATCH 2/2] No public description PiperOrigin-RevId: 660466014 Change-Id: I5c8a8db6da3ebef3ebeb48165226292730dedb84 --- api/main.py | 1 + 1 file changed, 1 insertion(+) diff --git a/api/main.py b/api/main.py index d707589..2dd41a9 100644 --- a/api/main.py +++ b/api/main.py @@ -51,6 +51,7 @@ class ClassifyResponse(pydantic.BaseModel): media_uri: str | None = None media_description: str | None = None categories: list[dict[str, Union[str, float]]] + embedding: list[float] | None = None class GenerateTaxonomyEmbeddingsRequest(pydantic.BaseModel):