Conv_GPT / model.py
nnsohamnn's picture
Upload 4 files
591ec58 verified
# model.py
import torch
import torch.nn as nn
import math
class MultiHeadSelfAttention(nn.Module):
def __init__(self, hidden_size, num_heads, dropout=0.1):
super().__init__()
assert hidden_size % num_heads == 0, "hidden_size must be divisible by num_heads"
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
self.query = nn.Linear(hidden_size, hidden_size)
self.key = nn.Linear(hidden_size, hidden_size)
self.value = nn.Linear(hidden_size, hidden_size)
self.out = nn.Linear(hidden_size, hidden_size)
self.dropout = nn.Dropout(dropout)
self.scale = math.sqrt(self.head_dim)
def forward(self, x, mask=None, padding_mask=None):
batch_size, seq_len, _ = x.size()
q = self.query(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
k = self.key(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
v = self.value(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
scores = torch.matmul(q, k.transpose(-2, -1)) / self.scale
if mask is not None:
scores = scores.masked_fill(mask == 1, -1e4) # Adjusted for FP16 compatibility
if padding_mask is not None:
padding_mask = padding_mask.unsqueeze(1).unsqueeze(2)
scores = scores.masked_fill(padding_mask, -1e4) # Adjusted for FP16 compatibility
attn = torch.softmax(scores, dim=-1)
attn = self.dropout(attn)
out = torch.matmul(attn, v).transpose(1, 2).contiguous().view(batch_size, seq_len, self.hidden_size)
out = self.out(out)
return out
class TransformerLayer(nn.Module):
def __init__(self, hidden_size, num_heads, dropout=0.1):
super().__init__()
self.attn = MultiHeadSelfAttention(hidden_size, num_heads, dropout)
self.ffn = nn.Sequential(
nn.Linear(hidden_size, 4 * hidden_size),
nn.ReLU(),
nn.Linear(4 * hidden_size, hidden_size),
nn.Dropout(dropout)
)
self.ln1 = nn.LayerNorm(hidden_size)
self.ln2 = nn.LayerNorm(hidden_size)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None, padding_mask=None):
x = self.ln1(x)
attn_out = self.attn(x, mask, padding_mask)
x = x + self.dropout(attn_out)
x = self.ln2(x)
ffn_out = self.ffn(x)
x = x + self.dropout(ffn_out)
return x
class TransformerModel(nn.Module):
def __init__(self, vocab_size, hidden_size=512, num_layers=6, num_heads=8, dropout=0.1):
super().__init__()
self.token_embedding = nn.Embedding(vocab_size, hidden_size)
self.pos_embedding = nn.Embedding(512, hidden_size) # Fixed max_seq_len=512
self.layers = nn.ModuleList([
TransformerLayer(hidden_size, num_heads, dropout) for _ in range(num_layers)
])
self.final_ln = nn.LayerNorm(hidden_size)
self.head = nn.Linear(hidden_size, vocab_size)
self.dropout = nn.Dropout(dropout)
def forward(self, input_ids, padding_mask=None):
batch_size, seq_len = input_ids.size()
positions = torch.arange(0, seq_len, device=input_ids.device).unsqueeze(0).expand_as(input_ids)
x = self.token_embedding(input_ids) + self.pos_embedding(positions)
x = self.dropout(x)
causal_mask = torch.triu(torch.ones(seq_len, seq_len, device=input_ids.device), diagonal=1).bool()
for layer in self.layers:
x = layer(x, causal_mask, padding_mask)
x = self.final_ln(x)
logits = self.head(x)
return logits