Skip to content

Commit

Permalink
➖ Remove jwt-scala.
Browse files Browse the repository at this point in the history
  • Loading branch information
jwojnowski committed Oct 4, 2023
1 parent 54167d6 commit 2830a33
Show file tree
Hide file tree
Showing 11 changed files with 129 additions and 82 deletions.
2 changes: 0 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@ on front-end side can lead to a simple, yet secure authentication for Single Pag

Scala versions 3.x and 2.13.x are supported.

[JWT Scala](https://github.com/jwt-scala/jwt-scala) is used for JWT verification under-the-hood.

## Getting started

To use this library with default Sttp/Circe implementations, add the following dependency to your `build.sbt`:
Expand Down
1 change: 0 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ lazy val core = (project in file("core")).settings(
name := "oidc4s-core",
libraryDependencies += "org.typelevel" %% "cats-core" % Versions.cats.core,
libraryDependencies += "org.typelevel" %% "cats-effect" % Versions.cats.effect,
libraryDependencies += "com.github.jwt-scala" %% "jwt-core" % Versions.jwtScala,
libraryDependencies += "org.scalameta" %% "munit" % Versions.mUnit % Test,
libraryDependencies += "org.typelevel" %% "munit-cats-effect-3" % Versions.mUnitCatsEffect % Test,
libraryDependencies += "org.typelevel" %% "cats-effect-testkit" % Versions.cats.effect % Test
Expand Down
Original file line number Diff line number Diff line change
@@ -1,18 +1,15 @@
package me.wojnowski.oidc4s.json.circe

import cats.data.NonEmptySet
import cats.syntax.all._
import io.circe.Decoder
import io.circe.parser
import me.wojnowski.oidc4s.IdTokenClaims
import me.wojnowski.oidc4s.IdTokenClaims.Audience
import me.wojnowski.oidc4s.Issuer
import me.wojnowski.oidc4s.JwtHeader
import me.wojnowski.oidc4s.PublicKeyProvider
import me.wojnowski.oidc4s.config.OpenIdConfig
import me.wojnowski.oidc4s.json.JsonDecoder
import me.wojnowski.oidc4s.json.JsonDecoder.ClaimsDecoder
import me.wojnowski.oidc4s.json.JsonSupport
import pdi.jwt.JwtHeader

trait CirceJsonSupport
extends JsonSupport
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
package me.wojnowski.oidc4s.json.circe

import io.circe.Decoder
import pdi.jwt.JwtHeader
import me.wojnowski.oidc4s.Algorithm
import me.wojnowski.oidc4s.JwtHeader

trait JwtHeaderCirceDecoder {

private implicit val jwtAlgorithmCirceDecoder: Decoder[Algorithm] =
Decoder[String].map(Algorithm.fromString)

protected implicit val jwtHeaderCirceDecoder: Decoder[JwtHeader] =
Decoder.forProduct1[JwtHeader, String]("kid") { kid =>
JwtHeader(keyId = Some(kid))
}
Decoder.forProduct2[JwtHeader, String, Algorithm]("kid", "alg")(JwtHeader.apply)

}
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
package me.wojnowski.oidc4s.json.circe

import me.wojnowski.oidc4s.Algorithm
import me.wojnowski.oidc4s.JwtHeader
import me.wojnowski.oidc4s.json.circe.CirceJsonSupport
import munit.FunSuite
import pdi.jwt.JwtHeader

class JwtHeaderCirceJsonSupportTest extends FunSuite {
test("JwtHeader is decoding") {
Expand All @@ -11,7 +12,7 @@ class JwtHeaderCirceJsonSupportTest extends FunSuite {

val result = CirceJsonSupport.jwtHeaderDecoder.decode(rawJson)

assertEquals(result, Right(JwtHeader(keyId = Some("thisiskeyid"))))
assertEquals(result, Right(JwtHeader(keyId = "thisiskeyid", algorithm = Algorithm.Rs256)))
}

test("JwtHeader decoding (missing field)") {
Expand Down
22 changes: 22 additions & 0 deletions core/src/main/scala/me/wojnowski/oidc4s/Algorithm.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package me.wojnowski.oidc4s

import cats.Order
import cats.data.NonEmptySet
import cats.implicits._

sealed abstract class Algorithm(val name: String, val fullName: String) extends Product with Serializable

// According to OIDC RFC, only RS256 should be supported
object Algorithm {
case object Rs256 extends Algorithm(name = "RS256", fullName = "SHA256withRSA")
case object Rs384 extends Algorithm(name = "RS384", fullName = "SHA384withRSA")
case object Rs512 extends Algorithm(name = "RS512", fullName = "SHA512withRSA")

case class Other(override val name: String) extends Algorithm(name, fullName = name)

implicit val order: Order[Algorithm] = Order.by(_.name)

val supportedAlgorithms: NonEmptySet[Algorithm] = NonEmptySet.of(Rs256, Rs384, Rs512)

def fromString(s: String): Algorithm = supportedAlgorithms.find(_.name === s).getOrElse(Other(s))
}
118 changes: 75 additions & 43 deletions core/src/main/scala/me/wojnowski/oidc4s/IdTokenVerifier.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,27 @@ package me.wojnowski.oidc4s

import cats.Monad
import cats.data.EitherT
import cats.data.NonEmptySet
import cats.effect.Clock
import cats.syntax.all._
import me.wojnowski.oidc4s.IdTokenClaims.Audience
import me.wojnowski.oidc4s.IdTokenVerifier.Error.CouldNotExtractKeyId
import me.wojnowski.oidc4s.IdTokenVerifier.Error.JwtVerificationError
import me.wojnowski.oidc4s.IdTokenVerifier.Error.CouldNotDecodeClaim
import me.wojnowski.oidc4s.IdTokenVerifier.Error.CouldNotDecodeHeader
import me.wojnowski.oidc4s.IdTokenVerifier.Error.InvalidSignature
import me.wojnowski.oidc4s.IdTokenVerifier.Error.InvalidToken
import me.wojnowski.oidc4s.IdTokenVerifier.Error.TokenExpired
import me.wojnowski.oidc4s.IdTokenVerifier.Error.UnsupportedAlgorithm
import me.wojnowski.oidc4s.config.OpenIdConnectDiscovery
import me.wojnowski.oidc4s.json.JsonDecoder
import me.wojnowski.oidc4s.json.JsonDecoder.ClaimsDecoder
import me.wojnowski.oidc4s.json.JsonSupport
import pdi.jwt.Jwt
import pdi.jwt.JwtAlgorithm
import pdi.jwt.JwtHeader

import java.nio.charset.StandardCharsets
import java.security.PublicKey
import java.security.Signature
import java.time.Instant
import java.time.ZoneId
import java.time.{Clock => JavaClock}
import java.util.Base64
import scala.util.Success
import scala.util.Try

trait IdTokenVerifier[F[_]] {
Expand Down Expand Up @@ -78,9 +80,6 @@ object IdTokenVerifier {
new IdTokenVerifier[F] {
import jsonSupport._

// According to OIDC RFC, only RS256 should be supported
private val supportedAlgorithms = Seq(JwtAlgorithm.RS256, JwtAlgorithm.RS384, JwtAlgorithm.RS512)

override def verify(rawToken: String, expectedClientId: ClientId): F[Either[Error, IdTokenClaims.Subject]] =
verifyAndDecode(rawToken).map(_.ensure(Error.ClientIdDoesNotMatch)(_.matchesClientId(expectedClientId)).map(_.subject))

Expand All @@ -104,49 +103,74 @@ object IdTokenVerifier {
): F[Either[IdTokenVerifier.Error, A]] = {
for {
issuer <- EitherT(issuerF)
instant <- EitherT.liftF(Clock[F].realTimeInstant)
javaClock = JavaClock.fixed(instant, ZoneId.of("UTC"))
now <- EitherT.liftF(Clock[F].realTimeInstant)
headerJson <- EitherT.fromEither(extractHeaderJson(rawToken))
kid <- EitherT.fromEither(extractKid(headerJson))
publicKey <- EitherT(publicKeyProvider.getKey(kid).map(_.leftMap(IdTokenVerifier.Error.CouldNotFindPublicKey.apply)))
header <- EitherT.fromEither(decodeHeader(headerJson))
publicKey <- EitherT(publicKeyProvider.getKey(header.keyId).map(_.leftMap(IdTokenVerifier.Error.CouldNotFindPublicKey.apply)))
result <- EitherT.fromEither {
decodeAndVerifyToken[(A, IdTokenClaims)](rawToken, javaClock, publicKey)
.flatMap { case (claims, standardClaims) =>
ensureExpectedIssuer(tokenIssuer = standardClaims.issuer, expectedIssuer = issuer)
.leftWiden[IdTokenVerifier.Error]
.flatTap { _ =>
standardClaimsCheck(standardClaims)
}
.as(claims)
}
decodeJwtAndVerifySignature[A](rawToken, publicKey, header).flatMap { case (claims, standardClaims) =>
List(
ensureNotExpired(now, standardClaims.expiration),
ensureExpectedIssuer(tokenIssuer = standardClaims.issuer, expectedIssuer = issuer),
standardClaimsCheck(standardClaims)
).sequence.as(claims)
}
}
} yield result
}.value

private def decodeAndVerifyToken[A: JsonDecoder](
rawToken: String,
javaClock: JavaClock,
publicKey: PublicKey
): Either[Error, A] =
Jwt(javaClock)
.decodeRaw(rawToken, publicKey, supportedAlgorithms)
.toEither
.leftMap[Error](throwable => JwtVerificationError(throwable))
.flatMap { rawClaims =>
JsonDecoder[A]
.decode(rawClaims)
.leftMap(IdTokenVerifier.Error.CouldNotDecodeClaim.apply)
}

private def ensureExpectedIssuer(tokenIssuer: Issuer, expectedIssuer: Issuer): Either[Error.UnexpectedIssuer, Unit] =
Either.cond(expectedIssuer === tokenIssuer, (), IdTokenVerifier.Error.UnexpectedIssuer(tokenIssuer, expectedIssuer))

private def extractKid(headerJson: String): Either[CouldNotExtractKeyId.type, String] =
private def ensureNotExpired(now: Instant, expiresAt: Instant): Either[Error.TokenExpired, Unit] =
Either.raiseWhen(expiresAt.isBefore(now))(TokenExpired(since = expiresAt))

private def decodeHeader(headerJson: String): Either[CouldNotDecodeHeader, JwtHeader] =
JsonDecoder[JwtHeader]
.decode(headerJson)
.toOption
.flatMap(_.keyId)
.toRight(CouldNotExtractKeyId)
.leftMap(CouldNotDecodeHeader.apply)

private def decodeJwtAndVerifySignature[A: ClaimsDecoder](rawToken: String, key: PublicKey, header: JwtHeader)
: Either[Error, (A, IdTokenClaims)] =
rawToken.split('.') match {
case Array(rawHeader, rawClaims, rawSignature) =>
for {
_ <- verifyAlgorithm(header.algorithm)
_ <- verifySignature(header.algorithm.fullName, key, rawHeader, rawClaims, rawSignature)
result <- parseClaims[A](rawClaims)
} yield result

case _ =>
InvalidToken.asLeft
}

private def parseClaims[A: ClaimsDecoder](rawClaims: String): Either[CouldNotDecodeClaim, (A, IdTokenClaims)] =
Try {
new String(Base64.getUrlDecoder.decode(rawClaims))
}.toEither.leftMap(t => CouldNotDecodeClaim(t.getMessage)).flatMap { rawJson =>
ClaimsDecoder[A].decode(rawJson).leftMap(CouldNotDecodeClaim.apply)
}

private def verifyAlgorithm(algorithm: Algorithm) =
Either.raiseUnless(Algorithm.supportedAlgorithms.contains_(algorithm))(UnsupportedAlgorithm(algorithm.name.some))

private def verifySignature(
signingAlgorithm: String,
publicKey: PublicKey,
rawHeader: String,
rawClaims: String,
rawSignature: String
) =
Try {
val decodedSignature = Base64.getUrlDecoder.decode(rawSignature)
val signatureInstance = Signature.getInstance(signingAlgorithm)
signatureInstance.initVerify(publicKey)
signatureInstance.update(s"$rawHeader.$rawClaims".getBytes(StandardCharsets.UTF_8))
signatureInstance.verify(decodedSignature)
} match {
case Success(true) => Either.unit
case _ => InvalidSignature.asLeft
}

private def extractHeaderJson(rawToken: String) =
Try {
Expand All @@ -171,9 +195,17 @@ object IdTokenVerifier {

case class CouldNotFindPublicKey(cause: PublicKeyProvider.Error) extends Error

case class CouldNotDecodeHeader(details: String) extends Error

case class CouldNotDecodeClaim(details: String) extends Error

case class JwtVerificationError(cause: Throwable) extends Error
case class TokenExpired(since: Instant) extends Error

case object InvalidToken extends Error

case object InvalidSignature extends Error

case class UnsupportedAlgorithm(providedAlgorithm: Option[String]) extends Error

case class UnexpectedIssuer(found: Issuer, expected: Issuer) extends Error
}
Expand Down
3 changes: 3 additions & 0 deletions core/src/main/scala/me/wojnowski/oidc4s/JwtHeader.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
package me.wojnowski.oidc4s

case class JwtHeader(keyId: String, algorithm: Algorithm)
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
package me.wojnowski.oidc4s.json

import me.wojnowski.oidc4s.IdTokenClaims
import me.wojnowski.oidc4s.JwtHeader
import me.wojnowski.oidc4s.PublicKeyProvider.JsonWebKeySet
import me.wojnowski.oidc4s.config.OpenIdConfig
import pdi.jwt.JwtHeader

trait JsonSupport {
implicit def jwtHeaderDecoder: JsonDecoder[JwtHeader]
Expand Down
41 changes: 17 additions & 24 deletions core/src/test/scala/me/wojnowski/oidc4s/IdTokenVerifierTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@ import cats.effect.IO
import cats.effect.testkit.TestControl
import cats.syntax.all._
import me.wojnowski.oidc4s.IdTokenClaims._
import me.wojnowski.oidc4s.IdTokenVerifier.Error.JwtVerificationError
import me.wojnowski.oidc4s.IdTokenVerifier.Error.InvalidSignature
import me.wojnowski.oidc4s.IdTokenVerifier.Error.TokenExpired
import me.wojnowski.oidc4s.IdTokenVerifier.Error.UnexpectedIssuer
import me.wojnowski.oidc4s.IdTokenVerifierTest.staticKeyProvider
import me.wojnowski.oidc4s.IdTokenVerifier.Error.UnsupportedAlgorithm
import me.wojnowski.oidc4s.PublicKeyProvider.KeyId
import me.wojnowski.oidc4s.PublicKeyProvider.KeyMap
import me.wojnowski.oidc4s.TimeUtils.InstantToFiniteDuration
Expand All @@ -21,17 +22,9 @@ import me.wojnowski.oidc4s.mocks.CacheMock
import me.wojnowski.oidc4s.mocks.HttpTransportMock
import me.wojnowski.oidc4s.mocks.JsonSupportMock
import munit.CatsEffectSuite
import pdi.jwt.JwtHeader
import pdi.jwt.exceptions.JwtEmptyAlgorithmException
import pdi.jwt.exceptions.JwtEmptySignatureException
import pdi.jwt.exceptions.JwtExpirationException
import pdi.jwt.exceptions.JwtValidationException

import java.security.KeyFactory
import java.security.PublicKey
import java.security.spec.X509EncodedKeySpec
import java.time.Instant
import java.util.Base64
import scala.annotation.unused

//noinspection ZeroIndexToHead
Expand Down Expand Up @@ -64,8 +57,8 @@ class IdTokenVerifierTest extends CatsEffectSuite {
.discovery[IO](googleKeyProvider, discovery, jsonSupport)
.verifyAndDecode(rawIdToken)
.map {
case Left(JwtVerificationError(_: JwtExpirationException)) => ()
case _ => fail("expected JwtExpirationException")
case Left(TokenExpired(`idTokenExpiration`)) => ()
case _ => fail(s"expected $TokenExpired")
}
}
}
Expand Down Expand Up @@ -228,8 +221,8 @@ class IdTokenVerifierTest extends CatsEffectSuite {
.discovery[IO](googleKeyProvider, discovery, jsonSupport)
.verifyAndDecode(tokenSignedWithOtherKey)
.map {
case Left(JwtVerificationError(_: JwtValidationException)) => ()
case e => fail(s"expected JwtValidationException, got $e")
case Left(InvalidSignature) => ()
case e => fail(s"expected JwtValidationException, got $e")
}
}
}
Expand Down Expand Up @@ -276,8 +269,8 @@ class IdTokenVerifierTest extends CatsEffectSuite {
.discovery[IO](nonGoogleKeyProvider, discovery, jsonSupport)
.verifyAndDecode(tokenWithAlgorithmNone)
.map {
case Left(JwtVerificationError(_: JwtEmptySignatureException)) => ()
case e => fail(s"expected JwtEmptySignatureException, got $e")
case Left(InvalidToken) => ()
case e => fail(s"expected JwtEmptySignatureException, got $e")
}
}
}
Expand All @@ -291,8 +284,8 @@ class IdTokenVerifierTest extends CatsEffectSuite {
.discovery[IO](nonGoogleKeyProvider, discovery, jsonSupport)
.verifyAndDecode(tokenWithAlgorithmNone)
.map {
case Left(JwtVerificationError(_: JwtEmptyAlgorithmException)) => ()
case e => fail(s"expected JwtEmptyAlgorithmException, got $e")
case Left(UnsupportedAlgorithm(_)) => ()
case e => fail(s"expected $UnsupportedAlgorithm, got $e")
}
}
}
Expand All @@ -305,8 +298,8 @@ class IdTokenVerifierTest extends CatsEffectSuite {
.discovery[IO](nonGoogleKeyProvider, discovery, jsonSupport)
.verifyAndDecode(tokenWithHs256Algorithm)
.map {
case Left(JwtVerificationError(_: JwtValidationException)) => ()
case e => fail(s"expected JwtValidationException, got $e")
case Left(UnsupportedAlgorithm(_)) => ()
case e => fail(s"expected $UnsupportedAlgorithm, got $e")
}
}
}
Expand Down Expand Up @@ -427,10 +420,10 @@ object IdTokenVerifierTest {

val decodedJwtHeaders =
List(
JwtHeader(keyId = Some("f9d97b4cae90bcd76aeb20026f6b770cac221783")),
JwtHeader(keyId = Some("11e03f39b8d300c8c9a1b800ddebfcfde4152c0c")),
JwtHeader(keyId = Some("11e03f39b8d300c8c9a1b800ddebfcfde4152c0c")),
JwtHeader(keyId = Some("11e03f39b8d300c8c9a1b800ddebfcfde4152c0c"))
JwtHeader(keyId = "f9d97b4cae90bcd76aeb20026f6b770cac221783", algorithm = Algorithm.Rs256),
JwtHeader(keyId = "11e03f39b8d300c8c9a1b800ddebfcfde4152c0c", algorithm = Algorithm.Rs256),
JwtHeader(keyId = "11e03f39b8d300c8c9a1b800ddebfcfde4152c0c", algorithm = Algorithm.Other("none")),
JwtHeader(keyId = "11e03f39b8d300c8c9a1b800ddebfcfde4152c0c", algorithm = Algorithm.Other("HS256"))
)

}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
package me.wojnowski.oidc4s.mocks

import me.wojnowski.oidc4s.IdTokenClaims
import me.wojnowski.oidc4s.JwtHeader
import me.wojnowski.oidc4s.PublicKeyProvider.JsonWebKeySet
import me.wojnowski.oidc4s.config.OpenIdConfig
import me.wojnowski.oidc4s.json.JsonDecoder
import me.wojnowski.oidc4s.json.JsonSupport
import pdi.jwt.JwtHeader

object JsonSupportMock {

Expand Down

0 comments on commit 2830a33

Please sign in to comment.