Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve Scio on Jupyter #226

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -152,19 +152,23 @@ Status: POC
Use like

```scala
import $ivy.`org.jupyter-scala::scio:0.4.2`
import $ivy.`org.jupyter-scala::scio:0.4.3`, $ivy.`org.apache.beam:beam-runners-direct-java:2.6.0`

import jupyter.scio._

import com.spotify.scio._
import com.spotify.scio.accumulators._
import com.spotify.scio.bigquery._
import com.spotify.scio.experimental._

val sc = JupyterScioContext(
"runner" -> "DataflowPipelineRunner",
// Define JupyterScioContext
JupyterScioContext(
"runner" -> "DirectRunner", // DirectRunner or DataflowRunner
"project" -> "jupyter-scala",
"stagingLocation" -> "gs://bucket/staging"
).withGcpCredential("/path-to/credentials.json") // alternatively, set the env var GOOGLE_APPLICATION_CREDENTIALS to that path
)

sc.withGcpCredential("/path-to/credentials.json") // alternatively, set the env var GOOGLE_APPLICATION_CREDENTIALS to that path

// Access JupyterScioContext with `sc`
```

### Scalding
Expand Down
3 changes: 2 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,8 @@ lazy val scio = project
Deps.kantanCsv,
Deps.macroParadise,
Deps.scioCore,
Deps.scioExtra
Deps.scioExtra,
Deps.dataflowRunner
)
},
disableScalaVersion("2.12")
Expand Down
4 changes: 3 additions & 1 deletion project/Deps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ object Deps {
def ammonium = "0.8.3-1"
def flink = "1.1.3"
def jupyterKernel = "0.4.1"
def scio = "0.2.12"
def scio = "0.6.1"
def beam = "2.6.0"
}

