Skip to content

Commit

Permalink
Merge pull request #69 from scaleapi/da/predictions_async
Browse files Browse the repository at this point in the history
Async prediction upload
  • Loading branch information
ardila authored May 26, 2021
2 parents 37afc9f + 0ce4ca2 commit b664115
Show file tree
Hide file tree
Showing 12 changed files with 211 additions and 83 deletions.
104 changes: 52 additions & 52 deletions nucleus/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,89 +50,83 @@
geometry | dict | Representation of the bounding box in the Box2DGeometry format.\n
metadata | dict | An arbitrary metadata blob for the annotation.\n
"""
__version__ = "0.1.0"

import json
import logging
import warnings
import os
from typing import List, Union, Dict, Callable, Any, Optional

import tqdm
import tqdm.notebook as tqdm_notebook
import warnings
from typing import Any, Callable, Dict, List, Optional, Union

import grequests
import pkg_resources
import requests
import tqdm
import tqdm.notebook as tqdm_notebook
from requests.adapters import HTTPAdapter

# pylint: disable=E1101
# TODO: refactor to reduce this file to under 1000 lines.
# pylint: disable=C0302
from requests.packages.urllib3.util.retry import Retry

from .constants import REFERENCE_IDS_KEY, DATASET_ITEM_IDS_KEY, UPDATE_KEY
from .dataset import Dataset
from .dataset_item import DatasetItem
from .annotation import (
BoxAnnotation,
PolygonAnnotation,
SegmentationAnnotation,
Segment,
)
from .prediction import (
BoxPrediction,
PolygonPrediction,
SegmentationPrediction,
)
from .model_run import ModelRun
from .slice import Slice
from .upload_response import UploadResponse
from .payload_constructor import (
construct_append_payload,
construct_annotation_payload,
construct_model_creation_payload,
construct_box_predictions_payload,
construct_segmentation_payload,
SegmentationAnnotation,
)
from .constants import (
NUCLEUS_ENDPOINT,
ANNOTATION_METADATA_SCHEMA_KEY,
ANNOTATIONS_IGNORED_KEY,
ANNOTATIONS_PROCESSED_KEY,
AUTOTAGS_KEY,
DATASET_ID_KEY,
DATASET_ITEM_IDS_KEY,
DEFAULT_NETWORK_TIMEOUT_SEC,
ERRORS_KEY,
EMBEDDINGS_URL_KEY,
ERROR_ITEMS,
ERROR_PAYLOAD,
ITEMS_KEY,
ITEM_KEY,
ERRORS_KEY,
IMAGE_KEY,
IMAGE_URL_KEY,
DATASET_ID_KEY,
ITEM_METADATA_SCHEMA_KEY,
ITEMS_KEY,
MODEL_RUN_ID_KEY,
DATASET_ITEM_ID_KEY,
SLICE_ID_KEY,
ANNOTATIONS_PROCESSED_KEY,
ANNOTATIONS_IGNORED_KEY,
PREDICTIONS_PROCESSED_KEY,
NAME_KEY,
NUCLEUS_ENDPOINT,
PREDICTIONS_IGNORED_KEY,
PREDICTIONS_PROCESSED_KEY,
REFERENCE_IDS_KEY,
SLICE_ID_KEY,
STATUS_CODE_KEY,
SUCCESS_STATUS_CODES,
DATASET_NAME_KEY,
DATASET_MODEL_RUNS_KEY,
DATASET_SLICES_KEY,
DATASET_LENGTH_KEY,
NAME_KEY,
ANNOTATIONS_KEY,
AUTOTAGS_KEY,
ANNOTATION_METADATA_SCHEMA_KEY,
ITEM_METADATA_SCHEMA_KEY,
EMBEDDINGS_URL_KEY,
UPDATE_KEY,
)
from .model import Model
from .dataset import Dataset
from .dataset_item import DatasetItem
from .errors import (
DatasetItemRetrievalError,
ModelCreationError,
ModelRunCreationError,
DatasetItemRetrievalError,
NotFoundError,
NucleusAPIError,
)
from .model import Model
from .model_run import ModelRun
from .payload_constructor import (
construct_annotation_payload,
construct_append_payload,
construct_box_predictions_payload,
construct_model_creation_payload,
construct_segmentation_payload,
)
from .prediction import (
BoxPrediction,
PolygonPrediction,
SegmentationPrediction,
)
from .slice import Slice
from .upload_response import UploadResponse

__version__ = pkg_resources.get_distribution("scale-nucleus").version

logger = logging.getLogger(__name__)
logging.basicConfig()
Expand All @@ -158,6 +152,8 @@ def __init__(
self.endpoint = os.environ.get(
"NUCLEUS_ENDPOINT", NUCLEUS_ENDPOINT
)
else:
self.endpoint = endpoint
self._use_notebook = use_notebook
if use_notebook:
self.tqdm_bar = tqdm_notebook.tqdm
Expand Down Expand Up @@ -230,13 +226,15 @@ def get_dataset(self, dataset_id: str) -> Dataset:
"""
return Dataset(dataset_id, self)

