From 132616479d2184be6ab7917e6c11d7b8a84214bf Mon Sep 17 00:00:00 2001 From: Zoheb Shaikh Date: Thu, 17 Oct 2024 13:22:09 +0100 Subject: [PATCH] Refactored code --- src/blueapi/cli/cli.py | 28 +- src/blueapi/client/client.py | 4 +- src/blueapi/client/rest.py | 28 +- src/blueapi/config.py | 13 +- src/blueapi/service/__init__.py | 4 +- src/blueapi/service/authentication.py | 156 +++++++---- src/blueapi/service/main.py | 4 +- tests/unit_tests/client/test_rest.py | 26 +- .../unit_tests/service/test_authentication.py | 263 ++++++++++-------- tests/unit_tests/test_cli.py | 4 +- tests/unit_tests/test_config.py | 6 +- 11 files changed, 311 insertions(+), 225 deletions(-) diff --git a/src/blueapi/cli/cli.py b/src/blueapi/cli/cli.py index cbddf008a..4dfcf0f5d 100644 --- a/src/blueapi/cli/cli.py +++ b/src/blueapi/cli/cli.py @@ -17,9 +17,13 @@ from blueapi.client.client import BlueapiClient from blueapi.client.event_bus import AnyEvent, BlueskyStreamingError, EventBusClient from blueapi.client.rest import BlueskyRemoteControlError -from blueapi.config import ApplicationConfig, CLIAuthConfig, ConfigLoader, OauthConfig +from blueapi.config import ( + ApplicationConfig, + CLIClientConfig, + ConfigLoader, +) from blueapi.core import DataEvent -from blueapi.service.authentication import TokenManager +from blueapi.service.authentication import CLITokenManager, SessionManager from blueapi.service.main import start from blueapi.service.openapi import ( DOCS_SCHEMA_LOCATION, @@ -335,11 +339,13 @@ def scratch(obj: dict) -> None: @click.pass_obj def login(obj: dict) -> None: config: ApplicationConfig = obj["config"] - if config.cliAuth and config.oauth: - cliAuthConfig: CLIAuthConfig = config.cliAuth - oauthConfig: OauthConfig = config.oauth + if isinstance(config.oauth_client, CLIClientConfig) and config.oauth_server: print("Logging in") - auth: TokenManager = TokenManager(oauth=oauthConfig, cliAuth=cliAuthConfig) + auth: SessionManager = SessionManager( + server_config=config.oauth_server, + client_config=config.oauth_client, + token_manager=CLITokenManager(Path(config.oauth_client.token_file_path)), + ) auth.start_device_flow() else: print("Please provide configuration to login!") @@ -349,10 +355,12 @@ def login(obj: dict) -> None: @click.pass_obj def logout(obj: dict) -> None: config: ApplicationConfig = obj["config"] - if config.cliAuth and config.oauth: - oauthConfig: OauthConfig = config.oauth - cliAuthConfig: CLIAuthConfig = config.cliAuth - auth: TokenManager = TokenManager(cliAuth=cliAuthConfig, oauth=oauthConfig) + if isinstance(config.oauth_client, CLIClientConfig) and config.oauth_server: + auth: SessionManager = SessionManager( + server_config=config.oauth_server, + client_config=config.oauth_client, + token_manager=CLITokenManager(Path(config.oauth_client.token_file_path)), + ) auth.logout() print("Logged out") else: diff --git a/src/blueapi/client/client.py b/src/blueapi/client/client.py index 53940af62..f3752cb3c 100644 --- a/src/blueapi/client/client.py +++ b/src/blueapi/client/client.py @@ -6,6 +6,7 @@ from blueapi.config import ApplicationConfig from blueapi.core.bluesky_types import DataEvent +from blueapi.service.authentication import SessionManager from blueapi.service.model import ( DeviceModel, DeviceResponse, @@ -40,7 +41,8 @@ def __init__( @classmethod def from_config(cls, config: ApplicationConfig) -> "BlueapiClient": rest: BlueapiRestClient = BlueapiRestClient( - config.api, config.oauth, config.cliAuth + config.api, + SessionManager.from_config(config.oauth_server, config.oauth_client), ) if config.stomp is not None: template = StompClient.for_broker( diff --git a/src/blueapi/client/rest.py b/src/blueapi/client/rest.py index a1585009b..8b1f20977 100644 --- a/src/blueapi/client/rest.py +++ b/src/blueapi/client/rest.py @@ -5,8 +5,8 @@ import requests from pydantic import TypeAdapter -from blueapi.config import CLIAuthConfig, OauthConfig, RestConfig -from blueapi.service.authentication import TokenManager +from blueapi.config import RestConfig +from blueapi.service.authentication import SessionManager from blueapi.service.model import ( DeviceModel, DeviceResponse, @@ -43,13 +43,10 @@ class BlueapiRestClient: def __init__( self, config: RestConfig | None = None, - authConfig: OauthConfig | None = None, - cliAuthConfig: CLIAuthConfig | None = None, + session_manager: SessionManager | None = None, ) -> None: self._config = config or RestConfig() - self._tokenHandler: TokenManager | None = None - if authConfig and cliAuthConfig: - self._tokenHandler = TokenManager(authConfig, cliAuthConfig) + self._session_manager: SessionManager | None = session_manager def get_plans(self) -> PlanResponse: return self._request_and_deserialize("/plans", PlanResponse) @@ -140,19 +137,14 @@ def _request_and_deserialize( headers: dict[str, str] = { "content-type": "application/json; charset=UTF-8", } - if ( - self._tokenHandler - and self._tokenHandler.token - and self._tokenHandler.token["access_token"] - ): + if self._session_manager and (token := self._session_manager.get_token()): try: - auth_token: str = self._tokenHandler.token["access_token"] - self._tokenHandler.authenticator.verify_token(auth_token) - headers["Authorization"] = f"Bearer {auth_token}" + self._session_manager.authenticator.verify_token(token["access_token"]) + headers["Authorization"] = f"Bearer {token['access_token']}" except jwt.ExpiredSignatureError: - if self._tokenHandler.refresh_auth_token(): - access_token: str = self._tokenHandler.token["access_token"] - headers["Authorization"] = f"Bearer {access_token}" + if token := self._session_manager.refresh_auth_token(): + if token := self._session_manager.get_token(): + headers["Authorization"] = f"Bearer {token['access_token']}" except Exception: pass if data: diff --git a/src/blueapi/config.py b/src/blueapi/config.py index 1b06a7083..378bdc33b 100644 --- a/src/blueapi/config.py +++ b/src/blueapi/config.py @@ -83,7 +83,7 @@ class ScratchConfig(BlueapiBaseModel): repositories: list[ScratchRepository] = Field(default_factory=list) -class OauthConfig(BlueapiBaseModel): +class OAuthServerConfig(BlueapiBaseModel): oidc_config_url: str = Field( description="URL to fetch OIDC config from the provider" ) @@ -125,13 +125,13 @@ def model_post_init(self, __context: Any) -> None: raise ValueError("OIDC config is missing required fields") -class BaseAuthConfig(BlueapiBaseModel): +class OAuthClientConfig(BlueapiBaseModel): client_id: str = Field(description="Client ID") client_audience: str = Field(description="Client Audience") -class CLIAuthConfig(BaseAuthConfig): - token_file_path: str = "~/token" +class CLIClientConfig(OAuthClientConfig): + token_file_path: Path | None = Path("~/token") class ApplicationConfig(BlueapiBaseModel): @@ -145,9 +145,8 @@ class ApplicationConfig(BlueapiBaseModel): logging: LoggingConfig = Field(default_factory=LoggingConfig) api: RestConfig = Field(default_factory=RestConfig) scratch: ScratchConfig | None = None - oauth: OauthConfig | None = None - cliAuth: CLIAuthConfig | None = None - swaggerAuth: BaseAuthConfig | None = None + oauth_server: OAuthServerConfig | None = None + oauth_client: CLIClientConfig | None = None def __eq__(self, other: object) -> bool: if isinstance(other, ApplicationConfig): diff --git a/src/blueapi/service/__init__.py b/src/blueapi/service/__init__.py index 4f08d98e7..ae9ffaf79 100644 --- a/src/blueapi/service/__init__.py +++ b/src/blueapi/service/__init__.py @@ -1,4 +1,4 @@ -from .authentication import Authenticator, TokenManager +from .authentication import Authenticator, SessionManager from .model import DeviceModel, PlanModel -__all__ = ["PlanModel", "DeviceModel", "Authenticator", "TokenManager"] +__all__ = ["PlanModel", "DeviceModel", "Authenticator", "SessionManager"] diff --git a/src/blueapi/service/authentication.py b/src/blueapi/service/authentication.py index deabf5a7f..3ebe9150a 100644 --- a/src/blueapi/service/authentication.py +++ b/src/blueapi/service/authentication.py @@ -2,17 +2,19 @@ import json import os import time +from abc import ABC, abstractmethod from enum import Enum from http import HTTPStatus +from pathlib import Path from typing import Any import jwt import requests from blueapi.config import ( - BaseAuthConfig, - CLIAuthConfig, - OauthConfig, + CLIClientConfig, + OAuthClientConfig, + OAuthServerConfig, ) @@ -24,28 +26,28 @@ class AuthenticationType(Enum): class Authenticator: def __init__( self, - oauth: OauthConfig, - baseAuthConfig: BaseAuthConfig, + server_config: OAuthServerConfig, + client_config: OAuthClientConfig, ): - self.oauth: OauthConfig = oauth - self.baseAuthConfig: BaseAuthConfig = baseAuthConfig + self._server_config: OAuthServerConfig = server_config + self._client_config: OAuthClientConfig = client_config def verify_token(self, token: str, verify_expiration: bool = True) -> bool: self.decode_jwt(token, verify_expiration) return True def decode_jwt(self, token: str, verify_expiration: bool = True) -> dict[str, str]: - signing_key = jwt.PyJWKClient(self.oauth.jwks_uri).get_signing_key_from_jwt( - token - ) + signing_key = jwt.PyJWKClient( + self._server_config.jwks_uri + ).get_signing_key_from_jwt(token) decode: dict[str, str] = jwt.decode( token, signing_key.key, algorithms=["RS256"], options={"verify_exp": verify_expiration}, verify=True, - audience=self.baseAuthConfig.client_audience, - issuer=self.oauth.issuer, + audience=self._client_config.client_audience, + issuer=self._server_config.issuer, leeway=5, ) return decode @@ -55,58 +57,97 @@ def print_user_info(self, token: str) -> None: print(f'Logged in as {decode.get("name")} with fed-id {decode.get("fedid")}') -class TokenManager: - def __init__(self, oauth: OauthConfig, cliAuth: CLIAuthConfig) -> None: - self.oauth: OauthConfig = oauth - self.cliAuth: CLIAuthConfig = cliAuth - self.token: Any = None - self.authenticator: Authenticator = Authenticator(self.oauth, self.cliAuth) - self.load_token() +class TokenManager(ABC): + @abstractmethod + def save_token(self, token: dict[str, Any]): ... + @abstractmethod + def load_token(token) -> dict[str, Any] | None: ... + @abstractmethod + def delete_token(self): ... - def logout(self) -> None: - if os.path.exists(os.path.expanduser(self.cliAuth.token_file_path)): - os.remove(os.path.expanduser(self.cliAuth.token_file_path)) - def refresh_auth_token(self) -> bool: - if self.token: - response = requests.post( - self.oauth.token_url, - headers={"Content-Type": "application/x-www-form-urlencoded"}, - data={ - "client_id": self.cliAuth.client_id, - "grant_type": "refresh_token", - "refresh_token": self.token["refresh_token"], - }, - ) - if response.status_code == HTTPStatus.OK: - self.save_token(response.json()) - self.load_token() - return True - return False +class CLITokenManager(TokenManager): + def __init__(self, token_file_path: Path) -> None: + self._token_file_path: Path = token_file_path def save_token(self, token: dict[str, Any]) -> None: token_json: str = json.dumps(token) token_bytes: bytes = token_json.encode("utf-8") token_base64: bytes = base64.b64encode(token_bytes) - with open(os.path.expanduser(self.cliAuth.token_file_path), "wb") as token_file: + with open(os.path.expanduser(self._token_file_path), "wb") as token_file: token_file.write(token_base64) - def load_token(self) -> None: - if not os.path.exists(os.path.expanduser(self.cliAuth.token_file_path)): - return - with open(os.path.expanduser(self.cliAuth.token_file_path), "rb") as token_file: + def load_token(self) -> dict[str, Any] | None: + file_path = os.path.expanduser(self._token_file_path) + if not os.path.exists(file_path): + return None + with open(file_path, "rb") as token_file: token_base64: bytes = token_file.read() token_bytes: bytes = base64.b64decode(token_base64) token_json: str = token_bytes.decode("utf-8") - self.token = json.loads(token_json) + return json.loads(token_json) + + def delete_token(self) -> None: + if os.path.exists(os.path.expanduser(self._token_file_path)): + os.remove(os.path.expanduser(self._token_file_path)) + + +class SessionManager: + def __init__( + self, + server_config: OAuthServerConfig, + client_config: OAuthClientConfig, + token_manager: TokenManager, + ) -> None: + self._server_config: OAuthServerConfig = server_config + self._client_config: OAuthClientConfig = client_config + self.authenticator: Authenticator = Authenticator(server_config, client_config) + self._token_manager = token_manager + + @classmethod + def from_config( + cls, + server_config: OAuthServerConfig | None, + client_config: OAuthClientConfig | None, + ) -> "SessionManager": + if server_config and client_config: + if isinstance(client_config, CLIClientConfig): + return SessionManager( + server_config, + client_config, + CLITokenManager(Path(client_config.token_file_path)), # type: ignore + ) + # raise NotImplementedError("Only CLI client config is supported") + + def get_token(self) -> dict[str, Any] | None: + return self._token_manager.load_token() + + def logout(self) -> None: + self._token_manager.delete_token() + + def refresh_auth_token(self) -> dict[str, Any] | None: + if token := self._token_manager.load_token(): + response = requests.post( + self._server_config.token_url, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + data={ + "client_id": self._client_config.client_id, + "grant_type": "refresh_token", + "refresh_token": token["refresh_token"], + }, + ) + if response.status_code == HTTPStatus.OK: + token = response.json() + self._token_manager.save_token(token) + return token def get_device_code(self): response: requests.Response = requests.post( - self.oauth.token_url, + self._server_config.token_url, data={ - "client_id": self.cliAuth.client_id, + "client_id": self._client_config.client_id, "scope": "openid profile offline_access", - "audience": self.cliAuth.client_audience, + "audience": self._client_config.client_audience, }, ) response_data: dict[str, str] = response.json() @@ -120,12 +161,12 @@ def poll_for_token( too_late: float = time.time() + timeout while time.time() < too_late: response = requests.post( - self.oauth.token_url, + self._server_config.token_url, headers={"Content-Type": "application/x-www-form-urlencoded"}, data={ "grant_type": "urn:ietf:params:oauth:grant-type:device_code", "device_code": device_code, - "client_id": self.cliAuth.client_id, + "client_id": self._client_config.client_id, }, ) if response.status_code == HTTPStatus.OK: @@ -137,24 +178,23 @@ def poll_for_token( raise TimeoutError("Polling timed out") def start_device_flow(self) -> None: - if self.token: + if token := self._token_manager.load_token(): try: is_token_vaild: bool = self.authenticator.verify_token( - self.token["access_token"] + token["access_token"] ) if is_token_vaild: - self.load_token() - self.authenticator.print_user_info(self.token["access_token"]) + self.authenticator.print_user_info(token["access_token"]) return except jwt.ExpiredSignatureError: - if self.refresh_auth_token(): - self.authenticator.print_user_info(self.token["access_token"]) + if token := self.refresh_auth_token(): + self.authenticator.print_user_info(token["access_token"]) return response: requests.Response = requests.post( - self.oauth.device_auth_url, + self._server_config.device_auth_url, headers={"Content-Type": "application/x-www-form-urlencoded"}, - data={"client_id": self.cliAuth.client_id}, + data={"client_id": self._client_config.client_id}, ) if response.status_code == HTTPStatus.OK: @@ -169,5 +209,5 @@ def start_device_flow(self) -> None: auth_token_json["access_token"] ) if valid_token: - self.save_token(auth_token_json) + self._token_manager.save_token(auth_token_json) self.authenticator.print_user_info(auth_token_json["access_token"]) diff --git a/src/blueapi/service/main.py b/src/blueapi/service/main.py index 4d1f031f5..eefe42f6c 100644 --- a/src/blueapi/service/main.py +++ b/src/blueapi/service/main.py @@ -376,9 +376,9 @@ def start(config: ApplicationConfig): global AUTHENTICATOR app.state.config = config - if config.swaggerAuth and config.oauth: + if config.oauth_client and config.oauth_server: AUTHENTICATOR = Authenticator( - oauth=config.oauth, baseAuthConfig=config.swaggerAuth + server_config=config.oauth_server, client_config=config.oauth_client ) uvicorn.run(app, host=config.api.host, port=config.api.port) diff --git a/tests/unit_tests/client/test_rest.py b/tests/unit_tests/client/test_rest.py index 7dbab5eda..249c9eb7a 100644 --- a/tests/unit_tests/client/test_rest.py +++ b/tests/unit_tests/client/test_rest.py @@ -1,3 +1,5 @@ +import base64 +from pathlib import Path from unittest.mock import Mock, patch import jwt @@ -6,14 +8,15 @@ from pydantic import BaseModel from blueapi.client.rest import BlueapiRestClient, BlueskyRemoteControlError -from blueapi.config import CLIAuthConfig, OauthConfig +from blueapi.config import OAuthClientConfig, OAuthServerConfig from blueapi.core.bluesky_types import Plan +from blueapi.service.authentication import CLITokenManager, SessionManager from blueapi.service.model import PlanModel, PlanResponse @pytest.fixture @responses.activate -def rest() -> BlueapiRestClient: +def rest(tmp_path: Path) -> BlueapiRestClient: responses.add( responses.GET, "http://example.com", @@ -27,10 +30,19 @@ def rest() -> BlueapiRestClient: }, status=200, ) - return BlueapiRestClient( - cliAuthConfig=CLIAuthConfig(client_id="foo", client_audience="bar"), - authConfig=OauthConfig(oidc_config_url="http://example.com"), + with open(tmp_path / "token", "w") as token_file: + # base64 encoded token + token_file.write( + base64.b64encode( + b'{"access_token":"token","refresh_token":"refresh_token"}' + ).decode("utf-8") + ) + session_manager = SessionManager( + token_manager=CLITokenManager(tmp_path / "token"), + client_config=OAuthClientConfig(client_id="foo", client_audience="bar"), + server_config=OAuthServerConfig(oidc_config_url="http://example.com"), ) + return BlueapiRestClient(session_manager=session_manager) @pytest.mark.parametrize( @@ -88,7 +100,9 @@ def test_refresh_if_signature_expired(rest: BlueapiRestClient): ) with ( patch("blueapi.service.Authenticator.verify_token") as mock_verify_token, - patch("blueapi.service.TokenManager.refresh_auth_token") as mock_refresh_token, + patch( + "blueapi.service.SessionManager.refresh_auth_token" + ) as mock_refresh_token, ): mock_verify_token.side_effect = jwt.ExpiredSignatureError mock_refresh_token.return_value = True diff --git a/tests/unit_tests/service/test_authentication.py b/tests/unit_tests/service/test_authentication.py index a4c402370..cd2a6f3fa 100644 --- a/tests/unit_tests/service/test_authentication.py +++ b/tests/unit_tests/service/test_authentication.py @@ -1,125 +1,156 @@ +import base64 import os from http import HTTPStatus -from unittest import TestCase, mock +from pathlib import Path +from unittest import mock import jwt import pytest from jwt import PyJWTError -from blueapi.config import BaseAuthConfig, CLIAuthConfig, OauthConfig -from blueapi.service.authentication import Authenticator, TokenManager - - -class TestAuthenticator(TestCase): - @mock.patch("requests.get") - def setUp(self, mock_requests_get): - mock_requests_get.return_value.status_code = 200 - mock_requests_get.return_value.json.return_value = { - "device_authorization_endpoint": "https://example.com/device_authorization", - "authorization_endpoint": "https://example.com/authorization", - "token_endpoint": "https://example.com/token", - "issuer": "https://example.com", - "jwks_uri": "https://example.com/.well-known/jwks.json", - "end_session_endpoint": "https://example.com/logout", - } - self.oauth_config = OauthConfig( - oidc_config_url="https://auth.example.com/realms/sample/.well-known/openid-configuration", - ) - self.base_auth_config = BaseAuthConfig( - client_id="example_client_id", client_audience="example_audience" - ) - self.authenticator = Authenticator(self.oauth_config, self.base_auth_config) - - @mock.patch("jwt.decode") - @mock.patch("jwt.PyJWKClient.get_signing_key_from_jwt") - def test_verify_token_valid(self, mock_get_signing_key, mock_decode): - decode_retun_value = {"token": "valid_token", "name": "John Doe"} - mock_decode.return_value = decode_retun_value - valid_token = self.authenticator.verify_token(decode_retun_value["token"]) - self.assertTrue(valid_token) - - @mock.patch("jwt.decode") - @mock.patch("jwt.PyJWKClient.get_signing_key_from_jwt") - def test_verify_token_invalid(self, mock_get_signing_key, mock_decode): - mock_decode.side_effect = jwt.ExpiredSignatureError - token = "invalid_token" - with pytest.raises(PyJWTError): - self.authenticator.verify_token(token) - - @mock.patch("jwt.decode") - @mock.patch("jwt.PyJWKClient.get_signing_key_from_jwt") - def test_user_info( - self, - mock_get_signing_key, - mock_decode, - ): - mock_decode.return_value = { - "name": "John Doe", - "fedid": "12345", - } - self.authenticator.print_user_info("valid_token") - - -class TestTokenManager(TestCase): - @mock.patch("requests.get") - def setUp(self, mock_requests_get): - mock_requests_get.return_value.status_code = 200 - mock_requests_get.return_value.json.return_value = { - "device_authorization_endpoint": "https://example.com/device_authorization", - "authorization_endpoint": "https://example.com/authorization", - "token_endpoint": "https://example.com/token", - "issuer": "https://example.com", - "jwks_uri": "https://example.com/.well-known/jwks.json", - "end_session_endpoint": "https://example.com/logout", - } - self.oauth_config = OauthConfig( - oidc_config_url="https://auth.example.com/realms/sample/.well-known/openid-configuration", - ) - self.cli_auth_config = CLIAuthConfig( - client_id="client_id", - client_audience="client_audience", - token_file_path="~/token", - ) - self.token_manager = TokenManager(self.oauth_config, self.cli_auth_config) - - @mock.patch("os.path.exists") - @mock.patch("os.remove") - def test_logout(self, mock_remove, mock_exists): - mock_exists.return_value = True - self.token_manager.logout() - mock_remove.assert_called_once_with( - os.path.expanduser(self.cli_auth_config.token_file_path) +from blueapi.config import CLIClientConfig, OAuthClientConfig, OAuthServerConfig +from blueapi.service.authentication import Authenticator, SessionManager + + +@pytest.fixture +def mock_client_config(tmp_path: Path) -> OAuthClientConfig: + return CLIClientConfig( + client_id="client_id", + client_audience="client_audience", + token_file_path=tmp_path / "token", + ) + + +@pytest.fixture +@mock.patch("requests.get") +def mock_server_config(mock_requests_get) -> OAuthServerConfig: + mock_requests_get.return_value.status_code = 200 + mock_requests_get.return_value.json.return_value = { + "device_authorization_endpoint": "https://example.com/device_authorization", + "authorization_endpoint": "https://example.com/authorization", + "token_endpoint": "https://example.com/token", + "issuer": "https://example.com", + "jwks_uri": "https://example.com/.well-known/jwks.json", + "end_session_endpoint": "https://example.com/logout", + } + return OAuthServerConfig( + oidc_config_url="https://auth.example.com/realms/sample/.well-known/openid-configuration", + ) + + +@pytest.fixture +def mock_session_manager(mock_client_config, mock_server_config) -> SessionManager: + session_manager = SessionManager.from_config(mock_server_config, mock_client_config) + return session_manager + + +@pytest.fixture +def mock_connected_client_config(mock_client_config: OAuthClientConfig): + assert isinstance(mock_client_config, CLIClientConfig) + with open(mock_client_config.token_file_path, "w") as token_file: + # base64 encoded token + token_file.write( + base64.b64encode( + b'{"access_token":"token","refresh_token":"refresh_token"}' + ).decode("utf-8") ) + return mock_client_config + + +@pytest.fixture +def mock_authenticator(mock_server_config, mock_client_config) -> Authenticator: + return Authenticator(mock_server_config, mock_client_config) + - @mock.patch("requests.post") - def test_refresh_auth_token(self, mock_post): - self.token_manager.token = {"refresh_token": "refresh_token"} - mock_post.return_value.status_code = HTTPStatus.OK - mock_post.return_value.json.return_value = {"access_token": "new_access_token"} - result = self.token_manager.refresh_auth_token() - self.assertTrue(result) - - @mock.patch("requests.post") - def test_get_device_code(self, mock_post): - mock_post.return_value.status_code = HTTPStatus.OK - mock_post.return_value.json.return_value = {"device_code": "device_code"} - device_code = self.token_manager.get_device_code() - self.assertEqual(device_code, "device_code") - - @mock.patch("requests.post") - def test_poll_for_token(self, mock_post): - mock_post.return_value.status_code = HTTPStatus.OK - mock_post.return_value.json.return_value = {"access_token": "access_token"} - device_code = "device_code" - token = self.token_manager.poll_for_token(device_code) - self.assertEqual(token, {"access_token": "access_token"}) - - @mock.patch("requests.post") - @mock.patch("time.sleep") - def test_poll_for_token_timeout(self, mock_sleep, mock_post): - mock_post.return_value.status_code = HTTPStatus.BAD_REQUEST - device_code = "device_code" - with self.assertRaises(TimeoutError): - self.token_manager.poll_for_token( - device_code, timeout=1, polling_interval=0.1 - ) +@mock.patch("jwt.decode") +@mock.patch("jwt.PyJWKClient.get_signing_key_from_jwt") +def test_verify_token_valid( + mock_get_signing_key, mock_decode, mock_authenticator: Authenticator +): + decode_retun_value = {"token": "valid_token", "name": "John Doe"} + mock_decode.return_value = decode_retun_value + valid_token = mock_authenticator.verify_token(decode_retun_value["token"]) + assert valid_token + + +@mock.patch("jwt.decode") +@mock.patch("jwt.PyJWKClient.get_signing_key_from_jwt") +def test_verify_token_invalid( + mock_get_signing_key, mock_decode, mock_authenticator: Authenticator +): + mock_decode.side_effect = jwt.ExpiredSignatureError + token = "invalid_token" + with pytest.raises(PyJWTError): + mock_authenticator.verify_token(token) + + +@mock.patch("jwt.decode") +@mock.patch("jwt.PyJWKClient.get_signing_key_from_jwt") +def test_user_info( + mock_get_signing_key, + mock_decode, + mock_authenticator: Authenticator, +): + mock_decode.return_value = { + "name": "John Doe", + "fedid": "12345", + } + mock_authenticator.print_user_info("valid_token") + + +def test_logout( + mock_session_manager: SessionManager, mock_connected_client_config: CLIClientConfig +): + assert os.path.exists(mock_connected_client_config.token_file_path) # type: ignore + mock_session_manager.logout() + assert not os.path.exists(mock_connected_client_config.token_file_path) # type: ignore + + +@mock.patch("requests.post") +def test_refresh_auth_token( + mock_post, + mock_session_manager: SessionManager, + mock_connected_client_config: OAuthClientConfig, +): + mock_post.return_value.status_code = HTTPStatus.OK + mock_post.return_value.json.return_value = {"access_token": "new_access_token"} + result = mock_session_manager.refresh_auth_token() + assert result == {"access_token": "new_access_token"} + + +@mock.patch("requests.post") +def test_get_device_code( + mock_post, + mock_session_manager: SessionManager, +): + mock_post.return_value.status_code = HTTPStatus.OK + mock_post.return_value.json.return_value = {"device_code": "device_code"} + device_code = mock_session_manager.get_device_code() + assert device_code == "device_code" + + +@mock.patch("requests.post") +def test_poll_for_token( + mock_post, + mock_session_manager: SessionManager, +): + mock_post.return_value.status_code = HTTPStatus.OK + mock_post.return_value.json.return_value = {"access_token": "access_token"} + device_code = "device_code" + token = mock_session_manager.poll_for_token(device_code) + assert token == {"access_token": "access_token"} + + +@mock.patch("requests.post") +@mock.patch("time.sleep") +def test_poll_for_token_timeout( + mock_sleep, + mock_post, + mock_session_manager: SessionManager, +): + mock_post.return_value.status_code = HTTPStatus.BAD_REQUEST + device_code = "device_code" + with pytest.raises(TimeoutError): + mock_session_manager.poll_for_token( + device_code, timeout=1, polling_interval=0.1 + ) diff --git a/tests/unit_tests/test_cli.py b/tests/unit_tests/test_cli.py index 5008bdd1f..2eb17d3ef 100644 --- a/tests/unit_tests/test_cli.py +++ b/tests/unit_tests/test_cli.py @@ -644,9 +644,9 @@ def test_logout_missing_config(runner: CliRunner): @pytest.fixture def valid_auth_config(tmp_path: Path) -> str: config = f""" -oauth: +oauth_server: oidc_config_url: https://auth.example.com/realms/sample/.well-known/openid-configuration -cliAuth: +oauth_client: client_id: sample-cli client_audience: sample-account token_file_path: {tmp_path}/token diff --git a/tests/unit_tests/test_config.py b/tests/unit_tests/test_config.py index d4e8146a8..d4f63a130 100644 --- a/tests/unit_tests/test_config.py +++ b/tests/unit_tests/test_config.py @@ -7,7 +7,7 @@ from bluesky_stomp.models import BasicAuthentication from pydantic import BaseModel, Field -from blueapi.config import ConfigLoader, OauthConfig +from blueapi.config import ConfigLoader, OAuthServerConfig from blueapi.utils import InvalidConfigError @@ -164,7 +164,7 @@ def test_oauth_config_model_post_init(mock_get): mock_get.return_value.json.return_value = mock_response mock_get.return_value.raise_for_status = lambda: None - oauth_config = OauthConfig(oidc_config_url=oidc_config_url) + oauth_config = OAuthServerConfig(oidc_config_url=oidc_config_url) assert ( oauth_config.device_auth_url == mock_response["device_authorization_endpoint"] @@ -191,4 +191,4 @@ def test_oauth_config_model_post_init_missing_fields(mock_get): mock_get.return_value.json.return_value = mock_response mock_get.return_value.raise_for_status = lambda: None with pytest.raises(ValueError, match="OIDC config is missing required fields"): - OauthConfig(oidc_config_url=oidc_config_url) + OAuthServerConfig(oidc_config_url=oidc_config_url)