Skip to content

Commit

Permalink
added type checking and other changes
Browse files Browse the repository at this point in the history
  • Loading branch information
ZohebShaikh committed Oct 16, 2024
1 parent e4918cc commit e1a450f
Show file tree
Hide file tree
Showing 8 changed files with 115 additions and 86 deletions.
4 changes: 2 additions & 2 deletions src/blueapi/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ def login(obj: dict) -> None:
cliAuthConfig: CLIAuthConfig = config.cliAuth
oauthConfig: OauthConfig = config.oauth
print("Logging in")
auth = TokenManager(oauth=oauthConfig, cliAuth=cliAuthConfig)
auth: TokenManager = TokenManager(oauth=oauthConfig, cliAuth=cliAuthConfig)
auth.start_device_flow()
else:
print("Please provide configuration to login!")
Expand All @@ -352,7 +352,7 @@ def logout(obj: dict) -> None:
if config.cliAuth and config.oauth:
oauthConfig: OauthConfig = config.oauth
cliAuthConfig: CLIAuthConfig = config.cliAuth
auth = TokenManager(cliAuth=cliAuthConfig, oauth=oauthConfig)
auth: TokenManager = TokenManager(cliAuth=cliAuthConfig, oauth=oauthConfig)
auth.logout()
print("Logged out")
else:
Expand Down
4 changes: 3 additions & 1 deletion src/blueapi/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ def __init__(

@classmethod
def from_config(cls, config: ApplicationConfig) -> "BlueapiClient":
rest = BlueapiRestClient(config.api, config.oauth, config.cliAuth)
rest: BlueapiRestClient = BlueapiRestClient(
config.api, config.oauth, config.cliAuth
)
if config.stomp is not None:
template = StompClient.for_broker(
broker=Broker(
Expand Down
16 changes: 8 additions & 8 deletions src/blueapi/client/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __init__(
cliAuthConfig: CLIAuthConfig | None = None,
) -> None:
self._config = config or RestConfig()
self._tokenHandler = None
self._tokenHandler: TokenManager | None = None
if authConfig and cliAuthConfig:
self._tokenHandler = TokenManager(authConfig, cliAuthConfig)

Expand Down Expand Up @@ -137,7 +137,7 @@ def _request_and_deserialize(
get_exception: Callable[[requests.Response], Exception | None] = _exception,
) -> T:
url = self._url(suffix)
headers = {
headers: dict[str, str] = {
"content-type": "application/json; charset=UTF-8",
}
if (
Expand All @@ -146,15 +146,15 @@ def _request_and_deserialize(
and self._tokenHandler.token["access_token"]
):
try:
access_token = self._tokenHandler.token["access_token"]
self._tokenHandler.authenticator.verify_token(access_token)
headers["Authorization"] = f"Bearer {access_token}"
auth_token: str = self._tokenHandler.token["access_token"]
self._tokenHandler.authenticator.verify_token(auth_token)
headers["Authorization"] = f"Bearer {auth_token}"
except jwt.ExpiredSignatureError:
if self._tokenHandler.refresh_auth_token():
access_token = self._tokenHandler.token["access_token"]
access_token: str = self._tokenHandler.token["access_token"]
headers["Authorization"] = f"Bearer {access_token}"
except Exception as e:
raise Exception from e
except Exception:
pass

Check warning on line 157 in src/blueapi/client/rest.py

View check run for this annotation

Codecov / codecov/patch

src/blueapi/client/rest.py#L148-L157

Added lines #L148 - L157 were not covered by tests
if data:
response = requests.request(method, url, json=data, headers=headers)
else:
Expand Down
43 changes: 24 additions & 19 deletions src/blueapi/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,28 +97,33 @@ class OauthConfig(BlueapiBaseModel):
refresh_url: str = ""

def model_post_init(self, __context: Any) -> None:
response = requests.get(self.oidc_config_url)
response: requests.Response = requests.get(self.oidc_config_url)
response.raise_for_status()
config_data = response.json()

self.device_auth_url = config_data.get("device_authorization_endpoint")
self.pkce_auth_url = config_data.get("authorization_endpoint")
self.token_url = config_data.get("token_endpoint")
self.issuer = config_data.get("issuer")
self.jwks_uri = config_data.get("jwks_uri")
self.logout_url = config_data.get("end_session_endpoint")
self.refresh_url = config_data.get("end_session_endpoint")
config_data: dict[str, str] = response.json()

device_auth_url: str | None = config_data.get("device_authorization_endpoint")
pkce_auth_url: str | None = config_data.get("authorization_endpoint")
token_url: str | None = config_data.get("token_endpoint")
issuer: str | None = config_data.get("issuer")
jwks_uri: str | None = config_data.get("jwks_uri")
logout_url: str | None = config_data.get("end_session_endpoint")
# post this we need to check if all the values are present
if any(
(
self.device_auth_url == "",
self.pkce_auth_url == "",
self.token_url == "",
self.issuer == "",
self.jwks_uri == "",
self.logout_url == "",
)
if (
device_auth_url
and pkce_auth_url
and token_url
and issuer
and jwks_uri
and logout_url
):
self.device_auth_url = device_auth_url
self.pkce_auth_url = pkce_auth_url
self.token_url = token_url
self.issuer = issuer
self.jwks_uri = jwks_uri
self.logout_url = logout_url
self.refresh_url = token_url
else:
raise ValueError("OIDC config is missing required fields")


Expand Down
46 changes: 25 additions & 21 deletions src/blueapi/service/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def decode_jwt(self, token: str, verify_expiration: bool = True) -> dict[str, st
signing_key = jwt.PyJWKClient(self.oauth.jwks_uri).get_signing_key_from_jwt(
token
)
decode = jwt.decode(
decode: dict[str, str] = jwt.decode(
token,
signing_key.key,
algorithms=["RS256"],
Expand All @@ -51,8 +51,8 @@ def decode_jwt(self, token: str, verify_expiration: bool = True) -> dict[str, st
return decode

def print_user_info(self, token: str) -> None:
decode = self.decode_jwt(token)
print(f'Logged in as {decode["name"]} with fed-id {decode["fedid"]}')
decode: dict[str, str] = self.decode_jwt(token)
print(f'Logged in as {decode.get("name")} with fed-id {decode.get("fedid")}')


class TokenManager:
Expand Down Expand Up @@ -80,43 +80,44 @@ def refresh_auth_token(self) -> bool:
)
if response.status_code == HTTPStatus.OK:
self.save_token(response.json())
self.load_token()
return True
return False

def save_token(self, token: dict[str, Any]) -> None:
token_json = json.dumps(token)
token_bytes = token_json.encode("utf-8")
token_base64 = base64.b64encode(token_bytes)
token_json: str = json.dumps(token)
token_bytes: bytes = token_json.encode("utf-8")
token_base64: bytes = base64.b64encode(token_bytes)
with open(os.path.expanduser(self.cliAuth.token_file_path), "wb") as token_file:
token_file.write(token_base64)

def load_token(self) -> None:
if not os.path.exists(os.path.expanduser(self.cliAuth.token_file_path)):
return
with open(os.path.expanduser(self.cliAuth.token_file_path), "rb") as token_file:
token_base64 = token_file.read()
token_bytes = base64.b64decode(token_base64)
token_json = token_bytes.decode("utf-8")
token_base64: bytes = token_file.read()
token_bytes: bytes = base64.b64decode(token_base64)
token_json: str = token_bytes.decode("utf-8")
self.token = json.loads(token_json)

def get_device_code(self):
response = requests.post(
response: requests.Response = requests.post(
self.oauth.token_url,
data={
"client_id": self.cliAuth.client_id,
"scope": "openid profile offline_access",
"audience": self.cliAuth.client_audience,
},
)
response_data = response.json()
response_data: dict[str, str] = response.json()
if response.status_code == 200:
return response_data["device_code"]
raise Exception("Failed to get device code.")

Check warning on line 115 in src/blueapi/service/authentication.py

View check run for this annotation

Codecov / codecov/patch

src/blueapi/service/authentication.py#L115

Added line #L115 was not covered by tests

def poll_for_token(
self, device_code: str, timeout: float = 30, polling_interval: float = 0.5
) -> dict[str, Any]:
too_late = time.time() + timeout
too_late: float = time.time() + timeout
while time.time() < too_late:
response = requests.post(
self.oauth.token_url,
Expand All @@ -138,12 +139,17 @@ def poll_for_token(
def start_device_flow(self) -> None:
if self.token:
try:
self.authenticator.verify_token(self.token["access_token"])
is_token_vaild: bool = self.authenticator.verify_token(
self.token["access_token"]
)
if is_token_vaild:
self.load_token()
self.authenticator.print_user_info(self.token["access_token"])
return

Check warning on line 148 in src/blueapi/service/authentication.py

View check run for this annotation

Codecov / codecov/patch

src/blueapi/service/authentication.py#L145-L148

Added lines #L145 - L148 were not covered by tests
except jwt.ExpiredSignatureError:
self.refresh_auth_token()
self.load_token()
self.authenticator.print_user_info(self.token["access_token"])
return
if self.refresh_auth_token():
self.authenticator.print_user_info(self.token["access_token"])
return

response: requests.Response = requests.post(
self.oauth.device_auth_url,
Expand All @@ -159,11 +165,9 @@ def start_device_flow(self) -> None:
f"{response_json['verification_uri_complete']}"
)
auth_token_json: dict[str, Any] = self.poll_for_token(device_code)
valid_token = self.authenticator.verify_token(
valid_token: bool = self.authenticator.verify_token(
auth_token_json["access_token"]
)
if valid_token:
self.save_token(auth_token_json)
self.load_token()
if self.token:
self.authenticator.print_user_info(self.token["access_token"])
self.authenticator.print_user_info(auth_token_json["access_token"])
19 changes: 0 additions & 19 deletions tests/unit_tests/client/test_rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,22 +101,3 @@ def test_refresh_if_signature_expired(rest_with_auth: BlueapiRestClient):
result = rest_with_auth.get_plans()
# Add assertions as needed
assert result == PlanResponse(plans=[PlanModel.from_plan(plan)])


@responses.activate
def test_verify_token_ignore_other_exceptions(rest_with_auth: BlueapiRestClient):
plan = Plan(name="my-plan", model=MyModel)
responses.add(
responses.GET,
"http://localhost:8000/plans",
json=PlanResponse(plans=[PlanModel.from_plan(plan)]).model_dump(),
status=200,
)
with (
patch("blueapi.service.Authenticator.verify_token") as mock_verify_token,
):
# Mock the verify_token function to return True (indicating a valid token)
mock_verify_token.side_effect = Exception
result = rest_with_auth.get_plans()
# Add assertions as needed
assert result == PlanResponse(plans=[PlanModel.from_plan(plan)])
67 changes: 52 additions & 15 deletions tests/unit_tests/test_cli.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import base64
import json
from datetime import datetime, timedelta

import jwt
import matplotlib
Expand Down Expand Up @@ -659,15 +658,9 @@ def valid_auth_config(tmp_path: Path) -> str:

@responses.activate
def test_login_success(runner: CliRunner, valid_auth_config: str):
payload: dict[str, Any] = {
"kid": "Key_identifier",
"sub": "1234567890", # Subject (usually user ID)
payload: dict[str, str] = {
"name": "John Doe",
"fedid": "jd1",
"iat": datetime.now(), # Issued at time
"exp": datetime.now() + timedelta(hours=1), # Expiration time
"aud": "client_audience",
"iss": "https://example.com/",
}

mock_json_responses: list[dict[str, str]] = [
Expand Down Expand Up @@ -732,15 +725,8 @@ def test_login_with_refresh_token(
).decode("utf-8")
)
payload: dict[str, Any] = {
"kid": "Key_identifier",
"sub": "1234567890", # Subject (usually user ID)
"name": "John Doe",
"fedid": "jd1",
"iat": datetime.now(), # Issued at time
"exp": datetime.now() + timedelta(hours=1), # Expiration time
"aud": "client_audience",
"iss": "https://example.com/",
"refresh_token": "refresh_token",
}

mock_json_responses: list[dict[str, str]] = [
Expand Down Expand Up @@ -782,6 +768,57 @@ def test_login_with_refresh_token(
assert result.exit_code == 0


@responses.activate
def test_login_edge_cases(runner: CliRunner, valid_auth_config: str, tmp_path: Path):
with open(tmp_path / "token", "w") as token_file:
# base64 encoded token
token_file.write(
base64.b64encode(
b'{"access_token":"token","refresh_token":"refresh_token"}'
).decode("utf-8")
)

mock_json_responses: list[dict[str, str]] = [
{
"device_authorization_endpoint": "https://example.com/device_authorization",
"authorization_endpoint": "https://example.com/authorization",
"token_endpoint": "https://example.com/token",
"issuer": "https://example.com",
"jwks_uri": "https://example.com/realms/master/protocol/openid-connect/certs",
"end_session_endpoint": "https://example.com/logout",
},
{
"details": "not found",
},
]
with responses.RequestsMock(assert_all_requests_are_fired=True) as requests_mock:
requests_mock.add(
requests_mock.GET,
"https://auth.example.com/realms/sample/.well-known/openid-configuration",
json=mock_json_responses[0],
status=200,
)
requests_mock.add(
requests_mock.POST,
"https://example.com/token",
json=mock_json_responses[1],
status=400,
)
requests_mock.add(
requests_mock.POST,
"https://example.com/device_authorization",
json=mock_json_responses[1],
status=400,
)
with (
patch("blueapi.service.Authenticator.decode_jwt") as mock_decode,
):
mock_decode.side_effect = jwt.ExpiredSignatureError
result = runner.invoke(main, ["-c", valid_auth_config, "login"])
assert "Logging in\n" == result.output
assert result.exit_code == 0


@responses.activate
def test_logout_success(runner: CliRunner, valid_auth_config: str, tmp_path: Path):
with open(tmp_path / "token", "w") as token_file:
Expand Down
2 changes: 1 addition & 1 deletion tests/unit_tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def test_oauth_config_model_post_init(mock_get):
assert oauth_config.issuer == mock_response["issuer"]
assert oauth_config.jwks_uri == mock_response["jwks_uri"]
assert oauth_config.logout_url == mock_response["end_session_endpoint"]
assert oauth_config.refresh_url == mock_response["end_session_endpoint"]
assert oauth_config.refresh_url == mock_response["token_endpoint"]


@mock.patch("requests.get")
Expand Down

0 comments on commit e1a450f

Please sign in to comment.