Skip to content

Commit

Permalink
[minor] add subdivisions to holidays (#1584)
Browse files Browse the repository at this point in the history
* add io buffer support

* added tests

* remove typealias

* add subdivions to country

* remove subdivion variable

* adjust function signature

* move holidays logic to one class

* changed get_holidays_from_country

* fix types issues

* move hdays util from time_dataset to hdays_utils

* fix test

* fix test-2

* extend subdivision support

---------

Co-authored-by: Maisa Ben Salah <maisabensalah@AminsMBP131.attlocal.net>
Co-authored-by: ourownstory <ourownstory@users.noreply.github.com>
  • Loading branch information
3 people authored Jun 21, 2024
1 parent d4e2553 commit 79ea9cd
Show file tree
Hide file tree
Showing 8 changed files with 117 additions and 81 deletions.
7 changes: 4 additions & 3 deletions neuralprophet/configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import types
from collections import OrderedDict
from dataclasses import dataclass, field
from typing import Callable, List, Optional
from typing import Callable, Iterable, List, Optional

Check failure on line 8 in neuralprophet/configure.py

View workflow job for this annotation

GitHub Actions / flake8

'typing.Iterable' imported but unused
from typing import OrderedDict as OrderedDictType
from typing import Type, Union

Expand All @@ -15,6 +15,7 @@

from neuralprophet import df_utils, np_types, utils, utils_torch

Check failure on line 16 in neuralprophet/configure.py

View workflow job for this annotation

GitHub Actions / flake8

'neuralprophet.utils' imported but unused
from neuralprophet.custom_loss_metrics import PinballLoss
from neuralprophet.hdays_utils import get_holidays_from_country

log = logging.getLogger("NP.config")

Expand Down Expand Up @@ -505,15 +506,15 @@ class Event:

@dataclass
class Holidays:
country: Union[str, List[str]]
country: Union[str, List[str], dict]
lower_window: int
upper_window: int
mode: str = "additive"
reg_lambda: Optional[float] = None
holiday_names: set = field(init=False)

def init_holidays(self, df=None):
self.holiday_names = utils.get_holidays_from_country(self.country, df)
self.holiday_names = get_holidays_from_country(self.country, df)


ConfigCountryHolidays = Holidays
6 changes: 3 additions & 3 deletions neuralprophet/forecaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,7 +748,7 @@ def add_events(

def add_country_holidays(
self,
country_name: Union[str, list],
country_name: Union[str, list, dict],
lower_window: int = 0,
upper_window: int = 0,
regularization: Optional[float] = None,
Expand All @@ -764,8 +764,8 @@ def add_country_holidays(
Parameters
----------
country_name : str, list
name or list of names of the country
country_name : str, list, dict
name or list of names of the country or a dictionary where the key is the country name and the value is a subdivision
lower_window : int
the lower window for all the country holidays
upper_window : int
Expand Down
88 changes: 86 additions & 2 deletions neuralprophet/hdays_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
from collections import defaultdict
from typing import Iterable, Optional, Union

import holidays
import numpy as np
import pandas as pd


def get_country_holidays(country: str, years: Optional[Union[int, Iterable[int]]] = None):
def get_country_holidays(
country: str, years: Optional[Union[int, Iterable[int]]] = None, subdivision: Optional[str] = None
):
"""
Helper function to get holidays for a country.
Expand All @@ -13,6 +18,8 @@ def get_country_holidays(country: str, years: Optional[Union[int, Iterable[int]]
Country name to retrieve country specific holidays
years : int, list
Year or list of years to retrieve holidays for
subdivision : str
Subdivision name to retrieve subdivision specific holidays
Returns
-------
Expand All @@ -27,5 +34,82 @@ def get_country_holidays(country: str, years: Optional[Union[int, Iterable[int]]
country = substitutions.get(country, country)
if not hasattr(holidays, country):
raise AttributeError(f"Holidays in {country} are not currently supported!")
if subdivision:
holiday_obj = getattr(holidays, country)(years=years, subdiv=subdivision)
else:
holiday_obj = getattr(holidays, country)(years=years)

return getattr(holidays, country)(years=years)
return holiday_obj


def get_holidays_from_country(country: Union[str, Iterable[str], dict], df=None):
"""
Return all possible holiday names of given countries
Parameters
----------
country : str, list
List of country names to retrieve country specific holidays
subdivision : str, dict
a single subdivision (e.g., province or state) as a string or
a dictionary where the key is the country name and the value is a subdivision
df : pd.Dataframe
Dataframe from which datestamps will be retrieved from
Returns
-------
set
All possible holiday names of given country
"""
if df is None:
years = np.arange(1995, 2045)
else:
dates = df["ds"].copy(deep=True)
years = list({x.year for x in dates})
# support multiple countries
if isinstance(country, str):
country = {country: None}
elif isinstance(country, list):
country = dict(zip(country, [None] * len(country)))

unique_holidays = {}
for single_country, subdivision in country.items():
holidays_country = get_country_holidays(single_country, years, subdivision)
for date, name in holidays_country.items():
if date not in unique_holidays:
unique_holidays[date] = name
holiday_names = unique_holidays.values()

return set(holiday_names)


def make_country_specific_holidays(year_list, country):
"""
Create dict of holiday names and dates for given years and countries
Parameters
----------
year_list : list
List of years
country : str, list, dict
List of country names and optional subdivisions
Returns
-------
dict
holiday names as keys and dates as values
"""
# iterate over countries and get holidays for each country

if isinstance(country, str):
country = {country: None}
elif isinstance(country, list):
country = dict(zip(country, [None] * len(country)))

country_specific_holidays = {}
for single_country, subdivision in country.items():
single_country_specific_holidays = get_country_holidays(single_country, year_list, subdivision)
# only add holiday if it is not already in the dict
country_specific_holidays.update(single_country_specific_holidays)
holidays_dates = defaultdict(list)
for date, holiday in country_specific_holidays.items():
holidays_dates[holiday].append(pd.to_datetime(date))
return holidays_dates
35 changes: 3 additions & 32 deletions neuralprophet/time_dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from collections import OrderedDict, defaultdict
from collections import OrderedDict
from datetime import datetime
from typing import Optional

Expand All @@ -10,7 +10,7 @@

from neuralprophet import configure, utils
from neuralprophet.df_utils import get_max_num_lags
from neuralprophet.hdays_utils import get_country_holidays
from neuralprophet.hdays_utils import make_country_specific_holidays

log = logging.getLogger("NP.time_dataset")

Expand Down Expand Up @@ -531,35 +531,6 @@ def fourier_series_t(t, period, series_order):
return features


def make_country_specific_holidays_df(year_list, country):
"""
Make dataframe of country specific holidays for given years and countries
Parameters
----------
year_list : list
List of years
country : str, list
List of country names
Returns
-------
pd.DataFrame
Containing country specific holidays df with columns 'ds' and 'holiday'
"""
# iterate over countries and get holidays for each country
# convert to list if not already
if isinstance(country, str):
country = [country]
country_specific_holidays = {}
for single_country in country:
single_country_specific_holidays = get_country_holidays(single_country, year_list)
# only add holiday if it is not already in the dict
country_specific_holidays.update(single_country_specific_holidays)
country_specific_holidays_dict = defaultdict(list)
for date, holiday in country_specific_holidays.items():
country_specific_holidays_dict[holiday].append(pd.to_datetime(date))
return country_specific_holidays_dict


def _create_event_offset_features(event, config, feature, additive_events, multiplicative_events):
"""
Create event offset features for the given event, config and feature
Expand Down Expand Up @@ -623,7 +594,7 @@ def make_events_features(df, config_events: Optional[configure.ConfigEvents] = N
# create all country specific holidays
if config_country_holidays is not None:
year_list = list({x.year for x in df.ds})
country_holidays_dict = make_country_specific_holidays_df(year_list, config_country_holidays.country)
country_holidays_dict = make_country_specific_holidays(year_list, config_country_holidays.country)
for holiday in config_country_holidays.holiday_names:
feature = pd.Series([0.0] * df.shape[0])
if holiday in country_holidays_dict.keys():
Expand Down
38 changes: 2 additions & 36 deletions neuralprophet/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import os
import sys
from collections import OrderedDict
from typing import TYPE_CHECKING, Iterable, Optional, Union, BinaryIO, IO
from typing import IO, TYPE_CHECKING, BinaryIO, Iterable, Optional, Union

import numpy as np
import pandas as pd
Expand All @@ -23,6 +23,7 @@

FILE_LIKE = Union[str, os.PathLike, BinaryIO, IO[bytes]]


def save(forecaster, path: FILE_LIKE):
"""Save a fitted Neural Prophet model to disk.
Expand Down Expand Up @@ -375,41 +376,6 @@ def config_seasonality_to_model_dims(config_seasonality: ConfigSeasonality):
return seasonal_dims


def get_holidays_from_country(country: Union[str, Iterable[str]], df=None):
"""
Return all possible holiday names of given country
Parameters
----------
country : str, list
List of country names to retrieve country specific holidays
df : pd.Dataframe
Dataframe from which datestamps will be retrieved from
Returns
-------
set
All possible holiday names of given country
"""
if df is None:
years = np.arange(1995, 2045)
else:
dates = df["ds"].copy(deep=True)
years = list({x.year for x in dates})
# support multiple countries
if isinstance(country, str):
country = [country]

unique_holidays = {}
for single_country in country:
holidays_country = get_country_holidays(single_country, years)
for date, name in holidays_country.items():
if date not in unique_holidays:
unique_holidays[date] = name
holiday_names = unique_holidays.values()
return set(holiday_names)


def config_events_to_model_dims(config_events: Optional[ConfigEvents], config_country_holidays):
"""
Convert user specified events configurations along with country specific
Expand Down
12 changes: 12 additions & 0 deletions tests/test_hdays_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,15 @@ def test_get_country_holidays():

with pytest.raises(AttributeError):
hdays_utils.get_country_holidays("NotSupportedCountry")


def test_get_country_holidays_with_subdivisions():
# Test US holidays with a subdivision
us_ca_holidays = hdays_utils.get_country_holidays("US", years=2019, subdivision="CA")
assert issubclass(us_ca_holidays.__class__, holidays.countries.united_states.UnitedStates) is True
assert len(us_ca_holidays) > 0 # Assuming there are holidays specific to CA

# Test Canada holidays with a subdivision
ca_on_holidays = hdays_utils.get_country_holidays("CA", years=2019, subdivision="ON")
assert issubclass(ca_on_holidays.__class__, holidays.countries.canada.CA) is True
assert len(ca_on_holidays) > 0 # Assuming there are holidays specific to ON
8 changes: 5 additions & 3 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#!/usr/bin/env python3

import io
import logging
import os
import pathlib
import io

import pandas as pd
import pytest
Expand Down Expand Up @@ -67,6 +67,7 @@ def test_save_load():
pd.testing.assert_frame_equal(forecast, forecast2)
pd.testing.assert_frame_equal(forecast, forecast3)


def test_save_load_io():
df = pd.read_csv(PEYTON_FILE, nrows=NROWS)
m = NeuralProphet(
Expand All @@ -80,13 +81,13 @@ def test_save_load_io():
_ = m.fit(df, freq="D")
future = m.make_future_dataframe(df, periods=3)
forecast = m.predict(df=future)

# Save the model to an in-memory buffer
log.info("testing: save to buffer")
buffer = io.BytesIO()
save(m, buffer)
buffer.seek(0) # Reset buffer position to the beginning

log.info("testing: load from buffer")
m2 = load(buffer)
forecast2 = m2.predict(df=future)
Expand All @@ -99,6 +100,7 @@ def test_save_load_io():
pd.testing.assert_frame_equal(forecast, forecast2)
pd.testing.assert_frame_equal(forecast, forecast3)


# TODO: add functionality to continue training
# def test_continue_training():
# df = pd.read_csv(PEYTON_FILE, nrows=NROWS)
Expand Down
4 changes: 2 additions & 2 deletions tests/utils/dataset_generators.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
import pandas as pd

from neuralprophet.time_dataset import make_country_specific_holidays_df
from neuralprophet.hdays_utils import make_country_specific_holidays


def generate_holiday_dataset(country="US", years=[2022], y_default=1, y_holiday=100, y_holidays_override={}):
Expand All @@ -11,7 +11,7 @@ def generate_holiday_dataset(country="US", years=[2022], y_default=1, y_holiday=
dates = pd.date_range("%i-01-01" % (years[0]), periods=periods, freq="D")
df = pd.DataFrame({"ds": dates, "y": y_default}, index=dates)

holidays = make_country_specific_holidays_df(years, country)
holidays = make_country_specific_holidays(years, country)
for holiday_name, timestamps in holidays.items():
df.loc[timestamps[0], "y"] = y_holidays_override.get(holiday_name, y_holiday)

Expand Down

0 comments on commit 79ea9cd

Please sign in to comment.