Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: oauth user case sensitivity #4207

Merged
merged 7 commits into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 60 additions & 3 deletions api/custom_auth/oauth/serializers.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
from abc import abstractmethod

from django.conf import settings
from django.contrib.auth import get_user_model
from django.contrib.auth.signals import user_logged_in
from django.db.models import F
from rest_framework import serializers
from rest_framework.authtoken.models import Token
from rest_framework.exceptions import PermissionDenied

from organisations.invites.models import Invite
from users.auth_type import AuthType
from users.models import SignUpType

from ..constants import USER_REGISTRATION_WITHOUT_INVITE_ERROR_MESSAGE
Expand All @@ -30,6 +34,9 @@ class OAuthLoginSerializer(serializers.Serializer):
write_only=True,
)

auth_type: AuthType | None = None
user_model_id_attribute: str = "id"

class Meta:
abstract = True

Expand All @@ -53,8 +60,28 @@ def create(self, validated_data):
return Token.objects.get_or_create(user=user)[0]

def _get_user(self, user_data: dict):
email = user_data.get("email")
existing_user = UserModel.objects.filter(email=email).first()
email: str = user_data.pop("email")

# There are a number of scenarios that we're catering for in this
# query:
# 1. A new user arriving, and immediately authenticating with
# the given social auth method.
# 2. A user that has previously authenticated with method A is now
# authenticating with method B. Using the `email__iexact` means
# that we'll always retrieve the user that already authenticated
# with A.
# 3. A user that (prior to the case sensitivity fix) authenticated
# with multiple methods and ended up with duplicate user accounts.
# Since it's difficult for us to know which user account they are
# using as their primary, we order by the method they are currently
# authenticating with and grab the first one in the list.
existing_user = (
UserModel.objects.filter(email__iexact=email)
.order_by(
F(self.get_user_model_id_attribute()).desc(nulls_last=True),
)
.first()
)

if not existing_user:
sign_up_type = self.validated_data.get("sign_up_type")
Expand All @@ -65,20 +92,50 @@ def _get_user(self, user_data: dict):
):
raise PermissionDenied(USER_REGISTRATION_WITHOUT_INVITE_ERROR_MESSAGE)

return UserModel.objects.create(**user_data, sign_up_type=sign_up_type)
return UserModel.objects.create(
**user_data, email=email.lower(), sign_up_type=sign_up_type
)
elif existing_user.auth_type != self.get_auth_type().value:
# In this scenario, we're seeing a user that had previously
# authenticated with another authentication method and is now
# authenticating with a new OAuth provider.
user_model_id_attribute = self.get_user_model_id_attribute()
setattr(
existing_user,
user_model_id_attribute,
user_data[user_model_id_attribute],
)
existing_user.save()

return existing_user

@abstractmethod
def get_user_info(self):
raise NotImplementedError("`get_user_info()` must be implemented.")

def get_auth_type(self) -> AuthType:
if not self.auth_type: # pragma: no cover
raise NotImplementedError(
"`auth_type` must be set, or `get_auth_type()` must be implemented."
)
return self.auth_type

def get_user_model_id_attribute(self) -> str:
return self.user_model_id_attribute
matthewelwell marked this conversation as resolved.
Show resolved Hide resolved


class GoogleLoginSerializer(OAuthLoginSerializer):
auth_type = AuthType.GOOGLE
user_model_id_attribute = "google_user_id"

def get_user_info(self):
return get_user_info(self.validated_data["access_token"])


class GithubLoginSerializer(OAuthLoginSerializer):
auth_type = AuthType.GITHUB
user_model_id_attribute = "github_user_id"

def get_user_info(self):
github_user = GithubUser(code=self.validated_data["access_token"])
return github_user.get_user_info()
157 changes: 155 additions & 2 deletions api/tests/unit/custom_auth/oauth/test_unit_oauth_views.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from unittest import mock

from django.db.models import Model
from django.test import override_settings
from django.urls import reverse
from pytest_mock import MockerFixture
from rest_framework import status
from rest_framework.test import APIClient

