Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add alternative worker commands, config options #20

Merged
merged 12 commits into from
Nov 9, 2023
52 changes: 47 additions & 5 deletions dask_databricks/cli.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import logging
import os
import socket
Expand All @@ -20,7 +21,16 @@ def main():


@main.command()
def run():
@click.option('--worker-command', help='Custom worker command')
@click.option('--worker-args', help='Additional worker arguments')
@click.option(
"--cuda",
is_flag=True,
show_default=True,
default=False,
help="Use `dask cuda worker` from the dask-cuda package when starting workers.",
)
def run(worker_command, worker_args, cuda):
"""Run Dask processes on a Databricks cluster."""
log = get_logger()

Expand All @@ -38,20 +48,52 @@ def run():

if DB_IS_DRIVER == "TRUE":
log.info("This node is the Dask scheduler.")
subprocess.Popen(["dask", "scheduler", "--dashboard-address", ":8787,:8087"])
scheduler_process = subprocess.Popen(["dask", "scheduler", "--dashboard-address", ":8787,:8087"])
time.sleep(5) # give the scheduler time to start
if scheduler_process.poll() is not None:
log.error("Scheduler process has exited prematurely.")
sys.exit(1)
else:
# Specify the same port for all workers
worker_port = 8786
log.info("This node is a Dask worker.")
log.info(f"Connecting to Dask scheduler at {DB_DRIVER_IP}:8786")
log.info(f"Connecting to Dask scheduler at {DB_DRIVER_IP}:{worker_port}")
while True:
try:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect((DB_DRIVER_IP, 8786))
sock.connect((DB_DRIVER_IP, worker_port))
sock.close()
break
except ConnectionRefusedError:
log.info("Scheduler not available yet. Waiting...")
time.sleep(1)
subprocess.Popen(["dask", "worker", f"tcp://{DB_DRIVER_IP}:8786"])

# Construct the worker command
if worker_command:
worker_command = worker_command.split()
elif cuda:
worker_command = ["dask", "cuda", "worker"]
else:
worker_command = ["dask", "worker"]

if worker_args:
try:
# Try to decode the JSON-encoded worker_args
worker_args_list = json.loads(worker_args)
if not isinstance(worker_args_list, list):
raise ValueError("The JSON-encoded worker_args must be a list.")
except json.JSONDecodeError:
# If decoding as JSON fails, split worker_args by spaces
worker_args_list = worker_args.split()

worker_command.extend(worker_args_list)
worker_command.append(f"tcp://{DB_DRIVER_IP}:{worker_port}")

worker_process = subprocess.Popen(worker_command)
time.sleep(5) # give the worker time to start
if worker_process.poll() is not None:
log.error("Worker process has exited prematurely.")
sys.exit(1)


if __name__ == "__main__":
Expand Down
Loading