File size: 6,896 Bytes
336cbca |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 |
# Transformers from Scratch using "Attention is All You Need" paper
# Modelling Scaled Dot-Product Attention, Multi-Head Attention, Position-wise Feed-Forward Networks.
# Import Modules
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torch.nn as nn
import torch
import numpy as np
import math
# Making Single and Multi-Head Attention modules from scratch using Pure PyTorch
# Initialise the seed for reproducibility
seed = 42
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# Self-Attention Mechanism: Single Head
embdim = 256 # D
headdim = 64 # Internal D
tokens = torch.randn(1, 5, embdim) # batch, tokens, embedding
# Defining weights associates with query, key, value
Wq = torch.randn(embdim, headdim) / math.sqrt(embdim)
Wk = torch.randn(embdim, headdim) / math.sqrt(embdim)
Wv = torch.randn(embdim, embdim) / math.sqrt(embdim)
# Query, Key, Value
qis = torch.einsum("BSE,EH->BSH", tokens, Wq) # batch x seqlen x headdim; queries, (1, 5, 64)
kis = torch.einsum("BTE,EH->BTH", tokens, Wk) # batch x seqlen x headdim; keys
vis = torch.einsum("BTE,EF->BTF", tokens, Wv) # batch x seqlen x embeddim; values
# Start: Testing Code
random_mat1 = torch.randn(2, 5, 4) # BATCH, TOKENS, DIMENSIONS
random_mat2 = torch.randn(2, 5, 4)
# 2, 5, 4 * , 2, 4, 5
torch.matmul(random_mat1, random_mat2.transpose(1, 2)) # 2, 5, 5
print(qis.shape)
print(kis.shape)
# (Q) N, D * (K^T) D, N -> N, N
# End: Testing Code
scoremat = torch.matmul(qis, kis.transpose(1, 2)) # output: batch x seqlen (Query) x seqlen (Key)
attmat = F.softmax(scoremat / math.sqrt(headdim), dim=2) # attention matrix given.
# Output of the attention mechanism
zis = torch.einsum("BST,BTF->BSF", attmat, vis)
# We can verify the output, with scaled dot-product attention
attn_torch = F.scaled_dot_product_attention(qis, kis, vis)
assert (torch.allclose(attn_torch, zis, atol=1E-6, rtol=1E-6)) # True
# Multi-Head Attention
embdim = 768
headcnt = 12
headdim = embdim // headcnt
# print(headdim)
assert headdim * headcnt == embdim
tokens = torch.randn(1, 5, embdim) # batch, tokens, embedding
# We use all the 256, ( 768) ~ which is (256), (64 * 12 (heads))
Wq = torch.randn(embdim, headcnt * headdim) / math.sqrt(embdim) # heads packed in a single dim
Wk = torch.randn(embdim, headcnt * headdim) / math.sqrt(embdim) # heads packed in a single dim
Wv = torch.randn(embdim, headcnt * headdim) / math.sqrt(embdim) # heads packed in a single dim
print(Wq.shape)
print(Wk.shape)
print(Wv.shape)
batch, token_num, _ = tokens.shape # batch, tokens (n), embedding shape.
# tokens, B, N, E
# Wq, B, E, HWeights (H * HC)
qis = torch.einsum("BSE,EH->BSH", tokens, Wq) # Batch, N, H ~ 1, 5, 768
kis = torch.einsum("BTE,EH->BTH", tokens, Wk) # Batch N, H
vis = torch.einsum("BTE,EH->BTH", tokens, Wv) # Batch, N, H
# split the single hidden dim into the heads
# Converting dimensions from (B, N, H) to (B, N, HC, HW)
# So now for each batch, for each token, for each head there are a set of weights.
qis_mh = qis.view(batch, token_num, headcnt, headdim) # B, N, HC, HW
kis_mh = kis.view(batch, token_num, headcnt, headdim)
vis_mh = vis.view(batch, token_num, headcnt, headdim)
scoremat_mh = torch.einsum("BSHC,BTHC->BHST", qis_mh, kis_mh) # Input: (B, N, HC, HH) & Output: (B, HC, Q, K)
print(scoremat_mh.shape) # 1, 12, 5, 5 # Now I have 12 heads, which have given me attention matrices of shape 5x5.
# batch x headcnt x seqlen (query) x seqlen (key)
attmat_mh = F.softmax(scoremat_mh / math.sqrt(headdim), dim=-1)
zis_mh = torch.einsum("BCST,BTCH->BSCH", attmat_mh, vis_mh) # batch x seqlen (query) x headcnt x headdim
zis = zis_mh.reshape(batch, token_num, headcnt * headdim)
# The block does not do the operation of concat and linear layer operations on this.
# We can verify the output, with Multi-Head Attention
mha = nn.MultiheadAttention(embdim, headcnt, batch_first=True, )
print(mha.in_proj_weight.shape) # 3 * embdim x embdim
mha.in_proj_weight.data = torch.cat([Wq, Wk, Wv], dim=1).T
attn_out, attn_weights = mha(tokens, tokens, tokens, average_attn_weights=False, )
# Which is the same as attmat_mh
assert torch.allclose(attmat_mh, attn_weights, atol=1e-6, rtol=1e-6) # True
print(attn_weights.shape) # batch, heads, tokens, tokens.
print(attn_out.shape)
# Casual Mask from Scratch
# Calculate Casual Mask, this is described in the paper when we do not want to attend to the future tokens, in decoder.
attn_mask = torch.ones(token_num, token_num, )
attn_mask = -1E4 * torch.triu(attn_mask, 1)
print(attn_mask)
scoremat_mh_msk = torch.einsum("BSCH,BTCH->BCST", qis_mh, kis_mh) # batch x headcnt x seqlen (query) x seqlen (key)
scoremat_mh_msk += attn_mask # add the attn mask to the scores before SoftMax normalization
attmat_mh_msk = F.softmax(scoremat_mh_msk / math.sqrt(headdim), dim=-1)
zis_mh_msk = torch.einsum("BCST,BTCH->BSCH", attmat_mh_msk, vis_mh) # batch x seqlen (query) x headcnt x headdim
zis_msk = zis_mh_msk.reshape(batch, token_num, headcnt * headdim)
attn_out_causal, attn_weights_causal = mha(tokens, tokens, tokens, average_attn_weights=False, attn_mask=attn_mask)
# Plotting all heads of the attention mechanism.
plt.figure()
for head in range(headcnt):
plt.subplot(3, 4, head + 1)
plt.imshow(attn_weights_causal[0, head].detach().numpy())
plt.title(f"head {head}")
plt.axis("off")
plt.show()
# Transformer Block from Scratch
# Modeling the Transformer Block from Scratch using PyTorch
# Transformer Block contains:
# - Layer norm
# - Skip connections
# - Multi-head attention
# - MLP, Feedforward net
class TransformerBlock(nn.Module):
def __init__(self, embdim:int, headcnt, *args, dropout=0.0, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.ln1 = nn.LayerNorm(embdim)
self.ln2 = nn.LayerNorm(embdim)
self.attn = nn.MultiheadAttention(embdim, headcnt, batch_first=True,)
self.ffn = nn.Sequential(
nn.Linear(embdim, 4 * embdim),
nn.GELU(),
nn.Linear(4 * embdim, embdim),
nn.Dropout(dropout),
)
def forward(self, x, is_causal=True):
"""
Input to forward function is matrix with shape B, S, E, we can assume therefore that input and positional embeddings have been added.
"""
batch, token_num, hidden_dim = x.shape
if is_causal:
attn_mask = torch.ones(token_num, token_num,)
attn_mask = -1E4 * torch.triu(attn_mask,1)
else:
attn_mask = None
residue = x
attn_output, attn_weights = self.attn(x, x, x, average_attn_weights=False, )
x = residue + attn_output
x = self.ln1(x)
residue = x
ffn_output = self.ffn(x)
output = residue + ffn_output
return output
if __name__ == "__main__":
# Testing the Transformer Block
print("Testing the Transformer Block")
transformer_block = TransformerBlock(embdim, headcnt)
tokens = torch.randn(1, 5, embdim)
output = transformer_block(tokens)
print(output.shape)
|