Spaces:
Running
Running
File size: 701 Bytes
8d4b0c7 6563ff2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 |
import torch
import torch.nn as nn
from .usta_causal_attention import UstaCausalAttention
class UstaMultiHeadAttention(nn.Module):
def __init__(self, embedding_dim, output_dim, context_length, num_heads, dropout_rate = 0):
super().__init__()
self.heads = nn.ModuleList(
[UstaCausalAttention(embedding_dim, output_dim, context_length, dropout_rate) for _ in range(num_heads)]
)
self.projection = nn.Linear(embedding_dim, output_dim)
def forward(self, x):
attention_outs = []
for head in self.heads:
head_out = head(x)
attention_outs.append(head_out)
attention_out = torch.cat(attention_outs, dim=1)
return self.projection(attention_out) |