Skip to content

Commit

Permalink
providers/oauth2: add initial JWE support (#11344)
Browse files Browse the repository at this point in the history
* providers/oauth2: add initial JWE support

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* re-migrate, only set id_token_encryption_* when encryption key is set

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* add docs

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* add jwks test with encryption

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

---------

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
  • Loading branch information
BeryJu authored Oct 17, 2024
1 parent fc1f146 commit 47206d3
Show file tree
Hide file tree
Showing 18 changed files with 329 additions and 35 deletions.
1 change: 1 addition & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"authn",
"entra",
"goauthentik",
"jwe",
"jwks",
"kubernetes",
"oidc",
Expand Down
1 change: 1 addition & 0 deletions authentik/providers/oauth2/api/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class Meta:
"refresh_token_validity",
"include_claims_in_id_token",
"signing_key",
"encryption_key",
"redirect_uris",
"sub_mode",
"property_mappings",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Generated by Django 5.0.9 on 2024-10-16 14:53

import django.db.models.deletion
from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
("authentik_crypto", "0004_alter_certificatekeypair_name"),
(
"authentik_providers_oauth2",
"0020_remove_accesstoken_authentik_p_token_4bc870_idx_and_more",
),
]

operations = [
migrations.AddField(
model_name="oauth2provider",
name="encryption_key",
field=models.ForeignKey(
help_text="Key used to encrypt the tokens. When set, tokens will be encrypted and returned as JWEs.",
null=True,
on_delete=django.db.models.deletion.SET_NULL,
related_name="oauth2provider_encryption_key_set",
to="authentik_crypto.certificatekeypair",
verbose_name="Encryption Key",
),
),
migrations.AlterField(
model_name="oauth2provider",
name="signing_key",
field=models.ForeignKey(
help_text="Key used to sign the tokens.",
null=True,
on_delete=django.db.models.deletion.SET_NULL,
related_name="oauth2provider_signing_key_set",
to="authentik_crypto.certificatekeypair",
verbose_name="Signing Key",
),
),
]
37 changes: 35 additions & 2 deletions authentik/providers/oauth2/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
from django.templatetags.static import static
from django.urls import reverse
from django.utils.translation import gettext_lazy as _
from jwcrypto.common import json_encode
from jwcrypto.jwe import JWE
from jwcrypto.jwk import JWK
from jwt import encode
from rest_framework.serializers import Serializer
from structlog.stdlib import get_logger
Expand Down Expand Up @@ -206,9 +209,19 @@ class OAuth2Provider(WebfingerProvider, Provider):
verbose_name=_("Signing Key"),
on_delete=models.SET_NULL,
null=True,
help_text=_("Key used to sign the tokens."),
related_name="oauth2provider_signing_key_set",
)
encryption_key = models.ForeignKey(
CertificateKeyPair,
verbose_name=_("Encryption Key"),
on_delete=models.SET_NULL,
null=True,
help_text=_(
"Key used to sign the tokens. Only required when JWT Algorithm is set to RS256."
"Key used to encrypt the tokens. When set, "
"tokens will be encrypted and returned as JWEs."
),
related_name="oauth2provider_encryption_key_set",
)

jwks_sources = models.ManyToManyField(
Expand Down Expand Up @@ -287,7 +300,27 @@ def encode(self, payload: dict[str, Any]) -> str:
if self.signing_key:
headers["kid"] = self.signing_key.kid
key, alg = self.jwt_key
return encode(payload, key, algorithm=alg, headers=headers)
encoded = encode(payload, key, algorithm=alg, headers=headers)
if self.encryption_key:
return self.encrypt(encoded)
return encoded

def encrypt(self, raw: str) -> str:
"""Encrypt JWT"""
key = JWK.from_pem(self.encryption_key.certificate_data.encode())
jwe = JWE(
raw,
json_encode(
{
"alg": "RSA-OAEP-256",
"enc": "A256CBC-HS512",
"typ": "JWE",
"kid": self.encryption_key.kid,
}
),
)
jwe.add_recipient(key)
return jwe.serialize(compact=True)

def webfinger(self, resource: str, request: HttpRequest):
return {
Expand Down
67 changes: 67 additions & 0 deletions authentik/providers/oauth2/tests/test_authorize.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,73 @@ def test_full_implicit(self):
delta=5,
)

@apply_blueprint("system/providers-oauth2.yaml")
def test_full_implicit_enc(self):
"""Test full authorization with encryption"""
flow = create_test_flow()
provider: OAuth2Provider = OAuth2Provider.objects.create(
name=generate_id(),
client_id="test",
authorization_flow=flow,
redirect_uris="http://localhost",
signing_key=self.keypair,
encryption_key=self.keypair,
)
provider.property_mappings.set(
ScopeMapping.objects.filter(
managed__in=[
"goauthentik.io/providers/oauth2/scope-openid",
"goauthentik.io/providers/oauth2/scope-email",
"goauthentik.io/providers/oauth2/scope-profile",
]
)
)
provider.property_mappings.add(
ScopeMapping.objects.create(
name=generate_id(), scope_name="test", expression="""return {"sub": "foo"}"""
)
)
Application.objects.create(name=generate_id(), slug=generate_id(), provider=provider)
state = generate_id()
user = create_test_admin_user()
self.client.force_login(user)
with patch(
"authentik.providers.oauth2.id_token.get_login_event",
MagicMock(
return_value=Event(
action=EventAction.LOGIN,
context={PLAN_CONTEXT_METHOD: "password"},
created=now(),
)
),
):
# Step 1, initiate params and get redirect to flow
self.client.get(
reverse("authentik_providers_oauth2:authorize"),
data={
"response_type": "id_token",
"client_id": "test",
"state": state,
"scope": "openid test",
"redirect_uri": "http://localhost",
"nonce": generate_id(),
},
)
response = self.client.get(
reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}),
)
self.assertEqual(response.status_code, 200)
token: AccessToken = AccessToken.objects.filter(user=user).first()
expires = timedelta_from_string(provider.access_token_validity).total_seconds()
jwt = self.validate_jwe(token, provider)
self.assertEqual(jwt["amr"], ["pwd"])
self.assertEqual(jwt["sub"], "foo")
self.assertAlmostEqual(
jwt["exp"] - now().timestamp(),
expires,
delta=5,
)

