Spaces:
Sleeping
Sleeping
Upload db.py
Browse files
AWorld-main/aworlddistributed/aworldspace/db.py
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
from datetime import datetime
|
3 |
+
from typing import Optional
|
4 |
+
|
5 |
+
from sqlalchemy import create_engine
|
6 |
+
from sqlalchemy.orm import sessionmaker
|
7 |
+
|
8 |
+
from base import AworldTask, AworldTaskResult
|
9 |
+
from aworldspace.db.models import (
|
10 |
+
Base, AworldTaskModel, AworldTaskResultModel,
|
11 |
+
orm_to_pydantic_task, pydantic_to_orm_task,
|
12 |
+
orm_to_pydantic_result, pydantic_to_orm_result
|
13 |
+
)
|
14 |
+
|
15 |
+
|
16 |
+
class AworldTaskDB(ABC):
|
17 |
+
|
18 |
+
@abstractmethod
|
19 |
+
async def query_task_by_id(self, task_id: str) -> AworldTask:
|
20 |
+
pass
|
21 |
+
|
22 |
+
@abstractmethod
|
23 |
+
async def query_latest_task_result_by_id(self, task_id: str) -> Optional[AworldTaskResult]:
|
24 |
+
pass
|
25 |
+
|
26 |
+
@abstractmethod
|
27 |
+
async def insert_task(self, task: AworldTask):
|
28 |
+
pass
|
29 |
+
|
30 |
+
@abstractmethod
|
31 |
+
async def query_tasks_by_status(self, status: str, nums: int) -> list[AworldTask]:
|
32 |
+
pass
|
33 |
+
|
34 |
+
@abstractmethod
|
35 |
+
async def update_task(self, task: AworldTask):
|
36 |
+
pass
|
37 |
+
|
38 |
+
@abstractmethod
|
39 |
+
async def page_query_tasks(self, filter: dict, page_size: int, page_num: int) -> dict:
|
40 |
+
pass
|
41 |
+
|
42 |
+
@abstractmethod
|
43 |
+
async def save_task_result(self, result: AworldTaskResult):
|
44 |
+
pass
|
45 |
+
|
46 |
+
|
47 |
+
class SqliteTaskDB(AworldTaskDB):
|
48 |
+
def __init__(self, db_path: str):
|
49 |
+
self.engine = create_engine(db_path, echo=False, future=True)
|
50 |
+
Base.metadata.create_all(self.engine)
|
51 |
+
self.Session = sessionmaker(bind=self.engine, expire_on_commit=False)
|
52 |
+
|
53 |
+
async def query_task_by_id(self, task_id: str) -> Optional[AworldTask]:
|
54 |
+
with self.Session() as session:
|
55 |
+
orm_task = session.query(AworldTaskModel).filter_by(task_id=task_id).first()
|
56 |
+
return orm_to_pydantic_task(orm_task) if orm_task else None
|
57 |
+
|
58 |
+
async def query_latest_task_result_by_id(self, task_id: str) -> Optional[AworldTaskResult]:
|
59 |
+
with self.Session() as session:
|
60 |
+
orm_result = (
|
61 |
+
session.query(AworldTaskResultModel)
|
62 |
+
.filter_by(task_id=task_id)
|
63 |
+
.order_by(AworldTaskResultModel.created_at.desc())
|
64 |
+
.first()
|
65 |
+
)
|
66 |
+
return orm_to_pydantic_result(orm_result) if orm_result else None
|
67 |
+
|
68 |
+
async def insert_task(self, task: AworldTask):
|
69 |
+
with self.Session() as session:
|
70 |
+
orm_task = pydantic_to_orm_task(task)
|
71 |
+
session.add(orm_task)
|
72 |
+
session.commit()
|
73 |
+
|
74 |
+
async def query_tasks_by_status(self, status: str, nums: int) -> list[AworldTask]:
|
75 |
+
with self.Session() as session:
|
76 |
+
orm_tasks = (
|
77 |
+
session.query(AworldTaskModel)
|
78 |
+
.filter_by(status=status)
|
79 |
+
.limit(nums)
|
80 |
+
.all()
|
81 |
+
)
|
82 |
+
return [orm_to_pydantic_task(t) for t in orm_tasks]
|
83 |
+
|
84 |
+
async def update_task(self, task: AworldTask):
|
85 |
+
with self.Session() as session:
|
86 |
+
orm_task = session.query(AworldTaskModel).filter_by(task_id=task.task_id).first()
|
87 |
+
if orm_task:
|
88 |
+
for k, v in task.model_dump().items():
|
89 |
+
setattr(orm_task, k, v)
|
90 |
+
orm_task.updated_at = datetime.utcnow()
|
91 |
+
session.commit()
|
92 |
+
|
93 |
+
async def save_task_result(self, result: AworldTaskResult):
|
94 |
+
with self.Session() as session:
|
95 |
+
orm_task = pydantic_to_orm_result(result)
|
96 |
+
session.add(orm_task)
|
97 |
+
session.commit()
|
98 |
+
|
99 |
+
async def page_query_tasks(self, filter: dict, page_size: int, page_num: int) -> dict:
|
100 |
+
with self.Session() as session:
|
101 |
+
query = session.query(AworldTaskModel)
|
102 |
+
|
103 |
+
# Handle special filters for time ranges
|
104 |
+
start_time = filter.pop('start_time', None)
|
105 |
+
end_time = filter.pop('end_time', None)
|
106 |
+
|
107 |
+
# Apply regular filters
|
108 |
+
for k, v in filter.items():
|
109 |
+
if hasattr(AworldTaskModel, k):
|
110 |
+
query = query.filter(getattr(AworldTaskModel, k) == v)
|
111 |
+
|
112 |
+
# Apply time range filters
|
113 |
+
if start_time:
|
114 |
+
query = query.filter(AworldTaskModel.created_at >= start_time)
|
115 |
+
if end_time:
|
116 |
+
query = query.filter(AworldTaskModel.created_at <= end_time)
|
117 |
+
|
118 |
+
total = query.count()
|
119 |
+
orm_tasks = query.offset((page_num - 1) * page_size).limit(page_size).all()
|
120 |
+
items = [orm_to_pydantic_task(t) for t in orm_tasks]
|
121 |
+
return {
|
122 |
+
"total": total,
|
123 |
+
"page_num": page_num,
|
124 |
+
"page_size": page_size,
|
125 |
+
"items": items
|
126 |
+
}
|
127 |
+
|
128 |
+
|
129 |
+
class PostgresTaskDB(AworldTaskDB):
|
130 |
+
def __init__(self, db_url: str):
|
131 |
+
# db_url example: 'postgresql+psycopg2://user:password@host:port/dbname'
|
132 |
+
self.engine = create_engine(db_url, echo=False, future=True)
|
133 |
+
Base.metadata.create_all(self.engine)
|
134 |
+
self.Session = sessionmaker(bind=self.engine, expire_on_commit=False)
|
135 |
+
|
136 |
+
async def query_task_by_id(self, task_id: str) -> Optional[AworldTask]:
|
137 |
+
with self.Session() as session:
|
138 |
+
orm_task = session.query(AworldTaskModel).filter_by(task_id=task_id).first()
|
139 |
+
return orm_to_pydantic_task(orm_task) if orm_task else None
|
140 |
+
|
141 |
+
async def query_latest_task_result_by_id(self, task_id: str) -> Optional[AworldTaskResult]:
|
142 |
+
with self.Session() as session:
|
143 |
+
orm_result = (
|
144 |
+
session.query(AworldTaskResultModel)
|
145 |
+
.filter_by(task_id=task_id)
|
146 |
+
.order_by(AworldTaskResultModel.created_at.desc())
|
147 |
+
.first()
|
148 |
+
)
|
149 |
+
return orm_to_pydantic_result(orm_result) if orm_result else None
|
150 |
+
|
151 |
+
async def insert_task(self, task: AworldTask):
|
152 |
+
with self.Session() as session:
|
153 |
+
orm_task = pydantic_to_orm_task(task)
|
154 |
+
session.add(orm_task)
|
155 |
+
session.commit()
|
156 |
+
|
157 |
+
async def query_tasks_by_status(self, status: str, nums: int) -> list[AworldTask]:
|
158 |
+
with self.Session() as session:
|
159 |
+
orm_tasks = (
|
160 |
+
session.query(AworldTaskModel)
|
161 |
+
.filter_by(status=status)
|
162 |
+
.limit(nums)
|
163 |
+
.all()
|
164 |
+
)
|
165 |
+
return [orm_to_pydantic_task(t) for t in orm_tasks]
|
166 |
+
|
167 |
+
async def update_task(self, task: AworldTask):
|
168 |
+
with self.Session() as session:
|
169 |
+
orm_task = session.query(AworldTaskModel).filter_by(task_id=task.task_id).first()
|
170 |
+
if orm_task:
|
171 |
+
for k, v in task.model_dump().items():
|
172 |
+
setattr(orm_task, k, v)
|
173 |
+
orm_task.updated_at = datetime.utcnow()
|
174 |
+
session.commit()
|
175 |
+
|
176 |
+
async def save_task_result(self, result: AworldTaskResult):
|
177 |
+
with self.Session() as session:
|
178 |
+
orm_task = pydantic_to_orm_result(result)
|
179 |
+
session.add(orm_task)
|
180 |
+
session.commit()
|
181 |
+
|
182 |
+
async def page_query_tasks(self, filter: dict, page_size: int, page_num: int) -> dict:
|
183 |
+
with self.Session() as session:
|
184 |
+
query = session.query(AworldTaskModel)
|
185 |
+
|
186 |
+
# Handle special filters for time ranges
|
187 |
+
start_time = filter.pop('start_time', None)
|
188 |
+
end_time = filter.pop('end_time', None)
|
189 |
+
|
190 |
+
# Apply regular filters
|
191 |
+
for k, v in filter.items():
|
192 |
+
if hasattr(AworldTaskModel, k):
|
193 |
+
query = query.filter(getattr(AworldTaskModel, k) == v)
|
194 |
+
|
195 |
+
# Apply time range filters
|
196 |
+
if start_time:
|
197 |
+
query = query.filter(AworldTaskModel.created_at >= start_time)
|
198 |
+
if end_time:
|
199 |
+
query = query.filter(AworldTaskModel.created_at <= end_time)
|
200 |
+
|
201 |
+
total = query.count()
|
202 |
+
orm_tasks = query.offset((page_num - 1) * page_size).limit(page_size).all()
|
203 |
+
items = [orm_to_pydantic_task(t) for t in orm_tasks]
|
204 |
+
return {
|
205 |
+
"total": total,
|
206 |
+
"page_num": page_num,
|
207 |
+
"page_size": page_size,
|
208 |
+
"items": items
|
209 |
+
}
|
210 |
+
|