Skip to content

Commit

Permalink
Merge pull request #805 from PowerGridModel/feature/ub-update-unknown-id
Browse files Browse the repository at this point in the history
UB update unknown id repro case
  • Loading branch information
Jerry-Jinfeng-Guo authored Oct 26, 2024
2 parents d0e8c1f + 250b09f commit 384342d
Show file tree
Hide file tree
Showing 8 changed files with 336 additions and 126 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -171,14 +171,34 @@ template <dataset_type_tag dataset_type_> class Dataset {
template <class StructType>
using DataStruct = std::conditional_t<is_data_mutable_v<dataset_type>, StructType, StructType const>;

// for columnar buffers, Data* data is empty and attributes is filled
// for columnar buffers, Data* data is empty and attributes.data is filled
// for uniform buffers, indptr is empty
struct Buffer {
using Data = Dataset::Data;

Data* data{nullptr};
std::vector<AttributeBuffer<Data>> attributes{};
std::span<Indptr> indptr{};
Idx find_attribute(std::string_view attr_name) const {
if (data == nullptr && std::ranges::all_of(attributes, [](auto const& x) { return x.data == nullptr; })) {
return invalid_index;
}

auto const found = std::ranges::find_if(
attributes, [attr_name](auto const& x) { return x.meta_attribute->name == attr_name; });

if (found == attributes.cend()) {
return invalid_index;
}
return std::distance(attributes.cbegin(), found);
}
template <typename T> T* get_col_data_at_index(Idx index) const {
assert(data == nullptr);
if (data != nullptr) {
throw std::runtime_error("Buffer access by index not supported for row based data!\n");
}
return const_cast<T*>(reinterpret_cast<T const*>(attributes[index].data));
}
};

template <class StructType>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,7 @@ template <class Enum>
inline bool is_nan(Enum x) {
return static_cast<IntS>(x) == na_IntS;
}
inline bool is_nan(Idx x) { return x == na_Idx; }

// is normal
inline auto is_normal(std::floating_point auto value) { return std::isnormal(value); }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,7 @@ class MainModel {
};
~MainModel() { impl_.reset(); }

static bool is_update_independent(ConstDataset const& update_data) {
return Impl::is_update_independent(update_data);
}
bool is_update_independent(ConstDataset const& update_data) { return impl().is_update_independent(update_data); }

std::map<std::string, Idx, std::less<>> all_component_count() const { return impl().all_component_count(); }
void get_indexer(std::string_view component_type, ID const* id_begin, Idx size, Idx* indexer_begin) const {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,12 +154,21 @@ class MainModelImpl<ExtraRetrievableTypes<ExtraRetrievableType...>, ComponentLis
bool is_columnar{false}; // if the component is columnar
bool ids_match{false}; // if the ids match
Idx elements_ps_in_update{0}; // count of elements for this component per scenario in update
Idx elements_ps_in_base{0}; // count of elements for this component per scenario in input
Idx elements_in_base{0}; // count of elements for this component per scenario in input
inline bool qualify_for_optional_id() const { return !has_id && ids_match && ids_all_na && !ids_part_na; }
bool is_independent() const {
bool const provided_ids_valid = has_id && ids_match && !ids_all_na && !ids_part_na;
return qualify_for_optional_id() || provided_ids_valid;
}
Idx get_n_elements() const {
return qualify_for_optional_id() ? (uniform ? elements_ps_in_update : elements_in_base) : invalid_index;
}
};
using UpdateCompIndependence = std::vector<UpdateCompProperties>;
using ComponentCountInBase = std::pair<std::string, Idx>;

static constexpr Idx ignore_output{-1};
static constexpr Idx invalid_index{-1};

protected:
// run functors with all component types
Expand Down Expand Up @@ -401,7 +410,7 @@ class MainModelImpl<ExtraRetrievableTypes<ExtraRetrievableType...>, ComponentLis
auto const get_seq_idx_func = [&state = this->state_, &update_data, scenario_idx, &process_buffer_span,
&comp_indenpendence]<typename CT>() -> std::vector<Idx2D> {
// TODO: (jguo) this function could be encapsulated in UpdateCompIndependence in update.hpp
auto const get_n_comp_elements = [&comp_indenpendence]() {
Idx const n_comp_elements = [&comp_indenpendence]() {
if (!comp_indenpendence.empty()) {
auto const comp_idx = std::ranges::find_if(comp_indenpendence.begin(), comp_indenpendence.end(),
[](auto const& comp) { return comp.name == CT::name; });
Expand All @@ -410,16 +419,14 @@ class MainModelImpl<ExtraRetrievableTypes<ExtraRetrievableType...>, ComponentLis
}
auto const& comp = *comp_idx;
if (comp.is_columnar && (!comp.has_id || comp.ids_all_na)) {
return comp.elements_ps_in_update;
return comp.get_n_elements();
}
if (!comp.is_columnar && (!comp.has_id && comp.ids_all_na)) {
return comp.elements_ps_in_update;
return comp.get_n_elements();
}
}
return na_Idx;
};

Idx n_comp_elements = get_n_comp_elements();
}();

auto const get_sequence = [&state, &n_comp_elements](auto const& it_begin, auto const& it_end) {
return main_core::get_component_sequence<CT>(state, it_begin, it_end, n_comp_elements);
Expand All @@ -438,21 +445,13 @@ class MainModelImpl<ExtraRetrievableTypes<ExtraRetrievableType...>, ComponentLis
// get sequence idx map of an entire batch for fast caching of component sequences
// (only applicable for independent update dataset)
SequenceIdx get_sequence_idx_map(ConstDataset const& update_data) const {
auto update_components_independence = check_components_independence(update_data);
auto const update_components_independence = check_components_independence(update_data);
assert(std::ranges::all_of(update_components_independence,
[](auto const& comp) { return !comp.has_id || comp.ids_match; }));
[](auto const& comp) { return comp.is_independent(); }));

// TODO: (jguo) this function could be encapsulated in UpdateCompIndependence in update.hpp
auto const all_comp_count_in_base = this->all_component_count();
for (auto& comp : update_components_independence) {
if (auto it =
std::ranges::find_if(all_comp_count_in_base,
[&comp](const ComponentCountInBase& pair) { return pair.first == comp.name; });
it != all_comp_count_in_base.end()) {
comp.elements_ps_in_base = it->second;
}
validate_update_data_independence(comp);
}
std::for_each(update_components_independence.begin(), update_components_independence.end(),
[this](auto& comp) { validate_update_data_independence(comp); });

return get_sequence_idx_map(update_data, 0, update_components_independence);
}
Expand Down Expand Up @@ -767,7 +766,10 @@ class MainModelImpl<ExtraRetrievableTypes<ExtraRetrievableType...>, ComponentLis
public:
template <class Component> using UpdateType = typename Component::UpdateType;

