Skip to content

Commit

Permalink
Minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
olokobayusuf committed Apr 3, 2024
1 parent 5bca134 commit e1c7ebd
Showing 1 changed file with 32 additions and 20 deletions.
52 changes: 32 additions & 20 deletions fxn/services/prediction/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ def create (
inputs: Dict[str, Union[ndarray, str, float, int, bool, List, Dict[str, Any], Path, Image.Image, Value]] = None,
raw_outputs: bool=False,
return_binary_path: bool=True,
data_url_limit: int=None
data_url_limit: int=None,
client_id: str=None,
configuration_id: str=None
) -> Prediction:
"""
Create a prediction.
Expand All @@ -55,6 +57,8 @@ def create (
raw_outputs (bool): Skip converting output values into Pythonic types. This only applies to `CLOUD` predictions.
return_binary_path (bool): Write binary values to file and return a `Path` instead of returning `BytesIO` instance.
data_url_limit (int): Return a data URL if a given output value is smaller than this size in bytes. This only applies to `CLOUD` predictions.
client_id (str): Function client identifier. Specify this to override the current client identifier.
configuration_id (str): Configuration identifier. Specify this to override the current client configuration identifier.
Returns:
Prediction: Created prediction.
Expand All @@ -66,13 +70,13 @@ def create (
key = uuid4().hex
values = { name: self.to_value(value, name, key=key).model_dump(mode="json") for name, value in inputs.items() } if inputs is not None else { }
# Query
response = post( # INCOMPLETE # Configuration token
response = post(
f"{self.client.api_url}/predict/{tag}?rawOutputs=true&dataUrlLimit={data_url_limit}",
json=values,
headers={
"Authorization": f"Bearer {self.client.access_key}",
"fxn-client": self.__get_client_id(),
"fxn-configuration-token": self.__get_configuration_token()
"fxn-client": client_id if client_id is not None else self.__get_client_id(),
"fxn-configuration-token": configuration_id if configuration_id is not None else self.__get_configuration_id()
}
)
# Check
Expand All @@ -84,12 +88,14 @@ def create (
raise RuntimeError(error)
# Parse prediction
prediction = self.__parse_prediction(prediction, raw_outputs=raw_outputs, return_binary_path=return_binary_path)
# Check edge prediction
if prediction.type != PredictorType.Edge or raw_outputs:
return prediction
# Load edge predictor
predictor = self.__load(prediction)
self.__cache[tag] = predictor
# Create edge prediction
if prediction.type == PredictorType.Edge:
predictor = self.__load(prediction)
self.__cache[tag] = predictor
prediction = self.__predict(tag=tag, predictor=predictor, inputs=inputs) if inputs is not None else prediction
# Return
prediction = self.__predict(tag=tag, predictor=predictor, inputs=inputs) if inputs is not None else prediction
return prediction

async def stream (
Expand All @@ -100,6 +106,8 @@ async def stream (
raw_outputs: bool=False,
return_binary_path: bool=True,
data_url_limit: int=None,
client_id: str=None,
configuration_id: str=None
) -> AsyncIterator[Prediction]:
"""
Create a streaming prediction.
Expand All @@ -112,6 +120,8 @@ async def stream (
raw_outputs (bool): Skip converting output values into Pythonic types. This only applies to `CLOUD` predictions.
return_binary_path (bool): Write binary values to file and return a `Path` instead of returning `BytesIO` instance.
data_url_limit (int): Return a data URL if a given output value is smaller than this size in bytes. This only applies to `CLOUD` predictions.
client_id (str): Function client identifier. Specify this to override the current client identifier.
configuration_id (str): Configuration identifier. Specify this to override the current client configuration identifier.
Returns:
Prediction: Created prediction.
Expand All @@ -122,14 +132,14 @@ async def stream (
return
# Serialize inputs
key = uuid4().hex
values = { name: self.to_value(value, name, key=key).model_dump(mode="json") for name, value in inputs.items() } # INCOMPLETE # values
values = { name: self.to_value(value, name, key=key).model_dump(mode="json") for name, value in inputs.items() }
# Request
url = f"{self.client.api_url}/predict/{tag}?stream=true&rawOutputs=true&dataUrlLimit={data_url_limit}"
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.client.access_key}",
"fxn-client": self.__get_client_id(),
"fxn-configuration-token": self.__get_configuration_token()
"fxn-client": client_id if client_id is not None else self.__get_client_id(),
"fxn-configuration-token": configuration_id if configuration_id is not None else self.__get_configuration_id()
}
async with ClientSession(headers=headers) as session:
async with session.post(url, data=dumps(values)) as response:
Expand All @@ -143,12 +153,14 @@ async def stream (
raise RuntimeError(error)
# Parse prediction
prediction = self.__parse_prediction(prediction, raw_outputs=raw_outputs, return_binary_path=return_binary_path)
# Create edge prediction
if prediction.type == PredictorType.Edge:
predictor = self.__load(prediction)
self.__cache[tag] = predictor
prediction = self.__predict(tag=tag, predictor=predictor, inputs=inputs) if inputs is not None else prediction
# Yield
# Check edge prediction
if prediction.type != PredictorType.Edge or raw_outputs:
return prediction
# Load edge predictor
predictor = self.__load(prediction)
self.__cache[tag] = predictor
# Create prediction
prediction = self.__predict(tag=tag, predictor=predictor, inputs=inputs) if inputs is not None else prediction
yield prediction

def to_object (
Expand Down Expand Up @@ -303,14 +315,14 @@ def __get_client_id (self) -> str:
return f"windows:{machine()}"
raise RuntimeError(f"Function cannot make predictions on the {id} platform")

def __get_configuration_token (self) -> Optional[str]:
def __get_configuration_id (self) -> Optional[str]:
# Check
if not self.__fxnc:
return None
# Get
buffer = create_string_buffer(2048)
status = self.__fxnc.FXNConfigurationGetUniqueID(buffer, len(buffer))
assert status.value == FXNStatus.OK, f"Failed to create prediction configuration token with status: {status.value}"
assert status.value == FXNStatus.OK, f"Failed to create prediction configuration identifier with status: {status.value}"
uid = buffer.value.decode("utf-8")
# Return
return uid
Expand Down

0 comments on commit e1c7ebd

Please sign in to comment.