import numpy as np | |
import torch | |
def mask(size): | |
''' | |
A function for creating a look-ahead mask, ensuring that tokens won't see future tokens during the process of training | |
through the creation of upper-triangular matrixes | |
size: number of tokens within the sequence | |
''' | |
sq_mat = (1, size, size) # Creating a square matrix filled with 1 | |
mask = np.triu(np.ones(sq_mat), k=1).astype('uint8') # Turning the square matrix into an upper triangular matrix | |
return torch.from_numpy(1 - mask) |