Tesvia commited on
Commit
bb49a20
·
verified ·
1 Parent(s): 52d1305

Upload agent.py

Browse files
Files changed (1) hide show
  1. agent.py +12 -12
agent.py CHANGED
@@ -21,9 +21,14 @@ from smolagents import (
21
  Tool,
22
  )
23
 
24
- # Custom Tools
25
- from .tools import (PythonRunTool, ExcelLoaderTool, YouTubeTranscriptTool,
26
- AudioTranscriptionTool, SimpleOCRTool)
 
 
 
 
 
27
 
28
  # ---------------------------------------------------------------------------
29
  # Model selection helper
@@ -31,7 +36,6 @@ from .tools import (PythonRunTool, ExcelLoaderTool, YouTubeTranscriptTool,
31
 
32
  load_dotenv() # Make sure we read credentials from .env when running locally
33
 
34
-
35
  def _select_model():
36
  """Return a smolagents *model* as configured by the ``MODEL_PROVIDER`` env."""
37
 
@@ -88,14 +92,10 @@ class GAIAAgent(CodeAgent):
88
  # ---------------------------------------------------------------------------
89
 
90
  def gaia_agent(*, extra_tools: Sequence[Tool] | None = None) -> GAIAAgent:
91
- base_tools = [
92
- DuckDuckGoSearchTool(),
93
- CustomTool1(),
94
- CustomTool2(),
95
- ]
96
  if extra_tools:
97
- base_tools.extend(extra_tools)
98
- return GAIAAgent(tools=base_tools)
99
-
100
 
101
  __all__ = ["GAIAAgent", "gaia_agent"]
 
21
  Tool,
22
  )
23
 
24
+ # Custom Tools from tools.py
25
+ from .tools import (
26
+ PythonRunTool,
27
+ ExcelLoaderTool,
28
+ YouTubeTranscriptTool,
29
+ AudioTranscriptionTool,
30
+ SimpleOCRTool,
31
+ )
32
 
33
  # ---------------------------------------------------------------------------
34
  # Model selection helper
 
36
 
37
  load_dotenv() # Make sure we read credentials from .env when running locally
38
 
 
39
  def _select_model():
40
  """Return a smolagents *model* as configured by the ``MODEL_PROVIDER`` env."""
41
 
 
92
  # ---------------------------------------------------------------------------
93
 
94
  def gaia_agent(*, extra_tools: Sequence[Tool] | None = None) -> GAIAAgent:
95
+ # Compose the toolset: always include all default tools, plus any extras
96
+ toolset = list(DEFAULT_TOOLS)
 
 
 
97
  if extra_tools:
98
+ toolset.extend(extra_tools)
99
+ return GAIAAgent(tools=toolset)
 
100
 
101
  __all__ = ["GAIAAgent", "gaia_agent"]