dawid-lorek commited on
Commit
bd702b9
·
verified ·
1 Parent(s): e6b7fa0

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +51 -147
agent.py CHANGED
@@ -1,148 +1,52 @@
1
- import os
2
- import tempfile
3
- import requests
4
- import re
5
- import pandas as pd
6
-
7
- from langchain_openai import ChatOpenAI
8
- from langchain.agents import initialize_agent, Tool
9
- from langchain.agents.agent_types import AgentType
10
- from langchain_community.tools import DuckDuckGoSearchRun
11
-
12
- # Audio transcription tool (OpenAI Whisper)
13
- def transcribe_audio_tool(file_url: str) -> str:
14
- import openai
15
- openai.api_key = os.getenv("OPENAI_API_KEY")
16
- try:
17
- r = requests.get(file_url, timeout=20)
18
- with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as f:
19
- f.write(r.content)
20
- f.flush()
21
- path = f.name
22
- transcript = openai.Audio.transcribe("whisper-1", open(path, "rb"))
23
- return transcript.get("text", "")
24
- except Exception as e:
25
- return ""
26
-
27
- # Excel reading tool
28
- def read_excel_tool(file_url: str) -> str:
29
- try:
30
- r = requests.get(file_url, timeout=20)
31
- with tempfile.NamedTemporaryFile(suffix=".xlsx", delete=False) as f:
32
- f.write(r.content)
33
- f.flush()
34
- path = f.name
35
- df = pd.read_excel(path)
36
- if 'Type' in df.columns and 'Sales' in df.columns:
37
- total = df[df['Type'].str.lower() == 'food']['Sales'].sum()
38
- return str(round(total, 2))
39
- total = df.select_dtypes(include='number').sum().sum()
40
- return str(round(total, 2))
41
- except Exception as e:
42
- return ""
43
-
44
- # Python code execution tool (caution: only for controlled/tested code!)
45
- def execute_python_tool(code_url: str) -> str:
46
- try:
47
- r = requests.get(code_url, timeout=20)
48
- code = r.content.decode("utf-8")
49
- import io, contextlib
50
- buf = io.StringIO()
51
- with contextlib.redirect_stdout(buf):
52
- exec(code, {})
53
- output = buf.getvalue().strip().split('\n')[-1]
54
- numbers = re.findall(r'[-+]?\d*\.\d+|\d+', output)
55
- return numbers[-1] if numbers else output
56
- except Exception as e:
57
- return ""
58
-
59
- def extract_numbers(text: str) -> str:
60
- nums = re.findall(r'\b\d+\b', text)
61
- return ', '.join(nums) if nums else ""
62
-
63
- def extract_names(text: str) -> str:
64
- words = re.findall(r'\b[A-Z][a-z]{2,}\b', text)
65
- return ', '.join(words) if words else ""
66
-
67
- tools = [
68
- Tool(
69
- name="DuckDuckGo Search",
70
- func=DuckDuckGoSearchRun().run,
71
- description="Use to find factual information or recent events."
72
- ),
73
- Tool(
74
- name="Transcribe Audio",
75
- func=transcribe_audio_tool,
76
- description="Use to transcribe an audio file from a URL (mp3 or wav)."
77
- ),
78
- Tool(
79
- name="Read Excel File",
80
- func=read_excel_tool,
81
- description="Use to read an Excel spreadsheet file from a URL (xlsx) and sum food sales or extract tables."
82
- ),
83
- Tool(
84
- name="Execute Python",
85
- func=execute_python_tool,
86
- description="Use to execute a Python file from a URL and get the final output."
87
- ),
88
- Tool(
89
- name="Extract Numbers",
90
- func=extract_numbers,
91
- description="Use to extract all numbers from provided text."
92
- ),
93
- Tool(
94
- name="Extract Names",
95
- func=extract_names,
96
- description="Use to extract capitalized names from provided text."
97
- )
98
- ]
99
-
100
- PROMPT = (
101
- "You are a general AI assistant. I will ask you a question. "
102
- "Reason step by step, and use tools as needed. "
103
- "Search the web only once per question if needed, then reason further using your tools and the provided information. "
104
- "Only after you are sure, answer with the template: "
105
- "FINAL ANSWER: [YOUR FINAL ANSWER]. YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. "
106
- "If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. "
107
- "If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. "
108
- "If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string."
109
- )
110
-
111
- llm = ChatOpenAI(model="gpt-4o", temperature=0)
112
-
113
- class BasicAgent:
114
- def __init__(self):
115
- self.agent = initialize_agent(
116
- tools=tools,
117
- llm=llm,
118
- agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
119
- verbose=False,
120
- handle_parsing_errors=True,
121
- max_iterations=4, # Prevent endless loops
122
- max_execution_time=60 # Safety timeout in seconds
123
- )
124
- self.prompt = PROMPT
125
-
126
- def fetch_file_url(self, task_id):
127
- DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
128
- try:
129
- url = f"{DEFAULT_API_URL}/files/{task_id}"
130
- r = requests.head(url, timeout=5)
131
- if r.status_code == 200:
132
- return url
133
- except:
134
- pass
135
- return None
136
-
137
- def __call__(self, question: str, task_id: str = None) -> str:
138
- file_url = self.fetch_file_url(task_id) if task_id else None
139
- if file_url:
140
- question_aug = f"{question}\nThis task has assigned file at this URL: {file_url}"
141
- else:
142
- question_aug = question
143
- full_prompt = self.prompt + "\n" + question_aug
144
- result = self.agent.run(full_prompt)
145
- for line in result.splitlines():
146
  if line.strip().lower().startswith("final answer:"):
