Skip to content

Commit

Permalink
fix stacker in forecaster
Browse files Browse the repository at this point in the history
  • Loading branch information
ourownstory committed Sep 3, 2024
1 parent 344bc5c commit b1f4fd1
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 7 deletions.
14 changes: 7 additions & 7 deletions neuralprophet/forecaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -1221,8 +1221,8 @@ def fit(
# df_val, _, _, _ = df_utils.prep_or_copy_df(df_val)
df_val = _normalize(df=df_val, config_normalization=self.config_normalization)
val_components_stacker = utils_time_dataset.ComponentStacker(
n_lags=self.n_lags,
max_lags=self.max_lags,
n_lags=self.config_ar.n_lags,
max_lags=self.config_model.max_lags,
n_forecasts=self.n_forecasts,
config_seasonality=self.config_seasonality,
lagged_regressor_config=self.config_lagged_regressors,
Expand Down Expand Up @@ -1461,9 +1461,9 @@ def test(self, df: pd.DataFrame, verbose: bool = True):
df, _, _, _ = df_utils.prep_or_copy_df(df)
df = _normalize(df=df, config_normalization=self.config_normalization)
components_stacker = utils_time_dataset.ComponentStacker(
n_lags=self.n_lags,
n_lags=self.config_ar.n_lags,
n_forecasts=self.n_forecasts,
max_lags=self.max_lags,
max_lags=self.config_model.max_lags,
config_seasonality=self.config_seasonality,
lagged_regressor_config=self.config_lagged_regressors,
feature_indices={},
Expand Down Expand Up @@ -2116,7 +2116,7 @@ def predict_seasonal_components(self, df: pd.DataFrame, quantile: float = 0.5):
max_lags=0,
n_forecasts=1,
config_seasonality=self.config_seasonality,
lagged_regressor_config=self.config_lagged_regressors,
lagged_regressor_config=None,
)
dataset = time_dataset.TimeDataset(
df=df_i,
Expand Down Expand Up @@ -2959,9 +2959,9 @@ def _predict_raw(self, df, df_name, include_components=False, prediction_frequen
if "y_scaled" not in df.columns or "t" not in df.columns:
raise ValueError("Received unprepared dataframe to predict. " "Please call predict_dataframe_to_predict.")
components_stacker = utils_time_dataset.ComponentStacker(
n_lags=self.n_lags,
n_lags=self.config_ar.n_lags,
n_forecasts=self.n_forecasts,
max_lags=self.max_lags,
max_lags=self.config_model.max_lags,
config_seasonality=self.config_seasonality,
lagged_regressor_config=self.config_lagged_regressors,
feature_indices={},
Expand Down
2 changes: 2 additions & 0 deletions neuralprophet/utils_time_dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from collections import OrderedDict
from datetime import datetime

import numpy as np
import torch


Expand Down

0 comments on commit b1f4fd1

Please sign in to comment.