Tesvia commited on
Commit
52d1305
·
verified ·
1 Parent(s): 6eeaf9f

Upload 4 files

Browse files
Files changed (4) hide show
  1. agent.py +101 -88
  2. app.py +2 -3
  3. requirements.txt +5 -2
  4. tools.py +104 -0
agent.py CHANGED
@@ -1,88 +1,101 @@
1
- """
2
- agent.py – central coordinator for smolagents-powered agent.
3
-
4
- This file exposes a single helper function `my_agent()` that returns an
5
- object which is **callable** (i.e. implements `__call__(question:str) -> str`)
6
- so that `app.py` can stay unchanged apart from a single import.
7
-
8
- * Adding new tools
9
- ------------------
10
- 1. Drop the tool file inside the ``/tools`` package.
11
- 2. Import the tool class in `my_agent` and append it to the ``tools`` list.
12
- The rest of the application will automatically pick it up.
13
- """
14
-
15
- from typing import List, Sequence
16
-
17
- try:
18
- from smolagents import Agent, Tool # type: ignore
19
- except ImportError as exc: # pragma: no cover
20
- raise ImportError(
21
- "smolagents must be in requirements.txt. "
22
- "Add `smolagents` to your dependencies."
23
- ) from exc
24
-
25
- # Available tools
26
- from tools.web_search import DuckDuckGoSearchTool # noqa: E402
27
-
28
-
29
- class SmolAgentWrapper:
30
- """
31
- Thin wrapper that makes a smolagents.Agent *callable*.
32
-
33
- The evaluation harness in app.py expects an object that can be called
34
- directly with a single question and that returns a string. The underlying
35
- smolagents agent is session-aware and can handle multi-turn conversations
36
- but we keep the public interface single-turn for now.
37
- """
38
-
39
- def __init__(self, tools: Sequence["Tool"] | None = None) -> None: # type: ignore[name-defined]
40
- if tools is None:
41
- tools = [DuckDuckGoSearchTool()]
42
- self._agent = Agent(tools=list(tools))
43
-
44
- # Allow the object itself to be called like a function
45
- def __call__(self, question: str) -> str: # noqa: D401 (simple summary ok)
46
- """
47
- Ask the underlying smolagents Agent a **single** question and return the answer.
48
-
49
- Any exception is caught and surfaced as a readable string in order not
50
- to crash the evaluation loop.
51
- """
52
- try:
53
- response = self._agent.run(question)
54
- # smolagents may return dicts or ToolOutput objects; normalise to str
55
- if isinstance(response, str):
56
- return response
57
- return str(response)
58
- except Exception as err: # pragma: no cover
59
- return f"ERROR: {type(err).__name__}: {err}"
60
-
61
-
62
- # --------------------------------------------------------------------------- #
63
- # Helper – this is what app.py will import
64
- # --------------------------------------------------------------------------- #
65
-
66
- def my_agent(extra_tools: Sequence["Tool"] | None = None) -> SmolAgentWrapper: # type: ignore[name-defined]
67
- """
68
- Factory that returns a ready-to-go agent.
69
-
70
- Parameters
71
- ----------
72
- extra_tools:
73
- Optional sequence of additional smolagents Tool objects to extend the
74
- agent's capabilities. They are appended **after** the default search
75
- tool so they can override it if they expose the same name.
76
-
77
- Returns
78
- -------
79
- SmolAgentWrapper
80
- A callable object compatible with the original BasicAgent.
81
- """
82
- tools: List["Tool"] = [DuckDuckGoSearchTool()]
83
- if extra_tools:
84
- tools.extend(extra_tools)
85
- return SmolAgentWrapper(tools=tools)
86
-
87
-
88
- __all__ = ["my_agent", "SmolAgentWrapper"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """agent.py – GAIA benchmark agent using *smolagents*.
2
+
3
+ This module exposes:
4
+
5
+ * ``gaia_agent()`` factory returning a ready‑to‑use agent instance.
6
+ * ``GAIAAgent`` – subclass of ``smolagents.CodeAgent``.
7
+
8
+ The LLM backend is chosen at runtime via the ``MODEL_PROVIDER``
9
+ environment variable (``hf`` or ``openai``) exactly like *example.py*.
10
+ """
11
+
12
+ import os
13
+ from typing import Any, Sequence
14
+
15
+ from dotenv import load_dotenv
16
+
17
+ # SmolAgents Tools
18
+ from smolagents import (
19
+ CodeAgent,
20
+ DuckDuckGoSearchTool,
21
+ Tool,
22
+ )
23
+
24
+ # Custom Tools
25
+ from .tools import (PythonRunTool, ExcelLoaderTool, YouTubeTranscriptTool,
26
+ AudioTranscriptionTool, SimpleOCRTool)
27
+
28
+ # ---------------------------------------------------------------------------
29
+ # Model selection helper
30
+ # ---------------------------------------------------------------------------
31
+
32
+ load_dotenv() # Make sure we read credentials from .env when running locally
33
+
34
+
35
+ def _select_model():
36
+ """Return a smolagents *model* as configured by the ``MODEL_PROVIDER`` env."""
37
+
38
+ provider = os.getenv("MODEL_PROVIDER", "hf").lower()
39
+
40
+ if provider == "hf":
41
+ from smolagents import InferenceClientModel
42
+ hf_model_id = os.getenv("HF_MODEL", "HuggingFaceH4/zephyr-7b-beta")
43
+ hf_token = os.getenv("HF_API_KEY")
44
+ return InferenceClientModel(model_id=hf_model_id, token=hf_token)
45
+
46
+ if provider == "openai":
47
+ from smolagents import OpenAIServerModel
48
+ openai_model_id = os.getenv("OPENAI_MODEL", "gpt-3.5-turbo")
49
+ openai_token = os.getenv("OPENAI_API_KEY")
50
+ return OpenAIServerModel(model_id=openai_model_id, api_key=openai_token)
51
+
52
+ raise ValueError(
53
+ f"Unsupported MODEL_PROVIDER: {provider!r}. "
54
+ "Use 'hf' (default) or 'openai'."
55
+ )
56
+
57
+ # ---------------------------------------------------------------------------
58
+ # Core Agent implementation
59
+ # ---------------------------------------------------------------------------
60
+
61
+ DEFAULT_TOOLS = [
62
+ DuckDuckGoSearchTool(),
63
+ PythonRunTool(),
64
+ ExcelLoaderTool(),
65
+ YouTubeTranscriptTool(),
66
+ AudioTranscriptionTool(),
67
+ SimpleOCRTool(),
68
+ ]
69
+
70
+ class GAIAAgent(CodeAgent):
71
+ def __init__(self, tools=None):
72
+ super().__init__(tools=tools or DEFAULT_TOOLS, model=_select_model())
73
+
74
+ # Convenience so the object itself can be *called* directly
75
+ def __call__(self, question: str, **kwargs: Any) -> str:
76
+ steps = self.run(question, **kwargs)
77
+ last_step = None
78
+ for step in steps:
79
+ last_step = step
80
+ # If last_step is a FinalAnswerStep with .answer, return it
81
+ answer = getattr(last_step, "answer", None)
82
+ if answer is not None:
83
+ return str(answer)
84
+ return str(last_step)
85
+
86
+ # ---------------------------------------------------------------------------
87
+ # Factory helpers expected by app.py
88
+ # ---------------------------------------------------------------------------
89
+
90
+ def gaia_agent(*, extra_tools: Sequence[Tool] | None = None) -> GAIAAgent:
91
+ base_tools = [
92
+ DuckDuckGoSearchTool(),
93
+ CustomTool1(),
94
+ CustomTool2(),
95
+ ]
96
+ if extra_tools:
97
+ base_tools.extend(extra_tools)
98
+ return GAIAAgent(tools=base_tools)
99
+
100
+
101
+ __all__ = ["GAIAAgent", "gaia_agent"]
app.py CHANGED
@@ -1,11 +1,10 @@
1
  import os
2
  import gradio as gr
3
  import requests
4
- import inspect
5
  import pandas as pd
6
 
7
  # --- Our Agent ---
8
- from agent import my_agent
9
 
10
  # (Keep Constants as is)
11
  # --- Constants ---
@@ -33,7 +32,7 @@ def run_and_submit_all( profile: gr.OAuthProfile | None):
33
 
34
  # 1. Instantiate Agent (now using smolagents)
35
  try:
36
- agent = my_agent()
37
  print("SmolAgent instantiated successfully.")
38
  except Exception as e:
39
  print(f"Error instantiating agent: {e}")
 
1
  import os
2
  import gradio as gr
3
  import requests
 
4
  import pandas as pd
5
 
6
  # --- Our Agent ---
7
+ from agent import gaia_agent
8
 
9
  # (Keep Constants as is)
10
  # --- Constants ---
 
32
 
33
  # 1. Instantiate Agent (now using smolagents)
34
  try:
35
+ agent = gaia_agent()
36
  print("SmolAgent instantiated successfully.")
37
  except Exception as e:
38
  print(f"Error instantiating agent: {e}")
requirements.txt CHANGED
@@ -1,5 +1,8 @@
1
  gradio
2
  requests
3
  pandas
4
- smolagents
5
- duckduckgo-search
 
 
 
 
1
  gradio
2
  requests
3
  pandas
4
+ smolagents[openai]
5
+ duckduckgo-search
6
+ youtube-transcript-api
7
+ pytesseract
8
+ pillow
tools.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Custom tools for smolagents GAIA agent
2
+ from __future__ import annotations
3
+ import contextlib
4
+ import io
5
+ import os
6
+ from typing import Any, Dict, List, Hashable
7
+
8
+ from smolagents import Tool
9
+
10
+ # ---- 1. PythonRunTool ------------------------------------------------------
11
+ class PythonRunTool(Tool):
12
+ name = "python_run"
13
+ description = (
14
+ "Execute trusted Python code and return printed output "
15
+ "+ repr() of the last expression (or _result variable)."
16
+ )
17
+
18
+ def forward(self, code: str) -> str: # type: ignore[override]
19
+ buf, ns = io.StringIO(), {}
20
+ last = None
21
+ try:
22
+ with contextlib.redirect_stdout(buf):
23
+ exec(compile(code, "<agent-python>", "exec"), {}, ns)
24
+ last = ns.get("_result", None)
25
+ except Exception as e:
26
+ raise RuntimeError(f"PythonRunTool error: {e}") from e
27
+ out = buf.getvalue()
28
+ return (out + (repr(last) if last is not None else "")).strip()
29
+
30
+ # ---- 2. ExcelLoaderTool ----------------------------------------------------
31
+ class ExcelLoaderTool(Tool):
32
+ name = "load_spreadsheet"
33
+ description = (
34
+ "Read .xlsx/.xls/.csv from disk and return "
35
+ "rows as a list of dictionaries with string keys."
36
+ )
37
+
38
+ def forward(self, path: str, sheet: str | int | None = None) -> List[Dict[str, Any]]: # type: ignore[override]
39
+ import pandas as pd
40
+ if not os.path.isfile(path):
41
+ raise FileNotFoundError(path)
42
+ ext = os.path.splitext(path)[1].lower()
43
+ if ext == ".csv":
44
+ df = pd.read_csv(path)
45
+ else:
46
+ df = pd.read_excel(path, sheet_name=sheet)
47
+ # Ensure all keys are str for type safety
48
+ records = [{str(k): v for k, v in row.items()} for row in df.to_dict(orient="records")]
49
+ return records
50
+
51
+ # ---- 3. YouTubeTranscriptTool ---------------------------------------------
52
+ class YouTubeTranscriptTool(Tool):
53
+ name = "youtube_transcript"
54
+ description = "Return the subtitles of a YouTube URL using youtube-transcript-api."
55
+
56
+ def forward(self, url: str, lang: str = "en") -> str: # type: ignore[override]
57
+ from urllib.parse import urlparse, parse_qs
58
+ # Per Pylance, import from private API
59
+ from youtube_transcript_api._api import YouTubeTranscriptApi
60
+ vid = parse_qs(urlparse(url).query).get("v", [None])[0] or url.split("/")[-1]
61
+ data = YouTubeTranscriptApi.get_transcript(vid, languages=[lang, "en", "en-US", "en-GB"])
62
+ return " ".join(d["text"] for d in data).strip()
63
+
64
+ # ---- 4. AudioTranscriptionTool --------------------------------------------
65
+ class AudioTranscriptionTool(Tool):
66
+ name = "transcribe_audio"
67
+ description = "Transcribe an audio file with OpenAI Whisper, returns plain text."
68
+
69
+ def forward(self, path: str, model: str = "whisper-1") -> str: # type: ignore[override]
70
+ import openai
71
+ import os
72
+ if not os.path.isfile(path):
73
+ raise FileNotFoundError(path)
74
+ openai.api_key = os.getenv("OPENAI_API_KEY")
75
+ # Version/API guard for openai.Audio
76
+ if not hasattr(openai, "Audio"):
77
+ raise ImportError(
78
+ "Your OpenAI package does not support Audio. "
79
+ "Please upgrade it with: pip install --upgrade openai"
80
+ )
81
+ with open(path, "rb") as fp:
82
+ # type: ignore[attr-defined]
83
+ return openai.Audio.transcribe(model=model, file=fp)["text"].strip()
84
+
85
+ # ---- 5. SimpleOCRTool ------------------------------------------------------
86
+ class SimpleOCRTool(Tool):
87
+ name = "image_ocr"
88
+ description = "Return any text spotted in an image via pytesseract OCR."
89
+
90
+ def forward(self, path: str) -> str: # type: ignore[override]
91
+ from PIL import Image
92
+ import pytesseract
93
+ if not os.path.isfile(path):
94
+ raise FileNotFoundError(path)
95
+ return pytesseract.image_to_string(Image.open(path)).strip()
96
+
97
+ # ---------------------------------------------------------------------------
98
+ __all__ = [
99
+ "PythonRunTool",
100
+ "ExcelLoaderTool",
101
+ "YouTubeTranscriptTool",
102
+ "AudioTranscriptionTool",
103
+ "SimpleOCRTool",
104
+ ]