cab / lib /db /db.py
docs4you's picture
Upload 487 files
27867f1 verified
"""
MIT License
Copyright (C) 2023 ROCKY4546
https://github.com/rocky4546
This file is part of Cabernet
Permission is hereby granted, free of charge, to any person obtaining a copy of this software
and associated documentation files (the "Software"), to deal in the Software without restriction,
including without limitation the rights to use, copy, modify, merge, publish, distribute,
sublicense, and/or sell copies of the Software, and to permit persons to whom the Software
is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or
substantial portions of the Software.
"""
import logging
import os
import pathlib
import random
import shutil
import sqlite3
import threading
import time
LOCK = threading.Lock()
DB_EXT = '.db'
BACKUP_EXT = '.sql'
# trailers used in sqlcmds.py
SQL_CREATE_TABLES = 'ct'
SQL_DROP_TABLES = 'dt'
SQL_ADD_ROW = '_add'
SQL_UPDATE = '_update'
SQL_GET = '_get'
SQL_DELETE = '_del'
FILE_LINK_ZIP = '_filelinks'
class DB:
conn = {}
def __init__(self, _config, _db_name, _sqlcmds):
self.logger = logging.getLogger(__name__ + str(threading.get_ident()))
self.config = _config
self.db_name = _db_name
self.sqlcmds = _sqlcmds
self.cur = None
self.offset = -1
self.where = None
self.sqlcmd = None
self.db_fullpath = pathlib.Path(self.config['paths']['db_dir']) \
.joinpath(_db_name + DB_EXT)
if not os.path.exists(self.db_fullpath):
self.logger.debug('Creating new database: {} {}'.format(_db_name, self.db_fullpath))
self.create_tables()
self.check_connection()
DB.conn[self.db_name][threading.get_ident()].commit()
def sql_exec(self, _sqlcmd, _bindings=None, _cursor=None):
try:
self.check_connection()
if _bindings:
if _cursor:
return _cursor.execute(_sqlcmd, _bindings)
else:
return DB.conn[self.db_name][threading.get_ident()].execute(_sqlcmd, _bindings)
else:
if _cursor:
return _cursor.execute(_sqlcmd)
else:
return DB.conn[self.db_name][threading.get_ident()].execute(_sqlcmd)
except sqlite3.IntegrityError as e:
DB.conn[self.db_name][threading.get_ident()].close()
del DB.conn[self.db_name][threading.get_ident()]
raise e
def rnd_sleep(self, _sec):
r = random.randrange(0, 50)
sec = _sec + r / 100
time.sleep(sec)
def add(self, _table, _values):
self.logger.trace('DB add() called {}'.format(threading.get_ident()))
cur = None
sqlcmd = self.sqlcmds[''.join([_table, SQL_ADD_ROW])]
i = 10
while i > 0:
i -= 1
try:
self.check_connection()
cur = DB.conn[self.db_name][threading.get_ident()].cursor()
self.sql_exec(sqlcmd, _values, cur)
DB.conn[self.db_name][threading.get_ident()].commit()
lastrow = cur.lastrowid
cur.close()
self.logger.trace('DB add() exit {}'.format(threading.get_ident()))
return lastrow
except sqlite3.OperationalError as e:
self.logger.warning('{} Add request ignored, retrying {}, {}'
.format(self.db_name, i, e))
DB.conn[self.db_name][threading.get_ident()].rollback()
if cur is not None:
cur.close()
self.rnd_sleep(0.3)
self.logger.trace('DB add() exit {}'.format(threading.get_ident()))
return None
def delete(self, _table, _values):
self.logger.trace('DB delete() called {}'.format(threading.get_ident()))
cur = None
sqlcmd = self.sqlcmds[''.join([_table, SQL_DELETE])]
i = 10
while i > 0:
i -= 1
try:
self.check_connection()
cur = DB.conn[self.db_name][threading.get_ident()].cursor()
self.sql_exec(sqlcmd, _values, cur)
num_deleted = cur.rowcount
DB.conn[self.db_name][threading.get_ident()].commit()
cur.close()
self.logger.trace('DB delete() exit {}'.format(threading.get_ident()))
return num_deleted
except sqlite3.OperationalError as e:
self.logger.warning('{} Delete request ignored, retrying {}, {}'
.format(self.db_name, i, e))
DB.conn[self.db_name][threading.get_ident()].rollback()
if cur is not None:
cur.close()
self.rnd_sleep(0.3)
self.logger.trace('DB delete() exit {}'.format(threading.get_ident()))
return 0
def update(self, _table, _values=None):
self.logger.trace('DB update() called {}'.format(threading.get_ident()))
cur = None
sqlcmd = self.sqlcmds[''.join([_table, SQL_UPDATE])]
i = 10
while i > 0:
i -= 1
try:
LOCK.acquire(True)
self.check_connection()
cur = DB.conn[self.db_name][threading.get_ident()].cursor()
self.sql_exec(sqlcmd, _values, cur)
DB.conn[self.db_name][threading.get_ident()].commit()
lastrow = cur.lastrowid
cur.close()
LOCK.release()
self.logger.trace('DB update() exit {}'.format(threading.get_ident()))
return lastrow
except sqlite3.OperationalError as e:
self.logger.notice('{} Update request ignored, retrying {}, {}'
.format(self.db_name, i, e))
DB.conn[self.db_name][threading.get_ident()].rollback()
if cur is not None:
cur.close()
LOCK.release()
self.rnd_sleep(0.3)
self.logger.trace('DB update() exit {}'.format(threading.get_ident()))
return None
def commit(self):
DB.conn[self.db_name][threading.get_ident()].commit()
def get(self, _table, _where=None):
cur = None
sqlcmd = self.sqlcmds[''.join([_table, SQL_GET])]
i = 10
while i > 0:
i -= 1
try:
self.check_connection()
cur = DB.conn[self.db_name][threading.get_ident()].cursor()
self.sql_exec(sqlcmd, _where, cur)
result = cur.fetchall()
cur.close()
return result
except sqlite3.OperationalError as e:
self.logger.warning('{} GET request ignored retrying {}, {}'
.format(self.db_name, i, e))
DB.conn[self.db_name][threading.get_ident()].rollback()
if cur is not None:
cur.close()
self.rnd_sleep(0.3)
return None
def get_dict(self, _table, _where=None, sql=None):
cur = None
if sql is None:
sqlcmd = self.sqlcmds[''.join([_table, SQL_GET])]
else:
sqlcmd = sql
i = 10
while i > 0:
i -= 1
try:
LOCK.acquire(True)
self.check_connection()
cur = DB.conn[self.db_name][threading.get_ident()].cursor()
self.sql_exec(sqlcmd, _where, cur)
records = cur.fetchall()
rows = []
for row in records:
rows.append(dict(zip([c[0] for c in cur.description], row)))
cur.close()
LOCK.release()
return rows
except sqlite3.OperationalError as e:
self.logger.warning('{} GET request ignored retrying {}, {}'
.format(self.db_name, i, e))
DB.conn[self.db_name][threading.get_ident()].rollback()
if cur is not None:
cur.close()
LOCK.release()
self.rnd_sleep(0.3)
return None
def get_init(self, _table, _where=None):
"""
Requires "LIMIT ? OFFSET ?" at the end of the sql statement
"""
self.sqlcmd = self.sqlcmds[''.join([_table, SQL_GET])]
self.where = list(_where)
self.offset = 0
def get_dict_next(self):
w_list = self.where.copy()
w_list.extend((1, self.offset))
self.cur = self.sql_exec(self.sqlcmd, tuple(w_list))
records = self.cur.fetchall()
self.offset += 1
if len(records) == 0:
return None
row = records[0]
return dict(zip([c[0] for c in self.cur.description], row))
def save_file(self, _keys, _blob):
"""
Stores the blob in the folder with the db name with
the filename of concatenated _keys
_keys is the list of unique keys for the table
Returns the filepath to the file generated
"""
folder_path = pathlib.Path(self.config['paths']['db_dir']) \
.joinpath(self.db_name)
os.makedirs(folder_path, exist_ok=True)
filename = '_'.join(str(x) for x in _keys) + '.txt'
file_rel_path = pathlib.Path(self.db_name).joinpath(filename)
filepath = folder_path.joinpath(filename)
try:
with open(filepath, mode='wb') as f:
if isinstance(_blob, str):
f.write(_blob.encode())
else:
f.write(_blob)
f.flush()
f.close()
except PermissionError as ex:
self.logger.warning('Unable to create linked database file {}'
.format(file_rel_path))
return None
return file_rel_path
def delete_file(self, _filepath):
"""
_filepath is relative to the database path
"""
fullpath = pathlib.Path(self.config['paths']['db_dir']) \
.joinpath(_filepath)
try:
os.remove(fullpath)
return True
except PermissionError as ex:
self.logger.warning('Unable to delete linked database file {}'
.format(_filepath))
return False
except FileNotFoundError as ex:
self.logger.warning('File missing, unable to delete linked database file {}'
.format(_filepath))
return False
def get_file(self, _filepath):
"""
_filepath is relative to the database path
return the blob
"""
fullpath = pathlib.Path(self.config['paths']['db_dir']) \
.joinpath(_filepath)
if not fullpath.exists():
self.logger.warning('Linked database file Missing {}'.format(_filepath))
return None
try:
with open(fullpath, mode='rb') as f:
blob = f.read()
f.close()
return blob
except PermissionError as ex:
self.logger.warning('Unable to read linked database file {}'
.format(_filepath))
return None
def get_file_by_key(self, _keys):
filename = '_'.join(str(x) for x in _keys) + '.txt'
file_rel_path = pathlib.Path(self.db_name).joinpath(filename)
return self.get_file(file_rel_path)
def reinitialize_tables(self):
self.drop_tables()
self.create_tables()
def create_tables(self):
for table in self.sqlcmds[''.join([SQL_CREATE_TABLES])]:
cur = self.sql_exec(table)
DB.conn[self.db_name][threading.get_ident()].commit()
def drop_tables(self):
for table in self.sqlcmds[SQL_DROP_TABLES]:
cur = self.sql_exec(table)
DB.conn[self.db_name][threading.get_ident()].commit()
def export_sql(self, backup_folder):
self.logger.debug('Running backup for {} database'.format(self.db_name))
try:
if not os.path.isdir(backup_folder):
os.mkdir(backup_folder)
self.check_connection()
# Check for linked file folder and zip up if present
db_linkfilepath = pathlib.Path(self.config['paths']['db_dir']) \
.joinpath(self.db_name)
if db_linkfilepath.exists():
self.logger.debug('Linked file folder exists, backing up folder for db {}'.format(self.db_name))
backup_filelink = pathlib.Path(backup_folder, self.db_name + FILE_LINK_ZIP)
shutil.make_archive(backup_filelink, 'zip', db_linkfilepath)
backup_file = pathlib.Path(backup_folder, self.db_name + BACKUP_EXT)
with open(backup_file, 'w', encoding='utf-8') as export_f:
for line in DB.conn[self.db_name][threading.get_ident()].iterdump():
export_f.write('%s\n' % line)
except PermissionError as e:
self.logger.warning(e)
self.logger.warning('Unable to make backups')
def import_sql(self, backup_folder):
self.logger.debug('Running restore for {} database'.format(self.db_name))
if not os.path.isdir(backup_folder):
msg = 'Backup folder does not exist: {}'.format(backup_folder)
self.logger.warning(msg)
return msg
# Check for linked file folder and zip up if present
backup_filelink = pathlib.Path(backup_folder, self.db_name + FILE_LINK_ZIP + '.zip')
db_linkfilepath = pathlib.Path(self.config['paths']['db_dir']) \
.joinpath(self.db_name)
if backup_filelink.exists():
self.logger.debug('Linked file folder exists, restoring folder for db {}'.format(self.db_name))
shutil.unpack_archive(backup_filelink, db_linkfilepath)
backup_file = pathlib.Path(backup_folder, self.db_name + BACKUP_EXT)
if not os.path.isfile(backup_file):
msg = 'Backup file does not exist, skipping: {}'.format(backup_file)
self.logger.info(msg)
return msg
self.check_connection()
self.drop_tables()
with open(backup_file, 'r') as import_f:
cmd = ''
for line in import_f:
cmd += line
if ';' in line[-3:]:
DB.conn[self.db_name][threading.get_ident()].execute(cmd)
cmd = ''
return None
def close(self):
thread_id = threading.get_ident()
DB.conn[self.db_name][thread_id].close()
del DB.conn[self.db_name][thread_id]
self.logger.debug('{} database closed for thread:{}'.format(self.db_name, thread_id))
def check_connection(self):
if self.db_name not in DB.conn:
DB.conn[self.db_name] = {}
db_conn_dbname = DB.conn[self.db_name]
if threading.get_ident() not in db_conn_dbname:
db_conn_dbname[threading.get_ident()] = sqlite3.connect(
self.db_fullpath, detect_types=sqlite3.PARSE_DECLTYPES | sqlite3.PARSE_COLNAMES)
else:
try:
db_conn_dbname[threading.get_ident()].total_changes
except sqlite3.ProgrammingError:
self.logger.debug('Reopening {} database for thread:{}'.format(self.db_name, threading.get_ident()))
db_conn_dbname[threading.get_ident()] = sqlite3.connect(
self.db_fullpath, detect_types=sqlite3.PARSE_DECLTYPES | sqlite3.PARSE_COLNAMES)