Skip to content

Commit

Permalink
Merge pull request #67 from scaleapi/da/async_upload
Browse files Browse the repository at this point in the history
Async item ingest
  • Loading branch information
ardila authored May 12, 2021
2 parents 41e93ca + 854462d commit 37afc9f
Show file tree
Hide file tree
Showing 16 changed files with 317 additions and 124 deletions.
1 change: 1 addition & 0 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ jobs:
pip install --upgrade pip
pip install poetry
poetry install
- run:
name: Black Formatting Check # Only validation, without re-formatting
command: |
Expand Down
12 changes: 7 additions & 5 deletions nucleus/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
# pylint: disable=C0302
from requests.packages.urllib3.util.retry import Retry

from .constants import REFERENCE_IDS_KEY, DATASET_ITEM_IDS_KEY
from .constants import REFERENCE_IDS_KEY, DATASET_ITEM_IDS_KEY, UPDATE_KEY
from .dataset import Dataset
from .dataset_item import DatasetItem
from .annotation import (
Expand Down Expand Up @@ -123,7 +123,6 @@
AUTOTAGS_KEY,
ANNOTATION_METADATA_SCHEMA_KEY,
ITEM_METADATA_SCHEMA_KEY,
FORCE_KEY,
EMBEDDINGS_URL_KEY,
)
from .model import Model
Expand Down Expand Up @@ -151,11 +150,14 @@ def __init__(
self,
api_key: str,
use_notebook: bool = False,
endpoint=NUCLEUS_ENDPOINT,
endpoint: str = None,
):
self.api_key = api_key
self.tqdm_bar = tqdm.tqdm
self.endpoint = endpoint
if endpoint is None:
self.endpoint = os.environ.get(
"NUCLEUS_ENDPOINT", NUCLEUS_ENDPOINT
)
self._use_notebook = use_notebook
if use_notebook:
self.tqdm_bar = tqdm_notebook.tqdm
Expand Down Expand Up @@ -497,7 +499,7 @@ def exception_handler(request, exception):
items = payload[ITEMS_KEY]
payloads = [
# batch_size images per request
{ITEMS_KEY: items[i : i + batch_size], FORCE_KEY: update}
{ITEMS_KEY: items[i : i + batch_size], UPDATE_KEY: update}
for i in range(0, len(items), batch_size)
]

Expand Down
106 changes: 54 additions & 52 deletions nucleus/constants.py
Original file line number Diff line number Diff line change
@@ -1,63 +1,65 @@
NUCLEUS_ENDPOINT = "https://api.scale.com/v1/nucleus"
DEFAULT_NETWORK_TIMEOUT_SEC = 120
ITEMS_KEY = "items"
ITEM_KEY = "item"
REFERENCE_ID_KEY = "reference_id"
REFERENCE_IDS_KEY = "reference_ids"
DATASET_ID_KEY = "dataset_id"
IMAGE_KEY = "image"
IMAGE_URL_KEY = "image_url"
NEW_ITEMS = "new_items"
UPDATED_ITEMS = "updated_items"
IGNORED_ITEMS = "ignored_items"
ERROR_ITEMS = "upload_errors"
ERROR_PAYLOAD = "error_payload"
ERROR_CODES = "error_codes"
ANNOTATIONS_IGNORED_KEY = "annotations_ignored"
ANNOTATIONS_KEY = "annotations"
ANNOTATION_ID_KEY = "annotation_id"
ANNOTATIONS_PROCESSED_KEY = "annotations_processed"
ANNOTATIONS_IGNORED_KEY = "annotations_ignored"
PREDICTIONS_PROCESSED_KEY = "predictions_processed"
PREDICTIONS_IGNORED_KEY = "predictions_ignored"
ANNOTATION_ID_KEY = "annotation_id"
ANNOTATION_METADATA_SCHEMA_KEY = "annotation_metadata_schema"
BOX_TYPE = "box"
POLYGON_TYPE = "polygon"
SEGMENTATION_TYPE = "segmentation"
ANNOTATION_TYPES = (BOX_TYPE, POLYGON_TYPE, SEGMENTATION_TYPE)
ANNOTATION_UPDATE_KEY = "update"
DEFAULT_ANNOTATION_UPDATE_MODE = False
STATUS_CODE_KEY = "status_code"
STATUS_KEY = "status"
SUCCESS_STATUS_CODES = [200, 201, 202]
ERRORS_KEY = "errors"
MODEL_RUN_ID_KEY = "model_run_id"
MODEL_ID_KEY = "model_id"
DATASET_ITEM_ID_KEY = "dataset_item_id"
ITEM_ID_KEY = "item_id"
AUTOTAGS_KEY = "autotags"

