Skip to content

Commit

Permalink
Merge pull request #82 from scaleapi/da-get-model
Browse files Browse the repository at this point in the history
Get model + better error messages for segmentation upload
  • Loading branch information
ardila authored Jul 3, 2021
2 parents c25d271 + 1d6e1fe commit 906060f
Show file tree
Hide file tree
Showing 7 changed files with 46 additions and 16 deletions.
23 changes: 18 additions & 5 deletions nucleus/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,11 +181,11 @@ def list_models(self) -> List[Model]:

return [
Model(
model["id"],
model["name"],
model["ref_id"],
model["metadata"],
self,
model_id=model["id"],
name=model["name"],
reference_id=model["ref_id"],
metadata=model["metadata"] or None,
client=self,
)
for model in model_objects["models"]
]
Expand Down Expand Up @@ -231,6 +231,19 @@ def get_dataset(self, dataset_id: str) -> Dataset:
"""
return Dataset(dataset_id, self)

def get_model(self, model_id: str) -> Model:
"""
Fetched a model for a given id
:param model_id: internally controlled dataset_id
:return: model
"""
payload = self.make_request(
payload={},
route=f"model/{model_id}",
requests_command=requests.get,
)
return Model.from_json(payload=payload, client=self)

def get_model_run(self, model_run_id: str, dataset_id: str) -> ModelRun:
"""
Fetches a model_run for given id
Expand Down
4 changes: 2 additions & 2 deletions nucleus/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,12 +309,12 @@ def to_payload(self) -> dict:
}


def check_all_annotation_paths_remote(
def check_all_mask_paths_remote(
annotations: Sequence[Union[Annotation]],
):
for annotation in annotations:
if hasattr(annotation, MASK_URL_KEY):
if is_local_path(getattr(annotation, MASK_URL_KEY)):
raise ValueError(
f"Found an annotation with a local path, which cannot be uploaded asynchronously. Use a remote path instead. {annotation}"
f"Found an annotation with a local path, which is not currently supported. Use a remote path instead. {annotation}"
)
6 changes: 3 additions & 3 deletions nucleus/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from .annotation import (
Annotation,
CuboidAnnotation,
check_all_annotation_paths_remote,
check_all_mask_paths_remote,
)
from .constants import (
DATASET_ITEM_IDS_KEY,
Expand Down Expand Up @@ -171,9 +171,9 @@ def annotate(
if any((isinstance(ann, CuboidAnnotation) for ann in annotations)):
raise NotImplementedError("Cuboid annotations not yet supported")

if asynchronous:
check_all_annotation_paths_remote(annotations)
check_all_mask_paths_remote(annotations)

if asynchronous:
request_id = serialize_and_write_to_presigned_url(
annotations, self.id, self._client
)
Expand Down
17 changes: 16 additions & 1 deletion nucleus/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,26 @@ def __repr__(self):
return f"Model(model_id='{self.id}', name='{self.name}', reference_id='{self.reference_id}', metadata={self.metadata}, client={self._client})"

def __eq__(self, other):
return self.id == other.id
return (
(self.id == other.id)
and (self.name == other.name)
and (self.metadata == other.metadata)
and (self._client == other._client)
)

def __hash__(self):
return hash(self.id)

@classmethod
def from_json(cls, payload: dict, client):
return cls(
model_id=payload["id"],
name=payload["name"],
reference_id=payload["ref_id"],
metadata=payload["metadata"] or None,
client=client,
)

def create_run(
self,
name: str,
Expand Down
6 changes: 2 additions & 4 deletions nucleus/model_run.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from typing import Dict, List, Optional, Type, Union

import requests

from nucleus.annotation import check_all_annotation_paths_remote
from nucleus.annotation import check_all_mask_paths_remote
from nucleus.job import AsyncJob
from nucleus.utils import serialize_and_write_to_presigned_url

Expand Down Expand Up @@ -108,7 +106,7 @@ def predict(
}
"""
if asynchronous:
check_all_annotation_paths_remote(annotations)
check_all_mask_paths_remote(annotations)

request_id = serialize_and_write_to_presigned_url(
annotations, self._dataset_id, self._client
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ exclude = '''

[tool.poetry]
name = "scale-nucleus"
version = "0.1.12"
version = "0.1.13"
description = "The official Python client library for Nucleus, the Data Platform for AI"
license = "MIT"
authors = ["Scale AI Nucleus Team <nucleusapi@scaleapi.com>"]
Expand Down
4 changes: 4 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ def test_model_creation_and_listing(CLIENT, dataset):
# List the models
ms = CLIENT.list_models()

# Get a model
m = CLIENT.get_model(model.id)
assert m == model

assert model in ms
assert list(set(ms) - set(models_before))[0] == model

Expand Down

0 comments on commit 906060f

Please sign in to comment.