File size: 3,276 Bytes
10e9b7d
6e92f6f
10e9b7d
eccf8e4
3c4371f
6576efa
759cedb
fdb94f8
10e9b7d
6576efa
 
1dcff21
 
fdb94f8
 
1dcff21
 
 
 
 
 
 
 
 
 
fdb94f8
1dcff21
fdb94f8
91cad6f
31243f4
c8bf6ed
 
6576efa
c8bf6ed
2d924bf
fdb94f8
1dcff21
 
f86bd24
 
91cad6f
fdb94f8
f86bd24
fdb94f8
f86bd24
c8bf6ed
f86bd24
4021bf3
fdb94f8
6576efa
c8bf6ed
 
1dcff21
e80aab9
31243f4
6576efa
31243f4
 
6576efa
fdb94f8
 
 
c8bf6ed
fdb94f8
 
c8bf6ed
 
fdb94f8
 
759cedb
fdb94f8
 
 
f8e24f8
c8bf6ed
 
31243f4
fdb94f8
759cedb
 
fdb94f8
759cedb
fdb94f8
1dcff21
 
 
c8bf6ed
31243f4
b93c01e
e80aab9
c8bf6ed
7e4a06b
c8bf6ed
1dcff21
c8bf6ed
 
e80aab9
 
3c4371f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import os
import time
import gradio as gr
import requests
import pandas as pd

from smolagents import CodeAgent, OpenAIServerModel
from smolagents.tools import WebSearchTool, BaseTools

# Constants
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
MAX_QUESTION_LENGTH = 4000

# Reliable search tool with retry
class ReliableWebSearch(WebSearchTool):
    def run(self, query: str) -> str:
        for attempt in range(3):
            try:
                return super().run(query)
            except Exception as e:
                if "rate" in str(e).lower():
                    print(f"[WebSearch] Rate limit, retry {attempt+1}/3")
                    time.sleep(2 * (attempt + 1))
                else:
                    raise
        raise RuntimeError("WebSearchTool failed after retries")

# Main agent using GPT-4 & tools
class SmartGAIAAgent:
    def __init__(self):
        key = os.getenv("OPENAI_API_KEY")
        if not key:
            raise ValueError("Missing OPENAI_API_KEY")
        model = OpenAIServerModel(model_id="gpt-4", api_key=key)
        self.agent = CodeAgent(
            tools=[ReliableWebSearch()],
            model=model,
            add_base_tools=True
        )

    def __call__(self, question: str) -> str:
        q = question[:MAX_QUESTION_LENGTH]
        try:
            return self.agent.run(q).strip()
        except Exception as e:
            print("Agent error:", e)
            return "error"

# Fetch, filter, run, and submit answers
def run_and_submit_all(profile: gr.OAuthProfile | None):
    username = profile.username if profile else None
    if not username:
        return "Please Login to Hugging Face", None

    try:
        agent = SmartGAIAAgent()
    except Exception as e:
        return f"Error initializing agent: {e}", None

    resp = requests.get(f"{DEFAULT_API_URL}/questions", timeout=15)
    resp.raise_for_status()
    questions = resp.json()

    payload, logs = [], []
    skip_kw = ['.mp3','.wav','.png','.jpg','youtube','video','watch','listen']
    for item in questions:
        tid = item.get("task_id")
        q = item.get("question","")
        if not tid or not q or len(q)>MAX_QUESTION_LENGTH or any(k in q.lower() for k in skip_kw):
            continue
        ans = agent(q)
        payload.append({"task_id": tid, "submitted_answer": ans})
        logs.append({"Task ID": tid, "Question": q, "Submitted Answer": ans})

    if not payload:
        return "No valid questions to submit.", pd.DataFrame(logs)

    sub = {
        "username": username,
        "agent_code": f"https://huggingface.co/spaces/{os.getenv('SPACE_ID')}/tree/main",
        "answers": payload
    }
    resp = requests.post(f"{DEFAULT_API_URL}/submit", json=sub, timeout=60)
    resp.raise_for_status()
    result = resp.json()
    status = f"Score: {result.get('score')}% ({result.get('correct_count')}/{result.get('total_attempted')})"
    return status, pd.DataFrame(logs)

# Gradio UI
with gr.Blocks() as demo:
    gr.Markdown("# GAIA Agent")
    gr.LoginButton()
    run_btn = gr.Button("Run & Submit")
    status = gr.Textbox(lines=5)
    table = gr.DataFrame()
    run_btn.click(run_and_submit_all, outputs=[status, table])

if __name__ == "__main__":
    demo.launch(debug=True, share=False)