Skip to content

Commit

Permalink
AWS: Fix S3InputStream retry policy (#11335)
Browse files Browse the repository at this point in the history
  • Loading branch information
edgarRd authored Oct 17, 2024
1 parent 9d58865 commit bbbfd1e
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 8 deletions.
26 changes: 22 additions & 4 deletions aws/src/main/java/org/apache/iceberg/aws/s3/S3InputStream.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import java.net.SocketException;
import java.net.SocketTimeoutException;
import java.util.Arrays;
import java.util.List;
import javax.net.ssl.SSLException;
import org.apache.iceberg.exceptions.NotFoundException;
import org.apache.iceberg.io.FileIOMetricsContext;
Expand All @@ -35,6 +36,7 @@
import org.apache.iceberg.metrics.Counter;
import org.apache.iceberg.metrics.MetricsContext;
import org.apache.iceberg.metrics.MetricsContext.Unit;
import org.apache.iceberg.relocated.com.google.common.annotations.VisibleForTesting;
import org.apache.iceberg.relocated.com.google.common.base.Joiner;
import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
Expand All @@ -50,6 +52,9 @@
class S3InputStream extends SeekableInputStream implements RangeReadable {
private static final Logger LOG = LoggerFactory.getLogger(S3InputStream.class);

private static final List<Class<? extends Throwable>> RETRYABLE_EXCEPTIONS =
ImmutableList.of(SSLException.class, SocketTimeoutException.class, SocketException.class);

private final StackTraceElement[] createStack;
private final S3Client s3;
private final S3URI location;
Expand All @@ -66,10 +71,18 @@ class S3InputStream extends SeekableInputStream implements RangeReadable {
private int skipSize = 1024 * 1024;
private RetryPolicy<Object> retryPolicy =
RetryPolicy.builder()
.handle(
ImmutableList.of(
SSLException.class, SocketTimeoutException.class, SocketException.class))
.onFailure(failure -> openStream(true))
.handle(RETRYABLE_EXCEPTIONS)
.onRetry(
e -> {
LOG.warn(
"Retrying read from S3, reopening stream (attempt {})", e.getAttemptCount());
resetForRetry();
})
.onFailure(
e ->
LOG.error(
"Failed to read from S3 input stream after exhausting all retries",
e.getException()))
.withMaxRetries(3)
.build();

Expand Down Expand Up @@ -230,6 +243,11 @@ private void openStream(boolean closeQuietly) throws IOException {
}
}

@VisibleForTesting
void resetForRetry() throws IOException {
openStream(true);
}

private void closeStream(boolean closeQuietly) throws IOException {
if (stream != null) {
// if we aren't at the end of the stream, and the stream is abortable, then
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
*/
package org.apache.iceberg.aws.s3;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doAnswer;
Expand All @@ -29,6 +30,7 @@
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Stream;
import javax.net.ssl.SSLException;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
Expand All @@ -49,10 +51,29 @@

public class TestFlakyS3InputStream extends TestS3InputStream {

private AtomicInteger resetForRetryCounter;

@BeforeEach
public void setupTest() {
resetForRetryCounter = new AtomicInteger(0);
}

@Override
S3InputStream newInputStream(S3Client s3Client, S3URI uri) {
return new S3InputStream(s3Client, uri) {
@Override
void resetForRetry() throws IOException {
resetForRetryCounter.incrementAndGet();
super.resetForRetry();
}
};
}

@ParameterizedTest
@MethodSource("retryableExceptions")
public void testReadWithFlakyStreamRetrySucceed(IOException exception) throws Exception {
testRead(flakyStreamClient(new AtomicInteger(3), exception));
assertThat(resetForRetryCounter.get()).isEqualTo(2);
}

@ParameterizedTest
Expand All @@ -61,6 +82,7 @@ public void testReadWithFlakyStreamExhaustedRetries(IOException exception) {
assertThatThrownBy(() -> testRead(flakyStreamClient(new AtomicInteger(5), exception)))
.isInstanceOf(exception.getClass())
.hasMessage(exception.getMessage());
assertThat(resetForRetryCounter.get()).isEqualTo(3);
}

@ParameterizedTest
Expand All @@ -69,12 +91,14 @@ public void testReadWithFlakyStreamNonRetryableException(IOException exception)
assertThatThrownBy(() -> testRead(flakyStreamClient(new AtomicInteger(3), exception)))
.isInstanceOf(exception.getClass())
.hasMessage(exception.getMessage());
assertThat(resetForRetryCounter.get()).isEqualTo(0);
}

@ParameterizedTest
@MethodSource("retryableExceptions")
public void testSeekWithFlakyStreamRetrySucceed(IOException exception) throws Exception {
testSeek(flakyStreamClient(new AtomicInteger(3), exception));
assertThat(resetForRetryCounter.get()).isEqualTo(2);
}

@ParameterizedTest
Expand All @@ -83,6 +107,7 @@ public void testSeekWithFlakyStreamExhaustedRetries(IOException exception) {
assertThatThrownBy(() -> testSeek(flakyStreamClient(new AtomicInteger(5), exception)))
.isInstanceOf(exception.getClass())
.hasMessage(exception.getMessage());
assertThat(resetForRetryCounter.get()).isEqualTo(3);
}

@ParameterizedTest
Expand All @@ -91,6 +116,7 @@ public void testSeekWithFlakyStreamNonRetryableException(IOException exception)
assertThatThrownBy(() -> testSeek(flakyStreamClient(new AtomicInteger(3), exception)))
.isInstanceOf(exception.getClass())
.hasMessage(exception.getMessage());
assertThat(resetForRetryCounter.get()).isEqualTo(0);
}

private static Stream<Arguments> retryableExceptions() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,18 @@ public void testRead() throws Exception {
testRead(s3);
}

S3InputStream newInputStream(S3Client s3Client, S3URI uri) {
return new S3InputStream(s3Client, uri);
}

protected void testRead(S3Client s3Client) throws Exception {
S3URI uri = new S3URI("s3://bucket/path/to/read.dat");
int dataSize = 1024 * 1024 * 10;
byte[] data = randomData(dataSize);

writeS3Data(uri, data);

try (SeekableInputStream in = new S3InputStream(s3Client, uri)) {
try (SeekableInputStream in = newInputStream(s3Client, uri)) {
int readSize = 1024;
readAndCheck(in, in.getPos(), readSize, data, false);
readAndCheck(in, in.getPos(), readSize, data, true);
Expand Down Expand Up @@ -128,7 +132,7 @@ protected void testRangeRead(S3Client s3Client) throws Exception {

writeS3Data(uri, expected);

try (RangeReadable in = new S3InputStream(s3Client, uri)) {
try (RangeReadable in = newInputStream(s3Client, uri)) {
// first 1k
position = 0;
offset = 0;
Expand Down Expand Up @@ -160,7 +164,7 @@ private void readAndCheckRanges(
@Test
public void testClose() throws Exception {
S3URI uri = new S3URI("s3://bucket/path/to/closed.dat");
SeekableInputStream closed = new S3InputStream(s3, uri);
SeekableInputStream closed = newInputStream(s3, uri);
closed.close();
assertThatThrownBy(() -> closed.seek(0))
.isInstanceOf(IllegalStateException.class)
Expand All @@ -178,7 +182,7 @@ protected void testSeek(S3Client s3Client) throws Exception {

writeS3Data(uri, expected);

try (SeekableInputStream in = new S3InputStream(s3Client, uri)) {
try (SeekableInputStream in = newInputStream(s3Client, uri)) {
in.seek(expected.length / 2);
byte[] actual = new byte[expected.length / 2];
IOUtil.readFully(in, actual, 0, expected.length / 2);
Expand Down

0 comments on commit bbbfd1e

Please sign in to comment.