|
import torch |
|
|
|
|
|
def get_attention_mask(sequence_length, device, mask_type="block-causal", **kwargs): |
|
if mask_type.lower() == 'none' or mask_type is None: |
|
return None |
|
elif mask_type.lower() == 'block-causal': |
|
return _block_caulsal_mask_impl(sequence_length, device, **kwargs) |
|
elif mask_type.lower() == 'causal': |
|
return _caulsal_mask_impl(sequence_length, device, **kwargs) |
|
else: |
|
raise NotImplementedError(f"Mask type {mask_type} not implemented") |
|
|
|
|
|
def _block_caulsal_mask_impl(sequence_length, device, block_size=16, **kwargs): |
|
""" |
|
Create a block-causal mask |
|
""" |
|
assert sequence_length % block_size == 0, "for block causal masks sequence length must be divisible by block size" |
|
blocks = torch.ones(sequence_length // block_size, block_size, block_size, device=device) |
|
block_diag_enable_mask = torch.block_diag(*blocks) |
|
causal_enable_mask = torch.ones(sequence_length, sequence_length, device=device).tril_(0) |
|
disable_mask = ((block_diag_enable_mask + causal_enable_mask) < 0.5) |
|
return disable_mask |
|
|
|
|
|
def _caulsal_mask_impl(sequence_length, device, **kwargs): |
|
""" |
|
Create a causal mask |
|
""" |
|
causal_disable_mask = torch.triu( |
|
torch.full((sequence_length, sequence_length), float('-inf'), dtype=torch.float32, device=device), |
|
diagonal=1, |
|
) |
|
return causal_disable_mask |
|
|
|
|
|
if __name__ == '__main__': |
|
mask = get_attention_mask(9, "cuda", mask_type="block-causal", block_size=3) |
|
print(mask) |
|
mask = get_attention_mask(9, "cuda", mask_type="causal") |
|
print(mask) |