# evo_core_gpt.py — minimal decoder-only LM matching your checkpoint keys import math import torch import torch.nn as nn import torch.nn.functional as F class SelfAttention(nn.Module): def __init__(self, d_model: int, num_heads: int): super().__init__() assert d_model % num_heads == 0 self.d_model = d_model self.num_heads = num_heads self.head_dim = d_model // num_heads # names match your state_dict: self.qkv_proj = nn.Linear(d_model, 3 * d_model, bias=True) self.out_proj = nn.Linear(d_model, d_model, bias=True) def forward(self, x, attn_mask=None): B, T, C = x.shape qkv = self.qkv_proj(x) # (B,T,3C) q, k, v = qkv.chunk(3, dim=-1) # each (B,T,C) # reshape to heads q = q.view(B, T, self.num_heads, self.head_dim).transpose(1, 2) # (B,H,T,Hd) k = k.view(B, T, self.num_heads, self.head_dim).transpose(1, 2) # (B,H,T,Hd) v = v.view(B, T, self.num_heads, self.head_dim).transpose(1, 2) # (B,H,T,Hd) # scaled dot-product attention with causal mask att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim) # (B,H,T,T) # causal mask causal = torch.ones((T, T), device=x.device, dtype=torch.bool).tril() att = att.masked_fill(~causal, float("-inf")) if attn_mask is not None: att = att + attn_mask att = F.softmax(att, dim=-1) y = att @ v # (B,H,T,Hd) y = y.transpose(1, 2).contiguous().view(B, T, C) # (B,T,C) return self.out_proj(y) class FFN(nn.Module): def __init__(self, d_model: int, mult: int = 4): super().__init__() # match your names: ffn.net.0 (Linear), .3 (Linear), with GELU in between self.net = nn.Sequential( nn.Linear(d_model, mult * d_model, bias=True), nn.GELU(), nn.Dropout(p=0.0), nn.Linear(mult * d_model, d_model, bias=True), ) def forward(self, x): return self.net(x) class Block(nn.Module): def __init__(self, d_model: int, num_heads: int): super().__init__() # names must match: blocks.N.ln1/ln2, .attn.*, .ffn.* self.ln1 = nn.LayerNorm(d_model) self.attn = SelfAttention(d_model, num_heads) self.ln2 = nn.LayerNorm(d_model) self.ffn = FFN(d_model) def forward(self, x): x = x + self.attn(self.ln1(x)) x = x + self.ffn(self.ln2(x)) return x class EvoGPT(nn.Module): def __init__(self, vocab_size: int, d_model: int, n_layers: int, n_positions: int, num_heads: int): super().__init__() # names must match your checkpoint: self.token_emb = nn.Embedding(vocab_size, d_model) self.pos_emb = nn.Embedding(n_positions, d_model) self.blocks = nn.ModuleList([Block(d_model, num_heads) for _ in range(n_layers)]) # (no final ln_f in your keys; tokens are projected by tying to token_emb) self.register_buffer("pos_idx", torch.arange(0, n_positions).long(), persistent=False) def forward(self, input_ids): B, T = input_ids.shape pos = self.pos_idx[:T] x = self.token_emb(input_ids) + self.pos_emb(pos)[None, :, :] for blk in self.blocks: x = blk(x) # weight tying: logits = x @ token_emb^T logits = x @ self.token_emb.weight.t() return logits @torch.no_grad() def generate(self, input_ids, max_new_tokens=128, temperature=0.7, eos_token_id=None): self.eval() for _ in range(int(max_new_tokens)): logits = self.forward(input_ids)[:, -1, :] # (B, V) if temperature and temperature > 0.0: probs = F.softmax(logits / max(0.01, temperature), dim=-1) next_id = torch.multinomial(probs, num_samples=1) else: next_id = torch.argmax(logits, dim=-1, keepdim=True) input_ids = torch.cat([input_ids, next_id], dim=1) if eos_token_id is not None and (next_id == eos_token_id).all(): break return input_ids