Update agent.py
Browse files
agent.py
CHANGED
@@ -1,141 +1,90 @@
|
|
1 |
import os
|
2 |
-
import
|
3 |
-
|
4 |
-
from
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
-
|
7 |
-
try:
|
8 |
-
with DDGS() as ddg:
|
9 |
-
results = ddg.text(query=query, region="wt-wt", max_results=5)
|
10 |
-
bodies = [r.get('body', '') for r in results if r.get('body')]
|
11 |
-
return "\n".join(bodies[:3])
|
12 |
-
except Exception as e:
|
13 |
-
return f"ERROR: {e}"
|
14 |
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
-
|
22 |
-
|
23 |
-
if not answer:
|
24 |
-
return ""
|
25 |
-
# Remove apologies and anything after
|
26 |
-
answer = re.sub(
|
27 |
-
r'(?i)(I[\' ]?m sorry.*|Unfortunately.*|I cannot.*|I am unable.*|error:.*|no file.*|but.*|however.*|unable to.*|not available.*|if you have access.*|I can\'t.*)',
|
28 |
-
'', answer).strip()
|
29 |
-
# Remove everything after the first period if it's not a list
|
30 |
-
if not ("list" in question or "ingredient" in question or "vegetable" in question):
|
31 |
-
answer = answer.split('\n')[0].split('.')[0]
|
32 |
-
# Remove quotes/brackets
|
33 |
-
answer = answer.strip(' "\'[](),;:')
|
34 |
-
# Only numbers for count questions
|
35 |
-
if re.search(r'how many|number of|albums|at bats|total sales|output', question, re.I):
|
36 |
-
match = re.search(r'(\d+)', answer)
|
37 |
-
if match:
|
38 |
-
return match.group(1)
|
39 |
-
# Only last word for "surname", first for "first name"
|
40 |
-
if "surname" in question:
|
41 |
-
return answer.split()[-1]
|
42 |
-
if "first name" in question:
|
43 |
-
return answer.split()[0]
|
44 |
-
# For code outputs, numbers only
|
45 |
-
if "output" in question and "python" in question:
|
46 |
-
num = re.search(r'(\d+)', answer)
|
47 |
-
return num.group(1) if num else answer
|
48 |
-
# Only country code (3+ uppercase letters or digits)
|
49 |
-
if re.search(r'IOC country code|award number|NASA', question, re.I):
|
50 |
-
code = re.search(r'[A-Z0-9]{3,}', answer)
|
51 |
-
if code:
|
52 |
-
return code.group(0)
|
53 |
-
# For lists: comma-separated, alpha, deduped, merged phrases
|
54 |
-
if "list" in question or "ingredient" in question or "vegetable" in question:
|
55 |
-
items = [x.strip(' "\'') for x in re.split(r'[,\n]', answer) if x.strip()]
|
56 |
-
merged = []
|
57 |
-
skip = False
|
58 |
-
for i, item in enumerate(items):
|
59 |
-
if skip:
|
60 |
-
skip = False
|
61 |
-
continue
|
62 |
-
if i + 1 < len(items) and item in ['sweet', 'green', 'lemon', 'ripe', 'whole', 'fresh', 'bell']:
|
63 |
-
merged.append(f"{item} {items[i+1]}")
|
64 |
-
skip = True
|
65 |
-
else:
|
66 |
-
merged.append(item)
|
67 |
-
merged = [x.lower() for x in merged]
|
68 |
-
merged = sorted(set(merged))
|
69 |
-
return ', '.join(merged)
|
70 |
-
# For chess: algebraic move
|
71 |
-
if "algebraic notation" in question or "chess" in question:
|
72 |
-
move = re.findall(r'[KQRBN]?[a-h]?[1-8]?x?[a-h][1-8][+#]?', answer)
|
73 |
-
if move:
|
74 |
-
return move[-1]
|
75 |
-
return answer.strip(' "\'[](),;:')
|
76 |
|
77 |
-
|
78 |
-
|
79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
|
81 |
-
def
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
# Retry if apology/empty/incorrect
|
100 |
-
if not formatted or "sorry" in formatted.lower() or "unable" in formatted.lower():
|
101 |
-
llm_answer2 = self.llm.chat.completions.create(
|
102 |
-
model="gpt-4o",
|
103 |
-
messages=[
|
104 |
-
{"role": "system", "content": "Only answer with the value. No explanation. Do not apologize. Do not begin with 'I'm sorry', 'Unfortunately', or similar."},
|
105 |
-
{"role": "user", "content": f"Web search results:\n{web_result}\n\nQuestion: {question}"}
|
106 |
-
],
|
107 |
-
temperature=0.0,
|
108 |
-
max_tokens=128,
|
109 |
-
).choices[0].message.content.strip()
|
110 |
-
formatted = format_gaia_answer(llm_answer2, question)
|
111 |
-
return formatted
|
112 |
-
# For code/math output
|
113 |
-
if "output" in question.lower() and "python" in question.lower():
|
114 |
-
code_match = re.search(r'```python(.*?)```', question, re.DOTALL)
|
115 |
-
code = code_match.group(1) if code_match else ""
|
116 |
-
result = eval_python_code(code)
|
117 |
-
return format_gaia_answer(result, question)
|
118 |
-
# For lists/ingredients, always web search and format
|
119 |
-
if "list" in question.lower() or "ingredient" in question.lower() or "vegetable" in question.lower():
|
120 |
-
web_result = duckduckgo_search(question)
|
121 |
-
llm_answer = self.llm.chat.completions.create(
|
122 |
-
model="gpt-4o",
|
123 |
-
messages=[
|
124 |
-
{"role": "system", "content": "You are a research assistant. Based on the web search results and question, answer strictly and concisely for the GAIA benchmark. Only the answer, no explanations or apologies."},
|
125 |
-
{"role": "user", "content": f"Web search results:\n{web_result}\n\nQuestion: {question}"}
|
126 |
-
],
|
127 |
-
temperature=0.0,
|
128 |
-
max_tokens=256,
|
129 |
-
).choices[0].message.content.strip()
|
130 |
-
return format_gaia_answer(llm_answer, question)
|
131 |
-
# Fallback: strict LLM answer, formatted
|
132 |
-
llm_answer = self.llm.chat.completions.create(
|
133 |
-
model="gpt-4o",
|
134 |
-
messages=[
|
135 |
-
{"role": "system", "content": "You are a research assistant. Answer strictly and concisely for the GAIA benchmark. Only the answer, no explanations or apologies."},
|
136 |
-
{"role": "user", "content": question}
|
137 |
-
],
|
138 |
-
temperature=0.0,
|
139 |
-
max_tokens=128,
|
140 |
-
).choices[0].message.content.strip()
|
141 |
-
return format_gaia_answer(llm_answer, question)
|
|
|
1 |
import os
|
2 |
+
import requests
|
3 |
+
import base64
|
4 |
+
from langchain_openai import ChatOpenAI
|
5 |
+
from langchain_community.tools import DuckDuckGoSearchRun
|
6 |
+
from langchain.agents import initialize_agent, Tool
|
7 |
+
from langchain.agents.agent_types import AgentType
|
8 |
+
from langchain.memory import ConversationBufferMemory
|
9 |
+
from langchain_core.messages import HumanMessage
|
10 |
|
11 |
+
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
+
class BasicAgent:
|
14 |
+
def __init__(self):
|
15 |
+
print("BasicAgent initialized.")
|
16 |
+
# Use DuckDuckGo Search only
|
17 |
+
tools = [
|
18 |
+
Tool(
|
19 |
+
name="DuckDuckGo Search",
|
20 |
+
func=DuckDuckGoSearchRun().run,
|
21 |
+
description="Use this tool to find factual information or recent events."
|
22 |
+
),
|
23 |
+
Tool(
|
24 |
+
name="Image Analyzer",
|
25 |
+
func=self.describe_image,
|
26 |
+
description="Analyzes and describes what's in an image. Input is an image path."
|
27 |
+
)
|
28 |
+
]
|
29 |
|
30 |
+
memory = ConversationBufferMemory(memory_key="chat_history")
|
31 |
+
self.model = ChatOpenAI(model="gpt-4.1-mini", temperature=0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
+
self.agent = initialize_agent(
|
34 |
+
tools=tools,
|
35 |
+
llm=self.model,
|
36 |
+
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
|
37 |
+
verbose=True,
|
38 |
+
memory=memory
|
39 |
+
)
|
40 |
+
|
41 |
+
def describe_image(self, img_path: str) -> str:
|
42 |
+
all_text = ""
|
43 |
+
try:
|
44 |
+
r = requests.get(img_path, timeout=10)
|
45 |
+
image_bytes = r.content
|
46 |
+
image_base64 = base64.b64encode(image_bytes).decode("utf-8")
|
47 |
+
message = [
|
48 |
+
HumanMessage(
|
49 |
+
content=[
|
50 |
+
{
|
51 |
+
"type": "text",
|
52 |
+
"text": (
|
53 |
+
"You're a chess assistant. Answer only with the best move in algebraic notation (e.g., Qd1#)."
|
54 |
+
),
|
55 |
+
},
|
56 |
+
{
|
57 |
+
"type": "image_url",
|
58 |
+
"image_url": {
|
59 |
+
"url": f"data:image/png;base64,{image_base64}"
|
60 |
+
},
|
61 |
+
},
|
62 |
+
]
|
63 |
+
)
|
64 |
+
]
|
65 |
+
response = self.model.invoke(message)
|
66 |
+
all_text += response.content + "\n\n"
|
67 |
+
return all_text.strip()
|
68 |
+
except Exception as e:
|
69 |
+
error_msg = f"Error extracting text: {str(e)}"
|
70 |
+
print(error_msg)
|
71 |
+
return ""
|
72 |
|
73 |
+
def fetch_file(self, task_id):
|
74 |
+
try:
|
75 |
+
url = f"{DEFAULT_API_URL}/files/{task_id}"
|
76 |
+
r = requests.get(url, timeout=10)
|
77 |
+
r.raise_for_status()
|
78 |
+
return url, r.content, r.headers.get("Content-Type", "")
|
79 |
+
except:
|
80 |
+
return None, None, None
|
81 |
+
|
82 |
+
def __call__(self, question: str, task_id: str) -> str:
|
83 |
+
print(f"Agent received question (first 50 chars) {task_id}: {question[:50]}...")
|
84 |
+
file_url, file_content, file_type = self.fetch_file(task_id)
|
85 |
+
print(f"Fetched file {file_type}")
|
86 |
+
if file_url is not None:
|
87 |
+
question = f"{question} This task has assigned file with URL: {file_url}"
|
88 |
+
fixed_answer = self.agent.run(question)
|
89 |
+
print(f"Agent returning fixed answer: {fixed_answer}")
|
90 |
+
return fixed_answer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|