""" Used environment variables: DB_PATH: The location of the SQLite file. `website.db` is always appended to the end. """ import threading import sqlite3 import logging import atexit import os # pylint: disable=locally-disabled, global-statement class DB: """ Basically just a thread-safe version of sqlite3. """ def __init__(self): self._log = logging.getLogger('DB') self._dbFile = os.path.join(os.getenv('DB_PATH', './'), 'website.db') self._conn = sqlite3.connect(self._dbFile, check_same_thread=False) self._conn.row_factory = sqlite3.Row self._connLock = threading.Lock() atexit.register(self.close) self._log.info('Database ready.') def execute(self, command: str, args: list[any]=None) -> sqlite3.Cursor: """ Executes a SQL command directly. """ args = [] if args is None else args with self._connLock: return self._conn.execute(command, args) def commit(self): """ Commits any uncommited changes. """ with self._connLock: self._conn.commit() def createTable(self, tableName: str): """ Creates a table in the database with a default ID column. **WARNING: `tableName` is assumed to be safe. DO NOT PASS USER INPUT TO THIS VARIABLE.** Args: tableName (str): The name of the table to create. """ self.execute(f'CREATE TABLE IF NOT EXISTS {tableName}'+ '(id INTEGER PRIMARY KEY AUTOINCREMENT)') def createColumn(self, tableName: str, columnDef: str): """ Creates a column in the specified table if it doesn't already exist. **WARNING: This function call is intended to be INTERNAL ONLY.** **NO USER INPUT SHOULD EVER BE PASSED TO THIS FUNCTION.** Args: tableName (str): The name of the table to alter. columnDef (str): The definition of the column. For example 'name STRING DEFAULT NULL'. """ columnName = columnDef.split(' ')[0] res = self.execute(f'SELECT 1 FROM pragma_table_info(\'{tableName}\')'+ f'WHERE name=\'{columnName}\'') columnExempt = res.fetchone() is None if columnExempt: self.execute(f'ALTER TABLE {tableName} ADD COLUMN {columnDef}') self.commit() def addRow(self, table: str, data: dict): """ Adds a row into the database. **WARNING: The arguments `table` and all keys of `data` are assumed to be internal.** **ONLY THE VALUES OF `data` CAN BE USER INPUT. ANYTHING ELSE IS UNSAFE.** Args: table (str): The table to insert into. data (dict): The keys and values to insert. **The keys are internal only and should NEVER contain user input.** """ self.execute(f'INSERT INTO {table} ({",".join(data.keys())}) '+ f'VALUES ({",".join("?"*len(data))})', list(data.values())) self.commit() def updateColumns(self, table: str, column: str, value: any, selectorColumn: str, selectorValue: any): """ Allows for altering cells based on the slelector. """ self.execute(f'ALTER TABLE {table} SET ?=? WHERE ?=?', (column, value, selectorColumn, selectorValue)) self.commit() def getRow(self, table: str, whereKey: str, whereVal: any) -> dict: """ Gets a row from the database where `whereKey` == `whereVal`. """ res = self.execute(f'SELECT * FROM {table} WHERE {whereKey}=?', (whereVal,)) return res.fetchone() def close(self): """ Safely closes the database connection. """ self._log.info('Shutting down database...') self._conn.close() GLOBAL: DB = None def init(*args, **kwargs) -> DB: """ Initialises the global database. """ global GLOBAL GLOBAL = DB(*args, **kwargs) return GLOBAL