Skip to content

Commit

Permalink
processed some comments, other awaits further input
Browse files Browse the repository at this point in the history
Signed-off-by: Martijn Govers <Martijn.Govers@Alliander.com>
  • Loading branch information
mgovers committed Jul 16, 2024
1 parent eba031e commit 4d39b51
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -163,12 +163,12 @@ template <dataset_type_tag dataset_type_> class Dataset {
template <template <class> class type_getter, class ComponentType,
class StructType = DataStruct<typename type_getter<ComponentType>::type>>
std::span<StructType> get_buffer_span(Idx scenario = invalid_index) const {
assert(scenario < batch_size());

if (!is_batch() && scenario > 0) {
throw DatasetError{"Cannot export a single dataset with specified scenario\n"};
}
if (scenario >= batch_size()) {
throw DatasetError{"Scenario cannot be greater or equal to batch size!\n"};
}

Idx const idx = find_component(ComponentType::name, false);
return get_buffer_span_impl<StructType>(scenario, idx);
}
Expand All @@ -189,9 +189,8 @@ template <dataset_type_tag dataset_type_> class Dataset {
Dataset get_individual_scenario(Idx scenario)
requires(!is_indptr_mutable_v<dataset_type>)
{
if (scenario < 0 || scenario >= batch_size()) {
throw DatasetError{"Scenario cannot be less than 0 or greater or equal to batch size!\n"};
}
assert(0 <= scenario && scenario < batch_size());

Dataset result{false, 1, dataset().name, meta_data()};
for (Idx i{}; i != n_components(); ++i) {
auto const& buffer = get_buffer(i);
Expand Down
71 changes: 27 additions & 44 deletions tests/cpp_unit_tests/test_dataset.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -315,8 +315,6 @@ TEST_CASE_TEMPLATE("Test dataset (common)", DatasetType, ConstDataset, MutableDa

CHECK(dataset.template get_buffer_span<input_getter_s, A>(0).data() == a_buffer.data());
CHECK(dataset.template get_buffer_span<input_getter_s, A>(0).size() == elements_per_scenario);
CHECK_THROWS_AS((dataset.template get_buffer_span<input_getter_s, A>(1)), DatasetError);
CHECK_THROWS_AS((dataset.template get_buffer_span<input_getter_s, A>(2)), DatasetError);

auto const all_scenario_spans = dataset.template get_buffer_span_all_scenarios<input_getter_s, A>();
CHECK(all_scenario_spans.size() == 1);
Expand Down Expand Up @@ -358,9 +356,6 @@ TEST_CASE_TEMPLATE("Test dataset (common)", DatasetType, ConstDataset, MutableDa
CHECK(scenario_span.size() == elements_per_scenario);
CHECK(all_scenario_spans[scenario].data() == scenario_span.data());
CHECK(all_scenario_spans[scenario].size() == scenario_span.size());
} else {
CHECK_THROWS_AS((dataset.template get_buffer_span<input_getter_s, A>(scenario)),
DatasetError);
}
}
}
Expand Down Expand Up @@ -388,8 +383,6 @@ TEST_CASE_TEMPLATE("Test dataset (common)", DatasetType, ConstDataset, MutableDa

CHECK(dataset.template get_buffer_span<input_getter_s, A>(0).data() == a_buffer.data());
CHECK(dataset.template get_buffer_span<input_getter_s, A>(0).size() == total_elements);
CHECK_THROWS_AS((dataset.template get_buffer_span<input_getter_s, A>(1)), DatasetError);
CHECK_THROWS_AS((dataset.template get_buffer_span<input_getter_s, A>(2)), DatasetError);

auto const all_scenario_spans = dataset.template get_buffer_span_all_scenarios<input_getter_s, A>();
CHECK(all_scenario_spans.size() == 1);
Expand Down Expand Up @@ -438,9 +431,6 @@ TEST_CASE_TEMPLATE("Test dataset (common)", DatasetType, ConstDataset, MutableDa
CHECK(scenario_span.size() == elements_per_scenarios[scenario]);
CHECK(all_scenario_spans[scenario].data() == scenario_span.data());
CHECK(all_scenario_spans[scenario].size() == scenario_span.size());
} else {
CHECK_THROWS_AS((dataset.template get_buffer_span<input_getter_s, A>(scenario)),
DatasetError);
}
}
}
Expand Down Expand Up @@ -489,40 +479,33 @@ TEST_CASE_TEMPLATE("Test dataset (common)", DatasetType, ConstDataset, MutableDa
add_inhomogeneous_buffer(dataset, B::name, b_buffer.size(), b_indptr.data(),
static_cast<void*>(b_buffer.data()));

for (auto scenario = -1; scenario <= batch_size; ++scenario) {
if (scenario >= 0 && scenario < batch_size) {
auto const scenario_dataset = dataset.get_individual_scenario(scenario);

CHECK(&scenario_dataset.meta_data() == &dataset.meta_data());
CHECK(!scenario_dataset.empty());
CHECK(scenario_dataset.is_batch() == false);
CHECK(scenario_dataset.batch_size() == 1);
CHECK(scenario_dataset.n_components() == dataset.n_components());

CHECK(scenario_dataset.get_component_info(A::name).component ==
&dataset_type.get_component(A::name));
CHECK(scenario_dataset.get_component_info(A::name).elements_per_scenario ==
a_elements_per_scenario);
CHECK(scenario_dataset.get_component_info(A::name).total_elements == a_elements_per_scenario);

CHECK(scenario_dataset.get_component_info(B::name).component ==
&dataset_type.get_component(B::name));
CHECK(scenario_dataset.get_component_info(B::name).elements_per_scenario ==
dataset.template get_buffer_span<input_getter_s, B>(scenario).size());
CHECK(scenario_dataset.get_component_info(B::name).total_elements ==
scenario_dataset.get_component_info(B::name).elements_per_scenario);

auto const scenario_span_a = scenario_dataset.template get_buffer_span<input_getter_s, A>();
auto const scenario_span_b = scenario_dataset.template get_buffer_span<input_getter_s, B>();
auto const dataset_span_a = dataset.template get_buffer_span<input_getter_s, A>(scenario);
auto const dataset_span_b = dataset.template get_buffer_span<input_getter_s, B>(scenario);
CHECK(scenario_span_a.data() == dataset_span_a.data());
CHECK(scenario_span_a.size() == dataset_span_a.size());
CHECK(scenario_span_b.data() == dataset_span_b.data());
CHECK(scenario_span_b.size() == dataset_span_b.size());
} else {
CHECK_THROWS_AS(dataset.get_individual_scenario(scenario), DatasetError);
}
for (auto scenario = 0; scenario < batch_size; ++scenario) {
auto const scenario_dataset = dataset.get_individual_scenario(scenario);

CHECK(&scenario_dataset.meta_data() == &dataset.meta_data());
CHECK(!scenario_dataset.empty());
CHECK(scenario_dataset.is_batch() == false);
CHECK(scenario_dataset.batch_size() == 1);
CHECK(scenario_dataset.n_components() == dataset.n_components());

CHECK(scenario_dataset.get_component_info(A::name).component == &dataset_type.get_component(A::name));
CHECK(scenario_dataset.get_component_info(A::name).elements_per_scenario == a_elements_per_scenario);
CHECK(scenario_dataset.get_component_info(A::name).total_elements == a_elements_per_scenario);

CHECK(scenario_dataset.get_component_info(B::name).component == &dataset_type.get_component(B::name));
CHECK(scenario_dataset.get_component_info(B::name).elements_per_scenario ==
dataset.template get_buffer_span<input_getter_s, B>(scenario).size());
CHECK(scenario_dataset.get_component_info(B::name).total_elements ==
scenario_dataset.get_component_info(B::name).elements_per_scenario);

auto const scenario_span_a = scenario_dataset.template get_buffer_span<input_getter_s, A>();
auto const scenario_span_b = scenario_dataset.template get_buffer_span<input_getter_s, B>();
auto const dataset_span_a = dataset.template get_buffer_span<input_getter_s, A>(scenario);
auto const dataset_span_b = dataset.template get_buffer_span<input_getter_s, B>(scenario);
CHECK(scenario_span_a.data() == dataset_span_a.data());
CHECK(scenario_span_a.size() == dataset_span_a.size());
CHECK(scenario_span_b.data() == dataset_span_b.data());
CHECK(scenario_span_b.size() == dataset_span_b.size());
}
}
}
Expand Down

0 comments on commit 4d39b51

Please sign in to comment.