From 2e42701783de9eb5b92e428875fad738348fab12 Mon Sep 17 00:00:00 2001 From: Sheilah Kirui <71867292+skirui-source@users.noreply.github.com> Date: Thu, 9 Nov 2023 08:52:12 -0800 Subject: [PATCH] Add alternative worker commands, config options (#20) * add alternative worker commands, config options * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add support for Json worker args * polls the scheduler for health check * extra health checks for dask workers * Handle dask not being on the path * Revert sys.executable change * Fix when worker args are not specified * Clean up and add --cuda flag --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jacob Tomlinson --- dask_databricks/cli.py | 52 ++++++++++++++++++++++++++++++++++++++---- 1 file changed, 47 insertions(+), 5 deletions(-) diff --git a/dask_databricks/cli.py b/dask_databricks/cli.py index 0651ca5..91d317a 100644 --- a/dask_databricks/cli.py +++ b/dask_databricks/cli.py @@ -1,3 +1,4 @@ +import json import logging import os import socket @@ -20,7 +21,16 @@ def main(): @main.command() -def run(): +@click.option('--worker-command', help='Custom worker command') +@click.option('--worker-args', help='Additional worker arguments') +@click.option( + "--cuda", + is_flag=True, + show_default=True, + default=False, + help="Use `dask cuda worker` from the dask-cuda package when starting workers.", +) +def run(worker_command, worker_args, cuda): """Run Dask processes on a Databricks cluster.""" log = get_logger() @@ -38,20 +48,52 @@ def run(): if DB_IS_DRIVER == "TRUE": log.info("This node is the Dask scheduler.") - subprocess.Popen(["dask", "scheduler", "--dashboard-address", ":8787,:8087"]) + scheduler_process = subprocess.Popen(["dask", "scheduler", "--dashboard-address", ":8787,:8087"]) + time.sleep(5) # give the scheduler time to start + if scheduler_process.poll() is not None: + log.error("Scheduler process has exited prematurely.") + sys.exit(1) else: + # Specify the same port for all workers + worker_port = 8786 log.info("This node is a Dask worker.") - log.info(f"Connecting to Dask scheduler at {DB_DRIVER_IP}:8786") + log.info(f"Connecting to Dask scheduler at {DB_DRIVER_IP}:{worker_port}") while True: try: sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.connect((DB_DRIVER_IP, 8786)) + sock.connect((DB_DRIVER_IP, worker_port)) sock.close() break except ConnectionRefusedError: log.info("Scheduler not available yet. Waiting...") time.sleep(1) - subprocess.Popen(["dask", "worker", f"tcp://{DB_DRIVER_IP}:8786"]) + + # Construct the worker command + if worker_command: + worker_command = worker_command.split() + elif cuda: + worker_command = ["dask", "cuda", "worker"] + else: + worker_command = ["dask", "worker"] + + if worker_args: + try: + # Try to decode the JSON-encoded worker_args + worker_args_list = json.loads(worker_args) + if not isinstance(worker_args_list, list): + raise ValueError("The JSON-encoded worker_args must be a list.") + except json.JSONDecodeError: + # If decoding as JSON fails, split worker_args by spaces + worker_args_list = worker_args.split() + + worker_command.extend(worker_args_list) + worker_command.append(f"tcp://{DB_DRIVER_IP}:{worker_port}") + + worker_process = subprocess.Popen(worker_command) + time.sleep(5) # give the worker time to start + if worker_process.poll() is not None: + log.error("Worker process has exited prematurely.") + sys.exit(1) if __name__ == "__main__":