Files
overub/core/app.py
2025-12-21 17:12:32 +01:00

332 lines
14 KiB
Python

from __future__ import annotations
import asyncio
from pathlib import Path
from datetime import datetime
from typing import Any, Dict, Optional
from core.backup import BackupManager
from core.backup_service import BackupService
from core.bus import MessageBus
from core.cache import Cache
from core.rate_limiter import RateLimit, RateLimiter
from core.sandbox import Sandbox
from core.gitea import GiteaClient
from core.http import SessionManager
from core.migrations import MigrationManager
from core.webhook_server import WebhookServer
from core.module_updates import ModuleUpdateManager
from core.update_service import UpdateService
from core.version import VersionManager
from core.client import ClientWrapper
from core.commands import CommandBuilder, CommandRegistry
from core.config import ConfigManager
from core.database import DatabaseManager
from core.events import EventDispatcher
from core.loader import ModuleLoader, PluginManager
from core.logger import get_logger, setup_logging
from core.permissions import PermissionManager
from core.updater import UpdateManager
logger = get_logger("core.app")
class OverUB:
def __init__(self, root: Path, bot_override: Optional[Dict[str, Any]] = None) -> None:
self.root = root
self.config = ConfigManager(root / "config" / "config.yml", root / "config" / "modules.yml")
self.config.load()
if bot_override:
self.config.merge({"bot": bot_override})
log_cfg = self.config.get().get("logging", {})
log_level = log_cfg.get("level", "INFO")
setup_logging(
log_level,
json_logs=bool(log_cfg.get("json", False)),
module_logs=bool(log_cfg.get("module_logs", False)),
remote_url=str(log_cfg.get("remote_url", "")),
)
self.events = EventDispatcher()
self.bus = MessageBus()
self.permissions = PermissionManager()
self.commands = CommandRegistry(prefix=self.config.get().get("bot", {}).get("command_prefix", "."))
self.command_builder = CommandBuilder(self.commands)
self.permissions.load_from_config(self.config.get().get("security", {}))
cache_size = self.config.get().get("performance", {}).get("cache_size", 100)
self.cache = Cache(max_size=int(cache_size))
self.backups = BackupManager(root)
self.rate_limiter = RateLimiter()
self.versions = VersionManager(self)
sandbox_root = root / "plugins"
sandbox_cfg = self.config.get().get("security", {})
self.sandbox = Sandbox(
sandbox_root,
allow_network=bool(sandbox_cfg.get("plugin_network", False)),
)
self.update_service = UpdateService(self)
self.backup_service = BackupService(self)
self.http = SessionManager()
self.migrations = MigrationManager(root)
db_cfg = self.config.get().get("database", {})
db_path = Path(db_cfg.get("path", "data/database.db"))
self.database = DatabaseManager(db_cfg.get("type", "sqlite"), db_cfg, root / db_path)
updates = self.config.get().get("updates", {})
git_cfg = updates.get("git", {})
self.updater = UpdateManager(root, git_cfg.get("remote", "origin"), updates.get("branch", "main"))
modules_cfg = updates.get("modules", {})
self.module_updates = ModuleUpdateManager(
root / modules_cfg.get("path", "modules"),
modules_cfg.get("remote", "origin"),
modules_cfg.get("branch", updates.get("branch", "main")),
)
gitea_cfg = updates.get("gitea", {})
self.gitea = GiteaClient(gitea_cfg.get("api_url", ""), gitea_cfg.get("token", ""))
self.updater.gitea = self.gitea
webhook_cfg = updates.get("webhook", {})
self.webhook_server = WebhookServer(
self,
webhook_cfg.get("host", "0.0.0.0"),
int(webhook_cfg.get("port", 8080)),
)
self.webhook_enabled = bool(webhook_cfg.get("enabled", False))
bot_cfg = self.config.get().get("bot", {})
self.client = ClientWrapper(
api_id=int(bot_cfg.get("api_id", 0)),
api_hash=str(bot_cfg.get("api_hash", "")),
session_name=str(bot_cfg.get("session_name", "overub")),
)
self.modules = ModuleLoader(self, root / "modules")
self.plugins = PluginManager(self, root / "plugins" / "external")
self.last_activity = None
async def start(self) -> None:
await self.events.emit("on_startup")
await self.database.connect()
await self.migrations.apply(self)
await self.client.connect()
await self.client.attach_handlers(self)
await self._load_builtin_modules()
await self._load_plugins()
if self.config.get().get("updates", {}).get("enabled", False):
self.update_service.start()
self.backup_service.start()
if self.webhook_enabled:
await self.webhook_server.start()
await self.events.emit("on_ready")
logger.info("OverUB ready")
async def shutdown(self) -> None:
await self.events.emit("on_shutdown")
await self.update_service.stop()
await self.backup_service.stop()
if self.webhook_enabled:
await self.webhook_server.stop()
await self.http.close()
await self.client.disconnect()
await self.database.close()
logger.info("OverUB shutdown")
async def handle_message_event(self, event: Any) -> None:
message = getattr(event, "message", event)
text = getattr(message, "raw_text", "") or ""
user_id = getattr(event, "sender_id", None) or getattr(message, "sender_id", 0) or 0
chat_id = getattr(event, "chat_id", None) or getattr(message, "chat_id", 0) or 0
chat_type = self._get_chat_type(event)
self.last_activity = datetime.utcnow()
parsed = self.commands.parse(text)
if parsed:
command_event = await self.events.emit("on_command", message=message, event=event, parsed=parsed)
if not command_event.cancelled:
command, args = self.commands.resolve(parsed.name, parsed.args)
if command:
event.command_flags = self._apply_flag_specs(command, parsed.flags)
event.command_args = args
event.command_raw = parsed.raw
event.command_prefix = parsed.prefix
if command.arguments:
try:
event.command_parsed_args = self._apply_argument_specs(command, args)
except ValueError as exc:
await self._reply(message, str(exc))
return
if command.chat_types and chat_type not in command.chat_types:
await self._reply(message, "Command not available here")
elif not self.permissions.is_allowed(command.permission, user_id, chat_id):
await self._reply(message, "Permission denied")
else:
plugin_cfg = self._plugin_config_for_command(command)
if plugin_cfg and not self._plugin_allowed(plugin_cfg, user_id, chat_id):
await self._reply(message, "Plugin permission denied")
return
remaining = self.commands.cooldown_remaining(command, user_id)
plugin_cooldown = int(plugin_cfg.get("cooldown", 0)) if plugin_cfg else 0
remaining = max(remaining, plugin_cooldown)
if remaining > 0:
await self._reply(message, f"Cooldown: {remaining}s")
else:
if not self._rate_limit(command, user_id):
await self._reply(message, "Rate limit exceeded")
return
await command.handler(event, args)
await self.events.emit("on_message_new", message=message, event=event)
if getattr(event, "out", False):
await self.events.emit("on_message_sent", message=message, event=event)
async def handle_edit_event(self, event: Any) -> None:
message = getattr(event, "message", event)
self.last_activity = datetime.utcnow()
await self.events.emit("on_message_edit", message=message, event=event)
async def handle_delete_event(self, event: Any) -> None:
self.last_activity = datetime.utcnow()
await self.events.emit("on_message_delete", event=event)
async def _load_builtin_modules(self) -> None:
module_config = self.config.get_modules().get("modules", {})
for name, cfg in module_config.items():
if not cfg.get("enabled", True):
continue
module_path = f"modules.{name}"
await self.modules.load(module_path)
sub_modules = cfg.get("sub_modules", {})
for sub_name, enabled in sub_modules.items():
if enabled:
await self.modules.load(f"modules.{name}.{sub_name}")
async def _load_plugins(self) -> None:
plugin_cfg = self.config.get().get("plugins", {})
if not plugin_cfg.get("enabled", True):
return
plugin_dir = Path(plugin_cfg.get("plugin_path", "plugins/external"))
if not plugin_dir.exists():
return
for item in plugin_dir.iterdir():
if item.is_dir() and (item / "__init__.py").exists():
await self.plugins.load(item.name)
def register_permission_profile(self, name: str, users: list[int], chats: list[int]) -> None:
from core.permissions import PermissionProfile
self.permissions.add_profile(PermissionProfile(name=name, users=users, chats=chats))
def _get_chat_type(self, event: Any) -> str:
if getattr(event, "is_private", False):
return "private"
if getattr(event, "is_group", False):
return "group"
if getattr(event, "is_channel", False):
return "channel"
return "unknown"
async def _reply(self, message: Any, text: str) -> None:
if hasattr(message, "reply"):
await message.reply(text)
def _plugin_config_for_command(self, command: Any) -> Dict[str, Any]:
if getattr(command, "owner_type", "") != "plugin":
return {}
return self.config.get_plugin_config(getattr(command, "owner", ""))
def _plugin_allowed(self, plugin_cfg: Dict[str, Any], user_id: int, chat_id: int) -> bool:
allowed = plugin_cfg.get("permissions", []) or []
level = plugin_cfg.get("permission_level", "user")
if not self.permissions.is_allowed(level, user_id, chat_id):
return False
if not allowed:
return True
return user_id in allowed or chat_id in allowed
def _rate_limit(self, command: Any, user_id: int) -> bool:
limits = self.config.get().get("performance", {}).get("rate_limits", {})
command_limit = limits.get("command", {"limit": 5, "window": 5})
limit = RateLimit(limit=int(command_limit["limit"]), window=int(command_limit["window"]))
return self.rate_limiter.check(f"command:{command.name}", user_id, limit)
def _apply_argument_specs(self, command: Any, args: list[str]) -> Dict[str, Any]:
parsed: Dict[str, Any] = {}
idx = 0
for spec in command.arguments:
if spec.variadic:
remaining = args[idx:]
parsed[spec.name] = [spec.arg_type(item) for item in remaining]
idx = len(args)
break
if idx >= len(args):
if spec.required:
raise ValueError(f"Missing argument: {spec.name}")
parsed[spec.name] = spec.default
continue
parsed[spec.name] = spec.arg_type(args[idx])
idx += 1
return parsed
def _apply_flag_specs(self, command: Any, flags: Dict[str, Any]) -> Dict[str, Any]:
if not command.flags:
return flags
result = dict(flags)
for spec in command.flags:
value = None
if spec.name in result:
value = result[spec.name]
else:
for alias in spec.aliases:
if alias in result:
value = result[alias]
break
if value is None:
result[spec.name] = spec.default
continue
try:
result[spec.name] = spec.flag_type(value)
except Exception:
result[spec.name] = spec.default
return result
async def main() -> None:
root = Path(__file__).resolve().parents[1]
cfg = ConfigManager(root / "config" / "config.yml", root / "config" / "modules.yml")
cfg.load()
instances = cfg.get().get("bot", {}).get("instances", [])
apps = []
if instances:
for instance_cfg in instances:
apps.append(OverUB(root, bot_override=instance_cfg))
else:
apps.append(OverUB(root))
async def run_instance(app: OverUB) -> None:
backoff = 3
while True:
try:
await app.start()
await app.client.wait_until_disconnected()
except Exception:
logger.exception("Instance failed, restarting")
finally:
await app.shutdown()
await asyncio.sleep(backoff)
tasks = [asyncio.create_task(run_instance(app)) for app in apps]
try:
await asyncio.gather(*tasks)
except KeyboardInterrupt:
logger.info("Interrupted")
finally:
for task in tasks:
task.cancel()
if __name__ == "__main__":
asyncio.run(main())