Spaces:
Sleeping
Sleeping
| import random | |
| import uuid | |
| from dataclasses import dataclass, field | |
| from typing import Dict, List, TypeVar | |
| from abc import ABC, abstractmethod | |
| from math import ceil | |
| from aworld.core.common import ActionModel, Observation | |
| from aworld.replay_buffer.query_filter import QueryCondition, QueryFilter | |
| from aworld.logs.util import logger | |
| T = TypeVar('T') | |
| class Experience: | |
| ''' | |
| Experience of agent. | |
| ''' | |
| state: Observation | |
| actions: List[ActionModel] | |
| reward_t: float = None | |
| adv_t: float = None | |
| v_t: float = None | |
| messages: List[Dict] = None | |
| def to_dict(self): | |
| return { | |
| "state": self.state, | |
| "actions": self.actions, | |
| "reward_t": self.reward_t, | |
| "adv_t": self.adv_t, | |
| "v_t": self.v_t, | |
| "messages": self.messages | |
| } | |
| class ExpMeta: | |
| ''' | |
| Experience meta data. | |
| ''' | |
| task_id: str | |
| task_name: str | |
| agent_id: str | |
| step: int | |
| execute_time: float | |
| pre_agent: str | |
| def to_dict(self): | |
| return { | |
| "task_id": self.task_id, | |
| "task_name": self.task_name, | |
| "agent_id": self.agent_id, | |
| "step": self.step, | |
| "execute_time": self.execute_time, | |
| "pre_agent": self.pre_agent | |
| } | |
| class DataRow: | |
| ''' | |
| Data row for storing data. | |
| ''' | |
| exp_meta: ExpMeta | |
| exp_data: Experience | |
| id: str = field(default_factory=lambda: str(uuid.uuid4())) | |
| def to_dict(self): | |
| return { | |
| "exp_meta": self.exp_meta.to_dict(), | |
| "exp_data": self.exp_data.to_dict(), | |
| "id": self.id | |
| } | |
| class Storage(ABC): | |
| ''' | |
| Storage for storing and sampling data. | |
| ''' | |
| def add(self, data: DataRow): | |
| ''' | |
| Add data to the storage. | |
| Args: | |
| data (DataRow): Data to add. | |
| ''' | |
| def add_batch(self, data_batch: List[DataRow]): | |
| ''' | |
| Add batch of data to the storage. | |
| Args: | |
| data_batch (List[DataRow]): List of data to add. | |
| ''' | |
| def size(self, query_condition: QueryCondition = None) -> int: | |
| ''' | |
| Get the size of the storage. | |
| Returns: | |
| int: Size of the storage. | |
| ''' | |
| def get_paginated(self, page: int, page_size: int, query_condition: QueryCondition = None) -> List[DataRow]: | |
| ''' | |
| Get paginated data from the storage. | |
| Args: | |
| page (int): Page number. | |
| page_size (int): Number of data per page. | |
| Returns: | |
| List[DataRow]: List of data. | |
| ''' | |
| def get_all(self, query_condition: QueryCondition = None) -> List[DataRow]: | |
| ''' | |
| Get all data from the storage. | |
| Returns: | |
| List[DataRow]: List of data. | |
| ''' | |
| def get_by_task_id(self, task_id: str) -> List[DataRow]: | |
| ''' | |
| Get data by task_id from the storage. | |
| Args: | |
| task_id (str): Task id. | |
| Returns: | |
| List[DataRow]: List of data. | |
| ''' | |
| def get_bacth_by_task_ids(self, task_ids: List[str]) -> Dict[str, List[DataRow]]: | |
| ''' | |
| Get batch of data by task_ids from the storage. | |
| Args: | |
| task_ids (List[str]): List of task ids. | |
| Returns: | |
| Dict[str, List[DataRow]]: Dictionary of data. | |
| The key is the task_id and the value is the list of data. | |
| The list of data is sorted by step. | |
| ''' | |
| class Sampler(ABC): | |
| ''' | |
| Sample data from the storage. | |
| ''' | |
| def sample(self, | |
| storage: Storage, | |
| batch_size: int, | |
| query_condition: QueryCondition = None) -> List[DataRow]: | |
| ''' | |
| Sample data from the storage. | |
| Args: | |
| storage (Storage): Storage to sample from. | |
| batch_size (int): Number of data to sample. | |
| query_condition (QueryCondition, optional): Query condition. Defaults to None. | |
| Returns: | |
| List[DataRow] | |
| ''' | |
| class TaskSampler(Sampler): | |
| ''' | |
| Sample task data from storage, returns Dict[str, List[DataRow]] where: | |
| - key is task_id | |
| - value is list of task all data rows | |
| ''' | |
| def sorted_by_step(self, task_experience: List[DataRow]) -> List[DataRow]: | |
| ''' | |
| Sort the task experience by step and execute_time. | |
| Args: | |
| task_experience (List[DataRow]): List of task experience. | |
| Returns: | |
| List[DataRow]: List of task experience sorted by step and execute_time. | |
| ''' | |
| return sorted(task_experience, key=lambda x: (x.exp_meta.step, x.exp_meta.execute_time)) | |
| def sample(self, | |
| storage: Storage, | |
| batch_size: int, | |
| query_condition: QueryCondition = None) -> List[DataRow]: | |
| task_ids = self.sample_task_ids(storage, batch_size, query_condition) | |
| return storage.get_bacth_by_task_ids(task_ids) | |
| def sample_tasks(self, | |
| storage: Storage, | |
| batch_size: int, | |
| query_condition: QueryCondition = None) -> Dict[str, List[DataRow]]: | |
| ''' | |
| Sample data from the storage. | |
| Args: | |
| storage (Storage): Storage to sample from. | |
| batch_size (int): Number of data to sample. | |
| query_condition (QueryCondition, optional): Query condition. Defaults to None. | |
| Returns: | |
| Dict[str, List[DataRow]]: Dictionary of sampled data. | |
| The key is the task_id and the value is the list of data. | |
| The list of data is sorted by step. | |
| ''' | |
| task_ids = self.sample_task_ids(storage, batch_size, query_condition) | |
| raws = storage.get_bacth_by_task_ids(task_ids) | |
| return {task_id: self.sorted_by_step(raws) for task_id, raws in raws.items()} | |
| def sample_task_ids(self, | |
| storage: Storage, | |
| batch_size: int, | |
| query_condition: QueryCondition = None) -> List[str]: | |
| ''' | |
| Sample task_ids from the storage. | |
| Args: | |
| storage (Storage): Storage to sample from. | |
| batch_size (int): Number of task_ids to sample. | |
| query_condition (QueryCondition, optional): Query condition. Defaults to None. | |
| Returns: | |
| List[str]: List of task_ids. | |
| ''' | |
| class Converter(ABC): | |
| ''' | |
| Convert data to dataset row. | |
| ''' | |
| def to_dataset_row(self, task_experience: List[DataRow]) -> T: | |
| ''' | |
| Convert task experience to dataset row. | |
| Args: | |
| task_experience (List[DataRow]): List of task experience. | |
| Returns: | |
| T: type of dataset row. | |
| ''' | |
| class InMemoryStorage(Storage): | |
| ''' | |
| In-memory storage for storing and sampling data. | |
| ''' | |
| def __init__(self, max_capacity: int = 10000): | |
| self._data: Dict[str, List[DataRow]] = {} | |
| self._max_capacity = max_capacity | |
| self._fifo_queue = [] # (task_id) | |
| def add(self, data: DataRow): | |
| if not data: | |
| raise ValueError("Data is required") | |
| if not data.exp_meta: | |
| raise ValueError("exp_meta is required") | |
| while self.size() >= self._max_capacity and self._fifo_queue: | |
| oldest_task_id = self._fifo_queue.pop(0) | |
| if oldest_task_id in self._data: | |
| del self._data[oldest_task_id] | |
| if data.exp_meta.task_id not in self._data: | |
| self._data[data.exp_meta.task_id] = [] | |
| self._data[data.exp_meta.task_id].append(data) | |
| self._fifo_queue.append(data.exp_meta.task_id) | |
| if data.exp_meta.task_id not in self._data: | |
| self._data[data.exp_meta.task_id] = [] | |
| self._data[data.exp_meta.task_id].append(data) | |
| def add_batch(self, data_batch: List[DataRow]): | |
| for data in data_batch: | |
| self.add(data) | |
| def size(self, query_condition: QueryCondition = None) -> int: | |
| return len(self.get_all(query_condition)) | |
| def get_paginated(self, page: int, page_size: int, query_condition: QueryCondition = None) -> List[DataRow]: | |
| if page < 1: | |
| raise ValueError("Page must be greater than 0") | |
| if page_size < 1: | |
| raise ValueError("Page size must be greater than 0") | |
| all_data = self.get_all(query_condition) | |
| start_index = (page - 1) * page_size | |
| end_index = start_index + page_size | |
| return all_data[start_index:end_index] | |
| def get_all(self, query_condition: QueryCondition = None) -> List[DataRow]: | |
| all_data = [] | |
| query_filter = None | |
| if query_condition: | |
| query_filter = QueryFilter(query_condition) | |
| for data in self._data.values(): | |
| if query_filter: | |
| all_data.extend(query_filter.filter(data)) | |
| else: | |
| all_data.extend(data) | |
| return all_data | |
| def get_by_task_id(self, task_id: str) -> List[DataRow]: | |
| return self._data.get(task_id, []) | |
| def get_bacth_by_task_ids(self, task_ids: List[str]) -> Dict[str, List[DataRow]]: | |
| return {task_id: self._data.get(task_id, []) for task_id in task_ids} | |
| def clear(self): | |
| self._data = {} | |
| self._fifo_queue = [] | |
| class RandomTaskSample(TaskSampler): | |
| ''' | |
| Randomly sample data from the storage. | |
| ''' | |
| def sample_task_ids(self, | |
| storage: Storage, | |
| batch_size: int, | |
| query_condition: QueryCondition = None) -> List[str]: | |
| total_size = storage.size(query_condition) | |
| if total_size <= batch_size: | |
| return storage.get_all(query_condition) | |
| sampled_task_ids = set() | |
| page_size = min(100, batch_size * 2) | |
| total_pages = ceil(total_size/page_size) | |
| visited_pages = set() | |
| while len(sampled_task_ids) < batch_size and len(visited_pages) < total_pages: | |
| page = random.choice( | |
| [p for p in range(1, total_pages+1) if p not in visited_pages]) | |
| visited_pages.add(page) | |
| current_page = storage.get_paginated( | |
| page, page_size, query_condition) | |
| if not current_page: | |
| continue | |
| current_page_task_ids = set( | |
| [data.exp_meta.task_id for data in current_page if data.exp_meta.task_id not in sampled_task_ids]) | |
| sample_count = min(len(current_page_task_ids), | |
| batch_size - len(sampled_task_ids)) | |
| sampled_task_ids.update(random.sample( | |
| list(current_page_task_ids), sample_count)) | |
| return list(sampled_task_ids) | |
| class DefaultConverter(Converter): | |
| ''' | |
| Default converter do nothing. | |
| ''' | |
| def to_dataset_row(self, task_experience: List[DataRow]) -> List[DataRow]: | |
| return task_experience | |
| class ReplayBuffer: | |
| ''' | |
| Replay buffer for storing and sampling data. | |
| ''' | |
| def __init__( | |
| self, | |
| storage: Storage = InMemoryStorage() | |
| ): | |
| self._storage = storage | |
| def store(self, data: DataRow): | |
| ''' | |
| Store data in the replay buffer. | |
| ''' | |
| if not data: | |
| raise ValueError("Data is required") | |
| self._storage.add(data) | |
| def store_batch(self, data_batch: List[DataRow]): | |
| ''' | |
| Store batch of data in the replay buffer. | |
| ''' | |
| if not data_batch: | |
| raise ValueError("Data batch is required") | |
| self._storage.add_batch(data_batch) | |
| def sample_task(self, | |
| sampler: TaskSampler = RandomTaskSample(), | |
| query_condition: QueryCondition = None, | |
| converter: Converter = DefaultConverter(), | |
| batch_size: int = 1000) -> List[T]: | |
| ''' | |
| Sample Task from the replay buffer and convert to dataset row. | |
| DefaultConverter return List[DataRow] | |
| ''' | |
| sampled_task = sampler.sample_tasks( | |
| self._storage, batch_size, query_condition) | |
| return [converter.to_dataset_row(task_experiences) for task_experiences in sampled_task.values()] | |
| def sample(self, | |
| sampler: Sampler = RandomTaskSample(), | |
| query_condition: QueryCondition = None, | |
| converter: Converter = DefaultConverter(), | |
| batch_size: int = 1000) -> List[T]: | |
| ''' | |
| Sample data from the replay buffer and convert to dataset row. | |
| DefaultConverter return List[DataRow] | |
| ''' | |
| sampled_data = sampler.sample( | |
| self._storage, batch_size, query_condition) | |
| return converter.to_dataset_row(sampled_data) | |