HemanM commited on
Commit
1861d4a
·
verified ·
1 Parent(s): af358ab

Update evo_plugin_example.py

Browse files
Files changed (1) hide show
  1. evo_plugin_example.py +13 -10
evo_plugin_example.py CHANGED
@@ -1,31 +1,34 @@
1
- # evo_plugin_example.py — small, instruction-following stand-in generator
2
- # The app will use YOUR evo_plugin.py if present; otherwise it falls back to this.
3
  import torch
4
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
5
 
6
  class _HFSeq2SeqGenerator:
7
  def __init__(self, model_name: str = "google/flan-t5-small"):
8
- # CPU is fine for demos; no GPU required on HF Spaces basic CPU.
9
  self.device = torch.device("cpu")
10
  self.tok = AutoTokenizer.from_pretrained(model_name)
11
  self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(self.device).eval()
12
 
13
  @torch.no_grad()
14
- def generate(self, prompt: str, max_new_tokens: int = 200, temperature: float = 0.4) -> str:
15
  inputs = self.tok(prompt, return_tensors="pt").to(self.device)
 
 
 
16
  out = self.model.generate(
17
  **inputs,
18
- max_new_tokens=int(max_new_tokens),
 
19
  do_sample=temperature > 0.0,
20
  temperature=float(max(0.01, temperature)),
21
  top_p=0.9,
22
- num_beams=4, # beam search makes it less echo-y
23
  early_stopping=True,
24
- no_repeat_ngram_size=3, # avoid repeating phrases
 
 
25
  )
26
- return self.tok.decode(out[0], skip_special_tokens=True).strip()
27
-
28
 
29
  def load_model():
30
- # The app calls this to obtain a generator instance.
31
  return _HFSeq2SeqGenerator()
 
1
+ # evo_plugin_example.py — FLAN-T5 stand-in (anti-echo tuned)
 
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
 
5
  class _HFSeq2SeqGenerator:
6
  def __init__(self, model_name: str = "google/flan-t5-small"):
 
7
  self.device = torch.device("cpu")
8
  self.tok = AutoTokenizer.from_pretrained(model_name)
9
  self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(self.device).eval()
10
 
11
  @torch.no_grad()
12
+ def generate(self, prompt: str, max_new_tokens: int = 200, temperature: float = 0.0) -> str:
13
  inputs = self.tok(prompt, return_tensors="pt").to(self.device)
14
+
15
+ # Encourage non-trivial length and reduce repeats
16
+ min_new = max(48, int(0.4 * max_new_tokens))
17
  out = self.model.generate(
18
  **inputs,
19
+ max_length=inputs["input_ids"].shape[1] + int(max_new_tokens),
20
+ min_length=inputs["input_ids"].shape[1] + int(min_new),
21
  do_sample=temperature > 0.0,
22
  temperature=float(max(0.01, temperature)),
23
  top_p=0.9,
24
+ num_beams=4,
25
  early_stopping=True,
26
+ no_repeat_ngram_size=3,
27
+ repetition_penalty=1.1,
28
+ length_penalty=0.1,
29
  )
30
+ text = self.tok.decode(out[0], skip_special_tokens=True).strip()
31
+ return text
32
 
33
  def load_model():
 
34
  return _HFSeq2SeqGenerator()