Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

migrate instance provider library to use aws sdk v2 #2676

Merged
merged 1 commit into from
Aug 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions libs/java/instance_provider/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@
<dependencyManagement>
<dependencies>
<dependency>
<groupId>com.amazonaws</groupId>
<artifactId>aws-java-sdk-bom</artifactId>
<version>${aws.version}</version>
<groupId>software.amazon.awssdk</groupId>
<artifactId>bom</artifactId>
<version>${aws2.version}</version>
<type>pom</type>
<scope>import</scope>
</dependency>
Expand Down Expand Up @@ -137,8 +137,8 @@
<scope>runtime</scope>
</dependency>
<dependency>
<groupId>com.amazonaws</groupId>
<artifactId>aws-java-sdk-sts</artifactId>
<groupId>software.amazon.awssdk</groupId>
<artifactId>sts</artifactId>
</dependency>
<dependency>
<groupId>com.yahoo.athenz</groupId>
Expand All @@ -156,8 +156,8 @@
<version>${gcp.api-client.version}</version>
</dependency>
<dependency>
<groupId>com.amazonaws</groupId>
<artifactId>aws-java-sdk-iam</artifactId>
<groupId>software.amazon.awssdk</groupId>
<artifactId>iam</artifactId>
</dependency>
</dependencies>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,6 @@
*/
package com.yahoo.athenz.instance.provider.impl;

import com.amazonaws.auth.AWSStaticCredentialsProvider;
import com.amazonaws.auth.BasicSessionCredentials;
import com.amazonaws.auth.DefaultAWSCredentialsProviderChain;
import com.amazonaws.services.identitymanagement.AmazonIdentityManagement;
import com.amazonaws.services.identitymanagement.AmazonIdentityManagementClientBuilder;
import com.amazonaws.services.identitymanagement.model.ListOpenIDConnectProvidersRequest;
import com.amazonaws.services.identitymanagement.model.ListOpenIDConnectProvidersResult;
import com.amazonaws.services.identitymanagement.model.OpenIDConnectProviderListEntry;
import com.amazonaws.services.securitytoken.AWSSecurityTokenService;
import com.amazonaws.services.securitytoken.AWSSecurityTokenServiceClientBuilder;
import com.amazonaws.services.securitytoken.model.AssumeRoleRequest;
import com.amazonaws.services.securitytoken.model.AssumeRoleResult;
import com.yahoo.athenz.auth.Authorizer;
import com.yahoo.athenz.auth.Principal;
import com.yahoo.athenz.auth.impl.SimplePrincipal;
Expand All @@ -45,6 +33,18 @@
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider;
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.sts.StsClient;
import software.amazon.awssdk.services.iam.model.ListOpenIdConnectProvidersRequest;
import software.amazon.awssdk.services.iam.model.ListOpenIdConnectProvidersResponse;
import software.amazon.awssdk.services.iam.model.OpenIDConnectProviderListEntry;
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
import software.amazon.awssdk.services.iam.IamClient;
import software.amazon.awssdk.services.sts.model.AssumeRoleRequest;
import software.amazon.awssdk.services.sts.model.AssumeRoleResponse;

import static com.yahoo.athenz.common.server.util.config.ConfigManagerSingleton.CONFIG_MANAGER;
import static com.yahoo.athenz.instance.provider.InstanceProvider.ZTS_INSTANCE_AWS_ACCOUNT;
import static com.yahoo.athenz.instance.provider.impl.InstanceAWSProvider.*;
Expand All @@ -61,7 +61,7 @@ public class DefaultAWSElasticKubernetesServiceValidator extends CommonKubernete
private static final String ASSUME_ROLE_NAME = System.getProperty(ZTS_PROP_K8S_PROVIDER_ATTESTATION_AWS_ASSUME_ROLE_NAME, "oidc-issuers-reader");
static final String ZTS_PROP_K8S_PROVIDER_AWS_ATTR_VALIDATOR_FACTORY_CLASS = "athenz.zts.k8s_provider_aws_attr_validator_factory_class";

AWSSecurityTokenService stsClient;
StsClient stsClient;
String serverRegion;

