Skip to content

Commit

Permalink
Merge branch 'main' into correct-definition-torchmetrics
Browse files Browse the repository at this point in the history
  • Loading branch information
leoniewgnr authored Aug 9, 2023
2 parents 5323f23 + 90f6e6e commit b1c50ee
Show file tree
Hide file tree
Showing 3 changed files with 191 additions and 177 deletions.
7 changes: 5 additions & 2 deletions neuralprophet/plot_model_parameters_matplotlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,9 +265,12 @@ def plot_trend(m, quantile, ax=None, plot_name="Trend", figsize=(10, 6), df_name
trend_1 = trend_0
else:
if m.model.config_trend.trend_global_local == "local":
trend_1 = trend_0 + m.model.trend.trend_k0[quantile_index, m.model.id_dict[df_name]].detach().numpy()
trend_1 = (
trend_0
+ m.model.trend.trend_k0[quantile_index, m.model.id_dict[df_name]].detach().numpy().squeeze()
)
else:
trend_1 = trend_0 + m.model.trend.trend_k0[quantile_index, 0].detach().numpy()
trend_1 = trend_0 + m.model.trend.trend_k0[quantile_index, 0].detach().numpy().squeeze()

data_params = m.config_normalization.get_data_params(df_name)
shift = data_params["y"].shift
Expand Down
Loading

0 comments on commit b1c50ee

Please sign in to comment.