Skip to content

Commit

Permalink
Add operator[] for MarginalDistribution (#487)
Browse files Browse the repository at this point in the history
This change is in support of the PIC implementation.
  • Loading branch information
peddie authored Aug 2, 2024
1 parent 5265add commit 7e27b1f
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 1 deletion.
11 changes: 11 additions & 0 deletions include/albatross/src/core/distribution.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ inline bool operator==(const albatross::DiagonalMatrixXd &x,

namespace albatross {

struct MarginalDistribution;

constexpr double cDefaultApproximatelyEqualEpsilon = 1e-3;

template <typename Derived> struct DistributionBase {
Expand Down Expand Up @@ -48,6 +50,8 @@ template <typename Derived> struct DistributionBase {
return derived().get_diagonal(i);
}

MarginalDistribution operator[](std::size_t index) const;

Eigen::VectorXd mean;
std::map<std::string, std::string> metadata;

Expand Down Expand Up @@ -245,6 +249,13 @@ inline void set_subset(const DistributionBase<DistributionType> &from,
to->derived().set_subset(from, indices);
}

template <typename Derived>
MarginalDistribution
DistributionBase<Derived>::operator[](std::size_t index) const {
return MarginalDistribution(mean[cast::to_index(index)],
get_diagonal(cast::to_index(index)));
}

inline MarginalDistribution
concatenate_marginals(const MarginalDistribution &x,
const MarginalDistribution &y) {
Expand Down
15 changes: 14 additions & 1 deletion tests/test_core_distribution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -236,14 +236,27 @@ TYPED_TEST_P(DistributionTest, test_subtract) {
EXPECT_EQ(actual, expected);
};

TYPED_TEST_P(DistributionTest, test_operator_indexing) {

TypeParam test_case;
const auto dist = test_case.create();

for (std::size_t idx = 0; idx < dist.size(); ++idx) {
MarginalDistribution m = dist[idx];

EXPECT_EQ(m.mean[0], dist.mean[cast::to_index(idx)]);
EXPECT_EQ(m.get_diagonal(0), dist.get_diagonal(cast::to_index(idx)));
}
};

REGISTER_TYPED_TEST_SUITE_P(DistributionTest, test_subset,
test_multiply_with_matrix_joint,
test_multiply_with_matrix_marginal,
test_multiply_with_sparse_matrix_joint,
test_multiply_with_sparse_matrix_marginal,
test_multiply_with_vector, test_multiply_by_scalar,
test_equal, test_approximately_equal, test_add,
test_subtract);
test_subtract, test_operator_indexing);

Eigen::VectorXd arange(Eigen::Index k = 5) {
Eigen::VectorXd mean(k);
Expand Down

0 comments on commit 7e27b1f

Please sign in to comment.