Spaces:
Sleeping
Sleeping
Upload 5 files
Browse files- aworld/replay_buffer/README.md +111 -0
- aworld/replay_buffer/__init__.py +2 -0
- aworld/replay_buffer/base.py +409 -0
- aworld/replay_buffer/processor.py +190 -0
- aworld/replay_buffer/query_filter.py +228 -0
aworld/replay_buffer/README.md
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Replay Buffer
|
2 |
+
|
3 |
+
A multi-process capable replay buffer system for storing and sampling experience data.
|
4 |
+
|
5 |
+
## Features
|
6 |
+
|
7 |
+
- **Multi-process Support**: Safe concurrent access using shared memory and locks
|
8 |
+
- **Flexible Querying**: Powerful query builder for filtering stored data
|
9 |
+
- **Task-based Organization**: Data organized by task_id and agent_id
|
10 |
+
- **Capacity Management**: FIFO eviction when reaching max capacity
|
11 |
+
- **Custom Sampling**: Implement custom sampling logic through Sampler interface
|
12 |
+
- **Data Conversion**: Custom data conversion through Converter interface
|
13 |
+
|
14 |
+
## Basic Usage
|
15 |
+
|
16 |
+
### Writing Data
|
17 |
+
|
18 |
+
```python
|
19 |
+
from aworld.replay_buffer import ReplayBuffer, DataRow, ExpMeta, Experience
|
20 |
+
from aworld.core.common import ActionModel, Observation
|
21 |
+
|
22 |
+
# Create a data row
|
23 |
+
data = DataRow(
|
24 |
+
exp_meta=ExpMeta(
|
25 |
+
task_id="task_1",
|
26 |
+
task_name="my_task",
|
27 |
+
agent_id="agent_1",
|
28 |
+
step=1,
|
29 |
+
execute_time=time.time()
|
30 |
+
),
|
31 |
+
exp_data=Experience(
|
32 |
+
state=Observation(),
|
33 |
+
action=ActionModel()
|
34 |
+
)
|
35 |
+
)
|
36 |
+
|
37 |
+
# Store data
|
38 |
+
replay_buffer.store(data)
|
39 |
+
```
|
40 |
+
|
41 |
+
### Reading Data
|
42 |
+
|
43 |
+
```python
|
44 |
+
from aworld.replay_buffer.query_filter import QueryBuilder
|
45 |
+
|
46 |
+
# Basic example
|
47 |
+
replay_buffer = ReplayBuffer()
|
48 |
+
query_condition = QueryBuilder().eq("exp_meta.task_name", "test_task").build()
|
49 |
+
data = replay_buffer.sample(sampler=RandomTaskSample(),
|
50 |
+
query_condition=query_condition,
|
51 |
+
converter=DefaultConverter(),
|
52 |
+
batch_size=1000)
|
53 |
+
|
54 |
+
# Query Task by task_id
|
55 |
+
query = QueryBuilder().eq("exp_meta.task_id", "task_1").build()
|
56 |
+
data = replay_buffer.sample_task(query_condition=query, batch_size=10)
|
57 |
+
|
58 |
+
# Query Task by agent_id
|
59 |
+
query = QueryBuilder().eq("exp_meta.agent_id", "agent_1").build()
|
60 |
+
data = replay_buffer.sample_task(query_condition=query, batch_size=5)
|
61 |
+
```
|
62 |
+
## Multi-processing Example
|
63 |
+
|
64 |
+
```python
|
65 |
+
import multiprocessing
|
66 |
+
from aworld.replay_buffer.storage.multi_proc_mem import MultiProcMemoryStorage
|
67 |
+
|
68 |
+
manager = multiprocessing.Manager()
|
69 |
+
replay_buffer = ReplayBuffer(
|
70 |
+
storage=MultiProcMemoryStorage(
|
71 |
+
data_dict=manager.dict(),
|
72 |
+
fifo_queue=manager.list(),
|
73 |
+
lock=manager.Lock(),
|
74 |
+
max_capacity=10000
|
75 |
+
)
|
76 |
+
)
|
77 |
+
|
78 |
+
# Start writer processes
|
79 |
+
processes = [
|
80 |
+
multiprocessing.Process(target=write_processing, args=(replay_buffer, f"task_{i}"))
|
81 |
+
for i in range(4)
|
82 |
+
]
|
83 |
+
```
|
84 |
+
## Query Builder Examples
|
85 |
+
|
86 |
+
### Simple Equality
|
87 |
+
```python
|
88 |
+
QueryBuilder().eq("exp_meta.task_id", "123").build()
|
89 |
+
```
|
90 |
+
|
91 |
+
### Complex Conditions
|
92 |
+
```python
|
93 |
+
QueryBuilder()
|
94 |
+
.eq("exp_meta.task_id", "123")
|
95 |
+
.and_()
|
96 |
+
.eq("exp_meta.agent_id", "456")
|
97 |
+
.build()
|
98 |
+
```
|
99 |
+
### Nested Conditions
|
100 |
+
```python
|
101 |
+
QueryBuilder()
|
102 |
+
.eq("exp_meta.task_id", "123")
|
103 |
+
.and_()
|
104 |
+
.nested(
|
105 |
+
QueryBuilder()
|
106 |
+
.eq("exp_meta.agent_id", "111")
|
107 |
+
.or_()
|
108 |
+
.eq("exp_meta.agent_id", "222")
|
109 |
+
)
|
110 |
+
.build()
|
111 |
+
```
|
aworld/replay_buffer/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
# Copyright (c) 2025 inclusionAI.
|
aworld/replay_buffer/base.py
ADDED
@@ -0,0 +1,409 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import uuid
|
3 |
+
from dataclasses import dataclass, field
|
4 |
+
from typing import Dict, List, TypeVar
|
5 |
+
from abc import ABC, abstractmethod
|
6 |
+
from math import ceil
|
7 |
+
|
8 |
+
from aworld.core.common import ActionModel, Observation
|
9 |
+
from aworld.replay_buffer.query_filter import QueryCondition, QueryFilter
|
10 |
+
from aworld.logs.util import logger
|
11 |
+
|
12 |
+
|
13 |
+
T = TypeVar('T')
|
14 |
+
|
15 |
+
|
16 |
+
@dataclass
|
17 |
+
class Experience:
|
18 |
+
'''
|
19 |
+
Experience of agent.
|
20 |
+
'''
|
21 |
+
state: Observation
|
22 |
+
actions: List[ActionModel]
|
23 |
+
reward_t: float = None
|
24 |
+
adv_t: float = None
|
25 |
+
v_t: float = None
|
26 |
+
messages: List[Dict] = None
|
27 |
+
|
28 |
+
def to_dict(self):
|
29 |
+
return {
|
30 |
+
"state": self.state,
|
31 |
+
"actions": self.actions,
|
32 |
+
"reward_t": self.reward_t,
|
33 |
+
"adv_t": self.adv_t,
|
34 |
+
"v_t": self.v_t,
|
35 |
+
"messages": self.messages
|
36 |
+
}
|
37 |
+
|
38 |
+
|
39 |
+
@dataclass
|
40 |
+
class ExpMeta:
|
41 |
+
'''
|
42 |
+
Experience meta data.
|
43 |
+
'''
|
44 |
+
task_id: str
|
45 |
+
task_name: str
|
46 |
+
agent_id: str
|
47 |
+
step: int
|
48 |
+
execute_time: float
|
49 |
+
pre_agent: str
|
50 |
+
|
51 |
+
def to_dict(self):
|
52 |
+
return {
|
53 |
+
"task_id": self.task_id,
|
54 |
+
"task_name": self.task_name,
|
55 |
+
"agent_id": self.agent_id,
|
56 |
+
"step": self.step,
|
57 |
+
"execute_time": self.execute_time,
|
58 |
+
"pre_agent": self.pre_agent
|
59 |
+
}
|
60 |
+
@dataclass
|
61 |
+
class DataRow:
|
62 |
+
'''
|
63 |
+
Data row for storing data.
|
64 |
+
'''
|
65 |
+
exp_meta: ExpMeta
|
66 |
+
exp_data: Experience
|
67 |
+
id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
68 |
+
|
69 |
+
def to_dict(self):
|
70 |
+
return {
|
71 |
+
"exp_meta": self.exp_meta.to_dict(),
|
72 |
+
"exp_data": self.exp_data.to_dict(),
|
73 |
+
"id": self.id
|
74 |
+
}
|
75 |
+
|
76 |
+
|
77 |
+
class Storage(ABC):
|
78 |
+
'''
|
79 |
+
Storage for storing and sampling data.
|
80 |
+
'''
|
81 |
+
|
82 |
+
@abstractmethod
|
83 |
+
def add(self, data: DataRow):
|
84 |
+
'''
|
85 |
+
Add data to the storage.
|
86 |
+
Args:
|
87 |
+
data (DataRow): Data to add.
|
88 |
+
'''
|
89 |
+
|
90 |
+
@abstractmethod
|
91 |
+
def add_batch(self, data_batch: List[DataRow]):
|
92 |
+
'''
|
93 |
+
Add batch of data to the storage.
|
94 |
+
Args:
|
95 |
+
data_batch (List[DataRow]): List of data to add.
|
96 |
+
'''
|
97 |
+
|
98 |
+
@abstractmethod
|
99 |
+
def size(self, query_condition: QueryCondition = None) -> int:
|
100 |
+
'''
|
101 |
+
Get the size of the storage.
|
102 |
+
Returns:
|
103 |
+
int: Size of the storage.
|
104 |
+
'''
|
105 |
+
|
106 |
+
@abstractmethod
|
107 |
+
def get_paginated(self, page: int, page_size: int, query_condition: QueryCondition = None) -> List[DataRow]:
|
108 |
+
'''
|
109 |
+
Get paginated data from the storage.
|
110 |
+
Args:
|
111 |
+
page (int): Page number.
|
112 |
+
page_size (int): Number of data per page.
|
113 |
+
Returns:
|
114 |
+
List[DataRow]: List of data.
|
115 |
+
'''
|
116 |
+
|
117 |
+
@abstractmethod
|
118 |
+
def get_all(self, query_condition: QueryCondition = None) -> List[DataRow]:
|
119 |
+
'''
|
120 |
+
Get all data from the storage.
|
121 |
+
Returns:
|
122 |
+
List[DataRow]: List of data.
|
123 |
+
'''
|
124 |
+
|
125 |
+
@abstractmethod
|
126 |
+
def get_by_task_id(self, task_id: str) -> List[DataRow]:
|
127 |
+
'''
|
128 |
+
Get data by task_id from the storage.
|
129 |
+
Args:
|
130 |
+
task_id (str): Task id.
|
131 |
+
Returns:
|
132 |
+
List[DataRow]: List of data.
|
133 |
+
'''
|
134 |
+
|
135 |
+
@abstractmethod
|
136 |
+
def get_bacth_by_task_ids(self, task_ids: List[str]) -> Dict[str, List[DataRow]]:
|
137 |
+
'''
|
138 |
+
Get batch of data by task_ids from the storage.
|
139 |
+
Args:
|
140 |
+
task_ids (List[str]): List of task ids.
|
141 |
+
Returns:
|
142 |
+
Dict[str, List[DataRow]]: Dictionary of data.
|
143 |
+
The key is the task_id and the value is the list of data.
|
144 |
+
The list of data is sorted by step.
|
145 |
+
'''
|
146 |
+
|
147 |
+
|
148 |
+
class Sampler(ABC):
|
149 |
+
'''
|
150 |
+
Sample data from the storage.
|
151 |
+
'''
|
152 |
+
|
153 |
+
def sample(self,
|
154 |
+
storage: Storage,
|
155 |
+
batch_size: int,
|
156 |
+
query_condition: QueryCondition = None) -> List[DataRow]:
|
157 |
+
'''
|
158 |
+
Sample data from the storage.
|
159 |
+
Args:
|
160 |
+
storage (Storage): Storage to sample from.
|
161 |
+
batch_size (int): Number of data to sample.
|
162 |
+
query_condition (QueryCondition, optional): Query condition. Defaults to None.
|
163 |
+
Returns:
|
164 |
+
List[DataRow]
|
165 |
+
'''
|
166 |
+
|
167 |
+
|
168 |
+
class TaskSampler(Sampler):
|
169 |
+
'''
|
170 |
+
Sample task data from storage, returns Dict[str, List[DataRow]] where:
|
171 |
+
- key is task_id
|
172 |
+
- value is list of task all data rows
|
173 |
+
'''
|
174 |
+
|
175 |
+
def sorted_by_step(self, task_experience: List[DataRow]) -> List[DataRow]:
|
176 |
+
'''
|
177 |
+
Sort the task experience by step and execute_time.
|
178 |
+
Args:
|
179 |
+
task_experience (List[DataRow]): List of task experience.
|
180 |
+
Returns:
|
181 |
+
List[DataRow]: List of task experience sorted by step and execute_time.
|
182 |
+
'''
|
183 |
+
return sorted(task_experience, key=lambda x: (x.exp_meta.step, x.exp_meta.execute_time))
|
184 |
+
|
185 |
+
def sample(self,
|
186 |
+
storage: Storage,
|
187 |
+
batch_size: int,
|
188 |
+
query_condition: QueryCondition = None) -> List[DataRow]:
|
189 |
+
task_ids = self.sample_task_ids(storage, batch_size, query_condition)
|
190 |
+
return storage.get_bacth_by_task_ids(task_ids)
|
191 |
+
|
192 |
+
def sample_tasks(self,
|
193 |
+
storage: Storage,
|
194 |
+
batch_size: int,
|
195 |
+
query_condition: QueryCondition = None) -> Dict[str, List[DataRow]]:
|
196 |
+
'''
|
197 |
+
Sample data from the storage.
|
198 |
+
Args:
|
199 |
+
storage (Storage): Storage to sample from.
|
200 |
+
batch_size (int): Number of data to sample.
|
201 |
+
query_condition (QueryCondition, optional): Query condition. Defaults to None.
|
202 |
+
Returns:
|
203 |
+
Dict[str, List[DataRow]]: Dictionary of sampled data.
|
204 |
+
The key is the task_id and the value is the list of data.
|
205 |
+
The list of data is sorted by step.
|
206 |
+
'''
|
207 |
+
task_ids = self.sample_task_ids(storage, batch_size, query_condition)
|
208 |
+
raws = storage.get_bacth_by_task_ids(task_ids)
|
209 |
+
return {task_id: self.sorted_by_step(raws) for task_id, raws in raws.items()}
|
210 |
+
|
211 |
+
@abstractmethod
|
212 |
+
def sample_task_ids(self,
|
213 |
+
storage: Storage,
|
214 |
+
batch_size: int,
|
215 |
+
query_condition: QueryCondition = None) -> List[str]:
|
216 |
+
'''
|
217 |
+
Sample task_ids from the storage.
|
218 |
+
Args:
|
219 |
+
storage (Storage): Storage to sample from.
|
220 |
+
batch_size (int): Number of task_ids to sample.
|
221 |
+
query_condition (QueryCondition, optional): Query condition. Defaults to None.
|
222 |
+
Returns:
|
223 |
+
List[str]: List of task_ids.
|
224 |
+
'''
|
225 |
+
|
226 |
+
|
227 |
+
class Converter(ABC):
|
228 |
+
'''
|
229 |
+
Convert data to dataset row.
|
230 |
+
'''
|
231 |
+
|
232 |
+
@abstractmethod
|
233 |
+
def to_dataset_row(self, task_experience: List[DataRow]) -> T:
|
234 |
+
'''
|
235 |
+
Convert task experience to dataset row.
|
236 |
+
Args:
|
237 |
+
task_experience (List[DataRow]): List of task experience.
|
238 |
+
Returns:
|
239 |
+
T: type of dataset row.
|
240 |
+
'''
|
241 |
+
|
242 |
+
|
243 |
+
class InMemoryStorage(Storage):
|
244 |
+
'''
|
245 |
+
In-memory storage for storing and sampling data.
|
246 |
+
'''
|
247 |
+
|
248 |
+
def __init__(self, max_capacity: int = 10000):
|
249 |
+
self._data: Dict[str, List[DataRow]] = {}
|
250 |
+
self._max_capacity = max_capacity
|
251 |
+
self._fifo_queue = [] # (task_id)
|
252 |
+
|
253 |
+
def add(self, data: DataRow):
|
254 |
+
if not data:
|
255 |
+
raise ValueError("Data is required")
|
256 |
+
if not data.exp_meta:
|
257 |
+
raise ValueError("exp_meta is required")
|
258 |
+
|
259 |
+
while self.size() >= self._max_capacity and self._fifo_queue:
|
260 |
+
oldest_task_id = self._fifo_queue.pop(0)
|
261 |
+
if oldest_task_id in self._data:
|
262 |
+
del self._data[oldest_task_id]
|
263 |
+
|
264 |
+
if data.exp_meta.task_id not in self._data:
|
265 |
+
self._data[data.exp_meta.task_id] = []
|
266 |
+
self._data[data.exp_meta.task_id].append(data)
|
267 |
+
self._fifo_queue.append(data.exp_meta.task_id)
|
268 |
+
|
269 |
+
if data.exp_meta.task_id not in self._data:
|
270 |
+
self._data[data.exp_meta.task_id] = []
|
271 |
+
self._data[data.exp_meta.task_id].append(data)
|
272 |
+
|
273 |
+
def add_batch(self, data_batch: List[DataRow]):
|
274 |
+
for data in data_batch:
|
275 |
+
self.add(data)
|
276 |
+
|
277 |
+
def size(self, query_condition: QueryCondition = None) -> int:
|
278 |
+
return len(self.get_all(query_condition))
|
279 |
+
|
280 |
+
def get_paginated(self, page: int, page_size: int, query_condition: QueryCondition = None) -> List[DataRow]:
|
281 |
+
if page < 1:
|
282 |
+
raise ValueError("Page must be greater than 0")
|
283 |
+
if page_size < 1:
|
284 |
+
raise ValueError("Page size must be greater than 0")
|
285 |
+
all_data = self.get_all(query_condition)
|
286 |
+
start_index = (page - 1) * page_size
|
287 |
+
end_index = start_index + page_size
|
288 |
+
return all_data[start_index:end_index]
|
289 |
+
|
290 |
+
def get_all(self, query_condition: QueryCondition = None) -> List[DataRow]:
|
291 |
+
all_data = []
|
292 |
+
query_filter = None
|
293 |
+
if query_condition:
|
294 |
+
query_filter = QueryFilter(query_condition)
|
295 |
+
for data in self._data.values():
|
296 |
+
if query_filter:
|
297 |
+
all_data.extend(query_filter.filter(data))
|
298 |
+
else:
|
299 |
+
all_data.extend(data)
|
300 |
+
return all_data
|
301 |
+
|
302 |
+
def get_by_task_id(self, task_id: str) -> List[DataRow]:
|
303 |
+
return self._data.get(task_id, [])
|
304 |
+
|
305 |
+
def get_bacth_by_task_ids(self, task_ids: List[str]) -> Dict[str, List[DataRow]]:
|
306 |
+
return {task_id: self._data.get(task_id, []) for task_id in task_ids}
|
307 |
+
|
308 |
+
def clear(self):
|
309 |
+
self._data = {}
|
310 |
+
self._fifo_queue = []
|
311 |
+
|
312 |
+
|
313 |
+
class RandomTaskSample(TaskSampler):
|
314 |
+
'''
|
315 |
+
Randomly sample data from the storage.
|
316 |
+
'''
|
317 |
+
|
318 |
+
def sample_task_ids(self,
|
319 |
+
storage: Storage,
|
320 |
+
batch_size: int,
|
321 |
+
query_condition: QueryCondition = None) -> List[str]:
|
322 |
+
total_size = storage.size(query_condition)
|
323 |
+
if total_size <= batch_size:
|
324 |
+
return storage.get_all(query_condition)
|
325 |
+
|
326 |
+
sampled_task_ids = set()
|
327 |
+
page_size = min(100, batch_size * 2)
|
328 |
+
total_pages = ceil(total_size/page_size)
|
329 |
+
visited_pages = set()
|
330 |
+
while len(sampled_task_ids) < batch_size and len(visited_pages) < total_pages:
|
331 |
+
page = random.choice(
|
332 |
+
[p for p in range(1, total_pages+1) if p not in visited_pages])
|
333 |
+
visited_pages.add(page)
|
334 |
+
|
335 |
+
current_page = storage.get_paginated(
|
336 |
+
page, page_size, query_condition)
|
337 |
+
if not current_page:
|
338 |
+
continue
|
339 |
+
current_page_task_ids = set(
|
340 |
+
[data.exp_meta.task_id for data in current_page if data.exp_meta.task_id not in sampled_task_ids])
|
341 |
+
sample_count = min(len(current_page_task_ids),
|
342 |
+
batch_size - len(sampled_task_ids))
|
343 |
+
sampled_task_ids.update(random.sample(
|
344 |
+
list(current_page_task_ids), sample_count))
|
345 |
+
|
346 |
+
return list(sampled_task_ids)
|
347 |
+
|
348 |
+
|
349 |
+
class DefaultConverter(Converter):
|
350 |
+
'''
|
351 |
+
Default converter do nothing.
|
352 |
+
'''
|
353 |
+
|
354 |
+
def to_dataset_row(self, task_experience: List[DataRow]) -> List[DataRow]:
|
355 |
+
return task_experience
|
356 |
+
|
357 |
+
|
358 |
+
class ReplayBuffer:
|
359 |
+
'''
|
360 |
+
Replay buffer for storing and sampling data.
|
361 |
+
'''
|
362 |
+
|
363 |
+
def __init__(
|
364 |
+
self,
|
365 |
+
storage: Storage = InMemoryStorage()
|
366 |
+
):
|
367 |
+
self._storage = storage
|
368 |
+
|
369 |
+
def store(self, data: DataRow):
|
370 |
+
'''
|
371 |
+
Store data in the replay buffer.
|
372 |
+
'''
|
373 |
+
if not data:
|
374 |
+
raise ValueError("Data is required")
|
375 |
+
self._storage.add(data)
|
376 |
+
|
377 |
+
def store_batch(self, data_batch: List[DataRow]):
|
378 |
+
'''
|
379 |
+
Store batch of data in the replay buffer.
|
380 |
+
'''
|
381 |
+
if not data_batch:
|
382 |
+
raise ValueError("Data batch is required")
|
383 |
+
self._storage.add_batch(data_batch)
|
384 |
+
|
385 |
+
def sample_task(self,
|
386 |
+
sampler: TaskSampler = RandomTaskSample(),
|
387 |
+
query_condition: QueryCondition = None,
|
388 |
+
converter: Converter = DefaultConverter(),
|
389 |
+
batch_size: int = 1000) -> List[T]:
|
390 |
+
'''
|
391 |
+
Sample Task from the replay buffer and convert to dataset row.
|
392 |
+
DefaultConverter return List[DataRow]
|
393 |
+
'''
|
394 |
+
sampled_task = sampler.sample_tasks(
|
395 |
+
self._storage, batch_size, query_condition)
|
396 |
+
return [converter.to_dataset_row(task_experiences) for task_experiences in sampled_task.values()]
|
397 |
+
|
398 |
+
def sample(self,
|
399 |
+
sampler: Sampler = RandomTaskSample(),
|
400 |
+
query_condition: QueryCondition = None,
|
401 |
+
converter: Converter = DefaultConverter(),
|
402 |
+
batch_size: int = 1000) -> List[T]:
|
403 |
+
'''
|
404 |
+
Sample data from the replay buffer and convert to dataset row.
|
405 |
+
DefaultConverter return List[DataRow]
|
406 |
+
'''
|
407 |
+
sampled_data = sampler.sample(
|
408 |
+
self._storage, batch_size, query_condition)
|
409 |
+
return converter.to_dataset_row(sampled_data)
|
aworld/replay_buffer/processor.py
ADDED
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
"""
|
3 |
+
processor.py
|
4 |
+
Used to clean raw trace data into standard storage structure for reinforcement learning training.
|
5 |
+
"""
|
6 |
+
import json
|
7 |
+
import os
|
8 |
+
import datetime
|
9 |
+
from typing import Any
|
10 |
+
import threading
|
11 |
+
|
12 |
+
from aworld.utils import import_package
|
13 |
+
from aworld.replay_buffer.base import DataRow, Experience, ExpMeta
|
14 |
+
from aworld.logs.util import logger
|
15 |
+
from aworld.utils.common import get_local_ip
|
16 |
+
|
17 |
+
|
18 |
+
class ReplayBufferExporter:
|
19 |
+
def __init__(self):
|
20 |
+
"""Initialize ReplayBufferExporter instance"""
|
21 |
+
self._file_locks = {}
|
22 |
+
self._lock_dict_lock = threading.Lock()
|
23 |
+
self._task_output_paths = {}
|
24 |
+
|
25 |
+
def _get_file_lock(self, file_path):
|
26 |
+
"""Get the lock for the specified file"""
|
27 |
+
with self._lock_dict_lock:
|
28 |
+
if file_path not in self._file_locks:
|
29 |
+
self._file_locks[file_path] = threading.Lock()
|
30 |
+
return self._file_locks[file_path]
|
31 |
+
|
32 |
+
def replay_buffer_exporter(self, spans: list[dict[str, Any]], output_dir: str):
|
33 |
+
"""
|
34 |
+
Process spans, only process spans with 'step_execution_' prefix, and group by task_id to output to different files
|
35 |
+
|
36 |
+
Args:
|
37 |
+
spans: span data list
|
38 |
+
output_dir: output directory path
|
39 |
+
"""
|
40 |
+
# Ensure output directory exists
|
41 |
+
import_package("oss2")
|
42 |
+
import oss2
|
43 |
+
|
44 |
+
os.makedirs(output_dir, exist_ok=True)
|
45 |
+
|
46 |
+
# Get OSS credentials from environment variables
|
47 |
+
enable_oss_export = os.getenv("EXPORT_REPLAY_TRACE_TO_OSS", "false").lower() == "true"
|
48 |
+
access_key_id = os.getenv('OSS_ACCESS_KEY_ID')
|
49 |
+
access_key_secret = os.getenv('OSS_ACCESS_KEY_SECRET')
|
50 |
+
endpoint = os.getenv('OSS_ENDPOINT')
|
51 |
+
bucket_name = os.getenv('OSS_BUCKET_NAME')
|
52 |
+
bucket = None
|
53 |
+
|
54 |
+
if not all([access_key_id, access_key_secret, endpoint, bucket_name]):
|
55 |
+
enable_oss_export = False
|
56 |
+
logger.warn("Missing required OSS environment variables")
|
57 |
+
else:
|
58 |
+
try:
|
59 |
+
# Initialize OSS client
|
60 |
+
auth = oss2.Auth(access_key_id, access_key_secret)
|
61 |
+
bucket = oss2.Bucket(auth, endpoint, bucket_name)
|
62 |
+
except Exception as e:
|
63 |
+
enable_oss_export = False
|
64 |
+
logger.warn(f"Failed to initialize OSS client, endpoint: {endpoint}, bucket: {bucket_name}. Error: {str(e)}")
|
65 |
+
|
66 |
+
# Group by task_id
|
67 |
+
task_groups = {}
|
68 |
+
|
69 |
+
for span_data in spans:
|
70 |
+
# Only process spans with 'step_execution_' prefix
|
71 |
+
if not span_data['name'].startswith('step_execution_'):
|
72 |
+
continue
|
73 |
+
|
74 |
+
attr = span_data.get('attributes', {})
|
75 |
+
exp_id = attr.get('exp_id')
|
76 |
+
task_id = attr.get('task_id', '')
|
77 |
+
|
78 |
+
if not exp_id or not task_id:
|
79 |
+
continue
|
80 |
+
|
81 |
+
if task_id not in task_groups:
|
82 |
+
task_groups[task_id] = {}
|
83 |
+
|
84 |
+
if exp_id not in task_groups[task_id]:
|
85 |
+
task_groups[task_id][exp_id] = {
|
86 |
+
'exp_meta': None,
|
87 |
+
'exp_data': None
|
88 |
+
}
|
89 |
+
|
90 |
+
# Process step_execution span
|
91 |
+
task_name = attr.get('task_name', '')
|
92 |
+
agent_id = attr.get('agent_id', '')
|
93 |
+
step = attr.get('step', 0)
|
94 |
+
execute_time = float(span_data.get('start_time', 0).split('.')[0].replace(' ', '').replace('-', '').replace(':', ''))
|
95 |
+
|
96 |
+
observation = {}
|
97 |
+
action = []
|
98 |
+
messages = []
|
99 |
+
pre_agent = None
|
100 |
+
if 'observation' in attr:
|
101 |
+
try:
|
102 |
+
observation = json.loads(attr['observation'])
|
103 |
+
except:
|
104 |
+
observation = attr['observation']
|
105 |
+
|
106 |
+
if 'actions' in attr:
|
107 |
+
try:
|
108 |
+
action = json.loads(attr['actions'])
|
109 |
+
except:
|
110 |
+
action = attr['actions']
|
111 |
+
|
112 |
+
if 'messages' in attr:
|
113 |
+
try:
|
114 |
+
messages = json.loads(attr['messages'])
|
115 |
+
except:
|
116 |
+
messages = attr['messages']
|
117 |
+
|
118 |
+
pre_agent = attr.get('pre_agent', '')
|
119 |
+
reward = attr.get('reward', 0.0)
|
120 |
+
adv = attr.get('adv_t', 0.0)
|
121 |
+
v = attr.get('v_t', 0.0)
|
122 |
+
|
123 |
+
exp_meta = ExpMeta(task_id, task_name, agent_id, step, execute_time, pre_agent)
|
124 |
+
exp_data = Experience(observation, action, reward, adv, v, messages)
|
125 |
+
|
126 |
+
task_groups[task_id][exp_id]['exp_meta'] = exp_meta
|
127 |
+
task_groups[task_id][exp_id]['exp_data'] = exp_data
|
128 |
+
|
129 |
+
# Process data for each task_id
|
130 |
+
for task_id, exp_groups in task_groups.items():
|
131 |
+
# Merge data and generate final Experience object
|
132 |
+
data_rows = []
|
133 |
+
|
134 |
+
# Read existing data (if any)
|
135 |
+
output_path = self._task_output_paths.get(task_id)
|
136 |
+
if not output_path:
|
137 |
+
timestamp = datetime.datetime.now().strftime("%Y%m%d")
|
138 |
+
replay_dir = os.path.join(output_dir or "./trace_data", timestamp, get_local_ip(), "replays")
|
139 |
+
replay_dataset_path = os.getenv("REPLAY_TRACE_DATASET_PATH", replay_dir)
|
140 |
+
export_dir = os.path.abspath(replay_dataset_path)
|
141 |
+
os.makedirs(export_dir, exist_ok=True)
|
142 |
+
output_path = os.path.join(export_dir, f"task_replay_{task_id}.json")
|
143 |
+
self._task_output_paths[task_id] = output_path
|
144 |
+
|
145 |
+
# Use thread lock to protect read and write operations
|
146 |
+
file_lock = self._get_file_lock(output_path)
|
147 |
+
with file_lock:
|
148 |
+
if os.path.exists(output_path):
|
149 |
+
try:
|
150 |
+
with open(output_path, 'r', encoding='utf-8') as f:
|
151 |
+
existing_data = json.load(f)
|
152 |
+
data_rows.extend([DataRow(
|
153 |
+
ExpMeta(**row['exp_meta']),
|
154 |
+
Experience(**row['exp_data']),
|
155 |
+
row['id']
|
156 |
+
) for row in existing_data])
|
157 |
+
except Exception as e:
|
158 |
+
print(f"Failed to read existing file {output_path}: {str(e)}")
|
159 |
+
|
160 |
+
# Add new data
|
161 |
+
for exp_id, group in exp_groups.items():
|
162 |
+
if group['exp_meta'] and group['exp_data']:
|
163 |
+
row = DataRow(group['exp_meta'], group['exp_data'], exp_id)
|
164 |
+
data_rows.append(row)
|
165 |
+
|
166 |
+
# Sort by execute_time
|
167 |
+
data_rows.sort(key=lambda x: x.exp_meta.execute_time)
|
168 |
+
|
169 |
+
# Export to json
|
170 |
+
with open(output_path, 'w', encoding='utf-8') as f:
|
171 |
+
json.dump([row.to_dict() for row in data_rows], f, ensure_ascii=False, indent=2)
|
172 |
+
logger.info(f"Processing completed, exported {len(data_rows)} experiences to {output_path}")
|
173 |
+
|
174 |
+
if enable_oss_export:
|
175 |
+
# Upload to OSS
|
176 |
+
try:
|
177 |
+
# Get the relative path
|
178 |
+
abs_path = os.path.abspath(output_path)
|
179 |
+
path_parts = abs_path.split(os.sep)
|
180 |
+
if len(path_parts) >= 4:
|
181 |
+
# Get the last 4 parts of the path
|
182 |
+
relative_path = os.sep.join(path_parts[-4:])
|
183 |
+
oss_key = relative_path
|
184 |
+
else:
|
185 |
+
oss_key = f"replay_buffer/{os.path.basename(output_path)}"
|
186 |
+
bucket.put_object_from_file(oss_key, output_path)
|
187 |
+
logger.info(f"Successfully uploaded {output_path} to OSS: {oss_key}")
|
188 |
+
except Exception as e:
|
189 |
+
logger.warn(f"Failed to upload {output_path} to OSS: {str(e)}")
|
190 |
+
|
aworld/replay_buffer/query_filter.py
ADDED
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, List, TypeVar, Union, Literal, TypedDict, Dict
|
2 |
+
|
3 |
+
DataRow = TypeVar('DataRow')
|
4 |
+
|
5 |
+
|
6 |
+
class BaseCondition(TypedDict):
|
7 |
+
field: str
|
8 |
+
value: Any
|
9 |
+
op: Literal[
|
10 |
+
'eq', 'ne', 'gt', 'gte', 'lt', 'lte',
|
11 |
+
'in', 'not_in', 'like', 'not_like',
|
12 |
+
'is_null', 'is_not_null'
|
13 |
+
]
|
14 |
+
|
15 |
+
|
16 |
+
class LogicalCondition(TypedDict):
|
17 |
+
and_: List['QueryCondition']
|
18 |
+
or_: List['QueryCondition']
|
19 |
+
|
20 |
+
|
21 |
+
QueryCondition = Union[BaseCondition, LogicalCondition]
|
22 |
+
|
23 |
+
|
24 |
+
class QueryBuilder:
|
25 |
+
'''
|
26 |
+
Query builder for replay buffer. result example:
|
27 |
+
{
|
28 |
+
"and": [
|
29 |
+
{"field": "field1", "value": "value1", "op": "eq"},
|
30 |
+
{"or": [{"field": "field2", "value": "value2", "op": "eq"}, {"field": "field3", "value": "value3", "op": "eq"}]}
|
31 |
+
]
|
32 |
+
}
|
33 |
+
'''
|
34 |
+
|
35 |
+
def __init__(self) -> None:
|
36 |
+
self.conditions: List[Dict[str, any]] = []
|
37 |
+
self.logical_ops: List[str] = []
|
38 |
+
|
39 |
+
def eq(self, field: str, value: any) -> 'QueryBuilder':
|
40 |
+
self.conditions.append({"field": field, "value": value, "op": "eq"})
|
41 |
+
return self
|
42 |
+
|
43 |
+
def ne(self, field: str, value: any) -> 'QueryBuilder':
|
44 |
+
self.conditions.append({"field": field, "value": value, "op": "ne"})
|
45 |
+
return self
|
46 |
+
|
47 |
+
def gt(self, field: str, value: any) -> 'QueryBuilder':
|
48 |
+
self.conditions.append({"field": field, "value": value, "op": "gt"})
|
49 |
+
return self
|
50 |
+
|
51 |
+
def gte(self, field: str, value: any) -> 'QueryBuilder':
|
52 |
+
self.conditions.append({"field": field, "value": value, "op": "gte"})
|
53 |
+
return self
|
54 |
+
|
55 |
+
def lt(self, field: str, value: any) -> 'QueryBuilder':
|
56 |
+
self.conditions.append({"field": field, "value": value, "op": "lt"})
|
57 |
+
return self
|
58 |
+
|
59 |
+
def lte(self, field: str, value: any) -> 'QueryBuilder':
|
60 |
+
self.conditions.append({"field": field, "value": value, "op": "lte"})
|
61 |
+
return self
|
62 |
+
|
63 |
+
def in_(self, field: str, value: any) -> 'QueryBuilder':
|
64 |
+
self.conditions.append({"field": field, "value": value, "op": "in"})
|
65 |
+
return self
|
66 |
+
|
67 |
+
def not_in(self, field: str, value: any) -> 'QueryBuilder':
|
68 |
+
self.conditions.append(
|
69 |
+
{"field": field, "value": value, "op": "not_in"})
|
70 |
+
return self
|
71 |
+
|
72 |
+
def like(self, field: str, value: any) -> 'QueryBuilder':
|
73 |
+
self.conditions.append({"field": field, "value": value, "op": "like"})
|
74 |
+
return self
|
75 |
+
|
76 |
+
def not_like(self, field: str, value: any) -> 'QueryBuilder':
|
77 |
+
self.conditions.append(
|
78 |
+
{"field": field, "value": value, "op": "not_like"})
|
79 |
+
return self
|
80 |
+
|
81 |
+
def is_null(self, field: str) -> 'QueryBuilder':
|
82 |
+
self.conditions.append({"field": field, "op": "is_null"})
|
83 |
+
return self
|
84 |
+
|
85 |
+
def is_not_null(self, field: str) -> 'QueryBuilder':
|
86 |
+
self.conditions.append({"field": field, "op": "is_not_null"})
|
87 |
+
return self
|
88 |
+
|
89 |
+
def and_(self) -> 'QueryBuilder':
|
90 |
+
self.logical_ops.append("and_")
|
91 |
+
return self
|
92 |
+
|
93 |
+
def or_(self) -> 'QueryBuilder':
|
94 |
+
self.logical_ops.append("or_")
|
95 |
+
return self
|
96 |
+
|
97 |
+
def nested(self, builder: 'QueryBuilder') -> 'QueryBuilder':
|
98 |
+
self.conditions.append({"nested": builder.build()})
|
99 |
+
return self
|
100 |
+
|
101 |
+
def build(self) -> QueryCondition:
|
102 |
+
conditions = self.conditions # all conditions(including nested)
|
103 |
+
operators = self.logical_ops
|
104 |
+
|
105 |
+
# Validate condition and operator counts (n conditions need n-1 operators)
|
106 |
+
if len(operators) != len(conditions) - 1:
|
107 |
+
raise ValueError("Mismatch between condition and operator counts")
|
108 |
+
|
109 |
+
# Use stack to handle operator precedence (simplified version supporting and/or)
|
110 |
+
stack: List[Union[Dict[str, any], str]] = []
|
111 |
+
|
112 |
+
for i, item in enumerate(conditions):
|
113 |
+
if i == 0:
|
114 |
+
# First element goes directly to stack (condition or nested)
|
115 |
+
stack.append(item)
|
116 |
+
continue
|
117 |
+
|
118 |
+
# Pop stack top as left operand
|
119 |
+
left = stack.pop()
|
120 |
+
op = operators[i-1] # Current operator (and/or)
|
121 |
+
right = item # Right operand (current condition)
|
122 |
+
|
123 |
+
# Build logical expression: {op: [left, right]}
|
124 |
+
expr = {op: [left, right]}
|
125 |
+
# Push result back to stack for further operations
|
126 |
+
stack.append(expr)
|
127 |
+
|
128 |
+
# Process nested conditions (recursive unfolding)
|
129 |
+
def process_nested(cond: any) -> any:
|
130 |
+
if isinstance(cond, dict):
|
131 |
+
if "nested" in cond:
|
132 |
+
# Recursively process sub-conditions
|
133 |
+
return process_nested(cond["nested"])
|
134 |
+
# Recursively process child elements
|
135 |
+
return {k: process_nested(v) for k, v in cond.items()}
|
136 |
+
elif isinstance(cond, list):
|
137 |
+
return [process_nested(item) for item in cond]
|
138 |
+
return cond
|
139 |
+
|
140 |
+
# Final result: only one element left in stack, return after processing nested
|
141 |
+
result = stack[0] if stack else None
|
142 |
+
return process_nested(result) if result else None
|
143 |
+
|
144 |
+
|
145 |
+
class QueryFilter:
|
146 |
+
'''
|
147 |
+
Query filter for replay buffer.
|
148 |
+
'''
|
149 |
+
|
150 |
+
def __init__(self, query_condition: QueryCondition) -> None:
|
151 |
+
self.query_condition = query_condition
|
152 |
+
|
153 |
+
def _get_field_value(self, row: DataRow, field: str) -> Any:
|
154 |
+
'''
|
155 |
+
Get field value from row.
|
156 |
+
'''
|
157 |
+
obj = row
|
158 |
+
for part in field.split('.'):
|
159 |
+
obj = getattr(obj, part, None)
|
160 |
+
if obj is None:
|
161 |
+
break
|
162 |
+
return obj
|
163 |
+
|
164 |
+
def _do_check(self, row: DataRow, condition: QueryCondition) -> bool:
|
165 |
+
"""
|
166 |
+
check if row match condition
|
167 |
+
"""
|
168 |
+
if condition is None:
|
169 |
+
return True
|
170 |
+
if "field" in condition and "op" in condition:
|
171 |
+
field_val = self._get_field_value(row, condition["field"])
|
172 |
+
op = condition["op"]
|
173 |
+
target_val = condition["value"]
|
174 |
+
|
175 |
+
if op == "eq":
|
176 |
+
return field_val == target_val
|
177 |
+
if op == "ne":
|
178 |
+
return field_val != target_val
|
179 |
+
if op == "gt":
|
180 |
+
return field_val > target_val
|
181 |
+
if op == "gte":
|
182 |
+
return field_val >= target_val
|
183 |
+
if op == "lt":
|
184 |
+
return field_val < target_val
|
185 |
+
if op == "lte":
|
186 |
+
return field_val <= target_val
|
187 |
+
if op == "in":
|
188 |
+
return field_val in target_val
|
189 |
+
if op == "not_in":
|
190 |
+
return field_val not in target_val
|
191 |
+
if op == "like":
|
192 |
+
return target_val in field_val
|
193 |
+
if op == "not_like":
|
194 |
+
return target_val not in field_val
|
195 |
+
if op == "is_null":
|
196 |
+
return field_val is None
|
197 |
+
if op == "is_not_null":
|
198 |
+
return field_val is not None
|
199 |
+
|
200 |
+
return False
|
201 |
+
|
202 |
+
elif "and_" in condition or "or_" in condition:
|
203 |
+
if "and_" in condition:
|
204 |
+
return all(self._do_check(row, c) for c in condition["and_"])
|
205 |
+
if "or_" in condition:
|
206 |
+
return any(self._do_check(row, c) for c in condition["or_"])
|
207 |
+
return False
|
208 |
+
|
209 |
+
return False
|
210 |
+
|
211 |
+
def check_condition(self, row: DataRow) -> bool:
|
212 |
+
"""
|
213 |
+
check if row match condition
|
214 |
+
"""
|
215 |
+
return self._do_check(row, self.query_condition)
|
216 |
+
|
217 |
+
def filter(self, rows: List[DataRow]) -> List[DataRow]:
|
218 |
+
"""filter rows by condition
|
219 |
+
Args:
|
220 |
+
rows (List[DataRow]): List of rows to filter.
|
221 |
+
query_condition (QueryCondition): Query condition.
|
222 |
+
Returns:
|
223 |
+
List[DataRow]: List of rows that match the condition.
|
224 |
+
"""
|
225 |
+
condition = self.query_condition
|
226 |
+
if not condition:
|
227 |
+
return rows
|
228 |
+
return [row for row in rows if self.check_condition(row)]
|