Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate EC2 clients in discovery service to AWS SDK v2 #48950

Merged
merged 1 commit into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
237 changes: 237 additions & 0 deletions lib/cloud/aws/config/config.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
// Teleport
// Copyright (C) 2024 Gravitational, Inc.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.

package config

import (
"context"
"log/slog"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials/stscreds"
"github.com/aws/aws-sdk-go-v2/service/sts"
"github.com/gravitational/trace"

"github.com/gravitational/teleport/lib/modules"
)

const defaultRegion = "us-east-1"

// credentialsSource defines where the credentials must come from.
type credentialsSource int

const (
// credentialsSourceAmbient uses the default Cloud SDK method to load the credentials.
credentialsSourceAmbient = iota + 1
// credentialsSourceIntegration uses an Integration to load the credentials.
credentialsSourceIntegration
)

// AWSIntegrationSessionProvider defines a function that creates a credential provider from a region and an integration.
// This is used to generate aws configs for clients that must use an integration instead of ambient credentials.
type AWSIntegrationCredentialProvider func(ctx context.Context, region, integration string) (aws.CredentialsProvider, error)

// awsOptions is a struct of additional options for assuming an AWS role
// when construction an underlying AWS config.
type awsOptions struct {
// baseConfigis a config to use instead of the default config for an
// AWS region, which is used to enable role chaining.
baseConfig *aws.Config
// assumeRoleARN is the AWS IAM Role ARN to assume.
assumeRoleARN string
// assumeRoleExternalID is used to assume an external AWS IAM Role.
assumeRoleExternalID string
// credentialsSource describes which source to use to fetch credentials.
credentialsSource credentialsSource
// integration is the name of the integration to be used to fetch the credentials.
integration string
// awsIntegrationCredentialsProvider is the integration credential provider to use.
awsIntegrationCredentialsProvider AWSIntegrationCredentialProvider
// customRetryer is a custom retryer to use for the config.
customRetryer func() aws.Retryer
// maxRetries is the maximum number of retries to use for the config.
maxRetries *int
}

func (a *awsOptions) checkAndSetDefaults() error {
switch a.credentialsSource {
case credentialsSourceAmbient:
if a.integration != "" {
return trace.BadParameter("integration and ambient credentials cannot be used at the same time")
}
case credentialsSourceIntegration:
if a.integration == "" {
return trace.BadParameter("missing integration name")
}
default:
return trace.BadParameter("missing credentials source (ambient or integration)")
}

return nil
}

// AWSOptionsFn is an option function for setting additional options
// when getting an AWS config.
type AWSOptionsFn func(*awsOptions)

// WithAssumeRole configures options needed for assuming an AWS role.
func WithAssumeRole(roleARN, externalID string) AWSOptionsFn {
return func(options *awsOptions) {
options.assumeRoleARN = roleARN
options.assumeRoleExternalID = externalID
}
}

// WithRetryer sets a custom retryer for the config.
func WithRetryer(retryer func() aws.Retryer) AWSOptionsFn {
return func(options *awsOptions) {
options.customRetryer = retryer
}
}

// WithMaxRetries sets the maximum allowed value for the sdk to keep retrying.
func WithMaxRetries(maxRetries int) AWSOptionsFn {
return func(options *awsOptions) {
options.maxRetries = &maxRetries
}
}

// WithCredentialsMaybeIntegration sets the credential source to be
// - ambient if the integration is an empty string
// - integration, otherwise
func WithCredentialsMaybeIntegration(integration string) AWSOptionsFn {
if integration != "" {
return withIntegrationCredentials(integration)
}

return WithAmbientCredentials()
}

// withIntegrationCredentials configures options with an Integration that must be used to fetch Credentials to assume a role.
// This prevents the usage of AWS environment credentials.
func withIntegrationCredentials(integration string) AWSOptionsFn {
return func(options *awsOptions) {
options.credentialsSource = credentialsSourceIntegration
options.integration = integration
}
}

// WithAmbientCredentials configures options to use the ambient credentials.
func WithAmbientCredentials() AWSOptionsFn {
return func(options *awsOptions) {
options.credentialsSource = credentialsSourceAmbient
}
}

