From b56931a3af677ca3593229dedb6e63e5ccd941ec Mon Sep 17 00:00:00 2001 From: Jakub Wojnowski <29680262+jwojnowski@users.noreply.github.com> Date: Thu, 5 Oct 2023 00:28:44 +0200 Subject: [PATCH] =?UTF-8?q?=E2=9C=85=20Add=20property-based=20tests.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- build.sbt | 5 +- .../oidc4s/PropertyIdTokenVerifierTest.scala | 140 ++++++++++++++++++ 2 files changed, 144 insertions(+), 1 deletion(-) create mode 100644 core/src/test/scala/me/wojnowski/oidc4s/PropertyIdTokenVerifierTest.scala diff --git a/build.sbt b/build.sbt index ebbeeef..413b56e 100644 --- a/build.sbt +++ b/build.sbt @@ -54,7 +54,10 @@ lazy val core = (project in file("core")).settings( libraryDependencies += "org.typelevel" %% "cats-effect" % Versions.cats.effect, libraryDependencies += "org.scalameta" %% "munit" % Versions.mUnit % Test, libraryDependencies += "org.typelevel" %% "munit-cats-effect-3" % Versions.mUnitCatsEffect % Test, - libraryDependencies += "org.typelevel" %% "cats-effect-testkit" % Versions.cats.effect % Test + libraryDependencies += "org.typelevel" %% "cats-effect-testkit" % Versions.cats.effect % Test, + libraryDependencies += "org.scalameta" %% "munit-scalacheck" % "0.7.29" % Test, + libraryDependencies += "org.typelevel" %% "scalacheck-effect-munit" % "1.0.4" % Test, + libraryDependencies += "com.github.jwt-scala" %% "jwt-core" % "9.4.4" % Test ) ) diff --git a/core/src/test/scala/me/wojnowski/oidc4s/PropertyIdTokenVerifierTest.scala b/core/src/test/scala/me/wojnowski/oidc4s/PropertyIdTokenVerifierTest.scala new file mode 100644 index 0000000..92b961f --- /dev/null +++ b/core/src/test/scala/me/wojnowski/oidc4s/PropertyIdTokenVerifierTest.scala @@ -0,0 +1,140 @@ +package me.wojnowski.oidc4s + +import cats.data.NonEmptySet +import cats.effect.IO +import cats.effect.testkit.TestControl +import cats.implicits._ +import me.wojnowski.oidc4s.IdTokenClaims.Audience +import me.wojnowski.oidc4s.IdTokenClaims.Subject +import me.wojnowski.oidc4s.PropertyIdTokenVerifierTest._ +import me.wojnowski.oidc4s.mocks.JsonSupportMock +import munit.CatsEffectSuite +import munit.ScalaCheckEffectSuite +import org.scalacheck.Test.Parameters +import org.scalacheck.Gen +import org.scalacheck.Test +import org.scalacheck.effect.PropF.forAllF +import pdi.jwt.Jwt +import pdi.jwt.JwtClaim + +import java.security.KeyPairGenerator +import java.security.PrivateKey +import java.security.PublicKey +import java.time.Clock +import java.time.Instant +import java.time.ZoneId +import java.util.UUID +import scala.concurrent.duration.DurationLong +import scala.concurrent.duration.FiniteDuration + +class PropertyIdTokenVerifierTest extends CatsEffectSuite with ScalaCheckEffectSuite { + + override val munitTimeout: FiniteDuration = 90.seconds + + override protected def scalaCheckTestParameters: Test.Parameters = Parameters.default.withMinSuccessfulTests(30) + + test("Verification succeeds given correct public key") { + forAllF(keyIdGen, matchingKeyPairGen, clockGen, algorithmGen, clientIdGen) { + case (keyId, (privateKey, publicKey), clock, algorithm, clientId) => + TestControl.executeEmbed { + val (rawJwt, verifier, subject) = prepareJwtAndVerifier(keyId, privateKey, publicKey, clock, algorithm, clientId) + + for { + _ <- IO.sleep(clock.instant.getEpochSecond.seconds) + result <- verifier.verify(rawJwt, clientId) + } yield assertEquals(result, Right(subject)) + } + } + } + + test("Verification fails given incorrect public key") { + forAllF(keyIdGen, mismatchedKeyPairGen, clockGen, algorithmGen, clientIdGen) { + case (keyId, (privateKey, publicKey), clock, algorithm, clientId) => + TestControl.executeEmbed { + val (rawJwt, verifier, _) = prepareJwtAndVerifier(keyId, privateKey, publicKey, clock, algorithm, clientId) + + for { + _ <- IO.sleep(clock.instant.getEpochSecond.seconds) + result <- verifier.verify(rawJwt, clientId) + } yield assertEquals(result, Left(IdTokenVerifier.Error.InvalidSignature)) + } + } + } + + private def prepareJwtAndVerifier( + keyId: UUID, + privateKey: PrivateKey, + publicKey: PublicKey, + clock: java.time.Clock, + algorithm: Algorithm, + clientId: ClientId + ): (String, IdTokenVerifier[IO], Subject) = { + val publicKeyProvider = PublicKeyProvider.static[IO](Map(keyId.toString -> publicKey)) + val issuer = Issuer("https://example.com") + val subject = Subject("user-id") + + val issuedAt = clock.instant() + val expiresAt = issuedAt.plusSeconds(600) + + val issuedAtSeconds = issuedAt.getEpochSecond + val expiresAtSeconds = expiresAt.getEpochSecond + + val rawClaims = s"""{"sub":"${subject.value}","aud":["$clientId"],"exp":$expiresAtSeconds,"iat":$issuedAtSeconds}""" + val claims = IdTokenClaims(issuer, subject, NonEmptySet.of(Audience(clientId.value)), expiresAt, issuedAt) + + val rawHeader = s"""{"alg":"${algorithm.name}","kid":"$keyId"}""" + val header = JwtHeader(keyId.toString, algorithm) + + val verifier = + IdTokenVerifier.static( + publicKeyProvider, + issuer, + JsonSupportMock.instance(Map(rawClaims -> claims), Map(rawHeader -> header)) + ) + + val rawJwt = + Jwt(clock).encode( + pdi.jwt.JwtHeader(pdi.jwt.JwtAlgorithm.fromString(algorithm.name).some, keyId = keyId.toString.some), + JwtClaim(rawClaims), + privateKey + ) + + (rawJwt, verifier, subject) + } + +} + +object PropertyIdTokenVerifierTest { + val keyIdGen: Gen[UUID] = Gen.uuid + + private def keyPairGen(keySize: Int) = Gen.delay { + val keyPairGenerator = KeyPairGenerator.getInstance("RSA") + keyPairGenerator.initialize(keySize) + val keyPair = keyPairGenerator.generateKeyPair() + Gen.const((keyPair.getPrivate, keyPair.getPublic)) + } + + val matchingKeyPairGen: Gen[(PrivateKey, PublicKey)] = Gen.oneOf(2048, 4096).flatMap { keySize => + keyPairGen(keySize) + } + + val mismatchedKeyPairGen: Gen[(PrivateKey, PublicKey)] = Gen.oneOf(2048, 4096).flatMap { keySize => + keyPairGen(keySize).flatMap { case firstPair @ (firstPrivateKey, _) => + keyPairGen(keySize).suchThat(_ != firstPair).map { case (_, secondPublicKey) => + (firstPrivateKey, secondPublicKey) + } + } + } + + val clockGen: Gen[Clock] = + Gen.choose(2137, 1696369856).map { seconds => + java.time.Clock.fixed(Instant.ofEpochSecond(seconds), ZoneId.of("UTC")) + } + + val algorithmGen: Gen[Algorithm] = + Gen.oneOf(Algorithm.supportedAlgorithms.toSortedSet) + + val clientIdGen: Gen[ClientId] = + Gen.choose(3, 24).flatMap(length => Gen.listOfN(length, Gen.alphaNumChar).map(_.mkString)).map(ClientId.apply) + +}