Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/resolve comments optional ids #818

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ template <dataset_type_tag dataset_type_> class Dataset {
if (data == nullptr && std::ranges::all_of(attributes, [](auto const& x) { return x.data == nullptr; })) {
return invalid_index;
}
assert(data == nullptr); // assume colomnar buffer
assert(data == nullptr); // assume columnar buffer

auto const found = std::ranges::find_if(
attributes, [attr_name](auto const& x) { return x.meta_attribute->name == attr_name; });
Expand Down Expand Up @@ -275,6 +275,63 @@ template <dataset_type_tag dataset_type_> class Dataset {
return !is_row_based(buffer) && !(with_attribute_buffers && buffer.attributes.empty());
}

constexpr bool is_dense(std::string_view component) const {
Idx const idx = find_component(component, false);
if (idx == invalid_index) {
return true; // by definition
}
return is_dense(idx);
}
constexpr bool is_dense(Idx const i) const { return is_dense(buffers_[i]); }
constexpr bool is_dense(Buffer const& buffer) const { return buffer.indptr.empty(); }
constexpr bool is_sparse(std::string_view component, bool with_attribute_buffers = false) const {
Idx const idx = find_component(component, false);
if (idx == invalid_index) {
return false;
}
return is_sparse(idx, with_attribute_buffers);
}
constexpr bool is_sparse(Idx const i, bool with_attribute_buffers = false) const {
return is_sparse(buffers_[i], with_attribute_buffers);
}
constexpr bool is_sparse(Buffer const& buffer) const { return !is_dense(buffer); }

constexpr bool is_uniform(std::string_view component) const {
Idx const idx = find_component(component, false);
if (idx == invalid_index) {
return true; // by definition
}
return is_uniform(idx);
}
constexpr bool is_uniform(Idx const i) const { return is_uniform(buffers_[i]); }
constexpr bool is_uniform(Buffer const& buffer) const {
if (is_dense(buffer)) {
return true;
}
assert(buffer.indptr.size() > 1);
auto const first_scenario_size = buffer.indptr[1] - buffer.indptr[0];
return std::ranges::adjacent_find(buffer.indptr, [first_scenario_size](Idx start, Idx stop) {
return stop - start != first_scenario_size;
}) == buffer.indptr.end();
}

constexpr Idx uniform_elements_per_scenario(std::string_view component) const {
Idx const idx = find_component(component, false);
if (idx == invalid_index) {
return 0;
}
return uniform_elements_per_scenario(idx);
}
constexpr Idx uniform_elements_per_scenario(Idx const i) const {
assert(is_uniform(i));
if (is_dense(i)) {
return get_component_info(i).elements_per_scenario;
}
auto const& indptr = buffers_[i].indptr;
assert(indptr.size() > 1);
return indptr[1] - indptr[0];
}

Idx find_component(std::string_view component, bool required = false) const {
auto const found = std::ranges::find_if(dataset_info_.component_info, [component](ComponentInfo const& x) {
return x.component->name == component;
Expand All @@ -291,7 +348,7 @@ template <dataset_type_tag dataset_type_> class Dataset {
bool contains_component(std::string_view component) const { return find_component(component) >= 0; }

ComponentInfo const& get_component_info(std::string_view component) const {
return dataset_info_.component_info[find_component(component, true)];
return get_component_info(find_component(component, true));
}

void add_component_info(std::string_view component, Idx elements_per_scenario, Idx total_elements)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,23 +147,31 @@ class MainModelImpl<ExtraRetrievableTypes<ExtraRetrievableType...>, ComponentLis

struct UpdateCompProperties {
std::string name{};
bool has_id{false}; // if the component has id
bool ids_all_na{false}; // if all ids are all NA
bool ids_part_na{false}; // if some ids are NA but some are not
bool uniform{false}; // if the component is uniform
bool is_columnar{false}; // if the component is columnar
bool ids_match{false}; // if the ids match
Idx elements_ps_in_update{0}; // count of elements for this component per scenario in update
Idx elements_in_base{0}; // count of elements for this component per scenario in input
constexpr bool no_id_col() const { return is_columnar && (!has_id || ids_all_na); }
constexpr bool no_id_row() const { return !is_columnar && (!has_id && ids_all_na); }
constexpr bool qualify_for_optional_id() const { return ids_match && ids_all_na && !ids_part_na; }
constexpr bool provided_ids_valid() const { return has_id && ids_match && !ids_all_na && !ids_part_na; }
constexpr bool is_empty_component() const { return !has_id && ids_all_na; }
bool has_any_elements{false}; // whether the component has any elements in the update data
bool ids_all_na{false}; // whether all ids are all NA
bool ids_part_na{false}; // whether some ids are NA but some are not
bool dense{false}; // whether the component is dense
bool uniform{false}; // whether the component is uniform
bool is_columnar{false}; // whether the component is columnar
bool update_ids_match{false}; // whether the ids match
Idx elements_ps_in_update{invalid_index}; // count of elements for this component per scenario in update
Idx elements_in_base{invalid_index}; // count of elements for this component per scenario in input

constexpr bool no_id() const { return !has_any_elements || ids_all_na; }
constexpr bool no_id_col() const { return is_columnar && no_id(); }
constexpr bool no_id_row() const { return !is_columnar && no_id(); }
Jerry-Jinfeng-Guo marked this conversation as resolved.
Show resolved Hide resolved
constexpr bool qualify_for_optional_id() const {
return update_ids_match && ids_all_na && uniform && elements_ps_in_update == elements_in_base;
}
constexpr bool provided_ids_valid() const {
return is_empty_component() || (update_ids_match && !(ids_all_na || ids_part_na));
}
constexpr bool is_empty_component() const { return !has_any_elements; }
constexpr bool is_independent() const { return qualify_for_optional_id() || provided_ids_valid(); }
Idx get_n_elements() const {
auto const prov_n_elements = uniform ? elements_ps_in_update : elements_in_base;
return qualify_for_optional_id() ? prov_n_elements : invalid_index;
constexpr Idx get_n_elements() const {
assert(uniform || elements_ps_in_update == invalid_index);

return qualify_for_optional_id() ? elements_ps_in_update : invalid_index;
}
};
using UpdateCompIndependence = std::vector<UpdateCompProperties>;
Expand Down Expand Up @@ -767,114 +775,73 @@ class MainModelImpl<ExtraRetrievableTypes<ExtraRetrievableType...>, ComponentLis
auto const all_comp_count_in_base = this->all_component_count();
Idx const n_scenarios = update_data.batch_size();

auto check_ids_na = [](auto const& all_spans) {
std::vector<std::vector<bool>> ids_na{};
for (auto const& span : all_spans) {
std::vector<bool> id_na{};
if constexpr (requires { span.front().id; }) {
for (auto const& obj : span) {
id_na.emplace_back(is_nan(obj.id));
}
}
ids_na.emplace_back(std::move(id_na));
auto const check_id_na = [](auto const& obj) -> bool {
if constexpr (requires { obj.id; }) {
return is_nan(obj.id);
} else if constexpr (requires { obj.get().id; }) {
return is_nan(obj.get().id);
} else {
throw UnreachableHit{"check_components_independence", "This cannot exist"};
}
return ids_na;
};

auto process_buffer_span = [check_ids_na]<typename CT>(auto const& all_spans, UpdateCompProperties& result) {
// Remember the first batch size, then loop over the remaining batches and check if they are of the same
// length
std::vector<std::vector<bool>> const ids_na = check_ids_na(all_spans);
result.has_id = !std::ranges::all_of(ids_na, [](std::vector<bool> const& vec) { return vec.empty(); });
result.ids_all_na = std::ranges::all_of(ids_na, [](std::vector<bool> const& vec) {
return std::ranges::all_of(vec, [](bool const& obj) { return obj; });
});
result.ids_part_na = std::ranges::any_of(ids_na, [](std::vector<bool> const& vec) {
return std::ranges::any_of(vec, [](bool const& obj) { return obj; }) &&
std::ranges::any_of(vec, [](bool const& obj) { return !obj; });
});
result.uniform = std::ranges::all_of(
all_spans, [n_elements = static_cast<Idx>(all_spans.front().size())](auto const& span) {
return static_cast<Idx>(std::size(span)) == n_elements;
});
auto const process_buffer_span = [check_id_na]<typename CT>(auto const& all_spans,
UpdateCompProperties& result) {
result.ids_all_na = std::ranges::all_of(
all_spans, [&check_id_na](auto const& vec) { return std::ranges::all_of(vec, check_id_na); });
result.ids_part_na =
std::ranges::any_of(
all_spans, [&check_id_na](auto const& vec) { return std::ranges::any_of(vec, check_id_na); }) &&
!result.ids_all_na;

// Remember the begin iterator of the first scenario, then loop over the remaining scenarios and check
// the ids
auto const first_span = all_spans[0];
// check the subsequent scenarios
// only return true if all scenarios match the ids of the first batch
result.ids_match =
std::ranges::all_of(all_spans.cbegin() + 1, all_spans.cend(), [&first_span](auto const& current_span) {
return std::ranges::equal(
current_span, first_span,
[](UpdateType<CT> const& obj, UpdateType<CT> const& first) { return obj.id == first.id; });
});
if (all_spans.empty()) {
result.update_ids_match = true;
} else {
auto const first_span = all_spans[0];
// check the subsequent scenarios
// only return true if all scenarios match the ids of the first batch
result.update_ids_match = std::ranges::all_of(
all_spans.cbegin() + 1, all_spans.cend(), [&first_span](auto const& current_span) {
return std::ranges::equal(
current_span, first_span,
[](UpdateType<CT> const& obj, UpdateType<CT> const& first) { return obj.id == first.id; });
});
}
};

auto const check_each_component = [&update_data, &process_buffer_span, &all_comp_count_in_base,
&n_scenarios]<typename CT>() -> UpdateCompProperties {
auto const check_each_component = [&update_data, &process_buffer_span,
&all_comp_count_in_base]<typename CT>() -> UpdateCompProperties {
// get span of all the update data
auto const comp_index = update_data.find_component(CT::name, false);
UpdateCompProperties result;
result.name = CT::name;
result.is_columnar = update_data.is_columnar(result.name);
result.dense = update_data.is_dense(result.name);
result.uniform = update_data.is_uniform(result.name);
result.has_any_elements =
comp_index != invalid_index && update_data.get_component_info(comp_index).total_elements > 0;

result.elements_ps_in_update =
result.uniform ? update_data.uniform_elements_per_scenario(result.name) : invalid_index;

if (auto it = std::ranges::find_if(
all_comp_count_in_base,
[&result](ComponentCountInBase const& pair) { return pair.first == result.name; });
it != all_comp_count_in_base.end()) {
result.elements_in_base = it->second;
} else {
result.elements_in_base = 0;
}
if (result.is_columnar && comp_index != invalid_index) {
auto const& comp_buffer = update_data.get_buffer(comp_index);
auto const& comp_info = update_data.get_component_info(comp_index);
auto const id_idx = comp_buffer.find_attribute("id");

result.has_id = id_idx != invalid_index;
result.uniform = comp_info.elements_per_scenario == result.elements_in_base;

if (result.has_id) {
auto const* id_buffer = comp_buffer.template get_col_data_at_index<ID>(id_idx);

result.ids_all_na =
std::all_of(id_buffer, id_buffer + comp_info.total_elements, [](ID id) { return is_nan(id); });
result.ids_part_na = std::any_of(id_buffer, id_buffer + comp_info.total_elements,
[](ID id) { return is_nan(id); }) &&
!result.ids_all_na;
assert(result.elements_in_base > 0);
auto const n_scenario_comp_in_update =
result.elements_in_base * n_scenarios > comp_info.total_elements
? comp_info.total_elements / result.elements_in_base
: n_scenarios;
result.ids_match = [id_buffer, elements_ps = result.elements_in_base, n_scenario_comp_in_update]() {
bool all_match = true;
for (Idx i = 0; i < elements_ps; ++i) {
for (Idx j = 0; j < n_scenario_comp_in_update; ++j) {
if (id_buffer[i] != id_buffer[i + elements_ps * j]) {
all_match = false;
break;
}
}
if (!all_match) {
break;
}
}
return all_match;
}();
if (comp_info.elements_per_scenario != invalid_index) {
result.elements_ps_in_update = comp_info.elements_per_scenario;
}
} else { // no id,
result.ids_all_na = true; // no id, all NA
result.ids_part_na = false; // no id, no part NA
result.ids_match = true; // no id, all match
}

if (result.is_columnar) {
process_buffer_span.template operator()<CT>(
update_data.get_columnar_buffer_span_all_scenarios<meta_data::update_getter_s, CT>(), result);
} else {
process_buffer_span.template operator()<CT>(
update_data.get_buffer_span_all_scenarios<meta_data::update_getter_s, CT>(), result);
}
if (comp_index >= 0 && result.uniform) {
result.elements_ps_in_update =
update_data.get_component_info(comp_index).elements_per_scenario; // -1 for sparse
}
return result;
};

Expand Down Expand Up @@ -903,31 +870,22 @@ class MainModelImpl<ExtraRetrievableTypes<ExtraRetrievableType...>, ComponentLis
if (comp.is_empty_component()) {
return; // empty dataset is still supported
}
if (comp.elements_in_base < comp.elements_ps_in_update) {
auto const elements_ps = comp.get_n_elements();
assert(comp.uniform || elements_ps < 0);

if (elements_ps >= 0 && comp.elements_in_base < elements_ps) {
throw DatasetError("Update data has more elements per scenario than input data for component " + comp.name +
"!");
}
if (comp.ids_part_na) {
throw DatasetError("Some IDs are not valid for component " + comp.name + " in update data!");
}
if (!comp.uniform) {
if (comp.is_columnar && !comp.has_id) {
throw DatasetError("Columnar input data without IDs for component " + comp.name + " is not uniform!");
}
if (!comp.is_columnar && comp.ids_all_na) {
throw DatasetError("Row based input data with all NA IDs for component " + comp.name +
" is not uniform!");
}
if (comp.ids_all_na && comp.elements_in_base != elements_ps) {
throw DatasetError("Update data without IDs for component " + comp.name +
" has a different number of elements per scenario then input data!");
}
if (comp.elements_in_base != comp.elements_ps_in_update) {
if (comp.is_columnar && !comp.has_id) {
throw DatasetError("Columnar input data for component " + comp.name +
" has different number of elements per scenario in update and input data!");
}
if (!comp.is_columnar && comp.uniform && (comp.has_id && comp.ids_all_na)) {
throw DatasetError("Row based input data for component " + comp.name +
" has different number of elements per scenario in update and input data!");
}
if (!comp.is_columnar && comp.ids_all_na) {
throw DatasetError("Row based update data without IDs for component " + comp.name + " is not supported!");
}
}

Expand Down
Loading
Loading