Skip to content

Commit

Permalink
support tokens without audience and scope from claims when using JwtB…
Browse files Browse the repository at this point in the history
…earerGrant (#13)

* breaking-change: allow tokens without audience to be provided via OAuth2TokenCallback.kt
* api change on OAuth2TokenCallback.kt, audience now returns List<String> instead of String
* an empty list for audience() in OAuth2TokenCallback.kt will yield a token without audience
* support returned response scope from assertion claim

Co-authored-by: Tommy Trøen <tommy.troen@nav.no>
  • Loading branch information
ybelMekk and tommytroen authored Oct 7, 2020
1 parent 94fcbfa commit 600c2ec
Show file tree
Hide file tree
Showing 9 changed files with 91 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class MockOAuth2Server(
DefaultOAuth2TokenCallback(
issuerId,
subject,
audience,
audience?.let { listOf(it) },
claims,
expiry
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ internal class AuthorizationCodeHandler(
private class LoginOAuth2TokenCallback(val login: Login, val OAuth2TokenCallback: OAuth2TokenCallback) : OAuth2TokenCallback {
override fun issuerId(): String = OAuth2TokenCallback.issuerId()
override fun subject(tokenRequest: TokenRequest): String = login.username
override fun audience(tokenRequest: TokenRequest): String = OAuth2TokenCallback.audience(tokenRequest)
override fun audience(tokenRequest: TokenRequest): List<String> = OAuth2TokenCallback.audience(tokenRequest)
override fun addClaims(tokenRequest: TokenRequest): Map<String, Any> =
OAuth2TokenCallback.addClaims(tokenRequest).toMutableMap().apply {
login.acr?.let { put("acr", it) }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import no.nav.security.mock.oauth2.OAuth2Exception
import no.nav.security.mock.oauth2.extensions.expiresIn
import no.nav.security.mock.oauth2.http.OAuth2HttpRequest
import no.nav.security.mock.oauth2.http.OAuth2TokenResponse
import no.nav.security.mock.oauth2.invalidRequest
import no.nav.security.mock.oauth2.token.OAuth2TokenCallback
import no.nav.security.mock.oauth2.token.OAuth2TokenProvider
import okhttp3.HttpUrl
Expand All @@ -20,7 +21,7 @@ internal class JwtBearerGrantHandler(private val tokenProvider: OAuth2TokenProvi
oAuth2TokenCallback: OAuth2TokenCallback
): OAuth2TokenResponse {
val tokenRequest = request.asNimbusTokenRequest()
val receivedClaimsSet = assertion(tokenRequest)
val receivedClaimsSet = tokenRequest.assertion()
val accessToken = tokenProvider.exchangeAccessToken(
tokenRequest,
issuerUrl,
Expand All @@ -31,11 +32,17 @@ internal class JwtBearerGrantHandler(private val tokenProvider: OAuth2TokenProvi
tokenType = "Bearer",
accessToken = accessToken.serialize(),
expiresIn = accessToken.expiresIn(),
scope = tokenRequest.scope.toString()
scope = tokenRequest.responseScope()
)
}

private fun assertion(tokenRequest: TokenRequest): JWTClaimsSet =
(tokenRequest.authorizationGrant as? JWTBearerGrant)?.jwtAssertion?.jwtClaimsSet
private fun TokenRequest.responseScope(): String {
return scope?.toString()
?: assertion().getClaim("scope")?.toString()
?: invalidRequest("scope must be specified in request or as a claim in assertion parameter")
}

private fun TokenRequest.assertion(): JWTClaimsSet =
(this.authorizationGrant as? JWTBearerGrant)?.jwtAssertion?.jwtClaimsSet
?: throw OAuth2Exception(OAuth2Error.INVALID_REQUEST, "missing required parameter assertion")
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@ val TOKEN_EXCHANGE = GrantType("urn:ietf:params:oauth:grant-type:token-exchange"
class TokenExchangeGrant(
val subjectTokenType: String,
val subjectToken: String,
val audience: String
val audience: MutableList<String>
) : AuthorizationGrant(TOKEN_EXCHANGE) {

override fun toParameters(): MutableMap<String, MutableList<String>> =
mutableMapOf(
"grant_type" to mutableListOf(TOKEN_EXCHANGE.value),
"subject_token_type" to mutableListOf(subjectTokenType),
"subject_token" to mutableListOf(subjectToken),
"audience" to mutableListOf(audience)
"audience" to audience
)

companion object {
Expand All @@ -27,6 +27,8 @@ class TokenExchangeGrant(
parameters.require("subject_token_type"),
parameters.require("subject_token"),
parameters.require("audience")
.split(" ")
.toMutableList()
)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@ import com.nimbusds.openid.connect.sdk.OIDCScopeValue
import java.util.UUID
import no.nav.security.mock.oauth2.extensions.clientIdAsString
import no.nav.security.mock.oauth2.extensions.grantType
import no.nav.security.mock.oauth2.grant.TokenExchangeGrant

interface OAuth2TokenCallback {
fun issuerId(): String
fun subject(tokenRequest: TokenRequest): String
fun audience(tokenRequest: TokenRequest): String
fun audience(tokenRequest: TokenRequest): List<String>
fun addClaims(tokenRequest: TokenRequest): Map<String, Any>
fun tokenExpiry(): Long
}
Expand All @@ -19,7 +20,8 @@ interface OAuth2TokenCallback {
open class DefaultOAuth2TokenCallback(
private val issuerId: String = "default",
private val subject: String = UUID.randomUUID().toString(),
private val audience: String? = null,
// needs to be nullable in order to know if a list has explicitly been set, empty list should be a allowable value
private val audience: List<String>? = null,
private val claims: Map<String, Any> = emptyMap(),
private val expiry: Long = 3600
) : OAuth2TokenCallback {
Expand All @@ -33,15 +35,14 @@ open class DefaultOAuth2TokenCallback(
}
}

override fun audience(tokenRequest: TokenRequest): String {
override fun audience(tokenRequest: TokenRequest): List<String> {
val oidcScopeList = OIDCScopeValue.values().map { it.toString() }
return audience
?: (tokenRequest.authorizationGrant as? TokenExchangeGrant)?.audience
?: let {
tokenRequest.scope?.toStringList()
?.filterNot { oidcScopeList.contains(it) }?.firstOrNull()
}
?: tokenRequest.customParameters["audience"]?.first()
?: "default"
?.filterNot { oidcScopeList.contains(it) }
} ?: listOf("default")
}

override fun addClaims(tokenRequest: TokenRequest): Map<String, Any> =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class OAuth2TokenProvider {
defaultClaims(
issuerUrl,
oAuth2TokenCallback.subject(tokenRequest),
tokenRequest.clientIdAsString(),
listOf(tokenRequest.clientIdAsString()),
nonce,
oAuth2TokenCallback.addClaims(tokenRequest),
oAuth2TokenCallback.tokenExpiry()
Expand Down Expand Up @@ -90,7 +90,7 @@ class OAuth2TokenProvider {
private fun defaultClaims(
issuerUrl: HttpUrl,
subject: String,
audience: String,
audience: List<String>,
nonce: String?,
additionalClaims: Map<String, Any>,
expiry: Long
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ class MockOAuth2ServerTest {
DefaultOAuth2TokenCallback(
issuerId = "custom",
subject = "yolo",
audience = "myaud"
audience = listOf("myaud")
)
)

Expand Down Expand Up @@ -322,7 +322,7 @@ class MockOAuth2ServerTest {
DefaultOAuth2TokenCallback(
issuerId = "default",
subject = "mysub",
audience = "muyaud",
audience = listOf("muyaud"),
claims = mapOf("someclaim" to "claimvalue")
)
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package no.nav.security.mock.oauth2.e2e

import com.nimbusds.oauth2.sdk.GrantType
import io.kotest.matchers.collections.shouldBeEmpty
import io.kotest.matchers.collections.shouldContainExactly
import io.kotest.matchers.nulls.shouldNotBeNull
import io.kotest.matchers.should
import io.kotest.matchers.shouldBe
import io.kotest.matchers.string.shouldContain
Expand Down Expand Up @@ -62,4 +64,52 @@ class JwtBearerGrantIntegrationTest {
response.accessToken.claims["claim2"] shouldBe "value2"
}
}

@Test
fun `token request with JwtBearerGrant should exchange assertion with a new token with scope specified in assertion claim or request parmas`() {
withMockOAuth2Server {
val initialSubject = "mysub"
val initialToken = this.issueToken(
issuerId = "idprovider",
clientId = "client1",
tokenCallback = DefaultOAuth2TokenCallback(
issuerId = "idprovider",
subject = initialSubject,
audience = emptyList(),
claims = mapOf(
"claim1" to "value1",
"claim2" to "value2",
"scope" to "ascope",
"resource" to "aud1",
)
)
)

initialToken.audience.shouldBeEmpty()

val issuerId = "aad"

this.enqueueCallback(DefaultOAuth2TokenCallback(issuerId = issuerId, audience = emptyList()))

val response: ParsedTokenResponse = client.tokenRequest(
url = this.tokenEndpointUrl(issuerId),
parameters = mapOf(
"grant_type" to GrantType.JWT_BEARER.value,
"assertion" to initialToken.serialize()
)
).toTokenResponse()

println("YOLO:" + response.accessToken?.serialize())

response shouldBeValidFor GrantType.JWT_BEARER
response.scope shouldContain "ascope"
response.issuedTokenType shouldBe null
response.accessToken.shouldNotBeNull()
response.accessToken should verifyWith(issuerId, this, listOf("sub", "iss", "iat", "exp"))
response.accessToken.subject shouldBe initialSubject
response.accessToken.audience.shouldBeEmpty()
response.accessToken.claims["claim1"] shouldBe "value1"
response.accessToken.claims["claim2"] shouldBe "value2"
}
}
}
18 changes: 12 additions & 6 deletions src/test/kotlin/no/nav/security/mock/oauth2/testutils/Token.kt
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,14 @@ infix fun ParsedTokenResponse.shouldBeValidFor(type: GrantType) {
}
}

fun verifyWith(issuerId: String, server: MockOAuth2Server) = object : Matcher<SignedJWT> {
fun verifyWith(
issuerId: String,
server: MockOAuth2Server,
requiredClaims: List<String> = listOf("sub", "iss", "iat", "exp", "aud")
) = object : Matcher<SignedJWT> {
override fun test(value: SignedJWT): MatcherResult {
return try {
value.verifyWith(server.issuerUrl(issuerId), server.jwksUrl(issuerId))
value.verifyWith(server.issuerUrl(issuerId), server.jwksUrl(issuerId), requiredClaims)
MatcherResult(
true,
"should not happen, famous last words",
Expand All @@ -105,17 +109,19 @@ val SignedJWT.issuer: String get() = jwtClaimsSet.issuer
val SignedJWT.subject: String get() = jwtClaimsSet.subject
val SignedJWT.claims: Map<String, Any> get() = jwtClaimsSet.claims

fun SignedJWT.verifyWith(issuer: HttpUrl, jwkSetUri: HttpUrl): JWTClaimsSet {
fun SignedJWT.verifyWith(
issuer: HttpUrl,
jwkSetUri: HttpUrl,
requiredClaims: List<String> = listOf("sub", "iss", "iat", "exp", "aud")
): JWTClaimsSet {
return DefaultJWTProcessor<SecurityContext?>()
.apply {
jwsKeySelector = JWSVerificationKeySelector(JWSAlgorithm.RS256, RemoteJWKSet(jwkSetUri.toUrl()))
jwtClaimsSetVerifier = DefaultJWTClaimsVerifier(
JWTClaimsSet.Builder()
.issuer(issuer.toString())
.build(),
HashSet(
listOf("sub", "iss", "iat", "exp", "aud")
)
HashSet(requiredClaims)
)
}.process(this, null)
}
Expand Down

0 comments on commit 600c2ec

Please sign in to comment.