-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Update DaskCluster to use Cluster as a base (#18)
* Update DaskCluster to use Cluster as a base * Allow indirect references * Add fallback if can't autodetect cluster id * Add tests * Refactor files a little
- Loading branch information
1 parent
8eb38a9
commit 441fb88
Showing
8 changed files
with
63 additions
and
29 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import os | ||
import uuid | ||
|
||
from distributed.deploy.cluster import Cluster | ||
from distributed.core import rpc | ||
from typing import Optional | ||
from tornado.ioloop import IOLoop | ||
|
||
# Databricks Notebooks injects the `spark` session variable | ||
if 'spark' not in globals(): | ||
spark = None | ||
|
||
class DatabricksCluster(Cluster): | ||
"""Connect to a Dask cluster deployed via databricks.""" | ||
|
||
def __init__(self, | ||
loop: Optional[IOLoop] = None, | ||
asynchronous: bool = False,): | ||
self.spark_local_ip = os.getenv("SPARK_LOCAL_IP") | ||
if self.spark_local_ip is None: | ||
raise KeyError("Unable to find expected environment variable SPARK_LOCAL_IP. " | ||
"Are you running this on a Databricks driver node?") | ||
try: | ||
name = spark.conf.get("spark.databricks.clusterUsageTags.clusterId") | ||
except AttributeError: | ||
name = "unknown-databricks-" + uuid.uuid4().hex[:10] | ||
super().__init__(name=name, loop=loop, asynchronous=asynchronous) | ||
|
||
if not self.called_from_running_loop: | ||
self._loop_runner.start() | ||
self.sync(self._start) | ||
|
||
|
||
async def _start(self): | ||
self.scheduler_comm = rpc(f'{self.spark_local_ip}:8786') | ||
await super()._start() |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters