Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Token exchange to get refresh token #123

Merged
merged 13 commits into from
Jul 18, 2024
155 changes: 113 additions & 42 deletions egi_notebooks_hub/egiauthenticator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from urllib.parse import urlencode

import jwt
import jwt.exceptions
from jupyterhub import orm
from jupyterhub.handlers import BaseHandler
from oauthenticator.generic import GenericOAuthenticator
from tornado import web
Expand All @@ -18,49 +20,118 @@


class JWTHandler(BaseHandler):
async def get(self):
auth_header = self.request.headers.get("Authorization", "")
if auth_header:
try:
bearer, token = auth_header.split()
if bearer.lower() != "bearer":
self.log.debug("Unexpected authorization header format")
raise HTTPError(401)
except ValueError:
self.log.debug("Unexpected authorization header format")
raise HTTPError(401)
else:
self.log.debug("No authorization header")
raise HTTPError(401)
token_info = {
"access_token": token,
"token_type": "bearer",
async def exchange_for_refresh_token(self, access_token):
self.log.debug("Exchanging access token for refresh")
http_client = AsyncHTTPClient()
headers = {
"Accept": "application/json",
"User-Agent": "JupyterHub",
}
user = await self.login_user(token_info)
if user is None:
raise web.HTTPError(403, self.authenticator.custom_403_message)
auth_state = await user.get_auth_state()
if auth_state and "refresh_token" not in auth_state:
# TODO: decide how to deal with the refresh token
self.log.debug("Refresh token is not there...")

# extract from the jwt token (without verification!)
decoded_token = jwt.decode(token, options={"verify_signature": False})
# default: 1h token
expires_in = 3600
if "exp" in decoded_token and "iat" in decoded_token:
expires_in = decoded_token["exp"] - decoded_token["iat"]

# Possible optimisation here: instead of creating a new token every time,
# go through user.api_tokens and get one from there
api_token = user.new_api_token(
note="JWT auth token",
expires_in=expires_in,
# TODO: this may be tuned, but should be a post
# call with a body specifying the roles and scopes
# roles=token_roles,
# scopes=token_scopes,
body = urlencode(
dict(
grant_type="urn:ietf:params:oauth:grant-type:token-exchange",
requested_token_type="urn:ietf:params:oauth:token-type:refresh_token",
subject_token_type="urn:ietf:params:oauth:token-type:access_token",
subject_token=access_token,
# beware that this requires the "offline_access" or similar
# to be included, otherwise the refresh token will not be
# released. Also the access token must have this scope.
scope=" ".join(self.authenticator.scope),
)
)
req = HTTPRequest(
self.authenticator.token_url,
auth_username=self.authenticator.client_id,
auth_password=self.authenticator.client_secret,
headers=headers,
method="POST",
body=body,
)
try:
resp = await http_client.fetch(req)
except HTTPClientError as e:
self.log.warning(f"Unable to get refresh token: {e}")
if e.response:
self.log.debug(e.response.body)
return None
token_info = json.loads(resp.body.decode("utf8", "replace"))
if "refresh_token" in token_info:
return token_info.get("refresh_token")
# EOSC AAI returns the token into "access_token" field, so be it
return token_info.get("access_token", None)

async def _get_previous_hub_token(self, user, jwt_token):
if not user:
return None
auth_state = await user.get_auth_state()
if auth_state and auth_state.get("access_token", None) == jwt_token:
api_token = auth_state.get("jwt_api_token", None)
if api_token is None:
return None
orm_token = orm.APIToken.find(self.db, api_token)
if not orm_token or orm_token.expires_in <= 0:
return None
self.log.debug("Reusing previously available API token for this JWT")
return api_token

def _get_token(self):
jwt_token = self.get_auth_token()
if not jwt_token:
self.log.debug("No token found in header")
raise HTTPError(401)
try:
decoded_token = jwt.decode(
jwt_token,
options=dict(verify_signature=False, verify_exp=True),
)
except jwt.exceptions.InvalidTokenError as e:
self.log.debug(f"Invalid token {e}")
raise web.HTTPError(401)
return jwt_token, decoded_token

async def get(self):
user = None
jwt_token, decoded_token = self._get_token()
try:
username = self.authenticator.user_info_to_username(decoded_token)
user = self.find_user(username)
except ValueError as e:
self.log.debug(f"Unable to get username from token: {e}")
api_token = await self._get_previous_hub_token(user, jwt_token)
if not api_token:
self.log.debug("Authenticating user")
token_info = {
"access_token": jwt_token,
"token_type": "bearer",
}
user = await self.login_user(token_info)
if user is None:
raise web.HTTPError(403, self.authenticator.custom_403_message)
auth_state = await user.get_auth_state()
if auth_state and not auth_state.get("refresh_token", None):
self.log.debug("Refresh token is not available")
refresh_token = await self.exchange_for_refresh_token(jwt_token)
if refresh_token:
self.log.debug("Got refresh token from exchange")
auth_state["refresh_token"] = refresh_token

# default: 1h token
expires_in = 3600
if "exp" in decoded_token and "iat" in decoded_token:
expires_in = decoded_token["exp"] - decoded_token["iat"]

# Possible optimisation here: instead of creating a new token every time,
# go through user.api_tokens and get one from there
api_token = user.new_api_token(
note="JWT auth token",
expires_in=expires_in,
# TODO: this may be tuned, but should be a post
# call with a body specifying the roles and scopes
# roles=token_roles,
# scopes=token_scopes,
)
auth_state["jwt_api_token"] = api_token
await user.save_auth_state(auth_state)
self.finish({"token": api_token, "user": user.name})


Expand Down Expand Up @@ -301,7 +372,7 @@ class EOSCNodeAuthenticator(EGICheckinAuthenticator):
login_service = "EOSC AAI"

personal_project_re = Unicode(
r"^urn:geant:eosc-federation.eu:group:pp:Personal%20Project%20Name-(.*)$",
r"^urn:geant:eosc-federation.eu:group:(pp-.*)$",
config=True,
help="""Regular expression to match the personal groups.
If the regular expression contains a group and matches, it will be
Expand Down
6 changes: 5 additions & 1 deletion egi_notebooks_hub/services/api_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,11 @@ async def api_wrapper(request: Request, svc_path: str):
status_code=exc.response.status_code, detail=exc.response.text
)
content = await request.body()
api_path = svc_path.removeprefix(settings.jupyterhub_service_prefix)
api_path = (
svc_path.removeprefix(settings.jupyterhub_service_prefix.rstrip("/"))
if svc_path
else ""
)
async with httpx.AsyncClient() as client:
# which headers do we need to preserve?
headers = dict(request.headers)
Expand Down