diff --git a/src/Persistence_matrix/include/gudhi/Persistence_matrix/base_matrix_with_column_compression.h b/src/Persistence_matrix/include/gudhi/Persistence_matrix/base_matrix_with_column_compression.h index 826949eee3..d1b66456e0 100644 --- a/src/Persistence_matrix/include/gudhi/Persistence_matrix/base_matrix_with_column_compression.h +++ b/src/Persistence_matrix/include/gudhi/Persistence_matrix/base_matrix_with_column_compression.h @@ -19,7 +19,6 @@ #include //print() only #include -#include #include //std::swap, std::move & std::exchange #include @@ -335,16 +334,6 @@ class Base_matrix_with_column_compression : protected Master_matrix::Matrix_row_ }; using ra_opt = typename Master_matrix::Matrix_row_access_option; - using cell_rep_type = - typename std::conditional - >::type; - using tmp_column_type = typename std::conditional< - Master_matrix::Option_list::is_z2, - std::set, - std::set, typename Master_matrix::CellPairComparator> - >::type; using col_dict_type = boost::intrusive::set >; col_dict_type columnToRep_; /**< Map from a column to the index of its representative. */ diff --git a/src/Persistence_matrix/include/gudhi/Persistence_matrix/base_pairing.h b/src/Persistence_matrix/include/gudhi/Persistence_matrix/base_pairing.h index 31c95e35d1..8cda78dc62 100644 --- a/src/Persistence_matrix/include/gudhi/Persistence_matrix/base_pairing.h +++ b/src/Persistence_matrix/include/gudhi/Persistence_matrix/base_pairing.h @@ -162,11 +162,11 @@ inline void Base_pairing::_reduce() curr += _matrix()->get_column(pivotsToColumn.at(pivot)); } else { auto& toadd = _matrix()->get_column(pivotsToColumn.at(pivot)); - typename Master_matrix::element_type coef = curr.get_pivot_value(); + typename Master_matrix::element_type coef = toadd.get_pivot_value(); auto& operators = _matrix()->colSettings_->operators; coef = operators.get_inverse(coef); - operators.multiply_inplace(coef, operators.get_characteristic() - toadd.get_pivot_value()); - curr.multiply_target_and_add(coef, toadd); + operators.multiply_inplace(coef, operators.get_characteristic() - curr.get_pivot_value()); + curr.multiply_source_and_add(toadd, coef); } pivot = curr.get_pivot(); diff --git a/src/Persistence_matrix/include/gudhi/Persistence_matrix/chain_matrix.h b/src/Persistence_matrix/include/gudhi/Persistence_matrix/chain_matrix.h index 82bdefdd3e..0faa86d5f5 100644 --- a/src/Persistence_matrix/include/gudhi/Persistence_matrix/chain_matrix.h +++ b/src/Persistence_matrix/include/gudhi/Persistence_matrix/chain_matrix.h @@ -19,6 +19,7 @@ #include //print() only #include +#include #include #include #include //std::swap, std::move & std::exchange @@ -485,7 +486,7 @@ class Chain_matrix : public Master_matrix::Matrix_dimension_option, using tmp_column_type = typename std::conditional< Master_matrix::Option_list::is_z2, std::set, - std::set, typename Master_matrix::CellPairComparator> + std::map >::type; matrix_type matrix_; /**< Column container. */ @@ -1256,18 +1257,14 @@ inline void Chain_matrix::_add_to(const Column_type& column, } else { auto& operators = colSettings_->operators; for (const Cell& cell : column) { - std::pair p(cell.get_row_index(), cell.get_element()); - auto res_it = set.find(p); - - if (res_it != set.end()) { - operators.multiply_and_add_inplace_front(p.second, coef, res_it->second); - set.erase(res_it); - if (p.second != Field_operators::get_additive_identity()) { - set.insert(p); - } + auto res = set.emplace(cell.get_row_index(), cell.get_element()); + if (res.second){ + operators.multiply_inplace(res.first->second, coef); } else { - operators.multiply_inplace(p.second, coef); - set.insert(p); + operators.multiply_and_add_inplace_back(cell.get_element(), coef, res.first->second); + if (res.first->second == Field_operators::get_additive_identity()) { + set.erase(res.first); + } } } } diff --git a/src/Persistence_matrix/include/gudhi/Persistence_matrix/ru_matrix.h b/src/Persistence_matrix/include/gudhi/Persistence_matrix/ru_matrix.h index c2bbdea440..35ed360824 100644 --- a/src/Persistence_matrix/include/gudhi/Persistence_matrix/ru_matrix.h +++ b/src/Persistence_matrix/include/gudhi/Persistence_matrix/ru_matrix.h @@ -819,18 +819,20 @@ inline void RU_matrix::_reduce() while (pivot != static_cast(-1) && currIndex != static_cast(-1)) { if constexpr (Master_matrix::Option_list::is_z2) { curr += reducedMatrixR_.get_column(currIndex); + //to avoid having to do line operations during vineyards, U is transposed + //TODO: explain this somewhere in the documentation... if constexpr (Master_matrix::Option_list::has_vine_update) mirrorMatrixU_.get_column(currIndex) += mirrorMatrixU_.get_column(i); else mirrorMatrixU_.get_column(i) += mirrorMatrixU_.get_column(currIndex); } else { Column_type& toadd = reducedMatrixR_.get_column(currIndex); - Field_element_type coef = curr.get_pivot_value(); + Field_element_type coef = toadd.get_pivot_value(); coef = operators_->get_inverse(coef); - operators_->multiply_inplace(coef, operators_->get_characteristic() - toadd.get_pivot_value()); + operators_->multiply_inplace(coef, operators_->get_characteristic() - curr.get_pivot_value()); - curr.multiply_target_and_add(coef, toadd); - mirrorMatrixU_.multiply_target_and_add_to(currIndex, coef, i); + curr.multiply_source_and_add(toadd, coef); + mirrorMatrixU_.multiply_source_and_add_to(coef, currIndex, i); } pivot = curr.get_pivot(); @@ -881,18 +883,20 @@ inline void RU_matrix::_reduce_last_column(index lastIndex) while (pivot != static_cast(-1) && currIndex != static_cast(-1)) { if constexpr (Master_matrix::Option_list::is_z2) { curr += reducedMatrixR_.get_column(currIndex); + //to avoid having to do line operations during vineyards, U is transposed + //TODO: explain this somewhere in the documentation... if constexpr (Master_matrix::Option_list::has_vine_update) mirrorMatrixU_.get_column(currIndex) += mirrorMatrixU_.get_column(lastIndex); else mirrorMatrixU_.get_column(lastIndex) += mirrorMatrixU_.get_column(currIndex); } else { Column_type& toadd = reducedMatrixR_.get_column(currIndex); - Field_element_type coef = curr.get_pivot_value(); + Field_element_type coef = toadd.get_pivot_value(); coef = operators_->get_inverse(coef); - operators_->multiply_inplace(coef, operators_->get_characteristic() - toadd.get_pivot_value()); + operators_->multiply_inplace(coef, operators_->get_characteristic() - curr.get_pivot_value()); - curr.multiply_target_and_add(coef, toadd); - mirrorMatrixU_.get_column(lastIndex).multiply_target_and_add(coef, mirrorMatrixU_.get_column(currIndex)); + curr.multiply_source_and_add(toadd, coef); + mirrorMatrixU_.get_column(lastIndex).multiply_source_and_add(mirrorMatrixU_.get_column(currIndex), coef); } pivot = curr.get_pivot(); diff --git a/src/Persistence_matrix/include/gudhi/matrix.h b/src/Persistence_matrix/include/gudhi/matrix.h index 22e54c5b88..0f28b16818 100644 --- a/src/Persistence_matrix/include/gudhi/matrix.h +++ b/src/Persistence_matrix/include/gudhi/matrix.h @@ -268,17 +268,6 @@ class Matrix { std::pair >::type; - /** - * @brief Compaires two pairs, representing a cell (first = row index, second = value), - * by their position in the column and not their values. - * The two represented cells are therefore assumed to be in the same column. - */ - struct CellPairComparator { - bool operator()(const std::pair& p1, const std::pair& p2) const { - return p1.first < p2.first; - }; - }; - /** * @brief Compaires two cells by their position in the row. They are assume to be in the same row. */ diff --git a/src/Persistence_matrix/test/pm_matrix_tests.h b/src/Persistence_matrix/test/pm_matrix_tests.h index f0650bfba9..ca6f6168ff 100644 --- a/src/Persistence_matrix/test/pm_matrix_tests.h +++ b/src/Persistence_matrix/test/pm_matrix_tests.h @@ -799,7 +799,7 @@ void test_ru_u_access(){ uColumns[2] = {{2,1}}; uColumns[3] = {{3,1}}; uColumns[4] = {{4,1}}; - uColumns[5] = {{3,1},{4,1},{5,4}}; + uColumns[5] = {{3,4},{4,4},{5,1}}; uColumns[6] = {{6,1}}; } @@ -961,9 +961,9 @@ void test_ru_u_row_access(){ rows.push_back({{0,1}}); rows.push_back({{1,1}}); rows.push_back({{2,1}}); - rows.push_back({{3,1},{5,1}}); - rows.push_back({{4,1},{5,1}}); - rows.push_back({{5,4}}); + rows.push_back({{3,1},{5,4}}); + rows.push_back({{4,1},{5,4}}); + rows.push_back({{5,1}}); rows.push_back({{6,1}}); } @@ -1319,7 +1319,7 @@ void test_ru_operation(){ uColumns[2] = {{2,1}}; uColumns[3] = {{3,1}}; uColumns[4] = {{4,1}}; - uColumns[5] = {{3,1},{4,1},{5,4}}; + uColumns[5] = {{3,4},{4,4},{5,1}}; uColumns[6] = {{6,1}}; } @@ -1340,7 +1340,7 @@ void test_ru_operation(){ uColumns[5] = {4,5}; } else { columns[5] = {{0,1},{1,4}}; - uColumns[5] = {{3,2},{4,1},{5,4}}; + uColumns[5] = {{4,4},{5,1}}; } test_content_equality(columns, m); if constexpr (is_indexed_by_position()){ @@ -1359,7 +1359,7 @@ void test_ru_operation(){ uColumns[5] = {5}; } else { columns[5] = {{0,1},{2,4}}; - uColumns[5] = {{3,2},{4,2},{5,4}}; + uColumns[5] = {{5,1}}; } test_content_equality(columns, m); if constexpr (is_indexed_by_position()){ @@ -1376,7 +1376,7 @@ void test_ru_operation(){ uColumns[3] = {3,5}; } else { columns[3] = {{0,4},{1,2},{2,4}}; - uColumns[3] = {{4,2},{5,4}}; + uColumns[3] = {{3,3},{5,1}}; } test_content_equality(columns, m); if constexpr (is_indexed_by_position()){ @@ -1392,7 +1392,7 @@ void test_ru_operation(){ uColumns[4] = {4}; } else { columns[4] = {{0,4},{1,1}}; - uColumns[4] = {{3,3},{4,4},{5,1}}; + uColumns[4] = {{4,1},{5,4}}; } test_content_equality(columns, m); if constexpr (is_indexed_by_position()){