diff --git a/app/plugins/ai/media_query.py b/app/plugins/ai/media_query.py index be8727c..0e448fc 100644 --- a/app/plugins/ai/media_query.py +++ b/app/plugins/ai/media_query.py @@ -3,15 +3,14 @@ import os import shutil import time -import google.generativeai as genai from pyrogram.types.messages_and_media import Audio, Photo, Video, Voice from ub_core.utils import get_tg_media_details from app import Message -from app.plugins.ai.models import MODEL, get_response_text +from app.plugins.ai.models import async_client, get_response_text PROMPT_MAP = { - Video: "Summarize the video file", + Video: "Summarize video and audio from the file", Photo: "Summarize the image file", Voice: ( "\nDo not summarise." @@ -23,7 +22,7 @@ PROMPT_MAP = { PROMPT_MAP[Audio] = PROMPT_MAP[Voice] -async def handle_media(prompt: str, media_message: Message, model=MODEL) -> str: +async def handle_media(prompt: str, media_message: Message, **kwargs) -> str: media = get_tg_media_details(media_message) if getattr(media, "file_size", 0) >= 1048576 * 25: @@ -34,16 +33,20 @@ async def handle_media(prompt: str, media_message: Message, model=MODEL) -> str: ) download_dir = os.path.join("downloads", str(time.time())) + "/" - downloaded_file = await media_message.download(download_dir) + downloaded_file: str = await media_message.download(download_dir) - uploaded_file = await asyncio.to_thread(genai.upload_file, downloaded_file) + uploaded_file = await async_client.files.upload( + file=downloaded_file, config={"mime_type": media.mime_type} + ) while uploaded_file.state.name == "PROCESSING": await asyncio.sleep(5) - uploaded_file = await asyncio.to_thread(genai.get_file, uploaded_file.name) + uploaded_file = await async_client.files.get(name=uploaded_file.name) - response = await model.generate_content_async([prompt, uploaded_file]) - response_text = get_response_text(response) + response = await async_client.models.generate_content( + **kwargs, contents=[uploaded_file, prompt] + ) + response_text = get_response_text(response, quoted=True) shutil.rmtree(download_dir, ignore_errors=True) return response_text diff --git a/app/plugins/ai/models.py b/app/plugins/ai/models.py index 0fed03b..dfdb6f0 100644 --- a/app/plugins/ai/models.py +++ b/app/plugins/ai/models.py @@ -1,41 +1,70 @@ from functools import wraps -import google.generativeai as genai +from google.genai.client import AsyncClient, Client +from google.genai.types import ( + DynamicRetrievalConfig, + GenerateContentConfig, + GoogleSearchRetrieval, + SafetySetting, + Tool, +) from pyrogram import filters from app import BOT, CustomDB, Message, extra_config -SETTINGS = CustomDB("COMMON_SETTINGS") +DB_SETTINGS = CustomDB("COMMON_SETTINGS") -GENERATION_CONFIG = {"temperature": 0.69, "max_output_tokens": 2048} +try: + client: Client = Client(api_key=extra_config.GEMINI_API_KEY) + async_client: AsyncClient = client.aio +except: + client = async_client = None -SAFETY_SETTINGS = [ - {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"}, - {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"}, - {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"}, - {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"}, -] -SYSTEM_INSTRUCTION = ( - "Answer precisely and in short unless specifically instructed otherwise." - "\nWhen asked related to code, do not comment the code and do not explain the code unless instructed." -) +class Settings: + MODEL = "gemini-2.0-flash" -MODEL = genai.GenerativeModel( - generation_config=GENERATION_CONFIG, - safety_settings=SAFETY_SETTINGS, - system_instruction=SYSTEM_INSTRUCTION, -) + # fmt:off + CONFIG = GenerateContentConfig( + + system_instruction=( + "Answer precisely and in short unless specifically instructed otherwise." + "\nWhen asked related to code, do not comment the code and do not explain the code unless instructed." + ), + + temperature=0.69, + + max_output_tokens=4000, + + safety_settings=[ + SafetySetting(category="HARM_CATEGORY_HARASSMENT", threshold="BLOCK_NONE"), + SafetySetting(category="HARM_CATEGORY_HATE_SPEECH", threshold="BLOCK_NONE"), + SafetySetting(category="HARM_CATEGORY_SEXUALLY_EXPLICIT", threshold="BLOCK_NONE"), + SafetySetting(category="HARM_CATEGORY_DANGEROUS_CONTENT", threshold="BLOCK_NONE"), + ], + # fmt:on + + tools=[ + Tool( + google_search=GoogleSearchRetrieval( + dynamic_retrieval_config=DynamicRetrievalConfig( + dynamic_threshold=0.3 + ) + ) + ) + ], + ) + + @staticmethod + def get_kwargs() -> dict: + return {"model": Settings.MODEL, "config": Settings.CONFIG} async def init_task(): - if extra_config.GEMINI_API_KEY: - genai.configure(api_key=extra_config.GEMINI_API_KEY) - - model_info = await SETTINGS.find_one({"_id": "gemini_model_info"}) or {} + model_info = await DB_SETTINGS.find_one({"_id": "gemini_model_info"}) or {} model_name = model_info.get("model_name") if model_name: - MODEL._model_name = model_name + Settings.MODEL = model_name @BOT.add_cmd(cmd="llms") @@ -46,15 +75,15 @@ async def list_ai_models(bot: BOT, message: Message): USAGE: .llms """ model_list = [ - model.name - for model in genai.list_models() - if "generateContent" in model.supported_generation_methods + model.name.lstrip("models/") + async for model in await async_client.models.list(config={"query_base": True}) + if "generateContent" in model.supported_actions ] model_str = "\n\n".join(model_list) update_str = ( - f"\n\nCurrent Model: {MODEL._model_name}" + f"\n\nCurrent Model: {Settings.MODEL}" "\n\nTo change to a different model," "Reply to this message with the model name." ) @@ -63,7 +92,7 @@ async def list_ai_models(bot: BOT, message: Message): f"
{model_str}
{update_str}" ) - async def resp_filters(_, c, m): + async def resp_filters(_, __, m): return m.reply_id == model_reply.id response = await model_reply.get_response( @@ -80,10 +109,12 @@ async def list_ai_models(bot: BOT, message: Message): ) return - await SETTINGS.add_data({"_id": "gemini_model_info", "model_name": response.text}) + await DB_SETTINGS.add_data( + {"_id": "gemini_model_info", "model_name": response.text} + ) await model_reply.edit(f"{response.text} saved as model.") await model_reply.log() - MODEL._model_name = response.text + Settings.MODEL = response.text def run_basic_check(function): @@ -115,5 +146,16 @@ def run_basic_check(function): return wrapper -def get_response_text(response): - return "\n".join([part.text for part in response.parts]) +def get_response_text(response, quoted: bool = False): + candidate = response.candidates[0] + sources = "" + + if grounding_chunks := candidate.grounding_metadata.grounding_chunks: + hrefs = [f"[{chunk.web.title}]({chunk.web.uri})" for chunk in grounding_chunks] + sources = "\n\nSources: " + " | ".join(hrefs) + + text = "\n".join([part.text for part in candidate.content.parts]) + + final_text = (text.strip() + sources).strip() + + return f"**>\n{final_text}<**" if quoted else final_text diff --git a/app/plugins/ai/openai.py b/app/plugins/ai/openai.py index e6eb6d0..4ea85b9 100644 --- a/app/plugins/ai/openai.py +++ b/app/plugins/ai/openai.py @@ -6,7 +6,7 @@ import openai from pyrogram.enums import ParseMode from pyrogram.types import InputMediaPhoto -from app import BOT, LOGGER, Message +from app import BOT, Message from app.plugins.ai.models import SYSTEM_INSTRUCTION OPENAI_CLIENT = environ.get("OPENAI_CLIENT", "") @@ -37,14 +37,12 @@ else: try: TEXT_CLIENT = AI_CLIENT(**text_init_kwargs) -except Exception as e: - LOGGER.error(e) +except: TEXT_CLIENT = None try: DALL_E_CLIENT = AI_CLIENT(**image_init_kwargs) -except Exception as e: - LOGGER.error(e) +except: DALL_E_CLIENT = None @@ -68,7 +66,7 @@ async def chat_gpt(bot: BOT, message: Message): OPENAI_MODEL = your azure model AZURE_OPENAI_API_KEY = your api key AZURE_OPENAI_ENDPOINT = your azure endpoint - AZURE_DEPLOYMENT = your azure deployment + AZURE_DEPLOYMENT = your azure deployment USAGE: .gpt hi @@ -96,7 +94,8 @@ async def chat_gpt(bot: BOT, message: Message): response = chat_completion.choices[0].message.content await message.reply( - text=f"```\n{prompt}```**GPT**:\n{response}", parse_mode=ParseMode.MARKDOWN + text=f"**>\n{prompt}\n<**\n**GPT**:**>\n{response}\n<**", + parse_mode=ParseMode.MARKDOWN, ) @@ -118,7 +117,7 @@ async def chat_gpt(bot: BOT, message: Message): DALL_E_API_KEY = your api key DALL_E_API_VERSION = your version DALL_E_ENDPOINT = your azure endpoint - DALL_E_DEPLOYMENT = your azure deployment + DALL_E_DEPLOYMENT = your azure deployment FLAGS: -v: for vivid style images (default) @@ -160,7 +159,7 @@ async def chat_gpt(bot: BOT, message: Message): await response.edit_media( InputMediaPhoto( media=image_io, - caption=prompt, + caption=f"**>\n{prompt}\n<**", has_spoiler="-s" in message.flags, ) ) diff --git a/app/plugins/ai/text_query.py b/app/plugins/ai/text_query.py index db44511..f03f587 100644 --- a/app/plugins/ai/text_query.py +++ b/app/plugins/ai/text_query.py @@ -1,12 +1,18 @@ import pickle from io import BytesIO +from google.genai.chats import AsyncChat from pyrogram import filters from pyrogram.enums import ParseMode from app import BOT, Convo, Message, bot from app.plugins.ai.media_query import handle_media -from app.plugins.ai.models import MODEL, get_response_text, run_basic_check +from app.plugins.ai.models import ( + Settings, + async_client, + get_response_text, + run_basic_check, +) CONVO_CACHE: dict[str, Convo] = {} @@ -26,27 +32,33 @@ async def question(bot: BOT, message: Message): .ai [reply to image | video | gif] [custom prompt] """ reply = message.replied - reply_text = reply.text if reply else "" + prompt = message.input if reply and reply.media: message_response = await message.reply( "Processing... this may take a while." ) - prompt = message.input response_text = await handle_media( - prompt=prompt, media_message=reply, model=MODEL + prompt=prompt, media_message=reply, **Settings.get_kwargs() ) else: message_response = await message.reply( "Input received... generating response." ) - prompt = f"{reply_text}\n\n\n{message.input}".strip() - response = await MODEL.generate_content_async(prompt) - response_text = get_response_text(response) + if reply and reply.text: + prompts = [reply.text, message.input] + else: + prompts = [message.input] + + response = await async_client.models.generate_content( + contents=prompts, **Settings.get_kwargs() + ) + response_text = get_response_text(response, quoted=True) await message_response.edit( - text=f"```\n{prompt}```**GEMINI AI**:\n{response_text.strip()}", + text=f"**>\n{prompt}<**\n**GEMINI AI**:{response_text}", parse_mode=ParseMode.MARKDOWN, + disable_preview=True, ) @@ -62,7 +74,7 @@ async def ai_chat(bot: BOT, message: Message): After 5 mins of Idle bot will export history and stop chat. use .load_history to continue """ - chat = MODEL.start_chat(history=[]) + chat = async_client.chats.create(**Settings.get_kwargs()) await do_convo(chat=chat, message=message) @@ -77,23 +89,27 @@ async def history_chat(bot: BOT, message: Message): """ reply = message.replied + if not message.input: + await message.reply(f"Ask a question along with {message.trigger}{message.cmd}") + return + try: assert reply.document.file_name == "AI_Chat_History.pkl" except (AssertionError, AttributeError): await message.reply("Reply to a Valid History file.") return - resp = await message.reply("Loading History...") + resp = await message.reply("`Loading History...`") doc = await reply.download(in_memory=True) doc.seek(0) history = pickle.load(doc) - await resp.edit("History Loaded... Resuming chat") - chat = MODEL.start_chat(history=history) + await resp.edit("__History Loaded... Resuming chat__") + chat = async_client.chats.create(**Settings.get_kwargs(), history=history) await do_convo(chat=chat, message=message) -async def do_convo(chat, message: Message): +async def do_convo(chat: AsyncChat, message: Message): prompt = message.input reply_to_id = message.id chat_id = message.chat.id @@ -115,14 +131,15 @@ async def do_convo(chat, message: Message): try: async with convo_obj: while True: - ai_response = await chat.send_message_async(prompt) - ai_response_text = get_response_text(ai_response) - text = f"**GEMINI AI**:\n\n{ai_response_text}" + ai_response = await chat.send_message(prompt) + ai_response_text = get_response_text(ai_response, quoted=True) + text = f"**GEMINI AI**:\n{ai_response_text}" _, prompt_message = await convo_obj.send_message( text=text, reply_to_id=reply_to_id, parse_mode=ParseMode.MARKDOWN, get_response=True, + disable_preview=True, ) prompt, reply_to_id = prompt_message.text, prompt_message.id @@ -147,10 +164,10 @@ def generate_filter(message: Message): return filters.create(_filter) -async def export_history(chat, message: Message): - doc = BytesIO(pickle.dumps(chat.history)) +async def export_history(chat: AsyncChat, message: Message): + doc = BytesIO(pickle.dumps(chat._curated_history)) doc.name = "AI_Chat_History.pkl" caption = get_response_text( - await chat.send_message_async("Summarize our Conversation into one line.") + await chat.send_message("Summarize our Conversation into one line.") ) await bot.send_document(chat_id=message.from_user.id, document=doc, caption=caption) diff --git a/req.txt b/req.txt index 12200fa..003146e 100644 --- a/req.txt +++ b/req.txt @@ -4,6 +4,6 @@ yt-dlp>=2024.5.27 pillow -google-generativeai +google-genai openai