Skip to content

Commit

Permalink
Add IO parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
RustedBones committed Nov 18, 2024
1 parent 48252d2 commit 537f4f1
Show file tree
Hide file tree
Showing 5 changed files with 381 additions and 219 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,111 +20,201 @@ package com.spotify.scio.snowflake
import scala.util.chaining._
import com.spotify.scio.ScioContext
import com.spotify.scio.coders.{Coder, CoderMaterializer}
import com.spotify.scio.io.{EmptyTap, EmptyTapOf, ScioIO, Tap, TapT}
import com.spotify.scio.io.{EmptyTap, EmptyTapOf, ScioIO, Tap, TapT, TestIO}
import com.spotify.scio.util.ScioUtil
import com.spotify.scio.values.SCollection
import kantan.csv.{RowDecoder, RowEncoder}
import kantan.csv.{RowCodec, RowDecoder, RowEncoder}
import org.apache.beam.sdk.io.snowflake.SnowflakeIO.{CsvMapper, UserDataMapper}
import org.apache.beam.sdk.io.snowflake.data.SnowflakeTableSchema
import org.apache.beam.sdk.io.snowflake.enums.{CreateDisposition, WriteDisposition}
import org.apache.beam.sdk.io.{snowflake => beam}
import org.joda.time.Duration

