Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Minor] Improve Season glocal reg invalid parameter handling #1601

Merged
merged 3 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions neuralprophet/configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@
log.error("Invalid growth for global_local mode '{}'. Set to 'global'".format(self.trend_global_local))
self.trend_global_local = "global"

if self.trend_local_reg < 0:

Check failure on line 308 in neuralprophet/configure.py

View workflow job for this annotation

GitHub Actions / pyright

Operator "<" not supported for "None" (reportOptionalOperand)
log.error("Invalid negative trend_local_reg '{}'. Set to False".format(self.trend_local_reg))
self.trend_local_reg = False

Expand Down Expand Up @@ -354,13 +354,13 @@
log.error("Invalid global_local mode '{}'. Set to 'global'".format(self.global_local))
self.global_local = "global"

self.periods = OrderedDict(

Check failure on line 357 in neuralprophet/configure.py

View workflow job for this annotation

GitHub Actions / pyright

No overloads for "__init__" match the provided arguments (reportCallIssue)
{

Check failure on line 358 in neuralprophet/configure.py

View workflow job for this annotation

GitHub Actions / pyright

Argument of type "dict[str, Season]" cannot be assigned to parameter "iterable" of type "Iterable[list[bytes]]" in function "__init__" (reportArgumentType)
"yearly": Season(
resolution=6,
period=365.25,
arg=self.yearly_arg,
global_local=(

Check failure on line 363 in neuralprophet/configure.py

View workflow job for this annotation

GitHub Actions / pyright

Argument of type "SeasonGlobalLocalMode | Literal['auto']" cannot be assigned to parameter "global_local" of type "SeasonGlobalLocalMode" in function "__init__" (reportArgumentType)
self.yearly_global_local
if self.yearly_global_local in ["global", "local"]
else self.global_local
Expand All @@ -371,7 +371,7 @@
resolution=3,
period=7,
arg=self.weekly_arg,
global_local=(

Check failure on line 374 in neuralprophet/configure.py

View workflow job for this annotation

GitHub Actions / pyright

Argument of type "SeasonGlobalLocalMode | Literal['auto']" cannot be assigned to parameter "global_local" of type "SeasonGlobalLocalMode" in function "__init__" (reportArgumentType)
self.weekly_global_local
if self.weekly_global_local in ["global", "local"]
else self.global_local
Expand All @@ -382,7 +382,7 @@
resolution=6,
period=1,
arg=self.daily_arg,
global_local=(

Check failure on line 385 in neuralprophet/configure.py

View workflow job for this annotation

GitHub Actions / pyright

Argument of type "SeasonGlobalLocalMode | Literal['auto']" cannot be assigned to parameter "global_local" of type "SeasonGlobalLocalMode" in function "__init__" (reportArgumentType)
self.daily_global_local if self.daily_global_local in ["global", "local"] else self.global_local
),
condition_name=None,
Expand All @@ -390,12 +390,10 @@
}
)

if self.seasonality_local_reg < 0:
log.error("Invalid negative seasonality_local_reg '{}'. Set to False".format(self.seasonality_local_reg))
self.seasonality_local_reg = False
assert self.seasonality_local_reg >= 0, "Invalid seasonality_local_reg '{}'.".format(self.seasonality_local_reg)

Check failure on line 393 in neuralprophet/configure.py

View workflow job for this annotation

GitHub Actions / pyright

Operator ">=" not supported for "None" (reportOptionalOperand)

if self.seasonality_local_reg is True:
log.error("seasonality_local_reg = True. Default seasonality_local_reg value set to 1")
log.warning("seasonality_local_reg = True. Default seasonality_local_reg value set to 1")
self.seasonality_local_reg = 1

# If Season modelling is global but local regularization is set.
Expand All @@ -410,7 +408,7 @@
resolution=resolution,
period=period,
arg=arg,
global_local=global_local if global_local in ["global", "local"] else self.global_local,

Check failure on line 411 in neuralprophet/configure.py

View workflow job for this annotation

GitHub Actions / pyright

Argument of type "str" cannot be assigned to parameter "global_local" of type "SeasonGlobalLocalMode" in function "__init__"   Type "str" is incompatible with type "SeasonGlobalLocalMode"     "str" is incompatible with type "Literal['global']"     "str" is incompatible with type "Literal['local']"     "str" is incompatible with type "Literal['glocal']" (reportArgumentType)
condition_name=condition_name,
)

Expand Down Expand Up @@ -486,7 +484,7 @@
regressors: OrderedDict = field(init=False) # contains RegressorConfig objects

def __post_init__(self):
self.regressors = None

Check failure on line 487 in neuralprophet/configure.py

View workflow job for this annotation

GitHub Actions / pyright

Cannot assign to attribute "regressors" for class "ConfigFutureRegressors*"   "None" is incompatible with "OrderedDict[Unknown, Unknown]" (reportAttributeAccessIssue)


@dataclass
Expand Down
23 changes: 22 additions & 1 deletion tests/test_glocal.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pathlib

import pandas as pd
import pytest

from neuralprophet import NeuralProphet

Expand Down Expand Up @@ -341,14 +342,16 @@ def test_glocal_seasonality_reg():
df2_0["ID"] = "df2"
df3_0 = df.iloc[256:384, :].copy(deep=True)
df3_0["ID"] = "df3"
for _ in [-30, 0, False, True]:
for coef_i in [0, 1.5, False, True]:

m = NeuralProphet(
n_forecasts=1,
epochs=EPOCHS,
batch_size=BATCH_SIZE,
learning_rate=LR,
season_global_local="local",
yearly_seasonality_glocal_mode="global",
seasonality_local_reg=coef_i,
)

m.add_seasonality(period=30, fourier_order=8, name="monthly", global_local="global")
Expand All @@ -359,6 +362,24 @@ def test_glocal_seasonality_reg():
metrics = m.test(test_df)
log.info(f"forecast = {forecast}, metrics = {metrics}")

with pytest.raises(AssertionError, match="Invalid seasonality_local_reg"):
m = NeuralProphet(
n_forecasts=1,
epochs=EPOCHS,
batch_size=BATCH_SIZE,
learning_rate=LR,
season_global_local="local",
yearly_seasonality_glocal_mode="global",
seasonality_local_reg=-324,
)

m.add_seasonality(period=30, fourier_order=8, name="monthly", global_local="global")
train_df, test_df = m.split_df(pd.concat((df1_0, df2_0, df3_0)), valid_p=0.33, local_split=True)
m.fit(train_df)
future = m.make_future_dataframe(test_df, n_historic_predictions=True)
forecast = m.predict(future)
metrics = m.test(test_df)


def test_trend_local_reg_if_global():
# SEASONALITY GLOBAL LOCAL MODELLING - NO EXOGENOUS VARIABLES
Expand Down
Loading