Skip to content

Commit

Permalink
chore: Factoring out reusable presets logic - Part 4 (kaito-project#332)
Browse files Browse the repository at this point in the history
  • Loading branch information
ishaansehgal99 authored Apr 2, 2024
1 parent cf7cf94 commit 77aa95b
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 63 deletions.
92 changes: 29 additions & 63 deletions pkg/inference/preset-inferences.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package inference
import (
"context"
"fmt"
"github.com/azure/kaito/pkg/utils"
"os"
"strconv"

Expand All @@ -19,10 +20,9 @@ import (
)

const (
ProbePath = "/healthz"
Port5000 = int32(5000)
InferenceFile = "inference_api.py"
DefaultVolumeMountPath = "/dev/shm"
ProbePath = "/healthz"
Port5000 = int32(5000)
InferenceFile = "inference_api.py"
)

var (
Expand Down Expand Up @@ -92,21 +92,21 @@ func updateTorchParamsForDistributedInference(ctx context.Context, kubeClient cl
return nil
}

func GetImageInfo(ctx context.Context, workspaceObj *kaitov1alpha1.Workspace, inferenceObj *model.PresetParam) (string, []corev1.LocalObjectReference) {
imageName := string(workspaceObj.Inference.Preset.Name)
imageTag := inferenceObj.Tag
func GetInferenceImageInfo(ctx context.Context, workspaceObj *kaitov1alpha1.Workspace, presetObj *model.PresetParam) (string, []corev1.LocalObjectReference) {
imagePullSecretRefs := []corev1.LocalObjectReference{}
if inferenceObj.ImageAccessMode == "private" {
imageName = string(workspaceObj.Inference.Preset.PresetOptions.Image)
if presetObj.ImageAccessMode == "private" {
imageName := workspaceObj.Inference.Preset.PresetOptions.Image
for _, secretName := range workspaceObj.Inference.Preset.PresetOptions.ImagePullSecrets {
imagePullSecretRefs = append(imagePullSecretRefs, corev1.LocalObjectReference{Name: secretName})
}
return imageName, imagePullSecretRefs
} else {
imageName := string(workspaceObj.Inference.Preset.Name)
imageTag := presetObj.Tag
registryName := os.Getenv("PRESET_REGISTRY_NAME")
imageName = fmt.Sprintf("%s/kaito-%s:%s", registryName, imageName, imageTag)
return imageName, imagePullSecretRefs
}

registryName := os.Getenv("PRESET_REGISTRY_NAME")
imageName = registryName + fmt.Sprintf("/kaito-%s:%s", imageName, imageTag)
return imageName, imagePullSecretRefs
}

func CreatePresetInference(ctx context.Context, workspaceObj *kaitov1alpha1.Workspace,
Expand All @@ -118,17 +118,25 @@ func CreatePresetInference(ctx context.Context, workspaceObj *kaitov1alpha1.Work
}
}

volume, volumeMount := configVolume(workspaceObj, inferenceObj)
var volumes []corev1.Volume
var volumeMounts []corev1.VolumeMount
volume, volumeMount := utils.ConfigSHMVolume(workspaceObj)
if volume.Name != "" {
volumes = append(volumes, volume)
}
if volumeMount.Name != "" {
volumeMounts = append(volumeMounts, volumeMount)
}
commands, resourceReq := prepareInferenceParameters(ctx, inferenceObj)
image, imagePullSecrets := GetImageInfo(ctx, workspaceObj, inferenceObj)
image, imagePullSecrets := GetInferenceImageInfo(ctx, workspaceObj, inferenceObj)

var depObj client.Object
if supportDistributedInference {
depObj = resources.GenerateStatefulSetManifest(ctx, workspaceObj, image, imagePullSecrets, *workspaceObj.Resource.Count, commands,
containerPorts, livenessProbe, readinessProbe, resourceReq, tolerations, volume, volumeMount)
containerPorts, livenessProbe, readinessProbe, resourceReq, tolerations, volumes, volumeMounts)
} else {
depObj = resources.GenerateDeploymentManifest(ctx, workspaceObj, image, imagePullSecrets, *workspaceObj.Resource.Count, commands,
containerPorts, livenessProbe, readinessProbe, resourceReq, tolerations, volume, volumeMount)
containerPorts, livenessProbe, readinessProbe, resourceReq, tolerations, volumes, volumeMounts)
}
err := resources.CreateResource(ctx, depObj, kubeClient)
if client.IgnoreAlreadyExists(err) != nil {
Expand All @@ -142,10 +150,10 @@ func CreatePresetInference(ctx context.Context, workspaceObj *kaitov1alpha1.Work
// and sets the GPU resources required for inference.
// Returns the command and resource configuration.
func prepareInferenceParameters(ctx context.Context, inferenceObj *model.PresetParam) ([]string, corev1.ResourceRequirements) {
torchCommand := buildCommandStr(inferenceObj.BaseCommand, inferenceObj.TorchRunParams)
torchCommand = buildCommandStr(torchCommand, inferenceObj.TorchRunRdzvParams)
modelCommand := buildCommandStr(InferenceFile, inferenceObj.ModelRunParams)
commands := shellCommand(torchCommand + " " + modelCommand)
torchCommand := utils.BuildCmdStr(inferenceObj.BaseCommand, inferenceObj.TorchRunParams)
torchCommand = utils.BuildCmdStr(torchCommand, inferenceObj.TorchRunRdzvParams)
modelCommand := utils.BuildCmdStr(InferenceFile, inferenceObj.ModelRunParams)
commands := utils.ShellCmd(torchCommand + " " + modelCommand)

resourceRequirements := corev1.ResourceRequirements{
Requests: corev1.ResourceList{
Expand All @@ -158,45 +166,3 @@ func prepareInferenceParameters(ctx context.Context, inferenceObj *model.PresetP

return commands, resourceRequirements
}

func configVolume(wObj *kaitov1alpha1.Workspace, inferenceObj *model.PresetParam) ([]corev1.Volume, []corev1.VolumeMount) {
volume := []corev1.Volume{}
volumeMount := []corev1.VolumeMount{}

// Signifies multinode inference requirement
if *wObj.Resource.Count > 1 {
// Append share memory volume to any existing volumes
volume = append(volume, corev1.Volume{
Name: "dshm",
VolumeSource: corev1.VolumeSource{
EmptyDir: &corev1.EmptyDirVolumeSource{
Medium: "Memory",
},
},
})

volumeMount = append(volumeMount, corev1.VolumeMount{
Name: volume[0].Name,
MountPath: DefaultVolumeMountPath,
})
}

return volume, volumeMount
}

func shellCommand(command string) []string {
return []string{
"/bin/sh",
"-c",
command,
}
}

func buildCommandStr(baseCommand string, torchRunParams map[string]string) string {
updatedBaseCommand := baseCommand
for key, value := range torchRunParams {
updatedBaseCommand = fmt.Sprintf("%s --%s=%s", updatedBaseCommand, key, value)
}

return updatedBaseCommand
}
72 changes: 72 additions & 0 deletions pkg/utils/common-preset.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
package utils

import (
"fmt"
kaitov1alpha1 "github.com/azure/kaito/api/v1alpha1"
corev1 "k8s.io/api/core/v1"
)

const (
DefaultVolumeMountPath = "/dev/shm"
)

func ConfigSHMVolume(wObj *kaitov1alpha1.Workspace) (corev1.Volume, corev1.VolumeMount) {
volume := corev1.Volume{}
volumeMount := corev1.VolumeMount{}

// Signifies multinode inference requirement
if *wObj.Resource.Count > 1 {
// Append share memory volume to any existing volumes
volume = corev1.Volume{
Name: "dshm",
VolumeSource: corev1.VolumeSource{
EmptyDir: &corev1.EmptyDirVolumeSource{
Medium: "Memory",
},
},
}

volumeMount = corev1.VolumeMount{
Name: volume.Name,
MountPath: DefaultVolumeMountPath,
}
}

return volume, volumeMount
}

func ConfigDataVolume() ([]corev1.Volume, []corev1.VolumeMount) {
var volumes []corev1.Volume
var volumeMounts []corev1.VolumeMount
volumes = append(volumes, corev1.Volume{
Name: "data-volume",
VolumeSource: corev1.VolumeSource{
EmptyDir: &corev1.EmptyDirVolumeSource{},
},
})

volumeMounts = append(volumeMounts, corev1.VolumeMount{
Name: "data-volume",
MountPath: "/data",
})
return volumes, volumeMounts
}

func ShellCmd(command string) []string {
return []string{
"/bin/sh",
"-c",
command,
}
}

func BuildCmdStr(baseCommand string, torchRunParams map[string]string) string {
updatedBaseCommand := baseCommand
for key, value := range torchRunParams {
updatedBaseCommand = fmt.Sprintf("%s --%s=%s", updatedBaseCommand, key, value)
}

return updatedBaseCommand
}

0 comments on commit 77aa95b

Please sign in to comment.