|
|
|
import torch |
|
import torch.nn as nn |
|
import math |
|
|
|
class MultiHeadSelfAttention(nn.Module): |
|
def __init__(self, hidden_size, num_heads, dropout=0.1): |
|
super().__init__() |
|
assert hidden_size % num_heads == 0, "hidden_size must be divisible by num_heads" |
|
self.hidden_size = hidden_size |
|
self.num_heads = num_heads |
|
self.head_dim = hidden_size // num_heads |
|
|
|
self.query = nn.Linear(hidden_size, hidden_size) |
|
self.key = nn.Linear(hidden_size, hidden_size) |
|
self.value = nn.Linear(hidden_size, hidden_size) |
|
self.out = nn.Linear(hidden_size, hidden_size) |
|
|
|
self.dropout = nn.Dropout(dropout) |
|
self.scale = math.sqrt(self.head_dim) |
|
|
|
def forward(self, x, mask=None, padding_mask=None): |
|
batch_size, seq_len, _ = x.size() |
|
|
|
q = self.query(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) |
|
k = self.key(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) |
|
v = self.value(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) |
|
|
|
scores = torch.matmul(q, k.transpose(-2, -1)) / self.scale |
|
|
|
if mask is not None: |
|
scores = scores.masked_fill(mask == 1, -1e4) |
|
if padding_mask is not None: |
|
padding_mask = padding_mask.unsqueeze(1).unsqueeze(2) |
|
scores = scores.masked_fill(padding_mask, -1e4) |
|
|
|
attn = torch.softmax(scores, dim=-1) |
|
attn = self.dropout(attn) |
|
|
|
out = torch.matmul(attn, v).transpose(1, 2).contiguous().view(batch_size, seq_len, self.hidden_size) |
|
out = self.out(out) |
|
return out |
|
|
|
class TransformerLayer(nn.Module): |
|
def __init__(self, hidden_size, num_heads, dropout=0.1): |
|
super().__init__() |
|
self.attn = MultiHeadSelfAttention(hidden_size, num_heads, dropout) |
|
self.ffn = nn.Sequential( |
|
nn.Linear(hidden_size, 4 * hidden_size), |
|
nn.ReLU(), |
|
nn.Linear(4 * hidden_size, hidden_size), |
|
nn.Dropout(dropout) |
|
) |
|
self.ln1 = nn.LayerNorm(hidden_size) |
|
self.ln2 = nn.LayerNorm(hidden_size) |
|
self.dropout = nn.Dropout(dropout) |
|
|
|
def forward(self, x, mask=None, padding_mask=None): |
|
x = self.ln1(x) |
|
attn_out = self.attn(x, mask, padding_mask) |
|
x = x + self.dropout(attn_out) |
|
|
|
x = self.ln2(x) |
|
ffn_out = self.ffn(x) |
|
x = x + self.dropout(ffn_out) |
|
return x |
|
|
|
class TransformerModel(nn.Module): |
|
def __init__(self, vocab_size, hidden_size=512, num_layers=6, num_heads=8, dropout=0.1): |
|
super().__init__() |
|
self.token_embedding = nn.Embedding(vocab_size, hidden_size) |
|
self.pos_embedding = nn.Embedding(512, hidden_size) |
|
self.layers = nn.ModuleList([ |
|
TransformerLayer(hidden_size, num_heads, dropout) for _ in range(num_layers) |
|
]) |
|
self.final_ln = nn.LayerNorm(hidden_size) |
|
self.head = nn.Linear(hidden_size, vocab_size) |
|
self.dropout = nn.Dropout(dropout) |
|
|
|
def forward(self, input_ids, padding_mask=None): |
|
batch_size, seq_len = input_ids.size() |
|
positions = torch.arange(0, seq_len, device=input_ids.device).unsqueeze(0).expand_as(input_ids) |
|
x = self.token_embedding(input_ids) + self.pos_embedding(positions) |
|
x = self.dropout(x) |
|
|
|
causal_mask = torch.triu(torch.ones(seq_len, seq_len, device=input_ids.device), diagonal=1).bool() |
|
|
|
for layer in self.layers: |
|
x = layer(x, causal_mask, padding_mask) |
|
|
|
x = self.final_ln(x) |
|
logits = self.head(x) |
|
return logits |