From 762fb13c8756a7af995918ddcf79d7c14323c62c Mon Sep 17 00:00:00 2001 From: Gavin Frazar Date: Tue, 7 Jan 2025 17:44:16 -0800 Subject: [PATCH] migrate AWS Redshift Serverless to AWS SDK v2 --- go.mod | 1 + go.sum | 2 + integrations/event-handler/go.mod | 1 + integrations/event-handler/go.sum | 2 + integrations/terraform/go.mod | 1 + integrations/terraform/go.sum | 2 + lib/cloud/aws/errors.go | 3 +- lib/cloud/aws/tags_helpers.go | 6 +- lib/cloud/awstesthelpers/tags.go | 20 +++ lib/cloud/clients.go | 23 --- lib/cloud/mocks/aws_redshift_serverless.go | 94 ++++++------ lib/srv/db/access_test.go | 37 ++--- lib/srv/db/cloud/meta.go | 45 ++++-- lib/srv/db/cloud/meta_test.go | 22 ++- lib/srv/db/cloud/resource_checker_url_aws.go | 16 +- .../db/cloud/resource_checker_url_aws_test.go | 20 +-- lib/srv/db/common/auth.go | 29 ++-- lib/srv/db/common/auth_test.go | 28 ++-- lib/srv/db/watcher_test.go | 67 +++++++-- lib/srv/discovery/common/database.go | 18 +-- lib/srv/discovery/common/database_test.go | 6 +- lib/srv/discovery/discovery_test.go | 6 + lib/srv/discovery/fetchers/db/aws_rds_test.go | 6 + lib/srv/discovery/fetchers/db/aws_redshift.go | 2 +- .../fetchers/db/aws_redshift_serverless.go | 139 ++++++++++++------ .../db/aws_redshift_serverless_test.go | 59 ++++---- lib/srv/discovery/fetchers/db/db.go | 15 +- 27 files changed, 405 insertions(+), 265 deletions(-) diff --git a/go.mod b/go.mod index 625a780eb3ff6..4958f262fd6ac 100644 --- a/go.mod +++ b/go.mod @@ -66,6 +66,7 @@ require ( github.com/aws/aws-sdk-go-v2/service/organizations v1.37.0 github.com/aws/aws-sdk-go-v2/service/rds v1.93.2 github.com/aws/aws-sdk-go-v2/service/redshift v1.53.1 + github.com/aws/aws-sdk-go-v2/service/redshiftserverless v1.25.1 github.com/aws/aws-sdk-go-v2/service/s3 v1.72.0 github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.34.8 github.com/aws/aws-sdk-go-v2/service/sns v1.33.8 diff --git a/go.sum b/go.sum index 5bf38ba7fc0c4..81924d7ad6fba 100644 --- a/go.sum +++ b/go.sum @@ -928,6 +928,8 @@ github.com/aws/aws-sdk-go-v2/service/rds v1.93.2 h1:Fv2//DyCH9n6LqEOvpeIFYYRfIhv github.com/aws/aws-sdk-go-v2/service/rds v1.93.2/go.mod h1:QEpwiX4BS6nos2d/ele6gRGalNW0Hzc1TZMmhkywQb0= github.com/aws/aws-sdk-go-v2/service/redshift v1.53.1 h1:fpuhuF5DuY26w61bBq8YrMYecLVs6eiQK7JbD9womPI= github.com/aws/aws-sdk-go-v2/service/redshift v1.53.1/go.mod h1:Uz+PdLUo8+x/iXFrZGc+j+w/AVAfc7Qmju9XjCiQGHE= +github.com/aws/aws-sdk-go-v2/service/redshiftserverless v1.25.1 h1:anb79RuKbIO8z+SgNiDGCQln5CBI3Edzp9mXTAcZuNg= +github.com/aws/aws-sdk-go-v2/service/redshiftserverless v1.25.1/go.mod h1:u4NPdVb3te3+QB4rdjFGE9Of4V3vPqrPbTk6fAR6qf8= github.com/aws/aws-sdk-go-v2/service/s3 v1.72.0 h1:SAfh4pNx5LuTafKKWR02Y+hL3A+3TX8cTKG1OIAJaBk= github.com/aws/aws-sdk-go-v2/service/s3 v1.72.0/go.mod h1:r+xl5yzMk9083rMR+sJ5TYj9Tihvf/l1oxzZXDgGj2Q= github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.34.8 h1:WT3EPriVEpHE2jeNqHqj7l43JCIWPoZjNNRluZ7agII= diff --git a/integrations/event-handler/go.mod b/integrations/event-handler/go.mod index 19d919b359e39..9a316bf1670b9 100644 --- a/integrations/event-handler/go.mod +++ b/integrations/event-handler/go.mod @@ -88,6 +88,7 @@ require ( github.com/aws/aws-sdk-go-v2/service/organizations v1.37.0 // indirect github.com/aws/aws-sdk-go-v2/service/rds v1.93.2 // indirect github.com/aws/aws-sdk-go-v2/service/redshift v1.53.1 // indirect + github.com/aws/aws-sdk-go-v2/service/redshiftserverless v1.25.1 // indirect github.com/aws/aws-sdk-go-v2/service/s3 v1.72.0 // indirect github.com/aws/aws-sdk-go-v2/service/ssm v1.56.2 // indirect github.com/aws/aws-sdk-go-v2/service/sso v1.24.8 // indirect diff --git a/integrations/event-handler/go.sum b/integrations/event-handler/go.sum index 1f0435df0d184..29df1db6e8112 100644 --- a/integrations/event-handler/go.sum +++ b/integrations/event-handler/go.sum @@ -779,6 +779,8 @@ github.com/aws/aws-sdk-go-v2/service/rds v1.93.2 h1:Fv2//DyCH9n6LqEOvpeIFYYRfIhv github.com/aws/aws-sdk-go-v2/service/rds v1.93.2/go.mod h1:QEpwiX4BS6nos2d/ele6gRGalNW0Hzc1TZMmhkywQb0= github.com/aws/aws-sdk-go-v2/service/redshift v1.53.1 h1:fpuhuF5DuY26w61bBq8YrMYecLVs6eiQK7JbD9womPI= github.com/aws/aws-sdk-go-v2/service/redshift v1.53.1/go.mod h1:Uz+PdLUo8+x/iXFrZGc+j+w/AVAfc7Qmju9XjCiQGHE= +github.com/aws/aws-sdk-go-v2/service/redshiftserverless v1.25.1 h1:anb79RuKbIO8z+SgNiDGCQln5CBI3Edzp9mXTAcZuNg= +github.com/aws/aws-sdk-go-v2/service/redshiftserverless v1.25.1/go.mod h1:u4NPdVb3te3+QB4rdjFGE9Of4V3vPqrPbTk6fAR6qf8= github.com/aws/aws-sdk-go-v2/service/s3 v1.72.0 h1:SAfh4pNx5LuTafKKWR02Y+hL3A+3TX8cTKG1OIAJaBk= github.com/aws/aws-sdk-go-v2/service/s3 v1.72.0/go.mod h1:r+xl5yzMk9083rMR+sJ5TYj9Tihvf/l1oxzZXDgGj2Q= github.com/aws/aws-sdk-go-v2/service/ssm v1.56.2 h1:MOxvXH2kRP5exvqJxAZ0/H9Ar51VmADJh95SgZE8u60= diff --git a/integrations/terraform/go.mod b/integrations/terraform/go.mod index 5222dc914a105..3c55e81bf8a73 100644 --- a/integrations/terraform/go.mod +++ b/integrations/terraform/go.mod @@ -100,6 +100,7 @@ require ( github.com/aws/aws-sdk-go-v2/service/organizations v1.37.0 // indirect github.com/aws/aws-sdk-go-v2/service/rds v1.93.2 // indirect github.com/aws/aws-sdk-go-v2/service/redshift v1.53.1 // indirect + github.com/aws/aws-sdk-go-v2/service/redshiftserverless v1.25.1 // indirect github.com/aws/aws-sdk-go-v2/service/s3 v1.72.0 // indirect github.com/aws/aws-sdk-go-v2/service/ssm v1.56.2 // indirect github.com/aws/aws-sdk-go-v2/service/sso v1.24.8 // indirect diff --git a/integrations/terraform/go.sum b/integrations/terraform/go.sum index da4bca430e263..a2e2ba7a07617 100644 --- a/integrations/terraform/go.sum +++ b/integrations/terraform/go.sum @@ -852,6 +852,8 @@ github.com/aws/aws-sdk-go-v2/service/rds v1.93.2 h1:Fv2//DyCH9n6LqEOvpeIFYYRfIhv github.com/aws/aws-sdk-go-v2/service/rds v1.93.2/go.mod h1:QEpwiX4BS6nos2d/ele6gRGalNW0Hzc1TZMmhkywQb0= github.com/aws/aws-sdk-go-v2/service/redshift v1.53.1 h1:fpuhuF5DuY26w61bBq8YrMYecLVs6eiQK7JbD9womPI= github.com/aws/aws-sdk-go-v2/service/redshift v1.53.1/go.mod h1:Uz+PdLUo8+x/iXFrZGc+j+w/AVAfc7Qmju9XjCiQGHE= +github.com/aws/aws-sdk-go-v2/service/redshiftserverless v1.25.1 h1:anb79RuKbIO8z+SgNiDGCQln5CBI3Edzp9mXTAcZuNg= +github.com/aws/aws-sdk-go-v2/service/redshiftserverless v1.25.1/go.mod h1:u4NPdVb3te3+QB4rdjFGE9Of4V3vPqrPbTk6fAR6qf8= github.com/aws/aws-sdk-go-v2/service/s3 v1.72.0 h1:SAfh4pNx5LuTafKKWR02Y+hL3A+3TX8cTKG1OIAJaBk= github.com/aws/aws-sdk-go-v2/service/s3 v1.72.0/go.mod h1:r+xl5yzMk9083rMR+sJ5TYj9Tihvf/l1oxzZXDgGj2Q= github.com/aws/aws-sdk-go-v2/service/sns v1.33.8 h1:zKokiUMOfbZSrAUVqw+bSjr6gl9u/JcvPzHTmL+tmdQ= diff --git a/lib/cloud/aws/errors.go b/lib/cloud/aws/errors.go index 63a9ffa75ca95..946369bbe0ae6 100644 --- a/lib/cloud/aws/errors.go +++ b/lib/cloud/aws/errors.go @@ -28,7 +28,6 @@ import ( iamtypes "github.com/aws/aws-sdk-go-v2/service/iam/types" "github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/service/iam" - "github.com/aws/aws-sdk-go/service/redshiftserverless" "github.com/gravitational/trace" ) @@ -71,7 +70,7 @@ func convertRequestFailureErrorFromStatusCode(statusCode int, requestErr error) case http.StatusBadRequest: // Some services like memorydb, redshiftserverless may return 400 with // "AccessDeniedException" instead of 403. - if strings.Contains(requestErr.Error(), redshiftserverless.ErrCodeAccessDeniedException) { + if strings.Contains(requestErr.Error(), "AccessDeniedException") { return trace.AccessDenied(requestErr.Error()) } diff --git a/lib/cloud/aws/tags_helpers.go b/lib/cloud/aws/tags_helpers.go index 43f6ba48f61ca..a3e1b87ebbf4e 100644 --- a/lib/cloud/aws/tags_helpers.go +++ b/lib/cloud/aws/tags_helpers.go @@ -26,12 +26,12 @@ import ( ec2TypesV2 "github.com/aws/aws-sdk-go-v2/service/ec2/types" rdstypes "github.com/aws/aws-sdk-go-v2/service/rds/types" redshifttypes "github.com/aws/aws-sdk-go-v2/service/redshift/types" + rsstypes "github.com/aws/aws-sdk-go-v2/service/redshiftserverless/types" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/ec2" "github.com/aws/aws-sdk-go/service/elasticache" "github.com/aws/aws-sdk-go/service/memorydb" "github.com/aws/aws-sdk-go/service/opensearchservice" - "github.com/aws/aws-sdk-go/service/redshiftserverless" "github.com/aws/aws-sdk-go/service/secretsmanager" "golang.org/x/exp/maps" @@ -48,7 +48,7 @@ type ResourceTag interface { *ec2.Tag | *elasticache.Tag | *memorydb.Tag | - *redshiftserverless.Tag | + rsstypes.Tag | *opensearchservice.Tag | *secretsmanager.Tag } @@ -80,7 +80,7 @@ func resourceTagToKeyValue[Tag ResourceTag](tag Tag) (string, string) { return aws.StringValue(v.Key), aws.StringValue(v.Value) case *memorydb.Tag: return aws.StringValue(v.Key), aws.StringValue(v.Value) - case *redshiftserverless.Tag: + case rsstypes.Tag: return aws.StringValue(v.Key), aws.StringValue(v.Value) case rdstypes.Tag: return aws.StringValue(v.Key), aws.StringValue(v.Value) diff --git a/lib/cloud/awstesthelpers/tags.go b/lib/cloud/awstesthelpers/tags.go index 28bed6b973f0b..e17ebfb9c88a7 100644 --- a/lib/cloud/awstesthelpers/tags.go +++ b/lib/cloud/awstesthelpers/tags.go @@ -24,6 +24,7 @@ import ( rdstypes "github.com/aws/aws-sdk-go-v2/service/rds/types" redshifttypes "github.com/aws/aws-sdk-go-v2/service/redshift/types" + rsstypes "github.com/aws/aws-sdk-go-v2/service/redshiftserverless/types" ) // LabelsToRedshiftTags converts labels into [redshifttypes.Tag] list. @@ -63,3 +64,22 @@ func LabelsToRDSTags(labels map[string]string) []rdstypes.Tag { return ret } + +// LabelsToRedshiftServerlessTags converts labels into a [rsstypes.Tag] list. +func LabelsToRedshiftServerlessTags(labels map[string]string) []rsstypes.Tag { + keys := slices.Collect(maps.Keys(labels)) + slices.Sort(keys) + + ret := make([]rsstypes.Tag, 0, len(keys)) + for _, key := range keys { + key := key + value := labels[key] + + ret = append(ret, rsstypes.Tag{ + Key: &key, + Value: &value, + }) + } + + return ret +} diff --git a/lib/cloud/clients.go b/lib/cloud/clients.go index cc50c98c1ba4f..a39b367f92bd1 100644 --- a/lib/cloud/clients.go +++ b/lib/cloud/clients.go @@ -49,8 +49,6 @@ import ( "github.com/aws/aws-sdk-go/service/memorydb/memorydbiface" "github.com/aws/aws-sdk-go/service/opensearchservice" "github.com/aws/aws-sdk-go/service/opensearchservice/opensearchserviceiface" - "github.com/aws/aws-sdk-go/service/redshiftserverless" - "github.com/aws/aws-sdk-go/service/redshiftserverless/redshiftserverlessiface" "github.com/aws/aws-sdk-go/service/s3" "github.com/aws/aws-sdk-go/service/s3/s3iface" "github.com/aws/aws-sdk-go/service/secretsmanager" @@ -107,8 +105,6 @@ type GCPClients interface { type AWSClients interface { // GetAWSSession returns AWS session for the specified region and any role(s). GetAWSSession(ctx context.Context, region string, opts ...AWSOptionsFn) (*awssession.Session, error) - // GetAWSRedshiftServerlessClient returns AWS Redshift Serverless client for the specified region. - GetAWSRedshiftServerlessClient(ctx context.Context, region string, opts ...AWSOptionsFn) (redshiftserverlessiface.RedshiftServerlessAPI, error) // GetAWSElastiCacheClient returns AWS ElastiCache client for the specified region. GetAWSElastiCacheClient(ctx context.Context, region string, opts ...AWSOptionsFn) (elasticacheiface.ElastiCacheAPI, error) // GetAWSMemoryDBClient returns AWS MemoryDB client for the specified region. @@ -496,15 +492,6 @@ func (c *cloudClients) GetAWSSession(ctx context.Context, region string, opts .. return c.getAWSSessionForRole(ctx, region, options) } -// GetAWSRedshiftServerlessClient returns AWS Redshift Serverless client for the specified region. -func (c *cloudClients) GetAWSRedshiftServerlessClient(ctx context.Context, region string, opts ...AWSOptionsFn) (redshiftserverlessiface.RedshiftServerlessAPI, error) { - session, err := c.GetAWSSession(ctx, region, opts...) - if err != nil { - return nil, trace.Wrap(err) - } - return redshiftserverless.New(session), nil -} - // GetAWSElastiCacheClient returns AWS ElastiCache client for the specified region. func (c *cloudClients) GetAWSElastiCacheClient(ctx context.Context, region string, opts ...AWSOptionsFn) (elasticacheiface.ElastiCacheAPI, error) { session, err := c.GetAWSSession(ctx, region, opts...) @@ -992,7 +979,6 @@ var _ Clients = (*TestCloudClients)(nil) // TestCloudClients are used in tests. type TestCloudClients struct { - RedshiftServerless redshiftserverlessiface.RedshiftServerlessAPI ElastiCache elasticacheiface.ElastiCacheAPI OpenSearch opensearchserviceiface.OpenSearchServiceAPI MemoryDB memorydbiface.MemoryDBAPI @@ -1060,15 +1046,6 @@ func (c *TestCloudClients) getAWSSessionForRegion(region string) (*awssession.Se }) } -// GetAWSRedshiftServerlessClient returns AWS Redshift Serverless client for the specified region. -func (c *TestCloudClients) GetAWSRedshiftServerlessClient(ctx context.Context, region string, opts ...AWSOptionsFn) (redshiftserverlessiface.RedshiftServerlessAPI, error) { - _, err := c.GetAWSSession(ctx, region, opts...) - if err != nil { - return nil, trace.Wrap(err) - } - return c.RedshiftServerless, nil -} - // GetAWSElastiCacheClient returns AWS ElastiCache client for the specified region. func (c *TestCloudClients) GetAWSElastiCacheClient(ctx context.Context, region string, opts ...AWSOptionsFn) (elasticacheiface.ElastiCacheAPI, error) { _, err := c.GetAWSSession(ctx, region, opts...) diff --git a/lib/cloud/mocks/aws_redshift_serverless.go b/lib/cloud/mocks/aws_redshift_serverless.go index 5518352a04b34..c22afed51af4c 100644 --- a/lib/cloud/mocks/aws_redshift_serverless.go +++ b/lib/cloud/mocks/aws_redshift_serverless.go @@ -19,123 +19,127 @@ package mocks import ( + "context" "fmt" "time" + rss "github.com/aws/aws-sdk-go-v2/service/redshiftserverless" + rsstypes "github.com/aws/aws-sdk-go-v2/service/redshiftserverless/types" "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/request" - "github.com/aws/aws-sdk-go/service/redshiftserverless" - "github.com/aws/aws-sdk-go/service/redshiftserverless/redshiftserverlessiface" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" ) -// RedshiftServerlessMock mocks RedshiftServerless API. -type RedshiftServerlessMock struct { - redshiftserverlessiface.RedshiftServerlessAPI - +type RedshiftServerlessClient struct { Unauth bool - Workgroups []*redshiftserverless.Workgroup - Endpoints []*redshiftserverless.EndpointAccess - TagsByARN map[string][]*redshiftserverless.Tag - GetCredentialsOutput *redshiftserverless.GetCredentialsOutput + Workgroups []rsstypes.Workgroup + Endpoints []rsstypes.EndpointAccess + TagsByARN map[string][]rsstypes.Tag + GetCredentialsOutput *rss.GetCredentialsOutput } -func (m RedshiftServerlessMock) GetWorkgroupWithContext(_ aws.Context, input *redshiftserverless.GetWorkgroupInput, _ ...request.Option) (*redshiftserverless.GetWorkgroupOutput, error) { +func (m RedshiftServerlessClient) GetWorkgroup(_ context.Context, input *rss.GetWorkgroupInput, _ ...func(*rss.Options)) (*rss.GetWorkgroupOutput, error) { if m.Unauth { return nil, trace.AccessDenied("unauthorized") } for _, workgroup := range m.Workgroups { if aws.StringValue(workgroup.WorkgroupName) == aws.StringValue(input.WorkgroupName) { - return new(redshiftserverless.GetWorkgroupOutput).SetWorkgroup(workgroup), nil + return &rss.GetWorkgroupOutput{ + Workgroup: &workgroup, + }, nil } } return nil, trace.NotFound("workgroup %q not found", aws.StringValue(input.WorkgroupName)) } -func (m RedshiftServerlessMock) GetEndpointAccessWithContext(_ aws.Context, input *redshiftserverless.GetEndpointAccessInput, _ ...request.Option) (*redshiftserverless.GetEndpointAccessOutput, error) { + +func (m RedshiftServerlessClient) GetEndpointAccess(_ context.Context, input *rss.GetEndpointAccessInput, _ ...func(*rss.Options)) (*rss.GetEndpointAccessOutput, error) { if m.Unauth { return nil, trace.AccessDenied("unauthorized") } for _, endpoint := range m.Endpoints { if aws.StringValue(endpoint.EndpointName) == aws.StringValue(input.EndpointName) { - return new(redshiftserverless.GetEndpointAccessOutput).SetEndpoint(endpoint), nil + return &rss.GetEndpointAccessOutput{ + Endpoint: &endpoint, + }, nil } } return nil, trace.NotFound("endpoint %q not found", aws.StringValue(input.EndpointName)) } -func (m RedshiftServerlessMock) ListWorkgroupsPagesWithContext(_ aws.Context, input *redshiftserverless.ListWorkgroupsInput, fn func(*redshiftserverless.ListWorkgroupsOutput, bool) bool, _ ...request.Option) error { + +func (m RedshiftServerlessClient) ListWorkgroups(_ context.Context, input *rss.ListWorkgroupsInput, _ ...func(*rss.Options)) (*rss.ListWorkgroupsOutput, error) { if m.Unauth { - return trace.AccessDenied("unauthorized") + return nil, trace.AccessDenied("unauthorized") } - fn(&redshiftserverless.ListWorkgroupsOutput{ + return &rss.ListWorkgroupsOutput{ Workgroups: m.Workgroups, - }, true) - return nil + }, nil } -func (m RedshiftServerlessMock) ListEndpointAccessPagesWithContext(_ aws.Context, input *redshiftserverless.ListEndpointAccessInput, fn func(*redshiftserverless.ListEndpointAccessOutput, bool) bool, _ ...request.Option) error { + +func (m RedshiftServerlessClient) ListEndpointAccess(_ context.Context, input *rss.ListEndpointAccessInput, _ ...func(*rss.Options)) (*rss.ListEndpointAccessOutput, error) { if m.Unauth { - return trace.AccessDenied("unauthorized") + return nil, trace.AccessDenied("unauthorized") } - fn(&redshiftserverless.ListEndpointAccessOutput{ + return &rss.ListEndpointAccessOutput{ Endpoints: m.Endpoints, - }, true) - return nil + }, nil } -func (m RedshiftServerlessMock) ListTagsForResourceWithContext(_ aws.Context, input *redshiftserverless.ListTagsForResourceInput, _ ...request.Option) (*redshiftserverless.ListTagsForResourceOutput, error) { + +func (m RedshiftServerlessClient) ListTagsForResource(_ context.Context, input *rss.ListTagsForResourceInput, _ ...func(*rss.Options)) (*rss.ListTagsForResourceOutput, error) { if m.Unauth { return nil, trace.AccessDenied("unauthorized") } if m.TagsByARN == nil { - return &redshiftserverless.ListTagsForResourceOutput{}, nil + return &rss.ListTagsForResourceOutput{}, nil } - return &redshiftserverless.ListTagsForResourceOutput{ + return &rss.ListTagsForResourceOutput{ Tags: m.TagsByARN[aws.StringValue(input.ResourceArn)], }, nil } -func (m RedshiftServerlessMock) GetCredentialsWithContext(aws.Context, *redshiftserverless.GetCredentialsInput, ...request.Option) (*redshiftserverless.GetCredentialsOutput, error) { + +func (m RedshiftServerlessClient) GetCredentials(context.Context, *rss.GetCredentialsInput, ...func(*rss.Options)) (*rss.GetCredentialsOutput, error) { if m.Unauth || m.GetCredentialsOutput == nil { return nil, trace.AccessDenied("access denied") } return m.GetCredentialsOutput, nil } -// RedshiftServerlessWorkgroup returns a sample redshiftserverless.Workgroup. -func RedshiftServerlessWorkgroup(name, region string) *redshiftserverless.Workgroup { - return &redshiftserverless.Workgroup{ - BaseCapacity: aws.Int64(32), - ConfigParameters: []*redshiftserverless.ConfigParameter{{ +// RedshiftServerlessWorkgroup returns a sample rsstypes.Workgroup. +func RedshiftServerlessWorkgroup(name, region string) *rsstypes.Workgroup { + return &rsstypes.Workgroup{ + BaseCapacity: aws.Int32(32), + ConfigParameters: []rsstypes.ConfigParameter{{ ParameterKey: aws.String("max_query_execution_time"), ParameterValue: aws.String("14400"), }}, CreationDate: aws.Time(sampleTime), - Endpoint: &redshiftserverless.Endpoint{ + Endpoint: &rsstypes.Endpoint{ Address: aws.String(fmt.Sprintf("%v.123456789012.%v.redshift-serverless.amazonaws.com", name, region)), - Port: aws.Int64(5439), - VpcEndpoints: []*redshiftserverless.VpcEndpoint{{ + Port: aws.Int32(5439), + VpcEndpoints: []rsstypes.VpcEndpoint{{ VpcEndpointId: aws.String("vpc-endpoint-id"), VpcId: aws.String("vpc-id"), }}, }, NamespaceName: aws.String("my-namespace"), PubliclyAccessible: aws.Bool(true), - Status: aws.String("AVAILABLE"), + Status: rsstypes.WorkgroupStatusAvailable, WorkgroupArn: aws.String(fmt.Sprintf("arn:aws:redshift-serverless:%v:123456789012:workgroup/some-uuid-for-%v", region, name)), WorkgroupId: aws.String(fmt.Sprintf("some-uuid-for-%v", name)), WorkgroupName: aws.String(name), } } -// RedshiftServerlessEndpointAccess returns a sample redshiftserverless.EndpointAccess. -func RedshiftServerlessEndpointAccess(workgroup *redshiftserverless.Workgroup, name, region string) *redshiftserverless.EndpointAccess { - return &redshiftserverless.EndpointAccess{ +// RedshiftServerlessEndpointAccess returns a sample rsstypes.EndpointAccess. +func RedshiftServerlessEndpointAccess(workgroup *rsstypes.Workgroup, name, region string) *rsstypes.EndpointAccess { + return &rsstypes.EndpointAccess{ Address: aws.String(fmt.Sprintf("%s-endpoint-xxxyyyzzz.123456789012.%s.redshift-serverless.amazonaws.com", name, region)), EndpointArn: aws.String(fmt.Sprintf("arn:aws:redshift-serverless:%s:123456789012:managedvpcendpoint/some-uuid-for-%v", region, name)), EndpointCreateTime: aws.Time(sampleTime), EndpointName: aws.String(name), EndpointStatus: aws.String("AVAILABLE"), - Port: aws.Int64(5439), - VpcEndpoint: &redshiftserverless.VpcEndpoint{ + Port: aws.Int32(5439), + VpcEndpoint: &rsstypes.VpcEndpoint{ VpcEndpointId: aws.String("vpce-id"), VpcId: aws.String("vpc-id"), }, @@ -144,11 +148,11 @@ func RedshiftServerlessEndpointAccess(workgroup *redshiftserverless.Workgroup, n } // RedshiftServerlessGetCredentialsOutput return a sample redshiftserverless.GetCredentialsOutput. -func RedshiftServerlessGetCredentialsOutput(user, password string, clock clockwork.Clock) *redshiftserverless.GetCredentialsOutput { +func RedshiftServerlessGetCredentialsOutput(user, password string, clock clockwork.Clock) *rss.GetCredentialsOutput { if clock == nil { clock = clockwork.NewRealClock() } - return &redshiftserverless.GetCredentialsOutput{ + return &rss.GetCredentialsOutput{ DbUser: aws.String(user), DbPassword: aws.String(password), Expiration: aws.Time(clock.Now().Add(15 * time.Minute)), diff --git a/lib/srv/db/access_test.go b/lib/srv/db/access_test.go index 906acfd06c7cd..ee02ee02d903b 100644 --- a/lib/srv/db/access_test.go +++ b/lib/srv/db/access_test.go @@ -95,6 +95,7 @@ import ( "github.com/gravitational/teleport/lib/srv/db/snowflake" "github.com/gravitational/teleport/lib/srv/db/spanner" "github.com/gravitational/teleport/lib/srv/db/sqlserver" + "github.com/gravitational/teleport/lib/srv/discovery/fetchers/db" "github.com/gravitational/teleport/lib/tlsca" "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/teleport/lib/utils/cert" @@ -2450,6 +2451,8 @@ type agentParams struct { CloudClients clients.Clients // AWSConfigProvider provides [aws.Config] for AWS SDK service clients. AWSConfigProvider awsconfig.Provider + // AWSDatabaseFetcherFactory provides AWS database fetchers + AWSDatabaseFetcherFactory *db.AWSFetcherFactory // AWSMatchers is a list of AWS databases matchers. AWSMatchers []types.AWSMatcher // AzureMatchers is a list of Azure databases matchers. @@ -2490,13 +2493,12 @@ func (p *agentParams) setDefaults(c *testContext) { if p.CloudClients == nil { p.CloudClients = &clients.TestCloudClients{ - STS: &mocks.STSClientV1{}, - RedshiftServerless: &mocks.RedshiftServerlessMock{}, - ElastiCache: p.ElastiCache, - MemoryDB: p.MemoryDB, - SecretsManager: secrets.NewMockSecretsManagerClient(secrets.MockSecretsManagerClientConfig{}), - IAM: &mocks.IAMMock{}, - GCPSQL: p.GCPSQL, + STS: &mocks.STSClientV1{}, + ElastiCache: p.ElastiCache, + MemoryDB: p.MemoryDB, + SecretsManager: secrets.NewMockSecretsManagerClient(secrets.MockSecretsManagerClientConfig{}), + IAM: &mocks.IAMMock{}, + GCPSQL: p.GCPSQL, } } if p.AWSConfigProvider == nil { @@ -2603,16 +2605,17 @@ func (c *testContext) setupDatabaseServer(ctx context.Context, t testing.TB, p a Clock: c.clock, }) }, - CADownloader: p.CADownloader, - OnReconcile: p.OnReconcile, - ConnectionMonitor: connMonitor, - CloudClients: p.CloudClients, - AWSConfigProvider: p.AWSConfigProvider, - AWSMatchers: p.AWSMatchers, - AzureMatchers: p.AzureMatchers, - ShutdownPollPeriod: 100 * time.Millisecond, - InventoryHandle: inventoryHandle, - discoveryResourceChecker: p.DiscoveryResourceChecker, + CADownloader: p.CADownloader, + OnReconcile: p.OnReconcile, + ConnectionMonitor: connMonitor, + CloudClients: p.CloudClients, + AWSConfigProvider: p.AWSConfigProvider, + AWSDatabaseFetcherFactory: p.AWSDatabaseFetcherFactory, + AWSMatchers: p.AWSMatchers, + AzureMatchers: p.AzureMatchers, + ShutdownPollPeriod: 100 * time.Millisecond, + InventoryHandle: inventoryHandle, + discoveryResourceChecker: p.DiscoveryResourceChecker, }) require.NoError(t, err) diff --git a/lib/srv/db/cloud/meta.go b/lib/srv/db/cloud/meta.go index 0956759422b07..b57059da18ab7 100644 --- a/lib/srv/db/cloud/meta.go +++ b/lib/srv/db/cloud/meta.go @@ -28,12 +28,12 @@ import ( rdstypes "github.com/aws/aws-sdk-go-v2/service/rds/types" "github.com/aws/aws-sdk-go-v2/service/redshift" redshifttypes "github.com/aws/aws-sdk-go-v2/service/redshift/types" + rss "github.com/aws/aws-sdk-go-v2/service/redshiftserverless" + rsstypes "github.com/aws/aws-sdk-go-v2/service/redshiftserverless/types" "github.com/aws/aws-sdk-go/service/elasticache" "github.com/aws/aws-sdk-go/service/elasticache/elasticacheiface" "github.com/aws/aws-sdk-go/service/memorydb" "github.com/aws/aws-sdk-go/service/memorydb/memorydbiface" - "github.com/aws/aws-sdk-go/service/redshiftserverless" - "github.com/aws/aws-sdk-go/service/redshiftserverless/redshiftserverlessiface" "github.com/gravitational/trace" "github.com/gravitational/teleport" @@ -60,10 +60,17 @@ type redshiftClient interface { redshift.DescribeClustersAPIClient } +// rssClient defines a subset of the AWS Redshift Serverless client API. +type rssClient interface { + GetEndpointAccess(ctx context.Context, params *rss.GetEndpointAccessInput, optFns ...func(*rss.Options)) (*rss.GetEndpointAccessOutput, error) + GetWorkgroup(ctx context.Context, params *rss.GetWorkgroupInput, optFns ...func(*rss.Options)) (*rss.GetWorkgroupOutput, error) +} + // awsClientProvider is an AWS SDK client provider. type awsClientProvider interface { getRDSClient(cfg aws.Config, optFns ...func(*rds.Options)) rdsClient getRedshiftClient(cfg aws.Config, optFns ...func(*redshift.Options)) redshiftClient + getRedshiftServerlessClient(cfg aws.Config, optFns ...func(*rss.Options)) rssClient } type defaultAWSClients struct{} @@ -76,6 +83,10 @@ func (defaultAWSClients) getRedshiftClient(cfg aws.Config, optFns ...func(*redsh return redshift.NewFromConfig(cfg, optFns...) } +func (defaultAWSClients) getRedshiftServerlessClient(cfg aws.Config, optFns ...func(*rss.Options)) rssClient { + return rss.NewFromConfig(cfg, optFns...) +} + // MetadataConfig is the cloud metadata service config. type MetadataConfig struct { // Clients is an interface for retrieving cloud clients. @@ -242,18 +253,19 @@ func (m *Metadata) fetchRedshiftMetadata(ctx context.Context, database types.Dat // Serverless database. func (m *Metadata) fetchRedshiftServerlessMetadata(ctx context.Context, database types.Database) (*types.AWS, error) { meta := database.GetAWS() - client, err := m.cfg.Clients.GetAWSRedshiftServerlessClient(ctx, meta.Region, - cloud.WithAssumeRoleFromAWSMeta(meta), - cloud.WithAmbientCredentials(), + awsCfg, err := m.cfg.AWSConfigProvider.GetConfig(ctx, meta.Region, + awsconfig.WithAssumeRole(meta.AssumeRoleARN, meta.ExternalID), + awsconfig.WithAmbientCredentials(), ) if err != nil { return nil, trace.Wrap(err) } + clt := m.cfg.awsClients.getRedshiftServerlessClient(awsCfg) if meta.RedshiftServerless.EndpointName != "" { - return fetchRedshiftServerlessVPCEndpointMetadata(ctx, client, meta.RedshiftServerless.EndpointName) + return fetchRedshiftServerlessVPCEndpointMetadata(ctx, clt, meta.RedshiftServerless.EndpointName) } - return fetchRedshiftServerlessWorkgroupMetadata(ctx, client, meta.RedshiftServerless.WorkgroupName) + return fetchRedshiftServerlessWorkgroupMetadata(ctx, clt, meta.RedshiftServerless.WorkgroupName) } // fetchElastiCacheMetadata fetches metadata for the provided ElastiCache database. @@ -449,14 +461,14 @@ func describeRDSProxyCustomEndpointAndFindURI(ctx context.Context, clt rdsClient return nil, trace.BadParameter("could not find RDS Proxy custom endpoint %v with URI %v, got %s", proxyEndpointName, uri, endpoints) } -func fetchRedshiftServerlessWorkgroupMetadata(ctx context.Context, client redshiftserverlessiface.RedshiftServerlessAPI, workgroupName string) (*types.AWS, error) { +func fetchRedshiftServerlessWorkgroupMetadata(ctx context.Context, client rssClient, workgroupName string) (*types.AWS, error) { workgroup, err := describeRedshiftServerlessWorkgroup(ctx, client, workgroupName) if err != nil { return nil, trace.Wrap(err) } return discoverycommon.MetadataFromRedshiftServerlessWorkgroup(workgroup) } -func fetchRedshiftServerlessVPCEndpointMetadata(ctx context.Context, client redshiftserverlessiface.RedshiftServerlessAPI, endpointName string) (*types.AWS, error) { +func fetchRedshiftServerlessVPCEndpointMetadata(ctx context.Context, client rssClient, endpointName string) (*types.AWS, error) { endpoint, err := describeRedshiftServerlessVCPEndpoint(ctx, client, endpointName) if err != nil { return nil, trace.Wrap(err) @@ -467,17 +479,20 @@ func fetchRedshiftServerlessVPCEndpointMetadata(ctx context.Context, client reds } return discoverycommon.MetadataFromRedshiftServerlessVPCEndpoint(endpoint, workgroup) } -func describeRedshiftServerlessWorkgroup(ctx context.Context, client redshiftserverlessiface.RedshiftServerlessAPI, workgroupName string) (*redshiftserverless.Workgroup, error) { - input := new(redshiftserverless.GetWorkgroupInput).SetWorkgroupName(workgroupName) - output, err := client.GetWorkgroupWithContext(ctx, input) +func describeRedshiftServerlessWorkgroup(ctx context.Context, client rssClient, workgroupName string) (*rsstypes.Workgroup, error) { + output, err := client.GetWorkgroup(ctx, &rss.GetWorkgroupInput{ + WorkgroupName: aws.String(workgroupName), + }) if err != nil { return nil, common.ConvertError(err) } return output.Workgroup, nil } -func describeRedshiftServerlessVCPEndpoint(ctx context.Context, client redshiftserverlessiface.RedshiftServerlessAPI, endpointName string) (*redshiftserverless.EndpointAccess, error) { - input := new(redshiftserverless.GetEndpointAccessInput).SetEndpointName(endpointName) - output, err := client.GetEndpointAccessWithContext(ctx, input) + +func describeRedshiftServerlessVCPEndpoint(ctx context.Context, client rssClient, endpointName string) (*rsstypes.EndpointAccess, error) { + output, err := client.GetEndpointAccess(ctx, &rss.GetEndpointAccessInput{ + EndpointName: aws.String(endpointName), + }) if err != nil { return nil, common.ConvertError(err) } diff --git a/lib/srv/db/cloud/meta_test.go b/lib/srv/db/cloud/meta_test.go index 9c8805f026820..9cbefec8457b9 100644 --- a/lib/srv/db/cloud/meta_test.go +++ b/lib/srv/db/cloud/meta_test.go @@ -27,9 +27,10 @@ import ( rdstypes "github.com/aws/aws-sdk-go-v2/service/rds/types" "github.com/aws/aws-sdk-go-v2/service/redshift" redshifttypes "github.com/aws/aws-sdk-go-v2/service/redshift/types" + rss "github.com/aws/aws-sdk-go-v2/service/redshiftserverless" + rsstypes "github.com/aws/aws-sdk-go-v2/service/redshiftserverless/types" "github.com/aws/aws-sdk-go/service/elasticache" "github.com/aws/aws-sdk-go/service/memorydb" - "github.com/aws/aws-sdk-go/service/redshiftserverless" "github.com/stretchr/testify/require" "github.com/gravitational/teleport/api/types" @@ -124,18 +125,17 @@ func TestAWSMetadata(t *testing.T) { // Configure Redshift Serverless API mock. redshiftServerlessWorkgroup := mocks.RedshiftServerlessWorkgroup("my-workgroup", "us-west-1") redshiftServerlessEndpoint := mocks.RedshiftServerlessEndpointAccess(redshiftServerlessWorkgroup, "my-endpoint", "us-west-1") - redshiftServerless := &mocks.RedshiftServerlessMock{ - Workgroups: []*redshiftserverless.Workgroup{redshiftServerlessWorkgroup}, - Endpoints: []*redshiftserverless.EndpointAccess{redshiftServerlessEndpoint}, + redshiftServerless := &mocks.RedshiftServerlessClient{ + Workgroups: []rsstypes.Workgroup{*redshiftServerlessWorkgroup}, + Endpoints: []rsstypes.EndpointAccess{*redshiftServerlessEndpoint}, } // Create metadata fetcher. metadata, err := NewMetadata(MetadataConfig{ Clients: &cloud.TestCloudClients{ - ElastiCache: elasticache, - MemoryDB: memorydb, - RedshiftServerless: redshiftServerless, - STS: &fakeSTS.STSClientV1, + ElastiCache: elasticache, + MemoryDB: memorydb, + STS: &fakeSTS.STSClientV1, }, AWSConfigProvider: &mocks.AWSConfigProvider{ STSClient: fakeSTS, @@ -143,6 +143,7 @@ func TestAWSMetadata(t *testing.T) { awsClients: fakeAWSClients{ rdsClient: rdsClt, redshiftClient: redshiftClt, + rssClient: redshiftServerless, }, }) require.NoError(t, err) @@ -503,6 +504,7 @@ func TestAWSMetadataNoPermissions(t *testing.T) { type fakeAWSClients struct { rdsClient rdsClient redshiftClient redshiftClient + rssClient rssClient } func (f fakeAWSClients) getRDSClient(aws.Config, ...func(*rds.Options)) rdsClient { @@ -512,3 +514,7 @@ func (f fakeAWSClients) getRDSClient(aws.Config, ...func(*rds.Options)) rdsClien func (f fakeAWSClients) getRedshiftClient(aws.Config, ...func(*redshift.Options)) redshiftClient { return f.redshiftClient } + +func (f fakeAWSClients) getRedshiftServerlessClient(aws.Config, ...func(*rss.Options)) rssClient { + return f.rssClient +} diff --git a/lib/srv/db/cloud/resource_checker_url_aws.go b/lib/srv/db/cloud/resource_checker_url_aws.go index 5b87d643ea7b7..f5de449ec2c65 100644 --- a/lib/srv/db/cloud/resource_checker_url_aws.go +++ b/lib/srv/db/cloud/resource_checker_url_aws.go @@ -24,7 +24,6 @@ import ( rdstypes "github.com/aws/aws-sdk-go-v2/service/rds/types" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/opensearchservice" - "github.com/aws/aws-sdk-go/service/redshiftserverless/redshiftserverlessiface" "github.com/gravitational/trace" "github.com/gravitational/teleport/api/types" @@ -181,21 +180,22 @@ func (c *urlChecker) checkRedshift(ctx context.Context, database types.Database) func (c *urlChecker) checkRedshiftServerless(ctx context.Context, database types.Database) error { meta := database.GetAWS() - client, err := c.clients.GetAWSRedshiftServerlessClient(ctx, meta.Region, - cloud.WithAssumeRoleFromAWSMeta(meta), - cloud.WithAmbientCredentials(), + awsCfg, err := c.awsConfigProvider.GetConfig(ctx, meta.Region, + awsconfig.WithAssumeRole(meta.AssumeRoleARN, meta.ExternalID), + awsconfig.WithAmbientCredentials(), ) if err != nil { return trace.Wrap(err) } + clt := c.awsClients.getRedshiftServerlessClient(awsCfg) if meta.RedshiftServerless.EndpointName != "" { - return trace.Wrap(c.checkRedshiftServerlessVPCEndpoint(ctx, database, client, meta.RedshiftServerless.EndpointName)) + return trace.Wrap(c.checkRedshiftServerlessVPCEndpoint(ctx, database, clt, meta.RedshiftServerless.EndpointName)) } - return trace.Wrap(c.checkRedshiftServerlessWorkgroup(ctx, database, client, meta.RedshiftServerless.WorkgroupName)) + return trace.Wrap(c.checkRedshiftServerlessWorkgroup(ctx, database, clt, meta.RedshiftServerless.WorkgroupName)) } -func (c *urlChecker) checkRedshiftServerlessVPCEndpoint(ctx context.Context, database types.Database, client redshiftserverlessiface.RedshiftServerlessAPI, endpointName string) error { +func (c *urlChecker) checkRedshiftServerlessVPCEndpoint(ctx context.Context, database types.Database, client rssClient, endpointName string) error { endpoint, err := describeRedshiftServerlessVCPEndpoint(ctx, client, endpointName) if err != nil { return trace.Wrap(err) @@ -203,7 +203,7 @@ func (c *urlChecker) checkRedshiftServerlessVPCEndpoint(ctx context.Context, dat return trace.Wrap(requireDatabaseAddressPort(database, endpoint.Address, endpoint.Port)) } -func (c *urlChecker) checkRedshiftServerlessWorkgroup(ctx context.Context, database types.Database, client redshiftserverlessiface.RedshiftServerlessAPI, workgroupName string) error { +func (c *urlChecker) checkRedshiftServerlessWorkgroup(ctx context.Context, database types.Database, client rssClient, workgroupName string) error { workgroup, err := describeRedshiftServerlessWorkgroup(ctx, client, workgroupName) if err != nil { return trace.Wrap(err) diff --git a/lib/srv/db/cloud/resource_checker_url_aws_test.go b/lib/srv/db/cloud/resource_checker_url_aws_test.go index 40095f7efafe0..493fdcb76400d 100644 --- a/lib/srv/db/cloud/resource_checker_url_aws_test.go +++ b/lib/srv/db/cloud/resource_checker_url_aws_test.go @@ -24,10 +24,10 @@ import ( rdstypes "github.com/aws/aws-sdk-go-v2/service/rds/types" redshifttypes "github.com/aws/aws-sdk-go-v2/service/redshift/types" + rsstypes "github.com/aws/aws-sdk-go-v2/service/redshiftserverless/types" "github.com/aws/aws-sdk-go/service/elasticache" "github.com/aws/aws-sdk-go/service/memorydb" "github.com/aws/aws-sdk-go/service/opensearchservice" - "github.com/aws/aws-sdk-go/service/redshiftserverless" "github.com/stretchr/testify/require" "github.com/gravitational/teleport/api/types" @@ -121,10 +121,6 @@ func TestURLChecker_AWS(t *testing.T) { // Mock cloud clients. mockClients := &cloud.TestCloudClients{ - RedshiftServerless: &mocks.RedshiftServerlessMock{ - Workgroups: []*redshiftserverless.Workgroup{redshiftServerlessWorkgroup}, - Endpoints: []*redshiftserverless.EndpointAccess{redshiftServerlessVPCEndpoint}, - }, ElastiCache: &mocks.ElastiCacheMock{ ReplicationGroups: []*elasticache.ReplicationGroup{elastiCacheClusterConfigurationMode, elastiCacheCluster}, }, @@ -137,11 +133,10 @@ func TestURLChecker_AWS(t *testing.T) { STS: &mocks.STSClientV1{}, } mockClientsUnauth := &cloud.TestCloudClients{ - RedshiftServerless: &mocks.RedshiftServerlessMock{Unauth: true}, - ElastiCache: &mocks.ElastiCacheMock{Unauth: true}, - MemoryDB: &mocks.MemoryDBMock{Unauth: true}, - OpenSearch: &mocks.OpenSearchMock{Unauth: true}, - STS: &mocks.STSClientV1{}, + ElastiCache: &mocks.ElastiCacheMock{Unauth: true}, + MemoryDB: &mocks.MemoryDBMock{Unauth: true}, + OpenSearch: &mocks.OpenSearchMock{Unauth: true}, + STS: &mocks.STSClientV1{}, } // Test both check methods. @@ -167,6 +162,10 @@ func TestURLChecker_AWS(t *testing.T) { redshiftClient: &mocks.RedshiftClient{ Clusters: []redshifttypes.Cluster{redshiftCluster}, }, + rssClient: &mocks.RedshiftServerlessClient{ + Workgroups: []rsstypes.Workgroup{*redshiftServerlessWorkgroup}, + Endpoints: []rsstypes.EndpointAccess{*redshiftServerlessVPCEndpoint}, + }, }, }, { @@ -176,6 +175,7 @@ func TestURLChecker_AWS(t *testing.T) { awsClients: fakeAWSClients{ rdsClient: &mocks.RDSClient{Unauth: true}, redshiftClient: &mocks.RedshiftClient{Unauth: true}, + rssClient: &mocks.RedshiftServerlessClient{Unauth: true}, }, }, } diff --git a/lib/srv/db/common/auth.go b/lib/srv/db/common/auth.go index ad7183e70563b..35723b4e0fbe8 100644 --- a/lib/srv/db/common/auth.go +++ b/lib/srv/db/common/auth.go @@ -37,11 +37,11 @@ import ( "github.com/aws/aws-sdk-go-v2/aws" rdsauth "github.com/aws/aws-sdk-go-v2/feature/rds/auth" "github.com/aws/aws-sdk-go-v2/service/redshift" + rss "github.com/aws/aws-sdk-go-v2/service/redshiftserverless" "github.com/aws/aws-sdk-go/aws/credentials" v4 "github.com/aws/aws-sdk-go/aws/signer/v4" "github.com/aws/aws-sdk-go/service/elasticache" "github.com/aws/aws-sdk-go/service/memorydb" - "github.com/aws/aws-sdk-go/service/redshiftserverless" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" "golang.org/x/oauth2" @@ -131,9 +131,15 @@ type redshiftClient interface { GetClusterCredentials(context.Context, *redshift.GetClusterCredentialsInput, ...func(*redshift.Options)) (*redshift.GetClusterCredentialsOutput, error) } +// rssClient defines a subset of the AWS Redshift Serverless client API. +type rssClient interface { + GetCredentials(context.Context, *rss.GetCredentialsInput, ...func(*rss.Options)) (*rss.GetCredentialsOutput, error) +} + // awsClientProvider is an AWS SDK client provider. type awsClientProvider interface { getRedshiftClient(cfg aws.Config, optFns ...func(*redshift.Options)) redshiftClient + getRedshiftServerlessClient(cfg aws.Config, optFns ...func(*rss.Options)) rssClient } type defaultAWSClients struct{} @@ -142,6 +148,10 @@ func (defaultAWSClients) getRedshiftClient(cfg aws.Config, optFns ...func(*redsh return redshift.NewFromConfig(cfg, optFns...) } +func (defaultAWSClients) getRedshiftServerlessClient(cfg aws.Config, optFns ...func(*rss.Options)) rssClient { + return rss.NewFromConfig(cfg, optFns...) +} + // AuthConfig is the database access authenticator configuration. type AuthConfig struct { // AuthClient is the cluster auth client. @@ -400,18 +410,12 @@ func (a *dbAuth) GetRedshiftServerlessAuthToken(ctx context.Context, database ty if err != nil { return "", "", trace.Wrap(err) } - baseSession, err := a.cfg.Clients.GetAWSSession(ctx, meta.Region, - cloud.WithAssumeRoleFromAWSMeta(meta), - cloud.WithAmbientCredentials(), - ) - if err != nil { - return "", "", trace.Wrap(err) - } // Assume the configured AWS role before assuming the role we need to get the // auth token. This allows cross-account AWS access. - client, err := a.cfg.Clients.GetAWSRedshiftServerlessClient(ctx, meta.Region, - cloud.WithChainedAssumeRole(baseSession, roleARN, externalIDForChainedAssumeRole(meta)), - cloud.WithAmbientCredentials(), + awsCfg, err := a.cfg.AWSConfigProvider.GetConfig(ctx, meta.Region, + awsconfig.WithAssumeRole(meta.AssumeRoleARN, meta.ExternalID), + awsconfig.WithAssumeRole(roleARN, externalIDForChainedAssumeRole(meta)), + awsconfig.WithAmbientCredentials(), ) if err != nil { return "", "", trace.AccessDenied(`Could not generate Redshift Serverless auth token: @@ -421,6 +425,7 @@ func (a *dbAuth) GetRedshiftServerlessAuthToken(ctx context.Context, database ty Make sure that IAM role %q has a trust relationship with Teleport database agent's IAM identity. `, err, roleARN) } + clt := a.cfg.awsClients.getRedshiftServerlessClient(awsCfg) // Now make the API call to generate the temporary credentials. a.cfg.Logger.DebugContext(ctx, "Generating Redshift Serverless auth token", @@ -428,7 +433,7 @@ Make sure that IAM role %q has a trust relationship with Teleport database agent "database_user", databaseUser, "database_name", databaseName, ) - resp, err := client.GetCredentialsWithContext(ctx, &redshiftserverless.GetCredentialsInput{ + resp, err := clt.GetCredentials(ctx, &rss.GetCredentialsInput{ WorkgroupName: aws.String(meta.RedshiftServerless.WorkgroupName), DbName: aws.String(databaseName), }) diff --git a/lib/srv/db/common/auth_test.go b/lib/srv/db/common/auth_test.go index d85df87c5fd54..f98de77911fe8 100644 --- a/lib/srv/db/common/auth_test.go +++ b/lib/srv/db/common/auth_test.go @@ -33,6 +33,7 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v6" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/redshift" + rss "github.com/aws/aws-sdk-go-v2/service/redshiftserverless" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" "github.com/stretchr/testify/require" @@ -116,19 +117,19 @@ func TestAuthGetRedshiftServerlessAuthToken(t *testing.T) { t.Parallel() // setup mock aws sessions. - stsMock := &mocks.STSClientV1{} + stsMock := &mocks.STSClient{} clock := clockwork.NewFakeClock() auth, err := NewAuth(AuthConfig{ - Clock: clock, - AuthClient: new(authClientMock), - AccessPoint: new(accessPointMock), - Clients: &cloud.TestCloudClients{ - STS: stsMock, - RedshiftServerless: &mocks.RedshiftServerlessMock{ + Clock: clock, + AuthClient: new(authClientMock), + AccessPoint: new(accessPointMock), + Clients: &cloud.TestCloudClients{}, + AWSConfigProvider: &mocks.AWSConfigProvider{STSClient: stsMock}, + awsClients: fakeAWSClients{ + rssClient: &mocks.RedshiftServerlessClient{ GetCredentialsOutput: mocks.RedshiftServerlessGetCredentialsOutput("IAM:some-user", "some-password", clock), }, }, - AWSConfigProvider: &mocks.AWSConfigProvider{}, }) require.NoError(t, err) @@ -609,9 +610,6 @@ func TestAuthGetAWSTokenWithAssumedRole(t *testing.T) { AccessPoint: new(accessPointMock), Clients: &cloud.TestCloudClients{ STS: &fakeSTS.STSClientV1, - RedshiftServerless: &mocks.RedshiftServerlessMock{ - GetCredentialsOutput: mocks.RedshiftServerlessGetCredentialsOutput("IAM:some-user", "some-password", clock), - }, }, AWSConfigProvider: &mocks.AWSConfigProvider{ STSClient: fakeSTS, @@ -621,6 +619,9 @@ func TestAuthGetAWSTokenWithAssumedRole(t *testing.T) { GetClusterCredentialsOutput: mocks.RedshiftGetClusterCredentialsOutput("IAM:some-user", "some-password", clock), GetClusterCredentialsWithIAMOutput: mocks.RedshiftGetClusterCredentialsWithIAMOutput("IAM:some-role", "some-password-for-some-role", clock), }, + rssClient: &mocks.RedshiftServerlessClient{ + GetCredentialsOutput: mocks.RedshiftServerlessGetCredentialsOutput("IAM:some-user", "some-password", clock), + }, }, }) require.NoError(t, err) @@ -1023,8 +1024,13 @@ func (m *imdsMock) GetType() types.InstanceMetadataType { type fakeAWSClients struct { redshiftClient redshiftClient + rssClient rssClient } func (f fakeAWSClients) getRedshiftClient(aws.Config, ...func(*redshift.Options)) redshiftClient { return f.redshiftClient } + +func (f fakeAWSClients) getRedshiftServerlessClient(cfg aws.Config, optFns ...func(*rss.Options)) rssClient { + return f.rssClient +} diff --git a/lib/srv/db/watcher_test.go b/lib/srv/db/watcher_test.go index 6020547ea9590..6d0571e2c56ea 100644 --- a/lib/srv/db/watcher_test.go +++ b/lib/srv/db/watcher_test.go @@ -27,7 +27,11 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/sql/armsql" - "github.com/aws/aws-sdk-go/service/redshiftserverless" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/rds" + "github.com/aws/aws-sdk-go-v2/service/redshift" + rss "github.com/aws/aws-sdk-go-v2/service/redshiftserverless" + rsstypes "github.com/aws/aws-sdk-go-v2/service/redshiftserverless/types" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "github.com/gravitational/trace" @@ -40,6 +44,7 @@ import ( "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/services" discovery "github.com/gravitational/teleport/lib/srv/discovery/common" + "github.com/gravitational/teleport/lib/srv/discovery/fetchers/db" ) // TestWatcher verifies that database server properly detects and applies @@ -295,8 +300,7 @@ func setDiscoveryTypeLabel(r types.ResourceWithLabels, matcherType string) { // TestWatcherCloudFetchers tests usage of discovery database fetchers by the // database service. func TestWatcherCloudFetchers(t *testing.T) { - // Test an AWS fetcher. Note that status AWS can be set by Metadata - // service. + // Test an AWS fetcher. redshiftServerlessWorkgroup := mocks.RedshiftServerlessWorkgroup("discovery-aws", "us-east-1") redshiftServerlessDatabase, err := discovery.NewDatabaseFromRedshiftServerlessWorkgroup(redshiftServerlessWorkgroup, nil) require.NoError(t, err) @@ -313,21 +317,31 @@ func TestWatcherCloudFetchers(t *testing.T) { ctx := context.Background() testCtx := setupTestContext(ctx, t) + testCloudClients := &clients.TestCloudClients{ + AzureSQLServer: azure.NewSQLClientByAPI(&azure.ARMSQLServerMock{ + AllServers: []*armsql.Server{azSQLServer}, + }), + AzureManagedSQLServer: azure.NewManagedSQLClientByAPI(&azure.ARMSQLManagedServerMock{}), + } + dbFetcherFactory, err := db.NewAWSFetcherFactory(db.AWSFetcherFactoryConfig{ + AWSConfigProvider: &mocks.AWSConfigProvider{}, + CloudClients: testCloudClients, + AWSClients: fakeAWSClients{ + rdsClient: &mocks.RDSClient{Unauth: true}, // Access denied error should not affect other fetchers. + rssClient: &mocks.RedshiftServerlessClient{ + Workgroups: []rsstypes.Workgroup{*redshiftServerlessWorkgroup}, + }, + }, + }) + require.NoError(t, err) reconcileCh := make(chan types.Databases) testCtx.setupDatabaseServer(ctx, t, agentParams{ // Keep ResourceMatchers as nil to disable resource matchers. OnReconcile: func(d types.Databases) { reconcileCh <- d }, - CloudClients: &clients.TestCloudClients{ - RedshiftServerless: &mocks.RedshiftServerlessMock{ - Workgroups: []*redshiftserverless.Workgroup{redshiftServerlessWorkgroup}, - }, - AzureSQLServer: azure.NewSQLClientByAPI(&azure.ARMSQLServerMock{ - AllServers: []*armsql.Server{azSQLServer}, - }), - AzureManagedSQLServer: azure.NewManagedSQLClientByAPI(&azure.ARMSQLManagedServerMock{}), - }, + CloudClients: testCloudClients, + AWSDatabaseFetcherFactory: dbFetcherFactory, AzureMatchers: []types.AzureMatcher{{ Subscriptions: []string{"sub"}, Types: []string{types.AzureMatcherSQLServer}, @@ -343,18 +357,21 @@ func TestWatcherCloudFetchers(t *testing.T) { wantDatabases := types.Databases{azSQLServerDatabase, redshiftServerlessDatabase} sort.Sort(wantDatabases) - assertReconciledResource(t, reconcileCh, wantDatabases) + // cloud metadata updater is disabled, so don't check the AWS metadata status. + assertReconciledResource(t, reconcileCh, wantDatabases, cmpopts.IgnoreFields(types.DatabaseStatusV3{}, "AWS")) } -func assertReconciledResource(t *testing.T, ch chan types.Databases, databases types.Databases) { +func assertReconciledResource(t *testing.T, ch chan types.Databases, databases types.Databases, opts ...cmp.Option) { t.Helper() select { case d := <-ch: sort.Sort(d) require.Equal(t, len(d), len(databases)) require.Empty(t, cmp.Diff(databases, d, - cmpopts.IgnoreFields(types.Metadata{}, "Revision"), - cmpopts.IgnoreFields(types.DatabaseStatusV3{}, "CACert"), + append(cmp.Options{ + cmpopts.IgnoreFields(types.Metadata{}, "Revision"), + cmpopts.IgnoreFields(types.DatabaseStatusV3{}, "CACert"), + }, opts...), )) case <-time.After(time.Second): require.FailNow(t, "Didn't receive reconcile event after 1s.") @@ -420,3 +437,21 @@ func makeAzureSQLServer(t *testing.T, name, group string) (*armsql.Server, types discovery.ApplyAzureDatabaseNameSuffix(database, types.AzureMatcherSQLServer) return server, database } + +type fakeAWSClients struct { + rdsClient db.RDSClient + redshiftClient db.RedshiftClient + rssClient db.RSSClient +} + +func (f fakeAWSClients) GetRDSClient(cfg aws.Config, optFns ...func(*rds.Options)) db.RDSClient { + return f.rdsClient +} + +func (f fakeAWSClients) GetRedshiftClient(cfg aws.Config, optFns ...func(*redshift.Options)) db.RedshiftClient { + return f.redshiftClient +} + +func (f fakeAWSClients) GetRedshiftServerlessClient(cfg aws.Config, optFns ...func(*rss.Options)) db.RSSClient { + return f.rssClient +} diff --git a/lib/srv/discovery/common/database.go b/lib/srv/discovery/common/database.go index dcff7a2c0f614..aa8b354cd2531 100644 --- a/lib/srv/discovery/common/database.go +++ b/lib/srv/discovery/common/database.go @@ -31,11 +31,11 @@ import ( "github.com/aws/aws-sdk-go-v2/aws" rdstypes "github.com/aws/aws-sdk-go-v2/service/rds/types" redshifttypes "github.com/aws/aws-sdk-go-v2/service/redshift/types" + rsstypes "github.com/aws/aws-sdk-go-v2/service/redshiftserverless/types" "github.com/aws/aws-sdk-go/aws/arn" "github.com/aws/aws-sdk-go/service/elasticache" "github.com/aws/aws-sdk-go/service/memorydb" "github.com/aws/aws-sdk-go/service/opensearchservice" - "github.com/aws/aws-sdk-go/service/redshiftserverless" "github.com/gravitational/trace" "github.com/gravitational/teleport/api/types" @@ -990,7 +990,7 @@ func NewDatabaseFromMemoryDBCluster(cluster *memorydb.Cluster, extraLabels map[s // NewDatabaseFromRedshiftServerlessWorkgroup creates a database resource from // a Redshift Serverless Workgroup. -func NewDatabaseFromRedshiftServerlessWorkgroup(workgroup *redshiftserverless.Workgroup, tags []*redshiftserverless.Tag) (types.Database, error) { +func NewDatabaseFromRedshiftServerlessWorkgroup(workgroup *rsstypes.Workgroup, tags []rsstypes.Tag) (types.Database, error) { if workgroup.Endpoint == nil { return nil, trace.BadParameter("missing endpoint") } @@ -1007,14 +1007,14 @@ func NewDatabaseFromRedshiftServerlessWorkgroup(workgroup *redshiftserverless.Wo }, metadata.RedshiftServerless.WorkgroupName), types.DatabaseSpecV3{ Protocol: defaults.ProtocolPostgres, - URI: fmt.Sprintf("%v:%v", aws.ToString(workgroup.Endpoint.Address), aws.ToInt64(workgroup.Endpoint.Port)), + URI: fmt.Sprintf("%v:%v", aws.ToString(workgroup.Endpoint.Address), aws.ToInt32(workgroup.Endpoint.Port)), AWS: *metadata, }) } // NewDatabaseFromRedshiftServerlessVPCEndpoint creates a database resource from // a Redshift Serverless VPC endpoint. -func NewDatabaseFromRedshiftServerlessVPCEndpoint(endpoint *redshiftserverless.EndpointAccess, workgroup *redshiftserverless.Workgroup, tags []*redshiftserverless.Tag) (types.Database, error) { +func NewDatabaseFromRedshiftServerlessVPCEndpoint(endpoint *rsstypes.EndpointAccess, workgroup *rsstypes.Workgroup, tags []rsstypes.Tag) (types.Database, error) { if workgroup.Endpoint == nil { return nil, trace.BadParameter("missing endpoint") } @@ -1031,7 +1031,7 @@ func NewDatabaseFromRedshiftServerlessVPCEndpoint(endpoint *redshiftserverless.E }, metadata.RedshiftServerless.WorkgroupName, metadata.RedshiftServerless.EndpointName), types.DatabaseSpecV3{ Protocol: defaults.ProtocolPostgres, - URI: fmt.Sprintf("%v:%v", aws.ToString(endpoint.Address), aws.ToInt64(endpoint.Port)), + URI: fmt.Sprintf("%v:%v", aws.ToString(endpoint.Address), aws.ToInt32(endpoint.Port)), AWS: *metadata, // Use workgroup's default address as the server name. @@ -1220,7 +1220,7 @@ func MetadataFromMemoryDBCluster(cluster *memorydb.Cluster, endpointType string) // MetadataFromRedshiftServerlessWorkgroup creates AWS metadata for the // provided Redshift Serverless Workgroup. -func MetadataFromRedshiftServerlessWorkgroup(workgroup *redshiftserverless.Workgroup) (*types.AWS, error) { +func MetadataFromRedshiftServerlessWorkgroup(workgroup *rsstypes.Workgroup) (*types.AWS, error) { parsedARN, err := arn.Parse(aws.ToString(workgroup.WorkgroupArn)) if err != nil { return nil, trace.Wrap(err) @@ -1238,7 +1238,7 @@ func MetadataFromRedshiftServerlessWorkgroup(workgroup *redshiftserverless.Workg // MetadataFromRedshiftServerlessVPCEndpoint creates AWS metadata for the // provided Redshift Serverless VPC endpoint. -func MetadataFromRedshiftServerlessVPCEndpoint(endpoint *redshiftserverless.EndpointAccess, workgroup *redshiftserverless.Workgroup) (*types.AWS, error) { +func MetadataFromRedshiftServerlessVPCEndpoint(endpoint *rsstypes.EndpointAccess, workgroup *rsstypes.Workgroup) (*types.AWS, error) { parsedARN, err := arn.Parse(aws.ToString(endpoint.EndpointArn)) if err != nil { return nil, trace.Wrap(err) @@ -1472,7 +1472,7 @@ func labelsFromRedshiftCluster(cluster *redshifttypes.Cluster, meta *types.AWS) return addLabels(labels, libcloudaws.TagsToLabels(cluster.Tags)) } -func labelsFromRedshiftServerlessWorkgroup(workgroup *redshiftserverless.Workgroup, meta *types.AWS, tags []*redshiftserverless.Tag) map[string]string { +func labelsFromRedshiftServerlessWorkgroup(workgroup *rsstypes.Workgroup, meta *types.AWS, tags []rsstypes.Tag) map[string]string { labels := labelsFromAWSMetadata(meta) labels[types.DiscoveryLabelEndpointType] = services.RedshiftServerlessWorkgroupEndpoint labels[types.DiscoveryLabelNamespace] = aws.ToString(workgroup.NamespaceName) @@ -1482,7 +1482,7 @@ func labelsFromRedshiftServerlessWorkgroup(workgroup *redshiftserverless.Workgro return addLabels(labels, libcloudaws.TagsToLabels(tags)) } -func labelsFromRedshiftServerlessVPCEndpoint(endpoint *redshiftserverless.EndpointAccess, workgroup *redshiftserverless.Workgroup, meta *types.AWS, tags []*redshiftserverless.Tag) map[string]string { +func labelsFromRedshiftServerlessVPCEndpoint(endpoint *rsstypes.EndpointAccess, workgroup *rsstypes.Workgroup, meta *types.AWS, tags []rsstypes.Tag) map[string]string { labels := labelsFromAWSMetadata(meta) labels[types.DiscoveryLabelEndpointType] = services.RedshiftServerlessVPCEndpoint labels[types.DiscoveryLabelWorkgroup] = aws.ToString(endpoint.WorkgroupName) diff --git a/lib/srv/discovery/common/database_test.go b/lib/srv/discovery/common/database_test.go index 891c31a18bc13..819db125399b2 100644 --- a/lib/srv/discovery/common/database_test.go +++ b/lib/srv/discovery/common/database_test.go @@ -32,7 +32,6 @@ import ( redshifttypes "github.com/aws/aws-sdk-go-v2/service/redshift/types" "github.com/aws/aws-sdk-go/service/elasticache" "github.com/aws/aws-sdk-go/service/memorydb" - "github.com/aws/aws-sdk-go/service/redshiftserverless" "github.com/google/go-cmp/cmp" "github.com/google/uuid" "github.com/gravitational/trace" @@ -42,6 +41,7 @@ import ( awsutils "github.com/gravitational/teleport/api/utils/aws" azureutils "github.com/gravitational/teleport/api/utils/azure" libcloudaws "github.com/gravitational/teleport/lib/cloud/aws" + "github.com/gravitational/teleport/lib/cloud/awstesthelpers" "github.com/gravitational/teleport/lib/cloud/azure" "github.com/gravitational/teleport/lib/cloud/mocks" "github.com/gravitational/teleport/lib/defaults" @@ -1658,7 +1658,7 @@ func TestDatabaseFromMemoryDBCluster(t *testing.T) { func TestDatabaseFromRedshiftServerlessWorkgroup(t *testing.T) { workgroup := mocks.RedshiftServerlessWorkgroup("my-workgroup", "eu-west-2") - tags := libcloudaws.LabelsToTags[redshiftserverless.Tag](map[string]string{"env": "prod"}) + tags := awstesthelpers.LabelsToRedshiftServerlessTags(map[string]string{"env": "prod"}) expected, err := types.NewDatabaseV3(types.Metadata{ Name: "my-workgroup", Description: "Redshift Serverless workgroup in eu-west-2", @@ -1693,7 +1693,7 @@ func TestDatabaseFromRedshiftServerlessWorkgroup(t *testing.T) { func TestDatabaseFromRedshiftServerlessVPCEndpoint(t *testing.T) { workgroup := mocks.RedshiftServerlessWorkgroup("my-workgroup", "eu-west-2") endpoint := mocks.RedshiftServerlessEndpointAccess(workgroup, "my-endpoint", "eu-west-2") - tags := libcloudaws.LabelsToTags[redshiftserverless.Tag](map[string]string{"env": "prod"}) + tags := awstesthelpers.LabelsToRedshiftServerlessTags(map[string]string{"env": "prod"}) expected, err := types.NewDatabaseV3(types.Metadata{ Name: "my-workgroup-my-endpoint", Description: "Redshift Serverless endpoint in eu-west-2", diff --git a/lib/srv/discovery/discovery_test.go b/lib/srv/discovery/discovery_test.go index 2948e10cdb916..841079a5469ec 100644 --- a/lib/srv/discovery/discovery_test.go +++ b/lib/srv/discovery/discovery_test.go @@ -45,6 +45,7 @@ import ( rdstypes "github.com/aws/aws-sdk-go-v2/service/rds/types" "github.com/aws/aws-sdk-go-v2/service/redshift" redshifttypes "github.com/aws/aws-sdk-go-v2/service/redshift/types" + rss "github.com/aws/aws-sdk-go-v2/service/redshiftserverless" "github.com/aws/aws-sdk-go-v2/service/ssm" ssmtypes "github.com/aws/aws-sdk-go-v2/service/ssm/types" "github.com/google/go-cmp/cmp" @@ -3766,6 +3767,7 @@ func newPopulatedGCPProjectsMock() *mockProjectsAPI { type fakeAWSClients struct { rdsClient db.RDSClient redshiftClient db.RedshiftClient + rssClient db.RSSClient } func (f fakeAWSClients) GetRDSClient(cfg aws.Config, optFns ...func(*rds.Options)) db.RDSClient { @@ -3775,3 +3777,7 @@ func (f fakeAWSClients) GetRDSClient(cfg aws.Config, optFns ...func(*rds.Options func (f fakeAWSClients) GetRedshiftClient(cfg aws.Config, optFns ...func(*redshift.Options)) db.RedshiftClient { return f.redshiftClient } + +func (f fakeAWSClients) GetRedshiftServerlessClient(cfg aws.Config, optFns ...func(*rss.Options)) db.RSSClient { + return f.rssClient +} diff --git a/lib/srv/discovery/fetchers/db/aws_rds_test.go b/lib/srv/discovery/fetchers/db/aws_rds_test.go index db4aeeb376cc3..b86529f22f7a6 100644 --- a/lib/srv/discovery/fetchers/db/aws_rds_test.go +++ b/lib/srv/discovery/fetchers/db/aws_rds_test.go @@ -25,6 +25,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/rds" rdstypes "github.com/aws/aws-sdk-go-v2/service/rds/types" "github.com/aws/aws-sdk-go-v2/service/redshift" + rss "github.com/aws/aws-sdk-go-v2/service/redshiftserverless" "github.com/stretchr/testify/require" "github.com/gravitational/teleport/api/types" @@ -309,6 +310,7 @@ func newRegionalFakeRDSClientProvider(cs map[string]RDSClient) fakeRegionalRDSCl type fakeAWSClients struct { rdsClient RDSClient redshiftClient RedshiftClient + rssClient RSSClient } func (f fakeAWSClients) GetRDSClient(cfg aws.Config, optFns ...func(*rds.Options)) RDSClient { @@ -319,6 +321,10 @@ func (f fakeAWSClients) GetRedshiftClient(cfg aws.Config, optFns ...func(*redshi return f.redshiftClient } +func (f fakeAWSClients) GetRedshiftServerlessClient(cfg aws.Config, optFns ...func(*rss.Options)) RSSClient { + return f.rssClient +} + type fakeRegionalRDSClients struct { AWSClientProvider rdsClients map[string]RDSClient diff --git a/lib/srv/discovery/fetchers/db/aws_redshift.go b/lib/srv/discovery/fetchers/db/aws_redshift.go index b6a17f32ede5e..ccfa726e36e9e 100644 --- a/lib/srv/discovery/fetchers/db/aws_redshift.go +++ b/lib/srv/discovery/fetchers/db/aws_redshift.go @@ -89,7 +89,7 @@ func (f *redshiftPlugin) ComponentShortName() string { // getRedshiftClusters fetches all Reshift clusters using the provided client, // up to the specified max number of pages -func getRedshiftClusters(ctx context.Context, clt redshift.DescribeClustersAPIClient) ([]redshifttypes.Cluster, error) { +func getRedshiftClusters(ctx context.Context, clt RedshiftClient) ([]redshifttypes.Cluster, error) { pager := redshift.NewDescribeClustersPaginator(clt, &redshift.DescribeClustersInput{}, func(dcpo *redshift.DescribeClustersPaginatorOptions) { diff --git a/lib/srv/discovery/fetchers/db/aws_redshift_serverless.go b/lib/srv/discovery/fetchers/db/aws_redshift_serverless.go index 651034882b239..81dfb6cc27468 100644 --- a/lib/srv/discovery/fetchers/db/aws_redshift_serverless.go +++ b/lib/srv/discovery/fetchers/db/aws_redshift_serverless.go @@ -22,17 +22,25 @@ import ( "context" "log/slog" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/redshiftserverless" - "github.com/aws/aws-sdk-go/service/redshiftserverless/redshiftserverlessiface" + "github.com/aws/aws-sdk-go-v2/aws" + rss "github.com/aws/aws-sdk-go-v2/service/redshiftserverless" + rsstypes "github.com/aws/aws-sdk-go-v2/service/redshiftserverless/types" "github.com/gravitational/trace" "github.com/gravitational/teleport/api/types" - "github.com/gravitational/teleport/lib/cloud" libcloudaws "github.com/gravitational/teleport/lib/cloud/aws" + "github.com/gravitational/teleport/lib/cloud/awsconfig" "github.com/gravitational/teleport/lib/srv/discovery/common" ) +// RSSClient is a subset of the AWS Redshift Serverless API. +type RSSClient interface { + rss.ListEndpointAccessAPIClient + rss.ListWorkgroupsAPIClient + + ListTagsForResource(context.Context, *rss.ListTagsForResourceInput, ...func(*rss.Options)) (*rss.ListTagsForResourceOutput, error) +} + // newRedshiftServerlessFetcher returns a new AWS fetcher for Redshift // Serverless databases. func newRedshiftServerlessFetcher(cfg awsFetcherConfig) (common.Fetcher, error) { @@ -40,9 +48,9 @@ func newRedshiftServerlessFetcher(cfg awsFetcherConfig) (common.Fetcher, error) } type workgroupWithTags struct { - *redshiftserverless.Workgroup + *rsstypes.Workgroup - Tags []*redshiftserverless.Tag + Tags []rsstypes.Tag } // redshiftServerlessPlugin retrieves Redshift Serverless databases. @@ -53,25 +61,23 @@ func (f *redshiftServerlessPlugin) ComponentShortName() string { return "rss<" } -// rssAPI is a type alias for brevity alone. -type rssAPI = redshiftserverlessiface.RedshiftServerlessAPI - // GetDatabases returns Redshift Serverless databases matching the watcher's selectors. func (f *redshiftServerlessPlugin) GetDatabases(ctx context.Context, cfg *awsFetcherConfig) (types.Databases, error) { - client, err := cfg.AWSClients.GetAWSRedshiftServerlessClient(ctx, cfg.Region, - cloud.WithAssumeRole(cfg.AssumeRole.RoleARN, cfg.AssumeRole.ExternalID), - cloud.WithCredentialsMaybeIntegration(cfg.Integration), + awsCfg, err := cfg.AWSConfigProvider.GetConfig(ctx, cfg.Region, + awsconfig.WithAssumeRole(cfg.AssumeRole.RoleARN, cfg.AssumeRole.ExternalID), + awsconfig.WithCredentialsMaybeIntegration(cfg.Integration), ) if err != nil { return nil, trace.Wrap(err) } - databases, workgroups, err := getDatabasesFromWorkgroups(ctx, client, cfg.Logger) + clt := cfg.awsClients.GetRedshiftServerlessClient(awsCfg) + databases, workgroups, err := getDatabasesFromWorkgroups(ctx, clt, cfg.Logger) if err != nil { return nil, trace.Wrap(err) } if len(workgroups) > 0 { - vpcEndpointDatabases, err := getDatabasesFromVPCEndpoints(ctx, workgroups, client, cfg.Logger) + vpcEndpointDatabases, err := getDatabasesFromVPCEndpoints(ctx, workgroups, clt, cfg.Logger) if err != nil { if trace.IsAccessDenied(err) { cfg.Logger.DebugContext(ctx, "No permission to get Redshift Serverless VPC endpoints", "error", err) @@ -85,7 +91,7 @@ func (f *redshiftServerlessPlugin) GetDatabases(ctx context.Context, cfg *awsFet return databases, nil } -func getDatabasesFromWorkgroups(ctx context.Context, client rssAPI, logger *slog.Logger) (types.Databases, []*workgroupWithTags, error) { +func getDatabasesFromWorkgroups(ctx context.Context, client RSSClient, logger *slog.Logger) (types.Databases, []*workgroupWithTags, error) { workgroups, err := getRSSWorkgroups(ctx, client) if err != nil { return nil, nil, trace.Wrap(err) @@ -94,19 +100,19 @@ func getDatabasesFromWorkgroups(ctx context.Context, client rssAPI, logger *slog var databases types.Databases var workgroupsWithTags []*workgroupWithTags for _, workgroup := range workgroups { - if !libcloudaws.IsResourceAvailable(workgroup, workgroup.Status) { - logger.DebugContext(ctx, "Skipping unavailable Redshift Serverless workgroup", - "workgroup", aws.StringValue(workgroup.WorkgroupName), - "status", aws.StringValue(workgroup.Status), + if !isWorkgroupAvailable(logger, &workgroup) { + logger.DebugContext(ctx, "Skipping unavailable Redshift Serverless workgroup", + "status", workgroup.Status, + "workgroup", aws.ToString(workgroup.WorkgroupName), ) continue } tags := getRSSResourceTags(ctx, workgroup.WorkgroupArn, client, logger) - database, err := common.NewDatabaseFromRedshiftServerlessWorkgroup(workgroup, tags) + database, err := common.NewDatabaseFromRedshiftServerlessWorkgroup(&workgroup, tags) if err != nil { logger.InfoContext(ctx, "Could not convert Redshift Serverless workgroup to database resource", - "workgroup", aws.StringValue(workgroup.WorkgroupName), + "workgroup", aws.ToString(workgroup.WorkgroupName), "error", err, ) continue @@ -114,14 +120,14 @@ func getDatabasesFromWorkgroups(ctx context.Context, client rssAPI, logger *slog databases = append(databases, database) workgroupsWithTags = append(workgroupsWithTags, &workgroupWithTags{ - Workgroup: workgroup, + Workgroup: &workgroup, Tags: tags, }) } return databases, workgroupsWithTags, nil } -func getDatabasesFromVPCEndpoints(ctx context.Context, workgroups []*workgroupWithTags, client rssAPI, logger *slog.Logger) (types.Databases, error) { +func getDatabasesFromVPCEndpoints(ctx context.Context, workgroups []*workgroupWithTags, client RSSClient, logger *slog.Logger) (types.Databases, error) { endpoints, err := getRSSVPCEndpoints(ctx, client) if err != nil { return nil, trace.Wrap(err) @@ -129,26 +135,26 @@ func getDatabasesFromVPCEndpoints(ctx context.Context, workgroups []*workgroupWi var databases types.Databases for _, endpoint := range endpoints { - workgroup, found := findWorkgroupWithName(workgroups, aws.StringValue(endpoint.WorkgroupName)) + workgroup, found := findWorkgroupWithName(workgroups, aws.ToString(endpoint.WorkgroupName)) if !found { - logger.DebugContext(ctx, "Could not find matching workgroup for Redshift Serverless endpoint", "endpoint", aws.StringValue(endpoint.EndpointName)) + logger.DebugContext(ctx, "Could not find matching workgroup for Redshift Serverless endpoint", "endpoint", aws.ToString(endpoint.EndpointName)) continue } if !libcloudaws.IsResourceAvailable(endpoint, endpoint.EndpointStatus) { logger.DebugContext(ctx, "Skipping unavailable Redshift Serverless endpoint", - "endpoint", aws.StringValue(endpoint.EndpointName), - "status", aws.StringValue(endpoint.EndpointStatus), + "endpoint", aws.ToString(endpoint.EndpointName), + "status", aws.ToString(endpoint.EndpointStatus), ) continue } // VPC endpoints do not have resource tags attached to them. Use the // tags from the workgroups instead. - database, err := common.NewDatabaseFromRedshiftServerlessVPCEndpoint(endpoint, workgroup.Workgroup, workgroup.Tags) + database, err := common.NewDatabaseFromRedshiftServerlessVPCEndpoint(&endpoint, workgroup.Workgroup, workgroup.Tags) if err != nil { logger.InfoContext(ctx, "Could not convert Redshift Serverless endpoint to database resource", - "endpoint", aws.StringValue(endpoint.EndpointName), + "endpoint", aws.ToString(endpoint.EndpointName), "error", err, ) continue @@ -158,20 +164,20 @@ func getDatabasesFromVPCEndpoints(ctx context.Context, workgroups []*workgroupWi return databases, nil } -func getRSSResourceTags(ctx context.Context, arn *string, client rssAPI, logger *slog.Logger) []*redshiftserverless.Tag { - output, err := client.ListTagsForResourceWithContext(ctx, &redshiftserverless.ListTagsForResourceInput{ +func getRSSResourceTags(ctx context.Context, arn *string, client RSSClient, logger *slog.Logger) []rsstypes.Tag { + output, err := client.ListTagsForResource(ctx, &rss.ListTagsForResourceInput{ ResourceArn: arn, }) if err != nil { // Log errors here and return nil. if trace.IsAccessDenied(err) { logger.DebugContext(ctx, "No Permission to get Redshift Serverless tags", - "arn", aws.StringValue(arn), + "arn", aws.ToString(arn), "error", err, ) } else { logger.WarnContext(ctx, "Failed to get Redshift Serverless tags", - "arn", aws.StringValue(arn), + "arn", aws.ToString(arn), "error", err, ) } @@ -180,29 +186,66 @@ func getRSSResourceTags(ctx context.Context, arn *string, client rssAPI, logger return output.Tags } -func getRSSWorkgroups(ctx context.Context, client rssAPI) ([]*redshiftserverless.Workgroup, error) { - var pages [][]*redshiftserverless.Workgroup - err := client.ListWorkgroupsPagesWithContext(ctx, nil, func(page *redshiftserverless.ListWorkgroupsOutput, lastPage bool) bool { - pages = append(pages, page.Workgroups) - return len(pages) <= maxAWSPages - }) - return flatten(pages), libcloudaws.ConvertRequestFailureError(err) +func getRSSWorkgroups(ctx context.Context, clt RSSClient) ([]rsstypes.Workgroup, error) { + var out []rsstypes.Workgroup + pager := rss.NewListWorkgroupsPaginator(clt, + &rss.ListWorkgroupsInput{}, + func(o *rss.ListWorkgroupsPaginatorOptions) { + o.StopOnDuplicateToken = true + }, + ) + for i := 0; i < maxAWSPages && pager.HasMorePages(); i++ { + page, err := pager.NextPage(ctx) + if err != nil { + return nil, libcloudaws.ConvertRequestFailureErrorV2(err) + } + out = append(out, page.Workgroups...) + } + return out, nil } -func getRSSVPCEndpoints(ctx context.Context, client rssAPI) ([]*redshiftserverless.EndpointAccess, error) { - var pages [][]*redshiftserverless.EndpointAccess - err := client.ListEndpointAccessPagesWithContext(ctx, nil, func(page *redshiftserverless.ListEndpointAccessOutput, lastPage bool) bool { - pages = append(pages, page.Endpoints) - return len(pages) <= maxAWSPages - }) - return flatten(pages), libcloudaws.ConvertRequestFailureError(err) +func getRSSVPCEndpoints(ctx context.Context, clt RSSClient) ([]rsstypes.EndpointAccess, error) { + var out []rsstypes.EndpointAccess + pager := rss.NewListEndpointAccessPaginator(clt, + &rss.ListEndpointAccessInput{}, + func(o *rss.ListEndpointAccessPaginatorOptions) { + o.StopOnDuplicateToken = true + }, + ) + for i := 0; i < maxAWSPages && pager.HasMorePages(); i++ { + page, err := pager.NextPage(ctx) + if err != nil { + return nil, libcloudaws.ConvertRequestFailureErrorV2(err) + } + out = append(out, page.Endpoints...) + } + return out, nil } func findWorkgroupWithName(workgroups []*workgroupWithTags, name string) (*workgroupWithTags, bool) { for _, workgroup := range workgroups { - if aws.StringValue(workgroup.WorkgroupName) == name { + if aws.ToString(workgroup.WorkgroupName) == name { return workgroup, true } } return nil, false } + +func isWorkgroupAvailable(logger *slog.Logger, wg *rsstypes.Workgroup) bool { + switch wg.Status { + case + rsstypes.WorkgroupStatusAvailable, + rsstypes.WorkgroupStatusModifying: + return true + case + rsstypes.WorkgroupStatusCreating, + rsstypes.WorkgroupStatusDeleting: + return false + default: + logger.WarnContext(context.Background(), "Assuming Redshift Serverless workgroup with an unknown status is available", + "status", wg.Status, + "workgroup", aws.ToString(wg.NamespaceName), + ) + return true + } +} diff --git a/lib/srv/discovery/fetchers/db/aws_redshift_serverless_test.go b/lib/srv/discovery/fetchers/db/aws_redshift_serverless_test.go index cb90c400e038e..dd64dcdade45f 100644 --- a/lib/srv/discovery/fetchers/db/aws_redshift_serverless_test.go +++ b/lib/srv/discovery/fetchers/db/aws_redshift_serverless_test.go @@ -21,13 +21,13 @@ package db import ( "testing" + rsstypes "github.com/aws/aws-sdk-go-v2/service/redshiftserverless/types" "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/redshiftserverless" "github.com/stretchr/testify/require" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/cloud" - libcloudaws "github.com/gravitational/teleport/lib/cloud/aws" + "github.com/gravitational/teleport/lib/cloud/awstesthelpers" "github.com/gravitational/teleport/lib/cloud/mocks" "github.com/gravitational/teleport/lib/srv/discovery/common" ) @@ -39,24 +39,27 @@ func TestRedshiftServerlessFetcher(t *testing.T) { workgroupDev, workgroupDevDB := makeRedshiftServerlessWorkgroup(t, "wg2", "us-east-1", envDevLabels) endpointProd, endpointProdDB := makeRedshiftServerlessEndpoint(t, workgroupProd, "endpoint1", "us-east-1", envProdLabels) endpointDev, endpointProdDev := makeRedshiftServerlessEndpoint(t, workgroupDev, "endpoint2", "us-east-1", envDevLabels) - tagsByARN := map[string][]*redshiftserverless.Tag{ - aws.StringValue(workgroupProd.WorkgroupArn): libcloudaws.LabelsToTags[redshiftserverless.Tag](envProdLabels), - aws.StringValue(workgroupDev.WorkgroupArn): libcloudaws.LabelsToTags[redshiftserverless.Tag](envDevLabels), + tagsByARN := map[string][]rsstypes.Tag{ + aws.StringValue(workgroupProd.WorkgroupArn): awstesthelpers.LabelsToRedshiftServerlessTags(envProdLabels), + aws.StringValue(workgroupDev.WorkgroupArn): awstesthelpers.LabelsToRedshiftServerlessTags(envDevLabels), } workgroupNotAvailable := mocks.RedshiftServerlessWorkgroup("wg-creating", "us-east-1") - workgroupNotAvailable.Status = aws.String("creating") + workgroupNotAvailable.Status = rsstypes.WorkgroupStatusCreating endpointNotAvailable := mocks.RedshiftServerlessEndpointAccess(workgroupNotAvailable, "endpoint-creating", "us-east-1") endpointNotAvailable.EndpointStatus = aws.String("creating") tests := []awsFetcherTest{ { - name: "fetch all", - inputClients: &cloud.TestCloudClients{ - RedshiftServerless: &mocks.RedshiftServerlessMock{ - Workgroups: []*redshiftserverless.Workgroup{workgroupProd, workgroupDev}, - Endpoints: []*redshiftserverless.EndpointAccess{endpointProd, endpointDev}, - TagsByARN: tagsByARN, + name: "fetch all", + inputClients: &cloud.TestCloudClients{}, + fetcherCfg: AWSFetcherFactoryConfig{ + AWSClients: fakeAWSClients{ + rssClient: &mocks.RedshiftServerlessClient{ + Workgroups: []rsstypes.Workgroup{*workgroupProd, *workgroupDev}, + Endpoints: []rsstypes.EndpointAccess{*endpointProd, *endpointDev}, + TagsByARN: tagsByARN, + }, }, }, inputMatchers: makeAWSMatchersForType(types.AWSMatcherRedshiftServerless, "us-east-1", wildcardLabels), @@ -64,11 +67,13 @@ func TestRedshiftServerlessFetcher(t *testing.T) { }, { name: "fetch prod", - inputClients: &cloud.TestCloudClients{ - RedshiftServerless: &mocks.RedshiftServerlessMock{ - Workgroups: []*redshiftserverless.Workgroup{workgroupProd, workgroupDev}, - Endpoints: []*redshiftserverless.EndpointAccess{endpointProd, endpointDev}, - TagsByARN: tagsByARN, + fetcherCfg: AWSFetcherFactoryConfig{ + AWSClients: fakeAWSClients{ + rssClient: &mocks.RedshiftServerlessClient{ + Workgroups: []rsstypes.Workgroup{*workgroupProd, *workgroupDev}, + Endpoints: []rsstypes.EndpointAccess{*endpointProd, *endpointDev}, + TagsByARN: tagsByARN, + }, }, }, inputMatchers: makeAWSMatchersForType(types.AWSMatcherRedshiftServerless, "us-east-1", envProdLabels), @@ -76,11 +81,13 @@ func TestRedshiftServerlessFetcher(t *testing.T) { }, { name: "skip unavailable", - inputClients: &cloud.TestCloudClients{ - RedshiftServerless: &mocks.RedshiftServerlessMock{ - Workgroups: []*redshiftserverless.Workgroup{workgroupProd, workgroupNotAvailable}, - Endpoints: []*redshiftserverless.EndpointAccess{endpointNotAvailable}, - TagsByARN: tagsByARN, + fetcherCfg: AWSFetcherFactoryConfig{ + AWSClients: fakeAWSClients{ + rssClient: &mocks.RedshiftServerlessClient{ + Workgroups: []rsstypes.Workgroup{*workgroupProd, *workgroupNotAvailable}, + Endpoints: []rsstypes.EndpointAccess{*endpointNotAvailable}, + TagsByARN: tagsByARN, + }, }, }, inputMatchers: makeAWSMatchersForType(types.AWSMatcherRedshiftServerless, "us-east-1", wildcardLabels), @@ -90,18 +97,18 @@ func TestRedshiftServerlessFetcher(t *testing.T) { testAWSFetchers(t, tests...) } -func makeRedshiftServerlessWorkgroup(t *testing.T, name, region string, labels map[string]string) (*redshiftserverless.Workgroup, types.Database) { +func makeRedshiftServerlessWorkgroup(t *testing.T, name, region string, labels map[string]string) (*rsstypes.Workgroup, types.Database) { workgroup := mocks.RedshiftServerlessWorkgroup(name, region) - tags := libcloudaws.LabelsToTags[redshiftserverless.Tag](labels) + tags := awstesthelpers.LabelsToRedshiftServerlessTags(labels) database, err := common.NewDatabaseFromRedshiftServerlessWorkgroup(workgroup, tags) require.NoError(t, err) common.ApplyAWSDatabaseNameSuffix(database, types.AWSMatcherRedshiftServerless) return workgroup, database } -func makeRedshiftServerlessEndpoint(t *testing.T, workgroup *redshiftserverless.Workgroup, name, region string, labels map[string]string) (*redshiftserverless.EndpointAccess, types.Database) { +func makeRedshiftServerlessEndpoint(t *testing.T, workgroup *rsstypes.Workgroup, name, region string, labels map[string]string) (*rsstypes.EndpointAccess, types.Database) { endpoint := mocks.RedshiftServerlessEndpointAccess(workgroup, name, region) - tags := libcloudaws.LabelsToTags[redshiftserverless.Tag](labels) + tags := awstesthelpers.LabelsToRedshiftServerlessTags(labels) database, err := common.NewDatabaseFromRedshiftServerlessVPCEndpoint(endpoint, workgroup, tags) require.NoError(t, err) common.ApplyAWSDatabaseNameSuffix(database, types.AWSMatcherRedshiftServerless) diff --git a/lib/srv/discovery/fetchers/db/db.go b/lib/srv/discovery/fetchers/db/db.go index cd4df7269a14e..72ee5093786d4 100644 --- a/lib/srv/discovery/fetchers/db/db.go +++ b/lib/srv/discovery/fetchers/db/db.go @@ -25,6 +25,7 @@ import ( "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/rds" "github.com/aws/aws-sdk-go-v2/service/redshift" + rss "github.com/aws/aws-sdk-go-v2/service/redshiftserverless" "github.com/gravitational/trace" "golang.org/x/exp/maps" @@ -74,6 +75,8 @@ type AWSClientProvider interface { GetRDSClient(cfg aws.Config, optFns ...func(*rds.Options)) RDSClient // GetRedshiftClient provides an [RedshiftClient]. GetRedshiftClient(cfg aws.Config, optFns ...func(*redshift.Options)) RedshiftClient + // GetRedshiftServerlessClient provides an [RSSClient]. + GetRedshiftServerlessClient(cfg aws.Config, optFns ...func(*rss.Options)) RSSClient } type defaultAWSClients struct{} @@ -86,6 +89,10 @@ func (defaultAWSClients) GetRedshiftClient(cfg aws.Config, optFns ...func(*redsh return redshift.NewFromConfig(cfg, optFns...) } +func (defaultAWSClients) GetRedshiftServerlessClient(cfg aws.Config, optFns ...func(*rss.Options)) RSSClient { + return rss.NewFromConfig(cfg, optFns...) +} + // AWSFetcherFactoryConfig is the configuration for an [AWSFetcherFactory]. type AWSFetcherFactoryConfig struct { // AWSConfigProvider provides [aws.Config] for AWS SDK service clients. @@ -211,11 +218,3 @@ func filterDatabasesByLabels(ctx context.Context, databases types.Databases, lab } return matchedDatabases } - -// flatten flattens a nested slice [][]T to []T. -func flatten[T any](s [][]T) (result []T) { - for i := range s { - result = append(result, s[i]...) - } - return -}