Files
overub/core/loader.py
2025-12-21 17:12:32 +01:00

477 lines
20 KiB
Python

from __future__ import annotations
import asyncio
import importlib
import inspect
import shutil
import subprocess
import sys
from pathlib import Path
from types import ModuleType
from typing import Dict, List, Optional, Type
from core.logger import get_logger
from core.audit import log as audit_log
from core.events import EventHandler
from core.versioning import is_compatible
from core.module import Module
from core.plugin import Plugin, PluginContext
logger = get_logger("core.loader")
class ModuleLoader:
def __init__(self, app: "OverUB", modules_path: Path) -> None:
self.app = app
self.modules_path = modules_path
self._loaded: Dict[str, Module] = {}
async def load(self, module_path: str) -> Optional[Module]:
if module_path in self._loaded:
return self._loaded[module_path]
module = importlib.import_module(module_path)
module_class = self._find_module_class(module)
if module_class is None:
logger.warning("No Module class found in %s", module_path)
return None
await self._load_dependencies(module_class)
instance = module_class(self.app)
await instance.on_load()
self._loaded[module_path] = instance
audit_log("module_load", module_path)
return instance
async def unload(self, module_path: str) -> None:
instance = self._loaded.pop(module_path, None)
if instance:
await instance.on_unload()
audit_log("module_unload", module_path)
async def reload(self, module_path: str) -> Optional[Module]:
await self.unload(module_path)
if module_path in sys.modules:
importlib.reload(sys.modules[module_path])
return await self.load(module_path)
def list(self) -> List[str]:
return sorted(self._loaded.keys())
def list_installed(self) -> List[str]:
if not self.modules_path.exists():
return []
return sorted(
[
item.name
for item in self.modules_path.iterdir()
if item.is_dir() and (item / "__init__.py").exists()
]
)
def _find_module_class(self, module: ModuleType) -> Optional[Type[Module]]:
for _, obj in inspect.getmembers(module, inspect.isclass):
if issubclass(obj, Module) and obj is not Module:
return obj
return None
async def _run(self, cmd: List[str]) -> str:
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, self._sync_run, cmd)
def _sync_run(self, cmd: List[str]) -> str:
logger.debug("Running command: %s", " ".join(cmd))
result = subprocess.run(cmd, cwd=self.modules_path, capture_output=True, text=True)
if result.returncode != 0:
raise RuntimeError(result.stderr.strip() or "Command failed")
return result.stdout.strip()
async def _load_dependencies(self, module_class: Type[Module]) -> None:
for dep in getattr(module_class, "dependencies", []) or []:
if dep.startswith("lib:"):
requirement = dep.split(":", 1)[1]
await self._run([sys.executable, "-m", "pip", "install", requirement])
continue
if dep not in self._loaded:
await self.load(dep)
for dep in getattr(module_class, "optional_dependencies", []) or []:
if dep.startswith("lib:"):
requirement = dep.split(":", 1)[1]
try:
await self._run([sys.executable, "-m", "pip", "install", requirement])
except Exception:
logger.debug("Optional dependency failed: %s", requirement)
continue
try:
if dep not in self._loaded:
await self.load(dep)
except Exception:
logger.debug("Optional module dependency failed: %s", dep)
class PluginManager:
def __init__(self, app: "OverUB", plugin_path: Path) -> None:
self.app = app
self.plugin_path = plugin_path
self._loaded: Dict[str, Plugin] = {}
self._handlers: Dict[str, List[tuple[str, EventHandler]]] = {}
self._error_counts: Dict[str, int] = {}
def list_installed(self) -> List[str]:
if not self.plugin_path.exists():
return []
return sorted([item.name for item in self.plugin_path.iterdir() if item.is_dir()])
async def load(self, name: str) -> Optional[Plugin]:
if name in self._loaded:
return self._loaded[name]
sys.path.insert(0, str(self.plugin_path))
try:
module = importlib.import_module(name)
plugin_class = self._find_plugin_class(module)
if plugin_class is None:
logger.warning("No Plugin class found in %s", name)
return None
await self._load_dependencies(plugin_class)
if self._has_conflicts(plugin_class):
logger.warning("Plugin %s conflicts with loaded plugins", name)
return None
if not await self._check_core_compat(plugin_class):
logger.warning("Plugin %s is not compatible with core", name)
return None
plugin_name = plugin_class.name or name
if not getattr(plugin_class, "name", ""):
plugin_class.name = plugin_name
if getattr(plugin_class, "config_schema", None):
self.app.config.register_plugin_schema(plugin_name, plugin_class.config_schema)
errors = self.app.config.validate_plugin_config(plugin_name)
if errors:
logger.warning("Plugin %s config issues: %s", plugin_name, ", ".join(errors))
plugin = plugin_class(PluginContext(self.app, plugin_name))
await plugin.on_load()
self._loaded[name] = plugin
await self._enable_if_configured(plugin)
audit_log("plugin_load", name)
return plugin
except Exception:
logger.exception("Failed to load plugin %s", name)
return None
async def unload(self, name: str) -> None:
plugin = self._loaded.pop(name, None)
if plugin:
await self.disable(name)
await plugin.on_unload()
audit_log("plugin_unload", name)
async def enable(self, name: str) -> None:
plugin = self._loaded.get(name)
if plugin:
self._register_hooks(plugin)
await plugin.on_enable()
audit_log("plugin_enable", name)
async def disable(self, name: str) -> None:
plugin = self._loaded.get(name)
if plugin:
self._unregister_hooks(name)
await plugin.on_disable()
audit_log("plugin_disable", name)
async def reload(self, name: str) -> Optional[Plugin]:
await self.unload(name)
if name in sys.modules:
importlib.reload(sys.modules[name])
return await self.load(name)
def list(self) -> List[str]:
return sorted(self._loaded.keys())
async def install(self, repo: str) -> str:
repo_url, ref = self._parse_repo(repo)
allowed = self.app.config.get().get("updates", {}).get("allowed_sources", [])
if allowed and not any(repo_url.startswith(src) for src in allowed):
raise RuntimeError("Source not allowed")
self.plugin_path.mkdir(parents=True, exist_ok=True)
dest_name = repo_url.rstrip("/").split("/")[-1]
dest_path = self.plugin_path / dest_name
if dest_path.exists():
raise RuntimeError(f"Plugin {dest_name} already exists")
await self._run(["git", "clone", repo_url, str(dest_path)])
if ref:
await self._run(["git", "checkout", ref], cwd=dest_path)
await self._verify_repo(dest_path)
self._ensure_init(dest_path)
await self._install_requirements(dest_path)
audit_log("plugin_install", dest_name)
return dest_name
async def uninstall(self, name: str) -> None:
plugin_path = self.plugin_path / name
if plugin_path.exists():
shutil.rmtree(plugin_path)
audit_log("plugin_uninstall", name)
async def update(self, name: str) -> str:
plugin_path = self.plugin_path / name
if not plugin_path.exists():
raise RuntimeError(f"Plugin {name} not found")
output = await self._run(["git", "pull"], cwd=plugin_path)
await self._verify_repo(plugin_path)
await self._install_requirements(plugin_path)
audit_log("plugin_update", name)
return output
async def rollback(self, name: str, ref: str = "HEAD~1") -> str:
plugin_path = self.plugin_path / name
if not plugin_path.exists():
raise RuntimeError(f"Plugin {name} not found")
output = await self._run(["git", "reset", "--hard", ref], cwd=plugin_path)
audit_log("plugin_rollback", name)
return output
async def fetch(self, name: str) -> str:
plugin_path = self.plugin_path / name
if not plugin_path.exists():
raise RuntimeError(f"Plugin {name} not found")
return await self._run(["git", "fetch"], cwd=plugin_path)
async def remote(self, name: str) -> str:
plugin_path = self.plugin_path / name
if not plugin_path.exists():
raise RuntimeError(f"Plugin {name} not found")
return await self._run(["git", "remote", "get-url", "origin"], cwd=plugin_path)
async def info(self, name: str) -> Dict[str, str]:
plugin = self._loaded.get(name)
if plugin:
return {
"name": plugin.name,
"version": plugin.version,
"author": plugin.author,
"description": plugin.description,
"category": plugin.category,
}
info = {"name": name, "status": "not_loaded"}
try:
remote = await self.remote(name)
except Exception:
return info
owner_repo = self._parse_owner_repo(remote)
if owner_repo and self.app.gitea.api_url:
owner, repo = owner_repo
try:
data = self.app.gitea.repo_info(owner, repo)
except Exception:
return info
info.update(
{
"full_name": data.get("full_name", ""),
"stars": str(data.get("stars_count", "")),
"downloads": str(data.get("downloads", "")),
}
)
return info
async def search(self, query: str) -> str:
api_url = self.app.config.get().get("updates", {}).get("gitea", {}).get("api_url")
if api_url:
try:
results = self.app.gitea.search_repos(query)
except Exception as exc:
return f"Gitea API error: {exc}"
lines = []
for item in results:
name = item.get("full_name", "")
if not name:
continue
stars = item.get("stars_count", 0)
lines.append(f"{name}{stars}")
return "\n".join(lines)
return await self._run(["tea", "repos", "search", query])
def _find_plugin_class(self, module: ModuleType) -> Optional[Type[Plugin]]:
for _, obj in inspect.getmembers(module, inspect.isclass):
if issubclass(obj, Plugin) and obj is not Plugin:
return obj
return None
async def _enable_if_configured(self, plugin: Plugin) -> None:
config = self.app.config.get_plugin_config(plugin.name)
if config.get("enabled", True):
await self.enable(plugin.name)
def _register_hooks(self, plugin: Plugin) -> None:
if plugin.name in self._handlers:
return
mapping = {
"on_startup": "on_startup",
"on_shutdown": "on_shutdown",
"on_ready": "on_ready",
"on_reconnect": "on_reconnect",
"on_disconnect": "on_disconnect",
"on_message": "on_message_new",
"on_message_new": "on_message_new",
"on_edit": "on_message_edit",
"on_message_edit": "on_message_edit",
"on_delete": "on_message_delete",
"on_message_delete": "on_message_delete",
"on_command": "on_command",
"on_message_read": "on_message_read",
"on_message_sent": "on_message_sent",
"on_inline_query": "on_inline_query",
"on_callback_query": "on_callback_query",
"on_chat_action": "on_chat_action",
"on_chat_update": "on_chat_update",
"on_typing": "on_typing",
"on_recording": "on_recording",
"on_user_update": "on_user_update",
"on_contact_update": "on_contact_update",
"on_status_update": "on_status_update",
}
handlers: List[tuple[str, EventHandler]] = []
for method_name, event_name in mapping.items():
handler = self._resolve_handler(plugin, method_name)
if handler:
wrapped = self._wrap_handler(plugin, handler)
self.app.events.on(event_name, wrapped)
handlers.append((event_name, wrapped))
self._handlers[plugin.name] = handlers
def _unregister_hooks(self, name: str) -> None:
handlers = self._handlers.pop(name, [])
for event_name, handler in handlers:
self.app.events.off(event_name, handler)
def _resolve_handler(self, plugin: Plugin, method_name: str) -> Optional[EventHandler]:
base_method = getattr(Plugin, method_name, None)
handler = getattr(plugin, method_name, None)
if handler is None:
return None
if base_method is not None and getattr(handler, "__func__", None) == base_method:
return None
return handler
def _wrap_handler(self, plugin: Plugin, handler: EventHandler) -> EventHandler:
async def wrapped(event):
cfg = self.app.config.get_plugin_config(plugin.name)
timeout = cfg.get("timeout")
try:
if timeout:
await asyncio.wait_for(handler(event), timeout=timeout)
else:
await handler(event)
except asyncio.TimeoutError:
logger.warning("Plugin %s timed out", plugin.name)
except Exception:
logger.exception("Plugin %s handler failed", plugin.name)
self._error_counts[plugin.name] = self._error_counts.get(plugin.name, 0) + 1
limit = int(self.app.config.get().get("security", {}).get("plugin_error_limit", 3))
if self._error_counts[plugin.name] >= limit:
logger.error("Disabling plugin %s after %s errors", plugin.name, limit)
await self.disable(plugin.name)
max_mem = cfg.get("max_memory_mb") or self.app.config.get().get("performance", {}).get("max_memory")
if max_mem:
try:
from core.monitor import get_system_stats
stats = get_system_stats()
if stats.memory_mb and stats.memory_mb > float(max_mem):
logger.warning("Memory limit exceeded, disabling plugin %s", plugin.name)
await self.disable(plugin.name)
except Exception:
logger.debug("Memory check skipped")
max_cpu = cfg.get("max_cpu_percent") or self.app.config.get().get("performance", {}).get("max_cpu")
if max_cpu:
try:
from core.monitor import get_system_stats
stats = get_system_stats()
if stats.cpu_percent and stats.cpu_percent > float(max_cpu):
logger.warning("CPU limit exceeded, disabling plugin %s", plugin.name)
await self.disable(plugin.name)
except Exception:
logger.debug("CPU check skipped")
return wrapped
async def _load_dependencies(self, plugin_class: Type[Plugin]) -> None:
for dep in getattr(plugin_class, "dependencies", []) or []:
if dep.startswith("lib:"):
requirement = dep.split(":", 1)[1]
await self._run([sys.executable, "-m", "pip", "install", requirement])
continue
if dep not in self._loaded:
await self.load(dep)
def _has_conflicts(self, plugin_class: Type[Plugin]) -> bool:
conflicts = set(getattr(plugin_class, "conflicts", []) or [])
return any(conflict in self._loaded for conflict in conflicts)
async def _check_core_compat(self, plugin_class: Type[Plugin]) -> bool:
min_version = getattr(plugin_class, "min_core_version", "")
max_version = getattr(plugin_class, "max_core_version", "")
if not min_version and not max_version:
return True
info = await self.app.updater.get_version_info()
return is_compatible(info.core, min_version, max_version)
def _parse_repo(self, repo: str) -> tuple[str, Optional[str]]:
if "@" in repo:
repo, ref = repo.split("@", 1)
else:
ref = None
if "://" in repo or repo.startswith("git@"):
return repo, ref
base = self.app.config.get().get("updates", {}).get("git", {}).get("remote", "")
if base.startswith("http"):
root = "/".join(base.split("/")[:3])
return f"{root}/{repo}", ref
raise RuntimeError("Repo URL must be a full URL or configure updates.git.remote")
def _parse_owner_repo(self, url: str) -> Optional[tuple[str, str]]:
if url.startswith("http"):
parts = url.rstrip(".git").split("/")
if len(parts) >= 2:
return parts[-2], parts[-1]
if url.startswith("git@") and ":" in url:
path = url.split(":", 1)[1].rstrip(".git")
parts = path.split("/")
if len(parts) >= 2:
return parts[-2], parts[-1]
return None
def _ensure_init(self, path: Path) -> None:
init_path = path / "__init__.py"
if init_path.exists():
return
plugin_file = path / "plugin.py"
if plugin_file.exists():
init_path.write_text("from .plugin import *\n", encoding="utf-8")
async def _install_requirements(self, path: Path) -> None:
requirements = path / "requirements.txt"
if requirements.exists():
await self._run([sys.executable, "-m", "pip", "install", "-r", str(requirements)])
async def _verify_repo(self, path: Path) -> None:
security = self.app.config.get().get("security", {})
if not security.get("verify_plugin_commits", False):
return
commit = (await self._run(["git", "rev-parse", "HEAD"], cwd=path)).strip()
await self._run(["git", "verify-commit", commit], cwd=path)
allowed = security.get("allowed_signers", [])
if allowed:
signer = (await self._run(["git", "log", "--format=%GF", "-n", "1", commit], cwd=path)).strip()
if signer not in allowed:
raise RuntimeError("Plugin signer not allowed")
async def _run(self, cmd: List[str], cwd: Optional[Path] = None) -> str:
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, self._sync_run, cmd, cwd)
def _sync_run(self, cmd: List[str], cwd: Optional[Path]) -> str:
logger.debug("Running command: %s", " ".join(cmd))
result = subprocess.run(cmd, cwd=cwd, capture_output=True, text=True)
if result.returncode != 0:
raise RuntimeError(result.stderr.strip() or "Command failed")
return result.stdout.strip()