File size: 5,530 Bytes
9c31777 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 |
import ast
from textwrap import dedent
import pytest
from smolagents.default_tools import (
DuckDuckGoSearchTool,
GoogleSearchTool,
SpeechToTextTool,
VisitWebpageTool,
WebSearchTool,
)
from smolagents.tool_validation import MethodChecker, validate_tool_attributes
from smolagents.tools import Tool, tool
UNDEFINED_VARIABLE = "undefined_variable"
@pytest.mark.parametrize(
"tool_class", [DuckDuckGoSearchTool, GoogleSearchTool, SpeechToTextTool, VisitWebpageTool, WebSearchTool]
)
def test_validate_tool_attributes_with_default_tools(tool_class):
assert validate_tool_attributes(tool_class) is None, f"failed for {tool_class.name} tool"
class ValidTool(Tool):
name = "valid_tool"
description = "A valid tool"
inputs = {"input": {"type": "string", "description": "input"}}
output_type = "string"
simple_attr = "string"
dict_attr = {"key": "value"}
def __init__(self, optional_param="default"):
super().__init__()
self.param = optional_param
def forward(self, input: str) -> str:
return input.upper()
@tool
def valid_tool_function(input: str) -> str:
"""A valid tool function.
Args:
input (str): Input string.
"""
return input.upper()
@pytest.mark.parametrize("tool_class", [ValidTool, valid_tool_function.__class__])
def test_validate_tool_attributes_valid(tool_class):
assert validate_tool_attributes(tool_class) is None
class InvalidToolName(Tool):
name = "invalid tool name"
description = "Tool with invalid name"
inputs = {"input": {"type": "string", "description": "input"}}
output_type = "string"
def __init__(self):
super().__init__()
def forward(self, input: str) -> str:
return input
class InvalidToolComplexAttrs(Tool):
name = "invalid_tool"
description = "Tool with complex class attributes"
inputs = {"input": {"type": "string", "description": "input"}}
output_type = "string"
complex_attr = [x for x in range(3)] # Complex class attribute
def __init__(self):
super().__init__()
def forward(self, input: str) -> str:
return input
class InvalidToolRequiredParams(Tool):
name = "invalid_tool"
description = "Tool with required params"
inputs = {"input": {"type": "string", "description": "input"}}
output_type = "string"
def __init__(self, required_param, kwarg1=1): # No default value
super().__init__()
self.param = required_param
def forward(self, input: str) -> str:
return input
class InvalidToolNonLiteralDefaultParam(Tool):
name = "invalid_tool"
description = "Tool with non-literal default parameter value"
inputs = {"input": {"type": "string", "description": "input"}}
output_type = "string"
def __init__(self, default_param=UNDEFINED_VARIABLE): # UNDEFINED_VARIABLE as default is non-literal
super().__init__()
self.default_param = default_param
def forward(self, input: str) -> str:
return input
class InvalidToolUndefinedNames(Tool):
name = "invalid_tool"
description = "Tool with undefined names"
inputs = {"input": {"type": "string", "description": "input"}}
output_type = "string"
def forward(self, input: str) -> str:
return UNDEFINED_VARIABLE # Undefined name
@pytest.mark.parametrize(
"tool_class, expected_error",
[
(
InvalidToolName,
"Class attribute 'name' must be a valid Python identifier and not a reserved keyword, found 'invalid tool name'",
),
(InvalidToolComplexAttrs, "Complex attributes should be defined in __init__, not as class attributes"),
(InvalidToolRequiredParams, "Parameters in __init__ must have default values, found required parameters"),
(
InvalidToolNonLiteralDefaultParam,
"Parameters in __init__ must have literal default values, found non-literal defaults",
),
(InvalidToolUndefinedNames, "Name 'UNDEFINED_VARIABLE' is undefined"),
],
)
def test_validate_tool_attributes_exceptions(tool_class, expected_error):
with pytest.raises(ValueError, match=expected_error):
validate_tool_attributes(tool_class)
class MultipleAssignmentsTool(Tool):
name = "multiple_assignments_tool"
description = "Tool with multiple assignments"
inputs = {"input": {"type": "string", "description": "input"}}
output_type = "string"
def __init__(self):
super().__init__()
def forward(self, input: str) -> str:
a, b = "1", "2"
return a + b
def test_validate_tool_attributes_multiple_assignments():
validate_tool_attributes(MultipleAssignmentsTool)
@tool
def tool_function_with_multiple_assignments(input: str) -> str:
"""A valid tool function.
Args:
input (str): Input string.
"""
a, b = "1", "2"
return input.upper() + a + b
@pytest.mark.parametrize("tool_instance", [MultipleAssignmentsTool(), tool_function_with_multiple_assignments])
def test_tool_to_dict_validation_with_multiple_assignments(tool_instance):
tool_instance.to_dict()
class TestMethodChecker:
def test_multiple_assignments(self):
source_code = dedent(
"""
def forward(self) -> str:
a, b = "1", "2"
return a + b
"""
)
method_checker = MethodChecker(set())
method_checker.visit(ast.parse(source_code))
assert method_checker.errors == []
|