Skip to content

Commit

Permalink
Shield OIDC: Add session integration
Browse files Browse the repository at this point in the history
  • Loading branch information
DanielleHuisman committed Dec 25, 2024
1 parent 4c60bd2 commit f5ed668
Show file tree
Hide file tree
Showing 11 changed files with 178 additions and 21 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 9 additions & 1 deletion packages/core/shield/src/dummy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::{
error::ShieldError,
form::Form,
provider::{Provider, Subprovider},
request::{SignInRequest, SignOutRequest},
request::{SignInCallbackRequest, SignInRequest, SignOutRequest},
response::Response,
session::Session,
storage::Storage,
Expand Down Expand Up @@ -48,6 +48,14 @@ impl Provider for DummyProvider {
todo!("redirect back?")
}

async fn sign_in_callback(
&self,
_request: SignInCallbackRequest,
_session: Session,
) -> Result<Response, ShieldError> {
todo!("redirect back?")
}

async fn sign_out(
&self,
_request: SignOutRequest,
Expand Down
6 changes: 6 additions & 0 deletions packages/core/shield/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,5 +44,11 @@ pub enum ShieldError {
#[error(transparent)]
Configuration(#[from] ConfigurationError),
#[error(transparent)]
Session(#[from] SessionError),
#[error(transparent)]
Storage(#[from] StorageError),
#[error("{0}")]
Request(String),
#[error("{0}")]
Verification(String),
}
18 changes: 16 additions & 2 deletions packages/core/shield/src/provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use serde::{Deserialize, Serialize};
use crate::{
error::ShieldError,
form::Form,
request::{SignInRequest, SignOutRequest},
request::{SignInCallbackRequest, SignInRequest, SignOutRequest},
response::Response,
session::Session,
};
Expand All @@ -26,6 +26,12 @@ pub trait Provider: Send + Sync {
session: Session,
) -> Result<Response, ShieldError>;

async fn sign_in_callback(
&self,
request: SignInCallbackRequest,
session: Session,
) -> Result<Response, ShieldError>;

async fn sign_out(
&self,
request: SignOutRequest,
Expand Down Expand Up @@ -57,7 +63,7 @@ pub(crate) mod tests {

use crate::{
error::ShieldError,
request::{SignInRequest, SignOutRequest},
request::{SignInCallbackRequest, SignInRequest, SignOutRequest},
response::Response,
session::Session,
};
Expand Down Expand Up @@ -103,6 +109,14 @@ pub(crate) mod tests {
todo!("redirect back?")
}

async fn sign_in_callback(
&self,
_request: SignInCallbackRequest,
_session: Session,
) -> Result<Response, ShieldError> {
todo!("redirect back?")
}

async fn sign_out(
&self,
_request: SignOutRequest,
Expand Down
7 changes: 7 additions & 0 deletions packages/core/shield/src/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,13 @@ pub struct SignInRequest {
pub form_data: Option<Value>,
}

#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct SignInCallbackRequest {
pub provider_id: String,
pub subprovider_id: Option<String>,
pub data: Option<Value>,
}

#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct SignOutRequest {
pub provider_id: String,
Expand Down
5 changes: 5 additions & 0 deletions packages/core/shield/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,9 @@ impl Session {
#[derive(Clone, Debug, Default, Deserialize, Serialize)]
pub struct SessionData {
pub user_id: Option<String>,

// TODO: allow arbitrary data to be stored by providers?
pub csrf: Option<String>,
pub nonce: Option<String>,
pub verifier: Option<String>,
}
23 changes: 21 additions & 2 deletions packages/core/shield/src/shield.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use futures::future::try_join_all;
use crate::{
error::{ProviderError, ShieldError},
provider::{Provider, Subprovider, SubproviderVisualisation},
request::{SignInRequest, SignOutRequest},
request::{SignInCallbackRequest, SignInRequest, SignOutRequest},
response::Response,
session::Session,
storage::Storage,
Expand Down Expand Up @@ -90,6 +90,21 @@ impl Shield {
provider.sign_in(request, session).await
}

pub async fn sign_in_callback(
&self,
request: SignInCallbackRequest,
session: Session,
) -> Result<Response, ShieldError> {
println!("sign in callback {:?}", request);

let provider = match self.providers.get(&request.provider_id) {
Some(provider) => provider,
None => return Err(ProviderError::ProviderNotFound(request.provider_id).into()),
};

provider.sign_in_callback(request, session).await
}

pub async fn sign_out(
&self,
request: SignOutRequest,
Expand All @@ -102,7 +117,11 @@ impl Shield {
None => return Err(ProviderError::ProviderNotFound(request.provider_id).into()),
};

provider.sign_out(request, session).await
let response = provider.sign_out(request, session.clone()).await?;

session.purge().await?;

Ok(response)
}
}

Expand Down
15 changes: 10 additions & 5 deletions packages/integrations/shield-leptos/src/routes/sign_in.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
use leptos::{either::Either, prelude::*};
use shield::{ServerIntegration, SubproviderVisualisation};
use shield::SubproviderVisualisation;

#[server]
pub async fn subproviders() -> Result<Vec<SubproviderVisualisation>, ServerFnError> {
use shield::Shield;
use std::sync::Arc;

let shield = expect_context::<Shield>();
use shield::ServerIntegration;

let server_integration = expect_context::<Arc<dyn ServerIntegration>>();
let shield = server_integration.extract_shield().await;

shield
.subprovider_visualisations()
Expand All @@ -18,9 +21,11 @@ pub async fn sign_in(
provider_id: String,
subprovider_id: Option<String>,
) -> Result<(), ServerFnError> {
use shield::{Response, ShieldError, SignInRequest};
use std::sync::Arc;

use shield::{Response, ServerIntegration, ShieldError, SignInRequest};

let server_integration = expect_context::<&dyn ServerIntegration>();
let server_integration = expect_context::<Arc<dyn ServerIntegration>>();
let shield = server_integration.extract_shield().await;
let session = server_integration.extract_session().await;

Expand Down
18 changes: 16 additions & 2 deletions packages/providers/shield-oauth/src/provider.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use async_trait::async_trait;
use shield::{
Provider, ProviderError, Response, Session, ShieldError, SignInRequest, SignOutRequest,
Subprovider,
Provider, ProviderError, Response, Session, ShieldError, SignInCallbackRequest, SignInRequest,
SignOutRequest, Subprovider,
};

use crate::{storage::OauthStorage, subprovider::OauthSubprovider};
Expand Down Expand Up @@ -84,6 +84,7 @@ impl Provider for OauthProvider {
.await
.map(|subprovider| Some(Box::new(subprovider) as Box<dyn Subprovider>))
}

async fn sign_in(
&self,
request: SignInRequest,
Expand All @@ -97,6 +98,19 @@ impl Provider for OauthProvider {
todo!("oauth sign in")
}

async fn sign_in_callback(
&self,
request: SignInCallbackRequest,
_session: Session,
) -> Result<Response, ShieldError> {
let _subprovider = match request.subprovider_id {
Some(subprovider_id) => self.oauth_subprovider_by_id(&subprovider_id).await?,
None => return Err(ProviderError::SubproviderMissing.into()),
};

todo!("oauth sign in callback")
}

async fn sign_out(
&self,
request: SignOutRequest,
Expand Down
1 change: 1 addition & 0 deletions packages/providers/shield-oidc/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,6 @@ version.workspace = true

[dependencies]
async-trait.workspace = true
oauth2 = {version = "4.4.2", features = ["pkce-plain"]}
openidconnect = "3.5.0"
shield = { path = "../../core/shield" }
95 changes: 86 additions & 9 deletions packages/providers/shield-oidc/src/provider.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
use async_trait::async_trait;
use openidconnect::{
core::CoreAuthenticationFlow, reqwest::async_http_client, AccessToken, CsrfToken, Nonce, Scope,
core::CoreAuthenticationFlow, reqwest::async_http_client, AccessToken, AuthorizationCode,
CsrfToken, Nonce, PkceCodeChallenge, PkceCodeVerifier, Scope, TokenResponse,
};
use shield::{
ConfigurationError, Provider, ProviderError, Response, Session, ShieldError, SignInRequest,
SignOutRequest, Subprovider,
ConfigurationError, Provider, ProviderError, Response, Session, SessionError, ShieldError,
SignInCallbackRequest, SignInRequest, SignOutRequest, Subprovider,
};

use crate::{storage::OidcStorage, subprovider::OidcSubprovider};
use crate::{storage::OidcStorage, subprovider::OidcSubprovider, OidcProviderPkceCodeChallenge};

pub const OIDC_PROVIDER_ID: &str = "oidc";

Expand Down Expand Up @@ -91,7 +92,7 @@ impl Provider for OidcProvider {
async fn sign_in(
&self,
request: SignInRequest,
_session: Session,
session: Session,
) -> Result<Response, ShieldError> {
let subprovider = match request.subprovider_id {
Some(subprovider_id) => self.oidc_subprovider_by_id(&subprovider_id).await?,
Expand All @@ -106,21 +107,97 @@ impl Provider for OidcProvider {
Nonce::new_random,
);

// TODO: PKCE code challenge.
let pkce_code_challenge = match subprovider.pkce_code_challenge {
OidcProviderPkceCodeChallenge::None => None,
OidcProviderPkceCodeChallenge::Plain => Some(PkceCodeChallenge::new_random_plain()),
OidcProviderPkceCodeChallenge::S256 => Some(PkceCodeChallenge::new_random_sha256()),
};

if let Some((pkce_code_challenge, _)) = &pkce_code_challenge {
authorization_request =
authorization_request.set_pkce_challenge(pkce_code_challenge.clone());
}

if let Some(scopes) = subprovider.scopes {
authorization_request =
authorization_request.add_scopes(scopes.into_iter().map(Scope::new));
}

let (auth_url, _csrf_token, _nonce) = authorization_request.url();
let (auth_url, csrf_token, nonce) = authorization_request.url();

// TODO: Store CSRF and nonce in session.
// TODO: Redirect.
{
let session_data = session.data();
let mut session_data = session_data
.lock()
.map_err(|err| SessionError::Lock(err.to_string()))?;

session_data.csrf = Some(csrf_token.secret().clone());
session_data.nonce = Some(nonce.secret().clone());
session_data.verifier = pkce_code_challenge
.map(|(_, pkce_code_verifier)| pkce_code_verifier.secret().clone());
}

Ok(Response::Redirect(auth_url.to_string()))
}

async fn sign_in_callback(
&self,
request: SignInCallbackRequest,
session: Session,
) -> Result<Response, ShieldError> {
let (pkce_verifier, nonce) = {
let session_data = session.data();
let session_data = session_data
.lock()
.map_err(|err| SessionError::Lock(err.to_string()))?;

(session_data.verifier.clone(), session_data.nonce.clone())
};

let subprovider = match request.subprovider_id {
Some(subprovider_id) => self.oidc_subprovider_by_id(&subprovider_id).await?,
None => return Err(ProviderError::SubproviderMissing.into()),
};

let client = subprovider.oidc_client().await?;

let authorization_code = "".to_owned();

let mut token_request = client.exchange_code(AuthorizationCode::new(authorization_code));

if let Some(pkce_verifier) = pkce_verifier {
token_request = token_request.set_pkce_verifier(PkceCodeVerifier::new(pkce_verifier));
} else if subprovider.pkce_code_challenge != OidcProviderPkceCodeChallenge::None {
return Err(ShieldError::Verification(
"Missing PKCE verifier.".to_owned(),
));
}

let token_response = token_request
.request_async(async_http_client)
.await
.map_err(|err| ShieldError::Request(err.to_string()))?;

if let Some(id_token) = token_response.id_token() {
let claims =
id_token
.claims(
&client.id_token_verifier(),
&Nonce::new(nonce.ok_or_else(|| {
ShieldError::Verification("Missing nonce.".to_owned())
})?),
)
.map_err(|err| ShieldError::Verification(err.to_string()))?;

println!("{:?}", claims);
}

// let user_info = client.user_info(token_response.access_token(), None)

// TODO
Ok(Response::Redirect("/".to_owned()))
}

async fn sign_out(
&self,
request: SignOutRequest,
Expand Down

0 comments on commit f5ed668

Please sign in to comment.