Skip to content

Commit

Permalink
Parallelism amendments
Browse files Browse the repository at this point in the history
  • Loading branch information
istreeter committed Nov 16, 2024
1 parent e4a3e07 commit df2089d
Show file tree
Hide file tree
Showing 13 changed files with 107 additions and 66 deletions.
2 changes: 1 addition & 1 deletion config/config.azure.reference.hocon
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
"batching": {

# - Events are emitted to Snowflake when the batch reaches this size in bytes
"maxBytes": 16000000
"maxBytes": 64000000

# - Events are emitted to Snowflake for a maximum of this duration, even if the `maxBytes` size has not been reached
"maxDelay": "1 second"
Expand Down
2 changes: 1 addition & 1 deletion config/config.kinesis.reference.hocon
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@
"batching": {

# - Events are emitted to Snowflake when the batch reaches this size in bytes
"maxBytes": 16000000
"maxBytes": 64000000

# - Events are emitted to Snowflake for a maximum of this duration, even if the `maxBytes` size has not been reached
"maxDelay": "1 second"
Expand Down
2 changes: 1 addition & 1 deletion config/config.pubsub.reference.hocon
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@
"batching": {

# - Events are emitted to Snowflake when the batch reaches this size in bytes
"maxBytes": 16000000
"maxBytes": 64000000

# - Events are emitted to Snowflake for a maximum of this duration, even if the `maxBytes` size has not been reached
"maxDelay": "1 second"
Expand Down
2 changes: 1 addition & 1 deletion modules/core/src/main/resources/reference.conf
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
}

"batching": {
"maxBytes": 16000000
"maxBytes": 64000000
"maxDelay": "1 second"
"uploadConcurrency": 3
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
package com.snowplowanalytics.snowplow.snowflake.processing

import cats.effect.{Async, Poll, Resource, Sync}
import cats.effect.std.Semaphore
import cats.implicits._
import com.snowplowanalytics.snowplow.runtime.AppHealth
import com.snowplowanalytics.snowplow.runtime.processing.Coldswap
Expand All @@ -24,6 +25,7 @@ import org.typelevel.log4cats.slf4j.Slf4jLogger
import java.time.ZoneOffset
import java.util.Properties
import scala.jdk.CollectionConverters._
import scala.concurrent.duration.DurationLong

trait Channel[F[_]] {

Expand All @@ -35,7 +37,7 @@ trait Channel[F[_]] {
* @return
* List of the details of any insert failures. Empty list implies complete success.
*/
def write(rows: Iterable[Map[String, AnyRef]]): F[Channel.WriteResult]
def write(rows: List[Iterable[Map[String, AnyRef]]]): F[Channel.WriteResult]
}

object Channel {
Expand All @@ -62,7 +64,9 @@ object Channel {

/**
* The result of trying to enqueue an event for sending to Snowflake
* @param index
* @param outerIndex
* Refers to the batch number in the list of attempted batches
* @param innerIndex
* Refers to the row number in the batch of attempted events
* @param extraCols
* The column names which were present in the batch but missing in the table
Expand All @@ -71,7 +75,8 @@ object Channel {
* enqueue
*/
case class WriteFailure(
index: Long,
outerIndex: Long,
innerIndex: Long,
extraCols: List[String],
cause: SFException
)
Expand Down Expand Up @@ -108,8 +113,9 @@ object Channel {
): Resource[F, Opener[F]] =
for {
client <- createClient(config, batchingConfig, retriesConfig, appHealth)
semaphore <- Resource.eval(Semaphore[F](1L))
} yield new Opener[F] {
def open: F[CloseableChannel[F]] = createChannel[F](config, client).map(impl[F])
def open: F[CloseableChannel[F]] = createChannel[F](config, client).map(impl[F](_, semaphore))
}

def provider[F[_]: Async](
Expand All @@ -134,15 +140,32 @@ object Channel {
Resource.makeFull(make)(_.close)
}

private def impl[F[_]: Async](channel: SnowflakeStreamingIngestChannel): CloseableChannel[F] =
private def impl[F[_]: Async](channel: SnowflakeStreamingIngestChannel, semaphore: Semaphore[F]): CloseableChannel[F] =
new CloseableChannel[F] {

def write(rows: Iterable[Map[String, AnyRef]]): F[WriteResult] = {
val attempt: F[WriteResult] = for {
response <- Sync[F].blocking(channel.insertRows(rows.map(_.asJava).asJava, null))
_ <- flushChannel[F](channel)
isValid <- Sync[F].delay(channel.isValid)
} yield if (isValid) WriteResult.WriteFailures(parseResponse(response)) else WriteResult.ChannelIsInvalid
def write(rows: List[Iterable[Map[String, AnyRef]]]): F[WriteResult] = {

val attempt: F[WriteResult] = semaphore.permit
.surround {
for {
responses <- rows.traverse { inner =>
Sync[F].blocking(channel.insertRows(inner.map(_.asJava).asJava, null))
}
future <- Sync[F].delay(SnowsFlakePlowInterop.flushChannel(channel))
_ <- Sync[F].untilDefinedM {
for {
_ <- Sync[F].sleep(100.millis)
isEmpty <- Sync[F].delay(SnowsFlakePlowInterop.isEmpty(channel))
} yield if (isEmpty) Some(()) else None
}
} yield (future, responses)
}
.flatMap { case (future, responses) =>
for {
_ <- Async[F].fromCompletableFuture(Sync[F].pure(future))
isValid <- Sync[F].delay(channel.isValid)
} yield if (isValid) WriteResult.WriteFailures(parseResponse(responses)) else WriteResult.ChannelIsInvalid
}

attempt.recover {
case sfe: SFException if sfe.getVendorCode === SFErrorCode.INVALID_CHANNEL.getMessageCode =>
Expand All @@ -166,14 +189,17 @@ object Channel {
}
}

private def parseResponse(response: InsertValidationResponse): List[WriteFailure] =
response.getInsertErrors.asScala.map { insertError =>
WriteFailure(
insertError.getRowIndex,
Option(insertError.getExtraColNames).fold(List.empty[String])(_.asScala.toList),
insertError.getException
)
}.toList
private def parseResponse(responses: List[InsertValidationResponse]): List[WriteFailure] =
responses.zipWithIndex.flatMap { case (response, outerIndex) =>
response.getInsertErrors.asScala.map { insertError =>
WriteFailure(
outerIndex.toLong,
insertError.getRowIndex,
Option(insertError.getExtraColNames).fold(List.empty[String])(_.asScala.toList),
insertError.getException
)
}
}

private def createChannel[F[_]: Async](
config: Config.Snowflake,
Expand Down Expand Up @@ -211,7 +237,9 @@ object Channel {
props.setProperty(ParameterProvider.INSERT_THROTTLE_THRESHOLD_IN_PERCENTAGE, "0")
props.setProperty(ParameterProvider.INSERT_THROTTLE_THRESHOLD_IN_BYTES, "0")
props.setProperty(ParameterProvider.MAX_CHANNEL_SIZE_IN_BYTES, Long.MaxValue.toString)
props.setProperty(ParameterProvider.IO_TIME_CPU_RATIO, batchingConfig.uploadConcurrency.toString)

val _ = batchingConfig
// props.setProperty(ParameterProvider.IO_TIME_CPU_RATIO, batchingConfig.uploadConcurrency.toString)

props
}
Expand All @@ -238,15 +266,4 @@ object Channel {
Resource.makeFull(make)(client => Sync[F].blocking(client.close()))
}

/**
* Flushes the channel
*
* The public interface of the Snowflake SDK does not tell us when the events are safely written
* to Snowflake. So we must cast it to an Internal class so we get access to the `flush()` method.
*/
private def flushChannel[F[_]: Async](channel: SnowflakeStreamingIngestChannel): F[Unit] =
Async[F].fromCompletableFuture {
Async[F].delay(SnowsFlakePlowInterop.flushChannel(channel))
}.void

}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.typelevel.log4cats.slf4j.Slf4jLogger

import java.nio.charset.StandardCharsets
import java.time.OffsetDateTime
import scala.concurrent.duration.DurationLong

import com.snowplowanalytics.iglu.schemaddl.parquet.Caster
import com.snowplowanalytics.snowplow.analytics.scalasdk.Event
Expand Down Expand Up @@ -74,7 +75,7 @@ object Processing {
* The tokens to be emitted after we have finished processing all events
*/
private case class BatchAfterTransform(
toBeInserted: ListOfList[EventWithTransform],
toBeInserted: List[List[EventWithTransform]],
origBatchBytes: Long,
origBatchCount: Int,
badAccumulated: ListOfList[BadRow],
Expand All @@ -99,15 +100,15 @@ object Processing {
)

private object ParsedWriteResult {
def empty: ParsedWriteResult = ParsedWriteResult(Set.empty, Nil, Nil)
private def empty: ParsedWriteResult = ParsedWriteResult(Set.empty, Nil, Nil)

def buildFrom(events: ListOfList[EventWithTransform], writeFailures: List[Channel.WriteFailure]): ParsedWriteResult =
def buildFrom(events: List[List[EventWithTransform]], writeFailures: List[Channel.WriteFailure]): ParsedWriteResult =
if (writeFailures.isEmpty)
empty
else {
val indexed = events.copyToIndexedSeq
val indexed = events.map(_.toIndexedSeq).toIndexedSeq
writeFailures.foldLeft(ParsedWriteResult.empty) { case (ParsedWriteResult(extraCols, eventsWithExtraCols, unexpected), failure) =>
val event = fastGetByIndex(indexed, failure.index)
val event = fastGetByIndex(indexed, failure.outerIndex, failure.innerIndex)
if (failure.extraCols.nonEmpty)
ParsedWriteResult(extraCols ++ failure.extraCols, event :: eventsWithExtraCols, unexpected)
else
Expand All @@ -122,6 +123,7 @@ object Processing {
in.through(setLatency(env.metrics))
.through(parseAndTransform(env, badProcessor))
.through(BatchUp.withTimeout(env.batching.maxBytes, env.batching.maxDelay))
.prefetchN(2)
.through(writeToSnowflake(env, badProcessor))
.through(sendFailedEvents(env, badProcessor))
.through(sendMetrics(env))
Expand Down Expand Up @@ -208,7 +210,7 @@ object Processing {
Sync[F].untilDefinedM {
env.channel.opened
.use { channel =>
channel.write(batch.toBeInserted.asIterable.map(_._2))
channel.write(batch.toBeInserted.map(_.view.map(_._2)))
}
.flatMap {
case Channel.WriteResult.ChannelIsInvalid =>
Expand Down Expand Up @@ -246,7 +248,7 @@ object Processing {
badRowFromEnqueueFailure(badProcessor, event, sfe)
}
batch.copy(
toBeInserted = ListOfList.ofLists(parsedResult.eventsWithExtraCols),
toBeInserted = parsedResult.eventsWithExtraCols.grouped(100).toList,
badAccumulated = batch.badAccumulated.prepend(moreBad)
)
}
Expand All @@ -267,15 +269,15 @@ object Processing {
val mapped = notWritten match {
case Nil => Nil
case more =>
val indexed = batch.toBeInserted.copyToIndexedSeq
more.map(f => (fastGetByIndex(indexed, f.index)._1, f.cause))
val indexed = batch.toBeInserted.map(_.toIndexedSeq).toIndexedSeq
more.map(f => (fastGetByIndex(indexed, f.outerIndex, f.innerIndex)._1, f.cause))
}
abortIfFatalException[F](mapped).as {
val moreBad = mapped.map { case (event, sfe) =>
badRowFromEnqueueFailure(badProcessor, event, sfe)
}
batch.copy(
toBeInserted = ListOfList.empty,
toBeInserted = Nil,
badAccumulated = batch.badAccumulated.prepend(moreBad)
)
}
Expand Down Expand Up @@ -357,18 +359,36 @@ object Processing {
env.metrics.addGood(batch.origBatchCount - countBad) *> env.metrics.addBad(countBad)
}

private def emitTokens[F[_]]: Pipe[F, BatchAfterTransform, Unique.Token] =
_.flatMap { batch =>
Stream.emits(batch.tokens)
private implicit def batchableTokens: BatchUp.Batchable[BatchAfterTransform, Vector[Unique.Token]] =
new BatchUp.Batchable[BatchAfterTransform, Vector[Unique.Token]] {
def combine(b: Vector[Unique.Token], a: BatchAfterTransform): Vector[Unique.Token] =
b ++ a.tokens

def single(a: BatchAfterTransform): Vector[Unique.Token] =
a.tokens

def weightOf(a: BatchAfterTransform): Long =
0L
}

private def emitTokens[F[_]: Async]: Pipe[F, BatchAfterTransform, Unique.Token] =
BatchUp.withTimeout[F, BatchAfterTransform, Vector[Unique.Token]](Long.MaxValue, 10.seconds).andThen {
_.flatMap { tokens =>
Stream.emits(tokens)
}
}

private def fastGetByIndex[A](items: IndexedSeq[A], index: Long): A = items(index.toInt)
private def fastGetByIndex[A](
items: IndexedSeq[IndexedSeq[A]],
outerIndex: Long,
innerIndex: Long
): A = items(outerIndex.toInt)(innerIndex.toInt)

private implicit def batchable: BatchUp.Batchable[TransformedBatch, BatchAfterTransform] =
new BatchUp.Batchable[TransformedBatch, BatchAfterTransform] {
def combine(b: BatchAfterTransform, a: TransformedBatch): BatchAfterTransform =
BatchAfterTransform(
toBeInserted = b.toBeInserted.prepend(a.events),
toBeInserted = if (a.events.isEmpty) b.toBeInserted else a.events :: b.toBeInserted,
origBatchBytes = b.origBatchBytes + a.countBytes,
origBatchCount = b.origBatchCount + a.countItems,
badAccumulated = b.badAccumulated.prepend(a.parseFailures).prepend(a.transformFailures),
Expand All @@ -377,7 +397,7 @@ object Processing {

def single(a: TransformedBatch): BatchAfterTransform =
BatchAfterTransform(
ListOfList.of(List(a.events)),
if (a.events.isEmpty) Nil else List(a.events),
a.countBytes,
a.countItems,
ListOfList.ofLists(a.parseFailures, a.transformFailures),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

package net.snowflake.ingest.streaming.internal

import cats.implicits._
import net.snowflake.ingest.streaming.SnowflakeStreamingIngestChannel
import net.snowflake.ingest.streaming.internal.SnowflakeStreamingIngestChannelInternal

Expand All @@ -26,4 +27,7 @@ object SnowsFlakePlowInterop {
def flushChannel(channel: SnowflakeStreamingIngestChannel): CompletableFuture[Void] =
channel.asInstanceOf[SnowflakeStreamingIngestChannelInternal[_]].flush(false)

def isEmpty(channel: SnowflakeStreamingIngestChannel): Boolean =
channel.asInstanceOf[SnowflakeStreamingIngestChannelInternal[_]].getRowBuffer.getSize === 0.0f

}
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ object MockEnvironment {
metrics = testMetrics(state),
appHealth = testAppHealth(state),
batching = Config.Batching(
maxBytes = 16000000,
maxBytes = 64000000,
maxDelay = 10.seconds,
uploadConcurrency = 1
),
Expand Down Expand Up @@ -139,7 +139,7 @@ object MockEnvironment {
Ref[IO].of(mockedResponses).map { responses =>
val make = actionRef.update(_ :+ OpenedChannel).as {
new Channel[IO] {
def write(rows: Iterable[Map[String, AnyRef]]): IO[Channel.WriteResult] =
def write(rows: List[Iterable[Map[String, AnyRef]]]): IO[Channel.WriteResult] =
for {
response <- responses.modify {
case head :: tail => (tail, head)
Expand All @@ -155,12 +155,12 @@ object MockEnvironment {

def updateActions(
state: Ref[IO, Vector[Action]],
rows: Iterable[Map[String, AnyRef]],
rows: List[Iterable[Map[String, AnyRef]]],
success: Response.Success[Channel.WriteResult]
): IO[Unit] =
success.value match {
case Channel.WriteResult.WriteFailures(failures) =>
state.update(_ :+ WroteRowsToSnowflake(rows.size - failures.size))
state.update(_ :+ WroteRowsToSnowflake(rows.flatten.size - failures.size))
case Channel.WriteResult.ChannelIsInvalid =>
IO.unit
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ object ChannelProviderSpec {
}

private def testCloseableChannel(state: Ref[IO, Vector[Action]]): Channel.CloseableChannel[IO] = new Channel.CloseableChannel[IO] {
def write(rows: Iterable[Map[String, AnyRef]]): IO[Channel.WriteResult] = IO.pure(Channel.WriteResult.WriteFailures(Nil))
def write(rows: List[Iterable[Map[String, AnyRef]]]): IO[Channel.WriteResult] = IO.pure(Channel.WriteResult.WriteFailures(Nil))

def close: IO[Unit] = state.update(_ :+ Action.ClosedChannel)
}
Expand Down
Loading

0 comments on commit df2089d

Please sign in to comment.