Skip to content

Commit

Permalink
[azure]: Added extra_vm_options (#407)
Browse files Browse the repository at this point in the history
Co-authored-by: Tom Augspurger <taugspurger@microsoft.com>
Co-authored-by: Jacob Tomlinson <jtomlinson@nvidia.com>
  • Loading branch information
3 people authored Sep 16, 2024
1 parent 536e182 commit 0bd7964
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 1 deletion.
18 changes: 17 additions & 1 deletion dask_cloudprovider/azure/azurevm.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def __init__(
extra_bootstrap=None,
auto_shutdown: bool = None,
marketplace_plan: dict = {},
extra_vm_options: Optional[dict] = None,
**kwargs,
):
super().__init__(*args, **kwargs)
Expand All @@ -71,6 +72,7 @@ def __init__(
self.auto_shutdown = auto_shutdown
self.env_vars = env_vars
self.marketplace_plan = marketplace_plan
self.extra_vm_options = extra_vm_options or {}

async def create_vm(self):
[subnet_info, *_] = await self.cluster.call_async(
Expand Down Expand Up @@ -179,6 +181,13 @@ async def create_vm(self):
vm_parameters["storage_profile"]["image_reference"]["version"] = "latest"
self.cluster._log("Using Marketplace VM image with a Plan")

repeated = self.extra_vm_options.keys() & vm_parameters.keys()
if repeated:
raise TypeError(
f"Parameters are passed in both 'extra_vm_options' and as regular parameters: {repeated}"
)
vm_parameters = {**self.extra_vm_options, **vm_parameters}

self.cluster._log("Creating VM")
if self.cluster.debug:
self.cluster._log(
Expand Down Expand Up @@ -344,6 +353,9 @@ class AzureVMCluster(VMCluster):
The ID of the Azure Subscription to create the virtual machines in. If not specified, then
dask-cloudprovider will attempt to use the configured default for the Azure CLI. List your
subscriptions with ``az account list``.
extra_vm_options: dict[str, Any]:
Additional arguments to provide to Azure's ``VirtualMachinesOperations.begin_create_or_update``
when creating the scheduler and worker VMs.
Examples
--------
Expand Down Expand Up @@ -472,6 +484,7 @@ def __init__(
debug: bool = False,
marketplace_plan: dict = {},
subscription_id: Optional[str] = None,
extra_vm_options: Optional[dict] = None,
**kwargs,
):
self.config = ClusterConfig(dask.config.get("cloudprovider.azure", {}))
Expand Down Expand Up @@ -550,7 +563,9 @@ def __init__(
"""To create a virtual machine from Marketplace image or a custom image sourced
from a Marketplace image with a plan, all 3 fields 'name', 'publisher' and 'product' must be passed."""
)

self.extra_vm_options = extra_vm_options or self.config.get(
"azurevm.extra_vm_options"
)
self.options = {
"cluster": self,
"config": self.config,
Expand All @@ -563,6 +578,7 @@ def __init__(
"auto_shutdown": self.auto_shutdown,
"docker_image": self.docker_image,
"marketplace_plan": self.marketplace_plan,
"extra_vm_options": self.extra_vm_options,
}
self.scheduler_options = {
"vm_size": self.scheduler_vm_size,
Expand Down
1 change: 1 addition & 0 deletions dask_cloudprovider/cloudprovider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ cloudprovider:
# name: "ngc-base-version-21-02-2"
# publisher: "nvidia"
# product: "ngc_azure_17_11"
extra_options: {} # Additional options to provide when creating the VMs.

digitalocean:
token: null # API token for interacting with the Digital Ocean API
Expand Down
33 changes: 33 additions & 0 deletions doc/source/azure.rst
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,39 @@ or specific IP.

Again take note of this security group name for later.

Extra options
^^^^^^^^^^^^^

To further customize the VMs created, you can provide ``extra_vm_options`` to :class:`AzureVMCluster`. For example, to set the identity
of the virtual machines to a (previously created) user assigned identity, create an ``azure.mgmt.compute.models.VirtualMachineIdentity``

.. code-block:: python
>>> import os
>>> import azure.identity
>>> import dask_cloudprovider.azure
>>> import azure.mgmt.compute.models
>>> subscription_id = os.environ["DASK_CLOUDPROVIDER__AZURE__SUBSCRIPTION_ID"]
>>> rg_name = os.environ["DASK_CLOUDPROVIDER__AZURE__RESOURCE_GROUP"]
>>> identity_name = "dask-cloudprovider-identity"
>>> v = azure.mgmt.compute.models.UserAssignedIdentitiesValue()
>>> user_assigned_identities = {
... f"/subscriptions/{subscription_id}/resourcegroups/{rg_name}/providers/Microsoft.ManagedIdentity/userAssignedIdentities/{identity_name}": v
... }
>>> identity = azure.mgmt.compute.models.VirtualMachineIdentity(
... type="UserAssigned",
... user_assigned_identities=user_assigned_identities
... )
And then provide that to :class:`AzureVMCluster`

.. code-block:: python
>>> cluster = dask_cloudprovider.azure.AzureVMCluster(extra_vm_options={"identity": identity.as_dict()})
>>> cluster.scale(1)
Dask Configuration
^^^^^^^^^^^^^^^^^^

Expand Down

0 comments on commit 0bd7964

Please sign in to comment.