Set<String> awsDNSSuffixes = new HashSet<>();
Expand All @@ -75,6 +75,7 @@ public class DefaultAWSElasticKubernetesServiceValidator extends CommonKubernete
public static DefaultAWSElasticKubernetesServiceValidator getInstance() {
return INSTANCE;
}

private DefaultAWSElasticKubernetesServiceValidator() {
}

Expand Down Expand Up @@ -105,10 +106,8 @@ public void initialize(final SSLContext sslContext, Authorizer authorizer) {

if (useIamRoleForIssuerValidation()) {
// Create an STS client using default credentials
stsClient = AWSSecurityTokenServiceClientBuilder.standard()
.withRegion(serverRegion)
.withCredentials(DefaultAWSCredentialsProviderChain.getInstance())
.build();
stsClient = StsClient.builder().credentialsProvider(DefaultCredentialsProvider.builder().build())
.region(Region.of(serverRegion)).build();
}
final String dnsSuffix = System.getProperty(AWS_PROP_DNS_SUFFIX);
if (!StringUtil.isEmpty(dnsSuffix)) {
Expand All @@ -121,6 +120,7 @@ public void initialize(final SSLContext sslContext, Authorizer authorizer) {

this.attrValidator = newAttrValidator(sslContext);
}

@Override
public String validateIssuer(InstanceConfirmation confirmation, IdTokenAttestationData attestationData, StringBuilder errMsg) {

Expand Down Expand Up @@ -173,37 +173,48 @@ public String validateIssuer(InstanceConfirmation confirmation, IdTokenAttestati
return issuer;
}

boolean verifyIssuerPresenceInDomainAWSAccount(final String issuer,
final String awsAccount) {
boolean result = false;
IamClient getIamClient(final String awsAccount) {

String roleArn = String.format("arn:aws:iam::%s:role/%s", awsAccount, ASSUME_ROLE_NAME);
String roleSessionName = ASSUME_ROLE_NAME + "-Session";
final String roleArn = String.format("arn:aws:iam::%s:role/%s", awsAccount, ASSUME_ROLE_NAME);
final String roleSessionName = ASSUME_ROLE_NAME + "-Session";

// Assume the role in the target AWS account
AssumeRoleRequest assumeRoleRequest = new AssumeRoleRequest()
.withRoleArn(roleArn)
.withRoleSessionName(roleSessionName);
AssumeRoleResult assumeRoleResult = stsClient.assumeRole(assumeRoleRequest);
BasicSessionCredentials sessionCredentials = new BasicSessionCredentials(
assumeRoleResult.getCredentials().getAccessKeyId(),
assumeRoleResult.getCredentials().getSecretAccessKey(),
assumeRoleResult.getCredentials().getSessionToken()
);

AmazonIdentityManagement iamClient = AmazonIdentityManagementClientBuilder.standard()
.withRegion(serverRegion)
.withCredentials(new AWSStaticCredentialsProvider(sessionCredentials))

AssumeRoleRequest assumeRoleRequest = AssumeRoleRequest.builder()
.roleArn(roleArn).roleSessionName(roleSessionName).build();
AssumeRoleResponse assumeRoleResponse = stsClient.assumeRole(assumeRoleRequest);

AwsBasicCredentials credentials = AwsBasicCredentials.builder()
.accessKeyId(assumeRoleResponse.credentials().accessKeyId())
.secretAccessKey(assumeRoleResponse.credentials().secretAccessKey())
.build();

// Create Static Credentials Provider

StaticCredentialsProvider credentialsProvider = StaticCredentialsProvider.create(credentials);

// Create IAM Client

return IamClient.builder().credentialsProvider(credentialsProvider).region(Region.of(serverRegion)).build();
}

boolean verifyIssuerPresenceInDomainAWSAccount(final String issuer, final String awsAccount) {

boolean result = false;

// get our IAM Client

IamClient iamClient = getIamClient(awsAccount);

// Call the IAM API to get the list of OIDC issuers
ListOpenIDConnectProvidersRequest listRequest = new ListOpenIDConnectProvidersRequest();
ListOpenIDConnectProvidersResult listResult = iamClient.listOpenIDConnectProviders(listRequest);
List<OpenIDConnectProviderListEntry> oidcIssuers = listResult.getOpenIDConnectProviderList();

ListOpenIdConnectProvidersRequest request = ListOpenIdConnectProvidersRequest.builder().build();
ListOpenIdConnectProvidersResponse response = iamClient.listOpenIDConnectProviders(request);
List<OpenIDConnectProviderListEntry> oidcIssuers = response.openIDConnectProviderList();
if (oidcIssuers != null) {
String issuerWithoutProtocol = issuer.replaceFirst("^https://", "");
for (OpenIDConnectProviderListEntry oidcIssuer : oidcIssuers) {
if (oidcIssuer != null && oidcIssuer.getArn() != null && oidcIssuer.getArn().endsWith(issuerWithoutProtocol)) {
if (oidcIssuer != null && oidcIssuer.arn() != null && oidcIssuer.arn().endsWith(issuerWithoutProtocol)) {
result = true;
break;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,12 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.amazonaws.regions.Regions;
import com.amazonaws.auth.BasicSessionCredentials;
import com.amazonaws.auth.AWSStaticCredentialsProvider;
import com.amazonaws.services.securitytoken.AWSSecurityTokenService;
import com.amazonaws.services.securitytoken.AWSSecurityTokenServiceClientBuilder;
import com.amazonaws.services.securitytoken.model.GetCallerIdentityRequest;
import com.amazonaws.services.securitytoken.model.GetCallerIdentityResult;
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.sts.StsClient;
import software.amazon.awssdk.services.sts.model.GetCallerIdentityRequest;
import software.amazon.awssdk.services.sts.model.GetCallerIdentityResponse;
import com.yahoo.athenz.auth.KeyStore;
import com.yahoo.athenz.instance.provider.InstanceConfirmation;
import com.yahoo.athenz.instance.provider.InstanceProvider;
Expand Down Expand Up @@ -412,8 +411,8 @@ protected void setConfirmationAttributes(InstanceConfirmation confirmation, bool
}
confirmation.setAttributes(attributes);
}
AWSSecurityTokenService getInstanceClient(AWSAttestationData info) {

StsClient getInstanceClient(AWSAttestationData info) {

String access = info.getAccess();
if (access == null || access.isEmpty()) {
Expand All @@ -432,36 +431,41 @@ AWSSecurityTokenService getInstanceClient(AWSAttestationData info) {
LOGGER.error("getInstanceClient: No token available in instance document");
return null;
}

BasicSessionCredentials creds = new BasicSessionCredentials(access, secret, token);

return AWSSecurityTokenServiceClientBuilder.standard()
.withCredentials(new AWSStaticCredentialsProvider(creds))
.withRegion(Regions.fromName(awsRegion))
AwsBasicCredentials credentials = AwsBasicCredentials.builder()
.accessKeyId(access)
.secretAccessKey(secret)
.build();

// Create Static Credentials Provider

StaticCredentialsProvider credentialsProvider = StaticCredentialsProvider.create(credentials);

// Create STS Client

return StsClient.builder().credentialsProvider(credentialsProvider).region(Region.of(awsRegion)).build();
}

boolean verifyInstanceIdentity(AWSAttestationData info, final String awsAccount) {

GetCallerIdentityRequest req = new GetCallerIdentityRequest();


try {
AWSSecurityTokenService client = getInstanceClient(info);
if (client == null) {
StsClient stsClient = getInstanceClient(info);
if (stsClient == null) {
LOGGER.error("verifyInstanceIdentity - unable to get AWS STS client object");
return false;
}

GetCallerIdentityResult res = client.getCallerIdentity(req);
if (res == null) {

GetCallerIdentityRequest request = GetCallerIdentityRequest.builder().build();
GetCallerIdentityResponse response = stsClient.getCallerIdentity(request);
if (response == null) {
LOGGER.error("verifyInstanceIdentity - unable to get caller identity");
return false;
}

String arn = "arn:aws:sts::" + awsAccount + ":assumed-role/" + info.getRole() + "/";
if (!res.getArn().startsWith(arn)) {
LOGGER.error("verifyInstanceIdentity - ARN mismatch - request: {} caller-idenity: {}",
arn, res.getArn());
if (!response.arn().startsWith(arn)) {
LOGGER.error("verifyInstanceIdentity - ARN mismatch - request: {} caller-identity: {}",
arn, response.arn());
return false;
}

Expand Down
Loading
Loading