Skip to content

Commit

Permalink
Shield OIDC: Add support for subprovider slugs
Browse files Browse the repository at this point in the history
  • Loading branch information
DanielleHuisman committed Dec 29, 2024
1 parent 3ee6b78 commit 7dbac7d
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 11 deletions.
16 changes: 10 additions & 6 deletions packages/providers/shield-oidc/src/provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ impl<U: User> OidcProvider<U> {
self
}

async fn oidc_subprovider_by_id(
async fn oidc_subprovider_by_id_or_slug(
&self,
subprovider_id: &str,
) -> Result<OidcSubprovider, ShieldError> {
Expand All @@ -54,7 +54,11 @@ impl<U: User> OidcProvider<U> {
return Ok(subprovider.clone());
}

if let Some(subprovider) = self.storage.oidc_subprovider_by_id(subprovider_id).await? {
if let Some(subprovider) = self
.storage
.oidc_subprovider_by_id_or_slug(subprovider_id)
.await?
{
return Ok(subprovider);
}

Expand Down Expand Up @@ -183,7 +187,7 @@ impl<U: User> Provider for OidcProvider<U> {
&self,
subprovider_id: &str,
) -> Result<Option<Box<dyn Subprovider>>, ShieldError> {
self.oidc_subprovider_by_id(subprovider_id)
self.oidc_subprovider_by_id_or_slug(subprovider_id)
.await
.map(|subprovider| Some(Box::new(subprovider) as Box<dyn Subprovider>))
}
Expand All @@ -194,7 +198,7 @@ impl<U: User> Provider for OidcProvider<U> {
session: Session,
) -> Result<Response, ShieldError> {
let subprovider = match request.subprovider_id {
Some(subprovider_id) => self.oidc_subprovider_by_id(&subprovider_id).await?,
Some(subprovider_id) => self.oidc_subprovider_by_id_or_slug(&subprovider_id).await?,
None => return Err(ProviderError::SubproviderMissing.into()),
};

Expand Down Expand Up @@ -287,7 +291,7 @@ impl<U: User> Provider for OidcProvider<U> {
.ok_or_else(|| ShieldError::Validation("Missing authorization code.".to_owned()))?;

let subprovider = match request.subprovider_id {
Some(subprovider_id) => self.oidc_subprovider_by_id(&subprovider_id).await?,
Some(subprovider_id) => self.oidc_subprovider_by_id_or_slug(&subprovider_id).await?,
None => return Err(ProviderError::SubproviderMissing.into()),
};

Expand Down Expand Up @@ -405,7 +409,7 @@ impl<U: User> Provider for OidcProvider<U> {
session: Session,
) -> Result<Response, ShieldError> {
let subprovider = match request.subprovider_id {
Some(subprovider_id) => self.oidc_subprovider_by_id(&subprovider_id).await?,
Some(subprovider_id) => self.oidc_subprovider_by_id_or_slug(&subprovider_id).await?,
None => return Err(ProviderError::SubproviderMissing.into()),
};

Expand Down
2 changes: 1 addition & 1 deletion packages/providers/shield-oidc/src/storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use crate::{
pub trait OidcStorage<U: User>: Storage<U> + Sync {
async fn oidc_subproviders(&self) -> Result<Vec<OidcSubprovider>, StorageError>;

async fn oidc_subprovider_by_id(
async fn oidc_subprovider_by_id_or_slug(
&self,
subprovider_id: &str,
) -> Result<Option<OidcSubprovider>, StorageError>;
Expand Down
2 changes: 1 addition & 1 deletion packages/storage/shield-memory/src/providers/oidc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ impl OidcStorage<User> for MemoryStorage {
Ok(vec![])
}

async fn oidc_subprovider_by_id(
async fn oidc_subprovider_by_id_or_slug(
&self,
_subprovider_id: &str,
) -> Result<Option<OidcSubprovider>, StorageError> {
Expand Down
11 changes: 8 additions & 3 deletions packages/storage/shield-sea-orm/src/providers/oidc.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use async_trait::async_trait;
use sea_orm::{ActiveModelTrait, ActiveValue, ColumnTrait, EntityTrait, QueryFilter};
use sea_orm::{ActiveModelTrait, ActiveValue, ColumnTrait, Condition, EntityTrait, QueryFilter};
use shield::StorageError;
use shield_oidc::{
CreateOidcConnection, OidcConnection, OidcProviderPkceCodeChallenge, OidcProviderVisibility,
Expand Down Expand Up @@ -27,11 +27,16 @@ impl OidcStorage<User> for SeaOrmStorage {
})
}

async fn oidc_subprovider_by_id(
async fn oidc_subprovider_by_id_or_slug(
&self,
subprovider_id: &str,
) -> Result<Option<OidcSubprovider>, StorageError> {
oidc_provider::Entity::find_by_id(Self::parse_uuid(subprovider_id)?)
oidc_provider::Entity::find()
.filter(
Condition::any()
.add(oidc_provider::Column::Id.eq(Self::parse_uuid(subprovider_id)?))
.add(oidc_provider::Column::Slug.eq(subprovider_id.to_lowercase())),
)
.one(&self.database)
.await
.map_err(|err| StorageError::Engine(err.to_string()))
Expand Down

0 comments on commit 7dbac7d

Please sign in to comment.