Spaces:
Sleeping
Sleeping
| from typing import Any, Dict, List, Optional | |
| from . import Checkpoint, BaseCheckpointRepository, VersionUtils | |
| class InMemoryCheckpointRepository(BaseCheckpointRepository): | |
| """ | |
| In-memory implementation of BaseCheckpointRepository. | |
| Stores checkpoints in a simple in-memory dictionary. | |
| Thread safety is not guaranteed. | |
| """ | |
| def __init__(self) -> None: | |
| """ | |
| Initialize the in-memory checkpoint repository. | |
| """ | |
| self._checkpoints: Dict[str, Checkpoint] = {} | |
| self._session_index: Dict[str, List[str]] = {} | |
| def get(self, checkpoint_id: str) -> Optional[Checkpoint]: | |
| """ | |
| Retrieve a checkpoint by its unique identifier. | |
| Args: | |
| checkpoint_id (str): The unique identifier of the checkpoint. | |
| Returns: | |
| Optional[Checkpoint]: The checkpoint if found, otherwise None. | |
| """ | |
| return self._checkpoints.get(checkpoint_id) | |
| def list(self, params: Dict[str, Any]) -> List[Checkpoint]: | |
| """ | |
| List checkpoints matching the given parameters. | |
| Args: | |
| params (dict): Parameters to filter checkpoints. | |
| Returns: | |
| List[Checkpoint]: List of matching checkpoints. | |
| """ | |
| result = [] | |
| for cp in self._checkpoints.values(): | |
| match = True | |
| for k, v in params.items(): | |
| if k == 'session_id': | |
| if cp.metadata.session_id != v: | |
| match = False | |
| break | |
| elif k == 'task_id': | |
| if cp.metadata.task_id != v: | |
| match = False | |
| break | |
| elif cp.get(k) != v: | |
| match = False | |
| break | |
| if match: | |
| result.append(cp) | |
| return result | |
| def put(self, checkpoint: Checkpoint) -> None: | |
| """ | |
| Store a checkpoint. | |
| Args: | |
| checkpoint (Checkpoint): The checkpoint to store. | |
| """ | |
| # Find last version checkpoint by session_id | |
| last_checkpoint = self.get_by_session(checkpoint.metadata.session_id) | |
| if last_checkpoint: | |
| # Compare versions to ensure optimistic locking | |
| if VersionUtils.is_version_less(checkpoint, last_checkpoint.version): | |
| raise ValueError(f"New checkpoint version {checkpoint.version} must be greater than last version {last_checkpoint.version}") | |
| # Store the new checkpoint | |
| self._checkpoints[checkpoint.id] = checkpoint | |
| # Update session index | |
| session_id = checkpoint.metadata.session_id | |
| if session_id: | |
| if session_id not in self._session_index: | |
| self._session_index[session_id] = [] | |
| self._session_index[session_id].append(checkpoint.id) | |
| def get_by_session(self, session_id: str) -> Optional[Checkpoint]: | |
| """ | |
| Get the latest checkpoint for a session. | |
| Args: | |
| session_id (str): The session identifier. | |
| Returns: | |
| Optional[Checkpoint]: The latest checkpoint if found, otherwise None. | |
| """ | |
| ids = self._session_index.get(session_id, []) | |
| if not ids: | |
| return None | |
| # Assume the last one is the latest | |
| last_id = ids[-1] | |
| return self._checkpoints.get(last_id) | |
| def delete_by_session(self, session_id: str) -> None: | |
| """ | |
| Delete all checkpoints related to a session. | |
| Args: | |
| session_id (str): The session identifier. | |
| """ | |
| ids = self._session_index.pop(session_id, []) | |
| for cid in ids: | |
| self._checkpoints.pop(cid, None) | |