Files
overub/core/migrations.py
2025-12-21 17:12:32 +01:00

110 lines
4.1 KiB
Python

from __future__ import annotations
import importlib.util
from pathlib import Path
from datetime import datetime
from typing import List, Optional
from core.logger import get_logger
logger = get_logger("core.migrations")
class MigrationManager:
def __init__(self, root: Path) -> None:
self.root = root
self.migrations_path = root / "migrations"
self.migrations_path.mkdir(parents=True, exist_ok=True)
def list_migrations(self) -> List[str]:
return sorted([item.name for item in self.migrations_path.glob("*.py")])
async def apply(self, app: "OverUB" = None) -> None:
if app is None:
return
await self._ensure_table(app)
applied = await self._applied(app)
for name in self.list_migrations():
if name in applied:
continue
module = self._load_module(name)
if hasattr(module, "upgrade"):
result = module.upgrade(app)
if hasattr(result, "__await__"):
await result
await self._mark(applied, app, name)
logger.info("Migrations applied")
async def rollback(self, app: "OverUB", name: Optional[str] = None, steps: int = 1) -> List[str]:
await self._ensure_table(app)
applied = await self._applied_with_time(app)
if name:
targets = [item for item in applied if item["name"] == name]
else:
targets = applied[:steps]
rolled = []
for item in targets:
module = self._load_module(item["name"])
if hasattr(module, "downgrade"):
result = module.downgrade(app)
if hasattr(result, "__await__"):
await result
await self._unmark(app, item["name"])
rolled.append(item["name"])
return rolled
def validate(self) -> List[str]:
errors = []
for name in self.list_migrations():
module = self._load_module(name)
if not hasattr(module, "upgrade"):
errors.append(f"{name}: missing upgrade()")
return errors
async def _ensure_table(self, app: "OverUB") -> None:
await app.database.execute(
"CREATE TABLE IF NOT EXISTS schema_migrations (name TEXT PRIMARY KEY, applied_at TEXT)"
)
columns = []
if app.database.db_type == "sqlite":
rows = await app.database.fetchall("PRAGMA table_info(schema_migrations)")
columns = [row["name"] for row in rows]
elif app.database.db_type == "postgres":
rows = await app.database.fetchall(
"SELECT column_name AS name FROM information_schema.columns WHERE table_name='schema_migrations'"
)
columns = [row["name"] for row in rows]
if columns and "applied_at" not in columns:
await app.database.execute("ALTER TABLE schema_migrations ADD COLUMN applied_at TEXT")
async def _applied(self, app: "OverUB") -> List[str]:
rows = await app.database.fetchall("SELECT name FROM schema_migrations")
return [row["name"] for row in rows]
async def _applied_with_time(self, app: "OverUB") -> List[dict]:
rows = await app.database.fetchall(
"SELECT name, applied_at FROM schema_migrations ORDER BY applied_at DESC"
)
return rows
async def _mark(self, applied: List[str], app: "OverUB", name: str) -> None:
if name in applied:
return
applied_at = datetime.utcnow().isoformat()
await app.database.execute(
"INSERT OR REPLACE INTO schema_migrations (name, applied_at) VALUES (?, ?)",
(name, applied_at),
)
async def _unmark(self, app: "OverUB", name: str) -> None:
await app.database.execute("DELETE FROM schema_migrations WHERE name=?", (name,))
def _load_module(self, name: str):
path = self.migrations_path / name
spec = importlib.util.spec_from_file_location(name, path)
module = importlib.util.module_from_spec(spec)
assert spec and spec.loader
spec.loader.exec_module(module)
return module