309 lines
12 KiB
Python
309 lines
12 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
import time
|
|
from pathlib import Path
|
|
from typing import Any, Dict, Iterable, Optional
|
|
|
|
from core.logger import get_logger
|
|
|
|
|
|
logger = get_logger("core.database")
|
|
|
|
|
|
class DatabaseManager:
|
|
def __init__(self, db_type: str, options: Dict[str, Any], db_path: Path) -> None:
|
|
self.db_type = db_type
|
|
self.options = options
|
|
self.db_path = db_path
|
|
self._conn = None
|
|
self._pg_pool = None
|
|
self._mongo = None
|
|
self._redis = None
|
|
|
|
async def connect(self) -> None:
|
|
if self.db_type == "sqlite":
|
|
import aiosqlite
|
|
|
|
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
|
self._conn = await aiosqlite.connect(self.db_path)
|
|
await self._conn.execute("PRAGMA journal_mode=WAL")
|
|
return
|
|
if self.db_type == "postgres":
|
|
try:
|
|
import asyncpg
|
|
except ImportError as exc:
|
|
raise RuntimeError("asyncpg not installed") from exc
|
|
dsn = self.options.get("dsn")
|
|
self._pg_pool = await asyncpg.create_pool(dsn=dsn)
|
|
return
|
|
if self.db_type == "mongodb":
|
|
try:
|
|
import motor.motor_asyncio
|
|
except ImportError as exc:
|
|
raise RuntimeError("motor not installed") from exc
|
|
uri = self.options.get("uri", "mongodb://localhost:27017")
|
|
database = self.options.get("database", "overub")
|
|
client = motor.motor_asyncio.AsyncIOMotorClient(uri)
|
|
self._mongo = client[database]
|
|
return
|
|
if self.db_type == "redis":
|
|
try:
|
|
import redis.asyncio as aioredis
|
|
except ImportError as exc:
|
|
raise RuntimeError("redis not installed") from exc
|
|
url = self.options.get("url", "redis://localhost:6379/0")
|
|
self._redis = aioredis.from_url(url)
|
|
return
|
|
raise RuntimeError("Unsupported database type")
|
|
|
|
async def close(self) -> None:
|
|
if self.db_type == "sqlite" and self._conn is not None:
|
|
await self._conn.close()
|
|
if self.db_type == "postgres" and self._pg_pool is not None:
|
|
await self._pg_pool.close()
|
|
if self.db_type == "redis" and self._redis is not None:
|
|
await self._redis.close()
|
|
|
|
async def execute(self, query: str, params: Iterable[Any] = ()) -> None:
|
|
if self.db_type == "sqlite":
|
|
if self._conn is None:
|
|
raise RuntimeError("Database not connected")
|
|
await self._conn.execute(query, params)
|
|
await self._conn.commit()
|
|
return
|
|
if self.db_type == "postgres":
|
|
if self._pg_pool is None:
|
|
raise RuntimeError("Database not connected")
|
|
query = self._convert_query(query, params)
|
|
async with self._pg_pool.acquire() as conn:
|
|
await conn.execute(query, *params)
|
|
return
|
|
raise RuntimeError("Execute not supported for this backend")
|
|
|
|
async def fetchone(self, query: str, params: Iterable[Any] = ()) -> Optional[Dict[str, Any]]:
|
|
if self.db_type == "sqlite":
|
|
if self._conn is None:
|
|
raise RuntimeError("Database not connected")
|
|
self._conn.row_factory = __import__("aiosqlite").Row
|
|
cursor = await self._conn.execute(query, params)
|
|
row = await cursor.fetchone()
|
|
return dict(row) if row else None
|
|
if self.db_type == "postgres":
|
|
if self._pg_pool is None:
|
|
raise RuntimeError("Database not connected")
|
|
query = self._convert_query(query, params)
|
|
async with self._pg_pool.acquire() as conn:
|
|
row = await conn.fetchrow(query, *params)
|
|
return dict(row) if row else None
|
|
raise RuntimeError("Fetch not supported for this backend")
|
|
|
|
async def fetchall(self, query: str, params: Iterable[Any] = ()) -> list[Dict[str, Any]]:
|
|
if self.db_type == "sqlite":
|
|
if self._conn is None:
|
|
raise RuntimeError("Database not connected")
|
|
self._conn.row_factory = __import__("aiosqlite").Row
|
|
cursor = await self._conn.execute(query, params)
|
|
rows = await cursor.fetchall()
|
|
return [dict(row) for row in rows]
|
|
if self.db_type == "postgres":
|
|
if self._pg_pool is None:
|
|
raise RuntimeError("Database not connected")
|
|
query = self._convert_query(query, params)
|
|
async with self._pg_pool.acquire() as conn:
|
|
rows = await conn.fetch(query, *params)
|
|
return [dict(row) for row in rows]
|
|
raise RuntimeError("Fetch not supported for this backend")
|
|
|
|
async def ensure_plugin_kv(self) -> None:
|
|
if self.db_type == "sqlite":
|
|
await self.execute(
|
|
"""
|
|
CREATE TABLE IF NOT EXISTS plugin_kv (
|
|
plugin TEXT NOT NULL,
|
|
key TEXT NOT NULL,
|
|
value TEXT,
|
|
expires_at INTEGER,
|
|
PRIMARY KEY (plugin, key)
|
|
)
|
|
"""
|
|
)
|
|
return
|
|
if self.db_type == "postgres":
|
|
await self.execute(
|
|
"""
|
|
CREATE TABLE IF NOT EXISTS plugin_kv (
|
|
plugin TEXT NOT NULL,
|
|
key TEXT NOT NULL,
|
|
value TEXT,
|
|
expires_at BIGINT,
|
|
PRIMARY KEY (plugin, key)
|
|
)
|
|
"""
|
|
)
|
|
|
|
def plugin_db(self, name: str) -> "PluginDatabase":
|
|
return PluginDatabase(self, name)
|
|
|
|
async def kv_get(self, plugin: str, key: str) -> Optional[Any]:
|
|
if self.db_type in {"sqlite", "postgres"}:
|
|
await self.ensure_plugin_kv()
|
|
row = await self.fetchone(
|
|
"SELECT value, expires_at FROM plugin_kv WHERE plugin=? AND key=?",
|
|
(plugin, key),
|
|
)
|
|
if not row:
|
|
return None
|
|
expires_at = row.get("expires_at")
|
|
if expires_at and expires_at <= int(time.time()):
|
|
await self.kv_delete(plugin, key)
|
|
return None
|
|
return json.loads(row.get("value") or "null")
|
|
if self.db_type == "mongodb":
|
|
doc = await self._mongo.plugin_kv.find_one({"plugin": plugin, "key": key})
|
|
if not doc:
|
|
return None
|
|
expires_at = doc.get("expires_at")
|
|
if expires_at and expires_at <= int(time.time()):
|
|
await self.kv_delete(plugin, key)
|
|
return None
|
|
return json.loads(doc.get("value") or "null")
|
|
if self.db_type == "redis":
|
|
value = await self._redis.get(f"plugin:{plugin}:{key}")
|
|
if value is None:
|
|
return None
|
|
return json.loads(value)
|
|
return None
|
|
|
|
async def kv_set(self, plugin: str, key: str, value: Any) -> None:
|
|
payload = json.dumps(value)
|
|
if self.db_type in {"sqlite", "postgres"}:
|
|
await self.ensure_plugin_kv()
|
|
if self.db_type == "postgres":
|
|
await self.execute(
|
|
"INSERT INTO plugin_kv (plugin, key, value, expires_at) VALUES (?, ?, ?, NULL) "
|
|
"ON CONFLICT (plugin, key) DO UPDATE SET value=EXCLUDED.value, expires_at=NULL",
|
|
(plugin, key, payload),
|
|
)
|
|
else:
|
|
await self.execute(
|
|
"INSERT OR REPLACE INTO plugin_kv (plugin, key, value, expires_at) VALUES (?, ?, ?, NULL)",
|
|
(plugin, key, payload),
|
|
)
|
|
return
|
|
if self.db_type == "mongodb":
|
|
await self._mongo.plugin_kv.update_one(
|
|
{"plugin": plugin, "key": key},
|
|
{"$set": {"value": payload, "expires_at": None}},
|
|
upsert=True,
|
|
)
|
|
return
|
|
if self.db_type == "redis":
|
|
await self._redis.set(f"plugin:{plugin}:{key}", payload)
|
|
|
|
async def kv_delete(self, plugin: str, key: str) -> None:
|
|
if self.db_type in {"sqlite", "postgres"}:
|
|
await self.ensure_plugin_kv()
|
|
await self.execute(
|
|
"DELETE FROM plugin_kv WHERE plugin=? AND key=?",
|
|
(plugin, key),
|
|
)
|
|
return
|
|
if self.db_type == "mongodb":
|
|
await self._mongo.plugin_kv.delete_one({"plugin": plugin, "key": key})
|
|
return
|
|
if self.db_type == "redis":
|
|
await self._redis.delete(f"plugin:{plugin}:{key}")
|
|
|
|
async def kv_list(self, plugin: str, pattern: str = "%") -> list[str]:
|
|
if self.db_type in {"sqlite", "postgres"}:
|
|
await self.ensure_plugin_kv()
|
|
rows = await self.fetchall(
|
|
"SELECT key FROM plugin_kv WHERE plugin=? AND key LIKE ?",
|
|
(plugin, pattern),
|
|
)
|
|
return [row["key"] for row in rows]
|
|
if self.db_type == "mongodb":
|
|
cursor = self._mongo.plugin_kv.find({"plugin": plugin})
|
|
return [doc["key"] async for doc in cursor]
|
|
if self.db_type == "redis":
|
|
keys = await self._redis.keys(f"plugin:{plugin}:*")
|
|
return [key.decode("utf-8").split(":", 2)[2] for key in keys]
|
|
return []
|
|
|
|
async def kv_exists(self, plugin: str, key: str) -> bool:
|
|
if self.db_type in {"sqlite", "postgres"}:
|
|
await self.ensure_plugin_kv()
|
|
row = await self.fetchone(
|
|
"SELECT 1 FROM plugin_kv WHERE plugin=? AND key=?",
|
|
(plugin, key),
|
|
)
|
|
return row is not None
|
|
if self.db_type == "mongodb":
|
|
doc = await self._mongo.plugin_kv.find_one({"plugin": plugin, "key": key})
|
|
return doc is not None
|
|
if self.db_type == "redis":
|
|
return bool(await self._redis.exists(f"plugin:{plugin}:{key}"))
|
|
return False
|
|
|
|
async def kv_expire(self, plugin: str, key: str, seconds: int) -> None:
|
|
if self.db_type in {"sqlite", "postgres"}:
|
|
await self.ensure_plugin_kv()
|
|
expires_at = int(time.time()) + seconds
|
|
await self.execute(
|
|
"UPDATE plugin_kv SET expires_at=? WHERE plugin=? AND key=?",
|
|
(expires_at, plugin, key),
|
|
)
|
|
return
|
|
if self.db_type == "mongodb":
|
|
expires_at = int(time.time()) + seconds
|
|
await self._mongo.plugin_kv.update_one(
|
|
{"plugin": plugin, "key": key},
|
|
{"$set": {"expires_at": expires_at}},
|
|
)
|
|
return
|
|
if self.db_type == "redis":
|
|
await self._redis.expire(f"plugin:{plugin}:{key}", seconds)
|
|
|
|
def _convert_query(self, query: str, params: Iterable[Any]) -> str:
|
|
if self.db_type != "postgres":
|
|
return query
|
|
if "?" not in query:
|
|
return query
|
|
converted = []
|
|
index = 1
|
|
for char in query:
|
|
if char == "?":
|
|
converted.append(f"${index}")
|
|
index += 1
|
|
else:
|
|
converted.append(char)
|
|
return "".join(converted)
|
|
|
|
|
|
class PluginDatabase:
|
|
def __init__(self, manager: DatabaseManager, plugin: str) -> None:
|
|
self.manager = manager
|
|
self.plugin = plugin
|
|
|
|
async def get(self, key: str) -> Optional[Any]:
|
|
return await self.manager.kv_get(self.plugin, key)
|
|
|
|
async def set(self, key: str, value: Any) -> None:
|
|
await self.manager.kv_set(self.plugin, key, value)
|
|
|
|
async def delete(self, key: str) -> None:
|
|
await self.manager.kv_delete(self.plugin, key)
|
|
|
|
async def list(self, pattern: str = "%") -> list[str]:
|
|
return await self.manager.kv_list(self.plugin, pattern)
|
|
|
|
async def exists(self, key: str) -> bool:
|
|
return await self.manager.kv_exists(self.plugin, key)
|
|
|
|
async def expire(self, key: str, seconds: int) -> None:
|
|
await self.manager.kv_expire(self.plugin, key, seconds)
|
|
|
|
async def query(self, sql: str) -> list[Dict[str, Any]]:
|
|
return await self.manager.fetchall(sql)
|