214 lines
7.6 KiB
Python
214 lines
7.6 KiB
Python
import os
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List
|
|
|
|
import yaml
|
|
|
|
|
|
def _deep_update(base: Dict[str, Any], updates: Dict[str, Any]) -> Dict[str, Any]:
|
|
for key, value in updates.items():
|
|
if isinstance(value, dict) and isinstance(base.get(key), dict):
|
|
base[key] = _deep_update(base[key], value)
|
|
else:
|
|
base[key] = value
|
|
return base
|
|
|
|
|
|
class ConfigManager:
|
|
def __init__(self, config_path: Path, modules_path: Path):
|
|
self.config_path = config_path
|
|
self.modules_path = modules_path
|
|
self._config: Dict[str, Any] = {}
|
|
self._modules: Dict[str, Any] = {}
|
|
self._plugin_schemas: Dict[str, Dict[str, Any]] = {}
|
|
|
|
def load(self) -> None:
|
|
self._config = self._read_yaml(self.config_path)
|
|
self._modules = self._read_yaml(self.modules_path)
|
|
self._config.setdefault("config_version", "1.0.0")
|
|
self._apply_env_overrides()
|
|
|
|
def _read_yaml(self, path: Path) -> Dict[str, Any]:
|
|
if not path.exists():
|
|
return {}
|
|
return yaml.safe_load(path.read_text(encoding="utf-8")) or {}
|
|
|
|
def _apply_env_overrides(self) -> None:
|
|
env_map = {
|
|
"OVERUB_API_ID": ("bot", "api_id"),
|
|
"OVERUB_API_HASH": ("bot", "api_hash"),
|
|
"OVERUB_SESSION": ("bot", "session_name"),
|
|
"OVERUB_PREFIX": ("bot", "command_prefix"),
|
|
"OVERUB_LOG_LEVEL": ("logging", "level"),
|
|
"OVERUB_GIT_REMOTE": ("updates", "git", "remote"),
|
|
"OVERUB_GITEA_TOKEN": ("updates", "git", "token"),
|
|
"OVERUB_GITEA_API": ("updates", "gitea", "api_url"),
|
|
"OVERUB_GIT_BRANCH": ("updates", "branch"),
|
|
"OVERUB_WEBHOOK_SECRET": ("updates", "gitea", "webhook_secret"),
|
|
}
|
|
for env_key, path in env_map.items():
|
|
value = os.getenv(env_key)
|
|
if value is None:
|
|
continue
|
|
self._set_path(self._config, path, value)
|
|
self._apply_plugin_env_overrides()
|
|
|
|
def _apply_plugin_env_overrides(self) -> None:
|
|
prefix = "OVERUB_PLUGIN_"
|
|
for key, value in os.environ.items():
|
|
if not key.startswith(prefix):
|
|
continue
|
|
parts = key[len(prefix):].split("_")
|
|
if len(parts) < 2:
|
|
continue
|
|
plugin = parts[0].lower()
|
|
scope = parts[1].lower()
|
|
if scope == "enabled":
|
|
cfg = self.get_plugin_config(plugin)
|
|
cfg["enabled"] = value.lower() == "true"
|
|
self.set_plugin_config(plugin, cfg)
|
|
elif scope == "setting" and len(parts) >= 3:
|
|
setting_key = "_".join(parts[2:]).lower()
|
|
cfg = self.get_plugin_config(plugin)
|
|
cfg.setdefault("settings", {})[setting_key] = value
|
|
self.set_plugin_config(plugin, cfg)
|
|
|
|
def _set_path(self, data: Dict[str, Any], path: tuple, value: Any) -> None:
|
|
cursor = data
|
|
for key in path[:-1]:
|
|
cursor = cursor.setdefault(key, {})
|
|
cursor[path[-1]] = value
|
|
|
|
def get(self) -> Dict[str, Any]:
|
|
return self._config
|
|
|
|
def get_modules(self) -> Dict[str, Any]:
|
|
return self._modules
|
|
|
|
def get_module_config(self, name: str) -> Dict[str, Any]:
|
|
return self._modules.get("modules", {}).get(name, {})
|
|
|
|
def get_plugin_config(self, name: str) -> Dict[str, Any]:
|
|
plugins = self._config.setdefault("plugin_settings", {})
|
|
if name not in plugins:
|
|
plugins[name] = {
|
|
"enabled": True,
|
|
"settings": {},
|
|
"secrets": {},
|
|
"permissions": [],
|
|
"command_prefix": None,
|
|
"cooldown": 0,
|
|
"timeout": None,
|
|
"permission_level": "user",
|
|
"auto_update": True,
|
|
"max_memory_mb": None,
|
|
"max_cpu_percent": None,
|
|
}
|
|
return plugins[name]
|
|
|
|
def set_plugin_config(self, name: str, data: Dict[str, Any]) -> None:
|
|
plugins = self._config.setdefault("plugin_settings", {})
|
|
plugins[name] = data
|
|
|
|
def register_plugin_schema(self, name: str, schema: Dict[str, Any]) -> None:
|
|
self._plugin_schemas[name] = schema
|
|
|
|
def validate_plugin_config(self, name: str) -> List[str]:
|
|
schema = self._plugin_schemas.get(name)
|
|
if not schema:
|
|
return []
|
|
cfg = self.get_plugin_config(name)
|
|
errors = []
|
|
for key, expected_type in schema.items():
|
|
if key not in cfg:
|
|
errors.append(f"Missing key: {key}")
|
|
continue
|
|
value = cfg[key]
|
|
if expected_type and not isinstance(value, expected_type):
|
|
errors.append(f"Invalid type for {key}")
|
|
return errors
|
|
|
|
def migrate_plugin_config(self, name: str, new_config: Dict[str, Any]) -> None:
|
|
current = self.get_plugin_config(name)
|
|
current.update(new_config)
|
|
self.set_plugin_config(name, current)
|
|
|
|
def reload(self) -> None:
|
|
self.load()
|
|
|
|
def migrate(self, target_version: str) -> None:
|
|
self._config["config_version"] = target_version
|
|
|
|
def encrypt_value(self, value: str) -> str:
|
|
key = self._config.get("security", {}).get("secret_key")
|
|
if not key:
|
|
return value
|
|
try:
|
|
from cryptography.fernet import Fernet
|
|
except ImportError:
|
|
return value
|
|
fernet = Fernet(key.encode("utf-8"))
|
|
token = fernet.encrypt(value.encode("utf-8"))
|
|
return f"ENC:{token.decode('utf-8')}"
|
|
|
|
def decrypt_value(self, value: str) -> str:
|
|
if not isinstance(value, str) or not value.startswith("ENC:"):
|
|
return value
|
|
key = self._config.get("security", {}).get("secret_key")
|
|
if not key:
|
|
return value
|
|
try:
|
|
from cryptography.fernet import Fernet
|
|
except ImportError:
|
|
return value
|
|
fernet = Fernet(key.encode("utf-8"))
|
|
token = value[4:].encode("utf-8")
|
|
return fernet.decrypt(token).decode("utf-8")
|
|
|
|
def merge(self, updates: Dict[str, Any]) -> None:
|
|
_deep_update(self._config, updates)
|
|
|
|
def save(self) -> None:
|
|
self.config_path.write_text(
|
|
yaml.safe_dump(self._config, sort_keys=False),
|
|
encoding="utf-8",
|
|
)
|
|
|
|
def save_modules(self) -> None:
|
|
self.modules_path.write_text(
|
|
yaml.safe_dump(self._modules, sort_keys=False),
|
|
encoding="utf-8",
|
|
)
|
|
|
|
|
|
class PluginConfigProxy:
|
|
def __init__(self, manager: ConfigManager, plugin: str) -> None:
|
|
self._manager = manager
|
|
self._plugin = plugin
|
|
|
|
def get_plugin_config(self, name: str) -> Dict[str, Any]:
|
|
if name != self._plugin:
|
|
raise PermissionError("Access denied")
|
|
return self._manager.get_plugin_config(name)
|
|
|
|
def set_plugin_config(self, name: str, data: Dict[str, Any]) -> None:
|
|
if name != self._plugin:
|
|
raise PermissionError("Access denied")
|
|
self._manager.set_plugin_config(name, data)
|
|
|
|
def encrypt_value(self, value: str) -> str:
|
|
return self._manager.encrypt_value(value)
|
|
|
|
def decrypt_value(self, value: str) -> str:
|
|
return self._manager.decrypt_value(value)
|
|
|
|
def register_plugin_schema(self, name: str, schema: Dict[str, Any]) -> None:
|
|
if name != self._plugin:
|
|
raise PermissionError("Access denied")
|
|
self._manager.register_plugin_schema(name, schema)
|
|
|
|
def validate_plugin_config(self, name: str) -> List[str]:
|
|
if name != self._plugin:
|
|
raise PermissionError("Access denied")
|
|
return self._manager.validate_plugin_config(name)
|