Tesvia commited on
Commit
4191a9b
·
verified ·
1 Parent(s): 6e06cc8

Upload 5 files

Browse files
Files changed (4) hide show
  1. agent.py +97 -118
  2. app.py +8 -21
  3. requirements.txt +2 -4
  4. tools.py +158 -130
agent.py CHANGED
@@ -1,150 +1,129 @@
1
- """
2
- GAIA benchmark agent using the OpenAI Agents SDK.
3
- """
4
 
5
- from __future__ import annotations
 
 
 
 
 
 
 
6
 
7
- import asyncio
8
  import os
9
- from typing import Any, Sequence, Callable, List
10
- from datetime import datetime
11
- from agents import RunHooks # for lifecycle hooks
12
 
13
  from dotenv import load_dotenv
14
- from agents import Agent, Runner, FunctionTool, Tool
15
 
16
- # Import all function tools
 
 
 
 
 
 
 
17
  from tools import (
18
- python_run,
19
- load_spreadsheet,
20
- youtube_transcript,
21
- transcribe_audio,
22
- image_ocr,
23
- duckduckgo_search,
24
  )
25
 
 
26
  # ---------------------------------------------------------------------------
27
- # Load the added system prompt
28
  # ---------------------------------------------------------------------------
29
  ADDED_PROMPT_PATH = os.path.join(os.path.dirname(__file__), "added_prompt.txt")
30
  with open(ADDED_PROMPT_PATH, "r", encoding="utf-8") as f:
31
  ADDED_PROMPT = f.read().strip()
32
 
33
- load_dotenv()
34
 
 
 
 
35
 
36
- def _select_model() -> str:
37
- """Return a model identifier appropriate for the Agents SDK based on environment settings."""
38
- provider = os.getenv("MODEL_PROVIDER", "hf").lower()
39
 
40
- if provider == "openai":
41
- model_name = os.getenv("OPENAI_MODEL", "gpt-4o-mini")
42
- return f"openai/{model_name}"
 
43
 
44
  if provider == "hf":
45
- hf_model_id = os.getenv("HF_MODEL", "Qwen/Qwen2.5-Coder-32B-Instruct")
46
- return f"litellm/huggingface/{hf_model_id}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
  raise ValueError(
49
- f"Unsupported MODEL_PROVIDER: {provider!r}. Expected 'openai' or 'hf'."
 
50
  )
51
 
 
 
 
52
 
53
- DEFAULT_TOOLS: List[FunctionTool] = [
54
- python_run,
55
- load_spreadsheet,
56
- youtube_transcript,
57
- transcribe_audio,
58
- image_ocr,
59
- duckduckgo_search,
60
  ]
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
- def _build_agent(extra_tools: Sequence[FunctionTool] | None = None) -> Agent:
64
- """Construct the underlying Agents SDK `Agent` instance."""
65
- instructions = (
66
- "You are a helpful assistant tasked with answering questions using the available tools.\n\n"
67
- + ADDED_PROMPT
68
- )
69
 
70
- tools: Sequence[Tool] = list(DEFAULT_TOOLS)
 
 
71
  if extra_tools:
