Skip to content

Commit

Permalink
Implement functionality for project card to dictate how to handle sco…
Browse files Browse the repository at this point in the history
…ped conflicts

Fixes #373
  • Loading branch information
e-lo committed Sep 9, 2024
1 parent b7b3b91 commit 6e4f781
Show file tree
Hide file tree
Showing 9 changed files with 67 additions and 79 deletions.
2 changes: 1 addition & 1 deletion network_wrangler/models/_base/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@
pa.String,
pa.Check.str_matches(r"^(?:[0-9]|[01]\d|2[0-3]):[0-5]\d(?::[0-5]\d)?$|^24:00(?::00)?$"),
coerce=True,
name=None # Name is set to None to ignore the Series name
name=None, # Name is set to None to ignore the Series name
)
19 changes: 6 additions & 13 deletions network_wrangler/models/gtfs/tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,14 +279,10 @@ class FrequenciesTable(pa.DataFrameModel):

trip_id: Series[str] = pa.Field(nullable=False, coerce=True)
start_time: Series[TimeString] = pa.Field(
nullable=False,
coerce=True,
default=DEFAULT_TIMESPAN[0]
nullable=False, coerce=True, default=DEFAULT_TIMESPAN[0]
)
end_time: Series[TimeString] = pa.Field(
nullable=False,
coerce=True,
default=DEFAULT_TIMESPAN[1]
nullable=False, coerce=True, default=DEFAULT_TIMESPAN[1]
)
headway_secs: Series[int] = pa.Field(
coerce=True,
Expand All @@ -306,16 +302,13 @@ class Config:

class WranglerFrequenciesTable(FrequenciesTable):
"""Wrangler flavor of GTFS FrequenciesTable."""

projects: Series[str] = pa.Field(coerce=True, default="")
start_time: Series = pa.Field(
nullable=False,
coerce=True,
default=str_to_time(DEFAULT_TIMESPAN[0])
nullable=False, coerce=True, default=str_to_time(DEFAULT_TIMESPAN[0])
)
end_time: Series = pa.Field(
nullable=False,
coerce=True,
default=str_to_time(DEFAULT_TIMESPAN[1])
nullable=False, coerce=True, default=str_to_time(DEFAULT_TIMESPAN[1])
)

class Config:
Expand Down Expand Up @@ -422,4 +415,4 @@ class Config:
"trip_id": ["trips", "trip_id"],
"stop_id": ["stops", "stop_id"],
}
unique = ["trip_id", "stop_sequence"]
unique = ["trip_id", "stop_sequence"]
5 changes: 4 additions & 1 deletion network_wrangler/models/projects/roadway_property_change.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import itertools

from typing import Optional, ClassVar, Any, Union
from typing import Optional, ClassVar, Any, Union, Literal
from datetime import datetime

