Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bug] Fix correct definition of torchmetrics inside pytorch lightning module #1365

Merged
merged 14 commits into from
Aug 14, 2023
9 changes: 9 additions & 0 deletions neuralprophet/time_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,17 @@ def __init__(
self.batch_size = self.config_train.batch_size

# Metrics Config
METRICS = {
leoniewgnr marked this conversation as resolved.
Show resolved Hide resolved
"MAE": torchmetrics.MeanAbsoluteError(),
"MSE": torchmetrics.MeanSquaredError(squared=True),
"RMSE": torchmetrics.MeanSquaredError(squared=False),
}

self.metrics_enabled = bool(metrics) # yields True if metrics is not an empty dictionary
if self.metrics_enabled:
# only convert to dict if metrics is a list (metrics were not set in utils_metrics)
if isinstance(metrics, list):
metrics = {metric: METRICS[metric] for metric in metrics}
leoniewgnr marked this conversation as resolved.
Show resolved Hide resolved
self.log_args = {
"on_step": False,
"on_epoch": True,
Expand Down
14 changes: 7 additions & 7 deletions neuralprophet/utils_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

def get_metrics(metric_input):
"""
Returns a list of metrics.
Returns a dict or list of metrics.

Parameters
----------
Expand All @@ -22,21 +22,21 @@ def get_metrics(metric_input):

Returns
-------
dict
Dict of torchmetrics.Metric metrics.
dict or list
leoniewgnr marked this conversation as resolved.
Show resolved Hide resolved
Dict of torchmetrics.Metric metrics or list of strings of metrics to use.
"""
if metric_input is None:
return {}
return []
elif metric_input is True:
return {k: v for k, v in METRICS.items() if k in ["MAE", "RMSE"]}
return ["MAE", "RMSE"]
elif isinstance(metric_input, str):
if metric_input.upper() in METRICS.keys():
return {metric_input: METRICS[metric_input]}
return [metric_input]
else:
raise ValueError("Received unsupported argument for collect_metrics.")
elif isinstance(metric_input, list):
if all([m.upper() in METRICS.keys() for m in metric_input]):
return {k: v for k, v in METRICS.items() if k in metric_input}
return metric_input
else:
raise ValueError("Received unsupported argument for collect_metrics.")
elif isinstance(metric_input, dict):
leoniewgnr marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
47 changes: 41 additions & 6 deletions tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1367,25 +1367,60 @@ def test_get_latest_forecast():
def test_metrics():
log.info("testing: Plotting")
df = pd.read_csv(PEYTON_FILE, nrows=NROWS)
m = NeuralProphet(
# list
m_list = NeuralProphet(
epochs=EPOCHS,
batch_size=BATCH_SIZE,
learning_rate=LR,
collect_metrics=["MAE", "MSE", "RMSE"],
)
metrics_df = m.fit(df, freq="D")
metrics_df = m_list.fit(df, freq="D")
assert all([metric in metrics_df.columns for metric in ["MAE", "MSE", "RMSE"]])
m.predict(df)
m_list.predict(df)

m2 = NeuralProphet(
# dict
m_dict = NeuralProphet(
epochs=EPOCHS,
batch_size=BATCH_SIZE,
learning_rate=LR,
collect_metrics={"ABC": torchmetrics.MeanAbsoluteError()},
)
metrics_df = m2.fit(df, freq="D")
metrics_df = m_dict.fit(df, freq="D")
assert "ABC" in metrics_df.columns
m2.predict(df)
m_dict.predict(df)

# string
m_string = NeuralProphet(
epochs=EPOCHS,
batch_size=BATCH_SIZE,
learning_rate=LR,
collect_metrics="MAE",
)
metrics_df = m_string.fit(df, freq="D")
assert "MAE" in metrics_df.columns
m_string.predict(df)

# False
m_false = NeuralProphet(
epochs=EPOCHS,
batch_size=BATCH_SIZE,
learning_rate=LR,
collect_metrics=False,
)
metrics_df = m_false.fit(df, freq="D")
assert metrics_df is None
m_false.predict(df)

# None
m_none = NeuralProphet(
epochs=EPOCHS,
batch_size=BATCH_SIZE,
learning_rate=LR,
collect_metrics=None,
)
metrics_df = m_none.fit(df, freq="D")
assert metrics_df is None
m_none.predict(df)


def test_progress_display():
Expand Down
Loading