diff --git a/modules/core/src/main/scala/com.snowplowanalytics.snowplow.snowflake/processing/Channel.scala b/modules/core/src/main/scala/com.snowplowanalytics.snowplow.snowflake/processing/Channel.scala index 9922c16..4786b6f 100644 --- a/modules/core/src/main/scala/com.snowplowanalytics.snowplow.snowflake/processing/Channel.scala +++ b/modules/core/src/main/scala/com.snowplowanalytics.snowplow.snowflake/processing/Channel.scala @@ -11,6 +11,7 @@ package com.snowplowanalytics.snowplow.snowflake.processing import cats.effect.{Async, Poll, Resource, Sync} +import cats.effect.std.Semaphore import cats.implicits._ import com.snowplowanalytics.snowplow.runtime.AppHealth import com.snowplowanalytics.snowplow.runtime.processing.Coldswap @@ -108,8 +109,9 @@ object Channel { ): Resource[F, Opener[F]] = for { client <- createClient(config, batchingConfig, retriesConfig, appHealth) + semaphore <- Resource.eval(Semaphore[F](1L)) } yield new Opener[F] { - def open: F[CloseableChannel[F]] = createChannel[F](config, client).map(impl[F]) + def open: F[CloseableChannel[F]] = createChannel[F](config, client).map(impl[F](_, semaphore)) } def provider[F[_]: Async]( @@ -134,15 +136,24 @@ object Channel { Resource.makeFull(make)(_.close) } - private def impl[F[_]: Async](channel: SnowflakeStreamingIngestChannel): CloseableChannel[F] = + private def impl[F[_]: Async](channel: SnowflakeStreamingIngestChannel, semaphore: Semaphore[F]): CloseableChannel[F] = new CloseableChannel[F] { def write(rows: Iterable[Map[String, AnyRef]]): F[WriteResult] = { - val attempt: F[WriteResult] = for { - response <- Sync[F].blocking(channel.insertRows(rows.map(_.asJava).asJava, null)) - _ <- flushChannel[F](channel) - isValid <- Sync[F].delay(channel.isValid) - } yield if (isValid) WriteResult.WriteFailures(parseResponse(response)) else WriteResult.ChannelIsInvalid + + val attempt: F[WriteResult] = semaphore.permit + .surround { + for { + response <- Sync[F].blocking(channel.insertRows(rows.map(_.asJava).asJava, null)) + future <- Sync[F].delay(SnowsFlakePlowInterop.flushChannel(channel)) + } yield (future, response) + } + .flatMap { case (future, response) => + for { + _ <- Async[F].fromCompletableFuture(Sync[F].pure(future)) + isValid <- Sync[F].delay(channel.isValid) + } yield if (isValid) WriteResult.WriteFailures(parseResponse(response)) else WriteResult.ChannelIsInvalid + } attempt.recover { case sfe: SFException if sfe.getVendorCode === SFErrorCode.INVALID_CHANNEL.getMessageCode => @@ -238,15 +249,4 @@ object Channel { Resource.makeFull(make)(client => Sync[F].blocking(client.close())) } - /** - * Flushes the channel - * - * The public interface of the Snowflake SDK does not tell us when the events are safely written - * to Snowflake. So we must cast it to an Internal class so we get access to the `flush()` method. - */ - private def flushChannel[F[_]: Async](channel: SnowflakeStreamingIngestChannel): F[Unit] = - Async[F].fromCompletableFuture { - Async[F].delay(SnowsFlakePlowInterop.flushChannel(channel)) - }.void - } diff --git a/modules/core/src/main/scala/com.snowplowanalytics.snowplow.snowflake/processing/Processing.scala b/modules/core/src/main/scala/com.snowplowanalytics.snowplow.snowflake/processing/Processing.scala index 19b79c3..ef2adcc 100644 --- a/modules/core/src/main/scala/com.snowplowanalytics.snowplow.snowflake/processing/Processing.scala +++ b/modules/core/src/main/scala/com.snowplowanalytics.snowplow.snowflake/processing/Processing.scala @@ -22,6 +22,7 @@ import org.typelevel.log4cats.slf4j.Slf4jLogger import java.nio.charset.StandardCharsets import java.time.OffsetDateTime +import scala.concurrent.duration.DurationLong import com.snowplowanalytics.iglu.schemaddl.parquet.Caster import com.snowplowanalytics.snowplow.analytics.scalasdk.Event @@ -122,6 +123,7 @@ object Processing { in.through(setLatency(env.metrics)) .through(parseAndTransform(env, badProcessor)) .through(BatchUp.withTimeout(env.batching.maxBytes, env.batching.maxDelay)) + .prefetchN(env.batching.uploadConcurrency) .through(writeToSnowflake(env, badProcessor)) .through(sendFailedEvents(env, badProcessor)) .through(sendMetrics(env)) @@ -357,9 +359,23 @@ object Processing { env.metrics.addGood(batch.origBatchCount - countBad) *> env.metrics.addBad(countBad) } - private def emitTokens[F[_]]: Pipe[F, BatchAfterTransform, Unique.Token] = - _.flatMap { batch => - Stream.emits(batch.tokens) + private implicit def batchable2: BatchUp.Batchable[BatchAfterTransform, Vector[Unique.Token]] = + new BatchUp.Batchable[BatchAfterTransform, Vector[Unique.Token]] { + def combine(b: Vector[Unique.Token], a: BatchAfterTransform): Vector[Unique.Token] = + b ++ a.tokens + + def single(a: BatchAfterTransform): Vector[Unique.Token] = + a.tokens + + def weightOf(a: BatchAfterTransform): Long = + 0L + } + + private def emitTokens[F[_]: Async]: Pipe[F, BatchAfterTransform, Unique.Token] = + BatchUp.withTimeout[F, BatchAfterTransform, Vector[Unique.Token]](Long.MaxValue, 10.seconds).andThen { + _.flatMap { tokens => + Stream.emits(tokens) + } } private def fastGetByIndex[A](items: IndexedSeq[A], index: Long): A = items(index.toInt)