339 lines
14 KiB
Python
339 lines
14 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
from pathlib import Path
|
|
from datetime import datetime
|
|
from typing import Any, Dict, Optional
|
|
|
|
from core.backup import BackupManager
|
|
from core.backup_service import BackupService
|
|
from core.bus import MessageBus
|
|
from core.cache import Cache
|
|
from core.rate_limiter import RateLimit, RateLimiter
|
|
from core.sandbox import Sandbox
|
|
from core.gitea import GiteaClient
|
|
from core.http import SessionManager
|
|
from core.migrations import MigrationManager
|
|
from core.webhook_server import WebhookServer
|
|
from core.module_updates import ModuleUpdateManager
|
|
from core.update_service import UpdateService
|
|
from core.version import VersionManager
|
|
from core.client import ClientWrapper
|
|
from core.commands import CommandBuilder, CommandRegistry
|
|
from core.config import ConfigManager
|
|
from core.database import DatabaseManager
|
|
from core.events import EventDispatcher
|
|
from core.loader import ModuleLoader, PluginManager
|
|
from core.logger import get_logger, setup_logging
|
|
from core.permissions import PermissionManager
|
|
from core.updater import UpdateManager
|
|
|
|
|
|
logger = get_logger("core.app")
|
|
|
|
|
|
class OverUB:
|
|
def __init__(self, root: Path, bot_override: Optional[Dict[str, Any]] = None) -> None:
|
|
self.root = root
|
|
self.config = ConfigManager(root / "config" / "config.yml", root / "config" / "modules.yml")
|
|
self.config.load()
|
|
if bot_override:
|
|
self.config.merge({"bot": bot_override})
|
|
|
|
log_cfg = self.config.get().get("logging", {})
|
|
log_level = log_cfg.get("level", "INFO")
|
|
setup_logging(
|
|
log_level,
|
|
json_logs=bool(log_cfg.get("json", False)),
|
|
module_logs=bool(log_cfg.get("module_logs", False)),
|
|
remote_url=str(log_cfg.get("remote_url", "")),
|
|
)
|
|
|
|
self.events = EventDispatcher()
|
|
self.bus = MessageBus()
|
|
self.permissions = PermissionManager()
|
|
self.commands = CommandRegistry(prefix=self.config.get().get("bot", {}).get("command_prefix", "."))
|
|
self.command_builder = CommandBuilder(self.commands)
|
|
|
|
self.permissions.load_from_config(self.config.get().get("security", {}))
|
|
|
|
cache_size = self.config.get().get("performance", {}).get("cache_size", 100)
|
|
self.cache = Cache(max_size=int(cache_size))
|
|
self.backups = BackupManager(root)
|
|
self.rate_limiter = RateLimiter()
|
|
self.versions = VersionManager(self)
|
|
sandbox_root = root / "plugins"
|
|
sandbox_cfg = self.config.get().get("security", {})
|
|
self.sandbox = Sandbox(
|
|
sandbox_root,
|
|
allow_network=bool(sandbox_cfg.get("plugin_network", False)),
|
|
)
|
|
self.update_service = UpdateService(self)
|
|
self.backup_service = BackupService(self)
|
|
self.http = SessionManager()
|
|
self.migrations = MigrationManager(root)
|
|
|
|
db_cfg = self.config.get().get("database", {})
|
|
db_path = Path(db_cfg.get("path", "data/database.db"))
|
|
self.database = DatabaseManager(db_cfg.get("type", "sqlite"), db_cfg, root / db_path)
|
|
|
|
updates = self.config.get().get("updates", {})
|
|
git_cfg = updates.get("git", {})
|
|
self.updater = UpdateManager(root, git_cfg.get("remote", "origin"), updates.get("branch", "main"))
|
|
modules_cfg = updates.get("modules", {})
|
|
self.module_updates = ModuleUpdateManager(
|
|
root / modules_cfg.get("path", "modules"),
|
|
modules_cfg.get("remote", "origin"),
|
|
modules_cfg.get("branch", updates.get("branch", "main")),
|
|
)
|
|
gitea_cfg = updates.get("gitea", {})
|
|
self.gitea = GiteaClient(gitea_cfg.get("api_url", ""), gitea_cfg.get("token", ""))
|
|
self.updater.gitea = self.gitea
|
|
webhook_cfg = updates.get("webhook", {})
|
|
self.webhook_server = WebhookServer(
|
|
self,
|
|
webhook_cfg.get("host", "0.0.0.0"),
|
|
int(webhook_cfg.get("port", 8080)),
|
|
)
|
|
self.webhook_enabled = bool(webhook_cfg.get("enabled", False))
|
|
|
|
bot_cfg = self.config.get().get("bot", {})
|
|
self.client = ClientWrapper(
|
|
api_id=int(bot_cfg.get("api_id", 0)),
|
|
api_hash=str(bot_cfg.get("api_hash", "")),
|
|
session_name=str(bot_cfg.get("session_name", "overub")),
|
|
)
|
|
|
|
self.modules = ModuleLoader(self, root / "modules")
|
|
self.plugins = PluginManager(self, root / "plugins" / "external")
|
|
self.last_activity = None
|
|
|
|
async def start(self) -> None:
|
|
await self.events.emit("on_startup")
|
|
await self.database.connect()
|
|
await self.migrations.apply(self)
|
|
bot_cfg = self.config.get().get("bot", {})
|
|
await self.client.connect(
|
|
login_mode=str(bot_cfg.get("login_mode", "phone")),
|
|
phone=str(bot_cfg.get("phone", "")),
|
|
qr_path=bot_cfg.get("login_qr_path"),
|
|
qr_open=bool(bot_cfg.get("login_qr_open", False)),
|
|
session_string=bot_cfg.get("session_string"),
|
|
)
|
|
await self.client.attach_handlers(self)
|
|
await self._load_builtin_modules()
|
|
await self._load_plugins()
|
|
if self.config.get().get("updates", {}).get("enabled", False):
|
|
self.update_service.start()
|
|
self.backup_service.start()
|
|
if self.webhook_enabled:
|
|
await self.webhook_server.start()
|
|
await self.events.emit("on_ready")
|
|
logger.info("OverUB ready")
|
|
|
|
async def shutdown(self) -> None:
|
|
await self.events.emit("on_shutdown")
|
|
await self.update_service.stop()
|
|
await self.backup_service.stop()
|
|
if self.webhook_enabled:
|
|
await self.webhook_server.stop()
|
|
await self.http.close()
|
|
await self.client.disconnect()
|
|
await self.database.close()
|
|
logger.info("OverUB shutdown")
|
|
|
|
async def handle_message_event(self, event: Any) -> None:
|
|
message = getattr(event, "message", event)
|
|
text = getattr(message, "raw_text", "") or ""
|
|
user_id = getattr(event, "sender_id", None) or getattr(message, "sender_id", 0) or 0
|
|
chat_id = getattr(event, "chat_id", None) or getattr(message, "chat_id", 0) or 0
|
|
chat_type = self._get_chat_type(event)
|
|
self.last_activity = datetime.utcnow()
|
|
|
|
parsed = self.commands.parse(text)
|
|
if parsed:
|
|
command_event = await self.events.emit("on_command", message=message, event=event, parsed=parsed)
|
|
if not command_event.cancelled:
|
|
command, args = self.commands.resolve(parsed.name, parsed.args)
|
|
if command:
|
|
event.command_flags = self._apply_flag_specs(command, parsed.flags)
|
|
event.command_args = args
|
|
event.command_raw = parsed.raw
|
|
event.command_prefix = parsed.prefix
|
|
if command.arguments:
|
|
try:
|
|
event.command_parsed_args = self._apply_argument_specs(command, args)
|
|
except ValueError as exc:
|
|
await self._reply(message, str(exc))
|
|
return
|
|
if command.chat_types and chat_type not in command.chat_types:
|
|
await self._reply(message, "Command not available here")
|
|
elif not self.permissions.is_allowed(command.permission, user_id, chat_id):
|
|
await self._reply(message, "Permission denied")
|
|
else:
|
|
plugin_cfg = self._plugin_config_for_command(command)
|
|
if plugin_cfg and not self._plugin_allowed(plugin_cfg, user_id, chat_id):
|
|
await self._reply(message, "Plugin permission denied")
|
|
return
|
|
remaining = self.commands.cooldown_remaining(command, user_id)
|
|
plugin_cooldown = int(plugin_cfg.get("cooldown", 0)) if plugin_cfg else 0
|
|
remaining = max(remaining, plugin_cooldown)
|
|
if remaining > 0:
|
|
await self._reply(message, f"Cooldown: {remaining}s")
|
|
else:
|
|
if not self._rate_limit(command, user_id):
|
|
await self._reply(message, "Rate limit exceeded")
|
|
return
|
|
await command.handler(event, args)
|
|
await self.events.emit("on_message_new", message=message, event=event)
|
|
if getattr(event, "out", False):
|
|
await self.events.emit("on_message_sent", message=message, event=event)
|
|
|
|
async def handle_edit_event(self, event: Any) -> None:
|
|
message = getattr(event, "message", event)
|
|
self.last_activity = datetime.utcnow()
|
|
await self.events.emit("on_message_edit", message=message, event=event)
|
|
|
|
async def handle_delete_event(self, event: Any) -> None:
|
|
self.last_activity = datetime.utcnow()
|
|
await self.events.emit("on_message_delete", event=event)
|
|
|
|
async def _load_builtin_modules(self) -> None:
|
|
module_config = self.config.get_modules().get("modules", {})
|
|
for name, cfg in module_config.items():
|
|
if not cfg.get("enabled", True):
|
|
continue
|
|
module_path = f"modules.{name}"
|
|
await self.modules.load(module_path)
|
|
sub_modules = cfg.get("sub_modules", {})
|
|
for sub_name, enabled in sub_modules.items():
|
|
if enabled:
|
|
await self.modules.load(f"modules.{name}.{sub_name}")
|
|
|
|
async def _load_plugins(self) -> None:
|
|
plugin_cfg = self.config.get().get("plugins", {})
|
|
if not plugin_cfg.get("enabled", True):
|
|
return
|
|
plugin_dir = Path(plugin_cfg.get("plugin_path", "plugins/external"))
|
|
if not plugin_dir.exists():
|
|
return
|
|
for item in plugin_dir.iterdir():
|
|
if item.is_dir() and (item / "__init__.py").exists():
|
|
await self.plugins.load(item.name)
|
|
|
|
def register_permission_profile(self, name: str, users: list[int], chats: list[int]) -> None:
|
|
from core.permissions import PermissionProfile
|
|
|
|
self.permissions.add_profile(PermissionProfile(name=name, users=users, chats=chats))
|
|
|
|
def _get_chat_type(self, event: Any) -> str:
|
|
if getattr(event, "is_private", False):
|
|
return "private"
|
|
if getattr(event, "is_group", False):
|
|
return "group"
|
|
if getattr(event, "is_channel", False):
|
|
return "channel"
|
|
return "unknown"
|
|
|
|
async def _reply(self, message: Any, text: str) -> None:
|
|
if hasattr(message, "reply"):
|
|
await message.reply(text)
|
|
|
|
def _plugin_config_for_command(self, command: Any) -> Dict[str, Any]:
|
|
if getattr(command, "owner_type", "") != "plugin":
|
|
return {}
|
|
return self.config.get_plugin_config(getattr(command, "owner", ""))
|
|
|
|
def _plugin_allowed(self, plugin_cfg: Dict[str, Any], user_id: int, chat_id: int) -> bool:
|
|
allowed = plugin_cfg.get("permissions", []) or []
|
|
level = plugin_cfg.get("permission_level", "user")
|
|
if not self.permissions.is_allowed(level, user_id, chat_id):
|
|
return False
|
|
if not allowed:
|
|
return True
|
|
return user_id in allowed or chat_id in allowed
|
|
|
|
def _rate_limit(self, command: Any, user_id: int) -> bool:
|
|
limits = self.config.get().get("performance", {}).get("rate_limits", {})
|
|
command_limit = limits.get("command", {"limit": 5, "window": 5})
|
|
limit = RateLimit(limit=int(command_limit["limit"]), window=int(command_limit["window"]))
|
|
return self.rate_limiter.check(f"command:{command.name}", user_id, limit)
|
|
|
|
def _apply_argument_specs(self, command: Any, args: list[str]) -> Dict[str, Any]:
|
|
parsed: Dict[str, Any] = {}
|
|
idx = 0
|
|
for spec in command.arguments:
|
|
if spec.variadic:
|
|
remaining = args[idx:]
|
|
parsed[spec.name] = [spec.arg_type(item) for item in remaining]
|
|
idx = len(args)
|
|
break
|
|
if idx >= len(args):
|
|
if spec.required:
|
|
raise ValueError(f"Missing argument: {spec.name}")
|
|
parsed[spec.name] = spec.default
|
|
continue
|
|
parsed[spec.name] = spec.arg_type(args[idx])
|
|
idx += 1
|
|
return parsed
|
|
|
|
def _apply_flag_specs(self, command: Any, flags: Dict[str, Any]) -> Dict[str, Any]:
|
|
if not command.flags:
|
|
return flags
|
|
result = dict(flags)
|
|
for spec in command.flags:
|
|
value = None
|
|
if spec.name in result:
|
|
value = result[spec.name]
|
|
else:
|
|
for alias in spec.aliases:
|
|
if alias in result:
|
|
value = result[alias]
|
|
break
|
|
if value is None:
|
|
result[spec.name] = spec.default
|
|
continue
|
|
try:
|
|
result[spec.name] = spec.flag_type(value)
|
|
except Exception:
|
|
result[spec.name] = spec.default
|
|
return result
|
|
|
|
|
|
async def main() -> None:
|
|
root = Path(__file__).resolve().parents[1]
|
|
cfg = ConfigManager(root / "config" / "config.yml", root / "config" / "modules.yml")
|
|
cfg.load()
|
|
instances = cfg.get().get("bot", {}).get("instances", [])
|
|
apps = []
|
|
if instances:
|
|
for instance_cfg in instances:
|
|
apps.append(OverUB(root, bot_override=instance_cfg))
|
|
else:
|
|
apps.append(OverUB(root))
|
|
|
|
async def run_instance(app: OverUB) -> None:
|
|
backoff = 3
|
|
while True:
|
|
try:
|
|
await app.start()
|
|
await app.client.wait_until_disconnected()
|
|
except Exception:
|
|
logger.exception("Instance failed, restarting")
|
|
finally:
|
|
await app.shutdown()
|
|
await asyncio.sleep(backoff)
|
|
|
|
tasks = [asyncio.create_task(run_instance(app)) for app in apps]
|
|
try:
|
|
await asyncio.gather(*tasks)
|
|
except KeyboardInterrupt:
|
|
logger.info("Interrupted")
|
|
finally:
|
|
for task in tasks:
|
|
task.cancel()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main())
|