Skip to content

Commit

Permalink
Fix one-to-many update function
Browse files Browse the repository at this point in the history
- add tests for it
- update to use indices propertly for one-to-many property updates
  • Loading branch information
e-lo committed Jun 11, 2024
1 parent 036efc7 commit bdef653
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 3 deletions.
11 changes: 9 additions & 2 deletions network_wrangler/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,13 @@ def _update_props_from_one_to_many(
Allows 1:many between source and destination relationship via `join_col`.
"""
destination_df.set_index(join_col, inplace=True)
source_df.set_index(join_col, inplace=True)

merged_df = destination_df.merge(
source_df[[join_col] + properties],
on=join_col,
source_df[properties],
left_index=True,
right_index=True,
how="left",
suffixes=("", "_new"),
)
Expand All @@ -175,6 +179,9 @@ def _update_props_from_one_to_many(
if len(update_idx) == 1:
update_vals = update_vals.values[0]
destination_df.loc[update_idx, prop] = update_vals

destination_df.reset_index(inplace=True)
source_df.reset_index(inplace=True)
return destination_df


Expand Down
21 changes: 20 additions & 1 deletion tests/test_transit/test_io.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from pathlib import Path

import pytest

from pandera.errors import SchemaErrors

from network_wrangler import load_transit, write_transit
from network_wrangler.transit.network import TransitNetwork
from network_wrangler import WranglerLogger
Expand Down Expand Up @@ -70,5 +74,20 @@ def test_write_feed_geo(request, small_transit_net, small_net, test_out_dir):
small_transit_net.feed,
ref_nodes_df=small_net.nodes_df,
out_dir=test_out_dir,
out_prefix="write_feed_geo",
out_prefix="write_feed_geo_small",
)
assert Path(test_out_dir / "write_feed_geo_small_trn_stops.geojson").exists()
assert Path(test_out_dir / "write_feed_geo_small_trn_shapes.geojson").exists()


def test_write_feed_geo_w_shapes(request, stpaul_transit_net, stpaul_net, test_out_dir):
from network_wrangler.transit.io import write_feed_geo

write_feed_geo(
stpaul_transit_net.feed,
ref_nodes_df=stpaul_net.nodes_df,
out_dir=test_out_dir,
out_prefix="write_feed_geo_stpaul",
)
assert Path(test_out_dir / "write_feed_geo_stpaul_trn_stops.geojson").exists()
assert Path(test_out_dir / "write_feed_geo_stpaul_trn_shapes.geojson").exists()
29 changes: 29 additions & 0 deletions tests/test_utils/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,3 +553,32 @@ def test_segment_list_by_list(request, ref_list, item_list, expected_result):
else:
calc_answer = segment_data_by_selection(item_list, ref_list)
assert expected_result == calc_answer


def test_update_props_from_one_to_many():
# Create destination_df
from network_wrangler.utils.data import _update_props_from_one_to_many
destination_df = pd.DataFrame(
{
"trip_id": [2, 2, 3, 4],
"property1": [10, 20, 30, 40],
"property2": [100, 200, 300, 400],
}
)
# Create source_df
source_df = pd.DataFrame(
{"trip_id": [2, 3], "property1": [25, pd.NA], "property2": [None, 350]}
)
# Expected updated_df
expected_df = pd.DataFrame(
{
"trip_id": [2, 2, 3, 4],
"property1": [25, 25, 30, 40],
"property2": [100, 200, 350, 400],
}
)
# Call the function
updated_df = _update_props_from_one_to_many(
destination_df, source_df, "trip_id", ["property1", "property2"])
# Check if the updated_df matches the expected_df
pd.testing.assert_frame_equal(updated_df, expected_df)

0 comments on commit bdef653

Please sign in to comment.