object SnowflakeIO {

final def apply[T](opts: SnowflakeConnectionOptions, query: String): SnowflakeIO[T] =
new SnowflakeIO[T] with TestIO[T] {
final override val tapT = EmptyTapOf[T]
override def testId: String = s"SnowflakeIO(${snowflakeIoId(opts, query)})"
}

private[snowflake] def snowflakeIoId(opts: SnowflakeConnectionOptions, target: String): String = {
// source params
val params = Option(opts.database).map(db => s"db=$db") ++
Option(opts.warehouse).map(db => s"warehouse=$db")
s"${opts.url}${params.mkString("?", "&", "")}:$target"
}

object ReadParam {
type ConfigOverride[T] = beam.SnowflakeIO.Read[T] => beam.SnowflakeIO.Read[T]

val DefaultStagingBucketName: String = null
val DefaultQuotationMark: String = null
val DefaultConfigOverride = null
}
final case class ReadParam[T](
storageIntegrationName: String,
stagingBucketName: String = ReadParam.DefaultStagingBucketName,
quotationMark: String = ReadParam.DefaultQuotationMark,
configOverride: ReadParam.ConfigOverride[T] = ReadParam.DefaultConfigOverride
)

object WriteParam {
type ConfigOverride[T] = beam.SnowflakeIO.Write[T] => beam.SnowflakeIO.Write[T]

val DefaultTableSchema: SnowflakeTableSchema = null
val DefaultCreateDisposition: CreateDisposition = null
val DefaultWriteDisposition: WriteDisposition = null
val DefaultSnowPipe: String = null
val DefaultShardNumber: Integer = null
val DefaultFlushRowLimit: Integer = null
val DefaultFlushTimeLimit: Duration = null
val DefaultStorageIntegrationName: String = null
val DefaultStagingBucketName: String = null
val DefaultQuotationMark: String = null
val DefaultConfigOverride = null
}
final case class WriteParam[T](
tableSchema: SnowflakeTableSchema = WriteParam.DefaultTableSchema,
createDisposition: CreateDisposition = WriteParam.DefaultCreateDisposition,
writeDisposition: WriteDisposition = WriteParam.DefaultWriteDisposition,
snowPipe: String = WriteParam.DefaultSnowPipe,
shardNumber: Integer = WriteParam.DefaultShardNumber,
flushRowLimit: Integer = WriteParam.DefaultFlushRowLimit,
flushTimeLimit: Duration = WriteParam.DefaultFlushTimeLimit,
storageIntegrationName: String = WriteParam.DefaultStorageIntegrationName,
stagingBucketName: String = WriteParam.DefaultStagingBucketName,
quotationMark: String = WriteParam.DefaultQuotationMark,
configOverride: WriteParam.ConfigOverride[T] = WriteParam.DefaultConfigOverride
)

private[snowflake] def dataSourceConfiguration(connectionOptions: SnowflakeConnectionOptions) =
beam.SnowflakeIO.DataSourceConfiguration
.create()
.pipe(ds =>
connectionOptions.authenticationOptions match {
case SnowflakeUsernamePasswordAuthenticationOptions(username, password) =>
.withUrl(connectionOptions.url)
.pipe { ds =>
import SnowflakeAuthenticationOptions._
Option(connectionOptions.authenticationOptions).fold(ds) {
case UsernamePassword(username, password) =>
ds.withUsernamePasswordAuth(username, password)
case SnowflakeKeyPairAuthenticationOptions(username, privateKeyPath, None) =>
case KeyPair(username, privateKeyPath, None) =>
ds.withKeyPairPathAuth(username, privateKeyPath)
case SnowflakeKeyPairAuthenticationOptions(username, privateKeyPath, Some(passphrase)) =>
case KeyPair(username, privateKeyPath, Some(passphrase)) =>
ds.withKeyPairPathAuth(username, privateKeyPath, passphrase)
case SnowflakeOAuthTokenAuthenticationOptions(token) =>
ds.withOAuth(token)
case OAuthToken(token) =>
ds.withOAuth(token).withAuthenticator("oauth")
}
)
}
.pipe(ds => Option(connectionOptions.database).fold(ds)(ds.withDatabase))
.pipe(ds => Option(connectionOptions.role).fold(ds)(ds.withRole))
.pipe(ds => Option(connectionOptions.warehouse).fold(ds)(ds.withWarehouse))
.pipe(ds =>
ds
.withServerName(connectionOptions.serverName)
.withDatabase(connectionOptions.database)
.withRole(connectionOptions.role)
.withWarehouse(connectionOptions.warehouse)
Option(connectionOptions.loginTimeout)
.map[Integer](_.getStandardSeconds.toInt)
.fold(ds)(ds.withLoginTimeout)
)
.pipe(ds => connectionOptions.schema.fold(ds)(ds.withSchema))

private[snowflake] def buildCsvMapper[T](rowDecoder: RowDecoder[T]): CsvMapper[T] =
new CsvMapper[T] {
override def mapRow(parts: Array[String]): T = {
val unsnowedParts = parts.map {
case "\\N" => "" // needs to be mapped to an Option
case other => other
}.toSeq
rowDecoder.unsafeDecode(unsnowedParts)
}
}
.pipe(ds => Option(connectionOptions.schema).fold(ds)(ds.withSchema))

private[snowflake] def csvMapper[T: RowDecoder]: CsvMapper[T] = { (parts: Array[String]) =>
val unsnowedParts = parts.map {
case "\\N" => "" // needs to be mapped to an Option
case other => other
}.toSeq
RowDecoder[T].unsafeDecode(unsnowedParts)
}

private[snowflake] def prepareRead[T](
snowflakeOptions: SnowflakeOptions,
sc: ScioContext
)(implicit rowDecoder: RowDecoder[T], coder: Coder[T]): beam.SnowflakeIO.Read[T] =
beam.SnowflakeIO
.read()
.withDataSourceConfiguration(
SnowflakeIO.dataSourceConfiguration(snowflakeOptions.connectionOptions)
)
.withStagingBucketName(snowflakeOptions.stagingBucketName)
.withStorageIntegrationName(snowflakeOptions.storageIntegrationName)
.withCsvMapper(buildCsvMapper(rowDecoder))
.withCoder(CoderMaterializer.beam(sc, coder))
private[snowflake] def userDataMapper[T: RowEncoder]: UserDataMapper[T] = { (element: T) =>
RowEncoder[T].encode(element).toArray
}
}

sealed trait SnowflakeIO[T] extends ScioIO[T]

final case class SnowflakeSelect[T](snowflakeOptions: SnowflakeOptions, select: String)(implicit
final case class SnowflakeSelect[T](connectionOptions: SnowflakeConnectionOptions, select: String)(
implicit
rowDecoder: RowDecoder[T],
coder: Coder[T]
) extends SnowflakeIO[T] {

override type ReadP = Unit
import SnowflakeIO._

override type ReadP = ReadParam[T]
override type WriteP = Unit
override val tapT: TapT.Aux[T, Nothing] = EmptyTapOf[T]

override protected def read(sc: ScioContext, params: ReadP): SCollection[T] =
sc.applyTransform(SnowflakeIO.prepareRead(snowflakeOptions, sc).fromQuery(select))
override def testId: String = s"SnowflakeIO(${snowflakeIoId(connectionOptions, select)})"

override protected def read(sc: ScioContext, params: ReadP): SCollection[T] = {
val tempDirectory = ScioUtil.tempDirOrDefault(params.stagingBucketName, sc).toString
val t = beam.SnowflakeIO
.read[T]()
.fromQuery(select)
.withDataSourceConfiguration(dataSourceConfiguration(connectionOptions))
.withStorageIntegrationName(params.storageIntegrationName)
.withStagingBucketName(tempDirectory)
.pipe(r => Option(params.quotationMark).fold(r)(r.withQuotationMark))
.withCsvMapper(csvMapper)
.withCoder(CoderMaterializer.beam(sc, coder))
.pipe(r => Option(params.configOverride).fold(r)(_(r)))

sc.applyTransform(t)
}

override protected def write(data: SCollection[T], params: WriteP): Tap[Nothing] =
throw new UnsupportedOperationException("SnowflakeSelect is read-only")

override def tap(params: ReadP): Tap[Nothing] = EmptyTap
}

final case class SnowflakeTable[T](snowflakeOptions: SnowflakeOptions, table: String)(implicit
rowDecoder: RowDecoder[T],
rowEncoder: RowEncoder[T],
final case class SnowflakeTable[T](connectionOptions: SnowflakeConnectionOptions, table: String)(
implicit
rowCodec: RowCodec[T], // use codec for tap
coder: Coder[T]
) extends SnowflakeIO[T] {

override type ReadP = Unit
override type WriteP = Unit
override val tapT: TapT.Aux[T, Nothing] = EmptyTapOf[T]
import SnowflakeIO._

override type ReadP = ReadParam[T]
override type WriteP = WriteParam[T]
override val tapT: TapT.Aux[T, Nothing] = EmptyTapOf[T] // TODO Create a tap

override def testId: String = s"SnowflakeIO(${snowflakeIoId(connectionOptions, table)})"

override protected def read(sc: ScioContext, params: ReadP): SCollection[T] =
sc.applyTransform(SnowflakeIO.prepareRead(snowflakeOptions, sc).fromTable(table))
override protected def read(sc: ScioContext, params: ReadP): SCollection[T] = {
val tempDirectory = ScioUtil.tempDirOrDefault(params.stagingBucketName, sc).toString
val t = beam.SnowflakeIO
.read[T]()
.fromTable(table)
.withDataSourceConfiguration(dataSourceConfiguration(connectionOptions))
.withStorageIntegrationName(params.storageIntegrationName)
.withStagingBucketName(tempDirectory)
.pipe(r => Option(params.quotationMark).fold(r)(r.withQuotationMark))
.withCsvMapper(csvMapper)
.withCoder(CoderMaterializer.beam(sc, coder))
.pipe(r => Option(params.configOverride).fold(r)(_(r)))

sc.applyTransform(t)
}

override protected def write(data: SCollection[T], params: WriteP): Tap[Nothing] = {
data.applyInternal(
beam.SnowflakeIO
.write[T]()
.withDataSourceConfiguration(
SnowflakeIO.dataSourceConfiguration(snowflakeOptions.connectionOptions)
)
.to(table)
.withStagingBucketName(snowflakeOptions.stagingBucketName)
.withStorageIntegrationName(snowflakeOptions.storageIntegrationName)
.withUserDataMapper(new UserDataMapper[T] {
override def mapRow(element: T): Array[AnyRef] = rowEncoder.encode(element).toArray
})
)
val tempDirectory = ScioUtil.tempDirOrDefault(params.stagingBucketName, data.context).toString
val t = beam.SnowflakeIO
.write[T]()
.withDataSourceConfiguration(dataSourceConfiguration(connectionOptions))
.to(table)
.pipe(w => Option(params.createDisposition).fold(w)(w.withCreateDisposition))
.pipe(w => Option(params.writeDisposition).fold(w)(w.withWriteDisposition))
.pipe(w => Option(params.snowPipe).fold(w)(w.withSnowPipe))
.pipe(w => Option(params.shardNumber).fold(w)(w.withShardsNumber))
.pipe(w => Option(params.flushRowLimit).fold(w)(w.withFlushRowLimit))
.pipe(w => Option(params.flushTimeLimit).fold(w)(w.withFlushTimeLimit))
.pipe(w => Option(params.quotationMark).fold(w)(w.withQuotationMark))
.pipe(w => Option(params.storageIntegrationName).fold(w)(w.withStorageIntegrationName))
.withStagingBucketName(tempDirectory)
.withUserDataMapper(userDataMapper)
.pipe(w => Option(params.configOverride).fold(w)(_(w)))

data.applyInternal(t)
EmptyTap
}

Expand Down
Loading

0 comments on commit 537f4f1

Please sign in to comment.