Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Minor] fix future regressor #1585

Merged
merged 18 commits into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion neuralprophet/components/future_regressors/linear.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import torch
import torch.nn as nn

from neuralprophet import utils
from neuralprophet.components.future_regressors import FutureRegressors
from neuralprophet.utils_torch import init_parameter

Expand Down
2 changes: 1 addition & 1 deletion neuralprophet/components/future_regressors/neural_nets.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch.nn as nn

from neuralprophet.components.future_regressors import FutureRegressors
from neuralprophet.utils_torch import init_parameter, interprete_model
from neuralprophet.utils_torch import interprete_model

# from neuralprophet.utils_torch import init_parameter

Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from collections import Counter, OrderedDict
from collections import Counter

import torch
import torch.nn as nn

from neuralprophet.components.future_regressors import FutureRegressors
from neuralprophet.utils_torch import init_parameter, interprete_model
from neuralprophet.utils_torch import interprete_model

# from neuralprophet.utils_torch import init_parameter

Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from collections import Counter, OrderedDict
from collections import Counter

import torch
import torch.nn as nn

from neuralprophet.components.future_regressors import FutureRegressors
from neuralprophet.utils_torch import init_parameter, interprete_model
from neuralprophet.utils_torch import interprete_model

# from neuralprophet.utils_torch import init_parameter

Expand Down
11 changes: 6 additions & 5 deletions neuralprophet/configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,17 +305,17 @@
self.trend_global_local = "global"

# If trend_local_reg < 0
if self.trend_local_reg < 0:

Check failure on line 308 in neuralprophet/configure.py

View workflow job for this annotation

GitHub Actions / pyright

Operator "<" not supported for "None" (reportOptionalOperand)
log.error("Invalid negative trend_local_reg '{}'. Set to False".format(self.trend_local_reg))
self.trend_local_reg = False

# If trend_local_reg = True
if self.trend_local_reg == True:
if self.trend_local_reg:
log.error("trend_local_reg = True. Default trend_local_reg value set to 1")
self.trend_local_reg = 1

# If Trend modelling is global.
if self.trend_global_local == "global" and self.trend_local_reg != False:
if self.trend_global_local == "global" and self.trend_local_reg:
log.error("Trend modeling is '{}'. Setting the trend_local_reg to False".format(self.trend_global_local))
self.trend_local_reg = False

Expand Down Expand Up @@ -355,13 +355,13 @@
log.error("Invalid global_local mode '{}'. Set to 'global'".format(self.global_local))
self.global_local = "global"

