Skip to content

Commit

Permalink
[v14] Improve error message when auto-user provisioning fails on read…
Browse files Browse the repository at this point in the history
…er endpoints (#37954)

* Improve error message when auto-user provisioning fails on reader endpoints

* fix lint
  • Loading branch information
greedy52 authored Feb 22, 2024
1 parent 3722a85 commit 5e48a04
Show file tree
Hide file tree
Showing 7 changed files with 130 additions and 46 deletions.
6 changes: 6 additions & 0 deletions api/types/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -1017,6 +1017,12 @@ func (d *DatabaseV3) GetEndpointType() string {
return d.GetAWS().MemoryDB.EndpointType
case DatabaseTypeOpenSearch:
return d.GetAWS().OpenSearch.EndpointType
case DatabaseTypeRDS:
// If not available from discovery tags, get the endpoint type from the
// URL.
if details, err := awsutils.ParseRDSEndpoint(d.GetURI()); err == nil {
return details.EndpointType
}
}
return ""
}
Expand Down
53 changes: 46 additions & 7 deletions api/types/database_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,12 @@ func TestDatabaseRDSEndpoint(t *testing.T) {
}

for _, tt := range []struct {
name string
spec DatabaseSpecV3
errorCheck require.ErrorAssertionFunc
expectedAWS AWS
name string
labels map[string]string
spec DatabaseSpecV3
errorCheck require.ErrorAssertionFunc
expectedAWS AWS
expectedEndpointType string
}{
{
name: "aurora instance",
Expand All @@ -53,6 +55,7 @@ func TestDatabaseRDSEndpoint(t *testing.T) {
InstanceID: "aurora-instance-1",
},
},
expectedEndpointType: "instance",
},
{
name: "invalid account id",
Expand All @@ -69,7 +72,7 @@ func TestDatabaseRDSEndpoint(t *testing.T) {
name: "valid account id",
spec: DatabaseSpecV3{
Protocol: "postgres",
URI: "marcotest-db001.abcdefghijklmnop.us-east-1.rds.amazonaws.com:5432",
URI: "marcotest-db001.cluster-ro-abcdefghijklmnop.us-east-1.rds.amazonaws.com:5432",
AWS: AWS{
AccountID: "123456789012",
},
Expand All @@ -78,17 +81,52 @@ func TestDatabaseRDSEndpoint(t *testing.T) {
expectedAWS: AWS{
Region: "us-east-1",
RDS: RDS{
InstanceID: "marcotest-db001",
ClusterID: "marcotest-db001",
},
AccountID: "123456789012",
},
expectedEndpointType: "reader",
},
{
name: "discovered instance",
labels: map[string]string{
"account-id": "123456789012",
"endpoint-type": "primary",
"engine": "aurora-postgresql",
"engine-version": "15.2",
"region": "us-west-1",
"teleport.dev/cloud": "AWS",
"teleport.dev/origin": "cloud",
"teleport.internal/discovered-name": "rds",
},
spec: DatabaseSpecV3{
Protocol: "postgres",
URI: "discovered.rds.com:5432",
AWS: AWS{
Region: "us-west-1",
RDS: RDS{
InstanceID: "aurora-instance-1",
IAMAuth: true,
},
},
},
errorCheck: require.NoError,
expectedAWS: AWS{
Region: "us-west-1",
RDS: RDS{
InstanceID: "aurora-instance-1",
IAMAuth: true,
},
},
expectedEndpointType: "primary",
},
} {
tt := tt
t.Run(tt.name, func(t *testing.T) {
database, err := NewDatabaseV3(
Metadata{
Name: "rds",
Labels: tt.labels,
Name: "rds",
},
tt.spec,
)
Expand All @@ -98,6 +136,7 @@ func TestDatabaseRDSEndpoint(t *testing.T) {
}

require.Equal(t, tt.expectedAWS, database.GetAWS())
require.Equal(t, tt.expectedEndpointType, database.GetEndpointType())
})
}
}
Expand Down
36 changes: 32 additions & 4 deletions api/utils/aws/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,12 @@ type RDSEndpointDetails struct {
ProxyCustomEndpointName string
// Region is the AWS region the database resides in.
Region string
// EndpointType specifies the type of the endpoint, if available.
//
// Note that the endpoint type of RDS Proxies are determined by their
// targets, so the endpoint type will be empty for RDS Proxies here as it
// cannot be decided by the endpoint URL itself.
EndpointType string
}

// IsProxy returns true if the RDS endpoint is an RDS Proxy.
Expand Down Expand Up @@ -189,12 +195,21 @@ func parseRDSWithoutSuffixes(endpoint string, parts []string, region string) (*R
return &RDSEndpointDetails{
ClusterCustomEndpointName: parts[0],
Region: region,
EndpointType: RDSEndpointTypeCustom,
}, nil

case strings.HasPrefix(parts[1], "cluster-ro-"):
return &RDSEndpointDetails{
ClusterID: parts[0],
Region: region,
EndpointType: RDSEndpointTypeReader,
}, nil

case strings.HasPrefix(parts[1], "cluster-"):
return &RDSEndpointDetails{
ClusterID: parts[0],
Region: region,
ClusterID: parts[0],
Region: region,
EndpointType: RDSEndpointTypePrimary,
}, nil

case strings.HasPrefix(parts[1], "proxy-"):
Expand All @@ -205,8 +220,9 @@ func parseRDSWithoutSuffixes(endpoint string, parts []string, region string) (*R

default:
return &RDSEndpointDetails{
InstanceID: parts[0],
Region: region,
InstanceID: parts[0],
Region: region,
EndpointType: RDSEndpointTypeInstance,
}, nil
}

Expand Down Expand Up @@ -364,6 +380,18 @@ const (
OpenSearchCustomEndpoint = "custom"
// OpenSearchVPCEndpoint is the VPC endpoint for domain.
OpenSearchVPCEndpoint = "vpc"

// RDSEndpointTypePrimary is the endpoint that specifies the connection for
// the primary instance of the RDS cluster.
RDSEndpointTypePrimary = "primary"
// RDSEndpointTypeReader is the endpoint that load-balances connections
// across the Aurora Replicas that are available in an RDS cluster.
RDSEndpointTypeReader = "reader"
// RDSEndpointTypeCustom is the endpoint that specifies one of the custom
// endpoints associated with the RDS cluster.
RDSEndpointTypeCustom = "custom"
// RDSEndpointTypeInstance is the endpoint of an RDS DB instance.
RDSEndpointTypeInstance = "instance"
)

// ParseElastiCacheEndpoint extracts the details from the provided
Expand Down
26 changes: 20 additions & 6 deletions api/utils/aws/endpoint_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,26 +37,39 @@ func TestParseRDSEndpoint(t *testing.T) {
endpoint: "aurora-instance-1.abcdefghijklmnop.us-west-1.rds.amazonaws.com:5432",
expectIsRDSEndpoint: true,
expectDetails: &RDSEndpointDetails{
InstanceID: "aurora-instance-1",
Region: "us-west-1",
InstanceID: "aurora-instance-1",
Region: "us-west-1",
EndpointType: "instance",
},
},
{
name: "RDS instance in cn-north-1",
endpoint: "aurora-instance-2.abcdefghijklmnop.rds.cn-north-1.amazonaws.com.cn",
expectIsRDSEndpoint: true,
expectDetails: &RDSEndpointDetails{
InstanceID: "aurora-instance-2",
Region: "cn-north-1",
InstanceID: "aurora-instance-2",
Region: "cn-north-1",
EndpointType: "instance",
},
},
{
name: "RDS cluster",
endpoint: "my-cluster.cluster-abcdefghijklmnop.us-west-1.rds.amazonaws.com:5432",
expectIsRDSEndpoint: true,
expectDetails: &RDSEndpointDetails{
ClusterID: "my-cluster",
Region: "us-west-1",
ClusterID: "my-cluster",
Region: "us-west-1",
EndpointType: "primary",
},
},
{
name: "RDS cluster reader",
endpoint: "my-cluster.cluster-ro-abcdefghijklmnop.us-west-1.rds.amazonaws.com:5432",
expectIsRDSEndpoint: true,
expectDetails: &RDSEndpointDetails{
ClusterID: "my-cluster",
Region: "us-west-1",
EndpointType: "reader",
},
},
{
Expand All @@ -66,6 +79,7 @@ func TestParseRDSEndpoint(t *testing.T) {
expectDetails: &RDSEndpointDetails{
ClusterCustomEndpointName: "my-custom",
Region: "us-west-1",
EndpointType: "custom",
},
},
{
Expand Down
43 changes: 14 additions & 29 deletions lib/services/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -696,7 +696,7 @@ func labelsFromRDSV2Instance(rdsInstance *rdsTypesV2.DBInstance, meta *types.AWS
labels := labelsFromAWSMetadata(meta)
labels[types.DiscoveryLabelEngine] = aws.StringValue(rdsInstance.Engine)
labels[types.DiscoveryLabelEngineVersion] = aws.StringValue(rdsInstance.EngineVersion)
labels[types.DiscoveryLabelEndpointType] = string(RDSEndpointTypeInstance)
labels[types.DiscoveryLabelEndpointType] = apiawsutils.RDSEndpointTypeInstance
labels[types.DiscoveryLabelStatus] = aws.StringValue(rdsInstance.DBInstanceStatus)
if rdsInstance.DBSubnetGroup != nil {
labels[types.DiscoveryLabelVPCID] = aws.StringValue(rdsInstance.DBSubnetGroup.VpcId)
Expand All @@ -723,7 +723,7 @@ func NewDatabaseFromRDSV2Cluster(cluster *rdsTypesV2.DBCluster, firstInstance *r
return types.NewDatabaseV3(
setAWSDBName(types.Metadata{
Description: fmt.Sprintf("Aurora cluster in %v", metadata.Region),
Labels: labelsFromRDSV2Cluster(cluster, metadata, RDSEndpointTypePrimary, firstInstance),
Labels: labelsFromRDSV2Cluster(cluster, metadata, apiawsutils.RDSEndpointTypePrimary, firstInstance),
}, aws.StringValue(cluster.DBClusterIdentifier)),
types.DatabaseSpecV3{
Protocol: protocol,
Expand Down Expand Up @@ -780,11 +780,11 @@ func MetadataFromRDSV2Cluster(rdsCluster *rdsTypesV2.DBCluster, rdsInstance *rds

// labelsFromRDSV2Cluster creates database labels for the provided RDS cluster.
// It uses aws sdk v2.
func labelsFromRDSV2Cluster(rdsCluster *rdsTypesV2.DBCluster, meta *types.AWS, endpointType RDSEndpointType, memberInstance *rdsTypesV2.DBInstance) map[string]string {
func labelsFromRDSV2Cluster(rdsCluster *rdsTypesV2.DBCluster, meta *types.AWS, endpointType string, memberInstance *rdsTypesV2.DBInstance) map[string]string {
labels := labelsFromAWSMetadata(meta)
labels[types.DiscoveryLabelEngine] = aws.StringValue(rdsCluster.Engine)
labels[types.DiscoveryLabelEngineVersion] = aws.StringValue(rdsCluster.EngineVersion)
labels[types.DiscoveryLabelEndpointType] = string(endpointType)
labels[types.DiscoveryLabelEndpointType] = endpointType
labels[types.DiscoveryLabelStatus] = aws.StringValue(rdsCluster.Status)
if memberInstance != nil && memberInstance.DBSubnetGroup != nil {
labels[types.DiscoveryLabelVPCID] = aws.StringValue(memberInstance.DBSubnetGroup.VpcId)
Expand All @@ -805,7 +805,7 @@ func NewDatabaseFromRDSCluster(cluster *rds.DBCluster, memberInstances []*rds.DB
return types.NewDatabaseV3(
setAWSDBName(types.Metadata{
Description: fmt.Sprintf("Aurora cluster in %v", metadata.Region),
Labels: labelsFromRDSCluster(cluster, metadata, RDSEndpointTypePrimary, memberInstances),
Labels: labelsFromRDSCluster(cluster, metadata, apiawsutils.RDSEndpointTypePrimary, memberInstances),
}, aws.StringValue(cluster.DBClusterIdentifier)),
types.DatabaseSpecV3{
Protocol: protocol,
Expand All @@ -826,9 +826,9 @@ func NewDatabaseFromRDSClusterReaderEndpoint(cluster *rds.DBCluster, memberInsta
}
return types.NewDatabaseV3(
setAWSDBName(types.Metadata{
Description: fmt.Sprintf("Aurora cluster in %v (%v endpoint)", metadata.Region, string(RDSEndpointTypeReader)),
Labels: labelsFromRDSCluster(cluster, metadata, RDSEndpointTypeReader, memberInstances),
}, aws.StringValue(cluster.DBClusterIdentifier), string(RDSEndpointTypeReader)),
Description: fmt.Sprintf("Aurora cluster in %v (%v endpoint)", metadata.Region, apiawsutils.RDSEndpointTypeReader),
Labels: labelsFromRDSCluster(cluster, metadata, apiawsutils.RDSEndpointTypeReader, memberInstances),
}, aws.StringValue(cluster.DBClusterIdentifier), apiawsutils.RDSEndpointTypeReader),
types.DatabaseSpecV3{
Protocol: protocol,
URI: fmt.Sprintf("%v:%v", aws.StringValue(cluster.ReaderEndpoint), aws.Int64Value(cluster.Port)),
Expand Down Expand Up @@ -864,9 +864,9 @@ func NewDatabasesFromRDSClusterCustomEndpoints(cluster *rds.DBCluster, memberIns

database, err := types.NewDatabaseV3(
setAWSDBName(types.Metadata{
Description: fmt.Sprintf("Aurora cluster in %v (%v endpoint)", metadata.Region, string(RDSEndpointTypeCustom)),
Labels: labelsFromRDSCluster(cluster, metadata, RDSEndpointTypeCustom, memberInstances),
}, aws.StringValue(cluster.DBClusterIdentifier), string(RDSEndpointTypeCustom), endpointDetails.ClusterCustomEndpointName),
Description: fmt.Sprintf("Aurora cluster in %v (%v endpoint)", metadata.Region, apiawsutils.RDSEndpointTypeCustom),
Labels: labelsFromRDSCluster(cluster, metadata, apiawsutils.RDSEndpointTypeCustom, memberInstances),
}, aws.StringValue(cluster.DBClusterIdentifier), apiawsutils.RDSEndpointTypeCustom, endpointDetails.ClusterCustomEndpointName),
types.DatabaseSpecV3{
Protocol: protocol,
URI: fmt.Sprintf("%v:%v", aws.StringValue(endpoint), aws.Int64Value(cluster.Port)),
Expand Down Expand Up @@ -1631,19 +1631,19 @@ func labelsFromRDSInstance(rdsInstance *rds.DBInstance, meta *types.AWS) map[str
labels := labelsFromAWSMetadata(meta)
labels[types.DiscoveryLabelEngine] = aws.StringValue(rdsInstance.Engine)
labels[types.DiscoveryLabelEngineVersion] = aws.StringValue(rdsInstance.EngineVersion)
labels[types.DiscoveryLabelEndpointType] = string(RDSEndpointTypeInstance)
labels[types.DiscoveryLabelEndpointType] = apiawsutils.RDSEndpointTypeInstance
if rdsInstance.DBSubnetGroup != nil {
labels[types.DiscoveryLabelVPCID] = aws.StringValue(rdsInstance.DBSubnetGroup.VpcId)
}
return addLabels(labels, libcloudaws.TagsToLabels(rdsInstance.TagList))
}

// labelsFromRDSCluster creates database labels for the provided RDS cluster.
func labelsFromRDSCluster(rdsCluster *rds.DBCluster, meta *types.AWS, endpointType RDSEndpointType, memberInstances []*rds.DBInstance) map[string]string {
func labelsFromRDSCluster(rdsCluster *rds.DBCluster, meta *types.AWS, endpointType string, memberInstances []*rds.DBInstance) map[string]string {
labels := labelsFromAWSMetadata(meta)
labels[types.DiscoveryLabelEngine] = aws.StringValue(rdsCluster.Engine)
labels[types.DiscoveryLabelEngineVersion] = aws.StringValue(rdsCluster.EngineVersion)
labels[types.DiscoveryLabelEndpointType] = string(endpointType)
labels[types.DiscoveryLabelEndpointType] = endpointType
if len(memberInstances) > 0 && memberInstances[0].DBSubnetGroup != nil {
labels[types.DiscoveryLabelVPCID] = aws.StringValue(memberInstances[0].DBSubnetGroup.VpcId)
}
Expand Down Expand Up @@ -2022,21 +2022,6 @@ const (
RDSEngineAuroraPostgres = "aurora-postgresql"
)

// RDSEndpointType specifies the endpoint type for RDS clusters.
type RDSEndpointType string

const (
// RDSEndpointTypePrimary is the endpoint that specifies the connection for the primary instance of the RDS cluster.
RDSEndpointTypePrimary RDSEndpointType = "primary"
// RDSEndpointTypeReader is the endpoint that load-balances connections across the Aurora Replicas that are
// available in an RDS cluster.
RDSEndpointTypeReader RDSEndpointType = "reader"
// RDSEndpointTypeCustom is the endpoint that specifies one of the custom endpoints associated with the RDS cluster.
RDSEndpointTypeCustom RDSEndpointType = "custom"
// RDSEndpointTypeInstance is the endpoint of an RDS DB instance.
RDSEndpointTypeInstance RDSEndpointType = "instance"
)

const (
// RDSEngineModeProvisioned is the RDS engine mode for provisioned Aurora clusters
RDSEngineModeProvisioned = "provisioned"
Expand Down
6 changes: 6 additions & 0 deletions lib/srv/db/mysql/autousers.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import (
"golang.org/x/exp/slices"

"github.com/gravitational/teleport/api/types"
apiawsutils "github.com/gravitational/teleport/api/utils/aws"
"github.com/gravitational/teleport/lib/srv/db/common"
)

Expand Down Expand Up @@ -111,6 +112,11 @@ func (e *Engine) ActivateUser(ctx context.Context, sessionCtx *common.Session) e
return trace.BadParameter("Teleport does not have admin user configured for this database")
}

if sessionCtx.Database.IsRDS() &&
sessionCtx.Database.GetEndpointType() == apiawsutils.RDSEndpointTypeReader {
return trace.BadParameter("auto-user provisioning is not supported for RDS reader endpoints")
}

conn, err := e.connectAsAdminUser(ctx, sessionCtx)
if err != nil {
return trace.Wrap(err)
Expand Down
6 changes: 6 additions & 0 deletions lib/srv/db/postgres/users.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"github.com/jackc/pgx/v4"

"github.com/gravitational/teleport/api/types"
apiawsutils "github.com/gravitational/teleport/api/utils/aws"
"github.com/gravitational/teleport/lib/srv/db/common"
)

Expand All @@ -47,6 +48,11 @@ func (e *Engine) ActivateUser(ctx context.Context, sessionCtx *common.Session) e
return trace.BadParameter("Teleport does not have admin user configured for this database")
}

if sessionCtx.Database.IsRDS() &&
sessionCtx.Database.GetEndpointType() == apiawsutils.RDSEndpointTypeReader {
return trace.BadParameter("auto-user provisioning is not supported for RDS reader endpoints")
}

conn, err := e.connectAsAdmin(ctx, sessionCtx)
if err != nil {
return trace.Wrap(err)
Expand Down

0 comments on commit 5e48a04

Please sign in to comment.