Duibonduil commited on
Commit
81ec5d0
·
verified ·
1 Parent(s): 0ab5ea3

Upload 7 files

Browse files
tests/base_test.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ class BaseTest:
4
+
5
+ # Custom assert methods implementation
6
+ def assertIsNotNone(self, value, msg=None):
7
+ """Assert that value is not None"""
8
+ if value is None:
9
+ raise AssertionError(msg or f"Expected not None, but got None")
10
+
11
+ def assertEqual(self, first, second, msg=None):
12
+ """Assert that first equals second"""
13
+ if first != second:
14
+ raise AssertionError(msg or f"Expected {first} == {second}, but {first} != {second}")
15
+
16
+ def assertTrue(self, expr, msg=None):
17
+ """Assert that expr is True"""
18
+ if not expr:
19
+ raise AssertionError(msg or f"Expected True, but got {expr}")
20
+
21
+ def assertFalse(self, expr, msg=None):
22
+ """Assert that expr is False"""
23
+ if expr:
24
+ raise AssertionError(msg or f"Expected False, but got {expr}")
25
+
26
+ def assertAlmostEqual(self, first, second, places=7, msg=None):
27
+ """Assert that first and second are approximately equal"""
28
+ if round(abs(second - first), places) != 0:
29
+ raise AssertionError(msg or f"Expected {first} ~= {second} (within {places} decimal places)")
30
+
31
+ def assertIs(self, first, second, msg=None):
32
+ """Assert that first is second (same object identity)"""
33
+ if first is not second:
34
+ raise AssertionError(msg or f"Expected {first} is {second}, but they are different objects")
35
+
36
+ def assertIn(self, member, container, msg=None):
37
+ """Assert that member is in container"""
38
+ if member not in container:
39
+ raise AssertionError(msg or f"Expected {member} in {container}")
40
+
41
+ def assertIsInstance(self, obj, cls, msg=None):
42
+ """Assert that obj is an instance of cls"""
43
+ if not isinstance(obj, cls):
44
+ raise AssertionError(msg or f"Expected {obj} to be instance of {cls}, but got {type(obj)}")
45
+
46
+ def assertIsNone(self, value, msg=None):
47
+ """Assert that value is None"""
48
+ if value is not None:
49
+ raise AssertionError(msg or f"Expected None, but got {value}")
50
+
51
+ def assertNotEqual(self, first, second, msg=None):
52
+ """Assert that first does not equal second"""
53
+ if first == second:
54
+ raise AssertionError(msg or f"Expected {first} != {second}, but they are equal")
55
+
56
+ def assertGreater(self, first, second, msg=None):
57
+ """Assert that first is greater than second"""
58
+ if not first > second:
59
+ raise AssertionError(msg or f"Expected {first} > {second}")
60
+
61
+ def assertLess(self, first, second, msg=None):
62
+ """Assert that first is less than second"""
63
+ if not first < second:
64
+ raise AssertionError(msg or f"Expected {first} < {second}")
65
+
66
+ def assertGreaterEqual(self, first, second, msg=None):
67
+ """Assert that first is greater than or equal to second"""
68
+ if not first >= second:
69
+ raise AssertionError(msg or f"Expected {first} >= {second}")
70
+
71
+ def assertLessEqual(self, first, second, msg=None):
72
+ """Assert that first is less than or equal to second"""
73
+ if not first <= second:
74
+ raise AssertionError(msg or f"Expected {first} <= {second}")
75
+
76
+ def assertNotIn(self, member, container, msg=None):
77
+ """Assert that member is not in container"""
78
+ if member in container:
79
+ raise AssertionError(msg or f"Expected {member} not in {container}")
80
+
81
+ def assertIsNot(self, first, second, msg=None):
82
+ """Assert that first is not second (different object identity)"""
83
+ if first is second:
84
+ raise AssertionError(msg or f"Expected {first} is not {second}, but they are the same object")
85
+
86
+ def assertRaises(self, exception_class, callable_obj=None, *args, **kwargs):
87
+ """Assert that calling callable_obj raises exception_class"""
88
+ if callable_obj is None:
89
+ # Return a context manager for use with 'with' statement
90
+ return self._AssertRaisesContext(exception_class)
91
+ else:
92
+ try:
93
+ callable_obj(*args, **kwargs)
94
+ raise AssertionError(f"Expected {exception_class.__name__} to be raised, but no exception was raised")
95
+ except exception_class:
96
+ pass # Expected exception was raised
97
+ except Exception as e:
98
+ raise AssertionError(f"Expected {exception_class.__name__} to be raised, but got {type(e).__name__}: {e}")
99
+
tests/test.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"test": "test content"}
tests/test_context_compressor.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+ import sys
3
+ import os
4
+ from unittest.mock import Mock, patch, MagicMock
5
+ from pathlib import Path
6
+
7
+ # Add the project root to Python path
8
+ project_root = Path(__file__).parent.parent
9
+ sys.path.insert(0, str(project_root))
10
+
11
+ from tests.base_test import BaseTest
12
+
13
+ from aworld.config.conf import AgentConfig, ModelConfig, ContextRuleConfig, OptimizationConfig, LlmCompressionConfig
14
+ from aworld.core.context.processor import CompressionResult, CompressionType
15
+ from aworld.core.context.processor.llm_compressor import LLMCompressor
16
+ from aworld.core.context.processor.prompt_processor import PromptProcessor
17
+ from aworld.core.context.base import AgentContext, ContextUsage
18
+
19
+
20
+ class TestPromptCompressor(BaseTest):
21
+ """Test cases for PromptCompressor.compress_batch function"""
22
+
23
+ def __init__(self):
24
+ """Set up test fixtures"""
25
+ self.mock_model_name = "qwen/qwen3-1.7b"
26
+ self.mock_base_url = "http://localhost:1234/v1"
27
+ self.mock_api_key = "lm-studio"
28
+ self.mock_llm_config = ModelConfig(
29
+ llm_model_name=self.mock_model_name,
30
+ llm_base_url=self.mock_base_url,
31
+ llm_api_key=self.mock_api_key
32
+ )
33
+ os.environ["LLM_API_KEY"] = self.mock_api_key
34
+ os.environ["LLM_BASE_URL"] = self.mock_base_url
35
+ os.environ["LLM_MODEL_NAME"] = self.mock_model_name
36
+
37
+ def test_compress_batch_basic(self):
38
+
39
+ compressor = LLMCompressor(
40
+ llm_config=self.mock_llm_config
41
+ )
42
+
43
+ # Test data
44
+ contents = [
45
+ "[SYSTEM]You are a helpful assistant.\n[USER]This is the first long text content that needs compression. This is the first long text content that needs compression.",
46
+ ]
47
+
48
+ # Execute compress_batch
49
+ results = compressor.compress_batch(contents)
50
+
51
+ # Assertions
52
+ for result in results:
53
+ self.assertIsInstance(result, CompressionResult)
54
+ self.assertEqual(result.compression_type, CompressionType.LLM_BASED)
55
+ self.assertTrue('This is the first long text content that needs compression. This is the first long text content that needs compression.' not in result.compressed_content)
56
+
57
+ def test_compress_messages(self):
58
+ """Test compress_messages function from PromptProcessor"""
59
+
60
+ # Create context rule with compression enabled
61
+ context_rule = ContextRuleConfig(
62
+ optimization_config=OptimizationConfig(
63
+ enabled=True,
64
+ max_token_budget_ratio=0.8
65
+ ),
66
+ llm_compression_config=LlmCompressionConfig(
67
+ enabled=True,
68
+ trigger_compress_token_length=10, # Low threshold to trigger compression
69
+ compress_model=self.mock_llm_config
70
+ )
71
+ )
72
+
73
+ # Create agent context
74
+ agent_context = AgentContext(
75
+ agent_id="test_agent",
76
+ agent_name="test_agent",
77
+ agent_desc="Test agent for compression",
78
+ system_prompt="You are a helpful assistant.",
79
+ agent_prompt="You are a helpful assistant.",
80
+ model_config=self.mock_llm_config,
81
+ context_rule=context_rule,
82
+ context_usage=ContextUsage(total_context_length=4096)
83
+ )
84
+
85
+ # Create prompt processor
86
+ processor = PromptProcessor(agent_context)
87
+
88
+ # Test messages with repeated content that needs compression
89
+ messages = [
90
+ {
91
+ "role": "system",
92
+ "content": "You are a helpful assistant."
93
+ },
94
+ {
95
+ "role": "user",
96
+ "content": "This is the first long text content that needs compression. This is the first long text content that needs compression."
97
+ },
98
+ {
99
+ "role": "assistant",
100
+ "content": "I understand you want me to help with compression."
101
+ }
102
+ ]
103
+
104
+ # Execute compress_messages
105
+ compressed_messages = processor.compress_messages(messages)
106
+
107
+ # Assertions
108
+ self.assertIsInstance(compressed_messages, list)
109
+ self.assertEqual(len(compressed_messages), len(messages))
110
+
111
+ # Find the user message and verify it was processed
112
+ user_message = None
113
+ for msg in compressed_messages:
114
+ if msg.get("role") == "user":
115
+ user_message = msg
116
+ break
117
+
118
+ self.assertIsNotNone(user_message)
119
+ # The original repeated text should be compressed
120
+ original_content = "This is the first long text content that needs compression. This is the first long text content that needs compression."
121
+ self.assertNotEqual(user_message["content"], original_content)
122
+ # The compressed content should be shorter than original
123
+ self.assertLess(len(user_message["content"]), len(original_content))
124
+
125
+ if __name__ == '__main__':
126
+ testPromptCompressor = TestPromptCompressor()
127
+ testPromptCompressor.test_compress_batch_basic()
128
+ testPromptCompressor = TestPromptCompressor()
129
+ testPromptCompressor.test_compress_messages()
130
+
131
+
132
+
133
+
tests/test_context_hook.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Add the project root to Python path
3
+ from pathlib import Path
4
+ import sys
5
+
6
+ from aworld.logs.util import color_log
7
+
8
+
9
+ project_root = Path(__file__).parent.parent
10
+ sys.path.insert(0, str(project_root))
11
+
12
+ from aworld.core.task import Task
13
+ from aworld.core.agent.base import AgentFactory
14
+ from aworld.core.agent.swarm import Swarm
15
+ from aworld.runner import Runners
16
+ from aworld.agents.llm_agent import Agent
17
+ from aworld.config.conf import AgentConfig, ContextRuleConfig, ModelConfig, OptimizationConfig, LlmCompressionConfig
18
+ from aworld.core.context.base import Context
19
+ from aworld.core.event.base import Message
20
+ from aworld.runners.hook.hooks import PreLLMCallHook, PostLLMCallHook
21
+ from aworld.runners.hook.hook_factory import HookFactory
22
+ from aworld.utils.common import convert_to_snake
23
+ from tests.base_test import BaseTest
24
+
25
+ @HookFactory.register(name="CheckContextPreLLMHook", desc="Test pre-LLM hook")
26
+ class CheckContextPreLLMHook(PreLLMCallHook):
27
+ def name(self):
28
+ return convert_to_snake("CheckContextPreLLMHook")
29
+ async def exec(self, message: Message, context: Context = None) -> Message:
30
+ assert context.state.get("task") == "What is an agent."
31
+ color_log("CheckContextPreLLMHook test ok", color="green")
32
+ return message
tests/test_context_management.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import sys
4
+ from pathlib import Path
5
+
6
+ # Add the project root to Python path
7
+ project_root = Path(__file__).parent.parent
8
+ sys.path.insert(0, str(project_root))
9
+
10
+ from aworld.core.task import Task
11
+ from aworld.core.agent.swarm import Swarm
12
+ from aworld.runner import Runners
13
+ from aworld.agents.llm_agent import Agent
14
+ from aworld.config.conf import AgentConfig, ContextRuleConfig, ModelConfig, OptimizationConfig, LlmCompressionConfig
15
+ from aworld.core.context.base import Context
16
+ from aworld.runners.hook.hook_factory import HookFactory
17
+ from tests.base_test import BaseTest
18
+
19
+
20
+ class TestContextManagement(BaseTest):
21
+ """Test cases for Context Management system based on README examples"""
22
+
23
+ def __init__(self):
24
+ """Set up test fixtures"""
25
+ self.mock_model_name = "qwen/qwen3-1.7b"
26
+ self.mock_base_url = "http://localhost:1234/v1"
27
+ self.mock_api_key = "lm-studio"
28
+ os.environ["LLM_API_KEY"] = self.mock_api_key
29
+ os.environ["LLM_BASE_URL"] = self.mock_base_url
30
+ os.environ["LLM_MODEL_NAME"] = self.mock_model_name
31
+
32
+ def init_agent(self, config_type: str = "1", context_rule: ContextRuleConfig = None):
33
+ if config_type == "1":
34
+ conf = AgentConfig(
35
+ llm_model_name=self.mock_model_name,
36
+ llm_base_url=self.mock_base_url,
37
+ llm_api_key=self.mock_api_key
38
+ )
39
+ else:
40
+ conf = AgentConfig(
41
+ llm_config=ModelConfig(
42
+ llm_model_name=self.mock_model_name,
43
+ llm_base_url=self.mock_base_url,
44
+ llm_api_key=self.mock_api_key
45
+ )
46
+ )
47
+ return Agent(
48
+ conf=conf,
49
+ name="my_agent" + str(random.randint(0, 1000000)),
50
+ system_prompt="You are a helpful assistant.",
51
+ agent_prompt="You are a helpful assistant.",
52
+ context_rule=context_rule
53
+ )
54
+
55
+ class _AssertRaisesContext:
56
+ """Context manager for assertRaises"""
57
+
58
+ def __init__(self, expected_exception):
59
+ self.expected_exception = expected_exception
60
+
61
+ def __enter__(self):
62
+ return self
63
+
64
+ def __exit__(self, exc_type, exc_value, traceback):
65
+ if exc_type is None:
66
+ raise AssertionError(f"Expected {self.expected_exception.__name__} to be raised, but no exception was raised")
67
+ if not issubclass(exc_type, self.expected_exception):
68
+ raise AssertionError(f"Expected {self.expected_exception.__name__} to be raised, but got {exc_type.__name__}: {exc_value}")
69
+ return True # Suppress the exception
70
+
71
+ def fail(self, msg=None):
72
+ """Fail immediately with the given message"""
73
+ raise AssertionError(msg or "Test failed")
74
+
75
+ def run_agent(self, input, agent: Agent):
76
+ swarm = Swarm(agent, max_steps=1)
77
+ return Runners.sync_run(
78
+ input=input,
79
+ swarm=swarm
80
+ )
81
+
82
+ def run_multi_agent(self, input, agent1: Agent, agent2: Agent):
83
+ swarm = Swarm(agent1, agent2, max_steps=1)
84
+ return Runners.sync_run(
85
+ input=input,
86
+ swarm=swarm
87
+ )
88
+
89
+ def run_task(self, context: Context, agent: Agent):
90
+ swarm = Swarm(agent, max_steps=1)
91
+ task = Task(input="""What is an agent.""", swarm=swarm, context=context)
92
+ return Runners.sync_run_task(task)
93
+
94
+ def test_default_context_configuration(self):
95
+
96
+ # No need to explicitly configure context_rule, system automatically uses default configuration
97
+ # Default configuration is equivalent to:
98
+ # context_rule=ContextRuleConfig(
99
+ # optimization_config=OptimizationConfig(
100
+ # enabled=True,
101
+ # max_token_budget_ratio=1.0 # Use 100% of context window
102
+ # ),
103
+ # llm_compression_config=LlmCompressionConfig(
104
+ # enabled=False # Compression disabled by default
105
+ # )
106
+ # )
107
+ mock_agent = self.init_agent("1")
108
+ response = self.run_agent(input="""What is an agent. describe within 20 words""", agent=mock_agent)
109
+
110
+ self.assertIsNotNone(response.answer)
111
+ self.assertEqual(mock_agent.agent_context.model_config.llm_model_name, self.mock_model_name)
112
+
113
+ # Test default context rule behavior
114
+ self.assertIsNotNone(mock_agent.agent_context.context_rule)
115
+ self.assertIsNotNone(mock_agent.agent_context.context_rule.optimization_config)
116
+
117
+ def test_custom_context_configuration(self):
118
+ """Test custom context configuration (README Configuration example)"""
119
+ # Create custom context rules
120
+ mock_agent = self.init_agent(context_rule=ContextRuleConfig(
121
+ optimization_config=OptimizationConfig(
122
+ enabled=True,
123
+ max_token_budget_ratio=0.00015
124
+ ),
125
+ llm_compression_config=LlmCompressionConfig(
126
+ enabled=True,
127
+ trigger_compress_token_length=100,
128
+ compress_model=ModelConfig(
129
+ llm_model_name=self.mock_model_name,
130
+ llm_base_url=self.mock_base_url,
131
+ llm_api_key=self.mock_api_key,
132
+ )
133
+ )
134
+ ))
135
+
136
+ response = self.run_agent(input="""describe What is an agent in details""", agent=mock_agent)
137
+ self.assertIsNotNone(response.answer)
138
+
139
+ # Test configuration values
140
+ self.assertTrue(mock_agent.agent_context.context_rule.optimization_config.enabled)
141
+ self.assertTrue(mock_agent.agent_context.context_rule.llm_compression_config.enabled)
142
+
143
+ def test_state_management_and_recovery(self):
144
+ class StateModifyAgent(Agent):
145
+ async def async_policy(self, observation, info=None, **kwargs):
146
+ result = await super().async_policy(observation, info, **kwargs)
147
+ self.context.state['policy_executed'] = True
148
+ return result
149
+
150
+ class StateTrackingAgent(Agent):
151
+ async def async_policy(self, observation, info=None, **kwargs):
152
+ result = await super().async_policy(observation, info, **kwargs)
153
+ assert self.context.state['policy_executed'] == True
154
+ return result
155
+
156
+ # Create custom agent instance
157
+ custom_agent = StateModifyAgent(
158
+ conf=AgentConfig(
159
+ llm_model_name=self.mock_model_name,
160
+ llm_base_url=self.mock_base_url,
161
+ llm_api_key=self.mock_api_key
162
+ ),
163
+ name="state_modify_agent",
164
+ system_prompt="You are a Python expert who provides detailed and practical answers.",
165
+ agent_prompt="You are a Python expert who provides detailed and practical answers.",
166
+ )
167
+
168
+ # Create a second agent for multi-agent testing
169
+ second_agent = StateTrackingAgent(
170
+ conf=AgentConfig(
171
+ llm_model_name=self.mock_model_name,
172
+ llm_base_url=self.mock_base_url,
173
+ llm_api_key=self.mock_api_key
174
+ ),
175
+ name="state_tracking_agent",
176
+ system_prompt="You are a helpful assistant.",
177
+ agent_prompt="You are a helpful assistant.",
178
+ )
179
+
180
+ response = self.run_multi_agent(
181
+ input="What is an agent. describe within 20 words",
182
+ agent1=custom_agent,
183
+ agent2=second_agent
184
+ )
185
+ self.assertIsNotNone(response.answer)
186
+
187
+ # Verify state changes after execution
188
+ self.assertTrue(custom_agent.context.state.get('policy_executed', True))
189
+ self.assertTrue(second_agent.agent_context.state.get('policy_executed', True))
190
+
191
+
192
+ class TestHookSystem(TestContextManagement):
193
+ def __init__(self):
194
+ super().__init__()
195
+
196
+ def test_hook_registration(self):
197
+ from tests.test_llm_hook import TestPreLLMHook, TestPostLLMHook
198
+ """Test hook registration and retrieval"""
199
+ # Test that hooks are registered in _cls attribute
200
+ self.assertIn("TestPreLLMHook", HookFactory._cls)
201
+ self.assertIn("TestPostLLMHook", HookFactory._cls)
202
+
203
+ # Test hook creation using __call__ method
204
+ pre_hook = HookFactory("TestPreLLMHook")
205
+ post_hook = HookFactory("TestPostLLMHook")
206
+
207
+ self.assertIsInstance(pre_hook, TestPreLLMHook)
208
+ self.assertIsInstance(post_hook, TestPostLLMHook)
209
+
210
+ def test_hook_execution(self):
211
+ from tests.test_llm_hook import TestPreLLMHook, TestPostLLMHook
212
+
213
+ mock_agent = self.init_agent("1")
214
+ response = self.run_agent(input="""What is an agent. describe within 20 words""", agent=mock_agent)
215
+ self.assertIsNotNone(response.answer)
216
+
217
+ def test_task_context_transfer(self):
218
+ from tests.test_context_hook import CheckContextPreLLMHook
219
+
220
+ mock_agent = self.init_agent("1")
221
+ context = Context.instance()
222
+ context.state.update({"task": "What is an agent."})
223
+ self.run_task(context=context, agent=mock_agent)
224
+
225
+
226
+ if __name__ == '__main__':
227
+ testContextManagement = TestContextManagement()
228
+ testContextManagement.test_default_context_configuration()
229
+ testContextManagement.test_custom_context_configuration()
230
+ testContextManagement.test_state_management_and_recovery()
231
+ testHookSystem = TestHookSystem()
232
+ testHookSystem.test_hook_registration()
233
+ testHookSystem = TestHookSystem()
234
+ testHookSystem.test_hook_execution()
235
+ testHookSystem = TestHookSystem()
236
+ testHookSystem.test_task_context_transfer()
tests/test_demo.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+
3
+
4
+
5
+ class TestShellTool(unittest.TestCase):
6
+
7
+
8
+ def test_init(self):
9
+ """Test initialization"""
10
+ print("Test initialization")
11
+
tests/test_llm_hook.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Add the project root to Python path
3
+ from pathlib import Path
4
+ import sys
5
+
6
+ project_root = Path(__file__).parent.parent
7
+ sys.path.insert(0, str(project_root))
8
+
9
+ from aworld.core.task import Task
10
+ from aworld.core.agent.base import AgentFactory
11
+ from aworld.core.agent.swarm import Swarm
12
+ from aworld.runner import Runners
13
+ from aworld.agents.llm_agent import Agent
14
+ from aworld.config.conf import AgentConfig, ContextRuleConfig, ModelConfig, OptimizationConfig, LlmCompressionConfig
15
+ from aworld.core.context.base import Context
16
+ from aworld.core.event.base import Message
17
+ from aworld.runners.hook.hooks import PreLLMCallHook, PostLLMCallHook
18
+ from aworld.runners.hook.hook_factory import HookFactory
19
+ from aworld.utils.common import convert_to_snake
20
+ from tests.base_test import BaseTest
21
+
22
+
23
+ @HookFactory.register(name="TestPreLLMHook", desc="Test pre-LLM hook")
24
+ class TestPreLLMHook(PreLLMCallHook):
25
+ def name(self):
26
+ return convert_to_snake("TestPreLLMHook")
27
+ async def exec(self, message: Message, context: Context = None) -> Message:
28
+ agent = AgentFactory.agent_instance(message.sender)
29
+ agent_context = agent.agent_context
30
+ if agent_context is not None:
31
+ agent_context.step = 1
32
+ assert agent_context.step == 1 or agent_context.step == 2
33
+ return message
34
+
35
+ @HookFactory.register(name="TestPostLLMHook", desc="Test post-LLM hook")
36
+ class TestPostLLMHook(PostLLMCallHook):
37
+ def name(self):
38
+ return convert_to_snake("TestPostLLMHook")
39
+ async def exec(self, message: Message, context: Context = None) -> Message:
40
+ agent = AgentFactory.agent_instance(message.sender)
41
+ agent_context = agent.agent_context
42
+ if agent_context is not None and agent_context.llm_output is not None:
43
+ # Test dynamic prompt adjustment based on LLM output
44
+ if hasattr(agent_context.llm_output, 'content'):
45
+ content = agent_context.llm_output.content.lower()
46
+ if content is not None:
47
+ agent_context.agent_prompt = "Success mode activated"
48
+ assert agent_context.agent_prompt == "Success mode activated"
49
+ return message