From 901de02bcbb7baf4a823ba5d3cc92e2b10f9bcef Mon Sep 17 00:00:00 2001 From: CorrectRoad Date: Tue, 30 Apr 2024 11:55:43 +0800 Subject: [PATCH] feat: add get gpu info from nvml recover (#52) --- external/gpu.go | 54 +++++++++++++++++++++++++------------------- external/gpu_test.go | 2 +- 2 files changed, 32 insertions(+), 24 deletions(-) diff --git a/external/gpu.go b/external/gpu.go index 656d2b9..4fb40ce 100644 --- a/external/gpu.go +++ b/external/gpu.go @@ -37,75 +37,82 @@ type NvidiaGPUInfo struct { Utilization float32 `json:"utilization"` } -func NvidiaGPUInfoListWithNVMl() ([]NvidiaGPUInfo, error) { +func NvidiaGPUInfoListWithNVML() (info []NvidiaGPUInfo, err error) { + // defer recover + defer func() { + if r := recover(); r != nil { + fmt.Println("Recovered in f", r) + err = fmt.Errorf("error getting GPU info: %v", r) + } + }() var GPUInfos []NvidiaGPUInfo // Initialize NVML - if err := nvml.Init(); err != nvml.SUCCESS { + if result := nvml.Init(); result != nvml.SUCCESS { return nil, fmt.Errorf("error initializing NVML: %w", err) } defer nvml.Shutdown() // Get device count - deviceCount, err := nvml.DeviceGetCount() - if err != nvml.SUCCESS { + deviceCount, result := nvml.DeviceGetCount() + if result != nvml.SUCCESS { return nil, fmt.Errorf("error getting device count: %w", err) } for i := 0; i < deviceCount; i++ { - device, err := nvml.DeviceGetHandleByIndex(i) - if err != nvml.SUCCESS { + device, result := nvml.DeviceGetHandleByIndex(i) + if result != nvml.SUCCESS { return nil, fmt.Errorf("error getting device handle: %w", err) } info := NvidiaGPUInfo{} info.Index = int(i) - info.UUID, err = device.GetUUID() - if err != nvml.SUCCESS { + info.UUID, result = device.GetUUID() + if result != nvml.SUCCESS { return nil, fmt.Errorf("error getting UUID: %w", err) } - utilization, err := device.GetUtilizationRates() - if err != nvml.SUCCESS { + utilization, result := device.GetUtilizationRates() + if result != nvml.SUCCESS { return nil, fmt.Errorf("error getting utilization rates: %w", err) } info.UtilizationGPU = int(utilization.Gpu) info.MemoryUtilization = float32(utilization.Memory) - memInfo, err := device.GetMemoryInfo() - if err != nvml.SUCCESS { + memInfo, result := device.GetMemoryInfo() + if result != nvml.SUCCESS { return nil, fmt.Errorf("error getting memory info: %w", err) } info.MemoryTotal = int(memInfo.Total) info.MemoryUsed = int(memInfo.Used) info.MemoryFree = int(memInfo.Free) - info.Name, err = device.GetName() - if err != nvml.SUCCESS { + info.Name, result = device.GetName() + if result != nvml.SUCCESS { return nil, fmt.Errorf("error getting name: %w", err) } - driverVersion, err := nvml.SystemGetDriverVersion() - if err != nvml.SUCCESS { + driverVersion, result := nvml.SystemGetDriverVersion() + if result != nvml.SUCCESS { return nil, fmt.Errorf("error getting driver version: %w", err) } info.DriverVersion = driverVersion - temp, err := device.GetTemperature(nvml.TEMPERATURE_GPU) - if err != nvml.SUCCESS { + temp, result := device.GetTemperature(nvml.TEMPERATURE_GPU) + if result != nvml.SUCCESS { return nil, fmt.Errorf("error getting temperature: %w", err) } info.TemperatureGPU = int(temp) - powerDraw, err := device.GetPowerUsage() - if err != nvml.SUCCESS { + powerDraw, result := device.GetPowerUsage() + if result != nvml.SUCCESS { return nil, fmt.Errorf("error getting power usage: %w", err) } info.PowerDraw = float32(powerDraw) / 1000.0 - powerLimit, err := device.GetEnforcedPowerLimit() - if err != nvml.SUCCESS { + powerLimit, result := device.GetEnforcedPowerLimit() + if result != nvml.SUCCESS { return nil, fmt.Errorf("error getting power limit: %w", err) } info.PowerLimit = float32(powerLimit) / 1000.0 @@ -165,8 +172,9 @@ func NvidiaGPUInfoListWithSMI() ([]NvidiaGPUInfo, error) { } func NvidiaGPUInfoList() ([]NvidiaGPUInfo, error) { - gpusInfo, err := NvidiaGPUInfoListWithNVMl() + gpusInfo, err := NvidiaGPUInfoListWithNVML() if err != nil { + fmt.Println("Error getting GPU info with NVML, trying with nvidia-smi") gpusInfo, err = NvidiaGPUInfoListWithSMI() if err != nil { return nil, err diff --git a/external/gpu_test.go b/external/gpu_test.go index e4e6f52..0807dd8 100644 --- a/external/gpu_test.go +++ b/external/gpu_test.go @@ -18,7 +18,7 @@ func TestGPUTwoImplementInfo(t *testing.T) { t.Skip() result, err := external.NvidiaGPUInfoListWithSMI() assert.NilError(t, err) - result2, err := external.NvidiaGPUInfoListWithNVMl() + result2, err := external.NvidiaGPUInfoListWithNVML() assert.NilError(t, err) assert.Equal(t, len(result), len(result2))