diff --git a/modules/core/src/main/scala/com.snowplowanalytics.snowplow.snowflake/Environment.scala b/modules/core/src/main/scala/com.snowplowanalytics.snowplow.snowflake/Environment.scala index 131be5d..b816b5a 100644 --- a/modules/core/src/main/scala/com.snowplowanalytics.snowplow.snowflake/Environment.scala +++ b/modules/core/src/main/scala/com.snowplowanalytics.snowplow.snowflake/Environment.scala @@ -15,7 +15,7 @@ import cats.effect.{Async, Resource} import com.snowplowanalytics.iglu.core.SchemaCriterion import com.snowplowanalytics.snowplow.runtime.{AppInfo, HealthProbe} import com.snowplowanalytics.snowplow.sinks.Sink -import com.snowplowanalytics.snowplow.snowflake.processing.{ChannelProvider, SnowflakeHealth, TableManager} +import com.snowplowanalytics.snowplow.snowflake.processing.{Channel, Coldswap, SnowflakeHealth, TableManager} import com.snowplowanalytics.snowplow.sources.SourceAndAck import org.http4s.blaze.client.BlazeClientBuilder import org.http4s.client.Client @@ -26,7 +26,7 @@ case class Environment[F[_]]( badSink: Sink[F], httpClient: Client[F], tableManager: TableManager[F], - channelProvider: ChannelProvider[F], + channel: Coldswap[F, Channel[F]], metrics: Metrics[F], batching: Config.Batching, schemasToSkip: List[SchemaCriterion] @@ -53,17 +53,17 @@ object Environment { badSink <- toSink(config.output.bad) metrics <- Resource.eval(Metrics.build(config.monitoring.metrics)) tableManager <- Resource.eval(TableManager.make(config.output.good, snowflakeHealth, config.retries, monitoring)) - _ <- Resource.eval(tableManager.initializeEventsTable()) - channelProvider <- ChannelProvider.make(config.output.good, snowflakeHealth, config.batching, config.retries, monitoring) + channelResource <- Channel.make(config.output.good, snowflakeHealth, config.batching, config.retries, monitoring) + channelColdswap <- Coldswap.make(channelResource) } yield Environment( - appInfo = appInfo, - source = sourceAndAck, - badSink = badSink, - httpClient = httpClient, - tableManager = tableManager, - channelProvider = channelProvider, - metrics = metrics, - batching = config.batching, - schemasToSkip = config.skipSchemas + appInfo = appInfo, + source = sourceAndAck, + badSink = badSink, + httpClient = httpClient, + tableManager = tableManager, + channel = channelColdswap, + metrics = metrics, + batching = config.batching, + schemasToSkip = config.skipSchemas ) } diff --git a/modules/core/src/main/scala/com.snowplowanalytics.snowplow.snowflake/processing/ChannelProvider.scala b/modules/core/src/main/scala/com.snowplowanalytics.snowplow.snowflake/processing/Channel.scala similarity index 65% rename from modules/core/src/main/scala/com.snowplowanalytics.snowplow.snowflake/processing/ChannelProvider.scala rename to modules/core/src/main/scala/com.snowplowanalytics.snowplow.snowflake/processing/Channel.scala index ece0224..3bd89aa 100644 --- a/modules/core/src/main/scala/com.snowplowanalytics.snowplow.snowflake/processing/ChannelProvider.scala +++ b/modules/core/src/main/scala/com.snowplowanalytics.snowplow.snowflake/processing/Channel.scala @@ -10,10 +10,7 @@ package com.snowplowanalytics.snowplow.snowflake.processing -import cats.effect.implicits._ -import cats.effect.kernel.{Ref, Resource} -import cats.effect.std.{Hotswap, Semaphore} -import cats.effect.{Async, Poll, Sync} +import cats.effect.{Async, Poll, Resource, Sync} import cats.implicits._ import com.snowplowanalytics.snowplow.snowflake.{Alert, Config, Monitoring} import net.snowflake.ingest.streaming.internal.SnowsFlakePlowInterop @@ -26,23 +23,7 @@ import java.time.ZoneOffset import java.util.Properties import scala.jdk.CollectionConverters._ -trait ChannelProvider[F[_]] { - - /** - * Closes the open channel and opens a new channel - * - * This should be called if the channel becomes invalid. And the channel becomes invalid if the - * table is altered by another concurrent loader. - */ - def reset: F[Unit] - - /** - * Wraps an action which requires the channel to be closed - * - * This should be called when altering the table to add new columns. The newly opened channel will - * be able to use the new columns. - */ - def withClosedChannel[A](fa: F[A]): F[A] +trait Channel[F[_]] { /** * Writes rows to Snowflake @@ -52,10 +33,10 @@ trait ChannelProvider[F[_]] { * @return * List of the details of any insert failures. Empty list implies complete success. */ - def write(rows: Iterable[Map[String, AnyRef]]): F[ChannelProvider.WriteResult] + def write(rows: Iterable[Map[String, AnyRef]]): F[Channel.WriteResult] } -object ChannelProvider { +object Channel { private implicit def logger[F[_]: Sync] = Slf4jLogger.getLogger[F] @@ -95,88 +76,36 @@ object ChannelProvider { * Contains details of any failures to write events to Snowflake. If the write was completely * successful then this list is empty. */ - case class WriteFailures(value: List[ChannelProvider.WriteFailure]) extends WriteResult + case class WriteFailures(value: List[Channel.WriteFailure]) extends WriteResult } - /** A large number so we don't limit the number of permits for calls to `flush` and `enqueue` */ - private val allAvailablePermits: Long = Long.MaxValue - def make[F[_]: Async]( config: Config.Snowflake, snowflakeHealth: SnowflakeHealth[F], batchingConfig: Config.Batching, retriesConfig: Config.Retries, monitoring: Monitoring[F] - ): Resource[F, ChannelProvider[F]] = + ): Resource[F, Resource[F, Channel[F]]] = for { client <- createClient(config, batchingConfig) - channelResource = createChannel(config, client, snowflakeHealth, retriesConfig, monitoring) - (hs, channel) <- Hotswap.apply(channelResource) - ref <- Resource.eval(Ref[F].of(channel)) - sem <- Resource.eval(Semaphore[F](allAvailablePermits)) - } yield impl(ref, hs, sem, channelResource) - - private def impl[F[_]: Async]( - ref: Ref[F, SnowflakeStreamingIngestChannel], - hs: Hotswap[F, SnowflakeStreamingIngestChannel], - sem: Semaphore[F], - next: Resource[F, SnowflakeStreamingIngestChannel] - ): ChannelProvider[F] = - new ChannelProvider[F] { - def reset: F[Unit] = - withAllPermits(sem) { // Must have **all** permits so we don't conflict with a write - ref.get.flatMap { channel => - if (channel.isValid()) - // We might have concurrent fibers calling `reset`. This just means another fiber - // has already reset this channel. - Sync[F].unit - else - Sync[F].uncancelable { _ => - for { - _ <- hs.clear - channel <- hs.swap(next) - _ <- ref.set(channel) - } yield () - } - } - } + } yield createChannel[F](config, client, snowflakeHealth, retriesConfig, monitoring).map(impl[F]) - def withClosedChannel[A](fa: F[A]): F[A] = - withAllPermits(sem) { // Must have **all** permites so we don't conflict with a write - Sync[F].uncancelable { _ => - for { - _ <- hs.clear - a <- fa - channel <- hs.swap(next) - _ <- ref.set(channel) - } yield a - } - } + private def impl[F[_]: Async](channel: SnowflakeStreamingIngestChannel): Channel[F] = + new Channel[F] { - def write(rows: Iterable[Map[String, AnyRef]]): F[WriteResult] = - sem.permit - .use[WriteResult] { _ => - for { - channel <- ref.get - response <- Sync[F].blocking(channel.insertRows(rows.map(_.asJava).asJava, null)) - _ <- flushChannel[F](channel) - isValid <- Sync[F].delay(channel.isValid) - } yield if (isValid) WriteResult.WriteFailures(parseResponse(response)) else WriteResult.ChannelIsInvalid - } - .recover { - case sfe: SFException if sfe.getVendorCode === SFErrorCode.INVALID_CHANNEL.getMessageCode => - WriteResult.ChannelIsInvalid - } - } + def write(rows: Iterable[Map[String, AnyRef]]): F[WriteResult] = { + val attempt: F[WriteResult] = for { + response <- Sync[F].blocking(channel.insertRows(rows.map(_.asJava).asJava, null)) + _ <- flushChannel[F](channel) + isValid <- Sync[F].delay(channel.isValid) + } yield if (isValid) WriteResult.WriteFailures(parseResponse(response)) else WriteResult.ChannelIsInvalid - /** Wraps a `F[A]` so it only runs when no other fiber is using the channel at the same time */ - private def withAllPermits[F[_]: Sync, A](sem: Semaphore[F])(f: F[A]): F[A] = - Sync[F].uncancelable { poll => - for { - _ <- poll(sem.acquireN(allAvailablePermits)) - a <- f.guarantee(sem.releaseN(allAvailablePermits)) - } yield a + attempt.recover { + case sfe: SFException if sfe.getVendorCode === SFErrorCode.INVALID_CHANNEL.getMessageCode => + WriteResult.ChannelIsInvalid + } + } } private def parseResponse(response: InsertValidationResponse): List[WriteFailure] = diff --git a/modules/core/src/main/scala/com.snowplowanalytics.snowplow.snowflake/processing/Coldswap.scala b/modules/core/src/main/scala/com.snowplowanalytics.snowplow.snowflake/processing/Coldswap.scala new file mode 100644 index 0000000..31fa051 --- /dev/null +++ b/modules/core/src/main/scala/com.snowplowanalytics.snowplow.snowflake/processing/Coldswap.scala @@ -0,0 +1,107 @@ +/* + * Copyright (c) 2014-present Snowplow Analytics Ltd. All rights reserved. + * + * This software is made available by Snowplow Analytics, Ltd., + * under the terms of the Snowplow Limited Use License Agreement, Version 1.0 + * located at https://docs.snowplow.io/limited-use-license-1.0 + * BY INSTALLING, DOWNLOADING, ACCESSING, USING OR DISTRIBUTING ANY PORTION + * OF THE SOFTWARE, YOU AGREE TO THE TERMS OF SUCH LICENSE AGREEMENT. + */ + +package com.snowplowanalytics.snowplow.snowflake.processing + +import cats.effect.{Async, Ref, Resource, Sync} +import cats.effect.std.Semaphore +import cats.Functor +import cats.implicits._ + +/** + * Manages swapping of Resources + * + * Inspired by `cats.effect.std.Hotswap` but with differences. A Hotswap is "hot" because a `swap` + * acquires the next resource before closing the previous one. Whereas this Coldswap is "cold" + * because it always closes any previous Resources before acquiring the next one. + * + * * '''Note''': The resource cannot be simultaneously open and closed, and so + * `coldswap.opened.surround(coldswap.closed.use_)` will deadlock. + */ +final class Coldswap[F[_]: Sync, A] private ( + sem: Semaphore[F], + ref: Ref[F, Coldswap.State[F, A]], + resource: Resource[F, A] +) { + import Coldswap._ + + /** + * Gets the current resource, or opens a new one if required. The returned `A` is guaranteed to be + * available for the duration of the `Resource.use` block. + */ + def opened: Resource[F, A] = + (sem.permit *> Resource.eval[F, State[F, A]](ref.get)).flatMap { + case Opened(a, _) => Resource.pure(a) + case Closed => + for { + _ <- releaseHeldPermit(sem) + _ <- acquireAllPermits(sem) + a <- Resource.eval(doOpen(ref, resource)) + } yield a + } + + /** + * Closes the resource if it was open. The resource is guaranteed to remain closed for the + * duration of the `Resource.use` block. + */ + def closed: Resource[F, Unit] = + (sem.permit *> Resource.eval(ref.get)).flatMap { + case Closed => Resource.unit + case Opened(_, _) => + for { + _ <- releaseHeldPermit(sem) + _ <- acquireAllPermits(sem) + _ <- Resource.eval(doClose(ref)) + } yield () + } + +} + +object Coldswap { + + private sealed trait State[+F[_], +A] + private case object Closed extends State[Nothing, Nothing] + private case class Opened[F[_], A](value: A, close: F[Unit]) extends State[F, A] + + def make[F[_]: Async, A](resource: Resource[F, A]): Resource[F, Coldswap[F, A]] = + for { + sem <- Resource.eval(Semaphore[F](Long.MaxValue)) + ref <- Resource.eval(Ref.of[F, State[F, A]](Closed)) + _ <- Resource.onFinalize(acquireAllPermits(sem).use(_ => doClose(ref))) + } yield new Coldswap(sem, ref, resource) + + private def releaseHeldPermit[F[_]: Functor](sem: Semaphore[F]): Resource[F, Unit] = + Resource.makeFull[F, Unit](poll => poll(sem.release))(_ => sem.acquire) + + private def acquireAllPermits[F[_]: Functor](sem: Semaphore[F]): Resource[F, Unit] = + Resource.makeFull[F, Unit](poll => poll(sem.acquireN(Long.MaxValue)))(_ => sem.releaseN(Long.MaxValue)) + + private def doClose[F[_]: Sync, A](ref: Ref[F, State[F, A]]): F[Unit] = + ref.get.flatMap { + case Closed => Sync[F].unit + case Opened(_, close) => + Sync[F].uncancelable { _ => + close *> ref.set(Closed) + } + } + + private def doOpen[F[_]: Sync, A](ref: Ref[F, State[F, A]], resource: Resource[F, A]): F[A] = + ref.get.flatMap { + case Opened(a, _) => Sync[F].pure(a) + case Closed => + Sync[F].uncancelable { _ => + for { + (a, close) <- resource.allocated + _ <- ref.set(Opened(a, close)) + } yield a + } + } + +} diff --git a/modules/core/src/main/scala/com.snowplowanalytics.snowplow.snowflake/processing/Processing.scala b/modules/core/src/main/scala/com.snowplowanalytics.snowplow.snowflake/processing/Processing.scala index 6c884cf..0c8ef51 100644 --- a/modules/core/src/main/scala/com.snowplowanalytics.snowplow.snowflake/processing/Processing.scala +++ b/modules/core/src/main/scala/com.snowplowanalytics.snowplow.snowflake/processing/Processing.scala @@ -11,7 +11,7 @@ package com.snowplowanalytics.snowplow.snowflake.processing import cats.implicits._ -import cats.{Applicative, Foldable, Monad} +import cats.{Applicative, Foldable} import cats.effect.{Async, Sync} import cats.effect.kernel.Unique import com.snowplowanalytics.iglu.core.SchemaCriterion @@ -39,7 +39,8 @@ object Processing { def stream[F[_]: Async](env: Environment[F]): Stream[F, Nothing] = { val eventProcessingConfig = EventProcessingConfig(EventProcessingConfig.NoWindowing) - env.source.stream(eventProcessingConfig, eventProcessor(env)) + Stream.eval(env.tableManager.initializeEventsTable()) *> + env.source.stream(eventProcessingConfig, eventProcessor(env)) } /** Model used between stages of the processing pipeline */ @@ -106,7 +107,7 @@ object Processing { private object ParsedWriteResult { def empty: ParsedWriteResult = ParsedWriteResult(Set.empty, Nil, Nil) - def buildFrom(events: ListOfList[EventWithTransform], writeFailures: List[ChannelProvider.WriteFailure]): ParsedWriteResult = + def buildFrom(events: ListOfList[EventWithTransform], writeFailures: List[Channel.WriteFailure]): ParsedWriteResult = if (writeFailures.isEmpty) empty else { @@ -121,9 +122,7 @@ object Processing { } } - private def eventProcessor[F[_]: Async]( - env: Environment[F] - ): EventProcessor[F] = { in => + private def eventProcessor[F[_]: Async](env: Environment[F]): EventProcessor[F] = { in => val badProcessor = BadRowProcessor(env.appInfo.name, env.appInfo.version) in.through(setLatency(env.metrics)) @@ -219,19 +218,23 @@ object Processing { env: Environment[F], batch: BatchAfterTransform )( - handleFailures: List[ChannelProvider.WriteFailure] => F[BatchAfterTransform] + handleFailures: List[Channel.WriteFailure] => F[BatchAfterTransform] ): F[BatchAfterTransform] = if (batch.toBeInserted.isEmpty) batch.pure[F] else Sync[F].untilDefinedM { - env.channelProvider.write(batch.toBeInserted.asIterable.map(_._2)).flatMap { - case ChannelProvider.WriteResult.ChannelIsInvalid => - // Reset the channel and immediately try again - env.channelProvider.reset.as(none) - case ChannelProvider.WriteResult.WriteFailures(notWritten) => - handleFailures(notWritten).map(Some(_)) - } + env.channel.opened + .use { channel => + channel.write(batch.toBeInserted.asIterable.map(_._2)) + } + .flatMap { + case Channel.WriteResult.ChannelIsInvalid => + // Reset the channel and immediately try again + env.channel.closed.use_.as(none) + case Channel.WriteResult.WriteFailures(notWritten) => + handleFailures(notWritten).map(Some(_)) + } } /** @@ -333,14 +336,14 @@ object Processing { * Alters the table to add any columns that were present in the Events but not currently in the * table */ - private def handleSchemaEvolution[F[_]: Monad]( + private def handleSchemaEvolution[F[_]: Sync]( env: Environment[F], extraColsRequired: Set[String] ): F[Unit] = if (extraColsRequired.isEmpty) ().pure[F] else - env.channelProvider.withClosedChannel { + env.channel.closed.surround { env.tableManager.addColumns(extraColsRequired.toList) } diff --git a/modules/core/src/test/scala/com.snowplowanalytics.snowplow.snowflake/MockEnvironment.scala b/modules/core/src/test/scala/com.snowplowanalytics.snowplow.snowflake/MockEnvironment.scala index f870ee5..22c99ce 100644 --- a/modules/core/src/test/scala/com.snowplowanalytics.snowplow.snowflake/MockEnvironment.scala +++ b/modules/core/src/test/scala/com.snowplowanalytics.snowplow.snowflake/MockEnvironment.scala @@ -17,7 +17,7 @@ import fs2.Stream import com.snowplowanalytics.snowplow.sources.{EventProcessingConfig, EventProcessor, SourceAndAck, TokenedEvents} import com.snowplowanalytics.snowplow.sinks.Sink -import com.snowplowanalytics.snowplow.snowflake.processing.{ChannelProvider, TableManager} +import com.snowplowanalytics.snowplow.snowflake.processing.{Channel, Coldswap, TableManager} import com.snowplowanalytics.snowplow.runtime.AppInfo import scala.concurrent.duration.{DurationInt, FiniteDuration} @@ -47,23 +47,24 @@ object MockEnvironment { * @param inputs * Input events to send into the environment. * @param channelResponses - * Responses we want the `ChannelProvider` to return when someone calls `write` + * Responses we want the `Channel` to return when someone calls `write` * @return * An environment and a Ref that records the actions make by the environment */ - def build(inputs: List[TokenedEvents], channelResponses: List[ChannelProvider.WriteResult] = Nil): IO[MockEnvironment] = + def build(inputs: List[TokenedEvents], channelResponses: List[Channel.WriteResult]): Resource[IO, MockEnvironment] = for { - state <- Ref[IO].of(Vector.empty[Action]) - channelProvider <- testChannelProvider(state, channelResponses) + state <- Resource.eval(Ref[IO].of(Vector.empty[Action])) + channelResource <- Resource.eval(testChannel(state, channelResponses)) + channelColdswap <- Coldswap.make(channelResource) } yield { val env = Environment( - appInfo = appInfo, - source = testSourceAndAck(inputs, state), - badSink = testSink(state), - httpClient = testHttpClient, - tableManager = testTableManager(state), - channelProvider = channelProvider, - metrics = testMetrics(state), + appInfo = appInfo, + source = testSourceAndAck(inputs, state), + badSink = testSink(state), + httpClient = testHttpClient, + tableManager = testTableManager(state), + channel = channelColdswap, + metrics = testMetrics(state), batching = Config.Batching( maxBytes = 16000000, maxDelay = 10.seconds, @@ -115,44 +116,40 @@ object MockEnvironment { } /** - * Mocked implementation of a `ChannelProvider` + * Mocked implementation of a `Channel` * * @param actionRef * Global Ref used to accumulate actions that happened * @param responses - * Responses that this mocked ChannelProvider should return each time someone calls `write`. If - * no responses given, then it will return with a successful response. + * Responses that this mocked Channel should return each time someone calls `write`. If no + * responses given, then it will return with a successful response. */ - private def testChannelProvider( + private def testChannel( actionRef: Ref[IO, Vector[Action]], - responses: List[ChannelProvider.WriteResult] - ): IO[ChannelProvider[IO]] = + responses: List[Channel.WriteResult] + ): IO[Resource[IO, Channel[IO]]] = for { responseRef <- Ref[IO].of(responses) - } yield new ChannelProvider[IO] { - def reset: IO[Unit] = - actionRef.update(_ :+ ClosedChannel :+ OpenedChannel) - - def withClosedChannel[A](fa: IO[A]): IO[A] = - for { - _ <- actionRef.update(_ :+ ClosedChannel) - a <- fa - _ <- actionRef.update(_ :+ OpenedChannel) - } yield a - - def write(rows: Iterable[Map[String, AnyRef]]): IO[ChannelProvider.WriteResult] = - for { - response <- responseRef.modify { - case head :: tail => (tail, head) - case Nil => (Nil, ChannelProvider.WriteResult.WriteFailures(Nil)) - } - _ <- response match { - case ChannelProvider.WriteResult.WriteFailures(failures) => - actionRef.update(_ :+ WroteRowsToSnowflake(rows.size - failures.size)) - case ChannelProvider.WriteResult.ChannelIsInvalid => - IO.unit - } - } yield response + } yield { + val make = actionRef.update(_ :+ OpenedChannel).as { + new Channel[IO] { + def write(rows: Iterable[Map[String, AnyRef]]): IO[Channel.WriteResult] = + for { + response <- responseRef.modify { + case head :: tail => (tail, head) + case Nil => (Nil, Channel.WriteResult.WriteFailures(Nil)) + } + _ <- response match { + case Channel.WriteResult.WriteFailures(failures) => + actionRef.update(_ :+ WroteRowsToSnowflake(rows.size - failures.size)) + case Channel.WriteResult.ChannelIsInvalid => + IO.unit + } + } yield response + } + } + + Resource.make(make)(_ => actionRef.update(_ :+ ClosedChannel)) } def testMetrics(ref: Ref[IO, Vector[Action]]): Metrics[IO] = new Metrics[IO] { diff --git a/modules/core/src/test/scala/com.snowplowanalytics.snowplow.snowflake/processing/ProcessingSpec.scala b/modules/core/src/test/scala/com.snowplowanalytics.snowplow.snowflake/processing/ProcessingSpec.scala index 20d71c7..c102bf3 100644 --- a/modules/core/src/test/scala/com.snowplowanalytics.snowplow.snowflake/processing/ProcessingSpec.scala +++ b/modules/core/src/test/scala/com.snowplowanalytics.snowplow.snowflake/processing/ProcessingSpec.scala @@ -35,187 +35,207 @@ class ProcessingSpec extends Specification with CatsEffect { Insert events to Snowflake and ack the events $e1 Emit BadRows when there are badly formatted events $e2 Write good batches and bad events when input contains both $e3 - Alter the Snowflake table when the ChannelProvider reports missing columns $e4 - Emit BadRows when the ChannelProvider reports a problem with the data $e5 - Abort processing and don't ack events when the ChannelProvider reports a runtime error $e6 - Reset the Channel when the ChannelProvider reports the channel has become invalid $e7 + Alter the Snowflake table when the Channel reports missing columns $e4 + Emit BadRows when the Channel reports a problem with the data $e5 + Abort processing and don't ack events when the Channel reports a runtime error $e6 + Reset the Channel when the Channel reports the channel has become invalid $e7 Set the latency metric based off the message timestamp $e8 """ def e1 = - for { - inputs <- generateEvents.take(2).compile.toList - control <- MockEnvironment.build(inputs) - _ <- Processing.stream(control.environment).compile.drain - state <- control.state.get - } yield state should beEqualTo( - Vector( - Action.WroteRowsToSnowflake(4), - Action.AddedGoodCountMetric(4), - Action.AddedBadCountMetric(0), - Action.Checkpointed(List(inputs(0).ack, inputs(1).ack)) + setup(generateEvents.take(2).compile.toList) { case (inputs, control) => + for { + _ <- Processing.stream(control.environment).compile.drain + state <- control.state.get + } yield state should beEqualTo( + Vector( + Action.InitEventsTable, + Action.OpenedChannel, + Action.WroteRowsToSnowflake(4), + Action.AddedGoodCountMetric(4), + Action.AddedBadCountMetric(0), + Action.Checkpointed(List(inputs(0).ack, inputs(1).ack)) + ) ) - ) + } def e2 = - for { - inputs <- generateBadlyFormatted.take(3).compile.toList - control <- MockEnvironment.build(inputs) - _ <- Processing.stream(control.environment).compile.drain - state <- control.state.get - } yield state should beEqualTo( - Vector( - Action.SentToBad(6), - Action.AddedGoodCountMetric(0), - Action.AddedBadCountMetric(6), - Action.Checkpointed(List(inputs(0).ack, inputs(1).ack, inputs(2).ack)) + setup(generateBadlyFormatted.take(3).compile.toList) { case (inputs, control) => + for { + _ <- Processing.stream(control.environment).compile.drain + state <- control.state.get + } yield state should beEqualTo( + Vector( + Action.InitEventsTable, + Action.SentToBad(6), + Action.AddedGoodCountMetric(0), + Action.AddedBadCountMetric(6), + Action.Checkpointed(List(inputs(0).ack, inputs(1).ack, inputs(2).ack)) + ) ) - ) + } - def e3 = - for { + def e3 = { + val toInputs = for { bads <- generateBadlyFormatted.take(3).compile.toList goods <- generateEvents.take(3).compile.toList - inputs = bads.zip(goods).map { case (bad, good) => - TokenedEvents(bad.events ++ good.events, good.ack, None) - } - control <- MockEnvironment.build(inputs) - _ <- Processing.stream(control.environment).compile.drain - state <- control.state.get - } yield state should beEqualTo( - Vector( - Action.WroteRowsToSnowflake(6), - Action.SentToBad(6), - Action.AddedGoodCountMetric(6), - Action.AddedBadCountMetric(6), - Action.Checkpointed(List(inputs(0).ack, inputs(1).ack, inputs(2).ack)) + } yield bads.zip(goods).map { case (bad, good) => + TokenedEvents(bad.events ++ good.events, good.ack, None) + } + setup(toInputs) { case (inputs, control) => + for { + _ <- Processing.stream(control.environment).compile.drain + state <- control.state.get + } yield state should beEqualTo( + Vector( + Action.InitEventsTable, + Action.OpenedChannel, + Action.WroteRowsToSnowflake(6), + Action.SentToBad(6), + Action.AddedGoodCountMetric(6), + Action.AddedBadCountMetric(6), + Action.Checkpointed(List(inputs(0).ack, inputs(1).ack, inputs(2).ack)) + ) ) - ) + } + } def e4 = { val mockedChannelResponses = List( - ChannelProvider.WriteResult.WriteFailures( + Channel.WriteResult.WriteFailures( List( - ChannelProvider.WriteFailure(0L, List("unstruct_event_xyz_1", "contexts_abc_2"), new SFException(ErrorCode.INVALID_FORMAT_ROW)) + Channel.WriteFailure(0L, List("unstruct_event_xyz_1", "contexts_abc_2"), new SFException(ErrorCode.INVALID_FORMAT_ROW)) ) ), - ChannelProvider.WriteResult.WriteFailures(Nil) + Channel.WriteResult.WriteFailures(Nil) ) - for { - inputs <- generateEvents.take(1).compile.toList - control <- MockEnvironment.build(inputs, mockedChannelResponses) - _ <- Processing.stream(control.environment).compile.drain - state <- control.state.get - } yield state should beEqualTo( - Vector( - Action.WroteRowsToSnowflake(1), - Action.ClosedChannel, - Action.AlterTableAddedColumns(List("unstruct_event_xyz_1", "contexts_abc_2")), - Action.OpenedChannel, - Action.WroteRowsToSnowflake(1), - Action.AddedGoodCountMetric(2), - Action.AddedBadCountMetric(0), - Action.Checkpointed(List(inputs(0).ack)) + setup(generateEvents.take(1).compile.toList, mockedChannelResponses) { case (inputs, control) => + for { + _ <- Processing.stream(control.environment).compile.drain + state <- control.state.get + } yield state should beEqualTo( + Vector( + Action.InitEventsTable, + Action.OpenedChannel, + Action.WroteRowsToSnowflake(1), + Action.ClosedChannel, + Action.AlterTableAddedColumns(List("unstruct_event_xyz_1", "contexts_abc_2")), + Action.OpenedChannel, + Action.WroteRowsToSnowflake(1), + Action.AddedGoodCountMetric(2), + Action.AddedBadCountMetric(0), + Action.Checkpointed(List(inputs(0).ack)) + ) ) - ) + } } def e5 = { val mockedChannelResponses = List( - ChannelProvider.WriteResult.WriteFailures( + Channel.WriteResult.WriteFailures( List( - ChannelProvider.WriteFailure(0L, Nil, new SFException(ErrorCode.INVALID_FORMAT_ROW)) + Channel.WriteFailure(0L, Nil, new SFException(ErrorCode.INVALID_FORMAT_ROW)) ) ), - ChannelProvider.WriteResult.WriteFailures(Nil) + Channel.WriteResult.WriteFailures(Nil) ) - for { - inputs <- generateEvents.take(1).compile.toList - control <- MockEnvironment.build(inputs, mockedChannelResponses) - _ <- Processing.stream(control.environment).compile.drain - state <- control.state.get - } yield state should beEqualTo( - Vector( - Action.WroteRowsToSnowflake(1), - Action.SentToBad(1), - Action.AddedGoodCountMetric(1), - Action.AddedBadCountMetric(1), - Action.Checkpointed(List(inputs(0).ack)) + setup(generateEvents.take(1).compile.toList, mockedChannelResponses) { case (inputs, control) => + for { + _ <- Processing.stream(control.environment).compile.drain + state <- control.state.get + } yield state should beEqualTo( + Vector( + Action.InitEventsTable, + Action.OpenedChannel, + Action.WroteRowsToSnowflake(1), + Action.SentToBad(1), + Action.AddedGoodCountMetric(1), + Action.AddedBadCountMetric(1), + Action.Checkpointed(List(inputs(0).ack)) + ) ) - ) + } } def e6 = { val mockedChannelResponses = List( - ChannelProvider.WriteResult.WriteFailures( + Channel.WriteResult.WriteFailures( List( - ChannelProvider.WriteFailure(0L, Nil, new SFException(ErrorCode.INTERNAL_ERROR)) + Channel.WriteFailure(0L, Nil, new SFException(ErrorCode.INTERNAL_ERROR)) ) ), - ChannelProvider.WriteResult.WriteFailures(Nil) + Channel.WriteResult.WriteFailures(Nil) ) - for { - inputs <- generateEvents.take(1).compile.toList - control <- MockEnvironment.build(inputs, mockedChannelResponses) - _ <- Processing.stream(control.environment).compile.drain.handleError(_ => ()) - state <- control.state.get - } yield state should beEqualTo( - Vector( - Action.WroteRowsToSnowflake(1) + setup(generateEvents.take(1).compile.toList, mockedChannelResponses) { case (_, control) => + for { + _ <- Processing.stream(control.environment).compile.drain.handleError(_ => ()) + state <- control.state.get + } yield state should beEqualTo( + Vector( + Action.InitEventsTable, + Action.OpenedChannel, + Action.WroteRowsToSnowflake(1) + ) ) - ) + } } def e7 = { val mockedChannelResponses = List( - ChannelProvider.WriteResult.ChannelIsInvalid, - ChannelProvider.WriteResult.WriteFailures(Nil) + Channel.WriteResult.ChannelIsInvalid, + Channel.WriteResult.WriteFailures(Nil) ) - for { - inputs <- generateEvents.take(1).compile.toList - control <- MockEnvironment.build(inputs, mockedChannelResponses) - _ <- Processing.stream(control.environment).compile.drain - state <- control.state.get - } yield state should beEqualTo( - Vector( - Action.ClosedChannel, - Action.OpenedChannel, - Action.WroteRowsToSnowflake(2), - Action.AddedGoodCountMetric(2), - Action.AddedBadCountMetric(0), - Action.Checkpointed(List(inputs(0).ack)) + setup(generateEvents.take(1).compile.toList, mockedChannelResponses) { case (inputs, control) => + for { + _ <- Processing.stream(control.environment).compile.drain + state <- control.state.get + } yield state should beEqualTo( + Vector( + Action.InitEventsTable, + Action.OpenedChannel, + Action.ClosedChannel, + Action.OpenedChannel, + Action.WroteRowsToSnowflake(2), + Action.AddedGoodCountMetric(2), + Action.AddedBadCountMetric(0), + Action.Checkpointed(List(inputs(0).ack)) + ) ) - ) + } } def e8 = { val messageTime = Instant.parse("2023-10-24T10:00:00.000Z") val processTime = Instant.parse("2023-10-24T10:00:42.123Z") - val io = for { - inputs <- generateEvents.take(2).compile.toList.map { - _.map { - _.copy(earliestSourceTstamp = Some(messageTime)) - } - } - control <- MockEnvironment.build(inputs) - _ <- IO.sleep(processTime.toEpochMilli.millis) - _ <- Processing.stream(control.environment).compile.drain - state <- control.state.get - } yield state should beEqualTo( - Vector( - Action.SetLatencyMetric(42123), - Action.SetLatencyMetric(42123), - Action.WroteRowsToSnowflake(4), - Action.AddedGoodCountMetric(4), - Action.AddedBadCountMetric(0), - Action.Checkpointed(List(inputs(0).ack, inputs(1).ack)) + val toInputs = generateEvents.take(2).compile.toList.map { + _.map { + _.copy(earliestSourceTstamp = Some(messageTime)) + } + } + + val io = setup(toInputs) { case (inputs, control) => + for { + _ <- IO.sleep(processTime.toEpochMilli.millis) + _ <- Processing.stream(control.environment).compile.drain + state <- control.state.get + } yield state should beEqualTo( + Vector( + Action.InitEventsTable, + Action.SetLatencyMetric(42123), + Action.SetLatencyMetric(42123), + Action.OpenedChannel, + Action.WroteRowsToSnowflake(4), + Action.AddedGoodCountMetric(4), + Action.AddedBadCountMetric(0), + Action.Checkpointed(List(inputs(0).ack, inputs(1).ack)) + ) ) - ) + } TestControl.executeEmbed(io) @@ -225,6 +245,18 @@ class ProcessingSpec extends Specification with CatsEffect { object ProcessingSpec { + def setup[A]( + toInputs: IO[List[TokenedEvents]], + channelResponses: List[Channel.WriteResult] = Nil + )( + f: (List[TokenedEvents], MockEnvironment) => IO[A] + ): IO[A] = + toInputs.flatMap { inputs => + MockEnvironment.build(inputs, channelResponses).use { control => + f(inputs, control) + } + } + def generateEvents: Stream[IO, TokenedEvents] = Stream.eval { for {