From 469b4d02adf0f47835203dc63afa7ec760ea6d7f Mon Sep 17 00:00:00 2001
From: thedragonsinn <98635854+thedragonsinn@users.noreply.github.com>
Date: Wed, 19 Mar 2025 15:25:02 +0530
Subject: [PATCH] `ai`: now supports image editing and generation.
---
app/plugins/ai/gemini_core.py | 161 +++++++++++-------
.../ai/{text_query.py => gemini_query.py} | 122 ++++++++-----
app/plugins/ai/openai.py | 4 +-
3 files changed, 175 insertions(+), 112 deletions(-)
rename app/plugins/ai/{text_query.py => gemini_query.py} (53%)
diff --git a/app/plugins/ai/gemini_core.py b/app/plugins/ai/gemini_core.py
index 84fe74c..55486d1 100644
--- a/app/plugins/ai/gemini_core.py
+++ b/app/plugins/ai/gemini_core.py
@@ -1,4 +1,5 @@
import asyncio
+import io
import logging
import shutil
import time
@@ -23,8 +24,8 @@ logging.getLogger("google_genai.models").setLevel(logging.WARNING)
DB_SETTINGS = CustomDB["COMMON_SETTINGS"]
try:
- client: Client = Client(api_key=extra_config.GEMINI_API_KEY)
- async_client: AsyncClient = client.aio
+ client: Client | None = Client(api_key=extra_config.GEMINI_API_KEY)
+ async_client: AsyncClient | None = client.aio
except:
client = async_client = None
@@ -32,7 +33,9 @@ except:
async def init_task():
model_info = await DB_SETTINGS.find_one({"_id": "gemini_model_info"}) or {}
if model_name := model_info.get("model_name"):
- Settings.MODEL = model_name
+ Settings.TEXT_MODEL = model_name
+ if image_model := model_info.get("image_model_name"):
+ Settings.IMAGE_MODEL = image_model
def run_basic_check(function):
@@ -64,9 +67,24 @@ def run_basic_check(function):
return wrapper
-def get_response_text(response, quoted: bool = False, add_sources: bool = True):
- candidate = response.candidates[0]
- text = "\n".join([part.text for part in candidate.content.parts])
+def get_response_content(
+ response, quoted: bool = False, add_sources: bool = True
+) -> tuple[str, io.BytesIO | None]:
+
+ try:
+ candidate = response.candidates
+ parts = candidate[0].content.parts
+ parts[0]
+ except (AttributeError, IndexError):
+ return "Query failed... Try again", None
+
+ try:
+ image_data = io.BytesIO(parts[0].inline_data.data)
+ image_data.name = "photo.jpg"
+ except (AttributeError, IndexError):
+ image_data = None
+
+ text = "\n".join([part.text for part in parts if part.text])
sources = ""
if add_sources:
@@ -80,7 +98,11 @@ def get_response_text(response, quoted: bool = False, add_sources: bool = True):
sources = ""
final_text = (text.strip() + sources).strip()
- return f"**>\n{final_text}<**" if quoted and "```" not in final_text else final_text
+
+ if final_text and quoted and "```" not in final_text:
+ final_text = f"**>\n{final_text}<**"
+
+ return final_text, image_data
async def save_file(message: Message, check_size: bool = True) -> File | None:
@@ -132,7 +154,7 @@ async def create_prompts(
if is_chat:
if message.media:
prompt = message.caption or PROMPT_MAP.get(message.media.value) or default_media_prompt
- return [await save_file(message=message, check_size=check_size), prompt]
+ return [prompt, await save_file(message=message, check_size=check_size)]
else:
return [message.text]
@@ -142,9 +164,9 @@ async def create_prompts(
prompt = (
message.filtered_input or PROMPT_MAP.get(reply.media.value) or default_media_prompt
)
- return [await save_file(message=reply, check_size=check_size), prompt]
+ return [prompt, await save_file(message=reply, check_size=check_size)]
else:
- return [str(reply.text), input_prompt]
+ return [input_prompt, str(reply.text)]
return [input_prompt]
@@ -165,80 +187,93 @@ async def list_ai_models(bot: BOT, message: Message):
model_str = "\n\n".join(model_list)
update_str = (
- f"Current Model: {Settings.MODEL}\n\n"
- f"
{model_str}"
+ f"Current Model: "
+ f"{Settings.TEXT_MODEL if "-i" not in message.flags else Settings.IMAGE_MODEL}"
+ f"\n\n{model_str}"
"\n\nReply to this message with the model name to change to a different model."
)
- model_reply = await message.reply(update_str)
+ model_info_response = await message.reply(update_str)
- response = await model_reply.get_response(
- timeout=60, reply_to_message_id=model_reply.id, from_user=message.from_user.id
+ model_response = await model_info_response.get_response(
+ timeout=60, reply_to_message_id=model_info_response.id, from_user=message.from_user.id
)
- if not response:
- await model_reply.delete()
+ if not model_response:
+ await model_info_response.delete()
return
- if response.text not in model_list:
- await model_reply.edit(
- f"Invalid Model... run {message.trigger}{message.cmd} again"
- )
+ if model_response.text not in model_list:
+ await model_info_response.edit(f"Invalid Model... Try again")
return
- await DB_SETTINGS.add_data({"_id": "gemini_model_info", "model_name": response.text})
- resp_str = f"{response.text} saved as model."
- await model_reply.edit(resp_str)
- await bot.log_text(text=resp_str, type="ai")
- Settings.MODEL = response.text
+ if "-i" in message.flags:
+ data_key = "image_model_name"
+ Settings.IMAGE_MODEL = model_response.text
+ else:
+ data_key = "model_name"
+ Settings.TEXT_MODEL = model_response.text
+
+ await DB_SETTINGS.add_data({"_id": "gemini_model_info", data_key: model_response.text})
+ resp_str = f"{model_response.text} saved as model."
+ await model_info_response.edit(resp_str)
+ await bot.log_text(text=resp_str, type=f"ai_{data_key}")
+
+
+SAFETY_SETTINGS = [
+ # SafetySetting(category="HARM_CATEGORY_UNSPECIFIED", threshold="BLOCK_NONE"),
+ SafetySetting(category="HARM_CATEGORY_HATE_SPEECH", threshold="BLOCK_NONE"),
+ SafetySetting(category="HARM_CATEGORY_DANGEROUS_CONTENT", threshold="BLOCK_NONE"),
+ SafetySetting(category="HARM_CATEGORY_HARASSMENT", threshold="BLOCK_NONE"),
+ SafetySetting(category="HARM_CATEGORY_SEXUALLY_EXPLICIT", threshold="BLOCK_NONE"),
+ SafetySetting(category="HARM_CATEGORY_CIVIC_INTEGRITY", threshold="BLOCK_NONE"),
+]
+
+SEARCH_TOOL = Tool(
+ google_search=GoogleSearchRetrieval(
+ dynamic_retrieval_config=DynamicRetrievalConfig(dynamic_threshold=0.3)
+ )
+)
+
+SYSTEM_INSTRUCTION = (
+ "Answer precisely and in short unless specifically instructed otherwise."
+ "\nIF asked related to code, do not comment the code and do not explain the code unless instructed."
+)
class Settings:
- MODEL = "gemini-2.0-flash"
-
- # fmt:off
- CONFIG = GenerateContentConfig(
+ TEXT_MODEL = "gemini-2.0-flash"
+ TEXT_CONFIG = GenerateContentConfig(
candidate_count=1,
-
- system_instruction=(
- "Answer precisely and in short unless specifically instructed otherwise."
- "\nIF asked related to code, do not comment the code and do not explain the code unless instructed."
- ),
-
- temperature=0.69,
-
max_output_tokens=1024,
-
- safety_settings=[
- # SafetySetting(category="HARM_CATEGORY_UNSPECIFIED", threshold="BLOCK_NONE"),
- SafetySetting(category="HARM_CATEGORY_HATE_SPEECH", threshold="BLOCK_NONE"),
- SafetySetting(category="HARM_CATEGORY_DANGEROUS_CONTENT", threshold="BLOCK_NONE"),
- SafetySetting(category="HARM_CATEGORY_HARASSMENT", threshold="BLOCK_NONE"),
- SafetySetting(category="HARM_CATEGORY_SEXUALLY_EXPLICIT", threshold="BLOCK_NONE"),
- SafetySetting(category="HARM_CATEGORY_CIVIC_INTEGRITY", threshold="BLOCK_NONE"),
- ],
- # fmt:on
-
+ response_modalities=["Text"],
+ system_instruction=SYSTEM_INSTRUCTION,
+ temperature=0.69,
tools=[],
)
- SEARCH_TOOL = Tool(
- google_search=GoogleSearchRetrieval(
- dynamic_retrieval_config=DynamicRetrievalConfig(
- dynamic_threshold=0.3
- )
- )
- )
+ IMAGE_MODEL = "gemini-2.0-flash-exp"
+
+ IMAGE_CONFIG = GenerateContentConfig(
+ candidate_count=1,
+ max_output_tokens=1024,
+ response_modalities=["Text", "Image"],
+ # system_instruction=SYSTEM_INSTRUCTION,
+ temperature=0.99,
+ )
@staticmethod
- def get_kwargs(use_search:bool=False) -> dict:
- tools = Settings.CONFIG.tools
+ def get_kwargs(use_search: bool = False, image_mode: bool = False) -> dict:
+ if image_mode:
+ return {"model": Settings.IMAGE_MODEL, "config": Settings.IMAGE_CONFIG}
- if not use_search and Settings.SEARCH_TOOL in tools:
- tools.remove(Settings.SEARCH_TOOL)
+ tools = Settings.TEXT_CONFIG.tools
- if use_search and Settings.SEARCH_TOOL not in tools:
- tools.append(Settings.SEARCH_TOOL)
+ if not use_search and SEARCH_TOOL in tools:
+ tools.remove(SEARCH_TOOL)
- return {"model": Settings.MODEL, "config": Settings.CONFIG}
+ if use_search and SEARCH_TOOL not in tools:
+ tools.append(SEARCH_TOOL)
+
+ return {"model": Settings.TEXT_MODEL, "config": Settings.TEXT_CONFIG}
diff --git a/app/plugins/ai/text_query.py b/app/plugins/ai/gemini_query.py
similarity index 53%
rename from app/plugins/ai/text_query.py
rename to app/plugins/ai/gemini_query.py
index 524cfd4..d08da11 100644
--- a/app/plugins/ai/text_query.py
+++ b/app/plugins/ai/gemini_query.py
@@ -3,13 +3,14 @@ from io import BytesIO
from google.genai.chats import AsyncChat
from pyrogram.enums import ChatType, ParseMode
+from pyrogram.types import InputMediaPhoto
from app import BOT, Convo, Message, bot
from app.plugins.ai.gemini_core import (
Settings,
async_client,
create_prompts,
- get_response_text,
+ get_response_content,
run_basic_check,
)
@@ -22,6 +23,7 @@ async def question(bot: BOT, message: Message):
INFO: Ask a question to Gemini AI or get info about replied message / media.
FLAGS:
-s: to use Search
+ -i: to edit/generate images
USAGE:
.ai what is the meaning of life.
.ai [reply to a message] (sends replied text as query)
@@ -34,11 +36,11 @@ async def question(bot: BOT, message: Message):
prompt = message.filtered_input
if reply and reply.media:
- message_response = await message.reply("Processing... this may take a while.")
+ resp_str = "Processing... this may take a while."
else:
- message_response = await message.reply(
- "Input received... generating response."
- )
+ resp_str = "Input received... generating response."
+
+ message_response = await message.reply(resp_str)
try:
prompts = await create_prompts(message=message)
@@ -46,17 +48,24 @@ async def question(bot: BOT, message: Message):
await message_response.edit(e)
return
- response = await async_client.models.generate_content(
- contents=prompts, **Settings.get_kwargs(use_search="-s" in message.flags)
- )
+ kwargs = Settings.get_kwargs(use_search="-s" in message.flags, image_mode="-i" in message.flags)
- response_text = get_response_text(response, quoted=True)
+ response = await async_client.models.generate_content(contents=prompts, **kwargs)
- await message_response.edit(
- text=f"**>\n•> {prompt}<**\n{response_text}",
- parse_mode=ParseMode.MARKDOWN,
- disable_preview=True,
- )
+ response_text, response_image = get_response_content(response, quoted=True)
+
+ if response_image:
+ await message_response.edit_media(
+ media=InputMediaPhoto(media=response_image, caption=f"**>\n•> {prompt}<**")
+ )
+ if response_text and isinstance(message, Message):
+ await message_response.reply(response_text)
+ else:
+ await message_response.edit(
+ text=f"**>\n•> {prompt}<**\n{response_text}",
+ parse_mode=ParseMode.MARKDOWN,
+ disable_preview=True,
+ )
@bot.add_cmd(cmd="aic")
@@ -67,7 +76,7 @@ async def ai_chat(bot: BOT, message: Message):
INFO: Have a Conversation with Gemini AI.
FLAGS:
"-s": use search
-
+ "-i": use image gen/edit mode
USAGE:
.aic hello
keep replying to AI responses with text | media [no need to reply in DM]
@@ -75,7 +84,9 @@ async def ai_chat(bot: BOT, message: Message):
use .load_history to continue
"""
- chat = async_client.chats.create(**Settings.get_kwargs("-s" in message.flags))
+ chat = async_client.chats.create(
+ **Settings.get_kwargs(use_search="-s" in message.flags, image_mode="-i" in message.flags)
+ )
await do_convo(chat=chat, message=message)
@@ -101,33 +112,36 @@ async def history_chat(bot: BOT, message: Message):
return
resp = await message.reply("`Loading History...`")
+
doc = await reply.download(in_memory=True)
doc.seek(0)
+ pickle.load(doc)
- history = pickle.load(doc)
await resp.edit("__History Loaded... Resuming chat__")
+
chat = async_client.chats.create(
- **Settings.get_kwargs(use_search="-s" in message.flags), history=history
+ **Settings.get_kwargs(use_search="-s" in message.flags, image_mode="-i" in message.flags)
)
- await do_convo(chat=chat, message=message)
+ await do_convo(chat=chat, message=message, is_reloaded=True)
CONVO_CACHE: dict[str, Convo] = {}
-async def do_convo(chat: AsyncChat, message: Message):
+async def do_convo(chat: AsyncChat, message: Message, is_reloaded: bool = False):
chat_id = message.chat.id
- old_convo = CONVO_CACHE.get(message.unique_chat_user_id)
- if old_convo in Convo.CONVO_DICT[chat_id]:
- Convo.CONVO_DICT[chat_id].remove(old_convo)
+ old_conversation = CONVO_CACHE.get(message.unique_chat_user_id)
+
+ if old_conversation in Convo.CONVO_DICT[chat_id]:
+ Convo.CONVO_DICT[chat_id].remove(old_conversation)
if message.chat.type in (ChatType.PRIVATE, ChatType.BOT):
reply_to_user_id = None
else:
reply_to_user_id = message._client.me.id
- convo_obj = Convo(
+ conversation_object = Convo(
client=message._client,
chat_id=chat_id,
timeout=300,
@@ -136,51 +150,65 @@ async def do_convo(chat: AsyncChat, message: Message):
reply_to_user_id=reply_to_user_id,
)
- CONVO_CACHE[message.unique_chat_user_id] = convo_obj
+ CONVO_CACHE[message.unique_chat_user_id] = conversation_object
try:
- async with convo_obj:
- prompt = [message.input]
+ async with conversation_object:
+ prompt = await create_prompts(message, is_chat=is_reloaded)
reply_to_id = message.id
+
while True:
ai_response = await chat.send_message(prompt)
- ai_response_text = get_response_text(ai_response, add_sources=True, quoted=True)
-
- _, prompt_message = await convo_obj.send_message(
- text=f"**>\n•><**\n{ai_response_text}",
+ response_text, response_image = get_response_content(ai_response, quoted=True)
+ prompt_message = await send_and_get_resp(
+ convo_obj=conversation_object,
+ response_text=response_text,
+ response_image=response_image,
reply_to_id=reply_to_id,
- parse_mode=ParseMode.MARKDOWN,
- get_response=True,
- disable_preview=True,
)
try:
- prompt = await create_prompts(
- message=prompt_message, is_chat=True, check_size=False
- )
+ prompt = await create_prompts(prompt_message, is_chat=True, check_size=False)
except Exception as e:
- _, prompt_message = await convo_obj.send_message(
- text=str(e),
- reply_to_id=reply_to_id,
- parse_mode=ParseMode.MARKDOWN,
- get_response=True,
- disable_preview=True,
- )
- prompt = await create_prompts(
- message=prompt_message, is_chat=True, check_size=False
+ prompt_message = await send_and_get_resp(
+ conversation_object, str(e), reply_to_id=reply_to_id
)
+ prompt = await create_prompts(prompt_message, is_chat=True, check_size=False)
reply_to_id = prompt_message.id
+ except TimeoutError:
+ pass
+
finally:
await export_history(chat, message)
CONVO_CACHE.pop(message.unique_chat_user_id, 0)
+async def send_and_get_resp(
+ convo_obj: Convo,
+ response_text: str | None = None,
+ response_image: BytesIO | None = None,
+ reply_to_id: int | None = None,
+) -> Message:
+
+ if response_image:
+ await convo_obj.send_photo(photo=response_image, reply_to_id=reply_to_id)
+
+ if response_text:
+ await convo_obj.send_message(
+ text=f"**>\n•><**\n{response_text}",
+ reply_to_id=reply_to_id,
+ parse_mode=ParseMode.MARKDOWN,
+ disable_preview=True,
+ )
+ return await convo_obj.get_response()
+
+
async def export_history(chat: AsyncChat, message: Message):
doc = BytesIO(pickle.dumps(chat._curated_history))
doc.name = "AI_Chat_History.pkl"
- caption = get_response_text(
+ caption, _ = get_response_content(
await chat.send_message("Summarize our Conversation into one line.")
)
await bot.send_document(chat_id=message.from_user.id, document=doc, caption=caption)
diff --git a/app/plugins/ai/openai.py b/app/plugins/ai/openai.py
index d6a49ee..84c0cbc 100644
--- a/app/plugins/ai/openai.py
+++ b/app/plugins/ai/openai.py
@@ -7,7 +7,7 @@ from pyrogram.enums import ParseMode
from pyrogram.types import InputMediaPhoto
from app import BOT, Message
-from app.plugins.ai.gemini_core import Settings
+from app.plugins.ai.gemini_core import SYSTEM_INSTRUCTION
OPENAI_CLIENT = environ.get("OPENAI_CLIENT", "")
OPENAI_MODEL = environ.get("OPENAI_MODEL", "gpt-4o")
@@ -86,7 +86,7 @@ async def chat_gpt(bot: BOT, message: Message):
chat_completion = await TEXT_CLIENT.chat.completions.create(
messages=[
- {"role": "system", "content": Settings.CONFIG.system_instruction},
+ {"role": "system", "content": SYSTEM_INSTRUCTION},
{"role": "user", "content": prompt},
],
model=OPENAI_MODEL,