CONFIDENCE_KEY = "confidence"
DATASET_ID_KEY = "dataset_id"
DATASET_ITEM_IDS_KEY = "dataset_item_ids"
SLICE_ID_KEY = "slice_id"
DATASET_NAME_KEY = "name"
DATASET_ITEM_ID_KEY = "dataset_item_id"
DATASET_LENGTH_KEY = "length"
DATASET_MODEL_RUNS_KEY = "model_run_ids"
DATASET_NAME_KEY = "name"
DATASET_SLICES_KEY = "slice_ids"
DATASET_LENGTH_KEY = "length"
FORCE_KEY = "update"
DEFAULT_ANNOTATION_UPDATE_MODE = False
DEFAULT_NETWORK_TIMEOUT_SEC = 120
EMBEDDINGS_URL_KEY = "embeddings_url"
ERRORS_KEY = "errors"
ERROR_CODES = "error_codes"
ERROR_ITEMS = "upload_errors"
ERROR_PAYLOAD = "error_payload"
GEOMETRY_KEY = "geometry"
HEIGHT_KEY = "height"
IGNORED_ITEMS = "ignored_items"
IMAGE_KEY = "image"
IMAGE_URL_KEY = "image_url"
INDEX_KEY = "index"
ITEMS_KEY = "items"
ITEM_ID_KEY = "item_id"
ITEM_KEY = "item"
ITEM_METADATA_SCHEMA_KEY = "item_metadata_schema"
JOB_ID_KEY = "job_id"
LABEL_KEY = "label"
MASK_URL_KEY = "mask_url"
MESSAGE_KEY = "message"
METADATA_KEY = "metadata"
MODEL_ID_KEY = "model_id"
MODEL_RUN_ID_KEY = "model_run_id"
NAME_KEY = "name"
LABEL_KEY = "label"
CONFIDENCE_KEY = "confidence"
NEW_ITEMS = "new_items"
NUCLEUS_ENDPOINT = "https://api.scale.com/v1/nucleus"
ORIGINAL_IMAGE_URL_KEY = "original_image_url"
X_KEY = "x"
Y_KEY = "y"
WIDTH_KEY = "width"
HEIGHT_KEY = "height"
PREDICTIONS_IGNORED_KEY = "predictions_ignored"
PREDICTIONS_PROCESSED_KEY = "predictions_processed"
REFERENCE_IDS_KEY = "reference_ids"
REFERENCE_ID_KEY = "reference_id"
REQUEST_ID_KEY = "requestId"
SEGMENTATIONS_KEY = "segmentations"
SLICE_ID_KEY = "slice_id"
STATUS_CODE_KEY = "status_code"
STATUS_KEY = "status"
SUCCESS_STATUS_CODES = [200, 201, 202]
TYPE_KEY = "type"
UPDATED_ITEMS = "updated_items"
UPDATE_KEY = "update"
VERTICES_KEY = "vertices"
BOX_TYPE = "box"
POLYGON_TYPE = "polygon"
SEGMENTATION_TYPE = "segmentation"
ANNOTATION_TYPES = (BOX_TYPE, POLYGON_TYPE, SEGMENTATION_TYPE)
GEOMETRY_KEY = "geometry"
AUTOTAGS_KEY = "autotags"
ANNOTATION_METADATA_SCHEMA_KEY = "annotation_metadata_schema"
ITEM_METADATA_SCHEMA_KEY = "item_metadata_schema"
MASK_URL_KEY = "mask_url"
INDEX_KEY = "index"
SEGMENTATIONS_KEY = "segmentations"
EMBEDDINGS_URL_KEY = "embeddings_url"
JOB_ID_KEY = "job_id"
MESSAGE_KEY = "message"
WIDTH_KEY = "width"
X_KEY = "x"
Y_KEY = "y"
64 changes: 43 additions & 21 deletions nucleus/dataset.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from collections import Counter
from typing import List, Dict, Any, Optional
import uuid
from typing import Any, Dict, List, Optional, Union

import requests

from nucleus.utils import format_dataset_item_response
from nucleus.job import AsyncJob
from nucleus.utils import (
format_dataset_item_response,
serialize_and_write_to_presigned_url,
)

from .annotation import Annotation
from .constants import (
Expand All @@ -15,8 +19,14 @@
DEFAULT_ANNOTATION_UPDATE_MODE,
NAME_KEY,
REFERENCE_IDS_KEY,
REQUEST_ID_KEY,
UPDATE_KEY,
)
from .dataset_item import (
DatasetItem,
check_all_paths_remote,
check_for_duplicate_reference_ids,
)
from .dataset_item import DatasetItem
from .payload_constructor import construct_model_run_creation_payload