72
- tools = list(tools) + list(extra_tools)
73
-
74
- return Agent(
75
- name="GAIA Agent",
76
- instructions=instructions,
77
- tools=tools,
78
- model=_select_model(),
79
- )
80
-
81
-
82
- class LoggingHooks(RunHooks):
83
- """RunHooks to log question start, model used, and each tool‐call step."""
84
- def __init__(self):
85
- self.step_counter = 0
86
-
87
- async def on_agent_start(self, context, agent):
88
- qnum = context.context.get("question_number")
89
- qtext = context.context.get("question_text")
90
- model = agent.model
91
- ts = datetime.now().isoformat()
92
- print(f"[{ts}] [Question {qnum}] Starting agent (model={model}) for question: '{qtext}'")
93
-
94
- async def on_tool_start(self, context, agent, tool):
95
- self.step_counter += 1
96
- qnum = context.context.get("question_number")
97
- ts = datetime.now().isoformat()
98
- print(f"[{ts}] [Question {qnum}] Step {self.step_counter}: Invoking tool '{tool.name}'")
99
-
100
- async def on_tool_end(self, context, agent, tool, result):
101
- qnum = context.context.get("question_number")
102
- ts = datetime.now().isoformat()
103
- print(f"[{ts}] [Question {qnum}] Step {self.step_counter}: Tool '{tool.name}' completed")
104
-
105
-
106
- class GAIAAgent:
107
- """Thin synchronous wrapper around an asynchronous Agents SDK agent."""
108
-
109
- def __init__(self, *, extra_tools: Sequence[FunctionTool] | None = None):
110
- self._agent = _build_agent(extra_tools=extra_tools)
111
-
112
- async def _arun(self, question: str, context_data=None, hooks=None) -> str:
113
- # Pass context and hooks to Runner.run if provided
114
- if context_data is not None and hooks is not None:
115
- result = await Runner.run(
116
- self._agent,
117
- question,
118
- context=context_data,
119
- hooks=hooks
120
- )
121
- else:
122
- result = await Runner.run(self._agent, question)
123
- return str(result.final_output).strip()
124
-
125
- def __call__(self, question: str, question_number: int | None = None, **_kwargs) -> str:
126
- # Prepare logging context if a question_number is given
127
- context_data = None
128
- hooks = None
129
- if question_number is not None:
130
- context_data = {
131
- "question_number": question_number,
132
- "question_text": question
133
- }
134
- hooks = LoggingHooks()
135
-
136
- try:
137
- loop = asyncio.get_running_loop()
138
- except RuntimeError:
139
- # No running loop: use asyncio.run
140
- return asyncio.run(self._arun(question, context_data, hooks))
141
- else:
142
- return loop.run_until_complete(self._arun(question, context_data, hooks))
143
-
144
-
145
- def gaia_agent(*, extra_tools: Sequence[FunctionTool] | None = None) -> GAIAAgent:
146
- """Factory returning a ready‑to‑use GAIAAgent instance."""
147
- return GAIAAgent(extra_tools=extra_tools)
148
-
149
 
150
  __all__ = ["GAIAAgent", "gaia_agent"]
 
1
+ """GAIA benchmark agent using *smolagents*.
 
 
2
 
3
+ This module exposes:
4
+
5
+ * ``gaia_agent()`` – factory returning a ready‑to‑use agent instance.
6
+ * ``GAIAAgent`` – subclass of ``smolagents.CodeAgent``.
7
+
8
+ The LLM backend is chosen at runtime via the ``MODEL_PROVIDER``
9
+ environment variable (``hf`` or ``openai``) exactly like *example.py*.
10
+ """
11
 
 
12
  import os
13
+ from typing import Any, Sequence
 
 
14
 
15
  from dotenv import load_dotenv
 
16
 
17
+ # SmolAgents Tools
18
+ from smolagents import (
19
+ CodeAgent,
20
+ DuckDuckGoSearchTool,
21
+ Tool
22
+ )
23
+
24
+ # Custom Tools from tools.py
25
  from tools import (
26
+ PythonRunTool,
27
+ ExcelLoaderTool,
28
+ YouTubeTranscriptTool,
29
+ AudioTranscriptionTool,
30
+ SimpleOCRTool,
 
31
  )
32
 
33
+
34
  # ---------------------------------------------------------------------------
35
+ # Load the added system prompt from system_prompt.txt (located in the same directory)
36
  # ---------------------------------------------------------------------------
37
  ADDED_PROMPT_PATH = os.path.join(os.path.dirname(__file__), "added_prompt.txt")
38
  with open(ADDED_PROMPT_PATH, "r", encoding="utf-8") as f:
