Skip to content

Commit

Permalink
Merge pull request #779 from PowerGridModel/bug/convert-to-col-util
Browse files Browse the repository at this point in the history
Bug/Convert to/form row/col util function
  • Loading branch information
TonyXiang8787 authored Oct 11, 2024
2 parents e15588a + bf20758 commit ea6f61d
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/power_grid_model/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ def compatibility_convert_row_columnar_dataset(
)

if is_sparse(data[comp_name]):
result_data[comp_name] = {"indptr": _extract_indptr(data), "data": converted_sub_data}
result_data[comp_name] = {"indptr": _extract_indptr(data[comp_name]), "data": converted_sub_data}
else:
result_data[comp_name] = converted_sub_data
return result_data
Expand Down
78 changes: 77 additions & 1 deletion tests/unit/test_internal_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import numpy as np
import pytest

from power_grid_model import initialize_array
from power_grid_model import ComponentType, DatasetType, initialize_array
from power_grid_model._utils import (
compatibility_convert_row_columnar_dataset,
convert_batch_dataset_to_batch_list,
Expand All @@ -17,6 +17,7 @@
get_batch_size,
get_dataset_type,
is_nan,
is_sparse,
process_data_filter,
split_dense_batch_data_in_batches,
split_sparse_batch_data_in_batches,
Expand Down Expand Up @@ -803,3 +804,78 @@ def test_get_dataset_type__conflicting_data():
else:
with pytest.raises(PowerGridError):
get_dataset_type(data=data)


@pytest.fixture
def row_dense_data():
source = initialize_array(DatasetType.update, ComponentType.source, (2, 3))
source["id"] = [[0, 2, 3], [0, 2, 3]]
source["u_ref"] = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]

sym_load = initialize_array(DatasetType.update, ComponentType.sym_load, (2, 1))
sym_load["id"] = [[1], [1]]
sym_load["p_specified"] = [[100.0], [200.0]]

return {
ComponentType.source: source,
ComponentType.sym_load: sym_load,
}


@pytest.fixture
def row_sparse_data():
transformer = initialize_array(DatasetType.update, ComponentType.transformer, 1)
transformer["id"] = 1
transformer["tap_pos"] = 3

sym_gen = initialize_array(DatasetType.update, ComponentType.sym_gen, 8)
sym_gen["id"] = [5, 6, 7, 8, 5, 6, 7, 8]
sym_gen["q_specified"] = [1.1, 2.2, 3.3, 4.4, 4.4, 3.3, 2.2, 1.1]

return {
ComponentType.transformer: {
"data": transformer,
"indptr": np.array([0, 1, 1]),
},
ComponentType.sym_gen: {
"data": sym_gen,
"indptr": np.array([0, 4, 8]),
},
}


@pytest.fixture(params=["row_dense_data", "row_sparse_data"])
def row_data(request):
return request.getfixturevalue(request.param)


def compare_row_data(actual_row_data, desired_row_data):
assert actual_row_data.keys() == desired_row_data.keys()

for comp_name in actual_row_data.keys():
actual_component = actual_row_data[comp_name]
desired_component = desired_row_data[comp_name]
if is_sparse(actual_component):
assert actual_component.keys() == desired_component.keys()
assert np.array_equal(actual_component["indptr"], desired_component["indptr"])
actual_component = actual_component["data"]
desired_component = desired_component["data"]
assert actual_component.dtype == desired_component.dtype
assert actual_component.shape == desired_component.shape

for field in actual_component.dtype.names:
actual_attr = actual_component[field]
desired_attr = desired_component[field]
nan_a = np.isnan(actual_attr)
nan_b = np.isnan(desired_attr)
np.testing.assert_array_equal(nan_a, nan_b)
np.testing.assert_allclose(actual_attr[~nan_a], desired_attr[~nan_b], rtol=1e-15)


def test_dense_row_data_to_from_col_data(row_data):
# row data to columnar data and back
col_data = compatibility_convert_row_columnar_dataset(
row_data, ComponentAttributeFilterOptions.ALL, DatasetType.update
)
new_row_data = compatibility_convert_row_columnar_dataset(col_data, None, DatasetType.update)
compare_row_data(row_data, new_row_data)
2 changes: 1 addition & 1 deletion tests/unit/validation/test_batch_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def test_validate_batch_data(input_data, batch_data):


def test_validate_batch_data_input_error(input_data, batch_data):
if is_columnar(input_data):
if is_columnar(input_data["node"]):
input_data["node"]["id"][-1] = 123
input_data["line"]["id"][-1] = 123
else:
Expand Down

0 comments on commit ea6f61d

Please sign in to comment.