Skip to content

Commit

Permalink
Override defaults if GPU (#664)
Browse files Browse the repository at this point in the history
* first commit

* missing comma

* missing function

* simplification

* updates tables

* logging no 1

* more debug

* fixing things and log them

* fixing

* accelerator type is just a string not url

* Revert "accelerator type is just a string not url"

This reverts commit c09b6fe.

* more logging and stuff

* more logs

* typo

* more and more logging

* use zone instead of region

* styles

* do no stack accelerator type

* removing debug

* always load defaults for gpu plan

* forgotten line

* Some nice logs

* pasing only count

* adding statement

* disk spart way

* cleanup

* do not assign VMsize if gpu VM Type

* Add gpu to query tags (#670)

* added gpu to the tags list for quering images in api-selector

* added debug lines

* add gpu_vm_type to params in api selector

* removed debug lines

* Update CHANGELOG.md
  • Loading branch information
makemp authored Jul 24, 2023
1 parent b840303 commit ee57d65
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 13 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/).

### Removed

### Added
- Adding GPU Support

### Fixed

## [6.2.4] - 2019-10-29
Expand Down
110 changes: 97 additions & 13 deletions backend/gce.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,20 @@ Set-LocalUser -Name travis -Password $pw
"large": "n2-standard-4",
"x-large": "n2-standard-8",
"2x-large": "n2-standard-16",
"gpu-medium": "n1-standard-8",
"gpu-xlarge": "n1-standard-8",
}
)

func stringInSlice(a string, list []string) bool {
for _, b := range list {
if b == a {
return true
}
}
return false
}

