ai: .aic now supports media

This commit is contained in:
thedragonsinn
2025-03-18 17:39:46 +05:30
parent eabd0ba454
commit 7903385588
4 changed files with 180 additions and 158 deletions

View File

@@ -1,14 +1,20 @@
import asyncio
import logging
import shutil
import time
from functools import wraps
from mimetypes import guess_type
from google.genai.client import AsyncClient, Client
from google.genai.types import (
DynamicRetrievalConfig,
File,
GenerateContentConfig,
GoogleSearchRetrieval,
SafetySetting,
Tool,
)
from ub_core.utils import get_tg_media_details
from app import BOT, CustomDB, Message, extra_config
@@ -23,61 +29,9 @@ except:
client = async_client = None
class Settings:
MODEL = "gemini-2.0-flash"
# fmt:off
CONFIG = GenerateContentConfig(
candidate_count=1,
system_instruction=(
"Answer precisely and in short unless specifically instructed otherwise."
"\nIF asked related to code, do not comment the code and do not explain the code unless instructed."
),
temperature=0.69,
max_output_tokens=1024,
safety_settings=[
# SafetySetting(category="HARM_CATEGORY_UNSPECIFIED", threshold="BLOCK_NONE"),
SafetySetting(category="HARM_CATEGORY_HATE_SPEECH", threshold="BLOCK_NONE"),
SafetySetting(category="HARM_CATEGORY_DANGEROUS_CONTENT", threshold="BLOCK_NONE"),
SafetySetting(category="HARM_CATEGORY_HARASSMENT", threshold="BLOCK_NONE"),
SafetySetting(category="HARM_CATEGORY_SEXUALLY_EXPLICIT", threshold="BLOCK_NONE"),
SafetySetting(category="HARM_CATEGORY_CIVIC_INTEGRITY", threshold="BLOCK_NONE"),
],
# fmt:on
tools=[],
)
SEARCH_TOOL = Tool(
google_search=GoogleSearchRetrieval(
dynamic_retrieval_config=DynamicRetrievalConfig(
dynamic_threshold=0.3
)
)
)
@staticmethod
def get_kwargs(use_search:bool=False) -> dict:
tools = Settings.CONFIG.tools
if not use_search and Settings.SEARCH_TOOL in tools:
tools.remove(Settings.SEARCH_TOOL)
if use_search and Settings.SEARCH_TOOL not in tools:
tools.append(Settings.SEARCH_TOOL)
return {"model": Settings.MODEL, "config": Settings.CONFIG}
async def init_task():
model_info = await DB_SETTINGS.find_one({"_id": "gemini_model_info"}) or {}
model_name = model_info.get("model_name")
if model_name:
if model_name := model_info.get("model_name"):
Settings.MODEL = model_name
@@ -129,6 +83,72 @@ def get_response_text(response, quoted: bool = False, add_sources: bool = True):
return f"**>\n{final_text}<**" if quoted and "```" not in final_text else final_text
async def save_file(message: Message, check_size: bool = True) -> File | None:
media = get_tg_media_details(message)
if check_size:
assert getattr(media, "file_size", 0) <= 1048576 * 25, "File size exceeds 25mb."
download_dir = f"downloads/{time.time()}/"
try:
downloaded_file: str = await message.download(download_dir)
uploaded_file = await async_client.files.upload(
file=downloaded_file,
config={
"mime_type": getattr(media, "mime_type", None) or guess_type(downloaded_file)[0]
},
)
while uploaded_file.state.name == "PROCESSING":
await asyncio.sleep(5)
uploaded_file = await async_client.files.get(name=uploaded_file.name)
return uploaded_file
finally:
shutil.rmtree(download_dir, ignore_errors=True)
PROMPT_MAP = {
"Video": "Summarize video and audio from the file",
"Photo": "Summarize the image file",
"Voice": (
"\nDo not summarise."
"\nTranscribe the audio file to english alphabets AS IS."
"\nTranslate it only if the audio is not english."
"\nIf the audio is in hindi: Transcribe it to hinglish without translating."
),
}
PROMPT_MAP["Audio"] = PROMPT_MAP["Voice"]
async def create_prompts(
message: Message, is_chat: bool = False, check_size: bool = True
) -> list[File, str] | list[str]:
default_media_prompt = "Analyse the file and explain."
input_prompt = message.filtered_input or "answer"
# Conversational
if is_chat:
if message.media:
prompt = message.caption or PROMPT_MAP.get(message.media.value) or default_media_prompt
return [await save_file(message=message, check_size=check_size), prompt]
else:
return [message.text]
# Single Use
if reply := message.replied:
if reply.media:
prompt = (
message.filtered_input or PROMPT_MAP.get(reply.media.value) or default_media_prompt
)
return [await save_file(message=reply, check_size=check_size), prompt]
else:
return [str(reply.text), input_prompt]
return [input_prompt]
@BOT.add_cmd(cmd="llms")
async def list_ai_models(bot: BOT, message: Message):
"""
@@ -166,14 +186,59 @@ async def list_ai_models(bot: BOT, message: Message):
)
return
await DB_SETTINGS.add_data(
{"_id": "gemini_model_info", "model_name": response.text}
await DB_SETTINGS.add_data({"_id": "gemini_model_info", "model_name": response.text})
resp_str = f"{response.text} saved as model."
await model_reply.edit(resp_str)
await bot.log_text(text=resp_str, type="ai")
Settings.MODEL = response.text
class Settings:
MODEL = "gemini-2.0-flash"
# fmt:off
CONFIG = GenerateContentConfig(
candidate_count=1,
system_instruction=(
"Answer precisely and in short unless specifically instructed otherwise."
"\nIF asked related to code, do not comment the code and do not explain the code unless instructed."
),
temperature=0.69,
max_output_tokens=1024,
safety_settings=[
# SafetySetting(category="HARM_CATEGORY_UNSPECIFIED", threshold="BLOCK_NONE"),
SafetySetting(category="HARM_CATEGORY_HATE_SPEECH", threshold="BLOCK_NONE"),
SafetySetting(category="HARM_CATEGORY_DANGEROUS_CONTENT", threshold="BLOCK_NONE"),
SafetySetting(category="HARM_CATEGORY_HARASSMENT", threshold="BLOCK_NONE"),
SafetySetting(category="HARM_CATEGORY_SEXUALLY_EXPLICIT", threshold="BLOCK_NONE"),
SafetySetting(category="HARM_CATEGORY_CIVIC_INTEGRITY", threshold="BLOCK_NONE"),
],
# fmt:on
tools=[],
)
resp_str = f"{response.text} saved as model."
SEARCH_TOOL = Tool(
google_search=GoogleSearchRetrieval(
dynamic_retrieval_config=DynamicRetrievalConfig(
dynamic_threshold=0.3
)
)
)
await model_reply.edit(resp_str)
@staticmethod
def get_kwargs(use_search:bool=False) -> dict:
tools = Settings.CONFIG.tools
await bot.log_text(text=resp_str, type="ai")
if not use_search and Settings.SEARCH_TOOL in tools:
tools.remove(Settings.SEARCH_TOOL)
Settings.MODEL = response.text
if use_search and Settings.SEARCH_TOOL not in tools:
tools.append(Settings.SEARCH_TOOL)
return {"model": Settings.MODEL, "config": Settings.CONFIG}

View File

@@ -1,58 +0,0 @@
import asyncio
import os
import shutil
import time
from mimetypes import guess_type
from pyrogram.types.messages_and_media import Audio, Photo, Video, Voice
from ub_core.utils import get_tg_media_details
from app import Message
from app.plugins.ai.models import async_client, get_response_text
PROMPT_MAP = {
Video: "Summarize video and audio from the file",
Photo: "Summarize the image file",
Voice: (
"\nDo not summarise."
"\nTranscribe the audio file to english alphabets AS IS."
"\nTranslate it only if the audio is not english."
"\nIf the audio is in hindi: Transcribe it to hinglish without translating."
),
}
PROMPT_MAP[Audio] = PROMPT_MAP[Voice]
async def handle_media(prompt: str, media_message: Message, **kwargs) -> str:
media = get_tg_media_details(media_message)
if getattr(media, "file_size", 0) >= 1048576 * 25:
return "Error: File Size exceeds 25mb."
prompt = prompt.strip() or PROMPT_MAP.get(
type(media), "Analyse the file and explain."
)
download_dir = os.path.join("downloads", str(time.time())) + "/"
downloaded_file: str = await media_message.download(download_dir)
try:
uploaded_file = await async_client.files.upload(
file=downloaded_file,
config={
"mime_type": getattr(media, "mime_type", guess_type(downloaded_file)[0])
},
)
while uploaded_file.state.name == "PROCESSING":
await asyncio.sleep(5)
uploaded_file = await async_client.files.get(name=uploaded_file.name)
response = await async_client.models.generate_content(
contents=[uploaded_file, prompt],
**kwargs,
)
return get_response_text(response, quoted=True)
finally:
shutil.rmtree(download_dir, ignore_errors=True)

View File

@@ -7,7 +7,7 @@ from pyrogram.enums import ParseMode
from pyrogram.types import InputMediaPhoto
from app import BOT, Message
from app.plugins.ai.models import Settings
from app.plugins.ai.gemini_core import Settings
OPENAI_CLIENT = environ.get("OPENAI_CLIENT", "")
OPENAI_MODEL = environ.get("OPENAI_MODEL", "gpt-4o")
@@ -94,8 +94,7 @@ async def chat_gpt(bot: BOT, message: Message):
response = chat_completion.choices[0].message.content
await message.reply(
text=f"**>\n••>{prompt}<**\n**>{response}\n<**",
parse_mode=ParseMode.MARKDOWN,
text=f"**>\n••>{prompt}<**\n**>{response}\n<**", parse_mode=ParseMode.MARKDOWN
)
@@ -158,8 +157,6 @@ async def chat_gpt(bot: BOT, message: Message):
await response.edit_media(
InputMediaPhoto(
media=image_io,
caption=f"**>\n{prompt}\n<**",
has_spoiler="-s" in message.flags,
media=image_io, caption=f"**>\n{prompt}\n<**", has_spoiler="-s" in message.flags
)
)

View File

@@ -2,19 +2,17 @@
from io import BytesIO
from google.genai.chats import AsyncChat
from pyrogram.enums import ParseMode
from pyrogram.enums import ChatType, ParseMode
from app import BOT, Convo, Message, bot
from app.plugins.ai.media_query import handle_media
from app.plugins.ai.models import (
from app.plugins.ai.gemini_core import (
Settings,
async_client,
create_prompts,
get_response_text,
run_basic_check,
)
CONVO_CACHE: dict[str, Convo] = {}
@bot.add_cmd(cmd="ai")
@run_basic_check
@@ -36,28 +34,23 @@ async def question(bot: BOT, message: Message):
prompt = message.filtered_input
if reply and reply.media:
message_response = await message.reply(
"<code>Processing... this may take a while.</code>"
)
response_text = await handle_media(
prompt=prompt,
media_message=reply,
**Settings.get_kwargs(use_search="-s" in message.flags),
)
message_response = await message.reply("<code>Processing... this may take a while.</code>")
else:
message_response = await message.reply(
"<code>Input received... generating response.</code>"
)
if reply and reply.text:
prompts = [str(reply.text), message.input or "answer"]
else:
prompts = [message.input]
response = await async_client.models.generate_content(
contents=prompts,
**Settings.get_kwargs(use_search="-s" in message.flags),
)
response_text = get_response_text(response, quoted=True)
try:
prompts = await create_prompts(message=message)
except AssertionError as e:
await message_response.edit(e)
return
response = await async_client.models.generate_content(
contents=prompts, **Settings.get_kwargs(use_search="-s" in message.flags)
)
response_text = get_response_text(response, quoted=True)
await message_response.edit(
text=f"**>\n•> {prompt}<**\n{response_text}",
@@ -72,13 +65,17 @@ async def ai_chat(bot: BOT, message: Message):
"""
CMD: AICHAT
INFO: Have a Conversation with Gemini AI.
FLAGS:
"-s": use search
USAGE:
.aic hello
keep replying to AI responses
keep replying to AI responses with text | media [no need to reply in DM]
After 5 mins of Idle bot will export history and stop chat.
use .load_history to continue
"""
chat = async_client.chats.create(**Settings.get_kwargs())
chat = async_client.chats.create(**Settings.get_kwargs("-s" in message.flags))
await do_convo(chat=chat, message=message)
@@ -109,49 +106,70 @@ async def history_chat(bot: BOT, message: Message):
history = pickle.load(doc)
await resp.edit("__History Loaded... Resuming chat__")
chat = async_client.chats.create(**Settings.get_kwargs(), history=history)
chat = async_client.chats.create(
**Settings.get_kwargs(use_search="-s" in message.flags), history=history
)
await do_convo(chat=chat, message=message)
CONVO_CACHE: dict[str, Convo] = {}
async def do_convo(chat: AsyncChat, message: Message):
prompt = message.input
reply_to_id = message.id
chat_id = message.chat.id
old_convo = CONVO_CACHE.get(message.unique_chat_user_id)
if old_convo in Convo.CONVO_DICT[chat_id]:
Convo.CONVO_DICT[chat_id].remove(old_convo)
if message.chat.type in (ChatType.PRIVATE, ChatType.BOT):
reply_to_user_id = None
else:
reply_to_user_id = message._client.me.id
convo_obj = Convo(
client=message._client,
chat_id=chat_id,
timeout=300,
check_for_duplicates=False,
from_user=message.from_user.id,
reply_to_user_id=message._client.me.id,
reply_to_user_id=reply_to_user_id,
)
CONVO_CACHE[message.unique_chat_user_id] = convo_obj
try:
async with convo_obj:
prompt = [message.input]
reply_to_id = message.id
while True:
ai_response = await chat.send_message(prompt)
ai_response_text = get_response_text(ai_response, quoted=True)
text = f"**GENAI:**\n{ai_response_text}"
ai_response_text = get_response_text(ai_response, add_sources=True, quoted=True)
_, prompt_message = await convo_obj.send_message(
text=text,
text=f"**>\n.><**\n{ai_response_text}",
reply_to_id=reply_to_id,
parse_mode=ParseMode.MARKDOWN,
get_response=True,
disable_preview=True,
)
prompt, reply_to_id = prompt_message.text, prompt_message.id
except TimeoutError:
try:
prompt = await create_prompts(message=prompt_message, is_chat=True)
except Exception as e:
_, prompt_message = await convo_obj.send_message(
text=str(e),
reply_to_id=reply_to_id,
parse_mode=ParseMode.MARKDOWN,
get_response=True,
disable_preview=True,
)
reply_to_id = prompt_message.id
finally:
await export_history(chat, message)
CONVO_CACHE.pop(message.unique_chat_user_id, 0)
CONVO_CACHE.pop(message.unique_chat_user_id, 0)
async def export_history(chat: AsyncChat, message: Message):