Duibonduil commited on
Commit
b05d426
·
verified ·
1 Parent(s): a5d4031

Upload 2 files

Browse files
Files changed (2) hide show
  1. aworld/events/manager.py +115 -0
  2. aworld/events/util.py +56 -0
aworld/events/manager.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+ # Copyright (c) 2025 inclusionAI.
3
+ from typing import Dict, Any, List, Callable
4
+
5
+ from aworld.core.context.base import Context
6
+ from aworld.core.event import eventbus
7
+ from aworld.core.event.base import Constants, Message
8
+
9
+
10
+ class EventManager:
11
+ """The event manager is now used to build an event bus instance and store the messages recently."""
12
+
13
+ def __init__(self, context: Context, **kwargs):
14
+ # use conf to build event bus instance
15
+ self.event_bus = eventbus
16
+ self.context = context
17
+ # Record events in memory for re-consume.
18
+ self.messages: Dict[str, List[Message]] = {'None': []}
19
+ self.max_len = kwargs.get('max_len', 1000)
20
+
21
+ async def emit(
22
+ self,
23
+ data: Any,
24
+ sender: str,
25
+ receiver: str = None,
26
+ topic: str = None,
27
+ session_id: str = None,
28
+ event_type: str = Constants.TASK
29
+ ):
30
+ """Send data to the event bus.
31
+
32
+ Args:
33
+ data: Message payload.
34
+ sender: The sender name of the message.
35
+ receiver: The receiver name of the message.
36
+ topic: The topic to which the message belongs.
37
+ session_id: Special session id.
38
+ event_type: Event type.
39
+ """
40
+ event = Message(
41
+ payload=data,
42
+ session_id=session_id if session_id else self.context.session_id,
43
+ sender=sender,
44
+ receiver=receiver,
45
+ topic=topic,
46
+ category=event_type,
47
+ )
48
+ return await self.emit_message(event)
49
+
50
+ async def emit_message(self, event: Message):
51
+ """Send the message to the event bus."""
52
+ key = event.key()
53
+ if key not in self.messages:
54
+ self.messages[key] = []
55
+ self.messages[key].append(event)
56
+ if len(self.messages) > self.max_len:
57
+ self.messages = self.messages[-self.max_len:]
58
+
59
+ await self.event_bus.publish(event)
60
+ return True
61
+
62
+ async def consume(self, nowait: bool = False):
63
+ msg = Message(session_id=self.context.session_id, sender="", category="", payload="")
64
+ msg.context = self.context
65
+ if nowait:
66
+ return await self.event_bus.consume_nowait(msg)
67
+ return await self.event_bus.consume(msg)
68
+
69
+ async def done(self):
70
+ await self.event_bus.done(self.context.task_id)
71
+
72
+ async def register(self, event_type: str, topic: str, handler: Callable[..., Any], **kwargs):
73
+ await self.event_bus.subscribe(event_type, topic, handler, **kwargs)
74
+
75
+ async def unregister(self, event_type: str, topic: str, handler: Callable[..., Any], **kwargs):
76
+ await self.event_bus.unsubscribe(event_type, topic, handler, **kwargs)
77
+
78
+ async def register_transformer(self, event_type: str, topic: str, handler: Callable[..., Any], **kwargs):
79
+ await self.event_bus.subscribe(event_type, topic, handler, transformer=True, **kwargs)
80
+
81
+ async def unregister_transformer(self, event_type: str, topic: str, handler: Callable[..., Any], **kwargs):
82
+ await self.event_bus.unsubscribe(event_type, topic, handler, transformer=True, **kwargs)
83
+
84
+ def messages_by_key(self, key: str) -> List[Message]:
85
+ return self.messages.get(key, [])
86
+
87
+ def messages_by_sender(self, sender: str, key: str):
88
+ results = []
89
+ for res in self.messages.get(key, []):
90
+ if res.sender == sender:
91
+ results.append(res)
92
+ return results
93
+
94
+ def messages_by_topic(self, topic: str, key: str):
95
+ results = []
96
+ for res in self.messages.get(key, []):
97
+ if res.topic == topic:
98
+ results.append(res)
99
+ return results
100
+
101
+ def session_messages(self, session_id: str) -> List[Message]:
102
+ return [m for k, msg in self.messages.items() for m in msg if m.session_id == session_id]
103
+
104
+ @staticmethod
105
+ def mark_valid(messages: List[Message]):
106
+ for msg in messages:
107
+ msg.is_valid = True
108
+
109
+ @staticmethod
110
+ def mark_invalid(messages: List[Message]):
111
+ for msg in messages:
112
+ msg.is_valid = False
113
+
114
+ def clear_messages(self):
115
+ self.messages = []
aworld/events/util.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+ # Copyright (c) 2025 inclusionAI.
3
+ import asyncio
4
+ from typing import Callable, Any
5
+
6
+ from aworld.core.context.base import Context
7
+ from aworld.core.event import eventbus
8
+ from aworld.core.event.base import Message, Constants
9
+ from aworld.events.manager import EventManager
10
+ from aworld.utils.common import sync_exec
11
+
12
+
13
+ def subscribe(key: str, category: str = None):
14
+ """Subscribe the special event to handle.
15
+
16
+ Examples:
17
+ >>> cate = Constants.TOOL or Constants.AGENT; key = "topic"
18
+ >>> @subscribe(category=cate, key=key)
19
+ >>> def example(message: Message) -> Message | None:
20
+ >>> print("do something")
21
+
22
+ Args:
23
+ key: The index key of the handler.
24
+ category: Types of subscription events, the value is `agent` or `tool`.
25
+ """
26
+ def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
27
+ if category is None:
28
+ sync_exec(eventbus.subscribe, Constants.TOOL, key, func)
29
+ sync_exec(eventbus.subscribe, Constants.AGENT, key, func)
30
+ else:
31
+ sync_exec(eventbus.subscribe, category, key, func)
32
+ return func
33
+
34
+ return decorator
35
+
36
+ async def _send_message(msg: Message) -> str:
37
+ context = msg.context
38
+ if not context:
39
+ context = Context()
40
+
41
+ event_mng = context.event_manager
42
+ if not event_mng:
43
+ event_mng = EventManager(context)
44
+
45
+ await event_mng.emit_message(msg)
46
+ return msg.id
47
+
48
+
49
+ async def send_message(msg: Message) -> asyncio.Task:
50
+ """Utility function of send event.
51
+
52
+ Args:
53
+ msg: The content and meta information to be sent.
54
+ """
55
+ task = asyncio.create_task(_send_message(msg), name=msg.id)
56
+ return task