dlaima commited on
Commit
1381703
·
verified ·
1 Parent(s): add03b8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -51
app.py CHANGED
@@ -3,7 +3,8 @@ import os
3
  import gradio as gr
4
  import requests
5
  import pandas as pd
6
- from smolagents import ToolCallingAgent, OpenAIClientModel
 
7
  from audio_transcriber import AudioTranscriptionTool
8
  from image_analyzer import ImageAnalysisTool
9
  from wikipedia_searcher import WikipediaSearcher
@@ -18,7 +19,13 @@ SYSTEM_PROMPT = (
18
  "3. For dates, use the exact requested format.\n"
19
  "4. For numbers, use only the number.\n"
20
  "5. For names, use the exact name from sources.\n"
21
- "6. If the question has a file, download it using the task ID and process it.\n"
 
 
 
 
 
 
22
  "Never say 'the answer is...'. Only return the answer.\n"
23
  )
24
 
@@ -26,68 +33,31 @@ class GaiaAgent:
26
  def __init__(self):
27
  print("Gaia Agent Initialized")
28
 
29
- openai_api_key = os.getenv("OPENAI_API_KEY")
30
- if not openai_api_key:
31
- raise EnvironmentError("OPENAI_API_KEY not found in environment variables.")
32
-
33
- self.model = OpenAIClientModel(
34
  model_name="gpt-3.5-turbo",
35
- api_key=openai_api_key
36
  )
37
 
 
38
  self.tools = [
39
  AudioTranscriptionTool(),
40
  ImageAnalysisTool(),
41
  WikipediaSearcher()
42
  ]
43
 
 
44
  self.agent = ToolCallingAgent(
45
  tools=self.tools,
46
  model=self.model
47
  )
48
 
49
- def download_file(self, task_id: str, file_extension: str) -> str:
50
- file_url = f"{DEFAULT_API_URL}/files/{task_id}.{file_extension}"
51
- local_filename = f"temp_{task_id}.{file_extension}"
52
-
53
- try:
54
- r = requests.get(file_url, timeout=30)
55
- r.raise_for_status()
56
- with open(local_filename, "wb") as f:
57
- f.write(r.content)
58
- return local_filename
59
- except Exception as e:
60
- print(f"Error downloading file for task {task_id}: {e}")
61
- return ""
62
-
63
- def __call__(self, question: str, task_id: str | None = None, file_name: str | None = None) -> str:
64
  print(f"Agent received question (first 50 chars): {question[:50]}...")
65
 
66
- # If there's a file related to the question, download it and prepare tool input
67
- tool_inputs = {}
68
- if task_id and file_name:
69
- ext = file_name.split(".")[-1].lower()
70
- local_path = self.download_file(task_id, ext)
71
- if local_path:
72
- if ext in ["mp3", "wav"]:
73
- tool_inputs = {"file_path": local_path}
74
- question = f"Transcribe the audio file."
75
- elif ext in ["jpg", "jpeg", "png"]:
76
- tool_inputs = {"image_path": local_path, "question": question}
77
- else:
78
- print(f"Unsupported file extension: {ext}")
79
-
80
  full_prompt = f"{SYSTEM_PROMPT}\nQUESTION:\n{question}"
81
 
82
  try:
83
- # If there's a file to process, call the tool with inputs
84
- if tool_inputs:
85
- for tool in self.tools:
86
- if all(k in tool.inputs for k in tool_inputs.keys()):
87
- result = tool.forward(**tool_inputs)
88
- return result.strip()
89
-
90
- # Otherwise, just call the agent with the prompt
91
  result = self.agent.run(full_prompt)
92
  print(f"Raw result from agent: {result}")
93
 
@@ -144,13 +114,33 @@ def run_and_submit_all(profile: gr.OAuthProfile | None):
144
 
145
  for item in questions_data:
146
  task_id = item.get("task_id")
147
- question_text = item.get("question", "")
148
- file_name = item.get("file_name") # file_name may or may not be present
149
-
150
  if not task_id:
151
  continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  try:
153
- submitted_answer = agent(question_text, task_id=task_id, file_name=file_name)
154
  answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer})
155
  results_log.append({
156
  "Task ID": task_id,
@@ -165,6 +155,13 @@ def run_and_submit_all(profile: gr.OAuthProfile | None):
165
  "Submitted Answer": error_msg
166
  })
167
 
 
 
 
 
 
 
 
168
  if not answers_payload:
169
  return "Agent did not produce any answers to submit.", pd.DataFrame(results_log)
170
 
@@ -215,7 +212,7 @@ with gr.Blocks() as demo:
215
  status_output = gr.Textbox(label="Run Status / Submission Result", lines=5, interactive=False)
