From 3a6db8309e3e4ee9d674eb4c05d96276f5402940 Mon Sep 17 00:00:00 2001 From: Zoheb Shaikh Date: Mon, 14 Oct 2024 09:23:04 +0100 Subject: [PATCH] Refactor authentication in Blueapi RestClient and service --- src/blueapi/cli/cli.py | 16 ++- src/blueapi/config.py | 69 ++++++++++- src/blueapi/service/authentication.py | 164 ++++++++------------------ 3 files changed, 130 insertions(+), 119 deletions(-) diff --git a/src/blueapi/cli/cli.py b/src/blueapi/cli/cli.py index 4f3c26321..fd8e10348 100644 --- a/src/blueapi/cli/cli.py +++ b/src/blueapi/cli/cli.py @@ -19,7 +19,7 @@ from blueapi.client.rest import BlueskyRemoteControlError from blueapi.config import ApplicationConfig, ConfigLoader from blueapi.core import DataEvent -from blueapi.service.authentication import Authenticator +from blueapi.service.authentication import TokenManager from blueapi.service.main import start from blueapi.service.openapi import ( DOCS_SCHEMA_LOCATION, @@ -333,6 +333,14 @@ def scratch(obj: dict) -> None: @main.command(name="login") -def login() -> None: - auth = Authenticator() - auth.start_device_flow() +@click.pass_obj +def login(obj: dict) -> None: + config: ApplicationConfig = obj["config"] + print(config) + if config.cliAuth is not None and config.oauth is not None: + print("Logging in") + print(config) + auth = TokenManager(cliAuth=config.cliAuth, oauth=config.oauth) + auth.start_device_flow() + else: + print("Please provide configuration to login!") diff --git a/src/blueapi/config.py b/src/blueapi/config.py index 3502590ba..ce7fa6e04 100644 --- a/src/blueapi/config.py +++ b/src/blueapi/config.py @@ -1,11 +1,20 @@ +import os from collections.abc import Mapping from enum import Enum from pathlib import Path from typing import Any, Generic, Literal, TypeVar +import requests import yaml from bluesky_stomp.models import BasicAuthentication -from pydantic import BaseModel, Field, TypeAdapter, ValidationError +from pydantic import ( + BaseModel, + Field, + Secret, + TypeAdapter, + ValidationError, + field_validator, +) from blueapi.utils import BlueapiBaseModel, InvalidConfigError @@ -77,6 +86,61 @@ class ScratchConfig(BlueapiBaseModel): repositories: list[ScratchRepository] = Field(default_factory=list) +class OauthConfig(BlueapiBaseModel): + oidc_config_url: str = Field( + description="URL to fetch OIDC config from the provider" + ) + # Initialized post-init + device_auth_url: str = "" + pkce_auth_url: str = "" + token_url: str = "" + issuer: str = "" + jwks_uri: str = "" + + def model_post_init(self, __context: Any) -> None: + response = requests.get(self.oidc_config_url) + response.raise_for_status() + config_data = response.json() + + self.device_auth_url = config_data.get("device_authorization_endpoint") + self.pkce_auth_url = config_data.get("authorization_endpoint") + self.token_url = config_data.get("token_endpoint") + self.issuer = config_data.get("issuer") + self.jwks_uri = config_data.get("jwks_uri") + # post this we need to check if all the values are present + if any( + ( + self.device_auth_url == "", + self.pkce_auth_url == "", + self.token_url == "", + self.issuer == "", + self.jwks_uri == "", + ) + ): + raise ValueError("OIDC config is missing required fields") + + +class SwaggerAuthConfig(BlueapiBaseModel): + client_id: str = Field(description="Client ID for PKCE client") + client_secret: Secret[str] = Field( + description="Password to verify PKCE client's identity" + ) + client_audience: str + + @field_validator("client_secret", mode="before") + @classmethod + def get_from_env(cls, v: str): + if v.startswith("${") and v.endswith("}"): + return os.environ[v.removeprefix("${").removesuffix("}").upper()] + return v + + +class CLIAuthConfig(BlueapiBaseModel): + client_id: str = Field(description="Client ID for CLI client") + client_audience: str = Field(description="Audience for CLI client") + token_file_path: str = "~/token" + + class ApplicationConfig(BlueapiBaseModel): """ Config for the worker application as a whole. Root of @@ -88,6 +152,9 @@ class ApplicationConfig(BlueapiBaseModel): logging: LoggingConfig = Field(default_factory=LoggingConfig) api: RestConfig = Field(default_factory=RestConfig) scratch: ScratchConfig | None = None + oauth: OauthConfig | None = None + cliAuth: CLIAuthConfig | None = None + swaggerAuth: SwaggerAuthConfig | None = None def __eq__(self, other: object) -> bool: if isinstance(other, ApplicationConfig): diff --git a/src/blueapi/service/authentication.py b/src/blueapi/service/authentication.py index 31821be9e..fe371dae4 100644 --- a/src/blueapi/service/authentication.py +++ b/src/blueapi/service/authentication.py @@ -8,7 +8,12 @@ import jwt import requests -from dotenv import load_dotenv + +from blueapi.config import ( + CLIAuthConfig, + OauthConfig, + SwaggerAuthConfig, +) class AuthenticationType(Enum): @@ -16,47 +21,20 @@ class AuthenticationType(Enum): PKCE = "pkce" -class TokenManager: - """ - TokenManager class handles the token verification and refreshing. - - Attributes: - client_id (str): The client ID for the authentication. - token_url (str): The URL to obtain the token. - audience (list[str]): The audience for the authentication. - issuer (str): The issuer of the token. - jwks_client (jwt.PyJWKClient): The JWKS client for verifying tokens. - token_file_path (str): The file path to save the token. - token (None | dict[str, Any]): The token dictionary. - """ - - # Will move this to a computed field in ApplicationConfig - # Get the OpenID Connect configuration and configure the JWKS client - oidc_config = requests.get( - "https://authn.diamond.ac.uk/realms/master/.well-known/openid-configuration" - ).json() - jwks_uri = oidc_config["jwks_uri"] - issuer = oidc_config["issuer"] - token_url = oidc_config["token_endpoint"] - jwks_client = jwt.PyJWKClient(jwks_uri) - audience = "blueapi" - +class Authenticator: def __init__( self, - client_id: str, - token_file_path: str = "token", - ) -> None: - self.client_id = client_id - self.token_file_path = token_file_path - self.token: None | dict[str, Any] = None - self.load_token() + oauth: OauthConfig, + authentorConfig: CLIAuthConfig | SwaggerAuthConfig, + ): + self.oauth: OauthConfig = oauth + self.authentorConfig: CLIAuthConfig | SwaggerAuthConfig = authentorConfig - @classmethod def verify_token( - cls, token: str, verify_expiration: bool = True + self, token: str, verify_expiration: bool = True ) -> tuple[bool, Exception | None]: try: - decode = cls.decode_jwt(token, verify_expiration) + decode = self.decode_jwt(token, verify_expiration) if decode: return (True, None) except jwt.PyJWTError as e: @@ -65,25 +43,25 @@ def verify_token( return (False, Exception("Invalid token")) - @classmethod - def decode_jwt(cls, token: str, verify_expiration: bool = True): - signing_key = cls.jwks_client.get_signing_key_from_jwt(token) + def decode_jwt(self, token: str, verify_expiration: bool = True): + signing_key = jwt.PyJWKClient(self.oauth.jwks_uri).get_signing_key_from_jwt( + token + ) decode = jwt.decode( token, signing_key.key, algorithms=["RS256"], options={"verify_exp": verify_expiration}, verify=True, - audience=cls.audience, - issuer=cls.issuer, + audience=self.authentorConfig.client_audience, + issuer=self.oauth.issuer, leeway=5, ) return decode - @classmethod - def userInfo(cls, token: str) -> tuple[str | None, str | None]: + def userInfo(self, token: str) -> tuple[str | None, str | None]: try: - decode = cls.decode_jwt(token) + decode = self.decode_jwt(token) if decode: return (decode["name"], decode["fedid"]) else: @@ -91,13 +69,22 @@ def userInfo(cls, token: str) -> tuple[str | None, str | None]: except jwt.PyJWTError as _: return (None, None) + +class TokenManager: + def __init__(self, oauth: OauthConfig, cliAuth: CLIAuthConfig) -> None: + self.oauth = oauth + self.cliAuth = cliAuth + self.token = None + self.authenticator = Authenticator(self.oauth, self.cliAuth) + self.load_token() + def refresh_auth_token(self) -> bool: if self.token: response = requests.post( - self.token_url, + self.oauth.token_url, headers={"Content-Type": "application/x-www-form-urlencoded"}, data={ - "client_id": self.client_id, + "client_id": self.cliAuth.client_id, "grant_type": "refresh_token", "refresh_token": self.token["refresh_token"], }, @@ -114,76 +101,25 @@ def save_token(self, token: dict[str, Any]) -> None: token_json = json.dumps(token) token_bytes = token_json.encode("utf-8") token_base64 = base64.b64encode(token_bytes) - with open(self.token_file_path, "wb") as token_file: + with open(os.path.expanduser(self.cliAuth.token_file_path), "wb") as token_file: token_file.write(token_base64) def load_token(self) -> None: - if not os.path.exists(self.token_file_path): + if not os.path.exists(self.cliAuth.token_file_path): return None - with open(self.token_file_path, "rb") as token_file: + with open(os.path.expanduser(self.cliAuth.token_file_path), "rb") as token_file: token_base64 = token_file.read() token_bytes = base64.b64decode(token_base64) token_json = token_bytes.decode("utf-8") self.token = json.loads(token_json) - -class Authenticator: - """ - Authenticator class handles the authentication process using either - device code flow or PKCE flow. - - Attributes: - client_id (str): The client ID for the authentication. - authentication_url (str): The URL for authentication. - audience (list[str]): The audience for the authentication. - token_manager (TokenManager): The TokenManager instance. - """ - - def __init__( - self, - authentication_type: AuthenticationType = AuthenticationType.DEVICE, - token_file_path: str = "token", - ) -> None: - load_dotenv() - if authentication_type == AuthenticationType.DEVICE: - self.client_id: str = os.getenv("DEVICE_CLIENT_ID", "") - self.authentication_url: str = os.getenv("DEVICE_AUTHENTICATION_URL", "") - self.audience: list[str] = os.getenv("DEVICE_AUDIENCES", "").split(" ") - else: - self.client_secret: str = os.getenv("PKCE_CLIENT_SECRET", "") - if self.client_secret == "": - raise Exception("Missing environment variables") - self.client_id = os.getenv("PKCE_CLIENT_ID", "") - self.authentication_url = os.getenv("PKCE_AUTHENTICATION_URL", "") - self.audience = os.getenv("PKCE_AUDIENCES", "").split(" ") - - self.token_url: str = os.getenv("TOKEN_URL", "") - self.openid_config: str = os.getenv("OPEN_ID_CONFIG", "") - self.issuer = os.getenv("ISSUER") - if any( - [ - self.client_id == "", - self.authentication_url == "", - self.audience == "", - self.token_url == "", - self.openid_config == "", - self.issuer == "", - ] - ): - raise Exception("Missing environment variables") - - self.token_manager = TokenManager( - client_id=self.client_id, - token_file_path=token_file_path, - ) - def get_device_code(self): response = requests.post( - self.token_url, + self.oauth.token_url, data={ - "client_id": self.client_id, + "client_id": self.cliAuth.client_id, "scope": "openid profile offline_access", - "audience": self.audience, + "audience": self.cliAuth.client_audience, }, ) response_data = response.json() @@ -198,12 +134,12 @@ def poll_for_token( too_late = time.time() + timeout while time.time() < too_late: response = requests.post( - self.token_url, + self.oauth.token_url, headers={"Content-Type": "application/x-www-form-urlencoded"}, data={ "grant_type": "urn:ietf:params:oauth:grant-type:device_code", "device_code": device_code, - "client_id": self.client_id, + "client_id": self.cliAuth.client_id, }, ) if response.status_code == HTTPStatus.OK: @@ -215,21 +151,21 @@ def poll_for_token( raise TimeoutError("Polling timed out") def start_device_flow(self) -> None: - if self.token_manager.token: - valid_token, exception = self.token_manager.verify_token( - self.token_manager.token["access_token"] + if self.token: + valid_token, exception = self.authenticator.verify_token( + self.token["access_token"] ) if valid_token: print("Token verified") return elif isinstance(exception, jwt.ExpiredSignatureError): - if self.token_manager.refresh_auth_token(): + if self.refresh_auth_token(): return response = requests.post( - self.authentication_url, + self.oauth.device_auth_url, headers={"Content-Type": "application/x-www-form-urlencoded"}, - data={"client_id": self.client_id}, + data={"client_id": self.cliAuth.client_id}, ) if response.status_code == HTTPStatus.OK: @@ -242,17 +178,17 @@ def start_device_flow(self) -> None: auth_token_json = self.poll_for_token(device_code) if auth_token_json: print(auth_token_json) - verify, exception = TokenManager.verify_token( + verify, exception = self.authenticator.verify_token( auth_token_json["access_token"] ) if verify: print("Token verified") - self.token_manager.save_token(auth_token_json) + self.save_token(auth_token_json) else: print("Unauthorized access") return else: print("Unauthorized access") return - userName, fedid = TokenManager.userInfo(auth_token_json["access_token"]) + userName, fedid = self.authenticator.userInfo(auth_token_json["access_token"]) print(f"Logged in as {userName} with fed-id {fedid}")