Spaces:
Sleeping
Sleeping
Update evo_plugin_example.py
Browse files- evo_plugin_example.py +13 -10
evo_plugin_example.py
CHANGED
@@ -1,31 +1,34 @@
|
|
1 |
-
# evo_plugin_example.py —
|
2 |
-
# The app will use YOUR evo_plugin.py if present; otherwise it falls back to this.
|
3 |
import torch
|
4 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
5 |
|
6 |
class _HFSeq2SeqGenerator:
|
7 |
def __init__(self, model_name: str = "google/flan-t5-small"):
|
8 |
-
# CPU is fine for demos; no GPU required on HF Spaces basic CPU.
|
9 |
self.device = torch.device("cpu")
|
10 |
self.tok = AutoTokenizer.from_pretrained(model_name)
|
11 |
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(self.device).eval()
|
12 |
|
13 |
@torch.no_grad()
|
14 |
-
def generate(self, prompt: str, max_new_tokens: int = 200, temperature: float = 0.
|
15 |
inputs = self.tok(prompt, return_tensors="pt").to(self.device)
|
|
|
|
|
|
|
16 |
out = self.model.generate(
|
17 |
**inputs,
|
18 |
-
|
|
|
19 |
do_sample=temperature > 0.0,
|
20 |
temperature=float(max(0.01, temperature)),
|
21 |
top_p=0.9,
|
22 |
-
num_beams=4,
|
23 |
early_stopping=True,
|
24 |
-
no_repeat_ngram_size=3,
|
|
|
|
|
25 |
)
|
26 |
-
|
27 |
-
|
28 |
|
29 |
def load_model():
|
30 |
-
# The app calls this to obtain a generator instance.
|
31 |
return _HFSeq2SeqGenerator()
|
|
|
1 |
+
# evo_plugin_example.py — FLAN-T5 stand-in (anti-echo tuned)
|
|
|
2 |
import torch
|
3 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
4 |
|
5 |
class _HFSeq2SeqGenerator:
|
6 |
def __init__(self, model_name: str = "google/flan-t5-small"):
|
|
|
7 |
self.device = torch.device("cpu")
|
8 |
self.tok = AutoTokenizer.from_pretrained(model_name)
|
9 |
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(self.device).eval()
|
10 |
|
11 |
@torch.no_grad()
|
12 |
+
def generate(self, prompt: str, max_new_tokens: int = 200, temperature: float = 0.0) -> str:
|
13 |
inputs = self.tok(prompt, return_tensors="pt").to(self.device)
|
14 |
+
|
15 |
+
# Encourage non-trivial length and reduce repeats
|
16 |
+
min_new = max(48, int(0.4 * max_new_tokens))
|
17 |
out = self.model.generate(
|
18 |
**inputs,
|
19 |
+
max_length=inputs["input_ids"].shape[1] + int(max_new_tokens),
|
20 |
+
min_length=inputs["input_ids"].shape[1] + int(min_new),
|
21 |
do_sample=temperature > 0.0,
|
22 |
temperature=float(max(0.01, temperature)),
|
23 |
top_p=0.9,
|
24 |
+
num_beams=4,
|
25 |
early_stopping=True,
|
26 |
+
no_repeat_ngram_size=3,
|
27 |
+
repetition_penalty=1.1,
|
28 |
+
length_penalty=0.1,
|
29 |
)
|
30 |
+
text = self.tok.decode(out[0], skip_special_tokens=True).strip()
|
31 |
+
return text
|
32 |
|
33 |
def load_model():
|
|
|
34 |
return _HFSeq2SeqGenerator()
|