From f162c2d920bc61e40af650bb3828602986d808e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABlle=20Huisman?= Date: Fri, 27 Dec 2024 14:06:38 +0100 Subject: [PATCH] Shield: Add user to storage --- packages/core/shield/src/error.rs | 4 +- packages/core/shield/src/storage.rs | 39 +++- packages/core/shield/src/user.rs | 35 ++- packages/providers/shield-oidc/src/claims.rs | 13 +- .../providers/shield-oidc/src/provider.rs | 216 +++++++++++++----- packages/storage/shield-memory/src/storage.rs | 18 +- .../shield-sea-orm/src/providers/oidc.rs | 19 +- .../storage/shield-sea-orm/src/storage.rs | 72 +++++- 8 files changed, 338 insertions(+), 78 deletions(-) diff --git a/packages/core/shield/src/error.rs b/packages/core/shield/src/error.rs index 8def4f0..cf56d0b 100644 --- a/packages/core/shield/src/error.rs +++ b/packages/core/shield/src/error.rs @@ -24,6 +24,8 @@ pub enum StorageError { Configuration(#[from] ConfigurationError), #[error("{0}")] Validation(String), + #[error("{0} with ID `{1}` not found.")] + NotFound(String, String), #[error("{0}")] Engine(String), } @@ -52,5 +54,5 @@ pub enum ShieldError { #[error("{0}")] Request(String), #[error("{0}")] - Verification(String), + Validation(String), } diff --git a/packages/core/shield/src/storage.rs b/packages/core/shield/src/storage.rs index 35a67be..bde7fe6 100644 --- a/packages/core/shield/src/storage.rs +++ b/packages/core/shield/src/storage.rs @@ -1,6 +1,9 @@ use async_trait::async_trait; -use crate::{error::StorageError, user::User}; +use crate::{ + error::StorageError, + user::{CreateEmailAddress, CreateUser, UpdateUser, User}, +}; #[async_trait] pub trait Storage: Send + Sync { @@ -9,13 +12,29 @@ pub trait Storage: Send + Sync { async fn user_by_id(&self, user_id: &str) -> Result, StorageError>; async fn user_by_email(&self, email: &str) -> Result, StorageError>; + + async fn create_user( + &self, + user: CreateUser, + email_address: CreateEmailAddress, + ) -> Result; + + async fn update_user(&self, user: UpdateUser) -> Result; + + async fn delete_user(&self, user_id: &str) -> Result<(), StorageError>; + + // TODO: create, update, delete email address } #[cfg(test)] pub(crate) mod tests { use async_trait::async_trait; - use crate::{error::StorageError, storage::Storage, user::tests::TestUser}; + use crate::{ + error::StorageError, + storage::Storage, + user::{tests::TestUser, CreateEmailAddress, CreateUser, UpdateUser}, + }; pub const TEST_STORAGE_ID: &str = "test"; @@ -35,5 +54,21 @@ pub(crate) mod tests { async fn user_by_email(&self, _email: &str) -> Result, StorageError> { todo!("user_by_email") } + + async fn create_user( + &self, + _user: CreateUser, + _email_address: CreateEmailAddress, + ) -> Result { + todo!("create_user") + } + + async fn update_user(&self, _user: UpdateUser) -> Result { + todo!("update_user") + } + + async fn delete_user(&self, _user_id: &str) -> Result<(), StorageError> { + todo!("delete_user") + } } } diff --git a/packages/core/shield/src/user.rs b/packages/core/shield/src/user.rs index 3c20fde..fa4385b 100644 --- a/packages/core/shield/src/user.rs +++ b/packages/core/shield/src/user.rs @@ -1,5 +1,20 @@ use chrono::{DateTime, Utc}; +pub trait User: Send + Sync { + fn id(&self) -> String; +} + +#[derive(Clone, Debug)] +pub struct CreateUser { + pub name: Option, +} + +#[derive(Clone, Debug)] +pub struct UpdateUser { + pub id: String, + pub name: Option>, +} + #[derive(Clone, Debug)] pub struct EmailAddress { pub id: String, @@ -12,8 +27,24 @@ pub struct EmailAddress { pub user_id: String, } -pub trait User: Send + Sync { - fn id(&self) -> String; +#[derive(Clone, Debug)] +pub struct CreateEmailAddress { + pub email: String, + pub is_primary: bool, + pub is_verified: bool, + pub verification_token: Option, + pub verification_token_expired_at: Option>, + pub verified_at: Option>, +} + +#[derive(Clone, Debug)] +pub struct UpdateEmailAddress { + pub id: String, + pub is_primary: Option, + pub is_verified: Option, + pub verification_token: Option>, + pub verification_token_expired_at: Option>>, + pub verified_at: Option>>, } #[cfg(test)] diff --git a/packages/providers/shield-oidc/src/claims.rs b/packages/providers/shield-oidc/src/claims.rs index d1d940a..0207a33 100644 --- a/packages/providers/shield-oidc/src/claims.rs +++ b/packages/providers/shield-oidc/src/claims.rs @@ -1,6 +1,6 @@ use openidconnect::{ - core::CoreGenderClaim, EmptyAdditionalClaims, EndUserEmail, IdTokenClaims, SubjectIdentifier, - UserInfoClaims, + core::CoreGenderClaim, EmptyAdditionalClaims, EndUserEmail, EndUserName, IdTokenClaims, + LocalizedClaim, SubjectIdentifier, UserInfoClaims, }; /// Unified interface for [`IdTokenClaims`] and [`UserInfoClaims`]. @@ -18,14 +18,19 @@ impl Claims { } } - // TODO: Remove allow dead code. - #[allow(dead_code)] pub fn email(&self) -> Option<&EndUserEmail> { match &self { Claims::IdToken(id_token_claims) => id_token_claims.email(), Claims::UserInfo(user_info_claims) => user_info_claims.email(), } } + + pub fn name(&self) -> Option<&LocalizedClaim> { + match &self { + Claims::IdToken(id_token_claims) => id_token_claims.name(), + Claims::UserInfo(user_info_claims) => user_info_claims.name(), + } + } } impl From> for Claims { diff --git a/packages/providers/shield-oidc/src/provider.rs b/packages/providers/shield-oidc/src/provider.rs index 33dd027..3f3445c 100644 --- a/packages/providers/shield-oidc/src/provider.rs +++ b/packages/providers/shield-oidc/src/provider.rs @@ -1,18 +1,20 @@ use async_trait::async_trait; +use chrono::{DateTime, Duration, Utc}; use openidconnect::{ - core::{CoreAuthenticationFlow, CoreGenderClaim}, + core::{CoreAuthenticationFlow, CoreGenderClaim, CoreTokenResponse}, reqwest::async_http_client, AccessToken, AuthorizationCode, CsrfToken, EmptyAdditionalClaims, Nonce, OAuth2TokenResponse, PkceCodeChallenge, PkceCodeVerifier, Scope, TokenResponse, UserInfoClaims, }; use shield::{ - ConfigurationError, Provider, ProviderError, Response, Session, SessionError, ShieldError, - SignInCallbackRequest, SignInRequest, SignOutRequest, Subprovider, User, + ConfigurationError, CreateEmailAddress, CreateUser, Provider, ProviderError, Response, Session, + SessionError, ShieldError, SignInCallbackRequest, SignInRequest, SignOutRequest, Subprovider, + UpdateUser, User, }; use crate::{ claims::Claims, storage::OidcStorage, subprovider::OidcSubprovider, CreateOidcConnection, - OidcProviderPkceCodeChallenge, + OidcConnection, OidcProviderPkceCodeChallenge, UpdateOidcConnection, }; pub const OIDC_PROVIDER_ID: &str = "oidc"; @@ -57,8 +59,103 @@ impl OidcProvider { Err(ProviderError::SubproviderNotFound(subprovider_id.to_owned()).into()) } - async fn get_or_create_user(&self, _claims: &Claims) -> Result { - todo!("get_or_create_user") + async fn create_user(&self, claims: &Claims) -> Result { + if let Some(email) = claims.email() { + match self.storage.user_by_email(email).await? { + Some(_) => Err(ShieldError::Validation( + "\ + Email address `{email}` is already used by another account. \ + To link a new provider, sign in to with your exising account first. \ + If this is not your account, please contact support for assistence.\ + " + .to_owned(), + )), + None => Ok(self + .storage + .create_user( + CreateUser { + name: claims + .name() + .and_then(|name| name.get(None).map(|name| name.to_string())), + }, + CreateEmailAddress { + email: email.to_string(), + is_primary: true, + // TODO: from claim? + is_verified: false, + // TODO: generate if not verified + verification_token: None, + verification_token_expired_at: None, + verified_at: None, + }, + ) + .await?), + } + } else { + Err(ShieldError::Validation( + "Missing email address in OpenID Connect claims.".to_owned(), + )) + } + } + + async fn update_user(&self, user_id: &str, claims: &Claims) -> Result { + self.storage + .update_user(UpdateUser { + id: user_id.to_owned(), + name: claims + .name() + .map(|name| name.get(None).map(|name| name.to_string())), + }) + .await + .map_err(ShieldError::Storage) + } + + async fn create_oidc_connection( + &self, + subprovider_id: String, + user_id: String, + identifier: String, + token_response: CoreTokenResponse, + ) -> Result { + let (token_type, access_token, refresh_token, id_token, expired_at, scopes) = + parse_token_response(token_response)?; + + self.storage + .create_oidc_connection(CreateOidcConnection { + identifier, + token_type, + access_token, + refresh_token, + id_token, + expired_at, + scopes, + subprovider_id, + user_id, + }) + .await + .map_err(ShieldError::Storage) + } + + async fn update_oidc_connection( + &self, + connection_id: String, + token_response: CoreTokenResponse, + ) -> Result { + let (token_type, access_token, refresh_token, id_token, expired_at, scopes) = + parse_token_response(token_response)?; + + self.storage + .update_oidc_connection(UpdateOidcConnection { + id: connection_id, + token_type: Some(token_type), + access_token: Some(access_token), + refresh_token: Some(refresh_token), + id_token: Some(id_token), + expired_at: Some(expired_at), + scopes: Some(scopes), + }) + .await + .map_err(ShieldError::Storage) } } @@ -165,10 +262,10 @@ impl Provider for OidcProvider { .as_ref() .and_then(|query| query.get("state")) .and_then(|code| code.as_str()) - .ok_or_else(|| ShieldError::Verification("Missing state.".to_owned()))?; + .ok_or_else(|| ShieldError::Validation("Missing state.".to_owned()))?; if csrf.is_none_or(|csrf| csrf != state) { - return Err(ShieldError::Verification("Invalid state.".to_owned())); + return Err(ShieldError::Validation("Invalid state.".to_owned())); } let authorization_code = request @@ -176,7 +273,7 @@ impl Provider for OidcProvider { .as_ref() .and_then(|query| query.get("code")) .and_then(|code| code.as_str()) - .ok_or_else(|| ShieldError::Verification("Missing authorization code.".to_owned()))?; + .ok_or_else(|| ShieldError::Validation("Missing authorization code.".to_owned()))?; let subprovider = match request.subprovider_id { Some(subprovider_id) => self.oidc_subprovider_by_id(&subprovider_id).await?, @@ -191,9 +288,7 @@ impl Provider for OidcProvider { if let Some(pkce_verifier) = pkce_verifier { token_request = token_request.set_pkce_verifier(PkceCodeVerifier::new(pkce_verifier)); } else if subprovider.pkce_code_challenge != OidcProviderPkceCodeChallenge::None { - return Err(ShieldError::Verification( - "Missing PKCE verifier.".to_owned(), - )); + return Err(ShieldError::Validation("Missing PKCE verifier.".to_owned())); } let token_response = token_request @@ -202,15 +297,15 @@ impl Provider for OidcProvider { .map_err(|err| ShieldError::Request(err.to_string()))?; let claims = if let Some(id_token) = token_response.id_token() { - let claims = - id_token - .claims( - &client.id_token_verifier(), - &Nonce::new(nonce.ok_or_else(|| { - ShieldError::Verification("Missing nonce.".to_owned()) - })?), - ) - .map_err(|err| ShieldError::Verification(err.to_string()))?; + let claims = id_token + .claims( + &client.id_token_verifier(), + &Nonce::new( + nonce + .ok_or_else(|| ShieldError::Validation("Missing nonce.".to_owned()))?, + ), + ) + .map_err(|err| ShieldError::Validation(err.to_string()))?; Claims::from(claims.clone()) } else { @@ -232,46 +327,24 @@ impl Provider for OidcProvider { .await? { Some(connection) => { - // TODO: update connection - - // TODO: get user + let connection = self + .update_oidc_connection(connection.id, token_response) + .await?; - let user = self - .storage - .user_by_id(&connection.user_id) - .await? - .expect("TODO"); + let user = self.update_user(&connection.user_id, &claims).await?; (connection, user) } None => { - // TODO: find or create user - let user = self.get_or_create_user(&claims).await?; + let user = self.create_user(&claims).await?; let connection = self - .storage - .create_oidc_connection(CreateOidcConnection { - identifier: claims.subject().to_string(), - // token_type: token_response.token_type(), - token_type: "".to_owned(), - access_token: token_response.access_token().secret().clone(), - refresh_token: token_response - .refresh_token() - .map(|refresh_token| refresh_token.secret().clone()), - id_token: token_response - .id_token() - .map(|id_token| id_token.to_string()), - // expired_at: token_response.expires_in().map(|expires_in| { - // Duration::from_std(expires_in) - // .map_err(|err| ShieldError::Verification(err.to_string())) - // }), - expired_at: None, - scopes: token_response - .scopes() - .map(|scopes| scopes.iter().map(|scope| scope.to_string()).collect()), - subprovider_id: subprovider.id, - user_id: user.id(), - }) + .create_oidc_connection( + subprovider.id, + user.id(), + claims.subject().to_string(), + token_response, + ) .await?; (connection, user) @@ -314,3 +387,38 @@ impl Provider for OidcProvider { Ok(Response::Redirect("/".to_owned())) } } + +type ParsedTokenResponse = ( + String, + String, + Option, + Option, + Option>, + Option>, +); + +fn parse_token_response( + token_response: CoreTokenResponse, +) -> Result { + Ok(( + token_response.token_type().as_ref().to_string(), + token_response.access_token().secret().clone(), + token_response + .refresh_token() + .map(|refresh_token| refresh_token.secret().clone()), + token_response + .id_token() + .map(|id_token| id_token.to_string()), + match token_response.expires_in() { + Some(expires_in) => Some( + Utc::now() + + Duration::from_std(expires_in) + .map_err(|err| ShieldError::Validation(err.to_string()))?, + ), + None => None, + }, + token_response + .scopes() + .map(|scopes| scopes.iter().map(|scope| scope.to_string()).collect()), + )) +} diff --git a/packages/storage/shield-memory/src/storage.rs b/packages/storage/shield-memory/src/storage.rs index b4e8e08..9734381 100644 --- a/packages/storage/shield-memory/src/storage.rs +++ b/packages/storage/shield-memory/src/storage.rs @@ -1,7 +1,7 @@ use std::sync::{Arc, Mutex}; use async_trait::async_trait; -use shield::{Storage, StorageError, User as _}; +use shield::{CreateEmailAddress, CreateUser, Storage, StorageError, UpdateUser, User as _}; use crate::user::User; @@ -49,4 +49,20 @@ impl Storage for MemoryStorage { }) .cloned()) } + + async fn create_user( + &self, + _user: CreateUser, + _email_address: CreateEmailAddress, + ) -> Result { + todo!("create_user") + } + + async fn update_user(&self, _user: UpdateUser) -> Result { + todo!("update_user") + } + + async fn delete_user(&self, _user_id: &str) -> Result<(), StorageError> { + todo!("delete_user") + } } diff --git a/packages/storage/shield-sea-orm/src/providers/oidc.rs b/packages/storage/shield-sea-orm/src/providers/oidc.rs index b594573..d915727 100644 --- a/packages/storage/shield-sea-orm/src/providers/oidc.rs +++ b/packages/storage/shield-sea-orm/src/providers/oidc.rs @@ -90,31 +90,26 @@ impl OidcStorage for SeaOrmStorage { .one(&self.database) .await .map_err(|err| StorageError::Engine(err.to_string()))? - .ok_or_else(|| { - StorageError::Validation(format!( - "OIDC connection `{}` not found.", - connection.id - )) - })? + .ok_or_else(|| StorageError::NotFound("OIDC Connection".to_owned(), connection.id))? .into(); if let Some(token_type) = connection.token_type { - active_model.token_type = ActiveValue::set(token_type); + active_model.token_type = ActiveValue::Set(token_type); } if let Some(access_token) = connection.access_token { - active_model.access_token = ActiveValue::set(access_token); + active_model.access_token = ActiveValue::Set(access_token); } if let Some(refresh_token) = connection.refresh_token { - active_model.refresh_token = ActiveValue::set(refresh_token); + active_model.refresh_token = ActiveValue::Set(refresh_token); } if let Some(id_token) = connection.id_token { - active_model.id_token = ActiveValue::set(id_token); + active_model.id_token = ActiveValue::Set(id_token); } if let Some(expired_at) = connection.expired_at { - active_model.expired_at = ActiveValue::set(expired_at); + active_model.expired_at = ActiveValue::Set(expired_at); } if let Some(scopes) = connection.scopes { - active_model.scopes = ActiveValue::set(scopes.map(|scopes| scopes.join(","))); + active_model.scopes = ActiveValue::Set(scopes.map(|scopes| scopes.join(","))); } active_model diff --git a/packages/storage/shield-sea-orm/src/storage.rs b/packages/storage/shield-sea-orm/src/storage.rs index 3d64a0a..35bb416 100644 --- a/packages/storage/shield-sea-orm/src/storage.rs +++ b/packages/storage/shield-sea-orm/src/storage.rs @@ -1,6 +1,9 @@ use async_trait::async_trait; -use sea_orm::{prelude::Uuid, ColumnTrait, DatabaseConnection, EntityTrait, QueryFilter}; -use shield::{Storage, StorageError}; +use sea_orm::{ + prelude::Uuid, ActiveModelTrait, ActiveValue, ColumnTrait, DatabaseConnection, EntityTrait, + QueryFilter, +}; +use shield::{CreateEmailAddress, CreateUser, Storage, StorageError, UpdateUser}; use crate::entities::{email_address, prelude::User, user}; @@ -48,4 +51,69 @@ impl Storage for SeaOrmStorage { .await .map_err(|err| StorageError::Engine(err.to_string())) } + + async fn create_user( + &self, + user: CreateUser, + email_address: CreateEmailAddress, + ) -> Result { + // TODO: transaction + + let active_model = user::ActiveModel { + name: ActiveValue::Set(user.name.unwrap_or_default()), + ..Default::default() + }; + + let user = active_model + .insert(&self.database) + .await + .map_err(|err| StorageError::Engine(err.to_string()))?; + + let active_model = email_address::ActiveModel { + email: ActiveValue::Set(email_address.email), + is_primary: ActiveValue::Set(email_address.is_primary), + is_verified: ActiveValue::Set(email_address.is_verified), + verification_token: ActiveValue::Set(email_address.verification_token), + verification_token_expired_at: ActiveValue::Set( + email_address.verification_token_expired_at, + ), + verified_at: ActiveValue::Set(email_address.verified_at), + user_id: ActiveValue::Set(user.id), + ..Default::default() + }; + + active_model + .insert(&self.database) + .await + .map_err(|err| StorageError::Engine(err.to_string()))?; + + Ok(user) + } + + async fn update_user(&self, user: UpdateUser) -> Result { + let mut active_model: user::ActiveModel = user::Entity::find() + .filter(user::Column::Id.eq(Self::parse_uuid(&user.id)?)) + .one(&self.database) + .await + .map_err(|err| StorageError::Engine(err.to_string()))? + .ok_or_else(|| StorageError::NotFound("User".to_owned(), user.id))? + .into(); + + if let Some(Some(name)) = user.name { + active_model.name = ActiveValue::Set(name); + } + + active_model + .update(&self.database) + .await + .map_err(|err| StorageError::Engine(err.to_string())) + } + + async fn delete_user(&self, user_id: &str) -> Result<(), StorageError> { + user::Entity::delete_by_id(Self::parse_uuid(user_id)?) + .exec(&self.database) + .await + .map_err(|err| StorageError::Engine(err.to_string())) + .map(|_| ()) + } }