diff --git a/neuralprophet/configure.py b/neuralprophet/configure.py index 784e00194..f64315e26 100644 --- a/neuralprophet/configure.py +++ b/neuralprophet/configure.py @@ -5,7 +5,7 @@ import types from collections import OrderedDict from dataclasses import dataclass, field -from typing import Callable, Iterable, List, Optional +from typing import Callable, List, Optional from typing import OrderedDict as OrderedDictType from typing import Type, Union @@ -13,7 +13,7 @@ import pandas as pd import torch -from neuralprophet import df_utils, np_types, utils, utils_torch +from neuralprophet import df_utils, np_types, utils_torch from neuralprophet.custom_loss_metrics import PinballLoss from neuralprophet.hdays_utils import get_holidays_from_country diff --git a/neuralprophet/custom_loss_metrics.py b/neuralprophet/custom_loss_metrics.py index 452bf93c0..3ce334862 100644 --- a/neuralprophet/custom_loss_metrics.py +++ b/neuralprophet/custom_loss_metrics.py @@ -33,15 +33,15 @@ def forward(self, outputs, target): """ target = target.repeat(1, 1, len(self.quantiles)) # increase the quantile dimension of the targets differences = target - outputs - base_losses = self.loss_func(outputs, target) # dimensions - [n_batch, n_forecasts, no. of quantiles] - positive_losses = ( - torch.tensor(self.quantiles, device=target.device).unsqueeze(dim=0).unsqueeze(dim=0) * base_losses + base_losses = self.loss_func(outputs, target).float() # dimensions - [n_batch, n_forecasts, no. of quantiles] + quantiles_tensor = ( + torch.tensor(self.quantiles, device=target.device, dtype=torch.float32).unsqueeze(dim=0).unsqueeze(dim=0) ) - negative_losses = ( - 1 - torch.tensor(self.quantiles, device=target.device).unsqueeze(dim=0).unsqueeze(dim=0) - ) * base_losses + positive_losses = quantiles_tensor * base_losses + negative_losses = (1 - quantiles_tensor) * base_losses + differences = differences.float() pinball_losses = torch.where(differences >= 0, positive_losses, negative_losses) - multiplier = torch.ones(size=(1, 1, len(self.quantiles)), device=target.device) + multiplier = torch.ones(size=(1, 1, len(self.quantiles)), device=target.device, dtype=torch.float32) multiplier[:, :, 0] = 2 pinball_losses = multiplier * pinball_losses # double the loss for the median quantile return pinball_losses diff --git a/neuralprophet/forecaster.py b/neuralprophet/forecaster.py index dfc2504a9..d80fcef14 100644 --- a/neuralprophet/forecaster.py +++ b/neuralprophet/forecaster.py @@ -11,6 +11,7 @@ import torch from matplotlib import pyplot from matplotlib.axes import Axes +from pytorch_lightning.tuner.tuning import Tuner from torch.utils.data import DataLoader from neuralprophet import configure, df_utils, np_types, time_dataset, time_net, utils, utils_metrics @@ -2756,6 +2757,8 @@ def _train( else: self.model = self._init_model() + self.model.train_loader = train_loader + # Init the Trainer self.trainer, checkpoint_callback = utils.configure_trainer( config_train=self.config_train, @@ -2780,8 +2783,9 @@ def _train( # Set parameters for the learning rate finder self.config_train.set_lr_finder_args(dataset_size=dataset_size, num_batches=len(train_loader)) # Find suitable learning rate - lr_finder = self.trainer.tuner.lr_find( - self.model, + tuner = Tuner(self.trainer) + lr_finder = tuner.lr_find( + model=self.model, train_dataloaders=train_loader, val_dataloaders=val_loader, **self.config_train.lr_finder_args, @@ -2802,8 +2806,9 @@ def _train( # Set parameters for the learning rate finder self.config_train.set_lr_finder_args(dataset_size=dataset_size, num_batches=len(train_loader)) # Find suitable learning rate - lr_finder = self.trainer.tuner.lr_find( - self.model, + tuner = Tuner(self.trainer) + lr_finder = tuner.lr_find( + model=self.model, train_dataloaders=train_loader, **self.config_train.lr_finder_args, ) @@ -2831,7 +2836,6 @@ def _train( if not metrics_enabled: return None - # Return metrics collected in logger as dataframe metrics_df = pd.DataFrame(self.metrics_logger.history) return metrics_df diff --git a/neuralprophet/logger.py b/neuralprophet/logger.py index 28a72829a..5a34628d2 100644 --- a/neuralprophet/logger.py +++ b/neuralprophet/logger.py @@ -52,6 +52,7 @@ class ProgressBar(TQDMProgressBar): def __init__(self, *args, **kwargs): self.epochs = kwargs.pop("epochs") super().__init__(*args, **kwargs) + self.main_progress_bar = super().init_train_tqdm() def on_train_epoch_start(self, trainer: "pl.Trainer", *_) -> None: self.main_progress_bar.reset(self.epochs) diff --git a/neuralprophet/time_net.py b/neuralprophet/time_net.py index f2b7388eb..ea3c4b2f3 100644 --- a/neuralprophet/time_net.py +++ b/neuralprophet/time_net.py @@ -795,7 +795,8 @@ def training_step(self, batch, batch_idx): scheduler.step() # Manually track the loss for the lr finder - self.trainer.fit_loop.running_loss.append(loss) + self.log("train_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True) + self.log("reg_loss", reg_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True) # Metrics if self.metrics_enabled: @@ -983,6 +984,9 @@ def denormalize(self, ts): ts = scale_y * ts + shift_y return ts + def train_dataloader(self): + return self.train_loader + class FlatNet(nn.Module): """ diff --git a/neuralprophet/utils.py b/neuralprophet/utils.py index fb3f016e5..e7f86ead0 100644 --- a/neuralprophet/utils.py +++ b/neuralprophet/utils.py @@ -5,7 +5,7 @@ import os import sys from collections import OrderedDict -from typing import IO, TYPE_CHECKING, BinaryIO, Iterable, Optional, Union +from typing import IO, TYPE_CHECKING, BinaryIO, Optional, Union import numpy as np import pandas as pd @@ -13,7 +13,6 @@ import torch from neuralprophet import utils_torch -from neuralprophet.hdays_utils import get_country_holidays from neuralprophet.logger import ProgressBar if TYPE_CHECKING: @@ -856,10 +855,6 @@ def configure_trainer( """ config = config.copy() - # Enable Learning rate finder if not learning rate provided - if config_train.learning_rate is None: - config["auto_lr_find"] = True - # Set max number of epochs if hasattr(config_train, "epochs"): if config_train.epochs is not None: diff --git a/poetry.lock b/poetry.lock index 681594e43..dcfb37096 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2760,13 +2760,13 @@ tenacity = ">=6.2.0" [[package]] name = "plotly-resampler" -version = "0.9.2" +version = "0.10.0" description = "Visualizing large time series with plotly" optional = true -python-versions = ">=3.7.1,<4.0.0" +python-versions = "<4.0.0,>=3.7.1" files = [ - {file = "plotly_resampler-0.9.2-py3-none-any.whl", hash = "sha256:72ced21696de5eb08ee6dae53eacd279eaac83d5500a04c410c3262c3bec7462"}, - {file = "plotly_resampler-0.9.2.tar.gz", hash = "sha256:abddac809931c157f5982d3398b68cc256786e75de33878a84060f57d886b471"}, + {file = "plotly_resampler-0.10.0-py3-none-any.whl", hash = "sha256:4d695557fe8a718b4a3f45a5381fae6faca0e441d04b6d83d6b43ce998013355"}, + {file = "plotly_resampler-0.10.0.tar.gz", hash = "sha256:e1063d6d00aa4aedeb8c2c204de8661751b2145fd06682cd8fead9983c9a8334"}, ] [package.dependencies] @@ -2778,7 +2778,7 @@ numpy = [ orjson = ">=3.8.0,<4.0.0" pandas = ">=1" plotly = ">=5.5.0,<6.0.0" -tsdownsample = "0.1.2" +tsdownsample = ">=0.1.3" [package.extras] inline-persistent = ["Flask-Cors (>=3.0.10,<4.0.0)", "jupyter-dash (>=0.4.2)", "kaleido (==0.2.1)"] @@ -3023,38 +3023,34 @@ six = ">=1.5" [[package]] name = "pytorch-lightning" -version = "1.9.5" +version = "2.3.0" description = "PyTorch Lightning is the lightweight PyTorch wrapper for ML researchers. Scale your models. Write less boilerplate." optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "pytorch-lightning-1.9.5.tar.gz", hash = "sha256:925fe7b80ddf04859fa385aa493b260be4000b11a2f22447afb4a932d1f07d26"}, - {file = "pytorch_lightning-1.9.5-py3-none-any.whl", hash = "sha256:06821558158623c5d2ecf5d3d0374dc8bd661e0acd3acf54a6d6f71737c156c5"}, + {file = "pytorch-lightning-2.3.0.tar.gz", hash = "sha256:89caf90e3543b314508493f26e0eca8d5e10e43e3d9e6c143acd8ddceb584ce2"}, + {file = "pytorch_lightning-2.3.0-py3-none-any.whl", hash = "sha256:b8eec361f4342ca628d0d8e6985511c9515435e4db62c5e982bb1c53a5a5140a"}, ] [package.dependencies] -fsspec = {version = ">2021.06.0", extras = ["http"]} -lightning-utilities = ">=0.6.0.post0" +fsspec = {version = ">=2022.5.0", extras = ["http"]} +lightning-utilities = ">=0.8.0" numpy = ">=1.17.2" -packaging = ">=17.1" +packaging = ">=20.0" PyYAML = ">=5.4" -torch = ">=1.10.0" +torch = ">=2.0.0" torchmetrics = ">=0.7.0" tqdm = ">=4.57.0" -typing-extensions = ">=4.0.0" +typing-extensions = ">=4.4.0" [package.extras] -all = ["colossalai (>=0.2.0)", "deepspeed (>=0.6.0)", "fairscale (>=0.4.5)", "gym[classic-control] (>=0.17.0)", "hivemind (==1.1.5)", "horovod (>=0.21.2,!=0.24.0)", "hydra-core (>=1.0.5)", "ipython[all] (<8.7.1)", "jsonargparse[signatures] (>=4.18.0)", "matplotlib (>3.1)", "omegaconf (>=2.0.5)", "rich (>=10.14.0,!=10.15.0.a)", "tensorboardX (>=2.2)", "torchmetrics (>=0.10.0)", "torchvision (>=0.11.1)"] -colossalai = ["colossalai (>=0.2.0)"] -deepspeed = ["deepspeed (>=0.6.0)"] -dev = ["cloudpickle (>=1.3)", "codecov (==2.1.12)", "colossalai (>=0.2.0)", "coverage (==6.5.0)", "deepspeed (>=0.6.0)", "fairscale (>=0.4.5)", "fastapi (<0.87.0)", "gym[classic-control] (>=0.17.0)", "hivemind (==1.1.5)", "horovod (>=0.21.2,!=0.24.0)", "hydra-core (>=1.0.5)", "ipython[all] (<8.7.1)", "jsonargparse[signatures] (>=4.18.0)", "matplotlib (>3.1)", "omegaconf (>=2.0.5)", "onnx (<1.14.0)", "onnxruntime (<1.14.0)", "pandas (>1.0)", "pre-commit (==2.20.0)", "protobuf (<=3.20.1)", "psutil (<5.9.5)", "pytest (==7.2.0)", "pytest-cov (==4.0.0)", "pytest-forked (==1.4.0)", "pytest-rerunfailures (==10.3)", "rich (>=10.14.0,!=10.15.0.a)", "scikit-learn (>0.22.1)", "tensorboard (>=2.9.1)", "tensorboardX (>=2.2)", "torchmetrics (>=0.10.0)", "torchvision (>=0.11.1)", "uvicorn (<0.19.1)"] -examples = ["gym[classic-control] (>=0.17.0)", "ipython[all] (<8.7.1)", "torchmetrics (>=0.10.0)", "torchvision (>=0.11.1)"] -extra = ["hydra-core (>=1.0.5)", "jsonargparse[signatures] (>=4.18.0)", "matplotlib (>3.1)", "omegaconf (>=2.0.5)", "rich (>=10.14.0,!=10.15.0.a)", "tensorboardX (>=2.2)"] -fairscale = ["fairscale (>=0.4.5)"] -hivemind = ["hivemind (==1.1.5)"] -horovod = ["horovod (>=0.21.2,!=0.24.0)"] -strategies = ["colossalai (>=0.2.0)", "deepspeed (>=0.6.0)", "fairscale (>=0.4.5)", "hivemind (==1.1.5)", "horovod (>=0.21.2,!=0.24.0)"] -test = ["cloudpickle (>=1.3)", "codecov (==2.1.12)", "coverage (==6.5.0)", "fastapi (<0.87.0)", "onnx (<1.14.0)", "onnxruntime (<1.14.0)", "pandas (>1.0)", "pre-commit (==2.20.0)", "protobuf (<=3.20.1)", "psutil (<5.9.5)", "pytest (==7.2.0)", "pytest-cov (==4.0.0)", "pytest-forked (==1.4.0)", "pytest-rerunfailures (==10.3)", "scikit-learn (>0.22.1)", "tensorboard (>=2.9.1)", "uvicorn (<0.19.1)"] +all = ["bitsandbytes (>=0.42.0)", "deepspeed (>=0.8.2,<=0.9.3)", "hydra-core (>=1.0.5)", "ipython[all] (<8.15.0)", "jsonargparse[signatures] (>=4.27.7)", "lightning-utilities (>=0.8.0)", "matplotlib (>3.1)", "omegaconf (>=2.0.5)", "requests (<2.32.0)", "rich (>=12.3.0)", "tensorboardX (>=2.2)", "torchmetrics (>=0.10.0)", "torchvision (>=0.15.0)"] +deepspeed = ["deepspeed (>=0.8.2,<=0.9.3)"] +dev = ["bitsandbytes (>=0.42.0)", "cloudpickle (>=1.3)", "coverage (==7.3.1)", "deepspeed (>=0.8.2,<=0.9.3)", "fastapi", "hydra-core (>=1.0.5)", "ipython[all] (<8.15.0)", "jsonargparse[signatures] (>=4.27.7)", "lightning-utilities (>=0.8.0)", "matplotlib (>3.1)", "omegaconf (>=2.0.5)", "onnx (>=0.14.0)", "onnxruntime (>=0.15.0)", "pandas (>1.0)", "psutil (<5.9.6)", "pytest (==7.4.0)", "pytest-cov (==4.1.0)", "pytest-random-order (==1.1.0)", "pytest-rerunfailures (==12.0)", "pytest-timeout (==2.1.0)", "requests (<2.32.0)", "rich (>=12.3.0)", "scikit-learn (>0.22.1)", "tensorboard (>=2.9.1)", "tensorboardX (>=2.2)", "torchmetrics (>=0.10.0)", "torchvision (>=0.15.0)", "uvicorn"] +examples = ["ipython[all] (<8.15.0)", "lightning-utilities (>=0.8.0)", "requests (<2.32.0)", "torchmetrics (>=0.10.0)", "torchvision (>=0.15.0)"] +extra = ["bitsandbytes (>=0.42.0)", "hydra-core (>=1.0.5)", "jsonargparse[signatures] (>=4.27.7)", "matplotlib (>3.1)", "omegaconf (>=2.0.5)", "rich (>=12.3.0)", "tensorboardX (>=2.2)"] +strategies = ["deepspeed (>=0.8.2,<=0.9.3)"] +test = ["cloudpickle (>=1.3)", "coverage (==7.3.1)", "fastapi", "onnx (>=0.14.0)", "onnxruntime (>=0.15.0)", "pandas (>1.0)", "psutil (<5.9.6)", "pytest (==7.4.0)", "pytest-cov (==4.1.0)", "pytest-random-order (==1.1.0)", "pytest-rerunfailures (==12.0)", "pytest-timeout (==2.1.0)", "scikit-learn (>0.22.1)", "tensorboard (>=2.9.1)", "uvicorn"] [[package]] name = "pytz" @@ -3926,72 +3922,84 @@ tutorials = ["matplotlib", "pandas", "tabulate", "torch"] [[package]] name = "tsdownsample" -version = "0.1.2" +version = "0.1.3" description = "Time series downsampling in rust" optional = true python-versions = ">=3.7" files = [ - {file = "tsdownsample-0.1.2-cp310-cp310-macosx_10_7_x86_64.whl", hash = "sha256:f5b15ce2be18816849f1b09ac35dc287233ca5da3b343cc1cddb55164b6fdc49"}, - {file = "tsdownsample-0.1.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:3af463173003c6ca1598ca44ef5e6d9d57988bd5576ae48300cd2b8af8ce3cef"}, - {file = "tsdownsample-0.1.2-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:616be5dfbbd486677fbc158403ba08d2c683cdd6bbe78c9ea3221248d45e8b98"}, - {file = "tsdownsample-0.1.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ec9ec1e54e6fb4e79ae577908e4741d2716bd0cea0423c6ad3f7cebb75430d0c"}, - {file = "tsdownsample-0.1.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c3297fdde74e0694c2a47d2655a668a2ed82ea194e8eae53786b0edaadea268d"}, - {file = "tsdownsample-0.1.2-cp310-cp310-manylinux_2_24_armv7l.whl", hash = "sha256:67b13aa3747013fc6cd0e7e3d4a18771fd3c430c5ff490a215566d97abf34570"}, - {file = "tsdownsample-0.1.2-cp310-cp310-manylinux_2_24_ppc64le.whl", hash = "sha256:4b1d916289f5a1eaf820f67764fcf65adcdc752abc82f5e14f30fe831f5a8687"}, - {file = "tsdownsample-0.1.2-cp310-cp310-manylinux_2_24_s390x.whl", hash = "sha256:eec6a306a70bf7e803563496e5e832fc43cfee81769aa9629acba7f04e693f58"}, - {file = "tsdownsample-0.1.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:98c0218ddeadf27331b60893ca5e575ad39b60da59b556bf01cb51939cfc1178"}, - {file = "tsdownsample-0.1.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:62dad5e7b615d7a270a7c220d6b874c515b6917a907fa4e2aea8ce20f29399e6"}, - {file = "tsdownsample-0.1.2-cp310-none-win32.whl", hash = "sha256:c29d82e21dbad01fdc97a873fe37c80d801c6d4635127e6de72cc8dac4e154af"}, - {file = "tsdownsample-0.1.2-cp310-none-win_amd64.whl", hash = "sha256:91d732304eec05f4e79be58b76986220c156c0602e1faa668feafdceb8bf8a1a"}, - {file = "tsdownsample-0.1.2-cp311-cp311-macosx_10_7_x86_64.whl", hash = "sha256:0406fb5d33f881f9d06f974ec8c915c0c848dd867422d4d11be87d7e536385ca"}, - {file = "tsdownsample-0.1.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:16458e6e82ffa34d900f431aba48918f5c44feda1c335080a54877dec9c16e60"}, - {file = "tsdownsample-0.1.2-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:480799c7f1ccff0c1b18b6c469558c9c774450bb631a4d6c2fb98c43756ef8e9"}, - {file = "tsdownsample-0.1.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9df5d888c79ac7e3129ad61becd9070814560c6607c759b300302086b19e639e"}, - {file = "tsdownsample-0.1.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c9e55d3ab482247f76a0edc29378ccdf3bc2a8e20b9b303b5e6da74eae953423"}, - {file = "tsdownsample-0.1.2-cp311-cp311-manylinux_2_24_armv7l.whl", hash = "sha256:f8892cee114f7d2d8b59f127b3f7ae63173ad7abc4ef2c7486854055d7c79f5f"}, - {file = "tsdownsample-0.1.2-cp311-cp311-manylinux_2_24_ppc64le.whl", hash = "sha256:aadec728b972f4c2bf19c27f39824118971ae36983e7c8d09269a18a15f6f194"}, - {file = "tsdownsample-0.1.2-cp311-cp311-manylinux_2_24_s390x.whl", hash = "sha256:12ca7dce72aa6c12cb65f690d1179b0b140e0236202add0601e56c417a9cb9ad"}, - {file = "tsdownsample-0.1.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:2b14bb78ab1a375edc87e7535aedd6e5d3d0f0ebf4cb98b2bff2c3a183797187"}, - {file = "tsdownsample-0.1.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:476ecc03172aad8076143d6a8b347dc2214c2f858a6231fab863ba50be4a01e9"}, - {file = "tsdownsample-0.1.2-cp311-none-win32.whl", hash = "sha256:91efdbccbaf0e6db51f3c3fc6191063033c86a718beef6c2c4b66593180525fb"}, - {file = "tsdownsample-0.1.2-cp311-none-win_amd64.whl", hash = "sha256:8295452fd49ca445907ab5314f3c9ac80ca02bb2cf079d278502dfeadaa28dca"}, - {file = "tsdownsample-0.1.2-cp37-cp37m-macosx_10_7_x86_64.whl", hash = "sha256:79e67275c3033b84b0a1f0ea4d7bd9891084f61c5545cb73aee8cd0923692930"}, - {file = "tsdownsample-0.1.2-cp37-cp37m-macosx_11_0_arm64.whl", hash = "sha256:d8a5718b23d1f424a5574d13a804e4cd11a0803ff954b93812a0ece6153ad209"}, - {file = "tsdownsample-0.1.2-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:332b99ad652ce7b2e523903179d9209323eb365e8474a895d3f6d91a9a4a064d"}, - {file = "tsdownsample-0.1.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:afb8c7406572520f33c8fe47fa71dfb956aa96ffe15c8c6c83b1c6dae2c97c44"}, - {file = "tsdownsample-0.1.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d19e9dfddca2ed0ddfdf9cf57f27c51a1f0f731efdf2c0e8503bd8a255866c43"}, - {file = "tsdownsample-0.1.2-cp37-cp37m-manylinux_2_24_armv7l.whl", hash = "sha256:729391d09cb7ee7f7295f045e3afc39ecb6eed0ecc20bf0c21ee80d31b2e8126"}, - {file = "tsdownsample-0.1.2-cp37-cp37m-manylinux_2_24_ppc64le.whl", hash = "sha256:fc771c5444954e5e69b0f2c6ed9a7605be6792d87a241f6593e6d5341caef83e"}, - {file = "tsdownsample-0.1.2-cp37-cp37m-manylinux_2_24_s390x.whl", hash = "sha256:8f9961335dc5885c7118872d1f0053fc087007f9a617a42281e51eb5c5f6ecff"}, - {file = "tsdownsample-0.1.2-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:519bb2fe44ed38ad7c04193b79f2820d6719c5eb530905155c97692d6e095758"}, - {file = "tsdownsample-0.1.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:8cbd37143e38c9854490fea5e741a2740cca01742010581e8de766f12d4075cb"}, - {file = "tsdownsample-0.1.2-cp37-none-win32.whl", hash = "sha256:5650b70299cd80703efd40cf80d17c0aea7ee4c8af875fafa64dd4957120d93b"}, - {file = "tsdownsample-0.1.2-cp37-none-win_amd64.whl", hash = "sha256:671b9b642551807f89f50f1b0f1a671dd5bb04726fa382f15b05123edc599a99"}, - {file = "tsdownsample-0.1.2-cp38-cp38-macosx_10_7_x86_64.whl", hash = "sha256:e677cbd4b5d7553502deb0b3126e13e80cfea31053b7e9dfa1ed738339078aa6"}, - {file = "tsdownsample-0.1.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:f315aa893b9c6bd1ca08c58ec70df9583c7854bc227c2ab3957ee9c496b30002"}, - {file = "tsdownsample-0.1.2-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:96719a42e16d7579cfb06f2b82c54a03dcfdaf99d54b63a7e3abb848a722a006"}, - {file = "tsdownsample-0.1.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b378f2339875a7e0fa05489170fe018da7cf0066183fd695a1d32f08bf51d9de"}, - {file = "tsdownsample-0.1.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:299b48bf047ca6c7a12b6fd42a62f6f365e134d805817e3caf71a3edd6d681d0"}, - {file = "tsdownsample-0.1.2-cp38-cp38-manylinux_2_24_armv7l.whl", hash = "sha256:bdf6c9c92699ed86d3b312a556a097fb836fb33dac3cb21a061af7327c341083"}, - {file = "tsdownsample-0.1.2-cp38-cp38-manylinux_2_24_ppc64le.whl", hash = "sha256:8de913d552c73baadd915151c59315bc56fa1c07dbf8f020ccdefcde6160dfb5"}, - {file = "tsdownsample-0.1.2-cp38-cp38-manylinux_2_24_s390x.whl", hash = "sha256:624671e516d64c38a1fad868861d64a66b0f15f3af68f2eab2f804d965d1c906"}, - {file = "tsdownsample-0.1.2-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:6eaa885a26db9b2d784269c9d002770a9c7a58b34a8ebb8b78edcdc0b5bbb8cd"}, - {file = "tsdownsample-0.1.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:9df0fd38963c37c65a69b51fd59463504f0a0f28bbbafab1759c0ee448e7526f"}, - {file = "tsdownsample-0.1.2-cp38-none-win32.whl", hash = "sha256:dae9c3d00aa12d3ed1108bebf434c35dd9611b432b937a0bb8338daa879d8834"}, - {file = "tsdownsample-0.1.2-cp38-none-win_amd64.whl", hash = "sha256:d0a086caf3121b566f98591cdf07ed2db697d8e72c1cd9dba7075798cc8a4550"}, - {file = "tsdownsample-0.1.2-cp39-cp39-macosx_10_7_x86_64.whl", hash = "sha256:3863c9638e2110018222bfc9b9c079d12426d4ef19d9de9fa759b1db3256e09d"}, - {file = "tsdownsample-0.1.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:29a8baf0393c46c867b4a9f625adb3ec7526015d8083948ec6ead703294d113f"}, - {file = "tsdownsample-0.1.2-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:4eee158d2e96ce1dca455277ef16036d704397f0d17f408e584d6eaca36cc381"}, - {file = "tsdownsample-0.1.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:235e7f2fc52b8257e451e47169b2e8377c48e1dd470985af99ac90913bcaa351"}, - {file = "tsdownsample-0.1.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cade03d838b82450736d23b0ce49dba1a782440f31b873d1d2d0c03fe458ef4f"}, - {file = "tsdownsample-0.1.2-cp39-cp39-manylinux_2_24_armv7l.whl", hash = "sha256:1c20d90000cb12af5b857a11fc5eca226fcfbcf2b4e5930f9f13412809b4b3ff"}, - {file = "tsdownsample-0.1.2-cp39-cp39-manylinux_2_24_ppc64le.whl", hash = "sha256:9c7f8fdb2713c64c1bf40327d2a64448b2a5a2f161238191e8105c9aa16d2e3d"}, - {file = "tsdownsample-0.1.2-cp39-cp39-manylinux_2_24_s390x.whl", hash = "sha256:3ced17341a23c568d843b5a3af6c5e9aacb5ae748d9c2ce8dc1f46788ad54f23"}, - {file = "tsdownsample-0.1.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:a89196d433d0325c0fd52bf002acc7da809eb270053ce116c38dcef6ced8d704"}, - {file = "tsdownsample-0.1.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:80af8a7e0c93acf4879a9b7ffb8b027de316a9fe6605e1d2cd9deaa54ebf0600"}, - {file = "tsdownsample-0.1.2-cp39-none-win32.whl", hash = "sha256:a105372c5e795ac72ac7aad80604fbbc410be40e59d209d055c994adf4e38c60"}, - {file = "tsdownsample-0.1.2-cp39-none-win_amd64.whl", hash = "sha256:bfc1ee990ebeed9865596edb09c7386590713b2cbbfaa1cfda2499f68cc94f5d"}, - {file = "tsdownsample-0.1.2.tar.gz", hash = "sha256:9da0b8f8859e9651910cc150bc5115bf5d1890d8ebee12bf57bd1bb016a0112a"}, + {file = "tsdownsample-0.1.3-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:e1e8b04a17efb6f25a730467bedd0a1ceda165149707305309f9456041cf4e49"}, + {file = "tsdownsample-0.1.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:1791225e0e610b8c883fd7c8237756901bd10af42240a98e747bdb1085ee4f7e"}, + {file = "tsdownsample-0.1.3-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:c1cfb42437732825af4b4fd6964ba8632c3b7a8094648ea9fb940412c62973ae"}, + {file = "tsdownsample-0.1.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:721ae2a9a385e36fe688d43c877852ba0bab2056b6875cf86aee1d6dec8b567c"}, + {file = "tsdownsample-0.1.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:66941f131ec096478483fdd9a80a408a12a0c01c85b0ebf9eb41445c2721eca6"}, + {file = "tsdownsample-0.1.3-cp310-cp310-manylinux_2_24_armv7l.whl", hash = "sha256:0bce3ae95aa104ec0fbc4c37db5694ad3fa9bd09f5509f49215de157b78a99b2"}, + {file = "tsdownsample-0.1.3-cp310-cp310-manylinux_2_24_ppc64le.whl", hash = "sha256:bab1f0258f41cff4f3968076946ff468641f2b6c32c624a278c2cdf47a5b3eff"}, + {file = "tsdownsample-0.1.3-cp310-cp310-manylinux_2_24_s390x.whl", hash = "sha256:d8d41957d44593c8fee21f89454303db9a9eb2cf0e556c0138c84157702b39a4"}, + {file = "tsdownsample-0.1.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:6ef3ad8f8f0fac177bab09fe2e916034d8d3d64e1b0531f2b3df9ecc1b36f84b"}, + {file = "tsdownsample-0.1.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:e268741155eff05125bd45f5ab32fd20daed293503c3749765eafab8fe3e12ef"}, + {file = "tsdownsample-0.1.3-cp310-none-win32.whl", hash = "sha256:0cf6695ecf63ab7114a18ebeb12f722811ec455842391ac216bb698803abdaeb"}, + {file = "tsdownsample-0.1.3-cp310-none-win_amd64.whl", hash = "sha256:d5493f021f96db43e35a37bca70ac13460e49de264085643b1bac88c136574d3"}, + {file = "tsdownsample-0.1.3-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:539bae2e1675af94c26c8eed67f1bb8ebfc72fe44fa17965fb324ef38cb9e70d"}, + {file = "tsdownsample-0.1.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f547ad4b796097d314fdc5538c50db8b073b110aae13e540cd8db4802b0317f6"}, + {file = "tsdownsample-0.1.3-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:8762d057a1b335fe44eec345b36b4d3b68ed369ca1214724a1c376546f49dce9"}, + {file = "tsdownsample-0.1.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3ec8f06d6ad9f26a52b34f229ed5647ee3a2d373014a1c29903e1d57208b7ded"}, + {file = "tsdownsample-0.1.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:08ed25561d504aad77ba815e4d176efe80a37b8761a53bf77bb087e08ae8f54b"}, + {file = "tsdownsample-0.1.3-cp311-cp311-manylinux_2_24_armv7l.whl", hash = "sha256:757b62df7fa170ada3748f2ee6682948be9bbd83370dba4bf83929dfd98570a3"}, + {file = "tsdownsample-0.1.3-cp311-cp311-manylinux_2_24_ppc64le.whl", hash = "sha256:c5d0ab1a46caf68764c36e6749d56d3c02ab39eb2f61ad700d8369dfe2861ad5"}, + {file = "tsdownsample-0.1.3-cp311-cp311-manylinux_2_24_s390x.whl", hash = "sha256:8d5bdcf9e09ee58411299ed5d47b1e9cdfaab61a8293cc2167af0e666aafcd4c"}, + {file = "tsdownsample-0.1.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:0884215aae9b75107f3400b43800ce7b61ce573e2f19e8fb155fbd2918b0d0b3"}, + {file = "tsdownsample-0.1.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:4deb0c331cb95ee634b6bd605b44e7936d5bf50e22f4efda15fa719ddf181951"}, + {file = "tsdownsample-0.1.3-cp311-none-win32.whl", hash = "sha256:3f0b70794d6ae79efc213ab4c384bbd3404b700c508d3873bbe63db3c48ab156"}, + {file = "tsdownsample-0.1.3-cp311-none-win_amd64.whl", hash = "sha256:f97a858a855d7d84c96d3b89daa2237c80da8da12ff7a3a3d13e029d6c15e69d"}, + {file = "tsdownsample-0.1.3-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:0822465103cde9688ecbbdad6642f3415871479bd62682bf85a55f0d156482c0"}, + {file = "tsdownsample-0.1.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:bb4ad7c0fea8267156e14c0c1a8f027b5c0831e9067a436b7888c1085e569aac"}, + {file = "tsdownsample-0.1.3-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:4eab94d2e392470e6e26bebdcbd7e600ad1e0edca4f88c160ed8e72e280c87aa"}, + {file = "tsdownsample-0.1.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:58788587348e88064bfdb88acbb1d51078c5b1fff934ed1433355d1dbf939641"}, + {file = "tsdownsample-0.1.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d59c184cac10edae160b1908c9ee579345033aaefe455581f5bd4375caada0b"}, + {file = "tsdownsample-0.1.3-cp312-cp312-manylinux_2_24_armv7l.whl", hash = "sha256:35b0eb80cd5d1eaacab1818b6fa3374a815787c844e385fc57f29785714a357b"}, + {file = "tsdownsample-0.1.3-cp312-cp312-manylinux_2_24_ppc64le.whl", hash = "sha256:acd0bae11eb3777e47de41783b590c1d0f6bc25c85870cb8a5d9fd47797f3f8d"}, + {file = "tsdownsample-0.1.3-cp312-cp312-manylinux_2_24_s390x.whl", hash = "sha256:09d705fe7a5a73aa97a486dbef8ed4b0b9a59e679ad2189eb9923c02c1321000"}, + {file = "tsdownsample-0.1.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:76107f22e01135bc97f493083c623bab88c12a79b909ee931efa585f6543d3c9"}, + {file = "tsdownsample-0.1.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:a966c97138809f2c3fe627b3acfa041983123342f92f9782d4573d16155d6c77"}, + {file = "tsdownsample-0.1.3-cp312-none-win32.whl", hash = "sha256:63b730c96a539abd71863166c203310383b0f04b268e935e69fd68a2b7a4d27a"}, + {file = "tsdownsample-0.1.3-cp312-none-win_amd64.whl", hash = "sha256:55479a6e1b1fa60092b6d344a99d374ccba57a40d23f9fd4d2fcdd5faabf7d40"}, + {file = "tsdownsample-0.1.3-cp37-cp37m-macosx_10_12_x86_64.whl", hash = "sha256:3caa8906b87efa2277584dde851f214db4a3aa4d6d5c631b0ec40e7978271ac8"}, + {file = "tsdownsample-0.1.3-cp37-cp37m-macosx_11_0_arm64.whl", hash = "sha256:982c4e885785f261d500c2d36f7054544313d0caacfc53c63fc47571f28194f3"}, + {file = "tsdownsample-0.1.3-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:fe138a297a504b2121c56e32afb2e645028177faef1a5152a3b080fe932bd458"}, + {file = "tsdownsample-0.1.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e3157795be6fbd2de9e936a242aa7c00ce8ab2def12fdec49d74502a216b403e"}, + {file = "tsdownsample-0.1.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e069bcf40b0ebf80d9f3a7f827cf26c60dde9b1474268a27e54f37daade1017d"}, + {file = "tsdownsample-0.1.3-cp37-cp37m-manylinux_2_24_armv7l.whl", hash = "sha256:5dd63a6a5358f52257775c83f66111fcdbd07cf70603a7d6271f5783b70bfaef"}, + {file = "tsdownsample-0.1.3-cp37-cp37m-manylinux_2_24_ppc64le.whl", hash = "sha256:5d77f9fd7d69a44e0935ad5f32a02b76b53c9aed325824565948b2a0bb6d2e08"}, + {file = "tsdownsample-0.1.3-cp37-cp37m-manylinux_2_24_s390x.whl", hash = "sha256:c4fa2115a8f46ff2d99958cd6652db84d09f4a1ac7cbafb62288e300df528b48"}, + {file = "tsdownsample-0.1.3-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:9bc4d095c43947f6499603210270e1dd29919b9fb692655acbd5b1e7d7667102"}, + {file = "tsdownsample-0.1.3-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:a1f89a77976f2cd24b19fa15dd98fa1aec2acef83743d9d9e31bc5daad7de8a3"}, + {file = "tsdownsample-0.1.3-cp37-none-win32.whl", hash = "sha256:9afe1856a18cbd58726614805cac01b2a7259a8f31396413e7c3c0629fe1f72c"}, + {file = "tsdownsample-0.1.3-cp37-none-win_amd64.whl", hash = "sha256:213d10479f06e98bb351bcf2f9b6b2c58632f6e4d27e200f224e7379b93775c2"}, + {file = "tsdownsample-0.1.3-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:c22103ec83b363b9195cb62533ddf76eaff7b198014a4dd40a8d1a0a2e8c6fa7"}, + {file = "tsdownsample-0.1.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:85a703c2d29dd0c7923b5b7d3c0eccc639a4087fbe8d0a0290506abcb8ef48bf"}, + {file = "tsdownsample-0.1.3-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:6ccff0710fb2dac229f810877cb308087f1456610f1aa272524b69f551642ee2"}, + {file = "tsdownsample-0.1.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1a198ab5f2ce1c9d30077b13393295dfbecaaddee6d2b5ffee7439d454c108f3"}, + {file = "tsdownsample-0.1.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:468752a26958c12a95ce583f46347e0c3eacb579323c3a230af27d50c3cabac5"}, + {file = "tsdownsample-0.1.3-cp38-cp38-manylinux_2_24_armv7l.whl", hash = "sha256:5293167ece8428f664ecd69302ecda32a6ad635d9aff6c21d1796198169e56bc"}, + {file = "tsdownsample-0.1.3-cp38-cp38-manylinux_2_24_ppc64le.whl", hash = "sha256:23557544e1d8e767606b134b55c5c598bcc368d73ba904ff4fe91b93d8935531"}, + {file = "tsdownsample-0.1.3-cp38-cp38-manylinux_2_24_s390x.whl", hash = "sha256:e0791339d9b8ddb78d1082a31ae880d00e4fa375147b8f2fdebc0915782e38ee"}, + {file = "tsdownsample-0.1.3-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:51d7a39b60d622f12beb1c04b404c7ab7743653eb74089b9663504675a9bfc62"}, + {file = "tsdownsample-0.1.3-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:9f2ba2ad61db5ef9617b1d6e764e9aabde63a391b1c8ee97098d348a9b9d883a"}, + {file = "tsdownsample-0.1.3-cp38-none-win32.whl", hash = "sha256:e69e6d066c30e946aeb7c19ae7d2364ee6f1a072c8e47dee76825d86b5ad84be"}, + {file = "tsdownsample-0.1.3-cp38-none-win_amd64.whl", hash = "sha256:9491a03ec700ad5ca0f555553b4da9b8285fc461c28f7970a61acf6b05d8f3ad"}, + {file = "tsdownsample-0.1.3-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:23d3b2d24b36fa5d05bf7e2e2748565522f3c8afd29a706dbb22a8c6b2dc4a75"}, + {file = "tsdownsample-0.1.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:1715fa981e5406c0c7adef04cdcc4b1a4c88e422946dc104b82f1669605c6da7"}, + {file = "tsdownsample-0.1.3-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:5a44a101a1b7fa6b67d185a9f905bc0a9e0dcac6537b653a71f14dce71a451fa"}, + {file = "tsdownsample-0.1.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:372b09627f39899b90605bd71ac481a9052f135b8c96d431255c3af6c383d0d4"}, + {file = "tsdownsample-0.1.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ac2d1141b0c3899bac018a6d8ed4b1baca2fb88ba8d8c9c7b1a4116f548c11a8"}, + {file = "tsdownsample-0.1.3-cp39-cp39-manylinux_2_24_armv7l.whl", hash = "sha256:ab6e9780f5a9d64b4692ac70f9a0aaf5a7bd499bc82e56b15e42e1a40e9f24a1"}, + {file = "tsdownsample-0.1.3-cp39-cp39-manylinux_2_24_ppc64le.whl", hash = "sha256:dc348f7802e18c33a6d8e0386debfa92ffa679591551fece9f2d49c0501b5c35"}, + {file = "tsdownsample-0.1.3-cp39-cp39-manylinux_2_24_s390x.whl", hash = "sha256:0d67e05dc61002c9672582f516ecf9473a6eba710be8580e9c9f37b187d83762"}, + {file = "tsdownsample-0.1.3-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:c1fb3e646206593ee92630e27a725cf00a9e70e20af5a7032baa3de2d2943bc6"}, + {file = "tsdownsample-0.1.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:20bf49fec6a4371fd417ca75a4096bafce2a3d3263a89f89d5bfba70bd78e409"}, + {file = "tsdownsample-0.1.3-cp39-none-win32.whl", hash = "sha256:74854a1b0c0a7a6581402769de12559cd9efeb1bb12ace831701aacfc0ddc140"}, + {file = "tsdownsample-0.1.3-cp39-none-win_amd64.whl", hash = "sha256:4d25567c0f15ca9e3f9d822934f60e5258e23a06d5515d12617518dc9f99f26f"}, + {file = "tsdownsample-0.1.3.tar.gz", hash = "sha256:5268d0ab5e8572138871feff389440a0c59d5e0fe02c0fa1cf975d74ba33b933"}, ] [package.dependencies] @@ -4221,5 +4229,5 @@ plotly-resampler = ["plotly-resampler"] [metadata] lock-version = "2.0" -python-versions = "^3.9" -content-hash = "76f8a9aa49b88fd68090fdef697200d6ab2d133b9bacb27f360227ea98d57d2d" +python-versions = ">=3.9,<=3.12" +content-hash = "6c3d1bac3597032afab291d8f490e6f6b21f11e3698b0ac88991fe5fde0d3859" diff --git a/pyproject.toml b/pyproject.toml index cfe62544c..9e0345635 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,28 +10,29 @@ classifiers = [ "Natural Language :: English", "Operating System :: OS Independent", "Operating System :: POSIX :: Linux", + "Operating System :: MacOS :: MacOS X", ] [tool.poetry.urls] Homepage = "https://github.com/ourownstory/neural_prophet" [tool.poetry.dependencies] -python = "^3.9" -typing-extensions = "^4.5.0" -numpy = "^1.25.0" -pandas = "^2.0.0" +python = ">=3.9,<=3.12" +numpy = ">=1.25.0,<2.0.0" +pandas = ">=2.0.0" +torch = ">=2.0.0" # Note: torch defaults to already installed version or installs CUDA version # If you want CPU-only torch, install that before installing neuralprophet. -torch = "^2.0.0" -pytorch-lightning = "^1.9.4" # TODO: move to ^2.0.0 -tensorboard = "^2.11.2" -torchmetrics = "^1.0.0" +pytorch-lightning = ">=2.0.0" +tensorboard = ">=2.11.2" +torchmetrics = ">=1.0.0" +typing-extensions = ">=4.5.0" holidays = ">=0.41" captum = ">=0.6.0" -matplotlib = "^3.5.3" -plotly = "^5.13.1" +matplotlib = ">=3.5.3" +plotly = ">=5.13.1" kaleido = "0.2.1" # required for plotly static image export -plotly-resampler = { version = "^0.9.2", python = "<3.12", optional = true } +plotly-resampler = { version = ">=0.9.2", optional = true } livelossplot = { version = ">=0.5.5", optional = true } [tool.poetry.extras] @@ -39,8 +40,8 @@ plotly-resampler = ["plotly-resampler"] live = ["livelossplot"] [tool.poetry.group.dev.dependencies] # For dev involving notebooks -ipykernel = "^6.29.2" -nbformat = "^5.8.0" +ipykernel = ">=6.29.2" +nbformat = ">=5.8.0" [tool.poetry.group.pytest] # pytest dev setup and CI optional = true @@ -121,4 +122,4 @@ exclude = [ [tool.ruff] line-length = 120 -typing-modules = ["neuralprophet.np_types"] +typing-modules = ["neuralprophet.np_types"] \ No newline at end of file diff --git a/tests/test_glocal.py b/tests/test_glocal.py index c25a649c3..e631b616d 100644 --- a/tests/test_glocal.py +++ b/tests/test_glocal.py @@ -207,7 +207,7 @@ def test_wrong_option_global_local_modeling(): def test_different_seasonality_modeling(): - ### SEASONALITY GLOBAL LOCAL MODELLING - NO EXOGENOUS VARIABLES + # SEASONALITY GLOBAL LOCAL MODELLING - NO EXOGENOUS VARIABLES log.info("Global Modeling + Global Normalization") df = pd.read_csv(PEYTON_FILE, nrows=512) df1_0 = df.iloc[:128, :].copy(deep=True) @@ -235,7 +235,7 @@ def test_different_seasonality_modeling(): def test_adding_new_global_seasonality(): - ### SEASONALITY GLOBAL LOCAL MODELLING - NO EXOGENOUS VARIABLES + # SEASONALITY GLOBAL LOCAL MODELLING - NO EXOGENOUS VARIABLES log.info("Global Modeling + Global Normalization") df = pd.read_csv(PEYTON_FILE, nrows=512) df1_0 = df.iloc[:128, :].copy(deep=True) @@ -264,7 +264,7 @@ def test_adding_new_global_seasonality(): def test_adding_new_local_seasonality(): - ### SEASONALITY GLOBAL LOCAL MODELLING - NO EXOGENOUS VARIABLES + # SEASONALITY GLOBAL LOCAL MODELLING - NO EXOGENOUS VARIABLES log.info("Global Modeling + Global Normalization") df = pd.read_csv(PEYTON_FILE, nrows=512) df1_0 = df.iloc[:128, :].copy(deep=True) @@ -285,7 +285,7 @@ def test_adding_new_local_seasonality(): def test_trend_local_reg(): - ### SEASONALITY GLOBAL LOCAL MODELLING - NO EXOGENOUS VARIABLES + # SEASONALITY GLOBAL LOCAL MODELLING - NO EXOGENOUS VARIABLES log.info("Global Modeling + Global Normalization") df = pd.read_csv(PEYTON_FILE, nrows=512) df1_0 = df.iloc[:128, :].copy(deep=True) @@ -315,7 +315,7 @@ def test_trend_local_reg(): def test_glocal_seasonality_reg(): - ### SEASONALITY GLOBAL LOCAL MODELLING - NO EXOGENOUS VARIABLES + # SEASONALITY GLOBAL LOCAL MODELLING - NO EXOGENOUS VARIABLES log.info("Global Modeling + Global Normalization") df = pd.read_csv(PEYTON_FILE, nrows=512) df1_0 = df.iloc[:128, :].copy(deep=True) @@ -344,7 +344,7 @@ def test_glocal_seasonality_reg(): def test_trend_local_reg_if_global(): - ### SEASONALITY GLOBAL LOCAL MODELLING - NO EXOGENOUS VARIABLES + # SEASONALITY GLOBAL LOCAL MODELLING - NO EXOGENOUS VARIABLES log.info("Global Modeling + Global Normalization") df = pd.read_csv(PEYTON_FILE, nrows=512) df1_0 = df.iloc[:128, :].copy(deep=True) @@ -373,7 +373,7 @@ def test_trend_local_reg_if_global(): def test_different_seasonality_modeling(): - ### SEASONALITY GLOBAL LOCAL MODELLING - NO EXOGENOUS VARIABLES + # SEASONALITY GLOBAL LOCAL MODELLING - NO EXOGENOUS VARIABLES log.info("Global Modeling + Global Normalization") df = pd.read_csv(PEYTON_FILE, nrows=512) df1_0 = df.iloc[:128, :].copy(deep=True) @@ -401,7 +401,7 @@ def test_different_seasonality_modeling(): def test_adding_new_global_seasonality(): - ### SEASONALITY GLOBAL LOCAL MODELLING - NO EXOGENOUS VARIABLES + # SEASONALITY GLOBAL LOCAL MODELLING - NO EXOGENOUS VARIABLES log.info("Global Modeling + Global Normalization") df = pd.read_csv(PEYTON_FILE, nrows=512) df1_0 = df.iloc[:128, :].copy(deep=True) @@ -430,7 +430,7 @@ def test_adding_new_global_seasonality(): def test_adding_new_local_seasonality(): - ### SEASONALITY GLOBAL LOCAL MODELLING - NO EXOGENOUS VARIABLES + # SEASONALITY GLOBAL LOCAL MODELLING - NO EXOGENOUS VARIABLES log.info("Global Modeling + Global Normalization") df = pd.read_csv(PEYTON_FILE, nrows=512) df1_0 = df.iloc[:128, :].copy(deep=True) @@ -451,7 +451,7 @@ def test_adding_new_local_seasonality(): def test_trend_local_reg(): - ### SEASONALITY GLOBAL LOCAL MODELLING - NO EXOGENOUS VARIABLES + # SEASONALITY GLOBAL LOCAL MODELLING - NO EXOGENOUS VARIABLES log.info("Global Modeling + Global Normalization") df = pd.read_csv(PEYTON_FILE, nrows=512) df1_0 = df.iloc[:128, :].copy(deep=True) @@ -481,7 +481,7 @@ def test_trend_local_reg(): def test_glocal_seasonality_reg(): - ### SEASONALITY GLOBAL LOCAL MODELLING - NO EXOGENOUS VARIABLES + # SEASONALITY GLOBAL LOCAL MODELLING - NO EXOGENOUS VARIABLES log.info("Global Modeling + Global Normalization") df = pd.read_csv(PEYTON_FILE, nrows=512) df1_0 = df.iloc[:128, :].copy(deep=True) @@ -510,7 +510,7 @@ def test_glocal_seasonality_reg(): def test_trend_local_reg_if_global(): - ### SEASONALITY GLOBAL LOCAL MODELLING - NO EXOGENOUS VARIABLES + # SEASONALITY GLOBAL LOCAL MODELLING - NO EXOGENOUS VARIABLES log.info("Global Modeling + Global Normalization") df = pd.read_csv(PEYTON_FILE, nrows=512) df1_0 = df.iloc[:128, :].copy(deep=True) diff --git a/tests/test_uncertainty.py b/tests/test_uncertainty.py index 039128cb1..6bbdd98e9 100644 --- a/tests/test_uncertainty.py +++ b/tests/test_uncertainty.py @@ -71,8 +71,8 @@ def test_uncertainty_estimation_peyton_manning(): if m.n_lags > 0: df["A"] = df["y"].rolling(7, min_periods=1).mean() df["B"] = df["y"].rolling(30, min_periods=1).mean() - m = m.add_lagged_regressor(name="A") - m = m.add_lagged_regressor(name="B", only_last_value=True) + m = m.add_lagged_regressor(names="A") + m = m.add_lagged_regressor(names="B") # add events m = m.add_events(["superbowl", "playoff"], lower_window=-1, upper_window=1, regularization=0.1)