Skip to content

Commit

Permalink
Shield: Improve error handling
Browse files Browse the repository at this point in the history
  • Loading branch information
DanielleHuisman committed Dec 15, 2024
1 parent b84024d commit 034e368
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 84 deletions.
30 changes: 22 additions & 8 deletions packages/core/shield/src/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,30 @@ pub struct SignInRequest {
}

#[derive(Debug, Error)]

pub enum SignInError {
pub enum ProviderError {
#[error("provider `{0}` not found")]
ProviderNotFound(String),
#[error("subprovider is missing")]
SubproviderMissing,
#[error("subprovider `{0}` not found")]
SubproviderNotFound(String),
}

#[derive(Debug, Error)]
pub enum ConfigurationError {
#[error("missing configuration: {0}")]
Missing(String),
#[error("invalid configuration: {0}")]
Invalid(String),
}

#[derive(Debug, Error)]

pub enum SignInError {
#[error(transparent)]
Provider(#[from] ProviderError),
#[error(transparent)]
Configuration(#[from] ConfigurationError),
#[error(transparent)]
Storage(#[from] StorageError),
}
Expand All @@ -34,12 +50,10 @@ pub struct SignOutRequest {
#[derive(Debug, Error)]

pub enum SignOutError {
#[error("provider `{0}` not found")]
ProviderNotFound(String),
#[error("subprovider is missing")]
SubproviderMissing,
#[error("subprovider `{0}` not found")]
SubproviderNotFound(String),
#[error(transparent)]
Provider(#[from] ProviderError),
#[error(transparent)]
Configuration(#[from] ConfigurationError),
#[error(transparent)]
Storage(#[from] StorageError),
}
8 changes: 4 additions & 4 deletions packages/core/shield/src/shield.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ use std::{collections::HashMap, sync::Arc};
use futures::future::try_join_all;

use crate::{
provider::{Provider, Subprovider},
request::{SignInRequest, SignOutRequest},
provider::{Provider, Subprovider, SubproviderVisualisation},
request::{SignInError, SignInRequest, SignOutError, SignOutRequest},
storage::{Storage, StorageError},
SignInError, SignOutError, SubproviderVisualisation,
ProviderError,
};

#[derive(Clone)]
Expand Down Expand Up @@ -76,7 +76,7 @@ impl Shield {
pub async fn sign_in(&self, request: SignInRequest) -> Result<(), SignInError> {
let provider = match self.providers.get(&request.provider_id) {
Some(provider) => provider,
None => return Err(SignInError::ProviderNotFound(request.provider_id)),
None => return Err(ProviderError::ProviderNotFound(request.provider_id).into()),
};

// let subprovider = match request.subprovider_id {
Expand Down
61 changes: 36 additions & 25 deletions packages/providers/shield-oauth/src/provider.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use async_trait::async_trait;
use shield::{
Provider, SignInError, SignInRequest, SignOutError, SignOutRequest, StorageError, Subprovider,
Provider, ProviderError, SignInError, SignInRequest, SignOutError, SignOutRequest,
StorageError, Subprovider,
};

use crate::storage::OauthStorage;
Expand Down Expand Up @@ -30,6 +31,27 @@ impl OauthProvider {
self.subproviders = subproviders.into_iter().collect();
self
}

async fn oauth_subprovider_by_id(
&self,
subprovider_id: &str,
) -> Result<Option<OauthSubprovider>, StorageError> {
if let Some(subprovider) = self
.subproviders
.iter()
.find(|subprovider| subprovider.id == subprovider_id)
{
return Ok(Some(subprovider.clone()));
}

if let Some(storage) = &self.storage {
if let Some(subprovider) = storage.oauth_subprovider_by_id(subprovider_id).await? {
return Ok(Some(subprovider));
}
}

Ok(None)
}
}

#[async_trait]
Expand Down Expand Up @@ -58,42 +80,31 @@ impl Provider for OauthProvider {
&self,
subprovider_id: &str,
) -> Result<Option<Box<dyn Subprovider>>, StorageError> {
if let Some(subprovider) = self
.subproviders
.iter()
.find(|subprovider| subprovider.id == subprovider_id)
{
return Ok(Some(Box::new(subprovider.clone()) as Box<dyn Subprovider>));
}

if let Some(storage) = &self.storage {
if let Some(subprovider) = storage.oauth_subprovider_by_id(subprovider_id).await? {
return Ok(Some(Box::new(subprovider.clone()) as Box<dyn Subprovider>));
}
}

Ok(None)
self.oauth_subprovider_by_id(subprovider_id)
.await
.map(|subprovider| {
subprovider.map(|subprovider| Box::new(subprovider) as Box<dyn Subprovider>)
})
}

async fn sign_in(&self, request: SignInRequest) -> Result<(), SignInError> {
let _subprovider = match request.subprovider_id {
Some(subprovider_id) => match self.subprovider_by_id(&subprovider_id).await? {
Some(subprovider) => Some(subprovider),
None => return Err(SignInError::SubproviderNotFound(subprovider_id)),
Some(subprovider_id) => match self.oauth_subprovider_by_id(&subprovider_id).await? {
Some(subprovider) => subprovider,
None => return Err(ProviderError::SubproviderNotFound(subprovider_id).into()),
},
None => None,
None => return Err(ProviderError::SubproviderMissing.into()),
};

todo!()
}

async fn sign_out(&self, request: SignOutRequest) -> Result<(), SignOutError> {
let _subprovider = match request.subprovider_id {
Some(subprovider_id) => match self.subprovider_by_id(&subprovider_id).await? {
Some(subprovider) => Some(subprovider),
None => return Err(SignOutError::SubproviderNotFound(subprovider_id)),
Some(subprovider_id) => match self.oauth_subprovider_by_id(&subprovider_id).await? {
Some(subprovider) => subprovider,
None => return Err(ProviderError::SubproviderNotFound(subprovider_id).into()),
},
None => None,
None => return Err(ProviderError::SubproviderMissing.into()),
};

todo!()
Expand Down
139 changes: 92 additions & 47 deletions packages/providers/shield-oidc/src/provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@ use async_trait::async_trait;
use openidconnect::{
core::{CoreAuthenticationFlow, CoreClient, CoreProviderMetadata},
reqwest::async_http_client,
AuthUrl, ClientId, ClientSecret, CsrfToken, IssuerUrl, Nonce, Scope, TokenUrl, UserInfoUrl,
AccessToken, AuthUrl, ClientId, ClientSecret, CsrfToken, IssuerUrl, Nonce, Scope, TokenUrl,
UserInfoUrl,
};
use shield::{
Provider, SignInError, SignInRequest, SignOutError, SignOutRequest, StorageError, Subprovider,
ConfigurationError, Provider, ProviderError, SignInError, SignInRequest, SignOutError,
SignOutRequest, StorageError, Subprovider,
};

use crate::{storage::OidcStorage, subprovider::OidcSubprovider};
Expand Down Expand Up @@ -56,6 +58,69 @@ impl OidcProvider {

Ok(None)
}

async fn oidc_client_for_subprovider(
subprovider: &OidcSubprovider,
) -> Result<CoreClient, ConfigurationError> {
let client = if let Some(discovery_url) = &subprovider.discovery_url {
let provider_metadata = CoreProviderMetadata::discover_async(
// TODO: Consider stripping `/.well-known/openid-configuration` so `openidconnect` doesn't error.
IssuerUrl::new(discovery_url.clone())
.map_err(|err| ConfigurationError::Invalid(err.to_string()))?,
async_http_client,
)
.await
.map_err(|err| ConfigurationError::Invalid(err.to_string()))?;

CoreClient::from_provider_metadata(
provider_metadata,
ClientId::new(subprovider.client_id.clone()),
subprovider.client_secret.clone().map(ClientSecret::new),
)
} else {
CoreClient::new(
ClientId::new(subprovider.client_id.clone()),
subprovider.client_secret.clone().map(ClientSecret::new),
IssuerUrl::new(
subprovider
.issuer_url
.clone()
.ok_or(ConfigurationError::Missing("issuer URL".to_owned()))?,
)
.map_err(|err| ConfigurationError::Invalid(err.to_string()))?,
subprovider
.authorization_url
.as_ref()
.ok_or(ConfigurationError::Missing("authorization URL".to_owned()))
.and_then(|authorization_url| {
AuthUrl::new(authorization_url.clone())
.map_err(|err| ConfigurationError::Invalid(err.to_string()))
})?,
match &subprovider.token_url {
Some(token_url) => Some(
TokenUrl::new(token_url.clone())
.map_err(|err| ConfigurationError::Invalid(err.to_string()))?,
),
None => None,
},
match &subprovider.user_info_url {
Some(user_info_url) => Some(
UserInfoUrl::new(user_info_url.clone())
.map_err(|err| ConfigurationError::Invalid(err.to_string()))?,
),
None => None,
},
subprovider
.json_web_key_set
.clone()
.ok_or(ConfigurationError::Missing("JSON Web Key Set".to_owned()))?,
)
};

// TODO: Common client options.

Ok(client)
}
}

#[async_trait]
Expand Down Expand Up @@ -95,48 +160,12 @@ impl Provider for OidcProvider {
let subprovider = match request.subprovider_id {
Some(subprovider_id) => match self.oidc_subprovider_by_id(&subprovider_id).await? {
Some(subprovider) => subprovider,
None => return Err(SignInError::SubproviderNotFound(subprovider_id)),
None => return Err(ProviderError::SubproviderNotFound(subprovider_id).into()),
},
None => return Err(SignInError::SubproviderMissing),
};

let client = if let Some(discovery_url) = subprovider.discovery_url {
let provider_metadata = CoreProviderMetadata::discover_async(
// TODO: Consider stripping `/.well-known/openid-configuration` so `openidconnect` doesn't error.
IssuerUrl::new(discovery_url).expect("TODO: issuer url error"),
async_http_client,
)
.await
.expect("TODO: provider metadata error");

CoreClient::from_provider_metadata(
provider_metadata,
ClientId::new(subprovider.client_id),
subprovider.client_secret.map(ClientSecret::new),
)
} else {
CoreClient::new(
ClientId::new(subprovider.client_id),
subprovider.client_secret.map(ClientSecret::new),
IssuerUrl::new(subprovider.issuer_url.expect("TODO: missing issuer url"))
.expect("TODO: issuer url error"),
subprovider
.authorization_url
.map(|authorization_url| {
AuthUrl::new(authorization_url).expect("TODO: auth url error")
})
.expect("TODO: missing authorization url"),
subprovider
.token_url
.map(|token_url| TokenUrl::new(token_url).expect("TODO: token url error")),
subprovider.user_info_url.map(|user_info_url| {
UserInfoUrl::new(user_info_url).expect("TODO: user info url error")
}),
subprovider.json_web_key_set.expect("TODO: missing jwks"),
)
None => return Err(ProviderError::SubproviderMissing.into()),
};

// TODO: Common client options.
let client = Self::oidc_client_for_subprovider(&subprovider).await?;

let mut authorization_request = client.authorize_url(
CoreAuthenticationFlow::AuthorizationCode,
Expand All @@ -160,14 +189,30 @@ impl Provider for OidcProvider {
}

async fn sign_out(&self, request: SignOutRequest) -> Result<(), SignOutError> {
let _subprovider = match request.subprovider_id {
Some(subprovider_id) => match self.subprovider_by_id(&subprovider_id).await? {
Some(subprovider) => Some(subprovider),
None => return Err(SignOutError::SubproviderNotFound(subprovider_id)),
let subprovider = match request.subprovider_id {
Some(subprovider_id) => match self.oidc_subprovider_by_id(&subprovider_id).await? {
Some(subprovider) => subprovider,
None => return Err(ProviderError::SubproviderNotFound(subprovider_id).into()),
},
None => return Err(SignOutError::SubproviderMissing),
None => return Err(ProviderError::SubproviderMissing.into()),
};

todo!("oidc sign out")
// TODO: find access token
let token = AccessToken::new("".to_owned());

let client = Self::oidc_client_for_subprovider(&subprovider).await?;

let revocation_request = match client.revoke_token(token.into()) {
Ok(revocation_request) => revocation_request,
Err(openidconnect::ConfigurationError::MissingUrl("revocation")) => return Ok(()),
Err(err) => return Err(ConfigurationError::Invalid(err.to_string()).into()),
};

revocation_request
.request_async(async_http_client)
.await
.expect("TODO: revocation request error");

Ok(())
}
}

0 comments on commit 034e368

Please sign in to comment.