Skip to content

Commit

Permalink
Merge pull request GUDHI#1143 from hschreiber/matrix_vine_opt
Browse files Browse the repository at this point in the history
[Persistence matrix] Better additions in U for RU matrices in Z2
  • Loading branch information
VincentRouvreau authored Oct 24, 2024
2 parents 6391ac2 + e1b4634 commit 1cf5a3d
Show file tree
Hide file tree
Showing 14 changed files with 178 additions and 73 deletions.
9 changes: 9 additions & 0 deletions src/Persistence_matrix/concept/PersistenceMatrixColumn.h
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,15 @@ class PersistenceMatrixColumn :
template <class Entry_range>
PersistenceMatrixColumn& multiply_source_and_add(const Entry_range& column, const Field_element& val);

/**
* @brief Adds a copy of the given entry at the end of the column. It is therefore assumed that the row index
* of the entry is higher than the current pivot of the column. Not available for @ref chainmatrix "chain matrices"
* and is only needed for @ref rumatrix "RU matrices".
*
* @param entry Entry to push back.
*/
void push_back(const Entry& entry);

/**
* @brief Equality comparator. Equal in the sense that what is "supposed" to be contained in the columns is equal,
* not what is actually stored in the underlying container. For example, the underlying container of
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,7 @@ inline void Base_matrix<Master_matrix>::print()
if constexpr (Master_matrix::Option_list::has_row_access) {
std::cout << "Row Matrix:\n";
for (Index i = 0; i < nextInsertIndex_; ++i) {
const auto& row = RA_opt::rows_[i];
const auto& row = (*RA_opt::rows_)[i];
for (const auto& entry : row) {
std::cout << entry.get_column_index() << " ";
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -729,8 +729,8 @@ inline void Boundary_matrix<Master_matrix>::print()
if constexpr (Master_matrix::Option_list::has_row_access) {
std::cout << "Row Matrix:\n";
for (ID_index i = 0; i < nextInsertIndex_; ++i) {
const auto& row = RA_opt::rows_[i];
for (const auto& entry : row) {
const auto& row = (*RA_opt::rows_)[i];
for (const typename Column::Entry& entry : row) {
std::cout << entry.get_column_index() << " ";
}
std::cout << "(" << i << ")\n";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,7 @@ inline void RU_matrix<Master_matrix>::insert_boundary(ID_index cellIndex,
if constexpr (Master_matrix::Option_list::has_vine_update) {
if (cellIndex != nextEventIndex_) {
Swap_opt::_positionToRowIdx().emplace(nextEventIndex_, cellIndex);
if (Master_matrix::Option_list::has_column_pairings) {
if constexpr (Master_matrix::Option_list::has_column_pairings) {
Swap_opt::template RU_pairing<Master_matrix>::idToPosition_.emplace(cellIndex, nextEventIndex_);
}
}
Expand Down Expand Up @@ -638,7 +638,7 @@ inline void RU_matrix<Master_matrix>::add_to(Index sourceColumnIndex, Index targ
{
reducedMatrixR_.add_to(sourceColumnIndex, targetColumnIndex);
// U transposed to avoid row operations
if constexpr (Master_matrix::Option_list::has_vine_update)
if constexpr (Master_matrix::Option_list::is_z2)
mirrorMatrixU_.add_to(targetColumnIndex, sourceColumnIndex);
else
mirrorMatrixU_.add_to(sourceColumnIndex, targetColumnIndex);
Expand All @@ -649,6 +649,8 @@ inline void RU_matrix<Master_matrix>::multiply_target_and_add_to(Index sourceCol
const Field_element& coefficient,
Index targetColumnIndex)
{
static_assert(!Master_matrix::Option_list::is_z2,
"Multiplication with something else than the identity is not allowed with Z2 coefficients.");
reducedMatrixR_.multiply_target_and_add_to(sourceColumnIndex, coefficient, targetColumnIndex);
mirrorMatrixU_.multiply_target_and_add_to(sourceColumnIndex, coefficient, targetColumnIndex);
}
Expand All @@ -658,6 +660,8 @@ inline void RU_matrix<Master_matrix>::multiply_source_and_add_to(const Field_ele
Index sourceColumnIndex,
Index targetColumnIndex)
{
static_assert(!Master_matrix::Option_list::is_z2,
"Multiplication with something else than the identity is not allowed with Z2 coefficients.");
reducedMatrixR_.multiply_source_and_add_to(coefficient, sourceColumnIndex, targetColumnIndex);
mirrorMatrixU_.multiply_source_and_add_to(coefficient, sourceColumnIndex, targetColumnIndex);
}
Expand Down Expand Up @@ -844,19 +848,17 @@ inline void RU_matrix<Master_matrix>::_reduce_column_by(Index target, Index sour
curr += reducedMatrixR_.get_column(source);
// to avoid having to do line operations during vineyards, U is transposed
// TODO: explain this somewhere in the documentation...
if constexpr (Master_matrix::Option_list::has_vine_update)
mirrorMatrixU_.get_column(source) += mirrorMatrixU_.get_column(target);
else
mirrorMatrixU_.get_column(target) += mirrorMatrixU_.get_column(source);
mirrorMatrixU_.get_column(source).push_back(*mirrorMatrixU_.get_column(target).begin());
} else {
Column& toadd = reducedMatrixR_.get_column(source);
Field_element coef = toadd.get_pivot_value();
coef = operators_->get_inverse(coef);
operators_->multiply_inplace(coef, operators_->get_characteristic() - curr.get_pivot_value());

curr.multiply_source_and_add(toadd, coef);
// but no transposition for Zp, careful if there will be vineyard or rep cycles in Zp one day
// TODO: explain this somewhere in the documentation...
mirrorMatrixU_.multiply_source_and_add_to(coef, source, target);
// mirrorMatrixU_.get_column(target).multiply_source_and_add(mirrorMatrixU_.get_column(source), coef);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <utility> //std::swap, std::move & std::exchange

#include <boost/iterator/indirect_iterator.hpp>
#include "gudhi/Debug_utils.h"

#include <gudhi/Persistence_matrix/allocators/entry_constructors.h>

Expand Down Expand Up @@ -134,6 +135,8 @@ class Heap_column : public Master_matrix::Column_dimension_option, public Master
Heap_column& multiply_source_and_add(const Entry_range& column, const Field_element& val);
Heap_column& multiply_source_and_add(Heap_column& column, const Field_element& val);

void push_back(const Entry& entry);

std::size_t compute_hash_value();

friend bool operator==(const Heap_column& c1, const Heap_column& c2) {
Expand Down Expand Up @@ -868,6 +871,21 @@ inline Heap_column<Master_matrix>& Heap_column<Master_matrix>::multiply_source_a
return *this;
}

template <class Master_matrix>
inline void Heap_column<Master_matrix>::push_back(const Entry& entry)
{
static_assert(Master_matrix::Option_list::is_of_boundary_type, "`push_back` is not available for Chain matrices.");

GUDHI_CHECK(entry.get_row_index() > get_pivot(), "The new row index has to be higher than the current pivot.");

Entry* newEntry = entryPool_->construct(entry.get_row_index());
if constexpr (!Master_matrix::Option_list::is_z2) {
newEntry->set_element(operators_->get_value(entry.get_element()));
}
column_.push_back(newEntry);
std::push_heap(column_.begin(), column_.end(), entryPointerComp_);
}

template <class Master_matrix>
inline Heap_column<Master_matrix>& Heap_column<Master_matrix>::operator=(const Heap_column& other)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,8 @@ class Intrusive_list_column : public Master_matrix::Row_access_option,
Intrusive_list_column& multiply_source_and_add(const Entry_range& column, const Field_element& val);
Intrusive_list_column& multiply_source_and_add(Intrusive_list_column& column, const Field_element& val);

void push_back(const Entry& entry);

friend bool operator==(const Intrusive_list_column& c1, const Intrusive_list_column& c2) {
if (&c1 == &c2) return true;

Expand Down Expand Up @@ -821,6 +823,20 @@ inline Intrusive_list_column<Master_matrix>& Intrusive_list_column<Master_matrix
return *this;
}

template <class Master_matrix>
inline void Intrusive_list_column<Master_matrix>::push_back(const Entry& entry)
{
static_assert(Master_matrix::Option_list::is_of_boundary_type, "`push_back` is not available for Chain matrices.");

GUDHI_CHECK(entry.get_row_index() > get_pivot(), "The new row index has to be higher than the current pivot.");

if constexpr (Master_matrix::Option_list::is_z2) {
_insert_entry(entry.get_row_index(), column_.end());
} else {
_insert_entry(entry.get_element(), entry.get_row_index(), column_.end());
}
}

template <class Master_matrix>
inline Intrusive_list_column<Master_matrix>& Intrusive_list_column<Master_matrix>::operator=(
const Intrusive_list_column& other)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ class Intrusive_set_column : public Master_matrix::Row_access_option,
Intrusive_set_column& multiply_source_and_add(const Entry_range& column, const Field_element& val);
Intrusive_set_column& multiply_source_and_add(Intrusive_set_column& column, const Field_element& val);

void push_back(const Entry& entry);

friend bool operator==(const Intrusive_set_column& c1, const Intrusive_set_column& c2) {
if (&c1 == &c2) return true;

Expand Down Expand Up @@ -821,6 +823,20 @@ inline Intrusive_set_column<Master_matrix>& Intrusive_set_column<Master_matrix>:
return *this;
}

template <class Master_matrix>
inline void Intrusive_set_column<Master_matrix>::push_back(const Entry& entry)
{
static_assert(Master_matrix::Option_list::is_of_boundary_type, "`push_back` is not available for Chain matrices.");

GUDHI_CHECK(entry.get_row_index() > get_pivot(), "The new row index has to be higher than the current pivot.");

if constexpr (Master_matrix::Option_list::is_z2) {
_insert_entry(entry.get_row_index(), column_.end());
} else {
_insert_entry(entry.get_element(), entry.get_row_index(), column_.end());
}
}

template <class Master_matrix>
inline Intrusive_set_column<Master_matrix>& Intrusive_set_column<Master_matrix>::operator=(
const Intrusive_set_column& other)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@ class List_column : public Master_matrix::Row_access_option,
List_column& multiply_source_and_add(const Entry_range& column, const Field_element& val);
List_column& multiply_source_and_add(List_column& column, const Field_element& val);

void push_back(const Entry& entry);

friend bool operator==(const List_column& c1, const List_column& c2) {
if (&c1 == &c2) return true;

Expand Down Expand Up @@ -805,6 +807,20 @@ inline List_column<Master_matrix>& List_column<Master_matrix>::multiply_source_a
return *this;
}

template <class Master_matrix>
inline void List_column<Master_matrix>::push_back(const Entry& entry)
{
static_assert(Master_matrix::Option_list::is_of_boundary_type, "`push_back` is not available for Chain matrices.");

GUDHI_CHECK(entry.get_row_index() > get_pivot(), "The new row index has to be higher than the current pivot.");

if constexpr (Master_matrix::Option_list::is_z2) {
_insert_entry(entry.get_row_index(), column_.end());
} else {
_insert_entry(entry.get_element(), entry.get_row_index(), column_.end());
}
}

template <class Master_matrix>
inline List_column<Master_matrix>& List_column<Master_matrix>::operator=(const List_column& other)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@ class Naive_vector_column : public Master_matrix::Row_access_option,
Naive_vector_column& multiply_source_and_add(const Entry_range& column, const Field_element& val);
Naive_vector_column& multiply_source_and_add(Naive_vector_column& column, const Field_element& val);

void push_back(const Entry& entry);

friend bool operator==(const Naive_vector_column& c1, const Naive_vector_column& c2) {
if (&c1 == &c2) return true;
if (c1.column_.size() != c2.column_.size()) return false;
Expand Down Expand Up @@ -801,6 +803,20 @@ inline Naive_vector_column<Master_matrix>& Naive_vector_column<Master_matrix>::m
return *this;
}

template <class Master_matrix>
inline void Naive_vector_column<Master_matrix>::push_back(const Entry& entry)
{
static_assert(Master_matrix::Option_list::is_of_boundary_type, "`push_back` is not available for Chain matrices.");

GUDHI_CHECK(entry.get_row_index() > get_pivot(), "The new row index has to be higher than the current pivot.");

if constexpr (Master_matrix::Option_list::is_z2) {
_insert_entry(entry.get_row_index(), column_);
} else {
_insert_entry(entry.get_element(), entry.get_row_index(), column_);
}
}

template <class Master_matrix>
inline Naive_vector_column<Master_matrix>& Naive_vector_column<Master_matrix>::operator=(
const Naive_vector_column& other)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,8 @@ class Set_column : public Master_matrix::Row_access_option,
Set_column& multiply_source_and_add(const Entry_range& column, const Field_element& val);
Set_column& multiply_source_and_add(Set_column& column, const Field_element& val);

void push_back(const Entry& entry);

friend bool operator==(const Set_column& c1, const Set_column& c2) {
if (&c1 == &c2) return true;

Expand Down Expand Up @@ -797,6 +799,20 @@ inline Set_column<Master_matrix>& Set_column<Master_matrix>::multiply_source_and
return *this;
}

template <class Master_matrix>
inline void Set_column<Master_matrix>::push_back(const Entry& entry)
{
static_assert(Master_matrix::Option_list::is_of_boundary_type, "`push_back` is not available for Chain matrices.");

GUDHI_CHECK(entry.get_row_index() > get_pivot(), "The new row index has to be higher than the current pivot.");

if constexpr (Master_matrix::Option_list::is_z2) {
_insert_entry(entry.get_row_index(), column_.end());
} else {
_insert_entry(entry.get_element(), entry.get_row_index(), column_.end());
}
}

template <class Master_matrix>
inline Set_column<Master_matrix>& Set_column<Master_matrix>::operator=(const Set_column& other)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,8 @@ class Unordered_set_column : public Master_matrix::Row_access_option,
Unordered_set_column& multiply_source_and_add(const Entry_range& column, const Field_element& val);
Unordered_set_column& multiply_source_and_add(Unordered_set_column& column, const Field_element& val);

void push_back(const Entry& entry);

friend bool operator==(const Unordered_set_column& c1, const Unordered_set_column& c2) {
if (&c1 == &c2) return true;
if (c1.column_.size() != c2.column_.size()) return false;
Expand Down Expand Up @@ -800,6 +802,20 @@ inline Unordered_set_column<Master_matrix>& Unordered_set_column<Master_matrix>:
return *this;
}

template <class Master_matrix>
inline void Unordered_set_column<Master_matrix>::push_back(const Entry& entry)
{
static_assert(Master_matrix::Option_list::is_of_boundary_type, "`push_back` is not available for Chain matrices.");

GUDHI_CHECK(entry.get_row_index() > get_pivot(), "The new row index has to be higher than the current pivot.");

if constexpr (Master_matrix::Option_list::is_z2) {
_insert_entry(entry.get_row_index());
} else {
_insert_entry(entry.get_element(), entry.get_row_index());
}
}

template <class Master_matrix>
inline Unordered_set_column<Master_matrix>& Unordered_set_column<Master_matrix>::operator=(
const Unordered_set_column& other)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ class Vector_column : public Master_matrix::Row_access_option,
Vector_column& multiply_source_and_add(const Entry_range& column, const Field_element& val);
Vector_column& multiply_source_and_add(Vector_column& column, const Field_element& val);

void push_back(const Entry& entry);

std::size_t compute_hash_value();

friend bool operator==(const Vector_column& c1, const Vector_column& c2) {
Expand Down Expand Up @@ -903,6 +905,20 @@ inline Vector_column<Master_matrix>& Vector_column<Master_matrix>::multiply_sour
return *this;
}

template <class Master_matrix>
inline void Vector_column<Master_matrix>::push_back(const Entry& entry)
{
static_assert(Master_matrix::Option_list::is_of_boundary_type, "`push_back` is not available for Chain matrices.");

GUDHI_CHECK(entry.get_row_index() > get_pivot(), "The new row index has to be higher than the current pivot.");

if constexpr (Master_matrix::Option_list::is_z2) {
_insert_entry(entry.get_row_index(), column_);
} else {
_insert_entry(entry.get_element(), entry.get_row_index(), column_);
}
}

template <class Master_matrix>
inline Vector_column<Master_matrix>& Vector_column<Master_matrix>::operator=(const Vector_column& other)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ inline RU_representative_cycles<Master_matrix>::RU_representative_cycles(
template <class Master_matrix>
inline void RU_representative_cycles<Master_matrix>::update_representative_cycles()
{
if constexpr (Master_matrix::Option_list::has_vine_update) {
if constexpr (Master_matrix::Option_list::is_z2) {
birthToCycle_.clear();
birthToCycle_.resize(_matrix()->reducedMatrixR_.get_number_of_columns(), -1);
Index c = 0;
Expand Down
Loading

0 comments on commit 1cf5a3d

Please sign in to comment.