Skip to content

Commit

Permalink
Shield: Add sign out implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
DanielleHuisman committed Dec 28, 2024
1 parent 8a59e76 commit 16fbf20
Show file tree
Hide file tree
Showing 11 changed files with 131 additions and 74 deletions.
2 changes: 1 addition & 1 deletion examples/leptos-actix/src/home.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ pub fn HomePage() -> impl IntoView {
{move || Suspend::new(async move { match user.await {
Ok(user) => Either::Left(match user {
Some(user) => Either::Left(view! {
{user.id}
<p><b>User ID:</b> {user.id}</p>

<A href="/auth/sign-out">
<button>"Sign out"</button>
Expand Down
2 changes: 1 addition & 1 deletion examples/leptos-axum/src/home.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ pub fn HomePage() -> impl IntoView {
{move || Suspend::new(async move { match user.await {
Ok(user) => Either::Left(match user {
Some(user) => Either::Left(view! {
{user.id}
<p><b>User ID:</b> {user.id}</p>

<A href="/auth/sign-out">
<button>"Sign out"</button>
Expand Down
12 changes: 10 additions & 2 deletions packages/core/shield/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,18 @@ impl Session {

#[derive(Clone, Debug, Default, Deserialize, Serialize)]
pub struct SessionData {
pub user_id: Option<String>,
pub authentication: Option<Authentication>,

// TODO: allow arbitrary data to be stored by providers?
// TODO: Allow arbitrary data to be stored by providers?
pub csrf: Option<String>,
pub nonce: Option<String>,
pub verifier: Option<String>,
pub oidc_connection_id: Option<String>,
}

#[derive(Clone, Debug, Default, Deserialize, Serialize)]
pub struct Authentication {
pub provider_id: String,
pub subprovider_id: Option<String>,
pub user_id: String,
}
42 changes: 31 additions & 11 deletions packages/core/shield/src/shield.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use futures::future::try_join_all;
use tracing::debug;

use crate::{
error::{ProviderError, ShieldError},
error::{ProviderError, SessionError, ShieldError},
provider::{Provider, Subprovider, SubproviderVisualisation},
request::{SignInCallbackRequest, SignInRequest, SignOutRequest},
response::Response,
Expand Down Expand Up @@ -107,19 +107,39 @@ impl<U: User> Shield<U> {
provider.sign_in_callback(request, session).await
}

pub async fn sign_out(
&self,
request: SignOutRequest,
session: Session,
) -> Result<Response, ShieldError> {
debug!("sign out {:?}", request);
pub async fn sign_out(&self, session: Session) -> Result<Response, ShieldError> {
debug!("sign out");

let provider = match self.providers.get(&request.provider_id) {
Some(provider) => provider,
None => return Err(ProviderError::ProviderNotFound(request.provider_id).into()),
let authenticated = {
let session_data = session.data();
let session_data = session_data
.lock()
.map_err(|err| SessionError::Lock(err.to_string()))?;

session_data.authentication.clone()
};

let response = provider.sign_out(request, session.clone()).await?;
let response = if let Some(authenticated) = authenticated {
let provider = match self.providers.get(&authenticated.provider_id) {
Some(provider) => provider,
None => {
return Err(ProviderError::ProviderNotFound(authenticated.provider_id).into())
}
};

provider
.sign_out(
SignOutRequest {
provider_id: authenticated.provider_id,
subprovider_id: authenticated.subprovider_id,
},
session.clone(),
)
.await?
} else {
// TODO: Should be configurable.
Response::Redirect("/".to_owned())
};

session.purge().await?;

Expand Down
24 changes: 6 additions & 18 deletions packages/core/shield/src/shield_dyn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use async_trait::async_trait;
use crate::{
error::ShieldError,
provider::{Subprovider, SubproviderVisualisation},
request::{SignInCallbackRequest, SignInRequest, SignOutRequest},
request::{SignInCallbackRequest, SignInRequest},
response::Response,
session::Session,
shield::Shield,
Expand All @@ -32,11 +32,7 @@ pub trait DynShield: Send + Sync {
session: Session,
) -> Result<Response, ShieldError>;

async fn sign_out(
&self,
request: SignOutRequest,
session: Session,
) -> Result<Response, ShieldError>;
async fn sign_out(&self, session: Session) -> Result<Response, ShieldError>;
}

#[async_trait]
Expand Down Expand Up @@ -67,12 +63,8 @@ impl<U: User> DynShield for Shield<U> {
self.sign_in_callback(request, session).await
}

async fn sign_out(
&self,
request: SignOutRequest,
session: Session,
) -> Result<Response, ShieldError> {
self.sign_out(request, session).await
async fn sign_out(&self, session: Session) -> Result<Response, ShieldError> {
self.sign_out(session).await
}
}

Expand Down Expand Up @@ -109,11 +101,7 @@ impl ShieldDyn {
self.0.sign_in_callback(request, session).await
}

pub async fn sign_out(
&self,
request: SignOutRequest,
session: Session,
) -> Result<Response, ShieldError> {
self.0.sign_out(request, session).await
pub async fn sign_out(&self, session: Session) -> Result<Response, ShieldError> {
self.0.sign_out(session).await
}
}
18 changes: 3 additions & 15 deletions packages/integrations/shield-leptos/src/routes/sign_out.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
use leptos::prelude::*;

#[server]
pub async fn sign_out(
provider_id: String,
subprovider_id: Option<String>,
) -> Result<(), ServerFnError> {
use shield::{Response, ShieldError, SignOutRequest};
pub async fn sign_out() -> Result<(), ServerFnError> {
use shield::{Response, ShieldError};

use crate::context::expect_server_integration;

Expand All @@ -14,13 +11,7 @@ pub async fn sign_out(
let session = server_integration.extract_session().await;

let response = shield
.sign_out(
SignOutRequest {
provider_id,
subprovider_id,
},
session,
)
.sign_out(session)
.await
.map_err(ServerFnError::<ShieldError>::from)?;

Expand All @@ -41,9 +32,6 @@ pub fn SignOut() -> impl IntoView {
<h1>"Sign out"</h1>

<ActionForm action=sign_out>
// <input name="provider_id" type="hidden" value=subprovider.provider_id />
// <input name="subprovider_id" type="hidden" value=subprovider.subprovider_id />

<button type="submit">"Sign out"</button>
</ActionForm>
}
Expand Down
10 changes: 6 additions & 4 deletions packages/integrations/shield-tower/src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,15 @@ where
};
let shield_session = Session::new(session_storage);

let user_id = match shield_session.data().lock() {
Ok(session) => session.user_id.clone(),
let authenticated = match shield_session.data().lock() {
Ok(session) => session.authentication.clone(),
Err(_err) => return Ok(Self::internal_server_error()),
};

let user = if let Some(user_id) = user_id {
match shield.storage().user_by_id(&user_id).await {
let user = if let Some(authenticated) = authenticated {
// TODO: Verify provider and subprovider still exist.

match shield.storage().user_by_id(&authenticated.user_id).await {
Ok(user) => {
if user.is_none() {
if let Err(_err) = shield_session.purge().await {
Expand Down
65 changes: 43 additions & 22 deletions packages/providers/shield-oidc/src/provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ use openidconnect::{
PkceCodeChallenge, PkceCodeVerifier, Scope, TokenResponse, UserInfoClaims,
};
use shield::{
ConfigurationError, CreateEmailAddress, CreateUser, Provider, ProviderError, Response, Session,
SessionError, ShieldError, SignInCallbackRequest, SignInRequest, SignOutRequest, Subprovider,
UpdateUser, User,
Authentication, ConfigurationError, CreateEmailAddress, CreateUser, Provider, ProviderError,
Response, Session, SessionError, ShieldError, SignInCallbackRequest, SignInRequest,
SignOutRequest, Subprovider, UpdateUser, User,
};
use tracing::debug;

Expand Down Expand Up @@ -341,7 +341,7 @@ impl<U: User> Provider for OidcProvider<U> {

let connection = self
.create_oidc_connection(
subprovider.id,
subprovider.id.clone(),
user.id(),
claims.subject().to_string(),
token_response,
Expand All @@ -352,6 +352,8 @@ impl<U: User> Provider for OidcProvider<U> {
}
};

debug!("signed in {:?} {:?}", user.id(), connection);

session.renew().await?;

{
Expand All @@ -360,13 +362,20 @@ impl<U: User> Provider for OidcProvider<U> {
.lock()
.map_err(|err| SessionError::Lock(err.to_string()))?;

session_data.user_id = Some(user.id());
session_data.csrf = None;
session_data.nonce = None;
session_data.verifier = None;

session_data.authentication = Some(Authentication {
provider_id: self.id(),
subprovider_id: Some(subprovider.id),
user_id: user.id(),
});
session_data.oidc_connection_id = Some(connection.id);
}

session.update().await?;

debug!("signed in {:?} {:?}", user.id(), connection);

// TODO: Should be configurable.
Ok(Response::Redirect("/".to_owned()))
}
Expand All @@ -381,25 +390,37 @@ impl<U: User> Provider for OidcProvider<U> {
None => return Err(ProviderError::SubproviderMissing.into()),
};

// TODO: find access token
let token = AccessToken::new("".to_owned());

let client = subprovider.oidc_client().await?;
let connection_id = {
let session_data = session.data();
let session_data = session_data
.lock()
.map_err(|err| SessionError::Lock(err.to_string()))?;

let revocation_request = match client.revoke_token(token.into()) {
Ok(revocation_request) => Some(revocation_request),
Err(openidconnect::ConfigurationError::MissingUrl("revocation")) => None,
Err(err) => return Err(ConfigurationError::Invalid(err.to_string()).into()),
session_data.oidc_connection_id.clone()
};

if let Some(revocation_request) = revocation_request {
revocation_request
.request_async(async_http_client)
.await
.expect("TODO: revocation request error");
}
if let Some(connection_id) = connection_id {
if let Some(connection) = self.storage.oidc_connection_by_id(&connection_id).await? {
debug!("revoking access token {:?}", connection.access_token);

let token = AccessToken::new(connection.access_token);

session.purge().await?;
let client = subprovider.oidc_client().await?;

let revocation_request = match client.revoke_token(token.into()) {
Ok(revocation_request) => Some(revocation_request),
Err(openidconnect::ConfigurationError::MissingUrl("revocation")) => None,
Err(err) => return Err(ConfigurationError::Invalid(err.to_string()).into()),
};

if let Some(revocation_request) = revocation_request {
revocation_request
.request_async(async_http_client)
.await
.map_err(|err| ShieldError::Request(err.to_string()))?;
}
}
}

// TODO: Should be configurable.
Ok(Response::Redirect("/".to_owned()))
Expand Down
5 changes: 5 additions & 0 deletions packages/providers/shield-oidc/src/storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@ pub trait OidcStorage<U: User>: Storage<U> + Sync {
subprovider_id: &str,
) -> Result<Option<OidcSubprovider>, StorageError>;

async fn oidc_connection_by_id(
&self,
connection_id: &str,
) -> Result<Option<OidcConnection>, StorageError>;

async fn oidc_connection_by_identifier(
&self,
subprovider_id: &str,
Expand Down
14 changes: 14 additions & 0 deletions packages/storage/shield-memory/src/providers/oidc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,20 @@ impl OidcStorage<User> for MemoryStorage {
Ok(None)
}

async fn oidc_connection_by_id(
&self,
connection_id: &str,
) -> Result<Option<OidcConnection>, StorageError> {
Ok(self
.oidc
.connections
.lock()
.map_err(|err| StorageError::Engine(err.to_string()))?
.iter()
.find(|connection| connection.id == connection_id)
.cloned())
}

async fn oidc_connection_by_identifier(
&self,
subprovider_id: &str,
Expand Down
11 changes: 11 additions & 0 deletions packages/storage/shield-sea-orm/src/providers/oidc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,17 @@ impl OidcStorage<user::Model> for SeaOrmStorage {
})
}

async fn oidc_connection_by_id(
&self,
connection_id: &str,
) -> Result<Option<OidcConnection>, StorageError> {
oidc_provider_connection::Entity::find_by_id(Self::parse_uuid(connection_id)?)
.one(&self.database)
.await
.map_err(|err| StorageError::Engine(err.to_string()))
.map(|connection| connection.map(OidcConnection::from))
}

async fn oidc_connection_by_identifier(
&self,
subprovider_id: &str,
Expand Down

0 comments on commit 16fbf20

Please sign in to comment.