dlaima commited on
Commit
da58915
·
verified ·
1 Parent(s): 79fcd3e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -31
app.py CHANGED
@@ -3,8 +3,8 @@ import os
3
  import gradio as gr
4
  import requests
5
  import pandas as pd
6
- from transformers import BartTokenizer, BartForConditionalGeneration
7
  import torch
 
8
 
9
  from smolagents import ToolCallingAgent
10
  from audio_transcriber import AudioTranscriptionTool
@@ -39,43 +39,26 @@ class LocalBartModel:
39
  self.model.to(self.device)
40
  self.model.eval()
41
 
42
- def generate(self, inputs, **generate_kwargs):
43
- # inputs must be dict with input_ids and attention_mask
44
- if not isinstance(inputs, dict):
45
- raise ValueError(f"Expected dict input but got {type(inputs)}")
46
- input_ids = inputs.get("input_ids")
47
- attention_mask = inputs.get("attention_mask")
48
- if input_ids is None or attention_mask is None:
49
- raise ValueError("input_ids and attention_mask are required in inputs dict")
50
-
51
- input_ids = input_ids.to(self.device)
52
- attention_mask = attention_mask.to(self.device)
53
 
 
54
  with torch.no_grad():
55
  outputs = self.model.generate(
56
- input_ids=input_ids,
57
- attention_mask=attention_mask,
58
- **generate_kwargs
 
 
59
  )
60
- return outputs
61
-
62
- def __call__(self, prompt):
63
- if not isinstance(prompt, str):
64
- raise ValueError(f"LocalBartModel expects a string prompt, got {type(prompt)}")
65
-
66
- inputs = self.tokenizer(prompt, return_tensors="pt")
67
- output_ids = self.generate(
68
- inputs,
69
- max_length=100,
70
- num_beams=5,
71
- early_stopping=True
72
- )
73
- output_text = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
74
  return output_text.strip()
75
 
76
  class GaiaAgent:
77
  def __init__(self):
78
  print("Gaia Agent Initialized")
 
79
  self.model = LocalBartModel()
80
 
81
  self.tools = [
@@ -91,19 +74,19 @@ class GaiaAgent:
91
 
92
  def __call__(self, question: str) -> str:
93
  print(f"Agent received question (first 50 chars): {question[:50]}...")
 
94
  full_prompt = f"{SYSTEM_PROMPT}\nQUESTION:\n{question}"
95
 
96
  try:
97
  result = self.agent.run(full_prompt)
98
  print(f"Raw result from agent: {result}")
99
 
100
- # Handle different result types robustly
101
  if isinstance(result, dict) and "answer" in result:
102
  return str(result["answer"]).strip()
103
  elif isinstance(result, str):
104
  return result.strip()
105
  elif isinstance(result, list):
106
- # Try to extract assistant content from list
107
  for item in reversed(result):
108
  if isinstance(item, dict) and item.get("role") == "assistant" and "content" in item:
109
  return item["content"].strip()
@@ -244,3 +227,4 @@ if __name__ == "__main__":
244
 
245
 
246
 
 
 
3
  import gradio as gr
4
  import requests
5
  import pandas as pd
 
6
  import torch
7
+ from transformers import BartTokenizer, BartForConditionalGeneration
8
 
9
  from smolagents import ToolCallingAgent
10
  from audio_transcriber import AudioTranscriptionTool
 
39
  self.model.to(self.device)
40
  self.model.eval()
41
 
42
+ def __call__(self, prompt: str) -> str:
43
+ if not isinstance(prompt, str):
44
+ raise ValueError(f"LocalBartModel expects a string prompt, got {type(prompt)}")
 
 
 
 
 
 
 
 
45
 
46
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
47
  with torch.no_grad():
48
  outputs = self.model.generate(
49
+ input_ids=inputs["input_ids"],
50
+ attention_mask=inputs["attention_mask"],
51
+ max_length=100,
52
+ num_beams=5,
53
+ early_stopping=True,
54
  )
55
+ output_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  return output_text.strip()
57
 
58
  class GaiaAgent:
59
  def __init__(self):
60
  print("Gaia Agent Initialized")
61
+
62
  self.model = LocalBartModel()
63
 
64
  self.tools = [
 
74
 
75
  def __call__(self, question: str) -> str:
76
  print(f"Agent received question (first 50 chars): {question[:50]}...")
77
+
78
  full_prompt = f"{SYSTEM_PROMPT}\nQUESTION:\n{question}"
79
 
80
  try:
81
  result = self.agent.run(full_prompt)
82
  print(f"Raw result from agent: {result}")
83
 
 
84
  if isinstance(result, dict) and "answer" in result:
85
  return str(result["answer"]).strip()
86
  elif isinstance(result, str):
87
  return result.strip()
88
  elif isinstance(result, list):
89
+ # Try to find assistant response content in list
90
  for item in reversed(result):
91
  if isinstance(item, dict) and item.get("role") == "assistant" and "content" in item:
92
  return item["content"].strip()
 
227
 
228
 
229
 
230
+