|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import shutil |
|
import tempfile |
|
import unittest |
|
from unittest.mock import Mock, patch |
|
|
|
import pytest |
|
|
|
from smolagents.agent_types import AgentAudio, AgentImage, AgentText |
|
from smolagents.gradio_ui import GradioUI, pull_messages_from_step, stream_to_gradio |
|
from smolagents.memory import ActionStep, FinalAnswerStep, PlanningStep, ToolCall |
|
from smolagents.models import ChatMessageStreamDelta |
|
from smolagents.monitoring import Timing, TokenUsage |
|
|
|
|
|
class GradioUITester(unittest.TestCase): |
|
def setUp(self): |
|
"""Initialize test environment""" |
|
self.temp_dir = tempfile.mkdtemp() |
|
self.mock_agent = Mock() |
|
self.ui = GradioUI(agent=self.mock_agent, file_upload_folder=self.temp_dir) |
|
self.allowed_types = [".pdf", ".docx", ".txt"] |
|
|
|
def tearDown(self): |
|
"""Clean up test environment""" |
|
shutil.rmtree(self.temp_dir) |
|
|
|
def test_upload_file_default_types(self): |
|
"""Test default allowed file types""" |
|
default_types = [".pdf", ".docx", ".txt"] |
|
for file_type in default_types: |
|
with tempfile.NamedTemporaryFile(suffix=file_type) as temp_file: |
|
mock_file = Mock() |
|
mock_file.name = temp_file.name |
|
|
|
textbox, uploads_log = self.ui.upload_file(mock_file, []) |
|
|
|
self.assertIn("File uploaded:", textbox.value) |
|
self.assertEqual(len(uploads_log), 1) |
|
self.assertTrue(os.path.exists(os.path.join(self.temp_dir, os.path.basename(temp_file.name)))) |
|
|
|
def test_upload_file_default_types_disallowed(self): |
|
"""Test default disallowed file types""" |
|
disallowed_types = [".exe", ".sh", ".py", ".jpg"] |
|
for file_type in disallowed_types: |
|
with tempfile.NamedTemporaryFile(suffix=file_type) as temp_file: |
|
mock_file = Mock() |
|
mock_file.name = temp_file.name |
|
|
|
textbox, uploads_log = self.ui.upload_file(mock_file, []) |
|
|
|
self.assertEqual(textbox.value, "File type disallowed") |
|
self.assertEqual(len(uploads_log), 0) |
|
|
|
def test_upload_file_success(self): |
|
"""Test successful file upload scenario""" |
|
with tempfile.NamedTemporaryFile(suffix=".txt") as temp_file: |
|
mock_file = Mock() |
|
mock_file.name = temp_file.name |
|
|
|
textbox, uploads_log = self.ui.upload_file(mock_file, []) |
|
|
|
self.assertIn("File uploaded:", textbox.value) |
|
self.assertEqual(len(uploads_log), 1) |
|
self.assertTrue(os.path.exists(os.path.join(self.temp_dir, os.path.basename(temp_file.name)))) |
|
self.assertEqual(uploads_log[0], os.path.join(self.temp_dir, os.path.basename(temp_file.name))) |
|
|
|
def test_upload_file_none(self): |
|
"""Test scenario when no file is selected""" |
|
textbox, uploads_log = self.ui.upload_file(None, []) |
|
|
|
self.assertEqual(textbox.value, "No file uploaded") |
|
self.assertEqual(len(uploads_log), 0) |
|
|
|
def test_upload_file_invalid_type(self): |
|
"""Test disallowed file type""" |
|
with tempfile.NamedTemporaryFile(suffix=".exe") as temp_file: |
|
mock_file = Mock() |
|
mock_file.name = temp_file.name |
|
|
|
textbox, uploads_log = self.ui.upload_file(mock_file, []) |
|
|
|
self.assertEqual(textbox.value, "File type disallowed") |
|
self.assertEqual(len(uploads_log), 0) |
|
|
|
def test_upload_file_special_chars(self): |
|
"""Test scenario with special characters in filename""" |
|
with tempfile.NamedTemporaryFile(suffix=".txt") as temp_file: |
|
|
|
special_char_name = os.path.join(os.path.dirname(temp_file.name), "test@#$%^&*.txt") |
|
shutil.copy(temp_file.name, special_char_name) |
|
try: |
|
mock_file = Mock() |
|
mock_file.name = special_char_name |
|
|
|
with patch("shutil.copy"): |
|
textbox, uploads_log = self.ui.upload_file(mock_file, []) |
|
|
|
self.assertIn("File uploaded:", textbox.value) |
|
self.assertEqual(len(uploads_log), 1) |
|
self.assertIn("test_____", uploads_log[0]) |
|
finally: |
|
|
|
if os.path.exists(special_char_name): |
|
os.remove(special_char_name) |
|
|
|
def test_upload_file_custom_types(self): |
|
"""Test custom allowed file types""" |
|
with tempfile.NamedTemporaryFile(suffix=".csv") as temp_file: |
|
mock_file = Mock() |
|
mock_file.name = temp_file.name |
|
|
|
textbox, uploads_log = self.ui.upload_file(mock_file, [], allowed_file_types=[".csv"]) |
|
|
|
self.assertIn("File uploaded:", textbox.value) |
|
self.assertEqual(len(uploads_log), 1) |
|
|
|
|
|
class TestStreamToGradio: |
|
"""Tests for the stream_to_gradio function.""" |
|
|
|
@patch("smolagents.gradio_ui.pull_messages_from_step") |
|
def test_stream_to_gradio_memory_step(self, mock_pull_messages): |
|
"""Test streaming a memory step""" |
|
|
|
mock_agent = Mock() |
|
mock_agent.run = Mock(return_value=[Mock(spec=ActionStep)]) |
|
mock_agent.model = Mock() |
|
mock_agent.model.last_input_token_count = 100 |
|
mock_agent.model.last_output_token_count = 200 |
|
|
|
mock_message = Mock() |
|
mock_pull_messages.return_value = [mock_message] |
|
|
|
result = list(stream_to_gradio(mock_agent, "test task")) |
|
|
|
mock_pull_messages.assert_called_once() |
|
assert result == [mock_message] |
|
|
|
def test_stream_to_gradio_stream_delta(self): |
|
"""Test streaming a ChatMessageStreamDelta""" |
|
|
|
mock_agent = Mock() |
|
mock_delta = ChatMessageStreamDelta(content="Hello") |
|
mock_agent.run = Mock(return_value=[mock_delta]) |
|
mock_agent.model = Mock() |
|
mock_agent.model.last_input_token_count = 100 |
|
mock_agent.model.last_output_token_count = 200 |
|
|
|
result = list(stream_to_gradio(mock_agent, "test task")) |
|
|
|
assert result == ["Hello"] |
|
|
|
def test_stream_to_gradio_multiple_deltas(self): |
|
"""Test streaming multiple ChatMessageStreamDeltas""" |
|
|
|
mock_agent = Mock() |
|
mock_delta1 = ChatMessageStreamDelta(content="Hello") |
|
mock_delta2 = ChatMessageStreamDelta(content=" world") |
|
mock_agent.run = Mock(return_value=[mock_delta1, mock_delta2]) |
|
mock_agent.model = Mock() |
|
mock_agent.model.last_input_token_count = 100 |
|
mock_agent.model.last_output_token_count = 200 |
|
|
|
result = list(stream_to_gradio(mock_agent, "test task")) |
|
|
|
assert result == ["Hello", "Hello world"] |
|
|
|
@pytest.mark.parametrize( |
|
"task,task_images,reset_memory,additional_args", |
|
[ |
|
("simple task", None, False, None), |
|
("task with images", ["image1.png", "image2.png"], False, None), |
|
("task with reset", None, True, None), |
|
("task with args", None, False, {"arg1": "value1"}), |
|
("complex task", ["image.png"], True, {"arg1": "value1", "arg2": "value2"}), |
|
], |
|
) |
|
def test_stream_to_gradio_parameters(self, task, task_images, reset_memory, additional_args): |
|
"""Test that stream_to_gradio passes parameters correctly to agent.run""" |
|
|
|
mock_agent = Mock() |
|
mock_agent.run = Mock(return_value=[]) |
|
|
|
list( |
|
stream_to_gradio( |
|
mock_agent, |
|
task=task, |
|
task_images=task_images, |
|
reset_agent_memory=reset_memory, |
|
additional_args=additional_args, |
|
) |
|
) |
|
|
|
mock_agent.run.assert_called_once_with( |
|
task, images=task_images, stream=True, reset=reset_memory, additional_args=additional_args |
|
) |
|
|
|
|
|
class TestPullMessagesFromStep: |
|
def test_action_step_basic( |
|
self, |
|
): |
|
"""Test basic ActionStep processing.""" |
|
step = ActionStep( |
|
step_number=1, |
|
model_output="This is the model output", |
|
observations="Some execution logs", |
|
error=None, |
|
timing=Timing(start_time=1.0, end_time=3.5), |
|
token_usage=TokenUsage(input_tokens=100, output_tokens=50), |
|
) |
|
messages = list(pull_messages_from_step(step)) |
|
assert len(messages) == 5 |
|
for message, expected_content in zip( |
|
messages, |
|
[ |
|
"**Step 1**", |
|
"This is the model output", |
|
"execution logs", |
|
"Input tokens: 100 | Output tokens: 50 | Duration: 2.5", |
|
"-----", |
|
], |
|
): |
|
assert expected_content in message.content |
|
|
|
def test_action_step_with_tool_calls(self): |
|
"""Test ActionStep with tool calls.""" |
|
step = ActionStep( |
|
step_number=2, |
|
tool_calls=[ToolCall(name="test_tool", arguments={"answer": "Test answer"}, id="tool_call_1")], |
|
observations="Tool execution logs", |
|
timing=Timing(start_time=1.0, end_time=2.5), |
|
token_usage=TokenUsage(input_tokens=100, output_tokens=50), |
|
) |
|
messages = list(pull_messages_from_step(step)) |
|
assert len(messages) == 5 |
|
assert messages[1].content == "Test answer" |
|
assert "Used tool test_tool" in messages[1].metadata["title"] |
|
|
|
@pytest.mark.parametrize( |
|
"tool_name, args, expected", |
|
[ |
|
("python_interpreter", "print('Hello')", "```python\nprint('Hello')\n```"), |
|
("regular_tool", {"key": "value"}, "{'key': 'value'}"), |
|
("string_args_tool", "simple string", "simple string"), |
|
], |
|
) |
|
def test_action_step_tool_call_formats(self, tool_name, args, expected): |
|
"""Test different formats of tool calls.""" |
|
tool_call = Mock() |
|
tool_call.name = tool_name |
|
tool_call.arguments = args |
|
step = ActionStep( |
|
step_number=1, |
|
tool_calls=[tool_call], |
|
timing=Timing(start_time=1.0, end_time=2.5), |
|
token_usage=TokenUsage(input_tokens=100, output_tokens=50), |
|
) |
|
messages = list(pull_messages_from_step(step)) |
|
tool_message = next( |
|
msg |
|
for msg in messages |
|
if msg.role == "assistant" and msg.metadata and msg.metadata.get("title", "").startswith("🛠️") |
|
) |
|
assert expected in tool_message.content |
|
|
|
def test_action_step_with_error(self): |
|
"""Test ActionStep with error.""" |
|
step = ActionStep( |
|
step_number=3, |
|
error="This is an error message", |
|
timing=Timing(start_time=1.0, end_time=2.0), |
|
token_usage=TokenUsage(input_tokens=100, output_tokens=200), |
|
) |
|
messages = list(pull_messages_from_step(step)) |
|
error_message = next((m for m in messages if "error" in str(m.content).lower()), None) |
|
assert error_message is not None |
|
assert "This is an error message" in error_message.content |
|
|
|
def test_action_step_with_images(self): |
|
"""Test ActionStep with observation images.""" |
|
step = ActionStep( |
|
step_number=4, |
|
observations_images=["image1.png", "image2.jpg"], |
|
token_usage=TokenUsage(input_tokens=100, output_tokens=200), |
|
timing=Timing(start_time=1.0, end_time=2.0), |
|
) |
|
with patch("smolagents.gradio_ui.AgentImage") as mock_agent_image: |
|
mock_agent_image.return_value.to_string.side_effect = lambda: "path/to/image.png" |
|
messages = list(pull_messages_from_step(step)) |
|
image_messages = [m for m in messages if "image" in str(m).lower()] |
|
assert len(image_messages) == 2 |
|
assert "path/to/image.png" in str(image_messages[0]) |
|
|
|
@pytest.mark.parametrize( |
|
"skip_model_outputs, expected_messages_length, token_usage", |
|
[(False, 4, TokenUsage(input_tokens=80, output_tokens=30)), (True, 2, None)], |
|
) |
|
def test_planning_step(self, skip_model_outputs, expected_messages_length, token_usage): |
|
"""Test PlanningStep processing.""" |
|
step = PlanningStep( |
|
plan="1. First step\n2. Second step", |
|
model_input_messages=Mock(), |
|
model_output_message=Mock(), |
|
token_usage=token_usage, |
|
timing=Timing(start_time=1.0, end_time=2.0), |
|
) |
|
messages = list(pull_messages_from_step(step, skip_model_outputs=skip_model_outputs)) |
|
assert len(messages) == expected_messages_length |
|
expected_contents = [ |
|
"**Planning step**", |
|
"1. First step\n2. Second step", |
|
"Input tokens: 80 | Output tokens: 30" if token_usage else "", |
|
"-----", |
|
] |
|
for message, expected_content in zip(messages, expected_contents[-expected_messages_length:]): |
|
assert expected_content in message.content |
|
|
|
if not token_usage: |
|
assert "Input tokens: 80 | Output tokens: 30" not in message.content |
|
|
|
@pytest.mark.parametrize( |
|
"answer_type, answer_value, expected_content", |
|
[ |
|
(AgentText, "This is a text answer", "**Final answer:**\nThis is a text answer\n"), |
|
(lambda: "Plain string", "Plain string", "**Final answer:** Plain string"), |
|
], |
|
) |
|
def test_final_answer_step(self, answer_type, answer_value, expected_content): |
|
"""Test FinalAnswerStep with different answer types.""" |
|
try: |
|
final_answer = answer_type() |
|
except TypeError: |
|
with patch.object(answer_type, "to_string", return_value=answer_value): |
|
final_answer = answer_type(answer_value) |
|
step = FinalAnswerStep( |
|
output=final_answer, |
|
) |
|
messages = list(pull_messages_from_step(step)) |
|
assert len(messages) == 1 |
|
assert messages[0].content == expected_content |
|
|
|
def test_final_answer_step_image(self): |
|
"""Test FinalAnswerStep with image answer.""" |
|
with patch.object(AgentImage, "to_string", return_value="path/to/image.png"): |
|
step = FinalAnswerStep(output=AgentImage("path/to/image.png")) |
|
messages = list(pull_messages_from_step(step)) |
|
assert len(messages) == 1 |
|
assert messages[0].content["path"] == "path/to/image.png" |
|
assert messages[0].content["mime_type"] == "image/png" |
|
|
|
def test_final_answer_step_audio(self): |
|
"""Test FinalAnswerStep with audio answer.""" |
|
with patch.object(AgentAudio, "to_string", return_value="path/to/audio.wav"): |
|
step = FinalAnswerStep(output=AgentAudio("path/to/audio.wav")) |
|
messages = list(pull_messages_from_step(step)) |
|
assert len(messages) == 1 |
|
assert messages[0].content["path"] == "path/to/audio.wav" |
|
assert messages[0].content["mime_type"] == "audio/wav" |
|
|
|
def test_unsupported_step_type(self): |
|
"""Test handling of unsupported step types.""" |
|
|
|
class UnsupportedStep(Mock): |
|
pass |
|
|
|
step = UnsupportedStep() |
|
with pytest.raises(ValueError, match="Unsupported step type"): |
|
list(pull_messages_from_step(step)) |
|
|