Spaces:
Sleeping
Sleeping
Upload 3 files
Browse files
aworld/replay_buffer/storage/multi_proc_mem.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import multiprocessing
|
2 |
+
import traceback
|
3 |
+
import pickle
|
4 |
+
from typing import Dict, List
|
5 |
+
from aworld.replay_buffer.base import Storage, DataRow
|
6 |
+
from aworld.replay_buffer.query_filter import QueryCondition, QueryFilter
|
7 |
+
from aworld.logs.util import logger
|
8 |
+
|
9 |
+
|
10 |
+
class MultiProcMemoryStorage(Storage):
|
11 |
+
|
12 |
+
"""
|
13 |
+
Memory storage for multi-process.
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(self,
|
17 |
+
data_dict: Dict[str, str],
|
18 |
+
fifo_queue: List[str],
|
19 |
+
lock: multiprocessing.Lock,
|
20 |
+
max_capacity: int = 10000):
|
21 |
+
self._data: Dict[str, str] = data_dict
|
22 |
+
self._fifo_queue = fifo_queue
|
23 |
+
self._max_capacity = max_capacity
|
24 |
+
self._lock = lock
|
25 |
+
|
26 |
+
def _save_to_shared_memory(self, data, task_id):
|
27 |
+
serialized_data = pickle.dumps(data)
|
28 |
+
try:
|
29 |
+
if task_id not in self._data or not self._data[task_id]:
|
30 |
+
shm = multiprocessing.shared_memory.SharedMemory(
|
31 |
+
create=True, size=len(serialized_data))
|
32 |
+
shm.buf[:len(serialized_data)] = serialized_data
|
33 |
+
self._data[task_id] = shm.name
|
34 |
+
shm.close()
|
35 |
+
return
|
36 |
+
shm = multiprocessing.shared_memory.SharedMemory(
|
37 |
+
name=self._data[task_id], create=False)
|
38 |
+
if len(serialized_data) > shm.size:
|
39 |
+
shm.close()
|
40 |
+
shm.unlink()
|
41 |
+
shm = multiprocessing.shared_memory.SharedMemory(
|
42 |
+
create=True, size=len(serialized_data))
|
43 |
+
shm.buf[:len(serialized_data)] = serialized_data
|
44 |
+
self._data[task_id] = shm.name
|
45 |
+
else:
|
46 |
+
shm.buf[:len(serialized_data)] = serialized_data
|
47 |
+
except FileNotFoundError:
|
48 |
+
shm = multiprocessing.shared_memory.SharedMemory(
|
49 |
+
create=True, size=len(serialized_data))
|
50 |
+
shm.buf[:len(serialized_data)] = serialized_data
|
51 |
+
self._data[task_id] = shm.name
|
52 |
+
shm.close()
|
53 |
+
|
54 |
+
def _load_from_shared_memory(self, task_id):
|
55 |
+
try:
|
56 |
+
if task_id not in self._data or not self._data[task_id]:
|
57 |
+
return []
|
58 |
+
try:
|
59 |
+
multiprocessing.shared_memory.SharedMemory(
|
60 |
+
name=self._data[task_id], create=False)
|
61 |
+
except FileNotFoundError:
|
62 |
+
return []
|
63 |
+
shm = multiprocessing.shared_memory.SharedMemory(
|
64 |
+
name=self._data[task_id])
|
65 |
+
data = pickle.loads(shm.buf.tobytes())
|
66 |
+
shm.close()
|
67 |
+
return data
|
68 |
+
except Exception as e:
|
69 |
+
stack_trace = traceback.format_exc()
|
70 |
+
logger.error(
|
71 |
+
f"_load_from_shared_memory error: {e}\nStack trace:\n{stack_trace}")
|
72 |
+
return []
|
73 |
+
|
74 |
+
def _delete_from_shared_memory(self, task_id):
|
75 |
+
try:
|
76 |
+
if task_id not in self._data or not self._data[task_id]:
|
77 |
+
return
|
78 |
+
shm = multiprocessing.shared_memory.SharedMemory(
|
79 |
+
name=self._data[task_id])
|
80 |
+
shm.close()
|
81 |
+
shm.unlink()
|
82 |
+
del self._data[task_id]
|
83 |
+
except FileNotFoundError:
|
84 |
+
pass
|
85 |
+
|
86 |
+
def add(self, data: DataRow):
|
87 |
+
if not data:
|
88 |
+
raise ValueError("Data is required")
|
89 |
+
if not data.exp_meta:
|
90 |
+
raise ValueError("exp_meta is required")
|
91 |
+
|
92 |
+
with self._lock:
|
93 |
+
current_size = sum(len(self._load_from_shared_memory(task_id))
|
94 |
+
for task_id in self._data.keys())
|
95 |
+
while current_size >= self._max_capacity and self._fifo_queue:
|
96 |
+
oldest_task_id = self._fifo_queue.pop(0)
|
97 |
+
if oldest_task_id in self._data.keys():
|
98 |
+
current_size -= len(self._load_from_shared_memory(oldest_task_id))
|
99 |
+
self._delete_from_shared_memory(oldest_task_id)
|
100 |
+
|
101 |
+
task_id = data.exp_meta.task_id
|
102 |
+
existing_data = self._load_from_shared_memory(task_id)
|
103 |
+
existing_data.append(data)
|
104 |
+
self._save_to_shared_memory(existing_data, task_id)
|
105 |
+
self._fifo_queue.append(task_id)
|
106 |
+
|
107 |
+
def add_batch(self, data_batch: List[DataRow]):
|
108 |
+
with self._lock:
|
109 |
+
for data in data_batch:
|
110 |
+
self.add(data)
|
111 |
+
|
112 |
+
def size(self, query_condition: QueryCondition = None) -> int:
|
113 |
+
with self._lock:
|
114 |
+
return len(self._get_all_without_lock(query_condition))
|
115 |
+
|
116 |
+
def get_paginated(self, page: int, page_size: int, query_condition: QueryCondition = None) -> List[DataRow]:
|
117 |
+
with self._lock:
|
118 |
+
if page < 1:
|
119 |
+
raise ValueError("Page must be greater than 0")
|
120 |
+
if page_size < 1:
|
121 |
+
raise ValueError("Page size must be greater than 0")
|
122 |
+
all_data = self._get_all_without_lock(query_condition)
|
123 |
+
start_index = (page - 1) * page_size
|
124 |
+
end_index = start_index + page_size
|
125 |
+
return all_data[start_index:end_index]
|
126 |
+
|
127 |
+
def _get_all_without_lock(self, query_condition: QueryCondition = None) -> List[DataRow]:
|
128 |
+
all_data = []
|
129 |
+
query_filter = None
|
130 |
+
if query_condition:
|
131 |
+
query_filter = QueryFilter(query_condition)
|
132 |
+
for task_id in self._data.keys():
|
133 |
+
local_data = self._load_from_shared_memory(task_id)
|
134 |
+
if query_filter:
|
135 |
+
all_data.extend(query_filter.filter(local_data))
|
136 |
+
else:
|
137 |
+
all_data.extend(local_data)
|
138 |
+
return all_data
|
139 |
+
|
140 |
+
def get_all(self, query_condition: QueryCondition = None) -> List[DataRow]:
|
141 |
+
with self._lock:
|
142 |
+
return self._get_all_without_lock(query_condition)
|
143 |
+
|
144 |
+
def get_by_task_id(self, task_id: str) -> List[DataRow]:
|
145 |
+
with self._lock:
|
146 |
+
if task_id in self._data.keys():
|
147 |
+
return self._load_from_shared_memory(task_id)
|
148 |
+
|
149 |
+
def get_bacth_by_task_ids(self, task_ids: List[str]) -> Dict[str, List[DataRow]]:
|
150 |
+
with self._lock:
|
151 |
+
result = {}
|
152 |
+
for task_id in task_ids:
|
153 |
+
if task_id in self._data.keys():
|
154 |
+
result[task_id] = self._load_from_shared_memory(task_id)
|
155 |
+
return result
|
aworld/replay_buffer/storage/odps.py
ADDED
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from pydantic import parse_obj_as
|
3 |
+
from typing import Any, List, Dict
|
4 |
+
from aworld.replay_buffer.base import Storage, DataRow, ExpMeta, Experience
|
5 |
+
from aworld.replay_buffer.query_filter import QueryCondition, QueryBuilder
|
6 |
+
from aworld.core.common import Observation, ActionModel
|
7 |
+
from aworld.logs.util import logger
|
8 |
+
from aworld.utils.import_package import import_package
|
9 |
+
import_package("odps") # noqa
|
10 |
+
from odps import ODPS # noqa
|
11 |
+
from odps.models.record import Record # noqa
|
12 |
+
|
13 |
+
|
14 |
+
class OdpsSQLBuilder:
|
15 |
+
''' Example:
|
16 |
+
query_condition = QueryBuilder().eq("field1", "value1").and_().eq("field2", "value2")
|
17 |
+
sql_builder = OdpsSQLBuilder(query_condition)
|
18 |
+
sql = sql_builder.build_sql()
|
19 |
+
print(sql) # 输出: "field1 = 'value1' AND field2 = 'value2'"
|
20 |
+
'''
|
21 |
+
|
22 |
+
def __init__(self, query_condition: QueryCondition):
|
23 |
+
self.query_condition = query_condition
|
24 |
+
|
25 |
+
def _build_condition(self, condition: QueryCondition) -> str:
|
26 |
+
if condition is None:
|
27 |
+
return ""
|
28 |
+
|
29 |
+
if "field" in condition and "op" in condition:
|
30 |
+
field = condition["field"].split('.')[-1]
|
31 |
+
op = condition["op"]
|
32 |
+
value = condition.get("value")
|
33 |
+
|
34 |
+
if op == "eq":
|
35 |
+
return f"{field} = {self._format_value(value)}"
|
36 |
+
elif op == "ne":
|
37 |
+
return f"{field} != {self._format_value(value)}"
|
38 |
+
elif op == "gt":
|
39 |
+
return f"{field} > {self._format_value(value)}"
|
40 |
+
elif op == "gte":
|
41 |
+
return f"{field} >= {self._format_value(value)}"
|
42 |
+
elif op == "lt":
|
43 |
+
return f"{field} < {self._format_value(value)}"
|
44 |
+
elif op == "lte":
|
45 |
+
return f"{field} <= {self._format_value(value)}"
|
46 |
+
elif op == "in":
|
47 |
+
return f"{field} IN ({self._format_value(value)})"
|
48 |
+
elif op == "not_in":
|
49 |
+
return f"{field} NOT IN ({self._format_value(value)})"
|
50 |
+
elif op == "like":
|
51 |
+
return f"{field} LIKE '{value}'"
|
52 |
+
elif op == "not_like":
|
53 |
+
return f"{field} NOT LIKE '{value}'"
|
54 |
+
elif op == "is_null":
|
55 |
+
return f"{field} IS NULL"
|
56 |
+
elif op == "is_not_null":
|
57 |
+
return f"{field} IS NOT NULL"
|
58 |
+
|
59 |
+
elif "and_" in condition:
|
60 |
+
return f"({' AND '.join(self._build_condition(c) for c in condition['and_'])})"
|
61 |
+
elif "or_" in condition:
|
62 |
+
return f"({' OR '.join(self._build_condition(c) for c in condition['or_'])})"
|
63 |
+
|
64 |
+
return ""
|
65 |
+
|
66 |
+
def _format_value(self, value: Any) -> str:
|
67 |
+
if isinstance(value, str):
|
68 |
+
return f"'{value}'"
|
69 |
+
elif isinstance(value, (list, tuple)):
|
70 |
+
return ", ".join(self._format_value(v) for v in value)
|
71 |
+
return str(value)
|
72 |
+
|
73 |
+
def build_sql(self) -> str:
|
74 |
+
if not self.query_condition:
|
75 |
+
return ""
|
76 |
+
return self._build_condition(self.query_condition)
|
77 |
+
|
78 |
+
|
79 |
+
class OdpsStorage(Storage):
|
80 |
+
'''
|
81 |
+
Aliyun ODPS storage.
|
82 |
+
Table schema:
|
83 |
+
id: int
|
84 |
+
task_id: string
|
85 |
+
task_name: string
|
86 |
+
agent_id: string
|
87 |
+
step: int
|
88 |
+
execute_time: string
|
89 |
+
state: string
|
90 |
+
actions: string
|
91 |
+
reward_t: string
|
92 |
+
adv_t: string
|
93 |
+
v_t: string
|
94 |
+
'''
|
95 |
+
|
96 |
+
def __init__(self, table_name: str, project: str, endpoint: str, access_id: str, access_key: str, **kwargs):
|
97 |
+
self.table_name = table_name
|
98 |
+
self.project = project
|
99 |
+
self.endpoint = endpoint
|
100 |
+
self.access_id = access_id
|
101 |
+
self.access_key = access_key
|
102 |
+
self.kwargs = kwargs
|
103 |
+
self._init_odps()
|
104 |
+
|
105 |
+
def _init_odps(self):
|
106 |
+
self.odps = ODPS(self.access_id, self.access_key,
|
107 |
+
self.project, self.endpoint)
|
108 |
+
|
109 |
+
def _get_table(self):
|
110 |
+
return self.odps.get_table(self.table_name)
|
111 |
+
|
112 |
+
def _convert_row_to_record(self, row: DataRow) -> Record:
|
113 |
+
table = self._get_table()
|
114 |
+
record = table.new_record()
|
115 |
+
record["id"] = row.id
|
116 |
+
record["task_id"] = row.exp_meta.task_id
|
117 |
+
record["task_name"] = row.exp_meta.task_name
|
118 |
+
record["agent_id"] = row.exp_meta.agent_id
|
119 |
+
record["step"] = row.exp_meta.step
|
120 |
+
record["execute_time"] = row.exp_meta.execute_time
|
121 |
+
if row.exp_data.state:
|
122 |
+
record["state"] = row.exp_data.state.model_dump_json()
|
123 |
+
if row.exp_data.actions:
|
124 |
+
record["actions"] = "[" + ", ".join(action.model_dump_json()
|
125 |
+
for action in row.exp_data.actions) + "]"
|
126 |
+
if row.exp_data.reward_t:
|
127 |
+
record["reward_t"] = row.exp_data.reward_t
|
128 |
+
if row.exp_data.adv_t:
|
129 |
+
record["adv_t"] = row.exp_data.adv_t
|
130 |
+
if row.exp_data.v_t:
|
131 |
+
record["v_t"] = row.exp_data.v_t
|
132 |
+
return record
|
133 |
+
|
134 |
+
def _convert_record_to_row(self, record: Record) -> DataRow:
|
135 |
+
return DataRow(
|
136 |
+
id=record.id,
|
137 |
+
exp_meta=ExpMeta(
|
138 |
+
task_id=record['task_id'],
|
139 |
+
task_name=record['task_name'],
|
140 |
+
agent_id=record['agent_id'],
|
141 |
+
step=record['step'],
|
142 |
+
execute_time=record['execute_time'],
|
143 |
+
pre_agent=record['pre_agent'] if 'pre_agent' in record else None
|
144 |
+
),
|
145 |
+
exp_data=Experience(
|
146 |
+
state=parse_obj_as(Observation, json.loads(record['state'])),
|
147 |
+
actions=[parse_obj_as(ActionModel, item)
|
148 |
+
for item in json.loads(record['actions'])],
|
149 |
+
reward_t=record['reward_t'] if 'reward_t' in record else None,
|
150 |
+
adv_t=record['adv_t'] if 'adv_t' in record else None,
|
151 |
+
v_t=record['v_t'] if 'v_t' in record else None,
|
152 |
+
)
|
153 |
+
)
|
154 |
+
|
155 |
+
def _build_paginated_sql(self, page: int = None, page_size: int = None):
|
156 |
+
if page and page_size:
|
157 |
+
offset = (page - 1) * page_size
|
158 |
+
limit = page_size
|
159 |
+
return f" LIMIT {offset}, {limit}"
|
160 |
+
return ""
|
161 |
+
|
162 |
+
def _build_sql(self, query_condition: QueryCondition, page: int = None, page_size: int = None):
|
163 |
+
if not query_condition:
|
164 |
+
return f"SELECT * FROM {self.table_name}" + self._build_paginated_sql(page, page_size)
|
165 |
+
where_builder = OdpsSQLBuilder(query_condition)
|
166 |
+
sql = f"SELECT * FROM {self.table_name} WHERE {where_builder.build_sql()}" + self._build_paginated_sql(page,
|
167 |
+
page_size)
|
168 |
+
return sql
|
169 |
+
|
170 |
+
def _build_count_sql(self, query_condition: QueryCondition):
|
171 |
+
if not query_condition:
|
172 |
+
return f"SELECT count(1) as count FROM {self.table_name}"
|
173 |
+
where_builder = OdpsSQLBuilder(query_condition)
|
174 |
+
sql = f"SELECT count(1) as count FROM {self.table_name} WHERE {where_builder.build_sql()}"
|
175 |
+
return sql
|
176 |
+
|
177 |
+
def add(self, row: DataRow):
|
178 |
+
record = self._convert_row_to_record(row)
|
179 |
+
self.odps.write_table(self.table_name, [record])
|
180 |
+
|
181 |
+
def add_batch(self, rows: list[DataRow]):
|
182 |
+
records = [self._convert_row_to_record(row) for row in rows]
|
183 |
+
self.odps.write_table(self.table_name, records)
|
184 |
+
|
185 |
+
def size(self, query_condition: QueryCondition = None) -> int:
|
186 |
+
sql = self._build_count_sql(query_condition)
|
187 |
+
with self.odps.execute_sql(sql).open_reader() as reader:
|
188 |
+
return reader[0]["count"]
|
189 |
+
|
190 |
+
def get_all(self, query_condition: QueryCondition = None) -> list[DataRow]:
|
191 |
+
sql = self._build_sql(query_condition)
|
192 |
+
logger.info(f"get_all sql: {sql}")
|
193 |
+
with self.odps.execute_sql(sql).open_reader(tunnel=True) as reader:
|
194 |
+
rows = []
|
195 |
+
for record in reader:
|
196 |
+
rows.append(self._convert_record_to_row(record))
|
197 |
+
return rows
|
198 |
+
|
199 |
+
def get_paginated(self, page: int, page_size: int, query_condition: QueryCondition = None) -> List[DataRow]:
|
200 |
+
sql = self._build_sql(query_condition, page, page_size)
|
201 |
+
logger.info(f"get_paginated sql: {sql}")
|
202 |
+
with self.odps.execute_sql(sql).open_reader(tunnel=True) as reader:
|
203 |
+
rows = []
|
204 |
+
for record in reader:
|
205 |
+
rows.append(self._convert_record_to_row(record))
|
206 |
+
return rows
|
207 |
+
|
208 |
+
def get_by_task_id(self, task_id: str) -> List[DataRow]:
|
209 |
+
query_condition = QueryBuilder().eq("task_id", task_id).build()
|
210 |
+
return self.get_all(query_condition)
|
211 |
+
|
212 |
+
def get_bacth_by_task_ids(self, task_ids: List[str]) -> Dict[str, List[DataRow]]:
|
213 |
+
query_condition = QueryBuilder().in_("task_id", task_ids).build()
|
214 |
+
sql = self._build_sql(query_condition)
|
215 |
+
logger.info(f"get_bacth_by_task_ids sql: {sql}")
|
216 |
+
result = {}
|
217 |
+
with self.odps.execute_sql(sql).open_reader(tunnel=True) as reader:
|
218 |
+
for record in reader:
|
219 |
+
row = self._convert_record_to_row(record)
|
220 |
+
if row.exp_meta.task_id not in result:
|
221 |
+
result[row.exp_meta.task_id] = []
|
222 |
+
result[row.exp_meta.task_id].append(row)
|
223 |
+
return result
|
aworld/replay_buffer/storage/redis.py
ADDED
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from typing import Dict, List
|
3 |
+
from aworld.replay_buffer.base import Storage, DataRow, ExpMeta, Experience
|
4 |
+
from aworld.logs.util import logger
|
5 |
+
from aworld.utils.import_package import import_package
|
6 |
+
from aworld.replay_buffer.query_filter import QueryCondition, QueryBuilder
|
7 |
+
from aworld.core.common import Observation, ActionModel
|
8 |
+
import_package("redis") # noqa
|
9 |
+
from redis import Redis # noqa
|
10 |
+
from redis.commands.json.path import Path # noqa
|
11 |
+
import redis.commands.search.aggregation as aggregations # noqa
|
12 |
+
import redis.commands.search.reducers as reducers # noqa
|
13 |
+
from redis.commands.search.field import TextField, NumericField, TagField # noqa
|
14 |
+
from redis.commands.search.index_definition import IndexDefinition, IndexType # noqa
|
15 |
+
from redis.commands.search.query import Query # noqa
|
16 |
+
import redis.exceptions # noqa
|
17 |
+
|
18 |
+
|
19 |
+
class RedisSearchQueryBuilder:
|
20 |
+
"""
|
21 |
+
Build redis search query from query condition
|
22 |
+
"""
|
23 |
+
|
24 |
+
def __init__(self, query_condition: QueryCondition):
|
25 |
+
self.query_condition = query_condition
|
26 |
+
|
27 |
+
def _build_condition(self, condition: QueryCondition) -> str:
|
28 |
+
if condition is None:
|
29 |
+
return ""
|
30 |
+
|
31 |
+
if "field" in condition and "op" in condition:
|
32 |
+
field = condition["field"].split('.')[-1]
|
33 |
+
op = condition["op"]
|
34 |
+
value = condition.get("value")
|
35 |
+
|
36 |
+
if op == "eq":
|
37 |
+
return f"@{field}:{{{value}}}"
|
38 |
+
elif op == "ne":
|
39 |
+
return f"-@{field}:{{{value}}}"
|
40 |
+
elif op == "gt":
|
41 |
+
return f"@{field}:[{value} +inf]"
|
42 |
+
elif op == "gte":
|
43 |
+
return f"@{field}:[{value} +inf]"
|
44 |
+
elif op == "lt":
|
45 |
+
return f"@{field}:[-inf {value}]"
|
46 |
+
elif op == "lte":
|
47 |
+
return f"@{field}:[-inf {value}]"
|
48 |
+
elif op == "in":
|
49 |
+
return f"@{field}:{{{'|'.join(str(v) for v in value)}}}"
|
50 |
+
elif op == "not_in":
|
51 |
+
return f"-@{field}:{{{'|'.join(str(v) for v in value)}}}"
|
52 |
+
elif op == "like":
|
53 |
+
return f"@{field}:*{value}*"
|
54 |
+
elif op == "not_like":
|
55 |
+
return f"-@{field}:*{value}*"
|
56 |
+
elif op == "is_null":
|
57 |
+
return f"-@{field}:*"
|
58 |
+
elif op == "is_not_null":
|
59 |
+
return f"@{field}:*"
|
60 |
+
|
61 |
+
elif "and_" in condition:
|
62 |
+
conditions = [self._build_condition(c) for c in condition["and_"]]
|
63 |
+
return " ".join(conditions)
|
64 |
+
elif "or_" in condition:
|
65 |
+
conditions = [self._build_condition(c) for c in condition["or_"]]
|
66 |
+
return f"({'|'.join(conditions)})"
|
67 |
+
|
68 |
+
return ""
|
69 |
+
|
70 |
+
def build(self) -> Query:
|
71 |
+
query_str = self._build_condition(self.query_condition)
|
72 |
+
logger.info(f"redis search query: {query_str}")
|
73 |
+
return Query(query_str)
|
74 |
+
|
75 |
+
|
76 |
+
class RedisStorage(Storage):
|
77 |
+
def __init__(self,
|
78 |
+
host: str = 'localhost',
|
79 |
+
port: int = 6379,
|
80 |
+
db: int = 0,
|
81 |
+
password: str = None,
|
82 |
+
key_prefix: str = 'AWORLD:RB:',
|
83 |
+
index_name: str = 'idx:AWORLD:RB',
|
84 |
+
recreate_idx_if_exists=False):
|
85 |
+
self._redis = Redis(host=host, port=port, db=db, password=password)
|
86 |
+
self._key_prefix = key_prefix
|
87 |
+
self._index_name = index_name
|
88 |
+
self._recreate_idx_if_exists = recreate_idx_if_exists
|
89 |
+
self._create_index()
|
90 |
+
|
91 |
+
def _create_index(self):
|
92 |
+
try:
|
93 |
+
existing_indices = self._redis.execute_command('FT._LIST')
|
94 |
+
if self._index_name.encode('utf-8') in existing_indices:
|
95 |
+
logger.info(f"Index {self._index_name} already exists")
|
96 |
+
if self._recreate_idx_if_exists:
|
97 |
+
self._redis.ft(self._index_name).dropindex()
|
98 |
+
logger.info(f"Index {self._index_name} dropped")
|
99 |
+
else:
|
100 |
+
return
|
101 |
+
self._redis.ft(self._index_name).create_index(
|
102 |
+
(
|
103 |
+
TagField("id"),
|
104 |
+
TagField("task_id"),
|
105 |
+
TextField("task_name"),
|
106 |
+
TagField("agent_id"),
|
107 |
+
NumericField("step"),
|
108 |
+
NumericField("execute_time"),
|
109 |
+
TagField("pre_agent")
|
110 |
+
),
|
111 |
+
definition=IndexDefinition(
|
112 |
+
prefix=[self._key_prefix], index_type=IndexType.HASH)
|
113 |
+
)
|
114 |
+
except redis.exceptions.ResponseError as e:
|
115 |
+
logger.error(f"Create index {self._index_name} failed. {e}")
|
116 |
+
|
117 |
+
def _get_object_key(self, key: str) -> str:
|
118 |
+
return f"{self._key_prefix}{key}"
|
119 |
+
|
120 |
+
def _serialize_to_str(self, value) -> str:
|
121 |
+
if str is None:
|
122 |
+
return ""
|
123 |
+
if isinstance(value, (int, float)):
|
124 |
+
return str(value)
|
125 |
+
return str(value) if value is not None else ""
|
126 |
+
|
127 |
+
def _serialize(self, data: DataRow) -> Dict[str, str]:
|
128 |
+
dict_data = {
|
129 |
+
'id': data.id,
|
130 |
+
'task_id': data.exp_meta.task_id,
|
131 |
+
'task_name': data.exp_meta.task_name,
|
132 |
+
'agent_id': data.exp_meta.agent_id,
|
133 |
+
'step': data.exp_meta.step,
|
134 |
+
'execute_time': data.exp_meta.execute_time,
|
135 |
+
'pre_agent': data.exp_meta.pre_agent,
|
136 |
+
'state': data.exp_data.state.model_dump_json(),
|
137 |
+
'actions': "[" + ", ".join(action.model_dump_json()
|
138 |
+
for action in data.exp_data.actions) + "]",
|
139 |
+
'reward_t': data.exp_data.reward_t,
|
140 |
+
'adv_t': data.exp_data.adv_t,
|
141 |
+
'v_t': data.exp_data.v_t
|
142 |
+
}
|
143 |
+
return {k: self._serialize_to_str(v) for k, v in dict_data.items()}
|
144 |
+
|
145 |
+
def _deserialize(self, data: Dict) -> DataRow:
|
146 |
+
if not data:
|
147 |
+
return None
|
148 |
+
return DataRow(
|
149 |
+
id=data.get('id'),
|
150 |
+
exp_meta=ExpMeta(
|
151 |
+
task_id=data.get('task_id'),
|
152 |
+
task_name=data.get('task_name'),
|
153 |
+
agent_id=data.get('agent_id'),
|
154 |
+
step=int(data.get('step', 0)),
|
155 |
+
execute_time=float(data.get('execute_time', 0)),
|
156 |
+
pre_agent=data.get('pre_agent')
|
157 |
+
),
|
158 |
+
exp_data=Experience(
|
159 |
+
state=Observation.model_validate_json(data.get('state', '{}')),
|
160 |
+
actions=[ActionModel.model_validate_json(json.dumps(action))
|
161 |
+
for action in json.loads(data.get('actions', '[]'))],
|
162 |
+
reward_t=float(data.get('reward_t', 0)) if data.get(
|
163 |
+
'reward_t') is not '' else None,
|
164 |
+
adv_t=float(data.get('adv_t', 0)) if data.get(
|
165 |
+
'adv_t') is not '' else None,
|
166 |
+
v_t=float(data.get('v_t', 0)) if data.get(
|
167 |
+
'v_t') is not '' else None
|
168 |
+
)
|
169 |
+
)
|
170 |
+
|
171 |
+
def add(self, data: DataRow):
|
172 |
+
key = self._get_object_key(data.id)
|
173 |
+
self._redis.hset(key, mapping=self._serialize(data))
|
174 |
+
|
175 |
+
def add_batch(self, data_batch: List[DataRow]):
|
176 |
+
pipeline = self._redis.pipeline()
|
177 |
+
for data in data_batch:
|
178 |
+
if not data or not data.exp_meta:
|
179 |
+
continue
|
180 |
+
key = self._get_object_key(data.id)
|
181 |
+
pipeline.hset(key, mapping=self._serialize(data))
|
182 |
+
pipeline.execute()
|
183 |
+
|
184 |
+
def search(self, key: str, value: str) -> DataRow:
|
185 |
+
result = self._redis.ft(self._index_name).search(
|
186 |
+
Query(f"@{key}:{{{value}}}"))
|
187 |
+
logger.info(f"Search result: {result}")
|
188 |
+
|
189 |
+
def size(self, query_condition: QueryCondition = None) -> int:
|
190 |
+
'''
|
191 |
+
Get the size of the storage.
|
192 |
+
Returns:
|
193 |
+
int: Size of the storage.
|
194 |
+
'''
|
195 |
+
if not query_condition:
|
196 |
+
return self._redis.ft(self._index_name).info()['num_docs']
|
197 |
+
query_builder = RedisSearchQueryBuilder(query_condition)
|
198 |
+
query = query_builder.build()
|
199 |
+
return self._redis.ft(self._index_name).search(query).total
|
200 |
+
|
201 |
+
def get_paginated(self, page: int, page_size: int, query_condition: QueryCondition = None) -> List[DataRow]:
|
202 |
+
'''
|
203 |
+
Get paginated data from the storage.
|
204 |
+
Args:
|
205 |
+
page (int): Page number.
|
206 |
+
page_size (int): Number of data per page.
|
207 |
+
Returns:
|
208 |
+
List[DataRow]: List of data.
|
209 |
+
'''
|
210 |
+
if not query_condition:
|
211 |
+
result = self._redis.ft(self._index_name).search(
|
212 |
+
Query("*").paging(page, page_size))
|
213 |
+
else:
|
214 |
+
query_builder = RedisSearchQueryBuilder(query_condition)
|
215 |
+
query = query_builder.build().paging(page, page_size)
|
216 |
+
result = self._redis.ft(self._index_name).search(query)
|
217 |
+
return [self._deserialize(doc.__dict__) for doc in result.docs]
|
218 |
+
|
219 |
+
def get_all(self, query_condition: QueryCondition = None) -> List[DataRow]:
|
220 |
+
'''
|
221 |
+
Get all data from the storage.
|
222 |
+
Returns:
|
223 |
+
List[DataRow]: List of data.
|
224 |
+
'''
|
225 |
+
if not query_condition:
|
226 |
+
result = self._redis.ft(self._index_name).search(Query("*"))
|
227 |
+
else:
|
228 |
+
query_builder = RedisSearchQueryBuilder(query_condition)
|
229 |
+
query = query_builder.build()
|
230 |
+
result = self._redis.ft(self._index_name).search(query)
|
231 |
+
return [self._deserialize(doc.__dict__) for doc in result.docs]
|
232 |
+
|
233 |
+
def get_by_task_id(self, task_id: str) -> List[DataRow]:
|
234 |
+
'''
|
235 |
+
Get data by task_id from the storage.
|
236 |
+
Args:
|
237 |
+
task_id (str): Task id.
|
238 |
+
Returns:
|
239 |
+
List[DataRow]: List of data.
|
240 |
+
'''
|
241 |
+
query_condition = QueryBuilder().eq("task_id", task_id).build()
|
242 |
+
return self.get_all(query_condition)
|
243 |
+
|
244 |
+
def get_bacth_by_task_ids(self, task_ids: List[str]) -> Dict[str, List[DataRow]]:
|
245 |
+
'''
|
246 |
+
Get data by task_ids from the storage.
|
247 |
+
Args:
|
248 |
+
task_ids (List[str]): List of task ids.
|
249 |
+
Returns:
|
250 |
+
Dict[str, List[DataRow]]: Dict of task id and list of data.
|
251 |
+
'''
|
252 |
+
query_condition = QueryBuilder().in_("task_id", task_ids).build()
|
253 |
+
result = self.get_all(query_condition)
|
254 |
+
return {task_id: [data for data in result if data.exp_meta.task_id == task_id] for task_id in task_ids}
|
255 |
+
|
256 |
+
def clear(self):
|
257 |
+
'''
|
258 |
+
Clear the storage.
|
259 |
+
'''
|
260 |
+
keys = self._redis.keys(f"{self._key_prefix}*")
|
261 |
+
if keys:
|
262 |
+
self._redis.delete(*keys)
|