dlaima commited on
Commit
b6e2509
·
verified ·
1 Parent(s): da58915

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -6
app.py CHANGED
@@ -40,9 +40,6 @@ class LocalBartModel:
40
  self.model.eval()
41
 
42
  def __call__(self, prompt: str) -> str:
43
- if not isinstance(prompt, str):
44
- raise ValueError(f"LocalBartModel expects a string prompt, got {type(prompt)}")
45
-
46
  inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
47
  with torch.no_grad():
48
  outputs = self.model.generate(
@@ -52,8 +49,11 @@ class LocalBartModel:
52
  num_beams=5,
53
  early_stopping=True,
54
  )
55
- output_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
56
- return output_text.strip()
 
 
 
57
 
58
  class GaiaAgent:
59
  def __init__(self):
@@ -86,7 +86,7 @@ class GaiaAgent:
86
  elif isinstance(result, str):
87
  return result.strip()
88
  elif isinstance(result, list):
89
- # Try to find assistant response content in list
90
  for item in reversed(result):
91
  if isinstance(item, dict) and item.get("role") == "assistant" and "content" in item:
92
  return item["content"].strip()
 
40
  self.model.eval()
41
 
42
  def __call__(self, prompt: str) -> str:
 
 
 
43
  inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
44
  with torch.no_grad():
45
  outputs = self.model.generate(
 
49
  num_beams=5,
50
  early_stopping=True,
51
  )
52
+ return self.tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
53
+
54
+ def generate(self, *args, **kwargs):
55
+ # Forward generate calls for ToolCallingAgent compatibility
56
+ return self.model.generate(*args, **kwargs)
57
 
58
  class GaiaAgent:
59
  def __init__(self):
 
86
  elif isinstance(result, str):
87
  return result.strip()
88
  elif isinstance(result, list):
89
+ # Find assistant content if possible
90
  for item in reversed(result):
91
  if isinstance(item, dict) and item.get("role") == "assistant" and "content" in item:
92
  return item["content"].strip()