Skip to content

Commit

Permalink
init Stacker in forecaster
Browse files Browse the repository at this point in the history
  • Loading branch information
ourownstory committed Aug 30, 2024
1 parent 5d0d59a commit e0623ed
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 31 deletions.
34 changes: 24 additions & 10 deletions neuralprophet/forecaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,17 @@
from matplotlib.axes import Axes
from torch.utils.data import DataLoader

from neuralprophet import configure, df_utils, np_types, time_dataset, time_net, utils, utils_lightning, utils_metrics
from neuralprophet import (
configure,
df_utils,
np_types,
time_dataset,
time_net,
utils,
utils_lightning,
utils_metrics,
utils_time_dataset,
)
from neuralprophet.data.process import (
_check_dataframe,
_convert_raw_predictions_to_raw_df,
Expand Down Expand Up @@ -1156,13 +1166,15 @@ def fit(
# Set up DataLoaders: Train
# Create TimeDataset
# Note: _create_dataset() needs to be called after set_auto_seasonalities()
train_components_stacker = _create_components_stacker(
train_components_stacker = utils_time_dataset.ComponentStacker(
n_lags=self.n_lags,
max_lags=self.max_lags,
n_forecasts=self.n_forecasts,
max_lags=self.max_lags,
config_seasonality=self.config_seasonality,
config_lagged_regressors=self.config_lagged_regressors,
lagged_regressor_config=self.config_lagged_regressors,
feature_indices={},
)

dataset = _create_dataset(
self,
df,
Expand Down Expand Up @@ -1443,12 +1455,13 @@ def test(self, df: pd.DataFrame, verbose: bool = True):
)
df, _, _, _ = df_utils.prep_or_copy_df(df)
df = _normalize(df=df, config_normalization=self.config_normalization)
components_stacker = _create_components_stacker(
components_stacker = utils_time_dataset.ComponentStacker(
n_lags=self.n_lags,
max_lags=self.max_lags,
n_forecasts=self.n_forecasts,
max_lags=self.max_lags,
config_seasonality=self.config_seasonality,
config_lagged_regressors=self.config_lagged_regressors,
lagged_regressor_config=self.config_lagged_regressors,
feature_indices={},
)
dataset = _create_dataset(self, df, predict_mode=False, components_stacker=components_stacker)
self.model.set_components_stacker(components_stacker, mode="test")
Expand Down Expand Up @@ -2928,12 +2941,13 @@ def _predict_raw(self, df, df_name, include_components=False, prediction_frequen
assert len(df["ID"].unique()) == 1
if "y_scaled" not in df.columns or "t" not in df.columns:
raise ValueError("Received unprepared dataframe to predict. " "Please call predict_dataframe_to_predict.")
components_stacker = _create_components_stacker(
components_stacker = utils_time_dataset.ComponentStacker(
n_lags=self.n_lags,
max_lags=self.max_lags,
n_forecasts=self.n_forecasts,
max_lags=self.max_lags,
config_seasonality=self.config_seasonality,
config_lagged_regressors=self.config_lagged_regressors,
lagged_regressor_config=self.config_lagged_regressors,
feature_indices={},
)
dataset = _create_dataset(
self,
Expand Down
24 changes: 3 additions & 21 deletions neuralprophet/time_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,9 @@ def __init__(
self.config_regressors
)

if self.config_seasonality is not None and hasattr(self.config_seasonality, "periods"):
self.calculate_seasonalities()

# skipping col "ID" is string type that is interpreted as object by torch (self.df[col].dtype == "O")
# "ID" is stored in self.meta["df_name"]
skip_cols = ["ID", "ds"]
Expand All @@ -112,30 +115,9 @@ def __init__(
self.df["ds"] = self.df["ds"].apply(lambda x: x.timestamp()) # Convert to Unix timestamp in seconds
self.df_tensors["ds"] = torch.tensor(self.df["ds"].values, dtype=torch.int64)

if self.additive_event_and_holiday_names:
self.df_tensors["additive_event_and_holiday"] = torch.stack(
[self.df_tensors[name] for name in self.additive_event_and_holiday_names], dim=1
)
if self.multiplicative_event_and_holiday_names:
self.df_tensors["multiplicative_event_and_holiday"] = torch.stack(
[self.df_tensors[name] for name in self.multiplicative_event_and_holiday_names], dim=1
)

if self.additive_regressors_names:
self.df_tensors["additive_regressors"] = torch.stack(
[self.df_tensors[name] for name in self.additive_regressors_names], dim=1
)
if self.multiplicative_regressors_names:
self.df_tensors["multiplicative_regressors"] = torch.stack(
[self.df_tensors[name] for name in self.multiplicative_regressors_names], dim=1
)

# Construct index map
self.sample2index_map, self.length = self.create_sample2index_map(self.df, self.df_tensors)

if self.config_seasonality is not None and hasattr(self.config_seasonality, "periods"):
self.calculate_seasonalities()

self.components_stacker = components_stacker

self.stack_all_features()
Expand Down

0 comments on commit e0623ed

Please sign in to comment.