216
  results_table = gr.DataFrame(label="Questions and Agent Answers", wrap=True)
217
 
218
- run_button.click(fn=run_and_submit_all, inputs=[gr.get_last_logged_in_user()], outputs=[status_output, results_table])
219
 
220
  if __name__ == "__main__":
221
  print("\n" + "-"*30 + " App Starting " + "-"*30)
@@ -239,4 +236,3 @@ if __name__ == "__main__":
239
 
240
 
241
 
242
-
 
3
  import gradio as gr
4
  import requests
5
  import pandas as pd
6
+
7
+ from smolagents import ToolCallingAgent, OpenAIServerModel
8
  from audio_transcriber import AudioTranscriptionTool
9
  from image_analyzer import ImageAnalysisTool
10
  from wikipedia_searcher import WikipediaSearcher
 
19
  "3. For dates, use the exact requested format.\n"
20
  "4. For numbers, use only the number.\n"
21
  "5. For names, use the exact name from sources.\n"
22
+ "6. If the question has a file, download it using the task ID.\n"
23
+ "Examples:\n"
24
+ "- '42'\n"
25
+ "- 'Arturo Nunez'\n"
26
+ "- 'Yes'\n"
27
+ "- 'October 5, 2001'\n"
28
+ "- 'Buenos Aires'\n"
29
  "Never say 'the answer is...'. Only return the answer.\n"
30
  )
31
 
 
33
  def __init__(self):
34
  print("Gaia Agent Initialized")
35
 
36
+ # Initialize the OpenAI GPT-3.5-turbo model via smolagents OpenAIServerModel
37
+ self.model = OpenAIServerModel(
 
 
 
38
  model_name="gpt-3.5-turbo",
39
+ api_key=os.getenv("OPENAI_API_KEY") # Make sure you set this in your environment
40
  )
41
 
42
+ # Initialize the tools
43
  self.tools = [
44
  AudioTranscriptionTool(),
45
  ImageAnalysisTool(),
46
  WikipediaSearcher()
47
  ]
48
 
49
+ # Create the agent with tools and model
50
  self.agent = ToolCallingAgent(
51
  tools=self.tools,
52
  model=self.model
53
  )
54
 
55
+ def __call__(self, question: str) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  print(f"Agent received question (first 50 chars): {question[:50]}...")
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  full_prompt = f"{SYSTEM_PROMPT}\nQUESTION:\n{question}"
59
 
60
  try:
 
 
 
 
 
 
 
 
61
  result = self.agent.run(full_prompt)
62
  print(f"Raw result from agent: {result}")
63
 
 
114
 
115
  for item in questions_data:
116
  task_id = item.get("task_id")
 
 
 
117
  if not task_id:
118
  continue
119
+
120
+ question_text = item.get("question", "")
121
+
122
+ # Download associated file if any (mp3 or jpeg) according to GAIA benchmark task
123
+ file_url = item.get("file_url")
124
+ local_file_path = None
125
+ if file_url:
126
+ try:
127
+ ext = file_url.split(".")[-1].lower()
128
+ if ext in ["mp3", "wav", "jpeg", "jpg", "png"]:
129
+ local_file_path = f"./temp_{task_id}.{ext}"
130
+ with requests.get(file_url, stream=True) as r:
131
+ r.raise_for_status()
132
+ with open(local_file_path, "wb") as f:
133
+ for chunk in r.iter_content(chunk_size=8192):
134
+ f.write(chunk)
135
+ print(f"Downloaded file for task {task_id} to {local_file_path}")
136
+
137
+ # Append info about the file path to the question so the agent knows to use it
138
+ question_text += f"\n\nFile path: {local_file_path}"
139
+ except Exception as e:
140
+ print(f"Failed to download file for task {task_id}: {e}")
141
+
142
  try:
143
+ submitted_answer = agent(question_text)
144
  answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer})
145
  results_log.append({
146
  "Task ID": task_id,
 
155
  "Submitted Answer": error_msg
156
  })
157
 
158
+ # Cleanup downloaded file
159
+ if local_file_path:
160
+ try:
161
+ os.remove(local_file_path)
162
+ except Exception:
163
+ pass
164
+
165
  if not answers_payload:
166
  return "Agent did not produce any answers to submit.", pd.DataFrame(results_log)
167
 
 
212
  status_output = gr.Textbox(label="Run Status / Submission Result", lines=5, interactive=False)
213
  results_table = gr.DataFrame(label="Questions and Agent Answers", wrap=True)
214
 
215
+ run_button.click(fn=run_and_submit_all, outputs=[status_output, results_table])
216
 
217
  if __name__ == "__main__":
218
  print("\n" + "-"*30 + " App Starting " + "-"*30)
 
236
 
237
 
238