Skip to content

Commit

Permalink
Fix timespan queries; Add applied project to roadway and transit feat…
Browse files Browse the repository at this point in the history
…ures (#369)

* Fix small bug in roadway deletion: modes not mode.
* Fix #282 Log applied project name as feature-level variable
- updates transit and roadway data models to add `projects` as string
- updates add and change projects to append project name to `projects`
- tests that it is appropriately logged
* Fix #368 resiliently query to update headway when there are multiple entries per trip_id
- Separately query timespans to enforce "or" logic 
- Separate querying of frequencies from querying of stoptimes so frequencies can return more than one entry
- Separate GTFS and Wrangler Frequencies and Stoptime schemas for times: string in HH:MM --> datetime object
- Add parser to parse strings to datetimes upon conversion to Wrangler table models: note that it will not correctly do H>24 :-( 
- Fix parsing of times to have consistent days
- Make default arrival/departure and start_stop related to DEFAULT_TIMESPAN
- Allow for single-digit hours
- enforce str schema for time strings
- raise TableValidationErrors for ValueError and TypeErrors
* Other:
- reduce some debugging noise
  • Loading branch information
e-lo authored Sep 3, 2024
1 parent 3b68343 commit b7b3b91
Show file tree
Hide file tree
Showing 38 changed files with 649 additions and 525 deletions.
2 changes: 1 addition & 1 deletion examples/small/frequencies.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
trip_id,headway_secs,start_time,end_time
blue-1,600,00:04:00,06:00:00
blue-1,600,04:00:00,06:00:00
blue-1,1800,06:00:00,09:00:00
blue-2,900,9:00:00,20:00:00
2 changes: 1 addition & 1 deletion examples/small/link.json
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
"price": 0,
"sc_price": [
{
"timespan": ["12:00","3:00"],
"timespan": ["12:00","15:00"],
"category": "HOV2",
"value": 1.0
}
Expand Down
9 changes: 9 additions & 0 deletions network_wrangler/models/_base/series.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import pandera as pa

# Define a general schema for any Series with time strings in HH:MM or HH:MM:SS format
TimeStrSeriesSchema = pa.SeriesSchema(
pa.String,
pa.Check.str_matches(r"^(?:[0-9]|[01]\d|2[0-3]):[0-5]\d(?::[0-5]\d)?$|^24:00(?::00)?$"),
coerce=True,
name=None # Name is set to None to ignore the Series name
)
4 changes: 3 additions & 1 deletion network_wrangler/models/_base/tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pandas as pd
from pandera.extensions import register_check_method
from pydantic import ValidationError
from pydantic import ValidationError, RootModel

from ...logger import WranglerLogger

Expand All @@ -18,6 +18,8 @@ def validate_list_of_pyd(item_list, pyd_model):

def validate_pyd(item, pyd_model):
try:
# if issubclass(pyd_model, RootModel):
# pyd_model(__root__=item)
pyd_model(item)
return True
except ValidationError:
Expand Down
2 changes: 1 addition & 1 deletion network_wrangler/models/gtfs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def __init__(self, **kwargs):
from .tables import (
StopsTable,
RoutesTable,
TripsTable,
WranglerTripsTable,
StopTimesTable,
ShapesTable,
FrequenciesTable,
Expand Down
114 changes: 104 additions & 10 deletions network_wrangler/models/gtfs/tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def process_table(table: pa.DataFrameModel):
from typing import Optional

import pandera as pa
import pandas as pd

from pandas import Timestamp
from pandera.typing import Series, Category
Expand All @@ -67,6 +68,9 @@ def process_table(table: pa.DataFrameModel):
TimepointType,
)
from .table_types import HttpURL
from .._base.types import TimeString
from ...utils.time import str_to_time_series, str_to_time
from ...params import DEFAULT_TIMESPAN


class AgenciesTable(pa.DataFrameModel):
Expand Down Expand Up @@ -147,6 +151,7 @@ class WranglerStopsTable(StopsTable):
)
stop_lat: Series[float] = pa.Field(coerce=True, nullable=True, ge=-90, le=90)
stop_lon: Series[float] = pa.Field(coerce=True, nullable=True, ge=-180, le=180)
projects: Series[str] = pa.Field(coerce=True, default="")


class RoutesTable(pa.DataFrameModel):
Expand Down Expand Up @@ -210,6 +215,7 @@ class WranglerShapesTable(ShapesTable):
"""Wrangler flavor of GTFS ShapesTable."""

shape_model_node_id: Series[int] = pa.Field(coerce=True, nullable=False)
projects: Series[str] = pa.Field(coerce=True, default="")


class TripsTable(pa.DataFrameModel):
Expand Down Expand Up @@ -249,6 +255,20 @@ class Config:
_fk = {"route_id": ["routes", "route_id"]}


class WranglerTripsTable(TripsTable):
"""Represents the Trips table in the Wrangler feed, adding projects list."""

projects: Series[str] = pa.Field(coerce=True, default="")

class Config:
"""Config for the WranglerTripsTable data model."""

coerce = True
add_missing_columns = True
_pk = ["trip_id"]
_fk = {"route_id": ["routes", "route_id"]}


class FrequenciesTable(pa.DataFrameModel):
"""Represents the Frequency table in the GTFS dataset.
Expand All @@ -258,8 +278,16 @@ class FrequenciesTable(pa.DataFrameModel):
"""

trip_id: Series[str] = pa.Field(nullable=False, coerce=True)
start_time: Series[Timestamp] = pa.Field(nullable=False, coerce=True)
end_time: Series[Timestamp] = pa.Field(nullable=False, coerce=True)
start_time: Series[TimeString] = pa.Field(
nullable=False,
coerce=True,
default=DEFAULT_TIMESPAN[0]
)
end_time: Series[TimeString] = pa.Field(
nullable=False,
coerce=True,
default=DEFAULT_TIMESPAN[1]
)
headway_secs: Series[int] = pa.Field(
coerce=True,
ge=1,
Expand All @@ -276,6 +304,47 @@ class Config:
_fk = {"trip_id": ["trips", "trip_id"]}


class WranglerFrequenciesTable(FrequenciesTable):
"""Wrangler flavor of GTFS FrequenciesTable."""
projects: Series[str] = pa.Field(coerce=True, default="")
start_time: Series = pa.Field(
nullable=False,
coerce=True,
default=str_to_time(DEFAULT_TIMESPAN[0])
)
end_time: Series = pa.Field(
nullable=False,
coerce=True,
default=str_to_time(DEFAULT_TIMESPAN[1])
)

class Config:
"""Config for the FrequenciesTable data model."""

coerce = True
add_missing_columns = True
unique = ["trip_id", "start_time"]
_pk = ["trip_id", "start_time"]
_fk = {"trip_id": ["trips", "trip_id"]}

@pa.parser("start_time")
def st_to_timestamp(cls, series: Series) -> Series[Timestamp]:
"""Check that start time is timestamp."""
series = series.fillna(str_to_time(DEFAULT_TIMESPAN[0]))
if series.dtype == "datetime64[ns]":
return series
series = str_to_time_series(series)
return series.astype("datetime64[ns]")

@pa.parser("end_time")
def et_to_timestamp(cls, series: Series) -> Series[Timestamp]:
"""Check that start time is timestamp."""
series = series.fillna(str_to_time(DEFAULT_TIMESPAN[1]))
if series.dtype == "datetime64[ns]":
return series
return str_to_time_series(series)


class StopTimesTable(pa.DataFrameModel):
"""Represents the Stop Times table in the GTFS dataset.
Expand All @@ -297,8 +366,8 @@ class StopTimesTable(pa.DataFrameModel):
nullable=True,
coerce=True,
)
arrival_time: Series[Timestamp] = pa.Field(nullable=False, coerce=True)
departure_time: Series[Timestamp] = pa.Field(nullable=False, coerce=True)
arrival_time: Series[TimeString] = pa.Field(nullable=True, coerce=True)
departure_time: Series[TimeString] = pa.Field(nullable=True, coerce=True)

# Optional
shape_dist_traveled: Optional[Series[float]] = pa.Field(coerce=True, nullable=True, ge=0)
Expand All @@ -323,9 +392,34 @@ class WranglerStopTimesTable(StopTimesTable):
"""Wrangler flavor of GTFS StopTimesTable."""

stop_id: Series[int] = pa.Field(nullable=False, coerce=True, description="The model_node_id.")
arrival_time: Optional[Series[Timestamp]] = pa.Field(
coerce=True, nullable=True, default=Timestamp("00:00:00")
)
departure_time: Optional[Series[Timestamp]] = pa.Field(
coerce=True, nullable=True, default=Timestamp("00:00:00")
)
arrival_time: Series[Timestamp] = pa.Field(nullable=True)
departure_time: Series[Timestamp] = pa.Field(nullable=True)
projects: Series[str] = pa.Field(coerce=True, default="")

@pa.parser("arrival_time")
def at_to_timestamp(cls, series: Series) -> Series[Timestamp]:
"""Check that arrival time timestamp."""
if series.dtype == "datetime64[ns]":
return series
series = str_to_time_series(series)
return series.astype("datetime64[ns]")

@pa.parser("departure_time")
def dt_to_timestamp(cls, series: Series) -> Series[Timestamp]:
"""Check that departure time is timestamp."""
if series.dtype == "datetime64[ns]":
return series
series = str_to_time_series(series)
return series.astype("datetime64[ns]")

class Config:
"""Config for the StopTimesTable data model."""

coerce = True
add_missing_columns = True
_pk = ["trip_id", "stop_sequence"]
_fk = {
"trip_id": ["trips", "trip_id"],
"stop_id": ["stops", "stop_id"],
}
unique = ["trip_id", "stop_sequence"]
4 changes: 2 additions & 2 deletions network_wrangler/models/projects/roadway_deletion.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,6 @@ class RoadwayDeletion(RecordModel):
def set_to_all_modes(cls, links: Optional[SelectLinksDict] = None):
"""Set the search mode to 'any' if not specified explicitly."""
if links is not None:
if links.mode == DEFAULT_SEARCH_MODES:
links.mode = DEFAULT_DELETE_MODES
if links.modes == DEFAULT_SEARCH_MODES:
links.modes = DEFAULT_DELETE_MODES
return links
28 changes: 8 additions & 20 deletions network_wrangler/models/roadway/tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class RoadLinksTable(DataFrameModel):
distance: Series[float] = pa.Field(coerce=True, nullable=False)

roadway: Series[str] = pa.Field(nullable=False, default="road")
projects: Series[str] = pa.Field(coerce=True, default="")
managed: Series[int] = pa.Field(coerce=True, nullable=False, default=0)

shape_id: Series[str] = pa.Field(coerce=True, nullable=True)
Expand All @@ -58,6 +59,7 @@ class RoadLinksTable(DataFrameModel):
sc_lanes: Optional[Series[object]] = pa.Field(coerce=True, nullable=True, default=None)
sc_price: Optional[Series[object]] = pa.Field(coerce=True, nullable=True, default=None)

ML_projects: Series[str] = pa.Field(coerce=True, default="")
ML_lanes: Optional[Series[Int64]] = pa.Field(coerce=True, nullable=True, default=None)
ML_price: Optional[Series[float]] = pa.Field(coerce=True, nullable=True, default=0)
ML_access: Optional[Series[Any]] = pa.Field(coerce=True, nullable=True, default=True)
Expand Down Expand Up @@ -108,30 +110,16 @@ class Config:
coerce = True
unique = ["A", "B"]

@pa.dataframe_check
def check_scoped_fields(cls, df: pd.DataFrame) -> Series[bool]:
@pa.check("sc_*", regex=True, element_wise=True)
def check_scoped_fields(cls, scoped_value: Series) -> Series[bool]:
"""Checks that all fields starting with 'sc_' or 'sc_ML_' are valid ScopedLinkValueList.
Custom check to validate fields starting with 'sc_' or 'sc_ML_'
against a ScopedLinkValueItem model, handling both mandatory and optional fields.
"""
scoped_fields = [
col for col in df.columns if col.startswith("sc_") or col.startswith("sc_ML")
]
results = []
# WranglerLogger.debug(f"Checking scoped fields: {scoped_fields}")
# WranglerLogger.debug(f"{df[scoped_fields]}")
for field in scoped_fields:
if df[field].notna().any():
results.append(
df[field].dropna().apply(validate_pyd, args=(ScopedLinkValueList,)).all()
)
else:
# Handling optional fields: Assume validation is true if the field is entirely NA
results.append(True)

# Combine all results: True if all fields pass validation
return pd.Series(all(results), index=df.index)
if not scoped_value or pd.isna(scoped_value):
return True
return validate_pyd(scoped_value, ScopedLinkValueList)


class RoadNodesTable(DataFrameModel):
Expand All @@ -149,7 +137,7 @@ class RoadNodesTable(DataFrameModel):
nullable=True,
default="",
)

projects: Series[str] = pa.Field(coerce=True, default="")
inboundReferenceIds: Optional[Series[list[str]]] = pa.Field(coerce=True, nullable=True)
outboundReferenceIds: Optional[Series[list[str]]] = pa.Field(coerce=True, nullable=True)

Expand Down
2 changes: 1 addition & 1 deletion network_wrangler/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def unique_ids(self) -> list[str]:
"""
DEFAULT_SP_WEIGHT_COL = "i"

"""Default timespan for scoped values."""
"""Default timespan for scoped values and GTFS."""
DEFAULT_TIMESPAN = ["00:00", "24:00"]

"""Default category for scoped values."""
Expand Down
Loading

0 comments on commit b7b3b91

Please sign in to comment.