Duibonduil commited on
Commit
7c117ed
·
verified ·
1 Parent(s): b7cf4ad

Upload 5 files

Browse files
aworld/replay_buffer/README.md ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Replay Buffer
2
+
3
+ A multi-process capable replay buffer system for storing and sampling experience data.
4
+
5
+ ## Features
6
+
7
+ - **Multi-process Support**: Safe concurrent access using shared memory and locks
8
+ - **Flexible Querying**: Powerful query builder for filtering stored data
9
+ - **Task-based Organization**: Data organized by task_id and agent_id
10
+ - **Capacity Management**: FIFO eviction when reaching max capacity
11
+ - **Custom Sampling**: Implement custom sampling logic through Sampler interface
12
+ - **Data Conversion**: Custom data conversion through Converter interface
13
+
14
+ ## Basic Usage
15
+
16
+ ### Writing Data
17
+
18
+ ```python
19
+ from aworld.replay_buffer import ReplayBuffer, DataRow, ExpMeta, Experience
20
+ from aworld.core.common import ActionModel, Observation
21
+
22
+ # Create a data row
23
+ data = DataRow(
24
+ exp_meta=ExpMeta(
25
+ task_id="task_1",
26
+ task_name="my_task",
27
+ agent_id="agent_1",
28
+ step=1,
29
+ execute_time=time.time()
30
+ ),
31
+ exp_data=Experience(
32
+ state=Observation(),
33
+ action=ActionModel()
34
+ )
35
+ )
36
+
37
+ # Store data
38
+ replay_buffer.store(data)
39
+ ```
40
+
41
+ ### Reading Data
42
+
43
+ ```python
44
+ from aworld.replay_buffer.query_filter import QueryBuilder
45
+
46
+ # Basic example
47
+ replay_buffer = ReplayBuffer()
48
+ query_condition = QueryBuilder().eq("exp_meta.task_name", "test_task").build()
49
+ data = replay_buffer.sample(sampler=RandomTaskSample(),
50
+ query_condition=query_condition,
51
+ converter=DefaultConverter(),
52
+ batch_size=1000)
53
+
54
+ # Query Task by task_id
55
+ query = QueryBuilder().eq("exp_meta.task_id", "task_1").build()
56
+ data = replay_buffer.sample_task(query_condition=query, batch_size=10)
57
+
58
+ # Query Task by agent_id
59
+ query = QueryBuilder().eq("exp_meta.agent_id", "agent_1").build()
60
+ data = replay_buffer.sample_task(query_condition=query, batch_size=5)
61
+ ```
62
+ ## Multi-processing Example
63
+
64
+ ```python
65
+ import multiprocessing
66
+ from aworld.replay_buffer.storage.multi_proc_mem import MultiProcMemoryStorage
67
+
68
+ manager = multiprocessing.Manager()
69
+ replay_buffer = ReplayBuffer(
70
+ storage=MultiProcMemoryStorage(
71
+ data_dict=manager.dict(),
72
+ fifo_queue=manager.list(),
73
+ lock=manager.Lock(),
74
+ max_capacity=10000
75
+ )
76
+ )
77
+
78
+ # Start writer processes
79
+ processes = [
80
+ multiprocessing.Process(target=write_processing, args=(replay_buffer, f"task_{i}"))
81
+ for i in range(4)
82
+ ]
83
+ ```
84
+ ## Query Builder Examples
85
+
86
+ ### Simple Equality
87
+ ```python
88
+ QueryBuilder().eq("exp_meta.task_id", "123").build()
89
+ ```
90
+
91
+ ### Complex Conditions
92
+ ```python
93
+ QueryBuilder()
94
+ .eq("exp_meta.task_id", "123")
95
+ .and_()
96
+ .eq("exp_meta.agent_id", "456")
97
+ .build()
98
+ ```
99
+ ### Nested Conditions
100
+ ```python
101
+ QueryBuilder()
102
+ .eq("exp_meta.task_id", "123")
103
+ .and_()
104
+ .nested(
105
+ QueryBuilder()
106
+ .eq("exp_meta.agent_id", "111")
107
+ .or_()
108
+ .eq("exp_meta.agent_id", "222")
109
+ )
110
+ .build()
111
+ ```
aworld/replay_buffer/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # coding: utf-8
2
+ # Copyright (c) 2025 inclusionAI.
aworld/replay_buffer/base.py ADDED
@@ -0,0 +1,409 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import uuid
3
+ from dataclasses import dataclass, field
4
+ from typing import Dict, List, TypeVar
5
+ from abc import ABC, abstractmethod
6
+ from math import ceil
7
+
8
+ from aworld.core.common import ActionModel, Observation
9
+ from aworld.replay_buffer.query_filter import QueryCondition, QueryFilter
10
+ from aworld.logs.util import logger
11
+
12
+
13
+ T = TypeVar('T')
14
+
15
+
16
+ @dataclass
17
+ class Experience:
18
+ '''
19
+ Experience of agent.
20
+ '''
21
+ state: Observation
22
+ actions: List[ActionModel]
23
+ reward_t: float = None
24
+ adv_t: float = None
25
+ v_t: float = None
26
+ messages: List[Dict] = None
27
+
28
+ def to_dict(self):
29
+ return {
30
+ "state": self.state,
31
+ "actions": self.actions,
32
+ "reward_t": self.reward_t,
33
+ "adv_t": self.adv_t,
34
+ "v_t": self.v_t,
35
+ "messages": self.messages
36
+ }
37
+
38
+
39
+ @dataclass
40
+ class ExpMeta:
41
+ '''
42
+ Experience meta data.
43
+ '''
44
+ task_id: str
45
+ task_name: str
46
+ agent_id: str
47
+ step: int
48
+ execute_time: float
49
+ pre_agent: str
50
+
51
+ def to_dict(self):
52
+ return {
53
+ "task_id": self.task_id,
54
+ "task_name": self.task_name,
55
+ "agent_id": self.agent_id,
56
+ "step": self.step,
57
+ "execute_time": self.execute_time,
58
+ "pre_agent": self.pre_agent
59
+ }
60
+ @dataclass
61
+ class DataRow:
62
+ '''
63
+ Data row for storing data.
64
+ '''
65
+ exp_meta: ExpMeta
66
+ exp_data: Experience
67
+ id: str = field(default_factory=lambda: str(uuid.uuid4()))
68
+
69
+ def to_dict(self):
70
+ return {
71
+ "exp_meta": self.exp_meta.to_dict(),
72
+ "exp_data": self.exp_data.to_dict(),
73
+ "id": self.id
74
+ }
75
+
76
+
77
+ class Storage(ABC):
78
+ '''
79
+ Storage for storing and sampling data.
80
+ '''
81
+
82
+ @abstractmethod
83
+ def add(self, data: DataRow):
84
+ '''
85
+ Add data to the storage.
86
+ Args:
87
+ data (DataRow): Data to add.
88
+ '''
89
+
90
+ @abstractmethod
91
+ def add_batch(self, data_batch: List[DataRow]):
92
+ '''
93
+ Add batch of data to the storage.
94
+ Args:
95
+ data_batch (List[DataRow]): List of data to add.
96
+ '''
97
+
98
+ @abstractmethod
99
+ def size(self, query_condition: QueryCondition = None) -> int:
100
+ '''
101
+ Get the size of the storage.
102
+ Returns:
103
+ int: Size of the storage.
104
+ '''
105
+
106
+ @abstractmethod
107
+ def get_paginated(self, page: int, page_size: int, query_condition: QueryCondition = None) -> List[DataRow]:
108
+ '''
109
+ Get paginated data from the storage.
110
+ Args:
111
+ page (int): Page number.
112
+ page_size (int): Number of data per page.
113
+ Returns:
114
+ List[DataRow]: List of data.
115
+ '''
116
+
117
+ @abstractmethod
118
+ def get_all(self, query_condition: QueryCondition = None) -> List[DataRow]:
119
+ '''
120
+ Get all data from the storage.
121
+ Returns:
122
+ List[DataRow]: List of data.
123
+ '''
124
+
125
+ @abstractmethod
126
+ def get_by_task_id(self, task_id: str) -> List[DataRow]:
127
+ '''
128
+ Get data by task_id from the storage.
129
+ Args:
130
+ task_id (str): Task id.
131
+ Returns:
132
+ List[DataRow]: List of data.
133
+ '''
134
+
135
+ @abstractmethod
136
+ def get_bacth_by_task_ids(self, task_ids: List[str]) -> Dict[str, List[DataRow]]:
137
+ '''
138
+ Get batch of data by task_ids from the storage.
139
+ Args:
140
+ task_ids (List[str]): List of task ids.
141
+ Returns:
142
+ Dict[str, List[DataRow]]: Dictionary of data.
143
+ The key is the task_id and the value is the list of data.
144
+ The list of data is sorted by step.
145
+ '''
146
+
147
+
148
+ class Sampler(ABC):
149
+ '''
150
+ Sample data from the storage.
151
+ '''
152
+
153
+ def sample(self,
154
+ storage: Storage,
155
+ batch_size: int,
156
+ query_condition: QueryCondition = None) -> List[DataRow]:
157
+ '''
158
+ Sample data from the storage.
159
+ Args:
160
+ storage (Storage): Storage to sample from.
161
+ batch_size (int): Number of data to sample.
162
+ query_condition (QueryCondition, optional): Query condition. Defaults to None.
163
+ Returns:
164
+ List[DataRow]
165
+ '''
166
+
167
+
168
+ class TaskSampler(Sampler):
169
+ '''
170
+ Sample task data from storage, returns Dict[str, List[DataRow]] where:
171
+ - key is task_id
172
+ - value is list of task all data rows
173
+ '''
174
+
175
+ def sorted_by_step(self, task_experience: List[DataRow]) -> List[DataRow]:
176
+ '''
177
+ Sort the task experience by step and execute_time.
178
+ Args:
179
+ task_experience (List[DataRow]): List of task experience.
180
+ Returns:
181
+ List[DataRow]: List of task experience sorted by step and execute_time.
182
+ '''
183
+ return sorted(task_experience, key=lambda x: (x.exp_meta.step, x.exp_meta.execute_time))
184
+
185
+ def sample(self,
186
+ storage: Storage,
187
+ batch_size: int,
188
+ query_condition: QueryCondition = None) -> List[DataRow]:
189
+ task_ids = self.sample_task_ids(storage, batch_size, query_condition)
190
+ return storage.get_bacth_by_task_ids(task_ids)
191
+
192
+ def sample_tasks(self,
193
+ storage: Storage,
194
+ batch_size: int,
195
+ query_condition: QueryCondition = None) -> Dict[str, List[DataRow]]:
196
+ '''
197
+ Sample data from the storage.
198
+ Args:
199
+ storage (Storage): Storage to sample from.
200
+ batch_size (int): Number of data to sample.
201
+ query_condition (QueryCondition, optional): Query condition. Defaults to None.
202
+ Returns:
203
+ Dict[str, List[DataRow]]: Dictionary of sampled data.
204
+ The key is the task_id and the value is the list of data.
205
+ The list of data is sorted by step.
206
+ '''
207
+ task_ids = self.sample_task_ids(storage, batch_size, query_condition)
208
+ raws = storage.get_bacth_by_task_ids(task_ids)
209
+ return {task_id: self.sorted_by_step(raws) for task_id, raws in raws.items()}
210
+
211
+ @abstractmethod
212
+ def sample_task_ids(self,
213
+ storage: Storage,
214
+ batch_size: int,
215
+ query_condition: QueryCondition = None) -> List[str]:
216
+ '''
217
+ Sample task_ids from the storage.
218
+ Args:
219
+ storage (Storage): Storage to sample from.
220
+ batch_size (int): Number of task_ids to sample.
221
+ query_condition (QueryCondition, optional): Query condition. Defaults to None.
222
+ Returns:
223
+ List[str]: List of task_ids.
224
+ '''
225
+
226
+
227
+ class Converter(ABC):
228
+ '''
229
+ Convert data to dataset row.
230
+ '''
231
+
232
+ @abstractmethod
233
+ def to_dataset_row(self, task_experience: List[DataRow]) -> T:
234
+ '''
235
+ Convert task experience to dataset row.
236
+ Args:
237
+ task_experience (List[DataRow]): List of task experience.
238
+ Returns:
239
+ T: type of dataset row.
240
+ '''
241
+
242
+
243
+ class InMemoryStorage(Storage):
244
+ '''
245
+ In-memory storage for storing and sampling data.
246
+ '''
247
+
248
+ def __init__(self, max_capacity: int = 10000):
249
+ self._data: Dict[str, List[DataRow]] = {}
250
+ self._max_capacity = max_capacity
251
+ self._fifo_queue = [] # (task_id)
252
+
253
+ def add(self, data: DataRow):
254
+ if not data:
255
+ raise ValueError("Data is required")
256
+ if not data.exp_meta:
257
+ raise ValueError("exp_meta is required")
258
+
259
+ while self.size() >= self._max_capacity and self._fifo_queue:
260
+ oldest_task_id = self._fifo_queue.pop(0)
261
+ if oldest_task_id in self._data:
262
+ del self._data[oldest_task_id]
263
+
264
+ if data.exp_meta.task_id not in self._data:
265
+ self._data[data.exp_meta.task_id] = []
266
+ self._data[data.exp_meta.task_id].append(data)
267
+ self._fifo_queue.append(data.exp_meta.task_id)
268
+
269
+ if data.exp_meta.task_id not in self._data:
270
+ self._data[data.exp_meta.task_id] = []
271
+ self._data[data.exp_meta.task_id].append(data)
272
+
273
+ def add_batch(self, data_batch: List[DataRow]):
274
+ for data in data_batch:
275
+ self.add(data)
276
+
277
+ def size(self, query_condition: QueryCondition = None) -> int:
278
+ return len(self.get_all(query_condition))
279
+
280
+ def get_paginated(self, page: int, page_size: int, query_condition: QueryCondition = None) -> List[DataRow]:
281
+ if page < 1:
282
+ raise ValueError("Page must be greater than 0")
283
+ if page_size < 1:
284
+ raise ValueError("Page size must be greater than 0")
285
+ all_data = self.get_all(query_condition)
286
+ start_index = (page - 1) * page_size
287
+ end_index = start_index + page_size
288
+ return all_data[start_index:end_index]
289
+
290
+ def get_all(self, query_condition: QueryCondition = None) -> List[DataRow]:
291
+ all_data = []
292
+ query_filter = None
293
+ if query_condition:
294
+ query_filter = QueryFilter(query_condition)
295
+ for data in self._data.values():
296
+ if query_filter:
297
+ all_data.extend(query_filter.filter(data))
298
+ else:
299
+ all_data.extend(data)
300
+ return all_data
301
+
302
+ def get_by_task_id(self, task_id: str) -> List[DataRow]:
303
+ return self._data.get(task_id, [])
304
+
305
+ def get_bacth_by_task_ids(self, task_ids: List[str]) -> Dict[str, List[DataRow]]:
306
+ return {task_id: self._data.get(task_id, []) for task_id in task_ids}
307
+
308
+ def clear(self):
309
+ self._data = {}
310
+ self._fifo_queue = []
311
+
312
+
313
+ class RandomTaskSample(TaskSampler):
314
+ '''
315
+ Randomly sample data from the storage.
316
+ '''
317
+
318
+ def sample_task_ids(self,
319
+ storage: Storage,
320
+ batch_size: int,
321
+ query_condition: QueryCondition = None) -> List[str]:
322
+ total_size = storage.size(query_condition)
323
+ if total_size <= batch_size:
324
+ return storage.get_all(query_condition)
325
+
326
+ sampled_task_ids = set()
327
+ page_size = min(100, batch_size * 2)
328
+ total_pages = ceil(total_size/page_size)
329
+ visited_pages = set()
330
+ while len(sampled_task_ids) < batch_size and len(visited_pages) < total_pages:
331
+ page = random.choice(
332
+ [p for p in range(1, total_pages+1) if p not in visited_pages])
333
+ visited_pages.add(page)
334
+
335
+ current_page = storage.get_paginated(
336
+ page, page_size, query_condition)
337
+ if not current_page:
338
+ continue
339
+ current_page_task_ids = set(
340
+ [data.exp_meta.task_id for data in current_page if data.exp_meta.task_id not in sampled_task_ids])
341
+ sample_count = min(len(current_page_task_ids),
342
+ batch_size - len(sampled_task_ids))
343
+ sampled_task_ids.update(random.sample(
344
+ list(current_page_task_ids), sample_count))
345
+
346
+ return list(sampled_task_ids)
347
+
348
+
349
+ class DefaultConverter(Converter):
350
+ '''
351
+ Default converter do nothing.
352
+ '''
353
+
354
+ def to_dataset_row(self, task_experience: List[DataRow]) -> List[DataRow]:
355
+ return task_experience
356
+
357
+
358
+ class ReplayBuffer:
359
+ '''
360
+ Replay buffer for storing and sampling data.
361
+ '''
362
+
363
+ def __init__(
364
+ self,
365
+ storage: Storage = InMemoryStorage()
366
+ ):
367
+ self._storage = storage
368
+
369
+ def store(self, data: DataRow):
370
+ '''
371
+ Store data in the replay buffer.
372
+ '''
373
+ if not data:
374
+ raise ValueError("Data is required")
375
+ self._storage.add(data)
376
+
377
+ def store_batch(self, data_batch: List[DataRow]):
378
+ '''
379
+ Store batch of data in the replay buffer.
380
+ '''
381
+ if not data_batch:
382
+ raise ValueError("Data batch is required")
383
+ self._storage.add_batch(data_batch)
384
+
385
+ def sample_task(self,
386
+ sampler: TaskSampler = RandomTaskSample(),
387
+ query_condition: QueryCondition = None,
388
+ converter: Converter = DefaultConverter(),
389
+ batch_size: int = 1000) -> List[T]:
390
+ '''
391
+ Sample Task from the replay buffer and convert to dataset row.
392
+ DefaultConverter return List[DataRow]
393
+ '''
394
+ sampled_task = sampler.sample_tasks(
395
+ self._storage, batch_size, query_condition)
396
+ return [converter.to_dataset_row(task_experiences) for task_experiences in sampled_task.values()]
397
+
398
+ def sample(self,
399
+ sampler: Sampler = RandomTaskSample(),
400
+ query_condition: QueryCondition = None,
401
+ converter: Converter = DefaultConverter(),
402
+ batch_size: int = 1000) -> List[T]:
403
+ '''
404
+ Sample data from the replay buffer and convert to dataset row.
405
+ DefaultConverter return List[DataRow]
406
+ '''
407
+ sampled_data = sampler.sample(
408
+ self._storage, batch_size, query_condition)
409
+ return converter.to_dataset_row(sampled_data)
aworld/replay_buffer/processor.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+ """
3
+ processor.py
4
+ Used to clean raw trace data into standard storage structure for reinforcement learning training.
5
+ """
6
+ import json
7
+ import os
8
+ import datetime
9
+ from typing import Any
10
+ import threading
11
+
12
+ from aworld.utils import import_package
13
+ from aworld.replay_buffer.base import DataRow, Experience, ExpMeta
14
+ from aworld.logs.util import logger
15
+ from aworld.utils.common import get_local_ip
16
+
17
+
18
+ class ReplayBufferExporter:
19
+ def __init__(self):
20
+ """Initialize ReplayBufferExporter instance"""
21
+ self._file_locks = {}
22
+ self._lock_dict_lock = threading.Lock()
23
+ self._task_output_paths = {}
24
+
25
+ def _get_file_lock(self, file_path):
26
+ """Get the lock for the specified file"""
27
+ with self._lock_dict_lock:
28
+ if file_path not in self._file_locks:
29
+ self._file_locks[file_path] = threading.Lock()
30
+ return self._file_locks[file_path]
31
+
32
+ def replay_buffer_exporter(self, spans: list[dict[str, Any]], output_dir: str):
33
+ """
34
+ Process spans, only process spans with 'step_execution_' prefix, and group by task_id to output to different files
35
+
36
+ Args:
37
+ spans: span data list
38
+ output_dir: output directory path
39
+ """
40
+ # Ensure output directory exists
41
+ import_package("oss2")
42
+ import oss2
43
+
44
+ os.makedirs(output_dir, exist_ok=True)
45
+
46
+ # Get OSS credentials from environment variables
47
+ enable_oss_export = os.getenv("EXPORT_REPLAY_TRACE_TO_OSS", "false").lower() == "true"
48
+ access_key_id = os.getenv('OSS_ACCESS_KEY_ID')
49
+ access_key_secret = os.getenv('OSS_ACCESS_KEY_SECRET')
50
+ endpoint = os.getenv('OSS_ENDPOINT')
51
+ bucket_name = os.getenv('OSS_BUCKET_NAME')
52
+ bucket = None
53
+
54
+ if not all([access_key_id, access_key_secret, endpoint, bucket_name]):
55
+ enable_oss_export = False
56
+ logger.warn("Missing required OSS environment variables")
57
+ else:
58
+ try:
59
+ # Initialize OSS client
60
+ auth = oss2.Auth(access_key_id, access_key_secret)
61
+ bucket = oss2.Bucket(auth, endpoint, bucket_name)
62
+ except Exception as e:
63
+ enable_oss_export = False
64
+ logger.warn(f"Failed to initialize OSS client, endpoint: {endpoint}, bucket: {bucket_name}. Error: {str(e)}")
65
+
66
+ # Group by task_id
67
+ task_groups = {}
68
+
69
+ for span_data in spans:
70
+ # Only process spans with 'step_execution_' prefix
71
+ if not span_data['name'].startswith('step_execution_'):
72
+ continue
73
+
74
+ attr = span_data.get('attributes', {})
75
+ exp_id = attr.get('exp_id')
76
+ task_id = attr.get('task_id', '')
77
+
78
+ if not exp_id or not task_id:
79
+ continue
80
+
81
+ if task_id not in task_groups:
82
+ task_groups[task_id] = {}
83
+
84
+ if exp_id not in task_groups[task_id]:
85
+ task_groups[task_id][exp_id] = {
86
+ 'exp_meta': None,
87
+ 'exp_data': None
88
+ }
89
+
90
+ # Process step_execution span
91
+ task_name = attr.get('task_name', '')
92
+ agent_id = attr.get('agent_id', '')
93
+ step = attr.get('step', 0)
94
+ execute_time = float(span_data.get('start_time', 0).split('.')[0].replace(' ', '').replace('-', '').replace(':', ''))
95
+
96
+ observation = {}
97
+ action = []
98
+ messages = []
99
+ pre_agent = None
100
+ if 'observation' in attr:
101
+ try:
102
+ observation = json.loads(attr['observation'])
103
+ except:
104
+ observation = attr['observation']
105
+
106
+ if 'actions' in attr:
107
+ try:
108
+ action = json.loads(attr['actions'])
109
+ except:
110
+ action = attr['actions']
111
+
112
+ if 'messages' in attr:
113
+ try:
114
+ messages = json.loads(attr['messages'])
115
+ except:
116
+ messages = attr['messages']
117
+
118
+ pre_agent = attr.get('pre_agent', '')
119
+ reward = attr.get('reward', 0.0)
120
+ adv = attr.get('adv_t', 0.0)
121
+ v = attr.get('v_t', 0.0)
122
+
123
+ exp_meta = ExpMeta(task_id, task_name, agent_id, step, execute_time, pre_agent)
124
+ exp_data = Experience(observation, action, reward, adv, v, messages)
125
+
126
+ task_groups[task_id][exp_id]['exp_meta'] = exp_meta
127
+ task_groups[task_id][exp_id]['exp_data'] = exp_data
128
+
129
+ # Process data for each task_id
130
+ for task_id, exp_groups in task_groups.items():
131
+ # Merge data and generate final Experience object
132
+ data_rows = []
133
+
134
+ # Read existing data (if any)
135
+ output_path = self._task_output_paths.get(task_id)
136
+ if not output_path:
137
+ timestamp = datetime.datetime.now().strftime("%Y%m%d")
138
+ replay_dir = os.path.join(output_dir or "./trace_data", timestamp, get_local_ip(), "replays")
139
+ replay_dataset_path = os.getenv("REPLAY_TRACE_DATASET_PATH", replay_dir)
140
+ export_dir = os.path.abspath(replay_dataset_path)
141
+ os.makedirs(export_dir, exist_ok=True)
142
+ output_path = os.path.join(export_dir, f"task_replay_{task_id}.json")
143
+ self._task_output_paths[task_id] = output_path
144
+
145
+ # Use thread lock to protect read and write operations
146
+ file_lock = self._get_file_lock(output_path)
147
+ with file_lock:
148
+ if os.path.exists(output_path):
149
+ try:
150
+ with open(output_path, 'r', encoding='utf-8') as f:
151
+ existing_data = json.load(f)
152
+ data_rows.extend([DataRow(
153
+ ExpMeta(**row['exp_meta']),
154
+ Experience(**row['exp_data']),
155
+ row['id']
156
+ ) for row in existing_data])
157
+ except Exception as e:
158
+ print(f"Failed to read existing file {output_path}: {str(e)}")
159
+
160
+ # Add new data
161
+ for exp_id, group in exp_groups.items():
162
+ if group['exp_meta'] and group['exp_data']:
163
+ row = DataRow(group['exp_meta'], group['exp_data'], exp_id)
164
+ data_rows.append(row)
165
+
166
+ # Sort by execute_time
167
+ data_rows.sort(key=lambda x: x.exp_meta.execute_time)
168
+
169
+ # Export to json
170
+ with open(output_path, 'w', encoding='utf-8') as f:
171
+ json.dump([row.to_dict() for row in data_rows], f, ensure_ascii=False, indent=2)
172
+ logger.info(f"Processing completed, exported {len(data_rows)} experiences to {output_path}")
173
+
174
+ if enable_oss_export:
175
+ # Upload to OSS
176
+ try:
177
+ # Get the relative path
178
+ abs_path = os.path.abspath(output_path)
179
+ path_parts = abs_path.split(os.sep)
180
+ if len(path_parts) >= 4:
181
+ # Get the last 4 parts of the path
182
+ relative_path = os.sep.join(path_parts[-4:])
183
+ oss_key = relative_path
184
+ else:
185
+ oss_key = f"replay_buffer/{os.path.basename(output_path)}"
186
+ bucket.put_object_from_file(oss_key, output_path)
187
+ logger.info(f"Successfully uploaded {output_path} to OSS: {oss_key}")
188
+ except Exception as e:
189
+ logger.warn(f"Failed to upload {output_path} to OSS: {str(e)}")
190
+
aworld/replay_buffer/query_filter.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, List, TypeVar, Union, Literal, TypedDict, Dict
2
+
3
+ DataRow = TypeVar('DataRow')
4
+
5
+
6
+ class BaseCondition(TypedDict):
7
+ field: str
8
+ value: Any
9
+ op: Literal[
10
+ 'eq', 'ne', 'gt', 'gte', 'lt', 'lte',
11
+ 'in', 'not_in', 'like', 'not_like',
12
+ 'is_null', 'is_not_null'
13
+ ]
14
+
15
+
16
+ class LogicalCondition(TypedDict):
17
+ and_: List['QueryCondition']
18
+ or_: List['QueryCondition']
19
+
20
+
21
+ QueryCondition = Union[BaseCondition, LogicalCondition]
22
+
23
+
24
+ class QueryBuilder:
25
+ '''
26
+ Query builder for replay buffer. result example:
27
+ {
28
+ "and": [
29
+ {"field": "field1", "value": "value1", "op": "eq"},
30
+ {"or": [{"field": "field2", "value": "value2", "op": "eq"}, {"field": "field3", "value": "value3", "op": "eq"}]}
31
+ ]
32
+ }
33
+ '''
34
+
35
+ def __init__(self) -> None:
36
+ self.conditions: List[Dict[str, any]] = []
37
+ self.logical_ops: List[str] = []
38
+
39
+ def eq(self, field: str, value: any) -> 'QueryBuilder':
40
+ self.conditions.append({"field": field, "value": value, "op": "eq"})
41
+ return self
42
+
43
+ def ne(self, field: str, value: any) -> 'QueryBuilder':
44
+ self.conditions.append({"field": field, "value": value, "op": "ne"})
45
+ return self
46
+
47
+ def gt(self, field: str, value: any) -> 'QueryBuilder':
48
+ self.conditions.append({"field": field, "value": value, "op": "gt"})
49
+ return self
50
+
51
+ def gte(self, field: str, value: any) -> 'QueryBuilder':
52
+ self.conditions.append({"field": field, "value": value, "op": "gte"})
53
+ return self
54
+
55
+ def lt(self, field: str, value: any) -> 'QueryBuilder':
56
+ self.conditions.append({"field": field, "value": value, "op": "lt"})
57
+ return self
58
+
59
+ def lte(self, field: str, value: any) -> 'QueryBuilder':
60
+ self.conditions.append({"field": field, "value": value, "op": "lte"})
61
+ return self
62
+
63
+ def in_(self, field: str, value: any) -> 'QueryBuilder':
64
+ self.conditions.append({"field": field, "value": value, "op": "in"})
65
+ return self
66
+
67
+ def not_in(self, field: str, value: any) -> 'QueryBuilder':
68
+ self.conditions.append(
69
+ {"field": field, "value": value, "op": "not_in"})
70
+ return self
71
+
72
+ def like(self, field: str, value: any) -> 'QueryBuilder':
73
+ self.conditions.append({"field": field, "value": value, "op": "like"})
74
+ return self
75
+
76
+ def not_like(self, field: str, value: any) -> 'QueryBuilder':
77
+ self.conditions.append(
78
+ {"field": field, "value": value, "op": "not_like"})
79
+ return self
80
+
81
+ def is_null(self, field: str) -> 'QueryBuilder':
82
+ self.conditions.append({"field": field, "op": "is_null"})
83
+ return self
84
+
85
+ def is_not_null(self, field: str) -> 'QueryBuilder':
86
+ self.conditions.append({"field": field, "op": "is_not_null"})
87
+ return self
88
+
89
+ def and_(self) -> 'QueryBuilder':
90
+ self.logical_ops.append("and_")
91
+ return self
92
+
93
+ def or_(self) -> 'QueryBuilder':
94
+ self.logical_ops.append("or_")
95
+ return self
96
+
97
+ def nested(self, builder: 'QueryBuilder') -> 'QueryBuilder':
98
+ self.conditions.append({"nested": builder.build()})
99
+ return self
100
+
101
+ def build(self) -> QueryCondition:
102
+ conditions = self.conditions # all conditions(including nested)
103
+ operators = self.logical_ops
104
+
105
+ # Validate condition and operator counts (n conditions need n-1 operators)
106
+ if len(operators) != len(conditions) - 1:
107
+ raise ValueError("Mismatch between condition and operator counts")
108
+
109
+ # Use stack to handle operator precedence (simplified version supporting and/or)
110
+ stack: List[Union[Dict[str, any], str]] = []
111
+
112
+ for i, item in enumerate(conditions):
113
+ if i == 0:
114
+ # First element goes directly to stack (condition or nested)
115
+ stack.append(item)
116
+ continue
117
+
118
+ # Pop stack top as left operand
119
+ left = stack.pop()
120
+ op = operators[i-1] # Current operator (and/or)
121
+ right = item # Right operand (current condition)
122
+
123
+ # Build logical expression: {op: [left, right]}
124
+ expr = {op: [left, right]}
125
+ # Push result back to stack for further operations
126
+ stack.append(expr)
127
+
128
+ # Process nested conditions (recursive unfolding)
129
+ def process_nested(cond: any) -> any:
130
+ if isinstance(cond, dict):
131
+ if "nested" in cond:
132
+ # Recursively process sub-conditions
133
+ return process_nested(cond["nested"])
134
+ # Recursively process child elements
135
+ return {k: process_nested(v) for k, v in cond.items()}
136
+ elif isinstance(cond, list):
137
+ return [process_nested(item) for item in cond]
138
+ return cond
139
+
140
+ # Final result: only one element left in stack, return after processing nested
141
+ result = stack[0] if stack else None
142
+ return process_nested(result) if result else None
143
+
144
+
145
+ class QueryFilter:
146
+ '''
147
+ Query filter for replay buffer.
148
+ '''
149
+
150
+ def __init__(self, query_condition: QueryCondition) -> None:
151
+ self.query_condition = query_condition
152
+
153
+ def _get_field_value(self, row: DataRow, field: str) -> Any:
154
+ '''
155
+ Get field value from row.
156
+ '''
157
+ obj = row
158
+ for part in field.split('.'):
159
+ obj = getattr(obj, part, None)
160
+ if obj is None:
161
+ break
162
+ return obj
163
+
164
+ def _do_check(self, row: DataRow, condition: QueryCondition) -> bool:
165
+ """
166
+ check if row match condition
167
+ """
168
+ if condition is None:
169
+ return True
170
+ if "field" in condition and "op" in condition:
171
+ field_val = self._get_field_value(row, condition["field"])
172
+ op = condition["op"]
173
+ target_val = condition["value"]
174
+
175
+ if op == "eq":
176
+ return field_val == target_val
177
+ if op == "ne":
178
+ return field_val != target_val
179
+ if op == "gt":
180
+ return field_val > target_val
181
+ if op == "gte":
182
+ return field_val >= target_val
183
+ if op == "lt":
184
+ return field_val < target_val
185
+ if op == "lte":
186
+ return field_val <= target_val
187
+ if op == "in":
188
+ return field_val in target_val
189
+ if op == "not_in":
190
+ return field_val not in target_val
191
+ if op == "like":
192
+ return target_val in field_val
193
+ if op == "not_like":
194
+ return target_val not in field_val
195
+ if op == "is_null":
196
+ return field_val is None
197
+ if op == "is_not_null":
198
+ return field_val is not None
199
+
200
+ return False
201
+
202
+ elif "and_" in condition or "or_" in condition:
203
+ if "and_" in condition:
204
+ return all(self._do_check(row, c) for c in condition["and_"])
205
+ if "or_" in condition:
206
+ return any(self._do_check(row, c) for c in condition["or_"])
207
+ return False
208
+
209
+ return False
210
+
211
+ def check_condition(self, row: DataRow) -> bool:
212
+ """
213
+ check if row match condition
214
+ """
215
+ return self._do_check(row, self.query_condition)
216
+
217
+ def filter(self, rows: List[DataRow]) -> List[DataRow]:
218
+ """filter rows by condition
219
+ Args:
220
+ rows (List[DataRow]): List of rows to filter.
221
+ query_condition (QueryCondition): Query condition.
222
+ Returns:
223
+ List[DataRow]: List of rows that match the condition.
224
+ """
225
+ condition = self.query_condition
226
+ if not condition:
227
+ return rows
228
+ return [row for row in rows if self.check_condition(row)]