From 91dad3c8083acd1d1d1c12eefd21c7d6b05ef914 Mon Sep 17 00:00:00 2001 From: Diego Ardila Date: Tue, 29 Jun 2021 19:00:26 -0700 Subject: [PATCH 1/3] Tests pass locally --- nucleus/__init__.py | 25 ++++++++++++++++++++----- nucleus/annotation.py | 2 +- nucleus/dataset.py | 4 ++-- nucleus/model.py | 19 ++++++++++++++++++- tests/test_dataset.py | 7 +++++++ tests/test_models.py | 4 ++++ 6 files changed, 52 insertions(+), 9 deletions(-) diff --git a/nucleus/__init__.py b/nucleus/__init__.py index 71aa6402..ad439937 100644 --- a/nucleus/__init__.py +++ b/nucleus/__init__.py @@ -181,11 +181,13 @@ def list_models(self) -> List[Model]: return [ Model( - model["id"], - model["name"], - model["ref_id"], - model["metadata"], - self, + model_id=model["id"], + name=model["name"], + reference_id=model["ref_id"], + metadata=None + if model["metadata"] == {} + else model["metadata"], + client=self, ) for model in model_objects["models"] ] @@ -231,6 +233,19 @@ def get_dataset(self, dataset_id: str) -> Dataset: """ return Dataset(dataset_id, self) + def get_model(self, model_id: str) -> Model: + """ + Fetched a model for a given id + :param model_id: internally controlled dataset_id + :return: model + """ + payload = self.make_request( + payload={}, + route=f"model/{model_id}", + requests_command=requests.get, + ) + return Model.from_json(payload=payload, client=self) + def get_model_run(self, model_run_id: str, dataset_id: str) -> ModelRun: """ Fetches a model_run for given id diff --git a/nucleus/annotation.py b/nucleus/annotation.py index 0e27f6ef..c36ddac1 100644 --- a/nucleus/annotation.py +++ b/nucleus/annotation.py @@ -314,5 +314,5 @@ def check_all_annotation_paths_remote( if hasattr(annotation, MASK_URL_KEY): if is_local_path(getattr(annotation, MASK_URL_KEY)): raise ValueError( - f"Found an annotation with a local path, which cannot be uploaded asynchronously. Use a remote path instead. {annotation}" + f"Found an annotation with a local path, which is not currently supported. Use a remote path instead. {annotation}" ) diff --git a/nucleus/dataset.py b/nucleus/dataset.py index 756c0619..285b51d1 100644 --- a/nucleus/dataset.py +++ b/nucleus/dataset.py @@ -171,9 +171,9 @@ def annotate( if any((isinstance(ann, CuboidAnnotation) for ann in annotations)): raise NotImplementedError("Cuboid annotations not yet supported") - if asynchronous: - check_all_annotation_paths_remote(annotations) + check_all_annotation_paths_remote(annotations) + if asynchronous: request_id = serialize_and_write_to_presigned_url( annotations, self.id, self._client ) diff --git a/nucleus/model.py b/nucleus/model.py index 5db899c0..dcf44857 100644 --- a/nucleus/model.py +++ b/nucleus/model.py @@ -32,11 +32,28 @@ def __repr__(self): return f"Model(model_id='{self.id}', name='{self.name}', reference_id='{self.reference_id}', metadata={self.metadata}, client={self._client})" def __eq__(self, other): - return self.id == other.id + return ( + (self.id == other.id) + and (self.name == other.name) + and (self.metadata == other.metadata) + and (self._client == other._client) + ) def __hash__(self): return hash(self.id) + @classmethod + def from_json(cls, payload: dict, client): + return cls( + model_id=payload["id"], + name=payload["name"], + reference_id=payload["ref_id"], + metadata=None + if payload["metadata"] == {} + else payload["metadata"], + client=client, + ) + def create_run( self, name: str, diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 6d0675b9..83df93c0 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -76,6 +76,13 @@ def dataset(CLIENT): assert response == {"message": "Beginning dataset deletion..."} +def test_upload_nonsense(dataset): + response = dataset.append( + [DatasetItem(image_location="https://fake.com/image.jpeg")] + ) + print(response) + + def make_dataset_items(): ds_items_with_metadata = [] for i, url in enumerate(TEST_IMG_URLS): diff --git a/tests/test_models.py b/tests/test_models.py index c9e040d9..58fd0612 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -64,6 +64,10 @@ def test_model_creation_and_listing(CLIENT, dataset): # List the models ms = CLIENT.list_models() + # Get a model + m = CLIENT.get_model(model.id) + assert m == model + assert model in ms assert list(set(ms) - set(models_before))[0] == model From 3ebde810099745e93e3c324e27b767e6b2288f09 Mon Sep 17 00:00:00 2001 From: Diego Ardila Date: Wed, 30 Jun 2021 06:36:49 -0700 Subject: [PATCH 2/3] remove irrelevant test --- tests/test_dataset.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 83df93c0..6d0675b9 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -76,13 +76,6 @@ def dataset(CLIENT): assert response == {"message": "Beginning dataset deletion..."} -def test_upload_nonsense(dataset): - response = dataset.append( - [DatasetItem(image_location="https://fake.com/image.jpeg")] - ) - print(response) - - def make_dataset_items(): ds_items_with_metadata = [] for i, url in enumerate(TEST_IMG_URLS): From 70d15309d4723733695b83028f0a61df209e9fe5 Mon Sep 17 00:00:00 2001 From: Diego Ardila Date: Fri, 2 Jul 2021 16:05:09 -0700 Subject: [PATCH 3/3] Review comments + bump version --- nucleus/__init__.py | 4 +--- nucleus/annotation.py | 2 +- nucleus/dataset.py | 4 ++-- nucleus/model.py | 4 +--- nucleus/model_run.py | 4 ++-- pyproject.toml | 2 +- 6 files changed, 8 insertions(+), 12 deletions(-) diff --git a/nucleus/__init__.py b/nucleus/__init__.py index ad439937..eed31275 100644 --- a/nucleus/__init__.py +++ b/nucleus/__init__.py @@ -184,9 +184,7 @@ def list_models(self) -> List[Model]: model_id=model["id"], name=model["name"], reference_id=model["ref_id"], - metadata=None - if model["metadata"] == {} - else model["metadata"], + metadata=model["metadata"] or None, client=self, ) for model in model_objects["models"] diff --git a/nucleus/annotation.py b/nucleus/annotation.py index c36ddac1..05f0a602 100644 --- a/nucleus/annotation.py +++ b/nucleus/annotation.py @@ -307,7 +307,7 @@ def to_payload(self) -> dict: } -def check_all_annotation_paths_remote( +def check_all_mask_paths_remote( annotations: Sequence[Union[Annotation]], ): for annotation in annotations: diff --git a/nucleus/dataset.py b/nucleus/dataset.py index 285b51d1..449dbfc3 100644 --- a/nucleus/dataset.py +++ b/nucleus/dataset.py @@ -13,7 +13,7 @@ from .annotation import ( Annotation, CuboidAnnotation, - check_all_annotation_paths_remote, + check_all_mask_paths_remote, ) from .constants import ( DATASET_ITEM_IDS_KEY, @@ -171,7 +171,7 @@ def annotate( if any((isinstance(ann, CuboidAnnotation) for ann in annotations)): raise NotImplementedError("Cuboid annotations not yet supported") - check_all_annotation_paths_remote(annotations) + check_all_mask_paths_remote(annotations) if asynchronous: request_id = serialize_and_write_to_presigned_url( diff --git a/nucleus/model.py b/nucleus/model.py index dcf44857..ef68da05 100644 --- a/nucleus/model.py +++ b/nucleus/model.py @@ -48,9 +48,7 @@ def from_json(cls, payload: dict, client): model_id=payload["id"], name=payload["name"], reference_id=payload["ref_id"], - metadata=None - if payload["metadata"] == {} - else payload["metadata"], + metadata=payload["metadata"] or None, client=client, ) diff --git a/nucleus/model_run.py b/nucleus/model_run.py index 90b95bfb..8442cbf1 100644 --- a/nucleus/model_run.py +++ b/nucleus/model_run.py @@ -1,6 +1,6 @@ from typing import Dict, List, Optional, Type, Union -from nucleus.annotation import check_all_annotation_paths_remote +from nucleus.annotation import check_all_mask_paths_remote from nucleus.job import AsyncJob from nucleus.utils import serialize_and_write_to_presigned_url @@ -106,7 +106,7 @@ def predict( } """ if asynchronous: - check_all_annotation_paths_remote(annotations) + check_all_mask_paths_remote(annotations) request_id = serialize_and_write_to_presigned_url( annotations, self._dataset_id, self._client diff --git a/pyproject.toml b/pyproject.toml index 9b2e1128..77d2f1ea 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ exclude = ''' [tool.poetry] name = "scale-nucleus" -version = "0.1.11" +version = "0.1.13" description = "The official Python client library for Nucleus, the Data Platform for AI" license = "MIT" authors = ["Scale AI Nucleus Team "]