Upload 17 files
Browse files- src/smolagents/__init__.py +31 -0
- src/smolagents/_function_type_hints_utils.py +431 -0
- src/smolagents/agent_types.py +283 -0
- src/smolagents/agents.py +1725 -0
- src/smolagents/cli.py +164 -0
- src/smolagents/default_tools.py +577 -0
- src/smolagents/gradio_ui.py +508 -0
- src/smolagents/local_python_executor.py +1611 -0
- src/smolagents/mcp_client.py +154 -0
- src/smolagents/memory.py +257 -0
- src/smolagents/models.py +1882 -0
- src/smolagents/monitoring.py +265 -0
- src/smolagents/remote_executors.py +451 -0
- src/smolagents/tool_validation.py +266 -0
- src/smolagents/tools.py +1239 -0
- src/smolagents/utils.py +500 -0
- src/smolagents/vision_web_browser.py +228 -0
src/smolagents/__init__.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding=utf-8
|
3 |
+
|
4 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
5 |
+
#
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# Unless required by applicable law or agreed to in writing, software
|
13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
# See the License for the specific language governing permissions and
|
16 |
+
# limitations under the License.
|
17 |
+
__version__ = "1.20.0.dev0"
|
18 |
+
|
19 |
+
from .agent_types import * # noqa: I001
|
20 |
+
from .agents import * # Above noqa avoids a circular dependency due to cli.py
|
21 |
+
from .default_tools import *
|
22 |
+
from .gradio_ui import *
|
23 |
+
from .local_python_executor import *
|
24 |
+
from .mcp_client import *
|
25 |
+
from .memory import *
|
26 |
+
from .models import *
|
27 |
+
from .monitoring import *
|
28 |
+
from .remote_executors import *
|
29 |
+
from .tools import *
|
30 |
+
from .utils import *
|
31 |
+
from .cli import *
|
src/smolagents/_function_type_hints_utils.py
ADDED
@@ -0,0 +1,431 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding=utf-8
|
3 |
+
|
4 |
+
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
5 |
+
#
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# Unless required by applicable law or agreed to in writing, software
|
13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
# See the License for the specific language governing permissions and
|
16 |
+
# limitations under the License.
|
17 |
+
"""This module contains utilities exclusively taken from `transformers` repository.
|
18 |
+
|
19 |
+
Since they are not specific to `transformers` and that `transformers` is an heavy dependencies, those helpers have
|
20 |
+
been duplicated.
|
21 |
+
|
22 |
+
TODO: move them to `huggingface_hub` to avoid code duplication.
|
23 |
+
"""
|
24 |
+
|
25 |
+
import inspect
|
26 |
+
import json
|
27 |
+
import re
|
28 |
+
import types
|
29 |
+
from collections.abc import Callable
|
30 |
+
from copy import copy
|
31 |
+
from typing import (
|
32 |
+
Any,
|
33 |
+
Literal,
|
34 |
+
Union,
|
35 |
+
get_args,
|
36 |
+
get_origin,
|
37 |
+
get_type_hints,
|
38 |
+
)
|
39 |
+
|
40 |
+
|
41 |
+
IMPORT_TO_PACKAGE_MAPPING = {
|
42 |
+
"wikipediaapi": "wikipedia-api",
|
43 |
+
}
|
44 |
+
|
45 |
+
|
46 |
+
def get_package_name(import_name: str) -> str:
|
47 |
+
"""
|
48 |
+
Return the package name for a given import name.
|
49 |
+
|
50 |
+
Args:
|
51 |
+
import_name (`str`): Import name to get the package name for.
|
52 |
+
|
53 |
+
Returns:
|
54 |
+
`str`: Package name for the given import name.
|
55 |
+
"""
|
56 |
+
return IMPORT_TO_PACKAGE_MAPPING.get(import_name, import_name)
|
57 |
+
|
58 |
+
|
59 |
+
def get_imports(code: str) -> list[str]:
|
60 |
+
"""
|
61 |
+
Extracts all the libraries (not relative imports) that are imported in a code.
|
62 |
+
|
63 |
+
Args:
|
64 |
+
code (`str`): Code text to inspect.
|
65 |
+
|
66 |
+
Returns:
|
67 |
+
`list[str]`: List of all packages required to use the input code.
|
68 |
+
"""
|
69 |
+
# filter out try/except block so in custom code we can have try/except imports
|
70 |
+
code = re.sub(r"\s*try\s*:.*?except.*?:", "", code, flags=re.DOTALL)
|
71 |
+
|
72 |
+
# filter out imports under is_flash_attn_2_available block for avoid import issues in cpu only environment
|
73 |
+
code = re.sub(
|
74 |
+
r"if is_flash_attn[a-zA-Z0-9_]+available\(\):\s*(from flash_attn\s*.*\s*)+",
|
75 |
+
"",
|
76 |
+
code,
|
77 |
+
flags=re.MULTILINE,
|
78 |
+
)
|
79 |
+
|
80 |
+
# Imports of the form `import xxx` or `import xxx as yyy`
|
81 |
+
imports = re.findall(r"^\s*import\s+(\S+?)(?:\s+as\s+\S+)?\s*$", code, flags=re.MULTILINE)
|
82 |
+
# Imports of the form `from xxx import yyy`
|
83 |
+
imports += re.findall(r"^\s*from\s+(\S+)\s+import", code, flags=re.MULTILINE)
|
84 |
+
# Only keep the top-level module
|
85 |
+
imports = [imp.split(".")[0] for imp in imports if not imp.startswith(".")]
|
86 |
+
return [get_package_name(import_name) for import_name in set(imports)]
|
87 |
+
|
88 |
+
|
89 |
+
class TypeHintParsingException(Exception):
|
90 |
+
"""Exception raised for errors in parsing type hints to generate JSON schemas"""
|
91 |
+
|
92 |
+
|
93 |
+
class DocstringParsingException(Exception):
|
94 |
+
"""Exception raised for errors in parsing docstrings to generate JSON schemas"""
|
95 |
+
|
96 |
+
|
97 |
+
def get_json_schema(func: Callable) -> dict:
|
98 |
+
"""
|
99 |
+
This function generates a JSON schema for a given function, based on its docstring and type hints. This is
|
100 |
+
mostly used for passing lists of tools to a chat template. The JSON schema contains the name and description of
|
101 |
+
the function, as well as the names, types and descriptions for each of its arguments. `get_json_schema()` requires
|
102 |
+
that the function has a docstring, and that each argument has a description in the docstring, in the standard
|
103 |
+
Google docstring format shown below. It also requires that all the function arguments have a valid Python type hint.
|
104 |
+
|
105 |
+
Although it is not required, a `Returns` block can also be added, which will be included in the schema. This is
|
106 |
+
optional because most chat templates ignore the return value of the function.
|
107 |
+
|
108 |
+
Args:
|
109 |
+
func: The function to generate a JSON schema for.
|
110 |
+
|
111 |
+
Returns:
|
112 |
+
A dictionary containing the JSON schema for the function.
|
113 |
+
|
114 |
+
Examples:
|
115 |
+
```python
|
116 |
+
>>> def multiply(x: float, y: float):
|
117 |
+
>>> '''
|
118 |
+
>>> A function that multiplies two numbers
|
119 |
+
>>>
|
120 |
+
>>> Args:
|
121 |
+
>>> x: The first number to multiply
|
122 |
+
>>> y: The second number to multiply
|
123 |
+
>>> '''
|
124 |
+
>>> return x * y
|
125 |
+
>>>
|
126 |
+
>>> print(get_json_schema(multiply))
|
127 |
+
{
|
128 |
+
"name": "multiply",
|
129 |
+
"description": "A function that multiplies two numbers",
|
130 |
+
"parameters": {
|
131 |
+
"type": "object",
|
132 |
+
"properties": {
|
133 |
+
"x": {"type": "number", "description": "The first number to multiply"},
|
134 |
+
"y": {"type": "number", "description": "The second number to multiply"}
|
135 |
+
},
|
136 |
+
"required": ["x", "y"]
|
137 |
+
}
|
138 |
+
}
|
139 |
+
```
|
140 |
+
|
141 |
+
The general use for these schemas is that they are used to generate tool descriptions for chat templates that
|
142 |
+
support them, like so:
|
143 |
+
|
144 |
+
```python
|
145 |
+
>>> from transformers import AutoTokenizer
|
146 |
+
>>> from transformers.utils import get_json_schema
|
147 |
+
>>>
|
148 |
+
>>> def multiply(x: float, y: float):
|
149 |
+
>>> '''
|
150 |
+
>>> A function that multiplies two numbers
|
151 |
+
>>>
|
152 |
+
>>> Args:
|
153 |
+
>>> x: The first number to multiply
|
154 |
+
>>> y: The second number to multiply
|
155 |
+
>>> return x * y
|
156 |
+
>>> '''
|
157 |
+
>>>
|
158 |
+
>>> multiply_schema = get_json_schema(multiply)
|
159 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("CohereForAI/c4ai-command-r-v01")
|
160 |
+
>>> messages = [{"role": "user", "content": "What is 179 x 4571?"}]
|
161 |
+
>>> formatted_chat = tokenizer.apply_chat_template(
|
162 |
+
>>> messages,
|
163 |
+
>>> tools=[multiply_schema],
|
164 |
+
>>> chat_template="tool_use",
|
165 |
+
>>> return_dict=True,
|
166 |
+
>>> return_tensors="pt",
|
167 |
+
>>> add_generation_prompt=True
|
168 |
+
>>> )
|
169 |
+
>>> # The formatted chat can now be passed to model.generate()
|
170 |
+
```
|
171 |
+
|
172 |
+
Each argument description can also have an optional `(choices: ...)` block at the end, such as
|
173 |
+
`(choices: ["tea", "coffee"])`, which will be parsed into an `enum` field in the schema. Note that this will
|
174 |
+
only be parsed correctly if it is at the end of the line:
|
175 |
+
|
176 |
+
```python
|
177 |
+
>>> def drink_beverage(beverage: str):
|
178 |
+
>>> '''
|
179 |
+
>>> A function that drinks a beverage
|
180 |
+
>>>
|
181 |
+
>>> Args:
|
182 |
+
>>> beverage: The beverage to drink (choices: ["tea", "coffee"])
|
183 |
+
>>> '''
|
184 |
+
>>> pass
|
185 |
+
>>>
|
186 |
+
>>> print(get_json_schema(drink_beverage))
|
187 |
+
```
|
188 |
+
{
|
189 |
+
'name': 'drink_beverage',
|
190 |
+
'description': 'A function that drinks a beverage',
|
191 |
+
'parameters': {
|
192 |
+
'type': 'object',
|
193 |
+
'properties': {
|
194 |
+
'beverage': {
|
195 |
+
'type': 'string',
|
196 |
+
'enum': ['tea', 'coffee'],
|
197 |
+
'description': 'The beverage to drink'
|
198 |
+
}
|
199 |
+
},
|
200 |
+
'required': ['beverage']
|
201 |
+
}
|
202 |
+
}
|
203 |
+
"""
|
204 |
+
doc = inspect.getdoc(func)
|
205 |
+
if not doc:
|
206 |
+
raise DocstringParsingException(
|
207 |
+
f"Cannot generate JSON schema for {func.__name__} because it has no docstring!"
|
208 |
+
)
|
209 |
+
doc = doc.strip()
|
210 |
+
main_doc, param_descriptions, return_doc = _parse_google_format_docstring(doc)
|
211 |
+
|
212 |
+
json_schema = _convert_type_hints_to_json_schema(func)
|
213 |
+
if (return_dict := json_schema["properties"].pop("return", None)) is not None:
|
214 |
+
if return_doc is not None: # We allow a missing return docstring since most templates ignore it
|
215 |
+
return_dict["description"] = return_doc
|
216 |
+
for arg, schema in json_schema["properties"].items():
|
217 |
+
if arg not in param_descriptions:
|
218 |
+
raise DocstringParsingException(
|
219 |
+
f"Cannot generate JSON schema for {func.__name__} because the docstring has no description for the argument '{arg}'"
|
220 |
+
)
|
221 |
+
desc = param_descriptions[arg]
|
222 |
+
enum_choices = re.search(r"\(choices:\s*(.*?)\)\s*$", desc, flags=re.IGNORECASE)
|
223 |
+
if enum_choices:
|
224 |
+
schema["enum"] = [c.strip() for c in json.loads(enum_choices.group(1))]
|
225 |
+
desc = enum_choices.string[: enum_choices.start()].strip()
|
226 |
+
schema["description"] = desc
|
227 |
+
|
228 |
+
output = {"name": func.__name__, "description": main_doc, "parameters": json_schema}
|
229 |
+
if return_dict is not None:
|
230 |
+
output["return"] = return_dict
|
231 |
+
return {"type": "function", "function": output}
|
232 |
+
|
233 |
+
|
234 |
+
# Extracts the initial segment of the docstring, containing the function description
|
235 |
+
description_re = re.compile(r"^(.*?)(?=\n\s*(Args:|Returns:|Raises:)|\Z)", re.DOTALL)
|
236 |
+
# Extracts the Args: block from the docstring
|
237 |
+
args_re = re.compile(r"\n\s*Args:\n\s*(.*?)[\n\s]*(Returns:|Raises:|\Z)", re.DOTALL)
|
238 |
+
# Splits the Args: block into individual arguments
|
239 |
+
args_split_re = re.compile(
|
240 |
+
r"(?:^|\n)" # Match the start of the args block, or a newline
|
241 |
+
r"\s*(\w+)\s*(?:\([^)]*?\))?:\s*" # Capture the argument name (ignore the type) and strip spacing
|
242 |
+
r"(.*?)\s*" # Capture the argument description, which can span multiple lines, and strip trailing spacing
|
243 |
+
r"(?=\n\s*\w+\s*(?:\([^)]*?\))?:|\Z)", # Stop when you hit the next argument (with or without type) or the end of the block
|
244 |
+
re.DOTALL | re.VERBOSE,
|
245 |
+
)
|
246 |
+
# Extracts the Returns: block from the docstring, if present. Note that most chat templates ignore the return type/doc!
|
247 |
+
returns_re = re.compile(
|
248 |
+
r"\n\s*Returns:\n\s*"
|
249 |
+
r"(?:[^)]*?:\s*)?" # Ignore the return type if present
|
250 |
+
r"(.*?)" # Capture the return description
|
251 |
+
r"[\n\s]*(Raises:|\Z)",
|
252 |
+
re.DOTALL,
|
253 |
+
)
|
254 |
+
|
255 |
+
|
256 |
+
def _parse_google_format_docstring(
|
257 |
+
docstring: str,
|
258 |
+
) -> tuple[str | None, dict | None, str | None]:
|
259 |
+
"""
|
260 |
+
Parses a Google-style docstring to extract the function description,
|
261 |
+
argument descriptions, and return description.
|
262 |
+
|
263 |
+
Args:
|
264 |
+
docstring (str): The docstring to parse.
|
265 |
+
|
266 |
+
Returns:
|
267 |
+
The function description, arguments, and return description.
|
268 |
+
"""
|
269 |
+
|
270 |
+
# Extract the sections
|
271 |
+
description_match = description_re.search(docstring)
|
272 |
+
args_match = args_re.search(docstring)
|
273 |
+
returns_match = returns_re.search(docstring)
|
274 |
+
|
275 |
+
# Clean and store the sections
|
276 |
+
description = description_match.group(1).strip() if description_match else None
|
277 |
+
docstring_args = args_match.group(1).strip() if args_match else None
|
278 |
+
returns = returns_match.group(1).strip() if returns_match else None
|
279 |
+
|
280 |
+
# Parsing the arguments into a dictionary
|
281 |
+
if docstring_args is not None:
|
282 |
+
docstring_args = "\n".join([line for line in docstring_args.split("\n") if line.strip()]) # Remove blank lines
|
283 |
+
matches = args_split_re.findall(docstring_args)
|
284 |
+
args_dict = {match[0]: re.sub(r"\s*\n+\s*", " ", match[1].strip()) for match in matches}
|
285 |
+
else:
|
286 |
+
args_dict = {}
|
287 |
+
|
288 |
+
return description, args_dict, returns
|
289 |
+
|
290 |
+
|
291 |
+
def _convert_type_hints_to_json_schema(func: Callable, error_on_missing_type_hints: bool = True) -> dict:
|
292 |
+
type_hints = get_type_hints(func)
|
293 |
+
signature = inspect.signature(func)
|
294 |
+
|
295 |
+
properties = {}
|
296 |
+
for param_name, param_type in type_hints.items():
|
297 |
+
properties[param_name] = _parse_type_hint(param_type)
|
298 |
+
|
299 |
+
required = []
|
300 |
+
for param_name, param in signature.parameters.items():
|
301 |
+
if param.annotation == inspect.Parameter.empty and error_on_missing_type_hints:
|
302 |
+
raise TypeHintParsingException(f"Argument {param.name} is missing a type hint in function {func.__name__}")
|
303 |
+
if param_name not in properties:
|
304 |
+
properties[param_name] = {}
|
305 |
+
|
306 |
+
if param.default == inspect.Parameter.empty:
|
307 |
+
required.append(param_name)
|
308 |
+
else:
|
309 |
+
properties[param_name]["nullable"] = True
|
310 |
+
|
311 |
+
# Return: multi‐type union -> treat as any
|
312 |
+
if (
|
313 |
+
"return" in properties
|
314 |
+
and (return_type := properties["return"].get("type"))
|
315 |
+
and not isinstance(return_type, str)
|
316 |
+
):
|
317 |
+
properties["return"]["type"] = "any"
|
318 |
+
|
319 |
+
schema = {"type": "object", "properties": properties}
|
320 |
+
if required:
|
321 |
+
schema["required"] = required
|
322 |
+
|
323 |
+
return schema
|
324 |
+
|
325 |
+
|
326 |
+
def _parse_type_hint(hint: type) -> dict:
|
327 |
+
origin = get_origin(hint)
|
328 |
+
args = get_args(hint)
|
329 |
+
|
330 |
+
if origin is None:
|
331 |
+
try:
|
332 |
+
return _get_json_schema_type(hint)
|
333 |
+
except KeyError:
|
334 |
+
raise TypeHintParsingException(
|
335 |
+
"Couldn't parse this type hint, likely due to a custom class or object: ",
|
336 |
+
hint,
|
337 |
+
)
|
338 |
+
|
339 |
+
elif origin is Union or (hasattr(types, "UnionType") and origin is types.UnionType):
|
340 |
+
return _parse_union_type(args)
|
341 |
+
|
342 |
+
elif origin is list:
|
343 |
+
if not args:
|
344 |
+
return {"type": "array"}
|
345 |
+
else:
|
346 |
+
# Lists can only have a single type argument, so recurse into it
|
347 |
+
return {"type": "array", "items": _parse_type_hint(args[0])}
|
348 |
+
|
349 |
+
elif origin is tuple:
|
350 |
+
if not args:
|
351 |
+
return {"type": "array"}
|
352 |
+
if len(args) == 1:
|
353 |
+
raise TypeHintParsingException(
|
354 |
+
f"The type hint {str(hint).replace('typing.', '')} is a Tuple with a single element, which "
|
355 |
+
"we do not automatically convert to JSON schema as it is rarely necessary. If this input can contain "
|
356 |
+
"more than one element, we recommend "
|
357 |
+
"using a List[] type instead, or if it really is a single element, remove the Tuple[] wrapper and just "
|
358 |
+
"pass the element directly."
|
359 |
+
)
|
360 |
+
if ... in args:
|
361 |
+
raise TypeHintParsingException(
|
362 |
+
"Conversion of '...' is not supported in Tuple type hints. "
|
363 |
+
"Use List[] types for variable-length"
|
364 |
+
" inputs instead."
|
365 |
+
)
|
366 |
+
return {"type": "array", "prefixItems": [_parse_type_hint(t) for t in args]}
|
367 |
+
|
368 |
+
elif origin is dict:
|
369 |
+
# The JSON equivalent to a dict is 'object', which mandates that all keys are strings
|
370 |
+
# However, we can specify the type of the dict values with "additionalProperties"
|
371 |
+
out = {"type": "object"}
|
372 |
+
if len(args) == 2:
|
373 |
+
out["additionalProperties"] = _parse_type_hint(args[1])
|
374 |
+
return out
|
375 |
+
|
376 |
+
elif origin is Literal:
|
377 |
+
literal_types = set(type(arg) for arg in args)
|
378 |
+
final_type = _parse_union_type(literal_types)
|
379 |
+
|
380 |
+
# None literal value is represented by 'nullable' field set by _parse_union_type
|
381 |
+
final_type.update({"enum": [arg for arg in args if arg is not None]})
|
382 |
+
return final_type
|
383 |
+
|
384 |
+
raise TypeHintParsingException("Couldn't parse this type hint, likely due to a custom class or object: ", hint)
|
385 |
+
|
386 |
+
|
387 |
+
def _parse_union_type(args: tuple[Any, ...]) -> dict:
|
388 |
+
subtypes = [_parse_type_hint(t) for t in args if t is not type(None)]
|
389 |
+
if len(subtypes) == 1:
|
390 |
+
# A single non-null type can be expressed directly
|
391 |
+
return_dict = subtypes[0]
|
392 |
+
elif all(isinstance(subtype["type"], str) for subtype in subtypes):
|
393 |
+
# A union of basic types can be expressed as a list in the schema
|
394 |
+
return_dict = {"type": sorted([subtype["type"] for subtype in subtypes])}
|
395 |
+
else:
|
396 |
+
# A union of more complex types requires "anyOf"
|
397 |
+
return_dict = {"anyOf": subtypes}
|
398 |
+
if type(None) in args:
|
399 |
+
return_dict["nullable"] = True
|
400 |
+
return return_dict
|
401 |
+
|
402 |
+
|
403 |
+
_BASE_TYPE_MAPPING = {
|
404 |
+
int: {"type": "integer"},
|
405 |
+
float: {"type": "number"},
|
406 |
+
str: {"type": "string"},
|
407 |
+
bool: {"type": "boolean"},
|
408 |
+
list: {"type": "array"},
|
409 |
+
dict: {"type": "object"},
|
410 |
+
Any: {"type": "any"},
|
411 |
+
types.NoneType: {"type": "null"},
|
412 |
+
}
|
413 |
+
|
414 |
+
|
415 |
+
def _get_json_schema_type(param_type: type) -> dict[str, str]:
|
416 |
+
if param_type in _BASE_TYPE_MAPPING:
|
417 |
+
return copy(_BASE_TYPE_MAPPING[param_type])
|
418 |
+
if str(param_type) == "Image":
|
419 |
+
from PIL.Image import Image
|
420 |
+
|
421 |
+
if param_type == Image:
|
422 |
+
return {"type": "image"}
|
423 |
+
if str(param_type) == "Tensor":
|
424 |
+
try:
|
425 |
+
from torch import Tensor
|
426 |
+
|
427 |
+
if param_type == Tensor:
|
428 |
+
return {"type": "audio"}
|
429 |
+
except ModuleNotFoundError:
|
430 |
+
pass
|
431 |
+
return {"type": "object"}
|
src/smolagents/agent_types.py
ADDED
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 HuggingFace Inc.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
import logging
|
16 |
+
import os
|
17 |
+
import pathlib
|
18 |
+
import tempfile
|
19 |
+
import uuid
|
20 |
+
from io import BytesIO
|
21 |
+
|
22 |
+
import PIL.Image
|
23 |
+
import requests
|
24 |
+
|
25 |
+
from .utils import _is_package_available
|
26 |
+
|
27 |
+
|
28 |
+
logger = logging.getLogger(__name__)
|
29 |
+
|
30 |
+
|
31 |
+
class AgentType:
|
32 |
+
"""
|
33 |
+
Abstract class to be reimplemented to define types that can be returned by agents.
|
34 |
+
|
35 |
+
These objects serve three purposes:
|
36 |
+
|
37 |
+
- They behave as they were the type they're meant to be, e.g., a string for text, a PIL.Image.Image for images
|
38 |
+
- They can be stringified: str(object) in order to return a string defining the object
|
39 |
+
- They should be displayed correctly in ipython notebooks/colab/jupyter
|
40 |
+
"""
|
41 |
+
|
42 |
+
def __init__(self, value):
|
43 |
+
self._value = value
|
44 |
+
|
45 |
+
def __str__(self):
|
46 |
+
return self.to_string()
|
47 |
+
|
48 |
+
def to_raw(self):
|
49 |
+
logger.error(
|
50 |
+
"This is a raw AgentType of unknown type. Display in notebooks and string conversion will be unreliable"
|
51 |
+
)
|
52 |
+
return self._value
|
53 |
+
|
54 |
+
def to_string(self) -> str:
|
55 |
+
logger.error(
|
56 |
+
"This is a raw AgentType of unknown type. Display in notebooks and string conversion will be unreliable"
|
57 |
+
)
|
58 |
+
return str(self._value)
|
59 |
+
|
60 |
+
|
61 |
+
class AgentText(AgentType, str):
|
62 |
+
"""
|
63 |
+
Text type returned by the agent. Behaves as a string.
|
64 |
+
"""
|
65 |
+
|
66 |
+
def to_raw(self):
|
67 |
+
return self._value
|
68 |
+
|
69 |
+
def to_string(self):
|
70 |
+
return str(self._value)
|
71 |
+
|
72 |
+
|
73 |
+
class AgentImage(AgentType, PIL.Image.Image):
|
74 |
+
"""
|
75 |
+
Image type returned by the agent. Behaves as a PIL.Image.Image.
|
76 |
+
"""
|
77 |
+
|
78 |
+
def __init__(self, value):
|
79 |
+
AgentType.__init__(self, value)
|
80 |
+
PIL.Image.Image.__init__(self)
|
81 |
+
|
82 |
+
self._path = None
|
83 |
+
self._raw = None
|
84 |
+
self._tensor = None
|
85 |
+
|
86 |
+
if isinstance(value, AgentImage):
|
87 |
+
self._raw, self._path, self._tensor = value._raw, value._path, value._tensor
|
88 |
+
elif isinstance(value, PIL.Image.Image):
|
89 |
+
self._raw = value
|
90 |
+
elif isinstance(value, bytes):
|
91 |
+
self._raw = PIL.Image.open(BytesIO(value))
|
92 |
+
elif isinstance(value, (str, pathlib.Path)):
|
93 |
+
self._path = value
|
94 |
+
else:
|
95 |
+
try:
|
96 |
+
import torch
|
97 |
+
|
98 |
+
if isinstance(value, torch.Tensor):
|
99 |
+
self._tensor = value
|
100 |
+
import numpy as np
|
101 |
+
|
102 |
+
if isinstance(value, np.ndarray):
|
103 |
+
self._tensor = torch.from_numpy(value)
|
104 |
+
except ModuleNotFoundError:
|
105 |
+
pass
|
106 |
+
|
107 |
+
if self._path is None and self._raw is None and self._tensor is None:
|
108 |
+
raise TypeError(f"Unsupported type for {self.__class__.__name__}: {type(value)}")
|
109 |
+
|
110 |
+
def _ipython_display_(self, include=None, exclude=None):
|
111 |
+
"""
|
112 |
+
Displays correctly this type in an ipython notebook (ipython, colab, jupyter, ...)
|
113 |
+
"""
|
114 |
+
from IPython.display import Image, display
|
115 |
+
|
116 |
+
display(Image(self.to_string()))
|
117 |
+
|
118 |
+
def to_raw(self):
|
119 |
+
"""
|
120 |
+
Returns the "raw" version of that object. In the case of an AgentImage, it is a PIL.Image.Image.
|
121 |
+
"""
|
122 |
+
if self._raw is not None:
|
123 |
+
return self._raw
|
124 |
+
|
125 |
+
if self._path is not None:
|
126 |
+
self._raw = PIL.Image.open(self._path)
|
127 |
+
return self._raw
|
128 |
+
|
129 |
+
if self._tensor is not None:
|
130 |
+
import numpy as np
|
131 |
+
|
132 |
+
array = self._tensor.cpu().detach().numpy()
|
133 |
+
return PIL.Image.fromarray((255 - array * 255).astype(np.uint8))
|
134 |
+
|
135 |
+
def to_string(self):
|
136 |
+
"""
|
137 |
+
Returns the stringified version of that object. In the case of an AgentImage, it is a path to the serialized
|
138 |
+
version of the image.
|
139 |
+
"""
|
140 |
+
if self._path is not None:
|
141 |
+
return self._path
|
142 |
+
|
143 |
+
if self._raw is not None:
|
144 |
+
directory = tempfile.mkdtemp()
|
145 |
+
self._path = os.path.join(directory, str(uuid.uuid4()) + ".png")
|
146 |
+
self._raw.save(self._path, format="png")
|
147 |
+
return self._path
|
148 |
+
|
149 |
+
if self._tensor is not None:
|
150 |
+
import numpy as np
|
151 |
+
|
152 |
+
array = self._tensor.cpu().detach().numpy()
|
153 |
+
|
154 |
+
# There is likely simpler than load into image into save
|
155 |
+
img = PIL.Image.fromarray((255 - array * 255).astype(np.uint8))
|
156 |
+
|
157 |
+
directory = tempfile.mkdtemp()
|
158 |
+
self._path = os.path.join(directory, str(uuid.uuid4()) + ".png")
|
159 |
+
img.save(self._path, format="png")
|
160 |
+
|
161 |
+
return self._path
|
162 |
+
|
163 |
+
def save(self, output_bytes, format: str = None, **params):
|
164 |
+
"""
|
165 |
+
Saves the image to a file.
|
166 |
+
Args:
|
167 |
+
output_bytes (bytes): The output bytes to save the image to.
|
168 |
+
format (str): The format to use for the output image. The format is the same as in PIL.Image.save.
|
169 |
+
**params: Additional parameters to pass to PIL.Image.save.
|
170 |
+
"""
|
171 |
+
img = self.to_raw()
|
172 |
+
img.save(output_bytes, format=format, **params)
|
173 |
+
|
174 |
+
|
175 |
+
class AgentAudio(AgentType, str):
|
176 |
+
"""
|
177 |
+
Audio type returned by the agent.
|
178 |
+
"""
|
179 |
+
|
180 |
+
def __init__(self, value, samplerate=16_000):
|
181 |
+
if not _is_package_available("soundfile") or not _is_package_available("torch"):
|
182 |
+
raise ModuleNotFoundError(
|
183 |
+
"Please install 'audio' extra to use AgentAudio: `pip install 'smolagents[audio]'`"
|
184 |
+
)
|
185 |
+
import numpy as np
|
186 |
+
import torch
|
187 |
+
|
188 |
+
super().__init__(value)
|
189 |
+
|
190 |
+
self._path = None
|
191 |
+
self._tensor = None
|
192 |
+
|
193 |
+
self.samplerate = samplerate
|
194 |
+
if isinstance(value, (str, pathlib.Path)):
|
195 |
+
self._path = value
|
196 |
+
elif isinstance(value, torch.Tensor):
|
197 |
+
self._tensor = value
|
198 |
+
elif isinstance(value, tuple):
|
199 |
+
self.samplerate = value[0]
|
200 |
+
if isinstance(value[1], np.ndarray):
|
201 |
+
self._tensor = torch.from_numpy(value[1])
|
202 |
+
else:
|
203 |
+
self._tensor = torch.tensor(value[1])
|
204 |
+
else:
|
205 |
+
raise ValueError(f"Unsupported audio type: {type(value)}")
|
206 |
+
|
207 |
+
def _ipython_display_(self, include=None, exclude=None):
|
208 |
+
"""
|
209 |
+
Displays correctly this type in an ipython notebook (ipython, colab, jupyter, ...)
|
210 |
+
"""
|
211 |
+
from IPython.display import Audio, display
|
212 |
+
|
213 |
+
display(Audio(self.to_string(), rate=self.samplerate))
|
214 |
+
|
215 |
+
def to_raw(self):
|
216 |
+
"""
|
217 |
+
Returns the "raw" version of that object. It is a `torch.Tensor` object.
|
218 |
+
"""
|
219 |
+
import soundfile as sf
|
220 |
+
|
221 |
+
if self._tensor is not None:
|
222 |
+
return self._tensor
|
223 |
+
|
224 |
+
import torch
|
225 |
+
|
226 |
+
if self._path is not None:
|
227 |
+
if "://" in str(self._path):
|
228 |
+
response = requests.get(self._path)
|
229 |
+
response.raise_for_status()
|
230 |
+
tensor, self.samplerate = sf.read(BytesIO(response.content))
|
231 |
+
else:
|
232 |
+
tensor, self.samplerate = sf.read(self._path)
|
233 |
+
self._tensor = torch.tensor(tensor)
|
234 |
+
return self._tensor
|
235 |
+
|
236 |
+
def to_string(self):
|
237 |
+
"""
|
238 |
+
Returns the stringified version of that object. In the case of an AgentAudio, it is a path to the serialized
|
239 |
+
version of the audio.
|
240 |
+
"""
|
241 |
+
import soundfile as sf
|
242 |
+
|
243 |
+
if self._path is not None:
|
244 |
+
return self._path
|
245 |
+
|
246 |
+
if self._tensor is not None:
|
247 |
+
directory = tempfile.mkdtemp()
|
248 |
+
self._path = os.path.join(directory, str(uuid.uuid4()) + ".wav")
|
249 |
+
sf.write(self._path, self._tensor, samplerate=self.samplerate)
|
250 |
+
return self._path
|
251 |
+
|
252 |
+
|
253 |
+
_AGENT_TYPE_MAPPING = {"string": AgentText, "image": AgentImage, "audio": AgentAudio}
|
254 |
+
|
255 |
+
|
256 |
+
def handle_agent_input_types(*args, **kwargs):
|
257 |
+
args = [(arg.to_raw() if isinstance(arg, AgentType) else arg) for arg in args]
|
258 |
+
kwargs = {k: (v.to_raw() if isinstance(v, AgentType) else v) for k, v in kwargs.items()}
|
259 |
+
return args, kwargs
|
260 |
+
|
261 |
+
|
262 |
+
def handle_agent_output_types(output, output_type=None):
|
263 |
+
if output_type in _AGENT_TYPE_MAPPING:
|
264 |
+
# If the class has defined outputs, we can map directly according to the class definition
|
265 |
+
decoded_outputs = _AGENT_TYPE_MAPPING[output_type](output)
|
266 |
+
return decoded_outputs
|
267 |
+
|
268 |
+
# If the class does not have defined output, then we map according to the type
|
269 |
+
if isinstance(output, str):
|
270 |
+
return AgentText(output)
|
271 |
+
if isinstance(output, PIL.Image.Image):
|
272 |
+
return AgentImage(output)
|
273 |
+
try:
|
274 |
+
import torch
|
275 |
+
|
276 |
+
if isinstance(output, torch.Tensor):
|
277 |
+
return AgentAudio(output)
|
278 |
+
except ModuleNotFoundError:
|
279 |
+
pass
|
280 |
+
return output
|
281 |
+
|
282 |
+
|
283 |
+
__all__ = ["AgentType", "AgentImage", "AgentText", "AgentAudio"]
|
src/smolagents/agents.py
ADDED
@@ -0,0 +1,1725 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding=utf-8
|
3 |
+
|
4 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
5 |
+
#
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# Unless required by applicable law or agreed to in writing, software
|
13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
# See the License for the specific language governing permissions and
|
16 |
+
# limitations under the License.
|
17 |
+
import importlib
|
18 |
+
import inspect
|
19 |
+
import json
|
20 |
+
import os
|
21 |
+
import re
|
22 |
+
import tempfile
|
23 |
+
import textwrap
|
24 |
+
import time
|
25 |
+
import warnings
|
26 |
+
from abc import ABC, abstractmethod
|
27 |
+
from collections.abc import Callable, Generator
|
28 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
29 |
+
from dataclasses import dataclass
|
30 |
+
from logging import getLogger
|
31 |
+
from pathlib import Path
|
32 |
+
from typing import TYPE_CHECKING, Any, Literal, TypeAlias, TypedDict, Union
|
33 |
+
|
34 |
+
import jinja2
|
35 |
+
import yaml
|
36 |
+
from huggingface_hub import create_repo, metadata_update, snapshot_download, upload_folder
|
37 |
+
from jinja2 import StrictUndefined, Template
|
38 |
+
from rich.console import Group
|
39 |
+
from rich.live import Live
|
40 |
+
from rich.markdown import Markdown
|
41 |
+
from rich.panel import Panel
|
42 |
+
from rich.rule import Rule
|
43 |
+
from rich.text import Text
|
44 |
+
|
45 |
+
|
46 |
+
if TYPE_CHECKING:
|
47 |
+
import PIL.Image
|
48 |
+
|
49 |
+
from .agent_types import AgentAudio, AgentImage, handle_agent_output_types
|
50 |
+
from .default_tools import TOOL_MAPPING, FinalAnswerTool
|
51 |
+
from .local_python_executor import BASE_BUILTIN_MODULES, LocalPythonExecutor, PythonExecutor, fix_final_answer_code
|
52 |
+
from .memory import (
|
53 |
+
ActionStep,
|
54 |
+
AgentMemory,
|
55 |
+
FinalAnswerStep,
|
56 |
+
PlanningStep,
|
57 |
+
SystemPromptStep,
|
58 |
+
TaskStep,
|
59 |
+
Timing,
|
60 |
+
TokenUsage,
|
61 |
+
ToolCall,
|
62 |
+
)
|
63 |
+
from .models import (
|
64 |
+
CODEAGENT_RESPONSE_FORMAT,
|
65 |
+
ChatMessage,
|
66 |
+
ChatMessageStreamDelta,
|
67 |
+
ChatMessageToolCall,
|
68 |
+
MessageRole,
|
69 |
+
Model,
|
70 |
+
agglomerate_stream_deltas,
|
71 |
+
parse_json_if_needed,
|
72 |
+
)
|
73 |
+
from .monitoring import (
|
74 |
+
YELLOW_HEX,
|
75 |
+
AgentLogger,
|
76 |
+
LogLevel,
|
77 |
+
Monitor,
|
78 |
+
)
|
79 |
+
from .remote_executors import DockerExecutor, E2BExecutor
|
80 |
+
from .tools import Tool, validate_tool_arguments
|
81 |
+
from .utils import (
|
82 |
+
AGENT_GRADIO_APP_TEMPLATE,
|
83 |
+
AgentError,
|
84 |
+
AgentExecutionError,
|
85 |
+
AgentGenerationError,
|
86 |
+
AgentMaxStepsError,
|
87 |
+
AgentParsingError,
|
88 |
+
AgentToolCallError,
|
89 |
+
AgentToolExecutionError,
|
90 |
+
extract_code_from_text,
|
91 |
+
is_valid_name,
|
92 |
+
make_init_file,
|
93 |
+
parse_code_blobs,
|
94 |
+
truncate_content,
|
95 |
+
)
|
96 |
+
|
97 |
+
|
98 |
+
logger = getLogger(__name__)
|
99 |
+
|
100 |
+
|
101 |
+
def get_variable_names(self, template: str) -> set[str]:
|
102 |
+
pattern = re.compile(r"\{\{([^{}]+)\}\}")
|
103 |
+
return {match.group(1).strip() for match in pattern.finditer(template)}
|
104 |
+
|
105 |
+
|
106 |
+
def populate_template(template: str, variables: dict[str, Any]) -> str:
|
107 |
+
compiled_template = Template(template, undefined=StrictUndefined)
|
108 |
+
try:
|
109 |
+
return compiled_template.render(**variables)
|
110 |
+
except Exception as e:
|
111 |
+
raise Exception(f"Error during jinja template rendering: {type(e).__name__}: {e}")
|
112 |
+
|
113 |
+
|
114 |
+
@dataclass
|
115 |
+
class ActionOutput:
|
116 |
+
output: Any
|
117 |
+
is_final_answer: bool
|
118 |
+
|
119 |
+
|
120 |
+
@dataclass
|
121 |
+
class ToolOutput:
|
122 |
+
id: str
|
123 |
+
output: Any
|
124 |
+
is_final_answer: bool
|
125 |
+
observation: str
|
126 |
+
tool_call: ToolCall
|
127 |
+
|
128 |
+
|
129 |
+
class PlanningPromptTemplate(TypedDict):
|
130 |
+
"""
|
131 |
+
Prompt templates for the planning step.
|
132 |
+
|
133 |
+
Args:
|
134 |
+
plan (`str`): Initial plan prompt.
|
135 |
+
update_plan_pre_messages (`str`): Update plan pre-messages prompt.
|
136 |
+
update_plan_post_messages (`str`): Update plan post-messages prompt.
|
137 |
+
"""
|
138 |
+
|
139 |
+
initial_plan: str
|
140 |
+
update_plan_pre_messages: str
|
141 |
+
update_plan_post_messages: str
|
142 |
+
|
143 |
+
|
144 |
+
class ManagedAgentPromptTemplate(TypedDict):
|
145 |
+
"""
|
146 |
+
Prompt templates for the managed agent.
|
147 |
+
|
148 |
+
Args:
|
149 |
+
task (`str`): Task prompt.
|
150 |
+
report (`str`): Report prompt.
|
151 |
+
"""
|
152 |
+
|
153 |
+
task: str
|
154 |
+
report: str
|
155 |
+
|
156 |
+
|
157 |
+
class FinalAnswerPromptTemplate(TypedDict):
|
158 |
+
"""
|
159 |
+
Prompt templates for the final answer.
|
160 |
+
|
161 |
+
Args:
|
162 |
+
pre_messages (`str`): Pre-messages prompt.
|
163 |
+
post_messages (`str`): Post-messages prompt.
|
164 |
+
"""
|
165 |
+
|
166 |
+
pre_messages: str
|
167 |
+
post_messages: str
|
168 |
+
|
169 |
+
|
170 |
+
class PromptTemplates(TypedDict):
|
171 |
+
"""
|
172 |
+
Prompt templates for the agent.
|
173 |
+
|
174 |
+
Args:
|
175 |
+
system_prompt (`str`): System prompt.
|
176 |
+
planning ([`~agents.PlanningPromptTemplate`]): Planning prompt templates.
|
177 |
+
managed_agent ([`~agents.ManagedAgentPromptTemplate`]): Managed agent prompt templates.
|
178 |
+
final_answer ([`~agents.FinalAnswerPromptTemplate`]): Final answer prompt templates.
|
179 |
+
"""
|
180 |
+
|
181 |
+
system_prompt: str
|
182 |
+
planning: PlanningPromptTemplate
|
183 |
+
managed_agent: ManagedAgentPromptTemplate
|
184 |
+
final_answer: FinalAnswerPromptTemplate
|
185 |
+
|
186 |
+
|
187 |
+
EMPTY_PROMPT_TEMPLATES = PromptTemplates(
|
188 |
+
system_prompt="",
|
189 |
+
planning=PlanningPromptTemplate(
|
190 |
+
initial_plan="",
|
191 |
+
update_plan_pre_messages="",
|
192 |
+
update_plan_post_messages="",
|
193 |
+
),
|
194 |
+
managed_agent=ManagedAgentPromptTemplate(task="", report=""),
|
195 |
+
final_answer=FinalAnswerPromptTemplate(pre_messages="", post_messages=""),
|
196 |
+
)
|
197 |
+
|
198 |
+
|
199 |
+
@dataclass
|
200 |
+
class RunResult:
|
201 |
+
"""Holds extended information about an agent run.
|
202 |
+
|
203 |
+
Attributes:
|
204 |
+
output (Any | None): The final output of the agent run, if available.
|
205 |
+
state (Literal["success", "max_steps_error"]): The final state of the agent after the run.
|
206 |
+
messages (list[dict]): The agent's memory, as a list of messages.
|
207 |
+
token_usage (TokenUsage | None): Count of tokens used during the run.
|
208 |
+
timing (Timing): Timing details of the agent run: start time, end time, duration.
|
209 |
+
"""
|
210 |
+
|
211 |
+
output: Any | None
|
212 |
+
state: Literal["success", "max_steps_error"]
|
213 |
+
messages: list[dict]
|
214 |
+
token_usage: TokenUsage | None
|
215 |
+
timing: Timing
|
216 |
+
|
217 |
+
|
218 |
+
StreamEvent: TypeAlias = Union[
|
219 |
+
ChatMessageStreamDelta,
|
220 |
+
ChatMessageToolCall,
|
221 |
+
ActionOutput,
|
222 |
+
ToolCall,
|
223 |
+
ToolOutput,
|
224 |
+
PlanningStep,
|
225 |
+
ActionStep,
|
226 |
+
FinalAnswerStep,
|
227 |
+
]
|
228 |
+
|
229 |
+
|
230 |
+
class MultiStepAgent(ABC):
|
231 |
+
"""
|
232 |
+
Agent class that solves the given task step by step, using the ReAct framework:
|
233 |
+
While the objective is not reached, the agent will perform a cycle of action (given by the LLM) and observation (obtained from the environment).
|
234 |
+
|
235 |
+
Args:
|
236 |
+
tools (`list[Tool]`): [`Tool`]s that the agent can use.
|
237 |
+
model (`Callable[[list[dict[str, str]]], ChatMessage]`): Model that will generate the agent's actions.
|
238 |
+
prompt_templates ([`~agents.PromptTemplates`], *optional*): Prompt templates.
|
239 |
+
instructions (`str`, *optional*): Custom instructions for the agent, will be inserted in the system prompt.
|
240 |
+
max_steps (`int`, default `20`): Maximum number of steps the agent can take to solve the task.
|
241 |
+
add_base_tools (`bool`, default `False`): Whether to add the base tools to the agent's tools.
|
242 |
+
verbosity_level (`LogLevel`, default `LogLevel.INFO`): Level of verbosity of the agent's logs.
|
243 |
+
grammar (`dict[str, str]`, *optional*): Grammar used to parse the LLM output.
|
244 |
+
<Deprecated version="1.17.0">
|
245 |
+
Parameter `grammar` is deprecated and will be removed in version 1.20.
|
246 |
+
</Deprecated>
|
247 |
+
managed_agents (`list`, *optional*): Managed agents that the agent can call.
|
248 |
+
step_callbacks (`list[Callable]`, *optional*): Callbacks that will be called at each step.
|
249 |
+
planning_interval (`int`, *optional*): Interval at which the agent will run a planning step.
|
250 |
+
name (`str`, *optional*): Necessary for a managed agent only - the name by which this agent can be called.
|
251 |
+
description (`str`, *optional*): Necessary for a managed agent only - the description of this agent.
|
252 |
+
provide_run_summary (`bool`, *optional*): Whether to provide a run summary when called as a managed agent.
|
253 |
+
final_answer_checks (`list[Callable]`, *optional*): List of validation functions to run before accepting a final answer.
|
254 |
+
Each function should:
|
255 |
+
- Take the final answer and the agent's memory as arguments.
|
256 |
+
- Return a boolean indicating whether the final answer is valid.
|
257 |
+
"""
|
258 |
+
|
259 |
+
def __init__(
|
260 |
+
self,
|
261 |
+
tools: list[Tool],
|
262 |
+
model: Model,
|
263 |
+
prompt_templates: PromptTemplates | None = None,
|
264 |
+
instructions: str | None = None,
|
265 |
+
max_steps: int = 20,
|
266 |
+
add_base_tools: bool = False,
|
267 |
+
verbosity_level: LogLevel = LogLevel.INFO,
|
268 |
+
grammar: dict[str, str] | None = None,
|
269 |
+
managed_agents: list | None = None,
|
270 |
+
step_callbacks: list[Callable] | None = None,
|
271 |
+
planning_interval: int | None = None,
|
272 |
+
name: str | None = None,
|
273 |
+
description: str | None = None,
|
274 |
+
provide_run_summary: bool = False,
|
275 |
+
final_answer_checks: list[Callable] | None = None,
|
276 |
+
return_full_result: bool = False,
|
277 |
+
logger: AgentLogger | None = None,
|
278 |
+
):
|
279 |
+
self.agent_name = self.__class__.__name__
|
280 |
+
self.model = model
|
281 |
+
self.prompt_templates = prompt_templates or EMPTY_PROMPT_TEMPLATES
|
282 |
+
if prompt_templates is not None:
|
283 |
+
missing_keys = set(EMPTY_PROMPT_TEMPLATES.keys()) - set(prompt_templates.keys())
|
284 |
+
assert not missing_keys, (
|
285 |
+
f"Some prompt templates are missing from your custom `prompt_templates`: {missing_keys}"
|
286 |
+
)
|
287 |
+
for key, value in EMPTY_PROMPT_TEMPLATES.items():
|
288 |
+
if isinstance(value, dict):
|
289 |
+
for subkey in value.keys():
|
290 |
+
assert key in prompt_templates.keys() and (subkey in prompt_templates[key].keys()), (
|
291 |
+
f"Some prompt templates are missing from your custom `prompt_templates`: {subkey} under {key}"
|
292 |
+
)
|
293 |
+
|
294 |
+
self.max_steps = max_steps
|
295 |
+
self.step_number = 0
|
296 |
+
if grammar is not None:
|
297 |
+
warnings.warn(
|
298 |
+
"Parameter 'grammar' is deprecated and will be removed in version 1.20.",
|
299 |
+
FutureWarning,
|
300 |
+
)
|
301 |
+
self.grammar = grammar
|
302 |
+
self.planning_interval = planning_interval
|
303 |
+
self.state: dict[str, Any] = {}
|
304 |
+
self.name = self._validate_name(name)
|
305 |
+
self.description = description
|
306 |
+
self.provide_run_summary = provide_run_summary
|
307 |
+
self.final_answer_checks = final_answer_checks if final_answer_checks is not None else []
|
308 |
+
self.return_full_result = return_full_result
|
309 |
+
self.instructions = instructions
|
310 |
+
self._setup_managed_agents(managed_agents)
|
311 |
+
self._setup_tools(tools, add_base_tools)
|
312 |
+
self._validate_tools_and_managed_agents(tools, managed_agents)
|
313 |
+
|
314 |
+
self.task: str | None = None
|
315 |
+
self.memory = AgentMemory(self.system_prompt)
|
316 |
+
|
317 |
+
if logger is None:
|
318 |
+
self.logger = AgentLogger(level=verbosity_level)
|
319 |
+
else:
|
320 |
+
self.logger = logger
|
321 |
+
|
322 |
+
self.monitor = Monitor(self.model, self.logger)
|
323 |
+
self.step_callbacks = step_callbacks if step_callbacks is not None else []
|
324 |
+
self.step_callbacks.append(self.monitor.update_metrics)
|
325 |
+
self.stream_outputs = False
|
326 |
+
|
327 |
+
@property
|
328 |
+
def system_prompt(self) -> str:
|
329 |
+
return self.initialize_system_prompt()
|
330 |
+
|
331 |
+
@system_prompt.setter
|
332 |
+
def system_prompt(self, value: str):
|
333 |
+
raise AttributeError(
|
334 |
+
"""The 'system_prompt' property is read-only. Use 'self.prompt_templates["system_prompt"]' instead."""
|
335 |
+
)
|
336 |
+
|
337 |
+
def _validate_name(self, name: str | None) -> str | None:
|
338 |
+
if name is not None and not is_valid_name(name):
|
339 |
+
raise ValueError(f"Agent name '{name}' must be a valid Python identifier and not a reserved keyword.")
|
340 |
+
return name
|
341 |
+
|
342 |
+
def _setup_managed_agents(self, managed_agents: list | None = None) -> None:
|
343 |
+
"""Setup managed agents with proper logging."""
|
344 |
+
self.managed_agents = {}
|
345 |
+
if managed_agents:
|
346 |
+
assert all(agent.name and agent.description for agent in managed_agents), (
|
347 |
+
"All managed agents need both a name and a description!"
|
348 |
+
)
|
349 |
+
self.managed_agents = {agent.name: agent for agent in managed_agents}
|
350 |
+
# Ensure managed agents can be called as tools by the model: set their inputs and output_type
|
351 |
+
for agent in self.managed_agents.values():
|
352 |
+
agent.inputs = {
|
353 |
+
"task": {"type": "string", "description": "Long detailed description of the task."},
|
354 |
+
"additional_args": {
|
355 |
+
"type": "object",
|
356 |
+
"description": "Dictionary of extra inputs to pass to the managed agent, e.g. images, dataframes, or any other contextual data it may need.",
|
357 |
+
},
|
358 |
+
}
|
359 |
+
agent.output_type = "string"
|
360 |
+
|
361 |
+
def _setup_tools(self, tools, add_base_tools):
|
362 |
+
assert all(isinstance(tool, Tool) for tool in tools), "All elements must be instance of Tool (or a subclass)"
|
363 |
+
self.tools = {tool.name: tool for tool in tools}
|
364 |
+
if add_base_tools:
|
365 |
+
self.tools.update(
|
366 |
+
{
|
367 |
+
name: cls()
|
368 |
+
for name, cls in TOOL_MAPPING.items()
|
369 |
+
if name != "python_interpreter" or self.__class__.__name__ == "ToolCallingAgent"
|
370 |
+
}
|
371 |
+
)
|
372 |
+
self.tools.setdefault("final_answer", FinalAnswerTool())
|
373 |
+
|
374 |
+
def _validate_tools_and_managed_agents(self, tools, managed_agents):
|
375 |
+
tool_and_managed_agent_names = [tool.name for tool in tools]
|
376 |
+
if managed_agents is not None:
|
377 |
+
tool_and_managed_agent_names += [agent.name for agent in managed_agents]
|
378 |
+
if self.name:
|
379 |
+
tool_and_managed_agent_names.append(self.name)
|
380 |
+
if len(tool_and_managed_agent_names) != len(set(tool_and_managed_agent_names)):
|
381 |
+
raise ValueError(
|
382 |
+
"Each tool or managed_agent should have a unique name! You passed these duplicate names: "
|
383 |
+
f"{[name for name in tool_and_managed_agent_names if tool_and_managed_agent_names.count(name) > 1]}"
|
384 |
+
)
|
385 |
+
|
386 |
+
def run(
|
387 |
+
self,
|
388 |
+
task: str,
|
389 |
+
stream: bool = False,
|
390 |
+
reset: bool = True,
|
391 |
+
images: list["PIL.Image.Image"] | None = None,
|
392 |
+
additional_args: dict | None = None,
|
393 |
+
max_steps: int | None = None,
|
394 |
+
):
|
395 |
+
"""
|
396 |
+
Run the agent for the given task.
|
397 |
+
|
398 |
+
Args:
|
399 |
+
task (`str`): Task to perform.
|
400 |
+
stream (`bool`): Whether to run in streaming mode.
|
401 |
+
If `True`, returns a generator that yields each step as it is executed. You must iterate over this generator to process the individual steps (e.g., using a for loop or `next()`).
|
402 |
+
If `False`, executes all steps internally and returns only the final answer after completion.
|
403 |
+
reset (`bool`): Whether to reset the conversation or keep it going from previous run.
|
404 |
+
images (`list[PIL.Image.Image]`, *optional*): Image(s) objects.
|
405 |
+
additional_args (`dict`, *optional*): Any other variables that you want to pass to the agent run, for instance images or dataframes. Give them clear names!
|
406 |
+
max_steps (`int`, *optional*): Maximum number of steps the agent can take to solve the task. if not provided, will use the agent's default value.
|
407 |
+
|
408 |
+
Example:
|
409 |
+
```py
|
410 |
+
from smolagents import CodeAgent
|
411 |
+
agent = CodeAgent(tools=[])
|
412 |
+
agent.run("What is the result of 2 power 3.7384?")
|
413 |
+
```
|
414 |
+
"""
|
415 |
+
max_steps = max_steps or self.max_steps
|
416 |
+
self.task = task
|
417 |
+
self.interrupt_switch = False
|
418 |
+
if additional_args is not None:
|
419 |
+
self.state.update(additional_args)
|
420 |
+
self.task += f"""
|
421 |
+
You have been provided with these additional arguments, that you can access using the keys as variables in your python code:
|
422 |
+
{str(additional_args)}."""
|
423 |
+
|
424 |
+
self.memory.system_prompt = SystemPromptStep(system_prompt=self.system_prompt)
|
425 |
+
if reset:
|
426 |
+
self.memory.reset()
|
427 |
+
self.monitor.reset()
|
428 |
+
|
429 |
+
self.logger.log_task(
|
430 |
+
content=self.task.strip(),
|
431 |
+
subtitle=f"{type(self.model).__name__} - {(self.model.model_id if hasattr(self.model, 'model_id') else '')}",
|
432 |
+
level=LogLevel.INFO,
|
433 |
+
title=self.name if hasattr(self, "name") else None,
|
434 |
+
)
|
435 |
+
self.memory.steps.append(TaskStep(task=self.task, task_images=images))
|
436 |
+
|
437 |
+
if getattr(self, "python_executor", None):
|
438 |
+
self.python_executor.send_variables(variables=self.state)
|
439 |
+
self.python_executor.send_tools({**self.tools, **self.managed_agents})
|
440 |
+
|
441 |
+
if stream:
|
442 |
+
# The steps are returned as they are executed through a generator to iterate on.
|
443 |
+
return self._run_stream(task=self.task, max_steps=max_steps, images=images)
|
444 |
+
run_start_time = time.time()
|
445 |
+
# Outputs are returned only at the end. We only look at the last step.
|
446 |
+
|
447 |
+
steps = list(self._run_stream(task=self.task, max_steps=max_steps, images=images))
|
448 |
+
assert isinstance(steps[-1], FinalAnswerStep)
|
449 |
+
output = steps[-1].output
|
450 |
+
|
451 |
+
if self.return_full_result:
|
452 |
+
total_input_tokens = 0
|
453 |
+
total_output_tokens = 0
|
454 |
+
correct_token_usage = True
|
455 |
+
for step in self.memory.steps:
|
456 |
+
if isinstance(step, (ActionStep, PlanningStep)):
|
457 |
+
if step.token_usage is None:
|
458 |
+
correct_token_usage = False
|
459 |
+
break
|
460 |
+
else:
|
461 |
+
total_input_tokens += step.token_usage.input_tokens
|
462 |
+
total_output_tokens += step.token_usage.output_tokens
|
463 |
+
if correct_token_usage:
|
464 |
+
token_usage = TokenUsage(input_tokens=total_input_tokens, output_tokens=total_output_tokens)
|
465 |
+
else:
|
466 |
+
token_usage = None
|
467 |
+
|
468 |
+
if self.memory.steps and isinstance(getattr(self.memory.steps[-1], "error", None), AgentMaxStepsError):
|
469 |
+
state = "max_steps_error"
|
470 |
+
else:
|
471 |
+
state = "success"
|
472 |
+
|
473 |
+
messages = self.memory.get_full_steps()
|
474 |
+
|
475 |
+
return RunResult(
|
476 |
+
output=output,
|
477 |
+
token_usage=token_usage,
|
478 |
+
messages=messages,
|
479 |
+
timing=Timing(start_time=run_start_time, end_time=time.time()),
|
480 |
+
state=state,
|
481 |
+
)
|
482 |
+
|
483 |
+
return output
|
484 |
+
|
485 |
+
def _run_stream(
|
486 |
+
self, task: str, max_steps: int, images: list["PIL.Image.Image"] | None = None
|
487 |
+
) -> Generator[ActionStep | PlanningStep | FinalAnswerStep | ChatMessageStreamDelta]:
|
488 |
+
self.step_number = 1
|
489 |
+
returned_final_answer = False
|
490 |
+
while not returned_final_answer and self.step_number <= max_steps:
|
491 |
+
if self.interrupt_switch:
|
492 |
+
raise AgentError("Agent interrupted.", self.logger)
|
493 |
+
|
494 |
+
# Run a planning step if scheduled
|
495 |
+
if self.planning_interval is not None and (
|
496 |
+
self.step_number == 1 or (self.step_number - 1) % self.planning_interval == 0
|
497 |
+
):
|
498 |
+
planning_start_time = time.time()
|
499 |
+
planning_step = None
|
500 |
+
for element in self._generate_planning_step(
|
501 |
+
task, is_first_step=len(self.memory.steps) == 1, step=self.step_number
|
502 |
+
): # Don't use the attribute step_number here, because there can be steps from previous runs
|
503 |
+
yield element
|
504 |
+
planning_step = element
|
505 |
+
assert isinstance(planning_step, PlanningStep) # Last yielded element should be a PlanningStep
|
506 |
+
self.memory.steps.append(planning_step)
|
507 |
+
planning_end_time = time.time()
|
508 |
+
planning_step.timing = Timing(
|
509 |
+
start_time=planning_start_time,
|
510 |
+
end_time=planning_end_time,
|
511 |
+
)
|
512 |
+
|
513 |
+
# Start action step!
|
514 |
+
action_step_start_time = time.time()
|
515 |
+
action_step = ActionStep(
|
516 |
+
step_number=self.step_number,
|
517 |
+
timing=Timing(start_time=action_step_start_time),
|
518 |
+
observations_images=images,
|
519 |
+
)
|
520 |
+
self.logger.log_rule(f"Step {self.step_number}", level=LogLevel.INFO)
|
521 |
+
try:
|
522 |
+
for output in self._step_stream(action_step):
|
523 |
+
# Yield all
|
524 |
+
yield output
|
525 |
+
|
526 |
+
if isinstance(output, ActionOutput) and output.is_final_answer:
|
527 |
+
final_answer = output.output
|
528 |
+
self.logger.log(
|
529 |
+
Text(f"Final answer: {final_answer}", style=f"bold {YELLOW_HEX}"),
|
530 |
+
level=LogLevel.INFO,
|
531 |
+
)
|
532 |
+
|
533 |
+
if self.final_answer_checks:
|
534 |
+
self._validate_final_answer(final_answer)
|
535 |
+
returned_final_answer = True
|
536 |
+
action_step.is_final_answer = True
|
537 |
+
|
538 |
+
except AgentGenerationError as e:
|
539 |
+
# Agent generation errors are not caused by a Model error but an implementation error: so we should raise them and exit.
|
540 |
+
raise e
|
541 |
+
except AgentError as e:
|
542 |
+
# Other AgentError types are caused by the Model, so we should log them and iterate.
|
543 |
+
action_step.error = e
|
544 |
+
finally:
|
545 |
+
self._finalize_step(action_step)
|
546 |
+
self.memory.steps.append(action_step)
|
547 |
+
yield action_step
|
548 |
+
self.step_number += 1
|
549 |
+
|
550 |
+
if not returned_final_answer and self.step_number == max_steps + 1:
|
551 |
+
final_answer = self._handle_max_steps_reached(task, images)
|
552 |
+
yield action_step
|
553 |
+
yield FinalAnswerStep(handle_agent_output_types(final_answer))
|
554 |
+
|
555 |
+
def _validate_final_answer(self, final_answer: Any):
|
556 |
+
for check_function in self.final_answer_checks:
|
557 |
+
try:
|
558 |
+
assert check_function(final_answer, self.memory)
|
559 |
+
except Exception as e:
|
560 |
+
raise AgentError(f"Check {check_function.__name__} failed with error: {e}", self.logger)
|
561 |
+
|
562 |
+
def _finalize_step(self, memory_step: ActionStep):
|
563 |
+
memory_step.timing.end_time = time.time()
|
564 |
+
for callback in self.step_callbacks:
|
565 |
+
# For compatibility with old callbacks that don't take the agent as an argument
|
566 |
+
callback(memory_step) if len(inspect.signature(callback).parameters) == 1 else callback(
|
567 |
+
memory_step, agent=self
|
568 |
+
)
|
569 |
+
|
570 |
+
def _handle_max_steps_reached(self, task: str, images: list["PIL.Image.Image"]) -> Any:
|
571 |
+
action_step_start_time = time.time()
|
572 |
+
final_answer = self.provide_final_answer(task, images)
|
573 |
+
final_memory_step = ActionStep(
|
574 |
+
step_number=self.step_number,
|
575 |
+
error=AgentMaxStepsError("Reached max steps.", self.logger),
|
576 |
+
timing=Timing(start_time=action_step_start_time, end_time=time.time()),
|
577 |
+
token_usage=final_answer.token_usage,
|
578 |
+
)
|
579 |
+
final_memory_step.action_output = final_answer.content
|
580 |
+
self._finalize_step(final_memory_step)
|
581 |
+
self.memory.steps.append(final_memory_step)
|
582 |
+
return final_answer.content
|
583 |
+
|
584 |
+
def _generate_planning_step(
|
585 |
+
self, task, is_first_step: bool, step: int
|
586 |
+
) -> Generator[ChatMessageStreamDelta | PlanningStep]:
|
587 |
+
start_time = time.time()
|
588 |
+
if is_first_step:
|
589 |
+
input_messages = [
|
590 |
+
ChatMessage(
|
591 |
+
role=MessageRole.USER,
|
592 |
+
content=[
|
593 |
+
{
|
594 |
+
"type": "text",
|
595 |
+
"text": populate_template(
|
596 |
+
self.prompt_templates["planning"]["initial_plan"],
|
597 |
+
variables={"task": task, "tools": self.tools, "managed_agents": self.managed_agents},
|
598 |
+
),
|
599 |
+
}
|
600 |
+
],
|
601 |
+
)
|
602 |
+
]
|
603 |
+
if self.stream_outputs and hasattr(self.model, "generate_stream"):
|
604 |
+
plan_message_content = ""
|
605 |
+
output_stream = self.model.generate_stream(input_messages, stop_sequences=["<end_plan>"]) # type: ignore
|
606 |
+
input_tokens, output_tokens = 0, 0
|
607 |
+
with Live("", console=self.logger.console, vertical_overflow="visible") as live:
|
608 |
+
for event in output_stream:
|
609 |
+
if event.content is not None:
|
610 |
+
plan_message_content += event.content
|
611 |
+
live.update(Markdown(plan_message_content))
|
612 |
+
if event.token_usage:
|
613 |
+
output_tokens += event.token_usage.output_tokens
|
614 |
+
input_tokens = event.token_usage.input_tokens
|
615 |
+
yield event
|
616 |
+
else:
|
617 |
+
plan_message = self.model.generate(input_messages, stop_sequences=["<end_plan>"])
|
618 |
+
plan_message_content = plan_message.content
|
619 |
+
input_tokens, output_tokens = (
|
620 |
+
(
|
621 |
+
plan_message.token_usage.input_tokens,
|
622 |
+
plan_message.token_usage.output_tokens,
|
623 |
+
)
|
624 |
+
if plan_message.token_usage
|
625 |
+
else (None, None)
|
626 |
+
)
|
627 |
+
plan = textwrap.dedent(
|
628 |
+
f"""Here are the facts I know and the plan of action that I will follow to solve the task:\n```\n{plan_message_content}\n```"""
|
629 |
+
)
|
630 |
+
else:
|
631 |
+
# Summary mode removes the system prompt and previous planning messages output by the model.
|
632 |
+
# Removing previous planning messages avoids influencing too much the new plan.
|
633 |
+
memory_messages = self.write_memory_to_messages(summary_mode=True)
|
634 |
+
plan_update_pre = ChatMessage(
|
635 |
+
role=MessageRole.SYSTEM,
|
636 |
+
content=[
|
637 |
+
{
|
638 |
+
"type": "text",
|
639 |
+
"text": populate_template(
|
640 |
+
self.prompt_templates["planning"]["update_plan_pre_messages"], variables={"task": task}
|
641 |
+
),
|
642 |
+
}
|
643 |
+
],
|
644 |
+
)
|
645 |
+
plan_update_post = ChatMessage(
|
646 |
+
role=MessageRole.USER,
|
647 |
+
content=[
|
648 |
+
{
|
649 |
+
"type": "text",
|
650 |
+
"text": populate_template(
|
651 |
+
self.prompt_templates["planning"]["update_plan_post_messages"],
|
652 |
+
variables={
|
653 |
+
"task": task,
|
654 |
+
"tools": self.tools,
|
655 |
+
"managed_agents": self.managed_agents,
|
656 |
+
"remaining_steps": (self.max_steps - step),
|
657 |
+
},
|
658 |
+
),
|
659 |
+
}
|
660 |
+
],
|
661 |
+
)
|
662 |
+
input_messages = [plan_update_pre] + memory_messages + [plan_update_post]
|
663 |
+
if self.stream_outputs and hasattr(self.model, "generate_stream"):
|
664 |
+
plan_message_content = ""
|
665 |
+
input_tokens, output_tokens = 0, 0
|
666 |
+
with Live("", console=self.logger.console, vertical_overflow="visible") as live:
|
667 |
+
for event in self.model.generate_stream(
|
668 |
+
input_messages,
|
669 |
+
stop_sequences=["<end_plan>"],
|
670 |
+
): # type: ignore
|
671 |
+
if event.content is not None:
|
672 |
+
plan_message_content += event.content
|
673 |
+
live.update(Markdown(plan_message_content))
|
674 |
+
if event.token_usage:
|
675 |
+
output_tokens += event.token_usage.output_tokens
|
676 |
+
input_tokens = event.token_usage.input_tokens
|
677 |
+
yield event
|
678 |
+
else:
|
679 |
+
plan_message = self.model.generate(input_messages, stop_sequences=["<end_plan>"])
|
680 |
+
plan_message_content = plan_message.content
|
681 |
+
if plan_message.token_usage is not None:
|
682 |
+
input_tokens, output_tokens = (
|
683 |
+
plan_message.token_usage.input_tokens,
|
684 |
+
plan_message.token_usage.output_tokens,
|
685 |
+
)
|
686 |
+
plan = textwrap.dedent(
|
687 |
+
f"""I still need to solve the task I was given:\n```\n{self.task}\n```\n\nHere are the facts I know and my new/updated plan of action to solve the task:\n```\n{plan_message_content}\n```"""
|
688 |
+
)
|
689 |
+
log_headline = "Initial plan" if is_first_step else "Updated plan"
|
690 |
+
self.logger.log(Rule(f"[bold]{log_headline}", style="orange"), Text(plan), level=LogLevel.INFO)
|
691 |
+
yield PlanningStep(
|
692 |
+
model_input_messages=input_messages,
|
693 |
+
plan=plan,
|
694 |
+
model_output_message=ChatMessage(role=MessageRole.ASSISTANT, content=plan_message_content),
|
695 |
+
token_usage=TokenUsage(input_tokens=input_tokens, output_tokens=output_tokens),
|
696 |
+
timing=Timing(start_time=start_time, end_time=time.time()),
|
697 |
+
)
|
698 |
+
|
699 |
+
@property
|
700 |
+
def logs(self):
|
701 |
+
logger.warning(
|
702 |
+
"The 'logs' attribute is deprecated and will soon be removed. Please use 'self.memory.steps' instead."
|
703 |
+
)
|
704 |
+
return [self.memory.system_prompt] + self.memory.steps
|
705 |
+
|
706 |
+
@abstractmethod
|
707 |
+
def initialize_system_prompt(self) -> str:
|
708 |
+
"""To be implemented in child classes"""
|
709 |
+
...
|
710 |
+
|
711 |
+
def interrupt(self):
|
712 |
+
"""Interrupts the agent execution."""
|
713 |
+
self.interrupt_switch = True
|
714 |
+
|
715 |
+
def write_memory_to_messages(
|
716 |
+
self,
|
717 |
+
summary_mode: bool = False,
|
718 |
+
) -> list[ChatMessage]:
|
719 |
+
"""
|
720 |
+
Reads past llm_outputs, actions, and observations or errors from the memory into a series of messages
|
721 |
+
that can be used as input to the LLM. Adds a number of keywords (such as PLAN, error, etc) to help
|
722 |
+
the LLM.
|
723 |
+
"""
|
724 |
+
messages = self.memory.system_prompt.to_messages(summary_mode=summary_mode)
|
725 |
+
for memory_step in self.memory.steps:
|
726 |
+
messages.extend(memory_step.to_messages(summary_mode=summary_mode))
|
727 |
+
return messages
|
728 |
+
|
729 |
+
def _step_stream(
|
730 |
+
self, memory_step: ActionStep
|
731 |
+
) -> Generator[ChatMessageStreamDelta | ToolCall | ToolOutput | ActionOutput]:
|
732 |
+
"""
|
733 |
+
Perform one step in the ReAct framework: the agent thinks, acts, and observes the result.
|
734 |
+
Yields ChatMessageStreamDelta during the run if streaming is enabled.
|
735 |
+
At the end, yields either None if the step is not final, or the final answer.
|
736 |
+
"""
|
737 |
+
raise NotImplementedError("This method should be implemented in child classes")
|
738 |
+
|
739 |
+
def step(self, memory_step: ActionStep) -> Any:
|
740 |
+
"""
|
741 |
+
Perform one step in the ReAct framework: the agent thinks, acts, and observes the result.
|
742 |
+
Returns either None if the step is not final, or the final answer.
|
743 |
+
"""
|
744 |
+
return list(self._step_stream(memory_step))[-1]
|
745 |
+
|
746 |
+
def extract_action(self, model_output: str, split_token: str) -> tuple[str, str]:
|
747 |
+
"""
|
748 |
+
Parse action from the LLM output
|
749 |
+
|
750 |
+
Args:
|
751 |
+
model_output (`str`): Output of the LLM
|
752 |
+
split_token (`str`): Separator for the action. Should match the example in the system prompt.
|
753 |
+
"""
|
754 |
+
try:
|
755 |
+
split = model_output.split(split_token)
|
756 |
+
rationale, action = (
|
757 |
+
split[-2],
|
758 |
+
split[-1],
|
759 |
+
) # NOTE: using indexes starting from the end solves for when you have more than one split_token in the output
|
760 |
+
except Exception:
|
761 |
+
raise AgentParsingError(
|
762 |
+
f"No '{split_token}' token provided in your output.\nYour output:\n{model_output}\n. Be sure to include an action, prefaced with '{split_token}'!",
|
763 |
+
self.logger,
|
764 |
+
)
|
765 |
+
return rationale.strip(), action.strip()
|
766 |
+
|
767 |
+
def provide_final_answer(self, task: str, images: list["PIL.Image.Image"] | None = None) -> ChatMessage:
|
768 |
+
"""
|
769 |
+
Provide the final answer to the task, based on the logs of the agent's interactions.
|
770 |
+
|
771 |
+
Args:
|
772 |
+
task (`str`): Task to perform.
|
773 |
+
images (`list[PIL.Image.Image]`, *optional*): Image(s) objects.
|
774 |
+
|
775 |
+
Returns:
|
776 |
+
`str`: Final answer to the task.
|
777 |
+
"""
|
778 |
+
messages = [
|
779 |
+
ChatMessage(
|
780 |
+
role=MessageRole.SYSTEM,
|
781 |
+
content=[
|
782 |
+
{
|
783 |
+
"type": "text",
|
784 |
+
"text": self.prompt_templates["final_answer"]["pre_messages"],
|
785 |
+
}
|
786 |
+
],
|
787 |
+
)
|
788 |
+
]
|
789 |
+
if images:
|
790 |
+
messages[0].content += [{"type": "image", "image": image} for image in images]
|
791 |
+
messages += self.write_memory_to_messages()[1:]
|
792 |
+
messages.append(
|
793 |
+
ChatMessage(
|
794 |
+
role=MessageRole.USER,
|
795 |
+
content=[
|
796 |
+
{
|
797 |
+
"type": "text",
|
798 |
+
"text": populate_template(
|
799 |
+
self.prompt_templates["final_answer"]["post_messages"], variables={"task": task}
|
800 |
+
),
|
801 |
+
}
|
802 |
+
],
|
803 |
+
)
|
804 |
+
)
|
805 |
+
try:
|
806 |
+
chat_message: ChatMessage = self.model.generate(messages)
|
807 |
+
return chat_message
|
808 |
+
except Exception as e:
|
809 |
+
return ChatMessage(role=MessageRole.ASSISTANT, content=f"Error in generating final LLM output:\n{e}")
|
810 |
+
|
811 |
+
def visualize(self):
|
812 |
+
"""Creates a rich tree visualization of the agent's structure."""
|
813 |
+
self.logger.visualize_agent_tree(self)
|
814 |
+
|
815 |
+
def replay(self, detailed: bool = False):
|
816 |
+
"""Prints a pretty replay of the agent's steps.
|
817 |
+
|
818 |
+
Args:
|
819 |
+
detailed (bool, optional): If True, also displays the memory at each step. Defaults to False.
|
820 |
+
Careful: will increase log length exponentially. Use only for debugging.
|
821 |
+
"""
|
822 |
+
self.memory.replay(self.logger, detailed=detailed)
|
823 |
+
|
824 |
+
def __call__(self, task: str, **kwargs):
|
825 |
+
"""Adds additional prompting for the managed agent, runs it, and wraps the output.
|
826 |
+
This method is called only by a managed agent.
|
827 |
+
"""
|
828 |
+
full_task = populate_template(
|
829 |
+
self.prompt_templates["managed_agent"]["task"],
|
830 |
+
variables=dict(name=self.name, task=task),
|
831 |
+
)
|
832 |
+
result = self.run(full_task, **kwargs)
|
833 |
+
if isinstance(result, RunResult):
|
834 |
+
report = result.output
|
835 |
+
else:
|
836 |
+
report = result
|
837 |
+
answer = populate_template(
|
838 |
+
self.prompt_templates["managed_agent"]["report"], variables=dict(name=self.name, final_answer=report)
|
839 |
+
)
|
840 |
+
if self.provide_run_summary:
|
841 |
+
answer += "\n\nFor more detail, find below a summary of this agent's work:\n<summary_of_work>\n"
|
842 |
+
for message in self.write_memory_to_messages(summary_mode=True):
|
843 |
+
content = message["content"]
|
844 |
+
answer += "\n" + truncate_content(str(content)) + "\n---"
|
845 |
+
answer += "\n</summary_of_work>"
|
846 |
+
return answer
|
847 |
+
|
848 |
+
def save(self, output_dir: str | Path, relative_path: str | None = None):
|
849 |
+
"""
|
850 |
+
Saves the relevant code files for your agent. This will copy the code of your agent in `output_dir` as well as autogenerate:
|
851 |
+
|
852 |
+
- a `tools` folder containing the logic for each of the tools under `tools/{tool_name}.py`.
|
853 |
+
- a `managed_agents` folder containing the logic for each of the managed agents.
|
854 |
+
- an `agent.json` file containing a dictionary representing your agent.
|
855 |
+
- a `prompt.yaml` file containing the prompt templates used by your agent.
|
856 |
+
- an `app.py` file providing a UI for your agent when it is exported to a Space with `agent.push_to_hub()`
|
857 |
+
- a `requirements.txt` containing the names of the modules used by your tool (as detected when inspecting its
|
858 |
+
code)
|
859 |
+
|
860 |
+
Args:
|
861 |
+
output_dir (`str` or `Path`): The folder in which you want to save your agent.
|
862 |
+
"""
|
863 |
+
make_init_file(output_dir)
|
864 |
+
|
865 |
+
# Recursively save managed agents
|
866 |
+
if self.managed_agents:
|
867 |
+
make_init_file(os.path.join(output_dir, "managed_agents"))
|
868 |
+
for agent_name, agent in self.managed_agents.items():
|
869 |
+
agent_suffix = f"managed_agents.{agent_name}"
|
870 |
+
if relative_path:
|
871 |
+
agent_suffix = relative_path + "." + agent_suffix
|
872 |
+
agent.save(os.path.join(output_dir, "managed_agents", agent_name), relative_path=agent_suffix)
|
873 |
+
|
874 |
+
class_name = self.__class__.__name__
|
875 |
+
|
876 |
+
# Save tools to different .py files
|
877 |
+
for tool in self.tools.values():
|
878 |
+
make_init_file(os.path.join(output_dir, "tools"))
|
879 |
+
tool.save(os.path.join(output_dir, "tools"), tool_file_name=tool.name, make_gradio_app=False)
|
880 |
+
|
881 |
+
# Save prompts to yaml
|
882 |
+
yaml_prompts = yaml.safe_dump(
|
883 |
+
self.prompt_templates,
|
884 |
+
default_style="|", # This forces block literals for all strings
|
885 |
+
default_flow_style=False,
|
886 |
+
width=float("inf"),
|
887 |
+
sort_keys=False,
|
888 |
+
allow_unicode=True,
|
889 |
+
indent=2,
|
890 |
+
)
|
891 |
+
|
892 |
+
with open(os.path.join(output_dir, "prompts.yaml"), "w", encoding="utf-8") as f:
|
893 |
+
f.write(yaml_prompts)
|
894 |
+
|
895 |
+
# Save agent dictionary to json
|
896 |
+
agent_dict = self.to_dict()
|
897 |
+
agent_dict["tools"] = [tool.name for tool in self.tools.values()]
|
898 |
+
agent_dict["managed_agents"] = {agent.name: agent.__class__.__name__ for agent in self.managed_agents.values()}
|
899 |
+
with open(os.path.join(output_dir, "agent.json"), "w", encoding="utf-8") as f:
|
900 |
+
json.dump(agent_dict, f, indent=4)
|
901 |
+
|
902 |
+
# Save requirements
|
903 |
+
with open(os.path.join(output_dir, "requirements.txt"), "w", encoding="utf-8") as f:
|
904 |
+
f.writelines(f"{r}\n" for r in agent_dict["requirements"])
|
905 |
+
|
906 |
+
# Make agent.py file with Gradio UI
|
907 |
+
agent_name = f"agent_{self.name}" if getattr(self, "name", None) else "agent"
|
908 |
+
managed_agent_relative_path = relative_path + "." if relative_path is not None else ""
|
909 |
+
app_template = AGENT_GRADIO_APP_TEMPLATE
|
910 |
+
template_env = jinja2.Environment(loader=jinja2.BaseLoader(), undefined=jinja2.StrictUndefined)
|
911 |
+
template_env.filters["repr"] = repr
|
912 |
+
template_env.filters["camelcase"] = lambda value: "".join(word.capitalize() for word in value.split("_"))
|
913 |
+
template = template_env.from_string(app_template)
|
914 |
+
|
915 |
+
# Render the app.py file from Jinja2 template
|
916 |
+
app_text = template.render(
|
917 |
+
{
|
918 |
+
"agent_name": agent_name,
|
919 |
+
"class_name": class_name,
|
920 |
+
"agent_dict": agent_dict,
|
921 |
+
"tools": self.tools,
|
922 |
+
"managed_agents": self.managed_agents,
|
923 |
+
"managed_agent_relative_path": managed_agent_relative_path,
|
924 |
+
}
|
925 |
+
)
|
926 |
+
|
927 |
+
with open(os.path.join(output_dir, "app.py"), "w", encoding="utf-8") as f:
|
928 |
+
f.write(app_text + "\n") # Append newline at the end
|
929 |
+
|
930 |
+
def to_dict(self) -> dict[str, Any]:
|
931 |
+
"""Convert the agent to a dictionary representation.
|
932 |
+
|
933 |
+
Returns:
|
934 |
+
`dict`: Dictionary representation of the agent.
|
935 |
+
"""
|
936 |
+
# TODO: handle serializing step_callbacks and final_answer_checks
|
937 |
+
for attr in ["final_answer_checks", "step_callbacks"]:
|
938 |
+
if getattr(self, attr, None):
|
939 |
+
self.logger.log(f"This agent has {attr}: they will be ignored by this method.", LogLevel.INFO)
|
940 |
+
|
941 |
+
tool_dicts = [tool.to_dict() for tool in self.tools.values()]
|
942 |
+
tool_requirements = {req for tool in self.tools.values() for req in tool.to_dict()["requirements"]}
|
943 |
+
managed_agents_requirements = {
|
944 |
+
req for managed_agent in self.managed_agents.values() for req in managed_agent.to_dict()["requirements"]
|
945 |
+
}
|
946 |
+
requirements = tool_requirements | managed_agents_requirements
|
947 |
+
if hasattr(self, "authorized_imports"):
|
948 |
+
requirements.update(
|
949 |
+
{package.split(".")[0] for package in self.authorized_imports if package not in BASE_BUILTIN_MODULES}
|
950 |
+
)
|
951 |
+
|
952 |
+
agent_dict = {
|
953 |
+
"class": self.__class__.__name__,
|
954 |
+
"tools": tool_dicts,
|
955 |
+
"model": {
|
956 |
+
"class": self.model.__class__.__name__,
|
957 |
+
"data": self.model.to_dict(),
|
958 |
+
},
|
959 |
+
"managed_agents": [managed_agent.to_dict() for managed_agent in self.managed_agents.values()],
|
960 |
+
"prompt_templates": self.prompt_templates,
|
961 |
+
"max_steps": self.max_steps,
|
962 |
+
"verbosity_level": int(self.logger.level),
|
963 |
+
"grammar": self.grammar,
|
964 |
+
"planning_interval": self.planning_interval,
|
965 |
+
"name": self.name,
|
966 |
+
"description": self.description,
|
967 |
+
"requirements": sorted(requirements),
|
968 |
+
}
|
969 |
+
return agent_dict
|
970 |
+
|
971 |
+
@classmethod
|
972 |
+
def from_dict(cls, agent_dict: dict[str, Any], **kwargs) -> "MultiStepAgent":
|
973 |
+
"""Create agent from a dictionary representation.
|
974 |
+
|
975 |
+
Args:
|
976 |
+
agent_dict (`dict[str, Any]`): Dictionary representation of the agent.
|
977 |
+
**kwargs: Additional keyword arguments that will override agent_dict values.
|
978 |
+
|
979 |
+
Returns:
|
980 |
+
`MultiStepAgent`: Instance of the agent class.
|
981 |
+
"""
|
982 |
+
# Load model
|
983 |
+
model_info = agent_dict["model"]
|
984 |
+
model_class = getattr(importlib.import_module("smolagents.models"), model_info["class"])
|
985 |
+
model = model_class.from_dict(model_info["data"])
|
986 |
+
# Load tools
|
987 |
+
tools = []
|
988 |
+
for tool_info in agent_dict["tools"]:
|
989 |
+
tools.append(Tool.from_code(tool_info["code"]))
|
990 |
+
# Load managed agents
|
991 |
+
managed_agents = []
|
992 |
+
for managed_agent_name, managed_agent_class_name in agent_dict["managed_agents"].items():
|
993 |
+
managed_agent_class = getattr(importlib.import_module("smolagents.agents"), managed_agent_class_name)
|
994 |
+
managed_agents.append(managed_agent_class.from_dict(agent_dict["managed_agents"][managed_agent_name]))
|
995 |
+
# Extract base agent parameters
|
996 |
+
agent_args = {
|
997 |
+
"model": model,
|
998 |
+
"tools": tools,
|
999 |
+
"prompt_templates": agent_dict.get("prompt_templates"),
|
1000 |
+
"max_steps": agent_dict.get("max_steps"),
|
1001 |
+
"verbosity_level": agent_dict.get("verbosity_level"),
|
1002 |
+
"grammar": agent_dict.get("grammar"),
|
1003 |
+
"planning_interval": agent_dict.get("planning_interval"),
|
1004 |
+
"name": agent_dict.get("name"),
|
1005 |
+
"description": agent_dict.get("description"),
|
1006 |
+
}
|
1007 |
+
# Filter out None values to use defaults from __init__
|
1008 |
+
agent_args = {k: v for k, v in agent_args.items() if v is not None}
|
1009 |
+
# Update with any additional kwargs
|
1010 |
+
agent_args.update(kwargs)
|
1011 |
+
# Create agent instance
|
1012 |
+
return cls(**agent_args)
|
1013 |
+
|
1014 |
+
@classmethod
|
1015 |
+
def from_hub(
|
1016 |
+
cls,
|
1017 |
+
repo_id: str,
|
1018 |
+
token: str | None = None,
|
1019 |
+
trust_remote_code: bool = False,
|
1020 |
+
**kwargs,
|
1021 |
+
):
|
1022 |
+
"""
|
1023 |
+
Loads an agent defined on the Hub.
|
1024 |
+
|
1025 |
+
<Tip warning={true}>
|
1026 |
+
|
1027 |
+
Loading a tool from the Hub means that you'll download the tool and execute it locally.
|
1028 |
+
ALWAYS inspect the tool you're downloading before loading it within your runtime, as you would do when
|
1029 |
+
installing a package using pip/npm/apt.
|
1030 |
+
|
1031 |
+
</Tip>
|
1032 |
+
|
1033 |
+
Args:
|
1034 |
+
repo_id (`str`):
|
1035 |
+
The name of the repo on the Hub where your tool is defined.
|
1036 |
+
token (`str`, *optional*):
|
1037 |
+
The token to identify you on hf.co. If unset, will use the token generated when running
|
1038 |
+
`huggingface-cli login` (stored in `~/.huggingface`).
|
1039 |
+
trust_remote_code(`bool`, *optional*, defaults to False):
|
1040 |
+
This flags marks that you understand the risk of running remote code and that you trust this tool.
|
1041 |
+
If not setting this to True, loading the tool from Hub will fail.
|
1042 |
+
kwargs (additional keyword arguments, *optional*):
|
1043 |
+
Additional keyword arguments that will be split in two: all arguments relevant to the Hub (such as
|
1044 |
+
`cache_dir`, `revision`, `subfolder`) will be used when downloading the files for your agent, and the
|
1045 |
+
others will be passed along to its init.
|
1046 |
+
"""
|
1047 |
+
if not trust_remote_code:
|
1048 |
+
raise ValueError(
|
1049 |
+
"Loading an agent from Hub requires to acknowledge you trust its code: to do so, pass `trust_remote_code=True`."
|
1050 |
+
)
|
1051 |
+
|
1052 |
+
# Get the agent's Hub folder.
|
1053 |
+
download_kwargs = {"token": token, "repo_type": "space"} | {
|
1054 |
+
key: kwargs.pop(key)
|
1055 |
+
for key in [
|
1056 |
+
"cache_dir",
|
1057 |
+
"force_download",
|
1058 |
+
"proxies",
|
1059 |
+
"revision",
|
1060 |
+
"local_files_only",
|
1061 |
+
]
|
1062 |
+
if key in kwargs
|
1063 |
+
}
|
1064 |
+
|
1065 |
+
download_folder = Path(snapshot_download(repo_id=repo_id, **download_kwargs))
|
1066 |
+
return cls.from_folder(download_folder, **kwargs)
|
1067 |
+
|
1068 |
+
@classmethod
|
1069 |
+
def from_folder(cls, folder: str | Path, **kwargs):
|
1070 |
+
"""Loads an agent from a local folder.
|
1071 |
+
|
1072 |
+
Args:
|
1073 |
+
folder (`str` or `Path`): The folder where the agent is saved.
|
1074 |
+
**kwargs: Additional keyword arguments that will be passed to the agent's init.
|
1075 |
+
"""
|
1076 |
+
# Load agent.json
|
1077 |
+
folder = Path(folder)
|
1078 |
+
agent_dict = json.loads((folder / "agent.json").read_text())
|
1079 |
+
|
1080 |
+
# Load managed agents from their respective folders, recursively
|
1081 |
+
managed_agents = []
|
1082 |
+
for managed_agent_name, managed_agent_class_name in agent_dict["managed_agents"].items():
|
1083 |
+
agent_cls = getattr(importlib.import_module("smolagents.agents"), managed_agent_class_name)
|
1084 |
+
managed_agents.append(agent_cls.from_folder(folder / "managed_agents" / managed_agent_name))
|
1085 |
+
agent_dict["managed_agents"] = {}
|
1086 |
+
|
1087 |
+
# Load tools
|
1088 |
+
tools = []
|
1089 |
+
for tool_name in agent_dict["tools"]:
|
1090 |
+
tool_code = (folder / "tools" / f"{tool_name}.py").read_text()
|
1091 |
+
tools.append({"name": tool_name, "code": tool_code})
|
1092 |
+
agent_dict["tools"] = tools
|
1093 |
+
|
1094 |
+
# Add managed agents to kwargs to override the empty list in from_dict
|
1095 |
+
if managed_agents:
|
1096 |
+
kwargs["managed_agents"] = managed_agents
|
1097 |
+
|
1098 |
+
return cls.from_dict(agent_dict, **kwargs)
|
1099 |
+
|
1100 |
+
def push_to_hub(
|
1101 |
+
self,
|
1102 |
+
repo_id: str,
|
1103 |
+
commit_message: str = "Upload agent",
|
1104 |
+
private: bool | None = None,
|
1105 |
+
token: bool | str | None = None,
|
1106 |
+
create_pr: bool = False,
|
1107 |
+
) -> str:
|
1108 |
+
"""
|
1109 |
+
Upload the agent to the Hub.
|
1110 |
+
|
1111 |
+
Parameters:
|
1112 |
+
repo_id (`str`):
|
1113 |
+
The name of the repository you want to push to. It should contain your organization name when
|
1114 |
+
pushing to a given organization.
|
1115 |
+
commit_message (`str`, *optional*, defaults to `"Upload agent"`):
|
1116 |
+
Message to commit while pushing.
|
1117 |
+
private (`bool`, *optional*, defaults to `None`):
|
1118 |
+
Whether to make the repo private. If `None`, the repo will be public unless the organization's default is private. This value is ignored if the repo already exists.
|
1119 |
+
token (`bool` or `str`, *optional*):
|
1120 |
+
The token to use as HTTP bearer authorization for remote files. If unset, will use the token generated
|
1121 |
+
when running `huggingface-cli login` (stored in `~/.huggingface`).
|
1122 |
+
create_pr (`bool`, *optional*, defaults to `False`):
|
1123 |
+
Whether to create a PR with the uploaded files or directly commit.
|
1124 |
+
"""
|
1125 |
+
repo_url = create_repo(
|
1126 |
+
repo_id=repo_id,
|
1127 |
+
token=token,
|
1128 |
+
private=private,
|
1129 |
+
exist_ok=True,
|
1130 |
+
repo_type="space",
|
1131 |
+
space_sdk="gradio",
|
1132 |
+
)
|
1133 |
+
repo_id = repo_url.repo_id
|
1134 |
+
metadata_update(
|
1135 |
+
repo_id,
|
1136 |
+
{"tags": ["smolagents", "agent"]},
|
1137 |
+
repo_type="space",
|
1138 |
+
token=token,
|
1139 |
+
overwrite=True,
|
1140 |
+
)
|
1141 |
+
|
1142 |
+
with tempfile.TemporaryDirectory() as work_dir:
|
1143 |
+
self.save(work_dir)
|
1144 |
+
logger.info(f"Uploading the following files to {repo_id}: {','.join(os.listdir(work_dir))}")
|
1145 |
+
return upload_folder(
|
1146 |
+
repo_id=repo_id,
|
1147 |
+
commit_message=commit_message,
|
1148 |
+
folder_path=work_dir,
|
1149 |
+
token=token,
|
1150 |
+
create_pr=create_pr,
|
1151 |
+
repo_type="space",
|
1152 |
+
)
|
1153 |
+
|
1154 |
+
|
1155 |
+
class ToolCallingAgent(MultiStepAgent):
|
1156 |
+
"""
|
1157 |
+
This agent uses JSON-like tool calls, using method `model.get_tool_call` to leverage the LLM engine's tool calling capabilities.
|
1158 |
+
|
1159 |
+
Args:
|
1160 |
+
tools (`list[Tool]`): [`Tool`]s that the agent can use.
|
1161 |
+
model (`Model`): Model that will generate the agent's actions.
|
1162 |
+
prompt_templates ([`~agents.PromptTemplates`], *optional*): Prompt templates.
|
1163 |
+
planning_interval (`int`, *optional*): Interval at which the agent will run a planning step.
|
1164 |
+
stream_outputs (`bool`, *optional*, default `False`): Whether to stream outputs during execution.
|
1165 |
+
max_tool_threads (`int`, *optional*): Maximum number of threads for parallel tool calls.
|
1166 |
+
Higher values increase concurrency but resource usage as well.
|
1167 |
+
Defaults to `ThreadPoolExecutor`'s default.
|
1168 |
+
**kwargs: Additional keyword arguments.
|
1169 |
+
"""
|
1170 |
+
|
1171 |
+
def __init__(
|
1172 |
+
self,
|
1173 |
+
tools: list[Tool],
|
1174 |
+
model: Model,
|
1175 |
+
prompt_templates: PromptTemplates | None = None,
|
1176 |
+
planning_interval: int | None = None,
|
1177 |
+
stream_outputs: bool = False,
|
1178 |
+
max_tool_threads: int | None = None,
|
1179 |
+
**kwargs,
|
1180 |
+
):
|
1181 |
+
prompt_templates = prompt_templates or yaml.safe_load(
|
1182 |
+
importlib.resources.files("smolagents.prompts").joinpath("toolcalling_agent.yaml").read_text()
|
1183 |
+
)
|
1184 |
+
super().__init__(
|
1185 |
+
tools=tools,
|
1186 |
+
model=model,
|
1187 |
+
prompt_templates=prompt_templates,
|
1188 |
+
planning_interval=planning_interval,
|
1189 |
+
**kwargs,
|
1190 |
+
)
|
1191 |
+
# Streaming setup
|
1192 |
+
self.stream_outputs = stream_outputs
|
1193 |
+
if self.stream_outputs and not hasattr(self.model, "generate_stream"):
|
1194 |
+
raise ValueError(
|
1195 |
+
"`stream_outputs` is set to True, but the model class implements no `generate_stream` method."
|
1196 |
+
)
|
1197 |
+
# Tool calling setup
|
1198 |
+
self.max_tool_threads = max_tool_threads
|
1199 |
+
|
1200 |
+
@property
|
1201 |
+
def tools_and_managed_agents(self):
|
1202 |
+
"""Returns a combined list of tools and managed agents."""
|
1203 |
+
return list(self.tools.values()) + list(self.managed_agents.values())
|
1204 |
+
|
1205 |
+
def initialize_system_prompt(self) -> str:
|
1206 |
+
system_prompt = populate_template(
|
1207 |
+
self.prompt_templates["system_prompt"],
|
1208 |
+
variables={
|
1209 |
+
"tools": self.tools,
|
1210 |
+
"managed_agents": self.managed_agents,
|
1211 |
+
"custom_instructions": self.instructions,
|
1212 |
+
},
|
1213 |
+
)
|
1214 |
+
return system_prompt
|
1215 |
+
|
1216 |
+
def _step_stream(
|
1217 |
+
self, memory_step: ActionStep
|
1218 |
+
) -> Generator[ChatMessageStreamDelta | ToolCall | ToolOutput | ActionOutput]:
|
1219 |
+
"""
|
1220 |
+
Perform one step in the ReAct framework: the agent thinks, acts, and observes the result.
|
1221 |
+
Yields ChatMessageStreamDelta during the run if streaming is enabled.
|
1222 |
+
At the end, yields either None if the step is not final, or the final answer.
|
1223 |
+
"""
|
1224 |
+
memory_messages = self.write_memory_to_messages()
|
1225 |
+
|
1226 |
+
input_messages = memory_messages.copy()
|
1227 |
+
|
1228 |
+
# Add new step in logs
|
1229 |
+
memory_step.model_input_messages = input_messages
|
1230 |
+
|
1231 |
+
try:
|
1232 |
+
if self.stream_outputs and hasattr(self.model, "generate_stream"):
|
1233 |
+
output_stream = self.model.generate_stream(
|
1234 |
+
input_messages,
|
1235 |
+
stop_sequences=["Observation:", "Calling tools:"],
|
1236 |
+
tools_to_call_from=self.tools_and_managed_agents,
|
1237 |
+
)
|
1238 |
+
|
1239 |
+
chat_message_stream_deltas: list[ChatMessageStreamDelta] = []
|
1240 |
+
with Live("", console=self.logger.console, vertical_overflow="visible") as live:
|
1241 |
+
for event in output_stream:
|
1242 |
+
chat_message_stream_deltas.append(event)
|
1243 |
+
live.update(
|
1244 |
+
Markdown(agglomerate_stream_deltas(chat_message_stream_deltas).render_as_markdown())
|
1245 |
+
)
|
1246 |
+
yield event
|
1247 |
+
chat_message = agglomerate_stream_deltas(chat_message_stream_deltas)
|
1248 |
+
else:
|
1249 |
+
chat_message: ChatMessage = self.model.generate(
|
1250 |
+
input_messages,
|
1251 |
+
stop_sequences=["Observation:", "Calling tools:"],
|
1252 |
+
tools_to_call_from=self.tools_and_managed_agents,
|
1253 |
+
)
|
1254 |
+
if chat_message.content is None and chat_message.raw is not None:
|
1255 |
+
log_content = str(chat_message.raw)
|
1256 |
+
else:
|
1257 |
+
log_content = str(chat_message.content) or ""
|
1258 |
+
|
1259 |
+
self.logger.log_markdown(
|
1260 |
+
content=log_content,
|
1261 |
+
title="Output message of the LLM:",
|
1262 |
+
level=LogLevel.DEBUG,
|
1263 |
+
)
|
1264 |
+
|
1265 |
+
# Record model output
|
1266 |
+
memory_step.model_output_message = chat_message
|
1267 |
+
memory_step.model_output = chat_message.content
|
1268 |
+
memory_step.token_usage = chat_message.token_usage
|
1269 |
+
except Exception as e:
|
1270 |
+
raise AgentGenerationError(f"Error while generating output:\n{e}", self.logger) from e
|
1271 |
+
|
1272 |
+
if chat_message.tool_calls is None or len(chat_message.tool_calls) == 0:
|
1273 |
+
try:
|
1274 |
+
chat_message = self.model.parse_tool_calls(chat_message)
|
1275 |
+
except Exception as e:
|
1276 |
+
raise AgentParsingError(f"Error while parsing tool call from model output: {e}", self.logger)
|
1277 |
+
else:
|
1278 |
+
for tool_call in chat_message.tool_calls:
|
1279 |
+
tool_call.function.arguments = parse_json_if_needed(tool_call.function.arguments)
|
1280 |
+
final_answer, got_final_answer = None, False
|
1281 |
+
for output in self.process_tool_calls(chat_message, memory_step):
|
1282 |
+
yield output
|
1283 |
+
if isinstance(output, ToolOutput):
|
1284 |
+
if output.is_final_answer:
|
1285 |
+
if got_final_answer:
|
1286 |
+
raise AgentToolExecutionError(
|
1287 |
+
"You returned multiple final answers. Please return only one single final answer!",
|
1288 |
+
self.logger,
|
1289 |
+
)
|
1290 |
+
final_answer = output.output
|
1291 |
+
got_final_answer = True
|
1292 |
+
|
1293 |
+
# Manage state variables
|
1294 |
+
if isinstance(final_answer, str) and final_answer in self.state.keys():
|
1295 |
+
final_answer = self.state[final_answer]
|
1296 |
+
yield ActionOutput(
|
1297 |
+
output=final_answer,
|
1298 |
+
is_final_answer=got_final_answer,
|
1299 |
+
)
|
1300 |
+
|
1301 |
+
def process_tool_calls(
|
1302 |
+
self, chat_message: ChatMessage, memory_step: ActionStep
|
1303 |
+
) -> Generator[ToolCall | ToolOutput]:
|
1304 |
+
"""Process tool calls from the model output and update agent memory.
|
1305 |
+
|
1306 |
+
Args:
|
1307 |
+
chat_message (`ChatMessage`): Chat message containing tool calls from the model.
|
1308 |
+
memory_step (`ActionStep)`: Memory ActionStep to update with results.
|
1309 |
+
|
1310 |
+
Yields:
|
1311 |
+
`ToolCall | ToolOutput`: The tool call or tool output.
|
1312 |
+
"""
|
1313 |
+
parallel_calls: dict[str, ToolCall] = {}
|
1314 |
+
assert chat_message.tool_calls is not None
|
1315 |
+
for chat_tool_call in chat_message.tool_calls:
|
1316 |
+
tool_call = ToolCall(
|
1317 |
+
name=chat_tool_call.function.name, arguments=chat_tool_call.function.arguments, id=chat_tool_call.id
|
1318 |
+
)
|
1319 |
+
yield tool_call
|
1320 |
+
parallel_calls[tool_call.id] = tool_call
|
1321 |
+
|
1322 |
+
# Helper function to process a single tool call
|
1323 |
+
def process_single_tool_call(tool_call: ToolCall) -> ToolOutput:
|
1324 |
+
tool_name = tool_call.name
|
1325 |
+
tool_arguments = tool_call.arguments or {}
|
1326 |
+
self.logger.log(
|
1327 |
+
Panel(Text(f"Calling tool: '{tool_name}' with arguments: {tool_arguments}")),
|
1328 |
+
level=LogLevel.INFO,
|
1329 |
+
)
|
1330 |
+
tool_call_result = self.execute_tool_call(tool_name, tool_arguments)
|
1331 |
+
tool_call_result_type = type(tool_call_result)
|
1332 |
+
if tool_call_result_type in [AgentImage, AgentAudio]:
|
1333 |
+
if tool_call_result_type == AgentImage:
|
1334 |
+
observation_name = "image.png"
|
1335 |
+
elif tool_call_result_type == AgentAudio:
|
1336 |
+
observation_name = "audio.mp3"
|
1337 |
+
# TODO: tool_call_result naming could allow for different names of same type
|
1338 |
+
self.state[observation_name] = tool_call_result
|
1339 |
+
observation = f"Stored '{observation_name}' in memory."
|
1340 |
+
else:
|
1341 |
+
observation = str(tool_call_result).strip()
|
1342 |
+
self.logger.log(
|
1343 |
+
f"Observations: {observation.replace('[', '|')}", # escape potential rich-tag-like components
|
1344 |
+
level=LogLevel.INFO,
|
1345 |
+
)
|
1346 |
+
is_final_answer = tool_name == "final_answer"
|
1347 |
+
|
1348 |
+
return ToolOutput(
|
1349 |
+
id=tool_call.id,
|
1350 |
+
output=tool_call_result,
|
1351 |
+
is_final_answer=is_final_answer,
|
1352 |
+
observation=observation,
|
1353 |
+
tool_call=tool_call,
|
1354 |
+
)
|
1355 |
+
|
1356 |
+
# Process tool calls in parallel
|
1357 |
+
outputs = {}
|
1358 |
+
if len(parallel_calls) == 1:
|
1359 |
+
# If there's only one call, process it directly
|
1360 |
+
tool_call = list(parallel_calls.values())[0]
|
1361 |
+
tool_output = process_single_tool_call(tool_call)
|
1362 |
+
outputs[tool_output.id] = tool_output
|
1363 |
+
yield tool_output
|
1364 |
+
else:
|
1365 |
+
# If multiple tool calls, process them in parallel
|
1366 |
+
with ThreadPoolExecutor(self.max_tool_threads) as executor:
|
1367 |
+
futures = [
|
1368 |
+
executor.submit(process_single_tool_call, tool_call) for tool_call in parallel_calls.values()
|
1369 |
+
]
|
1370 |
+
for future in as_completed(futures):
|
1371 |
+
tool_output = future.result()
|
1372 |
+
outputs[tool_output.id] = tool_output
|
1373 |
+
yield tool_output
|
1374 |
+
|
1375 |
+
memory_step.tool_calls = [parallel_calls[k] for k in sorted(parallel_calls.keys())]
|
1376 |
+
memory_step.model_output = memory_step.model_output or ""
|
1377 |
+
memory_step.observations = memory_step.observations or ""
|
1378 |
+
for tool_output in [outputs[k] for k in sorted(outputs.keys())]:
|
1379 |
+
message = f"Tool call {tool_output.id}: calling '{tool_output.tool_call.name}' with arguments: {tool_output.tool_call.arguments}\n"
|
1380 |
+
memory_step.model_output += message
|
1381 |
+
memory_step.observations += tool_output.observation + "\n"
|
1382 |
+
memory_step.model_output = memory_step.model_output.rstrip("\n")
|
1383 |
+
memory_step.observations = (
|
1384 |
+
memory_step.observations.rstrip("\n") if memory_step.observations else memory_step.observations
|
1385 |
+
)
|
1386 |
+
|
1387 |
+
def _substitute_state_variables(self, arguments: dict[str, str] | str) -> dict[str, Any] | str:
|
1388 |
+
"""Replace string values in arguments with their corresponding state values if they exist."""
|
1389 |
+
if isinstance(arguments, dict):
|
1390 |
+
return {
|
1391 |
+
key: self.state.get(value, value) if isinstance(value, str) else value
|
1392 |
+
for key, value in arguments.items()
|
1393 |
+
}
|
1394 |
+
return arguments
|
1395 |
+
|
1396 |
+
def execute_tool_call(self, tool_name: str, arguments: dict[str, str] | str) -> Any:
|
1397 |
+
"""
|
1398 |
+
Execute a tool or managed agent with the provided arguments.
|
1399 |
+
|
1400 |
+
The arguments are replaced with the actual values from the state if they refer to state variables.
|
1401 |
+
|
1402 |
+
Args:
|
1403 |
+
tool_name (`str`): Name of the tool or managed agent to execute.
|
1404 |
+
arguments (dict[str, str] | str): Arguments passed to the tool call.
|
1405 |
+
"""
|
1406 |
+
# Check if the tool exists
|
1407 |
+
available_tools = {**self.tools, **self.managed_agents}
|
1408 |
+
if tool_name not in available_tools:
|
1409 |
+
raise AgentToolExecutionError(
|
1410 |
+
f"Unknown tool {tool_name}, should be one of: {', '.join(available_tools)}.", self.logger
|
1411 |
+
)
|
1412 |
+
|
1413 |
+
# Get the tool and substitute state variables in arguments
|
1414 |
+
tool = available_tools[tool_name]
|
1415 |
+
arguments = self._substitute_state_variables(arguments)
|
1416 |
+
is_managed_agent = tool_name in self.managed_agents
|
1417 |
+
|
1418 |
+
error_msg = validate_tool_arguments(tool, arguments)
|
1419 |
+
if error_msg:
|
1420 |
+
raise AgentToolCallError(error_msg, self.logger)
|
1421 |
+
|
1422 |
+
try:
|
1423 |
+
# Call tool with appropriate arguments
|
1424 |
+
if isinstance(arguments, dict):
|
1425 |
+
return tool(**arguments) if is_managed_agent else tool(**arguments, sanitize_inputs_outputs=True)
|
1426 |
+
else:
|
1427 |
+
return tool(arguments) if is_managed_agent else tool(arguments, sanitize_inputs_outputs=True)
|
1428 |
+
|
1429 |
+
except Exception as e:
|
1430 |
+
# Handle execution errors
|
1431 |
+
if is_managed_agent:
|
1432 |
+
error_msg = (
|
1433 |
+
f"Error executing request to team member '{tool_name}' with arguments {str(arguments)}: {e}\n"
|
1434 |
+
"Please try again or request to another team member"
|
1435 |
+
)
|
1436 |
+
else:
|
1437 |
+
error_msg = (
|
1438 |
+
f"Error executing tool '{tool_name}' with arguments {str(arguments)}: {type(e).__name__}: {e}\n"
|
1439 |
+
"Please try again or use another tool"
|
1440 |
+
)
|
1441 |
+
raise AgentToolExecutionError(error_msg, self.logger) from e
|
1442 |
+
|
1443 |
+
|
1444 |
+
class CodeAgent(MultiStepAgent):
|
1445 |
+
"""
|
1446 |
+
In this agent, the tool calls will be formulated by the LLM in code format, then parsed and executed.
|
1447 |
+
|
1448 |
+
Args:
|
1449 |
+
tools (`list[Tool]`): [`Tool`]s that the agent can use.
|
1450 |
+
model (`Model`): Model that will generate the agent's actions.
|
1451 |
+
prompt_templates ([`~agents.PromptTemplates`], *optional*): Prompt templates.
|
1452 |
+
additional_authorized_imports (`list[str]`, *optional*): Additional authorized imports for the agent.
|
1453 |
+
planning_interval (`int`, *optional*): Interval at which the agent will run a planning step.
|
1454 |
+
executor_type (`str`, default `"local"`): Which executor type to use between `"local"`, `"e2b"`, or `"docker"`.
|
1455 |
+
executor_kwargs (`dict`, *optional*): Additional arguments to pass to initialize the executor.
|
1456 |
+
max_print_outputs_length (`int`, *optional*): Maximum length of the print outputs.
|
1457 |
+
stream_outputs (`bool`, *optional*, default `False`): Whether to stream outputs during execution.
|
1458 |
+
use_structured_outputs_internally (`bool`, default `False`): Whether to use structured generation at each action step: improves performance for many models.
|
1459 |
+
|
1460 |
+
<Added version="1.17.0"/>
|
1461 |
+
grammar (`dict[str, str]`, *optional*): Grammar used to parse the LLM output.
|
1462 |
+
<Deprecated version="1.17.0">
|
1463 |
+
Parameter `grammar` is deprecated and will be removed in version 1.20.
|
1464 |
+
</Deprecated>
|
1465 |
+
**kwargs: Additional keyword arguments.
|
1466 |
+
"""
|
1467 |
+
|
1468 |
+
def __init__(
|
1469 |
+
self,
|
1470 |
+
tools: list[Tool],
|
1471 |
+
model: Model,
|
1472 |
+
prompt_templates: PromptTemplates | None = None,
|
1473 |
+
additional_authorized_imports: list[str] | None = None,
|
1474 |
+
planning_interval: int | None = None,
|
1475 |
+
executor_type: str | None = "local",
|
1476 |
+
executor_kwargs: dict[str, Any] | None = None,
|
1477 |
+
max_print_outputs_length: int | None = None,
|
1478 |
+
stream_outputs: bool = False,
|
1479 |
+
use_structured_outputs_internally: bool = False,
|
1480 |
+
grammar: dict[str, str] | None = None,
|
1481 |
+
**kwargs,
|
1482 |
+
):
|
1483 |
+
self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else []
|
1484 |
+
self.authorized_imports = sorted(set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports))
|
1485 |
+
self.max_print_outputs_length = max_print_outputs_length
|
1486 |
+
self._use_structured_outputs_internally = use_structured_outputs_internally
|
1487 |
+
if use_structured_outputs_internally:
|
1488 |
+
prompt_templates = prompt_templates or yaml.safe_load(
|
1489 |
+
importlib.resources.files("smolagents.prompts").joinpath("structured_code_agent.yaml").read_text()
|
1490 |
+
)
|
1491 |
+
else:
|
1492 |
+
prompt_templates = prompt_templates or yaml.safe_load(
|
1493 |
+
importlib.resources.files("smolagents.prompts").joinpath("code_agent.yaml").read_text()
|
1494 |
+
)
|
1495 |
+
if grammar and use_structured_outputs_internally:
|
1496 |
+
raise ValueError("You cannot use 'grammar' and 'use_structured_outputs_internally' at the same time.")
|
1497 |
+
super().__init__(
|
1498 |
+
tools=tools,
|
1499 |
+
model=model,
|
1500 |
+
prompt_templates=prompt_templates,
|
1501 |
+
grammar=grammar,
|
1502 |
+
planning_interval=planning_interval,
|
1503 |
+
**kwargs,
|
1504 |
+
)
|
1505 |
+
self.stream_outputs = stream_outputs
|
1506 |
+
if self.stream_outputs and not hasattr(self.model, "generate_stream"):
|
1507 |
+
raise ValueError(
|
1508 |
+
"`stream_outputs` is set to True, but the model class implements no `generate_stream` method."
|
1509 |
+
)
|
1510 |
+
if "*" in self.additional_authorized_imports:
|
1511 |
+
self.logger.log(
|
1512 |
+
"Caution: you set an authorization for all imports, meaning your agent can decide to import any package it deems necessary. This might raise issues if the package is not installed in your environment.",
|
1513 |
+
level=LogLevel.INFO,
|
1514 |
+
)
|
1515 |
+
self.executor_type = executor_type or "local"
|
1516 |
+
self.executor_kwargs = executor_kwargs or {}
|
1517 |
+
self.python_executor = self.create_python_executor()
|
1518 |
+
|
1519 |
+
def __enter__(self):
|
1520 |
+
return self
|
1521 |
+
|
1522 |
+
def __exit__(self, exc_type, exc_value, traceback):
|
1523 |
+
self.cleanup()
|
1524 |
+
|
1525 |
+
def cleanup(self):
|
1526 |
+
"""Clean up resources used by the agent, such as the remote Python executor."""
|
1527 |
+
if hasattr(self.python_executor, "cleanup"):
|
1528 |
+
self.python_executor.cleanup()
|
1529 |
+
|
1530 |
+
def create_python_executor(self) -> PythonExecutor:
|
1531 |
+
match self.executor_type:
|
1532 |
+
case "e2b" | "docker":
|
1533 |
+
if self.managed_agents:
|
1534 |
+
raise Exception("Managed agents are not yet supported with remote code execution.")
|
1535 |
+
if self.executor_type == "e2b":
|
1536 |
+
return E2BExecutor(self.additional_authorized_imports, self.logger, **self.executor_kwargs)
|
1537 |
+
else:
|
1538 |
+
return DockerExecutor(self.additional_authorized_imports, self.logger, **self.executor_kwargs)
|
1539 |
+
case "local":
|
1540 |
+
return LocalPythonExecutor(
|
1541 |
+
self.additional_authorized_imports,
|
1542 |
+
**{"max_print_outputs_length": self.max_print_outputs_length} | self.executor_kwargs,
|
1543 |
+
)
|
1544 |
+
case _: # if applicable
|
1545 |
+
raise ValueError(f"Unsupported executor type: {self.executor_type}")
|
1546 |
+
|
1547 |
+
def initialize_system_prompt(self) -> str:
|
1548 |
+
system_prompt = populate_template(
|
1549 |
+
self.prompt_templates["system_prompt"],
|
1550 |
+
variables={
|
1551 |
+
"tools": self.tools,
|
1552 |
+
"managed_agents": self.managed_agents,
|
1553 |
+
"authorized_imports": (
|
1554 |
+
"You can import from any package you want."
|
1555 |
+
if "*" in self.authorized_imports
|
1556 |
+
else str(self.authorized_imports)
|
1557 |
+
),
|
1558 |
+
"custom_instructions": self.instructions,
|
1559 |
+
},
|
1560 |
+
)
|
1561 |
+
return system_prompt
|
1562 |
+
|
1563 |
+
def _step_stream(
|
1564 |
+
self, memory_step: ActionStep
|
1565 |
+
) -> Generator[ChatMessageStreamDelta | ToolCall | ToolOutput | ActionOutput]:
|
1566 |
+
"""
|
1567 |
+
Perform one step in the ReAct framework: the agent thinks, acts, and observes the result.
|
1568 |
+
Yields ChatMessageStreamDelta during the run if streaming is enabled.
|
1569 |
+
At the end, yields either None if the step is not final, or the final answer.
|
1570 |
+
"""
|
1571 |
+
memory_messages = self.write_memory_to_messages()
|
1572 |
+
|
1573 |
+
input_messages = memory_messages.copy()
|
1574 |
+
### Generate model output ###
|
1575 |
+
memory_step.model_input_messages = input_messages
|
1576 |
+
try:
|
1577 |
+
additional_args: dict[str, Any] = {}
|
1578 |
+
if self.grammar:
|
1579 |
+
additional_args["grammar"] = self.grammar
|
1580 |
+
if self._use_structured_outputs_internally:
|
1581 |
+
additional_args["response_format"] = CODEAGENT_RESPONSE_FORMAT
|
1582 |
+
if self.stream_outputs:
|
1583 |
+
output_stream = self.model.generate_stream(
|
1584 |
+
input_messages,
|
1585 |
+
stop_sequences=["<end_code>", "Observation:", "Calling tools:"],
|
1586 |
+
**additional_args,
|
1587 |
+
)
|
1588 |
+
chat_message_stream_deltas: list[ChatMessageStreamDelta] = []
|
1589 |
+
with Live("", console=self.logger.console, vertical_overflow="visible") as live:
|
1590 |
+
for event in output_stream:
|
1591 |
+
chat_message_stream_deltas.append(event)
|
1592 |
+
live.update(
|
1593 |
+
Markdown(agglomerate_stream_deltas(chat_message_stream_deltas).render_as_markdown())
|
1594 |
+
)
|
1595 |
+
yield event
|
1596 |
+
chat_message = agglomerate_stream_deltas(chat_message_stream_deltas)
|
1597 |
+
memory_step.model_output_message = chat_message
|
1598 |
+
output_text = chat_message.content
|
1599 |
+
else:
|
1600 |
+
chat_message: ChatMessage = self.model.generate(
|
1601 |
+
input_messages,
|
1602 |
+
stop_sequences=["<end_code>", "Observation:", "Calling tools:"],
|
1603 |
+
**additional_args,
|
1604 |
+
)
|
1605 |
+
memory_step.model_output_message = chat_message
|
1606 |
+
output_text = chat_message.content
|
1607 |
+
self.logger.log_markdown(
|
1608 |
+
content=output_text,
|
1609 |
+
title="Output message of the LLM:",
|
1610 |
+
level=LogLevel.DEBUG,
|
1611 |
+
)
|
1612 |
+
|
1613 |
+
# This adds <end_code> sequence to the history.
|
1614 |
+
# This will nudge ulterior LLM calls to finish with <end_code>, thus efficiently stopping generation.
|
1615 |
+
if output_text and output_text.strip().endswith("```"):
|
1616 |
+
output_text += "<end_code>"
|
1617 |
+
memory_step.model_output_message.content = output_text
|
1618 |
+
|
1619 |
+
memory_step.token_usage = chat_message.token_usage
|
1620 |
+
memory_step.model_output = output_text
|
1621 |
+
except Exception as e:
|
1622 |
+
raise AgentGenerationError(f"Error in generating model output:\n{e}", self.logger) from e
|
1623 |
+
|
1624 |
+
### Parse output ###
|
1625 |
+
try:
|
1626 |
+
if self._use_structured_outputs_internally:
|
1627 |
+
code_action = json.loads(output_text)["code"]
|
1628 |
+
code_action = extract_code_from_text(code_action) or code_action
|
1629 |
+
else:
|
1630 |
+
code_action = parse_code_blobs(output_text)
|
1631 |
+
code_action = fix_final_answer_code(code_action)
|
1632 |
+
memory_step.code_action = code_action
|
1633 |
+
except Exception as e:
|
1634 |
+
error_msg = f"Error in code parsing:\n{e}\nMake sure to provide correct code blobs."
|
1635 |
+
raise AgentParsingError(error_msg, self.logger)
|
1636 |
+
|
1637 |
+
tool_call = ToolCall(
|
1638 |
+
name="python_interpreter",
|
1639 |
+
arguments=code_action,
|
1640 |
+
id=f"call_{len(self.memory.steps)}",
|
1641 |
+
)
|
1642 |
+
yield tool_call
|
1643 |
+
memory_step.tool_calls = [tool_call]
|
1644 |
+
|
1645 |
+
### Execute action ###
|
1646 |
+
self.logger.log_code(title="Executing parsed code:", content=code_action, level=LogLevel.INFO)
|
1647 |
+
is_final_answer = False
|
1648 |
+
try:
|
1649 |
+
output, execution_logs, is_final_answer = self.python_executor(code_action)
|
1650 |
+
execution_outputs_console = []
|
1651 |
+
if len(execution_logs) > 0:
|
1652 |
+
execution_outputs_console += [
|
1653 |
+
Text("Execution logs:", style="bold"),
|
1654 |
+
Text(execution_logs),
|
1655 |
+
]
|
1656 |
+
observation = "Execution logs:\n" + execution_logs
|
1657 |
+
except Exception as e:
|
1658 |
+
if hasattr(self.python_executor, "state") and "_print_outputs" in self.python_executor.state:
|
1659 |
+
execution_logs = str(self.python_executor.state["_print_outputs"])
|
1660 |
+
if len(execution_logs) > 0:
|
1661 |
+
execution_outputs_console = [
|
1662 |
+
Text("Execution logs:", style="bold"),
|
1663 |
+
Text(execution_logs),
|
1664 |
+
]
|
1665 |
+
memory_step.observations = "Execution logs:\n" + execution_logs
|
1666 |
+
self.logger.log(Group(*execution_outputs_console), level=LogLevel.INFO)
|
1667 |
+
error_msg = str(e)
|
1668 |
+
if "Import of " in error_msg and " is not allowed" in error_msg:
|
1669 |
+
self.logger.log(
|
1670 |
+
"[bold red]Warning to user: Code execution failed due to an unauthorized import - Consider passing said import under `additional_authorized_imports` when initializing your CodeAgent.",
|
1671 |
+
level=LogLevel.INFO,
|
1672 |
+
)
|
1673 |
+
raise AgentExecutionError(error_msg, self.logger)
|
1674 |
+
|
1675 |
+
truncated_output = truncate_content(str(output))
|
1676 |
+
observation += "Last output from code snippet:\n" + truncated_output
|
1677 |
+
memory_step.observations = observation
|
1678 |
+
|
1679 |
+
if not is_final_answer:
|
1680 |
+
execution_outputs_console += [
|
1681 |
+
Text(
|
1682 |
+
f"Out: {truncated_output}",
|
1683 |
+
),
|
1684 |
+
]
|
1685 |
+
self.logger.log(Group(*execution_outputs_console), level=LogLevel.INFO)
|
1686 |
+
memory_step.action_output = output
|
1687 |
+
yield ActionOutput(output=output, is_final_answer=is_final_answer)
|
1688 |
+
|
1689 |
+
def to_dict(self) -> dict[str, Any]:
|
1690 |
+
"""Convert the agent to a dictionary representation.
|
1691 |
+
|
1692 |
+
Returns:
|
1693 |
+
`dict`: Dictionary representation of the agent.
|
1694 |
+
"""
|
1695 |
+
agent_dict = super().to_dict()
|
1696 |
+
agent_dict["authorized_imports"] = self.authorized_imports
|
1697 |
+
agent_dict["executor_type"] = self.executor_type
|
1698 |
+
agent_dict["executor_kwargs"] = self.executor_kwargs
|
1699 |
+
agent_dict["max_print_outputs_length"] = self.max_print_outputs_length
|
1700 |
+
return agent_dict
|
1701 |
+
|
1702 |
+
@classmethod
|
1703 |
+
def from_dict(cls, agent_dict: dict[str, Any], **kwargs) -> "CodeAgent":
|
1704 |
+
"""Create CodeAgent from a dictionary representation.
|
1705 |
+
|
1706 |
+
Args:
|
1707 |
+
agent_dict (`dict[str, Any]`): Dictionary representation of the agent.
|
1708 |
+
**kwargs: Additional keyword arguments that will override agent_dict values.
|
1709 |
+
|
1710 |
+
Returns:
|
1711 |
+
`CodeAgent`: Instance of the CodeAgent class.
|
1712 |
+
"""
|
1713 |
+
# Add CodeAgent-specific parameters to kwargs
|
1714 |
+
code_agent_kwargs = {
|
1715 |
+
"additional_authorized_imports": agent_dict.get("authorized_imports"),
|
1716 |
+
"executor_type": agent_dict.get("executor_type"),
|
1717 |
+
"executor_kwargs": agent_dict.get("executor_kwargs"),
|
1718 |
+
"max_print_outputs_length": agent_dict.get("max_print_outputs_length"),
|
1719 |
+
}
|
1720 |
+
# Filter out None values
|
1721 |
+
code_agent_kwargs = {k: v for k, v in code_agent_kwargs.items() if v is not None}
|
1722 |
+
# Update with any additional kwargs
|
1723 |
+
code_agent_kwargs.update(kwargs)
|
1724 |
+
# Call the parent class's from_dict method
|
1725 |
+
return super().from_dict(agent_dict, **code_agent_kwargs)
|
src/smolagents/cli.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding=utf-8
|
3 |
+
|
4 |
+
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
5 |
+
#
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# Unless required by applicable law or agreed to in writing, software
|
13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
# See the License for the specific language governing permissions and
|
16 |
+
# limitations under the License.
|
17 |
+
import argparse
|
18 |
+
import os
|
19 |
+
|
20 |
+
from dotenv import load_dotenv
|
21 |
+
|
22 |
+
from smolagents import CodeAgent, InferenceClientModel, LiteLLMModel, Model, OpenAIServerModel, Tool, TransformersModel
|
23 |
+
from smolagents.default_tools import TOOL_MAPPING
|
24 |
+
|
25 |
+
|
26 |
+
leopard_prompt = "How many seconds would it take for a leopard at full speed to run through Pont des Arts?"
|
27 |
+
|
28 |
+
|
29 |
+
def parse_arguments():
|
30 |
+
parser = argparse.ArgumentParser(description="Run a CodeAgent with all specified parameters")
|
31 |
+
parser.add_argument(
|
32 |
+
"prompt",
|
33 |
+
type=str,
|
34 |
+
nargs="?", # Makes it optional
|
35 |
+
default=leopard_prompt,
|
36 |
+
help="The prompt to run with the agent",
|
37 |
+
)
|
38 |
+
parser.add_argument(
|
39 |
+
"--model-type",
|
40 |
+
type=str,
|
41 |
+
default="InferenceClientModel",
|
42 |
+
help="The model type to use (e.g., InferenceClientModel, OpenAIServerModel, LiteLLMModel, TransformersModel)",
|
43 |
+
)
|
44 |
+
parser.add_argument(
|
45 |
+
"--model-id",
|
46 |
+
type=str,
|
47 |
+
default="Qwen/Qwen2.5-Coder-32B-Instruct",
|
48 |
+
help="The model ID to use for the specified model type",
|
49 |
+
)
|
50 |
+
parser.add_argument(
|
51 |
+
"--imports",
|
52 |
+
nargs="*", # accepts zero or more arguments
|
53 |
+
default=[],
|
54 |
+
help="Space-separated list of imports to authorize (e.g., 'numpy pandas')",
|
55 |
+
)
|
56 |
+
parser.add_argument(
|
57 |
+
"--tools",
|
58 |
+
nargs="*",
|
59 |
+
default=["web_search"],
|
60 |
+
help="Space-separated list of tools that the agent can use (e.g., 'tool1 tool2 tool3')",
|
61 |
+
)
|
62 |
+
parser.add_argument(
|
63 |
+
"--verbosity-level",
|
64 |
+
type=int,
|
65 |
+
default=1,
|
66 |
+
help="The verbosity level, as an int in [0, 1, 2].",
|
67 |
+
)
|
68 |
+
group = parser.add_argument_group("api options", "Options for API-based model types")
|
69 |
+
group.add_argument(
|
70 |
+
"--provider",
|
71 |
+
type=str,
|
72 |
+
default=None,
|
73 |
+
help="The inference provider to use for the model",
|
74 |
+
)
|
75 |
+
group.add_argument(
|
76 |
+
"--api-base",
|
77 |
+
type=str,
|
78 |
+
help="The base URL for the model",
|
79 |
+
)
|
80 |
+
group.add_argument(
|
81 |
+
"--api-key",
|
82 |
+
type=str,
|
83 |
+
help="The API key for the model",
|
84 |
+
)
|
85 |
+
return parser.parse_args()
|
86 |
+
|
87 |
+
|
88 |
+
def load_model(
|
89 |
+
model_type: str,
|
90 |
+
model_id: str,
|
91 |
+
api_base: str | None = None,
|
92 |
+
api_key: str | None = None,
|
93 |
+
provider: str | None = None,
|
94 |
+
) -> Model:
|
95 |
+
if model_type == "OpenAIServerModel":
|
96 |
+
return OpenAIServerModel(
|
97 |
+
api_key=api_key or os.getenv("FIREWORKS_API_KEY"),
|
98 |
+
api_base=api_base or "https://api.fireworks.ai/inference/v1",
|
99 |
+
model_id=model_id,
|
100 |
+
)
|
101 |
+
elif model_type == "LiteLLMModel":
|
102 |
+
return LiteLLMModel(
|
103 |
+
model_id=model_id,
|
104 |
+
api_key=api_key,
|
105 |
+
api_base=api_base,
|
106 |
+
)
|
107 |
+
elif model_type == "TransformersModel":
|
108 |
+
return TransformersModel(model_id=model_id, device_map="auto")
|
109 |
+
elif model_type == "InferenceClientModel":
|
110 |
+
return InferenceClientModel(
|
111 |
+
model_id=model_id,
|
112 |
+
token=api_key or os.getenv("HF_API_KEY"),
|
113 |
+
provider=provider,
|
114 |
+
)
|
115 |
+
else:
|
116 |
+
raise ValueError(f"Unsupported model type: {model_type}")
|
117 |
+
|
118 |
+
|
119 |
+
def run_smolagent(
|
120 |
+
prompt: str,
|
121 |
+
tools: list[str],
|
122 |
+
model_type: str,
|
123 |
+
model_id: str,
|
124 |
+
api_base: str | None = None,
|
125 |
+
api_key: str | None = None,
|
126 |
+
imports: list[str] | None = None,
|
127 |
+
provider: str | None = None,
|
128 |
+
) -> None:
|
129 |
+
load_dotenv()
|
130 |
+
|
131 |
+
model = load_model(model_type, model_id, api_base=api_base, api_key=api_key, provider=provider)
|
132 |
+
|
133 |
+
available_tools = []
|
134 |
+
for tool_name in tools:
|
135 |
+
if "/" in tool_name:
|
136 |
+
available_tools.append(Tool.from_space(tool_name))
|
137 |
+
else:
|
138 |
+
if tool_name in TOOL_MAPPING:
|
139 |
+
available_tools.append(TOOL_MAPPING[tool_name]())
|
140 |
+
else:
|
141 |
+
raise ValueError(f"Tool {tool_name} is not recognized either as a default tool or a Space.")
|
142 |
+
|
143 |
+
print(f"Running agent with these tools: {tools}")
|
144 |
+
agent = CodeAgent(tools=available_tools, model=model, additional_authorized_imports=imports)
|
145 |
+
|
146 |
+
agent.run(prompt)
|
147 |
+
|
148 |
+
|
149 |
+
def main() -> None:
|
150 |
+
args = parse_arguments()
|
151 |
+
run_smolagent(
|
152 |
+
args.prompt,
|
153 |
+
args.tools,
|
154 |
+
args.model_type,
|
155 |
+
args.model_id,
|
156 |
+
provider=args.provider,
|
157 |
+
api_base=args.api_base,
|
158 |
+
api_key=args.api_key,
|
159 |
+
imports=args.imports,
|
160 |
+
)
|
161 |
+
|
162 |
+
|
163 |
+
if __name__ == "__main__":
|
164 |
+
main()
|
src/smolagents/default_tools.py
ADDED
@@ -0,0 +1,577 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding=utf-8
|
3 |
+
|
4 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
5 |
+
#
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# Unless required by applicable law or agreed to in writing, software
|
13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
# See the License for the specific language governing permissions and
|
16 |
+
# limitations under the License.
|
17 |
+
from dataclasses import dataclass
|
18 |
+
from typing import Any
|
19 |
+
|
20 |
+
from .local_python_executor import (
|
21 |
+
BASE_BUILTIN_MODULES,
|
22 |
+
BASE_PYTHON_TOOLS,
|
23 |
+
evaluate_python_code,
|
24 |
+
)
|
25 |
+
from .tools import PipelineTool, Tool
|
26 |
+
|
27 |
+
|
28 |
+
@dataclass
|
29 |
+
class PreTool:
|
30 |
+
name: str
|
31 |
+
inputs: dict[str, str]
|
32 |
+
output_type: type
|
33 |
+
task: str
|
34 |
+
description: str
|
35 |
+
repo_id: str
|
36 |
+
|
37 |
+
|
38 |
+
class PythonInterpreterTool(Tool):
|
39 |
+
name = "python_interpreter"
|
40 |
+
description = "This is a tool that evaluates python code. It can be used to perform calculations."
|
41 |
+
inputs = {
|
42 |
+
"code": {
|
43 |
+
"type": "string",
|
44 |
+
"description": "The python code to run in interpreter",
|
45 |
+
}
|
46 |
+
}
|
47 |
+
output_type = "string"
|
48 |
+
|
49 |
+
def __init__(self, *args, authorized_imports=None, **kwargs):
|
50 |
+
if authorized_imports is None:
|
51 |
+
self.authorized_imports = list(set(BASE_BUILTIN_MODULES))
|
52 |
+
else:
|
53 |
+
self.authorized_imports = list(set(BASE_BUILTIN_MODULES) | set(authorized_imports))
|
54 |
+
self.inputs = {
|
55 |
+
"code": {
|
56 |
+
"type": "string",
|
57 |
+
"description": (
|
58 |
+
"The code snippet to evaluate. All variables used in this snippet must be defined in this same snippet, "
|
59 |
+
f"else you will get an error. This code can only import the following python libraries: {self.authorized_imports}."
|
60 |
+
),
|
61 |
+
}
|
62 |
+
}
|
63 |
+
self.base_python_tools = BASE_PYTHON_TOOLS
|
64 |
+
self.python_evaluator = evaluate_python_code
|
65 |
+
super().__init__(*args, **kwargs)
|
66 |
+
|
67 |
+
def forward(self, code: str) -> str:
|
68 |
+
state = {}
|
69 |
+
output = str(
|
70 |
+
self.python_evaluator(
|
71 |
+
code,
|
72 |
+
state=state,
|
73 |
+
static_tools=self.base_python_tools,
|
74 |
+
authorized_imports=self.authorized_imports,
|
75 |
+
)[0] # The second element is boolean is_final_answer
|
76 |
+
)
|
77 |
+
return f"Stdout:\n{str(state['_print_outputs'])}\nOutput: {output}"
|
78 |
+
|
79 |
+
|
80 |
+
class FinalAnswerTool(Tool):
|
81 |
+
name = "final_answer"
|
82 |
+
description = "Provides a final answer to the given problem."
|
83 |
+
inputs = {"answer": {"type": "any", "description": "The final answer to the problem"}}
|
84 |
+
output_type = "any"
|
85 |
+
|
86 |
+
def forward(self, answer: Any) -> Any:
|
87 |
+
return answer
|
88 |
+
|
89 |
+
|
90 |
+
class UserInputTool(Tool):
|
91 |
+
name = "user_input"
|
92 |
+
description = "Asks for user's input on a specific question"
|
93 |
+
inputs = {"question": {"type": "string", "description": "The question to ask the user"}}
|
94 |
+
output_type = "string"
|
95 |
+
|
96 |
+
def forward(self, question):
|
97 |
+
user_input = input(f"{question} => Type your answer here:")
|
98 |
+
return user_input
|
99 |
+
|
100 |
+
|
101 |
+
class DuckDuckGoSearchTool(Tool):
|
102 |
+
name = "web_search"
|
103 |
+
description = """Performs a duckduckgo web search based on your query (think a Google search) then returns the top search results."""
|
104 |
+
inputs = {"query": {"type": "string", "description": "The search query to perform."}}
|
105 |
+
output_type = "string"
|
106 |
+
|
107 |
+
def __init__(self, max_results=10, **kwargs):
|
108 |
+
super().__init__()
|
109 |
+
self.max_results = max_results
|
110 |
+
try:
|
111 |
+
from duckduckgo_search import DDGS
|
112 |
+
except ImportError as e:
|
113 |
+
raise ImportError(
|
114 |
+
"You must install package `duckduckgo_search` to run this tool: for instance run `pip install duckduckgo-search`."
|
115 |
+
) from e
|
116 |
+
self.ddgs = DDGS(**kwargs)
|
117 |
+
|
118 |
+
def forward(self, query: str) -> str:
|
119 |
+
results = self.ddgs.text(query, max_results=self.max_results)
|
120 |
+
if len(results) == 0:
|
121 |
+
raise Exception("No results found! Try a less restrictive/shorter query.")
|
122 |
+
postprocessed_results = [f"[{result['title']}]({result['href']})\n{result['body']}" for result in results]
|
123 |
+
return "## Search Results\n\n" + "\n\n".join(postprocessed_results)
|
124 |
+
|
125 |
+
|
126 |
+
class GoogleSearchTool(Tool):
|
127 |
+
name = "web_search"
|
128 |
+
description = """Performs a google web search for your query then returns a string of the top search results."""
|
129 |
+
inputs = {
|
130 |
+
"query": {"type": "string", "description": "The search query to perform."},
|
131 |
+
"filter_year": {
|
132 |
+
"type": "integer",
|
133 |
+
"description": "Optionally restrict results to a certain year",
|
134 |
+
"nullable": True,
|
135 |
+
},
|
136 |
+
}
|
137 |
+
output_type = "string"
|
138 |
+
|
139 |
+
def __init__(self, provider: str = "serpapi"):
|
140 |
+
super().__init__()
|
141 |
+
import os
|
142 |
+
|
143 |
+
self.provider = provider
|
144 |
+
if provider == "serpapi":
|
145 |
+
self.organic_key = "organic_results"
|
146 |
+
api_key_env_name = "SERPAPI_API_KEY"
|
147 |
+
else:
|
148 |
+
self.organic_key = "organic"
|
149 |
+
api_key_env_name = "SERPER_API_KEY"
|
150 |
+
self.api_key = os.getenv(api_key_env_name)
|
151 |
+
if self.api_key is None:
|
152 |
+
raise ValueError(f"Missing API key. Make sure you have '{api_key_env_name}' in your env variables.")
|
153 |
+
|
154 |
+
def forward(self, query: str, filter_year: int | None = None) -> str:
|
155 |
+
import requests
|
156 |
+
|
157 |
+
if self.provider == "serpapi":
|
158 |
+
params = {
|
159 |
+
"q": query,
|
160 |
+
"api_key": self.api_key,
|
161 |
+
"engine": "google",
|
162 |
+
"google_domain": "google.com",
|
163 |
+
}
|
164 |
+
base_url = "https://serpapi.com/search.json"
|
165 |
+
else:
|
166 |
+
params = {
|
167 |
+
"q": query,
|
168 |
+
"api_key": self.api_key,
|
169 |
+
}
|
170 |
+
base_url = "https://google.serper.dev/search"
|
171 |
+
if filter_year is not None:
|
172 |
+
params["tbs"] = f"cdr:1,cd_min:01/01/{filter_year},cd_max:12/31/{filter_year}"
|
173 |
+
|
174 |
+
response = requests.get(base_url, params=params)
|
175 |
+
|
176 |
+
if response.status_code == 200:
|
177 |
+
results = response.json()
|
178 |
+
else:
|
179 |
+
raise ValueError(response.json())
|
180 |
+
|
181 |
+
if self.organic_key not in results.keys():
|
182 |
+
if filter_year is not None:
|
183 |
+
raise Exception(
|
184 |
+
f"No results found for query: '{query}' with filtering on year={filter_year}. Use a less restrictive query or do not filter on year."
|
185 |
+
)
|
186 |
+
else:
|
187 |
+
raise Exception(f"No results found for query: '{query}'. Use a less restrictive query.")
|
188 |
+
if len(results[self.organic_key]) == 0:
|
189 |
+
year_filter_message = f" with filter year={filter_year}" if filter_year is not None else ""
|
190 |
+
return f"No results found for '{query}'{year_filter_message}. Try with a more general query, or remove the year filter."
|
191 |
+
|
192 |
+
web_snippets = []
|
193 |
+
if self.organic_key in results:
|
194 |
+
for idx, page in enumerate(results[self.organic_key]):
|
195 |
+
date_published = ""
|
196 |
+
if "date" in page:
|
197 |
+
date_published = "\nDate published: " + page["date"]
|
198 |
+
|
199 |
+
source = ""
|
200 |
+
if "source" in page:
|
201 |
+
source = "\nSource: " + page["source"]
|
202 |
+
|
203 |
+
snippet = ""
|
204 |
+
if "snippet" in page:
|
205 |
+
snippet = "\n" + page["snippet"]
|
206 |
+
|
207 |
+
redacted_version = f"{idx}. [{page['title']}]({page['link']}){date_published}{source}\n{snippet}"
|
208 |
+
web_snippets.append(redacted_version)
|
209 |
+
|
210 |
+
return "## Search Results\n" + "\n\n".join(web_snippets)
|
211 |
+
|
212 |
+
|
213 |
+
class ApiWebSearchTool(Tool):
|
214 |
+
name = "web_search"
|
215 |
+
description = "Performs a web search for a query and returns a string of the top search results formatted as markdown with titles, URLs, and descriptions."
|
216 |
+
inputs = {"query": {"type": "string", "description": "The search query to perform."}}
|
217 |
+
output_type = "string"
|
218 |
+
|
219 |
+
def __init__(
|
220 |
+
self, endpoint: str = "", api_key: str = "", api_key_name: str = "", headers: dict = None, params: dict = None
|
221 |
+
):
|
222 |
+
import os
|
223 |
+
|
224 |
+
super().__init__()
|
225 |
+
self.endpoint = endpoint or "https://api.search.brave.com/res/v1/web/search"
|
226 |
+
self.api_key = api_key or os.getenv(api_key_name)
|
227 |
+
self.headers = headers or {"X-Subscription-Token": self.api_key}
|
228 |
+
self.params = params or {"count": 10}
|
229 |
+
|
230 |
+
def forward(self, query: str) -> str:
|
231 |
+
import requests
|
232 |
+
|
233 |
+
params = {**self.params, "q": query}
|
234 |
+
response = requests.get(self.endpoint, headers=self.headers, params=params)
|
235 |
+
response.raise_for_status()
|
236 |
+
data = response.json()
|
237 |
+
results = self.extract_results(data)
|
238 |
+
return self.format_markdown(results)
|
239 |
+
|
240 |
+
def extract_results(self, data: dict) -> list:
|
241 |
+
results = []
|
242 |
+
for result in data.get("web", {}).get("results", []):
|
243 |
+
results.append(
|
244 |
+
{"title": result["title"], "url": result["url"], "description": result.get("description", "")}
|
245 |
+
)
|
246 |
+
return results
|
247 |
+
|
248 |
+
def format_markdown(self, results: list) -> str:
|
249 |
+
if not results:
|
250 |
+
return "No results found."
|
251 |
+
return "## Search Results\n\n" + "\n\n".join(
|
252 |
+
[
|
253 |
+
f"{idx}. [{result['title']}]({result['url']})\n{result['description']}"
|
254 |
+
for idx, result in enumerate(results, start=1)
|
255 |
+
]
|
256 |
+
)
|
257 |
+
|
258 |
+
|
259 |
+
class WebSearchTool(Tool):
|
260 |
+
name = "web_search"
|
261 |
+
description = "Performs a web search for a query and returns a string of the top search results formatted as markdown with titles, links, and descriptions."
|
262 |
+
inputs = {"query": {"type": "string", "description": "The search query to perform."}}
|
263 |
+
output_type = "string"
|
264 |
+
|
265 |
+
def __init__(self, max_results: int = 10, engine: str = "duckduckgo"):
|
266 |
+
super().__init__()
|
267 |
+
self.max_results = max_results
|
268 |
+
self.engine = engine
|
269 |
+
|
270 |
+
def forward(self, query: str) -> str:
|
271 |
+
results = self.search(query)
|
272 |
+
if len(results) == 0:
|
273 |
+
raise Exception("No results found! Try a less restrictive/shorter query.")
|
274 |
+
return self.parse_results(results)
|
275 |
+
|
276 |
+
def search(self, query: str) -> list:
|
277 |
+
if self.engine == "duckduckgo":
|
278 |
+
return self.search_duckduckgo(query)
|
279 |
+
elif self.engine == "bing":
|
280 |
+
return self.search_bing(query)
|
281 |
+
else:
|
282 |
+
raise ValueError(f"Unsupported engine: {self.engine}")
|
283 |
+
|
284 |
+
def parse_results(self, results: list) -> str:
|
285 |
+
return "## Search Results\n\n" + "\n\n".join(
|
286 |
+
[f"[{result['title']}]({result['link']})\n{result['description']}" for result in results]
|
287 |
+
)
|
288 |
+
|
289 |
+
def search_duckduckgo(self, query: str) -> list:
|
290 |
+
import requests
|
291 |
+
|
292 |
+
response = requests.get(
|
293 |
+
"https://lite.duckduckgo.com/lite/",
|
294 |
+
params={"q": query},
|
295 |
+
headers={"User-Agent": "Mozilla/5.0"},
|
296 |
+
)
|
297 |
+
response.raise_for_status()
|
298 |
+
parser = self._create_duckduckgo_parser()
|
299 |
+
parser.feed(response.text)
|
300 |
+
return parser.results
|
301 |
+
|
302 |
+
def _create_duckduckgo_parser(self):
|
303 |
+
from html.parser import HTMLParser
|
304 |
+
|
305 |
+
class SimpleResultParser(HTMLParser):
|
306 |
+
def __init__(self):
|
307 |
+
super().__init__()
|
308 |
+
self.results = []
|
309 |
+
self.current = {}
|
310 |
+
self.capture_title = False
|
311 |
+
self.capture_description = False
|
312 |
+
self.capture_link = False
|
313 |
+
|
314 |
+
def handle_starttag(self, tag, attrs):
|
315 |
+
attrs = dict(attrs)
|
316 |
+
if tag == "a" and attrs.get("class") == "result-link":
|
317 |
+
self.capture_title = True
|
318 |
+
elif tag == "td" and attrs.get("class") == "result-snippet":
|
319 |
+
self.capture_description = True
|
320 |
+
elif tag == "span" and attrs.get("class") == "link-text":
|
321 |
+
self.capture_link = True
|
322 |
+
|
323 |
+
def handle_endtag(self, tag):
|
324 |
+
if tag == "a" and self.capture_title:
|
325 |
+
self.capture_title = False
|
326 |
+
elif tag == "td" and self.capture_description:
|
327 |
+
self.capture_description = False
|
328 |
+
elif tag == "span" and self.capture_link:
|
329 |
+
self.capture_link = False
|
330 |
+
elif tag == "tr":
|
331 |
+
# Store current result if all parts are present
|
332 |
+
if {"title", "description", "link"} <= self.current.keys():
|
333 |
+
self.current["description"] = " ".join(self.current["description"])
|
334 |
+
self.results.append(self.current)
|
335 |
+
self.current = {}
|
336 |
+
|
337 |
+
def handle_data(self, data):
|
338 |
+
if self.capture_title:
|
339 |
+
self.current["title"] = data.strip()
|
340 |
+
elif self.capture_description:
|
341 |
+
self.current.setdefault("description", [])
|
342 |
+
self.current["description"].append(data.strip())
|
343 |
+
elif self.capture_link:
|
344 |
+
self.current["link"] = "https://" + data.strip()
|
345 |
+
|
346 |
+
return SimpleResultParser()
|
347 |
+
|
348 |
+
def search_bing(self, query: str) -> list:
|
349 |
+
import xml.etree.ElementTree as ET
|
350 |
+
|
351 |
+
import requests
|
352 |
+
|
353 |
+
response = requests.get(
|
354 |
+
"https://www.bing.com/search",
|
355 |
+
params={"q": query, "format": "rss"},
|
356 |
+
)
|
357 |
+
response.raise_for_status()
|
358 |
+
root = ET.fromstring(response.text)
|
359 |
+
items = root.findall(".//item")
|
360 |
+
results = [
|
361 |
+
{
|
362 |
+
"title": item.findtext("title"),
|
363 |
+
"link": item.findtext("link"),
|
364 |
+
"description": item.findtext("description"),
|
365 |
+
}
|
366 |
+
for item in items[: self.max_results]
|
367 |
+
]
|
368 |
+
return results
|
369 |
+
|
370 |
+
|
371 |
+
class VisitWebpageTool(Tool):
|
372 |
+
name = "visit_webpage"
|
373 |
+
description = (
|
374 |
+
"Visits a webpage at the given url and reads its content as a markdown string. Use this to browse webpages."
|
375 |
+
)
|
376 |
+
inputs = {
|
377 |
+
"url": {
|
378 |
+
"type": "string",
|
379 |
+
"description": "The url of the webpage to visit.",
|
380 |
+
}
|
381 |
+
}
|
382 |
+
output_type = "string"
|
383 |
+
|
384 |
+
def __init__(self, max_output_length: int = 40000):
|
385 |
+
super().__init__()
|
386 |
+
self.max_output_length = max_output_length
|
387 |
+
|
388 |
+
def _truncate_content(self, content: str, max_length: int) -> str:
|
389 |
+
if len(content) <= max_length:
|
390 |
+
return content
|
391 |
+
return (
|
392 |
+
content[: max_length // 2]
|
393 |
+
+ f"\n..._This content has been truncated to stay below {max_length} characters_...\n"
|
394 |
+
+ content[-max_length // 2 :]
|
395 |
+
)
|
396 |
+
|
397 |
+
def forward(self, url: str) -> str:
|
398 |
+
try:
|
399 |
+
import re
|
400 |
+
|
401 |
+
import requests
|
402 |
+
from markdownify import markdownify
|
403 |
+
from requests.exceptions import RequestException
|
404 |
+
except ImportError as e:
|
405 |
+
raise ImportError(
|
406 |
+
"You must install packages `markdownify` and `requests` to run this tool: for instance run `pip install markdownify requests`."
|
407 |
+
) from e
|
408 |
+
try:
|
409 |
+
# Send a GET request to the URL with a 20-second timeout
|
410 |
+
response = requests.get(url, timeout=20)
|
411 |
+
response.raise_for_status() # Raise an exception for bad status codes
|
412 |
+
|
413 |
+
# Convert the HTML content to Markdown
|
414 |
+
markdown_content = markdownify(response.text).strip()
|
415 |
+
|
416 |
+
# Remove multiple line breaks
|
417 |
+
markdown_content = re.sub(r"\n{3,}", "\n\n", markdown_content)
|
418 |
+
|
419 |
+
return self._truncate_content(markdown_content, self.max_output_length)
|
420 |
+
|
421 |
+
except requests.exceptions.Timeout:
|
422 |
+
return "The request timed out. Please try again later or check the URL."
|
423 |
+
except RequestException as e:
|
424 |
+
return f"Error fetching the webpage: {str(e)}"
|
425 |
+
except Exception as e:
|
426 |
+
return f"An unexpected error occurred: {str(e)}"
|
427 |
+
|
428 |
+
|
429 |
+
class WikipediaSearchTool(Tool):
|
430 |
+
"""
|
431 |
+
WikipediaSearchTool searches Wikipedia and returns a summary or full text of the given topic, along with the page URL.
|
432 |
+
|
433 |
+
Attributes:
|
434 |
+
user_agent (str): A custom user-agent string to identify the project. This is required as per Wikipedia API policies, read more here: http://github.com/martin-majlis/Wikipedia-API/blob/master/README.rst
|
435 |
+
language (str): The language in which to retrieve Wikipedia articles.
|
436 |
+
http://meta.wikimedia.org/wiki/List_of_Wikipedias
|
437 |
+
content_type (str): Defines the content to fetch. Can be "summary" for a short summary or "text" for the full article.
|
438 |
+
extract_format (str): Defines the output format. Can be `"WIKI"` or `"HTML"`.
|
439 |
+
|
440 |
+
Example:
|
441 |
+
>>> from smolagents import CodeAgent, InferenceClientModel, WikipediaSearchTool
|
442 |
+
>>> agent = CodeAgent(
|
443 |
+
>>> tools=[
|
444 |
+
>>> WikipediaSearchTool(
|
445 |
+
>>> user_agent="MyResearchBot ([email protected])",
|
446 |
+
>>> language="en",
|
447 |
+
>>> content_type="summary", # or "text"
|
448 |
+
>>> extract_format="WIKI",
|
449 |
+
>>> )
|
450 |
+
>>> ],
|
451 |
+
>>> model=InferenceClientModel(),
|
452 |
+
>>> )
|
453 |
+
>>> agent.run("Python_(programming_language)")
|
454 |
+
"""
|
455 |
+
|
456 |
+
name = "wikipedia_search"
|
457 |
+
description = "Searches Wikipedia and returns a summary or full text of the given topic, along with the page URL."
|
458 |
+
inputs = {
|
459 |
+
"query": {
|
460 |
+
"type": "string",
|
461 |
+
"description": "The topic to search on Wikipedia.",
|
462 |
+
}
|
463 |
+
}
|
464 |
+
output_type = "string"
|
465 |
+
|
466 |
+
def __init__(
|
467 |
+
self,
|
468 |
+
user_agent: str = "Smolagents ([email protected])",
|
469 |
+
language: str = "en",
|
470 |
+
content_type: str = "text",
|
471 |
+
extract_format: str = "WIKI",
|
472 |
+
):
|
473 |
+
super().__init__()
|
474 |
+
try:
|
475 |
+
import wikipediaapi
|
476 |
+
except ImportError as e:
|
477 |
+
raise ImportError(
|
478 |
+
"You must install `wikipedia-api` to run this tool: for instance run `pip install wikipedia-api`"
|
479 |
+
) from e
|
480 |
+
if not user_agent:
|
481 |
+
raise ValueError("User-agent is required. Provide a meaningful identifier for your project.")
|
482 |
+
|
483 |
+
self.user_agent = user_agent
|
484 |
+
self.language = language
|
485 |
+
self.content_type = content_type
|
486 |
+
|
487 |
+
# Map string format to wikipediaapi.ExtractFormat
|
488 |
+
extract_format_map = {
|
489 |
+
"WIKI": wikipediaapi.ExtractFormat.WIKI,
|
490 |
+
"HTML": wikipediaapi.ExtractFormat.HTML,
|
491 |
+
}
|
492 |
+
|
493 |
+
if extract_format not in extract_format_map:
|
494 |
+
raise ValueError("Invalid extract_format. Choose between 'WIKI' or 'HTML'.")
|
495 |
+
|
496 |
+
self.extract_format = extract_format_map[extract_format]
|
497 |
+
|
498 |
+
self.wiki = wikipediaapi.Wikipedia(
|
499 |
+
user_agent=self.user_agent, language=self.language, extract_format=self.extract_format
|
500 |
+
)
|
501 |
+
|
502 |
+
def forward(self, query: str) -> str:
|
503 |
+
try:
|
504 |
+
page = self.wiki.page(query)
|
505 |
+
|
506 |
+
if not page.exists():
|
507 |
+
return f"No Wikipedia page found for '{query}'. Try a different query."
|
508 |
+
|
509 |
+
title = page.title
|
510 |
+
url = page.fullurl
|
511 |
+
|
512 |
+
if self.content_type == "summary":
|
513 |
+
text = page.summary
|
514 |
+
elif self.content_type == "text":
|
515 |
+
text = page.text
|
516 |
+
else:
|
517 |
+
return "⚠️ Invalid `content_type`. Use either 'summary' or 'text'."
|
518 |
+
|
519 |
+
return f"✅ **Wikipedia Page:** {title}\n\n**Content:** {text}\n\n🔗 **Read more:** {url}"
|
520 |
+
|
521 |
+
except Exception as e:
|
522 |
+
return f"Error fetching Wikipedia summary: {str(e)}"
|
523 |
+
|
524 |
+
|
525 |
+
class SpeechToTextTool(PipelineTool):
|
526 |
+
default_checkpoint = "openai/whisper-large-v3-turbo"
|
527 |
+
description = "This is a tool that transcribes an audio into text. It returns the transcribed text."
|
528 |
+
name = "transcriber"
|
529 |
+
inputs = {
|
530 |
+
"audio": {
|
531 |
+
"type": "audio",
|
532 |
+
"description": "The audio to transcribe. Can be a local path, an url, or a tensor.",
|
533 |
+
}
|
534 |
+
}
|
535 |
+
output_type = "string"
|
536 |
+
|
537 |
+
def __new__(cls, *args, **kwargs):
|
538 |
+
from transformers.models.whisper import WhisperForConditionalGeneration, WhisperProcessor
|
539 |
+
|
540 |
+
cls.pre_processor_class = WhisperProcessor
|
541 |
+
cls.model_class = WhisperForConditionalGeneration
|
542 |
+
return super().__new__(cls)
|
543 |
+
|
544 |
+
def encode(self, audio):
|
545 |
+
from .agent_types import AgentAudio
|
546 |
+
|
547 |
+
audio = AgentAudio(audio).to_raw()
|
548 |
+
return self.pre_processor(audio, return_tensors="pt")
|
549 |
+
|
550 |
+
def forward(self, inputs):
|
551 |
+
return self.model.generate(inputs["input_features"])
|
552 |
+
|
553 |
+
def decode(self, outputs):
|
554 |
+
return self.pre_processor.batch_decode(outputs, skip_special_tokens=True)[0]
|
555 |
+
|
556 |
+
|
557 |
+
TOOL_MAPPING = {
|
558 |
+
tool_class.name: tool_class
|
559 |
+
for tool_class in [
|
560 |
+
PythonInterpreterTool,
|
561 |
+
DuckDuckGoSearchTool,
|
562 |
+
VisitWebpageTool,
|
563 |
+
]
|
564 |
+
}
|
565 |
+
|
566 |
+
__all__ = [
|
567 |
+
"ApiWebSearchTool",
|
568 |
+
"PythonInterpreterTool",
|
569 |
+
"FinalAnswerTool",
|
570 |
+
"UserInputTool",
|
571 |
+
"WebSearchTool",
|
572 |
+
"DuckDuckGoSearchTool",
|
573 |
+
"GoogleSearchTool",
|
574 |
+
"VisitWebpageTool",
|
575 |
+
"WikipediaSearchTool",
|
576 |
+
"SpeechToTextTool",
|
577 |
+
]
|
src/smolagents/gradio_ui.py
ADDED
@@ -0,0 +1,508 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding=utf-8
|
3 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
import os
|
17 |
+
import re
|
18 |
+
import shutil
|
19 |
+
from pathlib import Path
|
20 |
+
from typing import Generator
|
21 |
+
|
22 |
+
from smolagents.agent_types import AgentAudio, AgentImage, AgentText
|
23 |
+
from smolagents.agents import MultiStepAgent, PlanningStep
|
24 |
+
from smolagents.memory import ActionStep, FinalAnswerStep
|
25 |
+
from smolagents.models import ChatMessageStreamDelta, MessageRole, agglomerate_stream_deltas
|
26 |
+
from smolagents.utils import _is_package_available
|
27 |
+
|
28 |
+
|
29 |
+
def get_step_footnote_content(step_log: ActionStep | PlanningStep, step_name: str) -> str:
|
30 |
+
"""Get a footnote string for a step log with duration and token information"""
|
31 |
+
step_footnote = f"**{step_name}**"
|
32 |
+
if step_log.token_usage is not None:
|
33 |
+
step_footnote += f" | Input tokens: {step_log.token_usage.input_tokens:,} | Output tokens: {step_log.token_usage.output_tokens:,}"
|
34 |
+
step_footnote += f" | Duration: {round(float(step_log.timing.duration), 2)}s" if step_log.timing.duration else ""
|
35 |
+
step_footnote_content = f"""<span style="color: #bbbbc2; font-size: 12px;">{step_footnote}</span> """
|
36 |
+
return step_footnote_content
|
37 |
+
|
38 |
+
|
39 |
+
def _clean_model_output(model_output: str) -> str:
|
40 |
+
"""
|
41 |
+
Clean up model output by removing trailing tags and extra backticks.
|
42 |
+
|
43 |
+
Args:
|
44 |
+
model_output (`str`): Raw model output.
|
45 |
+
|
46 |
+
Returns:
|
47 |
+
`str`: Cleaned model output.
|
48 |
+
"""
|
49 |
+
if not model_output:
|
50 |
+
return ""
|
51 |
+
model_output = model_output.strip()
|
52 |
+
# Remove any trailing <end_code> and extra backticks, handling multiple possible formats
|
53 |
+
model_output = re.sub(r"```\s*<end_code>", "```", model_output) # handles ```<end_code>
|
54 |
+
model_output = re.sub(r"<end_code>\s*```", "```", model_output) # handles <end_code>```
|
55 |
+
model_output = re.sub(r"```\s*\n\s*<end_code>", "```", model_output) # handles ```\n<end_code>
|
56 |
+
return model_output.strip()
|
57 |
+
|
58 |
+
|
59 |
+
def _format_code_content(content: str) -> str:
|
60 |
+
"""
|
61 |
+
Format code content as Python code block if it's not already formatted.
|
62 |
+
|
63 |
+
Args:
|
64 |
+
content (`str`): Code content to format.
|
65 |
+
|
66 |
+
Returns:
|
67 |
+
`str`: Code content formatted as a Python code block.
|
68 |
+
"""
|
69 |
+
content = content.strip()
|
70 |
+
# Remove existing code blocks and end_code tags
|
71 |
+
content = re.sub(r"```.*?\n", "", content)
|
72 |
+
content = re.sub(r"\s*<end_code>\s*", "", content)
|
73 |
+
content = content.strip()
|
74 |
+
# Add Python code block formatting if not already present
|
75 |
+
if not content.startswith("```python"):
|
76 |
+
content = f"```python\n{content}\n```"
|
77 |
+
return content
|
78 |
+
|
79 |
+
|
80 |
+
def _process_action_step(step_log: ActionStep, skip_model_outputs: bool = False) -> Generator:
|
81 |
+
"""
|
82 |
+
Process an [`ActionStep`] and yield appropriate Gradio ChatMessage objects.
|
83 |
+
|
84 |
+
Args:
|
85 |
+
step_log ([`ActionStep`]): ActionStep to process.
|
86 |
+
skip_model_outputs (`bool`): Whether to skip model outputs.
|
87 |
+
|
88 |
+
Yields:
|
89 |
+
`gradio.ChatMessage`: Gradio ChatMessages representing the action step.
|
90 |
+
"""
|
91 |
+
import gradio as gr
|
92 |
+
|
93 |
+
# Output the step number
|
94 |
+
step_number = f"Step {step_log.step_number}"
|
95 |
+
if not skip_model_outputs:
|
96 |
+
yield gr.ChatMessage(role=MessageRole.ASSISTANT, content=f"**{step_number}**", metadata={"status": "done"})
|
97 |
+
|
98 |
+
# First yield the thought/reasoning from the LLM
|
99 |
+
if not skip_model_outputs and getattr(step_log, "model_output", ""):
|
100 |
+
model_output = _clean_model_output(step_log.model_output)
|
101 |
+
yield gr.ChatMessage(role=MessageRole.ASSISTANT, content=model_output, metadata={"status": "done"})
|
102 |
+
|
103 |
+
# For tool calls, create a parent message
|
104 |
+
if getattr(step_log, "tool_calls", []):
|
105 |
+
first_tool_call = step_log.tool_calls[0]
|
106 |
+
used_code = first_tool_call.name == "python_interpreter"
|
107 |
+
|
108 |
+
# Process arguments based on type
|
109 |
+
args = first_tool_call.arguments
|
110 |
+
if isinstance(args, dict):
|
111 |
+
content = str(args.get("answer", str(args)))
|
112 |
+
else:
|
113 |
+
content = str(args).strip()
|
114 |
+
|
115 |
+
# Format code content if needed
|
116 |
+
if used_code:
|
117 |
+
content = _format_code_content(content)
|
118 |
+
|
119 |
+
# Create the tool call message
|
120 |
+
parent_message_tool = gr.ChatMessage(
|
121 |
+
role=MessageRole.ASSISTANT,
|
122 |
+
content=content,
|
123 |
+
metadata={
|
124 |
+
"title": f"🛠️ Used tool {first_tool_call.name}",
|
125 |
+
"status": "done",
|
126 |
+
},
|
127 |
+
)
|
128 |
+
yield parent_message_tool
|
129 |
+
|
130 |
+
# Display execution logs if they exist
|
131 |
+
if getattr(step_log, "observations", "") and step_log.observations.strip():
|
132 |
+
log_content = step_log.observations.strip()
|
133 |
+
if log_content:
|
134 |
+
log_content = re.sub(r"^Execution logs:\s*", "", log_content)
|
135 |
+
yield gr.ChatMessage(
|
136 |
+
role=MessageRole.ASSISTANT,
|
137 |
+
content=f"```bash\n{log_content}\n",
|
138 |
+
metadata={"title": "📝 Execution Logs", "status": "done"},
|
139 |
+
)
|
140 |
+
|
141 |
+
# Display any images in observations
|
142 |
+
if getattr(step_log, "observations_images", []):
|
143 |
+
for image in step_log.observations_images:
|
144 |
+
path_image = AgentImage(image).to_string()
|
145 |
+
yield gr.ChatMessage(
|
146 |
+
role=MessageRole.ASSISTANT,
|
147 |
+
content={"path": path_image, "mime_type": f"image/{path_image.split('.')[-1]}"},
|
148 |
+
metadata={"title": "🖼️ Output Image", "status": "done"},
|
149 |
+
)
|
150 |
+
|
151 |
+
# Handle errors
|
152 |
+
if getattr(step_log, "error", None):
|
153 |
+
yield gr.ChatMessage(
|
154 |
+
role=MessageRole.ASSISTANT, content=str(step_log.error), metadata={"title": "💥 Error", "status": "done"}
|
155 |
+
)
|
156 |
+
|
157 |
+
# Add step footnote and separator
|
158 |
+
yield gr.ChatMessage(
|
159 |
+
role=MessageRole.ASSISTANT,
|
160 |
+
content=get_step_footnote_content(step_log, step_number),
|
161 |
+
metadata={"status": "done"},
|
162 |
+
)
|
163 |
+
yield gr.ChatMessage(role=MessageRole.ASSISTANT, content="-----", metadata={"status": "done"})
|
164 |
+
|
165 |
+
|
166 |
+
def _process_planning_step(step_log: PlanningStep, skip_model_outputs: bool = False) -> Generator:
|
167 |
+
"""
|
168 |
+
Process a [`PlanningStep`] and yield appropriate gradio.ChatMessage objects.
|
169 |
+
|
170 |
+
Args:
|
171 |
+
step_log ([`PlanningStep`]): PlanningStep to process.
|
172 |
+
|
173 |
+
Yields:
|
174 |
+
`gradio.ChatMessage`: Gradio ChatMessages representing the planning step.
|
175 |
+
"""
|
176 |
+
import gradio as gr
|
177 |
+
|
178 |
+
if not skip_model_outputs:
|
179 |
+
yield gr.ChatMessage(role=MessageRole.ASSISTANT, content="**Planning step**", metadata={"status": "done"})
|
180 |
+
yield gr.ChatMessage(role=MessageRole.ASSISTANT, content=step_log.plan, metadata={"status": "done"})
|
181 |
+
yield gr.ChatMessage(
|
182 |
+
role=MessageRole.ASSISTANT,
|
183 |
+
content=get_step_footnote_content(step_log, "Planning step"),
|
184 |
+
metadata={"status": "done"},
|
185 |
+
)
|
186 |
+
yield gr.ChatMessage(role=MessageRole.ASSISTANT, content="-----", metadata={"status": "done"})
|
187 |
+
|
188 |
+
|
189 |
+
def _process_final_answer_step(step_log: FinalAnswerStep) -> Generator:
|
190 |
+
"""
|
191 |
+
Process a [`FinalAnswerStep`] and yield appropriate gradio.ChatMessage objects.
|
192 |
+
|
193 |
+
Args:
|
194 |
+
step_log ([`FinalAnswerStep`]): FinalAnswerStep to process.
|
195 |
+
|
196 |
+
Yields:
|
197 |
+
`gradio.ChatMessage`: Gradio ChatMessages representing the final answer.
|
198 |
+
"""
|
199 |
+
import gradio as gr
|
200 |
+
|
201 |
+
final_answer = step_log.output
|
202 |
+
if isinstance(final_answer, AgentText):
|
203 |
+
yield gr.ChatMessage(
|
204 |
+
role=MessageRole.ASSISTANT,
|
205 |
+
content=f"**Final answer:**\n{final_answer.to_string()}\n",
|
206 |
+
metadata={"status": "done"},
|
207 |
+
)
|
208 |
+
elif isinstance(final_answer, AgentImage):
|
209 |
+
yield gr.ChatMessage(
|
210 |
+
role=MessageRole.ASSISTANT,
|
211 |
+
content={"path": final_answer.to_string(), "mime_type": "image/png"},
|
212 |
+
metadata={"status": "done"},
|
213 |
+
)
|
214 |
+
elif isinstance(final_answer, AgentAudio):
|
215 |
+
yield gr.ChatMessage(
|
216 |
+
role=MessageRole.ASSISTANT,
|
217 |
+
content={"path": final_answer.to_string(), "mime_type": "audio/wav"},
|
218 |
+
metadata={"status": "done"},
|
219 |
+
)
|
220 |
+
else:
|
221 |
+
yield gr.ChatMessage(
|
222 |
+
role=MessageRole.ASSISTANT, content=f"**Final answer:** {str(final_answer)}", metadata={"status": "done"}
|
223 |
+
)
|
224 |
+
|
225 |
+
|
226 |
+
def pull_messages_from_step(step_log: ActionStep | PlanningStep | FinalAnswerStep, skip_model_outputs: bool = False):
|
227 |
+
"""Extract Gradio ChatMessage objects from agent steps with proper nesting.
|
228 |
+
|
229 |
+
Args:
|
230 |
+
step_log: The step log to display as gr.ChatMessage objects.
|
231 |
+
skip_model_outputs: If True, skip the model outputs when creating the gr.ChatMessage objects:
|
232 |
+
This is used for instance when streaming model outputs have already been displayed.
|
233 |
+
"""
|
234 |
+
if not _is_package_available("gradio"):
|
235 |
+
raise ModuleNotFoundError(
|
236 |
+
"Please install 'gradio' extra to use the GradioUI: `pip install 'smolagents[gradio]'`"
|
237 |
+
)
|
238 |
+
if isinstance(step_log, ActionStep):
|
239 |
+
yield from _process_action_step(step_log, skip_model_outputs)
|
240 |
+
elif isinstance(step_log, PlanningStep):
|
241 |
+
yield from _process_planning_step(step_log, skip_model_outputs)
|
242 |
+
elif isinstance(step_log, FinalAnswerStep):
|
243 |
+
yield from _process_final_answer_step(step_log)
|
244 |
+
else:
|
245 |
+
raise ValueError(f"Unsupported step type: {type(step_log)}")
|
246 |
+
|
247 |
+
|
248 |
+
def stream_to_gradio(
|
249 |
+
agent,
|
250 |
+
task: str,
|
251 |
+
task_images: list | None = None,
|
252 |
+
reset_agent_memory: bool = False,
|
253 |
+
additional_args: dict | None = None,
|
254 |
+
) -> Generator:
|
255 |
+
"""Runs an agent with the given task and streams the messages from the agent as gradio ChatMessages."""
|
256 |
+
|
257 |
+
if not _is_package_available("gradio"):
|
258 |
+
raise ModuleNotFoundError(
|
259 |
+
"Please install 'gradio' extra to use the GradioUI: `pip install 'smolagents[gradio]'`"
|
260 |
+
)
|
261 |
+
accumulated_events: list[ChatMessageStreamDelta] = []
|
262 |
+
for event in agent.run(
|
263 |
+
task, images=task_images, stream=True, reset=reset_agent_memory, additional_args=additional_args
|
264 |
+
):
|
265 |
+
if isinstance(event, ActionStep | PlanningStep | FinalAnswerStep):
|
266 |
+
for message in pull_messages_from_step(
|
267 |
+
event,
|
268 |
+
# If we're streaming model outputs, no need to display them twice
|
269 |
+
skip_model_outputs=getattr(agent, "stream_outputs", False),
|
270 |
+
):
|
271 |
+
yield message
|
272 |
+
accumulated_events = []
|
273 |
+
elif isinstance(event, ChatMessageStreamDelta):
|
274 |
+
accumulated_events.append(event)
|
275 |
+
text = agglomerate_stream_deltas(accumulated_events).render_as_markdown()
|
276 |
+
yield text
|
277 |
+
|
278 |
+
|
279 |
+
class GradioUI:
|
280 |
+
"""
|
281 |
+
Gradio interface for interacting with a [`MultiStepAgent`].
|
282 |
+
|
283 |
+
This class provides a web interface to interact with the agent in real-time, allowing users to submit prompts, upload files, and receive responses in a chat-like format.
|
284 |
+
It can reset the agent's memory at the start of each interaction if desired.
|
285 |
+
It supports file uploads, which are saved to a specified folder.
|
286 |
+
It uses the [`gradio.Chatbot`] component to display the conversation history.
|
287 |
+
This class requires the `gradio` extra to be installed: `smolagents[gradio]`.
|
288 |
+
|
289 |
+
Args:
|
290 |
+
agent ([`MultiStepAgent`]): The agent to interact with.
|
291 |
+
file_upload_folder (`str`, *optional*): The folder where uploaded files will be saved.
|
292 |
+
If not provided, file uploads are disabled.
|
293 |
+
reset_agent_memory (`bool`, *optional*, defaults to `False`): Whether to reset the agent's memory at the start of each interaction.
|
294 |
+
If `True`, the agent will not remember previous interactions.
|
295 |
+
|
296 |
+
Raises:
|
297 |
+
ModuleNotFoundError: If the `gradio` extra is not installed.
|
298 |
+
|
299 |
+
Example:
|
300 |
+
```python
|
301 |
+
from smolagents import CodeAgent, GradioUI, InferenceClientModel
|
302 |
+
|
303 |
+
model = InferenceClientModel(model_id="meta-llama/Meta-Llama-3.1-8B-Instruct")
|
304 |
+
agent = CodeAgent(tools=[], model=model)
|
305 |
+
gradio_ui = GradioUI(agent, file_upload_folder="uploads", reset_agent_memory=True)
|
306 |
+
gradio_ui.launch()
|
307 |
+
```
|
308 |
+
"""
|
309 |
+
|
310 |
+
def __init__(self, agent: MultiStepAgent, file_upload_folder: str | None = None, reset_agent_memory: bool = False):
|
311 |
+
if not _is_package_available("gradio"):
|
312 |
+
raise ModuleNotFoundError(
|
313 |
+
"Please install 'gradio' extra to use the GradioUI: `pip install 'smolagents[gradio]'`"
|
314 |
+
)
|
315 |
+
self.agent = agent
|
316 |
+
self.file_upload_folder = Path(file_upload_folder) if file_upload_folder is not None else None
|
317 |
+
self.reset_agent_memory = reset_agent_memory
|
318 |
+
self.name = getattr(agent, "name") or "Agent interface"
|
319 |
+
self.description = getattr(agent, "description", None)
|
320 |
+
if self.file_upload_folder is not None:
|
321 |
+
if not self.file_upload_folder.exists():
|
322 |
+
self.file_upload_folder.mkdir(parents=True, exist_ok=True)
|
323 |
+
|
324 |
+
def interact_with_agent(self, prompt, messages, session_state):
|
325 |
+
import gradio as gr
|
326 |
+
|
327 |
+
# Get the agent type from the template agent
|
328 |
+
if "agent" not in session_state:
|
329 |
+
session_state["agent"] = self.agent
|
330 |
+
|
331 |
+
try:
|
332 |
+
messages.append(gr.ChatMessage(role="user", content=prompt, metadata={"status": "done"}))
|
333 |
+
yield messages
|
334 |
+
|
335 |
+
for msg in stream_to_gradio(
|
336 |
+
session_state["agent"], task=prompt, reset_agent_memory=self.reset_agent_memory
|
337 |
+
):
|
338 |
+
if isinstance(msg, gr.ChatMessage):
|
339 |
+
messages[-1].metadata["status"] = "done"
|
340 |
+
messages.append(msg)
|
341 |
+
elif isinstance(msg, str): # Then it's only a completion delta
|
342 |
+
msg = msg.replace("<", r"\<").replace(">", r"\>") # HTML tags seem to break Gradio Chatbot
|
343 |
+
if messages[-1].metadata["status"] == "pending":
|
344 |
+
messages[-1].content = msg
|
345 |
+
else:
|
346 |
+
messages.append(
|
347 |
+
gr.ChatMessage(role=MessageRole.ASSISTANT, content=msg, metadata={"status": "pending"})
|
348 |
+
)
|
349 |
+
yield messages
|
350 |
+
|
351 |
+
yield messages
|
352 |
+
except Exception as e:
|
353 |
+
yield messages
|
354 |
+
raise gr.Error(f"Error in interaction: {str(e)}")
|
355 |
+
|
356 |
+
def upload_file(self, file, file_uploads_log, allowed_file_types=None):
|
357 |
+
"""
|
358 |
+
Upload a file and add it to the list of uploaded files in the session state.
|
359 |
+
|
360 |
+
The file is saved to the `self.file_upload_folder` folder.
|
361 |
+
If the file type is not allowed, it returns a message indicating the disallowed file type.
|
362 |
+
|
363 |
+
Args:
|
364 |
+
file (`gradio.File`): The uploaded file.
|
365 |
+
file_uploads_log (`list`): A list to log uploaded files.
|
366 |
+
allowed_file_types (`list`, *optional*): List of allowed file extensions. Defaults to [".pdf", ".docx", ".txt"].
|
367 |
+
"""
|
368 |
+
import gradio as gr
|
369 |
+
|
370 |
+
if file is None:
|
371 |
+
return gr.Textbox(value="No file uploaded", visible=True), file_uploads_log
|
372 |
+
|
373 |
+
if allowed_file_types is None:
|
374 |
+
allowed_file_types = [".pdf", ".docx", ".txt"]
|
375 |
+
|
376 |
+
file_ext = os.path.splitext(file.name)[1].lower()
|
377 |
+
if file_ext not in allowed_file_types:
|
378 |
+
return gr.Textbox("File type disallowed", visible=True), file_uploads_log
|
379 |
+
|
380 |
+
# Sanitize file name
|
381 |
+
original_name = os.path.basename(file.name)
|
382 |
+
sanitized_name = re.sub(
|
383 |
+
r"[^\w\-.]", "_", original_name
|
384 |
+
) # Replace any non-alphanumeric, non-dash, or non-dot characters with underscores
|
385 |
+
|
386 |
+
# Save the uploaded file to the specified folder
|
387 |
+
file_path = os.path.join(self.file_upload_folder, os.path.basename(sanitized_name))
|
388 |
+
shutil.copy(file.name, file_path)
|
389 |
+
|
390 |
+
return gr.Textbox(f"File uploaded: {file_path}", visible=True), file_uploads_log + [file_path]
|
391 |
+
|
392 |
+
def log_user_message(self, text_input, file_uploads_log):
|
393 |
+
import gradio as gr
|
394 |
+
|
395 |
+
return (
|
396 |
+
text_input
|
397 |
+
+ (
|
398 |
+
f"\nYou have been provided with these files, which might be helpful or not: {file_uploads_log}"
|
399 |
+
if len(file_uploads_log) > 0
|
400 |
+
else ""
|
401 |
+
),
|
402 |
+
"",
|
403 |
+
gr.Button(interactive=False),
|
404 |
+
)
|
405 |
+
|
406 |
+
def launch(self, share: bool = True, **kwargs):
|
407 |
+
"""
|
408 |
+
Launch the Gradio app with the agent interface.
|
409 |
+
|
410 |
+
Args:
|
411 |
+
share (`bool`, defaults to `True`): Whether to share the app publicly.
|
412 |
+
**kwargs: Additional keyword arguments to pass to the Gradio launch method.
|
413 |
+
"""
|
414 |
+
self.create_app().launch(debug=True, share=share, **kwargs)
|
415 |
+
|
416 |
+
def create_app(self):
|
417 |
+
import gradio as gr
|
418 |
+
|
419 |
+
with gr.Blocks(theme="ocean", fill_height=True) as demo:
|
420 |
+
# Add session state to store session-specific data
|
421 |
+
session_state = gr.State({})
|
422 |
+
stored_messages = gr.State([])
|
423 |
+
file_uploads_log = gr.State([])
|
424 |
+
|
425 |
+
with gr.Sidebar():
|
426 |
+
gr.Markdown(
|
427 |
+
f"# {self.name.replace('_', ' ').capitalize()}"
|
428 |
+
"\n> This web ui allows you to interact with a `smolagents` agent that can use tools and execute steps to complete tasks."
|
429 |
+
+ (f"\n\n**Agent description:**\n{self.description}" if self.description else "")
|
430 |
+
)
|
431 |
+
|
432 |
+
with gr.Group():
|
433 |
+
gr.Markdown("**Your request**", container=True)
|
434 |
+
text_input = gr.Textbox(
|
435 |
+
lines=3,
|
436 |
+
label="Chat Message",
|
437 |
+
container=False,
|
438 |
+
placeholder="Enter your prompt here and press Shift+Enter or press the button",
|
439 |
+
)
|
440 |
+
submit_btn = gr.Button("Submit", variant="primary")
|
441 |
+
|
442 |
+
# If an upload folder is provided, enable the upload feature
|
443 |
+
if self.file_upload_folder is not None:
|
444 |
+
upload_file = gr.File(label="Upload a file")
|
445 |
+
upload_status = gr.Textbox(label="Upload Status", interactive=False, visible=False)
|
446 |
+
upload_file.change(
|
447 |
+
self.upload_file,
|
448 |
+
[upload_file, file_uploads_log],
|
449 |
+
[upload_status, file_uploads_log],
|
450 |
+
)
|
451 |
+
|
452 |
+
gr.HTML(
|
453 |
+
"<br><br><h4><center>Powered by <a target='_blank' href='https://github.com/huggingface/smolagents'><b>smolagents</b></a></center></h4>"
|
454 |
+
)
|
455 |
+
|
456 |
+
# Main chat interface
|
457 |
+
chatbot = gr.Chatbot(
|
458 |
+
label="Agent",
|
459 |
+
type="messages",
|
460 |
+
avatar_images=(
|
461 |
+
None,
|
462 |
+
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/smolagents/mascot_smol.png",
|
463 |
+
),
|
464 |
+
resizeable=True,
|
465 |
+
scale=1,
|
466 |
+
latex_delimiters=[
|
467 |
+
{"left": r"$$", "right": r"$$", "display": True},
|
468 |
+
{"left": r"$", "right": r"$", "display": False},
|
469 |
+
{"left": r"\[", "right": r"\]", "display": True},
|
470 |
+
{"left": r"\(", "right": r"\)", "display": False},
|
471 |
+
],
|
472 |
+
)
|
473 |
+
|
474 |
+
# Set up event handlers
|
475 |
+
text_input.submit(
|
476 |
+
self.log_user_message,
|
477 |
+
[text_input, file_uploads_log],
|
478 |
+
[stored_messages, text_input, submit_btn],
|
479 |
+
).then(self.interact_with_agent, [stored_messages, chatbot, session_state], [chatbot]).then(
|
480 |
+
lambda: (
|
481 |
+
gr.Textbox(
|
482 |
+
interactive=True, placeholder="Enter your prompt here and press Shift+Enter or the button"
|
483 |
+
),
|
484 |
+
gr.Button(interactive=True),
|
485 |
+
),
|
486 |
+
None,
|
487 |
+
[text_input, submit_btn],
|
488 |
+
)
|
489 |
+
|
490 |
+
submit_btn.click(
|
491 |
+
self.log_user_message,
|
492 |
+
[text_input, file_uploads_log],
|
493 |
+
[stored_messages, text_input, submit_btn],
|
494 |
+
).then(self.interact_with_agent, [stored_messages, chatbot, session_state], [chatbot]).then(
|
495 |
+
lambda: (
|
496 |
+
gr.Textbox(
|
497 |
+
interactive=True, placeholder="Enter your prompt here and press Shift+Enter or the button"
|
498 |
+
),
|
499 |
+
gr.Button(interactive=True),
|
500 |
+
),
|
501 |
+
None,
|
502 |
+
[text_input, submit_btn],
|
503 |
+
)
|
504 |
+
|
505 |
+
return demo
|
506 |
+
|
507 |
+
|
508 |
+
__all__ = ["stream_to_gradio", "GradioUI"]
|
src/smolagents/local_python_executor.py
ADDED
@@ -0,0 +1,1611 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding=utf-8
|
3 |
+
|
4 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
5 |
+
#
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# Unless required by applicable law or agreed to in writing, software
|
13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
# See the License for the specific language governing permissions and
|
16 |
+
# limitations under the License.
|
17 |
+
import ast
|
18 |
+
import builtins
|
19 |
+
import difflib
|
20 |
+
import inspect
|
21 |
+
import logging
|
22 |
+
import math
|
23 |
+
import re
|
24 |
+
from collections.abc import Callable, Mapping
|
25 |
+
from functools import wraps
|
26 |
+
from importlib import import_module
|
27 |
+
from types import BuiltinFunctionType, FunctionType, ModuleType
|
28 |
+
from typing import Any
|
29 |
+
|
30 |
+
from .tools import Tool
|
31 |
+
from .utils import BASE_BUILTIN_MODULES, truncate_content
|
32 |
+
|
33 |
+
|
34 |
+
logger = logging.getLogger(__name__)
|
35 |
+
|
36 |
+
|
37 |
+
class InterpreterError(ValueError):
|
38 |
+
"""
|
39 |
+
An error raised when the interpreter cannot evaluate a Python expression, due to syntax error or unsupported
|
40 |
+
operations.
|
41 |
+
"""
|
42 |
+
|
43 |
+
pass
|
44 |
+
|
45 |
+
|
46 |
+
ERRORS = {
|
47 |
+
name: getattr(builtins, name)
|
48 |
+
for name in dir(builtins)
|
49 |
+
if isinstance(getattr(builtins, name), type) and issubclass(getattr(builtins, name), BaseException)
|
50 |
+
}
|
51 |
+
|
52 |
+
DEFAULT_MAX_LEN_OUTPUT = 50000
|
53 |
+
MAX_OPERATIONS = 10000000
|
54 |
+
MAX_WHILE_ITERATIONS = 1000000
|
55 |
+
|
56 |
+
|
57 |
+
def custom_print(*args):
|
58 |
+
return None
|
59 |
+
|
60 |
+
|
61 |
+
def nodunder_getattr(obj, name, default=None):
|
62 |
+
if name.startswith("__") and name.endswith("__"):
|
63 |
+
raise InterpreterError(f"Forbidden access to dunder attribute: {name}")
|
64 |
+
return getattr(obj, name, default)
|
65 |
+
|
66 |
+
|
67 |
+
BASE_PYTHON_TOOLS = {
|
68 |
+
"print": custom_print,
|
69 |
+
"isinstance": isinstance,
|
70 |
+
"range": range,
|
71 |
+
"float": float,
|
72 |
+
"int": int,
|
73 |
+
"bool": bool,
|
74 |
+
"str": str,
|
75 |
+
"set": set,
|
76 |
+
"list": list,
|
77 |
+
"dict": dict,
|
78 |
+
"tuple": tuple,
|
79 |
+
"round": round,
|
80 |
+
"ceil": math.ceil,
|
81 |
+
"floor": math.floor,
|
82 |
+
"log": math.log,
|
83 |
+
"exp": math.exp,
|
84 |
+
"sin": math.sin,
|
85 |
+
"cos": math.cos,
|
86 |
+
"tan": math.tan,
|
87 |
+
"asin": math.asin,
|
88 |
+
"acos": math.acos,
|
89 |
+
"atan": math.atan,
|
90 |
+
"atan2": math.atan2,
|
91 |
+
"degrees": math.degrees,
|
92 |
+
"radians": math.radians,
|
93 |
+
"pow": pow,
|
94 |
+
"sqrt": math.sqrt,
|
95 |
+
"len": len,
|
96 |
+
"sum": sum,
|
97 |
+
"max": max,
|
98 |
+
"min": min,
|
99 |
+
"abs": abs,
|
100 |
+
"enumerate": enumerate,
|
101 |
+
"zip": zip,
|
102 |
+
"reversed": reversed,
|
103 |
+
"sorted": sorted,
|
104 |
+
"all": all,
|
105 |
+
"any": any,
|
106 |
+
"map": map,
|
107 |
+
"filter": filter,
|
108 |
+
"ord": ord,
|
109 |
+
"chr": chr,
|
110 |
+
"next": next,
|
111 |
+
"iter": iter,
|
112 |
+
"divmod": divmod,
|
113 |
+
"callable": callable,
|
114 |
+
"getattr": nodunder_getattr,
|
115 |
+
"hasattr": hasattr,
|
116 |
+
"setattr": setattr,
|
117 |
+
"issubclass": issubclass,
|
118 |
+
"type": type,
|
119 |
+
"complex": complex,
|
120 |
+
}
|
121 |
+
|
122 |
+
# Non-exhaustive list of dangerous modules that should not be imported
|
123 |
+
DANGEROUS_MODULES = [
|
124 |
+
"builtins",
|
125 |
+
"io",
|
126 |
+
"multiprocessing",
|
127 |
+
"os",
|
128 |
+
"pathlib",
|
129 |
+
"pty",
|
130 |
+
"shutil",
|
131 |
+
"socket",
|
132 |
+
"subprocess",
|
133 |
+
"sys",
|
134 |
+
]
|
135 |
+
|
136 |
+
DANGEROUS_FUNCTIONS = [
|
137 |
+
"builtins.compile",
|
138 |
+
"builtins.eval",
|
139 |
+
"builtins.exec",
|
140 |
+
"builtins.globals",
|
141 |
+
"builtins.locals",
|
142 |
+
"builtins.__import__",
|
143 |
+
"os.popen",
|
144 |
+
"os.system",
|
145 |
+
"posix.system",
|
146 |
+
]
|
147 |
+
|
148 |
+
|
149 |
+
def check_safer_result(result: Any, static_tools: dict[str, Callable] = None, authorized_imports: list[str] = None):
|
150 |
+
"""
|
151 |
+
Checks if a result is safer according to authorized imports and static tools.
|
152 |
+
|
153 |
+
Args:
|
154 |
+
result (Any): The result to check.
|
155 |
+
static_tools (dict[str, Callable]): Dictionary of static tools.
|
156 |
+
authorized_imports (list[str]): List of authorized imports.
|
157 |
+
|
158 |
+
Raises:
|
159 |
+
InterpreterError: If the result is not safe
|
160 |
+
"""
|
161 |
+
if isinstance(result, ModuleType):
|
162 |
+
if not check_import_authorized(result.__name__, authorized_imports):
|
163 |
+
raise InterpreterError(f"Forbidden access to module: {result.__name__}")
|
164 |
+
elif isinstance(result, dict) and result.get("__spec__"):
|
165 |
+
if not check_import_authorized(result["__name__"], authorized_imports):
|
166 |
+
raise InterpreterError(f"Forbidden access to module: {result['__name__']}")
|
167 |
+
elif isinstance(result, (FunctionType, BuiltinFunctionType)):
|
168 |
+
for qualified_function_name in DANGEROUS_FUNCTIONS:
|
169 |
+
module_name, function_name = qualified_function_name.rsplit(".", 1)
|
170 |
+
if (
|
171 |
+
(static_tools is None or function_name not in static_tools)
|
172 |
+
and result.__name__ == function_name
|
173 |
+
and result.__module__ == module_name
|
174 |
+
):
|
175 |
+
raise InterpreterError(f"Forbidden access to function: {function_name}")
|
176 |
+
|
177 |
+
|
178 |
+
def safer_eval(func: Callable):
|
179 |
+
"""
|
180 |
+
Decorator to enhance the security of an evaluation function by checking its return value.
|
181 |
+
|
182 |
+
Args:
|
183 |
+
func (Callable): Evaluation function to be made safer.
|
184 |
+
|
185 |
+
Returns:
|
186 |
+
Callable: Safer evaluation function with return value check.
|
187 |
+
"""
|
188 |
+
|
189 |
+
@wraps(func)
|
190 |
+
def _check_return(
|
191 |
+
expression,
|
192 |
+
state,
|
193 |
+
static_tools,
|
194 |
+
custom_tools,
|
195 |
+
authorized_imports=BASE_BUILTIN_MODULES,
|
196 |
+
):
|
197 |
+
result = func(expression, state, static_tools, custom_tools, authorized_imports=authorized_imports)
|
198 |
+
check_safer_result(result, static_tools, authorized_imports)
|
199 |
+
return result
|
200 |
+
|
201 |
+
return _check_return
|
202 |
+
|
203 |
+
|
204 |
+
def safer_func(
|
205 |
+
func: Callable,
|
206 |
+
static_tools: dict[str, Callable] = BASE_PYTHON_TOOLS,
|
207 |
+
authorized_imports: list[str] = BASE_BUILTIN_MODULES,
|
208 |
+
):
|
209 |
+
"""
|
210 |
+
Decorator to enhance the security of a function call by checking its return value.
|
211 |
+
|
212 |
+
Args:
|
213 |
+
func (Callable): Function to be made safer.
|
214 |
+
static_tools (dict[str, Callable]): Dictionary of static tools.
|
215 |
+
authorized_imports (list[str]): List of authorized imports.
|
216 |
+
|
217 |
+
Returns:
|
218 |
+
Callable: Safer function with return value check.
|
219 |
+
"""
|
220 |
+
# If the function is a type, return it directly without wrapping
|
221 |
+
if isinstance(func, type):
|
222 |
+
return func
|
223 |
+
|
224 |
+
@wraps(func)
|
225 |
+
def _check_return(*args, **kwargs):
|
226 |
+
result = func(*args, **kwargs)
|
227 |
+
check_safer_result(result, static_tools, authorized_imports)
|
228 |
+
return result
|
229 |
+
|
230 |
+
return _check_return
|
231 |
+
|
232 |
+
|
233 |
+
class PrintContainer:
|
234 |
+
def __init__(self):
|
235 |
+
self.value = ""
|
236 |
+
|
237 |
+
def append(self, text):
|
238 |
+
self.value += text
|
239 |
+
return self
|
240 |
+
|
241 |
+
def __iadd__(self, other):
|
242 |
+
"""Implements the += operator"""
|
243 |
+
self.value += str(other)
|
244 |
+
return self
|
245 |
+
|
246 |
+
def __str__(self):
|
247 |
+
"""String representation"""
|
248 |
+
return self.value
|
249 |
+
|
250 |
+
def __repr__(self):
|
251 |
+
"""Representation for debugging"""
|
252 |
+
return f"PrintContainer({self.value})"
|
253 |
+
|
254 |
+
def __len__(self):
|
255 |
+
"""Implements len() function support"""
|
256 |
+
return len(self.value)
|
257 |
+
|
258 |
+
|
259 |
+
class BreakException(Exception):
|
260 |
+
pass
|
261 |
+
|
262 |
+
|
263 |
+
class ContinueException(Exception):
|
264 |
+
pass
|
265 |
+
|
266 |
+
|
267 |
+
class ReturnException(Exception):
|
268 |
+
def __init__(self, value):
|
269 |
+
self.value = value
|
270 |
+
|
271 |
+
|
272 |
+
def get_iterable(obj):
|
273 |
+
if isinstance(obj, list):
|
274 |
+
return obj
|
275 |
+
elif hasattr(obj, "__iter__"):
|
276 |
+
return list(obj)
|
277 |
+
else:
|
278 |
+
raise InterpreterError("Object is not iterable")
|
279 |
+
|
280 |
+
|
281 |
+
def fix_final_answer_code(code: str) -> str:
|
282 |
+
"""
|
283 |
+
Sometimes an LLM can try to assign a variable to final_answer, which would break the final_answer() tool.
|
284 |
+
This function fixes this behaviour by replacing variable assignments to final_answer with final_answer_variable,
|
285 |
+
while preserving function calls to final_answer().
|
286 |
+
"""
|
287 |
+
# First, find if there's a direct assignment to final_answer
|
288 |
+
# Use word boundary and negative lookbehind to ensure it's not an object attribute
|
289 |
+
assignment_pattern = r"(?<!\.)(?<!\w)\bfinal_answer\s*="
|
290 |
+
if "final_answer(" not in code or not re.search(assignment_pattern, code):
|
291 |
+
# If final_answer tool is not called in this blob, then doing the replacement is hazardous because it could false the model's memory for next steps.
|
292 |
+
# Let's not modify the code and leave the subsequent assignment error happen.
|
293 |
+
return code
|
294 |
+
|
295 |
+
# Pattern for replacing variable assignments
|
296 |
+
# Looks for 'final_answer' followed by '=' with optional whitespace
|
297 |
+
# Negative lookbehind ensures we don't match object attributes
|
298 |
+
assignment_regex = r"(?<!\.)(?<!\w)(\bfinal_answer)(\s*=)"
|
299 |
+
code = re.sub(assignment_regex, r"final_answer_variable\2", code)
|
300 |
+
|
301 |
+
# Pattern for replacing variable usage but not function calls
|
302 |
+
# Negative lookahead (?!\s*\() ensures we don't match function calls
|
303 |
+
# Negative lookbehind (?<!\.|\w) ensures we don't match object methods or other variables
|
304 |
+
variable_regex = r"(?<!\.)(?<!\w)(\bfinal_answer\b)(?!\s*\()"
|
305 |
+
code = re.sub(variable_regex, "final_answer_variable", code)
|
306 |
+
return code
|
307 |
+
|
308 |
+
|
309 |
+
def build_import_tree(authorized_imports: list[str]) -> dict[str, Any]:
|
310 |
+
tree = {}
|
311 |
+
for import_path in authorized_imports:
|
312 |
+
parts = import_path.split(".")
|
313 |
+
current = tree
|
314 |
+
for part in parts:
|
315 |
+
if part not in current:
|
316 |
+
current[part] = {}
|
317 |
+
current = current[part]
|
318 |
+
return tree
|
319 |
+
|
320 |
+
|
321 |
+
def check_import_authorized(import_to_check: str, authorized_imports: list[str]) -> bool:
|
322 |
+
current_node = build_import_tree(authorized_imports)
|
323 |
+
for part in import_to_check.split("."):
|
324 |
+
if "*" in current_node:
|
325 |
+
return True
|
326 |
+
if part not in current_node:
|
327 |
+
return False
|
328 |
+
current_node = current_node[part]
|
329 |
+
return True
|
330 |
+
|
331 |
+
|
332 |
+
def evaluate_attribute(
|
333 |
+
expression: ast.Attribute,
|
334 |
+
state: dict[str, Any],
|
335 |
+
static_tools: dict[str, Callable],
|
336 |
+
custom_tools: dict[str, Callable],
|
337 |
+
authorized_imports: list[str],
|
338 |
+
) -> Any:
|
339 |
+
if expression.attr.startswith("__") and expression.attr.endswith("__"):
|
340 |
+
raise InterpreterError(f"Forbidden access to dunder attribute: {expression.attr}")
|
341 |
+
value = evaluate_ast(expression.value, state, static_tools, custom_tools, authorized_imports)
|
342 |
+
return getattr(value, expression.attr)
|
343 |
+
|
344 |
+
|
345 |
+
def evaluate_unaryop(
|
346 |
+
expression: ast.UnaryOp,
|
347 |
+
state: dict[str, Any],
|
348 |
+
static_tools: dict[str, Callable],
|
349 |
+
custom_tools: dict[str, Callable],
|
350 |
+
authorized_imports: list[str],
|
351 |
+
) -> Any:
|
352 |
+
operand = evaluate_ast(expression.operand, state, static_tools, custom_tools, authorized_imports)
|
353 |
+
if isinstance(expression.op, ast.USub):
|
354 |
+
return -operand
|
355 |
+
elif isinstance(expression.op, ast.UAdd):
|
356 |
+
return operand
|
357 |
+
elif isinstance(expression.op, ast.Not):
|
358 |
+
return not operand
|
359 |
+
elif isinstance(expression.op, ast.Invert):
|
360 |
+
return ~operand
|
361 |
+
else:
|
362 |
+
raise InterpreterError(f"Unary operation {expression.op.__class__.__name__} is not supported.")
|
363 |
+
|
364 |
+
|
365 |
+
def evaluate_lambda(
|
366 |
+
lambda_expression: ast.Lambda,
|
367 |
+
state: dict[str, Any],
|
368 |
+
static_tools: dict[str, Callable],
|
369 |
+
custom_tools: dict[str, Callable],
|
370 |
+
authorized_imports: list[str],
|
371 |
+
) -> Callable:
|
372 |
+
args = [arg.arg for arg in lambda_expression.args.args]
|
373 |
+
|
374 |
+
def lambda_func(*values: Any) -> Any:
|
375 |
+
new_state = state.copy()
|
376 |
+
for arg, value in zip(args, values):
|
377 |
+
new_state[arg] = value
|
378 |
+
return evaluate_ast(
|
379 |
+
lambda_expression.body,
|
380 |
+
new_state,
|
381 |
+
static_tools,
|
382 |
+
custom_tools,
|
383 |
+
authorized_imports,
|
384 |
+
)
|
385 |
+
|
386 |
+
return lambda_func
|
387 |
+
|
388 |
+
|
389 |
+
def evaluate_while(
|
390 |
+
while_loop: ast.While,
|
391 |
+
state: dict[str, Any],
|
392 |
+
static_tools: dict[str, Callable],
|
393 |
+
custom_tools: dict[str, Callable],
|
394 |
+
authorized_imports: list[str],
|
395 |
+
) -> None:
|
396 |
+
iterations = 0
|
397 |
+
while evaluate_ast(while_loop.test, state, static_tools, custom_tools, authorized_imports):
|
398 |
+
for node in while_loop.body:
|
399 |
+
try:
|
400 |
+
evaluate_ast(node, state, static_tools, custom_tools, authorized_imports)
|
401 |
+
except BreakException:
|
402 |
+
return None
|
403 |
+
except ContinueException:
|
404 |
+
break
|
405 |
+
iterations += 1
|
406 |
+
if iterations > MAX_WHILE_ITERATIONS:
|
407 |
+
raise InterpreterError(f"Maximum number of {MAX_WHILE_ITERATIONS} iterations in While loop exceeded")
|
408 |
+
return None
|
409 |
+
|
410 |
+
|
411 |
+
def create_function(
|
412 |
+
func_def: ast.FunctionDef,
|
413 |
+
state: dict[str, Any],
|
414 |
+
static_tools: dict[str, Callable],
|
415 |
+
custom_tools: dict[str, Callable],
|
416 |
+
authorized_imports: list[str],
|
417 |
+
) -> Callable:
|
418 |
+
source_code = ast.unparse(func_def)
|
419 |
+
|
420 |
+
def new_func(*args: Any, **kwargs: Any) -> Any:
|
421 |
+
func_state = state.copy()
|
422 |
+
arg_names = [arg.arg for arg in func_def.args.args]
|
423 |
+
default_values = [
|
424 |
+
evaluate_ast(d, state, static_tools, custom_tools, authorized_imports) for d in func_def.args.defaults
|
425 |
+
]
|
426 |
+
|
427 |
+
# Apply default values
|
428 |
+
defaults = dict(zip(arg_names[-len(default_values) :], default_values))
|
429 |
+
|
430 |
+
# Set positional arguments
|
431 |
+
for name, value in zip(arg_names, args):
|
432 |
+
func_state[name] = value
|
433 |
+
|
434 |
+
# Set keyword arguments
|
435 |
+
for name, value in kwargs.items():
|
436 |
+
func_state[name] = value
|
437 |
+
|
438 |
+
# Handle variable arguments
|
439 |
+
if func_def.args.vararg:
|
440 |
+
vararg_name = func_def.args.vararg.arg
|
441 |
+
func_state[vararg_name] = args
|
442 |
+
|
443 |
+
if func_def.args.kwarg:
|
444 |
+
kwarg_name = func_def.args.kwarg.arg
|
445 |
+
func_state[kwarg_name] = kwargs
|
446 |
+
|
447 |
+
# Set default values for arguments that were not provided
|
448 |
+
for name, value in defaults.items():
|
449 |
+
if name not in func_state:
|
450 |
+
func_state[name] = value
|
451 |
+
|
452 |
+
# Update function state with self and __class__
|
453 |
+
if func_def.args.args and func_def.args.args[0].arg == "self":
|
454 |
+
if args:
|
455 |
+
func_state["self"] = args[0]
|
456 |
+
func_state["__class__"] = args[0].__class__
|
457 |
+
|
458 |
+
result = None
|
459 |
+
try:
|
460 |
+
for stmt in func_def.body:
|
461 |
+
result = evaluate_ast(stmt, func_state, static_tools, custom_tools, authorized_imports)
|
462 |
+
except ReturnException as e:
|
463 |
+
result = e.value
|
464 |
+
|
465 |
+
if func_def.name == "__init__":
|
466 |
+
return None
|
467 |
+
|
468 |
+
return result
|
469 |
+
|
470 |
+
# Store original AST, source code, and name
|
471 |
+
new_func.__ast__ = func_def
|
472 |
+
new_func.__source__ = source_code
|
473 |
+
new_func.__name__ = func_def.name
|
474 |
+
|
475 |
+
return new_func
|
476 |
+
|
477 |
+
|
478 |
+
def evaluate_function_def(
|
479 |
+
func_def: ast.FunctionDef,
|
480 |
+
state: dict[str, Any],
|
481 |
+
static_tools: dict[str, Callable],
|
482 |
+
custom_tools: dict[str, Callable],
|
483 |
+
authorized_imports: list[str],
|
484 |
+
) -> Callable:
|
485 |
+
custom_tools[func_def.name] = create_function(func_def, state, static_tools, custom_tools, authorized_imports)
|
486 |
+
return custom_tools[func_def.name]
|
487 |
+
|
488 |
+
|
489 |
+
def evaluate_class_def(
|
490 |
+
class_def: ast.ClassDef,
|
491 |
+
state: dict[str, Any],
|
492 |
+
static_tools: dict[str, Callable],
|
493 |
+
custom_tools: dict[str, Callable],
|
494 |
+
authorized_imports: list[str],
|
495 |
+
) -> type:
|
496 |
+
class_name = class_def.name
|
497 |
+
bases = [evaluate_ast(base, state, static_tools, custom_tools, authorized_imports) for base in class_def.bases]
|
498 |
+
class_dict = {}
|
499 |
+
|
500 |
+
for stmt in class_def.body:
|
501 |
+
if isinstance(stmt, ast.FunctionDef):
|
502 |
+
class_dict[stmt.name] = evaluate_ast(stmt, state, static_tools, custom_tools, authorized_imports)
|
503 |
+
elif isinstance(stmt, ast.AnnAssign):
|
504 |
+
if stmt.value:
|
505 |
+
value = evaluate_ast(stmt.value, state, static_tools, custom_tools, authorized_imports)
|
506 |
+
target = stmt.target
|
507 |
+
# Handle target types for annotation
|
508 |
+
if isinstance(target, ast.Name):
|
509 |
+
# Simple variable annotation like "x: int"
|
510 |
+
annotation = evaluate_ast(stmt.annotation, state, static_tools, custom_tools, authorized_imports)
|
511 |
+
class_dict.setdefault("__annotations__", {})[target.id] = annotation
|
512 |
+
# Assign value if provided
|
513 |
+
if stmt.value:
|
514 |
+
class_dict[target.id] = value
|
515 |
+
elif isinstance(target, ast.Attribute):
|
516 |
+
# Attribute annotation like "obj.attr: int"
|
517 |
+
obj = evaluate_ast(target.value, class_dict, static_tools, custom_tools, authorized_imports)
|
518 |
+
# If there's a value assignment, set the attribute
|
519 |
+
if stmt.value:
|
520 |
+
setattr(obj, target.attr, value)
|
521 |
+
elif isinstance(target, ast.Subscript):
|
522 |
+
# Subscript annotation like "dict[key]: int"
|
523 |
+
container = evaluate_ast(target.value, class_dict, static_tools, custom_tools, authorized_imports)
|
524 |
+
index = evaluate_ast(target.slice, state, static_tools, custom_tools, authorized_imports)
|
525 |
+
# If there's a value assignment, set the item
|
526 |
+
if stmt.value:
|
527 |
+
container[index] = value
|
528 |
+
else:
|
529 |
+
raise InterpreterError(f"Unsupported AnnAssign target in class body: {type(target).__name__}")
|
530 |
+
elif isinstance(stmt, ast.Assign):
|
531 |
+
value = evaluate_ast(stmt.value, state, static_tools, custom_tools, authorized_imports)
|
532 |
+
for target in stmt.targets:
|
533 |
+
if isinstance(target, ast.Name):
|
534 |
+
class_dict[target.id] = value
|
535 |
+
elif isinstance(target, ast.Attribute):
|
536 |
+
obj = evaluate_ast(target.value, class_dict, static_tools, custom_tools, authorized_imports)
|
537 |
+
setattr(obj, target.attr, value)
|
538 |
+
elif isinstance(stmt, ast.Pass):
|
539 |
+
pass
|
540 |
+
elif (
|
541 |
+
isinstance(stmt, ast.Expr)
|
542 |
+
and stmt == class_def.body[0]
|
543 |
+
and isinstance(stmt.value, ast.Constant)
|
544 |
+
and isinstance(stmt.value.value, str)
|
545 |
+
):
|
546 |
+
# Check if it is a docstring: first statement in class body which is a string literal expression
|
547 |
+
class_dict["__doc__"] = stmt.value.value
|
548 |
+
else:
|
549 |
+
raise InterpreterError(f"Unsupported statement in class body: {stmt.__class__.__name__}")
|
550 |
+
|
551 |
+
new_class = type(class_name, tuple(bases), class_dict)
|
552 |
+
state[class_name] = new_class
|
553 |
+
return new_class
|
554 |
+
|
555 |
+
|
556 |
+
def evaluate_annassign(
|
557 |
+
annassign: ast.AnnAssign,
|
558 |
+
state: dict[str, Any],
|
559 |
+
static_tools: dict[str, Callable],
|
560 |
+
custom_tools: dict[str, Callable],
|
561 |
+
authorized_imports: list[str],
|
562 |
+
) -> Any:
|
563 |
+
# If there's a value to assign, evaluate it
|
564 |
+
if annassign.value:
|
565 |
+
value = evaluate_ast(annassign.value, state, static_tools, custom_tools, authorized_imports)
|
566 |
+
# Set the value for the target
|
567 |
+
set_value(annassign.target, value, state, static_tools, custom_tools, authorized_imports)
|
568 |
+
return value
|
569 |
+
# For declarations without values (x: int), just return None
|
570 |
+
return None
|
571 |
+
|
572 |
+
|
573 |
+
def evaluate_augassign(
|
574 |
+
expression: ast.AugAssign,
|
575 |
+
state: dict[str, Any],
|
576 |
+
static_tools: dict[str, Callable],
|
577 |
+
custom_tools: dict[str, Callable],
|
578 |
+
authorized_imports: list[str],
|
579 |
+
) -> Any:
|
580 |
+
def get_current_value(target: ast.AST) -> Any:
|
581 |
+
if isinstance(target, ast.Name):
|
582 |
+
return state.get(target.id, 0)
|
583 |
+
elif isinstance(target, ast.Subscript):
|
584 |
+
obj = evaluate_ast(target.value, state, static_tools, custom_tools, authorized_imports)
|
585 |
+
key = evaluate_ast(target.slice, state, static_tools, custom_tools, authorized_imports)
|
586 |
+
return obj[key]
|
587 |
+
elif isinstance(target, ast.Attribute):
|
588 |
+
obj = evaluate_ast(target.value, state, static_tools, custom_tools, authorized_imports)
|
589 |
+
return getattr(obj, target.attr)
|
590 |
+
elif isinstance(target, ast.Tuple):
|
591 |
+
return tuple(get_current_value(elt) for elt in target.elts)
|
592 |
+
elif isinstance(target, ast.List):
|
593 |
+
return [get_current_value(elt) for elt in target.elts]
|
594 |
+
else:
|
595 |
+
raise InterpreterError("AugAssign not supported for {type(target)} targets.")
|
596 |
+
|
597 |
+
current_value = get_current_value(expression.target)
|
598 |
+
value_to_add = evaluate_ast(expression.value, state, static_tools, custom_tools, authorized_imports)
|
599 |
+
|
600 |
+
if isinstance(expression.op, ast.Add):
|
601 |
+
if isinstance(current_value, list):
|
602 |
+
if not isinstance(value_to_add, list):
|
603 |
+
raise InterpreterError(f"Cannot add non-list value {value_to_add} to a list.")
|
604 |
+
current_value += value_to_add
|
605 |
+
else:
|
606 |
+
current_value += value_to_add
|
607 |
+
elif isinstance(expression.op, ast.Sub):
|
608 |
+
current_value -= value_to_add
|
609 |
+
elif isinstance(expression.op, ast.Mult):
|
610 |
+
current_value *= value_to_add
|
611 |
+
elif isinstance(expression.op, ast.Div):
|
612 |
+
current_value /= value_to_add
|
613 |
+
elif isinstance(expression.op, ast.Mod):
|
614 |
+
current_value %= value_to_add
|
615 |
+
elif isinstance(expression.op, ast.Pow):
|
616 |
+
current_value **= value_to_add
|
617 |
+
elif isinstance(expression.op, ast.FloorDiv):
|
618 |
+
current_value //= value_to_add
|
619 |
+
elif isinstance(expression.op, ast.BitAnd):
|
620 |
+
current_value &= value_to_add
|
621 |
+
elif isinstance(expression.op, ast.BitOr):
|
622 |
+
current_value |= value_to_add
|
623 |
+
elif isinstance(expression.op, ast.BitXor):
|
624 |
+
current_value ^= value_to_add
|
625 |
+
elif isinstance(expression.op, ast.LShift):
|
626 |
+
current_value <<= value_to_add
|
627 |
+
elif isinstance(expression.op, ast.RShift):
|
628 |
+
current_value >>= value_to_add
|
629 |
+
else:
|
630 |
+
raise InterpreterError(f"Operation {type(expression.op).__name__} is not supported.")
|
631 |
+
|
632 |
+
# Update the state: current_value has been updated in-place
|
633 |
+
set_value(
|
634 |
+
expression.target,
|
635 |
+
current_value,
|
636 |
+
state,
|
637 |
+
static_tools,
|
638 |
+
custom_tools,
|
639 |
+
authorized_imports,
|
640 |
+
)
|
641 |
+
|
642 |
+
return current_value
|
643 |
+
|
644 |
+
|
645 |
+
def evaluate_boolop(
|
646 |
+
node: ast.BoolOp,
|
647 |
+
state: dict[str, Any],
|
648 |
+
static_tools: dict[str, Callable],
|
649 |
+
custom_tools: dict[str, Callable],
|
650 |
+
authorized_imports: list[str],
|
651 |
+
) -> Any:
|
652 |
+
# Determine which value should trigger short-circuit based on operation type:
|
653 |
+
# - 'and' returns the first falsy value encountered (or the last value if all are truthy)
|
654 |
+
# - 'or' returns the first truthy value encountered (or the last value if all are falsy)
|
655 |
+
is_short_circuit_value = (lambda x: not x) if isinstance(node.op, ast.And) else (lambda x: bool(x))
|
656 |
+
for value in node.values:
|
657 |
+
result = evaluate_ast(value, state, static_tools, custom_tools, authorized_imports)
|
658 |
+
# Short-circuit: return immediately if the condition is met
|
659 |
+
if is_short_circuit_value(result):
|
660 |
+
return result
|
661 |
+
# If no short-circuit occurred, return the last evaluated value
|
662 |
+
return result
|
663 |
+
|
664 |
+
|
665 |
+
def evaluate_binop(
|
666 |
+
binop: ast.BinOp,
|
667 |
+
state: dict[str, Any],
|
668 |
+
static_tools: dict[str, Callable],
|
669 |
+
custom_tools: dict[str, Callable],
|
670 |
+
authorized_imports: list[str],
|
671 |
+
) -> Any:
|
672 |
+
# Recursively evaluate the left and right operands
|
673 |
+
left_val = evaluate_ast(binop.left, state, static_tools, custom_tools, authorized_imports)
|
674 |
+
right_val = evaluate_ast(binop.right, state, static_tools, custom_tools, authorized_imports)
|
675 |
+
|
676 |
+
# Determine the operation based on the type of the operator in the BinOp
|
677 |
+
if isinstance(binop.op, ast.Add):
|
678 |
+
return left_val + right_val
|
679 |
+
elif isinstance(binop.op, ast.Sub):
|
680 |
+
return left_val - right_val
|
681 |
+
elif isinstance(binop.op, ast.Mult):
|
682 |
+
return left_val * right_val
|
683 |
+
elif isinstance(binop.op, ast.Div):
|
684 |
+
return left_val / right_val
|
685 |
+
elif isinstance(binop.op, ast.Mod):
|
686 |
+
return left_val % right_val
|
687 |
+
elif isinstance(binop.op, ast.Pow):
|
688 |
+
return left_val**right_val
|
689 |
+
elif isinstance(binop.op, ast.FloorDiv):
|
690 |
+
return left_val // right_val
|
691 |
+
elif isinstance(binop.op, ast.BitAnd):
|
692 |
+
return left_val & right_val
|
693 |
+
elif isinstance(binop.op, ast.BitOr):
|
694 |
+
return left_val | right_val
|
695 |
+
elif isinstance(binop.op, ast.BitXor):
|
696 |
+
return left_val ^ right_val
|
697 |
+
elif isinstance(binop.op, ast.LShift):
|
698 |
+
return left_val << right_val
|
699 |
+
elif isinstance(binop.op, ast.RShift):
|
700 |
+
return left_val >> right_val
|
701 |
+
else:
|
702 |
+
raise NotImplementedError(f"Binary operation {type(binop.op).__name__} is not implemented.")
|
703 |
+
|
704 |
+
|
705 |
+
def evaluate_assign(
|
706 |
+
assign: ast.Assign,
|
707 |
+
state: dict[str, Any],
|
708 |
+
static_tools: dict[str, Callable],
|
709 |
+
custom_tools: dict[str, Callable],
|
710 |
+
authorized_imports: list[str],
|
711 |
+
) -> Any:
|
712 |
+
result = evaluate_ast(assign.value, state, static_tools, custom_tools, authorized_imports)
|
713 |
+
if len(assign.targets) == 1:
|
714 |
+
target = assign.targets[0]
|
715 |
+
set_value(target, result, state, static_tools, custom_tools, authorized_imports)
|
716 |
+
else:
|
717 |
+
expanded_values = []
|
718 |
+
for tgt in assign.targets:
|
719 |
+
if isinstance(tgt, ast.Starred):
|
720 |
+
expanded_values.extend(result)
|
721 |
+
else:
|
722 |
+
expanded_values.append(result)
|
723 |
+
|
724 |
+
for tgt, val in zip(assign.targets, expanded_values):
|
725 |
+
set_value(tgt, val, state, static_tools, custom_tools, authorized_imports)
|
726 |
+
return result
|
727 |
+
|
728 |
+
|
729 |
+
def set_value(
|
730 |
+
target: ast.AST,
|
731 |
+
value: Any,
|
732 |
+
state: dict[str, Any],
|
733 |
+
static_tools: dict[str, Callable],
|
734 |
+
custom_tools: dict[str, Callable],
|
735 |
+
authorized_imports: list[str],
|
736 |
+
) -> None:
|
737 |
+
if isinstance(target, ast.Name):
|
738 |
+
if target.id in static_tools:
|
739 |
+
raise InterpreterError(f"Cannot assign to name '{target.id}': doing this would erase the existing tool!")
|
740 |
+
state[target.id] = value
|
741 |
+
elif isinstance(target, ast.Tuple):
|
742 |
+
if not isinstance(value, tuple):
|
743 |
+
if hasattr(value, "__iter__") and not isinstance(value, (str, bytes)):
|
744 |
+
value = tuple(value)
|
745 |
+
else:
|
746 |
+
raise InterpreterError("Cannot unpack non-tuple value")
|
747 |
+
if len(target.elts) != len(value):
|
748 |
+
raise InterpreterError("Cannot unpack tuple of wrong size")
|
749 |
+
for i, elem in enumerate(target.elts):
|
750 |
+
set_value(elem, value[i], state, static_tools, custom_tools, authorized_imports)
|
751 |
+
elif isinstance(target, ast.Subscript):
|
752 |
+
obj = evaluate_ast(target.value, state, static_tools, custom_tools, authorized_imports)
|
753 |
+
key = evaluate_ast(target.slice, state, static_tools, custom_tools, authorized_imports)
|
754 |
+
obj[key] = value
|
755 |
+
elif isinstance(target, ast.Attribute):
|
756 |
+
obj = evaluate_ast(target.value, state, static_tools, custom_tools, authorized_imports)
|
757 |
+
setattr(obj, target.attr, value)
|
758 |
+
|
759 |
+
|
760 |
+
def evaluate_call(
|
761 |
+
call: ast.Call,
|
762 |
+
state: dict[str, Any],
|
763 |
+
static_tools: dict[str, Callable],
|
764 |
+
custom_tools: dict[str, Callable],
|
765 |
+
authorized_imports: list[str],
|
766 |
+
) -> Any:
|
767 |
+
if not isinstance(call.func, (ast.Call, ast.Lambda, ast.Attribute, ast.Name, ast.Subscript)):
|
768 |
+
raise InterpreterError(f"This is not a correct function: {call.func}).")
|
769 |
+
|
770 |
+
func, func_name = None, None
|
771 |
+
|
772 |
+
if isinstance(call.func, ast.Call):
|
773 |
+
func = evaluate_ast(call.func, state, static_tools, custom_tools, authorized_imports)
|
774 |
+
elif isinstance(call.func, ast.Lambda):
|
775 |
+
func = evaluate_ast(call.func, state, static_tools, custom_tools, authorized_imports)
|
776 |
+
elif isinstance(call.func, ast.Attribute):
|
777 |
+
obj = evaluate_ast(call.func.value, state, static_tools, custom_tools, authorized_imports)
|
778 |
+
func_name = call.func.attr
|
779 |
+
if not hasattr(obj, func_name):
|
780 |
+
raise InterpreterError(f"Object {obj} has no attribute {func_name}")
|
781 |
+
func = getattr(obj, func_name)
|
782 |
+
elif isinstance(call.func, ast.Name):
|
783 |
+
func_name = call.func.id
|
784 |
+
if func_name in state:
|
785 |
+
func = state[func_name]
|
786 |
+
elif func_name in static_tools:
|
787 |
+
func = static_tools[func_name]
|
788 |
+
elif func_name in custom_tools:
|
789 |
+
func = custom_tools[func_name]
|
790 |
+
elif func_name in ERRORS:
|
791 |
+
func = ERRORS[func_name]
|
792 |
+
else:
|
793 |
+
raise InterpreterError(
|
794 |
+
f"Forbidden function evaluation: '{call.func.id}' is not among the explicitly allowed tools or defined/imported in the preceding code"
|
795 |
+
)
|
796 |
+
elif isinstance(call.func, ast.Subscript):
|
797 |
+
func = evaluate_ast(call.func, state, static_tools, custom_tools, authorized_imports)
|
798 |
+
if not callable(func):
|
799 |
+
raise InterpreterError(f"This is not a correct function: {call.func}).")
|
800 |
+
func_name = None
|
801 |
+
|
802 |
+
args = []
|
803 |
+
for arg in call.args:
|
804 |
+
if isinstance(arg, ast.Starred):
|
805 |
+
args.extend(evaluate_ast(arg.value, state, static_tools, custom_tools, authorized_imports))
|
806 |
+
else:
|
807 |
+
args.append(evaluate_ast(arg, state, static_tools, custom_tools, authorized_imports))
|
808 |
+
|
809 |
+
kwargs = {
|
810 |
+
keyword.arg: evaluate_ast(keyword.value, state, static_tools, custom_tools, authorized_imports)
|
811 |
+
for keyword in call.keywords
|
812 |
+
}
|
813 |
+
|
814 |
+
if func_name == "super":
|
815 |
+
if not args:
|
816 |
+
if "__class__" in state and "self" in state:
|
817 |
+
return super(state["__class__"], state["self"])
|
818 |
+
else:
|
819 |
+
raise InterpreterError("super() needs at least one argument")
|
820 |
+
cls = args[0]
|
821 |
+
if not isinstance(cls, type):
|
822 |
+
raise InterpreterError("super() argument 1 must be type")
|
823 |
+
if len(args) == 1:
|
824 |
+
return super(cls)
|
825 |
+
elif len(args) == 2:
|
826 |
+
instance = args[1]
|
827 |
+
return super(cls, instance)
|
828 |
+
else:
|
829 |
+
raise InterpreterError("super() takes at most 2 arguments")
|
830 |
+
elif func_name == "print":
|
831 |
+
state["_print_outputs"] += " ".join(map(str, args)) + "\n"
|
832 |
+
return None
|
833 |
+
else: # Assume it's a callable object
|
834 |
+
if (inspect.getmodule(func) == builtins) and inspect.isbuiltin(func) and (func not in static_tools.values()):
|
835 |
+
raise InterpreterError(
|
836 |
+
f"Invoking a builtin function that has not been explicitly added as a tool is not allowed ({func_name})."
|
837 |
+
)
|
838 |
+
return func(*args, **kwargs)
|
839 |
+
|
840 |
+
|
841 |
+
def evaluate_subscript(
|
842 |
+
subscript: ast.Subscript,
|
843 |
+
state: dict[str, Any],
|
844 |
+
static_tools: dict[str, Callable],
|
845 |
+
custom_tools: dict[str, Callable],
|
846 |
+
authorized_imports: list[str],
|
847 |
+
) -> Any:
|
848 |
+
index = evaluate_ast(subscript.slice, state, static_tools, custom_tools, authorized_imports)
|
849 |
+
value = evaluate_ast(subscript.value, state, static_tools, custom_tools, authorized_imports)
|
850 |
+
try:
|
851 |
+
return value[index]
|
852 |
+
except (KeyError, IndexError, TypeError) as e:
|
853 |
+
error_message = f"Could not index {value} with '{index}': {type(e).__name__}: {e}"
|
854 |
+
if isinstance(index, str) and isinstance(value, Mapping):
|
855 |
+
close_matches = difflib.get_close_matches(index, list(value.keys()))
|
856 |
+
if len(close_matches) > 0:
|
857 |
+
error_message += f". Maybe you meant one of these indexes instead: {str(close_matches)}"
|
858 |
+
raise InterpreterError(error_message) from e
|
859 |
+
|
860 |
+
|
861 |
+
def evaluate_name(
|
862 |
+
name: ast.Name,
|
863 |
+
state: dict[str, Any],
|
864 |
+
static_tools: dict[str, Callable],
|
865 |
+
custom_tools: dict[str, Callable],
|
866 |
+
authorized_imports: list[str],
|
867 |
+
) -> Any:
|
868 |
+
if name.id in state:
|
869 |
+
return state[name.id]
|
870 |
+
elif name.id in static_tools:
|
871 |
+
return safer_func(static_tools[name.id], static_tools=static_tools, authorized_imports=authorized_imports)
|
872 |
+
elif name.id in custom_tools:
|
873 |
+
return custom_tools[name.id]
|
874 |
+
elif name.id in ERRORS:
|
875 |
+
return ERRORS[name.id]
|
876 |
+
close_matches = difflib.get_close_matches(name.id, list(state.keys()))
|
877 |
+
if len(close_matches) > 0:
|
878 |
+
return state[close_matches[0]]
|
879 |
+
raise InterpreterError(f"The variable `{name.id}` is not defined.")
|
880 |
+
|
881 |
+
|
882 |
+
def evaluate_condition(
|
883 |
+
condition: ast.Compare,
|
884 |
+
state: dict[str, Any],
|
885 |
+
static_tools: dict[str, Callable],
|
886 |
+
custom_tools: dict[str, Callable],
|
887 |
+
authorized_imports: list[str],
|
888 |
+
) -> bool | object:
|
889 |
+
result = True
|
890 |
+
left = evaluate_ast(condition.left, state, static_tools, custom_tools, authorized_imports)
|
891 |
+
for i, (op, comparator) in enumerate(zip(condition.ops, condition.comparators)):
|
892 |
+
op = type(op)
|
893 |
+
right = evaluate_ast(comparator, state, static_tools, custom_tools, authorized_imports)
|
894 |
+
if op == ast.Eq:
|
895 |
+
current_result = left == right
|
896 |
+
elif op == ast.NotEq:
|
897 |
+
current_result = left != right
|
898 |
+
elif op == ast.Lt:
|
899 |
+
current_result = left < right
|
900 |
+
elif op == ast.LtE:
|
901 |
+
current_result = left <= right
|
902 |
+
elif op == ast.Gt:
|
903 |
+
current_result = left > right
|
904 |
+
elif op == ast.GtE:
|
905 |
+
current_result = left >= right
|
906 |
+
elif op == ast.Is:
|
907 |
+
current_result = left is right
|
908 |
+
elif op == ast.IsNot:
|
909 |
+
current_result = left is not right
|
910 |
+
elif op == ast.In:
|
911 |
+
current_result = left in right
|
912 |
+
elif op == ast.NotIn:
|
913 |
+
current_result = left not in right
|
914 |
+
else:
|
915 |
+
raise InterpreterError(f"Unsupported comparison operator: {op}")
|
916 |
+
|
917 |
+
if current_result is False:
|
918 |
+
return False
|
919 |
+
result = current_result if i == 0 else (result and current_result)
|
920 |
+
left = right
|
921 |
+
return result
|
922 |
+
|
923 |
+
|
924 |
+
def evaluate_if(
|
925 |
+
if_statement: ast.If,
|
926 |
+
state: dict[str, Any],
|
927 |
+
static_tools: dict[str, Callable],
|
928 |
+
custom_tools: dict[str, Callable],
|
929 |
+
authorized_imports: list[str],
|
930 |
+
) -> Any:
|
931 |
+
result = None
|
932 |
+
test_result = evaluate_ast(if_statement.test, state, static_tools, custom_tools, authorized_imports)
|
933 |
+
if test_result:
|
934 |
+
for line in if_statement.body:
|
935 |
+
line_result = evaluate_ast(line, state, static_tools, custom_tools, authorized_imports)
|
936 |
+
if line_result is not None:
|
937 |
+
result = line_result
|
938 |
+
else:
|
939 |
+
for line in if_statement.orelse:
|
940 |
+
line_result = evaluate_ast(line, state, static_tools, custom_tools, authorized_imports)
|
941 |
+
if line_result is not None:
|
942 |
+
result = line_result
|
943 |
+
return result
|
944 |
+
|
945 |
+
|
946 |
+
def evaluate_for(
|
947 |
+
for_loop: ast.For,
|
948 |
+
state: dict[str, Any],
|
949 |
+
static_tools: dict[str, Callable],
|
950 |
+
custom_tools: dict[str, Callable],
|
951 |
+
authorized_imports: list[str],
|
952 |
+
) -> Any:
|
953 |
+
result = None
|
954 |
+
iterator = evaluate_ast(for_loop.iter, state, static_tools, custom_tools, authorized_imports)
|
955 |
+
for counter in iterator:
|
956 |
+
set_value(
|
957 |
+
for_loop.target,
|
958 |
+
counter,
|
959 |
+
state,
|
960 |
+
static_tools,
|
961 |
+
custom_tools,
|
962 |
+
authorized_imports,
|
963 |
+
)
|
964 |
+
for node in for_loop.body:
|
965 |
+
try:
|
966 |
+
line_result = evaluate_ast(node, state, static_tools, custom_tools, authorized_imports)
|
967 |
+
if line_result is not None:
|
968 |
+
result = line_result
|
969 |
+
except BreakException:
|
970 |
+
break
|
971 |
+
except ContinueException:
|
972 |
+
continue
|
973 |
+
else:
|
974 |
+
continue
|
975 |
+
break
|
976 |
+
return result
|
977 |
+
|
978 |
+
|
979 |
+
def evaluate_listcomp(
|
980 |
+
listcomp: ast.ListComp,
|
981 |
+
state: dict[str, Any],
|
982 |
+
static_tools: dict[str, Callable],
|
983 |
+
custom_tools: dict[str, Callable],
|
984 |
+
authorized_imports: list[str],
|
985 |
+
) -> list[Any]:
|
986 |
+
def inner_evaluate(generators: list[ast.comprehension], index: int, current_state: dict[str, Any]) -> list[Any]:
|
987 |
+
if index >= len(generators):
|
988 |
+
return [
|
989 |
+
evaluate_ast(
|
990 |
+
listcomp.elt,
|
991 |
+
current_state,
|
992 |
+
static_tools,
|
993 |
+
custom_tools,
|
994 |
+
authorized_imports,
|
995 |
+
)
|
996 |
+
]
|
997 |
+
generator = generators[index]
|
998 |
+
iter_value = evaluate_ast(
|
999 |
+
generator.iter,
|
1000 |
+
current_state,
|
1001 |
+
static_tools,
|
1002 |
+
custom_tools,
|
1003 |
+
authorized_imports,
|
1004 |
+
)
|
1005 |
+
result = []
|
1006 |
+
for value in iter_value:
|
1007 |
+
new_state = current_state.copy()
|
1008 |
+
if isinstance(generator.target, ast.Tuple):
|
1009 |
+
for idx, elem in enumerate(generator.target.elts):
|
1010 |
+
new_state[elem.id] = value[idx]
|
1011 |
+
else:
|
1012 |
+
new_state[generator.target.id] = value
|
1013 |
+
if all(
|
1014 |
+
evaluate_ast(if_clause, new_state, static_tools, custom_tools, authorized_imports)
|
1015 |
+
for if_clause in generator.ifs
|
1016 |
+
):
|
1017 |
+
result.extend(inner_evaluate(generators, index + 1, new_state))
|
1018 |
+
return result
|
1019 |
+
|
1020 |
+
return inner_evaluate(listcomp.generators, 0, state)
|
1021 |
+
|
1022 |
+
|
1023 |
+
def evaluate_setcomp(
|
1024 |
+
setcomp: ast.SetComp,
|
1025 |
+
state: dict[str, Any],
|
1026 |
+
static_tools: dict[str, Callable],
|
1027 |
+
custom_tools: dict[str, Callable],
|
1028 |
+
authorized_imports: list[str],
|
1029 |
+
) -> set[Any]:
|
1030 |
+
result = set()
|
1031 |
+
for gen in setcomp.generators:
|
1032 |
+
iter_value = evaluate_ast(gen.iter, state, static_tools, custom_tools, authorized_imports)
|
1033 |
+
for value in iter_value:
|
1034 |
+
new_state = state.copy()
|
1035 |
+
set_value(
|
1036 |
+
gen.target,
|
1037 |
+
value,
|
1038 |
+
new_state,
|
1039 |
+
static_tools,
|
1040 |
+
custom_tools,
|
1041 |
+
authorized_imports,
|
1042 |
+
)
|
1043 |
+
if all(
|
1044 |
+
evaluate_ast(if_clause, new_state, static_tools, custom_tools, authorized_imports)
|
1045 |
+
for if_clause in gen.ifs
|
1046 |
+
):
|
1047 |
+
element = evaluate_ast(
|
1048 |
+
setcomp.elt,
|
1049 |
+
new_state,
|
1050 |
+
static_tools,
|
1051 |
+
custom_tools,
|
1052 |
+
authorized_imports,
|
1053 |
+
)
|
1054 |
+
result.add(element)
|
1055 |
+
return result
|
1056 |
+
|
1057 |
+
|
1058 |
+
def evaluate_try(
|
1059 |
+
try_node: ast.Try,
|
1060 |
+
state: dict[str, Any],
|
1061 |
+
static_tools: dict[str, Callable],
|
1062 |
+
custom_tools: dict[str, Callable],
|
1063 |
+
authorized_imports: list[str],
|
1064 |
+
) -> None:
|
1065 |
+
try:
|
1066 |
+
for stmt in try_node.body:
|
1067 |
+
evaluate_ast(stmt, state, static_tools, custom_tools, authorized_imports)
|
1068 |
+
except Exception as e:
|
1069 |
+
matched = False
|
1070 |
+
for handler in try_node.handlers:
|
1071 |
+
if handler.type is None or isinstance(
|
1072 |
+
e,
|
1073 |
+
evaluate_ast(handler.type, state, static_tools, custom_tools, authorized_imports),
|
1074 |
+
):
|
1075 |
+
matched = True
|
1076 |
+
if handler.name:
|
1077 |
+
state[handler.name] = e
|
1078 |
+
for stmt in handler.body:
|
1079 |
+
evaluate_ast(stmt, state, static_tools, custom_tools, authorized_imports)
|
1080 |
+
break
|
1081 |
+
if not matched:
|
1082 |
+
raise e
|
1083 |
+
else:
|
1084 |
+
if try_node.orelse:
|
1085 |
+
for stmt in try_node.orelse:
|
1086 |
+
evaluate_ast(stmt, state, static_tools, custom_tools, authorized_imports)
|
1087 |
+
finally:
|
1088 |
+
if try_node.finalbody:
|
1089 |
+
for stmt in try_node.finalbody:
|
1090 |
+
evaluate_ast(stmt, state, static_tools, custom_tools, authorized_imports)
|
1091 |
+
|
1092 |
+
|
1093 |
+
def evaluate_raise(
|
1094 |
+
raise_node: ast.Raise,
|
1095 |
+
state: dict[str, Any],
|
1096 |
+
static_tools: dict[str, Callable],
|
1097 |
+
custom_tools: dict[str, Callable],
|
1098 |
+
authorized_imports: list[str],
|
1099 |
+
) -> None:
|
1100 |
+
if raise_node.exc is not None:
|
1101 |
+
exc = evaluate_ast(raise_node.exc, state, static_tools, custom_tools, authorized_imports)
|
1102 |
+
else:
|
1103 |
+
exc = None
|
1104 |
+
if raise_node.cause is not None:
|
1105 |
+
cause = evaluate_ast(raise_node.cause, state, static_tools, custom_tools, authorized_imports)
|
1106 |
+
else:
|
1107 |
+
cause = None
|
1108 |
+
if exc is not None:
|
1109 |
+
if cause is not None:
|
1110 |
+
raise exc from cause
|
1111 |
+
else:
|
1112 |
+
raise exc
|
1113 |
+
else:
|
1114 |
+
raise InterpreterError("Re-raise is not supported without an active exception")
|
1115 |
+
|
1116 |
+
|
1117 |
+
def evaluate_assert(
|
1118 |
+
assert_node: ast.Assert,
|
1119 |
+
state: dict[str, Any],
|
1120 |
+
static_tools: dict[str, Callable],
|
1121 |
+
custom_tools: dict[str, Callable],
|
1122 |
+
authorized_imports: list[str],
|
1123 |
+
) -> None:
|
1124 |
+
test_result = evaluate_ast(assert_node.test, state, static_tools, custom_tools, authorized_imports)
|
1125 |
+
if not test_result:
|
1126 |
+
if assert_node.msg:
|
1127 |
+
msg = evaluate_ast(assert_node.msg, state, static_tools, custom_tools, authorized_imports)
|
1128 |
+
raise AssertionError(msg)
|
1129 |
+
else:
|
1130 |
+
# Include the failing condition in the assertion message
|
1131 |
+
test_code = ast.unparse(assert_node.test)
|
1132 |
+
raise AssertionError(f"Assertion failed: {test_code}")
|
1133 |
+
|
1134 |
+
|
1135 |
+
def evaluate_with(
|
1136 |
+
with_node: ast.With,
|
1137 |
+
state: dict[str, Any],
|
1138 |
+
static_tools: dict[str, Callable],
|
1139 |
+
custom_tools: dict[str, Callable],
|
1140 |
+
authorized_imports: list[str],
|
1141 |
+
) -> None:
|
1142 |
+
contexts = []
|
1143 |
+
for item in with_node.items:
|
1144 |
+
context_expr = evaluate_ast(item.context_expr, state, static_tools, custom_tools, authorized_imports)
|
1145 |
+
if item.optional_vars:
|
1146 |
+
state[item.optional_vars.id] = context_expr.__enter__()
|
1147 |
+
contexts.append(state[item.optional_vars.id])
|
1148 |
+
else:
|
1149 |
+
context_var = context_expr.__enter__()
|
1150 |
+
contexts.append(context_var)
|
1151 |
+
|
1152 |
+
try:
|
1153 |
+
for stmt in with_node.body:
|
1154 |
+
evaluate_ast(stmt, state, static_tools, custom_tools, authorized_imports)
|
1155 |
+
except Exception as e:
|
1156 |
+
for context in reversed(contexts):
|
1157 |
+
context.__exit__(type(e), e, e.__traceback__)
|
1158 |
+
raise
|
1159 |
+
else:
|
1160 |
+
for context in reversed(contexts):
|
1161 |
+
context.__exit__(None, None, None)
|
1162 |
+
|
1163 |
+
|
1164 |
+
def get_safe_module(raw_module, authorized_imports, visited=None):
|
1165 |
+
"""Creates a safe copy of a module or returns the original if it's a function"""
|
1166 |
+
# If it's a function or non-module object, return it directly
|
1167 |
+
if not isinstance(raw_module, ModuleType):
|
1168 |
+
return raw_module
|
1169 |
+
|
1170 |
+
# Handle circular references: Initialize visited set for the first call
|
1171 |
+
if visited is None:
|
1172 |
+
visited = set()
|
1173 |
+
|
1174 |
+
module_id = id(raw_module)
|
1175 |
+
if module_id in visited:
|
1176 |
+
return raw_module # Return original for circular refs
|
1177 |
+
|
1178 |
+
visited.add(module_id)
|
1179 |
+
|
1180 |
+
# Create new module for actual modules
|
1181 |
+
safe_module = ModuleType(raw_module.__name__)
|
1182 |
+
|
1183 |
+
# Copy all attributes by reference, recursively checking modules
|
1184 |
+
for attr_name in dir(raw_module):
|
1185 |
+
try:
|
1186 |
+
attr_value = getattr(raw_module, attr_name)
|
1187 |
+
except (ImportError, AttributeError) as e:
|
1188 |
+
# lazy / dynamic loading module -> INFO log and skip
|
1189 |
+
logger.info(
|
1190 |
+
f"Skipping import error while copying {raw_module.__name__}.{attr_name}: {type(e).__name__} - {e}"
|
1191 |
+
)
|
1192 |
+
continue
|
1193 |
+
# Recursively process nested modules, passing visited set
|
1194 |
+
if isinstance(attr_value, ModuleType):
|
1195 |
+
attr_value = get_safe_module(attr_value, authorized_imports, visited=visited)
|
1196 |
+
|
1197 |
+
setattr(safe_module, attr_name, attr_value)
|
1198 |
+
|
1199 |
+
return safe_module
|
1200 |
+
|
1201 |
+
|
1202 |
+
def evaluate_import(expression, state, authorized_imports):
|
1203 |
+
if isinstance(expression, ast.Import):
|
1204 |
+
for alias in expression.names:
|
1205 |
+
if check_import_authorized(alias.name, authorized_imports):
|
1206 |
+
raw_module = import_module(alias.name)
|
1207 |
+
state[alias.asname or alias.name] = get_safe_module(raw_module, authorized_imports)
|
1208 |
+
else:
|
1209 |
+
raise InterpreterError(
|
1210 |
+
f"Import of {alias.name} is not allowed. Authorized imports are: {str(authorized_imports)}"
|
1211 |
+
)
|
1212 |
+
return None
|
1213 |
+
elif isinstance(expression, ast.ImportFrom):
|
1214 |
+
if check_import_authorized(expression.module, authorized_imports):
|
1215 |
+
raw_module = __import__(expression.module, fromlist=[alias.name for alias in expression.names])
|
1216 |
+
module = get_safe_module(raw_module, authorized_imports)
|
1217 |
+
if expression.names[0].name == "*": # Handle "from module import *"
|
1218 |
+
if hasattr(module, "__all__"): # If module has __all__, import only those names
|
1219 |
+
for name in module.__all__:
|
1220 |
+
state[name] = getattr(module, name)
|
1221 |
+
else: # If no __all__, import all public names (those not starting with '_')
|
1222 |
+
for name in dir(module):
|
1223 |
+
if not name.startswith("_"):
|
1224 |
+
state[name] = getattr(module, name)
|
1225 |
+
else: # regular from imports
|
1226 |
+
for alias in expression.names:
|
1227 |
+
if hasattr(module, alias.name):
|
1228 |
+
state[alias.asname or alias.name] = getattr(module, alias.name)
|
1229 |
+
else:
|
1230 |
+
raise InterpreterError(f"Module {expression.module} has no attribute {alias.name}")
|
1231 |
+
else:
|
1232 |
+
raise InterpreterError(
|
1233 |
+
f"Import from {expression.module} is not allowed. Authorized imports are: {str(authorized_imports)}"
|
1234 |
+
)
|
1235 |
+
return None
|
1236 |
+
|
1237 |
+
|
1238 |
+
def evaluate_dictcomp(
|
1239 |
+
dictcomp: ast.DictComp,
|
1240 |
+
state: dict[str, Any],
|
1241 |
+
static_tools: dict[str, Callable],
|
1242 |
+
custom_tools: dict[str, Callable],
|
1243 |
+
authorized_imports: list[str],
|
1244 |
+
) -> dict[Any, Any]:
|
1245 |
+
result = {}
|
1246 |
+
for gen in dictcomp.generators:
|
1247 |
+
iter_value = evaluate_ast(gen.iter, state, static_tools, custom_tools, authorized_imports)
|
1248 |
+
for value in iter_value:
|
1249 |
+
new_state = state.copy()
|
1250 |
+
set_value(
|
1251 |
+
gen.target,
|
1252 |
+
value,
|
1253 |
+
new_state,
|
1254 |
+
static_tools,
|
1255 |
+
custom_tools,
|
1256 |
+
authorized_imports,
|
1257 |
+
)
|
1258 |
+
if all(
|
1259 |
+
evaluate_ast(if_clause, new_state, static_tools, custom_tools, authorized_imports)
|
1260 |
+
for if_clause in gen.ifs
|
1261 |
+
):
|
1262 |
+
key = evaluate_ast(
|
1263 |
+
dictcomp.key,
|
1264 |
+
new_state,
|
1265 |
+
static_tools,
|
1266 |
+
custom_tools,
|
1267 |
+
authorized_imports,
|
1268 |
+
)
|
1269 |
+
val = evaluate_ast(
|
1270 |
+
dictcomp.value,
|
1271 |
+
new_state,
|
1272 |
+
static_tools,
|
1273 |
+
custom_tools,
|
1274 |
+
authorized_imports,
|
1275 |
+
)
|
1276 |
+
result[key] = val
|
1277 |
+
return result
|
1278 |
+
|
1279 |
+
|
1280 |
+
def evaluate_delete(
|
1281 |
+
delete_node: ast.Delete,
|
1282 |
+
state: dict[str, Any],
|
1283 |
+
static_tools: dict[str, Callable],
|
1284 |
+
custom_tools: dict[str, Callable],
|
1285 |
+
authorized_imports: list[str],
|
1286 |
+
) -> None:
|
1287 |
+
"""
|
1288 |
+
Evaluate a delete statement (del x, del x[y]).
|
1289 |
+
|
1290 |
+
Args:
|
1291 |
+
delete_node: The AST Delete node to evaluate
|
1292 |
+
state: The current state dictionary
|
1293 |
+
static_tools: Dictionary of static tools
|
1294 |
+
custom_tools: Dictionary of custom tools
|
1295 |
+
authorized_imports: List of authorized imports
|
1296 |
+
"""
|
1297 |
+
for target in delete_node.targets:
|
1298 |
+
if isinstance(target, ast.Name):
|
1299 |
+
# Handle simple variable deletion (del x)
|
1300 |
+
if target.id in state:
|
1301 |
+
del state[target.id]
|
1302 |
+
else:
|
1303 |
+
raise InterpreterError(f"Cannot delete name '{target.id}': name is not defined")
|
1304 |
+
elif isinstance(target, ast.Subscript):
|
1305 |
+
# Handle index/key deletion (del x[y])
|
1306 |
+
obj = evaluate_ast(target.value, state, static_tools, custom_tools, authorized_imports)
|
1307 |
+
index = evaluate_ast(target.slice, state, static_tools, custom_tools, authorized_imports)
|
1308 |
+
try:
|
1309 |
+
del obj[index]
|
1310 |
+
except (TypeError, KeyError, IndexError) as e:
|
1311 |
+
raise InterpreterError(f"Cannot delete index/key: {str(e)}")
|
1312 |
+
else:
|
1313 |
+
raise InterpreterError(f"Deletion of {type(target).__name__} targets is not supported")
|
1314 |
+
|
1315 |
+
|
1316 |
+
@safer_eval
|
1317 |
+
def evaluate_ast(
|
1318 |
+
expression: ast.AST,
|
1319 |
+
state: dict[str, Any],
|
1320 |
+
static_tools: dict[str, Callable],
|
1321 |
+
custom_tools: dict[str, Callable],
|
1322 |
+
authorized_imports: list[str] = BASE_BUILTIN_MODULES,
|
1323 |
+
):
|
1324 |
+
"""
|
1325 |
+
Evaluate an abstract syntax tree using the content of the variables stored in a state and only evaluating a given
|
1326 |
+
set of functions.
|
1327 |
+
|
1328 |
+
This function will recurse through the nodes of the tree provided.
|
1329 |
+
|
1330 |
+
Args:
|
1331 |
+
expression (`ast.AST`):
|
1332 |
+
The code to evaluate, as an abstract syntax tree.
|
1333 |
+
state (`Dict[str, Any]`):
|
1334 |
+
A dictionary mapping variable names to values. The `state` is updated if need be when the evaluation
|
1335 |
+
encounters assignments.
|
1336 |
+
static_tools (`Dict[str, Callable]`):
|
1337 |
+
Functions that may be called during the evaluation. Trying to change one of these static_tools will raise an error.
|
1338 |
+
custom_tools (`Dict[str, Callable]`):
|
1339 |
+
Functions that may be called during the evaluation. These custom_tools can be overwritten.
|
1340 |
+
authorized_imports (`List[str]`):
|
1341 |
+
The list of modules that can be imported by the code. By default, only a few safe modules are allowed.
|
1342 |
+
If it contains "*", it will authorize any import. Use this at your own risk!
|
1343 |
+
"""
|
1344 |
+
if state.setdefault("_operations_count", {"counter": 0})["counter"] >= MAX_OPERATIONS:
|
1345 |
+
raise InterpreterError(
|
1346 |
+
f"Reached the max number of operations of {MAX_OPERATIONS}. Maybe there is an infinite loop somewhere in the code, or you're just asking too many calculations."
|
1347 |
+
)
|
1348 |
+
state["_operations_count"]["counter"] += 1
|
1349 |
+
common_params = (state, static_tools, custom_tools, authorized_imports)
|
1350 |
+
if isinstance(expression, ast.Assign):
|
1351 |
+
# Assignment -> we evaluate the assignment which should update the state
|
1352 |
+
# We return the variable assigned as it may be used to determine the final result.
|
1353 |
+
return evaluate_assign(expression, *common_params)
|
1354 |
+
elif isinstance(expression, ast.AnnAssign):
|
1355 |
+
return evaluate_annassign(expression, *common_params)
|
1356 |
+
elif isinstance(expression, ast.AugAssign):
|
1357 |
+
return evaluate_augassign(expression, *common_params)
|
1358 |
+
elif isinstance(expression, ast.Call):
|
1359 |
+
# Function call -> we return the value of the function call
|
1360 |
+
return evaluate_call(expression, *common_params)
|
1361 |
+
elif isinstance(expression, ast.Constant):
|
1362 |
+
# Constant -> just return the value
|
1363 |
+
return expression.value
|
1364 |
+
elif isinstance(expression, ast.Tuple):
|
1365 |
+
return tuple((evaluate_ast(elt, *common_params) for elt in expression.elts))
|
1366 |
+
elif isinstance(expression, (ast.ListComp, ast.GeneratorExp)):
|
1367 |
+
return evaluate_listcomp(expression, *common_params)
|
1368 |
+
elif isinstance(expression, ast.DictComp):
|
1369 |
+
return evaluate_dictcomp(expression, *common_params)
|
1370 |
+
elif isinstance(expression, ast.SetComp):
|
1371 |
+
return evaluate_setcomp(expression, *common_params)
|
1372 |
+
elif isinstance(expression, ast.UnaryOp):
|
1373 |
+
return evaluate_unaryop(expression, *common_params)
|
1374 |
+
elif isinstance(expression, ast.Starred):
|
1375 |
+
return evaluate_ast(expression.value, *common_params)
|
1376 |
+
elif isinstance(expression, ast.BoolOp):
|
1377 |
+
# Boolean operation -> evaluate the operation
|
1378 |
+
return evaluate_boolop(expression, *common_params)
|
1379 |
+
elif isinstance(expression, ast.Break):
|
1380 |
+
raise BreakException()
|
1381 |
+
elif isinstance(expression, ast.Continue):
|
1382 |
+
raise ContinueException()
|
1383 |
+
elif isinstance(expression, ast.BinOp):
|
1384 |
+
# Binary operation -> execute operation
|
1385 |
+
return evaluate_binop(expression, *common_params)
|
1386 |
+
elif isinstance(expression, ast.Compare):
|
1387 |
+
# Comparison -> evaluate the comparison
|
1388 |
+
return evaluate_condition(expression, *common_params)
|
1389 |
+
elif isinstance(expression, ast.Lambda):
|
1390 |
+
return evaluate_lambda(expression, *common_params)
|
1391 |
+
elif isinstance(expression, ast.FunctionDef):
|
1392 |
+
return evaluate_function_def(expression, *common_params)
|
1393 |
+
elif isinstance(expression, ast.Dict):
|
1394 |
+
# Dict -> evaluate all keys and values
|
1395 |
+
keys = (evaluate_ast(k, *common_params) for k in expression.keys)
|
1396 |
+
values = (evaluate_ast(v, *common_params) for v in expression.values)
|
1397 |
+
return dict(zip(keys, values))
|
1398 |
+
elif isinstance(expression, ast.Expr):
|
1399 |
+
# Expression -> evaluate the content
|
1400 |
+
return evaluate_ast(expression.value, *common_params)
|
1401 |
+
elif isinstance(expression, ast.For):
|
1402 |
+
# For loop -> execute the loop
|
1403 |
+
return evaluate_for(expression, *common_params)
|
1404 |
+
elif isinstance(expression, ast.FormattedValue):
|
1405 |
+
# Formatted value (part of f-string) -> evaluate the content and format it
|
1406 |
+
value = evaluate_ast(expression.value, *common_params)
|
1407 |
+
# Early return if no format spec
|
1408 |
+
if not expression.format_spec:
|
1409 |
+
return value
|
1410 |
+
# Apply format specification
|
1411 |
+
format_spec = evaluate_ast(expression.format_spec, *common_params)
|
1412 |
+
return format(value, format_spec)
|
1413 |
+
elif isinstance(expression, ast.If):
|
1414 |
+
# If -> execute the right branch
|
1415 |
+
return evaluate_if(expression, *common_params)
|
1416 |
+
elif hasattr(ast, "Index") and isinstance(expression, ast.Index):
|
1417 |
+
return evaluate_ast(expression.value, *common_params)
|
1418 |
+
elif isinstance(expression, ast.JoinedStr):
|
1419 |
+
return "".join([str(evaluate_ast(v, *common_params)) for v in expression.values])
|
1420 |
+
elif isinstance(expression, ast.List):
|
1421 |
+
# List -> evaluate all elements
|
1422 |
+
return [evaluate_ast(elt, *common_params) for elt in expression.elts]
|
1423 |
+
elif isinstance(expression, ast.Name):
|
1424 |
+
# Name -> pick up the value in the state
|
1425 |
+
return evaluate_name(expression, *common_params)
|
1426 |
+
elif isinstance(expression, ast.Subscript):
|
1427 |
+
# Subscript -> return the value of the indexing
|
1428 |
+
return evaluate_subscript(expression, *common_params)
|
1429 |
+
elif isinstance(expression, ast.IfExp):
|
1430 |
+
test_val = evaluate_ast(expression.test, *common_params)
|
1431 |
+
if test_val:
|
1432 |
+
return evaluate_ast(expression.body, *common_params)
|
1433 |
+
else:
|
1434 |
+
return evaluate_ast(expression.orelse, *common_params)
|
1435 |
+
elif isinstance(expression, ast.Attribute):
|
1436 |
+
return evaluate_attribute(expression, *common_params)
|
1437 |
+
elif isinstance(expression, ast.Slice):
|
1438 |
+
return slice(
|
1439 |
+
evaluate_ast(expression.lower, *common_params) if expression.lower is not None else None,
|
1440 |
+
evaluate_ast(expression.upper, *common_params) if expression.upper is not None else None,
|
1441 |
+
evaluate_ast(expression.step, *common_params) if expression.step is not None else None,
|
1442 |
+
)
|
1443 |
+
elif isinstance(expression, ast.While):
|
1444 |
+
return evaluate_while(expression, *common_params)
|
1445 |
+
elif isinstance(expression, (ast.Import, ast.ImportFrom)):
|
1446 |
+
return evaluate_import(expression, state, authorized_imports)
|
1447 |
+
elif isinstance(expression, ast.ClassDef):
|
1448 |
+
return evaluate_class_def(expression, *common_params)
|
1449 |
+
elif isinstance(expression, ast.Try):
|
1450 |
+
return evaluate_try(expression, *common_params)
|
1451 |
+
elif isinstance(expression, ast.Raise):
|
1452 |
+
return evaluate_raise(expression, *common_params)
|
1453 |
+
elif isinstance(expression, ast.Assert):
|
1454 |
+
return evaluate_assert(expression, *common_params)
|
1455 |
+
elif isinstance(expression, ast.With):
|
1456 |
+
return evaluate_with(expression, *common_params)
|
1457 |
+
elif isinstance(expression, ast.Set):
|
1458 |
+
return set((evaluate_ast(elt, *common_params) for elt in expression.elts))
|
1459 |
+
elif isinstance(expression, ast.Return):
|
1460 |
+
raise ReturnException(evaluate_ast(expression.value, *common_params) if expression.value else None)
|
1461 |
+
elif isinstance(expression, ast.Pass):
|
1462 |
+
return None
|
1463 |
+
elif isinstance(expression, ast.Delete):
|
1464 |
+
return evaluate_delete(expression, *common_params)
|
1465 |
+
else:
|
1466 |
+
# For now we refuse anything else. Let's add things as we need them.
|
1467 |
+
raise InterpreterError(f"{expression.__class__.__name__} is not supported.")
|
1468 |
+
|
1469 |
+
|
1470 |
+
class FinalAnswerException(Exception):
|
1471 |
+
def __init__(self, value):
|
1472 |
+
self.value = value
|
1473 |
+
|
1474 |
+
|
1475 |
+
def evaluate_python_code(
|
1476 |
+
code: str,
|
1477 |
+
static_tools: dict[str, Callable] | None = None,
|
1478 |
+
custom_tools: dict[str, Callable] | None = None,
|
1479 |
+
state: dict[str, Any] | None = None,
|
1480 |
+
authorized_imports: list[str] = BASE_BUILTIN_MODULES,
|
1481 |
+
max_print_outputs_length: int = DEFAULT_MAX_LEN_OUTPUT,
|
1482 |
+
):
|
1483 |
+
"""
|
1484 |
+
Evaluate a python expression using the content of the variables stored in a state and only evaluating a given set
|
1485 |
+
of functions.
|
1486 |
+
|
1487 |
+
This function will recurse through the nodes of the tree provided.
|
1488 |
+
|
1489 |
+
Args:
|
1490 |
+
code (`str`):
|
1491 |
+
The code to evaluate.
|
1492 |
+
static_tools (`Dict[str, Callable]`):
|
1493 |
+
The functions that may be called during the evaluation. These can also be agents in a multiagent setting.
|
1494 |
+
These tools cannot be overwritten in the code: any assignment to their name will raise an error.
|
1495 |
+
custom_tools (`Dict[str, Callable]`):
|
1496 |
+
The functions that may be called during the evaluation.
|
1497 |
+
These tools can be overwritten in the code: any assignment to their name will overwrite them.
|
1498 |
+
state (`Dict[str, Any]`):
|
1499 |
+
A dictionary mapping variable names to values. The `state` should contain the initial inputs but will be
|
1500 |
+
updated by this function to contain all variables as they are evaluated.
|
1501 |
+
The print outputs will be stored in the state under the key "_print_outputs".
|
1502 |
+
"""
|
1503 |
+
try:
|
1504 |
+
expression = ast.parse(code)
|
1505 |
+
except SyntaxError as e:
|
1506 |
+
raise InterpreterError(
|
1507 |
+
f"Code parsing failed on line {e.lineno} due to: {type(e).__name__}\n"
|
1508 |
+
f"{e.text}"
|
1509 |
+
f"{' ' * (e.offset or 0)}^\n"
|
1510 |
+
f"Error: {str(e)}"
|
1511 |
+
)
|
1512 |
+
|
1513 |
+
if state is None:
|
1514 |
+
state = {}
|
1515 |
+
static_tools = static_tools.copy() if static_tools is not None else {}
|
1516 |
+
custom_tools = custom_tools if custom_tools is not None else {}
|
1517 |
+
result = None
|
1518 |
+
state["_print_outputs"] = PrintContainer()
|
1519 |
+
state["_operations_count"] = {"counter": 0}
|
1520 |
+
|
1521 |
+
if "final_answer" in static_tools:
|
1522 |
+
previous_final_answer = static_tools["final_answer"]
|
1523 |
+
|
1524 |
+
def final_answer(*args, **kwargs): # Allow arbitrary arguments to be passed
|
1525 |
+
raise FinalAnswerException(previous_final_answer(*args, **kwargs))
|
1526 |
+
|
1527 |
+
static_tools["final_answer"] = final_answer
|
1528 |
+
|
1529 |
+
try:
|
1530 |
+
for node in expression.body:
|
1531 |
+
result = evaluate_ast(node, state, static_tools, custom_tools, authorized_imports)
|
1532 |
+
state["_print_outputs"].value = truncate_content(
|
1533 |
+
str(state["_print_outputs"]), max_length=max_print_outputs_length
|
1534 |
+
)
|
1535 |
+
is_final_answer = False
|
1536 |
+
return result, is_final_answer
|
1537 |
+
except FinalAnswerException as e:
|
1538 |
+
state["_print_outputs"].value = truncate_content(
|
1539 |
+
str(state["_print_outputs"]), max_length=max_print_outputs_length
|
1540 |
+
)
|
1541 |
+
is_final_answer = True
|
1542 |
+
return e.value, is_final_answer
|
1543 |
+
except Exception as e:
|
1544 |
+
state["_print_outputs"].value = truncate_content(
|
1545 |
+
str(state["_print_outputs"]), max_length=max_print_outputs_length
|
1546 |
+
)
|
1547 |
+
raise InterpreterError(
|
1548 |
+
f"Code execution failed at line '{ast.get_source_segment(code, node)}' due to: {type(e).__name__}: {e}"
|
1549 |
+
)
|
1550 |
+
|
1551 |
+
|
1552 |
+
class PythonExecutor:
|
1553 |
+
pass
|
1554 |
+
|
1555 |
+
|
1556 |
+
class LocalPythonExecutor(PythonExecutor):
|
1557 |
+
"""
|
1558 |
+
Executor of Python code in a local environment.
|
1559 |
+
|
1560 |
+
This executor evaluates Python code with restricted access to imports and built-in functions,
|
1561 |
+
making it suitable for running untrusted code. It maintains state between executions,
|
1562 |
+
allows for custom tools and functions to be made available to the code, and captures
|
1563 |
+
print outputs separately from return values.
|
1564 |
+
|
1565 |
+
Args:
|
1566 |
+
additional_authorized_imports (`list[str]`):
|
1567 |
+
Additional authorized imports for the executor.
|
1568 |
+
max_print_outputs_length (`int`, defaults to `DEFAULT_MAX_LEN_OUTPUT=50_000`):
|
1569 |
+
Maximum length of the print outputs.
|
1570 |
+
additional_functions (`dict[str, Callable]`, *optional*):
|
1571 |
+
Additional Python functions to be added to the executor.
|
1572 |
+
"""
|
1573 |
+
|
1574 |
+
def __init__(
|
1575 |
+
self,
|
1576 |
+
additional_authorized_imports: list[str],
|
1577 |
+
max_print_outputs_length: int | None = None,
|
1578 |
+
additional_functions: dict[str, Callable] | None = None,
|
1579 |
+
):
|
1580 |
+
self.custom_tools = {}
|
1581 |
+
self.state = {"__name__": "__main__"}
|
1582 |
+
self.max_print_outputs_length = max_print_outputs_length
|
1583 |
+
if max_print_outputs_length is None:
|
1584 |
+
self.max_print_outputs_length = DEFAULT_MAX_LEN_OUTPUT
|
1585 |
+
self.additional_authorized_imports = additional_authorized_imports
|
1586 |
+
self.authorized_imports = list(set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports))
|
1587 |
+
# TODO: assert self.authorized imports are all installed locally
|
1588 |
+
self.static_tools = None
|
1589 |
+
self.additional_functions = additional_functions or {}
|
1590 |
+
|
1591 |
+
def __call__(self, code_action: str) -> tuple[Any, str, bool]:
|
1592 |
+
output, is_final_answer = evaluate_python_code(
|
1593 |
+
code_action,
|
1594 |
+
static_tools=self.static_tools,
|
1595 |
+
custom_tools=self.custom_tools,
|
1596 |
+
state=self.state,
|
1597 |
+
authorized_imports=self.authorized_imports,
|
1598 |
+
max_print_outputs_length=self.max_print_outputs_length,
|
1599 |
+
)
|
1600 |
+
logs = str(self.state["_print_outputs"])
|
1601 |
+
return output, logs, is_final_answer
|
1602 |
+
|
1603 |
+
def send_variables(self, variables: dict):
|
1604 |
+
self.state.update(variables)
|
1605 |
+
|
1606 |
+
def send_tools(self, tools: dict[str, Tool]):
|
1607 |
+
# Combine agent tools, base Python tools, and additional Python functions
|
1608 |
+
self.static_tools = {**tools, **BASE_PYTHON_TOOLS.copy(), **self.additional_functions}
|
1609 |
+
|
1610 |
+
|
1611 |
+
__all__ = ["evaluate_python_code", "LocalPythonExecutor"]
|
src/smolagents/mcp_client.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding=utf-8
|
3 |
+
|
4 |
+
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
5 |
+
#
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# Unless required by applicable law or agreed to in writing, software
|
13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
# See the License for the specific language governing permissions and
|
16 |
+
# limitations under the License.
|
17 |
+
|
18 |
+
from __future__ import annotations
|
19 |
+
|
20 |
+
import warnings
|
21 |
+
from types import TracebackType
|
22 |
+
from typing import TYPE_CHECKING, Any
|
23 |
+
|
24 |
+
from smolagents.tools import Tool
|
25 |
+
|
26 |
+
|
27 |
+
__all__ = ["MCPClient"]
|
28 |
+
|
29 |
+
if TYPE_CHECKING:
|
30 |
+
from mcpadapt.core import StdioServerParameters
|
31 |
+
|
32 |
+
|
33 |
+
class MCPClient:
|
34 |
+
"""Manages the connection to an MCP server and make its tools available to SmolAgents.
|
35 |
+
|
36 |
+
Note: tools can only be accessed after the connection has been started with the
|
37 |
+
`connect()` method, done during the init. If you don't use the context manager
|
38 |
+
we strongly encourage to use "try ... finally" to ensure the connection is cleaned up.
|
39 |
+
|
40 |
+
Args:
|
41 |
+
server_parameters (StdioServerParameters | dict[str, Any] | list[StdioServerParameters | dict[str, Any]]):
|
42 |
+
Configuration parameters to connect to the MCP server. Can be a list if you want to connect multiple MCPs at once.
|
43 |
+
|
44 |
+
- An instance of `mcp.StdioServerParameters` for connecting a Stdio MCP server via standard input/output using a subprocess.
|
45 |
+
|
46 |
+
- A `dict` with at least:
|
47 |
+
- "url": URL of the server.
|
48 |
+
- "transport": Transport protocol to use, one of:
|
49 |
+
- "streamable-http": (recommended) Streamable HTTP transport.
|
50 |
+
- "sse": Legacy HTTP+SSE transport (deprecated).
|
51 |
+
If "transport" is omitted, the legacy "sse" transport is assumed (a deprecation warning will be issued).
|
52 |
+
|
53 |
+
<Deprecated version="1.17.0">
|
54 |
+
The HTTP+SSE transport is deprecated and future behavior will default to the Streamable HTTP transport.
|
55 |
+
Please pass explicitly the "transport" key.
|
56 |
+
</Deprecated>
|
57 |
+
|
58 |
+
Example:
|
59 |
+
```python
|
60 |
+
# fully managed context manager + stdio
|
61 |
+
with MCPClient(...) as tools:
|
62 |
+
# tools are now available
|
63 |
+
|
64 |
+
# context manager + Streamable HTTP transport:
|
65 |
+
with MCPClient({"url": "http://localhost:8000/mcp", "transport": "streamable-http"}) as tools:
|
66 |
+
# tools are now available
|
67 |
+
|
68 |
+
# manually manage the connection via the mcp_client object:
|
69 |
+
try:
|
70 |
+
mcp_client = MCPClient(...)
|
71 |
+
tools = mcp_client.get_tools()
|
72 |
+
|
73 |
+
# use your tools here.
|
74 |
+
finally:
|
75 |
+
mcp_client.disconnect()
|
76 |
+
```
|
77 |
+
"""
|
78 |
+
|
79 |
+
def __init__(
|
80 |
+
self,
|
81 |
+
server_parameters: "StdioServerParameters" | dict[str, Any] | list["StdioServerParameters" | dict[str, Any]],
|
82 |
+
):
|
83 |
+
try:
|
84 |
+
from mcpadapt.core import MCPAdapt
|
85 |
+
from mcpadapt.smolagents_adapter import SmolAgentsAdapter
|
86 |
+
except ModuleNotFoundError:
|
87 |
+
raise ModuleNotFoundError("Please install 'mcp' extra to use MCPClient: `pip install 'smolagents[mcp]'`")
|
88 |
+
if isinstance(server_parameters, dict):
|
89 |
+
transport = server_parameters.get("transport")
|
90 |
+
if transport is None:
|
91 |
+
warnings.warn(
|
92 |
+
"Passing a dict as server_parameters without specifying the 'transport' key is deprecated. "
|
93 |
+
"For now, it defaults to the legacy 'sse' (HTTP+SSE) transport, but this default will change "
|
94 |
+
"to 'streamable-http' in version 1.20. Please add the 'transport' key explicitly. ",
|
95 |
+
FutureWarning,
|
96 |
+
)
|
97 |
+
transport = "sse"
|
98 |
+
server_parameters["transport"] = transport
|
99 |
+
if transport not in {"sse", "streamable-http"}:
|
100 |
+
raise ValueError(
|
101 |
+
f"Unsupported transport: {transport}. Supported transports are 'streamable-http' and 'sse'."
|
102 |
+
)
|
103 |
+
self._adapter = MCPAdapt(server_parameters, SmolAgentsAdapter())
|
104 |
+
self._tools: list[Tool] | None = None
|
105 |
+
self.connect()
|
106 |
+
|
107 |
+
def connect(self):
|
108 |
+
"""Connect to the MCP server and initialize the tools."""
|
109 |
+
self._tools: list[Tool] = self._adapter.__enter__()
|
110 |
+
|
111 |
+
def disconnect(
|
112 |
+
self,
|
113 |
+
exc_type: type[BaseException] | None = None,
|
114 |
+
exc_value: BaseException | None = None,
|
115 |
+
exc_traceback: TracebackType | None = None,
|
116 |
+
):
|
117 |
+
"""Disconnect from the MCP server"""
|
118 |
+
self._adapter.__exit__(exc_type, exc_value, exc_traceback)
|
119 |
+
|
120 |
+
def get_tools(self) -> list[Tool]:
|
121 |
+
"""The SmolAgents tools available from the MCP server.
|
122 |
+
|
123 |
+
Note: for now, this always returns the tools available at the creation of the session,
|
124 |
+
but it will in a future release return also new tools available from the MCP server if
|
125 |
+
any at call time.
|
126 |
+
|
127 |
+
Raises:
|
128 |
+
ValueError: If the MCP server tools is None (usually assuming the server is not started).
|
129 |
+
|
130 |
+
Returns:
|
131 |
+
list[Tool]: The SmolAgents tools available from the MCP server.
|
132 |
+
"""
|
133 |
+
if self._tools is None:
|
134 |
+
raise ValueError(
|
135 |
+
"Couldn't retrieve tools from MCP server, run `mcp_client.connect()` first before accessing `tools`"
|
136 |
+
)
|
137 |
+
return self._tools
|
138 |
+
|
139 |
+
def __enter__(self) -> list[Tool]:
|
140 |
+
"""Connect to the MCP server and return the tools directly.
|
141 |
+
|
142 |
+
Note that because of the `.connect` in the init, the mcp_client
|
143 |
+
is already connected at this point.
|
144 |
+
"""
|
145 |
+
return self._tools
|
146 |
+
|
147 |
+
def __exit__(
|
148 |
+
self,
|
149 |
+
exc_type: type[BaseException] | None,
|
150 |
+
exc_value: BaseException | None,
|
151 |
+
exc_traceback: TracebackType | None,
|
152 |
+
):
|
153 |
+
"""Disconnect from the MCP server."""
|
154 |
+
self.disconnect(exc_type, exc_value, exc_traceback)
|
src/smolagents/memory.py
ADDED
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import asdict, dataclass
|
2 |
+
from logging import getLogger
|
3 |
+
from typing import TYPE_CHECKING, Any
|
4 |
+
|
5 |
+
from smolagents.models import ChatMessage, MessageRole
|
6 |
+
from smolagents.monitoring import AgentLogger, LogLevel, Timing, TokenUsage
|
7 |
+
from smolagents.utils import AgentError, make_json_serializable
|
8 |
+
|
9 |
+
|
10 |
+
if TYPE_CHECKING:
|
11 |
+
import PIL.Image
|
12 |
+
|
13 |
+
from smolagents.models import ChatMessage
|
14 |
+
from smolagents.monitoring import AgentLogger
|
15 |
+
|
16 |
+
|
17 |
+
logger = getLogger(__name__)
|
18 |
+
|
19 |
+
|
20 |
+
@dataclass
|
21 |
+
class ToolCall:
|
22 |
+
name: str
|
23 |
+
arguments: Any
|
24 |
+
id: str
|
25 |
+
|
26 |
+
def dict(self):
|
27 |
+
return {
|
28 |
+
"id": self.id,
|
29 |
+
"type": "function",
|
30 |
+
"function": {
|
31 |
+
"name": self.name,
|
32 |
+
"arguments": make_json_serializable(self.arguments),
|
33 |
+
},
|
34 |
+
}
|
35 |
+
|
36 |
+
|
37 |
+
@dataclass
|
38 |
+
class MemoryStep:
|
39 |
+
def dict(self):
|
40 |
+
return asdict(self)
|
41 |
+
|
42 |
+
def to_messages(self, summary_mode: bool = False) -> list[ChatMessage]:
|
43 |
+
raise NotImplementedError
|
44 |
+
|
45 |
+
|
46 |
+
@dataclass
|
47 |
+
class ActionStep(MemoryStep):
|
48 |
+
step_number: int
|
49 |
+
timing: Timing
|
50 |
+
model_input_messages: list[ChatMessage] | None = None
|
51 |
+
tool_calls: list[ToolCall] | None = None
|
52 |
+
error: AgentError | None = None
|
53 |
+
model_output_message: ChatMessage | None = None
|
54 |
+
model_output: str | list[dict[str, Any]] | None = None
|
55 |
+
code_action: str | None = None
|
56 |
+
observations: str | None = None
|
57 |
+
observations_images: list["PIL.Image.Image"] | None = None
|
58 |
+
action_output: Any = None
|
59 |
+
token_usage: TokenUsage | None = None
|
60 |
+
is_final_answer: bool = False
|
61 |
+
|
62 |
+
def dict(self):
|
63 |
+
# We overwrite the method to parse the tool_calls and action_output manually
|
64 |
+
return {
|
65 |
+
"step_number": self.step_number,
|
66 |
+
"timing": self.timing.dict(),
|
67 |
+
"model_input_messages": self.model_input_messages,
|
68 |
+
"tool_calls": [tc.dict() for tc in self.tool_calls] if self.tool_calls else [],
|
69 |
+
"error": self.error.dict() if self.error else None,
|
70 |
+
"model_output_message": self.model_output_message.dict() if self.model_output_message else None,
|
71 |
+
"model_output": self.model_output,
|
72 |
+
"code_action": self.code_action,
|
73 |
+
"observations": self.observations,
|
74 |
+
"observations_images": [image.tobytes() for image in self.observations_images]
|
75 |
+
if self.observations_images
|
76 |
+
else None,
|
77 |
+
"action_output": make_json_serializable(self.action_output),
|
78 |
+
"token_usage": asdict(self.token_usage) if self.token_usage else None,
|
79 |
+
"is_final_answer": self.is_final_answer,
|
80 |
+
}
|
81 |
+
|
82 |
+
def to_messages(self, summary_mode: bool = False) -> list[ChatMessage]:
|
83 |
+
messages = []
|
84 |
+
if self.model_output is not None and not summary_mode:
|
85 |
+
messages.append(
|
86 |
+
ChatMessage(role=MessageRole.ASSISTANT, content=[{"type": "text", "text": self.model_output.strip()}])
|
87 |
+
)
|
88 |
+
|
89 |
+
if self.tool_calls is not None:
|
90 |
+
messages.append(
|
91 |
+
ChatMessage(
|
92 |
+
role=MessageRole.TOOL_CALL,
|
93 |
+
content=[
|
94 |
+
{
|
95 |
+
"type": "text",
|
96 |
+
"text": "Calling tools:\n" + str([tc.dict() for tc in self.tool_calls]),
|
97 |
+
}
|
98 |
+
],
|
99 |
+
)
|
100 |
+
)
|
101 |
+
|
102 |
+
if self.observations_images:
|
103 |
+
messages.append(
|
104 |
+
ChatMessage(
|
105 |
+
role=MessageRole.USER,
|
106 |
+
content=[
|
107 |
+
{
|
108 |
+
"type": "image",
|
109 |
+
"image": image,
|
110 |
+
}
|
111 |
+
for image in self.observations_images
|
112 |
+
],
|
113 |
+
)
|
114 |
+
)
|
115 |
+
|
116 |
+
if self.observations is not None:
|
117 |
+
messages.append(
|
118 |
+
ChatMessage(
|
119 |
+
role=MessageRole.TOOL_RESPONSE,
|
120 |
+
content=[
|
121 |
+
{
|
122 |
+
"type": "text",
|
123 |
+
"text": f"Observation:\n{self.observations}",
|
124 |
+
}
|
125 |
+
],
|
126 |
+
)
|
127 |
+
)
|
128 |
+
if self.error is not None:
|
129 |
+
error_message = (
|
130 |
+
"Error:\n"
|
131 |
+
+ str(self.error)
|
132 |
+
+ "\nNow let's retry: take care not to repeat previous errors! If you have retried several times, try a completely different approach.\n"
|
133 |
+
)
|
134 |
+
message_content = f"Call id: {self.tool_calls[0].id}\n" if self.tool_calls else ""
|
135 |
+
message_content += error_message
|
136 |
+
messages.append(
|
137 |
+
ChatMessage(role=MessageRole.TOOL_RESPONSE, content=[{"type": "text", "text": message_content}])
|
138 |
+
)
|
139 |
+
|
140 |
+
return messages
|
141 |
+
|
142 |
+
|
143 |
+
@dataclass
|
144 |
+
class PlanningStep(MemoryStep):
|
145 |
+
model_input_messages: list[ChatMessage]
|
146 |
+
model_output_message: ChatMessage
|
147 |
+
plan: str
|
148 |
+
timing: Timing
|
149 |
+
token_usage: TokenUsage | None = None
|
150 |
+
|
151 |
+
def to_messages(self, summary_mode: bool = False) -> list[ChatMessage]:
|
152 |
+
if summary_mode:
|
153 |
+
return []
|
154 |
+
return [
|
155 |
+
ChatMessage(role=MessageRole.ASSISTANT, content=[{"type": "text", "text": self.plan.strip()}]),
|
156 |
+
ChatMessage(
|
157 |
+
role=MessageRole.USER, content=[{"type": "text", "text": "Now proceed and carry out this plan."}]
|
158 |
+
),
|
159 |
+
# This second message creates a role change to prevent models models from simply continuing the plan message
|
160 |
+
]
|
161 |
+
|
162 |
+
|
163 |
+
@dataclass
|
164 |
+
class TaskStep(MemoryStep):
|
165 |
+
task: str
|
166 |
+
task_images: list["PIL.Image.Image"] | None = None
|
167 |
+
|
168 |
+
def to_messages(self, summary_mode: bool = False) -> list[ChatMessage]:
|
169 |
+
content = [{"type": "text", "text": f"New task:\n{self.task}"}]
|
170 |
+
if self.task_images:
|
171 |
+
content.extend([{"type": "image", "image": image} for image in self.task_images])
|
172 |
+
|
173 |
+
return [ChatMessage(role=MessageRole.USER, content=content)]
|
174 |
+
|
175 |
+
|
176 |
+
@dataclass
|
177 |
+
class SystemPromptStep(MemoryStep):
|
178 |
+
system_prompt: str
|
179 |
+
|
180 |
+
def to_messages(self, summary_mode: bool = False) -> list[ChatMessage]:
|
181 |
+
if summary_mode:
|
182 |
+
return []
|
183 |
+
return [ChatMessage(role=MessageRole.SYSTEM, content=[{"type": "text", "text": self.system_prompt}])]
|
184 |
+
|
185 |
+
|
186 |
+
@dataclass
|
187 |
+
class FinalAnswerStep(MemoryStep):
|
188 |
+
output: Any
|
189 |
+
|
190 |
+
|
191 |
+
class AgentMemory:
|
192 |
+
"""Memory for the agent, containing the system prompt and all steps taken by the agent.
|
193 |
+
|
194 |
+
This class is used to store the agent's steps, including tasks, actions, and planning steps.
|
195 |
+
It allows for resetting the memory, retrieving succinct or full step information, and replaying the agent's steps.
|
196 |
+
|
197 |
+
Args:
|
198 |
+
system_prompt (`str`): System prompt for the agent, which sets the context and instructions for the agent's behavior.
|
199 |
+
|
200 |
+
**Attributes**:
|
201 |
+
- **system_prompt** (`SystemPromptStep`) -- System prompt step for the agent.
|
202 |
+
- **steps** (`list[TaskStep | ActionStep | PlanningStep]`) -- List of steps taken by the agent, which can include tasks, actions, and planning steps.
|
203 |
+
"""
|
204 |
+
|
205 |
+
def __init__(self, system_prompt: str):
|
206 |
+
self.system_prompt: SystemPromptStep = SystemPromptStep(system_prompt=system_prompt)
|
207 |
+
self.steps: list[TaskStep | ActionStep | PlanningStep] = []
|
208 |
+
|
209 |
+
def reset(self):
|
210 |
+
"""Reset the agent's memory, clearing all steps and keeping the system prompt."""
|
211 |
+
self.steps = []
|
212 |
+
|
213 |
+
def get_succinct_steps(self) -> list[dict]:
|
214 |
+
"""Return a succinct representation of the agent's steps, excluding model input messages."""
|
215 |
+
return [
|
216 |
+
{key: value for key, value in step.dict().items() if key != "model_input_messages"} for step in self.steps
|
217 |
+
]
|
218 |
+
|
219 |
+
def get_full_steps(self) -> list[dict]:
|
220 |
+
"""Return a full representation of the agent's steps, including model input messages."""
|
221 |
+
if len(self.steps) == 0:
|
222 |
+
return []
|
223 |
+
return [step.dict() for step in self.steps]
|
224 |
+
|
225 |
+
def replay(self, logger: AgentLogger, detailed: bool = False):
|
226 |
+
"""Prints a pretty replay of the agent's steps.
|
227 |
+
|
228 |
+
Args:
|
229 |
+
logger (`AgentLogger`): The logger to print replay logs to.
|
230 |
+
detailed (`bool`, default `False`): If True, also displays the memory at each step. Defaults to False.
|
231 |
+
Careful: will increase log length exponentially. Use only for debugging.
|
232 |
+
"""
|
233 |
+
logger.console.log("Replaying the agent's steps:")
|
234 |
+
logger.log_markdown(title="System prompt", content=self.system_prompt.system_prompt, level=LogLevel.ERROR)
|
235 |
+
for step in self.steps:
|
236 |
+
if isinstance(step, TaskStep):
|
237 |
+
logger.log_task(step.task, "", level=LogLevel.ERROR)
|
238 |
+
elif isinstance(step, ActionStep):
|
239 |
+
logger.log_rule(f"Step {step.step_number}", level=LogLevel.ERROR)
|
240 |
+
if detailed and step.model_input_messages is not None:
|
241 |
+
logger.log_messages(step.model_input_messages, level=LogLevel.ERROR)
|
242 |
+
if step.model_output is not None:
|
243 |
+
logger.log_markdown(title="Agent output:", content=step.model_output, level=LogLevel.ERROR)
|
244 |
+
elif isinstance(step, PlanningStep):
|
245 |
+
logger.log_rule("Planning step", level=LogLevel.ERROR)
|
246 |
+
if detailed and step.model_input_messages is not None:
|
247 |
+
logger.log_messages(step.model_input_messages, level=LogLevel.ERROR)
|
248 |
+
logger.log_markdown(title="Agent output:", content=step.plan, level=LogLevel.ERROR)
|
249 |
+
|
250 |
+
def return_full_code(self) -> str:
|
251 |
+
"""Returns all code actions from the agent's steps, concatenated as a single script."""
|
252 |
+
return "\n\n".join(
|
253 |
+
[step.code_action for step in self.steps if isinstance(step, ActionStep) and step.code_action is not None]
|
254 |
+
)
|
255 |
+
|
256 |
+
|
257 |
+
__all__ = ["AgentMemory"]
|
src/smolagents/models.py
ADDED
@@ -0,0 +1,1882 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import json
|
15 |
+
import logging
|
16 |
+
import os
|
17 |
+
import re
|
18 |
+
import uuid
|
19 |
+
import warnings
|
20 |
+
from collections.abc import Generator
|
21 |
+
from copy import deepcopy
|
22 |
+
from dataclasses import asdict, dataclass
|
23 |
+
from enum import Enum
|
24 |
+
from threading import Thread
|
25 |
+
from typing import TYPE_CHECKING, Any
|
26 |
+
|
27 |
+
from .monitoring import TokenUsage
|
28 |
+
from .tools import Tool
|
29 |
+
from .utils import _is_package_available, encode_image_base64, make_image_url, parse_json_blob
|
30 |
+
|
31 |
+
|
32 |
+
if TYPE_CHECKING:
|
33 |
+
from transformers import StoppingCriteriaList
|
34 |
+
|
35 |
+
|
36 |
+
logger = logging.getLogger(__name__)
|
37 |
+
|
38 |
+
STRUCTURED_GENERATION_PROVIDERS = ["cerebras", "fireworks-ai"]
|
39 |
+
CODEAGENT_RESPONSE_FORMAT = {
|
40 |
+
"type": "json_schema",
|
41 |
+
"json_schema": {
|
42 |
+
"schema": {
|
43 |
+
"additionalProperties": False,
|
44 |
+
"properties": {
|
45 |
+
"thought": {
|
46 |
+
"description": "A free form text description of the thought process.",
|
47 |
+
"title": "Thought",
|
48 |
+
"type": "string",
|
49 |
+
},
|
50 |
+
"code": {
|
51 |
+
"description": "Valid Python code snippet implementing the thought.",
|
52 |
+
"title": "Code",
|
53 |
+
"type": "string",
|
54 |
+
},
|
55 |
+
},
|
56 |
+
"required": ["thought", "code"],
|
57 |
+
"title": "ThoughtAndCodeAnswer",
|
58 |
+
"type": "object",
|
59 |
+
},
|
60 |
+
"name": "ThoughtAndCodeAnswer",
|
61 |
+
"strict": True,
|
62 |
+
},
|
63 |
+
}
|
64 |
+
|
65 |
+
|
66 |
+
def get_dict_from_nested_dataclasses(obj, ignore_key=None):
|
67 |
+
def convert(obj):
|
68 |
+
if hasattr(obj, "__dataclass_fields__"):
|
69 |
+
return {k: convert(v) for k, v in asdict(obj).items() if k != ignore_key}
|
70 |
+
return obj
|
71 |
+
|
72 |
+
return convert(obj)
|
73 |
+
|
74 |
+
|
75 |
+
@dataclass
|
76 |
+
class ChatMessageToolCallFunction:
|
77 |
+
arguments: Any
|
78 |
+
name: str
|
79 |
+
description: str | None = None
|
80 |
+
|
81 |
+
|
82 |
+
@dataclass
|
83 |
+
class ChatMessageToolCall:
|
84 |
+
function: ChatMessageToolCallFunction
|
85 |
+
id: str
|
86 |
+
type: str
|
87 |
+
|
88 |
+
def __str__(self) -> str:
|
89 |
+
return f"Call: {self.id}: Calling {str(self.function.name)} with arguments: {str(self.function.arguments)}"
|
90 |
+
|
91 |
+
|
92 |
+
class MessageRole(str, Enum):
|
93 |
+
USER = "user"
|
94 |
+
ASSISTANT = "assistant"
|
95 |
+
SYSTEM = "system"
|
96 |
+
TOOL_CALL = "tool-call"
|
97 |
+
TOOL_RESPONSE = "tool-response"
|
98 |
+
|
99 |
+
@classmethod
|
100 |
+
def roles(cls):
|
101 |
+
return [r.value for r in cls]
|
102 |
+
|
103 |
+
|
104 |
+
@dataclass
|
105 |
+
class ChatMessage:
|
106 |
+
role: MessageRole
|
107 |
+
content: str | list[dict[str, Any]] | None = None
|
108 |
+
tool_calls: list[ChatMessageToolCall] | None = None
|
109 |
+
raw: Any | None = None # Stores the raw output from the API
|
110 |
+
token_usage: TokenUsage | None = None
|
111 |
+
|
112 |
+
def model_dump_json(self):
|
113 |
+
return json.dumps(get_dict_from_nested_dataclasses(self, ignore_key="raw"))
|
114 |
+
|
115 |
+
@classmethod
|
116 |
+
def from_dict(cls, data: dict, raw: Any | None = None, token_usage: TokenUsage | None = None) -> "ChatMessage":
|
117 |
+
if data.get("tool_calls"):
|
118 |
+
tool_calls = [
|
119 |
+
ChatMessageToolCall(
|
120 |
+
function=ChatMessageToolCallFunction(**tc["function"]), id=tc["id"], type=tc["type"]
|
121 |
+
)
|
122 |
+
for tc in data["tool_calls"]
|
123 |
+
]
|
124 |
+
data["tool_calls"] = tool_calls
|
125 |
+
return cls(
|
126 |
+
role=data["role"],
|
127 |
+
content=data.get("content"),
|
128 |
+
tool_calls=data.get("tool_calls"),
|
129 |
+
raw=raw,
|
130 |
+
token_usage=token_usage,
|
131 |
+
)
|
132 |
+
|
133 |
+
def dict(self):
|
134 |
+
return get_dict_from_nested_dataclasses(self)
|
135 |
+
|
136 |
+
def render_as_markdown(self) -> str:
|
137 |
+
rendered = str(self.content) or ""
|
138 |
+
if self.tool_calls:
|
139 |
+
rendered += "\n".join(
|
140 |
+
[
|
141 |
+
json.dumps({"tool": tool.function.name, "arguments": tool.function.arguments})
|
142 |
+
for tool in self.tool_calls
|
143 |
+
]
|
144 |
+
)
|
145 |
+
return rendered
|
146 |
+
|
147 |
+
|
148 |
+
def parse_json_if_needed(arguments: str | dict) -> str | dict:
|
149 |
+
if isinstance(arguments, dict):
|
150 |
+
return arguments
|
151 |
+
else:
|
152 |
+
try:
|
153 |
+
return json.loads(arguments)
|
154 |
+
except Exception:
|
155 |
+
return arguments
|
156 |
+
|
157 |
+
|
158 |
+
@dataclass
|
159 |
+
class ChatMessageToolCallStreamDelta:
|
160 |
+
"""Represents a streaming delta for tool calls during generation."""
|
161 |
+
|
162 |
+
index: int | None = None
|
163 |
+
id: str | None = None
|
164 |
+
type: str | None = None
|
165 |
+
function: ChatMessageToolCallFunction | None = None
|
166 |
+
|
167 |
+
|
168 |
+
@dataclass
|
169 |
+
class ChatMessageStreamDelta:
|
170 |
+
content: str | None = None
|
171 |
+
tool_calls: list[ChatMessageToolCallStreamDelta] | None = None
|
172 |
+
token_usage: TokenUsage | None = None
|
173 |
+
|
174 |
+
|
175 |
+
def agglomerate_stream_deltas(
|
176 |
+
stream_deltas: list[ChatMessageStreamDelta], role: MessageRole = MessageRole.ASSISTANT
|
177 |
+
) -> ChatMessage:
|
178 |
+
"""
|
179 |
+
Agglomerate a list of stream deltas into a single stream delta.
|
180 |
+
"""
|
181 |
+
accumulated_tool_calls: dict[int, ChatMessageToolCallStreamDelta] = {}
|
182 |
+
accumulated_content = ""
|
183 |
+
total_input_tokens = 0
|
184 |
+
total_output_tokens = 0
|
185 |
+
for stream_delta in stream_deltas:
|
186 |
+
if stream_delta.token_usage:
|
187 |
+
total_input_tokens += stream_delta.token_usage.input_tokens
|
188 |
+
total_output_tokens += stream_delta.token_usage.output_tokens
|
189 |
+
if stream_delta.content:
|
190 |
+
accumulated_content += stream_delta.content
|
191 |
+
if stream_delta.tool_calls:
|
192 |
+
for tool_call_delta in stream_delta.tool_calls: # ?ormally there should be only one call at a time
|
193 |
+
# Extend accumulated_tool_calls list to accommodate the new tool call if needed
|
194 |
+
if tool_call_delta.index is not None:
|
195 |
+
if tool_call_delta.index not in accumulated_tool_calls:
|
196 |
+
accumulated_tool_calls[tool_call_delta.index] = ChatMessageToolCallStreamDelta(
|
197 |
+
id=tool_call_delta.id,
|
198 |
+
type=tool_call_delta.type,
|
199 |
+
function=ChatMessageToolCallFunction(name="", arguments=""),
|
200 |
+
)
|
201 |
+
# Update the tool call at the specific index
|
202 |
+
tool_call = accumulated_tool_calls[tool_call_delta.index]
|
203 |
+
if tool_call_delta.id:
|
204 |
+
tool_call.id = tool_call_delta.id
|
205 |
+
if tool_call_delta.type:
|
206 |
+
tool_call.type = tool_call_delta.type
|
207 |
+
if tool_call_delta.function:
|
208 |
+
if tool_call_delta.function.name and len(tool_call_delta.function.name) > 0:
|
209 |
+
tool_call.function.name = tool_call_delta.function.name
|
210 |
+
if tool_call_delta.function.arguments:
|
211 |
+
tool_call.function.arguments += tool_call_delta.function.arguments
|
212 |
+
else:
|
213 |
+
raise ValueError(f"Tool call index is not provided in tool delta: {tool_call_delta}")
|
214 |
+
|
215 |
+
return ChatMessage(
|
216 |
+
role=role,
|
217 |
+
content=accumulated_content,
|
218 |
+
tool_calls=[
|
219 |
+
ChatMessageToolCall(
|
220 |
+
function=ChatMessageToolCallFunction(
|
221 |
+
name=tool_call_stream_delta.function.name,
|
222 |
+
arguments=tool_call_stream_delta.function.arguments,
|
223 |
+
),
|
224 |
+
id=tool_call_stream_delta.id or "",
|
225 |
+
type="function",
|
226 |
+
)
|
227 |
+
for tool_call_stream_delta in accumulated_tool_calls.values()
|
228 |
+
if tool_call_stream_delta.function
|
229 |
+
],
|
230 |
+
token_usage=TokenUsage(
|
231 |
+
input_tokens=total_input_tokens,
|
232 |
+
output_tokens=total_output_tokens,
|
233 |
+
),
|
234 |
+
)
|
235 |
+
|
236 |
+
|
237 |
+
tool_role_conversions = {
|
238 |
+
MessageRole.TOOL_CALL: MessageRole.ASSISTANT,
|
239 |
+
MessageRole.TOOL_RESPONSE: MessageRole.USER,
|
240 |
+
}
|
241 |
+
|
242 |
+
|
243 |
+
def get_tool_json_schema(tool: Tool) -> dict:
|
244 |
+
properties = deepcopy(tool.inputs)
|
245 |
+
required = []
|
246 |
+
for key, value in properties.items():
|
247 |
+
if value["type"] == "any":
|
248 |
+
value["type"] = "string"
|
249 |
+
if not ("nullable" in value and value["nullable"]):
|
250 |
+
required.append(key)
|
251 |
+
return {
|
252 |
+
"type": "function",
|
253 |
+
"function": {
|
254 |
+
"name": tool.name,
|
255 |
+
"description": tool.description,
|
256 |
+
"parameters": {
|
257 |
+
"type": "object",
|
258 |
+
"properties": properties,
|
259 |
+
"required": required,
|
260 |
+
},
|
261 |
+
},
|
262 |
+
}
|
263 |
+
|
264 |
+
|
265 |
+
def remove_stop_sequences(content: str, stop_sequences: list[str]) -> str:
|
266 |
+
for stop_seq in stop_sequences:
|
267 |
+
if content[-len(stop_seq) :] == stop_seq:
|
268 |
+
content = content[: -len(stop_seq)]
|
269 |
+
return content
|
270 |
+
|
271 |
+
|
272 |
+
def get_clean_message_list(
|
273 |
+
message_list: list[ChatMessage],
|
274 |
+
role_conversions: dict[MessageRole, MessageRole] | dict[str, str] = {},
|
275 |
+
convert_images_to_image_urls: bool = False,
|
276 |
+
flatten_messages_as_text: bool = False,
|
277 |
+
) -> list[dict[str, Any]]:
|
278 |
+
"""
|
279 |
+
Creates a list of messages to give as input to the LLM. These messages are dictionaries and chat template compatible with transformers LLM chat template.
|
280 |
+
Subsequent messages with the same role will be concatenated to a single message.
|
281 |
+
|
282 |
+
Args:
|
283 |
+
message_list (`list[dict[str, str]]`): List of chat messages.
|
284 |
+
role_conversions (`dict[MessageRole, MessageRole]`, *optional* ): Mapping to convert roles.
|
285 |
+
convert_images_to_image_urls (`bool`, default `False`): Whether to convert images to image URLs.
|
286 |
+
flatten_messages_as_text (`bool`, default `False`): Whether to flatten messages as text.
|
287 |
+
"""
|
288 |
+
output_message_list: list[dict[str, Any]] = []
|
289 |
+
message_list = deepcopy(message_list) # Avoid modifying the original list
|
290 |
+
for message in message_list:
|
291 |
+
role = message.role
|
292 |
+
if role not in MessageRole.roles():
|
293 |
+
raise ValueError(f"Incorrect role {role}, only {MessageRole.roles()} are supported for now.")
|
294 |
+
|
295 |
+
if role in role_conversions:
|
296 |
+
message.role = role_conversions[role] # type: ignore
|
297 |
+
# encode images if needed
|
298 |
+
if isinstance(message.content, list):
|
299 |
+
for element in message.content:
|
300 |
+
assert isinstance(element, dict), "Error: this element should be a dict:" + str(element)
|
301 |
+
if element["type"] == "image":
|
302 |
+
assert not flatten_messages_as_text, f"Cannot use images with {flatten_messages_as_text=}"
|
303 |
+
if convert_images_to_image_urls:
|
304 |
+
element.update(
|
305 |
+
{
|
306 |
+
"type": "image_url",
|
307 |
+
"image_url": {"url": make_image_url(encode_image_base64(element.pop("image")))},
|
308 |
+
}
|
309 |
+
)
|
310 |
+
else:
|
311 |
+
element["image"] = encode_image_base64(element["image"])
|
312 |
+
|
313 |
+
if len(output_message_list) > 0 and message.role == output_message_list[-1]["role"]:
|
314 |
+
assert isinstance(message.content, list), "Error: wrong content:" + str(message.content)
|
315 |
+
if flatten_messages_as_text:
|
316 |
+
output_message_list[-1]["content"] += "\n" + message.content[0]["text"]
|
317 |
+
else:
|
318 |
+
for el in message.content:
|
319 |
+
if el["type"] == "text" and output_message_list[-1]["content"][-1]["type"] == "text":
|
320 |
+
# Merge consecutive text messages rather than creating new ones
|
321 |
+
output_message_list[-1]["content"][-1]["text"] += "\n" + el["text"]
|
322 |
+
else:
|
323 |
+
output_message_list[-1]["content"].append(el)
|
324 |
+
else:
|
325 |
+
if flatten_messages_as_text:
|
326 |
+
content = message.content[0]["text"]
|
327 |
+
else:
|
328 |
+
content = message.content
|
329 |
+
output_message_list.append(
|
330 |
+
{
|
331 |
+
"role": message.role,
|
332 |
+
"content": content,
|
333 |
+
}
|
334 |
+
)
|
335 |
+
return output_message_list
|
336 |
+
|
337 |
+
|
338 |
+
def get_tool_call_from_text(text: str, tool_name_key: str, tool_arguments_key: str) -> ChatMessageToolCall:
|
339 |
+
tool_call_dictionary, _ = parse_json_blob(text)
|
340 |
+
try:
|
341 |
+
tool_name = tool_call_dictionary[tool_name_key]
|
342 |
+
except Exception as e:
|
343 |
+
raise ValueError(
|
344 |
+
f"Key {tool_name_key=} not found in the generated tool call. Got keys: {list(tool_call_dictionary.keys())} instead"
|
345 |
+
) from e
|
346 |
+
tool_arguments = tool_call_dictionary.get(tool_arguments_key, None)
|
347 |
+
if isinstance(tool_arguments, str):
|
348 |
+
tool_arguments = parse_json_if_needed(tool_arguments)
|
349 |
+
return ChatMessageToolCall(
|
350 |
+
id=str(uuid.uuid4()),
|
351 |
+
type="function",
|
352 |
+
function=ChatMessageToolCallFunction(name=tool_name, arguments=tool_arguments),
|
353 |
+
)
|
354 |
+
|
355 |
+
|
356 |
+
def supports_stop_parameter(model_id: str) -> bool:
|
357 |
+
"""
|
358 |
+
Check if the model supports the `stop` parameter.
|
359 |
+
|
360 |
+
Not supported with reasoning models openai/o3 and openai/o4-mini (and their versioned variants).
|
361 |
+
|
362 |
+
Args:
|
363 |
+
model_id (`str`): Model identifier (e.g. "openai/o3", "o4-mini-2025-04-16")
|
364 |
+
|
365 |
+
Returns:
|
366 |
+
bool: True if the model supports the stop parameter, False otherwise
|
367 |
+
"""
|
368 |
+
model_name = model_id.split("/")[-1]
|
369 |
+
# o3 and o4-mini (including versioned variants, o3-2025-04-16) don't support stop parameter
|
370 |
+
pattern = r"^(o3[-\d]*|o4-mini[-\d]*)$"
|
371 |
+
return not re.match(pattern, model_name)
|
372 |
+
|
373 |
+
|
374 |
+
class Model:
|
375 |
+
def __init__(
|
376 |
+
self,
|
377 |
+
flatten_messages_as_text: bool = False,
|
378 |
+
tool_name_key: str = "name",
|
379 |
+
tool_arguments_key: str = "arguments",
|
380 |
+
model_id: str | None = None,
|
381 |
+
**kwargs,
|
382 |
+
):
|
383 |
+
self.flatten_messages_as_text = flatten_messages_as_text
|
384 |
+
self.tool_name_key = tool_name_key
|
385 |
+
self.tool_arguments_key = tool_arguments_key
|
386 |
+
self.kwargs = kwargs
|
387 |
+
self._last_input_token_count: int | None = None
|
388 |
+
self._last_output_token_count: int | None = None
|
389 |
+
self.model_id: str | None = model_id
|
390 |
+
|
391 |
+
@property
|
392 |
+
def last_input_token_count(self) -> int | None:
|
393 |
+
warnings.warn(
|
394 |
+
"Attribute last_input_token_count is deprecated and will be removed in version 1.20. "
|
395 |
+
"Please use TokenUsage.input_tokens instead.",
|
396 |
+
FutureWarning,
|
397 |
+
)
|
398 |
+
return self._last_input_token_count
|
399 |
+
|
400 |
+
@property
|
401 |
+
def last_output_token_count(self) -> int | None:
|
402 |
+
warnings.warn(
|
403 |
+
"Attribute last_output_token_count is deprecated and will be removed in version 1.20. "
|
404 |
+
"Please use TokenUsage.output_tokens instead.",
|
405 |
+
FutureWarning,
|
406 |
+
)
|
407 |
+
return self._last_output_token_count
|
408 |
+
|
409 |
+
def _prepare_completion_kwargs(
|
410 |
+
self,
|
411 |
+
messages: list[ChatMessage],
|
412 |
+
stop_sequences: list[str] | None = None,
|
413 |
+
response_format: dict[str, str] | None = None,
|
414 |
+
tools_to_call_from: list[Tool] | None = None,
|
415 |
+
custom_role_conversions: dict[str, str] | None = None,
|
416 |
+
convert_images_to_image_urls: bool = False,
|
417 |
+
tool_choice: str | dict | None = "required", # Configurable tool_choice parameter
|
418 |
+
**kwargs,
|
419 |
+
) -> dict[str, Any]:
|
420 |
+
"""
|
421 |
+
Prepare parameters required for model invocation, handling parameter priorities.
|
422 |
+
|
423 |
+
Parameter priority from high to low:
|
424 |
+
1. Explicitly passed kwargs
|
425 |
+
2. Specific parameters (stop_sequences, response_format, etc.)
|
426 |
+
3. Default values in self.kwargs
|
427 |
+
"""
|
428 |
+
# Clean and standardize the message list
|
429 |
+
flatten_messages_as_text = kwargs.pop("flatten_messages_as_text", self.flatten_messages_as_text)
|
430 |
+
messages_as_dicts = get_clean_message_list(
|
431 |
+
messages,
|
432 |
+
role_conversions=custom_role_conversions or tool_role_conversions,
|
433 |
+
convert_images_to_image_urls=convert_images_to_image_urls,
|
434 |
+
flatten_messages_as_text=flatten_messages_as_text,
|
435 |
+
)
|
436 |
+
# Use self.kwargs as the base configuration
|
437 |
+
completion_kwargs = {
|
438 |
+
**self.kwargs,
|
439 |
+
"messages": messages_as_dicts,
|
440 |
+
}
|
441 |
+
|
442 |
+
# Handle specific parameters
|
443 |
+
if stop_sequences is not None:
|
444 |
+
# Some models do not support stop parameter
|
445 |
+
if supports_stop_parameter(self.model_id or ""):
|
446 |
+
completion_kwargs["stop"] = stop_sequences
|
447 |
+
if response_format is not None:
|
448 |
+
completion_kwargs["response_format"] = response_format
|
449 |
+
|
450 |
+
# Handle tools parameter
|
451 |
+
if tools_to_call_from:
|
452 |
+
tools_config = {
|
453 |
+
"tools": [get_tool_json_schema(tool) for tool in tools_to_call_from],
|
454 |
+
}
|
455 |
+
if tool_choice is not None:
|
456 |
+
tools_config["tool_choice"] = tool_choice
|
457 |
+
completion_kwargs.update(tools_config)
|
458 |
+
|
459 |
+
# Finally, use the passed-in kwargs to override all settings
|
460 |
+
completion_kwargs.update(kwargs)
|
461 |
+
|
462 |
+
return completion_kwargs
|
463 |
+
|
464 |
+
def generate(
|
465 |
+
self,
|
466 |
+
messages: list[ChatMessage],
|
467 |
+
stop_sequences: list[str] | None = None,
|
468 |
+
response_format: dict[str, str] | None = None,
|
469 |
+
tools_to_call_from: list[Tool] | None = None,
|
470 |
+
**kwargs,
|
471 |
+
) -> ChatMessage:
|
472 |
+
"""Process the input messages and return the model's response.
|
473 |
+
|
474 |
+
Parameters:
|
475 |
+
messages (`list[dict[str, str | list[dict]]] | list[ChatMessage]`):
|
476 |
+
A list of message dictionaries to be processed. Each dictionary should have the structure `{"role": "user/system", "content": "message content"}`.
|
477 |
+
stop_sequences (`List[str]`, *optional*):
|
478 |
+
A list of strings that will stop the generation if encountered in the model's output.
|
479 |
+
response_format (`dict[str, str]`, *optional*):
|
480 |
+
The response format to use in the model's response.
|
481 |
+
tools_to_call_from (`List[Tool]`, *optional*):
|
482 |
+
A list of tools that the model can use to generate responses.
|
483 |
+
**kwargs:
|
484 |
+
Additional keyword arguments to be passed to the underlying model.
|
485 |
+
|
486 |
+
Returns:
|
487 |
+
`ChatMessage`: A chat message object containing the model's response.
|
488 |
+
"""
|
489 |
+
raise NotImplementedError("This method must be implemented in child classes")
|
490 |
+
|
491 |
+
def __call__(self, *args, **kwargs):
|
492 |
+
return self.generate(*args, **kwargs)
|
493 |
+
|
494 |
+
def parse_tool_calls(self, message: ChatMessage) -> ChatMessage:
|
495 |
+
"""Sometimes APIs do not return the tool call as a specific object, so we need to parse it."""
|
496 |
+
message.role = MessageRole.ASSISTANT # Overwrite role if needed
|
497 |
+
if not message.tool_calls:
|
498 |
+
assert message.content is not None, "Message contains no content and no tool calls"
|
499 |
+
message.tool_calls = [
|
500 |
+
get_tool_call_from_text(message.content, self.tool_name_key, self.tool_arguments_key)
|
501 |
+
]
|
502 |
+
assert len(message.tool_calls) > 0, "No tool call was found in the model output"
|
503 |
+
for tool_call in message.tool_calls:
|
504 |
+
tool_call.function.arguments = parse_json_if_needed(tool_call.function.arguments)
|
505 |
+
return message
|
506 |
+
|
507 |
+
def to_dict(self) -> dict:
|
508 |
+
"""
|
509 |
+
Converts the model into a JSON-compatible dictionary.
|
510 |
+
"""
|
511 |
+
model_dictionary = {
|
512 |
+
**self.kwargs,
|
513 |
+
"model_id": self.model_id,
|
514 |
+
}
|
515 |
+
for attribute in [
|
516 |
+
"custom_role_conversion",
|
517 |
+
"temperature",
|
518 |
+
"max_tokens",
|
519 |
+
"provider",
|
520 |
+
"timeout",
|
521 |
+
"api_base",
|
522 |
+
"torch_dtype",
|
523 |
+
"device_map",
|
524 |
+
"organization",
|
525 |
+
"project",
|
526 |
+
"azure_endpoint",
|
527 |
+
]:
|
528 |
+
if hasattr(self, attribute):
|
529 |
+
model_dictionary[attribute] = getattr(self, attribute)
|
530 |
+
|
531 |
+
dangerous_attributes = ["token", "api_key"]
|
532 |
+
for attribute_name in dangerous_attributes:
|
533 |
+
if hasattr(self, attribute_name):
|
534 |
+
print(
|
535 |
+
f"For security reasons, we do not export the `{attribute_name}` attribute of your model. Please export it manually."
|
536 |
+
)
|
537 |
+
return model_dictionary
|
538 |
+
|
539 |
+
@classmethod
|
540 |
+
def from_dict(cls, model_dictionary: dict[str, Any]) -> "Model":
|
541 |
+
return cls(**{k: v for k, v in model_dictionary.items()})
|
542 |
+
|
543 |
+
|
544 |
+
class VLLMModel(Model):
|
545 |
+
"""Model to use [vLLM](https://docs.vllm.ai/) for fast LLM inference and serving.
|
546 |
+
|
547 |
+
Parameters:
|
548 |
+
model_id (`str`):
|
549 |
+
The Hugging Face model ID to be used for inference.
|
550 |
+
This can be a path or model identifier from the Hugging Face model hub.
|
551 |
+
model_kwargs (`dict[str, Any]`, *optional*):
|
552 |
+
Additional keyword arguments to pass to the vLLM model (like revision, max_model_len, etc.).
|
553 |
+
"""
|
554 |
+
|
555 |
+
def __init__(
|
556 |
+
self,
|
557 |
+
model_id,
|
558 |
+
model_kwargs: dict[str, Any] | None = None,
|
559 |
+
**kwargs,
|
560 |
+
):
|
561 |
+
if not _is_package_available("vllm"):
|
562 |
+
raise ModuleNotFoundError("Please install 'vllm' extra to use VLLMModel: `pip install 'smolagents[vllm]'`")
|
563 |
+
|
564 |
+
from vllm import LLM # type: ignore
|
565 |
+
from vllm.transformers_utils.tokenizer import get_tokenizer # type: ignore
|
566 |
+
|
567 |
+
self.model_kwargs = model_kwargs or {}
|
568 |
+
super().__init__(**kwargs)
|
569 |
+
self.model_id = model_id
|
570 |
+
self.model = LLM(model=model_id, **self.model_kwargs)
|
571 |
+
assert self.model is not None
|
572 |
+
self.tokenizer = get_tokenizer(model_id)
|
573 |
+
self._is_vlm = False # VLLMModel does not support vision models yet.
|
574 |
+
|
575 |
+
def cleanup(self):
|
576 |
+
import gc
|
577 |
+
|
578 |
+
import torch
|
579 |
+
from vllm.distributed.parallel_state import ( # type: ignore
|
580 |
+
destroy_distributed_environment,
|
581 |
+
destroy_model_parallel,
|
582 |
+
)
|
583 |
+
|
584 |
+
destroy_model_parallel()
|
585 |
+
if self.model is not None:
|
586 |
+
# taken from https://github.com/vllm-project/vllm/issues/1908#issuecomment-2076870351
|
587 |
+
del self.model.llm_engine.model_executor.driver_worker
|
588 |
+
gc.collect()
|
589 |
+
destroy_distributed_environment()
|
590 |
+
torch.cuda.empty_cache()
|
591 |
+
|
592 |
+
def generate(
|
593 |
+
self,
|
594 |
+
messages: list[ChatMessage],
|
595 |
+
stop_sequences: list[str] | None = None,
|
596 |
+
response_format: dict[str, str] | None = None,
|
597 |
+
tools_to_call_from: list[Tool] | None = None,
|
598 |
+
**kwargs,
|
599 |
+
) -> ChatMessage:
|
600 |
+
from vllm import SamplingParams # type: ignore
|
601 |
+
|
602 |
+
completion_kwargs = self._prepare_completion_kwargs(
|
603 |
+
messages=messages,
|
604 |
+
flatten_messages_as_text=(not self._is_vlm),
|
605 |
+
stop_sequences=stop_sequences,
|
606 |
+
tools_to_call_from=tools_to_call_from,
|
607 |
+
**kwargs,
|
608 |
+
)
|
609 |
+
# Override the OpenAI schema for VLLM compatibility
|
610 |
+
guided_options_request = {"guided_json": response_format["json_schema"]["schema"]} if response_format else None
|
611 |
+
|
612 |
+
messages = completion_kwargs.pop("messages")
|
613 |
+
prepared_stop_sequences = completion_kwargs.pop("stop", [])
|
614 |
+
tools = completion_kwargs.pop("tools", None)
|
615 |
+
completion_kwargs.pop("tool_choice", None)
|
616 |
+
|
617 |
+
prompt = self.tokenizer.apply_chat_template(
|
618 |
+
messages,
|
619 |
+
tools=tools,
|
620 |
+
add_generation_prompt=True,
|
621 |
+
tokenize=False,
|
622 |
+
)
|
623 |
+
|
624 |
+
sampling_params = SamplingParams(
|
625 |
+
n=kwargs.get("n", 1),
|
626 |
+
temperature=kwargs.get("temperature", 0.0),
|
627 |
+
max_tokens=kwargs.get("max_tokens", 2048),
|
628 |
+
stop=prepared_stop_sequences,
|
629 |
+
)
|
630 |
+
|
631 |
+
out = self.model.generate(
|
632 |
+
prompt,
|
633 |
+
sampling_params=sampling_params,
|
634 |
+
guided_options_request=guided_options_request,
|
635 |
+
)
|
636 |
+
|
637 |
+
output_text = out[0].outputs[0].text
|
638 |
+
self._last_input_token_count = len(out[0].prompt_token_ids)
|
639 |
+
self._last_output_token_count = len(out[0].outputs[0].token_ids)
|
640 |
+
return ChatMessage(
|
641 |
+
role=MessageRole.ASSISTANT,
|
642 |
+
content=output_text,
|
643 |
+
raw={"out": output_text, "completion_kwargs": completion_kwargs},
|
644 |
+
token_usage=TokenUsage(
|
645 |
+
input_tokens=len(out[0].prompt_token_ids),
|
646 |
+
output_tokens=len(out[0].outputs[0].token_ids),
|
647 |
+
),
|
648 |
+
)
|
649 |
+
|
650 |
+
|
651 |
+
class MLXModel(Model):
|
652 |
+
"""A class to interact with models loaded using MLX on Apple silicon.
|
653 |
+
|
654 |
+
> [!TIP]
|
655 |
+
> You must have `mlx-lm` installed on your machine. Please run `pip install smolagents[mlx-lm]` if it's not the case.
|
656 |
+
|
657 |
+
Parameters:
|
658 |
+
model_id (str):
|
659 |
+
The Hugging Face model ID to be used for inference. This can be a path or model identifier from the Hugging Face model hub.
|
660 |
+
tool_name_key (str):
|
661 |
+
The key, which can usually be found in the model's chat template, for retrieving a tool name.
|
662 |
+
tool_arguments_key (str):
|
663 |
+
The key, which can usually be found in the model's chat template, for retrieving tool arguments.
|
664 |
+
trust_remote_code (bool, default `False`):
|
665 |
+
Some models on the Hub require running remote code: for this model, you would have to set this flag to True.
|
666 |
+
load_kwargs (dict[str, Any], *optional*):
|
667 |
+
Additional keyword arguments to pass to the `mlx.lm.load` method when loading the model and tokenizer.
|
668 |
+
apply_chat_template_kwargs (dict, *optional*):
|
669 |
+
Additional keyword arguments to pass to the `apply_chat_template` method of the tokenizer.
|
670 |
+
kwargs (dict, *optional*):
|
671 |
+
Any additional keyword arguments that you want to use in model.generate(), for instance `max_tokens`.
|
672 |
+
|
673 |
+
Example:
|
674 |
+
```python
|
675 |
+
>>> engine = MLXModel(
|
676 |
+
... model_id="mlx-community/Qwen2.5-Coder-32B-Instruct-4bit",
|
677 |
+
... max_tokens=10000,
|
678 |
+
... )
|
679 |
+
>>> messages = [
|
680 |
+
... {
|
681 |
+
... "role": "user",
|
682 |
+
... "content": "Explain quantum mechanics in simple terms."
|
683 |
+
... }
|
684 |
+
... ]
|
685 |
+
>>> response = engine(messages, stop_sequences=["END"])
|
686 |
+
>>> print(response)
|
687 |
+
"Quantum mechanics is the branch of physics that studies..."
|
688 |
+
```
|
689 |
+
"""
|
690 |
+
|
691 |
+
def __init__(
|
692 |
+
self,
|
693 |
+
model_id: str,
|
694 |
+
trust_remote_code: bool = False,
|
695 |
+
load_kwargs: dict[str, Any] | None = None,
|
696 |
+
apply_chat_template_kwargs: dict[str, Any] | None = None,
|
697 |
+
**kwargs,
|
698 |
+
):
|
699 |
+
if not _is_package_available("mlx_lm"):
|
700 |
+
raise ModuleNotFoundError(
|
701 |
+
"Please install 'mlx-lm' extra to use 'MLXModel': `pip install 'smolagents[mlx-lm]'`"
|
702 |
+
)
|
703 |
+
import mlx_lm
|
704 |
+
|
705 |
+
self.load_kwargs = load_kwargs or {}
|
706 |
+
self.load_kwargs.setdefault("tokenizer_config", {}).setdefault("trust_remote_code", trust_remote_code)
|
707 |
+
self.apply_chat_template_kwargs = apply_chat_template_kwargs or {}
|
708 |
+
self.apply_chat_template_kwargs.setdefault("add_generation_prompt", True)
|
709 |
+
# mlx-lm doesn't support vision models: flatten_messages_as_text=True
|
710 |
+
super().__init__(model_id=model_id, flatten_messages_as_text=True, **kwargs)
|
711 |
+
|
712 |
+
self.model, self.tokenizer = mlx_lm.load(self.model_id, **self.load_kwargs)
|
713 |
+
self.stream_generate = mlx_lm.stream_generate
|
714 |
+
self.is_vlm = False # mlx-lm doesn't support vision models
|
715 |
+
|
716 |
+
def generate(
|
717 |
+
self,
|
718 |
+
messages: list[ChatMessage],
|
719 |
+
stop_sequences: list[str] | None = None,
|
720 |
+
response_format: dict[str, str] | None = None,
|
721 |
+
tools_to_call_from: list[Tool] | None = None,
|
722 |
+
**kwargs,
|
723 |
+
) -> ChatMessage:
|
724 |
+
if response_format is not None:
|
725 |
+
raise ValueError("MLX does not support structured outputs.")
|
726 |
+
completion_kwargs = self._prepare_completion_kwargs(
|
727 |
+
messages=messages,
|
728 |
+
stop_sequences=stop_sequences,
|
729 |
+
tools_to_call_from=tools_to_call_from,
|
730 |
+
**kwargs,
|
731 |
+
)
|
732 |
+
messages = completion_kwargs.pop("messages")
|
733 |
+
stops = completion_kwargs.pop("stop", [])
|
734 |
+
tools = completion_kwargs.pop("tools", None)
|
735 |
+
completion_kwargs.pop("tool_choice", None)
|
736 |
+
|
737 |
+
prompt_ids = self.tokenizer.apply_chat_template(messages, tools=tools, **self.apply_chat_template_kwargs)
|
738 |
+
|
739 |
+
output_tokens = 0
|
740 |
+
text = ""
|
741 |
+
for response in self.stream_generate(self.model, self.tokenizer, prompt=prompt_ids, **completion_kwargs):
|
742 |
+
output_tokens += 1
|
743 |
+
text += response.text
|
744 |
+
if any((stop_index := text.rfind(stop)) != -1 for stop in stops):
|
745 |
+
text = text[:stop_index]
|
746 |
+
break
|
747 |
+
|
748 |
+
self._last_input_token_count = len(prompt_ids)
|
749 |
+
self._last_output_token_count = output_tokens
|
750 |
+
return ChatMessage(
|
751 |
+
role=MessageRole.ASSISTANT,
|
752 |
+
content=text,
|
753 |
+
raw={"out": text, "completion_kwargs": completion_kwargs},
|
754 |
+
token_usage=TokenUsage(
|
755 |
+
input_tokens=len(prompt_ids),
|
756 |
+
output_tokens=output_tokens,
|
757 |
+
),
|
758 |
+
)
|
759 |
+
|
760 |
+
|
761 |
+
class TransformersModel(Model):
|
762 |
+
"""A class that uses Hugging Face's Transformers library for language model interaction.
|
763 |
+
|
764 |
+
This model allows you to load and use Hugging Face's models locally using the Transformers library. It supports features like stop sequences and grammar customization.
|
765 |
+
|
766 |
+
> [!TIP]
|
767 |
+
> You must have `transformers` and `torch` installed on your machine. Please run `pip install smolagents[transformers]` if it's not the case.
|
768 |
+
|
769 |
+
Parameters:
|
770 |
+
model_id (`str`):
|
771 |
+
The Hugging Face model ID to be used for inference. This can be a path or model identifier from the Hugging Face model hub.
|
772 |
+
For example, `"Qwen/Qwen2.5-Coder-32B-Instruct"`.
|
773 |
+
device_map (`str`, *optional*):
|
774 |
+
The device_map to initialize your model with.
|
775 |
+
torch_dtype (`str`, *optional*):
|
776 |
+
The torch_dtype to initialize your model with.
|
777 |
+
trust_remote_code (bool, default `False`):
|
778 |
+
Some models on the Hub require running remote code: for this model, you would have to set this flag to True.
|
779 |
+
kwargs (dict, *optional*):
|
780 |
+
Any additional keyword arguments that you want to use in model.generate(), for instance `max_new_tokens` or `device`.
|
781 |
+
**kwargs:
|
782 |
+
Additional keyword arguments to pass to `model.generate()`, for instance `max_new_tokens` or `device`.
|
783 |
+
Raises:
|
784 |
+
ValueError:
|
785 |
+
If the model name is not provided.
|
786 |
+
|
787 |
+
Example:
|
788 |
+
```python
|
789 |
+
>>> engine = TransformersModel(
|
790 |
+
... model_id="Qwen/Qwen2.5-Coder-32B-Instruct",
|
791 |
+
... device="cuda",
|
792 |
+
... max_new_tokens=5000,
|
793 |
+
... )
|
794 |
+
>>> messages = [{"role": "user", "content": "Explain quantum mechanics in simple terms."}]
|
795 |
+
>>> response = engine(messages, stop_sequences=["END"])
|
796 |
+
>>> print(response)
|
797 |
+
"Quantum mechanics is the branch of physics that studies..."
|
798 |
+
```
|
799 |
+
"""
|
800 |
+
|
801 |
+
def __init__(
|
802 |
+
self,
|
803 |
+
model_id: str | None = None,
|
804 |
+
device_map: str | None = None,
|
805 |
+
torch_dtype: str | None = None,
|
806 |
+
trust_remote_code: bool = False,
|
807 |
+
**kwargs,
|
808 |
+
):
|
809 |
+
try:
|
810 |
+
import torch
|
811 |
+
from transformers import (
|
812 |
+
AutoModelForCausalLM,
|
813 |
+
AutoModelForImageTextToText,
|
814 |
+
AutoProcessor,
|
815 |
+
AutoTokenizer,
|
816 |
+
TextIteratorStreamer,
|
817 |
+
)
|
818 |
+
except ModuleNotFoundError:
|
819 |
+
raise ModuleNotFoundError(
|
820 |
+
"Please install 'transformers' extra to use 'TransformersModel': `pip install 'smolagents[transformers]'`"
|
821 |
+
)
|
822 |
+
|
823 |
+
if not model_id:
|
824 |
+
warnings.warn(
|
825 |
+
"The 'model_id' parameter will be required in version 2.0.0. "
|
826 |
+
"Please update your code to pass this parameter to avoid future errors. "
|
827 |
+
"For now, it defaults to 'HuggingFaceTB/SmolLM2-1.7B-Instruct'.",
|
828 |
+
FutureWarning,
|
829 |
+
)
|
830 |
+
model_id = "HuggingFaceTB/SmolLM2-1.7B-Instruct"
|
831 |
+
|
832 |
+
default_max_tokens = 4096
|
833 |
+
max_new_tokens = kwargs.get("max_new_tokens") or kwargs.get("max_tokens")
|
834 |
+
if not max_new_tokens:
|
835 |
+
kwargs["max_new_tokens"] = default_max_tokens
|
836 |
+
logger.warning(
|
837 |
+
f"`max_new_tokens` not provided, using this default value for `max_new_tokens`: {default_max_tokens}"
|
838 |
+
)
|
839 |
+
|
840 |
+
if device_map is None:
|
841 |
+
device_map = "cuda" if torch.cuda.is_available() else "cpu"
|
842 |
+
logger.info(f"Using device: {device_map}")
|
843 |
+
self._is_vlm = False
|
844 |
+
try:
|
845 |
+
self.model = AutoModelForImageTextToText.from_pretrained(
|
846 |
+
model_id,
|
847 |
+
device_map=device_map,
|
848 |
+
torch_dtype=torch_dtype,
|
849 |
+
trust_remote_code=trust_remote_code,
|
850 |
+
)
|
851 |
+
self.processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=trust_remote_code)
|
852 |
+
self._is_vlm = True
|
853 |
+
self.streamer = TextIteratorStreamer(self.processor.tokenizer, skip_prompt=True, skip_special_tokens=True) # type: ignore
|
854 |
+
|
855 |
+
except ValueError as e:
|
856 |
+
if "Unrecognized configuration class" in str(e):
|
857 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
858 |
+
model_id,
|
859 |
+
device_map=device_map,
|
860 |
+
torch_dtype=torch_dtype,
|
861 |
+
trust_remote_code=trust_remote_code,
|
862 |
+
)
|
863 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=trust_remote_code)
|
864 |
+
self.streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True) # type: ignore
|
865 |
+
else:
|
866 |
+
raise e
|
867 |
+
except Exception as e:
|
868 |
+
raise ValueError(f"Failed to load tokenizer and model for {model_id=}: {e}") from e
|
869 |
+
super().__init__(flatten_messages_as_text=not self._is_vlm, model_id=model_id, **kwargs)
|
870 |
+
|
871 |
+
def make_stopping_criteria(self, stop_sequences: list[str], tokenizer) -> "StoppingCriteriaList":
|
872 |
+
from transformers import StoppingCriteria, StoppingCriteriaList
|
873 |
+
|
874 |
+
class StopOnStrings(StoppingCriteria):
|
875 |
+
def __init__(self, stop_strings: list[str], tokenizer):
|
876 |
+
self.stop_strings = stop_strings
|
877 |
+
self.tokenizer = tokenizer
|
878 |
+
self.stream = ""
|
879 |
+
|
880 |
+
def reset(self):
|
881 |
+
self.stream = ""
|
882 |
+
|
883 |
+
def __call__(self, input_ids, scores, **kwargs):
|
884 |
+
generated = self.tokenizer.decode(input_ids[0][-1], skip_special_tokens=True)
|
885 |
+
self.stream += generated
|
886 |
+
if any([self.stream.endswith(stop_string) for stop_string in self.stop_strings]):
|
887 |
+
return True
|
888 |
+
return False
|
889 |
+
|
890 |
+
return StoppingCriteriaList([StopOnStrings(stop_sequences, tokenizer)])
|
891 |
+
|
892 |
+
def _prepare_completion_args(
|
893 |
+
self,
|
894 |
+
messages: list[ChatMessage],
|
895 |
+
stop_sequences: list[str] | None = None,
|
896 |
+
tools_to_call_from: list[Tool] | None = None,
|
897 |
+
**kwargs,
|
898 |
+
) -> dict[str, Any]:
|
899 |
+
completion_kwargs = self._prepare_completion_kwargs(
|
900 |
+
messages=messages,
|
901 |
+
stop_sequences=stop_sequences,
|
902 |
+
**kwargs,
|
903 |
+
)
|
904 |
+
|
905 |
+
messages = completion_kwargs.pop("messages")
|
906 |
+
stop_sequences = completion_kwargs.pop("stop", None)
|
907 |
+
tools = completion_kwargs.pop("tools", None)
|
908 |
+
|
909 |
+
max_new_tokens = (
|
910 |
+
kwargs.get("max_new_tokens")
|
911 |
+
or kwargs.get("max_tokens")
|
912 |
+
or self.kwargs.get("max_new_tokens")
|
913 |
+
or self.kwargs.get("max_tokens")
|
914 |
+
or 1024
|
915 |
+
)
|
916 |
+
prompt_tensor = (self.processor if hasattr(self, "processor") else self.tokenizer).apply_chat_template(
|
917 |
+
messages,
|
918 |
+
tools=tools,
|
919 |
+
return_tensors="pt",
|
920 |
+
add_generation_prompt=True,
|
921 |
+
tokenize=True,
|
922 |
+
return_dict=True,
|
923 |
+
)
|
924 |
+
prompt_tensor = prompt_tensor.to(self.model.device) # type: ignore
|
925 |
+
if hasattr(prompt_tensor, "input_ids"):
|
926 |
+
prompt_tensor = prompt_tensor["input_ids"]
|
927 |
+
|
928 |
+
model_tokenizer = self.processor.tokenizer if hasattr(self, "processor") else self.tokenizer
|
929 |
+
stopping_criteria = (
|
930 |
+
self.make_stopping_criteria(stop_sequences, tokenizer=model_tokenizer) if stop_sequences else None
|
931 |
+
)
|
932 |
+
completion_kwargs["max_new_tokens"] = max_new_tokens
|
933 |
+
return dict(
|
934 |
+
inputs=prompt_tensor,
|
935 |
+
use_cache=True,
|
936 |
+
stopping_criteria=stopping_criteria,
|
937 |
+
**completion_kwargs,
|
938 |
+
)
|
939 |
+
|
940 |
+
def generate(
|
941 |
+
self,
|
942 |
+
messages: list[ChatMessage],
|
943 |
+
stop_sequences: list[str] | None = None,
|
944 |
+
response_format: dict[str, str] | None = None,
|
945 |
+
tools_to_call_from: list[Tool] | None = None,
|
946 |
+
**kwargs,
|
947 |
+
) -> ChatMessage:
|
948 |
+
if response_format is not None:
|
949 |
+
raise ValueError("Transformers does not support structured outputs, use VLLMModel for this.")
|
950 |
+
generation_kwargs = self._prepare_completion_args(
|
951 |
+
messages=messages,
|
952 |
+
stop_sequences=stop_sequences,
|
953 |
+
tools_to_call_from=tools_to_call_from,
|
954 |
+
**kwargs,
|
955 |
+
)
|
956 |
+
count_prompt_tokens = generation_kwargs["inputs"].shape[1] # type: ignore
|
957 |
+
out = self.model.generate(
|
958 |
+
**generation_kwargs,
|
959 |
+
)
|
960 |
+
generated_tokens = out[0, count_prompt_tokens:]
|
961 |
+
if hasattr(self, "processor"):
|
962 |
+
output_text = self.processor.decode(generated_tokens, skip_special_tokens=True)
|
963 |
+
else:
|
964 |
+
output_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
965 |
+
|
966 |
+
if stop_sequences is not None:
|
967 |
+
output_text = remove_stop_sequences(output_text, stop_sequences)
|
968 |
+
|
969 |
+
self._last_input_token_count = count_prompt_tokens
|
970 |
+
self._last_output_token_count = len(generated_tokens)
|
971 |
+
return ChatMessage(
|
972 |
+
role=MessageRole.ASSISTANT,
|
973 |
+
content=output_text,
|
974 |
+
raw={
|
975 |
+
"out": output_text,
|
976 |
+
"completion_kwargs": {key: value for key, value in generation_kwargs.items() if key != "inputs"},
|
977 |
+
},
|
978 |
+
token_usage=TokenUsage(
|
979 |
+
input_tokens=count_prompt_tokens,
|
980 |
+
output_tokens=len(generated_tokens),
|
981 |
+
),
|
982 |
+
)
|
983 |
+
|
984 |
+
def generate_stream(
|
985 |
+
self,
|
986 |
+
messages: list[ChatMessage],
|
987 |
+
stop_sequences: list[str] | None = None,
|
988 |
+
response_format: dict[str, str] | None = None,
|
989 |
+
tools_to_call_from: list[Tool] | None = None,
|
990 |
+
**kwargs,
|
991 |
+
) -> Generator[ChatMessageStreamDelta]:
|
992 |
+
if response_format is not None:
|
993 |
+
raise ValueError("Transformers does not support structured outputs, use VLLMModel for this.")
|
994 |
+
generation_kwargs = self._prepare_completion_args(
|
995 |
+
messages=messages,
|
996 |
+
stop_sequences=stop_sequences,
|
997 |
+
response_format=response_format,
|
998 |
+
tools_to_call_from=tools_to_call_from,
|
999 |
+
**kwargs,
|
1000 |
+
)
|
1001 |
+
count_prompt_tokens = generation_kwargs["inputs"].shape[1] # type: ignore
|
1002 |
+
|
1003 |
+
thread = Thread(target=self.model.generate, kwargs={"streamer": self.streamer, **generation_kwargs})
|
1004 |
+
thread.start()
|
1005 |
+
|
1006 |
+
# Generate with streaming
|
1007 |
+
for new_text in self.streamer:
|
1008 |
+
self._last_input_token_count = count_prompt_tokens
|
1009 |
+
self._last_output_token_count = 1
|
1010 |
+
yield ChatMessageStreamDelta(
|
1011 |
+
content=new_text,
|
1012 |
+
tool_calls=None,
|
1013 |
+
token_usage=TokenUsage(input_tokens=count_prompt_tokens, output_tokens=1),
|
1014 |
+
)
|
1015 |
+
thread.join()
|
1016 |
+
|
1017 |
+
|
1018 |
+
class ApiModel(Model):
|
1019 |
+
"""
|
1020 |
+
Base class for API-based language models.
|
1021 |
+
|
1022 |
+
This class serves as a foundation for implementing models that interact with
|
1023 |
+
external APIs. It handles the common functionality for managing model IDs,
|
1024 |
+
custom role mappings, and API client connections.
|
1025 |
+
|
1026 |
+
Parameters:
|
1027 |
+
model_id (`str`):
|
1028 |
+
The identifier for the model to be used with the API.
|
1029 |
+
custom_role_conversions (`dict[str, str`], **optional**):
|
1030 |
+
Mapping to convert between internal role names and API-specific role names. Defaults to None.
|
1031 |
+
client (`Any`, **optional**):
|
1032 |
+
Pre-configured API client instance. If not provided, a default client will be created. Defaults to None.
|
1033 |
+
**kwargs: Additional keyword arguments to pass to the parent class.
|
1034 |
+
"""
|
1035 |
+
|
1036 |
+
def __init__(
|
1037 |
+
self, model_id: str, custom_role_conversions: dict[str, str] | None = None, client: Any | None = None, **kwargs
|
1038 |
+
):
|
1039 |
+
super().__init__(model_id=model_id, **kwargs)
|
1040 |
+
self.custom_role_conversions = custom_role_conversions or {}
|
1041 |
+
self.client = client or self.create_client()
|
1042 |
+
|
1043 |
+
def create_client(self):
|
1044 |
+
"""Create the API client for the specific service."""
|
1045 |
+
raise NotImplementedError("Subclasses must implement this method to create a client")
|
1046 |
+
|
1047 |
+
|
1048 |
+
class LiteLLMModel(ApiModel):
|
1049 |
+
"""Model to use [LiteLLM Python SDK](https://docs.litellm.ai/docs/#litellm-python-sdk) to access hundreds of LLMs.
|
1050 |
+
|
1051 |
+
Parameters:
|
1052 |
+
model_id (`str`):
|
1053 |
+
The model identifier to use on the server (e.g. "gpt-3.5-turbo").
|
1054 |
+
api_base (`str`, *optional*):
|
1055 |
+
The base URL of the provider API to call the model.
|
1056 |
+
api_key (`str`, *optional*):
|
1057 |
+
The API key to use for authentication.
|
1058 |
+
custom_role_conversions (`dict[str, str]`, *optional*):
|
1059 |
+
Custom role conversion mapping to convert message roles in others.
|
1060 |
+
Useful for specific models that do not support specific message roles like "system".
|
1061 |
+
flatten_messages_as_text (`bool`, *optional*): Whether to flatten messages as text.
|
1062 |
+
Defaults to `True` for models that start with "ollama", "groq", "cerebras".
|
1063 |
+
**kwargs:
|
1064 |
+
Additional keyword arguments to pass to the OpenAI API.
|
1065 |
+
"""
|
1066 |
+
|
1067 |
+
def __init__(
|
1068 |
+
self,
|
1069 |
+
model_id: str | None = None,
|
1070 |
+
api_base: str | None = None,
|
1071 |
+
api_key: str | None = None,
|
1072 |
+
custom_role_conversions: dict[str, str] | None = None,
|
1073 |
+
flatten_messages_as_text: bool | None = None,
|
1074 |
+
**kwargs,
|
1075 |
+
):
|
1076 |
+
if not model_id:
|
1077 |
+
warnings.warn(
|
1078 |
+
"The 'model_id' parameter will be required in version 2.0.0. "
|
1079 |
+
"Please update your code to pass this parameter to avoid future errors. "
|
1080 |
+
"For now, it defaults to 'anthropic/claude-3-5-sonnet-20240620'.",
|
1081 |
+
FutureWarning,
|
1082 |
+
)
|
1083 |
+
model_id = "anthropic/claude-3-5-sonnet-20240620"
|
1084 |
+
self.api_base = api_base
|
1085 |
+
self.api_key = api_key
|
1086 |
+
flatten_messages_as_text = (
|
1087 |
+
flatten_messages_as_text
|
1088 |
+
if flatten_messages_as_text is not None
|
1089 |
+
else model_id.startswith(("ollama", "groq", "cerebras"))
|
1090 |
+
)
|
1091 |
+
super().__init__(
|
1092 |
+
model_id=model_id,
|
1093 |
+
custom_role_conversions=custom_role_conversions,
|
1094 |
+
flatten_messages_as_text=flatten_messages_as_text,
|
1095 |
+
**kwargs,
|
1096 |
+
)
|
1097 |
+
|
1098 |
+
def create_client(self):
|
1099 |
+
"""Create the LiteLLM client."""
|
1100 |
+
try:
|
1101 |
+
import litellm
|
1102 |
+
except ModuleNotFoundError as e:
|
1103 |
+
raise ModuleNotFoundError(
|
1104 |
+
"Please install 'litellm' extra to use LiteLLMModel: `pip install 'smolagents[litellm]'`"
|
1105 |
+
) from e
|
1106 |
+
|
1107 |
+
return litellm
|
1108 |
+
|
1109 |
+
def generate(
|
1110 |
+
self,
|
1111 |
+
messages: list[ChatMessage],
|
1112 |
+
stop_sequences: list[str] | None = None,
|
1113 |
+
response_format: dict[str, str] | None = None,
|
1114 |
+
tools_to_call_from: list[Tool] | None = None,
|
1115 |
+
**kwargs,
|
1116 |
+
) -> ChatMessage:
|
1117 |
+
completion_kwargs = self._prepare_completion_kwargs(
|
1118 |
+
messages=messages,
|
1119 |
+
stop_sequences=stop_sequences,
|
1120 |
+
response_format=response_format,
|
1121 |
+
tools_to_call_from=tools_to_call_from,
|
1122 |
+
model=self.model_id,
|
1123 |
+
api_base=self.api_base,
|
1124 |
+
api_key=self.api_key,
|
1125 |
+
convert_images_to_image_urls=True,
|
1126 |
+
custom_role_conversions=self.custom_role_conversions,
|
1127 |
+
**kwargs,
|
1128 |
+
)
|
1129 |
+
|
1130 |
+
response = self.client.completion(**completion_kwargs)
|
1131 |
+
|
1132 |
+
self._last_input_token_count = response.usage.prompt_tokens
|
1133 |
+
self._last_output_token_count = response.usage.completion_tokens
|
1134 |
+
return ChatMessage.from_dict(
|
1135 |
+
response.choices[0].message.model_dump(include={"role", "content", "tool_calls"}),
|
1136 |
+
raw=response,
|
1137 |
+
token_usage=TokenUsage(
|
1138 |
+
input_tokens=response.usage.prompt_tokens,
|
1139 |
+
output_tokens=response.usage.completion_tokens,
|
1140 |
+
),
|
1141 |
+
)
|
1142 |
+
|
1143 |
+
def generate_stream(
|
1144 |
+
self,
|
1145 |
+
messages: list[ChatMessage],
|
1146 |
+
stop_sequences: list[str] | None = None,
|
1147 |
+
response_format: dict[str, str] | None = None,
|
1148 |
+
tools_to_call_from: list[Tool] | None = None,
|
1149 |
+
**kwargs,
|
1150 |
+
) -> Generator[ChatMessageStreamDelta]:
|
1151 |
+
completion_kwargs = self._prepare_completion_kwargs(
|
1152 |
+
messages=messages,
|
1153 |
+
stop_sequences=stop_sequences,
|
1154 |
+
response_format=response_format,
|
1155 |
+
tools_to_call_from=tools_to_call_from,
|
1156 |
+
model=self.model_id,
|
1157 |
+
api_base=self.api_base,
|
1158 |
+
api_key=self.api_key,
|
1159 |
+
custom_role_conversions=self.custom_role_conversions,
|
1160 |
+
convert_images_to_image_urls=True,
|
1161 |
+
**kwargs,
|
1162 |
+
)
|
1163 |
+
for event in self.client.completion(**completion_kwargs, stream=True, stream_options={"include_usage": True}):
|
1164 |
+
if getattr(event, "usage", None):
|
1165 |
+
self._last_input_token_count = event.usage.prompt_tokens
|
1166 |
+
self._last_output_token_count = event.usage.completion_tokens
|
1167 |
+
yield ChatMessageStreamDelta(
|
1168 |
+
content="",
|
1169 |
+
token_usage=TokenUsage(
|
1170 |
+
input_tokens=event.usage.prompt_tokens,
|
1171 |
+
output_tokens=event.usage.completion_tokens,
|
1172 |
+
),
|
1173 |
+
)
|
1174 |
+
if event.choices:
|
1175 |
+
choice = event.choices[0]
|
1176 |
+
if choice.delta:
|
1177 |
+
yield ChatMessageStreamDelta(
|
1178 |
+
content=choice.delta.content,
|
1179 |
+
tool_calls=[
|
1180 |
+
ChatMessageToolCallStreamDelta(
|
1181 |
+
index=delta.index,
|
1182 |
+
id=delta.id,
|
1183 |
+
type=delta.type,
|
1184 |
+
function=delta.function,
|
1185 |
+
)
|
1186 |
+
for delta in choice.delta.tool_calls
|
1187 |
+
]
|
1188 |
+
if choice.delta.tool_calls
|
1189 |
+
else None,
|
1190 |
+
)
|
1191 |
+
else:
|
1192 |
+
if not getattr(choice, "finish_reason", None):
|
1193 |
+
raise ValueError(f"No content or tool calls in event: {event}")
|
1194 |
+
|
1195 |
+
|
1196 |
+
class LiteLLMRouterModel(LiteLLMModel):
|
1197 |
+
"""Router‑based client for interacting with the [LiteLLM Python SDK Router](https://docs.litellm.ai/docs/routing).
|
1198 |
+
|
1199 |
+
This class provides a high-level interface for distributing requests among multiple language models using
|
1200 |
+
the LiteLLM SDK's routing capabilities. It is responsible for initializing and configuring the router client,
|
1201 |
+
applying custom role conversions, and managing message formatting to ensure seamless integration with various LLMs.
|
1202 |
+
|
1203 |
+
Parameters:
|
1204 |
+
model_id (`str`):
|
1205 |
+
Identifier for the model group to use from the model list (e.g., "model-group-1").
|
1206 |
+
model_list (`list[dict[str, Any]]`):
|
1207 |
+
Model configurations to be used for routing.
|
1208 |
+
Each configuration should include the model group name and any necessary parameters.
|
1209 |
+
For more details, refer to the [LiteLLM Routing](https://docs.litellm.ai/docs/routing#quick-start) documentation.
|
1210 |
+
client_kwargs (`dict[str, Any]`, *optional*):
|
1211 |
+
Additional configuration parameters for the Router client. For more details, see the
|
1212 |
+
[LiteLLM Routing Configurations](https://docs.litellm.ai/docs/routing).
|
1213 |
+
custom_role_conversions (`dict[str, str]`, *optional*):
|
1214 |
+
Custom role conversion mapping to convert message roles in others.
|
1215 |
+
Useful for specific models that do not support specific message roles like "system".
|
1216 |
+
flatten_messages_as_text (`bool`, *optional*): Whether to flatten messages as text.
|
1217 |
+
Defaults to `True` for models that start with "ollama", "groq", "cerebras".
|
1218 |
+
**kwargs:
|
1219 |
+
Additional keyword arguments to pass to the LiteLLM Router completion method.
|
1220 |
+
|
1221 |
+
Example:
|
1222 |
+
```python
|
1223 |
+
>>> import os
|
1224 |
+
>>> from smolagents import CodeAgent, WebSearchTool, LiteLLMRouterModel
|
1225 |
+
>>> os.environ["OPENAI_API_KEY"] = ""
|
1226 |
+
>>> os.environ["AWS_ACCESS_KEY_ID"] = ""
|
1227 |
+
>>> os.environ["AWS_SECRET_ACCESS_KEY"] = ""
|
1228 |
+
>>> os.environ["AWS_REGION"] = ""
|
1229 |
+
>>> llm_loadbalancer_model_list = [
|
1230 |
+
... {
|
1231 |
+
... "model_name": "model-group-1",
|
1232 |
+
... "litellm_params": {
|
1233 |
+
... "model": "gpt-4o-mini",
|
1234 |
+
... "api_key": os.getenv("OPENAI_API_KEY"),
|
1235 |
+
... },
|
1236 |
+
... },
|
1237 |
+
... {
|
1238 |
+
... "model_name": "model-group-1",
|
1239 |
+
... "litellm_params": {
|
1240 |
+
... "model": "bedrock/anthropic.claude-3-sonnet-20240229-v1:0",
|
1241 |
+
... "aws_access_key_id": os.getenv("AWS_ACCESS_KEY_ID"),
|
1242 |
+
... "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"),
|
1243 |
+
... "aws_region_name": os.getenv("AWS_REGION"),
|
1244 |
+
... },
|
1245 |
+
... },
|
1246 |
+
>>> ]
|
1247 |
+
>>> model = LiteLLMRouterModel(
|
1248 |
+
... model_id="model-group-1",
|
1249 |
+
... model_list=llm_loadbalancer_model_list,
|
1250 |
+
... client_kwargs={
|
1251 |
+
... "routing_strategy":"simple-shuffle"
|
1252 |
+
... }
|
1253 |
+
>>> )
|
1254 |
+
>>> agent = CodeAgent(tools=[WebSearchTool()], model=model)
|
1255 |
+
>>> agent.run("How many seconds would it take for a leopard at full speed to run through Pont des Arts?")
|
1256 |
+
```
|
1257 |
+
"""
|
1258 |
+
|
1259 |
+
def __init__(
|
1260 |
+
self,
|
1261 |
+
model_id: str,
|
1262 |
+
model_list: list[dict[str, Any]],
|
1263 |
+
client_kwargs: dict[str, Any] | None = None,
|
1264 |
+
custom_role_conversions: dict[str, str] | None = None,
|
1265 |
+
flatten_messages_as_text: bool | None = None,
|
1266 |
+
**kwargs,
|
1267 |
+
):
|
1268 |
+
self.client_kwargs = {
|
1269 |
+
"model_list": model_list,
|
1270 |
+
**(client_kwargs or {}),
|
1271 |
+
}
|
1272 |
+
super().__init__(
|
1273 |
+
model_id=model_id,
|
1274 |
+
custom_role_conversions=custom_role_conversions,
|
1275 |
+
flatten_messages_as_text=flatten_messages_as_text,
|
1276 |
+
**kwargs,
|
1277 |
+
)
|
1278 |
+
|
1279 |
+
def create_client(self):
|
1280 |
+
try:
|
1281 |
+
from litellm.router import Router
|
1282 |
+
except ModuleNotFoundError as e:
|
1283 |
+
raise ModuleNotFoundError(
|
1284 |
+
"Please install 'litellm' extra to use LiteLLMRouterModel: `pip install 'smolagents[litellm]'`"
|
1285 |
+
) from e
|
1286 |
+
return Router(**self.client_kwargs)
|
1287 |
+
|
1288 |
+
|
1289 |
+
class InferenceClientModel(ApiModel):
|
1290 |
+
"""A class to interact with Hugging Face's Inference Providers for language model interaction.
|
1291 |
+
|
1292 |
+
This model allows you to communicate with Hugging Face's models using Inference Providers. It can be used in both serverless mode, with a dedicated endpoint, or even with a local URL, supporting features like stop sequences and grammar customization.
|
1293 |
+
|
1294 |
+
Providers include Cerebras, Cohere, Fal, Fireworks, HF-Inference, Hyperbolic, Nebius, Novita, Replicate, SambaNova, Together, and more.
|
1295 |
+
|
1296 |
+
Parameters:
|
1297 |
+
model_id (`str`, *optional*, default `"Qwen/Qwen2.5-Coder-32B-Instruct"`):
|
1298 |
+
The Hugging Face model ID to be used for inference.
|
1299 |
+
This can be a model identifier from the Hugging Face model hub or a URL to a deployed Inference Endpoint.
|
1300 |
+
Currently, it defaults to `"Qwen/Qwen2.5-Coder-32B-Instruct"`, but this may change in the future.
|
1301 |
+
provider (`str`, *optional*):
|
1302 |
+
Name of the provider to use for inference. A list of supported providers can be found in the [Inference Providers documentation](https://huggingface.co/docs/inference-providers/index#partners).
|
1303 |
+
Defaults to "auto" i.e. the first of the providers available for the model, sorted by the user's order [here](https://hf.co/settings/inference-providers).
|
1304 |
+
If `base_url` is passed, then `provider` is not used.
|
1305 |
+
token (`str`, *optional*):
|
1306 |
+
Token used by the Hugging Face API for authentication. This token need to be authorized 'Make calls to the serverless Inference Providers'.
|
1307 |
+
If the model is gated (like Llama-3 models), the token also needs 'Read access to contents of all public gated repos you can access'.
|
1308 |
+
If not provided, the class will try to use environment variable 'HF_TOKEN', else use the token stored in the Hugging Face CLI configuration.
|
1309 |
+
timeout (`int`, *optional*, defaults to 120):
|
1310 |
+
Timeout for the API request, in seconds.
|
1311 |
+
client_kwargs (`dict[str, Any]`, *optional*):
|
1312 |
+
Additional keyword arguments to pass to the Hugging Face InferenceClient.
|
1313 |
+
custom_role_conversions (`dict[str, str]`, *optional*):
|
1314 |
+
Custom role conversion mapping to convert message roles in others.
|
1315 |
+
Useful for specific models that do not support specific message roles like "system".
|
1316 |
+
api_key (`str`, *optional*):
|
1317 |
+
Token to use for authentication. This is a duplicated argument from `token` to make [`InferenceClientModel`]
|
1318 |
+
follow the same pattern as `openai.OpenAI` client. Cannot be used if `token` is set. Defaults to None.
|
1319 |
+
bill_to (`str`, *optional*):
|
1320 |
+
The billing account to use for the requests. By default the requests are billed on the user's account. Requests can only be billed to
|
1321 |
+
an organization the user is a member of, and which has subscribed to Enterprise Hub.
|
1322 |
+
base_url (`str`, `optional`):
|
1323 |
+
Base URL to run inference. This is a duplicated argument from `model` to make [`InferenceClientModel`]
|
1324 |
+
follow the same pattern as `openai.OpenAI` client. Cannot be used if `model` is set. Defaults to None.
|
1325 |
+
**kwargs:
|
1326 |
+
Additional keyword arguments to pass to the Hugging Face InferenceClient.
|
1327 |
+
|
1328 |
+
Raises:
|
1329 |
+
ValueError:
|
1330 |
+
If the model name is not provided.
|
1331 |
+
|
1332 |
+
Example:
|
1333 |
+
```python
|
1334 |
+
>>> engine = InferenceClientModel(
|
1335 |
+
... model_id="Qwen/Qwen2.5-Coder-32B-Instruct",
|
1336 |
+
... provider="nebius",
|
1337 |
+
... token="your_hf_token_here",
|
1338 |
+
... max_tokens=5000,
|
1339 |
+
... )
|
1340 |
+
>>> messages = [{"role": "user", "content": "Explain quantum mechanics in simple terms."}]
|
1341 |
+
>>> response = engine(messages, stop_sequences=["END"])
|
1342 |
+
>>> print(response)
|
1343 |
+
"Quantum mechanics is the branch of physics that studies..."
|
1344 |
+
```
|
1345 |
+
"""
|
1346 |
+
|
1347 |
+
def __init__(
|
1348 |
+
self,
|
1349 |
+
model_id: str = "Qwen/Qwen2.5-Coder-32B-Instruct",
|
1350 |
+
provider: str | None = None,
|
1351 |
+
token: str | None = None,
|
1352 |
+
timeout: int = 120,
|
1353 |
+
client_kwargs: dict[str, Any] | None = None,
|
1354 |
+
custom_role_conversions: dict[str, str] | None = None,
|
1355 |
+
api_key: str | None = None,
|
1356 |
+
bill_to: str | None = None,
|
1357 |
+
base_url: str | None = None,
|
1358 |
+
**kwargs,
|
1359 |
+
):
|
1360 |
+
if token is not None and api_key is not None:
|
1361 |
+
raise ValueError(
|
1362 |
+
"Received both `token` and `api_key` arguments. Please provide only one of them."
|
1363 |
+
" `api_key` is an alias for `token` to make the API compatible with OpenAI's client."
|
1364 |
+
" It has the exact same behavior as `token`."
|
1365 |
+
)
|
1366 |
+
token = token if token is not None else api_key
|
1367 |
+
if token is None:
|
1368 |
+
token = os.getenv("HF_TOKEN")
|
1369 |
+
self.client_kwargs = {
|
1370 |
+
**(client_kwargs or {}),
|
1371 |
+
"model": model_id,
|
1372 |
+
"provider": provider,
|
1373 |
+
"token": token,
|
1374 |
+
"timeout": timeout,
|
1375 |
+
"bill_to": bill_to,
|
1376 |
+
"base_url": base_url,
|
1377 |
+
}
|
1378 |
+
super().__init__(model_id=model_id, custom_role_conversions=custom_role_conversions, **kwargs)
|
1379 |
+
|
1380 |
+
def create_client(self):
|
1381 |
+
"""Create the Hugging Face client."""
|
1382 |
+
from huggingface_hub import InferenceClient
|
1383 |
+
|
1384 |
+
return InferenceClient(**self.client_kwargs)
|
1385 |
+
|
1386 |
+
def generate(
|
1387 |
+
self,
|
1388 |
+
messages: list[ChatMessage],
|
1389 |
+
stop_sequences: list[str] | None = None,
|
1390 |
+
response_format: dict[str, str] | None = None,
|
1391 |
+
tools_to_call_from: list[Tool] | None = None,
|
1392 |
+
**kwargs,
|
1393 |
+
) -> ChatMessage:
|
1394 |
+
if response_format is not None and self.client_kwargs["provider"] not in STRUCTURED_GENERATION_PROVIDERS:
|
1395 |
+
raise ValueError(
|
1396 |
+
"InferenceClientModel only supports structured outputs with these providers:"
|
1397 |
+
+ ", ".join(STRUCTURED_GENERATION_PROVIDERS)
|
1398 |
+
)
|
1399 |
+
completion_kwargs = self._prepare_completion_kwargs(
|
1400 |
+
messages=messages,
|
1401 |
+
stop_sequences=stop_sequences,
|
1402 |
+
tools_to_call_from=tools_to_call_from,
|
1403 |
+
# response_format=response_format,
|
1404 |
+
convert_images_to_image_urls=True,
|
1405 |
+
custom_role_conversions=self.custom_role_conversions,
|
1406 |
+
**kwargs,
|
1407 |
+
)
|
1408 |
+
response = self.client.chat_completion(**completion_kwargs)
|
1409 |
+
|
1410 |
+
self._last_input_token_count = response.usage.prompt_tokens
|
1411 |
+
self._last_output_token_count = response.usage.completion_tokens
|
1412 |
+
return ChatMessage.from_dict(
|
1413 |
+
asdict(response.choices[0].message),
|
1414 |
+
raw=response,
|
1415 |
+
token_usage=TokenUsage(
|
1416 |
+
input_tokens=response.usage.prompt_tokens,
|
1417 |
+
output_tokens=response.usage.completion_tokens,
|
1418 |
+
),
|
1419 |
+
)
|
1420 |
+
|
1421 |
+
def generate_stream(
|
1422 |
+
self,
|
1423 |
+
messages: list[ChatMessage],
|
1424 |
+
stop_sequences: list[str] | None = None,
|
1425 |
+
response_format: dict[str, str] | None = None,
|
1426 |
+
tools_to_call_from: list[Tool] | None = None,
|
1427 |
+
**kwargs,
|
1428 |
+
) -> Generator[ChatMessageStreamDelta]:
|
1429 |
+
completion_kwargs = self._prepare_completion_kwargs(
|
1430 |
+
messages=messages,
|
1431 |
+
stop_sequences=stop_sequences,
|
1432 |
+
response_format=response_format,
|
1433 |
+
tools_to_call_from=tools_to_call_from,
|
1434 |
+
model=self.model_id,
|
1435 |
+
custom_role_conversions=self.custom_role_conversions,
|
1436 |
+
convert_images_to_image_urls=True,
|
1437 |
+
**kwargs,
|
1438 |
+
)
|
1439 |
+
for event in self.client.chat.completions.create(
|
1440 |
+
**completion_kwargs, stream=True, stream_options={"include_usage": True}
|
1441 |
+
):
|
1442 |
+
if getattr(event, "usage", None):
|
1443 |
+
self._last_input_token_count = event.usage.prompt_tokens
|
1444 |
+
self._last_output_token_count = event.usage.completion_tokens
|
1445 |
+
yield ChatMessageStreamDelta(
|
1446 |
+
content="",
|
1447 |
+
token_usage=TokenUsage(
|
1448 |
+
input_tokens=event.usage.prompt_tokens,
|
1449 |
+
output_tokens=event.usage.completion_tokens,
|
1450 |
+
),
|
1451 |
+
)
|
1452 |
+
if event.choices:
|
1453 |
+
choice = event.choices[0]
|
1454 |
+
if choice.delta:
|
1455 |
+
yield ChatMessageStreamDelta(
|
1456 |
+
content=choice.delta.content,
|
1457 |
+
tool_calls=[
|
1458 |
+
ChatMessageToolCallStreamDelta(
|
1459 |
+
index=delta.index,
|
1460 |
+
id=delta.id,
|
1461 |
+
type=delta.type,
|
1462 |
+
function=delta.function,
|
1463 |
+
)
|
1464 |
+
for delta in choice.delta.tool_calls
|
1465 |
+
]
|
1466 |
+
if choice.delta.tool_calls
|
1467 |
+
else None,
|
1468 |
+
)
|
1469 |
+
else:
|
1470 |
+
if not getattr(choice, "finish_reason", None):
|
1471 |
+
raise ValueError(f"No content or tool calls in event: {event}")
|
1472 |
+
|
1473 |
+
|
1474 |
+
class OpenAIServerModel(ApiModel):
|
1475 |
+
"""This model connects to an OpenAI-compatible API server.
|
1476 |
+
|
1477 |
+
Parameters:
|
1478 |
+
model_id (`str`):
|
1479 |
+
The model identifier to use on the server (e.g. "gpt-3.5-turbo").
|
1480 |
+
api_base (`str`, *optional*):
|
1481 |
+
The base URL of the OpenAI-compatible API server.
|
1482 |
+
api_key (`str`, *optional*):
|
1483 |
+
The API key to use for authentication.
|
1484 |
+
organization (`str`, *optional*):
|
1485 |
+
The organization to use for the API request.
|
1486 |
+
project (`str`, *optional*):
|
1487 |
+
The project to use for the API request.
|
1488 |
+
client_kwargs (`dict[str, Any]`, *optional*):
|
1489 |
+
Additional keyword arguments to pass to the OpenAI client (like organization, project, max_retries etc.).
|
1490 |
+
custom_role_conversions (`dict[str, str]`, *optional*):
|
1491 |
+
Custom role conversion mapping to convert message roles in others.
|
1492 |
+
Useful for specific models that do not support specific message roles like "system".
|
1493 |
+
flatten_messages_as_text (`bool`, default `False`):
|
1494 |
+
Whether to flatten messages as text.
|
1495 |
+
**kwargs:
|
1496 |
+
Additional keyword arguments to pass to the OpenAI API.
|
1497 |
+
"""
|
1498 |
+
|
1499 |
+
def __init__(
|
1500 |
+
self,
|
1501 |
+
model_id: str,
|
1502 |
+
api_base: str | None = None,
|
1503 |
+
api_key: str | None = None,
|
1504 |
+
organization: str | None = None,
|
1505 |
+
project: str | None = None,
|
1506 |
+
client_kwargs: dict[str, Any] | None = None,
|
1507 |
+
custom_role_conversions: dict[str, str] | None = None,
|
1508 |
+
flatten_messages_as_text: bool = False,
|
1509 |
+
**kwargs,
|
1510 |
+
):
|
1511 |
+
self.client_kwargs = {
|
1512 |
+
**(client_kwargs or {}),
|
1513 |
+
"api_key": api_key,
|
1514 |
+
"base_url": api_base,
|
1515 |
+
"organization": organization,
|
1516 |
+
"project": project,
|
1517 |
+
}
|
1518 |
+
super().__init__(
|
1519 |
+
model_id=model_id,
|
1520 |
+
custom_role_conversions=custom_role_conversions,
|
1521 |
+
flatten_messages_as_text=flatten_messages_as_text,
|
1522 |
+
**kwargs,
|
1523 |
+
)
|
1524 |
+
|
1525 |
+
def create_client(self):
|
1526 |
+
try:
|
1527 |
+
import openai
|
1528 |
+
except ModuleNotFoundError as e:
|
1529 |
+
raise ModuleNotFoundError(
|
1530 |
+
"Please install 'openai' extra to use OpenAIServerModel: `pip install 'smolagents[openai]'`"
|
1531 |
+
) from e
|
1532 |
+
|
1533 |
+
return openai.OpenAI(**self.client_kwargs)
|
1534 |
+
|
1535 |
+
def generate_stream(
|
1536 |
+
self,
|
1537 |
+
messages: list[ChatMessage],
|
1538 |
+
stop_sequences: list[str] | None = None,
|
1539 |
+
response_format: dict[str, str] | None = None,
|
1540 |
+
tools_to_call_from: list[Tool] | None = None,
|
1541 |
+
**kwargs,
|
1542 |
+
) -> Generator[ChatMessageStreamDelta]:
|
1543 |
+
completion_kwargs = self._prepare_completion_kwargs(
|
1544 |
+
messages=messages,
|
1545 |
+
stop_sequences=stop_sequences,
|
1546 |
+
response_format=response_format,
|
1547 |
+
tools_to_call_from=tools_to_call_from,
|
1548 |
+
model=self.model_id,
|
1549 |
+
custom_role_conversions=self.custom_role_conversions,
|
1550 |
+
convert_images_to_image_urls=True,
|
1551 |
+
**kwargs,
|
1552 |
+
)
|
1553 |
+
for event in self.client.chat.completions.create(
|
1554 |
+
**completion_kwargs, stream=True, stream_options={"include_usage": True}
|
1555 |
+
):
|
1556 |
+
if event.usage:
|
1557 |
+
self._last_input_token_count = event.usage.prompt_tokens
|
1558 |
+
self._last_output_token_count = event.usage.completion_tokens
|
1559 |
+
yield ChatMessageStreamDelta(
|
1560 |
+
content="",
|
1561 |
+
token_usage=TokenUsage(
|
1562 |
+
input_tokens=event.usage.prompt_tokens,
|
1563 |
+
output_tokens=event.usage.completion_tokens,
|
1564 |
+
),
|
1565 |
+
)
|
1566 |
+
if event.choices:
|
1567 |
+
choice = event.choices[0]
|
1568 |
+
if choice.delta:
|
1569 |
+
yield ChatMessageStreamDelta(
|
1570 |
+
content=choice.delta.content,
|
1571 |
+
tool_calls=[
|
1572 |
+
ChatMessageToolCallStreamDelta(
|
1573 |
+
index=delta.index,
|
1574 |
+
id=delta.id,
|
1575 |
+
type=delta.type,
|
1576 |
+
function=delta.function,
|
1577 |
+
)
|
1578 |
+
for delta in choice.delta.tool_calls
|
1579 |
+
]
|
1580 |
+
if choice.delta.tool_calls
|
1581 |
+
else None,
|
1582 |
+
)
|
1583 |
+
else:
|
1584 |
+
if not getattr(choice, "finish_reason", None):
|
1585 |
+
raise ValueError(f"No content or tool calls in event: {event}")
|
1586 |
+
|
1587 |
+
def generate(
|
1588 |
+
self,
|
1589 |
+
messages: list[ChatMessage],
|
1590 |
+
stop_sequences: list[str] | None = None,
|
1591 |
+
response_format: dict[str, str] | None = None,
|
1592 |
+
tools_to_call_from: list[Tool] | None = None,
|
1593 |
+
**kwargs,
|
1594 |
+
) -> ChatMessage:
|
1595 |
+
completion_kwargs = self._prepare_completion_kwargs(
|
1596 |
+
messages=messages,
|
1597 |
+
stop_sequences=stop_sequences,
|
1598 |
+
response_format=response_format,
|
1599 |
+
tools_to_call_from=tools_to_call_from,
|
1600 |
+
model=self.model_id,
|
1601 |
+
custom_role_conversions=self.custom_role_conversions,
|
1602 |
+
convert_images_to_image_urls=True,
|
1603 |
+
**kwargs,
|
1604 |
+
)
|
1605 |
+
response = self.client.chat.completions.create(**completion_kwargs)
|
1606 |
+
|
1607 |
+
# Reported that `response.usage` can be None in some cases when using OpenRouter: see GH-1401
|
1608 |
+
self._last_input_token_count = getattr(response.usage, "prompt_tokens", 0)
|
1609 |
+
self._last_output_token_count = getattr(response.usage, "completion_tokens", 0)
|
1610 |
+
return ChatMessage.from_dict(
|
1611 |
+
response.choices[0].message.model_dump(include={"role", "content", "tool_calls"}),
|
1612 |
+
raw=response,
|
1613 |
+
token_usage=TokenUsage(
|
1614 |
+
input_tokens=response.usage.prompt_tokens,
|
1615 |
+
output_tokens=response.usage.completion_tokens,
|
1616 |
+
),
|
1617 |
+
)
|
1618 |
+
|
1619 |
+
|
1620 |
+
OpenAIModel = OpenAIServerModel
|
1621 |
+
|
1622 |
+
|
1623 |
+
class AzureOpenAIServerModel(OpenAIServerModel):
|
1624 |
+
"""This model connects to an Azure OpenAI deployment.
|
1625 |
+
|
1626 |
+
Parameters:
|
1627 |
+
model_id (`str`):
|
1628 |
+
The model deployment name to use when connecting (e.g. "gpt-4o-mini").
|
1629 |
+
azure_endpoint (`str`, *optional*):
|
1630 |
+
The Azure endpoint, including the resource, e.g. `https://example-resource.azure.openai.com/`. If not provided, it will be inferred from the `AZURE_OPENAI_ENDPOINT` environment variable.
|
1631 |
+
api_key (`str`, *optional*):
|
1632 |
+
The API key to use for authentication. If not provided, it will be inferred from the `AZURE_OPENAI_API_KEY` environment variable.
|
1633 |
+
api_version (`str`, *optional*):
|
1634 |
+
The API version to use. If not provided, it will be inferred from the `OPENAI_API_VERSION` environment variable.
|
1635 |
+
client_kwargs (`dict[str, Any]`, *optional*):
|
1636 |
+
Additional keyword arguments to pass to the AzureOpenAI client (like organization, project, max_retries etc.).
|
1637 |
+
custom_role_conversions (`dict[str, str]`, *optional*):
|
1638 |
+
Custom role conversion mapping to convert message roles in others.
|
1639 |
+
Useful for specific models that do not support specific message roles like "system".
|
1640 |
+
**kwargs:
|
1641 |
+
Additional keyword arguments to pass to the Azure OpenAI API.
|
1642 |
+
"""
|
1643 |
+
|
1644 |
+
def __init__(
|
1645 |
+
self,
|
1646 |
+
model_id: str,
|
1647 |
+
azure_endpoint: str | None = None,
|
1648 |
+
api_key: str | None = None,
|
1649 |
+
api_version: str | None = None,
|
1650 |
+
client_kwargs: dict[str, Any] | None = None,
|
1651 |
+
custom_role_conversions: dict[str, str] | None = None,
|
1652 |
+
**kwargs,
|
1653 |
+
):
|
1654 |
+
client_kwargs = client_kwargs or {}
|
1655 |
+
client_kwargs.update(
|
1656 |
+
{
|
1657 |
+
"api_version": api_version,
|
1658 |
+
"azure_endpoint": azure_endpoint,
|
1659 |
+
}
|
1660 |
+
)
|
1661 |
+
super().__init__(
|
1662 |
+
model_id=model_id,
|
1663 |
+
api_key=api_key,
|
1664 |
+
client_kwargs=client_kwargs,
|
1665 |
+
custom_role_conversions=custom_role_conversions,
|
1666 |
+
**kwargs,
|
1667 |
+
)
|
1668 |
+
|
1669 |
+
def create_client(self):
|
1670 |
+
try:
|
1671 |
+
import openai
|
1672 |
+
except ModuleNotFoundError as e:
|
1673 |
+
raise ModuleNotFoundError(
|
1674 |
+
"Please install 'openai' extra to use AzureOpenAIServerModel: `pip install 'smolagents[openai]'`"
|
1675 |
+
) from e
|
1676 |
+
|
1677 |
+
return openai.AzureOpenAI(**self.client_kwargs)
|
1678 |
+
|
1679 |
+
|
1680 |
+
AzureOpenAIModel = AzureOpenAIServerModel
|
1681 |
+
|
1682 |
+
|
1683 |
+
class AmazonBedrockServerModel(ApiModel):
|
1684 |
+
"""
|
1685 |
+
A model class for interacting with Amazon Bedrock Server models through the Bedrock API.
|
1686 |
+
|
1687 |
+
This class provides an interface to interact with various Bedrock language models,
|
1688 |
+
allowing for customized model inference, guardrail configuration, message handling,
|
1689 |
+
and other parameters allowed by boto3 API.
|
1690 |
+
|
1691 |
+
Parameters:
|
1692 |
+
model_id (`str`):
|
1693 |
+
The model identifier to use on Bedrock (e.g. "us.amazon.nova-pro-v1:0").
|
1694 |
+
client (`boto3.client`, *optional*):
|
1695 |
+
A custom boto3 client for AWS interactions. If not provided, a default client will be created.
|
1696 |
+
client_kwargs (dict[str, Any], *optional*):
|
1697 |
+
Keyword arguments used to configure the boto3 client if it needs to be created internally.
|
1698 |
+
Examples include `region_name`, `config`, or `endpoint_url`.
|
1699 |
+
custom_role_conversions (`dict[str, str]`, *optional*):
|
1700 |
+
Custom role conversion mapping to convert message roles in others.
|
1701 |
+
Useful for specific models that do not support specific message roles like "system".
|
1702 |
+
Defaults to converting all roles to "user" role to enable using all the Bedrock models.
|
1703 |
+
flatten_messages_as_text (`bool`, default `False`):
|
1704 |
+
Whether to flatten messages as text.
|
1705 |
+
**kwargs
|
1706 |
+
Additional keyword arguments passed directly to the underlying API calls.
|
1707 |
+
|
1708 |
+
Example:
|
1709 |
+
Creating a model instance with default settings:
|
1710 |
+
>>> bedrock_model = AmazonBedrockServerModel(
|
1711 |
+
... model_id='us.amazon.nova-pro-v1:0'
|
1712 |
+
... )
|
1713 |
+
|
1714 |
+
Creating a model instance with a custom boto3 client:
|
1715 |
+
>>> import boto3
|
1716 |
+
>>> client = boto3.client('bedrock-runtime', region_name='us-west-2')
|
1717 |
+
>>> bedrock_model = AmazonBedrockServerModel(
|
1718 |
+
... model_id='us.amazon.nova-pro-v1:0',
|
1719 |
+
... client=client
|
1720 |
+
... )
|
1721 |
+
|
1722 |
+
Creating a model instance with client_kwargs for internal client creation:
|
1723 |
+
>>> bedrock_model = AmazonBedrockServerModel(
|
1724 |
+
... model_id='us.amazon.nova-pro-v1:0',
|
1725 |
+
... client_kwargs={'region_name': 'us-west-2', 'endpoint_url': 'https://custom-endpoint.com'}
|
1726 |
+
... )
|
1727 |
+
|
1728 |
+
Creating a model instance with inference and guardrail configurations:
|
1729 |
+
>>> additional_api_config = {
|
1730 |
+
... "inferenceConfig": {
|
1731 |
+
... "maxTokens": 3000
|
1732 |
+
... },
|
1733 |
+
... "guardrailConfig": {
|
1734 |
+
... "guardrailIdentifier": "identify1",
|
1735 |
+
... "guardrailVersion": 'v1'
|
1736 |
+
... },
|
1737 |
+
... }
|
1738 |
+
>>> bedrock_model = AmazonBedrockServerModel(
|
1739 |
+
... model_id='anthropic.claude-3-haiku-20240307-v1:0',
|
1740 |
+
... **additional_api_config
|
1741 |
+
... )
|
1742 |
+
"""
|
1743 |
+
|
1744 |
+
def __init__(
|
1745 |
+
self,
|
1746 |
+
model_id: str,
|
1747 |
+
client=None,
|
1748 |
+
client_kwargs: dict[str, Any] | None = None,
|
1749 |
+
custom_role_conversions: dict[str, str] | None = None,
|
1750 |
+
**kwargs,
|
1751 |
+
):
|
1752 |
+
self.client_kwargs = client_kwargs or {}
|
1753 |
+
|
1754 |
+
# Bedrock only supports `assistant` and `user` roles.
|
1755 |
+
# Many Bedrock models do not allow conversations to start with the `assistant` role, so the default is set to `user/user`.
|
1756 |
+
# This parameter is retained for future model implementations and extended support.
|
1757 |
+
custom_role_conversions = custom_role_conversions or {
|
1758 |
+
MessageRole.SYSTEM: MessageRole.USER,
|
1759 |
+
MessageRole.ASSISTANT: MessageRole.USER,
|
1760 |
+
MessageRole.TOOL_CALL: MessageRole.USER,
|
1761 |
+
MessageRole.TOOL_RESPONSE: MessageRole.USER,
|
1762 |
+
}
|
1763 |
+
|
1764 |
+
super().__init__(
|
1765 |
+
model_id=model_id,
|
1766 |
+
custom_role_conversions=custom_role_conversions,
|
1767 |
+
flatten_messages_as_text=False, # Bedrock API doesn't support flatten messages, must be a list of messages
|
1768 |
+
client=client,
|
1769 |
+
**kwargs,
|
1770 |
+
)
|
1771 |
+
|
1772 |
+
def _prepare_completion_kwargs(
|
1773 |
+
self,
|
1774 |
+
messages: list[ChatMessage],
|
1775 |
+
stop_sequences: list[str] | None = None,
|
1776 |
+
response_format: dict[str, str] | None = None,
|
1777 |
+
tools_to_call_from: list[Tool] | None = None,
|
1778 |
+
custom_role_conversions: dict[str, str] | None = None,
|
1779 |
+
convert_images_to_image_urls: bool = False,
|
1780 |
+
tool_choice: str | dict[Any, Any] | None = None,
|
1781 |
+
**kwargs,
|
1782 |
+
) -> dict:
|
1783 |
+
"""
|
1784 |
+
Overrides the base method to handle Bedrock-specific configurations.
|
1785 |
+
|
1786 |
+
This implementation adapts the completion keyword arguments to align with
|
1787 |
+
Bedrock's requirements, ensuring compatibility with its unique setup and
|
1788 |
+
constraints.
|
1789 |
+
"""
|
1790 |
+
completion_kwargs = super()._prepare_completion_kwargs(
|
1791 |
+
messages=messages,
|
1792 |
+
stop_sequences=None, # Bedrock support stop_sequence using Inference Config
|
1793 |
+
tools_to_call_from=tools_to_call_from,
|
1794 |
+
custom_role_conversions=custom_role_conversions,
|
1795 |
+
convert_images_to_image_urls=convert_images_to_image_urls,
|
1796 |
+
**kwargs,
|
1797 |
+
)
|
1798 |
+
|
1799 |
+
# Not all models in Bedrock support `toolConfig`. Also, smolagents already include the tool call in the prompt,
|
1800 |
+
# so adding `toolConfig` could cause conflicts. We remove it to avoid issues.
|
1801 |
+
completion_kwargs.pop("toolConfig", None)
|
1802 |
+
|
1803 |
+
# The Bedrock API does not support the `type` key in requests.
|
1804 |
+
# This block of code modifies the object to meet Bedrock's requirements.
|
1805 |
+
for message in completion_kwargs.get("messages", []):
|
1806 |
+
for content in message.get("content", []):
|
1807 |
+
if "type" in content:
|
1808 |
+
del content["type"]
|
1809 |
+
|
1810 |
+
return {
|
1811 |
+
"modelId": self.model_id,
|
1812 |
+
**completion_kwargs,
|
1813 |
+
}
|
1814 |
+
|
1815 |
+
def create_client(self):
|
1816 |
+
try:
|
1817 |
+
import boto3 # type: ignore
|
1818 |
+
except ModuleNotFoundError as e:
|
1819 |
+
raise ModuleNotFoundError(
|
1820 |
+
"Please install 'bedrock' extra to use AmazonBedrockServerModel: `pip install 'smolagents[bedrock]'`"
|
1821 |
+
) from e
|
1822 |
+
|
1823 |
+
return boto3.client("bedrock-runtime", **self.client_kwargs)
|
1824 |
+
|
1825 |
+
def generate(
|
1826 |
+
self,
|
1827 |
+
messages: list[ChatMessage],
|
1828 |
+
stop_sequences: list[str] | None = None,
|
1829 |
+
response_format: dict[str, str] | None = None,
|
1830 |
+
tools_to_call_from: list[Tool] | None = None,
|
1831 |
+
**kwargs,
|
1832 |
+
) -> ChatMessage:
|
1833 |
+
if response_format is not None:
|
1834 |
+
raise ValueError("Amazon Bedrock does not support response_format")
|
1835 |
+
completion_kwargs: dict = self._prepare_completion_kwargs(
|
1836 |
+
messages=messages,
|
1837 |
+
tools_to_call_from=tools_to_call_from,
|
1838 |
+
custom_role_conversions=self.custom_role_conversions,
|
1839 |
+
convert_images_to_image_urls=True,
|
1840 |
+
**kwargs,
|
1841 |
+
)
|
1842 |
+
|
1843 |
+
# self.client is created in ApiModel class
|
1844 |
+
response = self.client.converse(**completion_kwargs)
|
1845 |
+
|
1846 |
+
# Get first message
|
1847 |
+
response["output"]["message"]["content"] = response["output"]["message"]["content"][0]["text"]
|
1848 |
+
|
1849 |
+
self._last_input_token_count = response["usage"]["inputTokens"]
|
1850 |
+
self._last_output_token_count = response["usage"]["outputTokens"]
|
1851 |
+
return ChatMessage.from_dict(
|
1852 |
+
response["output"]["message"],
|
1853 |
+
raw=response,
|
1854 |
+
token_usage=TokenUsage(
|
1855 |
+
input_tokens=response["usage"]["inputTokens"],
|
1856 |
+
output_tokens=response["usage"]["outputTokens"],
|
1857 |
+
),
|
1858 |
+
)
|
1859 |
+
|
1860 |
+
|
1861 |
+
AmazonBedrockModel = AmazonBedrockServerModel
|
1862 |
+
|
1863 |
+
__all__ = [
|
1864 |
+
"MessageRole",
|
1865 |
+
"tool_role_conversions",
|
1866 |
+
"get_clean_message_list",
|
1867 |
+
"Model",
|
1868 |
+
"MLXModel",
|
1869 |
+
"TransformersModel",
|
1870 |
+
"ApiModel",
|
1871 |
+
"InferenceClientModel",
|
1872 |
+
"LiteLLMModel",
|
1873 |
+
"LiteLLMRouterModel",
|
1874 |
+
"OpenAIServerModel",
|
1875 |
+
"OpenAIModel",
|
1876 |
+
"VLLMModel",
|
1877 |
+
"AzureOpenAIServerModel",
|
1878 |
+
"AzureOpenAIModel",
|
1879 |
+
"AmazonBedrockServerModel",
|
1880 |
+
"AmazonBedrockModel",
|
1881 |
+
"ChatMessage",
|
1882 |
+
]
|
src/smolagents/monitoring.py
ADDED
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding=utf-8
|
3 |
+
|
4 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
5 |
+
#
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# Unless required by applicable law or agreed to in writing, software
|
13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
# See the License for the specific language governing permissions and
|
16 |
+
# limitations under the License.
|
17 |
+
import json
|
18 |
+
from dataclasses import dataclass, field
|
19 |
+
from enum import IntEnum
|
20 |
+
|
21 |
+
from rich import box
|
22 |
+
from rich.console import Console, Group
|
23 |
+
from rich.panel import Panel
|
24 |
+
from rich.rule import Rule
|
25 |
+
from rich.syntax import Syntax
|
26 |
+
from rich.table import Table
|
27 |
+
from rich.text import Text
|
28 |
+
from rich.tree import Tree
|
29 |
+
|
30 |
+
from smolagents.utils import escape_code_brackets
|
31 |
+
|
32 |
+
|
33 |
+
__all__ = ["AgentLogger", "LogLevel", "Monitor", "TokenUsage", "Timing"]
|
34 |
+
|
35 |
+
|
36 |
+
@dataclass
|
37 |
+
class TokenUsage:
|
38 |
+
"""
|
39 |
+
Contains the token usage information for a given step or run.
|
40 |
+
"""
|
41 |
+
|
42 |
+
input_tokens: int
|
43 |
+
output_tokens: int
|
44 |
+
total_tokens: int = field(init=False)
|
45 |
+
|
46 |
+
def __post_init__(self):
|
47 |
+
self.total_tokens = self.input_tokens + self.output_tokens
|
48 |
+
|
49 |
+
def dict(self):
|
50 |
+
return {
|
51 |
+
"input_tokens": self.input_tokens,
|
52 |
+
"output_tokens": self.output_tokens,
|
53 |
+
"total_tokens": self.total_tokens,
|
54 |
+
}
|
55 |
+
|
56 |
+
|
57 |
+
@dataclass
|
58 |
+
class Timing:
|
59 |
+
"""
|
60 |
+
Contains the timing information for a given step or run.
|
61 |
+
"""
|
62 |
+
|
63 |
+
start_time: float
|
64 |
+
end_time: float | None = None
|
65 |
+
|
66 |
+
@property
|
67 |
+
def duration(self):
|
68 |
+
return None if self.end_time is None else self.end_time - self.start_time
|
69 |
+
|
70 |
+
def dict(self):
|
71 |
+
return {
|
72 |
+
"start_time": self.start_time,
|
73 |
+
"end_time": self.end_time,
|
74 |
+
"duration": self.duration,
|
75 |
+
}
|
76 |
+
|
77 |
+
def __repr__(self) -> str:
|
78 |
+
return f"Timing(start_time={self.start_time}, end_time={self.end_time}, duration={self.duration})"
|
79 |
+
|
80 |
+
|
81 |
+
class Monitor:
|
82 |
+
def __init__(self, tracked_model, logger):
|
83 |
+
self.step_durations = []
|
84 |
+
self.tracked_model = tracked_model
|
85 |
+
self.logger = logger
|
86 |
+
self.total_input_token_count = 0
|
87 |
+
self.total_output_token_count = 0
|
88 |
+
|
89 |
+
def get_total_token_counts(self) -> TokenUsage:
|
90 |
+
return TokenUsage(
|
91 |
+
input_tokens=self.total_input_token_count,
|
92 |
+
output_tokens=self.total_output_token_count,
|
93 |
+
)
|
94 |
+
|
95 |
+
def reset(self):
|
96 |
+
self.step_durations = []
|
97 |
+
self.total_input_token_count = 0
|
98 |
+
self.total_output_token_count = 0
|
99 |
+
|
100 |
+
def update_metrics(self, step_log):
|
101 |
+
"""Update the metrics of the monitor.
|
102 |
+
|
103 |
+
Args:
|
104 |
+
step_log ([`MemoryStep`]): Step log to update the monitor with.
|
105 |
+
"""
|
106 |
+
step_duration = step_log.timing.duration
|
107 |
+
self.step_durations.append(step_duration)
|
108 |
+
console_outputs = f"[Step {len(self.step_durations)}: Duration {step_duration:.2f} seconds"
|
109 |
+
|
110 |
+
if step_log.token_usage is not None:
|
111 |
+
self.total_input_token_count += step_log.token_usage.input_tokens
|
112 |
+
self.total_output_token_count += step_log.token_usage.output_tokens
|
113 |
+
console_outputs += (
|
114 |
+
f"| Input tokens: {self.total_input_token_count:,} | Output tokens: {self.total_output_token_count:,}"
|
115 |
+
)
|
116 |
+
console_outputs += "]"
|
117 |
+
self.logger.log(Text(console_outputs, style="dim"), level=1)
|
118 |
+
|
119 |
+
|
120 |
+
class LogLevel(IntEnum):
|
121 |
+
OFF = -1 # No output
|
122 |
+
ERROR = 0 # Only errors
|
123 |
+
INFO = 1 # Normal output (default)
|
124 |
+
DEBUG = 2 # Detailed output
|
125 |
+
|
126 |
+
|
127 |
+
YELLOW_HEX = "#d4b702"
|
128 |
+
|
129 |
+
|
130 |
+
class AgentLogger:
|
131 |
+
def __init__(self, level: LogLevel = LogLevel.INFO, console: Console | None = None):
|
132 |
+
self.level = level
|
133 |
+
if console is None:
|
134 |
+
self.console = Console()
|
135 |
+
else:
|
136 |
+
self.console = console
|
137 |
+
|
138 |
+
def log(self, *args, level: int | str | LogLevel = LogLevel.INFO, **kwargs) -> None:
|
139 |
+
"""Logs a message to the console.
|
140 |
+
|
141 |
+
Args:
|
142 |
+
level (LogLevel, optional): Defaults to LogLevel.INFO.
|
143 |
+
"""
|
144 |
+
if isinstance(level, str):
|
145 |
+
level = LogLevel[level.upper()]
|
146 |
+
if level <= self.level:
|
147 |
+
self.console.print(*args, **kwargs)
|
148 |
+
|
149 |
+
def log_error(self, error_message: str) -> None:
|
150 |
+
self.log(escape_code_brackets(error_message), style="bold red", level=LogLevel.ERROR)
|
151 |
+
|
152 |
+
def log_markdown(self, content: str, title: str | None = None, level=LogLevel.INFO, style=YELLOW_HEX) -> None:
|
153 |
+
markdown_content = Syntax(
|
154 |
+
content,
|
155 |
+
lexer="markdown",
|
156 |
+
theme="github-dark",
|
157 |
+
word_wrap=True,
|
158 |
+
)
|
159 |
+
if title:
|
160 |
+
self.log(
|
161 |
+
Group(
|
162 |
+
Rule(
|
163 |
+
"[bold italic]" + title,
|
164 |
+
align="left",
|
165 |
+
style=style,
|
166 |
+
),
|
167 |
+
markdown_content,
|
168 |
+
),
|
169 |
+
level=level,
|
170 |
+
)
|
171 |
+
else:
|
172 |
+
self.log(markdown_content, level=level)
|
173 |
+
|
174 |
+
def log_code(self, title: str, content: str, level: int = LogLevel.INFO) -> None:
|
175 |
+
self.log(
|
176 |
+
Panel(
|
177 |
+
Syntax(
|
178 |
+
content,
|
179 |
+
lexer="python",
|
180 |
+
theme="monokai",
|
181 |
+
word_wrap=True,
|
182 |
+
),
|
183 |
+
title="[bold]" + title,
|
184 |
+
title_align="left",
|
185 |
+
box=box.HORIZONTALS,
|
186 |
+
),
|
187 |
+
level=level,
|
188 |
+
)
|
189 |
+
|
190 |
+
def log_rule(self, title: str, level: int = LogLevel.INFO) -> None:
|
191 |
+
self.log(
|
192 |
+
Rule(
|
193 |
+
"[bold]" + title,
|
194 |
+
characters="━",
|
195 |
+
style=YELLOW_HEX,
|
196 |
+
),
|
197 |
+
level=LogLevel.INFO,
|
198 |
+
)
|
199 |
+
|
200 |
+
def log_task(self, content: str, subtitle: str, title: str | None = None, level: LogLevel = LogLevel.INFO) -> None:
|
201 |
+
self.log(
|
202 |
+
Panel(
|
203 |
+
f"\n[bold]{escape_code_brackets(content)}\n",
|
204 |
+
title="[bold]New run" + (f" - {title}" if title else ""),
|
205 |
+
subtitle=subtitle,
|
206 |
+
border_style=YELLOW_HEX,
|
207 |
+
subtitle_align="left",
|
208 |
+
),
|
209 |
+
level=level,
|
210 |
+
)
|
211 |
+
|
212 |
+
def log_messages(self, messages: list[dict], level: LogLevel = LogLevel.DEBUG) -> None:
|
213 |
+
messages_as_string = "\n".join([json.dumps(dict(message), indent=4) for message in messages])
|
214 |
+
self.log(
|
215 |
+
Syntax(
|
216 |
+
messages_as_string,
|
217 |
+
lexer="markdown",
|
218 |
+
theme="github-dark",
|
219 |
+
word_wrap=True,
|
220 |
+
),
|
221 |
+
level=level,
|
222 |
+
)
|
223 |
+
|
224 |
+
def visualize_agent_tree(self, agent):
|
225 |
+
def create_tools_section(tools_dict):
|
226 |
+
table = Table(show_header=True, header_style="bold")
|
227 |
+
table.add_column("Name", style="#1E90FF")
|
228 |
+
table.add_column("Description")
|
229 |
+
table.add_column("Arguments")
|
230 |
+
|
231 |
+
for name, tool in tools_dict.items():
|
232 |
+
args = [
|
233 |
+
f"{arg_name} (`{info.get('type', 'Any')}`{', optional' if info.get('optional') else ''}): {info.get('description', '')}"
|
234 |
+
for arg_name, info in getattr(tool, "inputs", {}).items()
|
235 |
+
]
|
236 |
+
table.add_row(name, getattr(tool, "description", str(tool)), "\n".join(args))
|
237 |
+
|
238 |
+
return Group("🛠️ [italic #1E90FF]Tools:[/italic #1E90FF]", table)
|
239 |
+
|
240 |
+
def get_agent_headline(agent, name: str | None = None):
|
241 |
+
name_headline = f"{name} | " if name else ""
|
242 |
+
return f"[bold {YELLOW_HEX}]{name_headline}{agent.__class__.__name__} | {agent.model.model_id}"
|
243 |
+
|
244 |
+
def build_agent_tree(parent_tree, agent_obj):
|
245 |
+
"""Recursively builds the agent tree."""
|
246 |
+
parent_tree.add(create_tools_section(agent_obj.tools))
|
247 |
+
|
248 |
+
if agent_obj.managed_agents:
|
249 |
+
agents_branch = parent_tree.add("🤖 [italic #1E90FF]Managed agents:")
|
250 |
+
for name, managed_agent in agent_obj.managed_agents.items():
|
251 |
+
agent_tree = agents_branch.add(get_agent_headline(managed_agent, name))
|
252 |
+
if managed_agent.__class__.__name__ == "CodeAgent":
|
253 |
+
agent_tree.add(
|
254 |
+
f"✅ [italic #1E90FF]Authorized imports:[/italic #1E90FF] {managed_agent.additional_authorized_imports}"
|
255 |
+
)
|
256 |
+
agent_tree.add(f"📝 [italic #1E90FF]Description:[/italic #1E90FF] {managed_agent.description}")
|
257 |
+
build_agent_tree(agent_tree, managed_agent)
|
258 |
+
|
259 |
+
main_tree = Tree(get_agent_headline(agent))
|
260 |
+
if agent.__class__.__name__ == "CodeAgent":
|
261 |
+
main_tree.add(
|
262 |
+
f"✅ [italic #1E90FF]Authorized imports:[/italic #1E90FF] {agent.additional_authorized_imports}"
|
263 |
+
)
|
264 |
+
build_agent_tree(main_tree, agent)
|
265 |
+
self.console.print(main_tree)
|
src/smolagents/remote_executors.py
ADDED
@@ -0,0 +1,451 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding=utf-8
|
3 |
+
|
4 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
5 |
+
#
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# Unless required by applicable law or agreed to in writing, software
|
13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
# See the License for the specific language governing permissions and
|
16 |
+
# limitations under the License.
|
17 |
+
import base64
|
18 |
+
import inspect
|
19 |
+
import json
|
20 |
+
import pickle
|
21 |
+
import time
|
22 |
+
from io import BytesIO
|
23 |
+
from pathlib import Path
|
24 |
+
from textwrap import dedent
|
25 |
+
from typing import Any
|
26 |
+
|
27 |
+
import PIL.Image
|
28 |
+
import requests
|
29 |
+
|
30 |
+
from .default_tools import FinalAnswerTool
|
31 |
+
from .local_python_executor import PythonExecutor
|
32 |
+
from .monitoring import LogLevel
|
33 |
+
from .tools import Tool, get_tools_definition_code
|
34 |
+
from .utils import AgentError
|
35 |
+
|
36 |
+
|
37 |
+
try:
|
38 |
+
from dotenv import load_dotenv
|
39 |
+
|
40 |
+
load_dotenv()
|
41 |
+
except ModuleNotFoundError:
|
42 |
+
pass
|
43 |
+
|
44 |
+
|
45 |
+
class RemotePythonExecutor(PythonExecutor):
|
46 |
+
FINAL_ANSWER_EXCEPTION = "FinalAnswerException"
|
47 |
+
|
48 |
+
def __init__(self, additional_imports: list[str], logger):
|
49 |
+
self.additional_imports = additional_imports
|
50 |
+
self.logger = logger
|
51 |
+
self.logger.log("Initializing executor, hold on...")
|
52 |
+
self.installed_packages = []
|
53 |
+
|
54 |
+
def run_code_raise_errors(self, code: str) -> tuple[Any, str, bool]:
|
55 |
+
"""
|
56 |
+
Execute code, return the result and output, also determining if
|
57 |
+
the result is the final answer.
|
58 |
+
"""
|
59 |
+
raise NotImplementedError
|
60 |
+
|
61 |
+
def send_tools(self, tools: dict[str, Tool]):
|
62 |
+
if "final_answer" in tools:
|
63 |
+
self._patch_final_answer_with_exception(tools["final_answer"])
|
64 |
+
# Install tool packages
|
65 |
+
packages_to_install = {
|
66 |
+
pkg
|
67 |
+
for tool in tools.values()
|
68 |
+
for pkg in tool.to_dict()["requirements"]
|
69 |
+
if pkg not in self.installed_packages + ["smolagents"]
|
70 |
+
}
|
71 |
+
if packages_to_install:
|
72 |
+
self.installed_packages += self.install_packages(list(packages_to_install))
|
73 |
+
# Get tool definitions
|
74 |
+
code = get_tools_definition_code(tools)
|
75 |
+
if code:
|
76 |
+
execution = self.run_code_raise_errors(code)
|
77 |
+
self.logger.log(execution[1])
|
78 |
+
|
79 |
+
def send_variables(self, variables: dict):
|
80 |
+
"""
|
81 |
+
Send variables to the kernel namespace using pickle.
|
82 |
+
"""
|
83 |
+
pickled_vars = base64.b64encode(pickle.dumps(variables)).decode()
|
84 |
+
code = f"""
|
85 |
+
import pickle, base64
|
86 |
+
vars_dict = pickle.loads(base64.b64decode('{pickled_vars}'))
|
87 |
+
locals().update(vars_dict)
|
88 |
+
"""
|
89 |
+
self.run_code_raise_errors(code)
|
90 |
+
|
91 |
+
def __call__(self, code_action: str) -> tuple[Any, str, bool]:
|
92 |
+
"""Run the code and determine if it is the final answer."""
|
93 |
+
return self.run_code_raise_errors(code_action)
|
94 |
+
|
95 |
+
def install_packages(self, additional_imports: list[str]):
|
96 |
+
if additional_imports:
|
97 |
+
_, execution_logs, _ = self.run_code_raise_errors(f"!pip install {' '.join(additional_imports)}")
|
98 |
+
self.logger.log(execution_logs)
|
99 |
+
return additional_imports
|
100 |
+
|
101 |
+
def _patch_final_answer_with_exception(self, final_answer_tool: FinalAnswerTool):
|
102 |
+
"""Patch the FinalAnswerTool to raise an exception.
|
103 |
+
|
104 |
+
This is necessary because the remote executors
|
105 |
+
rely on the FinalAnswerTool to detect the final answer.
|
106 |
+
It modifies the `forward` method of the FinalAnswerTool to raise
|
107 |
+
a `FinalAnswerException` with the final answer as a pickled value.
|
108 |
+
This allows the executor to catch this exception and return the final answer.
|
109 |
+
|
110 |
+
Args:
|
111 |
+
final_answer_tool (`FinalAnswerTool`): FinalAnswerTool instance to patch.
|
112 |
+
"""
|
113 |
+
|
114 |
+
# Create a new class that inherits from the original FinalAnswerTool
|
115 |
+
class _FinalAnswerTool(final_answer_tool.__class__):
|
116 |
+
pass
|
117 |
+
|
118 |
+
# Add a new forward method that raises the FinalAnswerException
|
119 |
+
# - Define the new forward method function
|
120 |
+
def forward(self, *args, **kwargs) -> Any:
|
121 |
+
import base64
|
122 |
+
import pickle
|
123 |
+
|
124 |
+
class FinalAnswerException(Exception):
|
125 |
+
def __init__(self, value):
|
126 |
+
self.value = value
|
127 |
+
|
128 |
+
raise FinalAnswerException(base64.b64encode(pickle.dumps(self._forward(*args, **kwargs))).decode())
|
129 |
+
|
130 |
+
# - Set the new forward method function to the _FinalAnswerTool class
|
131 |
+
_FinalAnswerTool.forward = forward
|
132 |
+
|
133 |
+
# Rename the original forward method to _forward
|
134 |
+
# - Get the original forward method function from the final_answer_tool instance
|
135 |
+
original_forward_function = final_answer_tool.forward.__func__
|
136 |
+
# - Set the new _forward method function to the _FinalAnswerTool class
|
137 |
+
_FinalAnswerTool._forward = original_forward_function
|
138 |
+
# - Update the source code of the new forward method to match the original but with the new name
|
139 |
+
_FinalAnswerTool._forward.__source__ = inspect.getsource(original_forward_function).replace(
|
140 |
+
"def forward(", "def _forward("
|
141 |
+
)
|
142 |
+
|
143 |
+
# Set the new class as the class of the final_answer_tool instance
|
144 |
+
final_answer_tool.__class__ = _FinalAnswerTool
|
145 |
+
|
146 |
+
|
147 |
+
class E2BExecutor(RemotePythonExecutor):
|
148 |
+
"""
|
149 |
+
Executes Python code using E2B.
|
150 |
+
|
151 |
+
Args:
|
152 |
+
additional_imports (`list[str]`): Additional imports to install.
|
153 |
+
logger (`Logger`): Logger to use.
|
154 |
+
**kwargs: Additional arguments to pass to the E2B Sandbox.
|
155 |
+
"""
|
156 |
+
|
157 |
+
def __init__(self, additional_imports: list[str], logger, **kwargs):
|
158 |
+
super().__init__(additional_imports, logger)
|
159 |
+
try:
|
160 |
+
from e2b_code_interpreter import Sandbox
|
161 |
+
except ModuleNotFoundError:
|
162 |
+
raise ModuleNotFoundError(
|
163 |
+
"""Please install 'e2b' extra to use E2BExecutor: `pip install 'smolagents[e2b]'`"""
|
164 |
+
)
|
165 |
+
self.sandbox = Sandbox(**kwargs)
|
166 |
+
self.installed_packages = self.install_packages(additional_imports)
|
167 |
+
self.logger.log("E2B is running", level=LogLevel.INFO)
|
168 |
+
|
169 |
+
def run_code_raise_errors(self, code: str) -> tuple[Any, str, bool]:
|
170 |
+
execution = self.sandbox.run_code(code)
|
171 |
+
execution_logs = "\n".join([str(log) for log in execution.logs.stdout])
|
172 |
+
|
173 |
+
# Handle errors
|
174 |
+
if execution.error:
|
175 |
+
# Check if the error is a FinalAnswerException
|
176 |
+
if execution.error.name == RemotePythonExecutor.FINAL_ANSWER_EXCEPTION:
|
177 |
+
final_answer = pickle.loads(base64.b64decode(execution.error.value))
|
178 |
+
return final_answer, execution_logs, True
|
179 |
+
|
180 |
+
# Construct error message
|
181 |
+
error_message = (
|
182 |
+
f"{execution_logs}\n"
|
183 |
+
f"Executing code yielded an error:\n"
|
184 |
+
f"{execution.error.name}\n"
|
185 |
+
f"{execution.error.value}\n"
|
186 |
+
f"{execution.error.traceback}"
|
187 |
+
)
|
188 |
+
raise AgentError(error_message, self.logger)
|
189 |
+
|
190 |
+
# Handle results
|
191 |
+
if not execution.results:
|
192 |
+
return None, execution_logs, False
|
193 |
+
|
194 |
+
for result in execution.results:
|
195 |
+
if not result.is_main_result:
|
196 |
+
continue
|
197 |
+
# Handle image outputs
|
198 |
+
for attribute_name in ["jpeg", "png"]:
|
199 |
+
img_data = getattr(result, attribute_name, None)
|
200 |
+
if img_data is not None:
|
201 |
+
decoded_bytes = base64.b64decode(img_data.encode("utf-8"))
|
202 |
+
return PIL.Image.open(BytesIO(decoded_bytes)), execution_logs, False
|
203 |
+
# Handle other data formats
|
204 |
+
for attribute_name in [
|
205 |
+
"chart",
|
206 |
+
"data",
|
207 |
+
"html",
|
208 |
+
"javascript",
|
209 |
+
"json",
|
210 |
+
"latex",
|
211 |
+
"markdown",
|
212 |
+
"pdf",
|
213 |
+
"svg",
|
214 |
+
"text",
|
215 |
+
]:
|
216 |
+
data = getattr(result, attribute_name, None)
|
217 |
+
if data is not None:
|
218 |
+
return data, execution_logs, False
|
219 |
+
# If no main result found, return None
|
220 |
+
return None, execution_logs, False
|
221 |
+
|
222 |
+
def cleanup(self):
|
223 |
+
"""Clean up the E2B sandbox and resources."""
|
224 |
+
try:
|
225 |
+
if hasattr(self, "sandbox"):
|
226 |
+
self.logger.log("Shutting down sandbox...", level=LogLevel.INFO)
|
227 |
+
self.sandbox.kill()
|
228 |
+
self.logger.log("Sandbox cleanup completed", level=LogLevel.INFO)
|
229 |
+
del self.sandbox
|
230 |
+
except Exception as e:
|
231 |
+
self.logger.log_error(f"Error during cleanup: {e}")
|
232 |
+
|
233 |
+
|
234 |
+
class DockerExecutor(RemotePythonExecutor):
|
235 |
+
"""
|
236 |
+
Executes Python code using Jupyter Kernel Gateway in a Docker container.
|
237 |
+
"""
|
238 |
+
|
239 |
+
def __init__(
|
240 |
+
self,
|
241 |
+
additional_imports: list[str],
|
242 |
+
logger,
|
243 |
+
host: str = "127.0.0.1",
|
244 |
+
port: int = 8888,
|
245 |
+
image_name: str = "jupyter-kernel",
|
246 |
+
build_new_image: bool = True,
|
247 |
+
container_run_kwargs: dict[str, Any] | None = None,
|
248 |
+
):
|
249 |
+
"""
|
250 |
+
Initialize the Docker-based Jupyter Kernel Gateway executor.
|
251 |
+
|
252 |
+
Args:
|
253 |
+
additional_imports: Additional imports to install.
|
254 |
+
logger: Logger to use.
|
255 |
+
host: Host to bind to.
|
256 |
+
port: Port to bind to.
|
257 |
+
image_name: Name of the Docker image to use. If the image doesn't exist, it will be built.
|
258 |
+
build_new_image: If True, the image will be rebuilt even if it already exists.
|
259 |
+
container_run_kwargs: Additional keyword arguments to pass to the Docker container run command.
|
260 |
+
"""
|
261 |
+
super().__init__(additional_imports, logger)
|
262 |
+
try:
|
263 |
+
import docker
|
264 |
+
from websocket import create_connection
|
265 |
+
except ModuleNotFoundError:
|
266 |
+
raise ModuleNotFoundError(
|
267 |
+
"Please install 'docker' extra to use DockerExecutor: `pip install 'smolagents[docker]'`"
|
268 |
+
)
|
269 |
+
self.host = host
|
270 |
+
self.port = port
|
271 |
+
self.image_name = image_name
|
272 |
+
|
273 |
+
# Initialize Docker
|
274 |
+
try:
|
275 |
+
self.client = docker.from_env()
|
276 |
+
except docker.errors.DockerException as e:
|
277 |
+
raise RuntimeError("Could not connect to Docker daemon: make sure Docker is running.") from e
|
278 |
+
|
279 |
+
# Build and start container
|
280 |
+
try:
|
281 |
+
# Check if image exists, unless forced to rebuild
|
282 |
+
if not build_new_image:
|
283 |
+
try:
|
284 |
+
self.client.images.get(self.image_name)
|
285 |
+
self.logger.log(f"Using existing Docker image: {self.image_name}", level=LogLevel.INFO)
|
286 |
+
except docker.errors.ImageNotFound:
|
287 |
+
self.logger.log(f"Image {self.image_name} not found, building...", level=LogLevel.INFO)
|
288 |
+
build_new_image = True
|
289 |
+
|
290 |
+
if build_new_image:
|
291 |
+
self.logger.log(f"Building Docker image {self.image_name}...", level=LogLevel.INFO)
|
292 |
+
dockerfile_path = Path(__file__).parent / "Dockerfile"
|
293 |
+
if not dockerfile_path.exists():
|
294 |
+
with open(dockerfile_path, "w") as f:
|
295 |
+
f.write(
|
296 |
+
dedent(
|
297 |
+
"""\
|
298 |
+
FROM python:3.12-slim
|
299 |
+
|
300 |
+
RUN pip install jupyter_kernel_gateway jupyter_client
|
301 |
+
|
302 |
+
EXPOSE 8888
|
303 |
+
CMD ["jupyter", "kernelgateway", "--KernelGatewayApp.ip='0.0.0.0'", "--KernelGatewayApp.port=8888", "--KernelGatewayApp.allow_origin='*'"]
|
304 |
+
"""
|
305 |
+
)
|
306 |
+
)
|
307 |
+
_, build_logs = self.client.images.build(
|
308 |
+
path=str(dockerfile_path.parent), dockerfile=str(dockerfile_path), tag=self.image_name
|
309 |
+
)
|
310 |
+
for log_chunk in build_logs:
|
311 |
+
# Only log non-empty messages
|
312 |
+
if log_message := log_chunk.get("stream", "").rstrip():
|
313 |
+
self.logger.log(log_message, level=LogLevel.DEBUG)
|
314 |
+
|
315 |
+
self.logger.log(f"Starting container on {host}:{port}...", level=LogLevel.INFO)
|
316 |
+
# Create base container parameters
|
317 |
+
container_kwargs = {}
|
318 |
+
if container_run_kwargs:
|
319 |
+
container_kwargs.update(container_run_kwargs)
|
320 |
+
|
321 |
+
# Ensure required port mapping and background running
|
322 |
+
if not isinstance(container_kwargs.get("ports"), dict):
|
323 |
+
container_kwargs["ports"] = {}
|
324 |
+
container_kwargs["ports"]["8888/tcp"] = (host, port)
|
325 |
+
container_kwargs["detach"] = True
|
326 |
+
|
327 |
+
self.container = self.client.containers.run(self.image_name, **container_kwargs)
|
328 |
+
|
329 |
+
retries = 0
|
330 |
+
while self.container.status != "running" and retries < 5:
|
331 |
+
self.logger.log(f"Container status: {self.container.status}, waiting...", level=LogLevel.INFO)
|
332 |
+
time.sleep(1)
|
333 |
+
self.container.reload()
|
334 |
+
retries += 1
|
335 |
+
|
336 |
+
self.base_url = f"http://{host}:{port}"
|
337 |
+
|
338 |
+
# Create new kernel via HTTP
|
339 |
+
r = requests.post(f"{self.base_url}/api/kernels")
|
340 |
+
if r.status_code != 201:
|
341 |
+
error_details = {
|
342 |
+
"status_code": r.status_code,
|
343 |
+
"headers": dict(r.headers),
|
344 |
+
"url": r.url,
|
345 |
+
"body": r.text,
|
346 |
+
"request_method": r.request.method,
|
347 |
+
"request_headers": dict(r.request.headers),
|
348 |
+
"request_body": r.request.body,
|
349 |
+
}
|
350 |
+
self.logger.log_error(f"Failed to create kernel. Details: {json.dumps(error_details, indent=2)}")
|
351 |
+
raise RuntimeError(f"Failed to create kernel: Status {r.status_code}\nResponse: {r.text}") from None
|
352 |
+
|
353 |
+
self.kernel_id = r.json()["id"]
|
354 |
+
|
355 |
+
ws_url = f"ws://{host}:{port}/api/kernels/{self.kernel_id}/channels"
|
356 |
+
self.ws = create_connection(ws_url)
|
357 |
+
|
358 |
+
self.installed_packages = self.install_packages(additional_imports)
|
359 |
+
self.logger.log(
|
360 |
+
f"Container {self.container.short_id} is running with kernel {self.kernel_id}", level=LogLevel.INFO
|
361 |
+
)
|
362 |
+
|
363 |
+
except Exception as e:
|
364 |
+
self.cleanup()
|
365 |
+
raise RuntimeError(f"Failed to initialize Jupyter kernel: {e}") from e
|
366 |
+
|
367 |
+
def run_code_raise_errors(self, code_action: str) -> tuple[Any, str, bool]:
|
368 |
+
try:
|
369 |
+
# Send execute request
|
370 |
+
msg_id = self._send_execute_request(code_action)
|
371 |
+
|
372 |
+
# Collect output and results
|
373 |
+
outputs = []
|
374 |
+
result = None
|
375 |
+
is_final_answer = False
|
376 |
+
|
377 |
+
while True:
|
378 |
+
msg = json.loads(self.ws.recv())
|
379 |
+
parent_msg_id = msg.get("parent_header", {}).get("msg_id")
|
380 |
+
# Skip unrelated messages
|
381 |
+
if parent_msg_id != msg_id:
|
382 |
+
continue
|
383 |
+
msg_type = msg.get("msg_type", "")
|
384 |
+
msg_content = msg.get("content", {})
|
385 |
+
if msg_type == "stream":
|
386 |
+
outputs.append(msg_content["text"])
|
387 |
+
elif msg_type == "execute_result":
|
388 |
+
result = msg_content["data"].get("text/plain", None)
|
389 |
+
elif msg_type == "error":
|
390 |
+
if msg_content.get("ename", "") == RemotePythonExecutor.FINAL_ANSWER_EXCEPTION:
|
391 |
+
result = pickle.loads(base64.b64decode(msg_content.get("evalue", "")))
|
392 |
+
is_final_answer = True
|
393 |
+
else:
|
394 |
+
raise AgentError("\n".join(msg_content.get("traceback", [])), self.logger)
|
395 |
+
elif msg_type == "status" and msg_content["execution_state"] == "idle":
|
396 |
+
break
|
397 |
+
|
398 |
+
return result, "".join(outputs), is_final_answer
|
399 |
+
|
400 |
+
except Exception as e:
|
401 |
+
self.logger.log_error(f"Code execution failed: {e}")
|
402 |
+
raise
|
403 |
+
|
404 |
+
def _send_execute_request(self, code: str) -> str:
|
405 |
+
"""Send code execution request to kernel."""
|
406 |
+
import uuid
|
407 |
+
|
408 |
+
# Generate a unique message ID
|
409 |
+
msg_id = str(uuid.uuid4())
|
410 |
+
|
411 |
+
# Create execute request
|
412 |
+
execute_request = {
|
413 |
+
"header": {
|
414 |
+
"msg_id": msg_id,
|
415 |
+
"username": "anonymous",
|
416 |
+
"session": str(uuid.uuid4()),
|
417 |
+
"msg_type": "execute_request",
|
418 |
+
"version": "5.0",
|
419 |
+
},
|
420 |
+
"parent_header": {},
|
421 |
+
"metadata": {},
|
422 |
+
"content": {
|
423 |
+
"code": code,
|
424 |
+
"silent": False,
|
425 |
+
"store_history": True,
|
426 |
+
"user_expressions": {},
|
427 |
+
"allow_stdin": False,
|
428 |
+
},
|
429 |
+
}
|
430 |
+
|
431 |
+
self.ws.send(json.dumps(execute_request))
|
432 |
+
return msg_id
|
433 |
+
|
434 |
+
def cleanup(self):
|
435 |
+
"""Clean up the Docker container and resources."""
|
436 |
+
try:
|
437 |
+
if hasattr(self, "container"):
|
438 |
+
self.logger.log(f"Stopping and removing container {self.container.short_id}...", level=LogLevel.INFO)
|
439 |
+
self.container.stop()
|
440 |
+
self.container.remove()
|
441 |
+
self.logger.log("Container cleanup completed", level=LogLevel.INFO)
|
442 |
+
del self.container
|
443 |
+
except Exception as e:
|
444 |
+
self.logger.log_error(f"Error during cleanup: {e}")
|
445 |
+
|
446 |
+
def delete(self):
|
447 |
+
"""Ensure cleanup on deletion."""
|
448 |
+
self.cleanup()
|
449 |
+
|
450 |
+
|
451 |
+
__all__ = ["E2BExecutor", "DockerExecutor"]
|
src/smolagents/tool_validation.py
ADDED
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import ast
|
2 |
+
import builtins
|
3 |
+
from itertools import zip_longest
|
4 |
+
|
5 |
+
from .utils import BASE_BUILTIN_MODULES, get_source, is_valid_name
|
6 |
+
|
7 |
+
|
8 |
+
_BUILTIN_NAMES = set(vars(builtins))
|
9 |
+
|
10 |
+
|
11 |
+
class MethodChecker(ast.NodeVisitor):
|
12 |
+
"""
|
13 |
+
Checks that a method
|
14 |
+
- only uses defined names
|
15 |
+
- contains no local imports (e.g. numpy is ok but local_script is not)
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self, class_attributes: set[str], check_imports: bool = True):
|
19 |
+
self.undefined_names = set()
|
20 |
+
self.imports = {}
|
21 |
+
self.from_imports = {}
|
22 |
+
self.assigned_names = set()
|
23 |
+
self.arg_names = set()
|
24 |
+
self.class_attributes = class_attributes
|
25 |
+
self.errors = []
|
26 |
+
self.check_imports = check_imports
|
27 |
+
self.typing_names = {"Any"}
|
28 |
+
self.defined_classes = set()
|
29 |
+
|
30 |
+
def visit_arguments(self, node):
|
31 |
+
"""Collect function arguments"""
|
32 |
+
self.arg_names = {arg.arg for arg in node.args}
|
33 |
+
if node.kwarg:
|
34 |
+
self.arg_names.add(node.kwarg.arg)
|
35 |
+
if node.vararg:
|
36 |
+
self.arg_names.add(node.vararg.arg)
|
37 |
+
|
38 |
+
def visit_Import(self, node):
|
39 |
+
for name in node.names:
|
40 |
+
actual_name = name.asname or name.name
|
41 |
+
self.imports[actual_name] = name.name
|
42 |
+
|
43 |
+
def visit_ImportFrom(self, node):
|
44 |
+
module = node.module or ""
|
45 |
+
for name in node.names:
|
46 |
+
actual_name = name.asname or name.name
|
47 |
+
self.from_imports[actual_name] = (module, name.name)
|
48 |
+
|
49 |
+
def visit_Assign(self, node):
|
50 |
+
for target in node.targets:
|
51 |
+
if isinstance(target, ast.Name):
|
52 |
+
self.assigned_names.add(target.id)
|
53 |
+
elif isinstance(target, (ast.Tuple, ast.List)):
|
54 |
+
for elt in target.elts:
|
55 |
+
if isinstance(elt, ast.Name):
|
56 |
+
self.assigned_names.add(elt.id)
|
57 |
+
self.visit(node.value)
|
58 |
+
|
59 |
+
def visit_With(self, node):
|
60 |
+
"""Track aliases in 'with' statements (the 'y' in 'with X as y')"""
|
61 |
+
for item in node.items:
|
62 |
+
if item.optional_vars: # This is the 'y' in 'with X as y'
|
63 |
+
if isinstance(item.optional_vars, ast.Name):
|
64 |
+
self.assigned_names.add(item.optional_vars.id)
|
65 |
+
self.generic_visit(node)
|
66 |
+
|
67 |
+
def visit_ExceptHandler(self, node):
|
68 |
+
"""Track exception aliases (the 'e' in 'except Exception as e')"""
|
69 |
+
if node.name: # This is the 'e' in 'except Exception as e'
|
70 |
+
self.assigned_names.add(node.name)
|
71 |
+
self.generic_visit(node)
|
72 |
+
|
73 |
+
def visit_AnnAssign(self, node):
|
74 |
+
"""Track annotated assignments."""
|
75 |
+
if isinstance(node.target, ast.Name):
|
76 |
+
self.assigned_names.add(node.target.id)
|
77 |
+
if node.value:
|
78 |
+
self.visit(node.value)
|
79 |
+
|
80 |
+
def visit_For(self, node):
|
81 |
+
target = node.target
|
82 |
+
if isinstance(target, ast.Name):
|
83 |
+
self.assigned_names.add(target.id)
|
84 |
+
elif isinstance(target, ast.Tuple):
|
85 |
+
for elt in target.elts:
|
86 |
+
if isinstance(elt, ast.Name):
|
87 |
+
self.assigned_names.add(elt.id)
|
88 |
+
self.generic_visit(node)
|
89 |
+
|
90 |
+
def _handle_comprehension_generators(self, generators):
|
91 |
+
"""Helper method to handle generators in all types of comprehensions"""
|
92 |
+
for generator in generators:
|
93 |
+
if isinstance(generator.target, ast.Name):
|
94 |
+
self.assigned_names.add(generator.target.id)
|
95 |
+
elif isinstance(generator.target, ast.Tuple):
|
96 |
+
for elt in generator.target.elts:
|
97 |
+
if isinstance(elt, ast.Name):
|
98 |
+
self.assigned_names.add(elt.id)
|
99 |
+
|
100 |
+
def visit_ListComp(self, node):
|
101 |
+
"""Track variables in list comprehensions"""
|
102 |
+
self._handle_comprehension_generators(node.generators)
|
103 |
+
self.generic_visit(node)
|
104 |
+
|
105 |
+
def visit_DictComp(self, node):
|
106 |
+
"""Track variables in dictionary comprehensions"""
|
107 |
+
self._handle_comprehension_generators(node.generators)
|
108 |
+
self.generic_visit(node)
|
109 |
+
|
110 |
+
def visit_SetComp(self, node):
|
111 |
+
"""Track variables in set comprehensions"""
|
112 |
+
self._handle_comprehension_generators(node.generators)
|
113 |
+
self.generic_visit(node)
|
114 |
+
|
115 |
+
def visit_Attribute(self, node):
|
116 |
+
if not (isinstance(node.value, ast.Name) and node.value.id == "self"):
|
117 |
+
self.generic_visit(node)
|
118 |
+
|
119 |
+
def visit_ClassDef(self, node):
|
120 |
+
"""Track class definitions"""
|
121 |
+
self.defined_classes.add(node.name)
|
122 |
+
self.generic_visit(node)
|
123 |
+
|
124 |
+
def visit_Name(self, node):
|
125 |
+
if isinstance(node.ctx, ast.Load):
|
126 |
+
if not (
|
127 |
+
node.id in _BUILTIN_NAMES
|
128 |
+
or node.id in BASE_BUILTIN_MODULES
|
129 |
+
or node.id in self.arg_names
|
130 |
+
or node.id == "self"
|
131 |
+
or node.id in self.class_attributes
|
132 |
+
or node.id in self.imports
|
133 |
+
or node.id in self.from_imports
|
134 |
+
or node.id in self.assigned_names
|
135 |
+
or node.id in self.typing_names
|
136 |
+
or node.id in self.defined_classes
|
137 |
+
):
|
138 |
+
self.errors.append(f"Name '{node.id}' is undefined.")
|
139 |
+
|
140 |
+
def visit_Call(self, node):
|
141 |
+
if isinstance(node.func, ast.Name):
|
142 |
+
if not (
|
143 |
+
node.func.id in _BUILTIN_NAMES
|
144 |
+
or node.func.id in BASE_BUILTIN_MODULES
|
145 |
+
or node.func.id in self.arg_names
|
146 |
+
or node.func.id == "self"
|
147 |
+
or node.func.id in self.class_attributes
|
148 |
+
or node.func.id in self.imports
|
149 |
+
or node.func.id in self.from_imports
|
150 |
+
or node.func.id in self.assigned_names
|
151 |
+
or node.func.id in self.defined_classes
|
152 |
+
):
|
153 |
+
self.errors.append(f"Name '{node.func.id}' is undefined.")
|
154 |
+
self.generic_visit(node)
|
155 |
+
|
156 |
+
|
157 |
+
def validate_tool_attributes(cls, check_imports: bool = True) -> None:
|
158 |
+
"""
|
159 |
+
Validates that a Tool class follows the proper patterns:
|
160 |
+
0. Any argument of __init__ should have a default.
|
161 |
+
Args chosen at init are not traceable, so we cannot rebuild the source code for them, thus any important arg should be defined as a class attribute.
|
162 |
+
1. About the class:
|
163 |
+
- Class attributes should only be strings or dicts
|
164 |
+
- Class attributes cannot be complex attributes
|
165 |
+
2. About all class methods:
|
166 |
+
- Imports must be from packages, not local files
|
167 |
+
- All methods must be self-contained
|
168 |
+
|
169 |
+
Raises all errors encountered, if no error returns None.
|
170 |
+
"""
|
171 |
+
|
172 |
+
class ClassLevelChecker(ast.NodeVisitor):
|
173 |
+
def __init__(self):
|
174 |
+
self.imported_names = set()
|
175 |
+
self.complex_attributes = set()
|
176 |
+
self.class_attributes = set()
|
177 |
+
self.non_defaults = set()
|
178 |
+
self.non_literal_defaults = set()
|
179 |
+
self.in_method = False
|
180 |
+
self.invalid_attributes = []
|
181 |
+
|
182 |
+
def visit_FunctionDef(self, node):
|
183 |
+
if node.name == "__init__":
|
184 |
+
self._check_init_function_parameters(node)
|
185 |
+
old_context = self.in_method
|
186 |
+
self.in_method = True
|
187 |
+
self.generic_visit(node)
|
188 |
+
self.in_method = old_context
|
189 |
+
|
190 |
+
def visit_Assign(self, node):
|
191 |
+
if self.in_method:
|
192 |
+
return
|
193 |
+
# Track class attributes
|
194 |
+
for target in node.targets:
|
195 |
+
if isinstance(target, ast.Name):
|
196 |
+
self.class_attributes.add(target.id)
|
197 |
+
|
198 |
+
# Check if the assignment is more complex than simple literals
|
199 |
+
if not all(
|
200 |
+
isinstance(val, (ast.Str, ast.Num, ast.Constant, ast.Dict, ast.List, ast.Set))
|
201 |
+
for val in ast.walk(node.value)
|
202 |
+
):
|
203 |
+
for target in node.targets:
|
204 |
+
if isinstance(target, ast.Name):
|
205 |
+
self.complex_attributes.add(target.id)
|
206 |
+
|
207 |
+
# Check specific class attributes
|
208 |
+
if getattr(node.targets[0], "id", "") == "name":
|
209 |
+
if not isinstance(node.value, ast.Constant):
|
210 |
+
self.invalid_attributes.append(f"Class attribute 'name' must be a constant, found '{node.value}'")
|
211 |
+
elif not isinstance(node.value.value, str):
|
212 |
+
self.invalid_attributes.append(
|
213 |
+
f"Class attribute 'name' must be a string, found '{node.value.value}'"
|
214 |
+
)
|
215 |
+
elif not is_valid_name(node.value.value):
|
216 |
+
self.invalid_attributes.append(
|
217 |
+
f"Class attribute 'name' must be a valid Python identifier and not a reserved keyword, found '{node.value.value}'"
|
218 |
+
)
|
219 |
+
|
220 |
+
def _check_init_function_parameters(self, node):
|
221 |
+
# Check defaults in parameters
|
222 |
+
for arg, default in reversed(list(zip_longest(reversed(node.args.args), reversed(node.args.defaults)))):
|
223 |
+
if default is None:
|
224 |
+
if arg.arg != "self":
|
225 |
+
self.non_defaults.add(arg.arg)
|
226 |
+
elif not isinstance(default, (ast.Str, ast.Num, ast.Constant, ast.Dict, ast.List, ast.Set)):
|
227 |
+
self.non_literal_defaults.add(arg.arg)
|
228 |
+
|
229 |
+
class_level_checker = ClassLevelChecker()
|
230 |
+
source = get_source(cls)
|
231 |
+
tree = ast.parse(source)
|
232 |
+
class_node = tree.body[0]
|
233 |
+
if not isinstance(class_node, ast.ClassDef):
|
234 |
+
raise ValueError("Source code must define a class")
|
235 |
+
class_level_checker.visit(class_node)
|
236 |
+
|
237 |
+
errors = []
|
238 |
+
# Check invalid class attributes
|
239 |
+
if class_level_checker.invalid_attributes:
|
240 |
+
errors += class_level_checker.invalid_attributes
|
241 |
+
if class_level_checker.complex_attributes:
|
242 |
+
errors.append(
|
243 |
+
f"Complex attributes should be defined in __init__, not as class attributes: "
|
244 |
+
f"{', '.join(class_level_checker.complex_attributes)}"
|
245 |
+
)
|
246 |
+
if class_level_checker.non_defaults:
|
247 |
+
errors.append(
|
248 |
+
f"Parameters in __init__ must have default values, found required parameters: "
|
249 |
+
f"{', '.join(class_level_checker.non_defaults)}"
|
250 |
+
)
|
251 |
+
if class_level_checker.non_literal_defaults:
|
252 |
+
errors.append(
|
253 |
+
f"Parameters in __init__ must have literal default values, found non-literal defaults: "
|
254 |
+
f"{', '.join(class_level_checker.non_literal_defaults)}"
|
255 |
+
)
|
256 |
+
|
257 |
+
# Run checks on all methods
|
258 |
+
for node in class_node.body:
|
259 |
+
if isinstance(node, ast.FunctionDef):
|
260 |
+
method_checker = MethodChecker(class_level_checker.class_attributes, check_imports=check_imports)
|
261 |
+
method_checker.visit(node)
|
262 |
+
errors += [f"- {node.name}: {error}" for error in method_checker.errors]
|
263 |
+
|
264 |
+
if errors:
|
265 |
+
raise ValueError(f"Tool validation failed for {cls.__name__}:\n" + "\n".join(errors))
|
266 |
+
return
|
src/smolagents/tools.py
ADDED
@@ -0,0 +1,1239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding=utf-8
|
3 |
+
|
4 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
5 |
+
#
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# Unless required by applicable law or agreed to in writing, software
|
13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
# See the License for the specific language governing permissions and
|
16 |
+
# limitations under the License.
|
17 |
+
from __future__ import annotations
|
18 |
+
|
19 |
+
import ast
|
20 |
+
import inspect
|
21 |
+
import json
|
22 |
+
import logging
|
23 |
+
import os
|
24 |
+
import sys
|
25 |
+
import tempfile
|
26 |
+
import textwrap
|
27 |
+
import types
|
28 |
+
import warnings
|
29 |
+
from collections.abc import Callable
|
30 |
+
from contextlib import contextmanager
|
31 |
+
from functools import wraps
|
32 |
+
from pathlib import Path
|
33 |
+
from typing import TYPE_CHECKING, Any
|
34 |
+
|
35 |
+
from huggingface_hub import (
|
36 |
+
CommitOperationAdd,
|
37 |
+
create_commit,
|
38 |
+
create_repo,
|
39 |
+
get_collection,
|
40 |
+
hf_hub_download,
|
41 |
+
metadata_update,
|
42 |
+
)
|
43 |
+
|
44 |
+
from ._function_type_hints_utils import (
|
45 |
+
TypeHintParsingException,
|
46 |
+
_convert_type_hints_to_json_schema,
|
47 |
+
_get_json_schema_type,
|
48 |
+
get_imports,
|
49 |
+
get_json_schema,
|
50 |
+
)
|
51 |
+
from .agent_types import handle_agent_input_types, handle_agent_output_types
|
52 |
+
from .tool_validation import MethodChecker, validate_tool_attributes
|
53 |
+
from .utils import (
|
54 |
+
BASE_BUILTIN_MODULES,
|
55 |
+
_is_package_available,
|
56 |
+
get_source,
|
57 |
+
instance_to_source,
|
58 |
+
is_valid_name,
|
59 |
+
)
|
60 |
+
|
61 |
+
|
62 |
+
if TYPE_CHECKING:
|
63 |
+
import mcp
|
64 |
+
|
65 |
+
|
66 |
+
logger = logging.getLogger(__name__)
|
67 |
+
|
68 |
+
|
69 |
+
def validate_after_init(cls):
|
70 |
+
original_init = cls.__init__
|
71 |
+
|
72 |
+
@wraps(original_init)
|
73 |
+
def new_init(self, *args, **kwargs):
|
74 |
+
original_init(self, *args, **kwargs)
|
75 |
+
self.validate_arguments()
|
76 |
+
|
77 |
+
cls.__init__ = new_init
|
78 |
+
return cls
|
79 |
+
|
80 |
+
|
81 |
+
AUTHORIZED_TYPES = [
|
82 |
+
"string",
|
83 |
+
"boolean",
|
84 |
+
"integer",
|
85 |
+
"number",
|
86 |
+
"image",
|
87 |
+
"audio",
|
88 |
+
"array",
|
89 |
+
"object",
|
90 |
+
"any",
|
91 |
+
"null",
|
92 |
+
]
|
93 |
+
|
94 |
+
CONVERSION_DICT = {"str": "string", "int": "integer", "float": "number"}
|
95 |
+
|
96 |
+
|
97 |
+
class Tool:
|
98 |
+
"""
|
99 |
+
A base class for the functions used by the agent. Subclass this and implement the `forward` method as well as the
|
100 |
+
following class attributes:
|
101 |
+
|
102 |
+
- **description** (`str`) -- A short description of what your tool does, the inputs it expects and the output(s) it
|
103 |
+
will return. For instance 'This is a tool that downloads a file from a `url`. It takes the `url` as input, and
|
104 |
+
returns the text contained in the file'.
|
105 |
+
- **name** (`str`) -- A performative name that will be used for your tool in the prompt to the agent. For instance
|
106 |
+
`"text-classifier"` or `"image_generator"`.
|
107 |
+
- **inputs** (`Dict[str, Dict[str, Union[str, type, bool]]]`) -- The dict of modalities expected for the inputs.
|
108 |
+
It has one `type`key and a `description`key.
|
109 |
+
This is used by `launch_gradio_demo` or to make a nice space from your tool, and also can be used in the generated
|
110 |
+
description for your tool.
|
111 |
+
- **output_type** (`type`) -- The type of the tool output. This is used by `launch_gradio_demo`
|
112 |
+
or to make a nice space from your tool, and also can be used in the generated description for your tool.
|
113 |
+
|
114 |
+
You can also override the method [`~Tool.setup`] if your tool has an expensive operation to perform before being
|
115 |
+
usable (such as loading a model). [`~Tool.setup`] will be called the first time you use your tool, but not at
|
116 |
+
instantiation.
|
117 |
+
"""
|
118 |
+
|
119 |
+
name: str
|
120 |
+
description: str
|
121 |
+
inputs: dict[str, dict[str, str | type | bool]]
|
122 |
+
output_type: str
|
123 |
+
|
124 |
+
def __init__(self, *args, **kwargs):
|
125 |
+
self.is_initialized = False
|
126 |
+
|
127 |
+
def __init_subclass__(cls, **kwargs):
|
128 |
+
super().__init_subclass__(**kwargs)
|
129 |
+
validate_after_init(cls)
|
130 |
+
|
131 |
+
def validate_arguments(self):
|
132 |
+
required_attributes = {
|
133 |
+
"description": str,
|
134 |
+
"name": str,
|
135 |
+
"inputs": dict,
|
136 |
+
"output_type": str,
|
137 |
+
}
|
138 |
+
# Validate class attributes
|
139 |
+
for attr, expected_type in required_attributes.items():
|
140 |
+
attr_value = getattr(self, attr, None)
|
141 |
+
if attr_value is None:
|
142 |
+
raise TypeError(f"You must set an attribute {attr}.")
|
143 |
+
if not isinstance(attr_value, expected_type):
|
144 |
+
raise TypeError(
|
145 |
+
f"Attribute {attr} should have type {expected_type.__name__}, got {type(attr_value)} instead."
|
146 |
+
)
|
147 |
+
# - Validate name
|
148 |
+
if not is_valid_name(self.name):
|
149 |
+
raise Exception(
|
150 |
+
f"Invalid Tool name '{self.name}': must be a valid Python identifier and not a reserved keyword"
|
151 |
+
)
|
152 |
+
# Validate inputs
|
153 |
+
for input_name, input_content in self.inputs.items():
|
154 |
+
assert isinstance(input_content, dict), f"Input '{input_name}' should be a dictionary."
|
155 |
+
assert "type" in input_content and "description" in input_content, (
|
156 |
+
f"Input '{input_name}' should have keys 'type' and 'description', has only {list(input_content.keys())}."
|
157 |
+
)
|
158 |
+
if input_content["type"] not in AUTHORIZED_TYPES:
|
159 |
+
raise Exception(
|
160 |
+
f"Input '{input_name}': type '{input_content['type']}' is not an authorized value, should be one of {AUTHORIZED_TYPES}."
|
161 |
+
)
|
162 |
+
# Validate output type
|
163 |
+
assert getattr(self, "output_type", None) in AUTHORIZED_TYPES
|
164 |
+
|
165 |
+
# Validate forward function signature, except for Tools that use a "generic" signature (PipelineTool, SpaceToolWrapper, LangChainToolWrapper)
|
166 |
+
if not (
|
167 |
+
hasattr(self, "skip_forward_signature_validation")
|
168 |
+
and getattr(self, "skip_forward_signature_validation") is True
|
169 |
+
):
|
170 |
+
signature = inspect.signature(self.forward)
|
171 |
+
actual_keys = set(key for key in signature.parameters.keys() if key != "self")
|
172 |
+
expected_keys = set(self.inputs.keys())
|
173 |
+
if actual_keys != expected_keys:
|
174 |
+
raise Exception(
|
175 |
+
f"In tool '{self.name}', 'forward' method parameters were {actual_keys}, but expected {expected_keys}. "
|
176 |
+
f"It should take 'self' as its first argument, then its next arguments should match the keys of tool attribute 'inputs'."
|
177 |
+
)
|
178 |
+
|
179 |
+
json_schema = _convert_type_hints_to_json_schema(self.forward, error_on_missing_type_hints=False)[
|
180 |
+
"properties"
|
181 |
+
] # This function will not raise an error on missing docstrings, contrary to get_json_schema
|
182 |
+
for key, value in self.inputs.items():
|
183 |
+
assert key in json_schema, (
|
184 |
+
f"Input '{key}' should be present in function signature, found only {json_schema.keys()}"
|
185 |
+
)
|
186 |
+
if "nullable" in value:
|
187 |
+
assert "nullable" in json_schema[key], (
|
188 |
+
f"Nullable argument '{key}' in inputs should have key 'nullable' set to True in function signature."
|
189 |
+
)
|
190 |
+
if key in json_schema and "nullable" in json_schema[key]:
|
191 |
+
assert "nullable" in value, (
|
192 |
+
f"Nullable argument '{key}' in function signature should have key 'nullable' set to True in inputs."
|
193 |
+
)
|
194 |
+
|
195 |
+
def forward(self, *args, **kwargs):
|
196 |
+
return NotImplementedError("Write this method in your subclass of `Tool`.")
|
197 |
+
|
198 |
+
def __call__(self, *args, sanitize_inputs_outputs: bool = False, **kwargs):
|
199 |
+
if not self.is_initialized:
|
200 |
+
self.setup()
|
201 |
+
|
202 |
+
# Handle the arguments might be passed as a single dictionary
|
203 |
+
if len(args) == 1 and len(kwargs) == 0 and isinstance(args[0], dict):
|
204 |
+
potential_kwargs = args[0]
|
205 |
+
|
206 |
+
# If the dictionary keys match our input parameters, convert it to kwargs
|
207 |
+
if all(key in self.inputs for key in potential_kwargs):
|
208 |
+
args = ()
|
209 |
+
kwargs = potential_kwargs
|
210 |
+
|
211 |
+
if sanitize_inputs_outputs:
|
212 |
+
args, kwargs = handle_agent_input_types(*args, **kwargs)
|
213 |
+
outputs = self.forward(*args, **kwargs)
|
214 |
+
if sanitize_inputs_outputs:
|
215 |
+
outputs = handle_agent_output_types(outputs, self.output_type)
|
216 |
+
return outputs
|
217 |
+
|
218 |
+
def setup(self):
|
219 |
+
"""
|
220 |
+
Overwrite this method here for any operation that is expensive and needs to be executed before you start using
|
221 |
+
your tool. Such as loading a big model.
|
222 |
+
"""
|
223 |
+
self.is_initialized = True
|
224 |
+
|
225 |
+
def to_dict(self) -> dict:
|
226 |
+
"""Returns a dictionary representing the tool"""
|
227 |
+
class_name = self.__class__.__name__
|
228 |
+
if type(self).__name__ == "SimpleTool":
|
229 |
+
# Check that imports are self-contained
|
230 |
+
source_code = get_source(self.forward).replace("@tool", "")
|
231 |
+
forward_node = ast.parse(source_code)
|
232 |
+
# If tool was created using '@tool' decorator, it has only a forward pass, so it's simpler to just get its code
|
233 |
+
method_checker = MethodChecker(set())
|
234 |
+
method_checker.visit(forward_node)
|
235 |
+
|
236 |
+
if len(method_checker.errors) > 0:
|
237 |
+
errors = [f"- {error}" for error in method_checker.errors]
|
238 |
+
raise (ValueError(f"SimpleTool validation failed for {self.name}:\n" + "\n".join(errors)))
|
239 |
+
|
240 |
+
forward_source_code = get_source(self.forward)
|
241 |
+
tool_code = textwrap.dedent(
|
242 |
+
f"""
|
243 |
+
from smolagents import Tool
|
244 |
+
from typing import Any, Optional
|
245 |
+
|
246 |
+
class {class_name}(Tool):
|
247 |
+
name = "{self.name}"
|
248 |
+
description = {json.dumps(textwrap.dedent(self.description).strip())}
|
249 |
+
inputs = {repr(self.inputs)}
|
250 |
+
output_type = "{self.output_type}"
|
251 |
+
"""
|
252 |
+
).strip()
|
253 |
+
import re
|
254 |
+
|
255 |
+
def add_self_argument(source_code: str) -> str:
|
256 |
+
"""Add 'self' as first argument to a function definition if not present."""
|
257 |
+
pattern = r"def forward\(((?!self)[^)]*)\)"
|
258 |
+
|
259 |
+
def replacement(match):
|
260 |
+
args = match.group(1).strip()
|
261 |
+
if args: # If there are other arguments
|
262 |
+
return f"def forward(self, {args})"
|
263 |
+
return "def forward(self)"
|
264 |
+
|
265 |
+
return re.sub(pattern, replacement, source_code)
|
266 |
+
|
267 |
+
forward_source_code = forward_source_code.replace(self.name, "forward")
|
268 |
+
forward_source_code = add_self_argument(forward_source_code)
|
269 |
+
forward_source_code = forward_source_code.replace("@tool", "").strip()
|
270 |
+
tool_code += "\n\n" + textwrap.indent(forward_source_code, " ")
|
271 |
+
|
272 |
+
else: # If the tool was not created by the @tool decorator, it was made by subclassing Tool
|
273 |
+
if type(self).__name__ in [
|
274 |
+
"SpaceToolWrapper",
|
275 |
+
"LangChainToolWrapper",
|
276 |
+
"GradioToolWrapper",
|
277 |
+
]:
|
278 |
+
raise ValueError(
|
279 |
+
"Cannot save objects created with from_space, from_langchain or from_gradio, as this would create errors."
|
280 |
+
)
|
281 |
+
|
282 |
+
validate_tool_attributes(self.__class__)
|
283 |
+
|
284 |
+
tool_code = "from typing import Any, Optional\n" + instance_to_source(self, base_cls=Tool)
|
285 |
+
|
286 |
+
requirements = {el for el in get_imports(tool_code) if el not in sys.stdlib_module_names} | {"smolagents"}
|
287 |
+
|
288 |
+
return {"name": self.name, "code": tool_code, "requirements": sorted(requirements)}
|
289 |
+
|
290 |
+
@classmethod
|
291 |
+
def from_dict(cls, tool_dict: dict[str, Any], **kwargs) -> "Tool":
|
292 |
+
"""
|
293 |
+
Create tool from a dictionary representation.
|
294 |
+
|
295 |
+
Args:
|
296 |
+
tool_dict (`dict[str, Any]`): Dictionary representation of the tool.
|
297 |
+
**kwargs: Additional keyword arguments to pass to the tool's constructor.
|
298 |
+
|
299 |
+
Returns:
|
300 |
+
`Tool`: Tool object.
|
301 |
+
"""
|
302 |
+
if "code" not in tool_dict:
|
303 |
+
raise ValueError("Tool dictionary must contain 'code' key with the tool source code")
|
304 |
+
return cls.from_code(tool_dict["code"], **kwargs)
|
305 |
+
|
306 |
+
def save(self, output_dir: str | Path, tool_file_name: str = "tool", make_gradio_app: bool = True):
|
307 |
+
"""
|
308 |
+
Saves the relevant code files for your tool so it can be pushed to the Hub. This will copy the code of your
|
309 |
+
tool in `output_dir` as well as autogenerate:
|
310 |
+
|
311 |
+
- a `{tool_file_name}.py` file containing the logic for your tool.
|
312 |
+
If you pass `make_gradio_app=True`, this will also write:
|
313 |
+
- an `app.py` file providing a UI for your tool when it is exported to a Space with `tool.push_to_hub()`
|
314 |
+
- a `requirements.txt` containing the names of the modules used by your tool (as detected when inspecting its
|
315 |
+
code)
|
316 |
+
|
317 |
+
Args:
|
318 |
+
output_dir (`str` or `Path`): The folder in which you want to save your tool.
|
319 |
+
tool_file_name (`str`, *optional*): The file name in which you want to save your tool.
|
320 |
+
make_gradio_app (`bool`, *optional*, defaults to True): Whether to also export a `requirements.txt` file and Gradio UI.
|
321 |
+
"""
|
322 |
+
# Ensure output directory exists
|
323 |
+
output_path = Path(output_dir)
|
324 |
+
output_path.mkdir(parents=True, exist_ok=True)
|
325 |
+
# Save tool file
|
326 |
+
self._write_file(output_path / f"{tool_file_name}.py", self._get_tool_code())
|
327 |
+
if make_gradio_app:
|
328 |
+
# Save app file
|
329 |
+
self._write_file(output_path / "app.py", self._get_gradio_app_code(tool_module_name=tool_file_name))
|
330 |
+
# Save requirements file
|
331 |
+
self._write_file(output_path / "requirements.txt", self._get_requirements())
|
332 |
+
|
333 |
+
def _write_file(self, file_path: Path, content: str) -> None:
|
334 |
+
"""Writes content to a file with UTF-8 encoding."""
|
335 |
+
file_path.write_text(content, encoding="utf-8")
|
336 |
+
|
337 |
+
def push_to_hub(
|
338 |
+
self,
|
339 |
+
repo_id: str,
|
340 |
+
commit_message: str = "Upload tool",
|
341 |
+
private: bool | None = None,
|
342 |
+
token: bool | str | None = None,
|
343 |
+
create_pr: bool = False,
|
344 |
+
) -> str:
|
345 |
+
"""
|
346 |
+
Upload the tool to the Hub.
|
347 |
+
|
348 |
+
Parameters:
|
349 |
+
repo_id (`str`):
|
350 |
+
The name of the repository you want to push your tool to. It should contain your organization name when
|
351 |
+
pushing to a given organization.
|
352 |
+
commit_message (`str`, *optional*, defaults to `"Upload tool"`):
|
353 |
+
Message to commit while pushing.
|
354 |
+
private (`bool`, *optional*):
|
355 |
+
Whether to make the repo private. If `None` (default), the repo will be public unless the organization's default is private. This value is ignored if the repo already exists.
|
356 |
+
token (`bool` or `str`, *optional*):
|
357 |
+
The token to use as HTTP bearer authorization for remote files. If unset, will use the token generated
|
358 |
+
when running `huggingface-cli login` (stored in `~/.huggingface`).
|
359 |
+
create_pr (`bool`, *optional*, defaults to `False`):
|
360 |
+
Whether to create a PR with the uploaded files or directly commit.
|
361 |
+
"""
|
362 |
+
# Initialize repository
|
363 |
+
repo_id = self._initialize_hub_repo(repo_id, token, private)
|
364 |
+
# Prepare files for commit
|
365 |
+
additions = self._prepare_hub_files()
|
366 |
+
# Create commit
|
367 |
+
return create_commit(
|
368 |
+
repo_id=repo_id,
|
369 |
+
operations=additions,
|
370 |
+
commit_message=commit_message,
|
371 |
+
token=token,
|
372 |
+
create_pr=create_pr,
|
373 |
+
repo_type="space",
|
374 |
+
)
|
375 |
+
|
376 |
+
@staticmethod
|
377 |
+
def _initialize_hub_repo(repo_id: str, token: bool | str | None, private: bool | None) -> str:
|
378 |
+
"""Initialize repository on Hugging Face Hub."""
|
379 |
+
repo_url = create_repo(
|
380 |
+
repo_id=repo_id,
|
381 |
+
token=token,
|
382 |
+
private=private,
|
383 |
+
exist_ok=True,
|
384 |
+
repo_type="space",
|
385 |
+
space_sdk="gradio",
|
386 |
+
)
|
387 |
+
metadata_update(repo_url.repo_id, {"tags": ["smolagents", "tool"]}, repo_type="space", token=token)
|
388 |
+
return repo_url.repo_id
|
389 |
+
|
390 |
+
def _prepare_hub_files(self) -> list:
|
391 |
+
"""Prepare files for Hub commit."""
|
392 |
+
additions = [
|
393 |
+
# Add tool code
|
394 |
+
CommitOperationAdd(
|
395 |
+
path_in_repo="tool.py",
|
396 |
+
path_or_fileobj=self._get_tool_code().encode(),
|
397 |
+
),
|
398 |
+
# Add Gradio app
|
399 |
+
CommitOperationAdd(
|
400 |
+
path_in_repo="app.py",
|
401 |
+
path_or_fileobj=self._get_gradio_app_code().encode(),
|
402 |
+
),
|
403 |
+
# Add requirements
|
404 |
+
CommitOperationAdd(
|
405 |
+
path_in_repo="requirements.txt",
|
406 |
+
path_or_fileobj=self._get_requirements().encode(),
|
407 |
+
),
|
408 |
+
]
|
409 |
+
return additions
|
410 |
+
|
411 |
+
def _get_tool_code(self) -> str:
|
412 |
+
"""Get the tool's code."""
|
413 |
+
return self.to_dict()["code"]
|
414 |
+
|
415 |
+
def _get_gradio_app_code(self, tool_module_name: str = "tool") -> str:
|
416 |
+
"""Get the Gradio app code."""
|
417 |
+
class_name = self.__class__.__name__
|
418 |
+
return textwrap.dedent(
|
419 |
+
f"""\
|
420 |
+
from smolagents import launch_gradio_demo
|
421 |
+
from {tool_module_name} import {class_name}
|
422 |
+
|
423 |
+
tool = {class_name}()
|
424 |
+
launch_gradio_demo(tool)
|
425 |
+
"""
|
426 |
+
)
|
427 |
+
|
428 |
+
def _get_requirements(self) -> str:
|
429 |
+
"""Get the requirements."""
|
430 |
+
return "\n".join(self.to_dict()["requirements"])
|
431 |
+
|
432 |
+
@classmethod
|
433 |
+
def from_hub(
|
434 |
+
cls,
|
435 |
+
repo_id: str,
|
436 |
+
token: str | None = None,
|
437 |
+
trust_remote_code: bool = False,
|
438 |
+
**kwargs,
|
439 |
+
):
|
440 |
+
"""
|
441 |
+
Loads a tool defined on the Hub.
|
442 |
+
|
443 |
+
<Tip warning={true}>
|
444 |
+
|
445 |
+
Loading a tool from the Hub means that you'll download the tool and execute it locally.
|
446 |
+
ALWAYS inspect the tool you're downloading before loading it within your runtime, as you would do when
|
447 |
+
installing a package using pip/npm/apt.
|
448 |
+
|
449 |
+
</Tip>
|
450 |
+
|
451 |
+
Args:
|
452 |
+
repo_id (`str`):
|
453 |
+
The name of the Space repo on the Hub where your tool is defined.
|
454 |
+
token (`str`, *optional*):
|
455 |
+
The token to identify you on hf.co. If unset, will use the token generated when running
|
456 |
+
`huggingface-cli login` (stored in `~/.huggingface`).
|
457 |
+
trust_remote_code(`str`, *optional*, defaults to False):
|
458 |
+
This flags marks that you understand the risk of running remote code and that you trust this tool.
|
459 |
+
If not setting this to True, loading the tool from Hub will fail.
|
460 |
+
kwargs (additional keyword arguments, *optional*):
|
461 |
+
Additional keyword arguments that will be split in two: all arguments relevant to the Hub (such as
|
462 |
+
`cache_dir`, `revision`, `subfolder`) will be used when downloading the files for your tool, and the
|
463 |
+
others will be passed along to its init.
|
464 |
+
"""
|
465 |
+
if not trust_remote_code:
|
466 |
+
raise ValueError(
|
467 |
+
"Loading a tool from Hub requires to acknowledge you trust its code: to do so, pass `trust_remote_code=True`."
|
468 |
+
)
|
469 |
+
|
470 |
+
# Get the tool's tool.py file.
|
471 |
+
tool_file = hf_hub_download(
|
472 |
+
repo_id,
|
473 |
+
"tool.py",
|
474 |
+
token=token,
|
475 |
+
repo_type="space",
|
476 |
+
cache_dir=kwargs.get("cache_dir"),
|
477 |
+
force_download=kwargs.get("force_download"),
|
478 |
+
proxies=kwargs.get("proxies"),
|
479 |
+
revision=kwargs.get("revision"),
|
480 |
+
subfolder=kwargs.get("subfolder"),
|
481 |
+
local_files_only=kwargs.get("local_files_only"),
|
482 |
+
)
|
483 |
+
|
484 |
+
tool_code = Path(tool_file).read_text()
|
485 |
+
return Tool.from_code(tool_code, **kwargs)
|
486 |
+
|
487 |
+
@classmethod
|
488 |
+
def from_code(cls, tool_code: str, **kwargs):
|
489 |
+
module = types.ModuleType("dynamic_tool")
|
490 |
+
|
491 |
+
exec(tool_code, module.__dict__)
|
492 |
+
|
493 |
+
# Find the Tool subclass
|
494 |
+
tool_class = next(
|
495 |
+
(
|
496 |
+
obj
|
497 |
+
for _, obj in inspect.getmembers(module, inspect.isclass)
|
498 |
+
if issubclass(obj, Tool) and obj is not Tool
|
499 |
+
),
|
500 |
+
None,
|
501 |
+
)
|
502 |
+
|
503 |
+
if tool_class is None:
|
504 |
+
raise ValueError("No Tool subclass found in the code.")
|
505 |
+
|
506 |
+
if not isinstance(tool_class.inputs, dict):
|
507 |
+
tool_class.inputs = ast.literal_eval(tool_class.inputs)
|
508 |
+
|
509 |
+
return tool_class(**kwargs)
|
510 |
+
|
511 |
+
@staticmethod
|
512 |
+
def from_space(
|
513 |
+
space_id: str,
|
514 |
+
name: str,
|
515 |
+
description: str,
|
516 |
+
api_name: str | None = None,
|
517 |
+
token: str | None = None,
|
518 |
+
):
|
519 |
+
"""
|
520 |
+
Creates a [`Tool`] from a Space given its id on the Hub.
|
521 |
+
|
522 |
+
Args:
|
523 |
+
space_id (`str`):
|
524 |
+
The id of the Space on the Hub.
|
525 |
+
name (`str`):
|
526 |
+
The name of the tool.
|
527 |
+
description (`str`):
|
528 |
+
The description of the tool.
|
529 |
+
api_name (`str`, *optional*):
|
530 |
+
The specific api_name to use, if the space has several tabs. If not precised, will default to the first available api.
|
531 |
+
token (`str`, *optional*):
|
532 |
+
Add your token to access private spaces or increase your GPU quotas.
|
533 |
+
Returns:
|
534 |
+
[`Tool`]:
|
535 |
+
The Space, as a tool.
|
536 |
+
|
537 |
+
Examples:
|
538 |
+
```py
|
539 |
+
>>> image_generator = Tool.from_space(
|
540 |
+
... space_id="black-forest-labs/FLUX.1-schnell",
|
541 |
+
... name="image-generator",
|
542 |
+
... description="Generate an image from a prompt"
|
543 |
+
... )
|
544 |
+
>>> image = image_generator("Generate an image of a cool surfer in Tahiti")
|
545 |
+
```
|
546 |
+
```py
|
547 |
+
>>> face_swapper = Tool.from_space(
|
548 |
+
... "tuan2308/face-swap",
|
549 |
+
... "face_swapper",
|
550 |
+
... "Tool that puts the face shown on the first image on the second image. You can give it paths to images.",
|
551 |
+
... )
|
552 |
+
>>> image = face_swapper('./aymeric.jpeg', './ruth.jpg')
|
553 |
+
```
|
554 |
+
"""
|
555 |
+
from gradio_client import Client, handle_file
|
556 |
+
|
557 |
+
class SpaceToolWrapper(Tool):
|
558 |
+
skip_forward_signature_validation = True
|
559 |
+
|
560 |
+
def __init__(
|
561 |
+
self,
|
562 |
+
space_id: str,
|
563 |
+
name: str,
|
564 |
+
description: str,
|
565 |
+
api_name: str | None = None,
|
566 |
+
token: str | None = None,
|
567 |
+
):
|
568 |
+
self.name = name
|
569 |
+
self.description = description
|
570 |
+
self.client = Client(space_id, hf_token=token)
|
571 |
+
space_description = self.client.view_api(return_format="dict", print_info=False)["named_endpoints"]
|
572 |
+
|
573 |
+
# If api_name is not defined, take the first of the available APIs for this space
|
574 |
+
if api_name is None:
|
575 |
+
api_name = list(space_description.keys())[0]
|
576 |
+
logger.warning(
|
577 |
+
f"Since `api_name` was not defined, it was automatically set to the first available API: `{api_name}`."
|
578 |
+
)
|
579 |
+
self.api_name = api_name
|
580 |
+
|
581 |
+
try:
|
582 |
+
space_description_api = space_description[api_name]
|
583 |
+
except KeyError:
|
584 |
+
raise KeyError(f"Could not find specified {api_name=} among available api names.")
|
585 |
+
|
586 |
+
self.inputs = {}
|
587 |
+
for parameter in space_description_api["parameters"]:
|
588 |
+
if not parameter["parameter_has_default"]:
|
589 |
+
parameter_type = parameter["type"]["type"]
|
590 |
+
if parameter_type == "object":
|
591 |
+
parameter_type = "any"
|
592 |
+
self.inputs[parameter["parameter_name"]] = {
|
593 |
+
"type": parameter_type,
|
594 |
+
"description": parameter["python_type"]["description"],
|
595 |
+
}
|
596 |
+
output_component = space_description_api["returns"][0]["component"]
|
597 |
+
if output_component == "Image":
|
598 |
+
self.output_type = "image"
|
599 |
+
elif output_component == "Audio":
|
600 |
+
self.output_type = "audio"
|
601 |
+
else:
|
602 |
+
self.output_type = "any"
|
603 |
+
self.is_initialized = True
|
604 |
+
|
605 |
+
def sanitize_argument_for_prediction(self, arg):
|
606 |
+
from gradio_client.utils import is_http_url_like
|
607 |
+
from PIL.Image import Image
|
608 |
+
|
609 |
+
if isinstance(arg, Image):
|
610 |
+
temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
|
611 |
+
arg.save(temp_file.name)
|
612 |
+
arg = temp_file.name
|
613 |
+
if (
|
614 |
+
(isinstance(arg, str) and os.path.isfile(arg))
|
615 |
+
or (isinstance(arg, Path) and arg.exists() and arg.is_file())
|
616 |
+
or is_http_url_like(arg)
|
617 |
+
):
|
618 |
+
arg = handle_file(arg)
|
619 |
+
return arg
|
620 |
+
|
621 |
+
def forward(self, *args, **kwargs):
|
622 |
+
# Preprocess args and kwargs:
|
623 |
+
args = list(args)
|
624 |
+
for i, arg in enumerate(args):
|
625 |
+
args[i] = self.sanitize_argument_for_prediction(arg)
|
626 |
+
for arg_name, arg in kwargs.items():
|
627 |
+
kwargs[arg_name] = self.sanitize_argument_for_prediction(arg)
|
628 |
+
|
629 |
+
output = self.client.predict(*args, api_name=self.api_name, **kwargs)
|
630 |
+
if isinstance(output, tuple) or isinstance(output, list):
|
631 |
+
return output[
|
632 |
+
0
|
633 |
+
] # Sometime the space also returns the generation seed, in which case the result is at index 0
|
634 |
+
return output
|
635 |
+
|
636 |
+
return SpaceToolWrapper(
|
637 |
+
space_id=space_id,
|
638 |
+
name=name,
|
639 |
+
description=description,
|
640 |
+
api_name=api_name,
|
641 |
+
token=token,
|
642 |
+
)
|
643 |
+
|
644 |
+
@staticmethod
|
645 |
+
def from_gradio(gradio_tool):
|
646 |
+
"""
|
647 |
+
Creates a [`Tool`] from a gradio tool.
|
648 |
+
"""
|
649 |
+
import inspect
|
650 |
+
|
651 |
+
class GradioToolWrapper(Tool):
|
652 |
+
def __init__(self, _gradio_tool):
|
653 |
+
self.name = _gradio_tool.name
|
654 |
+
self.description = _gradio_tool.description
|
655 |
+
self.output_type = "string"
|
656 |
+
self._gradio_tool = _gradio_tool
|
657 |
+
func_args = list(inspect.signature(_gradio_tool.run).parameters.items())
|
658 |
+
self.inputs = {
|
659 |
+
key: {"type": CONVERSION_DICT[value.annotation], "description": ""} for key, value in func_args
|
660 |
+
}
|
661 |
+
self.forward = self._gradio_tool.run
|
662 |
+
|
663 |
+
return GradioToolWrapper(gradio_tool)
|
664 |
+
|
665 |
+
@staticmethod
|
666 |
+
def from_langchain(langchain_tool):
|
667 |
+
"""
|
668 |
+
Creates a [`Tool`] from a langchain tool.
|
669 |
+
"""
|
670 |
+
|
671 |
+
class LangChainToolWrapper(Tool):
|
672 |
+
skip_forward_signature_validation = True
|
673 |
+
|
674 |
+
def __init__(self, _langchain_tool):
|
675 |
+
self.name = _langchain_tool.name.lower()
|
676 |
+
self.description = _langchain_tool.description
|
677 |
+
self.inputs = _langchain_tool.args.copy()
|
678 |
+
for input_content in self.inputs.values():
|
679 |
+
if "title" in input_content:
|
680 |
+
input_content.pop("title")
|
681 |
+
input_content["description"] = ""
|
682 |
+
self.output_type = "string"
|
683 |
+
self.langchain_tool = _langchain_tool
|
684 |
+
self.is_initialized = True
|
685 |
+
|
686 |
+
def forward(self, *args, **kwargs):
|
687 |
+
tool_input = kwargs.copy()
|
688 |
+
for index, argument in enumerate(args):
|
689 |
+
if index < len(self.inputs):
|
690 |
+
input_key = next(iter(self.inputs))
|
691 |
+
tool_input[input_key] = argument
|
692 |
+
return self.langchain_tool.run(tool_input)
|
693 |
+
|
694 |
+
return LangChainToolWrapper(langchain_tool)
|
695 |
+
|
696 |
+
|
697 |
+
def launch_gradio_demo(tool: Tool):
|
698 |
+
"""
|
699 |
+
Launches a gradio demo for a tool. The corresponding tool class needs to properly implement the class attributes
|
700 |
+
`inputs` and `output_type`.
|
701 |
+
|
702 |
+
Args:
|
703 |
+
tool (`Tool`): The tool for which to launch the demo.
|
704 |
+
"""
|
705 |
+
try:
|
706 |
+
import gradio as gr
|
707 |
+
except ImportError:
|
708 |
+
raise ImportError("Gradio should be installed in order to launch a gradio demo.")
|
709 |
+
|
710 |
+
TYPE_TO_COMPONENT_CLASS_MAPPING = {
|
711 |
+
"boolean": gr.Checkbox,
|
712 |
+
"image": gr.Image,
|
713 |
+
"audio": gr.Audio,
|
714 |
+
"string": gr.Textbox,
|
715 |
+
"integer": gr.Textbox,
|
716 |
+
"number": gr.Textbox,
|
717 |
+
}
|
718 |
+
|
719 |
+
def tool_forward(*args, **kwargs):
|
720 |
+
return tool(*args, sanitize_inputs_outputs=True, **kwargs)
|
721 |
+
|
722 |
+
tool_forward.__signature__ = inspect.signature(tool.forward)
|
723 |
+
|
724 |
+
gradio_inputs = []
|
725 |
+
for input_name, input_details in tool.inputs.items():
|
726 |
+
input_gradio_component_class = TYPE_TO_COMPONENT_CLASS_MAPPING[input_details["type"]]
|
727 |
+
new_component = input_gradio_component_class(label=input_name)
|
728 |
+
gradio_inputs.append(new_component)
|
729 |
+
|
730 |
+
output_gradio_component_class = TYPE_TO_COMPONENT_CLASS_MAPPING[tool.output_type]
|
731 |
+
gradio_output = output_gradio_component_class(label="Output")
|
732 |
+
|
733 |
+
gr.Interface(
|
734 |
+
fn=tool_forward,
|
735 |
+
inputs=gradio_inputs,
|
736 |
+
outputs=gradio_output,
|
737 |
+
title=tool.name,
|
738 |
+
description=tool.description,
|
739 |
+
api_name=tool.name,
|
740 |
+
).launch()
|
741 |
+
|
742 |
+
|
743 |
+
def load_tool(
|
744 |
+
repo_id,
|
745 |
+
model_repo_id: str | None = None,
|
746 |
+
token: str | None = None,
|
747 |
+
trust_remote_code: bool = False,
|
748 |
+
**kwargs,
|
749 |
+
):
|
750 |
+
"""
|
751 |
+
Main function to quickly load a tool from the Hub.
|
752 |
+
|
753 |
+
<Tip warning={true}>
|
754 |
+
|
755 |
+
Loading a tool means that you'll download the tool and execute it locally.
|
756 |
+
ALWAYS inspect the tool you're downloading before loading it within your runtime, as you would do when
|
757 |
+
installing a package using pip/npm/apt.
|
758 |
+
|
759 |
+
</Tip>
|
760 |
+
|
761 |
+
Args:
|
762 |
+
repo_id (`str`):
|
763 |
+
Space repo ID of a tool on the Hub.
|
764 |
+
model_repo_id (`str`, *optional*):
|
765 |
+
Use this argument to use a different model than the default one for the tool you selected.
|
766 |
+
token (`str`, *optional*):
|
767 |
+
The token to identify you on hf.co. If unset, will use the token generated when running `huggingface-cli
|
768 |
+
login` (stored in `~/.huggingface`).
|
769 |
+
trust_remote_code (`bool`, *optional*, defaults to False):
|
770 |
+
This needs to be accepted in order to load a tool from Hub.
|
771 |
+
kwargs (additional keyword arguments, *optional*):
|
772 |
+
Additional keyword arguments that will be split in two: all arguments relevant to the Hub (such as
|
773 |
+
`cache_dir`, `revision`, `subfolder`) will be used when downloading the files for your tool, and the others
|
774 |
+
will be passed along to its init.
|
775 |
+
"""
|
776 |
+
return Tool.from_hub(
|
777 |
+
repo_id,
|
778 |
+
model_repo_id=model_repo_id,
|
779 |
+
token=token,
|
780 |
+
trust_remote_code=trust_remote_code,
|
781 |
+
**kwargs,
|
782 |
+
)
|
783 |
+
|
784 |
+
|
785 |
+
def add_description(description):
|
786 |
+
"""
|
787 |
+
A decorator that adds a description to a function.
|
788 |
+
"""
|
789 |
+
|
790 |
+
def inner(func):
|
791 |
+
func.description = description
|
792 |
+
func.name = func.__name__
|
793 |
+
return func
|
794 |
+
|
795 |
+
return inner
|
796 |
+
|
797 |
+
|
798 |
+
class ToolCollection:
|
799 |
+
"""
|
800 |
+
Tool collections enable loading a collection of tools in the agent's toolbox.
|
801 |
+
|
802 |
+
Collections can be loaded from a collection in the Hub or from an MCP server, see:
|
803 |
+
- [`ToolCollection.from_hub`]
|
804 |
+
- [`ToolCollection.from_mcp`]
|
805 |
+
|
806 |
+
For example and usage, see: [`ToolCollection.from_hub`] and [`ToolCollection.from_mcp`]
|
807 |
+
"""
|
808 |
+
|
809 |
+
def __init__(self, tools: list[Tool]):
|
810 |
+
self.tools = tools
|
811 |
+
|
812 |
+
@classmethod
|
813 |
+
def from_hub(
|
814 |
+
cls,
|
815 |
+
collection_slug: str,
|
816 |
+
token: str | None = None,
|
817 |
+
trust_remote_code: bool = False,
|
818 |
+
) -> "ToolCollection":
|
819 |
+
"""Loads a tool collection from the Hub.
|
820 |
+
|
821 |
+
it adds a collection of tools from all Spaces in the collection to the agent's toolbox
|
822 |
+
|
823 |
+
> [!NOTE]
|
824 |
+
> Only Spaces will be fetched, so you can feel free to add models and datasets to your collection if you'd
|
825 |
+
> like for this collection to showcase them.
|
826 |
+
|
827 |
+
Args:
|
828 |
+
collection_slug (str): The collection slug referencing the collection.
|
829 |
+
token (str, *optional*): The authentication token if the collection is private.
|
830 |
+
trust_remote_code (bool, *optional*, defaults to False): Whether to trust the remote code.
|
831 |
+
|
832 |
+
Returns:
|
833 |
+
ToolCollection: A tool collection instance loaded with the tools.
|
834 |
+
|
835 |
+
Example:
|
836 |
+
```py
|
837 |
+
>>> from smolagents import ToolCollection, CodeAgent
|
838 |
+
|
839 |
+
>>> image_tool_collection = ToolCollection.from_hub("huggingface-tools/diffusion-tools-6630bb19a942c2306a2cdb6f")
|
840 |
+
>>> agent = CodeAgent(tools=[*image_tool_collection.tools], add_base_tools=True)
|
841 |
+
|
842 |
+
>>> agent.run("Please draw me a picture of rivers and lakes.")
|
843 |
+
```
|
844 |
+
"""
|
845 |
+
_collection = get_collection(collection_slug, token=token)
|
846 |
+
_hub_repo_ids = {item.item_id for item in _collection.items if item.item_type == "space"}
|
847 |
+
|
848 |
+
tools = {Tool.from_hub(repo_id, token, trust_remote_code) for repo_id in _hub_repo_ids}
|
849 |
+
|
850 |
+
return cls(tools)
|
851 |
+
|
852 |
+
@classmethod
|
853 |
+
@contextmanager
|
854 |
+
def from_mcp(
|
855 |
+
cls, server_parameters: "mcp.StdioServerParameters" | dict, trust_remote_code: bool = False
|
856 |
+
) -> "ToolCollection":
|
857 |
+
"""Automatically load a tool collection from an MCP server.
|
858 |
+
|
859 |
+
This method supports Stdio, Streamable HTTP, and legacy HTTP+SSE MCP servers. Look at the `server_parameters`
|
860 |
+
argument for more details on how to connect to each MCP server.
|
861 |
+
|
862 |
+
Note: a separate thread will be spawned to run an asyncio event loop handling
|
863 |
+
the MCP server.
|
864 |
+
|
865 |
+
Args:
|
866 |
+
server_parameters (`mcp.StdioServerParameters` or `dict`):
|
867 |
+
Configuration parameters to connect to the MCP server. This can be:
|
868 |
+
|
869 |
+
- An instance of `mcp.StdioServerParameters` for connecting a Stdio MCP server via standard input/output using a subprocess.
|
870 |
+
|
871 |
+
- A `dict` with at least:
|
872 |
+
- "url": URL of the server.
|
873 |
+
- "transport": Transport protocol to use, one of:
|
874 |
+
- "streamable-http": (recommended) Streamable HTTP transport.
|
875 |
+
- "sse": Legacy HTTP+SSE transport (deprecated).
|
876 |
+
If "transport" is omitted, the legacy "sse" transport is assumed (a deprecation warning will be issued).
|
877 |
+
|
878 |
+
<Deprecated version="1.17.0">
|
879 |
+
The HTTP+SSE transport is deprecated and future behavior will default to the Streamable HTTP transport.
|
880 |
+
Please pass explicitly the "transport" key.
|
881 |
+
</Deprecated>
|
882 |
+
trust_remote_code (`bool`, *optional*, defaults to `False`):
|
883 |
+
Whether to trust the execution of code from tools defined on the MCP server.
|
884 |
+
This option should only be set to `True` if you trust the MCP server,
|
885 |
+
and undertand the risks associated with running remote code on your local machine.
|
886 |
+
If set to `False`, loading tools from MCP will fail.
|
887 |
+
|
888 |
+
|
889 |
+
Returns:
|
890 |
+
ToolCollection: A tool collection instance.
|
891 |
+
|
892 |
+
Example with a Stdio MCP server:
|
893 |
+
```py
|
894 |
+
>>> import os
|
895 |
+
>>> from smolagents import ToolCollection, CodeAgent, InferenceClientModel
|
896 |
+
>>> from mcp import StdioServerParameters
|
897 |
+
|
898 |
+
>>> model = InferenceClientModel()
|
899 |
+
|
900 |
+
>>> server_parameters = StdioServerParameters(
|
901 |
+
>>> command="uvx",
|
902 |
+
>>> args=["--quiet", "[email protected]"],
|
903 |
+
>>> env={"UV_PYTHON": "3.12", **os.environ},
|
904 |
+
>>> )
|
905 |
+
|
906 |
+
>>> with ToolCollection.from_mcp(server_parameters, trust_remote_code=True) as tool_collection:
|
907 |
+
>>> agent = CodeAgent(tools=[*tool_collection.tools], add_base_tools=True, model=model)
|
908 |
+
>>> agent.run("Please find a remedy for hangover.")
|
909 |
+
```
|
910 |
+
|
911 |
+
Example with a Streamable HTTP MCP server:
|
912 |
+
```py
|
913 |
+
>>> with ToolCollection.from_mcp({"url": "http://127.0.0.1:8000/mcp", "transport": "streamable-http"}, trust_remote_code=True) as tool_collection:
|
914 |
+
>>> agent = CodeAgent(tools=[*tool_collection.tools], add_base_tools=True, model=model)
|
915 |
+
>>> agent.run("Please find a remedy for hangover.")
|
916 |
+
```
|
917 |
+
"""
|
918 |
+
try:
|
919 |
+
from mcpadapt.core import MCPAdapt
|
920 |
+
from mcpadapt.smolagents_adapter import SmolAgentsAdapter
|
921 |
+
except ImportError:
|
922 |
+
raise ImportError(
|
923 |
+
"""Please install 'mcp' extra to use ToolCollection.from_mcp: `pip install "smolagents[mcp]"`."""
|
924 |
+
)
|
925 |
+
if isinstance(server_parameters, dict):
|
926 |
+
transport = server_parameters.get("transport")
|
927 |
+
if transport is None:
|
928 |
+
warnings.warn(
|
929 |
+
"Passing a dict as server_parameters without specifying the 'transport' key is deprecated. "
|
930 |
+
"For now, it defaults to the legacy 'sse' (HTTP+SSE) transport, but this default will change "
|
931 |
+
"to 'streamable-http' in version 1.20. Please add the 'transport' key explicitly. ",
|
932 |
+
FutureWarning,
|
933 |
+
)
|
934 |
+
transport = "sse"
|
935 |
+
server_parameters["transport"] = transport
|
936 |
+
if transport not in {"sse", "streamable-http"}:
|
937 |
+
raise ValueError(
|
938 |
+
f"Unsupported transport: {transport}. Supported transports are 'streamable-http' and 'sse'."
|
939 |
+
)
|
940 |
+
if not trust_remote_code:
|
941 |
+
raise ValueError(
|
942 |
+
"Loading tools from MCP requires you to acknowledge you trust the MCP server, "
|
943 |
+
"as it will execute code on your local machine: pass `trust_remote_code=True`."
|
944 |
+
)
|
945 |
+
with MCPAdapt(server_parameters, SmolAgentsAdapter()) as tools:
|
946 |
+
yield cls(tools)
|
947 |
+
|
948 |
+
|
949 |
+
def tool(tool_function: Callable) -> Tool:
|
950 |
+
"""
|
951 |
+
Convert a function into an instance of a dynamically created Tool subclass.
|
952 |
+
|
953 |
+
Args:
|
954 |
+
tool_function (`Callable`): Function to convert into a Tool subclass.
|
955 |
+
Should have type hints for each input and a type hint for the output.
|
956 |
+
Should also have a docstring including the description of the function
|
957 |
+
and an 'Args:' part where each argument is described.
|
958 |
+
"""
|
959 |
+
tool_json_schema = get_json_schema(tool_function)["function"]
|
960 |
+
if "return" not in tool_json_schema:
|
961 |
+
raise TypeHintParsingException("Tool return type not found: make sure your function has a return type hint!")
|
962 |
+
|
963 |
+
class SimpleTool(Tool):
|
964 |
+
def __init__(self):
|
965 |
+
self.is_initialized = True
|
966 |
+
|
967 |
+
# Set the class attributes
|
968 |
+
SimpleTool.name = tool_json_schema["name"]
|
969 |
+
SimpleTool.description = tool_json_schema["description"]
|
970 |
+
SimpleTool.inputs = tool_json_schema["parameters"]["properties"]
|
971 |
+
SimpleTool.output_type = tool_json_schema["return"]["type"]
|
972 |
+
|
973 |
+
@wraps(tool_function)
|
974 |
+
def wrapped_function(*args, **kwargs):
|
975 |
+
return tool_function(*args, **kwargs)
|
976 |
+
|
977 |
+
# Bind the copied function to the forward method
|
978 |
+
SimpleTool.forward = staticmethod(wrapped_function)
|
979 |
+
|
980 |
+
# Get the signature parameters of the tool function
|
981 |
+
sig = inspect.signature(tool_function)
|
982 |
+
# - Add "self" as first parameter to tool_function signature
|
983 |
+
new_sig = sig.replace(
|
984 |
+
parameters=[inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD)] + list(sig.parameters.values())
|
985 |
+
)
|
986 |
+
# - Set the signature of the forward method
|
987 |
+
SimpleTool.forward.__signature__ = new_sig
|
988 |
+
|
989 |
+
# Create and attach the source code of the dynamically created tool class and forward method
|
990 |
+
# - Get the source code of tool_function
|
991 |
+
tool_source = inspect.getsource(tool_function)
|
992 |
+
# - Remove the tool decorator and function definition line
|
993 |
+
tool_source_body = "\n".join(tool_source.split("\n")[2:])
|
994 |
+
# - Dedent
|
995 |
+
tool_source_body = textwrap.dedent(tool_source_body)
|
996 |
+
# - Create the forward method source, including def line and indentation
|
997 |
+
forward_method_source = f"def forward{str(new_sig)}:\n{textwrap.indent(tool_source_body, ' ')}"
|
998 |
+
# - Create the class source
|
999 |
+
class_source = (
|
1000 |
+
textwrap.dedent(f"""
|
1001 |
+
class SimpleTool(Tool):
|
1002 |
+
name: str = "{tool_json_schema["name"]}"
|
1003 |
+
description: str = {json.dumps(textwrap.dedent(tool_json_schema["description"]).strip())}
|
1004 |
+
inputs: dict[str, dict[str, str]] = {tool_json_schema["parameters"]["properties"]}
|
1005 |
+
output_type: str = "{tool_json_schema["return"]["type"]}"
|
1006 |
+
|
1007 |
+
def __init__(self):
|
1008 |
+
self.is_initialized = True
|
1009 |
+
|
1010 |
+
""")
|
1011 |
+
+ textwrap.indent(forward_method_source, " ") # indent for class method
|
1012 |
+
)
|
1013 |
+
# - Store the source code on both class and method for inspection
|
1014 |
+
SimpleTool.__source__ = class_source
|
1015 |
+
SimpleTool.forward.__source__ = forward_method_source
|
1016 |
+
|
1017 |
+
simple_tool = SimpleTool()
|
1018 |
+
return simple_tool
|
1019 |
+
|
1020 |
+
|
1021 |
+
class PipelineTool(Tool):
|
1022 |
+
"""
|
1023 |
+
A [`Tool`] tailored towards Transformer models. On top of the class attributes of the base class [`Tool`], you will
|
1024 |
+
need to specify:
|
1025 |
+
|
1026 |
+
- **model_class** (`type`) -- The class to use to load the model in this tool.
|
1027 |
+
- **default_checkpoint** (`str`) -- The default checkpoint that should be used when the user doesn't specify one.
|
1028 |
+
- **pre_processor_class** (`type`, *optional*, defaults to [`transformers.AutoProcessor`]) -- The class to use to load the
|
1029 |
+
pre-processor
|
1030 |
+
- **post_processor_class** (`type`, *optional*, defaults to [`transformers.AutoProcessor`]) -- The class to use to load the
|
1031 |
+
post-processor (when different from the pre-processor).
|
1032 |
+
|
1033 |
+
Args:
|
1034 |
+
model (`str` or [`transformers.PreTrainedModel`], *optional*):
|
1035 |
+
The name of the checkpoint to use for the model, or the instantiated model. If unset, will default to the
|
1036 |
+
value of the class attribute `default_checkpoint`.
|
1037 |
+
pre_processor (`str` or `Any`, *optional*):
|
1038 |
+
The name of the checkpoint to use for the pre-processor, or the instantiated pre-processor (can be a
|
1039 |
+
tokenizer, an image processor, a feature extractor or a processor). Will default to the value of `model` if
|
1040 |
+
unset.
|
1041 |
+
post_processor (`str` or `Any`, *optional*):
|
1042 |
+
The name of the checkpoint to use for the post-processor, or the instantiated pre-processor (can be a
|
1043 |
+
tokenizer, an image processor, a feature extractor or a processor). Will default to the `pre_processor` if
|
1044 |
+
unset.
|
1045 |
+
device (`int`, `str` or `torch.device`, *optional*):
|
1046 |
+
The device on which to execute the model. Will default to any accelerator available (GPU, MPS etc...), the
|
1047 |
+
CPU otherwise.
|
1048 |
+
device_map (`str` or `dict`, *optional*):
|
1049 |
+
If passed along, will be used to instantiate the model.
|
1050 |
+
model_kwargs (`dict`, *optional*):
|
1051 |
+
Any keyword argument to send to the model instantiation.
|
1052 |
+
token (`str`, *optional*):
|
1053 |
+
The token to use as HTTP bearer authorization for remote files. If unset, will use the token generated when
|
1054 |
+
running `huggingface-cli login` (stored in `~/.huggingface`).
|
1055 |
+
hub_kwargs (additional keyword arguments, *optional*):
|
1056 |
+
Any additional keyword argument to send to the methods that will load the data from the Hub.
|
1057 |
+
"""
|
1058 |
+
|
1059 |
+
pre_processor_class = None
|
1060 |
+
model_class = None
|
1061 |
+
post_processor_class = None
|
1062 |
+
default_checkpoint = None
|
1063 |
+
description = "This is a pipeline tool"
|
1064 |
+
name = "pipeline"
|
1065 |
+
inputs = {"prompt": str}
|
1066 |
+
output_type = str
|
1067 |
+
skip_forward_signature_validation = True
|
1068 |
+
|
1069 |
+
def __init__(
|
1070 |
+
self,
|
1071 |
+
model=None,
|
1072 |
+
pre_processor=None,
|
1073 |
+
post_processor=None,
|
1074 |
+
device=None,
|
1075 |
+
device_map=None,
|
1076 |
+
model_kwargs=None,
|
1077 |
+
token=None,
|
1078 |
+
**hub_kwargs,
|
1079 |
+
):
|
1080 |
+
if not _is_package_available("accelerate") or not _is_package_available("torch"):
|
1081 |
+
raise ModuleNotFoundError(
|
1082 |
+
"Please install 'transformers' extra to use a PipelineTool: `pip install 'smolagents[transformers]'`"
|
1083 |
+
)
|
1084 |
+
|
1085 |
+
if model is None:
|
1086 |
+
if self.default_checkpoint is None:
|
1087 |
+
raise ValueError("This tool does not implement a default checkpoint, you need to pass one.")
|
1088 |
+
model = self.default_checkpoint
|
1089 |
+
if pre_processor is None:
|
1090 |
+
pre_processor = model
|
1091 |
+
|
1092 |
+
self.model = model
|
1093 |
+
self.pre_processor = pre_processor
|
1094 |
+
self.post_processor = post_processor
|
1095 |
+
self.device = device
|
1096 |
+
self.device_map = device_map
|
1097 |
+
self.model_kwargs = {} if model_kwargs is None else model_kwargs
|
1098 |
+
if device_map is not None:
|
1099 |
+
self.model_kwargs["device_map"] = device_map
|
1100 |
+
self.hub_kwargs = hub_kwargs
|
1101 |
+
self.hub_kwargs["token"] = token
|
1102 |
+
|
1103 |
+
super().__init__()
|
1104 |
+
|
1105 |
+
def setup(self):
|
1106 |
+
"""
|
1107 |
+
Instantiates the `pre_processor`, `model` and `post_processor` if necessary.
|
1108 |
+
"""
|
1109 |
+
if isinstance(self.pre_processor, str):
|
1110 |
+
if self.pre_processor_class is None:
|
1111 |
+
from transformers import AutoProcessor
|
1112 |
+
|
1113 |
+
self.pre_processor_class = AutoProcessor
|
1114 |
+
self.pre_processor = self.pre_processor_class.from_pretrained(self.pre_processor, **self.hub_kwargs)
|
1115 |
+
|
1116 |
+
if isinstance(self.model, str):
|
1117 |
+
self.model = self.model_class.from_pretrained(self.model, **self.model_kwargs, **self.hub_kwargs)
|
1118 |
+
|
1119 |
+
if self.post_processor is None:
|
1120 |
+
self.post_processor = self.pre_processor
|
1121 |
+
elif isinstance(self.post_processor, str):
|
1122 |
+
if self.post_processor_class is None:
|
1123 |
+
from transformers import AutoProcessor
|
1124 |
+
|
1125 |
+
self.post_processor_class = AutoProcessor
|
1126 |
+
self.post_processor = self.post_processor_class.from_pretrained(self.post_processor, **self.hub_kwargs)
|
1127 |
+
|
1128 |
+
if self.device is None:
|
1129 |
+
if self.device_map is not None:
|
1130 |
+
self.device = list(self.model.hf_device_map.values())[0]
|
1131 |
+
else:
|
1132 |
+
from accelerate import PartialState
|
1133 |
+
|
1134 |
+
self.device = PartialState().default_device
|
1135 |
+
|
1136 |
+
if self.device_map is None:
|
1137 |
+
self.model.to(self.device)
|
1138 |
+
|
1139 |
+
super().setup()
|
1140 |
+
|
1141 |
+
def encode(self, raw_inputs):
|
1142 |
+
"""
|
1143 |
+
Uses the `pre_processor` to prepare the inputs for the `model`.
|
1144 |
+
"""
|
1145 |
+
return self.pre_processor(raw_inputs)
|
1146 |
+
|
1147 |
+
def forward(self, inputs):
|
1148 |
+
"""
|
1149 |
+
Sends the inputs through the `model`.
|
1150 |
+
"""
|
1151 |
+
import torch
|
1152 |
+
|
1153 |
+
with torch.no_grad():
|
1154 |
+
return self.model(**inputs)
|
1155 |
+
|
1156 |
+
def decode(self, outputs):
|
1157 |
+
"""
|
1158 |
+
Uses the `post_processor` to decode the model output.
|
1159 |
+
"""
|
1160 |
+
return self.post_processor(outputs)
|
1161 |
+
|
1162 |
+
def __call__(self, *args, sanitize_inputs_outputs: bool = False, **kwargs):
|
1163 |
+
import torch
|
1164 |
+
from accelerate.utils import send_to_device
|
1165 |
+
|
1166 |
+
if not self.is_initialized:
|
1167 |
+
self.setup()
|
1168 |
+
|
1169 |
+
if sanitize_inputs_outputs:
|
1170 |
+
args, kwargs = handle_agent_input_types(*args, **kwargs)
|
1171 |
+
encoded_inputs = self.encode(*args, **kwargs)
|
1172 |
+
|
1173 |
+
tensor_inputs = {k: v for k, v in encoded_inputs.items() if isinstance(v, torch.Tensor)}
|
1174 |
+
non_tensor_inputs = {k: v for k, v in encoded_inputs.items() if not isinstance(v, torch.Tensor)}
|
1175 |
+
|
1176 |
+
encoded_inputs = send_to_device(tensor_inputs, self.device)
|
1177 |
+
outputs = self.forward({**encoded_inputs, **non_tensor_inputs})
|
1178 |
+
outputs = send_to_device(outputs, "cpu")
|
1179 |
+
decoded_outputs = self.decode(outputs)
|
1180 |
+
if sanitize_inputs_outputs:
|
1181 |
+
decoded_outputs = handle_agent_output_types(decoded_outputs, self.output_type)
|
1182 |
+
return decoded_outputs
|
1183 |
+
|
1184 |
+
|
1185 |
+
def get_tools_definition_code(tools: dict[str, Tool]) -> str:
|
1186 |
+
tool_codes = []
|
1187 |
+
for tool in tools.values():
|
1188 |
+
validate_tool_attributes(tool.__class__, check_imports=False)
|
1189 |
+
tool_code = instance_to_source(tool, base_cls=Tool)
|
1190 |
+
tool_code = tool_code.replace("from smolagents.tools import Tool", "")
|
1191 |
+
tool_code += f"\n\n{tool.name} = {tool.__class__.__name__}()\n"
|
1192 |
+
tool_codes.append(tool_code)
|
1193 |
+
|
1194 |
+
tool_definition_code = "\n".join([f"import {module}" for module in BASE_BUILTIN_MODULES])
|
1195 |
+
tool_definition_code += textwrap.dedent(
|
1196 |
+
"""
|
1197 |
+
from typing import Any
|
1198 |
+
|
1199 |
+
class Tool:
|
1200 |
+
def __call__(self, *args, **kwargs):
|
1201 |
+
return self.forward(*args, **kwargs)
|
1202 |
+
|
1203 |
+
def forward(self, *args, **kwargs):
|
1204 |
+
pass # to be implemented in child class
|
1205 |
+
"""
|
1206 |
+
)
|
1207 |
+
tool_definition_code += "\n\n".join(tool_codes)
|
1208 |
+
return tool_definition_code
|
1209 |
+
|
1210 |
+
|
1211 |
+
def validate_tool_arguments(tool: Tool, arguments: Any) -> str | None:
|
1212 |
+
if isinstance(arguments, dict):
|
1213 |
+
for key, value in arguments.items():
|
1214 |
+
if key not in tool.inputs:
|
1215 |
+
return f"Argument {key} is not in the tool's input schema."
|
1216 |
+
|
1217 |
+
parsed_type = _get_json_schema_type(type(value))["type"]
|
1218 |
+
|
1219 |
+
if parsed_type != tool.inputs[key]["type"] and not tool.inputs[key]["type"] == "any":
|
1220 |
+
return f"Argument {key} has type '{parsed_type}' but should be '{tool.inputs[key]['type']}'."
|
1221 |
+
for key in tool.inputs:
|
1222 |
+
if key not in arguments:
|
1223 |
+
return f"Argument {key} is required."
|
1224 |
+
return None
|
1225 |
+
else:
|
1226 |
+
expected_type = list(tool.inputs.values())[0]["type"]
|
1227 |
+
if _get_json_schema_type(type(arguments))["type"] != expected_type and not expected_type == "any":
|
1228 |
+
return f"Argument has type '{type(arguments).__name__}' but should be '{expected_type}'."
|
1229 |
+
return None
|
1230 |
+
|
1231 |
+
|
1232 |
+
__all__ = [
|
1233 |
+
"AUTHORIZED_TYPES",
|
1234 |
+
"Tool",
|
1235 |
+
"tool",
|
1236 |
+
"load_tool",
|
1237 |
+
"launch_gradio_demo",
|
1238 |
+
"ToolCollection",
|
1239 |
+
]
|
src/smolagents/utils.py
ADDED
@@ -0,0 +1,500 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding=utf-8
|
3 |
+
|
4 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
5 |
+
#
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# Unless required by applicable law or agreed to in writing, software
|
13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
# See the License for the specific language governing permissions and
|
16 |
+
# limitations under the License.
|
17 |
+
import ast
|
18 |
+
import base64
|
19 |
+
import importlib.metadata
|
20 |
+
import importlib.util
|
21 |
+
import inspect
|
22 |
+
import json
|
23 |
+
import keyword
|
24 |
+
import os
|
25 |
+
import re
|
26 |
+
import types
|
27 |
+
from functools import lru_cache
|
28 |
+
from io import BytesIO
|
29 |
+
from pathlib import Path
|
30 |
+
from textwrap import dedent
|
31 |
+
from typing import TYPE_CHECKING, Any
|
32 |
+
|
33 |
+
|
34 |
+
if TYPE_CHECKING:
|
35 |
+
from smolagents.memory import AgentLogger
|
36 |
+
|
37 |
+
|
38 |
+
__all__ = ["AgentError"]
|
39 |
+
|
40 |
+
|
41 |
+
@lru_cache
|
42 |
+
def _is_package_available(package_name: str) -> bool:
|
43 |
+
try:
|
44 |
+
importlib.metadata.version(package_name)
|
45 |
+
return True
|
46 |
+
except importlib.metadata.PackageNotFoundError:
|
47 |
+
return False
|
48 |
+
|
49 |
+
|
50 |
+
BASE_BUILTIN_MODULES = [
|
51 |
+
"collections",
|
52 |
+
"datetime",
|
53 |
+
"itertools",
|
54 |
+
"math",
|
55 |
+
"queue",
|
56 |
+
"random",
|
57 |
+
"re",
|
58 |
+
"stat",
|
59 |
+
"statistics",
|
60 |
+
"time",
|
61 |
+
"unicodedata",
|
62 |
+
]
|
63 |
+
|
64 |
+
|
65 |
+
def escape_code_brackets(text: str) -> str:
|
66 |
+
"""Escapes square brackets in code segments while preserving Rich styling tags."""
|
67 |
+
|
68 |
+
def replace_bracketed_content(match):
|
69 |
+
content = match.group(1)
|
70 |
+
cleaned = re.sub(
|
71 |
+
r"bold|red|green|blue|yellow|magenta|cyan|white|black|italic|dim|\s|#[0-9a-fA-F]{6}", "", content
|
72 |
+
)
|
73 |
+
return f"\\[{content}\\]" if cleaned.strip() else f"[{content}]"
|
74 |
+
|
75 |
+
return re.sub(r"\[([^\]]*)\]", replace_bracketed_content, text)
|
76 |
+
|
77 |
+
|
78 |
+
class AgentError(Exception):
|
79 |
+
"""Base class for other agent-related exceptions"""
|
80 |
+
|
81 |
+
def __init__(self, message, logger: "AgentLogger"):
|
82 |
+
super().__init__(message)
|
83 |
+
self.message = message
|
84 |
+
logger.log_error(message)
|
85 |
+
|
86 |
+
def dict(self) -> dict[str, str]:
|
87 |
+
return {"type": self.__class__.__name__, "message": str(self.message)}
|
88 |
+
|
89 |
+
|
90 |
+
class AgentParsingError(AgentError):
|
91 |
+
"""Exception raised for errors in parsing in the agent"""
|
92 |
+
|
93 |
+
pass
|
94 |
+
|
95 |
+
|
96 |
+
class AgentExecutionError(AgentError):
|
97 |
+
"""Exception raised for errors in execution in the agent"""
|
98 |
+
|
99 |
+
pass
|
100 |
+
|
101 |
+
|
102 |
+
class AgentMaxStepsError(AgentError):
|
103 |
+
"""Exception raised for errors in execution in the agent"""
|
104 |
+
|
105 |
+
pass
|
106 |
+
|
107 |
+
|
108 |
+
class AgentToolCallError(AgentExecutionError):
|
109 |
+
"""Exception raised for errors when incorrect arguments are passed to the tool"""
|
110 |
+
|
111 |
+
pass
|
112 |
+
|
113 |
+
|
114 |
+
class AgentToolExecutionError(AgentExecutionError):
|
115 |
+
"""Exception raised for errors when executing a tool"""
|
116 |
+
|
117 |
+
pass
|
118 |
+
|
119 |
+
|
120 |
+
class AgentGenerationError(AgentError):
|
121 |
+
"""Exception raised for errors in generation in the agent"""
|
122 |
+
|
123 |
+
pass
|
124 |
+
|
125 |
+
|
126 |
+
def make_json_serializable(obj: Any) -> Any:
|
127 |
+
"""Recursive function to make objects JSON serializable"""
|
128 |
+
if obj is None:
|
129 |
+
return None
|
130 |
+
elif isinstance(obj, (str, int, float, bool)):
|
131 |
+
# Try to parse string as JSON if it looks like a JSON object/array
|
132 |
+
if isinstance(obj, str):
|
133 |
+
try:
|
134 |
+
if (obj.startswith("{") and obj.endswith("}")) or (obj.startswith("[") and obj.endswith("]")):
|
135 |
+
parsed = json.loads(obj)
|
136 |
+
return make_json_serializable(parsed)
|
137 |
+
except json.JSONDecodeError:
|
138 |
+
pass
|
139 |
+
return obj
|
140 |
+
elif isinstance(obj, (list, tuple)):
|
141 |
+
return [make_json_serializable(item) for item in obj]
|
142 |
+
elif isinstance(obj, dict):
|
143 |
+
return {str(k): make_json_serializable(v) for k, v in obj.items()}
|
144 |
+
elif hasattr(obj, "__dict__"):
|
145 |
+
# For custom objects, convert their __dict__ to a serializable format
|
146 |
+
return {"_type": obj.__class__.__name__, **{k: make_json_serializable(v) for k, v in obj.__dict__.items()}}
|
147 |
+
else:
|
148 |
+
# For any other type, convert to string
|
149 |
+
return str(obj)
|
150 |
+
|
151 |
+
|
152 |
+
def parse_json_blob(json_blob: str) -> tuple[dict[str, str], str]:
|
153 |
+
"Extracts the JSON blob from the input and returns the JSON data and the rest of the input."
|
154 |
+
try:
|
155 |
+
first_accolade_index = json_blob.find("{")
|
156 |
+
last_accolade_index = [a.start() for a in list(re.finditer("}", json_blob))][-1]
|
157 |
+
json_data = json_blob[first_accolade_index : last_accolade_index + 1]
|
158 |
+
json_data = json.loads(json_data, strict=False)
|
159 |
+
return json_data, json_blob[:first_accolade_index]
|
160 |
+
except IndexError:
|
161 |
+
raise ValueError("The model output does not contain any JSON blob.")
|
162 |
+
except json.JSONDecodeError as e:
|
163 |
+
place = e.pos
|
164 |
+
if json_blob[place - 1 : place + 2] == "},\n":
|
165 |
+
raise ValueError(
|
166 |
+
"JSON is invalid: you probably tried to provide multiple tool calls in one action. PROVIDE ONLY ONE TOOL CALL."
|
167 |
+
)
|
168 |
+
raise ValueError(
|
169 |
+
f"The JSON blob you used is invalid due to the following error: {e}.\n"
|
170 |
+
f"JSON blob was: {json_blob}, decoding failed on that specific part of the blob:\n"
|
171 |
+
f"'{json_blob[place - 4 : place + 5]}'."
|
172 |
+
)
|
173 |
+
|
174 |
+
|
175 |
+
def extract_code_from_text(text: str) -> str | None:
|
176 |
+
"""Extract code from the LLM's output."""
|
177 |
+
pattern = r"<code>(.*?)</code>"
|
178 |
+
matches = re.findall(pattern, text, re.DOTALL)
|
179 |
+
if matches:
|
180 |
+
return "\n\n".join(match.strip() for match in matches)
|
181 |
+
return None
|
182 |
+
|
183 |
+
|
184 |
+
def parse_code_blobs(text: str) -> str:
|
185 |
+
"""Extract code blocs from the LLM's output.
|
186 |
+
|
187 |
+
If a valid code block is passed, it returns it directly.
|
188 |
+
|
189 |
+
Args:
|
190 |
+
text (`str`): LLM's output text to parse.
|
191 |
+
|
192 |
+
Returns:
|
193 |
+
`str`: Extracted code block.
|
194 |
+
|
195 |
+
Raises:
|
196 |
+
ValueError: If no valid code block is found in the text.
|
197 |
+
"""
|
198 |
+
matches = extract_code_from_text(text)
|
199 |
+
if matches:
|
200 |
+
return matches
|
201 |
+
# Maybe the LLM outputted a code blob directly
|
202 |
+
try:
|
203 |
+
ast.parse(text)
|
204 |
+
return text
|
205 |
+
except SyntaxError:
|
206 |
+
pass
|
207 |
+
|
208 |
+
if "final" in text and "answer" in text:
|
209 |
+
raise ValueError(
|
210 |
+
dedent(
|
211 |
+
f"""
|
212 |
+
Your code snippet is invalid, because the regex pattern <code>(.*?)</code> was not found in it.
|
213 |
+
Here is your code snippet:
|
214 |
+
{text}
|
215 |
+
It seems like you're trying to return the final answer, you can do it as follows:
|
216 |
+
<code>
|
217 |
+
final_answer("YOUR FINAL ANSWER HERE")
|
218 |
+
</code>
|
219 |
+
"""
|
220 |
+
).strip()
|
221 |
+
)
|
222 |
+
raise ValueError(
|
223 |
+
dedent(
|
224 |
+
f"""
|
225 |
+
Your code snippet is invalid, because the regex pattern <code>(.*?)</code> was not found in it.
|
226 |
+
Here is your code snippet:
|
227 |
+
{text}
|
228 |
+
Make sure to include code with the correct pattern, for instance:
|
229 |
+
Thoughts: Your thoughts
|
230 |
+
<code>
|
231 |
+
# Your python code here
|
232 |
+
</code>
|
233 |
+
"""
|
234 |
+
).strip()
|
235 |
+
)
|
236 |
+
|
237 |
+
|
238 |
+
MAX_LENGTH_TRUNCATE_CONTENT = 20000
|
239 |
+
|
240 |
+
|
241 |
+
def truncate_content(content: str, max_length: int = MAX_LENGTH_TRUNCATE_CONTENT) -> str:
|
242 |
+
if len(content) <= max_length:
|
243 |
+
return content
|
244 |
+
else:
|
245 |
+
return (
|
246 |
+
content[: max_length // 2]
|
247 |
+
+ f"\n..._This content has been truncated to stay below {max_length} characters_...\n"
|
248 |
+
+ content[-max_length // 2 :]
|
249 |
+
)
|
250 |
+
|
251 |
+
|
252 |
+
class ImportFinder(ast.NodeVisitor):
|
253 |
+
def __init__(self):
|
254 |
+
self.packages = set()
|
255 |
+
|
256 |
+
def visit_Import(self, node):
|
257 |
+
for alias in node.names:
|
258 |
+
# Get the base package name (before any dots)
|
259 |
+
base_package = alias.name.split(".")[0]
|
260 |
+
self.packages.add(base_package)
|
261 |
+
|
262 |
+
def visit_ImportFrom(self, node):
|
263 |
+
if node.module: # for "from x import y" statements
|
264 |
+
# Get the base package name (before any dots)
|
265 |
+
base_package = node.module.split(".")[0]
|
266 |
+
self.packages.add(base_package)
|
267 |
+
|
268 |
+
|
269 |
+
def get_method_source(method):
|
270 |
+
"""Get source code for a method, including bound methods."""
|
271 |
+
if isinstance(method, types.MethodType):
|
272 |
+
method = method.__func__
|
273 |
+
return get_source(method)
|
274 |
+
|
275 |
+
|
276 |
+
def is_same_method(method1, method2):
|
277 |
+
"""Compare two methods by their source code."""
|
278 |
+
try:
|
279 |
+
source1 = get_method_source(method1)
|
280 |
+
source2 = get_method_source(method2)
|
281 |
+
|
282 |
+
# Remove method decorators if any
|
283 |
+
source1 = "\n".join(line for line in source1.split("\n") if not line.strip().startswith("@"))
|
284 |
+
source2 = "\n".join(line for line in source2.split("\n") if not line.strip().startswith("@"))
|
285 |
+
|
286 |
+
return source1 == source2
|
287 |
+
except (TypeError, OSError):
|
288 |
+
return False
|
289 |
+
|
290 |
+
|
291 |
+
def is_same_item(item1, item2):
|
292 |
+
"""Compare two class items (methods or attributes) for equality."""
|
293 |
+
if callable(item1) and callable(item2):
|
294 |
+
return is_same_method(item1, item2)
|
295 |
+
else:
|
296 |
+
return item1 == item2
|
297 |
+
|
298 |
+
|
299 |
+
def instance_to_source(instance, base_cls=None):
|
300 |
+
"""Convert an instance to its class source code representation."""
|
301 |
+
cls = instance.__class__
|
302 |
+
class_name = cls.__name__
|
303 |
+
|
304 |
+
# Start building class lines
|
305 |
+
class_lines = []
|
306 |
+
if base_cls:
|
307 |
+
class_lines.append(f"class {class_name}({base_cls.__name__}):")
|
308 |
+
else:
|
309 |
+
class_lines.append(f"class {class_name}:")
|
310 |
+
|
311 |
+
# Add docstring if it exists and differs from base
|
312 |
+
if cls.__doc__ and (not base_cls or cls.__doc__ != base_cls.__doc__):
|
313 |
+
class_lines.append(f' """{cls.__doc__}"""')
|
314 |
+
|
315 |
+
# Add class-level attributes
|
316 |
+
class_attrs = {
|
317 |
+
name: value
|
318 |
+
for name, value in cls.__dict__.items()
|
319 |
+
if not name.startswith("__")
|
320 |
+
and not callable(value)
|
321 |
+
and not (base_cls and hasattr(base_cls, name) and getattr(base_cls, name) == value)
|
322 |
+
}
|
323 |
+
|
324 |
+
for name, value in class_attrs.items():
|
325 |
+
if isinstance(value, str):
|
326 |
+
# multiline value
|
327 |
+
if "\n" in value:
|
328 |
+
escaped_value = value.replace('"""', r"\"\"\"") # Escape triple quotes
|
329 |
+
class_lines.append(f' {name} = """{escaped_value}"""')
|
330 |
+
else:
|
331 |
+
class_lines.append(f" {name} = {json.dumps(value)}")
|
332 |
+
else:
|
333 |
+
class_lines.append(f" {name} = {repr(value)}")
|
334 |
+
|
335 |
+
if class_attrs:
|
336 |
+
class_lines.append("")
|
337 |
+
|
338 |
+
# Add methods
|
339 |
+
methods = {
|
340 |
+
name: func.__wrapped__ if hasattr(func, "__wrapped__") else func
|
341 |
+
for name, func in cls.__dict__.items()
|
342 |
+
if callable(func)
|
343 |
+
and (
|
344 |
+
not base_cls
|
345 |
+
or not hasattr(base_cls, name)
|
346 |
+
or (
|
347 |
+
isinstance(func, (staticmethod, classmethod))
|
348 |
+
or (getattr(base_cls, name).__code__.co_code != func.__code__.co_code)
|
349 |
+
)
|
350 |
+
)
|
351 |
+
}
|
352 |
+
|
353 |
+
for name, method in methods.items():
|
354 |
+
method_source = get_source(method)
|
355 |
+
# Clean up the indentation
|
356 |
+
method_lines = method_source.split("\n")
|
357 |
+
first_line = method_lines[0]
|
358 |
+
indent = len(first_line) - len(first_line.lstrip())
|
359 |
+
method_lines = [line[indent:] for line in method_lines]
|
360 |
+
method_source = "\n".join([" " + line if line.strip() else line for line in method_lines])
|
361 |
+
class_lines.append(method_source)
|
362 |
+
class_lines.append("")
|
363 |
+
|
364 |
+
# Find required imports using ImportFinder
|
365 |
+
import_finder = ImportFinder()
|
366 |
+
import_finder.visit(ast.parse("\n".join(class_lines)))
|
367 |
+
required_imports = import_finder.packages
|
368 |
+
|
369 |
+
# Build final code with imports
|
370 |
+
final_lines = []
|
371 |
+
|
372 |
+
# Add base class import if needed
|
373 |
+
if base_cls:
|
374 |
+
final_lines.append(f"from {base_cls.__module__} import {base_cls.__name__}")
|
375 |
+
|
376 |
+
# Add discovered imports
|
377 |
+
for package in required_imports:
|
378 |
+
final_lines.append(f"import {package}")
|
379 |
+
|
380 |
+
if final_lines: # Add empty line after imports
|
381 |
+
final_lines.append("")
|
382 |
+
|
383 |
+
# Add the class code
|
384 |
+
final_lines.extend(class_lines)
|
385 |
+
|
386 |
+
return "\n".join(final_lines)
|
387 |
+
|
388 |
+
|
389 |
+
def get_source(obj) -> str:
|
390 |
+
"""Get the source code of a class or callable object (e.g.: function, method).
|
391 |
+
First attempts to get the source code using `inspect.getsource`.
|
392 |
+
In a dynamic environment (e.g.: Jupyter, IPython), if this fails,
|
393 |
+
falls back to retrieving the source code from the current interactive shell session.
|
394 |
+
|
395 |
+
Args:
|
396 |
+
obj: A class or callable object (e.g.: function, method)
|
397 |
+
|
398 |
+
Returns:
|
399 |
+
str: The source code of the object, dedented and stripped
|
400 |
+
|
401 |
+
Raises:
|
402 |
+
TypeError: If object is not a class or callable
|
403 |
+
OSError: If source code cannot be retrieved from any source
|
404 |
+
ValueError: If source cannot be found in IPython history
|
405 |
+
|
406 |
+
Note:
|
407 |
+
TODO: handle Python standard REPL
|
408 |
+
"""
|
409 |
+
if not (isinstance(obj, type) or callable(obj)):
|
410 |
+
raise TypeError(f"Expected class or callable, got {type(obj)}")
|
411 |
+
|
412 |
+
inspect_error = None
|
413 |
+
try:
|
414 |
+
# Handle dynamically created classes
|
415 |
+
source = getattr(obj, "__source__", None) or inspect.getsource(obj)
|
416 |
+
return dedent(source).strip()
|
417 |
+
except OSError as e:
|
418 |
+
# let's keep track of the exception to raise it if all further methods fail
|
419 |
+
inspect_error = e
|
420 |
+
try:
|
421 |
+
import IPython
|
422 |
+
|
423 |
+
shell = IPython.get_ipython()
|
424 |
+
if not shell:
|
425 |
+
raise ImportError("No active IPython shell found")
|
426 |
+
all_cells = "\n".join(shell.user_ns.get("In", [])).strip()
|
427 |
+
if not all_cells:
|
428 |
+
raise ValueError("No code cells found in IPython session")
|
429 |
+
|
430 |
+
tree = ast.parse(all_cells)
|
431 |
+
for node in ast.walk(tree):
|
432 |
+
if isinstance(node, (ast.ClassDef, ast.FunctionDef)) and node.name == obj.__name__:
|
433 |
+
return dedent("\n".join(all_cells.split("\n")[node.lineno - 1 : node.end_lineno])).strip()
|
434 |
+
raise ValueError(f"Could not find source code for {obj.__name__} in IPython history")
|
435 |
+
except ImportError:
|
436 |
+
# IPython is not available, let's just raise the original inspect error
|
437 |
+
raise inspect_error
|
438 |
+
except ValueError as e:
|
439 |
+
# IPython is available but we couldn't find the source code, let's raise the error
|
440 |
+
raise e from inspect_error
|
441 |
+
|
442 |
+
|
443 |
+
def encode_image_base64(image):
|
444 |
+
buffered = BytesIO()
|
445 |
+
image.save(buffered, format="PNG")
|
446 |
+
return base64.b64encode(buffered.getvalue()).decode("utf-8")
|
447 |
+
|
448 |
+
|
449 |
+
def make_image_url(base64_image):
|
450 |
+
return f"data:image/png;base64,{base64_image}"
|
451 |
+
|
452 |
+
|
453 |
+
def make_init_file(folder: str | Path):
|
454 |
+
os.makedirs(folder, exist_ok=True)
|
455 |
+
# Create __init__
|
456 |
+
with open(os.path.join(folder, "__init__.py"), "w"):
|
457 |
+
pass
|
458 |
+
|
459 |
+
|
460 |
+
def is_valid_name(name: str) -> bool:
|
461 |
+
return name.isidentifier() and not keyword.iskeyword(name) if isinstance(name, str) else False
|
462 |
+
|
463 |
+
|
464 |
+
AGENT_GRADIO_APP_TEMPLATE = """import yaml
|
465 |
+
import os
|
466 |
+
from smolagents import GradioUI, {{ class_name }}, {{ agent_dict['model']['class'] }}
|
467 |
+
|
468 |
+
# Get current directory path
|
469 |
+
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
|
470 |
+
|
471 |
+
{% for tool in tools.values() -%}
|
472 |
+
from {{managed_agent_relative_path}}tools.{{ tool.name }} import {{ tool.__class__.__name__ }} as {{ tool.name | camelcase }}
|
473 |
+
{% endfor %}
|
474 |
+
{% for managed_agent in managed_agents.values() -%}
|
475 |
+
from {{managed_agent_relative_path}}managed_agents.{{ managed_agent.name }}.app import agent_{{ managed_agent.name }}
|
476 |
+
{% endfor %}
|
477 |
+
|
478 |
+
model = {{ agent_dict['model']['class'] }}(
|
479 |
+
{% for key in agent_dict['model']['data'] if key not in ['class', 'last_input_token_count', 'last_output_token_count'] -%}
|
480 |
+
{{ key }}={{ agent_dict['model']['data'][key]|repr }},
|
481 |
+
{% endfor %})
|
482 |
+
|
483 |
+
{% for tool in tools.values() -%}
|
484 |
+
{{ tool.name }} = {{ tool.name | camelcase }}()
|
485 |
+
{% endfor %}
|
486 |
+
|
487 |
+
with open(os.path.join(CURRENT_DIR, "prompts.yaml"), 'r') as stream:
|
488 |
+
prompt_templates = yaml.safe_load(stream)
|
489 |
+
|
490 |
+
{{ agent_name }} = {{ class_name }}(
|
491 |
+
model=model,
|
492 |
+
tools=[{% for tool_name in tools.keys() if tool_name != "final_answer" %}{{ tool_name }}{% if not loop.last %}, {% endif %}{% endfor %}],
|
493 |
+
managed_agents=[{% for subagent_name in managed_agents.keys() %}agent_{{ subagent_name }}{% if not loop.last %}, {% endif %}{% endfor %}],
|
494 |
+
{% for attribute_name, value in agent_dict.items() if attribute_name not in ["model", "tools", "prompt_templates", "authorized_imports", "managed_agents", "requirements"] -%}
|
495 |
+
{{ attribute_name }}={{ value|repr }},
|
496 |
+
{% endfor %}prompt_templates=prompt_templates
|
497 |
+
)
|
498 |
+
if __name__ == "__main__":
|
499 |
+
GradioUI({{ agent_name }}).launch()
|
500 |
+
""".strip()
|
src/smolagents/vision_web_browser.py
ADDED
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from io import BytesIO
|
3 |
+
from time import sleep
|
4 |
+
|
5 |
+
import helium
|
6 |
+
import PIL.Image
|
7 |
+
from dotenv import load_dotenv
|
8 |
+
from selenium import webdriver
|
9 |
+
from selenium.webdriver.common.by import By
|
10 |
+
from selenium.webdriver.common.keys import Keys
|
11 |
+
|
12 |
+
from smolagents import CodeAgent, WebSearchTool, tool
|
13 |
+
from smolagents.agents import ActionStep
|
14 |
+
from smolagents.cli import load_model
|
15 |
+
|
16 |
+
|
17 |
+
github_request = """
|
18 |
+
I'm trying to find how hard I have to work to get a repo in github.com/trending.
|
19 |
+
Can you navigate to the profile for the top author of the top trending repo, and give me their total number of commits over the last year?
|
20 |
+
""" # The agent is able to achieve this request only when powered by GPT-4o or Claude-3.5-sonnet.
|
21 |
+
|
22 |
+
search_request = """
|
23 |
+
Please navigate to https://en.wikipedia.org/wiki/Chicago and give me a sentence containing the word "1992" that mentions a construction accident.
|
24 |
+
"""
|
25 |
+
|
26 |
+
|
27 |
+
def parse_arguments():
|
28 |
+
parser = argparse.ArgumentParser(description="Run a web browser automation script with a specified model.")
|
29 |
+
parser.add_argument(
|
30 |
+
"prompt",
|
31 |
+
type=str,
|
32 |
+
nargs="?", # Makes it optional
|
33 |
+
default=search_request,
|
34 |
+
help="The prompt to run with the agent",
|
35 |
+
)
|
36 |
+
parser.add_argument(
|
37 |
+
"--model-type",
|
38 |
+
type=str,
|
39 |
+
default="LiteLLMModel",
|
40 |
+
help="The model type to use (e.g., OpenAIServerModel, LiteLLMModel, TransformersModel, InferenceClientModel)",
|
41 |
+
)
|
42 |
+
parser.add_argument(
|
43 |
+
"--model-id",
|
44 |
+
type=str,
|
45 |
+
default="gpt-4o",
|
46 |
+
help="The model ID to use for the specified model type",
|
47 |
+
)
|
48 |
+
parser.add_argument(
|
49 |
+
"--provider",
|
50 |
+
type=str,
|
51 |
+
help="The inference provider to use for the model",
|
52 |
+
)
|
53 |
+
parser.add_argument(
|
54 |
+
"--api-base",
|
55 |
+
type=str,
|
56 |
+
help="The API base to use for the model",
|
57 |
+
)
|
58 |
+
parser.add_argument(
|
59 |
+
"--api-key",
|
60 |
+
type=str,
|
61 |
+
help="The API key to use for the model",
|
62 |
+
)
|
63 |
+
return parser.parse_args()
|
64 |
+
|
65 |
+
|
66 |
+
def save_screenshot(memory_step: ActionStep, agent: CodeAgent) -> None:
|
67 |
+
sleep(1.0) # Let JavaScript animations happen before taking the screenshot
|
68 |
+
driver = helium.get_driver()
|
69 |
+
current_step = memory_step.step_number
|
70 |
+
if driver is not None:
|
71 |
+
for previous_memory_step in agent.memory.steps: # Remove previous screenshots from logs for lean processing
|
72 |
+
if isinstance(previous_memory_step, ActionStep) and previous_memory_step.step_number <= current_step - 2:
|
73 |
+
previous_memory_step.observations_images = None
|
74 |
+
png_bytes = driver.get_screenshot_as_png()
|
75 |
+
image = PIL.Image.open(BytesIO(png_bytes))
|
76 |
+
print(f"Captured a browser screenshot: {image.size} pixels")
|
77 |
+
memory_step.observations_images = [image.copy()] # Create a copy to ensure it persists, important!
|
78 |
+
|
79 |
+
# Update observations with current URL
|
80 |
+
url_info = f"Current url: {driver.current_url}"
|
81 |
+
memory_step.observations = (
|
82 |
+
url_info if memory_step.observations is None else memory_step.observations + "\n" + url_info
|
83 |
+
)
|
84 |
+
return
|
85 |
+
|
86 |
+
|
87 |
+
@tool
|
88 |
+
def search_item_ctrl_f(text: str, nth_result: int = 1) -> str:
|
89 |
+
"""
|
90 |
+
Searches for text on the current page via Ctrl + F and jumps to the nth occurrence.
|
91 |
+
Args:
|
92 |
+
text: The text to search for
|
93 |
+
nth_result: Which occurrence to jump to (default: 1)
|
94 |
+
"""
|
95 |
+
elements = driver.find_elements(By.XPATH, f"//*[contains(text(), '{text}')]")
|
96 |
+
if nth_result > len(elements):
|
97 |
+
raise Exception(f"Match n°{nth_result} not found (only {len(elements)} matches found)")
|
98 |
+
result = f"Found {len(elements)} matches for '{text}'."
|
99 |
+
elem = elements[nth_result - 1]
|
100 |
+
driver.execute_script("arguments[0].scrollIntoView(true);", elem)
|
101 |
+
result += f"Focused on element {nth_result} of {len(elements)}"
|
102 |
+
return result
|
103 |
+
|
104 |
+
|
105 |
+
@tool
|
106 |
+
def go_back() -> None:
|
107 |
+
"""Goes back to previous page."""
|
108 |
+
driver.back()
|
109 |
+
|
110 |
+
|
111 |
+
@tool
|
112 |
+
def close_popups() -> str:
|
113 |
+
"""
|
114 |
+
Closes any visible modal or pop-up on the page. Use this to dismiss pop-up windows! This does not work on cookie consent banners.
|
115 |
+
"""
|
116 |
+
webdriver.ActionChains(driver).send_keys(Keys.ESCAPE).perform()
|
117 |
+
|
118 |
+
|
119 |
+
def initialize_driver():
|
120 |
+
"""Initialize the Selenium WebDriver."""
|
121 |
+
chrome_options = webdriver.ChromeOptions()
|
122 |
+
chrome_options.add_argument("--force-device-scale-factor=1")
|
123 |
+
chrome_options.add_argument("--window-size=1000,1350")
|
124 |
+
chrome_options.add_argument("--disable-pdf-viewer")
|
125 |
+
chrome_options.add_argument("--window-position=0,0")
|
126 |
+
return helium.start_chrome(headless=False, options=chrome_options)
|
127 |
+
|
128 |
+
|
129 |
+
def initialize_agent(model):
|
130 |
+
"""Initialize the CodeAgent with the specified model."""
|
131 |
+
return CodeAgent(
|
132 |
+
tools=[WebSearchTool(), go_back, close_popups, search_item_ctrl_f],
|
133 |
+
model=model,
|
134 |
+
additional_authorized_imports=["helium"],
|
135 |
+
step_callbacks=[save_screenshot],
|
136 |
+
max_steps=20,
|
137 |
+
verbosity_level=2,
|
138 |
+
)
|
139 |
+
|
140 |
+
|
141 |
+
helium_instructions = """
|
142 |
+
Use your web_search tool when you want to get Google search results.
|
143 |
+
Then you can use helium to access websites. Don't use helium for Google search, only for navigating websites!
|
144 |
+
Don't bother about the helium driver, it's already managed.
|
145 |
+
We've already ran "from helium import *"
|
146 |
+
Then you can go to pages!
|
147 |
+
<code>
|
148 |
+
go_to('github.com/trending')
|
149 |
+
</code>
|
150 |
+
|
151 |
+
You can directly click clickable elements by inputting the text that appears on them.
|
152 |
+
<code>
|
153 |
+
click("Top products")
|
154 |
+
</code>
|
155 |
+
|
156 |
+
If it's a link:
|
157 |
+
<code>
|
158 |
+
click(Link("Top products"))
|
159 |
+
</code>
|
160 |
+
|
161 |
+
If you try to interact with an element and it's not found, you'll get a LookupError.
|
162 |
+
In general stop your action after each button click to see what happens on your screenshot.
|
163 |
+
Never try to login in a page.
|
164 |
+
|
165 |
+
To scroll up or down, use scroll_down or scroll_up with as an argument the number of pixels to scroll from.
|
166 |
+
<code>
|
167 |
+
scroll_down(num_pixels=1200) # This will scroll one viewport down
|
168 |
+
</code>
|
169 |
+
|
170 |
+
When you have pop-ups with a cross icon to close, don't try to click the close icon by finding its element or targeting an 'X' element (this most often fails).
|
171 |
+
Just use your built-in tool `close_popups` to close them:
|
172 |
+
<code>
|
173 |
+
close_popups()
|
174 |
+
</code>
|
175 |
+
|
176 |
+
You can use .exists() to check for the existence of an element. For example:
|
177 |
+
<code>
|
178 |
+
if Text('Accept cookies?').exists():
|
179 |
+
click('I accept')
|
180 |
+
</code>
|
181 |
+
|
182 |
+
Proceed in several steps rather than trying to solve the task in one shot.
|
183 |
+
And at the end, only when you have your answer, return your final answer.
|
184 |
+
<code>
|
185 |
+
final_answer("YOUR_ANSWER_HERE")
|
186 |
+
</code>
|
187 |
+
|
188 |
+
If pages seem stuck on loading, you might have to wait, for instance `import time` and run `time.sleep(5.0)`. But don't overuse this!
|
189 |
+
To list elements on page, DO NOT try code-based element searches like 'contributors = find_all(S("ol > li"))': just look at the latest screenshot you have and read it visually, or use your tool search_item_ctrl_f.
|
190 |
+
Of course, you can act on buttons like a user would do when navigating.
|
191 |
+
After each code blob you write, you will be automatically provided with an updated screenshot of the browser and the current browser url.
|
192 |
+
But beware that the screenshot will only be taken at the end of the whole action, it won't see intermediate states.
|
193 |
+
Don't kill the browser.
|
194 |
+
When you have modals or cookie banners on screen, you should get rid of them before you can click anything else.
|
195 |
+
"""
|
196 |
+
|
197 |
+
|
198 |
+
def run_webagent(
|
199 |
+
prompt: str,
|
200 |
+
model_type: str,
|
201 |
+
model_id: str,
|
202 |
+
provider: str | None = None,
|
203 |
+
api_base: str | None = None,
|
204 |
+
api_key: str | None = None,
|
205 |
+
) -> None:
|
206 |
+
# Load environment variables
|
207 |
+
load_dotenv()
|
208 |
+
|
209 |
+
# Initialize the model based on the provided arguments
|
210 |
+
model = load_model(model_type, model_id, provider=provider, api_base=api_base, api_key=api_key)
|
211 |
+
|
212 |
+
global driver
|
213 |
+
driver = initialize_driver()
|
214 |
+
agent = initialize_agent(model)
|
215 |
+
|
216 |
+
# Run the agent with the provided prompt
|
217 |
+
agent.python_executor("from helium import *")
|
218 |
+
agent.run(prompt + helium_instructions)
|
219 |
+
|
220 |
+
|
221 |
+
def main() -> None:
|
222 |
+
# Parse command line arguments
|
223 |
+
args = parse_arguments()
|
224 |
+
run_webagent(args.prompt, args.model_type, args.model_id, args.provider, args.api_base, args.api_key)
|
225 |
+
|
226 |
+
|
227 |
+
if __name__ == "__main__":
|
228 |
+
main()
|