def ammonium = ("org.jupyter-scala" % "ammonite" % Versions.ammonium).cross(CrossVersion.full)
Expand All @@ -35,6 +36,7 @@ object Deps {
def scalaXml = "org.scala-lang.modules" %% "scala-xml" % "1.0.6"
def scioCore = "com.spotify" %% "scio-core" % Versions.scio
def scioExtra = "com.spotify" %% "scio-extra" % Versions.scio
def dataflowRunner = "org.apache.beam" % "beam-runners-google-cloud-dataflow-java" % Versions.beam
def slf4jSimple = "org.slf4j" % "slf4j-simple" % "1.7.24"
def sparkSql1 = "org.apache.spark" %% "spark-sql" % "1.3.1"
def sparkSql = "org.apache.spark" %% "spark-sql" % "2.0.2"
Expand Down
138 changes: 79 additions & 59 deletions scio/src/main/scala/com/spotify/scio/jupyter/JupyterScioContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,59 +6,62 @@ import java.nio.file.{Files, Path}
import ammonite.repl.RuntimeAPI
import ammonite.runtime.InterpAPI

import com.google.api.client.auth.oauth2.Credential
import com.google.api.client.googleapis.auth.oauth2.GoogleCredential
import com.google.api.services.dataflow.DataflowScopes
import com.google.cloud.dataflow.sdk.options.{DataflowPipelineOptions, PipelineOptions, PipelineOptionsFactory}
import com.google.auth.Credentials
import com.google.auth.oauth2.GoogleCredentials
import com.spotify.scio.{ScioContext, ScioResult}
import org.apache.beam.runners.dataflow.options.DataflowPipelineOptions
import org.apache.beam.sdk.options.{PipelineOptions, PipelineOptionsFactory}
import org.slf4j.LoggerFactory

import scala.collection.JavaConverters._

// similar to com.spotify.scio.repl.ReplScioContext

// in the com.spotify.scio namespace to access private[scio] things

class JupyterScioContext(
options: PipelineOptions,
replJarPath: Path
)(implicit
interpApi: InterpAPI,
runtimeApi: RuntimeAPI
) extends ScioContext(options, Nil) {

addArtifacts(
replJarPath.toAbsolutePath.toString ::
runtimeApi.sess.frames
.flatMap(_.classpath)
.map(_.getAbsolutePath)
)
class JupyterScioContext(options: PipelineOptions,
replJarPath: Path)
(implicit interpApi: InterpAPI,
runtimeApi: RuntimeAPI
) extends ScioContext(options,
replJarPath.toAbsolutePath.toString ::
runtimeApi.sess.frames
.flatMap(_.classpath)
.map(_.getAbsolutePath)
) {

import JupyterScioContext._

interpApi.load.onJarAdded {
case Seq() => // just in case
case jars =>
addArtifacts(jars.map(_.getAbsolutePath).toList)
case _ =>
throw new RuntimeException("Cannot add jars after ScioContext Initialization")
}

def setGcpCredential(credentials: Credentials): Unit = {
// Save credentials for all future context creation
_gcpCredentials = Some(credentials)

require(_currentContext.isDefined && !_currentContext.get.isClosed,
"Scio Context is not yet defined or already closed")

_currentContext.get.options.as(classOf[DataflowPipelineOptions])
.setGcpCredential(credentials)
}

def setGcpCredential(credential: Credential): Unit =
options.as(classOf[DataflowPipelineOptions]).setGcpCredential(credential)
def setGcpCredential(path: String): Unit =
setGcpCredential(
GoogleCredential.fromStream(new FileInputStream(new File(path))).createScoped(
GoogleCredentials.fromStream(new FileInputStream(new File(path))).createScoped(
List(DataflowScopes.CLOUD_PLATFORM).asJava
)
)

def withGcpCredential(credential: Credential): this.type = {
setGcpCredential(credential)
this
}
def withGcpCredential(path: String): this.type = {
setGcpCredential(path)
this
}

/** Enhanced version that dumps REPL session jar. */
override def close(): ScioResult = {
// Some APIs exposed only for Jupyter will close Scio Context
// even if the user intends to use it further. A new Context would be created again.
logger.info("Closing Scio Context")
runtimeApi.sess.sessionJarFile(replJarPath.toFile)
super.close()
}
Expand All @@ -70,38 +73,55 @@ class JupyterScioContext(

}

/**
* Allow only one active Scio Context.
* Also manage the existing Scio Context, create a new one if the current has been closed.
*/
object JupyterScioContext {

def apply(args: (String, String)*)(implicit
interpApi: InterpAPI,
runtimeApi: RuntimeAPI
): JupyterScioContext =
JupyterScioContext(
PipelineOptionsFactory.fromArgs(
args
.map { case (k, v) => s"--$k=$v" }
.toArray
).as(classOf[DataflowPipelineOptions]),
nextReplJarPath()
)

def apply(options: PipelineOptions)(implicit
interpApi: InterpAPI,
runtimeApi: RuntimeAPI
): JupyterScioContext =
JupyterScioContext(options, nextReplJarPath())
private val logger = LoggerFactory.getLogger(this.getClass)

private var _currentContext: Option[JupyterScioContext] = None
private var _pipelineOptions: Option[PipelineOptions] = None
private var _gcpCredentials: Option[Credentials] = None

/**
* Always returns a new Scio Context, and forgets the old context
*/
def apply(args: (String, String)*)
(implicit interpApi: InterpAPI,
runtimeApi: RuntimeAPI): Unit = JupyterScioContext(
PipelineOptionsFactory.fromArgs(
args.map { case (k, v) => s"--$k=$v" }: _*
).as(classOf[PipelineOptions])
)

def apply(
options: PipelineOptions,
replJarPath: Path
)(implicit
interpApi: InterpAPI,
runtimeApi: RuntimeAPI
): JupyterScioContext =
new JupyterScioContext(options, replJarPath)
/**
* Always returns a new Scio Context, and forgets the old context
*/
def apply(options: PipelineOptions)
(implicit interpApi: InterpAPI, runtimeApi: RuntimeAPI): Unit = {
_pipelineOptions = Some(options)
_currentContext = Some(new JupyterScioContext(options, nextReplJarPath()))
logger.info("ScioContext is accessible as sc")
}

/**
* Get Scio Context with currently defined options.
* Get new Scio Context if previous is closed.
*
* @return
*/
def sc(implicit interpApi: InterpAPI, runtimeApi: RuntimeAPI): JupyterScioContext = {
if (_currentContext.isEmpty || _currentContext.get.isClosed) {
// Create a new Scio Context
JupyterScioContext(_pipelineOptions.getOrElse(PipelineOptionsFactory.create()))
_gcpCredentials.foreach(_currentContext.get.setGcpCredential)
}
_currentContext.get
}

def nextReplJarPath(prefix: String = "jupyter-scala-scio-", suffix: String = ".jar"): Path =
private def nextReplJarPath(prefix: String = "jupyter-scala-scio-", suffix: String = ".jar"): Path =
Files.createTempFile(prefix, suffix)

}
}
59 changes: 54 additions & 5 deletions scio/src/main/scala/jupyter/scio/package.scala
Original file line number Diff line number Diff line change
@@ -1,24 +1,73 @@
package jupyter

import _root_.scala.tools.nsc.interpreter.Helper

import java.io.File

import com.google.api.client.auth.oauth2.Credential
import ammonite.repl.RuntimeAPI
import ammonite.runtime.InterpAPI

import com.google.auth.Credentials
import com.spotify.scio.bigquery.BigQueryClient
import com.spotify.scio.io.Tap
import com.spotify.scio.values.SCollection

import _root_.scala.tools.nsc.interpreter.Helper

package object scio {

// Alias to reduce number of imports in notebook
val JupyterScioContext: com.spotify.scio.jupyter.JupyterScioContext.type =
com.spotify.scio.jupyter.JupyterScioContext

def sc(implicit interpApi: InterpAPI, runtimeApi: RuntimeAPI) = JupyterScioContext.sc

def bigQueryClient(project: String): BigQueryClient =
Helper.bigQueryClient(project)

def bigQueryClient(project: String, credential: Credential): BigQueryClient =
Helper.bigQueryClient(project, credential)
def bigQueryClient(project: String, credentials: Credentials): BigQueryClient =
Helper.bigQueryClient(project, credentials)

def bigQueryClient(project: String, secretFile: File): BigQueryClient =
Helper.bigQueryClient(project, secretFile)

// Helpers for interactive analysis
implicit class JupyterSCollection[T](self: SCollection[T]) {

/**
* Get first n elements of the SCollection as a String separated by \n
*/
private def asString(numElements: Int): String =
self
.withName(s"Take $numElements elements")
.take(numElements)
.tap()
.value
.mkString("\n")

/**
* Closes the ScioContext and print elements on screen
*/
def show(numElements: Int = 20): Unit = println(asString(numElements))

/**
* Closes the ScioContext and gets SCollection as a Tap
*/
def tap(): Tap[T] = {
val mSelf = self.materialize
self.context.close().waitUntilDone()
mSelf.waitForResult() // Should be ready
}
}

implicit class JupyterTap[T](self: Tap[T]) {

/**
* Print the contents of a tap on screen
*/
def show(numElements: Int = 20): Unit = println(self
.value
.take(numElements)
.mkString("\n")
)
}

}
6 changes: 3 additions & 3 deletions scio/src/main/scala/scala/tools/nsc/interpreter/Helper.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package scala.tools.nsc.interpreter

import java.io.File

import com.google.api.client.auth.oauth2.Credential
import com.google.auth.Credentials
import com.spotify.scio.bigquery.BigQueryClient

/**
Expand All @@ -19,8 +19,8 @@ object Helper {
BigQueryClient(project, new File(secret))
}

def bigQueryClient(project: String, credential: Credential): BigQueryClient =
BigQueryClient(project, credential)
def bigQueryClient(project: String, credentials: Credentials): BigQueryClient =
BigQueryClient(project, credentials)

def bigQueryClient(project: String, secretFile: File): BigQueryClient =
BigQueryClient(project, secretFile)
Expand Down