diff --git a/xpk.py b/xpk.py index df66664b..62eba335 100644 --- a/xpk.py +++ b/xpk.py @@ -98,6 +98,7 @@ hostNetwork: true dnsPolicy: ClusterFirstWithHostNet terminationGracePeriodSeconds: {args.termination_grace_period_seconds} + {gpu_env} containers: {container} volumeMounts: @@ -2198,6 +2199,42 @@ def create_machine_label(accelerator_type, system) -> str: return f"{AcceleratorTypeToAcceleratorCharacteristics[accelerator_type].machine_label}: {system.topology}" return "" +def create_gpu_env(accelerator_type, args) -> str: + """Generates gpu env entries for workload_create_yaml. + + Args: + accelerator_type: type of accelerator. + system: system characteristics. + + Returns: + The machine label. + """ + if accelerator_type == AcceleratorType['TPU']: + return "" + + env = """env: + - name: GPU_WORKER_ID + valueFrom: + fieldRef: + fieldPath: metadata.annotations['batch.kubernetes.io/job-completion-index'] + - name: GPU_WORKER_HOSTNAMES + value: {gpu_worker_hostnames} + - name: MEGASCALE_NUM_SLICES + value: {args.num_slice} + - name: MEGASCALE_SLICE_ID + valueFrom: + fieldRef: + fieldPath: metadata.annotations['jobset.sigs.k8s.io/job-index'] + - name: MEGASCALE_COORDINATOR_ADDRESS + value: {args.workload}-slice-job-0-0.{args.workload} +""" + gpu_worker_hostnames = [] + for i in range(int(args.num_slice)): + gpu_worker_hostnames.append(f"slice-job-{i}-{args.workload}") + + env.format(args=args, gpu_worker_hostnames=",".join(gpu_worker_hostnames)) + return env + def get_system_characteristics(args) -> tuple[SystemCharacteristics|None, int]: """Get system characteristics based on user provided arguments. @@ -2277,7 +2314,8 @@ def workload_create(args) -> int: container=container, accelerator_label=create_accelerator_label(system.accelerator_type, system), machine_label=create_machine_label(system.accelerator_type, system), - resource_type=resource_type) + resource_type=resource_type, + gpu_env=create_gpu_env(system.accelerator_type, args)) tmp = write_temporary_file(yml_string) command = f'kubectl apply -f {str(tmp.file.name)}'