From c187166a427df5e7d9861d930a408f4ec2a12deb Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Wed, 25 Sep 2024 15:17:20 +0500 Subject: [PATCH] Support AWS placement groups for cluster fleets (#1725) * Support aws placement groups for cluster fleets * Test placement group processing * Adapt placement groups to work with dstack sky * Update docs * Fix tests --- docs/docs/concepts/fleets.md | 9 +- .../_internal/core/backends/__init__.py | 3 + .../_internal/core/backends/aws/compute.py | 43 ++++++- .../_internal/core/backends/aws/resources.py | 9 +- .../_internal/core/backends/base/compute.py | 24 +++- .../core/backends/remote/provisioning.py | 1 + src/dstack/_internal/core/errors.py | 4 + src/dstack/_internal/core/models/instances.py | 9 +- src/dstack/_internal/core/models/placement.py | 27 ++++ .../_internal/server/background/__init__.py | 4 + .../server/background/tasks/process_fleets.py | 15 ++- .../background/tasks/process_instances.py | 119 ++++++++++++++++-- .../tasks/process_placement_groups.py | 97 ++++++++++++++ .../a7b46c073fa1_add_placementgroupmodel.py | 58 +++++++++ src/dstack/_internal/server/models.py | 26 ++++ .../_internal/server/services/fleets.py | 25 ++++ .../_internal/server/services/placement.py | 53 ++++++++ src/dstack/_internal/server/services/pools.py | 16 ++- src/dstack/_internal/server/testing/common.py | 56 +++++++++ .../tasks/test_process_placement_groups.py | 42 +++++++ 20 files changed, 611 insertions(+), 29 deletions(-) create mode 100644 src/dstack/_internal/core/models/placement.py create mode 100644 src/dstack/_internal/server/background/tasks/process_placement_groups.py create mode 100644 src/dstack/_internal/server/migrations/versions/a7b46c073fa1_add_placementgroupmodel.py create mode 100644 src/dstack/_internal/server/services/placement.py create mode 100644 src/tests/_internal/server/background/tasks/test_process_placement_groups.py diff --git a/docs/docs/concepts/fleets.md b/docs/docs/concepts/fleets.md index 1ce53d22f..5b72804d8 100644 --- a/docs/docs/concepts/fleets.md +++ b/docs/docs/concepts/fleets.md @@ -49,8 +49,10 @@ are both acceptable). to the specified parameters. !!! info "Network" - Set `placement` to `cluster` if the nodes should be interconnected (e.g. if you'd like to use them for multi-node tasks). - In that case, `dstack` will provision all nodes in the same backend and region. + Set `placement` to `cluster` if the nodes should be interconnected + (e.g. if you'd like to use them for [multi-node tasks](reference/dstack.yml/task.md#distributed-tasks)). + In that case, `dstack` will provision all nodes in the same backend and region and configure the optimal + connectivity via availability zones, placement groups, etc. Note that cloud fleets aren't supported for the `kubernetes`, `vastai`, and `runpod` backends. @@ -120,7 +122,8 @@ are both acceptable). ``` !!! info "Network" - Set `placement` to `cluster` if the hosts are interconnected (e.g. if you'd like to use them for multi-node tasks). + Set `placement` to `cluster` if the hosts are interconnected + (e.g. if you'd like to use them for [multi-node tasks](reference/dstack.yml/task.md#distributed-tasks)). In that case, by default, `dstack` will automatically detect the private network. You can specify the [`network`](../reference/dstack.yml/fleet.md#network) parameter manually. diff --git a/src/dstack/_internal/core/backends/__init__.py b/src/dstack/_internal/core/backends/__init__.py index d9f17f527..729e0475b 100644 --- a/src/dstack/_internal/core/backends/__init__.py +++ b/src/dstack/_internal/core/backends/__init__.py @@ -18,6 +18,9 @@ BackendType.OCI, BackendType.TENSORDOCK, ] +BACKENDS_WITH_PLACEMENT_GROUPS_SUPPORT = [ + BackendType.AWS, +] BACKENDS_WITH_GATEWAY_SUPPORT = [ BackendType.AWS, BackendType.AZURE, diff --git a/src/dstack/_internal/core/backends/aws/compute.py b/src/dstack/_internal/core/backends/aws/compute.py index 34913a6d2..da4bcf7da 100644 --- a/src/dstack/_internal/core/backends/aws/compute.py +++ b/src/dstack/_internal/core/backends/aws/compute.py @@ -16,7 +16,7 @@ get_user_data, ) from dstack._internal.core.backends.base.offers import get_catalog_offers -from dstack._internal.core.errors import ComputeError, NoCapacityError +from dstack._internal.core.errors import ComputeError, NoCapacityError, PlacementGroupInUseError from dstack._internal.core.models.backends.aws import AWSAccessKeyCreds from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.common import CoreModel, is_core_model_instance @@ -31,6 +31,7 @@ InstanceOfferWithAvailability, SSHKey, ) +from dstack._internal.core.models.placement import PlacementGroup, PlacementGroupProvisioningData from dstack._internal.core.models.resources import Memory, Range from dstack._internal.core.models.runs import Job, JobProvisioningData, Requirements, Run from dstack._internal.core.models.volumes import ( @@ -116,7 +117,7 @@ def terminate_instance( ec2_client.terminate_instances(InstanceIds=[instance_id]) except botocore.exceptions.ClientError as e: if e.response["Error"]["Code"] == "InvalidInstanceID.NotFound": - pass + logger.debug("Skipping instance %s termination. Instance not found.", instance_id) else: raise e @@ -181,6 +182,7 @@ def create_instance( spot=instance_offer.instance.resources.spot, subnet_id=subnet_id, allocate_public_ip=allocate_public_ip, + placement_group_name=instance_config.placement_group_name, ) ) instance = response[0] @@ -239,6 +241,41 @@ def run_job( instance_config.availability_zone = volume.provisioning_data.availability_zone return self.create_instance(instance_offer, instance_config) + def create_placement_group( + self, + placement_group: PlacementGroup, + ) -> PlacementGroupProvisioningData: + ec2_client = self.session.client("ec2", region_name=placement_group.configuration.region) + logger.debug("Creating placement group %s...", placement_group.name) + ec2_client.create_placement_group( + GroupName=placement_group.name, + Strategy=placement_group.configuration.placement_strategy.value, + ) + logger.debug("Created placement group %s", placement_group.name) + return PlacementGroupProvisioningData( + backend=BackendType.AWS, + backend_data=None, + ) + + def delete_placement_group( + self, + placement_group: PlacementGroup, + ): + ec2_client = self.session.client("ec2", region_name=placement_group.configuration.region) + logger.debug("Deleting placement group %s...", placement_group.name) + try: + ec2_client.delete_placement_group(GroupName=placement_group.name) + except botocore.exceptions.ClientError as e: + if e.response["Error"]["Code"] == "InvalidPlacementGroup.Unknown": + logger.debug("Placement group %s not found", placement_group.name) + return + elif e.response["Error"]["Code"] == "InvalidPlacementGroup.InUse": + logger.debug("Placement group %s is in use", placement_group.name) + raise PlacementGroupInUseError() + else: + raise e + logger.debug("Deleted placement group %s", placement_group.name) + def create_gateway( self, configuration: GatewayComputeConfiguration, @@ -372,7 +409,7 @@ def create_gateway( def terminate_gateway( self, - instance_id, + instance_id: str, configuration: GatewayComputeConfiguration, backend_data: Optional[str] = None, ): diff --git a/src/dstack/_internal/core/backends/aws/resources.py b/src/dstack/_internal/core/backends/aws/resources.py index c5aeca80c..bd02defb0 100644 --- a/src/dstack/_internal/core/backends/aws/resources.py +++ b/src/dstack/_internal/core/backends/aws/resources.py @@ -105,8 +105,9 @@ def create_instances_struct( spot: bool, subnet_id: Optional[str] = None, allocate_public_ip: bool = True, + placement_group_name: Optional[str] = None, ) -> Dict[str, Any]: - struct = dict( + struct: Dict[str, Any] = dict( BlockDeviceMappings=[ { "DeviceName": "/dev/sda1", @@ -151,6 +152,12 @@ def create_instances_struct( ] else: struct["SecurityGroupIds"] = [security_group_id] + + if placement_group_name is not None: + struct["Placement"] = { + "GroupName": placement_group_name, + } + return struct diff --git a/src/dstack/_internal/core/backends/base/compute.py b/src/dstack/_internal/core/backends/base/compute.py index e3ac108fa..c6241dfe4 100644 --- a/src/dstack/_internal/core/backends/base/compute.py +++ b/src/dstack/_internal/core/backends/base/compute.py @@ -17,6 +17,7 @@ InstanceConfiguration, InstanceOfferWithAvailability, ) +from dstack._internal.core.models.placement import PlacementGroup, PlacementGroupProvisioningData from dstack._internal.core.models.runs import Job, JobProvisioningData, Requirements, Run from dstack._internal.core.models.volumes import ( Volume, @@ -62,8 +63,8 @@ def terminate_instance( backend_data: Optional[str] = None, ) -> None: """ - Terminates an instance by `instance_id`. If instance does not exist, - it should not raise errors but return silently. + Terminates an instance by `instance_id`. + If the instance does not exist, it should not raise errors but return silently. """ pass @@ -95,6 +96,25 @@ def update_provisioning_data( """ pass + def create_placement_group( + self, + placement_group: PlacementGroup, + ) -> PlacementGroupProvisioningData: + """ + Creates a placement group. + """ + raise NotImplementedError() + + def delete_placement_group( + self, + placement_group: PlacementGroup, + ): + """ + Deletes a placement group. + If the group does not exist, it should not raise errors but return silently. + """ + raise NotImplementedError() + def create_gateway( self, configuration: GatewayComputeConfiguration, diff --git a/src/dstack/_internal/core/backends/remote/provisioning.py b/src/dstack/_internal/core/backends/remote/provisioning.py index 67629887c..4cbd08763 100644 --- a/src/dstack/_internal/core/backends/remote/provisioning.py +++ b/src/dstack/_internal/core/backends/remote/provisioning.py @@ -7,6 +7,7 @@ import paramiko from gpuhunt import correct_gpu_memory_gib +# FIXME: ProvisioningError is a subclass of ComputeError and should not be used outside of Compute from dstack._internal.core.errors import ProvisioningError from dstack._internal.core.models.instances import ( Disk, diff --git a/src/dstack/_internal/core/errors.py b/src/dstack/_internal/core/errors.py index 7ae82f544..231eb7dbf 100644 --- a/src/dstack/_internal/core/errors.py +++ b/src/dstack/_internal/core/errors.py @@ -98,6 +98,10 @@ class ComputeResourceNotFoundError(ComputeError): pass +class PlacementGroupInUseError(ComputeError): + pass + + class CLIError(DstackError): pass diff --git a/src/dstack/_internal/core/models/instances.py b/src/dstack/_internal/core/models/instances.py index 218b5c21e..8c2f13075 100644 --- a/src/dstack/_internal/core/models/instances.py +++ b/src/dstack/_internal/core/models/instances.py @@ -99,12 +99,13 @@ class DockerConfig(CoreModel): class InstanceConfiguration(CoreModel): project_name: str - instance_name: str # unique in pool - instance_id: Optional[str] = None - ssh_keys: List[SSHKey] - job_docker_config: Optional[DockerConfig] + instance_name: str user: str # dstack user name + ssh_keys: List[SSHKey] + instance_id: Optional[str] = None availability_zone: Optional[str] = None + placement_group_name: Optional[str] = None + job_docker_config: Optional[DockerConfig] # FIXME: cannot find any usages – remove? def get_public_keys(self) -> List[str]: return [ssh_key.public.strip() for ssh_key in self.ssh_keys] diff --git a/src/dstack/_internal/core/models/placement.py b/src/dstack/_internal/core/models/placement.py new file mode 100644 index 000000000..93b0cf09d --- /dev/null +++ b/src/dstack/_internal/core/models/placement.py @@ -0,0 +1,27 @@ +from enum import Enum +from typing import Optional + +from dstack._internal.core.models.backends.base import BackendType +from dstack._internal.core.models.common import CoreModel + + +class PlacementStrategy(str, Enum): + CLUSTER = "cluster" + + +class PlacementGroupConfiguration(CoreModel): + backend: BackendType + region: str + placement_strategy: PlacementStrategy + + +class PlacementGroupProvisioningData(CoreModel): + backend: BackendType # can be different from configuration backend + backend_data: Optional[str] = None + + +class PlacementGroup(CoreModel): + name: str + project_name: str + configuration: PlacementGroupConfiguration + provisioning_data: Optional[PlacementGroupProvisioningData] = None diff --git a/src/dstack/_internal/server/background/__init__.py b/src/dstack/_internal/server/background/__init__.py index 0a2e5f757..811600779 100644 --- a/src/dstack/_internal/server/background/__init__.py +++ b/src/dstack/_internal/server/background/__init__.py @@ -9,6 +9,9 @@ from dstack._internal.server.background.tasks.process_instances import ( process_instances, ) +from dstack._internal.server.background.tasks.process_placement_groups import ( + process_placement_groups, +) from dstack._internal.server.background.tasks.process_running_jobs import process_running_jobs from dstack._internal.server.background.tasks.process_runs import process_runs from dstack._internal.server.background.tasks.process_submitted_jobs import process_submitted_jobs @@ -45,5 +48,6 @@ def start_background_tasks() -> AsyncIOScheduler: process_submitted_volumes, IntervalTrigger(seconds=10, jitter=2), max_instances=5 ) _scheduler.add_job(process_fleets, IntervalTrigger(seconds=15)) + _scheduler.add_job(process_placement_groups, IntervalTrigger(seconds=30)) _scheduler.start() return _scheduler diff --git a/src/dstack/_internal/server/background/tasks/process_fleets.py b/src/dstack/_internal/server/background/tasks/process_fleets.py index ef92d5c3e..88c877eca 100644 --- a/src/dstack/_internal/server/background/tasks/process_fleets.py +++ b/src/dstack/_internal/server/background/tasks/process_fleets.py @@ -1,10 +1,10 @@ -from sqlalchemy import select +from sqlalchemy import select, update from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import joinedload from dstack._internal.core.models.fleets import FleetStatus from dstack._internal.server.db import get_session_ctx -from dstack._internal.server.models import FleetModel +from dstack._internal.server.models import FleetModel, PlacementGroupModel from dstack._internal.server.services.fleets import ( fleet_model_to_fleet, is_fleet_empty, @@ -73,5 +73,16 @@ async def _autodelete_fleet(session: AsyncSession, fleet_model: FleetModel): fleet_model.status = FleetStatus.TERMINATED fleet_model.deleted = True fleet_model.last_processed_at = get_current_datetime() + await _mark_placement_groups_as_ready_for_deletion(session=session, fleet_model=fleet_model) await session.commit() logger.info("Fleet %s deleted", fleet_model.name) + + +async def _mark_placement_groups_as_ready_for_deletion( + session: AsyncSession, fleet_model: FleetModel +): + await session.execute( + update(PlacementGroupModel) + .where(PlacementGroupModel.fleet_id == fleet_model.id) + .values(fleet_deleted=True) + ) diff --git a/src/dstack/_internal/server/background/tasks/process_instances.py b/src/dstack/_internal/server/background/tasks/process_instances.py index 45b796097..b3903cb31 100644 --- a/src/dstack/_internal/server/background/tasks/process_instances.py +++ b/src/dstack/_internal/server/background/tasks/process_instances.py @@ -12,7 +12,10 @@ from sqlalchemy.orm import joinedload, lazyload from dstack._internal import settings -from dstack._internal.core.backends import BACKENDS_WITH_CREATE_INSTANCE_SUPPORT +from dstack._internal.core.backends import ( + BACKENDS_WITH_CREATE_INSTANCE_SUPPORT, + BACKENDS_WITH_PLACEMENT_GROUPS_SUPPORT, +) from dstack._internal.core.backends.base.compute import ( DSTACK_WORKING_DIR, get_dstack_runner_version, @@ -29,6 +32,8 @@ run_shim_as_systemd_service, upload_envs, ) + +# FIXME: ProvisioningError is a subclass of ComputeError and should not be used outside of Compute from dstack._internal.core.errors import BackendError, ProvisioningError from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.fleets import InstanceGroupPlacement @@ -41,19 +46,27 @@ InstanceType, RemoteConnectionInfo, ) +from dstack._internal.core.models.placement import ( + PlacementGroup, + PlacementGroupConfiguration, + PlacementStrategy, +) from dstack._internal.core.models.profiles import ( - Profile, RetryEvent, TerminationPolicy, ) from dstack._internal.core.models.runs import ( JobProvisioningData, - Requirements, Retry, ) from dstack._internal.core.services.profiles import get_retry from dstack._internal.server.db import get_session_ctx -from dstack._internal.server.models import FleetModel, InstanceModel, ProjectModel +from dstack._internal.server.models import ( + FleetModel, + InstanceModel, + PlacementGroupModel, + ProjectModel, +) from dstack._internal.server.schemas.runner import HealthcheckResponse from dstack._internal.server.services import backends as backends_services from dstack._internal.server.services.fleets import fleet_model_to_fleet @@ -61,7 +74,16 @@ terminate_job_provisioning_data_instance, ) from dstack._internal.server.services.locking import get_locker -from dstack._internal.server.services.pools import get_instance_provisioning_data +from dstack._internal.server.services.placement import ( + get_fleet_placement_groups, + placement_group_model_to_placement_group, +) +from dstack._internal.server.services.pools import ( + get_instance_configuration, + get_instance_profile, + get_instance_provisioning_data, + get_instance_requirements, +) from dstack._internal.server.services.runner import client as runner_client from dstack._internal.server.services.runner.client import HealthStatus from dstack._internal.server.services.runner.ssh import runner_ssh_tunnel @@ -138,9 +160,12 @@ async def _process_instance(session: AsyncSession, instance: InstanceModel): await _terminate_idle_instance(instance) elif instance.status == InstanceStatus.PENDING: if instance.remote_connection_info is not None: - await _add_remote(instance=instance) + await _add_remote(instance) else: - await _create_instance(instance) + await _create_instance( + session=session, + instance=instance, + ) elif instance.status in ( InstanceStatus.PROVISIONING, InstanceStatus.IDLE, @@ -370,7 +395,7 @@ def _deploy_instance( return health, host_info -async def _create_instance(instance: InstanceModel) -> None: +async def _create_instance(session: AsyncSession, instance: InstanceModel) -> None: if instance.last_retry_at is not None: last_retry = instance.last_retry_at.replace(tzinfo=datetime.timezone.utc) if get_current_datetime() < last_retry + timedelta(minutes=1): @@ -399,11 +424,9 @@ async def _create_instance(instance: InstanceModel) -> None: return try: - profile: Profile = Profile.__response__.parse_raw(instance.profile) - requirements: Requirements = Requirements.__response__.parse_raw(instance.requirements) - instance_configuration: InstanceConfiguration = ( - InstanceConfiguration.__response__.parse_raw(instance.instance_configuration) - ) + instance_configuration = get_instance_configuration(instance) + profile = get_instance_profile(instance) + requirements = get_instance_requirements(instance) except ValidationError as e: instance.status = InstanceStatus.TERMINATED instance.termination_reason = ( @@ -455,9 +478,36 @@ async def _create_instance(instance: InstanceModel) -> None: ) return + placement_groups = [] + if instance.fleet_id: + placement_groups = await get_fleet_placement_groups( + session=session, fleet_id=instance.fleet_id + ) + + instance_configuration = _patch_instance_configuration(instance) + for backend, instance_offer in offers: if instance_offer.backend not in BACKENDS_WITH_CREATE_INSTANCE_SUPPORT: continue + if ( + instance_offer.backend in BACKENDS_WITH_PLACEMENT_GROUPS_SUPPORT + and instance.fleet + and instance_configuration.placement_group_name + ): + placement_group_model = _create_placement_group_if_does_not_exist( + session=session, + fleet_model=instance.fleet, + placement_groups=placement_groups, + name=instance_configuration.placement_group_name, + backend=instance_offer.backend, + region=instance_offer.region, + ) + if placement_group_model is not None: + placement_group = placement_group_model_to_placement_group(placement_group_model) + pgpd = await run_async(backend.compute().create_placement_group, placement_group) + placement_group_model.provisioning_data = pgpd.json() + session.add(placement_group_model) + placement_groups.append(placement_group) logger.debug( "Trying %s in %s/%s for $%0.4f per hour", instance_offer.instance.name, @@ -489,6 +539,7 @@ async def _create_instance(instance: InstanceModel) -> None: instance.backend = backend.TYPE instance.region = instance_offer.region instance.price = instance_offer.price + instance.instance_configuration = instance_configuration.json() instance.job_provisioning_data = job_provisioning_data.json() instance.offer = instance_offer.json() instance.started_at = get_current_datetime() @@ -755,6 +806,48 @@ def _need_to_wait_fleet_provisioning(instance: InstanceModel) -> bool: ) +def _patch_instance_configuration(instance: InstanceModel) -> InstanceConfiguration: + instance_configuration = get_instance_configuration(instance) + if instance.fleet is None: + return instance_configuration + + fleet = fleet_model_to_fleet(instance.fleet) + master_instance = instance.fleet.instances[0] + master_job_provisioning_data = get_instance_provisioning_data(master_instance) + if ( + fleet.spec.configuration.placement == InstanceGroupPlacement.CLUSTER + and master_job_provisioning_data is not None + ): + instance_configuration.availability_zone = master_job_provisioning_data.availability_zone + + return instance_configuration + + +def _create_placement_group_if_does_not_exist( + session: AsyncSession, + fleet_model: FleetModel, + placement_groups: List[PlacementGroup], + name: str, + backend: BackendType, + region: str, +) -> Optional[PlacementGroupModel]: + for pg in placement_groups: + if pg.configuration.backend == backend and pg.configuration.region == region: + return None + placement_group_model = PlacementGroupModel( + name=name, + project=fleet_model.project, + fleet=fleet_model, + configuration=PlacementGroupConfiguration( + backend=backend, + region=region, + placement_strategy=PlacementStrategy.CLUSTER, + ).json(), + ) + session.add(placement_group_model) + return placement_group_model + + def _get_instance_idle_duration(instance: InstanceModel) -> datetime.timedelta: last_time = instance.created_at.replace(tzinfo=datetime.timezone.utc) if instance.last_job_processed_at is not None: diff --git a/src/dstack/_internal/server/background/tasks/process_placement_groups.py b/src/dstack/_internal/server/background/tasks/process_placement_groups.py new file mode 100644 index 000000000..9cd751a90 --- /dev/null +++ b/src/dstack/_internal/server/background/tasks/process_placement_groups.py @@ -0,0 +1,97 @@ +from typing import List +from uuid import UUID + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import joinedload + +from dstack._internal.core.errors import PlacementGroupInUseError +from dstack._internal.server.db import get_session_ctx +from dstack._internal.server.models import PlacementGroupModel, ProjectModel +from dstack._internal.server.services import backends as backends_services +from dstack._internal.server.services.locking import get_locker +from dstack._internal.server.services.placement import placement_group_model_to_placement_group +from dstack._internal.server.utils.common import run_async +from dstack._internal.utils.common import get_current_datetime +from dstack._internal.utils.logging import get_logger + +logger = get_logger(__name__) + + +async def process_placement_groups(): + lock, lockset = get_locker().get_lockset(PlacementGroupModel.__tablename__) + async with get_session_ctx() as session: + async with lock: + res = await session.execute( + select(PlacementGroupModel) + .where( + PlacementGroupModel.fleet_deleted == True, + PlacementGroupModel.deleted == False, + PlacementGroupModel.id.not_in(lockset), + ) + .with_for_update(skip_locked=True) + ) + placement_group_models = res.scalars().all() + if len(placement_group_models) == 0: + return + placement_group_models_ids = [pg.id for pg in placement_group_models] + lockset.update(placement_group_models_ids) + try: + await _delete_placement_groups( + session=session, + placement_group_models_ids=placement_group_models_ids, + ) + finally: + lockset.difference_update(placement_group_models_ids) + + +async def _delete_placement_groups( + session: AsyncSession, + placement_group_models_ids: List[UUID], +): + res = await session.execute( + select(PlacementGroupModel) + .where( + PlacementGroupModel.id.in_(placement_group_models_ids), + PlacementGroupModel.deleted == False, + ) + .options(joinedload(PlacementGroupModel.project).joinedload(ProjectModel.backends)) + .execution_options(populate_existing=True) + ) + placement_group_models = res.unique().scalars().all() + for pg in placement_group_models: + await _delete_placement_group(pg) + await session.commit() + + +async def _delete_placement_group(placement_group_model: PlacementGroupModel): + logger.info("Deleting placement group %s", placement_group_model.name) + placement_group = placement_group_model_to_placement_group(placement_group_model) + if placement_group.provisioning_data is None: + logger.error( + "Failed to delete placement group %s. provisioning_data is None.", placement_group.name + ) + return + backend = await backends_services.get_project_backend_by_type( + project=placement_group_model.project, + backend_type=placement_group.provisioning_data.backend, + ) + if backend is None: + logger.error( + "Failed to delete placement group %s. Backend not available.", placement_group.name + ) + return + try: + await run_async(backend.compute().delete_placement_group, placement_group) + except PlacementGroupInUseError: + logger.info( + "Placement group %s is still in use. Skipping deletion for now.", placement_group.name + ) + return + except Exception: + logger.exception("Got exception when deleting placement group %s", placement_group.name) + return + + placement_group_model.deleted = True + placement_group_model.deleted_at = get_current_datetime() + logger.info("Deleted placement group %s", placement_group_model.name) diff --git a/src/dstack/_internal/server/migrations/versions/a7b46c073fa1_add_placementgroupmodel.py b/src/dstack/_internal/server/migrations/versions/a7b46c073fa1_add_placementgroupmodel.py new file mode 100644 index 000000000..291b0b7e6 --- /dev/null +++ b/src/dstack/_internal/server/migrations/versions/a7b46c073fa1_add_placementgroupmodel.py @@ -0,0 +1,58 @@ +"""Add PlacementGroupModel + +Revision ID: a7b46c073fa1 +Revises: e3b7db07727f +Create Date: 2024-09-25 13:52:28.701586 + +""" + +import sqlalchemy as sa +import sqlalchemy_utils +from alembic import op + +import dstack._internal.server.models + +# revision identifiers, used by Alembic. +revision = "a7b46c073fa1" +down_revision = "e3b7db07727f" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "placement_groups", + sa.Column("id", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=False), + sa.Column("name", sa.String(length=100), nullable=False), + sa.Column( + "project_id", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=False + ), + sa.Column("fleet_id", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=False), + sa.Column("fleet_deleted", sa.Boolean(), nullable=False), + sa.Column("created_at", dstack._internal.server.models.NaiveDateTime(), nullable=False), + sa.Column( + "last_processed_at", dstack._internal.server.models.NaiveDateTime(), nullable=False + ), + sa.Column("deleted", sa.Boolean(), nullable=False), + sa.Column("deleted_at", dstack._internal.server.models.NaiveDateTime(), nullable=True), + sa.Column("configuration", sa.Text(), nullable=False), + sa.Column("provisioning_data", sa.Text(), nullable=True), + sa.ForeignKeyConstraint( + ["fleet_id"], ["fleets.id"], name=op.f("fk_placement_groups_fleet_id_fleets") + ), + sa.ForeignKeyConstraint( + ["project_id"], + ["projects.id"], + name=op.f("fk_placement_groups_project_id_projects"), + ondelete="CASCADE", + ), + sa.PrimaryKeyConstraint("id", name=op.f("pk_placement_groups")), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("placement_groups") + # ### end Alembic commands ### diff --git a/src/dstack/_internal/server/models.py b/src/dstack/_internal/server/models.py index 9d603b130..35ddd5d8e 100644 --- a/src/dstack/_internal/server/models.py +++ b/src/dstack/_internal/server/models.py @@ -561,3 +561,29 @@ class VolumeModel(BaseModel): Column("volume_id", ForeignKey("volumes.id"), primary_key=True), Column("instace_id", ForeignKey("instances.id"), primary_key=True), ) + + +class PlacementGroupModel(BaseModel): + __tablename__ = "placement_groups" + + id: Mapped[uuid.UUID] = mapped_column( + UUIDType(binary=False), primary_key=True, default=uuid.uuid4 + ) + name: Mapped[str] = mapped_column(String(100)) + + project_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("projects.id", ondelete="CASCADE")) + project: Mapped["ProjectModel"] = relationship(foreign_keys=[project_id]) + + fleet_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("fleets.id")) + fleet: Mapped["FleetModel"] = relationship(foreign_keys=[fleet_id]) + fleet_deleted: Mapped[bool] = mapped_column(Boolean, default=False) + + created_at: Mapped[datetime] = mapped_column(NaiveDateTime, default=get_current_datetime) + last_processed_at: Mapped[datetime] = mapped_column( + NaiveDateTime, default=get_current_datetime + ) + deleted: Mapped[bool] = mapped_column(Boolean, default=False) + deleted_at: Mapped[Optional[datetime]] = mapped_column(NaiveDateTime) + + configuration: Mapped[str] = mapped_column(Text) + provisioning_data: Mapped[Optional[str]] = mapped_column(Text) diff --git a/src/dstack/_internal/server/services/fleets.py b/src/dstack/_internal/server/services/fleets.py index 435e5f70a..718cf9d93 100644 --- a/src/dstack/_internal/server/services/fleets.py +++ b/src/dstack/_internal/server/services/fleets.py @@ -1,3 +1,5 @@ +import random +import string import uuid from datetime import timezone from typing import List, Optional, Union @@ -17,6 +19,7 @@ Fleet, FleetSpec, FleetStatus, + InstanceGroupPlacement, SSHHostParams, SSHParams, ) @@ -162,6 +165,10 @@ async def create_fleet( ) fleet_model.instances.append(instances_model) else: + placement_group_name = _get_placement_group_name( + project=project, + fleet_spec=spec, + ) for i in range(_get_fleet_nodes_to_provision(spec)): instance_model = await create_fleet_instance_model( session=session, @@ -169,6 +176,7 @@ async def create_fleet( user=user, pool=pool, spec=spec, + placement_group_name=placement_group_name, instance_num=i, ) fleet_model.instances.append(instance_model) @@ -182,6 +190,7 @@ async def create_fleet_instance_model( user: UserModel, pool: PoolModel, spec: FleetSpec, + placement_group_name: Optional[str], instance_num: int, ) -> InstanceModel: profile = spec.merged_profile @@ -199,6 +208,7 @@ async def create_fleet_instance_model( requirements=requirements, instance_name=f"{spec.configuration.name}-{instance_num}", instance_num=instance_num, + placement_group_name=placement_group_name, ) return instance_model @@ -422,3 +432,18 @@ def _terminate_fleet_instances(fleet_model: FleetModel, instance_nums: Optional[ instance.deleted = True else: instance.status = InstanceStatus.TERMINATING + + +def _get_placement_group_name( + project: ProjectModel, + fleet_spec: FleetSpec, +) -> Optional[str]: + if fleet_spec.configuration.placement != InstanceGroupPlacement.CLUSTER: + return None + # A random suffix to avoid clashing with to-be-deleted placement groups left by old fleets + suffix = _generate_random_placement_group_suffix() + return f"{project.name}-{fleet_spec.configuration.name}-{suffix}-pg" + + +def _generate_random_placement_group_suffix(length: int = 8) -> str: + return "".join(random.choice(string.ascii_lowercase + string.digits) for _ in range(length)) diff --git a/src/dstack/_internal/server/services/placement.py b/src/dstack/_internal/server/services/placement.py new file mode 100644 index 000000000..70d138507 --- /dev/null +++ b/src/dstack/_internal/server/services/placement.py @@ -0,0 +1,53 @@ +from typing import Optional +from uuid import UUID + +from git import List +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from dstack._internal.core.models.placement import ( + PlacementGroup, + PlacementGroupConfiguration, + PlacementGroupProvisioningData, +) +from dstack._internal.server.models import PlacementGroupModel + + +async def get_fleet_placement_groups( + session: AsyncSession, + fleet_id: UUID, +) -> List[PlacementGroup]: + res = await session.execute( + select(PlacementGroupModel).where(PlacementGroupModel.fleet_id == fleet_id) + ) + placement_groups = res.scalars().all() + return [placement_group_model_to_placement_group(pg) for pg in placement_groups] + + +def placement_group_model_to_placement_group( + placement_group_model: PlacementGroupModel, +) -> PlacementGroup: + configuration = get_placement_group_configuration(placement_group_model) + provisioning_data = get_placement_group_provisioning_data(placement_group_model) + return PlacementGroup( + name=placement_group_model.name, + project_name=placement_group_model.project.name, + configuration=configuration, + provisioning_data=provisioning_data, + ) + + +def get_placement_group_configuration( + placement_group_model: PlacementGroupModel, +) -> PlacementGroupConfiguration: + return PlacementGroupConfiguration.__response__.parse_raw(placement_group_model.configuration) + + +def get_placement_group_provisioning_data( + placement_group_model: PlacementGroupModel, +) -> Optional[PlacementGroupProvisioningData]: + if placement_group_model.provisioning_data is None: + return None + return PlacementGroupProvisioningData.__response__.parse_raw( + placement_group_model.provisioning_data + ) diff --git a/src/dstack/_internal/server/services/pools.py b/src/dstack/_internal/server/services/pools.py index 491cd7b6e..98be08070 100644 --- a/src/dstack/_internal/server/services/pools.py +++ b/src/dstack/_internal/server/services/pools.py @@ -263,6 +263,18 @@ def get_instance_offer(instance_model: InstanceModel) -> Optional[InstanceOfferW return InstanceOfferWithAvailability.__response__.parse_raw(instance_model.offer) +def get_instance_configuration(instance_model: InstanceModel) -> InstanceConfiguration: + return InstanceConfiguration.__response__.parse_raw(instance_model.instance_configuration) + + +def get_instance_profile(instance_model: InstanceModel) -> Profile: + return Profile.__response__.parse_raw(instance_model.profile) + + +def get_instance_requirements(instance_model: InstanceModel) -> Requirements: + return Requirements.__response__.parse_raw(instance_model.requirements) + + async def generate_instance_name( session: AsyncSession, project: ProjectModel, @@ -582,6 +594,7 @@ async def create_instance_model( requirements: Requirements, instance_name: str, instance_num: int, + placement_group_name: Optional[str], ) -> InstanceModel: instance = InstanceModel( id=uuid.uuid4(), @@ -608,13 +621,14 @@ async def create_instance_model( instance_config = InstanceConfiguration( project_name=project.name, instance_name=instance_name, + user=user.name, instance_id=str(instance.id), ssh_keys=[project_ssh_key], + placement_group_name=placement_group_name, job_docker_config=DockerConfig( image=dstack_default_image, registry_auth=None, ), - user=user.name, ) instance.instance_configuration = instance_config.json() return instance diff --git a/src/dstack/_internal/server/testing/common.py b/src/dstack/_internal/server/testing/common.py index 418a1b100..6b90488c2 100644 --- a/src/dstack/_internal/server/testing/common.py +++ b/src/dstack/_internal/server/testing/common.py @@ -20,6 +20,11 @@ InstanceType, Resources, ) +from dstack._internal.core.models.placement import ( + PlacementGroupConfiguration, + PlacementGroupProvisioningData, + PlacementStrategy, +) from dstack._internal.core.models.profiles import ( DEFAULT_POOL_NAME, DEFAULT_POOL_TERMINATION_IDLE_TIME, @@ -50,6 +55,7 @@ GatewayModel, InstanceModel, JobModel, + PlacementGroupModel, PoolModel, ProjectModel, RepoModel, @@ -561,6 +567,56 @@ def get_volume_provisioning_data( ) +async def create_placement_group( + session: AsyncSession, + project: ProjectModel, + fleet: FleetModel, + name: str = "test-pg", + created_at: datetime = datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc), + configuration: Optional[PlacementGroupConfiguration] = None, + provisioning_data: Optional[PlacementGroupProvisioningData] = None, + fleet_deleted: Optional[bool] = False, + deleted: Optional[bool] = False, + deleted_at: Optional[datetime] = None, +) -> PlacementGroupModel: + if configuration is None: + configuration = get_placement_group_configuration() + if provisioning_data is None: + provisioning_data = get_placement_group_provisioning_data() + pg = PlacementGroupModel( + project=project, + fleet=fleet, + name=name, + created_at=created_at, + configuration=configuration.json(), + provisioning_data=provisioning_data.json(), + fleet_deleted=fleet_deleted, + deleted=deleted, + deleted_at=deleted_at, + ) + session.add(pg) + await session.commit() + return pg + + +def get_placement_group_configuration( + backend: BackendType = BackendType.AWS, + region: str = "eu-central-1", + strategy: PlacementStrategy = PlacementStrategy.CLUSTER, +) -> PlacementGroupConfiguration: + return PlacementGroupConfiguration( + backend=backend, + region=region, + placement_strategy=strategy, + ) + + +def get_placement_group_provisioning_data( + backend: BackendType = BackendType.AWS, +) -> PlacementGroupProvisioningData: + return PlacementGroupProvisioningData(backend=backend) + + def get_private_key_string() -> str: return """ -----BEGIN RSA PRIVATE KEY----- diff --git a/src/tests/_internal/server/background/tasks/test_process_placement_groups.py b/src/tests/_internal/server/background/tasks/test_process_placement_groups.py new file mode 100644 index 000000000..becc8ec7c --- /dev/null +++ b/src/tests/_internal/server/background/tasks/test_process_placement_groups.py @@ -0,0 +1,42 @@ +from unittest.mock import Mock, patch + +import pytest +from sqlalchemy.ext.asyncio import AsyncSession + +from dstack._internal.server.background.tasks.process_placement_groups import ( + process_placement_groups, +) +from dstack._internal.server.testing.common import ( + create_fleet, + create_placement_group, + create_project, +) + + +class TestProcessPlacementGroups: + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_deletes_placement_groups(self, test_db, session: AsyncSession): + project = await create_project(session) + fleet = await create_fleet( + session=session, + project=project, + ) + placement_group1 = await create_placement_group( + session=session, + project=project, + fleet=fleet, + name="test1-pg", + ) + placement_group2 = await create_placement_group( + session=session, project=project, fleet=fleet, name="test2-pg", fleet_deleted=True + ) + with patch("dstack._internal.server.services.backends.get_project_backend_by_type") as m: + aws_mock = Mock() + m.return_value = aws_mock + await process_placement_groups() + aws_mock.compute.return_value.delete_placement_group.assert_called_once() + await session.refresh(placement_group1) + await session.refresh(placement_group2) + assert not placement_group1.deleted + assert placement_group2.deleted