Skip to content

Commit

Permalink
jfr-connection: jfr-connection: wrong parameter sent to JFR Diagnosti…
Browse files Browse the repository at this point in the history
…cCommand (#1492)
  • Loading branch information
dsgrieve authored Oct 16, 2024
1 parent dc66266 commit 9e68011
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 96 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,21 @@

import java.io.IOException;
import java.io.InputStream;
import java.lang.management.ManagementFactory;
import java.lang.management.RuntimeMXBean;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import javax.management.InstanceNotFoundException;
import javax.management.IntrospectionException;
import javax.management.MBeanException;
import javax.management.MBeanInfo;
import javax.management.MBeanOperationInfo;
import javax.management.MBeanServerConnection;
import javax.management.MalformedObjectNameException;
import javax.management.ObjectInstance;
Expand All @@ -37,6 +40,8 @@ final class FlightRecorderDiagnosticCommandConnection implements FlightRecorderC
"com.sun.management:type=DiagnosticCommand";
private static final String JFR_START_REGEX = "Started recording (\\d+?)\\.";
private static final Pattern JFR_START_PATTERN = Pattern.compile(JFR_START_REGEX, Pattern.DOTALL);
private static final String JFR_CHECK_REGEX = "(?:recording|name)=(\\d+)";
private static final Pattern JFR_CHECK_PATTERN = Pattern.compile(JFR_CHECK_REGEX, Pattern.DOTALL);

// All JFR commands take String[] parameters
private static final String[] signature = new String[] {"[Ljava.lang.String;"};
Expand All @@ -59,9 +64,7 @@ static FlightRecorderConnection connect(MBeanServerConnection mBeanServerConnect
mBeanServerConnection.getObjectInstance(new ObjectName(DIAGNOSTIC_COMMAND_OBJECT_NAME));
ObjectName objectName = objectInstance.getObjectName();

if (jdkHasUnlockCommercialFeatures(mBeanServerConnection)) {
assertCommercialFeaturesUnlocked(mBeanServerConnection, objectName);
}
assertCommercialFeaturesUnlocked(mBeanServerConnection, objectName);

return new FlightRecorderDiagnosticCommandConnection(
mBeanServerConnection, objectInstance.getObjectName());
Expand Down Expand Up @@ -123,21 +126,22 @@ public long startRecording(
Object[] params = formOptions(recordingOptions, recordingConfiguration);

// jfrStart returns "Started recording 2." and some more stuff, but all we care about is the
// name of the recording.
// id of the recording.
String jfrStart;
try {
String jfrStart =
(String) mBeanServerConnection.invoke(objectName, "jfrStart", params, signature);
String name;
jfrStart = (String) mBeanServerConnection.invoke(objectName, "jfrStart", params, signature);
Matcher matcher = JFR_START_PATTERN.matcher(jfrStart);
if (matcher.find()) {
name = matcher.group(1);
return Long.parseLong(name);
String id = matcher.group(1);
return Long.parseLong(id);
}
} catch (InstanceNotFoundException | ReflectionException | MBeanException e) {
throw JfrConnectionException.canonicalJfrConnectionException(getClass(), "startRecording", e);
}
throw JfrConnectionException.canonicalJfrConnectionException(
getClass(), "startRecording", new IllegalStateException("Failed to parse jfrStart output"));
getClass(),
"startRecording",
new IllegalStateException("Failed to parse: '" + jfrStart + "'"));
}

private static Object[] formOptions(
Expand All @@ -156,10 +160,33 @@ private static Object[] formOptions(
return mkParamsArray(params);
}

//
// Whether to use the 'name' or 'recording' parameter depends on the JVM.
// Use JFR.check to determine which one to use.
//
private String getRecordingParam(long recordingId) throws JfrConnectionException, IOException {
String jfrCheck;
try {
Object[] params = new Object[] {new String[] {}};
jfrCheck = (String) mBeanServerConnection.invoke(objectName, "jfrCheck", params, signature);
Matcher matcher = JFR_CHECK_PATTERN.matcher(jfrCheck);
while (matcher.find()) {
String id = matcher.group(1);
if (id.equals(Long.toString(recordingId))) {
return matcher.group(0);
}
}
} catch (InstanceNotFoundException | MBeanException | ReflectionException e) {
throw JfrConnectionException.canonicalJfrConnectionException(getClass(), "jfrCheck", e);
}
throw JfrConnectionException.canonicalJfrConnectionException(
getClass(), "jfrCheck", new IllegalStateException("Failed to parse: '" + jfrCheck + "'"));
}

@Override
public void stopRecording(long id) throws JfrConnectionException {
try {
Object[] params = mkParams("name=" + id);
Object[] params = mkParams(getRecordingParam(id));
mBeanServerConnection.invoke(objectName, "jfrStop", params, signature);
} catch (InstanceNotFoundException | MBeanException | ReflectionException | IOException e) {
throw JfrConnectionException.canonicalJfrConnectionException(getClass(), "stopRecording", e);
Expand All @@ -169,7 +196,7 @@ public void stopRecording(long id) throws JfrConnectionException {
@Override
public void dumpRecording(long id, String outputFile) throws IOException, JfrConnectionException {
try {
Object[] params = mkParams("filename=" + outputFile, "name=" + id);
Object[] params = mkParams("filename=" + outputFile, getRecordingParam(id));
mBeanServerConnection.invoke(objectName, "jfrDump", params, signature);
} catch (InstanceNotFoundException | MBeanException | ReflectionException e) {
throw JfrConnectionException.canonicalJfrConnectionException(getClass(), "dumpRecording", e);
Expand Down Expand Up @@ -197,41 +224,34 @@ public void closeRecording(long id) {
"closeRecording not available through the DiagnosticCommand connection");
}

// Do this check separate from assertCommercialFeatures because reliance
// on System properties makes it difficult to test.
static boolean jdkHasUnlockCommercialFeatures(MBeanServerConnection mBeanServerConnection) {
try {
RuntimeMXBean runtimeMxBean =
ManagementFactory.getPlatformMXBean(mBeanServerConnection, RuntimeMXBean.class);
String javaVmVendor = runtimeMxBean.getVmVendor();
String javaVersion = runtimeMxBean.getVmVersion();
return javaVmVendor.contains("Oracle Corporation")
&& javaVersion.matches("(?:^1\\.8|9|10).*");
} catch (IOException e) {
return false;
}
}

// visible for testing
static void assertCommercialFeaturesUnlocked(
MBeanServerConnection mBeanServerConnection, ObjectName objectName)
throws IOException, JfrConnectionException {

try {
Object unlockedMessage =
mBeanServerConnection.invoke(objectName, "vmCheckCommercialFeatures", null, null);
if (unlockedMessage instanceof String) {
boolean unlocked = ((String) unlockedMessage).contains("unlocked");
if (!unlocked) {
throw JfrConnectionException.canonicalJfrConnectionException(
FlightRecorderDiagnosticCommandConnection.class,
"assertCommercialFeaturesUnlocked",
new UnsupportedOperationException(
"Unlocking commercial features may be required. This must be explicitly enabled by adding -XX:+UnlockCommercialFeatures"));
}
Object[] params = new Object[] {new String[] {}};
MBeanInfo mBeanInfo = mBeanServerConnection.getMBeanInfo(objectName);
if (mBeanInfo == null) {
throw JfrConnectionException.canonicalJfrConnectionException(
FlightRecorderDiagnosticCommandConnection.class,
"assertCommercialFeaturesUnlocked",
new NullPointerException("Could not get MBeanInfo for " + objectName));
}
Optional<MBeanOperationInfo> operation =
Arrays.stream(mBeanInfo.getOperations())
.filter(it -> "vmUnlockCommercialFeatures".equals(it.getName()))
.findFirst();

if (operation.isPresent()) {
mBeanServerConnection.invoke(objectName, "vmUnlockCommercialFeatures", params, signature);
}
} catch (InstanceNotFoundException | MBeanException | ReflectionException ignored) {
// If the MBean doesn't have the vmCheckCommercialFeatures method, then we can't check it.
} catch (InstanceNotFoundException
| IntrospectionException
| MBeanException
| ReflectionException e) {
throw JfrConnectionException.canonicalJfrConnectionException(
FlightRecorderDiagnosticCommandConnection.class, "assertCommercialFeaturesUnlocked", e);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,74 +7,26 @@

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.fail;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.mockStatic;
import static org.mockito.Mockito.when;

import com.google.errorprone.annotations.Keep;
import java.lang.management.ManagementFactory;
import java.lang.management.RuntimeMXBean;
import java.util.stream.Stream;
import java.nio.file.Files;
import java.nio.file.Path;
import javax.management.MBeanServer;
import javax.management.MBeanServerConnection;
import javax.management.ObjectName;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.mockito.MockedStatic;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;

class FlightRecorderDiagnosticCommandConnectionTest {

@Keep
static Stream<Arguments> assertJdkHasUnlockCommercialFeatures() {
return Stream.of(
Arguments.of("Oracle Corporation", "1.8.0_401", true),
Arguments.of("AdoptOpenJDK", "1.8.0_282", false),
Arguments.of("Oracle Corporation", "10.0.2", true),
Arguments.of("Oracle Corporation", "9.0.4", true),
Arguments.of("Oracle Corporation", "11.0.22", false),
Arguments.of("Microsoft", "11.0.13", false),
Arguments.of("Microsoft", "17.0.3", false),
Arguments.of("Oracle Corporation", "21.0.3", false));
}

@ParameterizedTest
@MethodSource
void assertJdkHasUnlockCommercialFeatures(String vmVendor, String vmVersion, boolean expected)
throws Exception {

MBeanServerConnection mBeanServerConnection = mock(MBeanServerConnection.class);

try (MockedStatic<ManagementFactory> mockedStatic = mockStatic(ManagementFactory.class)) {
mockedStatic
.when(
() -> ManagementFactory.getPlatformMXBean(mBeanServerConnection, RuntimeMXBean.class))
.thenAnswer(
new Answer<RuntimeMXBean>() {
@Override
public RuntimeMXBean answer(InvocationOnMock invocation) {
RuntimeMXBean mockedRuntimeMxBean = mock(RuntimeMXBean.class);
when(mockedRuntimeMxBean.getVmVendor()).thenReturn(vmVendor);
when(mockedRuntimeMxBean.getVmVersion()).thenReturn(vmVersion);
return mockedRuntimeMxBean;
}
});

boolean actual =
FlightRecorderDiagnosticCommandConnection.jdkHasUnlockCommercialFeatures(
mBeanServerConnection);
assertEquals(expected, actual, "Expected " + expected + " for " + vmVendor + " " + vmVersion);
}
}

@Test
void assertCommercialFeaturesUnlocked() throws Exception {
ObjectName objectName = mock(ObjectName.class);
MBeanServerConnection mBeanServerConnection = mockMbeanServer(objectName, "unlocked");
MBeanServer mBeanServerConnection = ManagementFactory.getPlatformMBeanServer();
ObjectName objectName = new ObjectName("com.sun.management:type=DiagnosticCommand");
FlightRecorderDiagnosticCommandConnection.assertCommercialFeaturesUnlocked(
mBeanServerConnection, objectName);
}
Expand Down Expand Up @@ -124,6 +76,36 @@ void startRecordingParsesIdCorrectly() throws Exception {
assertEquals(id, 99);
}

@Test
void endToEndTest() throws Exception {

MBeanServerConnection mBeanServer = ManagementFactory.getPlatformMBeanServer();
FlightRecorderConnection flightRecorderConnection =
FlightRecorderDiagnosticCommandConnection.connect(mBeanServer);
RecordingOptions recordingOptions =
new RecordingOptions.Builder().disk("true").duration("5s").build();
RecordingConfiguration recordingConfiguration = RecordingConfiguration.PROFILE_CONFIGURATION;
Path tempFile = Files.createTempFile("recording", ".jfr");

try (Recording recording =
flightRecorderConnection.newRecording(recordingOptions, recordingConfiguration)) {

recording.start();
try {
Thread.sleep(1000);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
recording.dump(tempFile.toString());
recording.stop();
} finally {
if (!Files.exists(tempFile)) {
fail("Recording file not found");
}
Files.deleteIfExists(tempFile);
}
}

MBeanServerConnection mockMbeanServer(
ObjectName objectName, String vmCheckCommercialFeaturesResponse) throws Exception {
MBeanServerConnection mBeanServerConnection = mock(MBeanServerConnection.class);
Expand Down

0 comments on commit 9e68011

Please sign in to comment.