Skip to content

Commit

Permalink
Add predictions
Browse files Browse the repository at this point in the history
  • Loading branch information
olokobayusuf committed May 29, 2023
1 parent 79750d1 commit c1350d6
Show file tree
Hide file tree
Showing 5 changed files with 260 additions and 5 deletions.
2 changes: 2 additions & 0 deletions Changelog.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
## 0.0.5
+ Added `Prediction` class for making predictions.
+ Added `fxn predict` CLI command for makong predictions.
+ Updated `Predictor.create` method `type` argument to be optional. Cloud predictors are now the default.

## 0.0.4
Expand Down
2 changes: 1 addition & 1 deletion fxn/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .dtype import Dtype
from .feature import Feature
from .featureinput import FeatureInput
#from .prediction import Prediction # INCOMPLETE
from .prediction import CloudPrediction, EdgePrediction, Prediction
from .predictor import Acceleration, AccessMode, Parameter, Predictor, PredictorStatus, PredictorType, Signature
from .profile import Profile
from .storage import Storage, UploadType
Expand Down
5 changes: 4 additions & 1 deletion fxn/api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@ def query (query: str, variables: dict=None, access_key: str=None) -> dict:
fxn.api_url,
json={ "query": query, "variables": variables },
headers=headers
).json()
)
# Check
response.raise_for_status()
response = response.json()
# Check error
if "errors" in response:
raise RuntimeError(response["errors"][0]["message"])
Expand Down
186 changes: 186 additions & 0 deletions fxn/api/prediction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
#
# Function
# Copyright © 2023 NatML Inc. All Rights Reserved.
#

from __future__ import annotations
from dataclasses import asdict, dataclass
from io import BytesIO
from numpy import frombuffer
from PIL import Image
from platform import system
from requests import get
from typing import Any, Dict, List, Union
from urllib.request import urlopen

from .api import query
from .dtype import Dtype
from .feature import Feature
from .featureinput import FeatureInput
from .predictor import PredictorType

@dataclass(frozen=True)
class Prediction:
"""
Prediction.
Members:
id (str): Prediction ID.
tag (str): Predictor tag.
type (PredictorType): Prediction type.
created (str): Date created.
"""
id: str
tag: str
type: PredictorType
created: str
FIELDS = f"""
id
tag
type
created
... on CloudPrediction {{
results {{
data
type
shape
stringValue
listValue
dictValue
}}
latency
error
logs
}}
"""
RAW_FIELDS = f"""
id
tag
type
created
... on CloudPrediction {{
results {{
data
type
shape
}}
latency
error
logs
}}
"""

@classmethod
def create (
cls,
tag: str,
*features: List[FeatureInput],
data_url_limit: int=None,
raw_outputs: bool=False,
access_key: str=None,
**inputs: Dict[str, Any],
) -> Union[CloudPrediction, EdgePrediction]:
"""
Create a prediction.
Parameters:
tag (str): Predictor tag.
features (list): Input features. Only applies to `CLOUD` predictions.
data_url_limit (int): Return a data URL if a given output feature is smaller than this limit in bytes. Only applies to `CLOUD` predictions.
raw_outputs (bool): Skip parsing output features into Pythonic data types.
access_key (str): Function access key.
inputs (dict): Input features. Only applies to `CLOUD` predictions.
Returns:
CloudPrediction | EdgePrediction: Created prediction.
"""
# Collect input features
input_features = list(features) + [FeatureInput.from_value(value, name) for name, value in inputs.items()]
input_features = [asdict(feature) for feature in input_features]
# Query
response = query(f"""
mutation ($input: CreatePredictionInput!) {{
createPrediction (input: $input) {{
{cls.RAW_FIELDS if raw_outputs else cls.FIELDS}
}}
}}""",
{ "input": { "tag": tag, "client": _get_client(), "inputs": input_features, "dataUrlLimit": data_url_limit } },
access_key=access_key
)
# Check
prediction = response["createPrediction"]
if not prediction:
return None
# Parse results
if "results" in prediction and not raw_outputs:
prediction["results"] = [_parse_output_feature(feature) for feature in prediction["results"]]
# Create
prediction = CloudPrediction(**prediction) if prediction["type"] == PredictorType.Cloud else EdgePrediction(**prediction)
# Return
return prediction

