Files
overub/core/database.py
2025-12-21 17:12:32 +01:00

309 lines
12 KiB
Python

from __future__ import annotations
import json
import time
from pathlib import Path
from typing import Any, Dict, Iterable, Optional
from core.logger import get_logger
logger = get_logger("core.database")
class DatabaseManager:
def __init__(self, db_type: str, options: Dict[str, Any], db_path: Path) -> None:
self.db_type = db_type
self.options = options
self.db_path = db_path
self._conn = None
self._pg_pool = None
self._mongo = None
self._redis = None
async def connect(self) -> None:
if self.db_type == "sqlite":
import aiosqlite
self.db_path.parent.mkdir(parents=True, exist_ok=True)
self._conn = await aiosqlite.connect(self.db_path)
await self._conn.execute("PRAGMA journal_mode=WAL")
return
if self.db_type == "postgres":
try:
import asyncpg
except ImportError as exc:
raise RuntimeError("asyncpg not installed") from exc
dsn = self.options.get("dsn")
self._pg_pool = await asyncpg.create_pool(dsn=dsn)
return
if self.db_type == "mongodb":
try:
import motor.motor_asyncio
except ImportError as exc:
raise RuntimeError("motor not installed") from exc
uri = self.options.get("uri", "mongodb://localhost:27017")
database = self.options.get("database", "overub")
client = motor.motor_asyncio.AsyncIOMotorClient(uri)
self._mongo = client[database]
return
if self.db_type == "redis":
try:
import redis.asyncio as aioredis
except ImportError as exc:
raise RuntimeError("redis not installed") from exc
url = self.options.get("url", "redis://localhost:6379/0")
self._redis = aioredis.from_url(url)
return
raise RuntimeError("Unsupported database type")
async def close(self) -> None:
if self.db_type == "sqlite" and self._conn is not None:
await self._conn.close()
if self.db_type == "postgres" and self._pg_pool is not None:
await self._pg_pool.close()
if self.db_type == "redis" and self._redis is not None:
await self._redis.close()
async def execute(self, query: str, params: Iterable[Any] = ()) -> None:
if self.db_type == "sqlite":
if self._conn is None:
raise RuntimeError("Database not connected")
await self._conn.execute(query, params)
await self._conn.commit()
return
if self.db_type == "postgres":
if self._pg_pool is None:
raise RuntimeError("Database not connected")
query = self._convert_query(query, params)
async with self._pg_pool.acquire() as conn:
await conn.execute(query, *params)
return
raise RuntimeError("Execute not supported for this backend")
async def fetchone(self, query: str, params: Iterable[Any] = ()) -> Optional[Dict[str, Any]]:
if self.db_type == "sqlite":
if self._conn is None:
raise RuntimeError("Database not connected")
self._conn.row_factory = __import__("aiosqlite").Row
cursor = await self._conn.execute(query, params)
row = await cursor.fetchone()
return dict(row) if row else None
if self.db_type == "postgres":
if self._pg_pool is None:
raise RuntimeError("Database not connected")
query = self._convert_query(query, params)
async with self._pg_pool.acquire() as conn:
row = await conn.fetchrow(query, *params)
return dict(row) if row else None
raise RuntimeError("Fetch not supported for this backend")
async def fetchall(self, query: str, params: Iterable[Any] = ()) -> list[Dict[str, Any]]:
if self.db_type == "sqlite":
if self._conn is None:
raise RuntimeError("Database not connected")
self._conn.row_factory = __import__("aiosqlite").Row
cursor = await self._conn.execute(query, params)
rows = await cursor.fetchall()
return [dict(row) for row in rows]
if self.db_type == "postgres":
if self._pg_pool is None:
raise RuntimeError("Database not connected")
query = self._convert_query(query, params)
async with self._pg_pool.acquire() as conn:
rows = await conn.fetch(query, *params)
return [dict(row) for row in rows]
raise RuntimeError("Fetch not supported for this backend")
async def ensure_plugin_kv(self) -> None:
if self.db_type == "sqlite":
await self.execute(
"""
CREATE TABLE IF NOT EXISTS plugin_kv (
plugin TEXT NOT NULL,
key TEXT NOT NULL,
value TEXT,
expires_at INTEGER,
PRIMARY KEY (plugin, key)
)
"""
)
return
if self.db_type == "postgres":
await self.execute(
"""
CREATE TABLE IF NOT EXISTS plugin_kv (
plugin TEXT NOT NULL,
key TEXT NOT NULL,
value TEXT,
expires_at BIGINT,
PRIMARY KEY (plugin, key)
)
"""
)
def plugin_db(self, name: str) -> "PluginDatabase":
return PluginDatabase(self, name)
async def kv_get(self, plugin: str, key: str) -> Optional[Any]:
if self.db_type in {"sqlite", "postgres"}:
await self.ensure_plugin_kv()
row = await self.fetchone(
"SELECT value, expires_at FROM plugin_kv WHERE plugin=? AND key=?",
(plugin, key),
)
if not row:
return None
expires_at = row.get("expires_at")
if expires_at and expires_at <= int(time.time()):
await self.kv_delete(plugin, key)
return None
return json.loads(row.get("value") or "null")
if self.db_type == "mongodb":
doc = await self._mongo.plugin_kv.find_one({"plugin": plugin, "key": key})
if not doc:
return None
expires_at = doc.get("expires_at")
if expires_at and expires_at <= int(time.time()):
await self.kv_delete(plugin, key)
return None
return json.loads(doc.get("value") or "null")
if self.db_type == "redis":
value = await self._redis.get(f"plugin:{plugin}:{key}")
if value is None:
return None
return json.loads(value)
return None
async def kv_set(self, plugin: str, key: str, value: Any) -> None:
payload = json.dumps(value)
if self.db_type in {"sqlite", "postgres"}:
await self.ensure_plugin_kv()
if self.db_type == "postgres":
await self.execute(
"INSERT INTO plugin_kv (plugin, key, value, expires_at) VALUES (?, ?, ?, NULL) "
"ON CONFLICT (plugin, key) DO UPDATE SET value=EXCLUDED.value, expires_at=NULL",
(plugin, key, payload),
)
else:
await self.execute(
"INSERT OR REPLACE INTO plugin_kv (plugin, key, value, expires_at) VALUES (?, ?, ?, NULL)",
(plugin, key, payload),
)
return
if self.db_type == "mongodb":
await self._mongo.plugin_kv.update_one(
{"plugin": plugin, "key": key},
{"$set": {"value": payload, "expires_at": None}},
upsert=True,
)
return
if self.db_type == "redis":
await self._redis.set(f"plugin:{plugin}:{key}", payload)
async def kv_delete(self, plugin: str, key: str) -> None:
if self.db_type in {"sqlite", "postgres"}:
await self.ensure_plugin_kv()
await self.execute(
"DELETE FROM plugin_kv WHERE plugin=? AND key=?",
(plugin, key),
)
return
if self.db_type == "mongodb":
await self._mongo.plugin_kv.delete_one({"plugin": plugin, "key": key})
return
if self.db_type == "redis":
await self._redis.delete(f"plugin:{plugin}:{key}")
async def kv_list(self, plugin: str, pattern: str = "%") -> list[str]:
if self.db_type in {"sqlite", "postgres"}:
await self.ensure_plugin_kv()
rows = await self.fetchall(
"SELECT key FROM plugin_kv WHERE plugin=? AND key LIKE ?",
(plugin, pattern),
)
return [row["key"] for row in rows]
if self.db_type == "mongodb":
cursor = self._mongo.plugin_kv.find({"plugin": plugin})
return [doc["key"] async for doc in cursor]
if self.db_type == "redis":
keys = await self._redis.keys(f"plugin:{plugin}:*")
return [key.decode("utf-8").split(":", 2)[2] for key in keys]
return []
async def kv_exists(self, plugin: str, key: str) -> bool:
if self.db_type in {"sqlite", "postgres"}:
await self.ensure_plugin_kv()
row = await self.fetchone(
"SELECT 1 FROM plugin_kv WHERE plugin=? AND key=?",
(plugin, key),
)
return row is not None
if self.db_type == "mongodb":
doc = await self._mongo.plugin_kv.find_one({"plugin": plugin, "key": key})
return doc is not None
if self.db_type == "redis":
return bool(await self._redis.exists(f"plugin:{plugin}:{key}"))
return False
async def kv_expire(self, plugin: str, key: str, seconds: int) -> None:
if self.db_type in {"sqlite", "postgres"}:
await self.ensure_plugin_kv()
expires_at = int(time.time()) + seconds
await self.execute(
"UPDATE plugin_kv SET expires_at=? WHERE plugin=? AND key=?",
(expires_at, plugin, key),
)
return
if self.db_type == "mongodb":
expires_at = int(time.time()) + seconds
await self._mongo.plugin_kv.update_one(
{"plugin": plugin, "key": key},
{"$set": {"expires_at": expires_at}},
)
return
if self.db_type == "redis":
await self._redis.expire(f"plugin:{plugin}:{key}", seconds)
def _convert_query(self, query: str, params: Iterable[Any]) -> str:
if self.db_type != "postgres":
return query
if "?" not in query:
return query
converted = []
index = 1
for char in query:
if char == "?":
converted.append(f"${index}")
index += 1
else:
converted.append(char)
return "".join(converted)
class PluginDatabase:
def __init__(self, manager: DatabaseManager, plugin: str) -> None:
self.manager = manager
self.plugin = plugin
async def get(self, key: str) -> Optional[Any]:
return await self.manager.kv_get(self.plugin, key)
async def set(self, key: str, value: Any) -> None:
await self.manager.kv_set(self.plugin, key, value)
async def delete(self, key: str) -> None:
await self.manager.kv_delete(self.plugin, key)
async def list(self, pattern: str = "%") -> list[str]:
return await self.manager.kv_list(self.plugin, pattern)
async def exists(self, key: str) -> bool:
return await self.manager.kv_exists(self.plugin, key)
async def expire(self, key: str, seconds: int) -> None:
await self.manager.kv_expire(self.plugin, key, seconds)
async def query(self, sql: str) -> list[Dict[str, Any]]:
return await self.manager.fetchall(sql)