Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature / use span instead of reference wrapper for cache sequence #824

Merged
merged 5 commits into from
Nov 13, 2024
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -316,12 +316,9 @@ class MainModelImpl<ExtraRetrievableTypes<ExtraRetrievableType...>, ComponentLis

// update all components
template <cache_type_c CacheType>
void
update_component(ConstDataset const& update_data, Idx pos,
std::array<std::reference_wrapper<std::vector<Idx2D> const>, n_types> const& sequence_idx_map) {
void update_component(ConstDataset const& update_data, Idx pos, SequenceIdxView const& sequence_idx_map) {
run_functor_with_all_types_return_void([this, pos, &update_data, &sequence_idx_map]<typename CT>() {
this->update_component<CT, CacheType>(update_data, pos,
std::get<index_of_component<CT>>(sequence_idx_map).get());
this->update_component<CT, CacheType>(update_data, pos, std::get<index_of_component<CT>>(sequence_idx_map));
});
}
template <cache_type_c CacheType>
Expand Down Expand Up @@ -435,7 +432,7 @@ class MainModelImpl<ExtraRetrievableTypes<ExtraRetrievableType...>, ComponentLis
ComponentFlags const& to_store) const {
// TODO: (jguo) this function could be encapsulated in UpdateCompIndependence in update.hpp
return run_functor_with_all_types_return_array([this, scenario_idx, &update_data, &to_store]<typename CT>() {
if (!to_store[index_of_component<CT>]) {
if (!std::get<index_of_component<CT>>(to_store)) {
figueroa1395 marked this conversation as resolved.
Show resolved Hide resolved
return std::vector<Idx2D>{};
}
auto const independence = check_components_independence<CT>(update_data);
Expand All @@ -445,15 +442,7 @@ class MainModelImpl<ExtraRetrievableTypes<ExtraRetrievableType...>, ComponentLis
}
// get sequence idx map of an entire batch for fast caching of component sequences
SequenceIdx get_sequence_idx_map(ConstDataset const& update_data, ComponentFlags const& to_store) const {
// TODO: (jguo) this function could be encapsulated in UpdateCompIndependence in update.hpp
return run_functor_with_all_types_return_array([this, &update_data, &to_store]<typename CT>() {
if (!to_store[index_of_component<CT>]) {
return std::vector<Idx2D>{};
}
auto const independence = check_components_independence<CT>(update_data);
validate_update_data_independence(independence);
return get_component_sequence<CT>(update_data, 0, independence);
});
return get_sequence_idx_map(update_data, 0, to_store);
figueroa1395 marked this conversation as resolved.
Show resolved Hide resolved
}
SequenceIdx get_sequence_idx_map(ConstDataset const& update_data) const {
constexpr ComponentFlags all_true = [] {
Expand Down Expand Up @@ -640,15 +629,8 @@ class MainModelImpl<ExtraRetrievableTypes<ExtraRetrievableType...>, ComponentLis
auto model = copy_model_functor(start);

SequenceIdx current_scenario_sequence_cache = SequenceIdx{};
std::array<std::reference_wrapper<std::vector<Idx2D> const>, n_types> const current_scenario_sequence =
run_functor_with_all_types_return_array(
[&is_independent, &all_scenarios_sequence, &current_scenario_sequence_cache]<typename CT>() {
constexpr auto comp_idx = index_of_component<CT>;
return is_independent[comp_idx] ? std::cref(all_scenarios_sequence[comp_idx])
: std::cref(current_scenario_sequence_cache[comp_idx]);
});
auto [setup, winddown] = scenario_update_restore(model, update_data, current_scenario_sequence,
current_scenario_sequence_cache, is_independent, infos);
auto [setup, winddown] = scenario_update_restore(model, update_data, is_independent, all_scenarios_sequence,
current_scenario_sequence_cache, infos);

auto calculate_scenario = MainModelImpl::call_with<Idx>(
[&model, &calculation_fn, &result_data, &infos](Idx scenario_idx) {
Expand Down Expand Up @@ -720,28 +702,41 @@ class MainModelImpl<ExtraRetrievableTypes<ExtraRetrievableType...>, ComponentLis
};
}

static auto scenario_update_restore(
MainModelImpl& model, ConstDataset const& update_data,
std::array<std::reference_wrapper<std::vector<Idx2D> const>, n_types> const& scenario_sequence,
SequenceIdx& current_scenario_sequence_cache, ComponentFlags const& is_independent,
std::vector<CalculationInfo>& infos) noexcept {
static auto scenario_update_restore(MainModelImpl& model, ConstDataset const& update_data,
ComponentFlags const& is_independent, SequenceIdx const& all_scenario_sequence,
SequenceIdx& current_scenario_sequence_cache,
std::vector<CalculationInfo>& infos) noexcept {
auto do_update_cache = [&is_independent] {
ComponentFlags result;
std::ranges::transform(is_independent, result.begin(), std::logical_not<>{});
return result;
}();

auto const scenario_sequence = [&all_scenario_sequence, &current_scenario_sequence_cache,
&is_independent]() -> SequenceIdxView {
return run_functor_with_all_types_return_array(
[&all_scenario_sequence, &current_scenario_sequence_cache, &is_independent]<typename CT>() {
constexpr auto comp_idx = index_of_component<CT>;
if (std::get<comp_idx>(is_independent)) {
return std::span<Idx2D const>{std::get<comp_idx>(all_scenario_sequence)};
}
return std::span<Idx2D const>{std::get<comp_idx>(current_scenario_sequence_cache)};
});
};

figueroa1395 marked this conversation as resolved.
Show resolved Hide resolved
return std::make_pair(
[&model, &update_data, &scenario_sequence, &current_scenario_sequence_cache,
[&model, &update_data, scenario_sequence, &current_scenario_sequence_cache,
do_update_cache_ = std::move(do_update_cache), &infos](Idx scenario_idx) {
Timer const t_update_model(infos[scenario_idx], 1200, "Update model");
current_scenario_sequence_cache =
model.get_sequence_idx_map(update_data, scenario_idx, do_update_cache_);
model.template update_component<cached_update_t>(update_data, scenario_idx, scenario_sequence);

model.template update_component<cached_update_t>(update_data, scenario_idx, scenario_sequence());
},
[&model, &scenario_sequence, &current_scenario_sequence_cache, &infos](Idx scenario_idx) {
[&model, scenario_sequence, &current_scenario_sequence_cache, &infos](Idx scenario_idx) {
Timer const t_update_model(infos[scenario_idx], 1201, "Restore model");
model.restore_components(scenario_sequence);

model.restore_components(scenario_sequence());
std::ranges::for_each(current_scenario_sequence_cache,
[](auto& comp_seq_idx) { comp_seq_idx.clear(); });
});
Expand Down
Loading