"""Base Module.""" # Copyright (c) 2018-2022 Lynton Ardizzone, Visual Learning Lab Heidelberg. # SPDX-License-Identifier: MIT # # flake8: noqa # pylint: skip-file # type: ignore # pydocstyle: noqa from typing import Iterable, List, Tuple import torch.nn as nn from torch import Tensor class InvertibleModule(nn.Module): r"""Base class for all invertible modules in FrEIA. Given ``module``, an instance of some InvertibleModule. This ``module`` shall be invertible in its input dimensions, so that the input can be recovered by applying the module in backwards mode (``rev=True``), not to be confused with ``pytorch.backward()`` which computes the gradient of an operation:: x = torch.randn(BATCH_SIZE, DIM_COUNT) c = torch.randn(BATCH_SIZE, CONDITION_DIM) # Forward mode z, jac = module([x], [c], jac=True) # Backward mode x_rev, jac_rev = module(z, [c], rev=True) The ``module`` returns :math:`\\log \\det J = \\log \\left| \\det \\frac{\\partial f}{\\partial x} \\right|` of the operation in forward mode, and :math:`-\\log | \\det J | = \\log \\left| \\det \\frac{\\partial f^{-1}}{\\partial z} \\right| = -\\log \\left| \\det \\frac{\\partial f}{\\partial x} \\right|` in backward mode (``rev=True``). Then, ``torch.allclose(x, x_rev) == True`` and ``torch.allclose(jac, -jac_rev) == True``. """ def __init__(self, dims_in: Iterable[Tuple[int]], dims_c: Iterable[Tuple[int]] = None): """Initialize. Args: dims_in: list of tuples specifying the shape of the inputs to this operator: ``dims_in = [shape_x_0, shape_x_1, ...]`` dims_c: list of tuples specifying the shape of the conditions to this operator. """ super().__init__() if dims_c is None: dims_c = [] self.dims_in = list(dims_in) self.dims_c = list(dims_c) def forward( self, x_or_z: Iterable[Tensor], c: Iterable[Tensor] = None, rev: bool = False, jac: bool = True ) -> Tuple[Tuple[Tensor], Tensor]: r"""Forward/Backward Pass. Perform a forward (default, ``rev=False``) or backward pass (``rev=True``) through this module/operator. **Note to implementers:** - Subclasses MUST return a Jacobian when ``jac=True``, but CAN return a valid Jacobian when ``jac=False`` (not punished). The latter is only recommended if the computation of the Jacobian is trivial. - Subclasses MUST follow the convention that the returned Jacobian be consistent with the evaluation direction. Let's make this more precise: Let :math:`f` be the function that the subclass represents. Then: .. math:: J &= \\log \\det \\frac{\\partial f}{\\partial x} \\\\ -J &= \\log \\det \\frac{\\partial f^{-1}}{\\partial z}. Any subclass MUST return :math:`J` for forward evaluation (``rev=False``), and :math:`-J` for backward evaluation (``rev=True``). Args: x_or_z: input data (array-like of one or more tensors) c: conditioning data (array-like of none or more tensors) rev: perform backward pass jac: return Jacobian associated to the direction """ raise NotImplementedError(f"{self.__class__.__name__} does not provide forward(...) method") def log_jacobian(self, *args, **kwargs): """This method is deprecated, and does nothing except raise a warning.""" raise DeprecationWarning( "module.log_jacobian(...) is deprecated. " "module.forward(..., jac=True) returns a " "tuple (out, jacobian) now." ) def output_dims(self, input_dims: List[Tuple[int]]) -> List[Tuple[int]]: """Use for shape inference during construction of the graph. MUST be implemented for each subclass of ``InvertibleModule``. Args: input_dims: A list with one entry for each input to the module. Even if the module only has one input, must be a list with one entry. Each entry is a tuple giving the shape of that input, excluding the batch dimension. For example for a module with one input, which receives a 32x32 pixel RGB image, ``input_dims`` would be ``[(3, 32, 32)]`` Returns: A list structured in the same way as ``input_dims``. Each entry represents one output of the module, and the entry is a tuple giving the shape of that output. For example if the module splits the image into a right and a left half, the return value should be ``[(3, 16, 32), (3, 16, 32)]``. It is up to the implementor of the subclass to ensure that the total number of elements in all inputs and all outputs is consistent. """ raise NotImplementedError(f"{self.__class__.__name__} does not provide output_dims(...)")