Skip to content

Commit

Permalink
Raise error if encountered in chat completion SSE stream (#2558)
Browse files Browse the repository at this point in the history
  • Loading branch information
Wauplin authored Sep 20, 2024
1 parent 64bcff5 commit c0fd4e0
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 3 deletions.
6 changes: 6 additions & 0 deletions src/huggingface_hub/inference/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,12 @@ def _format_chat_completion_stream_output(
# Decode payload
payload = byte_payload.decode("utf-8")
json_payload = json.loads(payload.lstrip("data:").rstrip("/n"))

# Either an error as being returned
if json_payload.get("error") is not None:
raise _parse_text_generation_error(json_payload["error"], json_payload.get("error_type"))

# Or parse token payload
return ChatCompletionStreamOutput.parse_obj_as_instance(json_payload)


Expand Down
34 changes: 31 additions & 3 deletions tests/test_inference_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
hf_hub_download,
)
from huggingface_hub.constants import ALL_INFERENCE_API_FRAMEWORKS, MAIN_INFERENCE_API_FRAMEWORKS
from huggingface_hub.errors import HfHubHTTPError
from huggingface_hub.errors import HfHubHTTPError, ValidationError
from huggingface_hub.inference._client import _open_as_binary
from huggingface_hub.inference._common import (
_stream_chat_completion_response,
Expand Down Expand Up @@ -919,7 +919,14 @@ def test_model_and_base_url_mutually_exclusive(self):
InferenceClient(model="meta-llama/Meta-Llama-3-8B-Instruct", base_url="http://127.0.0.1:8000")


@pytest.mark.parametrize("stop_signal", [b"data: [DONE]", b"data: [DONE]\n", b"data: [DONE] "])
@pytest.mark.parametrize(
"stop_signal",
[
b"data: [DONE]",
b"data: [DONE]\n",
b"data: [DONE] ",
],
)
def test_stream_text_generation_response(stop_signal: bytes):
data = [
b'data: {"index":1,"token":{"id":4560,"text":" trying","logprob":-2.078125,"special":false},"generated_text":null,"details":null}',
Expand All @@ -935,7 +942,14 @@ def test_stream_text_generation_response(stop_signal: bytes):
assert output == [" trying", " to"]


@pytest.mark.parametrize("stop_signal", [b"data: [DONE]", b"data: [DONE]\n", b"data: [DONE] "])
@pytest.mark.parametrize(
"stop_signal",
[
b"data: [DONE]",
b"data: [DONE]\n",
b"data: [DONE] ",
],
)
def test_stream_chat_completion_response(stop_signal: bytes):
data = [
b'data: {"object":"chat.completion.chunk","id":"","created":1721737661,"model":"","system_fingerprint":"2.1.2-dev0-sha-5fca30e","choices":[{"index":0,"delta":{"role":"assistant","content":"Both"},"logprobs":null,"finish_reason":null}]}',
Expand All @@ -952,6 +966,20 @@ def test_stream_chat_completion_response(stop_signal: bytes):
assert output[1].choices[0].delta.content == " Rust"


def test_chat_completion_error_in_stream():
"""
Regression test for https://github.com/huggingface/huggingface_hub/issues/2514.
When an error is encountered in the stream, it should raise a TextGenerationError (e.g. a ValidationError).
"""
data = [
b'data: {"object":"chat.completion.chunk","id":"","created":1721737661,"model":"","system_fingerprint":"2.1.2-dev0-sha-5fca30e","choices":[{"index":0,"delta":{"role":"assistant","content":"Both"},"logprobs":null,"finish_reason":null}]}',
b'data: {"error":"Input validation error: `inputs` tokens + `max_new_tokens` must be <= 4096. Given: 6 `inputs` tokens and 4091 `max_new_tokens`","error_type":"validation"}',
]
with pytest.raises(ValidationError):
for token in _stream_chat_completion_response(data):
pass


INFERENCE_API_URL = "https://api-inference.huggingface.co/models"
INFERENCE_ENDPOINT_URL = "https://rur2d6yoccusjxgn.us-east-1.aws.endpoints.huggingface.cloud" # example
LOCAL_TGI_URL = "http://0.0.0.0:8080"
Expand Down

0 comments on commit c0fd4e0

Please sign in to comment.