Arbnor Tefiki commited on
Commit
8ecb1cd
Β·
1 Parent(s): 94b3868

one percent accuracy

Browse files
Files changed (5) hide show
  1. .gitignore +1 -0
  2. app.py +81 -26
  3. custom_tools.py +198 -25
  4. functions.py +346 -99
  5. requirements.txt +3 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__/
app.py CHANGED
@@ -5,6 +5,8 @@ import pandas as pd
5
  from dotenv import load_dotenv
6
  from functions import *
7
  from langchain_core.messages import HumanMessage
 
 
8
 
9
  load_dotenv()
10
 
@@ -49,60 +51,106 @@ def run_and_submit_all(profile: gr.OAuthProfile | None):
49
  results_log = []
50
  answers_payload = []
51
 
 
52
  print(f"Running agent on {len(questions_data)} questions...")
53
- for item in questions_data:
 
 
54
  task_id = item.get("task_id")
55
  question_text = item.get("question")
56
  if not task_id or question_text is None:
57
  print(f"Skipping item with missing task_id or question: {item}")
58
  continue
 
 
 
 
 
59
  try:
 
 
60
  input_messages = [HumanMessage(content=question_text)]
61
-
 
62
  result = agent({"messages": input_messages})
63
-
 
 
64
  if "messages" in result and result["messages"]:
65
- last_valid = next(
66
- (m for m in reversed(result["messages"]) if hasattr(m, "content") and isinstance(m.content, str)),
67
- None
68
- )
69
- if last_valid:
70
- answer = last_valid.content.strip()
71
- else:
72
- answer = "UNKNOWN"
73
- else:
74
- answer = "UNKNOWN"
75
-
76
- print("Answered with:", answer)
77
  answers_payload.append({"task_id": task_id, "submitted_answer": answer})
78
  results_log.append({
79
  "Task ID": task_id,
80
- "Question": question_text,
81
- "Submitted Answer": answer
 
82
  })
 
83
  except Exception as e:
84
  print(f"Error running agent on task {task_id}: {e}")
85
- results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": f"AGENT ERROR: {e}"})
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
  if not answers_payload:
88
  print("Agent did not produce any answers to submit.")
89
  return "Agent did not produce any answers to submit.", pd.DataFrame(results_log)
90
 
 
 
 
 
 
 
 
91
  submission_data = {"username": username.strip(), "agent_code": agent_code, "answers": answers_payload}
92
- print(f"Submitting {len(answers_payload)} answers for user '{username}'...")
93
 
94
  try:
95
  response = requests.post(submit_url, json=submission_data, timeout=60)
96
  response.raise_for_status()
97
  result_data = response.json()
 
 
 
 
 
98
  final_status = (
99
  f"Submission Successful!\n"
100
  f"User: {result_data.get('username')}\n"
101
- f"Overall Score: {result_data.get('score', 'N/A')}% "
102
- f"({result_data.get('correct_count', '?')}/{result_data.get('total_attempted', '?')} correct)\n"
103
  f"Message: {result_data.get('message', 'No message received.')}"
104
  )
105
- print("Submission successful.")
 
 
 
 
 
 
106
  results_df = pd.DataFrame(results_log)
107
  return final_status, results_df
108
  except Exception as e:
@@ -113,10 +161,17 @@ def run_and_submit_all(profile: gr.OAuthProfile | None):
113
 
114
  # Gradio UI
115
  with gr.Blocks() as demo:
116
- gr.Markdown("# Basic Agent Evaluation Runner")
117
  gr.Markdown(
118
  """
119
- Modify the code here to define your agent's logic, the tools, the necessary packages, etc...
 
 
 
 
 
 
 
120
  """
121
  )
122
 
@@ -154,5 +209,5 @@ if __name__ == "__main__":
154
 
155
  print("-"*(60 + len(" App Starting ")) + "\n")
156
 
157
- print("Launching Gradio Interface for Basic Agent Evaluation...")
158
- demo.launch(debug=True, share=False)
 
5
  from dotenv import load_dotenv
6
  from functions import *
7
  from langchain_core.messages import HumanMessage
8
+ import traceback
9
+ import time
10
 
11
  load_dotenv()
12
 
 
51
  results_log = []
52
  answers_payload = []
53
 
54
+ print(f"\n{'='*60}")
55
  print(f"Running agent on {len(questions_data)} questions...")
56
+ print(f"{'='*60}\n")
57
+
58
+ for idx, item in enumerate(questions_data, 1):
59
  task_id = item.get("task_id")
60
  question_text = item.get("question")
61
  if not task_id or question_text is None:
62
  print(f"Skipping item with missing task_id or question: {item}")
63
  continue
64
+
65
+ print(f"\n--- Question {idx}/{len(questions_data)} ---")
66
+ print(f"Task ID: {task_id}")
67
+ print(f"Question: {question_text}")
68
+
69
  try:
70
+ # Add timeout for each question
71
+ start_time = time.time()
72
  input_messages = [HumanMessage(content=question_text)]
73
+
74
+ # Invoke the agent with the question
75
  result = agent({"messages": input_messages})
76
+
77
+ # Extract the answer from the result
78
+ answer = "UNKNOWN"
79
  if "messages" in result and result["messages"]:
