diff --git a/lightly/api/api_workflow_upload_embeddings.py b/lightly/api/api_workflow_upload_embeddings.py index 11ebca7f4..d01623a9c 100644 --- a/lightly/api/api_workflow_upload_embeddings.py +++ b/lightly/api/api_workflow_upload_embeddings.py @@ -41,7 +41,10 @@ def set_embedding_id_to_latest(self): self._embeddings_api.get_embeddings_by_dataset_id( dataset_id=self.dataset_id ) - self.embedding_id = embeddings_on_server[-1].id + if len(embeddings_on_server) == 0: + raise RuntimeError(f"There are no known embeddings for dataset_id {self.dataset_id}.") + # return first entry as the API returns newest first + self.embedding_id = embeddings_on_server[0].id def get_embedding_by_name( self, name: str, ignore_suffix: bool = True diff --git a/tests/api_workflow/mocked_api_workflow_client.py b/tests/api_workflow/mocked_api_workflow_client.py index ad7f2a7dd..14ecdc1b6 100644 --- a/tests/api_workflow/mocked_api_workflow_client.py +++ b/tests/api_workflow/mocked_api_workflow_client.py @@ -173,9 +173,9 @@ def __init__(self, api_client): self.embeddings = [ DatasetEmbeddingData( id="embedding_id_xyz", - name="embedding_name_xxyyzz", + name="embedding_newest", is_processed=True, - created_at=0, + created_at=1111111, ), DatasetEmbeddingData( id="embedding_id_xyz_2", diff --git a/tests/api_workflow/test_api_workflow_upload_embeddings.py b/tests/api_workflow/test_api_workflow_upload_embeddings.py index 30797770e..be8fc183d 100644 --- a/tests/api_workflow/test_api_workflow_upload_embeddings.py +++ b/tests/api_workflow/test_api_workflow_upload_embeddings.py @@ -89,6 +89,12 @@ def test_upload_comma_filenames(self): def test_set_embedding_id_default(self): self.api_workflow_client.set_embedding_id_to_latest() + self.assertEqual(self.api_workflow_client.embedding_id, 'embedding_id_xyz') + + def test_set_embedding_id_no_embeddings(self): + self.api_workflow_client._embeddings_api.embeddings = [] + with self.assertRaises(RuntimeError): + self.api_workflow_client.set_embedding_id_to_latest() def test_upload_existing_embedding(self):