Skip to content

Commit

Permalink
Merge pull request #811 from PowerGridModel/feature/error-on-type-check
Browse files Browse the repository at this point in the history
Throw error on wrong dtype
  • Loading branch information
mgovers authored Nov 8, 2024
2 parents 2dfa39a + 5ba686b commit f9889a1
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 9 deletions.
3 changes: 1 addition & 2 deletions src/power_grid_model/core/buffer_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
Power grid model buffer handler
"""

import warnings
from dataclasses import dataclass
from typing import cast

Expand Down Expand Up @@ -85,7 +84,7 @@ def _get_raw_data_view(data: np.ndarray, dtype: np.dtype) -> VoidPtr:
a raw view on the data set.
"""
if data.dtype != dtype:
warnings.warn(f"Data type does not match schema. {VALIDATOR_MSG}", DeprecationWarning)
raise ValueError(f"Data type does not match schema. {VALIDATOR_MSG}")
return np.ascontiguousarray(data, dtype=dtype).ctypes.data_as(VoidPtr)


Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_data_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,5 +168,5 @@ def test_dtype_compatibility_check_compatible():
def test_dtype_compatibility_check__error():
nodes = initialize_array(DT.sym_output, CT.node, (1, 2))
data = {CT.node: nodes.astype(nodes.dtype.newbyteorder("S"))}
with pytest.warns(DeprecationWarning):
with pytest.raises(ValueError):
CMutableDataset(data, DT.sym_output)
36 changes: 30 additions & 6 deletions tests/unit/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def test_const_dataset__conflicting_data():
with pytest.raises(PowerGridError):
CConstDataset(
data={
"node": np.zeros(1, dtype=power_grid_meta_data["input"]["node"]),
ComponentType.node: np.zeros(1, dtype=power_grid_meta_data["input"][ComponentType.node]),
"sym_load": np.zeros(1, dtype=power_grid_meta_data["update"]["sym_load"]),
}
)
Expand Down Expand Up @@ -108,7 +108,7 @@ def test_const_dataset__sparse_batch_data(dataset_type):
components = {ComponentType.node: 3, ComponentType.sym_load: 2, ComponentType.asym_load: 4, ComponentType.link: 4}
data = {
ComponentType.node: {
"data": np.zeros(shape=3, dtype=power_grid_meta_data[dataset_type]["node"]),
"data": np.zeros(shape=3, dtype=power_grid_meta_data[dataset_type][ComponentType.node]),
"indptr": np.array([0, 2, 3, 3]),
},
ComponentType.sym_load: {
Expand Down Expand Up @@ -148,8 +148,8 @@ def test_const_dataset__sparse_batch_data(dataset_type):

def test_const_dataset__mixed_batch_size(dataset_type):
data = {
ComponentType.node: np.zeros(shape=(2, 3), dtype=power_grid_meta_data[dataset_type]["node"]),
ComponentType.line: np.zeros(shape=(3, 3), dtype=power_grid_meta_data[dataset_type]["line"]),
ComponentType.node: np.zeros(shape=(2, 3), dtype=power_grid_meta_data[dataset_type][ComponentType.node]),
ComponentType.line: np.zeros(shape=(3, 3), dtype=power_grid_meta_data[dataset_type][ComponentType.line]),
}
with pytest.raises(ValueError):
CConstDataset(data, dataset_type)
Expand All @@ -158,10 +158,34 @@ def test_const_dataset__mixed_batch_size(dataset_type):
@pytest.mark.parametrize("bad_indptr", (np.ndarray([0, 1]), np.ndarray([0, 3, 2]), np.ndarray([0, 1, 2, 3, 4])))
def test_const_dataset__bad_sparse_data(dataset_type, bad_indptr):
data = {
"node": {
"data": np.zeros(shape=2, dtype=power_grid_meta_data[dataset_type]["node"]),
ComponentType.node: {
"data": np.zeros(shape=2, dtype=power_grid_meta_data[dataset_type][ComponentType.node]),
"indptr": bad_indptr,
},
}
with pytest.raises(TypeError):
CConstDataset(data, dataset_type)


@pytest.mark.parametrize(
("dtype", "supported"),
[
(power_grid_meta_data[DatasetType.input][ComponentType.node].dtype["id"], True),
("<i4", True),
("<i8", False),
("<i1", False),
("<f8", False),
],
)
def test_const_dataset__different_dtype(dataset_type, dtype, supported):
data = {
ComponentType.node: {
"id": np.zeros(shape=3, dtype=dtype),
}
}
if supported:
result = CConstDataset(data, dataset_type)
assert result.get_info().total_elements() == {ComponentType.node: 3}
else:
with pytest.raises(ValueError):
CConstDataset(data, dataset_type)

0 comments on commit f9889a1

Please sign in to comment.