Duibonduil commited on
Commit
2814685
·
verified ·
1 Parent(s): 4416f87

Upload 5 files

Browse files
aworld/runners/hook/agent_hooks.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+ # Copyright (c) 2025 inclusionAI.
3
+ import abc
4
+ from typing import AsyncGenerator
5
+ from aworld.core.context.base import Context, AgentContext
6
+ from aworld.core.event.base import Message
7
+ from aworld.runners.hook.hook_factory import HookFactory
8
+ from aworld.runners.hook.hooks import PostLLMCallHook, PreLLMCallHook
9
+ from aworld.utils.common import convert_to_snake
10
+
11
+ @HookFactory.register(name="PreLLMCallContextProcessHook",
12
+ desc="PreLLMCallContextProcessHook")
13
+ class PreLLMCallContextProcessHook(PreLLMCallHook):
14
+ """Process in the hook point of the pre_llm_call."""
15
+ __metaclass__ = abc.ABCMeta
16
+
17
+ def name(self):
18
+ return convert_to_snake("PreLLMCallContextProcessHook")
19
+
20
+ async def exec(self, message: Message, context: Context = None) -> Message:
21
+ ''' context.get_agent_context(message.sender) ''' # get agent context
22
+ # and do something
23
+
24
+ @HookFactory.register(name="PostLLMCallContextProcessHook",
25
+ desc="PostLLMCallContextProcessHook")
26
+ class PostLLMCallContextProcessHook(PostLLMCallHook):
27
+ """Process in the hook point of the post_llm_call."""
28
+ __metaclass__ = abc.ABCMeta
29
+
30
+ def name(self):
31
+ return convert_to_snake("PostLLMCallContextProcessHook")
32
+
33
+ async def exec(self, message: Message, context: Context = None) -> Message:
34
+ '''context.get_agent_context(message.sender)''' # get agent context
35
+
aworld/runners/hook/hook_factory.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+ # Copyright (c) 2025 inclusionAI.
3
+ import sys
4
+ from typing import Dict, List
5
+
6
+ from aworld.core.factory import Factory
7
+ from aworld.logs.util import logger
8
+ from aworld.runners.hook.hooks import Hook, StartHook, HookPoint
9
+
10
+
11
+ class HookManager(Factory):
12
+ def __init__(self, type_name: str = None):
13
+ super(HookManager, self).__init__(type_name)
14
+
15
+ def __call__(self, name: str, **kwargs):
16
+ if name is None:
17
+ raise ValueError("hook name is None")
18
+
19
+ try:
20
+ if name in self._cls:
21
+ act = self._cls[name](**kwargs)
22
+ else:
23
+ raise RuntimeError("The hook was not registered.\nPlease confirm the package has been imported.")
24
+ except Exception:
25
+ err = sys.exc_info()
26
+ logger.warning(f"Failed to create hook with name {name}:\n{err[1]}")
27
+ act = None
28
+ return act
29
+
30
+ def hooks(self, name: str = None) -> Dict[str, List[Hook]]:
31
+ vals = list(filter(lambda s: not s.startswith('__'), dir(HookPoint)))
32
+ results = {val.lower(): [] for val in vals}
33
+
34
+ for k, v in self._cls.items():
35
+ hook = v()
36
+ if name and hook.point() != name:
37
+ continue
38
+
39
+ results.get(hook.point(), []).append(hook)
40
+
41
+ return results
42
+
43
+
44
+ HookFactory = HookManager("hook_type")
aworld/runners/hook/hooks.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+ # Copyright (c) 2025 inclusionAI.
3
+ import abc
4
+ from typing import AsyncGenerator
5
+ from aworld.core.context.base import Context, AgentContext
6
+ from aworld.core.event.base import Message
7
+
8
+
9
+ class HookPoint:
10
+ START = "start"
11
+ FINISHED = "finished"
12
+ ERROR = "error"
13
+ PRE_LLM_CALL = "pre_llm_call"
14
+ POST_LLM_CALL = "post_llm_call"
15
+
16
+ class Hook:
17
+ """Runner hook."""
18
+ __metaclass__ = abc.ABCMeta
19
+
20
+ @abc.abstractmethod
21
+ def point(self):
22
+ """Hook point."""
23
+
24
+ @abc.abstractmethod
25
+ async def exec(self, message: Message, context: Context = None) -> Message:
26
+ """Execute hook function."""
27
+
28
+
29
+ class StartHook(Hook):
30
+ """Process in the hook point of the start."""
31
+ __metaclass__ = abc.ABCMeta
32
+
33
+ def point(self):
34
+ return HookPoint.START
35
+
36
+
37
+ class FinishedHook(Hook):
38
+ """Process in the hook point of the finished."""
39
+ __metaclass__ = abc.ABCMeta
40
+
41
+ def point(self):
42
+ return HookPoint.FINISHED
43
+
44
+
45
+ class ErrorHook(Hook):
46
+ """Process in the hook point of the error."""
47
+ __metaclass__ = abc.ABCMeta
48
+
49
+ def point(self):
50
+ return HookPoint.ERROR
51
+
52
+ class PreLLMCallHook(Hook):
53
+ """Process in the hook point of the pre_llm_call."""
54
+ __metaclass__ = abc.ABCMeta
55
+
56
+ def point(self):
57
+ return HookPoint.PRE_LLM_CALL
58
+
59
+ class PostLLMCallHook(Hook):
60
+ """Process in the hook point of the post_llm_call."""
61
+ __metaclass__ = abc.ABCMeta
62
+
63
+ def point(self):
64
+ return HookPoint.POST_LLM_CALL
aworld/runners/hook/template.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+ # Copyright (c) 2025 inclusionAI.
3
+
4
+ HOOK_TEMPLATE = """
5
+ import traceback
6
+
7
+ from aworld.core.context.base import Context
8
+
9
+ from aworld.core.event.base import Message, Constants, TopicType
10
+ from aworld.runners.hook.hooks import *
11
+ from aworld.runners.hook.hook_factory import HookFactory
12
+ from aworld.logs.util import logger
13
+
14
+ from aworld.utils.common import convert_to_snake
15
+
16
+
17
+ @HookFactory.register(name="{name}",
18
+ desc="{desc}")
19
+ class {name}({point}Hook):
20
+ def name(self):
21
+ return convert_to_snake("{name}")
22
+
23
+ async def exec(self, message: Message) -> Message:
24
+ {func_import}import {func}
25
+ try:
26
+ res = {func}(message)
27
+ if not res:
28
+ raise ValueError(f"{func} no result return.")
29
+ return Message(payload=res,
30
+ session_id=Context.instance().session_id,
31
+ sender="{name}",
32
+ category=Constants.TASK,
33
+ topic="{topic}")
34
+ except Exception as e:
35
+ logger.error(traceback.format_exc())
36
+ return Message(payload=str(e),
37
+ session_id=Context.instance().session_id,
38
+ sender="{name}",
39
+ category=Constants.TASK,
40
+ topic=TopicType.ERROR)
41
+ """
aworld/runners/hook/utils.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+ # Copyright (c) 2025 inclusionAI.
3
+
4
+ import importlib
5
+ import inspect
6
+ import os
7
+ from typing import Callable, Any
8
+
9
+ from aworld.runners.hook.template import HOOK_TEMPLATE
10
+ from aworld.utils.common import snake_to_camel
11
+
12
+
13
+ def hook(hook_point: str, name: str = None):
14
+ """Hook decorator.
15
+
16
+ NOTE: Hooks can be annotated, but they need to comply with the protocol agreement.
17
+ The input parameter of the hook function is `Message` type, and the @hook needs to specify `hook_point`.
18
+
19
+ Examples:
20
+ >>> @hook(hook_point=HookPoint.ERROR)
21
+ >>> def error_process(message: Message) -> Message | None:
22
+ >>> print("process error")
23
+ The function `error_process` will be executed when an error message appears in the task,
24
+ you can choose return nothing or return a message.
25
+
26
+ Args:
27
+ hook_point: Hook point that wants to process the message.
28
+ name: Hook name.
29
+ """
30
+
31
+ def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
32
+ # converts python function into a hoop with associated hoop point
33
+ func_import = func.__module__
34
+ if func_import == '__main__':
35
+ path = inspect.getsourcefile(func)
36
+ package = path.replace(os.getcwd(), '').replace('.py', '')
37
+ if package[0] == '/':
38
+ package = package[1:]
39
+ func_import = f"from {package} "
40
+ else:
41
+ func_import = f"from {func_import} "
42
+
43
+ real_name = name if name else func.__name__
44
+ con = HOOK_TEMPLATE.format(func_import=func_import,
45
+ func=func.__name__,
46
+ point=snake_to_camel(hook_point),
47
+ name=real_name,
48
+ topic=hook_point,
49
+ desc='')
50
+ with open(f"{real_name}.py", 'w+') as write:
51
+ write.writelines(con)
52
+ importlib.import_module(real_name)
53
+ return func
54
+
55
+ return decorator