From ca0952fb2ffc41bb344dc9a3fe62c82fbf05c0fb Mon Sep 17 00:00:00 2001 From: Oskar Triebe Date: Tue, 3 Sep 2024 15:58:07 -0700 Subject: [PATCH] [Minor] Move max_lags, prediction_freq to config_model and n_lags to config_ar (#1644) * updated dataset get_item * fixed linting issues * make targets contiguous * fixed ruff warnings * Unpack incrementally when needed * adjust forecaster * separate unpacking logic * added featureExtractor class * separate packing logic * fixed liniting issues * fixed covariates * remove lagged_reg_layers from model_config * remove n_lags from forecaster * remove model.n_lags references * fix typo * fix 2 * fixes * do not init max_lags * set max lags in add_lagged_reg * fix test * fix testz * fix testz2 * fix predict_seasonality * improve predic_seasonal_components * uncomment None configs * save previous settings * fix failing test * move prediction_frequency to model_config * remove lagged layers merge issue * remove packer * rm Extractor * rm Extractor2 * fix stacker in forecaster * fix test * fix tests * remove commented code * cleanup * ruff * remove unused fourier funcs and document new func * fix new func * retain OG fourier * move fourier helper to plotting utils * ruff --------- Co-authored-by: MaiBe-ctrl Co-authored-by: Maisa Ben Salah <76703998+MaiBe-ctrl@users.noreply.github.com> --- neuralprophet/configure.py | 32 ++++- neuralprophet/data/process.py | 8 +- neuralprophet/df_utils.py | 23 ---- neuralprophet/forecaster.py | 173 ++++++++++++++------------ neuralprophet/plot_utils.py | 42 ++++++- neuralprophet/time_dataset.py | 91 +++++--------- neuralprophet/time_net.py | 8 +- neuralprophet/utils_time_dataset.py | 1 + tests/test_integration.py | 39 ++++-- tests/test_plotting.py | 2 +- tests/test_uncertainty.py | 2 +- tests/test_unit.py | 16 ++- tests/utils/benchmark_time_dataset.py | 10 +- 13 files changed, 241 insertions(+), 206 deletions(-) diff --git a/neuralprophet/configure.py b/neuralprophet/configure.py index 0c2a13b52..1e3287af8 100644 --- a/neuralprophet/configure.py +++ b/neuralprophet/configure.py @@ -5,7 +5,7 @@ import types from collections import OrderedDict from dataclasses import dataclass, field -from typing import Callable, List, Optional +from typing import Callable, Dict, List, Optional from typing import OrderedDict as OrderedDictType from typing import Type, Union @@ -23,8 +23,10 @@ @dataclass class Model: features_map: dict - lagged_reg_layers: Optional[List[int]] quantiles: Optional[List[float]] = None + prediction_frequency: Optional[Dict[str]] = None + features_map: Optional[dict] = field(default_factory=dict) + max_lags: Optional[int] = field(init=False) def setup_quantiles(self): # convert quantiles to empty list [] if None @@ -43,6 +45,32 @@ def setup_quantiles(self): # 0 is the median quantile index self.quantiles.insert(0, 0.5) + def set_max_num_lags(self, n_lags: int, config_lagged_regressors: Optional[ConfigLaggedRegressors] = None) -> int: + """Get the greatest number of lags between the autoregression lags and the covariates lags. + + Parameters + ---------- + n_lags : int + number of autoregressive lagged values of series to include as model inputs + config_lagged_regressors : configure.ConfigLaggedRegressors + Configurations for lagged regressors + + Returns + ------- + int + Maximum number of lags between the autoregression lags and the covariates lags. + """ + if ( + config_lagged_regressors is not None + and config_lagged_regressors.regressors is not None + and len(config_lagged_regressors.regressors) > 0 + ): + lagged_regressor_lags = [val.n_lags for key, val in config_lagged_regressors.regressors.items()] + max_lagged_regressor_lags = max(lagged_regressor_lags) + self.max_lags = max(n_lags, max_lagged_regressor_lags) + else: + self.max_lags = n_lags + ConfigModel = Model diff --git a/neuralprophet/data/process.py b/neuralprophet/data/process.py index 46e63a67b..8705c641a 100644 --- a/neuralprophet/data/process.py +++ b/neuralprophet/data/process.py @@ -276,12 +276,12 @@ def _prepare_dataframe_to_predict(model, df: pd.DataFrame, max_lags: int, freq: raise ValueError("only datestamps provided but y values needed for auto-regression.") df_i = _check_dataframe(model, df_i, check_y=False, exogenous=False) else: - df_i = _check_dataframe(model, df_i, check_y=model.max_lags > 0, exogenous=False) + df_i = _check_dataframe(model, df_i, check_y=model.config_model.max_lags > 0, exogenous=False) # fill in missing nans except for nans at end df_i = _handle_missing_data( df=df_i, freq=freq, - n_lags=model.n_lags, + n_lags=model.config_ar.n_lags, n_forecasts=model.n_forecasts, config_missing=model.config_missing, config_regressors=model.config_regressors, @@ -401,7 +401,7 @@ def _check_dataframe( pd.DataFrame checked dataframe """ - if len(df) < (model.n_forecasts + model.n_lags) and not future: + if len(df) < (model.n_forecasts + model.config_ar.n_lags) and not future: raise ValueError( "Dataframe has less than n_forecasts + n_lags rows. " "Forecasting not possible. Please either use a larger dataset, or adjust the model parameters." @@ -616,7 +616,7 @@ def _create_dataset(model, df, predict_mode, prediction_frequency=None, componen return time_dataset.GlobalTimeDataset( df, predict_mode=predict_mode, - n_lags=model.n_lags, + n_lags=model.config_ar.n_lags, n_forecasts=model.n_forecasts, prediction_frequency=prediction_frequency, predict_steps=model.predict_steps, diff --git a/neuralprophet/df_utils.py b/neuralprophet/df_utils.py index d00efff6f..1f390db4f 100644 --- a/neuralprophet/df_utils.py +++ b/neuralprophet/df_utils.py @@ -88,29 +88,6 @@ def return_df_in_original_format(df, received_ID_col=False, received_single_time return new_df -def get_max_num_lags(n_lags: int, config_lagged_regressors: Optional[ConfigLaggedRegressors]) -> int: - """Get the greatest number of lags between the autoregression lags and the covariates lags. - - Parameters - ---------- - n_lags : int - number of lagged values of series to include as model inputs - config_lagged_regressors : configure.ConfigLaggedRegressors - Configurations for lagged regressors - - Returns - ------- - int - Maximum number of lags between the autoregression lags and the covariates lags. - """ - if config_lagged_regressors is not None and config_lagged_regressors.regressors is not None: - # log.debug("config_lagged_regressors exists") - return max([n_lags] + [val.n_lags for key, val in config_lagged_regressors.regressors.items()]) - else: - # log.debug("config_lagged_regressors.regressors does not exist") - return n_lags - - def merge_dataframes(df: pd.DataFrame) -> pd.DataFrame: """Join dataframes for procedures such as splitting data, set auto seasonalities, and others. diff --git a/neuralprophet/forecaster.py b/neuralprophet/forecaster.py index a9fcecc75..e2c34f463 100644 --- a/neuralprophet/forecaster.py +++ b/neuralprophet/forecaster.py @@ -480,7 +480,14 @@ def __init__( # General self.name = "NeuralProphet" self.n_forecasts = n_forecasts - self.prediction_frequency = prediction_frequency + + # Model + self.config_model = configure.Model( + quantiles=quantiles, + prediction_frequency=prediction_frequency, + features_map={}, + ) + self.config_model.setup_quantiles() # Data Normalization settings self.config_normalization = configure.Normalization( @@ -509,16 +516,6 @@ def __init__( # AR self.config_ar = configure.AR(n_lags=n_lags, ar_reg=ar_reg, ar_layers=ar_layers) - self.n_lags = self.config_ar.n_lags - self.max_lags = self.n_lags - - # Model - self.config_model = configure.Model( - features_map={}, - lagged_reg_layers=lagged_reg_layers, - quantiles=quantiles, - ) - self.config_model.setup_quantiles() # Trend self.config_trend = configure.Trend( @@ -570,11 +567,15 @@ def __init__( self.config_lagged_regressors = configure.ConfigLaggedRegressors( layers=lagged_reg_layers, ) + # Update max_lags + self.config_model.set_max_num_lags( + n_lags=self.config_ar.n_lags, config_lagged_regressors=self.config_lagged_regressors + ) # Future Regressors self.config_regressors = configure.ConfigFutureRegressors( model=future_regressors_model, regressors_layers=future_regressors_layers, - ) # Optional[configure.ConfigFutureRegressors] = None + ) # set during fit() self.data_freq = None @@ -624,11 +625,11 @@ def add_lagged_regressor( f"Received n_lags {n_lags} for lagged regressor {names}. Please set n_lags > 0 or use options 'scalar' or 'auto'." ) if n_lags == "auto": - if self.n_lags is not None and self.n_lags > 0: - n_lags = self.n_lags + if self.config_ar.n_lags is not None and self.config_ar.n_lags > 0: + n_lags = self.config_ar.n_lags log.info( "n_lags = 'auto', number of lags for regressor is set to Autoregression number of lags " - + f"({self.n_lags})" + + f"({self.config_ar.n_lags})" ) else: n_lags = 1 @@ -661,6 +662,9 @@ def add_lagged_regressor( as_scalar=only_last_value, n_lags=n_lags, ) + self.config_model.set_max_num_lags( + n_lags=self.config_ar.n_lags, config_lagged_regressors=self.config_lagged_regressors + ) return self def parameters(self): @@ -1021,23 +1025,25 @@ def fit( if self.fitted: raise RuntimeError("Model has been fitted already.") - # Copy df and save list of unique time series IDs (the latter for global-local modelling if enabled) - df, _, _, self.id_list = df_utils.prep_or_copy_df(df) - df = _check_dataframe(self, df, check_y=True, exogenous=True) - # Infer from config if lags are activated - self.max_lags = df_utils.get_max_num_lags( - n_lags=self.n_lags, config_lagged_regressors=self.config_lagged_regressors + self.config_model.set_max_num_lags( + n_lags=self.config_ar.n_lags, config_lagged_regressors=self.config_lagged_regressors ) - if self.max_lags == 0 and self.n_forecasts > 1: + + if self.config_model.max_lags == 0 and self.n_forecasts > 1: self.n_forecasts = 1 self.predict_steps = 1 - log.warning( + log.error( "Changing n_forecasts to 1. Without lags, the forecast can be " "computed for any future time, independent of lagged values" ) + + # Copy df and save list of unique time series IDs (the latter for global-local modelling if enabled) + df, _, _, self.id_list = df_utils.prep_or_copy_df(df) + df = _check_dataframe(self, df, check_y=True, exogenous=True) + # Infer frequency from data - self.data_freq = df_utils.infer_frequency(df, n_lags=self.max_lags, freq=freq) + self.data_freq = df_utils.infer_frequency(df, n_lags=self.config_model.max_lags, freq=freq) # Setup Metrics if metrics is not None: @@ -1122,7 +1128,7 @@ def fit( df = _handle_missing_data( df=df, freq=self.data_freq, - n_lags=self.n_lags, + n_lags=self.config_ar.n_lags, n_forecasts=self.n_forecasts, config_missing=self.config_missing, config_regressors=self.config_regressors, @@ -1166,19 +1172,18 @@ def fit( # Create TimeDataset # Note: _create_dataset() needs to be called after set_auto_seasonalities() train_components_stacker = utils_time_dataset.ComponentStacker( - n_lags=self.n_lags, + n_lags=self.config_ar.n_lags, n_forecasts=self.n_forecasts, - max_lags=self.max_lags, + max_lags=self.config_model.max_lags, config_seasonality=self.config_seasonality, lagged_regressor_config=self.config_lagged_regressors, feature_indices={}, ) - dataset = _create_dataset( self, df, predict_mode=False, - prediction_frequency=self.prediction_frequency, + prediction_frequency=self.config_model.prediction_frequency, components_stacker=train_components_stacker, ) # Determine the max_number of epochs @@ -1204,7 +1209,7 @@ def fit( df_val = _handle_missing_data( df=df_val, freq=self.data_freq, - n_lags=self.n_lags, + n_lags=self.config_ar.n_lags, n_forecasts=self.n_forecasts, config_missing=self.config_missing, config_regressors=self.config_regressors, @@ -1216,8 +1221,8 @@ def fit( # df_val, _, _, _ = df_utils.prep_or_copy_df(df_val) df_val = _normalize(df=df_val, config_normalization=self.config_normalization) val_components_stacker = utils_time_dataset.ComponentStacker( - n_lags=self.n_lags, - max_lags=self.max_lags, + n_lags=self.config_ar.n_lags, + max_lags=self.config_model.max_lags, n_forecasts=self.n_forecasts, config_seasonality=self.config_seasonality, lagged_regressor_config=self.config_lagged_regressors, @@ -1375,21 +1380,21 @@ def predict(self, df: pd.DataFrame, decompose: bool = True, raw: bool = False, a df, periods_added = _maybe_extend_df( df=df, n_forecasts=self.n_forecasts, - max_lags=self.max_lags, + max_lags=self.config_model.max_lags, freq=self.data_freq, config_regressors=self.config_regressors, config_events=self.config_events, ) - df = _prepare_dataframe_to_predict(model=self, df=df, max_lags=self.max_lags, freq=self.data_freq) + df = _prepare_dataframe_to_predict(model=self, df=df, max_lags=self.config_model.max_lags, freq=self.data_freq) # normalize df = _normalize(df=df, config_normalization=self.config_normalization) forecast = pd.DataFrame() for df_name, df_i in df.groupby("ID"): dates, predicted, components = self._predict_raw( - df_i, df_name, include_components=decompose, prediction_frequency=self.prediction_frequency + df_i, df_name, include_components=decompose, prediction_frequency=self.config_model.prediction_frequency ) df_i = df_utils.drop_missing_from_df( - df_i, self.config_missing.drop_missing, self.predict_steps, self.n_lags + df_i, self.config_missing.drop_missing, self.predict_steps, self.config_ar.n_lags ) if raw: fcst = _convert_raw_predictions_to_raw_df( @@ -1406,10 +1411,10 @@ def predict(self, df: pd.DataFrame, decompose: bool = True, raw: bool = False, a df=df_i, predicted=predicted, components=components, - prediction_frequency=self.prediction_frequency, + prediction_frequency=self.config_model.prediction_frequency, dates=dates, n_forecasts=self.n_forecasts, - max_lags=self.max_lags, + max_lags=self.config_model.max_lags, freq=self.data_freq, quantiles=self.config_model.quantiles, config_lagged_regressors=self.config_lagged_regressors, @@ -1440,11 +1445,11 @@ def test(self, df: pd.DataFrame, verbose: bool = True): if self.fitted is False: log.warning("Model has not been fitted. Test results will be random.") df = _check_dataframe(self, df, check_y=True, exogenous=True) - freq = df_utils.infer_frequency(df, n_lags=self.max_lags, freq=self.data_freq) + freq = df_utils.infer_frequency(df, n_lags=self.config_model.max_lags, freq=self.data_freq) df = _handle_missing_data( df=df, freq=freq, - n_lags=self.n_lags, + n_lags=self.config_ar.n_lags, n_forecasts=self.n_forecasts, config_missing=self.config_missing, config_regressors=self.config_regressors, @@ -1456,9 +1461,9 @@ def test(self, df: pd.DataFrame, verbose: bool = True): df, _, _, _ = df_utils.prep_or_copy_df(df) df = _normalize(df=df, config_normalization=self.config_normalization) components_stacker = utils_time_dataset.ComponentStacker( - n_lags=self.n_lags, + n_lags=self.config_ar.n_lags, n_forecasts=self.n_forecasts, - max_lags=self.max_lags, + max_lags=self.config_model.max_lags, config_seasonality=self.config_seasonality, lagged_regressor_config=self.config_lagged_regressors, feature_indices={}, @@ -1592,11 +1597,11 @@ def split_df(self, df: pd.DataFrame, freq: str = "auto", valid_p: float = 0.2, l """ df, received_ID_col, received_single_time_series, _ = df_utils.prep_or_copy_df(df) df = _check_dataframe(self, df, check_y=False, exogenous=False) - freq = df_utils.infer_frequency(df, n_lags=self.max_lags, freq=freq) + freq = df_utils.infer_frequency(df, n_lags=self.config_model.max_lags, freq=freq) df = _handle_missing_data( df=df, freq=freq, - n_lags=self.n_lags, + n_lags=self.config_ar.n_lags, n_forecasts=self.n_forecasts, config_missing=self.config_missing, config_regressors=self.config_regressors, @@ -1607,7 +1612,7 @@ def split_df(self, df: pd.DataFrame, freq: str = "auto", valid_p: float = 0.2, l ) df_train, df_val = df_utils.split_df( df, - n_lags=self.max_lags, + n_lags=self.config_model.max_lags, n_forecasts=self.n_forecasts, valid_p=valid_p, inputs_overbleed=True, @@ -1781,11 +1786,11 @@ def crossvalidation_split_df( """ df, received_ID_col, received_single_time_series, _ = df_utils.prep_or_copy_df(df) df = _check_dataframe(self, df, check_y=False, exogenous=False) - freq = df_utils.infer_frequency(df, n_lags=self.max_lags, freq=freq) + freq = df_utils.infer_frequency(df, n_lags=self.config_model.max_lags, freq=freq) df = _handle_missing_data( df=df, freq=freq, - n_lags=self.n_lags, + n_lags=self.config_ar.n_lags, n_forecasts=self.n_forecasts, config_missing=self.config_missing, config_regressors=self.config_regressors, @@ -1796,7 +1801,7 @@ def crossvalidation_split_df( ) folds = df_utils.crossvalidation_split_df( df, - n_lags=self.max_lags, + n_lags=self.config_model.max_lags, n_forecasts=self.n_forecasts, k=k, fold_pct=fold_pct, @@ -1845,11 +1850,11 @@ def double_crossvalidation_split_df( """ df, _, _, _ = df_utils.prep_or_copy_df(df) df = _check_dataframe(self, df, check_y=False, exogenous=False) - freq = df_utils.infer_frequency(df, n_lags=self.max_lags, freq=freq) + freq = df_utils.infer_frequency(df, n_lags=self.config_model.max_lags, freq=freq) df = _handle_missing_data( df=df, freq=freq, - n_lags=self.n_lags, + n_lags=self.config_ar.n_lags, n_forecasts=self.n_forecasts, config_missing=self.config_missing, config_regressors=self.config_regressors, @@ -1860,7 +1865,7 @@ def double_crossvalidation_split_df( ) folds_val, folds_test = df_utils.double_crossvalidation_split_df( df, - n_lags=self.max_lags, + n_lags=self.config_model.max_lags, n_forecasts=self.n_forecasts, k=k, valid_pct=valid_pct, @@ -1979,7 +1984,7 @@ def make_future_dataframe( periods=periods, n_historic_predictions=n_historic_predictions, n_forecasts=self.n_forecasts, - max_lags=self.max_lags, + max_lags=self.config_model.max_lags, freq=self.data_freq, ) df_aux["ID"] = df_name @@ -2091,35 +2096,43 @@ def predict_seasonal_components(self, df: pd.DataFrame, quantile: float = 0.5): """ if quantile is not None and not (0 < quantile < 1): raise ValueError("The quantile specified need to be a float in-between (0,1)") + if quantile not in self.config_model.quantiles: + raise ValueError("The quantile needs to have been specified in the model configuration.") + + df_seasonal = pd.DataFrame() + prev_n_forecasts = self.n_forecasts + prev_n_lags = self.config_ar.n_lags + prev_max_lags = self.config_model.max_lags + prev_features_map = {key: value for key, value in self.config_model.features_map.items()} + + self.config_model.max_lags = 0 df, received_ID_col, received_single_time_series, _ = df_utils.prep_or_copy_df(df) df = _check_dataframe(self, df, check_y=False, exogenous=False) df = _normalize(df=df, config_normalization=self.config_normalization) - df_seasonal = pd.DataFrame() for df_name, df_i in df.groupby("ID"): feature_unstackor = ComponentStacker( n_lags=0, max_lags=0, n_forecasts=1, config_seasonality=self.config_seasonality, - lagged_regressor_config=self.config_lagged_regressors, + lagged_regressor_config=None, ) dataset = time_dataset.TimeDataset( df=df_i, predict_mode=True, n_lags=0, n_forecasts=1, - prediction_frequency=self.prediction_frequency, - predict_steps=self.predict_steps, - config_seasonality=self.config_seasonality, - config_events=self.config_events, - config_country_holidays=self.config_country_holidays, - config_regressors=self.config_regressors, - config_lagged_regressors=self.config_lagged_regressors, + prediction_frequency=self.config_model.prediction_frequency, + predict_steps=1, config_missing=self.config_missing, config_model=self.config_model, + config_seasonality=self.config_seasonality, + config_events=None, + config_country_holidays=None, + config_regressors=None, + config_lagged_regressors=None, components_stacker=feature_unstackor, - # config_train=self.config_train, # no longer needed since JIT tabularization. ) self.model.set_components_stacker(feature_unstackor, mode="predict") loader = DataLoader(dataset, batch_size=min(4096, len(df)), shuffle=False, drop_last=False) @@ -2153,10 +2166,15 @@ def predict_seasonal_components(self, df: pd.DataFrame, quantile: float = 0.5): if self.config_seasonality.mode == "additive": data_params = self.config_normalization.get_data_params(df_name) predicted[name] = predicted[name] * data_params["y"].scale - df_i = df_i[:: self.prediction_frequency].reset_index(drop=True) + df_i = df_i[:: self.config_model.prediction_frequency].reset_index(drop=True) df_aux = pd.DataFrame({"ds": df_i["ds"], "ID": df_i["ID"], **predicted}) df_seasonal = pd.concat((df_seasonal, df_aux), ignore_index=True) df = df_utils.return_df_in_original_format(df_seasonal, received_ID_col, received_single_time_series) + # reset possibly altered values + self.n_forecasts = prev_n_forecasts + self.config_ar.n_lags = prev_n_lags + self.config_model.max_lags = prev_max_lags + self.config_model.features_map = prev_features_map return df def set_true_ar_for_eval(self, true_ar_weights: np.ndarray): @@ -2274,13 +2292,13 @@ def plot( forecast_in_focus = self.highlight_forecast_step_n if len(self.config_model.quantiles) > 1: if (self.highlight_forecast_step_n) is None and ( - self.n_forecasts > 1 or self.n_lags > 0 + self.n_forecasts > 1 or self.config_model.max_lags > 0 ): # rather query if n_forecasts >1 than n_lags>1 raise ValueError( "Please specify step_number using the highlight_nth_step_ahead_of_each_forecast function" " for quantiles plotting when auto-regression enabled." ) - if (self.highlight_forecast_step_n or forecast_in_focus) is not None and self.n_lags == 0: + if (self.highlight_forecast_step_n or forecast_in_focus) is not None and self.config_model.max_lags == 0: log.warning("highlight_forecast_step_n is ignored since auto-regression not enabled.") self.highlight_forecast_step_n = None if forecast_in_focus is not None and forecast_in_focus > self.n_forecasts: @@ -2289,7 +2307,7 @@ def plot( "prediction time step to forecast " ) - if self.max_lags > 0: + if self.config_model.max_lags > 0: num_forecasts = sum(fcst["yhat1"].notna()) if num_forecasts < self.n_forecasts: log.warning( @@ -2372,7 +2390,7 @@ def get_latest_forecast( Historical data could be included, however be aware that the df could be large: >>> df_forecast = m.get_latest_forecast(forecast, include_history_data=True) """ - if self.max_lags == 0: + if self.config_model.max_lags == 0: raise ValueError("Use the standard plot function for models without lags.") fcst, received_ID_col, received_single_time_series, _ = df_utils.prep_or_copy_df(fcst) if not received_single_time_series: @@ -2386,7 +2404,7 @@ def get_latest_forecast( fcst = fcst[fcst["ID"] == df_name].copy(deep=True) log.info(f"Getting data from ID {df_name}") if include_history_data is None: - fcst = fcst[-(include_previous_forecasts + self.n_forecasts + self.max_lags) :] + fcst = fcst[-(include_previous_forecasts + self.n_forecasts + self.config_model.max_lags) :] elif include_history_data is False: fcst = fcst[-(include_previous_forecasts + self.n_forecasts) :] elif include_history_data is True: @@ -2448,7 +2466,7 @@ def plot_latest_forecast( matplotlib.axes.Axes plot of NeuralProphet forecasting """ - if self.max_lags == 0: + if self.config_model.max_lags == 0: raise ValueError("Use the standard plot function for models without lags.") fcst, received_ID_col, received_single_time_series, _ = df_utils.prep_or_copy_df(fcst) if not received_single_time_series: @@ -2467,7 +2485,7 @@ def plot_latest_forecast( " plots only the median quantile forecasts." ) if plot_history_data is None: - fcst = fcst[-(include_previous_forecasts + self.n_forecasts + self.max_lags) :] + fcst = fcst[-(include_previous_forecasts + self.n_forecasts + self.config_model.max_lags) :] elif plot_history_data is False: fcst = fcst[-(include_previous_forecasts + self.n_forecasts) :] elif plot_history_data is True: @@ -2856,8 +2874,7 @@ def _init_model(self): config_holidays=self.config_country_holidays, config_normalization=self.config_normalization, n_forecasts=self.n_forecasts, - n_lags=self.n_lags, - max_lags=self.max_lags, + n_lags=self.config_ar.n_lags, ar_layers=self.config_ar.ar_layers, metrics=self.metrics, id_list=self.id_list, @@ -2886,9 +2903,9 @@ def restore_trainer(self, accelerator: Optional[str] = None): ) def _eval_true_ar(self): - assert self.max_lags > 0 + assert self.config_model.max_lags > 0 if self.highlight_forecast_step_n is None: - if self.max_lags > 1: + if self.config_model.max_lags > 1: raise ValueError("Please define forecast_lag for sTPE computation") forecast_pos = 1 else: @@ -2942,9 +2959,9 @@ def _predict_raw(self, df, df_name, include_components=False, prediction_frequen if "y_scaled" not in df.columns or "t" not in df.columns: raise ValueError("Received unprepared dataframe to predict. " "Please call predict_dataframe_to_predict.") components_stacker = utils_time_dataset.ComponentStacker( - n_lags=self.n_lags, + n_lags=self.config_ar.n_lags, n_forecasts=self.n_forecasts, - max_lags=self.max_lags, + max_lags=self.config_model.max_lags, config_seasonality=self.config_seasonality, lagged_regressor_config=self.config_lagged_regressors, feature_indices={}, @@ -2959,9 +2976,9 @@ def _predict_raw(self, df, df_name, include_components=False, prediction_frequen self.model.set_components_stacker(components_stacker, mode="predict") loader = DataLoader(dataset, batch_size=min(1024, len(df)), shuffle=False, drop_last=False) if self.n_forecasts > 1: - dates = df["ds"].iloc[self.max_lags : -self.n_forecasts + 1] + dates = df["ds"].iloc[self.config_model.max_lags : -self.n_forecasts + 1] else: - dates = df["ds"].iloc[self.max_lags :] + dates = df["ds"].iloc[self.config_model.max_lags :] # Pass the include_components flag to the model if include_components: diff --git a/neuralprophet/plot_utils.py b/neuralprophet/plot_utils.py index 14b416bf3..e5e035902 100644 --- a/neuralprophet/plot_utils.py +++ b/neuralprophet/plot_utils.py @@ -1,16 +1,48 @@ import logging import warnings from collections import OrderedDict +from datetime import datetime from typing import Optional import numpy as np import torch -from neuralprophet import time_dataset, utils_torch +from neuralprophet import utils_torch log = logging.getLogger("NP.plotting") +def fourier_series_numpy(dates, period, series_order): + """Provides Fourier series components with the specified frequency and order. + Note + ---- + Identical to OG Prophet. + Parameters + ---------- + dates : pd.Series + Containing time stamps + period : float + Number of days of the period + series_order : int + Number of fourier components + Returns + ------- + np.array + Matrix with seasonality features + """ + # convert to days since epoch (numeric) + t = np.array((dates - datetime(1970, 1, 1)).dt.total_seconds().astype(np.float32)) / (3600 * 24.0) + features = fourier_series_numpy_numeric(t, period, series_order) + return features + + +def fourier_series_numpy_numeric(t, period, series_order): + features = np.column_stack( + [fun((2.0 * (i + 1) * np.pi * t / period)) for i in range(series_order) for fun in (np.sin, np.cos)] + ) + return features + + def log_warning_deprecation_plotly(plotting_backend): if plotting_backend == "matplotlib": log.warning( @@ -81,10 +113,8 @@ def predict_one_season(m, quantile, name, n_steps=100, df_name="__df__"): """ config = m.config_seasonality.periods[name] - t_i = np.arange(n_steps + 1) / float(n_steps) - features = time_dataset.fourier_series_t( - t=t_i * config.period, period=config.period, series_order=config.resolution - ) + t_i = np.arange(n_steps + 1) / float(n_steps) * config.period + features = fourier_series_numpy_numeric(t=t_i, period=config.period, series_order=config.resolution) features = torch.from_numpy(np.expand_dims(features, 1)) if df_name == "__df__": @@ -129,7 +159,7 @@ def predict_season_from_dates(m, dates, name, quantile, df_name="__df__"): presdicted seasonal component """ config = m.config_seasonality.periods[name] - features = time_dataset.fourier_series(dates=dates, period=config.period, series_order=config.resolution) + features = fourier_series_numpy(dates=dates, period=config.period, series_order=config.resolution) features = torch.from_numpy(np.expand_dims(features, 1)) if df_name == "__df__": meta_name_tensor = None diff --git a/neuralprophet/time_dataset.py b/neuralprophet/time_dataset.py index 39364a5b2..303d28d44 100644 --- a/neuralprophet/time_dataset.py +++ b/neuralprophet/time_dataset.py @@ -10,7 +10,6 @@ from torch.utils.data.dataset import Dataset from neuralprophet import configure, utils -from neuralprophet.df_utils import get_max_num_lags from neuralprophet.event_utils import get_all_holidays log = logging.getLogger("NP.time_dataset") @@ -79,8 +78,7 @@ def __init__( self.config_missing = config_missing self.config_model = config_model - self.max_lags = get_max_num_lags(n_lags=self.n_lags, config_lagged_regressors=self.config_lagged_regressors) - if self.max_lags == 0: + if self.config_model.max_lags == 0: assert self.n_forecasts == 1 self.two_level_inputs = ["seasonalities", "covariates", "events", "regressors"] @@ -162,19 +160,39 @@ def stack_all_features(self): self.all_features = torch.cat(feature_list, dim=1) # Concatenating along the second dimension def calculate_seasonalities(self): + """Computes Fourier series components with the specified frequency and order.""" self.seasonalities = OrderedDict({}) dates = self.df_tensors["ds"] t = (dates - torch.tensor(datetime(1900, 1, 1).timestamp())).float() / (3600 * 24.0) - def compute_fourier_features(t, period): - factor = 2.0 * np.pi / period.period - sin_terms = torch.sin(factor * t[:, None] * torch.arange(1, period.resolution + 1)) - cos_terms = torch.cos(factor * t[:, None] * torch.arange(1, period.resolution + 1)) - return torch.cat((sin_terms, cos_terms), dim=1) + def compute_fourier_features(t, period, resolution): + """Provides Fourier series components with the specified frequency and order. + Note + ---- + This function's calculation is identical to Meta AI's Prophet Library + Parameters + ---------- + t : pd.Series + Containing time as floating point number of days + period : float + Number of days of the period + resolution : int + Number of fourier components + Returns + ------- + tensor : torch.Tensor + Matrix with seasonality features, dims: (len(t), 2 * resolution) + """ + resolutions = torch.arange(1, resolution + 1) + factor = 2.0 * np.pi / period + periodicities = factor * resolutions * t[:, None] + features = torch.cat((torch.sin(periodicities), torch.cos(periodicities)), dim=1) + features.requires_grad = False + return features for name, period in self.config_seasonality.periods.items(): if period.resolution > 0: - features = compute_fourier_features(t, period) + features = compute_fourier_features(t, period.period, period.resolution) if period.condition_name is not None: condition_values = self.df_tensors[period.condition_name].unsqueeze(1) @@ -216,8 +234,8 @@ def __getitem__(self, index): df_index = self.sample_index_to_df_index(index) # Extract features from dataframe at given target index position - if self.max_lags > 0: - min_start_index = df_index - self.max_lags + 1 + if self.config_model.max_lags > 0: + min_start_index = df_index - self.config_model.max_lags + 1 max_end_index = df_index + self.n_forecasts + 1 inputs = self.all_features[min_start_index:max_end_index, :] else: @@ -242,7 +260,7 @@ def create_sample2index_map(self, df, df_tensors): # Limit target range due to input lags and number of forecasts df_length = len(df_tensors["ds"]) origin_start_end_mask = self.create_origin_start_end_mask( - df_length=df_length, max_lags=self.max_lags, n_forecasts=self.n_forecasts + df_length=df_length, max_lags=self.config_model.max_lags, n_forecasts=self.n_forecasts ) # Prediction Frequency @@ -258,7 +276,7 @@ def create_sample2index_map(self, df, df_tensors): nan_mask = self.create_nan_mask( df_tensors=df_tensors, predict_mode=self.predict_mode, - max_lags=self.max_lags, + max_lags=self.config_model.max_lags, n_lags=self.n_lags, n_forecasts=self.n_forecasts, config_lagged_regressors=self.config_lagged_regressors, @@ -636,50 +654,3 @@ def __getitem__(self, idx): df_name = self.global_sample_to_local_ID[idx] local_pos = self.global_sample_to_local_sample[idx] return self.datasets[df_name].__getitem__(local_pos) - - -def fourier_series(dates, period, series_order): - """Provides Fourier series components with the specified frequency and order. - Note - ---- - Identical to OG Prophet. - Parameters - ---------- - dates : pd.Series - Containing time stamps - period : float - Number of days of the period - series_order : int - Number of fourier components - Returns - ------- - np.array - Matrix with seasonality features - """ - # convert to days since epoch - t = np.array((dates - datetime(1970, 1, 1)).dt.total_seconds().astype(np.float32)) / (3600 * 24.0) - return fourier_series_t(t, period, series_order) - - -def fourier_series_t(t, period, series_order): - """Provides Fourier series components with the specified frequency and order. - Note - ---- - This function is identical to Meta AI's Prophet Library - Parameters - ---------- - t : pd.Series, float - Containing time as floating point number of days - period : float - Number of days of the period - series_order : int - Number of fourier components - Returns - ------- - np.array - Matrix with seasonality features - """ - features = np.column_stack( - [fun((2.0 * (i + 1) * np.pi * t / period)) for i in range(series_order) for fun in (np.sin, np.cos)] - ) - return features diff --git a/neuralprophet/time_net.py b/neuralprophet/time_net.py index 336d9bf76..a45fb9875 100644 --- a/neuralprophet/time_net.py +++ b/neuralprophet/time_net.py @@ -54,7 +54,6 @@ def __init__( config_holidays: Optional[configure.ConfigCountryHolidays] = None, n_forecasts: int = 1, n_lags: int = 0, - max_lags: int = 0, ar_layers: Optional[List[int]] = [], compute_components_flag: bool = False, metrics: Optional[np_types.CollectMetricsMode] = {}, @@ -92,9 +91,6 @@ def __init__( ---- The default value is ``0``, which initializes no auto-regression. - max_lags : int - Number of max. previous steps of time series used as input (aka AR-order). - ar_layers : list List of hidden layers (for AR-Net). @@ -267,7 +263,6 @@ def __init__( self.config_ar = config_ar self.n_lags = n_lags self.ar_layers = ar_layers - self.max_lags = max_lags if self.n_lags > 0: ar_net_layers = [] d_inputs = self.n_lags @@ -776,7 +771,6 @@ def training_step(self, batch, batch_idx): epoch_float = self.trainer.current_epoch + batch_idx / float(self.train_steps_per_epoch) self.train_progress = epoch_float / float(self.config_train.epochs) - targets = self.train_components_stacker.unstack_component("targets", batch_tensor=inputs_tensor) time = self.train_components_stacker.unstack_component("time", batch_tensor=inputs_tensor) # Global-local @@ -947,7 +941,7 @@ def _add_batch_regularizations(self, loss, progress): reg_loss = torch.zeros(1, dtype=torch.float, requires_grad=False, device=self.device) if delay_weight > 0: # Add regularization of AR weights - sparsify - if self.max_lags > 0 and self.config_ar.reg_lambda is not None: + if self.config_model.max_lags > 0 and self.config_ar.reg_lambda is not None: reg_ar = self.config_ar.regularize(self.ar_weights) reg_ar = torch.sum(reg_ar).squeeze() / self.n_forecasts reg_loss += self.config_ar.reg_lambda * reg_ar diff --git a/neuralprophet/utils_time_dataset.py b/neuralprophet/utils_time_dataset.py index 075dff8e1..de09b4d9b 100644 --- a/neuralprophet/utils_time_dataset.py +++ b/neuralprophet/utils_time_dataset.py @@ -128,6 +128,7 @@ def unstack_additive_events(self, batch_tensor): ] else: events_start_idx, events_end_idx = self.feature_indices["additive_events"] + return batch_tensor[:, events_start_idx : events_end_idx + 1].unsqueeze(1) def unstack_multiplicative_events(self, batch_tensor): diff --git a/tests/test_integration.py b/tests/test_integration.py index ee4a9028d..a6d727acc 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -59,7 +59,7 @@ def test_train_eval_test(): _handle_missing_data( df=df, freq="D", - n_lags=m.n_lags, + n_lags=m.config_ar.n_lags, n_forecasts=m.n_forecasts, config_missing=m.config_missing, config_regressors=m.config_regressors, @@ -463,7 +463,7 @@ def test_air_data(): learning_rate=LR, ) m.fit(df, freq="MS") - future = m.make_future_dataframe(df, periods=48, n_historic_predictions=len(df) - m.n_lags) + future = m.make_future_dataframe(df, periods=48, n_historic_predictions=len(df) - m.config_ar.n_lags) forecast = m.predict(future) if PLOT: m.plot(forecast) @@ -551,13 +551,14 @@ def check_cv(df, freq, n_lags, n_forecasts, k, fold_pct, fold_overlap_pct): learning_rate=LR, ) folds = m.crossvalidation_split_df(df, freq=freq, k=k, fold_pct=fold_pct, fold_overlap_pct=fold_overlap_pct) - total_samples = len(df) - m.n_lags + 2 - (2 * m.n_forecasts) + total_samples = len(df) - m.config_ar.n_lags + 2 - (2 * m.n_forecasts) per_fold = int(fold_pct * total_samples) not_overlap = per_fold - int(fold_overlap_pct * per_fold) - assert all([per_fold == len(val) - m.n_lags + 1 - m.n_forecasts for (train, val) in folds]) + assert all([per_fold == len(val) - m.config_ar.n_lags + 1 - m.n_forecasts for (train, val) in folds]) assert all( [ - total_samples - per_fold - (k - i - 1) * not_overlap == len(train) - m.n_lags + 1 - m.n_forecasts + total_samples - per_fold - (k - i - 1) * not_overlap + == len(train) - m.config_ar.n_lags + 1 - m.n_forecasts for i, (train, val) in enumerate(folds) ] ) @@ -762,27 +763,36 @@ def test_global_modeling_no_exogenous_variable(): trend_global_local="global", season_global_local="global", ) + print(m.config_model) m.fit(pd.concat((df1_0, df2_0)), freq="D") + # Set unknown_data_normalization to True - now there should be no errors + m.config_normalization.unknown_data_normalization = True + forecast = m.predict(df4_0) + # print(m.config_model) + m.test(df4_0) + m.predict_trend(df4_0) + m.predict_seasonal_components(df4_0) + m.plot_parameters(df_name="df1") + m.plot_parameters() + + # Set unknown_data_normalization to False - now there should be errors + m.config_normalization.unknown_data_normalization = False with pytest.raises(ValueError): forecast = m.predict(df4_0) + print(m.config_model) log.info("unknown_data_normalization was not set to True") with pytest.raises(ValueError): m.test(df4_0) + print(m.config_model) log.info("unknown_data_normalization was not set to True") with pytest.raises(ValueError): m.predict_trend(df4_0) + print(m.config_model) log.info("unknown_data_normalization was not set to True") with pytest.raises(ValueError): m.predict_seasonal_components(df4_0) + print(m.config_model) log.info("unknown_data_normalization was not set to True") - # Set unknown_data_normalization to True - now there should be no errors - m.config_normalization.unknown_data_normalization = True - forecast = m.predict(df4_0) - m.test(df4_0) - m.predict_trend(df4_0) - m.predict_seasonal_components(df4_0) - m.plot_parameters(df_name="df1") - m.plot_parameters() def test_global_modeling_validation_df(): @@ -1688,3 +1698,6 @@ def test_fit_twice_error(): _ = m.fit(df, freq="D") with pytest.raises(RuntimeError): _ = m.fit(df, freq="D") + + +test_global_modeling_no_exogenous_variable() diff --git a/tests/test_plotting.py b/tests/test_plotting.py index 0d24f3530..7c1b1a256 100644 --- a/tests/test_plotting.py +++ b/tests/test_plotting.py @@ -409,7 +409,7 @@ def test_plot_trend(plotting_backend): seasonality_mode="multiplicative", ) m.fit(df, freq="MS") - future = m.make_future_dataframe(df, periods=48, n_historic_predictions=len(df) - m.n_lags) + future = m.make_future_dataframe(df, periods=48, n_historic_predictions=len(df) - m.config_ar.n_lags) forecast = m.predict(future) fig1 = m.plot(forecast, plotting_backend=plotting_backend) fig2 = m.plot_components(forecast, plotting_backend=plotting_backend) diff --git a/tests/test_uncertainty.py b/tests/test_uncertainty.py index 7fea7618b..3423edf8d 100644 --- a/tests/test_uncertainty.py +++ b/tests/test_uncertainty.py @@ -68,7 +68,7 @@ def test_uncertainty_estimation_peyton_manning(): ) # add lagged regressors - if m.n_lags > 0: + if m.config_ar.n_lags > 0: df["A"] = df["y"].rolling(7, min_periods=1).mean() df["B"] = df["y"].rolling(30, min_periods=1).mean() m = m.add_lagged_regressor(names="A") diff --git a/tests/test_unit.py b/tests/test_unit.py index 6c234ae57..182e2ad90 100644 --- a/tests/test_unit.py +++ b/tests/test_unit.py @@ -74,6 +74,8 @@ def test_timedataset_minimal(): log.debug(f"Infile shape: {df_in.shape}") valid_p = 0.2 for n_forecasts, n_lags in [(1, 0), (1, 5), (3, 5)]: + config_model = configure.Model() + config_model.set_max_num_lags(n_lags) config_missing = configure.MissingDataHandling() # config_train = configure.Train() df, df_val = df_utils.split_df(df_in, n_lags, n_forecasts, valid_p) @@ -118,7 +120,7 @@ def test_timedataset_minimal(): config_regressors=None, config_lagged_regressors=None, config_missing=config_missing, - config_model=None, + config_model=config_model, components_stacker=components_stacker, ) input, meta = dataset.__getitem__(0) @@ -694,9 +696,9 @@ def test_globaltimedataset(): m.config_normalization = config_normalization df_global = _normalize(df=df_global, config_normalization=m.config_normalization) components_stacker = utils_time_dataset.ComponentStacker( - n_lags=m.n_lags, + n_lags=m.config_ar.n_lags, n_forecasts=m.n_forecasts, - max_lags=m.max_lags, + max_lags=m.config_model.max_lags, config_seasonality=m.config_seasonality, lagged_regressor_config=m.config_lagged_regressors, ) @@ -724,9 +726,9 @@ def test_globaltimedataset(): m.config_normalization = config_normalization df4 = _normalize(df=df4, config_normalization=m.config_normalization) components_stacker = utils_time_dataset.ComponentStacker( - n_lags=m.n_lags, + n_lags=m.config_ar.n_lags, n_forecasts=m.n_forecasts, - max_lags=m.max_lags, + max_lags=m.config_model.max_lags, config_seasonality=m.config_seasonality, lagged_regressor_config=m.config_lagged_regressors, ) @@ -870,6 +872,8 @@ def test_make_future(): def test_too_many_NaN(): n_lags = 12 n_forecasts = 1 + config_model = configure.Model() + config_model.set_max_num_lags(n_lags) config_missing = configure.MissingDataHandling( impute_missing=True, impute_linear=5, @@ -915,7 +919,7 @@ def test_too_many_NaN(): config_regressors=None, config_lagged_regressors=None, config_missing=config_missing, - config_model=None, + config_model=config_model, components_stacker=components_stacker, ) diff --git a/tests/utils/benchmark_time_dataset.py b/tests/utils/benchmark_time_dataset.py index 88d1a6f28..e2984b10a 100644 --- a/tests/utils/benchmark_time_dataset.py +++ b/tests/utils/benchmark_time_dataset.py @@ -70,11 +70,11 @@ def load(nrows=NROWS, epochs=EPOCHS, batch=BATCH_SIZE, season=True, iterations=1 df, _, _, m.id_list = df_utils.prep_or_copy_df(df) df = _check_dataframe(m, df, check_y=True, exogenous=True) - m.data_freq = df_utils.infer_frequency(df, n_lags=m.max_lags, freq=freq) + m.data_freq = df_utils.infer_frequency(df, n_lags=m.config_model.max_lags, freq=freq) df = _handle_missing_data( df=df, freq=m.data_freq, - n_lags=m.n_lags, + n_lags=m.config_ar.n_lags, n_forecasts=m.n_forecasts, config_missing=m.config_missing, config_regressors=m.config_regressors, @@ -99,7 +99,7 @@ def load(nrows=NROWS, epochs=EPOCHS, batch=BATCH_SIZE, season=True, iterations=1 m.config_country_holidays.init_holidays(df_merged) dataset = _create_dataset( - m, df, predict_mode=False, prediction_frequency=m.prediction_frequency + m, df, predict_mode=False, prediction_frequency=m.model_config.prediction_frequency ) # needs to be called after set_auto_seasonalities # Determine the max_number of epochs @@ -214,7 +214,7 @@ def peyton(nrows=NROWS, epochs=EPOCHS, batch=BATCH_SIZE, season=True): ) # add lagged regressors - # # if m.n_lags > 0: + # # if m.config_ar.n_lags > 0: # df["A"] = df["y"].rolling(7, min_periods=1).mean() # df["B"] = df["y"].rolling(30, min_periods=1).mean() # m = m.add_lagged_regressor(name="A", n_lags=10) @@ -264,7 +264,7 @@ def peyton_minus_events(nrows=NROWS, epochs=EPOCHS, batch=BATCH_SIZE, season=Tru ) # add lagged regressors - if m.n_lags > 0: + if m.config_ar.n_lags > 0: df["A"] = df["y"].rolling(7, min_periods=1).mean() df["B"] = df["y"].rolling(30, min_periods=1).mean() m = m.add_lagged_regressor(name="A")