File size: 514 Bytes
9622166
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
import numpy as np 
import torch

def mask(size):
    '''
    A function for creating a look-ahead mask, ensuring that tokens won't see future tokens during the process of training
    through the creation of upper-triangular matrixes
    size: number of tokens within the sequence
    '''
    sq_mat = (1, size, size) # Creating a square matrix filled with 1
    mask = np.triu(np.ones(sq_mat), k=1).astype('uint8') # Turning the square matrix into an upper triangular matrix
    return torch.from_numpy(1 - mask)