Skip to content

Commit

Permalink
Inline Javatemplate in RefasterTemplateProcessor (#57)
Browse files Browse the repository at this point in the history
* Inline JavaTemplate into RefasterTemplateProcessor

* Adapt Refaster input and output tests

* Remove unused method statementType

* Resolve parameters, which now triggers NPE

* Revert JavacResolution for template parameters

* Change input to ParameterReuseRecipe

* Adapt three out of four remaining tests

* Fix remaining test

* Update new test too

* Add TODO

* Quick stash

* Fix compilation issue

* Add an inline test for RefaterTemplateProcessor as well

* Add FIXME

* Add more analysis of what's going wrong

* First working version

* Update reference outputs

* Handle recipes without parameters

* Drop now unused import

* Drop inline tests again while text blocks are unavailable

* Only pass through matching resolved parameter types

* Resolve all before templates at once
  • Loading branch information
timtebeek authored Mar 4, 2024
1 parent 4f8f759 commit ab2dee2
Show file tree
Hide file tree
Showing 22 changed files with 386 additions and 170 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import com.sun.tools.javac.tree.TreeScanner;
import com.sun.tools.javac.util.Context;
import com.sun.tools.javac.util.Log;
import org.openrewrite.internal.lang.Nullable;

import javax.tools.JavaFileObject;
import java.lang.reflect.Field;
Expand All @@ -50,7 +51,7 @@ public JavacResolution(Context context) {
this.log = Log.instance(context);
}

public Map<JCTree, JCTree> resolveAll(Context context, JCCompilationUnit cu, List<? extends Tree> trees) {
public @Nullable Map<JCTree, JCTree> resolveAll(Context context, JCCompilationUnit cu, List<? extends Tree> trees) {
AtomicReference<Map<JCTree, JCTree>> resolved = new AtomicReference<>();

new TreeScanner() {
Expand Down Expand Up @@ -203,6 +204,7 @@ public void visitMethodDef(JCMethodDecl tree) {
copyAt = tree;
}

@Override
public void visitVarDef(JCVariableDecl tree) {
if (copyAt != null) {
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,13 @@ public static <T extends JCTree> String process(T tree, List<JCTree.JCVariableDe
}
}

public static String indent(String code, int width) {
char[] indent = new char[width];
Arrays.fill(indent, ' ');
String replacement = "$1" + new String(indent);
return code.replaceAll("(?m)(\\R)", replacement);
}

private static class TemplateCodePrinter extends Pretty {

private static final String PRIMITIVE_ANNOTATION = "org.openrewrite.java.template.Primitive";
Expand All @@ -87,8 +94,9 @@ public void visitIdent(JCIdent jcIdent) {
if (param.isPresent()) {
print("#{" + sym.name);
if (seenParameters.add(param.get())) {
String type = param.get().type.toString();
if (param.get().getModifiers().getAnnotations().stream().anyMatch(a -> a.attribute.type.tsym.getQualifiedName().toString().equals(PRIMITIVE_ANNOTATION))) {
String type = param.get().sym.type.toString();
if (param.get().getModifiers().getAnnotations().stream()
.anyMatch(a -> a.attribute.type.tsym.getQualifiedName().toString().equals(PRIMITIVE_ANNOTATION))) {
type = getUnboxedPrimitive(type);
}
print(":any(" + type + ")");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@
import com.sun.tools.javac.tree.TreeScanner;
import com.sun.tools.javac.util.Context;
import org.openrewrite.internal.lang.Nullable;
import org.openrewrite.java.template.internal.FQNPretty;
import org.openrewrite.java.template.internal.ImportDetector;
import org.openrewrite.java.template.internal.JavacResolution;
import org.openrewrite.java.template.internal.TemplateCode;
import org.openrewrite.java.template.internal.UsedMethodDetector;

import javax.annotation.processing.RoundEnvironment;
Expand Down Expand Up @@ -213,22 +213,15 @@ public void visitClassDef(JCTree.JCClassDecl classDecl) {
for (Map.Entry<String, JCTree.JCMethodDecl> entry : beforeTemplates.entrySet()) {
recipe.append(" final JavaTemplate ")
.append(entry.getKey())
.append(" = Semantics.")
.append(statementType(entry.getValue()))
.append("(this, \"")
.append(entry.getKey()).append("\", ")
.append(toLambda(entry.getValue()))
.append(").build();\n");
.append(" = ")
.append(toJavaTemplateBuilder(entry.getValue(), descriptor.resolvedParameters))
.append("\n .build();\n");
}
recipe.append(" final JavaTemplate ")
.append(after)
.append(" = Semantics.")
.append(statementType(descriptor.afterTemplate))
.append("(this, \"")
.append(after)
.append("\", ")
.append(toLambda(descriptor.afterTemplate))
.append(").build();\n");
.append(" = ")
.append(toJavaTemplateBuilder(descriptor.afterTemplate, descriptor.resolvedParameters))
.append("\n .build();\n");
recipe.append("\n");

List<String> lstTypes = LST_TYPE_MAP.get(getType(descriptor.beforeTemplates.get(0)));
Expand Down Expand Up @@ -331,11 +324,11 @@ public void visitClassDef(JCTree.JCClassDecl classDecl) {
out.write("import org.openrewrite.Recipe;\n");
out.write("import org.openrewrite.TreeVisitor;\n");
out.write("import org.openrewrite.internal.lang.NonNullApi;\n");
out.write("import org.openrewrite.java.JavaParser;\n");
out.write("import org.openrewrite.java.JavaTemplate;\n");
out.write("import org.openrewrite.java.JavaVisitor;\n");
out.write("import org.openrewrite.java.search.*;\n");
out.write("import org.openrewrite.java.template.Primitive;\n");
out.write("import org.openrewrite.java.template.Semantics;\n");
out.write("import org.openrewrite.java.template.function.*;\n");
out.write("import org.openrewrite.java.template.internal.AbstractRefasterJavaVisitor;\n");
out.write("import org.openrewrite.java.tree.*;\n");
Expand Down Expand Up @@ -400,6 +393,22 @@ public void visitClassDef(JCTree.JCClassDecl classDecl) {
}
}

private String toJavaTemplateBuilder(JCTree.JCMethodDecl methodDecl,
Map<JCTree.JCVariableDecl, JCTree.JCVariableDecl> resolvedParameters) {
JCTree tree = methodDecl.getBody().getStatements().get(0);
if (tree instanceof JCTree.JCReturn) {
tree = ((JCTree.JCReturn) tree).getExpression();
}

List<JCTree.JCVariableDecl> mappedParameters = methodDecl.getParameters().stream()
.map(resolvedParameters::get)
.map(JCTree.JCVariableDecl.class::cast)
.collect(Collectors.toList());

String javaTemplateBuilder = TemplateCode.process(tree, mappedParameters, true);
return TemplateCode.indent(javaTemplateBuilder, 16);
}

private boolean simplifyBooleans(JCTree.JCMethodDecl template) {
if (template.getReturnType().type.getTag() == TypeTag.BOOLEAN) {
return true;
Expand Down Expand Up @@ -655,55 +664,6 @@ private JCTree.JCExpression getReturnExpression(JCTree.JCMethodDecl method) {
return null;
}

private String statementType(JCTree.JCMethodDecl method) {
// for now excluding assignment expressions and prefix and postfix -- and ++
Set<Class<? extends JCTree>> expressionStatementTypes = Stream.of(
JCTree.JCMethodInvocation.class,
JCTree.JCNewClass.class).collect(Collectors.toSet());

Class<? extends JCTree> type = getType(method);
if (expressionStatementTypes.contains(type)) {
if (type == JCTree.JCMethodInvocation.class
&& method.getBody().getStatements().last() instanceof JCTree.JCExpressionStatement
&& !(method.getReturnType().type instanceof Type.JCVoidType)) {
return "expression";
}
if (method.restype.type instanceof Type.JCVoidType || !JCTree.JCExpression.class.isAssignableFrom(type)) {
return "statement";
}
}
return "expression";
}

private String toLambda(JCTree.JCMethodDecl method) {
StringBuilder builder = new StringBuilder();

StringJoiner joiner = new StringJoiner(", ", "(", ")");
for (JCTree.JCVariableDecl parameter : method.getParameters()) {
String paramType = parameter.getType().type.toString();
if (!getBoxedPrimitive(paramType).equals(paramType)) {
paramType = "@Primitive " + getBoxedPrimitive(paramType);
} else if (paramType.startsWith("java.lang.")) {
paramType = paramType.substring("java.lang.".length());
}
joiner.add(paramType + " " + parameter.getName());
}
builder.append(joiner);
builder.append(" -> ");

JCTree.JCStatement statement = method.getBody().getStatements().get(0);
if (statement instanceof JCTree.JCReturn) {
builder.append(FQNPretty.toString(((JCTree.JCReturn) statement).getExpression()));
} else if (statement instanceof JCTree.JCThrow) {
String string = FQNPretty.toString(statement);
builder.append("{ ").append(string).append(" }");
} else {
String string = FQNPretty.toString(statement);
builder.append(string);
}
return builder.toString();
}

@Nullable
private TemplateDescriptor getTemplateDescriptor(JCTree.JCClassDecl tree, Context context, JCCompilationUnit cu) {
TemplateDescriptor result = new TemplateDescriptor(tree);
Expand All @@ -727,6 +687,7 @@ class TemplateDescriptor {
final JCTree.JCClassDecl classDecl;
final List<JCTree.JCMethodDecl> beforeTemplates = new ArrayList<>();
JCTree.JCMethodDecl afterTemplate;
Map<JCTree.JCVariableDecl, JCTree.JCVariableDecl> resolvedParameters = new IdentityHashMap<>();

public TemplateDescriptor(JCTree.JCClassDecl classDecl) {
this.classDecl = classDecl;
Expand Down Expand Up @@ -829,14 +790,30 @@ public void afterTemplate(JCTree.JCMethodDecl method) {
}

private boolean resolve(Context context, JCCompilationUnit cu) {
JavacResolution res = new JavacResolution(context);
try {
JavacResolution res = new JavacResolution(context);
beforeTemplates.replaceAll(key -> {
Map<JCTree, JCTree> resolved = res.resolveAll(context, cu, singletonList(key));
return (JCTree.JCMethodDecl) resolved.get(key);
});
Map<JCTree, JCTree> resolved = res.resolveAll(context, cu, singletonList(afterTemplate));
afterTemplate = (JCTree.JCMethodDecl) resolved.get(afterTemplate);
// Resolve parameters
for (JCTree.JCMethodDecl beforeTemplate : beforeTemplates) {
if (!beforeTemplate.getParameters().isEmpty()) {
for (Map.Entry<JCTree, JCTree> e : res.resolveAll(context, cu, beforeTemplate.getParameters()).entrySet()) {
if (e.getKey() instanceof JCTree.JCVariableDecl && e.getValue() instanceof JCTree.JCVariableDecl) {
resolvedParameters.put((JCTree.JCVariableDecl) e.getValue(), (JCTree.JCVariableDecl) e.getKey());
}
}
}
}
if (!afterTemplate.getParameters().isEmpty()) {
for (Map.Entry<JCTree, JCTree> e : res.resolveAll(context, cu, afterTemplate.getParameters()).entrySet()) {
if (e.getKey() instanceof JCTree.JCVariableDecl && e.getValue() instanceof JCTree.JCVariableDecl) {
resolvedParameters.put((JCTree.JCVariableDecl) e.getValue(), (JCTree.JCVariableDecl) e.getKey());
}
}
}

// Resolve templates
Map<JCTree, JCTree> resolvedBeforeTemplates = res.resolveAll(context, cu, beforeTemplates);
beforeTemplates.replaceAll(key -> (JCTree.JCMethodDecl) resolvedBeforeTemplates.get(key));
afterTemplate = (JCTree.JCMethodDecl) res.resolveAll(context, cu, singletonList(afterTemplate)).get(afterTemplate);
} catch (Throwable t) {
processingEnv.getMessager().printMessage(Kind.WARNING, "Had trouble type attributing the template.");
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ public void visitApply(JCTree.JCMethodInvocation tree) {
out.write(" * @return the JavaTemplate builder.\n");
out.write(" */\n");
out.write(" public static JavaTemplate.Builder getTemplate() {\n");
out.write(" return " + indent(templateCode, 12) + ";\n");
out.write(" return " + TemplateCode.indent(templateCode, 12) + ";\n");
out.write(" }\n");
out.write("}\n");
out.flush();
Expand All @@ -176,13 +176,6 @@ public void visitApply(JCTree.JCMethodInvocation tree) {

super.visitApply(tree);
}

private String indent(String code, int width) {
char[] indent = new char[width];
Arrays.fill(indent, ' ');
String replacement = "$1" + new String(indent);
return code.replaceAll("(?m)(\\R)", replacement);
}
}.scan(cu);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,14 +110,19 @@ private static Compilation compileResource(String resourceName) {

static Compilation compileResource(String resourceName, TypeAwareProcessor processor) {
// As per https://github.com/google/compile-testing/blob/v0.21.0/src/main/java/com/google/testing/compile/package-info.java#L53-L55
return compileResource(JavaFileObjects.forResource(resourceName), processor);
return compile(JavaFileObjects.forResource(resourceName), processor);
}

@SuppressWarnings("unused") // use when text blocks are available
static Compilation compileSource(String fqn, @Language("java") String source) {
return compile(JavaFileObjects.forSourceString(fqn, source), new RefasterTemplateProcessor());
}

static Compilation compileSource(String fqn, @Language("java") String source, TypeAwareProcessor processor) {
return compileResource(JavaFileObjects.forSourceString(fqn, source), processor);
return compile(JavaFileObjects.forSourceString(fqn, source), processor);
}

static Compilation compileResource(JavaFileObject javaFileObject, TypeAwareProcessor processor) {
static Compilation compile(JavaFileObject javaFileObject, TypeAwareProcessor processor) {
return javac()
.withProcessors(processor)
.withClasspath(Arrays.asList(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import org.openrewrite.java.template.processor.TemplateProcessor;

import static com.google.testing.compile.CompilationSubject.assertThat;
import static org.openrewrite.java.template.RefasterTemplateProcessorTest.compileResource;
import static org.openrewrite.java.template.RefasterTemplateProcessorTest.*;

class TemplateProcessorTest {
@ParameterizedTest
Expand Down
23 changes: 17 additions & 6 deletions src/test/resources/refaster/ArraysRecipe.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023 the original author or authors.
* Copyright 2024 the original author or authors.
* <p>
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -20,11 +20,12 @@
import org.openrewrite.Recipe;
import org.openrewrite.TreeVisitor;
import org.openrewrite.internal.lang.NonNullApi;
import org.openrewrite.java.JavaParser;
import org.openrewrite.java.JavaTemplate;
import org.openrewrite.java.JavaVisitor;
import org.openrewrite.java.search.*;
import org.openrewrite.java.template.Primitive;
import org.openrewrite.java.template.Semantics;

import org.openrewrite.java.template.function.*;
import org.openrewrite.java.template.internal.AbstractRefasterJavaVisitor;
import org.openrewrite.java.tree.*;
Expand All @@ -33,12 +34,18 @@

import static org.openrewrite.java.template.internal.AbstractRefasterJavaVisitor.EmbeddingOption.*;


/**
* OpenRewrite recipe created for Refaster template {@code Arrays}.
*/
@SuppressWarnings("all")
@NonNullApi
public class ArraysRecipe extends Recipe {

public ArraysRecipe() {
}
/**
* Instantiates a new instance.
*/
public ArraysRecipe() {}

@Override
public String getDisplayName() {
Expand All @@ -53,8 +60,12 @@ public String getDescription() {
@Override
public TreeVisitor<?, ExecutionContext> getVisitor() {
JavaVisitor<ExecutionContext> javaVisitor = new AbstractRefasterJavaVisitor() {
final JavaTemplate before = Semantics.expression(this, "before", (String[] strings) -> String.join(", ", strings)).build();
final JavaTemplate after = Semantics.expression(this, "after", (String[] strings) -> String.join(":", strings)).build();
final JavaTemplate before = JavaTemplate
.builder("String.join(\", \", #{strings:any(java.lang.String[])})")
.build();
final JavaTemplate after = JavaTemplate
.builder("String.join(\":\", #{strings:any(java.lang.String[])})")
.build();

@Override
public J visitMethodInvocation(J.MethodInvocation elem, ExecutionContext ctx) {
Expand Down
14 changes: 9 additions & 5 deletions src/test/resources/refaster/CharacterEscapeAnnotationRecipe.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@
import org.openrewrite.Recipe;
import org.openrewrite.TreeVisitor;
import org.openrewrite.internal.lang.NonNullApi;
import org.openrewrite.java.JavaParser;
import org.openrewrite.java.JavaTemplate;
import org.openrewrite.java.JavaVisitor;
import org.openrewrite.java.search.*;
import org.openrewrite.java.template.Primitive;
import org.openrewrite.java.template.Semantics;
import org.openrewrite.java.template.function.*;
import org.openrewrite.java.template.internal.AbstractRefasterJavaVisitor;
import org.openrewrite.java.tree.*;
Expand All @@ -35,7 +35,7 @@


/**
* OpenRewrite recipe created for Refaster template {@code MultilineAnnotation}.
* OpenRewrite recipe created for Refaster template {@code CharacterEscapeAnnotation}.
*/
@SuppressWarnings("all")
@NonNullApi
Expand All @@ -44,7 +44,7 @@ public class CharacterEscapeAnnotationRecipe extends Recipe {
/**
* Instantiates a new instance.
*/
public MultilineAnnotationRecipe() {}
public CharacterEscapeAnnotationRecipe() {}

@Override
public String getDisplayName() {
Expand All @@ -64,8 +64,12 @@ public Set<String> getTags() {
@Override
public TreeVisitor<?, ExecutionContext> getVisitor() {
JavaVisitor<ExecutionContext> javaVisitor = new AbstractRefasterJavaVisitor() {
final JavaTemplate before = Semantics.expression(this, "before", () -> "The answer to life, the universe, and everything").build();
final JavaTemplate after = Semantics.expression(this, "after", () -> "42").build();
final JavaTemplate before = JavaTemplate
.builder("\"The answer to life, the universe, and everything\"")
.build();
final JavaTemplate after = JavaTemplate
.builder("\"42\"")
.build();

@Override
public J visitExpression(Expression elem, ExecutionContext ctx) {
Expand Down
Loading

0 comments on commit ab2dee2

Please sign in to comment.