diff --git a/egi_notebooks_hub/egiauthenticator.py b/egi_notebooks_hub/egiauthenticator.py index 04cafc9..8466849 100644 --- a/egi_notebooks_hub/egiauthenticator.py +++ b/egi_notebooks_hub/egiauthenticator.py @@ -10,6 +10,8 @@ from urllib.parse import urlencode import jwt +import jwt.exceptions +from jupyterhub import orm from jupyterhub.handlers import BaseHandler from oauthenticator.generic import GenericOAuthenticator from tornado import web @@ -18,49 +20,118 @@ class JWTHandler(BaseHandler): - async def get(self): - auth_header = self.request.headers.get("Authorization", "") - if auth_header: - try: - bearer, token = auth_header.split() - if bearer.lower() != "bearer": - self.log.debug("Unexpected authorization header format") - raise HTTPError(401) - except ValueError: - self.log.debug("Unexpected authorization header format") - raise HTTPError(401) - else: - self.log.debug("No authorization header") - raise HTTPError(401) - token_info = { - "access_token": token, - "token_type": "bearer", + async def exchange_for_refresh_token(self, access_token): + self.log.debug("Exchanging access token for refresh") + http_client = AsyncHTTPClient() + headers = { + "Accept": "application/json", + "User-Agent": "JupyterHub", } - user = await self.login_user(token_info) - if user is None: - raise web.HTTPError(403, self.authenticator.custom_403_message) - auth_state = await user.get_auth_state() - if auth_state and "refresh_token" not in auth_state: - # TODO: decide how to deal with the refresh token - self.log.debug("Refresh token is not there...") - - # extract from the jwt token (without verification!) - decoded_token = jwt.decode(token, options={"verify_signature": False}) - # default: 1h token - expires_in = 3600 - if "exp" in decoded_token and "iat" in decoded_token: - expires_in = decoded_token["exp"] - decoded_token["iat"] - - # Possible optimisation here: instead of creating a new token every time, - # go through user.api_tokens and get one from there - api_token = user.new_api_token( - note="JWT auth token", - expires_in=expires_in, - # TODO: this may be tuned, but should be a post - # call with a body specifying the roles and scopes - # roles=token_roles, - # scopes=token_scopes, + body = urlencode( + dict( + grant_type="urn:ietf:params:oauth:grant-type:token-exchange", + requested_token_type="urn:ietf:params:oauth:token-type:refresh_token", + subject_token_type="urn:ietf:params:oauth:token-type:access_token", + subject_token=access_token, + # beware that this requires the "offline_access" or similar + # to be included, otherwise the refresh token will not be + # released. Also the access token must have this scope. + scope=" ".join(self.authenticator.scope), + ) ) + req = HTTPRequest( + self.authenticator.token_url, + auth_username=self.authenticator.client_id, + auth_password=self.authenticator.client_secret, + headers=headers, + method="POST", + body=body, + ) + try: + resp = await http_client.fetch(req) + except HTTPClientError as e: + self.log.warning(f"Unable to get refresh token: {e}") + if e.response: + self.log.debug(e.response.body) + return None + token_info = json.loads(resp.body.decode("utf8", "replace")) + if "refresh_token" in token_info: + return token_info.get("refresh_token") + # EOSC AAI returns the token into "access_token" field, so be it + return token_info.get("access_token", None) + + async def _get_previous_hub_token(self, user, jwt_token): + if not user: + return None + auth_state = await user.get_auth_state() + if auth_state and auth_state.get("access_token", None) == jwt_token: + api_token = auth_state.get("jwt_api_token", None) + if api_token is None: + return None + orm_token = orm.APIToken.find(self.db, api_token) + if not orm_token or orm_token.expires_in <= 0: + return None + self.log.debug("Reusing previously available API token for this JWT") + return api_token + + def _get_token(self): + jwt_token = self.get_auth_token() + if not jwt_token: + self.log.debug("No token found in header") + raise HTTPError(401) + try: + decoded_token = jwt.decode( + jwt_token, + options=dict(verify_signature=False, verify_exp=True), + ) + except jwt.exceptions.InvalidTokenError as e: + self.log.debug(f"Invalid token {e}") + raise web.HTTPError(401) + return jwt_token, decoded_token + + async def get(self): + user = None + jwt_token, decoded_token = self._get_token() + try: + username = self.authenticator.user_info_to_username(decoded_token) + user = self.find_user(username) + except ValueError as e: + self.log.debug(f"Unable to get username from token: {e}") + api_token = await self._get_previous_hub_token(user, jwt_token) + if not api_token: + self.log.debug("Authenticating user") + token_info = { + "access_token": jwt_token, + "token_type": "bearer", + } + user = await self.login_user(token_info) + if user is None: + raise web.HTTPError(403, self.authenticator.custom_403_message) + auth_state = await user.get_auth_state() + if auth_state and not auth_state.get("refresh_token", None): + self.log.debug("Refresh token is not available") + refresh_token = await self.exchange_for_refresh_token(jwt_token) + if refresh_token: + self.log.debug("Got refresh token from exchange") + auth_state["refresh_token"] = refresh_token + + # default: 1h token + expires_in = 3600 + if "exp" in decoded_token and "iat" in decoded_token: + expires_in = decoded_token["exp"] - decoded_token["iat"] + + # Possible optimisation here: instead of creating a new token every time, + # go through user.api_tokens and get one from there + api_token = user.new_api_token( + note="JWT auth token", + expires_in=expires_in, + # TODO: this may be tuned, but should be a post + # call with a body specifying the roles and scopes + # roles=token_roles, + # scopes=token_scopes, + ) + auth_state["jwt_api_token"] = api_token + await user.save_auth_state(auth_state) self.finish({"token": api_token, "user": user.name}) @@ -301,7 +372,7 @@ class EOSCNodeAuthenticator(EGICheckinAuthenticator): login_service = "EOSC AAI" personal_project_re = Unicode( - r"^urn:geant:eosc-federation.eu:group:pp:Personal%20Project%20Name-(.*)$", + r"^urn:geant:eosc-federation.eu:group:(pp-.*)$", config=True, help="""Regular expression to match the personal groups. If the regular expression contains a group and matches, it will be diff --git a/egi_notebooks_hub/services/api_wrapper.py b/egi_notebooks_hub/services/api_wrapper.py index 15ed9eb..f8cfeaa 100644 --- a/egi_notebooks_hub/services/api_wrapper.py +++ b/egi_notebooks_hub/services/api_wrapper.py @@ -58,7 +58,11 @@ async def api_wrapper(request: Request, svc_path: str): status_code=exc.response.status_code, detail=exc.response.text ) content = await request.body() - api_path = svc_path.removeprefix(settings.jupyterhub_service_prefix) + api_path = ( + svc_path.removeprefix(settings.jupyterhub_service_prefix.rstrip("/")) + if svc_path + else "" + ) async with httpx.AsyncClient() as client: # which headers do we need to preserve? headers = dict(request.headers)