From c5461e9eee87cf5d290768536aa8f875a2e39a59 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABlle=20Huisman?= Date: Thu, 26 Dec 2024 10:14:24 +0100 Subject: [PATCH] Shield OIDC: Add user info support --- examples/leptos-actix/src/app.rs | 15 +++++---- examples/leptos-axum/src/app.rs | 15 +++++---- examples/leptos-axum/src/main.rs | 2 +- packages/providers/shield-oidc/src/claims.rs | 31 +++++++++++++++++++ packages/providers/shield-oidc/src/lib.rs | 1 + .../providers/shield-oidc/src/provider.rs | 28 ++++++++++++----- 6 files changed, 68 insertions(+), 24 deletions(-) create mode 100644 packages/providers/shield-oidc/src/claims.rs diff --git a/examples/leptos-actix/src/app.rs b/examples/leptos-actix/src/app.rs index d0c8fb0..de5fecd 100644 --- a/examples/leptos-actix/src/app.rs +++ b/examples/leptos-actix/src/app.rs @@ -1,7 +1,7 @@ use leptos::prelude::*; use leptos_meta::{provide_meta_context, MetaTags, Title}; use leptos_router::{ - components::{Route, Router, Routes}, + components::{Route, Router, Routes, A}, path, }; use shield_leptos::routes::SignIn; @@ -29,11 +29,11 @@ pub fn App() -> impl IntoView { provide_meta_context(); view! { - + <Title text="Shield Leptos Actix Example"/> <Router> <main> - <Routes fallback=|| "Page not found.".into_view()> + <Routes fallback=|| "Not found.".into_view()> <Route path=path!("") view=HomePage/> <Route path=path!("/auth/sign-in") view=SignIn /> @@ -45,11 +45,10 @@ pub fn App() -> impl IntoView { #[component] fn HomePage() -> impl IntoView { - let count = RwSignal::new(0); - let on_click = move |_| *count.write() += 1; - view! { - <h1>"Welcome to Leptos!"</h1> - <button on:click=on_click>"Click Me: " {count}</button> + <h1>"Shield Leptos Actix Example"</h1> + <A href="/auth/sign-in"> + <button>"Sign in"</button> + </A> } } diff --git a/examples/leptos-axum/src/app.rs b/examples/leptos-axum/src/app.rs index d0c8fb0..23d707c 100644 --- a/examples/leptos-axum/src/app.rs +++ b/examples/leptos-axum/src/app.rs @@ -1,7 +1,7 @@ use leptos::prelude::*; use leptos_meta::{provide_meta_context, MetaTags, Title}; use leptos_router::{ - components::{Route, Router, Routes}, + components::{Route, Router, Routes, A}, path, }; use shield_leptos::routes::SignIn; @@ -29,11 +29,11 @@ pub fn App() -> impl IntoView { provide_meta_context(); view! { - <Title text="Welcome to Leptos"/> + <Title text="Shield Leptos Axum Example"/> <Router> <main> - <Routes fallback=|| "Page not found.".into_view()> + <Routes fallback=|| "Not found.".into_view()> <Route path=path!("") view=HomePage/> <Route path=path!("/auth/sign-in") view=SignIn /> @@ -45,11 +45,10 @@ pub fn App() -> impl IntoView { #[component] fn HomePage() -> impl IntoView { - let count = RwSignal::new(0); - let on_click = move |_| *count.write() += 1; - view! { - <h1>"Welcome to Leptos!"</h1> - <button on:click=on_click>"Click Me: " {count}</button> + <h1>"Shield Leptos Axum Example"</h1> + <A href="/auth/sign-in"> + <button>"Sign in"</button> + </A> } } diff --git a/examples/leptos-axum/src/main.rs b/examples/leptos-axum/src/main.rs index 5868b3e..f30f021 100644 --- a/examples/leptos-axum/src/main.rs +++ b/examples/leptos-axum/src/main.rs @@ -25,7 +25,7 @@ async fn main() { let session_store = MemoryStore::default(); let session_layer = SessionManagerLayer::new(session_store) .with_secure(false) - .with_expiry(Expiry::OnInactivity(Duration::hours(1))); + .with_expiry(Expiry::OnInactivity(Duration::minutes(10))); // Initialize Shield let shield = Shield::new( diff --git a/packages/providers/shield-oidc/src/claims.rs b/packages/providers/shield-oidc/src/claims.rs new file mode 100644 index 0000000..da4b3bc --- /dev/null +++ b/packages/providers/shield-oidc/src/claims.rs @@ -0,0 +1,31 @@ +use openidconnect::{ + core::CoreGenderClaim, EmptyAdditionalClaims, IdTokenClaims, SubjectIdentifier, UserInfoClaims, +}; + +/// Unified interface for [`IdTokenClaims`] and [`UserInfoClaims`]. +#[derive(Clone, Debug)] +pub enum Claims { + IdToken(IdTokenClaims<EmptyAdditionalClaims, CoreGenderClaim>), + UserInfo(UserInfoClaims<EmptyAdditionalClaims, CoreGenderClaim>), +} + +impl Claims { + pub fn subject(&self) -> &SubjectIdentifier { + match &self { + Claims::IdToken(id_token_claims) => id_token_claims.subject(), + Claims::UserInfo(user_info_claims) => user_info_claims.subject(), + } + } +} + +impl From<IdTokenClaims<EmptyAdditionalClaims, CoreGenderClaim>> for Claims { + fn from(value: IdTokenClaims<EmptyAdditionalClaims, CoreGenderClaim>) -> Self { + Self::IdToken(value) + } +} + +impl From<UserInfoClaims<EmptyAdditionalClaims, CoreGenderClaim>> for Claims { + fn from(value: UserInfoClaims<EmptyAdditionalClaims, CoreGenderClaim>) -> Self { + Self::UserInfo(value) + } +} diff --git a/packages/providers/shield-oidc/src/lib.rs b/packages/providers/shield-oidc/src/lib.rs index 854b659..c5e3afa 100644 --- a/packages/providers/shield-oidc/src/lib.rs +++ b/packages/providers/shield-oidc/src/lib.rs @@ -1,4 +1,5 @@ mod builders; +mod claims; mod provider; mod storage; mod subprovider; diff --git a/packages/providers/shield-oidc/src/provider.rs b/packages/providers/shield-oidc/src/provider.rs index 7691439..b0db70d 100644 --- a/packages/providers/shield-oidc/src/provider.rs +++ b/packages/providers/shield-oidc/src/provider.rs @@ -1,14 +1,19 @@ use async_trait::async_trait; use openidconnect::{ - core::CoreAuthenticationFlow, reqwest::async_http_client, AccessToken, AuthorizationCode, - CsrfToken, Nonce, PkceCodeChallenge, PkceCodeVerifier, Scope, TokenResponse, + core::{CoreAuthenticationFlow, CoreGenderClaim}, + reqwest::async_http_client, + AccessToken, AuthorizationCode, CsrfToken, EmptyAdditionalClaims, Nonce, OAuth2TokenResponse, + PkceCodeChallenge, PkceCodeVerifier, Scope, TokenResponse, UserInfoClaims, }; use shield::{ ConfigurationError, Provider, ProviderError, Response, Session, SessionError, ShieldError, SignInCallbackRequest, SignInRequest, SignOutRequest, Subprovider, }; -use crate::{storage::OidcStorage, subprovider::OidcSubprovider, OidcProviderPkceCodeChallenge}; +use crate::{ + claims::Claims, storage::OidcStorage, subprovider::OidcSubprovider, + OidcProviderPkceCodeChallenge, +}; pub const OIDC_PROVIDER_ID: &str = "oidc"; @@ -201,7 +206,7 @@ impl Provider for OidcProvider { .await .map_err(|err| ShieldError::Request(err.to_string()))?; - if let Some(id_token) = token_response.id_token() { + let claims = if let Some(id_token) = token_response.id_token() { let claims = id_token .claims( @@ -212,10 +217,19 @@ impl Provider for OidcProvider { ) .map_err(|err| ShieldError::Verification(err.to_string()))?; - println!("{:?}", claims); - } + Claims::from(claims.clone()) + } else { + let claims: UserInfoClaims<EmptyAdditionalClaims, CoreGenderClaim> = client + .user_info(token_response.access_token().to_owned(), None) + .map_err(|err| ConfigurationError::Missing(err.to_string()))? + .request_async(async_http_client) + .await + .map_err(|err| ShieldError::Request(err.to_string()))?; + + Claims::from(claims) + }; - // let user_info = client.user_info(token_response.access_token(), None) + println!("{:?}\n{:?}", claims.subject(), claims); // TODO Ok(Response::Redirect("/".to_owned()))