diff --git a/app/plugins/ai/media_query.py b/app/plugins/ai/media_query.py
index 577f805..1def923 100644
--- a/app/plugins/ai/media_query.py
+++ b/app/plugins/ai/media_query.py
@@ -7,17 +7,11 @@ import time
from io import BytesIO
import google.generativeai as genai
+from app import BOT, Message, bot
+from app.plugins.ai.models import get_response_text, run_basic_check
from google.ai import generativelanguage as glm
from ub_core.utils import run_shell_cmd
-from app import BOT, Message, bot
-from app.plugins.ai.models import (
- IMAGE_MODEL,
- MEDIA_MODEL,
- TEXT_MODEL,
- get_response_text,
-)
-
CODE_EXTS = {
".txt",
".java",
@@ -45,6 +39,7 @@ AUDIO_EXTS = {".aac", ".mp3", ".opus", ".m4a", ".ogg"}
@bot.add_cmd(cmd="ocr")
+@run_basic_check
async def photo_query(bot: BOT, message: Message):
"""
CMD: OCR
@@ -64,6 +59,7 @@ async def photo_query(bot: BOT, message: Message):
@bot.add_cmd(cmd="stt")
+@run_basic_check
async def audio_to_text(bot: BOT, message: Message):
"""
CMD: STT (Speech To Text)
@@ -84,6 +80,7 @@ async def audio_to_text(bot: BOT, message: Message):
@bot.add_cmd(cmd="ocrv")
+@run_basic_check
async def video_to_text(bot: BOT, message: Message):
"""
CMD: OCRV
@@ -103,6 +100,7 @@ async def video_to_text(bot: BOT, message: Message):
@bot.add_cmd(cmd="aim")
+@run_basic_check
async def handle_document(bot: BOT, message: Message):
"""
CMD: AIM
diff --git a/app/plugins/ai/models.py b/app/plugins/ai/models.py
index fe6497f..d0eb4b7 100644
--- a/app/plugins/ai/models.py
+++ b/app/plugins/ai/models.py
@@ -1,6 +1,7 @@
-import google.generativeai as genai
+from functools import wraps
-from app import Message, extra_config
+import google.generativeai as genai
+from app import BOT, Message, extra_config
async def init_task():
@@ -18,37 +19,44 @@ SAFETY_SETTINGS = [
]
-TEXT_MODEL = genai.GenerativeModel(
- model_name="gemini-pro",
- generation_config=GENERATION_CONFIG,
- safety_settings=SAFETY_SETTINGS,
-)
-
-IMAGE_MODEL = genai.GenerativeModel(
- model_name="gemini-pro-vision",
- generation_config=GENERATION_CONFIG,
- safety_settings=SAFETY_SETTINGS,
-)
-
-MEDIA_MODEL = genai.GenerativeModel(
- model_name="models/gemini-1.5-pro-latest",
+MODEL = genai.GenerativeModel(
+ model_name="models/gemini-1.5-flash",
generation_config=GENERATION_CONFIG,
safety_settings=SAFETY_SETTINGS,
)
-async def basic_check(message: Message):
- if not extra_config.GEMINI_API_KEY:
- await message.reply(
- "Gemini API KEY not found."
- "\nGet it HERE "
- "and set GEMINI_API_KEY var."
- )
- return
- if not message.input:
- await message.reply("Ask a Question.")
- return
- return 1
+async def run_basic_check(func):
+
+ @wraps(func)
+ async def wrapper(bot: BOT, message: Message):
+
+ if not extra_config.GEMINI_API_KEY:
+ await message.reply(
+ "Gemini API KEY not found."
+ "\nGet it HERE "
+ "and set GEMINI_API_KEY var."
+ )
+ return
+
+ if not message.input:
+ await message.reply("Ask a Question.")
+ return
+
+ try:
+ await func(bot, message)
+ except Exception as e:
+
+ if "User location is not supported for the API use" in str(e):
+ await message.reply(
+ "Your server location doesn't allow gemini yet."
+ "\nIf you are on koyeb change your app region to Washington DC."
+ )
+ return
+
+ raise
+
+ return wrapper
def get_response_text(response):
diff --git a/app/plugins/ai/text_query.py b/app/plugins/ai/text_query.py
index 265e37d..1dd37c1 100644
--- a/app/plugins/ai/text_query.py
+++ b/app/plugins/ai/text_query.py
@@ -1,16 +1,16 @@
import pickle
from io import BytesIO
+from app import BOT, Convo, Message, bot
+from app.plugins.ai.models import get_response_text, run_basic_check
from pyrogram import filters
from pyrogram.enums import ParseMode
-from app import BOT, Convo, Message, bot
-from app.plugins.ai.models import TEXT_MODEL, basic_check, get_response_text
-
CONVO_CACHE: dict[str, Convo] = {}
@bot.add_cmd(cmd="ai")
+@run_basic_check
async def question(bot: BOT, message: Message):
"""
CMD: AI
@@ -18,9 +18,6 @@ async def question(bot: BOT, message: Message):
USAGE: .ai what is the meaning of life.
"""
- if not await basic_check(message):
- return
-
prompt = message.input
response = await TEXT_MODEL.generate_content_async(prompt)
@@ -42,6 +39,7 @@ async def question(bot: BOT, message: Message):
@bot.add_cmd(cmd="aichat")
+@run_basic_check
async def ai_chat(bot: BOT, message: Message):
"""
CMD: AICHAT
@@ -52,13 +50,12 @@ async def ai_chat(bot: BOT, message: Message):
After 5 mins of Idle bot will export history and stop chat.
use .load_history to continue
"""
- if not await basic_check(message):
- return
chat = TEXT_MODEL.start_chat(history=[])
await do_convo(chat=chat, message=message)
@bot.add_cmd(cmd="load_history")
+@run_basic_check
async def history_chat(bot: BOT, message: Message):
"""
CMD: LOAD_HISTORY
@@ -66,8 +63,6 @@ async def history_chat(bot: BOT, message: Message):
USAGE:
.load_history {question} [reply to history document]
"""
- if not await basic_check(message):
- return
reply = message.replied
if (