diff --git a/xplique/metrics/fidelity.py b/xplique/metrics/fidelity.py index d5132144..dd3c1b74 100644 --- a/xplique/metrics/fidelity.py +++ b/xplique/metrics/fidelity.py @@ -9,7 +9,7 @@ from scipy.stats import spearmanr from .base import ExplanationMetric -from ..commons import batch_predictions_one_hot +from ..commons import batch_predictions from ..types import Union, Callable, Optional, Dict @@ -83,8 +83,8 @@ def __init__(self, (self.nb_samples, self.grid_size, self.grid_size, 1)) self.subset_masks = tf.image.resize(subset_masks, inputs.shape[1:-1], method="nearest") - self.base_predictions = batch_predictions_one_hot(self.model, inputs, - targets, self.batch_size) + self.base_predictions = batch_predictions(self.model, inputs, + targets, self.batch_size) def evaluate(self, explanations: Union[tf.Tensor, np.ndarray]) -> float: @@ -117,8 +117,8 @@ def evaluate(self, # use the masks to set the selected subsets to baseline state degraded_inputs = inp * self.subset_masks + (1.0 - self.subset_masks) * baseline # measure the two terms that should be correlated - preds = base - batch_predictions_one_hot(self.model, degraded_inputs, - label, self.batch_size) + preds = base - batch_predictions(self.model, degraded_inputs, + label, self.batch_size) attrs = tf.reduce_sum(phi * (1.0 - self.subset_masks), (1, 2, 3)) corr_score = spearmanr(preds, attrs)[0] @@ -276,8 +276,8 @@ def detailed_evaluate(self, batch_inputs = batch_inputs.reshape((-1, *self.inputs.shape[1:])) - predictions = batch_predictions_one_hot(self.model, batch_inputs, - self.targets, self.batch_size) + predictions = batch_predictions(self.model, batch_inputs, + self.targets, self.batch_size) scores_dict[step] = np.mean(predictions)