diff --git a/src/main/java/com/meta/cp4m/S3PreProcessor.java b/src/main/java/com/meta/cp4m/S3PreProcessor.java index fddc9ef..e3733a9 100644 --- a/src/main/java/com/meta/cp4m/S3PreProcessor.java +++ b/src/main/java/com/meta/cp4m/S3PreProcessor.java @@ -11,13 +11,15 @@ import com.meta.cp4m.message.Message; import com.meta.cp4m.message.Payload; import com.meta.cp4m.message.ThreadState; +import java.nio.file.Path; import java.time.Instant; +import java.util.Objects; import org.checkerframework.checker.nullness.qual.Nullable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; import software.amazon.awssdk.auth.credentials.AwsSessionCredentials; import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; +import software.amazon.awssdk.auth.credentials.WebIdentityTokenFileCredentialsProvider; import software.amazon.awssdk.core.sync.RequestBody; import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.services.s3.S3Client; @@ -31,7 +33,6 @@ public class S3PreProcessor implements PreProcessor { private final String region; private final String bucket; private final @Nullable String textMessageAddition; - private final @Nullable AwsCredentialsProvider credentials; public S3PreProcessor( String awsAccessKeyID, @@ -44,14 +45,6 @@ public S3PreProcessor( this.region = region; this.bucket = bucket; this.textMessageAddition = textMessageAddition; - - if (!this.awsAccessKeyID.isEmpty() && !this.awsSecretAccessKey.isEmpty()) { - AwsSessionCredentials sessionCredentials = - AwsSessionCredentials.create(this.awsAccessKeyID, this.awsSecretAccessKey, ""); - this.credentials = StaticCredentialsProvider.create(sessionCredentials); - } else { - this.credentials = null; - } } @Override @@ -76,14 +69,33 @@ public ThreadState run(ThreadState in) { : in.with(in.newMessageFromUser(Instant.now(), textMessageAddition, Identifier.random())); } + private S3Client client() { + S3ClientBuilder clientBuilder = S3Client.builder().region(Region.of(this.region)); + if (!this.awsAccessKeyID.isEmpty() && !this.awsSecretAccessKey.isEmpty()) { + AwsSessionCredentials sessionCredentials = + AwsSessionCredentials.create(this.awsAccessKeyID, this.awsSecretAccessKey, ""); + clientBuilder.credentialsProvider(StaticCredentialsProvider.create(sessionCredentials)); + } else { + Path tokenFile = + Path.of( + Objects.requireNonNull(System.getenv("AWS_WEB_IDENTITY_TOKEN_FILE")), + "AWS_WEB_IDENTITY_TOKEN_FILE is required"); + String arnRole = + Objects.requireNonNull(System.getenv("AWS_ROLE_ARN"), "AWS_ROLE_ARN is required"); + WebIdentityTokenFileCredentialsProvider webCredentials = + WebIdentityTokenFileCredentialsProvider.builder() + .webIdentityTokenFile(tokenFile) + .roleArn(arnRole) + .build(); + clientBuilder = clientBuilder.credentialsProvider(webCredentials); + } + return clientBuilder.build(); + } + public void sendRequest(byte[] media, String senderID, String extension, String mimeType) { String key = senderID + '_' + Instant.now().toEpochMilli() + '.' + extension; LOGGER.debug("attempting to upload \"" + key + "\" file to AWS S3"); - S3ClientBuilder clientBuilder = S3Client.builder().region(Region.of(this.region)); - if (this.credentials != null) { - clientBuilder = clientBuilder.credentialsProvider(this.credentials); - } - try (S3Client s3Client = clientBuilder.build()) { + try (S3Client s3Client = client()) { PutObjectRequest request = PutObjectRequest.builder().bucket(this.bucket).key(key).contentType(mimeType).build();