Skip to content

Commit

Permalink
feat: add get gpu info from nvml recover (#52)
Browse files Browse the repository at this point in the history
  • Loading branch information
CorrectRoadH authored Apr 30, 2024
1 parent c2759bd commit 901de02
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 24 deletions.
54 changes: 31 additions & 23 deletions external/gpu.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,75 +37,82 @@ type NvidiaGPUInfo struct {
Utilization float32 `json:"utilization"`
}

func NvidiaGPUInfoListWithNVMl() ([]NvidiaGPUInfo, error) {
func NvidiaGPUInfoListWithNVML() (info []NvidiaGPUInfo, err error) {
// defer recover
defer func() {
if r := recover(); r != nil {
fmt.Println("Recovered in f", r)
err = fmt.Errorf("error getting GPU info: %v", r)
}
}()
var GPUInfos []NvidiaGPUInfo

// Initialize NVML
if err := nvml.Init(); err != nvml.SUCCESS {
if result := nvml.Init(); result != nvml.SUCCESS {
return nil, fmt.Errorf("error initializing NVML: %w", err)
}
defer nvml.Shutdown()

// Get device count
deviceCount, err := nvml.DeviceGetCount()
if err != nvml.SUCCESS {
deviceCount, result := nvml.DeviceGetCount()
if result != nvml.SUCCESS {
return nil, fmt.Errorf("error getting device count: %w", err)
}

for i := 0; i < deviceCount; i++ {
device, err := nvml.DeviceGetHandleByIndex(i)
if err != nvml.SUCCESS {
device, result := nvml.DeviceGetHandleByIndex(i)
if result != nvml.SUCCESS {
return nil, fmt.Errorf("error getting device handle: %w", err)
}

info := NvidiaGPUInfo{}
info.Index = int(i)

info.UUID, err = device.GetUUID()
if err != nvml.SUCCESS {
info.UUID, result = device.GetUUID()
if result != nvml.SUCCESS {
return nil, fmt.Errorf("error getting UUID: %w", err)
}

utilization, err := device.GetUtilizationRates()
if err != nvml.SUCCESS {
utilization, result := device.GetUtilizationRates()
if result != nvml.SUCCESS {
return nil, fmt.Errorf("error getting utilization rates: %w", err)
}
info.UtilizationGPU = int(utilization.Gpu)
info.MemoryUtilization = float32(utilization.Memory)

memInfo, err := device.GetMemoryInfo()
if err != nvml.SUCCESS {
memInfo, result := device.GetMemoryInfo()
if result != nvml.SUCCESS {
return nil, fmt.Errorf("error getting memory info: %w", err)
}
info.MemoryTotal = int(memInfo.Total)
info.MemoryUsed = int(memInfo.Used)
info.MemoryFree = int(memInfo.Free)

info.Name, err = device.GetName()
if err != nvml.SUCCESS {
info.Name, result = device.GetName()
if result != nvml.SUCCESS {
return nil, fmt.Errorf("error getting name: %w", err)
}

driverVersion, err := nvml.SystemGetDriverVersion()
if err != nvml.SUCCESS {
driverVersion, result := nvml.SystemGetDriverVersion()
if result != nvml.SUCCESS {
return nil, fmt.Errorf("error getting driver version: %w", err)
}
info.DriverVersion = driverVersion

temp, err := device.GetTemperature(nvml.TEMPERATURE_GPU)
if err != nvml.SUCCESS {
temp, result := device.GetTemperature(nvml.TEMPERATURE_GPU)
if result != nvml.SUCCESS {
return nil, fmt.Errorf("error getting temperature: %w", err)
}
info.TemperatureGPU = int(temp)

powerDraw, err := device.GetPowerUsage()
if err != nvml.SUCCESS {
powerDraw, result := device.GetPowerUsage()
if result != nvml.SUCCESS {
return nil, fmt.Errorf("error getting power usage: %w", err)
}
info.PowerDraw = float32(powerDraw) / 1000.0

powerLimit, err := device.GetEnforcedPowerLimit()
if err != nvml.SUCCESS {
powerLimit, result := device.GetEnforcedPowerLimit()
if result != nvml.SUCCESS {
return nil, fmt.Errorf("error getting power limit: %w", err)
}
info.PowerLimit = float32(powerLimit) / 1000.0
Expand Down Expand Up @@ -165,8 +172,9 @@ func NvidiaGPUInfoListWithSMI() ([]NvidiaGPUInfo, error) {
}

func NvidiaGPUInfoList() ([]NvidiaGPUInfo, error) {
gpusInfo, err := NvidiaGPUInfoListWithNVMl()
gpusInfo, err := NvidiaGPUInfoListWithNVML()
if err != nil {
fmt.Println("Error getting GPU info with NVML, trying with nvidia-smi")
gpusInfo, err = NvidiaGPUInfoListWithSMI()
if err != nil {
return nil, err
Expand Down
2 changes: 1 addition & 1 deletion external/gpu_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ func TestGPUTwoImplementInfo(t *testing.T) {
t.Skip()
result, err := external.NvidiaGPUInfoListWithSMI()
assert.NilError(t, err)
result2, err := external.NvidiaGPUInfoListWithNVMl()
result2, err := external.NvidiaGPUInfoListWithNVML()
assert.NilError(t, err)

assert.Equal(t, len(result), len(result2))
Expand Down

0 comments on commit 901de02

Please sign in to comment.