Duibonduil commited on
Commit
8499ea2
·
verified ·
1 Parent(s): d02b6a5

Upload db.py

Browse files
AWorld-main/aworlddistributed/aworldspace/db.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from datetime import datetime
3
+ from typing import Optional
4
+
5
+ from sqlalchemy import create_engine
6
+ from sqlalchemy.orm import sessionmaker
7
+
8
+ from base import AworldTask, AworldTaskResult
9
+ from aworldspace.db.models import (
10
+ Base, AworldTaskModel, AworldTaskResultModel,
11
+ orm_to_pydantic_task, pydantic_to_orm_task,
12
+ orm_to_pydantic_result, pydantic_to_orm_result
13
+ )
14
+
15
+
16
+ class AworldTaskDB(ABC):
17
+
18
+ @abstractmethod
19
+ async def query_task_by_id(self, task_id: str) -> AworldTask:
20
+ pass
21
+
22
+ @abstractmethod
23
+ async def query_latest_task_result_by_id(self, task_id: str) -> Optional[AworldTaskResult]:
24
+ pass
25
+
26
+ @abstractmethod
27
+ async def insert_task(self, task: AworldTask):
28
+ pass
29
+
30
+ @abstractmethod
31
+ async def query_tasks_by_status(self, status: str, nums: int) -> list[AworldTask]:
32
+ pass
33
+
34
+ @abstractmethod
35
+ async def update_task(self, task: AworldTask):
36
+ pass
37
+
38
+ @abstractmethod
39
+ async def page_query_tasks(self, filter: dict, page_size: int, page_num: int) -> dict:
40
+ pass
41
+
42
+ @abstractmethod
43
+ async def save_task_result(self, result: AworldTaskResult):
44
+ pass
45
+
46
+
47
+ class SqliteTaskDB(AworldTaskDB):
48
+ def __init__(self, db_path: str):
49
+ self.engine = create_engine(db_path, echo=False, future=True)
50
+ Base.metadata.create_all(self.engine)
51
+ self.Session = sessionmaker(bind=self.engine, expire_on_commit=False)
52
+
53
+ async def query_task_by_id(self, task_id: str) -> Optional[AworldTask]:
54
+ with self.Session() as session:
55
+ orm_task = session.query(AworldTaskModel).filter_by(task_id=task_id).first()
56
+ return orm_to_pydantic_task(orm_task) if orm_task else None
57
+
58
+ async def query_latest_task_result_by_id(self, task_id: str) -> Optional[AworldTaskResult]:
59
+ with self.Session() as session:
60
+ orm_result = (
61
+ session.query(AworldTaskResultModel)
62
+ .filter_by(task_id=task_id)
63
+ .order_by(AworldTaskResultModel.created_at.desc())
64
+ .first()
65
+ )
66
+ return orm_to_pydantic_result(orm_result) if orm_result else None
67
+
68
+ async def insert_task(self, task: AworldTask):
69
+ with self.Session() as session:
70
+ orm_task = pydantic_to_orm_task(task)
71
+ session.add(orm_task)
72
+ session.commit()
73
+
74
+ async def query_tasks_by_status(self, status: str, nums: int) -> list[AworldTask]:
75
+ with self.Session() as session:
76
+ orm_tasks = (
77
+ session.query(AworldTaskModel)
78
+ .filter_by(status=status)
79
+ .limit(nums)
80
+ .all()
81
+ )
82
+ return [orm_to_pydantic_task(t) for t in orm_tasks]
83
+
84
+ async def update_task(self, task: AworldTask):
85
+ with self.Session() as session:
86
+ orm_task = session.query(AworldTaskModel).filter_by(task_id=task.task_id).first()
87
+ if orm_task:
88
+ for k, v in task.model_dump().items():
89
+ setattr(orm_task, k, v)
90
+ orm_task.updated_at = datetime.utcnow()
91
+ session.commit()
92
+
93
+ async def save_task_result(self, result: AworldTaskResult):
94
+ with self.Session() as session:
95
+ orm_task = pydantic_to_orm_result(result)
96
+ session.add(orm_task)
97
+ session.commit()
98
+
99
+ async def page_query_tasks(self, filter: dict, page_size: int, page_num: int) -> dict:
100
+ with self.Session() as session:
101
+ query = session.query(AworldTaskModel)
102
+
103
+ # Handle special filters for time ranges
104
+ start_time = filter.pop('start_time', None)
105
+ end_time = filter.pop('end_time', None)
106
+
107
+ # Apply regular filters
108
+ for k, v in filter.items():
109
+ if hasattr(AworldTaskModel, k):
110
+ query = query.filter(getattr(AworldTaskModel, k) == v)
111
+
112
+ # Apply time range filters
113
+ if start_time:
114
+ query = query.filter(AworldTaskModel.created_at >= start_time)
115
+ if end_time:
116
+ query = query.filter(AworldTaskModel.created_at <= end_time)
117
+
118
+ total = query.count()
119
+ orm_tasks = query.offset((page_num - 1) * page_size).limit(page_size).all()
120
+ items = [orm_to_pydantic_task(t) for t in orm_tasks]
121
+ return {
122
+ "total": total,
123
+ "page_num": page_num,
124
+ "page_size": page_size,
125
+ "items": items
126
+ }
127
+
128
+
129
+ class PostgresTaskDB(AworldTaskDB):
130
+ def __init__(self, db_url: str):
131
+ # db_url example: 'postgresql+psycopg2://user:password@host:port/dbname'
132
+ self.engine = create_engine(db_url, echo=False, future=True)
133
+ Base.metadata.create_all(self.engine)
134
+ self.Session = sessionmaker(bind=self.engine, expire_on_commit=False)
135
+
136
+ async def query_task_by_id(self, task_id: str) -> Optional[AworldTask]:
137
+ with self.Session() as session:
138
+ orm_task = session.query(AworldTaskModel).filter_by(task_id=task_id).first()
139
+ return orm_to_pydantic_task(orm_task) if orm_task else None
140
+
141
+ async def query_latest_task_result_by_id(self, task_id: str) -> Optional[AworldTaskResult]:
142
+ with self.Session() as session:
143
+ orm_result = (
144
+ session.query(AworldTaskResultModel)
145
+ .filter_by(task_id=task_id)
146
+ .order_by(AworldTaskResultModel.created_at.desc())
147
+ .first()
148
+ )
149
+ return orm_to_pydantic_result(orm_result) if orm_result else None
150
+
151
+ async def insert_task(self, task: AworldTask):
152
+ with self.Session() as session:
153
+ orm_task = pydantic_to_orm_task(task)
154
+ session.add(orm_task)
155
+ session.commit()
156
+
157
+ async def query_tasks_by_status(self, status: str, nums: int) -> list[AworldTask]:
158
+ with self.Session() as session:
159
+ orm_tasks = (
160
+ session.query(AworldTaskModel)
161
+ .filter_by(status=status)
162
+ .limit(nums)
163
+ .all()
164
+ )
165
+ return [orm_to_pydantic_task(t) for t in orm_tasks]
166
+
167
+ async def update_task(self, task: AworldTask):
168
+ with self.Session() as session:
169
+ orm_task = session.query(AworldTaskModel).filter_by(task_id=task.task_id).first()
170
+ if orm_task:
171
+ for k, v in task.model_dump().items():
172
+ setattr(orm_task, k, v)
173
+ orm_task.updated_at = datetime.utcnow()
174
+ session.commit()
175
+
176
+ async def save_task_result(self, result: AworldTaskResult):
177
+ with self.Session() as session:
178
+ orm_task = pydantic_to_orm_result(result)
179
+ session.add(orm_task)
180
+ session.commit()
181
+
182
+ async def page_query_tasks(self, filter: dict, page_size: int, page_num: int) -> dict:
183
+ with self.Session() as session:
184
+ query = session.query(AworldTaskModel)
185
+
186
+ # Handle special filters for time ranges
187
+ start_time = filter.pop('start_time', None)
188
+ end_time = filter.pop('end_time', None)
189
+
190
+ # Apply regular filters
191
+ for k, v in filter.items():
192
+ if hasattr(AworldTaskModel, k):
193
+ query = query.filter(getattr(AworldTaskModel, k) == v)
194
+
195
+ # Apply time range filters
196
+ if start_time:
197
+ query = query.filter(AworldTaskModel.created_at >= start_time)
198
+ if end_time:
199
+ query = query.filter(AworldTaskModel.created_at <= end_time)
200
+
201
+ total = query.count()
202
+ orm_tasks = query.offset((page_num - 1) * page_size).limit(page_size).all()
203
+ items = [orm_to_pydantic_task(t) for t in orm_tasks]
204
+ return {
205
+ "total": total,
206
+ "page_num": page_num,
207
+ "page_size": page_size,
208
+ "items": items
209
+ }
210
+