diff --git a/dask_databricks/cli.py b/dask_databricks/cli.py index 7849d06..d5548b1 100644 --- a/dask_databricks/cli.py +++ b/dask_databricks/cli.py @@ -19,7 +19,9 @@ def main(): """Tools to launch Dask on Databricks.""" @main.command() -def run(): +@click.option('--worker-command', help='Custom worker command') +@click.option('--worker-args', help='Additional worker arguments as a single string') +def run(worker_command, worker_args): """Run Dask processes on a Databricks cluster.""" log = get_logger() @@ -48,7 +50,15 @@ def run(): except ConnectionRefusedError: log.info("Scheduler not available yet. Waiting...") time.sleep(1) - subprocess.Popen(["dask", "worker", f"tcp://{DB_DRIVER_IP}:8786"]) + + # Construct the worker command + worker_command = worker_command.split() if worker_command else ["dask", "worker"] + if worker_args: + worker_command.extend(worker_args.split()) + + worker_command.append(f"tcp://{DB_DRIVER_IP}:8786") + + subprocess.Popen(worker_command) if __name__ == "__main__":