Skip to content

Commit

Permalink
Add Precommit (#23)
Browse files Browse the repository at this point in the history
* Add precommit config

* Format with black and run isort
  • Loading branch information
jacobtomlinson authored Nov 8, 2023
1 parent ca324d8 commit 62e8d1b
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 61 deletions.
18 changes: 18 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# NOTE: autoupdate does not pick up flake8-bugbear since it is a transitive
# dependency. Make sure to update flake8-bugbear manually on a regular basis.
repos:
- repo: https://github.com/psf/black
rev: 23.10.1
hooks:
- id: black
language_version: python3
exclude: versioneer.py
args:
- --target-version=py39
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: "v0.1.3"
hooks:
- id: ruff
language_version: python3
args: [--fix, --exit-non-zero-on-fix]
2 changes: 1 addition & 1 deletion dask_databricks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
#
# SPDX-License-Identifier: BSD-3

from .databrickscluster import DatabricksCluster, get_client
from .databrickscluster import DatabricksCluster, get_client # noqa
19 changes: 11 additions & 8 deletions dask_databricks/cli.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,39 @@
import click
import logging
import os
import socket
import subprocess
import sys
import time

import click
from rich.logging import RichHandler


def get_logger():
logging.basicConfig(
level="INFO", format="%(message)s", datefmt="[%X]", handlers=[RichHandler()]
)
logging.basicConfig(level="INFO", format="%(message)s", datefmt="[%X]", handlers=[RichHandler()])
return logging.getLogger("dask_databricks")


@click.group(name="databricks")
def main():
"""Tools to launch Dask on Databricks."""


@main.command()
def run():
"""Run Dask processes on a Databricks cluster."""
log = get_logger()

log.info("Setting up Dask on a Databricks cluster.")

DB_IS_DRIVER = os.getenv('DB_IS_DRIVER')
DB_DRIVER_IP = os.getenv('DB_DRIVER_IP')
DB_IS_DRIVER = os.getenv("DB_IS_DRIVER")
DB_DRIVER_IP = os.getenv("DB_DRIVER_IP")

if DB_DRIVER_IP is None or DB_IS_DRIVER is None:
log.error("Unable to find expected environment variables DB_IS_DRIVER and DB_DRIVER_IP. "
"Are you running this command on a Databricks multi-node cluster?")
log.error(
"Unable to find expected environment variables DB_IS_DRIVER and DB_DRIVER_IP. "
"Are you running this command on a Databricks multi-node cluster?"
)
sys.exit(1)

if DB_IS_DRIVER == "TRUE":
Expand Down
23 changes: 14 additions & 9 deletions dask_databricks/databrickscluster.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,30 @@
import os
import uuid
from typing import Optional

from distributed.deploy.cluster import Cluster
from distributed.core import rpc
from typing import Optional
from distributed.deploy.cluster import Cluster
from tornado.ioloop import IOLoop

# Databricks Notebooks injects the `spark` session variable
if 'spark' not in globals():
if "spark" not in globals():
spark = None


class DatabricksCluster(Cluster):
"""Connect to a Dask cluster deployed via databricks."""

def __init__(self,
def __init__(
self,
loop: Optional[IOLoop] = None,
asynchronous: bool = False,):
asynchronous: bool = False,
):
self.spark_local_ip = os.getenv("SPARK_LOCAL_IP")
if self.spark_local_ip is None:
raise KeyError("Unable to find expected environment variable SPARK_LOCAL_IP. "
"Are you running this on a Databricks driver node?")
raise KeyError(
"Unable to find expected environment variable SPARK_LOCAL_IP. "
"Are you running this on a Databricks driver node?"
)
try:
name = spark.conf.get("spark.databricks.clusterUsageTags.clusterId")
except AttributeError:
Expand All @@ -30,11 +35,11 @@ def __init__(self,
self._loop_runner.start()
self.sync(self._start)


async def _start(self):
self.scheduler_comm = rpc(f'{self.spark_local_ip}:8786')
self.scheduler_comm = rpc(f"{self.spark_local_ip}:8786")
await super()._start()


def get_client():
"""Get a Dask client connected to a Databricks cluster."""
return DatabricksCluster().get_client()
6 changes: 5 additions & 1 deletion dask_databricks/tests/test_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
from dask.distributed import Client
from distributed.deploy import Cluster, LocalCluster


from dask_databricks import DatabricksCluster, get_client


@pytest.fixture(scope="session")
def dask_cluster():
"""Start a LocalCluster to simulate the cluster that would be started on Databricks."""
return LocalCluster(scheduler_port=8786)


@pytest.fixture
def remove_spark_local_ip():
original_spark_local_ip = os.getenv("SPARK_LOCAL_IP")
Expand All @@ -21,6 +22,7 @@ def remove_spark_local_ip():
if original_spark_local_ip:
os.environ["SPARK_LOCAL_IP"] = original_spark_local_ip


@pytest.fixture
def set_spark_local_ip():
original_spark_local_ip = os.getenv("SPARK_LOCAL_IP")
Expand All @@ -31,10 +33,12 @@ def set_spark_local_ip():
else:
del os.environ["SPARK_LOCAL_IP"]


def test_databricks_cluster_raises_key_error_when_initialised_outside_of_databricks(remove_spark_local_ip):
with pytest.raises(KeyError):
DatabricksCluster()


def test_databricks_cluster_create(set_spark_local_ip, dask_cluster):
cluster = DatabricksCluster()
assert isinstance(cluster, Cluster)
Expand Down
80 changes: 38 additions & 42 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -107,50 +107,46 @@ line-length = 120
skip-string-normalization = true

[tool.ruff]
target-version = "py37"
line-length = 120
select = [
"A",
"ARG",
"B",
"C",
"DTZ",
"E",
"EM",
"F",
"FBT",
"I",
"ICN",
"ISC",
"N",
"PLC",
"PLE",
"PLR",
"PLW",
"Q",
"RUF",
"S",
"T",
"TID",
"UP",
"W",
"YTT",
]
ignore = [
# Allow non-abstract empty methods in abstract base classes
"B027",
# Allow boolean positional values in function calls, like `dict.get(... True)`
"FBT003",
# Ignore checks for possible passwords
"S105", "S106", "S107",
# Ignore complexity
"C901", "PLR0911", "PLR0912", "PLR0913", "PLR0915",
]
unfixable = [
# Don't touch unused imports
"F401",
# Enable pycodestyle (`E`) and Pyflakes (`F`) codes by default.
select = ["E", "F", "I"]
ignore = []

# Allow autofix for all enabled rules (when `--fix`) is provided.
fixable = ["I"]
# unfixable = []

# Exclude a variety of commonly ignored directories.
exclude = [
".bzr",
".direnv",
".eggs",
".git",
".hg",
".mypy_cache",
".nox",
".pants.d",
".pytype",
".ruff_cache",
".svn",
".tox",
".venv",
"__pypackages__",
"_build",
"buck-out",
"build",
"dist",
"node_modules",
"venv",
]

line-length = 120

# Allow unused variables when underscore-prefixed.
dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"

# Assume Python 3.10.
target-version = "py310"

[tool.ruff.isort]
known-first-party = ["dask_databricks"]

Expand Down

0 comments on commit 62e8d1b

Please sign in to comment.