HemanM commited on
Commit
05084ef
·
verified ·
1 Parent(s): ecf4bc9

Create evo_plugin.py

Browse files
Files changed (1) hide show
  1. evo_plugin.py +78 -0
evo_plugin.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # evo_plugin.py — REAL Evo decoder integration (state_dict -> generation)
2
+ import os, torch
3
+ from typing import Optional
4
+
5
+ # If you use SentencePiece:
6
+ try:
7
+ import sentencepiece as spm
8
+ _HAS_SPM = True
9
+ except Exception:
10
+ _HAS_SPM = False
11
+
12
+ from evo_core_gpt import EvoGPT
13
+
14
+ MODEL_PATH = os.environ.get("EVO_DECODER_PATH", "models/evo_decoder.pt")
15
+ SPM_PATH = os.environ.get("EVO_SPM_PATH", "models/evo_tokenizer.model") # SentencePiece file
16
+ NUM_HEADS = int(os.environ.get("EVO_NUM_HEADS", "8")) # <-- set this to your trained value
17
+
18
+ class _SPTokenizer:
19
+ def __init__(self, spm_path: str):
20
+ if not _HAS_SPM:
21
+ raise RuntimeError("sentencepiece not installed; add 'sentencepiece' to requirements.txt")
22
+ self.sp = spm.SentencePieceProcessor(model_file=spm_path)
23
+ # Try to detect special tokens if your model has them in the SPM model; else leave None
24
+ self.bos_id = self.sp.bos_id() if self.sp.bos_id() >= 0 else None
25
+ self.eos_id = self.sp.eos_id() if self.sp.eos_id() >= 0 else None
26
+
27
+ def encode(self, text: str):
28
+ return self.sp.encode(text, out_type=int)
29
+
30
+ def decode(self, ids):
31
+ return self.sp.decode(ids)
32
+
33
+ class EvoTextGenerator:
34
+ def __init__(self, weights_path: str = MODEL_PATH, spm_path: str = SPM_PATH, num_heads: int = NUM_HEADS, device: str = "cpu"):
35
+ self.device = torch.device(device)
36
+ # 1) peek shapes from state_dict
37
+ sd = torch.load(weights_path, map_location="cpu")
38
+ # infer config
39
+ vocab_size, d_model = sd["token_emb.weight"].shape
40
+ n_positions = sd["pos_emb.weight"].shape[0]
41
+ # count layers by scanning keys "blocks.N."
42
+ import re
43
+ layer_ids = sorted({int(re.findall(r"blocks\.(\d+)\.", k)[0]) for k in sd.keys() if k.startswith("blocks.")})
44
+ n_layers = 1 + max(layer_ids) if layer_ids else 1
45
+
46
+ # 2) build model and load weights
47
+ self.model = EvoGPT(vocab_size, d_model, n_layers, n_positions, num_heads=num_heads).to(self.device)
48
+ missing, unexpected = self.model.load_state_dict(sd, strict=False)
49
+ # Usually both lists should be empty; print if needed:
50
+ if missing or unexpected:
51
+ print("[EvoGPT] load_state_dict notice -> missing:", missing, "unexpected:", unexpected)
52
+ self.model.eval()
53
+
54
+ # 3) tokenizer
55
+ if not os.path.exists(spm_path):
56
+ raise FileNotFoundError(f"Tokenizer not found at {spm_path}. Provide SentencePiece model via models/evo_tokenizer.model or set EVO_SPM_PATH.")
57
+ self.tok = _SPTokenizer(spm_path)
58
+
59
+ @torch.no_grad()
60
+ def generate(self, prompt: str, max_new_tokens: int = 200, temperature: float = 0.4) -> str:
61
+ # encode
62
+ ids = self.tok.encode(prompt)
63
+ if self.tok.bos_id is not None and (not ids or ids[0] != self.tok.bos_id):
64
+ ids = [self.tok.bos_id] + ids
65
+ input_ids = torch.tensor([ids], dtype=torch.long, device=self.device)
66
+ # generate
67
+ out_ids = self.model.generate(
68
+ input_ids,
69
+ max_new_tokens=int(max_new_tokens),
70
+ temperature=float(temperature),
71
+ eos_token_id=self.tok.eos_id,
72
+ )[0].tolist()
73
+ # strip the prompt tokens if desired; for now decode all
74
+ text = self.tok.decode(out_ids)
75
+ return text
76
+
77
+ def load_model():
78
+ return EvoTextGenerator()