diff --git a/README.md b/README.md index f40a09b67..2d500f46e 100644 --- a/README.md +++ b/README.md @@ -152,19 +152,23 @@ Status: POC Use like ```scala -import $ivy.`org.jupyter-scala::scio:0.4.2` +import $ivy.`org.jupyter-scala::scio:0.4.3`, $ivy.`org.apache.beam:beam-runners-direct-java:2.6.0` + import jupyter.scio._ import com.spotify.scio._ -import com.spotify.scio.accumulators._ import com.spotify.scio.bigquery._ -import com.spotify.scio.experimental._ -val sc = JupyterScioContext( - "runner" -> "DataflowPipelineRunner", +// Define JupyterScioContext +JupyterScioContext( + "runner" -> "DirectRunner", // DirectRunner or DataflowRunner "project" -> "jupyter-scala", "stagingLocation" -> "gs://bucket/staging" -).withGcpCredential("/path-to/credentials.json") // alternatively, set the env var GOOGLE_APPLICATION_CREDENTIALS to that path +) + +sc.withGcpCredential("/path-to/credentials.json") // alternatively, set the env var GOOGLE_APPLICATION_CREDENTIALS to that path + +// Access JupyterScioContext with `sc` ``` ### Scalding diff --git a/build.sbt b/build.sbt index 264607ad9..a2550d1ef 100644 --- a/build.sbt +++ b/build.sbt @@ -127,7 +127,8 @@ lazy val scio = project Deps.kantanCsv, Deps.macroParadise, Deps.scioCore, - Deps.scioExtra + Deps.scioExtra, + Deps.dataflowRunner ) }, disableScalaVersion("2.12") diff --git a/project/Deps.scala b/project/Deps.scala index e97382d91..f738d1766 100644 --- a/project/Deps.scala +++ b/project/Deps.scala @@ -9,7 +9,8 @@ object Deps { def ammonium = "0.8.3-1" def flink = "1.1.3" def jupyterKernel = "0.4.1" - def scio = "0.2.12" + def scio = "0.6.1" + def beam = "2.6.0" } def ammonium = ("org.jupyter-scala" % "ammonite" % Versions.ammonium).cross(CrossVersion.full) @@ -35,6 +36,7 @@ object Deps { def scalaXml = "org.scala-lang.modules" %% "scala-xml" % "1.0.6" def scioCore = "com.spotify" %% "scio-core" % Versions.scio def scioExtra = "com.spotify" %% "scio-extra" % Versions.scio + def dataflowRunner = "org.apache.beam" % "beam-runners-google-cloud-dataflow-java" % Versions.beam def slf4jSimple = "org.slf4j" % "slf4j-simple" % "1.7.24" def sparkSql1 = "org.apache.spark" %% "spark-sql" % "1.3.1" def sparkSql = "org.apache.spark" %% "spark-sql" % "2.0.2" diff --git a/scio/src/main/scala/com/spotify/scio/jupyter/JupyterScioContext.scala b/scio/src/main/scala/com/spotify/scio/jupyter/JupyterScioContext.scala index c751031c6..f5da17887 100644 --- a/scio/src/main/scala/com/spotify/scio/jupyter/JupyterScioContext.scala +++ b/scio/src/main/scala/com/spotify/scio/jupyter/JupyterScioContext.scala @@ -6,11 +6,13 @@ import java.nio.file.{Files, Path} import ammonite.repl.RuntimeAPI import ammonite.runtime.InterpAPI -import com.google.api.client.auth.oauth2.Credential -import com.google.api.client.googleapis.auth.oauth2.GoogleCredential import com.google.api.services.dataflow.DataflowScopes -import com.google.cloud.dataflow.sdk.options.{DataflowPipelineOptions, PipelineOptions, PipelineOptionsFactory} +import com.google.auth.Credentials +import com.google.auth.oauth2.GoogleCredentials import com.spotify.scio.{ScioContext, ScioResult} +import org.apache.beam.runners.dataflow.options.DataflowPipelineOptions +import org.apache.beam.sdk.options.{PipelineOptions, PipelineOptionsFactory} +import org.slf4j.LoggerFactory import scala.collection.JavaConverters._ @@ -18,47 +20,48 @@ import scala.collection.JavaConverters._ // in the com.spotify.scio namespace to access private[scio] things -class JupyterScioContext( - options: PipelineOptions, - replJarPath: Path -)(implicit - interpApi: InterpAPI, - runtimeApi: RuntimeAPI -) extends ScioContext(options, Nil) { - - addArtifacts( - replJarPath.toAbsolutePath.toString :: - runtimeApi.sess.frames - .flatMap(_.classpath) - .map(_.getAbsolutePath) - ) +class JupyterScioContext(options: PipelineOptions, + replJarPath: Path) + (implicit interpApi: InterpAPI, + runtimeApi: RuntimeAPI + ) extends ScioContext(options, + replJarPath.toAbsolutePath.toString :: + runtimeApi.sess.frames + .flatMap(_.classpath) + .map(_.getAbsolutePath) +) { + + import JupyterScioContext._ interpApi.load.onJarAdded { case Seq() => // just in case - case jars => - addArtifacts(jars.map(_.getAbsolutePath).toList) + case _ => + throw new RuntimeException("Cannot add jars after ScioContext Initialization") + } + + def setGcpCredential(credentials: Credentials): Unit = { + // Save credentials for all future context creation + _gcpCredentials = Some(credentials) + + require(_currentContext.isDefined && !_currentContext.get.isClosed, + "Scio Context is not yet defined or already closed") + + _currentContext.get.options.as(classOf[DataflowPipelineOptions]) + .setGcpCredential(credentials) } - def setGcpCredential(credential: Credential): Unit = - options.as(classOf[DataflowPipelineOptions]).setGcpCredential(credential) def setGcpCredential(path: String): Unit = setGcpCredential( - GoogleCredential.fromStream(new FileInputStream(new File(path))).createScoped( + GoogleCredentials.fromStream(new FileInputStream(new File(path))).createScoped( List(DataflowScopes.CLOUD_PLATFORM).asJava ) ) - def withGcpCredential(credential: Credential): this.type = { - setGcpCredential(credential) - this - } - def withGcpCredential(path: String): this.type = { - setGcpCredential(path) - this - } - /** Enhanced version that dumps REPL session jar. */ override def close(): ScioResult = { + // Some APIs exposed only for Jupyter will close Scio Context + // even if the user intends to use it further. A new Context would be created again. + logger.info("Closing Scio Context") runtimeApi.sess.sessionJarFile(replJarPath.toFile) super.close() } @@ -70,38 +73,55 @@ class JupyterScioContext( } +/** + * Allow only one active Scio Context. + * Also manage the existing Scio Context, create a new one if the current has been closed. + */ object JupyterScioContext { - def apply(args: (String, String)*)(implicit - interpApi: InterpAPI, - runtimeApi: RuntimeAPI - ): JupyterScioContext = - JupyterScioContext( - PipelineOptionsFactory.fromArgs( - args - .map { case (k, v) => s"--$k=$v" } - .toArray - ).as(classOf[DataflowPipelineOptions]), - nextReplJarPath() - ) - - def apply(options: PipelineOptions)(implicit - interpApi: InterpAPI, - runtimeApi: RuntimeAPI - ): JupyterScioContext = - JupyterScioContext(options, nextReplJarPath()) + private val logger = LoggerFactory.getLogger(this.getClass) + + private var _currentContext: Option[JupyterScioContext] = None + private var _pipelineOptions: Option[PipelineOptions] = None + private var _gcpCredentials: Option[Credentials] = None + + /** + * Always returns a new Scio Context, and forgets the old context + */ + def apply(args: (String, String)*) + (implicit interpApi: InterpAPI, + runtimeApi: RuntimeAPI): Unit = JupyterScioContext( + PipelineOptionsFactory.fromArgs( + args.map { case (k, v) => s"--$k=$v" }: _* + ).as(classOf[PipelineOptions]) + ) - def apply( - options: PipelineOptions, - replJarPath: Path - )(implicit - interpApi: InterpAPI, - runtimeApi: RuntimeAPI - ): JupyterScioContext = - new JupyterScioContext(options, replJarPath) + /** + * Always returns a new Scio Context, and forgets the old context + */ + def apply(options: PipelineOptions) + (implicit interpApi: InterpAPI, runtimeApi: RuntimeAPI): Unit = { + _pipelineOptions = Some(options) + _currentContext = Some(new JupyterScioContext(options, nextReplJarPath())) + logger.info("ScioContext is accessible as sc") + } + /** + * Get Scio Context with currently defined options. + * Get new Scio Context if previous is closed. + * + * @return + */ + def sc(implicit interpApi: InterpAPI, runtimeApi: RuntimeAPI): JupyterScioContext = { + if (_currentContext.isEmpty || _currentContext.get.isClosed) { + // Create a new Scio Context + JupyterScioContext(_pipelineOptions.getOrElse(PipelineOptionsFactory.create())) + _gcpCredentials.foreach(_currentContext.get.setGcpCredential) + } + _currentContext.get + } - def nextReplJarPath(prefix: String = "jupyter-scala-scio-", suffix: String = ".jar"): Path = + private def nextReplJarPath(prefix: String = "jupyter-scala-scio-", suffix: String = ".jar"): Path = Files.createTempFile(prefix, suffix) -} \ No newline at end of file +} diff --git a/scio/src/main/scala/jupyter/scio/package.scala b/scio/src/main/scala/jupyter/scio/package.scala index 4c3501095..8fe3862dd 100644 --- a/scio/src/main/scala/jupyter/scio/package.scala +++ b/scio/src/main/scala/jupyter/scio/package.scala @@ -1,24 +1,73 @@ package jupyter -import _root_.scala.tools.nsc.interpreter.Helper - import java.io.File -import com.google.api.client.auth.oauth2.Credential +import ammonite.repl.RuntimeAPI +import ammonite.runtime.InterpAPI + +import com.google.auth.Credentials import com.spotify.scio.bigquery.BigQueryClient +import com.spotify.scio.io.Tap +import com.spotify.scio.values.SCollection + +import _root_.scala.tools.nsc.interpreter.Helper package object scio { + // Alias to reduce number of imports in notebook val JupyterScioContext: com.spotify.scio.jupyter.JupyterScioContext.type = com.spotify.scio.jupyter.JupyterScioContext + def sc(implicit interpApi: InterpAPI, runtimeApi: RuntimeAPI) = JupyterScioContext.sc + def bigQueryClient(project: String): BigQueryClient = Helper.bigQueryClient(project) - def bigQueryClient(project: String, credential: Credential): BigQueryClient = - Helper.bigQueryClient(project, credential) + def bigQueryClient(project: String, credentials: Credentials): BigQueryClient = + Helper.bigQueryClient(project, credentials) def bigQueryClient(project: String, secretFile: File): BigQueryClient = Helper.bigQueryClient(project, secretFile) + // Helpers for interactive analysis + implicit class JupyterSCollection[T](self: SCollection[T]) { + + /** + * Get first n elements of the SCollection as a String separated by \n + */ + private def asString(numElements: Int): String = + self + .withName(s"Take $numElements elements") + .take(numElements) + .tap() + .value + .mkString("\n") + + /** + * Closes the ScioContext and print elements on screen + */ + def show(numElements: Int = 20): Unit = println(asString(numElements)) + + /** + * Closes the ScioContext and gets SCollection as a Tap + */ + def tap(): Tap[T] = { + val mSelf = self.materialize + self.context.close().waitUntilDone() + mSelf.waitForResult() // Should be ready + } + } + + implicit class JupyterTap[T](self: Tap[T]) { + + /** + * Print the contents of a tap on screen + */ + def show(numElements: Int = 20): Unit = println(self + .value + .take(numElements) + .mkString("\n") + ) + } + } diff --git a/scio/src/main/scala/scala/tools/nsc/interpreter/Helper.scala b/scio/src/main/scala/scala/tools/nsc/interpreter/Helper.scala index bb99ecf22..ba1bf3a63 100644 --- a/scio/src/main/scala/scala/tools/nsc/interpreter/Helper.scala +++ b/scio/src/main/scala/scala/tools/nsc/interpreter/Helper.scala @@ -2,7 +2,7 @@ package scala.tools.nsc.interpreter import java.io.File -import com.google.api.client.auth.oauth2.Credential +import com.google.auth.Credentials import com.spotify.scio.bigquery.BigQueryClient /** @@ -19,8 +19,8 @@ object Helper { BigQueryClient(project, new File(secret)) } - def bigQueryClient(project: String, credential: Credential): BigQueryClient = - BigQueryClient(project, credential) + def bigQueryClient(project: String, credentials: Credentials): BigQueryClient = + BigQueryClient(project, credentials) def bigQueryClient(project: String, secretFile: File): BigQueryClient = BigQueryClient(project, secretFile)