Spaces:
Sleeping
Sleeping
Arbnor Tefiki
commited on
Commit
Β·
8ecb1cd
1
Parent(s):
94b3868
one percent accuracy
Browse files- .gitignore +1 -0
- app.py +81 -26
- custom_tools.py +198 -25
- functions.py +346 -99
- requirements.txt +3 -0
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
__pycache__/
|
app.py
CHANGED
@@ -5,6 +5,8 @@ import pandas as pd
|
|
5 |
from dotenv import load_dotenv
|
6 |
from functions import *
|
7 |
from langchain_core.messages import HumanMessage
|
|
|
|
|
8 |
|
9 |
load_dotenv()
|
10 |
|
@@ -49,60 +51,106 @@ def run_and_submit_all(profile: gr.OAuthProfile | None):
|
|
49 |
results_log = []
|
50 |
answers_payload = []
|
51 |
|
|
|
52 |
print(f"Running agent on {len(questions_data)} questions...")
|
53 |
-
|
|
|
|
|
54 |
task_id = item.get("task_id")
|
55 |
question_text = item.get("question")
|
56 |
if not task_id or question_text is None:
|
57 |
print(f"Skipping item with missing task_id or question: {item}")
|
58 |
continue
|
|
|
|
|
|
|
|
|
|
|
59 |
try:
|
|
|
|
|
60 |
input_messages = [HumanMessage(content=question_text)]
|
61 |
-
|
|
|
62 |
result = agent({"messages": input_messages})
|
63 |
-
|
|
|
|
|
64 |
if "messages" in result and result["messages"]:
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
answers_payload.append({"task_id": task_id, "submitted_answer": answer})
|
78 |
results_log.append({
|
79 |
"Task ID": task_id,
|
80 |
-
"Question": question_text,
|
81 |
-
"Submitted Answer": answer
|
|
|
82 |
})
|
|
|
83 |
except Exception as e:
|
84 |
print(f"Error running agent on task {task_id}: {e}")
|
85 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
|
87 |
if not answers_payload:
|
88 |
print("Agent did not produce any answers to submit.")
|
89 |
return "Agent did not produce any answers to submit.", pd.DataFrame(results_log)
|
90 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
submission_data = {"username": username.strip(), "agent_code": agent_code, "answers": answers_payload}
|
92 |
-
print(f"
|
93 |
|
94 |
try:
|
95 |
response = requests.post(submit_url, json=submission_data, timeout=60)
|
96 |
response.raise_for_status()
|
97 |
result_data = response.json()
|
|
|
|
|
|
|
|
|
|
|
98 |
final_status = (
|
99 |
f"Submission Successful!\n"
|
100 |
f"User: {result_data.get('username')}\n"
|
101 |
-
f"Overall Score: {
|
102 |
-
f"({
|
103 |
f"Message: {result_data.get('message', 'No message received.')}"
|
104 |
)
|
105 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
results_df = pd.DataFrame(results_log)
|
107 |
return final_status, results_df
|
108 |
except Exception as e:
|
@@ -113,10 +161,17 @@ def run_and_submit_all(profile: gr.OAuthProfile | None):
|
|
113 |
|
114 |
# Gradio UI
|
115 |
with gr.Blocks() as demo:
|
116 |
-
gr.Markdown("#
|
117 |
gr.Markdown(
|
118 |
"""
|
119 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
"""
|
121 |
)
|
122 |
|
@@ -154,5 +209,5 @@ if __name__ == "__main__":
|
|
154 |
|
155 |
print("-"*(60 + len(" App Starting ")) + "\n")
|
156 |
|
157 |
-
print("Launching Gradio Interface for
|
158 |
-
demo.launch(debug=True, share=False)
|
|
|
5 |
from dotenv import load_dotenv
|
6 |
from functions import *
|
7 |
from langchain_core.messages import HumanMessage
|
8 |
+
import traceback
|
9 |
+
import time
|
10 |
|
11 |
load_dotenv()
|
12 |
|
|
|
51 |
results_log = []
|
52 |
answers_payload = []
|
53 |
|
54 |
+
print(f"\n{'='*60}")
|
55 |
print(f"Running agent on {len(questions_data)} questions...")
|
56 |
+
print(f"{'='*60}\n")
|
57 |
+
|
58 |
+
for idx, item in enumerate(questions_data, 1):
|
59 |
task_id = item.get("task_id")
|
60 |
question_text = item.get("question")
|
61 |
if not task_id or question_text is None:
|
62 |
print(f"Skipping item with missing task_id or question: {item}")
|
63 |
continue
|
64 |
+
|
65 |
+
print(f"\n--- Question {idx}/{len(questions_data)} ---")
|
66 |
+
print(f"Task ID: {task_id}")
|
67 |
+
print(f"Question: {question_text}")
|
68 |
+
|
69 |
try:
|
70 |
+
# Add timeout for each question
|
71 |
+
start_time = time.time()
|
72 |
input_messages = [HumanMessage(content=question_text)]
|
73 |
+
|
74 |
+
# Invoke the agent with the question
|
75 |
result = agent({"messages": input_messages})
|
76 |
+
|
77 |
+
# Extract the answer from the result
|
78 |
+
answer = "UNKNOWN"
|
79 |
if "messages" in result and result["messages"]:
|
80 |
+
# Look for the last AI message with content
|
81 |
+
for msg in reversed(result["messages"]):
|
82 |
+
if hasattr(msg, "content") and isinstance(msg.content, str) and msg.content.strip():
|
83 |
+
# Skip planner outputs
|
84 |
+
if not any(msg.content.upper().startswith(prefix) for prefix in ["SEARCH:", "CALCULATE:", "DEFINE:", "WIKIPEDIA:", "REVERSE:", "DIRECT:"]):
|
85 |
+
answer = msg.content.strip()
|
86 |
+
break
|
87 |
+
|
88 |
+
elapsed_time = time.time() - start_time
|
89 |
+
print(f"Answer: {answer}")
|
90 |
+
print(f"Time taken: {elapsed_time:.2f}s")
|
91 |
+
|
92 |
answers_payload.append({"task_id": task_id, "submitted_answer": answer})
|
93 |
results_log.append({
|
94 |
"Task ID": task_id,
|
95 |
+
"Question": question_text[:100] + "..." if len(question_text) > 100 else question_text,
|
96 |
+
"Submitted Answer": answer,
|
97 |
+
"Time (s)": f"{elapsed_time:.2f}"
|
98 |
})
|
99 |
+
|
100 |
except Exception as e:
|
101 |
print(f"Error running agent on task {task_id}: {e}")
|
102 |
+
print(f"Traceback: {traceback.format_exc()}")
|
103 |
+
|
104 |
+
# Still submit UNKNOWN for errors
|
105 |
+
answers_payload.append({"task_id": task_id, "submitted_answer": "UNKNOWN"})
|
106 |
+
results_log.append({
|
107 |
+
"Task ID": task_id,
|
108 |
+
"Question": question_text[:100] + "..." if len(question_text) > 100 else question_text,
|
109 |
+
"Submitted Answer": f"ERROR: {str(e)[:50]}",
|
110 |
+
"Time (s)": "N/A"
|
111 |
+
})
|
112 |
+
|
113 |
+
print(f"\n{'='*60}")
|
114 |
+
print(f"Completed processing all questions")
|
115 |
+
print(f"{'='*60}\n")
|
116 |
|
117 |
if not answers_payload:
|
118 |
print("Agent did not produce any answers to submit.")
|
119 |
return "Agent did not produce any answers to submit.", pd.DataFrame(results_log)
|
120 |
|
121 |
+
# Summary before submission
|
122 |
+
unknown_count = sum(1 for ans in answers_payload if ans["submitted_answer"] == "UNKNOWN")
|
123 |
+
print(f"\nSummary before submission:")
|
124 |
+
print(f"Total questions: {len(answers_payload)}")
|
125 |
+
print(f"UNKNOWN answers: {unknown_count}")
|
126 |
+
print(f"Attempted answers: {len(answers_payload) - unknown_count}")
|
127 |
+
|
128 |
submission_data = {"username": username.strip(), "agent_code": agent_code, "answers": answers_payload}
|
129 |
+
print(f"\nSubmitting {len(answers_payload)} answers for user '{username}'...")
|
130 |
|
131 |
try:
|
132 |
response = requests.post(submit_url, json=submission_data, timeout=60)
|
133 |
response.raise_for_status()
|
134 |
result_data = response.json()
|
135 |
+
|
136 |
+
score = result_data.get('score', 0)
|
137 |
+
correct_count = result_data.get('correct_count', 0)
|
138 |
+
total_attempted = result_data.get('total_attempted', 0)
|
139 |
+
|
140 |
final_status = (
|
141 |
f"Submission Successful!\n"
|
142 |
f"User: {result_data.get('username')}\n"
|
143 |
+
f"Overall Score: {score}% "
|
144 |
+
f"({correct_count}/{total_attempted} correct)\n"
|
145 |
f"Message: {result_data.get('message', 'No message received.')}"
|
146 |
)
|
147 |
+
|
148 |
+
print("\n" + "="*60)
|
149 |
+
print("SUBMISSION RESULTS:")
|
150 |
+
print(f"Score: {score}%")
|
151 |
+
print(f"Correct: {correct_count}/{total_attempted}")
|
152 |
+
print("="*60)
|
153 |
+
|
154 |
results_df = pd.DataFrame(results_log)
|
155 |
return final_status, results_df
|
156 |
except Exception as e:
|
|
|
161 |
|
162 |
# Gradio UI
|
163 |
with gr.Blocks() as demo:
|
164 |
+
gr.Markdown("# Enhanced GAIA Agent Evaluation Runner")
|
165 |
gr.Markdown(
|
166 |
"""
|
167 |
+
This enhanced agent is optimized for GAIA benchmark questions with improved:
|
168 |
+
- Planning logic for better tool selection
|
169 |
+
- Search capabilities with more comprehensive results
|
170 |
+
- Mathematical expression parsing
|
171 |
+
- Answer extraction from search results
|
172 |
+
- Error handling and logging
|
173 |
+
|
174 |
+
Target: >50% accuracy on GAIA questions
|
175 |
"""
|
176 |
)
|
177 |
|
|
|
209 |
|
210 |
print("-"*(60 + len(" App Starting ")) + "\n")
|
211 |
|
212 |
+
print("Launching Gradio Interface for Enhanced GAIA Agent Evaluation...")
|
213 |
+
demo.launch(debug=True, share=False)
|
custom_tools.py
CHANGED
@@ -1,19 +1,21 @@
|
|
1 |
import requests
|
2 |
from duckduckgo_search import DDGS
|
3 |
from langchain_core.tools import tool
|
|
|
|
|
4 |
|
5 |
@tool
|
6 |
def reverse_text(input: str) -> str:
|
7 |
"""Reverse the characters in a text or string.
|
8 |
|
9 |
Args:
|
10 |
-
|
11 |
"""
|
12 |
return input[::-1]
|
13 |
|
14 |
@tool
|
15 |
def web_search(query: str) -> str:
|
16 |
-
"""Perform a web search using DuckDuckGo and return
|
17 |
|
18 |
Args:
|
19 |
query: The search query to look up.
|
@@ -21,76 +23,247 @@ def web_search(query: str) -> str:
|
|
21 |
try:
|
22 |
results = []
|
23 |
with DDGS() as ddgs:
|
24 |
-
|
|
|
|
|
|
|
25 |
title = r.get("title", "")
|
26 |
snippet = r.get("body", "")
|
27 |
url = r.get("href", "")
|
|
|
28 |
if title and snippet:
|
29 |
-
|
|
|
|
|
|
|
30 |
if not results:
|
31 |
-
|
32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
except Exception as e:
|
34 |
return f"Web search error: {e}"
|
35 |
|
36 |
@tool
|
37 |
def calculate(expression: str) -> str:
|
38 |
-
"""Evaluate a
|
39 |
|
40 |
Args:
|
41 |
expression: A string containing the math expression to evaluate.
|
42 |
"""
|
43 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
allowed_names = {
|
45 |
"abs": abs,
|
46 |
"round": round,
|
47 |
"min": min,
|
48 |
"max": max,
|
49 |
"pow": pow,
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
}
|
51 |
-
|
52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
except Exception as e:
|
54 |
return f"Calculation error: {e}"
|
55 |
|
56 |
@tool
|
57 |
def wikipedia_summary(query: str) -> str:
|
58 |
-
"""Retrieve a summary of a topic from Wikipedia.
|
59 |
|
60 |
Args:
|
61 |
query: The subject or topic to summarize.
|
62 |
"""
|
63 |
try:
|
|
|
|
|
|
|
|
|
|
|
64 |
response = requests.get(
|
65 |
-
f"https://en.wikipedia.org/api/rest_v1/page/summary/{
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
)
|
67 |
-
|
68 |
-
|
69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
except Exception as e:
|
71 |
return f"Wikipedia error: {e}"
|
72 |
|
73 |
@tool
|
74 |
def define_term(term: str) -> str:
|
75 |
-
"""Provide a dictionary
|
76 |
|
77 |
Args:
|
78 |
term: The word or term to define.
|
79 |
"""
|
80 |
try:
|
|
|
|
|
|
|
|
|
81 |
response = requests.get(
|
82 |
-
f"https://api.dictionaryapi.dev/api/v2/entries/en/{term}",
|
|
|
83 |
)
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
except Exception as e:
|
93 |
return f"Definition error: {e}"
|
94 |
|
95 |
# List of tools to register with your agent
|
96 |
-
TOOLS = [web_search, calculate, wikipedia_summary, define_term, reverse_text]
|
|
|
1 |
import requests
|
2 |
from duckduckgo_search import DDGS
|
3 |
from langchain_core.tools import tool
|
4 |
+
import time
|
5 |
+
import re
|
6 |
|
7 |
@tool
|
8 |
def reverse_text(input: str) -> str:
|
9 |
"""Reverse the characters in a text or string.
|
10 |
|
11 |
Args:
|
12 |
+
input: The text or string to reverse.
|
13 |
"""
|
14 |
return input[::-1]
|
15 |
|
16 |
@tool
|
17 |
def web_search(query: str) -> str:
|
18 |
+
"""Perform a web search using DuckDuckGo and return comprehensive results.
|
19 |
|
20 |
Args:
|
21 |
query: The search query to look up.
|
|
|
23 |
try:
|
24 |
results = []
|
25 |
with DDGS() as ddgs:
|
26 |
+
# Get more results for better coverage
|
27 |
+
search_results = list(ddgs.text(query, max_results=8))
|
28 |
+
|
29 |
+
for r in search_results:
|
30 |
title = r.get("title", "")
|
31 |
snippet = r.get("body", "")
|
32 |
url = r.get("href", "")
|
33 |
+
|
34 |
if title and snippet:
|
35 |
+
# Combine title and snippet for more context
|
36 |
+
full_text = f"{title}. {snippet}"
|
37 |
+
results.append(full_text)
|
38 |
+
|
39 |
if not results:
|
40 |
+
# Try with modified query
|
41 |
+
time.sleep(0.5)
|
42 |
+
with DDGS() as ddgs:
|
43 |
+
# Add more context to the query
|
44 |
+
modified_query = f"{query} facts information details"
|
45 |
+
search_results = list(ddgs.text(modified_query, max_results=5))
|
46 |
+
|
47 |
+
for r in search_results:
|
48 |
+
title = r.get("title", "")
|
49 |
+
snippet = r.get("body", "")
|
50 |
+
if title and snippet:
|
51 |
+
results.append(f"{title}. {snippet}")
|
52 |
+
|
53 |
+
if not results:
|
54 |
+
return "No search results found."
|
55 |
+
|
56 |
+
# Join all results with clear separation
|
57 |
+
return "\n\n".join(results)
|
58 |
+
|
59 |
except Exception as e:
|
60 |
return f"Web search error: {e}"
|
61 |
|
62 |
@tool
|
63 |
def calculate(expression: str) -> str:
|
64 |
+
"""Evaluate a mathematical expression and return the result.
|
65 |
|
66 |
Args:
|
67 |
expression: A string containing the math expression to evaluate.
|
68 |
"""
|
69 |
try:
|
70 |
+
# Clean the expression more thoroughly
|
71 |
+
expression = expression.strip()
|
72 |
+
|
73 |
+
# Handle various multiplication notations
|
74 |
+
expression = expression.replace("Γ", "*")
|
75 |
+
expression = expression.replace("x", "*")
|
76 |
+
expression = expression.replace("X", "*")
|
77 |
+
|
78 |
+
# Handle exponents
|
79 |
+
expression = expression.replace("^", "**")
|
80 |
+
|
81 |
+
# Remove thousands separators
|
82 |
+
expression = expression.replace(",", "")
|
83 |
+
|
84 |
+
# Handle parentheses
|
85 |
+
expression = expression.replace("[", "(").replace("]", ")")
|
86 |
+
expression = expression.replace("{", "(").replace("}", ")")
|
87 |
+
|
88 |
+
# Handle percentage calculations
|
89 |
+
# Convert "X% of Y" to "(X/100) * Y"
|
90 |
+
percent_pattern = r'(\d+(?:\.\d+)?)\s*%\s*of\s*(\d+(?:\.\d+)?)'
|
91 |
+
expression = re.sub(percent_pattern, r'(\1/100) * \2', expression)
|
92 |
+
|
93 |
+
# Convert standalone percentages
|
94 |
+
expression = re.sub(r'(\d+(?:\.\d+)?)\s*%', r'(\1/100)', expression)
|
95 |
+
|
96 |
+
# Define safe functions and constants
|
97 |
allowed_names = {
|
98 |
"abs": abs,
|
99 |
"round": round,
|
100 |
"min": min,
|
101 |
"max": max,
|
102 |
"pow": pow,
|
103 |
+
"sum": sum,
|
104 |
+
"len": len,
|
105 |
+
"__builtins__": {},
|
106 |
+
# Math constants
|
107 |
+
"pi": 3.14159265359,
|
108 |
+
"e": 2.71828182846,
|
109 |
}
|
110 |
+
|
111 |
+
# Evaluate the expression
|
112 |
+
result = eval(expression, allowed_names)
|
113 |
+
|
114 |
+
# Format the result nicely
|
115 |
+
if isinstance(result, float):
|
116 |
+
# Check if it's a whole number
|
117 |
+
if result.is_integer():
|
118 |
+
return str(int(result))
|
119 |
+
else:
|
120 |
+
# Round to reasonable precision
|
121 |
+
formatted = f"{result:.10f}".rstrip('0').rstrip('.')
|
122 |
+
return formatted
|
123 |
+
else:
|
124 |
+
return str(result)
|
125 |
+
|
126 |
+
except ZeroDivisionError:
|
127 |
+
return "Error: Division by zero"
|
128 |
+
except SyntaxError as e:
|
129 |
+
return f"Syntax error in expression: {e}"
|
130 |
except Exception as e:
|
131 |
return f"Calculation error: {e}"
|
132 |
|
133 |
@tool
|
134 |
def wikipedia_summary(query: str) -> str:
|
135 |
+
"""Retrieve a comprehensive summary of a topic from Wikipedia.
|
136 |
|
137 |
Args:
|
138 |
query: The subject or topic to summarize.
|
139 |
"""
|
140 |
try:
|
141 |
+
# Clean the query
|
142 |
+
query = query.strip()
|
143 |
+
|
144 |
+
# First, try direct API
|
145 |
+
clean_query = query.replace(" ", "_")
|
146 |
response = requests.get(
|
147 |
+
f"https://en.wikipedia.org/api/rest_v1/page/summary/{clean_query}",
|
148 |
+
timeout=10,
|
149 |
+
headers={"User-Agent": "Mozilla/5.0"}
|
150 |
+
)
|
151 |
+
|
152 |
+
if response.status_code == 200:
|
153 |
+
data = response.json()
|
154 |
+
extract = data.get("extract", "")
|
155 |
+
if extract and extract != "No summary found.":
|
156 |
+
title = data.get("title", query)
|
157 |
+
description = data.get("description", "")
|
158 |
+
|
159 |
+
# Get additional details from the full article if needed
|
160 |
+
full_response = requests.get(
|
161 |
+
f"https://en.wikipedia.org/w/api.php",
|
162 |
+
params={
|
163 |
+
"action": "query",
|
164 |
+
"prop": "extracts",
|
165 |
+
"exintro": True,
|
166 |
+
"explaintext": True,
|
167 |
+
"titles": title,
|
168 |
+
"format": "json"
|
169 |
+
},
|
170 |
+
timeout=10
|
171 |
+
)
|
172 |
+
|
173 |
+
result = extract
|
174 |
+
if description and description not in extract:
|
175 |
+
result = f"{description}. {extract}"
|
176 |
+
|
177 |
+
if full_response.status_code == 200:
|
178 |
+
full_data = full_response.json()
|
179 |
+
pages = full_data.get("query", {}).get("pages", {})
|
180 |
+
for page_id, page_info in pages.items():
|
181 |
+
full_extract = page_info.get("extract", "")
|
182 |
+
if full_extract and len(full_extract) > len(result):
|
183 |
+
result = full_extract[:1000] # Limit length
|
184 |
+
|
185 |
+
return result
|
186 |
+
|
187 |
+
# Fallback: Try searching Wikipedia
|
188 |
+
search_response = requests.get(
|
189 |
+
"https://en.wikipedia.org/w/api.php",
|
190 |
+
params={
|
191 |
+
"action": "opensearch",
|
192 |
+
"search": query,
|
193 |
+
"limit": 3,
|
194 |
+
"format": "json"
|
195 |
+
},
|
196 |
+
timeout=10
|
197 |
)
|
198 |
+
|
199 |
+
if search_response.status_code == 200:
|
200 |
+
search_data = search_response.json()
|
201 |
+
if len(search_data) > 1 and search_data[1]:
|
202 |
+
# Try the first result
|
203 |
+
first_result = search_data[1][0]
|
204 |
+
if first_result:
|
205 |
+
return wikipedia_summary(first_result)
|
206 |
+
|
207 |
+
return f"No Wikipedia article found for '{query}'."
|
208 |
+
|
209 |
except Exception as e:
|
210 |
return f"Wikipedia error: {e}"
|
211 |
|
212 |
@tool
|
213 |
def define_term(term: str) -> str:
|
214 |
+
"""Provide a comprehensive dictionary definition of a given term.
|
215 |
|
216 |
Args:
|
217 |
term: The word or term to define.
|
218 |
"""
|
219 |
try:
|
220 |
+
# Clean the term
|
221 |
+
term = term.strip().lower()
|
222 |
+
term = re.sub(r'[^\w\s-]', '', term) # Remove punctuation except hyphens
|
223 |
+
|
224 |
response = requests.get(
|
225 |
+
f"https://api.dictionaryapi.dev/api/v2/entries/en/{term}",
|
226 |
+
timeout=10
|
227 |
)
|
228 |
+
|
229 |
+
if response.status_code == 200:
|
230 |
+
data = response.json()
|
231 |
+
all_definitions = []
|
232 |
+
|
233 |
+
# Collect all definitions with their parts of speech
|
234 |
+
for entry in data:
|
235 |
+
word = entry.get("word", term)
|
236 |
+
meanings = entry.get("meanings", [])
|
237 |
+
|
238 |
+
for meaning in meanings:
|
239 |
+
part_of_speech = meaning.get("partOfSpeech", "")
|
240 |
+
definitions = meaning.get("definitions", [])
|
241 |
+
|
242 |
+
for definition in definitions:
|
243 |
+
def_text = definition.get("definition", "")
|
244 |
+
if def_text:
|
245 |
+
if part_of_speech:
|
246 |
+
all_definitions.append(f"({part_of_speech}) {def_text}")
|
247 |
+
else:
|
248 |
+
all_definitions.append(def_text)
|
249 |
+
|
250 |
+
if all_definitions:
|
251 |
+
# Return the most comprehensive definition
|
252 |
+
# Prefer longer, more detailed definitions
|
253 |
+
all_definitions.sort(key=len, reverse=True)
|
254 |
+
return all_definitions[0]
|
255 |
+
|
256 |
+
# Try alternative approach - use the error message if it's informative
|
257 |
+
if response.status_code == 404:
|
258 |
+
error_data = response.json()
|
259 |
+
if "message" in error_data:
|
260 |
+
return f"No definition found for '{term}'"
|
261 |
+
|
262 |
+
# Last resort - return a clear message
|
263 |
+
return f"Unable to find definition for '{term}'"
|
264 |
+
|
265 |
except Exception as e:
|
266 |
return f"Definition error: {e}"
|
267 |
|
268 |
# List of tools to register with your agent
|
269 |
+
TOOLS = [web_search, calculate, wikipedia_summary, define_term, reverse_text]
|
functions.py
CHANGED
@@ -1,140 +1,387 @@
|
|
1 |
import os
|
2 |
import re
|
|
|
3 |
from langgraph.graph import START, StateGraph, MessagesState
|
4 |
from langgraph.prebuilt import ToolNode
|
5 |
-
from langchain_core.messages import HumanMessage, SystemMessage
|
6 |
from huggingface_hub import InferenceClient
|
7 |
from custom_tools import TOOLS
|
8 |
-
from langchain_core.messages import AIMessage
|
9 |
|
10 |
HF_TOKEN = os.getenv("HUGGINGFACE_API_TOKEN")
|
11 |
client = InferenceClient(token=HF_TOKEN)
|
12 |
|
13 |
-
|
14 |
-
|
15 |
|
16 |
-
|
17 |
-
|
18 |
-
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
-
|
21 |
-
|
|
|
|
|
|
|
22 |
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
-If the questions asks for backwards pronounciation or reversing text, say:
|
30 |
-
I need to reverse text.
|
31 |
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
Q: How many studio albums did Mercedes Sosa release between 2000 and 2009?
|
37 |
-
A: I need to search this.
|
38 |
|
39 |
-
|
40 |
-
A: I need to define this.
|
41 |
|
42 |
-
|
43 |
-
|
|
|
|
|
44 |
|
45 |
-
|
46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
|
48 |
-
|
49 |
-
|
50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
|
52 |
def planner_node(state: MessagesState):
|
53 |
-
|
54 |
-
|
55 |
-
#
|
56 |
-
|
57 |
-
for msg in
|
58 |
-
if isinstance(msg,
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
|
71 |
-
|
72 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
|
74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
|
76 |
-
|
77 |
-
|
78 |
-
Use the tool result and the original question to give the final answer.
|
79 |
-
If the tool result is unhelpful or unclear, respond with 'UNKNOWN'.
|
80 |
-
Respond with only the answer β no explanations.
|
81 |
-
""")
|
82 |
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
role = "user"
|
92 |
-
else:
|
93 |
-
raise ValueError(f"Unsupported message type: {type(msg)}")
|
94 |
-
messages_dict.append({"role": role, "content": msg.content})
|
95 |
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
|
|
|
|
|
|
100 |
|
101 |
-
|
102 |
-
print("Final answer output:\n", text)
|
103 |
|
104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
|
106 |
def tools_condition(state: MessagesState) -> str:
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
|
|
|
|
|
|
116 |
return "tools"
|
117 |
-
|
|
|
|
|
|
|
|
|
|
|
118 |
return "end"
|
119 |
|
120 |
-
class PatchedToolNode(ToolNode):
|
121 |
-
def invoke(self, state: MessagesState, config) -> dict:
|
122 |
-
result = super().invoke(state)
|
123 |
-
tool_output = result.get("messages", [])[0].content if result.get("messages") else "UNKNOWN"
|
124 |
-
|
125 |
-
# Append tool result as a HumanMessage so assistant sees it
|
126 |
-
new_messages = state["messages"] + [HumanMessage(content=f"Tool result:\n{tool_output}")]
|
127 |
-
return {"messages": new_messages}
|
128 |
-
|
129 |
def build_graph():
|
|
|
130 |
builder = StateGraph(MessagesState)
|
131 |
-
|
|
|
132 |
builder.add_node("planner", planner_node)
|
|
|
133 |
builder.add_node("assistant", assistant_node)
|
134 |
-
|
135 |
-
|
136 |
builder.add_edge(START, "planner")
|
137 |
builder.add_conditional_edges("planner", tools_condition)
|
138 |
builder.add_edge("tools", "assistant")
|
139 |
-
|
140 |
return builder.compile()
|
|
|
1 |
import os
|
2 |
import re
|
3 |
+
import json
|
4 |
from langgraph.graph import START, StateGraph, MessagesState
|
5 |
from langgraph.prebuilt import ToolNode
|
6 |
+
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage, ToolMessage
|
7 |
from huggingface_hub import InferenceClient
|
8 |
from custom_tools import TOOLS
|
|
|
9 |
|
10 |
HF_TOKEN = os.getenv("HUGGINGFACE_API_TOKEN")
|
11 |
client = InferenceClient(token=HF_TOKEN)
|
12 |
|
13 |
+
# Enhanced planner prompt with better instructions
|
14 |
+
planner_prompt = SystemMessage(content="""You are an expert planning assistant for answering factual questions. Your job is to analyze each question and determine the BEST tool to use.
|
15 |
|
16 |
+
TOOL SELECTION RULES:
|
17 |
+
1. SEARCH: Use for ANY factual questions about:
|
18 |
+
- People (births, deaths, ages, achievements, relationships)
|
19 |
+
- Events (dates, locations, participants, outcomes)
|
20 |
+
- Places (locations, populations, geography)
|
21 |
+
- Current information (weather, news, prices)
|
22 |
+
- Specific facts requiring recent or detailed information
|
23 |
+
- Questions with numbers, dates, or statistics about real things
|
24 |
|
25 |
+
2. CALCULATE: Use ONLY for pure mathematical expressions that can be evaluated
|
26 |
+
- Basic arithmetic (23 * 6 + 3)
|
27 |
+
- Percentages (15% of 250)
|
28 |
+
- Unit conversions with clear numbers
|
29 |
+
- Mathematical formulas
|
30 |
|
31 |
+
3. WIKIPEDIA: Use for general knowledge topics that need comprehensive overview
|
32 |
+
- Historical events or periods
|
33 |
+
- Scientific concepts
|
34 |
+
- Geographic locations
|
35 |
+
- Famous people (when general info is needed)
|
|
|
|
|
|
|
36 |
|
37 |
+
4. DEFINE: Use ONLY when asked for the definition of a single word
|
38 |
+
- "What does X mean?"
|
39 |
+
- "Define X"
|
40 |
+
- Single vocabulary words
|
|
|
|
|
41 |
|
42 |
+
5. REVERSE: Use ONLY when explicitly asked to reverse text
|
|
|
43 |
|
44 |
+
6. DIRECT: Use ONLY for:
|
45 |
+
- Greetings ("Hello", "Hi")
|
46 |
+
- Meta questions about the assistant
|
47 |
+
- Questions that are clearly unanswerable
|
48 |
|
49 |
+
IMPORTANT PATTERNS:
|
50 |
+
- "How many..." β Usually SEARCH (unless pure math)
|
51 |
+
- "Who is..." β WIKIPEDIA or SEARCH
|
52 |
+
- "When did..." β SEARCH
|
53 |
+
- "Where is..." β SEARCH
|
54 |
+
- "What is the [statistic/number]..." β SEARCH
|
55 |
+
- "Calculate..." β CALCULATE
|
56 |
+
- Names of people/places/things β SEARCH or WIKIPEDIA
|
57 |
|
58 |
+
RESPONSE FORMAT: Respond with EXACTLY one of:
|
59 |
+
- "SEARCH: [exact search query]"
|
60 |
+
- "CALCULATE: [mathematical expression]"
|
61 |
+
- "WIKIPEDIA: [topic]"
|
62 |
+
- "DEFINE: [word]"
|
63 |
+
- "REVERSE: [text]"
|
64 |
+
- "DIRECT: [answer]"
|
65 |
+
|
66 |
+
Extract the most relevant query from the question. Be specific and include key terms.""")
|
67 |
|
68 |
def planner_node(state: MessagesState):
|
69 |
+
messages = state["messages"]
|
70 |
+
|
71 |
+
# Get the last human message
|
72 |
+
question = None
|
73 |
+
for msg in reversed(messages):
|
74 |
+
if isinstance(msg, HumanMessage):
|
75 |
+
question = msg.content
|
76 |
+
break
|
77 |
+
|
78 |
+
if not question:
|
79 |
+
return {"messages": [AIMessage(content="DIRECT: UNKNOWN")]}
|
80 |
+
|
81 |
+
# Quick pattern matching for common cases
|
82 |
+
question_lower = question.lower()
|
83 |
+
|
84 |
+
# Mathematical calculations
|
85 |
+
if any(op in question for op in ['*', '+', '-', '/', '^']) or \
|
86 |
+
re.search(r'\d+\s*[xοΏ½οΏ½]\s*\d+', question) or \
|
87 |
+
re.search(r'\d+%\s+of\s+\d+', question_lower) or \
|
88 |
+
'calculate' in question_lower and not 'how many' in question_lower:
|
89 |
+
# Extract the mathematical expression
|
90 |
+
expr = question
|
91 |
+
for remove in ['calculate', 'what is', 'what\'s', '?', 'equals']:
|
92 |
+
expr = expr.lower().replace(remove, '')
|
93 |
+
expr = expr.strip()
|
94 |
+
return {"messages": [AIMessage(content=f"CALCULATE: {expr}")]}
|
95 |
+
|
96 |
+
# Definitions
|
97 |
+
if question_lower.startswith(('define ', 'what does ')) and ' mean' in question_lower:
|
98 |
+
word = re.search(r'(?:define |what does )(\w+)', question_lower)
|
99 |
+
if word:
|
100 |
+
return {"messages": [AIMessage(content=f"DEFINE: {word.group(1)}")]}
|
101 |
+
|
102 |
+
# Text reversal
|
103 |
+
if 'reverse' in question_lower:
|
104 |
+
# Extract text to reverse
|
105 |
+
match = re.search(r'reverse[:\s]+["\']?(.+?)["\']?$', question, re.IGNORECASE)
|
106 |
+
if match:
|
107 |
+
return {"messages": [AIMessage(content=f"REVERSE: {match.group(1).strip()}")]}
|
108 |
+
|
109 |
+
# For most factual questions, use search
|
110 |
+
factual_indicators = [
|
111 |
+
'how many', 'how much', 'how old', 'when did', 'when was',
|
112 |
+
'where is', 'where was', 'who is', 'who was', 'what year',
|
113 |
+
'which', 'name of', 'number of', 'amount of', 'age of',
|
114 |
+
'population', 'capital', 'president', 'founded', 'created',
|
115 |
+
'discovered', 'invented', 'released', 'published', 'born',
|
116 |
+
'died', 'location', 'situated', 'temperature', 'weather',
|
117 |
+
'price', 'cost', 'worth', 'value', 'rate'
|
118 |
+
]
|
119 |
+
|
120 |
+
if any(indicator in question_lower for indicator in factual_indicators):
|
121 |
+
return {"messages": [AIMessage(content=f"SEARCH: {question}")]}
|
122 |
+
|
123 |
+
# Use planner LLM for complex cases
|
124 |
+
messages_dict = [
|
125 |
+
{"role": "system", "content": planner_prompt.content},
|
126 |
+
{"role": "user", "content": question}
|
127 |
+
]
|
128 |
|
129 |
+
try:
|
130 |
+
response = client.chat.completions.create(
|
131 |
+
model="meta-llama/Meta-Llama-3-70B-Instruct",
|
132 |
+
messages=messages_dict,
|
133 |
+
max_tokens=100,
|
134 |
+
temperature=0.1
|
135 |
+
)
|
136 |
+
|
137 |
+
plan = response.choices[0].message.content.strip()
|
138 |
+
print(f"Question: {question}")
|
139 |
+
print(f"Planner output: {plan}")
|
140 |
+
|
141 |
+
return {"messages": [AIMessage(content=plan)]}
|
142 |
+
|
143 |
+
except Exception as e:
|
144 |
+
print(f"Planner error: {e}")
|
145 |
+
# Default to search for errors
|
146 |
+
return {"messages": [AIMessage(content=f"SEARCH: {question}")]}
|
147 |
|
148 |
+
def extract_query_from_plan(plan: str, original_question: str):
|
149 |
+
"""Extract the query/expression from the planner output"""
|
150 |
+
if ":" in plan:
|
151 |
+
parts = plan.split(":", 1)
|
152 |
+
if len(parts) == 2:
|
153 |
+
query = parts[1].strip()
|
154 |
+
# Remove quotes if present
|
155 |
+
query = query.strip("'\"")
|
156 |
+
return query
|
157 |
+
|
158 |
+
# Fallback to original question
|
159 |
+
return original_question
|
160 |
|
161 |
+
def tool_calling_node(state: MessagesState):
|
162 |
+
"""Call the appropriate tool based on planner decision"""
|
163 |
+
messages = state["messages"]
|
164 |
+
|
165 |
+
# Get planner output
|
166 |
+
plan = None
|
167 |
+
for msg in reversed(messages):
|
168 |
+
if isinstance(msg, AIMessage):
|
169 |
+
plan = msg.content
|
170 |
+
break
|
171 |
+
|
172 |
+
# Get original question
|
173 |
+
original_question = None
|
174 |
+
for msg in messages:
|
175 |
+
if isinstance(msg, HumanMessage):
|
176 |
+
original_question = msg.content
|
177 |
+
break
|
178 |
+
|
179 |
+
if not plan or not original_question:
|
180 |
+
return {"messages": [ToolMessage(content="UNKNOWN", tool_call_id="error")]}
|
181 |
+
|
182 |
+
plan_upper = plan.upper()
|
183 |
+
|
184 |
+
try:
|
185 |
+
if plan_upper.startswith("SEARCH:"):
|
186 |
+
query = extract_query_from_plan(plan, original_question)
|
187 |
+
tool = next(t for t in TOOLS if t.name == "web_search")
|
188 |
+
result = tool.invoke({"query": query})
|
189 |
+
|
190 |
+
elif plan_upper.startswith("CALCULATE:"):
|
191 |
+
expression = extract_query_from_plan(plan, original_question)
|
192 |
+
# Clean up the expression more thoroughly
|
193 |
+
expression = expression.replace("Γ", "*").replace("x", "*").replace("X", "*")
|
194 |
+
expression = expression.replace("^", "**")
|
195 |
+
expression = expression.replace(",", "")
|
196 |
+
|
197 |
+
# Handle percentage calculations
|
198 |
+
if "%" in expression:
|
199 |
+
# Convert "X% of Y" to "Y * X / 100"
|
200 |
+
match = re.search(r'(\d+(?:\.\d+)?)\s*%\s*of\s*(\d+(?:\.\d+)?)', expression)
|
201 |
+
if match:
|
202 |
+
expression = f"{match.group(2)} * {match.group(1)} / 100"
|
203 |
+
else:
|
204 |
+
expression = expression.replace("%", "/ 100")
|
205 |
+
|
206 |
+
tool = next(t for t in TOOLS if t.name == "calculate")
|
207 |
+
result = tool.invoke({"expression": expression})
|
208 |
+
|
209 |
+
elif plan_upper.startswith("DEFINE:"):
|
210 |
+
term = extract_query_from_plan(plan, original_question)
|
211 |
+
term = term.strip("'\"?.,!").lower()
|
212 |
+
tool = next(t for t in TOOLS if t.name == "define_term")
|
213 |
+
result = tool.invoke({"term": term})
|
214 |
+
|
215 |
+
elif plan_upper.startswith("WIKIPEDIA:"):
|
216 |
+
topic = extract_query_from_plan(plan, original_question)
|
217 |
+
tool = next(t for t in TOOLS if t.name == "wikipedia_summary")
|
218 |
+
result = tool.invoke({"query": topic})
|
219 |
+
|
220 |
+
elif plan_upper.startswith("REVERSE:"):
|
221 |
+
text = extract_query_from_plan(plan, original_question)
|
222 |
+
text = text.strip("'\"")
|
223 |
+
tool = next(t for t in TOOLS if t.name == "reverse_text")
|
224 |
+
result = tool.invoke({"input": text})
|
225 |
+
|
226 |
+
elif plan_upper.startswith("DIRECT:"):
|
227 |
+
result = extract_query_from_plan(plan, original_question)
|
228 |
+
|
229 |
+
elif "UNKNOWN" in plan_upper:
|
230 |
+
result = "UNKNOWN"
|
231 |
+
|
232 |
+
else:
|
233 |
+
# Fallback: search
|
234 |
+
print(f"Unrecognized plan format: {plan}, falling back to search")
|
235 |
+
tool = next(t for t in TOOLS if t.name == "web_search")
|
236 |
+
result = tool.invoke({"query": original_question})
|
237 |
+
|
238 |
+
except Exception as e:
|
239 |
+
print(f"Tool error: {e}")
|
240 |
+
# Try to provide a more specific error or fallback
|
241 |
+
if "calculate" in plan_upper:
|
242 |
+
result = "Calculation error"
|
243 |
+
else:
|
244 |
+
result = "UNKNOWN"
|
245 |
+
|
246 |
+
print(f"Tool result: {result[:200]}...")
|
247 |
+
return {"messages": [ToolMessage(content=str(result), tool_call_id="tool_call")]}
|
248 |
|
249 |
+
# Enhanced answer extraction
|
250 |
+
answer_prompt = SystemMessage(content="""You are an expert at extracting precise answers from search results and tool outputs.
|
|
|
|
|
|
|
|
|
251 |
|
252 |
+
CRITICAL RULES:
|
253 |
+
1. Extract the EXACT answer the question is asking for
|
254 |
+
2. For numerical questions, return ONLY the number (no units unless asked)
|
255 |
+
3. For yes/no questions, return ONLY "yes" or "no"
|
256 |
+
4. For counting questions ("how many"), return ONLY the number
|
257 |
+
5. For naming questions, return ONLY the name(s)
|
258 |
+
6. Be as concise as possible - typically 1-10 words
|
259 |
+
7. If the information is clearly not in the tool result, return "UNKNOWN"
|
|
|
|
|
|
|
|
|
260 |
|
261 |
+
PATTERN MATCHING:
|
262 |
+
- "How many..." β Return just the number
|
263 |
+
- "What is the name of..." β Return just the name
|
264 |
+
- "When did..." β Return just the date/year
|
265 |
+
- "Where is..." β Return just the location
|
266 |
+
- "Who is/was..." β Return just the name or brief role
|
267 |
+
- "Is/Are..." β Return "yes" or "no"
|
268 |
|
269 |
+
IMPORTANT: Look for specific numbers, dates, names, or facts in the tool result that directly answer the question.""")
|
|
|
270 |
|
271 |
+
def assistant_node(state: MessagesState):
|
272 |
+
"""Generate final answer based on tool results"""
|
273 |
+
messages = state["messages"]
|
274 |
+
|
275 |
+
# Get original question
|
276 |
+
original_question = None
|
277 |
+
for msg in messages:
|
278 |
+
if isinstance(msg, HumanMessage):
|
279 |
+
original_question = msg.content
|
280 |
+
break
|
281 |
+
|
282 |
+
# Get tool result
|
283 |
+
tool_result = None
|
284 |
+
for msg in reversed(messages):
|
285 |
+
if isinstance(msg, ToolMessage):
|
286 |
+
tool_result = msg.content
|
287 |
+
break
|
288 |
+
|
289 |
+
if not tool_result or not original_question:
|
290 |
+
return {"messages": [AIMessage(content="UNKNOWN")]}
|
291 |
+
|
292 |
+
# For calculation results, often just return the number
|
293 |
+
if "Calculation error" not in tool_result and re.match(r'^-?\d+\.?\d*$', tool_result.strip()):
|
294 |
+
return {"messages": [AIMessage(content=tool_result.strip())]}
|
295 |
+
|
296 |
+
# For simple reversed text, return it directly
|
297 |
+
if len(tool_result.split()) == 1 and original_question.lower().startswith('reverse'):
|
298 |
+
return {"messages": [AIMessage(content=tool_result)]}
|
299 |
+
|
300 |
+
# Extract specific patterns from questions
|
301 |
+
question_lower = original_question.lower()
|
302 |
+
|
303 |
+
# Try to extract numbers for "how many" questions
|
304 |
+
if "how many" in question_lower and tool_result != "UNKNOWN":
|
305 |
+
# Look for numbers in the result
|
306 |
+
numbers = re.findall(r'\b\d+\b', tool_result)
|
307 |
+
if numbers:
|
308 |
+
# Often the first prominent number is the answer
|
309 |
+
for num in numbers:
|
310 |
+
# Check if this number is mentioned in context of the question topic
|
311 |
+
context_window = 50
|
312 |
+
num_index = tool_result.find(num)
|
313 |
+
if num_index != -1:
|
314 |
+
context = tool_result[max(0, num_index-context_window):num_index+context_window+len(num)]
|
315 |
+
# Check if relevant keywords from question appear near the number
|
316 |
+
question_keywords = [w for w in question_lower.split() if len(w) > 3 and w not in ['what', 'when', 'where', 'many', 'much']]
|
317 |
+
if any(keyword in context.lower() for keyword in question_keywords):
|
318 |
+
return {"messages": [AIMessage(content=num)]}
|
319 |
+
|
320 |
+
# Use LLM for complex extraction
|
321 |
+
messages_dict = [
|
322 |
+
{"role": "system", "content": answer_prompt.content},
|
323 |
+
{"role": "user", "content": f"Question: {original_question}\n\nTool result: {tool_result}\n\nExtract the precise answer:"}
|
324 |
+
]
|
325 |
+
|
326 |
+
try:
|
327 |
+
response = client.chat.completions.create(
|
328 |
+
model="meta-llama/Meta-Llama-3-70B-Instruct",
|
329 |
+
messages=messages_dict,
|
330 |
+
max_tokens=50,
|
331 |
+
temperature=0.1
|
332 |
+
)
|
333 |
+
|
334 |
+
answer = response.choices[0].message.content.strip()
|
335 |
+
|
336 |
+
# Clean up common issues
|
337 |
+
answer = answer.replace("Answer:", "").replace("A:", "").strip()
|
338 |
+
answer = answer.strip(".")
|
339 |
+
|
340 |
+
# For yes/no questions, ensure lowercase
|
341 |
+
if answer.lower() in ['yes', 'no']:
|
342 |
+
answer = answer.lower()
|
343 |
+
|
344 |
+
print(f"Final answer: {answer}")
|
345 |
+
return {"messages": [AIMessage(content=answer)]}
|
346 |
+
|
347 |
+
except Exception as e:
|
348 |
+
print(f"Assistant error: {e}")
|
349 |
+
return {"messages": [AIMessage(content="UNKNOWN")]}
|
350 |
|
351 |
def tools_condition(state: MessagesState) -> str:
|
352 |
+
"""Decide whether to use tools or end"""
|
353 |
+
last_msg = state["messages"][-1]
|
354 |
+
|
355 |
+
if not isinstance(last_msg, AIMessage):
|
356 |
+
return "end"
|
357 |
+
|
358 |
+
content = last_msg.content.upper()
|
359 |
+
|
360 |
+
# Check if we need to use a tool
|
361 |
+
tool_keywords = ["SEARCH:", "CALCULATE:", "DEFINE:", "WIKIPEDIA:", "REVERSE:"]
|
362 |
+
|
363 |
+
if any(content.startswith(keyword) for keyword in tool_keywords):
|
364 |
return "tools"
|
365 |
+
|
366 |
+
# For DIRECT answers or UNKNOWN, go straight to assistant to format properly
|
367 |
+
if content.startswith("DIRECT:") or "UNKNOWN" in content:
|
368 |
+
# Still go through assistant to extract the answer
|
369 |
+
return "tools"
|
370 |
+
|
371 |
return "end"
|
372 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
373 |
def build_graph():
|
374 |
+
"""Build the LangGraph workflow"""
|
375 |
builder = StateGraph(MessagesState)
|
376 |
+
|
377 |
+
# Add nodes
|
378 |
builder.add_node("planner", planner_node)
|
379 |
+
builder.add_node("tools", tool_calling_node)
|
380 |
builder.add_node("assistant", assistant_node)
|
381 |
+
|
382 |
+
# Add edges
|
383 |
builder.add_edge(START, "planner")
|
384 |
builder.add_conditional_edges("planner", tools_condition)
|
385 |
builder.add_edge("tools", "assistant")
|
386 |
+
|
387 |
return builder.compile()
|
requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
gradio
|
2 |
+
requests
|
3 |
+
gradio[oauth]
|