Skip to content

Commit

Permalink
Debugged adding transit - passing test and ready to be tested more.
Browse files Browse the repository at this point in the history
- add default service_id
- update validation errors that are raised to TableValidationError
- add NodeNotFoundError
- debug node_coords()
  • Loading branch information
e-lo committed Sep 9, 2024
1 parent fd6d352 commit 62c27f2
Show file tree
Hide file tree
Showing 11 changed files with 70 additions and 43 deletions.
2 changes: 1 addition & 1 deletion network_wrangler/models/gtfs/tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ class TripsTable(pa.DataFrameModel):
direction_id: Series[Category] = pa.Field(
dtype_kwargs={"categories": DirectionID}, coerce=True, nullable=False, default=0
)
service_id: Series[str] = pa.Field(nullable=False, coerce=True)
service_id: Series[str] = pa.Field(nullable=False, coerce=True, default="1")
route_id: Series[str] = pa.Field(nullable=False, coerce=True)

# Optional Fields
Expand Down
4 changes: 2 additions & 2 deletions network_wrangler/models/projects/roadway_deletion.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,6 @@ class RoadwayDeletion(RecordModel):
def set_to_all_modes(cls, links: Optional[SelectLinksDict] = None):
"""Set the search mode to 'any' if not specified explicitly."""
if links is not None:
if links.mode == DEFAULT_SEARCH_MODES:
links.mode = DEFAULT_DELETE_MODES
if links.modes == DEFAULT_SEARCH_MODES:
links.modes = DEFAULT_DELETE_MODES
return links
8 changes: 4 additions & 4 deletions network_wrangler/models/roadway/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,12 +158,12 @@ def _v1_to_v0_scoped_link_property(v1_row: Series, prop: str) -> dict:
"""
v0_item_list = []
for v1_item in v1_row[prop]:
v0_item = {"value": v1_item.value}
v0_item = {"value": v1_item["value"]}
if "timespan" in v1_item:
# time is a tuple of seconds from midnight from a tuple of "HH:MM"
v0_item["time"] = tuple([str_to_seconds_from_midnight(t) for t in v1_item.timespan])
if v1_item.category != DEFAULT_CATEGORY:
v0_item["category"] = [v1_item.category]
v0_item["time"] = tuple([str_to_seconds_from_midnight(t) for t in v1_item["timespan"]])
if "category" in v1_item and v1_item.get("category") != DEFAULT_CATEGORY:
v0_item["category"] = [v1_item["category"]]
v0_item_list.append(v0_item)

default_prop = prop[3:]
Expand Down
14 changes: 7 additions & 7 deletions network_wrangler/roadway/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,26 +189,26 @@ def load_roadway_from_dir(

def write_roadway(
net: Union[RoadwayNetwork, ModelRoadwayNetwork],
convert_complex_link_properties_to_single_field: bool = False,
out_dir: Union[Path, str] = ".",
prefix: str = "",
file_format: GeoFileTypes = "geojson",
convert_complex_link_properties_to_single_field: bool = False,
overwrite: bool = True,
true_shape: bool = False,
) -> None:
"""Writes a network in the roadway network standard.
Args:
net: RoadwayNetwork or ModelRoadwayNetwork instance to write out
net: RoadwayNetwork or ModelRoadwayNetwork instance to write out.
out_dir: the path were the output will be saved. Defaults to ".".
prefix: the name prefix of the roadway files that will be generated.
file_format: the format of the output files. Defaults to "geojson".
convert_complex_link_properties_to_single_field: if True, will convert complex link
properties to a single column consistent with v0 format. This format is NOT valid
with parquet and many other softwares. Defaults to False.
out_dir: the path were the output will be saved
prefix: the name prefix of the roadway files that will be generated
file_format: the format of the output files. Defaults to "geojson"
overwrite: if True, will overwrite the files if they already exist. Defaults to True
overwrite: if True, will overwrite the files if they already exist. Defaults to True.
true_shape: if True, will write the true shape of the links as found from shapes.
Defaults to False
Defaults to False.
"""
out_dir = Path(out_dir)
if not out_dir.is_dir():
Expand Down
6 changes: 6 additions & 0 deletions network_wrangler/roadway/links/links.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@ class NotLinksError(Exception):
pass


class LinkNotFoundError(Exception):
"""Raised when a link is not found in the links table."""

pass


def node_ids_in_links(
links_df: DataFrame[RoadLinksTable], nodes_df: Optional[DataFrame[RoadNodesTable]] = None
) -> list[int]:
Expand Down
10 changes: 7 additions & 3 deletions network_wrangler/roadway/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
from .links.filters import filter_links_to_ids, filter_links_to_node_ids
from .links.delete import delete_links_by_ids
from .links.edit import edit_link_geometry_from_nodes
from .nodes.nodes import node_ids_without_links
from .nodes.nodes import node_ids_without_links, NodeNotFoundError
from .nodes.filters import filter_nodes_to_links
from .nodes.delete import delete_nodes_by_ids
from .nodes.edit import edit_node_geometry
Expand Down Expand Up @@ -384,8 +384,12 @@ def nodes_in_links(self) -> DataFrame[RoadNodesTable]:

def node_coords(self, model_node_id: int) -> tuple:
"""Return coordinates (x, y) of a node based on model_node_id."""
node = self.nodes_df[self.nodes_df.model_node_id == model_node_id]
return node.geometry.x[0], node.geometry.y[0]
try:
node = self.nodes_df[self.nodes_df.model_node_id == model_node_id]
except ValueError:
WranglerLogger.error(f"Node with model_node_id {model_node_id} not found.")
raise NodeNotFoundError(f"Node with model_node_id {model_node_id} not found.")
return node.geometry.x.values[0], node.geometry.y.values[0]

def add_links(self, add_links_df: Union[pd.DataFrame, DataFrame[RoadLinksTable]]):
"""Validate combined links_df with LinksSchema before adding to self.links_df.
Expand Down
6 changes: 6 additions & 0 deletions network_wrangler/roadway/nodes/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@ class NotNodesError(Exception):
pass


