From b3f7efd7f84667d600f3a8f3b50660a622ffe376 Mon Sep 17 00:00:00 2001 From: dongzezhao Date: Thu, 4 Jan 2024 20:25:48 +0800 Subject: [PATCH] feat(job): update kubeflow build (#1303) * feat(job): update kubeflow build * update PodTemplate build * update error handle * remove BuildPodSpec * set limit with request --- pkg/job/api/job.go | 9 + .../job/aitraining/kube_aitraining_job.go | 3 +- pkg/job/runtime_v2/job/mpi/kube_mpi_job.go | 22 +- .../runtime_v2/job/mpi/kube_mpi_job_test.go | 29 +- .../runtime_v2/job/paddle/kube_paddle_job.go | 12 +- .../job/paddle/kubeflow_paddle_job.go | 18 +- .../job/pytorch/kube_pytorch_job.go | 22 +- pkg/job/runtime_v2/job/ray/kube_ray_job.go | 16 +- .../job/tensorflow/kube_tensorflow_job.go | 18 +- .../job/util/kuberuntime/kube_job.go | 83 ------ .../job/util/kuberuntime/kube_job_builder.go | 263 ++++++++++++++++++ .../util/kuberuntime/kube_job_builder_test.go | 247 ++++++++++++++++ .../job/util/kuberuntime/kube_job_test.go | 213 -------------- 13 files changed, 578 insertions(+), 377 deletions(-) create mode 100644 pkg/job/runtime_v2/job/util/kuberuntime/kube_job_builder.go create mode 100644 pkg/job/runtime_v2/job/util/kuberuntime/kube_job_builder_test.go diff --git a/pkg/job/api/job.go b/pkg/job/api/job.go index 2f5f3ec84..5544bf5d8 100644 --- a/pkg/job/api/job.go +++ b/pkg/job/api/job.go @@ -134,6 +134,15 @@ func (pfj *PFJob) GetID() string { return pfj.ID } +func (pfj *PFJob) GetMember(roleName schema.MemberRole) schema.Member { + for _, member := range pfj.Tasks { + if member.Role == roleName { + return member + } + } + return schema.Member{} +} + type JobSyncInfo struct { ID string Namespace string diff --git a/pkg/job/runtime_v2/job/aitraining/kube_aitraining_job.go b/pkg/job/runtime_v2/job/aitraining/kube_aitraining_job.go index 955bced98..01d526852 100644 --- a/pkg/job/runtime_v2/job/aitraining/kube_aitraining_job.go +++ b/pkg/job/runtime_v2/job/aitraining/kube_aitraining_job.go @@ -138,7 +138,8 @@ func (pj *KubeAITrainingJob) patchReplicaSpec(rs *v1.ReplicaSpec, task pfschema. rs.Replicas = &replicas } // patch fs - return kuberuntime.BuildPodTemplateSpec(&rs.Template, jobID, &task) + kuberuntime.NewPodTemplateSpecBuilder(&rs.Template, jobID).Build(task) + return nil } func (pj *KubeAITrainingJob) builtinAITrainingJob(pdj *v1.TrainingJobSpec, job *api.PFJob) error { diff --git a/pkg/job/runtime_v2/job/mpi/kube_mpi_job.go b/pkg/job/runtime_v2/job/mpi/kube_mpi_job.go index c6cc8d9af..013119413 100644 --- a/pkg/job/runtime_v2/job/mpi/kube_mpi_job.go +++ b/pkg/job/runtime_v2/job/mpi/kube_mpi_job.go @@ -92,10 +92,8 @@ func (mj *KubeMPIJob) builtinMPIJobSpec(mpiJobSpec *mpiv1.MPIJobSpec, job *api.P replicaType = mpiv1.MPIReplicaTypeWorker } replicaSpec := mpiJobSpec.MPIReplicaSpecs[replicaType] - if err := kuberuntime.KubeflowReplicaSpec(replicaSpec, job.ID, &task); err != nil { - log.Errorf("build %s RepilcaSpec for %s failed, err: %v", replicaType, mj.String(jobName), err) - return err - } + // build kubeflowReplicaSpec + kuberuntime.NewKubeflowJobBuilder(job.ID, nil, replicaSpec).ReplicaSpec(task) // calculate job minResources taskResources, _ := resources.NewResourceFromMap(task.Flavour.ToMap()) taskResources.Multi(task.Replicas) @@ -103,7 +101,9 @@ func (mj *KubeMPIJob) builtinMPIJobSpec(mpiJobSpec *mpiv1.MPIJobSpec, job *api.P } // set RunPolicy resourceList := k8s.NewResourceList(minResources) - return kuberuntime.KubeflowRunPolicy(&mpiJobSpec.RunPolicy, &resourceList, job.Conf.GetQueueName(), job.Conf.GetPriority()) + kuberuntime.NewKubeflowJobBuilder(job.ID, &mpiJobSpec.RunPolicy, nil). + RunPolicy(&resourceList, job.Conf.GetQueueName(), job.Conf.GetPriority()) + return nil } // customMPIJobSpec set custom MPIJob Spec @@ -111,17 +111,19 @@ func (mj *KubeMPIJob) customMPIJobSpec(mpiJobSpec *mpiv1.MPIJobSpec, job *api.PF jobName := job.NamespacedName() log.Debugf("patch %s spec:%#v", mj.String(jobName), mpiJobSpec) // patch metadata - ps, find := mpiJobSpec.MPIReplicaSpecs[mpiv1.MPIReplicaTypeLauncher] - if find && ps != nil { - kuberuntime.BuildTaskMetadata(&ps.Template.ObjectMeta, job.ID, &pfschema.Conf{}) + master, find := mpiJobSpec.MPIReplicaSpecs[mpiv1.MPIReplicaTypeLauncher] + if find && master != nil { + kuberuntime.NewKubeflowJobBuilder(job.ID, nil, master).ReplicaSpec(job.GetMember(pfschema.RoleMaster)) } worker, find := mpiJobSpec.MPIReplicaSpecs[mpiv1.MPIReplicaTypeWorker] if find && worker != nil { - kuberuntime.BuildTaskMetadata(&worker.Template.ObjectMeta, job.ID, &pfschema.Conf{}) + kuberuntime.NewKubeflowJobBuilder(job.ID, nil, worker).ReplicaSpec(job.GetMember(pfschema.RoleWorker)) } // TODO: patch mpi job from user // check RunPolicy - return kuberuntime.KubeflowRunPolicy(&mpiJobSpec.RunPolicy, nil, job.Conf.GetQueueName(), job.Conf.GetPriority()) + kuberuntime.NewKubeflowJobBuilder(job.ID, &mpiJobSpec.RunPolicy, nil). + RunPolicy(nil, job.Conf.GetQueueName(), job.Conf.GetPriority()) + return nil } func (mj *KubeMPIJob) AddEventListener(ctx context.Context, listenerType string, jobQueue workqueue.RateLimitingInterface, listener interface{}) error { diff --git a/pkg/job/runtime_v2/job/mpi/kube_mpi_job_test.go b/pkg/job/runtime_v2/job/mpi/kube_mpi_job_test.go index 85562287c..be2cedb97 100644 --- a/pkg/job/runtime_v2/job/mpi/kube_mpi_job_test.go +++ b/pkg/job/runtime_v2/job/mpi/kube_mpi_job_test.go @@ -269,33 +269,6 @@ func TestMPIJob_CreateJob(t *testing.T) { expectErr: "err", wantErr: false, }, - { - caseName: "Member absent", - jobObj: &api.PFJob{ - Name: "test-mpi-job", - ID: uuid.GenerateIDWithLength("job", 5), - Namespace: "default", - JobType: pfschema.TypeDistributed, - JobMode: pfschema.EnvJobModePS, - Framework: pfschema.FrameworkMPI, - Conf: pfschema.Conf{ - Name: "normal", - Command: "sleep 200", - Image: "mockImage", - }, - Tasks: []pfschema.Member{ - { - Replicas: 1, - Role: pfschema.RoleMaster, - Conf: pfschema.Conf{ - Flavour: pfschema.Flavour{Name: "", ResourceInfo: pfschema.ResourceInfo{CPU: "-1", Mem: "4Gi"}}, - }, - }, - }, - }, - expectErr: "negative resources not permitted: map[cpu:-1 memory:4Gi]", - wantErr: true, - }, { caseName: "flavour wrong", jobObj: &api.PFJob{ @@ -321,7 +294,7 @@ func TestMPIJob_CreateJob(t *testing.T) { }, }, expectErr: "quantities must match the regular expression '^([+-]?[0-9.]+)([eEinumkKMGTP]*[-+]?[0-9]*)$'", - wantErr: true, + wantErr: false, }, { caseName: "ExtensionTemplate", diff --git a/pkg/job/runtime_v2/job/paddle/kube_paddle_job.go b/pkg/job/runtime_v2/job/paddle/kube_paddle_job.go index cee294b41..423f5d57e 100644 --- a/pkg/job/runtime_v2/job/paddle/kube_paddle_job.go +++ b/pkg/job/runtime_v2/job/paddle/kube_paddle_job.go @@ -123,9 +123,9 @@ func patchCustomPaddleTask(rSpec *paddlejobv1.ResourceSpec, task pfschema.Member if rSpec.Replicas <= 0 { rSpec.Replicas = kuberuntime.DefaultReplicas } - kuberuntime.BuildTaskMetadata(&rSpec.Template.ObjectMeta, jobID, &task.Conf) - // build pod spec - return kuberuntime.BuildPodSpec(&rSpec.Template.Spec, task) + // build pod template + kuberuntime.NewPodTemplateSpecBuilder(&rSpec.Template, jobID).Build(task) + return nil } func (pj *KubePaddleJob) buildSchedulingPolicy(pdjSpec *paddlejobv1.PaddleJobSpec, jobConf pfschema.PFJobConf) error { @@ -229,9 +229,9 @@ func (pj *KubePaddleJob) patchPaddleTask(resourceSpec *paddlejobv1.ResourceSpec, if task.Name == "" { task.Name = uuid.GenerateIDWithLength(jobID, 3) } - kuberuntime.BuildTaskMetadata(&resourceSpec.Template.ObjectMeta, jobID, &task.Conf) - // build pod spec - return kuberuntime.BuildPodSpec(&resourceSpec.Template.Spec, task) + // build pod template spec + kuberuntime.NewPodTemplateSpecBuilder(&resourceSpec.Template, jobID).Build(task) + return nil } func (pj *KubePaddleJob) AddEventListener(ctx context.Context, listenerType string, jobQueue workqueue.RateLimitingInterface, listener interface{}) error { diff --git a/pkg/job/runtime_v2/job/paddle/kubeflow_paddle_job.go b/pkg/job/runtime_v2/job/paddle/kubeflow_paddle_job.go index 69bbd4bc2..fa2bc1c66 100644 --- a/pkg/job/runtime_v2/job/paddle/kubeflow_paddle_job.go +++ b/pkg/job/runtime_v2/job/paddle/kubeflow_paddle_job.go @@ -100,10 +100,8 @@ func (pj *KubeKFPaddleJob) builtinPaddleJobSpec(jobSpec *paddlev1.PaddleJobSpec, if !ok { return fmt.Errorf("replica type %s for %s is not supported", replicaType, pj.String(jobName)) } - if err := kuberuntime.KubeflowReplicaSpec(replicaSpec, job.ID, &task); err != nil { - log.Errorf("build %s RepilcaSpec for %s failed, err: %v", replicaType, pj.String(jobName), err) - return err - } + // build kubeflowReplicaSpec + kuberuntime.NewKubeflowJobBuilder(job.ID, nil, replicaSpec).ReplicaSpec(task) // calculate job minResources taskResources, err := resources.NewResourceFromMap(task.Flavour.ToMap()) if err != nil { @@ -115,7 +113,9 @@ func (pj *KubeKFPaddleJob) builtinPaddleJobSpec(jobSpec *paddlev1.PaddleJobSpec, } // set RunPolicy resourceList := k8s.NewResourceList(minResources) - return kuberuntime.KubeflowRunPolicy(&jobSpec.RunPolicy, &resourceList, job.Conf.GetQueueName(), job.Conf.GetPriority()) + kuberuntime.NewKubeflowJobBuilder(job.ID, &jobSpec.RunPolicy, nil). + RunPolicy(&resourceList, job.Conf.GetQueueName(), job.Conf.GetPriority()) + return nil } func (pj *KubeKFPaddleJob) customPaddleJobSpec(jobSpec *paddlev1.PaddleJobSpec, job *api.PFJob) error { @@ -127,15 +127,17 @@ func (pj *KubeKFPaddleJob) customPaddleJobSpec(jobSpec *paddlev1.PaddleJobSpec, // patch metadata ps, find := jobSpec.PaddleReplicaSpecs[paddlev1.PaddleJobReplicaTypeMaster] if find && ps != nil { - kuberuntime.BuildTaskMetadata(&ps.Template.ObjectMeta, job.ID, &pfschema.Conf{}) + kuberuntime.NewKubeflowJobBuilder(job.ID, nil, ps).ReplicaSpec(job.GetMember(pfschema.RoleMaster)) } worker, find := jobSpec.PaddleReplicaSpecs[paddlev1.PaddleJobReplicaTypeWorker] if find && worker != nil { - kuberuntime.BuildTaskMetadata(&worker.Template.ObjectMeta, job.ID, &pfschema.Conf{}) + kuberuntime.NewKubeflowJobBuilder(job.ID, nil, worker).ReplicaSpec(job.GetMember(pfschema.RoleWorker)) } // TODO: patch paddle job from user // check RunPolicy - return kuberuntime.KubeflowRunPolicy(&jobSpec.RunPolicy, nil, job.Conf.GetQueueName(), job.Conf.GetPriority()) + kuberuntime.NewKubeflowJobBuilder(job.ID, &jobSpec.RunPolicy, nil). + RunPolicy(nil, job.Conf.GetQueueName(), job.Conf.GetPriority()) + return nil } func (pj *KubeKFPaddleJob) AddEventListener(ctx context.Context, listenerType string, jobQueue workqueue.RateLimitingInterface, listener interface{}) error { diff --git a/pkg/job/runtime_v2/job/pytorch/kube_pytorch_job.go b/pkg/job/runtime_v2/job/pytorch/kube_pytorch_job.go index 955ff041b..292b77f60 100644 --- a/pkg/job/runtime_v2/job/pytorch/kube_pytorch_job.go +++ b/pkg/job/runtime_v2/job/pytorch/kube_pytorch_job.go @@ -101,10 +101,8 @@ func (pj *KubePyTorchJob) builtinPyTorchJobSpec(torchJobSpec *pytorchv1.PyTorchJ if !ok { return fmt.Errorf("replica type %s for %s is not supported", replicaType, pj.String(jobName)) } - if err := kuberuntime.KubeflowReplicaSpec(replicaSpec, job.ID, &task); err != nil { - log.Errorf("build %s RepilcaSpec for %s failed, err: %v", replicaType, pj.String(jobName), err) - return err - } + // build kubeflowReplicaSpec + kuberuntime.NewKubeflowJobBuilder(job.ID, nil, replicaSpec).ReplicaSpec(task) // calculate job minResources taskResources, err := resources.NewResourceFromMap(task.Flavour.ToMap()) if err != nil { @@ -116,7 +114,9 @@ func (pj *KubePyTorchJob) builtinPyTorchJobSpec(torchJobSpec *pytorchv1.PyTorchJ } // set RunPolicy resourceList := k8s.NewResourceList(minResources) - return kuberuntime.KubeflowRunPolicy(&torchJobSpec.RunPolicy, &resourceList, job.Conf.GetQueueName(), job.Conf.GetPriority()) + kuberuntime.NewKubeflowJobBuilder(job.ID, &torchJobSpec.RunPolicy, nil). + RunPolicy(&resourceList, job.Conf.GetQueueName(), job.Conf.GetPriority()) + return nil } // customPyTorchJobSpec set custom PyTorchJob Spec @@ -127,17 +127,19 @@ func (pj *KubePyTorchJob) customPyTorchJobSpec(torchJobSpec *pytorchv1.PyTorchJo jobName := job.NamespacedName() log.Debugf("patch %s spec:%#v", pj.String(jobName), torchJobSpec) // patch metadata - ps, find := torchJobSpec.PyTorchReplicaSpecs[pytorchv1.PyTorchReplicaTypeMaster] - if find && ps != nil { - kuberuntime.BuildTaskMetadata(&ps.Template.ObjectMeta, job.ID, &pfschema.Conf{}) + master, find := torchJobSpec.PyTorchReplicaSpecs[pytorchv1.PyTorchReplicaTypeMaster] + if find && master != nil { + kuberuntime.NewKubeflowJobBuilder(job.ID, nil, master).ReplicaSpec(job.GetMember(pfschema.RoleMaster)) } worker, find := torchJobSpec.PyTorchReplicaSpecs[pytorchv1.PyTorchReplicaTypeWorker] if find && worker != nil { - kuberuntime.BuildTaskMetadata(&worker.Template.ObjectMeta, job.ID, &pfschema.Conf{}) + kuberuntime.NewKubeflowJobBuilder(job.ID, nil, worker).ReplicaSpec(job.GetMember(pfschema.RoleWorker)) } // TODO: patch pytorch job from user // check RunPolicy - return kuberuntime.KubeflowRunPolicy(&torchJobSpec.RunPolicy, nil, job.Conf.GetQueueName(), job.Conf.GetPriority()) + kuberuntime.NewKubeflowJobBuilder(job.ID, &torchJobSpec.RunPolicy, nil). + RunPolicy(nil, job.Conf.GetQueueName(), job.Conf.GetPriority()) + return nil } func (pj *KubePyTorchJob) AddEventListener(ctx context.Context, listenerType string, jobQueue workqueue.RateLimitingInterface, listener interface{}) error { diff --git a/pkg/job/runtime_v2/job/ray/kube_ray_job.go b/pkg/job/runtime_v2/job/ray/kube_ray_job.go index dd3244c5d..5124a2ba7 100644 --- a/pkg/job/runtime_v2/job/ray/kube_ray_job.go +++ b/pkg/job/runtime_v2/job/ray/kube_ray_job.go @@ -143,10 +143,7 @@ func (rj *KubeRayJob) buildHeadPod(rayJobSpec *rayV1alpha1.RayJobSpec, jobID str task.Command = "" task.Args = []string{} // Template - if err := kuberuntime.BuildPodTemplateSpec(&headGroupSpec.Template, jobID, &task); err != nil { - log.Errorf("build head pod spec failed, err:%v", err) - return err - } + kuberuntime.NewPodTemplateSpecBuilder(&headGroupSpec.Template, jobID).Build(task) // patch queue name headGroupSpec.Template.Labels[pfschema.QueueLabelKey] = task.QueueName headGroupSpec.Template.Annotations[pfschema.QueueLabelKey] = task.QueueName @@ -217,10 +214,7 @@ func (rj *KubeRayJob) buildWorkerPod(rayJobSpec *rayV1alpha1.RayJobSpec, jobID s task.Command = "" task.Args = []string{} // Template - if err := kuberuntime.BuildPodTemplateSpec(&worker.Template, jobID, &task); err != nil { - log.Errorf("build head pod spec failed, err: %v", err) - return err - } + kuberuntime.NewPodTemplateSpecBuilder(&worker.Template, jobID).Build(task) // patch queue name worker.Template.Labels[pfschema.QueueLabelKey] = task.QueueName worker.Template.Annotations[pfschema.QueueLabelKey] = task.QueueName @@ -266,9 +260,11 @@ func (rj *KubeRayJob) customRayJobSpec(rayJobSpec *rayV1alpha1.RayJobSpec, job * jobName := job.NamespacedName() log.Debugf("patch %s spec:%#v", rj.String(jobName), rayJobSpec) // patch metadata - kuberuntime.BuildTaskMetadata(&rayJobSpec.RayClusterSpec.HeadGroupSpec.Template.ObjectMeta, job.ID, &pfschema.Conf{}) + kuberuntime.NewPodTemplateSpecBuilder(&rayJobSpec.RayClusterSpec.HeadGroupSpec.Template, job.ID). + Build(job.GetMember(pfschema.RoleMaster)) for i := range rayJobSpec.RayClusterSpec.WorkerGroupSpecs { - kuberuntime.BuildTaskMetadata(&rayJobSpec.RayClusterSpec.WorkerGroupSpecs[i].Template.ObjectMeta, job.ID, &pfschema.Conf{}) + kuberuntime.NewPodTemplateSpecBuilder(&rayJobSpec.RayClusterSpec.WorkerGroupSpecs[i].Template, job.ID). + Build(job.GetMember(pfschema.RoleWorker)) } // TODO: patch ray job from user return nil diff --git a/pkg/job/runtime_v2/job/tensorflow/kube_tensorflow_job.go b/pkg/job/runtime_v2/job/tensorflow/kube_tensorflow_job.go index 5fa67e264..9b48a3e2c 100644 --- a/pkg/job/runtime_v2/job/tensorflow/kube_tensorflow_job.go +++ b/pkg/job/runtime_v2/job/tensorflow/kube_tensorflow_job.go @@ -103,10 +103,8 @@ func (pj *KubeTFJob) builtinTFJobSpec(tfJobSpec *tfv1.TFJobSpec, job *api.PFJob) if !ok { return fmt.Errorf("replica type %s for %s is not supported", replicaType, pj.String(jobName)) } - if err := kuberuntime.KubeflowReplicaSpec(replicaSpec, job.ID, &task); err != nil { - log.Errorf("build %s RepilcaSpec for %s failed, err: %v", replicaType, pj.String(jobName), err) - return err - } + // build kubeflowReplicaSpec + kuberuntime.NewKubeflowJobBuilder(job.ID, nil, replicaSpec).ReplicaSpec(task) // calculate job minResources taskResources, err := resources.NewResourceFromMap(task.Flavour.ToMap()) if err != nil { @@ -118,7 +116,9 @@ func (pj *KubeTFJob) builtinTFJobSpec(tfJobSpec *tfv1.TFJobSpec, job *api.PFJob) } // set RunPolicy resourceList := k8s.NewResourceList(minResources) - return kuberuntime.KubeflowRunPolicy(&tfJobSpec.RunPolicy, &resourceList, job.Conf.GetQueueName(), job.Conf.GetPriority()) + kuberuntime.NewKubeflowJobBuilder(job.ID, &tfJobSpec.RunPolicy, nil). + RunPolicy(&resourceList, job.Conf.GetQueueName(), job.Conf.GetPriority()) + return nil } // customTFJobSpec set custom TFJob Spec @@ -131,15 +131,17 @@ func (pj *KubeTFJob) customTFJobSpec(tfJobSpec *tfv1.TFJobSpec, job *api.PFJob) // patch metadata ps, find := tfJobSpec.TFReplicaSpecs[tfv1.TFReplicaTypePS] if find && ps != nil { - kuberuntime.BuildTaskMetadata(&ps.Template.ObjectMeta, job.ID, &pfschema.Conf{}) + kuberuntime.NewKubeflowJobBuilder(job.ID, nil, ps).ReplicaSpec(job.GetMember(pfschema.RolePServer)) } worker, find := tfJobSpec.TFReplicaSpecs[tfv1.TFReplicaTypeWorker] if find && worker != nil { - kuberuntime.BuildTaskMetadata(&worker.Template.ObjectMeta, job.ID, &pfschema.Conf{}) + kuberuntime.NewKubeflowJobBuilder(job.ID, nil, worker).ReplicaSpec(job.GetMember(pfschema.RoleWorker)) } // TODO: patch pytorch job from user // check RunPolicy - return kuberuntime.KubeflowRunPolicy(&tfJobSpec.RunPolicy, nil, job.Conf.GetQueueName(), job.Conf.GetPriority()) + kuberuntime.NewKubeflowJobBuilder(job.ID, &tfJobSpec.RunPolicy, nil). + RunPolicy(nil, job.Conf.GetQueueName(), job.Conf.GetPriority()) + return nil } func (pj *KubeTFJob) AddEventListener(ctx context.Context, listenerType string, jobQueue workqueue.RateLimitingInterface, listener interface{}) error { diff --git a/pkg/job/runtime_v2/job/util/kuberuntime/kube_job.go b/pkg/job/runtime_v2/job/util/kuberuntime/kube_job.go index 2cf146175..4bb36c72c 100644 --- a/pkg/job/runtime_v2/job/util/kuberuntime/kube_job.go +++ b/pkg/job/runtime_v2/job/util/kuberuntime/kube_job.go @@ -330,40 +330,6 @@ func appendMapsIfAbsent(Maps map[string]string, addMaps map[string]string) map[s return Maps } -func BuildPodSpec(podSpec *corev1.PodSpec, task schema.Member) error { - if podSpec == nil { - return fmt.Errorf("build pod spec failed, err: podSpec or task is nil") - } - // fill priorityClassName and schedulerName - err := buildPriorityAndScheduler(podSpec, task.Priority) - if err != nil { - log.Errorln(err) - return err - } - // fill volumes - fileSystems := task.Conf.GetAllFileSystem() - podSpec.Volumes = BuildVolumes(podSpec.Volumes, fileSystems) - // fill affinity - if len(fileSystems) != 0 { - var fsIDs []string - for _, fs := range fileSystems { - fsIDs = append(fsIDs, fs.ID) - } - podSpec.Affinity, err = generateAffinity(podSpec.Affinity, fsIDs) - if err != nil { - return err - } - } - // fill restartPolicy - patchRestartPolicy(podSpec, task) - // build containers - if err = buildPodContainers(podSpec, task); err != nil { - log.Errorf("failed to fill containers, err=%v", err) - return err - } - return nil -} - func buildPriorityAndScheduler(podSpec *corev1.PodSpec, priorityName string) error { if podSpec == nil { return fmt.Errorf("build scheduling policy failed, err: podSpec is nil") @@ -904,55 +870,6 @@ func GetKubeTime(t *metav1.Time) *time.Time { return &newT.Time } -// BuildPodTemplateSpec build PodTemplateSpec for built-in distributed job, such as PaddleJob, PyTorchJob, TFJob and so on -func BuildPodTemplateSpec(podSpec *corev1.PodTemplateSpec, jobID string, task *schema.Member) error { - if podSpec == nil || task == nil { - return fmt.Errorf("podTemplateSpec or task is nil") - } - // build task metadata - BuildTaskMetadata(&podSpec.ObjectMeta, jobID, &schema.Conf{}) - // build pod spec - err := BuildPodSpec(&podSpec.Spec, *task) - if err != nil { - log.Errorf("build pod spec failed, err: %v", err) - return err - } - return nil -} - -// KubeflowReplicaSpec build ReplicaSpec for kubeflow job, such as PyTorchJob, TFJob and so on. -func KubeflowReplicaSpec(replicaSpec *kubeflowv1.ReplicaSpec, jobID string, task *schema.Member) error { - if replicaSpec == nil || task == nil { - return fmt.Errorf("build kubeflow replica spec failed, err: replicaSpec or task is nil") - } - // set Replicas for job - replicas := int32(task.Replicas) - replicaSpec.Replicas = &replicas - // set RestartPolicy - // TODO: make RestartPolicy configurable - replicaSpec.RestartPolicy = kubeflowv1.RestartPolicyNever - // set PodTemplate - return BuildPodTemplateSpec(&replicaSpec.Template, jobID, task) -} - -// KubeflowRunPolicy build RunPolicy for kubeflow job, such as PyTorchJob, TFJob and so on. -func KubeflowRunPolicy(runPolicy *kubeflowv1.RunPolicy, minResources *corev1.ResourceList, queueName, priority string) error { - if runPolicy == nil { - return fmt.Errorf("build run policy for kubeflow job faield, err: runPolicy is nil") - } - // TODO set cleanPolicy - // set SchedulingPolicy - if runPolicy.SchedulingPolicy == nil { - runPolicy.SchedulingPolicy = &kubeflowv1.SchedulingPolicy{} - } - runPolicy.SchedulingPolicy.Queue = queueName - runPolicy.SchedulingPolicy.PriorityClass = KubePriorityClass(priority) - if minResources != nil { - runPolicy.SchedulingPolicy.MinResources = minResources - } - return nil -} - // Operations for kubernetes job, including single, paddle, sparkapp, tensorflow, pytorch, mpi jobs and so on. // getPodGroupName get the name of pod group diff --git a/pkg/job/runtime_v2/job/util/kuberuntime/kube_job_builder.go b/pkg/job/runtime_v2/job/util/kuberuntime/kube_job_builder.go new file mode 100644 index 000000000..8590b7556 --- /dev/null +++ b/pkg/job/runtime_v2/job/util/kuberuntime/kube_job_builder.go @@ -0,0 +1,263 @@ +/* +Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package kuberuntime + +import ( + "strings" + + kubeflowv1 "github.com/kubeflow/common/pkg/apis/common/v1" + log "github.com/sirupsen/logrus" + corev1 "k8s.io/api/core/v1" + + "github.com/PaddlePaddle/PaddleFlow/pkg/common/config" + "github.com/PaddlePaddle/PaddleFlow/pkg/common/k8s" + "github.com/PaddlePaddle/PaddleFlow/pkg/common/resources" + "github.com/PaddlePaddle/PaddleFlow/pkg/common/schema" +) + +// PodSpecBuilder is used to build pod spec +type PodSpecBuilder struct { + podSpec *corev1.PodSpec + jobID string +} + +// NewPodSpecBuilder create a new PodSpecBuilder +func NewPodSpecBuilder(podSpec *corev1.PodSpec, jobID string) *PodSpecBuilder { + return &PodSpecBuilder{ + podSpec: podSpec, + jobID: jobID, + } +} + +// Scheduler set scheduler name for pod +func (p *PodSpecBuilder) Scheduler() *PodSpecBuilder { + // fill SchedulerName + p.podSpec.SchedulerName = config.GlobalServerConfig.Job.SchedulerName + return p +} + +// PriorityClassName set priority class name for pod +func (p *PodSpecBuilder) PriorityClassName(priorityName string) *PodSpecBuilder { + p.podSpec.PriorityClassName = KubePriorityClass(priorityName) + return p +} + +// RestartPolicy set restart policy for pod +func (p *PodSpecBuilder) RestartPolicy(restartPolicy string) *PodSpecBuilder { + // fill restartPolicy + if restartPolicy == string(corev1.RestartPolicyAlways) || + restartPolicy == string(corev1.RestartPolicyOnFailure) { + p.podSpec.RestartPolicy = corev1.RestartPolicy(restartPolicy) + } else { + p.podSpec.RestartPolicy = corev1.RestartPolicyNever + } + return p +} + +// PFS set volumes and volumeMounts for pod +func (p *PodSpecBuilder) PFS(fileSystems []schema.FileSystem) *PodSpecBuilder { + if len(fileSystems) == 0 { + return p + } + // fill volumes + p.podSpec.Volumes = BuildVolumes(p.podSpec.Volumes, fileSystems) + // fill volumeMount + for idx, container := range p.podSpec.Containers { + p.podSpec.Containers[idx].VolumeMounts = BuildVolumeMounts(container.VolumeMounts, fileSystems) + } + // fill affinity + var fsIDs []string + for _, fs := range fileSystems { + fsIDs = append(fsIDs, fs.ID) + } + var err error + p.podSpec.Affinity, err = generateAffinity(p.podSpec.Affinity, fsIDs) + if err != nil { + log.Infof("mergeNodeAffinity new affinity is nil") + } + return p +} + +// containerResources set resources for container +func (p *PodSpecBuilder) containerResources(container *corev1.Container, requestFlavour, limitFlavour schema.Flavour) { + // fill request resources + if !schema.IsEmptyResource(requestFlavour.ResourceInfo) { + flavourResource, err := resources.NewResourceFromMap(requestFlavour.ToMap()) + if err == nil { + // request set specified value + container.Resources.Requests = k8s.NewResourceList(flavourResource) + } else { + log.Errorf("GenerateResourceRequirements by request:[%+v] error:%v", requestFlavour, err) + } + } + // fill limit resources + if !schema.IsEmptyResource(limitFlavour.ResourceInfo) { + limitFlavourResource, err := resources.NewResourceFromMap(limitFlavour.ToMap()) + if err != nil { + log.Errorf("GenerateResourceRequirements by limitFlavour:[%+v] error:%v", limitFlavourResource, err) + return + } + if strings.ToUpper(limitFlavour.Name) == schema.EnvJobLimitFlavourNone { + container.Resources.Limits = nil + } else if limitFlavourResource.CPU() == 0 || limitFlavourResource.Memory() == 0 { + // limit set zero, patch the same value as request + container.Resources.Limits = container.Resources.Requests + } else { + // limit set specified value + container.Resources.Limits = k8s.NewResourceList(limitFlavourResource) + } + } + +} + +// Containers set containers for pod +func (p *PodSpecBuilder) Containers(task schema.Member) *PodSpecBuilder { + log.Debugf("fill containers for job[%s]", p.jobID) + for idx, container := range p.podSpec.Containers { + // fill container[*].name + if task.Name != "" { + p.podSpec.Containers[idx].Name = task.Name + } + // fill container[*].Image + if task.Image != "" { + p.podSpec.Containers[idx].Image = task.Image + } + // fill container[*].Command + if task.Command != "" { + workDir := getWorkDir(&task, nil, nil) + container.Command = generateContainerCommand(task.Command, workDir) + } + // fill container[*].Args + if len(task.Args) > 0 { + container.Args = task.Args + } + // fill container[*].Resources + p.containerResources(&p.podSpec.Containers[idx], task.Flavour, task.LimitFlavour) + // fill env + p.podSpec.Containers[idx].Env = BuildEnvVars(container.Env, task.Env) + } + log.Debugf("fill containers completed: %v", p.podSpec.Containers) + return p +} + +// Build create pod spec +func (p *PodSpecBuilder) Build(task schema.Member) { + if p.podSpec == nil { + return + } + p.Scheduler(). + RestartPolicy(task.GetRestartPolicy()). + PriorityClassName(task.Priority). + PFS(task.Conf.GetAllFileSystem()). + Containers(task) +} + +// PodTemplateSpecBuilder build pod template spec +type PodTemplateSpecBuilder struct { + podTemplateSpec *corev1.PodTemplateSpec + jobID string +} + +// NewPodTemplateSpecBuilder create pod template spec builder +func NewPodTemplateSpecBuilder(podTempSpec *corev1.PodTemplateSpec, jobID string) *PodTemplateSpecBuilder { + return &PodTemplateSpecBuilder{ + podTemplateSpec: podTempSpec, + jobID: jobID, + } +} + +// Metadata set metadata for pod template spec +func (k *PodTemplateSpecBuilder) Metadata(name, namespace string, labels, annotations map[string]string) *PodTemplateSpecBuilder { + if name != "" { + k.podTemplateSpec.Name = name + } + if namespace != "" { + k.podTemplateSpec.Namespace = namespace + } + // set annotations + k.podTemplateSpec.Annotations = appendMapsIfAbsent(k.podTemplateSpec.Annotations, annotations) + // set labels + k.podTemplateSpec.Labels = appendMapsIfAbsent(k.podTemplateSpec.Labels, labels) + k.podTemplateSpec.Labels[schema.JobIDLabel] = k.jobID + k.podTemplateSpec.Labels[schema.JobOwnerLabel] = schema.JobOwnerValue + return k +} + +// Build set pod spec for pod template spec +func (k *PodTemplateSpecBuilder) Build(task schema.Member) { + if k.podTemplateSpec == nil { + return + } + // build metadata + k.Metadata(task.GetName(), task.GetNamespace(), task.GetLabels(), task.GetAnnotations()) + // set pod spec + NewPodSpecBuilder(&k.podTemplateSpec.Spec, k.jobID).Build(task) +} + +// KubeflowJobBuilder build kubeflow job +type KubeflowJobBuilder struct { + runPolicy *kubeflowv1.RunPolicy + replicaSpec *kubeflowv1.ReplicaSpec + jobID string +} + +// NewKubeflowJobBuilder create kubeflow job builder +func NewKubeflowJobBuilder(jobID string, runPolicy *kubeflowv1.RunPolicy, + replicaSpec *kubeflowv1.ReplicaSpec) *KubeflowJobBuilder { + return &KubeflowJobBuilder{ + jobID: jobID, + runPolicy: runPolicy, + replicaSpec: replicaSpec, + } +} + +// ReplicaSpec set replica spec for kubeflow job +func (k *KubeflowJobBuilder) ReplicaSpec(task schema.Member) { + if k.replicaSpec == nil { + return + } + // set Replicas for job + var replicas int32 = 1 + if task.Replicas > 0 { + replicas = int32(task.Replicas) + k.replicaSpec.Replicas = &replicas + } + // set RestartPolicy + k.replicaSpec.RestartPolicy = kubeflowv1.RestartPolicyNever + // set PodTemplate + NewPodTemplateSpecBuilder(&k.replicaSpec.Template, k.jobID).Build(task) +} + +// RunPolicy set run policy for kubeflow job +func (k *KubeflowJobBuilder) RunPolicy(minResources *corev1.ResourceList, queueName, priority string) { + if k.runPolicy == nil { + return + } + // set SchedulingPolicy + if k.runPolicy.SchedulingPolicy == nil { + k.runPolicy.SchedulingPolicy = &kubeflowv1.SchedulingPolicy{} + } + if queueName != "" { + k.runPolicy.SchedulingPolicy.Queue = queueName + } + if priority != "" { + k.runPolicy.SchedulingPolicy.PriorityClass = KubePriorityClass(priority) + } + if minResources != nil { + k.runPolicy.SchedulingPolicy.MinResources = minResources + } +} diff --git a/pkg/job/runtime_v2/job/util/kuberuntime/kube_job_builder_test.go b/pkg/job/runtime_v2/job/util/kuberuntime/kube_job_builder_test.go new file mode 100644 index 000000000..d3a3e51be --- /dev/null +++ b/pkg/job/runtime_v2/job/util/kuberuntime/kube_job_builder_test.go @@ -0,0 +1,247 @@ +/* +Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package kuberuntime + +import ( + "fmt" + "testing" + + kubeflowv1 "github.com/kubeflow/common/pkg/apis/common/v1" + "github.com/stretchr/testify/assert" + corev1 "k8s.io/api/core/v1" + + "github.com/PaddlePaddle/PaddleFlow/pkg/common/config" + "github.com/PaddlePaddle/PaddleFlow/pkg/common/schema" + "github.com/PaddlePaddle/PaddleFlow/pkg/model" + "github.com/PaddlePaddle/PaddleFlow/pkg/storage" + "github.com/PaddlePaddle/PaddleFlow/pkg/storage/driver" +) + +func TestPodSpecBuilder(t *testing.T) { + schedulerName := "testSchedulerName" + config.GlobalServerConfig = &config.ServerConfig{} + config.GlobalServerConfig.Job.SchedulerName = schedulerName + + testCases := []struct { + testName string + jobID string + podSpec *corev1.PodSpec + task schema.Member + err error + }{ + { + testName: "pod affinity is nil", + podSpec: &corev1.PodSpec{}, + jobID: "test-job-1", + task: schema.Member{ + Conf: schema.Conf{ + Name: "test-task-1", + QueueName: "test-queue", + Priority: "NORMAL", + FileSystem: schema.FileSystem{ + ID: "fs-root-test1", + Name: "test", + Type: "s3", + MountPath: "/home/work/mnt", + }, + }, + }, + }, + { + testName: "pod has affinity", + podSpec: &corev1.PodSpec{ + Affinity: &corev1.Affinity{ + NodeAffinity: &corev1.NodeAffinity{ + RequiredDuringSchedulingIgnoredDuringExecution: &corev1.NodeSelector{ + NodeSelectorTerms: []corev1.NodeSelectorTerm{ + { + MatchExpressions: []corev1.NodeSelectorRequirement{ + { + Key: "kubernetes.io/hostname", + Operator: corev1.NodeSelectorOpIn, + Values: []string{"instance1"}, + }, + }, + }, + }, + }, + }, + }, + }, + jobID: "test-job-2", + task: schema.Member{ + Conf: schema.Conf{ + Name: "test-task-1", + QueueName: "test-queue", + Priority: "NORMAL", + FileSystem: schema.FileSystem{ + ID: "fs-root-test2", + Name: "test", + Type: "s3", + MountPath: "/home/work/mnt", + }, + }, + }, + }, + } + + driver.InitMockDB() + for _, testCase := range testCases { + t.Run(testCase.testName, func(t *testing.T) { + err := storage.FsCache.Add(&model.FSCache{ + FsID: testCase.task.Conf.FileSystem.ID, + CacheDir: "./xx", + NodeName: "instance1", + ClusterID: "xxx", + }) + assert.Equal(t, nil, err) + NewPodSpecBuilder(testCase.podSpec, testCase.jobID).Build(testCase.task) + t.Logf("build pod spec: %v", testCase.podSpec) + }) + } +} + +func TestPodTemplateSpecBuilder(t *testing.T) { + schedulerName := "testSchedulerName" + config.GlobalServerConfig = &config.ServerConfig{} + config.GlobalServerConfig.Job.SchedulerName = schedulerName + + testCases := []struct { + testName string + jobID string + podSpec *corev1.PodTemplateSpec + task schema.Member + err error + }{ + { + testName: "pod affinity is nil", + podSpec: &corev1.PodTemplateSpec{}, + task: schema.Member{ + Conf: schema.Conf{ + Name: "test-task-1", + QueueName: "test-queue", + Priority: "NORMAL", + FileSystem: schema.FileSystem{ + ID: "fs-root-test1", + Name: "test", + Type: "s3", + MountPath: "/home/work/mnt", + }, + }, + }, + err: nil, + }, + { + testName: "wrong flavour nil", + podSpec: &corev1.PodTemplateSpec{}, + task: schema.Member{ + Conf: schema.Conf{ + Name: "test-task-1", + QueueName: "test-queue", + Priority: "NORMAL", + Flavour: schema.Flavour{Name: "", ResourceInfo: schema.ResourceInfo{CPU: "4a", Mem: "4Gi"}}, + }, + }, + err: fmt.Errorf("quantities must match the regular expression '^([+-]?[0-9.]+)([eEinumkKMGTP]*[-+]?[0-9]*)$'"), + }, + { + testName: "replicaSpec is nil", + podSpec: nil, + task: schema.Member{ + Conf: schema.Conf{ + Name: "test-task-1", + QueueName: "test-queue", + Priority: "NORMAL", + FileSystem: schema.FileSystem{ + ID: "fs-root-test1", + Name: "test", + Type: "s3", + MountPath: "/home/work/mnt", + }, + }, + }, + err: fmt.Errorf("podTemplateSpec or task is nil"), + }, + } + + driver.InitMockDB() + for _, tt := range testCases { + t.Run(tt.testName, func(t *testing.T) { + NewPodTemplateSpecBuilder(tt.podSpec, tt.jobID).Build(tt.task) + t.Logf("builder pod tempalte spec: %v", tt.podSpec) + }) + } +} + +func TestKubeflowReplicaSpec(t *testing.T) { + schedulerName := "testSchedulerName" + config.GlobalServerConfig = &config.ServerConfig{} + config.GlobalServerConfig.Job.SchedulerName = schedulerName + + testCases := []struct { + testName string + jobID string + replicaSpec *kubeflowv1.ReplicaSpec + task schema.Member + err error + }{ + { + testName: "pod affinity is nil", + replicaSpec: &kubeflowv1.ReplicaSpec{}, + task: schema.Member{ + Conf: schema.Conf{ + Name: "test-task-1", + QueueName: "test-queue", + Priority: "NORMAL", + FileSystem: schema.FileSystem{ + ID: "fs-root-test1", + Name: "test", + Type: "s3", + MountPath: "/home/work/mnt", + }, + }, + }, + err: nil, + }, + { + testName: "replicaSpec is nil", + replicaSpec: nil, + task: schema.Member{ + Conf: schema.Conf{ + Name: "test-task-1", + QueueName: "test-queue", + Priority: "NORMAL", + FileSystem: schema.FileSystem{ + ID: "fs-root-test1", + Name: "test", + Type: "s3", + MountPath: "/home/work/mnt", + }, + }, + }, + err: fmt.Errorf("build kubeflow replica spec failed, err: replicaSpec or task is nil"), + }, + } + + driver.InitMockDB() + for _, tt := range testCases { + t.Run(tt.testName, func(t *testing.T) { + NewKubeflowJobBuilder(tt.jobID, nil, tt.replicaSpec).ReplicaSpec(tt.task) + t.Logf("builder replica spec: %v", tt.replicaSpec) + }) + } +} diff --git a/pkg/job/runtime_v2/job/util/kuberuntime/kube_job_test.go b/pkg/job/runtime_v2/job/util/kuberuntime/kube_job_test.go index 127993f0f..07d0a058f 100644 --- a/pkg/job/runtime_v2/job/util/kuberuntime/kube_job_test.go +++ b/pkg/job/runtime_v2/job/util/kuberuntime/kube_job_test.go @@ -21,7 +21,6 @@ import ( "net/http/httptest" "testing" - kubeflowv1 "github.com/kubeflow/common/pkg/apis/common/v1" "github.com/stretchr/testify/assert" corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -74,218 +73,6 @@ func TestBuildSchedulingPolicy(t *testing.T) { assert.Equal(t, schedulerName, pod.Spec.SchedulerName) } -func TestBuildPodSpec(t *testing.T) { - schedulerName := "testSchedulerName" - config.GlobalServerConfig = &config.ServerConfig{} - config.GlobalServerConfig.Job.SchedulerName = schedulerName - - testCases := []struct { - testName string - podSpec *corev1.PodSpec - task schema.Member - err error - }{ - { - testName: "pod affinity is nil", - podSpec: &corev1.PodSpec{}, - task: schema.Member{ - Conf: schema.Conf{ - Name: "test-task-1", - QueueName: "test-queue", - Priority: "NORMAL", - FileSystem: schema.FileSystem{ - ID: "fs-root-test1", - Name: "test", - Type: "s3", - MountPath: "/home/work/mnt", - }, - }, - }, - }, - { - testName: "pod has affinity", - podSpec: &corev1.PodSpec{ - Affinity: &corev1.Affinity{ - NodeAffinity: &corev1.NodeAffinity{ - RequiredDuringSchedulingIgnoredDuringExecution: &corev1.NodeSelector{ - NodeSelectorTerms: []corev1.NodeSelectorTerm{ - { - MatchExpressions: []corev1.NodeSelectorRequirement{ - { - Key: "kubernetes.io/hostname", - Operator: corev1.NodeSelectorOpIn, - Values: []string{"instance1"}, - }, - }, - }, - }, - }, - }, - }, - }, - task: schema.Member{ - Conf: schema.Conf{ - Name: "test-task-1", - QueueName: "test-queue", - Priority: "NORMAL", - FileSystem: schema.FileSystem{ - ID: "fs-root-test2", - Name: "test", - Type: "s3", - MountPath: "/home/work/mnt", - }, - }, - }, - }, - } - - driver.InitMockDB() - for _, testCase := range testCases { - t.Run(testCase.testName, func(t *testing.T) { - err := storage.FsCache.Add(&model.FSCache{ - FsID: testCase.task.Conf.FileSystem.ID, - CacheDir: "./xx", - NodeName: "instance1", - ClusterID: "xxx", - }) - assert.Equal(t, nil, err) - err = BuildPodSpec(testCase.podSpec, testCase.task) - assert.Equal(t, testCase.err, err) - }) - } -} - -func TestKubeflowReplicaSpec(t *testing.T) { - schedulerName := "testSchedulerName" - config.GlobalServerConfig = &config.ServerConfig{} - config.GlobalServerConfig.Job.SchedulerName = schedulerName - - testCases := []struct { - testName string - jobID string - replicaSpec *kubeflowv1.ReplicaSpec - task schema.Member - err error - }{ - { - testName: "pod affinity is nil", - replicaSpec: &kubeflowv1.ReplicaSpec{}, - task: schema.Member{ - Conf: schema.Conf{ - Name: "test-task-1", - QueueName: "test-queue", - Priority: "NORMAL", - FileSystem: schema.FileSystem{ - ID: "fs-root-test1", - Name: "test", - Type: "s3", - MountPath: "/home/work/mnt", - }, - }, - }, - err: nil, - }, - { - testName: "replicaSpec is nil", - replicaSpec: nil, - task: schema.Member{ - Conf: schema.Conf{ - Name: "test-task-1", - QueueName: "test-queue", - Priority: "NORMAL", - FileSystem: schema.FileSystem{ - ID: "fs-root-test1", - Name: "test", - Type: "s3", - MountPath: "/home/work/mnt", - }, - }, - }, - err: fmt.Errorf("build kubeflow replica spec failed, err: replicaSpec or task is nil"), - }, - } - - driver.InitMockDB() - for _, tt := range testCases { - t.Run(tt.testName, func(t *testing.T) { - err := KubeflowReplicaSpec(tt.replicaSpec, tt.jobID, &tt.task) - assert.Equal(t, tt.err, err) - }) - } -} - -func TestBuildPodTemplateSpec(t *testing.T) { - schedulerName := "testSchedulerName" - config.GlobalServerConfig = &config.ServerConfig{} - config.GlobalServerConfig.Job.SchedulerName = schedulerName - - testCases := []struct { - testName string - jobID string - podSpec *corev1.PodTemplateSpec - task schema.Member - err error - }{ - { - testName: "pod affinity is nil", - podSpec: &corev1.PodTemplateSpec{}, - task: schema.Member{ - Conf: schema.Conf{ - Name: "test-task-1", - QueueName: "test-queue", - Priority: "NORMAL", - FileSystem: schema.FileSystem{ - ID: "fs-root-test1", - Name: "test", - Type: "s3", - MountPath: "/home/work/mnt", - }, - }, - }, - err: nil, - }, - { - testName: "wrong flavour nil", - podSpec: &corev1.PodTemplateSpec{}, - task: schema.Member{ - Conf: schema.Conf{ - Name: "test-task-1", - QueueName: "test-queue", - Priority: "NORMAL", - Flavour: schema.Flavour{Name: "", ResourceInfo: schema.ResourceInfo{CPU: "4a", Mem: "4Gi"}}, - }, - }, - err: fmt.Errorf("quantities must match the regular expression '^([+-]?[0-9.]+)([eEinumkKMGTP]*[-+]?[0-9]*)$'"), - }, - { - testName: "replicaSpec is nil", - podSpec: nil, - task: schema.Member{ - Conf: schema.Conf{ - Name: "test-task-1", - QueueName: "test-queue", - Priority: "NORMAL", - FileSystem: schema.FileSystem{ - ID: "fs-root-test1", - Name: "test", - Type: "s3", - MountPath: "/home/work/mnt", - }, - }, - }, - err: fmt.Errorf("podTemplateSpec or task is nil"), - }, - } - - driver.InitMockDB() - for _, tt := range testCases { - t.Run(tt.testName, func(t *testing.T) { - err := BuildPodTemplateSpec(tt.podSpec, tt.jobID, &tt.task) - assert.Equal(t, tt.err, err) - }) - } -} - func TestGenerateResourceRequirements(t *testing.T) { schedulerName := "testSchedulerName" config.GlobalServerConfig = &config.ServerConfig{}