|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations |
|
|
|
import ast |
|
import inspect |
|
import json |
|
import logging |
|
import os |
|
import sys |
|
import tempfile |
|
import textwrap |
|
import types |
|
import warnings |
|
from collections.abc import Callable |
|
from contextlib import contextmanager |
|
from functools import wraps |
|
from pathlib import Path |
|
from typing import TYPE_CHECKING, Any |
|
|
|
from huggingface_hub import ( |
|
CommitOperationAdd, |
|
create_commit, |
|
create_repo, |
|
get_collection, |
|
hf_hub_download, |
|
metadata_update, |
|
) |
|
|
|
from ._function_type_hints_utils import ( |
|
TypeHintParsingException, |
|
_convert_type_hints_to_json_schema, |
|
_get_json_schema_type, |
|
get_imports, |
|
get_json_schema, |
|
) |
|
from .agent_types import handle_agent_input_types, handle_agent_output_types |
|
from .tool_validation import MethodChecker, validate_tool_attributes |
|
from .utils import ( |
|
BASE_BUILTIN_MODULES, |
|
_is_package_available, |
|
get_source, |
|
instance_to_source, |
|
is_valid_name, |
|
) |
|
|
|
|
|
if TYPE_CHECKING: |
|
import mcp |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def validate_after_init(cls): |
|
original_init = cls.__init__ |
|
|
|
@wraps(original_init) |
|
def new_init(self, *args, **kwargs): |
|
original_init(self, *args, **kwargs) |
|
self.validate_arguments() |
|
|
|
cls.__init__ = new_init |
|
return cls |
|
|
|
|
|
AUTHORIZED_TYPES = [ |
|
"string", |
|
"boolean", |
|
"integer", |
|
"number", |
|
"image", |
|
"audio", |
|
"array", |
|
"object", |
|
"any", |
|
"null", |
|
] |
|
|
|
CONVERSION_DICT = {"str": "string", "int": "integer", "float": "number"} |
|
|
|
|
|
class Tool: |
|
""" |
|
A base class for the functions used by the agent. Subclass this and implement the `forward` method as well as the |
|
following class attributes: |
|
|
|
- **description** (`str`) -- A short description of what your tool does, the inputs it expects and the output(s) it |
|
will return. For instance 'This is a tool that downloads a file from a `url`. It takes the `url` as input, and |
|
returns the text contained in the file'. |
|
- **name** (`str`) -- A performative name that will be used for your tool in the prompt to the agent. For instance |
|
`"text-classifier"` or `"image_generator"`. |
|
- **inputs** (`Dict[str, Dict[str, Union[str, type, bool]]]`) -- The dict of modalities expected for the inputs. |
|
It has one `type`key and a `description`key. |
|
This is used by `launch_gradio_demo` or to make a nice space from your tool, and also can be used in the generated |
|
description for your tool. |
|
- **output_type** (`type`) -- The type of the tool output. This is used by `launch_gradio_demo` |
|
or to make a nice space from your tool, and also can be used in the generated description for your tool. |
|
|
|
You can also override the method [`~Tool.setup`] if your tool has an expensive operation to perform before being |
|
usable (such as loading a model). [`~Tool.setup`] will be called the first time you use your tool, but not at |
|
instantiation. |
|
""" |
|
|
|
name: str |
|
description: str |
|
inputs: dict[str, dict[str, str | type | bool]] |
|
output_type: str |
|
|
|
def __init__(self, *args, **kwargs): |
|
self.is_initialized = False |
|
|
|
def __init_subclass__(cls, **kwargs): |
|
super().__init_subclass__(**kwargs) |
|
validate_after_init(cls) |
|
|
|
def validate_arguments(self): |
|
required_attributes = { |
|
"description": str, |
|
"name": str, |
|
"inputs": dict, |
|
"output_type": str, |
|
} |
|
|
|
for attr, expected_type in required_attributes.items(): |
|
attr_value = getattr(self, attr, None) |
|
if attr_value is None: |
|
raise TypeError(f"You must set an attribute {attr}.") |
|
if not isinstance(attr_value, expected_type): |
|
raise TypeError( |
|
f"Attribute {attr} should have type {expected_type.__name__}, got {type(attr_value)} instead." |
|
) |
|
|
|
if not is_valid_name(self.name): |
|
raise Exception( |
|
f"Invalid Tool name '{self.name}': must be a valid Python identifier and not a reserved keyword" |
|
) |
|
|
|
for input_name, input_content in self.inputs.items(): |
|
assert isinstance(input_content, dict), f"Input '{input_name}' should be a dictionary." |
|
assert "type" in input_content and "description" in input_content, ( |
|
f"Input '{input_name}' should have keys 'type' and 'description', has only {list(input_content.keys())}." |
|
) |
|
if input_content["type"] not in AUTHORIZED_TYPES: |
|
raise Exception( |
|
f"Input '{input_name}': type '{input_content['type']}' is not an authorized value, should be one of {AUTHORIZED_TYPES}." |
|
) |
|
|
|
assert getattr(self, "output_type", None) in AUTHORIZED_TYPES |
|
|
|
|
|
if not ( |
|
hasattr(self, "skip_forward_signature_validation") |
|
and getattr(self, "skip_forward_signature_validation") is True |
|
): |
|
signature = inspect.signature(self.forward) |
|
actual_keys = set(key for key in signature.parameters.keys() if key != "self") |
|
expected_keys = set(self.inputs.keys()) |
|
if actual_keys != expected_keys: |
|
raise Exception( |
|
f"In tool '{self.name}', 'forward' method parameters were {actual_keys}, but expected {expected_keys}. " |
|
f"It should take 'self' as its first argument, then its next arguments should match the keys of tool attribute 'inputs'." |
|
) |
|
|
|
json_schema = _convert_type_hints_to_json_schema(self.forward, error_on_missing_type_hints=False)[ |
|
"properties" |
|
] |
|
for key, value in self.inputs.items(): |
|
assert key in json_schema, ( |
|
f"Input '{key}' should be present in function signature, found only {json_schema.keys()}" |
|
) |
|
if "nullable" in value: |
|
assert "nullable" in json_schema[key], ( |
|
f"Nullable argument '{key}' in inputs should have key 'nullable' set to True in function signature." |
|
) |
|
if key in json_schema and "nullable" in json_schema[key]: |
|
assert "nullable" in value, ( |
|
f"Nullable argument '{key}' in function signature should have key 'nullable' set to True in inputs." |
|
) |
|
|
|
def forward(self, *args, **kwargs): |
|
return NotImplementedError("Write this method in your subclass of `Tool`.") |
|
|
|
def __call__(self, *args, sanitize_inputs_outputs: bool = False, **kwargs): |
|
if not self.is_initialized: |
|
self.setup() |
|
|
|
|
|
if len(args) == 1 and len(kwargs) == 0 and isinstance(args[0], dict): |
|
potential_kwargs = args[0] |
|
|
|
|
|
if all(key in self.inputs for key in potential_kwargs): |
|
args = () |
|
kwargs = potential_kwargs |
|
|
|
if sanitize_inputs_outputs: |
|
args, kwargs = handle_agent_input_types(*args, **kwargs) |
|
outputs = self.forward(*args, **kwargs) |
|
if sanitize_inputs_outputs: |
|
outputs = handle_agent_output_types(outputs, self.output_type) |
|
return outputs |
|
|
|
def setup(self): |
|
""" |
|
Overwrite this method here for any operation that is expensive and needs to be executed before you start using |
|
your tool. Such as loading a big model. |
|
""" |
|
self.is_initialized = True |
|
|
|
def to_dict(self) -> dict: |
|
"""Returns a dictionary representing the tool""" |
|
class_name = self.__class__.__name__ |
|
if type(self).__name__ == "SimpleTool": |
|
|
|
source_code = get_source(self.forward).replace("@tool", "") |
|
forward_node = ast.parse(source_code) |
|
|
|
method_checker = MethodChecker(set()) |
|
method_checker.visit(forward_node) |
|
|
|
if len(method_checker.errors) > 0: |
|
errors = [f"- {error}" for error in method_checker.errors] |
|
raise (ValueError(f"SimpleTool validation failed for {self.name}:\n" + "\n".join(errors))) |
|
|
|
forward_source_code = get_source(self.forward) |
|
tool_code = textwrap.dedent( |
|
f""" |
|
from smolagents import Tool |
|
from typing import Any, Optional |
|
|
|
class {class_name}(Tool): |
|
name = "{self.name}" |
|
description = {json.dumps(textwrap.dedent(self.description).strip())} |
|
inputs = {repr(self.inputs)} |
|
output_type = "{self.output_type}" |
|
""" |
|
).strip() |
|
import re |
|
|
|
def add_self_argument(source_code: str) -> str: |
|
"""Add 'self' as first argument to a function definition if not present.""" |
|
pattern = r"def forward\(((?!self)[^)]*)\)" |
|
|
|
def replacement(match): |
|
args = match.group(1).strip() |
|
if args: |
|
return f"def forward(self, {args})" |
|
return "def forward(self)" |
|
|
|
return re.sub(pattern, replacement, source_code) |
|
|
|
forward_source_code = forward_source_code.replace(self.name, "forward") |
|
forward_source_code = add_self_argument(forward_source_code) |
|
forward_source_code = forward_source_code.replace("@tool", "").strip() |
|
tool_code += "\n\n" + textwrap.indent(forward_source_code, " ") |
|
|
|
else: |
|
if type(self).__name__ in [ |
|
"SpaceToolWrapper", |
|
"LangChainToolWrapper", |
|
"GradioToolWrapper", |
|
]: |
|
raise ValueError( |
|
"Cannot save objects created with from_space, from_langchain or from_gradio, as this would create errors." |
|
) |
|
|
|
validate_tool_attributes(self.__class__) |
|
|
|
tool_code = "from typing import Any, Optional\n" + instance_to_source(self, base_cls=Tool) |
|
|
|
requirements = {el for el in get_imports(tool_code) if el not in sys.stdlib_module_names} | {"smolagents"} |
|
|
|
return {"name": self.name, "code": tool_code, "requirements": sorted(requirements)} |
|
|
|
@classmethod |
|
def from_dict(cls, tool_dict: dict[str, Any], **kwargs) -> "Tool": |
|
""" |
|
Create tool from a dictionary representation. |
|
|
|
Args: |
|
tool_dict (`dict[str, Any]`): Dictionary representation of the tool. |
|
**kwargs: Additional keyword arguments to pass to the tool's constructor. |
|
|
|
Returns: |
|
`Tool`: Tool object. |
|
""" |
|
if "code" not in tool_dict: |
|
raise ValueError("Tool dictionary must contain 'code' key with the tool source code") |
|
return cls.from_code(tool_dict["code"], **kwargs) |
|
|
|
def save(self, output_dir: str | Path, tool_file_name: str = "tool", make_gradio_app: bool = True): |
|
""" |
|
Saves the relevant code files for your tool so it can be pushed to the Hub. This will copy the code of your |
|
tool in `output_dir` as well as autogenerate: |
|
|
|
- a `{tool_file_name}.py` file containing the logic for your tool. |
|
If you pass `make_gradio_app=True`, this will also write: |
|
- an `app.py` file providing a UI for your tool when it is exported to a Space with `tool.push_to_hub()` |
|
- a `requirements.txt` containing the names of the modules used by your tool (as detected when inspecting its |
|
code) |
|
|
|
Args: |
|
output_dir (`str` or `Path`): The folder in which you want to save your tool. |
|
tool_file_name (`str`, *optional*): The file name in which you want to save your tool. |
|
make_gradio_app (`bool`, *optional*, defaults to True): Whether to also export a `requirements.txt` file and Gradio UI. |
|
""" |
|
|
|
output_path = Path(output_dir) |
|
output_path.mkdir(parents=True, exist_ok=True) |
|
|
|
self._write_file(output_path / f"{tool_file_name}.py", self._get_tool_code()) |
|
if make_gradio_app: |
|
|
|
self._write_file(output_path / "app.py", self._get_gradio_app_code(tool_module_name=tool_file_name)) |
|
|
|
self._write_file(output_path / "requirements.txt", self._get_requirements()) |
|
|
|
def _write_file(self, file_path: Path, content: str) -> None: |
|
"""Writes content to a file with UTF-8 encoding.""" |
|
file_path.write_text(content, encoding="utf-8") |
|
|
|
def push_to_hub( |
|
self, |
|
repo_id: str, |
|
commit_message: str = "Upload tool", |
|
private: bool | None = None, |
|
token: bool | str | None = None, |
|
create_pr: bool = False, |
|
) -> str: |
|
""" |
|
Upload the tool to the Hub. |
|
|
|
Parameters: |
|
repo_id (`str`): |
|
The name of the repository you want to push your tool to. It should contain your organization name when |
|
pushing to a given organization. |
|
commit_message (`str`, *optional*, defaults to `"Upload tool"`): |
|
Message to commit while pushing. |
|
private (`bool`, *optional*): |
|
Whether to make the repo private. If `None` (default), the repo will be public unless the organization's default is private. This value is ignored if the repo already exists. |
|
token (`bool` or `str`, *optional*): |
|
The token to use as HTTP bearer authorization for remote files. If unset, will use the token generated |
|
when running `huggingface-cli login` (stored in `~/.huggingface`). |
|
create_pr (`bool`, *optional*, defaults to `False`): |
|
Whether to create a PR with the uploaded files or directly commit. |
|
""" |
|
|
|
repo_id = self._initialize_hub_repo(repo_id, token, private) |
|
|
|
additions = self._prepare_hub_files() |
|
|
|
return create_commit( |
|
repo_id=repo_id, |
|
operations=additions, |
|
commit_message=commit_message, |
|
token=token, |
|
create_pr=create_pr, |
|
repo_type="space", |
|
) |
|
|
|
@staticmethod |
|
def _initialize_hub_repo(repo_id: str, token: bool | str | None, private: bool | None) -> str: |
|
"""Initialize repository on Hugging Face Hub.""" |
|
repo_url = create_repo( |
|
repo_id=repo_id, |
|
token=token, |
|
private=private, |
|
exist_ok=True, |
|
repo_type="space", |
|
space_sdk="gradio", |
|
) |
|
metadata_update(repo_url.repo_id, {"tags": ["smolagents", "tool"]}, repo_type="space", token=token) |
|
return repo_url.repo_id |
|
|
|
def _prepare_hub_files(self) -> list: |
|
"""Prepare files for Hub commit.""" |
|
additions = [ |
|
|
|
CommitOperationAdd( |
|
path_in_repo="tool.py", |
|
path_or_fileobj=self._get_tool_code().encode(), |
|
), |
|
|
|
CommitOperationAdd( |
|
path_in_repo="app.py", |
|
path_or_fileobj=self._get_gradio_app_code().encode(), |
|
), |
|
|
|
CommitOperationAdd( |
|
path_in_repo="requirements.txt", |
|
path_or_fileobj=self._get_requirements().encode(), |
|
), |
|
] |
|
return additions |
|
|
|
def _get_tool_code(self) -> str: |
|
"""Get the tool's code.""" |
|
return self.to_dict()["code"] |
|
|
|
def _get_gradio_app_code(self, tool_module_name: str = "tool") -> str: |
|
"""Get the Gradio app code.""" |
|
class_name = self.__class__.__name__ |
|
return textwrap.dedent( |
|
f"""\ |
|
from smolagents import launch_gradio_demo |
|
from {tool_module_name} import {class_name} |
|
|
|
tool = {class_name}() |
|
launch_gradio_demo(tool) |
|
""" |
|
) |
|
|
|
def _get_requirements(self) -> str: |
|
"""Get the requirements.""" |
|
return "\n".join(self.to_dict()["requirements"]) |
|
|
|
@classmethod |
|
def from_hub( |
|
cls, |
|
repo_id: str, |
|
token: str | None = None, |
|
trust_remote_code: bool = False, |
|
**kwargs, |
|
): |
|
""" |
|
Loads a tool defined on the Hub. |
|
|
|
<Tip warning={true}> |
|
|
|
Loading a tool from the Hub means that you'll download the tool and execute it locally. |
|
ALWAYS inspect the tool you're downloading before loading it within your runtime, as you would do when |
|
installing a package using pip/npm/apt. |
|
|
|
</Tip> |
|
|
|
Args: |
|
repo_id (`str`): |
|
The name of the Space repo on the Hub where your tool is defined. |
|
token (`str`, *optional*): |
|
The token to identify you on hf.co. If unset, will use the token generated when running |
|
`huggingface-cli login` (stored in `~/.huggingface`). |
|
trust_remote_code(`str`, *optional*, defaults to False): |
|
This flags marks that you understand the risk of running remote code and that you trust this tool. |
|
If not setting this to True, loading the tool from Hub will fail. |
|
kwargs (additional keyword arguments, *optional*): |
|
Additional keyword arguments that will be split in two: all arguments relevant to the Hub (such as |
|
`cache_dir`, `revision`, `subfolder`) will be used when downloading the files for your tool, and the |
|
others will be passed along to its init. |
|
""" |
|
if not trust_remote_code: |
|
raise ValueError( |
|
"Loading a tool from Hub requires to acknowledge you trust its code: to do so, pass `trust_remote_code=True`." |
|
) |
|
|
|
|
|
tool_file = hf_hub_download( |
|
repo_id, |
|
"tool.py", |
|
token=token, |
|
repo_type="space", |
|
cache_dir=kwargs.get("cache_dir"), |
|
force_download=kwargs.get("force_download"), |
|
proxies=kwargs.get("proxies"), |
|
revision=kwargs.get("revision"), |
|
subfolder=kwargs.get("subfolder"), |
|
local_files_only=kwargs.get("local_files_only"), |
|
) |
|
|
|
tool_code = Path(tool_file).read_text() |
|
return Tool.from_code(tool_code, **kwargs) |
|
|
|
@classmethod |
|
def from_code(cls, tool_code: str, **kwargs): |
|
module = types.ModuleType("dynamic_tool") |
|
|
|
exec(tool_code, module.__dict__) |
|
|
|
|
|
tool_class = next( |
|
( |
|
obj |
|
for _, obj in inspect.getmembers(module, inspect.isclass) |
|
if issubclass(obj, Tool) and obj is not Tool |
|
), |
|
None, |
|
) |
|
|
|
if tool_class is None: |
|
raise ValueError("No Tool subclass found in the code.") |
|
|
|
if not isinstance(tool_class.inputs, dict): |
|
tool_class.inputs = ast.literal_eval(tool_class.inputs) |
|
|
|
return tool_class(**kwargs) |
|
|
|
@staticmethod |
|
def from_space( |
|
space_id: str, |
|
name: str, |
|
description: str, |
|
api_name: str | None = None, |
|
token: str | None = None, |
|
): |
|
""" |
|
Creates a [`Tool`] from a Space given its id on the Hub. |
|
|
|
Args: |
|
space_id (`str`): |
|
The id of the Space on the Hub. |
|
name (`str`): |
|
The name of the tool. |
|
description (`str`): |
|
The description of the tool. |
|
api_name (`str`, *optional*): |
|
The specific api_name to use, if the space has several tabs. If not precised, will default to the first available api. |
|
token (`str`, *optional*): |
|
Add your token to access private spaces or increase your GPU quotas. |
|
Returns: |
|
[`Tool`]: |
|
The Space, as a tool. |
|
|
|
Examples: |
|
```py |
|
>>> image_generator = Tool.from_space( |
|
... space_id="black-forest-labs/FLUX.1-schnell", |
|
... name="image-generator", |
|
... description="Generate an image from a prompt" |
|
... ) |
|
>>> image = image_generator("Generate an image of a cool surfer in Tahiti") |
|
``` |
|
```py |
|
>>> face_swapper = Tool.from_space( |
|
... "tuan2308/face-swap", |
|
... "face_swapper", |
|
... "Tool that puts the face shown on the first image on the second image. You can give it paths to images.", |
|
... ) |
|
>>> image = face_swapper('./aymeric.jpeg', './ruth.jpg') |
|
``` |
|
""" |
|
from gradio_client import Client, handle_file |
|
|
|
class SpaceToolWrapper(Tool): |
|
skip_forward_signature_validation = True |
|
|
|
def __init__( |
|
self, |
|
space_id: str, |
|
name: str, |
|
description: str, |
|
api_name: str | None = None, |
|
token: str | None = None, |
|
): |
|
self.name = name |
|
self.description = description |
|
self.client = Client(space_id, hf_token=token) |
|
space_description = self.client.view_api(return_format="dict", print_info=False)["named_endpoints"] |
|
|
|
|
|
if api_name is None: |
|
api_name = list(space_description.keys())[0] |
|
logger.warning( |
|
f"Since `api_name` was not defined, it was automatically set to the first available API: `{api_name}`." |
|
) |
|
self.api_name = api_name |
|
|
|
try: |
|
space_description_api = space_description[api_name] |
|
except KeyError: |
|
raise KeyError(f"Could not find specified {api_name=} among available api names.") |
|
|
|
self.inputs = {} |
|
for parameter in space_description_api["parameters"]: |
|
if not parameter["parameter_has_default"]: |
|
parameter_type = parameter["type"]["type"] |
|
if parameter_type == "object": |
|
parameter_type = "any" |
|
self.inputs[parameter["parameter_name"]] = { |
|
"type": parameter_type, |
|
"description": parameter["python_type"]["description"], |
|
} |
|
output_component = space_description_api["returns"][0]["component"] |
|
if output_component == "Image": |
|
self.output_type = "image" |
|
elif output_component == "Audio": |
|
self.output_type = "audio" |
|
else: |
|
self.output_type = "any" |
|
self.is_initialized = True |
|
|
|
def sanitize_argument_for_prediction(self, arg): |
|
from gradio_client.utils import is_http_url_like |
|
from PIL.Image import Image |
|
|
|
if isinstance(arg, Image): |
|
temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False) |
|
arg.save(temp_file.name) |
|
arg = temp_file.name |
|
if ( |
|
(isinstance(arg, str) and os.path.isfile(arg)) |
|
or (isinstance(arg, Path) and arg.exists() and arg.is_file()) |
|
or is_http_url_like(arg) |
|
): |
|
arg = handle_file(arg) |
|
return arg |
|
|
|
def forward(self, *args, **kwargs): |
|
|
|
args = list(args) |
|
for i, arg in enumerate(args): |
|
args[i] = self.sanitize_argument_for_prediction(arg) |
|
for arg_name, arg in kwargs.items(): |
|
kwargs[arg_name] = self.sanitize_argument_for_prediction(arg) |
|
|
|
output = self.client.predict(*args, api_name=self.api_name, **kwargs) |
|
if isinstance(output, tuple) or isinstance(output, list): |
|
return output[ |
|
0 |
|
] |
|
return output |
|
|
|
return SpaceToolWrapper( |
|
space_id=space_id, |
|
name=name, |
|
description=description, |
|
api_name=api_name, |
|
token=token, |
|
) |
|
|
|
@staticmethod |
|
def from_gradio(gradio_tool): |
|
""" |
|
Creates a [`Tool`] from a gradio tool. |
|
""" |
|
import inspect |
|
|
|
class GradioToolWrapper(Tool): |
|
def __init__(self, _gradio_tool): |
|
self.name = _gradio_tool.name |
|
self.description = _gradio_tool.description |
|
self.output_type = "string" |
|
self._gradio_tool = _gradio_tool |
|
func_args = list(inspect.signature(_gradio_tool.run).parameters.items()) |
|
self.inputs = { |
|
key: {"type": CONVERSION_DICT[value.annotation], "description": ""} for key, value in func_args |
|
} |
|
self.forward = self._gradio_tool.run |
|
|
|
return GradioToolWrapper(gradio_tool) |
|
|
|
@staticmethod |
|
def from_langchain(langchain_tool): |
|
""" |
|
Creates a [`Tool`] from a langchain tool. |
|
""" |
|
|
|
class LangChainToolWrapper(Tool): |
|
skip_forward_signature_validation = True |
|
|
|
def __init__(self, _langchain_tool): |
|
self.name = _langchain_tool.name.lower() |
|
self.description = _langchain_tool.description |
|
self.inputs = _langchain_tool.args.copy() |
|
for input_content in self.inputs.values(): |
|
if "title" in input_content: |
|
input_content.pop("title") |
|
input_content["description"] = "" |
|
self.output_type = "string" |
|
self.langchain_tool = _langchain_tool |
|
self.is_initialized = True |
|
|
|
def forward(self, *args, **kwargs): |
|
tool_input = kwargs.copy() |
|
for index, argument in enumerate(args): |
|
if index < len(self.inputs): |
|
input_key = next(iter(self.inputs)) |
|
tool_input[input_key] = argument |
|
return self.langchain_tool.run(tool_input) |
|
|
|
return LangChainToolWrapper(langchain_tool) |
|
|
|
|
|
def launch_gradio_demo(tool: Tool): |
|
""" |
|
Launches a gradio demo for a tool. The corresponding tool class needs to properly implement the class attributes |
|
`inputs` and `output_type`. |
|
|
|
Args: |
|
tool (`Tool`): The tool for which to launch the demo. |
|
""" |
|
try: |
|
import gradio as gr |
|
except ImportError: |
|
raise ImportError("Gradio should be installed in order to launch a gradio demo.") |
|
|
|
TYPE_TO_COMPONENT_CLASS_MAPPING = { |
|
"boolean": gr.Checkbox, |
|
"image": gr.Image, |
|
"audio": gr.Audio, |
|
"string": gr.Textbox, |
|
"integer": gr.Textbox, |
|
"number": gr.Textbox, |
|
} |
|
|
|
def tool_forward(*args, **kwargs): |
|
return tool(*args, sanitize_inputs_outputs=True, **kwargs) |
|
|
|
tool_forward.__signature__ = inspect.signature(tool.forward) |
|
|
|
gradio_inputs = [] |
|
for input_name, input_details in tool.inputs.items(): |
|
input_gradio_component_class = TYPE_TO_COMPONENT_CLASS_MAPPING[input_details["type"]] |
|
new_component = input_gradio_component_class(label=input_name) |
|
gradio_inputs.append(new_component) |
|
|
|
output_gradio_component_class = TYPE_TO_COMPONENT_CLASS_MAPPING[tool.output_type] |
|
gradio_output = output_gradio_component_class(label="Output") |
|
|
|
gr.Interface( |
|
fn=tool_forward, |
|
inputs=gradio_inputs, |
|
outputs=gradio_output, |
|
title=tool.name, |
|
description=tool.description, |
|
api_name=tool.name, |
|
).launch() |
|
|
|
|
|
def load_tool( |
|
repo_id, |
|
model_repo_id: str | None = None, |
|
token: str | None = None, |
|
trust_remote_code: bool = False, |
|
**kwargs, |
|
): |
|
""" |
|
Main function to quickly load a tool from the Hub. |
|
|
|
<Tip warning={true}> |
|
|
|
Loading a tool means that you'll download the tool and execute it locally. |
|
ALWAYS inspect the tool you're downloading before loading it within your runtime, as you would do when |
|
installing a package using pip/npm/apt. |
|
|
|
</Tip> |
|
|
|
Args: |
|
repo_id (`str`): |
|
Space repo ID of a tool on the Hub. |
|
model_repo_id (`str`, *optional*): |
|
Use this argument to use a different model than the default one for the tool you selected. |
|
token (`str`, *optional*): |
|
The token to identify you on hf.co. If unset, will use the token generated when running `huggingface-cli |
|
login` (stored in `~/.huggingface`). |
|
trust_remote_code (`bool`, *optional*, defaults to False): |
|
This needs to be accepted in order to load a tool from Hub. |
|
kwargs (additional keyword arguments, *optional*): |
|
Additional keyword arguments that will be split in two: all arguments relevant to the Hub (such as |
|
`cache_dir`, `revision`, `subfolder`) will be used when downloading the files for your tool, and the others |
|
will be passed along to its init. |
|
""" |
|
return Tool.from_hub( |
|
repo_id, |
|
model_repo_id=model_repo_id, |
|
token=token, |
|
trust_remote_code=trust_remote_code, |
|
**kwargs, |
|
) |
|
|
|
|
|
def add_description(description): |
|
""" |
|
A decorator that adds a description to a function. |
|
""" |
|
|
|
def inner(func): |
|
func.description = description |
|
func.name = func.__name__ |
|
return func |
|
|
|
return inner |
|
|
|
|
|
class ToolCollection: |
|
""" |
|
Tool collections enable loading a collection of tools in the agent's toolbox. |
|
|
|
Collections can be loaded from a collection in the Hub or from an MCP server, see: |
|
- [`ToolCollection.from_hub`] |
|
- [`ToolCollection.from_mcp`] |
|
|
|
For example and usage, see: [`ToolCollection.from_hub`] and [`ToolCollection.from_mcp`] |
|
""" |
|
|
|
def __init__(self, tools: list[Tool]): |
|
self.tools = tools |
|
|
|
@classmethod |
|
def from_hub( |
|
cls, |
|
collection_slug: str, |
|
token: str | None = None, |
|
trust_remote_code: bool = False, |
|
) -> "ToolCollection": |
|
"""Loads a tool collection from the Hub. |
|
|
|
it adds a collection of tools from all Spaces in the collection to the agent's toolbox |
|
|
|
> [!NOTE] |
|
> Only Spaces will be fetched, so you can feel free to add models and datasets to your collection if you'd |
|
> like for this collection to showcase them. |
|
|
|
Args: |
|
collection_slug (str): The collection slug referencing the collection. |
|
token (str, *optional*): The authentication token if the collection is private. |
|
trust_remote_code (bool, *optional*, defaults to False): Whether to trust the remote code. |
|
|
|
Returns: |
|
ToolCollection: A tool collection instance loaded with the tools. |
|
|
|
Example: |
|
```py |
|
>>> from smolagents import ToolCollection, CodeAgent |
|
|
|
>>> image_tool_collection = ToolCollection.from_hub("huggingface-tools/diffusion-tools-6630bb19a942c2306a2cdb6f") |
|
>>> agent = CodeAgent(tools=[*image_tool_collection.tools], add_base_tools=True) |
|
|
|
>>> agent.run("Please draw me a picture of rivers and lakes.") |
|
``` |
|
""" |
|
_collection = get_collection(collection_slug, token=token) |
|
_hub_repo_ids = {item.item_id for item in _collection.items if item.item_type == "space"} |
|
|
|
tools = {Tool.from_hub(repo_id, token, trust_remote_code) for repo_id in _hub_repo_ids} |
|
|
|
return cls(tools) |
|
|
|
@classmethod |
|
@contextmanager |
|
def from_mcp( |
|
cls, server_parameters: "mcp.StdioServerParameters" | dict, trust_remote_code: bool = False |
|
) -> "ToolCollection": |
|
"""Automatically load a tool collection from an MCP server. |
|
|
|
This method supports Stdio, Streamable HTTP, and legacy HTTP+SSE MCP servers. Look at the `server_parameters` |
|
argument for more details on how to connect to each MCP server. |
|
|
|
Note: a separate thread will be spawned to run an asyncio event loop handling |
|
the MCP server. |
|
|
|
Args: |
|
server_parameters (`mcp.StdioServerParameters` or `dict`): |
|
Configuration parameters to connect to the MCP server. This can be: |
|
|
|
- An instance of `mcp.StdioServerParameters` for connecting a Stdio MCP server via standard input/output using a subprocess. |
|
|
|
- A `dict` with at least: |
|
- "url": URL of the server. |
|
- "transport": Transport protocol to use, one of: |
|
- "streamable-http": (recommended) Streamable HTTP transport. |
|
- "sse": Legacy HTTP+SSE transport (deprecated). |
|
If "transport" is omitted, the legacy "sse" transport is assumed (a deprecation warning will be issued). |
|
|
|
<Deprecated version="1.17.0"> |
|
The HTTP+SSE transport is deprecated and future behavior will default to the Streamable HTTP transport. |
|
Please pass explicitly the "transport" key. |
|
</Deprecated> |
|
trust_remote_code (`bool`, *optional*, defaults to `False`): |
|
Whether to trust the execution of code from tools defined on the MCP server. |
|
This option should only be set to `True` if you trust the MCP server, |
|
and undertand the risks associated with running remote code on your local machine. |
|
If set to `False`, loading tools from MCP will fail. |
|
|
|
|
|
Returns: |
|
ToolCollection: A tool collection instance. |
|
|
|
Example with a Stdio MCP server: |
|
```py |
|
>>> import os |
|
>>> from smolagents import ToolCollection, CodeAgent, InferenceClientModel |
|
>>> from mcp import StdioServerParameters |
|
|
|
>>> model = InferenceClientModel() |
|
|
|
>>> server_parameters = StdioServerParameters( |
|
>>> command="uvx", |
|
>>> args=["--quiet", "[email protected]"], |
|
>>> env={"UV_PYTHON": "3.12", **os.environ}, |
|
>>> ) |
|
|
|
>>> with ToolCollection.from_mcp(server_parameters, trust_remote_code=True) as tool_collection: |
|
>>> agent = CodeAgent(tools=[*tool_collection.tools], add_base_tools=True, model=model) |
|
>>> agent.run("Please find a remedy for hangover.") |
|
``` |
|
|
|
Example with a Streamable HTTP MCP server: |
|
```py |
|
>>> with ToolCollection.from_mcp({"url": "http://127.0.0.1:8000/mcp", "transport": "streamable-http"}, trust_remote_code=True) as tool_collection: |
|
>>> agent = CodeAgent(tools=[*tool_collection.tools], add_base_tools=True, model=model) |
|
>>> agent.run("Please find a remedy for hangover.") |
|
``` |
|
""" |
|
try: |
|
from mcpadapt.core import MCPAdapt |
|
from mcpadapt.smolagents_adapter import SmolAgentsAdapter |
|
except ImportError: |
|
raise ImportError( |
|
"""Please install 'mcp' extra to use ToolCollection.from_mcp: `pip install "smolagents[mcp]"`.""" |
|
) |
|
if isinstance(server_parameters, dict): |
|
transport = server_parameters.get("transport") |
|
if transport is None: |
|
warnings.warn( |
|
"Passing a dict as server_parameters without specifying the 'transport' key is deprecated. " |
|
"For now, it defaults to the legacy 'sse' (HTTP+SSE) transport, but this default will change " |
|
"to 'streamable-http' in version 1.20. Please add the 'transport' key explicitly. ", |
|
FutureWarning, |
|
) |
|
transport = "sse" |
|
server_parameters["transport"] = transport |
|
if transport not in {"sse", "streamable-http"}: |
|
raise ValueError( |
|
f"Unsupported transport: {transport}. Supported transports are 'streamable-http' and 'sse'." |
|
) |
|
if not trust_remote_code: |
|
raise ValueError( |
|
"Loading tools from MCP requires you to acknowledge you trust the MCP server, " |
|
"as it will execute code on your local machine: pass `trust_remote_code=True`." |
|
) |
|
with MCPAdapt(server_parameters, SmolAgentsAdapter()) as tools: |
|
yield cls(tools) |
|
|
|
|
|
def tool(tool_function: Callable) -> Tool: |
|
""" |
|
Convert a function into an instance of a dynamically created Tool subclass. |
|
|
|
Args: |
|
tool_function (`Callable`): Function to convert into a Tool subclass. |
|
Should have type hints for each input and a type hint for the output. |
|
Should also have a docstring including the description of the function |
|
and an 'Args:' part where each argument is described. |
|
""" |
|
tool_json_schema = get_json_schema(tool_function)["function"] |
|
if "return" not in tool_json_schema: |
|
raise TypeHintParsingException("Tool return type not found: make sure your function has a return type hint!") |
|
|
|
class SimpleTool(Tool): |
|
def __init__(self): |
|
self.is_initialized = True |
|
|
|
|
|
SimpleTool.name = tool_json_schema["name"] |
|
SimpleTool.description = tool_json_schema["description"] |
|
SimpleTool.inputs = tool_json_schema["parameters"]["properties"] |
|
SimpleTool.output_type = tool_json_schema["return"]["type"] |
|
|
|
@wraps(tool_function) |
|
def wrapped_function(*args, **kwargs): |
|
return tool_function(*args, **kwargs) |
|
|
|
|
|
SimpleTool.forward = staticmethod(wrapped_function) |
|
|
|
|
|
sig = inspect.signature(tool_function) |
|
|
|
new_sig = sig.replace( |
|
parameters=[inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD)] + list(sig.parameters.values()) |
|
) |
|
|
|
SimpleTool.forward.__signature__ = new_sig |
|
|
|
|
|
|
|
tool_source = inspect.getsource(tool_function) |
|
|
|
tool_source_body = "\n".join(tool_source.split("\n")[2:]) |
|
|
|
tool_source_body = textwrap.dedent(tool_source_body) |
|
|
|
forward_method_source = f"def forward{str(new_sig)}:\n{textwrap.indent(tool_source_body, ' ')}" |
|
|
|
class_source = ( |
|
textwrap.dedent(f""" |
|
class SimpleTool(Tool): |
|
name: str = "{tool_json_schema["name"]}" |
|
description: str = {json.dumps(textwrap.dedent(tool_json_schema["description"]).strip())} |
|
inputs: dict[str, dict[str, str]] = {tool_json_schema["parameters"]["properties"]} |
|
output_type: str = "{tool_json_schema["return"]["type"]}" |
|
|
|
def __init__(self): |
|
self.is_initialized = True |
|
|
|
""") |
|
+ textwrap.indent(forward_method_source, " ") |
|
) |
|
|
|
SimpleTool.__source__ = class_source |
|
SimpleTool.forward.__source__ = forward_method_source |
|
|
|
simple_tool = SimpleTool() |
|
return simple_tool |
|
|
|
|
|
class PipelineTool(Tool): |
|
""" |
|
A [`Tool`] tailored towards Transformer models. On top of the class attributes of the base class [`Tool`], you will |
|
need to specify: |
|
|
|
- **model_class** (`type`) -- The class to use to load the model in this tool. |
|
- **default_checkpoint** (`str`) -- The default checkpoint that should be used when the user doesn't specify one. |
|
- **pre_processor_class** (`type`, *optional*, defaults to [`transformers.AutoProcessor`]) -- The class to use to load the |
|
pre-processor |
|
- **post_processor_class** (`type`, *optional*, defaults to [`transformers.AutoProcessor`]) -- The class to use to load the |
|
post-processor (when different from the pre-processor). |
|
|
|
Args: |
|
model (`str` or [`transformers.PreTrainedModel`], *optional*): |
|
The name of the checkpoint to use for the model, or the instantiated model. If unset, will default to the |
|
value of the class attribute `default_checkpoint`. |
|
pre_processor (`str` or `Any`, *optional*): |
|
The name of the checkpoint to use for the pre-processor, or the instantiated pre-processor (can be a |
|
tokenizer, an image processor, a feature extractor or a processor). Will default to the value of `model` if |
|
unset. |
|
post_processor (`str` or `Any`, *optional*): |
|
The name of the checkpoint to use for the post-processor, or the instantiated pre-processor (can be a |
|
tokenizer, an image processor, a feature extractor or a processor). Will default to the `pre_processor` if |
|
unset. |
|
device (`int`, `str` or `torch.device`, *optional*): |
|
The device on which to execute the model. Will default to any accelerator available (GPU, MPS etc...), the |
|
CPU otherwise. |
|
device_map (`str` or `dict`, *optional*): |
|
If passed along, will be used to instantiate the model. |
|
model_kwargs (`dict`, *optional*): |
|
Any keyword argument to send to the model instantiation. |
|
token (`str`, *optional*): |
|
The token to use as HTTP bearer authorization for remote files. If unset, will use the token generated when |
|
running `huggingface-cli login` (stored in `~/.huggingface`). |
|
hub_kwargs (additional keyword arguments, *optional*): |
|
Any additional keyword argument to send to the methods that will load the data from the Hub. |
|
""" |
|
|
|
pre_processor_class = None |
|
model_class = None |
|
post_processor_class = None |
|
default_checkpoint = None |
|
description = "This is a pipeline tool" |
|
name = "pipeline" |
|
inputs = {"prompt": str} |
|
output_type = str |
|
skip_forward_signature_validation = True |
|
|
|
def __init__( |
|
self, |
|
model=None, |
|
pre_processor=None, |
|
post_processor=None, |
|
device=None, |
|
device_map=None, |
|
model_kwargs=None, |
|
token=None, |
|
**hub_kwargs, |
|
): |
|
if not _is_package_available("accelerate") or not _is_package_available("torch"): |
|
raise ModuleNotFoundError( |
|
"Please install 'transformers' extra to use a PipelineTool: `pip install 'smolagents[transformers]'`" |
|
) |
|
|
|
if model is None: |
|
if self.default_checkpoint is None: |
|
raise ValueError("This tool does not implement a default checkpoint, you need to pass one.") |
|
model = self.default_checkpoint |
|
if pre_processor is None: |
|
pre_processor = model |
|
|
|
self.model = model |
|
self.pre_processor = pre_processor |
|
self.post_processor = post_processor |
|
self.device = device |
|
self.device_map = device_map |
|
self.model_kwargs = {} if model_kwargs is None else model_kwargs |
|
if device_map is not None: |
|
self.model_kwargs["device_map"] = device_map |
|
self.hub_kwargs = hub_kwargs |
|
self.hub_kwargs["token"] = token |
|
|
|
super().__init__() |
|
|
|
def setup(self): |
|
""" |
|
Instantiates the `pre_processor`, `model` and `post_processor` if necessary. |
|
""" |
|
if isinstance(self.pre_processor, str): |
|
if self.pre_processor_class is None: |
|
from transformers import AutoProcessor |
|
|
|
self.pre_processor_class = AutoProcessor |
|
self.pre_processor = self.pre_processor_class.from_pretrained(self.pre_processor, **self.hub_kwargs) |
|
|
|
if isinstance(self.model, str): |
|
self.model = self.model_class.from_pretrained(self.model, **self.model_kwargs, **self.hub_kwargs) |
|
|
|
if self.post_processor is None: |
|
self.post_processor = self.pre_processor |
|
elif isinstance(self.post_processor, str): |
|
if self.post_processor_class is None: |
|
from transformers import AutoProcessor |
|
|
|
self.post_processor_class = AutoProcessor |
|
self.post_processor = self.post_processor_class.from_pretrained(self.post_processor, **self.hub_kwargs) |
|
|
|
if self.device is None: |
|
if self.device_map is not None: |
|
self.device = list(self.model.hf_device_map.values())[0] |
|
else: |
|
from accelerate import PartialState |
|
|
|
self.device = PartialState().default_device |
|
|
|
if self.device_map is None: |
|
self.model.to(self.device) |
|
|
|
super().setup() |
|
|
|
def encode(self, raw_inputs): |
|
""" |
|
Uses the `pre_processor` to prepare the inputs for the `model`. |
|
""" |
|
return self.pre_processor(raw_inputs) |
|
|
|
def forward(self, inputs): |
|
""" |
|
Sends the inputs through the `model`. |
|
""" |
|
import torch |
|
|
|
with torch.no_grad(): |
|
return self.model(**inputs) |
|
|
|
def decode(self, outputs): |
|
""" |
|
Uses the `post_processor` to decode the model output. |
|
""" |
|
return self.post_processor(outputs) |
|
|
|
def __call__(self, *args, sanitize_inputs_outputs: bool = False, **kwargs): |
|
import torch |
|
from accelerate.utils import send_to_device |
|
|
|
if not self.is_initialized: |
|
self.setup() |
|
|
|
if sanitize_inputs_outputs: |
|
args, kwargs = handle_agent_input_types(*args, **kwargs) |
|
encoded_inputs = self.encode(*args, **kwargs) |
|
|
|
tensor_inputs = {k: v for k, v in encoded_inputs.items() if isinstance(v, torch.Tensor)} |
|
non_tensor_inputs = {k: v for k, v in encoded_inputs.items() if not isinstance(v, torch.Tensor)} |
|
|
|
encoded_inputs = send_to_device(tensor_inputs, self.device) |
|
outputs = self.forward({**encoded_inputs, **non_tensor_inputs}) |
|
outputs = send_to_device(outputs, "cpu") |
|
decoded_outputs = self.decode(outputs) |
|
if sanitize_inputs_outputs: |
|
decoded_outputs = handle_agent_output_types(decoded_outputs, self.output_type) |
|
return decoded_outputs |
|
|
|
|
|
def get_tools_definition_code(tools: dict[str, Tool]) -> str: |
|
tool_codes = [] |
|
for tool in tools.values(): |
|
validate_tool_attributes(tool.__class__, check_imports=False) |
|
tool_code = instance_to_source(tool, base_cls=Tool) |
|
tool_code = tool_code.replace("from smolagents.tools import Tool", "") |
|
tool_code += f"\n\n{tool.name} = {tool.__class__.__name__}()\n" |
|
tool_codes.append(tool_code) |
|
|
|
tool_definition_code = "\n".join([f"import {module}" for module in BASE_BUILTIN_MODULES]) |
|
tool_definition_code += textwrap.dedent( |
|
""" |
|
from typing import Any |
|
|
|
class Tool: |
|
def __call__(self, *args, **kwargs): |
|
return self.forward(*args, **kwargs) |
|
|
|
def forward(self, *args, **kwargs): |
|
pass # to be implemented in child class |
|
""" |
|
) |
|
tool_definition_code += "\n\n".join(tool_codes) |
|
return tool_definition_code |
|
|
|
|
|
def validate_tool_arguments(tool: Tool, arguments: Any) -> str | None: |
|
if isinstance(arguments, dict): |
|
for key, value in arguments.items(): |
|
if key not in tool.inputs: |
|
return f"Argument {key} is not in the tool's input schema." |
|
|
|
parsed_type = _get_json_schema_type(type(value))["type"] |
|
|
|
if parsed_type != tool.inputs[key]["type"] and not tool.inputs[key]["type"] == "any": |
|
return f"Argument {key} has type '{parsed_type}' but should be '{tool.inputs[key]['type']}'." |
|
for key in tool.inputs: |
|
if key not in arguments: |
|
return f"Argument {key} is required." |
|
return None |
|
else: |
|
expected_type = list(tool.inputs.values())[0]["type"] |
|
if _get_json_schema_type(type(arguments))["type"] != expected_type and not expected_type == "any": |
|
return f"Argument has type '{type(arguments).__name__}' but should be '{expected_type}'." |
|
return None |
|
|
|
|
|
__all__ = [ |
|
"AUTHORIZED_TYPES", |
|
"Tool", |
|
"tool", |
|
"load_tool", |
|
"launch_gradio_demo", |
|
"ToolCollection", |
|
] |
|
|