Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify awsconfig loading #50809

Merged
merged 1 commit into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 97 additions & 32 deletions lib/cloud/awsconfig/awsconfig.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
"github.com/gravitational/trace"
"go.opentelemetry.io/otel"

"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/modules"
)

Expand All @@ -43,12 +44,25 @@ const (
credentialsSourceIntegration
)

// IntegrationSessionProviderFunc defines a function that creates a credential provider from a region and an integration.
// This is used to generate aws configs for clients that must use an integration instead of ambient credentials.
type IntegrationCredentialProviderFunc func(ctx context.Context, region, integration string) (aws.CredentialsProvider, error)
// OIDCIntegrationClient is an interface that indicates which APIs are
// required to generate an AWS OIDC integration token.
type OIDCIntegrationClient interface {
// GetIntegration returns the specified integration resource.
GetIntegration(ctx context.Context, name string) (types.Integration, error)

// AssumeRoleClientProviderFunc provides an AWS STS assume role API client.
type AssumeRoleClientProviderFunc func(aws.Config) stscreds.AssumeRoleAPIClient
// GenerateAWSOIDCToken generates a token to be used to execute an AWS OIDC
// Integration action.
GenerateAWSOIDCToken(ctx context.Context, integrationName string) (string, error)
}

// STSClient is a subset of the AWS STS API.
type STSClient interface {
stscreds.AssumeRoleAPIClient
stscreds.AssumeRoleWithWebIdentityAPIClient
}

// STSClientProviderFunc provides an AWS STS assume role API client.
type STSClientProviderFunc func(aws.Config) STSClient

