File size: 5,970 Bytes
			
			| 90f1c7e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 | # Adapted from https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py
# Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py
"""
Functions for padding and unpadding 
"""
from typing import Tuple, cast
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
class IndexFirstAxis(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input: torch.Tensor,
                indices: torch.Tensor) -> torch.Tensor:
        """Get just the values of `input` which are at `indices`.
        Arguments:
            ctx: the autograd context object
            input: (b, ...) 2+ dimensional tensor
            indices: (num_idx) 1D tensor
        """
        ctx.save_for_backward(indices)
        assert input.ndim >= 2
        ctx.first_axis_dim, other_shape = input.shape[0], input.shape[
            1:]  
        second_dim = other_shape.numel(
        )  # product of sizes of all but first dimension
        # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
        return torch.gather(
            rearrange(input, 'b ... -> b (...)'),  # (b, ...) -> (b, second_dim)
            0,
            repeat(indices, 'z -> z d',
                   d=second_dim)  # (indices,) -> (indices, second_dim)
        ).reshape(-1, *other_shape)  # (num_idx, ...)
    @staticmethod
    def backward(ctx, grad_output: torch.Tensor) -> Tuple[torch.Tensor, None]:
        indices, = ctx.saved_tensors
        assert grad_output.ndim >= 2
        other_shape = grad_output.shape[1:]
        grad_output = rearrange(grad_output, 'b ... -> b (...)')
        grad_input = torch.zeros([ctx.first_axis_dim, grad_output.shape[1]],
                                 device=grad_output.device,
                                 dtype=grad_output.dtype)
        # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
        # grad_input[indices] = grad_output
        grad_input.scatter_(0,
                            repeat(indices, 'z -> z d', d=grad_output.shape[1]),
                            grad_output)
        return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
index_first_axis = IndexFirstAxis.apply
class IndexPutFirstAxis(torch.autograd.Function):
    @staticmethod
    def forward(ctx, values: torch.Tensor, indices: torch.Tensor,
                first_axis_dim) -> torch.Tensor:
        ctx.save_for_backward(indices)
        assert indices.ndim == 1
        assert values.ndim >= 2
        output = torch.zeros(first_axis_dim,
                             *values.shape[1:],
                             device=values.device,
                             dtype=values.dtype)
        output[indices] = values
        return output
    @staticmethod
    def backward(ctx,
                 grad_output: torch.Tensor) -> Tuple[torch.Tensor, None, None]:
        indices, = ctx.saved_tensors
        grad_values = grad_output[indices]
        return grad_values, None, None
index_put_first_axis = IndexPutFirstAxis.apply
def unpad_input(
    hidden_states: torch.Tensor,
    attention_mask: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]:
    """Remove padding from input sequences.
    Arguments:
        hidden_states: (batch, seqlen, ...)
        attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
    Returns:
        hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
        cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
        max_seqlen_in_batch: int
    """
    seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
    indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
    max_seqlen_in_batch = int(seqlens_in_batch.max().item())
    cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32),
                       (1, 0))
    # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
    # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
    # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
    # index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
    # so we write custom forward and backward to make it a bit faster.
    hidden_states = cast(
        torch.Tensor,
        index_first_axis(rearrange(hidden_states, 'b s ... -> (b s) ...'),
                         indices))
    return hidden_states, indices, cu_seqlens, max_seqlen_in_batch
def unpad_input_only(
    hidden_states: torch.Tensor,
    attention_mask: torch.Tensor,
) -> torch.Tensor:
    """Like unpad_input, but only return the unpadded first tensor.
    Save a small amount of overhead.
    Arguments:
        hidden_states: (batch, seqlen, ...)
        attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
    Returns:
        hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
    """
    indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
    rearranged = rearrange(hidden_states, 'b s ... -> (b s) ...')
    return index_first_axis(rearranged, indices)  # type: ignore
def pad_input(hidden_states: torch.Tensor, indices: torch.Tensor, batch: int,
              seqlen: int) -> torch.Tensor:
    """Add padding to sequences.
    Arguments:
        hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
        indices: (total_nnz)
    Returns:
        hidden_states: (batch, seqlen, ...)
    """
    output = index_put_first_axis(hidden_states, indices, batch * seqlen)
    return rearrange(output, '(b s) ... -> b s ...', b=batch)  # type: ignore
 | 
