File size: 3,255 Bytes
10e9b7d
6e92f6f
10e9b7d
eccf8e4
3c4371f
6576efa
759cedb
1dcff21
10e9b7d
6576efa
 
1dcff21
 
b93c01e
1dcff21
 
 
 
 
 
 
 
 
 
 
 
91cad6f
31243f4
c8bf6ed
 
6576efa
c8bf6ed
2d924bf
1dcff21
b93c01e
1dcff21
f86bd24
 
91cad6f
b93c01e
f86bd24
b93c01e
f86bd24
c8bf6ed
f86bd24
4021bf3
6576efa
c8bf6ed
 
1dcff21
e80aab9
31243f4
6576efa
31243f4
 
6576efa
b93c01e
 
 
 
 
 
c8bf6ed
 
 
759cedb
b93c01e
759cedb
b93c01e
 
 
 
 
f8e24f8
c8bf6ed
 
31243f4
b93c01e
759cedb
 
b93c01e
759cedb
b93c01e
 
1dcff21
 
 
c8bf6ed
31243f4
b93c01e
e80aab9
c8bf6ed
7e4a06b
c8bf6ed
1dcff21
c8bf6ed
 
e80aab9
 
3c4371f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import os
import time
import gradio as gr
import requests
import pandas as pd

from smolagents import CodeAgent, OpenAIServerModel
from smolagents.tools.web_search import WebSearchTool

# Constants
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
MAX_QUESTION_LENGTH = 4000

class RetryableWebSearchTool(WebSearchTool):
    def run(self, query: str) -> str:
        for attempt in range(3):
            try:
                return super().run(query)
            except Exception as e:
                if "rate" in str(e).lower():
                    print(f"[WebSearch] Rate limit, retry {attempt+1}/3")
                    time.sleep(2 * (attempt + 1))
                else:
                    raise
        raise RuntimeError("Web search failed after retries")

class SmartGAIAAgent:
    def __init__(self):
        key = os.getenv("OPENAI_API_KEY")
        if not key:
            raise ValueError("Missing OPENAI_API_KEY")
        model = OpenAIServerModel(model_id="gpt-4", api_key=key)
        self.agent = CodeAgent(
            model=model,
            tools=[RetryableWebSearchTool()],
            add_base_tools=True
        )

    def __call__(self, question: str) -> str:
        question = question[:MAX_QUESTION_LENGTH]
        try:
            return self.agent.run(question).strip()
        except Exception as e:
            print("Agent error:", e)
            return "error"

def run_and_submit_all(profile: gr.OAuthProfile | None):
    username = profile.username if profile else None
    if not username:
        return "Please Login to Hugging Face", None

    try:
        agent = SmartGAIAAgent()
    except Exception as e:
        return f"Error initializing agent: {e}", None

    response = requests.get(f"{DEFAULT_API_URL}/questions", timeout=15)
    response.raise_for_status()
    questions = response.json()

    payload = []
    logs = []

    for item in questions:
        tid = item.get("task_id")
        q = item.get("question", "")
        if not tid or not q or len(q) > MAX_QUESTION_LENGTH:
            continue
        if any(skip in q.lower() for skip in ['.mp3', '.wav', '.jpg', '.png', 'youtube', 'video', 'watch', 'listen']):
            continue
        answer = agent(q)
        payload.append({"task_id": tid, "submitted_answer": answer})
        logs.append({"Task ID": tid, "Question": q, "Submitted Answer": answer})

    if not payload:
        return "No valid questions to submit.", pd.DataFrame(logs)

    submission = {
        "username": username,
        "agent_code": f"https://huggingface.co/spaces/{os.getenv('SPACE_ID')}/tree/main",
        "answers": payload,
    }

    resp = requests.post(f"{DEFAULT_API_URL}/submit", json=submission, timeout=30)
    resp.raise_for_status()
    result = resp.json()
    status = f"Score: {result.get('score')}% ({result.get('correct_count')}/{result.get('total_attempted')})"
    return status, pd.DataFrame(logs)

# Gradio UI
with gr.Blocks() as demo:
    gr.Markdown("# GAIA Agent")
    gr.LoginButton()
    run_btn = gr.Button("Run & Submit")
    status = gr.Textbox(lines=5)
    table = gr.DataFrame()
    run_btn.click(run_and_submit_all, outputs=[status, table])

if __name__ == "__main__":
    demo.launch(debug=True, share=False)