|
import pytest |
|
from PIL import Image |
|
|
|
from smolagents.agents import ToolCall |
|
from smolagents.memory import ( |
|
ActionStep, |
|
AgentMemory, |
|
ChatMessage, |
|
MemoryStep, |
|
MessageRole, |
|
PlanningStep, |
|
SystemPromptStep, |
|
TaskStep, |
|
) |
|
from smolagents.monitoring import Timing, TokenUsage |
|
|
|
|
|
class TestAgentMemory: |
|
def test_initialization(self): |
|
system_prompt = "This is a system prompt." |
|
memory = AgentMemory(system_prompt=system_prompt) |
|
assert memory.system_prompt.system_prompt == system_prompt |
|
assert memory.steps == [] |
|
|
|
def test_return_all_code_actions(self): |
|
memory = AgentMemory(system_prompt="This is a system prompt.") |
|
memory.steps = [ |
|
ActionStep(step_number=1, timing=Timing(start_time=0.0, end_time=1.0), code_action="print('Hello')"), |
|
ActionStep(step_number=2, timing=Timing(start_time=0.0, end_time=1.0), code_action=None), |
|
ActionStep(step_number=3, timing=Timing(start_time=0.0, end_time=1.0), code_action="print('World')"), |
|
] |
|
assert memory.return_full_code() == "print('Hello')\n\nprint('World')" |
|
|
|
|
|
class TestMemoryStep: |
|
def test_initialization(self): |
|
step = MemoryStep() |
|
assert isinstance(step, MemoryStep) |
|
|
|
def test_dict(self): |
|
step = MemoryStep() |
|
assert step.dict() == {} |
|
|
|
def test_to_messages(self): |
|
step = MemoryStep() |
|
with pytest.raises(NotImplementedError): |
|
step.to_messages() |
|
|
|
|
|
def test_action_step_dict(): |
|
action_step = ActionStep( |
|
model_input_messages=[ChatMessage(role=MessageRole.USER, content="Hello")], |
|
tool_calls=[ |
|
ToolCall(id="id", name="get_weather", arguments={"location": "Paris"}), |
|
], |
|
timing=Timing(start_time=0.0, end_time=1.0), |
|
step_number=1, |
|
error=None, |
|
model_output_message=ChatMessage(role=MessageRole.ASSISTANT, content="Hi"), |
|
model_output="Hi", |
|
observations="This is a nice observation", |
|
observations_images=[Image.new("RGB", (100, 100))], |
|
action_output="Output", |
|
token_usage=TokenUsage(input_tokens=10, output_tokens=20), |
|
) |
|
action_step_dict = action_step.dict() |
|
|
|
assert "model_input_messages" in action_step_dict |
|
assert action_step_dict["model_input_messages"] == [ChatMessage(role=MessageRole.USER, content="Hello")] |
|
|
|
assert "tool_calls" in action_step_dict |
|
assert len(action_step_dict["tool_calls"]) == 1 |
|
assert action_step_dict["tool_calls"][0] == { |
|
"id": "id", |
|
"type": "function", |
|
"function": { |
|
"name": "get_weather", |
|
"arguments": {"location": "Paris"}, |
|
}, |
|
} |
|
|
|
assert "timing" in action_step_dict |
|
assert action_step_dict["timing"] == {"start_time": 0.0, "end_time": 1.0, "duration": 1.0} |
|
|
|
assert "token_usage" in action_step_dict |
|
assert action_step_dict["token_usage"] == {"input_tokens": 10, "output_tokens": 20, "total_tokens": 30} |
|
|
|
assert "step_number" in action_step_dict |
|
assert action_step_dict["step_number"] == 1 |
|
|
|
assert "error" in action_step_dict |
|
assert action_step_dict["error"] is None |
|
|
|
assert "model_output_message" in action_step_dict |
|
assert action_step_dict["model_output_message"] == { |
|
"role": "assistant", |
|
"content": "Hi", |
|
"tool_calls": None, |
|
"raw": None, |
|
"token_usage": None, |
|
} |
|
|
|
assert "model_output" in action_step_dict |
|
assert action_step_dict["model_output"] == "Hi" |
|
|
|
assert "observations" in action_step_dict |
|
assert action_step_dict["observations"] == "This is a nice observation" |
|
|
|
assert "observations_images" in action_step_dict |
|
|
|
assert "action_output" in action_step_dict |
|
assert action_step_dict["action_output"] == "Output" |
|
|
|
|
|
def test_action_step_to_messages(): |
|
action_step = ActionStep( |
|
model_input_messages=[ChatMessage(role=MessageRole.USER, content="Hello")], |
|
tool_calls=[ |
|
ToolCall(id="id", name="get_weather", arguments={"location": "Paris"}), |
|
], |
|
timing=Timing(start_time=0.0, end_time=1.0), |
|
step_number=1, |
|
error=None, |
|
model_output_message=ChatMessage(role=MessageRole.ASSISTANT, content="Hi"), |
|
model_output="Hi", |
|
observations="This is a nice observation", |
|
observations_images=[Image.new("RGB", (100, 100))], |
|
action_output="Output", |
|
token_usage=TokenUsage(input_tokens=10, output_tokens=20), |
|
) |
|
messages = action_step.to_messages() |
|
assert len(messages) == 4 |
|
for message in messages: |
|
assert isinstance(message, ChatMessage) |
|
assistant_message = messages[0] |
|
assert assistant_message.role == MessageRole.ASSISTANT |
|
assert len(assistant_message.content) == 1 |
|
assert assistant_message.content[0]["type"] == "text" |
|
assert assistant_message.content[0]["text"] == "Hi" |
|
message = messages[1] |
|
assert message.role == MessageRole.TOOL_CALL |
|
|
|
assert len(message.content) == 1 |
|
assert message.content[0]["type"] == "text" |
|
assert "Calling tools:" in message.content[0]["text"] |
|
|
|
image_message = messages[2] |
|
assert image_message.content[0]["type"] == "image" |
|
|
|
observation_message = messages[3] |
|
assert observation_message.role == MessageRole.TOOL_RESPONSE |
|
assert "Observation:\nThis is a nice observation" in observation_message.content[0]["text"] |
|
|
|
|
|
def test_action_step_to_messages_no_tool_calls_with_observations(): |
|
action_step = ActionStep( |
|
model_input_messages=None, |
|
tool_calls=None, |
|
timing=Timing(start_time=0.0, end_time=1.0), |
|
step_number=1, |
|
error=None, |
|
model_output_message=None, |
|
model_output=None, |
|
observations="This is an observation.", |
|
observations_images=None, |
|
action_output=None, |
|
token_usage=TokenUsage(input_tokens=10, output_tokens=20), |
|
) |
|
messages = action_step.to_messages() |
|
assert len(messages) == 1 |
|
observation_message = messages[0] |
|
assert observation_message.role == MessageRole.TOOL_RESPONSE |
|
assert "Observation:\nThis is an observation." in observation_message.content[0]["text"] |
|
|
|
|
|
def test_planning_step_to_messages(): |
|
planning_step = PlanningStep( |
|
model_input_messages=[ChatMessage(role=MessageRole.USER, content="Hello")], |
|
model_output_message=ChatMessage(role=MessageRole.ASSISTANT, content="Plan"), |
|
plan="This is a plan.", |
|
timing=Timing(start_time=0.0, end_time=1.0), |
|
) |
|
messages = planning_step.to_messages(summary_mode=False) |
|
assert len(messages) == 2 |
|
for message in messages: |
|
assert isinstance(message, ChatMessage) |
|
assert isinstance(message.content, list) |
|
assert len(message.content) == 1 |
|
for content in message.content: |
|
assert isinstance(content, dict) |
|
assert "type" in content |
|
assert "text" in content |
|
assert messages[0].role == MessageRole.ASSISTANT |
|
assert messages[1].role == MessageRole.USER |
|
|
|
|
|
def test_task_step_to_messages(): |
|
task_step = TaskStep(task="This is a task.", task_images=[Image.new("RGB", (100, 100))]) |
|
messages = task_step.to_messages(summary_mode=False) |
|
assert len(messages) == 1 |
|
for message in messages: |
|
assert isinstance(message, ChatMessage) |
|
assert message.role == MessageRole.USER |
|
assert isinstance(message.content, list) |
|
assert len(message.content) == 2 |
|
text_content = message.content[0] |
|
assert isinstance(text_content, dict) |
|
assert "type" in text_content |
|
assert "text" in text_content |
|
for image_content in message.content[1:]: |
|
assert isinstance(image_content, dict) |
|
assert "type" in image_content |
|
assert "image" in image_content |
|
|
|
|
|
def test_system_prompt_step_to_messages(): |
|
system_prompt_step = SystemPromptStep(system_prompt="This is a system prompt.") |
|
messages = system_prompt_step.to_messages(summary_mode=False) |
|
assert len(messages) == 1 |
|
for message in messages: |
|
assert isinstance(message, ChatMessage) |
|
assert message.role == MessageRole.SYSTEM |
|
assert isinstance(message.content, list) |
|
assert len(message.content) == 1 |
|
for content in message.content: |
|
assert isinstance(content, dict) |
|
assert "type" in content |
|
assert "text" in content |
|
|