evo-gov-copilot-mu / evo_core_gpt.py
HemanM's picture
Create evo_core_gpt.py
ecf4bc9 verified
# evo_core_gpt.py — minimal decoder-only LM matching your checkpoint keys
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class SelfAttention(nn.Module):
def __init__(self, d_model: int, num_heads: int):
super().__init__()
assert d_model % num_heads == 0
self.d_model = d_model
self.num_heads = num_heads
self.head_dim = d_model // num_heads
# names match your state_dict:
self.qkv_proj = nn.Linear(d_model, 3 * d_model, bias=True)
self.out_proj = nn.Linear(d_model, d_model, bias=True)
def forward(self, x, attn_mask=None):
B, T, C = x.shape
qkv = self.qkv_proj(x) # (B,T,3C)
q, k, v = qkv.chunk(3, dim=-1) # each (B,T,C)
# reshape to heads
q = q.view(B, T, self.num_heads, self.head_dim).transpose(1, 2) # (B,H,T,Hd)
k = k.view(B, T, self.num_heads, self.head_dim).transpose(1, 2) # (B,H,T,Hd)
v = v.view(B, T, self.num_heads, self.head_dim).transpose(1, 2) # (B,H,T,Hd)
# scaled dot-product attention with causal mask
att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim) # (B,H,T,T)
# causal mask
causal = torch.ones((T, T), device=x.device, dtype=torch.bool).tril()
att = att.masked_fill(~causal, float("-inf"))
if attn_mask is not None:
att = att + attn_mask
att = F.softmax(att, dim=-1)
y = att @ v # (B,H,T,Hd)
y = y.transpose(1, 2).contiguous().view(B, T, C) # (B,T,C)
return self.out_proj(y)
class FFN(nn.Module):
def __init__(self, d_model: int, mult: int = 4):
super().__init__()
# match your names: ffn.net.0 (Linear), .3 (Linear), with GELU in between
self.net = nn.Sequential(
nn.Linear(d_model, mult * d_model, bias=True),
nn.GELU(),
nn.Dropout(p=0.0),
nn.Linear(mult * d_model, d_model, bias=True),
)
def forward(self, x):
return self.net(x)
class Block(nn.Module):
def __init__(self, d_model: int, num_heads: int):
super().__init__()
# names must match: blocks.N.ln1/ln2, .attn.*, .ffn.*
self.ln1 = nn.LayerNorm(d_model)
self.attn = SelfAttention(d_model, num_heads)
self.ln2 = nn.LayerNorm(d_model)
self.ffn = FFN(d_model)
def forward(self, x):
x = x + self.attn(self.ln1(x))
x = x + self.ffn(self.ln2(x))
return x
class EvoGPT(nn.Module):
def __init__(self, vocab_size: int, d_model: int, n_layers: int, n_positions: int, num_heads: int):
super().__init__()
# names must match your checkpoint:
self.token_emb = nn.Embedding(vocab_size, d_model)
self.pos_emb = nn.Embedding(n_positions, d_model)
self.blocks = nn.ModuleList([Block(d_model, num_heads) for _ in range(n_layers)])
# (no final ln_f in your keys; tokens are projected by tying to token_emb)
self.register_buffer("pos_idx", torch.arange(0, n_positions).long(), persistent=False)
def forward(self, input_ids):
B, T = input_ids.shape
pos = self.pos_idx[:T]
x = self.token_emb(input_ids) + self.pos_emb(pos)[None, :, :]
for blk in self.blocks:
x = blk(x)
# weight tying: logits = x @ token_emb^T
logits = x @ self.token_emb.weight.t()
return logits
@torch.no_grad()
def generate(self, input_ids, max_new_tokens=128, temperature=0.7, eos_token_id=None):
self.eval()
for _ in range(int(max_new_tokens)):
logits = self.forward(input_ids)[:, -1, :] # (B, V)
if temperature and temperature > 0.0:
probs = F.softmax(logits / max(0.01, temperature), dim=-1)
next_id = torch.multinomial(probs, num_samples=1)
else:
next_id = torch.argmax(logits, dim=-1, keepdim=True)
input_ids = torch.cat([input_ids, next_id], dim=1)
if eos_token_id is not None and (next_id == eos_token_id).all():
break
return input_ids