80
+ # Look for the last AI message with content
81
+ for msg in reversed(result["messages"]):
82
+ if hasattr(msg, "content") and isinstance(msg.content, str) and msg.content.strip():
83
+ # Skip planner outputs
84
+ if not any(msg.content.upper().startswith(prefix) for prefix in ["SEARCH:", "CALCULATE:", "DEFINE:", "WIKIPEDIA:", "REVERSE:", "DIRECT:"]):
85
+ answer = msg.content.strip()
86
+ break
87
+
88
+ elapsed_time = time.time() - start_time
89
+ print(f"Answer: {answer}")
90
+ print(f"Time taken: {elapsed_time:.2f}s")
91
+
92
  answers_payload.append({"task_id": task_id, "submitted_answer": answer})
93
  results_log.append({
94
  "Task ID": task_id,
95
+ "Question": question_text[:100] + "..." if len(question_text) > 100 else question_text,
96
+ "Submitted Answer": answer,
97
+ "Time (s)": f"{elapsed_time:.2f}"
98
  })
99
+
100
  except Exception as e:
101
  print(f"Error running agent on task {task_id}: {e}")
102
+ print(f"Traceback: {traceback.format_exc()}")
103
+
104
+ # Still submit UNKNOWN for errors
105
+ answers_payload.append({"task_id": task_id, "submitted_answer": "UNKNOWN"})
106
+ results_log.append({
107
+ "Task ID": task_id,
108
+ "Question": question_text[:100] + "..." if len(question_text) > 100 else question_text,
109
+ "Submitted Answer": f"ERROR: {str(e)[:50]}",
110
+ "Time (s)": "N/A"
111
+ })
112
+
113
+ print(f"\n{'='*60}")
114
+ print(f"Completed processing all questions")
115
+ print(f"{'='*60}\n")
116
 
117
  if not answers_payload:
118
  print("Agent did not produce any answers to submit.")
119
  return "Agent did not produce any answers to submit.", pd.DataFrame(results_log)
120
 
121
+ # Summary before submission
122
+ unknown_count = sum(1 for ans in answers_payload if ans["submitted_answer"] == "UNKNOWN")
123
+ print(f"\nSummary before submission:")
124
+ print(f"Total questions: {len(answers_payload)}")
125
+ print(f"UNKNOWN answers: {unknown_count}")
126
+ print(f"Attempted answers: {len(answers_payload) - unknown_count}")
127
+
128
  submission_data = {"username": username.strip(), "agent_code": agent_code, "answers": answers_payload}
129
+ print(f"\nSubmitting {len(answers_payload)} answers for user '{username}'...")
130
 
131
  try:
132
  response = requests.post(submit_url, json=submission_data, timeout=60)
133
  response.raise_for_status()
134
  result_data = response.json()
135
+
136
+ score = result_data.get('score', 0)
137
+ correct_count = result_data.get('correct_count', 0)
138
+ total_attempted = result_data.get('total_attempted', 0)
139
+
140
  final_status = (
141
  f"Submission Successful!\n"
142
  f"User: {result_data.get('username')}\n"
143
+ f"Overall Score: {score}% "
144
+ f"({correct_count}/{total_attempted} correct)\n"
145
  f"Message: {result_data.get('message', 'No message received.')}"
146
  )
147
+
148
+ print("\n" + "="*60)
149
+ print("SUBMISSION RESULTS:")
150
+ print(f"Score: {score}%")
151
+ print(f"Correct: {correct_count}/{total_attempted}")
152
+ print("="*60)
153
+
154
  results_df = pd.DataFrame(results_log)
155
  return final_status, results_df
156
  except Exception as e:
 
161
 
162
  # Gradio UI
163
  with gr.Blocks() as demo:
164
+ gr.Markdown("# Enhanced GAIA Agent Evaluation Runner")
165
  gr.Markdown(
166
  """
167
+ This enhanced agent is optimized for GAIA benchmark questions with improved:
168
+ - Planning logic for better tool selection
169
+ - Search capabilities with more comprehensive results
170
+ - Mathematical expression parsing
171
+ - Answer extraction from search results
172
+ - Error handling and logging
173
+
174
+ Target: >50% accuracy on GAIA questions
175
  """
176
  )
177
 
 
209
 
210
  print("-"*(60 + len(" App Starting ")) + "\n")
211
 
212
+ print("Launching Gradio Interface for Enhanced GAIA Agent Evaluation...")
213
+ demo.launch(debug=True, share=False)
custom_tools.py CHANGED
@@ -1,19 +1,21 @@
1
  import requests
2
  from duckduckgo_search import DDGS
3
  from langchain_core.tools import tool
 
 
4
 
5
  @tool
6
  def reverse_text(input: str) -> str:
7
  """Reverse the characters in a text or string.
8
 
9
  Args:
10
- query: The text or string to reverse.
11
  """
12
  return input[::-1]
13
 
14
  @tool
15
  def web_search(query: str) -> str:
16
- """Perform a web search using DuckDuckGo and return the top 3 summarized results.
17
 
18
  Args:
19
  query: The search query to look up.
@@ -21,76 +23,247 @@ def web_search(query: str) -> str:
21
  try:
22
  results = []
23
  with DDGS() as ddgs:
24
- for r in ddgs.text(query, max_results=3):
 
 
 
25
  title = r.get("title", "")
26
  snippet = r.get("body", "")
27
  url = r.get("href", "")
 
28
  if title and snippet:
29
- results.append(f"{title}: {snippet} (URL: {url})")
 
 
 
30
  if not results:
31
- return "No results found."
32
- return "\n\n---\n\n".join(results)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  except Exception as e:
34
  return f"Web search error: {e}"
35
 
36
  @tool
37
  def calculate(expression: str) -> str:
38
- """Evaluate a simple math expression and return the result.
39
 
40
  Args:
41
  expression: A string containing the math expression to evaluate.
42
  """
43
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  allowed_names = {
45
  "abs": abs,
46
  "round": round,
47
  "min": min,
48
  "max": max,
49
  "pow": pow,
 
 
 
 
 
 
50
  }
51
- result = eval(expression, {"__builtins__": None}, allowed_names)
52
- return str(result)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  except Exception as e:
54
  return f"Calculation error: {e}"
55
 
56
  @tool
57
  def wikipedia_summary(query: str) -> str:
58
- """Retrieve a summary of a topic from Wikipedia.
59
 
60
  Args:
61
  query: The subject or topic to summarize.
62
  """
63
  try:
 
 
 
 
 
64
  response = requests.get(
65
- f"https://en.wikipedia.org/api/rest_v1/page/summary/{query}", timeout=10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  )
67
- response.raise_for_status()
68
- data = response.json()
69
- return data.get("extract", "No summary found.")
 
 
 
 
 
 
 
 
70
  except Exception as e:
71
  return f"Wikipedia error: {e}"
72
 
73
  @tool
74
  def define_term(term: str) -> str:
75
- """Provide a dictionary-style definition of a given term using an online API.
76
 
77
  Args:
78
  term: The word or term to define.
79
  """
80
  try:
 
 
 
 
81
  response = requests.get(
82
- f"https://api.dictionaryapi.dev/api/v2/entries/en/{term}", timeout=10
 
83
  )
84
- response.raise_for_status()
85
- data = response.json()
86
- meanings = data[0].get("meanings", [])
87
- if meanings:
88
- defs = meanings[0].get("definitions", [])
89
- if defs:
90
- return defs[0].get("definition", "Definition not found.")
91
- return "Definition not found."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  except Exception as e:
93
  return f"Definition error: {e}"
94
 
95
  # List of tools to register with your agent
96
- TOOLS = [web_search, calculate, wikipedia_summary, define_term, reverse_text]
 
1
  import requests
2
  from duckduckgo_search import DDGS
3
  from langchain_core.tools import tool
4
+ import time
5
+ import re
6
 
7
  @tool
8
  def reverse_text(input: str) -> str:
9
  """Reverse the characters in a text or string.
10
 
11
  Args:
12
+ input: The text or string to reverse.
13
  """
14
  return input[::-1]
15
 
16
  @tool
17
  def web_search(query: str) -> str:
18
+ """Perform a web search using DuckDuckGo and return comprehensive results.
19
 
20
  Args:
21
  query: The search query to look up.
 
23
  try:
24
  results = []
25
  with DDGS() as ddgs:
26
+ # Get more results for better coverage
27
+ search_results = list(ddgs.text(query, max_results=8))
28
+
29
+ for r in search_results:
30
  title = r.get("title", "")
31
  snippet = r.get("body", "")
32
  url = r.get("href", "")
33
+
34
  if title and snippet:
35
+ # Combine title and snippet for more context
36
+ full_text = f"{title}. {snippet}"
37
+ results.append(full_text)
38
+
39
  if not results:
40
+ # Try with modified query
41
+ time.sleep(0.5)
42
+ with DDGS() as ddgs:
43
+ # Add more context to the query
44
+ modified_query = f"{query} facts information details"
45
+ search_results = list(ddgs.text(modified_query, max_results=5))
46
+
47
+ for r in search_results:
48
+ title = r.get("title", "")
49
+ snippet = r.get("body", "")
50
+ if title and snippet:
51
+ results.append(f"{title}. {snippet}")
52
+
53
+ if not results:
54
+ return "No search results found."
55
+
56
+ # Join all results with clear separation
57
+ return "\n\n".join(results)
58
+
59
  except Exception as e:
60
  return f"Web search error: {e}"
61
 
62
  @tool
63
  def calculate(expression: str) -> str:
64
+ """Evaluate a mathematical expression and return the result.
65
 
66
  Args:
67
  expression: A string containing the math expression to evaluate.
68
  """
69
  try:
70
+ # Clean the expression more thoroughly
71
+ expression = expression.strip()
72
+
73
+ # Handle various multiplication notations
74
+ expression = expression.replace("Γ—", "*")
75
+ expression = expression.replace("x", "*")
76
+ expression = expression.replace("X", "*")
77
+
78
+ # Handle exponents
79
+ expression = expression.replace("^", "**")
80
+
81
+ # Remove thousands separators
82
+ expression = expression.replace(",", "")
83
+
84
+ # Handle parentheses
85
+ expression = expression.replace("[", "(").replace("]", ")")
86
+ expression = expression.replace("{", "(").replace("}", ")")
87
+
88
+ # Handle percentage calculations
89
+ # Convert "X% of Y" to "(X/100) * Y"
90
+ percent_pattern = r'(\d+(?:\.\d+)?)\s*%\s*of\s*(\d+(?:\.\d+)?)'
91
+ expression = re.sub(percent_pattern, r'(\1/100) * \2', expression)
92
+
93
+ # Convert standalone percentages
94
+ expression = re.sub(r'(\d+(?:\.\d+)?)\s*%', r'(\1/100)', expression)
95
+
96
+ # Define safe functions and constants
97
  allowed_names = {
98
  "abs": abs,
99
  "round": round,
100
  "min": min,
101
  "max": max,
102
  "pow": pow,
103
+ "sum": sum,
104
+ "len": len,
105
+ "__builtins__": {},
106
+ # Math constants
107
+ "pi": 3.14159265359,
108
+ "e": 2.71828182846,
109
  }
