diff --git a/external/gpu.go b/external/gpu.go index 3943e17..656d2b9 100644 --- a/external/gpu.go +++ b/external/gpu.go @@ -1,10 +1,12 @@ package external import ( + "fmt" "os/exec" "strconv" "strings" + "github.com/NVIDIA/go-nvml/pkg/nvml" "github.com/samber/lo" ) @@ -35,7 +37,87 @@ type NvidiaGPUInfo struct { Utilization float32 `json:"utilization"` } -func NvidiaGPUInfoList() ([]NvidiaGPUInfo, error) { +func NvidiaGPUInfoListWithNVMl() ([]NvidiaGPUInfo, error) { + var GPUInfos []NvidiaGPUInfo + + // Initialize NVML + if err := nvml.Init(); err != nvml.SUCCESS { + return nil, fmt.Errorf("error initializing NVML: %w", err) + } + defer nvml.Shutdown() + + // Get device count + deviceCount, err := nvml.DeviceGetCount() + if err != nvml.SUCCESS { + return nil, fmt.Errorf("error getting device count: %w", err) + } + + for i := 0; i < deviceCount; i++ { + device, err := nvml.DeviceGetHandleByIndex(i) + if err != nvml.SUCCESS { + return nil, fmt.Errorf("error getting device handle: %w", err) + } + + info := NvidiaGPUInfo{} + info.Index = int(i) + + info.UUID, err = device.GetUUID() + if err != nvml.SUCCESS { + return nil, fmt.Errorf("error getting UUID: %w", err) + } + + utilization, err := device.GetUtilizationRates() + if err != nvml.SUCCESS { + return nil, fmt.Errorf("error getting utilization rates: %w", err) + } + info.UtilizationGPU = int(utilization.Gpu) + info.MemoryUtilization = float32(utilization.Memory) + + memInfo, err := device.GetMemoryInfo() + if err != nvml.SUCCESS { + return nil, fmt.Errorf("error getting memory info: %w", err) + } + info.MemoryTotal = int(memInfo.Total) + info.MemoryUsed = int(memInfo.Used) + info.MemoryFree = int(memInfo.Free) + + info.Name, err = device.GetName() + if err != nvml.SUCCESS { + return nil, fmt.Errorf("error getting name: %w", err) + } + + driverVersion, err := nvml.SystemGetDriverVersion() + if err != nvml.SUCCESS { + return nil, fmt.Errorf("error getting driver version: %w", err) + } + info.DriverVersion = driverVersion + + temp, err := device.GetTemperature(nvml.TEMPERATURE_GPU) + if err != nvml.SUCCESS { + return nil, fmt.Errorf("error getting temperature: %w", err) + } + info.TemperatureGPU = int(temp) + + powerDraw, err := device.GetPowerUsage() + if err != nvml.SUCCESS { + return nil, fmt.Errorf("error getting power usage: %w", err) + } + info.PowerDraw = float32(powerDraw) / 1000.0 + + powerLimit, err := device.GetEnforcedPowerLimit() + if err != nvml.SUCCESS { + return nil, fmt.Errorf("error getting power limit: %w", err) + } + info.PowerLimit = float32(powerLimit) / 1000.0 + info.GPUSerial = "[N/A]" + GPUInfos = append(GPUInfos, info) + + } + + return GPUInfos, nil +} + +func NvidiaGPUInfoListWithSMI() ([]NvidiaGPUInfo, error) { GPUInfos := []NvidiaGPUInfo{} output, err := exec.Command("nvidia-smi", "--query-gpu=index,uuid,utilization.gpu,memory.total,memory.used,memory.free,driver_version,name,gpu_serial,display_active,display_mode,temperature.gpu,utilization.gpu,utilization.memory,power.draw,power.limit", "--format=csv,noheader,nounits").Output() @@ -82,6 +164,17 @@ func NvidiaGPUInfoList() ([]NvidiaGPUInfo, error) { return GPUInfos, nil } +func NvidiaGPUInfoList() ([]NvidiaGPUInfo, error) { + gpusInfo, err := NvidiaGPUInfoListWithNVMl() + if err != nil { + gpusInfo, err = NvidiaGPUInfoListWithSMI() + if err != nil { + return nil, err + } + } + return gpusInfo, nil +} + func GPUInfoList() ([]GPUInfo, error) { GPUInfos := []GPUInfo{} nvidiaGPUInfoList, err := NvidiaGPUInfoList() diff --git a/external/gpu_test.go b/external/gpu_test.go index 7e45f9c..e4e6f52 100644 --- a/external/gpu_test.go +++ b/external/gpu_test.go @@ -13,3 +13,22 @@ func TestGPUInfo(t *testing.T) { assert.NilError(t, err) assert.Equal(t, len(result), 1) } + +func TestGPUTwoImplementInfo(t *testing.T) { + t.Skip() + result, err := external.NvidiaGPUInfoListWithSMI() + assert.NilError(t, err) + result2, err := external.NvidiaGPUInfoListWithNVMl() + assert.NilError(t, err) + + assert.Equal(t, len(result), len(result2)) + for i := range result { + assert.Equal(t, result[i].Name, result2[i].Name) + assert.Equal(t, result[i].DriverVersion, result2[i].DriverVersion) + assert.Equal(t, result[i].Name, result2[i].Name) + assert.Equal(t, result[i].DisplayMode, result2[i].DisplayMode) + assert.Equal(t, result[i].PowerLimit, result2[i].PowerLimit) + assert.Equal(t, result[i].MemoryUtilization, result2[i].MemoryUtilization) + assert.Equal(t, result[i].Utilization, result2[i].Utilization) + } +} diff --git a/go.mod b/go.mod index 8d2de14..c3066ab 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/IceWhaleTech/CasaOS-Common go 1.20 require ( + github.com/NVIDIA/go-nvml v0.12.0-5 github.com/coreos/go-systemd/v22 v22.5.0 github.com/gin-gonic/gin v1.7.7 github.com/golang-jwt/jwt/v4 v4.5.0 @@ -10,7 +11,7 @@ require ( github.com/mattn/go-sqlite3 v1.14.16 github.com/mholt/archiver/v3 v3.5.1 github.com/sirupsen/logrus v1.9.0 - github.com/stretchr/testify v1.8.2 + github.com/stretchr/testify v1.9.0 github.com/tidwall/gjson v1.17.0 go.uber.org/zap v1.24.0 gopkg.in/ini.v1 v1.67.0 diff --git a/go.sum b/go.sum index cc091e8..114fa6b 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +github.com/NVIDIA/go-nvml v0.12.0-5 h1:4DYsngBqJEAEj+/RFmBZ43Q3ymoR3tyS0oBuJk12Fag= +github.com/NVIDIA/go-nvml v0.12.0-5/go.mod h1:8Llmj+1Rr+9VGGwZuRer5N/aCjxGuR5nPb/9ebBiIEQ= github.com/andybalholm/brotli v1.0.1/go.mod h1:loMXtMfwqflxFJPmdbJO0a3KNoPuLBgiu3qAvBg8x/Y= github.com/andybalholm/brotli v1.0.5 h1:8uQZIdzKmjc/iuPu7O2ioW48L81FgatrcpfFmiq/cCs= github.com/andybalholm/brotli v1.0.5/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= @@ -94,8 +96,9 @@ github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81P github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= -github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8= github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/tidwall/gjson v1.17.0 h1:/Jocvlh98kcTfpN2+JzGQWQcqrPQwDrVEMApx/M5ZwM= github.com/tidwall/gjson v1.17.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=