Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature / use span instead of reference wrapper for cache sequence #824

Merged
merged 5 commits into from
Nov 13, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -316,12 +316,9 @@ class MainModelImpl<ExtraRetrievableTypes<ExtraRetrievableType...>, ComponentLis

// update all components
template <cache_type_c CacheType>
void
update_component(ConstDataset const& update_data, Idx pos,
std::array<std::reference_wrapper<std::vector<Idx2D> const>, n_types> const& sequence_idx_map) {
void update_component(ConstDataset const& update_data, Idx pos, SequenceIdxView const& sequence_idx_map) {
run_functor_with_all_types_return_void([this, pos, &update_data, &sequence_idx_map]<typename CT>() {
this->update_component<CT, CacheType>(update_data, pos,
std::get<index_of_component<CT>>(sequence_idx_map).get());
this->update_component<CT, CacheType>(update_data, pos, std::get<index_of_component<CT>>(sequence_idx_map));
});
}
template <cache_type_c CacheType>
Expand Down Expand Up @@ -432,36 +429,28 @@ class MainModelImpl<ExtraRetrievableTypes<ExtraRetrievableType...>, ComponentLis
return get_sequence(buffer_span);
}
SequenceIdx get_sequence_idx_map(ConstDataset const& update_data, Idx scenario_idx,
ComponentFlags const& to_store) const {
// TODO: (jguo) this function could be encapsulated in UpdateCompIndependence in update.hpp
return run_functor_with_all_types_return_array([this, scenario_idx, &update_data, &to_store]<typename CT>() {
if (!to_store[index_of_component<CT>]) {
return std::vector<Idx2D>{};
}
auto const independence = check_components_independence<CT>(update_data);
validate_update_data_independence(independence);
return get_component_sequence<CT>(update_data, scenario_idx, independence);
});
}
// get sequence idx map of an entire batch for fast caching of component sequences
SequenceIdx get_sequence_idx_map(ConstDataset const& update_data, ComponentFlags const& to_store) const {
ComponentFlags const& components_to_store) const {
// TODO: (jguo) this function could be encapsulated in UpdateCompIndependence in update.hpp
return run_functor_with_all_types_return_array([this, &update_data, &to_store]<typename CT>() {
if (!to_store[index_of_component<CT>]) {
return std::vector<Idx2D>{};
}
auto const independence = check_components_independence<CT>(update_data);
validate_update_data_independence(independence);
return get_component_sequence<CT>(update_data, 0, independence);
});
return run_functor_with_all_types_return_array(
[this, scenario_idx, &update_data, &components_to_store]<typename CT>() {
if (!std::get<index_of_component<CT>>(components_to_store)) {
return std::vector<Idx2D>{};
}
auto const independence = check_components_independence<CT>(update_data);
validate_update_data_independence(independence);
return get_component_sequence<CT>(update_data, scenario_idx, independence);
});
}
// Get sequence idx map of an entire batch for fast caching of component sequences.
// The sequence idx map of the batch is the same as that of the first scenario in the batch (assuming homogeneity)
// This is the entry point for permanent updates.
SequenceIdx get_sequence_idx_map(ConstDataset const& update_data) const {
constexpr ComponentFlags all_true = [] {
ComponentFlags result{};
std::ranges::fill(result, true);
return result;
}();
return get_sequence_idx_map(update_data, all_true);
return get_sequence_idx_map(update_data, 0, all_true);
}

private:
Expand Down Expand Up @@ -621,9 +610,10 @@ class MainModelImpl<ExtraRetrievableTypes<ExtraRetrievableType...>, ComponentLis
// const ref of current instance
MainModelImpl const& base_model = *this;

// cache component update order if possible
// cache component update order where possible.
// the order for a cacheable (independent) component by definition is the same across all scenarios
auto const is_independent = is_update_independent(update_data);
all_scenarios_sequence = get_sequence_idx_map(update_data, is_independent);
all_scenarios_sequence = get_sequence_idx_map(update_data, 0, is_independent);