110
+
111
+ # Evaluate the expression
112
+ result = eval(expression, allowed_names)
113
+
114
+ # Format the result nicely
115
+ if isinstance(result, float):
116
+ # Check if it's a whole number
117
+ if result.is_integer():
118
+ return str(int(result))
119
+ else:
120
+ # Round to reasonable precision
121
+ formatted = f"{result:.10f}".rstrip('0').rstrip('.')
122
+ return formatted
123
+ else:
124
+ return str(result)
125
+
126
+ except ZeroDivisionError:
127
+ return "Error: Division by zero"
128
+ except SyntaxError as e:
129
+ return f"Syntax error in expression: {e}"
130
  except Exception as e:
131
  return f"Calculation error: {e}"
132
 
133
  @tool
134
  def wikipedia_summary(query: str) -> str:
135
+ """Retrieve a comprehensive summary of a topic from Wikipedia.
136
 
137
  Args:
138
  query: The subject or topic to summarize.
139
  """
140
  try:
141
+ # Clean the query
142
+ query = query.strip()
143
+
144
+ # First, try direct API
145
+ clean_query = query.replace(" ", "_")
146
  response = requests.get(
147
+ f"https://en.wikipedia.org/api/rest_v1/page/summary/{clean_query}",
148
+ timeout=10,
149
+ headers={"User-Agent": "Mozilla/5.0"}
150
+ )
151
+
152
+ if response.status_code == 200:
153
+ data = response.json()
154
+ extract = data.get("extract", "")
155
+ if extract and extract != "No summary found.":
156
+ title = data.get("title", query)
157
+ description = data.get("description", "")
158
+
159
+ # Get additional details from the full article if needed
160
+ full_response = requests.get(
161
+ f"https://en.wikipedia.org/w/api.php",
162
+ params={
163
+ "action": "query",
164
+ "prop": "extracts",
165
+ "exintro": True,
166
+ "explaintext": True,
167
+ "titles": title,
168
+ "format": "json"
169
+ },
170
+ timeout=10
171
+ )
172
+
173
+ result = extract
174
+ if description and description not in extract:
175
+ result = f"{description}. {extract}"
176
+
177
+ if full_response.status_code == 200:
178
+ full_data = full_response.json()
179
+ pages = full_data.get("query", {}).get("pages", {})
180
+ for page_id, page_info in pages.items():
181
+ full_extract = page_info.get("extract", "")
182
+ if full_extract and len(full_extract) > len(result):
183
+ result = full_extract[:1000] # Limit length
184
+
185
+ return result
186
+
187
+ # Fallback: Try searching Wikipedia
188
+ search_response = requests.get(
189
+ "https://en.wikipedia.org/w/api.php",
190
+ params={
191
+ "action": "opensearch",
192
+ "search": query,
193
+ "limit": 3,
194
+ "format": "json"
195
+ },
196
+ timeout=10
197
  )
198
+
199
+ if search_response.status_code == 200:
200
+ search_data = search_response.json()
201
+ if len(search_data) > 1 and search_data[1]:
202
+ # Try the first result
203
+ first_result = search_data[1][0]
204
+ if first_result:
205
+ return wikipedia_summary(first_result)
206
+
207
+ return f"No Wikipedia article found for '{query}'."
208
+
209
  except Exception as e:
210
  return f"Wikipedia error: {e}"
211
 
212
  @tool
213
  def define_term(term: str) -> str:
214
+ """Provide a comprehensive dictionary definition of a given term.
215
 
216
  Args:
217
  term: The word or term to define.
218
  """
219
  try:
220
+ # Clean the term
221
+ term = term.strip().lower()
222
+ term = re.sub(r'[^\w\s-]', '', term) # Remove punctuation except hyphens
223
+
224
  response = requests.get(
225
+ f"https://api.dictionaryapi.dev/api/v2/entries/en/{term}",
226
+ timeout=10
227
  )
228
+
229
+ if response.status_code == 200:
230
+ data = response.json()
231
+ all_definitions = []
232
+
233
+ # Collect all definitions with their parts of speech
234
+ for entry in data:
235
+ word = entry.get("word", term)
236
+ meanings = entry.get("meanings", [])
237
+
238
+ for meaning in meanings:
239
+ part_of_speech = meaning.get("partOfSpeech", "")
240
+ definitions = meaning.get("definitions", [])
241
+
242
+ for definition in definitions:
243
+ def_text = definition.get("definition", "")
244
+ if def_text:
245
+ if part_of_speech:
246
+ all_definitions.append(f"({part_of_speech}) {def_text}")
247
+ else:
248
+ all_definitions.append(def_text)
249
+
250
+ if all_definitions:
251
+ # Return the most comprehensive definition
252
+ # Prefer longer, more detailed definitions
253
+ all_definitions.sort(key=len, reverse=True)
254
+ return all_definitions[0]
255
+
256
+ # Try alternative approach - use the error message if it's informative
257
+ if response.status_code == 404:
258
+ error_data = response.json()
259
+ if "message" in error_data:
260
+ return f"No definition found for '{term}'"
261
+
262
+ # Last resort - return a clear message
263
+ return f"Unable to find definition for '{term}'"
264
+
265
  except Exception as e:
266
  return f"Definition error: {e}"
267
 
268
  # List of tools to register with your agent
269
+ TOOLS = [web_search, calculate, wikipedia_summary, define_term, reverse_text]
functions.py CHANGED
@@ -1,140 +1,387 @@
1
  import os
2
  import re
 
3
  from langgraph.graph import START, StateGraph, MessagesState
4
  from langgraph.prebuilt import ToolNode
5
- from langchain_core.messages import HumanMessage, SystemMessage
6
  from huggingface_hub import InferenceClient
7
  from custom_tools import TOOLS
