Spaces:
Sleeping
Sleeping
Upload 2 files
Browse files- aworld/checkpoint/README.md +98 -0
- aworld/checkpoint/inmemory.py +101 -0
aworld/checkpoint/README.md
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Checkpoint Module
|
2 |
+
|
3 |
+
## Overview
|
4 |
+
The Checkpoint module provides a robust and extensible framework for managing state snapshots (checkpoints) in Python applications. It is designed for scenarios where you need to persist, restore, and version the state of a process, session, or task.
|
5 |
+
|
6 |
+
```mermaid
|
7 |
+
sequenceDiagram
|
8 |
+
participant Application
|
9 |
+
participant CheckpointRepository
|
10 |
+
participant BackendStorage
|
11 |
+
|
12 |
+
Note over Application,BackendStorage: Create and store a checkpoint
|
13 |
+
%% Create and store a checkpoint
|
14 |
+
Application->>CheckpointRepository: create checkpoint
|
15 |
+
CheckpointRepository->>BackendStorage: put(checkpoint)
|
16 |
+
BackendStorage-->>CheckpointRepository: success
|
17 |
+
CheckpointRepository-->>Application: ack
|
18 |
+
|
19 |
+
Note over Application,BackendStorage: Retrieve the latest checkpoint by session
|
20 |
+
|
21 |
+
%% Retrieve the latest checkpoint by session
|
22 |
+
Application->>CheckpointRepository: get checkpoint by session_id
|
23 |
+
CheckpointRepository->>BackendStorage: get_by_session(session_id)
|
24 |
+
BackendStorage-->>CheckpointRepository: Checkpoint
|
25 |
+
CheckpointRepository-->>Application: Checkpoint
|
26 |
+
|
27 |
+
```
|
28 |
+
|
29 |
+
## Key Features
|
30 |
+
|
31 |
+
- **Structured Data Model**: Uses Pydantic's `BaseModel` for strong typing and validation of checkpoint data and metadata.
|
32 |
+
- **Versioning Support**: Built-in version management utilities for checkpoint evolution and comparison.
|
33 |
+
- **Extensible Repository Pattern**: Abstract base class (`BaseCheckpointRepository`) defines a standard interface for checkpoint storage, supporting both synchronous and asynchronous operations.
|
34 |
+
- **In-Memory Implementation**: Includes a simple, ready-to-use in-memory repository for development and testing.
|
35 |
+
- **Utility Functions**: Helper methods for creating, copying, and managing checkpoints.
|
36 |
+
|
37 |
+
## Data Structures
|
38 |
+
|
39 |
+
```mermaid
|
40 |
+
classDiagram
|
41 |
+
class Application {
|
42 |
+
+CheckpointRepository repo
|
43 |
+
+create_checkpoint()
|
44 |
+
+get_checkpoint_by_session()
|
45 |
+
}
|
46 |
+
class CheckpointRepository {
|
47 |
+
+put(checkpoint)
|
48 |
+
+get_by_session(session_id)
|
49 |
+
+delete_by_session(session_id)
|
50 |
+
-BackendStorage backend
|
51 |
+
}
|
52 |
+
class BackendStorage {
|
53 |
+
+put(checkpoint)
|
54 |
+
+get_by_session(session_id)
|
55 |
+
+delete_by_session(session_id)
|
56 |
+
}
|
57 |
+
Application --> CheckpointRepository : uses
|
58 |
+
CheckpointRepository --> BackendStorage : delegates
|
59 |
+
class Checkpoint {
|
60 |
+
+id: str
|
61 |
+
+ts: str
|
62 |
+
+metadata: CheckpointMetadata
|
63 |
+
+values: dict
|
64 |
+
+version: int
|
65 |
+
+parent_id: str
|
66 |
+
+namespace: str
|
67 |
+
}
|
68 |
+
class CheckpointMetadata {
|
69 |
+
+session_id: str
|
70 |
+
+task_id: str
|
71 |
+
}
|
72 |
+
Checkpoint o-- CheckpointMetadata
|
73 |
+
CheckpointRepository o-- Checkpoint
|
74 |
+
BackendStorage o-- Checkpoint
|
75 |
+
```
|
76 |
+
|
77 |
+
|
78 |
+
## Usage Example
|
79 |
+
|
80 |
+
```python
|
81 |
+
from aworld.checkpoint import (
|
82 |
+
Checkpoint, CheckpointMetadata, empty_checkpoint, create_checkpoint, InMemoryCheckpointRepository
|
83 |
+
)
|
84 |
+
|
85 |
+
# Create a new checkpoint
|
86 |
+
metadata = CheckpointMetadata(session_id="session-123", task_id="task-456")
|
87 |
+
values = {"step": 1, "score": 100}
|
88 |
+
checkpoint = create_checkpoint(values=values, metadata=metadata)
|
89 |
+
|
90 |
+
# Store and retrieve using the in-memory repository
|
91 |
+
repo = InMemoryCheckpointRepository()
|
92 |
+
repo.put(checkpoint)
|
93 |
+
restored = repo.get(checkpoint.id)
|
94 |
+
```
|
95 |
+
|
96 |
+
## Extensibility
|
97 |
+
- Implement custom repositories by inheriting from `BaseCheckpointRepository` (e.g., for database, file, or cloud storage).
|
98 |
+
- Extend versioning logic via the `VersionUtils` class.
|
aworld/checkpoint/inmemory.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict, List, Optional
|
2 |
+
from . import Checkpoint, BaseCheckpointRepository, VersionUtils
|
3 |
+
|
4 |
+
class InMemoryCheckpointRepository(BaseCheckpointRepository):
|
5 |
+
"""
|
6 |
+
In-memory implementation of BaseCheckpointRepository.
|
7 |
+
Stores checkpoints in a simple in-memory dictionary.
|
8 |
+
Thread safety is not guaranteed.
|
9 |
+
"""
|
10 |
+
def __init__(self) -> None:
|
11 |
+
"""
|
12 |
+
Initialize the in-memory checkpoint repository.
|
13 |
+
"""
|
14 |
+
self._checkpoints: Dict[str, Checkpoint] = {}
|
15 |
+
self._session_index: Dict[str, List[str]] = {}
|
16 |
+
|
17 |
+
def get(self, checkpoint_id: str) -> Optional[Checkpoint]:
|
18 |
+
"""
|
19 |
+
Retrieve a checkpoint by its unique identifier.
|
20 |
+
Args:
|
21 |
+
checkpoint_id (str): The unique identifier of the checkpoint.
|
22 |
+
Returns:
|
23 |
+
Optional[Checkpoint]: The checkpoint if found, otherwise None.
|
24 |
+
"""
|
25 |
+
return self._checkpoints.get(checkpoint_id)
|
26 |
+
|
27 |
+
def list(self, params: Dict[str, Any]) -> List[Checkpoint]:
|
28 |
+
"""
|
29 |
+
List checkpoints matching the given parameters.
|
30 |
+
Args:
|
31 |
+
params (dict): Parameters to filter checkpoints.
|
32 |
+
Returns:
|
33 |
+
List[Checkpoint]: List of matching checkpoints.
|
34 |
+
"""
|
35 |
+
result = []
|
36 |
+
for cp in self._checkpoints.values():
|
37 |
+
match = True
|
38 |
+
for k, v in params.items():
|
39 |
+
if k == 'session_id':
|
40 |
+
if cp.metadata.session_id != v:
|
41 |
+
match = False
|
42 |
+
break
|
43 |
+
elif k == 'task_id':
|
44 |
+
if cp.metadata.task_id != v:
|
45 |
+
match = False
|
46 |
+
break
|
47 |
+
elif cp.get(k) != v:
|
48 |
+
match = False
|
49 |
+
break
|
50 |
+
if match:
|
51 |
+
result.append(cp)
|
52 |
+
return result
|
53 |
+
|
54 |
+
def put(self, checkpoint: Checkpoint) -> None:
|
55 |
+
"""
|
56 |
+
Store a checkpoint.
|
57 |
+
Args:
|
58 |
+
checkpoint (Checkpoint): The checkpoint to store.
|
59 |
+
"""
|
60 |
+
# Find last version checkpoint by session_id
|
61 |
+
last_checkpoint = self.get_by_session(checkpoint.metadata.session_id)
|
62 |
+
|
63 |
+
if last_checkpoint:
|
64 |
+
# Compare versions to ensure optimistic locking
|
65 |
+
if VersionUtils.is_version_less(checkpoint, last_checkpoint.version):
|
66 |
+
raise ValueError(f"New checkpoint version {checkpoint.version} must be greater than last version {last_checkpoint.version}")
|
67 |
+
|
68 |
+
# Store the new checkpoint
|
69 |
+
self._checkpoints[checkpoint.id] = checkpoint
|
70 |
+
|
71 |
+
# Update session index
|
72 |
+
session_id = checkpoint.metadata.session_id
|
73 |
+
if session_id:
|
74 |
+
if session_id not in self._session_index:
|
75 |
+
self._session_index[session_id] = []
|
76 |
+
self._session_index[session_id].append(checkpoint.id)
|
77 |
+
|
78 |
+
def get_by_session(self, session_id: str) -> Optional[Checkpoint]:
|
79 |
+
"""
|
80 |
+
Get the latest checkpoint for a session.
|
81 |
+
Args:
|
82 |
+
session_id (str): The session identifier.
|
83 |
+
Returns:
|
84 |
+
Optional[Checkpoint]: The latest checkpoint if found, otherwise None.
|
85 |
+
"""
|
86 |
+
ids = self._session_index.get(session_id, [])
|
87 |
+
if not ids:
|
88 |
+
return None
|
89 |
+
# Assume the last one is the latest
|
90 |
+
last_id = ids[-1]
|
91 |
+
return self._checkpoints.get(last_id)
|
92 |
+
|
93 |
+
def delete_by_session(self, session_id: str) -> None:
|
94 |
+
"""
|
95 |
+
Delete all checkpoints related to a session.
|
96 |
+
Args:
|
97 |
+
session_id (str): The session identifier.
|
98 |
+
"""
|
99 |
+
ids = self._session_index.pop(session_id, [])
|
100 |
+
for cid in ids:
|
101 |
+
self._checkpoints.pop(cid, None)
|