Skip to content

Commit

Permalink
Shield OIDC: Add support for params
Browse files Browse the repository at this point in the history
  • Loading branch information
DanielleHuisman committed Dec 29, 2024
1 parent 9f439fb commit 8f263e1
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 1 deletion.
2 changes: 1 addition & 1 deletion examples/leptos-axum/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ async fn main() {
"client1",
)
.client_secret("xcpQsaGbRILTljPtX4npjmYMBjKrariJ")
.redirect_url(&format!(
.redirect_url(format!(
"http://localhost:{}/api/auth/sign-in/callback/oidc/keycloak",
addr.port()
))
Expand Down
31 changes: 31 additions & 0 deletions packages/providers/shield-oidc/src/provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use chrono::{DateTime, Duration, Utc};
use openidconnect::{
core::{CoreAuthenticationFlow, CoreGenderClaim, CoreTokenResponse},
reqwest::async_http_client,
url::form_urlencoded::parse,
AccessToken, AuthorizationCode, CsrfToken, EmptyAdditionalClaims, Nonce, OAuth2TokenResponse,
PkceCodeChallenge, PkceCodeVerifier, Scope, TokenResponse, UserInfoClaims,
};
Expand Down Expand Up @@ -221,6 +222,15 @@ impl<U: User> Provider for OidcProvider<U> {
authorization_request.add_scopes(scopes.into_iter().map(Scope::new));
}

if let Some(authorization_url_params) = subprovider.authorization_url_params {
let params = parse(authorization_url_params.trim_start_matches('?').as_bytes());

for (name, value) in params {
authorization_request =
authorization_request.add_extra_param(name.into_owned(), value.into_owned());
}
}

let (auth_url, csrf_token, nonce) = authorization_request.url();

{
Expand Down Expand Up @@ -292,6 +302,15 @@ impl<U: User> Provider for OidcProvider<U> {
return Err(ShieldError::Validation("Missing PKCE verifier.".to_owned()));
}

if let Some(token_url_params) = subprovider.token_url_params {
let params = parse(token_url_params.trim_start_matches('?').as_bytes());

for (name, value) in params {
token_request =
token_request.add_extra_param(name.into_owned(), value.into_owned());
}
}

let token_response = token_request
.request_async(async_http_client)
.await
Expand Down Expand Up @@ -414,6 +433,18 @@ impl<U: User> Provider for OidcProvider<U> {
};

if let Some(revocation_request) = revocation_request {
let mut revocation_request = revocation_request;

if let Some(revocation_url_params) = subprovider.revocation_url_params {
let params =
parse(revocation_url_params.trim_start_matches('?').as_bytes());

for (name, value) in params {
revocation_request = revocation_request
.add_extra_param(name.into_owned(), value.into_owned());
}
}

revocation_request
.request_async(async_http_client)
.await
Expand Down

0 comments on commit 8f263e1

Please sign in to comment.