return [&base_model, &exceptions, &infos, &calculation_fn, &result_data, &update_data,
&all_scenarios_sequence = std::as_const(all_scenarios_sequence),
Expand All @@ -640,15 +630,8 @@ class MainModelImpl<ExtraRetrievableTypes<ExtraRetrievableType...>, ComponentLis
auto model = copy_model_functor(start);

SequenceIdx current_scenario_sequence_cache = SequenceIdx{};
std::array<std::reference_wrapper<std::vector<Idx2D> const>, n_types> const current_scenario_sequence =
run_functor_with_all_types_return_array(
[&is_independent, &all_scenarios_sequence, &current_scenario_sequence_cache]<typename CT>() {
constexpr auto comp_idx = index_of_component<CT>;
return is_independent[comp_idx] ? std::cref(all_scenarios_sequence[comp_idx])
: std::cref(current_scenario_sequence_cache[comp_idx]);
});
auto [setup, winddown] = scenario_update_restore(model, update_data, current_scenario_sequence,
current_scenario_sequence_cache, is_independent, infos);
auto [setup, winddown] = scenario_update_restore(model, update_data, is_independent, all_scenarios_sequence,
current_scenario_sequence_cache, infos);

auto calculate_scenario = MainModelImpl::call_with<Idx>(
[&model, &calculation_fn, &result_data, &infos](Idx scenario_idx) {
Expand Down Expand Up @@ -720,28 +703,41 @@ class MainModelImpl<ExtraRetrievableTypes<ExtraRetrievableType...>, ComponentLis
};
}

static auto scenario_update_restore(
MainModelImpl& model, ConstDataset const& update_data,
std::array<std::reference_wrapper<std::vector<Idx2D> const>, n_types> const& scenario_sequence,
SequenceIdx& current_scenario_sequence_cache, ComponentFlags const& is_independent,
std::vector<CalculationInfo>& infos) noexcept {
static auto scenario_update_restore(MainModelImpl& model, ConstDataset const& update_data,
ComponentFlags const& is_independent, SequenceIdx const& all_scenario_sequence,
SequenceIdx& current_scenario_sequence_cache,
std::vector<CalculationInfo>& infos) noexcept {
auto do_update_cache = [&is_independent] {
ComponentFlags result;
std::ranges::transform(is_independent, result.begin(), std::logical_not<>{});
return result;
}();

auto const scenario_sequence = [&all_scenario_sequence, &current_scenario_sequence_cache,
&is_independent]() -> SequenceIdxView {
return run_functor_with_all_types_return_array(
[&all_scenario_sequence, &current_scenario_sequence_cache, &is_independent]<typename CT>() {
constexpr auto comp_idx = index_of_component<CT>;
if (std::get<comp_idx>(is_independent)) {
return std::span<Idx2D const>{std::get<comp_idx>(all_scenario_sequence)};
}
return std::span<Idx2D const>{std::get<comp_idx>(current_scenario_sequence_cache)};
});
};

figueroa1395 marked this conversation as resolved.
Show resolved Hide resolved
return std::make_pair(
[&model, &update_data, &scenario_sequence, &current_scenario_sequence_cache,
[&model, &update_data, scenario_sequence, &current_scenario_sequence_cache,
do_update_cache_ = std::move(do_update_cache), &infos](Idx scenario_idx) {
Timer const t_update_model(infos[scenario_idx], 1200, "Update model");
current_scenario_sequence_cache =
model.get_sequence_idx_map(update_data, scenario_idx, do_update_cache_);
model.template update_component<cached_update_t>(update_data, scenario_idx, scenario_sequence);

model.template update_component<cached_update_t>(update_data, scenario_idx, scenario_sequence());
},
[&model, &scenario_sequence, &current_scenario_sequence_cache, &infos](Idx scenario_idx) {
[&model, scenario_sequence, &current_scenario_sequence_cache, &infos](Idx scenario_idx) {
Timer const t_update_model(infos[scenario_idx], 1201, "Restore model");
model.restore_components(scenario_sequence);

model.restore_components(scenario_sequence());
std::ranges::for_each(current_scenario_sequence_cache,
[](auto& comp_seq_idx) { comp_seq_idx.clear(); });
});
Expand Down
Loading