From 5e48a04df3a48f740b14a1c4d355623bf326f254 Mon Sep 17 00:00:00 2001 From: "STeve (Xin) Huang" Date: Thu, 22 Feb 2024 10:54:08 -0500 Subject: [PATCH] [v14] Improve error message when auto-user provisioning fails on reader endpoints (#37954) * Improve error message when auto-user provisioning fails on reader endpoints * fix lint --- api/types/database.go | 6 ++++ api/types/database_test.go | 53 +++++++++++++++++++++++++++++----- api/utils/aws/endpoint.go | 36 ++++++++++++++++++++--- api/utils/aws/endpoint_test.go | 26 +++++++++++++---- lib/services/database.go | 43 +++++++++------------------ lib/srv/db/mysql/autousers.go | 6 ++++ lib/srv/db/postgres/users.go | 6 ++++ 7 files changed, 130 insertions(+), 46 deletions(-) diff --git a/api/types/database.go b/api/types/database.go index 8548e743211ea..18a9d3bb04c7f 100644 --- a/api/types/database.go +++ b/api/types/database.go @@ -1017,6 +1017,12 @@ func (d *DatabaseV3) GetEndpointType() string { return d.GetAWS().MemoryDB.EndpointType case DatabaseTypeOpenSearch: return d.GetAWS().OpenSearch.EndpointType + case DatabaseTypeRDS: + // If not available from discovery tags, get the endpoint type from the + // URL. + if details, err := awsutils.ParseRDSEndpoint(d.GetURI()); err == nil { + return details.EndpointType + } } return "" } diff --git a/api/types/database_test.go b/api/types/database_test.go index 3ca2d4cdbf69c..41a5eabe651eb 100644 --- a/api/types/database_test.go +++ b/api/types/database_test.go @@ -35,10 +35,12 @@ func TestDatabaseRDSEndpoint(t *testing.T) { } for _, tt := range []struct { - name string - spec DatabaseSpecV3 - errorCheck require.ErrorAssertionFunc - expectedAWS AWS + name string + labels map[string]string + spec DatabaseSpecV3 + errorCheck require.ErrorAssertionFunc + expectedAWS AWS + expectedEndpointType string }{ { name: "aurora instance", @@ -53,6 +55,7 @@ func TestDatabaseRDSEndpoint(t *testing.T) { InstanceID: "aurora-instance-1", }, }, + expectedEndpointType: "instance", }, { name: "invalid account id", @@ -69,7 +72,7 @@ func TestDatabaseRDSEndpoint(t *testing.T) { name: "valid account id", spec: DatabaseSpecV3{ Protocol: "postgres", - URI: "marcotest-db001.abcdefghijklmnop.us-east-1.rds.amazonaws.com:5432", + URI: "marcotest-db001.cluster-ro-abcdefghijklmnop.us-east-1.rds.amazonaws.com:5432", AWS: AWS{ AccountID: "123456789012", }, @@ -78,17 +81,52 @@ func TestDatabaseRDSEndpoint(t *testing.T) { expectedAWS: AWS{ Region: "us-east-1", RDS: RDS{ - InstanceID: "marcotest-db001", + ClusterID: "marcotest-db001", }, AccountID: "123456789012", }, + expectedEndpointType: "reader", + }, + { + name: "discovered instance", + labels: map[string]string{ + "account-id": "123456789012", + "endpoint-type": "primary", + "engine": "aurora-postgresql", + "engine-version": "15.2", + "region": "us-west-1", + "teleport.dev/cloud": "AWS", + "teleport.dev/origin": "cloud", + "teleport.internal/discovered-name": "rds", + }, + spec: DatabaseSpecV3{ + Protocol: "postgres", + URI: "discovered.rds.com:5432", + AWS: AWS{ + Region: "us-west-1", + RDS: RDS{ + InstanceID: "aurora-instance-1", + IAMAuth: true, + }, + }, + }, + errorCheck: require.NoError, + expectedAWS: AWS{ + Region: "us-west-1", + RDS: RDS{ + InstanceID: "aurora-instance-1", + IAMAuth: true, + }, + }, + expectedEndpointType: "primary", }, } { tt := tt t.Run(tt.name, func(t *testing.T) { database, err := NewDatabaseV3( Metadata{ - Name: "rds", + Labels: tt.labels, + Name: "rds", }, tt.spec, ) @@ -98,6 +136,7 @@ func TestDatabaseRDSEndpoint(t *testing.T) { } require.Equal(t, tt.expectedAWS, database.GetAWS()) + require.Equal(t, tt.expectedEndpointType, database.GetEndpointType()) }) } } diff --git a/api/utils/aws/endpoint.go b/api/utils/aws/endpoint.go index e7458cd8747cb..315658234c56b 100644 --- a/api/utils/aws/endpoint.go +++ b/api/utils/aws/endpoint.go @@ -93,6 +93,12 @@ type RDSEndpointDetails struct { ProxyCustomEndpointName string // Region is the AWS region the database resides in. Region string + // EndpointType specifies the type of the endpoint, if available. + // + // Note that the endpoint type of RDS Proxies are determined by their + // targets, so the endpoint type will be empty for RDS Proxies here as it + // cannot be decided by the endpoint URL itself. + EndpointType string } // IsProxy returns true if the RDS endpoint is an RDS Proxy. @@ -189,12 +195,21 @@ func parseRDSWithoutSuffixes(endpoint string, parts []string, region string) (*R return &RDSEndpointDetails{ ClusterCustomEndpointName: parts[0], Region: region, + EndpointType: RDSEndpointTypeCustom, + }, nil + + case strings.HasPrefix(parts[1], "cluster-ro-"): + return &RDSEndpointDetails{ + ClusterID: parts[0], + Region: region, + EndpointType: RDSEndpointTypeReader, }, nil case strings.HasPrefix(parts[1], "cluster-"): return &RDSEndpointDetails{ - ClusterID: parts[0], - Region: region, + ClusterID: parts[0], + Region: region, + EndpointType: RDSEndpointTypePrimary, }, nil case strings.HasPrefix(parts[1], "proxy-"): @@ -205,8 +220,9 @@ func parseRDSWithoutSuffixes(endpoint string, parts []string, region string) (*R default: return &RDSEndpointDetails{ - InstanceID: parts[0], - Region: region, + InstanceID: parts[0], + Region: region, + EndpointType: RDSEndpointTypeInstance, }, nil } @@ -364,6 +380,18 @@ const ( OpenSearchCustomEndpoint = "custom" // OpenSearchVPCEndpoint is the VPC endpoint for domain. OpenSearchVPCEndpoint = "vpc" + + // RDSEndpointTypePrimary is the endpoint that specifies the connection for + // the primary instance of the RDS cluster. + RDSEndpointTypePrimary = "primary" + // RDSEndpointTypeReader is the endpoint that load-balances connections + // across the Aurora Replicas that are available in an RDS cluster. + RDSEndpointTypeReader = "reader" + // RDSEndpointTypeCustom is the endpoint that specifies one of the custom + // endpoints associated with the RDS cluster. + RDSEndpointTypeCustom = "custom" + // RDSEndpointTypeInstance is the endpoint of an RDS DB instance. + RDSEndpointTypeInstance = "instance" ) // ParseElastiCacheEndpoint extracts the details from the provided diff --git a/api/utils/aws/endpoint_test.go b/api/utils/aws/endpoint_test.go index c627a75877be0..e2685b5862aa7 100644 --- a/api/utils/aws/endpoint_test.go +++ b/api/utils/aws/endpoint_test.go @@ -37,8 +37,9 @@ func TestParseRDSEndpoint(t *testing.T) { endpoint: "aurora-instance-1.abcdefghijklmnop.us-west-1.rds.amazonaws.com:5432", expectIsRDSEndpoint: true, expectDetails: &RDSEndpointDetails{ - InstanceID: "aurora-instance-1", - Region: "us-west-1", + InstanceID: "aurora-instance-1", + Region: "us-west-1", + EndpointType: "instance", }, }, { @@ -46,8 +47,9 @@ func TestParseRDSEndpoint(t *testing.T) { endpoint: "aurora-instance-2.abcdefghijklmnop.rds.cn-north-1.amazonaws.com.cn", expectIsRDSEndpoint: true, expectDetails: &RDSEndpointDetails{ - InstanceID: "aurora-instance-2", - Region: "cn-north-1", + InstanceID: "aurora-instance-2", + Region: "cn-north-1", + EndpointType: "instance", }, }, { @@ -55,8 +57,19 @@ func TestParseRDSEndpoint(t *testing.T) { endpoint: "my-cluster.cluster-abcdefghijklmnop.us-west-1.rds.amazonaws.com:5432", expectIsRDSEndpoint: true, expectDetails: &RDSEndpointDetails{ - ClusterID: "my-cluster", - Region: "us-west-1", + ClusterID: "my-cluster", + Region: "us-west-1", + EndpointType: "primary", + }, + }, + { + name: "RDS cluster reader", + endpoint: "my-cluster.cluster-ro-abcdefghijklmnop.us-west-1.rds.amazonaws.com:5432", + expectIsRDSEndpoint: true, + expectDetails: &RDSEndpointDetails{ + ClusterID: "my-cluster", + Region: "us-west-1", + EndpointType: "reader", }, }, { @@ -66,6 +79,7 @@ func TestParseRDSEndpoint(t *testing.T) { expectDetails: &RDSEndpointDetails{ ClusterCustomEndpointName: "my-custom", Region: "us-west-1", + EndpointType: "custom", }, }, { diff --git a/lib/services/database.go b/lib/services/database.go index 0c37c2f4c1ca2..243c912ade554 100644 --- a/lib/services/database.go +++ b/lib/services/database.go @@ -696,7 +696,7 @@ func labelsFromRDSV2Instance(rdsInstance *rdsTypesV2.DBInstance, meta *types.AWS labels := labelsFromAWSMetadata(meta) labels[types.DiscoveryLabelEngine] = aws.StringValue(rdsInstance.Engine) labels[types.DiscoveryLabelEngineVersion] = aws.StringValue(rdsInstance.EngineVersion) - labels[types.DiscoveryLabelEndpointType] = string(RDSEndpointTypeInstance) + labels[types.DiscoveryLabelEndpointType] = apiawsutils.RDSEndpointTypeInstance labels[types.DiscoveryLabelStatus] = aws.StringValue(rdsInstance.DBInstanceStatus) if rdsInstance.DBSubnetGroup != nil { labels[types.DiscoveryLabelVPCID] = aws.StringValue(rdsInstance.DBSubnetGroup.VpcId) @@ -723,7 +723,7 @@ func NewDatabaseFromRDSV2Cluster(cluster *rdsTypesV2.DBCluster, firstInstance *r return types.NewDatabaseV3( setAWSDBName(types.Metadata{ Description: fmt.Sprintf("Aurora cluster in %v", metadata.Region), - Labels: labelsFromRDSV2Cluster(cluster, metadata, RDSEndpointTypePrimary, firstInstance), + Labels: labelsFromRDSV2Cluster(cluster, metadata, apiawsutils.RDSEndpointTypePrimary, firstInstance), }, aws.StringValue(cluster.DBClusterIdentifier)), types.DatabaseSpecV3{ Protocol: protocol, @@ -780,11 +780,11 @@ func MetadataFromRDSV2Cluster(rdsCluster *rdsTypesV2.DBCluster, rdsInstance *rds // labelsFromRDSV2Cluster creates database labels for the provided RDS cluster. // It uses aws sdk v2. -func labelsFromRDSV2Cluster(rdsCluster *rdsTypesV2.DBCluster, meta *types.AWS, endpointType RDSEndpointType, memberInstance *rdsTypesV2.DBInstance) map[string]string { +func labelsFromRDSV2Cluster(rdsCluster *rdsTypesV2.DBCluster, meta *types.AWS, endpointType string, memberInstance *rdsTypesV2.DBInstance) map[string]string { labels := labelsFromAWSMetadata(meta) labels[types.DiscoveryLabelEngine] = aws.StringValue(rdsCluster.Engine) labels[types.DiscoveryLabelEngineVersion] = aws.StringValue(rdsCluster.EngineVersion) - labels[types.DiscoveryLabelEndpointType] = string(endpointType) + labels[types.DiscoveryLabelEndpointType] = endpointType labels[types.DiscoveryLabelStatus] = aws.StringValue(rdsCluster.Status) if memberInstance != nil && memberInstance.DBSubnetGroup != nil { labels[types.DiscoveryLabelVPCID] = aws.StringValue(memberInstance.DBSubnetGroup.VpcId) @@ -805,7 +805,7 @@ func NewDatabaseFromRDSCluster(cluster *rds.DBCluster, memberInstances []*rds.DB return types.NewDatabaseV3( setAWSDBName(types.Metadata{ Description: fmt.Sprintf("Aurora cluster in %v", metadata.Region), - Labels: labelsFromRDSCluster(cluster, metadata, RDSEndpointTypePrimary, memberInstances), + Labels: labelsFromRDSCluster(cluster, metadata, apiawsutils.RDSEndpointTypePrimary, memberInstances), }, aws.StringValue(cluster.DBClusterIdentifier)), types.DatabaseSpecV3{ Protocol: protocol, @@ -826,9 +826,9 @@ func NewDatabaseFromRDSClusterReaderEndpoint(cluster *rds.DBCluster, memberInsta } return types.NewDatabaseV3( setAWSDBName(types.Metadata{ - Description: fmt.Sprintf("Aurora cluster in %v (%v endpoint)", metadata.Region, string(RDSEndpointTypeReader)), - Labels: labelsFromRDSCluster(cluster, metadata, RDSEndpointTypeReader, memberInstances), - }, aws.StringValue(cluster.DBClusterIdentifier), string(RDSEndpointTypeReader)), + Description: fmt.Sprintf("Aurora cluster in %v (%v endpoint)", metadata.Region, apiawsutils.RDSEndpointTypeReader), + Labels: labelsFromRDSCluster(cluster, metadata, apiawsutils.RDSEndpointTypeReader, memberInstances), + }, aws.StringValue(cluster.DBClusterIdentifier), apiawsutils.RDSEndpointTypeReader), types.DatabaseSpecV3{ Protocol: protocol, URI: fmt.Sprintf("%v:%v", aws.StringValue(cluster.ReaderEndpoint), aws.Int64Value(cluster.Port)), @@ -864,9 +864,9 @@ func NewDatabasesFromRDSClusterCustomEndpoints(cluster *rds.DBCluster, memberIns database, err := types.NewDatabaseV3( setAWSDBName(types.Metadata{ - Description: fmt.Sprintf("Aurora cluster in %v (%v endpoint)", metadata.Region, string(RDSEndpointTypeCustom)), - Labels: labelsFromRDSCluster(cluster, metadata, RDSEndpointTypeCustom, memberInstances), - }, aws.StringValue(cluster.DBClusterIdentifier), string(RDSEndpointTypeCustom), endpointDetails.ClusterCustomEndpointName), + Description: fmt.Sprintf("Aurora cluster in %v (%v endpoint)", metadata.Region, apiawsutils.RDSEndpointTypeCustom), + Labels: labelsFromRDSCluster(cluster, metadata, apiawsutils.RDSEndpointTypeCustom, memberInstances), + }, aws.StringValue(cluster.DBClusterIdentifier), apiawsutils.RDSEndpointTypeCustom, endpointDetails.ClusterCustomEndpointName), types.DatabaseSpecV3{ Protocol: protocol, URI: fmt.Sprintf("%v:%v", aws.StringValue(endpoint), aws.Int64Value(cluster.Port)), @@ -1631,7 +1631,7 @@ func labelsFromRDSInstance(rdsInstance *rds.DBInstance, meta *types.AWS) map[str labels := labelsFromAWSMetadata(meta) labels[types.DiscoveryLabelEngine] = aws.StringValue(rdsInstance.Engine) labels[types.DiscoveryLabelEngineVersion] = aws.StringValue(rdsInstance.EngineVersion) - labels[types.DiscoveryLabelEndpointType] = string(RDSEndpointTypeInstance) + labels[types.DiscoveryLabelEndpointType] = apiawsutils.RDSEndpointTypeInstance if rdsInstance.DBSubnetGroup != nil { labels[types.DiscoveryLabelVPCID] = aws.StringValue(rdsInstance.DBSubnetGroup.VpcId) } @@ -1639,11 +1639,11 @@ func labelsFromRDSInstance(rdsInstance *rds.DBInstance, meta *types.AWS) map[str } // labelsFromRDSCluster creates database labels for the provided RDS cluster. -func labelsFromRDSCluster(rdsCluster *rds.DBCluster, meta *types.AWS, endpointType RDSEndpointType, memberInstances []*rds.DBInstance) map[string]string { +func labelsFromRDSCluster(rdsCluster *rds.DBCluster, meta *types.AWS, endpointType string, memberInstances []*rds.DBInstance) map[string]string { labels := labelsFromAWSMetadata(meta) labels[types.DiscoveryLabelEngine] = aws.StringValue(rdsCluster.Engine) labels[types.DiscoveryLabelEngineVersion] = aws.StringValue(rdsCluster.EngineVersion) - labels[types.DiscoveryLabelEndpointType] = string(endpointType) + labels[types.DiscoveryLabelEndpointType] = endpointType if len(memberInstances) > 0 && memberInstances[0].DBSubnetGroup != nil { labels[types.DiscoveryLabelVPCID] = aws.StringValue(memberInstances[0].DBSubnetGroup.VpcId) } @@ -2022,21 +2022,6 @@ const ( RDSEngineAuroraPostgres = "aurora-postgresql" ) -// RDSEndpointType specifies the endpoint type for RDS clusters. -type RDSEndpointType string - -const ( - // RDSEndpointTypePrimary is the endpoint that specifies the connection for the primary instance of the RDS cluster. - RDSEndpointTypePrimary RDSEndpointType = "primary" - // RDSEndpointTypeReader is the endpoint that load-balances connections across the Aurora Replicas that are - // available in an RDS cluster. - RDSEndpointTypeReader RDSEndpointType = "reader" - // RDSEndpointTypeCustom is the endpoint that specifies one of the custom endpoints associated with the RDS cluster. - RDSEndpointTypeCustom RDSEndpointType = "custom" - // RDSEndpointTypeInstance is the endpoint of an RDS DB instance. - RDSEndpointTypeInstance RDSEndpointType = "instance" -) - const ( // RDSEngineModeProvisioned is the RDS engine mode for provisioned Aurora clusters RDSEngineModeProvisioned = "provisioned" diff --git a/lib/srv/db/mysql/autousers.go b/lib/srv/db/mysql/autousers.go index d174d490b7008..26e8fffda5975 100644 --- a/lib/srv/db/mysql/autousers.go +++ b/lib/srv/db/mysql/autousers.go @@ -34,6 +34,7 @@ import ( "golang.org/x/exp/slices" "github.com/gravitational/teleport/api/types" + apiawsutils "github.com/gravitational/teleport/api/utils/aws" "github.com/gravitational/teleport/lib/srv/db/common" ) @@ -111,6 +112,11 @@ func (e *Engine) ActivateUser(ctx context.Context, sessionCtx *common.Session) e return trace.BadParameter("Teleport does not have admin user configured for this database") } + if sessionCtx.Database.IsRDS() && + sessionCtx.Database.GetEndpointType() == apiawsutils.RDSEndpointTypeReader { + return trace.BadParameter("auto-user provisioning is not supported for RDS reader endpoints") + } + conn, err := e.connectAsAdminUser(ctx, sessionCtx) if err != nil { return trace.Wrap(err) diff --git a/lib/srv/db/postgres/users.go b/lib/srv/db/postgres/users.go index 8a94f5dedcae7..861c70b697c21 100644 --- a/lib/srv/db/postgres/users.go +++ b/lib/srv/db/postgres/users.go @@ -27,6 +27,7 @@ import ( "github.com/jackc/pgx/v4" "github.com/gravitational/teleport/api/types" + apiawsutils "github.com/gravitational/teleport/api/utils/aws" "github.com/gravitational/teleport/lib/srv/db/common" ) @@ -47,6 +48,11 @@ func (e *Engine) ActivateUser(ctx context.Context, sessionCtx *common.Session) e return trace.BadParameter("Teleport does not have admin user configured for this database") } + if sessionCtx.Database.IsRDS() && + sessionCtx.Database.GetEndpointType() == apiawsutils.RDSEndpointTypeReader { + return trace.BadParameter("auto-user provisioning is not supported for RDS reader endpoints") + } + conn, err := e.connectAsAdmin(ctx, sessionCtx) if err != nil { return trace.Wrap(err)