Skip to content

Commit

Permalink
Merge pull request #776 from PowerGridModel/featue/check-compatible-type
Browse files Browse the repository at this point in the history
Add dtype compatibility check
  • Loading branch information
mgovers authored Oct 14, 2024
2 parents 68d007b + 2737a55 commit f295da5
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 0 deletions.
5 changes: 5 additions & 0 deletions src/power_grid_model/core/buffer_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
Power grid model buffer handler
"""

import warnings
from dataclasses import dataclass
from typing import cast

Expand Down Expand Up @@ -83,6 +84,8 @@ def _get_raw_data_view(data: np.ndarray, dtype: np.dtype) -> VoidPtr:
Returns:
a raw view on the data set.
"""
if data.dtype != dtype:
warnings.warn("Data type does not match schema. {VALIDATOR_MSG}", DeprecationWarning)
return np.ascontiguousarray(data, dtype=dtype).ctypes.data_as(VoidPtr)


Expand Down Expand Up @@ -115,6 +118,8 @@ def _get_raw_attribute_data_view(data: np.ndarray, schema: ComponentMetaData, at
Returns:
a raw view on the data set.
"""
if schema.dtype[attribute].shape == (3,) and data.shape[-1] != 3:
raise ValueError("Given data has a different schema than supported.")
return _get_raw_data_view(data, dtype=schema.dtype[attribute].base)


Expand Down
36 changes: 36 additions & 0 deletions tests/unit/test_data_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,16 @@
#
# SPDX-License-Identifier: MPL-2.0

import warnings

import numpy as np
import pytest

from power_grid_model._utils import is_columnar
from power_grid_model.core.data_handling import create_output_data
from power_grid_model.core.dataset_definitions import ComponentType as CT, DatasetType as DT
from power_grid_model.core.power_grid_core import VoidPtr
from power_grid_model.core.power_grid_dataset import CMutableDataset
from power_grid_model.core.power_grid_meta import initialize_array


Expand Down Expand Up @@ -69,3 +73,35 @@ def test_create_output_data(output_component_types, is_batch, expected):
else:
assert actual[comp].keys() == expected[comp].keys()
assert all(actual[comp][attr].dtype == expected[comp][attr].dtype for attr in expected[comp])


def test_dtype_compatibility_check_normal():
nodes = initialize_array(DT.sym_output, CT.node, (1, 2))
nodes_ptr = nodes.ctypes.data_as(VoidPtr)

data = {CT.node: nodes}
mutable_dataset = CMutableDataset(data, DT.sym_output)
buffer_views = mutable_dataset.get_buffer_views()

assert buffer_views[0].data.value == nodes_ptr.value


def test_dtype_compatibility_check_compatible():
nodes = initialize_array(DT.sym_output, CT.node, 4)
nodes = nodes[::2]
nodes_ptr = nodes.ctypes.data_as(VoidPtr)

data = {CT.node: nodes}
with warnings.catch_warnings():
warnings.simplefilter("error")
mutable_dataset = CMutableDataset(data, DT.sym_output)
buffer_views = mutable_dataset.get_buffer_views()

assert buffer_views[0].data.value != nodes_ptr.value


def test_dtype_compatibility_check__error():
nodes = initialize_array(DT.sym_output, CT.node, (1, 2))
data = {CT.node: nodes.astype(nodes.dtype.newbyteorder("S"))}
with pytest.warns(DeprecationWarning):
CMutableDataset(data, DT.sym_output)

0 comments on commit f295da5

Please sign in to comment.