Duibonduil's picture
Upload 17 files
d7949de verified
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Any
from .local_python_executor import (
BASE_BUILTIN_MODULES,
BASE_PYTHON_TOOLS,
evaluate_python_code,
)
from .tools import PipelineTool, Tool
@dataclass
class PreTool:
name: str
inputs: dict[str, str]
output_type: type
task: str
description: str
repo_id: str
class PythonInterpreterTool(Tool):
name = "python_interpreter"
description = "This is a tool that evaluates python code. It can be used to perform calculations."
inputs = {
"code": {
"type": "string",
"description": "The python code to run in interpreter",
}
}
output_type = "string"
def __init__(self, *args, authorized_imports=None, **kwargs):
if authorized_imports is None:
self.authorized_imports = list(set(BASE_BUILTIN_MODULES))
else:
self.authorized_imports = list(set(BASE_BUILTIN_MODULES) | set(authorized_imports))
self.inputs = {
"code": {
"type": "string",
"description": (
"The code snippet to evaluate. All variables used in this snippet must be defined in this same snippet, "
f"else you will get an error. This code can only import the following python libraries: {self.authorized_imports}."
),
}
}
self.base_python_tools = BASE_PYTHON_TOOLS
self.python_evaluator = evaluate_python_code
super().__init__(*args, **kwargs)
def forward(self, code: str) -> str:
state = {}
output = str(
self.python_evaluator(
code,
state=state,
static_tools=self.base_python_tools,
authorized_imports=self.authorized_imports,
)[0] # The second element is boolean is_final_answer
)
return f"Stdout:\n{str(state['_print_outputs'])}\nOutput: {output}"
class FinalAnswerTool(Tool):
name = "final_answer"
description = "Provides a final answer to the given problem."
inputs = {"answer": {"type": "any", "description": "The final answer to the problem"}}
output_type = "any"
def forward(self, answer: Any) -> Any:
return answer
class UserInputTool(Tool):
name = "user_input"
description = "Asks for user's input on a specific question"
inputs = {"question": {"type": "string", "description": "The question to ask the user"}}
output_type = "string"
def forward(self, question):
user_input = input(f"{question} => Type your answer here:")
return user_input
class DuckDuckGoSearchTool(Tool):
name = "web_search"
description = """Performs a duckduckgo web search based on your query (think a Google search) then returns the top search results."""
inputs = {"query": {"type": "string", "description": "The search query to perform."}}
output_type = "string"
def __init__(self, max_results=10, **kwargs):
super().__init__()
self.max_results = max_results
try:
from duckduckgo_search import DDGS
except ImportError as e:
raise ImportError(
"You must install package `duckduckgo_search` to run this tool: for instance run `pip install duckduckgo-search`."
) from e
self.ddgs = DDGS(**kwargs)
def forward(self, query: str) -> str:
results = self.ddgs.text(query, max_results=self.max_results)
if len(results) == 0:
raise Exception("No results found! Try a less restrictive/shorter query.")
postprocessed_results = [f"[{result['title']}]({result['href']})\n{result['body']}" for result in results]
return "## Search Results\n\n" + "\n\n".join(postprocessed_results)
class GoogleSearchTool(Tool):
name = "web_search"
description = """Performs a google web search for your query then returns a string of the top search results."""
inputs = {
"query": {"type": "string", "description": "The search query to perform."},
"filter_year": {
"type": "integer",
"description": "Optionally restrict results to a certain year",
"nullable": True,
},
}
output_type = "string"
def __init__(self, provider: str = "serpapi"):
super().__init__()
import os
self.provider = provider
if provider == "serpapi":
self.organic_key = "organic_results"
api_key_env_name = "SERPAPI_API_KEY"
else:
self.organic_key = "organic"
api_key_env_name = "SERPER_API_KEY"
self.api_key = os.getenv(api_key_env_name)
if self.api_key is None:
raise ValueError(f"Missing API key. Make sure you have '{api_key_env_name}' in your env variables.")
def forward(self, query: str, filter_year: int | None = None) -> str:
import requests
if self.provider == "serpapi":
params = {
"q": query,
"api_key": self.api_key,
"engine": "google",
"google_domain": "google.com",
}
base_url = "https://serpapi.com/search.json"
else:
params = {
"q": query,
"api_key": self.api_key,
}
base_url = "https://google.serper.dev/search"
if filter_year is not None:
params["tbs"] = f"cdr:1,cd_min:01/01/{filter_year},cd_max:12/31/{filter_year}"
response = requests.get(base_url, params=params)
if response.status_code == 200:
results = response.json()
else:
raise ValueError(response.json())
if self.organic_key not in results.keys():
if filter_year is not None:
raise Exception(
f"No results found for query: '{query}' with filtering on year={filter_year}. Use a less restrictive query or do not filter on year."
)
else:
raise Exception(f"No results found for query: '{query}'. Use a less restrictive query.")
if len(results[self.organic_key]) == 0:
year_filter_message = f" with filter year={filter_year}" if filter_year is not None else ""
return f"No results found for '{query}'{year_filter_message}. Try with a more general query, or remove the year filter."
web_snippets = []
if self.organic_key in results:
for idx, page in enumerate(results[self.organic_key]):
date_published = ""
if "date" in page:
date_published = "\nDate published: " + page["date"]
source = ""
if "source" in page:
source = "\nSource: " + page["source"]
snippet = ""
if "snippet" in page:
snippet = "\n" + page["snippet"]
redacted_version = f"{idx}. [{page['title']}]({page['link']}){date_published}{source}\n{snippet}"
web_snippets.append(redacted_version)
return "## Search Results\n" + "\n\n".join(web_snippets)
class ApiWebSearchTool(Tool):
name = "web_search"
description = "Performs a web search for a query and returns a string of the top search results formatted as markdown with titles, URLs, and descriptions."
inputs = {"query": {"type": "string", "description": "The search query to perform."}}
output_type = "string"
def __init__(
self, endpoint: str = "", api_key: str = "", api_key_name: str = "", headers: dict = None, params: dict = None
):
import os
super().__init__()
self.endpoint = endpoint or "https://api.search.brave.com/res/v1/web/search"
self.api_key = api_key or os.getenv(api_key_name)
self.headers = headers or {"X-Subscription-Token": self.api_key}
self.params = params or {"count": 10}
def forward(self, query: str) -> str:
import requests
params = {**self.params, "q": query}
response = requests.get(self.endpoint, headers=self.headers, params=params)
response.raise_for_status()
data = response.json()
results = self.extract_results(data)
return self.format_markdown(results)
def extract_results(self, data: dict) -> list:
results = []
for result in data.get("web", {}).get("results", []):
results.append(
{"title": result["title"], "url": result["url"], "description": result.get("description", "")}
)
return results
def format_markdown(self, results: list) -> str:
if not results:
return "No results found."
return "## Search Results\n\n" + "\n\n".join(
[
f"{idx}. [{result['title']}]({result['url']})\n{result['description']}"
for idx, result in enumerate(results, start=1)
]
)
class WebSearchTool(Tool):
name = "web_search"
description = "Performs a web search for a query and returns a string of the top search results formatted as markdown with titles, links, and descriptions."
inputs = {"query": {"type": "string", "description": "The search query to perform."}}
output_type = "string"
def __init__(self, max_results: int = 10, engine: str = "duckduckgo"):
super().__init__()
self.max_results = max_results
self.engine = engine
def forward(self, query: str) -> str:
results = self.search(query)
if len(results) == 0:
raise Exception("No results found! Try a less restrictive/shorter query.")
return self.parse_results(results)
def search(self, query: str) -> list:
if self.engine == "duckduckgo":
return self.search_duckduckgo(query)
elif self.engine == "bing":
return self.search_bing(query)
else:
raise ValueError(f"Unsupported engine: {self.engine}")
def parse_results(self, results: list) -> str:
return "## Search Results\n\n" + "\n\n".join(
[f"[{result['title']}]({result['link']})\n{result['description']}" for result in results]
)
def search_duckduckgo(self, query: str) -> list:
import requests
response = requests.get(
"https://lite.duckduckgo.com/lite/",
params={"q": query},
headers={"User-Agent": "Mozilla/5.0"},
)
response.raise_for_status()
parser = self._create_duckduckgo_parser()
parser.feed(response.text)
return parser.results
def _create_duckduckgo_parser(self):
from html.parser import HTMLParser
class SimpleResultParser(HTMLParser):
def __init__(self):
super().__init__()
self.results = []
self.current = {}
self.capture_title = False
self.capture_description = False
self.capture_link = False
def handle_starttag(self, tag, attrs):
attrs = dict(attrs)
if tag == "a" and attrs.get("class") == "result-link":
self.capture_title = True
elif tag == "td" and attrs.get("class") == "result-snippet":
self.capture_description = True
elif tag == "span" and attrs.get("class") == "link-text":
self.capture_link = True
def handle_endtag(self, tag):
if tag == "a" and self.capture_title:
self.capture_title = False
elif tag == "td" and self.capture_description:
self.capture_description = False
elif tag == "span" and self.capture_link:
self.capture_link = False
elif tag == "tr":
# Store current result if all parts are present
if {"title", "description", "link"} <= self.current.keys():
self.current["description"] = " ".join(self.current["description"])
self.results.append(self.current)
self.current = {}
def handle_data(self, data):
if self.capture_title:
self.current["title"] = data.strip()
elif self.capture_description:
self.current.setdefault("description", [])
self.current["description"].append(data.strip())
elif self.capture_link:
self.current["link"] = "https://" + data.strip()
return SimpleResultParser()
def search_bing(self, query: str) -> list:
import xml.etree.ElementTree as ET
import requests
response = requests.get(
"https://www.bing.com/search",
params={"q": query, "format": "rss"},
)
response.raise_for_status()
root = ET.fromstring(response.text)
items = root.findall(".//item")
results = [
{
"title": item.findtext("title"),
"link": item.findtext("link"),
"description": item.findtext("description"),
}
for item in items[: self.max_results]
]
return results
class VisitWebpageTool(Tool):
name = "visit_webpage"
description = (
"Visits a webpage at the given url and reads its content as a markdown string. Use this to browse webpages."
)
inputs = {
"url": {
"type": "string",
"description": "The url of the webpage to visit.",
}
}
output_type = "string"
def __init__(self, max_output_length: int = 40000):
super().__init__()
self.max_output_length = max_output_length
def _truncate_content(self, content: str, max_length: int) -> str:
if len(content) <= max_length:
return content
return (
content[: max_length // 2]
+ f"\n..._This content has been truncated to stay below {max_length} characters_...\n"
+ content[-max_length // 2 :]
)
def forward(self, url: str) -> str:
try:
import re
import requests
from markdownify import markdownify
from requests.exceptions import RequestException
except ImportError as e:
raise ImportError(
"You must install packages `markdownify` and `requests` to run this tool: for instance run `pip install markdownify requests`."
) from e
try:
# Send a GET request to the URL with a 20-second timeout
response = requests.get(url, timeout=20)
response.raise_for_status() # Raise an exception for bad status codes
# Convert the HTML content to Markdown
markdown_content = markdownify(response.text).strip()
# Remove multiple line breaks
markdown_content = re.sub(r"\n{3,}", "\n\n", markdown_content)
return self._truncate_content(markdown_content, self.max_output_length)
except requests.exceptions.Timeout:
return "The request timed out. Please try again later or check the URL."
except RequestException as e:
return f"Error fetching the webpage: {str(e)}"
except Exception as e:
return f"An unexpected error occurred: {str(e)}"
class WikipediaSearchTool(Tool):
"""
WikipediaSearchTool searches Wikipedia and returns a summary or full text of the given topic, along with the page URL.
Attributes:
user_agent (str): A custom user-agent string to identify the project. This is required as per Wikipedia API policies, read more here: http://github.com/martin-majlis/Wikipedia-API/blob/master/README.rst
language (str): The language in which to retrieve Wikipedia articles.
http://meta.wikimedia.org/wiki/List_of_Wikipedias
content_type (str): Defines the content to fetch. Can be "summary" for a short summary or "text" for the full article.
extract_format (str): Defines the output format. Can be `"WIKI"` or `"HTML"`.
Example:
>>> from smolagents import CodeAgent, InferenceClientModel, WikipediaSearchTool
>>> agent = CodeAgent(
>>> tools=[
>>> WikipediaSearchTool(
>>> user_agent="MyResearchBot ([email protected])",
>>> language="en",
>>> content_type="summary", # or "text"
>>> extract_format="WIKI",
>>> )
>>> ],
>>> model=InferenceClientModel(),
>>> )
>>> agent.run("Python_(programming_language)")
"""
name = "wikipedia_search"
description = "Searches Wikipedia and returns a summary or full text of the given topic, along with the page URL."
inputs = {
"query": {
"type": "string",
"description": "The topic to search on Wikipedia.",
}
}
output_type = "string"
def __init__(
self,
user_agent: str = "Smolagents ([email protected])",
language: str = "en",
content_type: str = "text",
extract_format: str = "WIKI",
):
super().__init__()
try:
import wikipediaapi
except ImportError as e:
raise ImportError(
"You must install `wikipedia-api` to run this tool: for instance run `pip install wikipedia-api`"
) from e
if not user_agent:
raise ValueError("User-agent is required. Provide a meaningful identifier for your project.")
self.user_agent = user_agent
self.language = language
self.content_type = content_type
# Map string format to wikipediaapi.ExtractFormat
extract_format_map = {
"WIKI": wikipediaapi.ExtractFormat.WIKI,
"HTML": wikipediaapi.ExtractFormat.HTML,
}
if extract_format not in extract_format_map:
raise ValueError("Invalid extract_format. Choose between 'WIKI' or 'HTML'.")
self.extract_format = extract_format_map[extract_format]
self.wiki = wikipediaapi.Wikipedia(
user_agent=self.user_agent, language=self.language, extract_format=self.extract_format
)
def forward(self, query: str) -> str:
try:
page = self.wiki.page(query)
if not page.exists():
return f"No Wikipedia page found for '{query}'. Try a different query."
title = page.title
url = page.fullurl
if self.content_type == "summary":
text = page.summary
elif self.content_type == "text":
text = page.text
else:
return "⚠️ Invalid `content_type`. Use either 'summary' or 'text'."
return f"✅ **Wikipedia Page:** {title}\n\n**Content:** {text}\n\n🔗 **Read more:** {url}"
except Exception as e:
return f"Error fetching Wikipedia summary: {str(e)}"
class SpeechToTextTool(PipelineTool):
default_checkpoint = "openai/whisper-large-v3-turbo"
description = "This is a tool that transcribes an audio into text. It returns the transcribed text."
name = "transcriber"
inputs = {
"audio": {
"type": "audio",
"description": "The audio to transcribe. Can be a local path, an url, or a tensor.",
}
}
output_type = "string"
def __new__(cls, *args, **kwargs):
from transformers.models.whisper import WhisperForConditionalGeneration, WhisperProcessor
cls.pre_processor_class = WhisperProcessor
cls.model_class = WhisperForConditionalGeneration
return super().__new__(cls)
def encode(self, audio):
from .agent_types import AgentAudio
audio = AgentAudio(audio).to_raw()
return self.pre_processor(audio, return_tensors="pt")
def forward(self, inputs):
return self.model.generate(inputs["input_features"])
def decode(self, outputs):
return self.pre_processor.batch_decode(outputs, skip_special_tokens=True)[0]
TOOL_MAPPING = {
tool_class.name: tool_class
for tool_class in [
PythonInterpreterTool,
DuckDuckGoSearchTool,
VisitWebpageTool,
]
}
__all__ = [
"ApiWebSearchTool",
"PythonInterpreterTool",
"FinalAnswerTool",
"UserInputTool",
"WebSearchTool",
"DuckDuckGoSearchTool",
"GoogleSearchTool",
"VisitWebpageTool",
"WikipediaSearchTool",
"SpeechToTextTool",
]