diff --git a/docs/usage/configuration/config_documentation.md b/docs/usage/configuration/config_documentation.md index 282b59dec9f..2599b34d708 100644 --- a/docs/usage/configuration/config_documentation.md +++ b/docs/usage/configuration/config_documentation.md @@ -3709,6 +3709,8 @@ Additional sub-options for this setting include: Required if `enabled` is set to true. * `subject_claim`: Name of the claim containing a unique identifier for the user. Optional, defaults to `sub`. +* `display_name_claim`: Name of the claim containing the display name for the user. Optional. + If provided, the display name will be set to the value of this claim upon first login. * `issuer`: The issuer to validate the "iss" claim against. Optional. If provided the "iss" claim will be required and validated for all JSON web tokens. * `audiences`: A list of audiences to validate the "aud" claim against. Optional. @@ -3723,6 +3725,7 @@ jwt_config: secret: "provided-by-your-issuer" algorithm: "provided-by-your-issuer" subject_claim: "name_of_claim" + display_name_claim: "name_of_claim" issuer: "provided-by-your-issuer" audiences: - "provided-by-your-issuer" diff --git a/synapse/config/jwt.py b/synapse/config/jwt.py index b41f2dc08f3..5c76551f334 100644 --- a/synapse/config/jwt.py +++ b/synapse/config/jwt.py @@ -38,6 +38,7 @@ def read_config(self, config: JsonDict, **kwargs: Any) -> None: self.jwt_algorithm = jwt_config["algorithm"] self.jwt_subject_claim = jwt_config.get("subject_claim", "sub") + self.jwt_display_name_claim = jwt_config.get("display_name_claim") # The issuer and audiences are optional, if provided, it is asserted # that the claims exist on the JWT. @@ -49,5 +50,6 @@ def read_config(self, config: JsonDict, **kwargs: Any) -> None: self.jwt_secret = None self.jwt_algorithm = None self.jwt_subject_claim = None + self.jwt_display_name_claim = None self.jwt_issuer = None self.jwt_audiences = None diff --git a/synapse/handlers/jwt.py b/synapse/handlers/jwt.py index 5fa7a305add..f284d6324f7 100644 --- a/synapse/handlers/jwt.py +++ b/synapse/handlers/jwt.py @@ -18,7 +18,7 @@ # [This file includes modifications made by New Vector Limited] # # -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional, Tuple from authlib.jose import JsonWebToken, JWTClaims from authlib.jose.errors import BadSignatureError, InvalidClaimError, JoseError @@ -36,11 +36,12 @@ def __init__(self, hs: "HomeServer"): self.jwt_secret = hs.config.jwt.jwt_secret self.jwt_subject_claim = hs.config.jwt.jwt_subject_claim + self.jwt_display_name_claim = hs.config.jwt.jwt_display_name_claim self.jwt_algorithm = hs.config.jwt.jwt_algorithm self.jwt_issuer = hs.config.jwt.jwt_issuer self.jwt_audiences = hs.config.jwt.jwt_audiences - def validate_login(self, login_submission: JsonDict) -> str: + def validate_login(self, login_submission: JsonDict) -> Tuple[str, Optional[str]]: """ Authenticates the user for the /login API @@ -49,7 +50,8 @@ def validate_login(self, login_submission: JsonDict) -> str: (including 'type' and other relevant fields) Returns: - The user ID that is logging in. + A tuple of (user_id, display_name) of the user that is logging in. + If the JWT does not contain a display name, the second element of the tuple will be None. Raises: LoginError if there was an authentication problem. @@ -109,4 +111,10 @@ def validate_login(self, login_submission: JsonDict) -> str: if user is None: raise LoginError(403, "Invalid JWT", errcode=Codes.FORBIDDEN) - return UserID(user, self.hs.hostname).to_string() + default_display_name = None + if self.jwt_display_name_claim: + display_name_claim = claims.get(self.jwt_display_name_claim, None) + if display_name_claim is not None: + default_display_name = display_name_claim + + return UserID(user, self.hs.hostname).to_string(), default_display_name diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py index 03b1e7edc49..3271b02d40e 100644 --- a/synapse/rest/client/login.py +++ b/synapse/rest/client/login.py @@ -363,6 +363,7 @@ async def _complete_login( login_submission: JsonDict, callback: Optional[Callable[[LoginResponse], Awaitable[None]]] = None, create_non_existent_users: bool = False, + default_display_name: Optional[str] = None, ratelimit: bool = True, auth_provider_id: Optional[str] = None, should_issue_refresh_token: bool = False, @@ -410,7 +411,8 @@ async def _complete_login( canonical_uid = await self.auth_handler.check_user_exists(user_id) if not canonical_uid: canonical_uid = await self.registration_handler.register_user( - localpart=UserID.from_string(user_id).localpart + localpart=UserID.from_string(user_id).localpart, + default_display_name=default_display_name, ) user_id = canonical_uid @@ -546,11 +548,14 @@ async def _do_jwt_login( Returns: The body of the JSON response. """ - user_id = self.hs.get_jwt_handler().validate_login(login_submission) + user_id, default_display_name = self.hs.get_jwt_handler().validate_login( + login_submission + ) return await self._complete_login( user_id, login_submission, create_non_existent_users=True, + default_display_name=default_display_name, should_issue_refresh_token=should_issue_refresh_token, request_info=request_info, )