Duibonduil commited on
Commit
b7cf4ad
·
verified ·
1 Parent(s): 2cc8fcb

Upload 3 files

Browse files
aworld/replay_buffer/storage/multi_proc_mem.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import multiprocessing
2
+ import traceback
3
+ import pickle
4
+ from typing import Dict, List
5
+ from aworld.replay_buffer.base import Storage, DataRow
6
+ from aworld.replay_buffer.query_filter import QueryCondition, QueryFilter
7
+ from aworld.logs.util import logger
8
+
9
+
10
+ class MultiProcMemoryStorage(Storage):
11
+
12
+ """
13
+ Memory storage for multi-process.
14
+ """
15
+
16
+ def __init__(self,
17
+ data_dict: Dict[str, str],
18
+ fifo_queue: List[str],
19
+ lock: multiprocessing.Lock,
20
+ max_capacity: int = 10000):
21
+ self._data: Dict[str, str] = data_dict
22
+ self._fifo_queue = fifo_queue
23
+ self._max_capacity = max_capacity
24
+ self._lock = lock
25
+
26
+ def _save_to_shared_memory(self, data, task_id):
27
+ serialized_data = pickle.dumps(data)
28
+ try:
29
+ if task_id not in self._data or not self._data[task_id]:
30
+ shm = multiprocessing.shared_memory.SharedMemory(
31
+ create=True, size=len(serialized_data))
32
+ shm.buf[:len(serialized_data)] = serialized_data
33
+ self._data[task_id] = shm.name
34
+ shm.close()
35
+ return
36
+ shm = multiprocessing.shared_memory.SharedMemory(
37
+ name=self._data[task_id], create=False)
38
+ if len(serialized_data) > shm.size:
39
+ shm.close()
40
+ shm.unlink()
41
+ shm = multiprocessing.shared_memory.SharedMemory(
42
+ create=True, size=len(serialized_data))
43
+ shm.buf[:len(serialized_data)] = serialized_data
44
+ self._data[task_id] = shm.name
45
+ else:
46
+ shm.buf[:len(serialized_data)] = serialized_data
47
+ except FileNotFoundError:
48
+ shm = multiprocessing.shared_memory.SharedMemory(
49
+ create=True, size=len(serialized_data))
50
+ shm.buf[:len(serialized_data)] = serialized_data
51
+ self._data[task_id] = shm.name
52
+ shm.close()
53
+
54
+ def _load_from_shared_memory(self, task_id):
55
+ try:
56
+ if task_id not in self._data or not self._data[task_id]:
57
+ return []
58
+ try:
59
+ multiprocessing.shared_memory.SharedMemory(
60
+ name=self._data[task_id], create=False)
61
+ except FileNotFoundError:
62
+ return []
63
+ shm = multiprocessing.shared_memory.SharedMemory(
64
+ name=self._data[task_id])
65
+ data = pickle.loads(shm.buf.tobytes())
66
+ shm.close()
67
+ return data
68
+ except Exception as e:
69
+ stack_trace = traceback.format_exc()
70
+ logger.error(
71
+ f"_load_from_shared_memory error: {e}\nStack trace:\n{stack_trace}")
72
+ return []
73
+
74
+ def _delete_from_shared_memory(self, task_id):
75
+ try:
76
+ if task_id not in self._data or not self._data[task_id]:
77
+ return
78
+ shm = multiprocessing.shared_memory.SharedMemory(
79
+ name=self._data[task_id])
80
+ shm.close()
81
+ shm.unlink()
82
+ del self._data[task_id]
83
+ except FileNotFoundError:
84
+ pass
85
+
86
+ def add(self, data: DataRow):
87
+ if not data:
88
+ raise ValueError("Data is required")
89
+ if not data.exp_meta:
90
+ raise ValueError("exp_meta is required")
91
+
92
+ with self._lock:
93
+ current_size = sum(len(self._load_from_shared_memory(task_id))
94
+ for task_id in self._data.keys())
95
+ while current_size >= self._max_capacity and self._fifo_queue:
96
+ oldest_task_id = self._fifo_queue.pop(0)
97
+ if oldest_task_id in self._data.keys():
98
+ current_size -= len(self._load_from_shared_memory(oldest_task_id))
99
+ self._delete_from_shared_memory(oldest_task_id)
100
+
101
+ task_id = data.exp_meta.task_id
102
+ existing_data = self._load_from_shared_memory(task_id)
103
+ existing_data.append(data)
104
+ self._save_to_shared_memory(existing_data, task_id)
105
+ self._fifo_queue.append(task_id)
106
+
107
+ def add_batch(self, data_batch: List[DataRow]):
108
+ with self._lock:
109
+ for data in data_batch:
110
+ self.add(data)
111
+
112
+ def size(self, query_condition: QueryCondition = None) -> int:
113
+ with self._lock:
114
+ return len(self._get_all_without_lock(query_condition))
115
+
116
+ def get_paginated(self, page: int, page_size: int, query_condition: QueryCondition = None) -> List[DataRow]:
117
+ with self._lock:
118
+ if page < 1:
119
+ raise ValueError("Page must be greater than 0")
120
+ if page_size < 1:
121
+ raise ValueError("Page size must be greater than 0")
122
+ all_data = self._get_all_without_lock(query_condition)
123
+ start_index = (page - 1) * page_size
124
+ end_index = start_index + page_size
125
+ return all_data[start_index:end_index]
126
+
127
+ def _get_all_without_lock(self, query_condition: QueryCondition = None) -> List[DataRow]:
128
+ all_data = []
129
+ query_filter = None
130
+ if query_condition:
131
+ query_filter = QueryFilter(query_condition)
132
+ for task_id in self._data.keys():
133
+ local_data = self._load_from_shared_memory(task_id)
134
+ if query_filter:
135
+ all_data.extend(query_filter.filter(local_data))
136
+ else:
137
+ all_data.extend(local_data)
138
+ return all_data
139
+
140
+ def get_all(self, query_condition: QueryCondition = None) -> List[DataRow]:
141
+ with self._lock:
142
+ return self._get_all_without_lock(query_condition)
143
+
144
+ def get_by_task_id(self, task_id: str) -> List[DataRow]:
145
+ with self._lock:
146
+ if task_id in self._data.keys():
147
+ return self._load_from_shared_memory(task_id)
148
+
149
+ def get_bacth_by_task_ids(self, task_ids: List[str]) -> Dict[str, List[DataRow]]:
150
+ with self._lock:
151
+ result = {}
152
+ for task_id in task_ids:
153
+ if task_id in self._data.keys():
154
+ result[task_id] = self._load_from_shared_memory(task_id)
155
+ return result
aworld/replay_buffer/storage/odps.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pydantic import parse_obj_as
3
+ from typing import Any, List, Dict
4
+ from aworld.replay_buffer.base import Storage, DataRow, ExpMeta, Experience
5
+ from aworld.replay_buffer.query_filter import QueryCondition, QueryBuilder
6
+ from aworld.core.common import Observation, ActionModel
7
+ from aworld.logs.util import logger
8
+ from aworld.utils.import_package import import_package
9
+ import_package("odps") # noqa
10
+ from odps import ODPS # noqa
11
+ from odps.models.record import Record # noqa
12
+
13
+
14
+ class OdpsSQLBuilder:
15
+ ''' Example:
16
+ query_condition = QueryBuilder().eq("field1", "value1").and_().eq("field2", "value2")
17
+ sql_builder = OdpsSQLBuilder(query_condition)
18
+ sql = sql_builder.build_sql()
19
+ print(sql) # 输出: "field1 = 'value1' AND field2 = 'value2'"
20
+ '''
21
+
22
+ def __init__(self, query_condition: QueryCondition):
23
+ self.query_condition = query_condition
24
+
25
+ def _build_condition(self, condition: QueryCondition) -> str:
26
+ if condition is None:
27
+ return ""
28
+
29
+ if "field" in condition and "op" in condition:
30
+ field = condition["field"].split('.')[-1]
31
+ op = condition["op"]
32
+ value = condition.get("value")
33
+
34
+ if op == "eq":
35
+ return f"{field} = {self._format_value(value)}"
36
+ elif op == "ne":
37
+ return f"{field} != {self._format_value(value)}"
38
+ elif op == "gt":
39
+ return f"{field} > {self._format_value(value)}"
40
+ elif op == "gte":
41
+ return f"{field} >= {self._format_value(value)}"
42
+ elif op == "lt":
43
+ return f"{field} < {self._format_value(value)}"
44
+ elif op == "lte":
45
+ return f"{field} <= {self._format_value(value)}"
46
+ elif op == "in":
47
+ return f"{field} IN ({self._format_value(value)})"
48
+ elif op == "not_in":
49
+ return f"{field} NOT IN ({self._format_value(value)})"
50
+ elif op == "like":
51
+ return f"{field} LIKE '{value}'"
52
+ elif op == "not_like":
53
+ return f"{field} NOT LIKE '{value}'"
54
+ elif op == "is_null":
55
+ return f"{field} IS NULL"
56
+ elif op == "is_not_null":
57
+ return f"{field} IS NOT NULL"
58
+
59
+ elif "and_" in condition:
60
+ return f"({' AND '.join(self._build_condition(c) for c in condition['and_'])})"
61
+ elif "or_" in condition:
62
+ return f"({' OR '.join(self._build_condition(c) for c in condition['or_'])})"
63
+
64
+ return ""
65
+
66
+ def _format_value(self, value: Any) -> str:
67
+ if isinstance(value, str):
68
+ return f"'{value}'"
69
+ elif isinstance(value, (list, tuple)):
70
+ return ", ".join(self._format_value(v) for v in value)
71
+ return str(value)
72
+
73
+ def build_sql(self) -> str:
74
+ if not self.query_condition:
75
+ return ""
76
+ return self._build_condition(self.query_condition)
77
+
78
+
79
+ class OdpsStorage(Storage):
80
+ '''
81
+ Aliyun ODPS storage.
82
+ Table schema:
83
+ id: int
84
+ task_id: string
85
+ task_name: string
86
+ agent_id: string
87
+ step: int
88
+ execute_time: string
89
+ state: string
90
+ actions: string
91
+ reward_t: string
92
+ adv_t: string
93
+ v_t: string
94
+ '''
95
+
96
+ def __init__(self, table_name: str, project: str, endpoint: str, access_id: str, access_key: str, **kwargs):
97
+ self.table_name = table_name
98
+ self.project = project
99
+ self.endpoint = endpoint
100
+ self.access_id = access_id
101
+ self.access_key = access_key
102
+ self.kwargs = kwargs
103
+ self._init_odps()
104
+
105
+ def _init_odps(self):
106
+ self.odps = ODPS(self.access_id, self.access_key,
107
+ self.project, self.endpoint)
108
+
109
+ def _get_table(self):
110
+ return self.odps.get_table(self.table_name)
111
+
112
+ def _convert_row_to_record(self, row: DataRow) -> Record:
113
+ table = self._get_table()
114
+ record = table.new_record()
115
+ record["id"] = row.id
116
+ record["task_id"] = row.exp_meta.task_id
117
+ record["task_name"] = row.exp_meta.task_name
118
+ record["agent_id"] = row.exp_meta.agent_id
119
+ record["step"] = row.exp_meta.step
120
+ record["execute_time"] = row.exp_meta.execute_time
121
+ if row.exp_data.state:
122
+ record["state"] = row.exp_data.state.model_dump_json()
123
+ if row.exp_data.actions:
124
+ record["actions"] = "[" + ", ".join(action.model_dump_json()
125
+ for action in row.exp_data.actions) + "]"
126
+ if row.exp_data.reward_t:
127
+ record["reward_t"] = row.exp_data.reward_t
128
+ if row.exp_data.adv_t:
129
+ record["adv_t"] = row.exp_data.adv_t
130
+ if row.exp_data.v_t:
131
+ record["v_t"] = row.exp_data.v_t
132
+ return record
133
+
134
+ def _convert_record_to_row(self, record: Record) -> DataRow:
135
+ return DataRow(
136
+ id=record.id,
137
+ exp_meta=ExpMeta(
138
+ task_id=record['task_id'],
139
+ task_name=record['task_name'],
140
+ agent_id=record['agent_id'],
141
+ step=record['step'],
142
+ execute_time=record['execute_time'],
143
+ pre_agent=record['pre_agent'] if 'pre_agent' in record else None
144
+ ),
145
+ exp_data=Experience(
146
+ state=parse_obj_as(Observation, json.loads(record['state'])),
147
+ actions=[parse_obj_as(ActionModel, item)
148
+ for item in json.loads(record['actions'])],
149
+ reward_t=record['reward_t'] if 'reward_t' in record else None,
150
+ adv_t=record['adv_t'] if 'adv_t' in record else None,
151
+ v_t=record['v_t'] if 'v_t' in record else None,
152
+ )
153
+ )
154
+
155
+ def _build_paginated_sql(self, page: int = None, page_size: int = None):
156
+ if page and page_size:
157
+ offset = (page - 1) * page_size
158
+ limit = page_size
159
+ return f" LIMIT {offset}, {limit}"
160
+ return ""
161
+
162
+ def _build_sql(self, query_condition: QueryCondition, page: int = None, page_size: int = None):
163
+ if not query_condition:
164
+ return f"SELECT * FROM {self.table_name}" + self._build_paginated_sql(page, page_size)
165
+ where_builder = OdpsSQLBuilder(query_condition)
166
+ sql = f"SELECT * FROM {self.table_name} WHERE {where_builder.build_sql()}" + self._build_paginated_sql(page,
167
+ page_size)
168
+ return sql
169
+
170
+ def _build_count_sql(self, query_condition: QueryCondition):
171
+ if not query_condition:
172
+ return f"SELECT count(1) as count FROM {self.table_name}"
173
+ where_builder = OdpsSQLBuilder(query_condition)
174
+ sql = f"SELECT count(1) as count FROM {self.table_name} WHERE {where_builder.build_sql()}"
175
+ return sql
176
+
177
+ def add(self, row: DataRow):
178
+ record = self._convert_row_to_record(row)
179
+ self.odps.write_table(self.table_name, [record])
180
+
181
+ def add_batch(self, rows: list[DataRow]):
182
+ records = [self._convert_row_to_record(row) for row in rows]
183
+ self.odps.write_table(self.table_name, records)
184
+
185
+ def size(self, query_condition: QueryCondition = None) -> int:
186
+ sql = self._build_count_sql(query_condition)
187
+ with self.odps.execute_sql(sql).open_reader() as reader:
188
+ return reader[0]["count"]
189
+
190
+ def get_all(self, query_condition: QueryCondition = None) -> list[DataRow]:
191
+ sql = self._build_sql(query_condition)
192
+ logger.info(f"get_all sql: {sql}")
193
+ with self.odps.execute_sql(sql).open_reader(tunnel=True) as reader:
194
+ rows = []
195
+ for record in reader:
196
+ rows.append(self._convert_record_to_row(record))
197
+ return rows
198
+
199
+ def get_paginated(self, page: int, page_size: int, query_condition: QueryCondition = None) -> List[DataRow]:
200
+ sql = self._build_sql(query_condition, page, page_size)
201
+ logger.info(f"get_paginated sql: {sql}")
202
+ with self.odps.execute_sql(sql).open_reader(tunnel=True) as reader:
203
+ rows = []
204
+ for record in reader:
205
+ rows.append(self._convert_record_to_row(record))
206
+ return rows
207
+
208
+ def get_by_task_id(self, task_id: str) -> List[DataRow]:
209
+ query_condition = QueryBuilder().eq("task_id", task_id).build()
210
+ return self.get_all(query_condition)
211
+
212
+ def get_bacth_by_task_ids(self, task_ids: List[str]) -> Dict[str, List[DataRow]]:
213
+ query_condition = QueryBuilder().in_("task_id", task_ids).build()
214
+ sql = self._build_sql(query_condition)
215
+ logger.info(f"get_bacth_by_task_ids sql: {sql}")
216
+ result = {}
217
+ with self.odps.execute_sql(sql).open_reader(tunnel=True) as reader:
218
+ for record in reader:
219
+ row = self._convert_record_to_row(record)
220
+ if row.exp_meta.task_id not in result:
221
+ result[row.exp_meta.task_id] = []
222
+ result[row.exp_meta.task_id].append(row)
223
+ return result
aworld/replay_buffer/storage/redis.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from typing import Dict, List
3
+ from aworld.replay_buffer.base import Storage, DataRow, ExpMeta, Experience
4
+ from aworld.logs.util import logger
5
+ from aworld.utils.import_package import import_package
6
+ from aworld.replay_buffer.query_filter import QueryCondition, QueryBuilder
7
+ from aworld.core.common import Observation, ActionModel
8
+ import_package("redis") # noqa
9
+ from redis import Redis # noqa
10
+ from redis.commands.json.path import Path # noqa
11
+ import redis.commands.search.aggregation as aggregations # noqa
12
+ import redis.commands.search.reducers as reducers # noqa
13
+ from redis.commands.search.field import TextField, NumericField, TagField # noqa
14
+ from redis.commands.search.index_definition import IndexDefinition, IndexType # noqa
15
+ from redis.commands.search.query import Query # noqa
16
+ import redis.exceptions # noqa
17
+
18
+
19
+ class RedisSearchQueryBuilder:
20
+ """
21
+ Build redis search query from query condition
22
+ """
23
+
24
+ def __init__(self, query_condition: QueryCondition):
25
+ self.query_condition = query_condition
26
+
27
+ def _build_condition(self, condition: QueryCondition) -> str:
28
+ if condition is None:
29
+ return ""
30
+
31
+ if "field" in condition and "op" in condition:
32
+ field = condition["field"].split('.')[-1]
33
+ op = condition["op"]
34
+ value = condition.get("value")
35
+
36
+ if op == "eq":
37
+ return f"@{field}:{{{value}}}"
38
+ elif op == "ne":
39
+ return f"-@{field}:{{{value}}}"
40
+ elif op == "gt":
41
+ return f"@{field}:[{value} +inf]"
42
+ elif op == "gte":
43
+ return f"@{field}:[{value} +inf]"
44
+ elif op == "lt":
45
+ return f"@{field}:[-inf {value}]"
46
+ elif op == "lte":
47
+ return f"@{field}:[-inf {value}]"
48
+ elif op == "in":
49
+ return f"@{field}:{{{'|'.join(str(v) for v in value)}}}"
50
+ elif op == "not_in":
51
+ return f"-@{field}:{{{'|'.join(str(v) for v in value)}}}"
52
+ elif op == "like":
53
+ return f"@{field}:*{value}*"
54
+ elif op == "not_like":
55
+ return f"-@{field}:*{value}*"
56
+ elif op == "is_null":
57
+ return f"-@{field}:*"
58
+ elif op == "is_not_null":
59
+ return f"@{field}:*"
60
+
61
+ elif "and_" in condition:
62
+ conditions = [self._build_condition(c) for c in condition["and_"]]
63
+ return " ".join(conditions)
64
+ elif "or_" in condition:
65
+ conditions = [self._build_condition(c) for c in condition["or_"]]
66
+ return f"({'|'.join(conditions)})"
67
+
68
+ return ""
69
+
70
+ def build(self) -> Query:
71
+ query_str = self._build_condition(self.query_condition)
72
+ logger.info(f"redis search query: {query_str}")
73
+ return Query(query_str)
74
+
75
+
76
+ class RedisStorage(Storage):
77
+ def __init__(self,
78
+ host: str = 'localhost',
79
+ port: int = 6379,
80
+ db: int = 0,
81
+ password: str = None,
82
+ key_prefix: str = 'AWORLD:RB:',
83
+ index_name: str = 'idx:AWORLD:RB',
84
+ recreate_idx_if_exists=False):
85
+ self._redis = Redis(host=host, port=port, db=db, password=password)
86
+ self._key_prefix = key_prefix
87
+ self._index_name = index_name
88
+ self._recreate_idx_if_exists = recreate_idx_if_exists
89
+ self._create_index()
90
+
91
+ def _create_index(self):
92
+ try:
93
+ existing_indices = self._redis.execute_command('FT._LIST')
94
+ if self._index_name.encode('utf-8') in existing_indices:
95
+ logger.info(f"Index {self._index_name} already exists")
96
+ if self._recreate_idx_if_exists:
97
+ self._redis.ft(self._index_name).dropindex()
98
+ logger.info(f"Index {self._index_name} dropped")
99
+ else:
100
+ return
101
+ self._redis.ft(self._index_name).create_index(
102
+ (
103
+ TagField("id"),
104
+ TagField("task_id"),
105
+ TextField("task_name"),
106
+ TagField("agent_id"),
107
+ NumericField("step"),
108
+ NumericField("execute_time"),
109
+ TagField("pre_agent")
110
+ ),
111
+ definition=IndexDefinition(
112
+ prefix=[self._key_prefix], index_type=IndexType.HASH)
113
+ )
114
+ except redis.exceptions.ResponseError as e:
115
+ logger.error(f"Create index {self._index_name} failed. {e}")
116
+
117
+ def _get_object_key(self, key: str) -> str:
118
+ return f"{self._key_prefix}{key}"
119
+
120
+ def _serialize_to_str(self, value) -> str:
121
+ if str is None:
122
+ return ""
123
+ if isinstance(value, (int, float)):
124
+ return str(value)
125
+ return str(value) if value is not None else ""
126
+
127
+ def _serialize(self, data: DataRow) -> Dict[str, str]:
128
+ dict_data = {
129
+ 'id': data.id,
130
+ 'task_id': data.exp_meta.task_id,
131
+ 'task_name': data.exp_meta.task_name,
132
+ 'agent_id': data.exp_meta.agent_id,
133
+ 'step': data.exp_meta.step,
134
+ 'execute_time': data.exp_meta.execute_time,
135
+ 'pre_agent': data.exp_meta.pre_agent,
136
+ 'state': data.exp_data.state.model_dump_json(),
137
+ 'actions': "[" + ", ".join(action.model_dump_json()
138
+ for action in data.exp_data.actions) + "]",
139
+ 'reward_t': data.exp_data.reward_t,
140
+ 'adv_t': data.exp_data.adv_t,
141
+ 'v_t': data.exp_data.v_t
142
+ }
143
+ return {k: self._serialize_to_str(v) for k, v in dict_data.items()}
144
+
145
+ def _deserialize(self, data: Dict) -> DataRow:
146
+ if not data:
147
+ return None
148
+ return DataRow(
149
+ id=data.get('id'),
150
+ exp_meta=ExpMeta(
151
+ task_id=data.get('task_id'),
152
+ task_name=data.get('task_name'),
153
+ agent_id=data.get('agent_id'),
154
+ step=int(data.get('step', 0)),
155
+ execute_time=float(data.get('execute_time', 0)),
156
+ pre_agent=data.get('pre_agent')
157
+ ),
158
+ exp_data=Experience(
159
+ state=Observation.model_validate_json(data.get('state', '{}')),
160
+ actions=[ActionModel.model_validate_json(json.dumps(action))
161
+ for action in json.loads(data.get('actions', '[]'))],
162
+ reward_t=float(data.get('reward_t', 0)) if data.get(
163
+ 'reward_t') is not '' else None,
164
+ adv_t=float(data.get('adv_t', 0)) if data.get(
165
+ 'adv_t') is not '' else None,
166
+ v_t=float(data.get('v_t', 0)) if data.get(
167
+ 'v_t') is not '' else None
168
+ )
169
+ )
170
+
171
+ def add(self, data: DataRow):
172
+ key = self._get_object_key(data.id)
173
+ self._redis.hset(key, mapping=self._serialize(data))
174
+
175
+ def add_batch(self, data_batch: List[DataRow]):
176
+ pipeline = self._redis.pipeline()
177
+ for data in data_batch:
178
+ if not data or not data.exp_meta:
179
+ continue
180
+ key = self._get_object_key(data.id)
181
+ pipeline.hset(key, mapping=self._serialize(data))
182
+ pipeline.execute()
183
+
184
+ def search(self, key: str, value: str) -> DataRow:
185
+ result = self._redis.ft(self._index_name).search(
186
+ Query(f"@{key}:{{{value}}}"))
187
+ logger.info(f"Search result: {result}")
188
+
189
+ def size(self, query_condition: QueryCondition = None) -> int:
190
+ '''
191
+ Get the size of the storage.
192
+ Returns:
193
+ int: Size of the storage.
194
+ '''
195
+ if not query_condition:
196
+ return self._redis.ft(self._index_name).info()['num_docs']
197
+ query_builder = RedisSearchQueryBuilder(query_condition)
198
+ query = query_builder.build()
199
+ return self._redis.ft(self._index_name).search(query).total
200
+
201
+ def get_paginated(self, page: int, page_size: int, query_condition: QueryCondition = None) -> List[DataRow]:
202
+ '''
203
+ Get paginated data from the storage.
204
+ Args:
205
+ page (int): Page number.
206
+ page_size (int): Number of data per page.
207
+ Returns:
208
+ List[DataRow]: List of data.
209
+ '''
210
+ if not query_condition:
211
+ result = self._redis.ft(self._index_name).search(
212
+ Query("*").paging(page, page_size))
213
+ else:
214
+ query_builder = RedisSearchQueryBuilder(query_condition)
215
+ query = query_builder.build().paging(page, page_size)
216
+ result = self._redis.ft(self._index_name).search(query)
217
+ return [self._deserialize(doc.__dict__) for doc in result.docs]
218
+
219
+ def get_all(self, query_condition: QueryCondition = None) -> List[DataRow]:
220
+ '''
221
+ Get all data from the storage.
222
+ Returns:
223
+ List[DataRow]: List of data.
224
+ '''
225
+ if not query_condition:
226
+ result = self._redis.ft(self._index_name).search(Query("*"))
227
+ else:
228
+ query_builder = RedisSearchQueryBuilder(query_condition)
229
+ query = query_builder.build()
230
+ result = self._redis.ft(self._index_name).search(query)
231
+ return [self._deserialize(doc.__dict__) for doc in result.docs]
232
+
233
+ def get_by_task_id(self, task_id: str) -> List[DataRow]:
234
+ '''
235
+ Get data by task_id from the storage.
236
+ Args:
237
+ task_id (str): Task id.
238
+ Returns:
239
+ List[DataRow]: List of data.
240
+ '''
241
+ query_condition = QueryBuilder().eq("task_id", task_id).build()
242
+ return self.get_all(query_condition)
243
+
244
+ def get_bacth_by_task_ids(self, task_ids: List[str]) -> Dict[str, List[DataRow]]:
245
+ '''
246
+ Get data by task_ids from the storage.
247
+ Args:
248
+ task_ids (List[str]): List of task ids.
249
+ Returns:
250
+ Dict[str, List[DataRow]]: Dict of task id and list of data.
251
+ '''
252
+ query_condition = QueryBuilder().in_("task_id", task_ids).build()
253
+ result = self.get_all(query_condition)
254
+ return {task_id: [data for data in result if data.exp_meta.task_id == task_id] for task_id in task_ids}
255
+
256
+ def clear(self):
257
+ '''
258
+ Clear the storage.
259
+ '''
260
+ keys = self._redis.keys(f"{self._key_prefix}*")
261
+ if keys:
262
+ self._redis.delete(*keys)