diff --git a/docs/quickstart.rst b/docs/quickstart.rst index ab50736..72e0fdd 100644 --- a/docs/quickstart.rst +++ b/docs/quickstart.rst @@ -163,7 +163,8 @@ The name of the claim that is used for the ``User.username`` property can be configured via the admin (**Username claim**). By default, the username is derived from the ``sub`` claim that is returned by the OIDC provider. -If the desired claim is nested in one or more objects, its path can be specified with dots, e.g.: +If the desired claim is nested in one or more objects, you can specify the segments +of the path: .. code-block:: json @@ -175,10 +176,13 @@ If the desired claim is nested in one or more objects, its path can be specified } } -Can be retrieved by setting the username claim to ``some.nested.claim`` +Can be retrieved by setting the username claim (array field) to: -.. note:: - The username claim does not support claims that have dots in their name, it cannot be configured to retrieve the following claim for instance: +- some +- nested +- claim + +If the claim has dots in it, you can specify those in a segment: .. code-block:: json @@ -186,6 +190,10 @@ Can be retrieved by setting the username claim to ``some.nested.claim`` "some.dotted.claim": "foo" } +can be retrieved with: + +- some.dotted.claim + User profile ------------ @@ -254,4 +262,4 @@ and ``OIDCAuthenticationBackend.config_class`` to be this new class. .. _mozilla-django-oidc settings documentation: https://mozilla-django-oidc.readthedocs.io/en/stable/settings.html -.. _OIDC spec: https://openid.net/specs/openid-connect-discovery-1_0.html#WellKnownRegistry \ No newline at end of file +.. _OIDC spec: https://openid.net/specs/openid-connect-discovery-1_0.html#WellKnownRegistry diff --git a/mozilla_django_oidc_db/backends.py b/mozilla_django_oidc_db/backends.py index 3744eb3..3e40537 100644 --- a/mozilla_django_oidc_db/backends.py +++ b/mozilla_django_oidc_db/backends.py @@ -7,7 +7,7 @@ from django.core.exceptions import ObjectDoesNotExist import requests -from glom import glom +from glom import Path, glom from mozilla_django_oidc.auth import ( OIDCAuthenticationBackend as _OIDCAuthenticationBackend, ) @@ -22,6 +22,12 @@ T = TypeVar("T", bound=OpenIDConnectConfig) +class MissingIdentifierClaim(Exception): + def __init__(self, claim_bits: list[str], *args, **kwargs): + self.claim_bits = claim_bits + super().__init__(*args, **kwargs) + + class OIDCAuthenticationBackend( GetAttributeMixin, SoloConfigMixin[T], _OIDCAuthenticationBackend ): @@ -31,7 +37,7 @@ class OIDCAuthenticationBackend( """ config_identifier_field = "username_claim" - sensitive_claim_names = [] + sensitive_claim_names: list[list[str]] = [] def __init__(self, *args, **kwargs): # django-stubs returns AbstractBaseUser, but we depend on properties of @@ -46,27 +52,26 @@ def __init__(self, *args, **kwargs): # to avoid a large number of `OpenIDConnectConfig.get_solo` calls when # `OIDCAuthenticationBackend.__init__` is called for permission checks - def retrieve_identifier_claim(self, claims: dict) -> str: - # NOTE: this does not support the extraction of claims that contain dots "." in - # their name (e.g. {"foo.bar": "baz"}) - identifier_claim_name = getattr(self.config, self.config_identifier_field) - unique_id = glom(claims, identifier_claim_name, default="") + def retrieve_identifier_claim( + self, claims: dict, raise_on_empty: bool = False + ) -> str: + claim_bits = getattr(self.config, self.config_identifier_field) + unique_id = glom(claims, Path(*claim_bits), default="") + if raise_on_empty and not unique_id: + raise MissingIdentifierClaim(claim_bits=claim_bits) return unique_id - def get_sensitive_claims_names(self) -> list: + def get_sensitive_claims_names(self) -> list[list[str]]: """ Defines the claims that should be obfuscated before logging claims. - Nested claims can be specified by using a dotted path (e.g. "foo.bar.baz") - NOTE: this does not support claim names that have dots in them, so the following - claim cannot be marked as a sensitive claim - - { - "foo.bar": "baz" - } + Nested claims are represented with a path of bits (e.g. ["foo", "bar", "baz"]). + Claims with dots in them are supported, e.g. ["foo.bar"]. """ - identifier_claim_name = getattr(self.config, self.config_identifier_field) - return [identifier_claim_name] + self.sensitive_claim_names + identifier_claim_bits: list[str] = getattr( + self.config, self.config_identifier_field + ) + return [identifier_claim_bits] + self.sensitive_claim_names def get_userinfo(self, access_token, id_token, payload): """ @@ -132,8 +137,8 @@ def get_user_instance_values(self, claims) -> dict[str, Any]: Map the names and values of the claims to the fields of the User model """ return { - model_field: glom(claims, claims_field, default="") - for model_field, claims_field in self.config.claim_mapping.items() + model_field: glom(claims, Path(*claim_bits), default="") + for model_field, claim_bits in self.config.claim_mapping.items() } def create_user(self, claims): @@ -169,11 +174,14 @@ def verify_claims(self, claims) -> bool: logger.debug("OIDC claims received: %s", obfuscated_claims) - identifier_claim_name = getattr(self.config, self.config_identifier_field) - if not glom(claims, identifier_claim_name, default=""): + # check if we have an identifier + try: + self.retrieve_identifier_claim(claims, raise_on_empty=True) + except MissingIdentifierClaim as exc: logger.error( - "%s not in OIDC claims, cannot proceed with authentication", - identifier_claim_name, + "'%s' not in OIDC claims, cannot proceed with authentication", + " > ".join(exc.claim_bits), + exc_info=exc, ) return False return True @@ -199,76 +207,79 @@ def update_user(self, user, claims): return user + def _retrieve_groups_claim(self, claims: dict[str, Any]) -> list[str]: + groups_claim_bits = self.config.groups_claim + return glom(claims, Path(*groups_claim_bits), default=[]) + def update_user_superuser_status(self, user, claims) -> None: """ Assigns superuser status to the user if the user is a member of at least one specific group. Superuser status is explicitly removed if the user is not or no longer member of at least one of these groups. """ - groups_claim = self.config.groups_claim # can't do an isinstance check here superuser_group_names = cast(list[str], self.config.superuser_group_names) if not superuser_group_names: return - claim_groups = glom(claims, groups_claim, default=[]) + claim_groups = self._retrieve_groups_claim(claims) if set(superuser_group_names) & set(claim_groups): user.is_superuser = True else: user.is_superuser = False user.save() - def update_user_groups(self, user, claims): + def update_user_groups(self, user, claims) -> None: """ Updates user group memberships based on the group_claim setting. Copied and modified from: https://github.com/snok/django-auth-adfs/blob/master/django_auth_adfs/backend.py """ - groups_claim = self.config.groups_claim - - if groups_claim: - # Update the user's group memberships - django_groups = [group.name for group in user.groups.all()] - claim_groups = glom(claims, groups_claim, default=[]) - if claim_groups: - if not isinstance(claim_groups, list): - claim_groups = [ + group_claim_bits: list[str] = self.config.groups_claim + if not group_claim_bits: + return + + claim_groups = self._retrieve_groups_claim(claims) + + # Update the user's group memberships + django_groups = [group.name for group in user.groups.all()] + if claim_groups: + if not isinstance(claim_groups, list): + claim_groups = [ + claim_groups, + ] + else: + logger.debug( + "The configured groups claim '%s' was not found in the access token", + " > ".join(group_claim_bits), + ) + claim_groups = [] + if sorted(claim_groups) != sorted(django_groups): + existing_groups = list( + Group.objects.filter(name__in=claim_groups).iterator() + ) + existing_group_names = frozenset(group.name for group in existing_groups) + new_groups = [] + if self.config.sync_groups: + # Only sync groups that match the supplied glob pattern + new_groups = [ + Group.objects.get_or_create(name=name)[0] + for name in fnmatch.filter( claim_groups, - ] + self.config.sync_groups_glob_pattern, + ) + if name not in existing_group_names + ] else: - logger.debug( - "The configured groups claim '%s' was not found in the access token", - groups_claim, - ) - claim_groups = [] - if sorted(claim_groups) != sorted(django_groups): - existing_groups = list( - Group.objects.filter(name__in=claim_groups).iterator() - ) - existing_group_names = frozenset( - group.name for group in existing_groups - ) - new_groups = [] - if self.config.sync_groups: - # Only sync groups that match the supplied glob pattern - new_groups = [ - Group.objects.get_or_create(name=name)[0] - for name in fnmatch.filter( - claim_groups, - self.config.sync_groups_glob_pattern, - ) - if name not in existing_group_names - ] - else: - for name in claim_groups: - if name not in existing_group_names: - try: - group = Group.objects.get(name=name) - new_groups.append(group) - except ObjectDoesNotExist: - pass - user.groups.set(existing_groups + new_groups) + for name in claim_groups: + if name not in existing_group_names: + try: + group = Group.objects.get(name=name) + new_groups.append(group) + except ObjectDoesNotExist: + pass + user.groups.set(existing_groups + new_groups) def update_user_default_groups(self, user): """ diff --git a/mozilla_django_oidc_db/fields.py b/mozilla_django_oidc_db/fields.py new file mode 100644 index 0000000..d4e0e86 --- /dev/null +++ b/mozilla_django_oidc_db/fields.py @@ -0,0 +1,16 @@ +from django.db import models +from django.utils.translation import gettext_lazy as _ + +from django_jsonform.models.fields import ArrayField + + +class ClaimField(ArrayField): + """ + A field to store a path to claims holding the desired value(s). + + Each item is a segment in the path from the root to leaf for nested claims. + """ + + def __init__(self, *args, **kwargs): + kwargs["base_field"] = models.CharField(_("claim path segment"), max_length=50) + super().__init__(*args, **kwargs) diff --git a/mozilla_django_oidc_db/migrations/0001_initial_to_v015.py b/mozilla_django_oidc_db/migrations/0001_initial_to_v015.py new file mode 100644 index 0000000..5413e21 --- /dev/null +++ b/mozilla_django_oidc_db/migrations/0001_initial_to_v015.py @@ -0,0 +1,291 @@ +# Generated by Django 4.2.9 on 2024-05-01 15:32 + +from django.db import migrations, models + +import django_jsonform.models.fields + +import mozilla_django_oidc_db.models + + +class Migration(migrations.Migration): + + replaces = [ + ("mozilla_django_oidc_db", "0001_initial"), + ( + "mozilla_django_oidc_db", + "0002_openidconnectconfig_oidc_op_discovery_endpoint", + ), + ("mozilla_django_oidc_db", "0003_auto_20210719_0803"), + ("mozilla_django_oidc_db", "0004_auto_20210812_1044"), + ("mozilla_django_oidc_db", "0005_openidconnectconfig_sync_groups_glob_pattern"), + ("mozilla_django_oidc_db", "0006_openidconnectconfig_unique_id_claim"), + ("mozilla_django_oidc_db", "0007_auto_20220307_1128"), + ("mozilla_django_oidc_db", "0008_auto_20220422_0849"), + ("mozilla_django_oidc_db", "0009_openidconnectconfig_default_groups"), + ("mozilla_django_oidc_db", "0010_openidconnectconfig_userinfo_claims_source"), + ( + "mozilla_django_oidc_db", + "0011_alter_openidconnectconfig_userinfo_claims_source", + ), + ("mozilla_django_oidc_db", "0012_openidconnectconfig_superuser_group_names"), + ("mozilla_django_oidc_db", "0012_alter_openidconnectconfig_sync_groups"), + ("mozilla_django_oidc_db", "0013_merge_20231221_1529"), + ("mozilla_django_oidc_db", "0014_alter_openidconnectconfig_groups_claim"), + ( + "mozilla_django_oidc_db", + "0015_openidconnectconfig_oidc_token_use_basic_auth", + ), + ] + + dependencies = [ + ("auth", "0001_initial"), + ] + + operations = [ + migrations.CreateModel( + name="OpenIDConnectConfig", + fields=[ + ( + "id", + models.AutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ( + "enabled", + models.BooleanField( + default=False, + help_text="Indicates whether OpenID Connect for authentication/authorization is enabled", + verbose_name="enable", + ), + ), + ( + "oidc_rp_client_id", + models.CharField( + help_text="OpenID Connect client ID provided by the OIDC Provider", + max_length=1000, + verbose_name="OpenID Connect client ID", + ), + ), + ( + "oidc_rp_client_secret", + models.CharField( + help_text="OpenID Connect secret provided by the OIDC Provider", + max_length=1000, + verbose_name="OpenID Connect secret", + ), + ), + ( + "oidc_rp_sign_algo", + models.CharField( + default="HS256", + help_text="Algorithm the Identity Provider uses to sign ID tokens", + max_length=50, + verbose_name="OpenID sign algorithm", + ), + ), + ( + "oidc_rp_scopes_list", + django_jsonform.models.fields.ArrayField( + base_field=models.CharField( + max_length=50, verbose_name="OpenID Connect scope" + ), + blank=True, + default=mozilla_django_oidc_db.models.get_default_scopes, + help_text="OpenID Connect scopes that are requested during login", + size=None, + verbose_name="OpenID Connect scopes", + ), + ), + ( + "oidc_op_jwks_endpoint", + models.URLField( + blank=True, + help_text="URL of your OpenID Connect provider JSON Web Key Set endpoint. Required if `RS256` is used as signing algorithm.", + max_length=1000, + verbose_name="JSON Web Key Set endpoint", + ), + ), + ( + "oidc_op_authorization_endpoint", + models.URLField( + help_text="URL of your OpenID Connect provider authorization endpoint", + max_length=1000, + verbose_name="Authorization endpoint", + ), + ), + ( + "oidc_op_token_endpoint", + models.URLField( + help_text="URL of your OpenID Connect provider token endpoint", + max_length=1000, + verbose_name="Token endpoint", + ), + ), + ( + "oidc_op_user_endpoint", + models.URLField( + help_text="URL of your OpenID Connect provider userinfo endpoint", + max_length=1000, + verbose_name="User endpoint", + ), + ), + ( + "oidc_rp_idp_sign_key", + models.CharField( + blank=True, + help_text="Key the Identity Provider uses to sign ID tokens in the case of an RSA sign algorithm. Should be the signing key in PEM or DER format.", + max_length=1000, + verbose_name="Sign key", + ), + ), + ( + "oidc_op_discovery_endpoint", + models.URLField( + blank=True, + help_text="URL of your OpenID Connect provider discovery endpoint ending with a slash (`.well-known/...` will be added automatically). If this is provided, the remaining endpoints can be omitted, as they will be derived from this endpoint.", + max_length=1000, + verbose_name="Discovery endpoint", + ), + ), + ( + "claim_mapping", + models.JSONField( + default=mozilla_django_oidc_db.models.get_claim_mapping, + help_text="Mapping from user-model fields to OIDC claims", + verbose_name="claim mapping", + ), + ), + ( + "groups_claim", + models.CharField( + blank=True, + default="roles", + help_text="The name of the OIDC claim that holds the values to map to local user groups.", + max_length=50, + verbose_name="groups claim", + ), + ), + ( + "make_users_staff", + models.BooleanField( + default=False, + help_text="Users will be flagged as being a staff user automatically. This allows users to login to the admin interface. By default they have no permissions, even if they are staff.", + verbose_name="make users staff", + ), + ), + ( + "sync_groups", + models.BooleanField( + default=True, + help_text="If checked, local user groups will be created for group names present in the groups claim, if they do not exist yet locally.", + verbose_name="Create local user groups if they do not exist yet", + ), + ), + ( + "sync_groups_glob_pattern", + models.CharField( + default="*", + help_text="The glob pattern that groups must match to be synchronized to the local database.", + max_length=255, + verbose_name="groups glob pattern", + ), + ), + ( + "username_claim", + models.CharField( + default="sub", + help_text="The name of the OIDC claim that is used as the username", + max_length=50, + verbose_name="username claim", + ), + ), + ( + "oidc_exempt_urls", + django_jsonform.models.fields.ArrayField( + base_field=models.CharField( + max_length=1000, verbose_name="Exempt URL" + ), + blank=True, + default=list, + help_text="This is a list of absolute url paths, regular expressions for url paths, or Django view names. This plus the mozilla-django-oidc urls are exempted from the session renewal by the SessionRefresh middleware.", + size=None, + verbose_name="URLs exempt from session renewal", + ), + ), + ( + "oidc_nonce_size", + models.PositiveIntegerField( + default=32, + help_text="Sets the length of the random string used for OpenID Connect nonce verification", + verbose_name="Nonce size", + ), + ), + ( + "oidc_state_size", + models.PositiveIntegerField( + default=32, + help_text="Sets the length of the random string used for OpenID Connect state verification", + verbose_name="State size", + ), + ), + ( + "oidc_use_nonce", + models.BooleanField( + default=True, + help_text="Controls whether the OpenID Connect client uses nonce verification", + verbose_name="Use nonce", + ), + ), + ( + "default_groups", + models.ManyToManyField( + blank=True, + help_text="The default groups to which every user logging in with OIDC will be assigned", + to="auth.group", + verbose_name="default groups", + ), + ), + ( + "userinfo_claims_source", + models.CharField( + choices=[ + ("userinfo_endpoint", "Userinfo endpoint"), + ("id_token", "ID token"), + ], + default="userinfo_endpoint", + help_text="Indicates the source from which the user information claims should be extracted.", + max_length=100, + verbose_name="user information claims extracted from", + ), + ), + ( + "superuser_group_names", + django_jsonform.models.fields.ArrayField( + base_field=models.CharField( + max_length=50, verbose_name="Superuser group name" + ), + blank=True, + default=list, + help_text="If any of these group names are present in the claims upon login, the user will be marked as a superuser. If none of these groups are present the user will lose superuser permissions.", + size=None, + verbose_name="Superuser group names", + ), + ), + ( + "oidc_token_use_basic_auth", + models.BooleanField( + default=False, + help_text="If enabled, the client ID and secret are sent in the HTTP Basic auth header when obtaining the access token. Otherwise, they are sent in the request body.", + verbose_name="Use Basic auth for token endpoint", + ), + ), + ], + options={ + "verbose_name": "OpenID Connect configuration", + }, + ), + ] diff --git a/mozilla_django_oidc_db/migrations/0002_migrate_to_claim_field.py b/mozilla_django_oidc_db/migrations/0002_migrate_to_claim_field.py new file mode 100644 index 0000000..51df481 --- /dev/null +++ b/mozilla_django_oidc_db/migrations/0002_migrate_to_claim_field.py @@ -0,0 +1,115 @@ +# Generated by Django 4.2.9 on 2024-05-01 16:10 + +from django.conf import settings +from django.core.cache import caches +from django.db import migrations, models, transaction + +import mozilla_django_oidc_db.fields +import mozilla_django_oidc_db.models +import mozilla_django_oidc_db.settings as oidc_settings + + +def flush_cache(): + cache_name = getattr( + settings, + "MOZILLA_DJANGO_OIDC_DB_CACHE", + oidc_settings.MOZILLA_DJANGO_OIDC_DB_CACHE, + ) + if not cache_name: + return + caches[cache_name].clear() + + +def forward(config) -> None: + config.new_username_claim = config.username_claim.split(".") + config.new_groups_claim = config.groups_claim.split(".") + config.claim_mapping = { + key: value.split(".") for key, value in config.claim_mapping.items() + } + + +def reverse(config) -> None: + config.username_claim = ".".join(config.new_username_claim) + config.groups_claim = ".".join(config.new_groups_claim) + config.claim_mapping = { + key: ".".join(value) for key, value in config.claim_mapping.items() + } + + +def action_factory(transformer): + def _run_python_action(apps, _) -> None: + OpenIDConnectConfig = apps.get_model( + "mozilla_django_oidc_db", "OpenIDConnectConfig" + ) + + # Solo model, so there's only ever one instance + config = OpenIDConnectConfig.objects.first() + if config is None: + return + + transformer(config) + + config.save() + transaction.on_commit(flush_cache) + + return _run_python_action + + +copy_forward = action_factory(transformer=forward) +copy_reverse = action_factory(transformer=reverse) + + +class Migration(migrations.Migration): + + dependencies = [ + ("mozilla_django_oidc_db", "0001_initial_to_v015"), + ] + + operations = [ + migrations.AddField( + model_name="openidconnectconfig", + name="new_groups_claim", + field=mozilla_django_oidc_db.fields.ClaimField( + base_field=models.CharField( + max_length=50, verbose_name="claim path segment" + ), + blank=True, + default=mozilla_django_oidc_db.models.get_default_groups_claim, + help_text="The name of the OIDC claim that holds the values to map to local user groups.", + size=None, + verbose_name="groups claim", + ), + ), + migrations.AddField( + model_name="openidconnectconfig", + name="new_username_claim", + field=mozilla_django_oidc_db.fields.ClaimField( + base_field=models.CharField( + max_length=50, verbose_name="claim path segment" + ), + default=mozilla_django_oidc_db.models.get_default_username_claim, + help_text="The name of the OIDC claim that is used as the username", + size=None, + verbose_name="username claim", + ), + ), + migrations.RunPython(copy_forward, copy_reverse), + migrations.RemoveField( + model_name="openidconnectconfig", + name="groups_claim", + ), + migrations.RemoveField( + model_name="openidconnectconfig", + name="username_claim", + ), + migrations.RenameField( + model_name="openidconnectconfig", + old_name="new_groups_claim", + new_name="groups_claim", + ), + migrations.RenameField( + model_name="openidconnectconfig", + old_name="new_username_claim", + new_name="username_claim", + ), + ] diff --git a/mozilla_django_oidc_db/models.py b/mozilla_django_oidc_db/models.py index a768f7b..7027de3 100644 --- a/mozilla_django_oidc_db/models.py +++ b/mozilla_django_oidc_db/models.py @@ -1,5 +1,3 @@ -from typing import Dict, List - from django.conf import settings from django.contrib.auth import get_user_model from django.contrib.auth.models import Group @@ -14,6 +12,7 @@ import mozilla_django_oidc_db.settings as oidc_settings from .compat import classproperty +from .fields import ClaimField class UserInformationClaimsSources(models.TextChoices): @@ -21,23 +20,31 @@ class UserInformationClaimsSources(models.TextChoices): id_token = "id_token", _("ID token") -def get_default_scopes() -> List[str]: +def get_default_scopes() -> list[str]: """ Returns the default scopes to request for OpenID Connect logins """ return ["openid", "email", "profile"] -def get_claim_mapping() -> Dict[str, str]: +def get_claim_mapping() -> dict[str, list[str]]: # Map (some) claim names from https://openid.net/specs/openid-connect-core-1_0.html#Claims # to corresponding field names on the User model return { - "email": "email", - "first_name": "given_name", - "last_name": "family_name", + "email": ["email"], + "first_name": ["given_name"], + "last_name": ["family_name"], } +def get_default_username_claim() -> list[str]: + return ["sub"] + + +def get_default_groups_claim() -> list[str]: + return ["roles"] + + class CachingMixin: @classmethod def clear_cache(cls): @@ -248,26 +255,26 @@ class OpenIDConnectConfig(CachingMixin, OpenIDConnectConfigBase): Configuration for authentication/authorization via OpenID connect """ - username_claim = models.CharField( - _("username claim"), - max_length=50, - default="sub", + username_claim = ClaimField( + verbose_name=_("username claim"), + default=get_default_username_claim, help_text=_("The name of the OIDC claim that is used as the username"), ) + claim_mapping = models.JSONField( _("claim mapping"), default=get_claim_mapping, help_text=("Mapping from user-model fields to OIDC claims"), ) - groups_claim = models.CharField( - _("groups claim"), - max_length=50, - default="roles", + groups_claim = ClaimField( + verbose_name=_("groups claim"), + default=get_default_groups_claim, help_text=_( "The name of the OIDC claim that holds the values to map to local user groups." ), blank=True, ) + sync_groups = models.BooleanField( _("Create local user groups if they do not exist yet"), default=True, diff --git a/mozilla_django_oidc_db/utils.py b/mozilla_django_oidc_db/utils.py index a92793a..8a16aa4 100644 --- a/mozilla_django_oidc_db/utils.py +++ b/mozilla_django_oidc_db/utils.py @@ -1,7 +1,7 @@ from copy import deepcopy -from typing import Any, List +from typing import Any -from glom import assign, glom +from glom import Path, assign, glom from requests.utils import _parse_content_type_header # type: ignore @@ -21,15 +21,15 @@ def obfuscate_claim_value(value: Any) -> str: return "".join([x if i > threshold else "*" for i, x in enumerate(value)]) -def obfuscate_claims(claims: dict, claims_to_obfuscate: List[str]) -> dict: +def obfuscate_claims(claims: dict, claims_to_obfuscate: list[list[str]]) -> dict: """ Obfuscates the specified claims in the specified claims dict """ copied_claims = deepcopy(claims) - for claim_name in claims_to_obfuscate: - # NOTE: this does not support claim names that have dots in them - claim_value = glom(copied_claims, claim_name) - assign(copied_claims, claim_name, obfuscate_claim_value(claim_value)) + for claim_bits in claims_to_obfuscate: + claim_path = Path(*claim_bits) + claim_value = glom(copied_claims, claim_path) + assign(copied_claims, claim_path, obfuscate_claim_value(claim_value)) return copied_claims diff --git a/tests/test_admin_form.py b/tests/test_admin_form.py index f44310d..48a0283 100644 --- a/tests/test_admin_form.py +++ b/tests/test_admin_form.py @@ -22,9 +22,9 @@ def test_derive_endpoints_success(): "oidc_rp_sign_algo": "RS256", "oidc_op_discovery_endpoint": "http://discovery-endpoint.nl/", "claim_mapping": get_claim_mapping(), - "groups_claim": "roles", + "groups_claim": ["roles"], "sync_groups_glob_pattern": "*", - "username_claim": "sub", + "username_claim": ["sub"], "oidc_nonce_size": 32, "oidc_state_size": 32, "userinfo_claims_source": UserInformationClaimsSources.id_token, @@ -71,9 +71,9 @@ def test_derive_endpoints_extra_field(): "oidc_rp_sign_algo": "RS256", "oidc_op_discovery_endpoint": "http://discovery-endpoint.nl/", "claim_mapping": get_claim_mapping(), - "groups_claim": "roles", + "groups_claim": ["roles"], "sync_groups_glob_pattern": "*", - "username_claim": "sub", + "username_claim": ["sub"], "oidc_nonce_size": 32, "oidc_state_size": 32, "userinfo_claims_source": UserInformationClaimsSources.id_token, @@ -119,9 +119,9 @@ def test_derive_endpoints_request_error(*m): "oidc_rp_sign_algo": "RS256", "oidc_op_discovery_endpoint": "http://discovery-endpoint.nl", "claim_mapping": get_claim_mapping(), - "groups_claim": "roles", + "groups_claim": ["roles"], "sync_groups_glob_pattern": "*", - "username_claim": "sub", + "username_claim": ["sub"], "oidc_nonce_size": 32, "oidc_state_size": 32, "userinfo_claims_source": UserInformationClaimsSources.id_token, @@ -145,9 +145,9 @@ def test_derive_endpoints_json_error(*m): "oidc_rp_sign_algo": "RS256", "oidc_op_discovery_endpoint": "http://discovery-endpoint.nl", "claim_mapping": get_claim_mapping(), - "groups_claim": "roles", + "groups_claim": ["roles"], "sync_groups_glob_pattern": "*", - "username_claim": "sub", + "username_claim": ["sub"], "oidc_nonce_size": 32, "oidc_state_size": 32, "userinfo_claims_source": UserInformationClaimsSources.id_token, @@ -169,9 +169,9 @@ def test_no_discovery_endpoint_other_fields_required(): "oidc_rp_client_secret": "secret", "oidc_rp_sign_algo": "RS256", "claim_mapping": get_claim_mapping(), - "groups_claim": "roles", + "groups_claim": ["roles"], "sync_groups_glob_pattern": "*", - "username_claim": "sub", + "username_claim": ["sub"], "oidc_nonce_size": 32, "oidc_state_size": 32, "userinfo_claims_source": UserInformationClaimsSources.id_token, diff --git a/tests/test_backend.py b/tests/test_backend.py index 6e56b36..8dbfe51 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -25,18 +25,20 @@ def test_backend_authenticate_oidc_not_enabled(mock_get_solo): @patch("mozilla_django_oidc_db.models.OpenIDConnectConfig.get_solo") def test_backend_get_sensitive_claims(mock_get_solo): - mock_get_solo.return_value = OpenIDConnectConfig(enabled=True, username_claim="sub") + mock_get_solo.return_value = OpenIDConnectConfig( + enabled=True, username_claim=["sub"] + ) class CustomOIDCBackend(OIDCAuthenticationBackend): - sensitive_claim_names = ["sensitive_claim1", "sensitive_claim2"] + sensitive_claim_names = [["sensitive_claim1"], ["sensitive_claim2"]] backend = CustomOIDCBackend() # Only the sensitive claims + the identifier claim should be obfuscated assert backend.get_sensitive_claims_names() == [ - "sub", - "sensitive_claim1", - "sensitive_claim2", + ["sub"], + ["sensitive_claim1"], + ["sensitive_claim2"], ] @@ -68,9 +70,9 @@ def test_backend_get_user_instance_values(mock_get_solo): def test_backend_get_user_instance_values_nested_claims(mock_get_solo): mock_get_solo.return_value = OpenIDConnectConfig( claim_mapping={ - "email": "user_info.email", - "first_name": "user_info.given_name", - "last_name": "user_info.family_name", + "email": ["user_info", "email"], + "first_name": ["user_info", "given_name"], + "last_name": ["user_info", "family_name"], } ) @@ -94,6 +96,41 @@ def test_backend_get_user_instance_values_nested_claims(mock_get_solo): } +@patch("mozilla_django_oidc_db.models.OpenIDConnectConfig.get_solo") +def test_backend_supports_dots_in_claim_names(mock_get_solo, django_user_model): + user = django_user_model.objects.create_user(username="dummy", password="dummy") + mock_get_solo.return_value = OpenIDConnectConfig( + username_claim=["ns1.sub"], + groups_claim=["ns1.groups"], + claim_mapping={ + "email": ["user_info.email"], + "first_name": ["user_info.given_name"], + "last_name": ["user_info.family_name"], + }, + ) + + claims = { + "ns1.sub": "123456", + "ns1.groups": ["aaaa"], + "user_info.email": "admin@localhost", + "user_info.given_name": "John", + "user_info.family_name": "Doe", + } + + backend = OIDCAuthenticationBackend() + + assert backend.retrieve_identifier_claim(claims) == "123456" + assert backend.get_user_instance_values(claims) == { + "email": "admin@localhost", + "first_name": "John", + "last_name": "Doe", + } + + backend.update_user_groups(user, claims) + group_names = user.groups.values_list("name", flat=True) + assert list(group_names) == ["aaaa"] + + @pytest.mark.django_db @patch("mozilla_django_oidc_db.models.OpenIDConnectConfig.get_solo") def test_backend_create_user(mock_get_solo): @@ -110,8 +147,6 @@ def test_backend_create_user(mock_get_solo): oidc_op_user_endpoint="http://some.endpoint/v1/user", ) - User = get_user_model() - claims = { "sub": "123456", "email": "admin@localhost", @@ -144,11 +179,9 @@ def test_backend_create_user_different_username_claim(mock_get_solo): oidc_op_authorization_endpoint="http://some.endpoint/v1/auth", oidc_op_token_endpoint="http://some.endpoint/v1/token", oidc_op_user_endpoint="http://some.endpoint/v1/user", - username_claim="upn", + username_claim=["upn"], ) - User = get_user_model() - claims = { "sub": "123456", "upn": "admin", @@ -181,7 +214,7 @@ def test_backend_filter_users(mock_get_solo): oidc_op_authorization_endpoint="http://some.endpoint/v1/auth", oidc_op_token_endpoint="http://some.endpoint/v1/token", oidc_op_user_endpoint="http://some.endpoint/v1/user", - username_claim="sub", + username_claim=["sub"], ) User = get_user_model() @@ -229,7 +262,7 @@ def test_backend_filter_users_different_username_claim(mock_get_solo): oidc_op_authorization_endpoint="http://some.endpoint/v1/auth", oidc_op_token_endpoint="http://some.endpoint/v1/token", oidc_op_user_endpoint="http://some.endpoint/v1/user", - username_claim="upn", + username_claim=["upn"], ) User = get_user_model() @@ -278,7 +311,7 @@ def test_backend_update_user(mock_get_solo): oidc_op_authorization_endpoint="http://some.endpoint/v1/auth", oidc_op_token_endpoint="http://some.endpoint/v1/token", oidc_op_user_endpoint="http://some.endpoint/v1/user", - username_claim="sub", + username_claim=["sub"], ) User = get_user_model() @@ -329,7 +362,7 @@ def test_backend_create_user_sync_all_groups(mock_get_solo): oidc_op_authorization_endpoint="http://some.endpoint/v1/auth", oidc_op_token_endpoint="http://some.endpoint/v1/token", oidc_op_user_endpoint="http://some.endpoint/v1/user", - groups_claim="roles", + groups_claim=["roles"], sync_groups=True, sync_groups_glob_pattern="*", ) @@ -374,7 +407,7 @@ def test_backend_create_user_no_groups_sync_without_groups_claim(mock_get_solo): oidc_op_authorization_endpoint="http://some.endpoint/v1/auth", oidc_op_token_endpoint="http://some.endpoint/v1/token", oidc_op_user_endpoint="http://some.endpoint/v1/user", - groups_claim="", + groups_claim=[], sync_groups=True, sync_groups_glob_pattern="*", ) @@ -415,7 +448,7 @@ def test_backend_create_user_sync_groups_according_to_pattern(mock_get_solo): oidc_op_authorization_endpoint="http://some.endpoint/v1/auth", oidc_op_token_endpoint="http://some.endpoint/v1/token", oidc_op_user_endpoint="http://some.endpoint/v1/user", - groups_claim="roles", + groups_claim=["roles"], sync_groups=True, sync_groups_glob_pattern="group*", ) @@ -451,7 +484,7 @@ def test_backend_create_user_sync_all_groups_nested_groups_claim(mock_get_solo): oidc_op_authorization_endpoint="http://some.endpoint/v1/auth", oidc_op_token_endpoint="http://some.endpoint/v1/token", oidc_op_user_endpoint="http://some.endpoint/v1/user", - groups_claim="nested_object.roles", + groups_claim=["nested_object", "roles"], sync_groups=True, sync_groups_glob_pattern="*", ) @@ -496,7 +529,7 @@ def test_backend_create_user_sync_all_groups_and_default_groups(mock_get_solo): oidc_op_authorization_endpoint="http://some.endpoint/v1/auth", oidc_op_token_endpoint="http://some.endpoint/v1/token", oidc_op_user_endpoint="http://some.endpoint/v1/user", - groups_claim="roles", + groups_claim=["roles"], sync_groups=True, sync_groups_glob_pattern="*", ) @@ -545,7 +578,7 @@ def test_backend_create_user_sync_groups_according_to_pattern_and_default_groups oidc_op_authorization_endpoint="http://some.endpoint/v1/auth", oidc_op_token_endpoint="http://some.endpoint/v1/token", oidc_op_user_endpoint="http://some.endpoint/v1/user", - groups_claim="roles", + groups_claim=["roles"], sync_groups=True, sync_groups_glob_pattern="group*", ) @@ -587,13 +620,13 @@ def test_backend_create_user_with_profile_settings(mock_get_solo): oidc_op_authorization_endpoint="http://some.endpoint/v1/auth", oidc_op_token_endpoint="http://some.endpoint/v1/token", oidc_op_user_endpoint="http://some.endpoint/v1/user", - groups_claim="roles", + groups_claim=["roles"], sync_groups=True, claim_mapping={ - "first_name": "given_name", - "last_name": "family_name", - "email": "email", - "is_superuser": "is_god", + "first_name": ["given_name"], + "last_name": ["family_name"], + "email": ["email"], + "is_superuser": ["is_god"], }, sync_groups_glob_pattern="*", make_users_staff=True, @@ -668,7 +701,7 @@ def test_backend_update_user_superuser(mock_get_solo): oidc_op_authorization_endpoint="http://some.endpoint/v1/auth", oidc_op_token_endpoint="http://some.endpoint/v1/token", oidc_op_user_endpoint="http://some.endpoint/v1/user", - groups_claim="roles", + groups_claim=["roles"], sync_groups=False, superuser_group_names=["superuser"], ) @@ -705,7 +738,7 @@ def test_backend_update_user_remove_superuser(mock_get_solo): oidc_op_authorization_endpoint="http://some.endpoint/v1/auth", oidc_op_token_endpoint="http://some.endpoint/v1/token", oidc_op_user_endpoint="http://some.endpoint/v1/user", - groups_claim="roles", + groups_claim=["roles"], sync_groups=False, superuser_group_names=["superuser"], ) @@ -753,7 +786,7 @@ def test_backend_update_user_no_superuser_group_names( oidc_op_authorization_endpoint="http://some.endpoint/v1/auth", oidc_op_token_endpoint="http://some.endpoint/v1/token", oidc_op_user_endpoint="http://some.endpoint/v1/user", - groups_claim="roles", + groups_claim=["roles"], sync_groups=False, superuser_group_names=[], ) diff --git a/tests/test_utils.py b/tests/test_utils.py index 0fcd0c6..05fba34 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -18,6 +18,7 @@ def test_obfuscate_non_string(): def test_obfuscate_nested(): claims = { "foo": "not_obfuscated", + "foo.bar": "obfuscated", "some": { "nested": { "claim": "obfuscated", @@ -29,9 +30,14 @@ def test_obfuscate_nested(): "bar": "obfuscated", }, } - claims_to_obfuscate = ["some.nested.claim", "object"] + claims_to_obfuscate = [ + ["foo.bar"], + ["some", "nested", "claim"], + ["object"], + ] expected_result = { "foo": "not_obfuscated", + "foo.bar": "********ed", "some": {"nested": {"claim": "********ed", "claim2": "not_obfuscated"}}, "object": {"foo": "********ed", "bar": "********ed"}, }