|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from dataclasses import dataclass |
|
from typing import Any |
|
|
|
from .local_python_executor import ( |
|
BASE_BUILTIN_MODULES, |
|
BASE_PYTHON_TOOLS, |
|
evaluate_python_code, |
|
) |
|
from .tools import PipelineTool, Tool |
|
|
|
|
|
@dataclass |
|
class PreTool: |
|
name: str |
|
inputs: dict[str, str] |
|
output_type: type |
|
task: str |
|
description: str |
|
repo_id: str |
|
|
|
|
|
class PythonInterpreterTool(Tool): |
|
name = "python_interpreter" |
|
description = "This is a tool that evaluates python code. It can be used to perform calculations." |
|
inputs = { |
|
"code": { |
|
"type": "string", |
|
"description": "The python code to run in interpreter", |
|
} |
|
} |
|
output_type = "string" |
|
|
|
def __init__(self, *args, authorized_imports=None, **kwargs): |
|
if authorized_imports is None: |
|
self.authorized_imports = list(set(BASE_BUILTIN_MODULES)) |
|
else: |
|
self.authorized_imports = list(set(BASE_BUILTIN_MODULES) | set(authorized_imports)) |
|
self.inputs = { |
|
"code": { |
|
"type": "string", |
|
"description": ( |
|
"The code snippet to evaluate. All variables used in this snippet must be defined in this same snippet, " |
|
f"else you will get an error. This code can only import the following python libraries: {self.authorized_imports}." |
|
), |
|
} |
|
} |
|
self.base_python_tools = BASE_PYTHON_TOOLS |
|
self.python_evaluator = evaluate_python_code |
|
super().__init__(*args, **kwargs) |
|
|
|
def forward(self, code: str) -> str: |
|
state = {} |
|
output = str( |
|
self.python_evaluator( |
|
code, |
|
state=state, |
|
static_tools=self.base_python_tools, |
|
authorized_imports=self.authorized_imports, |
|
)[0] |
|
) |
|
return f"Stdout:\n{str(state['_print_outputs'])}\nOutput: {output}" |
|
|
|
|
|
class FinalAnswerTool(Tool): |
|
name = "final_answer" |
|
description = "Provides a final answer to the given problem." |
|
inputs = {"answer": {"type": "any", "description": "The final answer to the problem"}} |
|
output_type = "any" |
|
|
|
def forward(self, answer: Any) -> Any: |
|
return answer |
|
|
|
|
|
class UserInputTool(Tool): |
|
name = "user_input" |
|
description = "Asks for user's input on a specific question" |
|
inputs = {"question": {"type": "string", "description": "The question to ask the user"}} |
|
output_type = "string" |
|
|
|
def forward(self, question): |
|
user_input = input(f"{question} => Type your answer here:") |
|
return user_input |
|
|
|
|
|
class DuckDuckGoSearchTool(Tool): |
|
name = "web_search" |
|
description = """Performs a duckduckgo web search based on your query (think a Google search) then returns the top search results.""" |
|
inputs = {"query": {"type": "string", "description": "The search query to perform."}} |
|
output_type = "string" |
|
|
|
def __init__(self, max_results=10, **kwargs): |
|
super().__init__() |
|
self.max_results = max_results |
|
try: |
|
from duckduckgo_search import DDGS |
|
except ImportError as e: |
|
raise ImportError( |
|
"You must install package `duckduckgo_search` to run this tool: for instance run `pip install duckduckgo-search`." |
|
) from e |
|
self.ddgs = DDGS(**kwargs) |
|
|
|
def forward(self, query: str) -> str: |
|
results = self.ddgs.text(query, max_results=self.max_results) |
|
if len(results) == 0: |
|
raise Exception("No results found! Try a less restrictive/shorter query.") |
|
postprocessed_results = [f"[{result['title']}]({result['href']})\n{result['body']}" for result in results] |
|
return "## Search Results\n\n" + "\n\n".join(postprocessed_results) |
|
|
|
|
|
class GoogleSearchTool(Tool): |
|
name = "web_search" |
|
description = """Performs a google web search for your query then returns a string of the top search results.""" |
|
inputs = { |
|
"query": {"type": "string", "description": "The search query to perform."}, |
|
"filter_year": { |
|
"type": "integer", |
|
"description": "Optionally restrict results to a certain year", |
|
"nullable": True, |
|
}, |
|
} |
|
output_type = "string" |
|
|
|
def __init__(self, provider: str = "serpapi"): |
|
super().__init__() |
|
import os |
|
|
|
self.provider = provider |
|
if provider == "serpapi": |
|
self.organic_key = "organic_results" |
|
api_key_env_name = "SERPAPI_API_KEY" |
|
else: |
|
self.organic_key = "organic" |
|
api_key_env_name = "SERPER_API_KEY" |
|
self.api_key = os.getenv(api_key_env_name) |
|
if self.api_key is None: |
|
raise ValueError(f"Missing API key. Make sure you have '{api_key_env_name}' in your env variables.") |
|
|
|
def forward(self, query: str, filter_year: int | None = None) -> str: |
|
import requests |
|
|
|
if self.provider == "serpapi": |
|
params = { |
|
"q": query, |
|
"api_key": self.api_key, |
|
"engine": "google", |
|
"google_domain": "google.com", |
|
} |
|
base_url = "https://serpapi.com/search.json" |
|
else: |
|
params = { |
|
"q": query, |
|
"api_key": self.api_key, |
|
} |
|
base_url = "https://google.serper.dev/search" |
|
if filter_year is not None: |
|
params["tbs"] = f"cdr:1,cd_min:01/01/{filter_year},cd_max:12/31/{filter_year}" |
|
|
|
response = requests.get(base_url, params=params) |
|
|
|
if response.status_code == 200: |
|
results = response.json() |
|
else: |
|
raise ValueError(response.json()) |
|
|
|
if self.organic_key not in results.keys(): |
|
if filter_year is not None: |
|
raise Exception( |
|
f"No results found for query: '{query}' with filtering on year={filter_year}. Use a less restrictive query or do not filter on year." |
|
) |
|
else: |
|
raise Exception(f"No results found for query: '{query}'. Use a less restrictive query.") |
|
if len(results[self.organic_key]) == 0: |
|
year_filter_message = f" with filter year={filter_year}" if filter_year is not None else "" |
|
return f"No results found for '{query}'{year_filter_message}. Try with a more general query, or remove the year filter." |
|
|
|
web_snippets = [] |
|
if self.organic_key in results: |
|
for idx, page in enumerate(results[self.organic_key]): |
|
date_published = "" |
|
if "date" in page: |
|
date_published = "\nDate published: " + page["date"] |
|
|
|
source = "" |
|
if "source" in page: |
|
source = "\nSource: " + page["source"] |
|
|
|
snippet = "" |
|
if "snippet" in page: |
|
snippet = "\n" + page["snippet"] |
|
|
|
redacted_version = f"{idx}. [{page['title']}]({page['link']}){date_published}{source}\n{snippet}" |
|
web_snippets.append(redacted_version) |
|
|
|
return "## Search Results\n" + "\n\n".join(web_snippets) |
|
|
|
|
|
class ApiWebSearchTool(Tool): |
|
name = "web_search" |
|
description = "Performs a web search for a query and returns a string of the top search results formatted as markdown with titles, URLs, and descriptions." |
|
inputs = {"query": {"type": "string", "description": "The search query to perform."}} |
|
output_type = "string" |
|
|
|
def __init__( |
|
self, endpoint: str = "", api_key: str = "", api_key_name: str = "", headers: dict = None, params: dict = None |
|
): |
|
import os |
|
|
|
super().__init__() |
|
self.endpoint = endpoint or "https://api.search.brave.com/res/v1/web/search" |
|
self.api_key = api_key or os.getenv(api_key_name) |
|
self.headers = headers or {"X-Subscription-Token": self.api_key} |
|
self.params = params or {"count": 10} |
|
|
|
def forward(self, query: str) -> str: |
|
import requests |
|
|
|
params = {**self.params, "q": query} |
|
response = requests.get(self.endpoint, headers=self.headers, params=params) |
|
response.raise_for_status() |
|
data = response.json() |
|
results = self.extract_results(data) |
|
return self.format_markdown(results) |
|
|
|
def extract_results(self, data: dict) -> list: |
|
results = [] |
|
for result in data.get("web", {}).get("results", []): |
|
results.append( |
|
{"title": result["title"], "url": result["url"], "description": result.get("description", "")} |
|
) |
|
return results |
|
|
|
def format_markdown(self, results: list) -> str: |
|
if not results: |
|
return "No results found." |
|
return "## Search Results\n\n" + "\n\n".join( |
|
[ |
|
f"{idx}. [{result['title']}]({result['url']})\n{result['description']}" |
|
for idx, result in enumerate(results, start=1) |
|
] |
|
) |
|
|
|
|
|
class WebSearchTool(Tool): |
|
name = "web_search" |
|
description = "Performs a web search for a query and returns a string of the top search results formatted as markdown with titles, links, and descriptions." |
|
inputs = {"query": {"type": "string", "description": "The search query to perform."}} |
|
output_type = "string" |
|
|
|
def __init__(self, max_results: int = 10, engine: str = "duckduckgo"): |
|
super().__init__() |
|
self.max_results = max_results |
|
self.engine = engine |
|
|
|
def forward(self, query: str) -> str: |
|
results = self.search(query) |
|
if len(results) == 0: |
|
raise Exception("No results found! Try a less restrictive/shorter query.") |
|
return self.parse_results(results) |
|
|
|
def search(self, query: str) -> list: |
|
if self.engine == "duckduckgo": |
|
return self.search_duckduckgo(query) |
|
elif self.engine == "bing": |
|
return self.search_bing(query) |
|
else: |
|
raise ValueError(f"Unsupported engine: {self.engine}") |
|
|
|
def parse_results(self, results: list) -> str: |
|
return "## Search Results\n\n" + "\n\n".join( |
|
[f"[{result['title']}]({result['link']})\n{result['description']}" for result in results] |
|
) |
|
|
|
def search_duckduckgo(self, query: str) -> list: |
|
import requests |
|
|
|
response = requests.get( |
|
"https://lite.duckduckgo.com/lite/", |
|
params={"q": query}, |
|
headers={"User-Agent": "Mozilla/5.0"}, |
|
) |
|
response.raise_for_status() |
|
parser = self._create_duckduckgo_parser() |
|
parser.feed(response.text) |
|
return parser.results |
|
|
|
def _create_duckduckgo_parser(self): |
|
from html.parser import HTMLParser |
|
|
|
class SimpleResultParser(HTMLParser): |
|
def __init__(self): |
|
super().__init__() |
|
self.results = [] |
|
self.current = {} |
|
self.capture_title = False |
|
self.capture_description = False |
|
self.capture_link = False |
|
|
|
def handle_starttag(self, tag, attrs): |
|
attrs = dict(attrs) |
|
if tag == "a" and attrs.get("class") == "result-link": |
|
self.capture_title = True |
|
elif tag == "td" and attrs.get("class") == "result-snippet": |
|
self.capture_description = True |
|
elif tag == "span" and attrs.get("class") == "link-text": |
|
self.capture_link = True |
|
|
|
def handle_endtag(self, tag): |
|
if tag == "a" and self.capture_title: |
|
self.capture_title = False |
|
elif tag == "td" and self.capture_description: |
|
self.capture_description = False |
|
elif tag == "span" and self.capture_link: |
|
self.capture_link = False |
|
elif tag == "tr": |
|
|
|
if {"title", "description", "link"} <= self.current.keys(): |
|
self.current["description"] = " ".join(self.current["description"]) |
|
self.results.append(self.current) |
|
self.current = {} |
|
|
|
def handle_data(self, data): |
|
if self.capture_title: |
|
self.current["title"] = data.strip() |
|
elif self.capture_description: |
|
self.current.setdefault("description", []) |
|
self.current["description"].append(data.strip()) |
|
elif self.capture_link: |
|
self.current["link"] = "https://" + data.strip() |
|
|
|
return SimpleResultParser() |
|
|
|
def search_bing(self, query: str) -> list: |
|
import xml.etree.ElementTree as ET |
|
|
|
import requests |
|
|
|
response = requests.get( |
|
"https://www.bing.com/search", |
|
params={"q": query, "format": "rss"}, |
|
) |
|
response.raise_for_status() |
|
root = ET.fromstring(response.text) |
|
items = root.findall(".//item") |
|
results = [ |
|
{ |
|
"title": item.findtext("title"), |
|
"link": item.findtext("link"), |
|
"description": item.findtext("description"), |
|
} |
|
for item in items[: self.max_results] |
|
] |
|
return results |
|
|
|
|
|
class VisitWebpageTool(Tool): |
|
name = "visit_webpage" |
|
description = ( |
|
"Visits a webpage at the given url and reads its content as a markdown string. Use this to browse webpages." |
|
) |
|
inputs = { |
|
"url": { |
|
"type": "string", |
|
"description": "The url of the webpage to visit.", |
|
} |
|
} |
|
output_type = "string" |
|
|
|
def __init__(self, max_output_length: int = 40000): |
|
super().__init__() |
|
self.max_output_length = max_output_length |
|
|
|
def _truncate_content(self, content: str, max_length: int) -> str: |
|
if len(content) <= max_length: |
|
return content |
|
return ( |
|
content[: max_length // 2] |
|
+ f"\n..._This content has been truncated to stay below {max_length} characters_...\n" |
|
+ content[-max_length // 2 :] |
|
) |
|
|
|
def forward(self, url: str) -> str: |
|
try: |
|
import re |
|
|
|
import requests |
|
from markdownify import markdownify |
|
from requests.exceptions import RequestException |
|
except ImportError as e: |
|
raise ImportError( |
|
"You must install packages `markdownify` and `requests` to run this tool: for instance run `pip install markdownify requests`." |
|
) from e |
|
try: |
|
|
|
response = requests.get(url, timeout=20) |
|
response.raise_for_status() |
|
|
|
|
|
markdown_content = markdownify(response.text).strip() |
|
|
|
|
|
markdown_content = re.sub(r"\n{3,}", "\n\n", markdown_content) |
|
|
|
return self._truncate_content(markdown_content, self.max_output_length) |
|
|
|
except requests.exceptions.Timeout: |
|
return "The request timed out. Please try again later or check the URL." |
|
except RequestException as e: |
|
return f"Error fetching the webpage: {str(e)}" |
|
except Exception as e: |
|
return f"An unexpected error occurred: {str(e)}" |
|
|
|
|
|
class WikipediaSearchTool(Tool): |
|
""" |
|
WikipediaSearchTool searches Wikipedia and returns a summary or full text of the given topic, along with the page URL. |
|
|
|
Attributes: |
|
user_agent (str): A custom user-agent string to identify the project. This is required as per Wikipedia API policies, read more here: http://github.com/martin-majlis/Wikipedia-API/blob/master/README.rst |
|
language (str): The language in which to retrieve Wikipedia articles. |
|
http://meta.wikimedia.org/wiki/List_of_Wikipedias |
|
content_type (str): Defines the content to fetch. Can be "summary" for a short summary or "text" for the full article. |
|
extract_format (str): Defines the output format. Can be `"WIKI"` or `"HTML"`. |
|
|
|
Example: |
|
>>> from smolagents import CodeAgent, InferenceClientModel, WikipediaSearchTool |
|
>>> agent = CodeAgent( |
|
>>> tools=[ |
|
>>> WikipediaSearchTool( |
|
>>> user_agent="MyResearchBot ([email protected])", |
|
>>> language="en", |
|
>>> content_type="summary", # or "text" |
|
>>> extract_format="WIKI", |
|
>>> ) |
|
>>> ], |
|
>>> model=InferenceClientModel(), |
|
>>> ) |
|
>>> agent.run("Python_(programming_language)") |
|
""" |
|
|
|
name = "wikipedia_search" |
|
description = "Searches Wikipedia and returns a summary or full text of the given topic, along with the page URL." |
|
inputs = { |
|
"query": { |
|
"type": "string", |
|
"description": "The topic to search on Wikipedia.", |
|
} |
|
} |
|
output_type = "string" |
|
|
|
def __init__( |
|
self, |
|
user_agent: str = "Smolagents ([email protected])", |
|
language: str = "en", |
|
content_type: str = "text", |
|
extract_format: str = "WIKI", |
|
): |
|
super().__init__() |
|
try: |
|
import wikipediaapi |
|
except ImportError as e: |
|
raise ImportError( |
|
"You must install `wikipedia-api` to run this tool: for instance run `pip install wikipedia-api`" |
|
) from e |
|
if not user_agent: |
|
raise ValueError("User-agent is required. Provide a meaningful identifier for your project.") |
|
|
|
self.user_agent = user_agent |
|
self.language = language |
|
self.content_type = content_type |
|
|
|
|
|
extract_format_map = { |
|
"WIKI": wikipediaapi.ExtractFormat.WIKI, |
|
"HTML": wikipediaapi.ExtractFormat.HTML, |
|
} |
|
|
|
if extract_format not in extract_format_map: |
|
raise ValueError("Invalid extract_format. Choose between 'WIKI' or 'HTML'.") |
|
|
|
self.extract_format = extract_format_map[extract_format] |
|
|
|
self.wiki = wikipediaapi.Wikipedia( |
|
user_agent=self.user_agent, language=self.language, extract_format=self.extract_format |
|
) |
|
|
|
def forward(self, query: str) -> str: |
|
try: |
|
page = self.wiki.page(query) |
|
|
|
if not page.exists(): |
|
return f"No Wikipedia page found for '{query}'. Try a different query." |
|
|
|
title = page.title |
|
url = page.fullurl |
|
|
|
if self.content_type == "summary": |
|
text = page.summary |
|
elif self.content_type == "text": |
|
text = page.text |
|
else: |
|
return "⚠️ Invalid `content_type`. Use either 'summary' or 'text'." |
|
|
|
return f"✅ **Wikipedia Page:** {title}\n\n**Content:** {text}\n\n🔗 **Read more:** {url}" |
|
|
|
except Exception as e: |
|
return f"Error fetching Wikipedia summary: {str(e)}" |
|
|
|
|
|
class SpeechToTextTool(PipelineTool): |
|
default_checkpoint = "openai/whisper-large-v3-turbo" |
|
description = "This is a tool that transcribes an audio into text. It returns the transcribed text." |
|
name = "transcriber" |
|
inputs = { |
|
"audio": { |
|
"type": "audio", |
|
"description": "The audio to transcribe. Can be a local path, an url, or a tensor.", |
|
} |
|
} |
|
output_type = "string" |
|
|
|
def __new__(cls, *args, **kwargs): |
|
from transformers.models.whisper import WhisperForConditionalGeneration, WhisperProcessor |
|
|
|
cls.pre_processor_class = WhisperProcessor |
|
cls.model_class = WhisperForConditionalGeneration |
|
return super().__new__(cls) |
|
|
|
def encode(self, audio): |
|
from .agent_types import AgentAudio |
|
|
|
audio = AgentAudio(audio).to_raw() |
|
return self.pre_processor(audio, return_tensors="pt") |
|
|
|
def forward(self, inputs): |
|
return self.model.generate(inputs["input_features"]) |
|
|
|
def decode(self, outputs): |
|
return self.pre_processor.batch_decode(outputs, skip_special_tokens=True)[0] |
|
|
|
|
|
TOOL_MAPPING = { |
|
tool_class.name: tool_class |
|
for tool_class in [ |
|
PythonInterpreterTool, |
|
DuckDuckGoSearchTool, |
|
VisitWebpageTool, |
|
] |
|
} |
|
|
|
__all__ = [ |
|
"ApiWebSearchTool", |
|
"PythonInterpreterTool", |
|
"FinalAnswerTool", |
|
"UserInputTool", |
|
"WebSearchTool", |
|
"DuckDuckGoSearchTool", |
|
"GoogleSearchTool", |
|
"VisitWebpageTool", |
|
"WikipediaSearchTool", |
|
"SpeechToTextTool", |
|
] |
|
|