ai: now supports image editing and generation.
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import io
|
||||
import logging
|
||||
import shutil
|
||||
import time
|
||||
@@ -23,8 +24,8 @@ logging.getLogger("google_genai.models").setLevel(logging.WARNING)
|
||||
DB_SETTINGS = CustomDB["COMMON_SETTINGS"]
|
||||
|
||||
try:
|
||||
client: Client = Client(api_key=extra_config.GEMINI_API_KEY)
|
||||
async_client: AsyncClient = client.aio
|
||||
client: Client | None = Client(api_key=extra_config.GEMINI_API_KEY)
|
||||
async_client: AsyncClient | None = client.aio
|
||||
except:
|
||||
client = async_client = None
|
||||
|
||||
@@ -32,7 +33,9 @@ except:
|
||||
async def init_task():
|
||||
model_info = await DB_SETTINGS.find_one({"_id": "gemini_model_info"}) or {}
|
||||
if model_name := model_info.get("model_name"):
|
||||
Settings.MODEL = model_name
|
||||
Settings.TEXT_MODEL = model_name
|
||||
if image_model := model_info.get("image_model_name"):
|
||||
Settings.IMAGE_MODEL = image_model
|
||||
|
||||
|
||||
def run_basic_check(function):
|
||||
@@ -64,9 +67,24 @@ def run_basic_check(function):
|
||||
return wrapper
|
||||
|
||||
|
||||
def get_response_text(response, quoted: bool = False, add_sources: bool = True):
|
||||
candidate = response.candidates[0]
|
||||
text = "\n".join([part.text for part in candidate.content.parts])
|
||||
def get_response_content(
|
||||
response, quoted: bool = False, add_sources: bool = True
|
||||
) -> tuple[str, io.BytesIO | None]:
|
||||
|
||||
try:
|
||||
candidate = response.candidates
|
||||
parts = candidate[0].content.parts
|
||||
parts[0]
|
||||
except (AttributeError, IndexError):
|
||||
return "Query failed... Try again", None
|
||||
|
||||
try:
|
||||
image_data = io.BytesIO(parts[0].inline_data.data)
|
||||
image_data.name = "photo.jpg"
|
||||
except (AttributeError, IndexError):
|
||||
image_data = None
|
||||
|
||||
text = "\n".join([part.text for part in parts if part.text])
|
||||
sources = ""
|
||||
|
||||
if add_sources:
|
||||
@@ -80,7 +98,11 @@ def get_response_text(response, quoted: bool = False, add_sources: bool = True):
|
||||
sources = ""
|
||||
|
||||
final_text = (text.strip() + sources).strip()
|
||||
return f"**>\n{final_text}<**" if quoted and "```" not in final_text else final_text
|
||||
|
||||
if final_text and quoted and "```" not in final_text:
|
||||
final_text = f"**>\n{final_text}<**"
|
||||
|
||||
return final_text, image_data
|
||||
|
||||
|
||||
async def save_file(message: Message, check_size: bool = True) -> File | None:
|
||||
@@ -132,7 +154,7 @@ async def create_prompts(
|
||||
if is_chat:
|
||||
if message.media:
|
||||
prompt = message.caption or PROMPT_MAP.get(message.media.value) or default_media_prompt
|
||||
return [await save_file(message=message, check_size=check_size), prompt]
|
||||
return [prompt, await save_file(message=message, check_size=check_size)]
|
||||
else:
|
||||
return [message.text]
|
||||
|
||||
@@ -142,9 +164,9 @@ async def create_prompts(
|
||||
prompt = (
|
||||
message.filtered_input or PROMPT_MAP.get(reply.media.value) or default_media_prompt
|
||||
)
|
||||
return [await save_file(message=reply, check_size=check_size), prompt]
|
||||
return [prompt, await save_file(message=reply, check_size=check_size)]
|
||||
else:
|
||||
return [str(reply.text), input_prompt]
|
||||
return [input_prompt, str(reply.text)]
|
||||
|
||||
return [input_prompt]
|
||||
|
||||
@@ -165,80 +187,93 @@ async def list_ai_models(bot: BOT, message: Message):
|
||||
model_str = "\n\n".join(model_list)
|
||||
|
||||
update_str = (
|
||||
f"<b>Current Model</b>: <code>{Settings.MODEL}</code>\n\n"
|
||||
f"<blockquote expandable=True><pre language=text>{model_str}</pre></blockquote>"
|
||||
f"<b>Current Model</b>: <code>"
|
||||
f"{Settings.TEXT_MODEL if "-i" not in message.flags else Settings.IMAGE_MODEL}</code>"
|
||||
f"\n\n<blockquote expandable=True><pre language=text>{model_str}</pre></blockquote>"
|
||||
"\n\nReply to this message with the <code>model name</code> to change to a different model."
|
||||
)
|
||||
|
||||
model_reply = await message.reply(update_str)
|
||||
model_info_response = await message.reply(update_str)
|
||||
|
||||
response = await model_reply.get_response(
|
||||
timeout=60, reply_to_message_id=model_reply.id, from_user=message.from_user.id
|
||||
model_response = await model_info_response.get_response(
|
||||
timeout=60, reply_to_message_id=model_info_response.id, from_user=message.from_user.id
|
||||
)
|
||||
|
||||
if not response:
|
||||
await model_reply.delete()
|
||||
if not model_response:
|
||||
await model_info_response.delete()
|
||||
return
|
||||
|
||||
if response.text not in model_list:
|
||||
await model_reply.edit(
|
||||
f"Invalid Model... run <code>{message.trigger}{message.cmd}</code> again"
|
||||
)
|
||||
if model_response.text not in model_list:
|
||||
await model_info_response.edit(f"<code>Invalid Model... Try again</code>")
|
||||
return
|
||||
|
||||
await DB_SETTINGS.add_data({"_id": "gemini_model_info", "model_name": response.text})
|
||||
resp_str = f"{response.text} saved as model."
|
||||
await model_reply.edit(resp_str)
|
||||
await bot.log_text(text=resp_str, type="ai")
|
||||
Settings.MODEL = response.text
|
||||
if "-i" in message.flags:
|
||||
data_key = "image_model_name"
|
||||
Settings.IMAGE_MODEL = model_response.text
|
||||
else:
|
||||
data_key = "model_name"
|
||||
Settings.TEXT_MODEL = model_response.text
|
||||
|
||||
await DB_SETTINGS.add_data({"_id": "gemini_model_info", data_key: model_response.text})
|
||||
resp_str = f"{model_response.text} saved as model."
|
||||
await model_info_response.edit(resp_str)
|
||||
await bot.log_text(text=resp_str, type=f"ai_{data_key}")
|
||||
|
||||
|
||||
SAFETY_SETTINGS = [
|
||||
# SafetySetting(category="HARM_CATEGORY_UNSPECIFIED", threshold="BLOCK_NONE"),
|
||||
SafetySetting(category="HARM_CATEGORY_HATE_SPEECH", threshold="BLOCK_NONE"),
|
||||
SafetySetting(category="HARM_CATEGORY_DANGEROUS_CONTENT", threshold="BLOCK_NONE"),
|
||||
SafetySetting(category="HARM_CATEGORY_HARASSMENT", threshold="BLOCK_NONE"),
|
||||
SafetySetting(category="HARM_CATEGORY_SEXUALLY_EXPLICIT", threshold="BLOCK_NONE"),
|
||||
SafetySetting(category="HARM_CATEGORY_CIVIC_INTEGRITY", threshold="BLOCK_NONE"),
|
||||
]
|
||||
|
||||
SEARCH_TOOL = Tool(
|
||||
google_search=GoogleSearchRetrieval(
|
||||
dynamic_retrieval_config=DynamicRetrievalConfig(dynamic_threshold=0.3)
|
||||
)
|
||||
)
|
||||
|
||||
SYSTEM_INSTRUCTION = (
|
||||
"Answer precisely and in short unless specifically instructed otherwise."
|
||||
"\nIF asked related to code, do not comment the code and do not explain the code unless instructed."
|
||||
)
|
||||
|
||||
|
||||
class Settings:
|
||||
MODEL = "gemini-2.0-flash"
|
||||
|
||||
# fmt:off
|
||||
CONFIG = GenerateContentConfig(
|
||||
TEXT_MODEL = "gemini-2.0-flash"
|
||||
|
||||
TEXT_CONFIG = GenerateContentConfig(
|
||||
candidate_count=1,
|
||||
|
||||
system_instruction=(
|
||||
"Answer precisely and in short unless specifically instructed otherwise."
|
||||
"\nIF asked related to code, do not comment the code and do not explain the code unless instructed."
|
||||
),
|
||||
|
||||
temperature=0.69,
|
||||
|
||||
max_output_tokens=1024,
|
||||
|
||||
safety_settings=[
|
||||
# SafetySetting(category="HARM_CATEGORY_UNSPECIFIED", threshold="BLOCK_NONE"),
|
||||
SafetySetting(category="HARM_CATEGORY_HATE_SPEECH", threshold="BLOCK_NONE"),
|
||||
SafetySetting(category="HARM_CATEGORY_DANGEROUS_CONTENT", threshold="BLOCK_NONE"),
|
||||
SafetySetting(category="HARM_CATEGORY_HARASSMENT", threshold="BLOCK_NONE"),
|
||||
SafetySetting(category="HARM_CATEGORY_SEXUALLY_EXPLICIT", threshold="BLOCK_NONE"),
|
||||
SafetySetting(category="HARM_CATEGORY_CIVIC_INTEGRITY", threshold="BLOCK_NONE"),
|
||||
],
|
||||
# fmt:on
|
||||
|
||||
response_modalities=["Text"],
|
||||
system_instruction=SYSTEM_INSTRUCTION,
|
||||
temperature=0.69,
|
||||
tools=[],
|
||||
)
|
||||
|
||||
SEARCH_TOOL = Tool(
|
||||
google_search=GoogleSearchRetrieval(
|
||||
dynamic_retrieval_config=DynamicRetrievalConfig(
|
||||
dynamic_threshold=0.3
|
||||
)
|
||||
)
|
||||
)
|
||||
IMAGE_MODEL = "gemini-2.0-flash-exp"
|
||||
|
||||
IMAGE_CONFIG = GenerateContentConfig(
|
||||
candidate_count=1,
|
||||
max_output_tokens=1024,
|
||||
response_modalities=["Text", "Image"],
|
||||
# system_instruction=SYSTEM_INSTRUCTION,
|
||||
temperature=0.99,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_kwargs(use_search:bool=False) -> dict:
|
||||
tools = Settings.CONFIG.tools
|
||||
def get_kwargs(use_search: bool = False, image_mode: bool = False) -> dict:
|
||||
if image_mode:
|
||||
return {"model": Settings.IMAGE_MODEL, "config": Settings.IMAGE_CONFIG}
|
||||
|
||||
if not use_search and Settings.SEARCH_TOOL in tools:
|
||||
tools.remove(Settings.SEARCH_TOOL)
|
||||
tools = Settings.TEXT_CONFIG.tools
|
||||
|
||||
if use_search and Settings.SEARCH_TOOL not in tools:
|
||||
tools.append(Settings.SEARCH_TOOL)
|
||||
if not use_search and SEARCH_TOOL in tools:
|
||||
tools.remove(SEARCH_TOOL)
|
||||
|
||||
return {"model": Settings.MODEL, "config": Settings.CONFIG}
|
||||
if use_search and SEARCH_TOOL not in tools:
|
||||
tools.append(SEARCH_TOOL)
|
||||
|
||||
return {"model": Settings.TEXT_MODEL, "config": Settings.TEXT_CONFIG}
|
||||
|
||||
@@ -3,13 +3,14 @@ from io import BytesIO
|
||||
|
||||
from google.genai.chats import AsyncChat
|
||||
from pyrogram.enums import ChatType, ParseMode
|
||||
from pyrogram.types import InputMediaPhoto
|
||||
|
||||
from app import BOT, Convo, Message, bot
|
||||
from app.plugins.ai.gemini_core import (
|
||||
Settings,
|
||||
async_client,
|
||||
create_prompts,
|
||||
get_response_text,
|
||||
get_response_content,
|
||||
run_basic_check,
|
||||
)
|
||||
|
||||
@@ -22,6 +23,7 @@ async def question(bot: BOT, message: Message):
|
||||
INFO: Ask a question to Gemini AI or get info about replied message / media.
|
||||
FLAGS:
|
||||
-s: to use Search
|
||||
-i: to edit/generate images
|
||||
USAGE:
|
||||
.ai what is the meaning of life.
|
||||
.ai [reply to a message] (sends replied text as query)
|
||||
@@ -34,11 +36,11 @@ async def question(bot: BOT, message: Message):
|
||||
prompt = message.filtered_input
|
||||
|
||||
if reply and reply.media:
|
||||
message_response = await message.reply("<code>Processing... this may take a while.</code>")
|
||||
resp_str = "<code>Processing... this may take a while.</code>"
|
||||
else:
|
||||
message_response = await message.reply(
|
||||
"<code>Input received... generating response.</code>"
|
||||
)
|
||||
resp_str = "<code>Input received... generating response.</code>"
|
||||
|
||||
message_response = await message.reply(resp_str)
|
||||
|
||||
try:
|
||||
prompts = await create_prompts(message=message)
|
||||
@@ -46,17 +48,24 @@ async def question(bot: BOT, message: Message):
|
||||
await message_response.edit(e)
|
||||
return
|
||||
|
||||
response = await async_client.models.generate_content(
|
||||
contents=prompts, **Settings.get_kwargs(use_search="-s" in message.flags)
|
||||
)
|
||||
kwargs = Settings.get_kwargs(use_search="-s" in message.flags, image_mode="-i" in message.flags)
|
||||
|
||||
response_text = get_response_text(response, quoted=True)
|
||||
response = await async_client.models.generate_content(contents=prompts, **kwargs)
|
||||
|
||||
await message_response.edit(
|
||||
text=f"**>\n•> {prompt}<**\n{response_text}",
|
||||
parse_mode=ParseMode.MARKDOWN,
|
||||
disable_preview=True,
|
||||
)
|
||||
response_text, response_image = get_response_content(response, quoted=True)
|
||||
|
||||
if response_image:
|
||||
await message_response.edit_media(
|
||||
media=InputMediaPhoto(media=response_image, caption=f"**>\n•> {prompt}<**")
|
||||
)
|
||||
if response_text and isinstance(message, Message):
|
||||
await message_response.reply(response_text)
|
||||
else:
|
||||
await message_response.edit(
|
||||
text=f"**>\n•> {prompt}<**\n{response_text}",
|
||||
parse_mode=ParseMode.MARKDOWN,
|
||||
disable_preview=True,
|
||||
)
|
||||
|
||||
|
||||
@bot.add_cmd(cmd="aic")
|
||||
@@ -67,7 +76,7 @@ async def ai_chat(bot: BOT, message: Message):
|
||||
INFO: Have a Conversation with Gemini AI.
|
||||
FLAGS:
|
||||
"-s": use search
|
||||
|
||||
"-i": use image gen/edit mode
|
||||
USAGE:
|
||||
.aic hello
|
||||
keep replying to AI responses with text | media [no need to reply in DM]
|
||||
@@ -75,7 +84,9 @@ async def ai_chat(bot: BOT, message: Message):
|
||||
use .load_history to continue
|
||||
|
||||
"""
|
||||
chat = async_client.chats.create(**Settings.get_kwargs("-s" in message.flags))
|
||||
chat = async_client.chats.create(
|
||||
**Settings.get_kwargs(use_search="-s" in message.flags, image_mode="-i" in message.flags)
|
||||
)
|
||||
await do_convo(chat=chat, message=message)
|
||||
|
||||
|
||||
@@ -101,33 +112,36 @@ async def history_chat(bot: BOT, message: Message):
|
||||
return
|
||||
|
||||
resp = await message.reply("`Loading History...`")
|
||||
|
||||
doc = await reply.download(in_memory=True)
|
||||
doc.seek(0)
|
||||
pickle.load(doc)
|
||||
|
||||
history = pickle.load(doc)
|
||||
await resp.edit("__History Loaded... Resuming chat__")
|
||||
|
||||
chat = async_client.chats.create(
|
||||
**Settings.get_kwargs(use_search="-s" in message.flags), history=history
|
||||
**Settings.get_kwargs(use_search="-s" in message.flags, image_mode="-i" in message.flags)
|
||||
)
|
||||
await do_convo(chat=chat, message=message)
|
||||
await do_convo(chat=chat, message=message, is_reloaded=True)
|
||||
|
||||
|
||||
CONVO_CACHE: dict[str, Convo] = {}
|
||||
|
||||
|
||||
async def do_convo(chat: AsyncChat, message: Message):
|
||||
async def do_convo(chat: AsyncChat, message: Message, is_reloaded: bool = False):
|
||||
chat_id = message.chat.id
|
||||
old_convo = CONVO_CACHE.get(message.unique_chat_user_id)
|
||||
|
||||
if old_convo in Convo.CONVO_DICT[chat_id]:
|
||||
Convo.CONVO_DICT[chat_id].remove(old_convo)
|
||||
old_conversation = CONVO_CACHE.get(message.unique_chat_user_id)
|
||||
|
||||
if old_conversation in Convo.CONVO_DICT[chat_id]:
|
||||
Convo.CONVO_DICT[chat_id].remove(old_conversation)
|
||||
|
||||
if message.chat.type in (ChatType.PRIVATE, ChatType.BOT):
|
||||
reply_to_user_id = None
|
||||
else:
|
||||
reply_to_user_id = message._client.me.id
|
||||
|
||||
convo_obj = Convo(
|
||||
conversation_object = Convo(
|
||||
client=message._client,
|
||||
chat_id=chat_id,
|
||||
timeout=300,
|
||||
@@ -136,51 +150,65 @@ async def do_convo(chat: AsyncChat, message: Message):
|
||||
reply_to_user_id=reply_to_user_id,
|
||||
)
|
||||
|
||||
CONVO_CACHE[message.unique_chat_user_id] = convo_obj
|
||||
CONVO_CACHE[message.unique_chat_user_id] = conversation_object
|
||||
|
||||
try:
|
||||
async with convo_obj:
|
||||
prompt = [message.input]
|
||||
async with conversation_object:
|
||||
prompt = await create_prompts(message, is_chat=is_reloaded)
|
||||
reply_to_id = message.id
|
||||
|
||||
while True:
|
||||
ai_response = await chat.send_message(prompt)
|
||||
ai_response_text = get_response_text(ai_response, add_sources=True, quoted=True)
|
||||
|
||||
_, prompt_message = await convo_obj.send_message(
|
||||
text=f"**>\n•><**\n{ai_response_text}",
|
||||
response_text, response_image = get_response_content(ai_response, quoted=True)
|
||||
prompt_message = await send_and_get_resp(
|
||||
convo_obj=conversation_object,
|
||||
response_text=response_text,
|
||||
response_image=response_image,
|
||||
reply_to_id=reply_to_id,
|
||||
parse_mode=ParseMode.MARKDOWN,
|
||||
get_response=True,
|
||||
disable_preview=True,
|
||||
)
|
||||
|
||||
try:
|
||||
prompt = await create_prompts(
|
||||
message=prompt_message, is_chat=True, check_size=False
|
||||
)
|
||||
prompt = await create_prompts(prompt_message, is_chat=True, check_size=False)
|
||||
except Exception as e:
|
||||
_, prompt_message = await convo_obj.send_message(
|
||||
text=str(e),
|
||||
reply_to_id=reply_to_id,
|
||||
parse_mode=ParseMode.MARKDOWN,
|
||||
get_response=True,
|
||||
disable_preview=True,
|
||||
)
|
||||
prompt = await create_prompts(
|
||||
message=prompt_message, is_chat=True, check_size=False
|
||||
prompt_message = await send_and_get_resp(
|
||||
conversation_object, str(e), reply_to_id=reply_to_id
|
||||
)
|
||||
prompt = await create_prompts(prompt_message, is_chat=True, check_size=False)
|
||||
|
||||
reply_to_id = prompt_message.id
|
||||
|
||||
except TimeoutError:
|
||||
pass
|
||||
|
||||
finally:
|
||||
await export_history(chat, message)
|
||||
CONVO_CACHE.pop(message.unique_chat_user_id, 0)
|
||||
|
||||
|
||||
async def send_and_get_resp(
|
||||
convo_obj: Convo,
|
||||
response_text: str | None = None,
|
||||
response_image: BytesIO | None = None,
|
||||
reply_to_id: int | None = None,
|
||||
) -> Message:
|
||||
|
||||
if response_image:
|
||||
await convo_obj.send_photo(photo=response_image, reply_to_id=reply_to_id)
|
||||
|
||||
if response_text:
|
||||
await convo_obj.send_message(
|
||||
text=f"**>\n•><**\n{response_text}",
|
||||
reply_to_id=reply_to_id,
|
||||
parse_mode=ParseMode.MARKDOWN,
|
||||
disable_preview=True,
|
||||
)
|
||||
return await convo_obj.get_response()
|
||||
|
||||
|
||||
async def export_history(chat: AsyncChat, message: Message):
|
||||
doc = BytesIO(pickle.dumps(chat._curated_history))
|
||||
doc.name = "AI_Chat_History.pkl"
|
||||
caption = get_response_text(
|
||||
caption, _ = get_response_content(
|
||||
await chat.send_message("Summarize our Conversation into one line.")
|
||||
)
|
||||
await bot.send_document(chat_id=message.from_user.id, document=doc, caption=caption)
|
||||
@@ -7,7 +7,7 @@ from pyrogram.enums import ParseMode
|
||||
from pyrogram.types import InputMediaPhoto
|
||||
|
||||
from app import BOT, Message
|
||||
from app.plugins.ai.gemini_core import Settings
|
||||
from app.plugins.ai.gemini_core import SYSTEM_INSTRUCTION
|
||||
|
||||
OPENAI_CLIENT = environ.get("OPENAI_CLIENT", "")
|
||||
OPENAI_MODEL = environ.get("OPENAI_MODEL", "gpt-4o")
|
||||
@@ -86,7 +86,7 @@ async def chat_gpt(bot: BOT, message: Message):
|
||||
|
||||
chat_completion = await TEXT_CLIENT.chat.completions.create(
|
||||
messages=[
|
||||
{"role": "system", "content": Settings.CONFIG.system_instruction},
|
||||
{"role": "system", "content": SYSTEM_INSTRUCTION},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
model=OPENAI_MODEL,
|
||||
|
||||
Reference in New Issue
Block a user