Skip to content

Commit

Permalink
Format with black and run isort
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobtomlinson committed Nov 8, 2023
1 parent 536a751 commit 4d11272
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 19 deletions.
2 changes: 1 addition & 1 deletion dask_databricks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
#
# SPDX-License-Identifier: BSD-3

from .databrickscluster import DatabricksCluster, get_client
from .databrickscluster import DatabricksCluster, get_client # noqa
19 changes: 11 additions & 8 deletions dask_databricks/cli.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,39 @@
import click
import logging
import os
import socket
import subprocess
import sys
import time

import click
from rich.logging import RichHandler


def get_logger():
logging.basicConfig(
level="INFO", format="%(message)s", datefmt="[%X]", handlers=[RichHandler()]
)
logging.basicConfig(level="INFO", format="%(message)s", datefmt="[%X]", handlers=[RichHandler()])
return logging.getLogger("dask_databricks")


@click.group(name="databricks")
def main():
"""Tools to launch Dask on Databricks."""


@main.command()
def run():
"""Run Dask processes on a Databricks cluster."""
log = get_logger()

log.info("Setting up Dask on a Databricks cluster.")

DB_IS_DRIVER = os.getenv('DB_IS_DRIVER')
DB_DRIVER_IP = os.getenv('DB_DRIVER_IP')
DB_IS_DRIVER = os.getenv("DB_IS_DRIVER")
DB_DRIVER_IP = os.getenv("DB_DRIVER_IP")

if DB_DRIVER_IP is None or DB_IS_DRIVER is None:
log.error("Unable to find expected environment variables DB_IS_DRIVER and DB_DRIVER_IP. "
"Are you running this command on a Databricks multi-node cluster?")
log.error(
"Unable to find expected environment variables DB_IS_DRIVER and DB_DRIVER_IP. "
"Are you running this command on a Databricks multi-node cluster?"
)
sys.exit(1)

if DB_IS_DRIVER == "TRUE":
Expand Down
23 changes: 14 additions & 9 deletions dask_databricks/databrickscluster.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,30 @@
import os
import uuid
from typing import Optional

from distributed.deploy.cluster import Cluster
from distributed.core import rpc
from typing import Optional
from distributed.deploy.cluster import Cluster
from tornado.ioloop import IOLoop

# Databricks Notebooks injects the `spark` session variable
if 'spark' not in globals():
if "spark" not in globals():
spark = None


class DatabricksCluster(Cluster):
"""Connect to a Dask cluster deployed via databricks."""

def __init__(self,
def __init__(
self,
loop: Optional[IOLoop] = None,
asynchronous: bool = False,):
asynchronous: bool = False,
):
self.spark_local_ip = os.getenv("SPARK_LOCAL_IP")
if self.spark_local_ip is None:
raise KeyError("Unable to find expected environment variable SPARK_LOCAL_IP. "
"Are you running this on a Databricks driver node?")
raise KeyError(
"Unable to find expected environment variable SPARK_LOCAL_IP. "
"Are you running this on a Databricks driver node?"
)
try:
name = spark.conf.get("spark.databricks.clusterUsageTags.clusterId")
except AttributeError:
Expand All @@ -30,11 +35,11 @@ def __init__(self,
self._loop_runner.start()
self.sync(self._start)


async def _start(self):
self.scheduler_comm = rpc(f'{self.spark_local_ip}:8786')
self.scheduler_comm = rpc(f"{self.spark_local_ip}:8786")
await super()._start()


def get_client():
"""Get a Dask client connected to a Databricks cluster."""
return DatabricksCluster().get_client()
6 changes: 5 additions & 1 deletion dask_databricks/tests/test_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
from dask.distributed import Client
from distributed.deploy import Cluster, LocalCluster


from dask_databricks import DatabricksCluster, get_client


@pytest.fixture(scope="session")
def dask_cluster():
"""Start a LocalCluster to simulate the cluster that would be started on Databricks."""
return LocalCluster(scheduler_port=8786)


@pytest.fixture
def remove_spark_local_ip():
original_spark_local_ip = os.getenv("SPARK_LOCAL_IP")
Expand All @@ -21,6 +22,7 @@ def remove_spark_local_ip():
if original_spark_local_ip:
os.environ["SPARK_LOCAL_IP"] = original_spark_local_ip


@pytest.fixture
def set_spark_local_ip():
original_spark_local_ip = os.getenv("SPARK_LOCAL_IP")
Expand All @@ -31,10 +33,12 @@ def set_spark_local_ip():
else:
del os.environ["SPARK_LOCAL_IP"]


def test_databricks_cluster_raises_key_error_when_initialised_outside_of_databricks(remove_spark_local_ip):
with pytest.raises(KeyError):
DatabricksCluster()


def test_databricks_cluster_create(set_spark_local_ip, dask_cluster):
cluster = DatabricksCluster()
assert isinstance(cluster, Cluster)
Expand Down

0 comments on commit 4d11272

Please sign in to comment.