update ai models and use a decorator for checks
This commit is contained in:
@@ -7,17 +7,11 @@ import time
|
||||
from io import BytesIO
|
||||
|
||||
import google.generativeai as genai
|
||||
from app import BOT, Message, bot
|
||||
from app.plugins.ai.models import get_response_text, run_basic_check
|
||||
from google.ai import generativelanguage as glm
|
||||
from ub_core.utils import run_shell_cmd
|
||||
|
||||
from app import BOT, Message, bot
|
||||
from app.plugins.ai.models import (
|
||||
IMAGE_MODEL,
|
||||
MEDIA_MODEL,
|
||||
TEXT_MODEL,
|
||||
get_response_text,
|
||||
)
|
||||
|
||||
CODE_EXTS = {
|
||||
".txt",
|
||||
".java",
|
||||
@@ -45,6 +39,7 @@ AUDIO_EXTS = {".aac", ".mp3", ".opus", ".m4a", ".ogg"}
|
||||
|
||||
|
||||
@bot.add_cmd(cmd="ocr")
|
||||
@run_basic_check
|
||||
async def photo_query(bot: BOT, message: Message):
|
||||
"""
|
||||
CMD: OCR
|
||||
@@ -64,6 +59,7 @@ async def photo_query(bot: BOT, message: Message):
|
||||
|
||||
|
||||
@bot.add_cmd(cmd="stt")
|
||||
@run_basic_check
|
||||
async def audio_to_text(bot: BOT, message: Message):
|
||||
"""
|
||||
CMD: STT (Speech To Text)
|
||||
@@ -84,6 +80,7 @@ async def audio_to_text(bot: BOT, message: Message):
|
||||
|
||||
|
||||
@bot.add_cmd(cmd="ocrv")
|
||||
@run_basic_check
|
||||
async def video_to_text(bot: BOT, message: Message):
|
||||
"""
|
||||
CMD: OCRV
|
||||
@@ -103,6 +100,7 @@ async def video_to_text(bot: BOT, message: Message):
|
||||
|
||||
|
||||
@bot.add_cmd(cmd="aim")
|
||||
@run_basic_check
|
||||
async def handle_document(bot: BOT, message: Message):
|
||||
"""
|
||||
CMD: AIM
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import google.generativeai as genai
|
||||
from functools import wraps
|
||||
|
||||
from app import Message, extra_config
|
||||
import google.generativeai as genai
|
||||
from app import BOT, Message, extra_config
|
||||
|
||||
|
||||
async def init_task():
|
||||
@@ -18,37 +19,44 @@ SAFETY_SETTINGS = [
|
||||
]
|
||||
|
||||
|
||||
TEXT_MODEL = genai.GenerativeModel(
|
||||
model_name="gemini-pro",
|
||||
generation_config=GENERATION_CONFIG,
|
||||
safety_settings=SAFETY_SETTINGS,
|
||||
)
|
||||
|
||||
IMAGE_MODEL = genai.GenerativeModel(
|
||||
model_name="gemini-pro-vision",
|
||||
generation_config=GENERATION_CONFIG,
|
||||
safety_settings=SAFETY_SETTINGS,
|
||||
)
|
||||
|
||||
MEDIA_MODEL = genai.GenerativeModel(
|
||||
model_name="models/gemini-1.5-pro-latest",
|
||||
MODEL = genai.GenerativeModel(
|
||||
model_name="models/gemini-1.5-flash",
|
||||
generation_config=GENERATION_CONFIG,
|
||||
safety_settings=SAFETY_SETTINGS,
|
||||
)
|
||||
|
||||
|
||||
async def basic_check(message: Message):
|
||||
if not extra_config.GEMINI_API_KEY:
|
||||
await message.reply(
|
||||
"Gemini API KEY not found."
|
||||
"\nGet it <a href='https://makersuite.google.com/app/apikey'>HERE</a> "
|
||||
"and set GEMINI_API_KEY var."
|
||||
)
|
||||
return
|
||||
if not message.input:
|
||||
await message.reply("Ask a Question.")
|
||||
return
|
||||
return 1
|
||||
async def run_basic_check(func):
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(bot: BOT, message: Message):
|
||||
|
||||
if not extra_config.GEMINI_API_KEY:
|
||||
await message.reply(
|
||||
"Gemini API KEY not found."
|
||||
"\nGet it <a href='https://makersuite.google.com/app/apikey'>HERE</a> "
|
||||
"and set GEMINI_API_KEY var."
|
||||
)
|
||||
return
|
||||
|
||||
if not message.input:
|
||||
await message.reply("Ask a Question.")
|
||||
return
|
||||
|
||||
try:
|
||||
await func(bot, message)
|
||||
except Exception as e:
|
||||
|
||||
if "User location is not supported for the API use" in str(e):
|
||||
await message.reply(
|
||||
"Your server location doesn't allow gemini yet."
|
||||
"\nIf you are on koyeb change your app region to Washington DC."
|
||||
)
|
||||
return
|
||||
|
||||
raise
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def get_response_text(response):
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
import pickle
|
||||
from io import BytesIO
|
||||
|
||||
from app import BOT, Convo, Message, bot
|
||||
from app.plugins.ai.models import get_response_text, run_basic_check
|
||||
from pyrogram import filters
|
||||
from pyrogram.enums import ParseMode
|
||||
|
||||
from app import BOT, Convo, Message, bot
|
||||
from app.plugins.ai.models import TEXT_MODEL, basic_check, get_response_text
|
||||
|
||||
CONVO_CACHE: dict[str, Convo] = {}
|
||||
|
||||
|
||||
@bot.add_cmd(cmd="ai")
|
||||
@run_basic_check
|
||||
async def question(bot: BOT, message: Message):
|
||||
"""
|
||||
CMD: AI
|
||||
@@ -18,9 +18,6 @@ async def question(bot: BOT, message: Message):
|
||||
USAGE: .ai what is the meaning of life.
|
||||
"""
|
||||
|
||||
if not await basic_check(message):
|
||||
return
|
||||
|
||||
prompt = message.input
|
||||
|
||||
response = await TEXT_MODEL.generate_content_async(prompt)
|
||||
@@ -42,6 +39,7 @@ async def question(bot: BOT, message: Message):
|
||||
|
||||
|
||||
@bot.add_cmd(cmd="aichat")
|
||||
@run_basic_check
|
||||
async def ai_chat(bot: BOT, message: Message):
|
||||
"""
|
||||
CMD: AICHAT
|
||||
@@ -52,13 +50,12 @@ async def ai_chat(bot: BOT, message: Message):
|
||||
After 5 mins of Idle bot will export history and stop chat.
|
||||
use .load_history to continue
|
||||
"""
|
||||
if not await basic_check(message):
|
||||
return
|
||||
chat = TEXT_MODEL.start_chat(history=[])
|
||||
await do_convo(chat=chat, message=message)
|
||||
|
||||
|
||||
@bot.add_cmd(cmd="load_history")
|
||||
@run_basic_check
|
||||
async def history_chat(bot: BOT, message: Message):
|
||||
"""
|
||||
CMD: LOAD_HISTORY
|
||||
@@ -66,8 +63,6 @@ async def history_chat(bot: BOT, message: Message):
|
||||
USAGE:
|
||||
.load_history {question} [reply to history document]
|
||||
"""
|
||||
if not await basic_check(message):
|
||||
return
|
||||
reply = message.replied
|
||||
|
||||
if (
|
||||
|
||||
Reference in New Issue
Block a user