You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Describe the bug
When running a SageMaker container using the default PyTorch inference Transformer, when specifying a UTF-8 Content-Type (application/json, text/csv), the TorchServe inference.py implementation will throw an error during de-serialization within input_fn. This is because the TorchServe inference input_fn function expects the input_data to be a bytes-like object, but it has already been decoded to a str by the Transformer. The NumpyDeserializer does support de-serializing from UTF-8 Content Types, but the code is effectively unreachable for input processing (can still be reached for output) without overriding the default Inference Handler / Handler Service / Transformer (transformer can't be specified if input_fn is specified).
The TorchServe inference.py script was implemented in #4662
With Python clients, or using the Predictor class from the SageMaker SDK, this is easily worked around. However, if trying to make predictions from other languages, such as Java, this is much more difficult as a JSON representation of the inference input cannot be provided, and custom serialization to match the NPY format would be necessary.
This is just one example use case - the issue may be applicable for different input beyond Numpy arrays / scikit-learn algorithms. Ownership of fix could lie either in the SageMaker python SDK, within the sagemaker-pytorch-inference repository,
or elsewhere. A change to any of these components could run the risk of impacting production behavior which clients may be reliant on.
To reproduce
import boto3
import io
import mlflow
from mlflow import MlflowClient
from mlflow.models import infer_signature
import numpy as np
import pandas as pd
from sklearn import datasets
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from sklearn.model_selection import train_test_split
X, y = datasets.load_iris(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
params = {
"solver": "lbfgs",
"max_iter": 1000,
"multi_class": "auto",
"random_state": 8888
}
lr = LogisticRegression(**params)
lr.fit(X_train, y_train)
y_pred = lr.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
mlflow.set_tracking_uri(os.environ['MLFLOW_URI'])
with mlflow.start_run() as run:
mlflow.log_params(params)
mlflow.log_metric('accuracy', accuracy)
mlflow.set_tag('Training Info', 'Basic LR model for iris data')
signature = infer_signature(X_train, lr.predict(X_train))
model_info = mlflow.sklearn.log_model(
sk_model=lr,
artifact_path='sklearn-model',
signature=signature,
input_example=X_train,
registered_model_name='tracking-quickstart'
)
model_uri = f'runs:/{run.info.run_id}/sklearn-model'
schema_builder = SchemaBuilder(sample_input=X_train, sample_output=y_pred)
model_builder = ModelBuilder(
mode=Mode.SAGEMAKER_ENDPOINT,
schema_builder=schema_builder,
role_arn=os.environ['ROLE_ARN'],
model_metadata={
"MLFLOW_MODEL_PATH": model_uri,
"MLFLOW_TRACKING_ARN": os.environ['MLFLOW_TRACKING_SERVER_ARN']
}
)
model = model_builder.build()
predictor = model.deploy(initial_instance_count=1, instance_type="ml.t2.medium")
predictor.predict(X_test) # works as expected
sagemaker_runtime_client = boto3.client('sagemaker-runtime')
# works as expected:
buffer = io.BytesIO()
np.save(buffer, X_test)
sagemaker_runtime_client.invoke_endpoint(
EndpointName=predictor.endpoint_name,
Body=buffer.getvalue(),
ContentType='application/x-npy'
)
predictions = np.load(io.BytesIO(invoke_response['Body'].read()))
# does not work as expected;
json_body = json.dumps(X_test.tolist()).encode('utf-8')
invoke_response = sagemaker_runtime_client.invoke_endpoint(
EndpointName=predictor.endpoint_name,
Body=json_body,
ContentType='application/json'
)
ModelError: An error occurred (ModelError) when calling the InvokeEndpoint operation: Received server error (500) from primary with message "<!DOCTYPE HTML PUBLIC "-//W3C//DTD HTML 3.2 Final//EN">
<title>500 Internal Server Error<[/title](https://###REDACTED###.studio.us-west-2.sagemaker.aws/title)>
<h1>Internal Server Error<[/h1](https://###REDACTED###.studio.us-west-2.sagemaker.aws/h1)>
<p>The server encountered an internal error and was unable to complete your request. Either the server is overloaded or there is an error in the application.<[/p](https://###REDACTED###.studio.us-west-2.sagemaker.aws/p)>
". See https://us-west-2.console.aws.amazon.com/cloudwatch/home?region=us-west-2#logEventViewer:group=###REDACTED### in account ###REDACTED### for more information.
Along similar lines (can open separate issue if applicable) - it doesn't seem as though requesting the response as JSON via the Accept header works. This is perhaps expected, though is only evident upon attempting to de-serialize the returned stream:
invoke_response_json_resp = sagemaker_runtime_client.invoke_endpoint(
EndpointName=predictor.endpoint_name,
Body=buffer.getvalue(),
ContentType='application/x-npy',
Accept='application/json'
)
np.load(io.BytesIO(invoke_response_json_resp['Body'].read())) # evidently the response stream is not JSON
In general, the error messaging during serialization/de-serialization is unhelpful/misleading, as it suggests the (de-)serialization failed for pickled data, which is not always the case.
Expected behavior
I expect to be able to invoke the SageMaker endpoints with a JSON-serialized Numpy array and receive NPY response.
Screenshots or logs
2024-09-11T23:52:57.493Z IP - - [11/Sep/2024:23:52:55 +0000] "POST /invocations HTTP/1.1" 200 368 "-" "AHC/2.0"
2024-09-11T23:55:30.402Z 2024-09-11 23:55:30,305 ERROR - inference - Exception on /invocations [POST]
2024-09-11T23:55:30.402Z Traceback (most recent call last):
File "/opt/ml/code/inference.py", line 74, in input_fn
io.BytesIO(input_data), content_type[0]
2024-09-11T23:55:30.402Z TypeError: a bytes-like object is required, not 'str'
2024-09-11T23:55:30.402Z The above exception was the direct cause of the following exception:
2024-09-11T23:55:30.402Z Traceback (most recent call last):
File "/miniconda3/lib/python3.8/site-packages/sagemaker_containers/_functions.py", line 93, in wrapper
return fn(*args, **kwargs)
File "/opt/ml/code/inference.py", line 77, in input_fn
raise Exception("Encountered error in deserialize_request.") from e
2024-09-11T23:55:30.402Z Exception: Encountered error in deserialize_request.
2024-09-11T23:55:30.402Z IP - - [11/Sep/2024:23:55:30 +0000] "POST /invocations HTTP/1.1" 500 290 "-" "AHC/2.0"
2024-09-11T23:55:30.402Z During handling of the above exception, another exception occurred:
2024-09-11T23:55:30.402Z Traceback (most recent call last):
File "/miniconda3/lib/python3.8/site-packages/flask/app.py", line 2446, in wsgi_app
response = self.full_dispatch_request()
File "/miniconda3/lib/python3.8/site-packages/flask/app.py", line 1951, in full_dispatch_request
rv = self.handle_user_exception(e)
File "/miniconda3/lib/python3.8/site-packages/flask/app.py", line 1820, in handle_user_exception
reraise(exc_type, exc_value, tb)
File "/miniconda3/lib/python3.8/site-packages/flask/_compat.py", line 39, in reraise
raise value
File "/miniconda3/lib/python3.8/site-packages/flask/app.py", line 1949, in full_dispatch_request
rv = self.dispatch_request()
File "/miniconda3/lib/python3.8/site-packages/flask/app.py", line 1935, in dispatch_request
return self.view_functions[rule.endpoint](**req.view_args)
File "/miniconda3/lib/python3.8/site-packages/sagemaker_containers/_transformer.py", line 199, in transform
result = self._transform_fn(
File "/miniconda3/lib/python3.8/site-packages/sagemaker_containers/_transformer.py", line 227, in _default_transform_fn
data = self._input_fn(content, content_type)
File "/miniconda3/lib/python3.8/site-packages/sagemaker_containers/_functions.py", line 95, in wrapper
six.reraise(error_class, error_class(e), sys.exc_info()[2])
File "/miniconda3/lib/python3.8/site-packages/six.py", line 702, in reraise
raise value.with_traceback(tb)
File "/miniconda3/lib/python3.8/site-packages/sagemaker_containers/_functions.py", line 93, in wrapper
return fn(*args, **kwargs)
File "/opt/ml/code/inference.py", line 77, in input_fn
raise Exception("Encountered error in deserialize_request.") from e
2024-09-11T23:55:32.657Z sagemaker_containers._errors.ClientError: Encountered error in deserialize_request.
System information
A description of your system. Please provide:
SageMaker Python SDK version: 2.231.0
Framework name (eg. PyTorch) or algorithm (eg. KMeans): PyTorch / SKLearn LogisticRegression
Framework version: 1.2.1
Python version: 3.8.17
CPU or GPU: CPU
Custom Docker image (Y/N): N
Additional context
Relevant guides / documentation used to generate example:
Describe the bug
When running a SageMaker container using the default PyTorch inference Transformer, when specifying a UTF-8 Content-Type (
application/json
,text/csv
), the TorchServeinference.py
implementation will throw an error during de-serialization withininput_fn
. This is because the TorchServe inferenceinput_fn
function expects theinput_data
to be a bytes-like object, but it has already been decoded to astr
by the Transformer. TheNumpyDeserializer
does support de-serializing from UTF-8 Content Types, but the code is effectively unreachable for input processing (can still be reached for output) without overriding the default Inference Handler / Handler Service / Transformer (transformer can't be specified ifinput_fn
is specified).The TorchServe
inference.py
script was implemented in #4662With Python clients, or using the
Predictor
class from the SageMaker SDK, this is easily worked around. However, if trying to make predictions from other languages, such as Java, this is much more difficult as a JSON representation of the inference input cannot be provided, and custom serialization to match the NPY format would be necessary.This is just one example use case - the issue may be applicable for different input beyond Numpy arrays / scikit-learn algorithms. Ownership of fix could lie either in the SageMaker python SDK, within the sagemaker-pytorch-inference repository,
or elsewhere. A change to any of these components could run the risk of impacting production behavior which clients may be reliant on.
To reproduce
Along similar lines (can open separate issue if applicable) - it doesn't seem as though requesting the response as JSON via the
Accept
header works. This is perhaps expected, though is only evident upon attempting to de-serialize the returned stream:Whereas the below works:
In general, the error messaging during serialization/de-serialization is unhelpful/misleading, as it suggests the (de-)serialization failed for pickled data, which is not always the case.
Expected behavior
I expect to be able to invoke the SageMaker endpoints with a JSON-serialized Numpy array and receive NPY response.
Screenshots or logs
System information
A description of your system. Please provide:
Additional context
Relevant guides / documentation used to generate example:
The text was updated successfully, but these errors were encountered: