HemanM commited on
Commit
ffe2489
·
verified ·
1 Parent(s): 3ae1eff

Update evo_plugin_example.py

Browse files
Files changed (1) hide show
  1. evo_plugin_example.py +5 -2
evo_plugin_example.py CHANGED
@@ -12,17 +12,20 @@ class _HFSeq2SeqGenerator:
12
 
13
  @torch.no_grad()
14
  def generate(self, prompt: str, max_new_tokens: int = 200, temperature: float = 0.4) -> str:
15
- # Seq2Seq models generate responses that follow instructions better than GPT-2 here.
16
  inputs = self.tok(prompt, return_tensors="pt").to(self.device)
17
  out = self.model.generate(
18
  **inputs,
19
  max_new_tokens=int(max_new_tokens),
20
  do_sample=temperature > 0.0,
21
  temperature=float(max(0.01, temperature)),
22
- top_p=0.95,
 
 
 
23
  )
24
  return self.tok.decode(out[0], skip_special_tokens=True).strip()
25
 
 
26
  def load_model():
27
  # The app calls this to obtain a generator instance.
28
  return _HFSeq2SeqGenerator()
 
12
 
13
  @torch.no_grad()
14
  def generate(self, prompt: str, max_new_tokens: int = 200, temperature: float = 0.4) -> str:
 
15
  inputs = self.tok(prompt, return_tensors="pt").to(self.device)
16
  out = self.model.generate(
17
  **inputs,
18
  max_new_tokens=int(max_new_tokens),
19
  do_sample=temperature > 0.0,
20
  temperature=float(max(0.01, temperature)),
21
+ top_p=0.9,
22
+ num_beams=4, # beam search makes it less echo-y
23
+ early_stopping=True,
24
+ no_repeat_ngram_size=3, # avoid repeating phrases
25
  )
26
  return self.tok.decode(out[0], skip_special_tokens=True).strip()
27
 
28
+
29
  def load_model():
30
  # The app calls this to obtain a generator instance.
31
  return _HFSeq2SeqGenerator()