PMGuard and AioHttp session, site changes.

This commit is contained in:
thedragonsinn
2024-01-12 17:19:42 +05:30
parent e03fe03d95
commit da8a697acc
11 changed files with 241 additions and 71 deletions

View File

@@ -27,6 +27,8 @@ class _Config:
self.LOG_CHAT: int = int(os.environ.get("LOG_CHAT"))
self.PM_GUARD: bool = False
self.REPO = Repo(".")
self.SUDO: bool = False

View File

@@ -12,7 +12,7 @@ from pyrogram.types import Message as Msg
from app import DB_CLIENT, LOGGER, Config, Message
from app.core.decorators.add_cmd import AddCmd
from app.utils import aiohttp_tools
from app.utils.aiohttp_tools import aio
def import_modules():
@@ -67,7 +67,7 @@ class BOT(Client, AddCmd):
await self.log_text(text="<i>Started</i>")
LOGGER.info("Idling...")
await idle()
await aiohttp_tools.init_task()
await aio.session.close()
LOGGER.info("DB Closed.")
DB_CLIENT.close()
@@ -92,29 +92,39 @@ class BOT(Client, AddCmd):
parse_mode=parse_mode,
)) # fmt:skip
async def log_message(self, message: Message | Msg):
@staticmethod
async def log_message(message: Message | Msg):
return (await message.copy(chat_id=Config.LOG_CHAT)) # fmt: skip
async def restart(self, hard=False) -> None:
await aiohttp_tools.init_task()
await aio.session.close()
await super().stop(block=False)
LOGGER.info("Closing DB....")
LOGGER.info("Closing DB...")
DB_CLIENT.close()
if hard:
os.remove("logs/app_logs.txt")
os.execl("/bin/bash", "/bin/bash", "run")
LOGGER.info("Restarting......")
LOGGER.info("Restarting...")
os.execl(sys.executable, sys.executable, "-m", "app")
async def send_message(
self, chat_id: int | str, text, name: str = "output.txt", **kwargs
self,
chat_id: int | str,
text,
name: str = "output.txt",
disable_web_page_preview: bool = False,
**kwargs,
) -> Message | Msg:
text = str(text)
if len(text) < 4096:
return Message.parse_message(
(await super().send_message(chat_id=chat_id, text=text, **kwargs))
message = await super().send_message(
chat_id=chat_id,
text=text,
disable_web_page_preview=disable_web_page_preview,
**kwargs,
)
return Message.parse_message(message=message)
doc = BytesIO(bytes(text, encoding="utf-8"))
doc.name = name
kwargs.pop("disable_web_page_preview", "")
# fmt:skip
return await super().send_document(chat_id=chat_id, document=doc, **kwargs)

View File

@@ -17,13 +17,21 @@ class CustomDB(AsyncIOMotorCollection):
super().__init__(database=DB, name=collection_name)
async def add_data(self, data: dict) -> None:
found = await self.find_one(data)
"""
:param data: {"_id": db_id, rest of the data}
entry is added or updated if exists.
"""
found = await self.find_one({"_id": data["_id"]})
if not found:
await self.insert_one(data)
else:
await self.update_one({"_id": data.pop("_id")}, {"$set": data})
async def delete_data(self, id: int | str) -> bool | None:
"""
:param id: the db id key to delete.
:return: True if entry was deleted.
"""
found = await self.find_one({"_id": id})
if found:
await self.delete_one({"_id": id})

View File

@@ -24,3 +24,4 @@ basicConfig(
getLogger("pyrogram").setLevel(WARNING)
getLogger("httpx").setLevel(WARNING)
getLogger("aiohttp.access").setLevel(WARNING)

View File

@@ -7,7 +7,8 @@ from io import StringIO
from pyrogram.enums import ParseMode
from app import Config, bot, BOT, Message, DB, DB_CLIENT, try_ # isort:skip
from app.utils import shell, aiohttp_tools as aio # isort:skip
from app.utils import shell # isort:skip
from app.utils.aiohttp_tools import aio # isort:skip
async def executor(bot: BOT, message: Message) -> Message | None:

View File

@@ -3,7 +3,7 @@
import os
from app import BOT, Message, bot
from app.utils import aiohttp_tools
from app.utils.aiohttp_tools import aio
from app.utils.helpers import post_to_telegraph as post_tgh
@@ -13,7 +13,7 @@ async def get_json(endpoint: str, query: dict, key=os.environ.get("DEBRID_TOKEN"
return "API key not found."
api = "https://api.alldebrid.com/v4" + endpoint
params = {"agent": "bot", "apikey": key, **query}
async with aiohttp_tools.SESSION.get(url=api, params=params) as ses:
async with aio.session.get(url=api, params=params) as ses:
try:
json = await ses.json()
return json

View File

@@ -7,7 +7,7 @@ from urllib.parse import urlparse
import yt_dlp
from app import Message, bot
from app.utils.aiohttp_tools import in_memory_dl
from app.utils.aiohttp_tools import aio
domains = [
"www.youtube.com",
@@ -63,7 +63,7 @@ async def song_dl(bot: bot, message: Message) -> None | Message:
yt_info: str = yt_info["entries"][0]
duration: int = yt_info["duration"]
artist: str = yt_info["channel"]
thumb = await in_memory_dl(yt_info["thumbnail"])
thumb = await aio.in_memory_dl(yt_info["thumbnail"])
down_path: list = glob.glob(dl_path + "*")
if not down_path:
return await response.edit("Not found")

View File

@@ -1,3 +1,5 @@
import asyncio
from pyrogram.types import User
from app import BOT, Config, CustomDB, Message, bot
@@ -29,8 +31,10 @@ async def sudo(bot: BOT, message: Message):
return
value = not Config.SUDO
Config.SUDO = value
await SUDO.add_data({"_id": "sudo_switch", "value": value})
await message.reply(text=f"Sudo is enabled: <b>{value}</b>!", del_in=8)
await asyncio.gather(
SUDO.add_data({"_id": "sudo_switch", "value": value}),
message.reply(text=f"Sudo is enabled: <b>{value}</b>!", del_in=8),
)
@bot.add_cmd(cmd="addsudo")

View File

@@ -0,0 +1,135 @@
import asyncio
from pyrogram import filters
from pyrogram.enums import ChatType
from app import BOT, Config, CustomDB, Message, bot
from app.utils.helpers import get_name
PM_USERS = CustomDB("PM_USERS")
PM_GUARD = CustomDB("PM_GUARD")
ALLOWED_USERS: list[int] = []
allowed_filter = filters.create(lambda _, __, m: m.from_user.id in ALLOWED_USERS)
guard_check = filters.create(lambda _, __, ___: Config.PM_GUARD)
RECENT_USERS: dict = {}
async def init_task():
guard = await PM_GUARD.find_one({"_id": "guard_switch"})
if not guard:
return
global ALLOWED_USERS
ALLOWED_USERS = [user_id["_id"] async for user_id in PM_USERS.find()]
Config.PM_GUARD = guard["value"]
@bot.on_message(
(guard_check & filters.private & filters.incoming) & ~allowed_filter, group=0
)
async def handle_new_pm(bot: BOT, message: Message):
user_id = message.from_user.id
RECENT_USERS[user_id] = RECENT_USERS.get(user_id, 0)
if RECENT_USERS[user_id] == 0:
await bot.log_text(
text=f"#PMGUARD\n{message.from_user.mention} [{user_id}] has messaged you.",
type="info",
)
RECENT_USERS[user_id] += 1
if RECENT_USERS[user_id] >= 5:
await message.reply("You've been blocked for spamming.")
await bot.block_user(user_id)
RECENT_USERS.pop(user_id)
await bot.log_text(
text=f"#PMGUARD\n{message.from_user.mention} [{user_id}] has been blocked for spamming.",
type="info",
)
return
if RECENT_USERS[user_id] % 2:
await message.reply(
"You are not authorised to PM.\nWait until you get authorised."
)
@bot.on_message(guard_check & filters.private & filters.outgoing, group=2)
async def auto_approve(bot: BOT, message: Message):
if message.chat.id in ALLOWED_USERS:
return
message = Message.parse_message(message=message)
await message.reply("Auto-Approved to PM.", del_in=5)
await PM_USERS.insert_one({"_id": message.chat.id})
ALLOWED_USERS.append(message.chat.id)
@bot.add_cmd(cmd="pmguard")
async def pmguard(bot: BOT, message: Message):
"""
CMD: PMGUARD
INFO: Enable/Disable PM GUARD.
FLAGS: -c to check guard status.
USAGE:
.pmguard | .pmguard -c
"""
if "-c" in message.flags:
await message.reply(
text=f"PM Guard is enabled: <b>{Config.PM_GUARD}</b> .", del_in=8
)
return
value = not Config.PM_GUARD
Config.PM_GUARD = value
await asyncio.gather(
PM_GUARD.add_data({"_id": "guard_switch", "value": value}),
message.reply(text=f"PM Guard is enabled: <b>{value}</b>!", del_in=8),
)
@bot.add_cmd(cmd=["a", "allow"])
async def allow_pm(bot: BOT, message: Message):
user_id, name = get_user_name(message)
if not user_id:
await message.reply(
"Unable to extract User to allow.\n<code>Give user id | Reply to a user | use in PM.</code>"
)
return
if user_id in ALLOWED_USERS:
await message.reply(f"{name} is already approved.")
return
ALLOWED_USERS.append(user_id)
await asyncio.gather(
message.reply(text=f"{name} allowed to PM.", del_in=8),
PM_USERS.insert_one({"_id": user_id}),
)
@bot.add_cmd(cmd="nopm")
async def no_pm(bot: BOT, message: Message):
user_id, name = get_user_name(message)
if not user_id:
await message.reply(
"Unable to extract User to Dis-allow.\n<code>Give user id | Reply to a user | use in PM.</code>"
)
return
if user_id not in ALLOWED_USERS:
await message.reply(f"{name} is not approved to PM.")
return
ALLOWED_USERS.remove(user_id)
await asyncio.gather(
message.reply(text=f"{name} Dis-allowed to PM.", del_in=8),
PM_USERS.delete_data(user_id),
)
def get_user_name(message: Message) -> tuple:
if message.flt_input and message.flt_input.isdigit():
user_id = int(message.flt_input)
return user_id, user_id
elif message.replied:
return message.replied.from_user.id, get_name(message.replied.from_user)
elif message.chat.type == ChatType.PRIVATE:
return message.chat.id, get_name(message.chat)
else:
return 0, 0

View File

@@ -1,49 +1,69 @@
import json
import os
from io import BytesIO
import aiohttp
from aiohttp import ClientSession, web
from app import LOGGER, Config
from app.utils.media_helper import get_filename
SESSION: aiohttp.ClientSession | None = None
class Aio:
def __init__(self):
self.session: ClientSession | None = None
self.app = None
self.site = None
self.port = os.environ.get("API_PORT", 0)
self.runner = None
if self.port:
Config.INIT_TASKS.append(self.set_site())
Config.INIT_TASKS.append(self.set_session())
async def set_session(self):
self.session = ClientSession()
async def set_site(self):
LOGGER.info("Starting Static WebSite.")
self.app = web.Application()
self.app.router.add_get(path="/", handler=self.handle_request)
self.runner = web.AppRunner(self.app)
await self.runner.setup()
self.site = web.TCPSite(self.runner, "0.0.0.0", self.port)
await self.site.start()
async def handle_request(self, _):
return web.Response(text="Web Server Running...")
async def get_json(
self,
url: str,
headers: dict = None,
params: dict | str = None,
json_: bool = False,
timeout: int = 10,
) -> dict | None:
try:
async with self.session.get(
url=url, headers=headers, params=params, timeout=timeout
) as ses:
if json_:
return await ses.json()
else:
return json.loads(await ses.text())
except BaseException:
return
async def in_memory_dl(self, url: str) -> BytesIO:
async with self.session.get(url) as remote_file:
bytes_data = await remote_file.read()
file = BytesIO(bytes_data)
file.name = get_filename(url)
return file
async def thumb_dl(self, thumb) -> BytesIO | str | None:
if not thumb or not thumb.startswith("http"):
return thumb
return await in_memory_dl(thumb) # NOQA
async def init_task() -> None:
if not SESSION:
globals().update({"SESSION": aiohttp.ClientSession()})
else:
await SESSION.close()
async def get_json(
url: str,
headers: dict = None,
params: dict = None,
json_: bool = False,
timeout: int = 10,
) -> dict | None:
try:
async with SESSION.get(
url=url, headers=headers, params=params, timeout=timeout
) as ses:
if json_:
ret_json = await ses.json()
else:
ret_json = json.loads(await ses.text())
return ret_json
except BaseException:
return
async def in_memory_dl(url: str) -> BytesIO:
async with SESSION.get(url) as remote_file:
bytes_data = await remote_file.read()
file = BytesIO(bytes_data)
file.name = get_filename(url)
return file
async def thumb_dl(thumb) -> BytesIO | str | None:
if not thumb or not thumb.startswith("http"):
return thumb
return await in_memory_dl(thumb) # NOQA
aio = Aio()

13
run
View File

@@ -1,18 +1,7 @@
#!/bin/sh
if [ "$API_PORT" ] ; then
py_code="
from aiohttp import web
app = web.Application()
app.router.add_get('/', lambda _: web.Response(text='Web Server Running...'))
web.run_app(app, host='0.0.0.0', port=$API_PORT, reuse_port=True, print=None)
"
python3 -q -c "$py_code" & echo "Dummy Web Server Started..."
fi
if ! [ -d ".git" ] ; then
git init
fi
python3 -m app
python3 -m app