diff --git a/.circleci/config.yml b/.circleci/config.yml index 4ff7f81f2e..2b42816e9e 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -42,6 +42,25 @@ jobs: make all ctest --output-on-failure + debug_tests: + docker: + - image: gudhi/ci_for_gudhi:latest + steps: + - checkout + - run: + name: Checkout submodules + command: | + git submodule sync + git submodule update --init + - run: + name: Build and test unitary tests + command: | + mkdir build + cd build + cmake -DCMAKE_BUILD_TYPE=Debug -DCMAKE_CXX_FLAGS=-fsanitize=address -DWITH_GUDHI_EXAMPLE=OFF -DWITH_GUDHI_TEST=ON -DWITH_GUDHI_UTILITIES=OFF -DWITH_GUDHI_PYTHON=OFF .. + make all + ctest --output-on-failure + utils: docker: - image: gudhi/ci_for_gudhi:latest @@ -322,6 +341,7 @@ workflows: - python_without_cgal - examples - tests + - debug_tests - utils - python - doxygen diff --git a/scripts/build_osx_universal_gmpfr.sh b/scripts/build_osx_universal_gmpfr.sh index 3dafa3ce10..516730c61d 100755 --- a/scripts/build_osx_universal_gmpfr.sh +++ b/scripts/build_osx_universal_gmpfr.sh @@ -5,15 +5,16 @@ set -e # Assumes that the user has enough rights to run brew fetch # Downloading +SED_PGM='s/^Downloaded to: |^Already downloaded: //p' mkdir deps-amd64 cd deps-amd64 -tar xf "`brew fetch --bottle-tag=big_sur gmp | sed -ne 's/^Downloaded to: //p'`" -tar xf "`brew fetch --bottle-tag=big_sur mpfr | sed -ne 's/^Downloaded to: //p'`" +tar xf "`brew fetch --bottle-tag=x86_64_monterey gmp | grep -F bottle.tar.gz | sed -Ene "$SED_PGM"`" +tar xf "`brew fetch --bottle-tag=x86_64_monterey mpfr | grep -F bottle.tar.gz | sed -Ene "$SED_PGM"`" cd .. mkdir deps-arm64 cd deps-arm64 -tar xf "`brew fetch --bottle-tag=arm64_big_sur gmp | sed -ne 's/^Downloaded to: //p'`" -tar xf "`brew fetch --bottle-tag=arm64_big_sur mpfr | sed -ne 's/^Downloaded to: //p'`" +tar xf "`brew fetch --bottle-tag=arm64_monterey gmp | grep -F bottle.tar.gz | sed -Ene "$SED_PGM"`" +tar xf "`brew fetch --bottle-tag=arm64_monterey mpfr | grep -F bottle.tar.gz | sed -Ene "$SED_PGM"`" cd .. # Merging @@ -33,9 +34,10 @@ install_name_tool -id $PWD/deps-uni/lib/$GMP deps-uni/lib/$GMP install_name_tool -id $PWD/deps-uni/lib/$GMPXX deps-uni/lib/$GMPXX install_name_tool -id $PWD/deps-uni/lib/$MPFR deps-uni/lib/$MPFR # Also fix dependencies -BADGMP=`otool -L deps-uni/lib/$MPFR|sed -ne 's/[[:space:]]*\(.*libgmp\..*dylib\).*/\1/p'` +# otool gives twice the same dependency, keep only one (a loop would be safer...) +BADGMP=`otool -L deps-uni/lib/$MPFR|sed -ne 's/[[:space:]]*\(.*libgmp\..*dylib\).*/\1/p'|uniq` install_name_tool -change $BADGMP $PWD/deps-uni/lib/$GMP deps-uni/lib/$MPFR -BADGMP=`otool -L deps-uni/lib/$GMPXX|sed -ne 's/[[:space:]]*\(.*libgmp\..*dylib\).*/\1/p'` +BADGMP=`otool -L deps-uni/lib/$GMPXX|sed -ne 's/[[:space:]]*\(.*libgmp\..*dylib\).*/\1/p'|uniq` install_name_tool -change $BADGMP $PWD/deps-uni/lib/$GMP deps-uni/lib/$GMPXX ln -s $GMP deps-uni/lib/libgmp.dylib diff --git a/src/Persistence_representations/include/gudhi/Persistence_landscape.h b/src/Persistence_representations/include/gudhi/Persistence_landscape.h index ce4065b8a0..a384bc17f8 100644 --- a/src/Persistence_representations/include/gudhi/Persistence_landscape.h +++ b/src/Persistence_representations/include/gudhi/Persistence_landscape.h @@ -388,6 +388,7 @@ class Persistence_landscape { nextLevelMerge.swap(nextNextLevelMerge); } (*this) = (*nextLevelMerge[0]); + if (!is_this_first_level) delete nextLevelMerge[0]; (*this) *= 1 / static_cast(to_average.size()); } diff --git a/src/Simplex_tree/include/gudhi/Simplex_tree.h b/src/Simplex_tree/include/gudhi/Simplex_tree.h index 9f93672d63..ae4545876c 100644 --- a/src/Simplex_tree/include/gudhi/Simplex_tree.h +++ b/src/Simplex_tree/include/gudhi/Simplex_tree.h @@ -5,8 +5,10 @@ * Copyright (C) 2014 Inria * * Modification(s): - * - 2020/09 Clément Maria: option to link all simplex tree nodes with same label in an intrusive list. * - 2023/02 Vincent Rouvreau: Add de/serialize methods for pickle feature + * - 2023/07 Clément Maria: Option to link all simplex tree nodes with same label in an intrusive list + * - 2023/05 Clément Maria: Edge insertion method for flag complexes + * - 2023/05 Hannah Schreiber: Factorization of expansion methods * - 2023/08 Hannah Schreiber (& Clément Maria): Add possibility of stable simplex handles. * - YYYY/MM Author: Description of the modification */ @@ -38,6 +40,7 @@ #include #include +#include #ifdef GUDHI_USE_TBB #include @@ -540,8 +543,11 @@ class Simplex_tree { } public: + template friend class Simplex_tree; + /** \brief Checks if two simplex trees are equal. */ - bool operator==(Simplex_tree& st2) { + template + bool operator==(Simplex_tree& st2) { if ((null_vertex_ != st2.null_vertex_) || (dimension_ != st2.dimension_ && !dimension_to_be_lowered_ && !st2.dimension_to_be_lowered_)) return false; @@ -549,17 +555,21 @@ class Simplex_tree { } /** \brief Checks if two simplex trees are different. */ - bool operator!=(Simplex_tree& st2) { + template + bool operator!=(Simplex_tree& st2) { return (!(*this == st2)); } private: /** rec_equal: Checks recursively whether or not two simplex trees are equal, using depth first search. */ - bool rec_equal(Siblings* s1, Siblings* s2) { + template + bool rec_equal(Siblings* s1, OtherSiblings* s2) { if (s1->members().size() != s2->members().size()) return false; - for (auto sh1 = s1->members().begin(), sh2 = s2->members().begin(); - (sh1 != s1->members().end() && sh2 != s2->members().end()); ++sh1, ++sh2) { + auto sh2 = s2->members().begin(); + for (auto sh1 = s1->members().begin(); + (sh1 != s1->members().end() && sh2 != s2->members().end()); + ++sh1, ++sh2) { if (sh1->first != sh2->first || sh1->second.filtration() != sh2->second.filtration()) return false; if (has_children(sh1) != has_children(sh2)) @@ -672,6 +682,21 @@ class Simplex_tree { return simplices_number; } + /** + * @brief Returns the dimension of the given sibling simplices. + * + * @param curr_sib Pointer to the sibling container. + * @return Height of the siblings in the tree (root counts as zero to make the height correspond to the dimension). + */ + int dimension(Siblings* curr_sib) { + int dim = -1; + while (curr_sib != nullptr) { + ++dim; + curr_sib = curr_sib->oncles(); + } + return dim; + } + public: /** \brief Returns the number of simplices of each dimension in the simplex tree. */ std::vector num_simplices_by_dimension() { @@ -696,13 +721,7 @@ class Simplex_tree { * * Must be different from null_simplex().*/ int dimension(Simplex_handle sh) { - Siblings * curr_sib = self_siblings(sh); - int dim = 0; - while (curr_sib != nullptr) { - ++dim; - curr_sib = curr_sib->oncles(); - } - return dim - 1; + return dimension(self_siblings(sh)); } /** \brief Returns an upper bound on the dimension of the simplicial complex. */ @@ -721,7 +740,7 @@ class Simplex_tree { } /** \brief Returns true if the node in the simplex tree pointed by - * sh has children.*/ + * the given simplex handle has children.*/ template bool has_children(SimplexHandle sh) const { // Here we rely on the root using null_vertex(), which cannot match any real vertex. @@ -1354,11 +1373,253 @@ class Simplex_tree { dimension_ = max_dim - dimension_; } + /** + * @brief Adds a new vertex or a new edge in a flag complex, as well as all + * simplices of its star, defined to maintain the property + * of the complex to be a flag complex, truncated at dimension dim_max. + * To insert a new edge, the two given vertex handles have to correspond + * to the two end points of the edge. To insert a new vertex, the handles + * have to be twice the same and correspond to the number you want assigned + * to it. I.e., to insert vertex \f$i\f$, give \f$u = v = i\f$. + * The method assumes that the given edge was not already contained in + * the simplex tree, so the behaviour is undefined if called on an existing + * edge. Also, the vertices of an edge have to be inserted before the edge. + * + * @param[in] u,v Vertex_handle representing the new edge + * (@p v != @p u) or the new vertex (@p v == @p u). + * @param[in] fil Filtration value of the edge. + * @param[in] dim_max Maximal dimension of the expansion. + * If set to -1, the expansion goes as far as possible. + * @param[out] added_simplices Contains at the end all new + * simplices induced by the insertion of the edge. + * The container is not emptied and new simplices are + * appended at the end. + * + * @pre `SimplexTreeOptions::link_nodes_by_label` must be true. + * @pre When inserting the edge `[u,v]`, the vertices @p u and @p v have to be + * already inserted in the simplex tree. + * + * @warning If the edges and vertices are not inserted in the order of their + * filtration values, the method `make_filtration_non_decreasing()` has to be + * called at the end of the insertions to restore the intended filtration. + * Note that even then, an edge has to be inserted after its vertices. + * @warning The method assumes that the given edge or vertex was not already + * contained in the simplex tree, so the behaviour is undefined if called on + * an existing simplex. + */ + void insert_edge_as_flag( Vertex_handle u + , Vertex_handle v + , Filtration_value fil + , int dim_max + , std::vector& added_simplices) + { + /** + * In term of edges in the graph, inserting edge `[u,v]` only affects + * the subtree rooted at @p u. + * + * For a new node with label @p v, we first do a local expansion for + * computing the children of this new node, and then a standard expansion + * for its children. + * Nodes with label @p v (and their subtrees) already in the tree + * do not get affected. + * + * Nodes with label @p u get affected only if a Node with label @p v is in their same + * siblings set. + * We then try to insert "ponctually" @p v all over the subtree rooted + * at `Node(u)`. Each insertion of a Node with @p v label induces a local + * expansion at this Node (as explained above) and a sequence of "ponctual" + * insertion of `Node(v)` in the subtree rooted at sibling nodes of the new node, + * on its left. + */ + + static_assert(Options::link_nodes_by_label, "Options::link_nodes_by_label must be true"); + + if (u == v) { // Are we inserting a vertex? + auto res_ins = root_.members().emplace(u, Node(&root_,fil)); + if (res_ins.second) { //if the vertex is not in the complex, insert it + added_simplices.push_back(res_ins.first); //no more insert in root_.members() + update_simplex_tree_after_node_insertion(res_ins.first); + if (dimension_ == -1) dimension_ = 0; + } + return; //because the vertex is isolated, no more insertions. + } + // else, we are inserting an edge: ensure that u < v + if (v < u) { std::swap(u,v); } + + //Note that we copy Simplex_handle (aka map iterators) in added_simplices + //while we are still modifying the Simplex_tree. Insertions in siblings may + //invalidate Simplex_handles; we take care of this fact by first doing all + //insertion in a Sibling, then inserting all handles in added_simplices. + +#ifdef GUDHI_DEBUG + //check whether vertices u and v are in the tree. If not, return an error. + auto sh_u = root_.members().find(u); + GUDHI_CHECK(sh_u != root_.members().end() && + root_.members().find(v) != root_.members().end(), + std::invalid_argument( + "Simplex_tree::insert_edge_as_flag - inserts an edge whose vertices are not in the complex") + ); + GUDHI_CHECK(!has_children(sh_u) || + sh_u->second.children()->members().find(v) == sh_u->second.children()->members().end(), + std::invalid_argument( + "Simplex_tree::insert_edge_as_flag - inserts an already existing edge") + ); +#endif + + // to update dimension + const auto tmp_dim = dimension_; + auto tmp_max_dim = dimension_; + + //for all siblings containing a Node labeled with u (including the root), run + //compute_punctual_expansion + //todo parallelise + List_max_vertex* nodes_with_label_u = nodes_by_label(u);//all Nodes with u label + + GUDHI_CHECK(nodes_with_label_u != nullptr, + "Simplex_tree::insert_edge_as_flag - cannot find the list of Nodes with label u"); + + for (auto&& node_as_hook : *nodes_with_label_u) + { + Node& node_u = static_cast(node_as_hook); //corresponding node, has label u + Simplex_handle sh_u = simplex_handle_from_node(node_u); + Siblings * sib_u = self_siblings(sh_u); + if (sib_u->members().find(v) != sib_u->members().end()) { //v is the label of a sibling of node_u + int curr_dim = dimension(sib_u); + if (dim_max == -1 || curr_dim < dim_max){ + if (!has_children(sh_u)) { + //then node_u was a leaf and now has a new child Node labeled v + //the child v is created in compute_punctual_expansion + node_u.assign_children(new Siblings(sib_u, u)); + } + dimension_ = dim_max - curr_dim - 1; + compute_punctual_expansion( + v, + node_u.children(), + fil, + dim_max - curr_dim - 1, //>= 0 if dim_max >= 0, <0 otherwise + added_simplices ); + dimension_ = dim_max - dimension_; + if (dimension_ > tmp_max_dim) tmp_max_dim = dimension_; + } + } + } + if (tmp_dim <= tmp_max_dim){ + dimension_ = tmp_max_dim; + dimension_to_be_lowered_ = false; + } else { + dimension_ = tmp_dim; + } + } + private: - /** \brief Recursive expansion of the simplex tree.*/ + /** \brief Inserts a Node with label @p v in the set of siblings sib, and percolate the + * expansion on the subtree rooted at sib. Sibling sib must not contain + * @p v. + * The percolation of the expansion is twofold: + * 1- the newly inserted Node labeled @p v in sib has a subtree computed + * via create_local_expansion. + * 2- All Node in the members of sib, with label @p x and @p x < @p v, + * need in turn a local_expansion by @p v iff N^+(x) contains @p v. + */ + void compute_punctual_expansion( Vertex_handle v + , Siblings * sib + , Filtration_value fil + , int k //k == dim_max - dimension simplices in sib + , std::vector& added_simplices ) + { //insertion always succeeds because the edge {u,v} used to not be here. + auto res_ins_v = sib->members().emplace(v, Node(sib,fil)); + added_simplices.push_back(res_ins_v.first); //no more insertion in sib + update_simplex_tree_after_node_insertion(res_ins_v.first); + + if (k == 0) { // reached the maximal dimension. if max_dim == -1, k is never equal to 0. + dimension_ = 0; // to keep track of the max height of the recursion tree + return; + } + + //create the subtree of new Node(v) + create_local_expansion( res_ins_v.first + , sib + , fil + , k + , added_simplices ); + + //punctual expansion in nodes on the left of v, i.e. with label x < v + for (auto sh = sib->members().begin(); sh != res_ins_v.first; ++sh) + { //if v belongs to N^+(x), punctual expansion + Simplex_handle root_sh = find_vertex(sh->first); //Node(x), x < v + if (has_children(root_sh) && + root_sh->second.children()->members().find(v) != root_sh->second.children()->members().end()) + { //edge {x,v} is in the complex + if (!has_children(sh)){ + sh->second.assign_children(new Siblings(sib, sh->first)); + } + //insert v in the children of sh, and expand. + compute_punctual_expansion( v + , sh->second.children() + , fil + , k-1 + , added_simplices ); + } + } + } + + /** \brief After the insertion of edge `{u,v}`, expansion of a subtree rooted at @p v, where the + * Node with label @p v has just been inserted, and its parent is a Node labeled with + * @p u. sh has no children here. + * + * k must be > 0 + */ + void create_local_expansion( + Simplex_handle sh_v //Node with label v which has just been inserted + , Siblings * curr_sib //Siblings containing the node sh_v + , Filtration_value fil_uv //Fil value of the edge uv in the zz filtration + , int k //Stopping condition for recursion based on max dim + , std::vector &added_simplices) //range of all new simplices + { //pick N^+(v) + //intersect N^+(v) with labels y > v in curr_sib + Simplex_handle next_it = sh_v; + ++next_it; + + if (dimension_ > k) { + dimension_ = k; //to keep track of the max height of the recursion tree + } + + create_expansion(curr_sib, sh_v, next_it, fil_uv, k, &added_simplices); + } + //TODO boost::container::ordered_unique_range_t in the creation of a Siblings + + /** \brief Global expansion of a subtree in the simplex tree. + * + * The filtration value is absolute and defined by `Filtration_value fil`. + * The new Node are also connected appropriately in the coface + * data structure. + * + * Only called in the case of `void insert_edge_as_flag(...)`. + */ + void siblings_expansion( + Siblings * siblings // must contain elements + , Filtration_value fil + , int k // == max_dim expansion - dimension curr siblings + , std::vector & added_simplices ) + { + if (dimension_ > k) { + dimension_ = k; //to keep track of the max height of the recursion tree + } + if (k == 0) { return; } //max dimension + Dictionary_it next = ++(siblings->members().begin()); + + for (Dictionary_it s_h = siblings->members().begin(); + next != siblings->members().end(); ++s_h, ++next) + { //find N^+(s_h) + create_expansion(siblings, s_h, next, fil, k, &added_simplices); + } + } + + /** \brief Recursive expansion of the simplex tree. + * Only called in the case of `void expansion(int max_dim)`. */ void siblings_expansion(Siblings * siblings, // must contain elements int k) { - if (dimension_ > k) { + if (k >= 0 && dimension_ > k) { dimension_ = k; } if (k == 0) @@ -1366,40 +1627,65 @@ class Simplex_tree { Dictionary_it next = siblings->members().begin(); ++next; - thread_local std::vector > inter; for (Dictionary_it s_h = siblings->members().begin(); - s_h != siblings->members().end(); ++s_h, ++next) { - Simplex_handle root_sh = find_vertex(s_h->first); - if (has_children(root_sh)) { - intersection( - inter, // output intersection - next, // begin - siblings->members().end(), // end - root_sh->second.children()->members().begin(), - root_sh->second.children()->members().end(), - s_h->second.filtration()); - if (inter.size() != 0) { - Siblings * new_sib = new Siblings(siblings, // oncles - s_h->first, // parent - inter); // boost::container::ordered_unique_range_t - for (auto it = new_sib->members().begin(); it != new_sib->members().end(); ++it) { - update_simplex_tree_after_node_insertion(it); - } + s_h != siblings->members().end(); ++s_h, ++next) + { + create_expansion(siblings, s_h, next, s_h->second.filtration(), k); + } + } - inter.clear(); - s_h->second.assign_children(new_sib); - siblings_expansion(new_sib, k - 1); - } else { - // ensure the children property - s_h->second.assign_children(siblings); - inter.clear(); + /** \brief Recursive expansion of the simplex tree. + * The method is used with `force_filtration_value == true` by `void insert_edge_as_flag(...)` and with + * `force_filtration_value == false` by `void expansion(int max_dim)`. Therefore, `added_simplices` is assumed + * to bon non-null in the first case and null in the second.*/ + template + void create_expansion(Siblings * siblings, + Dictionary_it& s_h, + Dictionary_it& next, + Filtration_value fil, + int k, + std::vector* added_simplices = nullptr) + { + Simplex_handle root_sh = find_vertex(s_h->first); + thread_local std::vector > inter; + + if (!has_children(root_sh)) return; + + intersection( + inter, // output intersection + next, // begin + siblings->members().end(), // end + root_sh->second.children()->members().begin(), + root_sh->second.children()->members().end(), + fil); + if (inter.size() != 0) { + Siblings * new_sib = new Siblings(siblings, // oncles + s_h->first, // parent + inter); // boost::container::ordered_unique_range_t + for (auto it = new_sib->members().begin(); it != new_sib->members().end(); ++it) { + update_simplex_tree_after_node_insertion(it); + if constexpr (force_filtration_value){ + //the way create_expansion is used, added_simplices != nullptr when force_filtration_value == true + added_simplices->push_back(it); } } + inter.clear(); + s_h->second.assign_children(new_sib); + if constexpr (force_filtration_value){ + siblings_expansion(new_sib, fil, k - 1, *added_simplices); + } else { + siblings_expansion(new_sib, k - 1); + } + } else { + // ensure the children property + s_h->second.assign_children(siblings); + inter.clear(); } } /** \brief Intersects Dictionary 1 [begin1;end1) with Dictionary 2 [begin2,end2) * and assigns the maximal possible Filtration_value to the Nodes. */ + template static void intersection(std::vector >& intersection, Dictionary_it begin1, Dictionary_it end1, Dictionary_it begin2, Dictionary_it end2, @@ -1408,8 +1694,12 @@ class Simplex_tree { return; // ----->> while (true) { if (begin1->first == begin2->first) { - Filtration_value filt = (std::max)({begin1->second.filtration(), begin2->second.filtration(), filtration_}); - intersection.emplace_back(begin1->first, Node(nullptr, filt)); + if constexpr (force_filtration_value){ + intersection.emplace_back(begin1->first, Node(nullptr, filtration_)); + } else { + Filtration_value filt = (std::max)({begin1->second.filtration(), begin2->second.filtration(), filtration_}); + intersection.emplace_back(begin1->first, Node(nullptr, filt)); + } if (++begin1 == end1 || ++begin2 == end2) return; // ----->> } else if (begin1->first < begin2->first) { @@ -1484,9 +1774,10 @@ class Simplex_tree { } if (intersection.size() != 0) { // Reverse the order to insert - Siblings * new_sib = new Siblings(siblings, // oncles - simplex->first, // parent - boost::adaptors::reverse(intersection)); // boost::container::ordered_unique_range_t + Siblings * new_sib = new Siblings( + siblings, // oncles + simplex->first, // parent + boost::adaptors::reverse(intersection)); // boost::container::ordered_unique_range_t simplex->second.assign_children(new_sib); std::vector blocked_new_sib_vertex_list; // As all intersections are inserted, we can call the blocker function on all new_sib members @@ -1662,16 +1953,38 @@ class Simplex_tree { private: bool rec_prune_above_filtration(Siblings* sib, Filtration_value filt) { auto&& list = sib->members(); - auto last = std::remove_if(list.begin(), list.end(), [this,filt](Dit_value_t& simplex) { - if (simplex.second.filtration() <= filt) return false; - if (has_children(&simplex)) rec_delete(simplex.second.children()); - // dimension may need to be lowered - dimension_to_be_lowered_ = true; - return true; - }); + bool modified = false; + bool emptied = false; + Simplex_handle last; + + auto to_remove = [this, filt](Dit_value_t& simplex) { + if (simplex.second.filtration() <= filt) return false; + if (has_children(&simplex)) rec_delete(simplex.second.children()); + // dimension may need to be lowered + dimension_to_be_lowered_ = true; + return true; + }; - bool modified = (last != list.end()); - if (last == list.begin() && sib != root()) { + //TODO: `if constexpr` replacable by `std::erase_if` in C++20? Has a risk of additional runtime, + //so to benchmark first. + if constexpr (Options::stable_simplex_handles) { + modified = false; + for (auto sh = list.begin(); sh != list.end();) { + if (to_remove(*sh)) { + sh = list.erase(sh); + modified = true; + } else { + ++sh; + } + } + emptied = (list.empty() && sib != root()); + } else { + last = std::remove_if(list.begin(), list.end(), to_remove); + modified = (last != list.end()); + emptied = (last == list.begin() && sib != root()); + } + + if (emptied) { // Removing the whole siblings, parent becomes a leaf. sib->oncles()->members()[sib->parent()].assign_children(sib->oncles()); delete sib; @@ -1680,11 +1993,11 @@ class Simplex_tree { return true; } else { // Keeping some elements of siblings. Remove the others, and recurse in the remaining ones. - list.erase(last, list.end()); + if constexpr (!Options::stable_simplex_handles) list.erase(last, list.end()); for (auto&& simplex : list) - if (has_children(&simplex)) - modified |= rec_prune_above_filtration(simplex.second.children(), filt); + if (has_children(&simplex)) modified |= rec_prune_above_filtration(simplex.second.children(), filt); } + return modified; } @@ -2220,11 +2533,12 @@ class Simplex_tree { } } Vertex_handle child_size; - for (auto& map_el : sib->members()) { + for (auto sh = sib->members().begin(); sh != sib->members().end(); ++sh) { + update_simplex_tree_after_node_insertion(sh); ptr = Gudhi::simplex_tree::deserialize_trivial(child_size, ptr); if (child_size > 0) { - Siblings* child = new Siblings(sib, map_el.first); - map_el.second.assign_children(child); + Siblings* child = new Siblings(sib, sh->first); + sh->second.assign_children(child); ptr = rec_deserialize(child, child_size, ptr, dim + 1); } } diff --git a/src/Simplex_tree/test/CMakeLists.txt b/src/Simplex_tree/test/CMakeLists.txt index e4535e1b7e..5153ca0ccf 100644 --- a/src/Simplex_tree/test/CMakeLists.txt +++ b/src/Simplex_tree/test/CMakeLists.txt @@ -47,8 +47,14 @@ if(TARGET TBB::tbb) endif() gudhi_add_boost_test(Simplex_tree_serialization_test_unit) +add_executable ( Simplex_tree_edge_expansion_unit_test simplex_tree_edge_expansion_unit_test.cpp ) +if(TARGET TBB::tbb) + target_link_libraries(Simplex_tree_edge_expansion_unit_test TBB::tbb) +endif() +gudhi_add_boost_test(Simplex_tree_edge_expansion_unit_test) + add_executable ( Simplex_tree_extended_filtration_test_unit simplex_tree_extended_filtration_unit_test.cpp ) if(TARGET TBB::tbb) target_link_libraries(Simplex_tree_extended_filtration_test_unit TBB::tbb) endif() -gudhi_add_boost_test(Simplex_tree_extended_filtration_test_unit) \ No newline at end of file +gudhi_add_boost_test(Simplex_tree_extended_filtration_test_unit) diff --git a/src/Simplex_tree/test/simplex_tree_edge_expansion_unit_test.cpp b/src/Simplex_tree/test/simplex_tree_edge_expansion_unit_test.cpp new file mode 100644 index 0000000000..ed7d08f391 --- /dev/null +++ b/src/Simplex_tree/test/simplex_tree_edge_expansion_unit_test.cpp @@ -0,0 +1,365 @@ +/* This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT. + * See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details. + * Author(s): Hannah Schreiber + * + * Copyright (C) 2023 Inria + * + * Modification(s): + * - YYYY/MM Author: Description of the modification + */ + +#include +#include +#include + +#define BOOST_TEST_DYN_LINK +#define BOOST_TEST_MODULE "simplex_tree_edge_expansion" +#include +#include + +#include "gudhi/Simplex_tree.h" +#include +#include +#include +#include + +using namespace Gudhi; + +struct Simplex_tree_options_stable_simplex_handles { + typedef linear_indexing_tag Indexing_tag; + typedef int Vertex_handle; + typedef double Filtration_value; + typedef std::uint32_t Simplex_key; + static const bool store_key = true; + static const bool store_filtration = true; + static const bool contiguous_vertices = false; + static const bool link_nodes_by_label = true; + static const bool stable_simplex_handles = true; +}; + +using Point = std::vector; + +std::vector build_point_cloud(unsigned int numberOfPoints, int seed){ + std::vector finalPoints; + std::set points; + std::random_device dev; + std::mt19937 rng(dev()); + if (seed > -1) rng.seed(seed); + std::uniform_real_distribution dist(0,10); + + for (unsigned int i = 0; i < numberOfPoints; ++i){ + auto res = points.insert({dist(rng), dist(rng)}); + while(!res.second){ + res = points.insert({dist(rng), dist(rng)}); + } + finalPoints.push_back(*res.first); + } + + return finalPoints; +} + +template +void test_insert_as_flag(ST_type& simplex_tree, int maxDim = -1) { + Simplex_tree st_to_test; + std::vector::Simplex_handle> added_simplices; + + for (auto& sh : simplex_tree.filtration_simplex_range()) { + if (simplex_tree.dimension(sh) == 0) { + auto u = *simplex_tree.simplex_vertex_range(sh).begin(); + st_to_test.insert_edge_as_flag(u, u, simplex_tree.filtration(sh), maxDim, added_simplices); + } else if (simplex_tree.dimension(sh) == 1) { + auto it = simplex_tree.simplex_vertex_range(sh).begin(); + auto u = *it; + auto v = *(++it); + st_to_test.insert_edge_as_flag(u, v, simplex_tree.filtration(sh), maxDim, added_simplices); + } + } + + BOOST_CHECK(added_simplices.size() == simplex_tree.num_simplices()); + BOOST_CHECK(st_to_test.dimension() == simplex_tree.dimension()); + BOOST_CHECK(st_to_test == simplex_tree); +} + +template +void test_unordered_insert_as_flag(ST_type& simplex_tree, int maxDim = -1) { + Simplex_tree st_to_test; + std::vector::Simplex_handle> added_simplices; + + // complex_simplex_range gives a lexicographical order where the word "12" is shorter than "1", + // so vertices have to be inserted separately + for (auto& sh : simplex_tree.skeleton_simplex_range(0)) { + auto u = *simplex_tree.simplex_vertex_range(sh).begin(); + st_to_test.insert_edge_as_flag(u, u, simplex_tree.filtration(sh), maxDim, added_simplices); + } + + for (auto& sh : simplex_tree.complex_simplex_range()) { + if (simplex_tree.dimension(sh) == 1) { + auto it = simplex_tree.simplex_vertex_range(sh).begin(); + auto u = *it; + auto v = *(++it); + st_to_test.insert_edge_as_flag(u, v, simplex_tree.filtration(sh), maxDim, added_simplices); + } + } + // as the insertions are not in the order of the filtration, + // the filtration values for higher dimensional simplices has to be restored + st_to_test.make_filtration_non_decreasing(); + + BOOST_CHECK(added_simplices.size() == simplex_tree.num_simplices()); + BOOST_CHECK(st_to_test.dimension() == simplex_tree.dimension()); + BOOST_CHECK(st_to_test == simplex_tree); +} + +BOOST_AUTO_TEST_CASE(simplex_tree_random_rips_expansion) { + using Filtration_value = typename Simplex_tree<>::Filtration_value; + using Rips_complex = Gudhi::rips_complex::Rips_complex; + + for (unsigned int i = 10; i < 100; i += 5){ + Simplex_tree<> simplex_tree; + Rips_complex rips_complex(build_point_cloud(i, -1), 3, Gudhi::Euclidean_distance()); + rips_complex.create_complex(simplex_tree, 100); + + std::clog << i << " ********************************************************************\n"; + std::clog << "simplex_tree_random_rips_expansion\n"; + std::clog << "Test tree 1: insertion by filtration values\n"; + std::clog << "************************************************************************\n"; + test_insert_as_flag,Simplex_tree_options_fast_cofaces>(simplex_tree); + test_insert_as_flag,Simplex_tree_options_stable_simplex_handles>(simplex_tree); + + std::clog << i << " ********************************************************************\n"; + std::clog << "simplex_tree_random_rips_expansion\n"; + std::clog << "Test tree 2: insertion by lexicographical order\n"; //equivalent to random order + std::clog << "************************************************************************\n"; + test_unordered_insert_as_flag,Simplex_tree_options_fast_cofaces>(simplex_tree); + test_unordered_insert_as_flag,Simplex_tree_options_stable_simplex_handles>(simplex_tree); + } +} + +BOOST_AUTO_TEST_CASE(simplex_tree_random_rips_expansion_with_max_dim) { + using Filtration_value = typename Simplex_tree<>::Filtration_value; + using Rips_complex = Gudhi::rips_complex::Rips_complex; + + for (unsigned int i = 10; i < 100; i += 5){ + int maxDim = i / 5; + Simplex_tree<> simplex_tree; + Rips_complex rips_complex(build_point_cloud(i, -1), 3, Gudhi::Euclidean_distance()); + rips_complex.create_complex(simplex_tree, maxDim); + + std::clog << i << " ********************************************************************\n"; + std::clog << "simplex_tree_random_rips_expansion\n"; + std::clog << "Test tree 3: insertion by filtration values and max dimension = " << maxDim << "\n"; + std::clog << "************************************************************************\n"; + test_insert_as_flag,Simplex_tree_options_fast_cofaces>(simplex_tree, maxDim); + test_insert_as_flag,Simplex_tree_options_stable_simplex_handles>(simplex_tree, maxDim); + + std::clog << i << " ********************************************************************\n"; + std::clog << "simplex_tree_random_rips_expansion\n"; + std::clog << "Test tree 4: insertion by lexicographical order and max dimension = " << maxDim << "\n"; //equivalent to random order + std::clog << "************************************************************************\n"; + test_unordered_insert_as_flag,Simplex_tree_options_fast_cofaces>(simplex_tree, maxDim); + test_unordered_insert_as_flag,Simplex_tree_options_stable_simplex_handles>(simplex_tree, maxDim); + } +} + +BOOST_AUTO_TEST_CASE(flag_expansion) { + using typeST = Simplex_tree; + { + std::cout << "************************************************************************" << std::endl; + std::cout << "Test flag expansion 1" << std::endl; + std::cout << "************************************************************************" << std::endl; + + typeST st; + std::vector added_simplices; + + st.insert_edge_as_flag(0,0,0,6,added_simplices); + st.insert_edge_as_flag(1,1,0,6,added_simplices); + st.insert_edge_as_flag(2,2,0,6,added_simplices); + st.insert_edge_as_flag(3,3,0,6,added_simplices); + st.insert_edge_as_flag(4,4,0,6,added_simplices); + st.insert_edge_as_flag(5,5,0,6,added_simplices); + st.insert_edge_as_flag(6,6,0,6,added_simplices); + st.insert_edge_as_flag(0,1,0,6,added_simplices); + st.insert_edge_as_flag(0,2,0,6,added_simplices); + st.insert_edge_as_flag(0,3,0,6,added_simplices); + st.insert_edge_as_flag(0,4,0,6,added_simplices); + st.insert_edge_as_flag(0,5,0,6,added_simplices); + st.insert_edge_as_flag(0,6,0,6,added_simplices); + st.insert_edge_as_flag(1,2,0,6,added_simplices); + st.insert_edge_as_flag(1,3,0,6,added_simplices); + st.insert_edge_as_flag(1,4,0,6,added_simplices); + st.insert_edge_as_flag(1,5,0,6,added_simplices); + st.insert_edge_as_flag(1,6,0,6,added_simplices); + st.insert_edge_as_flag(2,3,0,6,added_simplices); + st.insert_edge_as_flag(2,4,0,6,added_simplices); + st.insert_edge_as_flag(2,5,0,6,added_simplices); + st.insert_edge_as_flag(2,6,0,6,added_simplices); + st.insert_edge_as_flag(3,4,0,6,added_simplices); + st.insert_edge_as_flag(3,5,0,6,added_simplices); + st.insert_edge_as_flag(3,6,0,6,added_simplices); + st.insert_edge_as_flag(4,5,0,6,added_simplices); + st.insert_edge_as_flag(4,6,0,6,added_simplices); + st.insert_edge_as_flag(5,6,0,6,added_simplices); + + st.assign_filtration(st.find({0,2,4}), 10); + st.assign_filtration(st.find({1,5}), 20); + st.assign_filtration(st.find({1,2,4}), 30); + st.assign_filtration(st.find({3}), 5); + st.make_filtration_non_decreasing(); + + BOOST_CHECK_EQUAL(added_simplices.size(), 127); + BOOST_CHECK(st.filtration(st.find({1,2}))==0); + BOOST_CHECK(st.filtration(st.find({0,1,2,3,4}))==30); + BOOST_CHECK(st.minimal_simplex_with_same_filtration(st.find({0,1,2,3,4,5}))==st.find({1,2,4})); + BOOST_CHECK(st.minimal_simplex_with_same_filtration(st.find({0,2,3}))==st.find({3})); + auto s=st.minimal_simplex_with_same_filtration(st.find({0,2,6})); + BOOST_CHECK(s==st.find({0})||s==st.find({2})||s==st.find({6})); + BOOST_CHECK(st.vertex_with_same_filtration(st.find({2}))==2); + BOOST_CHECK(st.vertex_with_same_filtration(st.find({1,5}))==st.null_vertex()); + BOOST_CHECK(st.vertex_with_same_filtration(st.find({5,6}))>=5); + } + { + std::cout << "************************************************************************" << std::endl; + std::cout << "Test flag expansion 2" << std::endl; + std::cout << "************************************************************************" << std::endl; + + typeST st; + std::vector added_simplices; + + st.insert_edge_as_flag(0,0,0,6,added_simplices); + st.insert_edge_as_flag(1,1,0,6,added_simplices); + st.insert_edge_as_flag(2,2,0,6,added_simplices); + st.insert_edge_as_flag(4,4,0,6,added_simplices); + st.insert_edge_as_flag(5,5,0,6,added_simplices); + st.insert_edge_as_flag(6,6,0,6,added_simplices); + + st.insert_edge_as_flag(0,1,0,6,added_simplices); + st.insert_edge_as_flag(0,5,0,6,added_simplices); + st.insert_edge_as_flag(0,6,0,6,added_simplices); + st.insert_edge_as_flag(1,6,0,6,added_simplices); + st.insert_edge_as_flag(2,5,0,6,added_simplices); + st.insert_edge_as_flag(2,6,0,6,added_simplices); + st.insert_edge_as_flag(4,5,0,6,added_simplices); + st.insert_edge_as_flag(4,6,0,6,added_simplices); + st.insert_edge_as_flag(5,6,0,6,added_simplices); + + st.insert_edge_as_flag(3,3,5,6,added_simplices); + st.insert_edge_as_flag(0,3,5,6,added_simplices); + st.insert_edge_as_flag(1,3,5,6,added_simplices); + st.insert_edge_as_flag(2,3,5,6,added_simplices); + st.insert_edge_as_flag(3,4,5,6,added_simplices); + st.insert_edge_as_flag(3,5,5,6,added_simplices); + st.insert_edge_as_flag(3,6,5,6,added_simplices); + + st.insert_edge_as_flag(0,2,10,6,added_simplices); + st.insert_edge_as_flag(0,4,10,6,added_simplices); + st.insert_edge_as_flag(1,5,20,6,added_simplices); + st.insert_edge_as_flag(1,2,30,6,added_simplices); + st.insert_edge_as_flag(1,4,30,6,added_simplices); + st.insert_edge_as_flag(2,4,30,6,added_simplices); + + BOOST_CHECK_EQUAL(added_simplices.size(), 127); + BOOST_CHECK(st.filtration(st.find({1,2}))==30); + BOOST_CHECK(st.filtration(st.find({0,1,2,3,4}))==30); + BOOST_CHECK(st.minimal_simplex_with_same_filtration(st.find({0,1,2,3,4,5}))==st.find({1,2})); + BOOST_CHECK(st.minimal_simplex_with_same_filtration(st.find({0,2,3}))==st.find({0,2})); + auto s=st.minimal_simplex_with_same_filtration(st.find({0,2,6})); + BOOST_CHECK(s==st.find({0,2})); + BOOST_CHECK(st.vertex_with_same_filtration(st.find({2}))==2); + BOOST_CHECK(st.vertex_with_same_filtration(st.find({1,5}))==st.null_vertex()); + BOOST_CHECK(st.vertex_with_same_filtration(st.find({5,6}))>=5); + } + { + std::cout << "************************************************************************" << std::endl; + std::cout << "Test flag expansion 3" << std::endl; + std::cout << "************************************************************************" << std::endl; + + typeST st; + std::vector added_simplices; + + st.insert_edge_as_flag(0,0,0,6,added_simplices); + + st.insert_edge_as_flag(1,1,0,6,added_simplices); + st.insert_edge_as_flag(0,1,0,6,added_simplices); + + st.insert_edge_as_flag(2,2,0,6,added_simplices); + st.insert_edge_as_flag(0,2,10,6,added_simplices); + st.insert_edge_as_flag(1,2,30,6,added_simplices); + + st.insert_edge_as_flag(3,3,5,6,added_simplices); + st.insert_edge_as_flag(0,3,5,6,added_simplices); + st.insert_edge_as_flag(1,3,5,6,added_simplices); + st.insert_edge_as_flag(2,3,5,6,added_simplices); + + st.insert_edge_as_flag(4,4,0,6,added_simplices); + st.insert_edge_as_flag(0,4,10,6,added_simplices); + st.insert_edge_as_flag(1,4,30,6,added_simplices); + st.insert_edge_as_flag(2,4,30,6,added_simplices); + st.insert_edge_as_flag(3,4,5,6,added_simplices); + + st.insert_edge_as_flag(5,5,0,6,added_simplices); + st.insert_edge_as_flag(0,5,0,6,added_simplices); + st.insert_edge_as_flag(1,5,20,6,added_simplices); + st.insert_edge_as_flag(2,5,0,6,added_simplices); + st.insert_edge_as_flag(3,5,5,6,added_simplices); + st.insert_edge_as_flag(4,5,0,6,added_simplices); + + st.insert_edge_as_flag(6,6,0,6,added_simplices); + st.insert_edge_as_flag(0,6,0,6,added_simplices); + st.insert_edge_as_flag(1,6,0,6,added_simplices); + st.insert_edge_as_flag(2,6,0,6,added_simplices); + st.insert_edge_as_flag(3,6,5,6,added_simplices); + st.insert_edge_as_flag(4,6,0,6,added_simplices); + st.insert_edge_as_flag(5,6,0,6,added_simplices); + + st.make_filtration_non_decreasing(); + + BOOST_CHECK_EQUAL(added_simplices.size(), 127); + BOOST_CHECK(st.filtration(st.find({1,2}))==30); + BOOST_CHECK(st.filtration(st.find({0,1,2,3,4}))==30); + BOOST_CHECK(st.minimal_simplex_with_same_filtration(st.find({0,1,2,3,4,5}))==st.find({1,2})); + BOOST_CHECK(st.minimal_simplex_with_same_filtration(st.find({0,2,3}))==st.find({0,2})); + auto s=st.minimal_simplex_with_same_filtration(st.find({0,2,6})); + BOOST_CHECK(s==st.find({0,2})); + BOOST_CHECK(st.vertex_with_same_filtration(st.find({2}))==2); + BOOST_CHECK(st.vertex_with_same_filtration(st.find({1,5}))==st.null_vertex()); + BOOST_CHECK(st.vertex_with_same_filtration(st.find({5,6}))>=5); + } + { + std::cout << "************************************************************************" << std::endl; + std::cout << "Test flag expansion 4" << std::endl; + std::cout << "************************************************************************" << std::endl; + + typeST st; + std::vector added_simplices; + + st.insert_edge_as_flag(2,2,2,50,added_simplices); + st.insert_edge_as_flag(5,5,2,50,added_simplices); + st.insert_edge_as_flag(2,5,2,50,added_simplices); + + st.insert_edge_as_flag(0,0,3,50,added_simplices); + st.insert_edge_as_flag(0,5,3,50,added_simplices); + + st.insert_edge_as_flag(1,1,4,50,added_simplices); + st.insert_edge_as_flag(1,5,4,50,added_simplices); + + st.insert_edge_as_flag(1,2,5,50,added_simplices); + + st.insert_edge_as_flag(3,3,6,50,added_simplices); + st.insert_edge_as_flag(4,4,6,50,added_simplices); + st.insert_edge_as_flag(3,4,6,50,added_simplices); + + st.insert_edge_as_flag(0,1,8,50,added_simplices); + + st.insert_edge_as_flag(1,3,9,50,added_simplices); + + st.insert_edge_as_flag(0,2,10,50,added_simplices); + + BOOST_CHECK_EQUAL(added_simplices.size(), 19); + BOOST_CHECK(st.edge_with_same_filtration(st.find({0,1,2,5}))==st.find({0,2})); + BOOST_CHECK(st.edge_with_same_filtration(st.find({1,5}))==st.find({1,5})); + } +} + + + + + diff --git a/src/Simplex_tree/test/simplex_tree_serialization_unit_test.cpp b/src/Simplex_tree/test/simplex_tree_serialization_unit_test.cpp index c18e522ec9..451418dfcc 100644 --- a/src/Simplex_tree/test/simplex_tree_serialization_unit_test.cpp +++ b/src/Simplex_tree/test/simplex_tree_serialization_unit_test.cpp @@ -279,3 +279,44 @@ BOOST_AUTO_TEST_CASE_TEMPLATE(simplex_tree_non_empty_deserialize_throw, Stree, l } #endif +BOOST_AUTO_TEST_CASE_TEMPLATE(simplex_tree_serialization_and_cofaces, Stree, list_of_tested_variants) { + std::clog << "********************************************************************" << std::endl; + std::clog << "BASIC SIMPLEX TREE SERIALIZATION/DESERIALIZATION" << std::endl; + Stree st; + + st.insert_simplex_and_subfaces({2, 1, 0}); + st.insert_simplex_and_subfaces({3, 0}); + st.insert_simplex_and_subfaces({3, 4, 5}); + st.insert_simplex_and_subfaces({0, 1, 6, 7}); + /* Inserted simplex: */ + /* 1 6 */ + /* o---o */ + /* /X\7/ */ + /* o---o---o---o */ + /* 2 0 3\X/4 */ + /* o */ + /* 5 */ + + const std::size_t stree_buffer_size = st.get_serialization_size(); + std::clog << "Serialization (from the simplex tree) size in bytes = " << stree_buffer_size << std::endl; + char* stree_buffer = new char[stree_buffer_size]; + + st.serialize(stree_buffer, stree_buffer_size); + + Stree st_from_buffer; + st_from_buffer.deserialize(stree_buffer, stree_buffer_size); + delete[] stree_buffer; + + int num_stars = 0; + for (auto coface : st_from_buffer.star_simplex_range(st.find({1, 0}))) { + std::clog << "coface"; + for (auto vertex : st_from_buffer.simplex_vertex_range(coface)) { + std::clog << " " << vertex; + } + std::clog << "\n"; + num_stars++; + } + // [([0, 1], 0.0), ([0, 1, 2], 0.0), ([0, 1, 6], 0.0), ([0, 1, 6, 7], 0.0), ([0, 1, 7], 0.0)] + BOOST_CHECK(num_stars == 5); + +} diff --git a/src/Tangential_complex/include/gudhi/Tangential_complex.h b/src/Tangential_complex/include/gudhi/Tangential_complex.h index fb2a6ceead..4031459f02 100644 --- a/src/Tangential_complex/include/gudhi/Tangential_complex.h +++ b/src/Tangential_complex/include/gudhi/Tangential_complex.h @@ -1580,7 +1580,7 @@ class Tangential_complex { typename K::Scaled_vector_d k_scaled_vec = m_k.scaled_vector_d_object(); CGAL::Random_points_in_ball_d tr_point_in_ball_generator( - m_intrinsic_dim, m_random_generator.get_double(0., max_perturb)); + m_intrinsic_dim, CGAL::get_default_random().get_double(0., max_perturb)); Tr_point local_random_transl = local_tr_traits.construct_weighted_point_d_object()(*tr_point_in_ball_generator++, 0); @@ -2013,8 +2013,6 @@ class Tangential_complex { Points m_points_for_tse; Points_ds m_points_ds_for_tse; #endif - - mutable CGAL::Random m_random_generator; }; // /class Tangential_complex } // end namespace tangential_complex diff --git a/src/python/gudhi/representations/vector_methods.py b/src/python/gudhi/representations/vector_methods.py index 5041629783..71fc07b6d1 100644 --- a/src/python/gudhi/representations/vector_methods.py +++ b/src/python/gudhi/representations/vector_methods.py @@ -10,6 +10,7 @@ # - 2021/11 Vincent Rouvreau: factorize _automatic_sample_range import numpy as np +from scipy.spatial.distance import cdist from sklearn.base import BaseEstimator, TransformerMixin from sklearn.exceptions import NotFittedError from sklearn.preprocessing import MinMaxScaler, MaxAbsScaler @@ -681,16 +682,19 @@ def __call__(self, diag): return self.transform([diag])[0,:] def _lapl_contrast(measure, centers, inertias): - """contrast function for vectorising `measure` in ATOL""" - return np.exp(-pairwise.pairwise_distances(measure, Y=centers) / inertias) + """contrast function for vectorising `measure` in ATOL + we use cdist so as to accept 'inf' values instead of raising ValueError with sklearn.pairwise""" + return np.exp(-cdist(XA=measure, XB=centers) / inertias) def _gaus_contrast(measure, centers, inertias): - """contrast function for vectorising `measure` in ATOL""" - return np.exp(-pairwise.pairwise_distances(measure, Y=centers, squared=True) / inertias**2) + """contrast function for vectorising `measure` in ATOL + we use cdist so as to accept 'inf' values instead of raising ValueError with sklearn.pairwise""" + return np.exp(-cdist(XA=measure, XB=centers, metric="sqeuclidean") / inertias**2) def _indicator_contrast(diags, centers, inertias): - """contrast function for vectorising `measure` in ATOL""" - robe_curve = np.clip(2-pairwise.pairwise_distances(diags, Y=centers)/inertias, 0, 1) + """contrast function for vectorising `measure` in ATOL + we use cdist so as to accept 'inf' values instead of raising ValueError with sklearn.pairwise""" + robe_curve = np.clip(2-cdist(XA=diags, XB=centers)/inertias, 0, 1) return robe_curve def _cloud_weighting(measure): @@ -745,15 +749,21 @@ def __init__(self, quantiser, weighting_method="cloud", contrast="gaussian"): (default: gaussian contrast function, see page 3 in the ATOL paper). """ self.quantiser = quantiser - self.contrast = { + self.contrast = contrast + self.weighting_method = weighting_method + + def get_contrast(self): + return { "gaussian": _gaus_contrast, "laplacian": _lapl_contrast, "indicator": _indicator_contrast, - }.get(contrast, _gaus_contrast) - self.weighting_method = { + }.get(self.contrast, _gaus_contrast) + + def get_weighting_method(self): + return { "cloud" : _cloud_weighting, "iidproba": _iidproba_weighting, - }.get(weighting_method, _cloud_weighting) + }.get(self.weighting_method, _cloud_weighting) def fit(self, X, y=None, sample_weight=None): """ @@ -771,18 +781,24 @@ def fit(self, X, y=None, sample_weight=None): """ if not hasattr(self.quantiser, 'fit'): raise TypeError("quantiser %s has no `fit` attribute." % (self.quantiser)) + + # In fitting we remove infinite death time points so that every center is finite + X = [dgm[~np.isinf(dgm).any(axis=1), :] for dgm in X] + if sample_weight is None: - sample_weight = np.concatenate([self.weighting_method(measure) for measure in X]) + sample_weight = [self.get_weighting_method()(measure) for measure in X] measures_concat = np.concatenate(X) - self.quantiser.fit(X=measures_concat, sample_weight=sample_weight) + weights_concat = np.concatenate(sample_weight) + self.quantiser.fit(X=measures_concat, sample_weight=weights_concat) self.centers = self.quantiser.cluster_centers_ # Hack, but some people are unhappy if the order depends on the version of sklearn self.centers = self.centers[np.lexsort(self.centers.T)] if self.quantiser.n_clusters == 1: dist_centers = pairwise.pairwise_distances(measures_concat) np.fill_diagonal(dist_centers, 0) - self.inertias = np.array([np.max(dist_centers)/2]) + best_inertia = np.max(dist_centers)/2 if np.max(dist_centers)/2 > 0 else 1 + self.inertias = np.array([best_inertia]) else: dist_centers = pairwise.pairwise_distances(self.centers) dist_centers[dist_centers == 0] = np.inf @@ -800,8 +816,8 @@ def __call__(self, measure, sample_weight=None): numpy array in R^self.quantiser.n_clusters. """ if sample_weight is None: - sample_weight = self.weighting_method(measure) - return np.sum(sample_weight * self.contrast(measure, self.centers, self.inertias.T).T, axis=1) + sample_weight = self.get_weighting_method()(measure) + return np.sum(sample_weight * self.get_contrast()(measure, self.centers, self.inertias.T).T, axis=1) def transform(self, X, sample_weight=None): """ @@ -817,5 +833,5 @@ def transform(self, X, sample_weight=None): numpy array with shape (number of measures) x (self.quantiser.n_clusters). """ if sample_weight is None: - sample_weight = [self.weighting_method(measure) for measure in X] + sample_weight = [self.get_weighting_method()(measure) for measure in X] return np.stack([self(measure, sample_weight=weight) for measure, weight in zip(X, sample_weight)])