adjust according to core.

This commit is contained in:
thedragonsinn
2025-03-03 17:17:00 +05:30
parent ce4b61f2a3
commit 5a8a67e270
5 changed files with 24 additions and 40 deletions

View File

@@ -1 +1,11 @@
from ub_core import BOT, DB, DB_CLIENT, LOGGER, Config, Convo, CustomDB, Message, bot
from ub_core import (
BOT,
DATABASE,
DATABASE_CLIENT,
LOGGER,
Config,
Convo,
CustomDB,
Message,
bot,
)

View File

@@ -1,7 +1,6 @@
import asyncio
from datetime import UTC, datetime, timedelta
from pyrogram import filters
from pyrogram.types import User
from app import BOT, Message
@@ -43,8 +42,7 @@ async def kick_inactive_members(bot: BOT, message: Message):
chat_id = message.chat.id
async with bot.Convo(
client=bot,
chat_id=chat_id,
client=bot, chat_id=chat_id, from_user=message.from_user.id
) as convo:
async for member in bot.get_chat_members(chat_id):
@@ -65,14 +63,7 @@ async def kick_inactive_members(bot: BOT, message: Message):
f"\nreply with y to continue"
)
async def user_filter(_, __, m: Message):
return (
m.from_user
and m.from_user.id == message.from_user.id
and m.reply_to_message_id == prompt.id
)
convo.filters = filters.create(user_filter)
convo.reply_to_message_id = prompt.id
confirmation = await convo.get_response()

View File

@@ -49,7 +49,8 @@ async def handle_media(prompt: str, media_message: Message, **kwargs) -> str:
uploaded_file = await async_client.files.get(name=uploaded_file.name)
response = await async_client.models.generate_content(
**kwargs, contents=[uploaded_file, prompt]
contents=[uploaded_file, prompt],
**kwargs,
)
return get_response_text(response, quoted=True)

View File

@@ -63,7 +63,7 @@ class Settings:
)
@staticmethod
def get_kwargs(use_search:bool=True) -> dict:
def get_kwargs(use_search:bool=False) -> dict:
tools = Settings.CONFIG.tools
if not use_search and Settings.SEARCH_TOOL in tools:
@@ -130,10 +130,6 @@ def get_response_text(response, quoted: bool = False, add_sources: bool = True):
return f"**>\n{final_text}<**" if quoted and "```" not in final_text else final_text
async def resp_filters(flt, __, m):
return m.reply_id == flt.message_id
@BOT.add_cmd(cmd="llms")
async def list_ai_models(bot: BOT, message: Message):
"""
@@ -148,6 +144,7 @@ async def list_ai_models(bot: BOT, message: Message):
]
model_str = "\n\n".join(model_list)
update_str = (
f"<b>Current Model</b>: <code>{Settings.MODEL}</code>\n\n"
f"<blockquote expandable=True><pre language=text>{model_str}</pre></blockquote>"
@@ -157,7 +154,7 @@ async def list_ai_models(bot: BOT, message: Message):
model_reply = await message.reply(update_str)
response = await model_reply.get_response(
filters=filters.create(resp_filters, message_id=model_reply.id), timeout=60
timeout=60, reply_to_message_id=model_reply.id, from_user=message.from_user.id
)
if not response:

View File

@@ -2,7 +2,6 @@
from io import BytesIO
from google.genai.chats import AsyncChat
from pyrogram import filters
from pyrogram.enums import ParseMode
from app import BOT, Convo, Message, bot
@@ -24,7 +23,7 @@ async def question(bot: BOT, message: Message):
CMD: AI
INFO: Ask a question to Gemini AI or get info about replied message / media.
FLAGS:
-ns: Not use Search
-s: to use Search
USAGE:
.ai what is the meaning of life.
.ai [reply to a message] (sends replied text as query)
@@ -43,7 +42,7 @@ async def question(bot: BOT, message: Message):
response_text = await handle_media(
prompt=prompt,
media_message=reply,
**Settings.get_kwargs(use_search="-ns" not in message.flags),
**Settings.get_kwargs(use_search="-s" in message.flags),
)
else:
message_response = await message.reply(
@@ -56,7 +55,7 @@ async def question(bot: BOT, message: Message):
response = await async_client.models.generate_content(
contents=prompts,
**Settings.get_kwargs(use_search="-ns" not in message.flags),
**Settings.get_kwargs(use_search="-s" in message.flags),
)
response_text = get_response_text(response, quoted=True)
@@ -126,9 +125,10 @@ async def do_convo(chat: AsyncChat, message: Message):
convo_obj = Convo(
client=message._client,
chat_id=chat_id,
filters=generate_filter(message),
timeout=300,
check_for_duplicates=False,
from_user=message.from_user.id,
reply_to_user_id=message._client.me.id,
)
CONVO_CACHE[message.unique_chat_user_id] = convo_obj
@@ -138,7 +138,7 @@ async def do_convo(chat: AsyncChat, message: Message):
while True:
ai_response = await chat.send_message(prompt)
ai_response_text = get_response_text(ai_response, quoted=True)
text = f"**GEMINI AI**:\n{ai_response_text}"
text = f"**GEMINI AI**:{ai_response_text}"
_, prompt_message = await convo_obj.send_message(
text=text,
reply_to_id=reply_to_id,
@@ -154,21 +154,6 @@ async def do_convo(chat: AsyncChat, message: Message):
CONVO_CACHE.pop(message.unique_chat_user_id, 0)
def generate_filter(message: Message):
async def _filter(_, __, msg: Message):
try:
assert (
msg.text
and msg.from_user.id == message.from_user.id
and msg.reply_to_message.from_user.id == message._client.me.id
)
return True
except (AssertionError, AttributeError):
return False
return filters.create(_filter)
async def export_history(chat: AsyncChat, message: Message):
doc = BytesIO(pickle.dumps(chat._curated_history))
doc.name = "AI_Chat_History.pkl"