dawid-lorek commited on
Commit
06074b9
·
verified ·
1 Parent(s): 0afb7b8

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +84 -135
agent.py CHANGED
@@ -1,141 +1,90 @@
1
  import os
2
- import re
3
- from openai import OpenAI as OpenAIClient
4
- from duckduckgo_search import DDGS
 
 
 
 
 
5
 
6
- def duckduckgo_search(query: str) -> str:
7
- try:
8
- with DDGS() as ddg:
9
- results = ddg.text(query=query, region="wt-wt", max_results=5)
10
- bodies = [r.get('body', '') for r in results if r.get('body')]
11
- return "\n".join(bodies[:3])
12
- except Exception as e:
13
- return f"ERROR: {e}"
14
 
15
- def eval_python_code(code: str) -> str:
16
- try:
17
- return str(eval(code, {"__builtins__": {}}))
18
- except Exception as e:
19
- return f"ERROR: {e}"
 
 
 
 
 
 
 
 
 
 
 
20
 
21
- def format_gaia_answer(answer: str, question: str = "") -> str:
22
- """Strict GAIA output, eliminate apologies, extract only answer value."""
23
- if not answer:
24
- return ""
25
- # Remove apologies and anything after
26
- answer = re.sub(
27
- r'(?i)(I[\' ]?m sorry.*|Unfortunately.*|I cannot.*|I am unable.*|error:.*|no file.*|but.*|however.*|unable to.*|not available.*|if you have access.*|I can\'t.*)',
28
- '', answer).strip()
29
- # Remove everything after the first period if it's not a list
30
- if not ("list" in question or "ingredient" in question or "vegetable" in question):
31
- answer = answer.split('\n')[0].split('.')[0]
32
- # Remove quotes/brackets
33
- answer = answer.strip(' "\'[](),;:')
34
- # Only numbers for count questions
35
- if re.search(r'how many|number of|albums|at bats|total sales|output', question, re.I):
36
- match = re.search(r'(\d+)', answer)
37
- if match:
38
- return match.group(1)
39
- # Only last word for "surname", first for "first name"
40
- if "surname" in question:
41
- return answer.split()[-1]
42
- if "first name" in question:
43
- return answer.split()[0]
44
- # For code outputs, numbers only
45
- if "output" in question and "python" in question:
46
- num = re.search(r'(\d+)', answer)
47
- return num.group(1) if num else answer
48
- # Only country code (3+ uppercase letters or digits)
49
- if re.search(r'IOC country code|award number|NASA', question, re.I):
50
- code = re.search(r'[A-Z0-9]{3,}', answer)
51
- if code:
52
- return code.group(0)
53
- # For lists: comma-separated, alpha, deduped, merged phrases
54
- if "list" in question or "ingredient" in question or "vegetable" in question:
55
- items = [x.strip(' "\'') for x in re.split(r'[,\n]', answer) if x.strip()]
56
- merged = []
57
- skip = False
58
- for i, item in enumerate(items):
59
- if skip:
60
- skip = False
61
- continue
62
- if i + 1 < len(items) and item in ['sweet', 'green', 'lemon', 'ripe', 'whole', 'fresh', 'bell']:
63
- merged.append(f"{item} {items[i+1]}")
64
- skip = True
65
- else:
66
- merged.append(item)
67
- merged = [x.lower() for x in merged]
68
- merged = sorted(set(merged))
69
- return ', '.join(merged)
70
- # For chess: algebraic move
71
- if "algebraic notation" in question or "chess" in question:
72
- move = re.findall(r'[KQRBN]?[a-h]?[1-8]?x?[a-h][1-8][+#]?', answer)
73
- if move:
74
- return move[-1]
75
- return answer.strip(' "\'[](),;:')
76
 
