diff --git a/foolbox/attacks/__init__.py b/foolbox/attacks/__init__.py index 589bcf84..09088899 100644 --- a/foolbox/attacks/__init__.py +++ b/foolbox/attacks/__init__.py @@ -50,6 +50,7 @@ from .binarization import BinarizationRefinementAttack # noqa: F401 from .dataset_attack import DatasetAttack # noqa: F401 from .boundary_attack import BoundaryAttack # noqa: F401 +from .hop_skip_jump import HopSkipJump # noqa: F401 from .brendel_bethge import ( # noqa: F401 L0BrendelBethgeAttack, L1BrendelBethgeAttack, diff --git a/foolbox/attacks/hop_skip_jump.py b/foolbox/attacks/hop_skip_jump.py new file mode 100644 index 00000000..fef4599d --- /dev/null +++ b/foolbox/attacks/hop_skip_jump.py @@ -0,0 +1,366 @@ +import logging +from typing import Union, Any, Optional, Callable, List +from typing_extensions import Literal + +import math + +import eagerpy as ep +import numpy as np + +from foolbox.attacks import LinearSearchBlendedUniformNoiseAttack +from foolbox.tensorboard import TensorBoard +from ..models import Model + +from ..criteria import Criterion + +from ..distances import l1 + +from ..devutils import atleast_kd, flatten + +from .base import MinimizationAttack, get_is_adversarial +from .base import get_criterion +from .base import T +from .base import raise_if_kwargs +from ..distances import l2, linf + + +class HopSkipJump(MinimizationAttack): + """A powerful adversarial attack that requires neither gradients + nor probabilities [#Chen19]. + + Args: + init_attack : Attack to use to find a starting points. Defaults to + LinearSearchBlendedUniformNoiseAttack. Only used if starting_points is None. + steps : Number of optimization steps within each binary search step. + initial_gradient_eval_steps: Initial number of evaluations for gradient estimation. + Larger initial_num_evals increases time efficiency, but + may decrease query efficiency. + max_gradient_eval_steps : Maximum number of evaluations for gradient estimation. + stepsize_search : How to search for stepsize; choices are 'geometric_progression', + 'grid_search'. 'geometric progression' initializes the stepsize + by ||x_t - x||_p / sqrt(iteration), and keep decreasing by half + until reaching the target side of the boundary. 'grid_search' + chooses the optimal epsilon over a grid, in the scale of + ||x_t - x||_p. + gamma : The binary search threshold theta is gamma / d^1.5 for + l2 attack and gamma / d^2 for linf attack. + tensorboard : The log directory for TensorBoard summaries. If False, TensorBoard + summaries will be disabled (default). If None, the logdir will be + runs/CURRENT_DATETIME_HOSTNAME. + constraint : Norm to minimize, either "l2" or "linf" + + References: + .. [#Chen19] Jianbo Chen, Michael I. Jordan, Martin J. Wainwright, + "HopSkipJumpAttack: A Query-Efficient Decision-Based Attack", + https://arxiv.org/abs/1904.02144 + """ + + distance = l1 + + def __init__( + self, + init_attack: Optional[MinimizationAttack] = None, + steps: int = 64, + initial_gradient_eval_steps: int = 100, + max_gradient_eval_steps: int = 10000, + stepsize_search: Union[ + Literal["geometric_progression"], Literal["grid_search"] + ] = "geometric_progression", + gamma: float = 1.0, + tensorboard: Union[Literal[False], None, str] = False, + constraint: Union[Literal["linf"], Literal["l2"]] = "l2", + ): + if init_attack is not None and not isinstance(init_attack, MinimizationAttack): + raise NotImplementedError + self.init_attack = init_attack + self.steps = steps + self.initial_num_evals = initial_gradient_eval_steps + self.max_num_evals = max_gradient_eval_steps + self.stepsize_search = stepsize_search + self.gamma = gamma + self.tensorboard = tensorboard + self.constraint = constraint + + assert constraint in ("l2", "linf") + if constraint == "l2": + self.distance = l2 + else: + self.distance = linf + + def run( + self, + model: Model, + inputs: T, + criterion: Union[Criterion, T], + *, + early_stop: Optional[float] = None, + starting_points: Optional[T] = None, + **kwargs: Any, + ) -> T: + raise_if_kwargs(kwargs) + originals, restore_type = ep.astensor_(inputs) + del inputs, kwargs + + criterion = get_criterion(criterion) + is_adversarial = get_is_adversarial(criterion, model) + + if starting_points is None: + init_attack: MinimizationAttack + if self.init_attack is None: + init_attack = LinearSearchBlendedUniformNoiseAttack(steps=50) + logging.info( + f"Neither starting_points nor init_attack given. Falling" + f" back to {init_attack!r} for initialization." + ) + else: + init_attack = self.init_attack + # TODO: use call and support all types of attacks (once early_stop is + # possible in __call__) + x_advs = init_attack.run(model, originals, criterion, early_stop=early_stop) + else: + x_advs = ep.astensor(starting_points) + + is_adv = is_adversarial(x_advs) + if not is_adv.all(): + failed = is_adv.logical_not().float32().sum() + if starting_points is None: + raise ValueError( + f"init_attack failed for {failed} of {len(is_adv)} inputs" + ) + else: + raise ValueError( + f"{failed} of {len(is_adv)} starting_points are not adversarial" + ) + del starting_points + + tb = TensorBoard(logdir=self.tensorboard) + + # Project the initialization to the boundary. + x_advs = self._binary_search(is_adversarial, originals, x_advs) + + assert ep.all(is_adversarial(x_advs)) + + distances = self.distance(originals, x_advs) + + for step in range(self.steps): + delta = self.select_delta(originals, distances, step) + + # Choose number of gradient estimation steps. + num_gradient_estimation_steps = int( + min([self.initial_num_evals * math.sqrt(step + 1), self.max_num_evals]) + ) + + gradients = self.approximate_gradients( + is_adversarial, x_advs, num_gradient_estimation_steps, delta + ) + + if self.constraint == "linf": + update = ep.sign(gradients) + else: + update = gradients + + if self.stepsize_search == "geometric_progression": + # find step size. + epsilons = distances / math.sqrt(step + 1) + + while True: + x_advs_proposals = ep.clip( + x_advs + atleast_kd(epsilons, x_advs.ndim) * update, 0, 1 + ) + success = is_adversarial(x_advs_proposals) + epsilons = ep.where(success, epsilons, epsilons / 2.0) + + if ep.all(success): + break + + # Update the sample. + x_advs = ep.clip( + x_advs + atleast_kd(epsilons, update.ndim) * update, 0, 1 + ) + + assert ep.all(is_adversarial(x_advs)) + + # Binary search to return to the boundary. + x_advs = self._binary_search(is_adversarial, originals, x_advs) + + assert ep.all(is_adversarial(x_advs)) + + elif self.stepsize_search == "grid_search": + # Grid search for stepsize. + epsilons_grid = ep.expand_dims( + ep.from_numpy( + distances, + np.logspace(-4, 0, num=20, endpoint=True, dtype=np.float32), + ), + 1, + ) * ep.expand_dims(distances, 0) + + proposals_list = [] + + for epsilons in epsilons_grid: + x_advs_proposals = ( + x_advs + atleast_kd(epsilons, update.ndim) * update + ) + x_advs_proposals = ep.clip(x_advs_proposals, 0, 1) + + mask = is_adversarial(x_advs_proposals) + + x_advs_proposals = self._binary_search( + is_adversarial, originals, x_advs_proposals + ) + + # only use new values where initial guess was already adversarial + x_advs_proposals = ep.where( + atleast_kd(mask, x_advs.ndim), x_advs_proposals, x_advs + ) + + proposals_list.append(x_advs_proposals) + + proposals = ep.stack(proposals_list, 0) + proposals_distances = self.distance( + ep.expand_dims(originals, 0), proposals + ) + minimal_idx = ep.argmin(proposals_distances, 0) + + x_advs = proposals[minimal_idx] + + distances = self.distance(originals, x_advs) + + # log stats + tb.histogram("norms", distances, step) + + return restore_type(x_advs) + + def approximate_gradients( + self, + is_adversarial: Callable[[ep.Tensor], ep.Tensor], + x_advs: ep.Tensor, + steps: int, + delta: ep.Tensor, + ) -> ep.Tensor: + # (steps, bs, ...) + noise_shape = tuple([steps] + list(x_advs.shape)) + if self.constraint == "l2": + rv = ep.normal(x_advs, noise_shape) + elif self.constraint == "linf": + rv = ep.uniform(x_advs, low=-1, high=1, shape=noise_shape) + rv /= atleast_kd(ep.norms.l2(flatten(rv, keep=1), -1), rv.ndim) + 1e-12 + + scaled_rv = atleast_kd(ep.expand_dims(delta, 0), rv.ndim) * rv + + perturbed = ep.expand_dims(x_advs, 0) + scaled_rv + perturbed = ep.clip(perturbed, 0, 1) + + rv = (perturbed - x_advs) / 2 + + multipliers_list: List[ep.Tensor] = [] + for step in range(steps): + decision = is_adversarial(perturbed[step]) + multipliers_list.append( + ep.where( + decision, + ep.ones(x_advs, (len(x_advs,))), + -ep.ones(x_advs, (len(decision,))), + ) + ) + # (steps, bs, ...) + multipliers = ep.stack(multipliers_list, 0) + + vals = ep.where( + ep.abs(ep.mean(multipliers, axis=0, keepdims=True)) == 1, + multipliers, + multipliers - ep.mean(multipliers, axis=0, keepdims=True), + ) + grad = ep.mean(atleast_kd(vals, rv.ndim) * rv, axis=0) + + grad /= ep.norms.l2(atleast_kd(flatten(grad), grad.ndim)) + 1e-12 + + return grad + + def _project( + self, originals: ep.Tensor, perturbed: ep.Tensor, epsilons: ep.Tensor + ) -> ep.Tensor: + """Clips the perturbations to epsilon and returns the new perturbed + + Args: + originals: A batch of reference inputs. + perturbed: A batch of perturbed inputs. + epsilons: A batch of norm values to project to. + Returns: + A tensor like perturbed but with the perturbation clipped to epsilon. + """ + epsilons = atleast_kd(epsilons, originals.ndim) + if self.constraint == "linf": + perturbation = perturbed - originals + + # ep.clip does not support tensors as min/max + clipped_perturbed = ep.where( + perturbation > epsilons, originals + epsilons, perturbed + ) + clipped_perturbed = ep.where( + perturbation < -epsilons, originals - epsilons, clipped_perturbed + ) + return clipped_perturbed + else: + return (1.0 - epsilons) * originals + epsilons * perturbed + + def _binary_search( + self, + is_adversarial: Callable[[ep.Tensor], ep.Tensor], + originals: ep.Tensor, + perturbed: ep.Tensor, + ) -> ep.Tensor: + # Choose upper thresholds in binary search based on constraint. + d = np.prod(perturbed.shape[1:]) + if self.constraint == "linf": + highs = linf(originals, perturbed) + + # TODO: Check if the threshold is correct + # empirically this seems to be too low + thresholds = highs * self.gamma / (d * d) + else: + highs = ep.ones(perturbed, len(perturbed)) + thresholds = self.gamma / (d * math.sqrt(d)) + + lows = ep.zeros_like(highs) + + # use this variable to check when mids stays constant and the BS has converged + old_mids = highs + + while ep.any(highs - lows > thresholds): + mids = (lows + highs) / 2 + mids_perturbed = self._project(originals, perturbed, mids) + is_adversarial_ = is_adversarial(mids_perturbed) + + highs = ep.where(is_adversarial_, mids, highs) + lows = ep.where(is_adversarial_, lows, mids) + + # check of there is no more progress due to numerical imprecision + reached_numerical_precision = (old_mids == mids).all() + old_mids = mids + + if reached_numerical_precision: + # TODO: warn user + break + + res = self._project(originals, perturbed, highs) + + return res + + def select_delta( + self, originals: ep.Tensor, distances: ep.Tensor, step: int + ) -> ep.Tensor: + result: ep.Tensor + if step == 0: + result = 0.1 * ep.ones_like(distances) + else: + d = np.prod(originals.shape[1:]) + + if self.constraint == "linf": + theta = self.gamma / (d * d) + result = d * theta * distances + else: + theta = self.gamma / (d * np.sqrt(d)) + result = np.sqrt(d) * theta * distances + + return result diff --git a/tests/test_hsj_attack.py b/tests/test_hsj_attack.py new file mode 100644 index 00000000..d231c0f3 --- /dev/null +++ b/tests/test_hsj_attack.py @@ -0,0 +1,134 @@ +from typing import Tuple, Union, List, Any +import eagerpy as ep + +import foolbox as fbn +import foolbox.attacks as fa +from foolbox.devutils import flatten +from foolbox.attacks.brendel_bethge import BrendelBethgeAttack +import pytest + + +def get_attack_id(x: Tuple[BrendelBethgeAttack, Union[int, float]]) -> str: + return repr(x[0]) + + +attacks: List[Tuple[fa.Attack, Union[int, float]]] = [ + ( + fa.HopSkipJump( + steps=1, + constraint="linf", + initial_gradient_eval_steps=100, + max_gradient_eval_steps=100, + ), + ep.inf, + ), + ( + fa.HopSkipJump( + steps=1, + constraint="l2", + initial_gradient_eval_steps=100, + max_gradient_eval_steps=100, + ), + 2, + ), + ( + fa.HopSkipJump( + steps=1, + constraint="l2", + initial_gradient_eval_steps=100, + max_gradient_eval_steps=100, + stepsize_search="grid_search", + gamma=1e5, + ), + 2, + ), +] + + +@pytest.mark.parametrize("attack_and_p", attacks, ids=get_attack_id) +def test_hsj_untargeted_attack( + request: Any, + fmodel_and_data_ext_for_attacks: Tuple[ + Tuple[fbn.Model, ep.Tensor, ep.Tensor], bool + ], + attack_and_p: Tuple[fa.HopSkipJump, Union[int, float]], +) -> None: + if request.config.option.skipslow: + pytest.skip() + + (fmodel, x, y), real = fmodel_and_data_ext_for_attacks + + if isinstance(x, ep.NumPyTensor): + pytest.skip() + + x = (x - fmodel.bounds.lower) / (fmodel.bounds.upper - fmodel.bounds.lower) + fmodel = fmodel.transform_bounds((0, 1)) + + init_attack = fa.DatasetAttack() + init_attack.feed(fmodel, x) + init_advs = init_attack.run(fmodel, x, y) + + attack, p = attack_and_p + advs = attack.run(fmodel, x, y, starting_points=init_advs) + + init_norms = ep.norms.lp(flatten(init_advs - x), p=p, axis=-1) + norms = ep.norms.lp(flatten(advs - x), p=p, axis=-1) + + is_smaller = norms < init_norms + + assert fbn.accuracy(fmodel, advs, y) < fbn.accuracy(fmodel, x, y) + assert fbn.accuracy(fmodel, advs, y) <= fbn.accuracy(fmodel, init_advs, y) + assert is_smaller.any() + + +@pytest.mark.parametrize("attack_and_p", attacks, ids=get_attack_id) +def test_hsj_targeted_attack( + request: Any, + fmodel_and_data_ext_for_attacks: Tuple[ + Tuple[fbn.Model, ep.Tensor, ep.Tensor], bool + ], + attack_and_p: Tuple[fa.HopSkipJump, Union[int, float]], +) -> None: + if request.config.option.skipslow: + pytest.skip() + + (fmodel, x, y), real = fmodel_and_data_ext_for_attacks + + if isinstance(x, ep.NumPyTensor): + pytest.skip() + + if not real: + pytest.skip() + + x = (x - fmodel.bounds.lower) / (fmodel.bounds.upper - fmodel.bounds.lower) + fmodel = fmodel.transform_bounds((0, 1)) + + num_classes = fmodel(x).shape[-1] + y_np = y.numpy() + target_classes_np = (y_np + 1) % num_classes + for i in range(len(target_classes_np)): + while target_classes_np[i] not in y_np: + target_classes_np[i] = (target_classes_np[i] + 1) % num_classes + target_classes = ep.from_numpy(y, target_classes_np) + criterion = fbn.TargetedMisclassification(target_classes) + + init_attack = fa.DatasetAttack() + init_attack.feed(fmodel, x) + init_advs = init_attack.run(fmodel, x, criterion) + + attack, p = attack_and_p + advs = attack.run(fmodel, x, criterion, starting_points=init_advs) + + init_norms = ep.norms.lp(flatten(init_advs - x), p=p, axis=-1) + norms = ep.norms.lp(flatten(advs - x), p=p, axis=-1) + + is_smaller = norms < init_norms + + assert fbn.accuracy(fmodel, advs, y) < fbn.accuracy(fmodel, x, y) + assert fbn.accuracy(fmodel, advs, target_classes) > fbn.accuracy( + fmodel, x, target_classes + ) + assert fbn.accuracy(fmodel, advs, target_classes) >= fbn.accuracy( + fmodel, init_advs, target_classes + ) + assert is_smaller.any()