Skip to content

Commit

Permalink
Add type hints to management (#497)
Browse files Browse the repository at this point in the history
  • Loading branch information
adamjmcgrath authored Jul 24, 2023
2 parents f902e76 + 2bc0c98 commit d2ab498
Show file tree
Hide file tree
Showing 39 changed files with 926 additions and 552 deletions.
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ repos:
rev: v3.3.1
hooks:
- id: pyupgrade
args: [--keep-runtime-typing]
- repo: https://github.com/PyCQA/isort
rev: 5.12.0
hooks:
Expand Down
2 changes: 1 addition & 1 deletion auth0/management/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
if is_async_available():
from .async_auth0 import AsyncAuth0 as Auth0
else: # pragma: no cover
from .auth0 import Auth0
from .auth0 import Auth0 # type: ignore[assignment]

__all__ = (
"Auth0",
Expand Down
77 changes: 45 additions & 32 deletions auth0/management/actions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
from ..rest import RestClient
from __future__ import annotations

from typing import Any

from ..rest import RestClient, RestClientOptions
from ..types import TimeoutType


class Actions:
Expand All @@ -17,28 +22,31 @@ class Actions:
both values separately or a float to set both to it.
(defaults to 5.0 for both)
rest_options (RestClientOptions): Pass an instance of
protocol (str, optional): Protocol to use when making requests.
(defaults to "https")
rest_options (RestClientOptions, optional): Pass an instance of
RestClientOptions to configure additional RestClient
options, such as rate-limit retries.
(defaults to None)
"""

def __init__(
self,
domain,
token,
telemetry=True,
timeout=5.0,
protocol="https",
rest_options=None,
):
domain: str,
token: str,
telemetry: bool = True,
timeout: TimeoutType = 5.0,
protocol: str = "https",
rest_options: RestClientOptions | None = None,
) -> None:
self.domain = domain
self.protocol = protocol
self.client = RestClient(
jwt=token, telemetry=telemetry, timeout=timeout, options=rest_options
)

def _url(self, *args):
def _url(self, *args: str | None) -> str:
url = f"{self.protocol}://{self.domain}/api/v2/actions"
for p in args:
if p is not None:
Expand All @@ -47,13 +55,13 @@ def _url(self, *args):

def get_actions(
self,
trigger_id=None,
action_name=None,
deployed=None,
installed=False,
page=None,
per_page=None,
):
trigger_id: str | None = None,
action_name: str | None = None,
deployed: bool | None = None,
installed: bool = False,
page: int | None = None,
per_page: int | None = None,
) -> Any:
"""Get all actions.
Args:
Expand All @@ -77,21 +85,20 @@ def get_actions(
See: https://auth0.com/docs/api/management/v2#!/Actions/get_actions
"""

if deployed is not None:
deployed = str(deployed).lower()
deployed_str = str(deployed).lower() if deployed is not None else None

params = {
"triggerId": trigger_id,
"actionName": action_name,
"deployed": deployed,
"deployed": deployed_str,
"installed": str(installed).lower(),
"page": page,
"per_page": per_page,
}

return self.client.get(self._url("actions"), params=params)

def create_action(self, body):
def create_action(self, body: dict[str, Any]) -> dict[str, Any]:
"""Create a new action.
Args:
Expand All @@ -102,7 +109,7 @@ def create_action(self, body):

return self.client.post(self._url("actions"), data=body)

def update_action(self, id, body):
def update_action(self, id: str, body: dict[str, Any]) -> dict[str, Any]:
"""Updates an action.
Args:
Expand All @@ -115,7 +122,7 @@ def update_action(self, id, body):

return self.client.patch(self._url("actions", id), data=body)

def get_action(self, id):
def get_action(self, id: str) -> dict[str, Any]:
"""Retrieves an action by its ID.
Args:
Expand All @@ -127,7 +134,7 @@ def get_action(self, id):

return self.client.get(self._url("actions", id), params=params)

def delete_action(self, id, force=False):
def delete_action(self, id: str, force: bool = False) -> Any:
"""Deletes an action and all of its associated versions.
Args:
Expand All @@ -142,7 +149,7 @@ def delete_action(self, id, force=False):

return self.client.delete(self._url("actions", id), params=params)

def get_triggers(self):
def get_triggers(self) -> dict[str, Any]:
"""Retrieve the set of triggers currently available within actions.
See: https://auth0.com/docs/api/management/v2#!/Actions/get_triggers
Expand All @@ -151,7 +158,7 @@ def get_triggers(self):

return self.client.get(self._url("triggers"), params=params)

def get_execution(self, id):
def get_execution(self, id: str) -> dict[str, Any]:
"""Get information about a specific execution of a trigger.
Args:
Expand All @@ -163,7 +170,9 @@ def get_execution(self, id):

return self.client.get(self._url("executions", id), params=params)

def get_action_versions(self, id, page=None, per_page=None):
def get_action_versions(
self, id: str, page: int | None = None, per_page: int | None = None
) -> dict[str, Any]:
"""Get all of an action's versions.
Args:
Expand All @@ -181,7 +190,9 @@ def get_action_versions(self, id, page=None, per_page=None):

return self.client.get(self._url("actions", id, "versions"), params=params)

def get_trigger_bindings(self, id, page=None, per_page=None):
def get_trigger_bindings(
self, id: str, page: int | None = None, per_page: int | None = None
) -> dict[str, Any]:
"""Get the actions that are bound to a trigger.
Args:
Expand All @@ -198,7 +209,7 @@ def get_trigger_bindings(self, id, page=None, per_page=None):
params = {"page": page, "per_page": per_page}
return self.client.get(self._url("triggers", id, "bindings"), params=params)

def get_action_version(self, action_id, version_id):
def get_action_version(self, action_id: str, version_id: str) -> dict[str, Any]:
"""Retrieve a specific version of an action.
Args:
Expand All @@ -214,7 +225,7 @@ def get_action_version(self, action_id, version_id):
self._url("actions", action_id, "versions", version_id), params=params
)

def deploy_action(self, id):
def deploy_action(self, id: str) -> dict[str, Any]:
"""Deploy an action.
Args:
Expand All @@ -224,7 +235,9 @@ def deploy_action(self, id):
"""
return self.client.post(self._url("actions", id, "deploy"))

def rollback_action_version(self, action_id, version_id):
def rollback_action_version(
self, action_id: str, version_id: str
) -> dict[str, Any]:
"""Roll back to a previous version of an action.
Args:
Expand All @@ -238,7 +251,7 @@ def rollback_action_version(self, action_id, version_id):
self._url("actions", action_id, "versions", version_id, "deploy"), data={}
)

def update_trigger_bindings(self, id, body):
def update_trigger_bindings(self, id: str, body: dict[str, Any]) -> dict[str, Any]:
"""Update a trigger's bindings.
Args:
Expand Down
24 changes: 20 additions & 4 deletions auth0/management/async_auth0.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import aiohttp

from ..asyncify import asyncify
from .auth0 import Auth0

if TYPE_CHECKING:
from types import TracebackType

from auth0.rest import RestClientOptions


class AsyncAuth0:
"""Provides easy access to all endpoint classes
Expand All @@ -18,7 +27,9 @@ class AsyncAuth0:
(defaults to None)
"""

def __init__(self, domain, token, rest_options=None):
def __init__(
self, domain: str, token: str, rest_options: RestClientOptions | None = None
) -> None:
self._services = []
for name, attr in vars(Auth0(domain, token, rest_options=rest_options)).items():
cls = asyncify(attr.__class__)
Expand All @@ -30,7 +41,7 @@ def __init__(self, domain, token, rest_options=None):
service,
)

def set_session(self, session):
def set_session(self, session: aiohttp.ClientSession) -> None:
"""Set Client Session to improve performance by reusing session.
Args:
Expand All @@ -41,11 +52,16 @@ def set_session(self, session):
for service in self._services:
service.set_session(self._session)

async def __aenter__(self):
async def __aenter__(self) -> AsyncAuth0:
"""Automatically create and set session within context manager."""
self.set_session(aiohttp.ClientSession())
return self

async def __aexit__(self, exc_type, exc_val, exc_tb):
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
"""Automatically close session within context manager."""
await self._session.close()
40 changes: 25 additions & 15 deletions auth0/management/attack_protection.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
from ..rest import RestClient
from __future__ import annotations

from typing import Any

from ..rest import RestClient, RestClientOptions
from ..types import TimeoutType


class AttackProtection:
Expand All @@ -17,6 +22,9 @@ class AttackProtection:
both values separately or a float to set both to it.
(defaults to 5.0 for both)
protocol (str, optional): Protocol to use when making requests.
(defaults to "https")
rest_options (RestClientOptions): Pass an instance of
RestClientOptions to configure additional RestClient
options, such as rate-limit retries.
Expand All @@ -25,25 +33,25 @@ class AttackProtection:

def __init__(
self,
domain,
token,
telemetry=True,
timeout=5.0,
protocol="https",
rest_options=None,
):
domain: str,
token: str,
telemetry: bool = True,
timeout: TimeoutType = 5.0,
protocol: str = "https",
rest_options: RestClientOptions | None = None,
) -> None:
self.domain = domain
self.protocol = protocol
self.client = RestClient(
jwt=token, telemetry=telemetry, timeout=timeout, options=rest_options
)

def _url(self, component):
def _url(self, component: str) -> str:
return "{}://{}/api/v2/attack-protection/{}".format(
self.protocol, self.domain, component
)

def get_breached_password_detection(self):
def get_breached_password_detection(self) -> dict[str, Any]:
"""Get breached password detection settings.
Returns the breached password detection settings.
Expand All @@ -53,7 +61,9 @@ def get_breached_password_detection(self):
url = self._url("breached-password-detection")
return self.client.get(url)

def update_breached_password_detection(self, body):
def update_breached_password_detection(
self, body: dict[str, Any]
) -> dict[str, Any]:
"""Update breached password detection settings.
Returns the breached password detection settings.
Expand All @@ -67,7 +77,7 @@ def update_breached_password_detection(self, body):
url = self._url("breached-password-detection")
return self.client.patch(url, data=body)

def get_brute_force_protection(self):
def get_brute_force_protection(self) -> dict[str, Any]:
"""Get the brute force configuration.
Returns the brute force configuration.
Expand All @@ -77,7 +87,7 @@ def get_brute_force_protection(self):
url = self._url("brute-force-protection")
return self.client.get(url)

def update_brute_force_protection(self, body):
def update_brute_force_protection(self, body: dict[str, Any]) -> dict[str, Any]:
"""Update the brute force configuration.
Returns the brute force configuration.
Expand All @@ -91,7 +101,7 @@ def update_brute_force_protection(self, body):
url = self._url("brute-force-protection")
return self.client.patch(url, data=body)

def get_suspicious_ip_throttling(self):
def get_suspicious_ip_throttling(self) -> dict[str, Any]:
"""Get the suspicious IP throttling configuration.
Returns the suspicious IP throttling configuration.
Expand All @@ -101,7 +111,7 @@ def get_suspicious_ip_throttling(self):
url = self._url("suspicious-ip-throttling")
return self.client.get(url)

def update_suspicious_ip_throttling(self, body):
def update_suspicious_ip_throttling(self, body: dict[str, Any]) -> dict[str, Any]:
"""Update the suspicious IP throttling configuration.
Returns the suspicious IP throttling configuration.
Expand Down
Loading

0 comments on commit d2ab498

Please sign in to comment.