dawid-lorek commited on
Commit
14c8db3
·
verified ·
1 Parent(s): 4044d5c

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +62 -10
agent.py CHANGED
@@ -4,7 +4,11 @@ import requests
4
  import tempfile
5
  import pandas as pd
6
  from openai import OpenAI
7
- from duckduckgo_search import DDGS
 
 
 
 
8
 
9
  PROMPT = (
10
  "You are a general AI assistant. I will ask you a question. "
@@ -21,6 +25,8 @@ class BasicAgent:
21
  print("BasicAgent initialized.")
22
 
23
  def web_search(self, query: str, max_results: int = 5) -> str:
 
 
24
  try:
25
  with DDGS() as ddgs:
26
  results = list(ddgs.text(query, max_results=max_results))
@@ -54,6 +60,40 @@ class BasicAgent:
54
  except Exception as e:
55
  return ""
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  def fetch_file_url(self, task_id):
58
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
59
  try:
@@ -68,16 +108,31 @@ class BasicAgent:
68
  def __call__(self, question: str, task_id: str = None) -> str:
69
  file_url = self.fetch_file_url(task_id) if task_id else None
70
  file_result = None
71
- # --- Always try Excel tool if file exists ---
 
72
  if file_url:
73
- try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  file_result = self.excel_tool(file_url)
75
- # Only return if it actually found a number
76
  if file_result and re.match(r'^\d+(\.\d+)?$', file_result):
77
  return file_result
78
- except Exception:
79
- pass # fallback below
80
 
 
81
  search_snippet = self.web_search(question)
82
  prompt = PROMPT + f"\n\nWeb search results:\n{search_snippet}\n\nQuestion: {question}"
83
  response = self.llm.chat.completions.create(
@@ -104,16 +159,13 @@ class BasicAgent:
104
  ]
105
  norm_final = (final_line or "").lower()
106
  if norm_final in bads or norm_final.startswith("unable") or norm_final.startswith("i'm sorry") or norm_final.startswith("i apologize"):
107
- # --- Try to extract a plausible answer from web or file ---
108
- # Try: numbers from search
109
  numbers = re.findall(r'\b\d{2,}\b', search_snippet)
110
  if numbers:
111
  return numbers[0]
112
- # Try: possible names from search
113
  words = re.findall(r'\b[A-Z][a-z]{2,}\b', search_snippet)
114
  if words:
115
  return words[0]
116
- # Try: numbers or words from file_result
117
  if file_result:
118
  file_numbers = re.findall(r'\b\d{2,}\b', str(file_result))
119
  if file_numbers:
 
4
  import tempfile
5
  import pandas as pd
6
  from openai import OpenAI
7
+
8
+ try:
9
+ from duckduckgo_search import DDGS
10
+ except ImportError:
11
+ DDGS = None # In case of install problems
12
 
13
  PROMPT = (
14
  "You are a general AI assistant. I will ask you a question. "
 
25
  print("BasicAgent initialized.")
26
 
27
  def web_search(self, query: str, max_results: int = 5) -> str:
28
+ if not DDGS:
29
+ return ""
30
  try:
31
  with DDGS() as ddgs:
32
  results = list(ddgs.text(query, max_results=max_results))
 
60
  except Exception as e:
61
  return ""
62
 
63
+ def transcribe_audio(self, file_url: str) -> str:
64
+ import openai
65
+ openai.api_key = os.getenv("OPENAI_API_KEY")
66
+ try:
67
+ r = requests.get(file_url, timeout=20)
68
+ r.raise_for_status()
69
+ # Guess extension from url or response
70
+ ext = ".mp3"
71
+ if file_url.endswith(".wav"):
72
+ ext = ".wav"
73
+ with tempfile.NamedTemporaryFile(suffix=ext, delete=False) as f:
74
+ f.write(r.content)
75
+ f.flush()
76
+ audio_path = f.name
77
+ transcript = openai.Audio.transcribe("whisper-1", open(audio_path, "rb"))
78
+ return transcript.get("text", "")
79
+ except Exception as e:
80
+ return ""
81
+
82
+ def execute_python(self, file_url: str) -> str:
83
+ try:
84
+ r = requests.get(file_url, timeout=20)
85
+ r.raise_for_status()
86
+ code = r.content.decode("utf-8")
87
+ import io, contextlib
88
+ buf = io.StringIO()
89
+ with contextlib.redirect_stdout(buf):
90
+ exec(code, {})
91
+ output = buf.getvalue().strip().split('\n')[-1]
92
+ numbers = re.findall(r'[-+]?\d*\.\d+|\d+', output)
93
+ return numbers[-1] if numbers else output
94
+ except Exception as e:
95
+ return ""
96
+
97
  def fetch_file_url(self, task_id):
98
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
99
  try:
 
108
  def __call__(self, question: str, task_id: str = None) -> str:
109
  file_url = self.fetch_file_url(task_id) if task_id else None
110
  file_result = None
111
+ ext = file_url.split('.')[-1].lower() if file_url else ""
112
+ # --- Try all known file tools by extension ---
113
  if file_url:
114
+ # Excel (any file: always try)
115
+ if ext in ["xlsx", "xls"] or "excel" in question.lower() or "spreadsheet" in question.lower():
116
+ file_result = self.excel_tool(file_url)
117
+ if file_result and re.match(r'^\d+(\.\d+)?$', file_result):
118
+ return file_result
119
+ # Audio
120
+ elif ext in ["mp3", "wav"] or "audio" in question.lower() or "transcribe" in question.lower():
121
+ file_result = self.transcribe_audio(file_url)
122
+ if file_result and file_result.strip():
123
+ return file_result
124
+ # Python code
125
+ elif ext == "py":
126
+ file_result = self.execute_python(file_url)
127
+ if file_result and file_result.strip():
128
+ return file_result
129
+ # Fallback: try Excel anyway
130
+ if not file_result:
131
  file_result = self.excel_tool(file_url)
 
132
  if file_result and re.match(r'^\d+(\.\d+)?$', file_result):
133
  return file_result
 
 
134
 
135
+ # --- Web search and LLM as before ---
136
  search_snippet = self.web_search(question)
137
  prompt = PROMPT + f"\n\nWeb search results:\n{search_snippet}\n\nQuestion: {question}"
138
  response = self.llm.chat.completions.create(
 
159
  ]
160
  norm_final = (final_line or "").lower()
161
  if norm_final in bads or norm_final.startswith("unable") or norm_final.startswith("i'm sorry") or norm_final.startswith("i apologize"):
162
+ # Try to extract a plausible answer from web or file
 
163
  numbers = re.findall(r'\b\d{2,}\b', search_snippet)
164
  if numbers:
165
  return numbers[0]
 
166
  words = re.findall(r'\b[A-Z][a-z]{2,}\b', search_snippet)
167
  if words:
168
  return words[0]
 
169
  if file_result:
170
  file_numbers = re.findall(r'\b\d{2,}\b', str(file_result))
171
  if file_numbers: