Skip to content

Commit

Permalink
extract holiday getter helper to single function
Browse files Browse the repository at this point in the history
  • Loading branch information
noxan committed Dec 14, 2022
1 parent 1037990 commit fbf54a5
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 20 deletions.
15 changes: 15 additions & 0 deletions neuralprophet/hdays_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import holidays as pyholidays

from neuralprophet import hdays as hdays_part2


def get_country_holidays(country, years):
try:
holidays_country = getattr(hdays_part2, country)(years=years)
except AttributeError:
try:
holidays_country = getattr(pyholidays, country)(years=years)
except AttributeError:
raise AttributeError(f"Holidays in {country} are not currently supported!")

return holidays_country
14 changes: 3 additions & 11 deletions neuralprophet/time_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,14 @@
from datetime import datetime
from typing import Optional

import holidays as hdays_part1
import numpy as np
import pandas as pd
import torch
from torch.utils.data.dataset import Dataset

from neuralprophet import configure
from neuralprophet import hdays as hdays_part2
from neuralprophet import utils
from neuralprophet import configure, utils
from neuralprophet.df_utils import get_max_num_lags
from neuralprophet.hdays_utils import get_country_holidays

log = logging.getLogger("NP.time_dataset")

Expand Down Expand Up @@ -473,13 +471,7 @@ def make_country_specific_holidays_df(year_list, country):
country = [country]
country_specific_holidays = {}
for single_country in country:
try:
single_country_specific_holidays = getattr(hdays_part2, single_country)(years=year_list)
except AttributeError:
try:
single_country_specific_holidays = getattr(hdays_part1, single_country)(years=year_list)
except AttributeError:
raise AttributeError(f"Holidays in {single_country} are not currently supported!")
single_country_specific_holidays = get_country_holidays(single_country, year_list)
# only add holiday if it is not already in the dict
country_specific_holidays.update(single_country_specific_holidays)
country_specific_holidays_dict = defaultdict(list)
Expand Down
11 changes: 2 additions & 9 deletions neuralprophet/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,13 @@
from collections import OrderedDict
from typing import TYPE_CHECKING, Optional

import holidays as pyholidays
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import torch

from neuralprophet import hdays as hdays_part2
from neuralprophet import utils_torch
from neuralprophet.hdays_utils import get_country_holidays
from neuralprophet.logger import ProgressBar

if TYPE_CHECKING:
Expand Down Expand Up @@ -298,13 +297,7 @@ def get_holidays_from_country(country, df=None):

holidays = {}
for single_country in country:
try:
holidays_country = getattr(hdays_part2, single_country)(years=years)
except AttributeError:
try:
holidays_country = getattr(pyholidays, single_country)(years=years)
except AttributeError:
raise AttributeError(f"Holidays in {single_country} are not currently supported!")
holidays_country = get_country_holidays(single_country, years)
# only add holiday if it is not already in the dict
holidays.update(holidays_country)
holiday_names = holidays.values()
Expand Down

0 comments on commit fbf54a5

Please sign in to comment.