# model.py import torch import torch.nn as nn import math class MultiHeadSelfAttention(nn.Module): def __init__(self, hidden_size, num_heads, dropout=0.1): super().__init__() assert hidden_size % num_heads == 0, "hidden_size must be divisible by num_heads" self.hidden_size = hidden_size self.num_heads = num_heads self.head_dim = hidden_size // num_heads self.query = nn.Linear(hidden_size, hidden_size) self.key = nn.Linear(hidden_size, hidden_size) self.value = nn.Linear(hidden_size, hidden_size) self.out = nn.Linear(hidden_size, hidden_size) self.dropout = nn.Dropout(dropout) self.scale = math.sqrt(self.head_dim) def forward(self, x, mask=None, padding_mask=None): batch_size, seq_len, _ = x.size() q = self.query(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) k = self.key(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) v = self.value(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) scores = torch.matmul(q, k.transpose(-2, -1)) / self.scale if mask is not None: scores = scores.masked_fill(mask == 1, -1e4) # Adjusted for FP16 compatibility if padding_mask is not None: padding_mask = padding_mask.unsqueeze(1).unsqueeze(2) scores = scores.masked_fill(padding_mask, -1e4) # Adjusted for FP16 compatibility attn = torch.softmax(scores, dim=-1) attn = self.dropout(attn) out = torch.matmul(attn, v).transpose(1, 2).contiguous().view(batch_size, seq_len, self.hidden_size) out = self.out(out) return out class TransformerLayer(nn.Module): def __init__(self, hidden_size, num_heads, dropout=0.1): super().__init__() self.attn = MultiHeadSelfAttention(hidden_size, num_heads, dropout) self.ffn = nn.Sequential( nn.Linear(hidden_size, 4 * hidden_size), nn.ReLU(), nn.Linear(4 * hidden_size, hidden_size), nn.Dropout(dropout) ) self.ln1 = nn.LayerNorm(hidden_size) self.ln2 = nn.LayerNorm(hidden_size) self.dropout = nn.Dropout(dropout) def forward(self, x, mask=None, padding_mask=None): x = self.ln1(x) attn_out = self.attn(x, mask, padding_mask) x = x + self.dropout(attn_out) x = self.ln2(x) ffn_out = self.ffn(x) x = x + self.dropout(ffn_out) return x class TransformerModel(nn.Module): def __init__(self, vocab_size, hidden_size=512, num_layers=6, num_heads=8, dropout=0.1): super().__init__() self.token_embedding = nn.Embedding(vocab_size, hidden_size) self.pos_embedding = nn.Embedding(512, hidden_size) # Fixed max_seq_len=512 self.layers = nn.ModuleList([ TransformerLayer(hidden_size, num_heads, dropout) for _ in range(num_layers) ]) self.final_ln = nn.LayerNorm(hidden_size) self.head = nn.Linear(hidden_size, vocab_size) self.dropout = nn.Dropout(dropout) def forward(self, input_ids, padding_mask=None): batch_size, seq_len = input_ids.size() positions = torch.arange(0, seq_len, device=input_ids.device).unsqueeze(0).expand_as(input_ids) x = self.token_embedding(input_ids) + self.pos_embedding(positions) x = self.dropout(x) causal_mask = torch.triu(torch.ones(seq_len, seq_len, device=input_ids.device), diagonal=1).bool() for layer in self.layers: x = layer(x, causal_mask, padding_mask) x = self.final_ln(x) logits = self.head(x) return logits