From 5fb4e5aaf9724773aba943e50db647667ec2ca19 Mon Sep 17 00:00:00 2001 From: thedragonsinn <98635854+thedragonsinn@users.noreply.github.com> Date: Wed, 7 Aug 2024 14:42:19 +0530 Subject: [PATCH] update ai models and use a decorator for checks --- app/plugins/ai/media_query.py | 14 ++++---- app/plugins/ai/models.py | 64 ++++++++++++++++++++--------------- app/plugins/ai/text_query.py | 15 +++----- 3 files changed, 47 insertions(+), 46 deletions(-) diff --git a/app/plugins/ai/media_query.py b/app/plugins/ai/media_query.py index 577f805..1def923 100644 --- a/app/plugins/ai/media_query.py +++ b/app/plugins/ai/media_query.py @@ -7,17 +7,11 @@ import time from io import BytesIO import google.generativeai as genai +from app import BOT, Message, bot +from app.plugins.ai.models import get_response_text, run_basic_check from google.ai import generativelanguage as glm from ub_core.utils import run_shell_cmd -from app import BOT, Message, bot -from app.plugins.ai.models import ( - IMAGE_MODEL, - MEDIA_MODEL, - TEXT_MODEL, - get_response_text, -) - CODE_EXTS = { ".txt", ".java", @@ -45,6 +39,7 @@ AUDIO_EXTS = {".aac", ".mp3", ".opus", ".m4a", ".ogg"} @bot.add_cmd(cmd="ocr") +@run_basic_check async def photo_query(bot: BOT, message: Message): """ CMD: OCR @@ -64,6 +59,7 @@ async def photo_query(bot: BOT, message: Message): @bot.add_cmd(cmd="stt") +@run_basic_check async def audio_to_text(bot: BOT, message: Message): """ CMD: STT (Speech To Text) @@ -84,6 +80,7 @@ async def audio_to_text(bot: BOT, message: Message): @bot.add_cmd(cmd="ocrv") +@run_basic_check async def video_to_text(bot: BOT, message: Message): """ CMD: OCRV @@ -103,6 +100,7 @@ async def video_to_text(bot: BOT, message: Message): @bot.add_cmd(cmd="aim") +@run_basic_check async def handle_document(bot: BOT, message: Message): """ CMD: AIM diff --git a/app/plugins/ai/models.py b/app/plugins/ai/models.py index fe6497f..d0eb4b7 100644 --- a/app/plugins/ai/models.py +++ b/app/plugins/ai/models.py @@ -1,6 +1,7 @@ -import google.generativeai as genai +from functools import wraps -from app import Message, extra_config +import google.generativeai as genai +from app import BOT, Message, extra_config async def init_task(): @@ -18,37 +19,44 @@ SAFETY_SETTINGS = [ ] -TEXT_MODEL = genai.GenerativeModel( - model_name="gemini-pro", - generation_config=GENERATION_CONFIG, - safety_settings=SAFETY_SETTINGS, -) - -IMAGE_MODEL = genai.GenerativeModel( - model_name="gemini-pro-vision", - generation_config=GENERATION_CONFIG, - safety_settings=SAFETY_SETTINGS, -) - -MEDIA_MODEL = genai.GenerativeModel( - model_name="models/gemini-1.5-pro-latest", +MODEL = genai.GenerativeModel( + model_name="models/gemini-1.5-flash", generation_config=GENERATION_CONFIG, safety_settings=SAFETY_SETTINGS, ) -async def basic_check(message: Message): - if not extra_config.GEMINI_API_KEY: - await message.reply( - "Gemini API KEY not found." - "\nGet it HERE " - "and set GEMINI_API_KEY var." - ) - return - if not message.input: - await message.reply("Ask a Question.") - return - return 1 +async def run_basic_check(func): + + @wraps(func) + async def wrapper(bot: BOT, message: Message): + + if not extra_config.GEMINI_API_KEY: + await message.reply( + "Gemini API KEY not found." + "\nGet it HERE " + "and set GEMINI_API_KEY var." + ) + return + + if not message.input: + await message.reply("Ask a Question.") + return + + try: + await func(bot, message) + except Exception as e: + + if "User location is not supported for the API use" in str(e): + await message.reply( + "Your server location doesn't allow gemini yet." + "\nIf you are on koyeb change your app region to Washington DC." + ) + return + + raise + + return wrapper def get_response_text(response): diff --git a/app/plugins/ai/text_query.py b/app/plugins/ai/text_query.py index 265e37d..1dd37c1 100644 --- a/app/plugins/ai/text_query.py +++ b/app/plugins/ai/text_query.py @@ -1,16 +1,16 @@ import pickle from io import BytesIO +from app import BOT, Convo, Message, bot +from app.plugins.ai.models import get_response_text, run_basic_check from pyrogram import filters from pyrogram.enums import ParseMode -from app import BOT, Convo, Message, bot -from app.plugins.ai.models import TEXT_MODEL, basic_check, get_response_text - CONVO_CACHE: dict[str, Convo] = {} @bot.add_cmd(cmd="ai") +@run_basic_check async def question(bot: BOT, message: Message): """ CMD: AI @@ -18,9 +18,6 @@ async def question(bot: BOT, message: Message): USAGE: .ai what is the meaning of life. """ - if not await basic_check(message): - return - prompt = message.input response = await TEXT_MODEL.generate_content_async(prompt) @@ -42,6 +39,7 @@ async def question(bot: BOT, message: Message): @bot.add_cmd(cmd="aichat") +@run_basic_check async def ai_chat(bot: BOT, message: Message): """ CMD: AICHAT @@ -52,13 +50,12 @@ async def ai_chat(bot: BOT, message: Message): After 5 mins of Idle bot will export history and stop chat. use .load_history to continue """ - if not await basic_check(message): - return chat = TEXT_MODEL.start_chat(history=[]) await do_convo(chat=chat, message=message) @bot.add_cmd(cmd="load_history") +@run_basic_check async def history_chat(bot: BOT, message: Message): """ CMD: LOAD_HISTORY @@ -66,8 +63,6 @@ async def history_chat(bot: BOT, message: Message): USAGE: .load_history {question} [reply to history document] """ - if not await basic_check(message): - return reply = message.replied if (