Spaces:
Running
Running
File size: 846 Bytes
8d4b0c7 6563ff2 8d4b0c7 6563ff2 8d4b0c7 6563ff2 8d4b0c7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 |
import torch
import torch.nn as nn
class UstaMultiHeadAttention(nn.Module):
def __init__(self, embedding_dim, output_dim, context_length, num_heads, dropout_rate = 0, device="cpu"):
super().__init__()
self.context_length = context_length
self.multi_head_attention = nn.MultiheadAttention(embedding_dim, num_heads, dropout=dropout_rate, device=device)
self.projection = nn.Linear(embedding_dim, output_dim, device=device)
self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length), diagonal=1).bool().to(device))
def forward(self, x):
number_of_tokens = x.shape[0]
x = x[:self.context_length]
attention_mask = self.mask[:number_of_tokens, :number_of_tokens]
out, _ = self.multi_head_attention(x, x, x, attn_mask=attention_mask)
out = self.projection(out)
return out |