type gceStartupScriptData struct {
AutoImplode bool
HardTimeoutMinutes int64
Expand All @@ -180,6 +191,53 @@ func (oe *gceOpError) Error() string {
return strings.Join(errStrs, ", ")
}

type singleGpuMapping struct {
GpuCount int64
GpuType string
DiskSize int64
}

var gpuMedium = singleGpuMapping{
GpuCount: 1,
GpuType: "nvidia-tesla-t4",
DiskSize: 300,}
var gpuXLarge = singleGpuMapping{
GpuCount: 1,
GpuType: "nvidia-tesla-v100",
DiskSize: 300,}

func GpuMapping(vmSize string) (value singleGpuMapping) {
gpuMapping := map[string] singleGpuMapping{
"gpu-medium": gpuMedium,
"gpu-xlarge": gpuXLarge,
}
return gpuMapping[vmSize]
}


func GpuDefaultGpuCount(vmSize string) (gpuCountInt int64) {
return GpuMapping(vmSize).GpuCount
}

func GpuDefaultGpuDiskSize(vmSize string) (gpuDiskSizeInt int64) {
return GpuMapping(vmSize).DiskSize
}

func GpuDefaultGpuType(vmSize string) (gpuTypeString string) {
return GpuMapping(vmSize).GpuType
}

func GPUType(varSize string) string {
switch varSize {
case "gpu-medium":
return "gpu-medium"
case "gpu-xlarge":
return "gpu-xlarge"
default:
return ""
}
}

type gceAccountJSON struct {
ClientEmail string `json:"client_email"`
PrivateKey string `json:"private_key"`
Expand Down Expand Up @@ -827,7 +885,9 @@ func (p *gceProvider) Setup(ctx gocontext.Context) error {

machineTypes := []string{p.ic.MachineType, p.ic.PremiumMachineType}
for _, machineType := range gceVMSizeMapping {
machineTypes = append(machineTypes, machineType);
if !stringInSlice(machineType, machineTypes) {
machineTypes = append(machineTypes, machineType);
}
}
for _, zoneName := range append(zoneNames, p.alternateZones...) {
for _, machineType := range machineTypes {
Expand Down Expand Up @@ -1421,6 +1481,7 @@ func (p *gceProvider) imageSelect(ctx gocontext.Context, startAttributes *StartA

jobID, _ := context.JobIDFromContext(ctx)
repo, _ := context.RepositoryFromContext(ctx)
var gpuVMType = GPUType(startAttributes.VMSize)

if startAttributes.ImageName != "" {
imageName = startAttributes.ImageName
Expand All @@ -1434,6 +1495,7 @@ func (p *gceProvider) imageSelect(ctx gocontext.Context, startAttributes *StartA
OS: startAttributes.OS,
JobID: jobID,
Repo: repo,
GpuVMType: gpuVMType,
})

if err != nil {
Expand Down Expand Up @@ -1485,11 +1547,31 @@ func (p *gceProvider) buildInstance(ctx gocontext.Context, c *gceStartContext) (
Zone: c.zoneName,
}

var gpuVMType = GPUType(c.startAttributes.VMSize)

machineType := p.ic.MachineType
if c.startAttributes.VMType == "premium" {
c.startAttributes.VMSize = "premium"
machineType = p.ic.PremiumMachineType
} else if c.startAttributes.VMSize != "" {
if mtype, ok := gceVMSizeMapping[c.startAttributes.VMSize]; ok {
machineType = mtype;
//storing converted machine type for instance size identification
if gpuVMType == "" {
c.startAttributes.VMSize = machineType
}
}
}

diskSize := p.ic.DiskSize
if c.startAttributes.OS == "windows" {
diskSize = p.ic.DiskSizeWindows
}

if gpuVMType != "" {
diskSize = GpuDefaultGpuDiskSize(gpuVMType)
}

diskInitParams := &compute.AttachedDiskInitializeParams{
SourceImage: c.image.SelfLink,
DiskType: gcePdSSDForZone(c.zoneName),
Expand All @@ -1506,18 +1588,6 @@ func (p *gceProvider) buildInstance(ctx gocontext.Context, c *gceStartContext) (
},
}

machineType := p.ic.MachineType
if c.startAttributes.VMType == "premium" {
c.startAttributes.VMSize = "premium"
machineType = p.ic.PremiumMachineType
} else if c.startAttributes.VMSize != "" {
if mtype, ok := gceVMSizeMapping[c.startAttributes.VMSize]; ok {
machineType = mtype;
//storing converted machine type for instance size identification
c.startAttributes.VMSize = machineType
}
}

var ok bool
inst.MachineType, ok = p.machineTypeSelfLinks[gceMtKey(c.zoneName, machineType)]
if !ok {
Expand All @@ -1532,6 +1602,19 @@ func (p *gceProvider) buildInstance(ctx gocontext.Context, c *gceStartContext) (
p.projectID,
c.startAttributes.VMConfig.Zone,
c.startAttributes.VMConfig.GpuType)
} else if gpuVMType != "" {
logger.WithField("acceleratorConfig.AcceleratorType", acceleratorConfig.AcceleratorType).Debug("Setting AcceleratorConfig")
if !strings.HasPrefix(acceleratorConfig.AcceleratorType, "https") {
notUrlAcceleratorType := GpuDefaultGpuType(gpuVMType)
logger.WithField("notUrlAcceleratorType", notUrlAcceleratorType).Debug("Retrieving AcceleratorType from defaults")
logger.WithField("AcceleratorCount", p.ic.AcceleratorConfig.AcceleratorCount).Debug("Retrieving AcceleratorCount from defaults")
acceleratorConfig.AcceleratorCount = GpuDefaultGpuCount(gpuVMType)
acceleratorConfig.AcceleratorType = fmt.Sprintf("https://www.googleapis.com/compute/v1/projects/%s/zones/%s/acceleratorTypes/%s",
p.projectID,
c.zoneName,
notUrlAcceleratorType)
logger.WithField("acceleratorConfig.AcceleratorType", acceleratorConfig.AcceleratorType).Debug("Url for Accelerator Type is:")
}
}

var subnetwork string
Expand Down Expand Up @@ -1595,6 +1678,7 @@ func (p *gceProvider) buildInstance(ctx gocontext.Context, c *gceStartContext) (
}

inst.GuestAccelerators = []*compute.AcceleratorConfig{}

if acceleratorConfig.AcceleratorCount > 0 {
logger.Debug("GPU requested, setting acceleratorConfig")
inst.GuestAccelerators = append(inst.GuestAccelerators, acceleratorConfig)
Expand Down
8 changes: 8 additions & 0 deletions image/api_selector.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ func (as *APISelector) queryWithTags(ctx gocontext.Context, infra string, tags [
bodyLines := []string{}
lastJobID := uint64(0)
lastRepo := ""
gpuVMType := ""

for _, ts := range tags {
qs := url.Values{}
Expand All @@ -127,6 +128,7 @@ func (as *APISelector) queryWithTags(ctx gocontext.Context, infra string, tags [
qs.Set("limit", "1")
qs.Set("job_id", fmt.Sprintf("%v", ts.JobID))
qs.Set("repo", ts.Repo)
qs.Set("gpu_vm_type", ts.GpuVMType)
qs.Set("is_default", fmt.Sprintf("%v", ts.IsDefault))
if len(ts.Tags) > 0 {
qs.Set("tags", strings.Join(ts.Tags, ","))
Expand All @@ -135,6 +137,7 @@ func (as *APISelector) queryWithTags(ctx gocontext.Context, infra string, tags [
bodyLines = append(bodyLines, qs.Encode())
lastJobID = ts.JobID
lastRepo = ts.Repo
gpuVMType = ts.GpuVMType
}

qs := url.Values{}
Expand All @@ -144,6 +147,7 @@ func (as *APISelector) queryWithTags(ctx gocontext.Context, infra string, tags [
qs.Set("limit", "1")
qs.Set("job_id", fmt.Sprintf("%v", lastJobID))
qs.Set("repo", lastRepo)
qs.Set("gpu_vm_type", gpuVMType)

bodyLines = append(bodyLines, qs.Encode())

Expand Down Expand Up @@ -233,6 +237,7 @@ type tagSet struct {

JobID uint64
Repo string
GpuVMType string
}

func (ts *tagSet) GoString() string {
Expand All @@ -244,6 +249,7 @@ func (as *APISelector) buildCandidateTags(params *Params) ([]*tagSet, error) {
Tags: []string{},
JobID: params.JobID,
Repo: params.Repo,
GpuVMType: params.GpuVMType,
}
candidateTags := []*tagSet{}

Expand All @@ -255,6 +261,7 @@ func (as *APISelector) buildCandidateTags(params *Params) ([]*tagSet, error) {
Tags: []string{tag},
JobID: params.JobID,
Repo: params.Repo,
GpuVMType: params.GpuVMType,
})
}

Expand All @@ -265,6 +272,7 @@ func (as *APISelector) buildCandidateTags(params *Params) ([]*tagSet, error) {
Tags: tags,
JobID: params.JobID,
Repo: params.Repo,
GpuVMType: params.GpuVMType,
})
}

Expand Down
1 change: 1 addition & 0 deletions image/params.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ type Params struct {

JobID uint64
Repo string
GpuVMType string
}

0 comments on commit ee57d65

Please sign in to comment.