diff --git a/HISTORY.md b/HISTORY.md index e7c8d61021..d536440ddc 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,6 +1,12 @@ Release History =============== +3.7.1 (2024-06-??) +------------------ + +**Fixed** +- auth argument not accepting a function according to static type checkers. (#133) + 3.7.0 (2024-06-24) ------------------ diff --git a/src/niquests/_typing.py b/src/niquests/_typing.py index 6a880f07a3..b301e74f82 100644 --- a/src/niquests/_typing.py +++ b/src/niquests/_typing.py @@ -26,10 +26,11 @@ from .auth import AuthBase from .structures import CaseInsensitiveDict +if typing.TYPE_CHECKING: + from .models import PreparedRequest + #: (Restricted) list of http verb that we natively support and understand. -HttpMethodType: typing.TypeAlias = ( - str # todo: have typing.Literal when ready to drop Python 3.7 -) +HttpMethodType: typing.TypeAlias = str #: List of formats accepted for URL queries parameters. (e.g. /?param1=a¶m2=b) QueryParameterType: typing.TypeAlias = typing.Union[ typing.List[typing.Tuple[str, typing.Union[str, typing.List[str]]]], @@ -89,6 +90,7 @@ typing.Tuple[typing.Union[str, bytes], typing.Union[str, bytes]], str, AuthBase, + typing.Callable[["PreparedRequest"], "PreparedRequest"], ] #: Map for each protocol (http, https) associated proxy to be used. ProxyType: typing.TypeAlias = typing.Dict[str, str] diff --git a/src/niquests/models.py b/src/niquests/models.py index 513e4b1bd4..9f2b0ff40e 100644 --- a/src/niquests/models.py +++ b/src/niquests/models.py @@ -323,6 +323,8 @@ def __init__(self) -> None: self.ocsp_verified: bool | None = None #: upload progress if any. self.upload_progress: TransferProgress | None = None + #: internal usage only. warn us that we should re-compute content-length and await auth() outside of PreparedRequest. + self._asynchronous_auth: bool = False @property def oheaders(self) -> Headers: @@ -636,7 +638,11 @@ def prepare_auth(self, auth: HttpAuthenticationType | None, url: str = "") -> No "Unexpected non-callable authentication. Did you pass unsupported tuple to auth argument?" ) - if not asyncio.iscoroutinefunction(auth.__call__): + self._asynchronous_auth = hasattr( + auth, "__call__" + ) and asyncio.iscoroutinefunction(auth.__call__) + + if not self._asynchronous_auth: # Allow auth to make its changes. r = auth(self)