File size: 6,896 Bytes
336cbca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
# Transformers from Scratch using "Attention is All You Need" paper
# Modelling Scaled Dot-Product Attention, Multi-Head Attention, Position-wise Feed-Forward Networks.

# Import Modules
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torch.nn as nn
import torch
import numpy as np
import math

# Making Single and Multi-Head Attention modules from scratch using Pure PyTorch

# Initialise the seed for reproducibility
seed = 42
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

# Self-Attention Mechanism: Single Head
embdim = 256  # D
headdim = 64  # Internal D
tokens = torch.randn(1, 5, embdim)  # batch, tokens, embedding

# Defining weights associates with query, key, value
Wq = torch.randn(embdim, headdim) / math.sqrt(embdim)
Wk = torch.randn(embdim, headdim) / math.sqrt(embdim)
Wv = torch.randn(embdim, embdim) / math.sqrt(embdim)

# Query, Key, Value
qis = torch.einsum("BSE,EH->BSH", tokens, Wq)  # batch x seqlen x headdim; queries, (1, 5, 64)
kis = torch.einsum("BTE,EH->BTH", tokens, Wk)  # batch x seqlen x headdim; keys
vis = torch.einsum("BTE,EF->BTF", tokens, Wv)  # batch x seqlen x embeddim; values

# Start: Testing Code
random_mat1 = torch.randn(2, 5, 4)  # BATCH, TOKENS, DIMENSIONS
random_mat2 = torch.randn(2, 5, 4)

# 2, 5, 4 * , 2, 4, 5
torch.matmul(random_mat1, random_mat2.transpose(1, 2))  # 2, 5, 5
print(qis.shape)
print(kis.shape)
# (Q) N, D * (K^T) D, N  -> N, N
# End: Testing Code


scoremat = torch.matmul(qis, kis.transpose(1, 2))  # output: batch x seqlen (Query) x seqlen (Key)
attmat = F.softmax(scoremat / math.sqrt(headdim), dim=2)  # attention matrix given.

# Output of the attention mechanism
zis = torch.einsum("BST,BTF->BSF", attmat, vis)

# We can verify the output, with scaled dot-product attention
attn_torch = F.scaled_dot_product_attention(qis, kis, vis)
assert (torch.allclose(attn_torch, zis, atol=1E-6, rtol=1E-6))  # True

# Multi-Head Attention
embdim = 768
headcnt = 12
headdim = embdim // headcnt
# print(headdim)
assert headdim * headcnt == embdim
tokens = torch.randn(1, 5, embdim)  # batch, tokens, embedding

# We use all the 256, ( 768)  ~ which is (256), (64 * 12 (heads))
Wq = torch.randn(embdim, headcnt * headdim) / math.sqrt(embdim)  # heads packed in a single dim
Wk = torch.randn(embdim, headcnt * headdim) / math.sqrt(embdim)  # heads packed in a single dim
Wv = torch.randn(embdim, headcnt * headdim) / math.sqrt(embdim)  # heads packed in a single dim

print(Wq.shape)
print(Wk.shape)
print(Wv.shape)

batch, token_num, _ = tokens.shape  # batch, tokens (n), embedding shape.
# tokens, B, N, E

# Wq,     B, E, HWeights (H * HC)
qis = torch.einsum("BSE,EH->BSH", tokens, Wq)  # Batch, N, H  ~ 1, 5, 768
kis = torch.einsum("BTE,EH->BTH", tokens, Wk)  # Batch N, H
vis = torch.einsum("BTE,EH->BTH", tokens, Wv)  # Batch, N, H
# split the single hidden dim into the heads

# Converting dimensions from (B, N, H) to (B, N, HC, HW)
# So now for each batch, for each token, for each head there are a set of weights.
qis_mh = qis.view(batch, token_num, headcnt, headdim)  # B, N, HC, HW
kis_mh = kis.view(batch, token_num, headcnt, headdim)
vis_mh = vis.view(batch, token_num, headcnt, headdim)

scoremat_mh = torch.einsum("BSHC,BTHC->BHST", qis_mh, kis_mh)  # Input: (B, N, HC, HH) & Output: (B, HC, Q, K)
print(scoremat_mh.shape)  # 1, 12, 5, 5  # Now I have 12 heads, which have given me attention matrices of shape 5x5.

