From 469b4d02adf0f47835203dc63afa7ec760ea6d7f Mon Sep 17 00:00:00 2001 From: thedragonsinn <98635854+thedragonsinn@users.noreply.github.com> Date: Wed, 19 Mar 2025 15:25:02 +0530 Subject: [PATCH] `ai`: now supports image editing and generation. --- app/plugins/ai/gemini_core.py | 161 +++++++++++------- .../ai/{text_query.py => gemini_query.py} | 122 ++++++++----- app/plugins/ai/openai.py | 4 +- 3 files changed, 175 insertions(+), 112 deletions(-) rename app/plugins/ai/{text_query.py => gemini_query.py} (53%) diff --git a/app/plugins/ai/gemini_core.py b/app/plugins/ai/gemini_core.py index 84fe74c..55486d1 100644 --- a/app/plugins/ai/gemini_core.py +++ b/app/plugins/ai/gemini_core.py @@ -1,4 +1,5 @@ import asyncio +import io import logging import shutil import time @@ -23,8 +24,8 @@ logging.getLogger("google_genai.models").setLevel(logging.WARNING) DB_SETTINGS = CustomDB["COMMON_SETTINGS"] try: - client: Client = Client(api_key=extra_config.GEMINI_API_KEY) - async_client: AsyncClient = client.aio + client: Client | None = Client(api_key=extra_config.GEMINI_API_KEY) + async_client: AsyncClient | None = client.aio except: client = async_client = None @@ -32,7 +33,9 @@ except: async def init_task(): model_info = await DB_SETTINGS.find_one({"_id": "gemini_model_info"}) or {} if model_name := model_info.get("model_name"): - Settings.MODEL = model_name + Settings.TEXT_MODEL = model_name + if image_model := model_info.get("image_model_name"): + Settings.IMAGE_MODEL = image_model def run_basic_check(function): @@ -64,9 +67,24 @@ def run_basic_check(function): return wrapper -def get_response_text(response, quoted: bool = False, add_sources: bool = True): - candidate = response.candidates[0] - text = "\n".join([part.text for part in candidate.content.parts]) +def get_response_content( + response, quoted: bool = False, add_sources: bool = True +) -> tuple[str, io.BytesIO | None]: + + try: + candidate = response.candidates + parts = candidate[0].content.parts + parts[0] + except (AttributeError, IndexError): + return "Query failed... Try again", None + + try: + image_data = io.BytesIO(parts[0].inline_data.data) + image_data.name = "photo.jpg" + except (AttributeError, IndexError): + image_data = None + + text = "\n".join([part.text for part in parts if part.text]) sources = "" if add_sources: @@ -80,7 +98,11 @@ def get_response_text(response, quoted: bool = False, add_sources: bool = True): sources = "" final_text = (text.strip() + sources).strip() - return f"**>\n{final_text}<**" if quoted and "```" not in final_text else final_text + + if final_text and quoted and "```" not in final_text: + final_text = f"**>\n{final_text}<**" + + return final_text, image_data async def save_file(message: Message, check_size: bool = True) -> File | None: @@ -132,7 +154,7 @@ async def create_prompts( if is_chat: if message.media: prompt = message.caption or PROMPT_MAP.get(message.media.value) or default_media_prompt - return [await save_file(message=message, check_size=check_size), prompt] + return [prompt, await save_file(message=message, check_size=check_size)] else: return [message.text] @@ -142,9 +164,9 @@ async def create_prompts( prompt = ( message.filtered_input or PROMPT_MAP.get(reply.media.value) or default_media_prompt ) - return [await save_file(message=reply, check_size=check_size), prompt] + return [prompt, await save_file(message=reply, check_size=check_size)] else: - return [str(reply.text), input_prompt] + return [input_prompt, str(reply.text)] return [input_prompt] @@ -165,80 +187,93 @@ async def list_ai_models(bot: BOT, message: Message): model_str = "\n\n".join(model_list) update_str = ( - f"Current Model: {Settings.MODEL}\n\n" - f"
{model_str}
" + f"Current Model: " + f"{Settings.TEXT_MODEL if "-i" not in message.flags else Settings.IMAGE_MODEL}" + f"\n\n
{model_str}
" "\n\nReply to this message with the model name to change to a different model." ) - model_reply = await message.reply(update_str) + model_info_response = await message.reply(update_str) - response = await model_reply.get_response( - timeout=60, reply_to_message_id=model_reply.id, from_user=message.from_user.id + model_response = await model_info_response.get_response( + timeout=60, reply_to_message_id=model_info_response.id, from_user=message.from_user.id ) - if not response: - await model_reply.delete() + if not model_response: + await model_info_response.delete() return - if response.text not in model_list: - await model_reply.edit( - f"Invalid Model... run {message.trigger}{message.cmd} again" - ) + if model_response.text not in model_list: + await model_info_response.edit(f"Invalid Model... Try again") return - await DB_SETTINGS.add_data({"_id": "gemini_model_info", "model_name": response.text}) - resp_str = f"{response.text} saved as model." - await model_reply.edit(resp_str) - await bot.log_text(text=resp_str, type="ai") - Settings.MODEL = response.text + if "-i" in message.flags: + data_key = "image_model_name" + Settings.IMAGE_MODEL = model_response.text + else: + data_key = "model_name" + Settings.TEXT_MODEL = model_response.text + + await DB_SETTINGS.add_data({"_id": "gemini_model_info", data_key: model_response.text}) + resp_str = f"{model_response.text} saved as model." + await model_info_response.edit(resp_str) + await bot.log_text(text=resp_str, type=f"ai_{data_key}") + + +SAFETY_SETTINGS = [ + # SafetySetting(category="HARM_CATEGORY_UNSPECIFIED", threshold="BLOCK_NONE"), + SafetySetting(category="HARM_CATEGORY_HATE_SPEECH", threshold="BLOCK_NONE"), + SafetySetting(category="HARM_CATEGORY_DANGEROUS_CONTENT", threshold="BLOCK_NONE"), + SafetySetting(category="HARM_CATEGORY_HARASSMENT", threshold="BLOCK_NONE"), + SafetySetting(category="HARM_CATEGORY_SEXUALLY_EXPLICIT", threshold="BLOCK_NONE"), + SafetySetting(category="HARM_CATEGORY_CIVIC_INTEGRITY", threshold="BLOCK_NONE"), +] + +SEARCH_TOOL = Tool( + google_search=GoogleSearchRetrieval( + dynamic_retrieval_config=DynamicRetrievalConfig(dynamic_threshold=0.3) + ) +) + +SYSTEM_INSTRUCTION = ( + "Answer precisely and in short unless specifically instructed otherwise." + "\nIF asked related to code, do not comment the code and do not explain the code unless instructed." +) class Settings: - MODEL = "gemini-2.0-flash" - - # fmt:off - CONFIG = GenerateContentConfig( + TEXT_MODEL = "gemini-2.0-flash" + TEXT_CONFIG = GenerateContentConfig( candidate_count=1, - - system_instruction=( - "Answer precisely and in short unless specifically instructed otherwise." - "\nIF asked related to code, do not comment the code and do not explain the code unless instructed." - ), - - temperature=0.69, - max_output_tokens=1024, - - safety_settings=[ - # SafetySetting(category="HARM_CATEGORY_UNSPECIFIED", threshold="BLOCK_NONE"), - SafetySetting(category="HARM_CATEGORY_HATE_SPEECH", threshold="BLOCK_NONE"), - SafetySetting(category="HARM_CATEGORY_DANGEROUS_CONTENT", threshold="BLOCK_NONE"), - SafetySetting(category="HARM_CATEGORY_HARASSMENT", threshold="BLOCK_NONE"), - SafetySetting(category="HARM_CATEGORY_SEXUALLY_EXPLICIT", threshold="BLOCK_NONE"), - SafetySetting(category="HARM_CATEGORY_CIVIC_INTEGRITY", threshold="BLOCK_NONE"), - ], - # fmt:on - + response_modalities=["Text"], + system_instruction=SYSTEM_INSTRUCTION, + temperature=0.69, tools=[], ) - SEARCH_TOOL = Tool( - google_search=GoogleSearchRetrieval( - dynamic_retrieval_config=DynamicRetrievalConfig( - dynamic_threshold=0.3 - ) - ) - ) + IMAGE_MODEL = "gemini-2.0-flash-exp" + + IMAGE_CONFIG = GenerateContentConfig( + candidate_count=1, + max_output_tokens=1024, + response_modalities=["Text", "Image"], + # system_instruction=SYSTEM_INSTRUCTION, + temperature=0.99, + ) @staticmethod - def get_kwargs(use_search:bool=False) -> dict: - tools = Settings.CONFIG.tools + def get_kwargs(use_search: bool = False, image_mode: bool = False) -> dict: + if image_mode: + return {"model": Settings.IMAGE_MODEL, "config": Settings.IMAGE_CONFIG} - if not use_search and Settings.SEARCH_TOOL in tools: - tools.remove(Settings.SEARCH_TOOL) + tools = Settings.TEXT_CONFIG.tools - if use_search and Settings.SEARCH_TOOL not in tools: - tools.append(Settings.SEARCH_TOOL) + if not use_search and SEARCH_TOOL in tools: + tools.remove(SEARCH_TOOL) - return {"model": Settings.MODEL, "config": Settings.CONFIG} + if use_search and SEARCH_TOOL not in tools: + tools.append(SEARCH_TOOL) + + return {"model": Settings.TEXT_MODEL, "config": Settings.TEXT_CONFIG} diff --git a/app/plugins/ai/text_query.py b/app/plugins/ai/gemini_query.py similarity index 53% rename from app/plugins/ai/text_query.py rename to app/plugins/ai/gemini_query.py index 524cfd4..d08da11 100644 --- a/app/plugins/ai/text_query.py +++ b/app/plugins/ai/gemini_query.py @@ -3,13 +3,14 @@ from io import BytesIO from google.genai.chats import AsyncChat from pyrogram.enums import ChatType, ParseMode +from pyrogram.types import InputMediaPhoto from app import BOT, Convo, Message, bot from app.plugins.ai.gemini_core import ( Settings, async_client, create_prompts, - get_response_text, + get_response_content, run_basic_check, ) @@ -22,6 +23,7 @@ async def question(bot: BOT, message: Message): INFO: Ask a question to Gemini AI or get info about replied message / media. FLAGS: -s: to use Search + -i: to edit/generate images USAGE: .ai what is the meaning of life. .ai [reply to a message] (sends replied text as query) @@ -34,11 +36,11 @@ async def question(bot: BOT, message: Message): prompt = message.filtered_input if reply and reply.media: - message_response = await message.reply("Processing... this may take a while.") + resp_str = "Processing... this may take a while." else: - message_response = await message.reply( - "Input received... generating response." - ) + resp_str = "Input received... generating response." + + message_response = await message.reply(resp_str) try: prompts = await create_prompts(message=message) @@ -46,17 +48,24 @@ async def question(bot: BOT, message: Message): await message_response.edit(e) return - response = await async_client.models.generate_content( - contents=prompts, **Settings.get_kwargs(use_search="-s" in message.flags) - ) + kwargs = Settings.get_kwargs(use_search="-s" in message.flags, image_mode="-i" in message.flags) - response_text = get_response_text(response, quoted=True) + response = await async_client.models.generate_content(contents=prompts, **kwargs) - await message_response.edit( - text=f"**>\n•> {prompt}<**\n{response_text}", - parse_mode=ParseMode.MARKDOWN, - disable_preview=True, - ) + response_text, response_image = get_response_content(response, quoted=True) + + if response_image: + await message_response.edit_media( + media=InputMediaPhoto(media=response_image, caption=f"**>\n•> {prompt}<**") + ) + if response_text and isinstance(message, Message): + await message_response.reply(response_text) + else: + await message_response.edit( + text=f"**>\n•> {prompt}<**\n{response_text}", + parse_mode=ParseMode.MARKDOWN, + disable_preview=True, + ) @bot.add_cmd(cmd="aic") @@ -67,7 +76,7 @@ async def ai_chat(bot: BOT, message: Message): INFO: Have a Conversation with Gemini AI. FLAGS: "-s": use search - + "-i": use image gen/edit mode USAGE: .aic hello keep replying to AI responses with text | media [no need to reply in DM] @@ -75,7 +84,9 @@ async def ai_chat(bot: BOT, message: Message): use .load_history to continue """ - chat = async_client.chats.create(**Settings.get_kwargs("-s" in message.flags)) + chat = async_client.chats.create( + **Settings.get_kwargs(use_search="-s" in message.flags, image_mode="-i" in message.flags) + ) await do_convo(chat=chat, message=message) @@ -101,33 +112,36 @@ async def history_chat(bot: BOT, message: Message): return resp = await message.reply("`Loading History...`") + doc = await reply.download(in_memory=True) doc.seek(0) + pickle.load(doc) - history = pickle.load(doc) await resp.edit("__History Loaded... Resuming chat__") + chat = async_client.chats.create( - **Settings.get_kwargs(use_search="-s" in message.flags), history=history + **Settings.get_kwargs(use_search="-s" in message.flags, image_mode="-i" in message.flags) ) - await do_convo(chat=chat, message=message) + await do_convo(chat=chat, message=message, is_reloaded=True) CONVO_CACHE: dict[str, Convo] = {} -async def do_convo(chat: AsyncChat, message: Message): +async def do_convo(chat: AsyncChat, message: Message, is_reloaded: bool = False): chat_id = message.chat.id - old_convo = CONVO_CACHE.get(message.unique_chat_user_id) - if old_convo in Convo.CONVO_DICT[chat_id]: - Convo.CONVO_DICT[chat_id].remove(old_convo) + old_conversation = CONVO_CACHE.get(message.unique_chat_user_id) + + if old_conversation in Convo.CONVO_DICT[chat_id]: + Convo.CONVO_DICT[chat_id].remove(old_conversation) if message.chat.type in (ChatType.PRIVATE, ChatType.BOT): reply_to_user_id = None else: reply_to_user_id = message._client.me.id - convo_obj = Convo( + conversation_object = Convo( client=message._client, chat_id=chat_id, timeout=300, @@ -136,51 +150,65 @@ async def do_convo(chat: AsyncChat, message: Message): reply_to_user_id=reply_to_user_id, ) - CONVO_CACHE[message.unique_chat_user_id] = convo_obj + CONVO_CACHE[message.unique_chat_user_id] = conversation_object try: - async with convo_obj: - prompt = [message.input] + async with conversation_object: + prompt = await create_prompts(message, is_chat=is_reloaded) reply_to_id = message.id + while True: ai_response = await chat.send_message(prompt) - ai_response_text = get_response_text(ai_response, add_sources=True, quoted=True) - - _, prompt_message = await convo_obj.send_message( - text=f"**>\n•><**\n{ai_response_text}", + response_text, response_image = get_response_content(ai_response, quoted=True) + prompt_message = await send_and_get_resp( + convo_obj=conversation_object, + response_text=response_text, + response_image=response_image, reply_to_id=reply_to_id, - parse_mode=ParseMode.MARKDOWN, - get_response=True, - disable_preview=True, ) try: - prompt = await create_prompts( - message=prompt_message, is_chat=True, check_size=False - ) + prompt = await create_prompts(prompt_message, is_chat=True, check_size=False) except Exception as e: - _, prompt_message = await convo_obj.send_message( - text=str(e), - reply_to_id=reply_to_id, - parse_mode=ParseMode.MARKDOWN, - get_response=True, - disable_preview=True, - ) - prompt = await create_prompts( - message=prompt_message, is_chat=True, check_size=False + prompt_message = await send_and_get_resp( + conversation_object, str(e), reply_to_id=reply_to_id ) + prompt = await create_prompts(prompt_message, is_chat=True, check_size=False) reply_to_id = prompt_message.id + except TimeoutError: + pass + finally: await export_history(chat, message) CONVO_CACHE.pop(message.unique_chat_user_id, 0) +async def send_and_get_resp( + convo_obj: Convo, + response_text: str | None = None, + response_image: BytesIO | None = None, + reply_to_id: int | None = None, +) -> Message: + + if response_image: + await convo_obj.send_photo(photo=response_image, reply_to_id=reply_to_id) + + if response_text: + await convo_obj.send_message( + text=f"**>\n•><**\n{response_text}", + reply_to_id=reply_to_id, + parse_mode=ParseMode.MARKDOWN, + disable_preview=True, + ) + return await convo_obj.get_response() + + async def export_history(chat: AsyncChat, message: Message): doc = BytesIO(pickle.dumps(chat._curated_history)) doc.name = "AI_Chat_History.pkl" - caption = get_response_text( + caption, _ = get_response_content( await chat.send_message("Summarize our Conversation into one line.") ) await bot.send_document(chat_id=message.from_user.id, document=doc, caption=caption) diff --git a/app/plugins/ai/openai.py b/app/plugins/ai/openai.py index d6a49ee..84c0cbc 100644 --- a/app/plugins/ai/openai.py +++ b/app/plugins/ai/openai.py @@ -7,7 +7,7 @@ from pyrogram.enums import ParseMode from pyrogram.types import InputMediaPhoto from app import BOT, Message -from app.plugins.ai.gemini_core import Settings +from app.plugins.ai.gemini_core import SYSTEM_INSTRUCTION OPENAI_CLIENT = environ.get("OPENAI_CLIENT", "") OPENAI_MODEL = environ.get("OPENAI_MODEL", "gpt-4o") @@ -86,7 +86,7 @@ async def chat_gpt(bot: BOT, message: Message): chat_completion = await TEXT_CLIENT.chat.completions.create( messages=[ - {"role": "system", "content": Settings.CONFIG.system_instruction}, + {"role": "system", "content": SYSTEM_INSTRUCTION}, {"role": "user", "content": prompt}, ], model=OPENAI_MODEL,