diff --git a/src/approxmc.cpp b/src/approxmc.cpp index 2405e34..0cf6cb7 100644 --- a/src/approxmc.cpp +++ b/src/approxmc.cpp @@ -45,8 +45,9 @@ using namespace AppMCInt; namespace ApproxMC { struct AppMCPrivateData { - Counter counter; + AppMCPrivateData(): counter(conf) {} Config conf; + Counter counter; }; } @@ -231,7 +232,7 @@ DLL_PUBLIC ApproxMC::SolCount AppMC::count() } setup_sampling_vars(data); - SolCount sol_count = data->counter.solve(data->conf); + SolCount sol_count = data->counter.solve(); return sol_count; } @@ -269,12 +270,12 @@ DLL_PUBLIC bool AppMC::add_red_clause(const vector& lits) DLL_PUBLIC bool AppMC::add_clause(const vector& lits) { - return data->counter.solver->add_clause(lits); + return data->counter.solver_add_clause(lits); } DLL_PUBLIC bool AppMC::add_xor_clause(const vector& vars, bool rhs) { - return data->counter.solver->add_xor_clause(vars, rhs); + return data->counter.solver_add_xor_clause(vars, rhs); } DLL_PUBLIC bool AppMC::add_bnn_clause( @@ -282,6 +283,10 @@ DLL_PUBLIC bool AppMC::add_bnn_clause( signed cutoff, Lit out) { + if (data->conf.dump_intermediary_cnf) { + cout << "ERROR: BNNs not supported when dumping" << endl; + exit(-1); + } return data->counter.solver->add_bnn_clause(lits, cutoff, out); } diff --git a/src/counter.cpp b/src/counter.cpp index 319cdaa..45f4de1 100644 --- a/src/counter.cpp +++ b/src/counter.cpp @@ -36,6 +36,7 @@ #include #include #include +#include #include #include #include @@ -60,8 +61,20 @@ using std::cout; using std::endl; using std::map; +using std::make_pair; using namespace AppMCInt; +bool Counter::solver_add_clause(const vector& cl) { + if (conf.dump_intermediary_cnf) cls_in_solver.push_back(cl); + return solver->add_clause(cl); +} + + +bool Counter::solver_add_xor_clause(const vector& vars, const bool rhs) { + if (conf.dump_intermediary_cnf) xors_in_solver.push_back(make_pair(vars, rhs)); + return solver->add_xor_clause(vars, rhs); +} + Hash Counter::add_hash(uint32_t hash_index, SparseData& sparse_data) { const string randomBits = @@ -80,7 +93,7 @@ Hash Counter::add_hash(uint32_t hash_index, SparseData& sparse_data) Hash h(act_var, vars, rhs); vars.push_back(act_var); - solver->add_xor_clause(vars, rhs); + solver_add_xor_clause(vars, rhs); if (conf.verb_cls) { print_xor(vars, rhs); } @@ -95,7 +108,7 @@ void Counter::ban_one(const uint32_t act_var, const vector& model) for (const uint32_t var: conf.sampling_set) { lits.push_back(Lit(var, model[var] == l_True)); } - solver->add_clause(lits); + solver_add_clause(lits); } ///adding banning clauses for repeating solutions @@ -149,51 +162,28 @@ uint64_t Counter::add_glob_banning_cls( return repeat; } -void Counter::dump_cnf_from_solver(const vector& assumps, const uint32_t iter, const lbool result) -{ - vector> cnf; - solver->start_getting_small_clauses( - std::numeric_limits::max(), - std::numeric_limits::max(), - false, true); - - uint32_t maxvars = 0; - bool ret = true; - vector cl; - while(ret) { - ret = solver->get_next_small_clause(cl); - if (!ret) { - continue; - } - cnf.push_back(cl); - for(const auto& l: cl) { - maxvars = std::max(l.var(), maxvars); - } - } - - for(const auto& l: assumps) { - maxvars = std::max(l.var(), maxvars); - } - - solver->end_getting_small_clauses(); - - std::stringstream ss; +void Counter::dump_cnf_from_solver(const vector& assumps, const uint32_t iter, const lbool result) { std::string result_str; if (result == l_True) result_str = "SAT"; else if (result == l_False) result_str = "UNSAT"; else assert(false && "Should not be called with unknown!"); + + std::stringstream ss; ss << "cnfdump" << "-res-" << result_str << "-iter-" << iter << "-active-xors-" << assumps.size() << "-out-" << cnf_dump_no++ << ".cnf"; std::ofstream f; f.open(ss.str(), std::ios::out); - f << "p cnf " << maxvars+1 << " " << cnf.size()+assumps.size() << "\n"; - for(const auto& l: assumps) { - f << l << " 0\n"; - } - - for(const auto& c: cnf) { - f << c << " 0\n"; + f << "p cnf " << solver->nVars()+1 << " " << cls_in_solver.size()+xors_in_solver.size()+assumps.size() << endl; + for(const auto& l: assumps) f << l << " 0" << endl; + for(const auto& cl: cls_in_solver) f << cl << " 0" << endl; + for(const auto& x: xors_in_solver) { + f << "x"; + for(uint32_t i = 0; i < x.first.size(); i++) { + if (i == 0 && !x.second) f << "-"; + f << (x.first[i]+1) << " "; + } + f << "0" << endl; } f.close(); } @@ -284,7 +274,7 @@ SolNum Counter::bounded_sol_count( if (conf.verb_cls) { cout << "c [appmc] Adding banning clause: " << lits << endl; } - solver->add_clause(lits); + solver_add_clause(lits); } //Save global models @@ -297,7 +287,7 @@ SolNum Counter::bounded_sol_count( //Remove solution banning vector cl_that_removes; cl_that_removes.push_back(Lit(sol_ban_var, false)); - solver->add_clause(cl_that_removes); + solver_add_clause(cl_that_removes); return SolNum(solutions, repeat); } @@ -313,9 +303,7 @@ void Counter::print_final_count_stats(ApproxMC::SolCount solCount) } } -ApproxMC::SolCount Counter::solve(Config _conf) -{ - conf = _conf; +ApproxMC::SolCount Counter::solve() { orig_num_vars = solver->nVars(); startTime = cpuTimeTotal(); diff --git a/src/counter.h b/src/counter.h index 5f69577..2079527 100644 --- a/src/counter.h +++ b/src/counter.h @@ -34,6 +34,7 @@ #include #include #include +#include #include #include #include "approxmc.h" @@ -42,6 +43,7 @@ using std::string; using std::vector; using std::map; +using std::pair; using namespace CMSat; namespace AppMCInt { @@ -108,7 +110,8 @@ struct SparseData { class Counter { public: - ApproxMC::SolCount solve(Config _conf); + Counter(Config& _conf) : conf(_conf) {} + ApproxMC::SolCount solve(); string gen_rnd_bits(const uint32_t size, const uint32_t numhashes, SparseData& sparse_data); string binary(const uint32_t x, const uint32_t length); @@ -120,9 +123,11 @@ class Counter { ApproxMC::SolCount calc_est_count(); void print_final_count_stats(ApproxMC::SolCount sol_count); const Constants constants; + bool solver_add_clause(const vector& cl); + bool solver_add_xor_clause(const vector& vars, const bool rhs); private: - Config conf; + Config& conf; ApproxMC::SolCount count(); void add_appmc_options(); bool ScalCounter(ApproxMC::SolCount& count); @@ -197,6 +202,8 @@ class Counter { double total_inter_simp_time = 0; uint32_t threshold; //precision, it's computed uint32_t cnf_dump_no = 0; + vector> cls_in_solver; // needed for accurate dumping + vector, bool>> xors_in_solver; // needed for accurate dumping int argc; char** argv; diff --git a/src/main.cpp b/src/main.cpp index 9135d91..e77d60d 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -395,8 +395,7 @@ void print_num_solutions(uint32_t cellSolCount, uint32_t hashCount) } -void get_cnf_from_arjun() -{ +void get_cnf_from_arjun() { bool ret = true; const uint32_t orig_num_vars = arjun->get_orig_num_vars(); appmc->new_vars(orig_num_vars); @@ -407,21 +406,13 @@ void get_cnf_from_arjun() vector clause; while (ret) { ret = arjun->get_next_small_clause(clause); - if (!ret) { - break; - } + if (!ret) break; bool ok = true; for(auto l: clause) { - if (l.var() >= orig_num_vars) { - ok = false; - break; - } - } - - if (ok) { - appmc->add_clause(clause); + if (l.var() >= orig_num_vars) { ok = false; break; } } + if (ok) appmc->add_clause(clause); } arjun->end_getting_small_clauses(); @@ -548,7 +539,6 @@ int main(int argc, char** argv) cout << appmc->get_version_info(); cout << "c executed with command line: " << command_line << endl; } - set_approxmc_options(); uint32_t offset_count_by_2_pow = 0;