Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

♻️ 👌 Remove Algorithm.Other #57

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import me.wojnowski.oidc4s.JoseHeader
trait JoseHeaderCirceDecoder {

private implicit val jwtAlgorithmCirceDecoder: Decoder[Algorithm] =
Decoder[String].map(Algorithm.fromString)
Decoder[String].emap(shortName => Algorithm.findByShortName(shortName).toRight(s"Unsupported algorithm: $shortName"))

protected implicit val joseHeaderCirceDecoder: Decoder[JoseHeader] =
Decoder.forProduct2[JoseHeader, String, Algorithm]("kid", "alg")(JoseHeader.apply)
Expand Down
4 changes: 1 addition & 3 deletions core/src/main/scala/me/wojnowski/oidc4s/Algorithm.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,9 @@ object Algorithm {
case object Rs384 extends Algorithm(name = "RS384", fullName = "SHA384withRSA")
case object Rs512 extends Algorithm(name = "RS512", fullName = "SHA512withRSA")

case class Other(override val name: String) extends Algorithm(name, fullName = name)

implicit val order: Order[Algorithm] = Order.by(_.name)

val supportedAlgorithms: NonEmptySet[Algorithm] = NonEmptySet.of(Rs256, Rs384, Rs512)

def fromString(s: String): Algorithm = supportedAlgorithms.find(_.name === s).getOrElse(Other(s))
def findByShortName(s: String): Option[Algorithm] = supportedAlgorithms.find(_.name === s)
}
31 changes: 14 additions & 17 deletions core/src/main/scala/me/wojnowski/oidc4s/IdTokenVerifier.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,16 @@ import cats.Monad
import cats.data.EitherT
import cats.effect.Clock
import cats.syntax.all._
import me.wojnowski.oidc4s.IdTokenVerifier.Error.CouldNotDecodeClaim
import me.wojnowski.oidc4s.IdTokenVerifier.Error.CouldNotDecodeHeader
import me.wojnowski.oidc4s.IdTokenVerifier.Error.InvalidSignature
import me.wojnowski.oidc4s.IdTokenVerifier.Error.MalformedToken
import me.wojnowski.oidc4s.IdTokenVerifier.Error.TokenExpired
import me.wojnowski.oidc4s.IdTokenVerifier.Error.UnsupportedAlgorithm
import me.wojnowski.oidc4s.IdTokenVerifier.Error._
import me.wojnowski.oidc4s.config.OpenIdConnectDiscovery
import me.wojnowski.oidc4s.json.JsonDecoder
import me.wojnowski.oidc4s.json.JsonDecoder.ClaimsDecoder
import me.wojnowski.oidc4s.json.JsonDecoder
import me.wojnowski.oidc4s.json.JsonSupport

import java.nio.charset.StandardCharsets
import java.security.PublicKey
import java.security.Signature
import java.time.Instant
import java.time.ZoneId
import java.time.{Clock => JavaClock}
import java.util.Base64
import scala.util.Success
import scala.util.Try
Expand Down Expand Up @@ -125,17 +118,24 @@ object IdTokenVerifier {
private def ensureNotExpired(now: Instant, expiresAt: Instant): Either[Error.TokenExpired, Unit] =
Either.raiseWhen(expiresAt.isBefore(now))(TokenExpired(since = expiresAt))

private def decodeHeader(headerJson: String): Either[CouldNotDecodeHeader, JoseHeader] =
private def decodeHeader(headerJson: String): Either[Error, JoseHeader] = {
val unsupportedAlgorithmPrefix = "Unsupported algorithm: "
jwojnowski marked this conversation as resolved.
Show resolved Hide resolved

JsonDecoder[JoseHeader]
.decode(headerJson)
.leftMap(CouldNotDecodeHeader.apply)
.leftMap {
case details if details.startsWith(unsupportedAlgorithmPrefix) =>
UnsupportedAlgorithm(details.stripPrefix(unsupportedAlgorithmPrefix))
case details =>
CouldNotDecodeHeader(details)
}
}

private def decodeJwtAndVerifySignature[A: ClaimsDecoder](rawToken: String, key: PublicKey, header: JoseHeader)
: Either[Error, (A, IdTokenClaims)] =
rawToken.split('.') match {
case Array(rawHeader, rawClaims, rawSignature) =>
for {
_ <- verifyAlgorithm(header.algorithm)
_ <- verifySignature(header.algorithm.fullName, key, rawHeader, rawClaims, rawSignature)
result <- parseClaims[A](rawClaims)
} yield result
Expand All @@ -151,9 +151,6 @@ object IdTokenVerifier {
ClaimsDecoder[A].decode(rawJson).leftMap(CouldNotDecodeClaim.apply)
}

private def verifyAlgorithm(algorithm: Algorithm) =
Either.raiseUnless(Algorithm.supportedAlgorithms.contains_(algorithm))(UnsupportedAlgorithm(algorithm.name.some))

private def verifySignature(
signingAlgorithm: String,
publicKey: PublicKey,
Expand Down Expand Up @@ -197,6 +194,8 @@ object IdTokenVerifier {

case class CouldNotDecodeHeader(details: String) extends Error

case class UnsupportedAlgorithm(providedAlgorithm: String) extends Error

case class CouldNotDecodeClaim(details: String) extends Error

case class TokenExpired(since: Instant) extends Error
Expand All @@ -205,8 +204,6 @@ object IdTokenVerifier {

case object InvalidSignature extends Error

case class UnsupportedAlgorithm(providedAlgorithm: Option[String]) extends Error

case class UnexpectedIssuer(found: Issuer, expected: Issuer) extends Error
}

Expand Down
45 changes: 15 additions & 30 deletions core/src/test/scala/me/wojnowski/oidc4s/IdTokenVerifierTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class IdTokenVerifierTest extends CatsEffectSuite {
.discovery[IO](googleKeyProvider, discovery, jsonSupport)
.verifyAndDecode(rawIdToken)
.map { result =>
assertEquals(result, Right(decodedIdTokenClaims(0)))
assertEquals(result, decodedIdTokenClaims(0))
}
}
}
Expand All @@ -68,7 +68,7 @@ class IdTokenVerifierTest extends CatsEffectSuite {

implicit val jsonDecoder: JsonDecoder[(CustomClaims, IdTokenClaims)] = (rawClaims: String) =>
if (rawClaims === rawIdTokenClaims.head) {
Right(expectedCustomClaims, decodedIdTokenClaims.head)
decodedIdTokenClaims.head.map((expectedCustomClaims, _))
} else {
Left("Could not decode claims")
}
Expand All @@ -94,7 +94,7 @@ class IdTokenVerifierTest extends CatsEffectSuite {

implicit val jsonDecoder: JsonDecoder[(CustomClaims, IdTokenClaims)] = (rawClaims: String) =>
if (rawClaims === rawIdTokenClaims.head) {
Right(expectedCustomClaims, decodedIdTokenClaims.head.copy(issuer = unexpectedIssuer))
decodedIdTokenClaims.head.map(claims => (expectedCustomClaims, claims.copy(issuer = unexpectedIssuer)))
} else {
Left("Could not decode claims")
}
Expand All @@ -105,7 +105,7 @@ class IdTokenVerifierTest extends CatsEffectSuite {
.map { result =>
assertEquals(
result,
Left(UnexpectedIssuer(found = unexpectedIssuer, decodedIdTokenClaims.head.issuer))
Left(UnexpectedIssuer(found = unexpectedIssuer, decodedIdTokenClaims.head.value.issuer))
)
}
}
Expand All @@ -119,7 +119,7 @@ class IdTokenVerifierTest extends CatsEffectSuite {

implicit val jsonDecoder: JsonDecoder[(CustomClaims, IdTokenClaims)] = (rawClaims: String) =>
if (rawClaims === rawIdTokenClaims.head) {
Right(expectedCustomClaims, decodedIdTokenClaims.head)
decodedIdTokenClaims.head.map((expectedCustomClaims, _))
} else {
Left("Could not decode claims")
}
Expand All @@ -145,7 +145,7 @@ class IdTokenVerifierTest extends CatsEffectSuite {

implicit val jsonDecoder: JsonDecoder[(CustomClaims, IdTokenClaims)] = (rawClaims: String) =>
if (rawClaims === rawIdTokenClaims.head) {
Right(expectedCustomClaims, decodedIdTokenClaims.head)
decodedIdTokenClaims.head.map((expectedCustomClaims, _))
} else {
Left("Could not decode claims")
}
Expand All @@ -170,7 +170,7 @@ class IdTokenVerifierTest extends CatsEffectSuite {
.map { result =>
assertEquals(
result,
Right(decodedIdTokenClaims.apply(0).subject)
decodedIdTokenClaims.apply(0).map(_.subject)
)
}
}
Expand Down Expand Up @@ -236,7 +236,7 @@ class IdTokenVerifierTest extends CatsEffectSuite {
.discovery[IO](nonGoogleKeyProvider, discoveryWithDifferentIssuer, jsonSupport)
.verifyAndDecode(tokenWithOtherIssuer)
.map { result =>
assertEquals(result, Right(decodedIdTokenClaims.apply(1)))
assertEquals(result, decodedIdTokenClaims.apply(1))
}
}

Expand All @@ -257,21 +257,6 @@ class IdTokenVerifierTest extends CatsEffectSuite {
}
}

test("Algorithm: none and empty signature") {
val tokenWithAlgorithmNone =
"eyJhbGciOiJub25lIiwia2lkIjoiMTFlMDNmMzliOGQzMDBjOGM5YTFiODAwZGRlYmZjZmRlNDE1MmMwYyIsInR5cCI6IkpXVCJ9Cg.eyJhdWQiOiJodHRwczovL2V4YW1wbGUuY29tL3BhdGgiLCJhenAiOiJpbnRlZ3JhdGlvbi10ZXN0c0BjaGluZ29yLXRlc3QuaWFtLmdzZXJ2aWNlYWNjb3VudC5jb20iLCJlbWFpbCI6ImludGVncmF0aW9uLXRlc3RzQGNoaW5nb3ItdGVzdC5pYW0uZ3NlcnZpY2VhY2NvdW50LmNvbSIsImVtYWlsX3ZlcmlmaWVkIjp0cnVlLCJleHAiOjE1ODc2Mjk4ODgsImlhdCI6MTU4NzYyNjI4OCwiaXNzIjoiaHR0cHM6Ly9hY2NvdW50cy5nb29nbGUuY29tIiwic3ViIjoiMTA0MDI5MjkyODUzMDk5OTc4MjkzIn0."

runAtInstant(idTokenExpiration.minusSeconds(3)) {
IdTokenVerifier
.discovery[IO](nonGoogleKeyProvider, discovery, jsonSupport)
.verifyAndDecode(tokenWithAlgorithmNone)
.map {
case Left(MalformedToken) => ()
case e => fail(s"expected JwtEmptySignatureException, got $e")
}
}
}

test("Algorithm: none and random signature") {
val tokenWithAlgorithmNone =
"eyJhbGciOiJub25lIiwia2lkIjoiMTFlMDNmMzliOGQzMDBjOGM5YTFiODAwZGRlYmZjZmRlNDE1MmMwYyIsInR5cCI6IkpXVCJ9Cg.eyJhdWQiOiJodHRwczovL2V4YW1wbGUuY29tL3BhdGgiLCJhenAiOiJpbnRlZ3JhdGlvbi10ZXN0c0BjaGluZ29yLXRlc3QuaWFtLmdzZXJ2aWNlYWNjb3VudC5jb20iLCJlbWFpbCI6ImludGVncmF0aW9uLXRlc3RzQGNoaW5nb3ItdGVzdC5pYW0uZ3NlcnZpY2VhY2NvdW50LmNvbSIsImVtYWlsX3ZlcmlmaWVkIjp0cnVlLCJleHAiOjE1ODc2Mjk4ODgsImlhdCI6MTU4NzYyNjI4OCwiaXNzIjoiaHR0cHM6Ly9hY2NvdW50cy5nb29nbGUuY29tIiwic3ViIjoiMTA0MDI5MjkyODUzMDk5OTc4MjkzIn0.OZUxLojD0DddF9Phg63HQg3R6xvr2Gp3T3msRn09bHvaSUmr_SmMFrIAACiKHHJuQ43eZq9Qvc4ICCMrBwQfVV2FcOXffMQEA6SgTJxRcfSsfhdXX3QDJRGK27x0ynsarcWFrw9TefJyt_gPhhhE0yAzrxCHDPz0LRe8NCv4OvKnw9LZujF5P5k_AxduWTQJuNzHJGsx38E2NLW9SK93KbODZEzCX8YDkddfRfR_LZl2FciRsY6JoHtucqP5KMCFvSJBkfYqQGESeW8EUMxhBH8UrP1pcDD-6u7WXHM5bguC0rrGY6UPvRm3uZMcSYhyOnapC8f0zJBVXGyl9J5dWw"
Expand All @@ -282,7 +267,7 @@ class IdTokenVerifierTest extends CatsEffectSuite {
.verifyAndDecode(tokenWithAlgorithmNone)
.map {
case Left(UnsupportedAlgorithm(_)) => ()
case e => fail(s"expected $UnsupportedAlgorithm, got $e")
case e => fail(s"expected ${UnsupportedAlgorithm}, got $e")
}
}
}
Expand All @@ -296,7 +281,7 @@ class IdTokenVerifierTest extends CatsEffectSuite {
.verifyAndDecode(tokenWithHs256Algorithm)
.map {
case Left(UnsupportedAlgorithm(_)) => ()
case e => fail(s"expected $UnsupportedAlgorithm, got $e")
case e => fail(s"expected ${UnsupportedAlgorithm}, got $e")
}
}
}
Expand Down Expand Up @@ -405,7 +390,7 @@ object IdTokenVerifierTest {
expiration = idTokenExpiration,
issuedAt = Instant.parse("2020-04-23T07:18:08Z")
)
)
).map(Right(_))

val rawJoseHeaders =
List(
Expand All @@ -417,10 +402,10 @@ object IdTokenVerifierTest {

val decodedJoseHeaders =
List(
JoseHeader(keyId = "f9d97b4cae90bcd76aeb20026f6b770cac221783", algorithm = Algorithm.Rs256),
JoseHeader(keyId = "11e03f39b8d300c8c9a1b800ddebfcfde4152c0c", algorithm = Algorithm.Rs256),
JoseHeader(keyId = "11e03f39b8d300c8c9a1b800ddebfcfde4152c0c", algorithm = Algorithm.Other("none")),
JoseHeader(keyId = "11e03f39b8d300c8c9a1b800ddebfcfde4152c0c", algorithm = Algorithm.Other("HS256"))
Right(JoseHeader(keyId = "f9d97b4cae90bcd76aeb20026f6b770cac221783", algorithm = Algorithm.Rs256)),
Right(JoseHeader(keyId = "11e03f39b8d300c8c9a1b800ddebfcfde4152c0c", algorithm = Algorithm.Rs256)),
Left("Unsupported algorithm: none"),
Left("Unsupported algorithm: HS256")
)

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class OpenIdConnectDiscoveryTest extends FunSuite {
)
val transport = HttpTransportMock.const[Id](configurationUrl, response = "correct-config-response")
val jsonSupport = JsonSupportMock.instance(openIdConfigTranslations = { case "correct-config-response" =>
expectedConfig
Right(expectedConfig)
})

val discovery =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class PropertyIdTokenVerifierTest extends CatsEffectSuite with ScalaCheckEffectS
IdTokenVerifier.static(
publicKeyProvider,
issuer,
JsonSupportMock.instance(Map(rawClaims -> claims), Map(rawHeader -> header))
JsonSupportMock.instance(Map(rawClaims -> Right(claims)), Map(rawHeader -> Right(header)))
)

val rawJwt =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class PublicKeyProviderTest extends CatsEffectSuite {

private val transport = HttpTransportMock.const[Id](jwksUrl, "")

private val jsonSupport = JsonSupportMock.instance(jsonWebKeySetTranslations = { _ => jsonWebKeySet })
private val jsonSupport = JsonSupportMock.instance(jsonWebKeySetTranslations = { _ => Right(jsonWebKeySet) })

val discovery: OpenIdConnectDiscovery[Id] = OpenIdConnectDiscovery.static[Id](OpenIdConfig(issuer = Issuer(""), jwksUrl))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ import me.wojnowski.oidc4s.json.JsonSupport
object JsonSupportMock {

def instance(
idTokenTranslations: PartialFunction[String, IdTokenClaims] = PartialFunction.empty,
joseHeaderTranslations: PartialFunction[String, JoseHeader] = PartialFunction.empty,
openIdConfigTranslations: PartialFunction[String, OpenIdConfig] = PartialFunction.empty,
jsonWebKeySetTranslations: PartialFunction[String, JsonWebKeySet] = PartialFunction.empty
idTokenTranslations: PartialFunction[String, Either[String, IdTokenClaims]] = PartialFunction.empty,
joseHeaderTranslations: PartialFunction[String, Either[String, JoseHeader]] = PartialFunction.empty,
openIdConfigTranslations: PartialFunction[String, Either[String, OpenIdConfig]] = PartialFunction.empty,
jsonWebKeySetTranslations: PartialFunction[String, Either[String, JsonWebKeySet]] = PartialFunction.empty
): JsonSupport = new JsonSupport {

override implicit val joseHeaderDecoder: JsonDecoder[JoseHeader] =
Expand All @@ -28,8 +28,8 @@ object JsonSupportMock {
override implicit val jwksDecoder: JsonDecoder[JsonWebKeySet] =
translateOrFail(jsonWebKeySetTranslations, "JsonWebKeySet")

private def translateOrFail[A](translations: PartialFunction[String, A], name: String): JsonDecoder[A] =
(rawJson: String) => translations.lift(rawJson.trim).toRight(s"could not find $name for [$rawJson]")
private def translateOrFail[A](translations: PartialFunction[String, Either[String, A]], name: String): JsonDecoder[A] =
(rawJson: String) => translations.lift(rawJson.trim).toRight(s"could not find $name for [$rawJson]").flatten
}

}
Loading