diff --git a/app/__init__.py b/app/__init__.py index 99dcd42..a37cf65 100644 --- a/app/__init__.py +++ b/app/__init__.py @@ -1,11 +1,12 @@ import os from dotenv import load_dotenv + load_dotenv("config.env") -from app.config import Config -from app.core.db import DB -from app.core.client.client import BOT +from app.config import Config # NOQA +from app.core.db import DB # NOQA +from app.core.client.client import BOT # NOQA if "com.termux" not in os.environ.get("PATH", ""): import uvloop @@ -13,3 +14,5 @@ if "com.termux" not in os.environ.get("PATH", ""): uvloop.install() bot = BOT() + +from app.core.client.conversation import Conversation as Convo # NOQA diff --git a/app/config.py b/app/config.py index 6c41289..441af19 100644 --- a/app/config.py +++ b/app/config.py @@ -1,14 +1,18 @@ import json import os -from typing import Coroutine +from typing import Callable + +from pyrogram.filters import Filter +from pyrogram.types import Message -class Config: - CMD_DICT: dict["str", Coroutine] = {} +class _Config: + def __str__(self): + return json.dumps(self.__dict__, indent=4, ensure_ascii=False) - CALLBACK_DICT: dict["str", Coroutine] = {} + CMD_DICT: dict["str", Callable] = {} - CONVO_DICT = {} + CONVO_DICT: dict[int, dict[str | int, Message | Filter | None]] = {} DEV_MODE: int = int(os.environ.get("DEV_MODE", 0)) @@ -23,3 +27,6 @@ class Config: UPSTREAM_REPO: str = os.environ.get( "UPSTREAM_REPO", "https://github.com/thedragonsinn/plain-ub" ) + + +Config = _Config() diff --git a/app/core/client/conversation.py b/app/core/client/conversation.py new file mode 100644 index 0000000..bb96a66 --- /dev/null +++ b/app/core/client/conversation.py @@ -0,0 +1,48 @@ +import asyncio +import json + +from pyrogram.filters import Filter +from pyrogram.types import Message + +from app import Config, bot + + +class Conversation: + class DuplicateConvo(Exception): + def __init__(self, chat: str | int | None = None): + text = "Conversation already started" + if chat: + text += f" with {chat}" + super().__init__(text) + + class TimeOutError(Exception): + def __init__(self): + super().__init__("Conversation Timeout") + + def __init__(self, chat_id: int, filters: Filter | None = None, timeout: int = 10): + self._client = bot + self.chat_id = chat_id + self.filters = filters + self.timeout = timeout + + def __str__(self): + return json.dumps(self.__dict__, indent=4, ensure_ascii=False) + + async def get_response(self, timeout: int | None = None) -> Message | None: + try: + async with asyncio.timeout(timeout or self.timeout): + while not Config.CONVO_DICT[self.chat_id]["response"]: + await asyncio.sleep(0) + return Config.CONVO_DICT[self.chat_id]["response"] + except asyncio.TimeoutError: + raise self.TimeOutError + + async def __aenter__(self) -> "Conversation": + if self.chat_id in Config.CONVO_DICT: + raise self.DuplicateConvo(self.chat_id) + convo_dict = {"filters": self.filters, "response": None} + Config.CONVO_DICT[self.chat_id] = convo_dict + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + Config.CONVO_DICT.pop(self.chat_id, "") diff --git a/app/core/client/filters.py b/app/core/client/filters.py index e05e1e5..6ffe947 100644 --- a/app/core/client/filters.py +++ b/app/core/client/filters.py @@ -20,3 +20,7 @@ def dynamic_cmd_filter(_, __, message) -> bool: cmd_filter = _filters.create(dynamic_cmd_filter) +convo_filter = _filters.create( + lambda _, __, message: (message.chat.id in Config.CONVO_DICT) + and (not message.reactions) +) diff --git a/app/core/client/handler.py b/app/core/client/handler.py index 6eeab4e..0d67637 100644 --- a/app/core/client/handler.py +++ b/app/core/client/handler.py @@ -1,15 +1,14 @@ import asyncio import traceback -from datetime import datetime -from pyrogram.enums import ChatType +from pyrogram.types import Message as Msg -from app import DB, Config, bot -from app.core import CallbackQuery, Message, filters +from app import Config, bot +from app.core import Message, filters -@bot.on_message(filters.cmd_filter) -@bot.on_edited_message(filters.cmd_filter) +@bot.on_message(filters.cmd_filter, group=1) +@bot.on_edited_message(filters.cmd_filter, group=1) async def cmd_dispatcher(bot, message) -> None: message = Message.parse_message(message) func = Config.CMD_DICT[message.cmd] @@ -17,22 +16,19 @@ async def cmd_dispatcher(bot, message) -> None: await run_coro(coro, message) -@bot.on_callback_query() -async def callback_handler(bot: bot, cb): - if ( - cb.message.chat.type == ChatType.PRIVATE - and (datetime.now() - cb.message.date).total_seconds() > 30 - ): - return await cb.edit_message_text(f"Query Expired. Try again.") - banned = await DB.BANNED.find_one({"_id": cb.from_user.id}) - if banned: - return - cb = CallbackQuery.parse_cb(cb) - func = Config.CALLBACK_DICT.get(cb.cmd) - if not func: - return - coro = func(bot, cb) - await run_coro(coro, Message.parse_message(cb.message)) +@bot.on_message(filters.convo_filter, group=0) +@bot.on_edited_message(filters.convo_filter, group=0) +async def convo_handler(bot: bot, message: Msg): + conv_dict: dict = Config.CONVO_DICT[message.chat.id] + conv_filters = conv_dict.get("filters") + if conv_filters: + check = await conv_filters(bot, message) + if not check: + message.continue_propagation() + conv_dict["response"] = message + message.continue_propagation() + conv_dict["response"] = message + message.continue_propagation() async def run_coro(coro, message) -> None: diff --git a/app/plugins/tg_utils.py b/app/plugins/tg_utils.py index 553a440..84416c4 100644 --- a/app/plugins/tg_utils.py +++ b/app/plugins/tg_utils.py @@ -57,7 +57,7 @@ async def join_chat(bot: bot, message: Message) -> None: await bot.join_chat(os.path.basename(chat).strip()) except Exception as e: await message.reply(str(e)) - return + return await message.reply("Joined")