From c0fd4e0f7519a4e3659c836081cc7e38c0d14b35 Mon Sep 17 00:00:00 2001 From: Lucain Date: Fri, 20 Sep 2024 18:04:51 +0200 Subject: [PATCH] Raise error if encountered in chat completion SSE stream (#2558) --- src/huggingface_hub/inference/_common.py | 6 +++++ tests/test_inference_client.py | 34 +++++++++++++++++++++--- 2 files changed, 37 insertions(+), 3 deletions(-) diff --git a/src/huggingface_hub/inference/_common.py b/src/huggingface_hub/inference/_common.py index 3870fcddeb..a92d8fad4a 100644 --- a/src/huggingface_hub/inference/_common.py +++ b/src/huggingface_hub/inference/_common.py @@ -350,6 +350,12 @@ def _format_chat_completion_stream_output( # Decode payload payload = byte_payload.decode("utf-8") json_payload = json.loads(payload.lstrip("data:").rstrip("/n")) + + # Either an error as being returned + if json_payload.get("error") is not None: + raise _parse_text_generation_error(json_payload["error"], json_payload.get("error_type")) + + # Or parse token payload return ChatCompletionStreamOutput.parse_obj_as_instance(json_payload) diff --git a/tests/test_inference_client.py b/tests/test_inference_client.py index 001495ad8b..1772df8df6 100644 --- a/tests/test_inference_client.py +++ b/tests/test_inference_client.py @@ -47,7 +47,7 @@ hf_hub_download, ) from huggingface_hub.constants import ALL_INFERENCE_API_FRAMEWORKS, MAIN_INFERENCE_API_FRAMEWORKS -from huggingface_hub.errors import HfHubHTTPError +from huggingface_hub.errors import HfHubHTTPError, ValidationError from huggingface_hub.inference._client import _open_as_binary from huggingface_hub.inference._common import ( _stream_chat_completion_response, @@ -919,7 +919,14 @@ def test_model_and_base_url_mutually_exclusive(self): InferenceClient(model="meta-llama/Meta-Llama-3-8B-Instruct", base_url="http://127.0.0.1:8000") -@pytest.mark.parametrize("stop_signal", [b"data: [DONE]", b"data: [DONE]\n", b"data: [DONE] "]) +@pytest.mark.parametrize( + "stop_signal", + [ + b"data: [DONE]", + b"data: [DONE]\n", + b"data: [DONE] ", + ], +) def test_stream_text_generation_response(stop_signal: bytes): data = [ b'data: {"index":1,"token":{"id":4560,"text":" trying","logprob":-2.078125,"special":false},"generated_text":null,"details":null}', @@ -935,7 +942,14 @@ def test_stream_text_generation_response(stop_signal: bytes): assert output == [" trying", " to"] -@pytest.mark.parametrize("stop_signal", [b"data: [DONE]", b"data: [DONE]\n", b"data: [DONE] "]) +@pytest.mark.parametrize( + "stop_signal", + [ + b"data: [DONE]", + b"data: [DONE]\n", + b"data: [DONE] ", + ], +) def test_stream_chat_completion_response(stop_signal: bytes): data = [ b'data: {"object":"chat.completion.chunk","id":"","created":1721737661,"model":"","system_fingerprint":"2.1.2-dev0-sha-5fca30e","choices":[{"index":0,"delta":{"role":"assistant","content":"Both"},"logprobs":null,"finish_reason":null}]}', @@ -952,6 +966,20 @@ def test_stream_chat_completion_response(stop_signal: bytes): assert output[1].choices[0].delta.content == " Rust" +def test_chat_completion_error_in_stream(): + """ + Regression test for https://github.com/huggingface/huggingface_hub/issues/2514. + When an error is encountered in the stream, it should raise a TextGenerationError (e.g. a ValidationError). + """ + data = [ + b'data: {"object":"chat.completion.chunk","id":"","created":1721737661,"model":"","system_fingerprint":"2.1.2-dev0-sha-5fca30e","choices":[{"index":0,"delta":{"role":"assistant","content":"Both"},"logprobs":null,"finish_reason":null}]}', + b'data: {"error":"Input validation error: `inputs` tokens + `max_new_tokens` must be <= 4096. Given: 6 `inputs` tokens and 4091 `max_new_tokens`","error_type":"validation"}', + ] + with pytest.raises(ValidationError): + for token in _stream_chat_completion_response(data): + pass + + INFERENCE_API_URL = "https://api-inference.huggingface.co/models" INFERENCE_ENDPOINT_URL = "https://rur2d6yoccusjxgn.us-east-1.aws.endpoints.huggingface.cloud" # example LOCAL_TGI_URL = "http://0.0.0.0:8080"