39
  ADDED_PROMPT = f.read().strip()
40
 
 
41
 
42
+ # ---------------------------------------------------------------------------
43
+ # Model selection helper
44
+ # ---------------------------------------------------------------------------
45
 
46
+ load_dotenv() # Make sure we read credentials from .env when running locally
 
 
47
 
48
+ def _select_model():
49
+ """Return a smolagents *model* as configured by the ``MODEL_PROVIDER`` env."""
50
+
51
+ provider = os.getenv("MODEL_PROVIDER", "hf").lower()
52
 
53
  if provider == "hf":
54
+ from smolagents import InferenceClientModel
55
+ hf_model_id = os.getenv("HF_MODEL", "HuggingFaceH4/zephyr-7b-beta")
56
+ hf_token = os.getenv("HF_API_KEY")
57
+ return InferenceClientModel(
58
+ model_id=hf_model_id,
59
+ token=hf_token
60
+ )
61
+
62
+ if provider == "openai":
63
+ from smolagents import OpenAIServerModel
64
+ openai_model_id = os.getenv("OPENAI_MODEL", "gpt-3.5-turbo")
65
+ openai_token = os.getenv("OPENAI_API_KEY")
66
+ return OpenAIServerModel(
67
+ model_id=openai_model_id,
68
+ api_key=openai_token
69
+ )
70
 
71
  raise ValueError(
72
+ f"Unsupported MODEL_PROVIDER: {provider!r}. "
73
+ "Use 'hf' (default) or 'openai'."
74
  )
75
 
76
+ # ---------------------------------------------------------------------------
77
+ # Core Agent implementation
78
+ # ---------------------------------------------------------------------------
79
 
80
+ DEFAULT_TOOLS = [
81
+ DuckDuckGoSearchTool(),
82
+ PythonRunTool(),
83
+ ExcelLoaderTool(),
84
+ YouTubeTranscriptTool(),
85
+ AudioTranscriptionTool(),
86
+ SimpleOCRTool(),
87
  ]
88
 
89
+ class GAIAAgent(CodeAgent):
90
+ def __init__(
91
+ self,
92
+ tools=None
93
+ ):
94
+ super().__init__(
95
+ tools=tools or DEFAULT_TOOLS,
96
+ model=_select_model()
97
+ )
98
+ # Append the additional prompt to the existing system prompt
99
+ self.prompt_templates["system_prompt"] += f"\n\n{ADDED_PROMPT}"
100
+
101
+ # Convenience so the object itself can be *called* directly
102
+ def __call__(self, question: str, **kwargs: Any) -> str:
103
+ steps = self.run(question, **kwargs)
104
+ # If steps is a primitive, just return it
105
+ if isinstance(steps, (int, float, str)):
106
+ return str(steps).strip()
107
+ last_step = None
108
+ for step in steps:
109
+ last_step = step
110
+ # Defensive: handle int/float/str directly
111
+ if isinstance(last_step, (int, float, str)):
112
+ return str(last_step).strip()
113
+ answer = getattr(last_step, "answer", None)
114
+ if answer is not None:
115
+ return str(answer).strip()
116
+ return str(last_step).strip()
117
 
118
+ # ---------------------------------------------------------------------------
119
+ # Factory helpers expected by app.py
120
+ # ---------------------------------------------------------------------------
 
 
 
121
 
122
+ def gaia_agent(*, extra_tools: Sequence[Tool] | None = None) -> GAIAAgent:
123
+ # Compose the toolset: always include all default tools, plus any extras
124
+ toolset = list(DEFAULT_TOOLS)
125
  if extra_tools:
126
+ toolset.extend(extra_tools)
127
+ return GAIAAgent(tools=toolset)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
 
129
  __all__ = ["GAIAAgent", "gaia_agent"]
app.py CHANGED
@@ -32,10 +32,10 @@ def run_and_submit_all( profile: gr.OAuthProfile | None):
32
  questions_url = f"{api_url}/questions"
