dlaima commited on
Commit
a942c8c
·
verified ·
1 Parent(s): 8b469fc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -31
app.py CHANGED
@@ -4,8 +4,7 @@ import gradio as gr
4
  import requests
5
  import pandas as pd
6
 
7
- from smolagents import Tool, CodeAgent, HfApiModel
8
-
9
  from audio_transcriber import AudioTranscriptionTool
10
  from image_analyzer import ImageAnalysisTool
11
  from wikipedia_searcher import WikipediaSearcher
@@ -13,46 +12,68 @@ from wikipedia_searcher import WikipediaSearcher
13
  # GAIA scoring endpoint
14
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
15
 
16
- # Static system prompt for GAIA
17
- SYSTEM_PROMPT = """You are an agent solving the GAIA benchmark and you are required to provide exact answers.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  Rules to follow:
19
  1. Return only the exact requested answer: no explanation and no reasoning.
20
  2. For yes/no questions, return exactly \"Yes\" or \"No\".
21
  3. For dates, use the exact format requested.
22
  4. For numbers, use the exact number, no other format.
23
  5. For names, use the exact name as found in sources.
24
- 6. If the question has an associated file, process it accordingly.
25
  Examples of good responses:
26
  - \"42\"
 
27
  - \"Yes\"
28
  - \"October 5, 2001\"
29
  - \"Buenos Aires\"
30
  Never include phrases like \"the answer is...\" or \"Based on my research\".
31
- Only return the exact answer."""
32
-
33
- # Define agent tools
34
- audio_tool = AudioTranscriptionTool()
35
- image_tool = ImageAnalysisTool()
36
- wikipedia_tool = WikipediaSearcher()
37
-
38
- tools = [audio_tool, image_tool, wikipedia_tool]
39
-
40
- # Define the custom agent using Dolphin model (free Mixtral)
41
- class MyAgent(CodeAgent):
42
- def __init__(self):
43
- model = HfApiModel(
44
- model="cognitivecomputations/dolphin-2.6-mixtral-8x7b",
45
- api_key=os.getenv("HF_API_TOKEN", "").strip(),
46
- # No system_prompt here
47
- )
48
- super().__init__(model=model, tools=tools)
49
 
50
- def __call__(self, question_dict):
51
- system_message = {"role": "system", "content": SYSTEM_PROMPT}
52
- user_message = {"role": "user", "content": question_dict.get("question", "")}
53
- messages = [system_message, user_message]
54
- # Pass messages directly in the call
55
- return self.model(messages)
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
  # Evaluation + Submission function
58
  def run_and_submit_all(profile: gr.OAuthProfile | None):
@@ -70,7 +91,7 @@ def run_and_submit_all(profile: gr.OAuthProfile | None):
70
  submit_url = f"{api_url}/submit"
71
 
72
  try:
73
- agent = MyAgent()
74
  except Exception as e:
75
  print(f"Error initializing agent: {e}")
76
  return f"Error initializing agent: {e}", None
@@ -97,7 +118,7 @@ def run_and_submit_all(profile: gr.OAuthProfile | None):
97
  if not task_id:
98
  continue
99
  try:
100
- submitted_answer = agent(item)
101
  answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer})
102
  results_log.append({
103
  "Task ID": task_id,
 
4
  import requests
5
  import pandas as pd
6
 
7
+ from smolagents import InferenceClientModel, ToolCallingAgent
 
8
  from audio_transcriber import AudioTranscriptionTool
9
  from image_analyzer import ImageAnalysisTool
10
  from wikipedia_searcher import WikipediaSearcher
 
12
  # GAIA scoring endpoint
13
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
14
 
15
+ # Define the GaiaAgent class with embedded prompt in __call__
16
+ class GaiaAgent:
17
+ def __init__(self):
18
+ print("Gaia Agent Initialized")
19
+
20
+ self.model = InferenceClientModel(
21
+ model_id="cognitivecomputations/dolphin-2.6-mixtral-8x7b",
22
+ token=os.getenv("HUGGINGFACEHUB_API_TOKEN", "").strip()
23
+ )
24
+
25
+ self.tools = [
26
+ AudioTranscriptionTool(),
27
+ ImageAnalysisTool(),
28
+ WikipediaSearcher()
29
+ ]
30
+
31
+ self.agent = ToolCallingAgent(
32
+ tools=self.tools,
33
+ model=self.model
34
+ )
35
+
36
+ def __call__(self, question: str) -> str:
37
+ print(f"Agent received question (first 50 chars): {question[:50]}...")
38
+
39
+ prompt = f"""You are an agent solving the GAIA benchmark and you are required to provide exact answers.
40
  Rules to follow:
41
  1. Return only the exact requested answer: no explanation and no reasoning.
42
  2. For yes/no questions, return exactly \"Yes\" or \"No\".
43
  3. For dates, use the exact format requested.
44
  4. For numbers, use the exact number, no other format.
45
  5. For names, use the exact name as found in sources.
46
+ 6. If the question has an associated file, download the file first using the task ID.
47
  Examples of good responses:
48
  - \"42\"
49
+ - \"Arturo Nunez\"
50
  - \"Yes\"
51
  - \"October 5, 2001\"
52
  - \"Buenos Aires\"
53
  Never include phrases like \"the answer is...\" or \"Based on my research\".
54
+ Only return the exact answer.
55
+ QUESTION:
56
+ {question}
57
+ """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
+ try:
60
+ result = self.agent.run(prompt)
61
+ print(f"Raw result from agent: {result}")
62
+
63
+ if isinstance(result, dict) and "answer" in result:
64
+ return str(result["answer"]).strip()
65
+ elif isinstance(result, str):
66
+ return result.strip()
67
+ elif isinstance(result, list):
68
+ for item in reversed(result):
69
+ if isinstance(item, dict) and item.get("role") == "assistant" and "content" in item:
70
+ return item["content"].strip()
71
+ return "ERROR: Unexpected list format"
72
+ else:
73
+ return "ERROR: Unexpected result type"
74
+ except Exception as e:
75
+ print(f"Exception during agent run: {e}")
76
+ return f"AGENT ERROR: {e}"
77
 
78
  # Evaluation + Submission function
79
  def run_and_submit_all(profile: gr.OAuthProfile | None):
 
91
  submit_url = f"{api_url}/submit"
92
 
93
  try:
94
+ agent = GaiaAgent()
95
  except Exception as e:
96
  print(f"Error initializing agent: {e}")
97
  return f"Error initializing agent: {e}", None
 
118
  if not task_id:
119
  continue
120
  try:
121
+ submitted_answer = agent(item.get("question", ""))
122
  answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer})
123
  results_log.append({
124
  "Task ID": task_id,