# batch x headcnt x seqlen (query) x seqlen (key)

attmat_mh = F.softmax(scoremat_mh / math.sqrt(headdim), dim=-1)
zis_mh = torch.einsum("BCST,BTCH->BSCH", attmat_mh, vis_mh)  # batch x seqlen (query) x headcnt x headdim
zis = zis_mh.reshape(batch, token_num, headcnt * headdim)

# The block does not do the operation of concat and linear layer operations on this.

# We can verify the output, with Multi-Head Attention
mha = nn.MultiheadAttention(embdim, headcnt, batch_first=True, )
print(mha.in_proj_weight.shape)  # 3 * embdim x embdim
mha.in_proj_weight.data = torch.cat([Wq, Wk, Wv], dim=1).T
attn_out, attn_weights = mha(tokens, tokens, tokens, average_attn_weights=False, )

# Which is the same as attmat_mh
assert torch.allclose(attmat_mh, attn_weights, atol=1e-6, rtol=1e-6)  # True

print(attn_weights.shape)  # batch, heads, tokens, tokens.
print(attn_out.shape)

# Casual Mask from Scratch
# Calculate Casual Mask, this is described in the paper when we do not want to attend to the future tokens, in decoder.

attn_mask = torch.ones(token_num, token_num, )
attn_mask = -1E4 * torch.triu(attn_mask, 1)
print(attn_mask)
scoremat_mh_msk = torch.einsum("BSCH,BTCH->BCST", qis_mh, kis_mh)  # batch x headcnt x seqlen (query) x seqlen (key)
scoremat_mh_msk += attn_mask  # add the attn mask to the scores before SoftMax normalization
attmat_mh_msk = F.softmax(scoremat_mh_msk / math.sqrt(headdim), dim=-1)
zis_mh_msk = torch.einsum("BCST,BTCH->BSCH", attmat_mh_msk, vis_mh)  # batch x seqlen (query) x headcnt x headdim
zis_msk = zis_mh_msk.reshape(batch, token_num, headcnt * headdim)

attn_out_causal, attn_weights_causal = mha(tokens, tokens, tokens, average_attn_weights=False, attn_mask=attn_mask)

# Plotting all heads of the attention mechanism.
plt.figure()
for head in range(headcnt):
	plt.subplot(3, 4, head + 1)
	plt.imshow(attn_weights_causal[0, head].detach().numpy())
	plt.title(f"head {head}")
	plt.axis("off")
plt.show()

# Transformer Block from Scratch

# Modeling the Transformer Block from Scratch using PyTorch
# Transformer Block contains:
#     - Layer norm
#     - Skip connections
#     - Multi-head attention
#     - MLP, Feedforward net


class TransformerBlock(nn.Module):

	def __init__(self, embdim:int, headcnt, *args, dropout=0.0, **kwargs) -> None:
		super().__init__(*args, **kwargs)
		self.ln1 = nn.LayerNorm(embdim)
		self.ln2 = nn.LayerNorm(embdim)
		self.attn = nn.MultiheadAttention(embdim, headcnt, batch_first=True,)
		self.ffn = nn.Sequential(
			nn.Linear(embdim, 4 * embdim),
			nn.GELU(),
			nn.Linear(4 * embdim, embdim),
			nn.Dropout(dropout),
		)

	def forward(self, x, is_causal=True):
		"""
		Input to forward function is matrix with shape B, S, E, we can assume therefore that input and positional embeddings have been added.
		"""
		batch, token_num, hidden_dim = x.shape
		if is_causal:
			attn_mask = torch.ones(token_num, token_num,)
			attn_mask = -1E4 * torch.triu(attn_mask,1)
		else:
			attn_mask = None

		residue = x
		attn_output, attn_weights = self.attn(x, x, x, average_attn_weights=False, )
		x = residue + attn_output
		x = self.ln1(x)
		residue = x
		ffn_output = self.ffn(x)
		output = residue + ffn_output
		return output



if __name__ == "__main__":
	# Testing the Transformer Block
	print("Testing the Transformer Block")
	transformer_block = TransformerBlock(embdim, headcnt)
	tokens = torch.randn(1, 5, embdim)
	output = transformer_block(tokens)
	print(output.shape)