Expand All @@ -27,7 +37,11 @@ class Dataset:
compare model performance on you data.
"""

def __init__(self, dataset_id: str, client):
def __init__(
self,
dataset_id: str,
client: "NucleusClient", # type:ignore # noqa: F821
):
self.id = dataset_id
self._client = client

Expand Down Expand Up @@ -161,16 +175,18 @@ def ingest_tasks(self, task_ids: dict):
def append(
self,
dataset_items: List[DatasetItem],
force: Optional[bool] = False,
update: Optional[bool] = False,
batch_size: Optional[int] = 20,
) -> dict:
asynchronous=False,
) -> Union[dict, AsyncJob]:
"""
Appends images with metadata (dataset items) to the dataset. Overwrites images on collision if forced.
Parameters:
:param dataset_items: items to upload
:param force: if True overwrites images on collision
:param update: if True overwrites images and metadata on collision
:param batch_size: batch parameter for long uploads
:param aynchronous: if True, return a job object representing asynchronous ingestion job.
:return:
{
'dataset_id': str,
Expand All @@ -179,23 +195,29 @@ def append(
'ignored_items': int,
}
"""
ref_ids = []
for dataset_item in dataset_items:
if dataset_item.reference_id is not None:
ref_ids.append(dataset_item.reference_id)
if len(ref_ids) != len(set(ref_ids)):
duplicates = {
f"{key}": f"Count: {value}"
for key, value in Counter(ref_ids).items()
}
raise ValueError(
"Duplicate reference ids found among dataset_items: %s"
% duplicates
check_for_duplicate_reference_ids(dataset_items)

if asynchronous:
check_all_paths_remote(dataset_items)
request_id = uuid.uuid4().hex
response = self._client.make_request(
payload={},
route=f"dataset/{self.id}/signedUrl/{request_id}",
requests_command=requests.get,
)
serialize_and_write_to_presigned_url(
dataset_items, response["signed_url"]
)
response = self._client.make_request(
payload={REQUEST_ID_KEY: request_id, UPDATE_KEY: update},
route=f"dataset/{self.id}/append?async=1",
)
return AsyncJob(response["job_id"], self._client)

return self._client.populate_dataset(
self.id,
dataset_items,
force=force,
force=update,
batch_size=batch_size,
)

Expand Down
44 changes: 35 additions & 9 deletions nucleus/dataset_item.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from collections import Counter
import json
import os.path
from dataclasses import dataclass
from typing import Optional
from typing import Optional, Sequence
from urllib.parse import urlparse

from .constants import (
DATASET_ITEM_ID_KEY,
Expand All @@ -21,8 +23,7 @@ class DatasetItem:
metadata: Optional[dict] = None

def __post_init__(self):
self.image_url = self.image_location
self.local = self._is_local_path(self.image_location)
self.local = is_local_path(self.image_location)

@classmethod
def from_json(cls, payload: dict):
Expand All @@ -36,16 +37,12 @@ def from_json(cls, payload: dict):
metadata=payload.get(METADATA_KEY, {}),
)

def _is_local_path(self, path: str) -> bool:
path_components = [comp.lower() for comp in path.split("/")]
return path_components[0] not in {"https:", "http:", "s3:", "gs:"}

def local_file_exists(self):
return os.path.isfile(self.image_url)
return os.path.isfile(self.image_location)

def to_payload(self) -> dict:
payload = {
IMAGE_URL_KEY: self.image_url,
IMAGE_URL_KEY: self.image_location,
METADATA_KEY: self.metadata or {},
}
if self.reference_id:
Expand All @@ -56,3 +53,32 @@ def to_payload(self) -> dict:

def to_json(self) -> str:
return json.dumps(self.to_payload())


def is_local_path(path: str) -> bool:
return urlparse(path).scheme not in {"https", "http", "s3", "gs"}


def check_all_paths_remote(dataset_items: Sequence[DatasetItem]):
for item in dataset_items:
if is_local_path(item.image_location):
raise ValueError(
f"All paths must be remote, but {item.image_location} is either "
"local, or a remote URL type that is not supported."
)


def check_for_duplicate_reference_ids(dataset_items: Sequence[DatasetItem]):
ref_ids = []
for dataset_item in dataset_items:
if dataset_item.reference_id is not None:
ref_ids.append(dataset_item.reference_id)
if len(ref_ids) != len(set(ref_ids)):
duplicates = {
f"{key}": f"Count: {value}"
for key, value in Counter(ref_ids).items()
}
raise ValueError(
"Duplicate reference ids found among dataset_items: %s"
% duplicates
)
Loading

0 comments on commit 37afc9f

Please sign in to comment.