File size: 1,587 Bytes
14ce5a9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
import torch
def get_attention_mask(sequence_length, device, mask_type="block-causal", **kwargs):
if mask_type.lower() == 'none' or mask_type is None:
return None
elif mask_type.lower() == 'block-causal':
return _block_caulsal_mask_impl(sequence_length, device, **kwargs)
elif mask_type.lower() == 'causal':
return _caulsal_mask_impl(sequence_length, device, **kwargs)
else:
raise NotImplementedError(f"Mask type {mask_type} not implemented")
def _block_caulsal_mask_impl(sequence_length, device, block_size=16, **kwargs):
"""
Create a block-causal mask
"""
assert sequence_length % block_size == 0, "for block causal masks sequence length must be divisible by block size"
blocks = torch.ones(sequence_length // block_size, block_size, block_size, device=device)
block_diag_enable_mask = torch.block_diag(*blocks)
causal_enable_mask = torch.ones(sequence_length, sequence_length, device=device).tril_(0)
disable_mask = ((block_diag_enable_mask + causal_enable_mask) < 0.5)
return disable_mask
def _caulsal_mask_impl(sequence_length, device, **kwargs):
"""
Create a causal mask
"""
causal_disable_mask = torch.triu(
torch.full((sequence_length, sequence_length), float('-inf'), dtype=torch.float32, device=device),
diagonal=1,
)
return causal_disable_mask
if __name__ == '__main__':
mask = get_attention_mask(9, "cuda", mask_type="block-causal", block_size=3)
print(mask)
mask = get_attention_mask(9, "cuda", mask_type="causal")
print(mask) |