33
  submit_url = f"{api_url}/submit"
34
 
35
- # 1. Instantiate Agent (now using OpenAI Agents SDK)
36
  try:
37
  agent = gaia_agent()
38
- print("OpenAI Agent instantiated successfully.")
39
  except Exception as e:
40
  print(f"Error instantiating agent: {e}")
41
  return f"Error initializing agent: {e}", None
@@ -70,16 +70,14 @@ def run_and_submit_all( profile: gr.OAuthProfile | None):
70
  results_log = []
71
  answers_payload = []
72
  print(f"Running agent on {len(questions_data)} questions...")
73
- for idx, item in enumerate(questions_data, start=1):
74
  task_id = item.get("task_id")
75
  question_text = item.get("question")
76
  if not task_id or question_text is None:
77
  print(f"Skipping item with missing task_id or question: {item}")
78
  continue
79
  try:
80
- # pass in question_number for logging hooks
81
- submitted_answer = agent(question_text, question_number=idx)
82
-
83
  # --- DEBUG LOGGING ---
84
  if DEBUG:
85
  print(f"[DEBUG] Task {task_id}: Answer type: {type(submitted_answer)}, Value: {repr(submitted_answer)}")
@@ -88,22 +86,11 @@ def run_and_submit_all( profile: gr.OAuthProfile | None):
88
 
89
  # Force string type here just in case (defensive)
90
  submitted_answer = str(submitted_answer).strip()
91
- answers_payload.append({
92
- "task_id": task_id,
93
- "submitted_answer": submitted_answer
94
- })
95
- results_log.append({
96
- "Task ID": task_id,
97
- "Question": question_text,
98
- "Submitted Answer": submitted_answer
99
- })
100
  except Exception as e:
101
- print(f"Error running agent on task {task_id}: {e}")
102
- results_log.append({
103
- "Task ID": task_id,
104
- "Question": question_text,
105
- "Submitted Answer": f"AGENT ERROR: {e}"
106
- })
107
 
108
  if not answers_payload:
109
  print("Agent did not produce any answers to submit.")
 
32
  questions_url = f"{api_url}/questions"
33
  submit_url = f"{api_url}/submit"
34
 
35
+ # 1. Instantiate Agent (now using smolagents)
36
  try:
37
  agent = gaia_agent()
38
+ print("SmolAgent instantiated successfully.")
39
  except Exception as e:
40
  print(f"Error instantiating agent: {e}")
41
  return f"Error initializing agent: {e}", None
 
70
  results_log = []
71
  answers_payload = []
72
  print(f"Running agent on {len(questions_data)} questions...")
73
+ for item in questions_data:
74
  task_id = item.get("task_id")
75
  question_text = item.get("question")
76
  if not task_id or question_text is None:
77
  print(f"Skipping item with missing task_id or question: {item}")
78
  continue
79
  try:
80
+ submitted_answer = agent(question_text)
 
 
81
  # --- DEBUG LOGGING ---
82
  if DEBUG:
83
  print(f"[DEBUG] Task {task_id}: Answer type: {type(submitted_answer)}, Value: {repr(submitted_answer)}")
 
86
 
87
  # Force string type here just in case (defensive)
88
  submitted_answer = str(submitted_answer).strip()
89
+ answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer})
90
+ results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": submitted_answer})
 
 
 
 
 
 
 
91
  except Exception as e:
92
+ print(f"Error running agent on task {task_id}: {e}")
93
+ results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": f"AGENT ERROR: {e}"})
 
 
 
 
94
 
95
  if not answers_payload:
96
  print("Agent did not produce any answers to submit.")
requirements.txt CHANGED
@@ -1,10 +1,8 @@
1
  gradio
2
  requests
3
  pandas
4
- openai-agents[litellm]
5
- openai>=1.3
6
  duckduckgo-search
7
  youtube-transcript-api
