Skip to content

Commit

Permalink
Add Expression::remove_variables() method
Browse files Browse the repository at this point in the history
  • Loading branch information
arcondello committed Jul 5, 2023
1 parent 3873226 commit ad549c5
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 0 deletions.
37 changes: 37 additions & 0 deletions dimod/include/dimod/expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include <limits>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>

Expand Down Expand Up @@ -267,6 +268,10 @@ class Expression : public abc::QuadraticModelBase<Bias, Index> {

virtual void remove_variable(index_type v);

/// Remove several variables from the expression.
template<class Iter>
void remove_variables(Iter first, Iter last);

/// Set the linear bias of variable `v`.
void set_linear(index_type v, bias_type bias);

Expand Down Expand Up @@ -624,6 +629,38 @@ void Expression<bias_type, index_type>::remove_variable(index_type v) {
}
}

template <class bias_type, class index_type>
template <class Iter>
void Expression<bias_type, index_type>::remove_variables(Iter first, Iter last) {
std::unordered_set<index_type> to_remove;
for (auto it = first; it != last; ++it) {
if (indices_.find(*it) != indices_.end()) {
to_remove.emplace(*it);
}
}

if (!to_remove.size()) {
return; // nothing to remove
}

// now remove any variables found in to_remove
size_type i = 0;
while (i < this->num_variables()) {
if (to_remove.count(variables_[i])) {
base_type::remove_variable(i);
variables_.erase(variables_.begin() + i);
} else {
++i;
}
}

// finally fix the indices by rebuilding from scratch
indices_.clear();
for (size_type i = 0, end = variables_.size(); i < end; ++i) {
indices_[variables_[i]] = i;
}
}

template <class bias_type, class index_type>
void Expression<bias_type, index_type>::set_linear(index_type v, bias_type bias) {
base_type::set_linear(enforce_variable(v), bias);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
features:
- |
Add ``dimod::Expression::remove_variables()`` C++ method.
This method is useful for removing variables in bulk from the objective
or constraints of a constrained quadratic model.
54 changes: 54 additions & 0 deletions testscpp/tests/test_constrained_quadratic_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -901,4 +901,58 @@ TEST_CASE("Test Expression::add_quadratic()") {
}
}

TEST_CASE("Test Expression::remove_variables()") {
GIVEN("A CQM with five variables, an objective, and one constraint") {
auto cqm = ConstrainedQuadraticModel<double>();

cqm.add_variables(Vartype::BINARY, 10);
cqm.objective.set_linear(2, 1);
cqm.objective.set_linear(3, 2);
cqm.objective.set_linear(5, 3);
cqm.objective.set_linear(8, 4);
cqm.objective.set_quadratic(5, 8, 5);

auto c = cqm.add_linear_constraint({2, 3, 5, 8}, {1, 2, 3, 4}, Sense::EQ, 1);
cqm.constraint_ref(c).set_quadratic(5, 8, 5);

std::vector<int> fixed = {3, 8};

WHEN("We remove several variables from the objective") {
cqm.objective.remove_variables(fixed.cbegin(), fixed.cend()); // works with cbegin/cend

THEN("they are removed as expected") {
CHECK(cqm.num_variables() == 10); // not changed

CHECK(cqm.objective.num_variables() == 2);
CHECK(cqm.objective.num_interactions() == 0);

CHECK(cqm.objective.linear(2) == 1);
CHECK(cqm.objective.linear(3) == 0); // removed
CHECK(cqm.objective.linear(5) == 3);
CHECK(cqm.objective.linear(8) == 0); // removed

CHECK(cqm.objective.variables() == std::vector<int>{2, 5});
}
}

WHEN("We remove several variables from the constraint") {
cqm.constraint_ref(c).remove_variables(fixed.begin(), fixed.end());

THEN("they are removed as expected") {
CHECK(cqm.num_variables() == 10); // not changed

CHECK(cqm.constraint_ref(c).num_variables() == 2);
CHECK(cqm.constraint_ref(c).num_interactions() == 0);

CHECK(cqm.constraint_ref(c).linear(2) == 1);
CHECK(cqm.constraint_ref(c).linear(3) == 0); // removed
CHECK(cqm.constraint_ref(c).linear(5) == 3);
CHECK(cqm.constraint_ref(c).linear(8) == 0); // removed

CHECK(cqm.constraint_ref(c).variables() == std::vector<int>{2, 5});
}
}
}
}

} // namespace dimod

0 comments on commit ad549c5

Please sign in to comment.