diff --git a/authentik/enterprise/providers/ssf/migrations/0005_alter_stream_events_requested_streamevent.py b/authentik/enterprise/providers/ssf/migrations/0005_alter_stream_events_requested_streamevent.py new file mode 100644 index 0000000000000..83e25c2867884 --- /dev/null +++ b/authentik/enterprise/providers/ssf/migrations/0005_alter_stream_events_requested_streamevent.py @@ -0,0 +1,79 @@ +# Generated by Django 5.0.10 on 2024-12-11 18:33 + +import django.contrib.postgres.fields +import django.db.models.deletion +import uuid +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("authentik_providers_ssf", "0004_stream_aud_alter_stream_events_requested"), + ] + + operations = [ + migrations.AlterField( + model_name="stream", + name="events_requested", + field=django.contrib.postgres.fields.ArrayField( + base_field=models.TextField( + choices=[ + ( + "https://schemas.openid.net/secevent/caep/event-type/session-revoked", + "Caep Session Revoked", + ), + ( + "https://schemas.openid.net/secevent/caep/event-type/credential-change", + "Caep Credential Change", + ), + ( + "https://schemas.openid.net/secevent/ssf/event-type/verification", + "Set Verification", + ), + ] + ), + default=list, + size=None, + ), + ), + migrations.CreateModel( + name="StreamEvent", + fields=[ + ( + "uuid", + models.UUIDField( + default=uuid.uuid4, editable=False, primary_key=True, serialize=False + ), + ), + ("status", models.TextField(choices=[("pending", "Pending"), ("sent", "Sent")])), + ( + "type", + models.TextField( + choices=[ + ( + "https://schemas.openid.net/secevent/caep/event-type/session-revoked", + "Caep Session Revoked", + ), + ( + "https://schemas.openid.net/secevent/caep/event-type/credential-change", + "Caep Credential Change", + ), + ( + "https://schemas.openid.net/secevent/ssf/event-type/verification", + "Set Verification", + ), + ] + ), + ), + ("payload", models.JSONField(default=dict)), + ( + "stream", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + to="authentik_providers_ssf.stream", + ), + ), + ], + ), + ] diff --git a/authentik/enterprise/providers/ssf/models.py b/authentik/enterprise/providers/ssf/models.py index 501f158d7ce33..d2c8a8600fa71 100644 --- a/authentik/enterprise/providers/ssf/models.py +++ b/authentik/enterprise/providers/ssf/models.py @@ -1,3 +1,4 @@ +from datetime import datetime from functools import cached_property from uuid import uuid4 @@ -6,7 +7,9 @@ from cryptography.hazmat.primitives.asymmetric.types import PrivateKeyTypes from django.contrib.postgres.fields import ArrayField from django.db import models +from django.http import HttpRequest from django.templatetags.static import static +from django.urls import reverse from django.utils.translation import gettext_lazy as _ from jwt import encode @@ -30,6 +33,13 @@ class DeliveryMethods(models.TextChoices): RISC_POLL = "https://schemas.openid.net/secevent/risc/delivery-method/poll" +class SSFEventStatus(models.TextChoices): + """SSF Event status""" + + PENDING = "pending" + SENT = "sent" + + class SSFProvider(BackchannelProvider): """Shared Signals Framework""" @@ -102,6 +112,35 @@ class Stream(models.Model): def __str__(self) -> str: return "SSF Stream" + def new_event( + self, type: EventTypes, request: HttpRequest, event_data: dict, **kwargs + ) -> "StreamEvent": + """Create a new SSF event""" + jti = uuid4() + evt = StreamEvent( + uuid=jti, + stream=self, + type=type, + payload={ + "jti": jti.hex, + "aud": self.aud, + "iat": int(datetime.now().timestamp()), + "iss": request.build_absolute_uri( + reverse( + "authentik_providers_ssf:configuration", + kwargs={ + "application_slug": self.provider.application.slug, + "provider": self.provider.pk, + }, + ) + ), + "events": {type: event_data}, + **kwargs, + }, + ) + evt.save() + return evt + def encode(self, data: dict) -> str: headers = {} if self.provider.signing_key: @@ -117,6 +156,23 @@ class UserStreamSubject(models.Model): def __str__(self) -> str: return f"Stream subject {self.stream_id} to {self.user_id}" + class StreamEvent(models.Model): + """Single stream event to be sent""" uuid = models.UUIDField(default=uuid4, primary_key=True, editable=False) + + stream = models.ForeignKey(Stream, on_delete=models.CASCADE) + status = models.TextField(choices=SSFEventStatus.choices) + + type = models.TextField(choices=EventTypes.choices) + payload = models.JSONField(default=dict) + + def __str__(self): + return f"Stream event {self.type}" + + def queue(self): + """Queue event to be sent""" + from authentik.enterprise.providers.ssf.tasks import send_single_ssf_event + + return send_single_ssf_event.delay(str(self.stream.uuid), str(self.uuid)) diff --git a/authentik/enterprise/providers/ssf/signals.py b/authentik/enterprise/providers/ssf/signals.py index 5132f7984cc71..73c054743953f 100644 --- a/authentik/enterprise/providers/ssf/signals.py +++ b/authentik/enterprise/providers/ssf/signals.py @@ -1,6 +1,3 @@ -from datetime import datetime -from uuid import uuid4 - from django.contrib.auth.signals import user_logged_out from django.db.models import Model from django.db.models.signals import post_save @@ -17,9 +14,8 @@ from authentik.enterprise.providers.ssf.models import ( EventTypes, SSFProvider, - Stream, ) -from authentik.enterprise.providers.ssf.tasks import send_single_ssf_event, send_ssf_event +from authentik.enterprise.providers.ssf.tasks import send_ssf_event from authentik.events.middleware import audit_ignore from authentik.events.utils import get_user @@ -53,29 +49,6 @@ def ssf_providers_post_save(sender: type[Model], instance: SSFProvider, created: instance.save() -@receiver(post_save, sender=Stream) -def ssf_stream_post_create(sender: type[Model], instance: Stream, created: bool, **_): - """Send a verification event when a stream is created""" - if not created: - return - send_single_ssf_event.delay( - str(instance.uuid), - { - "jti": uuid4().hex, - # TODO: Figure out how to get iss - "iss": "https://ak.beryju.dev/.well-known/ssf-configuration/abm-ssf/8", - "aud": instance.aud, - "iat": int(datetime.now().timestamp()), - "sub_id": {"format": "opaque", "id": str(instance.uuid)}, - "events": { - "https://schemas.openid.net/secevent/ssf/event-type/verification": { - "state": None, - } - }, - }, - ) - - @receiver(user_logged_out) def user_logged_out_session(sender, request: HttpRequest, user: User, **_): send_ssf_event.delay( diff --git a/authentik/enterprise/providers/ssf/tasks.py b/authentik/enterprise/providers/ssf/tasks.py index 07cdb4c06255a..30938eb4c00b6 100644 --- a/authentik/enterprise/providers/ssf/tasks.py +++ b/authentik/enterprise/providers/ssf/tasks.py @@ -1,7 +1,13 @@ from celery import group from requests.exceptions import RequestException -from authentik.enterprise.providers.ssf.models import DeliveryMethods, EventTypes, Stream +from authentik.enterprise.providers.ssf.models import ( + DeliveryMethods, + EventTypes, + SSFEventStatus, + Stream, + StreamEvent, +) from authentik.lib.utils.http import get_http_session from authentik.root.celery import CELERY_APP @@ -12,28 +18,34 @@ def send_ssf_event(event_type: EventTypes, data: dict): tasks = [] for stream in Stream.objects.filter(events_requested__in=[event_type]): - tasks.append(send_single_ssf_event.si(str(stream.uuid), data)) + event = stream.new_event( + type=event_type, + ) + tasks.append(send_single_ssf_event.si(str(stream.uuid), str(event.id))) main_task = group(*tasks) main_task() @CELERY_APP.task(bind=True, autoretry=True, autoretry_for=(RequestException,), retry_backoff=True) -def send_single_ssf_event(self, stream_id: str, data: dict): +def send_single_ssf_event(self, stream_id: str, evt_id: str): stream = Stream.objects.filter(pk=stream_id).first() if not stream: return + event = StreamEvent.objects.filter(pk=evt_id).first() + if not event: + return + if event.status == SSFEventStatus.SENT: + return if stream.delivery_method == DeliveryMethods.RISC_PUSH: - ssf_push_request.delay(stream_id, data) + ssf_push_request(stream_id, event) + event.status = SSFEventStatus.SENT + event.save() -@CELERY_APP.task(bind=True, autoretry=True, autoretry_for=(RequestException,), retry_backoff=True) -def ssf_push_request(self, stream_id: str, data: dict): - stream = Stream.objects.filter(pk=stream_id).first() - if not stream: - return +def ssf_push_request(event: StreamEvent): response = session.post( - stream.endpoint_url, - data=stream.encode(data), + event.stream.endpoint_url, + data=event.stream.encode(event.data), headers={"Content-Type": "application/secevent+jwt", "Accept": "application/json"}, ) response.raise_for_status() diff --git a/authentik/enterprise/providers/ssf/views/stream.py b/authentik/enterprise/providers/ssf/views/stream.py index c5aef348af1a5..cf292d1818ebf 100644 --- a/authentik/enterprise/providers/ssf/views/stream.py +++ b/authentik/enterprise/providers/ssf/views/stream.py @@ -18,7 +18,6 @@ class StreamDeliverySerializer(PassiveSerializer): class StreamSerializer(ModelSerializer): - delivery = StreamDeliverySerializer() events_requested = ListField( child=ChoiceField(choices=[(x.value, x.value) for x in EventTypes]) @@ -49,7 +48,6 @@ class Meta: class StreamResponseSerializer(PassiveSerializer): - stream_id = CharField(source="pk") iss = SerializerMethodField() aud = ListField(child=CharField()) @@ -88,7 +86,15 @@ class StreamView(SSFView): def post(self, request: Request, *args, **kwargs) -> Response: stream = StreamSerializer(data=request.data) stream.is_valid(raise_exception=True) - instance = stream.save(provider=self.provider) + instance: Stream = stream.save(provider=self.provider) + instance.new_event( + EventTypes.SET_VERIFICATION, + request, + { + "state": None, + }, + sub_id={"format": "opaque", "id": str(instance.uuid)}, + ).queue() response = StreamResponseSerializer(instance=instance, context={"request": request}).data return Response(response, status=201)