From 11776ddcd278f773765d5014211ea727c6ce1a09 Mon Sep 17 00:00:00 2001 From: thedragonsinn <98635854+thedragonsinn@users.noreply.github.com> Date: Fri, 7 Feb 2025 21:12:22 +0530 Subject: [PATCH] `ai`: upgrade to new models and package. --- app/plugins/ai/media_query.py | 21 ++++--- app/plugins/ai/models.py | 106 ++++++++++++++++++++++++---------- app/plugins/ai/openai.py | 17 +++--- app/plugins/ai/text_query.py | 55 ++++++++++++------ req.txt | 2 +- 5 files changed, 131 insertions(+), 70 deletions(-) diff --git a/app/plugins/ai/media_query.py b/app/plugins/ai/media_query.py index be8727c..0e448fc 100644 --- a/app/plugins/ai/media_query.py +++ b/app/plugins/ai/media_query.py @@ -3,15 +3,14 @@ import os import shutil import time -import google.generativeai as genai from pyrogram.types.messages_and_media import Audio, Photo, Video, Voice from ub_core.utils import get_tg_media_details from app import Message -from app.plugins.ai.models import MODEL, get_response_text +from app.plugins.ai.models import async_client, get_response_text PROMPT_MAP = { - Video: "Summarize the video file", + Video: "Summarize video and audio from the file", Photo: "Summarize the image file", Voice: ( "\nDo not summarise." @@ -23,7 +22,7 @@ PROMPT_MAP = { PROMPT_MAP[Audio] = PROMPT_MAP[Voice] -async def handle_media(prompt: str, media_message: Message, model=MODEL) -> str: +async def handle_media(prompt: str, media_message: Message, **kwargs) -> str: media = get_tg_media_details(media_message) if getattr(media, "file_size", 0) >= 1048576 * 25: @@ -34,16 +33,20 @@ async def handle_media(prompt: str, media_message: Message, model=MODEL) -> str: ) download_dir = os.path.join("downloads", str(time.time())) + "/" - downloaded_file = await media_message.download(download_dir) + downloaded_file: str = await media_message.download(download_dir) - uploaded_file = await asyncio.to_thread(genai.upload_file, downloaded_file) + uploaded_file = await async_client.files.upload( + file=downloaded_file, config={"mime_type": media.mime_type} + ) while uploaded_file.state.name == "PROCESSING": await asyncio.sleep(5) - uploaded_file = await asyncio.to_thread(genai.get_file, uploaded_file.name) + uploaded_file = await async_client.files.get(name=uploaded_file.name) - response = await model.generate_content_async([prompt, uploaded_file]) - response_text = get_response_text(response) + response = await async_client.models.generate_content( + **kwargs, contents=[uploaded_file, prompt] + ) + response_text = get_response_text(response, quoted=True) shutil.rmtree(download_dir, ignore_errors=True) return response_text diff --git a/app/plugins/ai/models.py b/app/plugins/ai/models.py index 0fed03b..dfdb6f0 100644 --- a/app/plugins/ai/models.py +++ b/app/plugins/ai/models.py @@ -1,41 +1,70 @@ from functools import wraps -import google.generativeai as genai +from google.genai.client import AsyncClient, Client +from google.genai.types import ( + DynamicRetrievalConfig, + GenerateContentConfig, + GoogleSearchRetrieval, + SafetySetting, + Tool, +) from pyrogram import filters from app import BOT, CustomDB, Message, extra_config -SETTINGS = CustomDB("COMMON_SETTINGS") +DB_SETTINGS = CustomDB("COMMON_SETTINGS") -GENERATION_CONFIG = {"temperature": 0.69, "max_output_tokens": 2048} +try: + client: Client = Client(api_key=extra_config.GEMINI_API_KEY) + async_client: AsyncClient = client.aio +except: + client = async_client = None -SAFETY_SETTINGS = [ - {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"}, - {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"}, - {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"}, - {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"}, -] -SYSTEM_INSTRUCTION = ( - "Answer precisely and in short unless specifically instructed otherwise." - "\nWhen asked related to code, do not comment the code and do not explain the code unless instructed." -) +class Settings: + MODEL = "gemini-2.0-flash" -MODEL = genai.GenerativeModel( - generation_config=GENERATION_CONFIG, - safety_settings=SAFETY_SETTINGS, - system_instruction=SYSTEM_INSTRUCTION, -) + # fmt:off + CONFIG = GenerateContentConfig( + + system_instruction=( + "Answer precisely and in short unless specifically instructed otherwise." + "\nWhen asked related to code, do not comment the code and do not explain the code unless instructed." + ), + + temperature=0.69, + + max_output_tokens=4000, + + safety_settings=[ + SafetySetting(category="HARM_CATEGORY_HARASSMENT", threshold="BLOCK_NONE"), + SafetySetting(category="HARM_CATEGORY_HATE_SPEECH", threshold="BLOCK_NONE"), + SafetySetting(category="HARM_CATEGORY_SEXUALLY_EXPLICIT", threshold="BLOCK_NONE"), + SafetySetting(category="HARM_CATEGORY_DANGEROUS_CONTENT", threshold="BLOCK_NONE"), + ], + # fmt:on + + tools=[ + Tool( + google_search=GoogleSearchRetrieval( + dynamic_retrieval_config=DynamicRetrievalConfig( + dynamic_threshold=0.3 + ) + ) + ) + ], + ) + + @staticmethod + def get_kwargs() -> dict: + return {"model": Settings.MODEL, "config": Settings.CONFIG} async def init_task(): - if extra_config.GEMINI_API_KEY: - genai.configure(api_key=extra_config.GEMINI_API_KEY) - - model_info = await SETTINGS.find_one({"_id": "gemini_model_info"}) or {} + model_info = await DB_SETTINGS.find_one({"_id": "gemini_model_info"}) or {} model_name = model_info.get("model_name") if model_name: - MODEL._model_name = model_name + Settings.MODEL = model_name @BOT.add_cmd(cmd="llms") @@ -46,15 +75,15 @@ async def list_ai_models(bot: BOT, message: Message): USAGE: .llms """ model_list = [ - model.name - for model in genai.list_models() - if "generateContent" in model.supported_generation_methods + model.name.lstrip("models/") + async for model in await async_client.models.list(config={"query_base": True}) + if "generateContent" in model.supported_actions ] model_str = "\n\n".join(model_list) update_str = ( - f"\n\nCurrent Model: {MODEL._model_name}" + f"\n\nCurrent Model: {Settings.MODEL}" "\n\nTo change to a different model," "Reply to this message with the model name." ) @@ -63,7 +92,7 @@ async def list_ai_models(bot: BOT, message: Message): f"
{model_str}{update_str}"
)
- async def resp_filters(_, c, m):
+ async def resp_filters(_, __, m):
return m.reply_id == model_reply.id
response = await model_reply.get_response(
@@ -80,10 +109,12 @@ async def list_ai_models(bot: BOT, message: Message):
)
return
- await SETTINGS.add_data({"_id": "gemini_model_info", "model_name": response.text})
+ await DB_SETTINGS.add_data(
+ {"_id": "gemini_model_info", "model_name": response.text}
+ )
await model_reply.edit(f"{response.text} saved as model.")
await model_reply.log()
- MODEL._model_name = response.text
+ Settings.MODEL = response.text
def run_basic_check(function):
@@ -115,5 +146,16 @@ def run_basic_check(function):
return wrapper
-def get_response_text(response):
- return "\n".join([part.text for part in response.parts])
+def get_response_text(response, quoted: bool = False):
+ candidate = response.candidates[0]
+ sources = ""
+
+ if grounding_chunks := candidate.grounding_metadata.grounding_chunks:
+ hrefs = [f"[{chunk.web.title}]({chunk.web.uri})" for chunk in grounding_chunks]
+ sources = "\n\nSources: " + " | ".join(hrefs)
+
+ text = "\n".join([part.text for part in candidate.content.parts])
+
+ final_text = (text.strip() + sources).strip()
+
+ return f"**>\n{final_text}<**" if quoted else final_text
diff --git a/app/plugins/ai/openai.py b/app/plugins/ai/openai.py
index e6eb6d0..4ea85b9 100644
--- a/app/plugins/ai/openai.py
+++ b/app/plugins/ai/openai.py
@@ -6,7 +6,7 @@ import openai
from pyrogram.enums import ParseMode
from pyrogram.types import InputMediaPhoto
-from app import BOT, LOGGER, Message
+from app import BOT, Message
from app.plugins.ai.models import SYSTEM_INSTRUCTION
OPENAI_CLIENT = environ.get("OPENAI_CLIENT", "")
@@ -37,14 +37,12 @@ else:
try:
TEXT_CLIENT = AI_CLIENT(**text_init_kwargs)
-except Exception as e:
- LOGGER.error(e)
+except:
TEXT_CLIENT = None
try:
DALL_E_CLIENT = AI_CLIENT(**image_init_kwargs)
-except Exception as e:
- LOGGER.error(e)
+except:
DALL_E_CLIENT = None
@@ -68,7 +66,7 @@ async def chat_gpt(bot: BOT, message: Message):
OPENAI_MODEL = your azure model
AZURE_OPENAI_API_KEY = your api key
AZURE_OPENAI_ENDPOINT = your azure endpoint
- AZURE_DEPLOYMENT = your azure deployment
+ AZURE_DEPLOYMENT = your azure deployment
USAGE:
.gpt hi
@@ -96,7 +94,8 @@ async def chat_gpt(bot: BOT, message: Message):
response = chat_completion.choices[0].message.content
await message.reply(
- text=f"```\n{prompt}```**GPT**:\n{response}", parse_mode=ParseMode.MARKDOWN
+ text=f"**>\n{prompt}\n<**\n**GPT**:**>\n{response}\n<**",
+ parse_mode=ParseMode.MARKDOWN,
)
@@ -118,7 +117,7 @@ async def chat_gpt(bot: BOT, message: Message):
DALL_E_API_KEY = your api key
DALL_E_API_VERSION = your version
DALL_E_ENDPOINT = your azure endpoint
- DALL_E_DEPLOYMENT = your azure deployment
+ DALL_E_DEPLOYMENT = your azure deployment
FLAGS:
-v: for vivid style images (default)
@@ -160,7 +159,7 @@ async def chat_gpt(bot: BOT, message: Message):
await response.edit_media(
InputMediaPhoto(
media=image_io,
- caption=prompt,
+ caption=f"**>\n{prompt}\n<**",
has_spoiler="-s" in message.flags,
)
)
diff --git a/app/plugins/ai/text_query.py b/app/plugins/ai/text_query.py
index db44511..f03f587 100644
--- a/app/plugins/ai/text_query.py
+++ b/app/plugins/ai/text_query.py
@@ -1,12 +1,18 @@
import pickle
from io import BytesIO
+from google.genai.chats import AsyncChat
from pyrogram import filters
from pyrogram.enums import ParseMode
from app import BOT, Convo, Message, bot
from app.plugins.ai.media_query import handle_media
-from app.plugins.ai.models import MODEL, get_response_text, run_basic_check
+from app.plugins.ai.models import (
+ Settings,
+ async_client,
+ get_response_text,
+ run_basic_check,
+)
CONVO_CACHE: dict[str, Convo] = {}
@@ -26,27 +32,33 @@ async def question(bot: BOT, message: Message):
.ai [reply to image | video | gif] [custom prompt]
"""
reply = message.replied
- reply_text = reply.text if reply else ""
+ prompt = message.input
if reply and reply.media:
message_response = await message.reply(
"Processing... this may take a while."
)
- prompt = message.input
response_text = await handle_media(
- prompt=prompt, media_message=reply, model=MODEL
+ prompt=prompt, media_message=reply, **Settings.get_kwargs()
)
else:
message_response = await message.reply(
"Input received... generating response."
)
- prompt = f"{reply_text}\n\n\n{message.input}".strip()
- response = await MODEL.generate_content_async(prompt)
- response_text = get_response_text(response)
+ if reply and reply.text:
+ prompts = [reply.text, message.input]
+ else:
+ prompts = [message.input]
+
+ response = await async_client.models.generate_content(
+ contents=prompts, **Settings.get_kwargs()
+ )
+ response_text = get_response_text(response, quoted=True)
await message_response.edit(
- text=f"```\n{prompt}```**GEMINI AI**:\n{response_text.strip()}",
+ text=f"**>\n{prompt}<**\n**GEMINI AI**:{response_text}",
parse_mode=ParseMode.MARKDOWN,
+ disable_preview=True,
)
@@ -62,7 +74,7 @@ async def ai_chat(bot: BOT, message: Message):
After 5 mins of Idle bot will export history and stop chat.
use .load_history to continue
"""
- chat = MODEL.start_chat(history=[])
+ chat = async_client.chats.create(**Settings.get_kwargs())
await do_convo(chat=chat, message=message)
@@ -77,23 +89,27 @@ async def history_chat(bot: BOT, message: Message):
"""
reply = message.replied
+ if not message.input:
+ await message.reply(f"Ask a question along with {message.trigger}{message.cmd}")
+ return
+
try:
assert reply.document.file_name == "AI_Chat_History.pkl"
except (AssertionError, AttributeError):
await message.reply("Reply to a Valid History file.")
return
- resp = await message.reply("Loading History...")
+ resp = await message.reply("`Loading History...`")
doc = await reply.download(in_memory=True)
doc.seek(0)
history = pickle.load(doc)
- await resp.edit("History Loaded... Resuming chat")
- chat = MODEL.start_chat(history=history)
+ await resp.edit("__History Loaded... Resuming chat__")
+ chat = async_client.chats.create(**Settings.get_kwargs(), history=history)
await do_convo(chat=chat, message=message)
-async def do_convo(chat, message: Message):
+async def do_convo(chat: AsyncChat, message: Message):
prompt = message.input
reply_to_id = message.id
chat_id = message.chat.id
@@ -115,14 +131,15 @@ async def do_convo(chat, message: Message):
try:
async with convo_obj:
while True:
- ai_response = await chat.send_message_async(prompt)
- ai_response_text = get_response_text(ai_response)
- text = f"**GEMINI AI**:\n\n{ai_response_text}"
+ ai_response = await chat.send_message(prompt)
+ ai_response_text = get_response_text(ai_response, quoted=True)
+ text = f"**GEMINI AI**:\n{ai_response_text}"
_, prompt_message = await convo_obj.send_message(
text=text,
reply_to_id=reply_to_id,
parse_mode=ParseMode.MARKDOWN,
get_response=True,
+ disable_preview=True,
)
prompt, reply_to_id = prompt_message.text, prompt_message.id
@@ -147,10 +164,10 @@ def generate_filter(message: Message):
return filters.create(_filter)
-async def export_history(chat, message: Message):
- doc = BytesIO(pickle.dumps(chat.history))
+async def export_history(chat: AsyncChat, message: Message):
+ doc = BytesIO(pickle.dumps(chat._curated_history))
doc.name = "AI_Chat_History.pkl"
caption = get_response_text(
- await chat.send_message_async("Summarize our Conversation into one line.")
+ await chat.send_message("Summarize our Conversation into one line.")
)
await bot.send_document(chat_id=message.from_user.id, document=doc, caption=caption)
diff --git a/req.txt b/req.txt
index 12200fa..003146e 100644
--- a/req.txt
+++ b/req.txt
@@ -4,6 +4,6 @@
yt-dlp>=2024.5.27
pillow
-google-generativeai
+google-genai
openai