diff --git a/stairway/build.gradle b/stairway/build.gradle index 02418096..184b7616 100644 --- a/stairway/build.gradle +++ b/stairway/build.gradle @@ -33,6 +33,9 @@ dependencies { // File handling during testing testImplementation group: 'commons-io', name: 'commons-io', version: '2.16.1' + + // Mocks during testing + testImplementation group: 'org.mockito', name: 'mockito-junit-jupiter', version: '5.12.0' } apply from: "$rootDir/gradle/test.gradle" diff --git a/stairway/src/main/java/bio/terra/stairway/impl/MdcUtils.java b/stairway/src/main/java/bio/terra/stairway/impl/MdcUtils.java index 5d8e01b1..555d5114 100644 --- a/stairway/src/main/java/bio/terra/stairway/impl/MdcUtils.java +++ b/stairway/src/main/java/bio/terra/stairway/impl/MdcUtils.java @@ -1,15 +1,16 @@ package bio.terra.stairway.impl; import bio.terra.stairway.FlightContext; +import bio.terra.stairway.exception.StairwayExecutionException; import java.util.Map; +import java.util.concurrent.Callable; import org.slf4j.MDC; /** * Utility methods to make Stairway flight runnables context-aware, using mapped diagnostic context * (MDC). */ -class MdcUtils { - +public class MdcUtils { /** ID of the flight */ static final String FLIGHT_ID_KEY = "flightId"; @@ -25,6 +26,29 @@ class MdcUtils { /** The step's execution order */ static final String FLIGHT_STEP_NUMBER_KEY = "flightStepNumber"; + /** + * Run and return the result of the callable with MDC's context map temporarily overwritten during + * computation. The initial context map is then restored after computation. + * + * @param context to override MDC's context map + * @param callable to call and return + */ + public static T callWithContext(Map context, Callable callable) + throws InterruptedException { + // Save the initial thread context so that it can be restored + Map initialContext = MDC.getCopyOfContextMap(); + try { + MdcUtils.overwriteContext(context); + return callable.call(); + } catch (InterruptedException ex) { + throw ex; + } catch (Exception ex) { + throw new StairwayExecutionException("Unexpected exception " + ex.getMessage(), ex); + } finally { + MdcUtils.overwriteContext(initialContext); + } + } + /** * Null-safe utility method for overwriting the current thread's MDC. * diff --git a/stairway/src/main/java/bio/terra/stairway/impl/StairwayImpl.java b/stairway/src/main/java/bio/terra/stairway/impl/StairwayImpl.java index 0f6052aa..a3f5672d 100644 --- a/stairway/src/main/java/bio/terra/stairway/impl/StairwayImpl.java +++ b/stairway/src/main/java/bio/terra/stairway/impl/StairwayImpl.java @@ -26,7 +26,6 @@ import jakarta.annotation.Nullable; import java.time.Duration; import java.util.List; -import java.util.Optional; import java.util.UUID; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledThreadPoolExecutor; diff --git a/stairway/src/main/java/bio/terra/stairway/queue/QueueMessageReady.java b/stairway/src/main/java/bio/terra/stairway/queue/QueueMessageReady.java index 8eea93fc..59414185 100644 --- a/stairway/src/main/java/bio/terra/stairway/queue/QueueMessageReady.java +++ b/stairway/src/main/java/bio/terra/stairway/queue/QueueMessageReady.java @@ -1,9 +1,12 @@ package bio.terra.stairway.queue; import bio.terra.stairway.exception.DatabaseOperationException; +import bio.terra.stairway.impl.MdcUtils; import bio.terra.stairway.impl.StairwayImpl; +import java.util.Map; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.slf4j.MDC; /** * The Ready message communicates that a flight, identified by flightId, is ready for execution. By @@ -16,6 +19,7 @@ class QueueMessageReady extends QueueMessage { private QueueMessageType type; private String flightId; + private Map callingThreadContext; private QueueMessageReady() {} @@ -23,25 +27,31 @@ public QueueMessageReady(String flightId) { this.type = new QueueMessageType(QueueMessage.FORMAT_VERSION, QueueMessageEnum.QUEUE_MESSAGE_READY); this.flightId = flightId; + this.callingThreadContext = MDC.getCopyOfContextMap(); } @Override public boolean process(StairwayImpl stairwayImpl) throws InterruptedException { - try { - // Resumed is false if the flight is not found in the Ready state. We still call that - // a complete processing of the message and return true. We assume that some this is a - // duplicate message or that some other Stairway found the ready flight on recovery. - boolean resumed = stairwayImpl.resume(flightId); - logger.info( - "Stairway " - + stairwayImpl.getStairwayName() - + (resumed ? " resumed flight: " : " did not find flight to resume: ") - + flightId); - return true; - } catch (DatabaseOperationException ex) { - logger.error("Unexpected stairway error", ex); - return false; // leave the message on the queue - } + return MdcUtils.callWithContext( + callingThreadContext, + () -> { + try { + // Resumed is false if the flight is not found in the Ready state. We still call that + // a complete processing of the message and return true. We assume that some this is a + // duplicate message or that some other Stairway found the ready flight on recovery. + boolean resumed = stairwayImpl.resume(flightId); + logger.info( + "Stairway " + + stairwayImpl.getStairwayName() + + (resumed ? " resumed flight: " : " did not find flight to resume: ") + + flightId); + return true; + } catch (DatabaseOperationException ex) { + logger.error( + "Unexpected stairway error, leaving %s on the queue".formatted(flightId), ex); + return false; + } + }); } public QueueMessageType getType() { @@ -59,4 +69,8 @@ public String getFlightId() { public void setFlightId(String flightId) { this.flightId = flightId; } + + public Map getCallingThreadContext() { + return callingThreadContext; + } } diff --git a/stairway/src/test/java/bio/terra/stairway/impl/MdcUtilsTest.java b/stairway/src/test/java/bio/terra/stairway/impl/MdcUtilsTest.java index be36c155..1182c43e 100644 --- a/stairway/src/test/java/bio/terra/stairway/impl/MdcUtilsTest.java +++ b/stairway/src/test/java/bio/terra/stairway/impl/MdcUtilsTest.java @@ -2,21 +2,27 @@ import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.equalTo; +import static org.junit.jupiter.api.Assertions.assertThrows; import bio.terra.stairway.Direction; +import bio.terra.stairway.exception.StairwayExecutionException; import bio.terra.stairway.fixtures.TestFlightContext; +import java.util.ArrayList; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.UUID; import java.util.stream.Stream; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Tag; import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; import org.slf4j.MDC; @Tag("unit") class MdcUtilsTest { + private static final Map INITIAL_CONTEXT = Map.of("initial", "context"); private static final Map FOO_BAR = Map.of("foo", "bar"); private static final String FLIGHT_ID = "flightId" + UUID.randomUUID(); private static final String FLIGHT_CLASS = "flightClass" + UUID.randomUUID(); @@ -52,6 +58,57 @@ static Stream> contextMap() { return Stream.of(null, Map.of(), FOO_BAR); } + @ParameterizedTest + @MethodSource("contextMap") + void callWithContext(Map newContext) throws InterruptedException { + MDC.setContextMap(INITIAL_CONTEXT); + Boolean result = + MdcUtils.callWithContext( + newContext, + () -> { + assertThat( + "Context is overwritten during computation", + MDC.getCopyOfContextMap(), + equalTo(newContext)); + return true; + }); + assertThat("Result of computation is returned", result, equalTo(true)); + assertThat("Initial context is restored", MDC.getCopyOfContextMap(), equalTo(INITIAL_CONTEXT)); + } + + static Stream callWithContext_exception() { + List arguments = new ArrayList<>(); + for (var newContext : contextMap().toList()) { + arguments.add( + Arguments.of( + newContext, new InterruptedException("interrupted"), InterruptedException.class)); + arguments.add( + Arguments.of( + newContext, new RuntimeException("unexpected"), StairwayExecutionException.class)); + } + return arguments.stream(); + } + + @ParameterizedTest + @MethodSource + void callWithContext_exception( + Map newContext, Exception exception, Class expectedExceptionClass) { + MDC.setContextMap(INITIAL_CONTEXT); + assertThrows( + expectedExceptionClass, + () -> + MdcUtils.callWithContext( + newContext, + () -> { + assertThat( + "Context is overwritten during computation", + MDC.getCopyOfContextMap(), + equalTo(newContext)); + throw exception; + })); + assertThat("Initial context is restored", MDC.getCopyOfContextMap(), equalTo(INITIAL_CONTEXT)); + } + @ParameterizedTest @MethodSource("contextMap") void overwriteContext(Map newContext) { diff --git a/stairway/src/test/java/bio/terra/stairway/queue/QueueMessageTest.java b/stairway/src/test/java/bio/terra/stairway/queue/QueueMessageTest.java index 211b0e2e..02d350ac 100644 --- a/stairway/src/test/java/bio/terra/stairway/queue/QueueMessageTest.java +++ b/stairway/src/test/java/bio/terra/stairway/queue/QueueMessageTest.java @@ -1,32 +1,108 @@ package bio.terra.stairway.queue; import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.instanceOf; import static org.hamcrest.MatcherAssert.assertThat; -import static org.junit.jupiter.api.Assertions.fail; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.when; +import bio.terra.stairway.exception.DatabaseOperationException; +import bio.terra.stairway.impl.MdcUtils; +import bio.terra.stairway.impl.StairwayImpl; +import java.util.Map; +import java.util.stream.Stream; +import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Tag; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.junit.jupiter.params.provider.ValueSource; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.mockito.stubbing.Answer; +import org.slf4j.MDC; @Tag("unit") -public class QueueMessageTest { +@ExtendWith(MockitoExtension.class) +class QueueMessageTest { + + @Mock private StairwayImpl stairway; + private static final String FLIGHT_ID = "flight-abc"; + private static final Map CALLING_THREAD_CONTEXT = + Map.of("requestId", "request-abc"); + + @BeforeEach + void beforeEach() { + MDC.clear(); + } + + private QueueMessageReady createQueueMessageWithContext(Map expectedMdc) + throws InterruptedException { + return MdcUtils.callWithContext(expectedMdc, () -> new QueueMessageReady(FLIGHT_ID)); + } + + private static Stream> message_serde() { + return Stream.of(null, CALLING_THREAD_CONTEXT); + } + + @ParameterizedTest + @MethodSource + void message_serde(Map expectedMdc) throws InterruptedException { + QueueMessageReady messageReady = createQueueMessageWithContext(expectedMdc); + WorkQueueProcessor workQueueProcessor = new WorkQueueProcessor(stairway); + + // Now we add something else to the MDC, but it won't show up in our deserialized queue message. + MDC.put("another-key", "another-value"); + + String serialized = workQueueProcessor.serialize(messageReady); + QueueMessage deserialized = workQueueProcessor.deserialize(serialized); + assertThat(deserialized, instanceOf(QueueMessageReady.class)); + + QueueMessageReady messageReadyCopy = (QueueMessageReady) deserialized; + assertThat(messageReadyCopy.getFlightId(), equalTo(messageReady.getFlightId())); + assertThat( + messageReadyCopy.getType().getMessageEnum(), + equalTo(messageReady.getType().getMessageEnum())); + assertThat( + messageReadyCopy.getType().getVersion(), equalTo(messageReady.getType().getVersion())); + assertThat(messageReadyCopy.getFlightId(), equalTo(messageReady.getFlightId())); + assertThat(messageReadyCopy.getCallingThreadContext(), equalTo(expectedMdc)); + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void process(boolean resumeAnswer) throws InterruptedException { + QueueMessageReady messageReady = createQueueMessageWithContext(CALLING_THREAD_CONTEXT); + + when(stairway.resume(FLIGHT_ID)) + .thenAnswer( + (Answer) + invocation -> { + assertThat( + "MDC is set during processing", + MDC.getCopyOfContextMap(), + equalTo(CALLING_THREAD_CONTEXT)); + return resumeAnswer; + }); + + assertThat( + "Message is considered processed when stairway.resume returns " + resumeAnswer, + messageReady.process(stairway), + equalTo(true)); + assertThat("MDC is reverted after processing", MDC.getCopyOfContextMap(), equalTo(null)); + } @Test - public void messageTest() throws Exception { - WorkQueueProcessor queueProcessor = new WorkQueueProcessor(null); - QueueMessageReady messageReady = new QueueMessageReady("abcde"); - String serialized = queueProcessor.serialize(messageReady); - QueueMessage deserialized = queueProcessor.deserialize(serialized); - if (deserialized instanceof QueueMessageReady) { - QueueMessageReady messageReadyCopy = (QueueMessageReady) deserialized; - assertThat(messageReadyCopy.getFlightId(), equalTo(messageReady.getFlightId())); - assertThat( - messageReadyCopy.getType().getMessageEnum(), - equalTo(messageReady.getType().getMessageEnum())); - assertThat( - messageReadyCopy.getType().getVersion(), equalTo(messageReady.getType().getVersion())); - assertThat(messageReadyCopy.getFlightId(), equalTo(messageReady.getFlightId())); - } else { - fail(); - } + void process_DatabaseOperationException() throws InterruptedException { + QueueMessageReady messageReady = createQueueMessageWithContext(CALLING_THREAD_CONTEXT); + + doThrow(DatabaseOperationException.class).when(stairway).resume(FLIGHT_ID); + + assertThat( + "Message is left on the queue when stairway.resume throws", + messageReady.process(stairway), + equalTo(false)); + assertThat(MDC.getCopyOfContextMap(), equalTo(null)); } }