Skip to content

Commit

Permalink
few optimisations for zp
Browse files Browse the repository at this point in the history
  • Loading branch information
hschreiber committed Jun 5, 2024
1 parent a86f72e commit abb93a6
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

#include <iostream> //print() only
#include <vector>
#include <set>
#include <utility> //std::swap, std::move & std::exchange

#include <boost/intrusive/set.hpp>
Expand Down Expand Up @@ -335,16 +334,6 @@ class Base_matrix_with_column_compression : protected Master_matrix::Matrix_row_
};

using ra_opt = typename Master_matrix::Matrix_row_access_option;
using cell_rep_type =
typename std::conditional<Master_matrix::Option_list::is_z2,
index,
std::pair<index, Field_element_type>
>::type;
using tmp_column_type = typename std::conditional<
Master_matrix::Option_list::is_z2,
std::set<index>,
std::set<std::pair<index, Field_element_type>, typename Master_matrix::CellPairComparator>
>::type;
using col_dict_type = boost::intrusive::set<Column_type, boost::intrusive::constant_time_size<false> >;

col_dict_type columnToRep_; /**< Map from a column to the index of its representative. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,11 +162,11 @@ inline void Base_pairing<Master_matrix>::_reduce()
curr += _matrix()->get_column(pivotsToColumn.at(pivot));
} else {
auto& toadd = _matrix()->get_column(pivotsToColumn.at(pivot));
typename Master_matrix::element_type coef = curr.get_pivot_value();
typename Master_matrix::element_type coef = toadd.get_pivot_value();
auto& operators = _matrix()->colSettings_->operators;
coef = operators.get_inverse(coef);
operators.multiply_inplace(coef, operators.get_characteristic() - toadd.get_pivot_value());
curr.multiply_target_and_add(coef, toadd);
operators.multiply_inplace(coef, operators.get_characteristic() - curr.get_pivot_value());
curr.multiply_source_and_add(toadd, coef);
}

pivot = curr.get_pivot();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

#include <iostream> //print() only
#include <set>
#include <map>
#include <stdexcept>
#include <vector>
#include <utility> //std::swap, std::move & std::exchange
Expand Down Expand Up @@ -485,7 +486,7 @@ class Chain_matrix : public Master_matrix::Matrix_dimension_option,
using tmp_column_type = typename std::conditional<
Master_matrix::Option_list::is_z2,
std::set<id_index>,
std::set<std::pair<id_index, Field_element_type>, typename Master_matrix::CellPairComparator>
std::map<id_index, Field_element_type>
>::type;

matrix_type matrix_; /**< Column container. */
Expand Down Expand Up @@ -1256,18 +1257,14 @@ inline void Chain_matrix<Master_matrix>::_add_to(const Column_type& column,
} else {
auto& operators = colSettings_->operators;
for (const Cell& cell : column) {
std::pair<index, Field_element_type> p(cell.get_row_index(), cell.get_element());
auto res_it = set.find(p);

if (res_it != set.end()) {
operators.multiply_and_add_inplace_front(p.second, coef, res_it->second);
set.erase(res_it);
if (p.second != Field_operators::get_additive_identity()) {
set.insert(p);
}
auto res = set.emplace(cell.get_row_index(), cell.get_element());
if (res.second){
operators.multiply_inplace(res.first->second, coef);
} else {
operators.multiply_inplace(p.second, coef);
set.insert(p);
operators.multiply_and_add_inplace_back(cell.get_element(), coef, res.first->second);
if (res.first->second == Field_operators::get_additive_identity()) {
set.erase(res.first);
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -819,18 +819,20 @@ inline void RU_matrix<Master_matrix>::_reduce()
while (pivot != static_cast<id_index>(-1) && currIndex != static_cast<index>(-1)) {
if constexpr (Master_matrix::Option_list::is_z2) {
curr += reducedMatrixR_.get_column(currIndex);
//to avoid having to do line operations during vineyards, U is transposed
//TODO: explain this somewhere in the documentation...
if constexpr (Master_matrix::Option_list::has_vine_update)
mirrorMatrixU_.get_column(currIndex) += mirrorMatrixU_.get_column(i);
else
mirrorMatrixU_.get_column(i) += mirrorMatrixU_.get_column(currIndex);
} else {
Column_type& toadd = reducedMatrixR_.get_column(currIndex);
Field_element_type coef = curr.get_pivot_value();
Field_element_type coef = toadd.get_pivot_value();
coef = operators_->get_inverse(coef);
operators_->multiply_inplace(coef, operators_->get_characteristic() - toadd.get_pivot_value());
operators_->multiply_inplace(coef, operators_->get_characteristic() - curr.get_pivot_value());

curr.multiply_target_and_add(coef, toadd);
mirrorMatrixU_.multiply_target_and_add_to(currIndex, coef, i);
curr.multiply_source_and_add(toadd, coef);
mirrorMatrixU_.multiply_source_and_add_to(coef, currIndex, i);
}

pivot = curr.get_pivot();
Expand Down Expand Up @@ -881,18 +883,20 @@ inline void RU_matrix<Master_matrix>::_reduce_last_column(index lastIndex)
while (pivot != static_cast<id_index>(-1) && currIndex != static_cast<index>(-1)) {
if constexpr (Master_matrix::Option_list::is_z2) {
curr += reducedMatrixR_.get_column(currIndex);
//to avoid having to do line operations during vineyards, U is transposed
//TODO: explain this somewhere in the documentation...
if constexpr (Master_matrix::Option_list::has_vine_update)
mirrorMatrixU_.get_column(currIndex) += mirrorMatrixU_.get_column(lastIndex);
else
mirrorMatrixU_.get_column(lastIndex) += mirrorMatrixU_.get_column(currIndex);
} else {
Column_type& toadd = reducedMatrixR_.get_column(currIndex);
Field_element_type coef = curr.get_pivot_value();
Field_element_type coef = toadd.get_pivot_value();
coef = operators_->get_inverse(coef);
operators_->multiply_inplace(coef, operators_->get_characteristic() - toadd.get_pivot_value());
operators_->multiply_inplace(coef, operators_->get_characteristic() - curr.get_pivot_value());

curr.multiply_target_and_add(coef, toadd);
mirrorMatrixU_.get_column(lastIndex).multiply_target_and_add(coef, mirrorMatrixU_.get_column(currIndex));
curr.multiply_source_and_add(toadd, coef);
mirrorMatrixU_.get_column(lastIndex).multiply_source_and_add(mirrorMatrixU_.get_column(currIndex), coef);
}

pivot = curr.get_pivot();
Expand Down
11 changes: 0 additions & 11 deletions src/Persistence_matrix/include/gudhi/matrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -268,17 +268,6 @@ class Matrix {
std::pair<id_index, element_type>
>::type;

/**
* @brief Compaires two pairs, representing a cell (first = row index, second = value),
* by their position in the column and not their values.
* The two represented cells are therefore assumed to be in the same column.
*/
struct CellPairComparator {
bool operator()(const std::pair<id_index, element_type>& p1, const std::pair<id_index, element_type>& p2) const {
return p1.first < p2.first;
};
};

/**
* @brief Compaires two cells by their position in the row. They are assume to be in the same row.
*/
Expand Down
18 changes: 9 additions & 9 deletions src/Persistence_matrix/test/pm_matrix_tests.h
Original file line number Diff line number Diff line change
Expand Up @@ -799,7 +799,7 @@ void test_ru_u_access(){
uColumns[2] = {{2,1}};
uColumns[3] = {{3,1}};
uColumns[4] = {{4,1}};
uColumns[5] = {{3,1},{4,1},{5,4}};
uColumns[5] = {{3,4},{4,4},{5,1}};
uColumns[6] = {{6,1}};
}

Expand Down Expand Up @@ -961,9 +961,9 @@ void test_ru_u_row_access(){
rows.push_back({{0,1}});
rows.push_back({{1,1}});
rows.push_back({{2,1}});
rows.push_back({{3,1},{5,1}});
rows.push_back({{4,1},{5,1}});
rows.push_back({{5,4}});
rows.push_back({{3,1},{5,4}});
rows.push_back({{4,1},{5,4}});
rows.push_back({{5,1}});
rows.push_back({{6,1}});
}

Expand Down Expand Up @@ -1319,7 +1319,7 @@ void test_ru_operation(){
uColumns[2] = {{2,1}};
uColumns[3] = {{3,1}};
uColumns[4] = {{4,1}};
uColumns[5] = {{3,1},{4,1},{5,4}};
uColumns[5] = {{3,4},{4,4},{5,1}};
uColumns[6] = {{6,1}};
}

Expand All @@ -1340,7 +1340,7 @@ void test_ru_operation(){
uColumns[5] = {4,5};
} else {
columns[5] = {{0,1},{1,4}};
uColumns[5] = {{3,2},{4,1},{5,4}};
uColumns[5] = {{4,4},{5,1}};
}
test_content_equality(columns, m);
if constexpr (is_indexed_by_position<Matrix>()){
Expand All @@ -1359,7 +1359,7 @@ void test_ru_operation(){
uColumns[5] = {5};
} else {
columns[5] = {{0,1},{2,4}};
uColumns[5] = {{3,2},{4,2},{5,4}};
uColumns[5] = {{5,1}};
}
test_content_equality(columns, m);
if constexpr (is_indexed_by_position<Matrix>()){
Expand All @@ -1376,7 +1376,7 @@ void test_ru_operation(){
uColumns[3] = {3,5};
} else {
columns[3] = {{0,4},{1,2},{2,4}};
uColumns[3] = {{4,2},{5,4}};
uColumns[3] = {{3,3},{5,1}};
}
test_content_equality(columns, m);
if constexpr (is_indexed_by_position<Matrix>()){
Expand All @@ -1392,7 +1392,7 @@ void test_ru_operation(){
uColumns[4] = {4};
} else {
columns[4] = {{0,4},{1,1}};
uColumns[4] = {{3,3},{4,4},{5,1}};
uColumns[4] = {{4,1},{5,4}};
}
test_content_equality(columns, m);
if constexpr (is_indexed_by_position<Matrix>()){
Expand Down

0 comments on commit abb93a6

Please sign in to comment.