Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

♻️ 👌 Remove Algorithm.Other #57

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,16 @@ package me.wojnowski.oidc4s.json.circe

import io.circe.Decoder
import me.wojnowski.oidc4s.Algorithm
import me.wojnowski.oidc4s.IdTokenVerifier.Error.UnsupportedAlgorithm
import me.wojnowski.oidc4s.JoseHeader
import me.wojnowski.oidc4s.json.JsonSupport

trait JoseHeaderCirceDecoder {

private implicit val jwtAlgorithmCirceDecoder: Decoder[Algorithm] =
Decoder[String].map(Algorithm.fromString)
Decoder[String].emap { shortName =>
Algorithm.findByShortName(shortName).toRight(UnsupportedAlgorithm(shortName).toRawError)
}

protected implicit val joseHeaderCirceDecoder: Decoder[JoseHeader] =
Decoder.forProduct2[JoseHeader, String, Algorithm]("kid", "alg")(JoseHeader.apply)
Expand Down
4 changes: 1 addition & 3 deletions core/src/main/scala/me/wojnowski/oidc4s/Algorithm.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,9 @@ object Algorithm {
case object Rs384 extends Algorithm(name = "RS384", fullName = "SHA384withRSA")
case object Rs512 extends Algorithm(name = "RS512", fullName = "SHA512withRSA")

case class Other(override val name: String) extends Algorithm(name, fullName = name)

implicit val order: Order[Algorithm] = Order.by(_.name)

val supportedAlgorithms: NonEmptySet[Algorithm] = NonEmptySet.of(Rs256, Rs384, Rs512)

def fromString(s: String): Algorithm = supportedAlgorithms.find(_.name === s).getOrElse(Other(s))
def findByShortName(s: String): Option[Algorithm] = supportedAlgorithms.find(_.name === s)
}
38 changes: 19 additions & 19 deletions core/src/main/scala/me/wojnowski/oidc4s/IdTokenVerifier.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,16 @@ import cats.Monad
import cats.data.EitherT
import cats.effect.Clock
import cats.syntax.all._
import me.wojnowski.oidc4s.IdTokenVerifier.Error.CouldNotDecodeClaim
import me.wojnowski.oidc4s.IdTokenVerifier.Error.CouldNotDecodeHeader
import me.wojnowski.oidc4s.IdTokenVerifier.Error.InvalidSignature
import me.wojnowski.oidc4s.IdTokenVerifier.Error.MalformedToken
import me.wojnowski.oidc4s.IdTokenVerifier.Error.TokenExpired
import me.wojnowski.oidc4s.IdTokenVerifier.Error.UnsupportedAlgorithm
import me.wojnowski.oidc4s.IdTokenVerifier.Error._
import me.wojnowski.oidc4s.config.OpenIdConnectDiscovery
import me.wojnowski.oidc4s.json.JsonDecoder
import me.wojnowski.oidc4s.json.JsonDecoder.ClaimsDecoder
import me.wojnowski.oidc4s.json.JsonDecoder
import me.wojnowski.oidc4s.json.JsonSupport

import java.nio.charset.StandardCharsets
import java.security.PublicKey
import java.security.Signature
import java.time.Instant
import java.time.ZoneId
import java.time.{Clock => JavaClock}
import java.util.Base64
import scala.util.Success
import scala.util.Try
Expand Down Expand Up @@ -125,17 +118,18 @@ object IdTokenVerifier {
private def ensureNotExpired(now: Instant, expiresAt: Instant): Either[Error.TokenExpired, Unit] =
Either.raiseWhen(expiresAt.isBefore(now))(TokenExpired(since = expiresAt))

private def decodeHeader(headerJson: String): Either[CouldNotDecodeHeader, JoseHeader] =
private def decodeHeader(headerJson: String): Either[Error, JoseHeader] =
JsonDecoder[JoseHeader]
.decode(headerJson)
.leftMap(CouldNotDecodeHeader.apply)
.leftMap { rawError =>
UnsupportedAlgorithm.fromRawError(rawError).getOrElse(CouldNotDecodeHeader(rawError))
}

private def decodeJwtAndVerifySignature[A: ClaimsDecoder](rawToken: String, key: PublicKey, header: JoseHeader)
: Either[Error, (A, IdTokenClaims)] =
rawToken.split('.') match {
case Array(rawHeader, rawClaims, rawSignature) =>
for {
_ <- verifyAlgorithm(header.algorithm)
_ <- verifySignature(header.algorithm.fullName, key, rawHeader, rawClaims, rawSignature)
result <- parseClaims[A](rawClaims)
} yield result
Expand All @@ -151,9 +145,6 @@ object IdTokenVerifier {
ClaimsDecoder[A].decode(rawJson).leftMap(CouldNotDecodeClaim.apply)
}

private def verifyAlgorithm(algorithm: Algorithm) =
Either.raiseUnless(Algorithm.supportedAlgorithms.contains_(algorithm))(UnsupportedAlgorithm(algorithm.name.some))

private def verifySignature(
signingAlgorithm: String,
publicKey: PublicKey,
Expand Down Expand Up @@ -191,12 +182,23 @@ object IdTokenVerifier {

case object CouldNotExtractHeader extends Error

case object CouldNotExtractKeyId extends Error

case class CouldNotFindPublicKey(cause: PublicKeyProvider.Error) extends Error

case class CouldNotDecodeHeader(details: String) extends Error

case class UnsupportedAlgorithm(providedAlgorithm: String) extends Error {
private[oidc4s] def toRawError: String = s"${UnsupportedAlgorithm.rawErrorPrefix}$providedAlgorithm"
}

object UnsupportedAlgorithm {

private val rawErrorPrefix = "Unsupported algorithm: "

private[oidc4s] def fromRawError(details: String): Option[UnsupportedAlgorithm] =
Option.when(details.startsWith(rawErrorPrefix))(UnsupportedAlgorithm(details.stripPrefix(rawErrorPrefix)))

}

case class CouldNotDecodeClaim(details: String) extends Error

case class TokenExpired(since: Instant) extends Error
Expand All @@ -205,8 +207,6 @@ object IdTokenVerifier {

case object InvalidSignature extends Error

case class UnsupportedAlgorithm(providedAlgorithm: Option[String]) extends Error

case class UnexpectedIssuer(found: Issuer, expected: Issuer) extends Error
}

Expand Down
61 changes: 32 additions & 29 deletions core/src/test/scala/me/wojnowski/oidc4s/IdTokenVerifierTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import me.wojnowski.oidc4s.config.Location
import me.wojnowski.oidc4s.config.OpenIdConfig
import me.wojnowski.oidc4s.config.OpenIdConnectDiscovery
import me.wojnowski.oidc4s.json.JsonDecoder
import me.wojnowski.oidc4s.json.JsonSupport
import me.wojnowski.oidc4s.mocks.CacheMock
import me.wojnowski.oidc4s.mocks.HttpTransportMock
import me.wojnowski.oidc4s.mocks.JsonSupportMock
Expand Down Expand Up @@ -43,7 +44,7 @@ class IdTokenVerifierTest extends CatsEffectSuite {
.discovery[IO](googleKeyProvider, discovery, jsonSupport)
.verifyAndDecode(rawIdToken)
.map { result =>
assertEquals(result, Right(decodedIdTokenClaims(0)))
assertEquals(result, decodedIdTokenClaims(0))
}
}
}
Expand All @@ -68,7 +69,7 @@ class IdTokenVerifierTest extends CatsEffectSuite {

implicit val jsonDecoder: JsonDecoder[(CustomClaims, IdTokenClaims)] = (rawClaims: String) =>
if (rawClaims === rawIdTokenClaims.head) {
Right(expectedCustomClaims, decodedIdTokenClaims.head)
decodedIdTokenClaims.head.map((expectedCustomClaims, _))
} else {
Left("Could not decode claims")
}
Expand All @@ -94,7 +95,7 @@ class IdTokenVerifierTest extends CatsEffectSuite {

implicit val jsonDecoder: JsonDecoder[(CustomClaims, IdTokenClaims)] = (rawClaims: String) =>
if (rawClaims === rawIdTokenClaims.head) {
Right(expectedCustomClaims, decodedIdTokenClaims.head.copy(issuer = unexpectedIssuer))
decodedIdTokenClaims.head.map(claims => (expectedCustomClaims, claims.copy(issuer = unexpectedIssuer)))
} else {
Left("Could not decode claims")
}
Expand All @@ -105,7 +106,7 @@ class IdTokenVerifierTest extends CatsEffectSuite {
.map { result =>
assertEquals(
result,
Left(UnexpectedIssuer(found = unexpectedIssuer, decodedIdTokenClaims.head.issuer))
Left(UnexpectedIssuer(found = unexpectedIssuer, decodedIdTokenClaims.head.value.issuer))
)
}
}
Expand All @@ -119,7 +120,7 @@ class IdTokenVerifierTest extends CatsEffectSuite {

implicit val jsonDecoder: JsonDecoder[(CustomClaims, IdTokenClaims)] = (rawClaims: String) =>
if (rawClaims === rawIdTokenClaims.head) {
Right(expectedCustomClaims, decodedIdTokenClaims.head)
decodedIdTokenClaims.head.map((expectedCustomClaims, _))
} else {
Left("Could not decode claims")
}
Expand All @@ -145,7 +146,7 @@ class IdTokenVerifierTest extends CatsEffectSuite {

implicit val jsonDecoder: JsonDecoder[(CustomClaims, IdTokenClaims)] = (rawClaims: String) =>
if (rawClaims === rawIdTokenClaims.head) {
Right(expectedCustomClaims, decodedIdTokenClaims.head)
decodedIdTokenClaims.head.map((expectedCustomClaims, _))
} else {
Left("Could not decode claims")
}
Expand All @@ -170,7 +171,7 @@ class IdTokenVerifierTest extends CatsEffectSuite {
.map { result =>
assertEquals(
result,
Right(decodedIdTokenClaims.apply(0).subject)
decodedIdTokenClaims.apply(0).map(_.subject)
)
}
}
Expand Down Expand Up @@ -236,7 +237,7 @@ class IdTokenVerifierTest extends CatsEffectSuite {
.discovery[IO](nonGoogleKeyProvider, discoveryWithDifferentIssuer, jsonSupport)
.verifyAndDecode(tokenWithOtherIssuer)
.map { result =>
assertEquals(result, Right(decodedIdTokenClaims.apply(1)))
assertEquals(result, decodedIdTokenClaims.apply(1))
}
}

Expand All @@ -257,21 +258,6 @@ class IdTokenVerifierTest extends CatsEffectSuite {
}
}

test("Algorithm: none and empty signature") {
val tokenWithAlgorithmNone =
"eyJhbGciOiJub25lIiwia2lkIjoiMTFlMDNmMzliOGQzMDBjOGM5YTFiODAwZGRlYmZjZmRlNDE1MmMwYyIsInR5cCI6IkpXVCJ9Cg.eyJhdWQiOiJodHRwczovL2V4YW1wbGUuY29tL3BhdGgiLCJhenAiOiJpbnRlZ3JhdGlvbi10ZXN0c0BjaGluZ29yLXRlc3QuaWFtLmdzZXJ2aWNlYWNjb3VudC5jb20iLCJlbWFpbCI6ImludGVncmF0aW9uLXRlc3RzQGNoaW5nb3ItdGVzdC5pYW0uZ3NlcnZpY2VhY2NvdW50LmNvbSIsImVtYWlsX3ZlcmlmaWVkIjp0cnVlLCJleHAiOjE1ODc2Mjk4ODgsImlhdCI6MTU4NzYyNjI4OCwiaXNzIjoiaHR0cHM6Ly9hY2NvdW50cy5nb29nbGUuY29tIiwic3ViIjoiMTA0MDI5MjkyODUzMDk5OTc4MjkzIn0."

runAtInstant(idTokenExpiration.minusSeconds(3)) {
IdTokenVerifier
.discovery[IO](nonGoogleKeyProvider, discovery, jsonSupport)
.verifyAndDecode(tokenWithAlgorithmNone)
.map {
case Left(MalformedToken) => ()
case e => fail(s"expected JwtEmptySignatureException, got $e")
}
}
}

test("Algorithm: none and random signature") {
val tokenWithAlgorithmNone =
"eyJhbGciOiJub25lIiwia2lkIjoiMTFlMDNmMzliOGQzMDBjOGM5YTFiODAwZGRlYmZjZmRlNDE1MmMwYyIsInR5cCI6IkpXVCJ9Cg.eyJhdWQiOiJodHRwczovL2V4YW1wbGUuY29tL3BhdGgiLCJhenAiOiJpbnRlZ3JhdGlvbi10ZXN0c0BjaGluZ29yLXRlc3QuaWFtLmdzZXJ2aWNlYWNjb3VudC5jb20iLCJlbWFpbCI6ImludGVncmF0aW9uLXRlc3RzQGNoaW5nb3ItdGVzdC5pYW0uZ3NlcnZpY2VhY2NvdW50LmNvbSIsImVtYWlsX3ZlcmlmaWVkIjp0cnVlLCJleHAiOjE1ODc2Mjk4ODgsImlhdCI6MTU4NzYyNjI4OCwiaXNzIjoiaHR0cHM6Ly9hY2NvdW50cy5nb29nbGUuY29tIiwic3ViIjoiMTA0MDI5MjkyODUzMDk5OTc4MjkzIn0.OZUxLojD0DddF9Phg63HQg3R6xvr2Gp3T3msRn09bHvaSUmr_SmMFrIAACiKHHJuQ43eZq9Qvc4ICCMrBwQfVV2FcOXffMQEA6SgTJxRcfSsfhdXX3QDJRGK27x0ynsarcWFrw9TefJyt_gPhhhE0yAzrxCHDPz0LRe8NCv4OvKnw9LZujF5P5k_AxduWTQJuNzHJGsx38E2NLW9SK93KbODZEzCX8YDkddfRfR_LZl2FciRsY6JoHtucqP5KMCFvSJBkfYqQGESeW8EUMxhBH8UrP1pcDD-6u7WXHM5bguC0rrGY6UPvRm3uZMcSYhyOnapC8f0zJBVXGyl9J5dWw"
Expand Down Expand Up @@ -301,6 +287,21 @@ class IdTokenVerifierTest extends CatsEffectSuite {
}
}

test("Invalid header") {
val tokenWithInvalidHeader =
"dGhpcyBpcyBub3QgZXZlbiBhIEpTT04K.eyJhdWQiOiJodHRwczovL2V4YW1wbGUuY29tL3BhdGgiLCJhenAiOiJpbnRlZ3JhdGlvbi10ZXN0c0BjaGluZ29yLXRlc3QuaWFtLmdzZXJ2aWNlYWNjb3VudC5jb20iLCJlbWFpbCI6ImludGVncmF0aW9uLXRlc3RzQGNoaW5nb3ItdGVzdC5pYW0uZ3NlcnZpY2VhY2NvdW50LmNvbSIsImVtYWlsX3ZlcmlmaWVkIjp0cnVlLCJleHAiOjE1ODc2Mjk4ODgsImlhdCI6MTU4NzYyNjI4OCwiaXNzIjoiaHR0cHM6Ly9hY2NvdW50cy5nb29nbGUuY29tIiwic3ViIjoiMTA0MDI5MjkyODUzMDk5OTc4MjkzIn0.OZUxLojD0DddF9Phg63HQg3R6xvr2Gp3T3msRn09bHvaSUmr_SmMFrIAACiKHHJuQ43eZq9Qvc4ICCMrBwQfVV2FcOXffMQEA6SgTJxRcfSsfhdXX3QDJRGK27x0ynsarcWFrw9TefJyt_gPhhhE0yAzrxCHDPz0LRe8NCv4OvKnw9LZujF5P5k_AxduWTQJuNzHJGsx38E2NLW9SK93KbODZEzCX8YDkddfRfR_LZl2FciRsY6JoHtucqP5KMCFvSJBkfYqQGESeW8EUMxhBH8UrP1pcDD-6u7WXHM5bguC0rrGY6UPvRm3uZMcSYhyOnapC8f0zJBVXGyl9J5dWw"

runAtInstant(idTokenExpiration.minusSeconds(3)) {
IdTokenVerifier
.discovery[IO](nonGoogleKeyProvider, discovery, jsonSupport)
.verifyAndDecode(tokenWithInvalidHeader)
.map {
case Left(CouldNotDecodeHeader("not a JSON")) => ()
case e => fail(s"expected $CouldNotDecodeHeader, got $e")
}
}
}

test("Every verification calls config") {
runAtInstant(idTokenExpiration.minusSeconds(3)) {
CacheMock
Expand Down Expand Up @@ -405,22 +406,24 @@ object IdTokenVerifierTest {
expiration = idTokenExpiration,
issuedAt = Instant.parse("2020-04-23T07:18:08Z")
)
)
).map(Right(_))

val rawJoseHeaders =
List(
"""{"alg":"RS256","kid":"f9d97b4cae90bcd76aeb20026f6b770cac221783","typ":"JWT"}""",
"""{"alg":"RS256","kid":"11e03f39b8d300c8c9a1b800ddebfcfde4152c0c","typ":"JWT"}""",
"""{"alg":"none","kid":"11e03f39b8d300c8c9a1b800ddebfcfde4152c0c","typ":"JWT"}""",
"""{"alg":"HS256","kid":"11e03f39b8d300c8c9a1b800ddebfcfde4152c0c","typ":"JWT"}"""
"""{"alg":"HS256","kid":"11e03f39b8d300c8c9a1b800ddebfcfde4152c0c","typ":"JWT"}""",
"""this is not even a JSON"""
)

val decodedJoseHeaders =
List(
JoseHeader(keyId = "f9d97b4cae90bcd76aeb20026f6b770cac221783", algorithm = Algorithm.Rs256),
JoseHeader(keyId = "11e03f39b8d300c8c9a1b800ddebfcfde4152c0c", algorithm = Algorithm.Rs256),
JoseHeader(keyId = "11e03f39b8d300c8c9a1b800ddebfcfde4152c0c", algorithm = Algorithm.Other("none")),
JoseHeader(keyId = "11e03f39b8d300c8c9a1b800ddebfcfde4152c0c", algorithm = Algorithm.Other("HS256"))
Right(JoseHeader(keyId = "f9d97b4cae90bcd76aeb20026f6b770cac221783", algorithm = Algorithm.Rs256)),
Right(JoseHeader(keyId = "11e03f39b8d300c8c9a1b800ddebfcfde4152c0c", algorithm = Algorithm.Rs256)),
Left(UnsupportedAlgorithm(providedAlgorithm = "none").toRawError),
Left(UnsupportedAlgorithm(providedAlgorithm = "HS256").toRawError),
Left("not a JSON")
)

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class OpenIdConnectDiscoveryTest extends FunSuite {
)
val transport = HttpTransportMock.const[Id](configurationUrl, response = "correct-config-response")
val jsonSupport = JsonSupportMock.instance(openIdConfigTranslations = { case "correct-config-response" =>
expectedConfig
Right(expectedConfig)
})

val discovery =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class PropertyIdTokenVerifierTest extends CatsEffectSuite with ScalaCheckEffectS
IdTokenVerifier.static(
publicKeyProvider,
issuer,
JsonSupportMock.instance(Map(rawClaims -> claims), Map(rawHeader -> header))
JsonSupportMock.instance(Map(rawClaims -> Right(claims)), Map(rawHeader -> Right(header)))
)

val rawJwt =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class PublicKeyProviderTest extends CatsEffectSuite {

private val transport = HttpTransportMock.const[Id](jwksUrl, "")

private val jsonSupport = JsonSupportMock.instance(jsonWebKeySetTranslations = { _ => jsonWebKeySet })
private val jsonSupport = JsonSupportMock.instance(jsonWebKeySetTranslations = { _ => Right(jsonWebKeySet) })

val discovery: OpenIdConnectDiscovery[Id] = OpenIdConnectDiscovery.static[Id](OpenIdConfig(issuer = Issuer(""), jwksUrl))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ import me.wojnowski.oidc4s.json.JsonSupport
object JsonSupportMock {

def instance(
idTokenTranslations: PartialFunction[String, IdTokenClaims] = PartialFunction.empty,
joseHeaderTranslations: PartialFunction[String, JoseHeader] = PartialFunction.empty,
openIdConfigTranslations: PartialFunction[String, OpenIdConfig] = PartialFunction.empty,
jsonWebKeySetTranslations: PartialFunction[String, JsonWebKeySet] = PartialFunction.empty
idTokenTranslations: PartialFunction[String, Either[String, IdTokenClaims]] = PartialFunction.empty,
joseHeaderTranslations: PartialFunction[String, Either[String, JoseHeader]] = PartialFunction.empty,
openIdConfigTranslations: PartialFunction[String, Either[String, OpenIdConfig]] = PartialFunction.empty,
jsonWebKeySetTranslations: PartialFunction[String, Either[String, JsonWebKeySet]] = PartialFunction.empty
): JsonSupport = new JsonSupport {

override implicit val joseHeaderDecoder: JsonDecoder[JoseHeader] =
Expand All @@ -28,8 +28,8 @@ object JsonSupportMock {
override implicit val jwksDecoder: JsonDecoder[JsonWebKeySet] =
translateOrFail(jsonWebKeySetTranslations, "JsonWebKeySet")

private def translateOrFail[A](translations: PartialFunction[String, A], name: String): JsonDecoder[A] =
(rawJson: String) => translations.lift(rawJson.trim).toRight(s"could not find $name for [$rawJson]")
private def translateOrFail[A](translations: PartialFunction[String, Either[String, A]], name: String): JsonDecoder[A] =
(rawJson: String) => translations.lift(rawJson.trim).toRight(s"could not find $name for [$rawJson]").flatten
}

}