evo-gov-copilot-mu / evo_plugin.py
HemanM's picture
Update evo_plugin.py
bdf9229 verified
# evo_plugin.py β€” robust path fallback + clear errors
import os, re, torch
# Optional SentencePiece tokenizer
try:
import sentencepiece as spm
_HAS_SPM = True
except Exception:
_HAS_SPM = False
from evo_core_gpt import EvoGPT # you already added this
SEARCH_PATHS_WEIGHTS = [
os.environ.get("EVO_DECODER_PATH", ""), # allow override
"models/evo_decoder.pt",
"evo_decoder.pt", # root fallback
]
SEARCH_PATHS_SPM = [
os.environ.get("EVO_SPM_PATH", ""),
"models/evo_tokenizer.model",
"evo_tokenizer.model", # root fallback
]
def _first_existing(paths):
for p in paths:
if p and os.path.exists(p):
return p
return None
def _infer_config(sd):
vocab_size, d_model = sd["token_emb.weight"].shape
n_positions = sd["pos_emb.weight"].shape[0]
layer_ids = sorted({int(re.findall(r"blocks\.(\d+)\.", k)[0])
for k in sd.keys() if k.startswith("blocks.")})
n_layers = 1 + max(layer_ids) if layer_ids else 1
return vocab_size, d_model, n_layers, n_positions
class _SPTokenizer:
def __init__(self, spm_path):
if not _HAS_SPM:
raise RuntimeError("sentencepiece not installed; add 'sentencepiece' to requirements.txt")
self.sp = spm.SentencePieceProcessor(model_file=spm_path)
self.bos_id = self.sp.bos_id() if self.sp.bos_id() >= 0 else None
self.eos_id = self.sp.eos_id() if self.sp.eos_id() >= 0 else None
def encode(self, text): return self.sp.encode(text, out_type=int)
def decode(self, ids): return self.sp.decode(ids)
class EvoTextGenerator:
def __init__(self, num_heads:int=None, device:str="cpu"):
self.device = torch.device(device)
# 1) find weights
weights_path = _first_existing(SEARCH_PATHS_WEIGHTS)
if not weights_path:
raise FileNotFoundError("evo_decoder.pt not found. Put it at models/evo_decoder.pt or repo root, or set EVO_DECODER_PATH.")
sd = torch.load(weights_path, map_location="cpu")
# 2) infer config and heads
vocab_size, d_model, n_layers, n_positions = _infer_config(sd)
if num_heads is None:
# choose a divisor of d_model (8 is a safe default for many small models)
for h in (8, 12, 16, 4):
if d_model % h == 0:
num_heads = h; break
else:
raise ValueError(f"Pick num_heads that divides d_model={d_model}")
# 3) build model & load state
self.model = EvoGPT(vocab_size, d_model, n_layers, n_positions, num_heads=num_heads).to(self.device)
missing, unexpected = self.model.load_state_dict(sd, strict=False)
if missing or unexpected:
print("[EvoGPT] load_state_dict notice -> missing:", missing, "unexpected:", unexpected)
self.model.eval()
# 4) tokenizer (optional until you add it)
spm_path = _first_existing(SEARCH_PATHS_SPM)
if spm_path:
self.tok = _SPTokenizer(spm_path)
else:
# No tokenizer yet β†’ raise a clear error so evo_inference falls back to FLAN
raise FileNotFoundError(
"Tokenizer not found (expected models/evo_tokenizer.model or EVO_SPM_PATH). "
"Upload the exact SentencePiece model used in training."
)
@torch.no_grad()
def generate(self, prompt: str, max_new_tokens: int = 200, temperature: float = 0.4) -> str:
ids = self.tok.encode(prompt)
if self.tok.bos_id is not None and (not ids or ids[0] != self.tok.bos_id):
ids = [self.tok.bos_id] + ids
input_ids = torch.tensor([ids], dtype=torch.long, device=self.device)
out_ids = self.model.generate(
input_ids,
max_new_tokens=int(max_new_tokens),
temperature=float(temperature),
eos_token_id=self.tok.eos_id,
)[0].tolist()
return self.tok.decode(out_ids)
def load_model():
# evo_inference will try this; if it raises (e.g., no tokenizer), it falls back to FLAN example
return EvoTextGenerator()