477 lines
20 KiB
Python
477 lines
20 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import importlib
|
|
import inspect
|
|
import shutil
|
|
import subprocess
|
|
import sys
|
|
from pathlib import Path
|
|
from types import ModuleType
|
|
from typing import Dict, List, Optional, Type
|
|
|
|
from core.logger import get_logger
|
|
from core.audit import log as audit_log
|
|
from core.events import EventHandler
|
|
from core.versioning import is_compatible
|
|
from core.module import Module
|
|
from core.plugin import Plugin, PluginContext
|
|
|
|
|
|
logger = get_logger("core.loader")
|
|
|
|
|
|
class ModuleLoader:
|
|
def __init__(self, app: "OverUB", modules_path: Path) -> None:
|
|
self.app = app
|
|
self.modules_path = modules_path
|
|
self._loaded: Dict[str, Module] = {}
|
|
|
|
async def load(self, module_path: str) -> Optional[Module]:
|
|
if module_path in self._loaded:
|
|
return self._loaded[module_path]
|
|
module = importlib.import_module(module_path)
|
|
module_class = self._find_module_class(module)
|
|
if module_class is None:
|
|
logger.warning("No Module class found in %s", module_path)
|
|
return None
|
|
await self._load_dependencies(module_class)
|
|
instance = module_class(self.app)
|
|
await instance.on_load()
|
|
self._loaded[module_path] = instance
|
|
audit_log("module_load", module_path)
|
|
return instance
|
|
|
|
async def unload(self, module_path: str) -> None:
|
|
instance = self._loaded.pop(module_path, None)
|
|
if instance:
|
|
await instance.on_unload()
|
|
audit_log("module_unload", module_path)
|
|
|
|
async def reload(self, module_path: str) -> Optional[Module]:
|
|
await self.unload(module_path)
|
|
if module_path in sys.modules:
|
|
importlib.reload(sys.modules[module_path])
|
|
return await self.load(module_path)
|
|
|
|
def list(self) -> List[str]:
|
|
return sorted(self._loaded.keys())
|
|
|
|
def list_installed(self) -> List[str]:
|
|
if not self.modules_path.exists():
|
|
return []
|
|
return sorted(
|
|
[
|
|
item.name
|
|
for item in self.modules_path.iterdir()
|
|
if item.is_dir() and (item / "__init__.py").exists()
|
|
]
|
|
)
|
|
|
|
def _find_module_class(self, module: ModuleType) -> Optional[Type[Module]]:
|
|
for _, obj in inspect.getmembers(module, inspect.isclass):
|
|
if issubclass(obj, Module) and obj is not Module:
|
|
return obj
|
|
return None
|
|
|
|
async def _run(self, cmd: List[str]) -> str:
|
|
loop = asyncio.get_event_loop()
|
|
return await loop.run_in_executor(None, self._sync_run, cmd)
|
|
|
|
def _sync_run(self, cmd: List[str]) -> str:
|
|
logger.debug("Running command: %s", " ".join(cmd))
|
|
result = subprocess.run(cmd, cwd=self.modules_path, capture_output=True, text=True)
|
|
if result.returncode != 0:
|
|
raise RuntimeError(result.stderr.strip() or "Command failed")
|
|
return result.stdout.strip()
|
|
|
|
async def _load_dependencies(self, module_class: Type[Module]) -> None:
|
|
for dep in getattr(module_class, "dependencies", []) or []:
|
|
if dep.startswith("lib:"):
|
|
requirement = dep.split(":", 1)[1]
|
|
await self._run([sys.executable, "-m", "pip", "install", requirement])
|
|
continue
|
|
if dep not in self._loaded:
|
|
await self.load(dep)
|
|
for dep in getattr(module_class, "optional_dependencies", []) or []:
|
|
if dep.startswith("lib:"):
|
|
requirement = dep.split(":", 1)[1]
|
|
try:
|
|
await self._run([sys.executable, "-m", "pip", "install", requirement])
|
|
except Exception:
|
|
logger.debug("Optional dependency failed: %s", requirement)
|
|
continue
|
|
try:
|
|
if dep not in self._loaded:
|
|
await self.load(dep)
|
|
except Exception:
|
|
logger.debug("Optional module dependency failed: %s", dep)
|
|
|
|
|
|
class PluginManager:
|
|
def __init__(self, app: "OverUB", plugin_path: Path) -> None:
|
|
self.app = app
|
|
self.plugin_path = plugin_path
|
|
self._loaded: Dict[str, Plugin] = {}
|
|
self._handlers: Dict[str, List[tuple[str, EventHandler]]] = {}
|
|
self._error_counts: Dict[str, int] = {}
|
|
|
|
def list_installed(self) -> List[str]:
|
|
if not self.plugin_path.exists():
|
|
return []
|
|
return sorted([item.name for item in self.plugin_path.iterdir() if item.is_dir()])
|
|
|
|
async def load(self, name: str) -> Optional[Plugin]:
|
|
if name in self._loaded:
|
|
return self._loaded[name]
|
|
sys.path.insert(0, str(self.plugin_path))
|
|
try:
|
|
module = importlib.import_module(name)
|
|
plugin_class = self._find_plugin_class(module)
|
|
if plugin_class is None:
|
|
logger.warning("No Plugin class found in %s", name)
|
|
return None
|
|
await self._load_dependencies(plugin_class)
|
|
if self._has_conflicts(plugin_class):
|
|
logger.warning("Plugin %s conflicts with loaded plugins", name)
|
|
return None
|
|
if not await self._check_core_compat(plugin_class):
|
|
logger.warning("Plugin %s is not compatible with core", name)
|
|
return None
|
|
plugin_name = plugin_class.name or name
|
|
if not getattr(plugin_class, "name", ""):
|
|
plugin_class.name = plugin_name
|
|
if getattr(plugin_class, "config_schema", None):
|
|
self.app.config.register_plugin_schema(plugin_name, plugin_class.config_schema)
|
|
errors = self.app.config.validate_plugin_config(plugin_name)
|
|
if errors:
|
|
logger.warning("Plugin %s config issues: %s", plugin_name, ", ".join(errors))
|
|
plugin = plugin_class(PluginContext(self.app, plugin_name))
|
|
await plugin.on_load()
|
|
self._loaded[name] = plugin
|
|
await self._enable_if_configured(plugin)
|
|
audit_log("plugin_load", name)
|
|
return plugin
|
|
except Exception:
|
|
logger.exception("Failed to load plugin %s", name)
|
|
return None
|
|
|
|
async def unload(self, name: str) -> None:
|
|
plugin = self._loaded.pop(name, None)
|
|
if plugin:
|
|
await self.disable(name)
|
|
await plugin.on_unload()
|
|
audit_log("plugin_unload", name)
|
|
|
|
async def enable(self, name: str) -> None:
|
|
plugin = self._loaded.get(name)
|
|
if plugin:
|
|
self._register_hooks(plugin)
|
|
await plugin.on_enable()
|
|
audit_log("plugin_enable", name)
|
|
|
|
async def disable(self, name: str) -> None:
|
|
plugin = self._loaded.get(name)
|
|
if plugin:
|
|
self._unregister_hooks(name)
|
|
await plugin.on_disable()
|
|
audit_log("plugin_disable", name)
|
|
|
|
async def reload(self, name: str) -> Optional[Plugin]:
|
|
await self.unload(name)
|
|
if name in sys.modules:
|
|
importlib.reload(sys.modules[name])
|
|
return await self.load(name)
|
|
|
|
def list(self) -> List[str]:
|
|
return sorted(self._loaded.keys())
|
|
|
|
async def install(self, repo: str) -> str:
|
|
repo_url, ref = self._parse_repo(repo)
|
|
allowed = self.app.config.get().get("updates", {}).get("allowed_sources", [])
|
|
if allowed and not any(repo_url.startswith(src) for src in allowed):
|
|
raise RuntimeError("Source not allowed")
|
|
self.plugin_path.mkdir(parents=True, exist_ok=True)
|
|
dest_name = repo_url.rstrip("/").split("/")[-1]
|
|
dest_path = self.plugin_path / dest_name
|
|
if dest_path.exists():
|
|
raise RuntimeError(f"Plugin {dest_name} already exists")
|
|
await self._run(["git", "clone", repo_url, str(dest_path)])
|
|
if ref:
|
|
await self._run(["git", "checkout", ref], cwd=dest_path)
|
|
await self._verify_repo(dest_path)
|
|
self._ensure_init(dest_path)
|
|
await self._install_requirements(dest_path)
|
|
audit_log("plugin_install", dest_name)
|
|
return dest_name
|
|
|
|
async def uninstall(self, name: str) -> None:
|
|
plugin_path = self.plugin_path / name
|
|
if plugin_path.exists():
|
|
shutil.rmtree(plugin_path)
|
|
audit_log("plugin_uninstall", name)
|
|
|
|
async def update(self, name: str) -> str:
|
|
plugin_path = self.plugin_path / name
|
|
if not plugin_path.exists():
|
|
raise RuntimeError(f"Plugin {name} not found")
|
|
output = await self._run(["git", "pull"], cwd=plugin_path)
|
|
await self._verify_repo(plugin_path)
|
|
await self._install_requirements(plugin_path)
|
|
audit_log("plugin_update", name)
|
|
return output
|
|
|
|
async def rollback(self, name: str, ref: str = "HEAD~1") -> str:
|
|
plugin_path = self.plugin_path / name
|
|
if not plugin_path.exists():
|
|
raise RuntimeError(f"Plugin {name} not found")
|
|
output = await self._run(["git", "reset", "--hard", ref], cwd=plugin_path)
|
|
audit_log("plugin_rollback", name)
|
|
return output
|
|
|
|
async def fetch(self, name: str) -> str:
|
|
plugin_path = self.plugin_path / name
|
|
if not plugin_path.exists():
|
|
raise RuntimeError(f"Plugin {name} not found")
|
|
return await self._run(["git", "fetch"], cwd=plugin_path)
|
|
|
|
async def remote(self, name: str) -> str:
|
|
plugin_path = self.plugin_path / name
|
|
if not plugin_path.exists():
|
|
raise RuntimeError(f"Plugin {name} not found")
|
|
return await self._run(["git", "remote", "get-url", "origin"], cwd=plugin_path)
|
|
|
|
async def info(self, name: str) -> Dict[str, str]:
|
|
plugin = self._loaded.get(name)
|
|
if plugin:
|
|
return {
|
|
"name": plugin.name,
|
|
"version": plugin.version,
|
|
"author": plugin.author,
|
|
"description": plugin.description,
|
|
"category": plugin.category,
|
|
}
|
|
info = {"name": name, "status": "not_loaded"}
|
|
try:
|
|
remote = await self.remote(name)
|
|
except Exception:
|
|
return info
|
|
owner_repo = self._parse_owner_repo(remote)
|
|
if owner_repo and self.app.gitea.api_url:
|
|
owner, repo = owner_repo
|
|
try:
|
|
data = self.app.gitea.repo_info(owner, repo)
|
|
except Exception:
|
|
return info
|
|
info.update(
|
|
{
|
|
"full_name": data.get("full_name", ""),
|
|
"stars": str(data.get("stars_count", "")),
|
|
"downloads": str(data.get("downloads", "")),
|
|
}
|
|
)
|
|
return info
|
|
|
|
async def search(self, query: str) -> str:
|
|
api_url = self.app.config.get().get("updates", {}).get("gitea", {}).get("api_url")
|
|
if api_url:
|
|
try:
|
|
results = self.app.gitea.search_repos(query)
|
|
except Exception as exc:
|
|
return f"Gitea API error: {exc}"
|
|
lines = []
|
|
for item in results:
|
|
name = item.get("full_name", "")
|
|
if not name:
|
|
continue
|
|
stars = item.get("stars_count", 0)
|
|
lines.append(f"{name} ⭐{stars}")
|
|
return "\n".join(lines)
|
|
return await self._run(["tea", "repos", "search", query])
|
|
|
|
def _find_plugin_class(self, module: ModuleType) -> Optional[Type[Plugin]]:
|
|
for _, obj in inspect.getmembers(module, inspect.isclass):
|
|
if issubclass(obj, Plugin) and obj is not Plugin:
|
|
return obj
|
|
return None
|
|
|
|
async def _enable_if_configured(self, plugin: Plugin) -> None:
|
|
config = self.app.config.get_plugin_config(plugin.name)
|
|
if config.get("enabled", True):
|
|
await self.enable(plugin.name)
|
|
|
|
def _register_hooks(self, plugin: Plugin) -> None:
|
|
if plugin.name in self._handlers:
|
|
return
|
|
mapping = {
|
|
"on_startup": "on_startup",
|
|
"on_shutdown": "on_shutdown",
|
|
"on_ready": "on_ready",
|
|
"on_reconnect": "on_reconnect",
|
|
"on_disconnect": "on_disconnect",
|
|
"on_message": "on_message_new",
|
|
"on_message_new": "on_message_new",
|
|
"on_edit": "on_message_edit",
|
|
"on_message_edit": "on_message_edit",
|
|
"on_delete": "on_message_delete",
|
|
"on_message_delete": "on_message_delete",
|
|
"on_command": "on_command",
|
|
"on_message_read": "on_message_read",
|
|
"on_message_sent": "on_message_sent",
|
|
"on_inline_query": "on_inline_query",
|
|
"on_callback_query": "on_callback_query",
|
|
"on_chat_action": "on_chat_action",
|
|
"on_chat_update": "on_chat_update",
|
|
"on_typing": "on_typing",
|
|
"on_recording": "on_recording",
|
|
"on_user_update": "on_user_update",
|
|
"on_contact_update": "on_contact_update",
|
|
"on_status_update": "on_status_update",
|
|
}
|
|
handlers: List[tuple[str, EventHandler]] = []
|
|
for method_name, event_name in mapping.items():
|
|
handler = self._resolve_handler(plugin, method_name)
|
|
if handler:
|
|
wrapped = self._wrap_handler(plugin, handler)
|
|
self.app.events.on(event_name, wrapped)
|
|
handlers.append((event_name, wrapped))
|
|
self._handlers[plugin.name] = handlers
|
|
|
|
def _unregister_hooks(self, name: str) -> None:
|
|
handlers = self._handlers.pop(name, [])
|
|
for event_name, handler in handlers:
|
|
self.app.events.off(event_name, handler)
|
|
|
|
def _resolve_handler(self, plugin: Plugin, method_name: str) -> Optional[EventHandler]:
|
|
base_method = getattr(Plugin, method_name, None)
|
|
handler = getattr(plugin, method_name, None)
|
|
if handler is None:
|
|
return None
|
|
if base_method is not None and getattr(handler, "__func__", None) == base_method:
|
|
return None
|
|
return handler
|
|
|
|
def _wrap_handler(self, plugin: Plugin, handler: EventHandler) -> EventHandler:
|
|
async def wrapped(event):
|
|
cfg = self.app.config.get_plugin_config(plugin.name)
|
|
timeout = cfg.get("timeout")
|
|
try:
|
|
if timeout:
|
|
await asyncio.wait_for(handler(event), timeout=timeout)
|
|
else:
|
|
await handler(event)
|
|
except asyncio.TimeoutError:
|
|
logger.warning("Plugin %s timed out", plugin.name)
|
|
except Exception:
|
|
logger.exception("Plugin %s handler failed", plugin.name)
|
|
self._error_counts[plugin.name] = self._error_counts.get(plugin.name, 0) + 1
|
|
limit = int(self.app.config.get().get("security", {}).get("plugin_error_limit", 3))
|
|
if self._error_counts[plugin.name] >= limit:
|
|
logger.error("Disabling plugin %s after %s errors", plugin.name, limit)
|
|
await self.disable(plugin.name)
|
|
max_mem = cfg.get("max_memory_mb") or self.app.config.get().get("performance", {}).get("max_memory")
|
|
if max_mem:
|
|
try:
|
|
from core.monitor import get_system_stats
|
|
|
|
stats = get_system_stats()
|
|
if stats.memory_mb and stats.memory_mb > float(max_mem):
|
|
logger.warning("Memory limit exceeded, disabling plugin %s", plugin.name)
|
|
await self.disable(plugin.name)
|
|
except Exception:
|
|
logger.debug("Memory check skipped")
|
|
max_cpu = cfg.get("max_cpu_percent") or self.app.config.get().get("performance", {}).get("max_cpu")
|
|
if max_cpu:
|
|
try:
|
|
from core.monitor import get_system_stats
|
|
|
|
stats = get_system_stats()
|
|
if stats.cpu_percent and stats.cpu_percent > float(max_cpu):
|
|
logger.warning("CPU limit exceeded, disabling plugin %s", plugin.name)
|
|
await self.disable(plugin.name)
|
|
except Exception:
|
|
logger.debug("CPU check skipped")
|
|
return wrapped
|
|
|
|
async def _load_dependencies(self, plugin_class: Type[Plugin]) -> None:
|
|
for dep in getattr(plugin_class, "dependencies", []) or []:
|
|
if dep.startswith("lib:"):
|
|
requirement = dep.split(":", 1)[1]
|
|
await self._run([sys.executable, "-m", "pip", "install", requirement])
|
|
continue
|
|
if dep not in self._loaded:
|
|
await self.load(dep)
|
|
|
|
def _has_conflicts(self, plugin_class: Type[Plugin]) -> bool:
|
|
conflicts = set(getattr(plugin_class, "conflicts", []) or [])
|
|
return any(conflict in self._loaded for conflict in conflicts)
|
|
|
|
async def _check_core_compat(self, plugin_class: Type[Plugin]) -> bool:
|
|
min_version = getattr(plugin_class, "min_core_version", "")
|
|
max_version = getattr(plugin_class, "max_core_version", "")
|
|
if not min_version and not max_version:
|
|
return True
|
|
info = await self.app.updater.get_version_info()
|
|
return is_compatible(info.core, min_version, max_version)
|
|
|
|
def _parse_repo(self, repo: str) -> tuple[str, Optional[str]]:
|
|
if "@" in repo:
|
|
repo, ref = repo.split("@", 1)
|
|
else:
|
|
ref = None
|
|
if "://" in repo or repo.startswith("git@"):
|
|
return repo, ref
|
|
base = self.app.config.get().get("updates", {}).get("git", {}).get("remote", "")
|
|
if base.startswith("http"):
|
|
root = "/".join(base.split("/")[:3])
|
|
return f"{root}/{repo}", ref
|
|
raise RuntimeError("Repo URL must be a full URL or configure updates.git.remote")
|
|
|
|
def _parse_owner_repo(self, url: str) -> Optional[tuple[str, str]]:
|
|
if url.startswith("http"):
|
|
parts = url.rstrip(".git").split("/")
|
|
if len(parts) >= 2:
|
|
return parts[-2], parts[-1]
|
|
if url.startswith("git@") and ":" in url:
|
|
path = url.split(":", 1)[1].rstrip(".git")
|
|
parts = path.split("/")
|
|
if len(parts) >= 2:
|
|
return parts[-2], parts[-1]
|
|
return None
|
|
|
|
def _ensure_init(self, path: Path) -> None:
|
|
init_path = path / "__init__.py"
|
|
if init_path.exists():
|
|
return
|
|
plugin_file = path / "plugin.py"
|
|
if plugin_file.exists():
|
|
init_path.write_text("from .plugin import *\n", encoding="utf-8")
|
|
|
|
async def _install_requirements(self, path: Path) -> None:
|
|
requirements = path / "requirements.txt"
|
|
if requirements.exists():
|
|
await self._run([sys.executable, "-m", "pip", "install", "-r", str(requirements)])
|
|
|
|
async def _verify_repo(self, path: Path) -> None:
|
|
security = self.app.config.get().get("security", {})
|
|
if not security.get("verify_plugin_commits", False):
|
|
return
|
|
commit = (await self._run(["git", "rev-parse", "HEAD"], cwd=path)).strip()
|
|
await self._run(["git", "verify-commit", commit], cwd=path)
|
|
allowed = security.get("allowed_signers", [])
|
|
if allowed:
|
|
signer = (await self._run(["git", "log", "--format=%GF", "-n", "1", commit], cwd=path)).strip()
|
|
if signer not in allowed:
|
|
raise RuntimeError("Plugin signer not allowed")
|
|
|
|
async def _run(self, cmd: List[str], cwd: Optional[Path] = None) -> str:
|
|
loop = asyncio.get_event_loop()
|
|
return await loop.run_in_executor(None, self._sync_run, cmd, cwd)
|
|
|
|
def _sync_run(self, cmd: List[str], cwd: Optional[Path]) -> str:
|
|
logger.debug("Running command: %s", " ".join(cmd))
|
|
result = subprocess.run(cmd, cwd=cwd, capture_output=True, text=True)
|
|
if result.returncode != 0:
|
|
raise RuntimeError(result.stderr.strip() or "Command failed")
|
|
return result.stdout.strip()
|