// WithAWSIntegrationCredentialProvider sets the integration credential provider.
func WithAWSIntegrationCredentialProvider(cred AWSIntegrationCredentialProvider) AWSOptionsFn {
return func(options *awsOptions) {
options.awsIntegrationCredentialsProvider = cred
}
}

// GetAWSConfig returns an AWS config for the specified region, optionally
// assuming AWS IAM Roles.
func GetAWSConfig(ctx context.Context, region string, opts ...AWSOptionsFn) (aws.Config, error) {
var options awsOptions
for _, opt := range opts {
opt(&options)
}
if options.baseConfig == nil {
cfg, err := getAWSConfigForRegion(ctx, region, options)
if err != nil {
return aws.Config{}, trace.Wrap(err)
}
options.baseConfig = &cfg
}
if options.assumeRoleARN == "" {
return *options.baseConfig, nil
}
return getAWSConfigForRole(ctx, region, options)
}

// awsAmbientConfigProvider loads a new config using the environment variables.
func awsAmbientConfigProvider(region string, cred aws.CredentialsProvider, options awsOptions) (aws.Config, error) {
opts := buildAWSConfigOptions(region, cred, options)
cfg, err := config.LoadDefaultConfig(context.Background(), opts...)
return cfg, trace.Wrap(err)
}

func buildAWSConfigOptions(region string, cred aws.CredentialsProvider, options awsOptions) []func(*config.LoadOptions) error {
opts := []func(*config.LoadOptions) error{
config.WithDefaultRegion(defaultRegion),
config.WithRegion(region),
config.WithCredentialsProvider(cred),
}
if modules.GetModules().IsBoringBinary() {
opts = append(opts, config.WithUseFIPSEndpoint(aws.FIPSEndpointStateEnabled))
}
if options.customRetryer != nil {
opts = append(opts, config.WithRetryer(options.customRetryer))
}
if options.maxRetries != nil {
opts = append(opts, config.WithRetryMaxAttempts(*options.maxRetries))
}
return opts
}

// getAWSConfigForRegion returns AWS config for the specified region.
func getAWSConfigForRegion(ctx context.Context, region string, options awsOptions) (aws.Config, error) {
if err := options.checkAndSetDefaults(); err != nil {
return aws.Config{}, trace.Wrap(err)
}

var cred aws.CredentialsProvider
if options.credentialsSource == credentialsSourceIntegration {
if options.awsIntegrationCredentialsProvider == nil {
return aws.Config{}, trace.BadParameter("missing aws integration credential provider")
}

slog.DebugContext(ctx, "Initializing AWS config with integration", "region", region, "integration", options.integration)
var err error
cred, err = options.awsIntegrationCredentialsProvider(ctx, region, options.integration)
if err != nil {
return aws.Config{}, trace.Wrap(err)
}
} else {
slog.DebugContext(ctx, "Initializing AWS config from environment", "region", region)
}

cfg, err := awsAmbientConfigProvider(region, cred, options)
return cfg, trace.Wrap(err)
}

// getAWSConfigForRole returns an AWS config for the specified region and role.
func getAWSConfigForRole(ctx context.Context, region string, options awsOptions) (aws.Config, error) {
if err := options.checkAndSetDefaults(); err != nil {
return aws.Config{}, trace.Wrap(err)
}

stsClient := sts.NewFromConfig(*options.baseConfig)
cred := stscreds.NewAssumeRoleProvider(stsClient, options.assumeRoleARN, func(aro *stscreds.AssumeRoleOptions) {
greedy52 marked this conversation as resolved.
Show resolved Hide resolved
if options.assumeRoleExternalID != "" {
aro.ExternalID = aws.String(options.assumeRoleExternalID)
}
})
if _, err := cred.Retrieve(ctx); err != nil {
return aws.Config{}, trace.Wrap(err)
}

opts := buildAWSConfigOptions(region, cred, options)
cfg, err := config.LoadDefaultConfig(ctx, opts...)
return cfg, trace.Wrap(err)
}
104 changes: 104 additions & 0 deletions lib/cloud/aws/config/config_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
// Teleport
// Copyright (C) 2024 Gravitational, Inc.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.

package config

import (
"context"
"testing"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/gravitational/trace"
"github.com/stretchr/testify/require"
)

type mockCredentialProvider struct {
cred aws.Credentials
}

