From 22bb2a464fae3d4627fb3bf537d8015d696ec262 Mon Sep 17 00:00:00 2001 From: Jerry Guo Date: Thu, 17 Oct 2024 13:20:53 +0200 Subject: [PATCH 1/2] Capital enums made lower case Signed-off-by: Jerry Guo --- src/power_grid_model/core/power_grid_dataset.py | 4 ++-- src/power_grid_model/enum.py | 4 ++-- tests/unit/test_internal_utils.py | 17 ++++++++++------- tests/unit/test_power_grid_model.py | 6 +++--- tests/unit/test_serialization.py | 12 ++++++------ tests/unit/validation/test_batch_validation.py | 8 ++++---- tests/unit/validation/test_input_validation.py | 4 ++-- 7 files changed, 29 insertions(+), 26 deletions(-) diff --git a/src/power_grid_model/core/power_grid_dataset.py b/src/power_grid_model/core/power_grid_dataset.py index 05db79d4b..48136e5cd 100644 --- a/src/power_grid_model/core/power_grid_dataset.py +++ b/src/power_grid_model/core/power_grid_dataset.py @@ -467,7 +467,7 @@ def _filter_attributes(self, attributes): del attributes[key] def _filter_with_option(self): - if self._data_filter is ComponentAttributeFilterOptions.RELEVANT: + if self._data_filter is ComponentAttributeFilterOptions.relevant: for attributes in self._data.values(): self._filter_attributes(attributes) @@ -475,7 +475,7 @@ def _filter_with_mapping(self): for component_type, attributes in self._data.items(): if component_type in self._data_filter: filter_option = self._data_filter[component_type] - if filter_option is ComponentAttributeFilterOptions.RELEVANT: + if filter_option is ComponentAttributeFilterOptions.relevant: self._filter_attributes(attributes) def _post_filtering(self): diff --git a/src/power_grid_model/enum.py b/src/power_grid_model/enum.py index 92fc5f7aa..d6ae35600 100644 --- a/src/power_grid_model/enum.py +++ b/src/power_grid_model/enum.py @@ -208,7 +208,7 @@ class _ExperimentalFeatures(IntEnum): class ComponentAttributeFilterOptions(IntEnum): """Filter option component or attribute""" - ALL = 0 + everything = 0 """Filter all components/attributes""" - RELEVANT = 1 + relevant = 1 """Filter only non-empty components/attributes that contain non-NaN values""" diff --git a/tests/unit/test_internal_utils.py b/tests/unit/test_internal_utils.py index 1e876f7ea..7764ce6b8 100644 --- a/tests/unit/test_internal_utils.py +++ b/tests/unit/test_internal_utils.py @@ -612,15 +612,18 @@ def test_convert_batch_dataset_to_batch_list_invalid_type_sparse(_mock: MagicMoc convert_batch_dataset_to_batch_list(update_data) -DATA_FILTER_ALL = ComponentAttributeFilterOptions.ALL -DATA_FILTER_RELEVANT = ComponentAttributeFilterOptions.RELEVANT +DATA_FILTER_EVERYTHING = ComponentAttributeFilterOptions.everything +DATA_FILTER_RELEVANT = ComponentAttributeFilterOptions.relevant @pytest.mark.parametrize( ("data_filter", "expected"), [ (None, {CT.node: None, CT.sym_load: None, CT.source: None}), - (DATA_FILTER_ALL, {CT.node: DATA_FILTER_ALL, CT.sym_load: DATA_FILTER_ALL, CT.source: DATA_FILTER_ALL}), + ( + DATA_FILTER_EVERYTHING, + {CT.node: DATA_FILTER_EVERYTHING, CT.sym_load: DATA_FILTER_EVERYTHING, CT.source: DATA_FILTER_EVERYTHING}, + ), ( DATA_FILTER_RELEVANT, {CT.node: DATA_FILTER_RELEVANT, CT.sym_load: DATA_FILTER_RELEVANT, CT.source: DATA_FILTER_RELEVANT}, @@ -630,11 +633,11 @@ def test_convert_batch_dataset_to_batch_list_invalid_type_sparse(_mock: MagicMoc ({CT.node: [], CT.sym_load: []}, {CT.node: [], CT.sym_load: []}), ({CT.node: [], CT.sym_load: ["p"]}, {CT.node: [], CT.sym_load: ["p"]}), ({CT.node: None, CT.sym_load: ["p"]}, {CT.node: None, CT.sym_load: ["p"]}), - ({CT.node: DATA_FILTER_ALL, CT.sym_load: ["p"]}, {CT.node: DATA_FILTER_ALL, CT.sym_load: ["p"]}), + ({CT.node: DATA_FILTER_EVERYTHING, CT.sym_load: ["p"]}, {CT.node: DATA_FILTER_EVERYTHING, CT.sym_load: ["p"]}), ({CT.node: DATA_FILTER_RELEVANT, CT.sym_load: ["p"]}, {CT.node: DATA_FILTER_RELEVANT, CT.sym_load: ["p"]}), ( - {CT.node: DATA_FILTER_ALL, CT.sym_load: DATA_FILTER_ALL}, - {CT.node: DATA_FILTER_ALL, CT.sym_load: DATA_FILTER_ALL}, + {CT.node: DATA_FILTER_EVERYTHING, CT.sym_load: DATA_FILTER_EVERYTHING}, + {CT.node: DATA_FILTER_EVERYTHING, CT.sym_load: DATA_FILTER_EVERYTHING}, ), ( {CT.node: DATA_FILTER_RELEVANT, CT.sym_load: DATA_FILTER_RELEVANT}, @@ -875,7 +878,7 @@ def compare_row_data(actual_row_data, desired_row_data): def test_dense_row_data_to_from_col_data(row_data): # row data to columnar data and back col_data = compatibility_convert_row_columnar_dataset( - row_data, ComponentAttributeFilterOptions.ALL, DatasetType.update + row_data, ComponentAttributeFilterOptions.everything, DatasetType.update ) new_row_data = compatibility_convert_row_columnar_dataset(col_data, None, DatasetType.update) compare_row_data(row_data, new_row_data) diff --git a/tests/unit/test_power_grid_model.py b/tests/unit/test_power_grid_model.py index ed5c2b265..91ec78eb2 100644 --- a/tests/unit/test_power_grid_model.py +++ b/tests/unit/test_power_grid_model.py @@ -72,7 +72,7 @@ def input_row(): @pytest.fixture def input_col(input_row): return compatibility_convert_row_columnar_dataset( - input_row, ComponentAttributeFilterOptions.RELEVANT, DatasetType.input + input_row, ComponentAttributeFilterOptions.relevant, DatasetType.input ) @@ -117,7 +117,7 @@ def update_batch_row(): @pytest.fixture def update_batch_col(update_batch_row): return compatibility_convert_row_columnar_dataset( - update_batch_row, ComponentAttributeFilterOptions.RELEVANT, DatasetType.update + update_batch_row, ComponentAttributeFilterOptions.relevant, DatasetType.update ) @@ -163,7 +163,7 @@ def test_update_error(model: PowerGridModel): with pytest.raises(PowerGridError, match="The id cannot be found:"): model.update(update_data=update_data) update_data_col = compatibility_convert_row_columnar_dataset( - update_data, ComponentAttributeFilterOptions.RELEVANT, DatasetType.update + update_data, ComponentAttributeFilterOptions.relevant, DatasetType.update ) with pytest.raises(PowerGridError, match="The id cannot be found:"): model.update(update_data=update_data_col) diff --git a/tests/unit/test_serialization.py b/tests/unit/test_serialization.py index 3d770d428..e7015bcba 100644 --- a/tests/unit/test_serialization.py +++ b/tests/unit/test_serialization.py @@ -377,17 +377,17 @@ def serialized_data(request): @pytest.fixture( params=[ pytest.param(None, id="All row filter"), - pytest.param(ComponentAttributeFilterOptions.ALL, id="All columnar filter"), - pytest.param(ComponentAttributeFilterOptions.RELEVANT, id="All relevant columnar filter"), + pytest.param(ComponentAttributeFilterOptions.everything, id="All columnar filter"), + pytest.param(ComponentAttributeFilterOptions.relevant, id="All relevant columnar filter"), pytest.param({"node": ["id"], "sym_load": ["id"]}, id="columnar filter"), pytest.param({"node": ["id"], "sym_load": None}, id="mixed columnar/row filter"), pytest.param({"node": ["id"], "shunt": None}, id="unused component filter"), pytest.param( { "node": ["id"], - "line": ComponentAttributeFilterOptions.ALL, + "line": ComponentAttributeFilterOptions.everything, "sym_load": None, - "asym_load": ComponentAttributeFilterOptions.RELEVANT, + "asym_load": ComponentAttributeFilterOptions.relevant, }, id="mixed filter", ), @@ -629,9 +629,9 @@ def _check_only_relevant_attributes_present(component_values) -> bool: def assert_deserialization_filtering_correct(deserialized_dataset: Dataset, data_filter) -> bool: - if data_filter is ComponentAttributeFilterOptions.ALL: + if data_filter is ComponentAttributeFilterOptions.everything: return True - if data_filter is ComponentAttributeFilterOptions.RELEVANT: + if data_filter is ComponentAttributeFilterOptions.relevant: for component_values in deserialized_dataset.values(): if not _check_only_relevant_attributes_present(component_values): return False diff --git a/tests/unit/validation/test_batch_validation.py b/tests/unit/validation/test_batch_validation.py index ae02a4e53..26520dae9 100644 --- a/tests/unit/validation/test_batch_validation.py +++ b/tests/unit/validation/test_batch_validation.py @@ -45,14 +45,14 @@ def original_input_data() -> dict[str, np.ndarray]: @pytest.fixture def original_input_data_columnar_all(original_input_data): return compatibility_convert_row_columnar_dataset( - original_input_data, ComponentAttributeFilterOptions.ALL, DatasetType.input + original_input_data, ComponentAttributeFilterOptions.everything, DatasetType.input ) @pytest.fixture def original_input_data_columnar_relevant(original_input_data): return compatibility_convert_row_columnar_dataset( - original_input_data, ComponentAttributeFilterOptions.RELEVANT, DatasetType.input + original_input_data, ComponentAttributeFilterOptions.relevant, DatasetType.input ) @@ -79,14 +79,14 @@ def original_batch_data() -> dict[str, np.ndarray]: @pytest.fixture def original_batch_data_columnar_all(original_batch_data): return compatibility_convert_row_columnar_dataset( - original_batch_data, ComponentAttributeFilterOptions.ALL, DatasetType.update + original_batch_data, ComponentAttributeFilterOptions.everything, DatasetType.update ) @pytest.fixture def original_batch_data_columnar_relevant(original_batch_data): return compatibility_convert_row_columnar_dataset( - original_batch_data, ComponentAttributeFilterOptions.RELEVANT, DatasetType.update + original_batch_data, ComponentAttributeFilterOptions.relevant, DatasetType.update ) diff --git a/tests/unit/validation/test_input_validation.py b/tests/unit/validation/test_input_validation.py index 3aea93e71..383e1448e 100644 --- a/tests/unit/validation/test_input_validation.py +++ b/tests/unit/validation/test_input_validation.py @@ -278,14 +278,14 @@ def original_data() -> dict[ComponentType, np.ndarray]: @pytest.fixture def original_data_columnar_all(original_data): return compatibility_convert_row_columnar_dataset( - original_data, ComponentAttributeFilterOptions.ALL, DatasetType.input + original_data, ComponentAttributeFilterOptions.everything, DatasetType.input ) @pytest.fixture def original_data_columnar_relevant(original_data): return compatibility_convert_row_columnar_dataset( - original_data, ComponentAttributeFilterOptions.RELEVANT, DatasetType.input + original_data, ComponentAttributeFilterOptions.relevant, DatasetType.input ) From 057b95df3d83cd60c7302411974b7d6db71403c9 Mon Sep 17 00:00:00 2001 From: Jerry Guo Date: Fri, 18 Oct 2024 10:32:03 +0200 Subject: [PATCH 2/2] fix coverage Signed-off-by: Jerry Guo --- src/power_grid_model/core/power_grid_dataset.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/src/power_grid_model/core/power_grid_dataset.py b/src/power_grid_model/core/power_grid_dataset.py index 48136e5cd..f689e533c 100644 --- a/src/power_grid_model/core/power_grid_dataset.py +++ b/src/power_grid_model/core/power_grid_dataset.py @@ -466,11 +466,6 @@ def _filter_attributes(self, attributes): for key in keys_to_remove: del attributes[key] - def _filter_with_option(self): - if self._data_filter is ComponentAttributeFilterOptions.relevant: - for attributes in self._data.values(): - self._filter_attributes(attributes) - def _filter_with_mapping(self): for component_type, attributes in self._data.items(): if component_type in self._data_filter: @@ -479,9 +474,7 @@ def _filter_with_mapping(self): self._filter_attributes(attributes) def _post_filtering(self): - if isinstance(self._data_filter, ComponentAttributeFilterOptions): - self._filter_with_option() - elif isinstance(self._data_filter, dict): + if isinstance(self._data_filter, dict): self._filter_with_mapping()