from pandera import DataFrameModel, Field
Expand Down Expand Up @@ -44,6 +44,7 @@ class IndivScopedPropertySetItem(BaseModel):
timespan: Optional[TimespanString] = DEFAULT_TIMESPAN
set: Optional[Any] = None
existing: Optional[Any] = None
overwrite_conflicts: Optional[bool] = False
change: Optional[Union[int, float]] = None
_examples = [
{"category": "hov3", "timespan": ["6:00", "9:00"], "set": 2.0},
Expand Down Expand Up @@ -97,6 +98,7 @@ class GroupedScopedPropertySetItem(BaseModel):
categories: Optional[list[Any]] = []
timespans: Optional[list[TimespanString]] = []
set: Optional[Any] = None
overwrite_conflicts: Optional[bool] = False
existing: Optional[Any] = None
change: Optional[Union[int, float]] = None
_examples = [
Expand Down Expand Up @@ -227,6 +229,7 @@ class RoadPropertyChange(RecordModel):
change: Optional[Union[int, float]] = None
set: Optional[Any] = None
scoped: Optional[Union[None, ScopedPropertySetList]] = None
overwrite_scoped: Optional[Literal["conflicting", "all", False]] = False

require_one_of: ClassVar[OneOf] = [["change", "set"]]

Expand Down
40 changes: 11 additions & 29 deletions network_wrangler/roadway/links/edit.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

import copy

from typing import Union, Any, Optional
from typing import Union, Any, Optional, Literal

import numpy as np

Expand Down Expand Up @@ -80,15 +80,15 @@ def _initialize_links_as_managed_lanes(
def _resolve_conflicting_scopes(
scoped_values: list[ScopedLinkValueItem],
scoped_item: IndivScopedPropertySetItem,
delete_conflicting: bool = True,
overwrite_conflicting: bool = True,
) -> list[ScopedLinkValueItem]:
conflicting_existing = _filter_to_conflicting_scopes(
scoped_values,
timespan=scoped_item.timespan,
category=scoped_item.category,
)
if conflicting_existing:
if delete_conflicting:
if overwrite_conflicting:
return [i for i in scoped_values not in conflicting_existing]
else:
WranglerLogger.error(
Expand Down Expand Up @@ -145,29 +145,26 @@ def _edit_scoped_link_property(
scoped_prop_value_list: Union[None, list[ScopedLinkValueItem]],
scoped_prop_set: ScopedPropertySetList,
default_value: Any = None,
overwrite_all: bool = False,
overwrite_conflicting: bool = True,
overwrite_scoped: Optional[Literal["conflicting", "all", False]] = False,
) -> list[ScopedLinkValueItem]:
"""Edit scoped property on a single link.
Args:
scoped_prop_value_list: List of scoped property values for the link.
scoped_prop_set: ScopedPropertySetList of changes to make.
default_value: Default value for the property if not set.
overwrite_all: If True, will overwrite all scoped values for link property with
scoped_prop_value_set. Defaults to False.
overwrite_conflicting: If True will overwrite any conflicting scopes. Otherwise, will
raise an Exception on conflicting, but not matching, scopes.
overwrite_scoped: If 'all', will overwrite all scoped properties. If 'conflicting',
will overwrite conflicting scoped properties. If False, will raise an error on
conflicting scoped properties. Defaults to False.
"""
msg = f"Setting scoped link property.\n\
- Current value:{scoped_prop_value_list}\n\
- Set value: {scoped_prop_set}\n\
- Default value: {default_value}\n\
- Overwrite all? {overwrite_all}\n\
- Overwrite conflicting? {overwrite_conflicting}"
- Overwrite scoped: {overwrite_scoped}"
# WranglerLogger.debug(msg)
# If None, or asked to overwrite all scopes, and return all set items
if overwrite_all or not scoped_prop_value_list:
if overwrite_scoped == "all" or not scoped_prop_value_list:
scoped_prop_value_list = [
_update_property_for_scope(i, default_value) for i in scoped_prop_set
]
Expand All @@ -184,7 +181,7 @@ def _edit_scoped_link_property(
updated_scoped_prop_value_list = _resolve_conflicting_scopes(
updated_scoped_prop_value_list,
scoped_prop_set,
delete_conflicting=overwrite_conflicting,
overwrite_conflicting=overwrite_scoped == "conflicting",
)

# find matching scopes
Expand Down Expand Up @@ -246,8 +243,6 @@ def _edit_link_property(
prop_name: str,
prop_change: RoadPropertyChange,
existing_value_conflict_error: bool = False,
overwrite_all_scoped: bool = False,
overwrite_conflicting_scoped: bool = True,
ml_link_offset_meters: float = LINK_ML_OFFSET_METERS,
project_name: Optional[str] = None,
) -> DataFrame[RoadLinksTable]:
Expand All @@ -267,10 +262,6 @@ def _edit_link_property(
existing_value_conflict_error: If True, will trigger an error if the existing
specified value in the project card doesn't match the value in links_df.
Otherwise, will only trigger a warning. Defaults to False.
overwrite_all_scoped: If True, will overwrite all scoped values for link property with
scoped_prop_value_set. Defaults to False.
overwrite_conflicting_scoped: If True will overwrite any conflicting scopes.
Otherwise, will raise an Exception on conflicting, but not matching, scopes.
ml_link_offset_meters: Offset in meters for managed lane geometry. If not set, will use
LINK_ML_OFFSET_METERS from params.py.
project_name: optional name of the project to be applied
Expand Down Expand Up @@ -335,8 +326,7 @@ def _edit_link_property(
links_df.at[idx, sc_prop_name],
prop_change.scoped,
links_df.at[idx, prop_name],
overwrite_all=overwrite_all_scoped,
overwrite_conflicting=overwrite_conflicting_scoped,
overwrite_scoped=prop_change.overwrite_scoped,
)
msg = f"idx:\n {idx}\n\
type: \n {type(links_df.at[idx, sc_prop_name])}\n\
Expand Down Expand Up @@ -368,12 +358,6 @@ def edit_link_property(
existing_value_conflict_error: If True, will trigger an error if the existing
specified value in the project card doesn't match the value in links_df.
Otherwise, will only trigger a warning. Defaults to False.
overwrite_all_scoped: If True, will overwrite all scoped values for link property with
scoped_prop_value_set. Defaults to False.
overwrite_conflicting_scoped: If True will overwrite any conflicting scopes.
Otherwise, will raise an Exception on conflicting, but not matching, scopes
Defaults to True.
"""
WranglerLogger.info(f"Editing Link Property {prop_name} for {len(link_idx)} links.")
WranglerLogger.debug(f"prop_dict: /n{prop_dict}")
Expand All @@ -386,8 +370,6 @@ def edit_link_property(
prop_name,
prop_change,
existing_value_conflict_error,
overwrite_all_scoped=prop_change.overwrite_all_scoped,
overwrite_conflicting_scoped=prop_change.overwrite_conflicting_scoped,
)
links_df = validate_df_to_model(links_df, RoadLinksTable)
WranglerLogger.debug(
Expand Down
4 changes: 1 addition & 3 deletions network_wrangler/time.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,4 @@ def overlaps(self, other: Timespan) -> bool:
Returns:
bool: True if the two timespans overlap, False otherwise.
"""
return (
self.start_time <= other.end_time and self.end_time >= other.start_time
)
return self.start_time <= other.end_time and self.end_time >= other.start_time
20 changes: 13 additions & 7 deletions network_wrangler/transit/projects/edit_property.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,12 @@ def apply_transit_property_change(
WranglerLogger.debug("Applying transit property change project.")
for property, property_change in property_changes.items():
net = _apply_transit_property_change_to_table(
net, selection, property, property_change, project_name=project_name,
existing_value_conflict_error=existing_value_conflict_error
net,
selection,
property,
property_change,
project_name=project_name,
existing_value_conflict_error=existing_value_conflict_error,
)
return net

Expand All @@ -76,7 +80,8 @@ def _check_existing_value(existing_s: Series, expected_existing_val) -> bool:
f"Existing do not all match expected value of {expected_existing_val}."
)
WranglerLogger.debug(
f"Conflicting values values: {existing_s[existing_s != expected_existing_val]}")
f"Conflicting values values: {existing_s[existing_s != expected_existing_val]}"
)
return False
return True

Expand Down Expand Up @@ -104,7 +109,8 @@ def _apply_transit_property_change_to_table(

if "existing" in property_change:
existing_ok = _check_existing_value(
table_df.loc[update_idx, property], property_change["existing"])
table_df.loc[update_idx, property], property_change["existing"]
)
if not existing_ok:
WranglerLogger.warning(f"Existing {property} != {property_change['existing']}.")
if existing_value_conflict_error:
Expand All @@ -116,9 +122,9 @@ def _apply_transit_property_change_to_table(
if "set" in property_change:
set_df.loc[update_idx, property] = property_change["set"]
elif "change" in property_change:
set_df.loc[update_idx, property] = \
set_df.loc[update_idx, property] \
+ property_change["change"]
set_df.loc[update_idx, property] = (
set_df.loc[update_idx, property] + property_change["change"]
)
else:
raise ValueError("Property change must include 'set' or 'change'.")

Expand Down
19 changes: 10 additions & 9 deletions network_wrangler/utils/time.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,7 @@ def str_to_time(time_str: TimeString, base_date: Optional[datetime.date] = None)

@validate_call(config=dict(arbitrary_types_allowed=True))
def str_to_time_series(
time_str_s: pd.Series,
base_date: Optional[Union[pd.Series, datetime.date]] = None
time_str_s: pd.Series, base_date: Optional[Union[pd.Series, datetime.date]] = None
) -> pd.Series:
"""Convert panda series of TimeString (HH:MM<:SS>) to datetime object.
Expand All @@ -71,7 +70,7 @@ def str_to_time_series(
Args:
time_str_s: Pandas Series of TimeStrings in HH:MM:SS or HH:MM format.
base_date: optional date to base the datetime on. Defaults to None.
If not provided, will use today. Can be either a single instance or a series of
If not provided, will use today. Can be either a single instance or a series of
same length as time_str_s
"""
TimeStrSeriesSchema.validate(time_str_s)
Expand Down Expand Up @@ -100,16 +99,18 @@ def str_to_time_series(
hours = hours % 24

# Combine the base date with the adjusted time and add the extra days if needed
combined_datetimes = pd.to_datetime(base_dates)\
+ pd.to_timedelta(days_to_add, unit='d')\
+ pd.to_timedelta(hours, unit='h')\
+ pd.to_timedelta(minutes, unit='m')\
+ pd.to_timedelta(seconds, unit='s')
combined_datetimes = (
pd.to_datetime(base_dates)
+ pd.to_timedelta(days_to_add, unit="d")
+ pd.to_timedelta(hours, unit="h")
+ pd.to_timedelta(minutes, unit="m")
+ pd.to_timedelta(seconds, unit="s")
)

# Combine the results back into the original series
result = time_str_s.copy()
result[is_string] = combined_datetimes
result = result.astype('datetime64[ns]')
result = result.astype("datetime64[ns]")
return result


Expand Down
8 changes: 5 additions & 3 deletions tests/test_transit/test_changes/test_transit_prop_changes.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,16 +101,18 @@ def test_transit_property_change(request, small_transit_net):
# Filter the DataFrame correctly
WranglerLogger.debug(f"Result Frequencies: \n{net.feed.frequencies}")
target_df = net.feed.frequencies.loc[
(net.feed.frequencies.trip_id.isin(trip_ids)) &
(net.feed.frequencies.start_time.dt.strftime('%H:%M') == timespan[0])
(net.feed.frequencies.trip_id.isin(trip_ids))
& (net.feed.frequencies.start_time.dt.strftime("%H:%M") == timespan[0])
]

if not (target_df["headway_secs"] == new_headway).all():
WranglerLogger.error("Headway not changed as expected:")
WranglerLogger.debug(f"Targeted Frequencies: \n{target_df}")
assert False

unchanged_result_df = net.feed.frequencies.loc[~net.feed.frequencies.index.isin(target_df.index)]
unchanged_result_df = net.feed.frequencies.loc[
~net.feed.frequencies.index.isin(target_df.index)
]
unchanged_og_df_df = og_df.loc[~og_df.index.isin(target_df.index)]

try:
Expand Down
29 changes: 16 additions & 13 deletions tests/test_utils/test_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
import pandas as pd

from network_wrangler.utils.time import (
str_to_time, timespans_overlap, filter_df_to_overlapping_timespans
str_to_time,
timespans_overlap,
filter_df_to_overlapping_timespans,
)
from network_wrangler.logger import WranglerLogger

Expand Down Expand Up @@ -62,18 +64,19 @@ def test_timespans_overlaps(case):
([["1:30:15", "19:30:15"]], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]),
]

overlap_df = pd.DataFrame([
[1, "14:30:15", "15:30:15"],
[2, "14:30:17", "16:30:15"],
[3, "14:00:15", "14:30:15"],
[4, "12:00:00", "13:00:00"],
[5, "12:00:00", "14:00:00"],
[6, "8:00:00", "10:00:00"],
[7, "7:00:00", "8:00:00"],
[8, "9:45:00", "11:00:00"],
[9, "4:00:00", "7:11:00"],
[10, "13:00:00", "14:00:00"],
],
overlap_df = pd.DataFrame(
[
[1, "14:30:15", "15:30:15"],
[2, "14:30:17", "16:30:15"],
[3, "14:00:15", "14:30:15"],
[4, "12:00:00", "13:00:00"],
[5, "12:00:00", "14:00:00"],
[6, "8:00:00", "10:00:00"],
[7, "7:00:00", "8:00:00"],
[8, "9:45:00", "11:00:00"],
[9, "4:00:00", "7:11:00"],
[10, "13:00:00", "14:00:00"],
],
columns=["id", "start_time", "end_time"],
).astype({"start_time": "datetime64[s]", "end_time": "datetime64[s]"})

Expand Down

0 comments on commit 6e4f781

Please sign in to comment.