Spaces:
Sleeping
Sleeping
File size: 1,959 Bytes
1794bf5 e94ab3b b7222e2 e94ab3b b7222e2 e94ab3b b7222e2 1794bf5 e94ab3b 1861d4a 1794bf5 1861d4a 1794bf5 ffe2489 1861d4a e94ab3b 1794bf5 ffe2489 e94ab3b b7222e2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
# evo_plugin_example.py — FLAN-T5 stand-in (truncation + clean kwargs)
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
class _HFSeq2SeqGenerator:
def __init__(self, model_name: str = "google/flan-t5-small"):
self.device = torch.device("cpu")
self.tok = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(self.device).eval()
# FLAN-T5 encoder max length
ml = getattr(self.tok, "model_max_length", 512) or 512
# Some tokenizers report a huge sentinel value; clamp to 512 for T5-small
self.max_src_len = min(512, int(ml if ml < 10000 else 512))
@torch.no_grad()
def generate(self, prompt: str, max_new_tokens: int = 200, temperature: float = 0.0) -> str:
# TRUNCATE input to model's max encoder length
inputs = self.tok(
prompt,
return_tensors="pt",
truncation=True,
max_length=self.max_src_len,
).to(self.device)
do_sample = float(temperature) > 0.0
gen_kwargs = dict(
max_new_tokens=int(max_new_tokens),
num_beams=4, # stable, less echo
early_stopping=True,
no_repeat_ngram_size=3,
repetition_penalty=1.1,
length_penalty=0.1,
)
# Only include sampling args when sampling is ON (silences warnings)
if do_sample:
gen_kwargs.update(
do_sample=True,
temperature=float(max(0.01, temperature)),
top_p=0.9,
)
# Encourage non-trivial length without tying to input length
gen_kwargs["min_new_tokens"] = max(48, int(0.4 * max_new_tokens))
out = self.model.generate(**inputs, **gen_kwargs)
return self.tok.decode(out[0], skip_special_tokens=True).strip()
def load_model():
return _HFSeq2SeqGenerator()
|