Skip to content

Commit

Permalink
providers/oauth2: Add provider federation between OAuth2 Providers (#…
Browse files Browse the repository at this point in the history
…12083)

* rename + add field

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* initial implementation

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* fix

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* fix

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* refactor

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* rework source cc tests

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* add tests

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* update docs

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* re-migrate

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* fix a

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* Apply suggestions from code review

Co-authored-by: Tana M Berry <tanamarieberry@yahoo.com>
Signed-off-by: Jens L. <jens@beryju.org>

---------

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
Signed-off-by: Jens L. <jens@beryju.org>
Co-authored-by: Tana M Berry <tanamarieberry@yahoo.com>
  • Loading branch information
BeryJu and tanberry authored Dec 3, 2024
1 parent 4aeb7c8 commit 19488b7
Show file tree
Hide file tree
Showing 17 changed files with 509 additions and 63 deletions.
8 changes: 7 additions & 1 deletion authentik/blueprints/v1/importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,12 @@
from authentik.outposts.models import OutpostServiceConnection
from authentik.policies.models import Policy, PolicyBindingModel
from authentik.policies.reputation.models import Reputation
from authentik.providers.oauth2.models import AccessToken, AuthorizationCode, RefreshToken
from authentik.providers.oauth2.models import (
AccessToken,
AuthorizationCode,
DeviceToken,
RefreshToken,
)
from authentik.providers.scim.models import SCIMProviderGroup, SCIMProviderUser
from authentik.rbac.models import Role
from authentik.sources.scim.models import SCIMSourceGroup, SCIMSourceUser
Expand Down Expand Up @@ -125,6 +130,7 @@ def excluded_models() -> list[type[Model]]:
MicrosoftEntraProviderGroup,
EndpointDevice,
EndpointDeviceConnection,
DeviceToken,
)


Expand Down
3 changes: 2 additions & 1 deletion authentik/providers/oauth2/api/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ class Meta:
"sub_mode",
"property_mappings",
"issuer_mode",
"jwks_sources",
"jwt_federation_sources",
"jwt_federation_providers",
]
extra_kwargs = ProviderSerializer.Meta.extra_kwargs

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Generated by Django 5.0.9 on 2024-11-22 14:25

from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
("authentik_providers_oauth2", "0024_remove_oauth2provider_redirect_uris_and_more"),
]

operations = [
migrations.RenameField(
model_name="oauth2provider",
old_name="jwks_sources",
new_name="jwt_federation_sources",
),
migrations.AddField(
model_name="oauth2provider",
name="jwt_federation_providers",
field=models.ManyToManyField(
blank=True, default=None, to="authentik_providers_oauth2.oauth2provider"
),
),
]
3 changes: 2 additions & 1 deletion authentik/providers/oauth2/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ class OAuth2Provider(WebfingerProvider, Provider):
related_name="oauth2provider_encryption_key_set",
)