8
- from langchain_core.messages import AIMessage
9
 
10
  HF_TOKEN = os.getenv("HUGGINGFACE_API_TOKEN")
11
  client = InferenceClient(token=HF_TOKEN)
12
 
13
- planner_prompt = SystemMessage(content="""
14
- You are a planning assistant. Your job is to decide how to answer a question.
15
 
16
- - If the answer is easy and factual, answer it directly.
17
- - If you are not 100% certain or the answer requires looking up real-world information, say:
18
- I need to search this.
 
 
 
 
 
19
 
20
- - If the question contains math or expressions like +, -, /, ^, say:
21
- I need to calculate this.
 
 
 
22
 
23
- - If a word should be explained, say:
24
- I need to define this.
25
-
26
- -If the question asks about a person, historical event, or specific topic, say:
27
- I need to look up wikipedia.
28
-
29
- -If the questions asks for backwards pronounciation or reversing text, say:
30
- I need to reverse text.
31
 
32
- Only respond with one line explaining what you will do.
33
- Do not try to answer yet.
34
-
35
- e.g:
36
- Q: How many studio albums did Mercedes Sosa release between 2000 and 2009?
37
- A: I need to search this.
38
 
39
- Q: What does the word 'ephemeral' mean?
40
- A: I need to define this.
41
 
42
- Q: What is 23 * 6 + 3?
43
- A: I need to calculate this.
 
 
44
 
45
- Q: Reverse this: 'tfel drow eht'
46
- A: I need to reverse text.
 
 
 
 
 
 
47
 
48
- Q: What bird species are seen in this video?
49
- A: UNKNOWN
50
- """)
 
 
 
 
 
 
51
 
52
  def planner_node(state: MessagesState):
