diff --git a/go.mod b/go.mod index 05a6613d..741ae63c 100644 --- a/go.mod +++ b/go.mod @@ -14,7 +14,7 @@ require ( go.uber.org/mock v0.4.0 golang.org/x/net v0.29.0 golang.org/x/sys v0.25.0 - google.golang.org/grpc v1.66.2 + google.golang.org/grpc v1.67.0 google.golang.org/protobuf v1.34.2 k8s.io/apimachinery v0.31.1 k8s.io/klog/v2 v2.130.1 @@ -27,10 +27,10 @@ require ( github.com/go-resty/resty/v2 v2.13.1 // indirect github.com/godbus/dbus/v5 v5.1.0 // indirect github.com/moby/sys/mountinfo v0.7.1 // indirect - github.com/opencontainers/runc v1.1.13 // indirect + github.com/opencontainers/runc v1.1.14 // indirect github.com/opencontainers/runtime-spec v1.0.3-0.20220909204839-494a5a6aca78 // indirect github.com/sirupsen/logrus v1.9.3 // indirect golang.org/x/text v0.18.0 // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20240711142825-46eb208f015d // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20240814211410-ddb44dafa142 // indirect gopkg.in/ini.v1 v1.66.6 // indirect ) diff --git a/go.sum b/go.sum index f0b9f849..10641768 100644 --- a/go.sum +++ b/go.sum @@ -29,8 +29,8 @@ github.com/martinjungblut/go-cryptsetup v0.0.0-20220520180014-fd0874fd07a6 h1:YD github.com/martinjungblut/go-cryptsetup v0.0.0-20220520180014-fd0874fd07a6/go.mod h1:gZoZ0+POlM1ge/VUxWpMmZVNPzzMJ7l436CgkQ5+qzU= github.com/moby/sys/mountinfo v0.7.1 h1:/tTvQaSJRr2FshkhXiIpux6fQ2Zvc4j7tAhMTStAG2g= github.com/moby/sys/mountinfo v0.7.1/go.mod h1:IJb6JQeOklcdMU9F5xQ8ZALD+CUr5VlGpwtX+VE0rpI= -github.com/opencontainers/runc v1.1.13 h1:98S2srgG9vw0zWcDpFMn5TRrh8kLxa/5OFUstuUhmRs= -github.com/opencontainers/runc v1.1.13/go.mod h1:R016aXacfp/gwQBYw2FDGa9m+n6atbLWrYY8hNMT/sA= +github.com/opencontainers/runc v1.1.14 h1:rgSuzbmgz5DUJjeSnw337TxDbRuqjs6iqQck/2weR6w= +github.com/opencontainers/runc v1.1.14/go.mod h1:E4C2z+7BxR7GHXp0hAY53mek+x49X1LjPNeMTfRGvOA= github.com/opencontainers/runtime-spec v1.0.3-0.20220909204839-494a5a6aca78 h1:R5M2qXZiK/mWPMT4VldCOiSL9HIAMuxQZWdG0CSM5+4= github.com/opencontainers/runtime-spec v1.0.3-0.20220909204839-494a5a6aca78/go.mod h1:jwyrGlmzljRJv/Fgzds9SsS/C5hL+LL3ko9hs6T5lQ0= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -105,10 +105,10 @@ golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtn golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -google.golang.org/genproto/googleapis/rpc v0.0.0-20240711142825-46eb208f015d h1:JU0iKnSg02Gmb5ZdV8nYsKEKsP6o/FGVWTrw4i1DA9A= -google.golang.org/genproto/googleapis/rpc v0.0.0-20240711142825-46eb208f015d/go.mod h1:Ue6ibwXGpU+dqIcODieyLOcgj7z8+IcskoNIgZxtrFY= -google.golang.org/grpc v1.66.2 h1:3QdXkuq3Bkh7w+ywLdLvM56cmGvQHUMZpiCzt6Rqaoo= -google.golang.org/grpc v1.66.2/go.mod h1:s3/l6xSSCURdVfAnL+TqCNMyTDAGN6+lZeVxnZR128Y= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240814211410-ddb44dafa142 h1:e7S5W7MGGLaSu8j3YjdezkZ+m1/Nm0uRVRMEMGk26Xs= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240814211410-ddb44dafa142/go.mod h1:UqMtugtsSgubUsoxbuAoiCXvqvErP7Gf0so0mK9tHxU= +google.golang.org/grpc v1.67.0 h1:IdH9y6PF5MPSdAntIcpjQ+tXO41pcQsfZV2RxtQgVcw= +google.golang.org/grpc v1.67.0/go.mod h1:1gLDyUQU7CTLJI90u3nXZ9ekeghjeM7pTDZlqFNg2AA= google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg= google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/internal/driver/controllerserver.go b/internal/driver/controllerserver.go index 5b056522..05bd6f73 100644 --- a/internal/driver/controllerserver.go +++ b/internal/driver/controllerserver.go @@ -4,7 +4,6 @@ import ( "context" "errors" "strconv" - "strings" "github.com/container-storage-interface/spec/lib/go/csi" "github.com/linode/linodego" @@ -316,7 +315,8 @@ func (cs *ControllerServer) ControllerPublishVolume(ctx context.Context, req *cs } log.V(4).Info("Waiting for volume to attach", "volume_id", volumeID) - volume, err = cs.client.WaitForVolumeLinodeID(ctx, volumeID, &linodeID, waitTimeout()) + // Wait for the volume to be successfully attached to the instance + volume, err := cs.client.WaitForVolumeLinodeID(ctx, volumeID, &linodeID, waitTimeout()) if err != nil { return resp, errInternal("wait for volume to attach: %v", err) } diff --git a/internal/driver/controllerserver_helper.go b/internal/driver/controllerserver_helper.go index 5796a00b..4ca03cd4 100644 --- a/internal/driver/controllerserver_helper.go +++ b/internal/driver/controllerserver_helper.go @@ -85,18 +85,18 @@ const ( // // Whether or not another volume can be attached is based on how many instance // disks and block storage volumes are currently attached to the instance. -func (s *ControllerServer) canAttach(ctx context.Context, instance *linodego.Instance) (canAttach bool, err error) { +func (cs *ControllerServer) canAttach(ctx context.Context, instance *linodego.Instance) (canAttach bool, err error) { log := logger.GetLogger(ctx) log.V(4).Info("Checking if volume can be attached", "instance_id", instance.ID) // Get the maximum number of volume attachments allowed for the instance - limit, err := s.maxAllowedVolumeAttachments(ctx, instance) + limit, err := cs.maxAllowedVolumeAttachments(ctx, instance) if err != nil { return false, err } // List the volumes currently attached to the instance - volumes, err := s.client.ListInstanceVolumes(ctx, instance.ID, nil) + volumes, err := cs.client.ListInstanceVolumes(ctx, instance.ID, nil) if err != nil { return false, errInternal("list instance volumes: %v", err) } @@ -107,7 +107,7 @@ func (s *ControllerServer) canAttach(ctx context.Context, instance *linodego.Ins // maxAllowedVolumeAttachments calculates the maximum number of volumes that can be attached to a Linode instance, // taking into account the instance's memory and currently attached disks. -func (s *ControllerServer) maxAllowedVolumeAttachments(ctx context.Context, instance *linodego.Instance) (int, error) { +func (cs *ControllerServer) maxAllowedVolumeAttachments(ctx context.Context, instance *linodego.Instance) (int, error) { log := logger.GetLogger(ctx) log.V(4).Info("Calculating max volume attachments") @@ -117,7 +117,7 @@ func (s *ControllerServer) maxAllowedVolumeAttachments(ctx context.Context, inst } // Retrieve the list of disks currently attached to the instance - disks, err := s.client.ListInstanceDisks(ctx, instance.ID, nil) + disks, err := cs.client.ListInstanceDisks(ctx, instance.ID, nil) if err != nil { return 0, errInternal("list instance disks: %v", err) } @@ -480,3 +480,141 @@ func (cs *ControllerServer) prepareCreateVolumeResponse(ctx context.Context, vol return resp } + +// validateControllerPublishVolumeRequest validates the incoming ControllerPublishVolumeRequest. +// It extracts the Linode ID and Volume ID from the request and checks if the +// volume capability is provided and valid. If any validation fails, it returns +// an appropriate error. +func (cs *ControllerServer) validateControllerPublishVolumeRequest(ctx context.Context, req *csi.ControllerPublishVolumeRequest) (int, int, error) { + log := logger.GetLogger(ctx) + log.V(4).Info("Entering validateControllerPublishVolumeRequest()", "req", req) + defer log.V(4).Info("Exiting validateControllerPublishVolumeRequest()") + + // extract the linode ID from the request + linodeID, statusErr := linodevolumes.NodeIdAsInt("ControllerPublishVolume", req) + if statusErr != nil { + return 0, 0, statusErr + } + + // extract the volume ID from the request + volumeID, statusErr := linodevolumes.VolumeIdAsInt("ControllerPublishVolume", req) + if statusErr != nil { + return 0, 0, statusErr + } + + // retrieve the volume capability from the request + cap := req.GetVolumeCapability() + // return an error if no volume capability is provided + if cap == nil { + return 0, 0, errNoVolumeCapability + } + // return an error if the volume capability is invalid + if !validVolumeCapabilities([]*csi.VolumeCapability{cap}) { + return 0, 0, errInvalidVolumeCapability([]*csi.VolumeCapability{cap}) + } + + log.V(4).Info("Validation passed", "linodeID", linodeID, "volumeID", volumeID) + return linodeID, volumeID, nil +} + +// getAndValidateVolume retrieves the volume by its ID and checks if it is +// attached to the specified Linode instance. If the volume is found and +// already attached to the instance, it returns its device path. +// If the volume is not found or attached to a different instance, it +// returns an appropriate error. +func (cs *ControllerServer) getAndValidateVolume(ctx context.Context, volumeID, linodeID int) (string, error) { + log := logger.GetLogger(ctx) + log.V(4).Info("Entering getAndValidateVolume()", "volumeID", volumeID, "linodeID", linodeID) + defer log.V(4).Info("Exiting getAndValidateVolume()") + + volume, err := cs.client.GetVolume(ctx, volumeID) + if linodego.IsNotFound(err) { + return "", errVolumeNotFound(volumeID) + } else if err != nil { + return "", errInternal("get volume %d: %v", volumeID, err) + } + + if volume.LinodeID != nil { + if *volume.LinodeID == linodeID { + log.V(4).Info("Volume already attached to instance", "volume_id", volume.ID, "node_id", *volume.LinodeID, "device_path", volume.FilesystemPath) + return volume.FilesystemPath, nil + } + return "", errVolumeAttached(volumeID, linodeID) + } + + log.V(4).Info("Volume validated and is not attached to instance", "volume_id", volume.ID, "node_id", linodeID) + return "", nil +} + +// getInstance retrieves the Linode instance by its ID. If the +// instance is not found, it returns an error indicating that the instance +// does not exist. If any other error occurs during retrieval, it returns +// an internal error. +func (cs *ControllerServer) getInstance(ctx context.Context, linodeID int) (*linodego.Instance, error) { + log := logger.GetLogger(ctx) + log.V(4).Info("Entering getInstance()", "linodeID", linodeID) + defer log.V(4).Info("Exiting getInstance()") + + instance, err := cs.client.GetInstance(ctx, linodeID) + if linodego.IsNotFound(err) { + return nil, errInstanceNotFound(linodeID) + } else if err != nil { + // If any other error occurs, return an internal error. + return nil, errInternal("get linode instance %d: %v", linodeID, err) + } + + log.V(4).Info("Instance retrieved", "instance", instance) + return instance, nil +} + +// checkAttachmentCapacity checks if the specified instance can accommodate +// additional volume attachments. It retrieves the maximum number of allowed +// attachments and compares it with the currently attached volumes. If the +// limit is exceeded, it returns an error indicating the maximum volume +// attachments allowed. +func (cs *ControllerServer) checkAttachmentCapacity(ctx context.Context, instance *linodego.Instance) error { + log := logger.GetLogger(ctx) + log.V(4).Info("Entering checkAttachmentCapacity()", "linodeID", instance.ID) + defer log.V(4).Info("Exiting checkAttachmentCapacity()") + + canAttach, err := cs.canAttach(ctx, instance) + if err != nil { + return err + } + if !canAttach { + // If the instance cannot accommodate more attachments, retrieve the maximum allowed attachments. + limit, err := cs.maxAllowedVolumeAttachments(ctx, instance) + if errors.Is(err, errNilInstance) { + return errInternal("cannot calculate max volume attachments for a nil instance") + } else if err != nil { + return errMaxAttachments // Return an error indicating the maximum attachments limit has been reached. + } + return errMaxVolumeAttachments(limit) // Return an error indicating the maximum volume attachments allowed. + } + return nil // Return nil if the instance can accommodate more attachments. +} + +// attachVolume attaches the specified volume to the given Linode instance. +// It logs the action and handles any errors that may occur during the +// attachment process. If the volume is already attached, it allows for a +// retry by returning an Unavailable error. +func (cs *ControllerServer) attachVolume(ctx context.Context, volumeID, linodeID int) error { + log := logger.GetLogger(ctx) + log.V(4).Info("Entering attachVolume()", "volume_id", volumeID, "node_id", linodeID) + defer log.V(4).Info("Exiting attachVolume()") + + persist := false + _, err := cs.client.AttachVolume(ctx, volumeID, &linodego.VolumeAttachOptions{ + LinodeID: linodeID, + PersistAcrossBoots: &persist, + }) + if err != nil { + code := codes.Internal // Default error code is Internal. + // Check if the error indicates that the volume is already attached. + if apiErr, ok := err.(*linodego.Error); ok && strings.Contains(apiErr.Message, "is already attached") { + code = codes.Unavailable // Allow a retry if the volume is already attached: race condition can occur here + } + return status.Errorf(code, "attach volume: %v", err) + } + return nil // Return nil if the volume is successfully attached. +} diff --git a/internal/driver/controllerserver_helper_test.go b/internal/driver/controllerserver_helper_test.go index 7ea037f3..c48b7a15 100644 --- a/internal/driver/controllerserver_helper_test.go +++ b/internal/driver/controllerserver_helper_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "net/http" "reflect" "testing" @@ -12,6 +13,8 @@ import ( linodevolumes "github.com/linode/linode-blockstorage-csi-driver/pkg/linode-volumes" "github.com/linode/linodego" "go.uber.org/mock/gomock" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) func TestPrepareCreateVolumeResponse(t *testing.T) { @@ -300,7 +303,7 @@ func TestCreateAndWaitForVolume(t *testing.T) { volume, err := cs.createAndWaitForVolume(context.Background(), tc.volumeName, tc.sizeGB, tc.tags, tc.sourceInfo) - if !reflect.DeepEqual(tc.expectedError, err) { + if err != nil && !reflect.DeepEqual(tc.expectedError, err) { if tc.expectedError != nil { t.Errorf("expected error %v, got %v", tc.expectedError, err) } else { @@ -399,7 +402,7 @@ func TestPrepareVolumeParams(t *testing.T) { volumeName, sizeGB, size, err := cs.prepareVolumeParams(ctx, tt.req) - if !reflect.DeepEqual(tt.expectedError, err) { + if err != nil && !reflect.DeepEqual(tt.expectedError, err) { if tt.expectedError != nil { t.Errorf("expected error %v, got %v", tt.expectedError, err) } else { @@ -518,9 +521,551 @@ func TestValidateCreateVolumeRequest(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { gotErr := cs.validateCreateVolumeRequest(ctx, tc.req) - if !reflect.DeepEqual(gotErr, tc.wantErr) { + if gotErr != nil && !reflect.DeepEqual(gotErr, tc.wantErr) { t.Errorf("validateCreateVolumeRequest() error = %v, wantErr %v", gotErr, tc.wantErr) } }) } } + +func TestValidateControllerPublishVolumeRequest(t *testing.T) { + cs := &ControllerServer{} + ctx := context.Background() + + testCases := []struct { + name string + req *csi.ControllerPublishVolumeRequest + expectedNodeID int + expectedVolID int + expectedErr error + }{ + { + name: "Valid request", + req: &csi.ControllerPublishVolumeRequest{ + NodeId: "12345", + VolumeId: "67890-test-volume", + VolumeCapability: &csi.VolumeCapability{ + AccessMode: &csi.VolumeCapability_AccessMode{ + Mode: csi.VolumeCapability_AccessMode_SINGLE_NODE_WRITER, + }, + }, + }, + expectedNodeID: 12345, + expectedVolID: 67890, + expectedErr: nil, + }, + { + name: "missing node ID", + req: &csi.ControllerPublishVolumeRequest{ + NodeId: "", + VolumeId: "67890-test-volume", + VolumeCapability: &csi.VolumeCapability{ + AccessMode: &csi.VolumeCapability_AccessMode{ + Mode: csi.VolumeCapability_AccessMode_SINGLE_NODE_WRITER, + }, + }, + }, + expectedNodeID: 0, + expectedVolID: 0, + expectedErr: status.Error(codes.InvalidArgument, "ControllerPublishVolume Node ID must be provided"), + }, + { + name: "missing volume ID", + req: &csi.ControllerPublishVolumeRequest{ + NodeId: "12345", + VolumeId: "", + VolumeCapability: &csi.VolumeCapability{ + AccessMode: &csi.VolumeCapability_AccessMode{ + Mode: csi.VolumeCapability_AccessMode_SINGLE_NODE_WRITER, + }, + }, + }, + expectedNodeID: 0, + expectedVolID: 0, + expectedErr: status.Error(codes.InvalidArgument, "ControllerPublishVolume Volume ID must be provided"), + }, + { + name: "Missing volume capability", + req: &csi.ControllerPublishVolumeRequest{ + NodeId: "12345", + VolumeId: "67890-test-volume", + VolumeCapability: nil, + }, + expectedNodeID: 0, + expectedVolID: 0, + expectedErr: errNoVolumeCapability, + }, + { + name: "Invalid volume capability", + req: &csi.ControllerPublishVolumeRequest{ + NodeId: "12345", + VolumeId: "67890-test-volume", + VolumeCapability: &csi.VolumeCapability{ + AccessMode: &csi.VolumeCapability_AccessMode{ + Mode: csi.VolumeCapability_AccessMode_MULTI_NODE_MULTI_WRITER, + }, + }, + }, + expectedNodeID: 0, + expectedVolID: 0, + expectedErr: errInvalidVolumeCapability([]*csi.VolumeCapability{{AccessMode: &csi.VolumeCapability_AccessMode{Mode: csi.VolumeCapability_AccessMode_MULTI_NODE_MULTI_WRITER}}}), + }, + { + name: "Nil access mode", + req: &csi.ControllerPublishVolumeRequest{ + NodeId: "12345", + VolumeId: "67890-test-volume", + VolumeCapability: &csi.VolumeCapability{ + AccessMode: nil, + }, + }, + expectedNodeID: 0, + expectedVolID: 0, + expectedErr: errInvalidVolumeCapability([]*csi.VolumeCapability{{AccessMode: nil}}), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + nodeID, volID, err := cs.validateControllerPublishVolumeRequest(ctx, tc.req) + + if err != nil && !reflect.DeepEqual(err, tc.expectedErr) { + t.Errorf("Expected error %v, but got %v", tc.expectedErr, err) + } + + if nodeID != tc.expectedNodeID { + t.Errorf("Expected node ID %d, but got %d", tc.expectedNodeID, nodeID) + } + + if volID != tc.expectedVolID { + t.Errorf("Expected volume ID %d, but got %d", tc.expectedVolID, volID) + } + }) + } +} + +func TestGetAndValidateVolume(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockClient := mocks.NewMockLinodeClient(ctrl) + cs := &ControllerServer{ + client: mockClient, + } + + testCases := []struct { + name string + volumeID int + linodeID int + setupMocks func() + expectedResult string + expectedError error + }{ + { + name: "Volume found and attached to correct instance", + volumeID: 123, + linodeID: 456, + setupMocks: func() { + mockClient.EXPECT().GetVolume(gomock.Any(), 123).Return(&linodego.Volume{ + ID: 123, + LinodeID: &[]int{456}[0], + FilesystemPath: "/dev/disk/by-id/scsi-0Linode_Volume_test-volume", + }, nil) + }, + expectedResult: "/dev/disk/by-id/scsi-0Linode_Volume_test-volume", + expectedError: nil, + }, + { + name: "Volume found but not attached", + volumeID: 123, + linodeID: 456, + setupMocks: func() { + mockClient.EXPECT().GetVolume(gomock.Any(), 123).Return(&linodego.Volume{ + ID: 123, + LinodeID: nil, + }, nil) + }, + expectedResult: "", + expectedError: nil, + }, + { + name: "Volume found but attached to different instance", + volumeID: 123, + linodeID: 456, + setupMocks: func() { + mockClient.EXPECT().GetVolume(gomock.Any(), 123).Return(&linodego.Volume{ + ID: 123, + LinodeID: &[]int{789}[0], + }, nil) + }, + expectedResult: "", + expectedError: errVolumeAttached(123, 456), + }, + { + name: "Volume not found", + volumeID: 123, + linodeID: 456, + setupMocks: func() { + mockClient.EXPECT().GetVolume(gomock.Any(), 123).Return(nil, &linodego.Error{ + Code: http.StatusNotFound, + Message: "Not Found", + }) + }, + expectedResult: "", + expectedError: errVolumeNotFound(123), + }, + { + name: "API error", + volumeID: 123, + linodeID: 456, + setupMocks: func() { + mockClient.EXPECT().GetVolume(gomock.Any(), 123).Return(nil, errors.New("API error")) + }, + expectedResult: "", + expectedError: errInternal("get volume 123: API error"), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + tc.setupMocks() + + result, err := cs.getAndValidateVolume(context.Background(), tc.volumeID, tc.linodeID) + + if err != nil && !reflect.DeepEqual(tc.expectedError, err) { + t.Errorf("expected error %v, got %v", tc.expectedError, err) + } + + if tc.expectedResult != result { + t.Errorf("expected result %s, got %s", tc.expectedResult, result) + } + }) + } +} + +func TestCheckAttachmentCapacity(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockClient := mocks.NewMockLinodeClient(ctrl) + cs := &ControllerServer{ + client: mockClient, + } + + testCases := []struct { + name string + instance *linodego.Instance + setupMocks func() + expectedError error + }{ + { + name: "Can attach volume", + instance: &linodego.Instance{ + ID: 123, + Specs: &linodego.InstanceSpec{ + Memory: 4096, + }, + }, + setupMocks: func() { + mockClient.EXPECT().ListInstanceVolumes(gomock.Any(), 123, gomock.Any()).Return([]linodego.Volume{}, nil) + mockClient.EXPECT().ListInstanceDisks(gomock.Any(), 123, gomock.Any()).Return([]linodego.InstanceDisk{}, nil) + }, + expectedError: nil, + }, + { + name: "Cannot attach volume - max attachments reached", + instance: &linodego.Instance{ + ID: 456, + Specs: &linodego.InstanceSpec{ + Memory: 1024, + }, + }, + setupMocks: func() { + mockClient.EXPECT().ListInstanceDisks(gomock.Any(), 456, gomock.Any()).Return([]linodego.InstanceDisk{linodego.InstanceDisk{ID: 1}, linodego.InstanceDisk{ID: 2}}, nil).AnyTimes() + mockClient.EXPECT().ListInstanceVolumes(gomock.Any(), 456, gomock.Any()).Return([]linodego.Volume{linodego.Volume{ID: 1}, linodego.Volume{ID: 2}, linodego.Volume{ID: 3}, linodego.Volume{ID: 4}, linodego.Volume{ID: 5}, linodego.Volume{ID: 6}}, nil) + }, + expectedError: errMaxVolumeAttachments(6), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + tc.setupMocks() + + err := cs.checkAttachmentCapacity(context.Background(), tc.instance) + + if err != nil && !reflect.DeepEqual(tc.expectedError, err) { + t.Errorf("expected error %v, got %v", tc.expectedError, err) + } + }) + } +} + +func TestAttemptGetContentSourceVolume(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockClient := mocks.NewMockLinodeClient(ctrl) + cs := &ControllerServer{ + client: mockClient, + metadata: Metadata{ + Region: "us-east", + }, + } + + testCases := []struct { + name string + contentSource *csi.VolumeContentSource + setupMocks func() + expectedResult *linodevolumes.LinodeVolumeKey + expectedError error + }{ + { + name: "Nil content source", + contentSource: nil, + setupMocks: func() {}, + expectedResult: nil, + expectedError: nil, + }, + { + name: "Invalid content source type", + contentSource: &csi.VolumeContentSource{ + Type: &csi.VolumeContentSource_Snapshot{}, + }, + setupMocks: func() {}, + expectedResult: nil, + expectedError: errUnsupportedVolumeContentSource, + }, + { + name: "Nil volume", + contentSource: &csi.VolumeContentSource{ + Type: &csi.VolumeContentSource_Volume{ + Volume: nil, + }, + }, + setupMocks: func() {}, + expectedResult: nil, + expectedError: errNoSourceVolume, + }, + { + name: "Invalid volume ID", + contentSource: &csi.VolumeContentSource{ + Type: &csi.VolumeContentSource_Volume{ + Volume: &csi.VolumeContentSource_VolumeSource{ + VolumeId: "test-volume", + }, + }, + }, + setupMocks: func() {}, + expectedResult: nil, + expectedError: errInternal("parse volume info from content source: invalid linode volume id: \"test\""), + }, + { + name: "Valid content source, matching region", + contentSource: &csi.VolumeContentSource{ + Type: &csi.VolumeContentSource_Volume{ + Volume: &csi.VolumeContentSource_VolumeSource{ + VolumeId: "123-testvolume", + }, + }, + }, + setupMocks: func() { + mockClient.EXPECT().GetVolume(gomock.Any(), 123).Return(&linodego.Volume{ + ID: 123, + Region: "us-east", + }, nil) + }, + expectedResult: &linodevolumes.LinodeVolumeKey{ + VolumeID: 123, + Label: "testvolume", + }, + expectedError: nil, + }, + { + name: "Valid content source, mismatched region", + contentSource: &csi.VolumeContentSource{ + Type: &csi.VolumeContentSource_Volume{ + Volume: &csi.VolumeContentSource_VolumeSource{ + VolumeId: "456-othervolume", + }, + }, + }, + setupMocks: func() { + mockClient.EXPECT().GetVolume(gomock.Any(), 456).Return(&linodego.Volume{ + ID: 456, + Region: "us-west", + }, nil) + }, + expectedResult: nil, + expectedError: errRegionMismatch("us-west", "us-east"), + }, + { + name: "API error", + contentSource: &csi.VolumeContentSource{ + Type: &csi.VolumeContentSource_Volume{ + Volume: &csi.VolumeContentSource_VolumeSource{ + VolumeId: "789-errorvolume", + }, + }, + }, + setupMocks: func() { + mockClient.EXPECT().GetVolume(gomock.Any(), 789).Return(nil, errors.New("API error")) + }, + expectedResult: nil, + expectedError: errInternal("get volume 789: API error"), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + tc.setupMocks() + + result, err := cs.getContentSourceVolume(context.Background(), tc.contentSource) + + if err != nil && !reflect.DeepEqual(tc.expectedError, err) { + t.Errorf("expected error %v, got %v", tc.expectedError, err) + } + + if !reflect.DeepEqual(tc.expectedResult, result) { + t.Errorf("expected result %+v, got %+v", tc.expectedResult, result) + } + }) + } +} + +func TestAttachVolume(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockClient := mocks.NewMockLinodeClient(ctrl) + cs := &ControllerServer{ + client: mockClient, + } + + testCases := []struct { + name string + volumeID int + linodeID int + setupMocks func() + expectedError error + }{ + { + name: "Successful attachment", + volumeID: 123, + linodeID: 456, + setupMocks: func() { + mockClient.EXPECT().AttachVolume(gomock.Any(), 123, gomock.Any()).Return(&linodego.Volume{}, nil) + }, + expectedError: nil, + }, + { + name: "Volume already attached", + volumeID: 789, + linodeID: 101, + setupMocks: func() { + mockClient.EXPECT().AttachVolume(gomock.Any(), 789, gomock.Any()).Return(nil, &linodego.Error{Message: "Volume 789 is already attached"}) + }, + expectedError: status.Error(codes.Unavailable, "attach volume: [000] Volume 789 is already attached"), + }, + { + name: "API error", + volumeID: 202, + linodeID: 303, + setupMocks: func() { + mockClient.EXPECT().AttachVolume(gomock.Any(), 202, gomock.Any()).Return(nil, errors.New("API error")) + }, + expectedError: status.Error(codes.Internal, "attach volume: API error"), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + tc.setupMocks() + + err := cs.attachVolume(context.Background(), tc.volumeID, tc.linodeID) + + if tc.expectedError == nil && err != nil { + t.Errorf("expected no error, got %v", err) + } else if tc.expectedError != nil && err == nil { + t.Errorf("expected error %v, got nil", tc.expectedError) + } else if tc.expectedError != nil && err != nil { + if tc.expectedError.Error() != err.Error() { + t.Errorf("expected error %v, got %v", tc.expectedError, err) + } + } + }) + } +} + +func TestGetInstance(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockClient := mocks.NewMockLinodeClient(ctrl) + cs := &ControllerServer{ + client: mockClient, + } + + testCases := []struct { + name string + linodeID int + setupMocks func() + expectedInstance *linodego.Instance + expectedError error + }{ + { + name: "Instance found", + linodeID: 123, + setupMocks: func() { + mockClient.EXPECT().GetInstance(gomock.Any(), 123).Return(&linodego.Instance{ + ID: 123, + Label: "test-instance", + Status: linodego.InstanceRunning, + }, nil) + }, + expectedInstance: &linodego.Instance{ + ID: 123, + Label: "test-instance", + Status: linodego.InstanceRunning, + }, + expectedError: nil, + }, + { + name: "Instance not found", + linodeID: 456, + setupMocks: func() { + mockClient.EXPECT().GetInstance(gomock.Any(), 456).Return(nil, &linodego.Error{ + Code: http.StatusNotFound, + Message: "Not Found", + }) + }, + expectedInstance: nil, + expectedError: errInstanceNotFound(456), + }, + { + name: "API error", + linodeID: 789, + setupMocks: func() { + mockClient.EXPECT().GetInstance(gomock.Any(), 789).Return(nil, errors.New("API error")) + }, + expectedInstance: nil, + expectedError: errInternal("get linode instance 789: API error"), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + tc.setupMocks() + + instance, err := cs.getInstance(context.Background(), tc.linodeID) + + if err != nil && !reflect.DeepEqual(tc.expectedError, err) { + t.Errorf("expected error %v, got %v", tc.expectedError, err) + } + + if !reflect.DeepEqual(tc.expectedInstance, instance) { + t.Errorf("expected instance %+v, got %+v", tc.expectedInstance, instance) + } + }) + } +}