dlaima commited on
Commit
62ad750
·
verified ·
1 Parent(s): e0e7440

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -42
app.py CHANGED
@@ -3,10 +3,7 @@ import os
3
  import gradio as gr
4
  import requests
5
  import pandas as pd
6
- import torch
7
- from transformers import BartTokenizer, BartForConditionalGeneration
8
-
9
- from smolagents import ToolCallingAgent
10
  from audio_transcriber import AudioTranscriptionTool
11
  from image_analyzer import ImageAnalysisTool
12
  from wikipedia_searcher import WikipediaSearcher
@@ -21,45 +18,22 @@ SYSTEM_PROMPT = (
21
  "3. For dates, use the exact requested format.\n"
22
  "4. For numbers, use only the number.\n"
23
  "5. For names, use the exact name from sources.\n"
24
- "6. If the question has a file, download it using the task ID.\n"
25
- "Examples:\n"
26
- "- '42'\n"
27
- "- 'Arturo Nunez'\n"
28
- "- 'Yes'\n"
29
- "- 'October 5, 2001'\n"
30
- "- 'Buenos Aires'\n"
31
  "Never say 'the answer is...'. Only return the answer.\n"
32
  )
33
 
34
- class LocalBartModel:
35
- def __init__(self):
36
- self.tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
37
- self.model = BartForConditionalGeneration.from_pretrained("facebook/bart-base")
38
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
39
- self.model.to(self.device)
40
- self.model.eval()
41
-
42
- def __call__(self, prompt: str) -> str:
43
- inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
44
- with torch.no_grad():
45
- outputs = self.model.generate(
46
- input_ids=inputs["input_ids"],
47
- attention_mask=inputs["attention_mask"],
48
- max_length=100,
49
- num_beams=5,
50
- early_stopping=True,
51
- )
52
- return self.tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
53
-
54
- def generate(self, *args, **kwargs):
55
- # Forward generate calls for ToolCallingAgent compatibility
56
- return self.model.generate(*args, **kwargs)
57
-
58
  class GaiaAgent:
59
  def __init__(self):
60
  print("Gaia Agent Initialized")
61
 
62
- self.model = LocalBartModel()
 
 
 
 
 
 
 
63
 
64
  self.tools = [
65
  AudioTranscriptionTool(),
@@ -72,12 +46,48 @@ class GaiaAgent:
72
  model=self.model
73
  )
74
 
75
- def __call__(self, question: str) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  print(f"Agent received question (first 50 chars): {question[:50]}...")
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  full_prompt = f"{SYSTEM_PROMPT}\nQUESTION:\n{question}"
79
 
80
  try:
 
 
 
 
 
 
 
 
81
  result = self.agent.run(full_prompt)
82
  print(f"Raw result from agent: {result}")
83
 
@@ -86,7 +96,6 @@ class GaiaAgent:
86
  elif isinstance(result, str):
87
  return result.strip()
88
  elif isinstance(result, list):
89
- # Find assistant content if possible
90
  for item in reversed(result):
91
  if isinstance(item, dict) and item.get("role") == "assistant" and "content" in item:
92
  return item["content"].strip()
@@ -135,21 +144,24 @@ def run_and_submit_all(profile: gr.OAuthProfile | None):
135
 
136
  for item in questions_data:
137
  task_id = item.get("task_id")
 
 
 
138
  if not task_id:
139
  continue
140
  try:
141
- submitted_answer = agent(item.get("question", ""))
142
  answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer})
143
  results_log.append({
144
  "Task ID": task_id,
145
- "Question": item.get("question", ""),
146
  "Submitted Answer": submitted_answer
147
  })
148
  except Exception as e:
149
  error_msg = f"AGENT ERROR: {e}"
150
  results_log.append({
151
  "Task ID": task_id,
152
- "Question": item.get("question", ""),
153
  "Submitted Answer": error_msg
154
  })
155
 
@@ -203,7 +215,7 @@ with gr.Blocks() as demo:
203
  status_output = gr.Textbox(label="Run Status / Submission Result", lines=5, interactive=False)
204
  results_table = gr.DataFrame(label="Questions and Agent Answers", wrap=True)
205
 
206
- run_button.click(fn=run_and_submit_all, outputs=[status_output, results_table])
207
 
208
  if __name__ == "__main__":
209
  print("\n" + "-"*30 + " App Starting " + "-"*30)
 
3
  import gradio as gr
4
  import requests
5
  import pandas as pd
6
+ from smolagents import ToolCallingAgent, OpenAIClientModel
 
 
 
7
  from audio_transcriber import AudioTranscriptionTool