53
- hf_messages = [planner_prompt] + state["messages"]
54
-
55
- # Properly map LangChain message objects to dicts
56
- messages_dict = []
57
- for msg in hf_messages:
58
- if isinstance(msg, SystemMessage):
59
- role = "system"
60
- elif isinstance(msg, HumanMessage):
61
- role = "user"
62
- else:
63
- raise ValueError(f"Unsupported message type: {type(msg)}")
64
- messages_dict.append({"role": role, "content": msg.content})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
- response = client.chat.completions.create(
67
- model="mistralai/Mistral-7B-Instruct-v0.2",
68
- messages=messages_dict,
69
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
- text = response.choices[0].message.content.strip()
72
- print("Planner output:\n", text)
 
 
 
 
 
 
 
 
 
 
73
 
74
- return {"messages": [SystemMessage(content=text)]}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
- answer_prompt = SystemMessage(content="""
77
- You are now given the result of a tool (like a search, calculator, or text reversal).
78
- Use the tool result and the original question to give the final answer.
79
- If the tool result is unhelpful or unclear, respond with 'UNKNOWN'.
80
- Respond with only the answer β€” no explanations.
81
- """)
82
 
83
- def assistant_node(state: MessagesState):
84
- hf_messages = [answer_prompt] + state["messages"]
85
-
86
- messages_dict = []
87
- for msg in hf_messages:
88
- if isinstance(msg, SystemMessage):
89
- role = "system"
90
- elif isinstance(msg, HumanMessage):
91
- role = "user"
92
- else:
93
- raise ValueError(f"Unsupported message type: {type(msg)}")
94
- messages_dict.append({"role": role, "content": msg.content})
95
 
96
- response = client.chat.completions.create(
97
- model="mistralai/Mistral-7B-Instruct-v0.2",
98
- messages=messages_dict,
99
- )
 
 
 
100
 
101
- text = response.choices[0].message.content.strip()
102
- print("Final answer output:\n", text)
103
 
104
- return {"messages": [AIMessage(content=text)]}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
  def tools_condition(state: MessagesState) -> str:
107
- last_msg = state["messages"][-1].content.lower()
108
-
109
- if any(trigger in last_msg for trigger in [
110
- "i need to search",
111
- "i need to calculate",
112
- "i need to define",
113
- "i need to reverse text",
114
- "i need to look up wikipedia"
115
- ]):
 
 
 
116
  return "tools"
117
-
 
 
 
 
 
118
  return "end"
119
 
120
- class PatchedToolNode(ToolNode):
121
- def invoke(self, state: MessagesState, config) -> dict:
122
- result = super().invoke(state)
123
- tool_output = result.get("messages", [])[0].content if result.get("messages") else "UNKNOWN"
124
-
125
- # Append tool result as a HumanMessage so assistant sees it
126
- new_messages = state["messages"] + [HumanMessage(content=f"Tool result:\n{tool_output}")]
127
- return {"messages": new_messages}
128
-
129
  def build_graph():
 
130
  builder = StateGraph(MessagesState)
131
-
 
132
  builder.add_node("planner", planner_node)
 
133
  builder.add_node("assistant", assistant_node)
134
- builder.add_node("tools", PatchedToolNode(TOOLS))
135
-
136
  builder.add_edge(START, "planner")
137
  builder.add_conditional_edges("planner", tools_condition)
138
  builder.add_edge("tools", "assistant")
139
-
140
  return builder.compile()
 
1
  import os
2
  import re
3
+ import json
4
  from langgraph.graph import START, StateGraph, MessagesState
5
  from langgraph.prebuilt import ToolNode
6
+ from langchain_core.messages import HumanMessage, SystemMessage, AIMessage, ToolMessage
7
  from huggingface_hub import InferenceClient
8
  from custom_tools import TOOLS
 
9
 
10
  HF_TOKEN = os.getenv("HUGGINGFACE_API_TOKEN")
11
  client = InferenceClient(token=HF_TOKEN)
12
 
13
+ # Enhanced planner prompt with better instructions
14
+ planner_prompt = SystemMessage(content="""You are an expert planning assistant for answering factual questions. Your job is to analyze each question and determine the BEST tool to use.
15
 
16
+ TOOL SELECTION RULES:
17
+ 1. SEARCH: Use for ANY factual questions about:
18
+ - People (births, deaths, ages, achievements, relationships)
19
+ - Events (dates, locations, participants, outcomes)
20
+ - Places (locations, populations, geography)
21
+ - Current information (weather, news, prices)
22
+ - Specific facts requiring recent or detailed information
23
+ - Questions with numbers, dates, or statistics about real things
24
 
25
+ 2. CALCULATE: Use ONLY for pure mathematical expressions that can be evaluated
26
+ - Basic arithmetic (23 * 6 + 3)
27
+ - Percentages (15% of 250)
28
+ - Unit conversions with clear numbers
29
+ - Mathematical formulas
30
 
31
+ 3. WIKIPEDIA: Use for general knowledge topics that need comprehensive overview
32
+ - Historical events or periods
33
+ - Scientific concepts
34
+ - Geographic locations
35
+ - Famous people (when general info is needed)
 
 
 
36
 
37
+ 4. DEFINE: Use ONLY when asked for the definition of a single word
38
+ - "What does X mean?"
39
+ - "Define X"
40
+ - Single vocabulary words
 
 
41
 
42
+ 5. REVERSE: Use ONLY when explicitly asked to reverse text
 
43
 
44
+ 6. DIRECT: Use ONLY for:
45
+ - Greetings ("Hello", "Hi")
46
+ - Meta questions about the assistant
47
+ - Questions that are clearly unanswerable
48
 
49
+ IMPORTANT PATTERNS:
50
+ - "How many..." β†’ Usually SEARCH (unless pure math)
51
+ - "Who is..." β†’ WIKIPEDIA or SEARCH
52
+ - "When did..." β†’ SEARCH
53
+ - "Where is..." β†’ SEARCH
54
+ - "What is the [statistic/number]..." β†’ SEARCH
55
+ - "Calculate..." β†’ CALCULATE
56
+ - Names of people/places/things β†’ SEARCH or WIKIPEDIA
57
 
58
+ RESPONSE FORMAT: Respond with EXACTLY one of:
59
+ - "SEARCH: [exact search query]"
60
+ - "CALCULATE: [mathematical expression]"
61
+ - "WIKIPEDIA: [topic]"
62
+ - "DEFINE: [word]"
63
+ - "REVERSE: [text]"
64
+ - "DIRECT: [answer]"
65
+
66
+ Extract the most relevant query from the question. Be specific and include key terms.""")
67
 
68
  def planner_node(state: MessagesState):
69
+ messages = state["messages"]
70
+
71
+ # Get the last human message
72
+ question = None
73
+ for msg in reversed(messages):
74
+ if isinstance(msg, HumanMessage):
75
+ question = msg.content
76
+ break
77
+
78
+ if not question:
79
+ return {"messages": [AIMessage(content="DIRECT: UNKNOWN")]}
80
+
81
+ # Quick pattern matching for common cases
82
+ question_lower = question.lower()
83
+
84
+ # Mathematical calculations
85
+ if any(op in question for op in ['*', '+', '-', '/', '^']) or \
86
+ re.search(r'\d+\s*[xοΏ½οΏ½]\s*\d+', question) or \
87
+ re.search(r'\d+%\s+of\s+\d+', question_lower) or \
88
+ 'calculate' in question_lower and not 'how many' in question_lower:
89
+ # Extract the mathematical expression
90
+ expr = question
91
+ for remove in ['calculate', 'what is', 'what\'s', '?', 'equals']:
92
+ expr = expr.lower().replace(remove, '')
93
+ expr = expr.strip()
94
+ return {"messages": [AIMessage(content=f"CALCULATE: {expr}")]}
95
+
96
+ # Definitions
97
+ if question_lower.startswith(('define ', 'what does ')) and ' mean' in question_lower:
98
+ word = re.search(r'(?:define |what does )(\w+)', question_lower)
99
+ if word:
100
+ return {"messages": [AIMessage(content=f"DEFINE: {word.group(1)}")]}
101
+
102
+ # Text reversal
103
+ if 'reverse' in question_lower:
104
+ # Extract text to reverse
105
+ match = re.search(r'reverse[:\s]+["\']?(.+?)["\']?$', question, re.IGNORECASE)
106
+ if match:
107
+ return {"messages": [AIMessage(content=f"REVERSE: {match.group(1).strip()}")]}
108
+
109
+ # For most factual questions, use search
110
+ factual_indicators = [
111
+ 'how many', 'how much', 'how old', 'when did', 'when was',
112
+ 'where is', 'where was', 'who is', 'who was', 'what year',
113
+ 'which', 'name of', 'number of', 'amount of', 'age of',
114
+ 'population', 'capital', 'president', 'founded', 'created',
115
+ 'discovered', 'invented', 'released', 'published', 'born',
116
+ 'died', 'location', 'situated', 'temperature', 'weather',
117
+ 'price', 'cost', 'worth', 'value', 'rate'
118
+ ]
119
+
120
+ if any(indicator in question_lower for indicator in factual_indicators):
121
+ return {"messages": [AIMessage(content=f"SEARCH: {question}")]}
122
+
123
+ # Use planner LLM for complex cases
124
+ messages_dict = [
125
+ {"role": "system", "content": planner_prompt.content},
126
+ {"role": "user", "content": question}
127
+ ]
128
 
129
+ try:
130
+ response = client.chat.completions.create(
131
+ model="meta-llama/Meta-Llama-3-70B-Instruct",
132
+ messages=messages_dict,
133
+ max_tokens=100,
134
+ temperature=0.1
135
+ )
136
+
137
+ plan = response.choices[0].message.content.strip()
138
+ print(f"Question: {question}")
139
+ print(f"Planner output: {plan}")
140
+
141
+ return {"messages": [AIMessage(content=plan)]}
142
+
143
+ except Exception as e:
144
+ print(f"Planner error: {e}")
145
+ # Default to search for errors
146
+ return {"messages": [AIMessage(content=f"SEARCH: {question}")]}
147
 
148
+ def extract_query_from_plan(plan: str, original_question: str):
149
+ """Extract the query/expression from the planner output"""
150
+ if ":" in plan:
151
+ parts = plan.split(":", 1)
152
+ if len(parts) == 2:
153
+ query = parts[1].strip()
154
+ # Remove quotes if present
155
+ query = query.strip("'\"")
156
+ return query
157
+
158
+ # Fallback to original question
159
+ return original_question
160
 
161
+ def tool_calling_node(state: MessagesState):
162
+ """Call the appropriate tool based on planner decision"""
163
+ messages = state["messages"]
164
+
165
+ # Get planner output
166
+ plan = None
167
+ for msg in reversed(messages):
168
+ if isinstance(msg, AIMessage):
169
+ plan = msg.content
170
+ break
171
+
172
+ # Get original question
173
+ original_question = None
174
+ for msg in messages:
175
+ if isinstance(msg, HumanMessage):
176
+ original_question = msg.content
177
+ break
178
+
179
+ if not plan or not original_question:
180
+ return {"messages": [ToolMessage(content="UNKNOWN", tool_call_id="error")]}
181
+
182
+ plan_upper = plan.upper()
183
+
184
+ try:
185
+ if plan_upper.startswith("SEARCH:"):
186
+ query = extract_query_from_plan(plan, original_question)
187
+ tool = next(t for t in TOOLS if t.name == "web_search")
188
+ result = tool.invoke({"query": query})
189
+
190
+ elif plan_upper.startswith("CALCULATE:"):
191
+ expression = extract_query_from_plan(plan, original_question)
192
+ # Clean up the expression more thoroughly
193
+ expression = expression.replace("Γ—", "*").replace("x", "*").replace("X", "*")
194
+ expression = expression.replace("^", "**")
195
+ expression = expression.replace(",", "")
196
+
197
+ # Handle percentage calculations
198
+ if "%" in expression:
199
+ # Convert "X% of Y" to "Y * X / 100"
200
+ match = re.search(r'(\d+(?:\.\d+)?)\s*%\s*of\s*(\d+(?:\.\d+)?)', expression)
201
+ if match:
202
+ expression = f"{match.group(2)} * {match.group(1)} / 100"
203
+ else:
204
+ expression = expression.replace("%", "/ 100")
205
+
206
+ tool = next(t for t in TOOLS if t.name == "calculate")
207
+ result = tool.invoke({"expression": expression})
208
+
209
+ elif plan_upper.startswith("DEFINE:"):
210
+ term = extract_query_from_plan(plan, original_question)
211
+ term = term.strip("'\"?.,!").lower()
212
+ tool = next(t for t in TOOLS if t.name == "define_term")
213
+ result = tool.invoke({"term": term})
214
+
215
+ elif plan_upper.startswith("WIKIPEDIA:"):
216
+ topic = extract_query_from_plan(plan, original_question)
217
+ tool = next(t for t in TOOLS if t.name == "wikipedia_summary")
218
+ result = tool.invoke({"query": topic})
219
+
220
+ elif plan_upper.startswith("REVERSE:"):
221
+ text = extract_query_from_plan(plan, original_question)
222
+ text = text.strip("'\"")
223
+ tool = next(t for t in TOOLS if t.name == "reverse_text")
224
+ result = tool.invoke({"input": text})
225
+
226
+ elif plan_upper.startswith("DIRECT:"):
227
+ result = extract_query_from_plan(plan, original_question)
228
+
229
+ elif "UNKNOWN" in plan_upper:
230
+ result = "UNKNOWN"
231
+
232
+ else:
233
+ # Fallback: search
234
+ print(f"Unrecognized plan format: {plan}, falling back to search")
235
+ tool = next(t for t in TOOLS if t.name == "web_search")
236
+ result = tool.invoke({"query": original_question})
237
+
238
+ except Exception as e:
239
+ print(f"Tool error: {e}")
240
+ # Try to provide a more specific error or fallback
241
+ if "calculate" in plan_upper:
242
+ result = "Calculation error"
243
+ else:
244
+ result = "UNKNOWN"
245
+
246
+ print(f"Tool result: {result[:200]}...")
247
+ return {"messages": [ToolMessage(content=str(result), tool_call_id="tool_call")]}
248
 
249
+ # Enhanced answer extraction
250
+ answer_prompt = SystemMessage(content="""You are an expert at extracting precise answers from search results and tool outputs.
 
 
 
 
251
 
252
+ CRITICAL RULES:
253
+ 1. Extract the EXACT answer the question is asking for
254
+ 2. For numerical questions, return ONLY the number (no units unless asked)
255
+ 3. For yes/no questions, return ONLY "yes" or "no"
256
+ 4. For counting questions ("how many"), return ONLY the number
257
+ 5. For naming questions, return ONLY the name(s)
258
+ 6. Be as concise as possible - typically 1-10 words
259
+ 7. If the information is clearly not in the tool result, return "UNKNOWN"
 
 
 
 
260
 
261
+ PATTERN MATCHING:
262
+ - "How many..." β†’ Return just the number
263
+ - "What is the name of..." β†’ Return just the name
264
+ - "When did..." β†’ Return just the date/year
265
+ - "Where is..." β†’ Return just the location
266
+ - "Who is/was..." β†’ Return just the name or brief role
267
+ - "Is/Are..." β†’ Return "yes" or "no"
268
 
269
+ IMPORTANT: Look for specific numbers, dates, names, or facts in the tool result that directly answer the question.""")
 
270
 
271
+ def assistant_node(state: MessagesState):
272
+ """Generate final answer based on tool results"""
273
+ messages = state["messages"]
274
+
275
+ # Get original question
276
+ original_question = None
277
+ for msg in messages:
278
+ if isinstance(msg, HumanMessage):
279
+ original_question = msg.content
280
+ break
281
+
282
+ # Get tool result
283
+ tool_result = None
284
+ for msg in reversed(messages):
285
+ if isinstance(msg, ToolMessage):
286
+ tool_result = msg.content
287
+ break
288
+
289
+ if not tool_result or not original_question:
290
+ return {"messages": [AIMessage(content="UNKNOWN")]}
291
+
292
+ # For calculation results, often just return the number
293
+ if "Calculation error" not in tool_result and re.match(r'^-?\d+\.?\d*$', tool_result.strip()):
294
+ return {"messages": [AIMessage(content=tool_result.strip())]}
295
+
296
+ # For simple reversed text, return it directly
297
+ if len(tool_result.split()) == 1 and original_question.lower().startswith('reverse'):
298
+ return {"messages": [AIMessage(content=tool_result)]}
299
+
300
+ # Extract specific patterns from questions
301
+ question_lower = original_question.lower()
302
+
303
+ # Try to extract numbers for "how many" questions
304
+ if "how many" in question_lower and tool_result != "UNKNOWN":
305
+ # Look for numbers in the result
306
+ numbers = re.findall(r'\b\d+\b', tool_result)
307
+ if numbers:
308
+ # Often the first prominent number is the answer
309
+ for num in numbers:
310
+ # Check if this number is mentioned in context of the question topic
311
+ context_window = 50
312
+ num_index = tool_result.find(num)
313
+ if num_index != -1:
314
+ context = tool_result[max(0, num_index-context_window):num_index+context_window+len(num)]
315
+ # Check if relevant keywords from question appear near the number
316
+ question_keywords = [w for w in question_lower.split() if len(w) > 3 and w not in ['what', 'when', 'where', 'many', 'much']]
317
+ if any(keyword in context.lower() for keyword in question_keywords):
318
+ return {"messages": [AIMessage(content=num)]}
319
+
320
+ # Use LLM for complex extraction
321
+ messages_dict = [
322
+ {"role": "system", "content": answer_prompt.content},
323
+ {"role": "user", "content": f"Question: {original_question}\n\nTool result: {tool_result}\n\nExtract the precise answer:"}
324
+ ]
325
+
326
+ try:
327
+ response = client.chat.completions.create(
328
+ model="meta-llama/Meta-Llama-3-70B-Instruct",
329
+ messages=messages_dict,
330
+ max_tokens=50,
331
+ temperature=0.1
332
+ )
333
+
334
+ answer = response.choices[0].message.content.strip()
335
+
336
+ # Clean up common issues
337
+ answer = answer.replace("Answer:", "").replace("A:", "").strip()
338
+ answer = answer.strip(".")
339
+
340
+ # For yes/no questions, ensure lowercase
341
+ if answer.lower() in ['yes', 'no']:
342
+ answer = answer.lower()
343
+
344
+ print(f"Final answer: {answer}")
345
+ return {"messages": [AIMessage(content=answer)]}
346
+
347
+ except Exception as e:
348
+ print(f"Assistant error: {e}")
349
+ return {"messages": [AIMessage(content="UNKNOWN")]}
350
 
351
  def tools_condition(state: MessagesState) -> str:
352
+ """Decide whether to use tools or end"""
353
+ last_msg = state["messages"][-1]
354
+
355
+ if not isinstance(last_msg, AIMessage):
356
+ return "end"
357
+
358
+ content = last_msg.content.upper()
359
+
360
+ # Check if we need to use a tool
361
+ tool_keywords = ["SEARCH:", "CALCULATE:", "DEFINE:", "WIKIPEDIA:", "REVERSE:"]
362
+
363
+ if any(content.startswith(keyword) for keyword in tool_keywords):
364
  return "tools"
365
+
366
+ # For DIRECT answers or UNKNOWN, go straight to assistant to format properly
367
+ if content.startswith("DIRECT:") or "UNKNOWN" in content:
368
+ # Still go through assistant to extract the answer
369
+ return "tools"
370
+
371
  return "end"
372
 
 
 
 
 
 
 
 
 
 
373
  def build_graph():
374
+ """Build the LangGraph workflow"""
375
  builder = StateGraph(MessagesState)
376
+
377
+ # Add nodes
378
  builder.add_node("planner", planner_node)
379
+ builder.add_node("tools", tool_calling_node)
380
  builder.add_node("assistant", assistant_node)
381
+
382
+ # Add edges
383
  builder.add_edge(START, "planner")
384
  builder.add_conditional_edges("planner", tools_condition)
385
  builder.add_edge("tools", "assistant")
386
+
387
  return builder.compile()
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ gradio
2
+ requests
3
+ gradio[oauth]