Spaces:
Sleeping
Sleeping
| import time | |
| from aworld.core.common import ActionModel, Observation | |
| from aworld.replay_buffer.base import ( | |
| DataRow, | |
| DefaultConverter, | |
| ReplayBuffer, | |
| ExpMeta, | |
| Experience, | |
| RandomTaskSample | |
| ) | |
| from aworld.replay_buffer.query_filter import QueryBuilder | |
| from aworld.logs.util import logger | |
| buffer = ReplayBuffer() | |
| def write_data(): | |
| for task_id in range(5): | |
| for i in range(10): | |
| task_id = f"task_{task_id}" | |
| agent_id = f"agent_{i+1}" | |
| step = i + 1 | |
| execute_time = time.time() + i | |
| row = DataRow( | |
| exp_meta=ExpMeta( | |
| task_id=task_id, | |
| task_name="default_task_name", | |
| agent_id=agent_id, | |
| step=step, | |
| execute_time=execute_time, | |
| ), | |
| exp_data=Experience(state=Observation(), | |
| actions=[ActionModel()]) | |
| ) | |
| buffer.store(row) | |
| def read_data(): | |
| query = QueryBuilder().eq("exp_meta.task_id", "task_1").build() | |
| datas = buffer.sample_task(query_condition=query, | |
| sampler=RandomTaskSample(), | |
| converter=DefaultConverter(), | |
| batch_size=2) | |
| for data in datas: | |
| logger.info(f"task_1 data: {data}") | |
| query = QueryBuilder().eq("exp_meta.agent_id", "agent_5").build() | |
| datas = buffer.sample_task(query_condition=query, | |
| sampler=RandomTaskSample(), | |
| converter=DefaultConverter(), | |
| batch_size=2) | |
| for data in datas: | |
| logger.info(f"agent_5 data: {data}") | |
| if __name__ == "__main__": | |
| write_data() | |
| read_data() | |