diff --git a/build.sbt b/build.sbt index 2e35672..01b9035 100644 --- a/build.sbt +++ b/build.sbt @@ -81,6 +81,15 @@ lazy val sttp = (project in file("sttp")) ) .dependsOn(core % "compile->compile;test->test") +lazy val testkit = (project in file("testkit")) + .settings( + commonSettings ++ Seq( + name := "oidc4s-testkit", + mimaPreviousArtifacts := Set.empty // TODO remove after release + ) + ) + .dependsOn(core) + lazy val quickSttpCirce = (project in file("quick-sttp-circe")) .settings( commonSettings ++ Seq( @@ -94,4 +103,4 @@ lazy val root = (project in file(".")) publish / skip := true, mimaFailOnNoPrevious := false ) - .aggregate(core, circe, sttp, quickSttpCirce) + .aggregate(core, circe, sttp, quickSttpCirce, testkit) diff --git a/testkit/src/main/scala/IdTokenVerifierMock.scala b/testkit/src/main/scala/IdTokenVerifierMock.scala new file mode 100644 index 0000000..1e0f166 --- /dev/null +++ b/testkit/src/main/scala/IdTokenVerifierMock.scala @@ -0,0 +1,129 @@ +import cats.Applicative +import cats.Traverse +import cats.data.NonEmptySet +import cats.effect.Clock +import cats.implicits._ +import me.wojnowski.oidc4s.IdTokenClaims.Audience +import me.wojnowski.oidc4s.json.JsonDecoder.ClaimsDecoder +import me.wojnowski.oidc4s.ClientId +import me.wojnowski.oidc4s.IdTokenClaims +import me.wojnowski.oidc4s.IdTokenVerifier +import me.wojnowski.oidc4s.Issuer +import me.wojnowski.oidc4s.IdTokenVerifier.Error.CouldNotDecodeClaim +import me.wojnowski.oidc4s.json.JsonDecoder +import me.wojnowski.oidc4s.json.JsonSupport + +object IdTokenVerifierMock { + + def constRawClaims[F[_]: Applicative](rawClaims: String)(implicit jsonSupport: JsonSupport): IdTokenVerifier[F] = + constRawClaimsEither(Right(rawClaims)) + + def constRawClaimsEither[F[_]: Applicative](rawClaimsEither: Either[IdTokenVerifier.Error, String])(implicit jsonSupport: JsonSupport) + : IdTokenVerifier[F] = constRawClaimsEitherPF(_ => rawClaimsEither) + + def constRawClaimsEitherPF[F[_]: Applicative]( + rawTokenToRawClaimsEither: PartialFunction[String, Either[IdTokenVerifier.Error, String]] + )(implicit jsonSupport: JsonSupport + ): IdTokenVerifier[F] = new IdTokenVerifier[F] { + import jsonSupport._ + + override def verify(rawToken: String, expectedClientId: ClientId): F[Either[IdTokenVerifier.Error, IdTokenClaims.Subject]] = + verifyAndDecode(rawToken).map(_.map(_.subject)) + + override def verifyAndDecode(rawToken: String): F[Either[IdTokenVerifier.Error, IdTokenClaims]] = + rawTokenToRawClaimsEither + .lift(rawToken) + .toRight(IdTokenVerifier.Error.MalformedToken: IdTokenVerifier.Error) + .flatten + .flatMap { rawClaims => + JsonDecoder[IdTokenClaims] + .decode(rawClaims) + .leftMap(IdTokenVerifier.Error.CouldNotDecodeClaim(_): IdTokenVerifier.Error) + } + .pure[F] + + override def verifyAndDecodeCustom[A](rawToken: String)(implicit decoder: ClaimsDecoder[A]): F[Either[IdTokenVerifier.Error, A]] = + rawTokenToRawClaimsEither + .lift(rawToken) + .toRight(IdTokenVerifier.Error.MalformedToken: IdTokenVerifier.Error) + .flatten + .flatMap { rawClaims => + ClaimsDecoder[A] + .decode(rawClaims) + .map(_._1) + .leftMap(IdTokenVerifier.Error.CouldNotDecodeClaim(_): IdTokenVerifier.Error) + } + .pure[F] + + override def verifyAndDecodeCustom[A](rawToken: String, expectedClientId: ClientId)(implicit decoder: ClaimsDecoder[A]) + : F[Either[IdTokenVerifier.Error, A]] = + verifyAndDecodeCustom(rawToken) + + } + + def constSubject[F[_]: Applicative: Traverse: Clock](subject: IdTokenClaims.Subject): IdTokenVerifier[F] = + constSubjectEither[F](Right(subject)) + + def constSubjectEither[F[_]: Applicative: Traverse: Clock](errorOrSubject: Either[IdTokenVerifier.Error, IdTokenClaims.Subject]) + : IdTokenVerifier[F] = + constSubjectPF[F]((_: String) => errorOrSubject) + + def constSubjectPF[F[_]: Applicative: Traverse: Clock]( + rawTokenToSubjectPF: PartialFunction[String, Either[IdTokenVerifier.Error, IdTokenClaims.Subject]] + ): IdTokenVerifier[F] = + constStandardClaimsPF( + rawTokenToSubjectPF.map { errorOrSubject => + Applicative[F].map(Clock[F].realTimeInstant) { now => + errorOrSubject.map(subject => + IdTokenClaims( + Issuer("https://example.com"), + subject, + NonEmptySet.of(Audience("https://example.com")), + expiration = now.plusSeconds(600), + issuedAt = now + ) + ) + } + } + ) + + def constStandardClaims[F[_]: Applicative: Traverse](claims: IdTokenClaims): IdTokenVerifier[F] = constStandardClaimsEither(Right(claims)) + + def constStandardClaimsEither[F[_]: Applicative: Traverse](claimsEither: Either[IdTokenVerifier.Error, IdTokenClaims]) + : IdTokenVerifier[F] = + constStandardClaimsPF[F](_ => claimsEither.pure[F]) + + def constStandardClaimsPF[F[_]: Applicative: Traverse]( + rawTokenToClaimsPF: PartialFunction[String, F[Either[IdTokenVerifier.Error, IdTokenClaims]]] + ): IdTokenVerifier[F] = new IdTokenVerifier[F] { + + override def verifyAndDecode(rawToken: String): F[Either[IdTokenVerifier.Error, IdTokenClaims]] = + Applicative[F].map( + rawTokenToClaimsPF + .lift(rawToken) + .toRight(IdTokenVerifier.Error.MalformedToken: IdTokenVerifier.Error) + .sequence + )(_.flatten) + + override def verify(rawToken: String, expectedClientId: ClientId): F[Either[IdTokenVerifier.Error, IdTokenClaims.Subject]] = + Applicative[F].map(verifyAndDecode(rawToken))(_.map(_.subject)) + + override def verifyAndDecodeCustom[A](rawToken: String)(implicit decoder: ClaimsDecoder[A]): F[Either[IdTokenVerifier.Error, A]] = + Applicative[F].map( + rawTokenToClaimsPF + .lift(rawToken) + .toRight(IdTokenVerifier.Error.MalformedToken: IdTokenVerifier.Error) + .sequence + ) { + _.flatten.flatMap { _ => + CouldNotDecodeClaim("mock").asLeft[A] + } + } + + override def verifyAndDecodeCustom[A](rawToken: String, expectedClientId: ClientId)(implicit decoder: ClaimsDecoder[A]) + : F[Either[IdTokenVerifier.Error, A]] = + verifyAndDecodeCustom(rawToken) + + } + +}