Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DCJ-507] Stairway with work queue enabled should be context-aware #152

Merged
merged 5 commits into from
Jul 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions stairway/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ dependencies {

// File handling during testing
testImplementation group: 'commons-io', name: 'commons-io', version: '2.16.1'

// Mocks during testing
testImplementation group: 'org.mockito', name: 'mockito-junit-jupiter', version: '5.12.0'
}

apply from: "$rootDir/gradle/test.gradle"
28 changes: 26 additions & 2 deletions stairway/src/main/java/bio/terra/stairway/impl/MdcUtils.java
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
package bio.terra.stairway.impl;

import bio.terra.stairway.FlightContext;
import bio.terra.stairway.exception.StairwayExecutionException;
import java.util.Map;
import java.util.concurrent.Callable;
import org.slf4j.MDC;

/**
* Utility methods to make Stairway flight runnables context-aware, using mapped diagnostic context
* (MDC).
*/
class MdcUtils {

public class MdcUtils {
/** ID of the flight */
static final String FLIGHT_ID_KEY = "flightId";

Expand All @@ -25,6 +26,29 @@ class MdcUtils {
/** The step's execution order */
static final String FLIGHT_STEP_NUMBER_KEY = "flightStepNumber";

/**
* Run and return the result of the callable with MDC's context map temporarily overwritten during
* computation. The initial context map is then restored after computation.
*
* @param context to override MDC's context map
* @param callable to call and return
*/
public static <T> T callWithContext(Map<String, String> context, Callable<T> callable)
throws InterruptedException {
// Save the initial thread context so that it can be restored
Map<String, String> initialContext = MDC.getCopyOfContextMap();
try {
MdcUtils.overwriteContext(context);
return callable.call();
} catch (InterruptedException ex) {
throw ex;
} catch (Exception ex) {
throw new StairwayExecutionException("Unexpected exception " + ex.getMessage(), ex);
} finally {
MdcUtils.overwriteContext(initialContext);
}
}

/**
* Null-safe utility method for overwriting the current thread's MDC.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import jakarta.annotation.Nullable;
import java.time.Duration;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledThreadPoolExecutor;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
package bio.terra.stairway.queue;

import bio.terra.stairway.exception.DatabaseOperationException;
import bio.terra.stairway.impl.MdcUtils;
import bio.terra.stairway.impl.StairwayImpl;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.slf4j.MDC;

/**
* The Ready message communicates that a flight, identified by flightId, is ready for execution. By
Expand All @@ -16,32 +19,39 @@ class QueueMessageReady extends QueueMessage {

private QueueMessageType type;
private String flightId;
private Map<String, String> callingThreadContext;

private QueueMessageReady() {}

public QueueMessageReady(String flightId) {
this.type =
new QueueMessageType(QueueMessage.FORMAT_VERSION, QueueMessageEnum.QUEUE_MESSAGE_READY);
this.flightId = flightId;
this.callingThreadContext = MDC.getCopyOfContextMap();
}

@Override
public boolean process(StairwayImpl stairwayImpl) throws InterruptedException {
try {
// Resumed is false if the flight is not found in the Ready state. We still call that
// a complete processing of the message and return true. We assume that some this is a
// duplicate message or that some other Stairway found the ready flight on recovery.
boolean resumed = stairwayImpl.resume(flightId);
logger.info(
"Stairway "
+ stairwayImpl.getStairwayName()
+ (resumed ? " resumed flight: " : " did not find flight to resume: ")
+ flightId);
return true;
okotsopoulos marked this conversation as resolved.
Show resolved Hide resolved
} catch (DatabaseOperationException ex) {
logger.error("Unexpected stairway error", ex);
return false; // leave the message on the queue
}
return MdcUtils.callWithContext(
callingThreadContext,
() -> {
try {
// Resumed is false if the flight is not found in the Ready state. We still call that
// a complete processing of the message and return true. We assume that some this is a
// duplicate message or that some other Stairway found the ready flight on recovery.
boolean resumed = stairwayImpl.resume(flightId);
logger.info(
"Stairway "
+ stairwayImpl.getStairwayName()
+ (resumed ? " resumed flight: " : " did not find flight to resume: ")
+ flightId);
return true;
} catch (DatabaseOperationException ex) {
logger.error(
"Unexpected stairway error, leaving %s on the queue".formatted(flightId), ex);
return false;
}
});
}

public QueueMessageType getType() {
Expand All @@ -59,4 +69,8 @@ public String getFlightId() {
public void setFlightId(String flightId) {
this.flightId = flightId;
}

public Map<String, String> getCallingThreadContext() {
okotsopoulos marked this conversation as resolved.
Show resolved Hide resolved
return callingThreadContext;
}
}
57 changes: 57 additions & 0 deletions stairway/src/test/java/bio/terra/stairway/impl/MdcUtilsTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,27 @@

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.equalTo;
import static org.junit.jupiter.api.Assertions.assertThrows;

import bio.terra.stairway.Direction;
import bio.terra.stairway.exception.StairwayExecutionException;
import bio.terra.stairway.fixtures.TestFlightContext;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.stream.Stream;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.slf4j.MDC;

@Tag("unit")
class MdcUtilsTest {
private static final Map<String, String> INITIAL_CONTEXT = Map.of("initial", "context");
private static final Map<String, String> FOO_BAR = Map.of("foo", "bar");
private static final String FLIGHT_ID = "flightId" + UUID.randomUUID();
private static final String FLIGHT_CLASS = "flightClass" + UUID.randomUUID();
Expand Down Expand Up @@ -52,6 +58,57 @@ static Stream<Map<String, String>> contextMap() {
return Stream.of(null, Map.of(), FOO_BAR);
}

@ParameterizedTest
@MethodSource("contextMap")
void callWithContext(Map<String, String> newContext) throws InterruptedException {
MDC.setContextMap(INITIAL_CONTEXT);
Boolean result =
MdcUtils.callWithContext(
newContext,
() -> {
assertThat(
"Context is overwritten during computation",
MDC.getCopyOfContextMap(),
equalTo(newContext));
return true;
});
assertThat("Result of computation is returned", result, equalTo(true));
assertThat("Initial context is restored", MDC.getCopyOfContextMap(), equalTo(INITIAL_CONTEXT));
}

static Stream<Arguments> callWithContext_exception() {
List<Arguments> arguments = new ArrayList<>();
for (var newContext : contextMap().toList()) {
arguments.add(
Arguments.of(
newContext, new InterruptedException("interrupted"), InterruptedException.class));
arguments.add(
Arguments.of(
newContext, new RuntimeException("unexpected"), StairwayExecutionException.class));
}
return arguments.stream();
}

@ParameterizedTest
@MethodSource
<T extends Throwable> void callWithContext_exception(
Map<String, String> newContext, Exception exception, Class<T> expectedExceptionClass) {
MDC.setContextMap(INITIAL_CONTEXT);
assertThrows(
expectedExceptionClass,
() ->
MdcUtils.callWithContext(
newContext,
() -> {
assertThat(
"Context is overwritten during computation",
MDC.getCopyOfContextMap(),
equalTo(newContext));
throw exception;
}));
assertThat("Initial context is restored", MDC.getCopyOfContextMap(), equalTo(INITIAL_CONTEXT));
}

@ParameterizedTest
@MethodSource("contextMap")
void overwriteContext(Map<String, String> newContext) {
Expand Down
114 changes: 95 additions & 19 deletions stairway/src/test/java/bio/terra/stairway/queue/QueueMessageTest.java
Original file line number Diff line number Diff line change
@@ -1,32 +1,108 @@
package bio.terra.stairway.queue;

import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.CoreMatchers.instanceOf;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.jupiter.api.Assertions.fail;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.when;

import bio.terra.stairway.exception.DatabaseOperationException;
import bio.terra.stairway.impl.MdcUtils;
import bio.terra.stairway.impl.StairwayImpl;
import java.util.Map;
import java.util.stream.Stream;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.mockito.stubbing.Answer;
import org.slf4j.MDC;

@Tag("unit")
public class QueueMessageTest {
@ExtendWith(MockitoExtension.class)
class QueueMessageTest {

@Mock private StairwayImpl stairway;
private static final String FLIGHT_ID = "flight-abc";
private static final Map<String, String> CALLING_THREAD_CONTEXT =
Map.of("requestId", "request-abc");

@BeforeEach
void beforeEach() {
MDC.clear();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be better to use a static mock instead of the actual MDC API? We don't use mocks yet for MDC, so it may be too big a change for this PR, but it's something to consider if we'd like to isolate our tests from the MDC implementation, and might make it easier to check if our code is doing what we expect.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, yeah this feels like too big a change for this PR. But you make good points about the benefits of test isolation.

}

private QueueMessageReady createQueueMessageWithContext(Map<String, String> expectedMdc)
throws InterruptedException {
return MdcUtils.callWithContext(expectedMdc, () -> new QueueMessageReady(FLIGHT_ID));
}

private static Stream<Map<String, String>> message_serde() {
return Stream.of(null, CALLING_THREAD_CONTEXT);
}

@ParameterizedTest
@MethodSource
void message_serde(Map<String, String> expectedMdc) throws InterruptedException {
QueueMessageReady messageReady = createQueueMessageWithContext(expectedMdc);
WorkQueueProcessor workQueueProcessor = new WorkQueueProcessor(stairway);

// Now we add something else to the MDC, but it won't show up in our deserialized queue message.
MDC.put("another-key", "another-value");

String serialized = workQueueProcessor.serialize(messageReady);
QueueMessage deserialized = workQueueProcessor.deserialize(serialized);
assertThat(deserialized, instanceOf(QueueMessageReady.class));

QueueMessageReady messageReadyCopy = (QueueMessageReady) deserialized;
assertThat(messageReadyCopy.getFlightId(), equalTo(messageReady.getFlightId()));
assertThat(
messageReadyCopy.getType().getMessageEnum(),
equalTo(messageReady.getType().getMessageEnum()));
assertThat(
messageReadyCopy.getType().getVersion(), equalTo(messageReady.getType().getVersion()));
assertThat(messageReadyCopy.getFlightId(), equalTo(messageReady.getFlightId()));
assertThat(messageReadyCopy.getCallingThreadContext(), equalTo(expectedMdc));
}

@ParameterizedTest
@ValueSource(booleans = {true, false})
void process(boolean resumeAnswer) throws InterruptedException {
QueueMessageReady messageReady = createQueueMessageWithContext(CALLING_THREAD_CONTEXT);

when(stairway.resume(FLIGHT_ID))
.thenAnswer(
(Answer<Boolean>)
invocation -> {
assertThat(
"MDC is set during processing",
MDC.getCopyOfContextMap(),
equalTo(CALLING_THREAD_CONTEXT));
return resumeAnswer;
});

assertThat(
"Message is considered processed when stairway.resume returns " + resumeAnswer,
messageReady.process(stairway),
equalTo(true));
assertThat("MDC is reverted after processing", MDC.getCopyOfContextMap(), equalTo(null));
}

@Test
public void messageTest() throws Exception {
WorkQueueProcessor queueProcessor = new WorkQueueProcessor(null);
QueueMessageReady messageReady = new QueueMessageReady("abcde");
String serialized = queueProcessor.serialize(messageReady);
QueueMessage deserialized = queueProcessor.deserialize(serialized);
if (deserialized instanceof QueueMessageReady) {
QueueMessageReady messageReadyCopy = (QueueMessageReady) deserialized;
assertThat(messageReadyCopy.getFlightId(), equalTo(messageReady.getFlightId()));
assertThat(
messageReadyCopy.getType().getMessageEnum(),
equalTo(messageReady.getType().getMessageEnum()));
assertThat(
messageReadyCopy.getType().getVersion(), equalTo(messageReady.getType().getVersion()));
assertThat(messageReadyCopy.getFlightId(), equalTo(messageReady.getFlightId()));
} else {
fail();
}
void process_DatabaseOperationException() throws InterruptedException {
QueueMessageReady messageReady = createQueueMessageWithContext(CALLING_THREAD_CONTEXT);

doThrow(DatabaseOperationException.class).when(stairway).resume(FLIGHT_ID);

assertThat(
"Message is left on the queue when stairway.resume throws",
messageReady.process(stairway),
equalTo(false));
assertThat(MDC.getCopyOfContextMap(), equalTo(null));
}
}