File size: 4,153 Bytes
bdf9229
 
05084ef
bdf9229
05084ef
 
 
 
 
 
bdf9229
05084ef
bdf9229
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
05084ef
 
bdf9229
05084ef
 
 
 
 
bdf9229
 
05084ef
 
bdf9229
05084ef
bdf9229
 
 
 
 
05084ef
 
bdf9229
 
 
 
 
 
 
 
 
 
 
05084ef
 
 
 
 
 
bdf9229
 
 
 
 
 
 
 
 
 
05084ef
 
 
 
 
 
 
 
 
 
 
 
 
bdf9229
05084ef
 
bdf9229
05084ef
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
# evo_plugin.py — robust path fallback + clear errors
import os, re, torch

# Optional SentencePiece tokenizer
try:
    import sentencepiece as spm
    _HAS_SPM = True
except Exception:
    _HAS_SPM = False

from evo_core_gpt import EvoGPT  # you already added this

SEARCH_PATHS_WEIGHTS = [
    os.environ.get("EVO_DECODER_PATH", ""),   # allow override
    "models/evo_decoder.pt",
    "evo_decoder.pt",                         # root fallback
]
SEARCH_PATHS_SPM = [
    os.environ.get("EVO_SPM_PATH", ""),
    "models/evo_tokenizer.model",
    "evo_tokenizer.model",                    # root fallback
]

def _first_existing(paths):
    for p in paths:
        if p and os.path.exists(p):
            return p
    return None

def _infer_config(sd):
    vocab_size, d_model = sd["token_emb.weight"].shape
    n_positions = sd["pos_emb.weight"].shape[0]
    layer_ids = sorted({int(re.findall(r"blocks\.(\d+)\.", k)[0])
                        for k in sd.keys() if k.startswith("blocks.")})
    n_layers = 1 + max(layer_ids) if layer_ids else 1
    return vocab_size, d_model, n_layers, n_positions

class _SPTokenizer:
    def __init__(self, spm_path):
        if not _HAS_SPM:
            raise RuntimeError("sentencepiece not installed; add 'sentencepiece' to requirements.txt")
        self.sp = spm.SentencePieceProcessor(model_file=spm_path)
        self.bos_id = self.sp.bos_id() if self.sp.bos_id() >= 0 else None
        self.eos_id = self.sp.eos_id() if self.sp.eos_id() >= 0 else None
    def encode(self, text): return self.sp.encode(text, out_type=int)
    def decode(self, ids):  return self.sp.decode(ids)

class EvoTextGenerator:
    def __init__(self, num_heads:int=None, device:str="cpu"):
        self.device = torch.device(device)

        # 1) find weights
        weights_path = _first_existing(SEARCH_PATHS_WEIGHTS)
        if not weights_path:
            raise FileNotFoundError("evo_decoder.pt not found. Put it at models/evo_decoder.pt or repo root, or set EVO_DECODER_PATH.")
        sd = torch.load(weights_path, map_location="cpu")

        # 2) infer config and heads
        vocab_size, d_model, n_layers, n_positions = _infer_config(sd)
        if num_heads is None:
            # choose a divisor of d_model (8 is a safe default for many small models)
            for h in (8, 12, 16, 4):
                if d_model % h == 0:
                    num_heads = h; break
            else:
                raise ValueError(f"Pick num_heads that divides d_model={d_model}")

        # 3) build model & load state
        self.model = EvoGPT(vocab_size, d_model, n_layers, n_positions, num_heads=num_heads).to(self.device)
        missing, unexpected = self.model.load_state_dict(sd, strict=False)
        if missing or unexpected:
            print("[EvoGPT] load_state_dict notice -> missing:", missing, "unexpected:", unexpected)
        self.model.eval()

        # 4) tokenizer (optional until you add it)
        spm_path = _first_existing(SEARCH_PATHS_SPM)
        if spm_path:
            self.tok = _SPTokenizer(spm_path)
        else:
            # No tokenizer yet → raise a clear error so evo_inference falls back to FLAN
            raise FileNotFoundError(
                "Tokenizer not found (expected models/evo_tokenizer.model or EVO_SPM_PATH). "
                "Upload the exact SentencePiece model used in training."
            )

    @torch.no_grad()
    def generate(self, prompt: str, max_new_tokens: int = 200, temperature: float = 0.4) -> str:
        ids = self.tok.encode(prompt)
        if self.tok.bos_id is not None and (not ids or ids[0] != self.tok.bos_id):
            ids = [self.tok.bos_id] + ids
        input_ids = torch.tensor([ids], dtype=torch.long, device=self.device)
        out_ids = self.model.generate(
            input_ids,
            max_new_tokens=int(max_new_tokens),
            temperature=float(temperature),
            eos_token_id=self.tok.eos_id,
        )[0].tolist()
        return self.tok.decode(out_ids)

def load_model():
    # evo_inference will try this; if it raises (e.g., no tokenizer), it falls back to FLAN example
    return EvoTextGenerator()