Skip to content

Commit

Permalink
Uncertainty: Conformal Prediction V1.1 - support for multiple forecas…
Browse files Browse the repository at this point in the history
…t steps (#1073)

* Replaced yhat1 with the step_number for conformal.

* Moved conformal_predict() method logic into conformal_prediction.py and added plot_interval_width_per_timestep() method for multiple timesteps.

* Changed self.config_train.quantiles to quantiles in conformal_prediction.py.

* Removed q_hats in _conformalize().

* Added Conformal dataclass from PR #1073.

* Added conformal.py to replace conformal_prediction.py.

* Fixed self.method in ValueError for conformal.py.

* Modified Conformal class to fit multiple lines in forecaster.py.

* Uncommented auto-regression section for test_plot_conformal_prediction and changed the split frequencies from  to  because of the Peyton Manning dataset.

* Uncommented m.plot_latest_forecast in test_plot_conformal_prediction in test_plotting.py.

* Added step_number in docstring in the _get_nonconformity_scores() method in conformal.py.

* Uncommented the tests in test_model_performance.py.

* Added plot_interval_width_per_timestep() to plot_forecast_plotly.py and modified plot() in conformal.py to enable this method for plotting_backend='plotly'.

* Modified test_plot_conformal_prediction in test_plotting.py by adding plotting_backend param into m.conformal_predict() method.
  • Loading branch information
Kevin-Chen0 authored Jan 13, 2023
1 parent 05d5e10 commit 3d46b1e
Show file tree
Hide file tree
Showing 6 changed files with 233 additions and 125 deletions.
95 changes: 65 additions & 30 deletions neuralprophet/conformal.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
import numpy as np
import pandas as pd

from neuralprophet.plot_forecast_matplotlib import plot_nonconformity_scores
from neuralprophet.plot_forecast_matplotlib import plot_interval_width_per_timestep, plot_nonconformity_scores
from neuralprophet.plot_forecast_plotly import (
plot_interval_width_per_timestep as plot_interval_width_per_timestep_plotly,
)
from neuralprophet.plot_forecast_plotly import plot_nonconformity_scores as plot_nonconformity_scores_plotly


Expand All @@ -23,13 +26,16 @@ class Conformal:
Options
* ``naive``: Naive or Absolute Residual
* ``cqr``: Conformalized Quantile Regression
n_forecasts : int
optional, number of steps ahead of prediction time step to forecast
quantiles : list
optional, list of quantiles for quantile regression uncertainty estimate
"""

alpha: float
method: str
n_forecasts: int = 1
quantiles: Optional[List[float]] = None

def predict(self, df: pd.DataFrame, df_cal: pd.DataFrame) -> pd.DataFrame:
Expand All @@ -48,34 +54,50 @@ def predict(self, df: pd.DataFrame, df_cal: pd.DataFrame) -> pd.DataFrame:
test dataframe with uncertainty prediction intervals
"""
# conformalize
self.noncon_scores = self._get_nonconformity_scores(df_cal)
self.q_hat = self._get_q_hat(df_cal)
df["qhat1"] = self.q_hat
if self.method == "naive":
df["yhat1 - qhat1"] = df["yhat1"] - self.q_hat
df["yhat1 + qhat1"] = df["yhat1"] + self.q_hat
elif self.method == "cqr":
quantile_hi = str(max(self.quantiles) * 100)
quantile_lo = str(min(self.quantiles) * 100)
df[f"yhat1 {quantile_hi}% - qhat1"] = df[f"yhat1 {quantile_hi}%"] - self.q_hat
df[f"yhat1 {quantile_hi}% + qhat1"] = df[f"yhat1 {quantile_hi}%"] + self.q_hat
df[f"yhat1 {quantile_lo}% - qhat1"] = df[f"yhat1 {quantile_lo}%"] - self.q_hat
df[f"yhat1 {quantile_lo}% + qhat1"] = df[f"yhat1 {quantile_lo}%"] + self.q_hat
else:
raise ValueError(
f"Unknown conformal prediction method '{self.method}'. Please input either 'naive' or 'cqr'."
)
self.q_hats = []
for step_number in range(1, self.n_forecasts + 1):
# conformalize
noncon_scores = self._get_nonconformity_scores(df_cal, step_number)
q_hat = self._get_q_hat(df_cal, noncon_scores)
df[f"qhat{step_number}"] = q_hat
if self.method == "naive":
df[f"yhat{step_number} - qhat{step_number}"] = df[f"yhat{step_number}"] - q_hat
df[f"yhat{step_number} + qhat{step_number}"] = df[f"yhat{step_number}"] + q_hat
elif self.method == "cqr":
quantile_hi = str(max(self.quantiles) * 100)
quantile_lo = str(min(self.quantiles) * 100)
df[f"yhat{step_number} {quantile_hi}% - qhat{step_number}"] = (
df[f"yhat{step_number} {quantile_hi}%"] - q_hat
)
df[f"yhat{step_number} {quantile_hi}% + qhat{step_number}"] = (
df[f"yhat{step_number} {quantile_hi}%"] + q_hat
)
df[f"yhat{step_number} {quantile_lo}% - qhat{step_number}"] = (
df[f"yhat{step_number} {quantile_lo}%"] - q_hat
)
df[f"yhat{step_number} {quantile_lo}% + qhat{step_number}"] = (
df[f"yhat{step_number} {quantile_lo}%"] + q_hat
)
else:
raise ValueError(
f"Unknown conformal prediction method '{self.method}'. Please input either 'naive' or 'cqr'."
)
if step_number == 1:
# save nonconformity scores of the first timestep
self.noncon_scores = noncon_scores
self.q_hats.append(q_hat)

return df

def _get_nonconformity_scores(self, df_cal: pd.DataFrame) -> np.ndarray:
def _get_nonconformity_scores(self, df_cal: pd.DataFrame, step_number: int) -> np.ndarray:
"""Get the nonconformity scores using the given conformal prediction technique.
Parameters
----------
df_cal : pd.DataFrame
calibration dataframe
step_number : int
i-th step ahead forecast
Returns
-------
Expand All @@ -89,35 +111,40 @@ def _get_nonconformity_scores(self, df_cal: pd.DataFrame) -> np.ndarray:
quantile_lo = str(min(self.quantiles) * 100)
cqr_scoring_func = (
lambda row: [None, None]
if row[f"yhat1 {quantile_lo}%"] is None or row[f"yhat1 {quantile_hi}%"] is None
if row[f"yhat{step_number} {quantile_lo}%"] is None or row[f"yhat{step_number} {quantile_hi}%"] is None
else [
max(
row[f"yhat1 {quantile_lo}%"] - row["y"],
row["y"] - row[f"yhat1 {quantile_hi}%"],
row[f"yhat{step_number} {quantile_lo}%"] - row["y"],
row["y"] - row[f"yhat{step_number} {quantile_hi}%"],
),
0 if row[f"yhat1 {quantile_lo}%"] - row["y"] > row["y"] - row[f"yhat1 {quantile_hi}%"] else 1,
0
if row[f"yhat{step_number} {quantile_lo}%"] - row["y"]
> row["y"] - row[f"yhat{step_number} {quantile_hi}%"]
else 1,
]
)
scores_df = df_cal.apply(cqr_scoring_func, axis=1, result_type="expand")
scores_df.columns = ["scores", "arg"]
noncon_scores = scores_df["scores"].values
else: # self.method == "naive"
# Naive nonconformity scoring function
noncon_scores = abs(df_cal["y"] - df_cal["yhat1"]).values
noncon_scores = abs(df_cal["y"] - df_cal[f"yhat{step_number}"]).values
# Remove NaN values
noncon_scores = noncon_scores[~pd.isnull(noncon_scores)]
# Sort
noncon_scores.sort()

return noncon_scores

def _get_q_hat(self, df_cal: pd.DataFrame) -> float:
def _get_q_hat(self, df_cal: pd.DataFrame, noncon_scores: np.ndarray) -> float:
"""Get the q_hat that is derived from the nonconformity scores.
Parameters
----------
df_cal : pd.DataFrame
calibration dataframe
noncon_scores : np.ndarray
nonconformity scores
Returns
-------
Expand All @@ -126,8 +153,8 @@ def _get_q_hat(self, df_cal: pd.DataFrame) -> float:
"""
# Get the q-hat index and value
q_hat_idx = int(len(self.noncon_scores) * self.alpha)
q_hat = self.noncon_scores[-q_hat_idx]
q_hat_idx = int(len(noncon_scores) * self.alpha)
q_hat = noncon_scores[-q_hat_idx]

return q_hat

Expand All @@ -146,8 +173,16 @@ def plot(self, plotting_backend: str):
"""
method = self.method.upper() if "cqr" in self.method.lower() else self.method.title()
if plotting_backend == "plotly":
fig = plot_nonconformity_scores_plotly(self.noncon_scores, self.alpha, self.q_hat, method)
if self.n_forecasts == 1:
# includes nonconformity scores of the first timestep
fig = plot_nonconformity_scores_plotly(self.noncon_scores, self.alpha, self.q_hats[0], method)
else:
fig = plot_interval_width_per_timestep_plotly(self.q_hats, method)
elif plotting_backend == "matplotlib":
fig = plot_nonconformity_scores(self.noncon_scores, self.alpha, self.q_hat, method)
if self.n_forecasts == 1:
# includes nonconformity scores of the first timestep
fig = plot_nonconformity_scores(self.noncon_scores, self.alpha, self.q_hats[0], method)
else:
fig = plot_interval_width_per_timestep(self.q_hats, method)
if plotting_backend in ["matplotlib", "plotly"] and matplotlib.is_interactive():
fig.show()
7 changes: 6 additions & 1 deletion neuralprophet/forecaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -3148,7 +3148,12 @@ def conformal_predict(
# get predictions for test dataframe
df = self.predict(df, **kwargs)
# initiate Conformal instance
c = Conformal(alpha=alpha, method=method, quantiles=self.config_train.quantiles)
c = Conformal(
alpha=alpha,
method=method,
n_forecasts=self.n_forecasts,
quantiles=self.config_train.quantiles,
)
# call Conformal's predict to output test df with conformal prediction intervals
df = c.predict(df=df, df_cal=df_cal)
# plot one-sided prediction interval width with q
Expand Down
33 changes: 31 additions & 2 deletions neuralprophet/plot_forecast_matplotlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ def plot_multiforecast_component(


def plot_nonconformity_scores(scores, alpha, q, method):
"""Plot the NeuralProphet forecast components.
"""Plot the nonconformity scores as well as the one-sided interval width (q).
Parameters
----------
Expand All @@ -470,8 +470,37 @@ def plot_nonconformity_scores(scores, alpha, q, method):
ax.plot(confidence_levels, scores, label="score")
ax.axvline(x=1 - alpha, color="g", linestyle="-", label=f"(1-alpha) = {1-alpha}", linewidth=1)
ax.axhline(y=q, color="r", linestyle="-", label=f"q1 = {round(q, 2)}", linewidth=1)
ax.set_title(f"{method} One-Sided Interval Width with q")
ax.set_xlabel("Confidence Level")
ax.set_ylabel("One-Sided Interval Width")
ax.set_title(f"{method} One-Sided Interval Width with q")
ax.legend()
return fig


def plot_interval_width_per_timestep(q_hats, method):
"""Plot the nonconformity scores as well as the one-sided interval width (q).
Parameters
----------
q_hats : list
prediction interval widths (or q) for each timestep
method : str
name of conformal prediction technique used
Options
* (default) ``naive``: Naive or Absolute Residual
* ``cqr``: Conformalized Quantile Regression
Returns
-------
matplotlib.pyplot.figure
Figure showing the q-values for each timestep
"""
fig, ax = plt.subplots()
ax.plot(range(1, len(q_hats) + 1), q_hats)
ax.set_title(f"{method} One-Sided Interval Width with q per Timestep")
ax.set_xlabel("Timestep Number")
ax.set_ylabel("One-Sided Interval Width")
# ax.set_xlim(left=1)
ax.set_ylim(bottom=0)
return fig
32 changes: 32 additions & 0 deletions neuralprophet/plot_forecast_plotly.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,3 +723,35 @@ def plot_nonconformity_scores(scores, alpha, q, method):
)
fig.update_layout(margin=dict(l=70, r=70, t=60, b=50))
return fig


def plot_interval_width_per_timestep(q_hats, method):
"""Plot the nonconformity scores as well as the one-sided interval width (q).
Parameters
----------
q_hats : list
prediction interval widths (or q) for each timestep
method : str
name of conformal prediction technique used
Options
* (default) ``naive``: Naive or Absolute Residual
* ``cqr``: Conformalized Quantile Regression
Returns
-------
plotly.graph_objects.Figure
Figure showing the q-values for each timestep
"""
timestep_numbers = list(range(1, len(q_hats) + 1))
fig = px.line(
pd.DataFrame({"Timestep Number": timestep_numbers, "One-Sided Interval Width": q_hats}),
x="Timestep Number",
y="One-Sided Interval Width",
title=f"{method} One-Sided Interval Width with q per Timestep",
width=600,
height=400,
)
fig.update_layout(margin=dict(l=70, r=70, t=60, b=50))
return fig
79 changes: 43 additions & 36 deletions tests/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,52 +551,59 @@ def test_plot_conformal_prediction(plotting_backend):
batch_size=BATCH_SIZE,
learning_rate=LR,
)
train_df, test_df = m.split_df(df, freq="MS", valid_p=0.2)
train_df, cal_df = m.split_df(train_df, freq="MS", valid_p=0.15)
train_df, test_df = m.split_df(df, freq="D", valid_p=0.2)
train_df, cal_df = m.split_df(train_df, freq="D", valid_p=0.15)
metrics_df = m.fit(train_df, freq="D")
alpha = 0.1
for method in ["naive", "cqr"]: # Naive and CQR SCP methods
future = m.make_future_dataframe(test_df, periods=m.n_forecasts, n_historic_predictions=10)
forecast = m.conformal_predict(future, calibration_df=cal_df, alpha=alpha, method=method)
forecast = m.conformal_predict(
future, calibration_df=cal_df, alpha=alpha, method=method, plotting_backend=plotting_backend
)
m.highlight_nth_step_ahead_of_each_forecast(m.n_forecasts)
fig0 = m.plot(forecast, plotting_backend="matplotlib")
fig1 = m.plot_components(forecast, plotting_backend="matplotlib")
fig2 = m.plot_parameters(plotting_backend="matplotlib")
fig0 = m.plot(forecast, plotting_backend=plotting_backend)
fig1 = m.plot_components(forecast, plotting_backend=plotting_backend)
fig2 = m.plot_parameters(plotting_backend=plotting_backend)
if PLOT:
fig0.show()
fig1.show()
fig2.show()
# With auto-regression enabled
# TO-DO: Fix Assertion error n_train >= 1
# m = NeuralProphet(
# n_forecasts=7,
# n_lags=14,
# quantiles=[0.05, 0.95],
# epochs=EPOCHS,
# batch_size=BATCH_SIZE,
# learning_rate=LR,
# )
# train_df, test_df = m.split_df(df, freq="MS", valid_p=0.2)
# train_df, cal_df = m.split_df(train_df, freq="MS", valid_p=0.15)
# metrics_df = m.fit(train_df, freq="D")
# alpha = 0.1
# for method in ["naive", "cqr"]: # Naive and CQR SCP methods
# future = m.make_future_dataframe(df, periods=m.n_forecasts, n_historic_predictions=10)
# forecast = m.conformal_predict(future, calibration_df=cal_df, alpha=alpha, method=method)
# m.highlight_nth_step_ahead_of_each_forecast(m.n_forecasts)
# fig0 = m.plot(forecast)
# fig1 = m.plot_latest_forecast(forecast, include_previous_forecasts=10, plotting_backend="matplotlib")
# fig2 = m.plot_latest_forecast(forecast, include_previous_forecasts=10, plot_history_data=True, plotting_backend="matplotlib")
# fig3 = m.plot_latest_forecast(forecast, include_previous_forecasts=10, plot_history_data=False, plotting_backend="matplotlib")
# fig4 = m.plot_components(forecast, plotting_backend="matplotlib")
# fig5 = m.plot_parameters(plotting_backend="matplotlib")
# if PLOT:
# fig0.show()
# fig1.show()
# fig2.show()
# fig3.show()
# fig4.show()
# fig5.show()
m = NeuralProphet(
n_forecasts=7,
n_lags=14,
quantiles=[0.05, 0.95],
epochs=EPOCHS,
batch_size=BATCH_SIZE,
learning_rate=LR,
)
train_df, test_df = m.split_df(df, freq="D", valid_p=0.2)
train_df, cal_df = m.split_df(train_df, freq="D", valid_p=0.15)
metrics_df = m.fit(train_df, freq="D")
alpha = 0.1
for method in ["naive", "cqr"]: # Naive and CQR SCP methods
future = m.make_future_dataframe(df, periods=m.n_forecasts, n_historic_predictions=10)
forecast = m.conformal_predict(
future, calibration_df=cal_df, alpha=alpha, method=method, plotting_backend=plotting_backend
)
m.highlight_nth_step_ahead_of_each_forecast(m.n_forecasts)
fig0 = m.plot(forecast)
fig1 = m.plot_latest_forecast(forecast, include_previous_forecasts=10, plotting_backend=plotting_backend)
fig2 = m.plot_latest_forecast(
forecast, include_previous_forecasts=10, plot_history_data=True, plotting_backend=plotting_backend
)
fig3 = m.plot_latest_forecast(
forecast, include_previous_forecasts=10, plot_history_data=False, plotting_backend=plotting_backend
)
fig4 = m.plot_components(forecast, plotting_backend=plotting_backend)
fig5 = m.plot_parameters(plotting_backend=plotting_backend)
if PLOT:
fig0.show()
fig1.show()
fig2.show()
fig3.show()
fig4.show()
fig5.show()


@pytest.mark.parametrize(*decorator_input)
Expand Down
Loading

0 comments on commit 3d46b1e

Please sign in to comment.