def get_model_run(self, model_run_id: str) -> ModelRun:
def get_model_run(self, model_run_id: str, dataset_id: str) -> ModelRun:
"""
Fetches a model_run for given id
:param model_run_id: internally controlled model_run_id
:param dataset_id: the dataset id which may determine the prediction schema
for this model run if present on the dataset.
:return: model_run
"""
return ModelRun(model_run_id, self)
return ModelRun(model_run_id, dataset_id, self)

def delete_model_run(self, model_run_id: str):
"""
Expand Down Expand Up @@ -674,7 +672,9 @@ def create_model_run(self, dataset_id: str, payload: dict) -> ModelRun:
if response.get(STATUS_CODE_KEY, None):
raise ModelRunCreationError(response.get("error"))

return ModelRun(response[MODEL_RUN_ID_KEY], self)
return ModelRun(
response[MODEL_RUN_ID_KEY], dataset_id=dataset_id, client=self
)

def predict(
self,
Expand Down
16 changes: 15 additions & 1 deletion nucleus/annotation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import json
from dataclasses import dataclass
from enum import Enum
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Sequence, Union
from nucleus.dataset_item import is_local_path

from .constants import (
ANNOTATION_ID_KEY,
Expand All @@ -13,6 +14,7 @@
INDEX_KEY,
ITEM_ID_KEY,
LABEL_KEY,
MASK_TYPE,
MASK_URL_KEY,
METADATA_KEY,
POLYGON_TYPE,
Expand Down Expand Up @@ -108,6 +110,7 @@ def from_json(cls, payload: dict):

def to_payload(self) -> dict:
payload = {
TYPE_KEY: MASK_TYPE,
MASK_URL_KEY: self.mask_url,
ANNOTATIONS_KEY: [ann.to_payload() for ann in self.annotations],
ANNOTATION_ID_KEY: self.annotation_id,
Expand Down Expand Up @@ -206,3 +209,14 @@ def to_payload(self) -> dict:
ANNOTATION_ID_KEY: self.annotation_id,
METADATA_KEY: self.metadata,
}


def check_all_annotation_paths_remote(
annotations: Sequence[Union[Annotation]],
):
for annotation in annotations:
if hasattr(annotation, MASK_URL_KEY):
if is_local_path(getattr(annotation, MASK_URL_KEY)):
raise ValueError(
f"Found an annotation with a local path, which cannot be uploaded asynchronously. Use a remote path instead. {annotation}"
)
1 change: 1 addition & 0 deletions nucleus/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
ANNOTATION_METADATA_SCHEMA_KEY = "annotation_metadata_schema"
BOX_TYPE = "box"
POLYGON_TYPE = "polygon"
MASK_TYPE = "mask"
SEGMENTATION_TYPE = "segmentation"
ANNOTATION_TYPES = (BOX_TYPE, POLYGON_TYPE, SEGMENTATION_TYPE)
ANNOTATION_UPDATE_KEY = "update"
Expand Down
11 changes: 2 additions & 9 deletions nucleus/dataset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import uuid
from typing import Any, Dict, List, Optional, Union

import requests
Expand Down Expand Up @@ -199,14 +198,8 @@ def append(

if asynchronous:
check_all_paths_remote(dataset_items)
request_id = uuid.uuid4().hex
response = self._client.make_request(
payload={},
route=f"dataset/{self.id}/signedUrl/{request_id}",
requests_command=requests.get,
)
serialize_and_write_to_presigned_url(
dataset_items, response["signed_url"]
request_id = serialize_and_write_to_presigned_url(
dataset_items, self.id, self._client
)
response = self._client.make_request(
payload={REQUEST_ID_KEY: request_id, UPDATE_KEY: update},
Expand Down
3 changes: 2 additions & 1 deletion nucleus/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def create_run(
Union[BoxPrediction, PolygonPrediction, SegmentationPrediction]
],
metadata: Optional[Dict] = None,
asynchronous: bool = False,
) -> ModelRun:
payload: dict = {
NAME_KEY: name,
Expand All @@ -56,6 +57,6 @@ def create_run(
dataset.id, payload
)

model_run.predict(predictions)
model_run.predict(predictions, asynchronous=asynchronous)

return model_run
35 changes: 29 additions & 6 deletions nucleus/model_run.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
from typing import Dict, Optional, List, Union, Type
from typing import Dict, List, Optional, Type, Union

from nucleus.annotation import check_all_annotation_paths_remote
from nucleus.job import AsyncJob
from nucleus.utils import serialize_and_write_to_presigned_url

from .constants import (
ANNOTATIONS_KEY,
DEFAULT_ANNOTATION_UPDATE_MODE,
BOX_TYPE,
DEFAULT_ANNOTATION_UPDATE_MODE,
JOB_ID_KEY,
POLYGON_TYPE,
REQUEST_ID_KEY,
SEGMENTATION_TYPE,
UPDATE_KEY,
)
from .prediction import (
BoxPrediction,
Expand All @@ -19,12 +27,13 @@ class ModelRun:
Having an open model run is a prerequisite for uploading predictions to your dataset.
"""