147
- return line.split(":", 1)[-1].strip(" .\"'")
148
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ answer = response.choices[0].message.content.strip()
2
+ final_line = ""
3
+ for line in answer.splitlines():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  if line.strip().lower().startswith("final answer:"):
5
+ final_line = line.split(":", 1)[-1].strip(" .\"'")
6
+ break
7
+
8
+ bads = [
9
+ "", "unknown", "unable to determine", "unable to provide page numbers",
10
+ "unable to access video content directly", "unable to analyze video content",
11
+ "unable to determine without code", "unable to determine without file",
12
+ "follow the steps to locate the paper and find the nasa award number in the acknowledgment section",
13
+ "i am unable to view images or access external content directly", "unable to determine without access to the file",
14
+ "no results found", "n/a", "[your final answer]", "i'm sorry", "i apologize"
15
+ ]
16
+ if final_line.lower() in bads or final_line.lower().startswith("unable") or final_line.lower().startswith("i'm sorry") or final_line.lower().startswith("i apologize"):
17
+ # --- Try to extract a plausible answer from web or file ---
18
+ # Example: For numbers
19
+ numbers = re.findall(r'\b\d{2,}\b', search_snippet)
20
+ if numbers:
21
+ return numbers[0]
22
+ # Example: For possible names (capitalize words)
23
+ words = re.findall(r'\b[A-Z][a-z]{2,}\b', search_snippet)
24
+ if words:
25
+ return words[0]
26
+ # Example: For Excel, code, or file extraction, return the result
27
+ if file_result:
28
+ file_numbers = re.findall(r'\b\d{2,}\b', file_result)
29
+ if file_numbers:
30
+ return file_numbers[0]
31
+ file_words = re.findall(r'\b[A-Z][a-z]{2,}\b', file_result)
32
+ if file_words:
33
+ return file_words[0]
34
+ # --- Try to re-ask the LLM to answer "without apologies" ---
35
+ retry_prompt = (
36
+ "Based ONLY on the search results and/or file content above, return a direct answer to the question. "
37
+ "If you do not know, make your best plausible guess. Do NOT apologize or say you cannot assist. "
38
+ f"File: {file_result}\n\nWeb: {search_snippet}\n\nQuestion: {question}\nFINAL ANSWER:"
39
+ )
40
+ response2 = self.llm.chat.completions.create(
41
+ model="gpt-4o",
42
+ messages=[{"role": "system", "content": retry_prompt}],
43
+ temperature=0.1,
44
+ max_tokens=128,
45
+ )
46
+ retry_answer = response2.choices[0].message.content.strip()
47
+ for line in retry_answer.splitlines():
48
+ if line.strip().lower().startswith("final answer:"):
49
+ return line.split(":", 1)[-1].strip(" .\"'")
50
+ if retry_answer:
51
+ return retry_answer.strip(" .\"'")
52
+ return final_line