Duibonduil commited on
Commit
912a768
·
verified ·
1 Parent(s): 5fc6c27

Upload buffer.py

Browse files
Files changed (1) hide show
  1. examples/buffer.py +59 -0
examples/buffer.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from aworld.core.common import ActionModel, Observation
3
+ from aworld.replay_buffer.base import (
4
+ DataRow,
5
+ DefaultConverter,
6
+ ReplayBuffer,
7
+ ExpMeta,
8
+ Experience,
9
+ RandomTaskSample
10
+ )
11
+ from aworld.replay_buffer.query_filter import QueryBuilder
12
+ from aworld.logs.util import logger
13
+
14
+
15
+ buffer = ReplayBuffer()
16
+
17
+
18
+ def write_data():
19
+ for task_id in range(5):
20
+ for i in range(10):
21
+ task_id = f"task_{task_id}"
22
+ agent_id = f"agent_{i+1}"
23
+ step = i + 1
24
+ execute_time = time.time() + i
25
+ row = DataRow(
26
+ exp_meta=ExpMeta(
27
+ task_id=task_id,
28
+ task_name="default_task_name",
29
+ agent_id=agent_id,
30
+ step=step,
31
+ execute_time=execute_time,
32
+ ),
33
+ exp_data=Experience(state=Observation(),
34
+ actions=[ActionModel()])
35
+ )
36
+ buffer.store(row)
37
+
38
+
39
+ def read_data():
40
+ query = QueryBuilder().eq("exp_meta.task_id", "task_1").build()
41
+ datas = buffer.sample_task(query_condition=query,
42
+ sampler=RandomTaskSample(),
43
+ converter=DefaultConverter(),
44
+ batch_size=2)
45
+ for data in datas:
46
+ logger.info(f"task_1 data: {data}")
47
+
48
+ query = QueryBuilder().eq("exp_meta.agent_id", "agent_5").build()
49
+ datas = buffer.sample_task(query_condition=query,
50
+ sampler=RandomTaskSample(),
51
+ converter=DefaultConverter(),
52
+ batch_size=2)
53
+ for data in datas:
54
+ logger.info(f"agent_5 data: {data}")
55
+
56
+
57
+ if __name__ == "__main__":
58
+ write_data()
59
+ read_data()