ai: upgrade to new models and package.
This commit is contained in:
@@ -3,15 +3,14 @@ import os
|
||||
import shutil
|
||||
import time
|
||||
|
||||
import google.generativeai as genai
|
||||
from pyrogram.types.messages_and_media import Audio, Photo, Video, Voice
|
||||
from ub_core.utils import get_tg_media_details
|
||||
|
||||
from app import Message
|
||||
from app.plugins.ai.models import MODEL, get_response_text
|
||||
from app.plugins.ai.models import async_client, get_response_text
|
||||
|
||||
PROMPT_MAP = {
|
||||
Video: "Summarize the video file",
|
||||
Video: "Summarize video and audio from the file",
|
||||
Photo: "Summarize the image file",
|
||||
Voice: (
|
||||
"\nDo not summarise."
|
||||
@@ -23,7 +22,7 @@ PROMPT_MAP = {
|
||||
PROMPT_MAP[Audio] = PROMPT_MAP[Voice]
|
||||
|
||||
|
||||
async def handle_media(prompt: str, media_message: Message, model=MODEL) -> str:
|
||||
async def handle_media(prompt: str, media_message: Message, **kwargs) -> str:
|
||||
media = get_tg_media_details(media_message)
|
||||
|
||||
if getattr(media, "file_size", 0) >= 1048576 * 25:
|
||||
@@ -34,16 +33,20 @@ async def handle_media(prompt: str, media_message: Message, model=MODEL) -> str:
|
||||
)
|
||||
|
||||
download_dir = os.path.join("downloads", str(time.time())) + "/"
|
||||
downloaded_file = await media_message.download(download_dir)
|
||||
downloaded_file: str = await media_message.download(download_dir)
|
||||
|
||||
uploaded_file = await asyncio.to_thread(genai.upload_file, downloaded_file)
|
||||
uploaded_file = await async_client.files.upload(
|
||||
file=downloaded_file, config={"mime_type": media.mime_type}
|
||||
)
|
||||
|
||||
while uploaded_file.state.name == "PROCESSING":
|
||||
await asyncio.sleep(5)
|
||||
uploaded_file = await asyncio.to_thread(genai.get_file, uploaded_file.name)
|
||||
uploaded_file = await async_client.files.get(name=uploaded_file.name)
|
||||
|
||||
response = await model.generate_content_async([prompt, uploaded_file])
|
||||
response_text = get_response_text(response)
|
||||
response = await async_client.models.generate_content(
|
||||
**kwargs, contents=[uploaded_file, prompt]
|
||||
)
|
||||
response_text = get_response_text(response, quoted=True)
|
||||
|
||||
shutil.rmtree(download_dir, ignore_errors=True)
|
||||
return response_text
|
||||
|
||||
@@ -1,41 +1,70 @@
|
||||
from functools import wraps
|
||||
|
||||
import google.generativeai as genai
|
||||
from google.genai.client import AsyncClient, Client
|
||||
from google.genai.types import (
|
||||
DynamicRetrievalConfig,
|
||||
GenerateContentConfig,
|
||||
GoogleSearchRetrieval,
|
||||
SafetySetting,
|
||||
Tool,
|
||||
)
|
||||
from pyrogram import filters
|
||||
|
||||
from app import BOT, CustomDB, Message, extra_config
|
||||
|
||||
SETTINGS = CustomDB("COMMON_SETTINGS")
|
||||
DB_SETTINGS = CustomDB("COMMON_SETTINGS")
|
||||
|
||||
GENERATION_CONFIG = {"temperature": 0.69, "max_output_tokens": 2048}
|
||||
try:
|
||||
client: Client = Client(api_key=extra_config.GEMINI_API_KEY)
|
||||
async_client: AsyncClient = client.aio
|
||||
except:
|
||||
client = async_client = None
|
||||
|
||||
SAFETY_SETTINGS = [
|
||||
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
|
||||
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
|
||||
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
|
||||
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"},
|
||||
]
|
||||
|
||||
SYSTEM_INSTRUCTION = (
|
||||
"Answer precisely and in short unless specifically instructed otherwise."
|
||||
"\nWhen asked related to code, do not comment the code and do not explain the code unless instructed."
|
||||
)
|
||||
class Settings:
|
||||
MODEL = "gemini-2.0-flash"
|
||||
|
||||
MODEL = genai.GenerativeModel(
|
||||
generation_config=GENERATION_CONFIG,
|
||||
safety_settings=SAFETY_SETTINGS,
|
||||
system_instruction=SYSTEM_INSTRUCTION,
|
||||
)
|
||||
# fmt:off
|
||||
CONFIG = GenerateContentConfig(
|
||||
|
||||
system_instruction=(
|
||||
"Answer precisely and in short unless specifically instructed otherwise."
|
||||
"\nWhen asked related to code, do not comment the code and do not explain the code unless instructed."
|
||||
),
|
||||
|
||||
temperature=0.69,
|
||||
|
||||
max_output_tokens=4000,
|
||||
|
||||
safety_settings=[
|
||||
SafetySetting(category="HARM_CATEGORY_HARASSMENT", threshold="BLOCK_NONE"),
|
||||
SafetySetting(category="HARM_CATEGORY_HATE_SPEECH", threshold="BLOCK_NONE"),
|
||||
SafetySetting(category="HARM_CATEGORY_SEXUALLY_EXPLICIT", threshold="BLOCK_NONE"),
|
||||
SafetySetting(category="HARM_CATEGORY_DANGEROUS_CONTENT", threshold="BLOCK_NONE"),
|
||||
],
|
||||
# fmt:on
|
||||
|
||||
tools=[
|
||||
Tool(
|
||||
google_search=GoogleSearchRetrieval(
|
||||
dynamic_retrieval_config=DynamicRetrievalConfig(
|
||||
dynamic_threshold=0.3
|
||||
)
|
||||
)
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_kwargs() -> dict:
|
||||
return {"model": Settings.MODEL, "config": Settings.CONFIG}
|
||||
|
||||
|
||||
async def init_task():
|
||||
if extra_config.GEMINI_API_KEY:
|
||||
genai.configure(api_key=extra_config.GEMINI_API_KEY)
|
||||
|
||||
model_info = await SETTINGS.find_one({"_id": "gemini_model_info"}) or {}
|
||||
model_info = await DB_SETTINGS.find_one({"_id": "gemini_model_info"}) or {}
|
||||
model_name = model_info.get("model_name")
|
||||
if model_name:
|
||||
MODEL._model_name = model_name
|
||||
Settings.MODEL = model_name
|
||||
|
||||
|
||||
@BOT.add_cmd(cmd="llms")
|
||||
@@ -46,15 +75,15 @@ async def list_ai_models(bot: BOT, message: Message):
|
||||
USAGE: .llms
|
||||
"""
|
||||
model_list = [
|
||||
model.name
|
||||
for model in genai.list_models()
|
||||
if "generateContent" in model.supported_generation_methods
|
||||
model.name.lstrip("models/")
|
||||
async for model in await async_client.models.list(config={"query_base": True})
|
||||
if "generateContent" in model.supported_actions
|
||||
]
|
||||
|
||||
model_str = "\n\n".join(model_list)
|
||||
|
||||
update_str = (
|
||||
f"\n\nCurrent Model: {MODEL._model_name}"
|
||||
f"\n\nCurrent Model: {Settings.MODEL}"
|
||||
"\n\nTo change to a different model,"
|
||||
"Reply to this message with the model name."
|
||||
)
|
||||
@@ -63,7 +92,7 @@ async def list_ai_models(bot: BOT, message: Message):
|
||||
f"<blockquote expandable=True><pre language=text>{model_str}</pre></blockquote>{update_str}"
|
||||
)
|
||||
|
||||
async def resp_filters(_, c, m):
|
||||
async def resp_filters(_, __, m):
|
||||
return m.reply_id == model_reply.id
|
||||
|
||||
response = await model_reply.get_response(
|
||||
@@ -80,10 +109,12 @@ async def list_ai_models(bot: BOT, message: Message):
|
||||
)
|
||||
return
|
||||
|
||||
await SETTINGS.add_data({"_id": "gemini_model_info", "model_name": response.text})
|
||||
await DB_SETTINGS.add_data(
|
||||
{"_id": "gemini_model_info", "model_name": response.text}
|
||||
)
|
||||
await model_reply.edit(f"{response.text} saved as model.")
|
||||
await model_reply.log()
|
||||
MODEL._model_name = response.text
|
||||
Settings.MODEL = response.text
|
||||
|
||||
|
||||
def run_basic_check(function):
|
||||
@@ -115,5 +146,16 @@ def run_basic_check(function):
|
||||
return wrapper
|
||||
|
||||
|
||||
def get_response_text(response):
|
||||
return "\n".join([part.text for part in response.parts])
|
||||
def get_response_text(response, quoted: bool = False):
|
||||
candidate = response.candidates[0]
|
||||
sources = ""
|
||||
|
||||
if grounding_chunks := candidate.grounding_metadata.grounding_chunks:
|
||||
hrefs = [f"[{chunk.web.title}]({chunk.web.uri})" for chunk in grounding_chunks]
|
||||
sources = "\n\nSources: " + " | ".join(hrefs)
|
||||
|
||||
text = "\n".join([part.text for part in candidate.content.parts])
|
||||
|
||||
final_text = (text.strip() + sources).strip()
|
||||
|
||||
return f"**>\n{final_text}<**" if quoted else final_text
|
||||
|
||||
@@ -6,7 +6,7 @@ import openai
|
||||
from pyrogram.enums import ParseMode
|
||||
from pyrogram.types import InputMediaPhoto
|
||||
|
||||
from app import BOT, LOGGER, Message
|
||||
from app import BOT, Message
|
||||
from app.plugins.ai.models import SYSTEM_INSTRUCTION
|
||||
|
||||
OPENAI_CLIENT = environ.get("OPENAI_CLIENT", "")
|
||||
@@ -37,14 +37,12 @@ else:
|
||||
|
||||
try:
|
||||
TEXT_CLIENT = AI_CLIENT(**text_init_kwargs)
|
||||
except Exception as e:
|
||||
LOGGER.error(e)
|
||||
except:
|
||||
TEXT_CLIENT = None
|
||||
|
||||
try:
|
||||
DALL_E_CLIENT = AI_CLIENT(**image_init_kwargs)
|
||||
except Exception as e:
|
||||
LOGGER.error(e)
|
||||
except:
|
||||
DALL_E_CLIENT = None
|
||||
|
||||
|
||||
@@ -68,7 +66,7 @@ async def chat_gpt(bot: BOT, message: Message):
|
||||
OPENAI_MODEL = your azure model
|
||||
AZURE_OPENAI_API_KEY = your api key
|
||||
AZURE_OPENAI_ENDPOINT = your azure endpoint
|
||||
AZURE_DEPLOYMENT = your azure deployment
|
||||
AZURE_DEPLOYMENT = your azure deployment
|
||||
|
||||
USAGE:
|
||||
.gpt hi
|
||||
@@ -96,7 +94,8 @@ async def chat_gpt(bot: BOT, message: Message):
|
||||
|
||||
response = chat_completion.choices[0].message.content
|
||||
await message.reply(
|
||||
text=f"```\n{prompt}```**GPT**:\n{response}", parse_mode=ParseMode.MARKDOWN
|
||||
text=f"**>\n{prompt}\n<**\n**GPT**:**>\n{response}\n<**",
|
||||
parse_mode=ParseMode.MARKDOWN,
|
||||
)
|
||||
|
||||
|
||||
@@ -118,7 +117,7 @@ async def chat_gpt(bot: BOT, message: Message):
|
||||
DALL_E_API_KEY = your api key
|
||||
DALL_E_API_VERSION = your version
|
||||
DALL_E_ENDPOINT = your azure endpoint
|
||||
DALL_E_DEPLOYMENT = your azure deployment
|
||||
DALL_E_DEPLOYMENT = your azure deployment
|
||||
|
||||
FLAGS:
|
||||
-v: for vivid style images (default)
|
||||
@@ -160,7 +159,7 @@ async def chat_gpt(bot: BOT, message: Message):
|
||||
await response.edit_media(
|
||||
InputMediaPhoto(
|
||||
media=image_io,
|
||||
caption=prompt,
|
||||
caption=f"**>\n{prompt}\n<**",
|
||||
has_spoiler="-s" in message.flags,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1,12 +1,18 @@
|
||||
import pickle
|
||||
from io import BytesIO
|
||||
|
||||
from google.genai.chats import AsyncChat
|
||||
from pyrogram import filters
|
||||
from pyrogram.enums import ParseMode
|
||||
|
||||
from app import BOT, Convo, Message, bot
|
||||
from app.plugins.ai.media_query import handle_media
|
||||
from app.plugins.ai.models import MODEL, get_response_text, run_basic_check
|
||||
from app.plugins.ai.models import (
|
||||
Settings,
|
||||
async_client,
|
||||
get_response_text,
|
||||
run_basic_check,
|
||||
)
|
||||
|
||||
CONVO_CACHE: dict[str, Convo] = {}
|
||||
|
||||
@@ -26,27 +32,33 @@ async def question(bot: BOT, message: Message):
|
||||
.ai [reply to image | video | gif] [custom prompt]
|
||||
"""
|
||||
reply = message.replied
|
||||
reply_text = reply.text if reply else ""
|
||||
prompt = message.input
|
||||
|
||||
if reply and reply.media:
|
||||
message_response = await message.reply(
|
||||
"<code>Processing... this may take a while.</code>"
|
||||
)
|
||||
prompt = message.input
|
||||
response_text = await handle_media(
|
||||
prompt=prompt, media_message=reply, model=MODEL
|
||||
prompt=prompt, media_message=reply, **Settings.get_kwargs()
|
||||
)
|
||||
else:
|
||||
message_response = await message.reply(
|
||||
"<code>Input received... generating response.</code>"
|
||||
)
|
||||
prompt = f"{reply_text}\n\n\n{message.input}".strip()
|
||||
response = await MODEL.generate_content_async(prompt)
|
||||
response_text = get_response_text(response)
|
||||
if reply and reply.text:
|
||||
prompts = [reply.text, message.input]
|
||||
else:
|
||||
prompts = [message.input]
|
||||
|
||||
response = await async_client.models.generate_content(
|
||||
contents=prompts, **Settings.get_kwargs()
|
||||
)
|
||||
response_text = get_response_text(response, quoted=True)
|
||||
|
||||
await message_response.edit(
|
||||
text=f"```\n{prompt}```**GEMINI AI**:\n{response_text.strip()}",
|
||||
text=f"**>\n{prompt}<**\n**GEMINI AI**:{response_text}",
|
||||
parse_mode=ParseMode.MARKDOWN,
|
||||
disable_preview=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -62,7 +74,7 @@ async def ai_chat(bot: BOT, message: Message):
|
||||
After 5 mins of Idle bot will export history and stop chat.
|
||||
use .load_history to continue
|
||||
"""
|
||||
chat = MODEL.start_chat(history=[])
|
||||
chat = async_client.chats.create(**Settings.get_kwargs())
|
||||
await do_convo(chat=chat, message=message)
|
||||
|
||||
|
||||
@@ -77,23 +89,27 @@ async def history_chat(bot: BOT, message: Message):
|
||||
"""
|
||||
reply = message.replied
|
||||
|
||||
if not message.input:
|
||||
await message.reply(f"Ask a question along with {message.trigger}{message.cmd}")
|
||||
return
|
||||
|
||||
try:
|
||||
assert reply.document.file_name == "AI_Chat_History.pkl"
|
||||
except (AssertionError, AttributeError):
|
||||
await message.reply("Reply to a Valid History file.")
|
||||
return
|
||||
|
||||
resp = await message.reply("<i>Loading History...</i>")
|
||||
resp = await message.reply("`Loading History...`")
|
||||
doc = await reply.download(in_memory=True)
|
||||
doc.seek(0)
|
||||
|
||||
history = pickle.load(doc)
|
||||
await resp.edit("<i>History Loaded... Resuming chat</i>")
|
||||
chat = MODEL.start_chat(history=history)
|
||||
await resp.edit("__History Loaded... Resuming chat__")
|
||||
chat = async_client.chats.create(**Settings.get_kwargs(), history=history)
|
||||
await do_convo(chat=chat, message=message)
|
||||
|
||||
|
||||
async def do_convo(chat, message: Message):
|
||||
async def do_convo(chat: AsyncChat, message: Message):
|
||||
prompt = message.input
|
||||
reply_to_id = message.id
|
||||
chat_id = message.chat.id
|
||||
@@ -115,14 +131,15 @@ async def do_convo(chat, message: Message):
|
||||
try:
|
||||
async with convo_obj:
|
||||
while True:
|
||||
ai_response = await chat.send_message_async(prompt)
|
||||
ai_response_text = get_response_text(ai_response)
|
||||
text = f"**GEMINI AI**:\n\n{ai_response_text}"
|
||||
ai_response = await chat.send_message(prompt)
|
||||
ai_response_text = get_response_text(ai_response, quoted=True)
|
||||
text = f"**GEMINI AI**:\n{ai_response_text}"
|
||||
_, prompt_message = await convo_obj.send_message(
|
||||
text=text,
|
||||
reply_to_id=reply_to_id,
|
||||
parse_mode=ParseMode.MARKDOWN,
|
||||
get_response=True,
|
||||
disable_preview=True,
|
||||
)
|
||||
prompt, reply_to_id = prompt_message.text, prompt_message.id
|
||||
|
||||
@@ -147,10 +164,10 @@ def generate_filter(message: Message):
|
||||
return filters.create(_filter)
|
||||
|
||||
|
||||
async def export_history(chat, message: Message):
|
||||
doc = BytesIO(pickle.dumps(chat.history))
|
||||
async def export_history(chat: AsyncChat, message: Message):
|
||||
doc = BytesIO(pickle.dumps(chat._curated_history))
|
||||
doc.name = "AI_Chat_History.pkl"
|
||||
caption = get_response_text(
|
||||
await chat.send_message_async("Summarize our Conversation into one line.")
|
||||
await chat.send_message("Summarize our Conversation into one line.")
|
||||
)
|
||||
await bot.send_document(chat_id=message.from_user.id, document=doc, caption=caption)
|
||||
|
||||
Reference in New Issue
Block a user