Expand Down Expand Up @@ -103,7 +105,12 @@ def test_can_login_with_google_if_registration_disabled(
client = APIClient()

email = "test@example.com"
mock_get_user_info.return_value = {"email": email}
mock_get_user_info.return_value = {
"email": email,
"first_name": "John",
"last_name": "Smith",
"google_user_id": "abc123",
}
django_user_model.objects.create(email=email)

# When
Expand All @@ -126,7 +133,12 @@ def test_can_login_with_github_if_registration_disabled(
email = "test@example.com"
mock_github_user = mock.MagicMock()
MockGithubUser.return_value = mock_github_user
mock_github_user.get_user_info.return_value = {"email": email}
mock_github_user.get_user_info.return_value = {
"email": email,
"first_name": "John",
"last_name": "Smith",
"github_user_id": "abc123",
}
django_user_model.objects.create(email=email)

# When
Expand All @@ -135,3 +147,144 @@ def test_can_login_with_github_if_registration_disabled(
# Then
assert response.status_code == status.HTTP_200_OK
assert "key" in response.json()


def test_login_with_google_updates_existing_user_case_insensitive(
db: None,
django_user_model: type[Model],
mocker: MockerFixture,
api_client: APIClient,
) -> None:
# Given
email_lower = "test@example.com"
email_upper = email_lower.upper()
google_user_id = "abc123"

django_user_model.objects.create(email=email_lower)

mocker.patch(
"custom_auth.oauth.serializers.get_user_info",
return_value={
"email": email_upper,
"first_name": "John",
"last_name": "Smith",
"google_user_id": google_user_id,
},
)

url = reverse("api-v1:custom_auth:oauth:google-oauth-login")

# When
response = api_client.post(url, data={"access_token": "some-token"})

# Then
assert response.status_code == status.HTTP_200_OK

qs = django_user_model.objects.filter(email__iexact=email_lower)
assert qs.count() == 1

user = qs.first()
assert user.email == email_lower
assert user.google_user_id == google_user_id


def test_login_with_github_updates_existing_user_case_insensitive(
db: None,
django_user_model: type[Model],
mocker: MockerFixture,
api_client: APIClient,
) -> None:
# Given
email_lower = "test@example.com"
email_upper = email_lower.upper()
github_user_id = "abc123"

django_user_model.objects.create(email=email_lower)

mock_github_user = mock.MagicMock()
mocker.patch(
"custom_auth.oauth.serializers.GithubUser", return_value=mock_github_user
)
mock_github_user.get_user_info.return_value = {
"email": email_upper,
"first_name": "John",
"last_name": "Smith",
"github_user_id": github_user_id,
}

url = reverse("api-v1:custom_auth:oauth:github-oauth-login")

# When
response = api_client.post(url, data={"access_token": "some-token"})

# Then
assert response.status_code == status.HTTP_200_OK

qs = django_user_model.objects.filter(email__iexact=email_lower)
assert qs.count() == 1

user = qs.first()
assert user.email == email_lower
assert user.github_user_id == github_user_id


def test_user_with_duplicate_accounts_authenticates_as_the_correct_oauth_user(
db: None,
django_user_model: type[Model],
api_client: APIClient,
mocker: MockerFixture,
) -> None:
"""
Specific test to verify the correct behaviour for users affected by
https://github.com/Flagsmith/flagsmith/issues/4185.
"""

# Given
email_lower = "test@example.com"
email_upper = email_lower.upper()

github_user = django_user_model.objects.create(
email=email_lower, github_user_id="abc123"
)
google_user = django_user_model.objects.create(
email=email_upper, google_user_id="abc123"
)

mock_github_user = mock.MagicMock()
mocker.patch(
"custom_auth.oauth.serializers.GithubUser", return_value=mock_github_user
)
mock_github_user.get_user_info.return_value = {
"email": email_lower,
"first_name": "John",
"last_name": "Smith",
"github_user_id": github_user.github_user_id,
}

mocker.patch(
"custom_auth.oauth.serializers.get_user_info",
return_value={
"email": email_upper,
"first_name": "John",
"last_name": "Smith",
"google_user_id": google_user.google_user_id,
},
)

github_auth_url = reverse("api-v1:custom_auth:oauth:github-oauth-login")
google_auth_url = reverse("api-v1:custom_auth:oauth:google-oauth-login")

# When
auth_with_github_response = api_client.post(
github_auth_url, data={"access_token": "some-token"}
)
auth_with_google_response = api_client.post(
google_auth_url, data={"access_token": "some-token"}
)

# Then
github_auth_key = auth_with_github_response.json().get("key")
assert github_auth_key == github_user.auth_token.key

google_auth_key = auth_with_google_response.json().get("key")
assert google_auth_key == google_user.auth_token.key
Loading