Skip to content

Commit

Permalink
Refactor authentication in Blueapi RestClient and service
Browse files Browse the repository at this point in the history
  • Loading branch information
ZohebShaikh committed Oct 14, 2024
1 parent 5a638c0 commit 3a6db83
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 119 deletions.
16 changes: 12 additions & 4 deletions src/blueapi/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from blueapi.client.rest import BlueskyRemoteControlError
from blueapi.config import ApplicationConfig, ConfigLoader
from blueapi.core import DataEvent
from blueapi.service.authentication import Authenticator
from blueapi.service.authentication import TokenManager
from blueapi.service.main import start
from blueapi.service.openapi import (
DOCS_SCHEMA_LOCATION,
Expand Down Expand Up @@ -333,6 +333,14 @@ def scratch(obj: dict) -> None:


@main.command(name="login")
def login() -> None:
auth = Authenticator()
auth.start_device_flow()
@click.pass_obj
def login(obj: dict) -> None:
config: ApplicationConfig = obj["config"]
print(config)
if config.cliAuth is not None and config.oauth is not None:
print("Logging in")
print(config)
auth = TokenManager(cliAuth=config.cliAuth, oauth=config.oauth)
auth.start_device_flow()
else:
print("Please provide configuration to login!")
69 changes: 68 additions & 1 deletion src/blueapi/config.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
import os
from collections.abc import Mapping
from enum import Enum
from pathlib import Path
from typing import Any, Generic, Literal, TypeVar

import requests
import yaml
from bluesky_stomp.models import BasicAuthentication
from pydantic import BaseModel, Field, TypeAdapter, ValidationError
from pydantic import (
BaseModel,
Field,
Secret,
TypeAdapter,
ValidationError,
field_validator,
)

from blueapi.utils import BlueapiBaseModel, InvalidConfigError

Expand Down Expand Up @@ -77,6 +86,61 @@ class ScratchConfig(BlueapiBaseModel):
repositories: list[ScratchRepository] = Field(default_factory=list)


class OauthConfig(BlueapiBaseModel):
oidc_config_url: str = Field(
description="URL to fetch OIDC config from the provider"
)
# Initialized post-init
device_auth_url: str = ""
pkce_auth_url: str = ""
token_url: str = ""
issuer: str = ""
jwks_uri: str = ""

def model_post_init(self, __context: Any) -> None:
response = requests.get(self.oidc_config_url)
response.raise_for_status()
config_data = response.json()

self.device_auth_url = config_data.get("device_authorization_endpoint")
self.pkce_auth_url = config_data.get("authorization_endpoint")
self.token_url = config_data.get("token_endpoint")
self.issuer = config_data.get("issuer")
self.jwks_uri = config_data.get("jwks_uri")
# post this we need to check if all the values are present
if any(
(
self.device_auth_url == "",
self.pkce_auth_url == "",
self.token_url == "",
self.issuer == "",
self.jwks_uri == "",
)
):
raise ValueError("OIDC config is missing required fields")


class SwaggerAuthConfig(BlueapiBaseModel):
client_id: str = Field(description="Client ID for PKCE client")
client_secret: Secret[str] = Field(
description="Password to verify PKCE client's identity"
)
client_audience: str

@field_validator("client_secret", mode="before")
@classmethod
def get_from_env(cls, v: str):
if v.startswith("${") and v.endswith("}"):
return os.environ[v.removeprefix("${").removesuffix("}").upper()]
return v


class CLIAuthConfig(BlueapiBaseModel):
client_id: str = Field(description="Client ID for CLI client")
client_audience: str = Field(description="Audience for CLI client")
token_file_path: str = "~/token"


class ApplicationConfig(BlueapiBaseModel):
"""
Config for the worker application as a whole. Root of
Expand All @@ -88,6 +152,9 @@ class ApplicationConfig(BlueapiBaseModel):
logging: LoggingConfig = Field(default_factory=LoggingConfig)
api: RestConfig = Field(default_factory=RestConfig)
scratch: ScratchConfig | None = None
oauth: OauthConfig | None = None
cliAuth: CLIAuthConfig | None = None
swaggerAuth: SwaggerAuthConfig | None = None

def __eq__(self, other: object) -> bool:
if isinstance(other, ApplicationConfig):
Expand Down
164 changes: 50 additions & 114 deletions src/blueapi/service/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,55 +8,33 @@

import jwt
import requests
from dotenv import load_dotenv

from blueapi.config import (
CLIAuthConfig,
OauthConfig,
SwaggerAuthConfig,
)


class AuthenticationType(Enum):
DEVICE = "device"
PKCE = "pkce"


class TokenManager:
"""
TokenManager class handles the token verification and refreshing.
Attributes:
client_id (str): The client ID for the authentication.
token_url (str): The URL to obtain the token.
audience (list[str]): The audience for the authentication.
issuer (str): The issuer of the token.
jwks_client (jwt.PyJWKClient): The JWKS client for verifying tokens.
token_file_path (str): The file path to save the token.
token (None | dict[str, Any]): The token dictionary.
"""

# Will move this to a computed field in ApplicationConfig
# Get the OpenID Connect configuration and configure the JWKS client
oidc_config = requests.get(
"https://authn.diamond.ac.uk/realms/master/.well-known/openid-configuration"
).json()
jwks_uri = oidc_config["jwks_uri"]
issuer = oidc_config["issuer"]
token_url = oidc_config["token_endpoint"]
jwks_client = jwt.PyJWKClient(jwks_uri)
audience = "blueapi"

class Authenticator:
def __init__(
self,
client_id: str,
token_file_path: str = "token",
) -> None:
self.client_id = client_id
self.token_file_path = token_file_path
self.token: None | dict[str, Any] = None
self.load_token()
oauth: OauthConfig,
authentorConfig: CLIAuthConfig | SwaggerAuthConfig,
):
self.oauth: OauthConfig = oauth
self.authentorConfig: CLIAuthConfig | SwaggerAuthConfig = authentorConfig

@classmethod
def verify_token(
cls, token: str, verify_expiration: bool = True
self, token: str, verify_expiration: bool = True
) -> tuple[bool, Exception | None]:
try:
decode = cls.decode_jwt(token, verify_expiration)
decode = self.decode_jwt(token, verify_expiration)
if decode:
return (True, None)
except jwt.PyJWTError as e:
Expand All @@ -65,39 +43,48 @@ def verify_token(

return (False, Exception("Invalid token"))

@classmethod
def decode_jwt(cls, token: str, verify_expiration: bool = True):
signing_key = cls.jwks_client.get_signing_key_from_jwt(token)
def decode_jwt(self, token: str, verify_expiration: bool = True):
signing_key = jwt.PyJWKClient(self.oauth.jwks_uri).get_signing_key_from_jwt(
token
)
decode = jwt.decode(
token,
signing_key.key,
algorithms=["RS256"],
options={"verify_exp": verify_expiration},
verify=True,
audience=cls.audience,
issuer=cls.issuer,
audience=self.authentorConfig.client_audience,
issuer=self.oauth.issuer,
leeway=5,
)
return decode

@classmethod
def userInfo(cls, token: str) -> tuple[str | None, str | None]:
def userInfo(self, token: str) -> tuple[str | None, str | None]:
try:
decode = cls.decode_jwt(token)
decode = self.decode_jwt(token)
if decode:
return (decode["name"], decode["fedid"])
else:
return (None, None)
except jwt.PyJWTError as _:
return (None, None)


class TokenManager:
def __init__(self, oauth: OauthConfig, cliAuth: CLIAuthConfig) -> None:
self.oauth = oauth
self.cliAuth = cliAuth
self.token = None
self.authenticator = Authenticator(self.oauth, self.cliAuth)
self.load_token()

def refresh_auth_token(self) -> bool:
if self.token:
response = requests.post(
self.token_url,
self.oauth.token_url,
headers={"Content-Type": "application/x-www-form-urlencoded"},
data={
"client_id": self.client_id,
"client_id": self.cliAuth.client_id,
"grant_type": "refresh_token",
"refresh_token": self.token["refresh_token"],
},
Expand All @@ -114,76 +101,25 @@ def save_token(self, token: dict[str, Any]) -> None:
token_json = json.dumps(token)
token_bytes = token_json.encode("utf-8")
token_base64 = base64.b64encode(token_bytes)
with open(self.token_file_path, "wb") as token_file:
with open(os.path.expanduser(self.cliAuth.token_file_path), "wb") as token_file:
token_file.write(token_base64)

def load_token(self) -> None:
if not os.path.exists(self.token_file_path):
if not os.path.exists(self.cliAuth.token_file_path):
return None
with open(self.token_file_path, "rb") as token_file:
with open(os.path.expanduser(self.cliAuth.token_file_path), "rb") as token_file:
token_base64 = token_file.read()
token_bytes = base64.b64decode(token_base64)
token_json = token_bytes.decode("utf-8")
self.token = json.loads(token_json)


class Authenticator:
"""
Authenticator class handles the authentication process using either
device code flow or PKCE flow.
Attributes:
client_id (str): The client ID for the authentication.
authentication_url (str): The URL for authentication.
audience (list[str]): The audience for the authentication.
token_manager (TokenManager): The TokenManager instance.
"""

def __init__(
self,
authentication_type: AuthenticationType = AuthenticationType.DEVICE,
token_file_path: str = "token",
) -> None:
load_dotenv()
if authentication_type == AuthenticationType.DEVICE:
self.client_id: str = os.getenv("DEVICE_CLIENT_ID", "")
self.authentication_url: str = os.getenv("DEVICE_AUTHENTICATION_URL", "")
self.audience: list[str] = os.getenv("DEVICE_AUDIENCES", "").split(" ")
else:
self.client_secret: str = os.getenv("PKCE_CLIENT_SECRET", "")
if self.client_secret == "":
raise Exception("Missing environment variables")
self.client_id = os.getenv("PKCE_CLIENT_ID", "")
self.authentication_url = os.getenv("PKCE_AUTHENTICATION_URL", "")
self.audience = os.getenv("PKCE_AUDIENCES", "").split(" ")

self.token_url: str = os.getenv("TOKEN_URL", "")
self.openid_config: str = os.getenv("OPEN_ID_CONFIG", "")
self.issuer = os.getenv("ISSUER")
if any(
[
self.client_id == "",
self.authentication_url == "",
self.audience == "",
self.token_url == "",
self.openid_config == "",
self.issuer == "",
]
):
raise Exception("Missing environment variables")

self.token_manager = TokenManager(
client_id=self.client_id,
token_file_path=token_file_path,
)

def get_device_code(self):
response = requests.post(
self.token_url,
self.oauth.token_url,
data={
"client_id": self.client_id,
"client_id": self.cliAuth.client_id,
"scope": "openid profile offline_access",
"audience": self.audience,
"audience": self.cliAuth.client_audience,
},
)
response_data = response.json()
Expand All @@ -198,12 +134,12 @@ def poll_for_token(
too_late = time.time() + timeout
while time.time() < too_late:
response = requests.post(
self.token_url,
self.oauth.token_url,
headers={"Content-Type": "application/x-www-form-urlencoded"},
data={
"grant_type": "urn:ietf:params:oauth:grant-type:device_code",
"device_code": device_code,
"client_id": self.client_id,
"client_id": self.cliAuth.client_id,
},
)
if response.status_code == HTTPStatus.OK:
Expand All @@ -215,21 +151,21 @@ def poll_for_token(
raise TimeoutError("Polling timed out")

def start_device_flow(self) -> None:
if self.token_manager.token:
valid_token, exception = self.token_manager.verify_token(
self.token_manager.token["access_token"]
if self.token:
valid_token, exception = self.authenticator.verify_token(
self.token["access_token"]
)
if valid_token:
print("Token verified")
return
elif isinstance(exception, jwt.ExpiredSignatureError):
if self.token_manager.refresh_auth_token():
if self.refresh_auth_token():
return

response = requests.post(
self.authentication_url,
self.oauth.device_auth_url,
headers={"Content-Type": "application/x-www-form-urlencoded"},
data={"client_id": self.client_id},
data={"client_id": self.cliAuth.client_id},
)

if response.status_code == HTTPStatus.OK:
Expand All @@ -242,17 +178,17 @@ def start_device_flow(self) -> None:
auth_token_json = self.poll_for_token(device_code)
if auth_token_json:
print(auth_token_json)
verify, exception = TokenManager.verify_token(
verify, exception = self.authenticator.verify_token(
auth_token_json["access_token"]
)
if verify:
print("Token verified")
self.token_manager.save_token(auth_token_json)
self.save_token(auth_token_json)
else:
print("Unauthorized access")
return
else:
print("Unauthorized access")
return
userName, fedid = TokenManager.userInfo(auth_token_json["access_token"])
userName, fedid = self.authenticator.userInfo(auth_token_json["access_token"])
print(f"Logged in as {userName} with fed-id {fedid}")

0 comments on commit 3a6db83

Please sign in to comment.