File size: 2,992 Bytes
9622166
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import torch
import torch.nn as nn
import math

class MultiHeadAttention(nn.Module):
    def __init__(self, dimension_for_model, num_of_heads, dropout = 0.1):
        '''
        initializes multi-head attention module
        dimension_for_model: the same variable as the one in the embeddings, meaning the dimensionality of the embeddings
        num_heads:  the number of attention heads
        dropout: as explained in positional_encodings, the dropout rate, defaulted to 0.1
        '''
 
        # Initializing the parent function
        super(MultiHeadAttention, self).__init__()
        assert dimension_for_model % num_of_heads == 0, "dimension_for_model must be devisible by num_of_heads"

        self.num_of_heads = num_of_heads
        self.dimension_for_model = dimension_for_model
        self.d_k = dimension_for_model//num_of_heads  # This is the dimension for each head

        # Creating linear layers for seperating data into query, key, and value
        self.linear_query = nn.Linear(dimension_for_model, dimension_for_model)
        self.linear_key = nn.Linear(dimension_for_model, dimension_for_model)
        self.linear_value = nn.Linear(dimension_for_model, dimension_for_model)
        self.linear_out = nn.Linear(dimension_for_model, dimension_for_model)  # Added linear out

        # Adding dropout layer
        self.dropout = nn.Dropout(dropout)
        # Defining and applying softmax
        self.softmax = nn.Softmax(dim = -1)
    
    def forward(self, query, key, value, mask = None):
        '''
        Forward pass for multi-head attention.
        query: tensors with shape batch_size, sequence_length, dimension_for_model
        key: same as query
        value: same as query and key
        mask: a tensor that can be applied to attention scores
        '''
        batch_size = query.size(0)
        seq_len = query.size(1)

        # Projecting using linear layers
        Q = self.linear_query(query)
        K = self.linear_key(key)
        V = self.linear_value(value)

        # Splitting tensors into multiple heads
        Q = Q.view(batch_size, seq_len, self.num_of_heads, self.d_k).transpose(1,2)
        K = K.view(batch_size, seq_len, self.num_of_heads, self.d_k).transpose(1,2)
        V = V.view(batch_size, seq_len, self.num_of_heads, self.d_k).transpose(1,2)
        # Applying the attention calculation formula
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)

        # Apply mask if provided
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        # Attention weight conversion
        attn = self.softmax(scores)
        attn = self.dropout(attn)

        output = torch.matmul(attn, V)

        # Concatonating outputs for all heads
        output = output.transpose(1,2).contiguous().view(batch_size, seq_len, self.dimension_for_model)

        # Linear projection to combine all heads
        output = self.linear_out(output)

        return output, attn