diff --git a/internal/config/modifier_test.go b/internal/config/modifier_test.go new file mode 100644 index 0000000..ee1bca1 --- /dev/null +++ b/internal/config/modifier_test.go @@ -0,0 +1,122 @@ +package config + +import ( + "fmt" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/reecetech/ebs-bootstrap/internal/service" + "github.com/reecetech/ebs-bootstrap/internal/utils" +) + +func TestAwsNitroNVMeModifier(t *testing.T) { + ns := service.NewMockNVMeService() + ds := service.NewMockDeviceService() + + subtests := []struct { + Name string + Config *Config + GetBlockDevices func() ([]string, error) + GetBlockDeviceMapping func(name string) (string, error) + ExpectedOutput *Config + ExpectedErr error + }{ + { + Name: "Root Device + EBS Device (Non-Nitro Instance)", + Config: &Config{ + Devices: map[string]Device{ + "/dev/sdb": {}, + }, + }, + GetBlockDevices: func() ([]string, error) { + return []string{"/dev/sda1", "/dev/sdb"}, nil + }, + GetBlockDeviceMapping: func(name string) (string, error) { + return "", fmt.Errorf("🔴 GetBlockDeviceMapping() should not be called") + }, + ExpectedOutput: &Config{ + Devices: map[string]Device{ + "/dev/sdb": {}, + }, + }, + ExpectedErr: nil, + }, + { + Name: "Root Device + EBS/Instance Store Device (Nitro Instance)", + Config: &Config{ + Devices: map[string]Device{ + "/dev/sdb": {}, + }, + }, + GetBlockDevices: func() ([]string, error) { + return []string{"/dev/nvme0n1", "/dev/nvme1n1"}, nil + }, + GetBlockDeviceMapping: func(name string) (string, error) { + switch name { + case "/dev/nvme0n1": // Root Device + return "/dev/sda1", nil + default: // EBS/Instance Store + return "/dev/sdb", nil + } + }, + // Config will be left unchanged when error is encountered during modification stage + ExpectedOutput: &Config{ + Devices: map[string]Device{ + "/dev/nvme1n1": {}, + }, + }, + ExpectedErr: nil, + }, + { + Name: "NVMe Device that is not AWS-managed", + Config: &Config{ + Devices: map[string]Device{ + "/dev/sdb": {}, + }, + }, + GetBlockDevices: func() ([]string, error) { + return []string{"/dev/nvme0n1"}, nil + }, + GetBlockDeviceMapping: func(name string) (string, error) { + return "", fmt.Errorf("🔴 %s is not an AWS-managed NVME device", name) + }, + ExpectedOutput: &Config{ + Devices: map[string]Device{ + "/dev/sdb": {}, + }, + }, + ExpectedErr: fmt.Errorf("🔴 /dev/nvme0n1 is not an AWS-managed NVME device"), + }, + { + Name: "Failure to Retrieve Block Devices", + Config: &Config{ + Devices: map[string]Device{ + "/dev/sdb": {}, + }, + }, + GetBlockDevices: func() ([]string, error) { + return nil, fmt.Errorf("🔴 lsblk: Could not retrieve block devices") + }, + GetBlockDeviceMapping: func(name string) (string, error) { + return "", fmt.Errorf("🔴 GetBlockDeviceMapping() should not be called") + }, + // Config will be left unchanged when error is encountered during modification stage + ExpectedOutput: &Config{ + Devices: map[string]Device{ + "/dev/sdb": {}, + }, + }, + ExpectedErr: fmt.Errorf("🔴 lsblk: Could not retrieve block devices"), + }, + } + + for _, subtest := range subtests { + ds.StubGetBlockDevices = subtest.GetBlockDevices + ns.StubGetBlockDeviceMapping = subtest.GetBlockDeviceMapping + + andm := NewAwsNVMeDriverModifier(ns, ds) + err := andm.Modify(subtest.Config) + utils.CheckErrorGlob("andm.Modify()", t, subtest.ExpectedErr, err) + utils.CheckOutput("andm.Modify()", t, subtest.ExpectedOutput, subtest.Config, cmp.AllowUnexported(Config{})) + } +} diff --git a/internal/config/validator_test.go b/internal/config/validator_test.go index 4e1b20e..2d90760 100644 --- a/internal/config/validator_test.go +++ b/internal/config/validator_test.go @@ -38,7 +38,7 @@ func TestDeviceValidator(t *testing.T) { }, }, GetBlockDevice: func(name string) (*model.BlockDevice, error) { - return nil, fmt.Errorf("🔴 lsblk: /dev/nonexist: not a block device") + return nil, fmt.Errorf("🔴 lsblk: /dev/nonexist is not a block device") }, ExpectedErr: fmt.Errorf("🔴 /dev/nonexist is not a block device"), }, diff --git a/internal/service/mock.go b/internal/service/mock.go index 88d48f8..4359706 100644 --- a/internal/service/mock.go +++ b/internal/service/mock.go @@ -84,3 +84,19 @@ func (mos *MockOwnerService) GetUser(usr string) (*model.User, error) { func (mos *MockOwnerService) GetGroup(grp string) (*model.Group, error) { return mos.StubGetGroup(grp) } + +type MockNVMeService struct { + StubGetBlockDeviceMapping func(device string) (string, error) +} + +func NewMockNVMeService() *MockNVMeService { + return &MockNVMeService{ + StubGetBlockDeviceMapping: func(device string) (string, error) { + return "", utils.NotImeplementedError("GetBlockDeviceMapping()") + }, + } +} + +func (mns *MockNVMeService) GetBlockDeviceMapping(device string) (string, error) { + return mns.StubGetBlockDeviceMapping(device) +}