8
  pytesseract
9
- pillow
10
- python-dotenv
 
1
  gradio
2
  requests
3
  pandas
4
+ smolagents[openai]
 
5
  duckduckgo-search
6
  youtube-transcript-api
7
  pytesseract
8
+ pillow
 
tools.py CHANGED
@@ -1,142 +1,170 @@
1
- """
2
- Custom function tools for OpenAI Agents SDK GAIA agent.
3
- """
4
-
5
  from __future__ import annotations
6
-
7
  import contextlib
8
  import io
9
  import os
10
- from typing import List, Dict
11
 
12
- from agents import function_tool
13
 
14
- # 1. --------------------------------------------------------------------
15
- @function_tool
16
- def python_run(code: str) -> str:
17
- """Execute trusted Python code and return the captured stdout together with
18
- the repr() of the last expression (or `_result` variable if set).
19
-
20
- Args:
21
- code: Python code to execute.
22
  """
23
- buf = io.StringIO()
24
- ns: dict = {}
25
- last = None
26
- try:
27
- with contextlib.redirect_stdout(buf):
28
- exec(compile(code, "<agent-python>", "exec"), {}, ns)
29
- last = ns.get("_result")
30
- except Exception as e:
31
- raise RuntimeError(f"python_run error: {e}") from e
32
-
33
- out = buf.getvalue()
34
- return (out + (repr(last) if last is not None else "")).strip()
35
-
36
-
37
- # 2. --------------------------------------------------------------------
38
- @function_tool
39
- def load_spreadsheet(path: str, sheet: str | int | None = None) -> list[Dict[str, str]]:
40
- """Read .csv, .xls or .xlsx from disk and return rows as list of dictionaries.
41
-
42
- Args:
43
- path: Path to spreadsheet file.
44
- sheet: Sheet name or index (for Excel files only).
 
 
 
 
 
 
45
  """
46
- import pandas as pd
47
-
48
- if not os.path.isfile(path):
49
- raise FileNotFoundError(path)
50
- ext = os.path.splitext(path)[1].lower()
51
- if ext == ".csv":
52
- df = pd.read_csv(path)
53
- dfs = [df]
54
- else:
55
- sheets = pd.read_excel(path, sheet_name=sheet if sheet not in ("", None) else None)
56
- if isinstance(sheets, dict):
57
- dfs = sheets.values()
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  else:
59
- dfs = [sheets]
60
- results = []
61
- for df in dfs:
62
- results.extend([{str(k): v for k, v in row.items()} for row in df.to_dict(orient="records")])
63
- return results
64
-
65
-
66
- # 3. --------------------------------------------------------------------
67
- @function_tool
68
- def youtube_transcript(url: str, lang: str = "en") -> str:
69
- """Fetch the subtitles of a YouTube video.
70
-
71
- Args:
72
- url: YouTube video URL.
73
- lang: Preferred transcript language code (default "en").
74
- """
75
- from urllib.parse import urlparse, parse_qs
76
- from youtube_transcript_api._api import YouTubeTranscriptApi
77
-
78
- vid = parse_qs(urlparse(url).query).get("v", [None])[0] or url.split("/")[-1]
79
- data = YouTubeTranscriptApi.get_transcript(
80
- vid, languages=[lang, "en", "en-US", "en-GB"]
81
- )
82
- return " ".join(chunk["text"] for chunk in data).strip()
83
-
84
-
85
- # 4. --------------------------------------------------------------------
86
- @function_tool
87
- def transcribe_audio(path: str, model: str = "whisper-1") -> str:
88
- """Transcribe an audio file using OpenAI Whisper.
89
-
90
- Args:
91
- path: Path to audio file (wav / mp3 / m4a / etc.).
92
- model: Whisper model name (default "whisper-1").
93
  """