77
- class GaiaAgent:
78
- def __init__(self):
79
- self.llm = OpenAIClient(api_key=os.getenv("OPENAI_API_KEY"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
- def __call__(self, question: str, task_id: str = None) -> str:
82
- search_keywords = [
83
- "who", "when", "what", "which", "how many", "number", "name", "albums", "surname", "at bats",
84
- "nasa", "city", "winner", "code", "vegetable", "ingredient", "magda m.", "featured article"
85
- ]
86
- needs_search = any(kw in question.lower() for kw in search_keywords)
87
- if needs_search:
88
- web_result = duckduckgo_search(question)
89
- llm_answer = self.llm.chat.completions.create(
90
- model="gpt-4o",
91
- messages=[
92
- {"role": "system", "content": "You are a research assistant. Based on the web search results and question, answer strictly and concisely for the GAIA benchmark. Only the answer, no explanations or apologies."},
93
- {"role": "user", "content": f"Web search results:\n{web_result}\n\nQuestion: {question}"}
94
- ],
95
- temperature=0.0,
96
- max_tokens=256,
97
- ).choices[0].message.content.strip()
98
- formatted = format_gaia_answer(llm_answer, question)
99
- # Retry if apology/empty/incorrect
100
- if not formatted or "sorry" in formatted.lower() or "unable" in formatted.lower():
101
- llm_answer2 = self.llm.chat.completions.create(
102
- model="gpt-4o",
103
- messages=[
104
- {"role": "system", "content": "Only answer with the value. No explanation. Do not apologize. Do not begin with 'I'm sorry', 'Unfortunately', or similar."},
105
- {"role": "user", "content": f"Web search results:\n{web_result}\n\nQuestion: {question}"}
106
- ],
107
- temperature=0.0,
108
- max_tokens=128,
109
- ).choices[0].message.content.strip()
110
- formatted = format_gaia_answer(llm_answer2, question)
111
- return formatted
112
- # For code/math output
113
- if "output" in question.lower() and "python" in question.lower():
114
- code_match = re.search(r'```python(.*?)```', question, re.DOTALL)
115
- code = code_match.group(1) if code_match else ""
116
- result = eval_python_code(code)
117
- return format_gaia_answer(result, question)
118
- # For lists/ingredients, always web search and format
119
- if "list" in question.lower() or "ingredient" in question.lower() or "vegetable" in question.lower():
120
- web_result = duckduckgo_search(question)
121
- llm_answer = self.llm.chat.completions.create(
122
- model="gpt-4o",
123
- messages=[
124
- {"role": "system", "content": "You are a research assistant. Based on the web search results and question, answer strictly and concisely for the GAIA benchmark. Only the answer, no explanations or apologies."},
125
- {"role": "user", "content": f"Web search results:\n{web_result}\n\nQuestion: {question}"}
126
- ],
127
- temperature=0.0,
128
- max_tokens=256,
129
- ).choices[0].message.content.strip()
130
- return format_gaia_answer(llm_answer, question)
131
- # Fallback: strict LLM answer, formatted
132
- llm_answer = self.llm.chat.completions.create(
133
- model="gpt-4o",
134
- messages=[
135
- {"role": "system", "content": "You are a research assistant. Answer strictly and concisely for the GAIA benchmark. Only the answer, no explanations or apologies."},
136
- {"role": "user", "content": question}
137
- ],
138
- temperature=0.0,
139
- max_tokens=128,
140
- ).choices[0].message.content.strip()
141
- return format_gaia_answer(llm_answer, question)
 
1
  import os
2
+ import requests
3
+ import base64
4
+ from langchain_openai import ChatOpenAI
5
+ from langchain_community.tools import DuckDuckGoSearchRun
6
+ from langchain.agents import initialize_agent, Tool
7
+ from langchain.agents.agent_types import AgentType
8
+ from langchain.memory import ConversationBufferMemory
9
+ from langchain_core.messages import HumanMessage
10
 
11
+ DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
 
 
 
 
 
 
 
12
 
13
+ class BasicAgent:
14
+ def __init__(self):
15
+ print("BasicAgent initialized.")
16
+ # Use DuckDuckGo Search only
17
+ tools = [
18
+ Tool(
19
+ name="DuckDuckGo Search",
20
+ func=DuckDuckGoSearchRun().run,
21
+ description="Use this tool to find factual information or recent events."
22
+ ),
23
+ Tool(
24
+ name="Image Analyzer",
25
+ func=self.describe_image,
26
+ description="Analyzes and describes what's in an image. Input is an image path."
27
+ )
28
+ ]
29
 
30
+ memory = ConversationBufferMemory(memory_key="chat_history")
31
+ self.model = ChatOpenAI(model="gpt-4.1-mini", temperature=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
+ self.agent = initialize_agent(
34
+ tools=tools,
35
+ llm=self.model,
36
+ agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
37
+ verbose=True,
38
+ memory=memory
39
+ )
40
+
41
+ def describe_image(self, img_path: str) -> str:
42
+ all_text = ""
43
+ try:
44
+ r = requests.get(img_path, timeout=10)
45
+ image_bytes = r.content
46
+ image_base64 = base64.b64encode(image_bytes).decode("utf-8")
47
+ message = [
48
+ HumanMessage(
49
+ content=[
50
+ {
51
+ "type": "text",
52
+ "text": (
53
+ "You're a chess assistant. Answer only with the best move in algebraic notation (e.g., Qd1#)."
54
+ ),
55
+ },
56
+ {
57
+ "type": "image_url",
58
+ "image_url": {
59
+ "url": f"data:image/png;base64,{image_base64}"
60
+ },
61
+ },
62
+ ]
63
+ )
64
+ ]
65
+ response = self.model.invoke(message)
66
+ all_text += response.content + "\n\n"
67
+ return all_text.strip()
68
+ except Exception as e:
69
+ error_msg = f"Error extracting text: {str(e)}"
70
+ print(error_msg)
71
+ return ""
72
 
73
+ def fetch_file(self, task_id):
74
+ try:
75
+ url = f"{DEFAULT_API_URL}/files/{task_id}"
76
+ r = requests.get(url, timeout=10)
77
+ r.raise_for_status()
78
+ return url, r.content, r.headers.get("Content-Type", "")
79
+ except:
80
+ return None, None, None
81
+
82
+ def __call__(self, question: str, task_id: str) -> str:
83
+ print(f"Agent received question (first 50 chars) {task_id}: {question[:50]}...")
84
+ file_url, file_content, file_type = self.fetch_file(task_id)
85
+ print(f"Fetched file {file_type}")
86
+ if file_url is not None:
87
+ question = f"{question} This task has assigned file with URL: {file_url}"
88
+ fixed_answer = self.agent.run(question)
89
+ print(f"Agent returning fixed answer: {fixed_answer}")
90
+ return fixed_answer