diff --git a/pkg/azuredisk/azuredisk.go b/pkg/azuredisk/azuredisk.go index db8072353e..36b601406f 100644 --- a/pkg/azuredisk/azuredisk.go +++ b/pkg/azuredisk/azuredisk.go @@ -115,6 +115,7 @@ type DriverCore struct { attachDetachInitialDelayInMs int64 vmType string enableWindowsHostProcess bool + listDisksUsingWinCIM bool getNodeIDFromIMDS bool enableOtelTracing bool shouldWaitForSnapshotReady bool @@ -170,6 +171,7 @@ func newDriverV1(options *DriverOptions) *Driver { driver.volStatsCacheExpireInMinutes = options.VolStatsCacheExpireInMinutes driver.vmType = options.VMType driver.enableWindowsHostProcess = options.EnableWindowsHostProcess + driver.listDisksUsingWinCIM = options.ListDisksUsingWinCIM driver.getNodeIDFromIMDS = options.GetNodeIDFromIMDS driver.enableOtelTracing = options.EnableOtelTracing driver.shouldWaitForSnapshotReady = options.WaitForSnapshotReady @@ -267,7 +269,7 @@ func newDriverV1(options *DriverOptions) *Driver { } } - driver.mounter, err = mounter.NewSafeMounter(driver.enableWindowsHostProcess, driver.useCSIProxyGAInterface, int(driver.maxConcurrentFormat), time.Duration(driver.concurrentFormatTimeout)*time.Second) + driver.mounter, err = mounter.NewSafeMounter(driver.enableWindowsHostProcess, driver.listDisksUsingWinCIM, driver.useCSIProxyGAInterface, int(driver.maxConcurrentFormat), time.Duration(driver.concurrentFormatTimeout)*time.Second) if err != nil { klog.Fatalf("Failed to get safe mounter. Error: %v", err) } diff --git a/pkg/azuredisk/azuredisk_option.go b/pkg/azuredisk/azuredisk_option.go index 85b552ab9b..3a7f10fc7c 100644 --- a/pkg/azuredisk/azuredisk_option.go +++ b/pkg/azuredisk/azuredisk_option.go @@ -53,6 +53,7 @@ type DriverOptions struct { VolStatsCacheExpireInMinutes int64 VMType string EnableWindowsHostProcess bool + ListDisksUsingWinCIM bool GetNodeIDFromIMDS bool WaitForSnapshotReady bool CheckDiskLUNCollision bool @@ -97,6 +98,7 @@ func (o *DriverOptions) AddFlags() *flag.FlagSet { fs.Int64Var(&o.VolStatsCacheExpireInMinutes, "vol-stats-cache-expire-in-minutes", 10, "The cache expire time in minutes for volume stats cache") fs.StringVar(&o.VMType, "vm-type", "", "type of agent node. available values: vmss, standard") fs.BoolVar(&o.EnableWindowsHostProcess, "enable-windows-host-process", false, "enable windows host process") + fs.BoolVar(&o.ListDisksUsingWinCIM, "list-disks-using-win-cim", true, "list disks using CIM API on Windows") fs.BoolVar(&o.GetNodeIDFromIMDS, "get-nodeid-from-imds", false, "boolean flag to get NodeID from IMDS") fs.BoolVar(&o.WaitForSnapshotReady, "wait-for-snapshot-ready", true, "boolean flag to wait for snapshot ready when creating snapshot in same region") fs.BoolVar(&o.CheckDiskLUNCollision, "check-disk-lun-collision", true, "boolean flag to check disk lun collisio before attaching disk") diff --git a/pkg/azuredisk/azuredisk_v2.go b/pkg/azuredisk/azuredisk_v2.go index cad50b0850..2e3ba525b8 100644 --- a/pkg/azuredisk/azuredisk_v2.go +++ b/pkg/azuredisk/azuredisk_v2.go @@ -144,7 +144,7 @@ func newDriverV2(options *DriverOptions) *DriverV2 { } } - driver.mounter, err = mounter.NewSafeMounter(driver.enableWindowsHostProcess, driver.useCSIProxyGAInterface, int(driver.maxConcurrentFormat), time.Duration(driver.concurrentFormatTimeout)*time.Second) + driver.mounter, err = mounter.NewSafeMounter(driver.enableWindowsHostProcess, driver.listDisksUsingWinCIM, driver.useCSIProxyGAInterface, int(driver.maxConcurrentFormat), time.Duration(driver.concurrentFormatTimeout)*time.Second) if err != nil { klog.Fatalf("Failed to get safe mounter. Error: %v", err) } diff --git a/pkg/azuredisk/fake_azuredisk.go b/pkg/azuredisk/fake_azuredisk.go index e3fe666e6d..6411029242 100644 --- a/pkg/azuredisk/fake_azuredisk.go +++ b/pkg/azuredisk/fake_azuredisk.go @@ -125,7 +125,7 @@ func newFakeDriverV1(ctrl *gomock.Controller) (*fakeDriverV1, error) { driver.diskController = NewManagedDiskController(driver.cloud) driver.clientFactory = driver.cloud.ComputeClientFactory - mounter, err := mounter.NewSafeMounter(true, driver.useCSIProxyGAInterface, int(driver.maxConcurrentFormat), time.Duration(driver.concurrentFormatTimeout)*time.Second) + mounter, err := mounter.NewSafeMounter(true, true, driver.useCSIProxyGAInterface, int(driver.maxConcurrentFormat), time.Duration(driver.concurrentFormatTimeout)*time.Second) if err != nil { return nil, err } diff --git a/pkg/azuredisk/fake_azuredisk_v2.go b/pkg/azuredisk/fake_azuredisk_v2.go index ece5748b07..4d3b549fd9 100644 --- a/pkg/azuredisk/fake_azuredisk_v2.go +++ b/pkg/azuredisk/fake_azuredisk_v2.go @@ -76,7 +76,7 @@ func newFakeDriverV2(ctrl *gomock.Controller) (*fakeDriverV2, error) { driver.diskController = NewManagedDiskController(driver.cloud) driver.clientFactory = driver.cloud.ComputeClientFactory - mounter, err := mounter.NewSafeMounter(true, driver.useCSIProxyGAInterface, int(driver.maxConcurrentFormat), time.Duration(driver.concurrentFormatTimeout)*time.Second) + mounter, err := mounter.NewSafeMounter(true, true, driver.useCSIProxyGAInterface, int(driver.maxConcurrentFormat), time.Duration(driver.concurrentFormatTimeout)*time.Second) if err != nil { return nil, err } diff --git a/pkg/mounter/fake_safe_mounter.go b/pkg/mounter/fake_safe_mounter.go index e5d079b616..7111167b1a 100644 --- a/pkg/mounter/fake_safe_mounter.go +++ b/pkg/mounter/fake_safe_mounter.go @@ -36,7 +36,7 @@ type FakeSafeMounter struct { // NewFakeSafeMounter creates a mount.SafeFormatAndMount instance suitable for use in unit tests. func NewFakeSafeMounter() (*mount.SafeFormatAndMount, error) { if runtime.GOOS == "windows" { - return NewSafeMounter(true, true, 2, time.Duration(120)*time.Second) + return NewSafeMounter(true, true, true, 2, time.Duration(120)*time.Second) } fakeSafeMounter := FakeSafeMounter{} diff --git a/pkg/mounter/safe_mounter_host_process_windows.go b/pkg/mounter/safe_mounter_host_process_windows.go index 0d5fd6c34f..930e17127b 100644 --- a/pkg/mounter/safe_mounter_host_process_windows.go +++ b/pkg/mounter/safe_mounter_host_process_windows.go @@ -40,10 +40,14 @@ import ( var _ CSIProxyMounter = &winMounter{} -type winMounter struct{} +type winMounter struct { + listDisksUsingWinCIM bool +} -func NewWinMounter() *winMounter { - return &winMounter{} +func NewWinMounter(listDisksUsingWinCIM bool) *winMounter { + return &winMounter{ + listDisksUsingWinCIM: listDisksUsingWinCIM, + } } // Mount just creates a soft link at target pointing to source. @@ -206,7 +210,13 @@ func (mounter *winMounter) Rescan() error { // FindDiskByLun - given a lun number, find out the corresponding disk func (mounter *winMounter) FindDiskByLun(lun string) (diskNum string, err error) { - diskLocations, err := disk.ListDiskLocations() + var diskLocations map[uint32]disk.Location + + if mounter.listDisksUsingWinCIM { + diskLocations, err = disk.ListDisksUsingCIM() + } else { + diskLocations, err = disk.ListDiskLocations() + } if err != nil { return "", err } diff --git a/pkg/mounter/safe_mounter_unix.go b/pkg/mounter/safe_mounter_unix.go index 2fe95f9c75..fb5e442f7b 100644 --- a/pkg/mounter/safe_mounter_unix.go +++ b/pkg/mounter/safe_mounter_unix.go @@ -26,7 +26,7 @@ import ( utilexec "k8s.io/utils/exec" ) -func NewSafeMounter(_, _ bool, maxConcurrentFormat int, concurrentFormatTimeout time.Duration) (*mount.SafeFormatAndMount, error) { +func NewSafeMounter(_, _, _ bool, maxConcurrentFormat int, concurrentFormatTimeout time.Duration) (*mount.SafeFormatAndMount, error) { opt := mount.WithMaxConcurrentFormat(maxConcurrentFormat, concurrentFormatTimeout) return mount.NewSafeFormatAndMount(mount.New(""), utilexec.New(), opt), nil } diff --git a/pkg/mounter/safe_mounter_unix_test.go b/pkg/mounter/safe_mounter_unix_test.go index 2ae13cccef..684116f689 100644 --- a/pkg/mounter/safe_mounter_unix_test.go +++ b/pkg/mounter/safe_mounter_unix_test.go @@ -24,7 +24,7 @@ import ( ) func TestNewSafeMounter(t *testing.T) { - resp, err := NewSafeMounter(true, true, 2, time.Duration(120)*time.Second) + resp, err := NewSafeMounter(true, true, true, 2, time.Duration(120)*time.Second) assert.NotNil(t, resp) assert.Nil(t, err) } diff --git a/pkg/mounter/safe_mounter_windows.go b/pkg/mounter/safe_mounter_windows.go index f0832f9f8a..0c65bb98bd 100644 --- a/pkg/mounter/safe_mounter_windows.go +++ b/pkg/mounter/safe_mounter_windows.go @@ -412,11 +412,11 @@ func newCSIProxyMounter() (*csiProxyMounter, error) { }, nil } -func NewSafeMounter(enableWindowsHostProcess, useCSIProxyGAInterface bool, maxConcurrentFormat int, concurrentFormatTimeout time.Duration) (*mount.SafeFormatAndMount, error) { +func NewSafeMounter(enableWindowsHostProcess, listDisksUsingWinCIM, useCSIProxyGAInterface bool, maxConcurrentFormat int, concurrentFormatTimeout time.Duration) (*mount.SafeFormatAndMount, error) { if enableWindowsHostProcess { klog.V(2).Infof("using windows host process mounter") opt := mount.WithMaxConcurrentFormat(maxConcurrentFormat, concurrentFormatTimeout) - return mount.NewSafeFormatAndMount(NewWinMounter(), utilexec.New(), opt), nil + return mount.NewSafeFormatAndMount(NewWinMounter(listDisksUsingWinCIM), utilexec.New(), opt), nil } else { if useCSIProxyGAInterface { csiProxyMounter, err := newCSIProxyMounter() diff --git a/pkg/os/disk/disk.go b/pkg/os/disk/disk.go index 27ea144c03..a40b39f3fa 100644 --- a/pkg/os/disk/disk.go +++ b/pkg/os/disk/disk.go @@ -42,9 +42,9 @@ const ( IOCTL_STORAGE_QUERY_PROPERTY = 0x002d1400 ) -// ListDiskLocations - constructs a map with the disk number as the key and the DiskLocation structure +// ListDisksUsingCIM - constructs a map with the disk number as the key and the DiskLocation structure // as the value. The DiskLocation struct has various fields like the Adapter, Bus, Target and LUNID. -func ListDiskLocations() (map[uint32]Location, error) { +func ListDisksUsingCIM() (map[uint32]Location, error) { // sample response // [{ // "Index": 3, @@ -53,11 +53,12 @@ func ListDiskLocations() (map[uint32]Location, error) { // "SCSIPort": 1, // "SCSIBus": 0 // }, ...] - cmd := fmt.Sprintf("ConvertTo-Json @(Get-CimInstance win32_diskdrive|where-object -FilterScript {$_.SCSIPort -Ne 0}|Select Index,SCSILogicalUnit,SCSITargetId,SCSIPort,SCSIBus)") + cmd := fmt.Sprintf("ConvertTo-Json @(Get-CimInstance win32_diskdrive|Where-Object { $_.Model -eq "Virtual_Disk NVME Premium" -or $_.SCSIPort -eq 0 }|Select Index,SCSILogicalUnit,SCSITargetId,SCSIPort,SCSIBus)") out, err := azureutils.RunPowershellCmd(cmd) if err != nil { return nil, fmt.Errorf("failed to list disk location. cmd: %q, output: %q, err %v", cmd, string(out), err) } + klog.V(6).Infof("ListDisksUsingCIM output: %s", string(out)) var getCimInstance []struct { Index uint32 `json:"Index"` @@ -83,6 +84,67 @@ func ListDiskLocations() (map[uint32]Location, error) { return m, nil } +// ListDiskLocations - constructs a map with the disk number as the key and the DiskLocation structure +// as the value. The DiskLocation struct has various fields like the Adapter, Bus, Target and LUNID. +func ListDiskLocations() (map[uint32]Location, error) { + // sample response + // [{ + // "number": 0, + // "location": "PCI Slot 3 : Adapter 0 : Port 0 : Target 1 : LUN 0" + // }, ...] + cmd := fmt.Sprintf("ConvertTo-Json @(Get-Disk | select Number, Location, PartitionStyle)") + out, err := azureutils.RunPowershellCmd(cmd) + if err != nil { + return nil, fmt.Errorf("failed to list disk location. cmd: %q, output: %q, err %v", cmd, string(out), err) + } + klog.V(6).Infof("ListDiskLocations output: %s", string(out)) + + var getDisk []map[string]interface{} + err = json.Unmarshal(out, &getDisk) + if err != nil { + return nil, err + } + + m := make(map[uint32]Location) + for _, v := range getDisk { + str := v["Location"].(string) + num := v["Number"].(float64) + partitionStyle := v["PartitionStyle"].(string) + if strings.EqualFold(partitionStyle, "MBR") { + klog.V(2).Infof("skipping MBR disk, number: %d, location: %s", int(num), str) + continue + } + + found := false + s := strings.Split(str, ":") + if len(s) >= 5 { + var d Location + for _, item := range s { + item = strings.TrimSpace(item) + itemSplit := strings.Split(item, " ") + if len(itemSplit) == 2 { + found = true + switch strings.TrimSpace(itemSplit[0]) { + case "Adapter": + d.Adapter = strings.TrimSpace(itemSplit[1]) + case "Target": + d.Target = strings.TrimSpace(itemSplit[1]) + case "LUN": + d.LUNID = strings.TrimSpace(itemSplit[1]) + default: + klog.Warningf("Got unknown field : %s=%s", itemSplit[0], itemSplit[1]) + } + } + } + + if found { + m[uint32(num)] = d + } + } + } + return m, nil +} + func Rescan() error { cmd := "Update-HostStorageCache" out, err := azureutils.RunPowershellCmd(cmd)