Skip to content

Commit

Permalink
Separate numerator / denominator in ASCI scores to allow for reuse in…
Browse files Browse the repository at this point in the history
… PT2 calculation
  • Loading branch information
wavefunction91 committed Aug 16, 2023
1 parent c103c74 commit 1bf63c8
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 24 deletions.
17 changes: 10 additions & 7 deletions include/macis/asci/determinant_contributions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@ namespace macis {
template <typename WfnT>
struct asci_contrib {
WfnT state;
double rv;
double c_times_matel;
double h_diag;

auto rv() const { return c_times_matel / h_diag; }
};

template <typename WfnT>
Expand Down Expand Up @@ -57,10 +60,10 @@ void append_singles_asci_contributions(
// Calculate fast diagonal matrix element
auto h_diag =
ham_gen.fast_diag_single(eps_same[i], eps_same[a], i, a, root_diag);
h_el /= (E0 - h_diag);
//h_el /= (E0 - h_diag);

// Append to return values
asci_contributions.push_back({ex_det, coeff * h_el});
asci_contributions.push_back({ex_det, coeff * h_el, E0 - h_diag});

} // Loop over single extitations
}
Expand Down Expand Up @@ -110,10 +113,10 @@ void append_ss_doubles_asci_contributions(
auto h_diag =
ham_gen.fast_diag_ss_double(eps_same[i], eps_same[j], eps_same[a],
eps_same[b], i, j, a, b, root_diag);
h_el /= (E0 - h_diag);
//h_el /= (E0 - h_diag);

// Append {det, c*h_el}
asci_contributions.push_back({ex_det, coeff * h_el});
asci_contributions.push_back({ex_det, coeff * h_el, E0 - h_diag});

} // Restricted BJ loop
} // AI Loop
Expand Down Expand Up @@ -152,9 +155,9 @@ void append_os_doubles_asci_contributions(
auto h_diag = ham_gen.fast_diag_os_double(eps_alpha[i], eps_beta[j],
eps_alpha[a], eps_beta[b],
i, j, a, b, root_diag);
h_el /= (E0 - h_diag);
//h_el /= (E0 - h_diag);

asci_contributions.push_back({ex_det, coeff * h_el});
asci_contributions.push_back({ex_det, coeff * h_el, E0 - h_diag});
} // BJ loop
} // AI loop
}
Expand Down
19 changes: 11 additions & 8 deletions include/macis/asci/determinant_search.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ template <typename WfnT>
struct asci_contrib_topk_comparator {
using type = asci_contrib<WfnT>;
constexpr bool operator()(const type& a, const type& b) const {
return std::abs(a.rv) > std::abs(b.rv);
return std::abs(a.rv()) > std::abs(b.rv());
}
};

