HemanM commited on
Commit
6a117f9
·
verified ·
1 Parent(s): 2a6371b

Update evo_plugin_example.py

Browse files
Files changed (1) hide show
  1. evo_plugin_example.py +5 -1
evo_plugin_example.py CHANGED
@@ -1,15 +1,18 @@
1
- # evo_plugin_example.py — FLAN-T5 stand-in (better instruction following)
 
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
 
5
  class _HFSeq2SeqGenerator:
6
  def __init__(self, model_name: str = "google/flan-t5-small"):
 
7
  self.device = torch.device("cpu")
8
  self.tok = AutoTokenizer.from_pretrained(model_name)
9
  self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(self.device).eval()
10
 
11
  @torch.no_grad()
12
  def generate(self, prompt: str, max_new_tokens: int = 200, temperature: float = 0.4) -> str:
 
13
  inputs = self.tok(prompt, return_tensors="pt").to(self.device)
14
  out = self.model.generate(
15
  **inputs,
@@ -21,4 +24,5 @@ class _HFSeq2SeqGenerator:
21
  return self.tok.decode(out[0], skip_special_tokens=True).strip()
22
 
23
  def load_model():
 
24
  return _HFSeq2SeqGenerator()
 
1
+ # evo_plugin_example.py — small, instruction-following stand-in generator
2
+ # The app will use YOUR evo_plugin.py if present; otherwise it falls back to this.
3
  import torch
4
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
5
 
6
  class _HFSeq2SeqGenerator:
7
  def __init__(self, model_name: str = "google/flan-t5-small"):
8
+ # CPU is fine for demos; no GPU required on HF Spaces basic CPU.
9
  self.device = torch.device("cpu")
10
  self.tok = AutoTokenizer.from_pretrained(model_name)
11
  self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(self.device).eval()
12
 
13
  @torch.no_grad()
14
  def generate(self, prompt: str, max_new_tokens: int = 200, temperature: float = 0.4) -> str:
15
+ # Seq2Seq models generate responses that follow instructions better than GPT-2 here.
16
  inputs = self.tok(prompt, return_tensors="pt").to(self.device)
17
  out = self.model.generate(
18
  **inputs,
 
24
  return self.tok.decode(out[0], skip_special_tokens=True).strip()
25
 
26
  def load_model():
27
+ # The app calls this to obtain a generator instance.
28
  return _HFSeq2SeqGenerator()