diff --git a/activitysim/abm/models/vehicle_allocation.py b/activitysim/abm/models/vehicle_allocation.py index e4129e106..a341493ff 100644 --- a/activitysim/abm/models/vehicle_allocation.py +++ b/activitysim/abm/models/vehicle_allocation.py @@ -196,6 +196,17 @@ def vehicle_allocation( choosers = pd.merge(choosers, vehicles_wide, how="left", on="household_id") choosers.set_index("tour_id", inplace=True) + ## get categorical dtype for vehicle_type and use it to create new dtype for + ## vehicle_occup_* and selected_vehicle columns + veh_type_dtype = vehicles["vehicle_type"].dtype + if isinstance(veh_type_dtype, pd.CategoricalDtype): + veh_categories = list(veh_type_dtype.categories) + if "non_hh_veh" not in veh_categories: + veh_categories.append("non_hh_veh") + veh_choice_dtype = pd.CategoricalDtype(veh_categories, ordered=False) + else: + veh_choice_dtype = "category" + # ----- setup skim keys skims = get_skim_dict(network_los, choosers) locals_dict.update(skims) @@ -254,7 +265,7 @@ def vehicle_allocation( # creating a column for choice of each occupancy level tours_veh_occup_col = f"vehicle_occup_{occup}" tours[tours_veh_occup_col] = choices["choice"] - tours[tours_veh_occup_col] = tours[tours_veh_occup_col].astype("category") + tours[tours_veh_occup_col] = tours[tours_veh_occup_col].astype(veh_choice_dtype) tours_veh_occup_cols.append(tours_veh_occup_col) if estimator: diff --git a/activitysim/abm/models/vehicle_type_choice.py b/activitysim/abm/models/vehicle_type_choice.py index 8271ac6e8..813652459 100644 --- a/activitysim/abm/models/vehicle_type_choice.py +++ b/activitysim/abm/models/vehicle_type_choice.py @@ -393,7 +393,7 @@ def iterate_vehicle_type_choice( sorted(alts_cats_dict["fuel_type"]), ordered=False ) vehicle_type_cat = pd.api.types.CategoricalDtype( - [""] + sorted(set(alts_wide["vehicle_type"])), ordered=False + sorted(set(alts_wide["vehicle_type"])), ordered=False ) alts_wide["body_type"] = alts_wide["body_type"].astype(body_type_cat) @@ -403,7 +403,7 @@ def iterate_vehicle_type_choice( alts_wide = alts_long = None alts = model_spec.columns vehicle_type_cat = pd.api.types.CategoricalDtype( - [""] + sorted(set(alts)), ordered=False + sorted(set(alts)), ordered=False ) # alts preprocessor