Duibonduil commited on
Commit
4b677a1
·
verified ·
1 Parent(s): cdf0092

Upload 5 files

Browse files
examples/replay_buffer/multi_processing.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import traceback
3
+ import multiprocessing
4
+ from aworld import replay_buffer
5
+ from aworld.core.common import ActionModel, Observation
6
+ from aworld.replay_buffer.base import ReplayBuffer, DataRow, ExpMeta, Experience
7
+ from aworld.replay_buffer.query_filter import QueryBuilder
8
+ from aworld.replay_buffer.storage.multi_proc_mem import MultiProcMemoryStorage
9
+ from aworld.logs.util import logger
10
+
11
+
12
+ def write_processing(replay_buffer: ReplayBuffer, task_id: str):
13
+ for i in range(10):
14
+ try:
15
+ data = DataRow(
16
+ exp_meta=ExpMeta(
17
+ task_id=task_id,
18
+ task_name=task_id,
19
+ agent_id=f"agent_{i+1}",
20
+ step=i,
21
+ execute_time=time.time()
22
+ ),
23
+ exp_data=Experience(state=Observation(),
24
+ actions=[ActionModel()])
25
+ )
26
+ replay_buffer.store(data)
27
+ except Exception as e:
28
+ stack_trace = traceback.format_exc()
29
+ logger.error(
30
+ f"write_processing error: {e}\nStack trace:\n{stack_trace}")
31
+ time.sleep(1)
32
+
33
+
34
+ def read_processing_by_task(replay_buffer: ReplayBuffer, task_id: str):
35
+ while True:
36
+ try:
37
+ query_condition = QueryBuilder().eq("exp_meta.task_id", task_id).build()
38
+ data = replay_buffer.sample_task(
39
+ query_condition=query_condition, batch_size=2)
40
+ logger.info(f"read data of task[{task_id}]: {data}")
41
+ except Exception as e:
42
+ stack_trace = traceback.format_exc()
43
+ logger.error(
44
+ f"read_processing_by_task error: {e}\nStack trace:\n{stack_trace}")
45
+ time.sleep(1)
46
+
47
+
48
+ def read_processing_by_agent(replay_buffer: ReplayBuffer, agent_id: str):
49
+ while True:
50
+ try:
51
+ query_condition = QueryBuilder().eq("exp_meta.agent_id", agent_id).build()
52
+ data = replay_buffer.sample_task(
53
+ query_condition=query_condition, batch_size=2)
54
+ logger.info(f"read data of agent[{agent_id}]: {data}")
55
+ except Exception as e:
56
+ logger.info(f"read_processing_by_agent error: {e}")
57
+ time.sleep(1)
58
+
59
+
60
+ if __name__ == "__main__":
61
+ multiprocessing.freeze_support()
62
+ multiprocessing.set_start_method('spawn')
63
+ manager = multiprocessing.Manager()
64
+
65
+ replay_buffer = ReplayBuffer(storage=MultiProcMemoryStorage(
66
+ data_dict=manager.dict(),
67
+ fifo_queue=manager.list(),
68
+ lock=manager.Lock(),
69
+ max_capacity=10000
70
+ ))
71
+
72
+ processes = [
73
+ multiprocessing.Process(target=write_processing,
74
+ args=(replay_buffer, "task_1",)),
75
+ multiprocessing.Process(target=write_processing,
76
+ args=(replay_buffer, "task_2",)),
77
+ multiprocessing.Process(target=write_processing,
78
+ args=(replay_buffer, "task_3",)),
79
+ multiprocessing.Process(target=write_processing,
80
+ args=(replay_buffer, "task_4",)),
81
+ # multiprocessing.Process(
82
+ # target=read_processing_by_task, args=(replay_buffer, "task_1",)),
83
+ multiprocessing.Process(
84
+ target=read_processing_by_agent, args=(replay_buffer, "agent_3",))
85
+ ]
86
+ for p in processes:
87
+ p.start()
88
+
89
+ try:
90
+ for p in processes:
91
+ p.join()
92
+ except KeyboardInterrupt:
93
+ for p in processes:
94
+ p.terminate()
95
+ for p in processes:
96
+ p.join()
97
+ finally:
98
+ logger.info("Processes terminated.")
examples/replay_buffer/query_builder.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from aworld.replay_buffer.query_filter import QueryBuilder
2
+ from aworld.logs.util import logger
3
+
4
+
5
+ def example():
6
+ '''
7
+ expression: task_id = "123"
8
+ return :
9
+ {
10
+ 'field': 'task_id',
11
+ 'value': '123',
12
+ 'op': 'eq'
13
+ }
14
+ '''
15
+ qb = QueryBuilder()
16
+ query = qb.eq("task_id", "123").build()
17
+ logger.info(query)
18
+
19
+
20
+ def example1():
21
+ '''
22
+ expression: (task_id = "123" and agent_id = "111") or (task_id = "456" and agent_id = "222")
23
+ return :
24
+ {
25
+ 'or_': [{
26
+ 'and_': [{
27
+ 'field': 'task_id',
28
+ 'value': '123',
29
+ 'op': 'eq'
30
+ }, {
31
+ 'field': 'agent_id',
32
+ 'value': '111',
33
+ 'op': 'eq'
34
+ }]
35
+ }, {
36
+ 'and_': [{
37
+ 'field': 'task_id',
38
+ 'value': '456',
39
+ 'op': 'eq'
40
+ }, {
41
+ 'field': 'agent_id',
42
+ 'value': '222',
43
+ 'op': 'eq'
44
+ }]
45
+ }]
46
+ }
47
+ '''
48
+ qb = QueryBuilder()
49
+ query = (qb.eq("task_id", "123")
50
+ .and_()
51
+ .eq("agent_id", "111")
52
+ .or_()
53
+ .nested(QueryBuilder()
54
+ .eq("task_id", "456")
55
+ .and_()
56
+ .eq("agent_id", "222"))
57
+ .build())
58
+ logger.info(query)
59
+
60
+
61
+ def example2():
62
+ '''
63
+ expression: task_id = "123" and (agent_id = "111" or agent_id = "222")
64
+ return :
65
+ {
66
+ 'and_': [{
67
+ 'field': 'task_id',
68
+ 'value': '123',
69
+ 'op': 'eq'
70
+ }, {
71
+ 'or_': [{
72
+ 'field': 'agent_id',
73
+ 'value': '111',
74
+ 'op': 'eq'
75
+ }, {
76
+ 'field': 'agent_id',
77
+ 'value': '222',
78
+ 'op': 'eq'
79
+ }
80
+ }
81
+ }
82
+ '''
83
+ qb = QueryBuilder()
84
+ query = (qb.eq("task_id", "123")
85
+ .and_()
86
+ .nested(QueryBuilder()
87
+ .eq("agent_id", "111")
88
+ .or_()
89
+ .eq("agent_id", "222"))
90
+ .build())
91
+ logger.info(query)
92
+
93
+ if __name__ == "__main__":
94
+ example()
95
+ example1()
96
+ example2()
examples/replay_buffer/query_filter.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from aworld.replay_buffer.base import (
3
+ DataRow,
4
+ DefaultConverter,
5
+ ReplayBuffer,
6
+ ExpMeta,
7
+ Experience,
8
+ RandomSample
9
+ )
10
+ from aworld.core.common import ActionModel, Observation
11
+ from aworld.replay_buffer.query_filter import QueryBuilder, QueryFilter
12
+ from aworld.logs.util import logger
13
+
14
+
15
+ def test_filter():
16
+ row = DataRow(
17
+ exp_meta=ExpMeta(
18
+ task_id="task_1",
19
+ task_name="default_task_name",
20
+ agent_id="agent_1",
21
+ step=1,
22
+ execute_time=time.time(),
23
+ ),
24
+ exp_data=Experience(state=Observation(), action=ActionModel())
25
+ )
26
+
27
+ query = QueryBuilder().eq("exp_meta.task_id", "task_1").build()
28
+ filter1 = QueryFilter(query)
29
+ assert filter1.check_condition(row)
30
+
31
+ query = QueryBuilder().eq("exp_meta.task_id", "task_2").build()
32
+ filter2 = QueryFilter(query)
33
+ assert not filter2.check_condition(row)
34
+
35
+ query = QueryBuilder().eq("exp_meta.task_id", "task_1").and_().eq(
36
+ "exp_meta.agent_id", "agent_2").build()
37
+ filter3 = QueryFilter(query)
38
+ assert not filter3.check_condition(row)
39
+
40
+
41
+ if __name__ == "__main__":
42
+ test_filter()
examples/replay_buffer/storage_odps.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from aworld.core.common import ActionModel, Observation
3
+ from aworld.replay_buffer.base import (
4
+ DataRow,
5
+ DefaultConverter,
6
+ ReplayBuffer,
7
+ ExpMeta,
8
+ Experience,
9
+ RandomTaskSample
10
+ )
11
+ from aworld.replay_buffer.query_filter import QueryBuilder
12
+ from aworld.logs.util import logger
13
+ from aworld.replay_buffer.storage.odps import OdpsStorage
14
+
15
+
16
+ buffer = ReplayBuffer(storage=OdpsStorage(
17
+ table_name="adm_aworld_replay_buffer",
18
+ project="alifin_jtest_dev",
19
+ endpoint="",
20
+ access_id="",
21
+ access_key=""
22
+ ))
23
+
24
+
25
+ def write_data():
26
+ rows = []
27
+ for id in range(5):
28
+ task_id = f"task_{id+1}"
29
+ for i in range(5):
30
+ agent_id = f"agent_{i+1}"
31
+ for j in range(5):
32
+ step = j + 1
33
+ execute_time = time.time() + j
34
+ row = DataRow(
35
+ exp_meta=ExpMeta(
36
+ task_id=task_id,
37
+ task_name="default_task_name",
38
+ agent_id=agent_id,
39
+ step=step,
40
+ execute_time=execute_time,
41
+ pre_agent="pre_agent_id"
42
+ ),
43
+ exp_data=Experience(state=Observation(),
44
+ actions=[ActionModel()])
45
+ )
46
+ rows.append(row)
47
+ buffer.store_batch(rows)
48
+
49
+
50
+ def read_data():
51
+ query = QueryBuilder().eq("exp_meta.task_id", "task_1").build()
52
+ datas = buffer.sample_task(query_condition=query,
53
+ sampler=RandomTaskSample(),
54
+ converter=DefaultConverter(),
55
+ batch_size=1)
56
+ for data in datas:
57
+ logger.info(f"task_1 data: {data}")
58
+
59
+ query = QueryBuilder().eq("exp_meta.agent_id", "agent_5").build()
60
+ datas = buffer.sample_task(query_condition=query,
61
+ sampler=RandomTaskSample(),
62
+ converter=DefaultConverter(),
63
+ batch_size=2)
64
+ for data in datas:
65
+ logger.info(f"agent_5 data: {data}")
66
+
67
+
68
+ if __name__ == "__main__":
69
+ # write_data()
70
+ read_data()
examples/replay_buffer/storage_redis.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from aworld.replay_buffer.base import DataRow, ExpMeta, Experience
3
+ from aworld.replay_buffer.storage.redis import RedisStorage
4
+ from aworld.replay_buffer.query_filter import QueryBuilder
5
+ from aworld.core.common import Observation, ActionModel
6
+ from aworld.logs.util import logger
7
+
8
+ storage = RedisStorage(host="localhost", port=6379,
9
+ recreate_idx_if_exists=False)
10
+
11
+
12
+ def generate_data_row() -> list[DataRow]:
13
+ rows: list[DataRow] = []
14
+ for id in range(5):
15
+ task_id = f"task_{id+1}"
16
+ for i in range(5):
17
+ agent_id = f"agent_{i+1}"
18
+ for j in range(5):
19
+ step = j + 1
20
+ execute_time = time.time() + j
21
+ row = DataRow(
22
+ exp_meta=ExpMeta(
23
+ task_id=task_id,
24
+ task_name="default_task_name",
25
+ agent_id=agent_id,
26
+ step=step,
27
+ execute_time=execute_time,
28
+ pre_agent="pre_agent_id"
29
+ ),
30
+ exp_data=Experience(state=Observation(),
31
+ actions=[ActionModel()])
32
+ )
33
+ rows.append(row)
34
+ return rows
35
+
36
+
37
+ def wriete_data():
38
+ storage.clear()
39
+ rows = generate_data_row()
40
+ storage.add_batch(rows)
41
+ logger.info(f"Add {len(rows)} rows to storage.")
42
+
43
+
44
+ def read_data():
45
+ query_condition = (QueryBuilder()
46
+ .eq("exp_meta.task_id", "task_1")
47
+ .and_()
48
+ .eq("exp_meta.agent_id", "agent_1")
49
+ .or_()
50
+ .nested(QueryBuilder()
51
+ .eq("exp_meta.task_id", "task_4")
52
+ .and_()
53
+ .eq("exp_meta.agent_id", "agent_3")
54
+ .and_()
55
+ .gt("exp_meta.step", 4)).build())
56
+
57
+ rows = storage.get_all(query_condition)
58
+ for row in rows:
59
+ logger.info(row)
60
+
61
+ rows = storage.get_paginated(
62
+ page=2, page_size=2, query_condition=query_condition)
63
+ for row in rows:
64
+ logger.info(f"get_paginated: {row}")
65
+
66
+
67
+ if __name__ == "__main__":
68
+ # wriete_data()
69
+ read_data()