8
  from image_analyzer import ImageAnalysisTool
9
  from wikipedia_searcher import WikipediaSearcher
 
18
  "3. For dates, use the exact requested format.\n"
19
  "4. For numbers, use only the number.\n"
20
  "5. For names, use the exact name from sources.\n"
21
+ "6. If the question has a file, download it using the task ID and process it.\n"
 
 
 
 
 
 
22
  "Never say 'the answer is...'. Only return the answer.\n"
23
  )
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  class GaiaAgent:
26
  def __init__(self):
27
  print("Gaia Agent Initialized")
28
 
29
+ openai_api_key = os.getenv("OPENAI_API_KEY")
30
+ if not openai_api_key:
31
+ raise EnvironmentError("OPENAI_API_KEY not found in environment variables.")
32
+
33
+ self.model = OpenAIClientModel(
34
+ model_name="gpt-3.5-turbo",
35
+ api_key=openai_api_key
36
+ )
37
 
38
  self.tools = [
39
  AudioTranscriptionTool(),
 
46
  model=self.model
47
  )
48
 
49
+ def download_file(self, task_id: str, file_extension: str) -> str:
50
+ file_url = f"{DEFAULT_API_URL}/files/{task_id}.{file_extension}"
51
+ local_filename = f"temp_{task_id}.{file_extension}"
52
+
53
+ try:
54
+ r = requests.get(file_url, timeout=30)
55
+ r.raise_for_status()
56
+ with open(local_filename, "wb") as f:
57
+ f.write(r.content)
58
+ return local_filename
59
+ except Exception as e:
60
+ print(f"Error downloading file for task {task_id}: {e}")
61
+ return ""
62
+
63
+ def __call__(self, question: str, task_id: str | None = None, file_name: str | None = None) -> str:
64
  print(f"Agent received question (first 50 chars): {question[:50]}...")
65
 
66
+ # If there's a file related to the question, download it and prepare tool input
67
+ tool_inputs = {}
68
+ if task_id and file_name:
69
+ ext = file_name.split(".")[-1].lower()
70
+ local_path = self.download_file(task_id, ext)
71
+ if local_path:
72
+ if ext in ["mp3", "wav"]:
73
+ tool_inputs = {"file_path": local_path}
74
+ question = f"Transcribe the audio file."
75
+ elif ext in ["jpg", "jpeg", "png"]:
76
+ tool_inputs = {"image_path": local_path, "question": question}
77
+ else:
78
+ print(f"Unsupported file extension: {ext}")
79
+
80
  full_prompt = f"{SYSTEM_PROMPT}\nQUESTION:\n{question}"
81
 
82
  try:
83
+ # If there's a file to process, call the tool with inputs
84
+ if tool_inputs:
85
+ for tool in self.tools:
86
+ if all(k in tool.inputs for k in tool_inputs.keys()):
87
+ result = tool.forward(**tool_inputs)
88
+ return result.strip()
89
+
90
+ # Otherwise, just call the agent with the prompt
91
  result = self.agent.run(full_prompt)
92
  print(f"Raw result from agent: {result}")
93
 
 
96
  elif isinstance(result, str):
97
  return result.strip()
98
  elif isinstance(result, list):
 
99
  for item in reversed(result):
100
  if isinstance(item, dict) and item.get("role") == "assistant" and "content" in item:
101
  return item["content"].strip()
 
144
 
145
  for item in questions_data:
146
  task_id = item.get("task_id")
147
+ question_text = item.get("question", "")
148
+ file_name = item.get("file_name") # file_name may or may not be present
149
+
150
  if not task_id:
151
  continue
152
  try:
153
+ submitted_answer = agent(question_text, task_id=task_id, file_name=file_name)
154
  answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer})
155
  results_log.append({
156
  "Task ID": task_id,
157
+ "Question": question_text,
158
  "Submitted Answer": submitted_answer
159
  })
160
  except Exception as e:
161
  error_msg = f"AGENT ERROR: {e}"
162
  results_log.append({
163
  "Task ID": task_id,
164
+ "Question": question_text,
165
  "Submitted Answer": error_msg
166
  })
167
 
 
215
  status_output = gr.Textbox(label="Run Status / Submission Result", lines=5, interactive=False)
216
  results_table = gr.DataFrame(label="Questions and Agent Answers", wrap=True)
217
 
218
+ run_button.click(fn=run_and_submit_all, inputs=[gr.get_last_logged_in_user()], outputs=[status_output, results_table])
219
 
220
  if __name__ == "__main__":
221
  print("\n" + "-"*30 + " App Starting " + "-"*30)