class NodeNotFoundError(Exception):
"""Raised when a node is not found in the nodes table."""

pass


def node_ids_without_links(
nodes_df: DataFrame[RoadNodesTable], links_df: DataFrame[RoadLinksTable]
) -> list[int]:
Expand Down
28 changes: 19 additions & 9 deletions network_wrangler/transit/projects/add_route.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
import pandas as pd
from pandera.typing import DataFrame as paDataFrame

from ...utils.time import str_to_time_list, TimeString
from ...utils.time import str_to_time_list
from ...utils.utils import fill_str_ids
from ...utils.models import fill_df_with_defaults_from_model
from ...models._base.types import TimeString
from ...models.gtfs.tables import (
TripsTable,
WranglerShapesTable,
Expand Down Expand Up @@ -85,14 +86,16 @@ def _add_route_to_feed(
Feed: transit feed.
"""
WranglerLogger.debug(f"Adding route {len(add_routes)} to feed.")
add_routes_df = pd.DataFrame({k: v for r in add_routes for k, v in r.items() if k != "trips"})
add_routes_df = pd.DataFrame(
[{k: v for k, v in r.items() if k != "trips"} for r in add_routes]
)
routes_df = pd.concat([feed.routes, add_routes_df], ignore_index=True, sort=False)

for route in add_routes:
WranglerLogger.debug(
f"Adding {len(route['trips'])} trips for route {route['route_id']} to feed."
)
add_trips_df = _create_new_trips(route, feed.shapes)
add_trips_df = _create_new_trips(route["trips"], route["route_id"], feed.shapes)
trips_df = pd.concat([feed.trips, add_trips_df], ignore_index=True, sort=False)

for i, trip in enumerate(route["trips"]):
Expand All @@ -117,18 +120,18 @@ def _add_route_to_feed(
[feed.stop_times, add_stop_times_df], ignore_index=True, sort=False
)

feed.stops = stops_df
feed.routes = routes_df
feed.shapes = shapes_df
feed.trips = trips_df
feed.stop_times = stop_times_df
feed.stops = stops_df
feed.frequencies = frequencies_df

return feed


def _create_new_trips(

Check failure on line 133 in network_wrangler/transit/projects/add_route.py

View workflow job for this annotation

GitHub Actions / build (3.8)

Ruff (D417)

network_wrangler/transit/projects/add_route.py:133:5: D417 Missing argument description in the docstring for `_create_new_trips`: `route_id`
trips: list[dict], shapes_df: paDataFrame[WranglerShapesTable]
trips: list[dict], route_id: str, shapes_df: paDataFrame[WranglerShapesTable]
) -> paDataFrame[TripsTable]:
"""Create new trips for a route.
Expand All @@ -137,8 +140,15 @@ def _create_new_trips(
shapes_df: Shapes dataframe to get shape_id from.
"""
FILTER_OUT = ["routing", "headway_secs"]
add_trips_df = pd.DataFrame({k: v for r in trips for k, v in r.items() if k not in FILTER_OUT})
add_trips_df = fill_df_with_defaults_from_model(add_trips_df, TripsTable)
add_trips_df = pd.DataFrame(
[{k: v for k, v in r.items() if k not in FILTER_OUT} for r in trips]
)
add_trips_df["route_id"] = route_id
if "shape_id" not in add_trips_df.columns:
add_trips_df["shape_id"] = None

if "trip_id" not in add_trips_df.columns:
add_trips_df["trip_id"] = None
add_trips_df["shape_id"] = fill_str_ids(add_trips_df["shape_id"], shapes_df["shape_id"])
add_trips_df["trip_id"] = add_trips_df["trip_id"].fillna(
add_trips_df["shape_id"].apply(lambda x: f"tr_shp{x}")
Expand Down Expand Up @@ -201,14 +211,14 @@ def _get_stops_from_routing(routing: list[Union[dict, int]]) -> list[dict]:
if isinstance(i, dict):
stop_d = {}
stop_info = list(i.values())[0] # dict with stop, board, alight
if stop_info.get("stop", False):
if not stop_info.get("stop", False):
continue
stop_d["stop_id"] = int(list(i.keys())[0])
# Default for board and alight is True unless specified to be False
stop_d["pickup_type"] = 0 if stop_info.get("board", True) else 1
stop_d["drop_off_type"] = 0 if stop_info.get("alight", True) else 1
stop_d.update({k: v for k, v in stop_info.items() if k not in FILTER_OUT})
stop_dicts.append(stop_info)
stop_dicts.append(stop_d)
return stop_dicts


Expand Down
6 changes: 3 additions & 3 deletions network_wrangler/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,12 +253,12 @@ def fill_str_ids(
str_prefix (str, optional): Prefix to add to the new ID. Defaults to "".
str_suffix (str, optional): Suffix to add to the new ID. Defaults to "".
"""
if taken_ids_s.iloc[0] != str:
if not isinstance(taken_ids_s.iloc[0], str):
raise ValueError("taken_ids_s must be a series of strings.")

n_ids = id_s.isna().sum()
start_id = _get_max_int_id_within_string_ids(taken_ids_s, str_prefix, str_suffix) + 1
new_ids = [f"{str_prefix}i{str_suffix}" for i in range(start_id, start_id + n_ids)]
new_ids = [f"{str_prefix}{i}{str_suffix}" for i in range(start_id, start_id + n_ids)]
id_s.loc[id_s.isna()] = new_ids
return id_s

Expand All @@ -270,7 +270,7 @@ def fill_int_ids(id_s: pd.Series, taken_ids_s: pd.Series) -> pd.Series:
id_s (pd.Series): Series of IDs to fill.
taken_ids_s (pd.Series): Series of IDs that are already taken.
"""
if taken_ids_s.iloc[0] != int:
if not isinstance(taken_ids_s.iloc[0], int):
raise ValueError("id_s must be a series of integers.")
n_ids = id_s.isna().sum()
start_id = max(set(taken_ids_s.dropna())) + 1
Expand Down
4 changes: 2 additions & 2 deletions tests/test_models/test_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pandas as pd
import pytest
from pandera import DataFrameModel
from pandera.errors import SchemaErrors
from network_wrangler.utils.models import TableValidationError
from network_wrangler.models._base.db import DBModelMixin, ForeignKeyValueError


Expand Down Expand Up @@ -41,5 +41,5 @@ def test_validate_db_table():
with pytest.raises(ForeignKeyValueError):
db.table_b = pd.DataFrame({"B_ID": [4, 5, 6], "a_value": [3, 4, 5]})

with pytest.raises(SchemaErrors):
with pytest.raises(TableValidationError):
db.table_a = pd.DataFrame({"B_ID": ["hi", "there", "buddy"], "a_value": [3, 4, 5]})
25 changes: 13 additions & 12 deletions tests/test_transit/test_changes/test_transit_add_route.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,17 +54,18 @@ def test_add_route_to_feed_dict(
):
WranglerLogger.info(f"--Starting: {request.node.name}")
small_transit_net = copy.deepcopy(small_transit_net)
updated_feed = apply_transit_route_addition(
updated_transit_net = apply_transit_route_addition(
small_transit_net, add_route_change["transit_route_addition"], small_net
)
updated_feed = updated_transit_net.feed

# check trips
new_trips = updated_feed.trips.trip_id.loc[updated_feed.trips.route_id.isin("abc")]
new_trips = updated_feed.trips.loc[updated_feed.trips.route_id.isin(["abc"])]
new_trip_ids = new_trips.trip_id.to_list()
assert len(new_trip_ids) == 1

# check routes
new_routes = updated_feed.routes.loc[updated_feed.routes.route_id.isin("abc")]
new_routes = updated_feed.routes.loc[updated_feed.routes.route_id.isin(["abc"])]
assert len(new_routes) == 1

assert new_routes.route_long_name.loc[new_routes.route_long_name == "green_line"].all()
Expand All @@ -82,27 +83,27 @@ def test_add_route_to_feed_dict(
updated_feed.stop_times.trip_id.isin(new_trip_ids)
]
assert new_stop_times.stop_id.isin(new_stops).all()
assert new_stop_times.at[new_stop_times.stop_id == 4, "drop_off_type"] == 1
assert new_stop_times.at[new_stop_times.stop_id == 4, "pickup_type"] == 0
assert new_stop_times.at[new_stop_times.stop_id == 6, "drop_off_type"] == 0
assert new_stop_times.at[new_stop_times.stop_id == 6, "pickup_type"] == 0
assert new_stop_times.at[new_stop_times.stop_id == 1, "drop_off_type"] == 0
assert new_stop_times.at[new_stop_times.stop_id == 1, "pickup_type"] == 0
assert new_stop_times.loc[new_stop_times.stop_id == 4, "drop_off_type"].iloc[0] == 1
assert new_stop_times.loc[new_stop_times.stop_id == 4, "pickup_type"].iloc[0] == 0
assert new_stop_times.loc[new_stop_times.stop_id == 6, "drop_off_type"].iloc[0] == 0
assert new_stop_times.loc[new_stop_times.stop_id == 6, "pickup_type"].iloc[0] == 0
assert new_stop_times.loc[new_stop_times.stop_id == 1, "drop_off_type"].iloc[0] == 0
assert new_stop_times.loc[new_stop_times.stop_id == 1, "pickup_type"].iloc[0] == 0

# check shapes
new_shape_ids = new_trips.shape_id.unique()
new_shapes = updated_feed.shapes.loc[updated_feed.shapes.shape_id.isin(new_shape_ids)]
assert len(new_shapes) == 6
expected_shape_modeL_node_ids = [1, 2, 3, 4, 5, 6]
new_shapes.shape_modeL_node_id.isin(expected_shape_modeL_node_ids).all()
assert new_shapes.shape_model_node_id.isin(expected_shape_modeL_node_ids).all()

# check frequencies
new_frequencies = updated_feed.frequencies.loc[
updated_feed.frequencies.trip_id.isin(new_trip_ids)
]
assert len(new_frequencies) == 2
assert new_frequencies.headway_secs.isin([600, 900])
assert new_frequencies.start_time.isin([str_to_time("6:00"), str_to_time("12:00")])
assert new_frequencies.headway_secs.isin([600, 900]).all()
assert new_frequencies.start_time.isin([str_to_time("6:00"), str_to_time("12:00")]).all()
WranglerLogger.info(f"--Finished: {request.node.name}")


Expand Down

0 comments on commit 62c27f2

Please sign in to comment.