Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
from .usta_causal_attention import UstaCausalAttention | |
class UstaMultiHeadAttention(nn.Module): | |
def __init__(self, embedding_dim, output_dim, context_length, num_heads, dropout_rate = 0): | |
super().__init__() | |
self.heads = nn.ModuleList( | |
[UstaCausalAttention(embedding_dim, output_dim, context_length, dropout_rate) for _ in range(num_heads)] | |
) | |
self.projection = nn.Linear(embedding_dim, output_dim) | |
def forward(self, x): | |
attention_outs = [] | |
for head in self.heads: | |
head_out = head(x) | |
attention_outs.append(head_out) | |
attention_out = torch.cat(attention_outs, dim=1) | |
return self.projection(attention_out) |