dlaima commited on
Commit
abf7526
·
verified ·
1 Parent(s): fa712b5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -7
app.py CHANGED
@@ -38,19 +38,28 @@ class LocalBartModel:
38
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
39
  self.model.to(self.device)
40
 
41
- def generate(self, input_ids, **generate_kwargs):
42
- # Defensive: convert list input_ids to tensor if needed
43
- if isinstance(input_ids, list):
44
- input_ids = torch.tensor(input_ids)
 
 
 
45
  input_ids = input_ids.to(self.device)
46
- return self.model.generate(input_ids, **generate_kwargs)
 
 
 
 
 
 
 
47
 
48
  def __call__(self, prompt: str) -> str:
49
  inputs = self.tokenizer(prompt, return_tensors="pt")
50
- input_ids = inputs.input_ids # tensor here
51
 
52
  output_ids = self.generate(
53
- input_ids,
54
  max_length=100,
55
  num_beams=5,
56
  early_stopping=True
 
38
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
39
  self.model.to(self.device)
40
 
41
+ def generate(self, inputs, **generate_kwargs):
42
+ input_ids = inputs.get("input_ids")
43
+ attention_mask = inputs.get("attention_mask")
44
+
45
+ if input_ids is None:
46
+ raise ValueError("input_ids missing from tokenizer output")
47
+
48
  input_ids = input_ids.to(self.device)
49
+ if attention_mask is not None:
50
+ attention_mask = attention_mask.to(self.device)
51
+
52
+ return self.model.generate(
53
+ input_ids=input_ids,
54
+ attention_mask=attention_mask,
55
+ **generate_kwargs
56
+ )
57
 
58
  def __call__(self, prompt: str) -> str:
59
  inputs = self.tokenizer(prompt, return_tensors="pt")
 
60
 
61
  output_ids = self.generate(
62
+ inputs,
63
  max_length=100,
64
  num_beams=5,
65
  early_stopping=True