a-ragab-h-m commited on
Commit
722e008
·
verified ·
1 Parent(s): e61766c

Update nets/projections.py

Browse files
Files changed (1) hide show
  1. nets/projections.py +12 -13
nets/projections.py CHANGED
@@ -13,7 +13,7 @@ class Projections(nn.Module):
13
 
14
  self.W_key = nn.Parameter(torch.Tensor(n_heads, embed_dim, self.val_dim))
15
  self.W_val = nn.Parameter(torch.Tensor(n_heads, embed_dim, self.val_dim))
16
- self.W_output = nn.Parameter(torch.Tensor(1, embed_dim, embed_dim))
17
 
18
  self.init_parameters()
19
 
@@ -24,23 +24,22 @@ class Projections(nn.Module):
24
 
25
  def forward(self, h):
26
  """
27
- :param h: (batch_size, graph_size, embed_dim)
28
- :return: dict with keys K, V, V_output for attention
29
  """
30
  batch_size, graph_size, input_dim = h.size()
31
- hflat = h.view(-1, input_dim) # (batch_size * graph_size, embed_dim)
32
 
 
33
  shp = (self.n_heads, batch_size, graph_size, self.val_dim)
 
 
34
 
35
- # Apply projections
36
- K = torch.matmul(hflat, self.W_key).view(shp) # (n_heads, batch_size, graph_size, val_dim)
37
- V = torch.matmul(hflat, self.W_val).view(shp) # (n_heads, batch_size, graph_size, val_dim)
38
-
39
- # Output projection
40
- V_output = torch.bmm(h, self.W_output.repeat(batch_size, 1, 1)) # (batch_size, graph_size, embed_dim)
41
 
42
  return {
43
- 'K': K,
44
- 'V': V,
45
- 'V_output': V_output
46
  }
 
13
 
14
  self.W_key = nn.Parameter(torch.Tensor(n_heads, embed_dim, self.val_dim))
15
  self.W_val = nn.Parameter(torch.Tensor(n_heads, embed_dim, self.val_dim))
16
+ self.W_output = nn.Parameter(torch.Tensor(embed_dim, embed_dim))
17
 
18
  self.init_parameters()
19
 
 
24
 
25
  def forward(self, h):
26
  """
27
+ :param h: Tensor of shape (batch_size, graph_size, embed_dim)
28
+ :return: dict with keys: K, V, V_output
29
  """
30
  batch_size, graph_size, input_dim = h.size()
31
+ hflat = h.contiguous().view(-1, input_dim) # (batch_size * graph_size, embed_dim)
32
 
33
+ # Compute Keys and Values per head
34
  shp = (self.n_heads, batch_size, graph_size, self.val_dim)
35
+ K = torch.matmul(hflat, self.W_key).view(shp)
36
+ V = torch.matmul(hflat, self.W_val).view(shp)
37
 
38
+ # Compute output projection: (batch_size, graph_size, embed_dim)
39
+ V_output = torch.matmul(h, self.W_output.expand_as(self.W_output))
 
 
 
 
40
 
41
  return {
42
+ 'K': K, # (n_heads, batch_size, graph_size, val_dim)
43
+ 'V': V, # (n_heads, batch_size, graph_size, val_dim)
44
+ 'V_output': V_output # (batch_size, graph_size, embed_dim)
45
  }