#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import ast
import base64
import importlib.metadata
import importlib.util
import inspect
import json
import keyword
import os
import re
import types
from functools import lru_cache
from io import BytesIO
from pathlib import Path
from textwrap import dedent
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from smolagents.memory import AgentLogger
__all__ = ["AgentError"]
@lru_cache
def _is_package_available(package_name: str) -> bool:
try:
importlib.metadata.version(package_name)
return True
except importlib.metadata.PackageNotFoundError:
return False
BASE_BUILTIN_MODULES = [
"collections",
"datetime",
"itertools",
"math",
"queue",
"random",
"re",
"stat",
"statistics",
"time",
"unicodedata",
]
def escape_code_brackets(text: str) -> str:
"""Escapes square brackets in code segments while preserving Rich styling tags."""
def replace_bracketed_content(match):
content = match.group(1)
cleaned = re.sub(
r"bold|red|green|blue|yellow|magenta|cyan|white|black|italic|dim|\s|#[0-9a-fA-F]{6}", "", content
)
return f"\\[{content}\\]" if cleaned.strip() else f"[{content}]"
return re.sub(r"\[([^\]]*)\]", replace_bracketed_content, text)
class AgentError(Exception):
"""Base class for other agent-related exceptions"""
def __init__(self, message, logger: "AgentLogger"):
super().__init__(message)
self.message = message
logger.log_error(message)
def dict(self) -> dict[str, str]:
return {"type": self.__class__.__name__, "message": str(self.message)}
class AgentParsingError(AgentError):
"""Exception raised for errors in parsing in the agent"""
pass
class AgentExecutionError(AgentError):
"""Exception raised for errors in execution in the agent"""
pass
class AgentMaxStepsError(AgentError):
"""Exception raised for errors in execution in the agent"""
pass
class AgentToolCallError(AgentExecutionError):
"""Exception raised for errors when incorrect arguments are passed to the tool"""
pass
class AgentToolExecutionError(AgentExecutionError):
"""Exception raised for errors when executing a tool"""
pass
class AgentGenerationError(AgentError):
"""Exception raised for errors in generation in the agent"""
pass
def make_json_serializable(obj: Any) -> Any:
"""Recursive function to make objects JSON serializable"""
if obj is None:
return None
elif isinstance(obj, (str, int, float, bool)):
# Try to parse string as JSON if it looks like a JSON object/array
if isinstance(obj, str):
try:
if (obj.startswith("{") and obj.endswith("}")) or (obj.startswith("[") and obj.endswith("]")):
parsed = json.loads(obj)
return make_json_serializable(parsed)
except json.JSONDecodeError:
pass
return obj
elif isinstance(obj, (list, tuple)):
return [make_json_serializable(item) for item in obj]
elif isinstance(obj, dict):
return {str(k): make_json_serializable(v) for k, v in obj.items()}
elif hasattr(obj, "__dict__"):
# For custom objects, convert their __dict__ to a serializable format
return {"_type": obj.__class__.__name__, **{k: make_json_serializable(v) for k, v in obj.__dict__.items()}}
else:
# For any other type, convert to string
return str(obj)
def parse_json_blob(json_blob: str) -> tuple[dict[str, str], str]:
"Extracts the JSON blob from the input and returns the JSON data and the rest of the input."
try:
first_accolade_index = json_blob.find("{")
last_accolade_index = [a.start() for a in list(re.finditer("}", json_blob))][-1]
json_data = json_blob[first_accolade_index : last_accolade_index + 1]
json_data = json.loads(json_data, strict=False)
return json_data, json_blob[:first_accolade_index]
except IndexError:
raise ValueError("The model output does not contain any JSON blob.")
except json.JSONDecodeError as e:
place = e.pos
if json_blob[place - 1 : place + 2] == "},\n":
raise ValueError(
"JSON is invalid: you probably tried to provide multiple tool calls in one action. PROVIDE ONLY ONE TOOL CALL."
)
raise ValueError(
f"The JSON blob you used is invalid due to the following error: {e}.\n"
f"JSON blob was: {json_blob}, decoding failed on that specific part of the blob:\n"
f"'{json_blob[place - 4 : place + 5]}'."
)
def extract_code_from_text(text: str) -> str | None:
"""Extract code from the LLM's output."""
pattern = r"(.*?)
"
matches = re.findall(pattern, text, re.DOTALL)
if matches:
return "\n\n".join(match.strip() for match in matches)
return None
def parse_code_blobs(text: str) -> str:
"""Extract code blocs from the LLM's output.
If a valid code block is passed, it returns it directly.
Args:
text (`str`): LLM's output text to parse.
Returns:
`str`: Extracted code block.
Raises:
ValueError: If no valid code block is found in the text.
"""
matches = extract_code_from_text(text)
if matches:
return matches
# Maybe the LLM outputted a code blob directly
try:
ast.parse(text)
return text
except SyntaxError:
pass
if "final" in text and "answer" in text:
raise ValueError(
dedent(
f"""
Your code snippet is invalid, because the regex pattern (.*?)
was not found in it.
Here is your code snippet:
{text}
It seems like you're trying to return the final answer, you can do it as follows:
final_answer("YOUR FINAL ANSWER HERE")
"""
).strip()
)
raise ValueError(
dedent(
f"""
Your code snippet is invalid, because the regex pattern (.*?)
was not found in it.
Here is your code snippet:
{text}
Make sure to include code with the correct pattern, for instance:
Thoughts: Your thoughts
# Your python code here
"""
).strip()
)
MAX_LENGTH_TRUNCATE_CONTENT = 20000
def truncate_content(content: str, max_length: int = MAX_LENGTH_TRUNCATE_CONTENT) -> str:
if len(content) <= max_length:
return content
else:
return (
content[: max_length // 2]
+ f"\n..._This content has been truncated to stay below {max_length} characters_...\n"
+ content[-max_length // 2 :]
)
class ImportFinder(ast.NodeVisitor):
def __init__(self):
self.packages = set()
def visit_Import(self, node):
for alias in node.names:
# Get the base package name (before any dots)
base_package = alias.name.split(".")[0]
self.packages.add(base_package)
def visit_ImportFrom(self, node):
if node.module: # for "from x import y" statements
# Get the base package name (before any dots)
base_package = node.module.split(".")[0]
self.packages.add(base_package)
def get_method_source(method):
"""Get source code for a method, including bound methods."""
if isinstance(method, types.MethodType):
method = method.__func__
return get_source(method)
def is_same_method(method1, method2):
"""Compare two methods by their source code."""
try:
source1 = get_method_source(method1)
source2 = get_method_source(method2)
# Remove method decorators if any
source1 = "\n".join(line for line in source1.split("\n") if not line.strip().startswith("@"))
source2 = "\n".join(line for line in source2.split("\n") if not line.strip().startswith("@"))
return source1 == source2
except (TypeError, OSError):
return False
def is_same_item(item1, item2):
"""Compare two class items (methods or attributes) for equality."""
if callable(item1) and callable(item2):
return is_same_method(item1, item2)
else:
return item1 == item2
def instance_to_source(instance, base_cls=None):
"""Convert an instance to its class source code representation."""
cls = instance.__class__
class_name = cls.__name__
# Start building class lines
class_lines = []
if base_cls:
class_lines.append(f"class {class_name}({base_cls.__name__}):")
else:
class_lines.append(f"class {class_name}:")
# Add docstring if it exists and differs from base
if cls.__doc__ and (not base_cls or cls.__doc__ != base_cls.__doc__):
class_lines.append(f' """{cls.__doc__}"""')
# Add class-level attributes
class_attrs = {
name: value
for name, value in cls.__dict__.items()
if not name.startswith("__")
and not callable(value)
and not (base_cls and hasattr(base_cls, name) and getattr(base_cls, name) == value)
}
for name, value in class_attrs.items():
if isinstance(value, str):
# multiline value
if "\n" in value:
escaped_value = value.replace('"""', r"\"\"\"") # Escape triple quotes
class_lines.append(f' {name} = """{escaped_value}"""')
else:
class_lines.append(f" {name} = {json.dumps(value)}")
else:
class_lines.append(f" {name} = {repr(value)}")
if class_attrs:
class_lines.append("")
# Add methods
methods = {
name: func.__wrapped__ if hasattr(func, "__wrapped__") else func
for name, func in cls.__dict__.items()
if callable(func)
and (
not base_cls
or not hasattr(base_cls, name)
or (
isinstance(func, (staticmethod, classmethod))
or (getattr(base_cls, name).__code__.co_code != func.__code__.co_code)
)
)
}
for name, method in methods.items():
method_source = get_source(method)
# Clean up the indentation
method_lines = method_source.split("\n")
first_line = method_lines[0]
indent = len(first_line) - len(first_line.lstrip())
method_lines = [line[indent:] for line in method_lines]
method_source = "\n".join([" " + line if line.strip() else line for line in method_lines])
class_lines.append(method_source)
class_lines.append("")
# Find required imports using ImportFinder
import_finder = ImportFinder()
import_finder.visit(ast.parse("\n".join(class_lines)))
required_imports = import_finder.packages
# Build final code with imports
final_lines = []
# Add base class import if needed
if base_cls:
final_lines.append(f"from {base_cls.__module__} import {base_cls.__name__}")
# Add discovered imports
for package in required_imports:
final_lines.append(f"import {package}")
if final_lines: # Add empty line after imports
final_lines.append("")
# Add the class code
final_lines.extend(class_lines)
return "\n".join(final_lines)
def get_source(obj) -> str:
"""Get the source code of a class or callable object (e.g.: function, method).
First attempts to get the source code using `inspect.getsource`.
In a dynamic environment (e.g.: Jupyter, IPython), if this fails,
falls back to retrieving the source code from the current interactive shell session.
Args:
obj: A class or callable object (e.g.: function, method)
Returns:
str: The source code of the object, dedented and stripped
Raises:
TypeError: If object is not a class or callable
OSError: If source code cannot be retrieved from any source
ValueError: If source cannot be found in IPython history
Note:
TODO: handle Python standard REPL
"""
if not (isinstance(obj, type) or callable(obj)):
raise TypeError(f"Expected class or callable, got {type(obj)}")
inspect_error = None
try:
# Handle dynamically created classes
source = getattr(obj, "__source__", None) or inspect.getsource(obj)
return dedent(source).strip()
except OSError as e:
# let's keep track of the exception to raise it if all further methods fail
inspect_error = e
try:
import IPython
shell = IPython.get_ipython()
if not shell:
raise ImportError("No active IPython shell found")
all_cells = "\n".join(shell.user_ns.get("In", [])).strip()
if not all_cells:
raise ValueError("No code cells found in IPython session")
tree = ast.parse(all_cells)
for node in ast.walk(tree):
if isinstance(node, (ast.ClassDef, ast.FunctionDef)) and node.name == obj.__name__:
return dedent("\n".join(all_cells.split("\n")[node.lineno - 1 : node.end_lineno])).strip()
raise ValueError(f"Could not find source code for {obj.__name__} in IPython history")
except ImportError:
# IPython is not available, let's just raise the original inspect error
raise inspect_error
except ValueError as e:
# IPython is available but we couldn't find the source code, let's raise the error
raise e from inspect_error
def encode_image_base64(image):
buffered = BytesIO()
image.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode("utf-8")
def make_image_url(base64_image):
return f"data:image/png;base64,{base64_image}"
def make_init_file(folder: str | Path):
os.makedirs(folder, exist_ok=True)
# Create __init__
with open(os.path.join(folder, "__init__.py"), "w"):
pass
def is_valid_name(name: str) -> bool:
return name.isidentifier() and not keyword.iskeyword(name) if isinstance(name, str) else False
AGENT_GRADIO_APP_TEMPLATE = """import yaml
import os
from smolagents import GradioUI, {{ class_name }}, {{ agent_dict['model']['class'] }}
# Get current directory path
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
{% for tool in tools.values() -%}
from {{managed_agent_relative_path}}tools.{{ tool.name }} import {{ tool.__class__.__name__ }} as {{ tool.name | camelcase }}
{% endfor %}
{% for managed_agent in managed_agents.values() -%}
from {{managed_agent_relative_path}}managed_agents.{{ managed_agent.name }}.app import agent_{{ managed_agent.name }}
{% endfor %}
model = {{ agent_dict['model']['class'] }}(
{% for key in agent_dict['model']['data'] if key not in ['class', 'last_input_token_count', 'last_output_token_count'] -%}
{{ key }}={{ agent_dict['model']['data'][key]|repr }},
{% endfor %})
{% for tool in tools.values() -%}
{{ tool.name }} = {{ tool.name | camelcase }}()
{% endfor %}
with open(os.path.join(CURRENT_DIR, "prompts.yaml"), 'r') as stream:
prompt_templates = yaml.safe_load(stream)
{{ agent_name }} = {{ class_name }}(
model=model,
tools=[{% for tool_name in tools.keys() if tool_name != "final_answer" %}{{ tool_name }}{% if not loop.last %}, {% endif %}{% endfor %}],
managed_agents=[{% for subagent_name in managed_agents.keys() %}agent_{{ subagent_name }}{% if not loop.last %}, {% endif %}{% endfor %}],
{% for attribute_name, value in agent_dict.items() if attribute_name not in ["model", "tools", "prompt_templates", "authorized_imports", "managed_agents", "requirements"] -%}
{{ attribute_name }}={{ value|repr }},
{% endfor %}prompt_templates=prompt_templates
)
if __name__ == "__main__":
GradioUI({{ agent_name }}).launch()
""".strip()