Expand Down Expand Up @@ -122,7 +122,7 @@ asci_contrib_container<wfn_t<N>> asci_contributions_standard(
// Remove small contributions
auto it = std::partition(
asci_pairs.begin(), asci_pairs.end(), [=](const auto& x) {
return std::abs(x.rv) > asci_settings.rv_prune_tol;
return std::abs(x.rv()) > asci_settings.rv_prune_tol;
});
asci_pairs.erase(it, asci_pairs.end());
logger->info(" * Pruning at DET = {} NSZ = {}", i, asci_pairs.size());
Expand Down Expand Up @@ -356,7 +356,7 @@ asci_contrib_container<wfn_t<N>> asci_contributions_constraint(
// Remove small contributions
auto it = std::partition(
asci_pairs.begin(), asci_pairs.end(), [=](const auto& x) {
return std::abs(x.rv) > asci_settings.rv_prune_tol;
return std::abs(x.rv()) > asci_settings.rv_prune_tol;
});
asci_pairs.erase(it, asci_pairs.end());

Expand Down Expand Up @@ -527,7 +527,10 @@ std::vector<wfn_t<N>> asci_search(

auto keep_large_st = clock_type::now();
// Finalize scores
for(auto& x : asci_pairs) x.rv = -std::abs(x.rv);
for(auto& x : asci_pairs) {
x.c_times_matel = -std::abs(x.c_times_matel);
x.h_diag = std::abs(x.h_diag);
}

// Insert all dets with their coefficients as seeds
for(size_t i = 0; i < ncdets; ++i) {
Expand All @@ -540,7 +543,7 @@ std::vector<wfn_t<N>> asci_search(
keep_only_largest_copy_asci_pairs(asci_pairs);

asci_pairs.erase(std::partition(asci_pairs.begin(), asci_pairs.end(),
[](const auto& p) { return p.rv < 0.0; }),
[](const auto& p) { return p.rv() < 0.0; }),
asci_pairs.end());

// Only do top-K on (ndets_max - ncdets) b/c CDETS will be added later
Expand Down Expand Up @@ -569,7 +572,7 @@ std::vector<wfn_t<N>> asci_search(
// Strip scores
std::vector<double> scores(asci_pairs.size());
std::transform(asci_pairs.begin(), asci_pairs.end(), scores.begin(),
[](const auto& p) { return std::abs(p.rv); });
[](const auto& p) { return std::abs(p.rv()); });

// Determine kth-ranked scores
auto kth_score =
Expand All @@ -579,8 +582,8 @@ std::vector<wfn_t<N>> asci_search(
// Partition local pairs into less / eq batches
auto [g_begin, e_begin, l_begin, _end] = leg_partition(
asci_pairs.begin(), asci_pairs.end(), kth_score,
[=](const auto& p, const auto& s) { return std::abs(p.rv) > s; },
[=](const auto& p, const auto& s) { return std::abs(p.rv) == s; });
[=](const auto& p, const auto& s) { return std::abs(p.rv()) > s; },
[=](const auto& p, const auto& s) { return std::abs(p.rv()) == s; });

// Determine local counts
size_t n_greater = std::distance(g_begin, e_begin);
Expand Down
6 changes: 3 additions & 3 deletions include/macis/asci/determinant_sort.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ PairIterator sort_and_accumulate_asci_pairs(PairIterator pairs_begin,

// Accumulate
else {
cur_it->rv += it->rv;
it->rv = 0; // Zero out to expose potential bugs
cur_it->c_times_matel += it->c_times_matel;
it->c_times_matel = 0; // Zero out to expose potential bugs
}
}

Expand Down Expand Up @@ -108,7 +108,7 @@ void keep_only_largest_copy_asci_pairs(

// Keep only max value
else {
cur_it->rv = std::max(cur_it->rv, it->rv);
cur_it->c_times_matel = std::max(cur_it->c_times_matel, it->c_times_matel);
}
}

Expand Down
12 changes: 6 additions & 6 deletions include/macis/asci/mask_constraints.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -283,9 +283,9 @@ void generate_constraint_singles_contributions_ss(

// Compute Fast Diagonal Matrix Element
auto h_diag = ham_gen.fast_diag_single(eps[i], eps[a], i, a, root_diag);
h_el /= (E0 - h_diag);
//h_el /= (E0 - h_diag);

asci_contributions.push_back({ex_det, coeff * h_el});
asci_contributions.push_back({ex_det, coeff * h_el, E0 - h_diag});
}
}
}
Expand Down Expand Up @@ -336,9 +336,9 @@ void generate_constraint_doubles_contributions_ss(
// Evaluate fast diagonal matrix element
auto h_diag = ham_gen.fast_diag_ss_double(eps[i], eps[j], eps[a], eps[b],
i, j, a, b, root_diag);
h_el /= (E0 - h_diag);
//h_el /= (E0 - h_diag);

asci_contributions.push_back({full_ex, coeff * h_el});
asci_contributions.push_back({full_ex, coeff * h_el, E0 - h_diag});
}
}
}
Expand Down Expand Up @@ -394,9 +394,9 @@ void generate_constraint_doubles_contributions_os(
auto h_diag =
ham_gen.fast_diag_os_double(eps_same[i], eps_othr[j], eps_same[a],
eps_othr[b], i, j, a, b, root_diag);
h_el /= (E0 - h_diag);
//h_el /= (E0 - h_diag);

asci_contributions.push_back({ex_det, coeff * h_el});
asci_contributions.push_back({ex_det, coeff * h_el, E0 - h_diag});
} // BJ

} // A
Expand Down
1 change: 1 addition & 0 deletions tests/asci.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ TEST_CASE("ASCI") {
asci_settings, mcscf_settings, E0, std::move(dets), std::move(C), ham_gen,
norb MACIS_MPI_CODE(, MPI_COMM_WORLD));

std::cout << E0 - -8.542926243842e+01 << std::endl;
REQUIRE(E0 == Approx(-8.542926243842e+01));
REQUIRE(dets.size() == 10000);
REQUIRE(C.size() == 10000);
Expand Down

0 comments on commit 1bf63c8

Please sign in to comment.