Spaces:
Sleeping
Sleeping
Upload 5 files
Browse files
examples/replay_buffer/multi_processing.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
import traceback
|
| 3 |
+
import multiprocessing
|
| 4 |
+
from aworld import replay_buffer
|
| 5 |
+
from aworld.core.common import ActionModel, Observation
|
| 6 |
+
from aworld.replay_buffer.base import ReplayBuffer, DataRow, ExpMeta, Experience
|
| 7 |
+
from aworld.replay_buffer.query_filter import QueryBuilder
|
| 8 |
+
from aworld.replay_buffer.storage.multi_proc_mem import MultiProcMemoryStorage
|
| 9 |
+
from aworld.logs.util import logger
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def write_processing(replay_buffer: ReplayBuffer, task_id: str):
|
| 13 |
+
for i in range(10):
|
| 14 |
+
try:
|
| 15 |
+
data = DataRow(
|
| 16 |
+
exp_meta=ExpMeta(
|
| 17 |
+
task_id=task_id,
|
| 18 |
+
task_name=task_id,
|
| 19 |
+
agent_id=f"agent_{i+1}",
|
| 20 |
+
step=i,
|
| 21 |
+
execute_time=time.time()
|
| 22 |
+
),
|
| 23 |
+
exp_data=Experience(state=Observation(),
|
| 24 |
+
actions=[ActionModel()])
|
| 25 |
+
)
|
| 26 |
+
replay_buffer.store(data)
|
| 27 |
+
except Exception as e:
|
| 28 |
+
stack_trace = traceback.format_exc()
|
| 29 |
+
logger.error(
|
| 30 |
+
f"write_processing error: {e}\nStack trace:\n{stack_trace}")
|
| 31 |
+
time.sleep(1)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def read_processing_by_task(replay_buffer: ReplayBuffer, task_id: str):
|
| 35 |
+
while True:
|
| 36 |
+
try:
|
| 37 |
+
query_condition = QueryBuilder().eq("exp_meta.task_id", task_id).build()
|
| 38 |
+
data = replay_buffer.sample_task(
|
| 39 |
+
query_condition=query_condition, batch_size=2)
|
| 40 |
+
logger.info(f"read data of task[{task_id}]: {data}")
|
| 41 |
+
except Exception as e:
|
| 42 |
+
stack_trace = traceback.format_exc()
|
| 43 |
+
logger.error(
|
| 44 |
+
f"read_processing_by_task error: {e}\nStack trace:\n{stack_trace}")
|
| 45 |
+
time.sleep(1)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def read_processing_by_agent(replay_buffer: ReplayBuffer, agent_id: str):
|
| 49 |
+
while True:
|
| 50 |
+
try:
|
| 51 |
+
query_condition = QueryBuilder().eq("exp_meta.agent_id", agent_id).build()
|
| 52 |
+
data = replay_buffer.sample_task(
|
| 53 |
+
query_condition=query_condition, batch_size=2)
|
| 54 |
+
logger.info(f"read data of agent[{agent_id}]: {data}")
|
| 55 |
+
except Exception as e:
|
| 56 |
+
logger.info(f"read_processing_by_agent error: {e}")
|
| 57 |
+
time.sleep(1)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
if __name__ == "__main__":
|
| 61 |
+
multiprocessing.freeze_support()
|
| 62 |
+
multiprocessing.set_start_method('spawn')
|
| 63 |
+
manager = multiprocessing.Manager()
|
| 64 |
+
|
| 65 |
+
replay_buffer = ReplayBuffer(storage=MultiProcMemoryStorage(
|
| 66 |
+
data_dict=manager.dict(),
|
| 67 |
+
fifo_queue=manager.list(),
|
| 68 |
+
lock=manager.Lock(),
|
| 69 |
+
max_capacity=10000
|
| 70 |
+
))
|
| 71 |
+
|
| 72 |
+
processes = [
|
| 73 |
+
multiprocessing.Process(target=write_processing,
|
| 74 |
+
args=(replay_buffer, "task_1",)),
|
| 75 |
+
multiprocessing.Process(target=write_processing,
|
| 76 |
+
args=(replay_buffer, "task_2",)),
|
| 77 |
+
multiprocessing.Process(target=write_processing,
|
| 78 |
+
args=(replay_buffer, "task_3",)),
|
| 79 |
+
multiprocessing.Process(target=write_processing,
|
| 80 |
+
args=(replay_buffer, "task_4",)),
|
| 81 |
+
# multiprocessing.Process(
|
| 82 |
+
# target=read_processing_by_task, args=(replay_buffer, "task_1",)),
|
| 83 |
+
multiprocessing.Process(
|
| 84 |
+
target=read_processing_by_agent, args=(replay_buffer, "agent_3",))
|
| 85 |
+
]
|
| 86 |
+
for p in processes:
|
| 87 |
+
p.start()
|
| 88 |
+
|
| 89 |
+
try:
|
| 90 |
+
for p in processes:
|
| 91 |
+
p.join()
|
| 92 |
+
except KeyboardInterrupt:
|
| 93 |
+
for p in processes:
|
| 94 |
+
p.terminate()
|
| 95 |
+
for p in processes:
|
| 96 |
+
p.join()
|
| 97 |
+
finally:
|
| 98 |
+
logger.info("Processes terminated.")
|
examples/replay_buffer/query_builder.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from aworld.replay_buffer.query_filter import QueryBuilder
|
| 2 |
+
from aworld.logs.util import logger
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def example():
|
| 6 |
+
'''
|
| 7 |
+
expression: task_id = "123"
|
| 8 |
+
return :
|
| 9 |
+
{
|
| 10 |
+
'field': 'task_id',
|
| 11 |
+
'value': '123',
|
| 12 |
+
'op': 'eq'
|
| 13 |
+
}
|
| 14 |
+
'''
|
| 15 |
+
qb = QueryBuilder()
|
| 16 |
+
query = qb.eq("task_id", "123").build()
|
| 17 |
+
logger.info(query)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def example1():
|
| 21 |
+
'''
|
| 22 |
+
expression: (task_id = "123" and agent_id = "111") or (task_id = "456" and agent_id = "222")
|
| 23 |
+
return :
|
| 24 |
+
{
|
| 25 |
+
'or_': [{
|
| 26 |
+
'and_': [{
|
| 27 |
+
'field': 'task_id',
|
| 28 |
+
'value': '123',
|
| 29 |
+
'op': 'eq'
|
| 30 |
+
}, {
|
| 31 |
+
'field': 'agent_id',
|
| 32 |
+
'value': '111',
|
| 33 |
+
'op': 'eq'
|
| 34 |
+
}]
|
| 35 |
+
}, {
|
| 36 |
+
'and_': [{
|
| 37 |
+
'field': 'task_id',
|
| 38 |
+
'value': '456',
|
| 39 |
+
'op': 'eq'
|
| 40 |
+
}, {
|
| 41 |
+
'field': 'agent_id',
|
| 42 |
+
'value': '222',
|
| 43 |
+
'op': 'eq'
|
| 44 |
+
}]
|
| 45 |
+
}]
|
| 46 |
+
}
|
| 47 |
+
'''
|
| 48 |
+
qb = QueryBuilder()
|
| 49 |
+
query = (qb.eq("task_id", "123")
|
| 50 |
+
.and_()
|
| 51 |
+
.eq("agent_id", "111")
|
| 52 |
+
.or_()
|
| 53 |
+
.nested(QueryBuilder()
|
| 54 |
+
.eq("task_id", "456")
|
| 55 |
+
.and_()
|
| 56 |
+
.eq("agent_id", "222"))
|
| 57 |
+
.build())
|
| 58 |
+
logger.info(query)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def example2():
|
| 62 |
+
'''
|
| 63 |
+
expression: task_id = "123" and (agent_id = "111" or agent_id = "222")
|
| 64 |
+
return :
|
| 65 |
+
{
|
| 66 |
+
'and_': [{
|
| 67 |
+
'field': 'task_id',
|
| 68 |
+
'value': '123',
|
| 69 |
+
'op': 'eq'
|
| 70 |
+
}, {
|
| 71 |
+
'or_': [{
|
| 72 |
+
'field': 'agent_id',
|
| 73 |
+
'value': '111',
|
| 74 |
+
'op': 'eq'
|
| 75 |
+
}, {
|
| 76 |
+
'field': 'agent_id',
|
| 77 |
+
'value': '222',
|
| 78 |
+
'op': 'eq'
|
| 79 |
+
}
|
| 80 |
+
}
|
| 81 |
+
}
|
| 82 |
+
'''
|
| 83 |
+
qb = QueryBuilder()
|
| 84 |
+
query = (qb.eq("task_id", "123")
|
| 85 |
+
.and_()
|
| 86 |
+
.nested(QueryBuilder()
|
| 87 |
+
.eq("agent_id", "111")
|
| 88 |
+
.or_()
|
| 89 |
+
.eq("agent_id", "222"))
|
| 90 |
+
.build())
|
| 91 |
+
logger.info(query)
|
| 92 |
+
|
| 93 |
+
if __name__ == "__main__":
|
| 94 |
+
example()
|
| 95 |
+
example1()
|
| 96 |
+
example2()
|
examples/replay_buffer/query_filter.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
from aworld.replay_buffer.base import (
|
| 3 |
+
DataRow,
|
| 4 |
+
DefaultConverter,
|
| 5 |
+
ReplayBuffer,
|
| 6 |
+
ExpMeta,
|
| 7 |
+
Experience,
|
| 8 |
+
RandomSample
|
| 9 |
+
)
|
| 10 |
+
from aworld.core.common import ActionModel, Observation
|
| 11 |
+
from aworld.replay_buffer.query_filter import QueryBuilder, QueryFilter
|
| 12 |
+
from aworld.logs.util import logger
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def test_filter():
|
| 16 |
+
row = DataRow(
|
| 17 |
+
exp_meta=ExpMeta(
|
| 18 |
+
task_id="task_1",
|
| 19 |
+
task_name="default_task_name",
|
| 20 |
+
agent_id="agent_1",
|
| 21 |
+
step=1,
|
| 22 |
+
execute_time=time.time(),
|
| 23 |
+
),
|
| 24 |
+
exp_data=Experience(state=Observation(), action=ActionModel())
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
query = QueryBuilder().eq("exp_meta.task_id", "task_1").build()
|
| 28 |
+
filter1 = QueryFilter(query)
|
| 29 |
+
assert filter1.check_condition(row)
|
| 30 |
+
|
| 31 |
+
query = QueryBuilder().eq("exp_meta.task_id", "task_2").build()
|
| 32 |
+
filter2 = QueryFilter(query)
|
| 33 |
+
assert not filter2.check_condition(row)
|
| 34 |
+
|
| 35 |
+
query = QueryBuilder().eq("exp_meta.task_id", "task_1").and_().eq(
|
| 36 |
+
"exp_meta.agent_id", "agent_2").build()
|
| 37 |
+
filter3 = QueryFilter(query)
|
| 38 |
+
assert not filter3.check_condition(row)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
if __name__ == "__main__":
|
| 42 |
+
test_filter()
|
examples/replay_buffer/storage_odps.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
from aworld.core.common import ActionModel, Observation
|
| 3 |
+
from aworld.replay_buffer.base import (
|
| 4 |
+
DataRow,
|
| 5 |
+
DefaultConverter,
|
| 6 |
+
ReplayBuffer,
|
| 7 |
+
ExpMeta,
|
| 8 |
+
Experience,
|
| 9 |
+
RandomTaskSample
|
| 10 |
+
)
|
| 11 |
+
from aworld.replay_buffer.query_filter import QueryBuilder
|
| 12 |
+
from aworld.logs.util import logger
|
| 13 |
+
from aworld.replay_buffer.storage.odps import OdpsStorage
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
buffer = ReplayBuffer(storage=OdpsStorage(
|
| 17 |
+
table_name="adm_aworld_replay_buffer",
|
| 18 |
+
project="alifin_jtest_dev",
|
| 19 |
+
endpoint="",
|
| 20 |
+
access_id="",
|
| 21 |
+
access_key=""
|
| 22 |
+
))
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def write_data():
|
| 26 |
+
rows = []
|
| 27 |
+
for id in range(5):
|
| 28 |
+
task_id = f"task_{id+1}"
|
| 29 |
+
for i in range(5):
|
| 30 |
+
agent_id = f"agent_{i+1}"
|
| 31 |
+
for j in range(5):
|
| 32 |
+
step = j + 1
|
| 33 |
+
execute_time = time.time() + j
|
| 34 |
+
row = DataRow(
|
| 35 |
+
exp_meta=ExpMeta(
|
| 36 |
+
task_id=task_id,
|
| 37 |
+
task_name="default_task_name",
|
| 38 |
+
agent_id=agent_id,
|
| 39 |
+
step=step,
|
| 40 |
+
execute_time=execute_time,
|
| 41 |
+
pre_agent="pre_agent_id"
|
| 42 |
+
),
|
| 43 |
+
exp_data=Experience(state=Observation(),
|
| 44 |
+
actions=[ActionModel()])
|
| 45 |
+
)
|
| 46 |
+
rows.append(row)
|
| 47 |
+
buffer.store_batch(rows)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def read_data():
|
| 51 |
+
query = QueryBuilder().eq("exp_meta.task_id", "task_1").build()
|
| 52 |
+
datas = buffer.sample_task(query_condition=query,
|
| 53 |
+
sampler=RandomTaskSample(),
|
| 54 |
+
converter=DefaultConverter(),
|
| 55 |
+
batch_size=1)
|
| 56 |
+
for data in datas:
|
| 57 |
+
logger.info(f"task_1 data: {data}")
|
| 58 |
+
|
| 59 |
+
query = QueryBuilder().eq("exp_meta.agent_id", "agent_5").build()
|
| 60 |
+
datas = buffer.sample_task(query_condition=query,
|
| 61 |
+
sampler=RandomTaskSample(),
|
| 62 |
+
converter=DefaultConverter(),
|
| 63 |
+
batch_size=2)
|
| 64 |
+
for data in datas:
|
| 65 |
+
logger.info(f"agent_5 data: {data}")
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
if __name__ == "__main__":
|
| 69 |
+
# write_data()
|
| 70 |
+
read_data()
|
examples/replay_buffer/storage_redis.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
from aworld.replay_buffer.base import DataRow, ExpMeta, Experience
|
| 3 |
+
from aworld.replay_buffer.storage.redis import RedisStorage
|
| 4 |
+
from aworld.replay_buffer.query_filter import QueryBuilder
|
| 5 |
+
from aworld.core.common import Observation, ActionModel
|
| 6 |
+
from aworld.logs.util import logger
|
| 7 |
+
|
| 8 |
+
storage = RedisStorage(host="localhost", port=6379,
|
| 9 |
+
recreate_idx_if_exists=False)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def generate_data_row() -> list[DataRow]:
|
| 13 |
+
rows: list[DataRow] = []
|
| 14 |
+
for id in range(5):
|
| 15 |
+
task_id = f"task_{id+1}"
|
| 16 |
+
for i in range(5):
|
| 17 |
+
agent_id = f"agent_{i+1}"
|
| 18 |
+
for j in range(5):
|
| 19 |
+
step = j + 1
|
| 20 |
+
execute_time = time.time() + j
|
| 21 |
+
row = DataRow(
|
| 22 |
+
exp_meta=ExpMeta(
|
| 23 |
+
task_id=task_id,
|
| 24 |
+
task_name="default_task_name",
|
| 25 |
+
agent_id=agent_id,
|
| 26 |
+
step=step,
|
| 27 |
+
execute_time=execute_time,
|
| 28 |
+
pre_agent="pre_agent_id"
|
| 29 |
+
),
|
| 30 |
+
exp_data=Experience(state=Observation(),
|
| 31 |
+
actions=[ActionModel()])
|
| 32 |
+
)
|
| 33 |
+
rows.append(row)
|
| 34 |
+
return rows
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def wriete_data():
|
| 38 |
+
storage.clear()
|
| 39 |
+
rows = generate_data_row()
|
| 40 |
+
storage.add_batch(rows)
|
| 41 |
+
logger.info(f"Add {len(rows)} rows to storage.")
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def read_data():
|
| 45 |
+
query_condition = (QueryBuilder()
|
| 46 |
+
.eq("exp_meta.task_id", "task_1")
|
| 47 |
+
.and_()
|
| 48 |
+
.eq("exp_meta.agent_id", "agent_1")
|
| 49 |
+
.or_()
|
| 50 |
+
.nested(QueryBuilder()
|
| 51 |
+
.eq("exp_meta.task_id", "task_4")
|
| 52 |
+
.and_()
|
| 53 |
+
.eq("exp_meta.agent_id", "agent_3")
|
| 54 |
+
.and_()
|
| 55 |
+
.gt("exp_meta.step", 4)).build())
|
| 56 |
+
|
| 57 |
+
rows = storage.get_all(query_condition)
|
| 58 |
+
for row in rows:
|
| 59 |
+
logger.info(row)
|
| 60 |
+
|
| 61 |
+
rows = storage.get_paginated(
|
| 62 |
+
page=2, page_size=2, query_condition=query_condition)
|
| 63 |
+
for row in rows:
|
| 64 |
+
logger.info(f"get_paginated: {row}")
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
if __name__ == "__main__":
|
| 68 |
+
# wriete_data()
|
| 69 |
+
read_data()
|