static UpdateCompIndependence check_components_independence(ConstDataset const& update_data) {
UpdateCompIndependence check_components_independence(ConstDataset const& update_data) const {
auto const all_comp_count_in_base = this->all_component_count();
Idx const n_scenarios = update_data.batch_size();

auto check_ids_na = [](auto const& all_spans) {
std::vector<std::vector<bool>> ids_na{};
for (const auto& span : all_spans) {
Expand Down Expand Up @@ -798,8 +800,8 @@ class MainModelImpl<ExtraRetrievableTypes<ExtraRetrievableType...>, ComponentLis
std::ranges::all_of(all_spans, [n_elements = result.elements_ps_in_update](auto const& span) {
return static_cast<Idx>(std::size(span)) == n_elements;
});
// Remember the begin iterator of the first scenario, then loop over the remaining scenarios and check the
// ids
// Remember the begin iterator of the first scenario, then loop over the remaining scenarios and check
// the ids
auto const first_span = all_spans[0];
// check the subsequent scenarios
// only return true if all scenarios match the ids of the first batch
Expand All @@ -811,44 +813,89 @@ class MainModelImpl<ExtraRetrievableTypes<ExtraRetrievableType...>, ComponentLis
});
};

auto const check_each_component = [&update_data, &process_buffer_span]<typename CT>() -> UpdateCompProperties {
auto const check_each_component = [&update_data, &process_buffer_span, &all_comp_count_in_base,
&n_scenarios]<typename CT>() -> UpdateCompProperties {
// get span of all the update data
auto const comp_index = update_data.find_component(CT::name, false);
UpdateCompProperties result;
result.name = CT::name;
result.is_columnar = update_data.is_columnar(result.name);
if (comp_index >= 0) {
result.elements_ps_in_update = update_data.get_component_info(comp_index).elements_per_scenario;
if (auto it = std::ranges::find_if(
all_comp_count_in_base,
[&result](const ComponentCountInBase& pair) { return pair.first == result.name; });
it != all_comp_count_in_base.end()) {
result.elements_in_base = it->second;
}
if (result.is_columnar) {
process_buffer_span.template operator()<CT>(
update_data.get_columnar_buffer_span_all_scenarios<meta_data::update_getter_s, CT>(), result);
if (result.is_columnar && comp_index != invalid_index) {
auto const& comp_buffer = update_data.get_buffer(comp_index);
auto const& comp_info = update_data.get_component_info(comp_index);
auto const id_idx = comp_buffer.find_attribute("id");

result.has_id = id_idx != invalid_index;
result.uniform = (comp_info.elements_per_scenario == result.elements_in_base) &&
(comp_info.elements_per_scenario != invalid_index);

if (result.has_id) {
auto const* id_buffer = comp_buffer.template get_col_data_at_index<ID>(id_idx);

result.ids_all_na =
std::all_of(id_buffer, id_buffer + comp_info.total_elements, [](ID id) { return is_nan(id); });
result.ids_part_na = std::any_of(id_buffer, id_buffer + comp_info.total_elements,
[](ID id) { return is_nan(id); }) &&
!result.ids_all_na;
result.ids_match = [id_buffer, elements_ps = result.elements_in_base, n_scenarios]() {
bool all_match = true;
for (Idx i = 0; i < elements_ps; ++i) {
for (Idx j = 0; j < n_scenarios; ++j) {
if (id_buffer[i] != id_buffer[i + elements_ps * j]) {
all_match = false;
break;
}
}
if (!all_match) {
break;
}
}
return all_match;
}();
if (comp_info.elements_per_scenario != invalid_index) {
result.elements_ps_in_update = comp_info.elements_per_scenario;
}
} else { // no id,
result.ids_all_na = true; // no id, all NA
result.ids_part_na = false; // no id, no part NA
result.ids_match = true; // no id, all match
}
} else {
process_buffer_span.template operator()<CT>(
update_data.get_buffer_span_all_scenarios<meta_data::update_getter_s, CT>(), result);
}
if (comp_index >= 0 && result.uniform) {
result.elements_ps_in_update =
update_data.get_component_info(comp_index).elements_per_scenario; // -1 for sparse
}
return result;
};

// check and return indenpendence of all components
return run_functor_with_all_types_return_vector(check_each_component);
}

static bool is_update_independent(ConstDataset const& update_data) {
bool is_update_independent(ConstDataset const& update_data) {
// If the batch size is (0 or) 1, then the update data for this component is 'independent'
if (update_data.batch_size() <= 1) {
return true;
}
auto const all_comp_update_independence = check_components_independence(update_data);
return std::ranges::all_of(all_comp_update_independence,
[](auto const& comp) { return !comp.has_id || comp.ids_match; });
[](auto const& comp) { return comp.is_independent(); });
}

void validate_update_data_independence(UpdateCompProperties const& comp) const {
if (!comp.has_id && comp.ids_all_na) {
return; // empty dataset is still supported
}
if (comp.elements_ps_in_base < comp.elements_ps_in_update) {
if (comp.elements_in_base < comp.elements_ps_in_update) {
throw DatasetError("Update data has more elements per scenario than input data for component " + comp.name +
"!");
}
Expand All @@ -864,7 +911,7 @@ class MainModelImpl<ExtraRetrievableTypes<ExtraRetrievableType...>, ComponentLis
" is not uniform!");
}
}
if (comp.elements_ps_in_base != comp.elements_ps_in_update) {
if (comp.elements_in_base != comp.elements_ps_in_update) {
if (comp.is_columnar && !comp.has_id) {
throw DatasetError("Columnar input data for component " + comp.name +
" has different number of elements per scenario in update and input data!");
Expand Down
Loading

0 comments on commit 384342d

Please sign in to comment.