Duibonduil commited on
Commit
d7949de
·
verified ·
1 Parent(s): 2928e27

Upload 17 files

Browse files
src/smolagents/__init__.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+
4
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ __version__ = "1.20.0.dev0"
18
+
19
+ from .agent_types import * # noqa: I001
20
+ from .agents import * # Above noqa avoids a circular dependency due to cli.py
21
+ from .default_tools import *
22
+ from .gradio_ui import *
23
+ from .local_python_executor import *
24
+ from .mcp_client import *
25
+ from .memory import *
26
+ from .models import *
27
+ from .monitoring import *
28
+ from .remote_executors import *
29
+ from .tools import *
30
+ from .utils import *
31
+ from .cli import *
src/smolagents/_function_type_hints_utils.py ADDED
@@ -0,0 +1,431 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+
4
+ # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ """This module contains utilities exclusively taken from `transformers` repository.
18
+
19
+ Since they are not specific to `transformers` and that `transformers` is an heavy dependencies, those helpers have
20
+ been duplicated.
21
+
22
+ TODO: move them to `huggingface_hub` to avoid code duplication.
23
+ """
24
+
25
+ import inspect
26
+ import json
27
+ import re
28
+ import types
29
+ from collections.abc import Callable
30
+ from copy import copy
31
+ from typing import (
32
+ Any,
33
+ Literal,
34
+ Union,
35
+ get_args,
36
+ get_origin,
37
+ get_type_hints,
38
+ )
39
+
40
+
41
+ IMPORT_TO_PACKAGE_MAPPING = {
42
+ "wikipediaapi": "wikipedia-api",
43
+ }
44
+
45
+
46
+ def get_package_name(import_name: str) -> str:
47
+ """
48
+ Return the package name for a given import name.
49
+
50
+ Args:
51
+ import_name (`str`): Import name to get the package name for.
52
+
53
+ Returns:
54
+ `str`: Package name for the given import name.
55
+ """
56
+ return IMPORT_TO_PACKAGE_MAPPING.get(import_name, import_name)
57
+
58
+
59
+ def get_imports(code: str) -> list[str]:
60
+ """
61
+ Extracts all the libraries (not relative imports) that are imported in a code.
62
+
63
+ Args:
64
+ code (`str`): Code text to inspect.
65
+
66
+ Returns:
67
+ `list[str]`: List of all packages required to use the input code.
68
+ """
69
+ # filter out try/except block so in custom code we can have try/except imports
70
+ code = re.sub(r"\s*try\s*:.*?except.*?:", "", code, flags=re.DOTALL)
71
+
72
+ # filter out imports under is_flash_attn_2_available block for avoid import issues in cpu only environment
73
+ code = re.sub(
74
+ r"if is_flash_attn[a-zA-Z0-9_]+available\(\):\s*(from flash_attn\s*.*\s*)+",
75
+ "",
76
+ code,
77
+ flags=re.MULTILINE,
78
+ )
79
+
80
+ # Imports of the form `import xxx` or `import xxx as yyy`
81
+ imports = re.findall(r"^\s*import\s+(\S+?)(?:\s+as\s+\S+)?\s*$", code, flags=re.MULTILINE)
82
+ # Imports of the form `from xxx import yyy`
83
+ imports += re.findall(r"^\s*from\s+(\S+)\s+import", code, flags=re.MULTILINE)
84
+ # Only keep the top-level module
85
+ imports = [imp.split(".")[0] for imp in imports if not imp.startswith(".")]
86
+ return [get_package_name(import_name) for import_name in set(imports)]
87
+
88
+
89
+ class TypeHintParsingException(Exception):
90
+ """Exception raised for errors in parsing type hints to generate JSON schemas"""
91
+
92
+
93
+ class DocstringParsingException(Exception):
94
+ """Exception raised for errors in parsing docstrings to generate JSON schemas"""
95
+
96
+
97
+ def get_json_schema(func: Callable) -> dict:
98
+ """
99
+ This function generates a JSON schema for a given function, based on its docstring and type hints. This is
100
+ mostly used for passing lists of tools to a chat template. The JSON schema contains the name and description of
101
+ the function, as well as the names, types and descriptions for each of its arguments. `get_json_schema()` requires
102
+ that the function has a docstring, and that each argument has a description in the docstring, in the standard
103
+ Google docstring format shown below. It also requires that all the function arguments have a valid Python type hint.
104
+
105
+ Although it is not required, a `Returns` block can also be added, which will be included in the schema. This is
106
+ optional because most chat templates ignore the return value of the function.
107
+
108
+ Args:
109
+ func: The function to generate a JSON schema for.
110
+
111
+ Returns:
112
+ A dictionary containing the JSON schema for the function.
113
+
114
+ Examples:
115
+ ```python
116
+ >>> def multiply(x: float, y: float):
117
+ >>> '''
118
+ >>> A function that multiplies two numbers
119
+ >>>
120
+ >>> Args:
121
+ >>> x: The first number to multiply
122
+ >>> y: The second number to multiply
123
+ >>> '''
124
+ >>> return x * y
125
+ >>>
126
+ >>> print(get_json_schema(multiply))
127
+ {
128
+ "name": "multiply",
129
+ "description": "A function that multiplies two numbers",
130
+ "parameters": {
131
+ "type": "object",
132
+ "properties": {
133
+ "x": {"type": "number", "description": "The first number to multiply"},
134
+ "y": {"type": "number", "description": "The second number to multiply"}
135
+ },
136
+ "required": ["x", "y"]
137
+ }
138
+ }
139
+ ```
140
+
141
+ The general use for these schemas is that they are used to generate tool descriptions for chat templates that
142
+ support them, like so:
143
+
144
+ ```python
145
+ >>> from transformers import AutoTokenizer
146
+ >>> from transformers.utils import get_json_schema
147
+ >>>
148
+ >>> def multiply(x: float, y: float):
149
+ >>> '''
150
+ >>> A function that multiplies two numbers
151
+ >>>
152
+ >>> Args:
153
+ >>> x: The first number to multiply
154
+ >>> y: The second number to multiply
155
+ >>> return x * y
156
+ >>> '''
157
+ >>>
158
+ >>> multiply_schema = get_json_schema(multiply)
159
+ >>> tokenizer = AutoTokenizer.from_pretrained("CohereForAI/c4ai-command-r-v01")
160
+ >>> messages = [{"role": "user", "content": "What is 179 x 4571?"}]
161
+ >>> formatted_chat = tokenizer.apply_chat_template(
162
+ >>> messages,
163
+ >>> tools=[multiply_schema],
164
+ >>> chat_template="tool_use",
165
+ >>> return_dict=True,
166
+ >>> return_tensors="pt",
167
+ >>> add_generation_prompt=True
168
+ >>> )
169
+ >>> # The formatted chat can now be passed to model.generate()
170
+ ```
171
+
172
+ Each argument description can also have an optional `(choices: ...)` block at the end, such as
173
+ `(choices: ["tea", "coffee"])`, which will be parsed into an `enum` field in the schema. Note that this will
174
+ only be parsed correctly if it is at the end of the line:
175
+
176
+ ```python
177
+ >>> def drink_beverage(beverage: str):
178
+ >>> '''
179
+ >>> A function that drinks a beverage
180
+ >>>
181
+ >>> Args:
182
+ >>> beverage: The beverage to drink (choices: ["tea", "coffee"])
183
+ >>> '''
184
+ >>> pass
185
+ >>>
186
+ >>> print(get_json_schema(drink_beverage))
187
+ ```
188
+ {
189
+ 'name': 'drink_beverage',
190
+ 'description': 'A function that drinks a beverage',
191
+ 'parameters': {
192
+ 'type': 'object',
193
+ 'properties': {
194
+ 'beverage': {
195
+ 'type': 'string',
196
+ 'enum': ['tea', 'coffee'],
197
+ 'description': 'The beverage to drink'
198
+ }
199
+ },
200
+ 'required': ['beverage']
201
+ }
202
+ }
203
+ """
204
+ doc = inspect.getdoc(func)
205
+ if not doc:
206
+ raise DocstringParsingException(
207
+ f"Cannot generate JSON schema for {func.__name__} because it has no docstring!"
208
+ )
209
+ doc = doc.strip()
210
+ main_doc, param_descriptions, return_doc = _parse_google_format_docstring(doc)
211
+
212
+ json_schema = _convert_type_hints_to_json_schema(func)
213
+ if (return_dict := json_schema["properties"].pop("return", None)) is not None:
214
+ if return_doc is not None: # We allow a missing return docstring since most templates ignore it
215
+ return_dict["description"] = return_doc
216
+ for arg, schema in json_schema["properties"].items():
217
+ if arg not in param_descriptions:
218
+ raise DocstringParsingException(
219
+ f"Cannot generate JSON schema for {func.__name__} because the docstring has no description for the argument '{arg}'"
220
+ )
221
+ desc = param_descriptions[arg]
222
+ enum_choices = re.search(r"\(choices:\s*(.*?)\)\s*$", desc, flags=re.IGNORECASE)
223
+ if enum_choices:
224
+ schema["enum"] = [c.strip() for c in json.loads(enum_choices.group(1))]
225
+ desc = enum_choices.string[: enum_choices.start()].strip()
226
+ schema["description"] = desc
227
+
228
+ output = {"name": func.__name__, "description": main_doc, "parameters": json_schema}
229
+ if return_dict is not None:
230
+ output["return"] = return_dict
231
+ return {"type": "function", "function": output}
232
+
233
+
234
+ # Extracts the initial segment of the docstring, containing the function description
235
+ description_re = re.compile(r"^(.*?)(?=\n\s*(Args:|Returns:|Raises:)|\Z)", re.DOTALL)
236
+ # Extracts the Args: block from the docstring
237
+ args_re = re.compile(r"\n\s*Args:\n\s*(.*?)[\n\s]*(Returns:|Raises:|\Z)", re.DOTALL)
238
+ # Splits the Args: block into individual arguments
239
+ args_split_re = re.compile(
240
+ r"(?:^|\n)" # Match the start of the args block, or a newline
241
+ r"\s*(\w+)\s*(?:\([^)]*?\))?:\s*" # Capture the argument name (ignore the type) and strip spacing
242
+ r"(.*?)\s*" # Capture the argument description, which can span multiple lines, and strip trailing spacing
243
+ r"(?=\n\s*\w+\s*(?:\([^)]*?\))?:|\Z)", # Stop when you hit the next argument (with or without type) or the end of the block
244
+ re.DOTALL | re.VERBOSE,
245
+ )
246
+ # Extracts the Returns: block from the docstring, if present. Note that most chat templates ignore the return type/doc!
247
+ returns_re = re.compile(
248
+ r"\n\s*Returns:\n\s*"
249
+ r"(?:[^)]*?:\s*)?" # Ignore the return type if present
250
+ r"(.*?)" # Capture the return description
251
+ r"[\n\s]*(Raises:|\Z)",
252
+ re.DOTALL,
253
+ )
254
+
255
+
256
+ def _parse_google_format_docstring(
257
+ docstring: str,
258
+ ) -> tuple[str | None, dict | None, str | None]:
259
+ """
260
+ Parses a Google-style docstring to extract the function description,
261
+ argument descriptions, and return description.
262
+
263
+ Args:
264
+ docstring (str): The docstring to parse.
265
+
266
+ Returns:
267
+ The function description, arguments, and return description.
268
+ """
269
+
270
+ # Extract the sections
271
+ description_match = description_re.search(docstring)
272
+ args_match = args_re.search(docstring)
273
+ returns_match = returns_re.search(docstring)
274
+
275
+ # Clean and store the sections
276
+ description = description_match.group(1).strip() if description_match else None
277
+ docstring_args = args_match.group(1).strip() if args_match else None
278
+ returns = returns_match.group(1).strip() if returns_match else None
279
+
280
+ # Parsing the arguments into a dictionary
281
+ if docstring_args is not None:
282
+ docstring_args = "\n".join([line for line in docstring_args.split("\n") if line.strip()]) # Remove blank lines
283
+ matches = args_split_re.findall(docstring_args)
284
+ args_dict = {match[0]: re.sub(r"\s*\n+\s*", " ", match[1].strip()) for match in matches}
285
+ else:
286
+ args_dict = {}
287
+
288
+ return description, args_dict, returns
289
+
290
+
291
+ def _convert_type_hints_to_json_schema(func: Callable, error_on_missing_type_hints: bool = True) -> dict:
292
+ type_hints = get_type_hints(func)
293
+ signature = inspect.signature(func)
294
+
295
+ properties = {}
296
+ for param_name, param_type in type_hints.items():
297
+ properties[param_name] = _parse_type_hint(param_type)
298
+
299
+ required = []
300
+ for param_name, param in signature.parameters.items():
301
+ if param.annotation == inspect.Parameter.empty and error_on_missing_type_hints:
302
+ raise TypeHintParsingException(f"Argument {param.name} is missing a type hint in function {func.__name__}")
303
+ if param_name not in properties:
304
+ properties[param_name] = {}
305
+
306
+ if param.default == inspect.Parameter.empty:
307
+ required.append(param_name)
308
+ else:
309
+ properties[param_name]["nullable"] = True
310
+
311
+ # Return: multi‐type union -> treat as any
312
+ if (
313
+ "return" in properties
314
+ and (return_type := properties["return"].get("type"))
315
+ and not isinstance(return_type, str)
316
+ ):
317
+ properties["return"]["type"] = "any"
318
+
319
+ schema = {"type": "object", "properties": properties}
320
+ if required:
321
+ schema["required"] = required
322
+
323
+ return schema
324
+
325
+
326
+ def _parse_type_hint(hint: type) -> dict:
327
+ origin = get_origin(hint)
328
+ args = get_args(hint)
329
+
330
+ if origin is None:
331
+ try:
332
+ return _get_json_schema_type(hint)
333
+ except KeyError:
334
+ raise TypeHintParsingException(
335
+ "Couldn't parse this type hint, likely due to a custom class or object: ",
336
+ hint,
337
+ )
338
+
339
+ elif origin is Union or (hasattr(types, "UnionType") and origin is types.UnionType):
340
+ return _parse_union_type(args)
341
+
342
+ elif origin is list:
343
+ if not args:
344
+ return {"type": "array"}
345
+ else:
346
+ # Lists can only have a single type argument, so recurse into it
347
+ return {"type": "array", "items": _parse_type_hint(args[0])}
348
+
349
+ elif origin is tuple:
350
+ if not args:
351
+ return {"type": "array"}
352
+ if len(args) == 1:
353
+ raise TypeHintParsingException(
354
+ f"The type hint {str(hint).replace('typing.', '')} is a Tuple with a single element, which "
355
+ "we do not automatically convert to JSON schema as it is rarely necessary. If this input can contain "
356
+ "more than one element, we recommend "
357
+ "using a List[] type instead, or if it really is a single element, remove the Tuple[] wrapper and just "
358
+ "pass the element directly."
359
+ )
360
+ if ... in args:
361
+ raise TypeHintParsingException(
362
+ "Conversion of '...' is not supported in Tuple type hints. "
363
+ "Use List[] types for variable-length"
364
+ " inputs instead."
365
+ )
366
+ return {"type": "array", "prefixItems": [_parse_type_hint(t) for t in args]}
367
+
368
+ elif origin is dict:
369
+ # The JSON equivalent to a dict is 'object', which mandates that all keys are strings
370
+ # However, we can specify the type of the dict values with "additionalProperties"
371
+ out = {"type": "object"}
372
+ if len(args) == 2:
373
+ out["additionalProperties"] = _parse_type_hint(args[1])
374
+ return out
375
+
376
+ elif origin is Literal:
377
+ literal_types = set(type(arg) for arg in args)
378
+ final_type = _parse_union_type(literal_types)
379
+
380
+ # None literal value is represented by 'nullable' field set by _parse_union_type
381
+ final_type.update({"enum": [arg for arg in args if arg is not None]})
382
+ return final_type
383
+
384
+ raise TypeHintParsingException("Couldn't parse this type hint, likely due to a custom class or object: ", hint)
385
+
386
+
387
+ def _parse_union_type(args: tuple[Any, ...]) -> dict:
388
+ subtypes = [_parse_type_hint(t) for t in args if t is not type(None)]
389
+ if len(subtypes) == 1:
390
+ # A single non-null type can be expressed directly
391
+ return_dict = subtypes[0]
392
+ elif all(isinstance(subtype["type"], str) for subtype in subtypes):
393
+ # A union of basic types can be expressed as a list in the schema
394
+ return_dict = {"type": sorted([subtype["type"] for subtype in subtypes])}
395
+ else:
396
+ # A union of more complex types requires "anyOf"
397
+ return_dict = {"anyOf": subtypes}
398
+ if type(None) in args:
399
+ return_dict["nullable"] = True
400
+ return return_dict
401
+
402
+
403
+ _BASE_TYPE_MAPPING = {
404
+ int: {"type": "integer"},
405
+ float: {"type": "number"},
406
+ str: {"type": "string"},
407
+ bool: {"type": "boolean"},
408
+ list: {"type": "array"},
409
+ dict: {"type": "object"},
410
+ Any: {"type": "any"},
411
+ types.NoneType: {"type": "null"},
412
+ }
413
+
414
+
415
+ def _get_json_schema_type(param_type: type) -> dict[str, str]:
416
+ if param_type in _BASE_TYPE_MAPPING:
417
+ return copy(_BASE_TYPE_MAPPING[param_type])
418
+ if str(param_type) == "Image":
419
+ from PIL.Image import Image
420
+
421
+ if param_type == Image:
422
+ return {"type": "image"}
423
+ if str(param_type) == "Tensor":
424
+ try:
425
+ from torch import Tensor
426
+
427
+ if param_type == Tensor:
428
+ return {"type": "audio"}
429
+ except ModuleNotFoundError:
430
+ pass
431
+ return {"type": "object"}
src/smolagents/agent_types.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 HuggingFace Inc.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import logging
16
+ import os
17
+ import pathlib
18
+ import tempfile
19
+ import uuid
20
+ from io import BytesIO
21
+
22
+ import PIL.Image
23
+ import requests
24
+
25
+ from .utils import _is_package_available
26
+
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ class AgentType:
32
+ """
33
+ Abstract class to be reimplemented to define types that can be returned by agents.
34
+
35
+ These objects serve three purposes:
36
+
37
+ - They behave as they were the type they're meant to be, e.g., a string for text, a PIL.Image.Image for images
38
+ - They can be stringified: str(object) in order to return a string defining the object
39
+ - They should be displayed correctly in ipython notebooks/colab/jupyter
40
+ """
41
+
42
+ def __init__(self, value):
43
+ self._value = value
44
+
45
+ def __str__(self):
46
+ return self.to_string()
47
+
48
+ def to_raw(self):
49
+ logger.error(
50
+ "This is a raw AgentType of unknown type. Display in notebooks and string conversion will be unreliable"
51
+ )
52
+ return self._value
53
+
54
+ def to_string(self) -> str:
55
+ logger.error(
56
+ "This is a raw AgentType of unknown type. Display in notebooks and string conversion will be unreliable"
57
+ )
58
+ return str(self._value)
59
+
60
+
61
+ class AgentText(AgentType, str):
62
+ """
63
+ Text type returned by the agent. Behaves as a string.
64
+ """
65
+
66
+ def to_raw(self):
67
+ return self._value
68
+
69
+ def to_string(self):
70
+ return str(self._value)
71
+
72
+
73
+ class AgentImage(AgentType, PIL.Image.Image):
74
+ """
75
+ Image type returned by the agent. Behaves as a PIL.Image.Image.
76
+ """
77
+
78
+ def __init__(self, value):
79
+ AgentType.__init__(self, value)
80
+ PIL.Image.Image.__init__(self)
81
+
82
+ self._path = None
83
+ self._raw = None
84
+ self._tensor = None
85
+
86
+ if isinstance(value, AgentImage):
87
+ self._raw, self._path, self._tensor = value._raw, value._path, value._tensor
88
+ elif isinstance(value, PIL.Image.Image):
89
+ self._raw = value
90
+ elif isinstance(value, bytes):
91
+ self._raw = PIL.Image.open(BytesIO(value))
92
+ elif isinstance(value, (str, pathlib.Path)):
93
+ self._path = value
94
+ else:
95
+ try:
96
+ import torch
97
+
98
+ if isinstance(value, torch.Tensor):
99
+ self._tensor = value
100
+ import numpy as np
101
+
102
+ if isinstance(value, np.ndarray):
103
+ self._tensor = torch.from_numpy(value)
104
+ except ModuleNotFoundError:
105
+ pass
106
+
107
+ if self._path is None and self._raw is None and self._tensor is None:
108
+ raise TypeError(f"Unsupported type for {self.__class__.__name__}: {type(value)}")
109
+
110
+ def _ipython_display_(self, include=None, exclude=None):
111
+ """
112
+ Displays correctly this type in an ipython notebook (ipython, colab, jupyter, ...)
113
+ """
114
+ from IPython.display import Image, display
115
+
116
+ display(Image(self.to_string()))
117
+
118
+ def to_raw(self):
119
+ """
120
+ Returns the "raw" version of that object. In the case of an AgentImage, it is a PIL.Image.Image.
121
+ """
122
+ if self._raw is not None:
123
+ return self._raw
124
+
125
+ if self._path is not None:
126
+ self._raw = PIL.Image.open(self._path)
127
+ return self._raw
128
+
129
+ if self._tensor is not None:
130
+ import numpy as np
131
+
132
+ array = self._tensor.cpu().detach().numpy()
133
+ return PIL.Image.fromarray((255 - array * 255).astype(np.uint8))
134
+
135
+ def to_string(self):
136
+ """
137
+ Returns the stringified version of that object. In the case of an AgentImage, it is a path to the serialized
138
+ version of the image.
139
+ """
140
+ if self._path is not None:
141
+ return self._path
142
+
143
+ if self._raw is not None:
144
+ directory = tempfile.mkdtemp()
145
+ self._path = os.path.join(directory, str(uuid.uuid4()) + ".png")
146
+ self._raw.save(self._path, format="png")
147
+ return self._path
148
+
149
+ if self._tensor is not None:
150
+ import numpy as np
151
+
152
+ array = self._tensor.cpu().detach().numpy()
153
+
154
+ # There is likely simpler than load into image into save
155
+ img = PIL.Image.fromarray((255 - array * 255).astype(np.uint8))
156
+
157
+ directory = tempfile.mkdtemp()
158
+ self._path = os.path.join(directory, str(uuid.uuid4()) + ".png")
159
+ img.save(self._path, format="png")
160
+
161
+ return self._path
162
+
163
+ def save(self, output_bytes, format: str = None, **params):
164
+ """
165
+ Saves the image to a file.
166
+ Args:
167
+ output_bytes (bytes): The output bytes to save the image to.
168
+ format (str): The format to use for the output image. The format is the same as in PIL.Image.save.
169
+ **params: Additional parameters to pass to PIL.Image.save.
170
+ """
171
+ img = self.to_raw()
172
+ img.save(output_bytes, format=format, **params)
173
+
174
+
175
+ class AgentAudio(AgentType, str):
176
+ """
177
+ Audio type returned by the agent.
178
+ """
179
+
180
+ def __init__(self, value, samplerate=16_000):
181
+ if not _is_package_available("soundfile") or not _is_package_available("torch"):
182
+ raise ModuleNotFoundError(
183
+ "Please install 'audio' extra to use AgentAudio: `pip install 'smolagents[audio]'`"
184
+ )
185
+ import numpy as np
186
+ import torch
187
+
188
+ super().__init__(value)
189
+
190
+ self._path = None
191
+ self._tensor = None
192
+
193
+ self.samplerate = samplerate
194
+ if isinstance(value, (str, pathlib.Path)):
195
+ self._path = value
196
+ elif isinstance(value, torch.Tensor):
197
+ self._tensor = value
198
+ elif isinstance(value, tuple):
199
+ self.samplerate = value[0]
200
+ if isinstance(value[1], np.ndarray):
201
+ self._tensor = torch.from_numpy(value[1])
202
+ else:
203
+ self._tensor = torch.tensor(value[1])
204
+ else:
205
+ raise ValueError(f"Unsupported audio type: {type(value)}")
206
+
207
+ def _ipython_display_(self, include=None, exclude=None):
208
+ """
209
+ Displays correctly this type in an ipython notebook (ipython, colab, jupyter, ...)
210
+ """
211
+ from IPython.display import Audio, display
212
+
213
+ display(Audio(self.to_string(), rate=self.samplerate))
214
+
215
+ def to_raw(self):
216
+ """
217
+ Returns the "raw" version of that object. It is a `torch.Tensor` object.
218
+ """
219
+ import soundfile as sf
220
+
221
+ if self._tensor is not None:
222
+ return self._tensor
223
+
224
+ import torch
225
+
226
+ if self._path is not None:
227
+ if "://" in str(self._path):
228
+ response = requests.get(self._path)
229
+ response.raise_for_status()
230
+ tensor, self.samplerate = sf.read(BytesIO(response.content))
231
+ else:
232
+ tensor, self.samplerate = sf.read(self._path)
233
+ self._tensor = torch.tensor(tensor)
234
+ return self._tensor
235
+
236
+ def to_string(self):
237
+ """
238
+ Returns the stringified version of that object. In the case of an AgentAudio, it is a path to the serialized
239
+ version of the audio.
240
+ """
241
+ import soundfile as sf
242
+
243
+ if self._path is not None:
244
+ return self._path
245
+
246
+ if self._tensor is not None:
247
+ directory = tempfile.mkdtemp()
248
+ self._path = os.path.join(directory, str(uuid.uuid4()) + ".wav")
249
+ sf.write(self._path, self._tensor, samplerate=self.samplerate)
250
+ return self._path
251
+
252
+
253
+ _AGENT_TYPE_MAPPING = {"string": AgentText, "image": AgentImage, "audio": AgentAudio}
254
+
255
+
256
+ def handle_agent_input_types(*args, **kwargs):
257
+ args = [(arg.to_raw() if isinstance(arg, AgentType) else arg) for arg in args]
258
+ kwargs = {k: (v.to_raw() if isinstance(v, AgentType) else v) for k, v in kwargs.items()}
259
+ return args, kwargs
260
+
261
+
262
+ def handle_agent_output_types(output, output_type=None):
263
+ if output_type in _AGENT_TYPE_MAPPING:
264
+ # If the class has defined outputs, we can map directly according to the class definition
265
+ decoded_outputs = _AGENT_TYPE_MAPPING[output_type](output)
266
+ return decoded_outputs
267
+
268
+ # If the class does not have defined output, then we map according to the type
269
+ if isinstance(output, str):
270
+ return AgentText(output)
271
+ if isinstance(output, PIL.Image.Image):
272
+ return AgentImage(output)
273
+ try:
274
+ import torch
275
+
276
+ if isinstance(output, torch.Tensor):
277
+ return AgentAudio(output)
278
+ except ModuleNotFoundError:
279
+ pass
280
+ return output
281
+
282
+
283
+ __all__ = ["AgentType", "AgentImage", "AgentText", "AgentAudio"]
src/smolagents/agents.py ADDED
@@ -0,0 +1,1725 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+
4
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ import importlib
18
+ import inspect
19
+ import json
20
+ import os
21
+ import re
22
+ import tempfile
23
+ import textwrap
24
+ import time
25
+ import warnings
26
+ from abc import ABC, abstractmethod
27
+ from collections.abc import Callable, Generator
28
+ from concurrent.futures import ThreadPoolExecutor, as_completed
29
+ from dataclasses import dataclass
30
+ from logging import getLogger
31
+ from pathlib import Path
32
+ from typing import TYPE_CHECKING, Any, Literal, TypeAlias, TypedDict, Union
33
+
34
+ import jinja2
35
+ import yaml
36
+ from huggingface_hub import create_repo, metadata_update, snapshot_download, upload_folder
37
+ from jinja2 import StrictUndefined, Template
38
+ from rich.console import Group
39
+ from rich.live import Live
40
+ from rich.markdown import Markdown
41
+ from rich.panel import Panel
42
+ from rich.rule import Rule
43
+ from rich.text import Text
44
+
45
+
46
+ if TYPE_CHECKING:
47
+ import PIL.Image
48
+
49
+ from .agent_types import AgentAudio, AgentImage, handle_agent_output_types
50
+ from .default_tools import TOOL_MAPPING, FinalAnswerTool
51
+ from .local_python_executor import BASE_BUILTIN_MODULES, LocalPythonExecutor, PythonExecutor, fix_final_answer_code
52
+ from .memory import (
53
+ ActionStep,
54
+ AgentMemory,
55
+ FinalAnswerStep,
56
+ PlanningStep,
57
+ SystemPromptStep,
58
+ TaskStep,
59
+ Timing,
60
+ TokenUsage,
61
+ ToolCall,
62
+ )
63
+ from .models import (
64
+ CODEAGENT_RESPONSE_FORMAT,
65
+ ChatMessage,
66
+ ChatMessageStreamDelta,
67
+ ChatMessageToolCall,
68
+ MessageRole,
69
+ Model,
70
+ agglomerate_stream_deltas,
71
+ parse_json_if_needed,
72
+ )
73
+ from .monitoring import (
74
+ YELLOW_HEX,
75
+ AgentLogger,
76
+ LogLevel,
77
+ Monitor,
78
+ )
79
+ from .remote_executors import DockerExecutor, E2BExecutor
80
+ from .tools import Tool, validate_tool_arguments
81
+ from .utils import (
82
+ AGENT_GRADIO_APP_TEMPLATE,
83
+ AgentError,
84
+ AgentExecutionError,
85
+ AgentGenerationError,
86
+ AgentMaxStepsError,
87
+ AgentParsingError,
88
+ AgentToolCallError,
89
+ AgentToolExecutionError,
90
+ extract_code_from_text,
91
+ is_valid_name,
92
+ make_init_file,
93
+ parse_code_blobs,
94
+ truncate_content,
95
+ )
96
+
97
+
98
+ logger = getLogger(__name__)
99
+
100
+
101
+ def get_variable_names(self, template: str) -> set[str]:
102
+ pattern = re.compile(r"\{\{([^{}]+)\}\}")
103
+ return {match.group(1).strip() for match in pattern.finditer(template)}
104
+
105
+
106
+ def populate_template(template: str, variables: dict[str, Any]) -> str:
107
+ compiled_template = Template(template, undefined=StrictUndefined)
108
+ try:
109
+ return compiled_template.render(**variables)
110
+ except Exception as e:
111
+ raise Exception(f"Error during jinja template rendering: {type(e).__name__}: {e}")
112
+
113
+
114
+ @dataclass
115
+ class ActionOutput:
116
+ output: Any
117
+ is_final_answer: bool
118
+
119
+
120
+ @dataclass
121
+ class ToolOutput:
122
+ id: str
123
+ output: Any
124
+ is_final_answer: bool
125
+ observation: str
126
+ tool_call: ToolCall
127
+
128
+
129
+ class PlanningPromptTemplate(TypedDict):
130
+ """
131
+ Prompt templates for the planning step.
132
+
133
+ Args:
134
+ plan (`str`): Initial plan prompt.
135
+ update_plan_pre_messages (`str`): Update plan pre-messages prompt.
136
+ update_plan_post_messages (`str`): Update plan post-messages prompt.
137
+ """
138
+
139
+ initial_plan: str
140
+ update_plan_pre_messages: str
141
+ update_plan_post_messages: str
142
+
143
+
144
+ class ManagedAgentPromptTemplate(TypedDict):
145
+ """
146
+ Prompt templates for the managed agent.
147
+
148
+ Args:
149
+ task (`str`): Task prompt.
150
+ report (`str`): Report prompt.
151
+ """
152
+
153
+ task: str
154
+ report: str
155
+
156
+
157
+ class FinalAnswerPromptTemplate(TypedDict):
158
+ """
159
+ Prompt templates for the final answer.
160
+
161
+ Args:
162
+ pre_messages (`str`): Pre-messages prompt.
163
+ post_messages (`str`): Post-messages prompt.
164
+ """
165
+
166
+ pre_messages: str
167
+ post_messages: str
168
+
169
+
170
+ class PromptTemplates(TypedDict):
171
+ """
172
+ Prompt templates for the agent.
173
+
174
+ Args:
175
+ system_prompt (`str`): System prompt.
176
+ planning ([`~agents.PlanningPromptTemplate`]): Planning prompt templates.
177
+ managed_agent ([`~agents.ManagedAgentPromptTemplate`]): Managed agent prompt templates.
178
+ final_answer ([`~agents.FinalAnswerPromptTemplate`]): Final answer prompt templates.
179
+ """
180
+
181
+ system_prompt: str
182
+ planning: PlanningPromptTemplate
183
+ managed_agent: ManagedAgentPromptTemplate
184
+ final_answer: FinalAnswerPromptTemplate
185
+
186
+
187
+ EMPTY_PROMPT_TEMPLATES = PromptTemplates(
188
+ system_prompt="",
189
+ planning=PlanningPromptTemplate(
190
+ initial_plan="",
191
+ update_plan_pre_messages="",
192
+ update_plan_post_messages="",
193
+ ),
194
+ managed_agent=ManagedAgentPromptTemplate(task="", report=""),
195
+ final_answer=FinalAnswerPromptTemplate(pre_messages="", post_messages=""),
196
+ )
197
+
198
+
199
+ @dataclass
200
+ class RunResult:
201
+ """Holds extended information about an agent run.
202
+
203
+ Attributes:
204
+ output (Any | None): The final output of the agent run, if available.
205
+ state (Literal["success", "max_steps_error"]): The final state of the agent after the run.
206
+ messages (list[dict]): The agent's memory, as a list of messages.
207
+ token_usage (TokenUsage | None): Count of tokens used during the run.
208
+ timing (Timing): Timing details of the agent run: start time, end time, duration.
209
+ """
210
+
211
+ output: Any | None
212
+ state: Literal["success", "max_steps_error"]
213
+ messages: list[dict]
214
+ token_usage: TokenUsage | None
215
+ timing: Timing
216
+
217
+
218
+ StreamEvent: TypeAlias = Union[
219
+ ChatMessageStreamDelta,
220
+ ChatMessageToolCall,
221
+ ActionOutput,
222
+ ToolCall,
223
+ ToolOutput,
224
+ PlanningStep,
225
+ ActionStep,
226
+ FinalAnswerStep,
227
+ ]
228
+
229
+
230
+ class MultiStepAgent(ABC):
231
+ """
232
+ Agent class that solves the given task step by step, using the ReAct framework:
233
+ While the objective is not reached, the agent will perform a cycle of action (given by the LLM) and observation (obtained from the environment).
234
+
235
+ Args:
236
+ tools (`list[Tool]`): [`Tool`]s that the agent can use.
237
+ model (`Callable[[list[dict[str, str]]], ChatMessage]`): Model that will generate the agent's actions.
238
+ prompt_templates ([`~agents.PromptTemplates`], *optional*): Prompt templates.
239
+ instructions (`str`, *optional*): Custom instructions for the agent, will be inserted in the system prompt.
240
+ max_steps (`int`, default `20`): Maximum number of steps the agent can take to solve the task.
241
+ add_base_tools (`bool`, default `False`): Whether to add the base tools to the agent's tools.
242
+ verbosity_level (`LogLevel`, default `LogLevel.INFO`): Level of verbosity of the agent's logs.
243
+ grammar (`dict[str, str]`, *optional*): Grammar used to parse the LLM output.
244
+ <Deprecated version="1.17.0">
245
+ Parameter `grammar` is deprecated and will be removed in version 1.20.
246
+ </Deprecated>
247
+ managed_agents (`list`, *optional*): Managed agents that the agent can call.
248
+ step_callbacks (`list[Callable]`, *optional*): Callbacks that will be called at each step.
249
+ planning_interval (`int`, *optional*): Interval at which the agent will run a planning step.
250
+ name (`str`, *optional*): Necessary for a managed agent only - the name by which this agent can be called.
251
+ description (`str`, *optional*): Necessary for a managed agent only - the description of this agent.
252
+ provide_run_summary (`bool`, *optional*): Whether to provide a run summary when called as a managed agent.
253
+ final_answer_checks (`list[Callable]`, *optional*): List of validation functions to run before accepting a final answer.
254
+ Each function should:
255
+ - Take the final answer and the agent's memory as arguments.
256
+ - Return a boolean indicating whether the final answer is valid.
257
+ """
258
+
259
+ def __init__(
260
+ self,
261
+ tools: list[Tool],
262
+ model: Model,
263
+ prompt_templates: PromptTemplates | None = None,
264
+ instructions: str | None = None,
265
+ max_steps: int = 20,
266
+ add_base_tools: bool = False,
267
+ verbosity_level: LogLevel = LogLevel.INFO,
268
+ grammar: dict[str, str] | None = None,
269
+ managed_agents: list | None = None,
270
+ step_callbacks: list[Callable] | None = None,
271
+ planning_interval: int | None = None,
272
+ name: str | None = None,
273
+ description: str | None = None,
274
+ provide_run_summary: bool = False,
275
+ final_answer_checks: list[Callable] | None = None,
276
+ return_full_result: bool = False,
277
+ logger: AgentLogger | None = None,
278
+ ):
279
+ self.agent_name = self.__class__.__name__
280
+ self.model = model
281
+ self.prompt_templates = prompt_templates or EMPTY_PROMPT_TEMPLATES
282
+ if prompt_templates is not None:
283
+ missing_keys = set(EMPTY_PROMPT_TEMPLATES.keys()) - set(prompt_templates.keys())
284
+ assert not missing_keys, (
285
+ f"Some prompt templates are missing from your custom `prompt_templates`: {missing_keys}"
286
+ )
287
+ for key, value in EMPTY_PROMPT_TEMPLATES.items():
288
+ if isinstance(value, dict):
289
+ for subkey in value.keys():
290
+ assert key in prompt_templates.keys() and (subkey in prompt_templates[key].keys()), (
291
+ f"Some prompt templates are missing from your custom `prompt_templates`: {subkey} under {key}"
292
+ )
293
+
294
+ self.max_steps = max_steps
295
+ self.step_number = 0
296
+ if grammar is not None:
297
+ warnings.warn(
298
+ "Parameter 'grammar' is deprecated and will be removed in version 1.20.",
299
+ FutureWarning,
300
+ )
301
+ self.grammar = grammar
302
+ self.planning_interval = planning_interval
303
+ self.state: dict[str, Any] = {}
304
+ self.name = self._validate_name(name)
305
+ self.description = description
306
+ self.provide_run_summary = provide_run_summary
307
+ self.final_answer_checks = final_answer_checks if final_answer_checks is not None else []
308
+ self.return_full_result = return_full_result
309
+ self.instructions = instructions
310
+ self._setup_managed_agents(managed_agents)
311
+ self._setup_tools(tools, add_base_tools)
312
+ self._validate_tools_and_managed_agents(tools, managed_agents)
313
+
314
+ self.task: str | None = None
315
+ self.memory = AgentMemory(self.system_prompt)
316
+
317
+ if logger is None:
318
+ self.logger = AgentLogger(level=verbosity_level)
319
+ else:
320
+ self.logger = logger
321
+
322
+ self.monitor = Monitor(self.model, self.logger)
323
+ self.step_callbacks = step_callbacks if step_callbacks is not None else []
324
+ self.step_callbacks.append(self.monitor.update_metrics)
325
+ self.stream_outputs = False
326
+
327
+ @property
328
+ def system_prompt(self) -> str:
329
+ return self.initialize_system_prompt()
330
+
331
+ @system_prompt.setter
332
+ def system_prompt(self, value: str):
333
+ raise AttributeError(
334
+ """The 'system_prompt' property is read-only. Use 'self.prompt_templates["system_prompt"]' instead."""
335
+ )
336
+
337
+ def _validate_name(self, name: str | None) -> str | None:
338
+ if name is not None and not is_valid_name(name):
339
+ raise ValueError(f"Agent name '{name}' must be a valid Python identifier and not a reserved keyword.")
340
+ return name
341
+
342
+ def _setup_managed_agents(self, managed_agents: list | None = None) -> None:
343
+ """Setup managed agents with proper logging."""
344
+ self.managed_agents = {}
345
+ if managed_agents:
346
+ assert all(agent.name and agent.description for agent in managed_agents), (
347
+ "All managed agents need both a name and a description!"
348
+ )
349
+ self.managed_agents = {agent.name: agent for agent in managed_agents}
350
+ # Ensure managed agents can be called as tools by the model: set their inputs and output_type
351
+ for agent in self.managed_agents.values():
352
+ agent.inputs = {
353
+ "task": {"type": "string", "description": "Long detailed description of the task."},
354
+ "additional_args": {
355
+ "type": "object",
356
+ "description": "Dictionary of extra inputs to pass to the managed agent, e.g. images, dataframes, or any other contextual data it may need.",
357
+ },
358
+ }
359
+ agent.output_type = "string"
360
+
361
+ def _setup_tools(self, tools, add_base_tools):
362
+ assert all(isinstance(tool, Tool) for tool in tools), "All elements must be instance of Tool (or a subclass)"
363
+ self.tools = {tool.name: tool for tool in tools}
364
+ if add_base_tools:
365
+ self.tools.update(
366
+ {
367
+ name: cls()
368
+ for name, cls in TOOL_MAPPING.items()
369
+ if name != "python_interpreter" or self.__class__.__name__ == "ToolCallingAgent"
370
+ }
371
+ )
372
+ self.tools.setdefault("final_answer", FinalAnswerTool())
373
+
374
+ def _validate_tools_and_managed_agents(self, tools, managed_agents):
375
+ tool_and_managed_agent_names = [tool.name for tool in tools]
376
+ if managed_agents is not None:
377
+ tool_and_managed_agent_names += [agent.name for agent in managed_agents]
378
+ if self.name:
379
+ tool_and_managed_agent_names.append(self.name)
380
+ if len(tool_and_managed_agent_names) != len(set(tool_and_managed_agent_names)):
381
+ raise ValueError(
382
+ "Each tool or managed_agent should have a unique name! You passed these duplicate names: "
383
+ f"{[name for name in tool_and_managed_agent_names if tool_and_managed_agent_names.count(name) > 1]}"
384
+ )
385
+
386
+ def run(
387
+ self,
388
+ task: str,
389
+ stream: bool = False,
390
+ reset: bool = True,
391
+ images: list["PIL.Image.Image"] | None = None,
392
+ additional_args: dict | None = None,
393
+ max_steps: int | None = None,
394
+ ):
395
+ """
396
+ Run the agent for the given task.
397
+
398
+ Args:
399
+ task (`str`): Task to perform.
400
+ stream (`bool`): Whether to run in streaming mode.
401
+ If `True`, returns a generator that yields each step as it is executed. You must iterate over this generator to process the individual steps (e.g., using a for loop or `next()`).
402
+ If `False`, executes all steps internally and returns only the final answer after completion.
403
+ reset (`bool`): Whether to reset the conversation or keep it going from previous run.
404
+ images (`list[PIL.Image.Image]`, *optional*): Image(s) objects.
405
+ additional_args (`dict`, *optional*): Any other variables that you want to pass to the agent run, for instance images or dataframes. Give them clear names!
406
+ max_steps (`int`, *optional*): Maximum number of steps the agent can take to solve the task. if not provided, will use the agent's default value.
407
+
408
+ Example:
409
+ ```py
410
+ from smolagents import CodeAgent
411
+ agent = CodeAgent(tools=[])
412
+ agent.run("What is the result of 2 power 3.7384?")
413
+ ```
414
+ """
415
+ max_steps = max_steps or self.max_steps
416
+ self.task = task
417
+ self.interrupt_switch = False
418
+ if additional_args is not None:
419
+ self.state.update(additional_args)
420
+ self.task += f"""
421
+ You have been provided with these additional arguments, that you can access using the keys as variables in your python code:
422
+ {str(additional_args)}."""
423
+
424
+ self.memory.system_prompt = SystemPromptStep(system_prompt=self.system_prompt)
425
+ if reset:
426
+ self.memory.reset()
427
+ self.monitor.reset()
428
+
429
+ self.logger.log_task(
430
+ content=self.task.strip(),
431
+ subtitle=f"{type(self.model).__name__} - {(self.model.model_id if hasattr(self.model, 'model_id') else '')}",
432
+ level=LogLevel.INFO,
433
+ title=self.name if hasattr(self, "name") else None,
434
+ )
435
+ self.memory.steps.append(TaskStep(task=self.task, task_images=images))
436
+
437
+ if getattr(self, "python_executor", None):
438
+ self.python_executor.send_variables(variables=self.state)
439
+ self.python_executor.send_tools({**self.tools, **self.managed_agents})
440
+
441
+ if stream:
442
+ # The steps are returned as they are executed through a generator to iterate on.
443
+ return self._run_stream(task=self.task, max_steps=max_steps, images=images)
444
+ run_start_time = time.time()
445
+ # Outputs are returned only at the end. We only look at the last step.
446
+
447
+ steps = list(self._run_stream(task=self.task, max_steps=max_steps, images=images))
448
+ assert isinstance(steps[-1], FinalAnswerStep)
449
+ output = steps[-1].output
450
+
451
+ if self.return_full_result:
452
+ total_input_tokens = 0
453
+ total_output_tokens = 0
454
+ correct_token_usage = True
455
+ for step in self.memory.steps:
456
+ if isinstance(step, (ActionStep, PlanningStep)):
457
+ if step.token_usage is None:
458
+ correct_token_usage = False
459
+ break
460
+ else:
461
+ total_input_tokens += step.token_usage.input_tokens
462
+ total_output_tokens += step.token_usage.output_tokens
463
+ if correct_token_usage:
464
+ token_usage = TokenUsage(input_tokens=total_input_tokens, output_tokens=total_output_tokens)
465
+ else:
466
+ token_usage = None
467
+
468
+ if self.memory.steps and isinstance(getattr(self.memory.steps[-1], "error", None), AgentMaxStepsError):
469
+ state = "max_steps_error"
470
+ else:
471
+ state = "success"
472
+
473
+ messages = self.memory.get_full_steps()
474
+
475
+ return RunResult(
476
+ output=output,
477
+ token_usage=token_usage,
478
+ messages=messages,
479
+ timing=Timing(start_time=run_start_time, end_time=time.time()),
480
+ state=state,
481
+ )
482
+
483
+ return output
484
+
485
+ def _run_stream(
486
+ self, task: str, max_steps: int, images: list["PIL.Image.Image"] | None = None
487
+ ) -> Generator[ActionStep | PlanningStep | FinalAnswerStep | ChatMessageStreamDelta]:
488
+ self.step_number = 1
489
+ returned_final_answer = False
490
+ while not returned_final_answer and self.step_number <= max_steps:
491
+ if self.interrupt_switch:
492
+ raise AgentError("Agent interrupted.", self.logger)
493
+
494
+ # Run a planning step if scheduled
495
+ if self.planning_interval is not None and (
496
+ self.step_number == 1 or (self.step_number - 1) % self.planning_interval == 0
497
+ ):
498
+ planning_start_time = time.time()
499
+ planning_step = None
500
+ for element in self._generate_planning_step(
501
+ task, is_first_step=len(self.memory.steps) == 1, step=self.step_number
502
+ ): # Don't use the attribute step_number here, because there can be steps from previous runs
503
+ yield element
504
+ planning_step = element
505
+ assert isinstance(planning_step, PlanningStep) # Last yielded element should be a PlanningStep
506
+ self.memory.steps.append(planning_step)
507
+ planning_end_time = time.time()
508
+ planning_step.timing = Timing(
509
+ start_time=planning_start_time,
510
+ end_time=planning_end_time,
511
+ )
512
+
513
+ # Start action step!
514
+ action_step_start_time = time.time()
515
+ action_step = ActionStep(
516
+ step_number=self.step_number,
517
+ timing=Timing(start_time=action_step_start_time),
518
+ observations_images=images,
519
+ )
520
+ self.logger.log_rule(f"Step {self.step_number}", level=LogLevel.INFO)
521
+ try:
522
+ for output in self._step_stream(action_step):
523
+ # Yield all
524
+ yield output
525
+
526
+ if isinstance(output, ActionOutput) and output.is_final_answer:
527
+ final_answer = output.output
528
+ self.logger.log(
529
+ Text(f"Final answer: {final_answer}", style=f"bold {YELLOW_HEX}"),
530
+ level=LogLevel.INFO,
531
+ )
532
+
533
+ if self.final_answer_checks:
534
+ self._validate_final_answer(final_answer)
535
+ returned_final_answer = True
536
+ action_step.is_final_answer = True
537
+
538
+ except AgentGenerationError as e:
539
+ # Agent generation errors are not caused by a Model error but an implementation error: so we should raise them and exit.
540
+ raise e
541
+ except AgentError as e:
542
+ # Other AgentError types are caused by the Model, so we should log them and iterate.
543
+ action_step.error = e
544
+ finally:
545
+ self._finalize_step(action_step)
546
+ self.memory.steps.append(action_step)
547
+ yield action_step
548
+ self.step_number += 1
549
+
550
+ if not returned_final_answer and self.step_number == max_steps + 1:
551
+ final_answer = self._handle_max_steps_reached(task, images)
552
+ yield action_step
553
+ yield FinalAnswerStep(handle_agent_output_types(final_answer))
554
+
555
+ def _validate_final_answer(self, final_answer: Any):
556
+ for check_function in self.final_answer_checks:
557
+ try:
558
+ assert check_function(final_answer, self.memory)
559
+ except Exception as e:
560
+ raise AgentError(f"Check {check_function.__name__} failed with error: {e}", self.logger)
561
+
562
+ def _finalize_step(self, memory_step: ActionStep):
563
+ memory_step.timing.end_time = time.time()
564
+ for callback in self.step_callbacks:
565
+ # For compatibility with old callbacks that don't take the agent as an argument
566
+ callback(memory_step) if len(inspect.signature(callback).parameters) == 1 else callback(
567
+ memory_step, agent=self
568
+ )
569
+
570
+ def _handle_max_steps_reached(self, task: str, images: list["PIL.Image.Image"]) -> Any:
571
+ action_step_start_time = time.time()
572
+ final_answer = self.provide_final_answer(task, images)
573
+ final_memory_step = ActionStep(
574
+ step_number=self.step_number,
575
+ error=AgentMaxStepsError("Reached max steps.", self.logger),
576
+ timing=Timing(start_time=action_step_start_time, end_time=time.time()),
577
+ token_usage=final_answer.token_usage,
578
+ )
579
+ final_memory_step.action_output = final_answer.content
580
+ self._finalize_step(final_memory_step)
581
+ self.memory.steps.append(final_memory_step)
582
+ return final_answer.content
583
+
584
+ def _generate_planning_step(
585
+ self, task, is_first_step: bool, step: int
586
+ ) -> Generator[ChatMessageStreamDelta | PlanningStep]:
587
+ start_time = time.time()
588
+ if is_first_step:
589
+ input_messages = [
590
+ ChatMessage(
591
+ role=MessageRole.USER,
592
+ content=[
593
+ {
594
+ "type": "text",
595
+ "text": populate_template(
596
+ self.prompt_templates["planning"]["initial_plan"],
597
+ variables={"task": task, "tools": self.tools, "managed_agents": self.managed_agents},
598
+ ),
599
+ }
600
+ ],
601
+ )
602
+ ]
603
+ if self.stream_outputs and hasattr(self.model, "generate_stream"):
604
+ plan_message_content = ""
605
+ output_stream = self.model.generate_stream(input_messages, stop_sequences=["<end_plan>"]) # type: ignore
606
+ input_tokens, output_tokens = 0, 0
607
+ with Live("", console=self.logger.console, vertical_overflow="visible") as live:
608
+ for event in output_stream:
609
+ if event.content is not None:
610
+ plan_message_content += event.content
611
+ live.update(Markdown(plan_message_content))
612
+ if event.token_usage:
613
+ output_tokens += event.token_usage.output_tokens
614
+ input_tokens = event.token_usage.input_tokens
615
+ yield event
616
+ else:
617
+ plan_message = self.model.generate(input_messages, stop_sequences=["<end_plan>"])
618
+ plan_message_content = plan_message.content
619
+ input_tokens, output_tokens = (
620
+ (
621
+ plan_message.token_usage.input_tokens,
622
+ plan_message.token_usage.output_tokens,
623
+ )
624
+ if plan_message.token_usage
625
+ else (None, None)
626
+ )
627
+ plan = textwrap.dedent(
628
+ f"""Here are the facts I know and the plan of action that I will follow to solve the task:\n```\n{plan_message_content}\n```"""
629
+ )
630
+ else:
631
+ # Summary mode removes the system prompt and previous planning messages output by the model.
632
+ # Removing previous planning messages avoids influencing too much the new plan.
633
+ memory_messages = self.write_memory_to_messages(summary_mode=True)
634
+ plan_update_pre = ChatMessage(
635
+ role=MessageRole.SYSTEM,
636
+ content=[
637
+ {
638
+ "type": "text",
639
+ "text": populate_template(
640
+ self.prompt_templates["planning"]["update_plan_pre_messages"], variables={"task": task}
641
+ ),
642
+ }
643
+ ],
644
+ )
645
+ plan_update_post = ChatMessage(
646
+ role=MessageRole.USER,
647
+ content=[
648
+ {
649
+ "type": "text",
650
+ "text": populate_template(
651
+ self.prompt_templates["planning"]["update_plan_post_messages"],
652
+ variables={
653
+ "task": task,
654
+ "tools": self.tools,
655
+ "managed_agents": self.managed_agents,
656
+ "remaining_steps": (self.max_steps - step),
657
+ },
658
+ ),
659
+ }
660
+ ],
661
+ )
662
+ input_messages = [plan_update_pre] + memory_messages + [plan_update_post]
663
+ if self.stream_outputs and hasattr(self.model, "generate_stream"):
664
+ plan_message_content = ""
665
+ input_tokens, output_tokens = 0, 0
666
+ with Live("", console=self.logger.console, vertical_overflow="visible") as live:
667
+ for event in self.model.generate_stream(
668
+ input_messages,
669
+ stop_sequences=["<end_plan>"],
670
+ ): # type: ignore
671
+ if event.content is not None:
672
+ plan_message_content += event.content
673
+ live.update(Markdown(plan_message_content))
674
+ if event.token_usage:
675
+ output_tokens += event.token_usage.output_tokens
676
+ input_tokens = event.token_usage.input_tokens
677
+ yield event
678
+ else:
679
+ plan_message = self.model.generate(input_messages, stop_sequences=["<end_plan>"])
680
+ plan_message_content = plan_message.content
681
+ if plan_message.token_usage is not None:
682
+ input_tokens, output_tokens = (
683
+ plan_message.token_usage.input_tokens,
684
+ plan_message.token_usage.output_tokens,
685
+ )
686
+ plan = textwrap.dedent(
687
+ f"""I still need to solve the task I was given:\n```\n{self.task}\n```\n\nHere are the facts I know and my new/updated plan of action to solve the task:\n```\n{plan_message_content}\n```"""
688
+ )
689
+ log_headline = "Initial plan" if is_first_step else "Updated plan"
690
+ self.logger.log(Rule(f"[bold]{log_headline}", style="orange"), Text(plan), level=LogLevel.INFO)
691
+ yield PlanningStep(
692
+ model_input_messages=input_messages,
693
+ plan=plan,
694
+ model_output_message=ChatMessage(role=MessageRole.ASSISTANT, content=plan_message_content),
695
+ token_usage=TokenUsage(input_tokens=input_tokens, output_tokens=output_tokens),
696
+ timing=Timing(start_time=start_time, end_time=time.time()),
697
+ )
698
+
699
+ @property
700
+ def logs(self):
701
+ logger.warning(
702
+ "The 'logs' attribute is deprecated and will soon be removed. Please use 'self.memory.steps' instead."
703
+ )
704
+ return [self.memory.system_prompt] + self.memory.steps
705
+
706
+ @abstractmethod
707
+ def initialize_system_prompt(self) -> str:
708
+ """To be implemented in child classes"""
709
+ ...
710
+
711
+ def interrupt(self):
712
+ """Interrupts the agent execution."""
713
+ self.interrupt_switch = True
714
+
715
+ def write_memory_to_messages(
716
+ self,
717
+ summary_mode: bool = False,
718
+ ) -> list[ChatMessage]:
719
+ """
720
+ Reads past llm_outputs, actions, and observations or errors from the memory into a series of messages
721
+ that can be used as input to the LLM. Adds a number of keywords (such as PLAN, error, etc) to help
722
+ the LLM.
723
+ """
724
+ messages = self.memory.system_prompt.to_messages(summary_mode=summary_mode)
725
+ for memory_step in self.memory.steps:
726
+ messages.extend(memory_step.to_messages(summary_mode=summary_mode))
727
+ return messages
728
+
729
+ def _step_stream(
730
+ self, memory_step: ActionStep
731
+ ) -> Generator[ChatMessageStreamDelta | ToolCall | ToolOutput | ActionOutput]:
732
+ """
733
+ Perform one step in the ReAct framework: the agent thinks, acts, and observes the result.
734
+ Yields ChatMessageStreamDelta during the run if streaming is enabled.
735
+ At the end, yields either None if the step is not final, or the final answer.
736
+ """
737
+ raise NotImplementedError("This method should be implemented in child classes")
738
+
739
+ def step(self, memory_step: ActionStep) -> Any:
740
+ """
741
+ Perform one step in the ReAct framework: the agent thinks, acts, and observes the result.
742
+ Returns either None if the step is not final, or the final answer.
743
+ """
744
+ return list(self._step_stream(memory_step))[-1]
745
+
746
+ def extract_action(self, model_output: str, split_token: str) -> tuple[str, str]:
747
+ """
748
+ Parse action from the LLM output
749
+
750
+ Args:
751
+ model_output (`str`): Output of the LLM
752
+ split_token (`str`): Separator for the action. Should match the example in the system prompt.
753
+ """
754
+ try:
755
+ split = model_output.split(split_token)
756
+ rationale, action = (
757
+ split[-2],
758
+ split[-1],
759
+ ) # NOTE: using indexes starting from the end solves for when you have more than one split_token in the output
760
+ except Exception:
761
+ raise AgentParsingError(
762
+ f"No '{split_token}' token provided in your output.\nYour output:\n{model_output}\n. Be sure to include an action, prefaced with '{split_token}'!",
763
+ self.logger,
764
+ )
765
+ return rationale.strip(), action.strip()
766
+
767
+ def provide_final_answer(self, task: str, images: list["PIL.Image.Image"] | None = None) -> ChatMessage:
768
+ """
769
+ Provide the final answer to the task, based on the logs of the agent's interactions.
770
+
771
+ Args:
772
+ task (`str`): Task to perform.
773
+ images (`list[PIL.Image.Image]`, *optional*): Image(s) objects.
774
+
775
+ Returns:
776
+ `str`: Final answer to the task.
777
+ """
778
+ messages = [
779
+ ChatMessage(
780
+ role=MessageRole.SYSTEM,
781
+ content=[
782
+ {
783
+ "type": "text",
784
+ "text": self.prompt_templates["final_answer"]["pre_messages"],
785
+ }
786
+ ],
787
+ )
788
+ ]
789
+ if images:
790
+ messages[0].content += [{"type": "image", "image": image} for image in images]
791
+ messages += self.write_memory_to_messages()[1:]
792
+ messages.append(
793
+ ChatMessage(
794
+ role=MessageRole.USER,
795
+ content=[
796
+ {
797
+ "type": "text",
798
+ "text": populate_template(
799
+ self.prompt_templates["final_answer"]["post_messages"], variables={"task": task}
800
+ ),
801
+ }
802
+ ],
803
+ )
804
+ )
805
+ try:
806
+ chat_message: ChatMessage = self.model.generate(messages)
807
+ return chat_message
808
+ except Exception as e:
809
+ return ChatMessage(role=MessageRole.ASSISTANT, content=f"Error in generating final LLM output:\n{e}")
810
+
811
+ def visualize(self):
812
+ """Creates a rich tree visualization of the agent's structure."""
813
+ self.logger.visualize_agent_tree(self)
814
+
815
+ def replay(self, detailed: bool = False):
816
+ """Prints a pretty replay of the agent's steps.
817
+
818
+ Args:
819
+ detailed (bool, optional): If True, also displays the memory at each step. Defaults to False.
820
+ Careful: will increase log length exponentially. Use only for debugging.
821
+ """
822
+ self.memory.replay(self.logger, detailed=detailed)
823
+
824
+ def __call__(self, task: str, **kwargs):
825
+ """Adds additional prompting for the managed agent, runs it, and wraps the output.
826
+ This method is called only by a managed agent.
827
+ """
828
+ full_task = populate_template(
829
+ self.prompt_templates["managed_agent"]["task"],
830
+ variables=dict(name=self.name, task=task),
831
+ )
832
+ result = self.run(full_task, **kwargs)
833
+ if isinstance(result, RunResult):
834
+ report = result.output
835
+ else:
836
+ report = result
837
+ answer = populate_template(
838
+ self.prompt_templates["managed_agent"]["report"], variables=dict(name=self.name, final_answer=report)
839
+ )
840
+ if self.provide_run_summary:
841
+ answer += "\n\nFor more detail, find below a summary of this agent's work:\n<summary_of_work>\n"
842
+ for message in self.write_memory_to_messages(summary_mode=True):
843
+ content = message["content"]
844
+ answer += "\n" + truncate_content(str(content)) + "\n---"
845
+ answer += "\n</summary_of_work>"
846
+ return answer
847
+
848
+ def save(self, output_dir: str | Path, relative_path: str | None = None):
849
+ """
850
+ Saves the relevant code files for your agent. This will copy the code of your agent in `output_dir` as well as autogenerate:
851
+
852
+ - a `tools` folder containing the logic for each of the tools under `tools/{tool_name}.py`.
853
+ - a `managed_agents` folder containing the logic for each of the managed agents.
854
+ - an `agent.json` file containing a dictionary representing your agent.
855
+ - a `prompt.yaml` file containing the prompt templates used by your agent.
856
+ - an `app.py` file providing a UI for your agent when it is exported to a Space with `agent.push_to_hub()`
857
+ - a `requirements.txt` containing the names of the modules used by your tool (as detected when inspecting its
858
+ code)
859
+
860
+ Args:
861
+ output_dir (`str` or `Path`): The folder in which you want to save your agent.
862
+ """
863
+ make_init_file(output_dir)
864
+
865
+ # Recursively save managed agents
866
+ if self.managed_agents:
867
+ make_init_file(os.path.join(output_dir, "managed_agents"))
868
+ for agent_name, agent in self.managed_agents.items():
869
+ agent_suffix = f"managed_agents.{agent_name}"
870
+ if relative_path:
871
+ agent_suffix = relative_path + "." + agent_suffix
872
+ agent.save(os.path.join(output_dir, "managed_agents", agent_name), relative_path=agent_suffix)
873
+
874
+ class_name = self.__class__.__name__
875
+
876
+ # Save tools to different .py files
877
+ for tool in self.tools.values():
878
+ make_init_file(os.path.join(output_dir, "tools"))
879
+ tool.save(os.path.join(output_dir, "tools"), tool_file_name=tool.name, make_gradio_app=False)
880
+
881
+ # Save prompts to yaml
882
+ yaml_prompts = yaml.safe_dump(
883
+ self.prompt_templates,
884
+ default_style="|", # This forces block literals for all strings
885
+ default_flow_style=False,
886
+ width=float("inf"),
887
+ sort_keys=False,
888
+ allow_unicode=True,
889
+ indent=2,
890
+ )
891
+
892
+ with open(os.path.join(output_dir, "prompts.yaml"), "w", encoding="utf-8") as f:
893
+ f.write(yaml_prompts)
894
+
895
+ # Save agent dictionary to json
896
+ agent_dict = self.to_dict()
897
+ agent_dict["tools"] = [tool.name for tool in self.tools.values()]
898
+ agent_dict["managed_agents"] = {agent.name: agent.__class__.__name__ for agent in self.managed_agents.values()}
899
+ with open(os.path.join(output_dir, "agent.json"), "w", encoding="utf-8") as f:
900
+ json.dump(agent_dict, f, indent=4)
901
+
902
+ # Save requirements
903
+ with open(os.path.join(output_dir, "requirements.txt"), "w", encoding="utf-8") as f:
904
+ f.writelines(f"{r}\n" for r in agent_dict["requirements"])
905
+
906
+ # Make agent.py file with Gradio UI
907
+ agent_name = f"agent_{self.name}" if getattr(self, "name", None) else "agent"
908
+ managed_agent_relative_path = relative_path + "." if relative_path is not None else ""
909
+ app_template = AGENT_GRADIO_APP_TEMPLATE
910
+ template_env = jinja2.Environment(loader=jinja2.BaseLoader(), undefined=jinja2.StrictUndefined)
911
+ template_env.filters["repr"] = repr
912
+ template_env.filters["camelcase"] = lambda value: "".join(word.capitalize() for word in value.split("_"))
913
+ template = template_env.from_string(app_template)
914
+
915
+ # Render the app.py file from Jinja2 template
916
+ app_text = template.render(
917
+ {
918
+ "agent_name": agent_name,
919
+ "class_name": class_name,
920
+ "agent_dict": agent_dict,
921
+ "tools": self.tools,
922
+ "managed_agents": self.managed_agents,
923
+ "managed_agent_relative_path": managed_agent_relative_path,
924
+ }
925
+ )
926
+
927
+ with open(os.path.join(output_dir, "app.py"), "w", encoding="utf-8") as f:
928
+ f.write(app_text + "\n") # Append newline at the end
929
+
930
+ def to_dict(self) -> dict[str, Any]:
931
+ """Convert the agent to a dictionary representation.
932
+
933
+ Returns:
934
+ `dict`: Dictionary representation of the agent.
935
+ """
936
+ # TODO: handle serializing step_callbacks and final_answer_checks
937
+ for attr in ["final_answer_checks", "step_callbacks"]:
938
+ if getattr(self, attr, None):
939
+ self.logger.log(f"This agent has {attr}: they will be ignored by this method.", LogLevel.INFO)
940
+
941
+ tool_dicts = [tool.to_dict() for tool in self.tools.values()]
942
+ tool_requirements = {req for tool in self.tools.values() for req in tool.to_dict()["requirements"]}
943
+ managed_agents_requirements = {
944
+ req for managed_agent in self.managed_agents.values() for req in managed_agent.to_dict()["requirements"]
945
+ }
946
+ requirements = tool_requirements | managed_agents_requirements
947
+ if hasattr(self, "authorized_imports"):
948
+ requirements.update(
949
+ {package.split(".")[0] for package in self.authorized_imports if package not in BASE_BUILTIN_MODULES}
950
+ )
951
+
952
+ agent_dict = {
953
+ "class": self.__class__.__name__,
954
+ "tools": tool_dicts,
955
+ "model": {
956
+ "class": self.model.__class__.__name__,
957
+ "data": self.model.to_dict(),
958
+ },
959
+ "managed_agents": [managed_agent.to_dict() for managed_agent in self.managed_agents.values()],
960
+ "prompt_templates": self.prompt_templates,
961
+ "max_steps": self.max_steps,
962
+ "verbosity_level": int(self.logger.level),
963
+ "grammar": self.grammar,
964
+ "planning_interval": self.planning_interval,
965
+ "name": self.name,
966
+ "description": self.description,
967
+ "requirements": sorted(requirements),
968
+ }
969
+ return agent_dict
970
+
971
+ @classmethod
972
+ def from_dict(cls, agent_dict: dict[str, Any], **kwargs) -> "MultiStepAgent":
973
+ """Create agent from a dictionary representation.
974
+
975
+ Args:
976
+ agent_dict (`dict[str, Any]`): Dictionary representation of the agent.
977
+ **kwargs: Additional keyword arguments that will override agent_dict values.
978
+
979
+ Returns:
980
+ `MultiStepAgent`: Instance of the agent class.
981
+ """
982
+ # Load model
983
+ model_info = agent_dict["model"]
984
+ model_class = getattr(importlib.import_module("smolagents.models"), model_info["class"])
985
+ model = model_class.from_dict(model_info["data"])
986
+ # Load tools
987
+ tools = []
988
+ for tool_info in agent_dict["tools"]:
989
+ tools.append(Tool.from_code(tool_info["code"]))
990
+ # Load managed agents
991
+ managed_agents = []
992
+ for managed_agent_name, managed_agent_class_name in agent_dict["managed_agents"].items():
993
+ managed_agent_class = getattr(importlib.import_module("smolagents.agents"), managed_agent_class_name)
994
+ managed_agents.append(managed_agent_class.from_dict(agent_dict["managed_agents"][managed_agent_name]))
995
+ # Extract base agent parameters
996
+ agent_args = {
997
+ "model": model,
998
+ "tools": tools,
999
+ "prompt_templates": agent_dict.get("prompt_templates"),
1000
+ "max_steps": agent_dict.get("max_steps"),
1001
+ "verbosity_level": agent_dict.get("verbosity_level"),
1002
+ "grammar": agent_dict.get("grammar"),
1003
+ "planning_interval": agent_dict.get("planning_interval"),
1004
+ "name": agent_dict.get("name"),
1005
+ "description": agent_dict.get("description"),
1006
+ }
1007
+ # Filter out None values to use defaults from __init__
1008
+ agent_args = {k: v for k, v in agent_args.items() if v is not None}
1009
+ # Update with any additional kwargs
1010
+ agent_args.update(kwargs)
1011
+ # Create agent instance
1012
+ return cls(**agent_args)
1013
+
1014
+ @classmethod
1015
+ def from_hub(
1016
+ cls,
1017
+ repo_id: str,
1018
+ token: str | None = None,
1019
+ trust_remote_code: bool = False,
1020
+ **kwargs,
1021
+ ):
1022
+ """
1023
+ Loads an agent defined on the Hub.
1024
+
1025
+ <Tip warning={true}>
1026
+
1027
+ Loading a tool from the Hub means that you'll download the tool and execute it locally.
1028
+ ALWAYS inspect the tool you're downloading before loading it within your runtime, as you would do when
1029
+ installing a package using pip/npm/apt.
1030
+
1031
+ </Tip>
1032
+
1033
+ Args:
1034
+ repo_id (`str`):
1035
+ The name of the repo on the Hub where your tool is defined.
1036
+ token (`str`, *optional*):
1037
+ The token to identify you on hf.co. If unset, will use the token generated when running
1038
+ `huggingface-cli login` (stored in `~/.huggingface`).
1039
+ trust_remote_code(`bool`, *optional*, defaults to False):
1040
+ This flags marks that you understand the risk of running remote code and that you trust this tool.
1041
+ If not setting this to True, loading the tool from Hub will fail.
1042
+ kwargs (additional keyword arguments, *optional*):
1043
+ Additional keyword arguments that will be split in two: all arguments relevant to the Hub (such as
1044
+ `cache_dir`, `revision`, `subfolder`) will be used when downloading the files for your agent, and the
1045
+ others will be passed along to its init.
1046
+ """
1047
+ if not trust_remote_code:
1048
+ raise ValueError(
1049
+ "Loading an agent from Hub requires to acknowledge you trust its code: to do so, pass `trust_remote_code=True`."
1050
+ )
1051
+
1052
+ # Get the agent's Hub folder.
1053
+ download_kwargs = {"token": token, "repo_type": "space"} | {
1054
+ key: kwargs.pop(key)
1055
+ for key in [
1056
+ "cache_dir",
1057
+ "force_download",
1058
+ "proxies",
1059
+ "revision",
1060
+ "local_files_only",
1061
+ ]
1062
+ if key in kwargs
1063
+ }
1064
+
1065
+ download_folder = Path(snapshot_download(repo_id=repo_id, **download_kwargs))
1066
+ return cls.from_folder(download_folder, **kwargs)
1067
+
1068
+ @classmethod
1069
+ def from_folder(cls, folder: str | Path, **kwargs):
1070
+ """Loads an agent from a local folder.
1071
+
1072
+ Args:
1073
+ folder (`str` or `Path`): The folder where the agent is saved.
1074
+ **kwargs: Additional keyword arguments that will be passed to the agent's init.
1075
+ """
1076
+ # Load agent.json
1077
+ folder = Path(folder)
1078
+ agent_dict = json.loads((folder / "agent.json").read_text())
1079
+
1080
+ # Load managed agents from their respective folders, recursively
1081
+ managed_agents = []
1082
+ for managed_agent_name, managed_agent_class_name in agent_dict["managed_agents"].items():
1083
+ agent_cls = getattr(importlib.import_module("smolagents.agents"), managed_agent_class_name)
1084
+ managed_agents.append(agent_cls.from_folder(folder / "managed_agents" / managed_agent_name))
1085
+ agent_dict["managed_agents"] = {}
1086
+
1087
+ # Load tools
1088
+ tools = []
1089
+ for tool_name in agent_dict["tools"]:
1090
+ tool_code = (folder / "tools" / f"{tool_name}.py").read_text()
1091
+ tools.append({"name": tool_name, "code": tool_code})
1092
+ agent_dict["tools"] = tools
1093
+
1094
+ # Add managed agents to kwargs to override the empty list in from_dict
1095
+ if managed_agents:
1096
+ kwargs["managed_agents"] = managed_agents
1097
+
1098
+ return cls.from_dict(agent_dict, **kwargs)
1099
+
1100
+ def push_to_hub(
1101
+ self,
1102
+ repo_id: str,
1103
+ commit_message: str = "Upload agent",
1104
+ private: bool | None = None,
1105
+ token: bool | str | None = None,
1106
+ create_pr: bool = False,
1107
+ ) -> str:
1108
+ """
1109
+ Upload the agent to the Hub.
1110
+
1111
+ Parameters:
1112
+ repo_id (`str`):
1113
+ The name of the repository you want to push to. It should contain your organization name when
1114
+ pushing to a given organization.
1115
+ commit_message (`str`, *optional*, defaults to `"Upload agent"`):
1116
+ Message to commit while pushing.
1117
+ private (`bool`, *optional*, defaults to `None`):
1118
+ Whether to make the repo private. If `None`, the repo will be public unless the organization's default is private. This value is ignored if the repo already exists.
1119
+ token (`bool` or `str`, *optional*):
1120
+ The token to use as HTTP bearer authorization for remote files. If unset, will use the token generated
1121
+ when running `huggingface-cli login` (stored in `~/.huggingface`).
1122
+ create_pr (`bool`, *optional*, defaults to `False`):
1123
+ Whether to create a PR with the uploaded files or directly commit.
1124
+ """
1125
+ repo_url = create_repo(
1126
+ repo_id=repo_id,
1127
+ token=token,
1128
+ private=private,
1129
+ exist_ok=True,
1130
+ repo_type="space",
1131
+ space_sdk="gradio",
1132
+ )
1133
+ repo_id = repo_url.repo_id
1134
+ metadata_update(
1135
+ repo_id,
1136
+ {"tags": ["smolagents", "agent"]},
1137
+ repo_type="space",
1138
+ token=token,
1139
+ overwrite=True,
1140
+ )
1141
+
1142
+ with tempfile.TemporaryDirectory() as work_dir:
1143
+ self.save(work_dir)
1144
+ logger.info(f"Uploading the following files to {repo_id}: {','.join(os.listdir(work_dir))}")
1145
+ return upload_folder(
1146
+ repo_id=repo_id,
1147
+ commit_message=commit_message,
1148
+ folder_path=work_dir,
1149
+ token=token,
1150
+ create_pr=create_pr,
1151
+ repo_type="space",
1152
+ )
1153
+
1154
+
1155
+ class ToolCallingAgent(MultiStepAgent):
1156
+ """
1157
+ This agent uses JSON-like tool calls, using method `model.get_tool_call` to leverage the LLM engine's tool calling capabilities.
1158
+
1159
+ Args:
1160
+ tools (`list[Tool]`): [`Tool`]s that the agent can use.
1161
+ model (`Model`): Model that will generate the agent's actions.
1162
+ prompt_templates ([`~agents.PromptTemplates`], *optional*): Prompt templates.
1163
+ planning_interval (`int`, *optional*): Interval at which the agent will run a planning step.
1164
+ stream_outputs (`bool`, *optional*, default `False`): Whether to stream outputs during execution.
1165
+ max_tool_threads (`int`, *optional*): Maximum number of threads for parallel tool calls.
1166
+ Higher values increase concurrency but resource usage as well.
1167
+ Defaults to `ThreadPoolExecutor`'s default.
1168
+ **kwargs: Additional keyword arguments.
1169
+ """
1170
+
1171
+ def __init__(
1172
+ self,
1173
+ tools: list[Tool],
1174
+ model: Model,
1175
+ prompt_templates: PromptTemplates | None = None,
1176
+ planning_interval: int | None = None,
1177
+ stream_outputs: bool = False,
1178
+ max_tool_threads: int | None = None,
1179
+ **kwargs,
1180
+ ):
1181
+ prompt_templates = prompt_templates or yaml.safe_load(
1182
+ importlib.resources.files("smolagents.prompts").joinpath("toolcalling_agent.yaml").read_text()
1183
+ )
1184
+ super().__init__(
1185
+ tools=tools,
1186
+ model=model,
1187
+ prompt_templates=prompt_templates,
1188
+ planning_interval=planning_interval,
1189
+ **kwargs,
1190
+ )
1191
+ # Streaming setup
1192
+ self.stream_outputs = stream_outputs
1193
+ if self.stream_outputs and not hasattr(self.model, "generate_stream"):
1194
+ raise ValueError(
1195
+ "`stream_outputs` is set to True, but the model class implements no `generate_stream` method."
1196
+ )
1197
+ # Tool calling setup
1198
+ self.max_tool_threads = max_tool_threads
1199
+
1200
+ @property
1201
+ def tools_and_managed_agents(self):
1202
+ """Returns a combined list of tools and managed agents."""
1203
+ return list(self.tools.values()) + list(self.managed_agents.values())
1204
+
1205
+ def initialize_system_prompt(self) -> str:
1206
+ system_prompt = populate_template(
1207
+ self.prompt_templates["system_prompt"],
1208
+ variables={
1209
+ "tools": self.tools,
1210
+ "managed_agents": self.managed_agents,
1211
+ "custom_instructions": self.instructions,
1212
+ },
1213
+ )
1214
+ return system_prompt
1215
+
1216
+ def _step_stream(
1217
+ self, memory_step: ActionStep
1218
+ ) -> Generator[ChatMessageStreamDelta | ToolCall | ToolOutput | ActionOutput]:
1219
+ """
1220
+ Perform one step in the ReAct framework: the agent thinks, acts, and observes the result.
1221
+ Yields ChatMessageStreamDelta during the run if streaming is enabled.
1222
+ At the end, yields either None if the step is not final, or the final answer.
1223
+ """
1224
+ memory_messages = self.write_memory_to_messages()
1225
+
1226
+ input_messages = memory_messages.copy()
1227
+
1228
+ # Add new step in logs
1229
+ memory_step.model_input_messages = input_messages
1230
+
1231
+ try:
1232
+ if self.stream_outputs and hasattr(self.model, "generate_stream"):
1233
+ output_stream = self.model.generate_stream(
1234
+ input_messages,
1235
+ stop_sequences=["Observation:", "Calling tools:"],
1236
+ tools_to_call_from=self.tools_and_managed_agents,
1237
+ )
1238
+
1239
+ chat_message_stream_deltas: list[ChatMessageStreamDelta] = []
1240
+ with Live("", console=self.logger.console, vertical_overflow="visible") as live:
1241
+ for event in output_stream:
1242
+ chat_message_stream_deltas.append(event)
1243
+ live.update(
1244
+ Markdown(agglomerate_stream_deltas(chat_message_stream_deltas).render_as_markdown())
1245
+ )
1246
+ yield event
1247
+ chat_message = agglomerate_stream_deltas(chat_message_stream_deltas)
1248
+ else:
1249
+ chat_message: ChatMessage = self.model.generate(
1250
+ input_messages,
1251
+ stop_sequences=["Observation:", "Calling tools:"],
1252
+ tools_to_call_from=self.tools_and_managed_agents,
1253
+ )
1254
+ if chat_message.content is None and chat_message.raw is not None:
1255
+ log_content = str(chat_message.raw)
1256
+ else:
1257
+ log_content = str(chat_message.content) or ""
1258
+
1259
+ self.logger.log_markdown(
1260
+ content=log_content,
1261
+ title="Output message of the LLM:",
1262
+ level=LogLevel.DEBUG,
1263
+ )
1264
+
1265
+ # Record model output
1266
+ memory_step.model_output_message = chat_message
1267
+ memory_step.model_output = chat_message.content
1268
+ memory_step.token_usage = chat_message.token_usage
1269
+ except Exception as e:
1270
+ raise AgentGenerationError(f"Error while generating output:\n{e}", self.logger) from e
1271
+
1272
+ if chat_message.tool_calls is None or len(chat_message.tool_calls) == 0:
1273
+ try:
1274
+ chat_message = self.model.parse_tool_calls(chat_message)
1275
+ except Exception as e:
1276
+ raise AgentParsingError(f"Error while parsing tool call from model output: {e}", self.logger)
1277
+ else:
1278
+ for tool_call in chat_message.tool_calls:
1279
+ tool_call.function.arguments = parse_json_if_needed(tool_call.function.arguments)
1280
+ final_answer, got_final_answer = None, False
1281
+ for output in self.process_tool_calls(chat_message, memory_step):
1282
+ yield output
1283
+ if isinstance(output, ToolOutput):
1284
+ if output.is_final_answer:
1285
+ if got_final_answer:
1286
+ raise AgentToolExecutionError(
1287
+ "You returned multiple final answers. Please return only one single final answer!",
1288
+ self.logger,
1289
+ )
1290
+ final_answer = output.output
1291
+ got_final_answer = True
1292
+
1293
+ # Manage state variables
1294
+ if isinstance(final_answer, str) and final_answer in self.state.keys():
1295
+ final_answer = self.state[final_answer]
1296
+ yield ActionOutput(
1297
+ output=final_answer,
1298
+ is_final_answer=got_final_answer,
1299
+ )
1300
+
1301
+ def process_tool_calls(
1302
+ self, chat_message: ChatMessage, memory_step: ActionStep
1303
+ ) -> Generator[ToolCall | ToolOutput]:
1304
+ """Process tool calls from the model output and update agent memory.
1305
+
1306
+ Args:
1307
+ chat_message (`ChatMessage`): Chat message containing tool calls from the model.
1308
+ memory_step (`ActionStep)`: Memory ActionStep to update with results.
1309
+
1310
+ Yields:
1311
+ `ToolCall | ToolOutput`: The tool call or tool output.
1312
+ """
1313
+ parallel_calls: dict[str, ToolCall] = {}
1314
+ assert chat_message.tool_calls is not None
1315
+ for chat_tool_call in chat_message.tool_calls:
1316
+ tool_call = ToolCall(
1317
+ name=chat_tool_call.function.name, arguments=chat_tool_call.function.arguments, id=chat_tool_call.id
1318
+ )
1319
+ yield tool_call
1320
+ parallel_calls[tool_call.id] = tool_call
1321
+
1322
+ # Helper function to process a single tool call
1323
+ def process_single_tool_call(tool_call: ToolCall) -> ToolOutput:
1324
+ tool_name = tool_call.name
1325
+ tool_arguments = tool_call.arguments or {}
1326
+ self.logger.log(
1327
+ Panel(Text(f"Calling tool: '{tool_name}' with arguments: {tool_arguments}")),
1328
+ level=LogLevel.INFO,
1329
+ )
1330
+ tool_call_result = self.execute_tool_call(tool_name, tool_arguments)
1331
+ tool_call_result_type = type(tool_call_result)
1332
+ if tool_call_result_type in [AgentImage, AgentAudio]:
1333
+ if tool_call_result_type == AgentImage:
1334
+ observation_name = "image.png"
1335
+ elif tool_call_result_type == AgentAudio:
1336
+ observation_name = "audio.mp3"
1337
+ # TODO: tool_call_result naming could allow for different names of same type
1338
+ self.state[observation_name] = tool_call_result
1339
+ observation = f"Stored '{observation_name}' in memory."
1340
+ else:
1341
+ observation = str(tool_call_result).strip()
1342
+ self.logger.log(
1343
+ f"Observations: {observation.replace('[', '|')}", # escape potential rich-tag-like components
1344
+ level=LogLevel.INFO,
1345
+ )
1346
+ is_final_answer = tool_name == "final_answer"
1347
+
1348
+ return ToolOutput(
1349
+ id=tool_call.id,
1350
+ output=tool_call_result,
1351
+ is_final_answer=is_final_answer,
1352
+ observation=observation,
1353
+ tool_call=tool_call,
1354
+ )
1355
+
1356
+ # Process tool calls in parallel
1357
+ outputs = {}
1358
+ if len(parallel_calls) == 1:
1359
+ # If there's only one call, process it directly
1360
+ tool_call = list(parallel_calls.values())[0]
1361
+ tool_output = process_single_tool_call(tool_call)
1362
+ outputs[tool_output.id] = tool_output
1363
+ yield tool_output
1364
+ else:
1365
+ # If multiple tool calls, process them in parallel
1366
+ with ThreadPoolExecutor(self.max_tool_threads) as executor:
1367
+ futures = [
1368
+ executor.submit(process_single_tool_call, tool_call) for tool_call in parallel_calls.values()
1369
+ ]
1370
+ for future in as_completed(futures):
1371
+ tool_output = future.result()
1372
+ outputs[tool_output.id] = tool_output
1373
+ yield tool_output
1374
+
1375
+ memory_step.tool_calls = [parallel_calls[k] for k in sorted(parallel_calls.keys())]
1376
+ memory_step.model_output = memory_step.model_output or ""
1377
+ memory_step.observations = memory_step.observations or ""
1378
+ for tool_output in [outputs[k] for k in sorted(outputs.keys())]:
1379
+ message = f"Tool call {tool_output.id}: calling '{tool_output.tool_call.name}' with arguments: {tool_output.tool_call.arguments}\n"
1380
+ memory_step.model_output += message
1381
+ memory_step.observations += tool_output.observation + "\n"
1382
+ memory_step.model_output = memory_step.model_output.rstrip("\n")
1383
+ memory_step.observations = (
1384
+ memory_step.observations.rstrip("\n") if memory_step.observations else memory_step.observations
1385
+ )
1386
+
1387
+ def _substitute_state_variables(self, arguments: dict[str, str] | str) -> dict[str, Any] | str:
1388
+ """Replace string values in arguments with their corresponding state values if they exist."""
1389
+ if isinstance(arguments, dict):
1390
+ return {
1391
+ key: self.state.get(value, value) if isinstance(value, str) else value
1392
+ for key, value in arguments.items()
1393
+ }
1394
+ return arguments
1395
+
1396
+ def execute_tool_call(self, tool_name: str, arguments: dict[str, str] | str) -> Any:
1397
+ """
1398
+ Execute a tool or managed agent with the provided arguments.
1399
+
1400
+ The arguments are replaced with the actual values from the state if they refer to state variables.
1401
+
1402
+ Args:
1403
+ tool_name (`str`): Name of the tool or managed agent to execute.
1404
+ arguments (dict[str, str] | str): Arguments passed to the tool call.
1405
+ """
1406
+ # Check if the tool exists
1407
+ available_tools = {**self.tools, **self.managed_agents}
1408
+ if tool_name not in available_tools:
1409
+ raise AgentToolExecutionError(
1410
+ f"Unknown tool {tool_name}, should be one of: {', '.join(available_tools)}.", self.logger
1411
+ )
1412
+
1413
+ # Get the tool and substitute state variables in arguments
1414
+ tool = available_tools[tool_name]
1415
+ arguments = self._substitute_state_variables(arguments)
1416
+ is_managed_agent = tool_name in self.managed_agents
1417
+
1418
+ error_msg = validate_tool_arguments(tool, arguments)
1419
+ if error_msg:
1420
+ raise AgentToolCallError(error_msg, self.logger)
1421
+
1422
+ try:
1423
+ # Call tool with appropriate arguments
1424
+ if isinstance(arguments, dict):
1425
+ return tool(**arguments) if is_managed_agent else tool(**arguments, sanitize_inputs_outputs=True)
1426
+ else:
1427
+ return tool(arguments) if is_managed_agent else tool(arguments, sanitize_inputs_outputs=True)
1428
+
1429
+ except Exception as e:
1430
+ # Handle execution errors
1431
+ if is_managed_agent:
1432
+ error_msg = (
1433
+ f"Error executing request to team member '{tool_name}' with arguments {str(arguments)}: {e}\n"
1434
+ "Please try again or request to another team member"
1435
+ )
1436
+ else:
1437
+ error_msg = (
1438
+ f"Error executing tool '{tool_name}' with arguments {str(arguments)}: {type(e).__name__}: {e}\n"
1439
+ "Please try again or use another tool"
1440
+ )
1441
+ raise AgentToolExecutionError(error_msg, self.logger) from e
1442
+
1443
+
1444
+ class CodeAgent(MultiStepAgent):
1445
+ """
1446
+ In this agent, the tool calls will be formulated by the LLM in code format, then parsed and executed.
1447
+
1448
+ Args:
1449
+ tools (`list[Tool]`): [`Tool`]s that the agent can use.
1450
+ model (`Model`): Model that will generate the agent's actions.
1451
+ prompt_templates ([`~agents.PromptTemplates`], *optional*): Prompt templates.
1452
+ additional_authorized_imports (`list[str]`, *optional*): Additional authorized imports for the agent.
1453
+ planning_interval (`int`, *optional*): Interval at which the agent will run a planning step.
1454
+ executor_type (`str`, default `"local"`): Which executor type to use between `"local"`, `"e2b"`, or `"docker"`.
1455
+ executor_kwargs (`dict`, *optional*): Additional arguments to pass to initialize the executor.
1456
+ max_print_outputs_length (`int`, *optional*): Maximum length of the print outputs.
1457
+ stream_outputs (`bool`, *optional*, default `False`): Whether to stream outputs during execution.
1458
+ use_structured_outputs_internally (`bool`, default `False`): Whether to use structured generation at each action step: improves performance for many models.
1459
+
1460
+ <Added version="1.17.0"/>
1461
+ grammar (`dict[str, str]`, *optional*): Grammar used to parse the LLM output.
1462
+ <Deprecated version="1.17.0">
1463
+ Parameter `grammar` is deprecated and will be removed in version 1.20.
1464
+ </Deprecated>
1465
+ **kwargs: Additional keyword arguments.
1466
+ """
1467
+
1468
+ def __init__(
1469
+ self,
1470
+ tools: list[Tool],
1471
+ model: Model,
1472
+ prompt_templates: PromptTemplates | None = None,
1473
+ additional_authorized_imports: list[str] | None = None,
1474
+ planning_interval: int | None = None,
1475
+ executor_type: str | None = "local",
1476
+ executor_kwargs: dict[str, Any] | None = None,
1477
+ max_print_outputs_length: int | None = None,
1478
+ stream_outputs: bool = False,
1479
+ use_structured_outputs_internally: bool = False,
1480
+ grammar: dict[str, str] | None = None,
1481
+ **kwargs,
1482
+ ):
1483
+ self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else []
1484
+ self.authorized_imports = sorted(set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports))
1485
+ self.max_print_outputs_length = max_print_outputs_length
1486
+ self._use_structured_outputs_internally = use_structured_outputs_internally
1487
+ if use_structured_outputs_internally:
1488
+ prompt_templates = prompt_templates or yaml.safe_load(
1489
+ importlib.resources.files("smolagents.prompts").joinpath("structured_code_agent.yaml").read_text()
1490
+ )
1491
+ else:
1492
+ prompt_templates = prompt_templates or yaml.safe_load(
1493
+ importlib.resources.files("smolagents.prompts").joinpath("code_agent.yaml").read_text()
1494
+ )
1495
+ if grammar and use_structured_outputs_internally:
1496
+ raise ValueError("You cannot use 'grammar' and 'use_structured_outputs_internally' at the same time.")
1497
+ super().__init__(
1498
+ tools=tools,
1499
+ model=model,
1500
+ prompt_templates=prompt_templates,
1501
+ grammar=grammar,
1502
+ planning_interval=planning_interval,
1503
+ **kwargs,
1504
+ )
1505
+ self.stream_outputs = stream_outputs
1506
+ if self.stream_outputs and not hasattr(self.model, "generate_stream"):
1507
+ raise ValueError(
1508
+ "`stream_outputs` is set to True, but the model class implements no `generate_stream` method."
1509
+ )
1510
+ if "*" in self.additional_authorized_imports:
1511
+ self.logger.log(
1512
+ "Caution: you set an authorization for all imports, meaning your agent can decide to import any package it deems necessary. This might raise issues if the package is not installed in your environment.",
1513
+ level=LogLevel.INFO,
1514
+ )
1515
+ self.executor_type = executor_type or "local"
1516
+ self.executor_kwargs = executor_kwargs or {}
1517
+ self.python_executor = self.create_python_executor()
1518
+
1519
+ def __enter__(self):
1520
+ return self
1521
+
1522
+ def __exit__(self, exc_type, exc_value, traceback):
1523
+ self.cleanup()
1524
+
1525
+ def cleanup(self):
1526
+ """Clean up resources used by the agent, such as the remote Python executor."""
1527
+ if hasattr(self.python_executor, "cleanup"):
1528
+ self.python_executor.cleanup()
1529
+
1530
+ def create_python_executor(self) -> PythonExecutor:
1531
+ match self.executor_type:
1532
+ case "e2b" | "docker":
1533
+ if self.managed_agents:
1534
+ raise Exception("Managed agents are not yet supported with remote code execution.")
1535
+ if self.executor_type == "e2b":
1536
+ return E2BExecutor(self.additional_authorized_imports, self.logger, **self.executor_kwargs)
1537
+ else:
1538
+ return DockerExecutor(self.additional_authorized_imports, self.logger, **self.executor_kwargs)
1539
+ case "local":
1540
+ return LocalPythonExecutor(
1541
+ self.additional_authorized_imports,
1542
+ **{"max_print_outputs_length": self.max_print_outputs_length} | self.executor_kwargs,
1543
+ )
1544
+ case _: # if applicable
1545
+ raise ValueError(f"Unsupported executor type: {self.executor_type}")
1546
+
1547
+ def initialize_system_prompt(self) -> str:
1548
+ system_prompt = populate_template(
1549
+ self.prompt_templates["system_prompt"],
1550
+ variables={
1551
+ "tools": self.tools,
1552
+ "managed_agents": self.managed_agents,
1553
+ "authorized_imports": (
1554
+ "You can import from any package you want."
1555
+ if "*" in self.authorized_imports
1556
+ else str(self.authorized_imports)
1557
+ ),
1558
+ "custom_instructions": self.instructions,
1559
+ },
1560
+ )
1561
+ return system_prompt
1562
+
1563
+ def _step_stream(
1564
+ self, memory_step: ActionStep
1565
+ ) -> Generator[ChatMessageStreamDelta | ToolCall | ToolOutput | ActionOutput]:
1566
+ """
1567
+ Perform one step in the ReAct framework: the agent thinks, acts, and observes the result.
1568
+ Yields ChatMessageStreamDelta during the run if streaming is enabled.
1569
+ At the end, yields either None if the step is not final, or the final answer.
1570
+ """
1571
+ memory_messages = self.write_memory_to_messages()
1572
+
1573
+ input_messages = memory_messages.copy()
1574
+ ### Generate model output ###
1575
+ memory_step.model_input_messages = input_messages
1576
+ try:
1577
+ additional_args: dict[str, Any] = {}
1578
+ if self.grammar:
1579
+ additional_args["grammar"] = self.grammar
1580
+ if self._use_structured_outputs_internally:
1581
+ additional_args["response_format"] = CODEAGENT_RESPONSE_FORMAT
1582
+ if self.stream_outputs:
1583
+ output_stream = self.model.generate_stream(
1584
+ input_messages,
1585
+ stop_sequences=["<end_code>", "Observation:", "Calling tools:"],
1586
+ **additional_args,
1587
+ )
1588
+ chat_message_stream_deltas: list[ChatMessageStreamDelta] = []
1589
+ with Live("", console=self.logger.console, vertical_overflow="visible") as live:
1590
+ for event in output_stream:
1591
+ chat_message_stream_deltas.append(event)
1592
+ live.update(
1593
+ Markdown(agglomerate_stream_deltas(chat_message_stream_deltas).render_as_markdown())
1594
+ )
1595
+ yield event
1596
+ chat_message = agglomerate_stream_deltas(chat_message_stream_deltas)
1597
+ memory_step.model_output_message = chat_message
1598
+ output_text = chat_message.content
1599
+ else:
1600
+ chat_message: ChatMessage = self.model.generate(
1601
+ input_messages,
1602
+ stop_sequences=["<end_code>", "Observation:", "Calling tools:"],
1603
+ **additional_args,
1604
+ )
1605
+ memory_step.model_output_message = chat_message
1606
+ output_text = chat_message.content
1607
+ self.logger.log_markdown(
1608
+ content=output_text,
1609
+ title="Output message of the LLM:",
1610
+ level=LogLevel.DEBUG,
1611
+ )
1612
+
1613
+ # This adds <end_code> sequence to the history.
1614
+ # This will nudge ulterior LLM calls to finish with <end_code>, thus efficiently stopping generation.
1615
+ if output_text and output_text.strip().endswith("```"):
1616
+ output_text += "<end_code>"
1617
+ memory_step.model_output_message.content = output_text
1618
+
1619
+ memory_step.token_usage = chat_message.token_usage
1620
+ memory_step.model_output = output_text
1621
+ except Exception as e:
1622
+ raise AgentGenerationError(f"Error in generating model output:\n{e}", self.logger) from e
1623
+
1624
+ ### Parse output ###
1625
+ try:
1626
+ if self._use_structured_outputs_internally:
1627
+ code_action = json.loads(output_text)["code"]
1628
+ code_action = extract_code_from_text(code_action) or code_action
1629
+ else:
1630
+ code_action = parse_code_blobs(output_text)
1631
+ code_action = fix_final_answer_code(code_action)
1632
+ memory_step.code_action = code_action
1633
+ except Exception as e:
1634
+ error_msg = f"Error in code parsing:\n{e}\nMake sure to provide correct code blobs."
1635
+ raise AgentParsingError(error_msg, self.logger)
1636
+
1637
+ tool_call = ToolCall(
1638
+ name="python_interpreter",
1639
+ arguments=code_action,
1640
+ id=f"call_{len(self.memory.steps)}",
1641
+ )
1642
+ yield tool_call
1643
+ memory_step.tool_calls = [tool_call]
1644
+
1645
+ ### Execute action ###
1646
+ self.logger.log_code(title="Executing parsed code:", content=code_action, level=LogLevel.INFO)
1647
+ is_final_answer = False
1648
+ try:
1649
+ output, execution_logs, is_final_answer = self.python_executor(code_action)
1650
+ execution_outputs_console = []
1651
+ if len(execution_logs) > 0:
1652
+ execution_outputs_console += [
1653
+ Text("Execution logs:", style="bold"),
1654
+ Text(execution_logs),
1655
+ ]
1656
+ observation = "Execution logs:\n" + execution_logs
1657
+ except Exception as e:
1658
+ if hasattr(self.python_executor, "state") and "_print_outputs" in self.python_executor.state:
1659
+ execution_logs = str(self.python_executor.state["_print_outputs"])
1660
+ if len(execution_logs) > 0:
1661
+ execution_outputs_console = [
1662
+ Text("Execution logs:", style="bold"),
1663
+ Text(execution_logs),
1664
+ ]
1665
+ memory_step.observations = "Execution logs:\n" + execution_logs
1666
+ self.logger.log(Group(*execution_outputs_console), level=LogLevel.INFO)
1667
+ error_msg = str(e)
1668
+ if "Import of " in error_msg and " is not allowed" in error_msg:
1669
+ self.logger.log(
1670
+ "[bold red]Warning to user: Code execution failed due to an unauthorized import - Consider passing said import under `additional_authorized_imports` when initializing your CodeAgent.",
1671
+ level=LogLevel.INFO,
1672
+ )
1673
+ raise AgentExecutionError(error_msg, self.logger)
1674
+
1675
+ truncated_output = truncate_content(str(output))
1676
+ observation += "Last output from code snippet:\n" + truncated_output
1677
+ memory_step.observations = observation
1678
+
1679
+ if not is_final_answer:
1680
+ execution_outputs_console += [
1681
+ Text(
1682
+ f"Out: {truncated_output}",
1683
+ ),
1684
+ ]
1685
+ self.logger.log(Group(*execution_outputs_console), level=LogLevel.INFO)
1686
+ memory_step.action_output = output
1687
+ yield ActionOutput(output=output, is_final_answer=is_final_answer)
1688
+
1689
+ def to_dict(self) -> dict[str, Any]:
1690
+ """Convert the agent to a dictionary representation.
1691
+
1692
+ Returns:
1693
+ `dict`: Dictionary representation of the agent.
1694
+ """
1695
+ agent_dict = super().to_dict()
1696
+ agent_dict["authorized_imports"] = self.authorized_imports
1697
+ agent_dict["executor_type"] = self.executor_type
1698
+ agent_dict["executor_kwargs"] = self.executor_kwargs
1699
+ agent_dict["max_print_outputs_length"] = self.max_print_outputs_length
1700
+ return agent_dict
1701
+
1702
+ @classmethod
1703
+ def from_dict(cls, agent_dict: dict[str, Any], **kwargs) -> "CodeAgent":
1704
+ """Create CodeAgent from a dictionary representation.
1705
+
1706
+ Args:
1707
+ agent_dict (`dict[str, Any]`): Dictionary representation of the agent.
1708
+ **kwargs: Additional keyword arguments that will override agent_dict values.
1709
+
1710
+ Returns:
1711
+ `CodeAgent`: Instance of the CodeAgent class.
1712
+ """
1713
+ # Add CodeAgent-specific parameters to kwargs
1714
+ code_agent_kwargs = {
1715
+ "additional_authorized_imports": agent_dict.get("authorized_imports"),
1716
+ "executor_type": agent_dict.get("executor_type"),
1717
+ "executor_kwargs": agent_dict.get("executor_kwargs"),
1718
+ "max_print_outputs_length": agent_dict.get("max_print_outputs_length"),
1719
+ }
1720
+ # Filter out None values
1721
+ code_agent_kwargs = {k: v for k, v in code_agent_kwargs.items() if v is not None}
1722
+ # Update with any additional kwargs
1723
+ code_agent_kwargs.update(kwargs)
1724
+ # Call the parent class's from_dict method
1725
+ return super().from_dict(agent_dict, **code_agent_kwargs)
src/smolagents/cli.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+
4
+ # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ import argparse
18
+ import os
19
+
20
+ from dotenv import load_dotenv
21
+
22
+ from smolagents import CodeAgent, InferenceClientModel, LiteLLMModel, Model, OpenAIServerModel, Tool, TransformersModel
23
+ from smolagents.default_tools import TOOL_MAPPING
24
+
25
+
26
+ leopard_prompt = "How many seconds would it take for a leopard at full speed to run through Pont des Arts?"
27
+
28
+
29
+ def parse_arguments():
30
+ parser = argparse.ArgumentParser(description="Run a CodeAgent with all specified parameters")
31
+ parser.add_argument(
32
+ "prompt",
33
+ type=str,
34
+ nargs="?", # Makes it optional
35
+ default=leopard_prompt,
36
+ help="The prompt to run with the agent",
37
+ )
38
+ parser.add_argument(
39
+ "--model-type",
40
+ type=str,
41
+ default="InferenceClientModel",
42
+ help="The model type to use (e.g., InferenceClientModel, OpenAIServerModel, LiteLLMModel, TransformersModel)",
43
+ )
44
+ parser.add_argument(
45
+ "--model-id",
46
+ type=str,
47
+ default="Qwen/Qwen2.5-Coder-32B-Instruct",
48
+ help="The model ID to use for the specified model type",
49
+ )
50
+ parser.add_argument(
51
+ "--imports",
52
+ nargs="*", # accepts zero or more arguments
53
+ default=[],
54
+ help="Space-separated list of imports to authorize (e.g., 'numpy pandas')",
55
+ )
56
+ parser.add_argument(
57
+ "--tools",
58
+ nargs="*",
59
+ default=["web_search"],
60
+ help="Space-separated list of tools that the agent can use (e.g., 'tool1 tool2 tool3')",
61
+ )
62
+ parser.add_argument(
63
+ "--verbosity-level",
64
+ type=int,
65
+ default=1,
66
+ help="The verbosity level, as an int in [0, 1, 2].",
67
+ )
68
+ group = parser.add_argument_group("api options", "Options for API-based model types")
69
+ group.add_argument(
70
+ "--provider",
71
+ type=str,
72
+ default=None,
73
+ help="The inference provider to use for the model",
74
+ )
75
+ group.add_argument(
76
+ "--api-base",
77
+ type=str,
78
+ help="The base URL for the model",
79
+ )
80
+ group.add_argument(
81
+ "--api-key",
82
+ type=str,
83
+ help="The API key for the model",
84
+ )
85
+ return parser.parse_args()
86
+
87
+
88
+ def load_model(
89
+ model_type: str,
90
+ model_id: str,
91
+ api_base: str | None = None,
92
+ api_key: str | None = None,
93
+ provider: str | None = None,
94
+ ) -> Model:
95
+ if model_type == "OpenAIServerModel":
96
+ return OpenAIServerModel(
97
+ api_key=api_key or os.getenv("FIREWORKS_API_KEY"),
98
+ api_base=api_base or "https://api.fireworks.ai/inference/v1",
99
+ model_id=model_id,
100
+ )
101
+ elif model_type == "LiteLLMModel":
102
+ return LiteLLMModel(
103
+ model_id=model_id,
104
+ api_key=api_key,
105
+ api_base=api_base,
106
+ )
107
+ elif model_type == "TransformersModel":
108
+ return TransformersModel(model_id=model_id, device_map="auto")
109
+ elif model_type == "InferenceClientModel":
110
+ return InferenceClientModel(
111
+ model_id=model_id,
112
+ token=api_key or os.getenv("HF_API_KEY"),
113
+ provider=provider,
114
+ )
115
+ else:
116
+ raise ValueError(f"Unsupported model type: {model_type}")
117
+
118
+
119
+ def run_smolagent(
120
+ prompt: str,
121
+ tools: list[str],
122
+ model_type: str,
123
+ model_id: str,
124
+ api_base: str | None = None,
125
+ api_key: str | None = None,
126
+ imports: list[str] | None = None,
127
+ provider: str | None = None,
128
+ ) -> None:
129
+ load_dotenv()
130
+
131
+ model = load_model(model_type, model_id, api_base=api_base, api_key=api_key, provider=provider)
132
+
133
+ available_tools = []
134
+ for tool_name in tools:
135
+ if "/" in tool_name:
136
+ available_tools.append(Tool.from_space(tool_name))
137
+ else:
138
+ if tool_name in TOOL_MAPPING:
139
+ available_tools.append(TOOL_MAPPING[tool_name]())
140
+ else:
141
+ raise ValueError(f"Tool {tool_name} is not recognized either as a default tool or a Space.")
142
+
143
+ print(f"Running agent with these tools: {tools}")
144
+ agent = CodeAgent(tools=available_tools, model=model, additional_authorized_imports=imports)
145
+
146
+ agent.run(prompt)
147
+
148
+
149
+ def main() -> None:
150
+ args = parse_arguments()
151
+ run_smolagent(
152
+ args.prompt,
153
+ args.tools,
154
+ args.model_type,
155
+ args.model_id,
156
+ provider=args.provider,
157
+ api_base=args.api_base,
158
+ api_key=args.api_key,
159
+ imports=args.imports,
160
+ )
161
+
162
+
163
+ if __name__ == "__main__":
164
+ main()
src/smolagents/default_tools.py ADDED
@@ -0,0 +1,577 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+
4
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ from dataclasses import dataclass
18
+ from typing import Any
19
+
20
+ from .local_python_executor import (
21
+ BASE_BUILTIN_MODULES,
22
+ BASE_PYTHON_TOOLS,
23
+ evaluate_python_code,
24
+ )
25
+ from .tools import PipelineTool, Tool
26
+
27
+
28
+ @dataclass
29
+ class PreTool:
30
+ name: str
31
+ inputs: dict[str, str]
32
+ output_type: type
33
+ task: str
34
+ description: str
35
+ repo_id: str
36
+
37
+
38
+ class PythonInterpreterTool(Tool):
39
+ name = "python_interpreter"
40
+ description = "This is a tool that evaluates python code. It can be used to perform calculations."
41
+ inputs = {
42
+ "code": {
43
+ "type": "string",
44
+ "description": "The python code to run in interpreter",
45
+ }
46
+ }
47
+ output_type = "string"
48
+
49
+ def __init__(self, *args, authorized_imports=None, **kwargs):
50
+ if authorized_imports is None:
51
+ self.authorized_imports = list(set(BASE_BUILTIN_MODULES))
52
+ else:
53
+ self.authorized_imports = list(set(BASE_BUILTIN_MODULES) | set(authorized_imports))
54
+ self.inputs = {
55
+ "code": {
56
+ "type": "string",
57
+ "description": (
58
+ "The code snippet to evaluate. All variables used in this snippet must be defined in this same snippet, "
59
+ f"else you will get an error. This code can only import the following python libraries: {self.authorized_imports}."
60
+ ),
61
+ }
62
+ }
63
+ self.base_python_tools = BASE_PYTHON_TOOLS
64
+ self.python_evaluator = evaluate_python_code
65
+ super().__init__(*args, **kwargs)
66
+
67
+ def forward(self, code: str) -> str:
68
+ state = {}
69
+ output = str(
70
+ self.python_evaluator(
71
+ code,
72
+ state=state,
73
+ static_tools=self.base_python_tools,
74
+ authorized_imports=self.authorized_imports,
75
+ )[0] # The second element is boolean is_final_answer
76
+ )
77
+ return f"Stdout:\n{str(state['_print_outputs'])}\nOutput: {output}"
78
+
79
+
80
+ class FinalAnswerTool(Tool):
81
+ name = "final_answer"
82
+ description = "Provides a final answer to the given problem."
83
+ inputs = {"answer": {"type": "any", "description": "The final answer to the problem"}}
84
+ output_type = "any"
85
+
86
+ def forward(self, answer: Any) -> Any:
87
+ return answer
88
+
89
+
90
+ class UserInputTool(Tool):
91
+ name = "user_input"
92
+ description = "Asks for user's input on a specific question"
93
+ inputs = {"question": {"type": "string", "description": "The question to ask the user"}}
94
+ output_type = "string"
95
+
96
+ def forward(self, question):
97
+ user_input = input(f"{question} => Type your answer here:")
98
+ return user_input
99
+
100
+
101
+ class DuckDuckGoSearchTool(Tool):
102
+ name = "web_search"
103
+ description = """Performs a duckduckgo web search based on your query (think a Google search) then returns the top search results."""
104
+ inputs = {"query": {"type": "string", "description": "The search query to perform."}}
105
+ output_type = "string"
106
+
107
+ def __init__(self, max_results=10, **kwargs):
108
+ super().__init__()
109
+ self.max_results = max_results
110
+ try:
111
+ from duckduckgo_search import DDGS
112
+ except ImportError as e:
113
+ raise ImportError(
114
+ "You must install package `duckduckgo_search` to run this tool: for instance run `pip install duckduckgo-search`."
115
+ ) from e
116
+ self.ddgs = DDGS(**kwargs)
117
+
118
+ def forward(self, query: str) -> str:
119
+ results = self.ddgs.text(query, max_results=self.max_results)
120
+ if len(results) == 0:
121
+ raise Exception("No results found! Try a less restrictive/shorter query.")
122
+ postprocessed_results = [f"[{result['title']}]({result['href']})\n{result['body']}" for result in results]
123
+ return "## Search Results\n\n" + "\n\n".join(postprocessed_results)
124
+
125
+
126
+ class GoogleSearchTool(Tool):
127
+ name = "web_search"
128
+ description = """Performs a google web search for your query then returns a string of the top search results."""
129
+ inputs = {
130
+ "query": {"type": "string", "description": "The search query to perform."},
131
+ "filter_year": {
132
+ "type": "integer",
133
+ "description": "Optionally restrict results to a certain year",
134
+ "nullable": True,
135
+ },
136
+ }
137
+ output_type = "string"
138
+
139
+ def __init__(self, provider: str = "serpapi"):
140
+ super().__init__()
141
+ import os
142
+
143
+ self.provider = provider
144
+ if provider == "serpapi":
145
+ self.organic_key = "organic_results"
146
+ api_key_env_name = "SERPAPI_API_KEY"
147
+ else:
148
+ self.organic_key = "organic"
149
+ api_key_env_name = "SERPER_API_KEY"
150
+ self.api_key = os.getenv(api_key_env_name)
151
+ if self.api_key is None:
152
+ raise ValueError(f"Missing API key. Make sure you have '{api_key_env_name}' in your env variables.")
153
+
154
+ def forward(self, query: str, filter_year: int | None = None) -> str:
155
+ import requests
156
+
157
+ if self.provider == "serpapi":
158
+ params = {
159
+ "q": query,
160
+ "api_key": self.api_key,
161
+ "engine": "google",
162
+ "google_domain": "google.com",
163
+ }
164
+ base_url = "https://serpapi.com/search.json"
165
+ else:
166
+ params = {
167
+ "q": query,
168
+ "api_key": self.api_key,
169
+ }
170
+ base_url = "https://google.serper.dev/search"
171
+ if filter_year is not None:
172
+ params["tbs"] = f"cdr:1,cd_min:01/01/{filter_year},cd_max:12/31/{filter_year}"
173
+
174
+ response = requests.get(base_url, params=params)
175
+
176
+ if response.status_code == 200:
177
+ results = response.json()
178
+ else:
179
+ raise ValueError(response.json())
180
+
181
+ if self.organic_key not in results.keys():
182
+ if filter_year is not None:
183
+ raise Exception(
184
+ f"No results found for query: '{query}' with filtering on year={filter_year}. Use a less restrictive query or do not filter on year."
185
+ )
186
+ else:
187
+ raise Exception(f"No results found for query: '{query}'. Use a less restrictive query.")
188
+ if len(results[self.organic_key]) == 0:
189
+ year_filter_message = f" with filter year={filter_year}" if filter_year is not None else ""
190
+ return f"No results found for '{query}'{year_filter_message}. Try with a more general query, or remove the year filter."
191
+
192
+ web_snippets = []
193
+ if self.organic_key in results:
194
+ for idx, page in enumerate(results[self.organic_key]):
195
+ date_published = ""
196
+ if "date" in page:
197
+ date_published = "\nDate published: " + page["date"]
198
+
199
+ source = ""
200
+ if "source" in page:
201
+ source = "\nSource: " + page["source"]
202
+
203
+ snippet = ""
204
+ if "snippet" in page:
205
+ snippet = "\n" + page["snippet"]
206
+
207
+ redacted_version = f"{idx}. [{page['title']}]({page['link']}){date_published}{source}\n{snippet}"
208
+ web_snippets.append(redacted_version)
209
+
210
+ return "## Search Results\n" + "\n\n".join(web_snippets)
211
+
212
+
213
+ class ApiWebSearchTool(Tool):
214
+ name = "web_search"
215
+ description = "Performs a web search for a query and returns a string of the top search results formatted as markdown with titles, URLs, and descriptions."
216
+ inputs = {"query": {"type": "string", "description": "The search query to perform."}}
217
+ output_type = "string"
218
+
219
+ def __init__(
220
+ self, endpoint: str = "", api_key: str = "", api_key_name: str = "", headers: dict = None, params: dict = None
221
+ ):
222
+ import os
223
+
224
+ super().__init__()
225
+ self.endpoint = endpoint or "https://api.search.brave.com/res/v1/web/search"
226
+ self.api_key = api_key or os.getenv(api_key_name)
227
+ self.headers = headers or {"X-Subscription-Token": self.api_key}
228
+ self.params = params or {"count": 10}
229
+
230
+ def forward(self, query: str) -> str:
231
+ import requests
232
+
233
+ params = {**self.params, "q": query}
234
+ response = requests.get(self.endpoint, headers=self.headers, params=params)
235
+ response.raise_for_status()
236
+ data = response.json()
237
+ results = self.extract_results(data)
238
+ return self.format_markdown(results)
239
+
240
+ def extract_results(self, data: dict) -> list:
241
+ results = []
242
+ for result in data.get("web", {}).get("results", []):
243
+ results.append(
244
+ {"title": result["title"], "url": result["url"], "description": result.get("description", "")}
245
+ )
246
+ return results
247
+
248
+ def format_markdown(self, results: list) -> str:
249
+ if not results:
250
+ return "No results found."
251
+ return "## Search Results\n\n" + "\n\n".join(
252
+ [
253
+ f"{idx}. [{result['title']}]({result['url']})\n{result['description']}"
254
+ for idx, result in enumerate(results, start=1)
255
+ ]
256
+ )
257
+
258
+
259
+ class WebSearchTool(Tool):
260
+ name = "web_search"
261
+ description = "Performs a web search for a query and returns a string of the top search results formatted as markdown with titles, links, and descriptions."
262
+ inputs = {"query": {"type": "string", "description": "The search query to perform."}}
263
+ output_type = "string"
264
+
265
+ def __init__(self, max_results: int = 10, engine: str = "duckduckgo"):
266
+ super().__init__()
267
+ self.max_results = max_results
268
+ self.engine = engine
269
+
270
+ def forward(self, query: str) -> str:
271
+ results = self.search(query)
272
+ if len(results) == 0:
273
+ raise Exception("No results found! Try a less restrictive/shorter query.")
274
+ return self.parse_results(results)
275
+
276
+ def search(self, query: str) -> list:
277
+ if self.engine == "duckduckgo":
278
+ return self.search_duckduckgo(query)
279
+ elif self.engine == "bing":
280
+ return self.search_bing(query)
281
+ else:
282
+ raise ValueError(f"Unsupported engine: {self.engine}")
283
+
284
+ def parse_results(self, results: list) -> str:
285
+ return "## Search Results\n\n" + "\n\n".join(
286
+ [f"[{result['title']}]({result['link']})\n{result['description']}" for result in results]
287
+ )
288
+
289
+ def search_duckduckgo(self, query: str) -> list:
290
+ import requests
291
+
292
+ response = requests.get(
293
+ "https://lite.duckduckgo.com/lite/",
294
+ params={"q": query},
295
+ headers={"User-Agent": "Mozilla/5.0"},
296
+ )
297
+ response.raise_for_status()
298
+ parser = self._create_duckduckgo_parser()
299
+ parser.feed(response.text)
300
+ return parser.results
301
+
302
+ def _create_duckduckgo_parser(self):
303
+ from html.parser import HTMLParser
304
+
305
+ class SimpleResultParser(HTMLParser):
306
+ def __init__(self):
307
+ super().__init__()
308
+ self.results = []
309
+ self.current = {}
310
+ self.capture_title = False
311
+ self.capture_description = False
312
+ self.capture_link = False
313
+
314
+ def handle_starttag(self, tag, attrs):
315
+ attrs = dict(attrs)
316
+ if tag == "a" and attrs.get("class") == "result-link":
317
+ self.capture_title = True
318
+ elif tag == "td" and attrs.get("class") == "result-snippet":
319
+ self.capture_description = True
320
+ elif tag == "span" and attrs.get("class") == "link-text":
321
+ self.capture_link = True
322
+
323
+ def handle_endtag(self, tag):
324
+ if tag == "a" and self.capture_title:
325
+ self.capture_title = False
326
+ elif tag == "td" and self.capture_description:
327
+ self.capture_description = False
328
+ elif tag == "span" and self.capture_link:
329
+ self.capture_link = False
330
+ elif tag == "tr":
331
+ # Store current result if all parts are present
332
+ if {"title", "description", "link"} <= self.current.keys():
333
+ self.current["description"] = " ".join(self.current["description"])
334
+ self.results.append(self.current)
335
+ self.current = {}
336
+
337
+ def handle_data(self, data):
338
+ if self.capture_title:
339
+ self.current["title"] = data.strip()
340
+ elif self.capture_description:
341
+ self.current.setdefault("description", [])
342
+ self.current["description"].append(data.strip())
343
+ elif self.capture_link:
344
+ self.current["link"] = "https://" + data.strip()
345
+
346
+ return SimpleResultParser()
347
+
348
+ def search_bing(self, query: str) -> list:
349
+ import xml.etree.ElementTree as ET
350
+
351
+ import requests
352
+
353
+ response = requests.get(
354
+ "https://www.bing.com/search",
355
+ params={"q": query, "format": "rss"},
356
+ )
357
+ response.raise_for_status()
358
+ root = ET.fromstring(response.text)
359
+ items = root.findall(".//item")
360
+ results = [
361
+ {
362
+ "title": item.findtext("title"),
363
+ "link": item.findtext("link"),
364
+ "description": item.findtext("description"),
365
+ }
366
+ for item in items[: self.max_results]
367
+ ]
368
+ return results
369
+
370
+
371
+ class VisitWebpageTool(Tool):
372
+ name = "visit_webpage"
373
+ description = (
374
+ "Visits a webpage at the given url and reads its content as a markdown string. Use this to browse webpages."
375
+ )
376
+ inputs = {
377
+ "url": {
378
+ "type": "string",
379
+ "description": "The url of the webpage to visit.",
380
+ }
381
+ }
382
+ output_type = "string"
383
+
384
+ def __init__(self, max_output_length: int = 40000):
385
+ super().__init__()
386
+ self.max_output_length = max_output_length
387
+
388
+ def _truncate_content(self, content: str, max_length: int) -> str:
389
+ if len(content) <= max_length:
390
+ return content
391
+ return (
392
+ content[: max_length // 2]
393
+ + f"\n..._This content has been truncated to stay below {max_length} characters_...\n"
394
+ + content[-max_length // 2 :]
395
+ )
396
+
397
+ def forward(self, url: str) -> str:
398
+ try:
399
+ import re
400
+
401
+ import requests
402
+ from markdownify import markdownify
403
+ from requests.exceptions import RequestException
404
+ except ImportError as e:
405
+ raise ImportError(
406
+ "You must install packages `markdownify` and `requests` to run this tool: for instance run `pip install markdownify requests`."
407
+ ) from e
408
+ try:
409
+ # Send a GET request to the URL with a 20-second timeout
410
+ response = requests.get(url, timeout=20)
411
+ response.raise_for_status() # Raise an exception for bad status codes
412
+
413
+ # Convert the HTML content to Markdown
414
+ markdown_content = markdownify(response.text).strip()
415
+
416
+ # Remove multiple line breaks
417
+ markdown_content = re.sub(r"\n{3,}", "\n\n", markdown_content)
418
+
419
+ return self._truncate_content(markdown_content, self.max_output_length)
420
+
421
+ except requests.exceptions.Timeout:
422
+ return "The request timed out. Please try again later or check the URL."
423
+ except RequestException as e:
424
+ return f"Error fetching the webpage: {str(e)}"
425
+ except Exception as e:
426
+ return f"An unexpected error occurred: {str(e)}"
427
+
428
+
429
+ class WikipediaSearchTool(Tool):
430
+ """
431
+ WikipediaSearchTool searches Wikipedia and returns a summary or full text of the given topic, along with the page URL.
432
+
433
+ Attributes:
434
+ user_agent (str): A custom user-agent string to identify the project. This is required as per Wikipedia API policies, read more here: http://github.com/martin-majlis/Wikipedia-API/blob/master/README.rst
435
+ language (str): The language in which to retrieve Wikipedia articles.
436
+ http://meta.wikimedia.org/wiki/List_of_Wikipedias
437
+ content_type (str): Defines the content to fetch. Can be "summary" for a short summary or "text" for the full article.
438
+ extract_format (str): Defines the output format. Can be `"WIKI"` or `"HTML"`.
439
+
440
+ Example:
441
+ >>> from smolagents import CodeAgent, InferenceClientModel, WikipediaSearchTool
442
+ >>> agent = CodeAgent(
443
+ >>> tools=[
444
+ >>> WikipediaSearchTool(
445
+ >>> user_agent="MyResearchBot ([email protected])",
446
+ >>> language="en",
447
+ >>> content_type="summary", # or "text"
448
+ >>> extract_format="WIKI",
449
+ >>> )
450
+ >>> ],
451
+ >>> model=InferenceClientModel(),
452
+ >>> )
453
+ >>> agent.run("Python_(programming_language)")
454
+ """
455
+
456
+ name = "wikipedia_search"
457
+ description = "Searches Wikipedia and returns a summary or full text of the given topic, along with the page URL."
458
+ inputs = {
459
+ "query": {
460
+ "type": "string",
461
+ "description": "The topic to search on Wikipedia.",
462
+ }
463
+ }
464
+ output_type = "string"
465
+
466
+ def __init__(
467
+ self,
468
+ user_agent: str = "Smolagents ([email protected])",
469
+ language: str = "en",
470
+ content_type: str = "text",
471
+ extract_format: str = "WIKI",
472
+ ):
473
+ super().__init__()
474
+ try:
475
+ import wikipediaapi
476
+ except ImportError as e:
477
+ raise ImportError(
478
+ "You must install `wikipedia-api` to run this tool: for instance run `pip install wikipedia-api`"
479
+ ) from e
480
+ if not user_agent:
481
+ raise ValueError("User-agent is required. Provide a meaningful identifier for your project.")
482
+
483
+ self.user_agent = user_agent
484
+ self.language = language
485
+ self.content_type = content_type
486
+
487
+ # Map string format to wikipediaapi.ExtractFormat
488
+ extract_format_map = {
489
+ "WIKI": wikipediaapi.ExtractFormat.WIKI,
490
+ "HTML": wikipediaapi.ExtractFormat.HTML,
491
+ }
492
+
493
+ if extract_format not in extract_format_map:
494
+ raise ValueError("Invalid extract_format. Choose between 'WIKI' or 'HTML'.")
495
+
496
+ self.extract_format = extract_format_map[extract_format]
497
+
498
+ self.wiki = wikipediaapi.Wikipedia(
499
+ user_agent=self.user_agent, language=self.language, extract_format=self.extract_format
500
+ )
501
+
502
+ def forward(self, query: str) -> str:
503
+ try:
504
+ page = self.wiki.page(query)
505
+
506
+ if not page.exists():
507
+ return f"No Wikipedia page found for '{query}'. Try a different query."
508
+
509
+ title = page.title
510
+ url = page.fullurl
511
+
512
+ if self.content_type == "summary":
513
+ text = page.summary
514
+ elif self.content_type == "text":
515
+ text = page.text
516
+ else:
517
+ return "⚠️ Invalid `content_type`. Use either 'summary' or 'text'."
518
+
519
+ return f"✅ **Wikipedia Page:** {title}\n\n**Content:** {text}\n\n🔗 **Read more:** {url}"
520
+
521
+ except Exception as e:
522
+ return f"Error fetching Wikipedia summary: {str(e)}"
523
+
524
+
525
+ class SpeechToTextTool(PipelineTool):
526
+ default_checkpoint = "openai/whisper-large-v3-turbo"
527
+ description = "This is a tool that transcribes an audio into text. It returns the transcribed text."
528
+ name = "transcriber"
529
+ inputs = {
530
+ "audio": {
531
+ "type": "audio",
532
+ "description": "The audio to transcribe. Can be a local path, an url, or a tensor.",
533
+ }
534
+ }
535
+ output_type = "string"
536
+
537
+ def __new__(cls, *args, **kwargs):
538
+ from transformers.models.whisper import WhisperForConditionalGeneration, WhisperProcessor
539
+
540
+ cls.pre_processor_class = WhisperProcessor
541
+ cls.model_class = WhisperForConditionalGeneration
542
+ return super().__new__(cls)
543
+
544
+ def encode(self, audio):
545
+ from .agent_types import AgentAudio
546
+
547
+ audio = AgentAudio(audio).to_raw()
548
+ return self.pre_processor(audio, return_tensors="pt")
549
+
550
+ def forward(self, inputs):
551
+ return self.model.generate(inputs["input_features"])
552
+
553
+ def decode(self, outputs):
554
+ return self.pre_processor.batch_decode(outputs, skip_special_tokens=True)[0]
555
+
556
+
557
+ TOOL_MAPPING = {
558
+ tool_class.name: tool_class
559
+ for tool_class in [
560
+ PythonInterpreterTool,
561
+ DuckDuckGoSearchTool,
562
+ VisitWebpageTool,
563
+ ]
564
+ }
565
+
566
+ __all__ = [
567
+ "ApiWebSearchTool",
568
+ "PythonInterpreterTool",
569
+ "FinalAnswerTool",
570
+ "UserInputTool",
571
+ "WebSearchTool",
572
+ "DuckDuckGoSearchTool",
573
+ "GoogleSearchTool",
574
+ "VisitWebpageTool",
575
+ "WikipediaSearchTool",
576
+ "SpeechToTextTool",
577
+ ]
src/smolagents/gradio_ui.py ADDED
@@ -0,0 +1,508 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ import os
17
+ import re
18
+ import shutil
19
+ from pathlib import Path
20
+ from typing import Generator
21
+
22
+ from smolagents.agent_types import AgentAudio, AgentImage, AgentText
23
+ from smolagents.agents import MultiStepAgent, PlanningStep
24
+ from smolagents.memory import ActionStep, FinalAnswerStep
25
+ from smolagents.models import ChatMessageStreamDelta, MessageRole, agglomerate_stream_deltas
26
+ from smolagents.utils import _is_package_available
27
+
28
+
29
+ def get_step_footnote_content(step_log: ActionStep | PlanningStep, step_name: str) -> str:
30
+ """Get a footnote string for a step log with duration and token information"""
31
+ step_footnote = f"**{step_name}**"
32
+ if step_log.token_usage is not None:
33
+ step_footnote += f" | Input tokens: {step_log.token_usage.input_tokens:,} | Output tokens: {step_log.token_usage.output_tokens:,}"
34
+ step_footnote += f" | Duration: {round(float(step_log.timing.duration), 2)}s" if step_log.timing.duration else ""
35
+ step_footnote_content = f"""<span style="color: #bbbbc2; font-size: 12px;">{step_footnote}</span> """
36
+ return step_footnote_content
37
+
38
+
39
+ def _clean_model_output(model_output: str) -> str:
40
+ """
41
+ Clean up model output by removing trailing tags and extra backticks.
42
+
43
+ Args:
44
+ model_output (`str`): Raw model output.
45
+
46
+ Returns:
47
+ `str`: Cleaned model output.
48
+ """
49
+ if not model_output:
50
+ return ""
51
+ model_output = model_output.strip()
52
+ # Remove any trailing <end_code> and extra backticks, handling multiple possible formats
53
+ model_output = re.sub(r"```\s*<end_code>", "```", model_output) # handles ```<end_code>
54
+ model_output = re.sub(r"<end_code>\s*```", "```", model_output) # handles <end_code>```
55
+ model_output = re.sub(r"```\s*\n\s*<end_code>", "```", model_output) # handles ```\n<end_code>
56
+ return model_output.strip()
57
+
58
+
59
+ def _format_code_content(content: str) -> str:
60
+ """
61
+ Format code content as Python code block if it's not already formatted.
62
+
63
+ Args:
64
+ content (`str`): Code content to format.
65
+
66
+ Returns:
67
+ `str`: Code content formatted as a Python code block.
68
+ """
69
+ content = content.strip()
70
+ # Remove existing code blocks and end_code tags
71
+ content = re.sub(r"```.*?\n", "", content)
72
+ content = re.sub(r"\s*<end_code>\s*", "", content)
73
+ content = content.strip()
74
+ # Add Python code block formatting if not already present
75
+ if not content.startswith("```python"):
76
+ content = f"```python\n{content}\n```"
77
+ return content
78
+
79
+
80
+ def _process_action_step(step_log: ActionStep, skip_model_outputs: bool = False) -> Generator:
81
+ """
82
+ Process an [`ActionStep`] and yield appropriate Gradio ChatMessage objects.
83
+
84
+ Args:
85
+ step_log ([`ActionStep`]): ActionStep to process.
86
+ skip_model_outputs (`bool`): Whether to skip model outputs.
87
+
88
+ Yields:
89
+ `gradio.ChatMessage`: Gradio ChatMessages representing the action step.
90
+ """
91
+ import gradio as gr
92
+
93
+ # Output the step number
94
+ step_number = f"Step {step_log.step_number}"
95
+ if not skip_model_outputs:
96
+ yield gr.ChatMessage(role=MessageRole.ASSISTANT, content=f"**{step_number}**", metadata={"status": "done"})
97
+
98
+ # First yield the thought/reasoning from the LLM
99
+ if not skip_model_outputs and getattr(step_log, "model_output", ""):
100
+ model_output = _clean_model_output(step_log.model_output)
101
+ yield gr.ChatMessage(role=MessageRole.ASSISTANT, content=model_output, metadata={"status": "done"})
102
+
103
+ # For tool calls, create a parent message
104
+ if getattr(step_log, "tool_calls", []):
105
+ first_tool_call = step_log.tool_calls[0]
106
+ used_code = first_tool_call.name == "python_interpreter"
107
+
108
+ # Process arguments based on type
109
+ args = first_tool_call.arguments
110
+ if isinstance(args, dict):
111
+ content = str(args.get("answer", str(args)))
112
+ else:
113
+ content = str(args).strip()
114
+
115
+ # Format code content if needed
116
+ if used_code:
117
+ content = _format_code_content(content)
118
+
119
+ # Create the tool call message
120
+ parent_message_tool = gr.ChatMessage(
121
+ role=MessageRole.ASSISTANT,
122
+ content=content,
123
+ metadata={
124
+ "title": f"🛠️ Used tool {first_tool_call.name}",
125
+ "status": "done",
126
+ },
127
+ )
128
+ yield parent_message_tool
129
+
130
+ # Display execution logs if they exist
131
+ if getattr(step_log, "observations", "") and step_log.observations.strip():
132
+ log_content = step_log.observations.strip()
133
+ if log_content:
134
+ log_content = re.sub(r"^Execution logs:\s*", "", log_content)
135
+ yield gr.ChatMessage(
136
+ role=MessageRole.ASSISTANT,
137
+ content=f"```bash\n{log_content}\n",
138
+ metadata={"title": "📝 Execution Logs", "status": "done"},
139
+ )
140
+
141
+ # Display any images in observations
142
+ if getattr(step_log, "observations_images", []):
143
+ for image in step_log.observations_images:
144
+ path_image = AgentImage(image).to_string()
145
+ yield gr.ChatMessage(
146
+ role=MessageRole.ASSISTANT,
147
+ content={"path": path_image, "mime_type": f"image/{path_image.split('.')[-1]}"},
148
+ metadata={"title": "🖼️ Output Image", "status": "done"},
149
+ )
150
+
151
+ # Handle errors
152
+ if getattr(step_log, "error", None):
153
+ yield gr.ChatMessage(
154
+ role=MessageRole.ASSISTANT, content=str(step_log.error), metadata={"title": "💥 Error", "status": "done"}
155
+ )
156
+
157
+ # Add step footnote and separator
158
+ yield gr.ChatMessage(
159
+ role=MessageRole.ASSISTANT,
160
+ content=get_step_footnote_content(step_log, step_number),
161
+ metadata={"status": "done"},
162
+ )
163
+ yield gr.ChatMessage(role=MessageRole.ASSISTANT, content="-----", metadata={"status": "done"})
164
+
165
+
166
+ def _process_planning_step(step_log: PlanningStep, skip_model_outputs: bool = False) -> Generator:
167
+ """
168
+ Process a [`PlanningStep`] and yield appropriate gradio.ChatMessage objects.
169
+
170
+ Args:
171
+ step_log ([`PlanningStep`]): PlanningStep to process.
172
+
173
+ Yields:
174
+ `gradio.ChatMessage`: Gradio ChatMessages representing the planning step.
175
+ """
176
+ import gradio as gr
177
+
178
+ if not skip_model_outputs:
179
+ yield gr.ChatMessage(role=MessageRole.ASSISTANT, content="**Planning step**", metadata={"status": "done"})
180
+ yield gr.ChatMessage(role=MessageRole.ASSISTANT, content=step_log.plan, metadata={"status": "done"})
181
+ yield gr.ChatMessage(
182
+ role=MessageRole.ASSISTANT,
183
+ content=get_step_footnote_content(step_log, "Planning step"),
184
+ metadata={"status": "done"},
185
+ )
186
+ yield gr.ChatMessage(role=MessageRole.ASSISTANT, content="-----", metadata={"status": "done"})
187
+
188
+
189
+ def _process_final_answer_step(step_log: FinalAnswerStep) -> Generator:
190
+ """
191
+ Process a [`FinalAnswerStep`] and yield appropriate gradio.ChatMessage objects.
192
+
193
+ Args:
194
+ step_log ([`FinalAnswerStep`]): FinalAnswerStep to process.
195
+
196
+ Yields:
197
+ `gradio.ChatMessage`: Gradio ChatMessages representing the final answer.
198
+ """
199
+ import gradio as gr
200
+
201
+ final_answer = step_log.output
202
+ if isinstance(final_answer, AgentText):
203
+ yield gr.ChatMessage(
204
+ role=MessageRole.ASSISTANT,
205
+ content=f"**Final answer:**\n{final_answer.to_string()}\n",
206
+ metadata={"status": "done"},
207
+ )
208
+ elif isinstance(final_answer, AgentImage):
209
+ yield gr.ChatMessage(
210
+ role=MessageRole.ASSISTANT,
211
+ content={"path": final_answer.to_string(), "mime_type": "image/png"},
212
+ metadata={"status": "done"},
213
+ )
214
+ elif isinstance(final_answer, AgentAudio):
215
+ yield gr.ChatMessage(
216
+ role=MessageRole.ASSISTANT,
217
+ content={"path": final_answer.to_string(), "mime_type": "audio/wav"},
218
+ metadata={"status": "done"},
219
+ )
220
+ else:
221
+ yield gr.ChatMessage(
222
+ role=MessageRole.ASSISTANT, content=f"**Final answer:** {str(final_answer)}", metadata={"status": "done"}
223
+ )
224
+
225
+
226
+ def pull_messages_from_step(step_log: ActionStep | PlanningStep | FinalAnswerStep, skip_model_outputs: bool = False):
227
+ """Extract Gradio ChatMessage objects from agent steps with proper nesting.
228
+
229
+ Args:
230
+ step_log: The step log to display as gr.ChatMessage objects.
231
+ skip_model_outputs: If True, skip the model outputs when creating the gr.ChatMessage objects:
232
+ This is used for instance when streaming model outputs have already been displayed.
233
+ """
234
+ if not _is_package_available("gradio"):
235
+ raise ModuleNotFoundError(
236
+ "Please install 'gradio' extra to use the GradioUI: `pip install 'smolagents[gradio]'`"
237
+ )
238
+ if isinstance(step_log, ActionStep):
239
+ yield from _process_action_step(step_log, skip_model_outputs)
240
+ elif isinstance(step_log, PlanningStep):
241
+ yield from _process_planning_step(step_log, skip_model_outputs)
242
+ elif isinstance(step_log, FinalAnswerStep):
243
+ yield from _process_final_answer_step(step_log)
244
+ else:
245
+ raise ValueError(f"Unsupported step type: {type(step_log)}")
246
+
247
+
248
+ def stream_to_gradio(
249
+ agent,
250
+ task: str,
251
+ task_images: list | None = None,
252
+ reset_agent_memory: bool = False,
253
+ additional_args: dict | None = None,
254
+ ) -> Generator:
255
+ """Runs an agent with the given task and streams the messages from the agent as gradio ChatMessages."""
256
+
257
+ if not _is_package_available("gradio"):
258
+ raise ModuleNotFoundError(
259
+ "Please install 'gradio' extra to use the GradioUI: `pip install 'smolagents[gradio]'`"
260
+ )
261
+ accumulated_events: list[ChatMessageStreamDelta] = []
262
+ for event in agent.run(
263
+ task, images=task_images, stream=True, reset=reset_agent_memory, additional_args=additional_args
264
+ ):
265
+ if isinstance(event, ActionStep | PlanningStep | FinalAnswerStep):
266
+ for message in pull_messages_from_step(
267
+ event,
268
+ # If we're streaming model outputs, no need to display them twice
269
+ skip_model_outputs=getattr(agent, "stream_outputs", False),
270
+ ):
271
+ yield message
272
+ accumulated_events = []
273
+ elif isinstance(event, ChatMessageStreamDelta):
274
+ accumulated_events.append(event)
275
+ text = agglomerate_stream_deltas(accumulated_events).render_as_markdown()
276
+ yield text
277
+
278
+
279
+ class GradioUI:
280
+ """
281
+ Gradio interface for interacting with a [`MultiStepAgent`].
282
+
283
+ This class provides a web interface to interact with the agent in real-time, allowing users to submit prompts, upload files, and receive responses in a chat-like format.
284
+ It can reset the agent's memory at the start of each interaction if desired.
285
+ It supports file uploads, which are saved to a specified folder.
286
+ It uses the [`gradio.Chatbot`] component to display the conversation history.
287
+ This class requires the `gradio` extra to be installed: `smolagents[gradio]`.
288
+
289
+ Args:
290
+ agent ([`MultiStepAgent`]): The agent to interact with.
291
+ file_upload_folder (`str`, *optional*): The folder where uploaded files will be saved.
292
+ If not provided, file uploads are disabled.
293
+ reset_agent_memory (`bool`, *optional*, defaults to `False`): Whether to reset the agent's memory at the start of each interaction.
294
+ If `True`, the agent will not remember previous interactions.
295
+
296
+ Raises:
297
+ ModuleNotFoundError: If the `gradio` extra is not installed.
298
+
299
+ Example:
300
+ ```python
301
+ from smolagents import CodeAgent, GradioUI, InferenceClientModel
302
+
303
+ model = InferenceClientModel(model_id="meta-llama/Meta-Llama-3.1-8B-Instruct")
304
+ agent = CodeAgent(tools=[], model=model)
305
+ gradio_ui = GradioUI(agent, file_upload_folder="uploads", reset_agent_memory=True)
306
+ gradio_ui.launch()
307
+ ```
308
+ """
309
+
310
+ def __init__(self, agent: MultiStepAgent, file_upload_folder: str | None = None, reset_agent_memory: bool = False):
311
+ if not _is_package_available("gradio"):
312
+ raise ModuleNotFoundError(
313
+ "Please install 'gradio' extra to use the GradioUI: `pip install 'smolagents[gradio]'`"
314
+ )
315
+ self.agent = agent
316
+ self.file_upload_folder = Path(file_upload_folder) if file_upload_folder is not None else None
317
+ self.reset_agent_memory = reset_agent_memory
318
+ self.name = getattr(agent, "name") or "Agent interface"
319
+ self.description = getattr(agent, "description", None)
320
+ if self.file_upload_folder is not None:
321
+ if not self.file_upload_folder.exists():
322
+ self.file_upload_folder.mkdir(parents=True, exist_ok=True)
323
+
324
+ def interact_with_agent(self, prompt, messages, session_state):
325
+ import gradio as gr
326
+
327
+ # Get the agent type from the template agent
328
+ if "agent" not in session_state:
329
+ session_state["agent"] = self.agent
330
+
331
+ try:
332
+ messages.append(gr.ChatMessage(role="user", content=prompt, metadata={"status": "done"}))
333
+ yield messages
334
+
335
+ for msg in stream_to_gradio(
336
+ session_state["agent"], task=prompt, reset_agent_memory=self.reset_agent_memory
337
+ ):
338
+ if isinstance(msg, gr.ChatMessage):
339
+ messages[-1].metadata["status"] = "done"
340
+ messages.append(msg)
341
+ elif isinstance(msg, str): # Then it's only a completion delta
342
+ msg = msg.replace("<", r"\<").replace(">", r"\>") # HTML tags seem to break Gradio Chatbot
343
+ if messages[-1].metadata["status"] == "pending":
344
+ messages[-1].content = msg
345
+ else:
346
+ messages.append(
347
+ gr.ChatMessage(role=MessageRole.ASSISTANT, content=msg, metadata={"status": "pending"})
348
+ )
349
+ yield messages
350
+
351
+ yield messages
352
+ except Exception as e:
353
+ yield messages
354
+ raise gr.Error(f"Error in interaction: {str(e)}")
355
+
356
+ def upload_file(self, file, file_uploads_log, allowed_file_types=None):
357
+ """
358
+ Upload a file and add it to the list of uploaded files in the session state.
359
+
360
+ The file is saved to the `self.file_upload_folder` folder.
361
+ If the file type is not allowed, it returns a message indicating the disallowed file type.
362
+
363
+ Args:
364
+ file (`gradio.File`): The uploaded file.
365
+ file_uploads_log (`list`): A list to log uploaded files.
366
+ allowed_file_types (`list`, *optional*): List of allowed file extensions. Defaults to [".pdf", ".docx", ".txt"].
367
+ """
368
+ import gradio as gr
369
+
370
+ if file is None:
371
+ return gr.Textbox(value="No file uploaded", visible=True), file_uploads_log
372
+
373
+ if allowed_file_types is None:
374
+ allowed_file_types = [".pdf", ".docx", ".txt"]
375
+
376
+ file_ext = os.path.splitext(file.name)[1].lower()
377
+ if file_ext not in allowed_file_types:
378
+ return gr.Textbox("File type disallowed", visible=True), file_uploads_log
379
+
380
+ # Sanitize file name
381
+ original_name = os.path.basename(file.name)
382
+ sanitized_name = re.sub(
383
+ r"[^\w\-.]", "_", original_name
384
+ ) # Replace any non-alphanumeric, non-dash, or non-dot characters with underscores
385
+
386
+ # Save the uploaded file to the specified folder
387
+ file_path = os.path.join(self.file_upload_folder, os.path.basename(sanitized_name))
388
+ shutil.copy(file.name, file_path)
389
+
390
+ return gr.Textbox(f"File uploaded: {file_path}", visible=True), file_uploads_log + [file_path]
391
+
392
+ def log_user_message(self, text_input, file_uploads_log):
393
+ import gradio as gr
394
+
395
+ return (
396
+ text_input
397
+ + (
398
+ f"\nYou have been provided with these files, which might be helpful or not: {file_uploads_log}"
399
+ if len(file_uploads_log) > 0
400
+ else ""
401
+ ),
402
+ "",
403
+ gr.Button(interactive=False),
404
+ )
405
+
406
+ def launch(self, share: bool = True, **kwargs):
407
+ """
408
+ Launch the Gradio app with the agent interface.
409
+
410
+ Args:
411
+ share (`bool`, defaults to `True`): Whether to share the app publicly.
412
+ **kwargs: Additional keyword arguments to pass to the Gradio launch method.
413
+ """
414
+ self.create_app().launch(debug=True, share=share, **kwargs)
415
+
416
+ def create_app(self):
417
+ import gradio as gr
418
+
419
+ with gr.Blocks(theme="ocean", fill_height=True) as demo:
420
+ # Add session state to store session-specific data
421
+ session_state = gr.State({})
422
+ stored_messages = gr.State([])
423
+ file_uploads_log = gr.State([])
424
+
425
+ with gr.Sidebar():
426
+ gr.Markdown(
427
+ f"# {self.name.replace('_', ' ').capitalize()}"
428
+ "\n> This web ui allows you to interact with a `smolagents` agent that can use tools and execute steps to complete tasks."
429
+ + (f"\n\n**Agent description:**\n{self.description}" if self.description else "")
430
+ )
431
+
432
+ with gr.Group():
433
+ gr.Markdown("**Your request**", container=True)
434
+ text_input = gr.Textbox(
435
+ lines=3,
436
+ label="Chat Message",
437
+ container=False,
438
+ placeholder="Enter your prompt here and press Shift+Enter or press the button",
439
+ )
440
+ submit_btn = gr.Button("Submit", variant="primary")
441
+
442
+ # If an upload folder is provided, enable the upload feature
443
+ if self.file_upload_folder is not None:
444
+ upload_file = gr.File(label="Upload a file")
445
+ upload_status = gr.Textbox(label="Upload Status", interactive=False, visible=False)
446
+ upload_file.change(
447
+ self.upload_file,
448
+ [upload_file, file_uploads_log],
449
+ [upload_status, file_uploads_log],
450
+ )
451
+
452
+ gr.HTML(
453
+ "<br><br><h4><center>Powered by <a target='_blank' href='https://github.com/huggingface/smolagents'><b>smolagents</b></a></center></h4>"
454
+ )
455
+
456
+ # Main chat interface
457
+ chatbot = gr.Chatbot(
458
+ label="Agent",
459
+ type="messages",
460
+ avatar_images=(
461
+ None,
462
+ "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/smolagents/mascot_smol.png",
463
+ ),
464
+ resizeable=True,
465
+ scale=1,
466
+ latex_delimiters=[
467
+ {"left": r"$$", "right": r"$$", "display": True},
468
+ {"left": r"$", "right": r"$", "display": False},
469
+ {"left": r"\[", "right": r"\]", "display": True},
470
+ {"left": r"\(", "right": r"\)", "display": False},
471
+ ],
472
+ )
473
+
474
+ # Set up event handlers
475
+ text_input.submit(
476
+ self.log_user_message,
477
+ [text_input, file_uploads_log],
478
+ [stored_messages, text_input, submit_btn],
479
+ ).then(self.interact_with_agent, [stored_messages, chatbot, session_state], [chatbot]).then(
480
+ lambda: (
481
+ gr.Textbox(
482
+ interactive=True, placeholder="Enter your prompt here and press Shift+Enter or the button"
483
+ ),
484
+ gr.Button(interactive=True),
485
+ ),
486
+ None,
487
+ [text_input, submit_btn],
488
+ )
489
+
490
+ submit_btn.click(
491
+ self.log_user_message,
492
+ [text_input, file_uploads_log],
493
+ [stored_messages, text_input, submit_btn],
494
+ ).then(self.interact_with_agent, [stored_messages, chatbot, session_state], [chatbot]).then(
495
+ lambda: (
496
+ gr.Textbox(
497
+ interactive=True, placeholder="Enter your prompt here and press Shift+Enter or the button"
498
+ ),
499
+ gr.Button(interactive=True),
500
+ ),
501
+ None,
502
+ [text_input, submit_btn],
503
+ )
504
+
505
+ return demo
506
+
507
+
508
+ __all__ = ["stream_to_gradio", "GradioUI"]
src/smolagents/local_python_executor.py ADDED
@@ -0,0 +1,1611 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+
4
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ import ast
18
+ import builtins
19
+ import difflib
20
+ import inspect
21
+ import logging
22
+ import math
23
+ import re
24
+ from collections.abc import Callable, Mapping
25
+ from functools import wraps
26
+ from importlib import import_module
27
+ from types import BuiltinFunctionType, FunctionType, ModuleType
28
+ from typing import Any
29
+
30
+ from .tools import Tool
31
+ from .utils import BASE_BUILTIN_MODULES, truncate_content
32
+
33
+
34
+ logger = logging.getLogger(__name__)
35
+
36
+
37
+ class InterpreterError(ValueError):
38
+ """
39
+ An error raised when the interpreter cannot evaluate a Python expression, due to syntax error or unsupported
40
+ operations.
41
+ """
42
+
43
+ pass
44
+
45
+
46
+ ERRORS = {
47
+ name: getattr(builtins, name)
48
+ for name in dir(builtins)
49
+ if isinstance(getattr(builtins, name), type) and issubclass(getattr(builtins, name), BaseException)
50
+ }
51
+
52
+ DEFAULT_MAX_LEN_OUTPUT = 50000
53
+ MAX_OPERATIONS = 10000000
54
+ MAX_WHILE_ITERATIONS = 1000000
55
+
56
+
57
+ def custom_print(*args):
58
+ return None
59
+
60
+
61
+ def nodunder_getattr(obj, name, default=None):
62
+ if name.startswith("__") and name.endswith("__"):
63
+ raise InterpreterError(f"Forbidden access to dunder attribute: {name}")
64
+ return getattr(obj, name, default)
65
+
66
+
67
+ BASE_PYTHON_TOOLS = {
68
+ "print": custom_print,
69
+ "isinstance": isinstance,
70
+ "range": range,
71
+ "float": float,
72
+ "int": int,
73
+ "bool": bool,
74
+ "str": str,
75
+ "set": set,
76
+ "list": list,
77
+ "dict": dict,
78
+ "tuple": tuple,
79
+ "round": round,
80
+ "ceil": math.ceil,
81
+ "floor": math.floor,
82
+ "log": math.log,
83
+ "exp": math.exp,
84
+ "sin": math.sin,
85
+ "cos": math.cos,
86
+ "tan": math.tan,
87
+ "asin": math.asin,
88
+ "acos": math.acos,
89
+ "atan": math.atan,
90
+ "atan2": math.atan2,
91
+ "degrees": math.degrees,
92
+ "radians": math.radians,
93
+ "pow": pow,
94
+ "sqrt": math.sqrt,
95
+ "len": len,
96
+ "sum": sum,
97
+ "max": max,
98
+ "min": min,
99
+ "abs": abs,
100
+ "enumerate": enumerate,
101
+ "zip": zip,
102
+ "reversed": reversed,
103
+ "sorted": sorted,
104
+ "all": all,
105
+ "any": any,
106
+ "map": map,
107
+ "filter": filter,
108
+ "ord": ord,
109
+ "chr": chr,
110
+ "next": next,
111
+ "iter": iter,
112
+ "divmod": divmod,
113
+ "callable": callable,
114
+ "getattr": nodunder_getattr,
115
+ "hasattr": hasattr,
116
+ "setattr": setattr,
117
+ "issubclass": issubclass,
118
+ "type": type,
119
+ "complex": complex,
120
+ }
121
+
122
+ # Non-exhaustive list of dangerous modules that should not be imported
123
+ DANGEROUS_MODULES = [
124
+ "builtins",
125
+ "io",
126
+ "multiprocessing",
127
+ "os",
128
+ "pathlib",
129
+ "pty",
130
+ "shutil",
131
+ "socket",
132
+ "subprocess",
133
+ "sys",
134
+ ]
135
+
136
+ DANGEROUS_FUNCTIONS = [
137
+ "builtins.compile",
138
+ "builtins.eval",
139
+ "builtins.exec",
140
+ "builtins.globals",
141
+ "builtins.locals",
142
+ "builtins.__import__",
143
+ "os.popen",
144
+ "os.system",
145
+ "posix.system",
146
+ ]
147
+
148
+
149
+ def check_safer_result(result: Any, static_tools: dict[str, Callable] = None, authorized_imports: list[str] = None):
150
+ """
151
+ Checks if a result is safer according to authorized imports and static tools.
152
+
153
+ Args:
154
+ result (Any): The result to check.
155
+ static_tools (dict[str, Callable]): Dictionary of static tools.
156
+ authorized_imports (list[str]): List of authorized imports.
157
+
158
+ Raises:
159
+ InterpreterError: If the result is not safe
160
+ """
161
+ if isinstance(result, ModuleType):
162
+ if not check_import_authorized(result.__name__, authorized_imports):
163
+ raise InterpreterError(f"Forbidden access to module: {result.__name__}")
164
+ elif isinstance(result, dict) and result.get("__spec__"):
165
+ if not check_import_authorized(result["__name__"], authorized_imports):
166
+ raise InterpreterError(f"Forbidden access to module: {result['__name__']}")
167
+ elif isinstance(result, (FunctionType, BuiltinFunctionType)):
168
+ for qualified_function_name in DANGEROUS_FUNCTIONS:
169
+ module_name, function_name = qualified_function_name.rsplit(".", 1)
170
+ if (
171
+ (static_tools is None or function_name not in static_tools)
172
+ and result.__name__ == function_name
173
+ and result.__module__ == module_name
174
+ ):
175
+ raise InterpreterError(f"Forbidden access to function: {function_name}")
176
+
177
+
178
+ def safer_eval(func: Callable):
179
+ """
180
+ Decorator to enhance the security of an evaluation function by checking its return value.
181
+
182
+ Args:
183
+ func (Callable): Evaluation function to be made safer.
184
+
185
+ Returns:
186
+ Callable: Safer evaluation function with return value check.
187
+ """
188
+
189
+ @wraps(func)
190
+ def _check_return(
191
+ expression,
192
+ state,
193
+ static_tools,
194
+ custom_tools,
195
+ authorized_imports=BASE_BUILTIN_MODULES,
196
+ ):
197
+ result = func(expression, state, static_tools, custom_tools, authorized_imports=authorized_imports)
198
+ check_safer_result(result, static_tools, authorized_imports)
199
+ return result
200
+
201
+ return _check_return
202
+
203
+
204
+ def safer_func(
205
+ func: Callable,
206
+ static_tools: dict[str, Callable] = BASE_PYTHON_TOOLS,
207
+ authorized_imports: list[str] = BASE_BUILTIN_MODULES,
208
+ ):
209
+ """
210
+ Decorator to enhance the security of a function call by checking its return value.
211
+
212
+ Args:
213
+ func (Callable): Function to be made safer.
214
+ static_tools (dict[str, Callable]): Dictionary of static tools.
215
+ authorized_imports (list[str]): List of authorized imports.
216
+
217
+ Returns:
218
+ Callable: Safer function with return value check.
219
+ """
220
+ # If the function is a type, return it directly without wrapping
221
+ if isinstance(func, type):
222
+ return func
223
+
224
+ @wraps(func)
225
+ def _check_return(*args, **kwargs):
226
+ result = func(*args, **kwargs)
227
+ check_safer_result(result, static_tools, authorized_imports)
228
+ return result
229
+
230
+ return _check_return
231
+
232
+
233
+ class PrintContainer:
234
+ def __init__(self):
235
+ self.value = ""
236
+
237
+ def append(self, text):
238
+ self.value += text
239
+ return self
240
+
241
+ def __iadd__(self, other):
242
+ """Implements the += operator"""
243
+ self.value += str(other)
244
+ return self
245
+
246
+ def __str__(self):
247
+ """String representation"""
248
+ return self.value
249
+
250
+ def __repr__(self):
251
+ """Representation for debugging"""
252
+ return f"PrintContainer({self.value})"
253
+
254
+ def __len__(self):
255
+ """Implements len() function support"""
256
+ return len(self.value)
257
+
258
+
259
+ class BreakException(Exception):
260
+ pass
261
+
262
+
263
+ class ContinueException(Exception):
264
+ pass
265
+
266
+
267
+ class ReturnException(Exception):
268
+ def __init__(self, value):
269
+ self.value = value
270
+
271
+
272
+ def get_iterable(obj):
273
+ if isinstance(obj, list):
274
+ return obj
275
+ elif hasattr(obj, "__iter__"):
276
+ return list(obj)
277
+ else:
278
+ raise InterpreterError("Object is not iterable")
279
+
280
+
281
+ def fix_final_answer_code(code: str) -> str:
282
+ """
283
+ Sometimes an LLM can try to assign a variable to final_answer, which would break the final_answer() tool.
284
+ This function fixes this behaviour by replacing variable assignments to final_answer with final_answer_variable,
285
+ while preserving function calls to final_answer().
286
+ """
287
+ # First, find if there's a direct assignment to final_answer
288
+ # Use word boundary and negative lookbehind to ensure it's not an object attribute
289
+ assignment_pattern = r"(?<!\.)(?<!\w)\bfinal_answer\s*="
290
+ if "final_answer(" not in code or not re.search(assignment_pattern, code):
291
+ # If final_answer tool is not called in this blob, then doing the replacement is hazardous because it could false the model's memory for next steps.
292
+ # Let's not modify the code and leave the subsequent assignment error happen.
293
+ return code
294
+
295
+ # Pattern for replacing variable assignments
296
+ # Looks for 'final_answer' followed by '=' with optional whitespace
297
+ # Negative lookbehind ensures we don't match object attributes
298
+ assignment_regex = r"(?<!\.)(?<!\w)(\bfinal_answer)(\s*=)"
299
+ code = re.sub(assignment_regex, r"final_answer_variable\2", code)
300
+
301
+ # Pattern for replacing variable usage but not function calls
302
+ # Negative lookahead (?!\s*\() ensures we don't match function calls
303
+ # Negative lookbehind (?<!\.|\w) ensures we don't match object methods or other variables
304
+ variable_regex = r"(?<!\.)(?<!\w)(\bfinal_answer\b)(?!\s*\()"
305
+ code = re.sub(variable_regex, "final_answer_variable", code)
306
+ return code
307
+
308
+
309
+ def build_import_tree(authorized_imports: list[str]) -> dict[str, Any]:
310
+ tree = {}
311
+ for import_path in authorized_imports:
312
+ parts = import_path.split(".")
313
+ current = tree
314
+ for part in parts:
315
+ if part not in current:
316
+ current[part] = {}
317
+ current = current[part]
318
+ return tree
319
+
320
+
321
+ def check_import_authorized(import_to_check: str, authorized_imports: list[str]) -> bool:
322
+ current_node = build_import_tree(authorized_imports)
323
+ for part in import_to_check.split("."):
324
+ if "*" in current_node:
325
+ return True
326
+ if part not in current_node:
327
+ return False
328
+ current_node = current_node[part]
329
+ return True
330
+
331
+
332
+ def evaluate_attribute(
333
+ expression: ast.Attribute,
334
+ state: dict[str, Any],
335
+ static_tools: dict[str, Callable],
336
+ custom_tools: dict[str, Callable],
337
+ authorized_imports: list[str],
338
+ ) -> Any:
339
+ if expression.attr.startswith("__") and expression.attr.endswith("__"):
340
+ raise InterpreterError(f"Forbidden access to dunder attribute: {expression.attr}")
341
+ value = evaluate_ast(expression.value, state, static_tools, custom_tools, authorized_imports)
342
+ return getattr(value, expression.attr)
343
+
344
+
345
+ def evaluate_unaryop(
346
+ expression: ast.UnaryOp,
347
+ state: dict[str, Any],
348
+ static_tools: dict[str, Callable],
349
+ custom_tools: dict[str, Callable],
350
+ authorized_imports: list[str],
351
+ ) -> Any:
352
+ operand = evaluate_ast(expression.operand, state, static_tools, custom_tools, authorized_imports)
353
+ if isinstance(expression.op, ast.USub):
354
+ return -operand
355
+ elif isinstance(expression.op, ast.UAdd):
356
+ return operand
357
+ elif isinstance(expression.op, ast.Not):
358
+ return not operand
359
+ elif isinstance(expression.op, ast.Invert):
360
+ return ~operand
361
+ else:
362
+ raise InterpreterError(f"Unary operation {expression.op.__class__.__name__} is not supported.")
363
+
364
+
365
+ def evaluate_lambda(
366
+ lambda_expression: ast.Lambda,
367
+ state: dict[str, Any],
368
+ static_tools: dict[str, Callable],
369
+ custom_tools: dict[str, Callable],
370
+ authorized_imports: list[str],
371
+ ) -> Callable:
372
+ args = [arg.arg for arg in lambda_expression.args.args]
373
+
374
+ def lambda_func(*values: Any) -> Any:
375
+ new_state = state.copy()
376
+ for arg, value in zip(args, values):
377
+ new_state[arg] = value
378
+ return evaluate_ast(
379
+ lambda_expression.body,
380
+ new_state,
381
+ static_tools,
382
+ custom_tools,
383
+ authorized_imports,
384
+ )
385
+
386
+ return lambda_func
387
+
388
+
389
+ def evaluate_while(
390
+ while_loop: ast.While,
391
+ state: dict[str, Any],
392
+ static_tools: dict[str, Callable],
393
+ custom_tools: dict[str, Callable],
394
+ authorized_imports: list[str],
395
+ ) -> None:
396
+ iterations = 0
397
+ while evaluate_ast(while_loop.test, state, static_tools, custom_tools, authorized_imports):
398
+ for node in while_loop.body:
399
+ try:
400
+ evaluate_ast(node, state, static_tools, custom_tools, authorized_imports)
401
+ except BreakException:
402
+ return None
403
+ except ContinueException:
404
+ break
405
+ iterations += 1
406
+ if iterations > MAX_WHILE_ITERATIONS:
407
+ raise InterpreterError(f"Maximum number of {MAX_WHILE_ITERATIONS} iterations in While loop exceeded")
408
+ return None
409
+
410
+
411
+ def create_function(
412
+ func_def: ast.FunctionDef,
413
+ state: dict[str, Any],
414
+ static_tools: dict[str, Callable],
415
+ custom_tools: dict[str, Callable],
416
+ authorized_imports: list[str],
417
+ ) -> Callable:
418
+ source_code = ast.unparse(func_def)
419
+
420
+ def new_func(*args: Any, **kwargs: Any) -> Any:
421
+ func_state = state.copy()
422
+ arg_names = [arg.arg for arg in func_def.args.args]
423
+ default_values = [
424
+ evaluate_ast(d, state, static_tools, custom_tools, authorized_imports) for d in func_def.args.defaults
425
+ ]
426
+
427
+ # Apply default values
428
+ defaults = dict(zip(arg_names[-len(default_values) :], default_values))
429
+
430
+ # Set positional arguments
431
+ for name, value in zip(arg_names, args):
432
+ func_state[name] = value
433
+
434
+ # Set keyword arguments
435
+ for name, value in kwargs.items():
436
+ func_state[name] = value
437
+
438
+ # Handle variable arguments
439
+ if func_def.args.vararg:
440
+ vararg_name = func_def.args.vararg.arg
441
+ func_state[vararg_name] = args
442
+
443
+ if func_def.args.kwarg:
444
+ kwarg_name = func_def.args.kwarg.arg
445
+ func_state[kwarg_name] = kwargs
446
+
447
+ # Set default values for arguments that were not provided
448
+ for name, value in defaults.items():
449
+ if name not in func_state:
450
+ func_state[name] = value
451
+
452
+ # Update function state with self and __class__
453
+ if func_def.args.args and func_def.args.args[0].arg == "self":
454
+ if args:
455
+ func_state["self"] = args[0]
456
+ func_state["__class__"] = args[0].__class__
457
+
458
+ result = None
459
+ try:
460
+ for stmt in func_def.body:
461
+ result = evaluate_ast(stmt, func_state, static_tools, custom_tools, authorized_imports)
462
+ except ReturnException as e:
463
+ result = e.value
464
+
465
+ if func_def.name == "__init__":
466
+ return None
467
+
468
+ return result
469
+
470
+ # Store original AST, source code, and name
471
+ new_func.__ast__ = func_def
472
+ new_func.__source__ = source_code
473
+ new_func.__name__ = func_def.name
474
+
475
+ return new_func
476
+
477
+
478
+ def evaluate_function_def(
479
+ func_def: ast.FunctionDef,
480
+ state: dict[str, Any],
481
+ static_tools: dict[str, Callable],
482
+ custom_tools: dict[str, Callable],
483
+ authorized_imports: list[str],
484
+ ) -> Callable:
485
+ custom_tools[func_def.name] = create_function(func_def, state, static_tools, custom_tools, authorized_imports)
486
+ return custom_tools[func_def.name]
487
+
488
+
489
+ def evaluate_class_def(
490
+ class_def: ast.ClassDef,
491
+ state: dict[str, Any],
492
+ static_tools: dict[str, Callable],
493
+ custom_tools: dict[str, Callable],
494
+ authorized_imports: list[str],
495
+ ) -> type:
496
+ class_name = class_def.name
497
+ bases = [evaluate_ast(base, state, static_tools, custom_tools, authorized_imports) for base in class_def.bases]
498
+ class_dict = {}
499
+
500
+ for stmt in class_def.body:
501
+ if isinstance(stmt, ast.FunctionDef):
502
+ class_dict[stmt.name] = evaluate_ast(stmt, state, static_tools, custom_tools, authorized_imports)
503
+ elif isinstance(stmt, ast.AnnAssign):
504
+ if stmt.value:
505
+ value = evaluate_ast(stmt.value, state, static_tools, custom_tools, authorized_imports)
506
+ target = stmt.target
507
+ # Handle target types for annotation
508
+ if isinstance(target, ast.Name):
509
+ # Simple variable annotation like "x: int"
510
+ annotation = evaluate_ast(stmt.annotation, state, static_tools, custom_tools, authorized_imports)
511
+ class_dict.setdefault("__annotations__", {})[target.id] = annotation
512
+ # Assign value if provided
513
+ if stmt.value:
514
+ class_dict[target.id] = value
515
+ elif isinstance(target, ast.Attribute):
516
+ # Attribute annotation like "obj.attr: int"
517
+ obj = evaluate_ast(target.value, class_dict, static_tools, custom_tools, authorized_imports)
518
+ # If there's a value assignment, set the attribute
519
+ if stmt.value:
520
+ setattr(obj, target.attr, value)
521
+ elif isinstance(target, ast.Subscript):
522
+ # Subscript annotation like "dict[key]: int"
523
+ container = evaluate_ast(target.value, class_dict, static_tools, custom_tools, authorized_imports)
524
+ index = evaluate_ast(target.slice, state, static_tools, custom_tools, authorized_imports)
525
+ # If there's a value assignment, set the item
526
+ if stmt.value:
527
+ container[index] = value
528
+ else:
529
+ raise InterpreterError(f"Unsupported AnnAssign target in class body: {type(target).__name__}")
530
+ elif isinstance(stmt, ast.Assign):
531
+ value = evaluate_ast(stmt.value, state, static_tools, custom_tools, authorized_imports)
532
+ for target in stmt.targets:
533
+ if isinstance(target, ast.Name):
534
+ class_dict[target.id] = value
535
+ elif isinstance(target, ast.Attribute):
536
+ obj = evaluate_ast(target.value, class_dict, static_tools, custom_tools, authorized_imports)
537
+ setattr(obj, target.attr, value)
538
+ elif isinstance(stmt, ast.Pass):
539
+ pass
540
+ elif (
541
+ isinstance(stmt, ast.Expr)
542
+ and stmt == class_def.body[0]
543
+ and isinstance(stmt.value, ast.Constant)
544
+ and isinstance(stmt.value.value, str)
545
+ ):
546
+ # Check if it is a docstring: first statement in class body which is a string literal expression
547
+ class_dict["__doc__"] = stmt.value.value
548
+ else:
549
+ raise InterpreterError(f"Unsupported statement in class body: {stmt.__class__.__name__}")
550
+
551
+ new_class = type(class_name, tuple(bases), class_dict)
552
+ state[class_name] = new_class
553
+ return new_class
554
+
555
+
556
+ def evaluate_annassign(
557
+ annassign: ast.AnnAssign,
558
+ state: dict[str, Any],
559
+ static_tools: dict[str, Callable],
560
+ custom_tools: dict[str, Callable],
561
+ authorized_imports: list[str],
562
+ ) -> Any:
563
+ # If there's a value to assign, evaluate it
564
+ if annassign.value:
565
+ value = evaluate_ast(annassign.value, state, static_tools, custom_tools, authorized_imports)
566
+ # Set the value for the target
567
+ set_value(annassign.target, value, state, static_tools, custom_tools, authorized_imports)
568
+ return value
569
+ # For declarations without values (x: int), just return None
570
+ return None
571
+
572
+
573
+ def evaluate_augassign(
574
+ expression: ast.AugAssign,
575
+ state: dict[str, Any],
576
+ static_tools: dict[str, Callable],
577
+ custom_tools: dict[str, Callable],
578
+ authorized_imports: list[str],
579
+ ) -> Any:
580
+ def get_current_value(target: ast.AST) -> Any:
581
+ if isinstance(target, ast.Name):
582
+ return state.get(target.id, 0)
583
+ elif isinstance(target, ast.Subscript):
584
+ obj = evaluate_ast(target.value, state, static_tools, custom_tools, authorized_imports)
585
+ key = evaluate_ast(target.slice, state, static_tools, custom_tools, authorized_imports)
586
+ return obj[key]
587
+ elif isinstance(target, ast.Attribute):
588
+ obj = evaluate_ast(target.value, state, static_tools, custom_tools, authorized_imports)
589
+ return getattr(obj, target.attr)
590
+ elif isinstance(target, ast.Tuple):
591
+ return tuple(get_current_value(elt) for elt in target.elts)
592
+ elif isinstance(target, ast.List):
593
+ return [get_current_value(elt) for elt in target.elts]
594
+ else:
595
+ raise InterpreterError("AugAssign not supported for {type(target)} targets.")
596
+
597
+ current_value = get_current_value(expression.target)
598
+ value_to_add = evaluate_ast(expression.value, state, static_tools, custom_tools, authorized_imports)
599
+
600
+ if isinstance(expression.op, ast.Add):
601
+ if isinstance(current_value, list):
602
+ if not isinstance(value_to_add, list):
603
+ raise InterpreterError(f"Cannot add non-list value {value_to_add} to a list.")
604
+ current_value += value_to_add
605
+ else:
606
+ current_value += value_to_add
607
+ elif isinstance(expression.op, ast.Sub):
608
+ current_value -= value_to_add
609
+ elif isinstance(expression.op, ast.Mult):
610
+ current_value *= value_to_add
611
+ elif isinstance(expression.op, ast.Div):
612
+ current_value /= value_to_add
613
+ elif isinstance(expression.op, ast.Mod):
614
+ current_value %= value_to_add
615
+ elif isinstance(expression.op, ast.Pow):
616
+ current_value **= value_to_add
617
+ elif isinstance(expression.op, ast.FloorDiv):
618
+ current_value //= value_to_add
619
+ elif isinstance(expression.op, ast.BitAnd):
620
+ current_value &= value_to_add
621
+ elif isinstance(expression.op, ast.BitOr):
622
+ current_value |= value_to_add
623
+ elif isinstance(expression.op, ast.BitXor):
624
+ current_value ^= value_to_add
625
+ elif isinstance(expression.op, ast.LShift):
626
+ current_value <<= value_to_add
627
+ elif isinstance(expression.op, ast.RShift):
628
+ current_value >>= value_to_add
629
+ else:
630
+ raise InterpreterError(f"Operation {type(expression.op).__name__} is not supported.")
631
+
632
+ # Update the state: current_value has been updated in-place
633
+ set_value(
634
+ expression.target,
635
+ current_value,
636
+ state,
637
+ static_tools,
638
+ custom_tools,
639
+ authorized_imports,
640
+ )
641
+
642
+ return current_value
643
+
644
+
645
+ def evaluate_boolop(
646
+ node: ast.BoolOp,
647
+ state: dict[str, Any],
648
+ static_tools: dict[str, Callable],
649
+ custom_tools: dict[str, Callable],
650
+ authorized_imports: list[str],
651
+ ) -> Any:
652
+ # Determine which value should trigger short-circuit based on operation type:
653
+ # - 'and' returns the first falsy value encountered (or the last value if all are truthy)
654
+ # - 'or' returns the first truthy value encountered (or the last value if all are falsy)
655
+ is_short_circuit_value = (lambda x: not x) if isinstance(node.op, ast.And) else (lambda x: bool(x))
656
+ for value in node.values:
657
+ result = evaluate_ast(value, state, static_tools, custom_tools, authorized_imports)
658
+ # Short-circuit: return immediately if the condition is met
659
+ if is_short_circuit_value(result):
660
+ return result
661
+ # If no short-circuit occurred, return the last evaluated value
662
+ return result
663
+
664
+
665
+ def evaluate_binop(
666
+ binop: ast.BinOp,
667
+ state: dict[str, Any],
668
+ static_tools: dict[str, Callable],
669
+ custom_tools: dict[str, Callable],
670
+ authorized_imports: list[str],
671
+ ) -> Any:
672
+ # Recursively evaluate the left and right operands
673
+ left_val = evaluate_ast(binop.left, state, static_tools, custom_tools, authorized_imports)
674
+ right_val = evaluate_ast(binop.right, state, static_tools, custom_tools, authorized_imports)
675
+
676
+ # Determine the operation based on the type of the operator in the BinOp
677
+ if isinstance(binop.op, ast.Add):
678
+ return left_val + right_val
679
+ elif isinstance(binop.op, ast.Sub):
680
+ return left_val - right_val
681
+ elif isinstance(binop.op, ast.Mult):
682
+ return left_val * right_val
683
+ elif isinstance(binop.op, ast.Div):
684
+ return left_val / right_val
685
+ elif isinstance(binop.op, ast.Mod):
686
+ return left_val % right_val
687
+ elif isinstance(binop.op, ast.Pow):
688
+ return left_val**right_val
689
+ elif isinstance(binop.op, ast.FloorDiv):
690
+ return left_val // right_val
691
+ elif isinstance(binop.op, ast.BitAnd):
692
+ return left_val & right_val
693
+ elif isinstance(binop.op, ast.BitOr):
694
+ return left_val | right_val
695
+ elif isinstance(binop.op, ast.BitXor):
696
+ return left_val ^ right_val
697
+ elif isinstance(binop.op, ast.LShift):
698
+ return left_val << right_val
699
+ elif isinstance(binop.op, ast.RShift):
700
+ return left_val >> right_val
701
+ else:
702
+ raise NotImplementedError(f"Binary operation {type(binop.op).__name__} is not implemented.")
703
+
704
+
705
+ def evaluate_assign(
706
+ assign: ast.Assign,
707
+ state: dict[str, Any],
708
+ static_tools: dict[str, Callable],
709
+ custom_tools: dict[str, Callable],
710
+ authorized_imports: list[str],
711
+ ) -> Any:
712
+ result = evaluate_ast(assign.value, state, static_tools, custom_tools, authorized_imports)
713
+ if len(assign.targets) == 1:
714
+ target = assign.targets[0]
715
+ set_value(target, result, state, static_tools, custom_tools, authorized_imports)
716
+ else:
717
+ expanded_values = []
718
+ for tgt in assign.targets:
719
+ if isinstance(tgt, ast.Starred):
720
+ expanded_values.extend(result)
721
+ else:
722
+ expanded_values.append(result)
723
+
724
+ for tgt, val in zip(assign.targets, expanded_values):
725
+ set_value(tgt, val, state, static_tools, custom_tools, authorized_imports)
726
+ return result
727
+
728
+
729
+ def set_value(
730
+ target: ast.AST,
731
+ value: Any,
732
+ state: dict[str, Any],
733
+ static_tools: dict[str, Callable],
734
+ custom_tools: dict[str, Callable],
735
+ authorized_imports: list[str],
736
+ ) -> None:
737
+ if isinstance(target, ast.Name):
738
+ if target.id in static_tools:
739
+ raise InterpreterError(f"Cannot assign to name '{target.id}': doing this would erase the existing tool!")
740
+ state[target.id] = value
741
+ elif isinstance(target, ast.Tuple):
742
+ if not isinstance(value, tuple):
743
+ if hasattr(value, "__iter__") and not isinstance(value, (str, bytes)):
744
+ value = tuple(value)
745
+ else:
746
+ raise InterpreterError("Cannot unpack non-tuple value")
747
+ if len(target.elts) != len(value):
748
+ raise InterpreterError("Cannot unpack tuple of wrong size")
749
+ for i, elem in enumerate(target.elts):
750
+ set_value(elem, value[i], state, static_tools, custom_tools, authorized_imports)
751
+ elif isinstance(target, ast.Subscript):
752
+ obj = evaluate_ast(target.value, state, static_tools, custom_tools, authorized_imports)
753
+ key = evaluate_ast(target.slice, state, static_tools, custom_tools, authorized_imports)
754
+ obj[key] = value
755
+ elif isinstance(target, ast.Attribute):
756
+ obj = evaluate_ast(target.value, state, static_tools, custom_tools, authorized_imports)
757
+ setattr(obj, target.attr, value)
758
+
759
+
760
+ def evaluate_call(
761
+ call: ast.Call,
762
+ state: dict[str, Any],
763
+ static_tools: dict[str, Callable],
764
+ custom_tools: dict[str, Callable],
765
+ authorized_imports: list[str],
766
+ ) -> Any:
767
+ if not isinstance(call.func, (ast.Call, ast.Lambda, ast.Attribute, ast.Name, ast.Subscript)):
768
+ raise InterpreterError(f"This is not a correct function: {call.func}).")
769
+
770
+ func, func_name = None, None
771
+
772
+ if isinstance(call.func, ast.Call):
773
+ func = evaluate_ast(call.func, state, static_tools, custom_tools, authorized_imports)
774
+ elif isinstance(call.func, ast.Lambda):
775
+ func = evaluate_ast(call.func, state, static_tools, custom_tools, authorized_imports)
776
+ elif isinstance(call.func, ast.Attribute):
777
+ obj = evaluate_ast(call.func.value, state, static_tools, custom_tools, authorized_imports)
778
+ func_name = call.func.attr
779
+ if not hasattr(obj, func_name):
780
+ raise InterpreterError(f"Object {obj} has no attribute {func_name}")
781
+ func = getattr(obj, func_name)
782
+ elif isinstance(call.func, ast.Name):
783
+ func_name = call.func.id
784
+ if func_name in state:
785
+ func = state[func_name]
786
+ elif func_name in static_tools:
787
+ func = static_tools[func_name]
788
+ elif func_name in custom_tools:
789
+ func = custom_tools[func_name]
790
+ elif func_name in ERRORS:
791
+ func = ERRORS[func_name]
792
+ else:
793
+ raise InterpreterError(
794
+ f"Forbidden function evaluation: '{call.func.id}' is not among the explicitly allowed tools or defined/imported in the preceding code"
795
+ )
796
+ elif isinstance(call.func, ast.Subscript):
797
+ func = evaluate_ast(call.func, state, static_tools, custom_tools, authorized_imports)
798
+ if not callable(func):
799
+ raise InterpreterError(f"This is not a correct function: {call.func}).")
800
+ func_name = None
801
+
802
+ args = []
803
+ for arg in call.args:
804
+ if isinstance(arg, ast.Starred):
805
+ args.extend(evaluate_ast(arg.value, state, static_tools, custom_tools, authorized_imports))
806
+ else:
807
+ args.append(evaluate_ast(arg, state, static_tools, custom_tools, authorized_imports))
808
+
809
+ kwargs = {
810
+ keyword.arg: evaluate_ast(keyword.value, state, static_tools, custom_tools, authorized_imports)
811
+ for keyword in call.keywords
812
+ }
813
+
814
+ if func_name == "super":
815
+ if not args:
816
+ if "__class__" in state and "self" in state:
817
+ return super(state["__class__"], state["self"])
818
+ else:
819
+ raise InterpreterError("super() needs at least one argument")
820
+ cls = args[0]
821
+ if not isinstance(cls, type):
822
+ raise InterpreterError("super() argument 1 must be type")
823
+ if len(args) == 1:
824
+ return super(cls)
825
+ elif len(args) == 2:
826
+ instance = args[1]
827
+ return super(cls, instance)
828
+ else:
829
+ raise InterpreterError("super() takes at most 2 arguments")
830
+ elif func_name == "print":
831
+ state["_print_outputs"] += " ".join(map(str, args)) + "\n"
832
+ return None
833
+ else: # Assume it's a callable object
834
+ if (inspect.getmodule(func) == builtins) and inspect.isbuiltin(func) and (func not in static_tools.values()):
835
+ raise InterpreterError(
836
+ f"Invoking a builtin function that has not been explicitly added as a tool is not allowed ({func_name})."
837
+ )
838
+ return func(*args, **kwargs)
839
+
840
+
841
+ def evaluate_subscript(
842
+ subscript: ast.Subscript,
843
+ state: dict[str, Any],
844
+ static_tools: dict[str, Callable],
845
+ custom_tools: dict[str, Callable],
846
+ authorized_imports: list[str],
847
+ ) -> Any:
848
+ index = evaluate_ast(subscript.slice, state, static_tools, custom_tools, authorized_imports)
849
+ value = evaluate_ast(subscript.value, state, static_tools, custom_tools, authorized_imports)
850
+ try:
851
+ return value[index]
852
+ except (KeyError, IndexError, TypeError) as e:
853
+ error_message = f"Could not index {value} with '{index}': {type(e).__name__}: {e}"
854
+ if isinstance(index, str) and isinstance(value, Mapping):
855
+ close_matches = difflib.get_close_matches(index, list(value.keys()))
856
+ if len(close_matches) > 0:
857
+ error_message += f". Maybe you meant one of these indexes instead: {str(close_matches)}"
858
+ raise InterpreterError(error_message) from e
859
+
860
+
861
+ def evaluate_name(
862
+ name: ast.Name,
863
+ state: dict[str, Any],
864
+ static_tools: dict[str, Callable],
865
+ custom_tools: dict[str, Callable],
866
+ authorized_imports: list[str],
867
+ ) -> Any:
868
+ if name.id in state:
869
+ return state[name.id]
870
+ elif name.id in static_tools:
871
+ return safer_func(static_tools[name.id], static_tools=static_tools, authorized_imports=authorized_imports)
872
+ elif name.id in custom_tools:
873
+ return custom_tools[name.id]
874
+ elif name.id in ERRORS:
875
+ return ERRORS[name.id]
876
+ close_matches = difflib.get_close_matches(name.id, list(state.keys()))
877
+ if len(close_matches) > 0:
878
+ return state[close_matches[0]]
879
+ raise InterpreterError(f"The variable `{name.id}` is not defined.")
880
+
881
+
882
+ def evaluate_condition(
883
+ condition: ast.Compare,
884
+ state: dict[str, Any],
885
+ static_tools: dict[str, Callable],
886
+ custom_tools: dict[str, Callable],
887
+ authorized_imports: list[str],
888
+ ) -> bool | object:
889
+ result = True
890
+ left = evaluate_ast(condition.left, state, static_tools, custom_tools, authorized_imports)
891
+ for i, (op, comparator) in enumerate(zip(condition.ops, condition.comparators)):
892
+ op = type(op)
893
+ right = evaluate_ast(comparator, state, static_tools, custom_tools, authorized_imports)
894
+ if op == ast.Eq:
895
+ current_result = left == right
896
+ elif op == ast.NotEq:
897
+ current_result = left != right
898
+ elif op == ast.Lt:
899
+ current_result = left < right
900
+ elif op == ast.LtE:
901
+ current_result = left <= right
902
+ elif op == ast.Gt:
903
+ current_result = left > right
904
+ elif op == ast.GtE:
905
+ current_result = left >= right
906
+ elif op == ast.Is:
907
+ current_result = left is right
908
+ elif op == ast.IsNot:
909
+ current_result = left is not right
910
+ elif op == ast.In:
911
+ current_result = left in right
912
+ elif op == ast.NotIn:
913
+ current_result = left not in right
914
+ else:
915
+ raise InterpreterError(f"Unsupported comparison operator: {op}")
916
+
917
+ if current_result is False:
918
+ return False
919
+ result = current_result if i == 0 else (result and current_result)
920
+ left = right
921
+ return result
922
+
923
+
924
+ def evaluate_if(
925
+ if_statement: ast.If,
926
+ state: dict[str, Any],
927
+ static_tools: dict[str, Callable],
928
+ custom_tools: dict[str, Callable],
929
+ authorized_imports: list[str],
930
+ ) -> Any:
931
+ result = None
932
+ test_result = evaluate_ast(if_statement.test, state, static_tools, custom_tools, authorized_imports)
933
+ if test_result:
934
+ for line in if_statement.body:
935
+ line_result = evaluate_ast(line, state, static_tools, custom_tools, authorized_imports)
936
+ if line_result is not None:
937
+ result = line_result
938
+ else:
939
+ for line in if_statement.orelse:
940
+ line_result = evaluate_ast(line, state, static_tools, custom_tools, authorized_imports)
941
+ if line_result is not None:
942
+ result = line_result
943
+ return result
944
+
945
+
946
+ def evaluate_for(
947
+ for_loop: ast.For,
948
+ state: dict[str, Any],
949
+ static_tools: dict[str, Callable],
950
+ custom_tools: dict[str, Callable],
951
+ authorized_imports: list[str],
952
+ ) -> Any:
953
+ result = None
954
+ iterator = evaluate_ast(for_loop.iter, state, static_tools, custom_tools, authorized_imports)
955
+ for counter in iterator:
956
+ set_value(
957
+ for_loop.target,
958
+ counter,
959
+ state,
960
+ static_tools,
961
+ custom_tools,
962
+ authorized_imports,
963
+ )
964
+ for node in for_loop.body:
965
+ try:
966
+ line_result = evaluate_ast(node, state, static_tools, custom_tools, authorized_imports)
967
+ if line_result is not None:
968
+ result = line_result
969
+ except BreakException:
970
+ break
971
+ except ContinueException:
972
+ continue
973
+ else:
974
+ continue
975
+ break
976
+ return result
977
+
978
+
979
+ def evaluate_listcomp(
980
+ listcomp: ast.ListComp,
981
+ state: dict[str, Any],
982
+ static_tools: dict[str, Callable],
983
+ custom_tools: dict[str, Callable],
984
+ authorized_imports: list[str],
985
+ ) -> list[Any]:
986
+ def inner_evaluate(generators: list[ast.comprehension], index: int, current_state: dict[str, Any]) -> list[Any]:
987
+ if index >= len(generators):
988
+ return [
989
+ evaluate_ast(
990
+ listcomp.elt,
991
+ current_state,
992
+ static_tools,
993
+ custom_tools,
994
+ authorized_imports,
995
+ )
996
+ ]
997
+ generator = generators[index]
998
+ iter_value = evaluate_ast(
999
+ generator.iter,
1000
+ current_state,
1001
+ static_tools,
1002
+ custom_tools,
1003
+ authorized_imports,
1004
+ )
1005
+ result = []
1006
+ for value in iter_value:
1007
+ new_state = current_state.copy()
1008
+ if isinstance(generator.target, ast.Tuple):
1009
+ for idx, elem in enumerate(generator.target.elts):
1010
+ new_state[elem.id] = value[idx]
1011
+ else:
1012
+ new_state[generator.target.id] = value
1013
+ if all(
1014
+ evaluate_ast(if_clause, new_state, static_tools, custom_tools, authorized_imports)
1015
+ for if_clause in generator.ifs
1016
+ ):
1017
+ result.extend(inner_evaluate(generators, index + 1, new_state))
1018
+ return result
1019
+
1020
+ return inner_evaluate(listcomp.generators, 0, state)
1021
+
1022
+
1023
+ def evaluate_setcomp(
1024
+ setcomp: ast.SetComp,
1025
+ state: dict[str, Any],
1026
+ static_tools: dict[str, Callable],
1027
+ custom_tools: dict[str, Callable],
1028
+ authorized_imports: list[str],
1029
+ ) -> set[Any]:
1030
+ result = set()
1031
+ for gen in setcomp.generators:
1032
+ iter_value = evaluate_ast(gen.iter, state, static_tools, custom_tools, authorized_imports)
1033
+ for value in iter_value:
1034
+ new_state = state.copy()
1035
+ set_value(
1036
+ gen.target,
1037
+ value,
1038
+ new_state,
1039
+ static_tools,
1040
+ custom_tools,
1041
+ authorized_imports,
1042
+ )
1043
+ if all(
1044
+ evaluate_ast(if_clause, new_state, static_tools, custom_tools, authorized_imports)
1045
+ for if_clause in gen.ifs
1046
+ ):
1047
+ element = evaluate_ast(
1048
+ setcomp.elt,
1049
+ new_state,
1050
+ static_tools,
1051
+ custom_tools,
1052
+ authorized_imports,
1053
+ )
1054
+ result.add(element)
1055
+ return result
1056
+
1057
+
1058
+ def evaluate_try(
1059
+ try_node: ast.Try,
1060
+ state: dict[str, Any],
1061
+ static_tools: dict[str, Callable],
1062
+ custom_tools: dict[str, Callable],
1063
+ authorized_imports: list[str],
1064
+ ) -> None:
1065
+ try:
1066
+ for stmt in try_node.body:
1067
+ evaluate_ast(stmt, state, static_tools, custom_tools, authorized_imports)
1068
+ except Exception as e:
1069
+ matched = False
1070
+ for handler in try_node.handlers:
1071
+ if handler.type is None or isinstance(
1072
+ e,
1073
+ evaluate_ast(handler.type, state, static_tools, custom_tools, authorized_imports),
1074
+ ):
1075
+ matched = True
1076
+ if handler.name:
1077
+ state[handler.name] = e
1078
+ for stmt in handler.body:
1079
+ evaluate_ast(stmt, state, static_tools, custom_tools, authorized_imports)
1080
+ break
1081
+ if not matched:
1082
+ raise e
1083
+ else:
1084
+ if try_node.orelse:
1085
+ for stmt in try_node.orelse:
1086
+ evaluate_ast(stmt, state, static_tools, custom_tools, authorized_imports)
1087
+ finally:
1088
+ if try_node.finalbody:
1089
+ for stmt in try_node.finalbody:
1090
+ evaluate_ast(stmt, state, static_tools, custom_tools, authorized_imports)
1091
+
1092
+
1093
+ def evaluate_raise(
1094
+ raise_node: ast.Raise,
1095
+ state: dict[str, Any],
1096
+ static_tools: dict[str, Callable],
1097
+ custom_tools: dict[str, Callable],
1098
+ authorized_imports: list[str],
1099
+ ) -> None:
1100
+ if raise_node.exc is not None:
1101
+ exc = evaluate_ast(raise_node.exc, state, static_tools, custom_tools, authorized_imports)
1102
+ else:
1103
+ exc = None
1104
+ if raise_node.cause is not None:
1105
+ cause = evaluate_ast(raise_node.cause, state, static_tools, custom_tools, authorized_imports)
1106
+ else:
1107
+ cause = None
1108
+ if exc is not None:
1109
+ if cause is not None:
1110
+ raise exc from cause
1111
+ else:
1112
+ raise exc
1113
+ else:
1114
+ raise InterpreterError("Re-raise is not supported without an active exception")
1115
+
1116
+
1117
+ def evaluate_assert(
1118
+ assert_node: ast.Assert,
1119
+ state: dict[str, Any],
1120
+ static_tools: dict[str, Callable],
1121
+ custom_tools: dict[str, Callable],
1122
+ authorized_imports: list[str],
1123
+ ) -> None:
1124
+ test_result = evaluate_ast(assert_node.test, state, static_tools, custom_tools, authorized_imports)
1125
+ if not test_result:
1126
+ if assert_node.msg:
1127
+ msg = evaluate_ast(assert_node.msg, state, static_tools, custom_tools, authorized_imports)
1128
+ raise AssertionError(msg)
1129
+ else:
1130
+ # Include the failing condition in the assertion message
1131
+ test_code = ast.unparse(assert_node.test)
1132
+ raise AssertionError(f"Assertion failed: {test_code}")
1133
+
1134
+
1135
+ def evaluate_with(
1136
+ with_node: ast.With,
1137
+ state: dict[str, Any],
1138
+ static_tools: dict[str, Callable],
1139
+ custom_tools: dict[str, Callable],
1140
+ authorized_imports: list[str],
1141
+ ) -> None:
1142
+ contexts = []
1143
+ for item in with_node.items:
1144
+ context_expr = evaluate_ast(item.context_expr, state, static_tools, custom_tools, authorized_imports)
1145
+ if item.optional_vars:
1146
+ state[item.optional_vars.id] = context_expr.__enter__()
1147
+ contexts.append(state[item.optional_vars.id])
1148
+ else:
1149
+ context_var = context_expr.__enter__()
1150
+ contexts.append(context_var)
1151
+
1152
+ try:
1153
+ for stmt in with_node.body:
1154
+ evaluate_ast(stmt, state, static_tools, custom_tools, authorized_imports)
1155
+ except Exception as e:
1156
+ for context in reversed(contexts):
1157
+ context.__exit__(type(e), e, e.__traceback__)
1158
+ raise
1159
+ else:
1160
+ for context in reversed(contexts):
1161
+ context.__exit__(None, None, None)
1162
+
1163
+
1164
+ def get_safe_module(raw_module, authorized_imports, visited=None):
1165
+ """Creates a safe copy of a module or returns the original if it's a function"""
1166
+ # If it's a function or non-module object, return it directly
1167
+ if not isinstance(raw_module, ModuleType):
1168
+ return raw_module
1169
+
1170
+ # Handle circular references: Initialize visited set for the first call
1171
+ if visited is None:
1172
+ visited = set()
1173
+
1174
+ module_id = id(raw_module)
1175
+ if module_id in visited:
1176
+ return raw_module # Return original for circular refs
1177
+
1178
+ visited.add(module_id)
1179
+
1180
+ # Create new module for actual modules
1181
+ safe_module = ModuleType(raw_module.__name__)
1182
+
1183
+ # Copy all attributes by reference, recursively checking modules
1184
+ for attr_name in dir(raw_module):
1185
+ try:
1186
+ attr_value = getattr(raw_module, attr_name)
1187
+ except (ImportError, AttributeError) as e:
1188
+ # lazy / dynamic loading module -> INFO log and skip
1189
+ logger.info(
1190
+ f"Skipping import error while copying {raw_module.__name__}.{attr_name}: {type(e).__name__} - {e}"
1191
+ )
1192
+ continue
1193
+ # Recursively process nested modules, passing visited set
1194
+ if isinstance(attr_value, ModuleType):
1195
+ attr_value = get_safe_module(attr_value, authorized_imports, visited=visited)
1196
+
1197
+ setattr(safe_module, attr_name, attr_value)
1198
+
1199
+ return safe_module
1200
+
1201
+
1202
+ def evaluate_import(expression, state, authorized_imports):
1203
+ if isinstance(expression, ast.Import):
1204
+ for alias in expression.names:
1205
+ if check_import_authorized(alias.name, authorized_imports):
1206
+ raw_module = import_module(alias.name)
1207
+ state[alias.asname or alias.name] = get_safe_module(raw_module, authorized_imports)
1208
+ else:
1209
+ raise InterpreterError(
1210
+ f"Import of {alias.name} is not allowed. Authorized imports are: {str(authorized_imports)}"
1211
+ )
1212
+ return None
1213
+ elif isinstance(expression, ast.ImportFrom):
1214
+ if check_import_authorized(expression.module, authorized_imports):
1215
+ raw_module = __import__(expression.module, fromlist=[alias.name for alias in expression.names])
1216
+ module = get_safe_module(raw_module, authorized_imports)
1217
+ if expression.names[0].name == "*": # Handle "from module import *"
1218
+ if hasattr(module, "__all__"): # If module has __all__, import only those names
1219
+ for name in module.__all__:
1220
+ state[name] = getattr(module, name)
1221
+ else: # If no __all__, import all public names (those not starting with '_')
1222
+ for name in dir(module):
1223
+ if not name.startswith("_"):
1224
+ state[name] = getattr(module, name)
1225
+ else: # regular from imports
1226
+ for alias in expression.names:
1227
+ if hasattr(module, alias.name):
1228
+ state[alias.asname or alias.name] = getattr(module, alias.name)
1229
+ else:
1230
+ raise InterpreterError(f"Module {expression.module} has no attribute {alias.name}")
1231
+ else:
1232
+ raise InterpreterError(
1233
+ f"Import from {expression.module} is not allowed. Authorized imports are: {str(authorized_imports)}"
1234
+ )
1235
+ return None
1236
+
1237
+
1238
+ def evaluate_dictcomp(
1239
+ dictcomp: ast.DictComp,
1240
+ state: dict[str, Any],
1241
+ static_tools: dict[str, Callable],
1242
+ custom_tools: dict[str, Callable],
1243
+ authorized_imports: list[str],
1244
+ ) -> dict[Any, Any]:
1245
+ result = {}
1246
+ for gen in dictcomp.generators:
1247
+ iter_value = evaluate_ast(gen.iter, state, static_tools, custom_tools, authorized_imports)
1248
+ for value in iter_value:
1249
+ new_state = state.copy()
1250
+ set_value(
1251
+ gen.target,
1252
+ value,
1253
+ new_state,
1254
+ static_tools,
1255
+ custom_tools,
1256
+ authorized_imports,
1257
+ )
1258
+ if all(
1259
+ evaluate_ast(if_clause, new_state, static_tools, custom_tools, authorized_imports)
1260
+ for if_clause in gen.ifs
1261
+ ):
1262
+ key = evaluate_ast(
1263
+ dictcomp.key,
1264
+ new_state,
1265
+ static_tools,
1266
+ custom_tools,
1267
+ authorized_imports,
1268
+ )
1269
+ val = evaluate_ast(
1270
+ dictcomp.value,
1271
+ new_state,
1272
+ static_tools,
1273
+ custom_tools,
1274
+ authorized_imports,
1275
+ )
1276
+ result[key] = val
1277
+ return result
1278
+
1279
+
1280
+ def evaluate_delete(
1281
+ delete_node: ast.Delete,
1282
+ state: dict[str, Any],
1283
+ static_tools: dict[str, Callable],
1284
+ custom_tools: dict[str, Callable],
1285
+ authorized_imports: list[str],
1286
+ ) -> None:
1287
+ """
1288
+ Evaluate a delete statement (del x, del x[y]).
1289
+
1290
+ Args:
1291
+ delete_node: The AST Delete node to evaluate
1292
+ state: The current state dictionary
1293
+ static_tools: Dictionary of static tools
1294
+ custom_tools: Dictionary of custom tools
1295
+ authorized_imports: List of authorized imports
1296
+ """
1297
+ for target in delete_node.targets:
1298
+ if isinstance(target, ast.Name):
1299
+ # Handle simple variable deletion (del x)
1300
+ if target.id in state:
1301
+ del state[target.id]
1302
+ else:
1303
+ raise InterpreterError(f"Cannot delete name '{target.id}': name is not defined")
1304
+ elif isinstance(target, ast.Subscript):
1305
+ # Handle index/key deletion (del x[y])
1306
+ obj = evaluate_ast(target.value, state, static_tools, custom_tools, authorized_imports)
1307
+ index = evaluate_ast(target.slice, state, static_tools, custom_tools, authorized_imports)
1308
+ try:
1309
+ del obj[index]
1310
+ except (TypeError, KeyError, IndexError) as e:
1311
+ raise InterpreterError(f"Cannot delete index/key: {str(e)}")
1312
+ else:
1313
+ raise InterpreterError(f"Deletion of {type(target).__name__} targets is not supported")
1314
+
1315
+
1316
+ @safer_eval
1317
+ def evaluate_ast(
1318
+ expression: ast.AST,
1319
+ state: dict[str, Any],
1320
+ static_tools: dict[str, Callable],
1321
+ custom_tools: dict[str, Callable],
1322
+ authorized_imports: list[str] = BASE_BUILTIN_MODULES,
1323
+ ):
1324
+ """
1325
+ Evaluate an abstract syntax tree using the content of the variables stored in a state and only evaluating a given
1326
+ set of functions.
1327
+
1328
+ This function will recurse through the nodes of the tree provided.
1329
+
1330
+ Args:
1331
+ expression (`ast.AST`):
1332
+ The code to evaluate, as an abstract syntax tree.
1333
+ state (`Dict[str, Any]`):
1334
+ A dictionary mapping variable names to values. The `state` is updated if need be when the evaluation
1335
+ encounters assignments.
1336
+ static_tools (`Dict[str, Callable]`):
1337
+ Functions that may be called during the evaluation. Trying to change one of these static_tools will raise an error.
1338
+ custom_tools (`Dict[str, Callable]`):
1339
+ Functions that may be called during the evaluation. These custom_tools can be overwritten.
1340
+ authorized_imports (`List[str]`):
1341
+ The list of modules that can be imported by the code. By default, only a few safe modules are allowed.
1342
+ If it contains "*", it will authorize any import. Use this at your own risk!
1343
+ """
1344
+ if state.setdefault("_operations_count", {"counter": 0})["counter"] >= MAX_OPERATIONS:
1345
+ raise InterpreterError(
1346
+ f"Reached the max number of operations of {MAX_OPERATIONS}. Maybe there is an infinite loop somewhere in the code, or you're just asking too many calculations."
1347
+ )
1348
+ state["_operations_count"]["counter"] += 1
1349
+ common_params = (state, static_tools, custom_tools, authorized_imports)
1350
+ if isinstance(expression, ast.Assign):
1351
+ # Assignment -> we evaluate the assignment which should update the state
1352
+ # We return the variable assigned as it may be used to determine the final result.
1353
+ return evaluate_assign(expression, *common_params)
1354
+ elif isinstance(expression, ast.AnnAssign):
1355
+ return evaluate_annassign(expression, *common_params)
1356
+ elif isinstance(expression, ast.AugAssign):
1357
+ return evaluate_augassign(expression, *common_params)
1358
+ elif isinstance(expression, ast.Call):
1359
+ # Function call -> we return the value of the function call
1360
+ return evaluate_call(expression, *common_params)
1361
+ elif isinstance(expression, ast.Constant):
1362
+ # Constant -> just return the value
1363
+ return expression.value
1364
+ elif isinstance(expression, ast.Tuple):
1365
+ return tuple((evaluate_ast(elt, *common_params) for elt in expression.elts))
1366
+ elif isinstance(expression, (ast.ListComp, ast.GeneratorExp)):
1367
+ return evaluate_listcomp(expression, *common_params)
1368
+ elif isinstance(expression, ast.DictComp):
1369
+ return evaluate_dictcomp(expression, *common_params)
1370
+ elif isinstance(expression, ast.SetComp):
1371
+ return evaluate_setcomp(expression, *common_params)
1372
+ elif isinstance(expression, ast.UnaryOp):
1373
+ return evaluate_unaryop(expression, *common_params)
1374
+ elif isinstance(expression, ast.Starred):
1375
+ return evaluate_ast(expression.value, *common_params)
1376
+ elif isinstance(expression, ast.BoolOp):
1377
+ # Boolean operation -> evaluate the operation
1378
+ return evaluate_boolop(expression, *common_params)
1379
+ elif isinstance(expression, ast.Break):
1380
+ raise BreakException()
1381
+ elif isinstance(expression, ast.Continue):
1382
+ raise ContinueException()
1383
+ elif isinstance(expression, ast.BinOp):
1384
+ # Binary operation -> execute operation
1385
+ return evaluate_binop(expression, *common_params)
1386
+ elif isinstance(expression, ast.Compare):
1387
+ # Comparison -> evaluate the comparison
1388
+ return evaluate_condition(expression, *common_params)
1389
+ elif isinstance(expression, ast.Lambda):
1390
+ return evaluate_lambda(expression, *common_params)
1391
+ elif isinstance(expression, ast.FunctionDef):
1392
+ return evaluate_function_def(expression, *common_params)
1393
+ elif isinstance(expression, ast.Dict):
1394
+ # Dict -> evaluate all keys and values
1395
+ keys = (evaluate_ast(k, *common_params) for k in expression.keys)
1396
+ values = (evaluate_ast(v, *common_params) for v in expression.values)
1397
+ return dict(zip(keys, values))
1398
+ elif isinstance(expression, ast.Expr):
1399
+ # Expression -> evaluate the content
1400
+ return evaluate_ast(expression.value, *common_params)
1401
+ elif isinstance(expression, ast.For):
1402
+ # For loop -> execute the loop
1403
+ return evaluate_for(expression, *common_params)
1404
+ elif isinstance(expression, ast.FormattedValue):
1405
+ # Formatted value (part of f-string) -> evaluate the content and format it
1406
+ value = evaluate_ast(expression.value, *common_params)
1407
+ # Early return if no format spec
1408
+ if not expression.format_spec:
1409
+ return value
1410
+ # Apply format specification
1411
+ format_spec = evaluate_ast(expression.format_spec, *common_params)
1412
+ return format(value, format_spec)
1413
+ elif isinstance(expression, ast.If):
1414
+ # If -> execute the right branch
1415
+ return evaluate_if(expression, *common_params)
1416
+ elif hasattr(ast, "Index") and isinstance(expression, ast.Index):
1417
+ return evaluate_ast(expression.value, *common_params)
1418
+ elif isinstance(expression, ast.JoinedStr):
1419
+ return "".join([str(evaluate_ast(v, *common_params)) for v in expression.values])
1420
+ elif isinstance(expression, ast.List):
1421
+ # List -> evaluate all elements
1422
+ return [evaluate_ast(elt, *common_params) for elt in expression.elts]
1423
+ elif isinstance(expression, ast.Name):
1424
+ # Name -> pick up the value in the state
1425
+ return evaluate_name(expression, *common_params)
1426
+ elif isinstance(expression, ast.Subscript):
1427
+ # Subscript -> return the value of the indexing
1428
+ return evaluate_subscript(expression, *common_params)
1429
+ elif isinstance(expression, ast.IfExp):
1430
+ test_val = evaluate_ast(expression.test, *common_params)
1431
+ if test_val:
1432
+ return evaluate_ast(expression.body, *common_params)
1433
+ else:
1434
+ return evaluate_ast(expression.orelse, *common_params)
1435
+ elif isinstance(expression, ast.Attribute):
1436
+ return evaluate_attribute(expression, *common_params)
1437
+ elif isinstance(expression, ast.Slice):
1438
+ return slice(
1439
+ evaluate_ast(expression.lower, *common_params) if expression.lower is not None else None,
1440
+ evaluate_ast(expression.upper, *common_params) if expression.upper is not None else None,
1441
+ evaluate_ast(expression.step, *common_params) if expression.step is not None else None,
1442
+ )
1443
+ elif isinstance(expression, ast.While):
1444
+ return evaluate_while(expression, *common_params)
1445
+ elif isinstance(expression, (ast.Import, ast.ImportFrom)):
1446
+ return evaluate_import(expression, state, authorized_imports)
1447
+ elif isinstance(expression, ast.ClassDef):
1448
+ return evaluate_class_def(expression, *common_params)
1449
+ elif isinstance(expression, ast.Try):
1450
+ return evaluate_try(expression, *common_params)
1451
+ elif isinstance(expression, ast.Raise):
1452
+ return evaluate_raise(expression, *common_params)
1453
+ elif isinstance(expression, ast.Assert):
1454
+ return evaluate_assert(expression, *common_params)
1455
+ elif isinstance(expression, ast.With):
1456
+ return evaluate_with(expression, *common_params)
1457
+ elif isinstance(expression, ast.Set):
1458
+ return set((evaluate_ast(elt, *common_params) for elt in expression.elts))
1459
+ elif isinstance(expression, ast.Return):
1460
+ raise ReturnException(evaluate_ast(expression.value, *common_params) if expression.value else None)
1461
+ elif isinstance(expression, ast.Pass):
1462
+ return None
1463
+ elif isinstance(expression, ast.Delete):
1464
+ return evaluate_delete(expression, *common_params)
1465
+ else:
1466
+ # For now we refuse anything else. Let's add things as we need them.
1467
+ raise InterpreterError(f"{expression.__class__.__name__} is not supported.")
1468
+
1469
+
1470
+ class FinalAnswerException(Exception):
1471
+ def __init__(self, value):
1472
+ self.value = value
1473
+
1474
+
1475
+ def evaluate_python_code(
1476
+ code: str,
1477
+ static_tools: dict[str, Callable] | None = None,
1478
+ custom_tools: dict[str, Callable] | None = None,
1479
+ state: dict[str, Any] | None = None,
1480
+ authorized_imports: list[str] = BASE_BUILTIN_MODULES,
1481
+ max_print_outputs_length: int = DEFAULT_MAX_LEN_OUTPUT,
1482
+ ):
1483
+ """
1484
+ Evaluate a python expression using the content of the variables stored in a state and only evaluating a given set
1485
+ of functions.
1486
+
1487
+ This function will recurse through the nodes of the tree provided.
1488
+
1489
+ Args:
1490
+ code (`str`):
1491
+ The code to evaluate.
1492
+ static_tools (`Dict[str, Callable]`):
1493
+ The functions that may be called during the evaluation. These can also be agents in a multiagent setting.
1494
+ These tools cannot be overwritten in the code: any assignment to their name will raise an error.
1495
+ custom_tools (`Dict[str, Callable]`):
1496
+ The functions that may be called during the evaluation.
1497
+ These tools can be overwritten in the code: any assignment to their name will overwrite them.
1498
+ state (`Dict[str, Any]`):
1499
+ A dictionary mapping variable names to values. The `state` should contain the initial inputs but will be
1500
+ updated by this function to contain all variables as they are evaluated.
1501
+ The print outputs will be stored in the state under the key "_print_outputs".
1502
+ """
1503
+ try:
1504
+ expression = ast.parse(code)
1505
+ except SyntaxError as e:
1506
+ raise InterpreterError(
1507
+ f"Code parsing failed on line {e.lineno} due to: {type(e).__name__}\n"
1508
+ f"{e.text}"
1509
+ f"{' ' * (e.offset or 0)}^\n"
1510
+ f"Error: {str(e)}"
1511
+ )
1512
+
1513
+ if state is None:
1514
+ state = {}
1515
+ static_tools = static_tools.copy() if static_tools is not None else {}
1516
+ custom_tools = custom_tools if custom_tools is not None else {}
1517
+ result = None
1518
+ state["_print_outputs"] = PrintContainer()
1519
+ state["_operations_count"] = {"counter": 0}
1520
+
1521
+ if "final_answer" in static_tools:
1522
+ previous_final_answer = static_tools["final_answer"]
1523
+
1524
+ def final_answer(*args, **kwargs): # Allow arbitrary arguments to be passed
1525
+ raise FinalAnswerException(previous_final_answer(*args, **kwargs))
1526
+
1527
+ static_tools["final_answer"] = final_answer
1528
+
1529
+ try:
1530
+ for node in expression.body:
1531
+ result = evaluate_ast(node, state, static_tools, custom_tools, authorized_imports)
1532
+ state["_print_outputs"].value = truncate_content(
1533
+ str(state["_print_outputs"]), max_length=max_print_outputs_length
1534
+ )
1535
+ is_final_answer = False
1536
+ return result, is_final_answer
1537
+ except FinalAnswerException as e:
1538
+ state["_print_outputs"].value = truncate_content(
1539
+ str(state["_print_outputs"]), max_length=max_print_outputs_length
1540
+ )
1541
+ is_final_answer = True
1542
+ return e.value, is_final_answer
1543
+ except Exception as e:
1544
+ state["_print_outputs"].value = truncate_content(
1545
+ str(state["_print_outputs"]), max_length=max_print_outputs_length
1546
+ )
1547
+ raise InterpreterError(
1548
+ f"Code execution failed at line '{ast.get_source_segment(code, node)}' due to: {type(e).__name__}: {e}"
1549
+ )
1550
+
1551
+
1552
+ class PythonExecutor:
1553
+ pass
1554
+
1555
+
1556
+ class LocalPythonExecutor(PythonExecutor):
1557
+ """
1558
+ Executor of Python code in a local environment.
1559
+
1560
+ This executor evaluates Python code with restricted access to imports and built-in functions,
1561
+ making it suitable for running untrusted code. It maintains state between executions,
1562
+ allows for custom tools and functions to be made available to the code, and captures
1563
+ print outputs separately from return values.
1564
+
1565
+ Args:
1566
+ additional_authorized_imports (`list[str]`):
1567
+ Additional authorized imports for the executor.
1568
+ max_print_outputs_length (`int`, defaults to `DEFAULT_MAX_LEN_OUTPUT=50_000`):
1569
+ Maximum length of the print outputs.
1570
+ additional_functions (`dict[str, Callable]`, *optional*):
1571
+ Additional Python functions to be added to the executor.
1572
+ """
1573
+
1574
+ def __init__(
1575
+ self,
1576
+ additional_authorized_imports: list[str],
1577
+ max_print_outputs_length: int | None = None,
1578
+ additional_functions: dict[str, Callable] | None = None,
1579
+ ):
1580
+ self.custom_tools = {}
1581
+ self.state = {"__name__": "__main__"}
1582
+ self.max_print_outputs_length = max_print_outputs_length
1583
+ if max_print_outputs_length is None:
1584
+ self.max_print_outputs_length = DEFAULT_MAX_LEN_OUTPUT
1585
+ self.additional_authorized_imports = additional_authorized_imports
1586
+ self.authorized_imports = list(set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports))
1587
+ # TODO: assert self.authorized imports are all installed locally
1588
+ self.static_tools = None
1589
+ self.additional_functions = additional_functions or {}
1590
+
1591
+ def __call__(self, code_action: str) -> tuple[Any, str, bool]:
1592
+ output, is_final_answer = evaluate_python_code(
1593
+ code_action,
1594
+ static_tools=self.static_tools,
1595
+ custom_tools=self.custom_tools,
1596
+ state=self.state,
1597
+ authorized_imports=self.authorized_imports,
1598
+ max_print_outputs_length=self.max_print_outputs_length,
1599
+ )
1600
+ logs = str(self.state["_print_outputs"])
1601
+ return output, logs, is_final_answer
1602
+
1603
+ def send_variables(self, variables: dict):
1604
+ self.state.update(variables)
1605
+
1606
+ def send_tools(self, tools: dict[str, Tool]):
1607
+ # Combine agent tools, base Python tools, and additional Python functions
1608
+ self.static_tools = {**tools, **BASE_PYTHON_TOOLS.copy(), **self.additional_functions}
1609
+
1610
+
1611
+ __all__ = ["evaluate_python_code", "LocalPythonExecutor"]
src/smolagents/mcp_client.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+
4
+ # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ from __future__ import annotations
19
+
20
+ import warnings
21
+ from types import TracebackType
22
+ from typing import TYPE_CHECKING, Any
23
+
24
+ from smolagents.tools import Tool
25
+
26
+
27
+ __all__ = ["MCPClient"]
28
+
29
+ if TYPE_CHECKING:
30
+ from mcpadapt.core import StdioServerParameters
31
+
32
+
33
+ class MCPClient:
34
+ """Manages the connection to an MCP server and make its tools available to SmolAgents.
35
+
36
+ Note: tools can only be accessed after the connection has been started with the
37
+ `connect()` method, done during the init. If you don't use the context manager
38
+ we strongly encourage to use "try ... finally" to ensure the connection is cleaned up.
39
+
40
+ Args:
41
+ server_parameters (StdioServerParameters | dict[str, Any] | list[StdioServerParameters | dict[str, Any]]):
42
+ Configuration parameters to connect to the MCP server. Can be a list if you want to connect multiple MCPs at once.
43
+
44
+ - An instance of `mcp.StdioServerParameters` for connecting a Stdio MCP server via standard input/output using a subprocess.
45
+
46
+ - A `dict` with at least:
47
+ - "url": URL of the server.
48
+ - "transport": Transport protocol to use, one of:
49
+ - "streamable-http": (recommended) Streamable HTTP transport.
50
+ - "sse": Legacy HTTP+SSE transport (deprecated).
51
+ If "transport" is omitted, the legacy "sse" transport is assumed (a deprecation warning will be issued).
52
+
53
+ <Deprecated version="1.17.0">
54
+ The HTTP+SSE transport is deprecated and future behavior will default to the Streamable HTTP transport.
55
+ Please pass explicitly the "transport" key.
56
+ </Deprecated>
57
+
58
+ Example:
59
+ ```python
60
+ # fully managed context manager + stdio
61
+ with MCPClient(...) as tools:
62
+ # tools are now available
63
+
64
+ # context manager + Streamable HTTP transport:
65
+ with MCPClient({"url": "http://localhost:8000/mcp", "transport": "streamable-http"}) as tools:
66
+ # tools are now available
67
+
68
+ # manually manage the connection via the mcp_client object:
69
+ try:
70
+ mcp_client = MCPClient(...)
71
+ tools = mcp_client.get_tools()
72
+
73
+ # use your tools here.
74
+ finally:
75
+ mcp_client.disconnect()
76
+ ```
77
+ """
78
+
79
+ def __init__(
80
+ self,
81
+ server_parameters: "StdioServerParameters" | dict[str, Any] | list["StdioServerParameters" | dict[str, Any]],
82
+ ):
83
+ try:
84
+ from mcpadapt.core import MCPAdapt
85
+ from mcpadapt.smolagents_adapter import SmolAgentsAdapter
86
+ except ModuleNotFoundError:
87
+ raise ModuleNotFoundError("Please install 'mcp' extra to use MCPClient: `pip install 'smolagents[mcp]'`")
88
+ if isinstance(server_parameters, dict):
89
+ transport = server_parameters.get("transport")
90
+ if transport is None:
91
+ warnings.warn(
92
+ "Passing a dict as server_parameters without specifying the 'transport' key is deprecated. "
93
+ "For now, it defaults to the legacy 'sse' (HTTP+SSE) transport, but this default will change "
94
+ "to 'streamable-http' in version 1.20. Please add the 'transport' key explicitly. ",
95
+ FutureWarning,
96
+ )
97
+ transport = "sse"
98
+ server_parameters["transport"] = transport
99
+ if transport not in {"sse", "streamable-http"}:
100
+ raise ValueError(
101
+ f"Unsupported transport: {transport}. Supported transports are 'streamable-http' and 'sse'."
102
+ )
103
+ self._adapter = MCPAdapt(server_parameters, SmolAgentsAdapter())
104
+ self._tools: list[Tool] | None = None
105
+ self.connect()
106
+
107
+ def connect(self):
108
+ """Connect to the MCP server and initialize the tools."""
109
+ self._tools: list[Tool] = self._adapter.__enter__()
110
+
111
+ def disconnect(
112
+ self,
113
+ exc_type: type[BaseException] | None = None,
114
+ exc_value: BaseException | None = None,
115
+ exc_traceback: TracebackType | None = None,
116
+ ):
117
+ """Disconnect from the MCP server"""
118
+ self._adapter.__exit__(exc_type, exc_value, exc_traceback)
119
+
120
+ def get_tools(self) -> list[Tool]:
121
+ """The SmolAgents tools available from the MCP server.
122
+
123
+ Note: for now, this always returns the tools available at the creation of the session,
124
+ but it will in a future release return also new tools available from the MCP server if
125
+ any at call time.
126
+
127
+ Raises:
128
+ ValueError: If the MCP server tools is None (usually assuming the server is not started).
129
+
130
+ Returns:
131
+ list[Tool]: The SmolAgents tools available from the MCP server.
132
+ """
133
+ if self._tools is None:
134
+ raise ValueError(
135
+ "Couldn't retrieve tools from MCP server, run `mcp_client.connect()` first before accessing `tools`"
136
+ )
137
+ return self._tools
138
+
139
+ def __enter__(self) -> list[Tool]:
140
+ """Connect to the MCP server and return the tools directly.
141
+
142
+ Note that because of the `.connect` in the init, the mcp_client
143
+ is already connected at this point.
144
+ """
145
+ return self._tools
146
+
147
+ def __exit__(
148
+ self,
149
+ exc_type: type[BaseException] | None,
150
+ exc_value: BaseException | None,
151
+ exc_traceback: TracebackType | None,
152
+ ):
153
+ """Disconnect from the MCP server."""
154
+ self.disconnect(exc_type, exc_value, exc_traceback)
src/smolagents/memory.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import asdict, dataclass
2
+ from logging import getLogger
3
+ from typing import TYPE_CHECKING, Any
4
+
5
+ from smolagents.models import ChatMessage, MessageRole
6
+ from smolagents.monitoring import AgentLogger, LogLevel, Timing, TokenUsage
7
+ from smolagents.utils import AgentError, make_json_serializable
8
+
9
+
10
+ if TYPE_CHECKING:
11
+ import PIL.Image
12
+
13
+ from smolagents.models import ChatMessage
14
+ from smolagents.monitoring import AgentLogger
15
+
16
+
17
+ logger = getLogger(__name__)
18
+
19
+
20
+ @dataclass
21
+ class ToolCall:
22
+ name: str
23
+ arguments: Any
24
+ id: str
25
+
26
+ def dict(self):
27
+ return {
28
+ "id": self.id,
29
+ "type": "function",
30
+ "function": {
31
+ "name": self.name,
32
+ "arguments": make_json_serializable(self.arguments),
33
+ },
34
+ }
35
+
36
+
37
+ @dataclass
38
+ class MemoryStep:
39
+ def dict(self):
40
+ return asdict(self)
41
+
42
+ def to_messages(self, summary_mode: bool = False) -> list[ChatMessage]:
43
+ raise NotImplementedError
44
+
45
+
46
+ @dataclass
47
+ class ActionStep(MemoryStep):
48
+ step_number: int
49
+ timing: Timing
50
+ model_input_messages: list[ChatMessage] | None = None
51
+ tool_calls: list[ToolCall] | None = None
52
+ error: AgentError | None = None
53
+ model_output_message: ChatMessage | None = None
54
+ model_output: str | list[dict[str, Any]] | None = None
55
+ code_action: str | None = None
56
+ observations: str | None = None
57
+ observations_images: list["PIL.Image.Image"] | None = None
58
+ action_output: Any = None
59
+ token_usage: TokenUsage | None = None
60
+ is_final_answer: bool = False
61
+
62
+ def dict(self):
63
+ # We overwrite the method to parse the tool_calls and action_output manually
64
+ return {
65
+ "step_number": self.step_number,
66
+ "timing": self.timing.dict(),
67
+ "model_input_messages": self.model_input_messages,
68
+ "tool_calls": [tc.dict() for tc in self.tool_calls] if self.tool_calls else [],
69
+ "error": self.error.dict() if self.error else None,
70
+ "model_output_message": self.model_output_message.dict() if self.model_output_message else None,
71
+ "model_output": self.model_output,
72
+ "code_action": self.code_action,
73
+ "observations": self.observations,
74
+ "observations_images": [image.tobytes() for image in self.observations_images]
75
+ if self.observations_images
76
+ else None,
77
+ "action_output": make_json_serializable(self.action_output),
78
+ "token_usage": asdict(self.token_usage) if self.token_usage else None,
79
+ "is_final_answer": self.is_final_answer,
80
+ }
81
+
82
+ def to_messages(self, summary_mode: bool = False) -> list[ChatMessage]:
83
+ messages = []
84
+ if self.model_output is not None and not summary_mode:
85
+ messages.append(
86
+ ChatMessage(role=MessageRole.ASSISTANT, content=[{"type": "text", "text": self.model_output.strip()}])
87
+ )
88
+
89
+ if self.tool_calls is not None:
90
+ messages.append(
91
+ ChatMessage(
92
+ role=MessageRole.TOOL_CALL,
93
+ content=[
94
+ {
95
+ "type": "text",
96
+ "text": "Calling tools:\n" + str([tc.dict() for tc in self.tool_calls]),
97
+ }
98
+ ],
99
+ )
100
+ )
101
+
102
+ if self.observations_images:
103
+ messages.append(
104
+ ChatMessage(
105
+ role=MessageRole.USER,
106
+ content=[
107
+ {
108
+ "type": "image",
109
+ "image": image,
110
+ }
111
+ for image in self.observations_images
112
+ ],
113
+ )
114
+ )
115
+
116
+ if self.observations is not None:
117
+ messages.append(
118
+ ChatMessage(
119
+ role=MessageRole.TOOL_RESPONSE,
120
+ content=[
121
+ {
122
+ "type": "text",
123
+ "text": f"Observation:\n{self.observations}",
124
+ }
125
+ ],
126
+ )
127
+ )
128
+ if self.error is not None:
129
+ error_message = (
130
+ "Error:\n"
131
+ + str(self.error)
132
+ + "\nNow let's retry: take care not to repeat previous errors! If you have retried several times, try a completely different approach.\n"
133
+ )
134
+ message_content = f"Call id: {self.tool_calls[0].id}\n" if self.tool_calls else ""
135
+ message_content += error_message
136
+ messages.append(
137
+ ChatMessage(role=MessageRole.TOOL_RESPONSE, content=[{"type": "text", "text": message_content}])
138
+ )
139
+
140
+ return messages
141
+
142
+
143
+ @dataclass
144
+ class PlanningStep(MemoryStep):
145
+ model_input_messages: list[ChatMessage]
146
+ model_output_message: ChatMessage
147
+ plan: str
148
+ timing: Timing
149
+ token_usage: TokenUsage | None = None
150
+
151
+ def to_messages(self, summary_mode: bool = False) -> list[ChatMessage]:
152
+ if summary_mode:
153
+ return []
154
+ return [
155
+ ChatMessage(role=MessageRole.ASSISTANT, content=[{"type": "text", "text": self.plan.strip()}]),
156
+ ChatMessage(
157
+ role=MessageRole.USER, content=[{"type": "text", "text": "Now proceed and carry out this plan."}]
158
+ ),
159
+ # This second message creates a role change to prevent models models from simply continuing the plan message
160
+ ]
161
+
162
+
163
+ @dataclass
164
+ class TaskStep(MemoryStep):
165
+ task: str
166
+ task_images: list["PIL.Image.Image"] | None = None
167
+
168
+ def to_messages(self, summary_mode: bool = False) -> list[ChatMessage]:
169
+ content = [{"type": "text", "text": f"New task:\n{self.task}"}]
170
+ if self.task_images:
171
+ content.extend([{"type": "image", "image": image} for image in self.task_images])
172
+
173
+ return [ChatMessage(role=MessageRole.USER, content=content)]
174
+
175
+
176
+ @dataclass
177
+ class SystemPromptStep(MemoryStep):
178
+ system_prompt: str
179
+
180
+ def to_messages(self, summary_mode: bool = False) -> list[ChatMessage]:
181
+ if summary_mode:
182
+ return []
183
+ return [ChatMessage(role=MessageRole.SYSTEM, content=[{"type": "text", "text": self.system_prompt}])]
184
+
185
+
186
+ @dataclass
187
+ class FinalAnswerStep(MemoryStep):
188
+ output: Any
189
+
190
+
191
+ class AgentMemory:
192
+ """Memory for the agent, containing the system prompt and all steps taken by the agent.
193
+
194
+ This class is used to store the agent's steps, including tasks, actions, and planning steps.
195
+ It allows for resetting the memory, retrieving succinct or full step information, and replaying the agent's steps.
196
+
197
+ Args:
198
+ system_prompt (`str`): System prompt for the agent, which sets the context and instructions for the agent's behavior.
199
+
200
+ **Attributes**:
201
+ - **system_prompt** (`SystemPromptStep`) -- System prompt step for the agent.
202
+ - **steps** (`list[TaskStep | ActionStep | PlanningStep]`) -- List of steps taken by the agent, which can include tasks, actions, and planning steps.
203
+ """
204
+
205
+ def __init__(self, system_prompt: str):
206
+ self.system_prompt: SystemPromptStep = SystemPromptStep(system_prompt=system_prompt)
207
+ self.steps: list[TaskStep | ActionStep | PlanningStep] = []
208
+
209
+ def reset(self):
210
+ """Reset the agent's memory, clearing all steps and keeping the system prompt."""
211
+ self.steps = []
212
+
213
+ def get_succinct_steps(self) -> list[dict]:
214
+ """Return a succinct representation of the agent's steps, excluding model input messages."""
215
+ return [
216
+ {key: value for key, value in step.dict().items() if key != "model_input_messages"} for step in self.steps
217
+ ]
218
+
219
+ def get_full_steps(self) -> list[dict]:
220
+ """Return a full representation of the agent's steps, including model input messages."""
221
+ if len(self.steps) == 0:
222
+ return []
223
+ return [step.dict() for step in self.steps]
224
+
225
+ def replay(self, logger: AgentLogger, detailed: bool = False):
226
+ """Prints a pretty replay of the agent's steps.
227
+
228
+ Args:
229
+ logger (`AgentLogger`): The logger to print replay logs to.
230
+ detailed (`bool`, default `False`): If True, also displays the memory at each step. Defaults to False.
231
+ Careful: will increase log length exponentially. Use only for debugging.
232
+ """
233
+ logger.console.log("Replaying the agent's steps:")
234
+ logger.log_markdown(title="System prompt", content=self.system_prompt.system_prompt, level=LogLevel.ERROR)
235
+ for step in self.steps:
236
+ if isinstance(step, TaskStep):
237
+ logger.log_task(step.task, "", level=LogLevel.ERROR)
238
+ elif isinstance(step, ActionStep):
239
+ logger.log_rule(f"Step {step.step_number}", level=LogLevel.ERROR)
240
+ if detailed and step.model_input_messages is not None:
241
+ logger.log_messages(step.model_input_messages, level=LogLevel.ERROR)
242
+ if step.model_output is not None:
243
+ logger.log_markdown(title="Agent output:", content=step.model_output, level=LogLevel.ERROR)
244
+ elif isinstance(step, PlanningStep):
245
+ logger.log_rule("Planning step", level=LogLevel.ERROR)
246
+ if detailed and step.model_input_messages is not None:
247
+ logger.log_messages(step.model_input_messages, level=LogLevel.ERROR)
248
+ logger.log_markdown(title="Agent output:", content=step.plan, level=LogLevel.ERROR)
249
+
250
+ def return_full_code(self) -> str:
251
+ """Returns all code actions from the agent's steps, concatenated as a single script."""
252
+ return "\n\n".join(
253
+ [step.code_action for step in self.steps if isinstance(step, ActionStep) and step.code_action is not None]
254
+ )
255
+
256
+
257
+ __all__ = ["AgentMemory"]
src/smolagents/models.py ADDED
@@ -0,0 +1,1882 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import json
15
+ import logging
16
+ import os
17
+ import re
18
+ import uuid
19
+ import warnings
20
+ from collections.abc import Generator
21
+ from copy import deepcopy
22
+ from dataclasses import asdict, dataclass
23
+ from enum import Enum
24
+ from threading import Thread
25
+ from typing import TYPE_CHECKING, Any
26
+
27
+ from .monitoring import TokenUsage
28
+ from .tools import Tool
29
+ from .utils import _is_package_available, encode_image_base64, make_image_url, parse_json_blob
30
+
31
+
32
+ if TYPE_CHECKING:
33
+ from transformers import StoppingCriteriaList
34
+
35
+
36
+ logger = logging.getLogger(__name__)
37
+
38
+ STRUCTURED_GENERATION_PROVIDERS = ["cerebras", "fireworks-ai"]
39
+ CODEAGENT_RESPONSE_FORMAT = {
40
+ "type": "json_schema",
41
+ "json_schema": {
42
+ "schema": {
43
+ "additionalProperties": False,
44
+ "properties": {
45
+ "thought": {
46
+ "description": "A free form text description of the thought process.",
47
+ "title": "Thought",
48
+ "type": "string",
49
+ },
50
+ "code": {
51
+ "description": "Valid Python code snippet implementing the thought.",
52
+ "title": "Code",
53
+ "type": "string",
54
+ },
55
+ },
56
+ "required": ["thought", "code"],
57
+ "title": "ThoughtAndCodeAnswer",
58
+ "type": "object",
59
+ },
60
+ "name": "ThoughtAndCodeAnswer",
61
+ "strict": True,
62
+ },
63
+ }
64
+
65
+
66
+ def get_dict_from_nested_dataclasses(obj, ignore_key=None):
67
+ def convert(obj):
68
+ if hasattr(obj, "__dataclass_fields__"):
69
+ return {k: convert(v) for k, v in asdict(obj).items() if k != ignore_key}
70
+ return obj
71
+
72
+ return convert(obj)
73
+
74
+
75
+ @dataclass
76
+ class ChatMessageToolCallFunction:
77
+ arguments: Any
78
+ name: str
79
+ description: str | None = None
80
+
81
+
82
+ @dataclass
83
+ class ChatMessageToolCall:
84
+ function: ChatMessageToolCallFunction
85
+ id: str
86
+ type: str
87
+
88
+ def __str__(self) -> str:
89
+ return f"Call: {self.id}: Calling {str(self.function.name)} with arguments: {str(self.function.arguments)}"
90
+
91
+
92
+ class MessageRole(str, Enum):
93
+ USER = "user"
94
+ ASSISTANT = "assistant"
95
+ SYSTEM = "system"
96
+ TOOL_CALL = "tool-call"
97
+ TOOL_RESPONSE = "tool-response"
98
+
99
+ @classmethod
100
+ def roles(cls):
101
+ return [r.value for r in cls]
102
+
103
+
104
+ @dataclass
105
+ class ChatMessage:
106
+ role: MessageRole
107
+ content: str | list[dict[str, Any]] | None = None
108
+ tool_calls: list[ChatMessageToolCall] | None = None
109
+ raw: Any | None = None # Stores the raw output from the API
110
+ token_usage: TokenUsage | None = None
111
+
112
+ def model_dump_json(self):
113
+ return json.dumps(get_dict_from_nested_dataclasses(self, ignore_key="raw"))
114
+
115
+ @classmethod
116
+ def from_dict(cls, data: dict, raw: Any | None = None, token_usage: TokenUsage | None = None) -> "ChatMessage":
117
+ if data.get("tool_calls"):
118
+ tool_calls = [
119
+ ChatMessageToolCall(
120
+ function=ChatMessageToolCallFunction(**tc["function"]), id=tc["id"], type=tc["type"]
121
+ )
122
+ for tc in data["tool_calls"]
123
+ ]
124
+ data["tool_calls"] = tool_calls
125
+ return cls(
126
+ role=data["role"],
127
+ content=data.get("content"),
128
+ tool_calls=data.get("tool_calls"),
129
+ raw=raw,
130
+ token_usage=token_usage,
131
+ )
132
+
133
+ def dict(self):
134
+ return get_dict_from_nested_dataclasses(self)
135
+
136
+ def render_as_markdown(self) -> str:
137
+ rendered = str(self.content) or ""
138
+ if self.tool_calls:
139
+ rendered += "\n".join(
140
+ [
141
+ json.dumps({"tool": tool.function.name, "arguments": tool.function.arguments})
142
+ for tool in self.tool_calls
143
+ ]
144
+ )
145
+ return rendered
146
+
147
+
148
+ def parse_json_if_needed(arguments: str | dict) -> str | dict:
149
+ if isinstance(arguments, dict):
150
+ return arguments
151
+ else:
152
+ try:
153
+ return json.loads(arguments)
154
+ except Exception:
155
+ return arguments
156
+
157
+
158
+ @dataclass
159
+ class ChatMessageToolCallStreamDelta:
160
+ """Represents a streaming delta for tool calls during generation."""
161
+
162
+ index: int | None = None
163
+ id: str | None = None
164
+ type: str | None = None
165
+ function: ChatMessageToolCallFunction | None = None
166
+
167
+
168
+ @dataclass
169
+ class ChatMessageStreamDelta:
170
+ content: str | None = None
171
+ tool_calls: list[ChatMessageToolCallStreamDelta] | None = None
172
+ token_usage: TokenUsage | None = None
173
+
174
+
175
+ def agglomerate_stream_deltas(
176
+ stream_deltas: list[ChatMessageStreamDelta], role: MessageRole = MessageRole.ASSISTANT
177
+ ) -> ChatMessage:
178
+ """
179
+ Agglomerate a list of stream deltas into a single stream delta.
180
+ """
181
+ accumulated_tool_calls: dict[int, ChatMessageToolCallStreamDelta] = {}
182
+ accumulated_content = ""
183
+ total_input_tokens = 0
184
+ total_output_tokens = 0
185
+ for stream_delta in stream_deltas:
186
+ if stream_delta.token_usage:
187
+ total_input_tokens += stream_delta.token_usage.input_tokens
188
+ total_output_tokens += stream_delta.token_usage.output_tokens
189
+ if stream_delta.content:
190
+ accumulated_content += stream_delta.content
191
+ if stream_delta.tool_calls:
192
+ for tool_call_delta in stream_delta.tool_calls: # ?ormally there should be only one call at a time
193
+ # Extend accumulated_tool_calls list to accommodate the new tool call if needed
194
+ if tool_call_delta.index is not None:
195
+ if tool_call_delta.index not in accumulated_tool_calls:
196
+ accumulated_tool_calls[tool_call_delta.index] = ChatMessageToolCallStreamDelta(
197
+ id=tool_call_delta.id,
198
+ type=tool_call_delta.type,
199
+ function=ChatMessageToolCallFunction(name="", arguments=""),
200
+ )
201
+ # Update the tool call at the specific index
202
+ tool_call = accumulated_tool_calls[tool_call_delta.index]
203
+ if tool_call_delta.id:
204
+ tool_call.id = tool_call_delta.id
205
+ if tool_call_delta.type:
206
+ tool_call.type = tool_call_delta.type
207
+ if tool_call_delta.function:
208
+ if tool_call_delta.function.name and len(tool_call_delta.function.name) > 0:
209
+ tool_call.function.name = tool_call_delta.function.name
210
+ if tool_call_delta.function.arguments:
211
+ tool_call.function.arguments += tool_call_delta.function.arguments
212
+ else:
213
+ raise ValueError(f"Tool call index is not provided in tool delta: {tool_call_delta}")
214
+
215
+ return ChatMessage(
216
+ role=role,
217
+ content=accumulated_content,
218
+ tool_calls=[
219
+ ChatMessageToolCall(
220
+ function=ChatMessageToolCallFunction(
221
+ name=tool_call_stream_delta.function.name,
222
+ arguments=tool_call_stream_delta.function.arguments,
223
+ ),
224
+ id=tool_call_stream_delta.id or "",
225
+ type="function",
226
+ )
227
+ for tool_call_stream_delta in accumulated_tool_calls.values()
228
+ if tool_call_stream_delta.function
229
+ ],
230
+ token_usage=TokenUsage(
231
+ input_tokens=total_input_tokens,
232
+ output_tokens=total_output_tokens,
233
+ ),
234
+ )
235
+
236
+
237
+ tool_role_conversions = {
238
+ MessageRole.TOOL_CALL: MessageRole.ASSISTANT,
239
+ MessageRole.TOOL_RESPONSE: MessageRole.USER,
240
+ }
241
+
242
+
243
+ def get_tool_json_schema(tool: Tool) -> dict:
244
+ properties = deepcopy(tool.inputs)
245
+ required = []
246
+ for key, value in properties.items():
247
+ if value["type"] == "any":
248
+ value["type"] = "string"
249
+ if not ("nullable" in value and value["nullable"]):
250
+ required.append(key)
251
+ return {
252
+ "type": "function",
253
+ "function": {
254
+ "name": tool.name,
255
+ "description": tool.description,
256
+ "parameters": {
257
+ "type": "object",
258
+ "properties": properties,
259
+ "required": required,
260
+ },
261
+ },
262
+ }
263
+
264
+
265
+ def remove_stop_sequences(content: str, stop_sequences: list[str]) -> str:
266
+ for stop_seq in stop_sequences:
267
+ if content[-len(stop_seq) :] == stop_seq:
268
+ content = content[: -len(stop_seq)]
269
+ return content
270
+
271
+
272
+ def get_clean_message_list(
273
+ message_list: list[ChatMessage],
274
+ role_conversions: dict[MessageRole, MessageRole] | dict[str, str] = {},
275
+ convert_images_to_image_urls: bool = False,
276
+ flatten_messages_as_text: bool = False,
277
+ ) -> list[dict[str, Any]]:
278
+ """
279
+ Creates a list of messages to give as input to the LLM. These messages are dictionaries and chat template compatible with transformers LLM chat template.
280
+ Subsequent messages with the same role will be concatenated to a single message.
281
+
282
+ Args:
283
+ message_list (`list[dict[str, str]]`): List of chat messages.
284
+ role_conversions (`dict[MessageRole, MessageRole]`, *optional* ): Mapping to convert roles.
285
+ convert_images_to_image_urls (`bool`, default `False`): Whether to convert images to image URLs.
286
+ flatten_messages_as_text (`bool`, default `False`): Whether to flatten messages as text.
287
+ """
288
+ output_message_list: list[dict[str, Any]] = []
289
+ message_list = deepcopy(message_list) # Avoid modifying the original list
290
+ for message in message_list:
291
+ role = message.role
292
+ if role not in MessageRole.roles():
293
+ raise ValueError(f"Incorrect role {role}, only {MessageRole.roles()} are supported for now.")
294
+
295
+ if role in role_conversions:
296
+ message.role = role_conversions[role] # type: ignore
297
+ # encode images if needed
298
+ if isinstance(message.content, list):
299
+ for element in message.content:
300
+ assert isinstance(element, dict), "Error: this element should be a dict:" + str(element)
301
+ if element["type"] == "image":
302
+ assert not flatten_messages_as_text, f"Cannot use images with {flatten_messages_as_text=}"
303
+ if convert_images_to_image_urls:
304
+ element.update(
305
+ {
306
+ "type": "image_url",
307
+ "image_url": {"url": make_image_url(encode_image_base64(element.pop("image")))},
308
+ }
309
+ )
310
+ else:
311
+ element["image"] = encode_image_base64(element["image"])
312
+
313
+ if len(output_message_list) > 0 and message.role == output_message_list[-1]["role"]:
314
+ assert isinstance(message.content, list), "Error: wrong content:" + str(message.content)
315
+ if flatten_messages_as_text:
316
+ output_message_list[-1]["content"] += "\n" + message.content[0]["text"]
317
+ else:
318
+ for el in message.content:
319
+ if el["type"] == "text" and output_message_list[-1]["content"][-1]["type"] == "text":
320
+ # Merge consecutive text messages rather than creating new ones
321
+ output_message_list[-1]["content"][-1]["text"] += "\n" + el["text"]
322
+ else:
323
+ output_message_list[-1]["content"].append(el)
324
+ else:
325
+ if flatten_messages_as_text:
326
+ content = message.content[0]["text"]
327
+ else:
328
+ content = message.content
329
+ output_message_list.append(
330
+ {
331
+ "role": message.role,
332
+ "content": content,
333
+ }
334
+ )
335
+ return output_message_list
336
+
337
+
338
+ def get_tool_call_from_text(text: str, tool_name_key: str, tool_arguments_key: str) -> ChatMessageToolCall:
339
+ tool_call_dictionary, _ = parse_json_blob(text)
340
+ try:
341
+ tool_name = tool_call_dictionary[tool_name_key]
342
+ except Exception as e:
343
+ raise ValueError(
344
+ f"Key {tool_name_key=} not found in the generated tool call. Got keys: {list(tool_call_dictionary.keys())} instead"
345
+ ) from e
346
+ tool_arguments = tool_call_dictionary.get(tool_arguments_key, None)
347
+ if isinstance(tool_arguments, str):
348
+ tool_arguments = parse_json_if_needed(tool_arguments)
349
+ return ChatMessageToolCall(
350
+ id=str(uuid.uuid4()),
351
+ type="function",
352
+ function=ChatMessageToolCallFunction(name=tool_name, arguments=tool_arguments),
353
+ )
354
+
355
+
356
+ def supports_stop_parameter(model_id: str) -> bool:
357
+ """
358
+ Check if the model supports the `stop` parameter.
359
+
360
+ Not supported with reasoning models openai/o3 and openai/o4-mini (and their versioned variants).
361
+
362
+ Args:
363
+ model_id (`str`): Model identifier (e.g. "openai/o3", "o4-mini-2025-04-16")
364
+
365
+ Returns:
366
+ bool: True if the model supports the stop parameter, False otherwise
367
+ """
368
+ model_name = model_id.split("/")[-1]
369
+ # o3 and o4-mini (including versioned variants, o3-2025-04-16) don't support stop parameter
370
+ pattern = r"^(o3[-\d]*|o4-mini[-\d]*)$"
371
+ return not re.match(pattern, model_name)
372
+
373
+
374
+ class Model:
375
+ def __init__(
376
+ self,
377
+ flatten_messages_as_text: bool = False,
378
+ tool_name_key: str = "name",
379
+ tool_arguments_key: str = "arguments",
380
+ model_id: str | None = None,
381
+ **kwargs,
382
+ ):
383
+ self.flatten_messages_as_text = flatten_messages_as_text
384
+ self.tool_name_key = tool_name_key
385
+ self.tool_arguments_key = tool_arguments_key
386
+ self.kwargs = kwargs
387
+ self._last_input_token_count: int | None = None
388
+ self._last_output_token_count: int | None = None
389
+ self.model_id: str | None = model_id
390
+
391
+ @property
392
+ def last_input_token_count(self) -> int | None:
393
+ warnings.warn(
394
+ "Attribute last_input_token_count is deprecated and will be removed in version 1.20. "
395
+ "Please use TokenUsage.input_tokens instead.",
396
+ FutureWarning,
397
+ )
398
+ return self._last_input_token_count
399
+
400
+ @property
401
+ def last_output_token_count(self) -> int | None:
402
+ warnings.warn(
403
+ "Attribute last_output_token_count is deprecated and will be removed in version 1.20. "
404
+ "Please use TokenUsage.output_tokens instead.",
405
+ FutureWarning,
406
+ )
407
+ return self._last_output_token_count
408
+
409
+ def _prepare_completion_kwargs(
410
+ self,
411
+ messages: list[ChatMessage],
412
+ stop_sequences: list[str] | None = None,
413
+ response_format: dict[str, str] | None = None,
414
+ tools_to_call_from: list[Tool] | None = None,
415
+ custom_role_conversions: dict[str, str] | None = None,
416
+ convert_images_to_image_urls: bool = False,
417
+ tool_choice: str | dict | None = "required", # Configurable tool_choice parameter
418
+ **kwargs,
419
+ ) -> dict[str, Any]:
420
+ """
421
+ Prepare parameters required for model invocation, handling parameter priorities.
422
+
423
+ Parameter priority from high to low:
424
+ 1. Explicitly passed kwargs
425
+ 2. Specific parameters (stop_sequences, response_format, etc.)
426
+ 3. Default values in self.kwargs
427
+ """
428
+ # Clean and standardize the message list
429
+ flatten_messages_as_text = kwargs.pop("flatten_messages_as_text", self.flatten_messages_as_text)
430
+ messages_as_dicts = get_clean_message_list(
431
+ messages,
432
+ role_conversions=custom_role_conversions or tool_role_conversions,
433
+ convert_images_to_image_urls=convert_images_to_image_urls,
434
+ flatten_messages_as_text=flatten_messages_as_text,
435
+ )
436
+ # Use self.kwargs as the base configuration
437
+ completion_kwargs = {
438
+ **self.kwargs,
439
+ "messages": messages_as_dicts,
440
+ }
441
+
442
+ # Handle specific parameters
443
+ if stop_sequences is not None:
444
+ # Some models do not support stop parameter
445
+ if supports_stop_parameter(self.model_id or ""):
446
+ completion_kwargs["stop"] = stop_sequences
447
+ if response_format is not None:
448
+ completion_kwargs["response_format"] = response_format
449
+
450
+ # Handle tools parameter
451
+ if tools_to_call_from:
452
+ tools_config = {
453
+ "tools": [get_tool_json_schema(tool) for tool in tools_to_call_from],
454
+ }
455
+ if tool_choice is not None:
456
+ tools_config["tool_choice"] = tool_choice
457
+ completion_kwargs.update(tools_config)
458
+
459
+ # Finally, use the passed-in kwargs to override all settings
460
+ completion_kwargs.update(kwargs)
461
+
462
+ return completion_kwargs
463
+
464
+ def generate(
465
+ self,
466
+ messages: list[ChatMessage],
467
+ stop_sequences: list[str] | None = None,
468
+ response_format: dict[str, str] | None = None,
469
+ tools_to_call_from: list[Tool] | None = None,
470
+ **kwargs,
471
+ ) -> ChatMessage:
472
+ """Process the input messages and return the model's response.
473
+
474
+ Parameters:
475
+ messages (`list[dict[str, str | list[dict]]] | list[ChatMessage]`):
476
+ A list of message dictionaries to be processed. Each dictionary should have the structure `{"role": "user/system", "content": "message content"}`.
477
+ stop_sequences (`List[str]`, *optional*):
478
+ A list of strings that will stop the generation if encountered in the model's output.
479
+ response_format (`dict[str, str]`, *optional*):
480
+ The response format to use in the model's response.
481
+ tools_to_call_from (`List[Tool]`, *optional*):
482
+ A list of tools that the model can use to generate responses.
483
+ **kwargs:
484
+ Additional keyword arguments to be passed to the underlying model.
485
+
486
+ Returns:
487
+ `ChatMessage`: A chat message object containing the model's response.
488
+ """
489
+ raise NotImplementedError("This method must be implemented in child classes")
490
+
491
+ def __call__(self, *args, **kwargs):
492
+ return self.generate(*args, **kwargs)
493
+
494
+ def parse_tool_calls(self, message: ChatMessage) -> ChatMessage:
495
+ """Sometimes APIs do not return the tool call as a specific object, so we need to parse it."""
496
+ message.role = MessageRole.ASSISTANT # Overwrite role if needed
497
+ if not message.tool_calls:
498
+ assert message.content is not None, "Message contains no content and no tool calls"
499
+ message.tool_calls = [
500
+ get_tool_call_from_text(message.content, self.tool_name_key, self.tool_arguments_key)
501
+ ]
502
+ assert len(message.tool_calls) > 0, "No tool call was found in the model output"
503
+ for tool_call in message.tool_calls:
504
+ tool_call.function.arguments = parse_json_if_needed(tool_call.function.arguments)
505
+ return message
506
+
507
+ def to_dict(self) -> dict:
508
+ """
509
+ Converts the model into a JSON-compatible dictionary.
510
+ """
511
+ model_dictionary = {
512
+ **self.kwargs,
513
+ "model_id": self.model_id,
514
+ }
515
+ for attribute in [
516
+ "custom_role_conversion",
517
+ "temperature",
518
+ "max_tokens",
519
+ "provider",
520
+ "timeout",
521
+ "api_base",
522
+ "torch_dtype",
523
+ "device_map",
524
+ "organization",
525
+ "project",
526
+ "azure_endpoint",
527
+ ]:
528
+ if hasattr(self, attribute):
529
+ model_dictionary[attribute] = getattr(self, attribute)
530
+
531
+ dangerous_attributes = ["token", "api_key"]
532
+ for attribute_name in dangerous_attributes:
533
+ if hasattr(self, attribute_name):
534
+ print(
535
+ f"For security reasons, we do not export the `{attribute_name}` attribute of your model. Please export it manually."
536
+ )
537
+ return model_dictionary
538
+
539
+ @classmethod
540
+ def from_dict(cls, model_dictionary: dict[str, Any]) -> "Model":
541
+ return cls(**{k: v for k, v in model_dictionary.items()})
542
+
543
+
544
+ class VLLMModel(Model):
545
+ """Model to use [vLLM](https://docs.vllm.ai/) for fast LLM inference and serving.
546
+
547
+ Parameters:
548
+ model_id (`str`):
549
+ The Hugging Face model ID to be used for inference.
550
+ This can be a path or model identifier from the Hugging Face model hub.
551
+ model_kwargs (`dict[str, Any]`, *optional*):
552
+ Additional keyword arguments to pass to the vLLM model (like revision, max_model_len, etc.).
553
+ """
554
+
555
+ def __init__(
556
+ self,
557
+ model_id,
558
+ model_kwargs: dict[str, Any] | None = None,
559
+ **kwargs,
560
+ ):
561
+ if not _is_package_available("vllm"):
562
+ raise ModuleNotFoundError("Please install 'vllm' extra to use VLLMModel: `pip install 'smolagents[vllm]'`")
563
+
564
+ from vllm import LLM # type: ignore
565
+ from vllm.transformers_utils.tokenizer import get_tokenizer # type: ignore
566
+
567
+ self.model_kwargs = model_kwargs or {}
568
+ super().__init__(**kwargs)
569
+ self.model_id = model_id
570
+ self.model = LLM(model=model_id, **self.model_kwargs)
571
+ assert self.model is not None
572
+ self.tokenizer = get_tokenizer(model_id)
573
+ self._is_vlm = False # VLLMModel does not support vision models yet.
574
+
575
+ def cleanup(self):
576
+ import gc
577
+
578
+ import torch
579
+ from vllm.distributed.parallel_state import ( # type: ignore
580
+ destroy_distributed_environment,
581
+ destroy_model_parallel,
582
+ )
583
+
584
+ destroy_model_parallel()
585
+ if self.model is not None:
586
+ # taken from https://github.com/vllm-project/vllm/issues/1908#issuecomment-2076870351
587
+ del self.model.llm_engine.model_executor.driver_worker
588
+ gc.collect()
589
+ destroy_distributed_environment()
590
+ torch.cuda.empty_cache()
591
+
592
+ def generate(
593
+ self,
594
+ messages: list[ChatMessage],
595
+ stop_sequences: list[str] | None = None,
596
+ response_format: dict[str, str] | None = None,
597
+ tools_to_call_from: list[Tool] | None = None,
598
+ **kwargs,
599
+ ) -> ChatMessage:
600
+ from vllm import SamplingParams # type: ignore
601
+
602
+ completion_kwargs = self._prepare_completion_kwargs(
603
+ messages=messages,
604
+ flatten_messages_as_text=(not self._is_vlm),
605
+ stop_sequences=stop_sequences,
606
+ tools_to_call_from=tools_to_call_from,
607
+ **kwargs,
608
+ )
609
+ # Override the OpenAI schema for VLLM compatibility
610
+ guided_options_request = {"guided_json": response_format["json_schema"]["schema"]} if response_format else None
611
+
612
+ messages = completion_kwargs.pop("messages")
613
+ prepared_stop_sequences = completion_kwargs.pop("stop", [])
614
+ tools = completion_kwargs.pop("tools", None)
615
+ completion_kwargs.pop("tool_choice", None)
616
+
617
+ prompt = self.tokenizer.apply_chat_template(
618
+ messages,
619
+ tools=tools,
620
+ add_generation_prompt=True,
621
+ tokenize=False,
622
+ )
623
+
624
+ sampling_params = SamplingParams(
625
+ n=kwargs.get("n", 1),
626
+ temperature=kwargs.get("temperature", 0.0),
627
+ max_tokens=kwargs.get("max_tokens", 2048),
628
+ stop=prepared_stop_sequences,
629
+ )
630
+
631
+ out = self.model.generate(
632
+ prompt,
633
+ sampling_params=sampling_params,
634
+ guided_options_request=guided_options_request,
635
+ )
636
+
637
+ output_text = out[0].outputs[0].text
638
+ self._last_input_token_count = len(out[0].prompt_token_ids)
639
+ self._last_output_token_count = len(out[0].outputs[0].token_ids)
640
+ return ChatMessage(
641
+ role=MessageRole.ASSISTANT,
642
+ content=output_text,
643
+ raw={"out": output_text, "completion_kwargs": completion_kwargs},
644
+ token_usage=TokenUsage(
645
+ input_tokens=len(out[0].prompt_token_ids),
646
+ output_tokens=len(out[0].outputs[0].token_ids),
647
+ ),
648
+ )
649
+
650
+
651
+ class MLXModel(Model):
652
+ """A class to interact with models loaded using MLX on Apple silicon.
653
+
654
+ > [!TIP]
655
+ > You must have `mlx-lm` installed on your machine. Please run `pip install smolagents[mlx-lm]` if it's not the case.
656
+
657
+ Parameters:
658
+ model_id (str):
659
+ The Hugging Face model ID to be used for inference. This can be a path or model identifier from the Hugging Face model hub.
660
+ tool_name_key (str):
661
+ The key, which can usually be found in the model's chat template, for retrieving a tool name.
662
+ tool_arguments_key (str):
663
+ The key, which can usually be found in the model's chat template, for retrieving tool arguments.
664
+ trust_remote_code (bool, default `False`):
665
+ Some models on the Hub require running remote code: for this model, you would have to set this flag to True.
666
+ load_kwargs (dict[str, Any], *optional*):
667
+ Additional keyword arguments to pass to the `mlx.lm.load` method when loading the model and tokenizer.
668
+ apply_chat_template_kwargs (dict, *optional*):
669
+ Additional keyword arguments to pass to the `apply_chat_template` method of the tokenizer.
670
+ kwargs (dict, *optional*):
671
+ Any additional keyword arguments that you want to use in model.generate(), for instance `max_tokens`.
672
+
673
+ Example:
674
+ ```python
675
+ >>> engine = MLXModel(
676
+ ... model_id="mlx-community/Qwen2.5-Coder-32B-Instruct-4bit",
677
+ ... max_tokens=10000,
678
+ ... )
679
+ >>> messages = [
680
+ ... {
681
+ ... "role": "user",
682
+ ... "content": "Explain quantum mechanics in simple terms."
683
+ ... }
684
+ ... ]
685
+ >>> response = engine(messages, stop_sequences=["END"])
686
+ >>> print(response)
687
+ "Quantum mechanics is the branch of physics that studies..."
688
+ ```
689
+ """
690
+
691
+ def __init__(
692
+ self,
693
+ model_id: str,
694
+ trust_remote_code: bool = False,
695
+ load_kwargs: dict[str, Any] | None = None,
696
+ apply_chat_template_kwargs: dict[str, Any] | None = None,
697
+ **kwargs,
698
+ ):
699
+ if not _is_package_available("mlx_lm"):
700
+ raise ModuleNotFoundError(
701
+ "Please install 'mlx-lm' extra to use 'MLXModel': `pip install 'smolagents[mlx-lm]'`"
702
+ )
703
+ import mlx_lm
704
+
705
+ self.load_kwargs = load_kwargs or {}
706
+ self.load_kwargs.setdefault("tokenizer_config", {}).setdefault("trust_remote_code", trust_remote_code)
707
+ self.apply_chat_template_kwargs = apply_chat_template_kwargs or {}
708
+ self.apply_chat_template_kwargs.setdefault("add_generation_prompt", True)
709
+ # mlx-lm doesn't support vision models: flatten_messages_as_text=True
710
+ super().__init__(model_id=model_id, flatten_messages_as_text=True, **kwargs)
711
+
712
+ self.model, self.tokenizer = mlx_lm.load(self.model_id, **self.load_kwargs)
713
+ self.stream_generate = mlx_lm.stream_generate
714
+ self.is_vlm = False # mlx-lm doesn't support vision models
715
+
716
+ def generate(
717
+ self,
718
+ messages: list[ChatMessage],
719
+ stop_sequences: list[str] | None = None,
720
+ response_format: dict[str, str] | None = None,
721
+ tools_to_call_from: list[Tool] | None = None,
722
+ **kwargs,
723
+ ) -> ChatMessage:
724
+ if response_format is not None:
725
+ raise ValueError("MLX does not support structured outputs.")
726
+ completion_kwargs = self._prepare_completion_kwargs(
727
+ messages=messages,
728
+ stop_sequences=stop_sequences,
729
+ tools_to_call_from=tools_to_call_from,
730
+ **kwargs,
731
+ )
732
+ messages = completion_kwargs.pop("messages")
733
+ stops = completion_kwargs.pop("stop", [])
734
+ tools = completion_kwargs.pop("tools", None)
735
+ completion_kwargs.pop("tool_choice", None)
736
+
737
+ prompt_ids = self.tokenizer.apply_chat_template(messages, tools=tools, **self.apply_chat_template_kwargs)
738
+
739
+ output_tokens = 0
740
+ text = ""
741
+ for response in self.stream_generate(self.model, self.tokenizer, prompt=prompt_ids, **completion_kwargs):
742
+ output_tokens += 1
743
+ text += response.text
744
+ if any((stop_index := text.rfind(stop)) != -1 for stop in stops):
745
+ text = text[:stop_index]
746
+ break
747
+
748
+ self._last_input_token_count = len(prompt_ids)
749
+ self._last_output_token_count = output_tokens
750
+ return ChatMessage(
751
+ role=MessageRole.ASSISTANT,
752
+ content=text,
753
+ raw={"out": text, "completion_kwargs": completion_kwargs},
754
+ token_usage=TokenUsage(
755
+ input_tokens=len(prompt_ids),
756
+ output_tokens=output_tokens,
757
+ ),
758
+ )
759
+
760
+
761
+ class TransformersModel(Model):
762
+ """A class that uses Hugging Face's Transformers library for language model interaction.
763
+
764
+ This model allows you to load and use Hugging Face's models locally using the Transformers library. It supports features like stop sequences and grammar customization.
765
+
766
+ > [!TIP]
767
+ > You must have `transformers` and `torch` installed on your machine. Please run `pip install smolagents[transformers]` if it's not the case.
768
+
769
+ Parameters:
770
+ model_id (`str`):
771
+ The Hugging Face model ID to be used for inference. This can be a path or model identifier from the Hugging Face model hub.
772
+ For example, `"Qwen/Qwen2.5-Coder-32B-Instruct"`.
773
+ device_map (`str`, *optional*):
774
+ The device_map to initialize your model with.
775
+ torch_dtype (`str`, *optional*):
776
+ The torch_dtype to initialize your model with.
777
+ trust_remote_code (bool, default `False`):
778
+ Some models on the Hub require running remote code: for this model, you would have to set this flag to True.
779
+ kwargs (dict, *optional*):
780
+ Any additional keyword arguments that you want to use in model.generate(), for instance `max_new_tokens` or `device`.
781
+ **kwargs:
782
+ Additional keyword arguments to pass to `model.generate()`, for instance `max_new_tokens` or `device`.
783
+ Raises:
784
+ ValueError:
785
+ If the model name is not provided.
786
+
787
+ Example:
788
+ ```python
789
+ >>> engine = TransformersModel(
790
+ ... model_id="Qwen/Qwen2.5-Coder-32B-Instruct",
791
+ ... device="cuda",
792
+ ... max_new_tokens=5000,
793
+ ... )
794
+ >>> messages = [{"role": "user", "content": "Explain quantum mechanics in simple terms."}]
795
+ >>> response = engine(messages, stop_sequences=["END"])
796
+ >>> print(response)
797
+ "Quantum mechanics is the branch of physics that studies..."
798
+ ```
799
+ """
800
+
801
+ def __init__(
802
+ self,
803
+ model_id: str | None = None,
804
+ device_map: str | None = None,
805
+ torch_dtype: str | None = None,
806
+ trust_remote_code: bool = False,
807
+ **kwargs,
808
+ ):
809
+ try:
810
+ import torch
811
+ from transformers import (
812
+ AutoModelForCausalLM,
813
+ AutoModelForImageTextToText,
814
+ AutoProcessor,
815
+ AutoTokenizer,
816
+ TextIteratorStreamer,
817
+ )
818
+ except ModuleNotFoundError:
819
+ raise ModuleNotFoundError(
820
+ "Please install 'transformers' extra to use 'TransformersModel': `pip install 'smolagents[transformers]'`"
821
+ )
822
+
823
+ if not model_id:
824
+ warnings.warn(
825
+ "The 'model_id' parameter will be required in version 2.0.0. "
826
+ "Please update your code to pass this parameter to avoid future errors. "
827
+ "For now, it defaults to 'HuggingFaceTB/SmolLM2-1.7B-Instruct'.",
828
+ FutureWarning,
829
+ )
830
+ model_id = "HuggingFaceTB/SmolLM2-1.7B-Instruct"
831
+
832
+ default_max_tokens = 4096
833
+ max_new_tokens = kwargs.get("max_new_tokens") or kwargs.get("max_tokens")
834
+ if not max_new_tokens:
835
+ kwargs["max_new_tokens"] = default_max_tokens
836
+ logger.warning(
837
+ f"`max_new_tokens` not provided, using this default value for `max_new_tokens`: {default_max_tokens}"
838
+ )
839
+
840
+ if device_map is None:
841
+ device_map = "cuda" if torch.cuda.is_available() else "cpu"
842
+ logger.info(f"Using device: {device_map}")
843
+ self._is_vlm = False
844
+ try:
845
+ self.model = AutoModelForImageTextToText.from_pretrained(
846
+ model_id,
847
+ device_map=device_map,
848
+ torch_dtype=torch_dtype,
849
+ trust_remote_code=trust_remote_code,
850
+ )
851
+ self.processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=trust_remote_code)
852
+ self._is_vlm = True
853
+ self.streamer = TextIteratorStreamer(self.processor.tokenizer, skip_prompt=True, skip_special_tokens=True) # type: ignore
854
+
855
+ except ValueError as e:
856
+ if "Unrecognized configuration class" in str(e):
857
+ self.model = AutoModelForCausalLM.from_pretrained(
858
+ model_id,
859
+ device_map=device_map,
860
+ torch_dtype=torch_dtype,
861
+ trust_remote_code=trust_remote_code,
862
+ )
863
+ self.tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=trust_remote_code)
864
+ self.streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True) # type: ignore
865
+ else:
866
+ raise e
867
+ except Exception as e:
868
+ raise ValueError(f"Failed to load tokenizer and model for {model_id=}: {e}") from e
869
+ super().__init__(flatten_messages_as_text=not self._is_vlm, model_id=model_id, **kwargs)
870
+
871
+ def make_stopping_criteria(self, stop_sequences: list[str], tokenizer) -> "StoppingCriteriaList":
872
+ from transformers import StoppingCriteria, StoppingCriteriaList
873
+
874
+ class StopOnStrings(StoppingCriteria):
875
+ def __init__(self, stop_strings: list[str], tokenizer):
876
+ self.stop_strings = stop_strings
877
+ self.tokenizer = tokenizer
878
+ self.stream = ""
879
+
880
+ def reset(self):
881
+ self.stream = ""
882
+
883
+ def __call__(self, input_ids, scores, **kwargs):
884
+ generated = self.tokenizer.decode(input_ids[0][-1], skip_special_tokens=True)
885
+ self.stream += generated
886
+ if any([self.stream.endswith(stop_string) for stop_string in self.stop_strings]):
887
+ return True
888
+ return False
889
+
890
+ return StoppingCriteriaList([StopOnStrings(stop_sequences, tokenizer)])
891
+
892
+ def _prepare_completion_args(
893
+ self,
894
+ messages: list[ChatMessage],
895
+ stop_sequences: list[str] | None = None,
896
+ tools_to_call_from: list[Tool] | None = None,
897
+ **kwargs,
898
+ ) -> dict[str, Any]:
899
+ completion_kwargs = self._prepare_completion_kwargs(
900
+ messages=messages,
901
+ stop_sequences=stop_sequences,
902
+ **kwargs,
903
+ )
904
+
905
+ messages = completion_kwargs.pop("messages")
906
+ stop_sequences = completion_kwargs.pop("stop", None)
907
+ tools = completion_kwargs.pop("tools", None)
908
+
909
+ max_new_tokens = (
910
+ kwargs.get("max_new_tokens")
911
+ or kwargs.get("max_tokens")
912
+ or self.kwargs.get("max_new_tokens")
913
+ or self.kwargs.get("max_tokens")
914
+ or 1024
915
+ )
916
+ prompt_tensor = (self.processor if hasattr(self, "processor") else self.tokenizer).apply_chat_template(
917
+ messages,
918
+ tools=tools,
919
+ return_tensors="pt",
920
+ add_generation_prompt=True,
921
+ tokenize=True,
922
+ return_dict=True,
923
+ )
924
+ prompt_tensor = prompt_tensor.to(self.model.device) # type: ignore
925
+ if hasattr(prompt_tensor, "input_ids"):
926
+ prompt_tensor = prompt_tensor["input_ids"]
927
+
928
+ model_tokenizer = self.processor.tokenizer if hasattr(self, "processor") else self.tokenizer
929
+ stopping_criteria = (
930
+ self.make_stopping_criteria(stop_sequences, tokenizer=model_tokenizer) if stop_sequences else None
931
+ )
932
+ completion_kwargs["max_new_tokens"] = max_new_tokens
933
+ return dict(
934
+ inputs=prompt_tensor,
935
+ use_cache=True,
936
+ stopping_criteria=stopping_criteria,
937
+ **completion_kwargs,
938
+ )
939
+
940
+ def generate(
941
+ self,
942
+ messages: list[ChatMessage],
943
+ stop_sequences: list[str] | None = None,
944
+ response_format: dict[str, str] | None = None,
945
+ tools_to_call_from: list[Tool] | None = None,
946
+ **kwargs,
947
+ ) -> ChatMessage:
948
+ if response_format is not None:
949
+ raise ValueError("Transformers does not support structured outputs, use VLLMModel for this.")
950
+ generation_kwargs = self._prepare_completion_args(
951
+ messages=messages,
952
+ stop_sequences=stop_sequences,
953
+ tools_to_call_from=tools_to_call_from,
954
+ **kwargs,
955
+ )
956
+ count_prompt_tokens = generation_kwargs["inputs"].shape[1] # type: ignore
957
+ out = self.model.generate(
958
+ **generation_kwargs,
959
+ )
960
+ generated_tokens = out[0, count_prompt_tokens:]
961
+ if hasattr(self, "processor"):
962
+ output_text = self.processor.decode(generated_tokens, skip_special_tokens=True)
963
+ else:
964
+ output_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
965
+
966
+ if stop_sequences is not None:
967
+ output_text = remove_stop_sequences(output_text, stop_sequences)
968
+
969
+ self._last_input_token_count = count_prompt_tokens
970
+ self._last_output_token_count = len(generated_tokens)
971
+ return ChatMessage(
972
+ role=MessageRole.ASSISTANT,
973
+ content=output_text,
974
+ raw={
975
+ "out": output_text,
976
+ "completion_kwargs": {key: value for key, value in generation_kwargs.items() if key != "inputs"},
977
+ },
978
+ token_usage=TokenUsage(
979
+ input_tokens=count_prompt_tokens,
980
+ output_tokens=len(generated_tokens),
981
+ ),
982
+ )
983
+
984
+ def generate_stream(
985
+ self,
986
+ messages: list[ChatMessage],
987
+ stop_sequences: list[str] | None = None,
988
+ response_format: dict[str, str] | None = None,
989
+ tools_to_call_from: list[Tool] | None = None,
990
+ **kwargs,
991
+ ) -> Generator[ChatMessageStreamDelta]:
992
+ if response_format is not None:
993
+ raise ValueError("Transformers does not support structured outputs, use VLLMModel for this.")
994
+ generation_kwargs = self._prepare_completion_args(
995
+ messages=messages,
996
+ stop_sequences=stop_sequences,
997
+ response_format=response_format,
998
+ tools_to_call_from=tools_to_call_from,
999
+ **kwargs,
1000
+ )
1001
+ count_prompt_tokens = generation_kwargs["inputs"].shape[1] # type: ignore
1002
+
1003
+ thread = Thread(target=self.model.generate, kwargs={"streamer": self.streamer, **generation_kwargs})
1004
+ thread.start()
1005
+
1006
+ # Generate with streaming
1007
+ for new_text in self.streamer:
1008
+ self._last_input_token_count = count_prompt_tokens
1009
+ self._last_output_token_count = 1
1010
+ yield ChatMessageStreamDelta(
1011
+ content=new_text,
1012
+ tool_calls=None,
1013
+ token_usage=TokenUsage(input_tokens=count_prompt_tokens, output_tokens=1),
1014
+ )
1015
+ thread.join()
1016
+
1017
+
1018
+ class ApiModel(Model):
1019
+ """
1020
+ Base class for API-based language models.
1021
+
1022
+ This class serves as a foundation for implementing models that interact with
1023
+ external APIs. It handles the common functionality for managing model IDs,
1024
+ custom role mappings, and API client connections.
1025
+
1026
+ Parameters:
1027
+ model_id (`str`):
1028
+ The identifier for the model to be used with the API.
1029
+ custom_role_conversions (`dict[str, str`], **optional**):
1030
+ Mapping to convert between internal role names and API-specific role names. Defaults to None.
1031
+ client (`Any`, **optional**):
1032
+ Pre-configured API client instance. If not provided, a default client will be created. Defaults to None.
1033
+ **kwargs: Additional keyword arguments to pass to the parent class.
1034
+ """
1035
+
1036
+ def __init__(
1037
+ self, model_id: str, custom_role_conversions: dict[str, str] | None = None, client: Any | None = None, **kwargs
1038
+ ):
1039
+ super().__init__(model_id=model_id, **kwargs)
1040
+ self.custom_role_conversions = custom_role_conversions or {}
1041
+ self.client = client or self.create_client()
1042
+
1043
+ def create_client(self):
1044
+ """Create the API client for the specific service."""
1045
+ raise NotImplementedError("Subclasses must implement this method to create a client")
1046
+
1047
+
1048
+ class LiteLLMModel(ApiModel):
1049
+ """Model to use [LiteLLM Python SDK](https://docs.litellm.ai/docs/#litellm-python-sdk) to access hundreds of LLMs.
1050
+
1051
+ Parameters:
1052
+ model_id (`str`):
1053
+ The model identifier to use on the server (e.g. "gpt-3.5-turbo").
1054
+ api_base (`str`, *optional*):
1055
+ The base URL of the provider API to call the model.
1056
+ api_key (`str`, *optional*):
1057
+ The API key to use for authentication.
1058
+ custom_role_conversions (`dict[str, str]`, *optional*):
1059
+ Custom role conversion mapping to convert message roles in others.
1060
+ Useful for specific models that do not support specific message roles like "system".
1061
+ flatten_messages_as_text (`bool`, *optional*): Whether to flatten messages as text.
1062
+ Defaults to `True` for models that start with "ollama", "groq", "cerebras".
1063
+ **kwargs:
1064
+ Additional keyword arguments to pass to the OpenAI API.
1065
+ """
1066
+
1067
+ def __init__(
1068
+ self,
1069
+ model_id: str | None = None,
1070
+ api_base: str | None = None,
1071
+ api_key: str | None = None,
1072
+ custom_role_conversions: dict[str, str] | None = None,
1073
+ flatten_messages_as_text: bool | None = None,
1074
+ **kwargs,
1075
+ ):
1076
+ if not model_id:
1077
+ warnings.warn(
1078
+ "The 'model_id' parameter will be required in version 2.0.0. "
1079
+ "Please update your code to pass this parameter to avoid future errors. "
1080
+ "For now, it defaults to 'anthropic/claude-3-5-sonnet-20240620'.",
1081
+ FutureWarning,
1082
+ )
1083
+ model_id = "anthropic/claude-3-5-sonnet-20240620"
1084
+ self.api_base = api_base
1085
+ self.api_key = api_key
1086
+ flatten_messages_as_text = (
1087
+ flatten_messages_as_text
1088
+ if flatten_messages_as_text is not None
1089
+ else model_id.startswith(("ollama", "groq", "cerebras"))
1090
+ )
1091
+ super().__init__(
1092
+ model_id=model_id,
1093
+ custom_role_conversions=custom_role_conversions,
1094
+ flatten_messages_as_text=flatten_messages_as_text,
1095
+ **kwargs,
1096
+ )
1097
+
1098
+ def create_client(self):
1099
+ """Create the LiteLLM client."""
1100
+ try:
1101
+ import litellm
1102
+ except ModuleNotFoundError as e:
1103
+ raise ModuleNotFoundError(
1104
+ "Please install 'litellm' extra to use LiteLLMModel: `pip install 'smolagents[litellm]'`"
1105
+ ) from e
1106
+
1107
+ return litellm
1108
+
1109
+ def generate(
1110
+ self,
1111
+ messages: list[ChatMessage],
1112
+ stop_sequences: list[str] | None = None,
1113
+ response_format: dict[str, str] | None = None,
1114
+ tools_to_call_from: list[Tool] | None = None,
1115
+ **kwargs,
1116
+ ) -> ChatMessage:
1117
+ completion_kwargs = self._prepare_completion_kwargs(
1118
+ messages=messages,
1119
+ stop_sequences=stop_sequences,
1120
+ response_format=response_format,
1121
+ tools_to_call_from=tools_to_call_from,
1122
+ model=self.model_id,
1123
+ api_base=self.api_base,
1124
+ api_key=self.api_key,
1125
+ convert_images_to_image_urls=True,
1126
+ custom_role_conversions=self.custom_role_conversions,
1127
+ **kwargs,
1128
+ )
1129
+
1130
+ response = self.client.completion(**completion_kwargs)
1131
+
1132
+ self._last_input_token_count = response.usage.prompt_tokens
1133
+ self._last_output_token_count = response.usage.completion_tokens
1134
+ return ChatMessage.from_dict(
1135
+ response.choices[0].message.model_dump(include={"role", "content", "tool_calls"}),
1136
+ raw=response,
1137
+ token_usage=TokenUsage(
1138
+ input_tokens=response.usage.prompt_tokens,
1139
+ output_tokens=response.usage.completion_tokens,
1140
+ ),
1141
+ )
1142
+
1143
+ def generate_stream(
1144
+ self,
1145
+ messages: list[ChatMessage],
1146
+ stop_sequences: list[str] | None = None,
1147
+ response_format: dict[str, str] | None = None,
1148
+ tools_to_call_from: list[Tool] | None = None,
1149
+ **kwargs,
1150
+ ) -> Generator[ChatMessageStreamDelta]:
1151
+ completion_kwargs = self._prepare_completion_kwargs(
1152
+ messages=messages,
1153
+ stop_sequences=stop_sequences,
1154
+ response_format=response_format,
1155
+ tools_to_call_from=tools_to_call_from,
1156
+ model=self.model_id,
1157
+ api_base=self.api_base,
1158
+ api_key=self.api_key,
1159
+ custom_role_conversions=self.custom_role_conversions,
1160
+ convert_images_to_image_urls=True,
1161
+ **kwargs,
1162
+ )
1163
+ for event in self.client.completion(**completion_kwargs, stream=True, stream_options={"include_usage": True}):
1164
+ if getattr(event, "usage", None):
1165
+ self._last_input_token_count = event.usage.prompt_tokens
1166
+ self._last_output_token_count = event.usage.completion_tokens
1167
+ yield ChatMessageStreamDelta(
1168
+ content="",
1169
+ token_usage=TokenUsage(
1170
+ input_tokens=event.usage.prompt_tokens,
1171
+ output_tokens=event.usage.completion_tokens,
1172
+ ),
1173
+ )
1174
+ if event.choices:
1175
+ choice = event.choices[0]
1176
+ if choice.delta:
1177
+ yield ChatMessageStreamDelta(
1178
+ content=choice.delta.content,
1179
+ tool_calls=[
1180
+ ChatMessageToolCallStreamDelta(
1181
+ index=delta.index,
1182
+ id=delta.id,
1183
+ type=delta.type,
1184
+ function=delta.function,
1185
+ )
1186
+ for delta in choice.delta.tool_calls
1187
+ ]
1188
+ if choice.delta.tool_calls
1189
+ else None,
1190
+ )
1191
+ else:
1192
+ if not getattr(choice, "finish_reason", None):
1193
+ raise ValueError(f"No content or tool calls in event: {event}")
1194
+
1195
+
1196
+ class LiteLLMRouterModel(LiteLLMModel):
1197
+ """Router‑based client for interacting with the [LiteLLM Python SDK Router](https://docs.litellm.ai/docs/routing).
1198
+
1199
+ This class provides a high-level interface for distributing requests among multiple language models using
1200
+ the LiteLLM SDK's routing capabilities. It is responsible for initializing and configuring the router client,
1201
+ applying custom role conversions, and managing message formatting to ensure seamless integration with various LLMs.
1202
+
1203
+ Parameters:
1204
+ model_id (`str`):
1205
+ Identifier for the model group to use from the model list (e.g., "model-group-1").
1206
+ model_list (`list[dict[str, Any]]`):
1207
+ Model configurations to be used for routing.
1208
+ Each configuration should include the model group name and any necessary parameters.
1209
+ For more details, refer to the [LiteLLM Routing](https://docs.litellm.ai/docs/routing#quick-start) documentation.
1210
+ client_kwargs (`dict[str, Any]`, *optional*):
1211
+ Additional configuration parameters for the Router client. For more details, see the
1212
+ [LiteLLM Routing Configurations](https://docs.litellm.ai/docs/routing).
1213
+ custom_role_conversions (`dict[str, str]`, *optional*):
1214
+ Custom role conversion mapping to convert message roles in others.
1215
+ Useful for specific models that do not support specific message roles like "system".
1216
+ flatten_messages_as_text (`bool`, *optional*): Whether to flatten messages as text.
1217
+ Defaults to `True` for models that start with "ollama", "groq", "cerebras".
1218
+ **kwargs:
1219
+ Additional keyword arguments to pass to the LiteLLM Router completion method.
1220
+
1221
+ Example:
1222
+ ```python
1223
+ >>> import os
1224
+ >>> from smolagents import CodeAgent, WebSearchTool, LiteLLMRouterModel
1225
+ >>> os.environ["OPENAI_API_KEY"] = ""
1226
+ >>> os.environ["AWS_ACCESS_KEY_ID"] = ""
1227
+ >>> os.environ["AWS_SECRET_ACCESS_KEY"] = ""
1228
+ >>> os.environ["AWS_REGION"] = ""
1229
+ >>> llm_loadbalancer_model_list = [
1230
+ ... {
1231
+ ... "model_name": "model-group-1",
1232
+ ... "litellm_params": {
1233
+ ... "model": "gpt-4o-mini",
1234
+ ... "api_key": os.getenv("OPENAI_API_KEY"),
1235
+ ... },
1236
+ ... },
1237
+ ... {
1238
+ ... "model_name": "model-group-1",
1239
+ ... "litellm_params": {
1240
+ ... "model": "bedrock/anthropic.claude-3-sonnet-20240229-v1:0",
1241
+ ... "aws_access_key_id": os.getenv("AWS_ACCESS_KEY_ID"),
1242
+ ... "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"),
1243
+ ... "aws_region_name": os.getenv("AWS_REGION"),
1244
+ ... },
1245
+ ... },
1246
+ >>> ]
1247
+ >>> model = LiteLLMRouterModel(
1248
+ ... model_id="model-group-1",
1249
+ ... model_list=llm_loadbalancer_model_list,
1250
+ ... client_kwargs={
1251
+ ... "routing_strategy":"simple-shuffle"
1252
+ ... }
1253
+ >>> )
1254
+ >>> agent = CodeAgent(tools=[WebSearchTool()], model=model)
1255
+ >>> agent.run("How many seconds would it take for a leopard at full speed to run through Pont des Arts?")
1256
+ ```
1257
+ """
1258
+
1259
+ def __init__(
1260
+ self,
1261
+ model_id: str,
1262
+ model_list: list[dict[str, Any]],
1263
+ client_kwargs: dict[str, Any] | None = None,
1264
+ custom_role_conversions: dict[str, str] | None = None,
1265
+ flatten_messages_as_text: bool | None = None,
1266
+ **kwargs,
1267
+ ):
1268
+ self.client_kwargs = {
1269
+ "model_list": model_list,
1270
+ **(client_kwargs or {}),
1271
+ }
1272
+ super().__init__(
1273
+ model_id=model_id,
1274
+ custom_role_conversions=custom_role_conversions,
1275
+ flatten_messages_as_text=flatten_messages_as_text,
1276
+ **kwargs,
1277
+ )
1278
+
1279
+ def create_client(self):
1280
+ try:
1281
+ from litellm.router import Router
1282
+ except ModuleNotFoundError as e:
1283
+ raise ModuleNotFoundError(
1284
+ "Please install 'litellm' extra to use LiteLLMRouterModel: `pip install 'smolagents[litellm]'`"
1285
+ ) from e
1286
+ return Router(**self.client_kwargs)
1287
+
1288
+
1289
+ class InferenceClientModel(ApiModel):
1290
+ """A class to interact with Hugging Face's Inference Providers for language model interaction.
1291
+
1292
+ This model allows you to communicate with Hugging Face's models using Inference Providers. It can be used in both serverless mode, with a dedicated endpoint, or even with a local URL, supporting features like stop sequences and grammar customization.
1293
+
1294
+ Providers include Cerebras, Cohere, Fal, Fireworks, HF-Inference, Hyperbolic, Nebius, Novita, Replicate, SambaNova, Together, and more.
1295
+
1296
+ Parameters:
1297
+ model_id (`str`, *optional*, default `"Qwen/Qwen2.5-Coder-32B-Instruct"`):
1298
+ The Hugging Face model ID to be used for inference.
1299
+ This can be a model identifier from the Hugging Face model hub or a URL to a deployed Inference Endpoint.
1300
+ Currently, it defaults to `"Qwen/Qwen2.5-Coder-32B-Instruct"`, but this may change in the future.
1301
+ provider (`str`, *optional*):
1302
+ Name of the provider to use for inference. A list of supported providers can be found in the [Inference Providers documentation](https://huggingface.co/docs/inference-providers/index#partners).
1303
+ Defaults to "auto" i.e. the first of the providers available for the model, sorted by the user's order [here](https://hf.co/settings/inference-providers).
1304
+ If `base_url` is passed, then `provider` is not used.
1305
+ token (`str`, *optional*):
1306
+ Token used by the Hugging Face API for authentication. This token need to be authorized 'Make calls to the serverless Inference Providers'.
1307
+ If the model is gated (like Llama-3 models), the token also needs 'Read access to contents of all public gated repos you can access'.
1308
+ If not provided, the class will try to use environment variable 'HF_TOKEN', else use the token stored in the Hugging Face CLI configuration.
1309
+ timeout (`int`, *optional*, defaults to 120):
1310
+ Timeout for the API request, in seconds.
1311
+ client_kwargs (`dict[str, Any]`, *optional*):
1312
+ Additional keyword arguments to pass to the Hugging Face InferenceClient.
1313
+ custom_role_conversions (`dict[str, str]`, *optional*):
1314
+ Custom role conversion mapping to convert message roles in others.
1315
+ Useful for specific models that do not support specific message roles like "system".
1316
+ api_key (`str`, *optional*):
1317
+ Token to use for authentication. This is a duplicated argument from `token` to make [`InferenceClientModel`]
1318
+ follow the same pattern as `openai.OpenAI` client. Cannot be used if `token` is set. Defaults to None.
1319
+ bill_to (`str`, *optional*):
1320
+ The billing account to use for the requests. By default the requests are billed on the user's account. Requests can only be billed to
1321
+ an organization the user is a member of, and which has subscribed to Enterprise Hub.
1322
+ base_url (`str`, `optional`):
1323
+ Base URL to run inference. This is a duplicated argument from `model` to make [`InferenceClientModel`]
1324
+ follow the same pattern as `openai.OpenAI` client. Cannot be used if `model` is set. Defaults to None.
1325
+ **kwargs:
1326
+ Additional keyword arguments to pass to the Hugging Face InferenceClient.
1327
+
1328
+ Raises:
1329
+ ValueError:
1330
+ If the model name is not provided.
1331
+
1332
+ Example:
1333
+ ```python
1334
+ >>> engine = InferenceClientModel(
1335
+ ... model_id="Qwen/Qwen2.5-Coder-32B-Instruct",
1336
+ ... provider="nebius",
1337
+ ... token="your_hf_token_here",
1338
+ ... max_tokens=5000,
1339
+ ... )
1340
+ >>> messages = [{"role": "user", "content": "Explain quantum mechanics in simple terms."}]
1341
+ >>> response = engine(messages, stop_sequences=["END"])
1342
+ >>> print(response)
1343
+ "Quantum mechanics is the branch of physics that studies..."
1344
+ ```
1345
+ """
1346
+
1347
+ def __init__(
1348
+ self,
1349
+ model_id: str = "Qwen/Qwen2.5-Coder-32B-Instruct",
1350
+ provider: str | None = None,
1351
+ token: str | None = None,
1352
+ timeout: int = 120,
1353
+ client_kwargs: dict[str, Any] | None = None,
1354
+ custom_role_conversions: dict[str, str] | None = None,
1355
+ api_key: str | None = None,
1356
+ bill_to: str | None = None,
1357
+ base_url: str | None = None,
1358
+ **kwargs,
1359
+ ):
1360
+ if token is not None and api_key is not None:
1361
+ raise ValueError(
1362
+ "Received both `token` and `api_key` arguments. Please provide only one of them."
1363
+ " `api_key` is an alias for `token` to make the API compatible with OpenAI's client."
1364
+ " It has the exact same behavior as `token`."
1365
+ )
1366
+ token = token if token is not None else api_key
1367
+ if token is None:
1368
+ token = os.getenv("HF_TOKEN")
1369
+ self.client_kwargs = {
1370
+ **(client_kwargs or {}),
1371
+ "model": model_id,
1372
+ "provider": provider,
1373
+ "token": token,
1374
+ "timeout": timeout,
1375
+ "bill_to": bill_to,
1376
+ "base_url": base_url,
1377
+ }
1378
+ super().__init__(model_id=model_id, custom_role_conversions=custom_role_conversions, **kwargs)
1379
+
1380
+ def create_client(self):
1381
+ """Create the Hugging Face client."""
1382
+ from huggingface_hub import InferenceClient
1383
+
1384
+ return InferenceClient(**self.client_kwargs)
1385
+
1386
+ def generate(
1387
+ self,
1388
+ messages: list[ChatMessage],
1389
+ stop_sequences: list[str] | None = None,
1390
+ response_format: dict[str, str] | None = None,
1391
+ tools_to_call_from: list[Tool] | None = None,
1392
+ **kwargs,
1393
+ ) -> ChatMessage:
1394
+ if response_format is not None and self.client_kwargs["provider"] not in STRUCTURED_GENERATION_PROVIDERS:
1395
+ raise ValueError(
1396
+ "InferenceClientModel only supports structured outputs with these providers:"
1397
+ + ", ".join(STRUCTURED_GENERATION_PROVIDERS)
1398
+ )
1399
+ completion_kwargs = self._prepare_completion_kwargs(
1400
+ messages=messages,
1401
+ stop_sequences=stop_sequences,
1402
+ tools_to_call_from=tools_to_call_from,
1403
+ # response_format=response_format,
1404
+ convert_images_to_image_urls=True,
1405
+ custom_role_conversions=self.custom_role_conversions,
1406
+ **kwargs,
1407
+ )
1408
+ response = self.client.chat_completion(**completion_kwargs)
1409
+
1410
+ self._last_input_token_count = response.usage.prompt_tokens
1411
+ self._last_output_token_count = response.usage.completion_tokens
1412
+ return ChatMessage.from_dict(
1413
+ asdict(response.choices[0].message),
1414
+ raw=response,
1415
+ token_usage=TokenUsage(
1416
+ input_tokens=response.usage.prompt_tokens,
1417
+ output_tokens=response.usage.completion_tokens,
1418
+ ),
1419
+ )
1420
+
1421
+ def generate_stream(
1422
+ self,
1423
+ messages: list[ChatMessage],
1424
+ stop_sequences: list[str] | None = None,
1425
+ response_format: dict[str, str] | None = None,
1426
+ tools_to_call_from: list[Tool] | None = None,
1427
+ **kwargs,
1428
+ ) -> Generator[ChatMessageStreamDelta]:
1429
+ completion_kwargs = self._prepare_completion_kwargs(
1430
+ messages=messages,
1431
+ stop_sequences=stop_sequences,
1432
+ response_format=response_format,
1433
+ tools_to_call_from=tools_to_call_from,
1434
+ model=self.model_id,
1435
+ custom_role_conversions=self.custom_role_conversions,
1436
+ convert_images_to_image_urls=True,
1437
+ **kwargs,
1438
+ )
1439
+ for event in self.client.chat.completions.create(
1440
+ **completion_kwargs, stream=True, stream_options={"include_usage": True}
1441
+ ):
1442
+ if getattr(event, "usage", None):
1443
+ self._last_input_token_count = event.usage.prompt_tokens
1444
+ self._last_output_token_count = event.usage.completion_tokens
1445
+ yield ChatMessageStreamDelta(
1446
+ content="",
1447
+ token_usage=TokenUsage(
1448
+ input_tokens=event.usage.prompt_tokens,
1449
+ output_tokens=event.usage.completion_tokens,
1450
+ ),
1451
+ )
1452
+ if event.choices:
1453
+ choice = event.choices[0]
1454
+ if choice.delta:
1455
+ yield ChatMessageStreamDelta(
1456
+ content=choice.delta.content,
1457
+ tool_calls=[
1458
+ ChatMessageToolCallStreamDelta(
1459
+ index=delta.index,
1460
+ id=delta.id,
1461
+ type=delta.type,
1462
+ function=delta.function,
1463
+ )
1464
+ for delta in choice.delta.tool_calls
1465
+ ]
1466
+ if choice.delta.tool_calls
1467
+ else None,
1468
+ )
1469
+ else:
1470
+ if not getattr(choice, "finish_reason", None):
1471
+ raise ValueError(f"No content or tool calls in event: {event}")
1472
+
1473
+
1474
+ class OpenAIServerModel(ApiModel):
1475
+ """This model connects to an OpenAI-compatible API server.
1476
+
1477
+ Parameters:
1478
+ model_id (`str`):
1479
+ The model identifier to use on the server (e.g. "gpt-3.5-turbo").
1480
+ api_base (`str`, *optional*):
1481
+ The base URL of the OpenAI-compatible API server.
1482
+ api_key (`str`, *optional*):
1483
+ The API key to use for authentication.
1484
+ organization (`str`, *optional*):
1485
+ The organization to use for the API request.
1486
+ project (`str`, *optional*):
1487
+ The project to use for the API request.
1488
+ client_kwargs (`dict[str, Any]`, *optional*):
1489
+ Additional keyword arguments to pass to the OpenAI client (like organization, project, max_retries etc.).
1490
+ custom_role_conversions (`dict[str, str]`, *optional*):
1491
+ Custom role conversion mapping to convert message roles in others.
1492
+ Useful for specific models that do not support specific message roles like "system".
1493
+ flatten_messages_as_text (`bool`, default `False`):
1494
+ Whether to flatten messages as text.
1495
+ **kwargs:
1496
+ Additional keyword arguments to pass to the OpenAI API.
1497
+ """
1498
+
1499
+ def __init__(
1500
+ self,
1501
+ model_id: str,
1502
+ api_base: str | None = None,
1503
+ api_key: str | None = None,
1504
+ organization: str | None = None,
1505
+ project: str | None = None,
1506
+ client_kwargs: dict[str, Any] | None = None,
1507
+ custom_role_conversions: dict[str, str] | None = None,
1508
+ flatten_messages_as_text: bool = False,
1509
+ **kwargs,
1510
+ ):
1511
+ self.client_kwargs = {
1512
+ **(client_kwargs or {}),
1513
+ "api_key": api_key,
1514
+ "base_url": api_base,
1515
+ "organization": organization,
1516
+ "project": project,
1517
+ }
1518
+ super().__init__(
1519
+ model_id=model_id,
1520
+ custom_role_conversions=custom_role_conversions,
1521
+ flatten_messages_as_text=flatten_messages_as_text,
1522
+ **kwargs,
1523
+ )
1524
+
1525
+ def create_client(self):
1526
+ try:
1527
+ import openai
1528
+ except ModuleNotFoundError as e:
1529
+ raise ModuleNotFoundError(
1530
+ "Please install 'openai' extra to use OpenAIServerModel: `pip install 'smolagents[openai]'`"
1531
+ ) from e
1532
+
1533
+ return openai.OpenAI(**self.client_kwargs)
1534
+
1535
+ def generate_stream(
1536
+ self,
1537
+ messages: list[ChatMessage],
1538
+ stop_sequences: list[str] | None = None,
1539
+ response_format: dict[str, str] | None = None,
1540
+ tools_to_call_from: list[Tool] | None = None,
1541
+ **kwargs,
1542
+ ) -> Generator[ChatMessageStreamDelta]:
1543
+ completion_kwargs = self._prepare_completion_kwargs(
1544
+ messages=messages,
1545
+ stop_sequences=stop_sequences,
1546
+ response_format=response_format,
1547
+ tools_to_call_from=tools_to_call_from,
1548
+ model=self.model_id,
1549
+ custom_role_conversions=self.custom_role_conversions,
1550
+ convert_images_to_image_urls=True,
1551
+ **kwargs,
1552
+ )
1553
+ for event in self.client.chat.completions.create(
1554
+ **completion_kwargs, stream=True, stream_options={"include_usage": True}
1555
+ ):
1556
+ if event.usage:
1557
+ self._last_input_token_count = event.usage.prompt_tokens
1558
+ self._last_output_token_count = event.usage.completion_tokens
1559
+ yield ChatMessageStreamDelta(
1560
+ content="",
1561
+ token_usage=TokenUsage(
1562
+ input_tokens=event.usage.prompt_tokens,
1563
+ output_tokens=event.usage.completion_tokens,
1564
+ ),
1565
+ )
1566
+ if event.choices:
1567
+ choice = event.choices[0]
1568
+ if choice.delta:
1569
+ yield ChatMessageStreamDelta(
1570
+ content=choice.delta.content,
1571
+ tool_calls=[
1572
+ ChatMessageToolCallStreamDelta(
1573
+ index=delta.index,
1574
+ id=delta.id,
1575
+ type=delta.type,
1576
+ function=delta.function,
1577
+ )
1578
+ for delta in choice.delta.tool_calls
1579
+ ]
1580
+ if choice.delta.tool_calls
1581
+ else None,
1582
+ )
1583
+ else:
1584
+ if not getattr(choice, "finish_reason", None):
1585
+ raise ValueError(f"No content or tool calls in event: {event}")
1586
+
1587
+ def generate(
1588
+ self,
1589
+ messages: list[ChatMessage],
1590
+ stop_sequences: list[str] | None = None,
1591
+ response_format: dict[str, str] | None = None,
1592
+ tools_to_call_from: list[Tool] | None = None,
1593
+ **kwargs,
1594
+ ) -> ChatMessage:
1595
+ completion_kwargs = self._prepare_completion_kwargs(
1596
+ messages=messages,
1597
+ stop_sequences=stop_sequences,
1598
+ response_format=response_format,
1599
+ tools_to_call_from=tools_to_call_from,
1600
+ model=self.model_id,
1601
+ custom_role_conversions=self.custom_role_conversions,
1602
+ convert_images_to_image_urls=True,
1603
+ **kwargs,
1604
+ )
1605
+ response = self.client.chat.completions.create(**completion_kwargs)
1606
+
1607
+ # Reported that `response.usage` can be None in some cases when using OpenRouter: see GH-1401
1608
+ self._last_input_token_count = getattr(response.usage, "prompt_tokens", 0)
1609
+ self._last_output_token_count = getattr(response.usage, "completion_tokens", 0)
1610
+ return ChatMessage.from_dict(
1611
+ response.choices[0].message.model_dump(include={"role", "content", "tool_calls"}),
1612
+ raw=response,
1613
+ token_usage=TokenUsage(
1614
+ input_tokens=response.usage.prompt_tokens,
1615
+ output_tokens=response.usage.completion_tokens,
1616
+ ),
1617
+ )
1618
+
1619
+
1620
+ OpenAIModel = OpenAIServerModel
1621
+
1622
+
1623
+ class AzureOpenAIServerModel(OpenAIServerModel):
1624
+ """This model connects to an Azure OpenAI deployment.
1625
+
1626
+ Parameters:
1627
+ model_id (`str`):
1628
+ The model deployment name to use when connecting (e.g. "gpt-4o-mini").
1629
+ azure_endpoint (`str`, *optional*):
1630
+ The Azure endpoint, including the resource, e.g. `https://example-resource.azure.openai.com/`. If not provided, it will be inferred from the `AZURE_OPENAI_ENDPOINT` environment variable.
1631
+ api_key (`str`, *optional*):
1632
+ The API key to use for authentication. If not provided, it will be inferred from the `AZURE_OPENAI_API_KEY` environment variable.
1633
+ api_version (`str`, *optional*):
1634
+ The API version to use. If not provided, it will be inferred from the `OPENAI_API_VERSION` environment variable.
1635
+ client_kwargs (`dict[str, Any]`, *optional*):
1636
+ Additional keyword arguments to pass to the AzureOpenAI client (like organization, project, max_retries etc.).
1637
+ custom_role_conversions (`dict[str, str]`, *optional*):
1638
+ Custom role conversion mapping to convert message roles in others.
1639
+ Useful for specific models that do not support specific message roles like "system".
1640
+ **kwargs:
1641
+ Additional keyword arguments to pass to the Azure OpenAI API.
1642
+ """
1643
+
1644
+ def __init__(
1645
+ self,
1646
+ model_id: str,
1647
+ azure_endpoint: str | None = None,
1648
+ api_key: str | None = None,
1649
+ api_version: str | None = None,
1650
+ client_kwargs: dict[str, Any] | None = None,
1651
+ custom_role_conversions: dict[str, str] | None = None,
1652
+ **kwargs,
1653
+ ):
1654
+ client_kwargs = client_kwargs or {}
1655
+ client_kwargs.update(
1656
+ {
1657
+ "api_version": api_version,
1658
+ "azure_endpoint": azure_endpoint,
1659
+ }
1660
+ )
1661
+ super().__init__(
1662
+ model_id=model_id,
1663
+ api_key=api_key,
1664
+ client_kwargs=client_kwargs,
1665
+ custom_role_conversions=custom_role_conversions,
1666
+ **kwargs,
1667
+ )
1668
+
1669
+ def create_client(self):
1670
+ try:
1671
+ import openai
1672
+ except ModuleNotFoundError as e:
1673
+ raise ModuleNotFoundError(
1674
+ "Please install 'openai' extra to use AzureOpenAIServerModel: `pip install 'smolagents[openai]'`"
1675
+ ) from e
1676
+
1677
+ return openai.AzureOpenAI(**self.client_kwargs)
1678
+
1679
+
1680
+ AzureOpenAIModel = AzureOpenAIServerModel
1681
+
1682
+
1683
+ class AmazonBedrockServerModel(ApiModel):
1684
+ """
1685
+ A model class for interacting with Amazon Bedrock Server models through the Bedrock API.
1686
+
1687
+ This class provides an interface to interact with various Bedrock language models,
1688
+ allowing for customized model inference, guardrail configuration, message handling,
1689
+ and other parameters allowed by boto3 API.
1690
+
1691
+ Parameters:
1692
+ model_id (`str`):
1693
+ The model identifier to use on Bedrock (e.g. "us.amazon.nova-pro-v1:0").
1694
+ client (`boto3.client`, *optional*):
1695
+ A custom boto3 client for AWS interactions. If not provided, a default client will be created.
1696
+ client_kwargs (dict[str, Any], *optional*):
1697
+ Keyword arguments used to configure the boto3 client if it needs to be created internally.
1698
+ Examples include `region_name`, `config`, or `endpoint_url`.
1699
+ custom_role_conversions (`dict[str, str]`, *optional*):
1700
+ Custom role conversion mapping to convert message roles in others.
1701
+ Useful for specific models that do not support specific message roles like "system".
1702
+ Defaults to converting all roles to "user" role to enable using all the Bedrock models.
1703
+ flatten_messages_as_text (`bool`, default `False`):
1704
+ Whether to flatten messages as text.
1705
+ **kwargs
1706
+ Additional keyword arguments passed directly to the underlying API calls.
1707
+
1708
+ Example:
1709
+ Creating a model instance with default settings:
1710
+ >>> bedrock_model = AmazonBedrockServerModel(
1711
+ ... model_id='us.amazon.nova-pro-v1:0'
1712
+ ... )
1713
+
1714
+ Creating a model instance with a custom boto3 client:
1715
+ >>> import boto3
1716
+ >>> client = boto3.client('bedrock-runtime', region_name='us-west-2')
1717
+ >>> bedrock_model = AmazonBedrockServerModel(
1718
+ ... model_id='us.amazon.nova-pro-v1:0',
1719
+ ... client=client
1720
+ ... )
1721
+
1722
+ Creating a model instance with client_kwargs for internal client creation:
1723
+ >>> bedrock_model = AmazonBedrockServerModel(
1724
+ ... model_id='us.amazon.nova-pro-v1:0',
1725
+ ... client_kwargs={'region_name': 'us-west-2', 'endpoint_url': 'https://custom-endpoint.com'}
1726
+ ... )
1727
+
1728
+ Creating a model instance with inference and guardrail configurations:
1729
+ >>> additional_api_config = {
1730
+ ... "inferenceConfig": {
1731
+ ... "maxTokens": 3000
1732
+ ... },
1733
+ ... "guardrailConfig": {
1734
+ ... "guardrailIdentifier": "identify1",
1735
+ ... "guardrailVersion": 'v1'
1736
+ ... },
1737
+ ... }
1738
+ >>> bedrock_model = AmazonBedrockServerModel(
1739
+ ... model_id='anthropic.claude-3-haiku-20240307-v1:0',
1740
+ ... **additional_api_config
1741
+ ... )
1742
+ """
1743
+
1744
+ def __init__(
1745
+ self,
1746
+ model_id: str,
1747
+ client=None,
1748
+ client_kwargs: dict[str, Any] | None = None,
1749
+ custom_role_conversions: dict[str, str] | None = None,
1750
+ **kwargs,
1751
+ ):
1752
+ self.client_kwargs = client_kwargs or {}
1753
+
1754
+ # Bedrock only supports `assistant` and `user` roles.
1755
+ # Many Bedrock models do not allow conversations to start with the `assistant` role, so the default is set to `user/user`.
1756
+ # This parameter is retained for future model implementations and extended support.
1757
+ custom_role_conversions = custom_role_conversions or {
1758
+ MessageRole.SYSTEM: MessageRole.USER,
1759
+ MessageRole.ASSISTANT: MessageRole.USER,
1760
+ MessageRole.TOOL_CALL: MessageRole.USER,
1761
+ MessageRole.TOOL_RESPONSE: MessageRole.USER,
1762
+ }
1763
+
1764
+ super().__init__(
1765
+ model_id=model_id,
1766
+ custom_role_conversions=custom_role_conversions,
1767
+ flatten_messages_as_text=False, # Bedrock API doesn't support flatten messages, must be a list of messages
1768
+ client=client,
1769
+ **kwargs,
1770
+ )
1771
+
1772
+ def _prepare_completion_kwargs(
1773
+ self,
1774
+ messages: list[ChatMessage],
1775
+ stop_sequences: list[str] | None = None,
1776
+ response_format: dict[str, str] | None = None,
1777
+ tools_to_call_from: list[Tool] | None = None,
1778
+ custom_role_conversions: dict[str, str] | None = None,
1779
+ convert_images_to_image_urls: bool = False,
1780
+ tool_choice: str | dict[Any, Any] | None = None,
1781
+ **kwargs,
1782
+ ) -> dict:
1783
+ """
1784
+ Overrides the base method to handle Bedrock-specific configurations.
1785
+
1786
+ This implementation adapts the completion keyword arguments to align with
1787
+ Bedrock's requirements, ensuring compatibility with its unique setup and
1788
+ constraints.
1789
+ """
1790
+ completion_kwargs = super()._prepare_completion_kwargs(
1791
+ messages=messages,
1792
+ stop_sequences=None, # Bedrock support stop_sequence using Inference Config
1793
+ tools_to_call_from=tools_to_call_from,
1794
+ custom_role_conversions=custom_role_conversions,
1795
+ convert_images_to_image_urls=convert_images_to_image_urls,
1796
+ **kwargs,
1797
+ )
1798
+
1799
+ # Not all models in Bedrock support `toolConfig`. Also, smolagents already include the tool call in the prompt,
1800
+ # so adding `toolConfig` could cause conflicts. We remove it to avoid issues.
1801
+ completion_kwargs.pop("toolConfig", None)
1802
+
1803
+ # The Bedrock API does not support the `type` key in requests.
1804
+ # This block of code modifies the object to meet Bedrock's requirements.
1805
+ for message in completion_kwargs.get("messages", []):
1806
+ for content in message.get("content", []):
1807
+ if "type" in content:
1808
+ del content["type"]
1809
+
1810
+ return {
1811
+ "modelId": self.model_id,
1812
+ **completion_kwargs,
1813
+ }
1814
+
1815
+ def create_client(self):
1816
+ try:
1817
+ import boto3 # type: ignore
1818
+ except ModuleNotFoundError as e:
1819
+ raise ModuleNotFoundError(
1820
+ "Please install 'bedrock' extra to use AmazonBedrockServerModel: `pip install 'smolagents[bedrock]'`"
1821
+ ) from e
1822
+
1823
+ return boto3.client("bedrock-runtime", **self.client_kwargs)
1824
+
1825
+ def generate(
1826
+ self,
1827
+ messages: list[ChatMessage],
1828
+ stop_sequences: list[str] | None = None,
1829
+ response_format: dict[str, str] | None = None,
1830
+ tools_to_call_from: list[Tool] | None = None,
1831
+ **kwargs,
1832
+ ) -> ChatMessage:
1833
+ if response_format is not None:
1834
+ raise ValueError("Amazon Bedrock does not support response_format")
1835
+ completion_kwargs: dict = self._prepare_completion_kwargs(
1836
+ messages=messages,
1837
+ tools_to_call_from=tools_to_call_from,
1838
+ custom_role_conversions=self.custom_role_conversions,
1839
+ convert_images_to_image_urls=True,
1840
+ **kwargs,
1841
+ )
1842
+
1843
+ # self.client is created in ApiModel class
1844
+ response = self.client.converse(**completion_kwargs)
1845
+
1846
+ # Get first message
1847
+ response["output"]["message"]["content"] = response["output"]["message"]["content"][0]["text"]
1848
+
1849
+ self._last_input_token_count = response["usage"]["inputTokens"]
1850
+ self._last_output_token_count = response["usage"]["outputTokens"]
1851
+ return ChatMessage.from_dict(
1852
+ response["output"]["message"],
1853
+ raw=response,
1854
+ token_usage=TokenUsage(
1855
+ input_tokens=response["usage"]["inputTokens"],
1856
+ output_tokens=response["usage"]["outputTokens"],
1857
+ ),
1858
+ )
1859
+
1860
+
1861
+ AmazonBedrockModel = AmazonBedrockServerModel
1862
+
1863
+ __all__ = [
1864
+ "MessageRole",
1865
+ "tool_role_conversions",
1866
+ "get_clean_message_list",
1867
+ "Model",
1868
+ "MLXModel",
1869
+ "TransformersModel",
1870
+ "ApiModel",
1871
+ "InferenceClientModel",
1872
+ "LiteLLMModel",
1873
+ "LiteLLMRouterModel",
1874
+ "OpenAIServerModel",
1875
+ "OpenAIModel",
1876
+ "VLLMModel",
1877
+ "AzureOpenAIServerModel",
1878
+ "AzureOpenAIModel",
1879
+ "AmazonBedrockServerModel",
1880
+ "AmazonBedrockModel",
1881
+ "ChatMessage",
1882
+ ]
src/smolagents/monitoring.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+
4
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ import json
18
+ from dataclasses import dataclass, field
19
+ from enum import IntEnum
20
+
21
+ from rich import box
22
+ from rich.console import Console, Group
23
+ from rich.panel import Panel
24
+ from rich.rule import Rule
25
+ from rich.syntax import Syntax
26
+ from rich.table import Table
27
+ from rich.text import Text
28
+ from rich.tree import Tree
29
+
30
+ from smolagents.utils import escape_code_brackets
31
+
32
+
33
+ __all__ = ["AgentLogger", "LogLevel", "Monitor", "TokenUsage", "Timing"]
34
+
35
+
36
+ @dataclass
37
+ class TokenUsage:
38
+ """
39
+ Contains the token usage information for a given step or run.
40
+ """
41
+
42
+ input_tokens: int
43
+ output_tokens: int
44
+ total_tokens: int = field(init=False)
45
+
46
+ def __post_init__(self):
47
+ self.total_tokens = self.input_tokens + self.output_tokens
48
+
49
+ def dict(self):
50
+ return {
51
+ "input_tokens": self.input_tokens,
52
+ "output_tokens": self.output_tokens,
53
+ "total_tokens": self.total_tokens,
54
+ }
55
+
56
+
57
+ @dataclass
58
+ class Timing:
59
+ """
60
+ Contains the timing information for a given step or run.
61
+ """
62
+
63
+ start_time: float
64
+ end_time: float | None = None
65
+
66
+ @property
67
+ def duration(self):
68
+ return None if self.end_time is None else self.end_time - self.start_time
69
+
70
+ def dict(self):
71
+ return {
72
+ "start_time": self.start_time,
73
+ "end_time": self.end_time,
74
+ "duration": self.duration,
75
+ }
76
+
77
+ def __repr__(self) -> str:
78
+ return f"Timing(start_time={self.start_time}, end_time={self.end_time}, duration={self.duration})"
79
+
80
+
81
+ class Monitor:
82
+ def __init__(self, tracked_model, logger):
83
+ self.step_durations = []
84
+ self.tracked_model = tracked_model
85
+ self.logger = logger
86
+ self.total_input_token_count = 0
87
+ self.total_output_token_count = 0
88
+
89
+ def get_total_token_counts(self) -> TokenUsage:
90
+ return TokenUsage(
91
+ input_tokens=self.total_input_token_count,
92
+ output_tokens=self.total_output_token_count,
93
+ )
94
+
95
+ def reset(self):
96
+ self.step_durations = []
97
+ self.total_input_token_count = 0
98
+ self.total_output_token_count = 0
99
+
100
+ def update_metrics(self, step_log):
101
+ """Update the metrics of the monitor.
102
+
103
+ Args:
104
+ step_log ([`MemoryStep`]): Step log to update the monitor with.
105
+ """
106
+ step_duration = step_log.timing.duration
107
+ self.step_durations.append(step_duration)
108
+ console_outputs = f"[Step {len(self.step_durations)}: Duration {step_duration:.2f} seconds"
109
+
110
+ if step_log.token_usage is not None:
111
+ self.total_input_token_count += step_log.token_usage.input_tokens
112
+ self.total_output_token_count += step_log.token_usage.output_tokens
113
+ console_outputs += (
114
+ f"| Input tokens: {self.total_input_token_count:,} | Output tokens: {self.total_output_token_count:,}"
115
+ )
116
+ console_outputs += "]"
117
+ self.logger.log(Text(console_outputs, style="dim"), level=1)
118
+
119
+
120
+ class LogLevel(IntEnum):
121
+ OFF = -1 # No output
122
+ ERROR = 0 # Only errors
123
+ INFO = 1 # Normal output (default)
124
+ DEBUG = 2 # Detailed output
125
+
126
+
127
+ YELLOW_HEX = "#d4b702"
128
+
129
+
130
+ class AgentLogger:
131
+ def __init__(self, level: LogLevel = LogLevel.INFO, console: Console | None = None):
132
+ self.level = level
133
+ if console is None:
134
+ self.console = Console()
135
+ else:
136
+ self.console = console
137
+
138
+ def log(self, *args, level: int | str | LogLevel = LogLevel.INFO, **kwargs) -> None:
139
+ """Logs a message to the console.
140
+
141
+ Args:
142
+ level (LogLevel, optional): Defaults to LogLevel.INFO.
143
+ """
144
+ if isinstance(level, str):
145
+ level = LogLevel[level.upper()]
146
+ if level <= self.level:
147
+ self.console.print(*args, **kwargs)
148
+
149
+ def log_error(self, error_message: str) -> None:
150
+ self.log(escape_code_brackets(error_message), style="bold red", level=LogLevel.ERROR)
151
+
152
+ def log_markdown(self, content: str, title: str | None = None, level=LogLevel.INFO, style=YELLOW_HEX) -> None:
153
+ markdown_content = Syntax(
154
+ content,
155
+ lexer="markdown",
156
+ theme="github-dark",
157
+ word_wrap=True,
158
+ )
159
+ if title:
160
+ self.log(
161
+ Group(
162
+ Rule(
163
+ "[bold italic]" + title,
164
+ align="left",
165
+ style=style,
166
+ ),
167
+ markdown_content,
168
+ ),
169
+ level=level,
170
+ )
171
+ else:
172
+ self.log(markdown_content, level=level)
173
+
174
+ def log_code(self, title: str, content: str, level: int = LogLevel.INFO) -> None:
175
+ self.log(
176
+ Panel(
177
+ Syntax(
178
+ content,
179
+ lexer="python",
180
+ theme="monokai",
181
+ word_wrap=True,
182
+ ),
183
+ title="[bold]" + title,
184
+ title_align="left",
185
+ box=box.HORIZONTALS,
186
+ ),
187
+ level=level,
188
+ )
189
+
190
+ def log_rule(self, title: str, level: int = LogLevel.INFO) -> None:
191
+ self.log(
192
+ Rule(
193
+ "[bold]" + title,
194
+ characters="━",
195
+ style=YELLOW_HEX,
196
+ ),
197
+ level=LogLevel.INFO,
198
+ )
199
+
200
+ def log_task(self, content: str, subtitle: str, title: str | None = None, level: LogLevel = LogLevel.INFO) -> None:
201
+ self.log(
202
+ Panel(
203
+ f"\n[bold]{escape_code_brackets(content)}\n",
204
+ title="[bold]New run" + (f" - {title}" if title else ""),
205
+ subtitle=subtitle,
206
+ border_style=YELLOW_HEX,
207
+ subtitle_align="left",
208
+ ),
209
+ level=level,
210
+ )
211
+
212
+ def log_messages(self, messages: list[dict], level: LogLevel = LogLevel.DEBUG) -> None:
213
+ messages_as_string = "\n".join([json.dumps(dict(message), indent=4) for message in messages])
214
+ self.log(
215
+ Syntax(
216
+ messages_as_string,
217
+ lexer="markdown",
218
+ theme="github-dark",
219
+ word_wrap=True,
220
+ ),
221
+ level=level,
222
+ )
223
+
224
+ def visualize_agent_tree(self, agent):
225
+ def create_tools_section(tools_dict):
226
+ table = Table(show_header=True, header_style="bold")
227
+ table.add_column("Name", style="#1E90FF")
228
+ table.add_column("Description")
229
+ table.add_column("Arguments")
230
+
231
+ for name, tool in tools_dict.items():
232
+ args = [
233
+ f"{arg_name} (`{info.get('type', 'Any')}`{', optional' if info.get('optional') else ''}): {info.get('description', '')}"
234
+ for arg_name, info in getattr(tool, "inputs", {}).items()
235
+ ]
236
+ table.add_row(name, getattr(tool, "description", str(tool)), "\n".join(args))
237
+
238
+ return Group("🛠️ [italic #1E90FF]Tools:[/italic #1E90FF]", table)
239
+
240
+ def get_agent_headline(agent, name: str | None = None):
241
+ name_headline = f"{name} | " if name else ""
242
+ return f"[bold {YELLOW_HEX}]{name_headline}{agent.__class__.__name__} | {agent.model.model_id}"
243
+
244
+ def build_agent_tree(parent_tree, agent_obj):
245
+ """Recursively builds the agent tree."""
246
+ parent_tree.add(create_tools_section(agent_obj.tools))
247
+
248
+ if agent_obj.managed_agents:
249
+ agents_branch = parent_tree.add("🤖 [italic #1E90FF]Managed agents:")
250
+ for name, managed_agent in agent_obj.managed_agents.items():
251
+ agent_tree = agents_branch.add(get_agent_headline(managed_agent, name))
252
+ if managed_agent.__class__.__name__ == "CodeAgent":
253
+ agent_tree.add(
254
+ f"✅ [italic #1E90FF]Authorized imports:[/italic #1E90FF] {managed_agent.additional_authorized_imports}"
255
+ )
256
+ agent_tree.add(f"📝 [italic #1E90FF]Description:[/italic #1E90FF] {managed_agent.description}")
257
+ build_agent_tree(agent_tree, managed_agent)
258
+
259
+ main_tree = Tree(get_agent_headline(agent))
260
+ if agent.__class__.__name__ == "CodeAgent":
261
+ main_tree.add(
262
+ f"✅ [italic #1E90FF]Authorized imports:[/italic #1E90FF] {agent.additional_authorized_imports}"
263
+ )
264
+ build_agent_tree(main_tree, agent)
265
+ self.console.print(main_tree)
src/smolagents/remote_executors.py ADDED
@@ -0,0 +1,451 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+
4
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ import base64
18
+ import inspect
19
+ import json
20
+ import pickle
21
+ import time
22
+ from io import BytesIO
23
+ from pathlib import Path
24
+ from textwrap import dedent
25
+ from typing import Any
26
+
27
+ import PIL.Image
28
+ import requests
29
+
30
+ from .default_tools import FinalAnswerTool
31
+ from .local_python_executor import PythonExecutor
32
+ from .monitoring import LogLevel
33
+ from .tools import Tool, get_tools_definition_code
34
+ from .utils import AgentError
35
+
36
+
37
+ try:
38
+ from dotenv import load_dotenv
39
+
40
+ load_dotenv()
41
+ except ModuleNotFoundError:
42
+ pass
43
+
44
+
45
+ class RemotePythonExecutor(PythonExecutor):
46
+ FINAL_ANSWER_EXCEPTION = "FinalAnswerException"
47
+
48
+ def __init__(self, additional_imports: list[str], logger):
49
+ self.additional_imports = additional_imports
50
+ self.logger = logger
51
+ self.logger.log("Initializing executor, hold on...")
52
+ self.installed_packages = []
53
+
54
+ def run_code_raise_errors(self, code: str) -> tuple[Any, str, bool]:
55
+ """
56
+ Execute code, return the result and output, also determining if
57
+ the result is the final answer.
58
+ """
59
+ raise NotImplementedError
60
+
61
+ def send_tools(self, tools: dict[str, Tool]):
62
+ if "final_answer" in tools:
63
+ self._patch_final_answer_with_exception(tools["final_answer"])
64
+ # Install tool packages
65
+ packages_to_install = {
66
+ pkg
67
+ for tool in tools.values()
68
+ for pkg in tool.to_dict()["requirements"]
69
+ if pkg not in self.installed_packages + ["smolagents"]
70
+ }
71
+ if packages_to_install:
72
+ self.installed_packages += self.install_packages(list(packages_to_install))
73
+ # Get tool definitions
74
+ code = get_tools_definition_code(tools)
75
+ if code:
76
+ execution = self.run_code_raise_errors(code)
77
+ self.logger.log(execution[1])
78
+
79
+ def send_variables(self, variables: dict):
80
+ """
81
+ Send variables to the kernel namespace using pickle.
82
+ """
83
+ pickled_vars = base64.b64encode(pickle.dumps(variables)).decode()
84
+ code = f"""
85
+ import pickle, base64
86
+ vars_dict = pickle.loads(base64.b64decode('{pickled_vars}'))
87
+ locals().update(vars_dict)
88
+ """
89
+ self.run_code_raise_errors(code)
90
+
91
+ def __call__(self, code_action: str) -> tuple[Any, str, bool]:
92
+ """Run the code and determine if it is the final answer."""
93
+ return self.run_code_raise_errors(code_action)
94
+
95
+ def install_packages(self, additional_imports: list[str]):
96
+ if additional_imports:
97
+ _, execution_logs, _ = self.run_code_raise_errors(f"!pip install {' '.join(additional_imports)}")
98
+ self.logger.log(execution_logs)
99
+ return additional_imports
100
+
101
+ def _patch_final_answer_with_exception(self, final_answer_tool: FinalAnswerTool):
102
+ """Patch the FinalAnswerTool to raise an exception.
103
+
104
+ This is necessary because the remote executors
105
+ rely on the FinalAnswerTool to detect the final answer.
106
+ It modifies the `forward` method of the FinalAnswerTool to raise
107
+ a `FinalAnswerException` with the final answer as a pickled value.
108
+ This allows the executor to catch this exception and return the final answer.
109
+
110
+ Args:
111
+ final_answer_tool (`FinalAnswerTool`): FinalAnswerTool instance to patch.
112
+ """
113
+
114
+ # Create a new class that inherits from the original FinalAnswerTool
115
+ class _FinalAnswerTool(final_answer_tool.__class__):
116
+ pass
117
+
118
+ # Add a new forward method that raises the FinalAnswerException
119
+ # - Define the new forward method function
120
+ def forward(self, *args, **kwargs) -> Any:
121
+ import base64
122
+ import pickle
123
+
124
+ class FinalAnswerException(Exception):
125
+ def __init__(self, value):
126
+ self.value = value
127
+
128
+ raise FinalAnswerException(base64.b64encode(pickle.dumps(self._forward(*args, **kwargs))).decode())
129
+
130
+ # - Set the new forward method function to the _FinalAnswerTool class
131
+ _FinalAnswerTool.forward = forward
132
+
133
+ # Rename the original forward method to _forward
134
+ # - Get the original forward method function from the final_answer_tool instance
135
+ original_forward_function = final_answer_tool.forward.__func__
136
+ # - Set the new _forward method function to the _FinalAnswerTool class
137
+ _FinalAnswerTool._forward = original_forward_function
138
+ # - Update the source code of the new forward method to match the original but with the new name
139
+ _FinalAnswerTool._forward.__source__ = inspect.getsource(original_forward_function).replace(
140
+ "def forward(", "def _forward("
141
+ )
142
+
143
+ # Set the new class as the class of the final_answer_tool instance
144
+ final_answer_tool.__class__ = _FinalAnswerTool
145
+
146
+
147
+ class E2BExecutor(RemotePythonExecutor):
148
+ """
149
+ Executes Python code using E2B.
150
+
151
+ Args:
152
+ additional_imports (`list[str]`): Additional imports to install.
153
+ logger (`Logger`): Logger to use.
154
+ **kwargs: Additional arguments to pass to the E2B Sandbox.
155
+ """
156
+
157
+ def __init__(self, additional_imports: list[str], logger, **kwargs):
158
+ super().__init__(additional_imports, logger)
159
+ try:
160
+ from e2b_code_interpreter import Sandbox
161
+ except ModuleNotFoundError:
162
+ raise ModuleNotFoundError(
163
+ """Please install 'e2b' extra to use E2BExecutor: `pip install 'smolagents[e2b]'`"""
164
+ )
165
+ self.sandbox = Sandbox(**kwargs)
166
+ self.installed_packages = self.install_packages(additional_imports)
167
+ self.logger.log("E2B is running", level=LogLevel.INFO)
168
+
169
+ def run_code_raise_errors(self, code: str) -> tuple[Any, str, bool]:
170
+ execution = self.sandbox.run_code(code)
171
+ execution_logs = "\n".join([str(log) for log in execution.logs.stdout])
172
+
173
+ # Handle errors
174
+ if execution.error:
175
+ # Check if the error is a FinalAnswerException
176
+ if execution.error.name == RemotePythonExecutor.FINAL_ANSWER_EXCEPTION:
177
+ final_answer = pickle.loads(base64.b64decode(execution.error.value))
178
+ return final_answer, execution_logs, True
179
+
180
+ # Construct error message
181
+ error_message = (
182
+ f"{execution_logs}\n"
183
+ f"Executing code yielded an error:\n"
184
+ f"{execution.error.name}\n"
185
+ f"{execution.error.value}\n"
186
+ f"{execution.error.traceback}"
187
+ )
188
+ raise AgentError(error_message, self.logger)
189
+
190
+ # Handle results
191
+ if not execution.results:
192
+ return None, execution_logs, False
193
+
194
+ for result in execution.results:
195
+ if not result.is_main_result:
196
+ continue
197
+ # Handle image outputs
198
+ for attribute_name in ["jpeg", "png"]:
199
+ img_data = getattr(result, attribute_name, None)
200
+ if img_data is not None:
201
+ decoded_bytes = base64.b64decode(img_data.encode("utf-8"))
202
+ return PIL.Image.open(BytesIO(decoded_bytes)), execution_logs, False
203
+ # Handle other data formats
204
+ for attribute_name in [
205
+ "chart",
206
+ "data",
207
+ "html",
208
+ "javascript",
209
+ "json",
210
+ "latex",
211
+ "markdown",
212
+ "pdf",
213
+ "svg",
214
+ "text",
215
+ ]:
216
+ data = getattr(result, attribute_name, None)
217
+ if data is not None:
218
+ return data, execution_logs, False
219
+ # If no main result found, return None
220
+ return None, execution_logs, False
221
+
222
+ def cleanup(self):
223
+ """Clean up the E2B sandbox and resources."""
224
+ try:
225
+ if hasattr(self, "sandbox"):
226
+ self.logger.log("Shutting down sandbox...", level=LogLevel.INFO)
227
+ self.sandbox.kill()
228
+ self.logger.log("Sandbox cleanup completed", level=LogLevel.INFO)
229
+ del self.sandbox
230
+ except Exception as e:
231
+ self.logger.log_error(f"Error during cleanup: {e}")
232
+
233
+
234
+ class DockerExecutor(RemotePythonExecutor):
235
+ """
236
+ Executes Python code using Jupyter Kernel Gateway in a Docker container.
237
+ """
238
+
239
+ def __init__(
240
+ self,
241
+ additional_imports: list[str],
242
+ logger,
243
+ host: str = "127.0.0.1",
244
+ port: int = 8888,
245
+ image_name: str = "jupyter-kernel",
246
+ build_new_image: bool = True,
247
+ container_run_kwargs: dict[str, Any] | None = None,
248
+ ):
249
+ """
250
+ Initialize the Docker-based Jupyter Kernel Gateway executor.
251
+
252
+ Args:
253
+ additional_imports: Additional imports to install.
254
+ logger: Logger to use.
255
+ host: Host to bind to.
256
+ port: Port to bind to.
257
+ image_name: Name of the Docker image to use. If the image doesn't exist, it will be built.
258
+ build_new_image: If True, the image will be rebuilt even if it already exists.
259
+ container_run_kwargs: Additional keyword arguments to pass to the Docker container run command.
260
+ """
261
+ super().__init__(additional_imports, logger)
262
+ try:
263
+ import docker
264
+ from websocket import create_connection
265
+ except ModuleNotFoundError:
266
+ raise ModuleNotFoundError(
267
+ "Please install 'docker' extra to use DockerExecutor: `pip install 'smolagents[docker]'`"
268
+ )
269
+ self.host = host
270
+ self.port = port
271
+ self.image_name = image_name
272
+
273
+ # Initialize Docker
274
+ try:
275
+ self.client = docker.from_env()
276
+ except docker.errors.DockerException as e:
277
+ raise RuntimeError("Could not connect to Docker daemon: make sure Docker is running.") from e
278
+
279
+ # Build and start container
280
+ try:
281
+ # Check if image exists, unless forced to rebuild
282
+ if not build_new_image:
283
+ try:
284
+ self.client.images.get(self.image_name)
285
+ self.logger.log(f"Using existing Docker image: {self.image_name}", level=LogLevel.INFO)
286
+ except docker.errors.ImageNotFound:
287
+ self.logger.log(f"Image {self.image_name} not found, building...", level=LogLevel.INFO)
288
+ build_new_image = True
289
+
290
+ if build_new_image:
291
+ self.logger.log(f"Building Docker image {self.image_name}...", level=LogLevel.INFO)
292
+ dockerfile_path = Path(__file__).parent / "Dockerfile"
293
+ if not dockerfile_path.exists():
294
+ with open(dockerfile_path, "w") as f:
295
+ f.write(
296
+ dedent(
297
+ """\
298
+ FROM python:3.12-slim
299
+
300
+ RUN pip install jupyter_kernel_gateway jupyter_client
301
+
302
+ EXPOSE 8888
303
+ CMD ["jupyter", "kernelgateway", "--KernelGatewayApp.ip='0.0.0.0'", "--KernelGatewayApp.port=8888", "--KernelGatewayApp.allow_origin='*'"]
304
+ """
305
+ )
306
+ )
307
+ _, build_logs = self.client.images.build(
308
+ path=str(dockerfile_path.parent), dockerfile=str(dockerfile_path), tag=self.image_name
309
+ )
310
+ for log_chunk in build_logs:
311
+ # Only log non-empty messages
312
+ if log_message := log_chunk.get("stream", "").rstrip():
313
+ self.logger.log(log_message, level=LogLevel.DEBUG)
314
+
315
+ self.logger.log(f"Starting container on {host}:{port}...", level=LogLevel.INFO)
316
+ # Create base container parameters
317
+ container_kwargs = {}
318
+ if container_run_kwargs:
319
+ container_kwargs.update(container_run_kwargs)
320
+
321
+ # Ensure required port mapping and background running
322
+ if not isinstance(container_kwargs.get("ports"), dict):
323
+ container_kwargs["ports"] = {}
324
+ container_kwargs["ports"]["8888/tcp"] = (host, port)
325
+ container_kwargs["detach"] = True
326
+
327
+ self.container = self.client.containers.run(self.image_name, **container_kwargs)
328
+
329
+ retries = 0
330
+ while self.container.status != "running" and retries < 5:
331
+ self.logger.log(f"Container status: {self.container.status}, waiting...", level=LogLevel.INFO)
332
+ time.sleep(1)
333
+ self.container.reload()
334
+ retries += 1
335
+
336
+ self.base_url = f"http://{host}:{port}"
337
+
338
+ # Create new kernel via HTTP
339
+ r = requests.post(f"{self.base_url}/api/kernels")
340
+ if r.status_code != 201:
341
+ error_details = {
342
+ "status_code": r.status_code,
343
+ "headers": dict(r.headers),
344
+ "url": r.url,
345
+ "body": r.text,
346
+ "request_method": r.request.method,
347
+ "request_headers": dict(r.request.headers),
348
+ "request_body": r.request.body,
349
+ }
350
+ self.logger.log_error(f"Failed to create kernel. Details: {json.dumps(error_details, indent=2)}")
351
+ raise RuntimeError(f"Failed to create kernel: Status {r.status_code}\nResponse: {r.text}") from None
352
+
353
+ self.kernel_id = r.json()["id"]
354
+
355
+ ws_url = f"ws://{host}:{port}/api/kernels/{self.kernel_id}/channels"
356
+ self.ws = create_connection(ws_url)
357
+
358
+ self.installed_packages = self.install_packages(additional_imports)
359
+ self.logger.log(
360
+ f"Container {self.container.short_id} is running with kernel {self.kernel_id}", level=LogLevel.INFO
361
+ )
362
+
363
+ except Exception as e:
364
+ self.cleanup()
365
+ raise RuntimeError(f"Failed to initialize Jupyter kernel: {e}") from e
366
+
367
+ def run_code_raise_errors(self, code_action: str) -> tuple[Any, str, bool]:
368
+ try:
369
+ # Send execute request
370
+ msg_id = self._send_execute_request(code_action)
371
+
372
+ # Collect output and results
373
+ outputs = []
374
+ result = None
375
+ is_final_answer = False
376
+
377
+ while True:
378
+ msg = json.loads(self.ws.recv())
379
+ parent_msg_id = msg.get("parent_header", {}).get("msg_id")
380
+ # Skip unrelated messages
381
+ if parent_msg_id != msg_id:
382
+ continue
383
+ msg_type = msg.get("msg_type", "")
384
+ msg_content = msg.get("content", {})
385
+ if msg_type == "stream":
386
+ outputs.append(msg_content["text"])
387
+ elif msg_type == "execute_result":
388
+ result = msg_content["data"].get("text/plain", None)
389
+ elif msg_type == "error":
390
+ if msg_content.get("ename", "") == RemotePythonExecutor.FINAL_ANSWER_EXCEPTION:
391
+ result = pickle.loads(base64.b64decode(msg_content.get("evalue", "")))
392
+ is_final_answer = True
393
+ else:
394
+ raise AgentError("\n".join(msg_content.get("traceback", [])), self.logger)
395
+ elif msg_type == "status" and msg_content["execution_state"] == "idle":
396
+ break
397
+
398
+ return result, "".join(outputs), is_final_answer
399
+
400
+ except Exception as e:
401
+ self.logger.log_error(f"Code execution failed: {e}")
402
+ raise
403
+
404
+ def _send_execute_request(self, code: str) -> str:
405
+ """Send code execution request to kernel."""
406
+ import uuid
407
+
408
+ # Generate a unique message ID
409
+ msg_id = str(uuid.uuid4())
410
+
411
+ # Create execute request
412
+ execute_request = {
413
+ "header": {
414
+ "msg_id": msg_id,
415
+ "username": "anonymous",
416
+ "session": str(uuid.uuid4()),
417
+ "msg_type": "execute_request",
418
+ "version": "5.0",
419
+ },
420
+ "parent_header": {},
421
+ "metadata": {},
422
+ "content": {
423
+ "code": code,
424
+ "silent": False,
425
+ "store_history": True,
426
+ "user_expressions": {},
427
+ "allow_stdin": False,
428
+ },
429
+ }
430
+
431
+ self.ws.send(json.dumps(execute_request))
432
+ return msg_id
433
+
434
+ def cleanup(self):
435
+ """Clean up the Docker container and resources."""
436
+ try:
437
+ if hasattr(self, "container"):
438
+ self.logger.log(f"Stopping and removing container {self.container.short_id}...", level=LogLevel.INFO)
439
+ self.container.stop()
440
+ self.container.remove()
441
+ self.logger.log("Container cleanup completed", level=LogLevel.INFO)
442
+ del self.container
443
+ except Exception as e:
444
+ self.logger.log_error(f"Error during cleanup: {e}")
445
+
446
+ def delete(self):
447
+ """Ensure cleanup on deletion."""
448
+ self.cleanup()
449
+
450
+
451
+ __all__ = ["E2BExecutor", "DockerExecutor"]
src/smolagents/tool_validation.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import builtins
3
+ from itertools import zip_longest
4
+
5
+ from .utils import BASE_BUILTIN_MODULES, get_source, is_valid_name
6
+
7
+
8
+ _BUILTIN_NAMES = set(vars(builtins))
9
+
10
+
11
+ class MethodChecker(ast.NodeVisitor):
12
+ """
13
+ Checks that a method
14
+ - only uses defined names
15
+ - contains no local imports (e.g. numpy is ok but local_script is not)
16
+ """
17
+
18
+ def __init__(self, class_attributes: set[str], check_imports: bool = True):
19
+ self.undefined_names = set()
20
+ self.imports = {}
21
+ self.from_imports = {}
22
+ self.assigned_names = set()
23
+ self.arg_names = set()
24
+ self.class_attributes = class_attributes
25
+ self.errors = []
26
+ self.check_imports = check_imports
27
+ self.typing_names = {"Any"}
28
+ self.defined_classes = set()
29
+
30
+ def visit_arguments(self, node):
31
+ """Collect function arguments"""
32
+ self.arg_names = {arg.arg for arg in node.args}
33
+ if node.kwarg:
34
+ self.arg_names.add(node.kwarg.arg)
35
+ if node.vararg:
36
+ self.arg_names.add(node.vararg.arg)
37
+
38
+ def visit_Import(self, node):
39
+ for name in node.names:
40
+ actual_name = name.asname or name.name
41
+ self.imports[actual_name] = name.name
42
+
43
+ def visit_ImportFrom(self, node):
44
+ module = node.module or ""
45
+ for name in node.names:
46
+ actual_name = name.asname or name.name
47
+ self.from_imports[actual_name] = (module, name.name)
48
+
49
+ def visit_Assign(self, node):
50
+ for target in node.targets:
51
+ if isinstance(target, ast.Name):
52
+ self.assigned_names.add(target.id)
53
+ elif isinstance(target, (ast.Tuple, ast.List)):
54
+ for elt in target.elts:
55
+ if isinstance(elt, ast.Name):
56
+ self.assigned_names.add(elt.id)
57
+ self.visit(node.value)
58
+
59
+ def visit_With(self, node):
60
+ """Track aliases in 'with' statements (the 'y' in 'with X as y')"""
61
+ for item in node.items:
62
+ if item.optional_vars: # This is the 'y' in 'with X as y'
63
+ if isinstance(item.optional_vars, ast.Name):
64
+ self.assigned_names.add(item.optional_vars.id)
65
+ self.generic_visit(node)
66
+
67
+ def visit_ExceptHandler(self, node):
68
+ """Track exception aliases (the 'e' in 'except Exception as e')"""
69
+ if node.name: # This is the 'e' in 'except Exception as e'
70
+ self.assigned_names.add(node.name)
71
+ self.generic_visit(node)
72
+
73
+ def visit_AnnAssign(self, node):
74
+ """Track annotated assignments."""
75
+ if isinstance(node.target, ast.Name):
76
+ self.assigned_names.add(node.target.id)
77
+ if node.value:
78
+ self.visit(node.value)
79
+
80
+ def visit_For(self, node):
81
+ target = node.target
82
+ if isinstance(target, ast.Name):
83
+ self.assigned_names.add(target.id)
84
+ elif isinstance(target, ast.Tuple):
85
+ for elt in target.elts:
86
+ if isinstance(elt, ast.Name):
87
+ self.assigned_names.add(elt.id)
88
+ self.generic_visit(node)
89
+
90
+ def _handle_comprehension_generators(self, generators):
91
+ """Helper method to handle generators in all types of comprehensions"""
92
+ for generator in generators:
93
+ if isinstance(generator.target, ast.Name):
94
+ self.assigned_names.add(generator.target.id)
95
+ elif isinstance(generator.target, ast.Tuple):
96
+ for elt in generator.target.elts:
97
+ if isinstance(elt, ast.Name):
98
+ self.assigned_names.add(elt.id)
99
+
100
+ def visit_ListComp(self, node):
101
+ """Track variables in list comprehensions"""
102
+ self._handle_comprehension_generators(node.generators)
103
+ self.generic_visit(node)
104
+
105
+ def visit_DictComp(self, node):
106
+ """Track variables in dictionary comprehensions"""
107
+ self._handle_comprehension_generators(node.generators)
108
+ self.generic_visit(node)
109
+
110
+ def visit_SetComp(self, node):
111
+ """Track variables in set comprehensions"""
112
+ self._handle_comprehension_generators(node.generators)
113
+ self.generic_visit(node)
114
+
115
+ def visit_Attribute(self, node):
116
+ if not (isinstance(node.value, ast.Name) and node.value.id == "self"):
117
+ self.generic_visit(node)
118
+
119
+ def visit_ClassDef(self, node):
120
+ """Track class definitions"""
121
+ self.defined_classes.add(node.name)
122
+ self.generic_visit(node)
123
+
124
+ def visit_Name(self, node):
125
+ if isinstance(node.ctx, ast.Load):
126
+ if not (
127
+ node.id in _BUILTIN_NAMES
128
+ or node.id in BASE_BUILTIN_MODULES
129
+ or node.id in self.arg_names
130
+ or node.id == "self"
131
+ or node.id in self.class_attributes
132
+ or node.id in self.imports
133
+ or node.id in self.from_imports
134
+ or node.id in self.assigned_names
135
+ or node.id in self.typing_names
136
+ or node.id in self.defined_classes
137
+ ):
138
+ self.errors.append(f"Name '{node.id}' is undefined.")
139
+
140
+ def visit_Call(self, node):
141
+ if isinstance(node.func, ast.Name):
142
+ if not (
143
+ node.func.id in _BUILTIN_NAMES
144
+ or node.func.id in BASE_BUILTIN_MODULES
145
+ or node.func.id in self.arg_names
146
+ or node.func.id == "self"
147
+ or node.func.id in self.class_attributes
148
+ or node.func.id in self.imports
149
+ or node.func.id in self.from_imports
150
+ or node.func.id in self.assigned_names
151
+ or node.func.id in self.defined_classes
152
+ ):
153
+ self.errors.append(f"Name '{node.func.id}' is undefined.")
154
+ self.generic_visit(node)
155
+
156
+
157
+ def validate_tool_attributes(cls, check_imports: bool = True) -> None:
158
+ """
159
+ Validates that a Tool class follows the proper patterns:
160
+ 0. Any argument of __init__ should have a default.
161
+ Args chosen at init are not traceable, so we cannot rebuild the source code for them, thus any important arg should be defined as a class attribute.
162
+ 1. About the class:
163
+ - Class attributes should only be strings or dicts
164
+ - Class attributes cannot be complex attributes
165
+ 2. About all class methods:
166
+ - Imports must be from packages, not local files
167
+ - All methods must be self-contained
168
+
169
+ Raises all errors encountered, if no error returns None.
170
+ """
171
+
172
+ class ClassLevelChecker(ast.NodeVisitor):
173
+ def __init__(self):
174
+ self.imported_names = set()
175
+ self.complex_attributes = set()
176
+ self.class_attributes = set()
177
+ self.non_defaults = set()
178
+ self.non_literal_defaults = set()
179
+ self.in_method = False
180
+ self.invalid_attributes = []
181
+
182
+ def visit_FunctionDef(self, node):
183
+ if node.name == "__init__":
184
+ self._check_init_function_parameters(node)
185
+ old_context = self.in_method
186
+ self.in_method = True
187
+ self.generic_visit(node)
188
+ self.in_method = old_context
189
+
190
+ def visit_Assign(self, node):
191
+ if self.in_method:
192
+ return
193
+ # Track class attributes
194
+ for target in node.targets:
195
+ if isinstance(target, ast.Name):
196
+ self.class_attributes.add(target.id)
197
+
198
+ # Check if the assignment is more complex than simple literals
199
+ if not all(
200
+ isinstance(val, (ast.Str, ast.Num, ast.Constant, ast.Dict, ast.List, ast.Set))
201
+ for val in ast.walk(node.value)
202
+ ):
203
+ for target in node.targets:
204
+ if isinstance(target, ast.Name):
205
+ self.complex_attributes.add(target.id)
206
+
207
+ # Check specific class attributes
208
+ if getattr(node.targets[0], "id", "") == "name":
209
+ if not isinstance(node.value, ast.Constant):
210
+ self.invalid_attributes.append(f"Class attribute 'name' must be a constant, found '{node.value}'")
211
+ elif not isinstance(node.value.value, str):
212
+ self.invalid_attributes.append(
213
+ f"Class attribute 'name' must be a string, found '{node.value.value}'"
214
+ )
215
+ elif not is_valid_name(node.value.value):
216
+ self.invalid_attributes.append(
217
+ f"Class attribute 'name' must be a valid Python identifier and not a reserved keyword, found '{node.value.value}'"
218
+ )
219
+
220
+ def _check_init_function_parameters(self, node):
221
+ # Check defaults in parameters
222
+ for arg, default in reversed(list(zip_longest(reversed(node.args.args), reversed(node.args.defaults)))):
223
+ if default is None:
224
+ if arg.arg != "self":
225
+ self.non_defaults.add(arg.arg)
226
+ elif not isinstance(default, (ast.Str, ast.Num, ast.Constant, ast.Dict, ast.List, ast.Set)):
227
+ self.non_literal_defaults.add(arg.arg)
228
+
229
+ class_level_checker = ClassLevelChecker()
230
+ source = get_source(cls)
231
+ tree = ast.parse(source)
232
+ class_node = tree.body[0]
233
+ if not isinstance(class_node, ast.ClassDef):
234
+ raise ValueError("Source code must define a class")
235
+ class_level_checker.visit(class_node)
236
+
237
+ errors = []
238
+ # Check invalid class attributes
239
+ if class_level_checker.invalid_attributes:
240
+ errors += class_level_checker.invalid_attributes
241
+ if class_level_checker.complex_attributes:
242
+ errors.append(
243
+ f"Complex attributes should be defined in __init__, not as class attributes: "
244
+ f"{', '.join(class_level_checker.complex_attributes)}"
245
+ )
246
+ if class_level_checker.non_defaults:
247
+ errors.append(
248
+ f"Parameters in __init__ must have default values, found required parameters: "
249
+ f"{', '.join(class_level_checker.non_defaults)}"
250
+ )
251
+ if class_level_checker.non_literal_defaults:
252
+ errors.append(
253
+ f"Parameters in __init__ must have literal default values, found non-literal defaults: "
254
+ f"{', '.join(class_level_checker.non_literal_defaults)}"
255
+ )
256
+
257
+ # Run checks on all methods
258
+ for node in class_node.body:
259
+ if isinstance(node, ast.FunctionDef):
260
+ method_checker = MethodChecker(class_level_checker.class_attributes, check_imports=check_imports)
261
+ method_checker.visit(node)
262
+ errors += [f"- {node.name}: {error}" for error in method_checker.errors]
263
+
264
+ if errors:
265
+ raise ValueError(f"Tool validation failed for {cls.__name__}:\n" + "\n".join(errors))
266
+ return
src/smolagents/tools.py ADDED
@@ -0,0 +1,1239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+
4
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ from __future__ import annotations
18
+
19
+ import ast
20
+ import inspect
21
+ import json
22
+ import logging
23
+ import os
24
+ import sys
25
+ import tempfile
26
+ import textwrap
27
+ import types
28
+ import warnings
29
+ from collections.abc import Callable
30
+ from contextlib import contextmanager
31
+ from functools import wraps
32
+ from pathlib import Path
33
+ from typing import TYPE_CHECKING, Any
34
+
35
+ from huggingface_hub import (
36
+ CommitOperationAdd,
37
+ create_commit,
38
+ create_repo,
39
+ get_collection,
40
+ hf_hub_download,
41
+ metadata_update,
42
+ )
43
+
44
+ from ._function_type_hints_utils import (
45
+ TypeHintParsingException,
46
+ _convert_type_hints_to_json_schema,
47
+ _get_json_schema_type,
48
+ get_imports,
49
+ get_json_schema,
50
+ )
51
+ from .agent_types import handle_agent_input_types, handle_agent_output_types
52
+ from .tool_validation import MethodChecker, validate_tool_attributes
53
+ from .utils import (
54
+ BASE_BUILTIN_MODULES,
55
+ _is_package_available,
56
+ get_source,
57
+ instance_to_source,
58
+ is_valid_name,
59
+ )
60
+
61
+
62
+ if TYPE_CHECKING:
63
+ import mcp
64
+
65
+
66
+ logger = logging.getLogger(__name__)
67
+
68
+
69
+ def validate_after_init(cls):
70
+ original_init = cls.__init__
71
+
72
+ @wraps(original_init)
73
+ def new_init(self, *args, **kwargs):
74
+ original_init(self, *args, **kwargs)
75
+ self.validate_arguments()
76
+
77
+ cls.__init__ = new_init
78
+ return cls
79
+
80
+
81
+ AUTHORIZED_TYPES = [
82
+ "string",
83
+ "boolean",
84
+ "integer",
85
+ "number",
86
+ "image",
87
+ "audio",
88
+ "array",
89
+ "object",
90
+ "any",
91
+ "null",
92
+ ]
93
+
94
+ CONVERSION_DICT = {"str": "string", "int": "integer", "float": "number"}
95
+
96
+
97
+ class Tool:
98
+ """
99
+ A base class for the functions used by the agent. Subclass this and implement the `forward` method as well as the
100
+ following class attributes:
101
+
102
+ - **description** (`str`) -- A short description of what your tool does, the inputs it expects and the output(s) it
103
+ will return. For instance 'This is a tool that downloads a file from a `url`. It takes the `url` as input, and
104
+ returns the text contained in the file'.
105
+ - **name** (`str`) -- A performative name that will be used for your tool in the prompt to the agent. For instance
106
+ `"text-classifier"` or `"image_generator"`.
107
+ - **inputs** (`Dict[str, Dict[str, Union[str, type, bool]]]`) -- The dict of modalities expected for the inputs.
108
+ It has one `type`key and a `description`key.
109
+ This is used by `launch_gradio_demo` or to make a nice space from your tool, and also can be used in the generated
110
+ description for your tool.
111
+ - **output_type** (`type`) -- The type of the tool output. This is used by `launch_gradio_demo`
112
+ or to make a nice space from your tool, and also can be used in the generated description for your tool.
113
+
114
+ You can also override the method [`~Tool.setup`] if your tool has an expensive operation to perform before being
115
+ usable (such as loading a model). [`~Tool.setup`] will be called the first time you use your tool, but not at
116
+ instantiation.
117
+ """
118
+
119
+ name: str
120
+ description: str
121
+ inputs: dict[str, dict[str, str | type | bool]]
122
+ output_type: str
123
+
124
+ def __init__(self, *args, **kwargs):
125
+ self.is_initialized = False
126
+
127
+ def __init_subclass__(cls, **kwargs):
128
+ super().__init_subclass__(**kwargs)
129
+ validate_after_init(cls)
130
+
131
+ def validate_arguments(self):
132
+ required_attributes = {
133
+ "description": str,
134
+ "name": str,
135
+ "inputs": dict,
136
+ "output_type": str,
137
+ }
138
+ # Validate class attributes
139
+ for attr, expected_type in required_attributes.items():
140
+ attr_value = getattr(self, attr, None)
141
+ if attr_value is None:
142
+ raise TypeError(f"You must set an attribute {attr}.")
143
+ if not isinstance(attr_value, expected_type):
144
+ raise TypeError(
145
+ f"Attribute {attr} should have type {expected_type.__name__}, got {type(attr_value)} instead."
146
+ )
147
+ # - Validate name
148
+ if not is_valid_name(self.name):
149
+ raise Exception(
150
+ f"Invalid Tool name '{self.name}': must be a valid Python identifier and not a reserved keyword"
151
+ )
152
+ # Validate inputs
153
+ for input_name, input_content in self.inputs.items():
154
+ assert isinstance(input_content, dict), f"Input '{input_name}' should be a dictionary."
155
+ assert "type" in input_content and "description" in input_content, (
156
+ f"Input '{input_name}' should have keys 'type' and 'description', has only {list(input_content.keys())}."
157
+ )
158
+ if input_content["type"] not in AUTHORIZED_TYPES:
159
+ raise Exception(
160
+ f"Input '{input_name}': type '{input_content['type']}' is not an authorized value, should be one of {AUTHORIZED_TYPES}."
161
+ )
162
+ # Validate output type
163
+ assert getattr(self, "output_type", None) in AUTHORIZED_TYPES
164
+
165
+ # Validate forward function signature, except for Tools that use a "generic" signature (PipelineTool, SpaceToolWrapper, LangChainToolWrapper)
166
+ if not (
167
+ hasattr(self, "skip_forward_signature_validation")
168
+ and getattr(self, "skip_forward_signature_validation") is True
169
+ ):
170
+ signature = inspect.signature(self.forward)
171
+ actual_keys = set(key for key in signature.parameters.keys() if key != "self")
172
+ expected_keys = set(self.inputs.keys())
173
+ if actual_keys != expected_keys:
174
+ raise Exception(
175
+ f"In tool '{self.name}', 'forward' method parameters were {actual_keys}, but expected {expected_keys}. "
176
+ f"It should take 'self' as its first argument, then its next arguments should match the keys of tool attribute 'inputs'."
177
+ )
178
+
179
+ json_schema = _convert_type_hints_to_json_schema(self.forward, error_on_missing_type_hints=False)[
180
+ "properties"
181
+ ] # This function will not raise an error on missing docstrings, contrary to get_json_schema
182
+ for key, value in self.inputs.items():
183
+ assert key in json_schema, (
184
+ f"Input '{key}' should be present in function signature, found only {json_schema.keys()}"
185
+ )
186
+ if "nullable" in value:
187
+ assert "nullable" in json_schema[key], (
188
+ f"Nullable argument '{key}' in inputs should have key 'nullable' set to True in function signature."
189
+ )
190
+ if key in json_schema and "nullable" in json_schema[key]:
191
+ assert "nullable" in value, (
192
+ f"Nullable argument '{key}' in function signature should have key 'nullable' set to True in inputs."
193
+ )
194
+
195
+ def forward(self, *args, **kwargs):
196
+ return NotImplementedError("Write this method in your subclass of `Tool`.")
197
+
198
+ def __call__(self, *args, sanitize_inputs_outputs: bool = False, **kwargs):
199
+ if not self.is_initialized:
200
+ self.setup()
201
+
202
+ # Handle the arguments might be passed as a single dictionary
203
+ if len(args) == 1 and len(kwargs) == 0 and isinstance(args[0], dict):
204
+ potential_kwargs = args[0]
205
+
206
+ # If the dictionary keys match our input parameters, convert it to kwargs
207
+ if all(key in self.inputs for key in potential_kwargs):
208
+ args = ()
209
+ kwargs = potential_kwargs
210
+
211
+ if sanitize_inputs_outputs:
212
+ args, kwargs = handle_agent_input_types(*args, **kwargs)
213
+ outputs = self.forward(*args, **kwargs)
214
+ if sanitize_inputs_outputs:
215
+ outputs = handle_agent_output_types(outputs, self.output_type)
216
+ return outputs
217
+
218
+ def setup(self):
219
+ """
220
+ Overwrite this method here for any operation that is expensive and needs to be executed before you start using
221
+ your tool. Such as loading a big model.
222
+ """
223
+ self.is_initialized = True
224
+
225
+ def to_dict(self) -> dict:
226
+ """Returns a dictionary representing the tool"""
227
+ class_name = self.__class__.__name__
228
+ if type(self).__name__ == "SimpleTool":
229
+ # Check that imports are self-contained
230
+ source_code = get_source(self.forward).replace("@tool", "")
231
+ forward_node = ast.parse(source_code)
232
+ # If tool was created using '@tool' decorator, it has only a forward pass, so it's simpler to just get its code
233
+ method_checker = MethodChecker(set())
234
+ method_checker.visit(forward_node)
235
+
236
+ if len(method_checker.errors) > 0:
237
+ errors = [f"- {error}" for error in method_checker.errors]
238
+ raise (ValueError(f"SimpleTool validation failed for {self.name}:\n" + "\n".join(errors)))
239
+
240
+ forward_source_code = get_source(self.forward)
241
+ tool_code = textwrap.dedent(
242
+ f"""
243
+ from smolagents import Tool
244
+ from typing import Any, Optional
245
+
246
+ class {class_name}(Tool):
247
+ name = "{self.name}"
248
+ description = {json.dumps(textwrap.dedent(self.description).strip())}
249
+ inputs = {repr(self.inputs)}
250
+ output_type = "{self.output_type}"
251
+ """
252
+ ).strip()
253
+ import re
254
+
255
+ def add_self_argument(source_code: str) -> str:
256
+ """Add 'self' as first argument to a function definition if not present."""
257
+ pattern = r"def forward\(((?!self)[^)]*)\)"
258
+
259
+ def replacement(match):
260
+ args = match.group(1).strip()
261
+ if args: # If there are other arguments
262
+ return f"def forward(self, {args})"
263
+ return "def forward(self)"
264
+
265
+ return re.sub(pattern, replacement, source_code)
266
+
267
+ forward_source_code = forward_source_code.replace(self.name, "forward")
268
+ forward_source_code = add_self_argument(forward_source_code)
269
+ forward_source_code = forward_source_code.replace("@tool", "").strip()
270
+ tool_code += "\n\n" + textwrap.indent(forward_source_code, " ")
271
+
272
+ else: # If the tool was not created by the @tool decorator, it was made by subclassing Tool
273
+ if type(self).__name__ in [
274
+ "SpaceToolWrapper",
275
+ "LangChainToolWrapper",
276
+ "GradioToolWrapper",
277
+ ]:
278
+ raise ValueError(
279
+ "Cannot save objects created with from_space, from_langchain or from_gradio, as this would create errors."
280
+ )
281
+
282
+ validate_tool_attributes(self.__class__)
283
+
284
+ tool_code = "from typing import Any, Optional\n" + instance_to_source(self, base_cls=Tool)
285
+
286
+ requirements = {el for el in get_imports(tool_code) if el not in sys.stdlib_module_names} | {"smolagents"}
287
+
288
+ return {"name": self.name, "code": tool_code, "requirements": sorted(requirements)}
289
+
290
+ @classmethod
291
+ def from_dict(cls, tool_dict: dict[str, Any], **kwargs) -> "Tool":
292
+ """
293
+ Create tool from a dictionary representation.
294
+
295
+ Args:
296
+ tool_dict (`dict[str, Any]`): Dictionary representation of the tool.
297
+ **kwargs: Additional keyword arguments to pass to the tool's constructor.
298
+
299
+ Returns:
300
+ `Tool`: Tool object.
301
+ """
302
+ if "code" not in tool_dict:
303
+ raise ValueError("Tool dictionary must contain 'code' key with the tool source code")
304
+ return cls.from_code(tool_dict["code"], **kwargs)
305
+
306
+ def save(self, output_dir: str | Path, tool_file_name: str = "tool", make_gradio_app: bool = True):
307
+ """
308
+ Saves the relevant code files for your tool so it can be pushed to the Hub. This will copy the code of your
309
+ tool in `output_dir` as well as autogenerate:
310
+
311
+ - a `{tool_file_name}.py` file containing the logic for your tool.
312
+ If you pass `make_gradio_app=True`, this will also write:
313
+ - an `app.py` file providing a UI for your tool when it is exported to a Space with `tool.push_to_hub()`
314
+ - a `requirements.txt` containing the names of the modules used by your tool (as detected when inspecting its
315
+ code)
316
+
317
+ Args:
318
+ output_dir (`str` or `Path`): The folder in which you want to save your tool.
319
+ tool_file_name (`str`, *optional*): The file name in which you want to save your tool.
320
+ make_gradio_app (`bool`, *optional*, defaults to True): Whether to also export a `requirements.txt` file and Gradio UI.
321
+ """
322
+ # Ensure output directory exists
323
+ output_path = Path(output_dir)
324
+ output_path.mkdir(parents=True, exist_ok=True)
325
+ # Save tool file
326
+ self._write_file(output_path / f"{tool_file_name}.py", self._get_tool_code())
327
+ if make_gradio_app:
328
+ # Save app file
329
+ self._write_file(output_path / "app.py", self._get_gradio_app_code(tool_module_name=tool_file_name))
330
+ # Save requirements file
331
+ self._write_file(output_path / "requirements.txt", self._get_requirements())
332
+
333
+ def _write_file(self, file_path: Path, content: str) -> None:
334
+ """Writes content to a file with UTF-8 encoding."""
335
+ file_path.write_text(content, encoding="utf-8")
336
+
337
+ def push_to_hub(
338
+ self,
339
+ repo_id: str,
340
+ commit_message: str = "Upload tool",
341
+ private: bool | None = None,
342
+ token: bool | str | None = None,
343
+ create_pr: bool = False,
344
+ ) -> str:
345
+ """
346
+ Upload the tool to the Hub.
347
+
348
+ Parameters:
349
+ repo_id (`str`):
350
+ The name of the repository you want to push your tool to. It should contain your organization name when
351
+ pushing to a given organization.
352
+ commit_message (`str`, *optional*, defaults to `"Upload tool"`):
353
+ Message to commit while pushing.
354
+ private (`bool`, *optional*):
355
+ Whether to make the repo private. If `None` (default), the repo will be public unless the organization's default is private. This value is ignored if the repo already exists.
356
+ token (`bool` or `str`, *optional*):
357
+ The token to use as HTTP bearer authorization for remote files. If unset, will use the token generated
358
+ when running `huggingface-cli login` (stored in `~/.huggingface`).
359
+ create_pr (`bool`, *optional*, defaults to `False`):
360
+ Whether to create a PR with the uploaded files or directly commit.
361
+ """
362
+ # Initialize repository
363
+ repo_id = self._initialize_hub_repo(repo_id, token, private)
364
+ # Prepare files for commit
365
+ additions = self._prepare_hub_files()
366
+ # Create commit
367
+ return create_commit(
368
+ repo_id=repo_id,
369
+ operations=additions,
370
+ commit_message=commit_message,
371
+ token=token,
372
+ create_pr=create_pr,
373
+ repo_type="space",
374
+ )
375
+
376
+ @staticmethod
377
+ def _initialize_hub_repo(repo_id: str, token: bool | str | None, private: bool | None) -> str:
378
+ """Initialize repository on Hugging Face Hub."""
379
+ repo_url = create_repo(
380
+ repo_id=repo_id,
381
+ token=token,
382
+ private=private,
383
+ exist_ok=True,
384
+ repo_type="space",
385
+ space_sdk="gradio",
386
+ )
387
+ metadata_update(repo_url.repo_id, {"tags": ["smolagents", "tool"]}, repo_type="space", token=token)
388
+ return repo_url.repo_id
389
+
390
+ def _prepare_hub_files(self) -> list:
391
+ """Prepare files for Hub commit."""
392
+ additions = [
393
+ # Add tool code
394
+ CommitOperationAdd(
395
+ path_in_repo="tool.py",
396
+ path_or_fileobj=self._get_tool_code().encode(),
397
+ ),
398
+ # Add Gradio app
399
+ CommitOperationAdd(
400
+ path_in_repo="app.py",
401
+ path_or_fileobj=self._get_gradio_app_code().encode(),
402
+ ),
403
+ # Add requirements
404
+ CommitOperationAdd(
405
+ path_in_repo="requirements.txt",
406
+ path_or_fileobj=self._get_requirements().encode(),
407
+ ),
408
+ ]
409
+ return additions
410
+
411
+ def _get_tool_code(self) -> str:
412
+ """Get the tool's code."""
413
+ return self.to_dict()["code"]
414
+
415
+ def _get_gradio_app_code(self, tool_module_name: str = "tool") -> str:
416
+ """Get the Gradio app code."""
417
+ class_name = self.__class__.__name__
418
+ return textwrap.dedent(
419
+ f"""\
420
+ from smolagents import launch_gradio_demo
421
+ from {tool_module_name} import {class_name}
422
+
423
+ tool = {class_name}()
424
+ launch_gradio_demo(tool)
425
+ """
426
+ )
427
+
428
+ def _get_requirements(self) -> str:
429
+ """Get the requirements."""
430
+ return "\n".join(self.to_dict()["requirements"])
431
+
432
+ @classmethod
433
+ def from_hub(
434
+ cls,
435
+ repo_id: str,
436
+ token: str | None = None,
437
+ trust_remote_code: bool = False,
438
+ **kwargs,
439
+ ):
440
+ """
441
+ Loads a tool defined on the Hub.
442
+
443
+ <Tip warning={true}>
444
+
445
+ Loading a tool from the Hub means that you'll download the tool and execute it locally.
446
+ ALWAYS inspect the tool you're downloading before loading it within your runtime, as you would do when
447
+ installing a package using pip/npm/apt.
448
+
449
+ </Tip>
450
+
451
+ Args:
452
+ repo_id (`str`):
453
+ The name of the Space repo on the Hub where your tool is defined.
454
+ token (`str`, *optional*):
455
+ The token to identify you on hf.co. If unset, will use the token generated when running
456
+ `huggingface-cli login` (stored in `~/.huggingface`).
457
+ trust_remote_code(`str`, *optional*, defaults to False):
458
+ This flags marks that you understand the risk of running remote code and that you trust this tool.
459
+ If not setting this to True, loading the tool from Hub will fail.
460
+ kwargs (additional keyword arguments, *optional*):
461
+ Additional keyword arguments that will be split in two: all arguments relevant to the Hub (such as
462
+ `cache_dir`, `revision`, `subfolder`) will be used when downloading the files for your tool, and the
463
+ others will be passed along to its init.
464
+ """
465
+ if not trust_remote_code:
466
+ raise ValueError(
467
+ "Loading a tool from Hub requires to acknowledge you trust its code: to do so, pass `trust_remote_code=True`."
468
+ )
469
+
470
+ # Get the tool's tool.py file.
471
+ tool_file = hf_hub_download(
472
+ repo_id,
473
+ "tool.py",
474
+ token=token,
475
+ repo_type="space",
476
+ cache_dir=kwargs.get("cache_dir"),
477
+ force_download=kwargs.get("force_download"),
478
+ proxies=kwargs.get("proxies"),
479
+ revision=kwargs.get("revision"),
480
+ subfolder=kwargs.get("subfolder"),
481
+ local_files_only=kwargs.get("local_files_only"),
482
+ )
483
+
484
+ tool_code = Path(tool_file).read_text()
485
+ return Tool.from_code(tool_code, **kwargs)
486
+
487
+ @classmethod
488
+ def from_code(cls, tool_code: str, **kwargs):
489
+ module = types.ModuleType("dynamic_tool")
490
+
491
+ exec(tool_code, module.__dict__)
492
+
493
+ # Find the Tool subclass
494
+ tool_class = next(
495
+ (
496
+ obj
497
+ for _, obj in inspect.getmembers(module, inspect.isclass)
498
+ if issubclass(obj, Tool) and obj is not Tool
499
+ ),
500
+ None,
501
+ )
502
+
503
+ if tool_class is None:
504
+ raise ValueError("No Tool subclass found in the code.")
505
+
506
+ if not isinstance(tool_class.inputs, dict):
507
+ tool_class.inputs = ast.literal_eval(tool_class.inputs)
508
+
509
+ return tool_class(**kwargs)
510
+
511
+ @staticmethod
512
+ def from_space(
513
+ space_id: str,
514
+ name: str,
515
+ description: str,
516
+ api_name: str | None = None,
517
+ token: str | None = None,
518
+ ):
519
+ """
520
+ Creates a [`Tool`] from a Space given its id on the Hub.
521
+
522
+ Args:
523
+ space_id (`str`):
524
+ The id of the Space on the Hub.
525
+ name (`str`):
526
+ The name of the tool.
527
+ description (`str`):
528
+ The description of the tool.
529
+ api_name (`str`, *optional*):
530
+ The specific api_name to use, if the space has several tabs. If not precised, will default to the first available api.
531
+ token (`str`, *optional*):
532
+ Add your token to access private spaces or increase your GPU quotas.
533
+ Returns:
534
+ [`Tool`]:
535
+ The Space, as a tool.
536
+
537
+ Examples:
538
+ ```py
539
+ >>> image_generator = Tool.from_space(
540
+ ... space_id="black-forest-labs/FLUX.1-schnell",
541
+ ... name="image-generator",
542
+ ... description="Generate an image from a prompt"
543
+ ... )
544
+ >>> image = image_generator("Generate an image of a cool surfer in Tahiti")
545
+ ```
546
+ ```py
547
+ >>> face_swapper = Tool.from_space(
548
+ ... "tuan2308/face-swap",
549
+ ... "face_swapper",
550
+ ... "Tool that puts the face shown on the first image on the second image. You can give it paths to images.",
551
+ ... )
552
+ >>> image = face_swapper('./aymeric.jpeg', './ruth.jpg')
553
+ ```
554
+ """
555
+ from gradio_client import Client, handle_file
556
+
557
+ class SpaceToolWrapper(Tool):
558
+ skip_forward_signature_validation = True
559
+
560
+ def __init__(
561
+ self,
562
+ space_id: str,
563
+ name: str,
564
+ description: str,
565
+ api_name: str | None = None,
566
+ token: str | None = None,
567
+ ):
568
+ self.name = name
569
+ self.description = description
570
+ self.client = Client(space_id, hf_token=token)
571
+ space_description = self.client.view_api(return_format="dict", print_info=False)["named_endpoints"]
572
+
573
+ # If api_name is not defined, take the first of the available APIs for this space
574
+ if api_name is None:
575
+ api_name = list(space_description.keys())[0]
576
+ logger.warning(
577
+ f"Since `api_name` was not defined, it was automatically set to the first available API: `{api_name}`."
578
+ )
579
+ self.api_name = api_name
580
+
581
+ try:
582
+ space_description_api = space_description[api_name]
583
+ except KeyError:
584
+ raise KeyError(f"Could not find specified {api_name=} among available api names.")
585
+
586
+ self.inputs = {}
587
+ for parameter in space_description_api["parameters"]:
588
+ if not parameter["parameter_has_default"]:
589
+ parameter_type = parameter["type"]["type"]
590
+ if parameter_type == "object":
591
+ parameter_type = "any"
592
+ self.inputs[parameter["parameter_name"]] = {
593
+ "type": parameter_type,
594
+ "description": parameter["python_type"]["description"],
595
+ }
596
+ output_component = space_description_api["returns"][0]["component"]
597
+ if output_component == "Image":
598
+ self.output_type = "image"
599
+ elif output_component == "Audio":
600
+ self.output_type = "audio"
601
+ else:
602
+ self.output_type = "any"
603
+ self.is_initialized = True
604
+
605
+ def sanitize_argument_for_prediction(self, arg):
606
+ from gradio_client.utils import is_http_url_like
607
+ from PIL.Image import Image
608
+
609
+ if isinstance(arg, Image):
610
+ temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
611
+ arg.save(temp_file.name)
612
+ arg = temp_file.name
613
+ if (
614
+ (isinstance(arg, str) and os.path.isfile(arg))
615
+ or (isinstance(arg, Path) and arg.exists() and arg.is_file())
616
+ or is_http_url_like(arg)
617
+ ):
618
+ arg = handle_file(arg)
619
+ return arg
620
+
621
+ def forward(self, *args, **kwargs):
622
+ # Preprocess args and kwargs:
623
+ args = list(args)
624
+ for i, arg in enumerate(args):
625
+ args[i] = self.sanitize_argument_for_prediction(arg)
626
+ for arg_name, arg in kwargs.items():
627
+ kwargs[arg_name] = self.sanitize_argument_for_prediction(arg)
628
+
629
+ output = self.client.predict(*args, api_name=self.api_name, **kwargs)
630
+ if isinstance(output, tuple) or isinstance(output, list):
631
+ return output[
632
+ 0
633
+ ] # Sometime the space also returns the generation seed, in which case the result is at index 0
634
+ return output
635
+
636
+ return SpaceToolWrapper(
637
+ space_id=space_id,
638
+ name=name,
639
+ description=description,
640
+ api_name=api_name,
641
+ token=token,
642
+ )
643
+
644
+ @staticmethod
645
+ def from_gradio(gradio_tool):
646
+ """
647
+ Creates a [`Tool`] from a gradio tool.
648
+ """
649
+ import inspect
650
+
651
+ class GradioToolWrapper(Tool):
652
+ def __init__(self, _gradio_tool):
653
+ self.name = _gradio_tool.name
654
+ self.description = _gradio_tool.description
655
+ self.output_type = "string"
656
+ self._gradio_tool = _gradio_tool
657
+ func_args = list(inspect.signature(_gradio_tool.run).parameters.items())
658
+ self.inputs = {
659
+ key: {"type": CONVERSION_DICT[value.annotation], "description": ""} for key, value in func_args
660
+ }
661
+ self.forward = self._gradio_tool.run
662
+
663
+ return GradioToolWrapper(gradio_tool)
664
+
665
+ @staticmethod
666
+ def from_langchain(langchain_tool):
667
+ """
668
+ Creates a [`Tool`] from a langchain tool.
669
+ """
670
+
671
+ class LangChainToolWrapper(Tool):
672
+ skip_forward_signature_validation = True
673
+
674
+ def __init__(self, _langchain_tool):
675
+ self.name = _langchain_tool.name.lower()
676
+ self.description = _langchain_tool.description
677
+ self.inputs = _langchain_tool.args.copy()
678
+ for input_content in self.inputs.values():
679
+ if "title" in input_content:
680
+ input_content.pop("title")
681
+ input_content["description"] = ""
682
+ self.output_type = "string"
683
+ self.langchain_tool = _langchain_tool
684
+ self.is_initialized = True
685
+
686
+ def forward(self, *args, **kwargs):
687
+ tool_input = kwargs.copy()
688
+ for index, argument in enumerate(args):
689
+ if index < len(self.inputs):
690
+ input_key = next(iter(self.inputs))
691
+ tool_input[input_key] = argument
692
+ return self.langchain_tool.run(tool_input)
693
+
694
+ return LangChainToolWrapper(langchain_tool)
695
+
696
+
697
+ def launch_gradio_demo(tool: Tool):
698
+ """
699
+ Launches a gradio demo for a tool. The corresponding tool class needs to properly implement the class attributes
700
+ `inputs` and `output_type`.
701
+
702
+ Args:
703
+ tool (`Tool`): The tool for which to launch the demo.
704
+ """
705
+ try:
706
+ import gradio as gr
707
+ except ImportError:
708
+ raise ImportError("Gradio should be installed in order to launch a gradio demo.")
709
+
710
+ TYPE_TO_COMPONENT_CLASS_MAPPING = {
711
+ "boolean": gr.Checkbox,
712
+ "image": gr.Image,
713
+ "audio": gr.Audio,
714
+ "string": gr.Textbox,
715
+ "integer": gr.Textbox,
716
+ "number": gr.Textbox,
717
+ }
718
+
719
+ def tool_forward(*args, **kwargs):
720
+ return tool(*args, sanitize_inputs_outputs=True, **kwargs)
721
+
722
+ tool_forward.__signature__ = inspect.signature(tool.forward)
723
+
724
+ gradio_inputs = []
725
+ for input_name, input_details in tool.inputs.items():
726
+ input_gradio_component_class = TYPE_TO_COMPONENT_CLASS_MAPPING[input_details["type"]]
727
+ new_component = input_gradio_component_class(label=input_name)
728
+ gradio_inputs.append(new_component)
729
+
730
+ output_gradio_component_class = TYPE_TO_COMPONENT_CLASS_MAPPING[tool.output_type]
731
+ gradio_output = output_gradio_component_class(label="Output")
732
+
733
+ gr.Interface(
734
+ fn=tool_forward,
735
+ inputs=gradio_inputs,
736
+ outputs=gradio_output,
737
+ title=tool.name,
738
+ description=tool.description,
739
+ api_name=tool.name,
740
+ ).launch()
741
+
742
+
743
+ def load_tool(
744
+ repo_id,
745
+ model_repo_id: str | None = None,
746
+ token: str | None = None,
747
+ trust_remote_code: bool = False,
748
+ **kwargs,
749
+ ):
750
+ """
751
+ Main function to quickly load a tool from the Hub.
752
+
753
+ <Tip warning={true}>
754
+
755
+ Loading a tool means that you'll download the tool and execute it locally.
756
+ ALWAYS inspect the tool you're downloading before loading it within your runtime, as you would do when
757
+ installing a package using pip/npm/apt.
758
+
759
+ </Tip>
760
+
761
+ Args:
762
+ repo_id (`str`):
763
+ Space repo ID of a tool on the Hub.
764
+ model_repo_id (`str`, *optional*):
765
+ Use this argument to use a different model than the default one for the tool you selected.
766
+ token (`str`, *optional*):
767
+ The token to identify you on hf.co. If unset, will use the token generated when running `huggingface-cli
768
+ login` (stored in `~/.huggingface`).
769
+ trust_remote_code (`bool`, *optional*, defaults to False):
770
+ This needs to be accepted in order to load a tool from Hub.
771
+ kwargs (additional keyword arguments, *optional*):
772
+ Additional keyword arguments that will be split in two: all arguments relevant to the Hub (such as
773
+ `cache_dir`, `revision`, `subfolder`) will be used when downloading the files for your tool, and the others
774
+ will be passed along to its init.
775
+ """
776
+ return Tool.from_hub(
777
+ repo_id,
778
+ model_repo_id=model_repo_id,
779
+ token=token,
780
+ trust_remote_code=trust_remote_code,
781
+ **kwargs,
782
+ )
783
+
784
+
785
+ def add_description(description):
786
+ """
787
+ A decorator that adds a description to a function.
788
+ """
789
+
790
+ def inner(func):
791
+ func.description = description
792
+ func.name = func.__name__
793
+ return func
794
+
795
+ return inner
796
+
797
+
798
+ class ToolCollection:
799
+ """
800
+ Tool collections enable loading a collection of tools in the agent's toolbox.
801
+
802
+ Collections can be loaded from a collection in the Hub or from an MCP server, see:
803
+ - [`ToolCollection.from_hub`]
804
+ - [`ToolCollection.from_mcp`]
805
+
806
+ For example and usage, see: [`ToolCollection.from_hub`] and [`ToolCollection.from_mcp`]
807
+ """
808
+
809
+ def __init__(self, tools: list[Tool]):
810
+ self.tools = tools
811
+
812
+ @classmethod
813
+ def from_hub(
814
+ cls,
815
+ collection_slug: str,
816
+ token: str | None = None,
817
+ trust_remote_code: bool = False,
818
+ ) -> "ToolCollection":
819
+ """Loads a tool collection from the Hub.
820
+
821
+ it adds a collection of tools from all Spaces in the collection to the agent's toolbox
822
+
823
+ > [!NOTE]
824
+ > Only Spaces will be fetched, so you can feel free to add models and datasets to your collection if you'd
825
+ > like for this collection to showcase them.
826
+
827
+ Args:
828
+ collection_slug (str): The collection slug referencing the collection.
829
+ token (str, *optional*): The authentication token if the collection is private.
830
+ trust_remote_code (bool, *optional*, defaults to False): Whether to trust the remote code.
831
+
832
+ Returns:
833
+ ToolCollection: A tool collection instance loaded with the tools.
834
+
835
+ Example:
836
+ ```py
837
+ >>> from smolagents import ToolCollection, CodeAgent
838
+
839
+ >>> image_tool_collection = ToolCollection.from_hub("huggingface-tools/diffusion-tools-6630bb19a942c2306a2cdb6f")
840
+ >>> agent = CodeAgent(tools=[*image_tool_collection.tools], add_base_tools=True)
841
+
842
+ >>> agent.run("Please draw me a picture of rivers and lakes.")
843
+ ```
844
+ """
845
+ _collection = get_collection(collection_slug, token=token)
846
+ _hub_repo_ids = {item.item_id for item in _collection.items if item.item_type == "space"}
847
+
848
+ tools = {Tool.from_hub(repo_id, token, trust_remote_code) for repo_id in _hub_repo_ids}
849
+
850
+ return cls(tools)
851
+
852
+ @classmethod
853
+ @contextmanager
854
+ def from_mcp(
855
+ cls, server_parameters: "mcp.StdioServerParameters" | dict, trust_remote_code: bool = False
856
+ ) -> "ToolCollection":
857
+ """Automatically load a tool collection from an MCP server.
858
+
859
+ This method supports Stdio, Streamable HTTP, and legacy HTTP+SSE MCP servers. Look at the `server_parameters`
860
+ argument for more details on how to connect to each MCP server.
861
+
862
+ Note: a separate thread will be spawned to run an asyncio event loop handling
863
+ the MCP server.
864
+
865
+ Args:
866
+ server_parameters (`mcp.StdioServerParameters` or `dict`):
867
+ Configuration parameters to connect to the MCP server. This can be:
868
+
869
+ - An instance of `mcp.StdioServerParameters` for connecting a Stdio MCP server via standard input/output using a subprocess.
870
+
871
+ - A `dict` with at least:
872
+ - "url": URL of the server.
873
+ - "transport": Transport protocol to use, one of:
874
+ - "streamable-http": (recommended) Streamable HTTP transport.
875
+ - "sse": Legacy HTTP+SSE transport (deprecated).
876
+ If "transport" is omitted, the legacy "sse" transport is assumed (a deprecation warning will be issued).
877
+
878
+ <Deprecated version="1.17.0">
879
+ The HTTP+SSE transport is deprecated and future behavior will default to the Streamable HTTP transport.
880
+ Please pass explicitly the "transport" key.
881
+ </Deprecated>
882
+ trust_remote_code (`bool`, *optional*, defaults to `False`):
883
+ Whether to trust the execution of code from tools defined on the MCP server.
884
+ This option should only be set to `True` if you trust the MCP server,
885
+ and undertand the risks associated with running remote code on your local machine.
886
+ If set to `False`, loading tools from MCP will fail.
887
+
888
+
889
+ Returns:
890
+ ToolCollection: A tool collection instance.
891
+
892
+ Example with a Stdio MCP server:
893
+ ```py
894
+ >>> import os
895
+ >>> from smolagents import ToolCollection, CodeAgent, InferenceClientModel
896
+ >>> from mcp import StdioServerParameters
897
+
898
+ >>> model = InferenceClientModel()
899
+
900
+ >>> server_parameters = StdioServerParameters(
901
+ >>> command="uvx",
902
+ >>> args=["--quiet", "[email protected]"],
903
+ >>> env={"UV_PYTHON": "3.12", **os.environ},
904
+ >>> )
905
+
906
+ >>> with ToolCollection.from_mcp(server_parameters, trust_remote_code=True) as tool_collection:
907
+ >>> agent = CodeAgent(tools=[*tool_collection.tools], add_base_tools=True, model=model)
908
+ >>> agent.run("Please find a remedy for hangover.")
909
+ ```
910
+
911
+ Example with a Streamable HTTP MCP server:
912
+ ```py
913
+ >>> with ToolCollection.from_mcp({"url": "http://127.0.0.1:8000/mcp", "transport": "streamable-http"}, trust_remote_code=True) as tool_collection:
914
+ >>> agent = CodeAgent(tools=[*tool_collection.tools], add_base_tools=True, model=model)
915
+ >>> agent.run("Please find a remedy for hangover.")
916
+ ```
917
+ """
918
+ try:
919
+ from mcpadapt.core import MCPAdapt
920
+ from mcpadapt.smolagents_adapter import SmolAgentsAdapter
921
+ except ImportError:
922
+ raise ImportError(
923
+ """Please install 'mcp' extra to use ToolCollection.from_mcp: `pip install "smolagents[mcp]"`."""
924
+ )
925
+ if isinstance(server_parameters, dict):
926
+ transport = server_parameters.get("transport")
927
+ if transport is None:
928
+ warnings.warn(
929
+ "Passing a dict as server_parameters without specifying the 'transport' key is deprecated. "
930
+ "For now, it defaults to the legacy 'sse' (HTTP+SSE) transport, but this default will change "
931
+ "to 'streamable-http' in version 1.20. Please add the 'transport' key explicitly. ",
932
+ FutureWarning,
933
+ )
934
+ transport = "sse"
935
+ server_parameters["transport"] = transport
936
+ if transport not in {"sse", "streamable-http"}:
937
+ raise ValueError(
938
+ f"Unsupported transport: {transport}. Supported transports are 'streamable-http' and 'sse'."
939
+ )
940
+ if not trust_remote_code:
941
+ raise ValueError(
942
+ "Loading tools from MCP requires you to acknowledge you trust the MCP server, "
943
+ "as it will execute code on your local machine: pass `trust_remote_code=True`."
944
+ )
945
+ with MCPAdapt(server_parameters, SmolAgentsAdapter()) as tools:
946
+ yield cls(tools)
947
+
948
+
949
+ def tool(tool_function: Callable) -> Tool:
950
+ """
951
+ Convert a function into an instance of a dynamically created Tool subclass.
952
+
953
+ Args:
954
+ tool_function (`Callable`): Function to convert into a Tool subclass.
955
+ Should have type hints for each input and a type hint for the output.
956
+ Should also have a docstring including the description of the function
957
+ and an 'Args:' part where each argument is described.
958
+ """
959
+ tool_json_schema = get_json_schema(tool_function)["function"]
960
+ if "return" not in tool_json_schema:
961
+ raise TypeHintParsingException("Tool return type not found: make sure your function has a return type hint!")
962
+
963
+ class SimpleTool(Tool):
964
+ def __init__(self):
965
+ self.is_initialized = True
966
+
967
+ # Set the class attributes
968
+ SimpleTool.name = tool_json_schema["name"]
969
+ SimpleTool.description = tool_json_schema["description"]
970
+ SimpleTool.inputs = tool_json_schema["parameters"]["properties"]
971
+ SimpleTool.output_type = tool_json_schema["return"]["type"]
972
+
973
+ @wraps(tool_function)
974
+ def wrapped_function(*args, **kwargs):
975
+ return tool_function(*args, **kwargs)
976
+
977
+ # Bind the copied function to the forward method
978
+ SimpleTool.forward = staticmethod(wrapped_function)
979
+
980
+ # Get the signature parameters of the tool function
981
+ sig = inspect.signature(tool_function)
982
+ # - Add "self" as first parameter to tool_function signature
983
+ new_sig = sig.replace(
984
+ parameters=[inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD)] + list(sig.parameters.values())
985
+ )
986
+ # - Set the signature of the forward method
987
+ SimpleTool.forward.__signature__ = new_sig
988
+
989
+ # Create and attach the source code of the dynamically created tool class and forward method
990
+ # - Get the source code of tool_function
991
+ tool_source = inspect.getsource(tool_function)
992
+ # - Remove the tool decorator and function definition line
993
+ tool_source_body = "\n".join(tool_source.split("\n")[2:])
994
+ # - Dedent
995
+ tool_source_body = textwrap.dedent(tool_source_body)
996
+ # - Create the forward method source, including def line and indentation
997
+ forward_method_source = f"def forward{str(new_sig)}:\n{textwrap.indent(tool_source_body, ' ')}"
998
+ # - Create the class source
999
+ class_source = (
1000
+ textwrap.dedent(f"""
1001
+ class SimpleTool(Tool):
1002
+ name: str = "{tool_json_schema["name"]}"
1003
+ description: str = {json.dumps(textwrap.dedent(tool_json_schema["description"]).strip())}
1004
+ inputs: dict[str, dict[str, str]] = {tool_json_schema["parameters"]["properties"]}
1005
+ output_type: str = "{tool_json_schema["return"]["type"]}"
1006
+
1007
+ def __init__(self):
1008
+ self.is_initialized = True
1009
+
1010
+ """)
1011
+ + textwrap.indent(forward_method_source, " ") # indent for class method
1012
+ )
1013
+ # - Store the source code on both class and method for inspection
1014
+ SimpleTool.__source__ = class_source
1015
+ SimpleTool.forward.__source__ = forward_method_source
1016
+
1017
+ simple_tool = SimpleTool()
1018
+ return simple_tool
1019
+
1020
+
1021
+ class PipelineTool(Tool):
1022
+ """
1023
+ A [`Tool`] tailored towards Transformer models. On top of the class attributes of the base class [`Tool`], you will
1024
+ need to specify:
1025
+
1026
+ - **model_class** (`type`) -- The class to use to load the model in this tool.
1027
+ - **default_checkpoint** (`str`) -- The default checkpoint that should be used when the user doesn't specify one.
1028
+ - **pre_processor_class** (`type`, *optional*, defaults to [`transformers.AutoProcessor`]) -- The class to use to load the
1029
+ pre-processor
1030
+ - **post_processor_class** (`type`, *optional*, defaults to [`transformers.AutoProcessor`]) -- The class to use to load the
1031
+ post-processor (when different from the pre-processor).
1032
+
1033
+ Args:
1034
+ model (`str` or [`transformers.PreTrainedModel`], *optional*):
1035
+ The name of the checkpoint to use for the model, or the instantiated model. If unset, will default to the
1036
+ value of the class attribute `default_checkpoint`.
1037
+ pre_processor (`str` or `Any`, *optional*):
1038
+ The name of the checkpoint to use for the pre-processor, or the instantiated pre-processor (can be a
1039
+ tokenizer, an image processor, a feature extractor or a processor). Will default to the value of `model` if
1040
+ unset.
1041
+ post_processor (`str` or `Any`, *optional*):
1042
+ The name of the checkpoint to use for the post-processor, or the instantiated pre-processor (can be a
1043
+ tokenizer, an image processor, a feature extractor or a processor). Will default to the `pre_processor` if
1044
+ unset.
1045
+ device (`int`, `str` or `torch.device`, *optional*):
1046
+ The device on which to execute the model. Will default to any accelerator available (GPU, MPS etc...), the
1047
+ CPU otherwise.
1048
+ device_map (`str` or `dict`, *optional*):
1049
+ If passed along, will be used to instantiate the model.
1050
+ model_kwargs (`dict`, *optional*):
1051
+ Any keyword argument to send to the model instantiation.
1052
+ token (`str`, *optional*):
1053
+ The token to use as HTTP bearer authorization for remote files. If unset, will use the token generated when
1054
+ running `huggingface-cli login` (stored in `~/.huggingface`).
1055
+ hub_kwargs (additional keyword arguments, *optional*):
1056
+ Any additional keyword argument to send to the methods that will load the data from the Hub.
1057
+ """
1058
+
1059
+ pre_processor_class = None
1060
+ model_class = None
1061
+ post_processor_class = None
1062
+ default_checkpoint = None
1063
+ description = "This is a pipeline tool"
1064
+ name = "pipeline"
1065
+ inputs = {"prompt": str}
1066
+ output_type = str
1067
+ skip_forward_signature_validation = True
1068
+
1069
+ def __init__(
1070
+ self,
1071
+ model=None,
1072
+ pre_processor=None,
1073
+ post_processor=None,
1074
+ device=None,
1075
+ device_map=None,
1076
+ model_kwargs=None,
1077
+ token=None,
1078
+ **hub_kwargs,
1079
+ ):
1080
+ if not _is_package_available("accelerate") or not _is_package_available("torch"):
1081
+ raise ModuleNotFoundError(
1082
+ "Please install 'transformers' extra to use a PipelineTool: `pip install 'smolagents[transformers]'`"
1083
+ )
1084
+
1085
+ if model is None:
1086
+ if self.default_checkpoint is None:
1087
+ raise ValueError("This tool does not implement a default checkpoint, you need to pass one.")
1088
+ model = self.default_checkpoint
1089
+ if pre_processor is None:
1090
+ pre_processor = model
1091
+
1092
+ self.model = model
1093
+ self.pre_processor = pre_processor
1094
+ self.post_processor = post_processor
1095
+ self.device = device
1096
+ self.device_map = device_map
1097
+ self.model_kwargs = {} if model_kwargs is None else model_kwargs
1098
+ if device_map is not None:
1099
+ self.model_kwargs["device_map"] = device_map
1100
+ self.hub_kwargs = hub_kwargs
1101
+ self.hub_kwargs["token"] = token
1102
+
1103
+ super().__init__()
1104
+
1105
+ def setup(self):
1106
+ """
1107
+ Instantiates the `pre_processor`, `model` and `post_processor` if necessary.
1108
+ """
1109
+ if isinstance(self.pre_processor, str):
1110
+ if self.pre_processor_class is None:
1111
+ from transformers import AutoProcessor
1112
+
1113
+ self.pre_processor_class = AutoProcessor
1114
+ self.pre_processor = self.pre_processor_class.from_pretrained(self.pre_processor, **self.hub_kwargs)
1115
+
1116
+ if isinstance(self.model, str):
1117
+ self.model = self.model_class.from_pretrained(self.model, **self.model_kwargs, **self.hub_kwargs)
1118
+
1119
+ if self.post_processor is None:
1120
+ self.post_processor = self.pre_processor
1121
+ elif isinstance(self.post_processor, str):
1122
+ if self.post_processor_class is None:
1123
+ from transformers import AutoProcessor
1124
+
1125
+ self.post_processor_class = AutoProcessor
1126
+ self.post_processor = self.post_processor_class.from_pretrained(self.post_processor, **self.hub_kwargs)
1127
+
1128
+ if self.device is None:
1129
+ if self.device_map is not None:
1130
+ self.device = list(self.model.hf_device_map.values())[0]
1131
+ else:
1132
+ from accelerate import PartialState
1133
+
1134
+ self.device = PartialState().default_device
1135
+
1136
+ if self.device_map is None:
1137
+ self.model.to(self.device)
1138
+
1139
+ super().setup()
1140
+
1141
+ def encode(self, raw_inputs):
1142
+ """
1143
+ Uses the `pre_processor` to prepare the inputs for the `model`.
1144
+ """
1145
+ return self.pre_processor(raw_inputs)
1146
+
1147
+ def forward(self, inputs):
1148
+ """
1149
+ Sends the inputs through the `model`.
1150
+ """
1151
+ import torch
1152
+
1153
+ with torch.no_grad():
1154
+ return self.model(**inputs)
1155
+
1156
+ def decode(self, outputs):
1157
+ """
1158
+ Uses the `post_processor` to decode the model output.
1159
+ """
1160
+ return self.post_processor(outputs)
1161
+
1162
+ def __call__(self, *args, sanitize_inputs_outputs: bool = False, **kwargs):
1163
+ import torch
1164
+ from accelerate.utils import send_to_device
1165
+
1166
+ if not self.is_initialized:
1167
+ self.setup()
1168
+
1169
+ if sanitize_inputs_outputs:
1170
+ args, kwargs = handle_agent_input_types(*args, **kwargs)
1171
+ encoded_inputs = self.encode(*args, **kwargs)
1172
+
1173
+ tensor_inputs = {k: v for k, v in encoded_inputs.items() if isinstance(v, torch.Tensor)}
1174
+ non_tensor_inputs = {k: v for k, v in encoded_inputs.items() if not isinstance(v, torch.Tensor)}
1175
+
1176
+ encoded_inputs = send_to_device(tensor_inputs, self.device)
1177
+ outputs = self.forward({**encoded_inputs, **non_tensor_inputs})
1178
+ outputs = send_to_device(outputs, "cpu")
1179
+ decoded_outputs = self.decode(outputs)
1180
+ if sanitize_inputs_outputs:
1181
+ decoded_outputs = handle_agent_output_types(decoded_outputs, self.output_type)
1182
+ return decoded_outputs
1183
+
1184
+
1185
+ def get_tools_definition_code(tools: dict[str, Tool]) -> str:
1186
+ tool_codes = []
1187
+ for tool in tools.values():
1188
+ validate_tool_attributes(tool.__class__, check_imports=False)
1189
+ tool_code = instance_to_source(tool, base_cls=Tool)
1190
+ tool_code = tool_code.replace("from smolagents.tools import Tool", "")
1191
+ tool_code += f"\n\n{tool.name} = {tool.__class__.__name__}()\n"
1192
+ tool_codes.append(tool_code)
1193
+
1194
+ tool_definition_code = "\n".join([f"import {module}" for module in BASE_BUILTIN_MODULES])
1195
+ tool_definition_code += textwrap.dedent(
1196
+ """
1197
+ from typing import Any
1198
+
1199
+ class Tool:
1200
+ def __call__(self, *args, **kwargs):
1201
+ return self.forward(*args, **kwargs)
1202
+
1203
+ def forward(self, *args, **kwargs):
1204
+ pass # to be implemented in child class
1205
+ """
1206
+ )
1207
+ tool_definition_code += "\n\n".join(tool_codes)
1208
+ return tool_definition_code
1209
+
1210
+
1211
+ def validate_tool_arguments(tool: Tool, arguments: Any) -> str | None:
1212
+ if isinstance(arguments, dict):
1213
+ for key, value in arguments.items():
1214
+ if key not in tool.inputs:
1215
+ return f"Argument {key} is not in the tool's input schema."
1216
+
1217
+ parsed_type = _get_json_schema_type(type(value))["type"]
1218
+
1219
+ if parsed_type != tool.inputs[key]["type"] and not tool.inputs[key]["type"] == "any":
1220
+ return f"Argument {key} has type '{parsed_type}' but should be '{tool.inputs[key]['type']}'."
1221
+ for key in tool.inputs:
1222
+ if key not in arguments:
1223
+ return f"Argument {key} is required."
1224
+ return None
1225
+ else:
1226
+ expected_type = list(tool.inputs.values())[0]["type"]
1227
+ if _get_json_schema_type(type(arguments))["type"] != expected_type and not expected_type == "any":
1228
+ return f"Argument has type '{type(arguments).__name__}' but should be '{expected_type}'."
1229
+ return None
1230
+
1231
+
1232
+ __all__ = [
1233
+ "AUTHORIZED_TYPES",
1234
+ "Tool",
1235
+ "tool",
1236
+ "load_tool",
1237
+ "launch_gradio_demo",
1238
+ "ToolCollection",
1239
+ ]
src/smolagents/utils.py ADDED
@@ -0,0 +1,500 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+
4
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ import ast
18
+ import base64
19
+ import importlib.metadata
20
+ import importlib.util
21
+ import inspect
22
+ import json
23
+ import keyword
24
+ import os
25
+ import re
26
+ import types
27
+ from functools import lru_cache
28
+ from io import BytesIO
29
+ from pathlib import Path
30
+ from textwrap import dedent
31
+ from typing import TYPE_CHECKING, Any
32
+
33
+
34
+ if TYPE_CHECKING:
35
+ from smolagents.memory import AgentLogger
36
+
37
+
38
+ __all__ = ["AgentError"]
39
+
40
+
41
+ @lru_cache
42
+ def _is_package_available(package_name: str) -> bool:
43
+ try:
44
+ importlib.metadata.version(package_name)
45
+ return True
46
+ except importlib.metadata.PackageNotFoundError:
47
+ return False
48
+
49
+
50
+ BASE_BUILTIN_MODULES = [
51
+ "collections",
52
+ "datetime",
53
+ "itertools",
54
+ "math",
55
+ "queue",
56
+ "random",
57
+ "re",
58
+ "stat",
59
+ "statistics",
60
+ "time",
61
+ "unicodedata",
62
+ ]
63
+
64
+
65
+ def escape_code_brackets(text: str) -> str:
66
+ """Escapes square brackets in code segments while preserving Rich styling tags."""
67
+
68
+ def replace_bracketed_content(match):
69
+ content = match.group(1)
70
+ cleaned = re.sub(
71
+ r"bold|red|green|blue|yellow|magenta|cyan|white|black|italic|dim|\s|#[0-9a-fA-F]{6}", "", content
72
+ )
73
+ return f"\\[{content}\\]" if cleaned.strip() else f"[{content}]"
74
+
75
+ return re.sub(r"\[([^\]]*)\]", replace_bracketed_content, text)
76
+
77
+
78
+ class AgentError(Exception):
79
+ """Base class for other agent-related exceptions"""
80
+
81
+ def __init__(self, message, logger: "AgentLogger"):
82
+ super().__init__(message)
83
+ self.message = message
84
+ logger.log_error(message)
85
+
86
+ def dict(self) -> dict[str, str]:
87
+ return {"type": self.__class__.__name__, "message": str(self.message)}
88
+
89
+
90
+ class AgentParsingError(AgentError):
91
+ """Exception raised for errors in parsing in the agent"""
92
+
93
+ pass
94
+
95
+
96
+ class AgentExecutionError(AgentError):
97
+ """Exception raised for errors in execution in the agent"""
98
+
99
+ pass
100
+
101
+
102
+ class AgentMaxStepsError(AgentError):
103
+ """Exception raised for errors in execution in the agent"""
104
+
105
+ pass
106
+
107
+
108
+ class AgentToolCallError(AgentExecutionError):
109
+ """Exception raised for errors when incorrect arguments are passed to the tool"""
110
+
111
+ pass
112
+
113
+
114
+ class AgentToolExecutionError(AgentExecutionError):
115
+ """Exception raised for errors when executing a tool"""
116
+
117
+ pass
118
+
119
+
120
+ class AgentGenerationError(AgentError):
121
+ """Exception raised for errors in generation in the agent"""
122
+
123
+ pass
124
+
125
+
126
+ def make_json_serializable(obj: Any) -> Any:
127
+ """Recursive function to make objects JSON serializable"""
128
+ if obj is None:
129
+ return None
130
+ elif isinstance(obj, (str, int, float, bool)):
131
+ # Try to parse string as JSON if it looks like a JSON object/array
132
+ if isinstance(obj, str):
133
+ try:
134
+ if (obj.startswith("{") and obj.endswith("}")) or (obj.startswith("[") and obj.endswith("]")):
135
+ parsed = json.loads(obj)
136
+ return make_json_serializable(parsed)
137
+ except json.JSONDecodeError:
138
+ pass
139
+ return obj
140
+ elif isinstance(obj, (list, tuple)):
141
+ return [make_json_serializable(item) for item in obj]
142
+ elif isinstance(obj, dict):
143
+ return {str(k): make_json_serializable(v) for k, v in obj.items()}
144
+ elif hasattr(obj, "__dict__"):
145
+ # For custom objects, convert their __dict__ to a serializable format
146
+ return {"_type": obj.__class__.__name__, **{k: make_json_serializable(v) for k, v in obj.__dict__.items()}}
147
+ else:
148
+ # For any other type, convert to string
149
+ return str(obj)
150
+
151
+
152
+ def parse_json_blob(json_blob: str) -> tuple[dict[str, str], str]:
153
+ "Extracts the JSON blob from the input and returns the JSON data and the rest of the input."
154
+ try:
155
+ first_accolade_index = json_blob.find("{")
156
+ last_accolade_index = [a.start() for a in list(re.finditer("}", json_blob))][-1]
157
+ json_data = json_blob[first_accolade_index : last_accolade_index + 1]
158
+ json_data = json.loads(json_data, strict=False)
159
+ return json_data, json_blob[:first_accolade_index]
160
+ except IndexError:
161
+ raise ValueError("The model output does not contain any JSON blob.")
162
+ except json.JSONDecodeError as e:
163
+ place = e.pos
164
+ if json_blob[place - 1 : place + 2] == "},\n":
165
+ raise ValueError(
166
+ "JSON is invalid: you probably tried to provide multiple tool calls in one action. PROVIDE ONLY ONE TOOL CALL."
167
+ )
168
+ raise ValueError(
169
+ f"The JSON blob you used is invalid due to the following error: {e}.\n"
170
+ f"JSON blob was: {json_blob}, decoding failed on that specific part of the blob:\n"
171
+ f"'{json_blob[place - 4 : place + 5]}'."
172
+ )
173
+
174
+
175
+ def extract_code_from_text(text: str) -> str | None:
176
+ """Extract code from the LLM's output."""
177
+ pattern = r"<code>(.*?)</code>"
178
+ matches = re.findall(pattern, text, re.DOTALL)
179
+ if matches:
180
+ return "\n\n".join(match.strip() for match in matches)
181
+ return None
182
+
183
+
184
+ def parse_code_blobs(text: str) -> str:
185
+ """Extract code blocs from the LLM's output.
186
+
187
+ If a valid code block is passed, it returns it directly.
188
+
189
+ Args:
190
+ text (`str`): LLM's output text to parse.
191
+
192
+ Returns:
193
+ `str`: Extracted code block.
194
+
195
+ Raises:
196
+ ValueError: If no valid code block is found in the text.
197
+ """
198
+ matches = extract_code_from_text(text)
199
+ if matches:
200
+ return matches
201
+ # Maybe the LLM outputted a code blob directly
202
+ try:
203
+ ast.parse(text)
204
+ return text
205
+ except SyntaxError:
206
+ pass
207
+
208
+ if "final" in text and "answer" in text:
209
+ raise ValueError(
210
+ dedent(
211
+ f"""
212
+ Your code snippet is invalid, because the regex pattern <code>(.*?)</code> was not found in it.
213
+ Here is your code snippet:
214
+ {text}
215
+ It seems like you're trying to return the final answer, you can do it as follows:
216
+ <code>
217
+ final_answer("YOUR FINAL ANSWER HERE")
218
+ </code>
219
+ """
220
+ ).strip()
221
+ )
222
+ raise ValueError(
223
+ dedent(
224
+ f"""
225
+ Your code snippet is invalid, because the regex pattern <code>(.*?)</code> was not found in it.
226
+ Here is your code snippet:
227
+ {text}
228
+ Make sure to include code with the correct pattern, for instance:
229
+ Thoughts: Your thoughts
230
+ <code>
231
+ # Your python code here
232
+ </code>
233
+ """
234
+ ).strip()
235
+ )
236
+
237
+
238
+ MAX_LENGTH_TRUNCATE_CONTENT = 20000
239
+
240
+
241
+ def truncate_content(content: str, max_length: int = MAX_LENGTH_TRUNCATE_CONTENT) -> str:
242
+ if len(content) <= max_length:
243
+ return content
244
+ else:
245
+ return (
246
+ content[: max_length // 2]
247
+ + f"\n..._This content has been truncated to stay below {max_length} characters_...\n"
248
+ + content[-max_length // 2 :]
249
+ )
250
+
251
+
252
+ class ImportFinder(ast.NodeVisitor):
253
+ def __init__(self):
254
+ self.packages = set()
255
+
256
+ def visit_Import(self, node):
257
+ for alias in node.names:
258
+ # Get the base package name (before any dots)
259
+ base_package = alias.name.split(".")[0]
260
+ self.packages.add(base_package)
261
+
262
+ def visit_ImportFrom(self, node):
263
+ if node.module: # for "from x import y" statements
264
+ # Get the base package name (before any dots)
265
+ base_package = node.module.split(".")[0]
266
+ self.packages.add(base_package)
267
+
268
+
269
+ def get_method_source(method):
270
+ """Get source code for a method, including bound methods."""
271
+ if isinstance(method, types.MethodType):
272
+ method = method.__func__
273
+ return get_source(method)
274
+
275
+
276
+ def is_same_method(method1, method2):
277
+ """Compare two methods by their source code."""
278
+ try:
279
+ source1 = get_method_source(method1)
280
+ source2 = get_method_source(method2)
281
+
282
+ # Remove method decorators if any
283
+ source1 = "\n".join(line for line in source1.split("\n") if not line.strip().startswith("@"))
284
+ source2 = "\n".join(line for line in source2.split("\n") if not line.strip().startswith("@"))
285
+
286
+ return source1 == source2
287
+ except (TypeError, OSError):
288
+ return False
289
+
290
+
291
+ def is_same_item(item1, item2):
292
+ """Compare two class items (methods or attributes) for equality."""
293
+ if callable(item1) and callable(item2):
294
+ return is_same_method(item1, item2)
295
+ else:
296
+ return item1 == item2
297
+
298
+
299
+ def instance_to_source(instance, base_cls=None):
300
+ """Convert an instance to its class source code representation."""
301
+ cls = instance.__class__
302
+ class_name = cls.__name__
303
+
304
+ # Start building class lines
305
+ class_lines = []
306
+ if base_cls:
307
+ class_lines.append(f"class {class_name}({base_cls.__name__}):")
308
+ else:
309
+ class_lines.append(f"class {class_name}:")
310
+
311
+ # Add docstring if it exists and differs from base
312
+ if cls.__doc__ and (not base_cls or cls.__doc__ != base_cls.__doc__):
313
+ class_lines.append(f' """{cls.__doc__}"""')
314
+
315
+ # Add class-level attributes
316
+ class_attrs = {
317
+ name: value
318
+ for name, value in cls.__dict__.items()
319
+ if not name.startswith("__")
320
+ and not callable(value)
321
+ and not (base_cls and hasattr(base_cls, name) and getattr(base_cls, name) == value)
322
+ }
323
+
324
+ for name, value in class_attrs.items():
325
+ if isinstance(value, str):
326
+ # multiline value
327
+ if "\n" in value:
328
+ escaped_value = value.replace('"""', r"\"\"\"") # Escape triple quotes
329
+ class_lines.append(f' {name} = """{escaped_value}"""')
330
+ else:
331
+ class_lines.append(f" {name} = {json.dumps(value)}")
332
+ else:
333
+ class_lines.append(f" {name} = {repr(value)}")
334
+
335
+ if class_attrs:
336
+ class_lines.append("")
337
+
338
+ # Add methods
339
+ methods = {
340
+ name: func.__wrapped__ if hasattr(func, "__wrapped__") else func
341
+ for name, func in cls.__dict__.items()
342
+ if callable(func)
343
+ and (
344
+ not base_cls
345
+ or not hasattr(base_cls, name)
346
+ or (
347
+ isinstance(func, (staticmethod, classmethod))
348
+ or (getattr(base_cls, name).__code__.co_code != func.__code__.co_code)
349
+ )
350
+ )
351
+ }
352
+
353
+ for name, method in methods.items():
354
+ method_source = get_source(method)
355
+ # Clean up the indentation
356
+ method_lines = method_source.split("\n")
357
+ first_line = method_lines[0]
358
+ indent = len(first_line) - len(first_line.lstrip())
359
+ method_lines = [line[indent:] for line in method_lines]
360
+ method_source = "\n".join([" " + line if line.strip() else line for line in method_lines])
361
+ class_lines.append(method_source)
362
+ class_lines.append("")
363
+
364
+ # Find required imports using ImportFinder
365
+ import_finder = ImportFinder()
366
+ import_finder.visit(ast.parse("\n".join(class_lines)))
367
+ required_imports = import_finder.packages
368
+
369
+ # Build final code with imports
370
+ final_lines = []
371
+
372
+ # Add base class import if needed
373
+ if base_cls:
374
+ final_lines.append(f"from {base_cls.__module__} import {base_cls.__name__}")
375
+
376
+ # Add discovered imports
377
+ for package in required_imports:
378
+ final_lines.append(f"import {package}")
379
+
380
+ if final_lines: # Add empty line after imports
381
+ final_lines.append("")
382
+
383
+ # Add the class code
384
+ final_lines.extend(class_lines)
385
+
386
+ return "\n".join(final_lines)
387
+
388
+
389
+ def get_source(obj) -> str:
390
+ """Get the source code of a class or callable object (e.g.: function, method).
391
+ First attempts to get the source code using `inspect.getsource`.
392
+ In a dynamic environment (e.g.: Jupyter, IPython), if this fails,
393
+ falls back to retrieving the source code from the current interactive shell session.
394
+
395
+ Args:
396
+ obj: A class or callable object (e.g.: function, method)
397
+
398
+ Returns:
399
+ str: The source code of the object, dedented and stripped
400
+
401
+ Raises:
402
+ TypeError: If object is not a class or callable
403
+ OSError: If source code cannot be retrieved from any source
404
+ ValueError: If source cannot be found in IPython history
405
+
406
+ Note:
407
+ TODO: handle Python standard REPL
408
+ """
409
+ if not (isinstance(obj, type) or callable(obj)):
410
+ raise TypeError(f"Expected class or callable, got {type(obj)}")
411
+
412
+ inspect_error = None
413
+ try:
414
+ # Handle dynamically created classes
415
+ source = getattr(obj, "__source__", None) or inspect.getsource(obj)
416
+ return dedent(source).strip()
417
+ except OSError as e:
418
+ # let's keep track of the exception to raise it if all further methods fail
419
+ inspect_error = e
420
+ try:
421
+ import IPython
422
+
423
+ shell = IPython.get_ipython()
424
+ if not shell:
425
+ raise ImportError("No active IPython shell found")
426
+ all_cells = "\n".join(shell.user_ns.get("In", [])).strip()
427
+ if not all_cells:
428
+ raise ValueError("No code cells found in IPython session")
429
+
430
+ tree = ast.parse(all_cells)
431
+ for node in ast.walk(tree):
432
+ if isinstance(node, (ast.ClassDef, ast.FunctionDef)) and node.name == obj.__name__:
433
+ return dedent("\n".join(all_cells.split("\n")[node.lineno - 1 : node.end_lineno])).strip()
434
+ raise ValueError(f"Could not find source code for {obj.__name__} in IPython history")
435
+ except ImportError:
436
+ # IPython is not available, let's just raise the original inspect error
437
+ raise inspect_error
438
+ except ValueError as e:
439
+ # IPython is available but we couldn't find the source code, let's raise the error
440
+ raise e from inspect_error
441
+
442
+
443
+ def encode_image_base64(image):
444
+ buffered = BytesIO()
445
+ image.save(buffered, format="PNG")
446
+ return base64.b64encode(buffered.getvalue()).decode("utf-8")
447
+
448
+
449
+ def make_image_url(base64_image):
450
+ return f"data:image/png;base64,{base64_image}"
451
+
452
+
453
+ def make_init_file(folder: str | Path):
454
+ os.makedirs(folder, exist_ok=True)
455
+ # Create __init__
456
+ with open(os.path.join(folder, "__init__.py"), "w"):
457
+ pass
458
+
459
+
460
+ def is_valid_name(name: str) -> bool:
461
+ return name.isidentifier() and not keyword.iskeyword(name) if isinstance(name, str) else False
462
+
463
+
464
+ AGENT_GRADIO_APP_TEMPLATE = """import yaml
465
+ import os
466
+ from smolagents import GradioUI, {{ class_name }}, {{ agent_dict['model']['class'] }}
467
+
468
+ # Get current directory path
469
+ CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
470
+
471
+ {% for tool in tools.values() -%}
472
+ from {{managed_agent_relative_path}}tools.{{ tool.name }} import {{ tool.__class__.__name__ }} as {{ tool.name | camelcase }}
473
+ {% endfor %}
474
+ {% for managed_agent in managed_agents.values() -%}
475
+ from {{managed_agent_relative_path}}managed_agents.{{ managed_agent.name }}.app import agent_{{ managed_agent.name }}
476
+ {% endfor %}
477
+
478
+ model = {{ agent_dict['model']['class'] }}(
479
+ {% for key in agent_dict['model']['data'] if key not in ['class', 'last_input_token_count', 'last_output_token_count'] -%}
480
+ {{ key }}={{ agent_dict['model']['data'][key]|repr }},
481
+ {% endfor %})
482
+
483
+ {% for tool in tools.values() -%}
484
+ {{ tool.name }} = {{ tool.name | camelcase }}()
485
+ {% endfor %}
486
+
487
+ with open(os.path.join(CURRENT_DIR, "prompts.yaml"), 'r') as stream:
488
+ prompt_templates = yaml.safe_load(stream)
489
+
490
+ {{ agent_name }} = {{ class_name }}(
491
+ model=model,
492
+ tools=[{% for tool_name in tools.keys() if tool_name != "final_answer" %}{{ tool_name }}{% if not loop.last %}, {% endif %}{% endfor %}],
493
+ managed_agents=[{% for subagent_name in managed_agents.keys() %}agent_{{ subagent_name }}{% if not loop.last %}, {% endif %}{% endfor %}],
494
+ {% for attribute_name, value in agent_dict.items() if attribute_name not in ["model", "tools", "prompt_templates", "authorized_imports", "managed_agents", "requirements"] -%}
495
+ {{ attribute_name }}={{ value|repr }},
496
+ {% endfor %}prompt_templates=prompt_templates
497
+ )
498
+ if __name__ == "__main__":
499
+ GradioUI({{ agent_name }}).launch()
500
+ """.strip()
src/smolagents/vision_web_browser.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from io import BytesIO
3
+ from time import sleep
4
+
5
+ import helium
6
+ import PIL.Image
7
+ from dotenv import load_dotenv
8
+ from selenium import webdriver
9
+ from selenium.webdriver.common.by import By
10
+ from selenium.webdriver.common.keys import Keys
11
+
12
+ from smolagents import CodeAgent, WebSearchTool, tool
13
+ from smolagents.agents import ActionStep
14
+ from smolagents.cli import load_model
15
+
16
+
17
+ github_request = """
18
+ I'm trying to find how hard I have to work to get a repo in github.com/trending.
19
+ Can you navigate to the profile for the top author of the top trending repo, and give me their total number of commits over the last year?
20
+ """ # The agent is able to achieve this request only when powered by GPT-4o or Claude-3.5-sonnet.
21
+
22
+ search_request = """
23
+ Please navigate to https://en.wikipedia.org/wiki/Chicago and give me a sentence containing the word "1992" that mentions a construction accident.
24
+ """
25
+
26
+
27
+ def parse_arguments():
28
+ parser = argparse.ArgumentParser(description="Run a web browser automation script with a specified model.")
29
+ parser.add_argument(
30
+ "prompt",
31
+ type=str,
32
+ nargs="?", # Makes it optional
33
+ default=search_request,
34
+ help="The prompt to run with the agent",
35
+ )
36
+ parser.add_argument(
37
+ "--model-type",
38
+ type=str,
39
+ default="LiteLLMModel",
40
+ help="The model type to use (e.g., OpenAIServerModel, LiteLLMModel, TransformersModel, InferenceClientModel)",
41
+ )
42
+ parser.add_argument(
43
+ "--model-id",
44
+ type=str,
45
+ default="gpt-4o",
46
+ help="The model ID to use for the specified model type",
47
+ )
48
+ parser.add_argument(
49
+ "--provider",
50
+ type=str,
51
+ help="The inference provider to use for the model",
52
+ )
53
+ parser.add_argument(
54
+ "--api-base",
55
+ type=str,
56
+ help="The API base to use for the model",
57
+ )
58
+ parser.add_argument(
59
+ "--api-key",
60
+ type=str,
61
+ help="The API key to use for the model",
62
+ )
63
+ return parser.parse_args()
64
+
65
+
66
+ def save_screenshot(memory_step: ActionStep, agent: CodeAgent) -> None:
67
+ sleep(1.0) # Let JavaScript animations happen before taking the screenshot
68
+ driver = helium.get_driver()
69
+ current_step = memory_step.step_number
70
+ if driver is not None:
71
+ for previous_memory_step in agent.memory.steps: # Remove previous screenshots from logs for lean processing
72
+ if isinstance(previous_memory_step, ActionStep) and previous_memory_step.step_number <= current_step - 2:
73
+ previous_memory_step.observations_images = None
74
+ png_bytes = driver.get_screenshot_as_png()
75
+ image = PIL.Image.open(BytesIO(png_bytes))
76
+ print(f"Captured a browser screenshot: {image.size} pixels")
77
+ memory_step.observations_images = [image.copy()] # Create a copy to ensure it persists, important!
78
+
79
+ # Update observations with current URL
80
+ url_info = f"Current url: {driver.current_url}"
81
+ memory_step.observations = (
82
+ url_info if memory_step.observations is None else memory_step.observations + "\n" + url_info
83
+ )
84
+ return
85
+
86
+
87
+ @tool
88
+ def search_item_ctrl_f(text: str, nth_result: int = 1) -> str:
89
+ """
90
+ Searches for text on the current page via Ctrl + F and jumps to the nth occurrence.
91
+ Args:
92
+ text: The text to search for
93
+ nth_result: Which occurrence to jump to (default: 1)
94
+ """
95
+ elements = driver.find_elements(By.XPATH, f"//*[contains(text(), '{text}')]")
96
+ if nth_result > len(elements):
97
+ raise Exception(f"Match n°{nth_result} not found (only {len(elements)} matches found)")
98
+ result = f"Found {len(elements)} matches for '{text}'."
99
+ elem = elements[nth_result - 1]
100
+ driver.execute_script("arguments[0].scrollIntoView(true);", elem)
101
+ result += f"Focused on element {nth_result} of {len(elements)}"
102
+ return result
103
+
104
+
105
+ @tool
106
+ def go_back() -> None:
107
+ """Goes back to previous page."""
108
+ driver.back()
109
+
110
+
111
+ @tool
112
+ def close_popups() -> str:
113
+ """
114
+ Closes any visible modal or pop-up on the page. Use this to dismiss pop-up windows! This does not work on cookie consent banners.
115
+ """
116
+ webdriver.ActionChains(driver).send_keys(Keys.ESCAPE).perform()
117
+
118
+
119
+ def initialize_driver():
120
+ """Initialize the Selenium WebDriver."""
121
+ chrome_options = webdriver.ChromeOptions()
122
+ chrome_options.add_argument("--force-device-scale-factor=1")
123
+ chrome_options.add_argument("--window-size=1000,1350")
124
+ chrome_options.add_argument("--disable-pdf-viewer")
125
+ chrome_options.add_argument("--window-position=0,0")
126
+ return helium.start_chrome(headless=False, options=chrome_options)
127
+
128
+
129
+ def initialize_agent(model):
130
+ """Initialize the CodeAgent with the specified model."""
131
+ return CodeAgent(
132
+ tools=[WebSearchTool(), go_back, close_popups, search_item_ctrl_f],
133
+ model=model,
134
+ additional_authorized_imports=["helium"],
135
+ step_callbacks=[save_screenshot],
136
+ max_steps=20,
137
+ verbosity_level=2,
138
+ )
139
+
140
+
141
+ helium_instructions = """
142
+ Use your web_search tool when you want to get Google search results.
143
+ Then you can use helium to access websites. Don't use helium for Google search, only for navigating websites!
144
+ Don't bother about the helium driver, it's already managed.
145
+ We've already ran "from helium import *"
146
+ Then you can go to pages!
147
+ <code>
148
+ go_to('github.com/trending')
149
+ </code>
150
+
151
+ You can directly click clickable elements by inputting the text that appears on them.
152
+ <code>
153
+ click("Top products")
154
+ </code>
155
+
156
+ If it's a link:
157
+ <code>
158
+ click(Link("Top products"))
159
+ </code>
160
+
161
+ If you try to interact with an element and it's not found, you'll get a LookupError.
162
+ In general stop your action after each button click to see what happens on your screenshot.
163
+ Never try to login in a page.
164
+
165
+ To scroll up or down, use scroll_down or scroll_up with as an argument the number of pixels to scroll from.
166
+ <code>
167
+ scroll_down(num_pixels=1200) # This will scroll one viewport down
168
+ </code>
169
+
170
+ When you have pop-ups with a cross icon to close, don't try to click the close icon by finding its element or targeting an 'X' element (this most often fails).
171
+ Just use your built-in tool `close_popups` to close them:
172
+ <code>
173
+ close_popups()
174
+ </code>
175
+
176
+ You can use .exists() to check for the existence of an element. For example:
177
+ <code>
178
+ if Text('Accept cookies?').exists():
179
+ click('I accept')
180
+ </code>
181
+
182
+ Proceed in several steps rather than trying to solve the task in one shot.
183
+ And at the end, only when you have your answer, return your final answer.
184
+ <code>
185
+ final_answer("YOUR_ANSWER_HERE")
186
+ </code>
187
+
188
+ If pages seem stuck on loading, you might have to wait, for instance `import time` and run `time.sleep(5.0)`. But don't overuse this!
189
+ To list elements on page, DO NOT try code-based element searches like 'contributors = find_all(S("ol > li"))': just look at the latest screenshot you have and read it visually, or use your tool search_item_ctrl_f.
190
+ Of course, you can act on buttons like a user would do when navigating.
191
+ After each code blob you write, you will be automatically provided with an updated screenshot of the browser and the current browser url.
192
+ But beware that the screenshot will only be taken at the end of the whole action, it won't see intermediate states.
193
+ Don't kill the browser.
194
+ When you have modals or cookie banners on screen, you should get rid of them before you can click anything else.
195
+ """
196
+
197
+
198
+ def run_webagent(
199
+ prompt: str,
200
+ model_type: str,
201
+ model_id: str,
202
+ provider: str | None = None,
203
+ api_base: str | None = None,
204
+ api_key: str | None = None,
205
+ ) -> None:
206
+ # Load environment variables
207
+ load_dotenv()
208
+
209
+ # Initialize the model based on the provided arguments
210
+ model = load_model(model_type, model_id, provider=provider, api_base=api_base, api_key=api_key)
211
+
212
+ global driver
213
+ driver = initialize_driver()
214
+ agent = initialize_agent(model)
215
+
216
+ # Run the agent with the provided prompt
217
+ agent.python_executor("from helium import *")
218
+ agent.run(prompt + helium_instructions)
219
+
220
+
221
+ def main() -> None:
222
+ # Parse command line arguments
223
+ args = parse_arguments()
224
+ run_webagent(args.prompt, args.model_type, args.model_id, args.provider, args.api_base, args.api_key)
225
+
226
+
227
+ if __name__ == "__main__":
228
+ main()