func (m *mockCredentialProvider) Retrieve(ctx context.Context) (aws.Credentials, error) {
return m.cred, nil
}

func TestGetAWSConfigIntegration(t *testing.T) {
t.Parallel()
dummyIntegration := "integration-test"
dummyRegion := "test-region-123"

t.Run("without an integration credential provider, must return missing credential provider error", func(t *testing.T) {
ctx := context.Background()
_, err := GetAWSConfig(ctx, dummyRegion, WithCredentialsMaybeIntegration(dummyIntegration))
require.True(t, trace.IsBadParameter(err), "unexpected error: %v", err)
require.ErrorContains(t, err, "missing aws integration credential provider")
})

t.Run("with an integration credential provider, must return the credentials", func(t *testing.T) {
ctx := context.Background()

cfg, err := GetAWSConfig(ctx, dummyRegion,
WithCredentialsMaybeIntegration(dummyIntegration),
WithAWSIntegrationCredentialProvider(func(ctx context.Context, region, integration string) (aws.CredentialsProvider, error) {
if region == dummyRegion && integration == dummyIntegration {
return &mockCredentialProvider{
cred: aws.Credentials{
SessionToken: "foo-bar",
},
}, nil
}
return nil, trace.NotFound("no creds in region %q with integration %q", region, integration)
}))
require.NoError(t, err)
creds, err := cfg.Credentials.Retrieve(ctx)
require.NoError(t, err)
require.Equal(t, "foo-bar", creds.SessionToken)
})

t.Run("with an integration credential provider, but using an empty integration falls back to ambient credentials", func(t *testing.T) {
ctx := context.Background()

_, err := GetAWSConfig(ctx, dummyRegion,
WithCredentialsMaybeIntegration(""),
WithAWSIntegrationCredentialProvider(func(ctx context.Context, region, integration string) (aws.CredentialsProvider, error) {
require.Fail(t, "this function should not be called")
return nil, nil
}))
require.NoError(t, err)
})

t.Run("with an integration credential provider, but using ambient credentials", func(t *testing.T) {
ctx := context.Background()

_, err := GetAWSConfig(ctx, dummyRegion,
WithAmbientCredentials(),
WithAWSIntegrationCredentialProvider(func(ctx context.Context, region, integration string) (aws.CredentialsProvider, error) {
require.Fail(t, "this function should not be called")
return nil, nil
}))
require.NoError(t, err)
})

t.Run("with an integration credential provider, but no credential source", func(t *testing.T) {
ctx := context.Background()

_, err := GetAWSConfig(ctx, dummyRegion,
WithAWSIntegrationCredentialProvider(func(ctx context.Context, region, integration string) (aws.CredentialsProvider, error) {
require.Fail(t, "this function should not be called")
return nil, nil
}))
require.Error(t, err)
require.ErrorContains(t, err, "missing credentials source")
})
}
11 changes: 7 additions & 4 deletions lib/cloud/aws/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,14 @@ import (
// the error without modifying it.
func ConvertRequestFailureError(err error) error {
var requestErr awserr.RequestFailure
if !errors.As(err, &requestErr) {
return err
if errors.As(err, &requestErr) {
return convertRequestFailureErrorFromStatusCode(requestErr.StatusCode(), requestErr)
}

return convertRequestFailureErrorFromStatusCode(requestErr.StatusCode(), requestErr)
var re *awshttp.ResponseError
if errors.As(err, &re) {
return convertRequestFailureErrorFromStatusCode(re.HTTPStatusCode(), re.Err)
}
return err
}

func convertRequestFailureErrorFromStatusCode(statusCode int, requestErr error) error {
Expand Down
12 changes: 12 additions & 0 deletions lib/cloud/aws/errors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,18 @@ func TestConvertRequestFailureError(t *testing.T) {
inputError: errors.New("not-aws-error"),
wantUnmodified: true,
},
{
name: "v2 sdk error",
inputError: &awshttp.ResponseError{
ResponseError: &smithyhttp.ResponseError{
Response: &smithyhttp.Response{Response: &http.Response{
StatusCode: http.StatusNotFound,
}},
Err: trace.Errorf(""),
},
},
wantIsError: trace.IsNotFound,
},
}

for _, test := range tests {
Expand Down
Loading
Loading