# coding=utf-8 # Copyright 2024 HuggingFace Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect import os import textwrap import unittest import pytest from IPython.core.interactiveshell import InteractiveShell from smolagents import Tool from smolagents.tools import tool from smolagents.utils import get_source, instance_to_source, is_valid_name, parse_code_blobs, parse_json_blob class ValidTool(Tool): name = "valid_tool" description = "A valid tool" inputs = {"input": {"type": "string", "description": "input"}} output_type = "string" simple_attr = "string" dict_attr = {"key": "value"} def __init__(self, optional_param="default"): super().__init__() self.param = optional_param def forward(self, input: str) -> str: return input.upper() @tool def valid_tool_function(input: str) -> str: """A valid tool function. Args: input (str): Input string. """ return input.upper() VALID_TOOL_SOURCE = """\ from smolagents.tools import Tool class ValidTool(Tool): name = "valid_tool" description = "A valid tool" inputs = {'input': {'type': 'string', 'description': 'input'}} output_type = "string" simple_attr = "string" dict_attr = {'key': 'value'} def __init__(self, optional_param="default"): super().__init__() self.param = optional_param def forward(self, input: str) -> str: return input.upper() """ VALID_TOOL_FUNCTION_SOURCE = '''\ from smolagents.tools import Tool class SimpleTool(Tool): name = "valid_tool_function" description = "A valid tool function." inputs = {'input': {'type': 'string', 'description': 'Input string.'}} output_type = "string" def __init__(self): self.is_initialized = True def forward(self, input: str) -> str: """A valid tool function. Args: input (str): Input string. """ return input.upper() ''' class AgentTextTests(unittest.TestCase): def test_parse_code_blobs(self): with pytest.raises(ValueError): parse_code_blobs("Wrong blob!") # Parsing mardkwon with code blobs should work output = parse_code_blobs(""" Here is how to solve the problem: import numpy as np """) assert output == "import numpy as np" # Parsing code blobs should work code_blob = "import numpy as np" output = parse_code_blobs(code_blob) assert output == code_blob # Allow whitespaces after header output = parse_code_blobs(" \ncode_a\n") assert output == "code_a" def test_multiple_code_blobs(self): test_input = "\nFoo\n\n\n\ncode_a\n\n\n\ncode_b\n" result = parse_code_blobs(test_input) assert result == "Foo\n\ncode_a\n\ncode_b" @pytest.fixture(scope="function") def ipython_shell(): """Reset IPython shell before and after each test.""" shell = InteractiveShell.instance() shell.reset() # Clean before test yield shell shell.reset() # Clean after test @pytest.mark.parametrize( "obj_name, code_blob", [ ("test_func", "def test_func():\n return 42"), ("TestClass", "class TestClass:\n ..."), ], ) def test_get_source_ipython(ipython_shell, obj_name, code_blob): ipython_shell.run_cell(code_blob, store_history=True) obj = ipython_shell.user_ns[obj_name] assert get_source(obj) == code_blob def test_get_source_standard_class(): class TestClass: ... source = get_source(TestClass) assert source == "class TestClass: ..." assert source == textwrap.dedent(inspect.getsource(TestClass)).strip() def test_get_source_standard_function(): def test_func(): ... source = get_source(test_func) assert source == "def test_func(): ..." assert source == textwrap.dedent(inspect.getsource(test_func)).strip() def test_get_source_ipython_errors_empty_cells(ipython_shell): test_code = textwrap.dedent("""class TestClass:\n ...""").strip() ipython_shell.user_ns["In"] = [""] ipython_shell.run_cell(test_code, store_history=True) with pytest.raises(ValueError, match="No code cells found in IPython session"): get_source(ipython_shell.user_ns["TestClass"]) def test_get_source_ipython_errors_definition_not_found(ipython_shell): test_code = textwrap.dedent("""class TestClass:\n ...""").strip() ipython_shell.user_ns["In"] = ["", "print('No class definition here')"] ipython_shell.run_cell(test_code, store_history=True) with pytest.raises(ValueError, match="Could not find source code for TestClass in IPython history"): get_source(ipython_shell.user_ns["TestClass"]) def test_get_source_ipython_errors_type_error(): with pytest.raises(TypeError, match="Expected class or callable"): get_source(None) @pytest.mark.parametrize( "tool, expected_tool_source", [(ValidTool(), VALID_TOOL_SOURCE), (valid_tool_function, VALID_TOOL_FUNCTION_SOURCE)] ) def test_instance_to_source(tool, expected_tool_source): tool_source = instance_to_source(tool, base_cls=Tool) assert tool_source == expected_tool_source def test_e2e_class_tool_save(tmp_path): class TestTool(Tool): name = "test_tool" description = "Test tool description" inputs = { "task": { "type": "string", "description": "tool input", } } output_type = "string" def forward(self, task: str): import IPython # noqa: F401 return task test_tool = TestTool() test_tool.save(tmp_path, make_gradio_app=True) assert set(os.listdir(tmp_path)) == {"requirements.txt", "app.py", "tool.py"} assert (tmp_path / "tool.py").read_text() == textwrap.dedent( """\ from typing import Any, Optional from smolagents.tools import Tool import IPython class TestTool(Tool): name = "test_tool" description = "Test tool description" inputs = {'task': {'type': 'string', 'description': 'tool input'}} output_type = "string" def forward(self, task: str): import IPython # noqa: F401 return task def __init__(self, *args, **kwargs): self.is_initialized = False """ ) requirements = set((tmp_path / "requirements.txt").read_text().split()) assert requirements == {"IPython", "smolagents"} assert (tmp_path / "app.py").read_text() == textwrap.dedent( """\ from smolagents import launch_gradio_demo from tool import TestTool tool = TestTool() launch_gradio_demo(tool) """ ) def test_e2e_ipython_class_tool_save(tmp_path): shell = InteractiveShell.instance() code_blob = textwrap.dedent( f"""\ from smolagents.tools import Tool class TestTool(Tool): name = "test_tool" description = "Test tool description" inputs = {{"task": {{"type": "string", "description": "tool input", }} }} output_type = "string" def forward(self, task: str): import IPython # noqa: F401 return task TestTool().save("{tmp_path}", make_gradio_app=True) """ ) assert shell.run_cell(code_blob, store_history=True).success assert set(os.listdir(tmp_path)) == {"requirements.txt", "app.py", "tool.py"} assert (tmp_path / "tool.py").read_text() == textwrap.dedent( """\ from typing import Any, Optional from smolagents.tools import Tool import IPython class TestTool(Tool): name = "test_tool" description = "Test tool description" inputs = {'task': {'type': 'string', 'description': 'tool input'}} output_type = "string" def forward(self, task: str): import IPython # noqa: F401 return task def __init__(self, *args, **kwargs): self.is_initialized = False """ ) requirements = set((tmp_path / "requirements.txt").read_text().split()) assert requirements == {"IPython", "smolagents"} assert (tmp_path / "app.py").read_text() == textwrap.dedent( """\ from smolagents import launch_gradio_demo from tool import TestTool tool = TestTool() launch_gradio_demo(tool) """ ) def test_e2e_function_tool_save(tmp_path): @tool def test_tool(task: str) -> str: """ Test tool description Args: task: tool input """ import IPython # noqa: F401 return task test_tool.save(tmp_path, make_gradio_app=True) assert set(os.listdir(tmp_path)) == {"requirements.txt", "app.py", "tool.py"} assert (tmp_path / "tool.py").read_text() == textwrap.dedent( """\ from smolagents import Tool from typing import Any, Optional class SimpleTool(Tool): name = "test_tool" description = "Test tool description" inputs = {'task': {'type': 'string', 'description': 'tool input'}} output_type = "string" def forward(self, task: str) -> str: \""" Test tool description Args: task: tool input \""" import IPython # noqa: F401 return task""" ) requirements = set((tmp_path / "requirements.txt").read_text().split()) assert requirements == {"smolagents"} # FIXME: IPython should be in the requirements assert (tmp_path / "app.py").read_text() == textwrap.dedent( """\ from smolagents import launch_gradio_demo from tool import SimpleTool tool = SimpleTool() launch_gradio_demo(tool) """ ) def test_e2e_ipython_function_tool_save(tmp_path): shell = InteractiveShell.instance() code_blob = textwrap.dedent( f""" from smolagents import tool @tool def test_tool(task: str) -> str: \""" Test tool description Args: task: tool input \""" import IPython # noqa: F401 return task test_tool.save("{tmp_path}", make_gradio_app=True) """ ) assert shell.run_cell(code_blob, store_history=True).success assert set(os.listdir(tmp_path)) == {"requirements.txt", "app.py", "tool.py"} assert (tmp_path / "tool.py").read_text() == textwrap.dedent( """\ from smolagents import Tool from typing import Any, Optional class SimpleTool(Tool): name = "test_tool" description = "Test tool description" inputs = {'task': {'type': 'string', 'description': 'tool input'}} output_type = "string" def forward(self, task: str) -> str: \""" Test tool description Args: task: tool input \""" import IPython # noqa: F401 return task""" ) requirements = set((tmp_path / "requirements.txt").read_text().split()) assert requirements == {"smolagents"} # FIXME: IPython should be in the requirements assert (tmp_path / "app.py").read_text() == textwrap.dedent( """\ from smolagents import launch_gradio_demo from tool import SimpleTool tool = SimpleTool() launch_gradio_demo(tool) """ ) @pytest.mark.parametrize( "raw_json, expected_data, expected_blob", [ ( """{}""", {}, "", ), ( """Text{}""", {}, "Text", ), ( """{"simple": "json"}""", {"simple": "json"}, "", ), ( """With text here{"simple": "json"}""", {"simple": "json"}, "With text here", ), ( """{"simple": "json"}With text after""", {"simple": "json"}, "", ), ( """With text before{"simple": "json"}And text after""", {"simple": "json"}, "With text before", ), ], ) def test_parse_json_blob_with_valid_json(raw_json, expected_data, expected_blob): data, blob = parse_json_blob(raw_json) assert data == expected_data assert blob == expected_blob @pytest.mark.parametrize( "raw_json", [ """simple": "json"}""", """With text here"simple": "json"}""", """{"simple": ""json"}With text after""", """{"simple": "json"With text after""", "}}", ], ) def test_parse_json_blob_with_invalid_json(raw_json): with pytest.raises(Exception): parse_json_blob(raw_json) @pytest.mark.parametrize( "name,expected", [ # Valid identifiers ("valid_name", True), ("ValidName", True), ("valid123", True), ("_private", True), # Invalid identifiers ("", False), ("123invalid", False), ("invalid-name", False), ("invalid name", False), ("invalid.name", False), # Python keywords ("if", False), ("for", False), ("class", False), ("return", False), # Non-string inputs (123, False), (None, False), ([], False), ({}, False), ], ) def test_is_valid_name(name, expected): """Test the is_valid_name function with various inputs.""" assert is_valid_name(name) is expected