usta-llm-demo / v2 /usta_multi_head_attention_old.py
alibayram's picture
v2 implemented
6563ff2
raw
history blame contribute delete
701 Bytes
import torch
import torch.nn as nn
from .usta_causal_attention import UstaCausalAttention
class UstaMultiHeadAttention(nn.Module):
def __init__(self, embedding_dim, output_dim, context_length, num_heads, dropout_rate = 0):
super().__init__()
self.heads = nn.ModuleList(
[UstaCausalAttention(embedding_dim, output_dim, context_length, dropout_rate) for _ in range(num_heads)]
)
self.projection = nn.Linear(embedding_dim, output_dim)
def forward(self, x):
attention_outs = []
for head in self.heads:
head_out = head(x)
attention_outs.append(head_out)
attention_out = torch.cat(attention_outs, dim=1)
return self.projection(attention_out)