Final_Assignment_Template3 / tests /test_gradio_ui.py
Duibonduil's picture
Upload 21 files
9c31777 verified
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import shutil
import tempfile
import unittest
from unittest.mock import Mock, patch
import pytest
from smolagents.agent_types import AgentAudio, AgentImage, AgentText
from smolagents.gradio_ui import GradioUI, pull_messages_from_step, stream_to_gradio
from smolagents.memory import ActionStep, FinalAnswerStep, PlanningStep, ToolCall
from smolagents.models import ChatMessageStreamDelta
from smolagents.monitoring import Timing, TokenUsage
class GradioUITester(unittest.TestCase):
def setUp(self):
"""Initialize test environment"""
self.temp_dir = tempfile.mkdtemp()
self.mock_agent = Mock()
self.ui = GradioUI(agent=self.mock_agent, file_upload_folder=self.temp_dir)
self.allowed_types = [".pdf", ".docx", ".txt"]
def tearDown(self):
"""Clean up test environment"""
shutil.rmtree(self.temp_dir)
def test_upload_file_default_types(self):
"""Test default allowed file types"""
default_types = [".pdf", ".docx", ".txt"]
for file_type in default_types:
with tempfile.NamedTemporaryFile(suffix=file_type) as temp_file:
mock_file = Mock()
mock_file.name = temp_file.name
textbox, uploads_log = self.ui.upload_file(mock_file, [])
self.assertIn("File uploaded:", textbox.value)
self.assertEqual(len(uploads_log), 1)
self.assertTrue(os.path.exists(os.path.join(self.temp_dir, os.path.basename(temp_file.name))))
def test_upload_file_default_types_disallowed(self):
"""Test default disallowed file types"""
disallowed_types = [".exe", ".sh", ".py", ".jpg"]
for file_type in disallowed_types:
with tempfile.NamedTemporaryFile(suffix=file_type) as temp_file:
mock_file = Mock()
mock_file.name = temp_file.name
textbox, uploads_log = self.ui.upload_file(mock_file, [])
self.assertEqual(textbox.value, "File type disallowed")
self.assertEqual(len(uploads_log), 0)
def test_upload_file_success(self):
"""Test successful file upload scenario"""
with tempfile.NamedTemporaryFile(suffix=".txt") as temp_file:
mock_file = Mock()
mock_file.name = temp_file.name
textbox, uploads_log = self.ui.upload_file(mock_file, [])
self.assertIn("File uploaded:", textbox.value)
self.assertEqual(len(uploads_log), 1)
self.assertTrue(os.path.exists(os.path.join(self.temp_dir, os.path.basename(temp_file.name))))
self.assertEqual(uploads_log[0], os.path.join(self.temp_dir, os.path.basename(temp_file.name)))
def test_upload_file_none(self):
"""Test scenario when no file is selected"""
textbox, uploads_log = self.ui.upload_file(None, [])
self.assertEqual(textbox.value, "No file uploaded")
self.assertEqual(len(uploads_log), 0)
def test_upload_file_invalid_type(self):
"""Test disallowed file type"""
with tempfile.NamedTemporaryFile(suffix=".exe") as temp_file:
mock_file = Mock()
mock_file.name = temp_file.name
textbox, uploads_log = self.ui.upload_file(mock_file, [])
self.assertEqual(textbox.value, "File type disallowed")
self.assertEqual(len(uploads_log), 0)
def test_upload_file_special_chars(self):
"""Test scenario with special characters in filename"""
with tempfile.NamedTemporaryFile(suffix=".txt") as temp_file:
# Create a new temporary file with special characters
special_char_name = os.path.join(os.path.dirname(temp_file.name), "test@#$%^&*.txt")
shutil.copy(temp_file.name, special_char_name)
try:
mock_file = Mock()
mock_file.name = special_char_name
with patch("shutil.copy"):
textbox, uploads_log = self.ui.upload_file(mock_file, [])
self.assertIn("File uploaded:", textbox.value)
self.assertEqual(len(uploads_log), 1)
self.assertIn("test_____", uploads_log[0])
finally:
# Clean up the special character file
if os.path.exists(special_char_name):
os.remove(special_char_name)
def test_upload_file_custom_types(self):
"""Test custom allowed file types"""
with tempfile.NamedTemporaryFile(suffix=".csv") as temp_file:
mock_file = Mock()
mock_file.name = temp_file.name
textbox, uploads_log = self.ui.upload_file(mock_file, [], allowed_file_types=[".csv"])
self.assertIn("File uploaded:", textbox.value)
self.assertEqual(len(uploads_log), 1)
class TestStreamToGradio:
"""Tests for the stream_to_gradio function."""
@patch("smolagents.gradio_ui.pull_messages_from_step")
def test_stream_to_gradio_memory_step(self, mock_pull_messages):
"""Test streaming a memory step"""
# Create mock agent and memory step
mock_agent = Mock()
mock_agent.run = Mock(return_value=[Mock(spec=ActionStep)])
mock_agent.model = Mock()
mock_agent.model.last_input_token_count = 100
mock_agent.model.last_output_token_count = 200
# Mock the pull_messages_from_step function to return some messages
mock_message = Mock()
mock_pull_messages.return_value = [mock_message]
# Call stream_to_gradio
result = list(stream_to_gradio(mock_agent, "test task"))
# Verify that pull_messages_from_step was called and the message was yielded
mock_pull_messages.assert_called_once()
assert result == [mock_message]
def test_stream_to_gradio_stream_delta(self):
"""Test streaming a ChatMessageStreamDelta"""
# Create mock agent and stream delta
mock_agent = Mock()
mock_delta = ChatMessageStreamDelta(content="Hello")
mock_agent.run = Mock(return_value=[mock_delta])
mock_agent.model = Mock()
mock_agent.model.last_input_token_count = 100
mock_agent.model.last_output_token_count = 200
# Call stream_to_gradio
result = list(stream_to_gradio(mock_agent, "test task"))
# Verify that the content was yielded
assert result == ["Hello"]
def test_stream_to_gradio_multiple_deltas(self):
"""Test streaming multiple ChatMessageStreamDeltas"""
# Create mock agent and stream deltas
mock_agent = Mock()
mock_delta1 = ChatMessageStreamDelta(content="Hello")
mock_delta2 = ChatMessageStreamDelta(content=" world")
mock_agent.run = Mock(return_value=[mock_delta1, mock_delta2])
mock_agent.model = Mock()
mock_agent.model.last_input_token_count = 100
mock_agent.model.last_output_token_count = 200
# Call stream_to_gradio
result = list(stream_to_gradio(mock_agent, "test task"))
# Verify that the content was accumulated and yielded
assert result == ["Hello", "Hello world"]
@pytest.mark.parametrize(
"task,task_images,reset_memory,additional_args",
[
("simple task", None, False, None),
("task with images", ["image1.png", "image2.png"], False, None),
("task with reset", None, True, None),
("task with args", None, False, {"arg1": "value1"}),
("complex task", ["image.png"], True, {"arg1": "value1", "arg2": "value2"}),
],
)
def test_stream_to_gradio_parameters(self, task, task_images, reset_memory, additional_args):
"""Test that stream_to_gradio passes parameters correctly to agent.run"""
# Create mock agent
mock_agent = Mock()
mock_agent.run = Mock(return_value=[])
# Call stream_to_gradio
list(
stream_to_gradio(
mock_agent,
task=task,
task_images=task_images,
reset_agent_memory=reset_memory,
additional_args=additional_args,
)
)
# Verify that agent.run was called with the right parameters
mock_agent.run.assert_called_once_with(
task, images=task_images, stream=True, reset=reset_memory, additional_args=additional_args
)
class TestPullMessagesFromStep:
def test_action_step_basic(
self,
):
"""Test basic ActionStep processing."""
step = ActionStep(
step_number=1,
model_output="This is the model output",
observations="Some execution logs",
error=None,
timing=Timing(start_time=1.0, end_time=3.5),
token_usage=TokenUsage(input_tokens=100, output_tokens=50),
)
messages = list(pull_messages_from_step(step))
assert len(messages) == 5 # step number, model_output, logs, footnote, divider
for message, expected_content in zip(
messages,
[
"**Step 1**",
"This is the model output",
"execution logs",
"Input tokens: 100 | Output tokens: 50 | Duration: 2.5",
"-----",
],
):
assert expected_content in message.content
def test_action_step_with_tool_calls(self):
"""Test ActionStep with tool calls."""
step = ActionStep(
step_number=2,
tool_calls=[ToolCall(name="test_tool", arguments={"answer": "Test answer"}, id="tool_call_1")],
observations="Tool execution logs",
timing=Timing(start_time=1.0, end_time=2.5),
token_usage=TokenUsage(input_tokens=100, output_tokens=50),
)
messages = list(pull_messages_from_step(step))
assert len(messages) == 5 # step, tool call, logs, footnote, divider
assert messages[1].content == "Test answer"
assert "Used tool test_tool" in messages[1].metadata["title"]
@pytest.mark.parametrize(
"tool_name, args, expected",
[
("python_interpreter", "print('Hello')", "```python\nprint('Hello')\n```"),
("regular_tool", {"key": "value"}, "{'key': 'value'}"),
("string_args_tool", "simple string", "simple string"),
],
)
def test_action_step_tool_call_formats(self, tool_name, args, expected):
"""Test different formats of tool calls."""
tool_call = Mock()
tool_call.name = tool_name
tool_call.arguments = args
step = ActionStep(
step_number=1,
tool_calls=[tool_call],
timing=Timing(start_time=1.0, end_time=2.5),
token_usage=TokenUsage(input_tokens=100, output_tokens=50),
)
messages = list(pull_messages_from_step(step))
tool_message = next(
msg
for msg in messages
if msg.role == "assistant" and msg.metadata and msg.metadata.get("title", "").startswith("🛠️")
)
assert expected in tool_message.content
def test_action_step_with_error(self):
"""Test ActionStep with error."""
step = ActionStep(
step_number=3,
error="This is an error message",
timing=Timing(start_time=1.0, end_time=2.0),
token_usage=TokenUsage(input_tokens=100, output_tokens=200),
)
messages = list(pull_messages_from_step(step))
error_message = next((m for m in messages if "error" in str(m.content).lower()), None)
assert error_message is not None
assert "This is an error message" in error_message.content
def test_action_step_with_images(self):
"""Test ActionStep with observation images."""
step = ActionStep(
step_number=4,
observations_images=["image1.png", "image2.jpg"],
token_usage=TokenUsage(input_tokens=100, output_tokens=200),
timing=Timing(start_time=1.0, end_time=2.0),
)
with patch("smolagents.gradio_ui.AgentImage") as mock_agent_image:
mock_agent_image.return_value.to_string.side_effect = lambda: "path/to/image.png"
messages = list(pull_messages_from_step(step))
image_messages = [m for m in messages if "image" in str(m).lower()]
assert len(image_messages) == 2
assert "path/to/image.png" in str(image_messages[0])
@pytest.mark.parametrize(
"skip_model_outputs, expected_messages_length, token_usage",
[(False, 4, TokenUsage(input_tokens=80, output_tokens=30)), (True, 2, None)],
)
def test_planning_step(self, skip_model_outputs, expected_messages_length, token_usage):
"""Test PlanningStep processing."""
step = PlanningStep(
plan="1. First step\n2. Second step",
model_input_messages=Mock(),
model_output_message=Mock(),
token_usage=token_usage,
timing=Timing(start_time=1.0, end_time=2.0),
)
messages = list(pull_messages_from_step(step, skip_model_outputs=skip_model_outputs))
assert len(messages) == expected_messages_length # [header, plan,] footnote, divider
expected_contents = [
"**Planning step**",
"1. First step\n2. Second step",
"Input tokens: 80 | Output tokens: 30" if token_usage else "",
"-----",
]
for message, expected_content in zip(messages, expected_contents[-expected_messages_length:]):
assert expected_content in message.content
if not token_usage:
assert "Input tokens: 80 | Output tokens: 30" not in message.content
@pytest.mark.parametrize(
"answer_type, answer_value, expected_content",
[
(AgentText, "This is a text answer", "**Final answer:**\nThis is a text answer\n"),
(lambda: "Plain string", "Plain string", "**Final answer:** Plain string"),
],
)
def test_final_answer_step(self, answer_type, answer_value, expected_content):
"""Test FinalAnswerStep with different answer types."""
try:
final_answer = answer_type()
except TypeError:
with patch.object(answer_type, "to_string", return_value=answer_value):
final_answer = answer_type(answer_value)
step = FinalAnswerStep(
output=final_answer,
)
messages = list(pull_messages_from_step(step))
assert len(messages) == 1
assert messages[0].content == expected_content
def test_final_answer_step_image(self):
"""Test FinalAnswerStep with image answer."""
with patch.object(AgentImage, "to_string", return_value="path/to/image.png"):
step = FinalAnswerStep(output=AgentImage("path/to/image.png"))
messages = list(pull_messages_from_step(step))
assert len(messages) == 1
assert messages[0].content["path"] == "path/to/image.png"
assert messages[0].content["mime_type"] == "image/png"
def test_final_answer_step_audio(self):
"""Test FinalAnswerStep with audio answer."""
with patch.object(AgentAudio, "to_string", return_value="path/to/audio.wav"):
step = FinalAnswerStep(output=AgentAudio("path/to/audio.wav"))
messages = list(pull_messages_from_step(step))
assert len(messages) == 1
assert messages[0].content["path"] == "path/to/audio.wav"
assert messages[0].content["mime_type"] == "audio/wav"
def test_unsupported_step_type(self):
"""Test handling of unsupported step types."""
class UnsupportedStep(Mock):
pass
step = UnsupportedStep()
with pytest.raises(ValueError, match="Unsupported step type"):
list(pull_messages_from_step(step))