diff --git a/dask_databricks/cli.py b/dask_databricks/cli.py index 09badd3..0651ca5 100644 --- a/dask_databricks/cli.py +++ b/dask_databricks/cli.py @@ -38,7 +38,7 @@ def run(): if DB_IS_DRIVER == "TRUE": log.info("This node is the Dask scheduler.") - subprocess.Popen(["dask", "scheduler"]) + subprocess.Popen(["dask", "scheduler", "--dashboard-address", ":8787,:8087"]) else: log.info("This node is a Dask worker.") log.info(f"Connecting to Dask scheduler at {DB_DRIVER_IP}:8786") diff --git a/dask_databricks/databrickscluster.py b/dask_databricks/databrickscluster.py index 7a83bc5..96be4f8 100644 --- a/dask_databricks/databrickscluster.py +++ b/dask_databricks/databrickscluster.py @@ -6,8 +6,12 @@ from distributed.deploy.cluster import Cluster from tornado.ioloop import IOLoop -# Databricks Notebooks injects the `spark` session variable -if "spark" not in globals(): +# Databricks Notebooks injects the `spark` session variable but we need to create it ourselves +try: + from pyspark.sql import SparkSession + + spark = SparkSession.getActiveSession() +except ImportError: spark = None @@ -39,6 +43,12 @@ async def _start(self): self.scheduler_comm = rpc(f"{self.spark_local_ip}:8786") await super()._start() + @property + def dashboard_link(self): + cluster_id = spark.conf.get("spark.databricks.clusterUsageTags.clusterId") + org_id = spark.conf.get("spark.databricks.clusterUsageTags.orgId") + return f"https://dbc-dp-{org_id}.cloud.databricks.com/driver-proxy/o/{org_id}/{cluster_id}/8087/status" + def get_client(): """Get a Dask client connected to a Databricks cluster."""