|
import os |
|
import tempfile |
|
import requests |
|
import re |
|
import pandas as pd |
|
|
|
from langchain_openai import ChatOpenAI |
|
from langchain.agents import initialize_agent, Tool |
|
from langchain.agents.agent_types import AgentType |
|
from langchain_community.tools import DuckDuckGoSearchRun |
|
|
|
|
|
def transcribe_audio_tool(file_url: str) -> str: |
|
import openai |
|
openai.api_key = os.getenv("OPENAI_API_KEY") |
|
try: |
|
r = requests.get(file_url, timeout=20) |
|
with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as f: |
|
f.write(r.content) |
|
f.flush() |
|
path = f.name |
|
transcript = openai.Audio.transcribe("whisper-1", open(path, "rb")) |
|
return transcript.get("text", "") |
|
except Exception as e: |
|
return "" |
|
|
|
|
|
def read_excel_tool(file_url: str) -> str: |
|
try: |
|
r = requests.get(file_url, timeout=20) |
|
with tempfile.NamedTemporaryFile(suffix=".xlsx", delete=False) as f: |
|
f.write(r.content) |
|
f.flush() |
|
path = f.name |
|
df = pd.read_excel(path) |
|
if 'Type' in df.columns and 'Sales' in df.columns: |
|
total = df[df['Type'].str.lower() == 'food']['Sales'].sum() |
|
return str(round(total, 2)) |
|
|
|
total = df.select_dtypes(include='number').sum().sum() |
|
return str(round(total, 2)) |
|
except Exception as e: |
|
return "" |
|
|
|
|
|
def execute_python_tool(code_url: str) -> str: |
|
try: |
|
r = requests.get(code_url, timeout=20) |
|
code = r.content.decode("utf-8") |
|
import io, contextlib |
|
buf = io.StringIO() |
|
with contextlib.redirect_stdout(buf): |
|
exec(code, {}) |
|
output = buf.getvalue().strip().split('\n')[-1] |
|
|
|
numbers = re.findall(r'[-+]?\d*\.\d+|\d+', output) |
|
return numbers[-1] if numbers else output |
|
except Exception as e: |
|
return "" |
|
|
|
|
|
def extract_numbers(text: str) -> str: |
|
nums = re.findall(r'\b\d+\b', text) |
|
return ', '.join(nums) if nums else "" |
|
|
|
def extract_names(text: str) -> str: |
|
words = re.findall(r'\b[A-Z][a-z]{2,}\b', text) |
|
return ', '.join(words) if words else "" |
|
|
|
|
|
tools = [ |
|
Tool( |
|
name="DuckDuckGo Search", |
|
func=DuckDuckGoSearchRun().run, |
|
description="Use to find factual information or recent events." |
|
), |
|
Tool( |
|
name="Transcribe Audio", |
|
func=transcribe_audio_tool, |
|
description="Use to transcribe an audio file from a URL (mp3 or wav)." |
|
), |
|
Tool( |
|
name="Read Excel File", |
|
func=read_excel_tool, |
|
description="Use to read an Excel spreadsheet file from a URL (xlsx) and sum food sales or extract tables." |
|
), |
|
Tool( |
|
name="Execute Python", |
|
func=execute_python_tool, |
|
description="Use to execute a Python file from a URL and get the final output." |
|
), |
|
Tool( |
|
name="Extract Numbers", |
|
func=extract_numbers, |
|
description="Use to extract all numbers from provided text." |
|
), |
|
Tool( |
|
name="Extract Names", |
|
func=extract_names, |
|
description="Use to extract capitalized names from provided text." |
|
) |
|
] |
|
|
|
PROMPT = ( |
|
"You are a general AI assistant. I will ask you a question. " |
|
"Reason step by step, and use tools as needed. Only after you are sure, answer with the template: " |
|
"FINAL ANSWER: [YOUR FINAL ANSWER]. YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. " |
|
"If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. " |
|
"If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. " |
|
"If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string." |
|
) |
|
|
|
llm = ChatOpenAI(model="gpt-4o", temperature=0) |
|
|
|
class BasicAgent: |
|
def __init__(self): |
|
self.agent = initialize_agent( |
|
tools=tools, |
|
llm=llm, |
|
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, |
|
verbose=False, |
|
handle_parsing_errors=True |
|
) |
|
self.prompt = PROMPT |
|
|
|
def fetch_file_url(self, task_id): |
|
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space" |
|
try: |
|
url = f"{DEFAULT_API_URL}/files/{task_id}" |
|
r = requests.head(url, timeout=5) |
|
if r.status_code == 200: |
|
return url |
|
except: |
|
pass |
|
return None |
|
|
|
def __call__(self, question: str, task_id: str = None) -> str: |
|
file_url = self.fetch_file_url(task_id) if task_id else None |
|
if file_url: |
|
|
|
question_aug = f"{question}\nThis task has assigned file at this URL: {file_url}" |
|
else: |
|
question_aug = question |
|
|
|
full_prompt = self.prompt + "\n" + question_aug |
|
result = self.agent.run(full_prompt) |
|
|
|
for line in result.splitlines(): |
|
if line.strip().lower().startswith("final answer:"): |
|
return line.split(":", 1)[-1].strip(" .\"'") |
|
return result |