Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Precommit #23

Merged
merged 2 commits into from
Nov 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# NOTE: autoupdate does not pick up flake8-bugbear since it is a transitive
# dependency. Make sure to update flake8-bugbear manually on a regular basis.
repos:
- repo: https://github.com/psf/black
rev: 23.10.1
hooks:
- id: black
language_version: python3
exclude: versioneer.py
args:
- --target-version=py39
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: "v0.1.3"
hooks:
- id: ruff
language_version: python3
args: [--fix, --exit-non-zero-on-fix]
2 changes: 1 addition & 1 deletion dask_databricks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
#
# SPDX-License-Identifier: BSD-3

from .databrickscluster import DatabricksCluster, get_client
from .databrickscluster import DatabricksCluster, get_client # noqa
19 changes: 11 additions & 8 deletions dask_databricks/cli.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,39 @@
import click
import logging
import os
import socket
import subprocess
import sys
import time

import click
from rich.logging import RichHandler


def get_logger():
logging.basicConfig(
level="INFO", format="%(message)s", datefmt="[%X]", handlers=[RichHandler()]
)
logging.basicConfig(level="INFO", format="%(message)s", datefmt="[%X]", handlers=[RichHandler()])
return logging.getLogger("dask_databricks")


@click.group(name="databricks")
def main():
"""Tools to launch Dask on Databricks."""


@main.command()
def run():
"""Run Dask processes on a Databricks cluster."""
log = get_logger()

log.info("Setting up Dask on a Databricks cluster.")

DB_IS_DRIVER = os.getenv('DB_IS_DRIVER')
DB_DRIVER_IP = os.getenv('DB_DRIVER_IP')
DB_IS_DRIVER = os.getenv("DB_IS_DRIVER")
DB_DRIVER_IP = os.getenv("DB_DRIVER_IP")

if DB_DRIVER_IP is None or DB_IS_DRIVER is None:
log.error("Unable to find expected environment variables DB_IS_DRIVER and DB_DRIVER_IP. "
"Are you running this command on a Databricks multi-node cluster?")
log.error(
"Unable to find expected environment variables DB_IS_DRIVER and DB_DRIVER_IP. "
"Are you running this command on a Databricks multi-node cluster?"
)
sys.exit(1)

if DB_IS_DRIVER == "TRUE":
Expand Down
23 changes: 14 additions & 9 deletions dask_databricks/databrickscluster.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,30 @@
import os
import uuid
from typing import Optional

from distributed.deploy.cluster import Cluster
from distributed.core import rpc
from typing import Optional
from distributed.deploy.cluster import Cluster
from tornado.ioloop import IOLoop

# Databricks Notebooks injects the `spark` session variable
if 'spark' not in globals():
if "spark" not in globals():
spark = None


class DatabricksCluster(Cluster):
"""Connect to a Dask cluster deployed via databricks."""

def __init__(self,
def __init__(
self,
loop: Optional[IOLoop] = None,
asynchronous: bool = False,):
asynchronous: bool = False,
):
self.spark_local_ip = os.getenv("SPARK_LOCAL_IP")
if self.spark_local_ip is None:
raise KeyError("Unable to find expected environment variable SPARK_LOCAL_IP. "
"Are you running this on a Databricks driver node?")
raise KeyError(
"Unable to find expected environment variable SPARK_LOCAL_IP. "
"Are you running this on a Databricks driver node?"
)
try:
name = spark.conf.get("spark.databricks.clusterUsageTags.clusterId")
except AttributeError:
Expand All @@ -30,11 +35,11 @@ def __init__(self,
self._loop_runner.start()
self.sync(self._start)


async def _start(self):
self.scheduler_comm = rpc(f'{self.spark_local_ip}:8786')
self.scheduler_comm = rpc(f"{self.spark_local_ip}:8786")
await super()._start()


def get_client():
"""Get a Dask client connected to a Databricks cluster."""
return DatabricksCluster().get_client()
6 changes: 5 additions & 1 deletion dask_databricks/tests/test_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
from dask.distributed import Client
from distributed.deploy import Cluster, LocalCluster


from dask_databricks import DatabricksCluster, get_client


@pytest.fixture(scope="session")
def dask_cluster():
"""Start a LocalCluster to simulate the cluster that would be started on Databricks."""
return LocalCluster(scheduler_port=8786)


@pytest.fixture
def remove_spark_local_ip():
original_spark_local_ip = os.getenv("SPARK_LOCAL_IP")
Expand All @@ -21,6 +22,7 @@ def remove_spark_local_ip():
if original_spark_local_ip:
os.environ["SPARK_LOCAL_IP"] = original_spark_local_ip


@pytest.fixture
def set_spark_local_ip():
original_spark_local_ip = os.getenv("SPARK_LOCAL_IP")
Expand All @@ -31,10 +33,12 @@ def set_spark_local_ip():
else:
del os.environ["SPARK_LOCAL_IP"]


def test_databricks_cluster_raises_key_error_when_initialised_outside_of_databricks(remove_spark_local_ip):
with pytest.raises(KeyError):
DatabricksCluster()


def test_databricks_cluster_create(set_spark_local_ip, dask_cluster):
cluster = DatabricksCluster()
assert isinstance(cluster, Cluster)
Expand Down
80 changes: 38 additions & 42 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -107,50 +107,46 @@ line-length = 120
skip-string-normalization = true

[tool.ruff]
target-version = "py37"
line-length = 120
select = [
"A",
"ARG",
"B",
"C",
"DTZ",
"E",
"EM",
"F",
"FBT",
"I",
"ICN",
"ISC",
"N",
"PLC",
"PLE",
"PLR",
"PLW",
"Q",
"RUF",
"S",
"T",
"TID",
"UP",
"W",
"YTT",
]
ignore = [
# Allow non-abstract empty methods in abstract base classes
"B027",
# Allow boolean positional values in function calls, like `dict.get(... True)`
"FBT003",
# Ignore checks for possible passwords
"S105", "S106", "S107",
# Ignore complexity
"C901", "PLR0911", "PLR0912", "PLR0913", "PLR0915",
]
unfixable = [
# Don't touch unused imports
"F401",
# Enable pycodestyle (`E`) and Pyflakes (`F`) codes by default.
select = ["E", "F", "I"]
ignore = []

# Allow autofix for all enabled rules (when `--fix`) is provided.
fixable = ["I"]
# unfixable = []

# Exclude a variety of commonly ignored directories.
exclude = [
".bzr",
".direnv",
".eggs",
".git",
".hg",
".mypy_cache",
".nox",
".pants.d",
".pytype",
".ruff_cache",
".svn",
".tox",
".venv",
"__pypackages__",
"_build",
"buck-out",
"build",
"dist",
"node_modules",
"venv",
]

line-length = 120

# Allow unused variables when underscore-prefixed.
dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"

# Assume Python 3.10.
target-version = "py310"

[tool.ruff.isort]
known-first-party = ["dask_databricks"]

Expand Down
Loading