Skip to content

Commit

Permalink
data unrolling: better handling of batched tf.Dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
fel-thomas committed Dec 14, 2022
1 parent 27e995f commit 337a3b6
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 13 deletions.
18 changes: 11 additions & 7 deletions tests/attributions/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ def test_common():
# all explanations returned must be either a tf.Tensor or ndarray
assert isinstance(explanations, (tf.Tensor, np.ndarray))

# we should have one explanation for each inputs
assert len(explanations) == len(inputs_np)


def test_batch_size():
"""Ensure the functioning of attributions for special batch size cases"""
Expand Down Expand Up @@ -89,12 +92,13 @@ def test_batch_size():
]

for method in methods:
try:
explanations = method.explain(inputs, targets)
except:
raise AssertionError(
"Explanation failed for method ", method.__class__.__name__,
" batch size ", bs)
explanations = method.explain(inputs, targets)

# all explanations returned must be either a tf.Tensor or ndarray
assert isinstance(explanations, (tf.Tensor, np.ndarray))

# we should have one explanation for each inputs
assert len(explanations) == len(inputs)


def test_model_caching():
Expand All @@ -118,4 +122,4 @@ def test_model_caching():

# ensure that there no more than one key has been added
assert (len(
BlackBoxExplainer._cache_models) == cache_len_before + 1) # pylint: disable=protected-access
BlackBoxExplainer._cache_models) == cache_len_before + 1) # pylint: disable=protected-access
10 changes: 4 additions & 6 deletions xplique/commons/data_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,13 @@ def tensor_sanitize(inputs: Union[tf.data.Dataset, tf.Tensor, np.ndarray],

# deal with tf.data.Dataset
if isinstance(inputs, tf.data.Dataset):
# if the dataset as 4 dimensions, assume it is batched
dataset_shape = inputs.element_spec[0].shape
if len(dataset_shape) == 4:
# try to know if the dataset is batched, if it is the case we unbatch
if hasattr(inputs, '_batch_size'):
inputs = inputs.unbatch()
# unpack the dataset, assume we have tuple of (input, target)
targets = [target for inp, target in inputs]
inputs = [inp for inp, target in inputs]
targets = [target for _, target in inputs]
inputs = [inp for inp, _ in inputs]

# deal with numpy array
inputs = tf.cast(inputs, tf.float32)
targets = tf.cast(targets, tf.float32)

Expand Down

0 comments on commit 337a3b6

Please sign in to comment.