# Transformers from Scratch using "Attention is All You Need" paper # Modelling Scaled Dot-Product Attention, Multi-Head Attention, Position-wise Feed-Forward Networks. # Import Modules import matplotlib.pyplot as plt import torch.nn.functional as F import torch.nn as nn import torch import numpy as np import math # Making Single and Multi-Head Attention modules from scratch using Pure PyTorch # Initialise the seed for reproducibility seed = 42 np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) # Self-Attention Mechanism: Single Head embdim = 256 # D headdim = 64 # Internal D tokens = torch.randn(1, 5, embdim) # batch, tokens, embedding # Defining weights associates with query, key, value Wq = torch.randn(embdim, headdim) / math.sqrt(embdim) Wk = torch.randn(embdim, headdim) / math.sqrt(embdim) Wv = torch.randn(embdim, embdim) / math.sqrt(embdim) # Query, Key, Value qis = torch.einsum("BSE,EH->BSH", tokens, Wq) # batch x seqlen x headdim; queries, (1, 5, 64) kis = torch.einsum("BTE,EH->BTH", tokens, Wk) # batch x seqlen x headdim; keys vis = torch.einsum("BTE,EF->BTF", tokens, Wv) # batch x seqlen x embeddim; values # Start: Testing Code random_mat1 = torch.randn(2, 5, 4) # BATCH, TOKENS, DIMENSIONS random_mat2 = torch.randn(2, 5, 4) # 2, 5, 4 * , 2, 4, 5 torch.matmul(random_mat1, random_mat2.transpose(1, 2)) # 2, 5, 5 print(qis.shape) print(kis.shape) # (Q) N, D * (K^T) D, N -> N, N # End: Testing Code scoremat = torch.matmul(qis, kis.transpose(1, 2)) # output: batch x seqlen (Query) x seqlen (Key) attmat = F.softmax(scoremat / math.sqrt(headdim), dim=2) # attention matrix given. # Output of the attention mechanism zis = torch.einsum("BST,BTF->BSF", attmat, vis) # We can verify the output, with scaled dot-product attention attn_torch = F.scaled_dot_product_attention(qis, kis, vis) assert (torch.allclose(attn_torch, zis, atol=1E-6, rtol=1E-6)) # True # Multi-Head Attention embdim = 768 headcnt = 12 headdim = embdim // headcnt # print(headdim) assert headdim * headcnt == embdim tokens = torch.randn(1, 5, embdim) # batch, tokens, embedding # We use all the 256, ( 768) ~ which is (256), (64 * 12 (heads)) Wq = torch.randn(embdim, headcnt * headdim) / math.sqrt(embdim) # heads packed in a single dim Wk = torch.randn(embdim, headcnt * headdim) / math.sqrt(embdim) # heads packed in a single dim Wv = torch.randn(embdim, headcnt * headdim) / math.sqrt(embdim) # heads packed in a single dim print(Wq.shape) print(Wk.shape) print(Wv.shape) batch, token_num, _ = tokens.shape # batch, tokens (n), embedding shape. # tokens, B, N, E # Wq, B, E, HWeights (H * HC) qis = torch.einsum("BSE,EH->BSH", tokens, Wq) # Batch, N, H ~ 1, 5, 768 kis = torch.einsum("BTE,EH->BTH", tokens, Wk) # Batch N, H vis = torch.einsum("BTE,EH->BTH", tokens, Wv) # Batch, N, H # split the single hidden dim into the heads # Converting dimensions from (B, N, H) to (B, N, HC, HW) # So now for each batch, for each token, for each head there are a set of weights. qis_mh = qis.view(batch, token_num, headcnt, headdim) # B, N, HC, HW kis_mh = kis.view(batch, token_num, headcnt, headdim) vis_mh = vis.view(batch, token_num, headcnt, headdim) scoremat_mh = torch.einsum("BSHC,BTHC->BHST", qis_mh, kis_mh) # Input: (B, N, HC, HH) & Output: (B, HC, Q, K) print(scoremat_mh.shape) # 1, 12, 5, 5 # Now I have 12 heads, which have given me attention matrices of shape 5x5. # batch x headcnt x seqlen (query) x seqlen (key) attmat_mh = F.softmax(scoremat_mh / math.sqrt(headdim), dim=-1) zis_mh = torch.einsum("BCST,BTCH->BSCH", attmat_mh, vis_mh) # batch x seqlen (query) x headcnt x headdim zis = zis_mh.reshape(batch, token_num, headcnt * headdim) # The block does not do the operation of concat and linear layer operations on this. # We can verify the output, with Multi-Head Attention mha = nn.MultiheadAttention(embdim, headcnt, batch_first=True, ) print(mha.in_proj_weight.shape) # 3 * embdim x embdim mha.in_proj_weight.data = torch.cat([Wq, Wk, Wv], dim=1).T attn_out, attn_weights = mha(tokens, tokens, tokens, average_attn_weights=False, ) # Which is the same as attmat_mh assert torch.allclose(attmat_mh, attn_weights, atol=1e-6, rtol=1e-6) # True print(attn_weights.shape) # batch, heads, tokens, tokens. print(attn_out.shape) # Casual Mask from Scratch # Calculate Casual Mask, this is described in the paper when we do not want to attend to the future tokens, in decoder. attn_mask = torch.ones(token_num, token_num, ) attn_mask = -1E4 * torch.triu(attn_mask, 1) print(attn_mask) scoremat_mh_msk = torch.einsum("BSCH,BTCH->BCST", qis_mh, kis_mh) # batch x headcnt x seqlen (query) x seqlen (key) scoremat_mh_msk += attn_mask # add the attn mask to the scores before SoftMax normalization attmat_mh_msk = F.softmax(scoremat_mh_msk / math.sqrt(headdim), dim=-1) zis_mh_msk = torch.einsum("BCST,BTCH->BSCH", attmat_mh_msk, vis_mh) # batch x seqlen (query) x headcnt x headdim zis_msk = zis_mh_msk.reshape(batch, token_num, headcnt * headdim) attn_out_causal, attn_weights_causal = mha(tokens, tokens, tokens, average_attn_weights=False, attn_mask=attn_mask) # Plotting all heads of the attention mechanism. plt.figure() for head in range(headcnt): plt.subplot(3, 4, head + 1) plt.imshow(attn_weights_causal[0, head].detach().numpy()) plt.title(f"head {head}") plt.axis("off") plt.show() # Transformer Block from Scratch # Modeling the Transformer Block from Scratch using PyTorch # Transformer Block contains: # - Layer norm # - Skip connections # - Multi-head attention # - MLP, Feedforward net class TransformerBlock(nn.Module): def __init__(self, embdim:int, headcnt, *args, dropout=0.0, **kwargs) -> None: super().__init__(*args, **kwargs) self.ln1 = nn.LayerNorm(embdim) self.ln2 = nn.LayerNorm(embdim) self.attn = nn.MultiheadAttention(embdim, headcnt, batch_first=True,) self.ffn = nn.Sequential( nn.Linear(embdim, 4 * embdim), nn.GELU(), nn.Linear(4 * embdim, embdim), nn.Dropout(dropout), ) def forward(self, x, is_causal=True): """ Input to forward function is matrix with shape B, S, E, we can assume therefore that input and positional embeddings have been added. """ batch, token_num, hidden_dim = x.shape if is_causal: attn_mask = torch.ones(token_num, token_num,) attn_mask = -1E4 * torch.triu(attn_mask,1) else: attn_mask = None residue = x attn_output, attn_weights = self.attn(x, x, x, average_attn_weights=False, ) x = residue + attn_output x = self.ln1(x) residue = x ffn_output = self.ffn(x) output = residue + ffn_output return output if __name__ == "__main__": # Testing the Transformer Block print("Testing the Transformer Block") transformer_block = TransformerBlock(embdim, headcnt) tokens = torch.randn(1, 5, embdim) output = transformer_block(tokens) print(output.shape)