diff --git a/src/blueapi/cli/cli.py b/src/blueapi/cli/cli.py index aec56e258..cbddf008a 100644 --- a/src/blueapi/cli/cli.py +++ b/src/blueapi/cli/cli.py @@ -339,7 +339,7 @@ def login(obj: dict) -> None: cliAuthConfig: CLIAuthConfig = config.cliAuth oauthConfig: OauthConfig = config.oauth print("Logging in") - auth = TokenManager(oauth=oauthConfig, cliAuth=cliAuthConfig) + auth: TokenManager = TokenManager(oauth=oauthConfig, cliAuth=cliAuthConfig) auth.start_device_flow() else: print("Please provide configuration to login!") @@ -352,7 +352,7 @@ def logout(obj: dict) -> None: if config.cliAuth and config.oauth: oauthConfig: OauthConfig = config.oauth cliAuthConfig: CLIAuthConfig = config.cliAuth - auth = TokenManager(cliAuth=cliAuthConfig, oauth=oauthConfig) + auth: TokenManager = TokenManager(cliAuth=cliAuthConfig, oauth=oauthConfig) auth.logout() print("Logged out") else: diff --git a/src/blueapi/client/client.py b/src/blueapi/client/client.py index 9e3bdfa02..53940af62 100644 --- a/src/blueapi/client/client.py +++ b/src/blueapi/client/client.py @@ -39,7 +39,9 @@ def __init__( @classmethod def from_config(cls, config: ApplicationConfig) -> "BlueapiClient": - rest = BlueapiRestClient(config.api, config.oauth, config.cliAuth) + rest: BlueapiRestClient = BlueapiRestClient( + config.api, config.oauth, config.cliAuth + ) if config.stomp is not None: template = StompClient.for_broker( broker=Broker( diff --git a/src/blueapi/client/rest.py b/src/blueapi/client/rest.py index 32d49e999..a1585009b 100644 --- a/src/blueapi/client/rest.py +++ b/src/blueapi/client/rest.py @@ -47,7 +47,7 @@ def __init__( cliAuthConfig: CLIAuthConfig | None = None, ) -> None: self._config = config or RestConfig() - self._tokenHandler = None + self._tokenHandler: TokenManager | None = None if authConfig and cliAuthConfig: self._tokenHandler = TokenManager(authConfig, cliAuthConfig) @@ -137,7 +137,7 @@ def _request_and_deserialize( get_exception: Callable[[requests.Response], Exception | None] = _exception, ) -> T: url = self._url(suffix) - headers = { + headers: dict[str, str] = { "content-type": "application/json; charset=UTF-8", } if ( @@ -146,15 +146,15 @@ def _request_and_deserialize( and self._tokenHandler.token["access_token"] ): try: - access_token = self._tokenHandler.token["access_token"] - self._tokenHandler.authenticator.verify_token(access_token) - headers["Authorization"] = f"Bearer {access_token}" + auth_token: str = self._tokenHandler.token["access_token"] + self._tokenHandler.authenticator.verify_token(auth_token) + headers["Authorization"] = f"Bearer {auth_token}" except jwt.ExpiredSignatureError: if self._tokenHandler.refresh_auth_token(): - access_token = self._tokenHandler.token["access_token"] + access_token: str = self._tokenHandler.token["access_token"] headers["Authorization"] = f"Bearer {access_token}" - except Exception as e: - raise Exception from e + except Exception: + pass if data: response = requests.request(method, url, json=data, headers=headers) else: diff --git a/src/blueapi/config.py b/src/blueapi/config.py index 27ea9171f..94ac16f94 100644 --- a/src/blueapi/config.py +++ b/src/blueapi/config.py @@ -97,28 +97,33 @@ class OauthConfig(BlueapiBaseModel): refresh_url: str = "" def model_post_init(self, __context: Any) -> None: - response = requests.get(self.oidc_config_url) + response: requests.Response = requests.get(self.oidc_config_url) response.raise_for_status() - config_data = response.json() - - self.device_auth_url = config_data.get("device_authorization_endpoint") - self.pkce_auth_url = config_data.get("authorization_endpoint") - self.token_url = config_data.get("token_endpoint") - self.issuer = config_data.get("issuer") - self.jwks_uri = config_data.get("jwks_uri") - self.logout_url = config_data.get("end_session_endpoint") - self.refresh_url = config_data.get("end_session_endpoint") + config_data: dict[str, str] = response.json() + + device_auth_url: str | None = config_data.get("device_authorization_endpoint") + pkce_auth_url: str | None = config_data.get("authorization_endpoint") + token_url: str | None = config_data.get("token_endpoint") + issuer: str | None = config_data.get("issuer") + jwks_uri: str | None = config_data.get("jwks_uri") + logout_url: str | None = config_data.get("end_session_endpoint") # post this we need to check if all the values are present - if any( - ( - self.device_auth_url == "", - self.pkce_auth_url == "", - self.token_url == "", - self.issuer == "", - self.jwks_uri == "", - self.logout_url == "", - ) + if ( + device_auth_url + and pkce_auth_url + and token_url + and issuer + and jwks_uri + and logout_url ): + self.device_auth_url = device_auth_url + self.pkce_auth_url = pkce_auth_url + self.token_url = token_url + self.issuer = issuer + self.jwks_uri = jwks_uri + self.logout_url = logout_url + self.refresh_url = token_url + else: raise ValueError("OIDC config is missing required fields") diff --git a/src/blueapi/service/authentication.py b/src/blueapi/service/authentication.py index 705029a37..deabf5a7f 100644 --- a/src/blueapi/service/authentication.py +++ b/src/blueapi/service/authentication.py @@ -38,7 +38,7 @@ def decode_jwt(self, token: str, verify_expiration: bool = True) -> dict[str, st signing_key = jwt.PyJWKClient(self.oauth.jwks_uri).get_signing_key_from_jwt( token ) - decode = jwt.decode( + decode: dict[str, str] = jwt.decode( token, signing_key.key, algorithms=["RS256"], @@ -51,8 +51,8 @@ def decode_jwt(self, token: str, verify_expiration: bool = True) -> dict[str, st return decode def print_user_info(self, token: str) -> None: - decode = self.decode_jwt(token) - print(f'Logged in as {decode["name"]} with fed-id {decode["fedid"]}') + decode: dict[str, str] = self.decode_jwt(token) + print(f'Logged in as {decode.get("name")} with fed-id {decode.get("fedid")}') class TokenManager: @@ -80,13 +80,14 @@ def refresh_auth_token(self) -> bool: ) if response.status_code == HTTPStatus.OK: self.save_token(response.json()) + self.load_token() return True return False def save_token(self, token: dict[str, Any]) -> None: - token_json = json.dumps(token) - token_bytes = token_json.encode("utf-8") - token_base64 = base64.b64encode(token_bytes) + token_json: str = json.dumps(token) + token_bytes: bytes = token_json.encode("utf-8") + token_base64: bytes = base64.b64encode(token_bytes) with open(os.path.expanduser(self.cliAuth.token_file_path), "wb") as token_file: token_file.write(token_base64) @@ -94,13 +95,13 @@ def load_token(self) -> None: if not os.path.exists(os.path.expanduser(self.cliAuth.token_file_path)): return with open(os.path.expanduser(self.cliAuth.token_file_path), "rb") as token_file: - token_base64 = token_file.read() - token_bytes = base64.b64decode(token_base64) - token_json = token_bytes.decode("utf-8") + token_base64: bytes = token_file.read() + token_bytes: bytes = base64.b64decode(token_base64) + token_json: str = token_bytes.decode("utf-8") self.token = json.loads(token_json) def get_device_code(self): - response = requests.post( + response: requests.Response = requests.post( self.oauth.token_url, data={ "client_id": self.cliAuth.client_id, @@ -108,7 +109,7 @@ def get_device_code(self): "audience": self.cliAuth.client_audience, }, ) - response_data = response.json() + response_data: dict[str, str] = response.json() if response.status_code == 200: return response_data["device_code"] raise Exception("Failed to get device code.") @@ -116,7 +117,7 @@ def get_device_code(self): def poll_for_token( self, device_code: str, timeout: float = 30, polling_interval: float = 0.5 ) -> dict[str, Any]: - too_late = time.time() + timeout + too_late: float = time.time() + timeout while time.time() < too_late: response = requests.post( self.oauth.token_url, @@ -138,12 +139,17 @@ def poll_for_token( def start_device_flow(self) -> None: if self.token: try: - self.authenticator.verify_token(self.token["access_token"]) + is_token_vaild: bool = self.authenticator.verify_token( + self.token["access_token"] + ) + if is_token_vaild: + self.load_token() + self.authenticator.print_user_info(self.token["access_token"]) + return except jwt.ExpiredSignatureError: - self.refresh_auth_token() - self.load_token() - self.authenticator.print_user_info(self.token["access_token"]) - return + if self.refresh_auth_token(): + self.authenticator.print_user_info(self.token["access_token"]) + return response: requests.Response = requests.post( self.oauth.device_auth_url, @@ -159,11 +165,9 @@ def start_device_flow(self) -> None: f"{response_json['verification_uri_complete']}" ) auth_token_json: dict[str, Any] = self.poll_for_token(device_code) - valid_token = self.authenticator.verify_token( + valid_token: bool = self.authenticator.verify_token( auth_token_json["access_token"] ) if valid_token: self.save_token(auth_token_json) - self.load_token() - if self.token: - self.authenticator.print_user_info(self.token["access_token"]) + self.authenticator.print_user_info(auth_token_json["access_token"]) diff --git a/tests/unit_tests/client/test_rest.py b/tests/unit_tests/client/test_rest.py index f32fbc41c..291730e46 100644 --- a/tests/unit_tests/client/test_rest.py +++ b/tests/unit_tests/client/test_rest.py @@ -101,22 +101,3 @@ def test_refresh_if_signature_expired(rest_with_auth: BlueapiRestClient): result = rest_with_auth.get_plans() # Add assertions as needed assert result == PlanResponse(plans=[PlanModel.from_plan(plan)]) - - -@responses.activate -def test_verify_token_ignore_other_exceptions(rest_with_auth: BlueapiRestClient): - plan = Plan(name="my-plan", model=MyModel) - responses.add( - responses.GET, - "http://localhost:8000/plans", - json=PlanResponse(plans=[PlanModel.from_plan(plan)]).model_dump(), - status=200, - ) - with ( - patch("blueapi.service.Authenticator.verify_token") as mock_verify_token, - ): - # Mock the verify_token function to return True (indicating a valid token) - mock_verify_token.side_effect = Exception - result = rest_with_auth.get_plans() - # Add assertions as needed - assert result == PlanResponse(plans=[PlanModel.from_plan(plan)]) diff --git a/tests/unit_tests/test_cli.py b/tests/unit_tests/test_cli.py index 1194b6f6a..3e8e49c50 100644 --- a/tests/unit_tests/test_cli.py +++ b/tests/unit_tests/test_cli.py @@ -1,6 +1,5 @@ import base64 import json -from datetime import datetime, timedelta import jwt import matplotlib @@ -659,15 +658,9 @@ def valid_auth_config(tmp_path: Path) -> str: @responses.activate def test_login_success(runner: CliRunner, valid_auth_config: str): - payload: dict[str, Any] = { - "kid": "Key_identifier", - "sub": "1234567890", # Subject (usually user ID) + payload: dict[str, str] = { "name": "John Doe", "fedid": "jd1", - "iat": datetime.now(), # Issued at time - "exp": datetime.now() + timedelta(hours=1), # Expiration time - "aud": "client_audience", - "iss": "https://example.com/", } mock_json_responses: list[dict[str, str]] = [ @@ -732,15 +725,8 @@ def test_login_with_refresh_token( ).decode("utf-8") ) payload: dict[str, Any] = { - "kid": "Key_identifier", - "sub": "1234567890", # Subject (usually user ID) "name": "John Doe", "fedid": "jd1", - "iat": datetime.now(), # Issued at time - "exp": datetime.now() + timedelta(hours=1), # Expiration time - "aud": "client_audience", - "iss": "https://example.com/", - "refresh_token": "refresh_token", } mock_json_responses: list[dict[str, str]] = [ @@ -782,6 +768,57 @@ def test_login_with_refresh_token( assert result.exit_code == 0 +@responses.activate +def test_login_edge_cases(runner: CliRunner, valid_auth_config: str, tmp_path: Path): + with open(tmp_path / "token", "w") as token_file: + # base64 encoded token + token_file.write( + base64.b64encode( + b'{"access_token":"token","refresh_token":"refresh_token"}' + ).decode("utf-8") + ) + + mock_json_responses: list[dict[str, str]] = [ + { + "device_authorization_endpoint": "https://example.com/device_authorization", + "authorization_endpoint": "https://example.com/authorization", + "token_endpoint": "https://example.com/token", + "issuer": "https://example.com", + "jwks_uri": "https://example.com/realms/master/protocol/openid-connect/certs", + "end_session_endpoint": "https://example.com/logout", + }, + { + "details": "not found", + }, + ] + with responses.RequestsMock(assert_all_requests_are_fired=True) as requests_mock: + requests_mock.add( + requests_mock.GET, + "https://auth.example.com/realms/sample/.well-known/openid-configuration", + json=mock_json_responses[0], + status=200, + ) + requests_mock.add( + requests_mock.POST, + "https://example.com/token", + json=mock_json_responses[1], + status=400, + ) + requests_mock.add( + requests_mock.POST, + "https://example.com/device_authorization", + json=mock_json_responses[1], + status=400, + ) + with ( + patch("blueapi.service.Authenticator.decode_jwt") as mock_decode, + ): + mock_decode.side_effect = jwt.ExpiredSignatureError + result = runner.invoke(main, ["-c", valid_auth_config, "login"]) + assert "Logging in\n" == result.output + assert result.exit_code == 0 + + @responses.activate def test_logout_success(runner: CliRunner, valid_auth_config: str, tmp_path: Path): with open(tmp_path / "token", "w") as token_file: diff --git a/tests/unit_tests/test_config.py b/tests/unit_tests/test_config.py index 1c3962927..99bb8514f 100644 --- a/tests/unit_tests/test_config.py +++ b/tests/unit_tests/test_config.py @@ -174,7 +174,7 @@ def test_oauth_config_model_post_init(mock_get): assert oauth_config.issuer == mock_response["issuer"] assert oauth_config.jwks_uri == mock_response["jwks_uri"] assert oauth_config.logout_url == mock_response["end_session_endpoint"] - assert oauth_config.refresh_url == mock_response["end_session_endpoint"] + assert oauth_config.refresh_url == mock_response["token_endpoint"] @mock.patch("requests.get")