Skip to content

Commit

Permalink
Update DaskCluster to use Cluster as a base (#18)
Browse files Browse the repository at this point in the history
* Update DaskCluster to use Cluster as a base

* Allow indirect references

* Add fallback if can't autodetect cluster id

* Add tests

* Refactor files a little
  • Loading branch information
jacobtomlinson authored Nov 7, 2023
1 parent 8eb38a9 commit 441fb88
Show file tree
Hide file tree
Showing 8 changed files with 63 additions and 29 deletions.
7 changes: 4 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@ To launch a Dask cluster on Databricks you need to create an [init script](https
dask databricks run
```

Then from your Databricks Notebook connect a Dask `Client` to the scheduler running on the Spark Driver Node.
Then from your Databricks Notebook you can use the `DatabricksCluster` class to quickly connect a Dask `Client` to the scheduler running on the Spark Driver Node.

```python
from dask.distributed import Client
import os
from dask_databricks import DatabricksCluster

client = Client(f'{os.environ["SPARK_LOCAL_IP"]}:8786')
cluster = DatabricksCluster()
client = Client(cluster)
```

Now you can submit work from your notebook to the multi-node Dask cluster.
Expand Down
2 changes: 1 addition & 1 deletion dask_databricks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
#
# SPDX-License-Identifier: BSD-3

from .databricks import DatabricksCluster
from .databrickscluster import DatabricksCluster
1 change: 0 additions & 1 deletion dask_databricks/databricks/__init__.py

This file was deleted.

14 changes: 0 additions & 14 deletions dask_databricks/databricks/databrickscluster.py

This file was deleted.

36 changes: 36 additions & 0 deletions dask_databricks/databrickscluster.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import os
import uuid

from distributed.deploy.cluster import Cluster
from distributed.core import rpc
from typing import Optional
from tornado.ioloop import IOLoop

# Databricks Notebooks injects the `spark` session variable
if 'spark' not in globals():
spark = None

class DatabricksCluster(Cluster):
"""Connect to a Dask cluster deployed via databricks."""

def __init__(self,
loop: Optional[IOLoop] = None,
asynchronous: bool = False,):
self.spark_local_ip = os.getenv("SPARK_LOCAL_IP")
if self.spark_local_ip is None:
raise KeyError("Unable to find expected environment variable SPARK_LOCAL_IP. "
"Are you running this on a Databricks driver node?")
try:
name = spark.conf.get("spark.databricks.clusterUsageTags.clusterId")
except AttributeError:
name = "unknown-databricks-" + uuid.uuid4().hex[:10]
super().__init__(name=name, loop=loop, asynchronous=asynchronous)

if not self.called_from_running_loop:
self._loop_runner.start()
self.sync(self._start)


async def _start(self):
self.scheduler_comm = rpc(f'{self.spark_local_ip}:8786')
await super()._start()
4 changes: 0 additions & 4 deletions dask_databricks/tests/__init__.py

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
import os

import pytest
from distributed.deploy import Cluster
from dask.distributed import Client
from distributed.deploy import Cluster, LocalCluster


from dask_databricks import databricks
from dask_databricks import DatabricksCluster

@pytest.fixture(scope="session")
def dask_cluster():
"""Start a LocalCluster to simulate the cluster that would be started on Databricks."""
return LocalCluster(scheduler_port=8786)

@pytest.fixture
def remove_spark_local_ip():
Expand All @@ -18,7 +24,7 @@ def remove_spark_local_ip():
@pytest.fixture
def set_spark_local_ip():
original_spark_local_ip = os.getenv("SPARK_LOCAL_IP")
os.environ["SPARK_LOCAL_IP"] = "1.1.1"
os.environ["SPARK_LOCAL_IP"] = "127.0.0.1"
yield None
if original_spark_local_ip:
os.environ["SPARK_LOCAL_IP"] = original_spark_local_ip
Expand All @@ -27,8 +33,15 @@ def set_spark_local_ip():

def test_databricks_cluster_raises_key_error_when_initialised_outside_of_databricks(remove_spark_local_ip):
with pytest.raises(KeyError):
cluster = databricks.DatabricksCluster()
DatabricksCluster()

def test_databricks_cluster_creates_local_cluster_object(set_spark_local_ip):
cluster = databricks.DatabricksCluster()
def test_databricks_cluster_create(set_spark_local_ip, dask_cluster):
cluster = DatabricksCluster()
assert isinstance(cluster, Cluster)


def test_databricks_cluster_create_client(set_spark_local_ip, dask_cluster):
cluster = DatabricksCluster()
client = Client(cluster)
assert isinstance(client, Client)
assert client.submit(sum, (10, 1)).result() == 11
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ rye = { dev-dependencies = [
"pytest>=7.4.3",
] }

[tool.hatch.metadata]
allow-direct-references = true

[tool.hatch.version]
path = "dask_databricks/__about__.py"

Expand Down

0 comments on commit 441fb88

Please sign in to comment.