Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OCM-7529 Setup VPC and AWS client interface #53

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 172 additions & 18 deletions pkg/aws/aws_client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,18 @@ package aws_client
import (
"context"
"os"
"time"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/cloudformation"
"github.com/aws/aws-sdk-go-v2/service/cloudwatchlogs"
"github.com/aws/aws-sdk-go-v2/service/ec2"
"github.com/aws/aws-sdk-go-v2/service/ec2/types"
elbtypes "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancing/types"
"github.com/aws/aws-sdk-go-v2/service/iam"
iamtypes "github.com/aws/aws-sdk-go-v2/service/iam/types"
"github.com/aws/aws-sdk-go-v2/service/kms"
"github.com/aws/aws-sdk-go-v2/service/sts"

Expand All @@ -22,7 +26,148 @@ import (
CON "github.com/openshift-online/ocm-common/pkg/aws/consts"
)

type AWSClient struct {
type AWSClient interface {
GetAWSAccountID() string
GetRegion() string

EC2() *ec2.Client
Route53() *route53.Client
CloudFormation() *cloudformation.Client
ELB() *elb.Client

RetrieveAccessKey() (*AccessKeyMod, error)

DescribeLogGroupsByName(logGroupName string) (*cloudwatchlogs.DescribeLogGroupsOutput, error)
DescribeLogStreamByName(logGroupName string) (*cloudwatchlogs.DescribeLogStreamsOutput, error)
DeleteLogGroupByName(logGroupName string) (*cloudwatchlogs.DeleteLogGroupOutput, error)

AllocateEIPAddress() (*ec2.AllocateAddressOutput, error)
DisassociateAddress(associateID string) (*ec2.DisassociateAddressOutput, error)
AllocateEIPAndAssociateInstance(instanceID string) (string, error)
ReleaseAddress(allocationID string) (*ec2.ReleaseAddressOutput, error)

DescribeLoadBalancers(vpcID string) ([]elbtypes.LoadBalancerDescription, error)
DeleteELB(ELB elbtypes.LoadBalancerDescription) error

CopyImage(sourceImageID string, sourceRegion string, name string) (string, error)
DescribeImage(imageIDs []string, filters ...map[string][]string) (*ec2.DescribeImagesOutput, error)

LaunchInstance(subnetID string, imageID string, count int, instanceType string, keyName string, securityGroupIds []string, wait bool) (*ec2.RunInstancesOutput, error)
ListInstances(instanceIDs []string, filters ...map[string][]string) ([]types.Instance, error)
WaitForInstanceReady(instanceID string, timeout time.Duration) error
CheckInstanceState(instanceIDs ...string) (*ec2.DescribeInstanceStatusOutput, error)
WaitForInstancesRunning(instanceIDs []string, timeout time.Duration) (allRunning bool, err error)
WaitForInstancesTerminated(instanceIDs []string, timeout time.Duration) (allTerminated bool, err error)
ListAvaliableInstanceTypesForRegion(region string, availabilityZones ...string) ([]string, error)
ListAvaliableZonesForRegion(region string, zoneType string) ([]string, error)
TerminateInstances(instanceIDs []string, wait bool, timeout time.Duration) error
WaitForInstanceTerminated(instanceIDs []string, timeout time.Duration) error
GetTagsOfInstanceProfile(instanceProfileName string) ([]iamtypes.Tag, error)
GetInstancesByInfraID(infraID string) ([]types.Instance, error)
ListAvaliableRegionsFromAWS() ([]types.Region, error)

CreateInternetGateway() (*ec2.CreateInternetGatewayOutput, error)
AttachInternetGateway(internetGatewayID string, vpcID string) (*ec2.AttachInternetGatewayOutput, error)
DetachInternetGateway(internetGatewayID string, vpcID string) (*ec2.DetachInternetGatewayOutput, error)
ListInternetGateWay(vpcID string) ([]types.InternetGateway, error)
DeleteInternetGateway(internetGatewayID string) (*ec2.DeleteInternetGatewayOutput, error)

CreateKeyPair(keyName string) (*ec2.CreateKeyPairOutput, error)
DeleteKeyPair(keyName string) (*ec2.DeleteKeyPairOutput, error)

CreateKMSKeys(tagKey string, tagValue string, description string, policy string, multiRegion bool) (keyID string, keyArn string, err error)
DescribeKMSKeys(keyID string) (kms.DescribeKeyOutput, error)
ScheduleKeyDeletion(kmsKeyId string, pendingWindowInDays int32) (*kms.ScheduleKeyDeletionOutput, error)
GetKMSPolicy(keyID string, policyName string) (kms.GetKeyPolicyOutput, error)
PutKMSPolicy(keyID string, policyName string, policy string) (kms.PutKeyPolicyOutput, error)
TagKeys(kmsKeyId string, tagKey string, tagValue string) (*kms.TagResourceOutput, error)

CreateNatGateway(subnetID string, allocationID string, vpcID string) (*ec2.CreateNatGatewayOutput, error)
DeleteNatGateway(natGatewayID string, timeout ...int) (*ec2.DeleteNatGatewayOutput, error)
ListNatGateWays(vpcID string) ([]types.NatGateway, error)

ListNetWorkAcls(vpcID string) ([]types.NetworkAcl, error)
AddNetworkAclEntry(networkAclId string, egress bool, protocol string, ruleAction string, ruleNumber int32, fromPort int32, toPort int32, cidrBlock string) (*ec2.CreateNetworkAclEntryOutput, error)
DeleteNetworkAclEntry(networkAclId string, egress bool, ruleNumber int32) (*ec2.DeleteNetworkAclEntryOutput, error)

DescribeNetWorkInterface(vpcID string) ([]types.NetworkInterface, error)
DeleteNetworkInterface(networkinterface types.NetworkInterface) error

DeleteOIDCProvider(providerArn string) error

CreateIAMPolicy(policyName string, policyDocument string, tags map[string]string) (*iamtypes.Policy, error)
GetIAMPolicy(policyArn string) (*iamtypes.Policy, error)
DeleteIAMPolicy(arn string) error
AttachIAMPolicy(roleName string, policyArn string) error
DetachIAMPolicy(roleAName string, policyArn string) error
GetCustomerIAMPolicies() ([]iamtypes.Policy, error)
FilterNeedCleanPolicies(cleanRule func(iamtypes.Policy) bool) ([]iamtypes.Policy, error)
DeletePolicy(arn string) error
DeletePolicyVersions(policyArn string) error
CleanPolicies(cleanRule func(iamtypes.Policy) bool) error

ResourceExisting(resourceID string) bool
ResourceDeleted(resourceID string) bool
WaitForResourceExisting(resourceID string, timeout int) error
WaitForResourceDeleted(resourceID string, timeout int) error

CreateRole(roleName string, assumeRolePolicyDocument string, permissionBoundry string, tags map[string]string, path string) (iamtypes.Role, error)
GetRole(roleName string) (*iamtypes.Role, error)
DeleteRole(roleName string) error
DeleteRoleAndPolicy(roleName string, managedPolicy bool) error
ListRoles() ([]iamtypes.Role, error)
IsPolicyAttachedToRole(roleName string, policyArn string) (bool, error)
ListAttachedRolePolicies(roleName string) ([]iamtypes.AttachedPolicy, error)
DetachRolePolicies(roleName string) error
DeleteRoleInstanceProfiles(roleName string) error
CreateIAMRole(roleName string, ProdENVTrustedRole string, StageENVTrustedRole string, StageIssuerTrustedRole string, externalID ...string) (iamtypes.Role, error)
CreateRegularRole(roleName string) (iamtypes.Role, error)
CreateRoleForAuditLogForward(roleName, awsAccountID string, oidcEndpointURL string) (iamtypes.Role, error)
CreatePolicy(policyName string, statements ...map[string]interface{}) (string, error)
CreatePolicyForAuditLogForward(policyName string) (string, error)

CreateRouteTable(vpcID string) (*ec2.CreateRouteTableOutput, error)
AssociateRouteTable(routeTableID string, subnetID string, vpcID string) (*ec2.AssociateRouteTableOutput, error)
ListCustomerRouteTables(vpcID string) ([]types.RouteTable, error)
ListRTAssociations(routeTableID string) ([]string, error)
DisassociateRouteTableAssociation(associationID string) (*ec2.DisassociateRouteTableOutput, error)
DisassociateRouteTableAssociations(routeTableID string) error
CreateRoute(routeTableID string, targetID string) (*types.Route, error)
DeleteRouteTable(routeTableID string) error
DeleteRouteTableChain(routeTableID string) error

CreateHostedZone(hostedZoneName string, vpcID string, private bool) (*route53.CreateHostedZoneOutput, error)
GetHostedZone(hostedZoneID string) (*route53.GetHostedZoneOutput, error)
ListHostedZoneByDNSName(hostedZoneName string) (*route53.ListHostedZonesByNameOutput, error)

ListSecurityGroups(vpcID string) ([]types.SecurityGroup, error)
ReleaseInboundOutboundRules(sgID string) error
DeleteSecurityGroup(groupID string) (*ec2.DeleteSecurityGroupOutput, error)
AuthorizeSecurityGroupIngress(groupID string, cidr string, protocol string, fromPort int32, toPort int32) (*ec2.AuthorizeSecurityGroupIngressOutput, error)
CreateSecurityGroup(vpcID string, groupName string, sgDescription string) (*ec2.CreateSecurityGroupOutput, error)
GetSecurityGroupWithID(sgID string) (*ec2.DescribeSecurityGroupsOutput, error)

CreateSubnet(vpcID string, zone string, subnetCidr string) (*types.Subnet, error)
ListSubnetByVpcID(vpcID string) ([]types.Subnet, error)
DeleteSubnet(subnetID string) (*ec2.DeleteSubnetOutput, error)
ListSubnetDetail(subnetIDs ...string) ([]types.Subnet, error)
ListSubnetsByFilter(filter []types.Filter) ([]types.Subnet, error)

TagResource(resourceID string, tags map[string]string) (*ec2.CreateTagsOutput, error)
RemoveResourceTag(resourceID string, tagKey string, tagValue string) (*ec2.DeleteTagsOutput, error)

DescribeVolumeByID(volumeID string) (*ec2.DescribeVolumesOutput, error)

ListVPCByName(vpcName string) ([]types.Vpc, error)
CreateVpc(cidr string, name ...string) (*ec2.CreateVpcOutput, error)
ModifyVpcDnsAttribute(vpcID string, dnsAttribute string, status bool) (*ec2.ModifyVpcAttributeOutput, error)
DeleteVpc(vpcID string) (*ec2.DeleteVpcOutput, error)
DescribeVPC(vpcID string) (types.Vpc, error)
ListEndpointAssociation(vpcID string) ([]types.VpcEndpoint, error)
DeleteVPCEndpoints(vpcID string) error
}

type awsClient struct {
Ec2Client *ec2.Client
Route53Client *route53.Client
StackFormationClient *cloudformation.Client
Expand All @@ -31,7 +176,6 @@ type AWSClient struct {
Region string
IamClient *iam.Client
ClientContext context.Context
AccountID string
KmsClient *kms.Client
CloudWatchLogsClient *cloudwatchlogs.Client
AWSConfig *aws.Config
Expand All @@ -42,7 +186,7 @@ type AccessKeyMod struct {
SecretAccessKey string `ini:"aws_secret_access_key,omitempty"`
}

func CreateAWSClient(profileName string, region string) (*AWSClient, error) {
func CreateAWSClient(profileName string, region string) (AWSClient, error) {
var cfg aws.Config
var err error

Expand Down Expand Up @@ -77,7 +221,7 @@ func CreateAWSClient(profileName string, region string) (*AWSClient, error) {
return nil, err
}

awsClient := &AWSClient{
awsClient := &awsClient{
Ec2Client: ec2.NewFromConfig(cfg),
Route53Client: route53.NewFromConfig(cfg),
StackFormationClient: cloudformation.NewFromConfig(cfg),
Expand All @@ -89,11 +233,13 @@ func CreateAWSClient(profileName string, region string) (*AWSClient, error) {
KmsClient: kms.NewFromConfig(cfg),
AWSConfig: &cfg,
}
awsClient.AccountID = awsClient.GetAWSAccountID()
return awsClient, nil
}
func (client *awsClient) GetRegion() string {
return client.Region
}

func (client *AWSClient) GetAWSAccountID() string {
func (client *awsClient) GetAWSAccountID() string {
input := &sts.GetCallerIdentityInput{}
out, err := client.StsClient.GetCallerIdentity(client.ClientContext, input)
if err != nil {
Expand All @@ -102,45 +248,53 @@ func (client *AWSClient) GetAWSAccountID() string {
return *out.Account
}

func (client *AWSClient) EC2() *ec2.Client {
func (client *awsClient) EC2() *ec2.Client {
return client.Ec2Client
}

func (client *AWSClient) Route53() *route53.Client {
func (client *awsClient) Route53() *route53.Client {
return client.Route53Client
}
func (client *AWSClient) CloudFormation() *cloudformation.Client {
func (client *awsClient) CloudFormation() *cloudformation.Client {
return client.StackFormationClient
}
func (client *AWSClient) ELB() *elb.Client {
func (client *awsClient) ELB() *elb.Client {
return client.ElbClient
}

func GrantValidAccessKeys(userName string) (*AccessKeyMod, error) {
var cre aws.Credentials
var keysMod *AccessKeyMod
var err error
retryTimes := 3
for retryTimes > 0 {
if cre.AccessKeyID != "" {
if keysMod.AccessKeyId != "" {
break
}
client, err := CreateAWSClient(userName, CON.DefaultAWSRegion)
if err != nil {
return nil, err
}

cre, err = client.AWSConfig.Credentials.Retrieve(client.ClientContext)
keysMod, err = client.RetrieveAccessKey()
if err != nil {
return nil, err
}
log.LogInfo(">>> Access key grant successfully")

keysMod = &AccessKeyMod{
AccessKeyId: cre.AccessKeyID,
SecretAccessKey: cre.SecretAccessKey,
}
retryTimes--
}
return keysMod, err
}

func (client *awsClient) RetrieveAccessKey() (*AccessKeyMod, error) {
cre, err := client.AWSConfig.Credentials.Retrieve(client.ClientContext)
if err != nil {
return nil, err
}
log.LogInfo(">>> Access key grant successfully")

keysMod := &AccessKeyMod{
AccessKeyId: cre.AccessKeyID,
SecretAccessKey: cre.SecretAccessKey,
}
return keysMod, nil
}
6 changes: 3 additions & 3 deletions pkg/aws/aws_client/cloudwatch_logs.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
"github.com/openshift-online/ocm-common/pkg/log"
)

func (client *AWSClient) DescribeLogGroupsByName(logGroupName string) (*cloudwatchlogs.DescribeLogGroupsOutput, error) {
func (client *awsClient) DescribeLogGroupsByName(logGroupName string) (*cloudwatchlogs.DescribeLogGroupsOutput, error) {
output, err := client.CloudWatchLogsClient.DescribeLogGroups(context.TODO(), &cloudwatchlogs.DescribeLogGroupsInput{
LogGroupNamePrefix: &logGroupName,
})
Expand All @@ -17,7 +17,7 @@ func (client *AWSClient) DescribeLogGroupsByName(logGroupName string) (*cloudwat
return output, err
}

func (client *AWSClient) DescribeLogStreamByName(logGroupName string) (*cloudwatchlogs.DescribeLogStreamsOutput, error) {
func (client *awsClient) DescribeLogStreamByName(logGroupName string) (*cloudwatchlogs.DescribeLogStreamsOutput, error) {
output, err := client.CloudWatchLogsClient.DescribeLogStreams(context.TODO(), &cloudwatchlogs.DescribeLogStreamsInput{
LogGroupName: &logGroupName,
})
Expand All @@ -27,7 +27,7 @@ func (client *AWSClient) DescribeLogStreamByName(logGroupName string) (*cloudwat
return output, err
}

func (client *AWSClient) DeleteLogGroupByName(logGroupName string) (*cloudwatchlogs.DeleteLogGroupOutput, error) {
func (client *awsClient) DeleteLogGroupByName(logGroupName string) (*cloudwatchlogs.DeleteLogGroupOutput, error) {
output, err := client.CloudWatchLogsClient.DeleteLogGroup(context.TODO(), &cloudwatchlogs.DeleteLogGroupInput{
LogGroupName: &logGroupName,
})
Expand Down
8 changes: 4 additions & 4 deletions pkg/aws/aws_client/eip.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
"github.com/openshift-online/ocm-common/pkg/log"
)

func (client *AWSClient) AllocateEIPAddress() (*ec2.AllocateAddressOutput, error) {
func (client *awsClient) AllocateEIPAddress() (*ec2.AllocateAddressOutput, error) {
inputs := &ec2.AllocateAddressInput{
Address: nil,
CustomerOwnedIpv4Pool: nil,
Expand All @@ -28,7 +28,7 @@ func (client *AWSClient) AllocateEIPAddress() (*ec2.AllocateAddressOutput, error
return respEIP, err
}

func (client *AWSClient) DisassociateAddress(associateID string) (*ec2.DisassociateAddressOutput, error) {
func (client *awsClient) DisassociateAddress(associateID string) (*ec2.DisassociateAddressOutput, error) {
inputDisassociate := &ec2.DisassociateAddressInput{
AssociationId: aws.String(associateID),
DryRun: nil,
Expand All @@ -44,7 +44,7 @@ func (client *AWSClient) DisassociateAddress(associateID string) (*ec2.Disassoci
return respDisassociate, err
}

func (client *AWSClient) AllocateEIPAndAssociateInstance(instanceID string) (string, error) {
func (client *awsClient) AllocateEIPAndAssociateInstance(instanceID string) (string, error) {
allocRes, err := client.AllocateEIPAddress()
if err != nil {
log.LogError("Failed allocated EIP: %s", err)
Expand Down Expand Up @@ -72,7 +72,7 @@ func (client *AWSClient) AllocateEIPAndAssociateInstance(instanceID string) (str
return *allocRes.PublicIp, nil
}

func (client *AWSClient) ReleaseAddress(allocationID string) (*ec2.ReleaseAddressOutput, error) {
func (client *awsClient) ReleaseAddress(allocationID string) (*ec2.ReleaseAddressOutput, error) {
inputRelease := &ec2.ReleaseAddressInput{
AllocationId: aws.String(allocationID),
DryRun: nil,
Expand Down
4 changes: 2 additions & 2 deletions pkg/aws/aws_client/elb.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
"github.com/openshift-online/ocm-common/pkg/log"
)

func (client *AWSClient) DescribeLoadBalancers(vpcID string) ([]elbtypes.LoadBalancerDescription, error) {
func (client *awsClient) DescribeLoadBalancers(vpcID string) ([]elbtypes.LoadBalancerDescription, error) {

listenedELB := []elbtypes.LoadBalancerDescription{}
input := &elb.DescribeLoadBalancersInput{}
Expand All @@ -30,7 +30,7 @@ func (client *AWSClient) DescribeLoadBalancers(vpcID string) ([]elbtypes.LoadBal
return listenedELB, err
}

func (client *AWSClient) DeleteELB(ELB elbtypes.LoadBalancerDescription) error {
func (client *awsClient) DeleteELB(ELB elbtypes.LoadBalancerDescription) error {
log.LogInfo("Goint to delete ELB %s", *ELB.LoadBalancerName)

deleteELBInput := &elb.DeleteLoadBalancerInput{
Expand Down
4 changes: 2 additions & 2 deletions pkg/aws/aws_client/image.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
"github.com/openshift-online/ocm-common/pkg/log"
)

func (client *AWSClient) CopyImage(sourceImageID string, sourceRegion string, name string) (string, error) {
func (client *awsClient) CopyImage(sourceImageID string, sourceRegion string, name string) (string, error) {
copyImageInput := &ec2.CopyImageInput{
Name: &name,
SourceImageId: &sourceImageID,
Expand All @@ -22,7 +22,7 @@ func (client *AWSClient) CopyImage(sourceImageID string, sourceRegion string, na
return *output.ImageId, nil
}

func (client *AWSClient) DescribeImage(imageIDs []string, filters ...map[string][]string) (*ec2.DescribeImagesOutput, error) {
func (client *awsClient) DescribeImage(imageIDs []string, filters ...map[string][]string) (*ec2.DescribeImagesOutput, error) {
filterInput := []types.Filter{}
for _, filter := range filters {
for k, v := range filter {
Expand Down
Loading