HemanM commited on
Commit
1794bf5
·
verified ·
1 Parent(s): 1861d4a

Update evo_plugin_example.py

Browse files
Files changed (1) hide show
  1. evo_plugin_example.py +30 -14
evo_plugin_example.py CHANGED
@@ -1,4 +1,4 @@
1
- # evo_plugin_example.py — FLAN-T5 stand-in (anti-echo tuned)
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
 
@@ -7,28 +7,44 @@ class _HFSeq2SeqGenerator:
7
  self.device = torch.device("cpu")
8
  self.tok = AutoTokenizer.from_pretrained(model_name)
9
  self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(self.device).eval()
 
 
 
 
10
 
11
  @torch.no_grad()
12
  def generate(self, prompt: str, max_new_tokens: int = 200, temperature: float = 0.0) -> str:
13
- inputs = self.tok(prompt, return_tensors="pt").to(self.device)
 
 
 
 
 
 
14
 
15
- # Encourage non-trivial length and reduce repeats
16
- min_new = max(48, int(0.4 * max_new_tokens))
17
- out = self.model.generate(
18
- **inputs,
19
- max_length=inputs["input_ids"].shape[1] + int(max_new_tokens),
20
- min_length=inputs["input_ids"].shape[1] + int(min_new),
21
- do_sample=temperature > 0.0,
22
- temperature=float(max(0.01, temperature)),
23
- top_p=0.9,
24
- num_beams=4,
25
  early_stopping=True,
26
  no_repeat_ngram_size=3,
27
  repetition_penalty=1.1,
28
  length_penalty=0.1,
29
  )
30
- text = self.tok.decode(out[0], skip_special_tokens=True).strip()
31
- return text
 
 
 
 
 
 
 
 
 
 
 
32
 
33
  def load_model():
34
  return _HFSeq2SeqGenerator()
 
1
+ # evo_plugin_example.py — FLAN-T5 stand-in (truncation + clean kwargs)
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
 
 
7
  self.device = torch.device("cpu")
8
  self.tok = AutoTokenizer.from_pretrained(model_name)
9
  self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(self.device).eval()
10
+ # FLAN-T5 encoder max length
11
+ ml = getattr(self.tok, "model_max_length", 512) or 512
12
+ # Some tokenizers report a huge sentinel value; clamp to 512 for T5-small
13
+ self.max_src_len = min(512, int(ml if ml < 10000 else 512))
14
 
15
  @torch.no_grad()
16
  def generate(self, prompt: str, max_new_tokens: int = 200, temperature: float = 0.0) -> str:
17
+ # TRUNCATE input to model's max encoder length
18
+ inputs = self.tok(
19
+ prompt,
20
+ return_tensors="pt",
21
+ truncation=True,
22
+ max_length=self.max_src_len,
23
+ ).to(self.device)
24
 
25
+ do_sample = float(temperature) > 0.0
26
+
27
+ gen_kwargs = dict(
28
+ max_new_tokens=int(max_new_tokens),
29
+ num_beams=4, # stable, less echo
 
 
 
 
 
30
  early_stopping=True,
31
  no_repeat_ngram_size=3,
32
  repetition_penalty=1.1,
33
  length_penalty=0.1,
34
  )
35
+ # Only include sampling args when sampling is ON (silences warnings)
36
+ if do_sample:
37
+ gen_kwargs.update(
38
+ do_sample=True,
39
+ temperature=float(max(0.01, temperature)),
40
+ top_p=0.9,
41
+ )
42
+
43
+ # Encourage non-trivial length without tying to input length
44
+ gen_kwargs["min_new_tokens"] = max(48, int(0.4 * max_new_tokens))
45
+
46
+ out = self.model.generate(**inputs, **gen_kwargs)
47
+ return self.tok.decode(out[0], skip_special_tokens=True).strip()
48
 
49
  def load_model():
50
  return _HFSeq2SeqGenerator()