Skip to content

Commit

Permalink
[CP-SAT] improve mod doc; improve precedences in scheduling; speed up…
Browse files Browse the repository at this point in the history
… circuit data structures
  • Loading branch information
lperron committed Jul 23, 2024
1 parent d0ed31d commit 458e2a1
Show file tree
Hide file tree
Showing 11 changed files with 153 additions and 108 deletions.
1 change: 1 addition & 0 deletions ortools/sat/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -1223,6 +1223,7 @@ cc_library(
":sat_base",
":sat_solver",
":synchronization",
":util",
"//ortools/base",
"//ortools/base:stl_util",
"//ortools/base:strong_vector",
Expand Down
22 changes: 14 additions & 8 deletions ortools/sat/circuit.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,16 +56,15 @@ CircuitPropagator::CircuitPropagator(const int num_nodes,
values.reserve(num_arcs);

graph_.reserve(num_arcs);
self_arcs_.resize(num_nodes_,
model->GetOrCreate<IntegerEncoder>()->GetFalseLiteral());
self_arcs_.resize(num_nodes_, kFalseLiteralIndex);
for (int arc = 0; arc < num_arcs; ++arc) {
const int head = heads[arc];
const int tail = tails[arc];
const Literal literal = literals[arc];
if (assignment_.LiteralIsFalse(literal)) continue;

if (tail == head) {
self_arcs_[tail] = literal;
self_arcs_[tail] = literal.Index();
} else {
graph_[{tail, head}] = literal;
}
Expand Down Expand Up @@ -97,7 +96,8 @@ CircuitPropagator::CircuitPropagator(const int num_nodes,
watch_index_to_arcs_.ResetFromFlatMapping(keys, values);

for (int node = 0; node < num_nodes_; ++node) {
if (assignment_.LiteralIsFalse(self_arcs_[node])) {
if (self_arcs_[node] == kFalseLiteralIndex ||
assignment_.LiteralIsFalse(Literal(self_arcs_[node]))) {
// For the multiple_subcircuit_through_zero case, must_be_in_cycle_ will
// be const and only contains zero.
if (node == 0 || !options_.multiple_subcircuit_through_zero) {
Expand Down Expand Up @@ -280,7 +280,7 @@ bool CircuitPropagator::Propagate() {
const int node = must_be_in_cycle_[i];
if (!in_current_path_[node]) {
miss_some_nodes = true;
extra_reason = self_arcs_[node].Index();
extra_reason = self_arcs_[node];
break;
}
}
Expand Down Expand Up @@ -320,7 +320,10 @@ bool CircuitPropagator::Propagate() {
BooleanVariable variable_with_same_reason = kNoBooleanVariable;
for (int node = 0; node < num_nodes_; ++node) {
if (in_current_path_[node]) continue;
if (assignment_.LiteralIsTrue(self_arcs_[node])) continue;
if (self_arcs_[node] >= 0 &&
assignment_.LiteralIsTrue(Literal(self_arcs_[node]))) {
continue;
}

// This shouldn't happen because ExactlyOnePerRowAndPerColumn() should
// have executed first and propagated self_arcs_[node] to false.
Expand All @@ -329,9 +332,12 @@ bool CircuitPropagator::Propagate() {
// We should have detected that above (miss_some_nodes == true). But we
// still need this for corner cases where the same literal is used for
// many arcs, and we just propagated it here.
if (assignment_.LiteralIsFalse(self_arcs_[node])) {
if (self_arcs_[node] == kFalseLiteralIndex ||
assignment_.LiteralIsFalse(Literal(self_arcs_[node]))) {
FillReasonForPath(start_node, trail_->MutableConflict());
trail_->MutableConflict()->push_back(self_arcs_[node]);
if (self_arcs_[node] != kFalseLiteralIndex) {
trail_->MutableConflict()->push_back(Literal(self_arcs_[node]));
}
return false;
}

Expand Down
4 changes: 2 additions & 2 deletions ortools/sat/circuit.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ namespace sat {
//
// Nodes that are not in the unique allowed sub-circuit must point to themseves.
// A nodes that has no self-arc must thus be inside the sub-circuit. If there is
// no self-arc at all, then this constaint forces the circuit to go through all
// no self-arc at all, then this constraint forces the circuit to go through all
// the nodes. Multi-arcs are NOT supported.
//
// Important: for correctness, this constraint requires that "exactly one"
Expand Down Expand Up @@ -87,7 +87,7 @@ class CircuitPropagator : PropagatorInterface, ReversibleInterface {
//
// TODO(user): for large dense graph, using a matrix is faster and uses less
// memory. If the need arise we can have the two implementations.
std::vector<Literal> self_arcs_;
std::vector<LiteralIndex> self_arcs_;
absl::flat_hash_map<std::pair<int, int>, Literal> graph_;

// Data used to interpret the watch indices passed to IncrementalPropagate().
Expand Down
17 changes: 17 additions & 0 deletions ortools/sat/cp_model_loader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -176,13 +176,30 @@ void LoadVariables(const CpModelProto& model_proto,
// Compute the integer variable references used by the model.
absl::flat_hash_set<int> used_variables;

const bool some_linerization =
m->GetOrCreate<SatParameters>()->linearization_level() > 0;

IndexReferences refs;
for (int c = 0; c < model_proto.constraints_size(); ++c) {
const ConstraintProto& ct = model_proto.constraints(c);
refs = GetReferencesUsedByConstraint(ct);
for (const int ref : refs.variables) {
used_variables.insert(PositiveRef(ref));
}

// We always add a linear relaxation for circuit/route except for
// linearization level zero.
if (some_linerization) {
if (ct.constraint_case() == ConstraintProto::kCircuit) {
for (const int ref : ct.circuit().literals()) {
used_variables.insert(PositiveRef(ref));
}
} else if (ct.constraint_case() == ConstraintProto::kRoutes) {
for (const int ref : ct.routes().literals()) {
used_variables.insert(PositiveRef(ref));
}
}
}
}

// Add the objectives variables that needs to be referenceable as integer
Expand Down
6 changes: 6 additions & 0 deletions ortools/sat/cp_model_solver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -820,6 +820,10 @@ class FullProblemSolver : public SubSolver {
}

bool IsDone() override {
// On large problem, deletion can take a while, so is is better to do it
// while waiting for the slower worker to finish.
if (shared_->SearchIsDone()) return true;

return stop_at_first_solution_ &&
shared_->response->first_solution_solvers_should_stop()->load();
}
Expand Down Expand Up @@ -983,6 +987,8 @@ class FeasibilityPumpSolver : public SubSolver {
shared_->stat_tables.AddTimingStat(*this);
}

bool IsDone() override { return shared_->SearchIsDone(); }

bool TaskIsAvailable() override {
if (shared_->SearchIsDone()) return false;
absl::MutexLock mutex_lock(&mutex_);
Expand Down
139 changes: 77 additions & 62 deletions ortools/sat/disjunctive.cc
Original file line number Diff line number Diff line change
Expand Up @@ -732,96 +732,111 @@ bool DisjunctiveSimplePrecedences::Propagate() {
return true;
}

bool DisjunctiveSimplePrecedences::Push(TaskTime before, int t) {
const int t_before = before.task_index;
DCHECK_NE(t_before, t);
helper_->ClearReason();
helper_->AddPresenceReason(t_before);
helper_->AddReasonForBeingBefore(t_before, t);
helper_->AddEndMinReason(t_before, before.time);
if (!helper_->IncreaseStartMin(t, before.time)) {
return false;
}
++stats_.num_propagations;
return true;
}

bool DisjunctiveSimplePrecedences::PropagateOneDirection() {
// We will loop in a decreasing way here.
// And add tasks that are present to the task_set_.
absl::Span<const TaskTime> task_by_decreasing_start_max =
helper_->TaskByDecreasingStartMax();

// We just keep amongst all the task before current_task, the one with the
// We just keep amongst all the task before current_end_min, the one with the
// highesh end-min.
TaskTime best_task_before = {-1, kMinIntegerValue};
to_propagate_.clear();

int blocking_task = -1;
processed_.assign(task_by_decreasing_start_max.size(), false);
for (const auto [current_task, current_end_min] :
helper_->TaskByIncreasingEndMin()) {
if (helper_->IsAbsent(current_task)) continue;
// We will loop in an increasing way here and consume task from beginning.
absl::Span<const TaskTime> task_by_increasing_end_min =
helper_->TaskByIncreasingEndMin();

for (; !task_by_increasing_end_min.empty();) {
// Skip absent task.
if (helper_->IsAbsent(task_by_increasing_end_min.front().task_index)) {
task_by_increasing_end_min.remove_prefix(1);
continue;
}

// Consider all task with a start_max < current_end_min.
int blocking_task = -1;
IntegerValue blocking_start_max;
IntegerValue current_end_min = task_by_increasing_end_min.front().time;
for (; true; task_by_decreasing_start_max.remove_suffix(1)) {
if (task_by_decreasing_start_max.empty()) {
// Small optim: this allows to process all remaining task rather than
// looping around are retesting this for all remaining tasks.
current_end_min = kMaxIntegerValue;
break;
}

for (; !task_by_decreasing_start_max.empty();
task_by_decreasing_start_max.remove_suffix(1)) {
const auto [t, start_max] = task_by_decreasing_start_max.back();
if (current_end_min <= start_max) break;
if (!helper_->IsPresent(t)) continue;

// If t has not been processed yet, it has a mandatory part, and we will
// delay all propagation until current_task is equal to this
// "blocking task".
// If t has a mandatory part, and extend further than current_end_min
// then we can push it first. All tasks for which their push is delayed
// are necessarily after this "blocking task".
//
// This idea is introduced in "Linear-Time Filtering Algorithms for the
// Disjunctive Constraints" Hamed Fahimi, Claude-Guy Quimper.
if (!processed_[t]) {
if (blocking_task != -1) {
// We have two blocking tasks, which means they are in conflict.
helper_->ClearReason();
helper_->AddPresenceReason(blocking_task);
helper_->AddPresenceReason(t);
helper_->AddReasonForBeingBefore(blocking_task, t);
helper_->AddReasonForBeingBefore(t, blocking_task);
return helper_->ReportConflict();
}
DCHECK_LT(start_max, helper_->ShiftedStartMin(t) + helper_->SizeMin(t))
<< " task should have mandatory part: "
<< helper_->TaskDebugString(t);
DCHECK(to_propagate_.empty());
const IntegerValue end_min = helper_->EndMin(t);
if (blocking_task == -1 && end_min >= current_end_min) {
DCHECK_LT(start_max, end_min) << " task should have mandatory part: "
<< helper_->TaskDebugString(t);
blocking_task = t;
to_propagate_.push_back(t);
} else {
const IntegerValue end_min = helper_->EndMin(t);
if (end_min > best_task_before.time) {
best_task_before = {t, end_min};
}
blocking_start_max = start_max;
current_end_min = end_min;
} else if (blocking_task != -1 && blocking_start_max < end_min) {
// Conflict! the task is after the blocking_task but also before.
helper_->ClearReason();
helper_->AddPresenceReason(blocking_task);
helper_->AddPresenceReason(t);
helper_->AddReasonForBeingBefore(blocking_task, t);
helper_->AddReasonForBeingBefore(t, blocking_task);
return helper_->ReportConflict();
} else if (end_min > best_task_before.time) {
best_task_before = {t, end_min};
}
}

// If we have a blocking task, we delay the propagation until current_task
// is the blocking task.
if (blocking_task != current_task) {
to_propagate_.push_back(current_task);
if (blocking_task != -1) continue;
// If we have a blocking task. We need to propagate it first.
if (blocking_task != -1) {
DCHECK(!helper_->IsAbsent(blocking_task));
if (best_task_before.time > helper_->StartMin(blocking_task)) {
if (!Push(best_task_before, blocking_task)) return false;
}

// Update best_task_before (it should likely be the blocking task).
const IntegerValue end_min = helper_->EndMin(blocking_task);
if (end_min > best_task_before.time) {
best_task_before = {blocking_task, end_min};
}
}

for (const int t : to_propagate_) {
DCHECK_NE(best_task_before.task_index, t);
DCHECK(!processed_[t]);
processed_[t] = true;
// Lets propagate all task after best_task_before.
for (; !task_by_increasing_end_min.empty();
task_by_increasing_end_min.remove_prefix(1)) {
const auto [t, end_min] = task_by_increasing_end_min.front();
if (end_min > current_end_min) break;
if (t == blocking_task) continue; // Already done.

// Lets propagate current_task.
if (best_task_before.time > helper_->StartMin(t)) {
// Corner case if a previous push from to_propagate_ caused a subsequent
// task to be absent.
// Corner case if a previous push caused a subsequent task to be absent.
if (helper_->IsAbsent(t)) continue;

const int t_before = best_task_before.task_index;
helper_->ClearReason();
helper_->AddPresenceReason(t_before);
helper_->AddReasonForBeingBefore(t_before, t);
helper_->AddEndMinReason(t_before, best_task_before.time);
if (!helper_->IncreaseStartMin(t, best_task_before.time)) {
return false;
}
++stats_.num_propagations;
}

if (t == blocking_task) {
blocking_task = -1;
const IntegerValue end_min = helper_->EndMin(t);
if (end_min > best_task_before.time) {
best_task_before = {t, end_min};
}
if (!Push(best_task_before, t)) return false;
}
}
to_propagate_.clear();
}
return true;
}
Expand Down
8 changes: 2 additions & 6 deletions ortools/sat/disjunctive.h
Original file line number Diff line number Diff line change
Expand Up @@ -213,17 +213,13 @@ class DisjunctiveSimplePrecedences : public PropagatorInterface {
public:
explicit DisjunctiveSimplePrecedences(SchedulingConstraintHelper* helper,
Model* model = nullptr)
: helper_(helper), stats_("DisjunctiveSimplePrecedences", model) {
to_propagate_.ClearAndReserve(helper->NumTasks());
}
: helper_(helper), stats_("DisjunctiveSimplePrecedences", model) {}
bool Propagate() final;
int RegisterWith(GenericLiteralWatcher* watcher);

private:
bool PropagateOneDirection();

std::vector<bool> processed_;
FixedCapacityVector<int> to_propagate_;
bool Push(TaskTime before, int t);

SchedulingConstraintHelper* helper_;
PropagationStatistics stats_;
Expand Down
10 changes: 0 additions & 10 deletions ortools/sat/linear_relaxation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -500,11 +500,6 @@ void AppendCircuitRelaxation(const ConstraintProto& ct, Model* model,
const Literal arc = mapping->Literal(ct.circuit().literals(i));
const int tail = ct.circuit().tails(i);
const int head = ct.circuit().heads(i);

// Make sure this literal has a view.
if (!mapping->IsInteger(PositiveRef(ct.circuit().literals(i)))) {
CreateNewIntegerVariableFromLiteral(arc, model);
}
outgoing_arc_constraints[tail].push_back(arc);
incoming_arc_constraints[head].push_back(arc);
}
Expand Down Expand Up @@ -545,11 +540,6 @@ void AppendRoutesRelaxation(const ConstraintProto& ct, Model* model,
const Literal arc = mapping->Literal(ct.routes().literals(i));
const int tail = ct.routes().tails(i);
const int head = ct.routes().heads(i);

// Make sure this literal has a view.
if (!mapping->IsInteger(PositiveRef(ct.routes().literals(i)))) {
CreateNewIntegerVariableFromLiteral(arc, model);
}
outgoing_arc_constraints[tail].push_back(arc);
incoming_arc_constraints[head].push_back(arc);
}
Expand Down
Loading

0 comments on commit 458e2a1

Please sign in to comment.