Skip to content

Commit

Permalink
Prevent recipe execution exceptions with unknown inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
timtebeek committed Jul 1, 2024
1 parent 9d76140 commit ddf6c3d
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.openrewrite.Preconditions;
import org.openrewrite.Recipe;
import org.openrewrite.TreeVisitor;
import org.openrewrite.internal.lang.Nullable;
import org.openrewrite.java.JavaIsoVisitor;
import org.openrewrite.java.JavaParser;
import org.openrewrite.java.JavaTemplate;
Expand All @@ -27,7 +28,10 @@
import org.openrewrite.java.tree.*;
import org.openrewrite.marker.Markers;

import java.util.*;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

Expand Down Expand Up @@ -58,33 +62,30 @@ public static boolean isStringLiteral(Expression expression) {
return expression instanceof J.Literal && TypeUtils.isString(((J.Literal) expression).getType());
}

private static Optional<String> getMethodIdentifier(String name) {
String newMethodName = null;
switch (name) {
@Nullable
private static String getMethodIdentifier(Expression levelArgument) {
String levelSimpleName = levelArgument instanceof J.FieldAccess ?
(((J.FieldAccess) levelArgument).getName().getSimpleName()) :
(((J.Identifier) levelArgument).getSimpleName());
switch (levelSimpleName) {
case "ALL":
case "FINEST":
case "FINER":
newMethodName = "trace";
break;
return "trace";
case "FINE":
newMethodName = "debug";
break;
return "debug";
case "CONFIG":
case "INFO":
newMethodName = "info";
break;
return "info";
case "WARNING":
newMethodName = "warn";
break;
return "warn";
case "SEVERE":
newMethodName = "error";
break;
return "error";
}

return Optional.ofNullable(newMethodName);
return null;
}

private static J.Literal buildString(String string) {
private static J.Literal buildStringLiteral(String string) {
return new J.Literal(randomId(), Space.EMPTY, Markers.EMPTY, string, String.format("\"%s\"", string), null, JavaType.Primitive.String);
}

Expand All @@ -93,27 +94,27 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu
if (METHOD_MATCHER_ARRAY.matches(method) || METHOD_MATCHER_PARAM.matches(method)) {
List<Expression> originalArguments = method.getArguments();

Expression levelName = originalArguments.get(0);
String simpleName = ((J.FieldAccess) levelName).getName().getSimpleName();
Optional<String> newName = getMethodIdentifier(simpleName);
if (!newName.isPresent()) {
Expression levelArgument = originalArguments.get(0);
Expression messageArgument = originalArguments.get(1);

if (!(levelArgument instanceof J.FieldAccess || levelArgument instanceof J.Identifier) ||
!isStringLiteral(messageArgument)) {
return method;
}
J.Literal stringFormat = (J.Literal) originalArguments.get(1);
if (!isStringLiteral(stringFormat)) {
String newName = getMethodIdentifier(levelArgument);
if(newName == null) {
return method;
}

maybeRemoveImport("java.util.logging.Level");

String originalFormatString = Objects.requireNonNull((stringFormat).getValue()).toString();
String originalFormatString = Objects.requireNonNull((String) ((J.Literal) messageArgument).getValue());
List<Integer> originalIndices = originalLoggedArgumentIndices(originalFormatString);
List<Expression> originalParameters = originalParameters(originalArguments.get(2));

List<Expression> targetArguments = new ArrayList<>(2);
targetArguments.add(buildString(originalFormatString.replaceAll("\\{\\d*}", "{}")));
targetArguments.add(buildStringLiteral(originalFormatString.replaceAll("\\{\\d*}", "{}")));
originalIndices.forEach(i -> targetArguments.add(originalParameters.get(i)));
return JavaTemplate.builder(createTemplateString(newName.get(), targetArguments))
return JavaTemplate.builder(createTemplateString(newName, targetArguments))
.contextSensitive()
.javaParser(JavaParser.fromJavaVersion()
.classpathFromResources(ctx, "slf4j-api-2.1"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,6 @@
import org.openrewrite.test.RecipeSpec;
import org.openrewrite.test.RewriteTest;

import java.util.logging.Level;
import java.util.logging.Logger;

import static org.openrewrite.java.Assertions.java;

class JulParameterizedArgumentsTest implements RewriteTest {
Expand Down Expand Up @@ -186,4 +183,64 @@ void method(Logger logger, String param1) {
)
);
}

@Test
void staticImportLevel() {
rewriteRun(
// language=java
java(
"""
import java.util.logging.Logger;
import static java.util.logging.Level.INFO;
class Test {
void method(Logger logger, String param1) {
logger.log(INFO, "INFO Log entry, param1: {0}", param1);
}
}
""",
"""
import org.slf4j.Logger;
class Test {
void method(Logger logger, String param1) {
logger.info("INFO Log entry, param1: {}", param1);
}
}
"""
)
);
}

@Test
void levelVariableLeadsToPartialConversion() {
rewriteRun(
// language=java
java(
"""
import java.util.logging.Logger;
import java.util.logging.Level;
class Test {
void method(Logger logger, Level level, String param1) {
// No way to determine the replacement logging method
logger.log(level, "INFO Log entry, param1: {0}", param1);
}
}
""",
"""
import org.slf4j.Logger;
import java.util.logging.Level;
class Test {
void method(Logger logger, Level level, String param1) {
// No way to determine the replacement logging method
logger.log(level, "INFO Log entry, param1: {0}", param1);
}
}
"""
)
);
}
}

0 comments on commit ddf6c3d

Please sign in to comment.