Spaces:
Build error
Build error
| #!/usr/bin/env python3 | |
| import threading | |
| import typing | |
| import warnings | |
| from collections import defaultdict | |
| from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union | |
| import torch | |
| from captum._utils.common import ( | |
| _reduce_list, | |
| _run_forward, | |
| _sort_key_list, | |
| _verify_select_neuron, | |
| ) | |
| from captum._utils.sample_gradient import SampleGradientWrapper | |
| from captum._utils.typing import ( | |
| Literal, | |
| ModuleOrModuleList, | |
| TargetType, | |
| TensorOrTupleOfTensorsGeneric, | |
| ) | |
| from torch import device, Tensor | |
| from torch.nn import Module | |
| def apply_gradient_requirements( | |
| inputs: Tuple[Tensor, ...], warn: bool = True | |
| ) -> List[bool]: | |
| """ | |
| Iterates through tuple on input tensors and sets requires_grad to be true on | |
| each Tensor, and ensures all grads are set to zero. To ensure that the input | |
| is returned to its initial state, a list of flags representing whether or not | |
| a tensor originally required grad is returned. | |
| """ | |
| assert isinstance( | |
| inputs, tuple | |
| ), "Inputs should be wrapped in a tuple prior to preparing for gradients" | |
| grad_required = [] | |
| for index, input in enumerate(inputs): | |
| assert isinstance(input, torch.Tensor), "Given input is not a torch.Tensor" | |
| grad_required.append(input.requires_grad) | |
| inputs_dtype = input.dtype | |
| # Note: torch 1.2 doesn't support is_complex for dtype that's why we check | |
| # on the existance of is_complex method. | |
| if not inputs_dtype.is_floating_point and not ( | |
| hasattr(inputs_dtype, "is_complex") and inputs_dtype.is_complex | |
| ): | |
| if warn: | |
| warnings.warn( | |
| """Input Tensor %d has a dtype of %s. | |
| Gradients cannot be activated | |
| for these data types.""" | |
| % (index, str(inputs_dtype)) | |
| ) | |
| elif not input.requires_grad: | |
| if warn: | |
| warnings.warn( | |
| "Input Tensor %d did not already require gradients, " | |
| "required_grads has been set automatically." % index | |
| ) | |
| input.requires_grad_() | |
| return grad_required | |
| def undo_gradient_requirements( | |
| inputs: Tuple[Tensor, ...], grad_required: List[bool] | |
| ) -> None: | |
| """ | |
| Iterates through list of tensors, zeros each gradient, and sets required | |
| grad to false if the corresponding index in grad_required is False. | |
| This method is used to undo the effects of prepare_gradient_inputs, making | |
| grads not required for any input tensor that did not initially require | |
| gradients. | |
| """ | |
| assert isinstance( | |
| inputs, tuple | |
| ), "Inputs should be wrapped in a tuple prior to preparing for gradients." | |
| assert len(inputs) == len( | |
| grad_required | |
| ), "Input tuple length should match gradient mask." | |
| for index, input in enumerate(inputs): | |
| assert isinstance(input, torch.Tensor), "Given input is not a torch.Tensor" | |
| if not grad_required[index]: | |
| input.requires_grad_(False) | |
| def compute_gradients( | |
| forward_fn: Callable, | |
| inputs: Union[Tensor, Tuple[Tensor, ...]], | |
| target_ind: TargetType = None, | |
| additional_forward_args: Any = None, | |
| ) -> Tuple[Tensor, ...]: | |
| r""" | |
| Computes gradients of the output with respect to inputs for an | |
| arbitrary forward function. | |
| Args: | |
| forward_fn: forward function. This can be for example model's | |
| forward function. | |
| input: Input at which gradients are evaluated, | |
| will be passed to forward_fn. | |
| target_ind: Index of the target class for which gradients | |
| must be computed (classification only). | |
| additional_forward_args: Additional input arguments that forward | |
| function requires. It takes an empty tuple (no additional | |
| arguments) if no additional arguments are required | |
| """ | |
| with torch.autograd.set_grad_enabled(True): | |
| # runs forward pass | |
| outputs = _run_forward(forward_fn, inputs, target_ind, additional_forward_args) | |
| assert outputs[0].numel() == 1, ( | |
| "Target not provided when necessary, cannot" | |
| " take gradient with respect to multiple outputs." | |
| ) | |
| # torch.unbind(forward_out) is a list of scalar tensor tuples and | |
| # contains batch_size * #steps elements | |
| grads = torch.autograd.grad(torch.unbind(outputs), inputs) | |
| return grads | |
| def _neuron_gradients( | |
| inputs: Union[Tensor, Tuple[Tensor, ...]], | |
| saved_layer: Dict[device, Tuple[Tensor, ...]], | |
| key_list: List[device], | |
| gradient_neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable], | |
| ) -> Tuple[Tensor, ...]: | |
| with torch.autograd.set_grad_enabled(True): | |
| gradient_tensors = [] | |
| for key in key_list: | |
| current_out_tensor = _verify_select_neuron( | |
| saved_layer[key], gradient_neuron_selector | |
| ) | |
| gradient_tensors.append( | |
| torch.autograd.grad( | |
| torch.unbind(current_out_tensor) | |
| if current_out_tensor.numel() > 1 | |
| else current_out_tensor, | |
| inputs, | |
| ) | |
| ) | |
| _total_gradients = _reduce_list(gradient_tensors, sum) | |
| return _total_gradients | |
| def _forward_layer_eval( | |
| forward_fn: Callable, | |
| inputs: Union[Tensor, Tuple[Tensor, ...]], | |
| layer: Module, | |
| additional_forward_args: Any = None, | |
| device_ids: Union[None, List[int]] = None, | |
| attribute_to_layer_input: bool = False, | |
| grad_enabled: bool = False, | |
| ) -> Tuple[Tensor, ...]: | |
| ... | |
| def _forward_layer_eval( | |
| forward_fn: Callable, | |
| inputs: Union[Tensor, Tuple[Tensor, ...]], | |
| layer: List[Module], | |
| additional_forward_args: Any = None, | |
| device_ids: Union[None, List[int]] = None, | |
| attribute_to_layer_input: bool = False, | |
| grad_enabled: bool = False, | |
| ) -> List[Tuple[Tensor, ...]]: | |
| ... | |
| def _forward_layer_eval( | |
| forward_fn: Callable, | |
| inputs: Union[Tensor, Tuple[Tensor, ...]], | |
| layer: ModuleOrModuleList, | |
| additional_forward_args: Any = None, | |
| device_ids: Union[None, List[int]] = None, | |
| attribute_to_layer_input: bool = False, | |
| grad_enabled: bool = False, | |
| ) -> Union[Tuple[Tensor, ...], List[Tuple[Tensor, ...]]]: | |
| return _forward_layer_eval_with_neuron_grads( | |
| forward_fn, | |
| inputs, | |
| layer, | |
| additional_forward_args=additional_forward_args, | |
| gradient_neuron_selector=None, | |
| grad_enabled=grad_enabled, | |
| device_ids=device_ids, | |
| attribute_to_layer_input=attribute_to_layer_input, | |
| ) | |
| def _forward_layer_distributed_eval( | |
| forward_fn: Callable, | |
| inputs: Any, | |
| layer: ModuleOrModuleList, | |
| target_ind: TargetType = None, | |
| additional_forward_args: Any = None, | |
| attribute_to_layer_input: bool = False, | |
| forward_hook_with_return: Literal[False] = False, | |
| require_layer_grads: bool = False, | |
| ) -> Dict[Module, Dict[device, Tuple[Tensor, ...]]]: | |
| ... | |
| def _forward_layer_distributed_eval( | |
| forward_fn: Callable, | |
| inputs: Any, | |
| layer: ModuleOrModuleList, | |
| target_ind: TargetType = None, | |
| additional_forward_args: Any = None, | |
| attribute_to_layer_input: bool = False, | |
| *, | |
| forward_hook_with_return: Literal[True], | |
| require_layer_grads: bool = False, | |
| ) -> Tuple[Dict[Module, Dict[device, Tuple[Tensor, ...]]], Tensor]: | |
| ... | |
| def _forward_layer_distributed_eval( | |
| forward_fn: Callable, | |
| inputs: Any, | |
| layer: ModuleOrModuleList, | |
| target_ind: TargetType = None, | |
| additional_forward_args: Any = None, | |
| attribute_to_layer_input: bool = False, | |
| forward_hook_with_return: bool = False, | |
| require_layer_grads: bool = False, | |
| ) -> Union[ | |
| Tuple[Dict[Module, Dict[device, Tuple[Tensor, ...]]], Tensor], | |
| Dict[Module, Dict[device, Tuple[Tensor, ...]]], | |
| ]: | |
| r""" | |
| A helper function that allows to set a hook on model's `layer`, run the forward | |
| pass and returns intermediate layer results, stored in a dictionary, | |
| and optionally also the output of the forward function. The keys in the | |
| dictionary are the device ids and the values are corresponding intermediate layer | |
| results, either the inputs or the outputs of the layer depending on whether we set | |
| `attribute_to_layer_input` to True or False. | |
| This is especially useful when we execute forward pass in a distributed setting, | |
| using `DataParallel`s for example. | |
| """ | |
| saved_layer: Dict[Module, Dict[device, Tuple[Tensor, ...]]] = defaultdict(dict) | |
| lock = threading.Lock() | |
| all_layers: List[Module] = [layer] if isinstance(layer, Module) else layer | |
| # Set a forward hook on specified module and run forward pass to | |
| # get layer output tensor(s). | |
| # For DataParallel models, each partition adds entry to dictionary | |
| # with key as device and value as corresponding Tensor. | |
| def hook_wrapper(original_module): | |
| def forward_hook(module, inp, out=None): | |
| eval_tsrs = inp if attribute_to_layer_input else out | |
| is_eval_tuple = isinstance(eval_tsrs, tuple) | |
| if not is_eval_tuple: | |
| eval_tsrs = (eval_tsrs,) | |
| if require_layer_grads: | |
| apply_gradient_requirements(eval_tsrs, warn=False) | |
| with lock: | |
| nonlocal saved_layer | |
| # Note that cloning behaviour of `eval_tsr` is different | |
| # when `forward_hook_with_return` is set to True. This is because | |
| # otherwise `backward()` on the last output layer won't execute. | |
| if forward_hook_with_return: | |
| saved_layer[original_module][eval_tsrs[0].device] = eval_tsrs | |
| eval_tsrs_to_return = tuple( | |
| eval_tsr.clone() for eval_tsr in eval_tsrs | |
| ) | |
| if not is_eval_tuple: | |
| eval_tsrs_to_return = eval_tsrs_to_return[0] | |
| return eval_tsrs_to_return | |
| else: | |
| saved_layer[original_module][eval_tsrs[0].device] = tuple( | |
| eval_tsr.clone() for eval_tsr in eval_tsrs | |
| ) | |
| return forward_hook | |
| all_hooks = [] | |
| try: | |
| for single_layer in all_layers: | |
| if attribute_to_layer_input: | |
| all_hooks.append( | |
| single_layer.register_forward_pre_hook(hook_wrapper(single_layer)) | |
| ) | |
| else: | |
| all_hooks.append( | |
| single_layer.register_forward_hook(hook_wrapper(single_layer)) | |
| ) | |
| output = _run_forward( | |
| forward_fn, | |
| inputs, | |
| target=target_ind, | |
| additional_forward_args=additional_forward_args, | |
| ) | |
| finally: | |
| for hook in all_hooks: | |
| hook.remove() | |
| if len(saved_layer) == 0: | |
| raise AssertionError("Forward hook did not obtain any outputs for given layer") | |
| if forward_hook_with_return: | |
| return saved_layer, output | |
| return saved_layer | |
| def _gather_distributed_tensors( | |
| saved_layer: Dict[device, Tuple[Tensor, ...]], | |
| device_ids: Union[None, List[int]] = None, | |
| key_list: Union[None, List[device]] = None, | |
| ) -> Tuple[Tensor, ...]: | |
| r""" | |
| A helper function to concatenate intermediate layer results stored on | |
| different devices in `saved_layer`. `saved_layer` is a dictionary that | |
| contains `device_id` as a key and intermediate layer results (either | |
| the input or the output of the layer) stored on the device corresponding to | |
| the key. | |
| `key_list` is a list of devices in appropriate ordering for concatenation | |
| and if not provided, keys are sorted based on device ids. | |
| If only one key exists (standard model), key list simply has one element. | |
| """ | |
| if key_list is None: | |
| key_list = _sort_key_list(list(saved_layer.keys()), device_ids) | |
| return _reduce_list([saved_layer[device_id] for device_id in key_list]) | |
| def _extract_device_ids( | |
| forward_fn: Callable, | |
| saved_layer: Dict[Module, Dict[device, Tuple[Tensor, ...]]], | |
| device_ids: Union[None, List[int]], | |
| ) -> Union[None, List[int]]: | |
| r""" | |
| A helper function to extract device_ids from `forward_function` in case it is | |
| provided as part of a `DataParallel` model or if is accessible from | |
| `forward_fn`. | |
| In case input device_ids is not None, this function returns that value. | |
| """ | |
| # Multiple devices / keys implies a DataParallel model, so we look for | |
| # device IDs if given or available from forward function | |
| # (DataParallel model object). | |
| if ( | |
| max(len(saved_layer[single_layer]) for single_layer in saved_layer) > 1 | |
| and device_ids is None | |
| ): | |
| if ( | |
| hasattr(forward_fn, "device_ids") | |
| and cast(Any, forward_fn).device_ids is not None | |
| ): | |
| device_ids = cast(Any, forward_fn).device_ids | |
| else: | |
| raise AssertionError( | |
| "Layer tensors are saved on multiple devices, however unable to access" | |
| " device ID list from the `forward_fn`. Device ID list must be" | |
| " accessible from `forward_fn`. For example, they can be retrieved" | |
| " if `forward_fn` is a model of type `DataParallel`. It is used" | |
| " for identifying device batch ordering." | |
| ) | |
| return device_ids | |
| def _forward_layer_eval_with_neuron_grads( | |
| forward_fn: Callable, | |
| inputs: Union[Tensor, Tuple[Tensor, ...]], | |
| layer: Module, | |
| additional_forward_args: Any = None, | |
| *, | |
| gradient_neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable], | |
| grad_enabled: bool = False, | |
| device_ids: Union[None, List[int]] = None, | |
| attribute_to_layer_input: bool = False, | |
| ) -> Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...]]: | |
| ... | |
| def _forward_layer_eval_with_neuron_grads( | |
| forward_fn: Callable, | |
| inputs: Union[Tensor, Tuple[Tensor, ...]], | |
| layer: Module, | |
| additional_forward_args: Any = None, | |
| gradient_neuron_selector: None = None, | |
| grad_enabled: bool = False, | |
| device_ids: Union[None, List[int]] = None, | |
| attribute_to_layer_input: bool = False, | |
| ) -> Tuple[Tensor, ...]: | |
| ... | |
| def _forward_layer_eval_with_neuron_grads( | |
| forward_fn: Callable, | |
| inputs: Union[Tensor, Tuple[Tensor, ...]], | |
| layer: List[Module], | |
| additional_forward_args: Any = None, | |
| gradient_neuron_selector: None = None, | |
| grad_enabled: bool = False, | |
| device_ids: Union[None, List[int]] = None, | |
| attribute_to_layer_input: bool = False, | |
| ) -> List[Tuple[Tensor, ...]]: | |
| ... | |
| def _forward_layer_eval_with_neuron_grads( | |
| forward_fn: Callable, | |
| inputs: Union[Tensor, Tuple[Tensor, ...]], | |
| layer: ModuleOrModuleList, | |
| additional_forward_args: Any = None, | |
| gradient_neuron_selector: Union[ | |
| None, int, Tuple[Union[int, slice], ...], Callable | |
| ] = None, | |
| grad_enabled: bool = False, | |
| device_ids: Union[None, List[int]] = None, | |
| attribute_to_layer_input: bool = False, | |
| ) -> Union[ | |
| Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...]], | |
| Tuple[Tensor, ...], | |
| List[Tuple[Tensor, ...]], | |
| ]: | |
| """ | |
| This method computes forward evaluation for a particular layer using a | |
| forward hook. If a gradient_neuron_selector is provided, then gradients with | |
| respect to that neuron in the layer output are also returned. | |
| These functionalities are combined due to the behavior of DataParallel models | |
| with hooks, in which hooks are executed once per device. We need to internally | |
| combine the separated tensors from devices by concatenating based on device_ids. | |
| Any necessary gradients must be taken with respect to each independent batched | |
| tensor, so the gradients are computed and combined appropriately. | |
| More information regarding the behavior of forward hooks with DataParallel models | |
| can be found in the PyTorch data parallel documentation. We maintain the separate | |
| evals in a dictionary protected by a lock, analogous to the gather implementation | |
| for the core PyTorch DataParallel implementation. | |
| """ | |
| grad_enabled = True if gradient_neuron_selector is not None else grad_enabled | |
| with torch.autograd.set_grad_enabled(grad_enabled): | |
| saved_layer = _forward_layer_distributed_eval( | |
| forward_fn, | |
| inputs, | |
| layer, | |
| additional_forward_args=additional_forward_args, | |
| attribute_to_layer_input=attribute_to_layer_input, | |
| ) | |
| device_ids = _extract_device_ids(forward_fn, saved_layer, device_ids) | |
| # Identifies correct device ordering based on device ids. | |
| # key_list is a list of devices in appropriate ordering for concatenation. | |
| # If only one key exists (standard model), key list simply has one element. | |
| key_list = _sort_key_list(list(next(iter(saved_layer.values())).keys()), device_ids) | |
| if gradient_neuron_selector is not None: | |
| assert isinstance( | |
| layer, Module | |
| ), "Cannot compute neuron gradients for multiple layers simultaneously!" | |
| inp_grads = _neuron_gradients( | |
| inputs, saved_layer[layer], key_list, gradient_neuron_selector | |
| ) | |
| return ( | |
| _gather_distributed_tensors(saved_layer[layer], key_list=key_list), | |
| inp_grads, | |
| ) | |
| else: | |
| if isinstance(layer, Module): | |
| return _gather_distributed_tensors(saved_layer[layer], key_list=key_list) | |
| else: | |
| return [ | |
| _gather_distributed_tensors(saved_layer[curr_layer], key_list=key_list) | |
| for curr_layer in layer | |
| ] | |
| def compute_layer_gradients_and_eval( | |
| forward_fn: Callable, | |
| layer: Module, | |
| inputs: Union[Tensor, Tuple[Tensor, ...]], | |
| target_ind: TargetType = None, | |
| additional_forward_args: Any = None, | |
| *, | |
| gradient_neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable], | |
| device_ids: Union[None, List[int]] = None, | |
| attribute_to_layer_input: bool = False, | |
| output_fn: Union[None, Callable] = None, | |
| ) -> Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...], Tuple[Tensor, ...]]: | |
| ... | |
| def compute_layer_gradients_and_eval( | |
| forward_fn: Callable, | |
| layer: List[Module], | |
| inputs: Union[Tensor, Tuple[Tensor, ...]], | |
| target_ind: TargetType = None, | |
| additional_forward_args: Any = None, | |
| gradient_neuron_selector: None = None, | |
| device_ids: Union[None, List[int]] = None, | |
| attribute_to_layer_input: bool = False, | |
| output_fn: Union[None, Callable] = None, | |
| ) -> Tuple[List[Tuple[Tensor, ...]], List[Tuple[Tensor, ...]]]: | |
| ... | |
| def compute_layer_gradients_and_eval( | |
| forward_fn: Callable, | |
| layer: Module, | |
| inputs: Union[Tensor, Tuple[Tensor, ...]], | |
| target_ind: TargetType = None, | |
| additional_forward_args: Any = None, | |
| gradient_neuron_selector: None = None, | |
| device_ids: Union[None, List[int]] = None, | |
| attribute_to_layer_input: bool = False, | |
| output_fn: Union[None, Callable] = None, | |
| ) -> Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...]]: | |
| ... | |
| def compute_layer_gradients_and_eval( | |
| forward_fn: Callable, | |
| layer: ModuleOrModuleList, | |
| inputs: Union[Tensor, Tuple[Tensor, ...]], | |
| target_ind: TargetType = None, | |
| additional_forward_args: Any = None, | |
| gradient_neuron_selector: Union[ | |
| None, int, Tuple[Union[int, slice], ...], Callable | |
| ] = None, | |
| device_ids: Union[None, List[int]] = None, | |
| attribute_to_layer_input: bool = False, | |
| output_fn: Union[None, Callable] = None, | |
| ) -> Union[ | |
| Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...]], | |
| Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...], Tuple[Tensor, ...]], | |
| Tuple[List[Tuple[Tensor, ...]], List[Tuple[Tensor, ...]]], | |
| ]: | |
| r""" | |
| Computes gradients of the output with respect to a given layer as well | |
| as the output evaluation of the layer for an arbitrary forward function | |
| and given input. | |
| For data parallel models, hooks are executed once per device ,so we | |
| need to internally combine the separated tensors from devices by | |
| concatenating based on device_ids. Any necessary gradients must be taken | |
| with respect to each independent batched tensor, so the gradients are | |
| computed and combined appropriately. | |
| More information regarding the behavior of forward hooks with DataParallel | |
| models can be found in the PyTorch data parallel documentation. We maintain | |
| the separate inputs in a dictionary protected by a lock, analogous to the | |
| gather implementation for the core PyTorch DataParallel implementation. | |
| NOTE: To properly handle inplace operations, a clone of the layer output | |
| is stored. This structure inhibits execution of a backward hook on the last | |
| module for the layer output when computing the gradient with respect to | |
| the input, since we store an intermediate clone, as | |
| opposed to the true module output. If backward module hooks are necessary | |
| for the final module when computing input gradients, utilize | |
| _forward_layer_eval_with_neuron_grads instead. | |
| Args: | |
| forward_fn: forward function. This can be for example model's | |
| forward function. | |
| layer: Layer for which gradients / output will be evaluated. | |
| inputs: Input at which gradients are evaluated, | |
| will be passed to forward_fn. | |
| target_ind: Index of the target class for which gradients | |
| must be computed (classification only). | |
| output_fn: An optional function that is applied to the layer inputs or | |
| outputs depending whether the `attribute_to_layer_input` is | |
| set to `True` or `False` | |
| args: Additional input arguments that forward function requires. | |
| It takes an empty tuple (no additional arguments) if no | |
| additional arguments are required | |
| Returns: | |
| 2-element tuple of **gradients**, **evals**: | |
| - **gradients**: | |
| Gradients of output with respect to target layer output. | |
| - **evals**: | |
| Target layer output for given input. | |
| """ | |
| with torch.autograd.set_grad_enabled(True): | |
| # saved_layer is a dictionary mapping device to a tuple of | |
| # layer evaluations on that device. | |
| saved_layer, output = _forward_layer_distributed_eval( | |
| forward_fn, | |
| inputs, | |
| layer, | |
| target_ind=target_ind, | |
| additional_forward_args=additional_forward_args, | |
| attribute_to_layer_input=attribute_to_layer_input, | |
| forward_hook_with_return=True, | |
| require_layer_grads=True, | |
| ) | |
| assert output[0].numel() == 1, ( | |
| "Target not provided when necessary, cannot" | |
| " take gradient with respect to multiple outputs." | |
| ) | |
| device_ids = _extract_device_ids(forward_fn, saved_layer, device_ids) | |
| # Identifies correct device ordering based on device ids. | |
| # key_list is a list of devices in appropriate ordering for concatenation. | |
| # If only one key exists (standard model), key list simply has one element. | |
| key_list = _sort_key_list( | |
| list(next(iter(saved_layer.values())).keys()), device_ids | |
| ) | |
| all_outputs: Union[Tuple[Tensor, ...], List[Tuple[Tensor, ...]]] | |
| if isinstance(layer, Module): | |
| all_outputs = _reduce_list( | |
| [ | |
| saved_layer[layer][device_id] | |
| if output_fn is None | |
| else output_fn(saved_layer[layer][device_id]) | |
| for device_id in key_list | |
| ] | |
| ) | |
| else: | |
| all_outputs = [ | |
| _reduce_list( | |
| [ | |
| saved_layer[single_layer][device_id] | |
| if output_fn is None | |
| else output_fn(saved_layer[single_layer][device_id]) | |
| for device_id in key_list | |
| ] | |
| ) | |
| for single_layer in layer | |
| ] | |
| all_layers: List[Module] = [layer] if isinstance(layer, Module) else layer | |
| grad_inputs = tuple( | |
| layer_tensor | |
| for single_layer in all_layers | |
| for device_id in key_list | |
| for layer_tensor in saved_layer[single_layer][device_id] | |
| ) | |
| saved_grads = torch.autograd.grad(torch.unbind(output), grad_inputs) | |
| offset = 0 | |
| all_grads: List[Tuple[Tensor, ...]] = [] | |
| for single_layer in all_layers: | |
| num_tensors = len(next(iter(saved_layer[single_layer].values()))) | |
| curr_saved_grads = [ | |
| saved_grads[i : i + num_tensors] | |
| for i in range( | |
| offset, offset + len(key_list) * num_tensors, num_tensors | |
| ) | |
| ] | |
| offset += len(key_list) * num_tensors | |
| if output_fn is not None: | |
| curr_saved_grads = [ | |
| output_fn(curr_saved_grad) for curr_saved_grad in curr_saved_grads | |
| ] | |
| all_grads.append(_reduce_list(curr_saved_grads)) | |
| layer_grads: Union[Tuple[Tensor, ...], List[Tuple[Tensor, ...]]] | |
| layer_grads = all_grads | |
| if isinstance(layer, Module): | |
| layer_grads = all_grads[0] | |
| if gradient_neuron_selector is not None: | |
| assert isinstance( | |
| layer, Module | |
| ), "Cannot compute neuron gradients for multiple layers simultaneously!" | |
| inp_grads = _neuron_gradients( | |
| inputs, saved_layer[layer], key_list, gradient_neuron_selector | |
| ) | |
| return ( | |
| cast(Tuple[Tensor, ...], layer_grads), | |
| cast(Tuple[Tensor, ...], all_outputs), | |
| inp_grads, | |
| ) | |
| return layer_grads, all_outputs # type: ignore | |
| def construct_neuron_grad_fn( | |
| layer: Module, | |
| neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable], | |
| device_ids: Union[None, List[int]] = None, | |
| attribute_to_neuron_input: bool = False, | |
| ) -> Callable: | |
| def grad_fn( | |
| forward_fn: Callable, | |
| inputs: TensorOrTupleOfTensorsGeneric, | |
| target_ind: TargetType = None, | |
| additional_forward_args: Any = None, | |
| ) -> Tuple[Tensor, ...]: | |
| _, grads = _forward_layer_eval_with_neuron_grads( | |
| forward_fn, | |
| inputs, | |
| layer, | |
| additional_forward_args, | |
| gradient_neuron_selector=neuron_selector, | |
| device_ids=device_ids, | |
| attribute_to_layer_input=attribute_to_neuron_input, | |
| ) | |
| return grads | |
| return grad_fn | |
| def _compute_jacobian_wrt_params( | |
| model: Module, | |
| inputs: Tuple[Any, ...], | |
| labels: Optional[Tensor] = None, | |
| loss_fn: Optional[Union[Module, Callable]] = None, | |
| ) -> Tuple[Tensor, ...]: | |
| r""" | |
| Computes the Jacobian of a batch of test examples given a model, and optional | |
| loss function and target labels. This method is equivalent to calculating the | |
| gradient for every individual example in the minibatch. | |
| Args: | |
| model (torch.nn.Module): The trainable model providing the forward pass | |
| inputs (tuple of Any): The minibatch for which the forward pass is computed. | |
| It is unpacked before passing to `model`, so it must be a tuple. The | |
| individual elements of `inputs` can be anything. | |
| labels (Tensor or None): Labels for input if computing a loss function. | |
| loss_fn (torch.nn.Module or Callable or None): The loss function. If a library | |
| defined loss function is provided, it would be expected to be a | |
| torch.nn.Module. If a custom loss is provided, it can be either type, | |
| but must behave as a library loss function would if `reduction='none'`. | |
| Returns: | |
| grads (Tuple of Tensor): Returns the Jacobian for the minibatch as a | |
| tuple of gradients corresponding to the tuple of trainable parameters | |
| returned by `model.parameters()`. Each object grads[i] references to the | |
| gradients for the parameters in the i-th trainable layer of the model. | |
| Each grads[i] object is a tensor with the gradients for the `inputs` | |
| batch. For example, grads[i][j] would reference the gradients for the | |
| parameters of the i-th layer, for the j-th member of the minibatch. | |
| """ | |
| with torch.autograd.set_grad_enabled(True): | |
| out = model(*inputs) | |
| assert out.dim() != 0, "Please ensure model output has at least one dimension." | |
| if labels is not None and loss_fn is not None: | |
| loss = loss_fn(out, labels) | |
| if hasattr(loss_fn, "reduction"): | |
| msg0 = "Please ensure loss_fn.reduction is set to `none`" | |
| assert loss_fn.reduction == "none", msg0 # type: ignore | |
| else: | |
| msg1 = ( | |
| "Loss function is applying a reduction. Please ensure " | |
| f"Output shape: {out.shape} and Loss shape: {loss.shape} " | |
| "are matching." | |
| ) | |
| assert loss.dim() != 0, msg1 | |
| assert out.shape[0] == loss.shape[0], msg1 | |
| out = loss | |
| grads_list = [ | |
| torch.autograd.grad( | |
| outputs=out[i], | |
| inputs=model.parameters(), # type: ignore | |
| grad_outputs=torch.ones_like(out[i]), | |
| retain_graph=True, | |
| ) | |
| for i in range(out.shape[0]) | |
| ] | |
| grads = tuple([torch.stack(x) for x in zip(*grads_list)]) | |
| return tuple(grads) | |
| def _compute_jacobian_wrt_params_with_sample_wise_trick( | |
| model: Module, | |
| inputs: Tuple[Any, ...], | |
| labels: Optional[Tensor] = None, | |
| loss_fn: Optional[Union[Module, Callable]] = None, | |
| reduction_type: Optional[str] = "sum", | |
| ) -> Tuple[Any, ...]: | |
| r""" | |
| Computes the Jacobian of a batch of test examples given a model, and optional | |
| loss function and target labels. This method uses sample-wise gradients per | |
| batch trick to fully vectorize the Jacobian calculation. Currently, only | |
| linear and conv2d layers are supported. | |
| User must `add_hooks(model)` before calling this function. | |
| Args: | |
| model (torch.nn.Module): The trainable model providing the forward pass | |
| inputs (tuple of Any): The minibatch for which the forward pass is computed. | |
| It is unpacked before passing to `model`, so it must be a tuple. The | |
| individual elements of `inputs` can be anything. | |
| labels (Tensor or None): Labels for input if computing a loss function. | |
| loss_fn (torch.nn.Module or Callable or None): The loss function. If a library | |
| defined loss function is provided, it would be expected to be a | |
| torch.nn.Module. If a custom loss is provided, it can be either type, | |
| but must behave as a library loss function would if `reduction='sum'` or | |
| `reduction='mean'`. | |
| reduction_type (str): The type of reduction applied. If a loss_fn is passed, | |
| this should match `loss_fn.reduction`. Else if gradients are being | |
| computed on direct model outputs (scores), then 'sum' should be used. | |
| Defaults to 'sum'. | |
| Returns: | |
| grads (Tuple of Tensor): Returns the Jacobian for the minibatch as a | |
| tuple of gradients corresponding to the tuple of trainable parameters | |
| returned by `model.parameters()`. Each object grads[i] references to the | |
| gradients for the parameters in the i-th trainable layer of the model. | |
| Each grads[i] object is a tensor with the gradients for the `inputs` | |
| batch. For example, grads[i][j] would reference the gradients for the | |
| parameters of the i-th layer, for the j-th member of the minibatch. | |
| """ | |
| with torch.autograd.set_grad_enabled(True): | |
| sample_grad_wrapper = SampleGradientWrapper(model) | |
| try: | |
| sample_grad_wrapper.add_hooks() | |
| out = model(*inputs) | |
| assert ( | |
| out.dim() != 0 | |
| ), "Please ensure model output has at least one dimension." | |
| if labels is not None and loss_fn is not None: | |
| loss = loss_fn(out, labels) | |
| # TODO: allow loss_fn to be Callable | |
| if isinstance(loss_fn, Module) and hasattr(loss_fn, "reduction"): | |
| msg0 = ( | |
| "Please ensure that loss_fn.reduction is set to `sum` or `mean`" | |
| ) | |
| assert loss_fn.reduction != "none", msg0 | |
| msg1 = ( | |
| f"loss_fn.reduction ({loss_fn.reduction}) does not match" | |
| f"reduction type ({reduction_type}). Please ensure they are" | |
| " matching." | |
| ) | |
| assert loss_fn.reduction == reduction_type, msg1 | |
| msg2 = ( | |
| "Please ensure custom loss function is applying either a " | |
| "sum or mean reduction." | |
| ) | |
| assert out.shape != loss.shape, msg2 | |
| if reduction_type != "sum" and reduction_type != "mean": | |
| raise ValueError( | |
| f"{reduction_type} is not a valid value for reduction_type. " | |
| "Must be either 'sum' or 'mean'." | |
| ) | |
| out = loss | |
| sample_grad_wrapper.compute_param_sample_gradients( | |
| out, loss_mode=reduction_type | |
| ) | |
| grads = tuple( | |
| param.sample_grad # type: ignore | |
| for param in model.parameters() | |
| if hasattr(param, "sample_grad") | |
| ) | |
| finally: | |
| sample_grad_wrapper.remove_hooks() | |
| return grads | |