Spaces:
Sleeping
Sleeping
# evo_core_gpt.py — minimal decoder-only LM matching your checkpoint keys | |
import math | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class SelfAttention(nn.Module): | |
def __init__(self, d_model: int, num_heads: int): | |
super().__init__() | |
assert d_model % num_heads == 0 | |
self.d_model = d_model | |
self.num_heads = num_heads | |
self.head_dim = d_model // num_heads | |
# names match your state_dict: | |
self.qkv_proj = nn.Linear(d_model, 3 * d_model, bias=True) | |
self.out_proj = nn.Linear(d_model, d_model, bias=True) | |
def forward(self, x, attn_mask=None): | |
B, T, C = x.shape | |
qkv = self.qkv_proj(x) # (B,T,3C) | |
q, k, v = qkv.chunk(3, dim=-1) # each (B,T,C) | |
# reshape to heads | |
q = q.view(B, T, self.num_heads, self.head_dim).transpose(1, 2) # (B,H,T,Hd) | |
k = k.view(B, T, self.num_heads, self.head_dim).transpose(1, 2) # (B,H,T,Hd) | |
v = v.view(B, T, self.num_heads, self.head_dim).transpose(1, 2) # (B,H,T,Hd) | |
# scaled dot-product attention with causal mask | |
att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim) # (B,H,T,T) | |
# causal mask | |
causal = torch.ones((T, T), device=x.device, dtype=torch.bool).tril() | |
att = att.masked_fill(~causal, float("-inf")) | |
if attn_mask is not None: | |
att = att + attn_mask | |
att = F.softmax(att, dim=-1) | |
y = att @ v # (B,H,T,Hd) | |
y = y.transpose(1, 2).contiguous().view(B, T, C) # (B,T,C) | |
return self.out_proj(y) | |
class FFN(nn.Module): | |
def __init__(self, d_model: int, mult: int = 4): | |
super().__init__() | |
# match your names: ffn.net.0 (Linear), .3 (Linear), with GELU in between | |
self.net = nn.Sequential( | |
nn.Linear(d_model, mult * d_model, bias=True), | |
nn.GELU(), | |
nn.Dropout(p=0.0), | |
nn.Linear(mult * d_model, d_model, bias=True), | |
) | |
def forward(self, x): | |
return self.net(x) | |
class Block(nn.Module): | |
def __init__(self, d_model: int, num_heads: int): | |
super().__init__() | |
# names must match: blocks.N.ln1/ln2, .attn.*, .ffn.* | |
self.ln1 = nn.LayerNorm(d_model) | |
self.attn = SelfAttention(d_model, num_heads) | |
self.ln2 = nn.LayerNorm(d_model) | |
self.ffn = FFN(d_model) | |
def forward(self, x): | |
x = x + self.attn(self.ln1(x)) | |
x = x + self.ffn(self.ln2(x)) | |
return x | |
class EvoGPT(nn.Module): | |
def __init__(self, vocab_size: int, d_model: int, n_layers: int, n_positions: int, num_heads: int): | |
super().__init__() | |
# names must match your checkpoint: | |
self.token_emb = nn.Embedding(vocab_size, d_model) | |
self.pos_emb = nn.Embedding(n_positions, d_model) | |
self.blocks = nn.ModuleList([Block(d_model, num_heads) for _ in range(n_layers)]) | |
# (no final ln_f in your keys; tokens are projected by tying to token_emb) | |
self.register_buffer("pos_idx", torch.arange(0, n_positions).long(), persistent=False) | |
def forward(self, input_ids): | |
B, T = input_ids.shape | |
pos = self.pos_idx[:T] | |
x = self.token_emb(input_ids) + self.pos_emb(pos)[None, :, :] | |
for blk in self.blocks: | |
x = blk(x) | |
# weight tying: logits = x @ token_emb^T | |
logits = x @ self.token_emb.weight.t() | |
return logits | |
def generate(self, input_ids, max_new_tokens=128, temperature=0.7, eos_token_id=None): | |
self.eval() | |
for _ in range(int(max_new_tokens)): | |
logits = self.forward(input_ids)[:, -1, :] # (B, V) | |
if temperature and temperature > 0.0: | |
probs = F.softmax(logits / max(0.01, temperature), dim=-1) | |
next_id = torch.multinomial(probs, num_samples=1) | |
else: | |
next_id = torch.argmax(logits, dim=-1, keepdim=True) | |
input_ids = torch.cat([input_ids, next_id], dim=1) | |
if eos_token_id is not None and (next_id == eos_token_id).all(): | |
break | |
return input_ids | |