Skip to content

Commit

Permalink
Respect HTTPError spec (#1693)
Browse files Browse the repository at this point in the history
* Respect HTTPError spec

* revert
  • Loading branch information
Wauplin authored Sep 25, 2023
1 parent ff0465d commit d9a0eea
Show file tree
Hide file tree
Showing 6 changed files with 21 additions and 12 deletions.
4 changes: 3 additions & 1 deletion src/huggingface_hub/hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1045,7 +1045,9 @@ def whoami(self, token: Optional[str] = None) -> Dict:
raise HTTPError(
"Invalid user token. If you didn't pass a user token, make sure you "
"are properly logged in by executing `huggingface-cli login`, and "
"if you did pass a user token, double-check it's correct."
"if you did pass a user token, double-check it's correct.",
request=e.request,
response=e.response,
) from e
return r.json()

Expand Down
6 changes: 4 additions & 2 deletions src/huggingface_hub/inference/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def post(
)
except TimeoutError as error:
# Convert any `TimeoutError` to a `InferenceTimeoutError`
raise InferenceTimeoutError(f"Inference call timed out: {url}") from error
raise InferenceTimeoutError(f"Inference call timed out: {url}") from error # type: ignore

try:
hf_raise_for_status(response)
Expand All @@ -243,7 +243,9 @@ def post(
if timeout is not None and time.time() - t0 > timeout:
raise InferenceTimeoutError(
f"Model not loaded on the server: {url}. Please retry with a higher timeout (current:"
f" {self.timeout})."
f" {self.timeout}).",
request=error.request,
response=error.response,
) from error
# ...or wait 1s and retry
logger.info(f"Waiting for model to be loaded on the server: {error}")
Expand Down
6 changes: 4 additions & 2 deletions src/huggingface_hub/inference/_generated/_async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ async def post(
except asyncio.TimeoutError as error:
await client.close()
# Convert any `TimeoutError` to a `InferenceTimeoutError`
raise InferenceTimeoutError(f"Inference call timed out: {url}") from error
raise InferenceTimeoutError(f"Inference call timed out: {url}") from error # type: ignore
except aiohttp.ClientResponseError as error:
error.response_error_payload = response_error_payload
await client.close()
Expand All @@ -240,7 +240,9 @@ async def post(
if timeout is not None and time.time() - t0 > timeout:
raise InferenceTimeoutError(
f"Model not loaded on the server: {url}. Please retry with a higher timeout"
f" (current: {self.timeout})."
f" (current: {self.timeout}).",
request=error.request,
response=error.response,
) from error
# ...or wait 1s and retry
logger.info(f"Waiting for model to be loaded on the server: {error}")
Expand Down
8 changes: 4 additions & 4 deletions src/huggingface_hub/inference/_text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,13 +468,13 @@ def raise_text_generation_error(http_error: HTTPError) -> NoReturn:
# If error_type => more information than `hf_raise_for_status`
if error_type is not None:
if error_type == "generation":
raise GenerationError(message) from http_error
raise GenerationError(message) from http_error # type: ignore
if error_type == "incomplete_generation":
raise IncompleteGenerationError(message) from http_error
raise IncompleteGenerationError(message) from http_error # type: ignore
if error_type == "overloaded":
raise OverloadedError(message) from http_error
raise OverloadedError(message) from http_error # type: ignore
if error_type == "validation":
raise ValidationError(message) from http_error
raise ValidationError(message) from http_error # type: ignore

# Otherwise, fallback to default error
raise http_error
3 changes: 2 additions & 1 deletion src/huggingface_hub/utils/_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ def __init__(self, message: str, response: Optional[Response] = None):
request_id=self.request_id,
server_message=self.server_message,
),
response=response,
response=response, # type: ignore
request=response.request if response is not None else None, # type: ignore
)

def append_to_message(self, additional_message: str) -> None:
Expand Down
6 changes: 4 additions & 2 deletions utils/generate_async_inference_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def _rename_to_AsyncInferenceClient(code: str) -> str:
except asyncio.TimeoutError as error:
await client.close()
# Convert any `TimeoutError` to a `InferenceTimeoutError`
raise InferenceTimeoutError(f"Inference call timed out: {url}") from error
raise InferenceTimeoutError(f"Inference call timed out: {url}") from error # type: ignore
except aiohttp.ClientResponseError as error:
error.response_error_payload = response_error_payload
await client.close()
Expand All @@ -216,7 +216,9 @@ def _rename_to_AsyncInferenceClient(code: str) -> str:
if timeout is not None and time.time() - t0 > timeout:
raise InferenceTimeoutError(
f"Model not loaded on the server: {url}. Please retry with a higher timeout"
f" (current: {self.timeout})."
f" (current: {self.timeout}).",
request=error.request,
response=error.response,
) from error
# ...or wait 1s and retry
logger.info(f"Waiting for model to be loaded on the server: {error}")
Expand Down

0 comments on commit d9a0eea

Please sign in to comment.