diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml new file mode 100644 index 0000000..01e5b19 --- /dev/null +++ b/.github/workflows/test.yaml @@ -0,0 +1,27 @@ +name: "Test" +on: + pull_request: + push: + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + +jobs: + test: + runs-on: ubuntu-latest + timeout-minutes: 45 + strategy: + fail-fast: false + matrix: + python-version: ["3.9", "3.10", "3.11"] + + steps: + - uses: actions/checkout@v2 + - uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install hatch + run: pipx install hatch + - name: Run tests + run: hatch run test:run diff --git a/.gitignore b/.gitignore index 68bc17f..8b276f9 100644 --- a/.gitignore +++ b/.gitignore @@ -158,3 +158,5 @@ cython_debug/ # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ + +_version.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..5f0e401 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,18 @@ +# NOTE: autoupdate does not pick up flake8-bugbear since it is a transitive +# dependency. Make sure to update flake8-bugbear manually on a regular basis. +repos: + - repo: https://github.com/psf/black + rev: 23.10.1 + hooks: + - id: black + language_version: python3 + exclude: versioneer.py + args: + - --target-version=py39 + - repo: https://github.com/astral-sh/ruff-pre-commit + # Ruff version. + rev: "v0.1.3" + hooks: + - id: ruff + language_version: python3 + args: [--fix, --exit-non-zero-on-fix] diff --git a/README.md b/README.md index 2cb357a..f3dca96 100644 --- a/README.md +++ b/README.md @@ -17,14 +17,12 @@ To launch a Dask cluster on Databricks you need to create an [init script](https dask databricks run ``` -Then from your Databricks Notebook you can use the `DatabricksCluster` class to quickly connect a Dask `Client` to the scheduler running on the Spark Driver Node. +Then from your Databricks Notebook you can quickly connect a Dask `Client` to the scheduler running on the Spark Driver Node. ```python -from dask.distributed import Client -from dask_databricks import DatabricksCluster +import dask_databricks -cluster = DatabricksCluster() -client = Client(cluster) +client = dask_databricks.get_client() ``` Now you can submit work from your notebook to the multi-node Dask cluster. @@ -36,3 +34,16 @@ def inc(x): x = client.submit(inc, 10) x.result() ``` + +### Dashboard + +You can access the [Dask dashboard](https://docs.dask.org/en/latest/dashboard.html) via the Databricks driver-node proxy. The link can be found in `Client` or `DatabricksCluster` repr or via `client.dashboard_link`. + +```python +>>> print(client.dashboard_link) +https://dbc-dp-xxxx.cloud.databricks.com/driver-proxy/o/xxxx/xx-xxx-xxxx/8087/status +``` + +![](https://user-images.githubusercontent.com/1610850/281442274-450d41c6-2eb6-42a1-8de6-c4a1a1b84193.png) + +![](https://user-images.githubusercontent.com/1610850/281441285-9b84d5f1-d58a-45dc-9354-7385e1599d1f.png) diff --git a/dask_databricks/__about__.py b/dask_databricks/__about__.py deleted file mode 100644 index 2d6adc3..0000000 --- a/dask_databricks/__about__.py +++ /dev/null @@ -1,4 +0,0 @@ -# SPDX-FileCopyrightText: 2023-present Dask Developers -# -# SPDX-License-Identifier: BSD-3 -__version__ = "0.0.1" diff --git a/dask_databricks/__init__.py b/dask_databricks/__init__.py index e2e5750..bd6f92a 100644 --- a/dask_databricks/__init__.py +++ b/dask_databricks/__init__.py @@ -2,4 +2,22 @@ # # SPDX-License-Identifier: BSD-3 -from .databrickscluster import DatabricksCluster +from .databrickscluster import DatabricksCluster, get_client # noqa + +# Define the variable '__version__': +try: + # If setuptools_scm is installed (e.g. in a development environment with + # an editable install), then use it to determine the version dynamically. + from setuptools_scm import get_version + + # This will fail with LookupError if the package is not installed in + # editable mode or if Git is not installed. + __version__ = get_version(root="..", relative_to=__file__) +except (ImportError, LookupError): + # As a fallback, use the version that is hard-coded in the file. + try: + from dask_databricks._version import __version__ # noqa: F401 + except ModuleNotFoundError: + # The user is probably trying to run this without having installed + # the package, so complain. + raise RuntimeError("dask-databricks is not correctly installed. " "Please install it with pip.") diff --git a/dask_databricks/cli.py b/dask_databricks/cli.py index d5548b1..e94b830 100644 --- a/dask_databricks/cli.py +++ b/dask_databricks/cli.py @@ -1,4 +1,3 @@ -import click import logging import os import socket @@ -6,18 +5,20 @@ import sys import time +import click from rich.logging import RichHandler + def get_logger(): - logging.basicConfig( - level="INFO", format="%(message)s", datefmt="[%X]", handlers=[RichHandler()] - ) + logging.basicConfig(level="INFO", format="%(message)s", datefmt="[%X]", handlers=[RichHandler()]) return logging.getLogger("dask_databricks") + @click.group(name="databricks") def main(): """Tools to launch Dask on Databricks.""" + @main.command() @click.option('--worker-command', help='Custom worker command') @click.option('--worker-args', help='Additional worker arguments as a single string') @@ -27,17 +28,19 @@ def run(worker_command, worker_args): log.info("Setting up Dask on a Databricks cluster.") - DB_IS_DRIVER = os.getenv('DB_IS_DRIVER') - DB_DRIVER_IP = os.getenv('DB_DRIVER_IP') + DB_IS_DRIVER = os.getenv("DB_IS_DRIVER") + DB_DRIVER_IP = os.getenv("DB_DRIVER_IP") if DB_DRIVER_IP is None or DB_IS_DRIVER is None: - log.error("Unable to find expected environment variables DB_IS_DRIVER and DB_DRIVER_IP. " - "Are you running this command on a Databricks multi-node cluster?") + log.error( + "Unable to find expected environment variables DB_IS_DRIVER and DB_DRIVER_IP. " + "Are you running this command on a Databricks multi-node cluster?" + ) sys.exit(1) if DB_IS_DRIVER == "TRUE": log.info("This node is the Dask scheduler.") - subprocess.Popen(["dask", "scheduler"]) + subprocess.Popen(["dask", "scheduler", "--dashboard-address", ":8787,:8087"]) else: log.info("This node is a Dask worker.") log.info(f"Connecting to Dask scheduler at {DB_DRIVER_IP}:8786") diff --git a/dask_databricks/databrickscluster.py b/dask_databricks/databrickscluster.py index ea10f75..96be4f8 100644 --- a/dask_databricks/databrickscluster.py +++ b/dask_databricks/databrickscluster.py @@ -1,25 +1,34 @@ import os import uuid +from typing import Optional -from distributed.deploy.cluster import Cluster from distributed.core import rpc -from typing import Optional +from distributed.deploy.cluster import Cluster from tornado.ioloop import IOLoop -# Databricks Notebooks injects the `spark` session variable -if 'spark' not in globals(): +# Databricks Notebooks injects the `spark` session variable but we need to create it ourselves +try: + from pyspark.sql import SparkSession + + spark = SparkSession.getActiveSession() +except ImportError: spark = None + class DatabricksCluster(Cluster): """Connect to a Dask cluster deployed via databricks.""" - def __init__(self, + def __init__( + self, loop: Optional[IOLoop] = None, - asynchronous: bool = False,): + asynchronous: bool = False, + ): self.spark_local_ip = os.getenv("SPARK_LOCAL_IP") if self.spark_local_ip is None: - raise KeyError("Unable to find expected environment variable SPARK_LOCAL_IP. " - "Are you running this on a Databricks driver node?") + raise KeyError( + "Unable to find expected environment variable SPARK_LOCAL_IP. " + "Are you running this on a Databricks driver node?" + ) try: name = spark.conf.get("spark.databricks.clusterUsageTags.clusterId") except AttributeError: @@ -30,7 +39,17 @@ def __init__(self, self._loop_runner.start() self.sync(self._start) - async def _start(self): - self.scheduler_comm = rpc(f'{self.spark_local_ip}:8786') + self.scheduler_comm = rpc(f"{self.spark_local_ip}:8786") await super()._start() + + @property + def dashboard_link(self): + cluster_id = spark.conf.get("spark.databricks.clusterUsageTags.clusterId") + org_id = spark.conf.get("spark.databricks.clusterUsageTags.orgId") + return f"https://dbc-dp-{org_id}.cloud.databricks.com/driver-proxy/o/{org_id}/{cluster_id}/8087/status" + + +def get_client(): + """Get a Dask client connected to a Databricks cluster.""" + return DatabricksCluster().get_client() diff --git a/dask_databricks/tests/test_databricks.py b/dask_databricks/tests/test_databricks.py index 18dfc15..62e56c8 100644 --- a/dask_databricks/tests/test_databricks.py +++ b/dask_databricks/tests/test_databricks.py @@ -4,14 +4,15 @@ from dask.distributed import Client from distributed.deploy import Cluster, LocalCluster +from dask_databricks import DatabricksCluster, get_client -from dask_databricks import DatabricksCluster @pytest.fixture(scope="session") def dask_cluster(): """Start a LocalCluster to simulate the cluster that would be started on Databricks.""" return LocalCluster(scheduler_port=8786) + @pytest.fixture def remove_spark_local_ip(): original_spark_local_ip = os.getenv("SPARK_LOCAL_IP") @@ -21,6 +22,7 @@ def remove_spark_local_ip(): if original_spark_local_ip: os.environ["SPARK_LOCAL_IP"] = original_spark_local_ip + @pytest.fixture def set_spark_local_ip(): original_spark_local_ip = os.getenv("SPARK_LOCAL_IP") @@ -31,10 +33,12 @@ def set_spark_local_ip(): else: del os.environ["SPARK_LOCAL_IP"] + def test_databricks_cluster_raises_key_error_when_initialised_outside_of_databricks(remove_spark_local_ip): with pytest.raises(KeyError): DatabricksCluster() + def test_databricks_cluster_create(set_spark_local_ip, dask_cluster): cluster = DatabricksCluster() assert isinstance(cluster, Cluster) @@ -45,3 +49,10 @@ def test_databricks_cluster_create_client(set_spark_local_ip, dask_cluster): client = Client(cluster) assert isinstance(client, Client) assert client.submit(sum, (10, 1)).result() == 11 + + +def test_get_client(set_spark_local_ip, dask_cluster): + client = get_client() + assert isinstance(client, Client) + assert isinstance(client.cluster, DatabricksCluster) + assert client.submit(sum, (10, 1)).result() == 11 diff --git a/pyproject.toml b/pyproject.toml index dbe796f..6da2ea3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ [build-system] -requires = ["hatchling"] +requires = ["hatchling", "hatch-vcs"] build-backend = "hatchling.build" [project] @@ -47,7 +47,10 @@ rye = { dev-dependencies = [ allow-direct-references = true [tool.hatch.version] -path = "dask_databricks/__about__.py" +source = "vcs" + +[tool.hatch.build.hooks.vcs] +version-file = "dask_databricks/_version.py" [tool.hatch.envs.default] dependencies = [ @@ -92,56 +95,61 @@ all = [ "typing", ] +[tool.hatch.envs.test] +dependencies = [ + "pytest>=7.2.2", + "pytest-timeout>=2.1.0", +] + +[tool.hatch.envs.test.scripts] +run = "pytest" + [tool.black] target-version = ["py37"] line-length = 120 skip-string-normalization = true [tool.ruff] -target-version = "py37" -line-length = 120 -select = [ - "A", - "ARG", - "B", - "C", - "DTZ", - "E", - "EM", - "F", - "FBT", - "I", - "ICN", - "ISC", - "N", - "PLC", - "PLE", - "PLR", - "PLW", - "Q", - "RUF", - "S", - "T", - "TID", - "UP", - "W", - "YTT", -] -ignore = [ - # Allow non-abstract empty methods in abstract base classes - "B027", - # Allow boolean positional values in function calls, like `dict.get(... True)` - "FBT003", - # Ignore checks for possible passwords - "S105", "S106", "S107", - # Ignore complexity - "C901", "PLR0911", "PLR0912", "PLR0913", "PLR0915", -] -unfixable = [ - # Don't touch unused imports - "F401", +# Enable pycodestyle (`E`) and Pyflakes (`F`) codes by default. +select = ["E", "F", "I"] +ignore = [] + +# Allow autofix for all enabled rules (when `--fix`) is provided. +fixable = ["I"] +# unfixable = [] + +# Exclude a variety of commonly ignored directories. +exclude = [ + ".bzr", + ".direnv", + ".eggs", + ".git", + ".hg", + ".mypy_cache", + ".nox", + ".pants.d", + ".pytype", + ".ruff_cache", + ".svn", + ".tox", + ".venv", + "__pypackages__", + "_build", + "buck-out", + "build", + "dist", + "node_modules", + "venv", ] +line-length = 120 + +# Allow unused variables when underscore-prefixed. +dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" + +# Assume Python 3.10. +target-version = "py310" + [tool.ruff.isort] known-first-party = ["dask_databricks"]