Duibonduil commited on
Commit
e0f8ec7
·
verified ·
1 Parent(s): cc54e11

Upload __init__.py

Browse files
Files changed (1) hide show
  1. aworld/__init__.py +240 -0
aworld/__init__.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Optional, List
2
+ import copy
3
+ import uuid
4
+ from datetime import datetime, timezone
5
+ from abc import ABC, abstractmethod
6
+ import asyncio
7
+ from pydantic import BaseModel, Field
8
+
9
+ class CheckpointMetadata(BaseModel):
10
+ """
11
+ Metadata for a checkpoint, including session and task identifiers.
12
+
13
+ Attributes:
14
+ session_id (str): The session identifier (required).
15
+ task_id (Optional[str]): The task identifier (optional).
16
+ """
17
+ session_id: str = Field(..., description="The session identifier.")
18
+ task_id: Optional[str] = Field(None, description="The task identifier.")
19
+
20
+ class Checkpoint(BaseModel):
21
+ """
22
+ Core structure for a state checkpoint.
23
+
24
+ Attributes:
25
+ id (str): Unique identifier for the checkpoint.
26
+ ts (str): Timestamp of the checkpoint.
27
+ metadata (CheckpointMetadata): Metadata associated with the checkpoint.
28
+ values (dict[str, Any]): State values stored in the checkpoint.
29
+ version (str): Version of the checkpoint format.
30
+ parent_id (Optional[str]): Parent checkpoint identifier, if any.
31
+ namespace (str): Namespace for the checkpoint, default is 'aworld'.
32
+ """
33
+ id: str = Field(..., description="Unique identifier for the checkpoint.")
34
+ ts: str = Field(..., description="Timestamp of the checkpoint.")
35
+ metadata: CheckpointMetadata = Field(..., description="Metadata associated with the checkpoint.")
36
+ values: Dict[str, Any] = Field(..., description="State values stored in the checkpoint.")
37
+ version: int = Field(..., description="Version of the checkpoint format.")
38
+ parent_id: Optional[str] = Field(default=None, description="Parent checkpoint identifier, if any.")
39
+ namespace: str = Field(default="aworld", description="Namespace for the checkpoint, default is 'aworld'.")
40
+
41
+ def empty_checkpoint() -> Checkpoint:
42
+ """
43
+ Create an empty checkpoint with default values.
44
+
45
+ Returns:
46
+ Checkpoint: An empty checkpoint structure.
47
+ """
48
+ return Checkpoint(
49
+ id=str(uuid.uuid4()),
50
+ ts=datetime.now(timezone.utc).isoformat(),
51
+ metadata=CheckpointMetadata(session_id="", task_id=None),
52
+ values={},
53
+ version=1,
54
+ parent_id=None,
55
+ namespace="aworld",
56
+ )
57
+
58
+ def copy_checkpoint(checkpoint: Checkpoint) -> Checkpoint:
59
+ """
60
+ Create a deep copy of a checkpoint.
61
+
62
+ Args:
63
+ checkpoint (Checkpoint): The checkpoint to copy.
64
+ Returns:
65
+ Checkpoint: A deep copy of the provided checkpoint.
66
+ """
67
+ return copy.deepcopy(checkpoint)
68
+
69
+ def create_checkpoint(
70
+ values: Dict[str, Any],
71
+ metadata: CheckpointMetadata,
72
+ parent_id: Optional[str] = None,
73
+ version: int = 1,
74
+ namespace: str = 'aworld',
75
+ ) -> Checkpoint:
76
+ """
77
+ Create a new checkpoint from provided state values and metadata.
78
+
79
+ Args:
80
+ values (dict[str, Any]): State values to store in the checkpoint.
81
+ metadata (CheckpointMetadata): Metadata for the checkpoint.
82
+ parent_id (Optional[str]): Parent checkpoint identifier, if any.
83
+ version (str): Version of the checkpoint format.
84
+ namespace (str): Namespace for the checkpoint.
85
+ Returns:
86
+ Checkpoint: The newly created checkpoint.
87
+ """
88
+ return Checkpoint(
89
+ id=str(uuid.uuid4()),
90
+ ts=datetime.now(timezone.utc).isoformat(),
91
+ metadata=metadata,
92
+ values=values,
93
+ version=VersionUtils.get_next_version(version),
94
+ parent_id=parent_id,
95
+ namespace=namespace,
96
+ )
97
+
98
+ class BaseCheckpointRepository(ABC):
99
+ """
100
+ Abstract base class for a checkpoint repository.
101
+ Provides synchronous and asynchronous methods for checkpoint management.
102
+ """
103
+
104
+ @abstractmethod
105
+ def get(self, checkpoint_id: str) -> Optional[Checkpoint]:
106
+ """
107
+ Retrieve a checkpoint by its unique identifier.
108
+
109
+ Args:
110
+ checkpoint_id (str): The unique identifier of the checkpoint.
111
+ Returns:
112
+ Optional[Checkpoint]: The checkpoint if found, otherwise None.
113
+ """
114
+ pass
115
+
116
+ @abstractmethod
117
+ def list(self, params: Dict[str, Any]) -> List[Checkpoint]:
118
+ """
119
+ List checkpoints matching the given parameters.
120
+
121
+ Args:
122
+ params (dict): Parameters to filter checkpoints.
123
+ Returns:
124
+ List[Checkpoint]: List of matching checkpoints.
125
+ """
126
+ pass
127
+
128
+ @abstractmethod
129
+ def put(self, checkpoint: Checkpoint) -> None:
130
+ """
131
+ Store a checkpoint.
132
+
133
+ Args:
134
+ checkpoint (Checkpoint): The checkpoint to store.
135
+ """
136
+ pass
137
+
138
+ @abstractmethod
139
+ def get_by_session(self, session_id: str) -> Optional[Checkpoint]:
140
+ """
141
+ Get the latest checkpoint for a session.
142
+
143
+ Args:
144
+ session_id (str): The session identifier.
145
+ Returns:
146
+ Optional[Checkpoint]: The latest checkpoint if found, otherwise None.
147
+ """
148
+ pass
149
+
150
+ @abstractmethod
151
+ def delete_by_session(self, session_id: str) -> None:
152
+ """
153
+ Delete all checkpoints related to a session.
154
+
155
+ Args:
156
+ session_id (str): The session identifier.
157
+ """
158
+ pass
159
+
160
+ # Async methods
161
+ async def aget(self, checkpoint_id: str) -> Optional[Checkpoint]:
162
+ """
163
+ Asynchronously retrieve a checkpoint by its unique identifier.
164
+
165
+ Args:
166
+ checkpoint_id (str): The unique identifier of the checkpoint.
167
+ Returns:
168
+ Optional[Checkpoint]: The checkpoint if found, otherwise None.
169
+ """
170
+ return await asyncio.to_thread(self.get, checkpoint_id)
171
+
172
+ async def alist(self, params: Dict[str, Any]) -> List[Checkpoint]:
173
+ """
174
+ Asynchronously list checkpoints matching the given parameters.
175
+
176
+ Args:
177
+ params (dict): Parameters to filter checkpoints.
178
+ Returns:
179
+ List[Checkpoint]: List of matching checkpoints.
180
+ """
181
+ return await asyncio.to_thread(self.list, params)
182
+
183
+ async def aput(self, checkpoint: Checkpoint) -> None:
184
+ """
185
+ Asynchronously store a checkpoint.
186
+
187
+ Args:
188
+ checkpoint (Checkpoint): The checkpoint to store.
189
+ """
190
+ await asyncio.to_thread(self.put, checkpoint)
191
+
192
+ async def aget_by_session(self, session_id: str) -> Optional[Checkpoint]:
193
+ """
194
+ Asynchronously get the latest checkpoint for a session.
195
+
196
+ Args:
197
+ session_id (str): The session identifier.
198
+ Returns:
199
+ Optional[Checkpoint]: The latest checkpoint if found, otherwise None.
200
+ """
201
+ return await asyncio.to_thread(self.get_by_session, session_id)
202
+
203
+ async def adelete_by_session(self, session_id: str) -> None:
204
+ """
205
+ Asynchronously delete all checkpoints related to a session.
206
+
207
+ Args:
208
+ session_id (str): The session identifier.
209
+ """
210
+ await asyncio.to_thread(self.delete_by_session, session_id)
211
+
212
+ class VersionUtils:
213
+
214
+ @staticmethod
215
+ def get_next_version(version: int) -> int:
216
+ """
217
+ Get the next version of the checkpoint.
218
+ """
219
+ return version + 1
220
+
221
+ @staticmethod
222
+ def get_previous_version(version: int) -> int:
223
+ """
224
+ Get the previous version of the checkpoint.
225
+ """
226
+ return version - 1
227
+
228
+ @staticmethod
229
+ def is_version_greater(checkpoint: Checkpoint, version: int) -> bool:
230
+ """
231
+ Check if the checkpoint version is greater than the given version.
232
+ """
233
+ return checkpoint.version > version
234
+
235
+ @staticmethod
236
+ def is_version_less(checkpoint: Checkpoint, version: int) -> bool:
237
+ """
238
+ Check if the checkpoint version is less than the given version.
239
+ """
240
+ return checkpoint.version < version