Skip to content

Commit

Permalink
update cluster_set_crd_yaml
Browse files Browse the repository at this point in the history
  • Loading branch information
michelle-yooh committed Dec 19, 2023
1 parent 1c0b49b commit 92f215d
Showing 1 changed file with 15 additions and 3 deletions.
18 changes: 15 additions & 3 deletions xpk.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,7 @@
name: {cluster_hardware_name}
spec:
nodeLabels:
cloud.google.com/gke-tpu-accelerator: {system.gke_accelerator}
cloud.google.com/gke-tpu-topology: {system.topology}
{node_labels}
---
apiVersion: kueue.x-k8s.io/v1beta1
kind: ClusterQueue
Expand Down Expand Up @@ -1219,7 +1218,7 @@ def run_gke_node_pool_create_command(args, system_characteristics) -> int:
f' --region={zone_to_region(args.zone)}'
f' --num-nodes={system_characteristics.vms_per_slice}'
f' --machine-type={system_characteristics.gce_machine_type}'
f' --tpu-topology={system_characteristics.topology}'
# f' --tpu-topology={system_characteristics.topology}'
f' --host-maintenance-interval={args.host_maintenance_interval}'
f' {capacity_args}'
' --scopes=storage-full,gke-default'
Expand Down Expand Up @@ -1386,6 +1385,7 @@ def enable_kueue_crds(args, system) -> int:
system=system,
cluster_hardware_name=cluster_hardware_name,
total_chips=total_chips,
node_labels=create_node_labels(args, system),
resource_type=get_resource_type(args)
)
tmp = write_temporary_file(yml_string)
Expand Down Expand Up @@ -1846,6 +1846,18 @@ def setup_docker_image(args) -> tuple[int, str]:

return 0, docker_image

def create_node_labels(args, system) -> str:
if args.device_type in TpuUserFacingNameToSystemCharacteristics:
return """cloud.google.com/gke-tpu-accelerator: {gke_accelerator}
cloud.google.com/gke-tpu-topology: {topology}
""".format(gke_accelerator=system.gke_accelerator, topology=system.topology)
elif args.device_type in GpuUserFacingNameToSystemCharacteristics:
return """cloud.google.com/gke-accelerator: {gke_accelerator}
cloud.google.com/gce-machine-type: {gce_machine_type}
""".format(gke_accelerator=system.gke_accelerator, gce_machine_type=system.gce_machine_type)
else:
raise ValueError("Unknown device type")


def create_node_selector(args, system) -> str:
if args.device_type in TpuUserFacingNameToSystemCharacteristics:
Expand Down

0 comments on commit 92f215d

Please sign in to comment.