From 26152000830b0af27b54614a9bc5b648efb961f0 Mon Sep 17 00:00:00 2001 From: Jakub Wojnowski <29680262+jwojnowski@users.noreply.github.com> Date: Thu, 28 Sep 2023 23:46:43 +0200 Subject: [PATCH] =?UTF-8?q?=E2=9E=96=20Remove=20jwt-scala.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 2 - build.sbt | 1 - .../oidc4s/json/circe/CirceJsonSupport.scala | 5 +- .../json/circe/JwtHeaderCirceDecoder.scala | 10 +- .../circe/JwtHeaderCirceJsonSupportTest.scala | 5 +- .../scala/me/wojnowski/oidc4s/Algorithm.scala | 22 ++++ .../me/wojnowski/oidc4s/IdTokenVerifier.scala | 118 +++++++++++------- .../scala/me/wojnowski/oidc4s/JwtHeader.scala | 3 + .../wojnowski/oidc4s/json/JsonSupport.scala | 2 +- .../oidc4s/IdTokenVerifierTest.scala | 40 +++--- .../oidc4s/mocks/JsonSupportMock.scala | 2 +- 11 files changed, 129 insertions(+), 81 deletions(-) create mode 100644 core/src/main/scala/me/wojnowski/oidc4s/Algorithm.scala create mode 100644 core/src/main/scala/me/wojnowski/oidc4s/JwtHeader.scala diff --git a/README.md b/README.md index 5abf20d..f1e5732 100644 --- a/README.md +++ b/README.md @@ -12,8 +12,6 @@ on front-end side can lead to a simple, yet secure authentication for Single Pag Scala versions 3.x and 2.13.x are supported. -[JWT Scala](https://github.com/jwt-scala/jwt-scala) is used for JWT verification under-the-hood. - ## Getting started To use this library with default Sttp/Circe implementations, add the following dependency to your `build.sbt`: diff --git a/build.sbt b/build.sbt index 9673839..ebbeeef 100644 --- a/build.sbt +++ b/build.sbt @@ -52,7 +52,6 @@ lazy val core = (project in file("core")).settings( name := "oidc4s-core", libraryDependencies += "org.typelevel" %% "cats-core" % Versions.cats.core, libraryDependencies += "org.typelevel" %% "cats-effect" % Versions.cats.effect, - libraryDependencies += "com.github.jwt-scala" %% "jwt-core" % Versions.jwtScala, libraryDependencies += "org.scalameta" %% "munit" % Versions.mUnit % Test, libraryDependencies += "org.typelevel" %% "munit-cats-effect-3" % Versions.mUnitCatsEffect % Test, libraryDependencies += "org.typelevel" %% "cats-effect-testkit" % Versions.cats.effect % Test diff --git a/circe/src/main/scala/me/wojnowski/oidc4s/json/circe/CirceJsonSupport.scala b/circe/src/main/scala/me/wojnowski/oidc4s/json/circe/CirceJsonSupport.scala index 55b6fbe..deae8a6 100644 --- a/circe/src/main/scala/me/wojnowski/oidc4s/json/circe/CirceJsonSupport.scala +++ b/circe/src/main/scala/me/wojnowski/oidc4s/json/circe/CirceJsonSupport.scala @@ -1,18 +1,15 @@ package me.wojnowski.oidc4s.json.circe -import cats.data.NonEmptySet import cats.syntax.all._ import io.circe.Decoder import io.circe.parser import me.wojnowski.oidc4s.IdTokenClaims -import me.wojnowski.oidc4s.IdTokenClaims.Audience -import me.wojnowski.oidc4s.Issuer +import me.wojnowski.oidc4s.JwtHeader import me.wojnowski.oidc4s.PublicKeyProvider import me.wojnowski.oidc4s.config.OpenIdConfig import me.wojnowski.oidc4s.json.JsonDecoder import me.wojnowski.oidc4s.json.JsonDecoder.ClaimsDecoder import me.wojnowski.oidc4s.json.JsonSupport -import pdi.jwt.JwtHeader trait CirceJsonSupport extends JsonSupport diff --git a/circe/src/main/scala/me/wojnowski/oidc4s/json/circe/JwtHeaderCirceDecoder.scala b/circe/src/main/scala/me/wojnowski/oidc4s/json/circe/JwtHeaderCirceDecoder.scala index 43edac4..bfe40d6 100644 --- a/circe/src/main/scala/me/wojnowski/oidc4s/json/circe/JwtHeaderCirceDecoder.scala +++ b/circe/src/main/scala/me/wojnowski/oidc4s/json/circe/JwtHeaderCirceDecoder.scala @@ -1,13 +1,15 @@ package me.wojnowski.oidc4s.json.circe import io.circe.Decoder -import pdi.jwt.JwtHeader +import me.wojnowski.oidc4s.Algorithm +import me.wojnowski.oidc4s.JwtHeader trait JwtHeaderCirceDecoder { protected implicit val jwtHeaderCirceDecoder: Decoder[JwtHeader] = - Decoder.forProduct1[JwtHeader, String]("kid") { kid => - JwtHeader(keyId = Some(kid)) - } + Decoder.forProduct2[JwtHeader, String, Algorithm]("kid", "alg")(JwtHeader.apply) + + private implicit val jwtAlgorithmCirceDecoder: Decoder[Algorithm] = + Decoder[String].map(Algorithm.fromString) } diff --git a/circe/src/test/scala/me/wojnowski/oidc4s/json/circe/JwtHeaderCirceJsonSupportTest.scala b/circe/src/test/scala/me/wojnowski/oidc4s/json/circe/JwtHeaderCirceJsonSupportTest.scala index a0a08cd..60dad49 100644 --- a/circe/src/test/scala/me/wojnowski/oidc4s/json/circe/JwtHeaderCirceJsonSupportTest.scala +++ b/circe/src/test/scala/me/wojnowski/oidc4s/json/circe/JwtHeaderCirceJsonSupportTest.scala @@ -1,8 +1,9 @@ package me.wojnowski.oidc4s.json.circe +import me.wojnowski.oidc4s.Algorithm +import me.wojnowski.oidc4s.JwtHeader import me.wojnowski.oidc4s.json.circe.CirceJsonSupport import munit.FunSuite -import pdi.jwt.JwtHeader class JwtHeaderCirceJsonSupportTest extends FunSuite { test("JwtHeader is decoding") { @@ -11,7 +12,7 @@ class JwtHeaderCirceJsonSupportTest extends FunSuite { val result = CirceJsonSupport.jwtHeaderDecoder.decode(rawJson) - assertEquals(result, Right(JwtHeader(keyId = Some("thisiskeyid")))) + assertEquals(result, Right(JwtHeader(keyId = "thisiskeyid", algorithm = Algorithm.Rs256))) } test("JwtHeader decoding (missing field)") { diff --git a/core/src/main/scala/me/wojnowski/oidc4s/Algorithm.scala b/core/src/main/scala/me/wojnowski/oidc4s/Algorithm.scala new file mode 100644 index 0000000..7e076a7 --- /dev/null +++ b/core/src/main/scala/me/wojnowski/oidc4s/Algorithm.scala @@ -0,0 +1,22 @@ +package me.wojnowski.oidc4s + +import cats.Order +import cats.data.NonEmptySet +import cats.implicits._ + +sealed abstract class Algorithm(val name: String, val fullName: String) extends Product with Serializable + +// According to OIDC RFC, only RS256 should be supported +object Algorithm { + case object Rs256 extends Algorithm(name = "RS256", fullName = "SHA256withRSA") + case object Rs384 extends Algorithm(name = "RS384", fullName = "SHA384withRSA") + case object Rs512 extends Algorithm(name = "RS512", fullName = "SHA512withRSA") + + case class Other(override val name: String) extends Algorithm(name, fullName = name) + + implicit val order: Order[Algorithm] = Order.by(_.name) + + val supportedAlgorithms: NonEmptySet[Algorithm] = NonEmptySet.of(Rs256, Rs384, Rs512) + + def fromString(s: String): Algorithm = supportedAlgorithms.find(_.name === s).getOrElse(Other(s)) +} diff --git a/core/src/main/scala/me/wojnowski/oidc4s/IdTokenVerifier.scala b/core/src/main/scala/me/wojnowski/oidc4s/IdTokenVerifier.scala index fb89abf..f108a44 100644 --- a/core/src/main/scala/me/wojnowski/oidc4s/IdTokenVerifier.scala +++ b/core/src/main/scala/me/wojnowski/oidc4s/IdTokenVerifier.scala @@ -2,25 +2,27 @@ package me.wojnowski.oidc4s import cats.Monad import cats.data.EitherT -import cats.data.NonEmptySet import cats.effect.Clock import cats.syntax.all._ -import me.wojnowski.oidc4s.IdTokenClaims.Audience -import me.wojnowski.oidc4s.IdTokenVerifier.Error.CouldNotExtractKeyId -import me.wojnowski.oidc4s.IdTokenVerifier.Error.JwtVerificationError +import me.wojnowski.oidc4s.IdTokenVerifier.Error.CouldNotDecodeClaim +import me.wojnowski.oidc4s.IdTokenVerifier.Error.CouldNotDecodeHeader +import me.wojnowski.oidc4s.IdTokenVerifier.Error.InvalidSignature +import me.wojnowski.oidc4s.IdTokenVerifier.Error.InvalidToken +import me.wojnowski.oidc4s.IdTokenVerifier.Error.TokenExpired +import me.wojnowski.oidc4s.IdTokenVerifier.Error.UnsupportedAlgorithm import me.wojnowski.oidc4s.config.OpenIdConnectDiscovery import me.wojnowski.oidc4s.json.JsonDecoder import me.wojnowski.oidc4s.json.JsonDecoder.ClaimsDecoder import me.wojnowski.oidc4s.json.JsonSupport -import pdi.jwt.Jwt -import pdi.jwt.JwtAlgorithm -import pdi.jwt.JwtHeader import java.nio.charset.StandardCharsets import java.security.PublicKey +import java.security.Signature +import java.time.Instant import java.time.ZoneId import java.time.{Clock => JavaClock} import java.util.Base64 +import scala.util.Success import scala.util.Try trait IdTokenVerifier[F[_]] { @@ -78,9 +80,6 @@ object IdTokenVerifier { new IdTokenVerifier[F] { import jsonSupport._ - // According to OIDC RFC, only RS256 should be supported - private val supportedAlgorithms = Seq(JwtAlgorithm.RS256, JwtAlgorithm.RS384, JwtAlgorithm.RS512) - override def verify(rawToken: String, expectedClientId: ClientId): F[Either[Error, IdTokenClaims.Subject]] = verifyAndDecode(rawToken).map(_.ensure(Error.ClientIdDoesNotMatch)(_.matchesClientId(expectedClientId)).map(_.subject)) @@ -104,49 +103,74 @@ object IdTokenVerifier { ): F[Either[IdTokenVerifier.Error, A]] = { for { issuer <- EitherT(issuerF) - instant <- EitherT.liftF(Clock[F].realTimeInstant) - javaClock = JavaClock.fixed(instant, ZoneId.of("UTC")) + now <- EitherT.liftF(Clock[F].realTimeInstant) headerJson <- EitherT.fromEither(extractHeaderJson(rawToken)) - kid <- EitherT.fromEither(extractKid(headerJson)) - publicKey <- EitherT(publicKeyProvider.getKey(kid).map(_.leftMap(IdTokenVerifier.Error.CouldNotFindPublicKey.apply))) + header <- EitherT.fromEither(decodeHeader(headerJson)) + publicKey <- EitherT(publicKeyProvider.getKey(header.keyId).map(_.leftMap(IdTokenVerifier.Error.CouldNotFindPublicKey.apply))) result <- EitherT.fromEither { - decodeAndVerifyToken[(A, IdTokenClaims)](rawToken, javaClock, publicKey) - .flatMap { case (claims, standardClaims) => - ensureExpectedIssuer(tokenIssuer = standardClaims.issuer, expectedIssuer = issuer) - .leftWiden[IdTokenVerifier.Error] - .flatTap { _ => - standardClaimsCheck(standardClaims) - } - .as(claims) - } + decodeJwtAndVerifySignature[A](rawToken, publicKey, header).flatMap { case (claims, standardClaims) => + List( + ensureNotExpired(now, standardClaims.expiration), + ensureExpectedIssuer(tokenIssuer = standardClaims.issuer, expectedIssuer = issuer), + standardClaimsCheck(standardClaims) + ).sequence.as(claims) + } } } yield result }.value - private def decodeAndVerifyToken[A: JsonDecoder]( - rawToken: String, - javaClock: JavaClock, - publicKey: PublicKey - ): Either[Error, A] = - Jwt(javaClock) - .decodeRaw(rawToken, publicKey, supportedAlgorithms) - .toEither - .leftMap[Error](throwable => JwtVerificationError(throwable)) - .flatMap { rawClaims => - JsonDecoder[A] - .decode(rawClaims) - .leftMap(IdTokenVerifier.Error.CouldNotDecodeClaim.apply) - } - private def ensureExpectedIssuer(tokenIssuer: Issuer, expectedIssuer: Issuer): Either[Error.UnexpectedIssuer, Unit] = Either.cond(expectedIssuer === tokenIssuer, (), IdTokenVerifier.Error.UnexpectedIssuer(tokenIssuer, expectedIssuer)) - private def extractKid(headerJson: String): Either[CouldNotExtractKeyId.type, String] = + private def ensureNotExpired(now: Instant, expiresAt: Instant): Either[Error.TokenExpired, Unit] = + Either.raiseWhen(expiresAt.isBefore(now))(TokenExpired(since = expiresAt)) + + private def decodeHeader(headerJson: String): Either[CouldNotDecodeHeader, JwtHeader] = JsonDecoder[JwtHeader] .decode(headerJson) - .toOption - .flatMap(_.keyId) - .toRight(CouldNotExtractKeyId) + .leftMap(CouldNotDecodeHeader.apply) + + private def decodeJwtAndVerifySignature[A: ClaimsDecoder](rawToken: String, key: PublicKey, header: JwtHeader) + : Either[Error, (A, IdTokenClaims)] = + rawToken.split('.') match { + case Array(rawHeader, rawClaims, rawSignature) => + for { + _ <- verifyAlgorithm(header.algorithm) + _ <- verifySignature(header.algorithm.fullName, key, rawHeader, rawClaims, rawSignature) + result <- parseClaims[A](rawClaims) + } yield result + + case _ => + InvalidToken.asLeft + } + + private def parseClaims[A: ClaimsDecoder](rawClaims: String): Either[CouldNotDecodeClaim, (A, IdTokenClaims)] = + Try { + new String(Base64.getUrlDecoder.decode(rawClaims)) + }.toEither.leftMap(t => CouldNotDecodeClaim(t.getMessage)).flatMap { rawJson => + ClaimsDecoder[A].decode(rawJson).leftMap(CouldNotDecodeClaim.apply) + } + + private def verifyAlgorithm(algorithm: Algorithm) = + Either.raiseUnless(Algorithm.supportedAlgorithms.contains_(algorithm))(UnsupportedAlgorithm(algorithm.name.some)) + + private def verifySignature( + signingAlgorithm: String, + publicKey: PublicKey, + rawHeader: String, + rawClaims: String, + rawSignature: String + ) = + Try { + val decodedSignature = Base64.getUrlDecoder.decode(rawSignature) + val signatureInstance = Signature.getInstance(signingAlgorithm) + signatureInstance.initVerify(publicKey) + signatureInstance.update(s"$rawHeader.$rawClaims".getBytes(StandardCharsets.UTF_8)) + signatureInstance.verify(decodedSignature) + } match { + case Success(true) => Either.unit + case _ => InvalidSignature.asLeft + } private def extractHeaderJson(rawToken: String) = Try { @@ -171,9 +195,17 @@ object IdTokenVerifier { case class CouldNotFindPublicKey(cause: PublicKeyProvider.Error) extends Error + case class CouldNotDecodeHeader(details: String) extends Error + case class CouldNotDecodeClaim(details: String) extends Error - case class JwtVerificationError(cause: Throwable) extends Error + case class TokenExpired(since: Instant) extends Error + + case object InvalidToken extends Error + + case object InvalidSignature extends Error + + case class UnsupportedAlgorithm(providedAlgorithm: Option[String]) extends Error case class UnexpectedIssuer(found: Issuer, expected: Issuer) extends Error } diff --git a/core/src/main/scala/me/wojnowski/oidc4s/JwtHeader.scala b/core/src/main/scala/me/wojnowski/oidc4s/JwtHeader.scala new file mode 100644 index 0000000..75a4e85 --- /dev/null +++ b/core/src/main/scala/me/wojnowski/oidc4s/JwtHeader.scala @@ -0,0 +1,3 @@ +package me.wojnowski.oidc4s + +case class JwtHeader(keyId: String, algorithm: Algorithm) diff --git a/core/src/main/scala/me/wojnowski/oidc4s/json/JsonSupport.scala b/core/src/main/scala/me/wojnowski/oidc4s/json/JsonSupport.scala index 2471f36..0d7a5de 100644 --- a/core/src/main/scala/me/wojnowski/oidc4s/json/JsonSupport.scala +++ b/core/src/main/scala/me/wojnowski/oidc4s/json/JsonSupport.scala @@ -1,9 +1,9 @@ package me.wojnowski.oidc4s.json import me.wojnowski.oidc4s.IdTokenClaims +import me.wojnowski.oidc4s.JwtHeader import me.wojnowski.oidc4s.PublicKeyProvider.JsonWebKeySet import me.wojnowski.oidc4s.config.OpenIdConfig -import pdi.jwt.JwtHeader trait JsonSupport { implicit def jwtHeaderDecoder: JsonDecoder[JwtHeader] diff --git a/core/src/test/scala/me/wojnowski/oidc4s/IdTokenVerifierTest.scala b/core/src/test/scala/me/wojnowski/oidc4s/IdTokenVerifierTest.scala index 3792c66..9c38bb6 100644 --- a/core/src/test/scala/me/wojnowski/oidc4s/IdTokenVerifierTest.scala +++ b/core/src/test/scala/me/wojnowski/oidc4s/IdTokenVerifierTest.scala @@ -7,8 +7,10 @@ import cats.effect.IO import cats.effect.testkit.TestControl import cats.syntax.all._ import me.wojnowski.oidc4s.IdTokenClaims._ -import me.wojnowski.oidc4s.IdTokenVerifier.Error.JwtVerificationError +import me.wojnowski.oidc4s.IdTokenVerifier.Error.InvalidSignature +import me.wojnowski.oidc4s.IdTokenVerifier.Error.TokenExpired import me.wojnowski.oidc4s.IdTokenVerifier.Error.UnexpectedIssuer +import me.wojnowski.oidc4s.IdTokenVerifier.Error.UnsupportedAlgorithm import me.wojnowski.oidc4s.IdTokenVerifierTest.staticKeyProvider import me.wojnowski.oidc4s.PublicKeyProvider.KeyId import me.wojnowski.oidc4s.PublicKeyProvider.KeyMap @@ -21,17 +23,9 @@ import me.wojnowski.oidc4s.mocks.CacheMock import me.wojnowski.oidc4s.mocks.HttpTransportMock import me.wojnowski.oidc4s.mocks.JsonSupportMock import munit.CatsEffectSuite -import pdi.jwt.JwtHeader -import pdi.jwt.exceptions.JwtEmptyAlgorithmException -import pdi.jwt.exceptions.JwtEmptySignatureException -import pdi.jwt.exceptions.JwtExpirationException -import pdi.jwt.exceptions.JwtValidationException -import java.security.KeyFactory import java.security.PublicKey -import java.security.spec.X509EncodedKeySpec import java.time.Instant -import java.util.Base64 import scala.annotation.unused //noinspection ZeroIndexToHead @@ -64,8 +58,8 @@ class IdTokenVerifierTest extends CatsEffectSuite { .discovery[IO](googleKeyProvider, discovery, jsonSupport) .verifyAndDecode(rawIdToken) .map { - case Left(JwtVerificationError(_: JwtExpirationException)) => () - case _ => fail("expected JwtExpirationException") + case Left(TokenExpired(`idTokenExpiration`)) => () + case _ => fail(s"expected $TokenExpired") } } } @@ -228,8 +222,8 @@ class IdTokenVerifierTest extends CatsEffectSuite { .discovery[IO](googleKeyProvider, discovery, jsonSupport) .verifyAndDecode(tokenSignedWithOtherKey) .map { - case Left(JwtVerificationError(_: JwtValidationException)) => () - case e => fail(s"expected JwtValidationException, got $e") + case Left(InvalidSignature) => () + case e => fail(s"expected JwtValidationException, got $e") } } } @@ -276,8 +270,8 @@ class IdTokenVerifierTest extends CatsEffectSuite { .discovery[IO](nonGoogleKeyProvider, discovery, jsonSupport) .verifyAndDecode(tokenWithAlgorithmNone) .map { - case Left(JwtVerificationError(_: JwtEmptySignatureException)) => () - case e => fail(s"expected JwtEmptySignatureException, got $e") + case Left(InvalidSignature) => () + case e => fail(s"expected JwtEmptySignatureException, got $e") } } } @@ -291,8 +285,8 @@ class IdTokenVerifierTest extends CatsEffectSuite { .discovery[IO](nonGoogleKeyProvider, discovery, jsonSupport) .verifyAndDecode(tokenWithAlgorithmNone) .map { - case Left(JwtVerificationError(_: JwtEmptyAlgorithmException)) => () - case e => fail(s"expected JwtEmptyAlgorithmException, got $e") + case Left(UnsupportedAlgorithm(_)) => () + case e => fail(s"expected $UnsupportedAlgorithm, got $e") } } } @@ -305,8 +299,8 @@ class IdTokenVerifierTest extends CatsEffectSuite { .discovery[IO](nonGoogleKeyProvider, discovery, jsonSupport) .verifyAndDecode(tokenWithHs256Algorithm) .map { - case Left(JwtVerificationError(_: JwtValidationException)) => () - case e => fail(s"expected JwtValidationException, got $e") + case Left(UnsupportedAlgorithm(_)) => () + case e => fail(s"expected $UnsupportedAlgorithm, got $e") } } } @@ -427,10 +421,10 @@ object IdTokenVerifierTest { val decodedJwtHeaders = List( - JwtHeader(keyId = Some("f9d97b4cae90bcd76aeb20026f6b770cac221783")), - JwtHeader(keyId = Some("11e03f39b8d300c8c9a1b800ddebfcfde4152c0c")), - JwtHeader(keyId = Some("11e03f39b8d300c8c9a1b800ddebfcfde4152c0c")), - JwtHeader(keyId = Some("11e03f39b8d300c8c9a1b800ddebfcfde4152c0c")) + JwtHeader(keyId = "f9d97b4cae90bcd76aeb20026f6b770cac221783", algorithm = Algorithm.Rs256), + JwtHeader(keyId = "11e03f39b8d300c8c9a1b800ddebfcfde4152c0c", algorithm = Algorithm.Rs256), + JwtHeader(keyId = "11e03f39b8d300c8c9a1b800ddebfcfde4152c0c", algorithm = Algorithm.Other("none")), + JwtHeader(keyId = "11e03f39b8d300c8c9a1b800ddebfcfde4152c0c", algorithm = Algorithm.Other("HS256")) ) } diff --git a/core/src/test/scala/me/wojnowski/oidc4s/mocks/JsonSupportMock.scala b/core/src/test/scala/me/wojnowski/oidc4s/mocks/JsonSupportMock.scala index 9280c62..d8ba4f0 100644 --- a/core/src/test/scala/me/wojnowski/oidc4s/mocks/JsonSupportMock.scala +++ b/core/src/test/scala/me/wojnowski/oidc4s/mocks/JsonSupportMock.scala @@ -1,11 +1,11 @@ package me.wojnowski.oidc4s.mocks import me.wojnowski.oidc4s.IdTokenClaims +import me.wojnowski.oidc4s.JwtHeader import me.wojnowski.oidc4s.PublicKeyProvider.JsonWebKeySet import me.wojnowski.oidc4s.config.OpenIdConfig import me.wojnowski.oidc4s.json.JsonDecoder import me.wojnowski.oidc4s.json.JsonSupport -import pdi.jwt.JwtHeader object JsonSupportMock {