Skip to content

Commit

Permalink
[DCJ-507] Stairway with work queue enabled should be context-aware (#152
Browse files Browse the repository at this point in the history
)

* [DCJ-507] Flights queued in GCP Pub-Sub include calling thread context

Context-aware Stairway previously assumed that the calling thread would have mapped diagnostic context (MDC) to persist to the child threads spawned for flight execution.

This was not the case when Stairway is configured with a work queue enabled (e.g. GCP Pub-Sub).  Here, flight information was written to the pub-sub topic without the calling thread's context, so when a message was processed and deserialized any logs emitted from that flight would be missing context set by the original calling thread (e.g. request ID).

- Expanded QueueMessageReady to store calling thread context on construction
- QueueMessageReady.process sets the MDC using its stored calling thread context, which then makes the flight context-aware
- Expanded unit tests

* Sonar: JUnit 5 test classes, methods should have default visibility

* PR feedback

MdcUtils:
- Added public utility method callWithContext for running and returning a callable with MDC's context map temporarily overwritten
- Remove public modifier on overwriteContext in favor of the above

QueueMessageReady:
- Use MdcUtils.callWithContext to simplify process definition
- Revert process return strategy to return booleans directly rather than setting a boolean to return later
- Remove unnecessary callingThreadContext setter (left existing setters untouched even though they can likely be removed: OOS)
- Left public callingThreadContext getter unchanged: it is required to be public for serde operations

Added unit test coverage for all changes.

* Remove debugging print statement errantly checked in
  • Loading branch information
okotsopoulos authored Jul 15, 2024
1 parent 1410310 commit 6f74895
Show file tree
Hide file tree
Showing 6 changed files with 210 additions and 37 deletions.
3 changes: 3 additions & 0 deletions stairway/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ dependencies {

// File handling during testing
testImplementation group: 'commons-io', name: 'commons-io', version: '2.16.1'

// Mocks during testing
testImplementation group: 'org.mockito', name: 'mockito-junit-jupiter', version: '5.12.0'
}

apply from: "$rootDir/gradle/test.gradle"
28 changes: 26 additions & 2 deletions stairway/src/main/java/bio/terra/stairway/impl/MdcUtils.java
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
package bio.terra.stairway.impl;

import bio.terra.stairway.FlightContext;
import bio.terra.stairway.exception.StairwayExecutionException;
import java.util.Map;
import java.util.concurrent.Callable;
import org.slf4j.MDC;

