ai: .aic now supports media
This commit is contained in:
@@ -1,14 +1,20 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import shutil
|
||||
import time
|
||||
from functools import wraps
|
||||
from mimetypes import guess_type
|
||||
|
||||
from google.genai.client import AsyncClient, Client
|
||||
from google.genai.types import (
|
||||
DynamicRetrievalConfig,
|
||||
File,
|
||||
GenerateContentConfig,
|
||||
GoogleSearchRetrieval,
|
||||
SafetySetting,
|
||||
Tool,
|
||||
)
|
||||
from ub_core.utils import get_tg_media_details
|
||||
|
||||
from app import BOT, CustomDB, Message, extra_config
|
||||
|
||||
@@ -23,61 +29,9 @@ except:
|
||||
client = async_client = None
|
||||
|
||||
|
||||
class Settings:
|
||||
MODEL = "gemini-2.0-flash"
|
||||
|
||||
# fmt:off
|
||||
CONFIG = GenerateContentConfig(
|
||||
|
||||
candidate_count=1,
|
||||
|
||||
system_instruction=(
|
||||
"Answer precisely and in short unless specifically instructed otherwise."
|
||||
"\nIF asked related to code, do not comment the code and do not explain the code unless instructed."
|
||||
),
|
||||
|
||||
temperature=0.69,
|
||||
|
||||
max_output_tokens=1024,
|
||||
|
||||
safety_settings=[
|
||||
# SafetySetting(category="HARM_CATEGORY_UNSPECIFIED", threshold="BLOCK_NONE"),
|
||||
SafetySetting(category="HARM_CATEGORY_HATE_SPEECH", threshold="BLOCK_NONE"),
|
||||
SafetySetting(category="HARM_CATEGORY_DANGEROUS_CONTENT", threshold="BLOCK_NONE"),
|
||||
SafetySetting(category="HARM_CATEGORY_HARASSMENT", threshold="BLOCK_NONE"),
|
||||
SafetySetting(category="HARM_CATEGORY_SEXUALLY_EXPLICIT", threshold="BLOCK_NONE"),
|
||||
SafetySetting(category="HARM_CATEGORY_CIVIC_INTEGRITY", threshold="BLOCK_NONE"),
|
||||
],
|
||||
# fmt:on
|
||||
|
||||
tools=[],
|
||||
)
|
||||
|
||||
SEARCH_TOOL = Tool(
|
||||
google_search=GoogleSearchRetrieval(
|
||||
dynamic_retrieval_config=DynamicRetrievalConfig(
|
||||
dynamic_threshold=0.3
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_kwargs(use_search:bool=False) -> dict:
|
||||
tools = Settings.CONFIG.tools
|
||||
|
||||
if not use_search and Settings.SEARCH_TOOL in tools:
|
||||
tools.remove(Settings.SEARCH_TOOL)
|
||||
|
||||
if use_search and Settings.SEARCH_TOOL not in tools:
|
||||
tools.append(Settings.SEARCH_TOOL)
|
||||
|
||||
return {"model": Settings.MODEL, "config": Settings.CONFIG}
|
||||
|
||||
|
||||
async def init_task():
|
||||
model_info = await DB_SETTINGS.find_one({"_id": "gemini_model_info"}) or {}
|
||||
model_name = model_info.get("model_name")
|
||||
if model_name:
|
||||
if model_name := model_info.get("model_name"):
|
||||
Settings.MODEL = model_name
|
||||
|
||||
|
||||
@@ -129,6 +83,72 @@ def get_response_text(response, quoted: bool = False, add_sources: bool = True):
|
||||
return f"**>\n{final_text}<**" if quoted and "```" not in final_text else final_text
|
||||
|
||||
|
||||
async def save_file(message: Message, check_size: bool = True) -> File | None:
|
||||
media = get_tg_media_details(message)
|
||||
|
||||
if check_size:
|
||||
assert getattr(media, "file_size", 0) <= 1048576 * 25, "File size exceeds 25mb."
|
||||
|
||||
download_dir = f"downloads/{time.time()}/"
|
||||
try:
|
||||
downloaded_file: str = await message.download(download_dir)
|
||||
uploaded_file = await async_client.files.upload(
|
||||
file=downloaded_file,
|
||||
config={
|
||||
"mime_type": getattr(media, "mime_type", None) or guess_type(downloaded_file)[0]
|
||||
},
|
||||
)
|
||||
while uploaded_file.state.name == "PROCESSING":
|
||||
await asyncio.sleep(5)
|
||||
uploaded_file = await async_client.files.get(name=uploaded_file.name)
|
||||
|
||||
return uploaded_file
|
||||
|
||||
finally:
|
||||
shutil.rmtree(download_dir, ignore_errors=True)
|
||||
|
||||
|
||||
PROMPT_MAP = {
|
||||
"Video": "Summarize video and audio from the file",
|
||||
"Photo": "Summarize the image file",
|
||||
"Voice": (
|
||||
"\nDo not summarise."
|
||||
"\nTranscribe the audio file to english alphabets AS IS."
|
||||
"\nTranslate it only if the audio is not english."
|
||||
"\nIf the audio is in hindi: Transcribe it to hinglish without translating."
|
||||
),
|
||||
}
|
||||
PROMPT_MAP["Audio"] = PROMPT_MAP["Voice"]
|
||||
|
||||
|
||||
async def create_prompts(
|
||||
message: Message, is_chat: bool = False, check_size: bool = True
|
||||
) -> list[File, str] | list[str]:
|
||||
|
||||
default_media_prompt = "Analyse the file and explain."
|
||||
input_prompt = message.filtered_input or "answer"
|
||||
|
||||
# Conversational
|
||||
if is_chat:
|
||||
if message.media:
|
||||
prompt = message.caption or PROMPT_MAP.get(message.media.value) or default_media_prompt
|
||||
return [await save_file(message=message, check_size=check_size), prompt]
|
||||
else:
|
||||
return [message.text]
|
||||
|
||||
# Single Use
|
||||
if reply := message.replied:
|
||||
if reply.media:
|
||||
prompt = (
|
||||
message.filtered_input or PROMPT_MAP.get(reply.media.value) or default_media_prompt
|
||||
)
|
||||
return [await save_file(message=reply, check_size=check_size), prompt]
|
||||
else:
|
||||
return [str(reply.text), input_prompt]
|
||||
|
||||
return [input_prompt]
|
||||
|
||||
|
||||
@BOT.add_cmd(cmd="llms")
|
||||
async def list_ai_models(bot: BOT, message: Message):
|
||||
"""
|
||||
@@ -166,14 +186,59 @@ async def list_ai_models(bot: BOT, message: Message):
|
||||
)
|
||||
return
|
||||
|
||||
await DB_SETTINGS.add_data(
|
||||
{"_id": "gemini_model_info", "model_name": response.text}
|
||||
await DB_SETTINGS.add_data({"_id": "gemini_model_info", "model_name": response.text})
|
||||
resp_str = f"{response.text} saved as model."
|
||||
await model_reply.edit(resp_str)
|
||||
await bot.log_text(text=resp_str, type="ai")
|
||||
Settings.MODEL = response.text
|
||||
|
||||
|
||||
class Settings:
|
||||
MODEL = "gemini-2.0-flash"
|
||||
|
||||
# fmt:off
|
||||
CONFIG = GenerateContentConfig(
|
||||
|
||||
candidate_count=1,
|
||||
|
||||
system_instruction=(
|
||||
"Answer precisely and in short unless specifically instructed otherwise."
|
||||
"\nIF asked related to code, do not comment the code and do not explain the code unless instructed."
|
||||
),
|
||||
|
||||
temperature=0.69,
|
||||
|
||||
max_output_tokens=1024,
|
||||
|
||||
safety_settings=[
|
||||
# SafetySetting(category="HARM_CATEGORY_UNSPECIFIED", threshold="BLOCK_NONE"),
|
||||
SafetySetting(category="HARM_CATEGORY_HATE_SPEECH", threshold="BLOCK_NONE"),
|
||||
SafetySetting(category="HARM_CATEGORY_DANGEROUS_CONTENT", threshold="BLOCK_NONE"),
|
||||
SafetySetting(category="HARM_CATEGORY_HARASSMENT", threshold="BLOCK_NONE"),
|
||||
SafetySetting(category="HARM_CATEGORY_SEXUALLY_EXPLICIT", threshold="BLOCK_NONE"),
|
||||
SafetySetting(category="HARM_CATEGORY_CIVIC_INTEGRITY", threshold="BLOCK_NONE"),
|
||||
],
|
||||
# fmt:on
|
||||
|
||||
tools=[],
|
||||
)
|
||||
|
||||
resp_str = f"{response.text} saved as model."
|
||||
SEARCH_TOOL = Tool(
|
||||
google_search=GoogleSearchRetrieval(
|
||||
dynamic_retrieval_config=DynamicRetrievalConfig(
|
||||
dynamic_threshold=0.3
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
await model_reply.edit(resp_str)
|
||||
@staticmethod
|
||||
def get_kwargs(use_search:bool=False) -> dict:
|
||||
tools = Settings.CONFIG.tools
|
||||
|
||||
await bot.log_text(text=resp_str, type="ai")
|
||||
if not use_search and Settings.SEARCH_TOOL in tools:
|
||||
tools.remove(Settings.SEARCH_TOOL)
|
||||
|
||||
Settings.MODEL = response.text
|
||||
if use_search and Settings.SEARCH_TOOL not in tools:
|
||||
tools.append(Settings.SEARCH_TOOL)
|
||||
|
||||
return {"model": Settings.MODEL, "config": Settings.CONFIG}
|
||||
@@ -1,58 +0,0 @@
|
||||
import asyncio
|
||||
import os
|
||||
import shutil
|
||||
import time
|
||||
from mimetypes import guess_type
|
||||
|
||||
from pyrogram.types.messages_and_media import Audio, Photo, Video, Voice
|
||||
from ub_core.utils import get_tg_media_details
|
||||
|
||||
from app import Message
|
||||
from app.plugins.ai.models import async_client, get_response_text
|
||||
|
||||
PROMPT_MAP = {
|
||||
Video: "Summarize video and audio from the file",
|
||||
Photo: "Summarize the image file",
|
||||
Voice: (
|
||||
"\nDo not summarise."
|
||||
"\nTranscribe the audio file to english alphabets AS IS."
|
||||
"\nTranslate it only if the audio is not english."
|
||||
"\nIf the audio is in hindi: Transcribe it to hinglish without translating."
|
||||
),
|
||||
}
|
||||
PROMPT_MAP[Audio] = PROMPT_MAP[Voice]
|
||||
|
||||
|
||||
async def handle_media(prompt: str, media_message: Message, **kwargs) -> str:
|
||||
media = get_tg_media_details(media_message)
|
||||
|
||||
if getattr(media, "file_size", 0) >= 1048576 * 25:
|
||||
return "Error: File Size exceeds 25mb."
|
||||
|
||||
prompt = prompt.strip() or PROMPT_MAP.get(
|
||||
type(media), "Analyse the file and explain."
|
||||
)
|
||||
|
||||
download_dir = os.path.join("downloads", str(time.time())) + "/"
|
||||
downloaded_file: str = await media_message.download(download_dir)
|
||||
|
||||
try:
|
||||
uploaded_file = await async_client.files.upload(
|
||||
file=downloaded_file,
|
||||
config={
|
||||
"mime_type": getattr(media, "mime_type", guess_type(downloaded_file)[0])
|
||||
},
|
||||
)
|
||||
|
||||
while uploaded_file.state.name == "PROCESSING":
|
||||
await asyncio.sleep(5)
|
||||
uploaded_file = await async_client.files.get(name=uploaded_file.name)
|
||||
|
||||
response = await async_client.models.generate_content(
|
||||
contents=[uploaded_file, prompt],
|
||||
**kwargs,
|
||||
)
|
||||
return get_response_text(response, quoted=True)
|
||||
|
||||
finally:
|
||||
shutil.rmtree(download_dir, ignore_errors=True)
|
||||
@@ -7,7 +7,7 @@ from pyrogram.enums import ParseMode
|
||||
from pyrogram.types import InputMediaPhoto
|
||||
|
||||
from app import BOT, Message
|
||||
from app.plugins.ai.models import Settings
|
||||
from app.plugins.ai.gemini_core import Settings
|
||||
|
||||
OPENAI_CLIENT = environ.get("OPENAI_CLIENT", "")
|
||||
OPENAI_MODEL = environ.get("OPENAI_MODEL", "gpt-4o")
|
||||
@@ -94,8 +94,7 @@ async def chat_gpt(bot: BOT, message: Message):
|
||||
|
||||
response = chat_completion.choices[0].message.content
|
||||
await message.reply(
|
||||
text=f"**>\n••>{prompt}<**\n**>{response}\n<**",
|
||||
parse_mode=ParseMode.MARKDOWN,
|
||||
text=f"**>\n••>{prompt}<**\n**>{response}\n<**", parse_mode=ParseMode.MARKDOWN
|
||||
)
|
||||
|
||||
|
||||
@@ -158,8 +157,6 @@ async def chat_gpt(bot: BOT, message: Message):
|
||||
|
||||
await response.edit_media(
|
||||
InputMediaPhoto(
|
||||
media=image_io,
|
||||
caption=f"**>\n{prompt}\n<**",
|
||||
has_spoiler="-s" in message.flags,
|
||||
media=image_io, caption=f"**>\n{prompt}\n<**", has_spoiler="-s" in message.flags
|
||||
)
|
||||
)
|
||||
|
||||
@@ -2,19 +2,17 @@
|
||||
from io import BytesIO
|
||||
|
||||
from google.genai.chats import AsyncChat
|
||||
from pyrogram.enums import ParseMode
|
||||
from pyrogram.enums import ChatType, ParseMode
|
||||
|
||||
from app import BOT, Convo, Message, bot
|
||||
from app.plugins.ai.media_query import handle_media
|
||||
from app.plugins.ai.models import (
|
||||
from app.plugins.ai.gemini_core import (
|
||||
Settings,
|
||||
async_client,
|
||||
create_prompts,
|
||||
get_response_text,
|
||||
run_basic_check,
|
||||
)
|
||||
|
||||
CONVO_CACHE: dict[str, Convo] = {}
|
||||
|
||||
|
||||
@bot.add_cmd(cmd="ai")
|
||||
@run_basic_check
|
||||
@@ -36,28 +34,23 @@ async def question(bot: BOT, message: Message):
|
||||
prompt = message.filtered_input
|
||||
|
||||
if reply and reply.media:
|
||||
message_response = await message.reply(
|
||||
"<code>Processing... this may take a while.</code>"
|
||||
)
|
||||
response_text = await handle_media(
|
||||
prompt=prompt,
|
||||
media_message=reply,
|
||||
**Settings.get_kwargs(use_search="-s" in message.flags),
|
||||
)
|
||||
message_response = await message.reply("<code>Processing... this may take a while.</code>")
|
||||
else:
|
||||
message_response = await message.reply(
|
||||
"<code>Input received... generating response.</code>"
|
||||
)
|
||||
if reply and reply.text:
|
||||
prompts = [str(reply.text), message.input or "answer"]
|
||||
else:
|
||||
prompts = [message.input]
|
||||
|
||||
response = await async_client.models.generate_content(
|
||||
contents=prompts,
|
||||
**Settings.get_kwargs(use_search="-s" in message.flags),
|
||||
)
|
||||
response_text = get_response_text(response, quoted=True)
|
||||
try:
|
||||
prompts = await create_prompts(message=message)
|
||||
except AssertionError as e:
|
||||
await message_response.edit(e)
|
||||
return
|
||||
|
||||
response = await async_client.models.generate_content(
|
||||
contents=prompts, **Settings.get_kwargs(use_search="-s" in message.flags)
|
||||
)
|
||||
|
||||
response_text = get_response_text(response, quoted=True)
|
||||
|
||||
await message_response.edit(
|
||||
text=f"**>\n•> {prompt}<**\n{response_text}",
|
||||
@@ -72,13 +65,17 @@ async def ai_chat(bot: BOT, message: Message):
|
||||
"""
|
||||
CMD: AICHAT
|
||||
INFO: Have a Conversation with Gemini AI.
|
||||
FLAGS:
|
||||
"-s": use search
|
||||
|
||||
USAGE:
|
||||
.aic hello
|
||||
keep replying to AI responses
|
||||
keep replying to AI responses with text | media [no need to reply in DM]
|
||||
After 5 mins of Idle bot will export history and stop chat.
|
||||
use .load_history to continue
|
||||
|
||||
"""
|
||||
chat = async_client.chats.create(**Settings.get_kwargs())
|
||||
chat = async_client.chats.create(**Settings.get_kwargs("-s" in message.flags))
|
||||
await do_convo(chat=chat, message=message)
|
||||
|
||||
|
||||
@@ -109,49 +106,70 @@ async def history_chat(bot: BOT, message: Message):
|
||||
|
||||
history = pickle.load(doc)
|
||||
await resp.edit("__History Loaded... Resuming chat__")
|
||||
chat = async_client.chats.create(**Settings.get_kwargs(), history=history)
|
||||
chat = async_client.chats.create(
|
||||
**Settings.get_kwargs(use_search="-s" in message.flags), history=history
|
||||
)
|
||||
await do_convo(chat=chat, message=message)
|
||||
|
||||
|
||||
CONVO_CACHE: dict[str, Convo] = {}
|
||||
|
||||
|
||||
async def do_convo(chat: AsyncChat, message: Message):
|
||||
prompt = message.input
|
||||
reply_to_id = message.id
|
||||
chat_id = message.chat.id
|
||||
old_convo = CONVO_CACHE.get(message.unique_chat_user_id)
|
||||
|
||||
if old_convo in Convo.CONVO_DICT[chat_id]:
|
||||
Convo.CONVO_DICT[chat_id].remove(old_convo)
|
||||
|
||||
if message.chat.type in (ChatType.PRIVATE, ChatType.BOT):
|
||||
reply_to_user_id = None
|
||||
else:
|
||||
reply_to_user_id = message._client.me.id
|
||||
|
||||
convo_obj = Convo(
|
||||
client=message._client,
|
||||
chat_id=chat_id,
|
||||
timeout=300,
|
||||
check_for_duplicates=False,
|
||||
from_user=message.from_user.id,
|
||||
reply_to_user_id=message._client.me.id,
|
||||
reply_to_user_id=reply_to_user_id,
|
||||
)
|
||||
|
||||
CONVO_CACHE[message.unique_chat_user_id] = convo_obj
|
||||
|
||||
try:
|
||||
async with convo_obj:
|
||||
prompt = [message.input]
|
||||
reply_to_id = message.id
|
||||
while True:
|
||||
ai_response = await chat.send_message(prompt)
|
||||
ai_response_text = get_response_text(ai_response, quoted=True)
|
||||
text = f"**GENAI:**\n{ai_response_text}"
|
||||
ai_response_text = get_response_text(ai_response, add_sources=True, quoted=True)
|
||||
|
||||
_, prompt_message = await convo_obj.send_message(
|
||||
text=text,
|
||||
text=f"**>\n.><**\n{ai_response_text}",
|
||||
reply_to_id=reply_to_id,
|
||||
parse_mode=ParseMode.MARKDOWN,
|
||||
get_response=True,
|
||||
disable_preview=True,
|
||||
)
|
||||
prompt, reply_to_id = prompt_message.text, prompt_message.id
|
||||
|
||||
except TimeoutError:
|
||||
try:
|
||||
prompt = await create_prompts(message=prompt_message, is_chat=True)
|
||||
except Exception as e:
|
||||
_, prompt_message = await convo_obj.send_message(
|
||||
text=str(e),
|
||||
reply_to_id=reply_to_id,
|
||||
parse_mode=ParseMode.MARKDOWN,
|
||||
get_response=True,
|
||||
disable_preview=True,
|
||||
)
|
||||
|
||||
reply_to_id = prompt_message.id
|
||||
|
||||
finally:
|
||||
await export_history(chat, message)
|
||||
|
||||
CONVO_CACHE.pop(message.unique_chat_user_id, 0)
|
||||
CONVO_CACHE.pop(message.unique_chat_user_id, 0)
|
||||
|
||||
|
||||
async def export_history(chat: AsyncChat, message: Message):
|
||||
|
||||
Reference in New Issue
Block a user