94
- import openai
95
-
96
- if not os.path.isfile(path):
97
- raise FileNotFoundError(path)
98
-
99
- client = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
100
- with open(path, "rb") as fp:
101
- transcript = client.audio.transcriptions.create(model=model, file=fp)
102
- return transcript.text.strip()
103
-
104
-
105
- # 5. --------------------------------------------------------------------
106
- @function_tool
107
- def image_ocr(path: str) -> str:
108
- """Perform OCR on an image using Tesseract.
109
-
110
- Args:
111
- path: Path to image file.
 
 
 
 
 
 
 
 
 
 
 
112
  """
113
- from PIL import Image
114
- import pytesseract
115
-
116
- if not os.path.isfile(path):
117
- raise FileNotFoundError(path)
118
- return pytesseract.image_to_string(Image.open(path)).strip()
119
-
120
-
121
- # 6. --------------------------------------------------------------------
122
- @function_tool
123
- def duckduckgo_search(query: str, max_results: int = 5) -> List[Dict[str, str]]:
124
- """Search DuckDuckGo and return a list of result dicts with title, href and body.
125
-
126
- Args:
127
- query: The search query.
128
- max_results: Maximum results to return (default 5).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  """
130
- from duckduckgo_search import DDGS
131
-
132
- results = []
133
- with DDGS() as ddgs:
134
- for r in ddgs.text(query, max_results=max_results):
135
- results.append(
136
- {
137
- "title": r.get("title", ""),
138
- "href": r.get("href", ""),
139
- "body": r.get("body", ""),
140
- }
141
- )
142
- return results
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Custom tools for smolagents GAIA agent
 
 
 
2
  from __future__ import annotations
 
3
  import contextlib
4
  import io
5
  import os
6
+ from typing import Any, Dict, List
7
 
8
+ from smolagents import Tool
9
 
10
+ # ---- 1. PythonRunTool ------------------------------------------------------
11
+ class PythonRunTool(Tool):
12
+ name = "python_run"
13
+ description = """
14
+ Execute trusted Python code and return printed output + repr() of the last expression (or _result variable).
 
 
 
15
  """
16
+ inputs = {
17
+ "code": {
18
+ "type": "string",
19
+ "description": "Python code to execute",
20
+ "required": True
21
+ }
22
+ }
23
+ output_type = "string"
24
+
25
+ def forward(self, code: str) -> str:
26
+ buf, ns = io.StringIO(), {}
27
+ last = None
28
+ try:
29
+ with contextlib.redirect_stdout(buf):
30
+ exec(compile(code, "<agent-python>", "exec"), {}, ns)
31
+ last = ns.get("_result", None)
32
+ except Exception as e:
33
+ raise RuntimeError(f"PythonRunTool error: {e}") from e
34
+ out = buf.getvalue()
35
+ # Always return a string
36
+ result = (out + (repr(last) if last is not None else "")).strip()
37
+ return str(result)
38
+
39
+ # ---- 2. ExcelLoaderTool ----------------------------------------------------
40
+ class ExcelLoaderTool(Tool):
41
+ name = "load_spreadsheet"
42
+ description = """
43
+ Read .xlsx/.xls/.csv from disk and return rows as a list of dictionaries with string keys.
44
  """
45
+ inputs = {
46
+ "path": {
47
+ "type": "string",
48
+ "description": "Path to .csv/.xls/.xlsx file",
49
+ "required": True
50
+ },
51
+ "sheet": {
52
+ "type": "string",
53
+ "description": "Sheet name or index (optional, required for Excel files only)",
54
+ "required": False,
55
+ "default": "",
56
+ "nullable": True
57
+ }
58
+ }
59
+ output_type = "array"
60
+
61
+ def forward(self, path: str, sheet: str | int | None = None) -> str:
62
+ import pandas as pd
63
+ if not os.path.isfile(path):
64
+ raise FileNotFoundError(path)
65
+ ext = os.path.splitext(path)[1].lower()
66
+ if sheet == "":
67
+ sheet = None
68
+ if ext == ".csv":
69
+ df = pd.read_csv(path)
70
  else:
