File size: 3,454 Bytes
10e9b7d
6e92f6f
10e9b7d
eccf8e4
3c4371f
6576efa
759cedb
 
10e9b7d
6576efa
 
6e92f6f
759cedb
6e92f6f
c8bf6ed
759cedb
5ee0d29
6e92f6f
 
5ee0d29
6e92f6f
5ee0d29
759cedb
5ee0d29
6e92f6f
c8bf6ed
759cedb
6e92f6f
759cedb
91cad6f
31243f4
c8bf6ed
 
6576efa
c8bf6ed
2d924bf
759cedb
c8bf6ed
f8e24f8
f86bd24
 
91cad6f
c8bf6ed
f86bd24
c8bf6ed
f86bd24
c8bf6ed
f86bd24
4021bf3
c8bf6ed
6576efa
c8bf6ed
 
 
e80aab9
759cedb
31243f4
6576efa
31243f4
 
6576efa
759cedb
c8bf6ed
 
 
 
 
759cedb
c8bf6ed
 
759cedb
 
 
 
f8e24f8
c8bf6ed
 
 
f8e24f8
c8bf6ed
 
31243f4
759cedb
 
 
 
 
c8bf6ed
 
 
759cedb
c8bf6ed
 
31243f4
c8bf6ed
e80aab9
c8bf6ed
7e4a06b
c8bf6ed
 
 
 
e80aab9
 
3c4371f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import os
import time
import gradio as gr
import requests
import pandas as pd

from smolagents import CodeAgent, OpenAIServerModel
from smolagents.tools import WebSearchTool  # Correct tool

# Constants
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
MAX_QUESTION_LENGTH = 4000
MAX_TOKENS_PER_QUESTION = 8000  # Safety margin to avoid 8192-token error

# Reliable search tool with retry
class ReliableWebSearchTool(WebSearchTool):
    def run(self, query: str) -> str:
        for attempt in range(3):
            try:
                return super().run(query)
            except Exception as e:
                if "rate" in str(e).lower():
                    print(f"[WebSearchTool] Rate limit, retry {attempt+1}/3")
                    time.sleep(2 * (attempt + 1))
                else:
                    raise
        raise RuntimeError("WebSearchTool failed after retries")

# Main agent using GPT-4 and tools
class SmartGAIAAgent:
    def __init__(self):
        key = os.getenv("OPENAI_API_KEY")
        if not key:
            raise ValueError("Missing OPENAI_API_KEY")
        model = OpenAIServerModel(model_id="gpt-4", api_key=key)
        self.agent = CodeAgent(
            tools=[ReliableWebSearchTool()],
            model=model,
            add_base_tools=True
        )

    def __call__(self, question: str) -> str:
        q = question[:MAX_QUESTION_LENGTH]
        try:
            return self.agent.run(q).strip()
        except Exception as e:
            print("Agent error:", e)
            return "error"

# Fetch, filter, run, and submit answers
def run_and_submit_all(profile: gr.OAuthProfile | None):
    username = profile.username if profile else None
    if not username:
        return "Please Login to Hugging Face", None

    # Instantiate agent
    try:
        agent = SmartGAIAAgent()
    except Exception as e:
        return f"Error initializing agent: {e}", None

    # Get questions
    resp = requests.get(f"{DEFAULT_API_URL}/questions", timeout=15)
    resp.raise_for_status()
    questions = resp.json()

    payload, logs = [], []
    skip_kw = ['.mp3', '.wav', '.png', '.jpg', 'youtube', 'video', 'watch', 'listen']
    for item in questions:
        tid = item.get("task_id")
        q = item.get("question", "")
        if not tid or not q:
            continue
        if len(q) > MAX_QUESTION_LENGTH or any(k in q.lower() for k in skip_kw):
            continue
        ans = agent(q)
        payload.append({"task_id": tid, "submitted_answer": ans})
        logs.append({"Task ID": tid, "Question": q, "Submitted Answer": ans})

    if not payload:
        return "No valid questions to submit.", pd.DataFrame(logs)

    sub = {
        "username": username,
        "agent_code": f"https://huggingface.co/spaces/{os.getenv('SPACE_ID')}/tree/main",
        "answers": payload
    }
    resp = requests.post(f"{DEFAULT_API_URL}/submit", json=sub, timeout=60)
    resp.raise_for_status()
    result = resp.json()

    status = f"Score: {result.get('score')}% ({result.get('correct_count')}/{result.get('total_attempted')})"
    return status, pd.DataFrame(logs)

# Gradio UI
with gr.Blocks() as demo:
    gr.Markdown("# GAIA Agent")
    gr.LoginButton()
    run_btn = gr.Button("Run & Submit")
    status = gr.Textbox(lines=5)
    table = gr.DataFrame()
    run_btn.click(run_and_submit_all, outputs=[status, table])

if __name__ == "__main__":
    demo.launch(debug=True, share=False)