Skip to content

Commit

Permalink
[Minor| Fix neural nets regressor shape (#1589)
Browse files Browse the repository at this point in the history
* fixed tensor shapes in regressor

* add test

* add shared NN test and add note to forecaster

* add tests for shared and coef NN and fix implementations

* move future regressor test with NN to separate file and separate test into smaller tests

* cleanup

* fix SharedNeuralNetsCoefFutureRegressors

* isolate issue

---------

Co-authored-by: Oskar Triebe <ourownstory@users.noreply.github.com>
  • Loading branch information
MaiBe-ctrl and ourownstory authored Jun 21, 2024
1 parent f1a3820 commit 0aafec9
Show file tree
Hide file tree
Showing 6 changed files with 211 additions and 107 deletions.
13 changes: 3 additions & 10 deletions neuralprophet/components/future_regressors/neural_nets.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ def __init__(self, config, id_list, quantiles, n_forecasts, device, config_trend
for i in range(self.num_hidden_layers_regressors):
regressor_net.append(nn.Linear(d_inputs, self.d_hidden_regressors, bias=True))
d_inputs = self.d_hidden_regressors
# final layer has input size d_inputs and output size equal to no. of forecasts * no. of quantiles
regressor_net.append(nn.Linear(d_inputs, self.n_forecasts * len(self.quantiles), bias=False))
# final layer has input size d_inputs and output size equal to no. of quantiles
regressor_net.append(nn.Linear(d_inputs, len(self.quantiles), bias=False))
for lay in regressor_net:
nn.init.kaiming_normal_(lay.weight, mode="fan_in")
self.regressor_nets[regressor] = regressor_net
Expand Down Expand Up @@ -84,8 +84,6 @@ def regressor(self, regressor_input, name):
x = nn.functional.relu(x)
x = self.regressor_nets[name][i](x)

# segment the last dimension to match the quantiles
x = x.reshape(x.shape[0], self.n_forecasts, len(self.quantiles))
return x

def all_regressors(self, regressor_inputs, mode):
Expand Down Expand Up @@ -123,9 +121,4 @@ def forward(self, inputs, mode, indeces=None):
torch.Tensor
Forecast component of dims (batch, n_forecasts, no_quantiles)
"""

if "additive" == mode:
f_r = self.all_regressors(inputs, mode="additive")
if "multiplicative" == mode:
f_r = self.all_regressors(inputs, mode="multiplicative")
return f_r
return self.all_regressors(inputs, mode)
15 changes: 5 additions & 10 deletions neuralprophet/components/future_regressors/shared_neural_nets.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ def __init__(self, config, id_list, quantiles, n_forecasts, device, config_trend
for i in range(self.num_hidden_layers_regressors):
regressor_net.append(nn.Linear(d_inputs, self.d_hidden_regressors, bias=True))
d_inputs = self.d_hidden_regressors
# final layer has input size d_inputs and output size equal to no. of forecasts * no. of quantiles
regressor_net.append(nn.Linear(d_inputs, self.n_forecasts * len(self.quantiles), bias=False))
# final layer has input size d_inputs and output size equal to no. of quantiles
regressor_net.append(nn.Linear(d_inputs, len(self.quantiles), bias=False))
for lay in regressor_net:
nn.init.kaiming_normal_(lay.weight, mode="fan_in")
self.regressor_nets[net_i] = regressor_net
Expand Down Expand Up @@ -67,7 +67,7 @@ def get_reg_weights(self, name):
regressor_index = self.regressors_dims[name]["regressor_index"]
return reg_attributions[:, regressor_index].unsqueeze(-1)

def regressors_net(self, regressor_inputs, mode):
def regressors(self, regressor_inputs, mode):
"""Compute single regressor component.
Parameters
----------
Expand All @@ -87,7 +87,7 @@ def regressors_net(self, regressor_inputs, mode):
x = self.regressor_nets[mode][i](x)

# segment the last dimension to match the quantiles
x = x.reshape(x.shape[0], self.n_forecasts, len(self.quantiles))
# x = x.reshape(x.shape[0], self.n_forecasts, len(self.quantiles)) # causes error in case of multiple forecast targes, possibly unneeded/wrong
return x

def forward(self, inputs, mode, indeces=None):
Expand All @@ -103,9 +103,4 @@ def forward(self, inputs, mode, indeces=None):
torch.Tensor
Forecast component of dims (batch, n_forecasts, no_quantiles)
"""

if "additive" == mode:
f_r = self.regressors_net(inputs, mode="additive")
if "multiplicative" == mode:
f_r = self.regressors_net(inputs, mode="multiplicative")
return f_r
return self.regressors(inputs, mode)
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ def __init__(self, config, id_list, quantiles, n_forecasts, device, config_trend
for i in range(self.num_hidden_layers_regressors):
regressor_net.append(nn.Linear(d_inputs, self.d_hidden_regressors, bias=True))
d_inputs = self.d_hidden_regressors
# final layer has input size d_inputs and output size equal to no. of forecasts * no. of quantiles
regressor_net.append(nn.Linear(d_inputs, size_i * self.n_forecasts * len(self.quantiles), bias=False))
# final layer has input size d_inputs and output size equal to no. of quantiles
regressor_net.append(nn.Linear(d_inputs, size_i * len(self.quantiles), bias=False))
for lay in regressor_net:
nn.init.kaiming_normal_(lay.weight, mode="fan_in")
self.regressor_nets[net_i] = regressor_net
Expand Down Expand Up @@ -68,7 +68,7 @@ def get_reg_weights(self, name):
regressor_index = self.regressors_dims[name]["regressor_index"]
return reg_attributions[:, regressor_index].unsqueeze(-1)

def regressors_net(self, regressor_inputs, mode):
def regressors(self, regressor_inputs, mode):
"""Compute single regressor component.
Parameters
----------
Expand All @@ -88,6 +88,7 @@ def regressors_net(self, regressor_inputs, mode):
x = self.regressor_nets[mode][i](x)

# segment the last dimension to match the quantiles
# causes errorin case of multiple forecast targes and lags, likely wrong, but needed with no lags
x = x.reshape(x.shape[0], self.n_forecasts, regressor_inputs.shape[-1], len(self.quantiles))
x = (regressor_inputs.unsqueeze(-1) * x).sum(-2)
return x
Expand All @@ -105,9 +106,4 @@ def forward(self, inputs, mode, indeces=None):
torch.Tensor
Forecast component of dims (batch, n_forecasts, no_quantiles)
"""

if "additive" == mode:
f_r = self.regressors_net(inputs, mode="additive")
if "multiplicative" == mode:
f_r = self.regressors_net(inputs, mode="multiplicative")
return f_r
return self.regressors(inputs, mode)
4 changes: 3 additions & 1 deletion neuralprophet/forecaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,8 @@ class NeuralProphet:
Options
* (default) ``linear``
* ``neural_nets``
* ``shared_neural_nets``
* ``shared_neural_nets_coef``
future_regressors_d_hidden: int
Number of hidden layers in the neural network model for future regressors.
Expand Down Expand Up @@ -1192,7 +1194,7 @@ def predict(self, df: pd.DataFrame, decompose: bool = True, raw: bool = False, a
config_lagged_regressors=self.config_lagged_regressors,
)
if auto_extend and periods_added[df_name] > 0:
fcst = fcst[:-periods_added[df_name]]
fcst = fcst[: -periods_added[df_name]]
forecast = pd.concat((forecast, fcst), ignore_index=True)

df = df_utils.return_df_in_original_format(forecast, received_ID_col, received_single_time_series)
Expand Down
195 changes: 195 additions & 0 deletions tests/test_future_regressor_nn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
#!/usr/bin/env python3

import logging
import os
import pathlib

import pandas as pd

from neuralprophet import NeuralProphet

log = logging.getLogger("NP.test")
log.setLevel("DEBUG")
log.parent.setLevel("WARNING")

DIR = pathlib.Path(__file__).parent.parent.absolute()
DATA_DIR = os.path.join(DIR, "tests", "test-data")
PEYTON_FILE = os.path.join(DATA_DIR, "wp_log_peyton_manning.csv")

TUTORIAL_FILE = "https://github.com/ourownstory/neuralprophet-data/raw/main/kaggle-energy/datasets/tutorial04.csv"
NROWS = 1028
EPOCHS = 2
BATCH_SIZE = 128
LR = 1.0

PLOT = False


def test_future_reg_nn():
log.info("testing: Future Regressors modelled with NNs")
df = pd.read_csv(PEYTON_FILE, nrows=NROWS + 50)
m = NeuralProphet(epochs=EPOCHS, batch_size=BATCH_SIZE, learning_rate=LR, future_regressors_model="neural_nets")
df["A"] = df["y"].rolling(7, min_periods=1).mean()
df["B"] = df["y"].rolling(30, min_periods=1).mean()
df["C"] = df["y"].rolling(7, min_periods=1).mean()
df["D"] = df["y"].rolling(30, min_periods=1).mean()

regressors_df_future = pd.DataFrame(
data={"A": df["A"][-50:], "B": df["B"][-50:], "C": df["C"][-50:], "D": df["D"][-50:]}
)
df = df[:-50]
m = m.add_future_regressor(name="A")
m = m.add_future_regressor(name="B", mode="additive")
m = m.add_future_regressor(name="C", mode="multiplicative")
m = m.add_future_regressor(name="D", mode="multiplicative")
m.fit(df, freq="D")
future = m.make_future_dataframe(df=df, regressors_df=regressors_df_future, n_historic_predictions=10, periods=50)
forecast = m.predict(df=future)
if PLOT:
m.plot(forecast)
m.plot_components(forecast)
m.plot_parameters()
plt.show()


def test_future_reg_nn_shared():
log.info("testing: Future Regressors modelled with NNs shared")
df = pd.read_csv(PEYTON_FILE, nrows=NROWS + 50)
m = NeuralProphet(
epochs=EPOCHS, batch_size=BATCH_SIZE, learning_rate=LR, future_regressors_model="shared_neural_nets"
)
df["A"] = df["y"].rolling(7, min_periods=1).mean()
df["B"] = df["y"].rolling(30, min_periods=1).mean()
df["C"] = df["y"].rolling(7, min_periods=1).mean()
df["D"] = df["y"].rolling(30, min_periods=1).mean()

regressors_df_future = pd.DataFrame(
data={"A": df["A"][-50:], "B": df["B"][-50:], "C": df["C"][-50:], "D": df["D"][-50:]}
)
df = df[:-50]
m = m.add_future_regressor(name="A")
m = m.add_future_regressor(name="B", mode="additive")
m = m.add_future_regressor(name="C", mode="multiplicative")
m = m.add_future_regressor(name="D", mode="multiplicative")
m.fit(df, freq="D")
future = m.make_future_dataframe(df=df, regressors_df=regressors_df_future, n_historic_predictions=10, periods=50)
forecast = m.predict(df=future)
if PLOT:
m.plot(forecast)
m.plot_components(forecast)
m.plot_parameters()
plt.show()


def test_future_reg_nn_shared_coef():
log.info("testing: Future Regressors modelled with NNs shared coef")
df = pd.read_csv(PEYTON_FILE, nrows=NROWS + 50)
m = NeuralProphet(
epochs=EPOCHS, batch_size=BATCH_SIZE, learning_rate=LR, future_regressors_model="shared_neural_nets_coef"
)
df["A"] = df["y"].rolling(7, min_periods=1).mean()
df["B"] = df["y"].rolling(30, min_periods=1).mean()
df["C"] = df["y"].rolling(7, min_periods=1).mean()
df["D"] = df["y"].rolling(30, min_periods=1).mean()

regressors_df_future = pd.DataFrame(
data={"A": df["A"][-50:], "B": df["B"][-50:], "C": df["C"][-50:], "D": df["D"][-50:]}
)
df = df[:-50]
m = m.add_future_regressor(name="A")
m = m.add_future_regressor(name="B", mode="additive")
m = m.add_future_regressor(name="C", mode="multiplicative")
m = m.add_future_regressor(name="D", mode="multiplicative")
m.fit(df, freq="D")
future = m.make_future_dataframe(df=df, regressors_df=regressors_df_future, n_historic_predictions=10, periods=50)
forecast = m.predict(df=future)
if PLOT:
m.plot(forecast)
m.plot_components(forecast)
m.plot_parameters()
plt.show()


def test_future_regressor_nn_2():
log.info("future regressor with NN")

df = pd.read_csv(TUTORIAL_FILE, nrows=NROWS)

m = NeuralProphet(
yearly_seasonality=False,
weekly_seasonality=False,
daily_seasonality=True,
future_regressors_model="neural_nets", # 'linear' default or 'neural_nets'
future_regressors_d_hidden=4, # (int)
future_regressors_num_hidden_layers=2, # (int)
n_forecasts=3,
n_lags=5,
drop_missing=True,
# trainer_config={"accelerator": "gpu"},
)
df_train, df_val = m.split_df(df, freq="H", valid_p=0.2)

# Use static plotly in notebooks
# m.set_plotting_backend("plotly")

# Add the new future regressor
m.add_future_regressor("temperature")

# Add counrty holidays
m.add_country_holidays("IT", mode="additive", lower_window=-1, upper_window=1)

metrics = m.fit(
df_train, validation_df=df_val, freq="H", epochs=EPOCHS, learning_rate=LR, early_stopping=True, progress=False
)


def test_future_regressor_nn_shared_2():
log.info("future regressor with NN shared 2")

df = pd.read_csv(TUTORIAL_FILE, nrows=NROWS)

m = NeuralProphet(
yearly_seasonality=False,
weekly_seasonality=False,
daily_seasonality=True,
future_regressors_model="shared_neural_nets",
future_regressors_d_hidden=4,
future_regressors_num_hidden_layers=2,
n_forecasts=3,
n_lags=5,
drop_missing=True,
)
df_train, df_val = m.split_df(df, freq="H", valid_p=0.2)

# Add the new future regressor
m.add_future_regressor("temperature")

metrics = m.fit(
df_train, validation_df=df_val, freq="H", epochs=EPOCHS, learning_rate=LR, early_stopping=True, progress=False
)


# def test_future_regressor_nn_shared_coef_2():
# log.info("future regressor with NN shared coef 2")

# df = pd.read_csv(TUTORIAL_FILE, nrows=NROWS)

# m = NeuralProphet(
# yearly_seasonality=False,
# weekly_seasonality=False,
# daily_seasonality=True,
# future_regressors_model="shared_neural_nets_coef",
# future_regressors_d_hidden=4,
# future_regressors_num_hidden_layers=2,
# n_forecasts=3,
# n_lags=5,
# drop_missing=True,
# )
# df_train, df_val = m.split_df(df, freq="H", valid_p=0.2)

# # Add the new future regressor
# m.add_future_regressor("temperature")

# metrics = m.fit(
# df_train, validation_df=df_val, freq="H", epochs=EPOCHS, learning_rate=LR, early_stopping=True, progress=False
# )
Loading

0 comments on commit 0aafec9

Please sign in to comment.