135 lines
4.0 KiB
Python
135 lines
4.0 KiB
Python
"""
|
|
|
|
Used environment variables:
|
|
DB_PATH: The location of the SQLite file. `website.db` is always appended to the end.
|
|
"""
|
|
|
|
import threading
|
|
import sqlite3
|
|
import logging
|
|
import atexit
|
|
import os
|
|
|
|
# pylint: disable=locally-disabled, global-statement
|
|
|
|
class DB:
|
|
"""
|
|
Basically just a thread-safe version of sqlite3.
|
|
"""
|
|
|
|
def __init__(self):
|
|
self._log = logging.getLogger('DB')
|
|
self._dbFile = os.path.join(os.getenv('DB_PATH', './'), 'website.db')
|
|
self._conn = sqlite3.connect(self._dbFile, check_same_thread=False)
|
|
self._conn.row_factory = sqlite3.Row
|
|
self._connLock = threading.Lock()
|
|
|
|
atexit.register(self.close)
|
|
self._log.info('Database ready.')
|
|
|
|
def execute(self, command: str, args: list[any]=None) -> sqlite3.Cursor:
|
|
"""
|
|
Executes a SQL command directly.
|
|
"""
|
|
|
|
args = [] if args is None else args
|
|
|
|
with self._connLock:
|
|
return self._conn.execute(command, args)
|
|
|
|
def commit(self):
|
|
"""
|
|
Commits any uncommited changes.
|
|
"""
|
|
|
|
with self._connLock:
|
|
self._conn.commit()
|
|
|
|
def createTable(self, tableName: str):
|
|
"""
|
|
Creates a table in the database with a default ID column.
|
|
|
|
**WARNING: `tableName` is assumed to be safe. DO NOT PASS USER INPUT TO THIS VARIABLE.**
|
|
|
|
Args:
|
|
tableName (str): The name of the table to create.
|
|
"""
|
|
|
|
self.execute(f'CREATE TABLE IF NOT EXISTS {tableName}'+
|
|
'(id INTEGER PRIMARY KEY AUTOINCREMENT)')
|
|
|
|
def createColumn(self, tableName: str, columnDef: str):
|
|
"""
|
|
Creates a column in the specified table if it doesn't already exist.
|
|
|
|
**WARNING: This function call is intended to be INTERNAL ONLY.**
|
|
**NO USER INPUT SHOULD EVER BE PASSED TO THIS FUNCTION.**
|
|
|
|
Args:
|
|
tableName (str): The name of the table to alter.
|
|
columnDef (str): The definition of the column. For example 'name STRING DEFAULT NULL'.
|
|
"""
|
|
|
|
columnName = columnDef.split(' ')[0]
|
|
res = self.execute(f'SELECT 1 FROM pragma_table_info(\'{tableName}\')'+
|
|
f'WHERE name=\'{columnName}\'')
|
|
columnExempt = res.fetchone() is None
|
|
if columnExempt:
|
|
self.execute(f'ALTER TABLE {tableName} ADD COLUMN {columnDef}')
|
|
self.commit()
|
|
|
|
def addRow(self, table: str, data: dict):
|
|
"""
|
|
Adds a row into the database.
|
|
|
|
**WARNING: The arguments `table` and all keys of `data` are assumed to be internal.**
|
|
**ONLY THE VALUES OF `data` CAN BE USER INPUT. ANYTHING ELSE IS UNSAFE.**
|
|
|
|
Args:
|
|
table (str): The table to insert into.
|
|
data (dict): The keys and values to insert.
|
|
**The keys are internal only and should NEVER contain user input.**
|
|
"""
|
|
|
|
self.execute(f'INSERT INTO {table} ({",".join(data.keys())}) '+
|
|
f'VALUES ({",".join("?"*len(data))})', list(data.values()))
|
|
self.commit()
|
|
|
|
def updateColumns(self, table: str, column: str, value: any,
|
|
selectorColumn: str, selectorValue: any):
|
|
"""
|
|
Allows for altering cells based on the slelector.
|
|
"""
|
|
|
|
self.execute(f'ALTER TABLE {table} SET ?=? WHERE ?=?',
|
|
(column, value, selectorColumn, selectorValue))
|
|
self.commit()
|
|
|
|
def getRow(self, table: str, whereKey: str, whereVal: any) -> dict:
|
|
"""
|
|
Gets a row from the database where `whereKey` == `whereVal`.
|
|
"""
|
|
|
|
res = self.execute(f'SELECT * FROM {table} WHERE {whereKey}=?', (whereVal,))
|
|
return res.fetchone()
|
|
|
|
def close(self):
|
|
"""
|
|
Safely closes the database connection.
|
|
"""
|
|
|
|
self._log.info('Shutting down database...')
|
|
self._conn.close()
|
|
|
|
GLOBAL: DB = None
|
|
|
|
def init(*args, **kwargs) -> DB:
|
|
"""
|
|
Initialises the global database.
|
|
"""
|
|
|
|
global GLOBAL
|
|
|
|
GLOBAL = DB(*args, **kwargs)
|
|
return GLOBAL
|