diff --git a/lib/cloud/aws/config/config.go b/lib/cloud/aws/config/config.go
new file mode 100644
index 0000000000000..815ebea3d0230
--- /dev/null
+++ b/lib/cloud/aws/config/config.go
@@ -0,0 +1,237 @@
+// Teleport
+// Copyright (C) 2024 Gravitational, Inc.
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+package config
+
+import (
+ "context"
+ "log/slog"
+
+ "github.com/aws/aws-sdk-go-v2/aws"
+ "github.com/aws/aws-sdk-go-v2/config"
+ "github.com/aws/aws-sdk-go-v2/credentials/stscreds"
+ "github.com/aws/aws-sdk-go-v2/service/sts"
+ "github.com/gravitational/trace"
+
+ "github.com/gravitational/teleport/lib/modules"
+)
+
+const defaultRegion = "us-east-1"
+
+// credentialsSource defines where the credentials must come from.
+type credentialsSource int
+
+const (
+ // credentialsSourceAmbient uses the default Cloud SDK method to load the credentials.
+ credentialsSourceAmbient = iota + 1
+ // credentialsSourceIntegration uses an Integration to load the credentials.
+ credentialsSourceIntegration
+)
+
+// AWSIntegrationSessionProvider defines a function that creates a credential provider from a region and an integration.
+// This is used to generate aws configs for clients that must use an integration instead of ambient credentials.
+type AWSIntegrationCredentialProvider func(ctx context.Context, region, integration string) (aws.CredentialsProvider, error)
+
+// awsOptions is a struct of additional options for assuming an AWS role
+// when construction an underlying AWS config.
+type awsOptions struct {
+ // baseConfigis a config to use instead of the default config for an
+ // AWS region, which is used to enable role chaining.
+ baseConfig *aws.Config
+ // assumeRoleARN is the AWS IAM Role ARN to assume.
+ assumeRoleARN string
+ // assumeRoleExternalID is used to assume an external AWS IAM Role.
+ assumeRoleExternalID string
+ // credentialsSource describes which source to use to fetch credentials.
+ credentialsSource credentialsSource
+ // integration is the name of the integration to be used to fetch the credentials.
+ integration string
+ // awsIntegrationCredentialsProvider is the integration credential provider to use.
+ awsIntegrationCredentialsProvider AWSIntegrationCredentialProvider
+ // customRetryer is a custom retryer to use for the config.
+ customRetryer func() aws.Retryer
+ // maxRetries is the maximum number of retries to use for the config.
+ maxRetries *int
+}
+
+func (a *awsOptions) checkAndSetDefaults() error {
+ switch a.credentialsSource {
+ case credentialsSourceAmbient:
+ if a.integration != "" {
+ return trace.BadParameter("integration and ambient credentials cannot be used at the same time")
+ }
+ case credentialsSourceIntegration:
+ if a.integration == "" {
+ return trace.BadParameter("missing integration name")
+ }
+ default:
+ return trace.BadParameter("missing credentials source (ambient or integration)")
+ }
+
+ return nil
+}
+
+// AWSOptionsFn is an option function for setting additional options
+// when getting an AWS config.
+type AWSOptionsFn func(*awsOptions)
+
+// WithAssumeRole configures options needed for assuming an AWS role.
+func WithAssumeRole(roleARN, externalID string) AWSOptionsFn {
+ return func(options *awsOptions) {
+ options.assumeRoleARN = roleARN
+ options.assumeRoleExternalID = externalID
+ }
+}
+
+// WithRetryer sets a custom retryer for the config.
+func WithRetryer(retryer func() aws.Retryer) AWSOptionsFn {
+ return func(options *awsOptions) {
+ options.customRetryer = retryer
+ }
+}
+
+// WithMaxRetries sets the maximum allowed value for the sdk to keep retrying.
+func WithMaxRetries(maxRetries int) AWSOptionsFn {
+ return func(options *awsOptions) {
+ options.maxRetries = &maxRetries
+ }
+}
+
+// WithCredentialsMaybeIntegration sets the credential source to be
+// - ambient if the integration is an empty string
+// - integration, otherwise
+func WithCredentialsMaybeIntegration(integration string) AWSOptionsFn {
+ if integration != "" {
+ return withIntegrationCredentials(integration)
+ }
+
+ return WithAmbientCredentials()
+}
+
+// withIntegrationCredentials configures options with an Integration that must be used to fetch Credentials to assume a role.
+// This prevents the usage of AWS environment credentials.
+func withIntegrationCredentials(integration string) AWSOptionsFn {
+ return func(options *awsOptions) {
+ options.credentialsSource = credentialsSourceIntegration
+ options.integration = integration
+ }
+}
+
+// WithAmbientCredentials configures options to use the ambient credentials.
+func WithAmbientCredentials() AWSOptionsFn {
+ return func(options *awsOptions) {
+ options.credentialsSource = credentialsSourceAmbient
+ }
+}
+
+// WithAWSIntegrationCredentialProvider sets the integration credential provider.
+func WithAWSIntegrationCredentialProvider(cred AWSIntegrationCredentialProvider) AWSOptionsFn {
+ return func(options *awsOptions) {
+ options.awsIntegrationCredentialsProvider = cred
+ }
+}
+
+// GetAWSConfig returns an AWS config for the specified region, optionally
+// assuming AWS IAM Roles.
+func GetAWSConfig(ctx context.Context, region string, opts ...AWSOptionsFn) (aws.Config, error) {
+ var options awsOptions
+ for _, opt := range opts {
+ opt(&options)
+ }
+ if options.baseConfig == nil {
+ cfg, err := getAWSConfigForRegion(ctx, region, options)
+ if err != nil {
+ return aws.Config{}, trace.Wrap(err)
+ }
+ options.baseConfig = &cfg
+ }
+ if options.assumeRoleARN == "" {
+ return *options.baseConfig, nil
+ }
+ return getAWSConfigForRole(ctx, region, options)
+}
+
+// awsAmbientConfigProvider loads a new config using the environment variables.
+func awsAmbientConfigProvider(region string, cred aws.CredentialsProvider, options awsOptions) (aws.Config, error) {
+ opts := buildAWSConfigOptions(region, cred, options)
+ cfg, err := config.LoadDefaultConfig(context.Background(), opts...)
+ return cfg, trace.Wrap(err)
+}
+
+func buildAWSConfigOptions(region string, cred aws.CredentialsProvider, options awsOptions) []func(*config.LoadOptions) error {
+ opts := []func(*config.LoadOptions) error{
+ config.WithDefaultRegion(defaultRegion),
+ config.WithRegion(region),
+ config.WithCredentialsProvider(cred),
+ }
+ if modules.GetModules().IsBoringBinary() {
+ opts = append(opts, config.WithUseFIPSEndpoint(aws.FIPSEndpointStateEnabled))
+ }
+ if options.customRetryer != nil {
+ opts = append(opts, config.WithRetryer(options.customRetryer))
+ }
+ if options.maxRetries != nil {
+ opts = append(opts, config.WithRetryMaxAttempts(*options.maxRetries))
+ }
+ return opts
+}
+
+// getAWSConfigForRegion returns AWS config for the specified region.
+func getAWSConfigForRegion(ctx context.Context, region string, options awsOptions) (aws.Config, error) {
+ if err := options.checkAndSetDefaults(); err != nil {
+ return aws.Config{}, trace.Wrap(err)
+ }
+
+ var cred aws.CredentialsProvider
+ if options.credentialsSource == credentialsSourceIntegration {
+ if options.awsIntegrationCredentialsProvider == nil {
+ return aws.Config{}, trace.BadParameter("missing aws integration credential provider")
+ }
+
+ slog.DebugContext(ctx, "Initializing AWS config with integration", "region", region, "integration", options.integration)
+ var err error
+ cred, err = options.awsIntegrationCredentialsProvider(ctx, region, options.integration)
+ if err != nil {
+ return aws.Config{}, trace.Wrap(err)
+ }
+ } else {
+ slog.DebugContext(ctx, "Initializing AWS config from environment", "region", region)
+ }
+
+ cfg, err := awsAmbientConfigProvider(region, cred, options)
+ return cfg, trace.Wrap(err)
+}
+
+// getAWSConfigForRole returns an AWS config for the specified region and role.
+func getAWSConfigForRole(ctx context.Context, region string, options awsOptions) (aws.Config, error) {
+ if err := options.checkAndSetDefaults(); err != nil {
+ return aws.Config{}, trace.Wrap(err)
+ }
+
+ stsClient := sts.NewFromConfig(*options.baseConfig)
+ cred := stscreds.NewAssumeRoleProvider(stsClient, options.assumeRoleARN, func(aro *stscreds.AssumeRoleOptions) {
+ if options.assumeRoleExternalID != "" {
+ aro.ExternalID = aws.String(options.assumeRoleExternalID)
+ }
+ })
+ if _, err := cred.Retrieve(ctx); err != nil {
+ return aws.Config{}, trace.Wrap(err)
+ }
+
+ opts := buildAWSConfigOptions(region, cred, options)
+ cfg, err := config.LoadDefaultConfig(ctx, opts...)
+ return cfg, trace.Wrap(err)
+}
diff --git a/lib/cloud/aws/config/config_test.go b/lib/cloud/aws/config/config_test.go
new file mode 100644
index 0000000000000..b6a0b867b7965
--- /dev/null
+++ b/lib/cloud/aws/config/config_test.go
@@ -0,0 +1,104 @@
+// Teleport
+// Copyright (C) 2024 Gravitational, Inc.
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+package config
+
+import (
+ "context"
+ "testing"
+
+ "github.com/aws/aws-sdk-go-v2/aws"
+ "github.com/gravitational/trace"
+ "github.com/stretchr/testify/require"
+)
+
+type mockCredentialProvider struct {
+ cred aws.Credentials
+}
+
+func (m *mockCredentialProvider) Retrieve(ctx context.Context) (aws.Credentials, error) {
+ return m.cred, nil
+}
+
+func TestGetAWSConfigIntegration(t *testing.T) {
+ t.Parallel()
+ dummyIntegration := "integration-test"
+ dummyRegion := "test-region-123"
+
+ t.Run("without an integration credential provider, must return missing credential provider error", func(t *testing.T) {
+ ctx := context.Background()
+ _, err := GetAWSConfig(ctx, dummyRegion, WithCredentialsMaybeIntegration(dummyIntegration))
+ require.True(t, trace.IsBadParameter(err), "unexpected error: %v", err)
+ require.ErrorContains(t, err, "missing aws integration credential provider")
+ })
+
+ t.Run("with an integration credential provider, must return the credentials", func(t *testing.T) {
+ ctx := context.Background()
+
+ cfg, err := GetAWSConfig(ctx, dummyRegion,
+ WithCredentialsMaybeIntegration(dummyIntegration),
+ WithAWSIntegrationCredentialProvider(func(ctx context.Context, region, integration string) (aws.CredentialsProvider, error) {
+ if region == dummyRegion && integration == dummyIntegration {
+ return &mockCredentialProvider{
+ cred: aws.Credentials{
+ SessionToken: "foo-bar",
+ },
+ }, nil
+ }
+ return nil, trace.NotFound("no creds in region %q with integration %q", region, integration)
+ }))
+ require.NoError(t, err)
+ creds, err := cfg.Credentials.Retrieve(ctx)
+ require.NoError(t, err)
+ require.Equal(t, "foo-bar", creds.SessionToken)
+ })
+
+ t.Run("with an integration credential provider, but using an empty integration falls back to ambient credentials", func(t *testing.T) {
+ ctx := context.Background()
+
+ _, err := GetAWSConfig(ctx, dummyRegion,
+ WithCredentialsMaybeIntegration(""),
+ WithAWSIntegrationCredentialProvider(func(ctx context.Context, region, integration string) (aws.CredentialsProvider, error) {
+ require.Fail(t, "this function should not be called")
+ return nil, nil
+ }))
+ require.NoError(t, err)
+ })
+
+ t.Run("with an integration credential provider, but using ambient credentials", func(t *testing.T) {
+ ctx := context.Background()
+
+ _, err := GetAWSConfig(ctx, dummyRegion,
+ WithAmbientCredentials(),
+ WithAWSIntegrationCredentialProvider(func(ctx context.Context, region, integration string) (aws.CredentialsProvider, error) {
+ require.Fail(t, "this function should not be called")
+ return nil, nil
+ }))
+ require.NoError(t, err)
+ })
+
+ t.Run("with an integration credential provider, but no credential source", func(t *testing.T) {
+ ctx := context.Background()
+
+ _, err := GetAWSConfig(ctx, dummyRegion,
+ WithAWSIntegrationCredentialProvider(func(ctx context.Context, region, integration string) (aws.CredentialsProvider, error) {
+ require.Fail(t, "this function should not be called")
+ return nil, nil
+ }))
+ require.Error(t, err)
+ require.ErrorContains(t, err, "missing credentials source")
+ })
+}
diff --git a/lib/cloud/aws/errors.go b/lib/cloud/aws/errors.go
index f13e1cf36c836..472590d35b1d4 100644
--- a/lib/cloud/aws/errors.go
+++ b/lib/cloud/aws/errors.go
@@ -36,11 +36,14 @@ import (
// the error without modifying it.
func ConvertRequestFailureError(err error) error {
var requestErr awserr.RequestFailure
- if !errors.As(err, &requestErr) {
- return err
+ if errors.As(err, &requestErr) {
+ return convertRequestFailureErrorFromStatusCode(requestErr.StatusCode(), requestErr)
}
-
- return convertRequestFailureErrorFromStatusCode(requestErr.StatusCode(), requestErr)
+ var re *awshttp.ResponseError
+ if errors.As(err, &re) {
+ return convertRequestFailureErrorFromStatusCode(re.HTTPStatusCode(), re.Err)
+ }
+ return err
}
func convertRequestFailureErrorFromStatusCode(statusCode int, requestErr error) error {
diff --git a/lib/cloud/aws/errors_test.go b/lib/cloud/aws/errors_test.go
index 448c2ef9a6e24..165456bfdb25b 100644
--- a/lib/cloud/aws/errors_test.go
+++ b/lib/cloud/aws/errors_test.go
@@ -73,6 +73,18 @@ func TestConvertRequestFailureError(t *testing.T) {
inputError: errors.New("not-aws-error"),
wantUnmodified: true,
},
+ {
+ name: "v2 sdk error",
+ inputError: &awshttp.ResponseError{
+ ResponseError: &smithyhttp.ResponseError{
+ Response: &smithyhttp.Response{Response: &http.Response{
+ StatusCode: http.StatusNotFound,
+ }},
+ Err: trace.Errorf(""),
+ },
+ },
+ wantIsError: trace.IsNotFound,
+ },
}
for _, test := range tests {
diff --git a/lib/cloud/clients.go b/lib/cloud/clients.go
index bc21d126b9903..328ee76bcee0e 100644
--- a/lib/cloud/clients.go
+++ b/lib/cloud/clients.go
@@ -38,8 +38,6 @@ import (
"github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/aws/aws-sdk-go/aws/request"
awssession "github.com/aws/aws-sdk-go/aws/session"
- "github.com/aws/aws-sdk-go/service/ec2"
- "github.com/aws/aws-sdk-go/service/ec2/ec2iface"
"github.com/aws/aws-sdk-go/service/eks"
"github.com/aws/aws-sdk-go/service/eks/eksiface"
"github.com/aws/aws-sdk-go/service/elasticache"
@@ -135,8 +133,6 @@ type AWSClients interface {
GetAWSIAMClient(ctx context.Context, region string, opts ...AWSOptionsFn) (iamiface.IAMAPI, error)
// GetAWSSTSClient returns AWS STS client for the specified region.
GetAWSSTSClient(ctx context.Context, region string, opts ...AWSOptionsFn) (stsiface.STSAPI, error)
- // GetAWSEC2Client returns AWS EC2 client for the specified region.
- GetAWSEC2Client(ctx context.Context, region string, opts ...AWSOptionsFn) (ec2iface.EC2API, error)
// GetAWSSSMClient returns AWS SSM client for the specified region.
GetAWSSSMClient(ctx context.Context, region string, opts ...AWSOptionsFn) (ssmiface.SSMAPI, error)
// GetAWSEKSClient returns AWS EKS client for the specified region.
@@ -372,7 +368,7 @@ type credentialsSource int
const (
// credentialsSourceAmbient uses the default Cloud SDK method to load the credentials.
credentialsSourceAmbient = iota + 1
- // CredentialsSourceIntegration uses an Integration to load the credentials.
+ // credentialsSourceIntegration uses an Integration to load the credentials.
credentialsSourceIntegration
)
@@ -596,15 +592,6 @@ func (c *cloudClients) GetAWSSTSClient(ctx context.Context, region string, opts
return sts.New(session), nil
}
-// GetAWSEC2Client returns AWS EC2 client for the specified region.
-func (c *cloudClients) GetAWSEC2Client(ctx context.Context, region string, opts ...AWSOptionsFn) (ec2iface.EC2API, error) {
- session, err := c.GetAWSSession(ctx, region, opts...)
- if err != nil {
- return nil, trace.Wrap(err)
- }
- return ec2.New(session), nil
-}
-
// GetAWSSSMClient returns AWS SSM client for the specified region.
func (c *cloudClients) GetAWSSSMClient(ctx context.Context, region string, opts ...AWSOptionsFn) (ssmiface.SSMAPI, error) {
session, err := c.GetAWSSession(ctx, region, opts...)
@@ -1034,7 +1021,6 @@ type TestCloudClients struct {
GCPGKE gcp.GKEClient
GCPProjects gcp.ProjectsClient
GCPInstances gcp.InstancesClient
- EC2 ec2iface.EC2API
SSM ssmiface.SSMAPI
InstanceMetadata imds.Client
EKS eksiface.EKSAPI
@@ -1205,15 +1191,6 @@ func (c *TestCloudClients) GetAWSKMSClient(ctx context.Context, region string, o
return c.KMS, nil
}
-// GetAWSEC2Client returns AWS EC2 client for the specified region.
-func (c *TestCloudClients) GetAWSEC2Client(ctx context.Context, region string, opts ...AWSOptionsFn) (ec2iface.EC2API, error) {
- _, err := c.GetAWSSession(ctx, region, opts...)
- if err != nil {
- return nil, trace.Wrap(err)
- }
- return c.EC2, nil
-}
-
// GetAWSSSMClient returns an AWS SSM client
func (c *TestCloudClients) GetAWSSSMClient(ctx context.Context, region string, opts ...AWSOptionsFn) (ssmiface.SSMAPI, error) {
_, err := c.GetAWSSession(ctx, region, opts...)
diff --git a/lib/srv/discovery/access_graph.go b/lib/srv/discovery/access_graph.go
index 7d65f99f88ef5..4bc207b21df01 100644
--- a/lib/srv/discovery/access_graph.go
+++ b/lib/srv/discovery/access_graph.go
@@ -502,6 +502,7 @@ func (s *Server) accessGraphFetchersFromMatchers(ctx context.Context, matchers M
ctx,
aws_sync.Config{
CloudClients: s.CloudClients,
+ GetEC2Client: s.GetEC2Client,
AssumeRole: assumeRole,
Regions: awsFetcher.Regions,
Integration: awsFetcher.Integration,
diff --git a/lib/srv/discovery/discovery.go b/lib/srv/discovery/discovery.go
index efaa28ff71d47..d7ea81a940368 100644
--- a/lib/srv/discovery/discovery.go
+++ b/lib/srv/discovery/discovery.go
@@ -30,6 +30,8 @@ import (
"time"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v3"
+ awsv2 "github.com/aws/aws-sdk-go-v2/aws"
+ "github.com/aws/aws-sdk-go-v2/service/ec2"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/session"
@@ -52,6 +54,7 @@ import (
"github.com/gravitational/teleport/api/utils/retryutils"
"github.com/gravitational/teleport/lib/auth/authclient"
"github.com/gravitational/teleport/lib/cloud"
+ "github.com/gravitational/teleport/lib/cloud/aws/config"
gcpimds "github.com/gravitational/teleport/lib/cloud/imds/gcp"
"github.com/gravitational/teleport/lib/cryptosuites"
"github.com/gravitational/teleport/lib/integrations/awsoidc"
@@ -111,6 +114,8 @@ type gcpInstaller interface {
type Config struct {
// CloudClients is an interface for retrieving cloud clients.
CloudClients cloud.Clients
+ // GetEC2Client gets an AWS EC2 client for the given region.
+ GetEC2Client server.EC2ClientGetter
// IntegrationOnlyCredentials discards any Matcher that don't have an Integration.
// When true, ambient credentials (used by the Cloud SDKs) are not used.
IntegrationOnlyCredentials bool
@@ -214,6 +219,34 @@ kubernetes matchers are present.`)
}
c.CloudClients = cloudClients
}
+ if c.GetEC2Client == nil {
+ c.GetEC2Client = func(ctx context.Context, region string, opts ...config.AWSOptionsFn) (ec2.DescribeInstancesAPIClient, error) {
+ opts = append(opts, config.WithAWSIntegrationCredentialProvider(func(ctx context.Context, region, integrationName string) (awsv2.CredentialsProvider, error) {
+ integration, err := c.AccessPoint.GetIntegration(ctx, integrationName)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ if integration.GetAWSOIDCIntegrationSpec() == nil {
+ return nil, trace.BadParameter("integration does not have aws oidc spec fields %q", integrationName)
+ }
+ token, err := c.AccessPoint.GenerateAWSOIDCToken(ctx, integrationName)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ cred, err := awsoidc.NewAWSCredentialsProvider(ctx, &awsoidc.AWSClientRequest{
+ Token: token,
+ RoleARN: integration.GetAWSOIDCIntegrationSpec().RoleARN,
+ Region: region,
+ })
+ return cred, trace.Wrap(err)
+ }))
+ cfg, err := config.GetAWSConfig(ctx, region, opts...)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ return ec2.NewFromConfig(cfg), nil
+ }
+ }
if c.KubernetesClient == nil && len(c.Matchers.Kubernetes) > 0 {
cfg, err := rest.InClusterConfig()
if err != nil {
@@ -465,7 +498,7 @@ func (s *Server) initAWSWatchers(matchers []types.AWSMatcher) error {
})
const noDiscoveryConfig = ""
- s.staticServerAWSFetchers, err = server.MatchersToEC2InstanceFetchers(s.ctx, ec2Matchers, s.CloudClients, noDiscoveryConfig)
+ s.staticServerAWSFetchers, err = server.MatchersToEC2InstanceFetchers(s.ctx, ec2Matchers, s.GetEC2Client, noDiscoveryConfig)
if err != nil {
return trace.Wrap(err)
}
@@ -555,7 +588,7 @@ func (s *Server) awsServerFetchersFromMatchers(ctx context.Context, matchers []t
return matcherType == types.AWSMatcherEC2
})
- fetchers, err := server.MatchersToEC2InstanceFetchers(ctx, serverMatchers, s.CloudClients, discoveryConfig)
+ fetchers, err := server.MatchersToEC2InstanceFetchers(ctx, serverMatchers, s.GetEC2Client, discoveryConfig)
if err != nil {
return nil, trace.Wrap(err)
}
@@ -910,7 +943,7 @@ func (s *Server) heartbeatEICEInstance(instances *server.EC2Instances) {
nodesToUpsert := make([]types.Server, 0, len(instances.Instances))
// Add EC2 Instances using EICE method
for _, ec2Instance := range instances.Instances {
- eiceNode, err := common.NewAWSNodeFromEC2v1Instance(ec2Instance.OriginalInstance, awsInfo)
+ eiceNode, err := common.NewAWSNodeFromEC2Instance(ec2Instance.OriginalInstance, awsInfo)
if err != nil {
s.Log.WarnContext(s.ctx, "Error converting to Teleport EICE Node", "error", err, "instance_id", ec2Instance.InstanceID)
diff --git a/lib/srv/discovery/discovery_test.go b/lib/srv/discovery/discovery_test.go
index a4e379f63933e..d176c86ca6ba2 100644
--- a/lib/srv/discovery/discovery_test.go
+++ b/lib/srv/discovery/discovery_test.go
@@ -36,11 +36,12 @@ import (
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v3"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/redis/armredis/v2"
+ awsv2 "github.com/aws/aws-sdk-go-v2/aws"
+ "github.com/aws/aws-sdk-go-v2/service/ec2"
+ ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
- "github.com/aws/aws-sdk-go/service/ec2"
- "github.com/aws/aws-sdk-go/service/ec2/ec2iface"
"github.com/aws/aws-sdk-go/service/eks"
"github.com/aws/aws-sdk-go/service/eks/eksiface"
"github.com/aws/aws-sdk-go/service/rds"
@@ -77,6 +78,7 @@ import (
"github.com/gravitational/teleport/lib/auth/authclient"
"github.com/gravitational/teleport/lib/authz"
"github.com/gravitational/teleport/lib/cloud"
+ "github.com/gravitational/teleport/lib/cloud/aws/config"
"github.com/gravitational/teleport/lib/cloud/azure"
"github.com/gravitational/teleport/lib/cloud/gcp"
gcpimds "github.com/gravitational/teleport/lib/cloud/imds/gcp"
@@ -161,16 +163,11 @@ func (m *mockUsageReporter) DiscoveryFetchEventCount() int {
}
type mockEC2Client struct {
- ec2iface.EC2API
output *ec2.DescribeInstancesOutput
}
-func (m *mockEC2Client) DescribeInstancesPagesWithContext(
- ctx context.Context, input *ec2.DescribeInstancesInput,
- f func(dio *ec2.DescribeInstancesOutput, b bool) bool, opts ...request.Option,
-) error {
- f(m.output, true)
- return nil
+func (m *mockEC2Client) DescribeInstances(ctx context.Context, input *ec2.DescribeInstancesInput, opts ...func(*ec2.Options)) (*ec2.DescribeInstancesOutput, error) {
+ return m.output, nil
}
func genEC2InstanceIDs(n int) []string {
@@ -181,17 +178,17 @@ func genEC2InstanceIDs(n int) []string {
return ec2InstanceIDs
}
-func genEC2Instances(n int) []*ec2.Instance {
- var ec2Instances []*ec2.Instance
+func genEC2Instances(n int) []ec2types.Instance {
+ var ec2Instances []ec2types.Instance
for _, id := range genEC2InstanceIDs(n) {
- ec2Instances = append(ec2Instances, &ec2.Instance{
- InstanceId: aws.String(id),
- Tags: []*ec2.Tag{{
- Key: aws.String("env"),
- Value: aws.String("dev"),
+ ec2Instances = append(ec2Instances, ec2types.Instance{
+ InstanceId: awsv2.String(id),
+ Tags: []ec2types.Tag{{
+ Key: awsv2.String("env"),
+ Value: awsv2.String("dev"),
}},
- State: &ec2.InstanceState{
- Name: aws.String(ec2.InstanceStateNameRunning),
+ State: &ec2types.InstanceState{
+ Name: ec2types.InstanceStateNameRunning,
},
})
}
@@ -301,7 +298,7 @@ func TestDiscoveryServer(t *testing.T) {
name string
// presentInstances is a list of servers already present in teleport
presentInstances []types.Server
- foundEC2Instances []*ec2.Instance
+ foundEC2Instances []ec2types.Instance
ssm *mockSSMClient
emitter *mockEmitter
discoveryConfig *discoveryconfig.DiscoveryConfig
@@ -314,15 +311,15 @@ func TestDiscoveryServer(t *testing.T) {
{
name: "no nodes present, 1 found ",
presentInstances: []types.Server{},
- foundEC2Instances: []*ec2.Instance{
+ foundEC2Instances: []ec2types.Instance{
{
- InstanceId: aws.String("instance-id-1"),
- Tags: []*ec2.Tag{{
- Key: aws.String("env"),
- Value: aws.String("dev"),
+ InstanceId: awsv2.String("instance-id-1"),
+ Tags: []ec2types.Tag{{
+ Key: awsv2.String("env"),
+ Value: awsv2.String("dev"),
}},
- State: &ec2.InstanceState{
- Name: aws.String(ec2.InstanceStateNameRunning),
+ State: &ec2types.InstanceState{
+ Name: ec2types.InstanceStateNameRunning,
},
},
},
@@ -372,15 +369,15 @@ func TestDiscoveryServer(t *testing.T) {
},
},
},
- foundEC2Instances: []*ec2.Instance{
+ foundEC2Instances: []ec2types.Instance{
{
- InstanceId: aws.String("instance-id-1"),
- Tags: []*ec2.Tag{{
- Key: aws.String("env"),
- Value: aws.String("dev"),
+ InstanceId: awsv2.String("instance-id-1"),
+ Tags: []ec2types.Tag{{
+ Key: awsv2.String("env"),
+ Value: awsv2.String("dev"),
}},
- State: &ec2.InstanceState{
- Name: aws.String(ec2.InstanceStateNameRunning),
+ State: &ec2types.InstanceState{
+ Name: ec2types.InstanceStateNameRunning,
},
},
},
@@ -413,15 +410,15 @@ func TestDiscoveryServer(t *testing.T) {
},
},
},
- foundEC2Instances: []*ec2.Instance{
+ foundEC2Instances: []ec2types.Instance{
{
- InstanceId: aws.String("instance-id-1"),
- Tags: []*ec2.Tag{{
- Key: aws.String("env"),
- Value: aws.String("dev"),
+ InstanceId: awsv2.String("instance-id-1"),
+ Tags: []ec2types.Tag{{
+ Key: awsv2.String("env"),
+ Value: awsv2.String("dev"),
}},
- State: &ec2.InstanceState{
- Name: aws.String(ec2.InstanceStateNameRunning),
+ State: &ec2types.InstanceState{
+ Name: ec2types.InstanceStateNameRunning,
},
},
},
@@ -462,15 +459,15 @@ func TestDiscoveryServer(t *testing.T) {
{
name: "no nodes present, 1 found using dynamic matchers",
presentInstances: []types.Server{},
- foundEC2Instances: []*ec2.Instance{
+ foundEC2Instances: []ec2types.Instance{
{
- InstanceId: aws.String("instance-id-1"),
- Tags: []*ec2.Tag{{
- Key: aws.String("env"),
- Value: aws.String("dev"),
+ InstanceId: awsv2.String("instance-id-1"),
+ Tags: []ec2types.Tag{{
+ Key: awsv2.String("env"),
+ Value: awsv2.String("dev"),
}},
- State: &ec2.InstanceState{
- Name: aws.String(ec2.InstanceStateNameRunning),
+ State: &ec2types.InstanceState{
+ Name: ec2types.InstanceStateNameRunning,
},
},
},
@@ -509,15 +506,15 @@ func TestDiscoveryServer(t *testing.T) {
{
name: "one node found with Script mode using Integration credentials",
presentInstances: []types.Server{},
- foundEC2Instances: []*ec2.Instance{
+ foundEC2Instances: []ec2types.Instance{
{
- InstanceId: aws.String("instance-id-1"),
- Tags: []*ec2.Tag{{
- Key: aws.String("env"),
- Value: aws.String("dev"),
+ InstanceId: awsv2.String("instance-id-1"),
+ Tags: []ec2types.Tag{{
+ Key: awsv2.String("env"),
+ Value: awsv2.String("dev"),
}},
- State: &ec2.InstanceState{
- Name: aws.String(ec2.InstanceStateNameRunning),
+ State: &ec2types.InstanceState{
+ Name: ec2types.InstanceStateNameRunning,
},
},
},
@@ -571,15 +568,15 @@ func TestDiscoveryServer(t *testing.T) {
{
name: "one node found but SSM Run fails and DiscoverEC2 User Task is created",
presentInstances: []types.Server{},
- foundEC2Instances: []*ec2.Instance{
+ foundEC2Instances: []ec2types.Instance{
{
- InstanceId: aws.String("instance-id-1"),
- Tags: []*ec2.Tag{{
- Key: aws.String("env"),
- Value: aws.String("dev"),
+ InstanceId: awsv2.String("instance-id-1"),
+ Tags: []ec2types.Tag{{
+ Key: awsv2.String("env"),
+ Value: awsv2.String("dev"),
}},
- State: &ec2.InstanceState{
- Name: aws.String(ec2.InstanceStateNameRunning),
+ State: &ec2types.InstanceState{
+ Name: ec2types.InstanceStateNameRunning,
},
},
},
@@ -644,18 +641,16 @@ func TestDiscoveryServer(t *testing.T) {
t.Parallel()
testCloudClients := &cloud.TestCloudClients{
- EC2: &mockEC2Client{
- output: &ec2.DescribeInstancesOutput{
- Reservations: []*ec2.Reservation{
- {
- OwnerId: aws.String("owner"),
- Instances: tc.foundEC2Instances,
- },
- },
- },
- },
SSM: tc.ssm,
}
+ ec2Client := &mockEC2Client{output: &ec2.DescribeInstancesOutput{
+ Reservations: []ec2types.Reservation{
+ {
+ OwnerId: awsv2.String("owner"),
+ Instances: tc.foundEC2Instances,
+ },
+ },
+ }}
ctx := context.Background()
// Create and start test auth server.
@@ -696,7 +691,10 @@ func TestDiscoveryServer(t *testing.T) {
}
server, err := New(authz.ContextWithUser(context.Background(), identity.I), &Config{
- CloudClients: testCloudClients,
+ CloudClients: testCloudClients,
+ GetEC2Client: func(ctx context.Context, region string, opts ...config.AWSOptionsFn) (ec2.DescribeInstancesAPIClient, error) {
+ return ec2Client, nil
+ },
ClusterFeatures: func() proto.Features { return proto.Features{} },
KubernetesClient: fake.NewSimpleClientset(),
AccessPoint: getDiscoveryAccessPoint(tlsServer.Auth(), authClient),
@@ -785,24 +783,34 @@ func TestDiscoveryServerConcurrency(t *testing.T) {
},
}
- testCloudClients := &cloud.TestCloudClients{
- EC2: &mockEC2Client{output: &ec2.DescribeInstancesOutput{
- Reservations: []*ec2.Reservation{{
- OwnerId: aws.String("123456789012"),
- Instances: []*ec2.Instance{{
- InstanceId: aws.String("i-123456789012"),
- Tags: []*ec2.Tag{{
- Key: aws.String("env"),
- Value: aws.String("dev"),
- }},
- PrivateIpAddress: aws.String("172.0.1.2"),
- VpcId: aws.String("vpcId"),
- SubnetId: aws.String("subnetId"),
- PrivateDnsName: aws.String("privateDnsName"),
- State: &ec2.InstanceState{Name: aws.String(ec2.InstanceStateNameRunning)},
- }},
- }},
- }},
+ testCloudClients := &cloud.TestCloudClients{}
+
+ ec2Client := &mockEC2Client{
+ output: &ec2.DescribeInstancesOutput{
+ Reservations: []ec2types.Reservation{
+ {
+ OwnerId: awsv2.String("123456789012"),
+ Instances: []ec2types.Instance{
+ {
+ InstanceId: awsv2.String("i-123456789012"),
+ Tags: []ec2types.Tag{
+ {
+ Key: awsv2.String("env"),
+ Value: awsv2.String("dev"),
+ },
+ },
+ PrivateIpAddress: awsv2.String("172.0.1.2"),
+ VpcId: awsv2.String("vpcId"),
+ SubnetId: awsv2.String("subnetId"),
+ PrivateDnsName: awsv2.String("privateDnsName"),
+ State: &ec2types.InstanceState{
+ Name: ec2types.InstanceStateNameRunning,
+ },
+ },
+ },
+ },
+ },
+ },
}
// Create and start test auth server.
@@ -822,9 +830,14 @@ func TestDiscoveryServerConcurrency(t *testing.T) {
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, authClient.Close()) })
+ getEC2Client := func(ctx context.Context, region string, opts ...config.AWSOptionsFn) (ec2.DescribeInstancesAPIClient, error) {
+ return ec2Client, nil
+ }
+
// Create Server1
server1, err := New(authz.ContextWithUser(ctx, identity.I), &Config{
CloudClients: testCloudClients,
+ GetEC2Client: getEC2Client,
ClusterFeatures: func() proto.Features { return proto.Features{} },
KubernetesClient: fake.NewSimpleClientset(),
AccessPoint: getDiscoveryAccessPoint(tlsServer.Auth(), authClient),
@@ -838,6 +851,7 @@ func TestDiscoveryServerConcurrency(t *testing.T) {
// Create Server2
server2, err := New(authz.ContextWithUser(ctx, identity.I), &Config{
CloudClients: testCloudClients,
+ GetEC2Client: getEC2Client,
ClusterFeatures: func() proto.Features { return proto.Features{} },
KubernetesClient: fake.NewSimpleClientset(),
AccessPoint: getDiscoveryAccessPoint(tlsServer.Auth(), authClient),
diff --git a/lib/srv/discovery/fetchers/aws-sync/aws-sync.go b/lib/srv/discovery/fetchers/aws-sync/aws-sync.go
index 4e2c5510def86..d96e44075195e 100644
--- a/lib/srv/discovery/fetchers/aws-sync/aws-sync.go
+++ b/lib/srv/discovery/fetchers/aws-sync/aws-sync.go
@@ -24,6 +24,8 @@ import (
"sync"
"time"
+ awsv2 "github.com/aws/aws-sdk-go-v2/aws"
+ "github.com/aws/aws-sdk-go-v2/aws/retry"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/service/sts"
@@ -33,6 +35,8 @@ import (
usageeventsv1 "github.com/gravitational/teleport/api/gen/proto/go/usageevents/v1"
accessgraphv1alpha "github.com/gravitational/teleport/gen/proto/go/accessgraph/v1alpha"
"github.com/gravitational/teleport/lib/cloud"
+ "github.com/gravitational/teleport/lib/cloud/aws/config"
+ "github.com/gravitational/teleport/lib/srv/server"
)
// pageSize is the default page size to use when fetching AWS resources
@@ -43,6 +47,8 @@ const pageSize int64 = 500
type Config struct {
// CloudClients is the cloud clients to use when fetching AWS resources.
CloudClients cloud.Clients
+ // GetEC2Client gets an AWS EC2 client for the given region.
+ GetEC2Client server.EC2ClientGetter
// AccountID is the AWS account ID to use when fetching resources.
AccountID string
// Regions is the list of AWS regions to fetch resources from.
@@ -318,6 +324,27 @@ func (a *awsFetcher) getAWSOptions() []cloud.AWSOptionsFn {
return opts
}
+// getAWSV2Options returns a list of options to be used when
+// creating AWS clients with the v2 sdk.
+func (a *awsFetcher) getAWSV2Options() []config.AWSOptionsFn {
+ opts := []config.AWSOptionsFn{
+ config.WithCredentialsMaybeIntegration(a.Config.Integration),
+ }
+
+ if a.Config.AssumeRole != nil {
+ opts = append(opts, config.WithAssumeRole(a.Config.AssumeRole.RoleARN, a.Config.AssumeRole.ExternalID))
+ }
+ const maxRetries = 10
+ opts = append(opts, config.WithRetryer(func() awsv2.Retryer {
+ return retry.NewStandard(func(so *retry.StandardOptions) {
+ so.MaxAttempts = maxRetries
+ so.Backoff = retry.NewExponentialJitterBackoff(300 * time.Second)
+ })
+ }))
+
+ return opts
+}
+
func (a *awsFetcher) getAccountId(ctx context.Context) (string, error) {
stsClient, err := a.CloudClients.GetAWSSTSClient(
ctx,
diff --git a/lib/srv/discovery/fetchers/aws-sync/ec2.go b/lib/srv/discovery/fetchers/aws-sync/ec2.go
index dcbe9b26f7167..7d32603676e24 100644
--- a/lib/srv/discovery/fetchers/aws-sync/ec2.go
+++ b/lib/srv/discovery/fetchers/aws-sync/ec2.go
@@ -24,7 +24,8 @@ import (
"time"
"github.com/aws/aws-sdk-go-v2/aws"
- "github.com/aws/aws-sdk-go/service/ec2"
+ "github.com/aws/aws-sdk-go-v2/service/ec2"
+ ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types"
"github.com/aws/aws-sdk-go/service/iam"
"github.com/gravitational/trace"
"golang.org/x/sync/errgroup"
@@ -32,6 +33,7 @@ import (
"google.golang.org/protobuf/types/known/wrapperspb"
accessgraphv1alpha "github.com/gravitational/teleport/gen/proto/go/accessgraph/v1alpha"
+ libcloudaws "github.com/gravitational/teleport/lib/cloud/aws"
)
// pollAWSEC2Instances is a function that returns a function that fetches
@@ -86,16 +88,21 @@ func (a *awsFetcher) fetchAWSEC2Instances(ctx context.Context) ([]*accessgraphv1
return h.Region == region && h.AccountId == a.AccountID
},
)
- ec2Client, err := a.CloudClients.GetAWSEC2Client(ctx, region, a.getAWSOptions()...)
+ ec2Client, err := a.GetEC2Client(ctx, region, a.getAWSV2Options()...)
if err != nil {
collectHosts(prevIterationEc2, trace.Wrap(err))
return nil
}
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
- err = ec2Client.DescribeInstancesPagesWithContext(ctx, &ec2.DescribeInstancesInput{
- MaxResults: aws.Int64(pageSize),
- }, func(page *ec2.DescribeInstancesOutput, lastPage bool) bool {
+ paginator := ec2.NewDescribeInstancesPaginator(ec2Client, &ec2.DescribeInstancesInput{
+ MaxResults: aws.Int32(int32(pageSize)),
+ })
+ for paginator.HasMorePages() {
+ page, err := paginator.NextPage(ctx)
+ if err != nil {
+ return libcloudaws.ConvertRequestFailureError(err)
+ }
lHosts := make([]*accessgraphv1alpha.AWSInstanceV1, 0, len(page.Reservations))
for _, reservation := range page.Reservations {
for _, instance := range reservation.Instances {
@@ -103,11 +110,6 @@ func (a *awsFetcher) fetchAWSEC2Instances(ctx context.Context) ([]*accessgraphv1
}
}
collectHosts(lHosts, nil)
- return !lastPage
- })
-
- if err != nil {
- collectHosts(prevIterationEc2, trace.Wrap(err))
}
return nil
})
@@ -119,7 +121,7 @@ func (a *awsFetcher) fetchAWSEC2Instances(ctx context.Context) ([]*accessgraphv1
// awsInstanceToProtoInstance converts an ec2.Instance to accessgraphv1alpha.AWSInstanceV1
// representation.
-func awsInstanceToProtoInstance(instance *ec2.Instance, region string, accountID string) *accessgraphv1alpha.AWSInstanceV1 {
+func awsInstanceToProtoInstance(instance ec2types.Instance, region string, accountID string) *accessgraphv1alpha.AWSInstanceV1 {
var tags []*accessgraphv1alpha.AWSTag
for _, tag := range instance.Tags {
tags = append(tags, &accessgraphv1alpha.AWSTag{
diff --git a/lib/srv/server/azure_watcher_test.go b/lib/srv/server/azure_watcher_test.go
index 6c3989bc91a75..ad507911c6882 100644
--- a/lib/srv/server/azure_watcher_test.go
+++ b/lib/srv/server/azure_watcher_test.go
@@ -28,9 +28,16 @@ import (
"github.com/stretchr/testify/require"
"github.com/gravitational/teleport/api/types"
+ "github.com/gravitational/teleport/lib/cloud"
"github.com/gravitational/teleport/lib/cloud/azure"
)
+type mockClients struct {
+ cloud.Clients
+
+ azureClient azure.VirtualMachinesClient
+}
+
func (c *mockClients) GetAzureVirtualMachinesClient(subscription string) (azure.VirtualMachinesClient, error) {
return c.azureClient, nil
}
diff --git a/lib/srv/server/ec2_watcher.go b/lib/srv/server/ec2_watcher.go
index 25f12018e0170..4c6300e7bf661 100644
--- a/lib/srv/server/ec2_watcher.go
+++ b/lib/srv/server/ec2_watcher.go
@@ -23,16 +23,16 @@ import (
"sync"
"time"
- "github.com/aws/aws-sdk-go/aws"
- "github.com/aws/aws-sdk-go/service/ec2"
- "github.com/aws/aws-sdk-go/service/ec2/ec2iface"
+ "github.com/aws/aws-sdk-go-v2/aws"
+ "github.com/aws/aws-sdk-go-v2/service/ec2"
+ ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types"
"github.com/gravitational/trace"
log "github.com/sirupsen/logrus"
usageeventsv1 "github.com/gravitational/teleport/api/gen/proto/go/usageevents/v1"
"github.com/gravitational/teleport/api/types"
- "github.com/gravitational/teleport/lib/cloud"
- awslib "github.com/gravitational/teleport/lib/cloud/aws"
+ libcloudaws "github.com/gravitational/teleport/lib/cloud/aws"
+ "github.com/gravitational/teleport/lib/cloud/aws/config"
"github.com/gravitational/teleport/lib/labels"
)
@@ -81,20 +81,20 @@ type EC2Instance struct {
InstanceID string
InstanceName string
Tags map[string]string
- OriginalInstance ec2.Instance
+ OriginalInstance ec2types.Instance
}
-func toEC2Instance(originalInst *ec2.Instance) EC2Instance {
+func toEC2Instance(originalInst ec2types.Instance) EC2Instance {
inst := EC2Instance{
- InstanceID: aws.StringValue(originalInst.InstanceId),
+ InstanceID: aws.ToString(originalInst.InstanceId),
Tags: make(map[string]string, len(originalInst.Tags)),
- OriginalInstance: *originalInst,
+ OriginalInstance: originalInst,
}
for _, tag := range originalInst.Tags {
- if key := aws.StringValue(tag.Key); key != "" {
- inst.Tags[key] = aws.StringValue(tag.Value)
+ if key := aws.ToString(tag.Key); key != "" {
+ inst.Tags[key] = aws.ToString(tag.Value)
if key == "Name" {
- inst.InstanceName = aws.StringValue(tag.Value)
+ inst.InstanceName = aws.ToString(tag.Value)
}
}
}
@@ -102,7 +102,7 @@ func toEC2Instance(originalInst *ec2.Instance) EC2Instance {
}
// ToEC2Instances converts aws []*ec2.Instance to []EC2Instance
-func ToEC2Instances(insts []*ec2.Instance) []EC2Instance {
+func ToEC2Instances(insts []ec2types.Instance) []EC2Instance {
var ec2Insts []EC2Instance
for _, inst := range insts {
@@ -188,14 +188,17 @@ func NewEC2Watcher(ctx context.Context, fetchersFn func() []Fetcher, missedRotat
return &watcher, nil
}
+// EC2ClientGetter gets an AWS EC2 client for the given region.
+type EC2ClientGetter func(ctx context.Context, region string, opts ...config.AWSOptionsFn) (ec2.DescribeInstancesAPIClient, error)
+
// MatchersToEC2InstanceFetchers converts a list of AWS EC2 Matchers into a list of AWS EC2 Fetchers.
-func MatchersToEC2InstanceFetchers(ctx context.Context, matchers []types.AWSMatcher, clients cloud.Clients, discoveryConfig string) ([]Fetcher, error) {
+func MatchersToEC2InstanceFetchers(ctx context.Context, matchers []types.AWSMatcher, getEC2Client EC2ClientGetter, discoveryConfig string) ([]Fetcher, error) {
ret := []Fetcher{}
for _, matcher := range matchers {
for _, region := range matcher.Regions {
// TODO(gavin): support assume_role_arn for ec2.
- ec2Client, err := clients.GetAWSEC2Client(ctx, region,
- cloud.WithCredentialsMaybeIntegration(matcher.Integration),
+ ec2Client, err := getEC2Client(ctx, region,
+ config.WithCredentialsMaybeIntegration(matcher.Integration),
)
if err != nil {
return nil, trace.Wrap(err)
@@ -221,7 +224,7 @@ type ec2FetcherConfig struct {
Matcher types.AWSMatcher
Region string
Document string
- EC2Client ec2iface.EC2API
+ EC2Client ec2.DescribeInstancesAPIClient
Labels types.Labels
Integration string
DiscoveryConfig string
@@ -229,8 +232,8 @@ type ec2FetcherConfig struct {
}
type ec2InstanceFetcher struct {
- Filters []*ec2.Filter
- EC2 ec2iface.EC2API
+ Filters []ec2types.Filter
+ EC2 ec2.DescribeInstancesAPIClient
Region string
DocumentName string
Parameters map[string]string
@@ -289,16 +292,16 @@ const (
const awsEC2APIChunkSize = 50
func newEC2InstanceFetcher(cfg ec2FetcherConfig) *ec2InstanceFetcher {
- tagFilters := []*ec2.Filter{{
+ tagFilters := []ec2types.Filter{{
Name: aws.String(AWSInstanceStateName),
- Values: aws.StringSlice([]string{ec2.InstanceStateNameRunning}),
+ Values: []string{string(ec2types.InstanceStateNameRunning)},
}}
if _, ok := cfg.Labels["*"]; !ok {
for key, val := range cfg.Labels {
- tagFilters = append(tagFilters, &ec2.Filter{
+ tagFilters = append(tagFilters, ec2types.Filter{
Name: aws.String("tag:" + key),
- Values: aws.StringSlice(val),
+ Values: val,
})
}
} else {
@@ -411,38 +414,40 @@ func chunkInstances(insts EC2Instances) []Instances {
func (f *ec2InstanceFetcher) GetInstances(ctx context.Context, rotation bool) ([]Instances, error) {
var instances []Instances
f.cachedInstances.clear()
- err := f.EC2.DescribeInstancesPagesWithContext(ctx, &ec2.DescribeInstancesInput{
+ paginator := ec2.NewDescribeInstancesPaginator(f.EC2, &ec2.DescribeInstancesInput{
Filters: f.Filters,
- },
- func(dio *ec2.DescribeInstancesOutput, b bool) bool {
- for _, res := range dio.Reservations {
- for i := 0; i < len(res.Instances); i += awsEC2APIChunkSize {
- end := i + awsEC2APIChunkSize
- if end > len(res.Instances) {
- end = len(res.Instances)
- }
- ownerID := aws.StringValue(res.OwnerId)
- inst := EC2Instances{
- AccountID: ownerID,
- Region: f.Region,
- DocumentName: f.DocumentName,
- Instances: ToEC2Instances(res.Instances[i:end]),
- Parameters: f.Parameters,
- Rotation: rotation,
- Integration: f.Integration,
- DiscoveryConfig: f.DiscoveryConfig,
- EnrollMode: f.EnrollMode,
- }
- for _, ec2inst := range res.Instances[i:end] {
- f.cachedInstances.add(ownerID, aws.StringValue(ec2inst.InstanceId))
- }
- instances = append(instances, Instances{EC2: &inst})
+ })
+
+ for paginator.HasMorePages() {
+ page, err := paginator.NextPage(ctx)
+ if err != nil {
+ return nil, libcloudaws.ConvertRequestFailureError(err)
+ }
+
+ for _, res := range page.Reservations {
+ for i := 0; i < len(res.Instances); i += awsEC2APIChunkSize {
+ end := i + awsEC2APIChunkSize
+ if end > len(res.Instances) {
+ end = len(res.Instances)
}
+ ownerID := aws.ToString(res.OwnerId)
+ inst := EC2Instances{
+ AccountID: ownerID,
+ Region: f.Region,
+ DocumentName: f.DocumentName,
+ Instances: ToEC2Instances(res.Instances[i:end]),
+ Parameters: f.Parameters,
+ Rotation: rotation,
+ Integration: f.Integration,
+ DiscoveryConfig: f.DiscoveryConfig,
+ EnrollMode: f.EnrollMode,
+ }
+ for _, ec2inst := range res.Instances[i:end] {
+ f.cachedInstances.add(ownerID, aws.ToString(ec2inst.InstanceId))
+ }
+ instances = append(instances, Instances{EC2: &inst})
}
- return true
- })
- if err != nil {
- return nil, awslib.ConvertRequestFailureError(err)
+ }
}
if len(instances) == 0 {
diff --git a/lib/srv/server/ec2_watcher_test.go b/lib/srv/server/ec2_watcher_test.go
index 8279cbbd868f1..f7c9c0a85458d 100644
--- a/lib/srv/server/ec2_watcher_test.go
+++ b/lib/srv/server/ec2_watcher_test.go
@@ -22,81 +22,64 @@ import (
"context"
"testing"
+ awsv2 "github.com/aws/aws-sdk-go-v2/aws"
+ "github.com/aws/aws-sdk-go-v2/service/ec2"
+ ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types"
"github.com/aws/aws-sdk-go/aws"
- "github.com/aws/aws-sdk-go/aws/request"
- "github.com/aws/aws-sdk-go/service/ec2"
- "github.com/aws/aws-sdk-go/service/ec2/ec2iface"
"github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/require"
usageeventsv1 "github.com/gravitational/teleport/api/gen/proto/go/usageevents/v1"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/api/utils"
- "github.com/gravitational/teleport/lib/cloud"
- "github.com/gravitational/teleport/lib/cloud/azure"
+ "github.com/gravitational/teleport/lib/cloud/aws/config"
)
-type mockClients struct {
- cloud.Clients
-
- ec2Client *mockEC2Client
- azureClient azure.VirtualMachinesClient
-}
-
-func (c *mockClients) GetAWSEC2Client(ctx context.Context, region string, _ ...cloud.AWSOptionsFn) (ec2iface.EC2API, error) {
- return c.ec2Client, nil
-}
-
type mockEC2Client struct {
- ec2iface.EC2API
output *ec2.DescribeInstancesOutput
}
-func instanceMatches(inst *ec2.Instance, filters []*ec2.Filter) bool {
+func (m *mockEC2Client) DescribeInstances(ctx context.Context, input *ec2.DescribeInstancesInput, opts ...func(*ec2.Options)) (*ec2.DescribeInstancesOutput, error) {
+ var output ec2.DescribeInstancesOutput
+ for _, res := range m.output.Reservations {
+ var instances []ec2types.Instance
+ for _, inst := range res.Instances {
+ if instanceMatches(inst, input.Filters) {
+ instances = append(instances, inst)
+ }
+ }
+ output.Reservations = append(output.Reservations, ec2types.Reservation{
+ Instances: instances,
+ })
+ }
+ return &output, nil
+}
+
+func instanceMatches(inst ec2types.Instance, filters []ec2types.Filter) bool {
allMatched := true
for _, filter := range filters {
- name := aws.StringValue(filter.Name)
- val := aws.StringValue(filter.Values[0])
- if name == AWSInstanceStateName && aws.StringValue(inst.State.Name) != ec2.InstanceStateNameRunning {
+ name := awsv2.ToString(filter.Name)
+ val := filter.Values[0]
+ if name == AWSInstanceStateName && inst.State.Name != ec2types.InstanceStateNameRunning {
return false
}
for _, tag := range inst.Tags {
- if aws.StringValue(tag.Key) != name[4:] {
+ if awsv2.ToString(tag.Key) != name[4:] {
continue
}
- allMatched = allMatched && aws.StringValue(tag.Value) != val
+ allMatched = allMatched && awsv2.ToString(tag.Value) != val
}
}
return !allMatched
}
-func (m *mockEC2Client) DescribeInstancesPagesWithContext(
- ctx context.Context, input *ec2.DescribeInstancesInput,
- f func(dio *ec2.DescribeInstancesOutput, b bool) bool, opts ...request.Option) error {
- output := &ec2.DescribeInstancesOutput{}
- for _, res := range m.output.Reservations {
- var instances []*ec2.Instance
- for _, inst := range res.Instances {
- if instanceMatches(inst, input.Filters) {
- instances = append(instances, inst)
- }
- }
- output.Reservations = append(output.Reservations, &ec2.Reservation{
- Instances: instances,
- })
- }
-
- f(output, true)
- return nil
-}
-
func TestNewEC2InstanceFetcherTags(t *testing.T) {
t.Parallel()
for _, tc := range []struct {
name string
config ec2FetcherConfig
- expectedFilters []*ec2.Filter
+ expectedFilters []ec2types.Filter
}{
{
name: "with glob key",
@@ -106,10 +89,10 @@ func TestNewEC2InstanceFetcherTags(t *testing.T) {
"hello": []string{"other"},
},
},
- expectedFilters: []*ec2.Filter{
+ expectedFilters: []ec2types.Filter{
{
- Name: aws.String(AWSInstanceStateName),
- Values: aws.StringSlice([]string{ec2.InstanceStateNameRunning}),
+ Name: awsv2.String(AWSInstanceStateName),
+ Values: []string{string(ec2types.InstanceStateNameRunning)},
},
},
},
@@ -120,14 +103,14 @@ func TestNewEC2InstanceFetcherTags(t *testing.T) {
"hello": []string{"other"},
},
},
- expectedFilters: []*ec2.Filter{
+ expectedFilters: []ec2types.Filter{
{
- Name: aws.String(AWSInstanceStateName),
- Values: aws.StringSlice([]string{ec2.InstanceStateNameRunning}),
+ Name: awsv2.String(AWSInstanceStateName),
+ Values: []string{string(ec2types.InstanceStateNameRunning)},
},
{
- Name: aws.String("tag:hello"),
- Values: aws.StringSlice([]string{"other"}),
+ Name: awsv2.String("tag:hello"),
+ Values: []string{"other"},
},
},
},
@@ -141,9 +124,7 @@ func TestNewEC2InstanceFetcherTags(t *testing.T) {
func TestEC2Watcher(t *testing.T) {
t.Parallel()
- clients := mockClients{
- ec2Client: &mockEC2Client{},
- }
+ client := &mockEC2Client{}
matchers := []types.AWSMatcher{
{
Params: &types.InstallerParams{
@@ -174,80 +155,82 @@ func TestEC2Watcher(t *testing.T) {
}
ctx := context.Background()
- present := ec2.Instance{
- InstanceId: aws.String("instance-present"),
- Tags: []*ec2.Tag{
+ present := ec2types.Instance{
+ InstanceId: awsv2.String("instance-present"),
+ Tags: []ec2types.Tag{
{
- Key: aws.String("teleport"),
- Value: aws.String("yes"),
+ Key: awsv2.String("teleport"),
+ Value: awsv2.String("yes"),
},
{
- Key: aws.String("Name"),
- Value: aws.String("Present"),
+ Key: awsv2.String("Name"),
+ Value: awsv2.String("Present"),
},
},
- State: &ec2.InstanceState{
- Name: aws.String(ec2.InstanceStateNameRunning),
+ State: &ec2types.InstanceState{
+ Name: ec2types.InstanceStateNameRunning,
},
}
- presentOther := ec2.Instance{
- InstanceId: aws.String("instance-present-2"),
- Tags: []*ec2.Tag{{
- Key: aws.String("env"),
- Value: aws.String("dev"),
+ presentOther := ec2types.Instance{
+ InstanceId: awsv2.String("instance-present-2"),
+ Tags: []ec2types.Tag{{
+ Key: awsv2.String("env"),
+ Value: awsv2.String("dev"),
}},
- State: &ec2.InstanceState{
- Name: aws.String(ec2.InstanceStateNameRunning),
+ State: &ec2types.InstanceState{
+ Name: ec2types.InstanceStateNameRunning,
},
}
- presentForEICE := ec2.Instance{
- InstanceId: aws.String("instance-present-3"),
- Tags: []*ec2.Tag{{
- Key: aws.String("with-eice"),
- Value: aws.String("please"),
+ presentForEICE := ec2types.Instance{
+ InstanceId: awsv2.String("instance-present-3"),
+ Tags: []ec2types.Tag{{
+ Key: awsv2.String("with-eice"),
+ Value: awsv2.String("please"),
}},
- State: &ec2.InstanceState{
- Name: aws.String(ec2.InstanceStateNameRunning),
+ State: &ec2types.InstanceState{
+ Name: ec2types.InstanceStateNameRunning,
},
}
output := ec2.DescribeInstancesOutput{
- Reservations: []*ec2.Reservation{{
- Instances: []*ec2.Instance{
- &present,
- &presentOther,
- &presentForEICE,
+ Reservations: []ec2types.Reservation{{
+ Instances: []ec2types.Instance{
+ present,
+ presentOther,
+ presentForEICE,
{
- InstanceId: aws.String("instance-absent"),
- Tags: []*ec2.Tag{{
- Key: aws.String("env"),
- Value: aws.String("prod"),
+ InstanceId: awsv2.String("instance-absent"),
+ Tags: []ec2types.Tag{{
+ Key: awsv2.String("env"),
+ Value: awsv2.String("prod"),
}},
- State: &ec2.InstanceState{
- Name: aws.String(ec2.InstanceStateNameRunning),
+ State: &ec2types.InstanceState{
+ Name: ec2types.InstanceStateNameRunning,
},
},
{
- InstanceId: aws.String("instance-absent-3"),
- Tags: []*ec2.Tag{{
- Key: aws.String("env"),
- Value: aws.String("prod"),
+ InstanceId: awsv2.String("instance-absent-3"),
+ Tags: []ec2types.Tag{{
+ Key: awsv2.String("env"),
+ Value: awsv2.String("prod"),
}, {
- Key: aws.String("teleport"),
- Value: aws.String("yes"),
+ Key: awsv2.String("teleport"),
+ Value: awsv2.String("yes"),
}},
- State: &ec2.InstanceState{
- Name: aws.String(ec2.InstanceStateNamePending),
+ State: &ec2types.InstanceState{
+ Name: ec2types.InstanceStateNamePending,
},
},
},
}},
}
- clients.ec2Client.output = &output
+ client.output = &output
const noDiscoveryConfig = ""
fetchersFn := func() []Fetcher {
- fetchers, err := MatchersToEC2InstanceFetchers(ctx, matchers, &clients, noDiscoveryConfig)
+ fetchers, err := MatchersToEC2InstanceFetchers(ctx, matchers, func(ctx context.Context, region string, opts ...config.AWSOptionsFn) (ec2.DescribeInstancesAPIClient, error) {
+ return client, nil
+ }, noDiscoveryConfig)
require.NoError(t, err)
return fetchers
@@ -260,19 +243,19 @@ func TestEC2Watcher(t *testing.T) {
result := <-watcher.InstancesC
require.Equal(t, EC2Instances{
Region: "us-west-2",
- Instances: []EC2Instance{toEC2Instance(&present)},
+ Instances: []EC2Instance{toEC2Instance(present)},
Parameters: map[string]string{"token": "", "scriptName": ""},
}, *result.EC2)
result = <-watcher.InstancesC
require.Equal(t, EC2Instances{
Region: "us-west-2",
- Instances: []EC2Instance{toEC2Instance(&presentOther)},
+ Instances: []EC2Instance{toEC2Instance(presentOther)},
Parameters: map[string]string{"token": "", "scriptName": ""},
}, *result.EC2)
result = <-watcher.InstancesC
require.Equal(t, EC2Instances{
Region: "us-west-2",
- Instances: []EC2Instance{toEC2Instance(&presentForEICE)},
+ Instances: []EC2Instance{toEC2Instance(presentForEICE)},
Parameters: map[string]string{"token": "", "scriptName": "", "sshdConfigPath": ""},
Integration: "my-aws-integration",
}, *result.EC2)
@@ -368,9 +351,9 @@ func TestMakeEvents(t *testing.T) {
}
func TestToEC2Instances(t *testing.T) {
- sampleInstance := &ec2.Instance{
+ sampleInstance := ec2types.Instance{
InstanceId: aws.String("instance-001"),
- Tags: []*ec2.Tag{
+ Tags: []ec2types.Tag{
{
Key: aws.String("teleport"),
Value: aws.String("yes"),
@@ -380,32 +363,32 @@ func TestToEC2Instances(t *testing.T) {
Value: aws.String("MyInstanceName"),
},
},
- State: &ec2.InstanceState{
- Name: aws.String(ec2.InstanceStateNameRunning),
+ State: &ec2types.InstanceState{
+ Name: ec2types.InstanceStateNameRunning,
},
}
- sampleInstanceWithoutName := &ec2.Instance{
+ sampleInstanceWithoutName := ec2types.Instance{
InstanceId: aws.String("instance-001"),
- Tags: []*ec2.Tag{
+ Tags: []ec2types.Tag{
{
Key: aws.String("teleport"),
Value: aws.String("yes"),
},
},
- State: &ec2.InstanceState{
- Name: aws.String(ec2.InstanceStateNameRunning),
+ State: &ec2types.InstanceState{
+ Name: ec2types.InstanceStateNameRunning,
},
}
for _, tt := range []struct {
name string
- input []*ec2.Instance
+ input []ec2types.Instance
expected []EC2Instance
}{
{
name: "with name",
- input: []*ec2.Instance{sampleInstance},
+ input: []ec2types.Instance{sampleInstance},
expected: []EC2Instance{{
InstanceID: "instance-001",
Tags: map[string]string{
@@ -413,18 +396,18 @@ func TestToEC2Instances(t *testing.T) {
"teleport": "yes",
},
InstanceName: "MyInstanceName",
- OriginalInstance: *sampleInstance,
+ OriginalInstance: sampleInstance,
}},
},
{
name: "without name",
- input: []*ec2.Instance{sampleInstanceWithoutName},
+ input: []ec2types.Instance{sampleInstanceWithoutName},
expected: []EC2Instance{{
InstanceID: "instance-001",
Tags: map[string]string{
"teleport": "yes",
},
- OriginalInstance: *sampleInstanceWithoutName,
+ OriginalInstance: sampleInstanceWithoutName,
}},
},
} {