diff --git a/controllers/object_controls.go b/controllers/object_controls.go index 2da92f40e..9b2f26893 100644 --- a/controllers/object_controls.go +++ b/controllers/object_controls.go @@ -1114,14 +1114,16 @@ func TransformToolkit(obj *appsv1.DaemonSet, config *gpuv1.ClusterPolicySpec, n return fmt.Errorf("error getting path to runtime config file: %v", err) } sourceConfigFileName := path.Base(runtimeConfigFile) - // update runtime args + + var configEnvvarName string if runtime == gpuv1.Containerd.String() { - setContainerEnv(&(obj.Spec.Template.Spec.Containers[0]), "CONTAINERD_CONFIG", DefaultRuntimeConfigTargetDir+sourceConfigFileName) + configEnvvarName = "CONTAINERD_CONFIG" } else if runtime == gpuv1.Docker.String() { - setContainerEnv(&(obj.Spec.Template.Spec.Containers[0]), "DOCKER_CONFIG", DefaultRuntimeConfigTargetDir+sourceConfigFileName) + configEnvvarName = "DOCKER_CONFIG" } else if runtime == gpuv1.CRIO.String() { - setContainerEnv(&(obj.Spec.Template.Spec.Containers[0]), "CRIO_CONFIG", DefaultRuntimeConfigTargetDir+sourceConfigFileName) + configEnvvarName = "CRIO_CONFIG" } + setContainerEnv(&(obj.Spec.Template.Spec.Containers[0]), configEnvvarName, DefaultRuntimeConfigTargetDir+sourceConfigFileName) volMountConfigName := fmt.Sprintf("%s-config", runtime) volMountConfig := corev1.VolumeMount{Name: volMountConfigName, MountPath: DefaultRuntimeConfigTargetDir} @@ -1131,22 +1133,21 @@ func TransformToolkit(obj *appsv1.DaemonSet, config *gpuv1.ClusterPolicySpec, n obj.Spec.Template.Spec.Volumes = append(obj.Spec.Template.Spec.Volumes, configVol) // setup mounts for runtime socket file - if runtime == gpuv1.Docker.String() || runtime == gpuv1.Containerd.String() { - runtimeSocketFile, err := getRuntimeSocketFile(&(obj.Spec.Template.Spec.Containers[0]), runtime) - if err != nil { - return fmt.Errorf("error getting path to runtime socket: %v", err) - } + runtimeSocketFile, err := getRuntimeSocketFile(&(obj.Spec.Template.Spec.Containers[0]), runtime) + if err != nil { + return fmt.Errorf("error getting path to runtime socket: %v", err) + } + if runtimeSocketFile != "" { sourceSocketFileName := path.Base(runtimeSocketFile) - // update runtime args - // update runtime args + // set envvar for runtime socket + var socketEnvvarName string if runtime == gpuv1.Containerd.String() { - setContainerEnv(&(obj.Spec.Template.Spec.Containers[0]), "CONTAINERD_SOCKET", DefaultRuntimeSocketTargetDir+sourceSocketFileName) + socketEnvvarName = "CONTAINERD_SOCKET" } else if runtime == gpuv1.Docker.String() { - setContainerEnv(&(obj.Spec.Template.Spec.Containers[0]), "DOCKER_SOCKET", DefaultRuntimeSocketTargetDir+sourceSocketFileName) - } else if runtime == gpuv1.CRIO.String() { - runtimeArgs := " --socket " + DefaultRuntimeSocketTargetDir + sourceSocketFileName - setContainerEnv(&(obj.Spec.Template.Spec.Containers[0]), "RUNTIME_ARGS", runtimeArgs) + socketEnvvarName = "DOCKER_SOCKET" } + setContainerEnv(&(obj.Spec.Template.Spec.Containers[0]), socketEnvvarName, DefaultRuntimeSocketTargetDir+sourceSocketFileName) + volMountSocketName := fmt.Sprintf("%s-socket", runtime) volMountSocket := corev1.VolumeMount{Name: volMountSocketName, MountPath: DefaultRuntimeSocketTargetDir} obj.Spec.Template.Spec.Containers[0].VolumeMounts = append(obj.Spec.Template.Spec.Containers[0].VolumeMounts, volMountSocket) @@ -1951,6 +1952,8 @@ func getRuntimeSocketFile(c *corev1.Container, runtime string) (string, error) { if getContainerEnv(c, "CONTAINERD_SOCKET") != "" { runtimeSocketFile = getContainerEnv(c, "CONTAINERD_SOCKET") } + case gpuv1.CRIO.String(): + runtimeSocketFile = "" default: return "", fmt.Errorf("invalid runtime: %s", runtime) }