ai: now supports image editing and generation.

This commit is contained in:
thedragonsinn
2025-03-19 15:25:02 +05:30
parent aa74e5304d
commit 469b4d02ad
3 changed files with 175 additions and 112 deletions

View File

@@ -1,4 +1,5 @@
import asyncio
import io
import logging
import shutil
import time
@@ -23,8 +24,8 @@ logging.getLogger("google_genai.models").setLevel(logging.WARNING)
DB_SETTINGS = CustomDB["COMMON_SETTINGS"]
try:
client: Client = Client(api_key=extra_config.GEMINI_API_KEY)
async_client: AsyncClient = client.aio
client: Client | None = Client(api_key=extra_config.GEMINI_API_KEY)
async_client: AsyncClient | None = client.aio
except:
client = async_client = None
@@ -32,7 +33,9 @@ except:
async def init_task():
model_info = await DB_SETTINGS.find_one({"_id": "gemini_model_info"}) or {}
if model_name := model_info.get("model_name"):
Settings.MODEL = model_name
Settings.TEXT_MODEL = model_name
if image_model := model_info.get("image_model_name"):
Settings.IMAGE_MODEL = image_model
def run_basic_check(function):
@@ -64,9 +67,24 @@ def run_basic_check(function):
return wrapper
def get_response_text(response, quoted: bool = False, add_sources: bool = True):
candidate = response.candidates[0]
text = "\n".join([part.text for part in candidate.content.parts])
def get_response_content(
response, quoted: bool = False, add_sources: bool = True
) -> tuple[str, io.BytesIO | None]:
try:
candidate = response.candidates
parts = candidate[0].content.parts
parts[0]
except (AttributeError, IndexError):
return "Query failed... Try again", None
try:
image_data = io.BytesIO(parts[0].inline_data.data)
image_data.name = "photo.jpg"
except (AttributeError, IndexError):
image_data = None
text = "\n".join([part.text for part in parts if part.text])
sources = ""
if add_sources:
@@ -80,7 +98,11 @@ def get_response_text(response, quoted: bool = False, add_sources: bool = True):
sources = ""
final_text = (text.strip() + sources).strip()
return f"**>\n{final_text}<**" if quoted and "```" not in final_text else final_text
if final_text and quoted and "```" not in final_text:
final_text = f"**>\n{final_text}<**"
return final_text, image_data
async def save_file(message: Message, check_size: bool = True) -> File | None:
@@ -132,7 +154,7 @@ async def create_prompts(
if is_chat:
if message.media:
prompt = message.caption or PROMPT_MAP.get(message.media.value) or default_media_prompt
return [await save_file(message=message, check_size=check_size), prompt]
return [prompt, await save_file(message=message, check_size=check_size)]
else:
return [message.text]
@@ -142,9 +164,9 @@ async def create_prompts(
prompt = (
message.filtered_input or PROMPT_MAP.get(reply.media.value) or default_media_prompt
)
return [await save_file(message=reply, check_size=check_size), prompt]
return [prompt, await save_file(message=reply, check_size=check_size)]
else:
return [str(reply.text), input_prompt]
return [input_prompt, str(reply.text)]
return [input_prompt]
@@ -165,80 +187,93 @@ async def list_ai_models(bot: BOT, message: Message):
model_str = "\n\n".join(model_list)
update_str = (
f"<b>Current Model</b>: <code>{Settings.MODEL}</code>\n\n"
f"<blockquote expandable=True><pre language=text>{model_str}</pre></blockquote>"
f"<b>Current Model</b>: <code>"
f"{Settings.TEXT_MODEL if "-i" not in message.flags else Settings.IMAGE_MODEL}</code>"
f"\n\n<blockquote expandable=True><pre language=text>{model_str}</pre></blockquote>"
"\n\nReply to this message with the <code>model name</code> to change to a different model."
)
model_reply = await message.reply(update_str)
model_info_response = await message.reply(update_str)
response = await model_reply.get_response(
timeout=60, reply_to_message_id=model_reply.id, from_user=message.from_user.id
model_response = await model_info_response.get_response(
timeout=60, reply_to_message_id=model_info_response.id, from_user=message.from_user.id
)
if not response:
await model_reply.delete()
if not model_response:
await model_info_response.delete()
return
if response.text not in model_list:
await model_reply.edit(
f"Invalid Model... run <code>{message.trigger}{message.cmd}</code> again"
)
if model_response.text not in model_list:
await model_info_response.edit(f"<code>Invalid Model... Try again</code>")
return
await DB_SETTINGS.add_data({"_id": "gemini_model_info", "model_name": response.text})
resp_str = f"{response.text} saved as model."
await model_reply.edit(resp_str)
await bot.log_text(text=resp_str, type="ai")
Settings.MODEL = response.text
if "-i" in message.flags:
data_key = "image_model_name"
Settings.IMAGE_MODEL = model_response.text
else:
data_key = "model_name"
Settings.TEXT_MODEL = model_response.text
await DB_SETTINGS.add_data({"_id": "gemini_model_info", data_key: model_response.text})
resp_str = f"{model_response.text} saved as model."
await model_info_response.edit(resp_str)
await bot.log_text(text=resp_str, type=f"ai_{data_key}")
SAFETY_SETTINGS = [
# SafetySetting(category="HARM_CATEGORY_UNSPECIFIED", threshold="BLOCK_NONE"),
SafetySetting(category="HARM_CATEGORY_HATE_SPEECH", threshold="BLOCK_NONE"),
SafetySetting(category="HARM_CATEGORY_DANGEROUS_CONTENT", threshold="BLOCK_NONE"),
SafetySetting(category="HARM_CATEGORY_HARASSMENT", threshold="BLOCK_NONE"),
SafetySetting(category="HARM_CATEGORY_SEXUALLY_EXPLICIT", threshold="BLOCK_NONE"),
SafetySetting(category="HARM_CATEGORY_CIVIC_INTEGRITY", threshold="BLOCK_NONE"),
]
SEARCH_TOOL = Tool(
google_search=GoogleSearchRetrieval(
dynamic_retrieval_config=DynamicRetrievalConfig(dynamic_threshold=0.3)
)
)
SYSTEM_INSTRUCTION = (
"Answer precisely and in short unless specifically instructed otherwise."
"\nIF asked related to code, do not comment the code and do not explain the code unless instructed."
)
class Settings:
MODEL = "gemini-2.0-flash"
# fmt:off
CONFIG = GenerateContentConfig(
TEXT_MODEL = "gemini-2.0-flash"
TEXT_CONFIG = GenerateContentConfig(
candidate_count=1,
system_instruction=(
"Answer precisely and in short unless specifically instructed otherwise."
"\nIF asked related to code, do not comment the code and do not explain the code unless instructed."
),
temperature=0.69,
max_output_tokens=1024,
safety_settings=[
# SafetySetting(category="HARM_CATEGORY_UNSPECIFIED", threshold="BLOCK_NONE"),
SafetySetting(category="HARM_CATEGORY_HATE_SPEECH", threshold="BLOCK_NONE"),
SafetySetting(category="HARM_CATEGORY_DANGEROUS_CONTENT", threshold="BLOCK_NONE"),
SafetySetting(category="HARM_CATEGORY_HARASSMENT", threshold="BLOCK_NONE"),
SafetySetting(category="HARM_CATEGORY_SEXUALLY_EXPLICIT", threshold="BLOCK_NONE"),
SafetySetting(category="HARM_CATEGORY_CIVIC_INTEGRITY", threshold="BLOCK_NONE"),
],
# fmt:on
response_modalities=["Text"],
system_instruction=SYSTEM_INSTRUCTION,
temperature=0.69,
tools=[],
)
SEARCH_TOOL = Tool(
google_search=GoogleSearchRetrieval(
dynamic_retrieval_config=DynamicRetrievalConfig(
dynamic_threshold=0.3
)
)
)
IMAGE_MODEL = "gemini-2.0-flash-exp"
IMAGE_CONFIG = GenerateContentConfig(
candidate_count=1,
max_output_tokens=1024,
response_modalities=["Text", "Image"],
# system_instruction=SYSTEM_INSTRUCTION,
temperature=0.99,
)
@staticmethod
def get_kwargs(use_search:bool=False) -> dict:
tools = Settings.CONFIG.tools
def get_kwargs(use_search: bool = False, image_mode: bool = False) -> dict:
if image_mode:
return {"model": Settings.IMAGE_MODEL, "config": Settings.IMAGE_CONFIG}
if not use_search and Settings.SEARCH_TOOL in tools:
tools.remove(Settings.SEARCH_TOOL)
tools = Settings.TEXT_CONFIG.tools
if use_search and Settings.SEARCH_TOOL not in tools:
tools.append(Settings.SEARCH_TOOL)
if not use_search and SEARCH_TOOL in tools:
tools.remove(SEARCH_TOOL)
return {"model": Settings.MODEL, "config": Settings.CONFIG}
if use_search and SEARCH_TOOL not in tools:
tools.append(SEARCH_TOOL)
return {"model": Settings.TEXT_MODEL, "config": Settings.TEXT_CONFIG}

View File

@@ -3,13 +3,14 @@ from io import BytesIO
from google.genai.chats import AsyncChat
from pyrogram.enums import ChatType, ParseMode
from pyrogram.types import InputMediaPhoto
from app import BOT, Convo, Message, bot
from app.plugins.ai.gemini_core import (
Settings,
async_client,
create_prompts,
get_response_text,
get_response_content,
run_basic_check,
)
@@ -22,6 +23,7 @@ async def question(bot: BOT, message: Message):
INFO: Ask a question to Gemini AI or get info about replied message / media.
FLAGS:
-s: to use Search
-i: to edit/generate images
USAGE:
.ai what is the meaning of life.
.ai [reply to a message] (sends replied text as query)
@@ -34,11 +36,11 @@ async def question(bot: BOT, message: Message):
prompt = message.filtered_input
if reply and reply.media:
message_response = await message.reply("<code>Processing... this may take a while.</code>")
resp_str = "<code>Processing... this may take a while.</code>"
else:
message_response = await message.reply(
"<code>Input received... generating response.</code>"
)
resp_str = "<code>Input received... generating response.</code>"
message_response = await message.reply(resp_str)
try:
prompts = await create_prompts(message=message)
@@ -46,17 +48,24 @@ async def question(bot: BOT, message: Message):
await message_response.edit(e)
return
response = await async_client.models.generate_content(
contents=prompts, **Settings.get_kwargs(use_search="-s" in message.flags)
)
kwargs = Settings.get_kwargs(use_search="-s" in message.flags, image_mode="-i" in message.flags)
response_text = get_response_text(response, quoted=True)
response = await async_client.models.generate_content(contents=prompts, **kwargs)
await message_response.edit(
text=f"**>\n•> {prompt}<**\n{response_text}",
parse_mode=ParseMode.MARKDOWN,
disable_preview=True,
)
response_text, response_image = get_response_content(response, quoted=True)
if response_image:
await message_response.edit_media(
media=InputMediaPhoto(media=response_image, caption=f"**>\n•> {prompt}<**")
)
if response_text and isinstance(message, Message):
await message_response.reply(response_text)
else:
await message_response.edit(
text=f"**>\n•> {prompt}<**\n{response_text}",
parse_mode=ParseMode.MARKDOWN,
disable_preview=True,
)
@bot.add_cmd(cmd="aic")
@@ -67,7 +76,7 @@ async def ai_chat(bot: BOT, message: Message):
INFO: Have a Conversation with Gemini AI.
FLAGS:
"-s": use search
"-i": use image gen/edit mode
USAGE:
.aic hello
keep replying to AI responses with text | media [no need to reply in DM]
@@ -75,7 +84,9 @@ async def ai_chat(bot: BOT, message: Message):
use .load_history to continue
"""
chat = async_client.chats.create(**Settings.get_kwargs("-s" in message.flags))
chat = async_client.chats.create(
**Settings.get_kwargs(use_search="-s" in message.flags, image_mode="-i" in message.flags)
)
await do_convo(chat=chat, message=message)
@@ -101,33 +112,36 @@ async def history_chat(bot: BOT, message: Message):
return
resp = await message.reply("`Loading History...`")
doc = await reply.download(in_memory=True)
doc.seek(0)
pickle.load(doc)
history = pickle.load(doc)
await resp.edit("__History Loaded... Resuming chat__")
chat = async_client.chats.create(
**Settings.get_kwargs(use_search="-s" in message.flags), history=history
**Settings.get_kwargs(use_search="-s" in message.flags, image_mode="-i" in message.flags)
)
await do_convo(chat=chat, message=message)
await do_convo(chat=chat, message=message, is_reloaded=True)
CONVO_CACHE: dict[str, Convo] = {}
async def do_convo(chat: AsyncChat, message: Message):
async def do_convo(chat: AsyncChat, message: Message, is_reloaded: bool = False):
chat_id = message.chat.id
old_convo = CONVO_CACHE.get(message.unique_chat_user_id)
if old_convo in Convo.CONVO_DICT[chat_id]:
Convo.CONVO_DICT[chat_id].remove(old_convo)
old_conversation = CONVO_CACHE.get(message.unique_chat_user_id)
if old_conversation in Convo.CONVO_DICT[chat_id]:
Convo.CONVO_DICT[chat_id].remove(old_conversation)
if message.chat.type in (ChatType.PRIVATE, ChatType.BOT):
reply_to_user_id = None
else:
reply_to_user_id = message._client.me.id
convo_obj = Convo(
conversation_object = Convo(
client=message._client,
chat_id=chat_id,
timeout=300,
@@ -136,51 +150,65 @@ async def do_convo(chat: AsyncChat, message: Message):
reply_to_user_id=reply_to_user_id,
)
CONVO_CACHE[message.unique_chat_user_id] = convo_obj
CONVO_CACHE[message.unique_chat_user_id] = conversation_object
try:
async with convo_obj:
prompt = [message.input]
async with conversation_object:
prompt = await create_prompts(message, is_chat=is_reloaded)
reply_to_id = message.id
while True:
ai_response = await chat.send_message(prompt)
ai_response_text = get_response_text(ai_response, add_sources=True, quoted=True)
_, prompt_message = await convo_obj.send_message(
text=f"**>\n•><**\n{ai_response_text}",
response_text, response_image = get_response_content(ai_response, quoted=True)
prompt_message = await send_and_get_resp(
convo_obj=conversation_object,
response_text=response_text,
response_image=response_image,
reply_to_id=reply_to_id,
parse_mode=ParseMode.MARKDOWN,
get_response=True,
disable_preview=True,
)
try:
prompt = await create_prompts(
message=prompt_message, is_chat=True, check_size=False
)
prompt = await create_prompts(prompt_message, is_chat=True, check_size=False)
except Exception as e:
_, prompt_message = await convo_obj.send_message(
text=str(e),
reply_to_id=reply_to_id,
parse_mode=ParseMode.MARKDOWN,
get_response=True,
disable_preview=True,
)
prompt = await create_prompts(
message=prompt_message, is_chat=True, check_size=False
prompt_message = await send_and_get_resp(
conversation_object, str(e), reply_to_id=reply_to_id
)
prompt = await create_prompts(prompt_message, is_chat=True, check_size=False)
reply_to_id = prompt_message.id
except TimeoutError:
pass
finally:
await export_history(chat, message)
CONVO_CACHE.pop(message.unique_chat_user_id, 0)
async def send_and_get_resp(
convo_obj: Convo,
response_text: str | None = None,
response_image: BytesIO | None = None,
reply_to_id: int | None = None,
) -> Message:
if response_image:
await convo_obj.send_photo(photo=response_image, reply_to_id=reply_to_id)
if response_text:
await convo_obj.send_message(
text=f"**>\n•><**\n{response_text}",
reply_to_id=reply_to_id,
parse_mode=ParseMode.MARKDOWN,
disable_preview=True,
)
return await convo_obj.get_response()
async def export_history(chat: AsyncChat, message: Message):
doc = BytesIO(pickle.dumps(chat._curated_history))
doc.name = "AI_Chat_History.pkl"
caption = get_response_text(
caption, _ = get_response_content(
await chat.send_message("Summarize our Conversation into one line.")
)
await bot.send_document(chat_id=message.from_user.id, document=doc, caption=caption)

View File

@@ -7,7 +7,7 @@ from pyrogram.enums import ParseMode
from pyrogram.types import InputMediaPhoto
from app import BOT, Message
from app.plugins.ai.gemini_core import Settings
from app.plugins.ai.gemini_core import SYSTEM_INSTRUCTION
OPENAI_CLIENT = environ.get("OPENAI_CLIENT", "")
OPENAI_MODEL = environ.get("OPENAI_MODEL", "gpt-4o")
@@ -86,7 +86,7 @@ async def chat_gpt(bot: BOT, message: Message):
chat_completion = await TEXT_CLIENT.chat.completions.create(
messages=[
{"role": "system", "content": Settings.CONFIG.system_instruction},
{"role": "system", "content": SYSTEM_INSTRUCTION},
{"role": "user", "content": prompt},
],
model=OPENAI_MODEL,