import os import time import logging import psycopg2 from os import getenv from glob import glob from dotenv import load_dotenv from typing import Optional, Any from psycopg2 import OperationalError load_dotenv() class DatabaseConnector: def __init__( self, dbname: str, user: str, password: str, host: str = "localhost", port: int = 5432, max_retries: int = 3, retry_delay: int = 5 ): """ Initialize database connector with connection parameters and retry settings. Args: dbname: Database name user: Database user password: Database password host: Database host port: Database port max_retries: Maximum number of reconnection attempts retry_delay: Delay between retry attempts in seconds """ self.conn_params = { "dbname": dbname, "user": user, "password": password, "host": host, "port": port } self.max_retries = max_retries self.retry_delay = retry_delay self.conn: Optional[psycopg2.extensions.connection] = None self.cur: Optional[psycopg2.extensions.cursor] = None # Set up logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s' ) self.logger = logging.getLogger(__name__) def connect(self) -> bool: """ Establish database connection. Returns: bool: True if connection successful, False otherwise """ try: self.conn = psycopg2.connect(**self.conn_params) self.cur = self.conn.cursor() self.logger.info("Successfully connected to the database") return True except OperationalError as e: self.logger.error(f"Error connecting to the database: {e}") return False def ensure_connection(self) -> bool: """ Ensure database connection is active, attempt to reconnect if necessary. Returns: bool: True if connection is active or reconnection successful """ if self.conn and not self.conn.closed: try: # Test connection with simple query self.cur.execute("SELECT 1") return True except (psycopg2.Error, AttributeError): self.logger.warning("Database connection lost") # Connection is closed or failed, attempt to reconnect for attempt in range(self.max_retries): self.logger.info(f"Attempting to reconnect (attempt {attempt + 1}/{self.max_retries})") if self.connect(): return True time.sleep(self.retry_delay) self.logger.error("Failed to reconnect to database after multiple attempts") return False def execute_query(self, query: str, params: tuple = None) -> Optional[Any]: """ Execute a database query with automatic reconnection on failure. Args: query: SQL query string params: Query parameters (optional) Returns: Query results if successful, None if failed """ if not self.ensure_connection(): return None try: self.cur.execute(query, params) # Check if query is a SELECT statement if query.strip().upper().startswith("SELECT"): results = self.cur.fetchall() self.conn.commit() return results else: self.conn.commit() return True except psycopg2.Error as e: self.logger.error(f"Error executing query: {e}") self.conn.rollback() return None def close(self): """Close database connection and cursor.""" if self.cur: self.cur.close() if self.conn: self.conn.close() self.logger.info("Database connection closed") db = DatabaseConnector( dbname=getenv("db_name"), user=getenv("db_user"), password=getenv("db_password"), host=getenv("db_host"), port=getenv("db_port") ) def register(email:str, password:str, name:str="unnamed"): db.execute_query("INSERT INTO users (email,password,name) VALUES (%s,%s,%s)", (email,password,name,)) def get_userid_by_token(token:str): result = db.execute_query("SELECT uid FROM users WHERE token=%s LIMIT 1", (token,)) return result[0] def login(email, password): result = db.execute_query("SELECT name,token FROM users WHERE email=%s AND password=%s LIMIT 1", (email,password,)) if len(result) == 0: return {"error": "Invalid email or password"} return { "email": email, "license": "", "name": result[0][0], "token": result[0][1] } def get_userinfo(token:str): result = db.execute_query("SELECT uid,name,email FROM users WHERE token=%s LIMIT 1", (token,)) if len(result) == 0: return {"error": "Not logged in"} userdata = result[0] return { "credit": 0, "credit_received": 0, "discount": None, "email": userdata[2], "license": "", "mfa": False, "name": userdata[1], "payment": "", "uid": userdata[0] } def list_vaults(token:str): uid = get_userid_by_token(token) raw_data = db.execute_query("SELECT v.* FROM vaults v WHERE v.owner = %s", (uid,)) vaults = [] for vault in raw_data: size = 0 for file in glob("data/*"): size += os.path.getsize(file) vaults.append({ "created": int(vault[2].timestamp()*1000), "encryption_version": 0, "host": getenv("SYNC_SERVER_URL"), "id": vault[0], "name": vault[1], "password": "", "region": "Home", "salt": "sugar", "size": size }) return {"limit":len(raw_data)+1,"shared":[],"vaults":vaults} def create_vault(name:str, token:str): uid = get_userid_by_token(token) db.execute_query("INSERT INTO vaults (name,owner) VALUES (%s, %s)", (name, uid)) data = db.execute_query("SELECT * FROM vaults WHERE owner=%s AND name=%s ORDER BY created_at DESC LIMIT 1;", (uid,name,)) data = data[0] return { "created": int(data[2].timestamp()*1000), "encrypted_version": 0, "host": os.getenv("SYNC_SERVER_URL"), "id": data[0], "name": data[1], "password": "", "region": "Home", "salt": "sugar", "size": 0 } def rename_vault(name:str, id:str, token:str): uid = get_userid_by_token(token) db.execute_query("UPDATE vaults SET name=%s WHERE id=%s AND owner=%s", (name, id, uid)) return {} def delete_database(vault_id:str, token:str): uid = get_userid_by_token(token) db.execute_query("DELETE FROM vaults WHERE id=%s AND owner=%s", (vault_id, uid,)) return {} def get_file(vault_id, path): data = db.execute_query("SELECT * FROM vault_files WHERE vault_id=%s AND path=%s LIMIT 1", (vault_id, path,)) return data def get_files(vault_id): return db.execute_query("SELECT * FROM vault_files WHERE vault_id=%s", (vault_id,)) def add_file(vault_id:str, path:str, hash:str, user_id:int, file_content:bytes): db.execute_query("INSERT INTO vaults_files (vault_id, path, hash, user_id, file_content) VALUES (%s, %s, %s, %s, %s)", (vault_id, path, hash, user_id, file_content)) if __name__ == "__main__": db.execute_query("""CREATE TABLE vaults_files ( vault_id VARCHAR(255) NOT NULL REFERENCES vaults(id), uid SERIAL PRIMARY KEY, path TEXT, hash TEXT, user_id INTEGER, file_content BYTEA, created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, modified_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP );""") db.execute_query("""CREATE TABLE IF NOT EXISTS users ( uid VARCHAR(255) PRIMARY KEY DEFAULT gen_random_uuid()::text, name VARCHAR(255) NOT NULL, email VARCHAR(255) NOT NULL, password VARCHAR(255) NOT NULL, token VARCHAR(255) );""") db.execute_query("""CREATE TABLE IF NOT EXISTS vaults ( id VARCHAR(255) PRIMARY KEY DEFAULT gen_random_uuid()::text, name VARCHAR(255) NOT NULL DEFAULT 'Unnamed', created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, owner VARCHAR(255) NOT NULL REFERENCES users(uid), shared VARCHAR(255)[] DEFAULT '{}' );""")