dawid-lorek's picture
Update agent.py
30a91c5 verified
raw
history blame
5.65 kB
# agent.py – HybridAgent: GAIA-style fallback + tools + no errors
import os
import requests
import openai
import traceback
from langchain_community.utilities.wikipedia import WikipediaAPIWrapper
from langchain_experimental.tools.python.tool import PythonREPLTool
from langchain_community.document_loaders import YoutubeLoader
# Set OpenAI API key
openai.api_key = os.getenv("OPENAI_API_KEY")
def search_wikipedia(query: str) -> str:
try:
wiki = WikipediaAPIWrapper(top_k_results=1, doc_content_chars_max=1200)
return wiki.run(query)
except Exception as e:
return f"[TOOL ERROR] Wikipedia: {e}"
def run_python_code(code: str) -> str:
try:
tool = PythonREPLTool()
if 'print' not in code:
code = f"print({repr(code)})"
return tool.run(code)
except Exception as e:
return f"[TOOL ERROR] Python: {e}"
def get_youtube_transcript(url: str) -> str:
try:
loader = YoutubeLoader.from_youtube_url(url, add_video_info=False)
docs = loader.load()
return " ".join(d.page_content for d in docs)
except Exception as e:
return f"[TOOL ERROR] YouTube: {e}"
def fetch_file_context(task_id: str) -> str:
try:
url = f"https://agents-course-unit4-scoring.hf.space/files/{task_id}"
resp = requests.get(url, timeout=10)
resp.raise_for_status()
if "text" in resp.headers.get("Content-Type", ""):
return resp.text[:3000] # Truncate to stay safe
return f"[Unsupported file type]"
except Exception as e:
return f"[FILE ERROR] {e}"
def build_prompt(question: str, context: str = "") -> str:
return f"""
You are a highly skilled research assistant completing factual benchmark questions.
If any tools are required, try to simulate reasoning or summarize fallback.
Return a concise factual answer only – no explanations.
{context.strip()}
Question: {question.strip()}
Answer:"""
def ask_openai(prompt: str) -> str:
try:
response = openai.ChatCompletion.create(
model="gpt-4-turbo",
messages=[
{"role": "system", "content": "Answer factually. Return only final result."},
{"role": "user", "content": prompt},
],
temperature=0.0
)
return response.choices[0].message.content.strip()
except Exception as e:
return f"[OPENAI ERROR] {e}"
# === Unified entry point ===
def answer_question(question: str, task_id: str = None) -> str:
try:
file_context = fetch_file_context(task_id) if task_id else ""
# Optional: tool-based enhancement
if "wikipedia" in question.lower():
wiki = search_wikipedia(question)
file_context += f"\n[Wikipedia] {wiki}"
elif "youtube.com" in question:
yt = get_youtube_transcript(question)
file_context += f"\n[YouTube Transcript] {yt}"
elif "* on the set" in question and file_context:
try:
import re
import pandas as pd
table_lines = [line.strip() for line in file_context.splitlines() if '|' in line and '*' not in line[:2]]
headers = re.split(r'\|+', table_lines[0])[1:-1]
data_rows = [re.split(r'\|+', row)[1:-1] for row in table_lines[1:]]
index = [row[0] for row in data_rows]
matrix = [row[1:] for row in data_rows]
df = pd.DataFrame(matrix, index=index, columns=headers)
non_comm = set()
for a in df.index:
for b in df.columns:
if df.at[a, b] != df.at[b, a]:
non_comm.add(a)
non_comm.add(b)
result = ", ".join(sorted(non_comm))
file_context += f"\n[Parsed Non-Commutative Set] {result}"
except Exception as e:
file_context += f"\n[Table Parse Error] {e}"
# Dynamic parser for operation table
import re
import pandas as pd
try:
table_lines = [line.strip() for line in file_context.splitlines() if '|' in line and '*' not in line[:2]]
headers = re.split(r'\|+', table_lines[0])[1:-1]
data_rows = [re.split(r'\|+', row)[1:-1] for row in table_lines[1:]]
index = [row[0] for row in data_rows]
matrix = [row[1:] for row in data_rows]
df = pd.DataFrame(matrix, index=index, columns=headers)
non_comm = set()
for a in df.index:
for b in df.columns:
if df.at[a, b] != df.at[b, a]:
non_comm.add(a)
non_comm.add(b)
result = ", ".join(sorted(non_comm))
file_context += f"[Parsed Non-Commutative Set] {result}"
except Exception as e:
file_context += f"[Table Parse Error] {e}"
# Parse table to extract non-commutative elements
import re
import pandas as pd
try:
table_lines = [line.strip() for line in file_context.splitlines() if '|' in line and '*' not in line[:2]]
headers = re.split(r'\|+', table_lines[0])[1:-1]
data_rows = [re.split(r'\|+', row)[1:-1] for row in table_lines[1:]]
index = [row[0] for row in data_rows]
matrix = [row[1:] for row in data_rows]
df = pd.DataFrame(matrix, index=index, columns=headers)
non_comm = set()
for a in df.index:
for b in df.columns:
if df.at[a, b] != df.at[b, a]:
non_comm.add(a)
non_comm.add(b)
result = ", ".join(sorted(non_comm))
file_context += f"[Parsed Non-Commutative Set] {result}"
except Exception as e:
file_context += f"[Table Parse Error] {e}"