VTBench / src /vqvaes /bsqvit /attention_mask.py
huaweilin's picture
update
14ce5a9
import torch
def get_attention_mask(sequence_length, device, mask_type="block-causal", **kwargs):
if mask_type.lower() == 'none' or mask_type is None:
return None
elif mask_type.lower() == 'block-causal':
return _block_caulsal_mask_impl(sequence_length, device, **kwargs)
elif mask_type.lower() == 'causal':
return _caulsal_mask_impl(sequence_length, device, **kwargs)
else:
raise NotImplementedError(f"Mask type {mask_type} not implemented")
def _block_caulsal_mask_impl(sequence_length, device, block_size=16, **kwargs):
"""
Create a block-causal mask
"""
assert sequence_length % block_size == 0, "for block causal masks sequence length must be divisible by block size"
blocks = torch.ones(sequence_length // block_size, block_size, block_size, device=device)
block_diag_enable_mask = torch.block_diag(*blocks)
causal_enable_mask = torch.ones(sequence_length, sequence_length, device=device).tril_(0)
disable_mask = ((block_diag_enable_mask + causal_enable_mask) < 0.5)
return disable_mask
def _caulsal_mask_impl(sequence_length, device, **kwargs):
"""
Create a causal mask
"""
causal_disable_mask = torch.triu(
torch.full((sequence_length, sequence_length), float('-inf'), dtype=torch.float32, device=device),
diagonal=1,
)
return causal_disable_mask
if __name__ == '__main__':
mask = get_attention_mask(9, "cuda", mask_type="block-causal", block_size=3)
print(mask)
mask = get_attention_mask(9, "cuda", mask_type="causal")
print(mask)