Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bug] Fix correct definition of torchmetrics inside pytorch lightning module #1365

Merged
merged 14 commits into from
Aug 14, 2023
7 changes: 5 additions & 2 deletions neuralprophet/forecaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,12 +242,15 @@
* (default) ``True``: [``mae``, ``rmse``]
* ``False``: No metrics
* ``list``: Valid options: [``mae``, ``rmse``, ``mse``]
* ``dict``: Collection of torchmetrics.Metric objects
* ``dict``: Collection of names of torchmetrics.Metric objects

Examples
--------
>>> from neuralprophet import NeuralProphet
>>> # computer MSE, MAE and RMSE
>>> m = NeuralProphet(collect_metrics=["MSE", "MAE", "RMSE"])
>>> # use custorm torchmetrics names
>>> m = NeuralProphet(collect_metrics={"MAPE": "MeanAbsolutePercentageError", "MSLE": "MeanSquaredLogError",

COMMENT
Uncertainty Estimation
Expand Down Expand Up @@ -366,7 +369,7 @@
impute_linear: int = 10,
impute_rolling: int = 10,
drop_missing: bool = False,
collect_metrics: np_types.CollectMetricsMode = True,
collect_metrics: Union[bool, list, dict] = True,
normalize: np_types.NormalizeMode = "auto",
global_normalization: bool = False,
global_time_normalization: bool = True,
Expand Down Expand Up @@ -1002,7 +1005,7 @@
# Only display the plot if the session is interactive, eg. do not show in github actions since it
# causes an error in the Windows and MacOS environment
if matplotlib.is_interactive():
fig

Check warning on line 1008 in neuralprophet/forecaster.py

View workflow job for this annotation

GitHub Actions / pyright

Expression value is unused (reportUnusedExpression)

self.fitted = True
return metrics_df
Expand Down
5 changes: 2 additions & 3 deletions neuralprophet/np_types.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import sys
from typing import Dict, List, Union
from typing import Dict, Union

import torch
import torchmetrics

# Ensure compatibility with python 3.7
if sys.version_info >= (3, 8):
Expand All @@ -19,7 +18,7 @@

GrowthMode = Literal["off", "linear", "discontinuous"]

CollectMetricsMode = Union[List[str], bool, Dict[str, torchmetrics.Metric]]
CollectMetricsMode = Union[Dict, bool]

SeasonGlobalLocalMode = Literal["global", "local"]

Expand Down
1 change: 1 addition & 0 deletions neuralprophet/time_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ def __init__(
# Metrics Config
self.metrics_enabled = bool(metrics) # yields True if metrics is not an empty dictionary
if self.metrics_enabled:
metrics = {metric: torchmetrics.__dict__[metrics[metric][0]](**metrics[metric][1]) for metric in metrics}
self.log_args = {
"on_step": False,
"on_epoch": True,
Expand Down
30 changes: 17 additions & 13 deletions neuralprophet/utils_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,16 @@
log = logging.getLogger("NP.metrics")

METRICS = {
"MAE": torchmetrics.MeanAbsoluteError(),
"MSE": torchmetrics.MeanSquaredError(squared=True),
"RMSE": torchmetrics.MeanSquaredError(squared=False),
# "short_name": [torchmetrics.Metric name, {optional args}]
"MAE": ["MeanAbsoluteError", {}],
"MSE": ["MeanSquaredError", {"squared": True}],
"RMSE": ["MeanSquaredError", {"squared": False}],
}


def get_metrics(metric_input):
"""
Returns a list of metrics.
Returns a dict of metrics.

Parameters
----------
Expand All @@ -23,29 +24,32 @@
Returns
-------
dict
Dict of torchmetrics.Metric metrics.
Dict of names of torchmetrics.Metric metrics
"""
if metric_input is None:
return {}
elif metric_input is True:
return {k: v for k, v in METRICS.items() if k in ["MAE", "RMSE"]}
return {"MAE": METRICS["MAE"], "RMSE": METRICS["RMSE"]}
elif isinstance(metric_input, str):
if metric_input.upper() in METRICS.keys():
return {metric_input: METRICS[metric_input]}
return {metric_input: METRICS[metric_input.upper()]}
else:
raise ValueError("Received unsupported argument for collect_metrics.")
elif isinstance(metric_input, list):
if all([m.upper() in METRICS.keys() for m in metric_input]):
return {k: v for k, v in METRICS.items() if k in metric_input}
return {m: METRICS[m.upper()] for m in metric_input}
else:
raise ValueError("Received unsupported argument for collect_metrics.")
elif isinstance(metric_input, dict):
leoniewgnr marked this conversation as resolved.
Show resolved Hide resolved
if all([isinstance(_metric, torchmetrics.Metric) for _, _metric in metric_input.items()]):
return metric_input
else:
# check if all values are names belonging to torchmetrics.Metric
try:
for _metric in metric_input.values():
torchmetrics.__dict__[_metric]()
except KeyError:

Check warning on line 48 in neuralprophet/utils_metrics.py

View check run for this annotation

Codecov / codecov/patch

neuralprophet/utils_metrics.py#L48

Added line #L48 was not covered by tests
raise ValueError(
"Received unsupported argument for collect_metrics. All metrics must be an instance of "
"torchmetrics.Metric."
"Received unsupported argument for collect_metrics."
"All metrics must be valid names of torchmetrics.Metric objects."
)
return {k: [v, {}] for k, v in metric_input.items()}
elif metric_input is not False:
raise ValueError("Received unsupported argument for collect_metrics.")
50 changes: 42 additions & 8 deletions tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import pandas as pd
import pytest
import torch
import torchmetrics

from neuralprophet import NeuralProphet, df_utils, set_random_seed
from neuralprophet.data.process import _handle_missing_data, _validate_column_name
Expand Down Expand Up @@ -1367,25 +1366,60 @@ def test_get_latest_forecast():
def test_metrics():
log.info("testing: Plotting")
df = pd.read_csv(PEYTON_FILE, nrows=NROWS)
m = NeuralProphet(
# list
m_list = NeuralProphet(
epochs=EPOCHS,
batch_size=BATCH_SIZE,
learning_rate=LR,
collect_metrics=["MAE", "MSE", "RMSE"],
)
metrics_df = m.fit(df, freq="D")
metrics_df = m_list.fit(df, freq="D")
assert all([metric in metrics_df.columns for metric in ["MAE", "MSE", "RMSE"]])
m.predict(df)
m_list.predict(df)

m2 = NeuralProphet(
# dict
m_dict = NeuralProphet(
epochs=EPOCHS,
batch_size=BATCH_SIZE,
learning_rate=LR,
collect_metrics={"ABC": torchmetrics.MeanAbsoluteError()},
collect_metrics={"ABC": "MeanSquaredLogError"},
)
metrics_df = m2.fit(df, freq="D")
metrics_df = m_dict.fit(df, freq="D")
assert "ABC" in metrics_df.columns
m2.predict(df)
m_dict.predict(df)

# string
m_string = NeuralProphet(
epochs=EPOCHS,
batch_size=BATCH_SIZE,
learning_rate=LR,
collect_metrics="MAE",
)
metrics_df = m_string.fit(df, freq="D")
assert "MAE" in metrics_df.columns
m_string.predict(df)

# False
m_false = NeuralProphet(
epochs=EPOCHS,
batch_size=BATCH_SIZE,
learning_rate=LR,
collect_metrics=False,
)
metrics_df = m_false.fit(df, freq="D")
assert metrics_df is None
m_false.predict(df)

# None
m_none = NeuralProphet(
epochs=EPOCHS,
batch_size=BATCH_SIZE,
learning_rate=LR,
collect_metrics=None,
)
metrics_df = m_none.fit(df, freq="D")
assert metrics_df is None
m_none.predict(df)


def test_progress_display():
Expand Down
Loading