// AssumeRole is an AWS role to assume, optionally with an external ID.
type AssumeRole struct {
Expand All @@ -68,14 +82,16 @@ type options struct {
credentialsSource credentialsSource
// integration is the name of the integration to be used to fetch the credentials.
integration string
// integrationCredentialsProvider is the integration credential provider to use.
integrationCredentialsProvider IntegrationCredentialProviderFunc
// oidcIntegrationClient provides APIs to generate AWS OIDC tokens, which
// can then be exchanged for IAM credentials.
// Required if integration credentials are requested.
oidcIntegrationClient OIDCIntegrationClient
// customRetryer is a custom retryer to use for the config.
customRetryer func() aws.Retryer
// maxRetries is the maximum number of retries to use for the config.
maxRetries *int
// assumeRoleClientProvider sets the STS assume role client provider func.
assumeRoleClientProvider AssumeRoleClientProviderFunc
// stsClientProvider sets the STS assume role client provider func.
stsClientProvider STSClientProviderFunc
}

func buildOptions(optFns ...OptionsFn) (*options, error) {
Expand All @@ -99,15 +115,18 @@ func (o *options) checkAndSetDefaults() error {
if o.integration == "" {
return trace.BadParameter("missing integration name")
}
if o.oidcIntegrationClient == nil {
return trace.BadParameter("missing AWS OIDC integration client")
}
default:
return trace.BadParameter("missing credentials source (ambient or integration)")
}
if len(o.assumeRoles) > 2 {
return trace.BadParameter("role chain contains more than 2 roles")
}

if o.assumeRoleClientProvider == nil {
o.assumeRoleClientProvider = func(cfg aws.Config) stscreds.AssumeRoleAPIClient {
if o.stsClientProvider == nil {
o.stsClientProvider = func(cfg aws.Config) STSClient {
return sts.NewFromConfig(cfg, func(o *sts.Options) {
o.TracerProvider = smithyoteltracing.Adapt(otel.GetTracerProvider())
})
Expand Down Expand Up @@ -175,18 +194,17 @@ func WithAmbientCredentials() OptionsFn {
}
}

// WithIntegrationCredentialProvider sets the integration credential provider.
func WithIntegrationCredentialProvider(cred IntegrationCredentialProviderFunc) OptionsFn {
// WithSTSClientProvider sets the STS API client factory func.
func WithSTSClientProvider(fn STSClientProviderFunc) OptionsFn {
return func(options *options) {
options.integrationCredentialsProvider = cred
options.stsClientProvider = fn
}
}

// WithAssumeRoleClientProviderFunc sets the STS API client factory func used to
// assume roles.
func WithAssumeRoleClientProviderFunc(fn AssumeRoleClientProviderFunc) OptionsFn {
// WithOIDCIntegrationClient sets the OIDC integration client.
func WithOIDCIntegrationClient(c OIDCIntegrationClient) OptionsFn {
return func(options *options) {
options.assumeRoleClientProvider = fn
options.oidcIntegrationClient = c
}
}

Expand All @@ -202,7 +220,7 @@ func GetConfig(ctx context.Context, region string, optFns ...OptionsFn) (aws.Con
if err != nil {
return aws.Config{}, trace.Wrap(err)
}
return getConfigForRoleChain(ctx, cfg, opts.assumeRoles, opts.assumeRoleClientProvider)
return getConfigForRoleChain(ctx, cfg, opts.assumeRoles, opts.stsClientProvider)
}

// loadDefaultConfig loads a new config.
Expand All @@ -217,6 +235,7 @@ func buildConfigOptions(region string, cred aws.CredentialsProvider, opts *optio
config.WithDefaultRegion(defaultRegion),
config.WithRegion(region),
config.WithCredentialsProvider(cred),
config.WithCredentialsCacheOptions(awsCredentialsCacheOptions),
}
if modules.GetModules().IsBoringBinary() {
configOpts = append(configOpts, config.WithUseFIPSEndpoint(aws.FIPSEndpointStateEnabled))
Expand All @@ -232,27 +251,35 @@ func buildConfigOptions(region string, cred aws.CredentialsProvider, opts *optio

// getBaseConfig returns an AWS config without assuming any roles.
func getBaseConfig(ctx context.Context, region string, opts *options) (aws.Config, error) {
var cred aws.CredentialsProvider
slog.DebugContext(ctx, "Initializing AWS config from default credential chain",
"region", region,
)
cfg, err := loadDefaultConfig(ctx, region, nil, opts)
if err != nil {
return aws.Config{}, trace.Wrap(err)
}

if opts.credentialsSource == credentialsSourceIntegration {
if opts.integrationCredentialsProvider == nil {
return aws.Config{}, trace.BadParameter("missing aws integration credential provider")
slog.DebugContext(ctx, "Initializing AWS config with OIDC integration credentials",
"region", region,
"integration", opts.integration,
)
provider := &integrationCredentialsProvider{
OIDCIntegrationClient: opts.oidcIntegrationClient,
stsClt: opts.stsClientProvider(cfg),
integrationName: opts.integration,
}

slog.DebugContext(ctx, "Initializing AWS config with integration", "region", region, "integration", opts.integration)
var err error
cred, err = opts.integrationCredentialsProvider(ctx, region, opts.integration)
cc := aws.NewCredentialsCache(provider, awsCredentialsCacheOptions)
_, err := cc.Retrieve(ctx)
if err != nil {
return aws.Config{}, trace.Wrap(err)
}
} else {
slog.DebugContext(ctx, "Initializing AWS config from default credential chain", "region", region)
cfg.Credentials = cc
}

cfg, err := loadDefaultConfig(ctx, region, cred, opts)
return cfg, trace.Wrap(err)
return cfg, nil
}

func getConfigForRoleChain(ctx context.Context, cfg aws.Config, roles []AssumeRole, newCltFn AssumeRoleClientProviderFunc) (aws.Config, error) {
func getConfigForRoleChain(ctx context.Context, cfg aws.Config, roles []AssumeRole, newCltFn STSClientProviderFunc) (aws.Config, error) {
for _, r := range roles {
cfg.Credentials = getAssumeRoleProvider(ctx, newCltFn(cfg), r)
}
Expand All @@ -277,3 +304,41 @@ func getAssumeRoleProvider(ctx context.Context, clt stscreds.AssumeRoleAPIClient
}
})
}

// staticIdentityToken provides itself as a JWT []byte token to implement
// [stscreds.IdentityTokenRetriever].
type staticIdentityToken string

// GetIdentityToken retrieves the JWT token.
func (t staticIdentityToken) GetIdentityToken() ([]byte, error) {
return []byte(t), nil
}

// integrationCredentialsProvider provides AWS OIDC integration credentials.
type integrationCredentialsProvider struct {
OIDCIntegrationClient
stsClt STSClient
integrationName string
}

// Retrieve provides [aws.Credentials] for an AWS OIDC integration.
func (p *integrationCredentialsProvider) Retrieve(ctx context.Context) (aws.Credentials, error) {
integration, err := p.GetIntegration(ctx, p.integrationName)
if err != nil {
return aws.Credentials{}, trace.Wrap(err)
}
spec := integration.GetAWSOIDCIntegrationSpec()
if spec == nil {
return aws.Credentials{}, trace.BadParameter("invalid integration subkind, expected awsoidc, got %s", integration.GetSubKind())
}
token, err := p.GenerateAWSOIDCToken(ctx, p.integrationName)
if err != nil {
return aws.Credentials{}, trace.Wrap(err)
}
cred, err := stscreds.NewWebIdentityRoleProvider(
p.stsClt,
spec.RoleARN,
staticIdentityToken(token),
).Retrieve(ctx)
return cred, trace.Wrap(err)
}
Loading
Loading