diff --git a/lib/cloud/aws/config/config.go b/lib/cloud/aws/config/config.go new file mode 100644 index 0000000000000..815ebea3d0230 --- /dev/null +++ b/lib/cloud/aws/config/config.go @@ -0,0 +1,237 @@ +// Teleport +// Copyright (C) 2024 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package config + +import ( + "context" + "log/slog" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials/stscreds" + "github.com/aws/aws-sdk-go-v2/service/sts" + "github.com/gravitational/trace" + + "github.com/gravitational/teleport/lib/modules" +) + +const defaultRegion = "us-east-1" + +// credentialsSource defines where the credentials must come from. +type credentialsSource int + +const ( + // credentialsSourceAmbient uses the default Cloud SDK method to load the credentials. + credentialsSourceAmbient = iota + 1 + // credentialsSourceIntegration uses an Integration to load the credentials. + credentialsSourceIntegration +) + +// AWSIntegrationSessionProvider defines a function that creates a credential provider from a region and an integration. +// This is used to generate aws configs for clients that must use an integration instead of ambient credentials. +type AWSIntegrationCredentialProvider func(ctx context.Context, region, integration string) (aws.CredentialsProvider, error) + +// awsOptions is a struct of additional options for assuming an AWS role +// when construction an underlying AWS config. +type awsOptions struct { + // baseConfigis a config to use instead of the default config for an + // AWS region, which is used to enable role chaining. + baseConfig *aws.Config + // assumeRoleARN is the AWS IAM Role ARN to assume. + assumeRoleARN string + // assumeRoleExternalID is used to assume an external AWS IAM Role. + assumeRoleExternalID string + // credentialsSource describes which source to use to fetch credentials. + credentialsSource credentialsSource + // integration is the name of the integration to be used to fetch the credentials. + integration string + // awsIntegrationCredentialsProvider is the integration credential provider to use. + awsIntegrationCredentialsProvider AWSIntegrationCredentialProvider + // customRetryer is a custom retryer to use for the config. + customRetryer func() aws.Retryer + // maxRetries is the maximum number of retries to use for the config. + maxRetries *int +} + +func (a *awsOptions) checkAndSetDefaults() error { + switch a.credentialsSource { + case credentialsSourceAmbient: + if a.integration != "" { + return trace.BadParameter("integration and ambient credentials cannot be used at the same time") + } + case credentialsSourceIntegration: + if a.integration == "" { + return trace.BadParameter("missing integration name") + } + default: + return trace.BadParameter("missing credentials source (ambient or integration)") + } + + return nil +} + +// AWSOptionsFn is an option function for setting additional options +// when getting an AWS config. +type AWSOptionsFn func(*awsOptions) + +// WithAssumeRole configures options needed for assuming an AWS role. +func WithAssumeRole(roleARN, externalID string) AWSOptionsFn { + return func(options *awsOptions) { + options.assumeRoleARN = roleARN + options.assumeRoleExternalID = externalID + } +} + +// WithRetryer sets a custom retryer for the config. +func WithRetryer(retryer func() aws.Retryer) AWSOptionsFn { + return func(options *awsOptions) { + options.customRetryer = retryer + } +} + +// WithMaxRetries sets the maximum allowed value for the sdk to keep retrying. +func WithMaxRetries(maxRetries int) AWSOptionsFn { + return func(options *awsOptions) { + options.maxRetries = &maxRetries + } +} + +// WithCredentialsMaybeIntegration sets the credential source to be +// - ambient if the integration is an empty string +// - integration, otherwise +func WithCredentialsMaybeIntegration(integration string) AWSOptionsFn { + if integration != "" { + return withIntegrationCredentials(integration) + } + + return WithAmbientCredentials() +} + +// withIntegrationCredentials configures options with an Integration that must be used to fetch Credentials to assume a role. +// This prevents the usage of AWS environment credentials. +func withIntegrationCredentials(integration string) AWSOptionsFn { + return func(options *awsOptions) { + options.credentialsSource = credentialsSourceIntegration + options.integration = integration + } +} + +// WithAmbientCredentials configures options to use the ambient credentials. +func WithAmbientCredentials() AWSOptionsFn { + return func(options *awsOptions) { + options.credentialsSource = credentialsSourceAmbient + } +} + +// WithAWSIntegrationCredentialProvider sets the integration credential provider. +func WithAWSIntegrationCredentialProvider(cred AWSIntegrationCredentialProvider) AWSOptionsFn { + return func(options *awsOptions) { + options.awsIntegrationCredentialsProvider = cred + } +} + +// GetAWSConfig returns an AWS config for the specified region, optionally +// assuming AWS IAM Roles. +func GetAWSConfig(ctx context.Context, region string, opts ...AWSOptionsFn) (aws.Config, error) { + var options awsOptions + for _, opt := range opts { + opt(&options) + } + if options.baseConfig == nil { + cfg, err := getAWSConfigForRegion(ctx, region, options) + if err != nil { + return aws.Config{}, trace.Wrap(err) + } + options.baseConfig = &cfg + } + if options.assumeRoleARN == "" { + return *options.baseConfig, nil + } + return getAWSConfigForRole(ctx, region, options) +} + +// awsAmbientConfigProvider loads a new config using the environment variables. +func awsAmbientConfigProvider(region string, cred aws.CredentialsProvider, options awsOptions) (aws.Config, error) { + opts := buildAWSConfigOptions(region, cred, options) + cfg, err := config.LoadDefaultConfig(context.Background(), opts...) + return cfg, trace.Wrap(err) +} + +func buildAWSConfigOptions(region string, cred aws.CredentialsProvider, options awsOptions) []func(*config.LoadOptions) error { + opts := []func(*config.LoadOptions) error{ + config.WithDefaultRegion(defaultRegion), + config.WithRegion(region), + config.WithCredentialsProvider(cred), + } + if modules.GetModules().IsBoringBinary() { + opts = append(opts, config.WithUseFIPSEndpoint(aws.FIPSEndpointStateEnabled)) + } + if options.customRetryer != nil { + opts = append(opts, config.WithRetryer(options.customRetryer)) + } + if options.maxRetries != nil { + opts = append(opts, config.WithRetryMaxAttempts(*options.maxRetries)) + } + return opts +} + +// getAWSConfigForRegion returns AWS config for the specified region. +func getAWSConfigForRegion(ctx context.Context, region string, options awsOptions) (aws.Config, error) { + if err := options.checkAndSetDefaults(); err != nil { + return aws.Config{}, trace.Wrap(err) + } + + var cred aws.CredentialsProvider + if options.credentialsSource == credentialsSourceIntegration { + if options.awsIntegrationCredentialsProvider == nil { + return aws.Config{}, trace.BadParameter("missing aws integration credential provider") + } + + slog.DebugContext(ctx, "Initializing AWS config with integration", "region", region, "integration", options.integration) + var err error + cred, err = options.awsIntegrationCredentialsProvider(ctx, region, options.integration) + if err != nil { + return aws.Config{}, trace.Wrap(err) + } + } else { + slog.DebugContext(ctx, "Initializing AWS config from environment", "region", region) + } + + cfg, err := awsAmbientConfigProvider(region, cred, options) + return cfg, trace.Wrap(err) +} + +// getAWSConfigForRole returns an AWS config for the specified region and role. +func getAWSConfigForRole(ctx context.Context, region string, options awsOptions) (aws.Config, error) { + if err := options.checkAndSetDefaults(); err != nil { + return aws.Config{}, trace.Wrap(err) + } + + stsClient := sts.NewFromConfig(*options.baseConfig) + cred := stscreds.NewAssumeRoleProvider(stsClient, options.assumeRoleARN, func(aro *stscreds.AssumeRoleOptions) { + if options.assumeRoleExternalID != "" { + aro.ExternalID = aws.String(options.assumeRoleExternalID) + } + }) + if _, err := cred.Retrieve(ctx); err != nil { + return aws.Config{}, trace.Wrap(err) + } + + opts := buildAWSConfigOptions(region, cred, options) + cfg, err := config.LoadDefaultConfig(ctx, opts...) + return cfg, trace.Wrap(err) +} diff --git a/lib/cloud/aws/config/config_test.go b/lib/cloud/aws/config/config_test.go new file mode 100644 index 0000000000000..b6a0b867b7965 --- /dev/null +++ b/lib/cloud/aws/config/config_test.go @@ -0,0 +1,104 @@ +// Teleport +// Copyright (C) 2024 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package config + +import ( + "context" + "testing" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/gravitational/trace" + "github.com/stretchr/testify/require" +) + +type mockCredentialProvider struct { + cred aws.Credentials +} + +func (m *mockCredentialProvider) Retrieve(ctx context.Context) (aws.Credentials, error) { + return m.cred, nil +} + +func TestGetAWSConfigIntegration(t *testing.T) { + t.Parallel() + dummyIntegration := "integration-test" + dummyRegion := "test-region-123" + + t.Run("without an integration credential provider, must return missing credential provider error", func(t *testing.T) { + ctx := context.Background() + _, err := GetAWSConfig(ctx, dummyRegion, WithCredentialsMaybeIntegration(dummyIntegration)) + require.True(t, trace.IsBadParameter(err), "unexpected error: %v", err) + require.ErrorContains(t, err, "missing aws integration credential provider") + }) + + t.Run("with an integration credential provider, must return the credentials", func(t *testing.T) { + ctx := context.Background() + + cfg, err := GetAWSConfig(ctx, dummyRegion, + WithCredentialsMaybeIntegration(dummyIntegration), + WithAWSIntegrationCredentialProvider(func(ctx context.Context, region, integration string) (aws.CredentialsProvider, error) { + if region == dummyRegion && integration == dummyIntegration { + return &mockCredentialProvider{ + cred: aws.Credentials{ + SessionToken: "foo-bar", + }, + }, nil + } + return nil, trace.NotFound("no creds in region %q with integration %q", region, integration) + })) + require.NoError(t, err) + creds, err := cfg.Credentials.Retrieve(ctx) + require.NoError(t, err) + require.Equal(t, "foo-bar", creds.SessionToken) + }) + + t.Run("with an integration credential provider, but using an empty integration falls back to ambient credentials", func(t *testing.T) { + ctx := context.Background() + + _, err := GetAWSConfig(ctx, dummyRegion, + WithCredentialsMaybeIntegration(""), + WithAWSIntegrationCredentialProvider(func(ctx context.Context, region, integration string) (aws.CredentialsProvider, error) { + require.Fail(t, "this function should not be called") + return nil, nil + })) + require.NoError(t, err) + }) + + t.Run("with an integration credential provider, but using ambient credentials", func(t *testing.T) { + ctx := context.Background() + + _, err := GetAWSConfig(ctx, dummyRegion, + WithAmbientCredentials(), + WithAWSIntegrationCredentialProvider(func(ctx context.Context, region, integration string) (aws.CredentialsProvider, error) { + require.Fail(t, "this function should not be called") + return nil, nil + })) + require.NoError(t, err) + }) + + t.Run("with an integration credential provider, but no credential source", func(t *testing.T) { + ctx := context.Background() + + _, err := GetAWSConfig(ctx, dummyRegion, + WithAWSIntegrationCredentialProvider(func(ctx context.Context, region, integration string) (aws.CredentialsProvider, error) { + require.Fail(t, "this function should not be called") + return nil, nil + })) + require.Error(t, err) + require.ErrorContains(t, err, "missing credentials source") + }) +} diff --git a/lib/cloud/aws/errors.go b/lib/cloud/aws/errors.go index f13e1cf36c836..472590d35b1d4 100644 --- a/lib/cloud/aws/errors.go +++ b/lib/cloud/aws/errors.go @@ -36,11 +36,14 @@ import ( // the error without modifying it. func ConvertRequestFailureError(err error) error { var requestErr awserr.RequestFailure - if !errors.As(err, &requestErr) { - return err + if errors.As(err, &requestErr) { + return convertRequestFailureErrorFromStatusCode(requestErr.StatusCode(), requestErr) } - - return convertRequestFailureErrorFromStatusCode(requestErr.StatusCode(), requestErr) + var re *awshttp.ResponseError + if errors.As(err, &re) { + return convertRequestFailureErrorFromStatusCode(re.HTTPStatusCode(), re.Err) + } + return err } func convertRequestFailureErrorFromStatusCode(statusCode int, requestErr error) error { diff --git a/lib/cloud/aws/errors_test.go b/lib/cloud/aws/errors_test.go index 448c2ef9a6e24..165456bfdb25b 100644 --- a/lib/cloud/aws/errors_test.go +++ b/lib/cloud/aws/errors_test.go @@ -73,6 +73,18 @@ func TestConvertRequestFailureError(t *testing.T) { inputError: errors.New("not-aws-error"), wantUnmodified: true, }, + { + name: "v2 sdk error", + inputError: &awshttp.ResponseError{ + ResponseError: &smithyhttp.ResponseError{ + Response: &smithyhttp.Response{Response: &http.Response{ + StatusCode: http.StatusNotFound, + }}, + Err: trace.Errorf(""), + }, + }, + wantIsError: trace.IsNotFound, + }, } for _, test := range tests { diff --git a/lib/cloud/clients.go b/lib/cloud/clients.go index bc21d126b9903..328ee76bcee0e 100644 --- a/lib/cloud/clients.go +++ b/lib/cloud/clients.go @@ -38,8 +38,6 @@ import ( "github.com/aws/aws-sdk-go/aws/endpoints" "github.com/aws/aws-sdk-go/aws/request" awssession "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/ec2" - "github.com/aws/aws-sdk-go/service/ec2/ec2iface" "github.com/aws/aws-sdk-go/service/eks" "github.com/aws/aws-sdk-go/service/eks/eksiface" "github.com/aws/aws-sdk-go/service/elasticache" @@ -135,8 +133,6 @@ type AWSClients interface { GetAWSIAMClient(ctx context.Context, region string, opts ...AWSOptionsFn) (iamiface.IAMAPI, error) // GetAWSSTSClient returns AWS STS client for the specified region. GetAWSSTSClient(ctx context.Context, region string, opts ...AWSOptionsFn) (stsiface.STSAPI, error) - // GetAWSEC2Client returns AWS EC2 client for the specified region. - GetAWSEC2Client(ctx context.Context, region string, opts ...AWSOptionsFn) (ec2iface.EC2API, error) // GetAWSSSMClient returns AWS SSM client for the specified region. GetAWSSSMClient(ctx context.Context, region string, opts ...AWSOptionsFn) (ssmiface.SSMAPI, error) // GetAWSEKSClient returns AWS EKS client for the specified region. @@ -372,7 +368,7 @@ type credentialsSource int const ( // credentialsSourceAmbient uses the default Cloud SDK method to load the credentials. credentialsSourceAmbient = iota + 1 - // CredentialsSourceIntegration uses an Integration to load the credentials. + // credentialsSourceIntegration uses an Integration to load the credentials. credentialsSourceIntegration ) @@ -596,15 +592,6 @@ func (c *cloudClients) GetAWSSTSClient(ctx context.Context, region string, opts return sts.New(session), nil } -// GetAWSEC2Client returns AWS EC2 client for the specified region. -func (c *cloudClients) GetAWSEC2Client(ctx context.Context, region string, opts ...AWSOptionsFn) (ec2iface.EC2API, error) { - session, err := c.GetAWSSession(ctx, region, opts...) - if err != nil { - return nil, trace.Wrap(err) - } - return ec2.New(session), nil -} - // GetAWSSSMClient returns AWS SSM client for the specified region. func (c *cloudClients) GetAWSSSMClient(ctx context.Context, region string, opts ...AWSOptionsFn) (ssmiface.SSMAPI, error) { session, err := c.GetAWSSession(ctx, region, opts...) @@ -1034,7 +1021,6 @@ type TestCloudClients struct { GCPGKE gcp.GKEClient GCPProjects gcp.ProjectsClient GCPInstances gcp.InstancesClient - EC2 ec2iface.EC2API SSM ssmiface.SSMAPI InstanceMetadata imds.Client EKS eksiface.EKSAPI @@ -1205,15 +1191,6 @@ func (c *TestCloudClients) GetAWSKMSClient(ctx context.Context, region string, o return c.KMS, nil } -// GetAWSEC2Client returns AWS EC2 client for the specified region. -func (c *TestCloudClients) GetAWSEC2Client(ctx context.Context, region string, opts ...AWSOptionsFn) (ec2iface.EC2API, error) { - _, err := c.GetAWSSession(ctx, region, opts...) - if err != nil { - return nil, trace.Wrap(err) - } - return c.EC2, nil -} - // GetAWSSSMClient returns an AWS SSM client func (c *TestCloudClients) GetAWSSSMClient(ctx context.Context, region string, opts ...AWSOptionsFn) (ssmiface.SSMAPI, error) { _, err := c.GetAWSSession(ctx, region, opts...) diff --git a/lib/srv/discovery/access_graph.go b/lib/srv/discovery/access_graph.go index 7d65f99f88ef5..4bc207b21df01 100644 --- a/lib/srv/discovery/access_graph.go +++ b/lib/srv/discovery/access_graph.go @@ -502,6 +502,7 @@ func (s *Server) accessGraphFetchersFromMatchers(ctx context.Context, matchers M ctx, aws_sync.Config{ CloudClients: s.CloudClients, + GetEC2Client: s.GetEC2Client, AssumeRole: assumeRole, Regions: awsFetcher.Regions, Integration: awsFetcher.Integration, diff --git a/lib/srv/discovery/discovery.go b/lib/srv/discovery/discovery.go index efaa28ff71d47..d7ea81a940368 100644 --- a/lib/srv/discovery/discovery.go +++ b/lib/srv/discovery/discovery.go @@ -30,6 +30,8 @@ import ( "time" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v3" + awsv2 "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ec2" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/aws/session" @@ -52,6 +54,7 @@ import ( "github.com/gravitational/teleport/api/utils/retryutils" "github.com/gravitational/teleport/lib/auth/authclient" "github.com/gravitational/teleport/lib/cloud" + "github.com/gravitational/teleport/lib/cloud/aws/config" gcpimds "github.com/gravitational/teleport/lib/cloud/imds/gcp" "github.com/gravitational/teleport/lib/cryptosuites" "github.com/gravitational/teleport/lib/integrations/awsoidc" @@ -111,6 +114,8 @@ type gcpInstaller interface { type Config struct { // CloudClients is an interface for retrieving cloud clients. CloudClients cloud.Clients + // GetEC2Client gets an AWS EC2 client for the given region. + GetEC2Client server.EC2ClientGetter // IntegrationOnlyCredentials discards any Matcher that don't have an Integration. // When true, ambient credentials (used by the Cloud SDKs) are not used. IntegrationOnlyCredentials bool @@ -214,6 +219,34 @@ kubernetes matchers are present.`) } c.CloudClients = cloudClients } + if c.GetEC2Client == nil { + c.GetEC2Client = func(ctx context.Context, region string, opts ...config.AWSOptionsFn) (ec2.DescribeInstancesAPIClient, error) { + opts = append(opts, config.WithAWSIntegrationCredentialProvider(func(ctx context.Context, region, integrationName string) (awsv2.CredentialsProvider, error) { + integration, err := c.AccessPoint.GetIntegration(ctx, integrationName) + if err != nil { + return nil, trace.Wrap(err) + } + if integration.GetAWSOIDCIntegrationSpec() == nil { + return nil, trace.BadParameter("integration does not have aws oidc spec fields %q", integrationName) + } + token, err := c.AccessPoint.GenerateAWSOIDCToken(ctx, integrationName) + if err != nil { + return nil, trace.Wrap(err) + } + cred, err := awsoidc.NewAWSCredentialsProvider(ctx, &awsoidc.AWSClientRequest{ + Token: token, + RoleARN: integration.GetAWSOIDCIntegrationSpec().RoleARN, + Region: region, + }) + return cred, trace.Wrap(err) + })) + cfg, err := config.GetAWSConfig(ctx, region, opts...) + if err != nil { + return nil, trace.Wrap(err) + } + return ec2.NewFromConfig(cfg), nil + } + } if c.KubernetesClient == nil && len(c.Matchers.Kubernetes) > 0 { cfg, err := rest.InClusterConfig() if err != nil { @@ -465,7 +498,7 @@ func (s *Server) initAWSWatchers(matchers []types.AWSMatcher) error { }) const noDiscoveryConfig = "" - s.staticServerAWSFetchers, err = server.MatchersToEC2InstanceFetchers(s.ctx, ec2Matchers, s.CloudClients, noDiscoveryConfig) + s.staticServerAWSFetchers, err = server.MatchersToEC2InstanceFetchers(s.ctx, ec2Matchers, s.GetEC2Client, noDiscoveryConfig) if err != nil { return trace.Wrap(err) } @@ -555,7 +588,7 @@ func (s *Server) awsServerFetchersFromMatchers(ctx context.Context, matchers []t return matcherType == types.AWSMatcherEC2 }) - fetchers, err := server.MatchersToEC2InstanceFetchers(ctx, serverMatchers, s.CloudClients, discoveryConfig) + fetchers, err := server.MatchersToEC2InstanceFetchers(ctx, serverMatchers, s.GetEC2Client, discoveryConfig) if err != nil { return nil, trace.Wrap(err) } @@ -910,7 +943,7 @@ func (s *Server) heartbeatEICEInstance(instances *server.EC2Instances) { nodesToUpsert := make([]types.Server, 0, len(instances.Instances)) // Add EC2 Instances using EICE method for _, ec2Instance := range instances.Instances { - eiceNode, err := common.NewAWSNodeFromEC2v1Instance(ec2Instance.OriginalInstance, awsInfo) + eiceNode, err := common.NewAWSNodeFromEC2Instance(ec2Instance.OriginalInstance, awsInfo) if err != nil { s.Log.WarnContext(s.ctx, "Error converting to Teleport EICE Node", "error", err, "instance_id", ec2Instance.InstanceID) diff --git a/lib/srv/discovery/discovery_test.go b/lib/srv/discovery/discovery_test.go index a4e379f63933e..d176c86ca6ba2 100644 --- a/lib/srv/discovery/discovery_test.go +++ b/lib/srv/discovery/discovery_test.go @@ -36,11 +36,12 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v3" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/redis/armredis/v2" + awsv2 "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ec2" + ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/aws/request" - "github.com/aws/aws-sdk-go/service/ec2" - "github.com/aws/aws-sdk-go/service/ec2/ec2iface" "github.com/aws/aws-sdk-go/service/eks" "github.com/aws/aws-sdk-go/service/eks/eksiface" "github.com/aws/aws-sdk-go/service/rds" @@ -77,6 +78,7 @@ import ( "github.com/gravitational/teleport/lib/auth/authclient" "github.com/gravitational/teleport/lib/authz" "github.com/gravitational/teleport/lib/cloud" + "github.com/gravitational/teleport/lib/cloud/aws/config" "github.com/gravitational/teleport/lib/cloud/azure" "github.com/gravitational/teleport/lib/cloud/gcp" gcpimds "github.com/gravitational/teleport/lib/cloud/imds/gcp" @@ -161,16 +163,11 @@ func (m *mockUsageReporter) DiscoveryFetchEventCount() int { } type mockEC2Client struct { - ec2iface.EC2API output *ec2.DescribeInstancesOutput } -func (m *mockEC2Client) DescribeInstancesPagesWithContext( - ctx context.Context, input *ec2.DescribeInstancesInput, - f func(dio *ec2.DescribeInstancesOutput, b bool) bool, opts ...request.Option, -) error { - f(m.output, true) - return nil +func (m *mockEC2Client) DescribeInstances(ctx context.Context, input *ec2.DescribeInstancesInput, opts ...func(*ec2.Options)) (*ec2.DescribeInstancesOutput, error) { + return m.output, nil } func genEC2InstanceIDs(n int) []string { @@ -181,17 +178,17 @@ func genEC2InstanceIDs(n int) []string { return ec2InstanceIDs } -func genEC2Instances(n int) []*ec2.Instance { - var ec2Instances []*ec2.Instance +func genEC2Instances(n int) []ec2types.Instance { + var ec2Instances []ec2types.Instance for _, id := range genEC2InstanceIDs(n) { - ec2Instances = append(ec2Instances, &ec2.Instance{ - InstanceId: aws.String(id), - Tags: []*ec2.Tag{{ - Key: aws.String("env"), - Value: aws.String("dev"), + ec2Instances = append(ec2Instances, ec2types.Instance{ + InstanceId: awsv2.String(id), + Tags: []ec2types.Tag{{ + Key: awsv2.String("env"), + Value: awsv2.String("dev"), }}, - State: &ec2.InstanceState{ - Name: aws.String(ec2.InstanceStateNameRunning), + State: &ec2types.InstanceState{ + Name: ec2types.InstanceStateNameRunning, }, }) } @@ -301,7 +298,7 @@ func TestDiscoveryServer(t *testing.T) { name string // presentInstances is a list of servers already present in teleport presentInstances []types.Server - foundEC2Instances []*ec2.Instance + foundEC2Instances []ec2types.Instance ssm *mockSSMClient emitter *mockEmitter discoveryConfig *discoveryconfig.DiscoveryConfig @@ -314,15 +311,15 @@ func TestDiscoveryServer(t *testing.T) { { name: "no nodes present, 1 found ", presentInstances: []types.Server{}, - foundEC2Instances: []*ec2.Instance{ + foundEC2Instances: []ec2types.Instance{ { - InstanceId: aws.String("instance-id-1"), - Tags: []*ec2.Tag{{ - Key: aws.String("env"), - Value: aws.String("dev"), + InstanceId: awsv2.String("instance-id-1"), + Tags: []ec2types.Tag{{ + Key: awsv2.String("env"), + Value: awsv2.String("dev"), }}, - State: &ec2.InstanceState{ - Name: aws.String(ec2.InstanceStateNameRunning), + State: &ec2types.InstanceState{ + Name: ec2types.InstanceStateNameRunning, }, }, }, @@ -372,15 +369,15 @@ func TestDiscoveryServer(t *testing.T) { }, }, }, - foundEC2Instances: []*ec2.Instance{ + foundEC2Instances: []ec2types.Instance{ { - InstanceId: aws.String("instance-id-1"), - Tags: []*ec2.Tag{{ - Key: aws.String("env"), - Value: aws.String("dev"), + InstanceId: awsv2.String("instance-id-1"), + Tags: []ec2types.Tag{{ + Key: awsv2.String("env"), + Value: awsv2.String("dev"), }}, - State: &ec2.InstanceState{ - Name: aws.String(ec2.InstanceStateNameRunning), + State: &ec2types.InstanceState{ + Name: ec2types.InstanceStateNameRunning, }, }, }, @@ -413,15 +410,15 @@ func TestDiscoveryServer(t *testing.T) { }, }, }, - foundEC2Instances: []*ec2.Instance{ + foundEC2Instances: []ec2types.Instance{ { - InstanceId: aws.String("instance-id-1"), - Tags: []*ec2.Tag{{ - Key: aws.String("env"), - Value: aws.String("dev"), + InstanceId: awsv2.String("instance-id-1"), + Tags: []ec2types.Tag{{ + Key: awsv2.String("env"), + Value: awsv2.String("dev"), }}, - State: &ec2.InstanceState{ - Name: aws.String(ec2.InstanceStateNameRunning), + State: &ec2types.InstanceState{ + Name: ec2types.InstanceStateNameRunning, }, }, }, @@ -462,15 +459,15 @@ func TestDiscoveryServer(t *testing.T) { { name: "no nodes present, 1 found using dynamic matchers", presentInstances: []types.Server{}, - foundEC2Instances: []*ec2.Instance{ + foundEC2Instances: []ec2types.Instance{ { - InstanceId: aws.String("instance-id-1"), - Tags: []*ec2.Tag{{ - Key: aws.String("env"), - Value: aws.String("dev"), + InstanceId: awsv2.String("instance-id-1"), + Tags: []ec2types.Tag{{ + Key: awsv2.String("env"), + Value: awsv2.String("dev"), }}, - State: &ec2.InstanceState{ - Name: aws.String(ec2.InstanceStateNameRunning), + State: &ec2types.InstanceState{ + Name: ec2types.InstanceStateNameRunning, }, }, }, @@ -509,15 +506,15 @@ func TestDiscoveryServer(t *testing.T) { { name: "one node found with Script mode using Integration credentials", presentInstances: []types.Server{}, - foundEC2Instances: []*ec2.Instance{ + foundEC2Instances: []ec2types.Instance{ { - InstanceId: aws.String("instance-id-1"), - Tags: []*ec2.Tag{{ - Key: aws.String("env"), - Value: aws.String("dev"), + InstanceId: awsv2.String("instance-id-1"), + Tags: []ec2types.Tag{{ + Key: awsv2.String("env"), + Value: awsv2.String("dev"), }}, - State: &ec2.InstanceState{ - Name: aws.String(ec2.InstanceStateNameRunning), + State: &ec2types.InstanceState{ + Name: ec2types.InstanceStateNameRunning, }, }, }, @@ -571,15 +568,15 @@ func TestDiscoveryServer(t *testing.T) { { name: "one node found but SSM Run fails and DiscoverEC2 User Task is created", presentInstances: []types.Server{}, - foundEC2Instances: []*ec2.Instance{ + foundEC2Instances: []ec2types.Instance{ { - InstanceId: aws.String("instance-id-1"), - Tags: []*ec2.Tag{{ - Key: aws.String("env"), - Value: aws.String("dev"), + InstanceId: awsv2.String("instance-id-1"), + Tags: []ec2types.Tag{{ + Key: awsv2.String("env"), + Value: awsv2.String("dev"), }}, - State: &ec2.InstanceState{ - Name: aws.String(ec2.InstanceStateNameRunning), + State: &ec2types.InstanceState{ + Name: ec2types.InstanceStateNameRunning, }, }, }, @@ -644,18 +641,16 @@ func TestDiscoveryServer(t *testing.T) { t.Parallel() testCloudClients := &cloud.TestCloudClients{ - EC2: &mockEC2Client{ - output: &ec2.DescribeInstancesOutput{ - Reservations: []*ec2.Reservation{ - { - OwnerId: aws.String("owner"), - Instances: tc.foundEC2Instances, - }, - }, - }, - }, SSM: tc.ssm, } + ec2Client := &mockEC2Client{output: &ec2.DescribeInstancesOutput{ + Reservations: []ec2types.Reservation{ + { + OwnerId: awsv2.String("owner"), + Instances: tc.foundEC2Instances, + }, + }, + }} ctx := context.Background() // Create and start test auth server. @@ -696,7 +691,10 @@ func TestDiscoveryServer(t *testing.T) { } server, err := New(authz.ContextWithUser(context.Background(), identity.I), &Config{ - CloudClients: testCloudClients, + CloudClients: testCloudClients, + GetEC2Client: func(ctx context.Context, region string, opts ...config.AWSOptionsFn) (ec2.DescribeInstancesAPIClient, error) { + return ec2Client, nil + }, ClusterFeatures: func() proto.Features { return proto.Features{} }, KubernetesClient: fake.NewSimpleClientset(), AccessPoint: getDiscoveryAccessPoint(tlsServer.Auth(), authClient), @@ -785,24 +783,34 @@ func TestDiscoveryServerConcurrency(t *testing.T) { }, } - testCloudClients := &cloud.TestCloudClients{ - EC2: &mockEC2Client{output: &ec2.DescribeInstancesOutput{ - Reservations: []*ec2.Reservation{{ - OwnerId: aws.String("123456789012"), - Instances: []*ec2.Instance{{ - InstanceId: aws.String("i-123456789012"), - Tags: []*ec2.Tag{{ - Key: aws.String("env"), - Value: aws.String("dev"), - }}, - PrivateIpAddress: aws.String("172.0.1.2"), - VpcId: aws.String("vpcId"), - SubnetId: aws.String("subnetId"), - PrivateDnsName: aws.String("privateDnsName"), - State: &ec2.InstanceState{Name: aws.String(ec2.InstanceStateNameRunning)}, - }}, - }}, - }}, + testCloudClients := &cloud.TestCloudClients{} + + ec2Client := &mockEC2Client{ + output: &ec2.DescribeInstancesOutput{ + Reservations: []ec2types.Reservation{ + { + OwnerId: awsv2.String("123456789012"), + Instances: []ec2types.Instance{ + { + InstanceId: awsv2.String("i-123456789012"), + Tags: []ec2types.Tag{ + { + Key: awsv2.String("env"), + Value: awsv2.String("dev"), + }, + }, + PrivateIpAddress: awsv2.String("172.0.1.2"), + VpcId: awsv2.String("vpcId"), + SubnetId: awsv2.String("subnetId"), + PrivateDnsName: awsv2.String("privateDnsName"), + State: &ec2types.InstanceState{ + Name: ec2types.InstanceStateNameRunning, + }, + }, + }, + }, + }, + }, } // Create and start test auth server. @@ -822,9 +830,14 @@ func TestDiscoveryServerConcurrency(t *testing.T) { require.NoError(t, err) t.Cleanup(func() { require.NoError(t, authClient.Close()) }) + getEC2Client := func(ctx context.Context, region string, opts ...config.AWSOptionsFn) (ec2.DescribeInstancesAPIClient, error) { + return ec2Client, nil + } + // Create Server1 server1, err := New(authz.ContextWithUser(ctx, identity.I), &Config{ CloudClients: testCloudClients, + GetEC2Client: getEC2Client, ClusterFeatures: func() proto.Features { return proto.Features{} }, KubernetesClient: fake.NewSimpleClientset(), AccessPoint: getDiscoveryAccessPoint(tlsServer.Auth(), authClient), @@ -838,6 +851,7 @@ func TestDiscoveryServerConcurrency(t *testing.T) { // Create Server2 server2, err := New(authz.ContextWithUser(ctx, identity.I), &Config{ CloudClients: testCloudClients, + GetEC2Client: getEC2Client, ClusterFeatures: func() proto.Features { return proto.Features{} }, KubernetesClient: fake.NewSimpleClientset(), AccessPoint: getDiscoveryAccessPoint(tlsServer.Auth(), authClient), diff --git a/lib/srv/discovery/fetchers/aws-sync/aws-sync.go b/lib/srv/discovery/fetchers/aws-sync/aws-sync.go index 4e2c5510def86..d96e44075195e 100644 --- a/lib/srv/discovery/fetchers/aws-sync/aws-sync.go +++ b/lib/srv/discovery/fetchers/aws-sync/aws-sync.go @@ -24,6 +24,8 @@ import ( "sync" "time" + awsv2 "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/aws/retry" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/client" "github.com/aws/aws-sdk-go/service/sts" @@ -33,6 +35,8 @@ import ( usageeventsv1 "github.com/gravitational/teleport/api/gen/proto/go/usageevents/v1" accessgraphv1alpha "github.com/gravitational/teleport/gen/proto/go/accessgraph/v1alpha" "github.com/gravitational/teleport/lib/cloud" + "github.com/gravitational/teleport/lib/cloud/aws/config" + "github.com/gravitational/teleport/lib/srv/server" ) // pageSize is the default page size to use when fetching AWS resources @@ -43,6 +47,8 @@ const pageSize int64 = 500 type Config struct { // CloudClients is the cloud clients to use when fetching AWS resources. CloudClients cloud.Clients + // GetEC2Client gets an AWS EC2 client for the given region. + GetEC2Client server.EC2ClientGetter // AccountID is the AWS account ID to use when fetching resources. AccountID string // Regions is the list of AWS regions to fetch resources from. @@ -318,6 +324,27 @@ func (a *awsFetcher) getAWSOptions() []cloud.AWSOptionsFn { return opts } +// getAWSV2Options returns a list of options to be used when +// creating AWS clients with the v2 sdk. +func (a *awsFetcher) getAWSV2Options() []config.AWSOptionsFn { + opts := []config.AWSOptionsFn{ + config.WithCredentialsMaybeIntegration(a.Config.Integration), + } + + if a.Config.AssumeRole != nil { + opts = append(opts, config.WithAssumeRole(a.Config.AssumeRole.RoleARN, a.Config.AssumeRole.ExternalID)) + } + const maxRetries = 10 + opts = append(opts, config.WithRetryer(func() awsv2.Retryer { + return retry.NewStandard(func(so *retry.StandardOptions) { + so.MaxAttempts = maxRetries + so.Backoff = retry.NewExponentialJitterBackoff(300 * time.Second) + }) + })) + + return opts +} + func (a *awsFetcher) getAccountId(ctx context.Context) (string, error) { stsClient, err := a.CloudClients.GetAWSSTSClient( ctx, diff --git a/lib/srv/discovery/fetchers/aws-sync/ec2.go b/lib/srv/discovery/fetchers/aws-sync/ec2.go index dcbe9b26f7167..7d32603676e24 100644 --- a/lib/srv/discovery/fetchers/aws-sync/ec2.go +++ b/lib/srv/discovery/fetchers/aws-sync/ec2.go @@ -24,7 +24,8 @@ import ( "time" "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go/service/ec2" + "github.com/aws/aws-sdk-go-v2/service/ec2" + ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" "github.com/aws/aws-sdk-go/service/iam" "github.com/gravitational/trace" "golang.org/x/sync/errgroup" @@ -32,6 +33,7 @@ import ( "google.golang.org/protobuf/types/known/wrapperspb" accessgraphv1alpha "github.com/gravitational/teleport/gen/proto/go/accessgraph/v1alpha" + libcloudaws "github.com/gravitational/teleport/lib/cloud/aws" ) // pollAWSEC2Instances is a function that returns a function that fetches @@ -86,16 +88,21 @@ func (a *awsFetcher) fetchAWSEC2Instances(ctx context.Context) ([]*accessgraphv1 return h.Region == region && h.AccountId == a.AccountID }, ) - ec2Client, err := a.CloudClients.GetAWSEC2Client(ctx, region, a.getAWSOptions()...) + ec2Client, err := a.GetEC2Client(ctx, region, a.getAWSV2Options()...) if err != nil { collectHosts(prevIterationEc2, trace.Wrap(err)) return nil } ctx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() - err = ec2Client.DescribeInstancesPagesWithContext(ctx, &ec2.DescribeInstancesInput{ - MaxResults: aws.Int64(pageSize), - }, func(page *ec2.DescribeInstancesOutput, lastPage bool) bool { + paginator := ec2.NewDescribeInstancesPaginator(ec2Client, &ec2.DescribeInstancesInput{ + MaxResults: aws.Int32(int32(pageSize)), + }) + for paginator.HasMorePages() { + page, err := paginator.NextPage(ctx) + if err != nil { + return libcloudaws.ConvertRequestFailureError(err) + } lHosts := make([]*accessgraphv1alpha.AWSInstanceV1, 0, len(page.Reservations)) for _, reservation := range page.Reservations { for _, instance := range reservation.Instances { @@ -103,11 +110,6 @@ func (a *awsFetcher) fetchAWSEC2Instances(ctx context.Context) ([]*accessgraphv1 } } collectHosts(lHosts, nil) - return !lastPage - }) - - if err != nil { - collectHosts(prevIterationEc2, trace.Wrap(err)) } return nil }) @@ -119,7 +121,7 @@ func (a *awsFetcher) fetchAWSEC2Instances(ctx context.Context) ([]*accessgraphv1 // awsInstanceToProtoInstance converts an ec2.Instance to accessgraphv1alpha.AWSInstanceV1 // representation. -func awsInstanceToProtoInstance(instance *ec2.Instance, region string, accountID string) *accessgraphv1alpha.AWSInstanceV1 { +func awsInstanceToProtoInstance(instance ec2types.Instance, region string, accountID string) *accessgraphv1alpha.AWSInstanceV1 { var tags []*accessgraphv1alpha.AWSTag for _, tag := range instance.Tags { tags = append(tags, &accessgraphv1alpha.AWSTag{ diff --git a/lib/srv/server/azure_watcher_test.go b/lib/srv/server/azure_watcher_test.go index 6c3989bc91a75..ad507911c6882 100644 --- a/lib/srv/server/azure_watcher_test.go +++ b/lib/srv/server/azure_watcher_test.go @@ -28,9 +28,16 @@ import ( "github.com/stretchr/testify/require" "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/cloud" "github.com/gravitational/teleport/lib/cloud/azure" ) +type mockClients struct { + cloud.Clients + + azureClient azure.VirtualMachinesClient +} + func (c *mockClients) GetAzureVirtualMachinesClient(subscription string) (azure.VirtualMachinesClient, error) { return c.azureClient, nil } diff --git a/lib/srv/server/ec2_watcher.go b/lib/srv/server/ec2_watcher.go index 25f12018e0170..4c6300e7bf661 100644 --- a/lib/srv/server/ec2_watcher.go +++ b/lib/srv/server/ec2_watcher.go @@ -23,16 +23,16 @@ import ( "sync" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/ec2" - "github.com/aws/aws-sdk-go/service/ec2/ec2iface" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ec2" + ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" "github.com/gravitational/trace" log "github.com/sirupsen/logrus" usageeventsv1 "github.com/gravitational/teleport/api/gen/proto/go/usageevents/v1" "github.com/gravitational/teleport/api/types" - "github.com/gravitational/teleport/lib/cloud" - awslib "github.com/gravitational/teleport/lib/cloud/aws" + libcloudaws "github.com/gravitational/teleport/lib/cloud/aws" + "github.com/gravitational/teleport/lib/cloud/aws/config" "github.com/gravitational/teleport/lib/labels" ) @@ -81,20 +81,20 @@ type EC2Instance struct { InstanceID string InstanceName string Tags map[string]string - OriginalInstance ec2.Instance + OriginalInstance ec2types.Instance } -func toEC2Instance(originalInst *ec2.Instance) EC2Instance { +func toEC2Instance(originalInst ec2types.Instance) EC2Instance { inst := EC2Instance{ - InstanceID: aws.StringValue(originalInst.InstanceId), + InstanceID: aws.ToString(originalInst.InstanceId), Tags: make(map[string]string, len(originalInst.Tags)), - OriginalInstance: *originalInst, + OriginalInstance: originalInst, } for _, tag := range originalInst.Tags { - if key := aws.StringValue(tag.Key); key != "" { - inst.Tags[key] = aws.StringValue(tag.Value) + if key := aws.ToString(tag.Key); key != "" { + inst.Tags[key] = aws.ToString(tag.Value) if key == "Name" { - inst.InstanceName = aws.StringValue(tag.Value) + inst.InstanceName = aws.ToString(tag.Value) } } } @@ -102,7 +102,7 @@ func toEC2Instance(originalInst *ec2.Instance) EC2Instance { } // ToEC2Instances converts aws []*ec2.Instance to []EC2Instance -func ToEC2Instances(insts []*ec2.Instance) []EC2Instance { +func ToEC2Instances(insts []ec2types.Instance) []EC2Instance { var ec2Insts []EC2Instance for _, inst := range insts { @@ -188,14 +188,17 @@ func NewEC2Watcher(ctx context.Context, fetchersFn func() []Fetcher, missedRotat return &watcher, nil } +// EC2ClientGetter gets an AWS EC2 client for the given region. +type EC2ClientGetter func(ctx context.Context, region string, opts ...config.AWSOptionsFn) (ec2.DescribeInstancesAPIClient, error) + // MatchersToEC2InstanceFetchers converts a list of AWS EC2 Matchers into a list of AWS EC2 Fetchers. -func MatchersToEC2InstanceFetchers(ctx context.Context, matchers []types.AWSMatcher, clients cloud.Clients, discoveryConfig string) ([]Fetcher, error) { +func MatchersToEC2InstanceFetchers(ctx context.Context, matchers []types.AWSMatcher, getEC2Client EC2ClientGetter, discoveryConfig string) ([]Fetcher, error) { ret := []Fetcher{} for _, matcher := range matchers { for _, region := range matcher.Regions { // TODO(gavin): support assume_role_arn for ec2. - ec2Client, err := clients.GetAWSEC2Client(ctx, region, - cloud.WithCredentialsMaybeIntegration(matcher.Integration), + ec2Client, err := getEC2Client(ctx, region, + config.WithCredentialsMaybeIntegration(matcher.Integration), ) if err != nil { return nil, trace.Wrap(err) @@ -221,7 +224,7 @@ type ec2FetcherConfig struct { Matcher types.AWSMatcher Region string Document string - EC2Client ec2iface.EC2API + EC2Client ec2.DescribeInstancesAPIClient Labels types.Labels Integration string DiscoveryConfig string @@ -229,8 +232,8 @@ type ec2FetcherConfig struct { } type ec2InstanceFetcher struct { - Filters []*ec2.Filter - EC2 ec2iface.EC2API + Filters []ec2types.Filter + EC2 ec2.DescribeInstancesAPIClient Region string DocumentName string Parameters map[string]string @@ -289,16 +292,16 @@ const ( const awsEC2APIChunkSize = 50 func newEC2InstanceFetcher(cfg ec2FetcherConfig) *ec2InstanceFetcher { - tagFilters := []*ec2.Filter{{ + tagFilters := []ec2types.Filter{{ Name: aws.String(AWSInstanceStateName), - Values: aws.StringSlice([]string{ec2.InstanceStateNameRunning}), + Values: []string{string(ec2types.InstanceStateNameRunning)}, }} if _, ok := cfg.Labels["*"]; !ok { for key, val := range cfg.Labels { - tagFilters = append(tagFilters, &ec2.Filter{ + tagFilters = append(tagFilters, ec2types.Filter{ Name: aws.String("tag:" + key), - Values: aws.StringSlice(val), + Values: val, }) } } else { @@ -411,38 +414,40 @@ func chunkInstances(insts EC2Instances) []Instances { func (f *ec2InstanceFetcher) GetInstances(ctx context.Context, rotation bool) ([]Instances, error) { var instances []Instances f.cachedInstances.clear() - err := f.EC2.DescribeInstancesPagesWithContext(ctx, &ec2.DescribeInstancesInput{ + paginator := ec2.NewDescribeInstancesPaginator(f.EC2, &ec2.DescribeInstancesInput{ Filters: f.Filters, - }, - func(dio *ec2.DescribeInstancesOutput, b bool) bool { - for _, res := range dio.Reservations { - for i := 0; i < len(res.Instances); i += awsEC2APIChunkSize { - end := i + awsEC2APIChunkSize - if end > len(res.Instances) { - end = len(res.Instances) - } - ownerID := aws.StringValue(res.OwnerId) - inst := EC2Instances{ - AccountID: ownerID, - Region: f.Region, - DocumentName: f.DocumentName, - Instances: ToEC2Instances(res.Instances[i:end]), - Parameters: f.Parameters, - Rotation: rotation, - Integration: f.Integration, - DiscoveryConfig: f.DiscoveryConfig, - EnrollMode: f.EnrollMode, - } - for _, ec2inst := range res.Instances[i:end] { - f.cachedInstances.add(ownerID, aws.StringValue(ec2inst.InstanceId)) - } - instances = append(instances, Instances{EC2: &inst}) + }) + + for paginator.HasMorePages() { + page, err := paginator.NextPage(ctx) + if err != nil { + return nil, libcloudaws.ConvertRequestFailureError(err) + } + + for _, res := range page.Reservations { + for i := 0; i < len(res.Instances); i += awsEC2APIChunkSize { + end := i + awsEC2APIChunkSize + if end > len(res.Instances) { + end = len(res.Instances) } + ownerID := aws.ToString(res.OwnerId) + inst := EC2Instances{ + AccountID: ownerID, + Region: f.Region, + DocumentName: f.DocumentName, + Instances: ToEC2Instances(res.Instances[i:end]), + Parameters: f.Parameters, + Rotation: rotation, + Integration: f.Integration, + DiscoveryConfig: f.DiscoveryConfig, + EnrollMode: f.EnrollMode, + } + for _, ec2inst := range res.Instances[i:end] { + f.cachedInstances.add(ownerID, aws.ToString(ec2inst.InstanceId)) + } + instances = append(instances, Instances{EC2: &inst}) } - return true - }) - if err != nil { - return nil, awslib.ConvertRequestFailureError(err) + } } if len(instances) == 0 { diff --git a/lib/srv/server/ec2_watcher_test.go b/lib/srv/server/ec2_watcher_test.go index 8279cbbd868f1..f7c9c0a85458d 100644 --- a/lib/srv/server/ec2_watcher_test.go +++ b/lib/srv/server/ec2_watcher_test.go @@ -22,81 +22,64 @@ import ( "context" "testing" + awsv2 "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ec2" + ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/request" - "github.com/aws/aws-sdk-go/service/ec2" - "github.com/aws/aws-sdk-go/service/ec2/ec2iface" "github.com/google/go-cmp/cmp" "github.com/stretchr/testify/require" usageeventsv1 "github.com/gravitational/teleport/api/gen/proto/go/usageevents/v1" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/utils" - "github.com/gravitational/teleport/lib/cloud" - "github.com/gravitational/teleport/lib/cloud/azure" + "github.com/gravitational/teleport/lib/cloud/aws/config" ) -type mockClients struct { - cloud.Clients - - ec2Client *mockEC2Client - azureClient azure.VirtualMachinesClient -} - -func (c *mockClients) GetAWSEC2Client(ctx context.Context, region string, _ ...cloud.AWSOptionsFn) (ec2iface.EC2API, error) { - return c.ec2Client, nil -} - type mockEC2Client struct { - ec2iface.EC2API output *ec2.DescribeInstancesOutput } -func instanceMatches(inst *ec2.Instance, filters []*ec2.Filter) bool { +func (m *mockEC2Client) DescribeInstances(ctx context.Context, input *ec2.DescribeInstancesInput, opts ...func(*ec2.Options)) (*ec2.DescribeInstancesOutput, error) { + var output ec2.DescribeInstancesOutput + for _, res := range m.output.Reservations { + var instances []ec2types.Instance + for _, inst := range res.Instances { + if instanceMatches(inst, input.Filters) { + instances = append(instances, inst) + } + } + output.Reservations = append(output.Reservations, ec2types.Reservation{ + Instances: instances, + }) + } + return &output, nil +} + +func instanceMatches(inst ec2types.Instance, filters []ec2types.Filter) bool { allMatched := true for _, filter := range filters { - name := aws.StringValue(filter.Name) - val := aws.StringValue(filter.Values[0]) - if name == AWSInstanceStateName && aws.StringValue(inst.State.Name) != ec2.InstanceStateNameRunning { + name := awsv2.ToString(filter.Name) + val := filter.Values[0] + if name == AWSInstanceStateName && inst.State.Name != ec2types.InstanceStateNameRunning { return false } for _, tag := range inst.Tags { - if aws.StringValue(tag.Key) != name[4:] { + if awsv2.ToString(tag.Key) != name[4:] { continue } - allMatched = allMatched && aws.StringValue(tag.Value) != val + allMatched = allMatched && awsv2.ToString(tag.Value) != val } } return !allMatched } -func (m *mockEC2Client) DescribeInstancesPagesWithContext( - ctx context.Context, input *ec2.DescribeInstancesInput, - f func(dio *ec2.DescribeInstancesOutput, b bool) bool, opts ...request.Option) error { - output := &ec2.DescribeInstancesOutput{} - for _, res := range m.output.Reservations { - var instances []*ec2.Instance - for _, inst := range res.Instances { - if instanceMatches(inst, input.Filters) { - instances = append(instances, inst) - } - } - output.Reservations = append(output.Reservations, &ec2.Reservation{ - Instances: instances, - }) - } - - f(output, true) - return nil -} - func TestNewEC2InstanceFetcherTags(t *testing.T) { t.Parallel() for _, tc := range []struct { name string config ec2FetcherConfig - expectedFilters []*ec2.Filter + expectedFilters []ec2types.Filter }{ { name: "with glob key", @@ -106,10 +89,10 @@ func TestNewEC2InstanceFetcherTags(t *testing.T) { "hello": []string{"other"}, }, }, - expectedFilters: []*ec2.Filter{ + expectedFilters: []ec2types.Filter{ { - Name: aws.String(AWSInstanceStateName), - Values: aws.StringSlice([]string{ec2.InstanceStateNameRunning}), + Name: awsv2.String(AWSInstanceStateName), + Values: []string{string(ec2types.InstanceStateNameRunning)}, }, }, }, @@ -120,14 +103,14 @@ func TestNewEC2InstanceFetcherTags(t *testing.T) { "hello": []string{"other"}, }, }, - expectedFilters: []*ec2.Filter{ + expectedFilters: []ec2types.Filter{ { - Name: aws.String(AWSInstanceStateName), - Values: aws.StringSlice([]string{ec2.InstanceStateNameRunning}), + Name: awsv2.String(AWSInstanceStateName), + Values: []string{string(ec2types.InstanceStateNameRunning)}, }, { - Name: aws.String("tag:hello"), - Values: aws.StringSlice([]string{"other"}), + Name: awsv2.String("tag:hello"), + Values: []string{"other"}, }, }, }, @@ -141,9 +124,7 @@ func TestNewEC2InstanceFetcherTags(t *testing.T) { func TestEC2Watcher(t *testing.T) { t.Parallel() - clients := mockClients{ - ec2Client: &mockEC2Client{}, - } + client := &mockEC2Client{} matchers := []types.AWSMatcher{ { Params: &types.InstallerParams{ @@ -174,80 +155,82 @@ func TestEC2Watcher(t *testing.T) { } ctx := context.Background() - present := ec2.Instance{ - InstanceId: aws.String("instance-present"), - Tags: []*ec2.Tag{ + present := ec2types.Instance{ + InstanceId: awsv2.String("instance-present"), + Tags: []ec2types.Tag{ { - Key: aws.String("teleport"), - Value: aws.String("yes"), + Key: awsv2.String("teleport"), + Value: awsv2.String("yes"), }, { - Key: aws.String("Name"), - Value: aws.String("Present"), + Key: awsv2.String("Name"), + Value: awsv2.String("Present"), }, }, - State: &ec2.InstanceState{ - Name: aws.String(ec2.InstanceStateNameRunning), + State: &ec2types.InstanceState{ + Name: ec2types.InstanceStateNameRunning, }, } - presentOther := ec2.Instance{ - InstanceId: aws.String("instance-present-2"), - Tags: []*ec2.Tag{{ - Key: aws.String("env"), - Value: aws.String("dev"), + presentOther := ec2types.Instance{ + InstanceId: awsv2.String("instance-present-2"), + Tags: []ec2types.Tag{{ + Key: awsv2.String("env"), + Value: awsv2.String("dev"), }}, - State: &ec2.InstanceState{ - Name: aws.String(ec2.InstanceStateNameRunning), + State: &ec2types.InstanceState{ + Name: ec2types.InstanceStateNameRunning, }, } - presentForEICE := ec2.Instance{ - InstanceId: aws.String("instance-present-3"), - Tags: []*ec2.Tag{{ - Key: aws.String("with-eice"), - Value: aws.String("please"), + presentForEICE := ec2types.Instance{ + InstanceId: awsv2.String("instance-present-3"), + Tags: []ec2types.Tag{{ + Key: awsv2.String("with-eice"), + Value: awsv2.String("please"), }}, - State: &ec2.InstanceState{ - Name: aws.String(ec2.InstanceStateNameRunning), + State: &ec2types.InstanceState{ + Name: ec2types.InstanceStateNameRunning, }, } output := ec2.DescribeInstancesOutput{ - Reservations: []*ec2.Reservation{{ - Instances: []*ec2.Instance{ - &present, - &presentOther, - &presentForEICE, + Reservations: []ec2types.Reservation{{ + Instances: []ec2types.Instance{ + present, + presentOther, + presentForEICE, { - InstanceId: aws.String("instance-absent"), - Tags: []*ec2.Tag{{ - Key: aws.String("env"), - Value: aws.String("prod"), + InstanceId: awsv2.String("instance-absent"), + Tags: []ec2types.Tag{{ + Key: awsv2.String("env"), + Value: awsv2.String("prod"), }}, - State: &ec2.InstanceState{ - Name: aws.String(ec2.InstanceStateNameRunning), + State: &ec2types.InstanceState{ + Name: ec2types.InstanceStateNameRunning, }, }, { - InstanceId: aws.String("instance-absent-3"), - Tags: []*ec2.Tag{{ - Key: aws.String("env"), - Value: aws.String("prod"), + InstanceId: awsv2.String("instance-absent-3"), + Tags: []ec2types.Tag{{ + Key: awsv2.String("env"), + Value: awsv2.String("prod"), }, { - Key: aws.String("teleport"), - Value: aws.String("yes"), + Key: awsv2.String("teleport"), + Value: awsv2.String("yes"), }}, - State: &ec2.InstanceState{ - Name: aws.String(ec2.InstanceStateNamePending), + State: &ec2types.InstanceState{ + Name: ec2types.InstanceStateNamePending, }, }, }, }}, } - clients.ec2Client.output = &output + client.output = &output const noDiscoveryConfig = "" fetchersFn := func() []Fetcher { - fetchers, err := MatchersToEC2InstanceFetchers(ctx, matchers, &clients, noDiscoveryConfig) + fetchers, err := MatchersToEC2InstanceFetchers(ctx, matchers, func(ctx context.Context, region string, opts ...config.AWSOptionsFn) (ec2.DescribeInstancesAPIClient, error) { + return client, nil + }, noDiscoveryConfig) require.NoError(t, err) return fetchers @@ -260,19 +243,19 @@ func TestEC2Watcher(t *testing.T) { result := <-watcher.InstancesC require.Equal(t, EC2Instances{ Region: "us-west-2", - Instances: []EC2Instance{toEC2Instance(&present)}, + Instances: []EC2Instance{toEC2Instance(present)}, Parameters: map[string]string{"token": "", "scriptName": ""}, }, *result.EC2) result = <-watcher.InstancesC require.Equal(t, EC2Instances{ Region: "us-west-2", - Instances: []EC2Instance{toEC2Instance(&presentOther)}, + Instances: []EC2Instance{toEC2Instance(presentOther)}, Parameters: map[string]string{"token": "", "scriptName": ""}, }, *result.EC2) result = <-watcher.InstancesC require.Equal(t, EC2Instances{ Region: "us-west-2", - Instances: []EC2Instance{toEC2Instance(&presentForEICE)}, + Instances: []EC2Instance{toEC2Instance(presentForEICE)}, Parameters: map[string]string{"token": "", "scriptName": "", "sshdConfigPath": ""}, Integration: "my-aws-integration", }, *result.EC2) @@ -368,9 +351,9 @@ func TestMakeEvents(t *testing.T) { } func TestToEC2Instances(t *testing.T) { - sampleInstance := &ec2.Instance{ + sampleInstance := ec2types.Instance{ InstanceId: aws.String("instance-001"), - Tags: []*ec2.Tag{ + Tags: []ec2types.Tag{ { Key: aws.String("teleport"), Value: aws.String("yes"), @@ -380,32 +363,32 @@ func TestToEC2Instances(t *testing.T) { Value: aws.String("MyInstanceName"), }, }, - State: &ec2.InstanceState{ - Name: aws.String(ec2.InstanceStateNameRunning), + State: &ec2types.InstanceState{ + Name: ec2types.InstanceStateNameRunning, }, } - sampleInstanceWithoutName := &ec2.Instance{ + sampleInstanceWithoutName := ec2types.Instance{ InstanceId: aws.String("instance-001"), - Tags: []*ec2.Tag{ + Tags: []ec2types.Tag{ { Key: aws.String("teleport"), Value: aws.String("yes"), }, }, - State: &ec2.InstanceState{ - Name: aws.String(ec2.InstanceStateNameRunning), + State: &ec2types.InstanceState{ + Name: ec2types.InstanceStateNameRunning, }, } for _, tt := range []struct { name string - input []*ec2.Instance + input []ec2types.Instance expected []EC2Instance }{ { name: "with name", - input: []*ec2.Instance{sampleInstance}, + input: []ec2types.Instance{sampleInstance}, expected: []EC2Instance{{ InstanceID: "instance-001", Tags: map[string]string{ @@ -413,18 +396,18 @@ func TestToEC2Instances(t *testing.T) { "teleport": "yes", }, InstanceName: "MyInstanceName", - OriginalInstance: *sampleInstance, + OriginalInstance: sampleInstance, }}, }, { name: "without name", - input: []*ec2.Instance{sampleInstanceWithoutName}, + input: []ec2types.Instance{sampleInstanceWithoutName}, expected: []EC2Instance{{ InstanceID: "instance-001", Tags: map[string]string{ "teleport": "yes", }, - OriginalInstance: *sampleInstanceWithoutName, + OriginalInstance: sampleInstanceWithoutName, }}, }, } {