diff --git a/build.gradle b/build.gradle index 47c39b36..f2375f12 100644 --- a/build.gradle +++ b/build.gradle @@ -62,6 +62,13 @@ dependencies { implementation files("${checkerJar}") implementation group: 'com.google.errorprone', name: 'javac', version: "$errorproneJavacVersion" + if (isJava8) { + implementation 'io.jenetics:jenetics:5.2.0' + } + else { + implementation 'io.jenetics:jenetics:6.3.0' + } + implementation 'org.plumelib:options:1.0.5' implementation 'org.plumelib:plume-util:1.8.1' diff --git a/src/checkers/inference/solver/backend/geneticmaxsat/GeneticMaxSatSolver.java b/src/checkers/inference/solver/backend/geneticmaxsat/GeneticMaxSatSolver.java new file mode 100644 index 00000000..5c2d4bd3 --- /dev/null +++ b/src/checkers/inference/solver/backend/geneticmaxsat/GeneticMaxSatSolver.java @@ -0,0 +1,112 @@ +package checkers.inference.solver.backend.geneticmaxsat; + +import checkers.inference.model.Constraint; +import checkers.inference.model.Slot; +import checkers.inference.solver.backend.maxsat.MaxSatFormatTranslator; +import checkers.inference.solver.backend.maxsat.MaxSatSolver; +import checkers.inference.solver.frontend.Lattice; +import checkers.inference.solver.util.FileUtils; +import checkers.inference.solver.util.SolverEnvironment; + +import javax.lang.model.element.AnnotationMirror; +import java.io.File; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; +import java.util.Map; + +/** + * GeneticMaxSatSolver adds support to use Genetic Algorithm to optimize the {@link checkers.inference.model.PreferenceConstraint} weights + * + */ +public abstract class GeneticMaxSatSolver extends MaxSatSolver { + + public int allSoftWeightsCount = 0; + public String wcnfFileContent; + + public GeneticMaxSatSolver(SolverEnvironment solverEnvironment, Collection slots, + Collection constraints, MaxSatFormatTranslator formatTranslator, + Lattice lattice) { + super(solverEnvironment, slots, constraints, formatTranslator, lattice); + } + + @Override + public Map solve() { + Map superSolutions = super.solve(); // to create initial set of constraints + List allLines = null; + + try { + allLines = Files.readAllLines(Paths.get("cnfData/wcnfdata.wcnf")); // read from the wcnf file created by super + } catch (IOException ioException) { + ioException.printStackTrace(); + } + assert allLines != null; + wcnfFileContent = String.join("\n", allLines.toArray(new String[0])); + + softWeightCounter(); + fit(); + + return superSolutions; + } + + public String changeSoftWeights(int[] newSoftWeights, String wcnfFileContent, boolean writeToFile){ + int oldTop = 0; + int wtIndex = 0; + + String[] wcnfContentSplit = wcnfFileContent.split("\n"); + + StringBuilder WCNFModInput = new StringBuilder(); + + for (String line : wcnfContentSplit) { + + String[] trimAndSplit = line.trim().split(" "); + + if (trimAndSplit[0].equals("p")) { + oldTop = Integer.parseInt(trimAndSplit[4]); + trimAndSplit[4] = String.valueOf(Arrays.stream(newSoftWeights).sum()); // replacing the top value with current sum of soft weights + } else if (oldTop != 0 && Integer.parseInt(trimAndSplit[0]) < oldTop) { + trimAndSplit[0] = String.valueOf(newSoftWeights[wtIndex]); + wtIndex++; + } + + WCNFModInput.append(String.join(" ", trimAndSplit)); + WCNFModInput.append("\n"); + } + + WCNFModInput.setLength(WCNFModInput.length() - 1); // to prevent unwanted character at the end of file + + if (writeToFile){ + File WCNFData = new File(new File("").getAbsolutePath() + "/cnfData"); + FileUtils.writeFile(new File(WCNFData.getAbsolutePath() + "/" + "wcnfdata_modified.wcnf"), WCNFModInput.toString()); + } + + return WCNFModInput.toString(); + } + + public void softWeightCounter(){ + allSoftWeightsCount = 0; + int top = 0; + String[] wcnfContentSplit = wcnfFileContent.split("\n"); + + for (String line : wcnfContentSplit) { + + String[] trimAndSplit = line.trim().split(" "); + + if (trimAndSplit[0].equals("p")) { + top = Integer.parseInt(trimAndSplit[4]); + } else if (top != 0 && Integer.parseInt(trimAndSplit[0]) < top) { + allSoftWeightsCount++; + } + } + } + + /** + * Override this method to declare an {@link io.jenetics.engine.Engine} builder and create a fitness function. + * For reference, please look at Universe type system. + */ + public abstract void fit(); + +} diff --git a/src/checkers/inference/solver/backend/maxsat/MaxSatSolver.java b/src/checkers/inference/solver/backend/maxsat/MaxSatSolver.java index 2986f7b2..0b549881 100644 --- a/src/checkers/inference/solver/backend/maxsat/MaxSatSolver.java +++ b/src/checkers/inference/solver/backend/maxsat/MaxSatSolver.java @@ -1,6 +1,7 @@ package checkers.inference.solver.backend.maxsat; import java.io.File; +import java.math.BigInteger; import java.util.Collection; import java.util.HashMap; import java.util.HashSet; @@ -20,6 +21,7 @@ import org.sat4j.pb.IPBSolver; import org.sat4j.specs.ContradictionException; import org.sat4j.specs.IConstr; +import org.sat4j.specs.TimeoutException; import org.sat4j.tools.xplain.DeletionStrategy; import org.sat4j.tools.xplain.Xplain; @@ -58,14 +60,16 @@ protected enum MaxSatSolverArg implements SolverArg { private MaxSATUnsatisfiableConstraintExplainer unsatisfiableConstraintExplainer; protected final File CNFData = new File(new File("").getAbsolutePath() + "/cnfData"); protected StringBuilder CNFInput = new StringBuilder(); + protected StringBuilder WCNFInput = new StringBuilder(); + private int sumSoftConstraintWeights = 0; private long serializationStart; private long serializationEnd; protected long solvingStart; protected long solvingEnd; public MaxSatSolver(SolverEnvironment solverEnvironment, Collection slots, - Collection constraints, MaxSatFormatTranslator formatTranslator, Lattice lattice) { + Collection constraints, MaxSatFormatTranslator formatTranslator, Lattice lattice) { super(solverEnvironment, slots, constraints, formatTranslator, lattice); this.slotManager = InferenceMain.getInstance().getSlotManager(); @@ -85,12 +89,15 @@ public Map solve() { this.serializationStart = System.currentTimeMillis(); // Serialization step: encodeAllConstraints(); + solver.setTopWeight(BigInteger.valueOf(sumSoftConstraintWeights + 1)); encodeWellFormednessRestriction(); this.serializationEnd = System.currentTimeMillis(); if (shouldOutputCNF()) { buildCNFInput(); writeCNFInput(); + buildWCNFInput(); + writeWCNFInput(); } // printClauses(); configureSatSolver(solver); @@ -145,7 +152,9 @@ public void encodeAllConstraints() { for (VecInt res : encoding) { if (res != null && res.size() != 0) { if (constraint instanceof PreferenceConstraint) { - softClauses.add(IPair.of(res, ((PreferenceConstraint) constraint).getWeight())); + int constraintWeight = ((PreferenceConstraint) constraint).getWeight(); + sumSoftConstraintWeights += constraintWeight; + softClauses.add(IPair.of(res, constraintWeight).getWeight()); } else { hardClauses.add(res); } @@ -251,21 +260,71 @@ protected void buildCNFInput() { private void buildCNFInputHelper(VecInt clause) { int[] literals = clause.toArray(); - for (int i = 0; i < literals.length; i++) { - CNFInput.append(literals[i]); + for (int literal : literals) { + CNFInput.append(literal); CNFInput.append(" "); } CNFInput.append("0\n"); } protected void writeCNFInput() { - writeCNFInput("cnfdata.txt"); + writeCNFInput("cnfdata.cnf"); } protected void writeCNFInput(String file) { + CNFInput.setLength(CNFInput.length() - 1); // to prevent unwanted character at the end of file FileUtils.writeFile(new File(CNFData.getAbsolutePath() + "/" + file), CNFInput.toString()); } + /** + * Write WCNF clauses into a string. + */ + protected void buildWCNFInput() { + + final int totalClauses = softClauses.size() + hardClauses.size() + wellFormednessClauses.size(); + final int totalVars = slotManager.getNumberOfSlots() * lattice.numTypes; + final int topWeight = sumSoftConstraintWeights + 1; + + WCNFInput.append("c This is the WCNF input\n"); + WCNFInput.append("p wcnf "); + WCNFInput.append(totalVars); + WCNFInput.append(" "); + WCNFInput.append(totalClauses); + WCNFInput.append(" "); + WCNFInput.append(topWeight); + WCNFInput.append("\n"); + + for (VecInt hardClause : hardClauses) { + buildWCNFInputHelper(hardClause, topWeight); + } + for (VecInt wellFormedNessClause: wellFormednessClauses) { + buildWCNFInputHelper(wellFormedNessClause, topWeight); + } + for (Pair softclause : softClauses) { + buildWCNFInputHelper(softclause.a, softclause.b); + } + } + + private void buildWCNFInputHelper(VecInt clause, int weight) { + int[] literals = clause.toArray(); + WCNFInput.append(weight); + WCNFInput.append(" "); + for (int literal : literals) { + WCNFInput.append(literal); + WCNFInput.append(" "); + } + WCNFInput.append("0\n"); + } + + protected void writeWCNFInput() { + writeWCNFInput("wcnfdata.wcnf"); + } + + protected void writeWCNFInput(String file) { + WCNFInput.setLength(WCNFInput.length() - 1); // to prevent unwanted character at the end of file + FileUtils.writeFile(new File(CNFData.getAbsolutePath() + "/" + file), WCNFInput.toString()); + } + /** * print all soft and hard clauses for testing. */ @@ -370,7 +429,7 @@ public Collection minimumUnsatisfiableConstraints() { } private void configureExplanationSolver(final List hardClauses, final List wellformedness, - final SlotManager slotManager, final Lattice lattice, final Xplain explainer) { + final SlotManager slotManager, final Lattice lattice, final Xplain explainer) { int numberOfNewVars = slotManager.getNumberOfSlots() * lattice.numTypes; System.out.println("Number of variables: " + numberOfNewVars); int numberOfClauses = hardClauses.size() + wellformedness.size();