File size: 1,959 Bytes
1794bf5
e94ab3b
b7222e2
e94ab3b
b7222e2
 
e94ab3b
 
b7222e2
1794bf5
 
 
 
e94ab3b
 
1861d4a
1794bf5
 
 
 
 
 
 
1861d4a
1794bf5
 
 
 
 
ffe2489
1861d4a
 
 
e94ab3b
1794bf5
 
 
 
 
 
 
 
 
 
 
 
 
ffe2489
e94ab3b
b7222e2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
# evo_plugin_example.py — FLAN-T5 stand-in (truncation + clean kwargs)
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

class _HFSeq2SeqGenerator:
    def __init__(self, model_name: str = "google/flan-t5-small"):
        self.device = torch.device("cpu")
        self.tok = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(self.device).eval()
        # FLAN-T5 encoder max length
        ml = getattr(self.tok, "model_max_length", 512) or 512
        # Some tokenizers report a huge sentinel value; clamp to 512 for T5-small
        self.max_src_len = min(512, int(ml if ml < 10000 else 512))

    @torch.no_grad()
    def generate(self, prompt: str, max_new_tokens: int = 200, temperature: float = 0.0) -> str:
        # TRUNCATE input to model's max encoder length
        inputs = self.tok(
            prompt,
            return_tensors="pt",
            truncation=True,
            max_length=self.max_src_len,
        ).to(self.device)

        do_sample = float(temperature) > 0.0

        gen_kwargs = dict(
            max_new_tokens=int(max_new_tokens),
            num_beams=4,                 # stable, less echo
            early_stopping=True,
            no_repeat_ngram_size=3,
            repetition_penalty=1.1,
            length_penalty=0.1,
        )
        # Only include sampling args when sampling is ON (silences warnings)
        if do_sample:
            gen_kwargs.update(
                do_sample=True,
                temperature=float(max(0.01, temperature)),
                top_p=0.9,
            )

        # Encourage non-trivial length without tying to input length
        gen_kwargs["min_new_tokens"] = max(48, int(0.4 * max_new_tokens))

        out = self.model.generate(**inputs, **gen_kwargs)
        return self.tok.decode(out[0], skip_special_tokens=True).strip()

def load_model():
    return _HFSeq2SeqGenerator()