diff --git a/app/plugins/ai/media_query.py b/app/plugins/ai/media_query.py index ff09619..389d1c6 100644 --- a/app/plugins/ai/media_query.py +++ b/app/plugins/ai/media_query.py @@ -11,7 +11,7 @@ from google.ai import generativelanguage as glm from ub_core.utils import run_shell_cmd from app import BOT, Message, bot -from app.plugins.ai.models import get_response_text, run_basic_check +from app.plugins.ai.models import MODEL, get_response_text, run_basic_check CODE_EXTS = { ".txt", @@ -153,7 +153,7 @@ async def handle_audio(prompt: str, message: Message): file_path, download_dir = await download_file(file_name, message) file_response = genai.upload_file(path=file_path) - response = await MEDIA_MODEL.generate_content_async([prompt, file_response]) + response = await MODEL.generate_content_async([prompt, file_response]) response_text = get_response_text(response) genai.delete_file(name=file_response.name) @@ -166,7 +166,7 @@ async def handle_code(prompt: str, message: Message): file: BytesIO = await message.download(in_memory=True) text = file.getvalue().decode("utf-8") final_prompt = f"{text}\n\n{prompt}" - response = await TEXT_MODEL.generate_content_async(final_prompt) + response = await MODEL.generate_content_async(final_prompt) return get_response_text(response) @@ -178,7 +178,7 @@ async def handle_photo(prompt: str, message: Message): mime_type = "image/unknown" image_blob = glm.Blob(mime_type=mime_type, data=file.getvalue()) - response = await IMAGE_MODEL.generate_content_async([prompt, image_blob]) + response = await MODEL.generate_content_async([prompt, image_blob]) return get_response_text(response) @@ -201,7 +201,7 @@ async def handle_video(prompt: str, message: Message): uploaded_frame = await asyncio.to_thread(genai.upload_file, frame) uploaded_frames.append(uploaded_frame) - response = await MEDIA_MODEL.generate_content_async([prompt, *uploaded_frames]) + response = await MODEL.generate_content_async([prompt, *uploaded_frames]) response_text = get_response_text(response) for uploaded_frame in uploaded_frames: diff --git a/app/plugins/ai/text_query.py b/app/plugins/ai/text_query.py index 31da9f8..ab88fb8 100644 --- a/app/plugins/ai/text_query.py +++ b/app/plugins/ai/text_query.py @@ -5,7 +5,7 @@ from pyrogram import filters from pyrogram.enums import ParseMode from app import BOT, Convo, Message, bot -from app.plugins.ai.models import get_response_text, run_basic_check +from app.plugins.ai.models import MODEL, get_response_text, run_basic_check CONVO_CACHE: dict[str, Convo] = {} @@ -21,7 +21,7 @@ async def question(bot: BOT, message: Message): prompt = message.input - response = await TEXT_MODEL.generate_content_async(prompt) + response = await MODEL.generate_content_async(prompt) response_text = get_response_text(response) @@ -51,7 +51,7 @@ async def ai_chat(bot: BOT, message: Message): After 5 mins of Idle bot will export history and stop chat. use .load_history to continue """ - chat = TEXT_MODEL.start_chat(history=[]) + chat = MODEL.start_chat(history=[]) await do_convo(chat=chat, message=message) @@ -79,7 +79,7 @@ async def history_chat(bot: BOT, message: Message): doc: BytesIO = (await reply.download(in_memory=True)).getbuffer() # NOQA history = pickle.loads(doc) await resp.edit("History Loaded... Resuming chat") - chat = TEXT_MODEL.start_chat(history=history) + chat = MODEL.start_chat(history=history) await do_convo(chat=chat, message=message)