diff --git a/examples/leptos-actix/src/home.rs b/examples/leptos-actix/src/home.rs index 1ff4e47..f304ac3 100644 --- a/examples/leptos-actix/src/home.rs +++ b/examples/leptos-actix/src/home.rs @@ -20,7 +20,7 @@ pub fn HomePage() -> impl IntoView { {move || Suspend::new(async move { match user.await { Ok(user) => Either::Left(match user { Some(user) => Either::Left(view! { - {user.id} +

User ID: {user.id}

diff --git a/examples/leptos-axum/src/home.rs b/examples/leptos-axum/src/home.rs index 3119eb6..1aeb617 100644 --- a/examples/leptos-axum/src/home.rs +++ b/examples/leptos-axum/src/home.rs @@ -20,7 +20,7 @@ pub fn HomePage() -> impl IntoView { {move || Suspend::new(async move { match user.await { Ok(user) => Either::Left(match user { Some(user) => Either::Left(view! { - {user.id} +

User ID: {user.id}

diff --git a/packages/core/shield/src/session.rs b/packages/core/shield/src/session.rs index 689aa7e..0a7dd3d 100644 --- a/packages/core/shield/src/session.rs +++ b/packages/core/shield/src/session.rs @@ -43,10 +43,18 @@ impl Session { #[derive(Clone, Debug, Default, Deserialize, Serialize)] pub struct SessionData { - pub user_id: Option, + pub authentication: Option, - // TODO: allow arbitrary data to be stored by providers? + // TODO: Allow arbitrary data to be stored by providers? pub csrf: Option, pub nonce: Option, pub verifier: Option, + pub oidc_connection_id: Option, +} + +#[derive(Clone, Debug, Default, Deserialize, Serialize)] +pub struct Authentication { + pub provider_id: String, + pub subprovider_id: Option, + pub user_id: String, } diff --git a/packages/core/shield/src/shield.rs b/packages/core/shield/src/shield.rs index b547e34..5604aa9 100644 --- a/packages/core/shield/src/shield.rs +++ b/packages/core/shield/src/shield.rs @@ -4,7 +4,7 @@ use futures::future::try_join_all; use tracing::debug; use crate::{ - error::{ProviderError, ShieldError}, + error::{ProviderError, SessionError, ShieldError}, provider::{Provider, Subprovider, SubproviderVisualisation}, request::{SignInCallbackRequest, SignInRequest, SignOutRequest}, response::Response, @@ -107,19 +107,39 @@ impl Shield { provider.sign_in_callback(request, session).await } - pub async fn sign_out( - &self, - request: SignOutRequest, - session: Session, - ) -> Result { - debug!("sign out {:?}", request); + pub async fn sign_out(&self, session: Session) -> Result { + debug!("sign out"); - let provider = match self.providers.get(&request.provider_id) { - Some(provider) => provider, - None => return Err(ProviderError::ProviderNotFound(request.provider_id).into()), + let authenticated = { + let session_data = session.data(); + let session_data = session_data + .lock() + .map_err(|err| SessionError::Lock(err.to_string()))?; + + session_data.authentication.clone() }; - let response = provider.sign_out(request, session.clone()).await?; + let response = if let Some(authenticated) = authenticated { + let provider = match self.providers.get(&authenticated.provider_id) { + Some(provider) => provider, + None => { + return Err(ProviderError::ProviderNotFound(authenticated.provider_id).into()) + } + }; + + provider + .sign_out( + SignOutRequest { + provider_id: authenticated.provider_id, + subprovider_id: authenticated.subprovider_id, + }, + session.clone(), + ) + .await? + } else { + // TODO: Should be configurable. + Response::Redirect("/".to_owned()) + }; session.purge().await?; diff --git a/packages/core/shield/src/shield_dyn.rs b/packages/core/shield/src/shield_dyn.rs index dc8e9b6..49c4a0f 100644 --- a/packages/core/shield/src/shield_dyn.rs +++ b/packages/core/shield/src/shield_dyn.rs @@ -5,7 +5,7 @@ use async_trait::async_trait; use crate::{ error::ShieldError, provider::{Subprovider, SubproviderVisualisation}, - request::{SignInCallbackRequest, SignInRequest, SignOutRequest}, + request::{SignInCallbackRequest, SignInRequest}, response::Response, session::Session, shield::Shield, @@ -32,11 +32,7 @@ pub trait DynShield: Send + Sync { session: Session, ) -> Result; - async fn sign_out( - &self, - request: SignOutRequest, - session: Session, - ) -> Result; + async fn sign_out(&self, session: Session) -> Result; } #[async_trait] @@ -67,12 +63,8 @@ impl DynShield for Shield { self.sign_in_callback(request, session).await } - async fn sign_out( - &self, - request: SignOutRequest, - session: Session, - ) -> Result { - self.sign_out(request, session).await + async fn sign_out(&self, session: Session) -> Result { + self.sign_out(session).await } } @@ -109,11 +101,7 @@ impl ShieldDyn { self.0.sign_in_callback(request, session).await } - pub async fn sign_out( - &self, - request: SignOutRequest, - session: Session, - ) -> Result { - self.0.sign_out(request, session).await + pub async fn sign_out(&self, session: Session) -> Result { + self.0.sign_out(session).await } } diff --git a/packages/integrations/shield-leptos/src/routes/sign_out.rs b/packages/integrations/shield-leptos/src/routes/sign_out.rs index a6eea0a..11ab8e3 100644 --- a/packages/integrations/shield-leptos/src/routes/sign_out.rs +++ b/packages/integrations/shield-leptos/src/routes/sign_out.rs @@ -1,11 +1,8 @@ use leptos::prelude::*; #[server] -pub async fn sign_out( - provider_id: String, - subprovider_id: Option, -) -> Result<(), ServerFnError> { - use shield::{Response, ShieldError, SignOutRequest}; +pub async fn sign_out() -> Result<(), ServerFnError> { + use shield::{Response, ShieldError}; use crate::context::expect_server_integration; @@ -14,13 +11,7 @@ pub async fn sign_out( let session = server_integration.extract_session().await; let response = shield - .sign_out( - SignOutRequest { - provider_id, - subprovider_id, - }, - session, - ) + .sign_out(session) .await .map_err(ServerFnError::::from)?; @@ -41,9 +32,6 @@ pub fn SignOut() -> impl IntoView {

"Sign out"

- // - // - } diff --git a/packages/integrations/shield-tower/src/service.rs b/packages/integrations/shield-tower/src/service.rs index b453c83..7b8a684 100644 --- a/packages/integrations/shield-tower/src/service.rs +++ b/packages/integrations/shield-tower/src/service.rs @@ -75,13 +75,15 @@ where }; let shield_session = Session::new(session_storage); - let user_id = match shield_session.data().lock() { - Ok(session) => session.user_id.clone(), + let authenticated = match shield_session.data().lock() { + Ok(session) => session.authentication.clone(), Err(_err) => return Ok(Self::internal_server_error()), }; - let user = if let Some(user_id) = user_id { - match shield.storage().user_by_id(&user_id).await { + let user = if let Some(authenticated) = authenticated { + // TODO: Verify provider and subprovider still exist. + + match shield.storage().user_by_id(&authenticated.user_id).await { Ok(user) => { if user.is_none() { if let Err(_err) = shield_session.purge().await { diff --git a/packages/providers/shield-oidc/src/provider.rs b/packages/providers/shield-oidc/src/provider.rs index b38aaf3..51f24d0 100644 --- a/packages/providers/shield-oidc/src/provider.rs +++ b/packages/providers/shield-oidc/src/provider.rs @@ -7,9 +7,9 @@ use openidconnect::{ PkceCodeChallenge, PkceCodeVerifier, Scope, TokenResponse, UserInfoClaims, }; use shield::{ - ConfigurationError, CreateEmailAddress, CreateUser, Provider, ProviderError, Response, Session, - SessionError, ShieldError, SignInCallbackRequest, SignInRequest, SignOutRequest, Subprovider, - UpdateUser, User, + Authentication, ConfigurationError, CreateEmailAddress, CreateUser, Provider, ProviderError, + Response, Session, SessionError, ShieldError, SignInCallbackRequest, SignInRequest, + SignOutRequest, Subprovider, UpdateUser, User, }; use tracing::debug; @@ -341,7 +341,7 @@ impl Provider for OidcProvider { let connection = self .create_oidc_connection( - subprovider.id, + subprovider.id.clone(), user.id(), claims.subject().to_string(), token_response, @@ -352,6 +352,8 @@ impl Provider for OidcProvider { } }; + debug!("signed in {:?} {:?}", user.id(), connection); + session.renew().await?; { @@ -360,13 +362,20 @@ impl Provider for OidcProvider { .lock() .map_err(|err| SessionError::Lock(err.to_string()))?; - session_data.user_id = Some(user.id()); + session_data.csrf = None; + session_data.nonce = None; + session_data.verifier = None; + + session_data.authentication = Some(Authentication { + provider_id: self.id(), + subprovider_id: Some(subprovider.id), + user_id: user.id(), + }); + session_data.oidc_connection_id = Some(connection.id); } session.update().await?; - debug!("signed in {:?} {:?}", user.id(), connection); - // TODO: Should be configurable. Ok(Response::Redirect("/".to_owned())) } @@ -381,25 +390,37 @@ impl Provider for OidcProvider { None => return Err(ProviderError::SubproviderMissing.into()), }; - // TODO: find access token - let token = AccessToken::new("".to_owned()); - - let client = subprovider.oidc_client().await?; + let connection_id = { + let session_data = session.data(); + let session_data = session_data + .lock() + .map_err(|err| SessionError::Lock(err.to_string()))?; - let revocation_request = match client.revoke_token(token.into()) { - Ok(revocation_request) => Some(revocation_request), - Err(openidconnect::ConfigurationError::MissingUrl("revocation")) => None, - Err(err) => return Err(ConfigurationError::Invalid(err.to_string()).into()), + session_data.oidc_connection_id.clone() }; - if let Some(revocation_request) = revocation_request { - revocation_request - .request_async(async_http_client) - .await - .expect("TODO: revocation request error"); - } + if let Some(connection_id) = connection_id { + if let Some(connection) = self.storage.oidc_connection_by_id(&connection_id).await? { + debug!("revoking access token {:?}", connection.access_token); + + let token = AccessToken::new(connection.access_token); - session.purge().await?; + let client = subprovider.oidc_client().await?; + + let revocation_request = match client.revoke_token(token.into()) { + Ok(revocation_request) => Some(revocation_request), + Err(openidconnect::ConfigurationError::MissingUrl("revocation")) => None, + Err(err) => return Err(ConfigurationError::Invalid(err.to_string()).into()), + }; + + if let Some(revocation_request) = revocation_request { + revocation_request + .request_async(async_http_client) + .await + .map_err(|err| ShieldError::Request(err.to_string()))?; + } + } + } // TODO: Should be configurable. Ok(Response::Redirect("/".to_owned())) diff --git a/packages/providers/shield-oidc/src/storage.rs b/packages/providers/shield-oidc/src/storage.rs index 55bc921..a55e0b1 100644 --- a/packages/providers/shield-oidc/src/storage.rs +++ b/packages/providers/shield-oidc/src/storage.rs @@ -16,6 +16,11 @@ pub trait OidcStorage: Storage + Sync { subprovider_id: &str, ) -> Result, StorageError>; + async fn oidc_connection_by_id( + &self, + connection_id: &str, + ) -> Result, StorageError>; + async fn oidc_connection_by_identifier( &self, subprovider_id: &str, diff --git a/packages/storage/shield-memory/src/providers/oidc.rs b/packages/storage/shield-memory/src/providers/oidc.rs index 8d40e02..fcffd7d 100644 --- a/packages/storage/shield-memory/src/providers/oidc.rs +++ b/packages/storage/shield-memory/src/providers/oidc.rs @@ -27,6 +27,20 @@ impl OidcStorage for MemoryStorage { Ok(None) } + async fn oidc_connection_by_id( + &self, + connection_id: &str, + ) -> Result, StorageError> { + Ok(self + .oidc + .connections + .lock() + .map_err(|err| StorageError::Engine(err.to_string()))? + .iter() + .find(|connection| connection.id == connection_id) + .cloned()) + } + async fn oidc_connection_by_identifier( &self, subprovider_id: &str, diff --git a/packages/storage/shield-sea-orm/src/providers/oidc.rs b/packages/storage/shield-sea-orm/src/providers/oidc.rs index d915727..29499e4 100644 --- a/packages/storage/shield-sea-orm/src/providers/oidc.rs +++ b/packages/storage/shield-sea-orm/src/providers/oidc.rs @@ -40,6 +40,17 @@ impl OidcStorage for SeaOrmStorage { }) } + async fn oidc_connection_by_id( + &self, + connection_id: &str, + ) -> Result, StorageError> { + oidc_provider_connection::Entity::find_by_id(Self::parse_uuid(connection_id)?) + .one(&self.database) + .await + .map_err(|err| StorageError::Engine(err.to_string())) + .map(|connection| connection.map(OidcConnection::from)) + } + async fn oidc_connection_by_identifier( &self, subprovider_id: &str,