Skip to content

Commit

Permalink
refactor: Made auth_headers and auth_params of `APIAuthenticatorB…
Browse files Browse the repository at this point in the history
…ase` simple instance attributes instead of decorated properties (#2596)

Closes #925
  • Loading branch information
edgarrmondragon authored Aug 9, 2024
1 parent e0efe9c commit 5fc0f75
Showing 1 changed file with 34 additions and 44 deletions.
78 changes: 34 additions & 44 deletions singer_sdk/authenticators.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,12 @@ def __call__(cls, *args: t.Any, **kwargs: t.Any) -> t.Any: # noqa: ANN401


class APIAuthenticatorBase:
"""Base class for offloading API auth."""
"""Base class for offloading API auth.
Attributes:
auth_headers: HTTP headers for authentication.
auth_params: URL query parameters for authentication.
"""

def __init__(self, stream: RESTStream) -> None:
"""Init authenticator.
Expand All @@ -89,8 +94,8 @@ def __init__(self, stream: RESTStream) -> None:
"""
self.tap_name: str = stream.tap_name
self._config: dict[str, t.Any] = dict(stream.config)
self._auth_headers: dict[str, t.Any] = {}
self._auth_params: dict[str, t.Any] = {}
self.auth_headers: dict[str, t.Any] = {}
self.auth_params: dict[str, t.Any] = {}
self.logger: logging.Logger = stream.logger

@property
Expand All @@ -102,24 +107,6 @@ def config(self) -> t.Mapping[str, t.Any]:
"""
return MappingProxyType(self._config)

@property
def auth_headers(self) -> dict:
"""Get headers.
Returns:
HTTP headers for authentication.
"""
return self._auth_headers or {}

@property
def auth_params(self) -> dict:
"""Get query parameters.
Returns:
URL query parameters for authentication.
"""
return self._auth_params or {}

def authenticate_request(
self,
request: requests.PreparedRequest,
Expand Down Expand Up @@ -177,10 +164,10 @@ def __init__(
auth_headers: Authentication headers.
"""
super().__init__(stream=stream)
if self._auth_headers is None:
self._auth_headers = {}
if self.auth_headers is None:
self.auth_headers = {}
if auth_headers:
self._auth_headers.update(auth_headers)
self.auth_headers.update(auth_headers)


class APIKeyAuthenticator(APIAuthenticatorBase):
Expand Down Expand Up @@ -218,13 +205,13 @@ def __init__(
raise ValueError(msg)

if location == "header":
if self._auth_headers is None:
self._auth_headers = {}
self._auth_headers.update(auth_credentials)
if self.auth_headers is None:
self.auth_headers = {}
self.auth_headers.update(auth_credentials)
elif location == "params":
if self._auth_params is None:
self._auth_params = {}
self._auth_params.update(auth_credentials)
if self.auth_params is None:
self.auth_params = {}
self.auth_params.update(auth_credentials)

@classmethod
def create_for_stream(
Expand Down Expand Up @@ -267,9 +254,9 @@ def __init__(self, stream: RESTStream, token: str) -> None:
super().__init__(stream=stream)
auth_credentials = {"Authorization": f"Bearer {token}"}

if self._auth_headers is None:
self._auth_headers = {}
self._auth_headers.update(auth_credentials)
if self.auth_headers is None:
self.auth_headers = {}
self.auth_headers.update(auth_credentials)

@classmethod
def create_for_stream(
Expand Down Expand Up @@ -326,9 +313,9 @@ def __init__(
auth_token = base64.b64encode(credentials).decode("ascii")
auth_credentials = {"Authorization": f"Basic {auth_token}"}

if self._auth_headers is None:
self._auth_headers = {}
self._auth_headers.update(auth_credentials)
if self.auth_headers is None:
self.auth_headers = {}
self.auth_headers.update(auth_credentials)

@classmethod
def create_for_stream(
Expand Down Expand Up @@ -383,20 +370,23 @@ def __init__(
self.last_refreshed: datetime.datetime | None = None
self.expires_in: int | None = None

@property
def auth_headers(self) -> dict:
"""Return a dictionary of auth headers to be applied.
def authenticate_request(
self,
request: requests.PreparedRequest,
) -> requests.PreparedRequest:
"""Authenticate an OAuth request.
These will be merged with any `http_headers` specified in the stream.
Args:
request: A :class:`requests.PreparedRequest` object.
Returns:
HTTP headers for authentication.
The authenticated request object.
"""
if not self.is_token_valid():
self.update_access_token()
result = super().auth_headers
result["Authorization"] = f"Bearer {self.access_token}"
return result

self.auth_headers["Authorization"] = f"Bearer {self.access_token}"
return super().authenticate_request(request)

@property
def auth_endpoint(self) -> str:
Expand Down

0 comments on commit 5fc0f75

Please sign in to comment.