Duibonduil commited on
Commit
8af6ba8
·
verified ·
1 Parent(s): b3e3855

Upload 2 files

Browse files
aworld/checkpoint/README.md ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Checkpoint Module
2
+
3
+ ## Overview
4
+ The Checkpoint module provides a robust and extensible framework for managing state snapshots (checkpoints) in Python applications. It is designed for scenarios where you need to persist, restore, and version the state of a process, session, or task.
5
+
6
+ ```mermaid
7
+ sequenceDiagram
8
+ participant Application
9
+ participant CheckpointRepository
10
+ participant BackendStorage
11
+
12
+ Note over Application,BackendStorage: Create and store a checkpoint
13
+ %% Create and store a checkpoint
14
+ Application->>CheckpointRepository: create checkpoint
15
+ CheckpointRepository->>BackendStorage: put(checkpoint)
16
+ BackendStorage-->>CheckpointRepository: success
17
+ CheckpointRepository-->>Application: ack
18
+
19
+ Note over Application,BackendStorage: Retrieve the latest checkpoint by session
20
+
21
+ %% Retrieve the latest checkpoint by session
22
+ Application->>CheckpointRepository: get checkpoint by session_id
23
+ CheckpointRepository->>BackendStorage: get_by_session(session_id)
24
+ BackendStorage-->>CheckpointRepository: Checkpoint
25
+ CheckpointRepository-->>Application: Checkpoint
26
+
27
+ ```
28
+
29
+ ## Key Features
30
+
31
+ - **Structured Data Model**: Uses Pydantic's `BaseModel` for strong typing and validation of checkpoint data and metadata.
32
+ - **Versioning Support**: Built-in version management utilities for checkpoint evolution and comparison.
33
+ - **Extensible Repository Pattern**: Abstract base class (`BaseCheckpointRepository`) defines a standard interface for checkpoint storage, supporting both synchronous and asynchronous operations.
34
+ - **In-Memory Implementation**: Includes a simple, ready-to-use in-memory repository for development and testing.
35
+ - **Utility Functions**: Helper methods for creating, copying, and managing checkpoints.
36
+
37
+ ## Data Structures
38
+
39
+ ```mermaid
40
+ classDiagram
41
+ class Application {
42
+ +CheckpointRepository repo
43
+ +create_checkpoint()
44
+ +get_checkpoint_by_session()
45
+ }
46
+ class CheckpointRepository {
47
+ +put(checkpoint)
48
+ +get_by_session(session_id)
49
+ +delete_by_session(session_id)
50
+ -BackendStorage backend
51
+ }
52
+ class BackendStorage {
53
+ +put(checkpoint)
54
+ +get_by_session(session_id)
55
+ +delete_by_session(session_id)
56
+ }
57
+ Application --> CheckpointRepository : uses
58
+ CheckpointRepository --> BackendStorage : delegates
59
+ class Checkpoint {
60
+ +id: str
61
+ +ts: str
62
+ +metadata: CheckpointMetadata
63
+ +values: dict
64
+ +version: int
65
+ +parent_id: str
66
+ +namespace: str
67
+ }
68
+ class CheckpointMetadata {
69
+ +session_id: str
70
+ +task_id: str
71
+ }
72
+ Checkpoint o-- CheckpointMetadata
73
+ CheckpointRepository o-- Checkpoint
74
+ BackendStorage o-- Checkpoint
75
+ ```
76
+
77
+
78
+ ## Usage Example
79
+
80
+ ```python
81
+ from aworld.checkpoint import (
82
+ Checkpoint, CheckpointMetadata, empty_checkpoint, create_checkpoint, InMemoryCheckpointRepository
83
+ )
84
+
85
+ # Create a new checkpoint
86
+ metadata = CheckpointMetadata(session_id="session-123", task_id="task-456")
87
+ values = {"step": 1, "score": 100}
88
+ checkpoint = create_checkpoint(values=values, metadata=metadata)
89
+
90
+ # Store and retrieve using the in-memory repository
91
+ repo = InMemoryCheckpointRepository()
92
+ repo.put(checkpoint)
93
+ restored = repo.get(checkpoint.id)
94
+ ```
95
+
96
+ ## Extensibility
97
+ - Implement custom repositories by inheriting from `BaseCheckpointRepository` (e.g., for database, file, or cloud storage).
98
+ - Extend versioning logic via the `VersionUtils` class.
aworld/checkpoint/inmemory.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Optional
2
+ from . import Checkpoint, BaseCheckpointRepository, VersionUtils
3
+
4
+ class InMemoryCheckpointRepository(BaseCheckpointRepository):
5
+ """
6
+ In-memory implementation of BaseCheckpointRepository.
7
+ Stores checkpoints in a simple in-memory dictionary.
8
+ Thread safety is not guaranteed.
9
+ """
10
+ def __init__(self) -> None:
11
+ """
12
+ Initialize the in-memory checkpoint repository.
13
+ """
14
+ self._checkpoints: Dict[str, Checkpoint] = {}
15
+ self._session_index: Dict[str, List[str]] = {}
16
+
17
+ def get(self, checkpoint_id: str) -> Optional[Checkpoint]:
18
+ """
19
+ Retrieve a checkpoint by its unique identifier.
20
+ Args:
21
+ checkpoint_id (str): The unique identifier of the checkpoint.
22
+ Returns:
23
+ Optional[Checkpoint]: The checkpoint if found, otherwise None.
24
+ """
25
+ return self._checkpoints.get(checkpoint_id)
26
+
27
+ def list(self, params: Dict[str, Any]) -> List[Checkpoint]:
28
+ """
29
+ List checkpoints matching the given parameters.
30
+ Args:
31
+ params (dict): Parameters to filter checkpoints.
32
+ Returns:
33
+ List[Checkpoint]: List of matching checkpoints.
34
+ """
35
+ result = []
36
+ for cp in self._checkpoints.values():
37
+ match = True
38
+ for k, v in params.items():
39
+ if k == 'session_id':
40
+ if cp.metadata.session_id != v:
41
+ match = False
42
+ break
43
+ elif k == 'task_id':
44
+ if cp.metadata.task_id != v:
45
+ match = False
46
+ break
47
+ elif cp.get(k) != v:
48
+ match = False
49
+ break
50
+ if match:
51
+ result.append(cp)
52
+ return result
53
+
54
+ def put(self, checkpoint: Checkpoint) -> None:
55
+ """
56
+ Store a checkpoint.
57
+ Args:
58
+ checkpoint (Checkpoint): The checkpoint to store.
59
+ """
60
+ # Find last version checkpoint by session_id
61
+ last_checkpoint = self.get_by_session(checkpoint.metadata.session_id)
62
+
63
+ if last_checkpoint:
64
+ # Compare versions to ensure optimistic locking
65
+ if VersionUtils.is_version_less(checkpoint, last_checkpoint.version):
66
+ raise ValueError(f"New checkpoint version {checkpoint.version} must be greater than last version {last_checkpoint.version}")
67
+
68
+ # Store the new checkpoint
69
+ self._checkpoints[checkpoint.id] = checkpoint
70
+
71
+ # Update session index
72
+ session_id = checkpoint.metadata.session_id
73
+ if session_id:
74
+ if session_id not in self._session_index:
75
+ self._session_index[session_id] = []
76
+ self._session_index[session_id].append(checkpoint.id)
77
+
78
+ def get_by_session(self, session_id: str) -> Optional[Checkpoint]:
79
+ """
80
+ Get the latest checkpoint for a session.
81
+ Args:
82
+ session_id (str): The session identifier.
83
+ Returns:
84
+ Optional[Checkpoint]: The latest checkpoint if found, otherwise None.
85
+ """
86
+ ids = self._session_index.get(session_id, [])
87
+ if not ids:
88
+ return None
89
+ # Assume the last one is the latest
90
+ last_id = ids[-1]
91
+ return self._checkpoints.get(last_id)
92
+
93
+ def delete_by_session(self, session_id: str) -> None:
94
+ """
95
+ Delete all checkpoints related to a session.
96
+ Args:
97
+ session_id (str): The session identifier.
98
+ """
99
+ ids = self._session_index.pop(session_id, [])
100
+ for cid in ids:
101
+ self._checkpoints.pop(cid, None)