File size: 1,587 Bytes
14ce5a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import torch


def get_attention_mask(sequence_length, device, mask_type="block-causal", **kwargs):
    if mask_type.lower() == 'none' or mask_type is None:
        return None
    elif mask_type.lower() == 'block-causal':
        return _block_caulsal_mask_impl(sequence_length, device, **kwargs)
    elif mask_type.lower() == 'causal':
        return _caulsal_mask_impl(sequence_length, device, **kwargs)
    else:
        raise NotImplementedError(f"Mask type {mask_type} not implemented")


def _block_caulsal_mask_impl(sequence_length, device, block_size=16, **kwargs):
    """
    Create a block-causal mask
    """
    assert sequence_length % block_size == 0, "for block causal masks sequence length must be divisible by block size"
    blocks = torch.ones(sequence_length // block_size, block_size, block_size, device=device)
    block_diag_enable_mask = torch.block_diag(*blocks)
    causal_enable_mask = torch.ones(sequence_length, sequence_length, device=device).tril_(0)
    disable_mask = ((block_diag_enable_mask + causal_enable_mask) < 0.5)
    return disable_mask


def _caulsal_mask_impl(sequence_length, device, **kwargs):
    """
    Create a causal mask
    """
    causal_disable_mask = torch.triu(
        torch.full((sequence_length, sequence_length), float('-inf'), dtype=torch.float32, device=device),
        diagonal=1,
    )
    return causal_disable_mask


if __name__ == '__main__':
    mask = get_attention_mask(9, "cuda", mask_type="block-causal", block_size=3)
    print(mask)
    mask = get_attention_mask(9, "cuda", mask_type="causal")
    print(mask)