from __future__ import annotations import importlib.util from pathlib import Path from datetime import datetime from typing import List, Optional from core.logger import get_logger logger = get_logger("core.migrations") class MigrationManager: def __init__(self, root: Path) -> None: self.root = root self.migrations_path = root / "migrations" self.migrations_path.mkdir(parents=True, exist_ok=True) def list_migrations(self) -> List[str]: return sorted([item.name for item in self.migrations_path.glob("*.py")]) async def apply(self, app: "OverUB" = None) -> None: if app is None: return await self._ensure_table(app) applied = await self._applied(app) for name in self.list_migrations(): if name in applied: continue module = self._load_module(name) if hasattr(module, "upgrade"): result = module.upgrade(app) if hasattr(result, "__await__"): await result await self._mark(applied, app, name) logger.info("Migrations applied") async def rollback(self, app: "OverUB", name: Optional[str] = None, steps: int = 1) -> List[str]: await self._ensure_table(app) applied = await self._applied_with_time(app) if name: targets = [item for item in applied if item["name"] == name] else: targets = applied[:steps] rolled = [] for item in targets: module = self._load_module(item["name"]) if hasattr(module, "downgrade"): result = module.downgrade(app) if hasattr(result, "__await__"): await result await self._unmark(app, item["name"]) rolled.append(item["name"]) return rolled def validate(self) -> List[str]: errors = [] for name in self.list_migrations(): module = self._load_module(name) if not hasattr(module, "upgrade"): errors.append(f"{name}: missing upgrade()") return errors async def _ensure_table(self, app: "OverUB") -> None: await app.database.execute( "CREATE TABLE IF NOT EXISTS schema_migrations (name TEXT PRIMARY KEY, applied_at TEXT)" ) columns = [] if app.database.db_type == "sqlite": rows = await app.database.fetchall("PRAGMA table_info(schema_migrations)") columns = [row["name"] for row in rows] elif app.database.db_type == "postgres": rows = await app.database.fetchall( "SELECT column_name AS name FROM information_schema.columns WHERE table_name='schema_migrations'" ) columns = [row["name"] for row in rows] if columns and "applied_at" not in columns: await app.database.execute("ALTER TABLE schema_migrations ADD COLUMN applied_at TEXT") async def _applied(self, app: "OverUB") -> List[str]: rows = await app.database.fetchall("SELECT name FROM schema_migrations") return [row["name"] for row in rows] async def _applied_with_time(self, app: "OverUB") -> List[dict]: rows = await app.database.fetchall( "SELECT name, applied_at FROM schema_migrations ORDER BY applied_at DESC" ) return rows async def _mark(self, applied: List[str], app: "OverUB", name: str) -> None: if name in applied: return applied_at = datetime.utcnow().isoformat() await app.database.execute( "INSERT OR REPLACE INTO schema_migrations (name, applied_at) VALUES (?, ?)", (name, applied_at), ) async def _unmark(self, app: "OverUB", name: str) -> None: await app.database.execute("DELETE FROM schema_migrations WHERE name=?", (name,)) def _load_module(self, name: str): path = self.migrations_path / name spec = importlib.util.spec_from_file_location(name, path) module = importlib.util.module_from_spec(spec) assert spec and spec.loader spec.loader.exec_module(module) return module