diff --git a/neuralprophet/configure.py b/neuralprophet/configure.py index f64315e26..054d92cd1 100644 --- a/neuralprophet/configure.py +++ b/neuralprophet/configure.py @@ -390,12 +390,10 @@ def __post_init__(self): } ) - if self.seasonality_local_reg < 0: - log.error("Invalid negative seasonality_local_reg '{}'. Set to False".format(self.seasonality_local_reg)) - self.seasonality_local_reg = False + assert self.seasonality_local_reg >= 0, "Invalid seasonality_local_reg '{}'.".format(self.seasonality_local_reg) if self.seasonality_local_reg is True: - log.error("seasonality_local_reg = True. Default seasonality_local_reg value set to 1") + log.warning("seasonality_local_reg = True. Default seasonality_local_reg value set to 1") self.seasonality_local_reg = 1 # If Season modelling is global but local regularization is set. diff --git a/tests/test_glocal.py b/tests/test_glocal.py index 5c171d597..73048a01a 100644 --- a/tests/test_glocal.py +++ b/tests/test_glocal.py @@ -5,6 +5,7 @@ import pathlib import pandas as pd +import pytest from neuralprophet import NeuralProphet @@ -341,7 +342,8 @@ def test_glocal_seasonality_reg(): df2_0["ID"] = "df2" df3_0 = df.iloc[256:384, :].copy(deep=True) df3_0["ID"] = "df3" - for _ in [-30, 0, False, True]: + for coef_i in [0, 1.5, False, True]: + m = NeuralProphet( n_forecasts=1, epochs=EPOCHS, @@ -349,6 +351,7 @@ def test_glocal_seasonality_reg(): learning_rate=LR, season_global_local="local", yearly_seasonality_glocal_mode="global", + seasonality_local_reg=coef_i, ) m.add_seasonality(period=30, fourier_order=8, name="monthly", global_local="global") @@ -359,6 +362,24 @@ def test_glocal_seasonality_reg(): metrics = m.test(test_df) log.info(f"forecast = {forecast}, metrics = {metrics}") + with pytest.raises(AssertionError, match="Invalid seasonality_local_reg"): + m = NeuralProphet( + n_forecasts=1, + epochs=EPOCHS, + batch_size=BATCH_SIZE, + learning_rate=LR, + season_global_local="local", + yearly_seasonality_glocal_mode="global", + seasonality_local_reg=-324, + ) + + m.add_seasonality(period=30, fourier_order=8, name="monthly", global_local="global") + train_df, test_df = m.split_df(pd.concat((df1_0, df2_0, df3_0)), valid_p=0.33, local_split=True) + m.fit(train_df) + future = m.make_future_dataframe(test_df, n_historic_predictions=True) + forecast = m.predict(future) + metrics = m.test(test_df) + def test_trend_local_reg_if_global(): # SEASONALITY GLOBAL LOCAL MODELLING - NO EXOGENOUS VARIABLES