import argparse import os import sys import shutil import random import numpy as np import time import copy import math import matplotlib.pyplot as plt import torch import torch.nn.functional as F import torch.nn as nn from torch.autograd import Variable import transformers from transformers import GPT2TokenizerFast # --------------------------- # Utility masks & helpers # --------------------------- def subsequent_mask(size): """Mask out subsequent positions for autoregressive decoding.""" attn_shape = (1, size, size) mask = torch.triu(torch.ones(attn_shape), diagonal=1).bool() return mask def read_corpus(filename, tokenizer): """Tokenise a plain‑text corpus into a single long id sequence.""" seq = [] with open(filename, "rt") as f: for line in f: line = line.rstrip("\n") tokens = tokenizer(line) seq.extend(tokens["input_ids"]) return seq # --------------------------- # Embedding & positional code # --------------------------- class Embedder(nn.Module): def __init__(self, vocab_size, d_model): super().__init__() self.d_model = d_model self.embed = nn.Embedding(vocab_size, d_model) def forward(self, x): return self.embed(x.long()) class PositionalEncoder(nn.Module): def __init__(self, d_model, max_seq_len: int = 4096, dropout: float = 0.1): super().__init__() self.d_model = d_model self.dropout = nn.Dropout(dropout) pe = torch.zeros(max_seq_len, d_model) position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0) self.register_buffer("pe", pe) def forward(self, x): x = x * math.sqrt(self.d_model) seq_len = x.size(1) x = x + self.pe[:, :seq_len] return self.dropout(x) class Norm(nn.Module): """Layer‑norm with learnable gain/bias (identical to nn.LayerNorm but explicit).""" def __init__(self, d_model: int, eps: float = 1e-6): super().__init__() self.size = d_model self.alpha = nn.Parameter(torch.ones(d_model)) self.bias = nn.Parameter(torch.zeros(d_model)) self.eps = eps def forward(self, x): return self.alpha * (x - x.mean(dim=-1, keepdim=True)) / (x.std(dim=-1, keepdim=True) + self.eps) + self.bias # --------------------------- # Attention (Euclidean metric) # --------------------------- def euclidean_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, d_k: int, mask=None, dropout=None): """Scaled Euclidean‑distance attention. Attention weights are computed from *negative scaled squared Euclidean distances*: score_{ij} = -||q_i - k_j||^2 / sqrt(d_k) A softmax over the key dimension then yields the usual attention distribution. """ # q, k, v: (bs, h, len, d_k) # Compute ||q||^2 and ||k||^2 terms q_norm = (q ** 2).sum(dim=-1, keepdim=True) # (bs, h, len_q, 1) k_norm = (k ** 2).sum(dim=-1).unsqueeze(-2) # (bs, h, 1, len_k) # Pairwise squared distances via (a-b)^2 = a^2 + b^2 - 2ab scores = q_norm + k_norm - 2 * torch.matmul(q, k.transpose(-2, -1)) # (bs, h, len_q, len_k) scores = -scores / math.sqrt(d_k) # negate & scale so that *smaller distance => larger score* if mask is not None: mask = mask.unsqueeze(1) # broadcast across heads scores = scores.masked_fill(mask == 0, -1e9) attn = F.softmax(scores, dim=-1) if dropout is not None: attn = dropout(attn) output = torch.matmul(attn, v) return output class MultiHeadAttention(nn.Module): def __init__(self, heads: int, d_model: int, dropout: float = 0.1): super().__init__() assert d_model % heads == 0, "d_model must be divisible by heads" self.d_k = d_model // heads self.h = heads self.q_linear = nn.Linear(d_model, d_model) self.k_linear = nn.Linear(d_model, d_model) self.v_linear = nn.Linear(d_model, d_model) self.dropout = nn.Dropout(dropout) self.out = nn.Linear(d_model, d_model) def forward(self, q, k, v, mask=None): bs = q.size(0) # project and split multi‑head k = self.k_linear(k).view(bs, -1, self.h, self.d_k).transpose(1, 2) # (bs, h, len, d_k) q = self.q_linear(q).view(bs, -1, self.h, self.d_k).transpose(1, 2) v = self.v_linear(v).view(bs, -1, self.h, self.d_k).transpose(1, 2) # Euclidean attention scores = euclidean_attention(q, k, v, self.d_k, mask, self.dropout) # merge heads concat = scores.transpose(1, 2).contiguous().view(bs, -1, self.h * self.d_k) return self.out(concat) # --------------------------- # Feed‑forward & decoder # --------------------------- class FeedForward(nn.Module): def __init__(self, d_model: int, d_ff: int = 2048, dropout: float = 0.1): super().__init__() self.linear_1 = nn.Linear(d_model, d_ff) self.dropout = nn.Dropout(dropout) self.linear_2 = nn.Linear(d_ff, d_model) def forward(self, x): return self.linear_2(self.dropout(F.relu(self.linear_1(x)))) def get_clones(module, N): return nn.ModuleList([copy.deepcopy(module) for _ in range(N)]) class DecoderLayer(nn.Module): def __init__(self, d_model: int, heads: int, dropout: float = 0.1): super().__init__() self.norm_1 = Norm(d_model) self.norm_2 = Norm(d_model) self.attn = MultiHeadAttention(heads, d_model, dropout) self.ff = FeedForward(d_model, dropout=dropout) self.dropout_1 = nn.Dropout(dropout) self.dropout_2 = nn.Dropout(dropout) def forward(self, x, trg_mask): x2 = self.norm_1(x) x = x + self.dropout_1(self.attn(x2, x2, x2, trg_mask)) x2 = self.norm_2(x) x = x + self.dropout_2(self.ff(x2)) return x class Decoder(nn.Module): def __init__(self, vocab_size: int, d_model: int, N: int, heads: int, dropout: float): super().__init__() self.embed = Embedder(vocab_size, d_model) self.pe = PositionalEncoder(d_model, dropout=dropout) self.layers = get_clones(DecoderLayer(d_model, heads, dropout), N) self.norm = Norm(d_model) def forward(self, x, trg_mask): x = self.embed(x) x = self.pe(x) for layer in self.layers: x = layer(x, trg_mask) return self.norm(x) class GPT2LM(nn.Module): def __init__(self, vocab_size: int, d_model: int, N: int, heads: int, dropout: float, tie_weights: bool = False): super().__init__() self.decoder = Decoder(vocab_size, d_model, N, heads, dropout) self.out = nn.Linear(d_model, vocab_size) if tie_weights: self.out.weight = self.decoder.embed.embed.weight print("✅ Tied embeddings enabled.") def forward(self, x, mask): return self.out(self.decoder(x, mask)) # --------------------------- # Data batcher # --------------------------- def batchify(data, batch_size, seq_len): nbatch = len(data) // batch_size data = torch.tensor(data[: nbatch * batch_size], dtype=torch.long) data = data.view(batch_size, -1) for i in range(0, data.size(1) - 1, seq_len): seq_len_i = min(seq_len, data.size(1) - 1 - i) src = data[:, i : i + seq_len_i] tgt = data[:, i + 1 : i + 1 + seq_len_i] yield src, tgt # --------------------------- # Train / eval loops # --------------------------- def train_model(model, opt): print("Starting training (Euclidean attention)…") model.train() train_ppls, valid_ppls = [], [] for epoch in range(opt.epochs): total_loss, batches = 0.0, 0 for src, tgt in batchify(opt.train, opt.batchsize, opt.seqlen): src, tgt = src.to(opt.device), tgt.to(opt.device) mask = subsequent_mask(src.size(1)).to(opt.device) output = model(src, mask) loss = F.cross_entropy(output.view(-1, opt.vocab_size), tgt.reshape(-1), ignore_index=opt.src_pad) opt.optimizer.zero_grad() loss.backward() opt.optimizer.step() total_loss += loss.item() batches += 1 avg_loss = total_loss / batches train_ppl = math.exp(avg_loss) train_ppls.append(train_ppl) print(f"Epoch {epoch+1}/{opt.epochs} • Train PPL: {train_ppl:.2f}") valid_ppl = evaluate(model, opt.valid, opt, tag=f"valid‑e{epoch+1}") valid_ppls.append(valid_ppl) # --- bookkeeping --- dir_name = os.path.join("saved", opt.dir_name) os.makedirs(dir_name, exist_ok=True) torch.save(model.state_dict(), os.path.join(dir_name, "gpt2lm_euclid.pth")) plt.plot(range(1, opt.epochs + 1), train_ppls, label="Train PPL") plt.plot(range(1, opt.epochs + 1), valid_ppls, label="Valid PPL") plt.xlabel("Epoch"); plt.ylabel("Perplexity"); plt.title("Euclidean‑Attention GPT‑2 on WikiText‑2") plt.legend() plt.savefig(os.path.join(dir_name, "learning_curve.png")) plt.close() with open(os.path.join(dir_name, "perplexity_log.txt"), "w") as f: for i in range(opt.epochs): f.write(f"Epoch {i+1}: Train {train_ppls[i]:.2f} Valid {valid_ppls[i]:.2f}\n") def evaluate(model, data, opt, tag="valid"): model.eval() total_loss, batches = 0.0, 0 with torch.no_grad(): for src, tgt in batchify(data, opt.batchsize, opt.seqlen): src, tgt = src.to(opt.device), tgt.to(opt.device) mask = subsequent_mask(src.size(1)).to(opt.device) output = model(src, mask) loss = F.cross_entropy(output.view(-1, opt.vocab_size), tgt.reshape(-1), ignore_index=opt.src_pad) total_loss += loss.item() batches += 1 ppl = math.exp(total_loss / batches) print(f"{tag.capitalize()} PPL: {ppl:.2f}") model.train() return ppl # --------------------------- # Main entry # --------------------------- def main(): random.seed(10) parser = argparse.ArgumentParser() parser.add_argument("-no_cuda", action="store_true") parser.add_argument("-epochs", type=int, default=20) parser.add_argument("-d_model", type=int, default=512) parser.add_argument("-n_layers", type=int, default=6) parser.add_argument("-heads", type=int, default=8) parser.add_argument("-dropout", type=float, default=0.1) parser.add_argument("-batchsize", type=int, default=1) parser.add_argument("-lr", type=float, default=1e-5) parser.add_argument("-seqlen", type=int, default=512) parser.add_argument("-tied", type=int, default=1) parser.add_argument("-dir_name", type=str, default="model_euclid") opt = parser.parse_args() opt.device = torch.device("cuda:0" if (not opt.no_cuda and torch.cuda.is_available()) else "cpu") tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") opt.train = read_corpus("wiki2.train.txt", tokenizer) opt.valid = read_corpus("wiki2.valid.txt", tokenizer) opt.test = read_corpus("wiki2.test.txt", tokenizer) opt.vocab_size = 50257 opt.src_pad = opt.trg_pad = 0 model = GPT2LM(opt.vocab_size, opt.d_model, opt.n_layers, opt.heads, opt.dropout, tie_weights=(opt.tied == 1)).to(opt.device) print(f"Model parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.1f}M") opt.optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr, betas=(0.9, 0.98), eps=1e-9) train_model(model, opt) evaluate(model, opt.test, opt, tag="test") if __name__ == "__main__": main()