diff --git a/pkg/apiserver/controller/job/get.go b/pkg/apiserver/controller/job/get.go index 419ac55e0..0f44c6469 100644 --- a/pkg/apiserver/controller/job/get.go +++ b/pkg/apiserver/controller/job/get.go @@ -479,10 +479,32 @@ func getPaddleJobRoleAndIndex(name string, annotations map[string]string) (schem return taskRole, taskIndex } +func getPyTorchJobRoleAndIndex(name string, annotations map[string]string) (schema.MemberRole, int) { + taskRole, taskIndex := schema.RoleWorker, 0 + if annotations["volcano.sh/task-spec"] == "master" { + taskRole = schema.RoleMaster + } else { + // worker is named with format: xxxx-worker-0 + items := strings.Split(name, "-") + if len(items) > 0 { + taskIndex, _ = strconv.Atoi(items[len(items)-1]) + } + } + return taskRole, taskIndex +} + // getMPIJobRoleAndIndex returns the runtime info of mpi job func getMPIJobRoleAndIndex(name string, annotations map[string]string) (schema.MemberRole, int) { taskRole, taskIndex := schema.RoleWorker, 0 - // TODO: support mpi job + if annotations["volcano.sh/task-spec"] == "master" { + taskRole = schema.RoleMaster + } else { + // worker is named with format: xxxx-worker-0 + items := strings.Split(name, "-") + if len(items) > 0 { + taskIndex, _ = strconv.Atoi(items[len(items)-1]) + } + } return taskRole, taskIndex } @@ -513,6 +535,8 @@ func getTaskRoleAndIndex(task model.JobTask, kgv *schema.KindGroupVersion) (stri taskRole, taskIndex = getAITrainingJobRoleAndIndex(task.Name, task.Annotations) case schema.MPIKindGroupVersion.String(): taskRole, taskIndex = getMPIJobRoleAndIndex(task.Name, task.Annotations) + case schema.PyTorchKindGroupVersion.String(): + taskRole, taskIndex = getPyTorchJobRoleAndIndex(task.Name, task.Annotations) default: find = false }