diff --git a/scio-snowflake/src/main/scala/com/spotify/scio/snowflake/SnowflakeIO.scala b/scio-snowflake/src/main/scala/com/spotify/scio/snowflake/SnowflakeIO.scala index de38cfec34..a710a1d1d6 100644 --- a/scio-snowflake/src/main/scala/com/spotify/scio/snowflake/SnowflakeIO.scala +++ b/scio-snowflake/src/main/scala/com/spotify/scio/snowflake/SnowflakeIO.scala @@ -20,77 +20,145 @@ package com.spotify.scio.snowflake import scala.util.chaining._ import com.spotify.scio.ScioContext import com.spotify.scio.coders.{Coder, CoderMaterializer} -import com.spotify.scio.io.{EmptyTap, EmptyTapOf, ScioIO, Tap, TapT} +import com.spotify.scio.io.{EmptyTap, EmptyTapOf, ScioIO, Tap, TapT, TestIO} +import com.spotify.scio.util.ScioUtil import com.spotify.scio.values.SCollection -import kantan.csv.{RowDecoder, RowEncoder} +import kantan.csv.{RowCodec, RowDecoder, RowEncoder} import org.apache.beam.sdk.io.snowflake.SnowflakeIO.{CsvMapper, UserDataMapper} +import org.apache.beam.sdk.io.snowflake.data.SnowflakeTableSchema +import org.apache.beam.sdk.io.snowflake.enums.{CreateDisposition, WriteDisposition} import org.apache.beam.sdk.io.{snowflake => beam} +import org.joda.time.Duration object SnowflakeIO { + final def apply[T](opts: SnowflakeConnectionOptions, query: String): SnowflakeIO[T] = + new SnowflakeIO[T] with TestIO[T] { + final override val tapT = EmptyTapOf[T] + override def testId: String = s"SnowflakeIO(${snowflakeIoId(opts, query)})" + } + + private[snowflake] def snowflakeIoId(opts: SnowflakeConnectionOptions, target: String): String = { + // source params + val params = Option(opts.database).map(db => s"db=$db") ++ + Option(opts.warehouse).map(db => s"warehouse=$db") + s"${opts.url}${params.mkString("?", "&", "")}:$target" + } + + object ReadParam { + type ConfigOverride[T] = beam.SnowflakeIO.Read[T] => beam.SnowflakeIO.Read[T] + + val DefaultStagingBucketName: String = null + val DefaultQuotationMark: String = null + val DefaultConfigOverride = null + } + final case class ReadParam[T]( + storageIntegrationName: String, + stagingBucketName: String = ReadParam.DefaultStagingBucketName, + quotationMark: String = ReadParam.DefaultQuotationMark, + configOverride: ReadParam.ConfigOverride[T] = ReadParam.DefaultConfigOverride + ) + + object WriteParam { + type ConfigOverride[T] = beam.SnowflakeIO.Write[T] => beam.SnowflakeIO.Write[T] + + val DefaultTableSchema: SnowflakeTableSchema = null + val DefaultCreateDisposition: CreateDisposition = null + val DefaultWriteDisposition: WriteDisposition = null + val DefaultSnowPipe: String = null + val DefaultShardNumber: Integer = null + val DefaultFlushRowLimit: Integer = null + val DefaultFlushTimeLimit: Duration = null + val DefaultStorageIntegrationName: String = null + val DefaultStagingBucketName: String = null + val DefaultQuotationMark: String = null + val DefaultConfigOverride = null + } + final case class WriteParam[T]( + tableSchema: SnowflakeTableSchema = WriteParam.DefaultTableSchema, + createDisposition: CreateDisposition = WriteParam.DefaultCreateDisposition, + writeDisposition: WriteDisposition = WriteParam.DefaultWriteDisposition, + snowPipe: String = WriteParam.DefaultSnowPipe, + shardNumber: Integer = WriteParam.DefaultShardNumber, + flushRowLimit: Integer = WriteParam.DefaultFlushRowLimit, + flushTimeLimit: Duration = WriteParam.DefaultFlushTimeLimit, + storageIntegrationName: String = WriteParam.DefaultStorageIntegrationName, + stagingBucketName: String = WriteParam.DefaultStagingBucketName, + quotationMark: String = WriteParam.DefaultQuotationMark, + configOverride: WriteParam.ConfigOverride[T] = WriteParam.DefaultConfigOverride + ) + private[snowflake] def dataSourceConfiguration(connectionOptions: SnowflakeConnectionOptions) = beam.SnowflakeIO.DataSourceConfiguration .create() - .pipe(ds => - connectionOptions.authenticationOptions match { - case SnowflakeUsernamePasswordAuthenticationOptions(username, password) => + .withUrl(connectionOptions.url) + .pipe { ds => + import SnowflakeAuthenticationOptions._ + Option(connectionOptions.authenticationOptions).fold(ds) { + case UsernamePassword(username, password) => ds.withUsernamePasswordAuth(username, password) - case SnowflakeKeyPairAuthenticationOptions(username, privateKeyPath, None) => + case KeyPair(username, privateKeyPath, None) => ds.withKeyPairPathAuth(username, privateKeyPath) - case SnowflakeKeyPairAuthenticationOptions(username, privateKeyPath, Some(passphrase)) => + case KeyPair(username, privateKeyPath, Some(passphrase)) => ds.withKeyPairPathAuth(username, privateKeyPath, passphrase) - case SnowflakeOAuthTokenAuthenticationOptions(token) => - ds.withOAuth(token) + case OAuthToken(token) => + ds.withOAuth(token).withAuthenticator("oauth") } - ) + } + .pipe(ds => Option(connectionOptions.database).fold(ds)(ds.withDatabase)) + .pipe(ds => Option(connectionOptions.role).fold(ds)(ds.withRole)) + .pipe(ds => Option(connectionOptions.warehouse).fold(ds)(ds.withWarehouse)) .pipe(ds => - ds - .withServerName(connectionOptions.serverName) - .withDatabase(connectionOptions.database) - .withRole(connectionOptions.role) - .withWarehouse(connectionOptions.warehouse) + Option(connectionOptions.loginTimeout) + .map[Integer](_.getStandardSeconds.toInt) + .fold(ds)(ds.withLoginTimeout) ) - .pipe(ds => connectionOptions.schema.fold(ds)(ds.withSchema)) - - private[snowflake] def buildCsvMapper[T](rowDecoder: RowDecoder[T]): CsvMapper[T] = - new CsvMapper[T] { - override def mapRow(parts: Array[String]): T = { - val unsnowedParts = parts.map { - case "\\N" => "" // needs to be mapped to an Option - case other => other - }.toSeq - rowDecoder.unsafeDecode(unsnowedParts) - } - } + .pipe(ds => Option(connectionOptions.schema).fold(ds)(ds.withSchema)) + + private[snowflake] def csvMapper[T: RowDecoder]: CsvMapper[T] = { (parts: Array[String]) => + val unsnowedParts = parts.map { + case "\\N" => "" // needs to be mapped to an Option + case other => other + }.toSeq + RowDecoder[T].unsafeDecode(unsnowedParts) + } - private[snowflake] def prepareRead[T]( - snowflakeOptions: SnowflakeOptions, - sc: ScioContext - )(implicit rowDecoder: RowDecoder[T], coder: Coder[T]): beam.SnowflakeIO.Read[T] = - beam.SnowflakeIO - .read() - .withDataSourceConfiguration( - SnowflakeIO.dataSourceConfiguration(snowflakeOptions.connectionOptions) - ) - .withStagingBucketName(snowflakeOptions.stagingBucketName) - .withStorageIntegrationName(snowflakeOptions.storageIntegrationName) - .withCsvMapper(buildCsvMapper(rowDecoder)) - .withCoder(CoderMaterializer.beam(sc, coder)) + private[snowflake] def userDataMapper[T: RowEncoder]: UserDataMapper[T] = { (element: T) => + RowEncoder[T].encode(element).toArray + } } sealed trait SnowflakeIO[T] extends ScioIO[T] -final case class SnowflakeSelect[T](snowflakeOptions: SnowflakeOptions, select: String)(implicit +final case class SnowflakeSelect[T](connectionOptions: SnowflakeConnectionOptions, select: String)( + implicit rowDecoder: RowDecoder[T], coder: Coder[T] ) extends SnowflakeIO[T] { - override type ReadP = Unit + import SnowflakeIO._ + + override type ReadP = ReadParam[T] override type WriteP = Unit override val tapT: TapT.Aux[T, Nothing] = EmptyTapOf[T] - override protected def read(sc: ScioContext, params: ReadP): SCollection[T] = - sc.applyTransform(SnowflakeIO.prepareRead(snowflakeOptions, sc).fromQuery(select)) + override def testId: String = s"SnowflakeIO(${snowflakeIoId(connectionOptions, select)})" + + override protected def read(sc: ScioContext, params: ReadP): SCollection[T] = { + val tempDirectory = ScioUtil.tempDirOrDefault(params.stagingBucketName, sc).toString + val t = beam.SnowflakeIO + .read[T]() + .fromQuery(select) + .withDataSourceConfiguration(dataSourceConfiguration(connectionOptions)) + .withStorageIntegrationName(params.storageIntegrationName) + .withStagingBucketName(tempDirectory) + .pipe(r => Option(params.quotationMark).fold(r)(r.withQuotationMark)) + .withCsvMapper(csvMapper) + .withCoder(CoderMaterializer.beam(sc, coder)) + .pipe(r => Option(params.configOverride).fold(r)(_(r))) + + sc.applyTransform(t) + } override protected def write(data: SCollection[T], params: WriteP): Tap[Nothing] = throw new UnsupportedOperationException("SnowflakeSelect is read-only") @@ -98,33 +166,55 @@ final case class SnowflakeSelect[T](snowflakeOptions: SnowflakeOptions, select: override def tap(params: ReadP): Tap[Nothing] = EmptyTap } -final case class SnowflakeTable[T](snowflakeOptions: SnowflakeOptions, table: String)(implicit - rowDecoder: RowDecoder[T], - rowEncoder: RowEncoder[T], +final case class SnowflakeTable[T](connectionOptions: SnowflakeConnectionOptions, table: String)( + implicit + rowCodec: RowCodec[T], // use codec for tap coder: Coder[T] ) extends SnowflakeIO[T] { - override type ReadP = Unit - override type WriteP = Unit - override val tapT: TapT.Aux[T, Nothing] = EmptyTapOf[T] + import SnowflakeIO._ + + override type ReadP = ReadParam[T] + override type WriteP = WriteParam[T] + override val tapT: TapT.Aux[T, Nothing] = EmptyTapOf[T] // TODO Create a tap + + override def testId: String = s"SnowflakeIO(${snowflakeIoId(connectionOptions, table)})" - override protected def read(sc: ScioContext, params: ReadP): SCollection[T] = - sc.applyTransform(SnowflakeIO.prepareRead(snowflakeOptions, sc).fromTable(table)) + override protected def read(sc: ScioContext, params: ReadP): SCollection[T] = { + val tempDirectory = ScioUtil.tempDirOrDefault(params.stagingBucketName, sc).toString + val t = beam.SnowflakeIO + .read[T]() + .fromTable(table) + .withDataSourceConfiguration(dataSourceConfiguration(connectionOptions)) + .withStorageIntegrationName(params.storageIntegrationName) + .withStagingBucketName(tempDirectory) + .pipe(r => Option(params.quotationMark).fold(r)(r.withQuotationMark)) + .withCsvMapper(csvMapper) + .withCoder(CoderMaterializer.beam(sc, coder)) + .pipe(r => Option(params.configOverride).fold(r)(_(r))) + + sc.applyTransform(t) + } override protected def write(data: SCollection[T], params: WriteP): Tap[Nothing] = { - data.applyInternal( - beam.SnowflakeIO - .write[T]() - .withDataSourceConfiguration( - SnowflakeIO.dataSourceConfiguration(snowflakeOptions.connectionOptions) - ) - .to(table) - .withStagingBucketName(snowflakeOptions.stagingBucketName) - .withStorageIntegrationName(snowflakeOptions.storageIntegrationName) - .withUserDataMapper(new UserDataMapper[T] { - override def mapRow(element: T): Array[AnyRef] = rowEncoder.encode(element).toArray - }) - ) + val tempDirectory = ScioUtil.tempDirOrDefault(params.stagingBucketName, data.context).toString + val t = beam.SnowflakeIO + .write[T]() + .withDataSourceConfiguration(dataSourceConfiguration(connectionOptions)) + .to(table) + .pipe(w => Option(params.createDisposition).fold(w)(w.withCreateDisposition)) + .pipe(w => Option(params.writeDisposition).fold(w)(w.withWriteDisposition)) + .pipe(w => Option(params.snowPipe).fold(w)(w.withSnowPipe)) + .pipe(w => Option(params.shardNumber).fold(w)(w.withShardsNumber)) + .pipe(w => Option(params.flushRowLimit).fold(w)(w.withFlushRowLimit)) + .pipe(w => Option(params.flushTimeLimit).fold(w)(w.withFlushTimeLimit)) + .pipe(w => Option(params.quotationMark).fold(w)(w.withQuotationMark)) + .pipe(w => Option(params.storageIntegrationName).fold(w)(w.withStorageIntegrationName)) + .withStagingBucketName(tempDirectory) + .withUserDataMapper(userDataMapper) + .pipe(w => Option(params.configOverride).fold(w)(_(w))) + + data.applyInternal(t) EmptyTap } diff --git a/scio-snowflake/src/main/scala/com/spotify/scio/snowflake/SnowflakeOptions.scala b/scio-snowflake/src/main/scala/com/spotify/scio/snowflake/SnowflakeOptions.scala index a43fb74fe9..cdcd801866 100644 --- a/scio-snowflake/src/main/scala/com/spotify/scio/snowflake/SnowflakeOptions.scala +++ b/scio-snowflake/src/main/scala/com/spotify/scio/snowflake/SnowflakeOptions.scala @@ -17,84 +17,76 @@ package com.spotify.scio.snowflake +import org.joda.time.Duration + sealed trait SnowflakeAuthenticationOptions -/** - * Options for a Snowflake username/password authentication. - * - * @param username - * username - * @param password - * password - */ -final case class SnowflakeUsernamePasswordAuthenticationOptions( - username: String, - password: String -) extends SnowflakeAuthenticationOptions +object SnowflakeAuthenticationOptions { -/** - * Options for a Snowflake key pair authentication. - * - * @param username - * username - * @param privateKeyPath - * path to the private key - * @param privateKeyPassphrase - * passphrase for the private key (optional) - */ -final case class SnowflakeKeyPairAuthenticationOptions( - username: String, - privateKeyPath: String, - privateKeyPassphrase: Option[String] -) extends SnowflakeAuthenticationOptions + /** + * Snowflake username/password authentication. + * + * @param username + * username + * @param password + * password + */ + final case class UsernamePassword( + username: String, + password: String + ) extends SnowflakeAuthenticationOptions -/** - * Options for a Snowflake OAuth token authentication. - * - * @param token - * OAuth token - */ -final case class SnowflakeOAuthTokenAuthenticationOptions( - token: String -) extends SnowflakeAuthenticationOptions + /** + * Key pair authentication. + * + * @param username + * username + * @param privateKeyPath + * path to the private key + * @param privateKeyPassphrase + * passphrase for the private key (optional) + */ + final case class KeyPair( + username: String, + privateKeyPath: String, + privateKeyPassphrase: Option[String] = None + ) extends SnowflakeAuthenticationOptions + + /** + * OAuth token authentication. + * + * @param token + * OAuth token + */ + final case class OAuthToken(token: String) extends SnowflakeAuthenticationOptions + +} /** * Options for a Snowflake connection. * * @param authenticationOptions * authentication options - * @param serverName - * server name (e.g. "account.region.snowflakecomputing.com") + * @param url + * Sets URL of Snowflake server in following format: + * "jdbc:snowflake://[host]:[port].snowflakecomputing.com" * @param database - * database name + * database to use * @param role - * role name + * user's role to be used when running queries on Snowflake * @param warehouse * warehouse name * @param schema - * schema name (optional) + * schema to use when connecting to Snowflake + * @param loginTimeout + * Sets loginTimeout that will be used in [[net.snowflake.client.jdbc.SnowflakeBasicDataSource]]. */ final case class SnowflakeConnectionOptions( - authenticationOptions: SnowflakeAuthenticationOptions, - serverName: String, - database: String, - role: String, - warehouse: String, - schema: Option[String] -) - -/** - * Options for configuring a Neo4J driver. - * - * @param connectionOptions - * connection options - * @param stagingBucketName - * Snowflake staging bucket name where CSV files will be stored - * @param storageIntegrationName - * Storage integration name as created in Snowflake - */ -final case class SnowflakeOptions( - connectionOptions: SnowflakeConnectionOptions, - stagingBucketName: String, - storageIntegrationName: String + url: String, + authenticationOptions: SnowflakeAuthenticationOptions = null, + database: String = null, + role: String = null, + warehouse: String = null, + schema: String = null, + loginTimeout: Duration = null ) diff --git a/scio-snowflake/src/main/scala/com/spotify/scio/snowflake/syntax/SCollectionSyntax.scala b/scio-snowflake/src/main/scala/com/spotify/scio/snowflake/syntax/SCollectionSyntax.scala index 1ff2c3f235..3679dc593c 100644 --- a/scio-snowflake/src/main/scala/com/spotify/scio/snowflake/syntax/SCollectionSyntax.scala +++ b/scio-snowflake/src/main/scala/com/spotify/scio/snowflake/syntax/SCollectionSyntax.scala @@ -18,108 +18,88 @@ package com.spotify.scio.snowflake.syntax import com.spotify.scio.coders.Coder -import com.spotify.scio.io.{EmptyTap, Tap} -import com.spotify.scio.snowflake.SnowflakeOptions +import com.spotify.scio.io.ClosedTap +import com.spotify.scio.snowflake.{SnowflakeConnectionOptions, SnowflakeIO, SnowflakeTable} import com.spotify.scio.values.SCollection -import kantan.csv.{RowDecoder, RowEncoder} -import org.apache.beam.sdk.io.snowflake.SnowflakeIO.UserDataMapper -import org.apache.beam.sdk.io.{snowflake => beam} +import kantan.csv.RowCodec +import org.apache.beam.sdk.io.snowflake.data.SnowflakeTableSchema +import org.apache.beam.sdk.io.snowflake.enums.{CreateDisposition, WriteDisposition} +import org.joda.time.Duration /** * Enhanced version of [[com.spotify.scio.values.SCollection SCollection]] with Snowflake methods. */ final class SnowflakeSCollectionOps[T](private val self: SCollection[T]) extends AnyVal { - import com.spotify.scio.snowflake.SnowflakeIO._ - /** - * Execute the provided SQL query in Snowflake, COPYing the result in CSV format to the provided - * bucket, and return an [[SCollection]] of provided type, reading this bucket. - * - * The [[SCollection]] is generated using [[kantan.csv.RowDecoded]]. [[SCollection]] type - * properties must then match the order of the columns of the SELECT, that will be copied to the - * bucket. - * - * @see - * ''Reading from Snowflake'' in the - * [[https://beam.apache.org/documentation/io/built-in/snowflake/ Beam `SnowflakeIO` documentation]] - * @param snowflakeConf - * options for configuring a Snowflake integration - * @param query - * SQL select query - * @return - * [[SCollection]] containing the query results as parsed from the CSV bucket copied from - * Snowflake - */ - def snowflakeSelect[U]( - snowflakeConf: SnowflakeOptions, - query: String - )(implicit - rowDecoder: RowDecoder[U], - coder: Coder[U] - ): SCollection[U] = - self.context.applyTransform(prepareRead(snowflakeConf, self.context).fromQuery(query)) - - /** - * Copy the provided Snowflake table in CSV format to the provided bucket, and * return an - * [[SCollection]] of provided type, reading this bucket. - * - * The [[SCollection]] is generated using [[kantan.csv.RowDecoded]]. [[SCollection]] type - * properties must then match the order of the columns of the table, that will be copied to the - * bucket. + * Save this SCollection as a Snowflake database table. The [[SCollection]] is written to CSV + * files in a bucket, using a provided [[kantan.csv.RowEncoder]] to encode each element as a CSV + * row. The bucket is then COPYied to the Snowflake table. * * @see * ''Reading from Snowflake'' in the * [[https://beam.apache.org/documentation/io/built-in/snowflake/ Beam `SnowflakeIO` documentation]] - * @param snowflakeConf + * @param connectionOptions * options for configuring a Snowflake integration * @param table - * table + * table name to be written in Snowflake + * @param tableSchema + * table schema to be used during creating table + * @param createDisposition + * disposition to be used during table preparation + * @param writeDisposition + * disposition to be used during writing to table phase + * @param snowPipe + * name of created + * [[https://docs.snowflake.com/en/user-guide/data-load-snowpipe-intro SnowPipe]] in Snowflake + * dashboard + * @param shardNumber + * number of shards that are created per window + * @param flushRowLimit + * number of row limit that will be saved to the staged file and then loaded to Snowflake + * @param flushTimeLimit + * duration how often staged files will be created and then how often ingested by Snowflake + * during streaming + * @param storageIntegrationName + * Storage Integration in Snowflake to be used + * @param stagingBucketName + * cloud bucket (GCS by now) to use as tmp location of CSVs during COPY statement. + * @param quotationMark + * Snowflake-specific quotations around strings * @return * [[SCollection]] containing the table elements as parsed from the CSV bucket copied from * Snowflake table */ - def snowflakeTable[U]( - snowflakeConf: SnowflakeOptions, - table: String - )(implicit - rowDecoder: RowDecoder[U], - coder: Coder[U] - ): SCollection[U] = - self.context.applyTransform(prepareRead(snowflakeConf, self.context).fromTable(table)) - - /** - * Save this SCollection as a Snowflake database table. The [[SCollection]] is written to CSV - * files in a bucket, using the provided [[kantan.csv.RowEncoder]] to encode each element as a CSV - * row. The bucket is then COPYied to the Snowflake table. - * - * @see - * ''Writing to Snowflake tables'' in the - * [[https://beam.apache.org/documentation/io/built-in/snowflake/ Beam `SnowflakeIO` documentation]] - * - * @param snowflakeOptions - * options for configuring a Snowflake connexion - * @param table - * Snowflake table - */ - def saveAsSnowflakeTable( - snowflakeOptions: SnowflakeOptions, - table: String - )(implicit rowEncoder: RowEncoder[T], coder: Coder[T]): Tap[Nothing] = { - self.applyInternal( - beam.SnowflakeIO - .write[T]() - .withDataSourceConfiguration( - dataSourceConfiguration(snowflakeOptions.connectionOptions) - ) - .to(table) - .withStagingBucketName(snowflakeOptions.stagingBucketName) - .withStorageIntegrationName(snowflakeOptions.storageIntegrationName) - .withUserDataMapper(new UserDataMapper[T] { - override def mapRow(element: T): Array[AnyRef] = rowEncoder.encode(element).toArray - }) + def saveAsSnowflake( + connectionOptions: SnowflakeConnectionOptions, + table: String, + tableSchema: SnowflakeTableSchema = SnowflakeIO.WriteParam.DefaultTableSchema, + createDisposition: CreateDisposition = SnowflakeIO.WriteParam.DefaultCreateDisposition, + writeDisposition: WriteDisposition = SnowflakeIO.WriteParam.DefaultWriteDisposition, + snowPipe: String = SnowflakeIO.WriteParam.DefaultSnowPipe, + shardNumber: Integer = SnowflakeIO.WriteParam.DefaultShardNumber, + flushRowLimit: Integer = SnowflakeIO.WriteParam.DefaultFlushRowLimit, + flushTimeLimit: Duration = SnowflakeIO.WriteParam.DefaultFlushTimeLimit, + storageIntegrationName: String = SnowflakeIO.WriteParam.DefaultStorageIntegrationName, + stagingBucketName: String = SnowflakeIO.WriteParam.DefaultStagingBucketName, + quotationMark: String = SnowflakeIO.WriteParam.DefaultQuotationMark, + configOverride: SnowflakeIO.WriteParam.ConfigOverride[T] = + SnowflakeIO.WriteParam.DefaultConfigOverride + )(implicit rowCodec: RowCodec[T], coder: Coder[T]): ClosedTap[Nothing] = { + val param = SnowflakeIO.WriteParam( + tableSchema = tableSchema, + createDisposition = createDisposition, + writeDisposition = writeDisposition, + snowPipe = snowPipe, + shardNumber = shardNumber, + flushRowLimit = flushRowLimit, + flushTimeLimit = flushTimeLimit, + storageIntegrationName = storageIntegrationName, + stagingBucketName = stagingBucketName, + quotationMark = quotationMark, + configOverride = configOverride ) - EmptyTap + self.write(SnowflakeTable[T](connectionOptions, table))(param) } } diff --git a/scio-snowflake/src/main/scala/com/spotify/scio/snowflake/syntax/ScioContextSyntax.scala b/scio-snowflake/src/main/scala/com/spotify/scio/snowflake/syntax/ScioContextSyntax.scala index 846f437da2..392267b29b 100644 --- a/scio-snowflake/src/main/scala/com/spotify/scio/snowflake/syntax/ScioContextSyntax.scala +++ b/scio-snowflake/src/main/scala/com/spotify/scio/snowflake/syntax/ScioContextSyntax.scala @@ -19,9 +19,14 @@ package com.spotify.scio.snowflake.syntax import com.spotify.scio.ScioContext import com.spotify.scio.coders.Coder -import com.spotify.scio.snowflake.{SnowflakeOptions, SnowflakeSelect} +import com.spotify.scio.snowflake.{ + SnowflakeConnectionOptions, + SnowflakeIO, + SnowflakeSelect, + SnowflakeTable +} import com.spotify.scio.values.SCollection -import kantan.csv.RowDecoder +import kantan.csv.{RowCodec, RowDecoder} /** Enhanced version of [[ScioContext]] with Snowflake methods. */ final class SnowflakeScioContextOps(private val self: ScioContext) extends AnyVal { @@ -29,16 +34,68 @@ final class SnowflakeScioContextOps(private val self: ScioContext) extends AnyVa /** * Get an SCollection for a Snowflake SQL query * - * @param snowflakeOptions + * @param connectionOptions * options for configuring a Snowflake connexion * @param query * Snowflake SQL select query + * @param storageIntegrationName + * Storage Integration in Snowflake to be used + * @param stagingBucketName + * cloud bucket (GCS by now) to use as tmp location of CSVs during COPY statement. + * @param quotationMark + * Snowflake-specific quotations around strings + */ + def snowflakeQuery[T]( + connectionOptions: SnowflakeConnectionOptions, + query: String, + storageIntegrationName: String, + stagingBucketName: String = SnowflakeIO.ReadParam.DefaultStagingBucketName, + quotationMark: String = SnowflakeIO.ReadParam.DefaultQuotationMark, + configOverride: SnowflakeIO.ReadParam.ConfigOverride[T] = + SnowflakeIO.ReadParam.DefaultConfigOverride + )(implicit rowDecoder: RowDecoder[T], coder: Coder[T]): SCollection[T] = { + val param = SnowflakeIO.ReadParam( + storageIntegrationName = storageIntegrationName, + stagingBucketName = stagingBucketName, + quotationMark = quotationMark, + configOverride = configOverride + ) + self.read(SnowflakeSelect(connectionOptions, query))(param) + } + + /** + * Get an SCollection for a Snowflake table + * + * @param connectionOptions + * options for configuring a Snowflake connexion + * @param table + * Snowflake table + * @param storageIntegrationName + * Storage Integration in Snowflake to be used + * @param stagingBucketName + * cloud bucket (GCS by now) to use as tmp location of CSVs during COPY statement. + * @param quotationMark + * Snowflake-specific quotations around strings */ - def snowflakeQuery[T: RowDecoder: Coder]( - snowflakeOptions: SnowflakeOptions, - query: String - ): SCollection[T] = - self.read(SnowflakeSelect(snowflakeOptions, query)) + def snowflakeTable[T]( + connectionOptions: SnowflakeConnectionOptions, + table: String, + storageIntegrationName: String, + stagingBucketName: String = SnowflakeIO.ReadParam.DefaultStagingBucketName, + quotationMark: String = SnowflakeIO.ReadParam.DefaultQuotationMark, + configOverride: SnowflakeIO.ReadParam.ConfigOverride[T] = + SnowflakeIO.ReadParam.DefaultConfigOverride + )(implicit rowDecoder: RowDecoder[T], coder: Coder[T]): SCollection[T] = { + // create a read only codec + implicit val codec: RowCodec[T] = RowCodec.from(rowDecoder, null) + val param = SnowflakeIO.ReadParam( + storageIntegrationName = storageIntegrationName, + stagingBucketName = stagingBucketName, + quotationMark = quotationMark, + configOverride = configOverride + ) + self.read(SnowflakeTable(connectionOptions, table))(param) + } } trait ScioContextSyntax { diff --git a/scio-snowflake/src/test/scala/com/spotify/scio/snowflake/SnowflakeIOTest.scala b/scio-snowflake/src/test/scala/com/spotify/scio/snowflake/SnowflakeIOTest.scala new file mode 100644 index 0000000000..289e318e78 --- /dev/null +++ b/scio-snowflake/src/test/scala/com/spotify/scio/snowflake/SnowflakeIOTest.scala @@ -0,0 +1,43 @@ +package com.spotify.scio.snowflake + +import com.spotify.scio.testing.ScioIOSpec +import kantan.csv.RowCodec + +object SnowflakeIOTest { + final case class Data(value: String) +} + +class SnowflakeIOTest extends ScioIOSpec { + + import SnowflakeIOTest._ + + val connectionOptions = SnowflakeConnectionOptions( + url = "jdbc:snowflake://host.snowflakecomputing.com" + ) + + implicit val rowCodecData: RowCodec[Data] = RowCodec.caseCodec(Data.apply)(Data.unapply) + + "SnowflakeIO" should "support query input" in { + val input = Seq(Data("a"), Data("b"), Data("c")) + val query = "SELECT * FROM table" + testJobTestInput(input, query)(SnowflakeIO(connectionOptions, _))( + _.snowflakeQuery(connectionOptions, _, "storage-integration") + ) + } + + it should "support table input" in { + val input = Seq(Data("a"), Data("b"), Data("c")) + val table = "table" + testJobTestInput(input, table)(SnowflakeIO(connectionOptions, _))( + _.snowflakeTable(connectionOptions, _, "storage-integration") + ) + } + + it should "support table output" in { + val output = Seq(Data("a"), Data("b"), Data("c")) + val table = "table" + testJobTestOutput(output, table)(SnowflakeIO(connectionOptions, _))( + _.saveAsSnowflake(connectionOptions, _) + ) + } +}