diff --git a/Changelog.md b/Changelog.md index 83cc2f6..6c06a35 100644 --- a/Changelog.md +++ b/Changelog.md @@ -1,4 +1,6 @@ ## 0.0.5 ++ Added `Prediction` class for making predictions. ++ Added `fxn predict` CLI command for makong predictions. + Updated `Predictor.create` method `type` argument to be optional. Cloud predictors are now the default. ## 0.0.4 diff --git a/fxn/api/__init__.py b/fxn/api/__init__.py index 75c3e47..1a7b27a 100644 --- a/fxn/api/__init__.py +++ b/fxn/api/__init__.py @@ -6,7 +6,7 @@ from .dtype import Dtype from .feature import Feature from .featureinput import FeatureInput -#from .prediction import Prediction # INCOMPLETE +from .prediction import CloudPrediction, EdgePrediction, Prediction from .predictor import Acceleration, AccessMode, Parameter, Predictor, PredictorStatus, PredictorType, Signature from .profile import Profile from .storage import Storage, UploadType diff --git a/fxn/api/api.py b/fxn/api/api.py index 0ebd1c8..ccfaeab 100644 --- a/fxn/api/api.py +++ b/fxn/api/api.py @@ -25,7 +25,10 @@ def query (query: str, variables: dict=None, access_key: str=None) -> dict: fxn.api_url, json={ "query": query, "variables": variables }, headers=headers - ).json() + ) + # Check + response.raise_for_status() + response = response.json() # Check error if "errors" in response: raise RuntimeError(response["errors"][0]["message"]) diff --git a/fxn/api/prediction.py b/fxn/api/prediction.py index e69de29..91d98b3 100644 --- a/fxn/api/prediction.py +++ b/fxn/api/prediction.py @@ -0,0 +1,186 @@ +# +# Function +# Copyright © 2023 NatML Inc. All Rights Reserved. +# + +from __future__ import annotations +from dataclasses import asdict, dataclass +from io import BytesIO +from numpy import frombuffer +from PIL import Image +from platform import system +from requests import get +from typing import Any, Dict, List, Union +from urllib.request import urlopen + +from .api import query +from .dtype import Dtype +from .feature import Feature +from .featureinput import FeatureInput +from .predictor import PredictorType + +@dataclass(frozen=True) +class Prediction: + """ + Prediction. + + Members: + id (str): Prediction ID. + tag (str): Predictor tag. + type (PredictorType): Prediction type. + created (str): Date created. + """ + id: str + tag: str + type: PredictorType + created: str + FIELDS = f""" + id + tag + type + created + ... on CloudPrediction {{ + results {{ + data + type + shape + stringValue + listValue + dictValue + }} + latency + error + logs + }} + """ + RAW_FIELDS = f""" + id + tag + type + created + ... on CloudPrediction {{ + results {{ + data + type + shape + }} + latency + error + logs + }} + """ + + @classmethod + def create ( + cls, + tag: str, + *features: List[FeatureInput], + data_url_limit: int=None, + raw_outputs: bool=False, + access_key: str=None, + **inputs: Dict[str, Any], + ) -> Union[CloudPrediction, EdgePrediction]: + """ + Create a prediction. + + Parameters: + tag (str): Predictor tag. + features (list): Input features. Only applies to `CLOUD` predictions. + data_url_limit (int): Return a data URL if a given output feature is smaller than this limit in bytes. Only applies to `CLOUD` predictions. + raw_outputs (bool): Skip parsing output features into Pythonic data types. + access_key (str): Function access key. + inputs (dict): Input features. Only applies to `CLOUD` predictions. + + Returns: + CloudPrediction | EdgePrediction: Created prediction. + """ + # Collect input features + input_features = list(features) + [FeatureInput.from_value(value, name) for name, value in inputs.items()] + input_features = [asdict(feature) for feature in input_features] + # Query + response = query(f""" + mutation ($input: CreatePredictionInput!) {{ + createPrediction (input: $input) {{ + {cls.RAW_FIELDS if raw_outputs else cls.FIELDS} + }} + }}""", + { "input": { "tag": tag, "client": _get_client(), "inputs": input_features, "dataUrlLimit": data_url_limit } }, + access_key=access_key + ) + # Check + prediction = response["createPrediction"] + if not prediction: + return None + # Parse results + if "results" in prediction and not raw_outputs: + prediction["results"] = [_parse_output_feature(feature) for feature in prediction["results"]] + # Create + prediction = CloudPrediction(**prediction) if prediction["type"] == PredictorType.Cloud else EdgePrediction(**prediction) + # Return + return prediction + +@dataclass(frozen=True) +class CloudPrediction (Prediction): + """ + Cloud prediction. + + Members: + results (list): Prediction results. + latency (float): Prediction latency in milliseconds. + error (str): Prediction error. This is `null` if the prediction completed successfully. + logs (str): Prediction logs. + """ + results: List[Feature] = None + latency: float = None + error: str = None + logs: str = None + +@dataclass(frozen=True) +class EdgePrediction (Prediction): + """ + Edge prediction + """ + +def _parse_output_feature (feature: dict) -> Union[Feature, str, float, int, bool, Image.Image, list, dict]: + data, type, shape = feature["data"], feature["type"], feature["shape"] + # Handle image + if type == Dtype.image: + return Image.open(_download_feature_data(data)) + # Handle non-numeric scalars + values = [feature.get(key, None) for key in ["stringValue", "listValue", "dictValue"]] + scalar = next((value for value in values if value is not None), None) + if scalar is not None: + return scalar + # Handle ndarray + ARRAY_TYPES = [ + Dtype.int8, Dtype.int16, Dtype.int32, Dtype.int64, + Dtype.uint8, Dtype.uint16, Dtype.uint32, Dtype.uint64, + Dtype.float16, Dtype.float32, Dtype.float64, Dtype.bool + ] + if type in ARRAY_TYPES: + # Create array + array = frombuffer(_download_feature_data(data).getbuffer(), dtype=type).reshape(shape) + return array if len(shape) > 0 else array.item() + # Handle generic feature + feature = Feature(**feature) + return feature + +def _download_feature_data (url: str) -> BytesIO: + # Check if data URL + if url.startswith("data:"): + with urlopen(url) as response: + return BytesIO(response.read()) + # Download + response = get(url) + result = BytesIO(response.content) + return result + +def _get_client () -> str: + id = system() + if id == "Darwin": + return "macos" + if id == "Linux": + return "linux" + if id == "Windows": + return "windows" + raise RuntimeError(f"Function cannot make predictions on the {id} platform") \ No newline at end of file diff --git a/fxn/cli/predict.py b/fxn/cli/predict.py index 01ed03a..397a0c9 100644 --- a/fxn/cli/predict.py +++ b/fxn/cli/predict.py @@ -3,11 +3,75 @@ # Copyright © 2023 NatML Inc. All Rights Reserved. # -from typer import Argument, Context, Option, Typer +from dataclasses import asdict +from numpy import ndarray +from pathlib import Path +from PIL import Image +from rich import print_json +from tempfile import mkstemp +from typer import Argument, Context, Option -def predict ( # INCOMPLETE +from ..api import Prediction +from .auth import get_access_key + +def predict ( tag: str = Argument(..., help="Predictor tag."), raw_outputs: bool = Option(False, "--raw-outputs", help="Generate raw output features instead of parsing."), context: Context = 0 ): - print("Predict!") \ No newline at end of file + # Predict + inputs = { context.args[i].replace("-", ""): _parse_value(context.args[i+1]) for i in range(0, len(context.args), 2) } + prediction = Prediction.create(tag, **inputs, raw_outputs=raw_outputs, access_key=get_access_key()) + # Parse results + if hasattr(prediction, "results"): + images = [feature for feature in prediction.results if isinstance(feature, Image.Image)] + results = [_serialize_feature(feature) for feature in prediction.results] + object.__setattr__(prediction, "results", results) + # Print + print_json(data=asdict(prediction)) + # Show images + for image in images: + image.show() + +def _parse_value (value: str): + """ + Parse a value from a CLI argument. + + Parameters: + value (str): CLI input argument. + + Returns: + bool | int | float | str | Path: Parsed value. + """ + # Boolean + if value == "true": + return True + if value == "false": + return False + # Integer + try: + return int(value) + except ValueError: + pass + # Float + try: + return float(value) + except ValueError: + pass + # File + if value.startswith("@"): + return Path(value[1:]) + # String + return value + +def _serialize_feature (feature): + # Convert ndarray to list + if isinstance(feature, ndarray): + return feature.tolist() + # Write image + if isinstance(feature, Image.Image): + _, path = mkstemp(suffix=".png" if feature.mode == "RGBA" else ".jpg") + feature.save(path) + return path + # Return + return feature \ No newline at end of file