diff --git a/lib/auth/integration/integrationv1/awsoidc.go b/lib/auth/integration/integrationv1/awsoidc.go
index d903dd22914c6..1e64c68f28eca 100644
--- a/lib/auth/integration/integrationv1/awsoidc.go
+++ b/lib/auth/integration/integrationv1/awsoidc.go
@@ -495,6 +495,58 @@ func (s *AWSOIDCService) DeployDatabaseService(ctx context.Context, req *integra
}, nil
}
+// ListDeployedDatabaseServices deploys Database Services into Amazon ECS.
+func (s *AWSOIDCService) ListDeployedDatabaseServices(ctx context.Context, req *integrationpb.ListDeployedDatabaseServicesRequest) (*integrationpb.ListDeployedDatabaseServicesResponse, error) {
+ authCtx, err := s.authorizer.Authorize(ctx)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+
+ if err := authCtx.CheckAccessToKind(types.KindIntegration, types.VerbUse); err != nil {
+ return nil, trace.Wrap(err)
+ }
+
+ clusterName, err := s.cache.GetClusterName()
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+
+ awsClientReq, err := s.awsClientReq(ctx, req.Integration, req.Region)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+
+ listDatabaseServicesClient, err := awsoidc.NewListDeployedDatabaseServicesClient(ctx, awsClientReq)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+
+ listDatabaseServicesResponse, err := awsoidc.ListDeployedDatabaseServices(ctx, listDatabaseServicesClient, awsoidc.ListDeployedDatabaseServicesRequest{
+ Integration: req.Integration,
+ TeleportClusterName: clusterName.GetClusterName(),
+ Region: req.Region,
+ NextToken: req.NextToken,
+ })
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+
+ deployedDatabaseServices := make([]*integrationpb.DeployedDatabaseService, 0, len(listDatabaseServicesResponse.DeployedDatabaseServices))
+ for _, deployedService := range listDatabaseServicesResponse.DeployedDatabaseServices {
+ deployedDatabaseServices = append(deployedDatabaseServices, &integrationpb.DeployedDatabaseService{
+ Name: deployedService.Name,
+ ServiceDashboardUrl: deployedService.ServiceDashboardURL,
+ ContainerEntryPoint: deployedService.ContainerEntryPoint,
+ ContainerCommand: deployedService.ContainerCommand,
+ })
+ }
+
+ return &integrationpb.ListDeployedDatabaseServicesResponse{
+ DeployedDatabaseServices: deployedDatabaseServices,
+ NextToken: listDatabaseServicesResponse.NextToken,
+ }, nil
+}
+
// EnrollEKSClusters enrolls EKS clusters into Teleport by installing teleport-kube-agent chart on the clusters.
func (s *AWSOIDCService) EnrollEKSClusters(ctx context.Context, req *integrationpb.EnrollEKSClustersRequest) (*integrationpb.EnrollEKSClustersResponse, error) {
authCtx, err := s.authorizer.Authorize(ctx)
diff --git a/lib/auth/integration/integrationv1/awsoidc_test.go b/lib/auth/integration/integrationv1/awsoidc_test.go
index f6cd0e925a48f..6a2497229ab38 100644
--- a/lib/auth/integration/integrationv1/awsoidc_test.go
+++ b/lib/auth/integration/integrationv1/awsoidc_test.go
@@ -423,6 +423,16 @@ func TestRBAC(t *testing.T) {
return err
},
},
+ {
+ name: "ListDeployedDatabaseServices",
+ fn: func() error {
+ _, err := awsoidService.ListDeployedDatabaseServices(userCtx, &integrationv1.ListDeployedDatabaseServicesRequest{
+ Integration: integrationName,
+ Region: "my-region",
+ })
+ return err
+ },
+ },
} {
t.Run(tt.name, func(t *testing.T) {
err := tt.fn()
diff --git a/lib/integrations/awsoidc/deployservice.go b/lib/integrations/awsoidc/deployservice.go
index b9fbc4b99c458..2e3572755b782 100644
--- a/lib/integrations/awsoidc/deployservice.go
+++ b/lib/integrations/awsoidc/deployservice.go
@@ -445,16 +445,18 @@ func DeployService(ctx context.Context, clt DeployServiceClient, req DeployServi
return nil, trace.Wrap(err)
}
- serviceDashboardURL := fmt.Sprintf("https://%s.console.aws.amazon.com/ecs/v2/clusters/%s/services/%s", req.Region, aws.ToString(req.ClusterName), aws.ToString(req.ServiceName))
-
return &DeployServiceResponse{
ClusterARN: aws.ToString(cluster.ClusterArn),
ServiceARN: aws.ToString(service.ServiceArn),
TaskDefinitionARN: taskDefinitionARN,
- ServiceDashboardURL: serviceDashboardURL,
+ ServiceDashboardURL: serviceDashboardURL(req.Region, aws.ToString(req.ClusterName), aws.ToString(req.ServiceName)),
}, nil
}
+func serviceDashboardURL(region, clusterName, serviceName string) string {
+ return fmt.Sprintf("https://%s.console.aws.amazon.com/ecs/v2/clusters/%s/services/%s", region, clusterName, serviceName)
+}
+
type upsertTaskRequest struct {
TaskName string
TaskRoleARN string
diff --git a/lib/integrations/awsoidc/listdeployeddatabaseservice.go b/lib/integrations/awsoidc/listdeployeddatabaseservice.go
new file mode 100644
index 0000000000000..c2894902f78fe
--- /dev/null
+++ b/lib/integrations/awsoidc/listdeployeddatabaseservice.go
@@ -0,0 +1,194 @@
+/*
+ * Teleport
+ * Copyright (C) 2024 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package awsoidc
+
+import (
+ "context"
+ "log/slog"
+
+ "github.com/aws/aws-sdk-go-v2/aws"
+ "github.com/aws/aws-sdk-go-v2/service/ecs"
+ ecstypes "github.com/aws/aws-sdk-go-v2/service/ecs/types"
+ "github.com/gravitational/trace"
+
+ "github.com/gravitational/teleport/lib/integrations/awsoidc/tags"
+)
+
+// ListDeployedDatabaseServicesRequest contains the required fields to list the deployed database services in Amazon ECS.
+type ListDeployedDatabaseServicesRequest struct {
+ // Region is the AWS Region.
+ Region string
+ // Integration is the AWS OIDC Integration name
+ Integration string
+ // TeleportClusterName is the name of the Teleport Cluster.
+ // Used to uniquely identify the ECS Cluster in Amazon.
+ TeleportClusterName string
+ // NextToken is the token to be used to fetch the next page.
+ // If empty, the first page is fetched.
+ NextToken string
+}
+
+func (req *ListDeployedDatabaseServicesRequest) checkAndSetDefaults() error {
+ if req.Region == "" {
+ return trace.BadParameter("region is required")
+ }
+
+ if req.Integration == "" {
+ return trace.BadParameter("integration is required")
+ }
+
+ if req.TeleportClusterName == "" {
+ return trace.BadParameter("teleport cluster name is required")
+ }
+
+ return nil
+}
+
+// ListDeployedDatabaseServicesResponse contains a page of Deployed Database Services.
+type ListDeployedDatabaseServicesResponse struct {
+ // DeployedDatabaseServices contains the page of Deployed Database Services.
+ DeployedDatabaseServices []DeployedDatabaseService `json:"deployedDatabaseServices"`
+
+ // NextToken is used for pagination.
+ // If non-empty, it can be used to request the next page.
+ NextToken string `json:"nextToken"`
+}
+
+// DeployedDatabaseService contains a database service that was deployed to Amazon ECS.
+type DeployedDatabaseService struct {
+ // Name is the ECS Service name.
+ Name string
+ // ServiceDashboardURL is the Amazon Web Console URL for this ECS Service.
+ ServiceDashboardURL string
+ // ContainerEntryPoint is the entry point for the container 0 that is running in the ECS Task.
+ ContainerEntryPoint []string
+ // ContainerCommand is the list of arguments that are passed into the ContainerEntryPoint.
+ ContainerCommand []string
+}
+
+// ListDeployedDatabaseServicesClient describes the required methods to list AWS VPCs.
+type ListDeployedDatabaseServicesClient interface {
+ // ListServices returns a list of services.
+ ListServices(ctx context.Context, params *ecs.ListServicesInput, optFns ...func(*ecs.Options)) (*ecs.ListServicesOutput, error)
+ // DescribeServices returns ECS Services details.
+ DescribeServices(ctx context.Context, params *ecs.DescribeServicesInput, optFns ...func(*ecs.Options)) (*ecs.DescribeServicesOutput, error)
+ // DescribeTaskDefinition returns an ECS Task Definition.
+ DescribeTaskDefinition(ctx context.Context, params *ecs.DescribeTaskDefinitionInput, optFns ...func(*ecs.Options)) (*ecs.DescribeTaskDefinitionOutput, error)
+}
+
+type defaultListDeployedDatabaseServicesClient struct {
+ *ecs.Client
+}
+
+// NewListDeployedDatabaseServicesClient creates a new ListDeployedDatabaseServicesClient using an AWSClientRequest.
+func NewListDeployedDatabaseServicesClient(ctx context.Context, req *AWSClientRequest) (ListDeployedDatabaseServicesClient, error) {
+ ecsClient, err := newECSClient(ctx, req)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+
+ return &defaultListDeployedDatabaseServicesClient{
+ Client: ecsClient,
+ }, nil
+}
+
+// ListDeployedDatabaseServices calls the following AWS API:
+// https://docs.aws.amazon.com/AmazonECS/latest/APIReference/API_ListServices.html
+// https://docs.aws.amazon.com/AmazonECS/latest/APIReference/API_DescribeServices.html
+// https://docs.aws.amazon.com/AmazonECS/latest/APIReference/API_DescribeTaskDefinition.html
+// It returns a list of ECS Services running Teleport Database Service and an optional NextToken that can be used to fetch the next page.
+func ListDeployedDatabaseServices(ctx context.Context, clt ListDeployedDatabaseServicesClient, req ListDeployedDatabaseServicesRequest) (*ListDeployedDatabaseServicesResponse, error) {
+ if err := req.checkAndSetDefaults(); err != nil {
+ return nil, trace.Wrap(err)
+ }
+
+ clusterName := normalizeECSClusterName(req.TeleportClusterName)
+
+ log := slog.With(
+ "integration", req.Integration,
+ "aws_region", req.Region,
+ "ecs_cluster", clusterName,
+ )
+
+ // Do not increase this value because ecs.DescribeServices only allows up to 10 services per API call.
+ maxServicesPerPage := aws.Int32(10)
+ listServicesInput := &ecs.ListServicesInput{
+ Cluster: &clusterName,
+ MaxResults: maxServicesPerPage,
+ LaunchType: ecstypes.LaunchTypeFargate,
+ }
+ if req.NextToken != "" {
+ listServicesInput.NextToken = &req.NextToken
+ }
+
+ listServicesOutput, err := clt.ListServices(ctx, listServicesInput)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+
+ describeServicesOutput, err := clt.DescribeServices(ctx, &ecs.DescribeServicesInput{
+ Services: listServicesOutput.ServiceArns,
+ Include: []ecstypes.ServiceField{ecstypes.ServiceFieldTags},
+ Cluster: &clusterName,
+ })
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+
+ ownershipTags := tags.DefaultResourceCreationTags(req.TeleportClusterName, req.Integration)
+
+ deployedDatabaseServices := []DeployedDatabaseService{}
+ for _, ecsService := range describeServicesOutput.Services {
+ log := log.With("ecs_service", aws.ToString(ecsService.ServiceName))
+ if !ownershipTags.MatchesECSTags(ecsService.Tags) {
+ log.WarnContext(ctx, "Missing ownership tags in ECS Service, skipping")
+ continue
+ }
+
+ taskDefinitionOut, err := clt.DescribeTaskDefinition(ctx, &ecs.DescribeTaskDefinitionInput{
+ TaskDefinition: ecsService.TaskDefinition,
+ })
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+
+ if len(taskDefinitionOut.TaskDefinition.ContainerDefinitions) == 0 {
+ log.WarnContext(ctx, "Task has no containers defined, skipping",
+ "ecs_task_family", aws.ToString(taskDefinitionOut.TaskDefinition.Family),
+ "ecs_task_revision", taskDefinitionOut.TaskDefinition.Revision,
+ )
+ continue
+ }
+
+ entryPoint := taskDefinitionOut.TaskDefinition.ContainerDefinitions[0].EntryPoint
+ command := taskDefinitionOut.TaskDefinition.ContainerDefinitions[0].Command
+
+ deployedDatabaseServices = append(deployedDatabaseServices, DeployedDatabaseService{
+ Name: aws.ToString(ecsService.ServiceName),
+ ServiceDashboardURL: serviceDashboardURL(req.Region, clusterName, aws.ToString(ecsService.ServiceName)),
+ ContainerEntryPoint: entryPoint,
+ ContainerCommand: command,
+ })
+ }
+
+ return &ListDeployedDatabaseServicesResponse{
+ DeployedDatabaseServices: deployedDatabaseServices,
+ NextToken: aws.ToString(listServicesOutput.NextToken),
+ }, nil
+}
diff --git a/lib/integrations/awsoidc/listdeployeddatabaseservice_test.go b/lib/integrations/awsoidc/listdeployeddatabaseservice_test.go
new file mode 100644
index 0000000000000..67f332d495c2b
--- /dev/null
+++ b/lib/integrations/awsoidc/listdeployeddatabaseservice_test.go
@@ -0,0 +1,360 @@
+/*
+ * Teleport
+ * Copyright (C) 2024 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package awsoidc
+
+import (
+ "context"
+ "fmt"
+ "strconv"
+ "testing"
+
+ "github.com/aws/aws-sdk-go-v2/aws"
+ "github.com/aws/aws-sdk-go-v2/service/ecs"
+ ecstypes "github.com/aws/aws-sdk-go-v2/service/ecs/types"
+ "github.com/google/go-cmp/cmp"
+ "github.com/gravitational/trace"
+ "github.com/stretchr/testify/require"
+)
+
+func TestListDeployedDatabaseServicesRequest(t *testing.T) {
+ isBadParamErrFn := func(tt require.TestingT, err error, i ...any) {
+ require.True(tt, trace.IsBadParameter(err), "expected bad parameter, got %v", err)
+ }
+
+ baseReqFn := func() ListDeployedDatabaseServicesRequest {
+ return ListDeployedDatabaseServicesRequest{
+ TeleportClusterName: "mycluster",
+ Region: "eu-west-2",
+ Integration: "my-integration",
+ }
+ }
+
+ for _, tt := range []struct {
+ name string
+ req func() ListDeployedDatabaseServicesRequest
+ errCheck require.ErrorAssertionFunc
+ reqWithDefaults ListDeployedDatabaseServicesRequest
+ }{
+ {
+ name: "no fields",
+ req: func() ListDeployedDatabaseServicesRequest {
+ return ListDeployedDatabaseServicesRequest{}
+ },
+ errCheck: isBadParamErrFn,
+ },
+ {
+ name: "missing teleport cluster name",
+ req: func() ListDeployedDatabaseServicesRequest {
+ r := baseReqFn()
+ r.TeleportClusterName = ""
+ return r
+ },
+ errCheck: isBadParamErrFn,
+ },
+ {
+ name: "missing region",
+ req: func() ListDeployedDatabaseServicesRequest {
+ r := baseReqFn()
+ r.Region = ""
+ return r
+ },
+ errCheck: isBadParamErrFn,
+ },
+ {
+ name: "missing integration",
+ req: func() ListDeployedDatabaseServicesRequest {
+ r := baseReqFn()
+ r.Integration = ""
+ return r
+ },
+ errCheck: isBadParamErrFn,
+ },
+ } {
+ t.Run(tt.name, func(t *testing.T) {
+ r := tt.req()
+ err := r.checkAndSetDefaults()
+ tt.errCheck(t, err)
+
+ if err != nil {
+ return
+ }
+
+ require.Empty(t, cmp.Diff(tt.reqWithDefaults, r))
+ })
+ }
+}
+
+type mockListECSClient struct {
+ pageSize int
+
+ clusterName string
+ services []*ecstypes.Service
+ mapServices map[string]ecstypes.Service
+ taskDefinition map[string]*ecstypes.TaskDefinition
+}
+
+func (m *mockListECSClient) ListServices(ctx context.Context, params *ecs.ListServicesInput, optFns ...func(*ecs.Options)) (*ecs.ListServicesOutput, error) {
+ ret := &ecs.ListServicesOutput{}
+ if aws.ToString(params.Cluster) != m.clusterName {
+ return ret, nil
+ }
+
+ requestedPage := 1
+
+ totalEndpoints := len(m.services)
+
+ if params.NextToken != nil {
+ currentMarker, err := strconv.Atoi(*params.NextToken)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ requestedPage = currentMarker
+ }
+
+ sliceStart := m.pageSize * (requestedPage - 1)
+ sliceEnd := m.pageSize * requestedPage
+ if sliceEnd > totalEndpoints {
+ sliceEnd = totalEndpoints
+ }
+
+ for _, service := range m.services[sliceStart:sliceEnd] {
+ ret.ServiceArns = append(ret.ServiceArns, aws.ToString(service.ServiceArn))
+ }
+
+ if sliceEnd < totalEndpoints {
+ nextToken := strconv.Itoa(requestedPage + 1)
+ ret.NextToken = &nextToken
+ }
+
+ return ret, nil
+}
+
+func (m *mockListECSClient) DescribeServices(ctx context.Context, params *ecs.DescribeServicesInput, optFns ...func(*ecs.Options)) (*ecs.DescribeServicesOutput, error) {
+ ret := &ecs.DescribeServicesOutput{}
+ if aws.ToString(params.Cluster) != m.clusterName {
+ return ret, nil
+ }
+
+ for _, serviceARN := range params.Services {
+ ret.Services = append(ret.Services, m.mapServices[serviceARN])
+ }
+ return ret, nil
+}
+
+func (m *mockListECSClient) DescribeTaskDefinition(ctx context.Context, params *ecs.DescribeTaskDefinitionInput, optFns ...func(*ecs.Options)) (*ecs.DescribeTaskDefinitionOutput, error) {
+ ret := &ecs.DescribeTaskDefinitionOutput{}
+ ret.TaskDefinition = m.taskDefinition[aws.ToString(params.TaskDefinition)]
+
+ return ret, nil
+}
+
+func dummyServiceTask(idx int) (ecstypes.Service, *ecstypes.TaskDefinition) {
+ taskName := fmt.Sprintf("task-family-name-%d", idx)
+ serviceARN := fmt.Sprintf("arn:eks:service-%d", idx)
+
+ ecsTask := &ecstypes.TaskDefinition{
+ Family: aws.String(taskName),
+ ContainerDefinitions: []ecstypes.ContainerDefinition{{
+ EntryPoint: []string{"teleport"},
+ Command: []string{"start"},
+ }},
+ }
+
+ ecsService := ecstypes.Service{
+ ServiceArn: aws.String(serviceARN),
+ ServiceName: aws.String(fmt.Sprintf("database-service-vpc-%d", idx)),
+ TaskDefinition: aws.String(taskName),
+ Tags: []ecstypes.Tag{
+ {Key: aws.String("teleport.dev/cluster"), Value: aws.String("my-cluster")},
+ {Key: aws.String("teleport.dev/integration"), Value: aws.String("my-integration")},
+ {Key: aws.String("teleport.dev/origin"), Value: aws.String("integration_awsoidc")},
+ },
+ }
+
+ return ecsService, ecsTask
+}
+
+func TestListDeployedDatabaseServices(t *testing.T) {
+ ctx := context.Background()
+
+ const pageSize = 100
+ t.Run("pagination", func(t *testing.T) {
+ totalServices := 203
+
+ allServices := make([]*ecstypes.Service, 0, totalServices)
+ mapServices := make(map[string]ecstypes.Service, totalServices)
+ allTasks := make(map[string]*ecstypes.TaskDefinition, totalServices)
+ for i := 0; i < totalServices; i++ {
+ ecsService, ecsTask := dummyServiceTask(i)
+ allTasks[aws.ToString(ecsTask.Family)] = ecsTask
+ mapServices[aws.ToString(ecsService.ServiceArn)] = ecsService
+ allServices = append(allServices, &ecsService)
+ }
+
+ mockListClient := &mockListECSClient{
+ pageSize: pageSize,
+ clusterName: "my-cluster-teleport",
+ mapServices: mapServices,
+ services: allServices,
+ taskDefinition: allTasks,
+ }
+
+ // First page must return pageSize number of Endpoints
+ resp, err := ListDeployedDatabaseServices(ctx, mockListClient, ListDeployedDatabaseServicesRequest{
+ Integration: "my-integration",
+ TeleportClusterName: "my-cluster",
+ Region: "us-east-1",
+ })
+ require.NoError(t, err)
+ require.NotEmpty(t, resp.NextToken)
+ require.Len(t, resp.DeployedDatabaseServices, pageSize)
+ require.Equal(t, "database-service-vpc-0", resp.DeployedDatabaseServices[0].Name)
+ require.Equal(t, "https://us-east-1.console.aws.amazon.com/ecs/v2/clusters/my-cluster-teleport/services/database-service-vpc-0", resp.DeployedDatabaseServices[0].ServiceDashboardURL)
+ require.Equal(t, []string{"teleport"}, resp.DeployedDatabaseServices[0].ContainerEntryPoint)
+ require.Equal(t, []string{"start"}, resp.DeployedDatabaseServices[0].ContainerCommand)
+
+ // Second page must return pageSize number of Endpoints
+ nextPageToken := resp.NextToken
+ resp, err = ListDeployedDatabaseServices(ctx, mockListClient, ListDeployedDatabaseServicesRequest{
+ Integration: "my-integration",
+ TeleportClusterName: "my-cluster",
+ Region: "us-east-1",
+ NextToken: nextPageToken,
+ })
+ require.NoError(t, err)
+ require.NotEmpty(t, resp.NextToken)
+ require.Len(t, resp.DeployedDatabaseServices, pageSize)
+ require.Equal(t, "https://us-east-1.console.aws.amazon.com/ecs/v2/clusters/my-cluster-teleport/services/database-service-vpc-100", resp.DeployedDatabaseServices[0].ServiceDashboardURL)
+
+ // Third page must return only the remaining Endpoints and an empty nextToken
+ nextPageToken = resp.NextToken
+ resp, err = ListDeployedDatabaseServices(ctx, mockListClient, ListDeployedDatabaseServicesRequest{
+ Integration: "my-integration",
+ TeleportClusterName: "my-cluster",
+ Region: "us-east-1",
+ NextToken: nextPageToken,
+ })
+ require.NoError(t, err)
+ require.Empty(t, resp.NextToken)
+ require.Len(t, resp.DeployedDatabaseServices, 3)
+ })
+
+ for _, tt := range []struct {
+ name string
+ req ListDeployedDatabaseServicesRequest
+ mockClient func() *mockListECSClient
+ errCheck require.ErrorAssertionFunc
+ respCheck func(*testing.T, *ListDeployedDatabaseServicesResponse)
+ }{
+ {
+ name: "ignores ECS Services without ownership tags",
+ req: ListDeployedDatabaseServicesRequest{
+ Integration: "my-integration",
+ TeleportClusterName: "my-cluster",
+ Region: "us-east-1",
+ },
+ mockClient: func() *mockListECSClient {
+ ret := &mockListECSClient{
+ pageSize: 10,
+ clusterName: "my-cluster-teleport",
+ }
+ ecsService, ecsTask := dummyServiceTask(0)
+
+ ecsServiceAnotherIntegration, ecsTaskAnotherIntegration := dummyServiceTask(1)
+ ecsServiceAnotherIntegration.Tags = []ecstypes.Tag{{Key: aws.String("teleport.dev/integration"), Value: aws.String("another-integration")}}
+
+ ret.taskDefinition = map[string]*ecstypes.TaskDefinition{
+ aws.ToString(ecsTask.Family): ecsTask,
+ aws.ToString(ecsTaskAnotherIntegration.Family): ecsTaskAnotherIntegration,
+ }
+ ret.mapServices = map[string]ecstypes.Service{
+ aws.ToString(ecsService.ServiceArn): ecsService,
+ aws.ToString(ecsServiceAnotherIntegration.ServiceArn): ecsServiceAnotherIntegration,
+ }
+ ret.services = append(ret.services, &ecsService)
+ ret.services = append(ret.services, &ecsServiceAnotherIntegration)
+ return ret
+ },
+ respCheck: func(t *testing.T, resp *ListDeployedDatabaseServicesResponse) {
+ require.Len(t, resp.DeployedDatabaseServices, 1, "expected 1 service, got %d", len(resp.DeployedDatabaseServices))
+ require.Empty(t, resp.NextToken, "expected an empty NextToken")
+
+ expectedService := DeployedDatabaseService{
+ Name: "database-service-vpc-0",
+ ServiceDashboardURL: "https://us-east-1.console.aws.amazon.com/ecs/v2/clusters/my-cluster-teleport/services/database-service-vpc-0",
+ ContainerEntryPoint: []string{"teleport"},
+ ContainerCommand: []string{"start"},
+ }
+ require.Empty(t, cmp.Diff(expectedService, resp.DeployedDatabaseServices[0]))
+ },
+ errCheck: require.NoError,
+ },
+ {
+ name: "ignores ECS Services without containers",
+ req: ListDeployedDatabaseServicesRequest{
+ Integration: "my-integration",
+ TeleportClusterName: "my-cluster",
+ Region: "us-east-1",
+ },
+ mockClient: func() *mockListECSClient {
+ ret := &mockListECSClient{
+ pageSize: 10,
+ clusterName: "my-cluster-teleport",
+ }
+ ecsService, ecsTask := dummyServiceTask(0)
+
+ ecsServiceWithoutContainers, ecsTaskWithoutContainers := dummyServiceTask(1)
+ ecsTaskWithoutContainers.ContainerDefinitions = []ecstypes.ContainerDefinition{}
+
+ ret.taskDefinition = map[string]*ecstypes.TaskDefinition{
+ aws.ToString(ecsTask.Family): ecsTask,
+ aws.ToString(ecsTaskWithoutContainers.Family): ecsTaskWithoutContainers,
+ }
+ ret.mapServices = map[string]ecstypes.Service{
+ aws.ToString(ecsService.ServiceArn): ecsService,
+ aws.ToString(ecsServiceWithoutContainers.ServiceArn): ecsServiceWithoutContainers,
+ }
+ ret.services = append(ret.services, &ecsService)
+ ret.services = append(ret.services, &ecsServiceWithoutContainers)
+ return ret
+ },
+ respCheck: func(t *testing.T, resp *ListDeployedDatabaseServicesResponse) {
+ require.Len(t, resp.DeployedDatabaseServices, 1, "expected 1 service, got %d", len(resp.DeployedDatabaseServices))
+ require.Empty(t, resp.NextToken, "expected an empty NextToken")
+
+ expectedService := DeployedDatabaseService{
+ Name: "database-service-vpc-0",
+ ServiceDashboardURL: "https://us-east-1.console.aws.amazon.com/ecs/v2/clusters/my-cluster-teleport/services/database-service-vpc-0",
+ ContainerEntryPoint: []string{"teleport"},
+ ContainerCommand: []string{"start"},
+ }
+ require.Empty(t, cmp.Diff(expectedService, resp.DeployedDatabaseServices[0]))
+ },
+ errCheck: require.NoError,
+ },
+ } {
+ t.Run(tt.name, func(t *testing.T) {
+ resp, err := ListDeployedDatabaseServices(ctx, tt.mockClient(), tt.req)
+ tt.errCheck(t, err)
+ if tt.respCheck != nil {
+ tt.respCheck(t, resp)
+ }
+ })
+ }
+}