|
""" |
|
MIT License |
|
|
|
Copyright (C) 2023 ROCKY4546 |
|
https://github.com/rocky4546 |
|
|
|
This file is part of Cabernet |
|
|
|
Permission is hereby granted, free of charge, to any person obtaining a copy of this software |
|
and associated documentation files (the "Software"), to deal in the Software without restriction, |
|
including without limitation the rights to use, copy, modify, merge, publish, distribute, |
|
sublicense, and/or sell copies of the Software, and to permit persons to whom the Software |
|
is furnished to do so, subject to the following conditions: |
|
|
|
The above copyright notice and this permission notice shall be included in all copies or |
|
substantial portions of the Software. |
|
""" |
|
|
|
import logging |
|
import os |
|
import pathlib |
|
import random |
|
import shutil |
|
import sqlite3 |
|
import threading |
|
import time |
|
|
|
LOCK = threading.Lock() |
|
DB_EXT = '.db' |
|
BACKUP_EXT = '.sql' |
|
|
|
|
|
SQL_CREATE_TABLES = 'ct' |
|
SQL_DROP_TABLES = 'dt' |
|
SQL_ADD_ROW = '_add' |
|
SQL_UPDATE = '_update' |
|
SQL_GET = '_get' |
|
SQL_DELETE = '_del' |
|
FILE_LINK_ZIP = '_filelinks' |
|
|
|
class DB: |
|
conn = {} |
|
|
|
def __init__(self, _config, _db_name, _sqlcmds): |
|
self.logger = logging.getLogger(__name__ + str(threading.get_ident())) |
|
self.config = _config |
|
self.db_name = _db_name |
|
self.sqlcmds = _sqlcmds |
|
self.cur = None |
|
self.offset = -1 |
|
self.where = None |
|
self.sqlcmd = None |
|
self.db_fullpath = pathlib.Path(self.config['paths']['db_dir']) \ |
|
.joinpath(_db_name + DB_EXT) |
|
if not os.path.exists(self.db_fullpath): |
|
self.logger.debug('Creating new database: {} {}'.format(_db_name, self.db_fullpath)) |
|
self.create_tables() |
|
self.check_connection() |
|
DB.conn[self.db_name][threading.get_ident()].commit() |
|
|
|
def sql_exec(self, _sqlcmd, _bindings=None, _cursor=None): |
|
try: |
|
self.check_connection() |
|
if _bindings: |
|
if _cursor: |
|
return _cursor.execute(_sqlcmd, _bindings) |
|
else: |
|
return DB.conn[self.db_name][threading.get_ident()].execute(_sqlcmd, _bindings) |
|
else: |
|
if _cursor: |
|
return _cursor.execute(_sqlcmd) |
|
else: |
|
return DB.conn[self.db_name][threading.get_ident()].execute(_sqlcmd) |
|
except sqlite3.IntegrityError as e: |
|
DB.conn[self.db_name][threading.get_ident()].close() |
|
del DB.conn[self.db_name][threading.get_ident()] |
|
raise e |
|
|
|
def rnd_sleep(self, _sec): |
|
r = random.randrange(0, 50) |
|
sec = _sec + r / 100 |
|
time.sleep(sec) |
|
|
|
def add(self, _table, _values): |
|
self.logger.trace('DB add() called {}'.format(threading.get_ident())) |
|
cur = None |
|
sqlcmd = self.sqlcmds[''.join([_table, SQL_ADD_ROW])] |
|
i = 10 |
|
while i > 0: |
|
i -= 1 |
|
try: |
|
self.check_connection() |
|
cur = DB.conn[self.db_name][threading.get_ident()].cursor() |
|
self.sql_exec(sqlcmd, _values, cur) |
|
DB.conn[self.db_name][threading.get_ident()].commit() |
|
lastrow = cur.lastrowid |
|
cur.close() |
|
self.logger.trace('DB add() exit {}'.format(threading.get_ident())) |
|
return lastrow |
|
except sqlite3.OperationalError as e: |
|
self.logger.warning('{} Add request ignored, retrying {}, {}' |
|
.format(self.db_name, i, e)) |
|
DB.conn[self.db_name][threading.get_ident()].rollback() |
|
if cur is not None: |
|
cur.close() |
|
self.rnd_sleep(0.3) |
|
self.logger.trace('DB add() exit {}'.format(threading.get_ident())) |
|
return None |
|
|
|
def delete(self, _table, _values): |
|
self.logger.trace('DB delete() called {}'.format(threading.get_ident())) |
|
cur = None |
|
sqlcmd = self.sqlcmds[''.join([_table, SQL_DELETE])] |
|
i = 10 |
|
while i > 0: |
|
i -= 1 |
|
try: |
|
self.check_connection() |
|
cur = DB.conn[self.db_name][threading.get_ident()].cursor() |
|
self.sql_exec(sqlcmd, _values, cur) |
|
num_deleted = cur.rowcount |
|
DB.conn[self.db_name][threading.get_ident()].commit() |
|
cur.close() |
|
self.logger.trace('DB delete() exit {}'.format(threading.get_ident())) |
|
return num_deleted |
|
except sqlite3.OperationalError as e: |
|
self.logger.warning('{} Delete request ignored, retrying {}, {}' |
|
.format(self.db_name, i, e)) |
|
DB.conn[self.db_name][threading.get_ident()].rollback() |
|
if cur is not None: |
|
cur.close() |
|
self.rnd_sleep(0.3) |
|
self.logger.trace('DB delete() exit {}'.format(threading.get_ident())) |
|
return 0 |
|
|
|
def update(self, _table, _values=None): |
|
self.logger.trace('DB update() called {}'.format(threading.get_ident())) |
|
cur = None |
|
sqlcmd = self.sqlcmds[''.join([_table, SQL_UPDATE])] |
|
i = 10 |
|
while i > 0: |
|
i -= 1 |
|
try: |
|
LOCK.acquire(True) |
|
self.check_connection() |
|
cur = DB.conn[self.db_name][threading.get_ident()].cursor() |
|
self.sql_exec(sqlcmd, _values, cur) |
|
DB.conn[self.db_name][threading.get_ident()].commit() |
|
lastrow = cur.lastrowid |
|
cur.close() |
|
LOCK.release() |
|
self.logger.trace('DB update() exit {}'.format(threading.get_ident())) |
|
return lastrow |
|
except sqlite3.OperationalError as e: |
|
self.logger.notice('{} Update request ignored, retrying {}, {}' |
|
.format(self.db_name, i, e)) |
|
DB.conn[self.db_name][threading.get_ident()].rollback() |
|
if cur is not None: |
|
cur.close() |
|
LOCK.release() |
|
self.rnd_sleep(0.3) |
|
self.logger.trace('DB update() exit {}'.format(threading.get_ident())) |
|
return None |
|
|
|
def commit(self): |
|
DB.conn[self.db_name][threading.get_ident()].commit() |
|
|
|
def get(self, _table, _where=None): |
|
cur = None |
|
sqlcmd = self.sqlcmds[''.join([_table, SQL_GET])] |
|
i = 10 |
|
while i > 0: |
|
i -= 1 |
|
try: |
|
self.check_connection() |
|
cur = DB.conn[self.db_name][threading.get_ident()].cursor() |
|
self.sql_exec(sqlcmd, _where, cur) |
|
result = cur.fetchall() |
|
cur.close() |
|
return result |
|
except sqlite3.OperationalError as e: |
|
self.logger.warning('{} GET request ignored retrying {}, {}' |
|
.format(self.db_name, i, e)) |
|
DB.conn[self.db_name][threading.get_ident()].rollback() |
|
if cur is not None: |
|
cur.close() |
|
self.rnd_sleep(0.3) |
|
return None |
|
|
|
def get_dict(self, _table, _where=None, sql=None): |
|
cur = None |
|
if sql is None: |
|
sqlcmd = self.sqlcmds[''.join([_table, SQL_GET])] |
|
else: |
|
sqlcmd = sql |
|
i = 10 |
|
while i > 0: |
|
i -= 1 |
|
try: |
|
LOCK.acquire(True) |
|
self.check_connection() |
|
cur = DB.conn[self.db_name][threading.get_ident()].cursor() |
|
self.sql_exec(sqlcmd, _where, cur) |
|
records = cur.fetchall() |
|
rows = [] |
|
for row in records: |
|
rows.append(dict(zip([c[0] for c in cur.description], row))) |
|
cur.close() |
|
LOCK.release() |
|
return rows |
|
except sqlite3.OperationalError as e: |
|
self.logger.warning('{} GET request ignored retrying {}, {}' |
|
.format(self.db_name, i, e)) |
|
DB.conn[self.db_name][threading.get_ident()].rollback() |
|
if cur is not None: |
|
cur.close() |
|
LOCK.release() |
|
self.rnd_sleep(0.3) |
|
return None |
|
|
|
def get_init(self, _table, _where=None): |
|
""" |
|
Requires "LIMIT ? OFFSET ?" at the end of the sql statement |
|
""" |
|
self.sqlcmd = self.sqlcmds[''.join([_table, SQL_GET])] |
|
self.where = list(_where) |
|
self.offset = 0 |
|
|
|
def get_dict_next(self): |
|
w_list = self.where.copy() |
|
w_list.extend((1, self.offset)) |
|
self.cur = self.sql_exec(self.sqlcmd, tuple(w_list)) |
|
records = self.cur.fetchall() |
|
self.offset += 1 |
|
if len(records) == 0: |
|
return None |
|
row = records[0] |
|
return dict(zip([c[0] for c in self.cur.description], row)) |
|
|
|
def save_file(self, _keys, _blob): |
|
""" |
|
Stores the blob in the folder with the db name with |
|
the filename of concatenated _keys |
|
_keys is the list of unique keys for the table |
|
Returns the filepath to the file generated |
|
""" |
|
folder_path = pathlib.Path(self.config['paths']['db_dir']) \ |
|
.joinpath(self.db_name) |
|
os.makedirs(folder_path, exist_ok=True) |
|
filename = '_'.join(str(x) for x in _keys) + '.txt' |
|
file_rel_path = pathlib.Path(self.db_name).joinpath(filename) |
|
filepath = folder_path.joinpath(filename) |
|
try: |
|
with open(filepath, mode='wb') as f: |
|
if isinstance(_blob, str): |
|
f.write(_blob.encode()) |
|
else: |
|
f.write(_blob) |
|
f.flush() |
|
f.close() |
|
except PermissionError as ex: |
|
self.logger.warning('Unable to create linked database file {}' |
|
.format(file_rel_path)) |
|
return None |
|
return file_rel_path |
|
|
|
def delete_file(self, _filepath): |
|
""" |
|
_filepath is relative to the database path |
|
""" |
|
fullpath = pathlib.Path(self.config['paths']['db_dir']) \ |
|
.joinpath(_filepath) |
|
try: |
|
os.remove(fullpath) |
|
return True |
|
except PermissionError as ex: |
|
self.logger.warning('Unable to delete linked database file {}' |
|
.format(_filepath)) |
|
return False |
|
except FileNotFoundError as ex: |
|
self.logger.warning('File missing, unable to delete linked database file {}' |
|
.format(_filepath)) |
|
return False |
|
|
|
def get_file(self, _filepath): |
|
""" |
|
_filepath is relative to the database path |
|
return the blob |
|
""" |
|
fullpath = pathlib.Path(self.config['paths']['db_dir']) \ |
|
.joinpath(_filepath) |
|
|
|
if not fullpath.exists(): |
|
self.logger.warning('Linked database file Missing {}'.format(_filepath)) |
|
return None |
|
try: |
|
with open(fullpath, mode='rb') as f: |
|
blob = f.read() |
|
f.close() |
|
return blob |
|
except PermissionError as ex: |
|
self.logger.warning('Unable to read linked database file {}' |
|
.format(_filepath)) |
|
return None |
|
|
|
def get_file_by_key(self, _keys): |
|
filename = '_'.join(str(x) for x in _keys) + '.txt' |
|
file_rel_path = pathlib.Path(self.db_name).joinpath(filename) |
|
return self.get_file(file_rel_path) |
|
|
|
def reinitialize_tables(self): |
|
self.drop_tables() |
|
self.create_tables() |
|
|
|
def create_tables(self): |
|
for table in self.sqlcmds[''.join([SQL_CREATE_TABLES])]: |
|
cur = self.sql_exec(table) |
|
DB.conn[self.db_name][threading.get_ident()].commit() |
|
|
|
def drop_tables(self): |
|
for table in self.sqlcmds[SQL_DROP_TABLES]: |
|
cur = self.sql_exec(table) |
|
DB.conn[self.db_name][threading.get_ident()].commit() |
|
|
|
def export_sql(self, backup_folder): |
|
self.logger.debug('Running backup for {} database'.format(self.db_name)) |
|
try: |
|
if not os.path.isdir(backup_folder): |
|
os.mkdir(backup_folder) |
|
self.check_connection() |
|
|
|
|
|
db_linkfilepath = pathlib.Path(self.config['paths']['db_dir']) \ |
|
.joinpath(self.db_name) |
|
if db_linkfilepath.exists(): |
|
self.logger.debug('Linked file folder exists, backing up folder for db {}'.format(self.db_name)) |
|
backup_filelink = pathlib.Path(backup_folder, self.db_name + FILE_LINK_ZIP) |
|
shutil.make_archive(backup_filelink, 'zip', db_linkfilepath) |
|
|
|
backup_file = pathlib.Path(backup_folder, self.db_name + BACKUP_EXT) |
|
with open(backup_file, 'w', encoding='utf-8') as export_f: |
|
for line in DB.conn[self.db_name][threading.get_ident()].iterdump(): |
|
export_f.write('%s\n' % line) |
|
except PermissionError as e: |
|
self.logger.warning(e) |
|
self.logger.warning('Unable to make backups') |
|
|
|
def import_sql(self, backup_folder): |
|
self.logger.debug('Running restore for {} database'.format(self.db_name)) |
|
if not os.path.isdir(backup_folder): |
|
msg = 'Backup folder does not exist: {}'.format(backup_folder) |
|
self.logger.warning(msg) |
|
return msg |
|
|
|
|
|
backup_filelink = pathlib.Path(backup_folder, self.db_name + FILE_LINK_ZIP + '.zip') |
|
db_linkfilepath = pathlib.Path(self.config['paths']['db_dir']) \ |
|
.joinpath(self.db_name) |
|
if backup_filelink.exists(): |
|
self.logger.debug('Linked file folder exists, restoring folder for db {}'.format(self.db_name)) |
|
shutil.unpack_archive(backup_filelink, db_linkfilepath) |
|
|
|
backup_file = pathlib.Path(backup_folder, self.db_name + BACKUP_EXT) |
|
if not os.path.isfile(backup_file): |
|
msg = 'Backup file does not exist, skipping: {}'.format(backup_file) |
|
self.logger.info(msg) |
|
return msg |
|
self.check_connection() |
|
self.drop_tables() |
|
with open(backup_file, 'r') as import_f: |
|
cmd = '' |
|
for line in import_f: |
|
cmd += line |
|
if ';' in line[-3:]: |
|
DB.conn[self.db_name][threading.get_ident()].execute(cmd) |
|
cmd = '' |
|
return None |
|
|
|
def close(self): |
|
thread_id = threading.get_ident() |
|
DB.conn[self.db_name][thread_id].close() |
|
del DB.conn[self.db_name][thread_id] |
|
self.logger.debug('{} database closed for thread:{}'.format(self.db_name, thread_id)) |
|
|
|
def check_connection(self): |
|
if self.db_name not in DB.conn: |
|
DB.conn[self.db_name] = {} |
|
db_conn_dbname = DB.conn[self.db_name] |
|
|
|
if threading.get_ident() not in db_conn_dbname: |
|
db_conn_dbname[threading.get_ident()] = sqlite3.connect( |
|
self.db_fullpath, detect_types=sqlite3.PARSE_DECLTYPES | sqlite3.PARSE_COLNAMES) |
|
else: |
|
try: |
|
db_conn_dbname[threading.get_ident()].total_changes |
|
except sqlite3.ProgrammingError: |
|
self.logger.debug('Reopening {} database for thread:{}'.format(self.db_name, threading.get_ident())) |
|
db_conn_dbname[threading.get_ident()] = sqlite3.connect( |
|
self.db_fullpath, detect_types=sqlite3.PARSE_DECLTYPES | sqlite3.PARSE_COLNAMES) |
|
|