Skip to content

Commit

Permalink
PR feedback
Browse files Browse the repository at this point in the history
MdcUtils:
- Added public utility method callWithContext for running and returning a callable with MDC's context map temporarily overwritten
- Remove public modifier on overwriteContext in favor of the above

QueueMessageReady:
- Use MdcUtils.callWithContext to simplify process definition
- Revert process return strategy to return booleans directly rather than setting a boolean to return later
- Remove unnecessary callingThreadContext setter (left existing setters untouched even though they can likely be removed: OOS)
- Left public callingThreadContext getter unchanged: it is required to be public for serde operations

Added unit test coverage for all changes.
  • Loading branch information
okotsopoulos committed Jul 15, 2024
1 parent b6aa844 commit d8fd414
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 34 deletions.
29 changes: 27 additions & 2 deletions stairway/src/main/java/bio/terra/stairway/impl/MdcUtils.java
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
package bio.terra.stairway.impl;

import bio.terra.stairway.FlightContext;
import bio.terra.stairway.exception.StairwayExecutionException;
import java.util.Map;
import java.util.concurrent.Callable;
import org.slf4j.MDC;

/**
* Utility methods to make Stairway flight runnables context-aware, using mapped diagnostic context
* (MDC).
*/
public class MdcUtils {

/** ID of the flight */
static final String FLIGHT_ID_KEY = "flightId";

Expand All @@ -25,12 +26,36 @@ public class MdcUtils {
/** The step's execution order */
static final String FLIGHT_STEP_NUMBER_KEY = "flightStepNumber";

/**
* Run and return the result of the callable with MDC's context map temporarily overwritten during
* computation. The initial context map is then restored after computation.
*
* @param context to override MDC's context map
* @param callable to call and return
*/
public static <T> T callWithContext(Map<String, String> context, Callable<T> callable)
throws InterruptedException {
// Save the initial thread context so that it can be restored
Map<String, String> initialContext = MDC.getCopyOfContextMap();
try {
MdcUtils.overwriteContext(context);
System.out.println(MDC.getCopyOfContextMap());
return callable.call();
} catch (InterruptedException ex) {
throw ex;
} catch (Exception ex) {
throw new StairwayExecutionException("Unexpected exception " + ex.getMessage(), ex);
} finally {
MdcUtils.overwriteContext(initialContext);
}
}

/**
* Null-safe utility method for overwriting the current thread's MDC.
*
* @param context to set as MDC, if null then MDC will be cleared.
*/
public static void overwriteContext(Map<String, String> context) {
static void overwriteContext(Map<String, String> context) {
MDC.clear();
if (context != null) {
MDC.setContextMap(context);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,27 +32,26 @@ public QueueMessageReady(String flightId) {

@Override
public boolean process(StairwayImpl stairwayImpl) throws InterruptedException {
boolean processed = false;
// Save the initial thread context so that it can be restored
Map<String, String> initialContext = MDC.getCopyOfContextMap();
try {
MdcUtils.overwriteContext(callingThreadContext);
// Resumed is false if the flight is not found in the Ready state. We still call that
// a complete processing of the message and return true. We assume that some this is a
// duplicate message or that some other Stairway found the ready flight on recovery.
boolean resumed = stairwayImpl.resume(flightId);
logger.info(
"Stairway "
+ stairwayImpl.getStairwayName()
+ (resumed ? " resumed flight: " : " did not find flight to resume: ")
+ flightId);
processed = true;
} catch (DatabaseOperationException ex) {
logger.error("Unexpected stairway error, leaving %s on the queue".formatted(flightId), ex);
} finally {
MdcUtils.overwriteContext(initialContext);
}
return processed;
return MdcUtils.callWithContext(
callingThreadContext,
() -> {
try {
// Resumed is false if the flight is not found in the Ready state. We still call that
// a complete processing of the message and return true. We assume that some this is a
// duplicate message or that some other Stairway found the ready flight on recovery.
boolean resumed = stairwayImpl.resume(flightId);
logger.info(
"Stairway "
+ stairwayImpl.getStairwayName()
+ (resumed ? " resumed flight: " : " did not find flight to resume: ")
+ flightId);
return true;
} catch (DatabaseOperationException ex) {
logger.error(
"Unexpected stairway error, leaving %s on the queue".formatted(flightId), ex);
return false;
}
});
}

public QueueMessageType getType() {
Expand All @@ -74,8 +73,4 @@ public void setFlightId(String flightId) {
public Map<String, String> getCallingThreadContext() {
return callingThreadContext;
}

public void setCallingThreadContext(Map<String, String> callingThreadContext) {
this.callingThreadContext = callingThreadContext;
}
}
57 changes: 57 additions & 0 deletions stairway/src/test/java/bio/terra/stairway/impl/MdcUtilsTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,27 @@

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.equalTo;
import static org.junit.jupiter.api.Assertions.assertThrows;

import bio.terra.stairway.Direction;
import bio.terra.stairway.exception.StairwayExecutionException;
import bio.terra.stairway.fixtures.TestFlightContext;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.stream.Stream;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.slf4j.MDC;

@Tag("unit")
class MdcUtilsTest {
private static final Map<String, String> INITIAL_CONTEXT = Map.of("initial", "context");
private static final Map<String, String> FOO_BAR = Map.of("foo", "bar");
private static final String FLIGHT_ID = "flightId" + UUID.randomUUID();
private static final String FLIGHT_CLASS = "flightClass" + UUID.randomUUID();
Expand Down Expand Up @@ -52,6 +58,57 @@ static Stream<Map<String, String>> contextMap() {
return Stream.of(null, Map.of(), FOO_BAR);
}

@ParameterizedTest
@MethodSource("contextMap")
void callWithContext(Map<String, String> newContext) throws InterruptedException {
MDC.setContextMap(INITIAL_CONTEXT);
Boolean result =
MdcUtils.callWithContext(
newContext,
() -> {
assertThat(
"Context is overwritten during computation",
MDC.getCopyOfContextMap(),
equalTo(newContext));
return true;
});
assertThat("Result of computation is returned", result, equalTo(true));
assertThat("Initial context is restored", MDC.getCopyOfContextMap(), equalTo(INITIAL_CONTEXT));
}

static Stream<Arguments> callWithContext_exception() {
List<Arguments> arguments = new ArrayList<>();
for (var newContext : contextMap().toList()) {
arguments.add(
Arguments.of(
newContext, new InterruptedException("interrupted"), InterruptedException.class));
arguments.add(
Arguments.of(
newContext, new RuntimeException("unexpected"), StairwayExecutionException.class));
}
return arguments.stream();
}

@ParameterizedTest
@MethodSource
<T extends Throwable> void callWithContext_exception(
Map<String, String> newContext, Exception exception, Class<T> expectedExceptionClass) {
MDC.setContextMap(INITIAL_CONTEXT);
assertThrows(
expectedExceptionClass,
() ->
MdcUtils.callWithContext(
newContext,
() -> {
assertThat(
"Context is overwritten during computation",
MDC.getCopyOfContextMap(),
equalTo(newContext));
throw exception;
}));
assertThat("Initial context is restored", MDC.getCopyOfContextMap(), equalTo(INITIAL_CONTEXT));
}

@ParameterizedTest
@MethodSource("contextMap")
void overwriteContext(Map<String, String> newContext) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,19 @@ void beforeEach() {
MDC.clear();
}

private QueueMessageReady createQueueMessageWithContext(Map<String, String> expectedMdc)
throws InterruptedException {
return MdcUtils.callWithContext(expectedMdc, () -> new QueueMessageReady(FLIGHT_ID));
}

private static Stream<Map<String, String>> message_serde() {
return Stream.of(null, CALLING_THREAD_CONTEXT);
}

@ParameterizedTest
@MethodSource
void message_serde(Map<String, String> expectedMdc) {
MdcUtils.overwriteContext(expectedMdc);
QueueMessageReady messageReady = new QueueMessageReady(FLIGHT_ID);
void message_serde(Map<String, String> expectedMdc) throws InterruptedException {
QueueMessageReady messageReady = createQueueMessageWithContext(expectedMdc);
WorkQueueProcessor workQueueProcessor = new WorkQueueProcessor(stairway);

// Now we add something else to the MDC, but it won't show up in our deserialized queue message.
Expand All @@ -69,8 +73,7 @@ void message_serde(Map<String, String> expectedMdc) {
@ParameterizedTest
@ValueSource(booleans = {true, false})
void process(boolean resumeAnswer) throws InterruptedException {
QueueMessageReady messageReady = new QueueMessageReady(FLIGHT_ID);
messageReady.setCallingThreadContext(CALLING_THREAD_CONTEXT);
QueueMessageReady messageReady = createQueueMessageWithContext(CALLING_THREAD_CONTEXT);

when(stairway.resume(FLIGHT_ID))
.thenAnswer(
Expand All @@ -92,8 +95,7 @@ void process(boolean resumeAnswer) throws InterruptedException {

@Test
void process_DatabaseOperationException() throws InterruptedException {
QueueMessageReady messageReady = new QueueMessageReady(FLIGHT_ID);
messageReady.setCallingThreadContext(CALLING_THREAD_CONTEXT);
QueueMessageReady messageReady = createQueueMessageWithContext(CALLING_THREAD_CONTEXT);

doThrow(DatabaseOperationException.class).when(stairway).resume(FLIGHT_ID);

Expand Down

0 comments on commit d8fd414

Please sign in to comment.