110 lines
4.1 KiB
Python
110 lines
4.1 KiB
Python
from __future__ import annotations
|
|
|
|
import importlib.util
|
|
from pathlib import Path
|
|
from datetime import datetime
|
|
from typing import List, Optional
|
|
|
|
from core.logger import get_logger
|
|
|
|
|
|
logger = get_logger("core.migrations")
|
|
|
|
|
|
class MigrationManager:
|
|
def __init__(self, root: Path) -> None:
|
|
self.root = root
|
|
self.migrations_path = root / "migrations"
|
|
self.migrations_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
def list_migrations(self) -> List[str]:
|
|
return sorted([item.name for item in self.migrations_path.glob("*.py")])
|
|
|
|
async def apply(self, app: "OverUB" = None) -> None:
|
|
if app is None:
|
|
return
|
|
await self._ensure_table(app)
|
|
applied = await self._applied(app)
|
|
for name in self.list_migrations():
|
|
if name in applied:
|
|
continue
|
|
module = self._load_module(name)
|
|
if hasattr(module, "upgrade"):
|
|
result = module.upgrade(app)
|
|
if hasattr(result, "__await__"):
|
|
await result
|
|
await self._mark(applied, app, name)
|
|
logger.info("Migrations applied")
|
|
|
|
async def rollback(self, app: "OverUB", name: Optional[str] = None, steps: int = 1) -> List[str]:
|
|
await self._ensure_table(app)
|
|
applied = await self._applied_with_time(app)
|
|
if name:
|
|
targets = [item for item in applied if item["name"] == name]
|
|
else:
|
|
targets = applied[:steps]
|
|
rolled = []
|
|
for item in targets:
|
|
module = self._load_module(item["name"])
|
|
if hasattr(module, "downgrade"):
|
|
result = module.downgrade(app)
|
|
if hasattr(result, "__await__"):
|
|
await result
|
|
await self._unmark(app, item["name"])
|
|
rolled.append(item["name"])
|
|
return rolled
|
|
|
|
def validate(self) -> List[str]:
|
|
errors = []
|
|
for name in self.list_migrations():
|
|
module = self._load_module(name)
|
|
if not hasattr(module, "upgrade"):
|
|
errors.append(f"{name}: missing upgrade()")
|
|
return errors
|
|
|
|
async def _ensure_table(self, app: "OverUB") -> None:
|
|
await app.database.execute(
|
|
"CREATE TABLE IF NOT EXISTS schema_migrations (name TEXT PRIMARY KEY, applied_at TEXT)"
|
|
)
|
|
columns = []
|
|
if app.database.db_type == "sqlite":
|
|
rows = await app.database.fetchall("PRAGMA table_info(schema_migrations)")
|
|
columns = [row["name"] for row in rows]
|
|
elif app.database.db_type == "postgres":
|
|
rows = await app.database.fetchall(
|
|
"SELECT column_name AS name FROM information_schema.columns WHERE table_name='schema_migrations'"
|
|
)
|
|
columns = [row["name"] for row in rows]
|
|
if columns and "applied_at" not in columns:
|
|
await app.database.execute("ALTER TABLE schema_migrations ADD COLUMN applied_at TEXT")
|
|
|
|
async def _applied(self, app: "OverUB") -> List[str]:
|
|
rows = await app.database.fetchall("SELECT name FROM schema_migrations")
|
|
return [row["name"] for row in rows]
|
|
|
|
async def _applied_with_time(self, app: "OverUB") -> List[dict]:
|
|
rows = await app.database.fetchall(
|
|
"SELECT name, applied_at FROM schema_migrations ORDER BY applied_at DESC"
|
|
)
|
|
return rows
|
|
|
|
async def _mark(self, applied: List[str], app: "OverUB", name: str) -> None:
|
|
if name in applied:
|
|
return
|
|
applied_at = datetime.utcnow().isoformat()
|
|
await app.database.execute(
|
|
"INSERT OR REPLACE INTO schema_migrations (name, applied_at) VALUES (?, ?)",
|
|
(name, applied_at),
|
|
)
|
|
|
|
async def _unmark(self, app: "OverUB", name: str) -> None:
|
|
await app.database.execute("DELETE FROM schema_migrations WHERE name=?", (name,))
|
|
|
|
def _load_module(self, name: str):
|
|
path = self.migrations_path / name
|
|
spec = importlib.util.spec_from_file_location(name, path)
|
|
module = importlib.util.module_from_spec(spec)
|
|
assert spec and spec.loader
|
|
spec.loader.exec_module(module)
|
|
return module
|