From 4d112726906f2f2d90c55d6f23161c9fe8c8ede5 Mon Sep 17 00:00:00 2001 From: Jacob Tomlinson Date: Wed, 8 Nov 2023 10:12:59 +0000 Subject: [PATCH] Format with black and run isort --- dask_databricks/__init__.py | 2 +- dask_databricks/cli.py | 19 +++++++++++-------- dask_databricks/databrickscluster.py | 23 ++++++++++++++--------- dask_databricks/tests/test_databricks.py | 6 +++++- 4 files changed, 31 insertions(+), 19 deletions(-) diff --git a/dask_databricks/__init__.py b/dask_databricks/__init__.py index 04d7674..916d210 100644 --- a/dask_databricks/__init__.py +++ b/dask_databricks/__init__.py @@ -2,4 +2,4 @@ # # SPDX-License-Identifier: BSD-3 -from .databrickscluster import DatabricksCluster, get_client +from .databrickscluster import DatabricksCluster, get_client # noqa diff --git a/dask_databricks/cli.py b/dask_databricks/cli.py index 7849d06..09badd3 100644 --- a/dask_databricks/cli.py +++ b/dask_databricks/cli.py @@ -1,4 +1,3 @@ -import click import logging import os import socket @@ -6,18 +5,20 @@ import sys import time +import click from rich.logging import RichHandler + def get_logger(): - logging.basicConfig( - level="INFO", format="%(message)s", datefmt="[%X]", handlers=[RichHandler()] - ) + logging.basicConfig(level="INFO", format="%(message)s", datefmt="[%X]", handlers=[RichHandler()]) return logging.getLogger("dask_databricks") + @click.group(name="databricks") def main(): """Tools to launch Dask on Databricks.""" + @main.command() def run(): """Run Dask processes on a Databricks cluster.""" @@ -25,12 +26,14 @@ def run(): log.info("Setting up Dask on a Databricks cluster.") - DB_IS_DRIVER = os.getenv('DB_IS_DRIVER') - DB_DRIVER_IP = os.getenv('DB_DRIVER_IP') + DB_IS_DRIVER = os.getenv("DB_IS_DRIVER") + DB_DRIVER_IP = os.getenv("DB_DRIVER_IP") if DB_DRIVER_IP is None or DB_IS_DRIVER is None: - log.error("Unable to find expected environment variables DB_IS_DRIVER and DB_DRIVER_IP. " - "Are you running this command on a Databricks multi-node cluster?") + log.error( + "Unable to find expected environment variables DB_IS_DRIVER and DB_DRIVER_IP. " + "Are you running this command on a Databricks multi-node cluster?" + ) sys.exit(1) if DB_IS_DRIVER == "TRUE": diff --git a/dask_databricks/databrickscluster.py b/dask_databricks/databrickscluster.py index 69a1b44..7a83bc5 100644 --- a/dask_databricks/databrickscluster.py +++ b/dask_databricks/databrickscluster.py @@ -1,25 +1,30 @@ import os import uuid +from typing import Optional -from distributed.deploy.cluster import Cluster from distributed.core import rpc -from typing import Optional +from distributed.deploy.cluster import Cluster from tornado.ioloop import IOLoop # Databricks Notebooks injects the `spark` session variable -if 'spark' not in globals(): +if "spark" not in globals(): spark = None + class DatabricksCluster(Cluster): """Connect to a Dask cluster deployed via databricks.""" - def __init__(self, + def __init__( + self, loop: Optional[IOLoop] = None, - asynchronous: bool = False,): + asynchronous: bool = False, + ): self.spark_local_ip = os.getenv("SPARK_LOCAL_IP") if self.spark_local_ip is None: - raise KeyError("Unable to find expected environment variable SPARK_LOCAL_IP. " - "Are you running this on a Databricks driver node?") + raise KeyError( + "Unable to find expected environment variable SPARK_LOCAL_IP. " + "Are you running this on a Databricks driver node?" + ) try: name = spark.conf.get("spark.databricks.clusterUsageTags.clusterId") except AttributeError: @@ -30,11 +35,11 @@ def __init__(self, self._loop_runner.start() self.sync(self._start) - async def _start(self): - self.scheduler_comm = rpc(f'{self.spark_local_ip}:8786') + self.scheduler_comm = rpc(f"{self.spark_local_ip}:8786") await super()._start() + def get_client(): """Get a Dask client connected to a Databricks cluster.""" return DatabricksCluster().get_client() diff --git a/dask_databricks/tests/test_databricks.py b/dask_databricks/tests/test_databricks.py index 70e4052..62e56c8 100644 --- a/dask_databricks/tests/test_databricks.py +++ b/dask_databricks/tests/test_databricks.py @@ -4,14 +4,15 @@ from dask.distributed import Client from distributed.deploy import Cluster, LocalCluster - from dask_databricks import DatabricksCluster, get_client + @pytest.fixture(scope="session") def dask_cluster(): """Start a LocalCluster to simulate the cluster that would be started on Databricks.""" return LocalCluster(scheduler_port=8786) + @pytest.fixture def remove_spark_local_ip(): original_spark_local_ip = os.getenv("SPARK_LOCAL_IP") @@ -21,6 +22,7 @@ def remove_spark_local_ip(): if original_spark_local_ip: os.environ["SPARK_LOCAL_IP"] = original_spark_local_ip + @pytest.fixture def set_spark_local_ip(): original_spark_local_ip = os.getenv("SPARK_LOCAL_IP") @@ -31,10 +33,12 @@ def set_spark_local_ip(): else: del os.environ["SPARK_LOCAL_IP"] + def test_databricks_cluster_raises_key_error_when_initialised_outside_of_databricks(remove_spark_local_ip): with pytest.raises(KeyError): DatabricksCluster() + def test_databricks_cluster_create(set_spark_local_ip, dask_cluster): cluster = DatabricksCluster() assert isinstance(cluster, Cluster)