Files
Moon-UB-overver/utils/db.py
deepsource-autofix[bot] 23341d79ae refactor: remove unnecessary whitespace
Blank lines should not contain any tabs or spaces.
2024-05-01 22:50:34 +05:30

227 lines
7.9 KiB
Python

# Moon-Userbot - telegram userbot
# Copyright (C) 2020-present Moon Userbot Organization
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
import re
import json
import threading
import sqlite3
from dns import resolver
import pymongo
from utils import config
resolver.default_resolver = resolver.Resolver(configure=False)
resolver.default_resolver.nameservers = ["1.1.1.1"]
class Database:
def get(self, module: str, variable: str, default=None):
"""Get value from database"""
raise NotImplementedError
def set(self, module: str, variable: str, value):
"""Set key in database"""
raise NotImplementedError
def remove(self, module: str, variable: str):
"""Remove key from database"""
raise NotImplementedError
def get_collection(self, module: str) -> dict:
"""Get database for selected module"""
raise NotImplementedError
def close(self):
"""Close the database"""
raise NotImplementedError
class MongoDatabase(Database):
def __init__(self, url, name):
self._client = pymongo.MongoClient(url)
self._database = self._client[name]
def set(self, module: str, variable: str, value):
if not isinstance(module, str) or not isinstance(variable, str):
raise ValueError("Module and variable must be strings")
self._database[module].replace_one(
{"var": variable}, {"var": variable, "val": value}, upsert=True
)
def get(self, module: str, variable: str, expected_value=None):
if not isinstance(module, str) or not isinstance(variable, str):
raise ValueError("Module and variable must be strings")
doc = self._database[module].find_one({"var": variable})
return expected_value if doc is None else doc["val"]
def get_collection(self, module: str):
if not isinstance(module, str):
raise ValueError("Module must be a string")
return {item["var"]: item["val"] for item in self._database[module].find()}
def remove(self, module: str, variable: str):
if not isinstance(module, str) or not isinstance(variable, str):
raise ValueError("Module and variable must be strings")
self._database[module].delete_one({"var": variable})
def close(self):
self._client.close()
def add_chat_history(self, user_id, message):
chat_history = self.get_chat_history(user_id, default=[])
chat_history.append(message)
self.set(f"core.cohere.user_{user_id}", "chat_history", chat_history)
def get_chat_history(self, user_id, default=[]):
return self.get(f"core.cohere.user_{user_id}", "chat_history", expected_value=[])
def addaiuser(self, user_id):
chatai_users = self.get("core.chatbot", "chatai_users", expected_value=[])
if user_id not in chatai_users:
chatai_users.append(user_id)
self.set("core.chatbot", "chatai_users", chatai_users)
def remaiuser(self, user_id):
chatai_users = self.get("core.chatbot", "chatai_users", expected_value=[])
if user_id in chatai_users:
chatai_users.remove(user_id)
self.set("core.chatbot", "chatai_users", chatai_users)
def getaiusers(self):
return self.get("core.chatbot", "chatai_users", expected_value=[])
class SqliteDatabase(Database):
def __init__(self, file):
self._conn = sqlite3.connect(file, check_same_thread=False)
self._conn.row_factory = sqlite3.Row
self._cursor = self._conn.cursor()
self._lock = threading.Lock()
@staticmethod
def _parse_row(row: sqlite3.Row):
if row["type"] == "bool":
return row["val"] == "1"
if row["type"] == "int":
return int(row["val"])
if row["type"] == "str":
return row["val"]
return json.loads(row["val"])
def _execute(self, module: str, *args, **kwargs) -> sqlite3.Cursor:
pattern = r"^(core|custom)"
if not re.match(pattern, module):
raise ValueError(f"Invalid module name format: {module}")
self._lock.acquire()
try:
return self._cursor.execute(*args, **kwargs)
except sqlite3.OperationalError as e:
if str(e).startswith("no such table"):
sql = f"""
CREATE TABLE IF NOT EXISTS '{module}' (
var TEXT UNIQUE NOT NULL,
val TEXT NOT NULL,
type TEXT NOT NULL
)
"""
self._cursor.execute(sql)
self._conn.commit()
return self._cursor.execute(*args, **kwargs)
raise e from None
finally:
self._lock.release()
def get(self, module: str, variable: str, default=None):
sql = f"SELECT * FROM '{module}' WHERE var=:var"
cur = self._execute(module, sql, {"var": variable})
row = cur.fetchone()
if row is None:
return default
return self._parse_row(row)
def set(self, module: str, variable: str, value) -> bool:
sql = f"""
INSERT INTO '{module}' VALUES ( :var, :val, :type )
ON CONFLICT (var) DO
UPDATE SET val=:val, type=:type WHERE var=:var
"""
if isinstance(value, bool):
val = "1" if value else "0"
typ = "bool"
elif isinstance(value, str):
val = value
typ = "str"
elif isinstance(value, int):
val = str(value)
typ = "int"
else:
val = json.dumps(value)
typ = "json"
self._execute(module, sql, {"var": variable, "val": val, "type": typ})
self._conn.commit()
return True
def remove(self, module: str, variable: str):
sql = f"DELETE FROM '{module}' WHERE var=:var"
self._execute(module, sql, {"var": variable})
self._conn.commit()
def get_collection(self, module: str) -> dict:
sql = f"SELECT * FROM '{module}'"
cur = self._execute(module, sql)
collection = {}
for row in cur:
collection[row["var"]] = self._parse_row(row)
return collection
def close(self):
self._conn.commit()
self._conn.close()
def add_chat_history(self, user_id, message):
chat_history = self.get_chat_history(user_id, default=[])
chat_history.append(message)
self.set(f"core.cohere.user_{user_id}", "chat_history", chat_history)
def get_chat_history(self, user_id, default=[]):
return self.get(f"core.cohere.user_{user_id}", "chat_history", default=[])
def addaiuser(self, user_id):
chatai_users = self.get("core.chatbot", "chatai_users", default=[])
if user_id not in chatai_users:
chatai_users.append(user_id)
self.set("core.chatbot", "chatai_users", chatai_users)
def remaiuser(self, user_id):
chatai_users = self.get("core.chatbot", "chatai_users", default=[])
if user_id in chatai_users:
chatai_users.remove(user_id)
self.set("core.chatbot", "chatai_users", chatai_users)
def getaiusers(self):
return self.get("core.chatbot", "chatai_users", default=[])
if config.db_type in ["mongo", "mongodb"]:
db = MongoDatabase(config.db_url, config.db_name)
else:
db = SqliteDatabase(config.db_name)