def __init__(self, model_run_id: str, client):
def __init__(self, model_run_id: str, dataset_id: str, client):
self.model_run_id = model_run_id
self._client = client
self._dataset_id = dataset_id

def __repr__(self):
return f"ModelRun(model_run_id='{self.model_run_id}', client={self._client})"
return f"ModelRun(model_run_id='{self.model_run_id}', dataset_id='{self._dataset_id}', client={self._client})"

def __eq__(self, other):
if self.model_run_id == other.model_run_id:
Expand Down Expand Up @@ -84,7 +93,8 @@ def predict(
Union[BoxPrediction, PolygonPrediction, SegmentationPrediction]
],
update: Optional[bool] = DEFAULT_ANNOTATION_UPDATE_MODE,
) -> dict:
asynchronous: bool = False,
) -> Union[dict, AsyncJob]:
"""
Uploads model outputs as predictions for a model_run. Returns info about the upload.
:param annotations: List[Union[BoxPrediction, PolygonPrediction]],
Expand All @@ -95,7 +105,20 @@ def predict(
"predictions_ignored": int,
}
"""
return self._client.predict(self.model_run_id, annotations, update)
if asynchronous:
check_all_annotation_paths_remote(annotations)

request_id = serialize_and_write_to_presigned_url(
annotations, self._dataset_id, self._client
)
response = self._client.make_request(
payload={REQUEST_ID_KEY: request_id, UPDATE_KEY: update},
route=f"modelRun/{self.model_run_id}/predict?async=1",
)

return AsyncJob(response[JOB_ID_KEY], self._client)
else:
return self._client.predict(self.model_run_id, annotations, update)

def iloc(self, i: int):
"""
Expand Down
15 changes: 13 additions & 2 deletions nucleus/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@


import io
import uuid
from typing import IO, Dict, List, Sequence, Union

import requests
Expand Down Expand Up @@ -104,9 +105,19 @@ def upload_to_presigned_url(presigned_url: str, file_pointer: IO):


def serialize_and_write_to_presigned_url(
upload_units: Sequence[Union[DatasetItem, Annotation]], presigned_url
upload_units: Sequence[Union["DatasetItem", Annotation]],
dataset_id: str,
client,
):
request_id = uuid.uuid4().hex
response = client.make_request(
payload={},
route=f"dataset/{dataset_id}/signedUrl/{request_id}",
requests_command=requests.get,
)

strio = io.StringIO()
serialize_and_write(upload_units, strio)
strio.seek(0)
upload_to_presigned_url(presigned_url, strio)
upload_to_presigned_url(response["signed_url"], strio)
return request_id
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ exclude = '''

[tool.poetry]
name = "scale-nucleus"
version = "0.1.4"
version = "0.1.5"
description = "The official Python client library for Nucleus, the Data Platform for AI"
license = "MIT"
authors = ["Scale AI Nucleus Team <nucleusapi@scaleapi.com>"]
Expand Down
10 changes: 5 additions & 5 deletions tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@


TEST_IMG_URLS = [
"http://farm1.staticflickr.com/107/309278012_7a1f67deaa_z.jpg",
"http://farm9.staticflickr.com/8001/7679588594_4e51b76472_z.jpg",
"http://farm6.staticflickr.com/5295/5465771966_76f9773af1_z.jpg",
"http://farm4.staticflickr.com/3449/4002348519_8ddfa4f2fb_z.jpg",
"http://farm1.staticflickr.com/6/7617223_d84fcbce0e_z.jpg",
"https://homepages.cae.wisc.edu/~ece533/images/airplane.png",
"https://homepages.cae.wisc.edu/~ece533/images/arctichare.png",
"https://homepages.cae.wisc.edu/~ece533/images/baboon.png",
"https://homepages.cae.wisc.edu/~ece533/images/barbara.png",
"https://homepages.cae.wisc.edu/~ece533/images/cat.png",
]

TEST_DATASET_ITEMS = [
Expand Down
10 changes: 5 additions & 5 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,11 +190,11 @@ def test_dataset_append_async_with_1_bad_url(dataset: Dataset):
"started_image_processing": f"Dataset: {dataset.id}, Job: {job.id}",
},
}
assert job.errors() == [
"One or more of the images you attempted to upload did not process correctly. Please see the status for an overview and the errors for more detailed messages.",
# Todo: figure out why this error isn't propagating from image upload.
'Failure when processing the image "https://looks.ok.but.is.not.accessible": {}',
]
# The error is fairly detailed and subject to change. What's important is we surface which URLs failed.
assert (
'Failure when processing the image "https://looks.ok.but.is.not.accessible"'
in str(job.errors())
)


def test_dataset_list_autotags(CLIENT, dataset):
Expand Down
Loading

0 comments on commit b664115

Please sign in to comment.