/**
* Utility methods to make Stairway flight runnables context-aware, using mapped diagnostic context
* (MDC).
*/
class MdcUtils {

public class MdcUtils {
/** ID of the flight */
static final String FLIGHT_ID_KEY = "flightId";

Expand All @@ -25,6 +26,29 @@ class MdcUtils {
/** The step's execution order */
static final String FLIGHT_STEP_NUMBER_KEY = "flightStepNumber";

/**
* Run and return the result of the callable with MDC's context map temporarily overwritten during
* computation. The initial context map is then restored after computation.
*
* @param context to override MDC's context map
* @param callable to call and return
*/
public static <T> T callWithContext(Map<String, String> context, Callable<T> callable)
throws InterruptedException {
// Save the initial thread context so that it can be restored
Map<String, String> initialContext = MDC.getCopyOfContextMap();
try {
MdcUtils.overwriteContext(context);
return callable.call();
} catch (InterruptedException ex) {
throw ex;
} catch (Exception ex) {
throw new StairwayExecutionException("Unexpected exception " + ex.getMessage(), ex);
} finally {
MdcUtils.overwriteContext(initialContext);
}
}

/**
* Null-safe utility method for overwriting the current thread's MDC.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import jakarta.annotation.Nullable;
import java.time.Duration;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledThreadPoolExecutor;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
package bio.terra.stairway.queue;

import bio.terra.stairway.exception.DatabaseOperationException;
import bio.terra.stairway.impl.MdcUtils;
import bio.terra.stairway.impl.StairwayImpl;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.slf4j.MDC;

/**
* The Ready message communicates that a flight, identified by flightId, is ready for execution. By
Expand All @@ -16,32 +19,39 @@ class QueueMessageReady extends QueueMessage {

private QueueMessageType type;
private String flightId;
private Map<String, String> callingThreadContext;

private QueueMessageReady() {}

public QueueMessageReady(String flightId) {
this.type =
new QueueMessageType(QueueMessage.FORMAT_VERSION, QueueMessageEnum.QUEUE_MESSAGE_READY);
this.flightId = flightId;
this.callingThreadContext = MDC.getCopyOfContextMap();
}

@Override
public boolean process(StairwayImpl stairwayImpl) throws InterruptedException {
try {
// Resumed is false if the flight is not found in the Ready state. We still call that
// a complete processing of the message and return true. We assume that some this is a
// duplicate message or that some other Stairway found the ready flight on recovery.
boolean resumed = stairwayImpl.resume(flightId);
logger.info(
"Stairway "
+ stairwayImpl.getStairwayName()
+ (resumed ? " resumed flight: " : " did not find flight to resume: ")
+ flightId);
return true;
} catch (DatabaseOperationException ex) {
logger.error("Unexpected stairway error", ex);
return false; // leave the message on the queue
}
return MdcUtils.callWithContext(
callingThreadContext,
() -> {
try {
// Resumed is false if the flight is not found in the Ready state. We still call that
// a complete processing of the message and return true. We assume that some this is a
// duplicate message or that some other Stairway found the ready flight on recovery.
boolean resumed = stairwayImpl.resume(flightId);
logger.info(
"Stairway "
+ stairwayImpl.getStairwayName()
+ (resumed ? " resumed flight: " : " did not find flight to resume: ")
+ flightId);
return true;
} catch (DatabaseOperationException ex) {
logger.error(
"Unexpected stairway error, leaving %s on the queue".formatted(flightId), ex);
return false;
}
});
}

public QueueMessageType getType() {
Expand All @@ -59,4 +69,8 @@ public String getFlightId() {
public void setFlightId(String flightId) {
this.flightId = flightId;
}

public Map<String, String> getCallingThreadContext() {
return callingThreadContext;
}
}
57 changes: 57 additions & 0 deletions stairway/src/test/java/bio/terra/stairway/impl/MdcUtilsTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,27 @@

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.equalTo;
import static org.junit.jupiter.api.Assertions.assertThrows;

import bio.terra.stairway.Direction;
import bio.terra.stairway.exception.StairwayExecutionException;
import bio.terra.stairway.fixtures.TestFlightContext;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.stream.Stream;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.slf4j.MDC;

@Tag("unit")
class MdcUtilsTest {
private static final Map<String, String> INITIAL_CONTEXT = Map.of("initial", "context");
private static final Map<String, String> FOO_BAR = Map.of("foo", "bar");
private static final String FLIGHT_ID = "flightId" + UUID.randomUUID();
private static final String FLIGHT_CLASS = "flightClass" + UUID.randomUUID();
Expand Down Expand Up @@ -52,6 +58,57 @@ static Stream<Map<String, String>> contextMap() {
return Stream.of(null, Map.of(), FOO_BAR);
}

@ParameterizedTest
@MethodSource("contextMap")
void callWithContext(Map<String, String> newContext) throws InterruptedException {
MDC.setContextMap(INITIAL_CONTEXT);
Boolean result =
MdcUtils.callWithContext(
newContext,
() -> {
assertThat(
"Context is overwritten during computation",
MDC.getCopyOfContextMap(),
equalTo(newContext));
return true;
});
assertThat("Result of computation is returned", result, equalTo(true));
assertThat("Initial context is restored", MDC.getCopyOfContextMap(), equalTo(INITIAL_CONTEXT));
}

static Stream<Arguments> callWithContext_exception() {
List<Arguments> arguments = new ArrayList<>();
for (var newContext : contextMap().toList()) {
arguments.add(
Arguments.of(
newContext, new InterruptedException("interrupted"), InterruptedException.class));
arguments.add(
Arguments.of(
newContext, new RuntimeException("unexpected"), StairwayExecutionException.class));
}
return arguments.stream();
}

@ParameterizedTest
@MethodSource
<T extends Throwable> void callWithContext_exception(
Map<String, String> newContext, Exception exception, Class<T> expectedExceptionClass) {
MDC.setContextMap(INITIAL_CONTEXT);
assertThrows(
expectedExceptionClass,
() ->
MdcUtils.callWithContext(
newContext,
() -> {
assertThat(
"Context is overwritten during computation",
MDC.getCopyOfContextMap(),
equalTo(newContext));
throw exception;
}));
assertThat("Initial context is restored", MDC.getCopyOfContextMap(), equalTo(INITIAL_CONTEXT));
}

@ParameterizedTest
@MethodSource("contextMap")
void overwriteContext(Map<String, String> newContext) {
Expand Down
114 changes: 95 additions & 19 deletions stairway/src/test/java/bio/terra/stairway/queue/QueueMessageTest.java
Original file line number Diff line number Diff line change
@@ -1,32 +1,108 @@
package bio.terra.stairway.queue;

import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.CoreMatchers.instanceOf;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.jupiter.api.Assertions.fail;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.when;

import bio.terra.stairway.exception.DatabaseOperationException;
import bio.terra.stairway.impl.MdcUtils;
import bio.terra.stairway.impl.StairwayImpl;
import java.util.Map;
import java.util.stream.Stream;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.mockito.stubbing.Answer;
import org.slf4j.MDC;

@Tag("unit")
public class QueueMessageTest {
@ExtendWith(MockitoExtension.class)
class QueueMessageTest {

@Mock private StairwayImpl stairway;
private static final String FLIGHT_ID = "flight-abc";
private static final Map<String, String> CALLING_THREAD_CONTEXT =
Map.of("requestId", "request-abc");

@BeforeEach
void beforeEach() {
MDC.clear();
}

private QueueMessageReady createQueueMessageWithContext(Map<String, String> expectedMdc)
throws InterruptedException {
return MdcUtils.callWithContext(expectedMdc, () -> new QueueMessageReady(FLIGHT_ID));
}

private static Stream<Map<String, String>> message_serde() {
return Stream.of(null, CALLING_THREAD_CONTEXT);
}

@ParameterizedTest
@MethodSource
void message_serde(Map<String, String> expectedMdc) throws InterruptedException {
QueueMessageReady messageReady = createQueueMessageWithContext(expectedMdc);
WorkQueueProcessor workQueueProcessor = new WorkQueueProcessor(stairway);

// Now we add something else to the MDC, but it won't show up in our deserialized queue message.
MDC.put("another-key", "another-value");

String serialized = workQueueProcessor.serialize(messageReady);
QueueMessage deserialized = workQueueProcessor.deserialize(serialized);
assertThat(deserialized, instanceOf(QueueMessageReady.class));

QueueMessageReady messageReadyCopy = (QueueMessageReady) deserialized;
assertThat(messageReadyCopy.getFlightId(), equalTo(messageReady.getFlightId()));
assertThat(
messageReadyCopy.getType().getMessageEnum(),
equalTo(messageReady.getType().getMessageEnum()));
assertThat(
messageReadyCopy.getType().getVersion(), equalTo(messageReady.getType().getVersion()));
assertThat(messageReadyCopy.getFlightId(), equalTo(messageReady.getFlightId()));
assertThat(messageReadyCopy.getCallingThreadContext(), equalTo(expectedMdc));
}

@ParameterizedTest
@ValueSource(booleans = {true, false})
void process(boolean resumeAnswer) throws InterruptedException {
QueueMessageReady messageReady = createQueueMessageWithContext(CALLING_THREAD_CONTEXT);

when(stairway.resume(FLIGHT_ID))
.thenAnswer(
(Answer<Boolean>)
invocation -> {
assertThat(
"MDC is set during processing",
MDC.getCopyOfContextMap(),
equalTo(CALLING_THREAD_CONTEXT));
return resumeAnswer;
});

assertThat(
"Message is considered processed when stairway.resume returns " + resumeAnswer,
messageReady.process(stairway),
equalTo(true));
assertThat("MDC is reverted after processing", MDC.getCopyOfContextMap(), equalTo(null));
}

@Test
public void messageTest() throws Exception {
WorkQueueProcessor queueProcessor = new WorkQueueProcessor(null);
QueueMessageReady messageReady = new QueueMessageReady("abcde");
String serialized = queueProcessor.serialize(messageReady);
QueueMessage deserialized = queueProcessor.deserialize(serialized);
if (deserialized instanceof QueueMessageReady) {
QueueMessageReady messageReadyCopy = (QueueMessageReady) deserialized;
assertThat(messageReadyCopy.getFlightId(), equalTo(messageReady.getFlightId()));
assertThat(
messageReadyCopy.getType().getMessageEnum(),
equalTo(messageReady.getType().getMessageEnum()));
assertThat(
messageReadyCopy.getType().getVersion(), equalTo(messageReady.getType().getVersion()));
assertThat(messageReadyCopy.getFlightId(), equalTo(messageReady.getFlightId()));
} else {
fail();
}
void process_DatabaseOperationException() throws InterruptedException {
QueueMessageReady messageReady = createQueueMessageWithContext(CALLING_THREAD_CONTEXT);

doThrow(DatabaseOperationException.class).when(stairway).resume(FLIGHT_ID);

assertThat(
"Message is left on the queue when stairway.resume throws",
messageReady.process(stairway),
equalTo(false));
assertThat(MDC.getCopyOfContextMap(), equalTo(null));
}
}

0 comments on commit 6f74895

Please sign in to comment.