Spaces:
Sleeping
Sleeping
Upload 7 files
Browse files- tests/base_test.py +99 -0
- tests/test.json +1 -0
- tests/test_context_compressor.py +133 -0
- tests/test_context_hook.py +32 -0
- tests/test_context_management.py +236 -0
- tests/test_demo.py +11 -0
- tests/test_llm_hook.py +49 -0
tests/base_test.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
class BaseTest:
|
4 |
+
|
5 |
+
# Custom assert methods implementation
|
6 |
+
def assertIsNotNone(self, value, msg=None):
|
7 |
+
"""Assert that value is not None"""
|
8 |
+
if value is None:
|
9 |
+
raise AssertionError(msg or f"Expected not None, but got None")
|
10 |
+
|
11 |
+
def assertEqual(self, first, second, msg=None):
|
12 |
+
"""Assert that first equals second"""
|
13 |
+
if first != second:
|
14 |
+
raise AssertionError(msg or f"Expected {first} == {second}, but {first} != {second}")
|
15 |
+
|
16 |
+
def assertTrue(self, expr, msg=None):
|
17 |
+
"""Assert that expr is True"""
|
18 |
+
if not expr:
|
19 |
+
raise AssertionError(msg or f"Expected True, but got {expr}")
|
20 |
+
|
21 |
+
def assertFalse(self, expr, msg=None):
|
22 |
+
"""Assert that expr is False"""
|
23 |
+
if expr:
|
24 |
+
raise AssertionError(msg or f"Expected False, but got {expr}")
|
25 |
+
|
26 |
+
def assertAlmostEqual(self, first, second, places=7, msg=None):
|
27 |
+
"""Assert that first and second are approximately equal"""
|
28 |
+
if round(abs(second - first), places) != 0:
|
29 |
+
raise AssertionError(msg or f"Expected {first} ~= {second} (within {places} decimal places)")
|
30 |
+
|
31 |
+
def assertIs(self, first, second, msg=None):
|
32 |
+
"""Assert that first is second (same object identity)"""
|
33 |
+
if first is not second:
|
34 |
+
raise AssertionError(msg or f"Expected {first} is {second}, but they are different objects")
|
35 |
+
|
36 |
+
def assertIn(self, member, container, msg=None):
|
37 |
+
"""Assert that member is in container"""
|
38 |
+
if member not in container:
|
39 |
+
raise AssertionError(msg or f"Expected {member} in {container}")
|
40 |
+
|
41 |
+
def assertIsInstance(self, obj, cls, msg=None):
|
42 |
+
"""Assert that obj is an instance of cls"""
|
43 |
+
if not isinstance(obj, cls):
|
44 |
+
raise AssertionError(msg or f"Expected {obj} to be instance of {cls}, but got {type(obj)}")
|
45 |
+
|
46 |
+
def assertIsNone(self, value, msg=None):
|
47 |
+
"""Assert that value is None"""
|
48 |
+
if value is not None:
|
49 |
+
raise AssertionError(msg or f"Expected None, but got {value}")
|
50 |
+
|
51 |
+
def assertNotEqual(self, first, second, msg=None):
|
52 |
+
"""Assert that first does not equal second"""
|
53 |
+
if first == second:
|
54 |
+
raise AssertionError(msg or f"Expected {first} != {second}, but they are equal")
|
55 |
+
|
56 |
+
def assertGreater(self, first, second, msg=None):
|
57 |
+
"""Assert that first is greater than second"""
|
58 |
+
if not first > second:
|
59 |
+
raise AssertionError(msg or f"Expected {first} > {second}")
|
60 |
+
|
61 |
+
def assertLess(self, first, second, msg=None):
|
62 |
+
"""Assert that first is less than second"""
|
63 |
+
if not first < second:
|
64 |
+
raise AssertionError(msg or f"Expected {first} < {second}")
|
65 |
+
|
66 |
+
def assertGreaterEqual(self, first, second, msg=None):
|
67 |
+
"""Assert that first is greater than or equal to second"""
|
68 |
+
if not first >= second:
|
69 |
+
raise AssertionError(msg or f"Expected {first} >= {second}")
|
70 |
+
|
71 |
+
def assertLessEqual(self, first, second, msg=None):
|
72 |
+
"""Assert that first is less than or equal to second"""
|
73 |
+
if not first <= second:
|
74 |
+
raise AssertionError(msg or f"Expected {first} <= {second}")
|
75 |
+
|
76 |
+
def assertNotIn(self, member, container, msg=None):
|
77 |
+
"""Assert that member is not in container"""
|
78 |
+
if member in container:
|
79 |
+
raise AssertionError(msg or f"Expected {member} not in {container}")
|
80 |
+
|
81 |
+
def assertIsNot(self, first, second, msg=None):
|
82 |
+
"""Assert that first is not second (different object identity)"""
|
83 |
+
if first is second:
|
84 |
+
raise AssertionError(msg or f"Expected {first} is not {second}, but they are the same object")
|
85 |
+
|
86 |
+
def assertRaises(self, exception_class, callable_obj=None, *args, **kwargs):
|
87 |
+
"""Assert that calling callable_obj raises exception_class"""
|
88 |
+
if callable_obj is None:
|
89 |
+
# Return a context manager for use with 'with' statement
|
90 |
+
return self._AssertRaisesContext(exception_class)
|
91 |
+
else:
|
92 |
+
try:
|
93 |
+
callable_obj(*args, **kwargs)
|
94 |
+
raise AssertionError(f"Expected {exception_class.__name__} to be raised, but no exception was raised")
|
95 |
+
except exception_class:
|
96 |
+
pass # Expected exception was raised
|
97 |
+
except Exception as e:
|
98 |
+
raise AssertionError(f"Expected {exception_class.__name__} to be raised, but got {type(e).__name__}: {e}")
|
99 |
+
|
tests/test.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"test": "test content"}
|
tests/test_context_compressor.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import unittest
|
2 |
+
import sys
|
3 |
+
import os
|
4 |
+
from unittest.mock import Mock, patch, MagicMock
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
# Add the project root to Python path
|
8 |
+
project_root = Path(__file__).parent.parent
|
9 |
+
sys.path.insert(0, str(project_root))
|
10 |
+
|
11 |
+
from tests.base_test import BaseTest
|
12 |
+
|
13 |
+
from aworld.config.conf import AgentConfig, ModelConfig, ContextRuleConfig, OptimizationConfig, LlmCompressionConfig
|
14 |
+
from aworld.core.context.processor import CompressionResult, CompressionType
|
15 |
+
from aworld.core.context.processor.llm_compressor import LLMCompressor
|
16 |
+
from aworld.core.context.processor.prompt_processor import PromptProcessor
|
17 |
+
from aworld.core.context.base import AgentContext, ContextUsage
|
18 |
+
|
19 |
+
|
20 |
+
class TestPromptCompressor(BaseTest):
|
21 |
+
"""Test cases for PromptCompressor.compress_batch function"""
|
22 |
+
|
23 |
+
def __init__(self):
|
24 |
+
"""Set up test fixtures"""
|
25 |
+
self.mock_model_name = "qwen/qwen3-1.7b"
|
26 |
+
self.mock_base_url = "http://localhost:1234/v1"
|
27 |
+
self.mock_api_key = "lm-studio"
|
28 |
+
self.mock_llm_config = ModelConfig(
|
29 |
+
llm_model_name=self.mock_model_name,
|
30 |
+
llm_base_url=self.mock_base_url,
|
31 |
+
llm_api_key=self.mock_api_key
|
32 |
+
)
|
33 |
+
os.environ["LLM_API_KEY"] = self.mock_api_key
|
34 |
+
os.environ["LLM_BASE_URL"] = self.mock_base_url
|
35 |
+
os.environ["LLM_MODEL_NAME"] = self.mock_model_name
|
36 |
+
|
37 |
+
def test_compress_batch_basic(self):
|
38 |
+
|
39 |
+
compressor = LLMCompressor(
|
40 |
+
llm_config=self.mock_llm_config
|
41 |
+
)
|
42 |
+
|
43 |
+
# Test data
|
44 |
+
contents = [
|
45 |
+
"[SYSTEM]You are a helpful assistant.\n[USER]This is the first long text content that needs compression. This is the first long text content that needs compression.",
|
46 |
+
]
|
47 |
+
|
48 |
+
# Execute compress_batch
|
49 |
+
results = compressor.compress_batch(contents)
|
50 |
+
|
51 |
+
# Assertions
|
52 |
+
for result in results:
|
53 |
+
self.assertIsInstance(result, CompressionResult)
|
54 |
+
self.assertEqual(result.compression_type, CompressionType.LLM_BASED)
|
55 |
+
self.assertTrue('This is the first long text content that needs compression. This is the first long text content that needs compression.' not in result.compressed_content)
|
56 |
+
|
57 |
+
def test_compress_messages(self):
|
58 |
+
"""Test compress_messages function from PromptProcessor"""
|
59 |
+
|
60 |
+
# Create context rule with compression enabled
|
61 |
+
context_rule = ContextRuleConfig(
|
62 |
+
optimization_config=OptimizationConfig(
|
63 |
+
enabled=True,
|
64 |
+
max_token_budget_ratio=0.8
|
65 |
+
),
|
66 |
+
llm_compression_config=LlmCompressionConfig(
|
67 |
+
enabled=True,
|
68 |
+
trigger_compress_token_length=10, # Low threshold to trigger compression
|
69 |
+
compress_model=self.mock_llm_config
|
70 |
+
)
|
71 |
+
)
|
72 |
+
|
73 |
+
# Create agent context
|
74 |
+
agent_context = AgentContext(
|
75 |
+
agent_id="test_agent",
|
76 |
+
agent_name="test_agent",
|
77 |
+
agent_desc="Test agent for compression",
|
78 |
+
system_prompt="You are a helpful assistant.",
|
79 |
+
agent_prompt="You are a helpful assistant.",
|
80 |
+
model_config=self.mock_llm_config,
|
81 |
+
context_rule=context_rule,
|
82 |
+
context_usage=ContextUsage(total_context_length=4096)
|
83 |
+
)
|
84 |
+
|
85 |
+
# Create prompt processor
|
86 |
+
processor = PromptProcessor(agent_context)
|
87 |
+
|
88 |
+
# Test messages with repeated content that needs compression
|
89 |
+
messages = [
|
90 |
+
{
|
91 |
+
"role": "system",
|
92 |
+
"content": "You are a helpful assistant."
|
93 |
+
},
|
94 |
+
{
|
95 |
+
"role": "user",
|
96 |
+
"content": "This is the first long text content that needs compression. This is the first long text content that needs compression."
|
97 |
+
},
|
98 |
+
{
|
99 |
+
"role": "assistant",
|
100 |
+
"content": "I understand you want me to help with compression."
|
101 |
+
}
|
102 |
+
]
|
103 |
+
|
104 |
+
# Execute compress_messages
|
105 |
+
compressed_messages = processor.compress_messages(messages)
|
106 |
+
|
107 |
+
# Assertions
|
108 |
+
self.assertIsInstance(compressed_messages, list)
|
109 |
+
self.assertEqual(len(compressed_messages), len(messages))
|
110 |
+
|
111 |
+
# Find the user message and verify it was processed
|
112 |
+
user_message = None
|
113 |
+
for msg in compressed_messages:
|
114 |
+
if msg.get("role") == "user":
|
115 |
+
user_message = msg
|
116 |
+
break
|
117 |
+
|
118 |
+
self.assertIsNotNone(user_message)
|
119 |
+
# The original repeated text should be compressed
|
120 |
+
original_content = "This is the first long text content that needs compression. This is the first long text content that needs compression."
|
121 |
+
self.assertNotEqual(user_message["content"], original_content)
|
122 |
+
# The compressed content should be shorter than original
|
123 |
+
self.assertLess(len(user_message["content"]), len(original_content))
|
124 |
+
|
125 |
+
if __name__ == '__main__':
|
126 |
+
testPromptCompressor = TestPromptCompressor()
|
127 |
+
testPromptCompressor.test_compress_batch_basic()
|
128 |
+
testPromptCompressor = TestPromptCompressor()
|
129 |
+
testPromptCompressor.test_compress_messages()
|
130 |
+
|
131 |
+
|
132 |
+
|
133 |
+
|
tests/test_context_hook.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# Add the project root to Python path
|
3 |
+
from pathlib import Path
|
4 |
+
import sys
|
5 |
+
|
6 |
+
from aworld.logs.util import color_log
|
7 |
+
|
8 |
+
|
9 |
+
project_root = Path(__file__).parent.parent
|
10 |
+
sys.path.insert(0, str(project_root))
|
11 |
+
|
12 |
+
from aworld.core.task import Task
|
13 |
+
from aworld.core.agent.base import AgentFactory
|
14 |
+
from aworld.core.agent.swarm import Swarm
|
15 |
+
from aworld.runner import Runners
|
16 |
+
from aworld.agents.llm_agent import Agent
|
17 |
+
from aworld.config.conf import AgentConfig, ContextRuleConfig, ModelConfig, OptimizationConfig, LlmCompressionConfig
|
18 |
+
from aworld.core.context.base import Context
|
19 |
+
from aworld.core.event.base import Message
|
20 |
+
from aworld.runners.hook.hooks import PreLLMCallHook, PostLLMCallHook
|
21 |
+
from aworld.runners.hook.hook_factory import HookFactory
|
22 |
+
from aworld.utils.common import convert_to_snake
|
23 |
+
from tests.base_test import BaseTest
|
24 |
+
|
25 |
+
@HookFactory.register(name="CheckContextPreLLMHook", desc="Test pre-LLM hook")
|
26 |
+
class CheckContextPreLLMHook(PreLLMCallHook):
|
27 |
+
def name(self):
|
28 |
+
return convert_to_snake("CheckContextPreLLMHook")
|
29 |
+
async def exec(self, message: Message, context: Context = None) -> Message:
|
30 |
+
assert context.state.get("task") == "What is an agent."
|
31 |
+
color_log("CheckContextPreLLMHook test ok", color="green")
|
32 |
+
return message
|
tests/test_context_management.py
ADDED
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
import sys
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
# Add the project root to Python path
|
7 |
+
project_root = Path(__file__).parent.parent
|
8 |
+
sys.path.insert(0, str(project_root))
|
9 |
+
|
10 |
+
from aworld.core.task import Task
|
11 |
+
from aworld.core.agent.swarm import Swarm
|
12 |
+
from aworld.runner import Runners
|
13 |
+
from aworld.agents.llm_agent import Agent
|
14 |
+
from aworld.config.conf import AgentConfig, ContextRuleConfig, ModelConfig, OptimizationConfig, LlmCompressionConfig
|
15 |
+
from aworld.core.context.base import Context
|
16 |
+
from aworld.runners.hook.hook_factory import HookFactory
|
17 |
+
from tests.base_test import BaseTest
|
18 |
+
|
19 |
+
|
20 |
+
class TestContextManagement(BaseTest):
|
21 |
+
"""Test cases for Context Management system based on README examples"""
|
22 |
+
|
23 |
+
def __init__(self):
|
24 |
+
"""Set up test fixtures"""
|
25 |
+
self.mock_model_name = "qwen/qwen3-1.7b"
|
26 |
+
self.mock_base_url = "http://localhost:1234/v1"
|
27 |
+
self.mock_api_key = "lm-studio"
|
28 |
+
os.environ["LLM_API_KEY"] = self.mock_api_key
|
29 |
+
os.environ["LLM_BASE_URL"] = self.mock_base_url
|
30 |
+
os.environ["LLM_MODEL_NAME"] = self.mock_model_name
|
31 |
+
|
32 |
+
def init_agent(self, config_type: str = "1", context_rule: ContextRuleConfig = None):
|
33 |
+
if config_type == "1":
|
34 |
+
conf = AgentConfig(
|
35 |
+
llm_model_name=self.mock_model_name,
|
36 |
+
llm_base_url=self.mock_base_url,
|
37 |
+
llm_api_key=self.mock_api_key
|
38 |
+
)
|
39 |
+
else:
|
40 |
+
conf = AgentConfig(
|
41 |
+
llm_config=ModelConfig(
|
42 |
+
llm_model_name=self.mock_model_name,
|
43 |
+
llm_base_url=self.mock_base_url,
|
44 |
+
llm_api_key=self.mock_api_key
|
45 |
+
)
|
46 |
+
)
|
47 |
+
return Agent(
|
48 |
+
conf=conf,
|
49 |
+
name="my_agent" + str(random.randint(0, 1000000)),
|
50 |
+
system_prompt="You are a helpful assistant.",
|
51 |
+
agent_prompt="You are a helpful assistant.",
|
52 |
+
context_rule=context_rule
|
53 |
+
)
|
54 |
+
|
55 |
+
class _AssertRaisesContext:
|
56 |
+
"""Context manager for assertRaises"""
|
57 |
+
|
58 |
+
def __init__(self, expected_exception):
|
59 |
+
self.expected_exception = expected_exception
|
60 |
+
|
61 |
+
def __enter__(self):
|
62 |
+
return self
|
63 |
+
|
64 |
+
def __exit__(self, exc_type, exc_value, traceback):
|
65 |
+
if exc_type is None:
|
66 |
+
raise AssertionError(f"Expected {self.expected_exception.__name__} to be raised, but no exception was raised")
|
67 |
+
if not issubclass(exc_type, self.expected_exception):
|
68 |
+
raise AssertionError(f"Expected {self.expected_exception.__name__} to be raised, but got {exc_type.__name__}: {exc_value}")
|
69 |
+
return True # Suppress the exception
|
70 |
+
|
71 |
+
def fail(self, msg=None):
|
72 |
+
"""Fail immediately with the given message"""
|
73 |
+
raise AssertionError(msg or "Test failed")
|
74 |
+
|
75 |
+
def run_agent(self, input, agent: Agent):
|
76 |
+
swarm = Swarm(agent, max_steps=1)
|
77 |
+
return Runners.sync_run(
|
78 |
+
input=input,
|
79 |
+
swarm=swarm
|
80 |
+
)
|
81 |
+
|
82 |
+
def run_multi_agent(self, input, agent1: Agent, agent2: Agent):
|
83 |
+
swarm = Swarm(agent1, agent2, max_steps=1)
|
84 |
+
return Runners.sync_run(
|
85 |
+
input=input,
|
86 |
+
swarm=swarm
|
87 |
+
)
|
88 |
+
|
89 |
+
def run_task(self, context: Context, agent: Agent):
|
90 |
+
swarm = Swarm(agent, max_steps=1)
|
91 |
+
task = Task(input="""What is an agent.""", swarm=swarm, context=context)
|
92 |
+
return Runners.sync_run_task(task)
|
93 |
+
|
94 |
+
def test_default_context_configuration(self):
|
95 |
+
|
96 |
+
# No need to explicitly configure context_rule, system automatically uses default configuration
|
97 |
+
# Default configuration is equivalent to:
|
98 |
+
# context_rule=ContextRuleConfig(
|
99 |
+
# optimization_config=OptimizationConfig(
|
100 |
+
# enabled=True,
|
101 |
+
# max_token_budget_ratio=1.0 # Use 100% of context window
|
102 |
+
# ),
|
103 |
+
# llm_compression_config=LlmCompressionConfig(
|
104 |
+
# enabled=False # Compression disabled by default
|
105 |
+
# )
|
106 |
+
# )
|
107 |
+
mock_agent = self.init_agent("1")
|
108 |
+
response = self.run_agent(input="""What is an agent. describe within 20 words""", agent=mock_agent)
|
109 |
+
|
110 |
+
self.assertIsNotNone(response.answer)
|
111 |
+
self.assertEqual(mock_agent.agent_context.model_config.llm_model_name, self.mock_model_name)
|
112 |
+
|
113 |
+
# Test default context rule behavior
|
114 |
+
self.assertIsNotNone(mock_agent.agent_context.context_rule)
|
115 |
+
self.assertIsNotNone(mock_agent.agent_context.context_rule.optimization_config)
|
116 |
+
|
117 |
+
def test_custom_context_configuration(self):
|
118 |
+
"""Test custom context configuration (README Configuration example)"""
|
119 |
+
# Create custom context rules
|
120 |
+
mock_agent = self.init_agent(context_rule=ContextRuleConfig(
|
121 |
+
optimization_config=OptimizationConfig(
|
122 |
+
enabled=True,
|
123 |
+
max_token_budget_ratio=0.00015
|
124 |
+
),
|
125 |
+
llm_compression_config=LlmCompressionConfig(
|
126 |
+
enabled=True,
|
127 |
+
trigger_compress_token_length=100,
|
128 |
+
compress_model=ModelConfig(
|
129 |
+
llm_model_name=self.mock_model_name,
|
130 |
+
llm_base_url=self.mock_base_url,
|
131 |
+
llm_api_key=self.mock_api_key,
|
132 |
+
)
|
133 |
+
)
|
134 |
+
))
|
135 |
+
|
136 |
+
response = self.run_agent(input="""describe What is an agent in details""", agent=mock_agent)
|
137 |
+
self.assertIsNotNone(response.answer)
|
138 |
+
|
139 |
+
# Test configuration values
|
140 |
+
self.assertTrue(mock_agent.agent_context.context_rule.optimization_config.enabled)
|
141 |
+
self.assertTrue(mock_agent.agent_context.context_rule.llm_compression_config.enabled)
|
142 |
+
|
143 |
+
def test_state_management_and_recovery(self):
|
144 |
+
class StateModifyAgent(Agent):
|
145 |
+
async def async_policy(self, observation, info=None, **kwargs):
|
146 |
+
result = await super().async_policy(observation, info, **kwargs)
|
147 |
+
self.context.state['policy_executed'] = True
|
148 |
+
return result
|
149 |
+
|
150 |
+
class StateTrackingAgent(Agent):
|
151 |
+
async def async_policy(self, observation, info=None, **kwargs):
|
152 |
+
result = await super().async_policy(observation, info, **kwargs)
|
153 |
+
assert self.context.state['policy_executed'] == True
|
154 |
+
return result
|
155 |
+
|
156 |
+
# Create custom agent instance
|
157 |
+
custom_agent = StateModifyAgent(
|
158 |
+
conf=AgentConfig(
|
159 |
+
llm_model_name=self.mock_model_name,
|
160 |
+
llm_base_url=self.mock_base_url,
|
161 |
+
llm_api_key=self.mock_api_key
|
162 |
+
),
|
163 |
+
name="state_modify_agent",
|
164 |
+
system_prompt="You are a Python expert who provides detailed and practical answers.",
|
165 |
+
agent_prompt="You are a Python expert who provides detailed and practical answers.",
|
166 |
+
)
|
167 |
+
|
168 |
+
# Create a second agent for multi-agent testing
|
169 |
+
second_agent = StateTrackingAgent(
|
170 |
+
conf=AgentConfig(
|
171 |
+
llm_model_name=self.mock_model_name,
|
172 |
+
llm_base_url=self.mock_base_url,
|
173 |
+
llm_api_key=self.mock_api_key
|
174 |
+
),
|
175 |
+
name="state_tracking_agent",
|
176 |
+
system_prompt="You are a helpful assistant.",
|
177 |
+
agent_prompt="You are a helpful assistant.",
|
178 |
+
)
|
179 |
+
|
180 |
+
response = self.run_multi_agent(
|
181 |
+
input="What is an agent. describe within 20 words",
|
182 |
+
agent1=custom_agent,
|
183 |
+
agent2=second_agent
|
184 |
+
)
|
185 |
+
self.assertIsNotNone(response.answer)
|
186 |
+
|
187 |
+
# Verify state changes after execution
|
188 |
+
self.assertTrue(custom_agent.context.state.get('policy_executed', True))
|
189 |
+
self.assertTrue(second_agent.agent_context.state.get('policy_executed', True))
|
190 |
+
|
191 |
+
|
192 |
+
class TestHookSystem(TestContextManagement):
|
193 |
+
def __init__(self):
|
194 |
+
super().__init__()
|
195 |
+
|
196 |
+
def test_hook_registration(self):
|
197 |
+
from tests.test_llm_hook import TestPreLLMHook, TestPostLLMHook
|
198 |
+
"""Test hook registration and retrieval"""
|
199 |
+
# Test that hooks are registered in _cls attribute
|
200 |
+
self.assertIn("TestPreLLMHook", HookFactory._cls)
|
201 |
+
self.assertIn("TestPostLLMHook", HookFactory._cls)
|
202 |
+
|
203 |
+
# Test hook creation using __call__ method
|
204 |
+
pre_hook = HookFactory("TestPreLLMHook")
|
205 |
+
post_hook = HookFactory("TestPostLLMHook")
|
206 |
+
|
207 |
+
self.assertIsInstance(pre_hook, TestPreLLMHook)
|
208 |
+
self.assertIsInstance(post_hook, TestPostLLMHook)
|
209 |
+
|
210 |
+
def test_hook_execution(self):
|
211 |
+
from tests.test_llm_hook import TestPreLLMHook, TestPostLLMHook
|
212 |
+
|
213 |
+
mock_agent = self.init_agent("1")
|
214 |
+
response = self.run_agent(input="""What is an agent. describe within 20 words""", agent=mock_agent)
|
215 |
+
self.assertIsNotNone(response.answer)
|
216 |
+
|
217 |
+
def test_task_context_transfer(self):
|
218 |
+
from tests.test_context_hook import CheckContextPreLLMHook
|
219 |
+
|
220 |
+
mock_agent = self.init_agent("1")
|
221 |
+
context = Context.instance()
|
222 |
+
context.state.update({"task": "What is an agent."})
|
223 |
+
self.run_task(context=context, agent=mock_agent)
|
224 |
+
|
225 |
+
|
226 |
+
if __name__ == '__main__':
|
227 |
+
testContextManagement = TestContextManagement()
|
228 |
+
testContextManagement.test_default_context_configuration()
|
229 |
+
testContextManagement.test_custom_context_configuration()
|
230 |
+
testContextManagement.test_state_management_and_recovery()
|
231 |
+
testHookSystem = TestHookSystem()
|
232 |
+
testHookSystem.test_hook_registration()
|
233 |
+
testHookSystem = TestHookSystem()
|
234 |
+
testHookSystem.test_hook_execution()
|
235 |
+
testHookSystem = TestHookSystem()
|
236 |
+
testHookSystem.test_task_context_transfer()
|
tests/test_demo.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import unittest
|
2 |
+
|
3 |
+
|
4 |
+
|
5 |
+
class TestShellTool(unittest.TestCase):
|
6 |
+
|
7 |
+
|
8 |
+
def test_init(self):
|
9 |
+
"""Test initialization"""
|
10 |
+
print("Test initialization")
|
11 |
+
|
tests/test_llm_hook.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# Add the project root to Python path
|
3 |
+
from pathlib import Path
|
4 |
+
import sys
|
5 |
+
|
6 |
+
project_root = Path(__file__).parent.parent
|
7 |
+
sys.path.insert(0, str(project_root))
|
8 |
+
|
9 |
+
from aworld.core.task import Task
|
10 |
+
from aworld.core.agent.base import AgentFactory
|
11 |
+
from aworld.core.agent.swarm import Swarm
|
12 |
+
from aworld.runner import Runners
|
13 |
+
from aworld.agents.llm_agent import Agent
|
14 |
+
from aworld.config.conf import AgentConfig, ContextRuleConfig, ModelConfig, OptimizationConfig, LlmCompressionConfig
|
15 |
+
from aworld.core.context.base import Context
|
16 |
+
from aworld.core.event.base import Message
|
17 |
+
from aworld.runners.hook.hooks import PreLLMCallHook, PostLLMCallHook
|
18 |
+
from aworld.runners.hook.hook_factory import HookFactory
|
19 |
+
from aworld.utils.common import convert_to_snake
|
20 |
+
from tests.base_test import BaseTest
|
21 |
+
|
22 |
+
|
23 |
+
@HookFactory.register(name="TestPreLLMHook", desc="Test pre-LLM hook")
|
24 |
+
class TestPreLLMHook(PreLLMCallHook):
|
25 |
+
def name(self):
|
26 |
+
return convert_to_snake("TestPreLLMHook")
|
27 |
+
async def exec(self, message: Message, context: Context = None) -> Message:
|
28 |
+
agent = AgentFactory.agent_instance(message.sender)
|
29 |
+
agent_context = agent.agent_context
|
30 |
+
if agent_context is not None:
|
31 |
+
agent_context.step = 1
|
32 |
+
assert agent_context.step == 1 or agent_context.step == 2
|
33 |
+
return message
|
34 |
+
|
35 |
+
@HookFactory.register(name="TestPostLLMHook", desc="Test post-LLM hook")
|
36 |
+
class TestPostLLMHook(PostLLMCallHook):
|
37 |
+
def name(self):
|
38 |
+
return convert_to_snake("TestPostLLMHook")
|
39 |
+
async def exec(self, message: Message, context: Context = None) -> Message:
|
40 |
+
agent = AgentFactory.agent_instance(message.sender)
|
41 |
+
agent_context = agent.agent_context
|
42 |
+
if agent_context is not None and agent_context.llm_output is not None:
|
43 |
+
# Test dynamic prompt adjustment based on LLM output
|
44 |
+
if hasattr(agent_context.llm_output, 'content'):
|
45 |
+
content = agent_context.llm_output.content.lower()
|
46 |
+
if content is not None:
|
47 |
+
agent_context.agent_prompt = "Success mode activated"
|
48 |
+
assert agent_context.agent_prompt == "Success mode activated"
|
49 |
+
return message
|