def test_full_fragment_code(self):
"""Test full authorization"""
flow = create_test_flow()
Expand Down
18 changes: 18 additions & 0 deletions authentik/providers/oauth2/tests/test_jwks.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,24 @@ def test_es256(self):
self.assertEqual(len(body["keys"]), 1)
PyJWKSet.from_dict(body)

def test_enc(self):
"""Test with JWE"""
provider = OAuth2Provider.objects.create(
name="test",
client_id="test",
authorization_flow=create_test_flow(),
redirect_uris="http://local.invalid",
signing_key=create_test_cert(PrivateKeyAlg.ECDSA),
encryption_key=create_test_cert(PrivateKeyAlg.ECDSA),
)
app = Application.objects.create(name="test", slug="test", provider=provider)
response = self.client.get(
reverse("authentik_providers_oauth2:jwks", kwargs={"application_slug": app.slug})
)
body = json.loads(response.content.decode())
self.assertEqual(len(body["keys"]), 2)
PyJWKSet.from_dict(body)

def test_ecdsa_coords_mismatched(self):
"""Test JWKS request with ES256"""
cert = CertificateKeyPair.objects.create(
Expand Down
30 changes: 30 additions & 0 deletions authentik/providers/oauth2/tests/test_token.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,36 @@ def test_auth_code_view(self):
)
self.validate_jwt(access, provider)

def test_auth_code_enc(self):
"""test request param"""
provider = OAuth2Provider.objects.create(
name=generate_id(),
authorization_flow=create_test_flow(),
redirect_uris="http://local.invalid",
signing_key=self.keypair,
encryption_key=self.keypair,
)
# Needs to be assigned to an application for iss to be set
self.app.provider = provider
self.app.save()
header = b64encode(f"{provider.client_id}:{provider.client_secret}".encode()).decode()
user = create_test_admin_user()
code = AuthorizationCode.objects.create(
code="foobar", provider=provider, user=user, auth_time=timezone.now()
)
response = self.client.post(
reverse("authentik_providers_oauth2:token"),
data={
"grant_type": GRANT_TYPE_AUTHORIZATION_CODE,
"code": code.code,
"redirect_uri": "http://local.invalid",
},
HTTP_AUTHORIZATION=f"Basic {header}",
)
self.assertEqual(response.status_code, 200)
access: AccessToken = AccessToken.objects.filter(user=user, provider=provider).first()
self.validate_jwe(access, provider)

