Spaces:
Sleeping
Sleeping
Update evo_plugin_example.py
Browse files- evo_plugin_example.py +30 -14
evo_plugin_example.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
# evo_plugin_example.py — FLAN-T5 stand-in (
|
2 |
import torch
|
3 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
4 |
|
@@ -7,28 +7,44 @@ class _HFSeq2SeqGenerator:
|
|
7 |
self.device = torch.device("cpu")
|
8 |
self.tok = AutoTokenizer.from_pretrained(model_name)
|
9 |
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(self.device).eval()
|
|
|
|
|
|
|
|
|
10 |
|
11 |
@torch.no_grad()
|
12 |
def generate(self, prompt: str, max_new_tokens: int = 200, temperature: float = 0.0) -> str:
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
min_length=inputs["input_ids"].shape[1] + int(min_new),
|
21 |
-
do_sample=temperature > 0.0,
|
22 |
-
temperature=float(max(0.01, temperature)),
|
23 |
-
top_p=0.9,
|
24 |
-
num_beams=4,
|
25 |
early_stopping=True,
|
26 |
no_repeat_ngram_size=3,
|
27 |
repetition_penalty=1.1,
|
28 |
length_penalty=0.1,
|
29 |
)
|
30 |
-
|
31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
def load_model():
|
34 |
return _HFSeq2SeqGenerator()
|
|
|
1 |
+
# evo_plugin_example.py — FLAN-T5 stand-in (truncation + clean kwargs)
|
2 |
import torch
|
3 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
4 |
|
|
|
7 |
self.device = torch.device("cpu")
|
8 |
self.tok = AutoTokenizer.from_pretrained(model_name)
|
9 |
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(self.device).eval()
|
10 |
+
# FLAN-T5 encoder max length
|
11 |
+
ml = getattr(self.tok, "model_max_length", 512) or 512
|
12 |
+
# Some tokenizers report a huge sentinel value; clamp to 512 for T5-small
|
13 |
+
self.max_src_len = min(512, int(ml if ml < 10000 else 512))
|
14 |
|
15 |
@torch.no_grad()
|
16 |
def generate(self, prompt: str, max_new_tokens: int = 200, temperature: float = 0.0) -> str:
|
17 |
+
# TRUNCATE input to model's max encoder length
|
18 |
+
inputs = self.tok(
|
19 |
+
prompt,
|
20 |
+
return_tensors="pt",
|
21 |
+
truncation=True,
|
22 |
+
max_length=self.max_src_len,
|
23 |
+
).to(self.device)
|
24 |
|
25 |
+
do_sample = float(temperature) > 0.0
|
26 |
+
|
27 |
+
gen_kwargs = dict(
|
28 |
+
max_new_tokens=int(max_new_tokens),
|
29 |
+
num_beams=4, # stable, less echo
|
|
|
|
|
|
|
|
|
|
|
30 |
early_stopping=True,
|
31 |
no_repeat_ngram_size=3,
|
32 |
repetition_penalty=1.1,
|
33 |
length_penalty=0.1,
|
34 |
)
|
35 |
+
# Only include sampling args when sampling is ON (silences warnings)
|
36 |
+
if do_sample:
|
37 |
+
gen_kwargs.update(
|
38 |
+
do_sample=True,
|
39 |
+
temperature=float(max(0.01, temperature)),
|
40 |
+
top_p=0.9,
|
41 |
+
)
|
42 |
+
|
43 |
+
# Encourage non-trivial length without tying to input length
|
44 |
+
gen_kwargs["min_new_tokens"] = max(48, int(0.4 * max_new_tokens))
|
45 |
+
|
46 |
+
out = self.model.generate(**inputs, **gen_kwargs)
|
47 |
+
return self.tok.decode(out[0], skip_special_tokens=True).strip()
|
48 |
|
49 |
def load_model():
|
50 |
return _HFSeq2SeqGenerator()
|