dlaima commited on
Commit
fa712b5
·
verified ·
1 Parent(s): 9e16e60

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -4
app.py CHANGED
@@ -31,7 +31,6 @@ SYSTEM_PROMPT = (
31
  "Never say 'the answer is...'. Only return the answer.\n"
32
  )
33
 
34
- # Local wrapper for facebook/bart-base that exposes generate()
35
  class LocalBartModel:
36
  def __init__(self, model_name="facebook/bart-base"):
37
  self.tokenizer = BartTokenizer.from_pretrained(model_name)
@@ -40,12 +39,18 @@ class LocalBartModel:
40
  self.model.to(self.device)
41
 
42
  def generate(self, input_ids, **generate_kwargs):
43
- return self.model.generate(input_ids.to(self.device), **generate_kwargs)
 
 
 
 
44
 
45
  def __call__(self, prompt: str) -> str:
46
- inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
 
 
47
  output_ids = self.generate(
48
- inputs.input_ids,
49
  max_length=100,
50
  num_beams=5,
51
  early_stopping=True
 
31
  "Never say 'the answer is...'. Only return the answer.\n"
32
  )
33
 
 
34
  class LocalBartModel:
35
  def __init__(self, model_name="facebook/bart-base"):
36
  self.tokenizer = BartTokenizer.from_pretrained(model_name)
 
39
  self.model.to(self.device)
40
 
41
  def generate(self, input_ids, **generate_kwargs):
42
+ # Defensive: convert list input_ids to tensor if needed
43
+ if isinstance(input_ids, list):
44
+ input_ids = torch.tensor(input_ids)
45
+ input_ids = input_ids.to(self.device)
46
+ return self.model.generate(input_ids, **generate_kwargs)
47
 
48
  def __call__(self, prompt: str) -> str:
49
+ inputs = self.tokenizer(prompt, return_tensors="pt")
50
+ input_ids = inputs.input_ids # tensor here
51
+
52
  output_ids = self.generate(
53
+ input_ids,
54
  max_length=100,
55
  num_beams=5,
56
  early_stopping=True