@dataclass(frozen=True)
class CloudPrediction (Prediction):
"""
Cloud prediction.
Members:
results (list): Prediction results.
latency (float): Prediction latency in milliseconds.
error (str): Prediction error. This is `null` if the prediction completed successfully.
logs (str): Prediction logs.
"""
results: List[Feature] = None
latency: float = None
error: str = None
logs: str = None

@dataclass(frozen=True)
class EdgePrediction (Prediction):
"""
Edge prediction
"""

def _parse_output_feature (feature: dict) -> Union[Feature, str, float, int, bool, Image.Image, list, dict]:
data, type, shape = feature["data"], feature["type"], feature["shape"]
# Handle image
if type == Dtype.image:
return Image.open(_download_feature_data(data))
# Handle non-numeric scalars
values = [feature.get(key, None) for key in ["stringValue", "listValue", "dictValue"]]
scalar = next((value for value in values if value is not None), None)
if scalar is not None:
return scalar
# Handle ndarray
ARRAY_TYPES = [
Dtype.int8, Dtype.int16, Dtype.int32, Dtype.int64,
Dtype.uint8, Dtype.uint16, Dtype.uint32, Dtype.uint64,
Dtype.float16, Dtype.float32, Dtype.float64, Dtype.bool
]
if type in ARRAY_TYPES:
# Create array
array = frombuffer(_download_feature_data(data).getbuffer(), dtype=type).reshape(shape)
return array if len(shape) > 0 else array.item()
# Handle generic feature
feature = Feature(**feature)
return feature

def _download_feature_data (url: str) -> BytesIO:
# Check if data URL
if url.startswith("data:"):
with urlopen(url) as response:
return BytesIO(response.read())
# Download
response = get(url)
result = BytesIO(response.content)
return result

def _get_client () -> str:
id = system()
if id == "Darwin":
return "macos"
if id == "Linux":
return "linux"
if id == "Windows":
return "windows"
raise RuntimeError(f"Function cannot make predictions on the {id} platform")
70 changes: 67 additions & 3 deletions fxn/cli/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,75 @@
# Copyright © 2023 NatML Inc. All Rights Reserved.
#

from typer import Argument, Context, Option, Typer
from dataclasses import asdict
from numpy import ndarray
from pathlib import Path
from PIL import Image
from rich import print_json
from tempfile import mkstemp
from typer import Argument, Context, Option

def predict ( # INCOMPLETE
from ..api import Prediction
from .auth import get_access_key

def predict (
tag: str = Argument(..., help="Predictor tag."),
raw_outputs: bool = Option(False, "--raw-outputs", help="Generate raw output features instead of parsing."),
context: Context = 0
):
print("Predict!")
# Predict
inputs = { context.args[i].replace("-", ""): _parse_value(context.args[i+1]) for i in range(0, len(context.args), 2) }
prediction = Prediction.create(tag, **inputs, raw_outputs=raw_outputs, access_key=get_access_key())
# Parse results
if hasattr(prediction, "results"):
images = [feature for feature in prediction.results if isinstance(feature, Image.Image)]
results = [_serialize_feature(feature) for feature in prediction.results]
object.__setattr__(prediction, "results", results)
# Print
print_json(data=asdict(prediction))
# Show images
for image in images:
image.show()

def _parse_value (value: str):
"""
Parse a value from a CLI argument.
Parameters:
value (str): CLI input argument.
Returns:
bool | int | float | str | Path: Parsed value.
"""
# Boolean
if value == "true":
return True
if value == "false":
return False
# Integer
try:
return int(value)
except ValueError:
pass
# Float
try:
return float(value)
except ValueError:
pass
# File
if value.startswith("@"):
return Path(value[1:])
# String
return value

def _serialize_feature (feature):
# Convert ndarray to list
if isinstance(feature, ndarray):
return feature.tolist()
# Write image
if isinstance(feature, Image.Image):
_, path = mkstemp(suffix=".png" if feature.mode == "RGBA" else ".jpg")
feature.save(path)
return path
# Return
return feature

0 comments on commit c1350d6

Please sign in to comment.