From 7ad1ad3687198b4ee12b782710231795f3dc0d1a Mon Sep 17 00:00:00 2001 From: Gavin Frazar Date: Mon, 6 Jan 2025 18:48:40 -0800 Subject: [PATCH] Simplify awsconfig loading This replaces awsconfig.WithIntegrationCredentialProvider option with the awsconfig.WithOIDCIntegrationClient option. This solves a chicken/egg problem with AWS config loading - callers no longer need to load AWS config (to create a credential provider) to load AWS config. The OIDCIntegrationClient interface is also much simpler to implement. This also adds default option overrides when creating an awsconfig.Cache. For now, this is used to add an OIDCIntegrationClient when creating the cache so that dependent callers don't have to. --- lib/cloud/awsconfig/awsconfig.go | 129 +++++++++--- lib/cloud/awsconfig/awsconfig_test.go | 194 ++++++++++++------ lib/cloud/awsconfig/cache.go | 36 +++- lib/cloud/mocks/aws_config.go | 35 +++- lib/cloud/mocks/aws_sts.go | 19 +- lib/integrations/awsoidc/clientsv1.go | 3 - lib/srv/discovery/discovery.go | 34 +-- lib/srv/discovery/discovery_test.go | 45 ++-- lib/srv/discovery/fetchers/db/aws.go | 3 - lib/srv/discovery/fetchers/db/aws_redshift.go | 1 - lib/srv/discovery/fetchers/db/db.go | 22 +- 11 files changed, 349 insertions(+), 172 deletions(-) diff --git a/lib/cloud/awsconfig/awsconfig.go b/lib/cloud/awsconfig/awsconfig.go index 8be00483f4012..7b1cabe5ffe75 100644 --- a/lib/cloud/awsconfig/awsconfig.go +++ b/lib/cloud/awsconfig/awsconfig.go @@ -28,6 +28,7 @@ import ( "github.com/gravitational/trace" "go.opentelemetry.io/otel" + "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/modules" ) @@ -43,12 +44,25 @@ const ( credentialsSourceIntegration ) -// IntegrationSessionProviderFunc defines a function that creates a credential provider from a region and an integration. -// This is used to generate aws configs for clients that must use an integration instead of ambient credentials. -type IntegrationCredentialProviderFunc func(ctx context.Context, region, integration string) (aws.CredentialsProvider, error) +// OIDCIntegrationClient is an interface that indicates which APIs are +// required to generate an AWS OIDC integration token. +type OIDCIntegrationClient interface { + // GetIntegration returns the specified integration resource. + GetIntegration(ctx context.Context, name string) (types.Integration, error) -// AssumeRoleClientProviderFunc provides an AWS STS assume role API client. -type AssumeRoleClientProviderFunc func(aws.Config) stscreds.AssumeRoleAPIClient + // GenerateAWSOIDCToken generates a token to be used to execute an AWS OIDC + // Integration action. + GenerateAWSOIDCToken(ctx context.Context, integrationName string) (string, error) +} + +// STSClient is a subset of the AWS STS API. +type STSClient interface { + stscreds.AssumeRoleAPIClient + stscreds.AssumeRoleWithWebIdentityAPIClient +} + +// STSClientProviderFunc provides an AWS STS assume role API client. +type STSClientProviderFunc func(aws.Config) STSClient // AssumeRole is an AWS role to assume, optionally with an external ID. type AssumeRole struct { @@ -68,14 +82,16 @@ type options struct { credentialsSource credentialsSource // integration is the name of the integration to be used to fetch the credentials. integration string - // integrationCredentialsProvider is the integration credential provider to use. - integrationCredentialsProvider IntegrationCredentialProviderFunc + // oidcIntegrationClient provides APIs to generate AWS OIDC tokens, which + // can then be exchanged for IAM credentials. + // Required if integration credentials are requested. + oidcIntegrationClient OIDCIntegrationClient // customRetryer is a custom retryer to use for the config. customRetryer func() aws.Retryer // maxRetries is the maximum number of retries to use for the config. maxRetries *int - // assumeRoleClientProvider sets the STS assume role client provider func. - assumeRoleClientProvider AssumeRoleClientProviderFunc + // stsClientProvider sets the STS assume role client provider func. + stsClientProvider STSClientProviderFunc } func buildOptions(optFns ...OptionsFn) (*options, error) { @@ -99,6 +115,9 @@ func (o *options) checkAndSetDefaults() error { if o.integration == "" { return trace.BadParameter("missing integration name") } + if o.oidcIntegrationClient == nil { + return trace.BadParameter("missing AWS OIDC integration client") + } default: return trace.BadParameter("missing credentials source (ambient or integration)") } @@ -106,8 +125,8 @@ func (o *options) checkAndSetDefaults() error { return trace.BadParameter("role chain contains more than 2 roles") } - if o.assumeRoleClientProvider == nil { - o.assumeRoleClientProvider = func(cfg aws.Config) stscreds.AssumeRoleAPIClient { + if o.stsClientProvider == nil { + o.stsClientProvider = func(cfg aws.Config) STSClient { return sts.NewFromConfig(cfg, func(o *sts.Options) { o.TracerProvider = smithyoteltracing.Adapt(otel.GetTracerProvider()) }) @@ -175,18 +194,17 @@ func WithAmbientCredentials() OptionsFn { } } -// WithIntegrationCredentialProvider sets the integration credential provider. -func WithIntegrationCredentialProvider(cred IntegrationCredentialProviderFunc) OptionsFn { +// WithSTSClientProvider sets the STS API client factory func. +func WithSTSClientProvider(fn STSClientProviderFunc) OptionsFn { return func(options *options) { - options.integrationCredentialsProvider = cred + options.stsClientProvider = fn } } -// WithAssumeRoleClientProviderFunc sets the STS API client factory func used to -// assume roles. -func WithAssumeRoleClientProviderFunc(fn AssumeRoleClientProviderFunc) OptionsFn { +// WithOIDCIntegrationClient sets the OIDC integration client. +func WithOIDCIntegrationClient(c OIDCIntegrationClient) OptionsFn { return func(options *options) { - options.assumeRoleClientProvider = fn + options.oidcIntegrationClient = c } } @@ -202,7 +220,7 @@ func GetConfig(ctx context.Context, region string, optFns ...OptionsFn) (aws.Con if err != nil { return aws.Config{}, trace.Wrap(err) } - return getConfigForRoleChain(ctx, cfg, opts.assumeRoles, opts.assumeRoleClientProvider) + return getConfigForRoleChain(ctx, cfg, opts.assumeRoles, opts.stsClientProvider) } // loadDefaultConfig loads a new config. @@ -217,6 +235,7 @@ func buildConfigOptions(region string, cred aws.CredentialsProvider, opts *optio config.WithDefaultRegion(defaultRegion), config.WithRegion(region), config.WithCredentialsProvider(cred), + config.WithCredentialsCacheOptions(awsCredentialsCacheOptions), } if modules.GetModules().IsBoringBinary() { configOpts = append(configOpts, config.WithUseFIPSEndpoint(aws.FIPSEndpointStateEnabled)) @@ -232,27 +251,35 @@ func buildConfigOptions(region string, cred aws.CredentialsProvider, opts *optio // getBaseConfig returns an AWS config without assuming any roles. func getBaseConfig(ctx context.Context, region string, opts *options) (aws.Config, error) { - var cred aws.CredentialsProvider + slog.DebugContext(ctx, "Initializing AWS config from default credential chain", + "region", region, + ) + cfg, err := loadDefaultConfig(ctx, region, nil, opts) + if err != nil { + return aws.Config{}, trace.Wrap(err) + } + if opts.credentialsSource == credentialsSourceIntegration { - if opts.integrationCredentialsProvider == nil { - return aws.Config{}, trace.BadParameter("missing aws integration credential provider") + slog.DebugContext(ctx, "Initializing AWS config with OIDC integration credentials", + "region", region, + "integration", opts.integration, + ) + provider := &integrationCredentialsProvider{ + OIDCIntegrationClient: opts.oidcIntegrationClient, + stsClt: opts.stsClientProvider(cfg), + integrationName: opts.integration, } - - slog.DebugContext(ctx, "Initializing AWS config with integration", "region", region, "integration", opts.integration) - var err error - cred, err = opts.integrationCredentialsProvider(ctx, region, opts.integration) + cc := aws.NewCredentialsCache(provider, awsCredentialsCacheOptions) + _, err := cc.Retrieve(ctx) if err != nil { return aws.Config{}, trace.Wrap(err) } - } else { - slog.DebugContext(ctx, "Initializing AWS config from default credential chain", "region", region) + cfg.Credentials = cc } - - cfg, err := loadDefaultConfig(ctx, region, cred, opts) - return cfg, trace.Wrap(err) + return cfg, nil } -func getConfigForRoleChain(ctx context.Context, cfg aws.Config, roles []AssumeRole, newCltFn AssumeRoleClientProviderFunc) (aws.Config, error) { +func getConfigForRoleChain(ctx context.Context, cfg aws.Config, roles []AssumeRole, newCltFn STSClientProviderFunc) (aws.Config, error) { for _, r := range roles { cfg.Credentials = getAssumeRoleProvider(ctx, newCltFn(cfg), r) } @@ -277,3 +304,41 @@ func getAssumeRoleProvider(ctx context.Context, clt stscreds.AssumeRoleAPIClient } }) } + +// staticIdentityToken provides itself as a JWT []byte token to implement +// [stscreds.IdentityTokenRetriever]. +type staticIdentityToken string + +// GetIdentityToken retrieves the JWT token. +func (t staticIdentityToken) GetIdentityToken() ([]byte, error) { + return []byte(t), nil +} + +// integrationCredentialsProvider provides AWS OIDC integration credentials. +type integrationCredentialsProvider struct { + OIDCIntegrationClient + stsClt STSClient + integrationName string +} + +// Retrieve provides [aws.Credentials] for an AWS OIDC integration. +func (p *integrationCredentialsProvider) Retrieve(ctx context.Context) (aws.Credentials, error) { + integration, err := p.GetIntegration(ctx, p.integrationName) + if err != nil { + return aws.Credentials{}, trace.Wrap(err) + } + spec := integration.GetAWSOIDCIntegrationSpec() + if spec == nil { + return aws.Credentials{}, trace.BadParameter("invalid integration subkind, expected awsoidc, got %s", integration.GetSubKind()) + } + token, err := p.GenerateAWSOIDCToken(ctx, p.integrationName) + if err != nil { + return aws.Credentials{}, trace.Wrap(err) + } + cred, err := stscreds.NewWebIdentityRoleProvider( + p.stsClt, + spec.RoleARN, + staticIdentityToken(token), + ).Retrieve(ctx) + return cred, trace.Wrap(err) +} diff --git a/lib/cloud/awsconfig/awsconfig_test.go b/lib/cloud/awsconfig/awsconfig_test.go index 3cb2c4eda3123..2de624fe86c54 100644 --- a/lib/cloud/awsconfig/awsconfig_test.go +++ b/lib/cloud/awsconfig/awsconfig_test.go @@ -24,20 +24,13 @@ import ( "time" "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/credentials/stscreds" "github.com/aws/aws-sdk-go-v2/service/sts" ststypes "github.com/aws/aws-sdk-go-v2/service/sts/types" "github.com/gravitational/trace" "github.com/stretchr/testify/require" -) - -type mockCredentialProvider struct { - cred aws.Credentials -} -func (m *mockCredentialProvider) Retrieve(_ context.Context) (aws.Credentials, error) { - return m.cred, nil -} + "github.com/gravitational/teleport/api/types" +) type mockAssumeRoleAPIClient struct{} @@ -57,6 +50,18 @@ func (m *mockAssumeRoleAPIClient) AssumeRole(_ context.Context, params *sts.Assu }, nil } +func (m *mockAssumeRoleAPIClient) AssumeRoleWithWebIdentity(ctx context.Context, in *sts.AssumeRoleWithWebIdentityInput, _ ...func(*sts.Options)) (*sts.AssumeRoleWithWebIdentityOutput, error) { + expiry := time.Now().Add(60 * time.Minute) + return &sts.AssumeRoleWithWebIdentityOutput{ + Credentials: &ststypes.Credentials{ + AccessKeyId: in.RoleArn, + SecretAccessKey: in.WebIdentityToken, + SessionToken: aws.String("token"), + Expiration: &expiry, + }, + }, nil +} + func TestGetConfigIntegration(t *testing.T) { t.Parallel() @@ -86,32 +91,100 @@ func testGetConfigIntegration(t *testing.T, provider Provider) { dummyIntegration := "integration-test" dummyRegion := "test-region-123" - t.Run("without an integration credential provider, must return missing credential provider error", func(t *testing.T) { + awsOIDCIntegration, err := types.NewIntegrationAWSOIDC( + types.Metadata{Name: "integration-test"}, + &types.AWSOIDCIntegrationSpecV1{ + RoleARN: "arn:aws:sts::123456789012:role/TestRole", + }, + ) + require.NoError(t, err) + fakeIntegrationClt := fakeOIDCIntegrationClient{ + getIntegrationFn: func(context.Context, string) (types.Integration, error) { + return awsOIDCIntegration, nil + }, + getTokenFn: func(context.Context, string) (string, error) { + return "oidc-token", nil + }, + } + + stsClt := func(cfg aws.Config) STSClient { + return &mockAssumeRoleAPIClient{} + } + + t.Run("without an integration client, must return missing credential provider error", func(t *testing.T) { ctx := context.Background() _, err := provider.GetConfig(ctx, dummyRegion, WithCredentialsMaybeIntegration(dummyIntegration)) require.True(t, trace.IsBadParameter(err), "unexpected error: %v", err) - require.ErrorContains(t, err, "missing aws integration credential provider") + require.ErrorContains(t, err, "missing AWS OIDC integration client") + }) + + t.Run("with an integration client, must return integration fetch error", func(t *testing.T) { + ctx := context.Background() + + fakeIntegrationClt := fakeIntegrationClt + fakeIntegrationClt.getIntegrationFn = func(context.Context, string) (types.Integration, error) { + return nil, trace.NotFound("integration not found") + } + _, err := provider.GetConfig(ctx, dummyRegion, + WithCredentialsMaybeIntegration(dummyIntegration), + WithOIDCIntegrationClient(&fakeIntegrationClt), + WithSTSClientProvider(stsClt), + ) + require.Error(t, err) + require.ErrorContains(t, err, "integration not found") + }) + + t.Run("with an integration client, must check for AWS integration subkind", func(t *testing.T) { + ctx := context.Background() + + azureIntegration, err := types.NewIntegrationAzureOIDC( + types.Metadata{Name: "integration-test"}, + &types.AzureOIDCIntegrationSpecV1{ + TenantID: "abc", + ClientID: "123", + }, + ) + require.NoError(t, err) + fakeIntegrationClt := fakeIntegrationClt + fakeIntegrationClt.getIntegrationFn = func(context.Context, string) (types.Integration, error) { + return azureIntegration, nil + } + _, err = provider.GetConfig(ctx, dummyRegion, + WithCredentialsMaybeIntegration(dummyIntegration), + WithOIDCIntegrationClient(&fakeIntegrationClt), + WithSTSClientProvider(stsClt), + ) + require.Error(t, err) + require.ErrorContains(t, err, "invalid integration subkind") + }) + + t.Run("with an integration client, must return token generation errors", func(t *testing.T) { + ctx := context.Background() + fakeIntegrationClt := fakeIntegrationClt + fakeIntegrationClt.getTokenFn = func(context.Context, string) (string, error) { + return "", trace.BadParameter("failed to generate OIDC token") + } + _, err = provider.GetConfig(ctx, dummyRegion, + WithCredentialsMaybeIntegration(dummyIntegration), + WithOIDCIntegrationClient(&fakeIntegrationClt), + WithSTSClientProvider(stsClt), + ) + require.Error(t, err) + require.ErrorContains(t, err, "failed to generate OIDC token") }) - t.Run("with an integration credential provider, must return the credentials", func(t *testing.T) { + t.Run("with an integration client, must return the credentials", func(t *testing.T) { ctx := context.Background() cfg, err := provider.GetConfig(ctx, dummyRegion, WithCredentialsMaybeIntegration(dummyIntegration), - WithIntegrationCredentialProvider(func(ctx context.Context, region, integration string) (aws.CredentialsProvider, error) { - if region == dummyRegion && integration == dummyIntegration { - return &mockCredentialProvider{ - cred: aws.Credentials{ - SessionToken: "foo-bar", - }, - }, nil - } - return nil, trace.NotFound("no creds in region %q with integration %q", region, integration) - })) + WithOIDCIntegrationClient(&fakeIntegrationClt), + WithSTSClientProvider(stsClt), + ) require.NoError(t, err) creds, err := cfg.Credentials.Retrieve(ctx) require.NoError(t, err) - require.Equal(t, "foo-bar", creds.SessionToken) + require.Equal(t, "oidc-token", creds.SecretAccessKey) }) t.Run("with an integration credential provider assuming a role, must return assumed role credentials", func(t *testing.T) { @@ -119,23 +192,9 @@ func testGetConfigIntegration(t *testing.T, provider Provider) { cfg, err := provider.GetConfig(ctx, dummyRegion, WithCredentialsMaybeIntegration(dummyIntegration), - WithIntegrationCredentialProvider(func(ctx context.Context, region, integration string) (aws.CredentialsProvider, error) { - if region == dummyRegion && integration == dummyIntegration { - return &mockCredentialProvider{ - cred: aws.Credentials{ - SessionToken: "foo-bar", - }, - }, nil - } - return nil, trace.NotFound("no creds in region %q with integration %q", region, integration) - }), + WithOIDCIntegrationClient(&fakeIntegrationClt), WithAssumeRole("roleA", "abc123"), - WithAssumeRoleClientProviderFunc(func(cfg aws.Config) stscreds.AssumeRoleAPIClient { - creds, err := cfg.Credentials.Retrieve(context.Background()) - require.NoError(t, err) - require.Equal(t, "foo-bar", creds.SessionToken) - return &mockAssumeRoleAPIClient{} - }), + WithSTSClientProvider(stsClt), ) require.NoError(t, err) creds, err := cfg.Credentials.Retrieve(ctx) @@ -148,25 +207,11 @@ func testGetConfigIntegration(t *testing.T, provider Provider) { ctx := context.Background() _, err := provider.GetConfig(ctx, dummyRegion, WithCredentialsMaybeIntegration(dummyIntegration), - WithIntegrationCredentialProvider(func(ctx context.Context, region, integration string) (aws.CredentialsProvider, error) { - if region == dummyRegion && integration == dummyIntegration { - return &mockCredentialProvider{ - cred: aws.Credentials{ - SessionToken: "foo-bar", - }, - }, nil - } - return nil, trace.NotFound("no creds in region %q with integration %q", region, integration) - }), + WithOIDCIntegrationClient(&fakeIntegrationClt), WithAssumeRole("roleA", "abc123"), WithAssumeRole("roleB", "abc123"), WithAssumeRole("roleC", "abc123"), - WithAssumeRoleClientProviderFunc(func(cfg aws.Config) stscreds.AssumeRoleAPIClient { - creds, err := cfg.Credentials.Retrieve(context.Background()) - require.NoError(t, err) - require.Equal(t, "foo-bar", creds.SessionToken) - return &mockAssumeRoleAPIClient{} - }), + WithSTSClientProvider(stsClt), ) require.Error(t, err) require.ErrorContains(t, err, "role chain contains more than 2 roles") @@ -177,10 +222,8 @@ func testGetConfigIntegration(t *testing.T, provider Provider) { _, err := provider.GetConfig(ctx, dummyRegion, WithCredentialsMaybeIntegration(""), - WithIntegrationCredentialProvider(func(ctx context.Context, region, integration string) (aws.CredentialsProvider, error) { - require.Fail(t, "this function should not be called") - return nil, nil - })) + WithOIDCIntegrationClient(&fakeOIDCIntegrationClient{unauth: true}), + ) require.NoError(t, err) }) @@ -189,10 +232,8 @@ func testGetConfigIntegration(t *testing.T, provider Provider) { _, err := provider.GetConfig(ctx, dummyRegion, WithAmbientCredentials(), - WithIntegrationCredentialProvider(func(ctx context.Context, region, integration string) (aws.CredentialsProvider, error) { - require.Fail(t, "this function should not be called") - return nil, nil - })) + WithOIDCIntegrationClient(&fakeOIDCIntegrationClient{unauth: true}), + ) require.NoError(t, err) }) @@ -200,10 +241,8 @@ func testGetConfigIntegration(t *testing.T, provider Provider) { ctx := context.Background() _, err := provider.GetConfig(ctx, dummyRegion, - WithIntegrationCredentialProvider(func(ctx context.Context, region, integration string) (aws.CredentialsProvider, error) { - require.Fail(t, "this function should not be called") - return nil, nil - })) + WithOIDCIntegrationClient(&fakeOIDCIntegrationClient{unauth: true}), + ) require.Error(t, err) require.ErrorContains(t, err, "missing credentials source") }) @@ -221,3 +260,24 @@ func TestNewCacheKey(t *testing.T) { `) require.Equal(t, want, got) } + +type fakeOIDCIntegrationClient struct { + unauth bool + + getIntegrationFn func(context.Context, string) (types.Integration, error) + getTokenFn func(context.Context, string) (string, error) +} + +func (f *fakeOIDCIntegrationClient) GetIntegration(ctx context.Context, name string) (types.Integration, error) { + if f.unauth { + return nil, trace.AccessDenied("unauthorized") + } + return f.getIntegrationFn(ctx, name) +} + +func (f *fakeOIDCIntegrationClient) GenerateAWSOIDCToken(ctx context.Context, integrationName string) (string, error) { + if f.unauth { + return "", trace.AccessDenied("unauthorized") + } + return f.getTokenFn(ctx, integrationName) +} diff --git a/lib/cloud/awsconfig/cache.go b/lib/cloud/awsconfig/cache.go index 3d664ba04c350..cdb315703212a 100644 --- a/lib/cloud/awsconfig/cache.go +++ b/lib/cloud/awsconfig/cache.go @@ -36,10 +36,23 @@ func awsCredentialsCacheOptions(opts *aws.CredentialsCacheOptions) { // role. type Cache struct { awsConfigCache *utils.FnCache + defaultOptions []OptionsFn +} + +// CacheOption is an option func for setting additional options when creating +// a new config cache. +type CacheOption func(*Cache) + +// WithDefaults is a [CacheOption] function that sets default [OptionsFn] to +// use when getting AWS config. +func WithDefaults(optFns ...OptionsFn) CacheOption { + return func(c *Cache) { + c.defaultOptions = optFns + } } // NewCache returns a new [Cache]. -func NewCache() (*Cache, error) { +func NewCache(optFns ...CacheOption) (*Cache, error) { c, err := utils.NewFnCache(utils.FnCacheConfig{ TTL: 15 * time.Minute, ReloadOnErr: true, @@ -47,14 +60,27 @@ func NewCache() (*Cache, error) { if err != nil { return nil, trace.Wrap(err) } - return &Cache{ + cache := &Cache{ awsConfigCache: c, - }, nil + } + for _, fn := range optFns { + fn(cache) + } + return cache, nil +} + +// withDefaultOptions prepends default options to the given option funcs, +// providing for default cache options and per-call options. +func (c *Cache) withDefaultOptions(optFns []OptionsFn) []OptionsFn { + if c.defaultOptions != nil { + return append(c.defaultOptions, optFns...) + } + return optFns } // GetConfig returns an [aws.Config] for the given region and options. func (c *Cache) GetConfig(ctx context.Context, region string, optFns ...OptionsFn) (aws.Config, error) { - opts, err := buildOptions(optFns...) + opts, err := buildOptions(c.withDefaultOptions(optFns)...) if err != nil { return aws.Config{}, trace.Wrap(err) } @@ -112,7 +138,7 @@ func (c *Cache) getConfigForRoleChain(ctx context.Context, cfg aws.Config, opts } credProvider, err := utils.FnCacheGet(ctx, c.awsConfigCache, cacheKey, func(ctx context.Context) (aws.CredentialsProvider, error) { - clt := opts.assumeRoleClientProvider(cfg) + clt := opts.stsClientProvider(cfg) credProvider := getAssumeRoleProvider(ctx, clt, r) cc := aws.NewCredentialsCache(credProvider, awsCredentialsCacheOptions, diff --git a/lib/cloud/mocks/aws_config.go b/lib/cloud/mocks/aws_config.go index 7edadf80a9e20..b52dfbd36d74a 100644 --- a/lib/cloud/mocks/aws_config.go +++ b/lib/cloud/mocks/aws_config.go @@ -22,12 +22,15 @@ import ( "context" "github.com/aws/aws-sdk-go-v2/aws" + "github.com/gravitational/trace" + "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/cloud/awsconfig" ) type AWSConfigProvider struct { - STSClient *STSClient + STSClient *STSClient + OIDCIntegrationClient awsconfig.OIDCIntegrationClient } func (f *AWSConfigProvider) GetConfig(ctx context.Context, region string, optFns ...awsconfig.OptionsFn) (aws.Config, error) { @@ -35,8 +38,32 @@ func (f *AWSConfigProvider) GetConfig(ctx context.Context, region string, optFns if stsClt == nil { stsClt = &STSClient{} } - optFns = append(optFns, awsconfig.WithAssumeRoleClientProviderFunc( - newAssumeRoleClientProviderFunc(stsClt), - )) + optFns = append(optFns, + awsconfig.WithOIDCIntegrationClient(f.OIDCIntegrationClient), + awsconfig.WithSTSClientProvider( + newAssumeRoleClientProviderFunc(stsClt), + ), + ) return awsconfig.GetConfig(ctx, region, optFns...) } + +type FakeOIDCIntegrationClient struct { + Unauth bool + + Integration types.Integration + Token string +} + +func (f *FakeOIDCIntegrationClient) GetIntegration(ctx context.Context, name string) (types.Integration, error) { + if f.Unauth { + return nil, trace.AccessDenied("unauthorized") + } + return f.Integration, nil +} + +func (f *FakeOIDCIntegrationClient) GenerateAWSOIDCToken(ctx context.Context, integrationName string) (string, error) { + if f.Unauth { + return "", trace.AccessDenied("unauthorized") + } + return f.Token, nil +} diff --git a/lib/cloud/mocks/aws_sts.go b/lib/cloud/mocks/aws_sts.go index 713de480ebf86..178a1259669a4 100644 --- a/lib/cloud/mocks/aws_sts.go +++ b/lib/cloud/mocks/aws_sts.go @@ -54,7 +54,20 @@ type STSClient struct { recordFn func(roleARN, externalID string) } -func (m *STSClient) AssumeRole(ctx context.Context, in *sts.AssumeRoleInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleOutput, error) { +func (m *STSClient) AssumeRoleWithWebIdentity(ctx context.Context, in *sts.AssumeRoleWithWebIdentityInput, _ ...func(*sts.Options)) (*sts.AssumeRoleWithWebIdentityOutput, error) { + m.record(aws.ToString(in.RoleArn), "") + expiry := time.Now().Add(60 * time.Minute) + return &sts.AssumeRoleWithWebIdentityOutput{ + Credentials: &ststypes.Credentials{ + AccessKeyId: in.RoleArn, + SecretAccessKey: aws.String("secret"), + SessionToken: aws.String("token"), + Expiration: &expiry, + }, + }, nil +} + +func (m *STSClient) AssumeRole(ctx context.Context, in *sts.AssumeRoleInput, _ ...func(*sts.Options)) (*sts.AssumeRoleOutput, error) { // Retrieve credentials if we have a credential provider, so that all // assume-role providers in a role chain are triggered to call AssumeRole. if m.credentialProvider != nil { @@ -93,8 +106,8 @@ func (m *STSClient) record(roleARN, externalID string) { } } -func newAssumeRoleClientProviderFunc(base *STSClient) awsconfig.AssumeRoleClientProviderFunc { - return func(cfg aws.Config) stscreds.AssumeRoleAPIClient { +func newAssumeRoleClientProviderFunc(base *STSClient) awsconfig.STSClientProviderFunc { + return func(cfg aws.Config) awsconfig.STSClient { if cfg.Credentials != nil { if _, ok := cfg.Credentials.(*stscreds.AssumeRoleProvider); ok { // Create a new fake client linked to the old one. diff --git a/lib/integrations/awsoidc/clientsv1.go b/lib/integrations/awsoidc/clientsv1.go index 8c16f4c66156a..ae2e0be6a186b 100644 --- a/lib/integrations/awsoidc/clientsv1.go +++ b/lib/integrations/awsoidc/clientsv1.go @@ -44,9 +44,6 @@ type IntegrationTokenGenerator interface { // GetIntegration returns the specified integration resources. GetIntegration(ctx context.Context, name string) (types.Integration, error) - // GetProxies returns a list of registered proxies. - GetProxies() ([]types.Server, error) - // GenerateAWSOIDCToken generates a token to be used to execute an AWS OIDC Integration action. GenerateAWSOIDCToken(ctx context.Context, integration string) (string, error) } diff --git a/lib/srv/discovery/discovery.go b/lib/srv/discovery/discovery.go index 28690130d51a7..f37ba025d2450 100644 --- a/lib/srv/discovery/discovery.go +++ b/lib/srv/discovery/discovery.go @@ -224,7 +224,11 @@ kubernetes matchers are present.`) c.CloudClients = cloudClients } if c.AWSConfigProvider == nil { - provider, err := awsconfig.NewCache() + provider, err := awsconfig.NewCache( + awsconfig.WithDefaults( + awsconfig.WithOIDCIntegrationClient(c.AccessPoint), + ), + ) if err != nil { return trace.Wrap(err, "unable to create AWS config provider cache") } @@ -232,9 +236,8 @@ kubernetes matchers are present.`) } if c.AWSDatabaseFetcherFactory == nil { factory, err := db.NewAWSFetcherFactory(db.AWSFetcherFactoryConfig{ - CloudClients: c.CloudClients, - AWSConfigProvider: c.AWSConfigProvider, - IntegrationCredentialProviderFn: c.getIntegrationCredentialProviderFn(), + CloudClients: c.CloudClients, + AWSConfigProvider: c.AWSConfigProvider, }) if err != nil { return trace.Wrap(err) @@ -312,33 +315,10 @@ kubernetes matchers are present.`) } func (c *Config) getAWSConfig(ctx context.Context, region string, opts ...awsconfig.OptionsFn) (aws.Config, error) { - opts = append(opts, awsconfig.WithIntegrationCredentialProvider(c.getIntegrationCredentialProviderFn())) cfg, err := c.AWSConfigProvider.GetConfig(ctx, region, opts...) return cfg, trace.Wrap(err) } -func (c *Config) getIntegrationCredentialProviderFn() awsconfig.IntegrationCredentialProviderFunc { - return func(ctx context.Context, region, integrationName string) (aws.CredentialsProvider, error) { - integration, err := c.AccessPoint.GetIntegration(ctx, integrationName) - if err != nil { - return nil, trace.Wrap(err) - } - if integration.GetAWSOIDCIntegrationSpec() == nil { - return nil, trace.BadParameter("integration does not have aws oidc spec fields %q", integrationName) - } - token, err := c.AccessPoint.GenerateAWSOIDCToken(ctx, integrationName) - if err != nil { - return nil, trace.Wrap(err) - } - cred, err := awsoidc.NewAWSCredentialsProvider(ctx, &awsoidc.AWSClientRequest{ - Token: token, - RoleARN: integration.GetAWSOIDCIntegrationSpec().RoleARN, - Region: region, - }) - return cred, trace.Wrap(err) - } -} - // Server is a discovery server, used to discover cloud resources for // inclusion in Teleport type Server struct { diff --git a/lib/srv/discovery/discovery_test.go b/lib/srv/discovery/discovery_test.go index f3c387a475932..865517ba4c33c 100644 --- a/lib/srv/discovery/discovery_test.go +++ b/lib/srv/discovery/discovery_test.go @@ -37,7 +37,6 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v6" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/redis/armredis/v3" awsv2 "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/credentials" "github.com/aws/aws-sdk-go-v2/service/ec2" ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" "github.com/aws/aws-sdk-go-v2/service/redshift" @@ -2032,18 +2031,6 @@ func TestDiscoveryDatabase(t *testing.T) { Clusters: []*eks.Cluster{eksAWSResource}, }, } - fakeConfigProvider := &mocks.AWSConfigProvider{} - dbFetcherFactory, err := db.NewAWSFetcherFactory(db.AWSFetcherFactoryConfig{ - AWSConfigProvider: fakeConfigProvider, - CloudClients: testCloudClients, - IntegrationCredentialProviderFn: func(_ context.Context, _, _ string) (awsv2.CredentialsProvider, error) { - return credentials.NewStaticCredentialsProvider("key", "secret", "session"), nil - }, - RedshiftClientProviderFn: newFakeRedshiftClientProvider(&mocks.RedshiftClient{ - Clusters: []redshifttypes.Cluster{*awsRedshiftResource}, - }), - }) - require.NoError(t, err) tcs := []struct { name string @@ -2334,6 +2321,23 @@ func TestDiscoveryDatabase(t *testing.T) { require.NoError(t, err) t.Cleanup(func() { require.NoError(t, tlsServer.Close()) }) + awsOIDCIntegration, err := types.NewIntegrationAWSOIDC(types.Metadata{ + Name: integrationName, + }, &types.AWSOIDCIntegrationSpecV1{ + RoleARN: "arn:aws:iam::123456789012:role/teleport", + }) + require.NoError(t, err) + + testAuthServer.AuthServer.IntegrationsTokenGenerator = &mockIntegrationsTokenGenerator{ + proxies: nil, + integrations: map[string]types.Integration{ + awsOIDCIntegration.GetName(): awsOIDCIntegration, + }, + } + + _, err = tlsServer.Auth().CreateIntegration(ctx, awsOIDCIntegration) + require.NoError(t, err) + // Auth client for discovery service. identity := auth.TestServerID(types.RoleDiscovery, "hostID") authClient, err := tlsServer.NewClient(identity) @@ -2349,6 +2353,19 @@ func TestDiscoveryDatabase(t *testing.T) { waitForReconcile := make(chan struct{}) reporter := &mockUsageReporter{} tlsServer.Auth().SetUsageReporter(reporter) + accessPoint := getDiscoveryAccessPoint(tlsServer.Auth(), authClient) + fakeConfigProvider := &mocks.AWSConfigProvider{ + OIDCIntegrationClient: accessPoint, + } + dbFetcherFactory, err := db.NewAWSFetcherFactory(db.AWSFetcherFactoryConfig{ + AWSConfigProvider: fakeConfigProvider, + CloudClients: testCloudClients, + RedshiftClientProviderFn: newFakeRedshiftClientProvider(&mocks.RedshiftClient{ + Clusters: []redshifttypes.Cluster{*awsRedshiftResource}, + }), + }) + require.NoError(t, err) + srv, err := New( authz.ContextWithUser(ctx, identity.I), &Config{ @@ -2358,7 +2375,7 @@ func TestDiscoveryDatabase(t *testing.T) { AWSConfigProvider: fakeConfigProvider, ClusterFeatures: func() proto.Features { return proto.Features{} }, KubernetesClient: fake.NewSimpleClientset(), - AccessPoint: getDiscoveryAccessPoint(tlsServer.Auth(), authClient), + AccessPoint: accessPoint, Matchers: Matchers{ AWS: tc.awsMatchers, Azure: tc.azureMatchers, diff --git a/lib/srv/discovery/fetchers/db/aws.go b/lib/srv/discovery/fetchers/db/aws.go index f87e0e9a6c443..d6d70912d7092 100644 --- a/lib/srv/discovery/fetchers/db/aws.go +++ b/lib/srv/discovery/fetchers/db/aws.go @@ -55,9 +55,6 @@ type awsFetcherConfig struct { AWSClients cloud.AWSClients // AWSConfigProvider provides [aws.Config] for AWS SDK service clients. AWSConfigProvider awsconfig.Provider - // IntegrationCredentialProviderFn is a required function that provides - // credentials via AWS OIDC integration. - IntegrationCredentialProviderFn awsconfig.IntegrationCredentialProviderFunc // Type is the type of DB matcher, for example "rds", "redshift", etc. Type string // AssumeRole provides a role ARN and ExternalID to assume an AWS role diff --git a/lib/srv/discovery/fetchers/db/aws_redshift.go b/lib/srv/discovery/fetchers/db/aws_redshift.go index 508cb6e8810f1..0cda0b478e67b 100644 --- a/lib/srv/discovery/fetchers/db/aws_redshift.go +++ b/lib/srv/discovery/fetchers/db/aws_redshift.go @@ -53,7 +53,6 @@ func (f *redshiftPlugin) GetDatabases(ctx context.Context, cfg *awsFetcherConfig awsCfg, err := cfg.AWSConfigProvider.GetConfig(ctx, cfg.Region, awsconfig.WithAssumeRole(cfg.AssumeRole.RoleARN, cfg.AssumeRole.ExternalID), awsconfig.WithCredentialsMaybeIntegration(cfg.Integration), - awsconfig.WithIntegrationCredentialProvider(cfg.IntegrationCredentialProviderFn), ) if err != nil { return nil, trace.Wrap(err) diff --git a/lib/srv/discovery/fetchers/db/db.go b/lib/srv/discovery/fetchers/db/db.go index 3ef56532d90af..8d79bc2bb65bc 100644 --- a/lib/srv/discovery/fetchers/db/db.go +++ b/lib/srv/discovery/fetchers/db/db.go @@ -73,9 +73,6 @@ type AWSFetcherFactoryConfig struct { AWSConfigProvider awsconfig.Provider // CloudClients is an interface for retrieving AWS SDK v1 cloud clients. CloudClients cloud.AWSClients - // IntegrationCredentialProviderFn is an optional function that provides - // credentials via AWS OIDC integration. - IntegrationCredentialProviderFn awsconfig.IntegrationCredentialProviderFunc // RedshiftClientProviderFn is an optional function that provides RedshiftClientProviderFn RedshiftClientProviderFunc } @@ -128,16 +125,15 @@ func (f *AWSFetcherFactory) MakeFetchers(ctx context.Context, matchers []types.A for _, makeFetcher := range makeFetchers { for _, region := range matcher.Regions { fetcher, err := makeFetcher(awsFetcherConfig{ - AWSClients: f.cfg.CloudClients, - Type: matcherType, - AssumeRole: assumeRole, - Labels: matcher.Tags, - Region: region, - Integration: matcher.Integration, - DiscoveryConfigName: discoveryConfigName, - AWSConfigProvider: f.cfg.AWSConfigProvider, - IntegrationCredentialProviderFn: f.cfg.IntegrationCredentialProviderFn, - redshiftClientProviderFn: f.cfg.RedshiftClientProviderFn, + AWSClients: f.cfg.CloudClients, + Type: matcherType, + AssumeRole: assumeRole, + Labels: matcher.Tags, + Region: region, + Integration: matcher.Integration, + DiscoveryConfigName: discoveryConfigName, + AWSConfigProvider: f.cfg.AWSConfigProvider, + redshiftClientProviderFn: f.cfg.RedshiftClientProviderFn, }) if err != nil { return nil, trace.Wrap(err)