Skip to content

Commit

Permalink
fixed components computation
Browse files Browse the repository at this point in the history
  • Loading branch information
leoniewgnr authored Jun 14, 2023
1 parent 8040314 commit cfab7f3
Showing 1 changed file with 15 additions and 8 deletions.
23 changes: 15 additions & 8 deletions neuralprophet/time_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,7 @@ def _forward(self, inputs: Dict, meta: Dict = None, non_stationary_only: bool =
) # dimensions - [batch, n_forecasts, no_quantiles]
return out

def forward(self, inputs: Dict, meta: Dict = None) -> Dict:
def forward(self, inputs: Dict, meta: Dict = None, compute_components: bool = False) -> Dict:
"""
Forward pass of the model to compute predictions based on the provided inputs and meta data.
This method fits non-stationary components first, substracts them from the present "lags" and in a
Expand All @@ -617,11 +617,15 @@ def forward(self, inputs: Dict, meta: Dict = None) -> Dict:
"regressors", "regressors_lagged", and "predict_mode".
meta : Dict, optional
Dictionary containing additional meta data for the forward pass, by default None.
compute_components : bool, optional
If True, the method returns additional components, by default False.
Returns
-------
Dict
Dictionary containing the prediction results with quantiles.
dict
Containing forecast coomponents with elements of dims (batch, n_forecasts)
"""
if "lags" in inputs:
_inputs = inputs.copy()
Expand All @@ -648,7 +652,15 @@ def forward(self, inputs: Dict, meta: Dict = None) -> Dict:
else:
predict_mode = False
prediction_with_quantiles = self._compute_quantile_forecasts_from_diffs(prediction, predict_mode)
return prediction_with_quantiles

# compute components
corrected_inputs = inputs if corrected_inputs is None else corrected_inputs

if compute_components:
components = self.compute_components(corrected_inputs, meta)
return prediction_with_quantiles, components
else:
return prediction_with_quantiles

def compute_components(self, inputs: Dict, meta: Dict) -> Dict:
"""This method returns the values of each model component.
Expand Down Expand Up @@ -841,12 +853,7 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0):
# Add predict_mode flag to dataset
inputs["predict_mode"] = True
# Run forward calculation
prediction = self.forward(inputs, meta_name_tensor)
# Calculate components (if requested)
if self.compute_components_flag:
components = self.compute_components(inputs, meta_name_tensor)
else:
components = None
prediction, components = self.forward(inputs, meta_name_tensor, compute_components=self.compute_components_flag)
return prediction, components

def configure_optimizers(self):
Expand Down

0 comments on commit cfab7f3

Please sign in to comment.