@apply_blueprint("system/providers-oauth2.yaml")
def test_refresh_token_view(self):
"""test request param"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def setUp(self) -> None:
self.factory = RequestFactory()
self.cert = create_test_cert()

jwk = JWKSView().get_jwk_for_key(self.cert)
jwk = JWKSView().get_jwk_for_key(self.cert, "sig")
self.source: OAuthSource = OAuthSource.objects.create(
name=generate_id(),
slug=generate_id(),
Expand Down
11 changes: 11 additions & 0 deletions authentik/providers/oauth2/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from typing import Any

from django.test import TestCase
from jwcrypto.jwe import JWE
from jwcrypto.jwk import JWK
from jwt import decode

from authentik.core.tests.utils import create_test_cert
Expand Down Expand Up @@ -32,6 +34,15 @@ def assert_non_none_or_unset(self, container: dict, key: str):
if key in container:
self.assertIsNotNone(container[key])

def validate_jwe(self, token: AccessToken, provider: OAuth2Provider) -> dict[str, Any]:
"""Validate JWEs"""
private_key = JWK.from_pem(provider.encryption_key.key_data.encode())

jwetoken = JWE()
jwetoken.deserialize(token.token, key=private_key)
token.token = jwetoken.payload.decode()
return self.validate_jwt(token, provider)

def validate_jwt(self, token: AccessToken, provider: OAuth2Provider) -> dict[str, Any]:
"""Validate that all required fields are set"""
key, alg = provider.jwt_key
Expand Down
55 changes: 33 additions & 22 deletions authentik/providers/oauth2/views/jwks.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,36 +64,42 @@ def to_base64url_uint(val: int, min_length: int = 0) -> bytes:
class JWKSView(View):
"""Show RSA Key data for Provider"""

def get_jwk_for_key(self, key: CertificateKeyPair) -> dict | None:
def get_jwk_for_key(self, key: CertificateKeyPair, use: str) -> dict | None:
"""Convert a certificate-key pair into JWK"""
private_key = key.private_key
key_data = None
if not private_key:
return key_data

key_data = {}

if use == "sig":
if isinstance(private_key, RSAPrivateKey):
key_data["alg"] = JWTAlgorithms.RS256
elif isinstance(private_key, EllipticCurvePrivateKey):
key_data["alg"] = JWTAlgorithms.ES256
elif use == "enc":
key_data["alg"] = "RSA-OAEP-256"
key_data["enc"] = "A256CBC-HS512"

if isinstance(private_key, RSAPrivateKey):
public_key: RSAPublicKey = private_key.public_key()
public_numbers = public_key.public_numbers()
key_data = {
"kid": key.kid,
"kty": "RSA",
"alg": JWTAlgorithms.RS256,
"use": "sig",
"n": to_base64url_uint(public_numbers.n).decode(),
"e": to_base64url_uint(public_numbers.e).decode(),
}
key_data["kid"] = key.kid
key_data["kty"] = "RSA"
key_data["use"] = use
key_data["n"] = to_base64url_uint(public_numbers.n).decode()
key_data["e"] = to_base64url_uint(public_numbers.e).decode()
elif isinstance(private_key, EllipticCurvePrivateKey):
public_key: EllipticCurvePublicKey = private_key.public_key()
public_numbers = public_key.public_numbers()
curve_type = type(public_key.curve)
key_data = {
"kid": key.kid,
"kty": "EC",
"alg": JWTAlgorithms.ES256,
"use": "sig",
"x": to_base64url_uint(public_numbers.x, min_length_map[curve_type]).decode(),
"y": to_base64url_uint(public_numbers.y, min_length_map[curve_type]).decode(),
"crv": ec_crv_map.get(curve_type, public_key.curve.name),
}
key_data["kid"] = key.kid
key_data["kty"] = "EC"
key_data["use"] = use
key_data["x"] = to_base64url_uint(public_numbers.x, min_length_map[curve_type]).decode()
key_data["y"] = to_base64url_uint(public_numbers.y, min_length_map[curve_type]).decode()
key_data["crv"] = ec_crv_map.get(curve_type, public_key.curve.name)
else:
return key_data
key_data["x5c"] = [b64encode(key.certificate.public_bytes(Encoding.DER)).decode("utf-8")]
Expand All @@ -113,14 +119,19 @@ def get(self, request: HttpRequest, application_slug: str) -> HttpResponse:
"""Show JWK Key data for Provider"""
application = get_object_or_404(Application, slug=application_slug)
provider: OAuth2Provider = get_object_or_404(OAuth2Provider, pk=application.provider_id)
signing_key: CertificateKeyPair = provider.signing_key

response_data = {}

if signing_key:
jwk = self.get_jwk_for_key(signing_key)
if signing_key := provider.signing_key:
jwk = self.get_jwk_for_key(signing_key, "sig")
if jwk:
response_data.setdefault("keys", [])
response_data["keys"].append(jwk)
if encryption_key := provider.encryption_key:
jwk = self.get_jwk_for_key(encryption_key, "enc")
if jwk:
response_data["keys"] = [jwk]
response_data.setdefault("keys", [])
response_data["keys"].append(jwk)

response = JsonResponse(response_data)
response["Access-Control-Allow-Origin"] = "*"
Expand Down
6 changes: 5 additions & 1 deletion authentik/providers/oauth2/views/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def get_info(self, provider: OAuth2Provider) -> dict[str, Any]:
if SCOPE_OPENID not in scopes:
scopes.append(SCOPE_OPENID)
_, supported_alg = provider.jwt_key
return {
config = {
"issuer": provider.get_issuer(self.request),
"authorization_endpoint": self.request.build_absolute_uri(
reverse("authentik_providers_oauth2:authorize")
Expand Down Expand Up @@ -114,6 +114,10 @@ def get_info(self, provider: OAuth2Provider) -> dict[str, Any]:
"claims_parameter_supported": False,
"code_challenge_methods_supported": [PKCE_METHOD_PLAIN, PKCE_METHOD_S256],
}
if provider.encryption_key:
config["id_token_encryption_alg_values_supported"] = ["RSA-OAEP-256"]
config["id_token_encryption_enc_values_supported"] = ["A256CBC-HS512"]
return config

def get_claims(self, provider: OAuth2Provider) -> list[str]:
"""Get a list of supported claims based on configured scope mappings"""
Expand Down
Loading

0 comments on commit 47206d3

Please sign in to comment.