Spaces:
Runtime error
Runtime error
File size: 3,727 Bytes
9d169e7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
import torch
import torch.nn.functional as F
import numpy as np
from torch import nn
import math
class MultiHeadAttention(nn.Module):
def __init__(
self,
n_heads,
input_dim,
embed_dim=None,
val_dim=None,
key_dim=None
):
super(MultiHeadAttention, self).__init__()
if val_dim is None:
assert embed_dim is not None, "Provide either embed_dim or val_dim"
val_dim = embed_dim // n_heads
if key_dim is None:
key_dim = val_dim
self.n_heads = n_heads
self.input_dim = input_dim
self.embed_dim = embed_dim
self.val_dim = val_dim
self.key_dim = key_dim
self.norm_factor = 1 / math.sqrt(key_dim) # See Attention is all you need
self.W_query = nn.Parameter(torch.Tensor(n_heads, input_dim, key_dim), requires_grad=True)
self.W_key = nn.Parameter(torch.Tensor(n_heads, input_dim, key_dim), requires_grad=True)
self.W_val = nn.Parameter(torch.Tensor(n_heads, input_dim, val_dim), requires_grad=True)
if embed_dim is not None:
self.W_out = nn.Parameter(torch.Tensor(n_heads, key_dim, embed_dim), requires_grad=True)
self.init_parameters()
def init_parameters(self):
for param in self.parameters():
stdv = 1. / math.sqrt(param.size(-1))
param.data.uniform_(-stdv, stdv)
def forward(self, q, h=None, mask=None):
"""
:param q: queries (batch_size, n_query, input_dim)
:param h: data (batch_size, graph_size, input_dim)
:param mask: mask (batch_size, n_query, graph_size) or viewable as that (i.e. can be 2 dim if n_query == 1)
Mask should contain 1 if attention is not possible (i.e. mask is negative adjacency)
:return:
"""
if h is None:
h = q # compute self-attention
# h should be (batch_size, graph_size, input_dim)
batch_size, graph_size, input_dim = h.size()
n_query = q.size(1)
assert q.size(0) == batch_size
assert q.size(2) == input_dim
assert input_dim == self.input_dim, "Wrong embedding dimension of input"
hflat = h.contiguous().view(-1, input_dim)
qflat = q.contiguous().view(-1, input_dim)
# last dimension can be different for keys and values
shp = (self.n_heads, batch_size, graph_size, -1)
shp_q = (self.n_heads, batch_size, n_query, -1)
# Calculate queries, (n_heads, n_query, graph_size, key/val_size)
Q = torch.matmul(qflat, self.W_query).view(shp_q)
# Calculate keys and values (n_heads, batch_size, graph_size, key/val_size)
K = torch.matmul(hflat, self.W_key).view(shp)
V = torch.matmul(hflat, self.W_val).view(shp)
# Calculate compatibility (n_heads, batch_size, n_query, graph_size)
compatibility = self.norm_factor * torch.matmul(Q, K.transpose(2, 3))
# Optionally apply mask to prevent attention
if mask is not None:
mask = mask.view(1, batch_size, n_query, graph_size).expand_as(compatibility)
compatibility[mask] = -np.inf
attn = F.softmax(compatibility, dim=-1)
# If there are nodes with no neighbours then softmax returns nan so we fix them to 0
if mask is not None:
attnc = attn.clone()
attnc[mask] = 0
attn = attnc
heads = torch.matmul(attn, V)
out = torch.mm(
heads.permute(1, 2, 0, 3).contiguous().view(-1, self.n_heads * self.val_dim),
self.W_out.view(-1, self.embed_dim)
).view(batch_size, n_query, self.embed_dim)
return out
|