File size: 3,907 Bytes
9f06b73
59ff18d
e836bd4
9f06b73
 
 
95da673
48d9442
d092bd1
9f06b73
92b0d1a
9f06b73
 
d092bd1
 
9881ede
9f06b73
 
9881ede
9f06b73
d092bd1
 
 
9f06b73
9881ede
 
9f06b73
d092bd1
9f06b73
48d9442
 
ffe4aa3
48d9442
 
 
ffe4aa3
9f06b73
ffe4aa3
9f06b73
5db119a
9f06b73
 
 
 
 
 
5db119a
9f06b73
d092bd1
9f06b73
 
 
88fa1a5
9f06b73
 
 
 
 
 
 
 
 
48d9442
9f06b73
 
 
 
 
 
 
 
 
48d9442
9f06b73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
893a971
9f06b73
 
 
ee7947e
9f06b73
 
 
 
 
 
 
 
 
 
 
 
 
 
2c93f3f
ee7947e
6ad00cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
# agent.py – HybridAgent: GAIA-style fallback + tools + no errors

import os
import requests
import openai
import traceback

from langchain_community.utilities.wikipedia import WikipediaAPIWrapper
from langchain_experimental.tools.python.tool import PythonREPLTool
from langchain_community.document_loaders import YoutubeLoader

# Set OpenAI API key
openai.api_key = os.getenv("OPENAI_API_KEY")

def search_wikipedia(query: str) -> str:
    try:
        wiki = WikipediaAPIWrapper(top_k_results=1, doc_content_chars_max=1200)
        return wiki.run(query)
    except Exception as e:
        return f"[TOOL ERROR] Wikipedia: {e}"

def run_python_code(code: str) -> str:
    try:
        tool = PythonREPLTool()
        if 'print' not in code:
            code = f"print({repr(code)})"
        return tool.run(code)
    except Exception as e:
        return f"[TOOL ERROR] Python: {e}"

def get_youtube_transcript(url: str) -> str:
    try:
        loader = YoutubeLoader.from_youtube_url(url, add_video_info=False)
        docs = loader.load()
        return " ".join(d.page_content for d in docs)
    except Exception as e:
        return f"[TOOL ERROR] YouTube: {e}"

def fetch_file_context(task_id: str) -> str:
    try:
        url = f"https://agents-course-unit4-scoring.hf.space/files/{task_id}"
        resp = requests.get(url, timeout=10)
        resp.raise_for_status()
        if "text" in resp.headers.get("Content-Type", ""):
            return resp.text[:3000]  # Truncate to stay safe
        return f"[Unsupported file type]"
    except Exception as e:
        return f"[FILE ERROR] {e}"

def build_prompt(question: str, context: str = "") -> str:
    return f"""
You are a highly skilled research assistant completing factual benchmark questions.

If any tools are required, try to simulate reasoning or summarize fallback.
Return a concise factual answer only – no explanations.

{context.strip()}

Question: {question.strip()}
Answer:"""

def ask_openai(prompt: str) -> str:
    try:
        response = openai.ChatCompletion.create(
            model="gpt-4-turbo",
            messages=[
                {"role": "system", "content": "Answer factually. Return only final result."},
                {"role": "user", "content": prompt},
            ],
            temperature=0.0
        )
        return response.choices[0].message.content.strip()
    except Exception as e:
        return f"[OPENAI ERROR] {e}"

# === Unified entry point ===

def answer_question(question: str, task_id: str = None) -> str:
    try:
        file_context = fetch_file_context(task_id) if task_id else ""

        # Optional: tool-based enhancement
        if "wikipedia" in question.lower():
            wiki = search_wikipedia(question)
            file_context += f"\n[Wikipedia] {wiki}"
        elif "youtube.com" in question:
            yt = get_youtube_transcript(question)
            file_context += f"\n[YouTube Transcript] {yt}"
        elif "* on the set" in question and file_context:
        # Parse table to extract non-commutative elements
    import re
    import pandas as pd

    try:
        table_lines = [line.strip() for line in file_context.splitlines() if '|' in line and '*' not in line[:2]]
        headers = re.split(r'\|+', table_lines[0])[1:-1]
        data_rows = [re.split(r'\|+', row)[1:-1] for row in table_lines[1:]]
        index = [row[0] for row in data_rows]
        matrix = [row[1:] for row in data_rows]
        df = pd.DataFrame(matrix, index=index, columns=headers)

        non_comm = set()
        for a in df.index:
            for b in df.columns:
                if df.at[a, b] != df.at[b, a]:
                    non_comm.add(a)
                    non_comm.add(b)
        result = ", ".join(sorted(non_comm))
        file_context += f"[Parsed Non-Commutative Set] {result}"
    except Exception as e:
        file_context += f"[Table Parse Error] {e}"