71
+ df = pd.read_excel(path, sheet_name=sheet)
72
+ if isinstance(df, dict):
73
+ # If user did not specify a sheet, use the first one found
74
+ first_sheet = next(iter(df))
75
+ df = df[first_sheet]
76
+ records = [{str(k): v for k, v in row.items()} for row in df.to_dict(orient="records")]
77
+ # Always return a string
78
+ return str(records)
79
+
80
+ # ---- 3. YouTubeTranscriptTool ---------------------------------------------
81
+ class YouTubeTranscriptTool(Tool):
82
+ name = "youtube_transcript"
83
+ description = """
84
+ Return the subtitles of a YouTube URL using youtube-transcript-api.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  """
86
+ inputs = {
87
+ "url": {
88
+ "type": "string",
89
+ "description": "YouTube URL",
90
+ "required": True
91
+ },
92
+ "lang": {
93
+ "type": "string",
94
+ "description": "Transcript language (default: en)",
95
+ "required": False,
96
+ "default": "en",
97
+ "nullable": True
98
+ }
99
+ }
100
+ output_type = "string"
101
+
102
+ def forward(self, url: str, lang: str = "en") -> str:
103
+ from urllib.parse import urlparse, parse_qs
104
+ from youtube_transcript_api._api import YouTubeTranscriptApi
105
+ vid = parse_qs(urlparse(url).query).get("v", [None])[0] or url.split("/")[-1]
106
+ data = YouTubeTranscriptApi.get_transcript(vid, languages=[lang, "en", "en-US", "en-GB"])
107
+ text = " ".join(d["text"] for d in data).strip()
108
+ return str(text)
109
+
110
+ # ---- 4. AudioTranscriptionTool --------------------------------------------
111
+ class AudioTranscriptionTool(Tool):
112
+ name = "transcribe_audio"
113
+ description = """
114
+ Transcribe an audio file with OpenAI Whisper, returns plain text."
115
  """
116
+ inputs = {
117
+ "path": {
118
+ "type": "string",
119
+ "description": "Path to audio file",
120
+ "required": True
121
+ },
122
+ "model": {
123
+ "type": "string",
124
+ "description": "Model name for transcription (default: whisper-1)",
125
+ "required": False,
126
+ "default": "whisper-1",
127
+ "nullable": True
128
+ }
129
+ }
130
+ output_type = "string"
131
+
132
+ def forward(self, path: str, model: str = "whisper-1") -> str:
133
+ import openai
134
+ if not os.path.isfile(path):
135
+ raise FileNotFoundError(path)
136
+ client = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
137
+ with open(path, "rb") as fp:
138
+ transcript = client.audio.transcriptions.create(model=model, file=fp)
139
+ return str(transcript.text.strip())
140
+
141
+ # ---- 5. SimpleOCRTool ------------------------------------------------------
142
+ class SimpleOCRTool(Tool):
143
+ name = "image_ocr"
144
+ description = """
145
+ Return any text spotted in an image via pytesseract OCR.
146
  """
147
+ inputs = {
148
+ "path": {
149
+ "type": "string",
150
+ "description": "Path to image file",
151
+ "required": True
152
+ }
153
+ }
154
+ output_type = "string"
155
+
156
+ def forward(self, path: str) -> str:
157
+ from PIL import Image
158
+ import pytesseract
159
+ if not os.path.isfile(path):
160
+ raise FileNotFoundError(path)
161
+ return str(pytesseract.image_to_string(Image.open(path)).strip())
162
+
163
+ # ---------------------------------------------------------------------------
164
+ __all__ = [
165
+ "PythonRunTool",
166
+ "ExcelLoaderTool",
167
+ "YouTubeTranscriptTool",
168
+ "AudioTranscriptionTool",
169
+ "SimpleOCRTool",
170
+ ]