Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added GeneticMaxSatSolver #382

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
7 changes: 7 additions & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,13 @@ dependencies {
implementation files("${checkerJar}")
implementation group: 'com.google.errorprone', name: 'javac', version: "$errorproneJavacVersion"

if (isJava8) {
implementation 'io.jenetics:jenetics:5.2.0'
}
else {
implementation 'io.jenetics:jenetics:6.3.0'
}

implementation 'org.plumelib:options:1.0.5'
implementation 'org.plumelib:plume-util:1.8.1'

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
package checkers.inference.solver.backend.geneticmaxsat;

import checkers.inference.model.Constraint;
import checkers.inference.model.Slot;
import checkers.inference.solver.backend.maxsat.MaxSatFormatTranslator;
import checkers.inference.solver.backend.maxsat.MaxSatSolver;
import checkers.inference.solver.frontend.Lattice;
import checkers.inference.solver.util.FileUtils;
import checkers.inference.solver.util.SolverEnvironment;

import javax.lang.model.element.AnnotationMirror;
import java.io.File;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.Map;

/**
* GeneticMaxSatSolver adds support to use Genetic Algorithm to optimize the {@link checkers.inference.model.PreferenceConstraint} weights
*
*/
public abstract class GeneticMaxSatSolver extends MaxSatSolver {

public int allSoftWeightsCount = 0;
public String wcnfFileContent;

public GeneticMaxSatSolver(SolverEnvironment solverEnvironment, Collection<Slot> slots,
Collection<Constraint> constraints, MaxSatFormatTranslator formatTranslator,
Lattice lattice) {
super(solverEnvironment, slots, constraints, formatTranslator, lattice);
}

@Override
public Map<Integer, AnnotationMirror> solve() {
Map<Integer, AnnotationMirror> superSolutions = super.solve(); // to create initial set of constraints
List<String> allLines = null;

try {
allLines = Files.readAllLines(Paths.get("cnfData/wcnfdata.wcnf")); // read from the wcnf file created by super
} catch (IOException ioException) {
ioException.printStackTrace();
}
assert allLines != null;
wcnfFileContent = String.join("\n", allLines.toArray(new String[0]));

softWeightCounter();
fit();

return superSolutions;
}

public String changeSoftWeights(int[] newSoftWeights, String wcnfFileContent, boolean writeToFile){
int oldTop = 0;
int wtIndex = 0;

String[] wcnfContentSplit = wcnfFileContent.split("\n");

StringBuilder WCNFModInput = new StringBuilder();

for (String line : wcnfContentSplit) {

String[] trimAndSplit = line.trim().split(" ");

if (trimAndSplit[0].equals("p")) {
oldTop = Integer.parseInt(trimAndSplit[4]);
trimAndSplit[4] = String.valueOf(Arrays.stream(newSoftWeights).sum()); // replacing the top value with current sum of soft weights
} else if (oldTop != 0 && Integer.parseInt(trimAndSplit[0]) < oldTop) {
trimAndSplit[0] = String.valueOf(newSoftWeights[wtIndex]);
wtIndex++;
}

WCNFModInput.append(String.join(" ", trimAndSplit));
WCNFModInput.append("\n");
}

WCNFModInput.setLength(WCNFModInput.length() - 1); // to prevent unwanted character at the end of file

if (writeToFile){
File WCNFData = new File(new File("").getAbsolutePath() + "/cnfData");
FileUtils.writeFile(new File(WCNFData.getAbsolutePath() + "/" + "wcnfdata_modified.wcnf"), WCNFModInput.toString());
}

return WCNFModInput.toString();
}

public void softWeightCounter(){
allSoftWeightsCount = 0;
int top = 0;
String[] wcnfContentSplit = wcnfFileContent.split("\n");

for (String line : wcnfContentSplit) {

String[] trimAndSplit = line.trim().split(" ");

if (trimAndSplit[0].equals("p")) {
top = Integer.parseInt(trimAndSplit[4]);
} else if (top != 0 && Integer.parseInt(trimAndSplit[0]) < top) {
allSoftWeightsCount++;
}
}
}

/**
* Override this method to declare an {@link io.jenetics.engine.Engine} builder and create a fitness function.
* For reference, please look at Universe type system.
*/
public abstract void fit();

}
71 changes: 65 additions & 6 deletions src/checkers/inference/solver/backend/maxsat/MaxSatSolver.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package checkers.inference.solver.backend.maxsat;

import java.io.File;
import java.math.BigInteger;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
Expand All @@ -20,6 +21,7 @@
import org.sat4j.pb.IPBSolver;
import org.sat4j.specs.ContradictionException;
import org.sat4j.specs.IConstr;
import org.sat4j.specs.TimeoutException;
import org.sat4j.tools.xplain.DeletionStrategy;
import org.sat4j.tools.xplain.Xplain;

Expand Down Expand Up @@ -58,14 +60,16 @@ protected enum MaxSatSolverArg implements SolverArg {
private MaxSATUnsatisfiableConstraintExplainer unsatisfiableConstraintExplainer;
protected final File CNFData = new File(new File("").getAbsolutePath() + "/cnfData");
protected StringBuilder CNFInput = new StringBuilder();
protected StringBuilder WCNFInput = new StringBuilder();

private int sumSoftConstraintWeights = 0;
private long serializationStart;
private long serializationEnd;
protected long solvingStart;
protected long solvingEnd;

public MaxSatSolver(SolverEnvironment solverEnvironment, Collection<Slot> slots,
Collection<Constraint> constraints, MaxSatFormatTranslator formatTranslator, Lattice lattice) {
Collection<Constraint> constraints, MaxSatFormatTranslator formatTranslator, Lattice lattice) {
super(solverEnvironment, slots, constraints, formatTranslator,
lattice);
this.slotManager = InferenceMain.getInstance().getSlotManager();
Expand All @@ -85,12 +89,15 @@ public Map<Integer, AnnotationMirror> solve() {
this.serializationStart = System.currentTimeMillis();
// Serialization step:
encodeAllConstraints();
solver.setTopWeight(BigInteger.valueOf(sumSoftConstraintWeights + 1));
encodeWellFormednessRestriction();
this.serializationEnd = System.currentTimeMillis();

if (shouldOutputCNF()) {
buildCNFInput();
writeCNFInput();
buildWCNFInput();
writeWCNFInput();
}
// printClauses();
configureSatSolver(solver);
Expand Down Expand Up @@ -145,7 +152,9 @@ public void encodeAllConstraints() {
for (VecInt res : encoding) {
if (res != null && res.size() != 0) {
if (constraint instanceof PreferenceConstraint) {
softClauses.add(IPair.of(res, ((PreferenceConstraint) constraint).getWeight()));
int constraintWeight = ((PreferenceConstraint) constraint).getWeight();
sumSoftConstraintWeights += constraintWeight;
softClauses.add(IPair.of(res, constraintWeight).getWeight());
} else {
hardClauses.add(res);
}
Expand Down Expand Up @@ -251,21 +260,71 @@ protected void buildCNFInput() {

private void buildCNFInputHelper(VecInt clause) {
int[] literals = clause.toArray();
for (int i = 0; i < literals.length; i++) {
CNFInput.append(literals[i]);
for (int literal : literals) {
CNFInput.append(literal);
CNFInput.append(" ");
}
CNFInput.append("0\n");
}

protected void writeCNFInput() {
writeCNFInput("cnfdata.txt");
writeCNFInput("cnfdata.cnf");
}

protected void writeCNFInput(String file) {
CNFInput.setLength(CNFInput.length() - 1); // to prevent unwanted character at the end of file
FileUtils.writeFile(new File(CNFData.getAbsolutePath() + "/" + file), CNFInput.toString());
}

/**
* Write WCNF clauses into a string.
*/
protected void buildWCNFInput() {

final int totalClauses = softClauses.size() + hardClauses.size() + wellFormednessClauses.size();
final int totalVars = slotManager.getNumberOfSlots() * lattice.numTypes;
final int topWeight = sumSoftConstraintWeights + 1;

WCNFInput.append("c This is the WCNF input\n");
WCNFInput.append("p wcnf ");
WCNFInput.append(totalVars);
WCNFInput.append(" ");
WCNFInput.append(totalClauses);
WCNFInput.append(" ");
WCNFInput.append(topWeight);
WCNFInput.append("\n");

for (VecInt hardClause : hardClauses) {
buildWCNFInputHelper(hardClause, topWeight);
}
for (VecInt wellFormedNessClause: wellFormednessClauses) {
buildWCNFInputHelper(wellFormedNessClause, topWeight);
}
for (Pair<VecInt, Integer> softclause : softClauses) {
buildWCNFInputHelper(softclause.a, softclause.b);
}
}

private void buildWCNFInputHelper(VecInt clause, int weight) {
int[] literals = clause.toArray();
WCNFInput.append(weight);
WCNFInput.append(" ");
for (int literal : literals) {
WCNFInput.append(literal);
WCNFInput.append(" ");
}
WCNFInput.append("0\n");
}

protected void writeWCNFInput() {
writeWCNFInput("wcnfdata.wcnf");
}

protected void writeWCNFInput(String file) {
WCNFInput.setLength(WCNFInput.length() - 1); // to prevent unwanted character at the end of file
FileUtils.writeFile(new File(CNFData.getAbsolutePath() + "/" + file), WCNFInput.toString());
}

/**
* print all soft and hard clauses for testing.
*/
Expand Down Expand Up @@ -370,7 +429,7 @@ public Collection<Constraint> minimumUnsatisfiableConstraints() {
}

private void configureExplanationSolver(final List<VecInt> hardClauses, final List<VecInt> wellformedness,
final SlotManager slotManager, final Lattice lattice, final Xplain<IPBSolver> explainer) {
final SlotManager slotManager, final Lattice lattice, final Xplain<IPBSolver> explainer) {
int numberOfNewVars = slotManager.getNumberOfSlots() * lattice.numTypes;
System.out.println("Number of variables: " + numberOfNewVars);
int numberOfClauses = hardClauses.size() + wellformedness.size();
Expand Down
Loading