diff --git a/siopv2/src/provider.rs b/siopv2/src/provider.rs index eca6a825..5ab8384a 100644 --- a/siopv2/src/provider.rs +++ b/siopv2/src/provider.rs @@ -33,22 +33,30 @@ where /// request by value. If the [`RequestUrl`] is a request by value, the request is decoded by the [`Subject`] of the [`Provider`]. /// If the request is valid, the request is returned. pub async fn validate_request(&self, request: RequestUrl) -> Result { - let request = match request { - RequestUrl::AuthorizationRequest(request) => *request, - RequestUrl::RequestUri { request_uri } => { - let client = reqwest::Client::new(); - let builder = client.get(request_uri); - let request_value = builder.send().await?.text().await?; - self.subject.decode(request_value).await? - } + let authorization_request = if let RequestUrl::Request(request) = request { + *request + } else { + let (request_object, client_id) = match request { + RequestUrl::RequestUri { request_uri, client_id } => { + let client = reqwest::Client::new(); + let builder = client.get(request_uri); + let request_value = builder.send().await?.text().await?; + (request_value, client_id) + } + RequestUrl::RequestObject { request, client_id } => (request, client_id), + _ => unreachable!(), + }; + let authorization_request: AuthorizationRequest = self.subject.decode(request_object).await?; + anyhow::ensure!(*authorization_request.client_id() == client_id, "Client id mismatch."); + authorization_request }; self.subject_syntax_types_supported().and_then(|supported| { - request.subject_syntax_types_supported().map_or_else( + authorization_request.subject_syntax_types_supported().map_or_else( || Err(anyhow!("No supported subject syntax types found.")), |supported_types| { supported_types.iter().find(|sst| supported.contains(sst)).map_or_else( || Err(anyhow!("Subject syntax type not supported.")), - |_| Ok(request.clone()), + |_| Ok(authorization_request.clone()), ) }, ) diff --git a/siopv2/src/relying_party.rs b/siopv2/src/relying_party.rs index 7e4873c4..78aa7223 100644 --- a/siopv2/src/relying_party.rs +++ b/siopv2/src/relying_party.rs @@ -146,6 +146,7 @@ mod tests { // Create a new RequestUrl which includes a `request_uri` pointing to the mock server's `request_uri` endpoint. let request_url = RequestUrl::builder() + .client_id("did:mock:1".to_string()) .request_uri(format!("{server_url}/request_uri")) .build() .unwrap(); diff --git a/siopv2/src/request/mod.rs b/siopv2/src/request/mod.rs index 5bce4536..892c7263 100644 --- a/siopv2/src/request/mod.rs +++ b/siopv2/src/request/mod.rs @@ -1,3 +1,4 @@ +use crate::token::id_token::RFC7519Claims; use crate::{claims::ClaimRequests, Registration, RequestUrlBuilder, Scope, StandardClaimsRequests}; use anyhow::{anyhow, Result}; use derive_more::Display; @@ -20,10 +21,11 @@ pub mod request_builder; /// /// // An example of a form-urlencoded request with only the `request_uri` parameter will be parsed as a /// // `RequestUrl::RequestUri` variant. -/// let request_url = RequestUrl::from_str("siopv2://idtoken?request_uri=https://example.com/request_uri").unwrap(); +/// let request_url = RequestUrl::from_str("siopv2://idtoken?client_id=did%3Aexample%3AEiDrihTRe0GMdc3K16kgJB3Xbl9Hb8oqVHjzm6ufHcYDGA&request_uri=https://example.com/request_uri").unwrap(); /// assert_eq!( /// request_url, /// RequestUrl::RequestUri { +/// client_id: "did:example:EiDrihTRe0GMdc3K16kgJB3Xbl9Hb8oqVHjzm6ufHcYDGA".to_string(), /// request_uri: "https://example.com/request_uri".to_string() /// } /// ); @@ -45,16 +47,17 @@ pub mod request_builder; /// ) /// .unwrap(); /// assert!(match request_url { -/// RequestUrl::AuthorizationRequest(_) => Ok(()), +/// RequestUrl::Request(_) => Ok(()), /// RequestUrl::RequestUri { .. } => Err(()), +/// RequestUrl::RequestObject { .. } => Err(()), /// }.is_ok()); /// ``` -#[derive(Deserialize, Debug, PartialEq, Clone, Serialize)] +#[derive(Deserialize, Debug, PartialEq, Serialize, Clone)] #[serde(untagged, deny_unknown_fields)] pub enum RequestUrl { - AuthorizationRequest(Box), - // TODO: Add client_id parameter. - RequestUri { request_uri: String }, + Request(Box), + RequestObject { client_id: String, request: String }, + RequestUri { client_id: String, request_uri: String }, } impl RequestUrl { @@ -68,8 +71,9 @@ impl TryInto for RequestUrl { fn try_into(self) -> Result { match self { - RequestUrl::AuthorizationRequest(request) => Ok(*request), - RequestUrl::RequestUri { .. } => Err(anyhow!("AuthorizationRequest is a request URI.")), + RequestUrl::Request(request) => Ok(*request), + RequestUrl::RequestUri { .. } => Err(anyhow!("Request is a request URI.")), + RequestUrl::RequestObject { .. } => Err(anyhow!("Request is a request object.")), } } } @@ -127,9 +131,12 @@ pub enum ResponseType { /// [`AuthorizationRequest`] is a request from a [crate::relying_party::RelyingParty] (RP) to a [crate::provider::Provider] (SIOP). #[allow(dead_code)] -#[derive(Debug, Getters, PartialEq, Clone, Default, Serialize, Deserialize)] +#[derive(Debug, Getters, PartialEq, Default, Serialize, Deserialize, Clone)] #[serde(deny_unknown_fields)] pub struct AuthorizationRequest { + #[serde(flatten)] + #[getset(get = "pub")] + pub(super) rfc7519_claims: RFC7519Claims, pub(crate) response_type: ResponseType, pub(crate) response_mode: Option, #[getset(get = "pub")] @@ -144,11 +151,6 @@ pub struct AuthorizationRequest { pub(crate) nonce: String, #[getset(get = "pub")] pub(crate) registration: Option, - pub(crate) iss: Option, - pub(crate) iat: Option, - pub(crate) exp: Option, - pub(crate) nbf: Option, - pub(crate) jti: Option, #[getset(get = "pub")] pub(crate) state: Option, } @@ -183,11 +185,12 @@ mod tests { #[test] fn test_valid_request_uri() { // A form urlencoded string with a `request_uri` parameter should deserialize into the `RequestUrl::RequestUri` variant. - let request_url = RequestUrl::from_str("siopv2://idtoken?request_uri=https://example.com/request_uri").unwrap(); + let request_url = RequestUrl::from_str("siopv2://idtoken?client_id=https%3A%2F%2Fclient.example.org%2Fcb&request_uri=https://example.com/request_uri").unwrap(); assert_eq!( request_url, RequestUrl::RequestUri { - request_uri: "https://example.com/request_uri".to_string() + client_id: "https://client.example.org/cb".to_string(), + request_uri: "https://example.com/request_uri".to_string(), } ); } @@ -212,7 +215,8 @@ mod tests { .unwrap(); assert_eq!( request_url.clone(), - RequestUrl::AuthorizationRequest(Box::new(AuthorizationRequest { + RequestUrl::Request(Box::new(AuthorizationRequest { + rfc7519_claims: RFC7519Claims::default(), response_type: ResponseType::IdToken, response_mode: Some("post".to_string()), client_id: "did:example:\ @@ -227,11 +231,6 @@ mod tests { .with_subject_syntax_types_supported(vec!["did:mock".to_string()]) .with_id_token_signing_alg_values_supported(vec!["EdDSA".to_string()]), ), - iss: None, - iat: None, - exp: None, - nbf: None, - jti: None, state: None, })) ); @@ -242,6 +241,22 @@ mod tests { ); } + #[test] + fn test_valid_request_object() { + // A form urlencoded string with a `request` parameter should deserialize into the `RequestUrl::RequestObject` variant. + let request_url = RequestUrl::from_str( + "siopv2://idtoken?client_id=https%3A%2F%2Fclient.example.org%2Fcb&request=eyJhb...lMGzw", + ) + .unwrap(); + assert_eq!( + request_url, + RequestUrl::RequestObject { + client_id: "https://client.example.org/cb".to_string(), + request: "eyJhb...lMGzw".to_string() + } + ); + } + #[test] fn test_invalid_request() { // A form urlencoded string with an otherwise valid request is invalid when the `request_uri` parameter is also diff --git a/siopv2/src/request/request_builder.rs b/siopv2/src/request/request_builder.rs index c6ac4bc1..c89da14c 100644 --- a/siopv2/src/request/request_builder.rs +++ b/siopv2/src/request/request_builder.rs @@ -1,6 +1,8 @@ use crate::{ + builder_fn, claims::ClaimRequests, request::{AuthorizationRequest, RequestUrl, ResponseType}, + token::id_token::RFC7519Claims, Registration, Scope, }; use anyhow::{anyhow, Result}; @@ -8,68 +10,61 @@ use is_empty::IsEmpty; #[derive(Default, IsEmpty)] pub struct RequestUrlBuilder { + rfc7519_claims: RFC7519Claims, + client_id: Option, + request: Option, request_uri: Option, response_type: Option, response_mode: Option, - client_id: Option, scope: Option, claims: Option>, redirect_uri: Option, nonce: Option, registration: Option, - iss: Option, - iat: Option, - exp: Option, - nbf: Option, - jti: Option, state: Option, } -macro_rules! builder_fn { - ($name:ident, $ty:ty) => { - pub fn $name(mut self, value: $ty) -> Self { - self.$name = Some(value); - self - } - }; -} - impl RequestUrlBuilder { pub fn new() -> Self { RequestUrlBuilder::default() } - pub fn build(&mut self) -> Result { - let request_uri = self.request_uri.take(); - match (request_uri, self.is_empty()) { - (Some(request_uri), true) => Ok(RequestUrl::RequestUri { request_uri }), - (None, _) => Ok(RequestUrl::AuthorizationRequest(Box::new(AuthorizationRequest { + pub fn build(mut self) -> Result { + match ( + self.client_id.take(), + self.request.take(), + self.request_uri.take(), + self.is_empty(), + ) { + (None, _, _, _) => Err(anyhow!("client_id parameter is required.")), + (Some(client_id), Some(request), None, true) => Ok(RequestUrl::RequestObject { client_id, request }), + (Some(client_id), None, Some(request_uri), true) => Ok(RequestUrl::RequestUri { client_id, request_uri }), + (Some(client_id), None, None, false) => Ok(RequestUrl::Request(Box::new(AuthorizationRequest { + rfc7519_claims: self.rfc7519_claims, + client_id, response_type: self .response_type .take() - .ok_or(anyhow!("response_type parameter is required."))?, + .ok_or_else(|| anyhow!("response_type parameter is required."))?, response_mode: self.response_mode.take(), - client_id: self - .client_id + scope: self + .scope .take() - .ok_or(anyhow!("client_id parameter is required."))?, - scope: self.scope.take().ok_or(anyhow!("scope parameter is required."))?, + .ok_or_else(|| anyhow!("scope parameter is required."))?, claims: self.claims.take().transpose()?, redirect_uri: self .redirect_uri .take() - .ok_or(anyhow!("redirect_uri parameter is required."))?, - nonce: self.nonce.take().ok_or(anyhow!("nonce parameter is required."))?, + .ok_or_else(|| anyhow!("redirect_uri parameter is required."))?, + nonce: self + .nonce + .take() + .ok_or_else(|| anyhow!("nonce parameter is required."))?, registration: self.registration.take(), - iss: self.iss.take(), - iat: self.iat, - exp: self.exp, - nbf: self.nbf, - jti: self.jti.take(), state: self.state.take(), }))), _ => Err(anyhow!( - "request_uri and other parameters cannot be set at the same time." + "one of either request_uri, request or other parameters should be set" )), } } @@ -79,6 +74,13 @@ impl RequestUrlBuilder { self } + builder_fn!(rfc7519_claims, iss, String); + builder_fn!(rfc7519_claims, sub, String); + builder_fn!(rfc7519_claims, aud, String); + builder_fn!(rfc7519_claims, exp, i64); + builder_fn!(rfc7519_claims, nbf, i64); + builder_fn!(rfc7519_claims, iat, i64); + builder_fn!(rfc7519_claims, jti, String); builder_fn!(request_uri, String); builder_fn!(response_type, ResponseType); builder_fn!(response_mode, String); @@ -87,11 +89,6 @@ impl RequestUrlBuilder { builder_fn!(redirect_uri, String); builder_fn!(nonce, String); builder_fn!(registration, Registration); - builder_fn!(iss, String); - builder_fn!(iat, i64); - builder_fn!(exp, i64); - builder_fn!(nbf, i64); - builder_fn!(jti, String); builder_fn!(state, String); } @@ -120,7 +117,8 @@ mod tests { assert_eq!( request_url, - RequestUrl::AuthorizationRequest(Box::new(AuthorizationRequest { + RequestUrl::Request(Box::new(AuthorizationRequest { + rfc7519_claims: RFC7519Claims::default(), response_type: ResponseType::IdToken, response_mode: None, client_id: "did:example:123".to_string(), @@ -135,11 +133,6 @@ mod tests { redirect_uri: "https://example.com".to_string(), nonce: "nonce".to_string(), registration: None, - iss: None, - iat: None, - exp: None, - nbf: None, - jti: None, state: None, })) ); @@ -167,10 +160,10 @@ mod tests { .nonce("nonce".to_string()) .claims( r#"{ - "id_token": { - "name": "invalid" - } - }"#, + "id_token": { + "name": "invalid" + } + }"#, ) .build() .is_err()); @@ -179,6 +172,7 @@ mod tests { #[test] fn test_valid_request_uri_builder() { let request_url = RequestUrl::builder() + .client_id("did:example:123".to_string()) .request_uri("https://example.com/request_uri".to_string()) .build() .unwrap(); @@ -186,6 +180,7 @@ mod tests { assert_eq!( request_url, RequestUrl::RequestUri { + client_id: "did:example:123".to_string(), request_uri: "https://example.com/request_uri".to_string() } ); diff --git a/siopv2/src/token/id_token.rs b/siopv2/src/token/id_token.rs index c44533ce..6abcd7e8 100644 --- a/siopv2/src/token/id_token.rs +++ b/siopv2/src/token/id_token.rs @@ -1,6 +1,7 @@ use super::id_token_builder::IdTokenBuilder; use crate::{parse_other, StandardClaimsValues}; use getset::Getters; +use is_empty::IsEmpty; use serde::{Deserialize, Serialize}; use serde_with::skip_serializing_none; @@ -41,15 +42,15 @@ pub struct SubJwk { /// Set of IANA registered claims by the Internet Engineering Task Force (IETF) in /// [RFC 7519](https://tools.ietf.org/html/rfc7519#section-4.1). -#[derive(Serialize, Deserialize, Debug, Default, PartialEq)] +#[derive(Serialize, Deserialize, Debug, Default, PartialEq, Clone, IsEmpty)] pub struct RFC7519Claims { - pub(super) iss: Option, - pub(super) sub: Option, - pub(super) aud: Option, - pub(super) exp: Option, - pub(super) nbf: Option, - pub(super) iat: Option, - pub(super) jti: Option, + pub(crate) iss: Option, + pub(crate) sub: Option, + pub(crate) aud: Option, + pub(crate) exp: Option, + pub(crate) nbf: Option, + pub(crate) iat: Option, + pub(crate) jti: Option, } #[cfg(test)]