Spaces:
Sleeping
Sleeping
Upload 5 files
Browse files
examples/replay_buffer/multi_processing.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import traceback
|
3 |
+
import multiprocessing
|
4 |
+
from aworld import replay_buffer
|
5 |
+
from aworld.core.common import ActionModel, Observation
|
6 |
+
from aworld.replay_buffer.base import ReplayBuffer, DataRow, ExpMeta, Experience
|
7 |
+
from aworld.replay_buffer.query_filter import QueryBuilder
|
8 |
+
from aworld.replay_buffer.storage.multi_proc_mem import MultiProcMemoryStorage
|
9 |
+
from aworld.logs.util import logger
|
10 |
+
|
11 |
+
|
12 |
+
def write_processing(replay_buffer: ReplayBuffer, task_id: str):
|
13 |
+
for i in range(10):
|
14 |
+
try:
|
15 |
+
data = DataRow(
|
16 |
+
exp_meta=ExpMeta(
|
17 |
+
task_id=task_id,
|
18 |
+
task_name=task_id,
|
19 |
+
agent_id=f"agent_{i+1}",
|
20 |
+
step=i,
|
21 |
+
execute_time=time.time()
|
22 |
+
),
|
23 |
+
exp_data=Experience(state=Observation(),
|
24 |
+
actions=[ActionModel()])
|
25 |
+
)
|
26 |
+
replay_buffer.store(data)
|
27 |
+
except Exception as e:
|
28 |
+
stack_trace = traceback.format_exc()
|
29 |
+
logger.error(
|
30 |
+
f"write_processing error: {e}\nStack trace:\n{stack_trace}")
|
31 |
+
time.sleep(1)
|
32 |
+
|
33 |
+
|
34 |
+
def read_processing_by_task(replay_buffer: ReplayBuffer, task_id: str):
|
35 |
+
while True:
|
36 |
+
try:
|
37 |
+
query_condition = QueryBuilder().eq("exp_meta.task_id", task_id).build()
|
38 |
+
data = replay_buffer.sample_task(
|
39 |
+
query_condition=query_condition, batch_size=2)
|
40 |
+
logger.info(f"read data of task[{task_id}]: {data}")
|
41 |
+
except Exception as e:
|
42 |
+
stack_trace = traceback.format_exc()
|
43 |
+
logger.error(
|
44 |
+
f"read_processing_by_task error: {e}\nStack trace:\n{stack_trace}")
|
45 |
+
time.sleep(1)
|
46 |
+
|
47 |
+
|
48 |
+
def read_processing_by_agent(replay_buffer: ReplayBuffer, agent_id: str):
|
49 |
+
while True:
|
50 |
+
try:
|
51 |
+
query_condition = QueryBuilder().eq("exp_meta.agent_id", agent_id).build()
|
52 |
+
data = replay_buffer.sample_task(
|
53 |
+
query_condition=query_condition, batch_size=2)
|
54 |
+
logger.info(f"read data of agent[{agent_id}]: {data}")
|
55 |
+
except Exception as e:
|
56 |
+
logger.info(f"read_processing_by_agent error: {e}")
|
57 |
+
time.sleep(1)
|
58 |
+
|
59 |
+
|
60 |
+
if __name__ == "__main__":
|
61 |
+
multiprocessing.freeze_support()
|
62 |
+
multiprocessing.set_start_method('spawn')
|
63 |
+
manager = multiprocessing.Manager()
|
64 |
+
|
65 |
+
replay_buffer = ReplayBuffer(storage=MultiProcMemoryStorage(
|
66 |
+
data_dict=manager.dict(),
|
67 |
+
fifo_queue=manager.list(),
|
68 |
+
lock=manager.Lock(),
|
69 |
+
max_capacity=10000
|
70 |
+
))
|
71 |
+
|
72 |
+
processes = [
|
73 |
+
multiprocessing.Process(target=write_processing,
|
74 |
+
args=(replay_buffer, "task_1",)),
|
75 |
+
multiprocessing.Process(target=write_processing,
|
76 |
+
args=(replay_buffer, "task_2",)),
|
77 |
+
multiprocessing.Process(target=write_processing,
|
78 |
+
args=(replay_buffer, "task_3",)),
|
79 |
+
multiprocessing.Process(target=write_processing,
|
80 |
+
args=(replay_buffer, "task_4",)),
|
81 |
+
# multiprocessing.Process(
|
82 |
+
# target=read_processing_by_task, args=(replay_buffer, "task_1",)),
|
83 |
+
multiprocessing.Process(
|
84 |
+
target=read_processing_by_agent, args=(replay_buffer, "agent_3",))
|
85 |
+
]
|
86 |
+
for p in processes:
|
87 |
+
p.start()
|
88 |
+
|
89 |
+
try:
|
90 |
+
for p in processes:
|
91 |
+
p.join()
|
92 |
+
except KeyboardInterrupt:
|
93 |
+
for p in processes:
|
94 |
+
p.terminate()
|
95 |
+
for p in processes:
|
96 |
+
p.join()
|
97 |
+
finally:
|
98 |
+
logger.info("Processes terminated.")
|
examples/replay_buffer/query_builder.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from aworld.replay_buffer.query_filter import QueryBuilder
|
2 |
+
from aworld.logs.util import logger
|
3 |
+
|
4 |
+
|
5 |
+
def example():
|
6 |
+
'''
|
7 |
+
expression: task_id = "123"
|
8 |
+
return :
|
9 |
+
{
|
10 |
+
'field': 'task_id',
|
11 |
+
'value': '123',
|
12 |
+
'op': 'eq'
|
13 |
+
}
|
14 |
+
'''
|
15 |
+
qb = QueryBuilder()
|
16 |
+
query = qb.eq("task_id", "123").build()
|
17 |
+
logger.info(query)
|
18 |
+
|
19 |
+
|
20 |
+
def example1():
|
21 |
+
'''
|
22 |
+
expression: (task_id = "123" and agent_id = "111") or (task_id = "456" and agent_id = "222")
|
23 |
+
return :
|
24 |
+
{
|
25 |
+
'or_': [{
|
26 |
+
'and_': [{
|
27 |
+
'field': 'task_id',
|
28 |
+
'value': '123',
|
29 |
+
'op': 'eq'
|
30 |
+
}, {
|
31 |
+
'field': 'agent_id',
|
32 |
+
'value': '111',
|
33 |
+
'op': 'eq'
|
34 |
+
}]
|
35 |
+
}, {
|
36 |
+
'and_': [{
|
37 |
+
'field': 'task_id',
|
38 |
+
'value': '456',
|
39 |
+
'op': 'eq'
|
40 |
+
}, {
|
41 |
+
'field': 'agent_id',
|
42 |
+
'value': '222',
|
43 |
+
'op': 'eq'
|
44 |
+
}]
|
45 |
+
}]
|
46 |
+
}
|
47 |
+
'''
|
48 |
+
qb = QueryBuilder()
|
49 |
+
query = (qb.eq("task_id", "123")
|
50 |
+
.and_()
|
51 |
+
.eq("agent_id", "111")
|
52 |
+
.or_()
|
53 |
+
.nested(QueryBuilder()
|
54 |
+
.eq("task_id", "456")
|
55 |
+
.and_()
|
56 |
+
.eq("agent_id", "222"))
|
57 |
+
.build())
|
58 |
+
logger.info(query)
|
59 |
+
|
60 |
+
|
61 |
+
def example2():
|
62 |
+
'''
|
63 |
+
expression: task_id = "123" and (agent_id = "111" or agent_id = "222")
|
64 |
+
return :
|
65 |
+
{
|
66 |
+
'and_': [{
|
67 |
+
'field': 'task_id',
|
68 |
+
'value': '123',
|
69 |
+
'op': 'eq'
|
70 |
+
}, {
|
71 |
+
'or_': [{
|
72 |
+
'field': 'agent_id',
|
73 |
+
'value': '111',
|
74 |
+
'op': 'eq'
|
75 |
+
}, {
|
76 |
+
'field': 'agent_id',
|
77 |
+
'value': '222',
|
78 |
+
'op': 'eq'
|
79 |
+
}
|
80 |
+
}
|
81 |
+
}
|
82 |
+
'''
|
83 |
+
qb = QueryBuilder()
|
84 |
+
query = (qb.eq("task_id", "123")
|
85 |
+
.and_()
|
86 |
+
.nested(QueryBuilder()
|
87 |
+
.eq("agent_id", "111")
|
88 |
+
.or_()
|
89 |
+
.eq("agent_id", "222"))
|
90 |
+
.build())
|
91 |
+
logger.info(query)
|
92 |
+
|
93 |
+
if __name__ == "__main__":
|
94 |
+
example()
|
95 |
+
example1()
|
96 |
+
example2()
|
examples/replay_buffer/query_filter.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
from aworld.replay_buffer.base import (
|
3 |
+
DataRow,
|
4 |
+
DefaultConverter,
|
5 |
+
ReplayBuffer,
|
6 |
+
ExpMeta,
|
7 |
+
Experience,
|
8 |
+
RandomSample
|
9 |
+
)
|
10 |
+
from aworld.core.common import ActionModel, Observation
|
11 |
+
from aworld.replay_buffer.query_filter import QueryBuilder, QueryFilter
|
12 |
+
from aworld.logs.util import logger
|
13 |
+
|
14 |
+
|
15 |
+
def test_filter():
|
16 |
+
row = DataRow(
|
17 |
+
exp_meta=ExpMeta(
|
18 |
+
task_id="task_1",
|
19 |
+
task_name="default_task_name",
|
20 |
+
agent_id="agent_1",
|
21 |
+
step=1,
|
22 |
+
execute_time=time.time(),
|
23 |
+
),
|
24 |
+
exp_data=Experience(state=Observation(), action=ActionModel())
|
25 |
+
)
|
26 |
+
|
27 |
+
query = QueryBuilder().eq("exp_meta.task_id", "task_1").build()
|
28 |
+
filter1 = QueryFilter(query)
|
29 |
+
assert filter1.check_condition(row)
|
30 |
+
|
31 |
+
query = QueryBuilder().eq("exp_meta.task_id", "task_2").build()
|
32 |
+
filter2 = QueryFilter(query)
|
33 |
+
assert not filter2.check_condition(row)
|
34 |
+
|
35 |
+
query = QueryBuilder().eq("exp_meta.task_id", "task_1").and_().eq(
|
36 |
+
"exp_meta.agent_id", "agent_2").build()
|
37 |
+
filter3 = QueryFilter(query)
|
38 |
+
assert not filter3.check_condition(row)
|
39 |
+
|
40 |
+
|
41 |
+
if __name__ == "__main__":
|
42 |
+
test_filter()
|
examples/replay_buffer/storage_odps.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
from aworld.core.common import ActionModel, Observation
|
3 |
+
from aworld.replay_buffer.base import (
|
4 |
+
DataRow,
|
5 |
+
DefaultConverter,
|
6 |
+
ReplayBuffer,
|
7 |
+
ExpMeta,
|
8 |
+
Experience,
|
9 |
+
RandomTaskSample
|
10 |
+
)
|
11 |
+
from aworld.replay_buffer.query_filter import QueryBuilder
|
12 |
+
from aworld.logs.util import logger
|
13 |
+
from aworld.replay_buffer.storage.odps import OdpsStorage
|
14 |
+
|
15 |
+
|
16 |
+
buffer = ReplayBuffer(storage=OdpsStorage(
|
17 |
+
table_name="adm_aworld_replay_buffer",
|
18 |
+
project="alifin_jtest_dev",
|
19 |
+
endpoint="",
|
20 |
+
access_id="",
|
21 |
+
access_key=""
|
22 |
+
))
|
23 |
+
|
24 |
+
|
25 |
+
def write_data():
|
26 |
+
rows = []
|
27 |
+
for id in range(5):
|
28 |
+
task_id = f"task_{id+1}"
|
29 |
+
for i in range(5):
|
30 |
+
agent_id = f"agent_{i+1}"
|
31 |
+
for j in range(5):
|
32 |
+
step = j + 1
|
33 |
+
execute_time = time.time() + j
|
34 |
+
row = DataRow(
|
35 |
+
exp_meta=ExpMeta(
|
36 |
+
task_id=task_id,
|
37 |
+
task_name="default_task_name",
|
38 |
+
agent_id=agent_id,
|
39 |
+
step=step,
|
40 |
+
execute_time=execute_time,
|
41 |
+
pre_agent="pre_agent_id"
|
42 |
+
),
|
43 |
+
exp_data=Experience(state=Observation(),
|
44 |
+
actions=[ActionModel()])
|
45 |
+
)
|
46 |
+
rows.append(row)
|
47 |
+
buffer.store_batch(rows)
|
48 |
+
|
49 |
+
|
50 |
+
def read_data():
|
51 |
+
query = QueryBuilder().eq("exp_meta.task_id", "task_1").build()
|
52 |
+
datas = buffer.sample_task(query_condition=query,
|
53 |
+
sampler=RandomTaskSample(),
|
54 |
+
converter=DefaultConverter(),
|
55 |
+
batch_size=1)
|
56 |
+
for data in datas:
|
57 |
+
logger.info(f"task_1 data: {data}")
|
58 |
+
|
59 |
+
query = QueryBuilder().eq("exp_meta.agent_id", "agent_5").build()
|
60 |
+
datas = buffer.sample_task(query_condition=query,
|
61 |
+
sampler=RandomTaskSample(),
|
62 |
+
converter=DefaultConverter(),
|
63 |
+
batch_size=2)
|
64 |
+
for data in datas:
|
65 |
+
logger.info(f"agent_5 data: {data}")
|
66 |
+
|
67 |
+
|
68 |
+
if __name__ == "__main__":
|
69 |
+
# write_data()
|
70 |
+
read_data()
|
examples/replay_buffer/storage_redis.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
from aworld.replay_buffer.base import DataRow, ExpMeta, Experience
|
3 |
+
from aworld.replay_buffer.storage.redis import RedisStorage
|
4 |
+
from aworld.replay_buffer.query_filter import QueryBuilder
|
5 |
+
from aworld.core.common import Observation, ActionModel
|
6 |
+
from aworld.logs.util import logger
|
7 |
+
|
8 |
+
storage = RedisStorage(host="localhost", port=6379,
|
9 |
+
recreate_idx_if_exists=False)
|
10 |
+
|
11 |
+
|
12 |
+
def generate_data_row() -> list[DataRow]:
|
13 |
+
rows: list[DataRow] = []
|
14 |
+
for id in range(5):
|
15 |
+
task_id = f"task_{id+1}"
|
16 |
+
for i in range(5):
|
17 |
+
agent_id = f"agent_{i+1}"
|
18 |
+
for j in range(5):
|
19 |
+
step = j + 1
|
20 |
+
execute_time = time.time() + j
|
21 |
+
row = DataRow(
|
22 |
+
exp_meta=ExpMeta(
|
23 |
+
task_id=task_id,
|
24 |
+
task_name="default_task_name",
|
25 |
+
agent_id=agent_id,
|
26 |
+
step=step,
|
27 |
+
execute_time=execute_time,
|
28 |
+
pre_agent="pre_agent_id"
|
29 |
+
),
|
30 |
+
exp_data=Experience(state=Observation(),
|
31 |
+
actions=[ActionModel()])
|
32 |
+
)
|
33 |
+
rows.append(row)
|
34 |
+
return rows
|
35 |
+
|
36 |
+
|
37 |
+
def wriete_data():
|
38 |
+
storage.clear()
|
39 |
+
rows = generate_data_row()
|
40 |
+
storage.add_batch(rows)
|
41 |
+
logger.info(f"Add {len(rows)} rows to storage.")
|
42 |
+
|
43 |
+
|
44 |
+
def read_data():
|
45 |
+
query_condition = (QueryBuilder()
|
46 |
+
.eq("exp_meta.task_id", "task_1")
|
47 |
+
.and_()
|
48 |
+
.eq("exp_meta.agent_id", "agent_1")
|
49 |
+
.or_()
|
50 |
+
.nested(QueryBuilder()
|
51 |
+
.eq("exp_meta.task_id", "task_4")
|
52 |
+
.and_()
|
53 |
+
.eq("exp_meta.agent_id", "agent_3")
|
54 |
+
.and_()
|
55 |
+
.gt("exp_meta.step", 4)).build())
|
56 |
+
|
57 |
+
rows = storage.get_all(query_condition)
|
58 |
+
for row in rows:
|
59 |
+
logger.info(row)
|
60 |
+
|
61 |
+
rows = storage.get_paginated(
|
62 |
+
page=2, page_size=2, query_condition=query_condition)
|
63 |
+
for row in rows:
|
64 |
+
logger.info(f"get_paginated: {row}")
|
65 |
+
|
66 |
+
|
67 |
+
if __name__ == "__main__":
|
68 |
+
# wriete_data()
|
69 |
+
read_data()
|