from __future__ import annotations import json import time from pathlib import Path from typing import Any, Dict, Iterable, Optional from core.logger import get_logger logger = get_logger("core.database") class DatabaseManager: def __init__(self, db_type: str, options: Dict[str, Any], db_path: Path) -> None: self.db_type = db_type self.options = options self.db_path = db_path self._conn = None self._pg_pool = None self._mongo = None self._redis = None async def connect(self) -> None: if self.db_type == "sqlite": import aiosqlite self.db_path.parent.mkdir(parents=True, exist_ok=True) self._conn = await aiosqlite.connect(self.db_path) await self._conn.execute("PRAGMA journal_mode=WAL") return if self.db_type == "postgres": try: import asyncpg except ImportError as exc: raise RuntimeError("asyncpg not installed") from exc dsn = self.options.get("dsn") self._pg_pool = await asyncpg.create_pool(dsn=dsn) return if self.db_type == "mongodb": try: import motor.motor_asyncio except ImportError as exc: raise RuntimeError("motor not installed") from exc uri = self.options.get("uri", "mongodb://localhost:27017") database = self.options.get("database", "overub") client = motor.motor_asyncio.AsyncIOMotorClient(uri) self._mongo = client[database] return if self.db_type == "redis": try: import redis.asyncio as aioredis except ImportError as exc: raise RuntimeError("redis not installed") from exc url = self.options.get("url", "redis://localhost:6379/0") self._redis = aioredis.from_url(url) return raise RuntimeError("Unsupported database type") async def close(self) -> None: if self.db_type == "sqlite" and self._conn is not None: await self._conn.close() if self.db_type == "postgres" and self._pg_pool is not None: await self._pg_pool.close() if self.db_type == "redis" and self._redis is not None: await self._redis.close() async def execute(self, query: str, params: Iterable[Any] = ()) -> None: if self.db_type == "sqlite": if self._conn is None: raise RuntimeError("Database not connected") await self._conn.execute(query, params) await self._conn.commit() return if self.db_type == "postgres": if self._pg_pool is None: raise RuntimeError("Database not connected") query = self._convert_query(query, params) async with self._pg_pool.acquire() as conn: await conn.execute(query, *params) return raise RuntimeError("Execute not supported for this backend") async def fetchone(self, query: str, params: Iterable[Any] = ()) -> Optional[Dict[str, Any]]: if self.db_type == "sqlite": if self._conn is None: raise RuntimeError("Database not connected") self._conn.row_factory = __import__("aiosqlite").Row cursor = await self._conn.execute(query, params) row = await cursor.fetchone() return dict(row) if row else None if self.db_type == "postgres": if self._pg_pool is None: raise RuntimeError("Database not connected") query = self._convert_query(query, params) async with self._pg_pool.acquire() as conn: row = await conn.fetchrow(query, *params) return dict(row) if row else None raise RuntimeError("Fetch not supported for this backend") async def fetchall(self, query: str, params: Iterable[Any] = ()) -> list[Dict[str, Any]]: if self.db_type == "sqlite": if self._conn is None: raise RuntimeError("Database not connected") self._conn.row_factory = __import__("aiosqlite").Row cursor = await self._conn.execute(query, params) rows = await cursor.fetchall() return [dict(row) for row in rows] if self.db_type == "postgres": if self._pg_pool is None: raise RuntimeError("Database not connected") query = self._convert_query(query, params) async with self._pg_pool.acquire() as conn: rows = await conn.fetch(query, *params) return [dict(row) for row in rows] raise RuntimeError("Fetch not supported for this backend") async def ensure_plugin_kv(self) -> None: if self.db_type == "sqlite": await self.execute( """ CREATE TABLE IF NOT EXISTS plugin_kv ( plugin TEXT NOT NULL, key TEXT NOT NULL, value TEXT, expires_at INTEGER, PRIMARY KEY (plugin, key) ) """ ) return if self.db_type == "postgres": await self.execute( """ CREATE TABLE IF NOT EXISTS plugin_kv ( plugin TEXT NOT NULL, key TEXT NOT NULL, value TEXT, expires_at BIGINT, PRIMARY KEY (plugin, key) ) """ ) def plugin_db(self, name: str) -> "PluginDatabase": return PluginDatabase(self, name) async def kv_get(self, plugin: str, key: str) -> Optional[Any]: if self.db_type in {"sqlite", "postgres"}: await self.ensure_plugin_kv() row = await self.fetchone( "SELECT value, expires_at FROM plugin_kv WHERE plugin=? AND key=?", (plugin, key), ) if not row: return None expires_at = row.get("expires_at") if expires_at and expires_at <= int(time.time()): await self.kv_delete(plugin, key) return None return json.loads(row.get("value") or "null") if self.db_type == "mongodb": doc = await self._mongo.plugin_kv.find_one({"plugin": plugin, "key": key}) if not doc: return None expires_at = doc.get("expires_at") if expires_at and expires_at <= int(time.time()): await self.kv_delete(plugin, key) return None return json.loads(doc.get("value") or "null") if self.db_type == "redis": value = await self._redis.get(f"plugin:{plugin}:{key}") if value is None: return None return json.loads(value) return None async def kv_set(self, plugin: str, key: str, value: Any) -> None: payload = json.dumps(value) if self.db_type in {"sqlite", "postgres"}: await self.ensure_plugin_kv() if self.db_type == "postgres": await self.execute( "INSERT INTO plugin_kv (plugin, key, value, expires_at) VALUES (?, ?, ?, NULL) " "ON CONFLICT (plugin, key) DO UPDATE SET value=EXCLUDED.value, expires_at=NULL", (plugin, key, payload), ) else: await self.execute( "INSERT OR REPLACE INTO plugin_kv (plugin, key, value, expires_at) VALUES (?, ?, ?, NULL)", (plugin, key, payload), ) return if self.db_type == "mongodb": await self._mongo.plugin_kv.update_one( {"plugin": plugin, "key": key}, {"$set": {"value": payload, "expires_at": None}}, upsert=True, ) return if self.db_type == "redis": await self._redis.set(f"plugin:{plugin}:{key}", payload) async def kv_delete(self, plugin: str, key: str) -> None: if self.db_type in {"sqlite", "postgres"}: await self.ensure_plugin_kv() await self.execute( "DELETE FROM plugin_kv WHERE plugin=? AND key=?", (plugin, key), ) return if self.db_type == "mongodb": await self._mongo.plugin_kv.delete_one({"plugin": plugin, "key": key}) return if self.db_type == "redis": await self._redis.delete(f"plugin:{plugin}:{key}") async def kv_list(self, plugin: str, pattern: str = "%") -> list[str]: if self.db_type in {"sqlite", "postgres"}: await self.ensure_plugin_kv() rows = await self.fetchall( "SELECT key FROM plugin_kv WHERE plugin=? AND key LIKE ?", (plugin, pattern), ) return [row["key"] for row in rows] if self.db_type == "mongodb": cursor = self._mongo.plugin_kv.find({"plugin": plugin}) return [doc["key"] async for doc in cursor] if self.db_type == "redis": keys = await self._redis.keys(f"plugin:{plugin}:*") return [key.decode("utf-8").split(":", 2)[2] for key in keys] return [] async def kv_exists(self, plugin: str, key: str) -> bool: if self.db_type in {"sqlite", "postgres"}: await self.ensure_plugin_kv() row = await self.fetchone( "SELECT 1 FROM plugin_kv WHERE plugin=? AND key=?", (plugin, key), ) return row is not None if self.db_type == "mongodb": doc = await self._mongo.plugin_kv.find_one({"plugin": plugin, "key": key}) return doc is not None if self.db_type == "redis": return bool(await self._redis.exists(f"plugin:{plugin}:{key}")) return False async def kv_expire(self, plugin: str, key: str, seconds: int) -> None: if self.db_type in {"sqlite", "postgres"}: await self.ensure_plugin_kv() expires_at = int(time.time()) + seconds await self.execute( "UPDATE plugin_kv SET expires_at=? WHERE plugin=? AND key=?", (expires_at, plugin, key), ) return if self.db_type == "mongodb": expires_at = int(time.time()) + seconds await self._mongo.plugin_kv.update_one( {"plugin": plugin, "key": key}, {"$set": {"expires_at": expires_at}}, ) return if self.db_type == "redis": await self._redis.expire(f"plugin:{plugin}:{key}", seconds) def _convert_query(self, query: str, params: Iterable[Any]) -> str: if self.db_type != "postgres": return query if "?" not in query: return query converted = [] index = 1 for char in query: if char == "?": converted.append(f"${index}") index += 1 else: converted.append(char) return "".join(converted) class PluginDatabase: def __init__(self, manager: DatabaseManager, plugin: str) -> None: self.manager = manager self.plugin = plugin async def get(self, key: str) -> Optional[Any]: return await self.manager.kv_get(self.plugin, key) async def set(self, key: str, value: Any) -> None: await self.manager.kv_set(self.plugin, key, value) async def delete(self, key: str) -> None: await self.manager.kv_delete(self.plugin, key) async def list(self, pattern: str = "%") -> list[str]: return await self.manager.kv_list(self.plugin, pattern) async def exists(self, key: str) -> bool: return await self.manager.kv_exists(self.plugin, key) async def expire(self, key: str, seconds: int) -> None: await self.manager.kv_expire(self.plugin, key, seconds) async def query(self, sql: str) -> list[Dict[str, Any]]: return await self.manager.fetchall(sql)