Skip to content

Commit

Permalink
make validation core functions private
Browse files Browse the repository at this point in the history
Signed-off-by: Martijn Govers <Martijn.Govers@Alliander.com>
  • Loading branch information
mgovers committed Nov 7, 2024
1 parent faec0f6 commit e984fdc
Show file tree
Hide file tree
Showing 6 changed files with 280 additions and 282 deletions.
38 changes: 19 additions & 19 deletions src/power_grid_model/validation/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,12 @@
ValidationError,
)
from power_grid_model.validation.utils import (
eval_expression,
get_indexer,
get_mask,
get_valid_ids,
nan_type,
set_default_value,
_eval_expression,
_get_indexer,
_get_mask,
_get_valid_ids,
_nan_type,
_set_default_value,
)

Error = TypeVar("Error", bound=ValidationError)
Expand Down Expand Up @@ -370,17 +370,17 @@ def none_match_comparison( # pylint: disable=too-many-arguments
where the value in the field of interest matched the comparison.
"""
if default_value_1 is not None:
set_default_value(data=data, component=component, field=field, default_value=default_value_1)
_set_default_value(data=data, component=component, field=field, default_value=default_value_1)
if default_value_2 is not None:
set_default_value(data=data, component=component, field=field, default_value=default_value_2)
_set_default_value(data=data, component=component, field=field, default_value=default_value_2)
component_data = data[component]
if not isinstance(component_data, np.ndarray):
raise NotImplementedError() # TODO(mgovers): add support for columnar data

if isinstance(ref_value, tuple):
ref = tuple(eval_expression(component_data, v) for v in ref_value)
ref = tuple(_eval_expression(component_data, v) for v in ref_value)
else:
ref = (eval_expression(component_data, ref_value),)
ref = (_eval_expression(component_data, ref_value),)
matches = compare_fn(component_data[field], *ref)
if matches.any():
if matches.ndim > 1:
Expand Down Expand Up @@ -520,7 +520,7 @@ def all_valid_enum_values(
"""
enums: list[Type[Enum]] = enum if isinstance(enum, list) else [enum]

valid = {nan_type(component, field)}
valid = {_nan_type(component, field)}
for enum_type in enums:
valid.update(list(enum_type))

Expand Down Expand Up @@ -557,13 +557,13 @@ def all_valid_associated_enum_values( # pylint: disable=too-many-positional-arg
"""
enums: list[Type[Enum]] = enum if isinstance(enum, list) else [enum]

valid_ids = get_valid_ids(data=data, ref_components=ref_components)
valid_ids = _get_valid_ids(data=data, ref_components=ref_components)
mask = np.logical_and(
get_mask(data=data, component=component, field=field, **filters),
_get_mask(data=data, component=component, field=field, **filters),
np.isin(data[component][ref_object_id_field], valid_ids),
)

valid = {nan_type(component, field)}
valid = {_nan_type(component, field)}
for enum_type in enums:
valid.update(list(enum_type))

Expand Down Expand Up @@ -596,8 +596,8 @@ def all_valid_ids(
A list containing zero or one InvalidIdError, listing all ids where the value in the field of interest
was not a valid object identifier.
"""
valid_ids = get_valid_ids(data=data, ref_components=ref_components)
mask = get_mask(data=data, component=component, field=field, **filters)
valid_ids = _get_valid_ids(data=data, ref_components=ref_components)
mask = _get_mask(data=data, component=component, field=field, **filters)

# Find any values that can't be found in the set of ids
invalid = np.logical_and(mask, np.isin(data[component][field], valid_ids, invert=True))
Expand Down Expand Up @@ -779,7 +779,7 @@ def none_missing(
for field in fields:
if isinstance(field, list):
field = field[0]
nan = nan_type(component, field)
nan = _nan_type(component, field)
if np.isnan(nan):
invalid = np.isnan(data[component][field][index])
else:
Expand Down Expand Up @@ -939,14 +939,14 @@ def all_supported_tap_control_side( # pylint: disable=too-many-arguments
A list containing zero or more InvalidAssociatedEnumValueErrors; listing all the ids
of components where the field of interest was invalid, given the referenced object's field.
"""
mask = get_mask(data=data, component=component, field=control_side_field, **filters)
mask = _get_mask(data=data, component=component, field=control_side_field, **filters)
values = data[component][control_side_field][mask]

invalid = np.zeros_like(mask)

for ref_component, ref_field in tap_side_fields:
if ref_component in data:
indices = get_indexer(data[ref_component]["id"], data[component][regulated_object_field], default_value=-1)
indices = _get_indexer(data[ref_component]["id"], data[component][regulated_object_field], default_value=-1)
found = indices != -1
ref_comp_values = data[ref_component][ref_field][indices[found]]
invalid[found] = np.logical_or(invalid[found], values[found] == ref_comp_values)
Expand Down
36 changes: 18 additions & 18 deletions src/power_grid_model/validation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from power_grid_model.validation.errors import ValidationError


def eval_expression(data: np.ndarray, expression: int | float | str) -> np.ndarray:
def _eval_expression(data: np.ndarray, expression: int | float | str) -> np.ndarray:
"""
Wrapper function that checks the type of the 'expression'. If the expression is a string, it is assumed to be a
field expression and the expression is validated. Otherwise it is assumed to be a numerical value and the value
Expand All @@ -42,11 +42,11 @@ def eval_expression(data: np.ndarray, expression: int | float | str) -> np.ndarr
"""
if isinstance(expression, str):
return eval_field_expression(data, expression)
return _eval_field_expression(data, expression)
return np.array(expression)


def eval_field_expression(data: np.ndarray, expression: str) -> np.ndarray:
def _eval_field_expression(data: np.ndarray, expression: str) -> np.ndarray:
"""
A field expression can either be the name of a field (e.g. 'field_x') in the data, or a ratio between two fields
(e.g. 'field_x / field_y'). The expression is checked on validity and then the fields are checked to be present in
Expand Down Expand Up @@ -92,18 +92,18 @@ def eval_field_expression(data: np.ndarray, expression: str) -> np.ndarray:
return np.true_divide(data[fields[0]], data[fields[1]])


def update_input_data(input_data: SingleDataset, update_data: SingleDataset):
def _update_input_data(input_data: SingleDataset, update_data: SingleDataset):
"""
Update the input data using the available non-nan values in the update data.
"""

merged_data = {component: array.copy() for component, array in input_data.items()}
for component in update_data.keys():
update_component_data(component, merged_data[component], update_data[component])
_update_component_data(component, merged_data[component], update_data[component])
return merged_data


def update_component_data(
def _update_component_data(
component: ComponentTypeLike, input_data: SingleComponentData, update_data: SingleComponentData
) -> None:
"""
Expand Down Expand Up @@ -133,7 +133,7 @@ def _update_component_array_data(
for field in update_data.dtype.names:
if field == "id":
continue
nan = nan_type(component, field, DatasetType.update)
nan = _nan_type(component, field, DatasetType.update)
if np.isnan(nan):
mask = ~np.isnan(update_data[field])
else:
Expand All @@ -143,12 +143,12 @@ def _update_component_array_data(
for phase in range(mask.shape[1]):
# find indexers of to-be-updated object
sub_mask = mask[:, phase]
idx = get_indexer(input_data["id"], update_data_ids[sub_mask])
idx = _get_indexer(input_data["id"], update_data_ids[sub_mask])
# update
input_data[field][idx, phase] = update_data[field][sub_mask, phase]
else:
# find indexers of to-be-updated object
idx = get_indexer(input_data["id"], update_data_ids[mask])
idx = _get_indexer(input_data["id"], update_data_ids[mask])
# update
input_data[field][idx] = update_data[field][mask]

Expand Down Expand Up @@ -190,15 +190,15 @@ def errors_to_string(
return msg


def nan_type(component: ComponentTypeLike, field: str, data_type: DatasetType = DatasetType.input):
def _nan_type(component: ComponentTypeLike, field: str, data_type: DatasetType = DatasetType.input):
"""
Helper function to retrieve the nan value for a certain field as defined in the power_grid_meta_data.
"""
component = _str_to_component_type(component)
return power_grid_meta_data[data_type][component].nans[field]


def get_indexer(source: np.ndarray, target: np.ndarray, default_value: Optional[int] = None) -> np.ndarray:
def _get_indexer(source: np.ndarray, target: np.ndarray, default_value: Optional[int] = None) -> np.ndarray:
"""
Given array of values from a source and a target dataset.
Find the position of each value in the target dataset in the context of the source dataset.
Expand All @@ -209,7 +209,7 @@ def get_indexer(source: np.ndarray, target: np.ndarray, default_value: Optional[
>>> input_ids = [1, 2, 3, 4, 5]
>>> update_ids = [3]
>>> assert get_indexer(input_ids, update_ids) == np.array([2])
>>> assert _get_indexer(input_ids, update_ids) == np.array([2])
Args:
source: array of values in the source dataset
Expand Down Expand Up @@ -238,7 +238,7 @@ def get_indexer(source: np.ndarray, target: np.ndarray, default_value: Optional[
return np.where(source[clipped_indices] == target, permutation_sort[clipped_indices], default_value)


def set_default_value(
def _set_default_value(
data: SingleDataset, component: ComponentTypeLike, field: str, default_value: int | float | np.ndarray
):
"""
Expand All @@ -256,17 +256,17 @@ def set_default_value(
Returns:
"""
if np.isnan(nan_type(component, field)):
if np.isnan(_nan_type(component, field)):
mask = np.isnan(data[component][field])
else:
mask = data[component][field] == nan_type(component, field)
mask = data[component][field] == _nan_type(component, field)
if isinstance(default_value, np.ndarray):
data[component][field][mask] = default_value[mask]
else:
data[component][field][mask] = default_value


def get_valid_ids(data: SingleDataset, ref_components: ComponentTypeLike | list[ComponentTypeVar]) -> list[int]:
def _get_valid_ids(data: SingleDataset, ref_components: ComponentTypeLike | list[ComponentTypeVar]) -> list[int]:
"""
This function returns the valid IDs specified by all ref_components
Expand All @@ -286,7 +286,7 @@ def get_valid_ids(data: SingleDataset, ref_components: ComponentTypeLike | list[
valid_ids = set()
for ref_component in ref_components:
if ref_component in data:
nan = nan_type(ref_component, "id")
nan = _nan_type(ref_component, "id")
if np.isnan(nan):
mask = ~np.isnan(data[ref_component]["id"])
else:
Expand All @@ -296,7 +296,7 @@ def get_valid_ids(data: SingleDataset, ref_components: ComponentTypeLike | list[
return list(valid_ids)


def get_mask(data: SingleDataset, component: ComponentTypeLike, field: str, **filters: Any) -> np.ndarray:
def _get_mask(data: SingleDataset, component: ComponentTypeLike, field: str, **filters: Any) -> np.ndarray:
"""
Get a mask based on the specified filters. E.g. measured_terminal_type=MeasuredTerminalType.source.
Expand Down
Loading

0 comments on commit e984fdc

Please sign in to comment.