Spaces:
Sleeping
Sleeping
Upload 5 files
Browse files- aworld/runners/hook/agent_hooks.py +35 -0
- aworld/runners/hook/hook_factory.py +44 -0
- aworld/runners/hook/hooks.py +64 -0
- aworld/runners/hook/template.py +41 -0
- aworld/runners/hook/utils.py +55 -0
aworld/runners/hook/agent_hooks.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
# Copyright (c) 2025 inclusionAI.
|
3 |
+
import abc
|
4 |
+
from typing import AsyncGenerator
|
5 |
+
from aworld.core.context.base import Context, AgentContext
|
6 |
+
from aworld.core.event.base import Message
|
7 |
+
from aworld.runners.hook.hook_factory import HookFactory
|
8 |
+
from aworld.runners.hook.hooks import PostLLMCallHook, PreLLMCallHook
|
9 |
+
from aworld.utils.common import convert_to_snake
|
10 |
+
|
11 |
+
@HookFactory.register(name="PreLLMCallContextProcessHook",
|
12 |
+
desc="PreLLMCallContextProcessHook")
|
13 |
+
class PreLLMCallContextProcessHook(PreLLMCallHook):
|
14 |
+
"""Process in the hook point of the pre_llm_call."""
|
15 |
+
__metaclass__ = abc.ABCMeta
|
16 |
+
|
17 |
+
def name(self):
|
18 |
+
return convert_to_snake("PreLLMCallContextProcessHook")
|
19 |
+
|
20 |
+
async def exec(self, message: Message, context: Context = None) -> Message:
|
21 |
+
''' context.get_agent_context(message.sender) ''' # get agent context
|
22 |
+
# and do something
|
23 |
+
|
24 |
+
@HookFactory.register(name="PostLLMCallContextProcessHook",
|
25 |
+
desc="PostLLMCallContextProcessHook")
|
26 |
+
class PostLLMCallContextProcessHook(PostLLMCallHook):
|
27 |
+
"""Process in the hook point of the post_llm_call."""
|
28 |
+
__metaclass__ = abc.ABCMeta
|
29 |
+
|
30 |
+
def name(self):
|
31 |
+
return convert_to_snake("PostLLMCallContextProcessHook")
|
32 |
+
|
33 |
+
async def exec(self, message: Message, context: Context = None) -> Message:
|
34 |
+
'''context.get_agent_context(message.sender)''' # get agent context
|
35 |
+
|
aworld/runners/hook/hook_factory.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
# Copyright (c) 2025 inclusionAI.
|
3 |
+
import sys
|
4 |
+
from typing import Dict, List
|
5 |
+
|
6 |
+
from aworld.core.factory import Factory
|
7 |
+
from aworld.logs.util import logger
|
8 |
+
from aworld.runners.hook.hooks import Hook, StartHook, HookPoint
|
9 |
+
|
10 |
+
|
11 |
+
class HookManager(Factory):
|
12 |
+
def __init__(self, type_name: str = None):
|
13 |
+
super(HookManager, self).__init__(type_name)
|
14 |
+
|
15 |
+
def __call__(self, name: str, **kwargs):
|
16 |
+
if name is None:
|
17 |
+
raise ValueError("hook name is None")
|
18 |
+
|
19 |
+
try:
|
20 |
+
if name in self._cls:
|
21 |
+
act = self._cls[name](**kwargs)
|
22 |
+
else:
|
23 |
+
raise RuntimeError("The hook was not registered.\nPlease confirm the package has been imported.")
|
24 |
+
except Exception:
|
25 |
+
err = sys.exc_info()
|
26 |
+
logger.warning(f"Failed to create hook with name {name}:\n{err[1]}")
|
27 |
+
act = None
|
28 |
+
return act
|
29 |
+
|
30 |
+
def hooks(self, name: str = None) -> Dict[str, List[Hook]]:
|
31 |
+
vals = list(filter(lambda s: not s.startswith('__'), dir(HookPoint)))
|
32 |
+
results = {val.lower(): [] for val in vals}
|
33 |
+
|
34 |
+
for k, v in self._cls.items():
|
35 |
+
hook = v()
|
36 |
+
if name and hook.point() != name:
|
37 |
+
continue
|
38 |
+
|
39 |
+
results.get(hook.point(), []).append(hook)
|
40 |
+
|
41 |
+
return results
|
42 |
+
|
43 |
+
|
44 |
+
HookFactory = HookManager("hook_type")
|
aworld/runners/hook/hooks.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
# Copyright (c) 2025 inclusionAI.
|
3 |
+
import abc
|
4 |
+
from typing import AsyncGenerator
|
5 |
+
from aworld.core.context.base import Context, AgentContext
|
6 |
+
from aworld.core.event.base import Message
|
7 |
+
|
8 |
+
|
9 |
+
class HookPoint:
|
10 |
+
START = "start"
|
11 |
+
FINISHED = "finished"
|
12 |
+
ERROR = "error"
|
13 |
+
PRE_LLM_CALL = "pre_llm_call"
|
14 |
+
POST_LLM_CALL = "post_llm_call"
|
15 |
+
|
16 |
+
class Hook:
|
17 |
+
"""Runner hook."""
|
18 |
+
__metaclass__ = abc.ABCMeta
|
19 |
+
|
20 |
+
@abc.abstractmethod
|
21 |
+
def point(self):
|
22 |
+
"""Hook point."""
|
23 |
+
|
24 |
+
@abc.abstractmethod
|
25 |
+
async def exec(self, message: Message, context: Context = None) -> Message:
|
26 |
+
"""Execute hook function."""
|
27 |
+
|
28 |
+
|
29 |
+
class StartHook(Hook):
|
30 |
+
"""Process in the hook point of the start."""
|
31 |
+
__metaclass__ = abc.ABCMeta
|
32 |
+
|
33 |
+
def point(self):
|
34 |
+
return HookPoint.START
|
35 |
+
|
36 |
+
|
37 |
+
class FinishedHook(Hook):
|
38 |
+
"""Process in the hook point of the finished."""
|
39 |
+
__metaclass__ = abc.ABCMeta
|
40 |
+
|
41 |
+
def point(self):
|
42 |
+
return HookPoint.FINISHED
|
43 |
+
|
44 |
+
|
45 |
+
class ErrorHook(Hook):
|
46 |
+
"""Process in the hook point of the error."""
|
47 |
+
__metaclass__ = abc.ABCMeta
|
48 |
+
|
49 |
+
def point(self):
|
50 |
+
return HookPoint.ERROR
|
51 |
+
|
52 |
+
class PreLLMCallHook(Hook):
|
53 |
+
"""Process in the hook point of the pre_llm_call."""
|
54 |
+
__metaclass__ = abc.ABCMeta
|
55 |
+
|
56 |
+
def point(self):
|
57 |
+
return HookPoint.PRE_LLM_CALL
|
58 |
+
|
59 |
+
class PostLLMCallHook(Hook):
|
60 |
+
"""Process in the hook point of the post_llm_call."""
|
61 |
+
__metaclass__ = abc.ABCMeta
|
62 |
+
|
63 |
+
def point(self):
|
64 |
+
return HookPoint.POST_LLM_CALL
|
aworld/runners/hook/template.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
# Copyright (c) 2025 inclusionAI.
|
3 |
+
|
4 |
+
HOOK_TEMPLATE = """
|
5 |
+
import traceback
|
6 |
+
|
7 |
+
from aworld.core.context.base import Context
|
8 |
+
|
9 |
+
from aworld.core.event.base import Message, Constants, TopicType
|
10 |
+
from aworld.runners.hook.hooks import *
|
11 |
+
from aworld.runners.hook.hook_factory import HookFactory
|
12 |
+
from aworld.logs.util import logger
|
13 |
+
|
14 |
+
from aworld.utils.common import convert_to_snake
|
15 |
+
|
16 |
+
|
17 |
+
@HookFactory.register(name="{name}",
|
18 |
+
desc="{desc}")
|
19 |
+
class {name}({point}Hook):
|
20 |
+
def name(self):
|
21 |
+
return convert_to_snake("{name}")
|
22 |
+
|
23 |
+
async def exec(self, message: Message) -> Message:
|
24 |
+
{func_import}import {func}
|
25 |
+
try:
|
26 |
+
res = {func}(message)
|
27 |
+
if not res:
|
28 |
+
raise ValueError(f"{func} no result return.")
|
29 |
+
return Message(payload=res,
|
30 |
+
session_id=Context.instance().session_id,
|
31 |
+
sender="{name}",
|
32 |
+
category=Constants.TASK,
|
33 |
+
topic="{topic}")
|
34 |
+
except Exception as e:
|
35 |
+
logger.error(traceback.format_exc())
|
36 |
+
return Message(payload=str(e),
|
37 |
+
session_id=Context.instance().session_id,
|
38 |
+
sender="{name}",
|
39 |
+
category=Constants.TASK,
|
40 |
+
topic=TopicType.ERROR)
|
41 |
+
"""
|
aworld/runners/hook/utils.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
# Copyright (c) 2025 inclusionAI.
|
3 |
+
|
4 |
+
import importlib
|
5 |
+
import inspect
|
6 |
+
import os
|
7 |
+
from typing import Callable, Any
|
8 |
+
|
9 |
+
from aworld.runners.hook.template import HOOK_TEMPLATE
|
10 |
+
from aworld.utils.common import snake_to_camel
|
11 |
+
|
12 |
+
|
13 |
+
def hook(hook_point: str, name: str = None):
|
14 |
+
"""Hook decorator.
|
15 |
+
|
16 |
+
NOTE: Hooks can be annotated, but they need to comply with the protocol agreement.
|
17 |
+
The input parameter of the hook function is `Message` type, and the @hook needs to specify `hook_point`.
|
18 |
+
|
19 |
+
Examples:
|
20 |
+
>>> @hook(hook_point=HookPoint.ERROR)
|
21 |
+
>>> def error_process(message: Message) -> Message | None:
|
22 |
+
>>> print("process error")
|
23 |
+
The function `error_process` will be executed when an error message appears in the task,
|
24 |
+
you can choose return nothing or return a message.
|
25 |
+
|
26 |
+
Args:
|
27 |
+
hook_point: Hook point that wants to process the message.
|
28 |
+
name: Hook name.
|
29 |
+
"""
|
30 |
+
|
31 |
+
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
|
32 |
+
# converts python function into a hoop with associated hoop point
|
33 |
+
func_import = func.__module__
|
34 |
+
if func_import == '__main__':
|
35 |
+
path = inspect.getsourcefile(func)
|
36 |
+
package = path.replace(os.getcwd(), '').replace('.py', '')
|
37 |
+
if package[0] == '/':
|
38 |
+
package = package[1:]
|
39 |
+
func_import = f"from {package} "
|
40 |
+
else:
|
41 |
+
func_import = f"from {func_import} "
|
42 |
+
|
43 |
+
real_name = name if name else func.__name__
|
44 |
+
con = HOOK_TEMPLATE.format(func_import=func_import,
|
45 |
+
func=func.__name__,
|
46 |
+
point=snake_to_camel(hook_point),
|
47 |
+
name=real_name,
|
48 |
+
topic=hook_point,
|
49 |
+
desc='')
|
50 |
+
with open(f"{real_name}.py", 'w+') as write:
|
51 |
+
write.writelines(con)
|
52 |
+
importlib.import_module(real_name)
|
53 |
+
return func
|
54 |
+
|
55 |
+
return decorator
|