randydev commited on
Commit
d66d424
·
verified ·
1 Parent(s): 0ec0e43

Create sql_db.py

Browse files
Files changed (1) hide show
  1. Akeno/utils/sql_db.py +214 -0
Akeno/utils/sql_db.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import json
3
+ import sqlite3
4
+ import threading
5
+
6
+ import aiosqlite
7
+
8
+ from config import db_name
9
+
10
+ class Database:
11
+ def get(self, module: str, variable: str, default=None):
12
+ raise NotImplementedError
13
+
14
+ def set(self, module: str, variable: str, value):
15
+ raise NotImplementedError
16
+
17
+ def remove(self, module: str, variable: str):
18
+ raise NotImplementedError
19
+
20
+ def get_collection(self, module: str) -> dict:
21
+ raise NotImplementedError
22
+
23
+ def close(self):
24
+ raise NotImplementedError
25
+
26
+
27
+ class SqliteDatabase(Database):
28
+ def __init__(self, file):
29
+ self._conn = sqlite3.connect(file, check_same_thread=False)
30
+ self._conn.row_factory = sqlite3.Row
31
+ self._cursor = self._conn.cursor()
32
+ self._lock = threading.Lock()
33
+
34
+ @staticmethod
35
+ def _parse_row(row: sqlite3.Row):
36
+ parse_func = {
37
+ "bool": lambda x: x == "1",
38
+ "int": int,
39
+ "str": lambda x: x,
40
+ "json": json.loads,
41
+ }
42
+ return parse_func[row["type"]](row["val"])
43
+
44
+ def _create_table(self, module: str):
45
+ sql = f"""
46
+ CREATE TABLE IF NOT EXISTS '{module}' (
47
+ var TEXT UNIQUE NOT NULL,
48
+ val TEXT NOT NULL,
49
+ type TEXT NOT NULL
50
+ )
51
+ """
52
+ self._cursor.execute(sql)
53
+ self._conn.commit()
54
+
55
+ def _execute(self, module: str, sql, params=None):
56
+ with self._lock:
57
+ try:
58
+ return self._cursor.execute(sql, params)
59
+ except sqlite3.OperationalError as e:
60
+ if str(e).startswith("no such table"):
61
+ self._create_table(module)
62
+ return self._cursor.execute(sql, params)
63
+ raise e from None
64
+
65
+ def get(self, module: str, variable: str, default=None):
66
+ cur = self._execute(module, f"SELECT * FROM '{module}' WHERE var=:var", {"var": variable})
67
+ row = cur.fetchone()
68
+ return default if row is None else self._parse_row(row)
69
+
70
+ def set(self, module: str, variable: str, value) -> bool:
71
+ sql = f"""
72
+ INSERT INTO '{module}' VALUES ( :var, :val, :type )
73
+ ON CONFLICT (var) DO
74
+ UPDATE SET val=:val, type=:type WHERE var=:var
75
+ """
76
+
77
+ if isinstance(value, bool):
78
+ val = "1" if value else "0"
79
+ typ = "bool"
80
+ elif isinstance(value, str):
81
+ val = value
82
+ typ = "str"
83
+ elif isinstance(value, int):
84
+ val = str(value)
85
+ typ = "int"
86
+ else:
87
+ val = json.dumps(value)
88
+ typ = "json"
89
+
90
+ self._execute(module, sql, {"var": variable, "val": val, "type": typ})
91
+ self._conn.commit()
92
+
93
+ return True
94
+
95
+ def remove(self, module: str, variable: str):
96
+ sql = f"DELETE FROM '{module}' WHERE var=:var"
97
+ self._execute(module, sql, {"var": variable})
98
+ self._conn.commit()
99
+
100
+ def get_collection(self, module: str) -> dict:
101
+ sql = f"SELECT * FROM '{module}'"
102
+ cur = self._execute(module, sql)
103
+
104
+ return {row["var"]: self._parse_row(row) for row in cur}
105
+
106
+ def close(self):
107
+ self._conn.commit()
108
+ self._conn.close()
109
+
110
+
111
+ class AioSqliteDatabase(Database):
112
+ def __init__(self, file=db_name):
113
+ self._file = file
114
+ self._conn = None
115
+
116
+ async def _connect(self):
117
+ self._conn = await aiosqlite.connect(self._file)
118
+ self._conn.row_factory = aiosqlite.Row
119
+ self._cursor = await self._conn.cursor()
120
+ self._lock = asyncio.Lock()
121
+
122
+ @staticmethod
123
+ def _parse_row(row: aiosqlite.Row):
124
+ parse_func = {
125
+ "bool": lambda x: x == "1",
126
+ "int": int,
127
+ "str": lambda x: x,
128
+ "json": json.loads,
129
+ }
130
+ return parse_func[row["type"]](row["val"])
131
+
132
+ async def _create_table(self, module: str):
133
+ sql = f"""
134
+ CREATE TABLE IF NOT EXISTS '{module}' (
135
+ var TEXT UNIQUE NOT NULL,
136
+ val TEXT NOT NULL,
137
+ type TEXT NOT NULL
138
+ )
139
+ """
140
+ self._cursor.execute(sql)
141
+ self._conn.commit()
142
+
143
+ async def _execute(self, module: str, *args, **kwargs) -> aiosqlite.Cursor:
144
+ try:
145
+ return await self._cursor.execute(*args, **kwargs)
146
+ except aiosqlite.OperationalError as e:
147
+ if str(e).startswith("no such table"):
148
+ await self._create_table(module)
149
+ return await self._cursor.execute(*args, **kwargs)
150
+ raise e from None
151
+
152
+ async def get(self, module: str, variable: str, default=None):
153
+ if not self._conn:
154
+ await self._connect()
155
+ sql = f"SELECT * FROM '{module}' WHERE var=:var"
156
+ cur = await self._execute(module, sql, {"var": variable})
157
+
158
+ row = await cur.fetchone()
159
+ result = default if row is None else self._parse_row(row)
160
+ await self.close()
161
+ return result
162
+
163
+ async def set(self, module: str, variable: str, value) -> bool:
164
+ if not self._conn:
165
+ await self._connect()
166
+ sql = f"""
167
+ INSERT INTO '{module}' VALUES ( :var, :val, :type )
168
+ ON CONFLICT (var) DO
169
+ UPDATE SET val=:val, type=:type WHERE var=:var
170
+ """
171
+
172
+ if isinstance(value, bool):
173
+ val = "1" if value else "0"
174
+ typ = "bool"
175
+ elif isinstance(value, str):
176
+ val = value
177
+ typ = "str"
178
+ elif isinstance(value, int):
179
+ val = str(value)
180
+ typ = "int"
181
+ else:
182
+ val = json.dumps(value)
183
+ typ = "json"
184
+
185
+ await self._execute(module, sql, {"var": variable, "val": val, "type": typ})
186
+ await self._conn.commit()
187
+ await self.close()
188
+ return True
189
+
190
+ async def remove(self, module: str, variable: str):
191
+ if not self._conn:
192
+ await self._connect()
193
+ sql = f"DELETE FROM '{module}' WHERE var=:var"
194
+ await self._execute(module, sql, {"var": variable})
195
+ await self._conn.commit()
196
+ await self.close()
197
+
198
+ async def get_collection(self, module: str) -> dict:
199
+ if not self._conn:
200
+ await self._connect()
201
+ sql = f"SELECT * FROM '{module}'"
202
+ cur = await self._execute(module, sql)
203
+ result = {row["var"]: self._parse_row(row) async for row in cur}
204
+ await self.close()
205
+ return result
206
+
207
+ async def close(self):
208
+ await self._conn.commit()
209
+ await self._cursor.close()
210
+ await self._conn.close()
211
+ self._conn = None
212
+
213
+
214
+ sql_db = SqliteDatabase(db_name)