Spaces:
Build error
Build error
| #!/usr/bin/env python3 | |
| from typing import Any, Callable, Tuple | |
| import torch | |
| from captum._utils.common import ( | |
| _format_additional_forward_args, | |
| _format_output, | |
| _format_tensor_into_tuples, | |
| _is_tuple, | |
| _select_targets, | |
| ) | |
| from captum._utils.gradient import ( | |
| apply_gradient_requirements, | |
| compute_gradients, | |
| undo_gradient_requirements, | |
| ) | |
| from captum._utils.typing import TensorOrTupleOfTensorsGeneric | |
| from captum.robust._core.perturbation import Perturbation | |
| from torch import Tensor | |
| class FGSM(Perturbation): | |
| r""" | |
| Fast Gradient Sign Method is an one-step method that can generate | |
| adversarial examples. For non-targeted attack, the formulation is | |
| x' = x + epsilon * sign(gradient of L(theta, x, y)). | |
| For targeted attack on t, the formulation is | |
| x' = x - epsilon * sign(gradient of L(theta, x, t)). | |
| L(theta, x, y) is the model's loss function with respect to model | |
| parameters, inputs and labels. | |
| More details on Fast Gradient Sign Method can be found in the original | |
| paper: | |
| https://arxiv.org/pdf/1412.6572.pdf | |
| """ | |
| def __init__( | |
| self, | |
| forward_func: Callable, | |
| loss_func: Callable = None, | |
| lower_bound: float = float("-inf"), | |
| upper_bound: float = float("inf"), | |
| ) -> None: | |
| r""" | |
| Args: | |
| forward_func (callable): The pytorch model for which the attack is | |
| computed. | |
| loss_func (callable, optional): Loss function of which the gradient | |
| computed. The loss function should take in outputs of the | |
| model and labels, and return a loss tensor. | |
| The default loss function is negative log. | |
| lower_bound (float, optional): Lower bound of input values. | |
| upper_bound (float, optional): Upper bound of input values. | |
| e.g. image pixels must be in the range 0-255 | |
| Attributes: | |
| bound (Callable): A function that bounds the input values based on | |
| given lower_bound and upper_bound. Can be overwritten for | |
| custom use cases if necessary. | |
| zero_thresh (float): The threshold below which gradient will be treated | |
| as zero. Can be modified for custom use cases if necessary. | |
| """ | |
| super().__init__() | |
| self.forward_func = forward_func | |
| self.loss_func = loss_func | |
| self.bound = lambda x: torch.clamp(x, min=lower_bound, max=upper_bound) | |
| self.zero_thresh = 10 ** -6 | |
| def perturb( | |
| self, | |
| inputs: TensorOrTupleOfTensorsGeneric, | |
| epsilon: float, | |
| target: Any, | |
| additional_forward_args: Any = None, | |
| targeted: bool = False, | |
| ) -> TensorOrTupleOfTensorsGeneric: | |
| r""" | |
| This method computes and returns the perturbed input for each input tensor. | |
| It supports both targeted and non-targeted attacks. | |
| Args: | |
| inputs (tensor or tuple of tensors): Input for which adversarial | |
| attack is computed. It can be provided as a single | |
| tensor or a tuple of multiple tensors. If multiple | |
| input tensors are provided, the batch sizes must be | |
| aligned accross all tensors. | |
| epsilon (float): Step size of perturbation. | |
| target (any): True labels of inputs if non-targeted attack is | |
| desired. Target class of inputs if targeted attack | |
| is desired. Target will be passed to the loss function | |
| to compute loss, so the type needs to match the | |
| argument type of the loss function. | |
| If using the default negative log as loss function, | |
| labels should be of type int, tuple, tensor or list. | |
| For general 2D outputs, labels can be either: | |
| - a single integer or a tensor containing a single | |
| integer, which is applied to all input examples | |
| - a list of integers or a 1D tensor, with length matching | |
| the number of examples in inputs (dim 0). Each integer | |
| is applied as the label for the corresponding example. | |
| For outputs with > 2 dimensions, labels can be either: | |
| - A single tuple, which contains #output_dims - 1 | |
| elements. This label index is applied to all examples. | |
| - A list of tuples with length equal to the number of | |
| examples in inputs (dim 0), and each tuple containing | |
| #output_dims - 1 elements. Each tuple is applied as the | |
| label for the corresponding example. | |
| additional_forward_args (any, optional): If the forward function | |
| requires additional arguments other than the inputs for | |
| which attributions should not be computed, this argument | |
| can be provided. These arguments are provided to | |
| forward_func in order following the arguments in inputs. | |
| Default: None. | |
| targeted (bool, optional): If attack should be targeted. | |
| Default: False. | |
| Returns: | |
| - **perturbed inputs** (*tensor* or tuple of *tensors*): | |
| Perturbed input for each | |
| input tensor. The perturbed inputs have the same shape and | |
| dimensionality as the inputs. | |
| If a single tensor is provided as inputs, a single tensor | |
| is returned. If a tuple is provided for inputs, a tuple of | |
| corresponding sized tensors is returned. | |
| """ | |
| is_inputs_tuple = _is_tuple(inputs) | |
| inputs: Tuple[Tensor, ...] = _format_tensor_into_tuples(inputs) | |
| gradient_mask = apply_gradient_requirements(inputs) | |
| def _forward_with_loss() -> Tensor: | |
| additional_inputs = _format_additional_forward_args(additional_forward_args) | |
| outputs = self.forward_func( # type: ignore | |
| *(*inputs, *additional_inputs) # type: ignore | |
| if additional_inputs is not None | |
| else inputs | |
| ) | |
| if self.loss_func is not None: | |
| return self.loss_func(outputs, target) | |
| else: | |
| loss = -torch.log(outputs) | |
| return _select_targets(loss, target) | |
| grads = compute_gradients(_forward_with_loss, inputs) | |
| undo_gradient_requirements(inputs, gradient_mask) | |
| perturbed_inputs = self._perturb(inputs, grads, epsilon, targeted) | |
| perturbed_inputs = tuple( | |
| self.bound(perturbed_inputs[i]) for i in range(len(perturbed_inputs)) | |
| ) | |
| return _format_output(is_inputs_tuple, perturbed_inputs) | |
| def _perturb( | |
| self, | |
| inputs: Tuple, | |
| grads: Tuple, | |
| epsilon: float, | |
| targeted: bool, | |
| ) -> Tuple: | |
| r""" | |
| A helper function to calculate the perturbed inputs given original | |
| inputs, gradient of loss function and epsilon. The calculation is | |
| different for targetd v.s. non-targeted as described above. | |
| """ | |
| multiplier = -1 if targeted else 1 | |
| inputs = tuple( | |
| torch.where( | |
| torch.abs(grad) > self.zero_thresh, | |
| inp + multiplier * epsilon * torch.sign(grad), | |
| inp, | |
| ) | |
| for grad, inp in zip(grads, inputs) | |
| ) | |
| return inputs | |