jwks_sources = models.ManyToManyField(
jwt_federation_sources = models.ManyToManyField(
OAuthSource,
verbose_name=_(
"Any JWT signed by the JWK of the selected source can be used to authenticate."
Expand All @@ -253,6 +253,7 @@ class OAuth2Provider(WebfingerProvider, Provider):
default=None,
blank=True,
)
jwt_federation_providers = models.ManyToManyField("OAuth2Provider", blank=True, default=None)

@cached_property
def jwt_key(self) -> tuple[str | PrivateKeyTypes, str]:
Expand Down
228 changes: 228 additions & 0 deletions authentik/providers/oauth2/tests/test_token_cc_jwt_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
"""Test token view"""

from datetime import datetime, timedelta
from json import loads

from django.test import RequestFactory
from django.urls import reverse
from django.utils.timezone import now
from jwt import decode

from authentik.blueprints.tests import apply_blueprint
from authentik.core.models import Application, Group
from authentik.core.tests.utils import create_test_cert, create_test_flow, create_test_user
from authentik.lib.generators import generate_id
from authentik.policies.models import PolicyBinding
from authentik.providers.oauth2.constants import (
GRANT_TYPE_CLIENT_CREDENTIALS,
SCOPE_OPENID,
SCOPE_OPENID_EMAIL,
SCOPE_OPENID_PROFILE,
TOKEN_TYPE,
)
from authentik.providers.oauth2.models import (
AccessToken,
OAuth2Provider,
RedirectURI,
RedirectURIMatchingMode,
ScopeMapping,
)
from authentik.providers.oauth2.tests.utils import OAuthTestCase


class TestTokenClientCredentialsJWTProvider(OAuthTestCase):
"""Test token (client_credentials, with JWT) view"""

@apply_blueprint("system/providers-oauth2.yaml")
def setUp(self) -> None:
super().setUp()
self.factory = RequestFactory()
self.other_cert = create_test_cert()
self.cert = create_test_cert()

self.other_provider = OAuth2Provider.objects.create(
name=generate_id(),
authorization_flow=create_test_flow(),
signing_key=self.other_cert,
)
self.other_provider.property_mappings.set(ScopeMapping.objects.all())
self.app = Application.objects.create(
name=generate_id(), slug=generate_id(), provider=self.other_provider
)

self.provider: OAuth2Provider = OAuth2Provider.objects.create(
name="test",
authorization_flow=create_test_flow(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://testserver")],
signing_key=self.cert,
)
self.provider.jwt_federation_providers.add(self.other_provider)
self.provider.property_mappings.set(ScopeMapping.objects.all())
self.app = Application.objects.create(name="test", slug="test", provider=self.provider)

def test_invalid_type(self):
"""test invalid type"""
response = self.client.post(
reverse("authentik_providers_oauth2:token"),
{
"grant_type": GRANT_TYPE_CLIENT_CREDENTIALS,
"scope": f"{SCOPE_OPENID} {SCOPE_OPENID_EMAIL} {SCOPE_OPENID_PROFILE}",
"client_id": self.provider.client_id,
"client_assertion_type": "foo",
"client_assertion": "foo.bar",
},
)
self.assertEqual(response.status_code, 400)
body = loads(response.content.decode())
self.assertEqual(body["error"], "invalid_grant")

def test_invalid_jwt(self):
"""test invalid JWT"""
response = self.client.post(
reverse("authentik_providers_oauth2:token"),
{
"grant_type": GRANT_TYPE_CLIENT_CREDENTIALS,
"scope": f"{SCOPE_OPENID} {SCOPE_OPENID_EMAIL} {SCOPE_OPENID_PROFILE}",
"client_id": self.provider.client_id,
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
"client_assertion": "foo.bar",
},
)
self.assertEqual(response.status_code, 400)
body = loads(response.content.decode())
self.assertEqual(body["error"], "invalid_grant")

def test_invalid_signature(self):
"""test invalid JWT"""
token = self.provider.encode(
{
"sub": "foo",
"exp": datetime.now() + timedelta(hours=2),
}
)
response = self.client.post(
reverse("authentik_providers_oauth2:token"),
{
"grant_type": GRANT_TYPE_CLIENT_CREDENTIALS,
"scope": f"{SCOPE_OPENID} {SCOPE_OPENID_EMAIL} {SCOPE_OPENID_PROFILE}",
"client_id": self.provider.client_id,
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
"client_assertion": token + "foo",
},
)
self.assertEqual(response.status_code, 400)
body = loads(response.content.decode())
self.assertEqual(body["error"], "invalid_grant")

def test_invalid_expired(self):
"""test invalid JWT"""
token = self.provider.encode(
{
"sub": "foo",
"exp": datetime.now() - timedelta(hours=2),
}
)
response = self.client.post(
reverse("authentik_providers_oauth2:token"),
{
"grant_type": GRANT_TYPE_CLIENT_CREDENTIALS,
"scope": f"{SCOPE_OPENID} {SCOPE_OPENID_EMAIL} {SCOPE_OPENID_PROFILE}",
"client_id": self.provider.client_id,
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
"client_assertion": token,
},
)
self.assertEqual(response.status_code, 400)
body = loads(response.content.decode())
self.assertEqual(body["error"], "invalid_grant")

def test_invalid_no_app(self):
"""test invalid JWT"""
self.app.provider = None
self.app.save()
token = self.provider.encode(
{
"sub": "foo",
"exp": datetime.now() + timedelta(hours=2),
}
)
response = self.client.post(
reverse("authentik_providers_oauth2:token"),
{
"grant_type": GRANT_TYPE_CLIENT_CREDENTIALS,
"scope": f"{SCOPE_OPENID} {SCOPE_OPENID_EMAIL} {SCOPE_OPENID_PROFILE}",
"client_id": self.provider.client_id,
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
"client_assertion": token,
},
)
self.assertEqual(response.status_code, 400)
body = loads(response.content.decode())
self.assertEqual(body["error"], "invalid_grant")

def test_invalid_access_denied(self):
"""test invalid JWT"""
group = Group.objects.create(name="foo")
PolicyBinding.objects.create(
group=group,
target=self.app,
order=0,
)
token = self.provider.encode(
{
"sub": "foo",
"exp": datetime.now() + timedelta(hours=2),
}
)
response = self.client.post(
reverse("authentik_providers_oauth2:token"),
{
"grant_type": GRANT_TYPE_CLIENT_CREDENTIALS,
"scope": f"{SCOPE_OPENID} {SCOPE_OPENID_EMAIL} {SCOPE_OPENID_PROFILE}",
"client_id": self.provider.client_id,
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
"client_assertion": token,
},
)
self.assertEqual(response.status_code, 400)
body = loads(response.content.decode())
self.assertEqual(body["error"], "invalid_grant")

def test_successful(self):
"""test successful"""
user = create_test_user()
token = self.other_provider.encode(
{
"sub": "foo",
"exp": datetime.now() + timedelta(hours=2),
}
)
AccessToken.objects.create(
provider=self.other_provider,
token=token,
user=user,
auth_time=now(),
)

response = self.client.post(
reverse("authentik_providers_oauth2:token"),
{
"grant_type": GRANT_TYPE_CLIENT_CREDENTIALS,
"scope": f"{SCOPE_OPENID} {SCOPE_OPENID_EMAIL} {SCOPE_OPENID_PROFILE}",
"client_id": self.provider.client_id,
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
"client_assertion": token,
},
)
self.assertEqual(response.status_code, 200)
body = loads(response.content.decode())
self.assertEqual(body["token_type"], TOKEN_TYPE)
_, alg = self.provider.jwt_key
jwt = decode(
body["access_token"],
key=self.provider.signing_key.public_key,
algorithms=[alg],
audience=self.provider.client_id,
)
self.assertEqual(jwt["given_name"], user.name)
self.assertEqual(jwt["preferred_username"], user.username)
21 changes: 14 additions & 7 deletions authentik/providers/oauth2/tests/test_token_cc_jwt_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,16 @@ class TestTokenClientCredentialsJWTSource(OAuthTestCase):
def setUp(self) -> None:
super().setUp()
self.factory = RequestFactory()
self.other_cert = create_test_cert()
# Provider used as a helper to sign JWTs with the same key as the OAuth source has
self.helper_provider = OAuth2Provider.objects.create(
name=generate_id(),
authorization_flow=create_test_flow(),
signing_key=self.other_cert,
)
self.cert = create_test_cert()

jwk = JWKSView().get_jwk_for_key(self.cert, "sig")
jwk = JWKSView().get_jwk_for_key(self.other_cert, "sig")
self.source: OAuthSource = OAuthSource.objects.create(
name=generate_id(),
slug=generate_id(),
Expand All @@ -62,7 +69,7 @@ def setUp(self) -> None:
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://testserver")],
signing_key=self.cert,
)
self.provider.jwks_sources.add(self.source)
self.provider.jwt_federation_sources.add(self.source)
self.provider.property_mappings.set(ScopeMapping.objects.all())
self.app = Application.objects.create(name="test", slug="test", provider=self.provider)

Expand Down Expand Up @@ -100,7 +107,7 @@ def test_invalid_jwt(self):

def test_invalid_signature(self):
"""test invalid JWT"""
token = self.provider.encode(
token = self.helper_provider.encode(
{
"sub": "foo",
"exp": datetime.now() + timedelta(hours=2),
Expand All @@ -122,7 +129,7 @@ def test_invalid_signature(self):

def test_invalid_expired(self):
"""test invalid JWT"""
token = self.provider.encode(
token = self.helper_provider.encode(
{
"sub": "foo",
"exp": datetime.now() - timedelta(hours=2),
Expand All @@ -146,7 +153,7 @@ def test_invalid_no_app(self):
"""test invalid JWT"""
self.app.provider = None
self.app.save()
token = self.provider.encode(
token = self.helper_provider.encode(
{
"sub": "foo",
"exp": datetime.now() + timedelta(hours=2),
Expand Down Expand Up @@ -174,7 +181,7 @@ def test_invalid_access_denied(self):
target=self.app,
order=0,
)
token = self.provider.encode(
token = self.helper_provider.encode(
{
"sub": "foo",
"exp": datetime.now() + timedelta(hours=2),
Expand All @@ -196,7 +203,7 @@ def test_invalid_access_denied(self):

def test_successful(self):
"""test successful"""
token = self.provider.encode(
token = self.helper_provider.encode(
{
"sub": "foo",
"exp": datetime.now() + timedelta(hours=2),
Expand Down
2 changes: 1 addition & 1 deletion authentik/providers/oauth2/views/device_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def validate_code(self, code: int) -> HttpResponse | None:


class OAuthDeviceCodeStage(ChallengeStageView):
"""Flow challenge for users to enter device codes"""
"""Flow challenge for users to enter device code"""

response_class = OAuthDeviceCodeChallengeResponse

Expand Down
Loading

0 comments on commit 19488b7

Please sign in to comment.