self.periods = OrderedDict(

Check failure on line 358 in neuralprophet/configure.py

View workflow job for this annotation

GitHub Actions / pyright

No overloads for "__init__" match the provided arguments (reportCallIssue)
{

Check failure on line 359 in neuralprophet/configure.py

View workflow job for this annotation

GitHub Actions / pyright

Argument of type "dict[str, Season]" cannot be assigned to parameter "iterable" of type "Iterable[list[bytes]]" in function "__init__" (reportArgumentType)
"yearly": Season(
resolution=6,
period=365.25,
arg=self.yearly_arg,
global_local=(

Check failure on line 364 in neuralprophet/configure.py

View workflow job for this annotation

GitHub Actions / pyright

Argument of type "SeasonGlobalLocalMode | Literal['auto']" cannot be assigned to parameter "global_local" of type "SeasonGlobalLocalMode" in function "__init__" (reportArgumentType)
self.yearly_global_local
if self.yearly_global_local in ["global", "local"]
else self.global_local
Expand All @@ -372,7 +372,7 @@
resolution=3,
period=7,
arg=self.weekly_arg,
global_local=(

Check failure on line 375 in neuralprophet/configure.py

View workflow job for this annotation

GitHub Actions / pyright

Argument of type "SeasonGlobalLocalMode | Literal['auto']" cannot be assigned to parameter "global_local" of type "SeasonGlobalLocalMode" in function "__init__" (reportArgumentType)
self.weekly_global_local
if self.weekly_global_local in ["global", "local"]
else self.global_local
Expand All @@ -383,7 +383,7 @@
resolution=6,
period=1,
arg=self.daily_arg,
global_local=(

Check failure on line 386 in neuralprophet/configure.py

View workflow job for this annotation

GitHub Actions / pyright

Argument of type "SeasonGlobalLocalMode | Literal['auto']" cannot be assigned to parameter "global_local" of type "SeasonGlobalLocalMode" in function "__init__" (reportArgumentType)
self.daily_global_local if self.daily_global_local in ["global", "local"] else self.global_local
),
condition_name=None,
Expand All @@ -392,17 +392,17 @@
)

# If seasonality_local_reg < 0
if self.seasonality_local_reg < 0:

Check failure on line 395 in neuralprophet/configure.py

View workflow job for this annotation

GitHub Actions / pyright

Operator "<" not supported for "None" (reportOptionalOperand)
log.error("Invalid negative seasonality_local_reg '{}'. Set to False".format(self.seasonality_local_reg))
self.seasonality_local_reg = False

# If seasonality_local_reg = True
if self.seasonality_local_reg == True:
if self.seasonality_local_reg:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This also seems to come from another PR

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, this change will always override a set seasonality_local_reg and overwrite it to be 1 which does not appear intentional?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No was trying to solve the flake8 errors - doesn't make any difference anyways

log.error("seasonality_local_reg = True. Default seasonality_local_reg value set to 1")
self.seasonality_local_reg = 1

# If Season modelling is global.
if self.global_local == "global" and self.seasonality_local_reg != False:
if self.global_local == "global" and self.seasonality_local_reg:
log.error(
"Seasonality modeling is '{}'. Setting the seasonality_local_reg to False".format(self.global_local)
)
Expand All @@ -413,7 +413,7 @@
resolution=resolution,
period=period,
arg=arg,
global_local=global_local if global_local in ["global", "local"] else self.global_local,

Check failure on line 416 in neuralprophet/configure.py

View workflow job for this annotation

GitHub Actions / pyright

Argument of type "str" cannot be assigned to parameter "global_local" of type "SeasonGlobalLocalMode" in function "__init__"   Type "str" is incompatible with type "SeasonGlobalLocalMode"     "str" is incompatible with type "Literal['global']"     "str" is incompatible with type "Literal['local']"     "str" is incompatible with type "Literal['glocal']" (reportArgumentType)
condition_name=condition_name,
)

Expand Down Expand Up @@ -487,7 +487,7 @@
regressors: OrderedDict = field(init=False) # contains RegressorConfig objects

def __post_init__(self):
self.regressors = None

Check failure on line 490 in neuralprophet/configure.py

View workflow job for this annotation

GitHub Actions / pyright

Cannot assign to attribute "regressors" for class "ConfigFutureRegressors*"   "None" is incompatible with "OrderedDict[Unknown, Unknown]" (reportAttributeAccessIssue)


@dataclass
Expand All @@ -507,11 +507,12 @@
lower_window: int
upper_window: int
mode: str = "additive"
subdivision: Optional[Union[str, dict]] = (None,)
reg_lambda: Optional[float] = None
holiday_names: set = field(init=False)

def init_holidays(self, df=None):
self.holiday_names = utils.get_holidays_from_country(self.country, df)
self.holiday_names = utils.get_holidays_from_country(self.country, self.subdivision, df)


ConfigCountryHolidays = Holidays
14 changes: 10 additions & 4 deletions neuralprophet/forecaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,6 +749,7 @@
def add_country_holidays(
self,
country_name: Union[str, list],
subdivision_name: Optional[Union[str, dict]] = None,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems like the subdivision commits slipped in here.
we can either remove those commits or merge that PR first and be careful that these lines do not longer reappear due to commit squashing.

lower_window: int = 0,
upper_window: int = 0,
regularization: Optional[float] = None,
Expand All @@ -766,6 +767,9 @@
----------
country_name : str, list
name or list of names of the country
subdivision_name : str, dict
name or list of names of the subdivisions (e.g., provinces or states) or
a dictionary where the key is the country name and the value is a list of subdivisions
lower_window : int
the lower window for all the country holidays
upper_window : int
Expand All @@ -789,6 +793,7 @@
regularization = None
self.config_country_holidays = configure.Holidays(
country=country_name,
subdivision=subdivision_name,
lower_window=lower_window,
upper_window=upper_window,
reg_lambda=regularization,
Expand Down Expand Up @@ -1102,12 +1107,12 @@
# Only display the plot if the session is interactive, eg. do not show in github actions since it
# causes an error in the Windows and MacOS environment
if matplotlib.is_interactive():
fig

Check warning on line 1110 in neuralprophet/forecaster.py

View workflow job for this annotation

GitHub Actions / pyright

Expression value is unused (reportUnusedExpression)

self.fitted = True
return metrics_df

def predict(self, df: pd.DataFrame, decompose: bool = True, raw: bool = False):
def predict(self, df: pd.DataFrame, decompose: bool = True, raw: bool = False, auto_extend=False):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To avoid the linked issues being raised, auto_extend should default to True, unless my logic is off?

Copy link
Collaborator Author

@MaiBe-ctrl MaiBe-ctrl Jun 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

auto_extend for me means that we don't do the cutting at the end. So we cut all the time except for the case where it will cause the whole dataframe to be nan. Can still reverse it

"""Runs the model to make predictions.

Expects all data needed to be present in dataframe.
Expand Down Expand Up @@ -1176,7 +1181,7 @@
quantiles=self.config_train.quantiles,
components=components,
)
if periods_added[df_name] > 0:
if not auto_extend and periods_added[df_name] > 0:
fcst = fcst[:-1]
else:
fcst = _reshape_raw_predictions_to_forecst_df(
Expand All @@ -1191,9 +1196,10 @@
quantiles=self.config_train.quantiles,
config_lagged_regressors=self.config_lagged_regressors,
)
if periods_added[df_name] > 0:
fcst = fcst[: -periods_added[df_name]]
if not auto_extend and periods_added[df_name] > 0:
fcst = fcst[:-1]
forecast = pd.concat((forecast, fcst), ignore_index=True)

df = df_utils.return_df_in_original_format(forecast, received_ID_col, received_single_time_series)
self.predict_steps = self.n_forecasts
return df
Expand Down
12 changes: 10 additions & 2 deletions neuralprophet/hdays_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import holidays


def get_country_holidays(country: str, years: Optional[Union[int, Iterable[int]]] = None):
def get_country_holidays(
country: str, years: Optional[Union[int, Iterable[int]]] = None, subdivision: Optional[str] = None
):
"""
Helper function to get holidays for a country.

Expand All @@ -13,6 +15,8 @@ def get_country_holidays(country: str, years: Optional[Union[int, Iterable[int]]
Country name to retrieve country specific holidays
years : int, list
Year or list of years to retrieve holidays for
subdivision : str
Subdivision name to retrieve subdivision specific holidays

Returns
-------
Expand All @@ -27,5 +31,9 @@ def get_country_holidays(country: str, years: Optional[Union[int, Iterable[int]]
country = substitutions.get(country, country)
if not hasattr(holidays, country):
raise AttributeError(f"Holidays in {country} are not currently supported!")
if subdivision:
holiday_obj = getattr(holidays, country)(years=years, subdiv=subdivision)
else:
holiday_obj = getattr(holidays, country)(years=years)

return getattr(holidays, country)(years=years)
return holiday_obj
4 changes: 2 additions & 2 deletions neuralprophet/time_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -940,7 +940,7 @@ def _add_batch_regularizations(self, loss, epoch, progress):
trend_glocal_loss = torch.zeros(1, dtype=torch.float, requires_grad=False)
# Glocal Trend
if self.config_trend is not None:
if self.config_trend.trend_global_local == "local" and self.config_trend.trend_local_reg != False:
if self.config_trend.trend_global_local == "local" and self.config_trend.trend_local_reg:
trend_glocal_loss = reg_func_trend_glocal(
self.trend.trend_k0, self.trend.trend_deltas, self.config_trend.trend_local_reg
)
Expand All @@ -949,7 +949,7 @@ def _add_batch_regularizations(self, loss, epoch, progress):
if self.config_seasonality is not None:
if (
self.config_seasonality.global_local in ["local", "glocal"]
and self.config_seasonality.seasonality_local_reg != False
and self.config_seasonality.seasonality_local_reg
):
seasonality_glocal_loss = reg_func_seasonality_glocal(
self.seasonality.season_params, self.config_seasonality.seasonality_local_reg
Expand Down
20 changes: 16 additions & 4 deletions neuralprophet/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import os
import sys
from collections import OrderedDict
from typing import TYPE_CHECKING, Iterable, Optional, Union, BinaryIO, IO
from typing import IO, TYPE_CHECKING, BinaryIO, Iterable, Optional, Union

import numpy as np
import pandas as pd
Expand All @@ -23,6 +23,7 @@

FILE_LIKE = Union[str, os.PathLike, BinaryIO, IO[bytes]]


def save(forecaster, path: FILE_LIKE):
"""Save a fitted Neural Prophet model to disk.

Expand Down Expand Up @@ -375,14 +376,19 @@ def config_seasonality_to_model_dims(config_seasonality: ConfigSeasonality):
return seasonal_dims


def get_holidays_from_country(country: Union[str, Iterable[str]], df=None):
def get_holidays_from_country(
country: Union[str, Iterable[str]], subdivision: Optional[Union[str, dict]] = None, df=None
):
"""
Return all possible holiday names of given country

Parameters
----------
country : str, list
List of country names to retrieve country specific holidays
subdivision : str, dict
a single subdivision (e.g., province or state) as a string or
a dictionary where the key is the country name and the value is a subdivision
df : pd.Dataframe
Dataframe from which datestamps will be retrieved from

Expand All @@ -399,10 +405,16 @@ def get_holidays_from_country(country: Union[str, Iterable[str]], df=None):
# support multiple countries
if isinstance(country, str):
country = [country]

# support subdivisions
if subdivision is not None:
if isinstance(subdivision, str):
if isinstance(country, list):
raise ValueError("If country_name is a list, subdivisions must be a dictionary.")
subdivision = {country: subdivision}
unique_holidays = {}
for single_country in country:
holidays_country = get_country_holidays(single_country, years)
subdivision = subdivision.get(single_country) if subdivision else None
holidays_country = get_country_holidays(single_country, years, subdivision)
for date, name in holidays_country.items():
if date not in unique_holidays:
unique_holidays[date] = name
Expand Down
Loading
Loading