This commit is contained in:
Timothy Jaeryang Baek 2025-08-14 15:46:18 +04:00
parent 53f1caf91f
commit 44e9ae243d
32 changed files with 927 additions and 827 deletions

View file

@ -4,6 +4,7 @@ import os
import shutil import shutil
import base64 import base64
import redis import redis
import asyncio
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
@ -85,23 +86,28 @@ def load_json_config():
return json.load(file) return json.load(file)
def save_to_db(data): async def asave_to_db(data):
with get_db() as db: async with get_db() as db:
existing_config = db.query(Config).first() existing_config = await db.query(Config).first()
if not existing_config: if not existing_config:
new_config = Config(data=data, version=0) new_config = Config(data=data, version=0)
db.add(new_config) await db.add(new_config)
else: else:
existing_config.data = data existing_config.data = data
existing_config.updated_at = datetime.now() existing_config.updated_at = datetime.now()
db.add(existing_config) await db.add(existing_config)
db.commit() await db.commit()
def reset_config(): def save_to_db(data):
with get_db() as db: loop = asyncio.get_event_loop()
db.query(Config).delete() result = loop.run_until_complete(asave_to_db(data))
db.commit()
async def reset_config():
async with get_db() as db:
await db.query(Config).delete()
await db.commit()
# When initializing, check if config.json exists and migrate it to the database # When initializing, check if config.json exists and migrate it to the database
@ -116,25 +122,14 @@ DEFAULT_CONFIG = {
} }
def get_config(): async def get_config():
with get_db() as db: async with get_db() as db:
config_entry = db.query(Config).order_by(Config.id.desc()).first() config_entry = await db.query(Config).order_by(Config.id.desc()).first()
return config_entry.data if config_entry else DEFAULT_CONFIG return config_entry.data if config_entry else DEFAULT_CONFIG
CONFIG_DATA = get_config() loop = asyncio.get_event_loop()
CONFIG_DATA = loop.run_until_complete(get_config())
def get_config_value(config_path: str):
path_parts = config_path.split(".")
cur_config = CONFIG_DATA
for key in path_parts:
if key in cur_config:
cur_config = cur_config[key]
else:
return None
return cur_config
PERSISTENT_CONFIG_REGISTRY = [] PERSISTENT_CONFIG_REGISTRY = []
@ -162,6 +157,17 @@ ENABLE_PERSISTENT_CONFIG = (
) )
def get_config_value(config_path: str):
path_parts = config_path.split(".")
cur_config = CONFIG_DATA
for key in path_parts:
if key in cur_config:
cur_config = cur_config[key]
else:
return None
return cur_config
class PersistentConfig(Generic[T]): class PersistentConfig(Generic[T]):
def __init__(self, env_name: str, config_path: str, env_value: T): def __init__(self, env_name: str, config_path: str, env_value: T):
self.env_name = env_name self.env_name = env_name

View file

@ -1,7 +1,6 @@
import os import os
import json import json
import logging import logging
from contextlib import contextmanager
from typing import Any, Optional from typing import Any, Optional
from open_webui.internal.wrappers import register_connection from open_webui.internal.wrappers import register_connection
@ -21,6 +20,13 @@ from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import scoped_session, sessionmaker from sqlalchemy.orm import scoped_session, sessionmaker
from sqlalchemy.pool import QueuePool, NullPool from sqlalchemy.pool import QueuePool, NullPool
from sqlalchemy.sql.type_api import _T from sqlalchemy.sql.type_api import _T
from sqlalchemy.ext.asyncio import (
create_async_engine,
AsyncSession,
async_sessionmaker,
AsyncAttrs,
)
from contextlib import asynccontextmanager
from typing_extensions import Self from typing_extensions import Self
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -102,7 +108,7 @@ if SQLALCHEMY_DATABASE_URL.startswith("sqlite+sqlcipher://"):
conn.execute(f"PRAGMA key = '{database_password}'") conn.execute(f"PRAGMA key = '{database_password}'")
return conn return conn
engine = create_engine( engine = create_async_engine(
"sqlite://", # Dummy URL since we're using creator "sqlite://", # Dummy URL since we're using creator
creator=create_sqlcipher_connection, creator=create_sqlcipher_connection,
echo=False, echo=False,
@ -111,13 +117,13 @@ if SQLALCHEMY_DATABASE_URL.startswith("sqlite+sqlcipher://"):
log.info("Connected to encrypted SQLite database using SQLCipher") log.info("Connected to encrypted SQLite database using SQLCipher")
elif "sqlite" in SQLALCHEMY_DATABASE_URL: elif "sqlite" in SQLALCHEMY_DATABASE_URL:
engine = create_engine( engine = create_async_engine(
SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False} SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
) )
else: else:
if isinstance(DATABASE_POOL_SIZE, int): if isinstance(DATABASE_POOL_SIZE, int):
if DATABASE_POOL_SIZE > 0: if DATABASE_POOL_SIZE > 0:
engine = create_engine( engine = create_async_engine(
SQLALCHEMY_DATABASE_URL, SQLALCHEMY_DATABASE_URL,
pool_size=DATABASE_POOL_SIZE, pool_size=DATABASE_POOL_SIZE,
max_overflow=DATABASE_POOL_MAX_OVERFLOW, max_overflow=DATABASE_POOL_MAX_OVERFLOW,
@ -127,27 +133,27 @@ else:
poolclass=QueuePool, poolclass=QueuePool,
) )
else: else:
engine = create_engine( engine = create_async_engine(
SQLALCHEMY_DATABASE_URL, pool_pre_ping=True, poolclass=NullPool SQLALCHEMY_DATABASE_URL, pool_pre_ping=True, poolclass=NullPool
) )
else: else:
engine = create_engine(SQLALCHEMY_DATABASE_URL, pool_pre_ping=True) engine = create_async_engine(SQLALCHEMY_DATABASE_URL, pool_pre_ping=True)
AsyncSessionLocal = async_sessionmaker(
SessionLocal = sessionmaker( bind=engine,
autocommit=False, autoflush=False, bind=engine, expire_on_commit=False class_=AsyncSession,
expire_on_commit=False,
autoflush=False,
autocommit=False,
) )
metadata_obj = MetaData(schema=DATABASE_SCHEMA) metadata_obj = MetaData(schema=DATABASE_SCHEMA)
Base = declarative_base(metadata=metadata_obj) Base = declarative_base(metadata=metadata_obj)
Session = scoped_session(SessionLocal)
def get_session(): @asynccontextmanager
db = SessionLocal() async def get_db():
try: async with AsyncSessionLocal() as session:
yield db try:
finally: yield session
db.close() finally:
await session.close()
get_db = contextmanager(get_session)

View file

@ -1459,7 +1459,9 @@ async def chat_completion(
if metadata.get("chat_id") and (user and user.role != "admin"): if metadata.get("chat_id") and (user and user.role != "admin"):
if metadata["chat_id"] != "local": if metadata["chat_id"] != "local":
chat = Chats.get_chat_by_id_and_user_id(metadata["chat_id"], user.id) chat = await Chats.get_chat_by_id_and_user_id(
metadata["chat_id"], user.id
)
if chat is None: if chat is None:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
@ -1477,7 +1479,7 @@ async def chat_completion(
if metadata.get("chat_id") and metadata.get("message_id"): if metadata.get("chat_id") and metadata.get("message_id"):
# Update the chat message with the error # Update the chat message with the error
try: try:
Chats.upsert_message_to_chat_by_id_and_message_id( await Chats.upsert_message_to_chat_by_id_and_message_id(
metadata["chat_id"], metadata["chat_id"],
metadata["message_id"], metadata["message_id"],
{ {
@ -1496,7 +1498,7 @@ async def chat_completion(
response = await chat_completion_handler(request, form_data, user) response = await chat_completion_handler(request, form_data, user)
if metadata.get("chat_id") and metadata.get("message_id"): if metadata.get("chat_id") and metadata.get("message_id"):
try: try:
Chats.upsert_message_to_chat_by_id_and_message_id( await Chats.upsert_message_to_chat_by_id_and_message_id(
metadata["chat_id"], metadata["chat_id"],
metadata["message_id"], metadata["message_id"],
{ {
@ -1514,7 +1516,7 @@ async def chat_completion(
if metadata.get("chat_id") and metadata.get("message_id"): if metadata.get("chat_id") and metadata.get("message_id"):
# Update the chat message with the error # Update the chat message with the error
try: try:
Chats.upsert_message_to_chat_by_id_and_message_id( await Chats.upsert_message_to_chat_by_id_and_message_id(
metadata["chat_id"], metadata["chat_id"],
metadata["message_id"], metadata["message_id"],
{ {
@ -1593,7 +1595,7 @@ async def list_tasks_endpoint(request: Request, user=Depends(get_verified_user))
async def list_tasks_by_chat_id_endpoint( async def list_tasks_by_chat_id_endpoint(
request: Request, chat_id: str, user=Depends(get_verified_user) request: Request, chat_id: str, user=Depends(get_verified_user)
): ):
chat = Chats.get_chat_by_id(chat_id) chat = await Chats.get_chat_by_id(chat_id)
if chat is None or chat.user_id != user.id: if chat is None or chat.user_id != user.id:
return {"task_ids": []} return {"task_ids": []}
@ -1624,7 +1626,7 @@ async def get_app_config(request: Request):
detail="Invalid token", detail="Invalid token",
) )
if data is not None and "id" in data: if data is not None and "id" in data:
user = Users.get_user_by_id(data["id"]) user = await Users.get_user_by_id(data["id"])
user_count = Users.get_num_users() user_count = Users.get_num_users()
onboarding = False onboarding = False

View file

@ -95,7 +95,7 @@ class AddUserForm(SignupForm):
class AuthsTable: class AuthsTable:
def insert_new_auth( async def insert_new_auth(
self, self,
email: str, email: str,
password: str, password: str,
@ -104,7 +104,7 @@ class AuthsTable:
role: str = "pending", role: str = "pending",
oauth_sub: Optional[str] = None, oauth_sub: Optional[str] = None,
) -> Optional[UserModel]: ) -> Optional[UserModel]:
with get_db() as db: async with get_db() as db:
log.info("insert_new_auth") log.info("insert_new_auth")
id = str(uuid.uuid4()) id = str(uuid.uuid4())
@ -115,28 +115,28 @@ class AuthsTable:
result = Auth(**auth.model_dump()) result = Auth(**auth.model_dump())
db.add(result) db.add(result)
user = Users.insert_new_user( user = await Users.insert_new_user(
id, name, email, profile_image_url, role, oauth_sub id, name, email, profile_image_url, role, oauth_sub
) )
db.commit() await db.commit()
db.refresh(result) await db.refresh(result)
if result and user: if result and user:
return user return user
else: else:
return None return None
def authenticate_user(self, email: str, password: str) -> Optional[UserModel]: async def authenticate_user(self, email: str, password: str) -> Optional[UserModel]:
log.info(f"authenticate_user: {email}") log.info(f"authenticate_user: {email}")
user = Users.get_user_by_email(email) user = await Users.get_user_by_email(email)
if not user: if not user:
return None return None
try: try:
with get_db() as db: async with get_db() as db:
auth = db.query(Auth).filter_by(id=user.id, active=True).first() auth = await db.query(Auth).filter_by(id=user.id, active=True).first()
if auth: if auth:
if verify_password(password, auth.password): if verify_password(password, auth.password):
return user return user
@ -147,58 +147,60 @@ class AuthsTable:
except Exception: except Exception:
return None return None
def authenticate_user_by_api_key(self, api_key: str) -> Optional[UserModel]: async def authenticate_user_by_api_key(self, api_key: str) -> Optional[UserModel]:
log.info(f"authenticate_user_by_api_key: {api_key}") log.info(f"authenticate_user_by_api_key: {api_key}")
# if no api_key, return None # if no api_key, return None
if not api_key: if not api_key:
return None return None
try: try:
user = Users.get_user_by_api_key(api_key) user = await Users.get_user_by_api_key(api_key)
return user if user else None return user if user else None
except Exception: except Exception:
return False return False
def authenticate_user_by_email(self, email: str) -> Optional[UserModel]: async def authenticate_user_by_email(self, email: str) -> Optional[UserModel]:
log.info(f"authenticate_user_by_email: {email}") log.info(f"authenticate_user_by_email: {email}")
try: try:
with get_db() as db: async with get_db() as db:
auth = db.query(Auth).filter_by(email=email, active=True).first() auth = await db.query(Auth).filter_by(email=email, active=True).first()
if auth: if auth:
user = Users.get_user_by_id(auth.id) user = await Users.get_user_by_id(auth.id)
return user return user
except Exception: except Exception:
return None return None
def update_user_password_by_id(self, id: str, new_password: str) -> bool: async def update_user_password_by_id(self, id: str, new_password: str) -> bool:
try: try:
with get_db() as db: async with get_db() as db:
result = ( result = (
db.query(Auth).filter_by(id=id).update({"password": new_password}) await db.query(Auth)
.filter_by(id=id)
.update({"password": new_password})
) )
db.commit() await db.commit()
return True if result == 1 else False return True if result == 1 else False
except Exception: except Exception:
return False return False
def update_email_by_id(self, id: str, email: str) -> bool: async def update_email_by_id(self, id: str, email: str) -> bool:
try: try:
with get_db() as db: async with get_db() as db:
result = db.query(Auth).filter_by(id=id).update({"email": email}) result = await db.query(Auth).filter_by(id=id).update({"email": email})
db.commit() await db.commit()
return True if result == 1 else False return True if result == 1 else False
except Exception: except Exception:
return False return False
def delete_auth_by_id(self, id: str) -> bool: async def delete_auth_by_id(self, id: str) -> bool:
try: try:
with get_db() as db: async with get_db() as db:
# Delete User # Delete User
result = Users.delete_user_by_id(id) result = await Users.delete_user_by_id(id)
if result: if result:
db.query(Auth).filter_by(id=id).delete() await db.query(Auth).filter_by(id=id).delete()
db.commit() await db.commit()
return True return True
else: else:

View file

@ -66,10 +66,10 @@ class ChannelForm(BaseModel):
class ChannelTable: class ChannelTable:
def insert_new_channel( async def insert_new_channel(
self, type: Optional[str], form_data: ChannelForm, user_id: str self, type: Optional[str], form_data: ChannelForm, user_id: str
) -> Optional[ChannelModel]: ) -> Optional[ChannelModel]:
with get_db() as db: async with get_db() as db:
channel = ChannelModel( channel = ChannelModel(
**{ **{
**form_data.model_dump(), **form_data.model_dump(),
@ -84,19 +84,19 @@ class ChannelTable:
new_channel = Channel(**channel.model_dump()) new_channel = Channel(**channel.model_dump())
db.add(new_channel) await db.add(new_channel)
db.commit() await db.commit()
return channel return channel
def get_channels(self) -> list[ChannelModel]: async def get_channels(self) -> list[ChannelModel]:
with get_db() as db: async with get_db() as db:
channels = db.query(Channel).all() channels = await db.query(Channel).all()
return [ChannelModel.model_validate(channel) for channel in channels] return [ChannelModel.model_validate(channel) for channel in channels]
def get_channels_by_user_id( async def get_channels_by_user_id(
self, user_id: str, permission: str = "read" self, user_id: str, permission: str = "read"
) -> list[ChannelModel]: ) -> list[ChannelModel]:
channels = self.get_channels() channels = await self.get_channels()
return [ return [
channel channel
for channel in channels for channel in channels
@ -104,16 +104,16 @@ class ChannelTable:
or has_access(user_id, permission, channel.access_control) or has_access(user_id, permission, channel.access_control)
] ]
def get_channel_by_id(self, id: str) -> Optional[ChannelModel]: async def get_channel_by_id(self, id: str) -> Optional[ChannelModel]:
with get_db() as db: async with get_db() as db:
channel = db.query(Channel).filter(Channel.id == id).first() channel = await db.query(Channel).filter(Channel.id == id).first()
return ChannelModel.model_validate(channel) if channel else None return ChannelModel.model_validate(channel) if channel else None
def update_channel_by_id( async def update_channel_by_id(
self, id: str, form_data: ChannelForm self, id: str, form_data: ChannelForm
) -> Optional[ChannelModel]: ) -> Optional[ChannelModel]:
with get_db() as db: async with get_db() as db:
channel = db.query(Channel).filter(Channel.id == id).first() channel = await db.query(Channel).filter(Channel.id == id).first()
if not channel: if not channel:
return None return None
@ -123,13 +123,13 @@ class ChannelTable:
channel.access_control = form_data.access_control channel.access_control = form_data.access_control
channel.updated_at = int(time.time_ns()) channel.updated_at = int(time.time_ns())
db.commit() await db.commit()
return ChannelModel.model_validate(channel) if channel else None return ChannelModel.model_validate(channel) if channel else None
def delete_channel_by_id(self, id: str): async def delete_channel_by_id(self, id: str):
with get_db() as db: async with get_db() as db:
db.query(Channel).filter(Channel.id == id).delete() await db.query(Channel).filter(Channel.id == id).delete()
db.commit() await db.commit()
return True return True

View file

@ -109,8 +109,10 @@ class ChatTitleIdResponse(BaseModel):
class ChatTable: class ChatTable:
def insert_new_chat(self, user_id: str, form_data: ChatForm) -> Optional[ChatModel]: async def insert_new_chat(
with get_db() as db: self, user_id: str, form_data: ChatForm
) -> Optional[ChatModel]:
async with get_db() as db:
id = str(uuid.uuid4()) id = str(uuid.uuid4())
chat = ChatModel( chat = ChatModel(
**{ **{
@ -129,15 +131,15 @@ class ChatTable:
) )
result = Chat(**chat.model_dump()) result = Chat(**chat.model_dump())
db.add(result) await db.add(result)
db.commit() await db.commit()
db.refresh(result) await db.refresh(result)
return ChatModel.model_validate(result) if result else None return ChatModel.model_validate(result) if result else None
def import_chat( async def import_chat(
self, user_id: str, form_data: ChatImportForm self, user_id: str, form_data: ChatImportForm
) -> Optional[ChatModel]: ) -> Optional[ChatModel]:
with get_db() as db: async with get_db() as db:
id = str(uuid.uuid4()) id = str(uuid.uuid4())
chat = ChatModel( chat = ChatModel(
**{ **{
@ -166,82 +168,84 @@ class ChatTable:
) )
result = Chat(**chat.model_dump()) result = Chat(**chat.model_dump())
db.add(result) await db.add(result)
db.commit() await db.commit()
db.refresh(result) await db.refresh(result)
return ChatModel.model_validate(result) if result else None return ChatModel.model_validate(result) if result else None
def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]: async def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]:
try: try:
with get_db() as db: async with get_db() as db:
chat_item = db.get(Chat, id) chat_item = await db.get(Chat, id)
chat_item.chat = chat chat_item.chat = chat
chat_item.title = chat["title"] if "title" in chat else "New Chat" chat_item.title = chat["title"] if "title" in chat else "New Chat"
chat_item.updated_at = int(time.time()) chat_item.updated_at = int(time.time())
db.commit() await db.commit()
db.refresh(chat_item) await db.refresh(chat_item)
return ChatModel.model_validate(chat_item) return ChatModel.model_validate(chat_item)
except Exception: except Exception:
return None return None
def update_chat_title_by_id(self, id: str, title: str) -> Optional[ChatModel]: async def update_chat_title_by_id(self, id: str, title: str) -> Optional[ChatModel]:
chat = self.get_chat_by_id(id) chat = await self.get_chat_by_id(id)
if chat is None: if chat is None:
return None return None
chat = chat.chat chat = chat.chat
chat["title"] = title chat["title"] = title
return self.update_chat_by_id(id, chat) return await self.update_chat_by_id(id, chat)
def update_chat_tags_by_id( async def update_chat_tags_by_id(
self, id: str, tags: list[str], user self, id: str, tags: list[str], user
) -> Optional[ChatModel]: ) -> Optional[ChatModel]:
chat = self.get_chat_by_id(id) chat = await self.get_chat_by_id(id)
if chat is None: if chat is None:
return None return None
self.delete_all_tags_by_id_and_user_id(id, user.id) await self.delete_all_tags_by_id_and_user_id(id, user.id)
for tag in chat.meta.get("tags", []): for tag in chat.meta.get("tags", []):
if self.count_chats_by_tag_name_and_user_id(tag, user.id) == 0: if await self.count_chats_by_tag_name_and_user_id(tag, user.id) == 0:
Tags.delete_tag_by_name_and_user_id(tag, user.id) await Tags.delete_tag_by_name_and_user_id(tag, user.id)
for tag_name in tags: for tag_name in tags:
if tag_name.lower() == "none": if tag_name.lower() == "none":
continue continue
self.add_chat_tag_by_id_and_user_id_and_tag_name(id, user.id, tag_name) await self.add_chat_tag_by_id_and_user_id_and_tag_name(
return self.get_chat_by_id(id) id, user.id, tag_name
)
return await self.get_chat_by_id(id)
def get_chat_title_by_id(self, id: str) -> Optional[str]: async def get_chat_title_by_id(self, id: str) -> Optional[str]:
chat = self.get_chat_by_id(id) chat = await self.get_chat_by_id(id)
if chat is None: if chat is None:
return None return None
return chat.chat.get("title", "New Chat") return chat.chat.get("title", "New Chat")
def get_messages_by_chat_id(self, id: str) -> Optional[dict]: async def get_messages_by_chat_id(self, id: str) -> Optional[dict]:
chat = self.get_chat_by_id(id) chat = await self.get_chat_by_id(id)
if chat is None: if chat is None:
return None return None
return chat.chat.get("history", {}).get("messages", {}) or {} return chat.chat.get("history", {}).get("messages", {}) or {}
def get_message_by_id_and_message_id( async def get_message_by_id_and_message_id(
self, id: str, message_id: str self, id: str, message_id: str
) -> Optional[dict]: ) -> Optional[dict]:
chat = self.get_chat_by_id(id) chat = await self.get_chat_by_id(id)
if chat is None: if chat is None:
return None return None
return chat.chat.get("history", {}).get("messages", {}).get(message_id, {}) return chat.chat.get("history", {}).get("messages", {}).get(message_id, {})
def upsert_message_to_chat_by_id_and_message_id( async def upsert_message_to_chat_by_id_and_message_id(
self, id: str, message_id: str, message: dict self, id: str, message_id: str, message: dict
) -> Optional[ChatModel]: ) -> Optional[ChatModel]:
chat = self.get_chat_by_id(id) chat = await self.get_chat_by_id(id)
if chat is None: if chat is None:
return None return None
@ -263,12 +267,12 @@ class ChatTable:
history["currentId"] = message_id history["currentId"] = message_id
chat["history"] = history chat["history"] = history
return self.update_chat_by_id(id, chat) return await self.update_chat_by_id(id, chat)
def add_message_status_to_chat_by_id_and_message_id( async def add_message_status_to_chat_by_id_and_message_id(
self, id: str, message_id: str, status: dict self, id: str, message_id: str, status: dict
) -> Optional[ChatModel]: ) -> Optional[ChatModel]:
chat = self.get_chat_by_id(id) chat = await self.get_chat_by_id(id)
if chat is None: if chat is None:
return None return None
@ -281,15 +285,15 @@ class ChatTable:
history["messages"][message_id]["statusHistory"] = status_history history["messages"][message_id]["statusHistory"] = status_history
chat["history"] = history chat["history"] = history
return self.update_chat_by_id(id, chat) return await self.update_chat_by_id(id, chat)
def insert_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]: async def insert_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
with get_db() as db: async with get_db() as db:
# Get the existing chat to share # Get the existing chat to share
chat = db.get(Chat, chat_id) chat = await db.get(Chat, chat_id)
# Check if the chat is already shared # Check if the chat is already shared
if chat.share_id: if chat.share_id:
return self.get_chat_by_id_and_user_id(chat.share_id, "shared") return await self.get_chat_by_id_and_user_id(chat.share_id, "shared")
# Create a new chat with the same data, but with a new ID # Create a new chat with the same data, but with a new ID
shared_chat = ChatModel( shared_chat = ChatModel(
**{ **{
@ -305,29 +309,30 @@ class ChatTable:
} }
) )
shared_result = Chat(**shared_chat.model_dump()) shared_result = Chat(**shared_chat.model_dump())
db.add(shared_result) await db.add(shared_result)
db.commit() await db.commit()
db.refresh(shared_result) await db.refresh(shared_result)
# Update the original chat with the share_id # Update the original chat with the share_id
result = ( result = (
db.query(Chat) await db.query(Chat)
.filter_by(id=chat_id) .filter_by(id=chat_id)
.update({"share_id": shared_chat.id}) .update({"share_id": shared_chat.id})
) )
db.commit() await db.commit()
return shared_chat if (shared_result and result) else None return shared_chat if (shared_result and result) else None
def update_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]: async def update_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
try: try:
with get_db() as db: async with get_db() as db:
chat = db.get(Chat, chat_id) chat = await db.get(Chat, chat_id)
shared_chat = ( shared_chat = (
db.query(Chat).filter_by(user_id=f"shared-{chat_id}").first() await db.query(Chat).filter_by(user_id=f"shared-{chat_id}").first()
) )
if shared_chat is None: if shared_chat is None:
return self.insert_shared_chat_by_chat_id(chat_id) return await self.insert_shared_chat_by_chat_id(chat_id)
shared_chat.title = chat.title shared_chat.title = chat.title
shared_chat.chat = chat.chat shared_chat.chat = chat.chat
@ -335,70 +340,72 @@ class ChatTable:
shared_chat.pinned = chat.pinned shared_chat.pinned = chat.pinned
shared_chat.folder_id = chat.folder_id shared_chat.folder_id = chat.folder_id
shared_chat.updated_at = int(time.time()) shared_chat.updated_at = int(time.time())
db.commit() await db.commit()
db.refresh(shared_chat) await db.refresh(shared_chat)
return ChatModel.model_validate(shared_chat) return ChatModel.model_validate(shared_chat)
except Exception: except Exception:
return None return None
def delete_shared_chat_by_chat_id(self, chat_id: str) -> bool: async def delete_shared_chat_by_chat_id(self, chat_id: str) -> bool:
try: try:
with get_db() as db: async with get_db() as db:
db.query(Chat).filter_by(user_id=f"shared-{chat_id}").delete() await db.query(Chat).filter_by(user_id=f"shared-{chat_id}").delete()
db.commit() await db.commit()
return True return True
except Exception: except Exception:
return False return False
def update_chat_share_id_by_id( async def update_chat_share_id_by_id(
self, id: str, share_id: Optional[str] self, id: str, share_id: Optional[str]
) -> Optional[ChatModel]: ) -> Optional[ChatModel]:
try: try:
with get_db() as db: async with get_db() as db:
chat = db.get(Chat, id) chat = await db.get(Chat, id)
chat.share_id = share_id chat.share_id = share_id
db.commit() await db.commit()
db.refresh(chat) await db.refresh(chat)
return ChatModel.model_validate(chat) return ChatModel.model_validate(chat)
except Exception: except Exception:
return None return None
def toggle_chat_pinned_by_id(self, id: str) -> Optional[ChatModel]: async def toggle_chat_pinned_by_id(self, id: str) -> Optional[ChatModel]:
try: try:
with get_db() as db: async with get_db() as db:
chat = db.get(Chat, id) chat = await db.get(Chat, id)
chat.pinned = not chat.pinned chat.pinned = not chat.pinned
chat.updated_at = int(time.time()) chat.updated_at = int(time.time())
db.commit() await db.commit()
db.refresh(chat) await db.refresh(chat)
return ChatModel.model_validate(chat) return ChatModel.model_validate(chat)
except Exception: except Exception:
return None return None
def toggle_chat_archive_by_id(self, id: str) -> Optional[ChatModel]: async def toggle_chat_archive_by_id(self, id: str) -> Optional[ChatModel]:
try: try:
with get_db() as db: async with get_db() as db:
chat = db.get(Chat, id) chat = await db.get(Chat, id)
chat.archived = not chat.archived chat.archived = not chat.archived
chat.updated_at = int(time.time()) chat.updated_at = int(time.time())
db.commit() await db.commit()
db.refresh(chat) await db.refresh(chat)
return ChatModel.model_validate(chat) return ChatModel.model_validate(chat)
except Exception: except Exception:
return None return None
def archive_all_chats_by_user_id(self, user_id: str) -> bool: async def archive_all_chats_by_user_id(self, user_id: str) -> bool:
try: try:
with get_db() as db: async with get_db() as db:
db.query(Chat).filter_by(user_id=user_id).update({"archived": True}) await db.query(Chat).filter_by(user_id=user_id).update(
db.commit() {"archived": True}
)
await db.commit()
return True return True
except Exception: except Exception:
return False return False
def get_archived_chat_list_by_user_id( async def get_archived_chat_list_by_user_id(
self, self,
user_id: str, user_id: str,
filter: Optional[dict] = None, filter: Optional[dict] = None,
@ -406,7 +413,7 @@ class ChatTable:
limit: int = 50, limit: int = 50,
) -> list[ChatModel]: ) -> list[ChatModel]:
with get_db() as db: async with get_db() as db:
query = db.query(Chat).filter_by(user_id=user_id, archived=True) query = db.query(Chat).filter_by(user_id=user_id, archived=True)
if filter: if filter:
@ -432,10 +439,10 @@ class ChatTable:
if limit: if limit:
query = query.limit(limit) query = query.limit(limit)
all_chats = query.all() all_chats = await query.all()
return [ChatModel.model_validate(chat) for chat in all_chats] return [ChatModel.model_validate(chat) for chat in all_chats]
def get_chat_list_by_user_id( async def get_chat_list_by_user_id(
self, self,
user_id: str, user_id: str,
include_archived: bool = False, include_archived: bool = False,
@ -443,7 +450,7 @@ class ChatTable:
skip: int = 0, skip: int = 0,
limit: int = 50, limit: int = 50,
) -> list[ChatModel]: ) -> list[ChatModel]:
with get_db() as db: async with get_db() as db:
query = db.query(Chat).filter_by(user_id=user_id) query = db.query(Chat).filter_by(user_id=user_id)
if not include_archived: if not include_archived:
query = query.filter_by(archived=False) query = query.filter_by(archived=False)
@ -471,17 +478,17 @@ class ChatTable:
if limit: if limit:
query = query.limit(limit) query = query.limit(limit)
all_chats = query.all() all_chats = await query.all()
return [ChatModel.model_validate(chat) for chat in all_chats] return [ChatModel.model_validate(chat) for chat in all_chats]
def get_chat_title_id_list_by_user_id( async def get_chat_title_id_list_by_user_id(
self, self,
user_id: str, user_id: str,
include_archived: bool = False, include_archived: bool = False,
skip: Optional[int] = None, skip: Optional[int] = None,
limit: Optional[int] = None, limit: Optional[int] = None,
) -> list[ChatTitleIdResponse]: ) -> list[ChatTitleIdResponse]:
with get_db() as db: async with get_db() as db:
query = db.query(Chat).filter_by(user_id=user_id).filter_by(folder_id=None) query = db.query(Chat).filter_by(user_id=user_id).filter_by(folder_id=None)
query = query.filter(or_(Chat.pinned == False, Chat.pinned == None)) query = query.filter(or_(Chat.pinned == False, Chat.pinned == None))
@ -497,7 +504,7 @@ class ChatTable:
if limit: if limit:
query = query.limit(limit) query = query.limit(limit)
all_chats = query.all() all_chats = await query.all()
# result has to be destructured from sqlalchemy `row` and mapped to a dict since the `ChatModel`is not the returned dataclass. # result has to be destructured from sqlalchemy `row` and mapped to a dict since the `ChatModel`is not the returned dataclass.
return [ return [
@ -512,12 +519,12 @@ class ChatTable:
for chat in all_chats for chat in all_chats
] ]
def get_chat_list_by_chat_ids( async def get_chat_list_by_chat_ids(
self, chat_ids: list[str], skip: int = 0, limit: int = 50 self, chat_ids: list[str], skip: int = 0, limit: int = 50
) -> list[ChatModel]: ) -> list[ChatModel]:
with get_db() as db: async with get_db() as db:
all_chats = ( all_chats = (
db.query(Chat) await db.query(Chat)
.filter(Chat.id.in_(chat_ids)) .filter(Chat.id.in_(chat_ids))
.filter_by(archived=False) .filter_by(archived=False)
.order_by(Chat.updated_at.desc()) .order_by(Chat.updated_at.desc())
@ -525,73 +532,75 @@ class ChatTable:
) )
return [ChatModel.model_validate(chat) for chat in all_chats] return [ChatModel.model_validate(chat) for chat in all_chats]
def get_chat_by_id(self, id: str) -> Optional[ChatModel]: async def get_chat_by_id(self, id: str) -> Optional[ChatModel]:
try: try:
with get_db() as db: async with get_db() as db:
chat = db.get(Chat, id) chat = await db.get(Chat, id)
return ChatModel.model_validate(chat) return ChatModel.model_validate(chat)
except Exception: except Exception:
return None return None
def get_chat_by_share_id(self, id: str) -> Optional[ChatModel]: async def get_chat_by_share_id(self, id: str) -> Optional[ChatModel]:
try: try:
with get_db() as db: async with get_db() as db:
# it is possible that the shared link was deleted. hence, # it is possible that the shared link was deleted. hence,
# we check if the chat is still shared by checking if a chat with the share_id exists # we check if the chat is still shared by checking if a chat with the share_id exists
chat = db.query(Chat).filter_by(share_id=id).first() chat = await db.query(Chat).filter_by(share_id=id).first()
if chat: if chat:
return self.get_chat_by_id(id) return await self.get_chat_by_id(id)
else: else:
return None return None
except Exception: except Exception:
return None return None
def get_chat_by_id_and_user_id(self, id: str, user_id: str) -> Optional[ChatModel]: async def get_chat_by_id_and_user_id(
self, id: str, user_id: str
) -> Optional[ChatModel]:
try: try:
with get_db() as db: async with get_db() as db:
chat = db.query(Chat).filter_by(id=id, user_id=user_id).first() chat = await db.query(Chat).filter_by(id=id, user_id=user_id).first()
return ChatModel.model_validate(chat) return ChatModel.model_validate(chat)
except Exception: except Exception:
return None return None
def get_chats(self, skip: int = 0, limit: int = 50) -> list[ChatModel]: async def get_chats(self, skip: int = 0, limit: int = 50) -> list[ChatModel]:
with get_db() as db: async with get_db() as db:
all_chats = ( all_chats = (
db.query(Chat) await db.query(Chat)
# .limit(limit).offset(skip) # .limit(limit).offset(skip)
.order_by(Chat.updated_at.desc()) .order_by(Chat.updated_at.desc())
) )
return [ChatModel.model_validate(chat) for chat in all_chats] return [ChatModel.model_validate(chat) for chat in all_chats]
def get_chats_by_user_id(self, user_id: str) -> list[ChatModel]: async def get_chats_by_user_id(self, user_id: str) -> list[ChatModel]:
with get_db() as db: async with get_db() as db:
all_chats = ( all_chats = (
db.query(Chat) await db.query(Chat)
.filter_by(user_id=user_id) .filter_by(user_id=user_id)
.order_by(Chat.updated_at.desc()) .order_by(Chat.updated_at.desc())
) )
return [ChatModel.model_validate(chat) for chat in all_chats] return [ChatModel.model_validate(chat) for chat in all_chats]
def get_pinned_chats_by_user_id(self, user_id: str) -> list[ChatModel]: async def get_pinned_chats_by_user_id(self, user_id: str) -> list[ChatModel]:
with get_db() as db: async with get_db() as db:
all_chats = ( all_chats = (
db.query(Chat) await db.query(Chat)
.filter_by(user_id=user_id, pinned=True, archived=False) .filter_by(user_id=user_id, pinned=True, archived=False)
.order_by(Chat.updated_at.desc()) .order_by(Chat.updated_at.desc())
) )
return [ChatModel.model_validate(chat) for chat in all_chats] return [ChatModel.model_validate(chat) for chat in all_chats]
def get_archived_chats_by_user_id(self, user_id: str) -> list[ChatModel]: async def get_archived_chats_by_user_id(self, user_id: str) -> list[ChatModel]:
with get_db() as db: async with get_db() as db:
all_chats = ( all_chats = (
db.query(Chat) await db.query(Chat)
.filter_by(user_id=user_id, archived=True) .filter_by(user_id=user_id, archived=True)
.order_by(Chat.updated_at.desc()) .order_by(Chat.updated_at.desc())
) )
return [ChatModel.model_validate(chat) for chat in all_chats] return [ChatModel.model_validate(chat) for chat in all_chats]
def get_chats_by_user_id_and_search_text( async def get_chats_by_user_id_and_search_text(
self, self,
user_id: str, user_id: str,
search_text: str, search_text: str,
@ -605,7 +614,7 @@ class ChatTable:
search_text = search_text.replace("\u0000", "").lower().strip() search_text = search_text.replace("\u0000", "").lower().strip()
if not search_text: if not search_text:
return self.get_chat_list_by_user_id( return await self.get_chat_list_by_user_id(
user_id, include_archived, filter={}, skip=skip, limit=limit user_id, include_archived, filter={}, skip=skip, limit=limit
) )
@ -619,7 +628,7 @@ class ChatTable:
] ]
# Extract folder names - handle spaces and case insensitivity # Extract folder names - handle spaces and case insensitivity
folders = Folders.search_folders_by_names( folders = await Folders.search_folders_by_names(
user_id, user_id,
[ [
word.replace("folder:", "") word.replace("folder:", "")
@ -661,7 +670,7 @@ class ChatTable:
search_text = " ".join(search_text_words) search_text = " ".join(search_text_words)
with get_db() as db: async with get_db() as db:
query = db.query(Chat).filter(Chat.user_id == user_id) query = db.query(Chat).filter(Chat.user_id == user_id)
if is_archived is not None: if is_archived is not None:
@ -783,30 +792,30 @@ class ChatTable:
) )
# Perform pagination at the SQL level # Perform pagination at the SQL level
all_chats = query.offset(skip).limit(limit).all() all_chats = await query.offset(skip).limit(limit).all()
log.info(f"The number of chats: {len(all_chats)}") log.info(f"The number of chats: {len(all_chats)}")
# Validate and return chats # Validate and return chats
return [ChatModel.model_validate(chat) for chat in all_chats] return [ChatModel.model_validate(chat) for chat in all_chats]
def get_chats_by_folder_id_and_user_id( async def get_chats_by_folder_id_and_user_id(
self, folder_id: str, user_id: str self, folder_id: str, user_id: str
) -> list[ChatModel]: ) -> list[ChatModel]:
with get_db() as db: async with get_db() as db:
query = db.query(Chat).filter_by(folder_id=folder_id, user_id=user_id) query = db.query(Chat).filter_by(folder_id=folder_id, user_id=user_id)
query = query.filter(or_(Chat.pinned == False, Chat.pinned == None)) query = query.filter(or_(Chat.pinned == False, Chat.pinned == None))
query = query.filter_by(archived=False) query = query.filter_by(archived=False)
query = query.order_by(Chat.updated_at.desc()) query = query.order_by(Chat.updated_at.desc())
all_chats = query.all() all_chats = await query.all()
return [ChatModel.model_validate(chat) for chat in all_chats] return [ChatModel.model_validate(chat) for chat in all_chats]
def get_chats_by_folder_ids_and_user_id( async def get_chats_by_folder_ids_and_user_id(
self, folder_ids: list[str], user_id: str self, folder_ids: list[str], user_id: str
) -> list[ChatModel]: ) -> list[ChatModel]:
with get_db() as db: async with get_db() as db:
query = db.query(Chat).filter( query = db.query(Chat).filter(
Chat.folder_id.in_(folder_ids), Chat.user_id == user_id Chat.folder_id.in_(folder_ids), Chat.user_id == user_id
) )
@ -815,34 +824,38 @@ class ChatTable:
query = query.order_by(Chat.updated_at.desc()) query = query.order_by(Chat.updated_at.desc())
all_chats = query.all() all_chats = await query.all()
return [ChatModel.model_validate(chat) for chat in all_chats] return [ChatModel.model_validate(chat) for chat in all_chats]
def update_chat_folder_id_by_id_and_user_id( async def update_chat_folder_id_by_id_and_user_id(
self, id: str, user_id: str, folder_id: str self, id: str, user_id: str, folder_id: str
) -> Optional[ChatModel]: ) -> Optional[ChatModel]:
try: try:
with get_db() as db: async with get_db() as db:
chat = db.get(Chat, id) chat = await db.get(Chat, id)
chat.folder_id = folder_id chat.folder_id = folder_id
chat.updated_at = int(time.time()) chat.updated_at = int(time.time())
chat.pinned = False chat.pinned = False
db.commit() await db.commit()
db.refresh(chat) await db.refresh(chat)
return ChatModel.model_validate(chat) return ChatModel.model_validate(chat)
except Exception: except Exception:
return None return None
def get_chat_tags_by_id_and_user_id(self, id: str, user_id: str) -> list[TagModel]: async def get_chat_tags_by_id_and_user_id(
with get_db() as db: self, id: str, user_id: str
chat = db.get(Chat, id) ) -> list[TagModel]:
async with get_db() as db:
chat = await db.get(Chat, id)
tags = chat.meta.get("tags", []) tags = chat.meta.get("tags", [])
return [Tags.get_tag_by_name_and_user_id(tag, user_id) for tag in tags] return [
await Tags.get_tag_by_name_and_user_id(tag, user_id) for tag in tags
]
def get_chat_list_by_user_id_and_tag_name( async def get_chat_list_by_user_id_and_tag_name(
self, user_id: str, tag_name: str, skip: int = 0, limit: int = 50 self, user_id: str, tag_name: str, skip: int = 0, limit: int = 50
) -> list[ChatModel]: ) -> list[ChatModel]:
with get_db() as db: async with get_db() as db:
query = db.query(Chat).filter_by(user_id=user_id) query = db.query(Chat).filter_by(user_id=user_id)
tag_id = tag_name.replace(" ", "_").lower() tag_id = tag_name.replace(" ", "_").lower()
@ -866,19 +879,19 @@ class ChatTable:
f"Unsupported dialect: {db.bind.dialect.name}" f"Unsupported dialect: {db.bind.dialect.name}"
) )
all_chats = query.all() all_chats = await query.all()
log.debug(f"all_chats: {all_chats}") log.debug(f"all_chats: {all_chats}")
return [ChatModel.model_validate(chat) for chat in all_chats] return [ChatModel.model_validate(chat) for chat in all_chats]
def add_chat_tag_by_id_and_user_id_and_tag_name( async def add_chat_tag_by_id_and_user_id_and_tag_name(
self, id: str, user_id: str, tag_name: str self, id: str, user_id: str, tag_name: str
) -> Optional[ChatModel]: ) -> Optional[ChatModel]:
tag = Tags.get_tag_by_name_and_user_id(tag_name, user_id) tag = await Tags.get_tag_by_name_and_user_id(tag_name, user_id)
if tag is None: if tag is None:
tag = Tags.insert_new_tag(tag_name, user_id) tag = await Tags.insert_new_tag(tag_name, user_id)
try: try:
with get_db() as db: async with get_db() as db:
chat = db.get(Chat, id) chat = await db.get(Chat, id)
tag_id = tag.id tag_id = tag.id
if tag_id not in chat.meta.get("tags", []): if tag_id not in chat.meta.get("tags", []):
@ -887,14 +900,16 @@ class ChatTable:
"tags": list(set(chat.meta.get("tags", []) + [tag_id])), "tags": list(set(chat.meta.get("tags", []) + [tag_id])),
} }
db.commit() await db.commit()
db.refresh(chat) await db.refresh(chat)
return ChatModel.model_validate(chat) return ChatModel.model_validate(chat)
except Exception: except Exception:
return None return None
def count_chats_by_tag_name_and_user_id(self, tag_name: str, user_id: str) -> int: async def count_chats_by_tag_name_and_user_id(
with get_db() as db: # Assuming `get_db()` returns a session object self, tag_name: str, user_id: str
) -> int:
async with get_db() as db: # Assuming `get_db()` returns a session object
query = db.query(Chat).filter_by(user_id=user_id, archived=False) query = db.query(Chat).filter_by(user_id=user_id, archived=False)
# Normalize the tag_name for consistency # Normalize the tag_name for consistency
@ -922,19 +937,19 @@ class ChatTable:
) )
# Get the count of matching records # Get the count of matching records
count = query.count() count = await query.count()
# Debugging output for inspection # Debugging output for inspection
log.info(f"Count of chats for tag '{tag_name}': {count}") log.info(f"Count of chats for tag '{tag_name}': {count}")
return count return count
def delete_tag_by_id_and_user_id_and_tag_name( async def delete_tag_by_id_and_user_id_and_tag_name(
self, id: str, user_id: str, tag_name: str self, id: str, user_id: str, tag_name: str
) -> bool: ) -> bool:
try: try:
with get_db() as db: async with get_db() as db:
chat = db.get(Chat, id) chat = await db.get(Chat, id)
tags = chat.meta.get("tags", []) tags = chat.meta.get("tags", [])
tag_id = tag_name.replace(" ", "_").lower() tag_id = tag_name.replace(" ", "_").lower()
@ -943,77 +958,79 @@ class ChatTable:
**chat.meta, **chat.meta,
"tags": list(set(tags)), "tags": list(set(tags)),
} }
db.commit() await db.commit()
return True return True
except Exception: except Exception:
return False return False
def delete_all_tags_by_id_and_user_id(self, id: str, user_id: str) -> bool: async def delete_all_tags_by_id_and_user_id(self, id: str, user_id: str) -> bool:
try: try:
with get_db() as db: async with get_db() as db:
chat = db.get(Chat, id) chat = await db.get(Chat, id)
chat.meta = { chat.meta = {
**chat.meta, **chat.meta,
"tags": [], "tags": [],
} }
db.commit() await db.commit()
return True return True
except Exception: except Exception:
return False return False
def delete_chat_by_id(self, id: str) -> bool: async def delete_chat_by_id(self, id: str) -> bool:
try: try:
with get_db() as db: async with get_db() as db:
db.query(Chat).filter_by(id=id).delete() await db.query(Chat).filter_by(id=id).delete()
db.commit() await db.commit()
return True and self.delete_shared_chat_by_chat_id(id) return True and await self.delete_shared_chat_by_chat_id(id)
except Exception: except Exception:
return False return False
def delete_chat_by_id_and_user_id(self, id: str, user_id: str) -> bool: async def delete_chat_by_id_and_user_id(self, id: str, user_id: str) -> bool:
try: try:
with get_db() as db: async with get_db() as db:
db.query(Chat).filter_by(id=id, user_id=user_id).delete() await db.query(Chat).filter_by(id=id, user_id=user_id).delete()
db.commit() await db.commit()
return True and self.delete_shared_chat_by_chat_id(id) return True and await self.delete_shared_chat_by_chat_id(id)
except Exception: except Exception:
return False return False
def delete_chats_by_user_id(self, user_id: str) -> bool: async def delete_chats_by_user_id(self, user_id: str) -> bool:
try: try:
with get_db() as db: async with get_db() as db:
self.delete_shared_chats_by_user_id(user_id) await self.delete_shared_chats_by_user_id(user_id)
db.query(Chat).filter_by(user_id=user_id).delete() await db.query(Chat).filter_by(user_id=user_id).delete()
db.commit() await db.commit()
return True return True
except Exception: except Exception:
return False return False
def delete_chats_by_user_id_and_folder_id( async def delete_chats_by_user_id_and_folder_id(
self, user_id: str, folder_id: str self, user_id: str, folder_id: str
) -> bool: ) -> bool:
try: try:
with get_db() as db: async with get_db() as db:
db.query(Chat).filter_by(user_id=user_id, folder_id=folder_id).delete() await db.query(Chat).filter_by(
db.commit() user_id=user_id, folder_id=folder_id
).delete()
await db.commit()
return True return True
except Exception: except Exception:
return False return False
def delete_shared_chats_by_user_id(self, user_id: str) -> bool: async def delete_shared_chats_by_user_id(self, user_id: str) -> bool:
try: try:
with get_db() as db: async with get_db() as db:
chats_by_user = db.query(Chat).filter_by(user_id=user_id).all() chats_by_user = await db.query(Chat).filter_by(user_id=user_id).all()
shared_chat_ids = [f"shared-{chat.id}" for chat in chats_by_user] shared_chat_ids = [f"shared-{chat.id}" for chat in chats_by_user]
db.query(Chat).filter(Chat.user_id.in_(shared_chat_ids)).delete() await db.query(Chat).filter(Chat.user_id.in_(shared_chat_ids)).delete()
db.commit() await db.commit()
return True return True
except Exception: except Exception:

View file

@ -93,10 +93,10 @@ class FeedbackForm(BaseModel):
class FeedbackTable: class FeedbackTable:
def insert_new_feedback( async def insert_new_feedback(
self, user_id: str, form_data: FeedbackForm self, user_id: str, form_data: FeedbackForm
) -> Optional[FeedbackModel]: ) -> Optional[FeedbackModel]:
with get_db() as db: async with get_db() as db:
id = str(uuid.uuid4()) id = str(uuid.uuid4())
feedback = FeedbackModel( feedback = FeedbackModel(
**{ **{
@ -110,9 +110,9 @@ class FeedbackTable:
) )
try: try:
result = Feedback(**feedback.model_dump()) result = Feedback(**feedback.model_dump())
db.add(result) await db.add(result)
db.commit() await db.commit()
db.refresh(result) await db.refresh(result)
if result: if result:
return FeedbackModel.model_validate(result) return FeedbackModel.model_validate(result)
else: else:
@ -121,62 +121,64 @@ class FeedbackTable:
log.exception(f"Error creating a new feedback: {e}") log.exception(f"Error creating a new feedback: {e}")
return None return None
def get_feedback_by_id(self, id: str) -> Optional[FeedbackModel]: async def get_feedback_by_id(self, id: str) -> Optional[FeedbackModel]:
try: try:
with get_db() as db: async with get_db() as db:
feedback = db.query(Feedback).filter_by(id=id).first() feedback = await db.query(Feedback).filter_by(id=id).first()
if not feedback: if not feedback:
return None return None
return FeedbackModel.model_validate(feedback) return FeedbackModel.model_validate(feedback)
except Exception: except Exception:
return None return None
def get_feedback_by_id_and_user_id( async def get_feedback_by_id_and_user_id(
self, id: str, user_id: str self, id: str, user_id: str
) -> Optional[FeedbackModel]: ) -> Optional[FeedbackModel]:
try: try:
with get_db() as db: async with get_db() as db:
feedback = db.query(Feedback).filter_by(id=id, user_id=user_id).first() feedback = (
await db.query(Feedback).filter_by(id=id, user_id=user_id).first()
)
if not feedback: if not feedback:
return None return None
return FeedbackModel.model_validate(feedback) return FeedbackModel.model_validate(feedback)
except Exception: except Exception:
return None return None
def get_all_feedbacks(self) -> list[FeedbackModel]: async def get_all_feedbacks(self) -> list[FeedbackModel]:
with get_db() as db: async with get_db() as db:
return [ return [
FeedbackModel.model_validate(feedback) FeedbackModel.model_validate(feedback)
for feedback in db.query(Feedback) for feedback in await db.query(Feedback)
.order_by(Feedback.updated_at.desc()) .order_by(Feedback.updated_at.desc())
.all() .all()
] ]
def get_feedbacks_by_type(self, type: str) -> list[FeedbackModel]: async def get_feedbacks_by_type(self, type: str) -> list[FeedbackModel]:
with get_db() as db: async with get_db() as db:
return [ return [
FeedbackModel.model_validate(feedback) FeedbackModel.model_validate(feedback)
for feedback in db.query(Feedback) for feedback in await db.query(Feedback)
.filter_by(type=type) .filter_by(type=type)
.order_by(Feedback.updated_at.desc()) .order_by(Feedback.updated_at.desc())
.all() .all()
] ]
def get_feedbacks_by_user_id(self, user_id: str) -> list[FeedbackModel]: async def get_feedbacks_by_user_id(self, user_id: str) -> list[FeedbackModel]:
with get_db() as db: async with get_db() as db:
return [ return [
FeedbackModel.model_validate(feedback) FeedbackModel.model_validate(feedback)
for feedback in db.query(Feedback) for feedback in await db.query(Feedback)
.filter_by(user_id=user_id) .filter_by(user_id=user_id)
.order_by(Feedback.updated_at.desc()) .order_by(Feedback.updated_at.desc())
.all() .all()
] ]
def update_feedback_by_id( async def update_feedback_by_id(
self, id: str, form_data: FeedbackForm self, id: str, form_data: FeedbackForm
) -> Optional[FeedbackModel]: ) -> Optional[FeedbackModel]:
with get_db() as db: async with get_db() as db:
feedback = db.query(Feedback).filter_by(id=id).first() feedback = await db.query(Feedback).filter_by(id=id).first()
if not feedback: if not feedback:
return None return None
@ -189,14 +191,16 @@ class FeedbackTable:
feedback.updated_at = int(time.time()) feedback.updated_at = int(time.time())
db.commit() await db.commit()
return FeedbackModel.model_validate(feedback) return FeedbackModel.model_validate(feedback)
def update_feedback_by_id_and_user_id( async def update_feedback_by_id_and_user_id(
self, id: str, user_id: str, form_data: FeedbackForm self, id: str, user_id: str, form_data: FeedbackForm
) -> Optional[FeedbackModel]: ) -> Optional[FeedbackModel]:
with get_db() as db: async with get_db() as db:
feedback = db.query(Feedback).filter_by(id=id, user_id=user_id).first() feedback = (
await db.query(Feedback).filter_by(id=id, user_id=user_id).first()
)
if not feedback: if not feedback:
return None return None
@ -209,45 +213,47 @@ class FeedbackTable:
feedback.updated_at = int(time.time()) feedback.updated_at = int(time.time())
db.commit() await db.commit()
return FeedbackModel.model_validate(feedback) return FeedbackModel.model_validate(feedback)
def delete_feedback_by_id(self, id: str) -> bool: async def delete_feedback_by_id(self, id: str) -> bool:
with get_db() as db: async with get_db() as db:
feedback = db.query(Feedback).filter_by(id=id).first() feedback = await db.query(Feedback).filter_by(id=id).first()
if not feedback: if not feedback:
return False return False
db.delete(feedback) await db.delete(feedback)
db.commit() await db.commit()
return True return True
def delete_feedback_by_id_and_user_id(self, id: str, user_id: str) -> bool: async def delete_feedback_by_id_and_user_id(self, id: str, user_id: str) -> bool:
with get_db() as db: async with get_db() as db:
feedback = db.query(Feedback).filter_by(id=id, user_id=user_id).first() feedback = (
await db.query(Feedback).filter_by(id=id, user_id=user_id).first()
)
if not feedback: if not feedback:
return False return False
db.delete(feedback) await db.delete(feedback)
db.commit() await db.commit()
return True return True
def delete_feedbacks_by_user_id(self, user_id: str) -> bool: async def delete_feedbacks_by_user_id(self, user_id: str) -> bool:
with get_db() as db: async with get_db() as db:
feedbacks = db.query(Feedback).filter_by(user_id=user_id).all() feedbacks = await db.query(Feedback).filter_by(user_id=user_id).all()
if not feedbacks: if not feedbacks:
return False return False
for feedback in feedbacks: for feedback in feedbacks:
db.delete(feedback) await db.delete(feedback)
db.commit() await db.commit()
return True return True
def delete_all_feedbacks(self) -> bool: async def delete_all_feedbacks(self) -> bool:
with get_db() as db: async with get_db() as db:
feedbacks = db.query(Feedback).all() feedbacks = await db.query(Feedback).all()
if not feedbacks: if not feedbacks:
return False return False
for feedback in feedbacks: for feedback in feedbacks:
db.delete(feedback) await db.delete(feedback)
db.commit() await db.commit()
return True return True

View file

@ -98,8 +98,10 @@ class FileForm(BaseModel):
class FilesTable: class FilesTable:
def insert_new_file(self, user_id: str, form_data: FileForm) -> Optional[FileModel]: async def insert_new_file(
with get_db() as db: self, user_id: str, form_data: FileForm
) -> Optional[FileModel]:
async with get_db() as db:
file = FileModel( file = FileModel(
**{ **{
**form_data.model_dump(), **form_data.model_dump(),
@ -111,9 +113,9 @@ class FilesTable:
try: try:
result = File(**file.model_dump()) result = File(**file.model_dump())
db.add(result) await db.add(result)
db.commit() await db.commit()
db.refresh(result) await db.refresh(result)
if result: if result:
return FileModel.model_validate(result) return FileModel.model_validate(result)
else: else:
@ -122,18 +124,18 @@ class FilesTable:
log.exception(f"Error inserting a new file: {e}") log.exception(f"Error inserting a new file: {e}")
return None return None
def get_file_by_id(self, id: str) -> Optional[FileModel]: async def get_file_by_id(self, id: str) -> Optional[FileModel]:
with get_db() as db: async with get_db() as db:
try: try:
file = db.get(File, id) file = await db.get(File, id)
return FileModel.model_validate(file) return FileModel.model_validate(file)
except Exception: except Exception:
return None return None
def get_file_metadata_by_id(self, id: str) -> Optional[FileMetadataResponse]: async def get_file_metadata_by_id(self, id: str) -> Optional[FileMetadataResponse]:
with get_db() as db: async with get_db() as db:
try: try:
file = db.get(File, id) file = await db.get(File, id)
return FileMetadataResponse( return FileMetadataResponse(
id=file.id, id=file.id,
meta=file.meta, meta=file.meta,
@ -143,22 +145,26 @@ class FilesTable:
except Exception: except Exception:
return None return None
def get_files(self) -> list[FileModel]: async def get_files(self) -> list[FileModel]:
with get_db() as db: async with get_db() as db:
return [FileModel.model_validate(file) for file in db.query(File).all()] return [
FileModel.model_validate(file) for file in await db.query(File).all()
]
def get_files_by_ids(self, ids: list[str]) -> list[FileModel]: async def get_files_by_ids(self, ids: list[str]) -> list[FileModel]:
with get_db() as db: async with get_db() as db:
return [ return [
FileModel.model_validate(file) FileModel.model_validate(file)
for file in db.query(File) for file in await db.query(File)
.filter(File.id.in_(ids)) .filter(File.id.in_(ids))
.order_by(File.updated_at.desc()) .order_by(File.updated_at.desc())
.all() .all()
] ]
def get_file_metadatas_by_ids(self, ids: list[str]) -> list[FileMetadataResponse]: async def get_file_metadatas_by_ids(
with get_db() as db: self, ids: list[str]
) -> list[FileMetadataResponse]:
async with get_db() as db:
return [ return [
FileMetadataResponse( FileMetadataResponse(
id=file.id, id=file.id,
@ -166,66 +172,68 @@ class FilesTable:
created_at=file.created_at, created_at=file.created_at,
updated_at=file.updated_at, updated_at=file.updated_at,
) )
for file in db.query(File) for file in await db.query(File)
.filter(File.id.in_(ids)) .filter(File.id.in_(ids))
.order_by(File.updated_at.desc()) .order_by(File.updated_at.desc())
.all() .all()
] ]
def get_files_by_user_id(self, user_id: str) -> list[FileModel]: async def get_files_by_user_id(self, user_id: str) -> list[FileModel]:
with get_db() as db: async with get_db() as db:
return [ return [
FileModel.model_validate(file) FileModel.model_validate(file)
for file in db.query(File).filter_by(user_id=user_id).all() for file in await db.query(File).filter_by(user_id=user_id).all()
] ]
def update_file_hash_by_id(self, id: str, hash: str) -> Optional[FileModel]: async def update_file_hash_by_id(self, id: str, hash: str) -> Optional[FileModel]:
with get_db() as db: async with get_db() as db:
try: try:
file = db.query(File).filter_by(id=id).first() file = await db.query(File).filter_by(id=id).first()
file.hash = hash file.hash = hash
db.commit() await db.commit()
return FileModel.model_validate(file) return FileModel.model_validate(file)
except Exception: except Exception:
return None return None
def update_file_data_by_id(self, id: str, data: dict) -> Optional[FileModel]: async def update_file_data_by_id(self, id: str, data: dict) -> Optional[FileModel]:
with get_db() as db: async with get_db() as db:
try: try:
file = db.query(File).filter_by(id=id).first() file = await db.query(File).filter_by(id=id).first()
file.data = {**(file.data if file.data else {}), **data} file.data = {**(file.data if file.data else {}), **data}
db.commit() await db.commit()
return FileModel.model_validate(file) return FileModel.model_validate(file)
except Exception as e: except Exception as e:
return None return None
def update_file_metadata_by_id(self, id: str, meta: dict) -> Optional[FileModel]: async def update_file_metadata_by_id(
with get_db() as db: self, id: str, meta: dict
) -> Optional[FileModel]:
async with get_db() as db:
try: try:
file = db.query(File).filter_by(id=id).first() file = await db.query(File).filter_by(id=id).first()
file.meta = {**(file.meta if file.meta else {}), **meta} file.meta = {**(file.meta if file.meta else {}), **meta}
db.commit() await db.commit()
return FileModel.model_validate(file) return FileModel.model_validate(file)
except Exception: except Exception:
return None return None
def delete_file_by_id(self, id: str) -> bool: async def delete_file_by_id(self, id: str) -> bool:
with get_db() as db: async with get_db() as db:
try: try:
db.query(File).filter_by(id=id).delete() await db.query(File).filter_by(id=id).delete()
db.commit() await db.commit()
return True return True
except Exception: except Exception:
return False return False
def delete_all_files(self) -> bool: async def delete_all_files(self) -> bool:
with get_db() as db: async with get_db() as db:
try: try:
db.query(File).delete() await db.query(File).delete()
db.commit() await db.commit()
return True return True
except Exception: except Exception:

View file

@ -62,10 +62,10 @@ class FolderForm(BaseModel):
class FolderTable: class FolderTable:
def insert_new_folder( async def insert_new_folder(
self, user_id: str, form_data: FolderForm, parent_id: Optional[str] = None self, user_id: str, form_data: FolderForm, parent_id: Optional[str] = None
) -> Optional[FolderModel]: ) -> Optional[FolderModel]:
with get_db() as db: async with get_db() as db:
id = str(uuid.uuid4()) id = str(uuid.uuid4())
folder = FolderModel( folder = FolderModel(
**{ **{
@ -79,9 +79,9 @@ class FolderTable:
) )
try: try:
result = Folder(**folder.model_dump()) result = Folder(**folder.model_dump())
db.add(result) await db.add(result)
db.commit() await db.commit()
db.refresh(result) await db.refresh(result)
if result: if result:
return FolderModel.model_validate(result) return FolderModel.model_validate(result)
else: else:
@ -90,12 +90,14 @@ class FolderTable:
log.exception(f"Error inserting a new folder: {e}") log.exception(f"Error inserting a new folder: {e}")
return None return None
def get_folder_by_id_and_user_id( async def get_folder_by_id_and_user_id(
self, id: str, user_id: str self, id: str, user_id: str
) -> Optional[FolderModel]: ) -> Optional[FolderModel]:
try: try:
with get_db() as db: async with get_db() as db:
folder = db.query(Folder).filter_by(id=id, user_id=user_id).first() folder = (
await db.query(Folder).filter_by(id=id, user_id=user_id).first()
)
if not folder: if not folder:
return None return None
@ -104,45 +106,47 @@ class FolderTable:
except Exception: except Exception:
return None return None
def get_children_folders_by_id_and_user_id( async def get_children_folders_by_id_and_user_id(
self, id: str, user_id: str self, id: str, user_id: str
) -> Optional[list[FolderModel]]: ) -> Optional[list[FolderModel]]:
try: try:
with get_db() as db: async with get_db() as db:
folders = [] folders = []
def get_children(folder): async def get_children(folder):
children = self.get_folders_by_parent_id_and_user_id( children = await self.get_folders_by_parent_id_and_user_id(
folder.id, user_id folder.id, user_id
) )
for child in children: for child in children:
get_children(child) await get_children(child)
folders.append(child) folders.append(child)
folder = db.query(Folder).filter_by(id=id, user_id=user_id).first() folder = (
await db.query(Folder).filter_by(id=id, user_id=user_id).first()
)
if not folder: if not folder:
return None return None
get_children(folder) await get_children(folder)
return folders return folders
except Exception: except Exception:
return None return None
def get_folders_by_user_id(self, user_id: str) -> list[FolderModel]: async def get_folders_by_user_id(self, user_id: str) -> list[FolderModel]:
with get_db() as db: async with get_db() as db:
return [ return [
FolderModel.model_validate(folder) FolderModel.model_validate(folder)
for folder in db.query(Folder).filter_by(user_id=user_id).all() for folder in await db.query(Folder).filter_by(user_id=user_id).all()
] ]
def get_folder_by_parent_id_and_user_id_and_name( async def get_folder_by_parent_id_and_user_id_and_name(
self, parent_id: Optional[str], user_id: str, name: str self, parent_id: Optional[str], user_id: str, name: str
) -> Optional[FolderModel]: ) -> Optional[FolderModel]:
try: try:
with get_db() as db: async with get_db() as db:
# Check if folder exists # Check if folder exists
folder = ( folder = (
db.query(Folder) await db.query(Folder)
.filter_by(parent_id=parent_id, user_id=user_id) .filter_by(parent_id=parent_id, user_id=user_id)
.filter(Folder.name.ilike(name)) .filter(Folder.name.ilike(name))
.first() .first()
@ -156,26 +160,28 @@ class FolderTable:
log.error(f"get_folder_by_parent_id_and_user_id_and_name: {e}") log.error(f"get_folder_by_parent_id_and_user_id_and_name: {e}")
return None return None
def get_folders_by_parent_id_and_user_id( async def get_folders_by_parent_id_and_user_id(
self, parent_id: Optional[str], user_id: str self, parent_id: Optional[str], user_id: str
) -> list[FolderModel]: ) -> list[FolderModel]:
with get_db() as db: async with get_db() as db:
return [ return [
FolderModel.model_validate(folder) FolderModel.model_validate(folder)
for folder in db.query(Folder) for folder in await db.query(Folder)
.filter_by(parent_id=parent_id, user_id=user_id) .filter_by(parent_id=parent_id, user_id=user_id)
.all() .all()
] ]
def update_folder_parent_id_by_id_and_user_id( async def update_folder_parent_id_by_id_and_user_id(
self, self,
id: str, id: str,
user_id: str, user_id: str,
parent_id: str, parent_id: str,
) -> Optional[FolderModel]: ) -> Optional[FolderModel]:
try: try:
with get_db() as db: async with get_db() as db:
folder = db.query(Folder).filter_by(id=id, user_id=user_id).first() folder = (
await db.query(Folder).filter_by(id=id, user_id=user_id).first()
)
if not folder: if not folder:
return None return None
@ -183,19 +189,21 @@ class FolderTable:
folder.parent_id = parent_id folder.parent_id = parent_id
folder.updated_at = int(time.time()) folder.updated_at = int(time.time())
db.commit() await db.commit()
return FolderModel.model_validate(folder) return FolderModel.model_validate(folder)
except Exception as e: except Exception as e:
log.error(f"update_folder: {e}") log.error(f"update_folder: {e}")
return return
def update_folder_by_id_and_user_id( async def update_folder_by_id_and_user_id(
self, id: str, user_id: str, form_data: FolderForm self, id: str, user_id: str, form_data: FolderForm
) -> Optional[FolderModel]: ) -> Optional[FolderModel]:
try: try:
with get_db() as db: async with get_db() as db:
folder = db.query(Folder).filter_by(id=id, user_id=user_id).first() folder = (
await db.query(Folder).filter_by(id=id, user_id=user_id).first()
)
if not folder: if not folder:
return None return None
@ -203,7 +211,7 @@ class FolderTable:
form_data = form_data.model_dump(exclude_unset=True) form_data = form_data.model_dump(exclude_unset=True)
existing_folder = ( existing_folder = (
db.query(Folder) await db.query(Folder)
.filter_by( .filter_by(
name=form_data.get("name"), name=form_data.get("name"),
parent_id=folder.parent_id, parent_id=folder.parent_id,
@ -224,19 +232,21 @@ class FolderTable:
folder.updated_at = int(time.time()) folder.updated_at = int(time.time())
db.commit() await db.commit()
return FolderModel.model_validate(folder) return FolderModel.model_validate(folder)
except Exception as e: except Exception as e:
log.error(f"update_folder: {e}") log.error(f"update_folder: {e}")
return return
def update_folder_is_expanded_by_id_and_user_id( async def update_folder_is_expanded_by_id_and_user_id(
self, id: str, user_id: str, is_expanded: bool self, id: str, user_id: str, is_expanded: bool
) -> Optional[FolderModel]: ) -> Optional[FolderModel]:
try: try:
with get_db() as db: async with get_db() as db:
folder = db.query(Folder).filter_by(id=id, user_id=user_id).first() folder = (
await db.query(Folder).filter_by(id=id, user_id=user_id).first()
)
if not folder: if not folder:
return None return None
@ -244,40 +254,44 @@ class FolderTable:
folder.is_expanded = is_expanded folder.is_expanded = is_expanded
folder.updated_at = int(time.time()) folder.updated_at = int(time.time())
db.commit() await db.commit()
return FolderModel.model_validate(folder) return FolderModel.model_validate(folder)
except Exception as e: except Exception as e:
log.error(f"update_folder: {e}") log.error(f"update_folder: {e}")
return return
def delete_folder_by_id_and_user_id(self, id: str, user_id: str) -> list[str]: async def delete_folder_by_id_and_user_id(self, id: str, user_id: str) -> list[str]:
try: try:
folder_ids = [] folder_ids = []
with get_db() as db: async with get_db() as db:
folder = db.query(Folder).filter_by(id=id, user_id=user_id).first() folder = (
await db.query(Folder).filter_by(id=id, user_id=user_id).first()
)
if not folder: if not folder:
return folder_ids return folder_ids
folder_ids.append(folder.id) folder_ids.append(folder.id)
# Delete all children folders # Delete all children folders
def delete_children(folder): async def delete_children(folder):
folder_children = self.get_folders_by_parent_id_and_user_id( folder_children = await self.get_folders_by_parent_id_and_user_id(
folder.id, user_id folder.id, user_id
) )
for folder_child in folder_children: for folder_child in folder_children:
delete_children(folder_child) await delete_children(folder_child)
folder_ids.append(folder_child.id) folder_ids.append(folder_child.id)
folder = db.query(Folder).filter_by(id=folder_child.id).first() folder = (
db.delete(folder) await db.query(Folder).filter_by(id=folder_child.id).first()
db.commit() )
await db.delete(folder)
await db.commit()
delete_children(folder) await delete_children(folder)
db.delete(folder) await db.delete(folder)
db.commit() await db.commit()
return folder_ids return folder_ids
except Exception as e: except Exception as e:
log.error(f"delete_folder: {e}") log.error(f"delete_folder: {e}")
@ -288,7 +302,7 @@ class FolderTable:
name = re.sub(r"[\s_]+", " ", name) name = re.sub(r"[\s_]+", " ", name)
return name.strip().lower() return name.strip().lower()
def search_folders_by_names( async def search_folders_by_names(
self, user_id: str, queries: list[str] self, user_id: str, queries: list[str]
) -> list[FolderModel]: ) -> list[FolderModel]:
""" """
@ -299,14 +313,14 @@ class FolderTable:
return [] return []
results = {} results = {}
with get_db() as db: async with get_db() as db:
folders = db.query(Folder).filter_by(user_id=user_id).all() folders = await db.query(Folder).filter_by(user_id=user_id).all()
for folder in folders: for folder in folders:
if self.normalize_folder_name(folder.name) in normalized_queries: if self.normalize_folder_name(folder.name) in normalized_queries:
results[folder.id] = FolderModel.model_validate(folder) results[folder.id] = FolderModel.model_validate(folder)
# get children folders # get children folders
children = self.get_children_folders_by_id_and_user_id( children = await self.get_children_folders_by_id_and_user_id(
folder.id, user_id folder.id, user_id
) )
for child in children: for child in children:
@ -319,7 +333,7 @@ class FolderTable:
results = list(results.values()) results = list(results.values())
return results return results
def search_folders_by_name_contains( async def search_folders_by_name_contains(
self, user_id: str, query: str self, user_id: str, query: str
) -> list[FolderModel]: ) -> list[FolderModel]:
""" """
@ -327,8 +341,8 @@ class FolderTable:
""" """
normalized_query = self.normalize_folder_name(query) normalized_query = self.normalize_folder_name(query)
results = [] results = []
with get_db() as db: async with get_db() as db:
folders = db.query(Folder).filter_by(user_id=user_id).all() folders = await db.query(Folder).filter_by(user_id=user_id).all()
for folder in folders: for folder in folders:
norm_name = self.normalize_folder_name(folder.name) norm_name = self.normalize_folder_name(folder.name)
if normalized_query in norm_name: if normalized_query in norm_name:

View file

@ -81,7 +81,7 @@ class FunctionValves(BaseModel):
class FunctionsTable: class FunctionsTable:
def insert_new_function( async def insert_new_function(
self, user_id: str, type: str, form_data: FunctionForm self, user_id: str, type: str, form_data: FunctionForm
) -> Optional[FunctionModel]: ) -> Optional[FunctionModel]:
function = FunctionModel( function = FunctionModel(
@ -95,11 +95,11 @@ class FunctionsTable:
) )
try: try:
with get_db() as db: async with get_db() as db:
result = Function(**function.model_dump()) result = Function(**function.model_dump())
db.add(result) await db.add(result)
db.commit() await db.commit()
db.refresh(result) await db.refresh(result)
if result: if result:
return FunctionModel.model_validate(result) return FunctionModel.model_validate(result)
else: else:
@ -108,14 +108,14 @@ class FunctionsTable:
log.exception(f"Error creating a new function: {e}") log.exception(f"Error creating a new function: {e}")
return None return None
def sync_functions( async def sync_functions(
self, user_id: str, functions: list[FunctionModel] self, user_id: str, functions: list[FunctionModel]
) -> list[FunctionModel]: ) -> list[FunctionModel]:
# Synchronize functions for a user by updating existing ones, inserting new ones, and removing those that are no longer present. # Synchronize functions for a user by updating existing ones, inserting new ones, and removing those that are no longer present.
try: try:
with get_db() as db: async with get_db() as db:
# Get existing functions # Get existing functions
existing_functions = db.query(Function).all() existing_functions = await db.query(Function).all()
existing_ids = {func.id for func in existing_functions} existing_ids = {func.id for func in existing_functions}
# Prepare a set of new function IDs # Prepare a set of new function IDs
@ -124,7 +124,7 @@ class FunctionsTable:
# Update or insert functions # Update or insert functions
for func in functions: for func in functions:
if func.id in existing_ids: if func.id in existing_ids:
db.query(Function).filter_by(id=func.id).update( await db.query(Function).filter_by(id=func.id).update(
{ {
**func.model_dump(), **func.model_dump(),
"user_id": user_id, "user_id": user_id,
@ -139,107 +139,109 @@ class FunctionsTable:
"updated_at": int(time.time()), "updated_at": int(time.time()),
} }
) )
db.add(new_func) await db.add(new_func)
# Remove functions that are no longer present # Remove functions that are no longer present
for func in existing_functions: for func in existing_functions:
if func.id not in new_function_ids: if func.id not in new_function_ids:
db.delete(func) await db.delete(func)
db.commit() await db.commit()
return [ return [
FunctionModel.model_validate(func) FunctionModel.model_validate(func)
for func in db.query(Function).all() for func in await db.query(Function).all()
] ]
except Exception as e: except Exception as e:
log.exception(f"Error syncing functions for user {user_id}: {e}") log.exception(f"Error syncing functions for user {user_id}: {e}")
return [] return []
def get_function_by_id(self, id: str) -> Optional[FunctionModel]: async def get_function_by_id(self, id: str) -> Optional[FunctionModel]:
try: try:
with get_db() as db: async with get_db() as db:
function = db.get(Function, id) function = await db.get(Function, id)
return FunctionModel.model_validate(function) return FunctionModel.model_validate(function)
except Exception: except Exception:
return None return None
def get_functions(self, active_only=False) -> list[FunctionModel]: async def get_functions(self, active_only=False) -> list[FunctionModel]:
with get_db() as db: async with get_db() as db:
if active_only: if active_only:
return [ return [
FunctionModel.model_validate(function) FunctionModel.model_validate(function)
for function in db.query(Function).filter_by(is_active=True).all() for function in await db.query(Function)
.filter_by(is_active=True)
.all()
] ]
else: else:
return [ return [
FunctionModel.model_validate(function) FunctionModel.model_validate(function)
for function in db.query(Function).all() for function in await db.query(Function).all()
] ]
def get_functions_by_type( async def get_functions_by_type(
self, type: str, active_only=False self, type: str, active_only=False
) -> list[FunctionModel]: ) -> list[FunctionModel]:
with get_db() as db: async with get_db() as db:
if active_only: if active_only:
return [ return [
FunctionModel.model_validate(function) FunctionModel.model_validate(function)
for function in db.query(Function) for function in await db.query(Function)
.filter_by(type=type, is_active=True) .filter_by(type=type, is_active=True)
.all() .all()
] ]
else: else:
return [ return [
FunctionModel.model_validate(function) FunctionModel.model_validate(function)
for function in db.query(Function).filter_by(type=type).all() for function in await db.query(Function).filter_by(type=type).all()
] ]
def get_global_filter_functions(self) -> list[FunctionModel]: async def get_global_filter_functions(self) -> list[FunctionModel]:
with get_db() as db: async with get_db() as db:
return [ return [
FunctionModel.model_validate(function) FunctionModel.model_validate(function)
for function in db.query(Function) for function in await db.query(Function)
.filter_by(type="filter", is_active=True, is_global=True) .filter_by(type="filter", is_active=True, is_global=True)
.all() .all()
] ]
def get_global_action_functions(self) -> list[FunctionModel]: async def get_global_action_functions(self) -> list[FunctionModel]:
with get_db() as db: async with get_db() as db:
return [ return [
FunctionModel.model_validate(function) FunctionModel.model_validate(function)
for function in db.query(Function) for function in await db.query(Function)
.filter_by(type="action", is_active=True, is_global=True) .filter_by(type="action", is_active=True, is_global=True)
.all() .all()
] ]
def get_function_valves_by_id(self, id: str) -> Optional[dict]: async def get_function_valves_by_id(self, id: str) -> Optional[dict]:
with get_db() as db: async with get_db() as db:
try: try:
function = db.get(Function, id) function = await db.get(Function, id)
return function.valves if function.valves else {} return function.valves if function.valves else {}
except Exception as e: except Exception as e:
log.exception(f"Error getting function valves by id {id}: {e}") log.exception(f"Error getting function valves by id {id}: {e}")
return None return None
def update_function_valves_by_id( async def update_function_valves_by_id(
self, id: str, valves: dict self, id: str, valves: dict
) -> Optional[FunctionValves]: ) -> Optional[FunctionValves]:
with get_db() as db: async with get_db() as db:
try: try:
function = db.get(Function, id) function = await db.get(Function, id)
function.valves = valves function.valves = valves
function.updated_at = int(time.time()) function.updated_at = int(time.time())
db.commit() await db.commit()
db.refresh(function) await db.refresh(function)
return self.get_function_by_id(id) return await self.get_function_by_id(id)
except Exception: except Exception:
return None return None
def get_user_valves_by_id_and_user_id( async def get_user_valves_by_id_and_user_id(
self, id: str, user_id: str self, id: str, user_id: str
) -> Optional[dict]: ) -> Optional[dict]:
try: try:
user = Users.get_user_by_id(user_id) user = await Users.get_user_by_id(user_id)
user_settings = user.settings.model_dump() if user.settings else {} user_settings = user.settings.model_dump() if user.settings else {}
# Check if user has "functions" and "valves" settings # Check if user has "functions" and "valves" settings
@ -255,11 +257,11 @@ class FunctionsTable:
) )
return None return None
def update_user_valves_by_id_and_user_id( async def update_user_valves_by_id_and_user_id(
self, id: str, user_id: str, valves: dict self, id: str, user_id: str, valves: dict
) -> Optional[dict]: ) -> Optional[dict]:
try: try:
user = Users.get_user_by_id(user_id) user = await Users.get_user_by_id(user_id)
user_settings = user.settings.model_dump() if user.settings else {} user_settings = user.settings.model_dump() if user.settings else {}
# Check if user has "functions" and "valves" settings # Check if user has "functions" and "valves" settings
@ -271,7 +273,7 @@ class FunctionsTable:
user_settings["functions"]["valves"][id] = valves user_settings["functions"]["valves"][id] = valves
# Update the user settings in the database # Update the user settings in the database
Users.update_user_by_id(user_id, {"settings": user_settings}) await Users.update_user_by_id(user_id, {"settings": user_settings})
return user_settings["functions"]["valves"][id] return user_settings["functions"]["valves"][id]
except Exception as e: except Exception as e:
@ -280,39 +282,41 @@ class FunctionsTable:
) )
return None return None
def update_function_by_id(self, id: str, updated: dict) -> Optional[FunctionModel]: async def update_function_by_id(
with get_db() as db: self, id: str, updated: dict
) -> Optional[FunctionModel]:
async with get_db() as db:
try: try:
db.query(Function).filter_by(id=id).update( await db.query(Function).filter_by(id=id).update(
{ {
**updated, **updated,
"updated_at": int(time.time()), "updated_at": int(time.time()),
} }
) )
db.commit() await db.commit()
return self.get_function_by_id(id) return await self.get_function_by_id(id)
except Exception: except Exception:
return None return None
def deactivate_all_functions(self) -> Optional[bool]: async def deactivate_all_functions(self) -> Optional[bool]:
with get_db() as db: async with get_db() as db:
try: try:
db.query(Function).update( await db.query(Function).update(
{ {
"is_active": False, "is_active": False,
"updated_at": int(time.time()), "updated_at": int(time.time()),
} }
) )
db.commit() await db.commit()
return True return True
except Exception: except Exception:
return None return None
def delete_function_by_id(self, id: str) -> bool: async def delete_function_by_id(self, id: str) -> bool:
with get_db() as db: async with get_db() as db:
try: try:
db.query(Function).filter_by(id=id).delete() await db.query(Function).filter_by(id=id).delete()
db.commit() await db.commit()
return True return True
except Exception: except Exception:

View file

@ -95,7 +95,7 @@ class GroupTable:
def insert_new_group( def insert_new_group(
self, user_id: str, form_data: GroupForm self, user_id: str, form_data: GroupForm
) -> Optional[GroupModel]: ) -> Optional[GroupModel]:
with get_db() as db: async with get_db() as db:
group = GroupModel( group = GroupModel(
**{ **{
**form_data.model_dump(exclude_none=True), **form_data.model_dump(exclude_none=True),
@ -120,14 +120,14 @@ class GroupTable:
return None return None
def get_groups(self) -> list[GroupModel]: def get_groups(self) -> list[GroupModel]:
with get_db() as db: async with get_db() as db:
return [ return [
GroupModel.model_validate(group) GroupModel.model_validate(group)
for group in db.query(Group).order_by(Group.updated_at.desc()).all() for group in db.query(Group).order_by(Group.updated_at.desc()).all()
] ]
def get_groups_by_member_id(self, user_id: str) -> list[GroupModel]: def get_groups_by_member_id(self, user_id: str) -> list[GroupModel]:
with get_db() as db: async with get_db() as db:
return [ return [
GroupModel.model_validate(group) GroupModel.model_validate(group)
for group in db.query(Group) for group in db.query(Group)
@ -143,7 +143,7 @@ class GroupTable:
def get_group_by_id(self, id: str) -> Optional[GroupModel]: def get_group_by_id(self, id: str) -> Optional[GroupModel]:
try: try:
with get_db() as db: async with get_db() as db:
group = db.query(Group).filter_by(id=id).first() group = db.query(Group).filter_by(id=id).first()
return GroupModel.model_validate(group) if group else None return GroupModel.model_validate(group) if group else None
except Exception: except Exception:
@ -160,7 +160,7 @@ class GroupTable:
self, id: str, form_data: GroupUpdateForm, overwrite: bool = False self, id: str, form_data: GroupUpdateForm, overwrite: bool = False
) -> Optional[GroupModel]: ) -> Optional[GroupModel]:
try: try:
with get_db() as db: async with get_db() as db:
db.query(Group).filter_by(id=id).update( db.query(Group).filter_by(id=id).update(
{ {
**form_data.model_dump(exclude_none=True), **form_data.model_dump(exclude_none=True),
@ -175,7 +175,7 @@ class GroupTable:
def delete_group_by_id(self, id: str) -> bool: def delete_group_by_id(self, id: str) -> bool:
try: try:
with get_db() as db: async with get_db() as db:
db.query(Group).filter_by(id=id).delete() db.query(Group).filter_by(id=id).delete()
db.commit() db.commit()
return True return True
@ -183,7 +183,7 @@ class GroupTable:
return False return False
def delete_all_groups(self) -> bool: def delete_all_groups(self) -> bool:
with get_db() as db: async with get_db() as db:
try: try:
db.query(Group).delete() db.query(Group).delete()
db.commit() db.commit()
@ -193,7 +193,7 @@ class GroupTable:
return False return False
def remove_user_from_all_groups(self, user_id: str) -> bool: def remove_user_from_all_groups(self, user_id: str) -> bool:
with get_db() as db: async with get_db() as db:
try: try:
groups = self.get_groups_by_member_id(user_id) groups = self.get_groups_by_member_id(user_id)
@ -221,7 +221,7 @@ class GroupTable:
new_groups = [] new_groups = []
with get_db() as db: async with get_db() as db:
for group_name in group_names: for group_name in group_names:
if group_name not in existing_group_names: if group_name not in existing_group_names:
new_group = GroupModel( new_group = GroupModel(
@ -244,7 +244,7 @@ class GroupTable:
return new_groups return new_groups
def sync_groups_by_group_names(self, user_id: str, group_names: list[str]) -> bool: def sync_groups_by_group_names(self, user_id: str, group_names: list[str]) -> bool:
with get_db() as db: async with get_db() as db:
try: try:
groups = db.query(Group).filter(Group.name.in_(group_names)).all() groups = db.query(Group).filter(Group.name.in_(group_names)).all()
group_ids = [group.id for group in groups] group_ids = [group.id for group in groups]
@ -283,7 +283,7 @@ class GroupTable:
self, id: str, user_ids: Optional[list[str]] = None self, id: str, user_ids: Optional[list[str]] = None
) -> Optional[GroupModel]: ) -> Optional[GroupModel]:
try: try:
with get_db() as db: async with get_db() as db:
group = db.query(Group).filter_by(id=id).first() group = db.query(Group).filter_by(id=id).first()
if not group: if not group:
return None return None
@ -307,7 +307,7 @@ class GroupTable:
self, id: str, user_ids: Optional[list[str]] = None self, id: str, user_ids: Optional[list[str]] = None
) -> Optional[GroupModel]: ) -> Optional[GroupModel]:
try: try:
with get_db() as db: async with get_db() as db:
group = db.query(Group).filter_by(id=id).first() group = db.query(Group).filter_by(id=id).first()
if not group: if not group:
return None return None

View file

@ -103,7 +103,7 @@ class KnowledgeTable:
def insert_new_knowledge( def insert_new_knowledge(
self, user_id: str, form_data: KnowledgeForm self, user_id: str, form_data: KnowledgeForm
) -> Optional[KnowledgeModel]: ) -> Optional[KnowledgeModel]:
with get_db() as db: async with get_db() as db:
knowledge = KnowledgeModel( knowledge = KnowledgeModel(
**{ **{
**form_data.model_dump(), **form_data.model_dump(),
@ -126,13 +126,13 @@ class KnowledgeTable:
except Exception: except Exception:
return None return None
def get_knowledge_bases(self) -> list[KnowledgeUserModel]: async def get_knowledge_bases(self) -> list[KnowledgeUserModel]:
with get_db() as db: async with get_db() as db:
knowledge_bases = [] knowledge_bases = []
for knowledge in ( for knowledge in (
db.query(Knowledge).order_by(Knowledge.updated_at.desc()).all() db.query(Knowledge).order_by(Knowledge.updated_at.desc()).all()
): ):
user = Users.get_user_by_id(knowledge.user_id) user = await Users.get_user_by_id(knowledge.user_id)
knowledge_bases.append( knowledge_bases.append(
KnowledgeUserModel.model_validate( KnowledgeUserModel.model_validate(
{ {
@ -156,7 +156,7 @@ class KnowledgeTable:
def get_knowledge_by_id(self, id: str) -> Optional[KnowledgeModel]: def get_knowledge_by_id(self, id: str) -> Optional[KnowledgeModel]:
try: try:
with get_db() as db: async with get_db() as db:
knowledge = db.query(Knowledge).filter_by(id=id).first() knowledge = db.query(Knowledge).filter_by(id=id).first()
return KnowledgeModel.model_validate(knowledge) if knowledge else None return KnowledgeModel.model_validate(knowledge) if knowledge else None
except Exception: except Exception:
@ -166,7 +166,7 @@ class KnowledgeTable:
self, id: str, form_data: KnowledgeForm, overwrite: bool = False self, id: str, form_data: KnowledgeForm, overwrite: bool = False
) -> Optional[KnowledgeModel]: ) -> Optional[KnowledgeModel]:
try: try:
with get_db() as db: async with get_db() as db:
knowledge = self.get_knowledge_by_id(id=id) knowledge = self.get_knowledge_by_id(id=id)
db.query(Knowledge).filter_by(id=id).update( db.query(Knowledge).filter_by(id=id).update(
{ {
@ -184,7 +184,7 @@ class KnowledgeTable:
self, id: str, data: dict self, id: str, data: dict
) -> Optional[KnowledgeModel]: ) -> Optional[KnowledgeModel]:
try: try:
with get_db() as db: async with get_db() as db:
knowledge = self.get_knowledge_by_id(id=id) knowledge = self.get_knowledge_by_id(id=id)
db.query(Knowledge).filter_by(id=id).update( db.query(Knowledge).filter_by(id=id).update(
{ {
@ -200,7 +200,7 @@ class KnowledgeTable:
def delete_knowledge_by_id(self, id: str) -> bool: def delete_knowledge_by_id(self, id: str) -> bool:
try: try:
with get_db() as db: async with get_db() as db:
db.query(Knowledge).filter_by(id=id).delete() db.query(Knowledge).filter_by(id=id).delete()
db.commit() db.commit()
return True return True
@ -208,7 +208,7 @@ class KnowledgeTable:
return False return False
def delete_all_knowledge(self) -> bool: def delete_all_knowledge(self) -> bool:
with get_db() as db: async with get_db() as db:
try: try:
db.query(Knowledge).delete() db.query(Knowledge).delete()
db.commit() db.commit()

View file

@ -42,7 +42,7 @@ class MemoriesTable:
user_id: str, user_id: str,
content: str, content: str,
) -> Optional[MemoryModel]: ) -> Optional[MemoryModel]:
with get_db() as db: async with get_db() as db:
id = str(uuid.uuid4()) id = str(uuid.uuid4())
memory = MemoryModel( memory = MemoryModel(
@ -69,7 +69,7 @@ class MemoriesTable:
user_id: str, user_id: str,
content: str, content: str,
) -> Optional[MemoryModel]: ) -> Optional[MemoryModel]:
with get_db() as db: async with get_db() as db:
try: try:
memory = db.get(Memory, id) memory = db.get(Memory, id)
if not memory or memory.user_id != user_id: if not memory or memory.user_id != user_id:
@ -84,7 +84,7 @@ class MemoriesTable:
return None return None
def get_memories(self) -> list[MemoryModel]: def get_memories(self) -> list[MemoryModel]:
with get_db() as db: async with get_db() as db:
try: try:
memories = db.query(Memory).all() memories = db.query(Memory).all()
return [MemoryModel.model_validate(memory) for memory in memories] return [MemoryModel.model_validate(memory) for memory in memories]
@ -92,7 +92,7 @@ class MemoriesTable:
return None return None
def get_memories_by_user_id(self, user_id: str) -> list[MemoryModel]: def get_memories_by_user_id(self, user_id: str) -> list[MemoryModel]:
with get_db() as db: async with get_db() as db:
try: try:
memories = db.query(Memory).filter_by(user_id=user_id).all() memories = db.query(Memory).filter_by(user_id=user_id).all()
return [MemoryModel.model_validate(memory) for memory in memories] return [MemoryModel.model_validate(memory) for memory in memories]
@ -100,7 +100,7 @@ class MemoriesTable:
return None return None
def get_memory_by_id(self, id: str) -> Optional[MemoryModel]: def get_memory_by_id(self, id: str) -> Optional[MemoryModel]:
with get_db() as db: async with get_db() as db:
try: try:
memory = db.get(Memory, id) memory = db.get(Memory, id)
return MemoryModel.model_validate(memory) return MemoryModel.model_validate(memory)
@ -108,7 +108,7 @@ class MemoriesTable:
return None return None
def delete_memory_by_id(self, id: str) -> bool: def delete_memory_by_id(self, id: str) -> bool:
with get_db() as db: async with get_db() as db:
try: try:
db.query(Memory).filter_by(id=id).delete() db.query(Memory).filter_by(id=id).delete()
db.commit() db.commit()
@ -119,7 +119,7 @@ class MemoriesTable:
return False return False
def delete_memories_by_user_id(self, user_id: str) -> bool: def delete_memories_by_user_id(self, user_id: str) -> bool:
with get_db() as db: async with get_db() as db:
try: try:
db.query(Memory).filter_by(user_id=user_id).delete() db.query(Memory).filter_by(user_id=user_id).delete()
db.commit() db.commit()
@ -129,7 +129,7 @@ class MemoriesTable:
return False return False
def delete_memory_by_id_and_user_id(self, id: str, user_id: str) -> bool: def delete_memory_by_id_and_user_id(self, id: str, user_id: str) -> bool:
with get_db() as db: async with get_db() as db:
try: try:
memory = db.get(Memory, id) memory = db.get(Memory, id)
if not memory or memory.user_id != user_id: if not memory or memory.user_id != user_id:

View file

@ -98,7 +98,7 @@ class MessageTable:
def insert_new_message( def insert_new_message(
self, form_data: MessageForm, channel_id: str, user_id: str self, form_data: MessageForm, channel_id: str, user_id: str
) -> Optional[MessageModel]: ) -> Optional[MessageModel]:
with get_db() as db: async with get_db() as db:
id = str(uuid.uuid4()) id = str(uuid.uuid4())
ts = int(time.time_ns()) ts = int(time.time_ns())
@ -123,7 +123,7 @@ class MessageTable:
return MessageModel.model_validate(result) if result else None return MessageModel.model_validate(result) if result else None
def get_message_by_id(self, id: str) -> Optional[MessageResponse]: def get_message_by_id(self, id: str) -> Optional[MessageResponse]:
with get_db() as db: async with get_db() as db:
message = db.get(Message, id) message = db.get(Message, id)
if not message: if not message:
return None return None
@ -141,7 +141,7 @@ class MessageTable:
) )
def get_replies_by_message_id(self, id: str) -> list[MessageModel]: def get_replies_by_message_id(self, id: str) -> list[MessageModel]:
with get_db() as db: async with get_db() as db:
all_messages = ( all_messages = (
db.query(Message) db.query(Message)
.filter_by(parent_id=id) .filter_by(parent_id=id)
@ -151,7 +151,7 @@ class MessageTable:
return [MessageModel.model_validate(message) for message in all_messages] return [MessageModel.model_validate(message) for message in all_messages]
def get_reply_user_ids_by_message_id(self, id: str) -> list[str]: def get_reply_user_ids_by_message_id(self, id: str) -> list[str]:
with get_db() as db: async with get_db() as db:
return [ return [
message.user_id message.user_id
for message in db.query(Message).filter_by(parent_id=id).all() for message in db.query(Message).filter_by(parent_id=id).all()
@ -160,7 +160,7 @@ class MessageTable:
def get_messages_by_channel_id( def get_messages_by_channel_id(
self, channel_id: str, skip: int = 0, limit: int = 50 self, channel_id: str, skip: int = 0, limit: int = 50
) -> list[MessageModel]: ) -> list[MessageModel]:
with get_db() as db: async with get_db() as db:
all_messages = ( all_messages = (
db.query(Message) db.query(Message)
.filter_by(channel_id=channel_id, parent_id=None) .filter_by(channel_id=channel_id, parent_id=None)
@ -174,7 +174,7 @@ class MessageTable:
def get_messages_by_parent_id( def get_messages_by_parent_id(
self, channel_id: str, parent_id: str, skip: int = 0, limit: int = 50 self, channel_id: str, parent_id: str, skip: int = 0, limit: int = 50
) -> list[MessageModel]: ) -> list[MessageModel]:
with get_db() as db: async with get_db() as db:
message = db.get(Message, parent_id) message = db.get(Message, parent_id)
if not message: if not message:
@ -198,7 +198,7 @@ class MessageTable:
def update_message_by_id( def update_message_by_id(
self, id: str, form_data: MessageForm self, id: str, form_data: MessageForm
) -> Optional[MessageModel]: ) -> Optional[MessageModel]:
with get_db() as db: async with get_db() as db:
message = db.get(Message, id) message = db.get(Message, id)
message.content = form_data.content message.content = form_data.content
message.data = form_data.data message.data = form_data.data
@ -211,7 +211,7 @@ class MessageTable:
def add_reaction_to_message( def add_reaction_to_message(
self, id: str, user_id: str, name: str self, id: str, user_id: str, name: str
) -> Optional[MessageReactionModel]: ) -> Optional[MessageReactionModel]:
with get_db() as db: async with get_db() as db:
reaction_id = str(uuid.uuid4()) reaction_id = str(uuid.uuid4())
reaction = MessageReactionModel( reaction = MessageReactionModel(
id=reaction_id, id=reaction_id,
@ -227,7 +227,7 @@ class MessageTable:
return MessageReactionModel.model_validate(result) if result else None return MessageReactionModel.model_validate(result) if result else None
def get_reactions_by_message_id(self, id: str) -> list[Reactions]: def get_reactions_by_message_id(self, id: str) -> list[Reactions]:
with get_db() as db: async with get_db() as db:
all_reactions = db.query(MessageReaction).filter_by(message_id=id).all() all_reactions = db.query(MessageReaction).filter_by(message_id=id).all()
reactions = {} reactions = {}
@ -246,7 +246,7 @@ class MessageTable:
def remove_reaction_by_id_and_user_id_and_name( def remove_reaction_by_id_and_user_id_and_name(
self, id: str, user_id: str, name: str self, id: str, user_id: str, name: str
) -> bool: ) -> bool:
with get_db() as db: async with get_db() as db:
db.query(MessageReaction).filter_by( db.query(MessageReaction).filter_by(
message_id=id, user_id=user_id, name=name message_id=id, user_id=user_id, name=name
).delete() ).delete()
@ -254,19 +254,19 @@ class MessageTable:
return True return True
def delete_reactions_by_id(self, id: str) -> bool: def delete_reactions_by_id(self, id: str) -> bool:
with get_db() as db: async with get_db() as db:
db.query(MessageReaction).filter_by(message_id=id).delete() db.query(MessageReaction).filter_by(message_id=id).delete()
db.commit() db.commit()
return True return True
def delete_replies_by_id(self, id: str) -> bool: def delete_replies_by_id(self, id: str) -> bool:
with get_db() as db: async with get_db() as db:
db.query(Message).filter_by(parent_id=id).delete() db.query(Message).filter_by(parent_id=id).delete()
db.commit() db.commit()
return True return True
def delete_message_by_id(self, id: str) -> bool: def delete_message_by_id(self, id: str) -> bool:
with get_db() as db: async with get_db() as db:
db.query(Message).filter_by(id=id).delete() db.query(Message).filter_by(id=id).delete()
# Delete all reactions to this message # Delete all reactions to this message

View file

@ -155,7 +155,7 @@ class ModelsTable:
} }
) )
try: try:
with get_db() as db: async with get_db() as db:
result = Model(**model.model_dump()) result = Model(**model.model_dump())
db.add(result) db.add(result)
db.commit() db.commit()
@ -170,14 +170,14 @@ class ModelsTable:
return None return None
def get_all_models(self) -> list[ModelModel]: def get_all_models(self) -> list[ModelModel]:
with get_db() as db: async with get_db() as db:
return [ModelModel.model_validate(model) for model in db.query(Model).all()] return [ModelModel.model_validate(model) for model in db.query(Model).all()]
def get_models(self) -> list[ModelUserResponse]: async def get_models(self) -> list[ModelUserResponse]:
with get_db() as db: async with get_db() as db:
models = [] models = []
for model in db.query(Model).filter(Model.base_model_id != None).all(): for model in db.query(Model).filter(Model.base_model_id != None).all():
user = Users.get_user_by_id(model.user_id) user = await Users.get_user_by_id(model.user_id)
models.append( models.append(
ModelUserResponse.model_validate( ModelUserResponse.model_validate(
{ {
@ -189,7 +189,7 @@ class ModelsTable:
return models return models
def get_base_models(self) -> list[ModelModel]: def get_base_models(self) -> list[ModelModel]:
with get_db() as db: async with get_db() as db:
return [ return [
ModelModel.model_validate(model) ModelModel.model_validate(model)
for model in db.query(Model).filter(Model.base_model_id == None).all() for model in db.query(Model).filter(Model.base_model_id == None).all()
@ -208,14 +208,14 @@ class ModelsTable:
def get_model_by_id(self, id: str) -> Optional[ModelModel]: def get_model_by_id(self, id: str) -> Optional[ModelModel]:
try: try:
with get_db() as db: async with get_db() as db:
model = db.get(Model, id) model = db.get(Model, id)
return ModelModel.model_validate(model) return ModelModel.model_validate(model)
except Exception: except Exception:
return None return None
def toggle_model_by_id(self, id: str) -> Optional[ModelModel]: def toggle_model_by_id(self, id: str) -> Optional[ModelModel]:
with get_db() as db: async with get_db() as db:
try: try:
is_active = db.query(Model).filter_by(id=id).first().is_active is_active = db.query(Model).filter_by(id=id).first().is_active
@ -233,7 +233,7 @@ class ModelsTable:
def update_model_by_id(self, id: str, model: ModelForm) -> Optional[ModelModel]: def update_model_by_id(self, id: str, model: ModelForm) -> Optional[ModelModel]:
try: try:
with get_db() as db: async with get_db() as db:
# update only the fields that are present in the model # update only the fields that are present in the model
result = ( result = (
db.query(Model) db.query(Model)
@ -251,7 +251,7 @@ class ModelsTable:
def delete_model_by_id(self, id: str) -> bool: def delete_model_by_id(self, id: str) -> bool:
try: try:
with get_db() as db: async with get_db() as db:
db.query(Model).filter_by(id=id).delete() db.query(Model).filter_by(id=id).delete()
db.commit() db.commit()
@ -261,7 +261,7 @@ class ModelsTable:
def delete_all_models(self) -> bool: def delete_all_models(self) -> bool:
try: try:
with get_db() as db: async with get_db() as db:
db.query(Model).delete() db.query(Model).delete()
db.commit() db.commit()
@ -271,7 +271,7 @@ class ModelsTable:
def sync_models(self, user_id: str, models: list[ModelModel]) -> list[ModelModel]: def sync_models(self, user_id: str, models: list[ModelModel]) -> list[ModelModel]:
try: try:
with get_db() as db: async with get_db() as db:
# Get existing models # Get existing models
existing_models = db.query(Model).all() existing_models = db.query(Model).all()
existing_ids = {model.id for model in existing_models} existing_ids = {model.id for model in existing_models}

View file

@ -79,7 +79,7 @@ class NoteTable:
form_data: NoteForm, form_data: NoteForm,
user_id: str, user_id: str,
) -> Optional[NoteModel]: ) -> Optional[NoteModel]:
with get_db() as db: async with get_db() as db:
note = NoteModel( note = NoteModel(
**{ **{
"id": str(uuid.uuid4()), "id": str(uuid.uuid4()),
@ -97,7 +97,7 @@ class NoteTable:
return note return note
def get_notes(self) -> list[NoteModel]: def get_notes(self) -> list[NoteModel]:
with get_db() as db: async with get_db() as db:
notes = db.query(Note).order_by(Note.updated_at.desc()).all() notes = db.query(Note).order_by(Note.updated_at.desc()).all()
return [NoteModel.model_validate(note) for note in notes] return [NoteModel.model_validate(note) for note in notes]
@ -113,14 +113,14 @@ class NoteTable:
] ]
def get_note_by_id(self, id: str) -> Optional[NoteModel]: def get_note_by_id(self, id: str) -> Optional[NoteModel]:
with get_db() as db: async with get_db() as db:
note = db.query(Note).filter(Note.id == id).first() note = db.query(Note).filter(Note.id == id).first()
return NoteModel.model_validate(note) if note else None return NoteModel.model_validate(note) if note else None
def update_note_by_id( def update_note_by_id(
self, id: str, form_data: NoteUpdateForm self, id: str, form_data: NoteUpdateForm
) -> Optional[NoteModel]: ) -> Optional[NoteModel]:
with get_db() as db: async with get_db() as db:
note = db.query(Note).filter(Note.id == id).first() note = db.query(Note).filter(Note.id == id).first()
if not note: if not note:
return None return None
@ -143,7 +143,7 @@ class NoteTable:
return NoteModel.model_validate(note) if note else None return NoteModel.model_validate(note) if note else None
def delete_note_by_id(self, id: str): def delete_note_by_id(self, id: str):
with get_db() as db: async with get_db() as db:
db.query(Note).filter(Note.id == id).delete() db.query(Note).filter(Note.id == id).delete()
db.commit() db.commit()
return True return True

View file

@ -81,7 +81,7 @@ class PromptsTable:
) )
try: try:
with get_db() as db: async with get_db() as db:
result = Prompt(**prompt.model_dump()) result = Prompt(**prompt.model_dump())
db.add(result) db.add(result)
db.commit() db.commit()
@ -95,18 +95,18 @@ class PromptsTable:
def get_prompt_by_command(self, command: str) -> Optional[PromptModel]: def get_prompt_by_command(self, command: str) -> Optional[PromptModel]:
try: try:
with get_db() as db: async with get_db() as db:
prompt = db.query(Prompt).filter_by(command=command).first() prompt = db.query(Prompt).filter_by(command=command).first()
return PromptModel.model_validate(prompt) return PromptModel.model_validate(prompt)
except Exception: except Exception:
return None return None
def get_prompts(self) -> list[PromptUserResponse]: async def get_prompts(self) -> list[PromptUserResponse]:
with get_db() as db: async with get_db() as db:
prompts = [] prompts = []
for prompt in db.query(Prompt).order_by(Prompt.timestamp.desc()).all(): for prompt in db.query(Prompt).order_by(Prompt.timestamp.desc()).all():
user = Users.get_user_by_id(prompt.user_id) user = await Users.get_user_by_id(prompt.user_id)
prompts.append( prompts.append(
PromptUserResponse.model_validate( PromptUserResponse.model_validate(
{ {
@ -134,7 +134,7 @@ class PromptsTable:
self, command: str, form_data: PromptForm self, command: str, form_data: PromptForm
) -> Optional[PromptModel]: ) -> Optional[PromptModel]:
try: try:
with get_db() as db: async with get_db() as db:
prompt = db.query(Prompt).filter_by(command=command).first() prompt = db.query(Prompt).filter_by(command=command).first()
prompt.title = form_data.title prompt.title = form_data.title
prompt.content = form_data.content prompt.content = form_data.content
@ -147,7 +147,7 @@ class PromptsTable:
def delete_prompt_by_command(self, command: str) -> bool: def delete_prompt_by_command(self, command: str) -> bool:
try: try:
with get_db() as db: async with get_db() as db:
db.query(Prompt).filter_by(command=command).delete() db.query(Prompt).filter_by(command=command).delete()
db.commit() db.commit()

View file

@ -48,7 +48,7 @@ class TagChatIdForm(BaseModel):
class TagTable: class TagTable:
def insert_new_tag(self, name: str, user_id: str) -> Optional[TagModel]: def insert_new_tag(self, name: str, user_id: str) -> Optional[TagModel]:
with get_db() as db: async with get_db() as db:
id = name.replace(" ", "_").lower() id = name.replace(" ", "_").lower()
tag = TagModel(**{"id": id, "user_id": user_id, "name": name}) tag = TagModel(**{"id": id, "user_id": user_id, "name": name})
try: try:
@ -69,14 +69,14 @@ class TagTable:
) -> Optional[TagModel]: ) -> Optional[TagModel]:
try: try:
id = name.replace(" ", "_").lower() id = name.replace(" ", "_").lower()
with get_db() as db: async with get_db() as db:
tag = db.query(Tag).filter_by(id=id, user_id=user_id).first() tag = db.query(Tag).filter_by(id=id, user_id=user_id).first()
return TagModel.model_validate(tag) return TagModel.model_validate(tag)
except Exception: except Exception:
return None return None
def get_tags_by_user_id(self, user_id: str) -> list[TagModel]: def get_tags_by_user_id(self, user_id: str) -> list[TagModel]:
with get_db() as db: async with get_db() as db:
return [ return [
TagModel.model_validate(tag) TagModel.model_validate(tag)
for tag in (db.query(Tag).filter_by(user_id=user_id).all()) for tag in (db.query(Tag).filter_by(user_id=user_id).all())
@ -85,7 +85,7 @@ class TagTable:
def get_tags_by_ids_and_user_id( def get_tags_by_ids_and_user_id(
self, ids: list[str], user_id: str self, ids: list[str], user_id: str
) -> list[TagModel]: ) -> list[TagModel]:
with get_db() as db: async with get_db() as db:
return [ return [
TagModel.model_validate(tag) TagModel.model_validate(tag)
for tag in ( for tag in (
@ -95,7 +95,7 @@ class TagTable:
def delete_tag_by_name_and_user_id(self, name: str, user_id: str) -> bool: def delete_tag_by_name_and_user_id(self, name: str, user_id: str) -> bool:
try: try:
with get_db() as db: async with get_db() as db:
id = name.replace(" ", "_").lower() id = name.replace(" ", "_").lower()
res = db.query(Tag).filter_by(id=id, user_id=user_id).delete() res = db.query(Tag).filter_by(id=id, user_id=user_id).delete()
log.debug(f"res: {res}") log.debug(f"res: {res}")

View file

@ -110,7 +110,7 @@ class ToolsTable:
def insert_new_tool( def insert_new_tool(
self, user_id: str, form_data: ToolForm, specs: list[dict] self, user_id: str, form_data: ToolForm, specs: list[dict]
) -> Optional[ToolModel]: ) -> Optional[ToolModel]:
with get_db() as db: async with get_db() as db:
tool = ToolModel( tool = ToolModel(
**{ **{
**form_data.model_dump(), **form_data.model_dump(),
@ -136,17 +136,17 @@ class ToolsTable:
def get_tool_by_id(self, id: str) -> Optional[ToolModel]: def get_tool_by_id(self, id: str) -> Optional[ToolModel]:
try: try:
with get_db() as db: async with get_db() as db:
tool = db.get(Tool, id) tool = db.get(Tool, id)
return ToolModel.model_validate(tool) return ToolModel.model_validate(tool)
except Exception: except Exception:
return None return None
def get_tools(self) -> list[ToolUserModel]: async def get_tools(self) -> list[ToolUserModel]:
with get_db() as db: async with get_db() as db:
tools = [] tools = []
for tool in db.query(Tool).order_by(Tool.updated_at.desc()).all(): for tool in db.query(Tool).order_by(Tool.updated_at.desc()).all():
user = Users.get_user_by_id(tool.user_id) user = await Users.get_user_by_id(tool.user_id)
tools.append( tools.append(
ToolUserModel.model_validate( ToolUserModel.model_validate(
{ {
@ -171,7 +171,7 @@ class ToolsTable:
def get_tool_valves_by_id(self, id: str) -> Optional[dict]: def get_tool_valves_by_id(self, id: str) -> Optional[dict]:
try: try:
with get_db() as db: async with get_db() as db:
tool = db.get(Tool, id) tool = db.get(Tool, id)
return tool.valves if tool.valves else {} return tool.valves if tool.valves else {}
except Exception as e: except Exception as e:
@ -180,7 +180,7 @@ class ToolsTable:
def update_tool_valves_by_id(self, id: str, valves: dict) -> Optional[ToolValves]: def update_tool_valves_by_id(self, id: str, valves: dict) -> Optional[ToolValves]:
try: try:
with get_db() as db: async with get_db() as db:
db.query(Tool).filter_by(id=id).update( db.query(Tool).filter_by(id=id).update(
{"valves": valves, "updated_at": int(time.time())} {"valves": valves, "updated_at": int(time.time())}
) )
@ -189,11 +189,11 @@ class ToolsTable:
except Exception: except Exception:
return None return None
def get_user_valves_by_id_and_user_id( async def get_user_valves_by_id_and_user_id(
self, id: str, user_id: str self, id: str, user_id: str
) -> Optional[dict]: ) -> Optional[dict]:
try: try:
user = Users.get_user_by_id(user_id) user = await Users.get_user_by_id(user_id)
user_settings = user.settings.model_dump() if user.settings else {} user_settings = user.settings.model_dump() if user.settings else {}
# Check if user has "tools" and "valves" settings # Check if user has "tools" and "valves" settings
@ -209,11 +209,11 @@ class ToolsTable:
) )
return None return None
def update_user_valves_by_id_and_user_id( async def update_user_valves_by_id_and_user_id(
self, id: str, user_id: str, valves: dict self, id: str, user_id: str, valves: dict
) -> Optional[dict]: ) -> Optional[dict]:
try: try:
user = Users.get_user_by_id(user_id) user = await Users.get_user_by_id(user_id)
user_settings = user.settings.model_dump() if user.settings else {} user_settings = user.settings.model_dump() if user.settings else {}
# Check if user has "tools" and "valves" settings # Check if user has "tools" and "valves" settings
@ -236,7 +236,7 @@ class ToolsTable:
def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]: def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]:
try: try:
with get_db() as db: async with get_db() as db:
db.query(Tool).filter_by(id=id).update( db.query(Tool).filter_by(id=id).update(
{**updated, "updated_at": int(time.time())} {**updated, "updated_at": int(time.time())}
) )
@ -250,7 +250,7 @@ class ToolsTable:
def delete_tool_by_id(self, id: str) -> bool: def delete_tool_by_id(self, id: str) -> bool:
try: try:
with get_db() as db: async with get_db() as db:
db.query(Tool).filter_by(id=id).delete() db.query(Tool).filter_by(id=id).delete()
db.commit() db.commit()

View file

@ -115,7 +115,7 @@ class UserUpdateForm(BaseModel):
class UsersTable: class UsersTable:
def insert_new_user( async def insert_new_user(
self, self,
id: str, id: str,
name: str, name: str,
@ -124,7 +124,7 @@ class UsersTable:
role: str = "pending", role: str = "pending",
oauth_sub: Optional[str] = None, oauth_sub: Optional[str] = None,
) -> Optional[UserModel]: ) -> Optional[UserModel]:
with get_db() as db: async with get_db() as db:
user = UserModel( user = UserModel(
**{ **{
"id": id, "id": id,
@ -139,53 +139,53 @@ class UsersTable:
} }
) )
result = User(**user.model_dump()) result = User(**user.model_dump())
db.add(result) await db.add(result)
db.commit() await db.commit()
db.refresh(result) await db.refresh(result)
if result: if result:
return user return user
else: else:
return None return None
def get_user_by_id(self, id: str) -> Optional[UserModel]: async def get_user_by_id(self, id: str) -> Optional[UserModel]:
try: try:
with get_db() as db: async with get_db() as db:
user = db.query(User).filter_by(id=id).first() user = await db.query(User).filter_by(id=id).first()
return UserModel.model_validate(user) return UserModel.model_validate(user)
except Exception: except Exception:
return None return None
def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]: async def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]:
try: try:
with get_db() as db: async with get_db() as db:
user = db.query(User).filter_by(api_key=api_key).first() user = await db.query(User).filter_by(api_key=api_key).first()
return UserModel.model_validate(user) return UserModel.model_validate(user)
except Exception: except Exception:
return None return None
def get_user_by_email(self, email: str) -> Optional[UserModel]: async def get_user_by_email(self, email: str) -> Optional[UserModel]:
try: try:
with get_db() as db: async with get_db() as db:
user = db.query(User).filter_by(email=email).first() user = await db.query(User).filter_by(email=email).first()
return UserModel.model_validate(user) return UserModel.model_validate(user)
except Exception: except Exception:
return None return None
def get_user_by_oauth_sub(self, sub: str) -> Optional[UserModel]: async def get_user_by_oauth_sub(self, sub: str) -> Optional[UserModel]:
try: try:
with get_db() as db: async with get_db() as db:
user = db.query(User).filter_by(oauth_sub=sub).first() user = await db.query(User).filter_by(oauth_sub=sub).first()
return UserModel.model_validate(user) return UserModel.model_validate(user)
except Exception: except Exception:
return None return None
def get_users( async def get_users(
self, self,
filter: Optional[dict] = None, filter: Optional[dict] = None,
skip: Optional[int] = None, skip: Optional[int] = None,
limit: Optional[int] = None, limit: Optional[int] = None,
) -> UserListResponse: ) -> UserListResponse:
with get_db() as db: async with get_db() as db:
query = db.query(User) query = db.query(User)
if filter: if filter:
@ -243,37 +243,37 @@ class UsersTable:
if limit: if limit:
query = query.limit(limit) query = query.limit(limit)
users = query.all() users = await query.all()
return { return {
"users": [UserModel.model_validate(user) for user in users], "users": [UserModel.model_validate(user) for user in users],
"total": db.query(User).count(), "total": await db.query(User).count(),
} }
def get_users_by_user_ids(self, user_ids: list[str]) -> list[UserModel]: async def get_users_by_user_ids(self, user_ids: list[str]) -> list[UserModel]:
with get_db() as db: async with get_db() as db:
users = db.query(User).filter(User.id.in_(user_ids)).all() users = await db.query(User).filter(User.id.in_(user_ids)).all()
return [UserModel.model_validate(user) for user in users] return [UserModel.model_validate(user) for user in users]
def get_num_users(self) -> Optional[int]: async def get_num_users(self) -> Optional[int]:
with get_db() as db: async with get_db() as db:
return db.query(User).count() return await db.query(User).count()
def has_users(self) -> bool: async def has_users(self) -> bool:
with get_db() as db: async with get_db() as db:
return db.query(db.query(User).exists()).scalar() return await db.query(db.query(User).exists()).scalar()
def get_first_user(self) -> UserModel: async def get_first_user(self) -> UserModel:
try: try:
with get_db() as db: async with get_db() as db:
user = db.query(User).order_by(User.created_at).first() user = await db.query(User).order_by(User.created_at).first()
return UserModel.model_validate(user) return UserModel.model_validate(user)
except Exception: except Exception:
return None return None
def get_user_webhook_url_by_id(self, id: str) -> Optional[str]: async def get_user_webhook_url_by_id(self, id: str) -> Optional[str]:
try: try:
with get_db() as db: async with get_db() as db:
user = db.query(User).filter_by(id=id).first() user = await db.query(User).filter_by(id=id).first()
if user.settings is None: if user.settings is None:
return None return None
@ -286,99 +286,103 @@ class UsersTable:
except Exception: except Exception:
return None return None
def update_user_role_by_id(self, id: str, role: str) -> Optional[UserModel]: async def update_user_role_by_id(self, id: str, role: str) -> Optional[UserModel]:
try: try:
with get_db() as db: async with get_db() as db:
db.query(User).filter_by(id=id).update({"role": role}) await db.query(User).filter_by(id=id).update({"role": role})
db.commit() await db.commit()
user = db.query(User).filter_by(id=id).first() user = await db.query(User).filter_by(id=id).first()
return UserModel.model_validate(user) return UserModel.model_validate(user)
except Exception: except Exception:
return None return None
def update_user_profile_image_url_by_id( async def update_user_profile_image_url_by_id(
self, id: str, profile_image_url: str self, id: str, profile_image_url: str
) -> Optional[UserModel]: ) -> Optional[UserModel]:
try: try:
with get_db() as db: async with get_db() as db:
db.query(User).filter_by(id=id).update( await db.query(User).filter_by(id=id).update(
{"profile_image_url": profile_image_url} {"profile_image_url": profile_image_url}
) )
db.commit() await db.commit()
user = db.query(User).filter_by(id=id).first() user = await db.query(User).filter_by(id=id).first()
return UserModel.model_validate(user) return UserModel.model_validate(user)
except Exception: except Exception:
return None return None
def update_user_last_active_by_id(self, id: str) -> Optional[UserModel]: async def update_user_last_active_by_id(self, id: str) -> Optional[UserModel]:
try: try:
with get_db() as db: async with get_db() as db:
db.query(User).filter_by(id=id).update( await db.query(User).filter_by(id=id).update(
{"last_active_at": int(time.time())} {"last_active_at": int(time.time())}
) )
db.commit() await db.commit()
user = db.query(User).filter_by(id=id).first() user = await db.query(User).filter_by(id=id).first()
return UserModel.model_validate(user) return UserModel.model_validate(user)
except Exception: except Exception:
return None return None
def update_user_oauth_sub_by_id( async def update_user_oauth_sub_by_id(
self, id: str, oauth_sub: str self, id: str, oauth_sub: str
) -> Optional[UserModel]: ) -> Optional[UserModel]:
try: try:
with get_db() as db: async with get_db() as db:
db.query(User).filter_by(id=id).update({"oauth_sub": oauth_sub}) await db.query(User).filter_by(id=id).update({"oauth_sub": oauth_sub})
db.commit() await db.commit()
user = db.query(User).filter_by(id=id).first() user = await db.query(User).filter_by(id=id).first()
return UserModel.model_validate(user) return UserModel.model_validate(user)
except Exception: except Exception:
return None return None
def update_user_by_id(self, id: str, updated: dict) -> Optional[UserModel]: async def update_user_by_id(self, id: str, updated: dict) -> Optional[UserModel]:
try: try:
with get_db() as db: async with get_db() as db:
db.query(User).filter_by(id=id).update(updated) await db.query(User).filter_by(id=id).update(updated)
db.commit() await db.commit()
user = db.query(User).filter_by(id=id).first() user = await db.query(User).filter_by(id=id).first()
return UserModel.model_validate(user) return UserModel.model_validate(user)
# return UserModel(**user.dict()) # return UserModel(**user.dict())
except Exception: except Exception:
return None return None
def update_user_settings_by_id(self, id: str, updated: dict) -> Optional[UserModel]: async def update_user_settings_by_id(
self, id: str, updated: dict
) -> Optional[UserModel]:
try: try:
with get_db() as db: async with get_db() as db:
user_settings = db.query(User).filter_by(id=id).first().settings user_settings = await db.query(User).filter_by(id=id).first().settings
if user_settings is None: if user_settings is None:
user_settings = {} user_settings = {}
user_settings.update(updated) user_settings.update(updated)
db.query(User).filter_by(id=id).update({"settings": user_settings}) await db.query(User).filter_by(id=id).update(
db.commit() {"settings": user_settings}
)
await db.commit()
user = db.query(User).filter_by(id=id).first() user = await db.query(User).filter_by(id=id).first()
return UserModel.model_validate(user) return UserModel.model_validate(user)
except Exception: except Exception:
return None return None
def delete_user_by_id(self, id: str) -> bool: async def delete_user_by_id(self, id: str) -> bool:
try: try:
# Remove User from Groups # Remove User from Groups
Groups.remove_user_from_all_groups(id) await Groups.remove_user_from_all_groups(id)
# Delete User Chats # Delete User Chats
result = Chats.delete_chats_by_user_id(id) result = await Chats.delete_chats_by_user_id(id)
if result: if result:
with get_db() as db: async with get_db() as db:
# Delete User # Delete User
db.query(User).filter_by(id=id).delete() await db.query(User).filter_by(id=id).delete()
db.commit() await db.commit()
return True return True
else: else:
@ -386,31 +390,33 @@ class UsersTable:
except Exception: except Exception:
return False return False
def update_user_api_key_by_id(self, id: str, api_key: str) -> bool: async def update_user_api_key_by_id(self, id: str, api_key: str) -> bool:
try: try:
with get_db() as db: async with get_db() as db:
result = db.query(User).filter_by(id=id).update({"api_key": api_key}) result = (
db.commit() await db.query(User).filter_by(id=id).update({"api_key": api_key})
)
await db.commit()
return True if result == 1 else False return True if result == 1 else False
except Exception: except Exception:
return False return False
def get_user_api_key_by_id(self, id: str) -> Optional[str]: async def get_user_api_key_by_id(self, id: str) -> Optional[str]:
try: try:
with get_db() as db: async with get_db() as db:
user = db.query(User).filter_by(id=id).first() user = await db.query(User).filter_by(id=id).first()
return user.api_key return user.api_key
except Exception: except Exception:
return None return None
def get_valid_user_ids(self, user_ids: list[str]) -> list[str]: async def get_valid_user_ids(self, user_ids: list[str]) -> list[str]:
with get_db() as db: async with get_db() as db:
users = db.query(User).filter(User.id.in_(user_ids)).all() users = await db.query(User).filter(User.id.in_(user_ids)).all()
return [user.id for user in users] return [user.id for user in users]
def get_super_admin_user(self) -> Optional[UserModel]: async def get_super_admin_user(self) -> Optional[UserModel]:
with get_db() as db: async with get_db() as db:
user = db.query(User).filter_by(role="admin").first() user = await db.query(User).filter_by(role="admin").first()
if user: if user:
return UserModel.model_validate(user) return UserModel.model_validate(user)
else: else:

View file

@ -159,11 +159,11 @@ async def update_password(
if WEBUI_AUTH_TRUSTED_EMAIL_HEADER: if WEBUI_AUTH_TRUSTED_EMAIL_HEADER:
raise HTTPException(400, detail=ERROR_MESSAGES.ACTION_PROHIBITED) raise HTTPException(400, detail=ERROR_MESSAGES.ACTION_PROHIBITED)
if session_user: if session_user:
user = Auths.authenticate_user(session_user.email, form_data.password) user = await Auths.authenticate_user(session_user.email, form_data.password)
if user: if user:
hashed = get_password_hash(form_data.new_password) hashed = get_password_hash(form_data.new_password)
return Auths.update_user_password_by_id(user.id, hashed) return await Auths.update_user_password_by_id(user.id, hashed)
else: else:
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_PASSWORD) raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_PASSWORD)
else: else:
@ -357,7 +357,7 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
else request.app.state.config.DEFAULT_USER_ROLE else request.app.state.config.DEFAULT_USER_ROLE
) )
user = Auths.insert_new_auth( user = await Auths.insert_new_auth(
email=email, email=email,
password=str(uuid.uuid4()), password=str(uuid.uuid4()),
name=cn, name=cn,
@ -377,7 +377,7 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
500, detail="Internal error occurred during LDAP user creation." 500, detail="Internal error occurred during LDAP user creation."
) )
user = Auths.authenticate_user_by_email(email) user = await Auths.authenticate_user_by_email(email)
if user: if user:
expires_delta = parse_duration(request.app.state.config.JWT_EXPIRES_IN) expires_delta = parse_duration(request.app.state.config.JWT_EXPIRES_IN)
@ -470,7 +470,7 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
SignupForm(email=email, password=str(uuid.uuid4()), name=name), SignupForm(email=email, password=str(uuid.uuid4()), name=name),
) )
user = Auths.authenticate_user_by_email(email) user = await Auths.authenticate_user_by_email(email)
if WEBUI_AUTH_TRUSTED_GROUPS_HEADER and user and user.role != "admin": if WEBUI_AUTH_TRUSTED_GROUPS_HEADER and user and user.role != "admin":
group_names = request.headers.get( group_names = request.headers.get(
WEBUI_AUTH_TRUSTED_GROUPS_HEADER, "" WEBUI_AUTH_TRUSTED_GROUPS_HEADER, ""
@ -485,7 +485,7 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
admin_password = "admin" admin_password = "admin"
if Users.get_user_by_email(admin_email.lower()): if Users.get_user_by_email(admin_email.lower()):
user = Auths.authenticate_user(admin_email.lower(), admin_password) user = await Auths.authenticate_user(admin_email.lower(), admin_password)
else: else:
if Users.has_users(): if Users.has_users():
raise HTTPException(400, detail=ERROR_MESSAGES.EXISTING_USERS) raise HTTPException(400, detail=ERROR_MESSAGES.EXISTING_USERS)
@ -496,9 +496,11 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
SignupForm(email=admin_email, password=admin_password, name="User"), SignupForm(email=admin_email, password=admin_password, name="User"),
) )
user = Auths.authenticate_user(admin_email.lower(), admin_password) user = await Auths.authenticate_user(admin_email.lower(), admin_password)
else: else:
user = Auths.authenticate_user(form_data.email.lower(), form_data.password) user = await Auths.authenticate_user(
form_data.email.lower(), form_data.password
)
if user: if user:
@ -589,7 +591,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
) )
hashed = get_password_hash(form_data.password) hashed = get_password_hash(form_data.password)
user = Auths.insert_new_auth( user = await Auths.insert_new_auth(
form_data.email.lower(), form_data.email.lower(),
hashed, hashed,
form_data.name, form_data.name,
@ -736,7 +738,7 @@ async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)):
try: try:
hashed = get_password_hash(form_data.password) hashed = get_password_hash(form_data.password)
user = Auths.insert_new_auth( user = await Auths.insert_new_auth(
form_data.email.lower(), form_data.email.lower(),
hashed, hashed,
form_data.name, form_data.name,

View file

@ -40,14 +40,14 @@ router = APIRouter()
@router.get("/", response_model=list[ChannelModel]) @router.get("/", response_model=list[ChannelModel])
async def get_channels(user=Depends(get_verified_user)): async def get_channels(user=Depends(get_verified_user)):
return Channels.get_channels_by_user_id(user.id) return await Channels.get_channels_by_user_id(user.id)
@router.get("/list", response_model=list[ChannelModel]) @router.get("/list", response_model=list[ChannelModel])
async def get_all_channels(user=Depends(get_verified_user)): async def get_all_channels(user=Depends(get_verified_user)):
if user.role == "admin": if user.role == "admin":
return Channels.get_channels() return await Channels.get_channels()
return Channels.get_channels_by_user_id(user.id) return await Channels.get_channels_by_user_id(user.id)
############################ ############################
@ -58,7 +58,7 @@ async def get_all_channels(user=Depends(get_verified_user)):
@router.post("/create", response_model=Optional[ChannelModel]) @router.post("/create", response_model=Optional[ChannelModel])
async def create_new_channel(form_data: ChannelForm, user=Depends(get_admin_user)): async def create_new_channel(form_data: ChannelForm, user=Depends(get_admin_user)):
try: try:
channel = Channels.insert_new_channel(None, form_data, user.id) channel = await Channels.insert_new_channel(None, form_data, user.id)
return ChannelModel(**channel.model_dump()) return ChannelModel(**channel.model_dump())
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
@ -74,7 +74,7 @@ async def create_new_channel(form_data: ChannelForm, user=Depends(get_admin_user
@router.get("/{id}", response_model=Optional[ChannelModel]) @router.get("/{id}", response_model=Optional[ChannelModel])
async def get_channel_by_id(id: str, user=Depends(get_verified_user)): async def get_channel_by_id(id: str, user=Depends(get_verified_user)):
channel = Channels.get_channel_by_id(id) channel = await Channels.get_channel_by_id(id)
if not channel: if not channel:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
@ -99,14 +99,14 @@ async def get_channel_by_id(id: str, user=Depends(get_verified_user)):
async def update_channel_by_id( async def update_channel_by_id(
id: str, form_data: ChannelForm, user=Depends(get_admin_user) id: str, form_data: ChannelForm, user=Depends(get_admin_user)
): ):
channel = Channels.get_channel_by_id(id) channel = await Channels.get_channel_by_id(id)
if not channel: if not channel:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
) )
try: try:
channel = Channels.update_channel_by_id(id, form_data) channel = await Channels.update_channel_by_id(id, form_data)
return ChannelModel(**channel.model_dump()) return ChannelModel(**channel.model_dump())
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
@ -122,14 +122,14 @@ async def update_channel_by_id(
@router.delete("/{id}/delete", response_model=bool) @router.delete("/{id}/delete", response_model=bool)
async def delete_channel_by_id(id: str, user=Depends(get_admin_user)): async def delete_channel_by_id(id: str, user=Depends(get_admin_user)):
channel = Channels.get_channel_by_id(id) channel = await Channels.get_channel_by_id(id)
if not channel: if not channel:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
) )
try: try:
Channels.delete_channel_by_id(id) await Channels.delete_channel_by_id(id)
return True return True
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
@ -151,7 +151,7 @@ class MessageUserResponse(MessageResponse):
async def get_channel_messages( async def get_channel_messages(
id: str, skip: int = 0, limit: int = 50, user=Depends(get_verified_user) id: str, skip: int = 0, limit: int = 50, user=Depends(get_verified_user)
): ):
channel = Channels.get_channel_by_id(id) channel = await Channels.get_channel_by_id(id)
if not channel: if not channel:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
@ -170,7 +170,7 @@ async def get_channel_messages(
messages = [] messages = []
for message in message_list: for message in message_list:
if message.user_id not in users: if message.user_id not in users:
user = Users.get_user_by_id(message.user_id) user = await Users.get_user_by_id(message.user_id)
users[message.user_id] = user users[message.user_id] = user
replies = Messages.get_replies_by_message_id(message.id) replies = Messages.get_replies_by_message_id(message.id)
@ -230,7 +230,7 @@ async def post_new_message(
background_tasks: BackgroundTasks, background_tasks: BackgroundTasks,
user=Depends(get_verified_user), user=Depends(get_verified_user),
): ):
channel = Channels.get_channel_by_id(id) channel = await Channels.get_channel_by_id(id)
if not channel: if not channel:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
@ -290,9 +290,13 @@ async def post_new_message(
**{ **{
**parent_message.model_dump(), **parent_message.model_dump(),
"user": UserNameResponse( "user": UserNameResponse(
**Users.get_user_by_id( **(
parent_message.user_id (
).model_dump() await Users.get_user_by_id(
parent_message.user_id
)
).model_dump()
)
), ),
} }
).model_dump(), ).model_dump(),
@ -331,7 +335,7 @@ async def post_new_message(
async def get_channel_message( async def get_channel_message(
id: str, message_id: str, user=Depends(get_verified_user) id: str, message_id: str, user=Depends(get_verified_user)
): ):
channel = Channels.get_channel_by_id(id) channel = await Channels.get_channel_by_id(id)
if not channel: if not channel:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
@ -359,7 +363,7 @@ async def get_channel_message(
**{ **{
**message.model_dump(), **message.model_dump(),
"user": UserNameResponse( "user": UserNameResponse(
**Users.get_user_by_id(message.user_id).model_dump() **((await Users.get_user_by_id(message.user_id)).model_dump())
), ),
} }
) )
@ -380,7 +384,7 @@ async def get_channel_thread_messages(
limit: int = 50, limit: int = 50,
user=Depends(get_verified_user), user=Depends(get_verified_user),
): ):
channel = Channels.get_channel_by_id(id) channel = await Channels.get_channel_by_id(id)
if not channel: if not channel:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
@ -399,7 +403,7 @@ async def get_channel_thread_messages(
messages = [] messages = []
for message in message_list: for message in message_list:
if message.user_id not in users: if message.user_id not in users:
user = Users.get_user_by_id(message.user_id) user = await Users.get_user_by_id(message.user_id)
users[message.user_id] = user users[message.user_id] = user
messages.append( messages.append(
@ -428,7 +432,7 @@ async def get_channel_thread_messages(
async def update_message_by_id( async def update_message_by_id(
id: str, message_id: str, form_data: MessageForm, user=Depends(get_verified_user) id: str, message_id: str, form_data: MessageForm, user=Depends(get_verified_user)
): ):
channel = Channels.get_channel_by_id(id) channel = await Channels.get_channel_by_id(id)
if not channel: if not channel:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
@ -502,7 +506,7 @@ class ReactionForm(BaseModel):
async def add_reaction_to_message( async def add_reaction_to_message(
id: str, message_id: str, form_data: ReactionForm, user=Depends(get_verified_user) id: str, message_id: str, form_data: ReactionForm, user=Depends(get_verified_user)
): ):
channel = Channels.get_channel_by_id(id) channel = await Channels.get_channel_by_id(id)
if not channel: if not channel:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
@ -540,7 +544,7 @@ async def add_reaction_to_message(
"data": { "data": {
**message.model_dump(), **message.model_dump(),
"user": UserNameResponse( "user": UserNameResponse(
**Users.get_user_by_id(message.user_id).model_dump() **(await Users.get_user_by_id(message.user_id)).model_dump()
).model_dump(), ).model_dump(),
"name": form_data.name, "name": form_data.name,
}, },
@ -568,7 +572,7 @@ async def add_reaction_to_message(
async def remove_reaction_by_id_and_user_id_and_name( async def remove_reaction_by_id_and_user_id_and_name(
id: str, message_id: str, form_data: ReactionForm, user=Depends(get_verified_user) id: str, message_id: str, form_data: ReactionForm, user=Depends(get_verified_user)
): ):
channel = Channels.get_channel_by_id(id) channel = await Channels.get_channel_by_id(id)
if not channel: if not channel:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
@ -609,7 +613,7 @@ async def remove_reaction_by_id_and_user_id_and_name(
"data": { "data": {
**message.model_dump(), **message.model_dump(),
"user": UserNameResponse( "user": UserNameResponse(
**Users.get_user_by_id(message.user_id).model_dump() **(await Users.get_user_by_id(message.user_id)).model_dump()
).model_dump(), ).model_dump(),
"name": form_data.name, "name": form_data.name,
}, },
@ -637,7 +641,7 @@ async def remove_reaction_by_id_and_user_id_and_name(
async def delete_message_by_id( async def delete_message_by_id(
id: str, message_id: str, user=Depends(get_verified_user) id: str, message_id: str, user=Depends(get_verified_user)
): ):
channel = Channels.get_channel_by_id(id) channel = await Channels.get_channel_by_id(id)
if not channel: if not channel:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
@ -699,8 +703,10 @@ async def delete_message_by_id(
**{ **{
**parent_message.model_dump(), **parent_message.model_dump(),
"user": UserNameResponse( "user": UserNameResponse(
**Users.get_user_by_id( **(
parent_message.user_id await Users.get_user_by_id(
parent_message.user_id
)
).model_dump() ).model_dump()
), ),
} }

View file

@ -44,11 +44,11 @@ async def get_session_user_chat_list(
limit = 60 limit = 60
skip = (page - 1) * limit skip = (page - 1) * limit
return Chats.get_chat_title_id_list_by_user_id( return await Chats.get_chat_title_id_list_by_user_id(
user.id, skip=skip, limit=limit user.id, skip=skip, limit=limit
) )
else: else:
return Chats.get_chat_title_id_list_by_user_id(user.id) return await Chats.get_chat_title_id_list_by_user_id(user.id)
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
raise HTTPException( raise HTTPException(
@ -72,7 +72,7 @@ async def delete_all_user_chats(request: Request, user=Depends(get_verified_user
detail=ERROR_MESSAGES.ACCESS_PROHIBITED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
) )
result = Chats.delete_chats_by_user_id(user.id) result = await Chats.delete_chats_by_user_id(user.id)
return result return result
@ -110,7 +110,7 @@ async def get_user_chat_list_by_user_id(
if direction: if direction:
filter["direction"] = direction filter["direction"] = direction
return Chats.get_chat_list_by_user_id( return await Chats.get_chat_list_by_user_id(
user_id, include_archived=True, filter=filter, skip=skip, limit=limit user_id, include_archived=True, filter=filter, skip=skip, limit=limit
) )
@ -123,7 +123,7 @@ async def get_user_chat_list_by_user_id(
@router.post("/new", response_model=Optional[ChatResponse]) @router.post("/new", response_model=Optional[ChatResponse])
async def create_new_chat(form_data: ChatForm, user=Depends(get_verified_user)): async def create_new_chat(form_data: ChatForm, user=Depends(get_verified_user)):
try: try:
chat = Chats.insert_new_chat(user.id, form_data) chat = await Chats.insert_new_chat(user.id, form_data)
return ChatResponse(**chat.model_dump()) return ChatResponse(**chat.model_dump())
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
@ -140,7 +140,7 @@ async def create_new_chat(form_data: ChatForm, user=Depends(get_verified_user)):
@router.post("/import", response_model=Optional[ChatResponse]) @router.post("/import", response_model=Optional[ChatResponse])
async def import_chat(form_data: ChatImportForm, user=Depends(get_verified_user)): async def import_chat(form_data: ChatImportForm, user=Depends(get_verified_user)):
try: try:
chat = Chats.import_chat(user.id, form_data) chat = await Chats.import_chat(user.id, form_data)
if chat: if chat:
tags = chat.meta.get("tags", []) tags = chat.meta.get("tags", [])
for tag_id in tags: for tag_id in tags:
@ -177,7 +177,7 @@ async def search_user_chats(
chat_list = [ chat_list = [
ChatTitleIdResponse(**chat.model_dump()) ChatTitleIdResponse(**chat.model_dump())
for chat in Chats.get_chats_by_user_id_and_search_text( for chat in await Chats.get_chats_by_user_id_and_search_text(
user.id, text, skip=skip, limit=limit user.id, text, skip=skip, limit=limit
) )
] ]
@ -187,9 +187,9 @@ async def search_user_chats(
if page == 1 and len(words) == 1 and words[0].startswith("tag:"): if page == 1 and len(words) == 1 and words[0].startswith("tag:"):
tag_id = words[0].replace("tag:", "") tag_id = words[0].replace("tag:", "")
if len(chat_list) == 0: if len(chat_list) == 0:
if Tags.get_tag_by_name_and_user_id(tag_id, user.id): if await Tags.get_tag_by_name_and_user_id(tag_id, user.id):
log.debug(f"deleting tag: {tag_id}") log.debug(f"deleting tag: {tag_id}")
Tags.delete_tag_by_name_and_user_id(tag_id, user.id) await Tags.delete_tag_by_name_and_user_id(tag_id, user.id)
return chat_list return chat_list
@ -210,7 +210,7 @@ async def get_chats_by_folder_id(folder_id: str, user=Depends(get_verified_user)
return [ return [
ChatResponse(**chat.model_dump()) ChatResponse(**chat.model_dump())
for chat in Chats.get_chats_by_folder_ids_and_user_id(folder_ids, user.id) for chat in await Chats.get_chats_by_folder_ids_and_user_id(folder_ids, user.id)
] ]
@ -223,7 +223,7 @@ async def get_chats_by_folder_id(folder_id: str, user=Depends(get_verified_user)
async def get_user_pinned_chats(user=Depends(get_verified_user)): async def get_user_pinned_chats(user=Depends(get_verified_user)):
return [ return [
ChatTitleIdResponse(**chat.model_dump()) ChatTitleIdResponse(**chat.model_dump())
for chat in Chats.get_pinned_chats_by_user_id(user.id) for chat in await Chats.get_pinned_chats_by_user_id(user.id)
] ]
@ -236,7 +236,7 @@ async def get_user_pinned_chats(user=Depends(get_verified_user)):
async def get_user_chats(user=Depends(get_verified_user)): async def get_user_chats(user=Depends(get_verified_user)):
return [ return [
ChatResponse(**chat.model_dump()) ChatResponse(**chat.model_dump())
for chat in Chats.get_chats_by_user_id(user.id) for chat in await Chats.get_chats_by_user_id(user.id)
] ]
@ -249,7 +249,7 @@ async def get_user_chats(user=Depends(get_verified_user)):
async def get_user_archived_chats(user=Depends(get_verified_user)): async def get_user_archived_chats(user=Depends(get_verified_user)):
return [ return [
ChatResponse(**chat.model_dump()) ChatResponse(**chat.model_dump())
for chat in Chats.get_archived_chats_by_user_id(user.id) for chat in await Chats.get_archived_chats_by_user_id(user.id)
] ]
@ -282,7 +282,7 @@ async def get_all_user_chats_in_db(user=Depends(get_admin_user)):
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
) )
return [ChatResponse(**chat.model_dump()) for chat in Chats.get_chats()] return [ChatResponse(**chat.model_dump()) for chat in await Chats.get_chats()]
############################ ############################
@ -314,7 +314,7 @@ async def get_archived_session_user_chat_list(
chat_list = [ chat_list = [
ChatTitleIdResponse(**chat.model_dump()) ChatTitleIdResponse(**chat.model_dump())
for chat in Chats.get_archived_chat_list_by_user_id( for chat in await Chats.get_archived_chat_list_by_user_id(
user.id, user.id,
filter=filter, filter=filter,
skip=skip, skip=skip,
@ -332,7 +332,7 @@ async def get_archived_session_user_chat_list(
@router.post("/archive/all", response_model=bool) @router.post("/archive/all", response_model=bool)
async def archive_all_chats(user=Depends(get_verified_user)): async def archive_all_chats(user=Depends(get_verified_user)):
return Chats.archive_all_chats_by_user_id(user.id) return await Chats.archive_all_chats_by_user_id(user.id)
############################ ############################
@ -348,9 +348,9 @@ async def get_shared_chat_by_id(share_id: str, user=Depends(get_verified_user)):
) )
if user.role == "user" or (user.role == "admin" and not ENABLE_ADMIN_CHAT_ACCESS): if user.role == "user" or (user.role == "admin" and not ENABLE_ADMIN_CHAT_ACCESS):
chat = Chats.get_chat_by_share_id(share_id) chat = await Chats.get_chat_by_share_id(share_id)
elif user.role == "admin" and ENABLE_ADMIN_CHAT_ACCESS: elif user.role == "admin" and ENABLE_ADMIN_CHAT_ACCESS:
chat = Chats.get_chat_by_id(share_id) chat = await Chats.get_chat_by_id(share_id)
if chat: if chat:
return ChatResponse(**chat.model_dump()) return ChatResponse(**chat.model_dump())
@ -379,11 +379,11 @@ class TagFilterForm(TagForm):
async def get_user_chat_list_by_tag_name( async def get_user_chat_list_by_tag_name(
form_data: TagFilterForm, user=Depends(get_verified_user) form_data: TagFilterForm, user=Depends(get_verified_user)
): ):
chats = Chats.get_chat_list_by_user_id_and_tag_name( chats = await Chats.get_chat_list_by_user_id_and_tag_name(
user.id, form_data.name, form_data.skip, form_data.limit user.id, form_data.name, form_data.skip, form_data.limit
) )
if len(chats) == 0: if len(chats) == 0:
Tags.delete_tag_by_name_and_user_id(form_data.name, user.id) await Tags.delete_tag_by_name_and_user_id(form_data.name, user.id)
return chats return chats
@ -395,7 +395,7 @@ async def get_user_chat_list_by_tag_name(
@router.get("/{id}", response_model=Optional[ChatResponse]) @router.get("/{id}", response_model=Optional[ChatResponse])
async def get_chat_by_id(id: str, user=Depends(get_verified_user)): async def get_chat_by_id(id: str, user=Depends(get_verified_user)):
chat = Chats.get_chat_by_id_and_user_id(id, user.id) chat = await Chats.get_chat_by_id_and_user_id(id, user.id)
if chat: if chat:
return ChatResponse(**chat.model_dump()) return ChatResponse(**chat.model_dump())
@ -415,10 +415,10 @@ async def get_chat_by_id(id: str, user=Depends(get_verified_user)):
async def update_chat_by_id( async def update_chat_by_id(
id: str, form_data: ChatForm, user=Depends(get_verified_user) id: str, form_data: ChatForm, user=Depends(get_verified_user)
): ):
chat = Chats.get_chat_by_id_and_user_id(id, user.id) chat = await Chats.get_chat_by_id_and_user_id(id, user.id)
if chat: if chat:
updated_chat = {**chat.chat, **form_data.chat} updated_chat = {**chat.chat, **form_data.chat}
chat = Chats.update_chat_by_id(id, updated_chat) chat = await Chats.update_chat_by_id(id, updated_chat)
return ChatResponse(**chat.model_dump()) return ChatResponse(**chat.model_dump())
else: else:
raise HTTPException( raise HTTPException(
@ -438,7 +438,7 @@ class MessageForm(BaseModel):
async def update_chat_message_by_id( async def update_chat_message_by_id(
id: str, message_id: str, form_data: MessageForm, user=Depends(get_verified_user) id: str, message_id: str, form_data: MessageForm, user=Depends(get_verified_user)
): ):
chat = Chats.get_chat_by_id(id) chat = await Chats.get_chat_by_id(id)
if not chat: if not chat:
raise HTTPException( raise HTTPException(
@ -452,7 +452,7 @@ async def update_chat_message_by_id(
detail=ERROR_MESSAGES.ACCESS_PROHIBITED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
) )
chat = Chats.upsert_message_to_chat_by_id_and_message_id( chat = await Chats.upsert_message_to_chat_by_id_and_message_id(
id, id,
message_id, message_id,
{ {
@ -496,7 +496,7 @@ class EventForm(BaseModel):
async def send_chat_message_event_by_id( async def send_chat_message_event_by_id(
id: str, message_id: str, form_data: EventForm, user=Depends(get_verified_user) id: str, message_id: str, form_data: EventForm, user=Depends(get_verified_user)
): ):
chat = Chats.get_chat_by_id(id) chat = await Chats.get_chat_by_id(id)
if not chat: if not chat:
raise HTTPException( raise HTTPException(
@ -536,12 +536,12 @@ async def send_chat_message_event_by_id(
@router.delete("/{id}", response_model=bool) @router.delete("/{id}", response_model=bool)
async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified_user)): async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified_user)):
if user.role == "admin": if user.role == "admin":
chat = Chats.get_chat_by_id(id) chat = await Chats.get_chat_by_id(id)
for tag in chat.meta.get("tags", []): for tag in chat.meta.get("tags", []):
if Chats.count_chats_by_tag_name_and_user_id(tag, user.id) == 1: if await Chats.count_chats_by_tag_name_and_user_id(tag, user.id) == 1:
Tags.delete_tag_by_name_and_user_id(tag, user.id) await Tags.delete_tag_by_name_and_user_id(tag, user.id)
result = Chats.delete_chat_by_id(id) result = await Chats.delete_chat_by_id(id)
return result return result
else: else:
@ -553,12 +553,12 @@ async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified
detail=ERROR_MESSAGES.ACCESS_PROHIBITED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
) )
chat = Chats.get_chat_by_id(id) chat = await Chats.get_chat_by_id(id)
for tag in chat.meta.get("tags", []): for tag in chat.meta.get("tags", []):
if Chats.count_chats_by_tag_name_and_user_id(tag, user.id) == 1: if await Chats.count_chats_by_tag_name_and_user_id(tag, user.id) == 1:
Tags.delete_tag_by_name_and_user_id(tag, user.id) Tags.delete_tag_by_name_and_user_id(tag, user.id)
result = Chats.delete_chat_by_id_and_user_id(id, user.id) result = await Chats.delete_chat_by_id_and_user_id(id, user.id)
return result return result
@ -569,7 +569,7 @@ async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified
@router.get("/{id}/pinned", response_model=Optional[bool]) @router.get("/{id}/pinned", response_model=Optional[bool])
async def get_pinned_status_by_id(id: str, user=Depends(get_verified_user)): async def get_pinned_status_by_id(id: str, user=Depends(get_verified_user)):
chat = Chats.get_chat_by_id_and_user_id(id, user.id) chat = await Chats.get_chat_by_id_and_user_id(id, user.id)
if chat: if chat:
return chat.pinned return chat.pinned
else: else:
@ -585,9 +585,9 @@ async def get_pinned_status_by_id(id: str, user=Depends(get_verified_user)):
@router.post("/{id}/pin", response_model=Optional[ChatResponse]) @router.post("/{id}/pin", response_model=Optional[ChatResponse])
async def pin_chat_by_id(id: str, user=Depends(get_verified_user)): async def pin_chat_by_id(id: str, user=Depends(get_verified_user)):
chat = Chats.get_chat_by_id_and_user_id(id, user.id) chat = await Chats.get_chat_by_id_and_user_id(id, user.id)
if chat: if chat:
chat = Chats.toggle_chat_pinned_by_id(id) chat = await Chats.toggle_chat_pinned_by_id(id)
return chat return chat
else: else:
raise HTTPException( raise HTTPException(
@ -608,7 +608,7 @@ class CloneForm(BaseModel):
async def clone_chat_by_id( async def clone_chat_by_id(
form_data: CloneForm, id: str, user=Depends(get_verified_user) form_data: CloneForm, id: str, user=Depends(get_verified_user)
): ):
chat = Chats.get_chat_by_id_and_user_id(id, user.id) chat = await Chats.get_chat_by_id_and_user_id(id, user.id)
if chat: if chat:
updated_chat = { updated_chat = {
**chat.chat, **chat.chat,
@ -617,7 +617,7 @@ async def clone_chat_by_id(
"title": form_data.title if form_data.title else f"Clone of {chat.title}", "title": form_data.title if form_data.title else f"Clone of {chat.title}",
} }
chat = Chats.import_chat( chat = await Chats.import_chat(
user.id, user.id,
ChatImportForm( ChatImportForm(
**{ **{
@ -645,9 +645,9 @@ async def clone_chat_by_id(
async def clone_shared_chat_by_id(id: str, user=Depends(get_verified_user)): async def clone_shared_chat_by_id(id: str, user=Depends(get_verified_user)):
if user.role == "admin": if user.role == "admin":
chat = Chats.get_chat_by_id(id) chat = await Chats.get_chat_by_id(id)
else: else:
chat = Chats.get_chat_by_share_id(id) chat = await Chats.get_chat_by_share_id(id)
if chat: if chat:
updated_chat = { updated_chat = {
@ -657,7 +657,7 @@ async def clone_shared_chat_by_id(id: str, user=Depends(get_verified_user)):
"title": f"Clone of {chat.title}", "title": f"Clone of {chat.title}",
} }
chat = Chats.import_chat( chat = await Chats.import_chat(
user.id, user.id,
ChatImportForm( ChatImportForm(
**{ **{
@ -682,22 +682,25 @@ async def clone_shared_chat_by_id(id: str, user=Depends(get_verified_user)):
@router.post("/{id}/archive", response_model=Optional[ChatResponse]) @router.post("/{id}/archive", response_model=Optional[ChatResponse])
async def archive_chat_by_id(id: str, user=Depends(get_verified_user)): async def archive_chat_by_id(id: str, user=Depends(get_verified_user)):
chat = Chats.get_chat_by_id_and_user_id(id, user.id) chat = await Chats.get_chat_by_id_and_user_id(id, user.id)
if chat: if chat:
chat = Chats.toggle_chat_archive_by_id(id) chat = await Chats.toggle_chat_archive_by_id(id)
# Delete tags if chat is archived # Delete tags if chat is archived
if chat.archived: if chat.archived:
for tag_id in chat.meta.get("tags", []): for tag_id in chat.meta.get("tags", []):
if Chats.count_chats_by_tag_name_and_user_id(tag_id, user.id) == 0: if (
await Chats.count_chats_by_tag_name_and_user_id(tag_id, user.id)
== 0
):
log.debug(f"deleting tag: {tag_id}") log.debug(f"deleting tag: {tag_id}")
Tags.delete_tag_by_name_and_user_id(tag_id, user.id) Tags.delete_tag_by_name_and_user_id(tag_id, user.id)
else: else:
for tag_id in chat.meta.get("tags", []): for tag_id in chat.meta.get("tags", []):
tag = Tags.get_tag_by_name_and_user_id(tag_id, user.id) tag = await Tags.get_tag_by_name_and_user_id(tag_id, user.id)
if tag is None: if tag is None:
log.debug(f"inserting tag: {tag_id}") log.debug(f"inserting tag: {tag_id}")
tag = Tags.insert_new_tag(tag_id, user.id) tag = await Tags.insert_new_tag(tag_id, user.id)
return ChatResponse(**chat.model_dump()) return ChatResponse(**chat.model_dump())
else: else:
@ -723,14 +726,14 @@ async def share_chat_by_id(request: Request, id: str, user=Depends(get_verified_
detail=ERROR_MESSAGES.ACCESS_PROHIBITED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
) )
chat = Chats.get_chat_by_id_and_user_id(id, user.id) chat = await Chats.get_chat_by_id_and_user_id(id, user.id)
if chat: if chat:
if chat.share_id: if chat.share_id:
shared_chat = Chats.update_shared_chat_by_chat_id(chat.id) shared_chat = await Chats.update_shared_chat_by_chat_id(chat.id)
return ChatResponse(**shared_chat.model_dump()) return ChatResponse(**shared_chat.model_dump())
shared_chat = Chats.insert_shared_chat_by_chat_id(chat.id) shared_chat = await Chats.insert_shared_chat_by_chat_id(chat.id)
if not shared_chat: if not shared_chat:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
@ -752,13 +755,13 @@ async def share_chat_by_id(request: Request, id: str, user=Depends(get_verified_
@router.delete("/{id}/share", response_model=Optional[bool]) @router.delete("/{id}/share", response_model=Optional[bool])
async def delete_shared_chat_by_id(id: str, user=Depends(get_verified_user)): async def delete_shared_chat_by_id(id: str, user=Depends(get_verified_user)):
chat = Chats.get_chat_by_id_and_user_id(id, user.id) chat = await Chats.get_chat_by_id_and_user_id(id, user.id)
if chat: if chat:
if not chat.share_id: if not chat.share_id:
return False return False
result = Chats.delete_shared_chat_by_chat_id(id) result = await Chats.delete_shared_chat_by_chat_id(id)
update_result = Chats.update_chat_share_id_by_id(id, None) update_result = await Chats.update_chat_share_id_by_id(id, None)
return result and update_result != None return result and update_result != None
else: else:
@ -781,9 +784,9 @@ class ChatFolderIdForm(BaseModel):
async def update_chat_folder_id_by_id( async def update_chat_folder_id_by_id(
id: str, form_data: ChatFolderIdForm, user=Depends(get_verified_user) id: str, form_data: ChatFolderIdForm, user=Depends(get_verified_user)
): ):
chat = Chats.get_chat_by_id_and_user_id(id, user.id) chat = await Chats.get_chat_by_id_and_user_id(id, user.id)
if chat: if chat:
chat = Chats.update_chat_folder_id_by_id_and_user_id( chat = await Chats.update_chat_folder_id_by_id_and_user_id(
id, user.id, form_data.folder_id id, user.id, form_data.folder_id
) )
return ChatResponse(**chat.model_dump()) return ChatResponse(**chat.model_dump())
@ -800,10 +803,10 @@ async def update_chat_folder_id_by_id(
@router.get("/{id}/tags", response_model=list[TagModel]) @router.get("/{id}/tags", response_model=list[TagModel])
async def get_chat_tags_by_id(id: str, user=Depends(get_verified_user)): async def get_chat_tags_by_id(id: str, user=Depends(get_verified_user)):
chat = Chats.get_chat_by_id_and_user_id(id, user.id) chat = await Chats.get_chat_by_id_and_user_id(id, user.id)
if chat: if chat:
tags = chat.meta.get("tags", []) tags = chat.meta.get("tags", [])
return Tags.get_tags_by_ids_and_user_id(tags, user.id) return await Tags.get_tags_by_ids_and_user_id(tags, user.id)
else: else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
@ -819,7 +822,7 @@ async def get_chat_tags_by_id(id: str, user=Depends(get_verified_user)):
async def add_tag_by_id_and_tag_name( async def add_tag_by_id_and_tag_name(
id: str, form_data: TagForm, user=Depends(get_verified_user) id: str, form_data: TagForm, user=Depends(get_verified_user)
): ):
chat = Chats.get_chat_by_id_and_user_id(id, user.id) chat = await Chats.get_chat_by_id_and_user_id(id, user.id)
if chat: if chat:
tags = chat.meta.get("tags", []) tags = chat.meta.get("tags", [])
tag_id = form_data.name.replace(" ", "_").lower() tag_id = form_data.name.replace(" ", "_").lower()
@ -831,13 +834,13 @@ async def add_tag_by_id_and_tag_name(
) )
if tag_id not in tags: if tag_id not in tags:
Chats.add_chat_tag_by_id_and_user_id_and_tag_name( await Chats.add_chat_tag_by_id_and_user_id_and_tag_name(
id, user.id, form_data.name id, user.id, form_data.name
) )
chat = Chats.get_chat_by_id_and_user_id(id, user.id) chat = await Chats.get_chat_by_id_and_user_id(id, user.id)
tags = chat.meta.get("tags", []) tags = chat.meta.get("tags", [])
return Tags.get_tags_by_ids_and_user_id(tags, user.id) return await Tags.get_tags_by_ids_and_user_id(tags, user.id)
else: else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT() status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
@ -853,16 +856,21 @@ async def add_tag_by_id_and_tag_name(
async def delete_tag_by_id_and_tag_name( async def delete_tag_by_id_and_tag_name(
id: str, form_data: TagForm, user=Depends(get_verified_user) id: str, form_data: TagForm, user=Depends(get_verified_user)
): ):
chat = Chats.get_chat_by_id_and_user_id(id, user.id) chat = await Chats.get_chat_by_id_and_user_id(id, user.id)
if chat: if chat:
Chats.delete_tag_by_id_and_user_id_and_tag_name(id, user.id, form_data.name) await Chats.delete_tag_by_id_and_user_id_and_tag_name(
id, user.id, form_data.name
)
if Chats.count_chats_by_tag_name_and_user_id(form_data.name, user.id) == 0: if (
Tags.delete_tag_by_name_and_user_id(form_data.name, user.id) await Chats.count_chats_by_tag_name_and_user_id(form_data.name, user.id)
== 0
):
await Tags.delete_tag_by_name_and_user_id(form_data.name, user.id)
chat = Chats.get_chat_by_id_and_user_id(id, user.id) chat = await Chats.get_chat_by_id_and_user_id(id, user.id)
tags = chat.meta.get("tags", []) tags = chat.meta.get("tags", [])
return Tags.get_tags_by_ids_and_user_id(tags, user.id) return await Tags.get_tags_by_ids_and_user_id(tags, user.id)
else: else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
@ -876,13 +884,13 @@ async def delete_tag_by_id_and_tag_name(
@router.delete("/{id}/tags/all", response_model=Optional[bool]) @router.delete("/{id}/tags/all", response_model=Optional[bool])
async def delete_all_tags_by_id(id: str, user=Depends(get_verified_user)): async def delete_all_tags_by_id(id: str, user=Depends(get_verified_user)):
chat = Chats.get_chat_by_id_and_user_id(id, user.id) chat = await Chats.get_chat_by_id_and_user_id(id, user.id)
if chat: if chat:
Chats.delete_all_tags_by_id_and_user_id(id, user.id) await Chats.delete_all_tags_by_id_and_user_id(id, user.id)
for tag in chat.meta.get("tags", []): for tag in chat.meta.get("tags", []):
if Chats.count_chats_by_tag_name_and_user_id(tag, user.id) == 0: if await Chats.count_chats_by_tag_name_and_user_id(tag, user.id) == 0:
Tags.delete_tag_by_name_and_user_id(tag, user.id) await Tags.delete_tag_by_name_and_user_id(tag, user.id)
return True return True
else: else:

View file

@ -73,11 +73,11 @@ class FeedbackUserResponse(FeedbackResponse):
@router.get("/feedbacks/all", response_model=list[FeedbackUserResponse]) @router.get("/feedbacks/all", response_model=list[FeedbackUserResponse])
async def get_all_feedbacks(user=Depends(get_admin_user)): async def get_all_feedbacks(user=Depends(get_admin_user)):
feedbacks = Feedbacks.get_all_feedbacks() feedbacks = await Feedbacks.get_all_feedbacks()
feedback_list = [] feedback_list = []
for feedback in feedbacks: for feedback in feedbacks:
user = Users.get_user_by_id(feedback.user_id) user = await Users.get_user_by_id(feedback.user_id)
feedback_list.append( feedback_list.append(
FeedbackUserResponse( FeedbackUserResponse(
**feedback.model_dump(), **feedback.model_dump(),
@ -89,25 +89,25 @@ async def get_all_feedbacks(user=Depends(get_admin_user)):
@router.delete("/feedbacks/all") @router.delete("/feedbacks/all")
async def delete_all_feedbacks(user=Depends(get_admin_user)): async def delete_all_feedbacks(user=Depends(get_admin_user)):
success = Feedbacks.delete_all_feedbacks() success = await Feedbacks.delete_all_feedbacks()
return success return success
@router.get("/feedbacks/all/export", response_model=list[FeedbackModel]) @router.get("/feedbacks/all/export", response_model=list[FeedbackModel])
async def get_all_feedbacks(user=Depends(get_admin_user)): async def get_all_feedbacks(user=Depends(get_admin_user)):
feedbacks = Feedbacks.get_all_feedbacks() feedbacks = await Feedbacks.get_all_feedbacks()
return feedbacks return feedbacks
@router.get("/feedbacks/user", response_model=list[FeedbackUserResponse]) @router.get("/feedbacks/user", response_model=list[FeedbackUserResponse])
async def get_feedbacks(user=Depends(get_verified_user)): async def get_feedbacks(user=Depends(get_verified_user)):
feedbacks = Feedbacks.get_feedbacks_by_user_id(user.id) feedbacks = await Feedbacks.get_feedbacks_by_user_id(user.id)
return feedbacks return feedbacks
@router.delete("/feedbacks", response_model=bool) @router.delete("/feedbacks", response_model=bool)
async def delete_feedbacks(user=Depends(get_verified_user)): async def delete_feedbacks(user=Depends(get_verified_user)):
success = Feedbacks.delete_feedbacks_by_user_id(user.id) success = await Feedbacks.delete_feedbacks_by_user_id(user.id)
return success return success
@ -117,7 +117,7 @@ async def create_feedback(
form_data: FeedbackForm, form_data: FeedbackForm,
user=Depends(get_verified_user), user=Depends(get_verified_user),
): ):
feedback = Feedbacks.insert_new_feedback(user_id=user.id, form_data=form_data) feedback = await Feedbacks.insert_new_feedback(user_id=user.id, form_data=form_data)
if not feedback: if not feedback:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
@ -130,9 +130,11 @@ async def create_feedback(
@router.get("/feedback/{id}", response_model=FeedbackModel) @router.get("/feedback/{id}", response_model=FeedbackModel)
async def get_feedback_by_id(id: str, user=Depends(get_verified_user)): async def get_feedback_by_id(id: str, user=Depends(get_verified_user)):
if user.role == "admin": if user.role == "admin":
feedback = Feedbacks.get_feedback_by_id(id=id) feedback = await Feedbacks.get_feedback_by_id(id=id)
else: else:
feedback = Feedbacks.get_feedback_by_id_and_user_id(id=id, user_id=user.id) feedback = await Feedbacks.get_feedback_by_id_and_user_id(
id=id, user_id=user.id
)
if not feedback: if not feedback:
raise HTTPException( raise HTTPException(
@ -147,9 +149,9 @@ async def update_feedback_by_id(
id: str, form_data: FeedbackForm, user=Depends(get_verified_user) id: str, form_data: FeedbackForm, user=Depends(get_verified_user)
): ):
if user.role == "admin": if user.role == "admin":
feedback = Feedbacks.update_feedback_by_id(id=id, form_data=form_data) feedback = await Feedbacks.update_feedback_by_id(id=id, form_data=form_data)
else: else:
feedback = Feedbacks.update_feedback_by_id_and_user_id( feedback = await Feedbacks.update_feedback_by_id_and_user_id(
id=id, user_id=user.id, form_data=form_data id=id, user_id=user.id, form_data=form_data
) )
@ -164,9 +166,11 @@ async def update_feedback_by_id(
@router.delete("/feedback/{id}") @router.delete("/feedback/{id}")
async def delete_feedback_by_id(id: str, user=Depends(get_verified_user)): async def delete_feedback_by_id(id: str, user=Depends(get_verified_user)):
if user.role == "admin": if user.role == "admin":
success = Feedbacks.delete_feedback_by_id(id=id) success = await Feedbacks.delete_feedback_by_id(id=id)
else: else:
success = Feedbacks.delete_feedback_by_id_and_user_id(id=id, user_id=user.id) success = await Feedbacks.delete_feedback_by_id_and_user_id(
id=id, user_id=user.id
)
if not success: if not success:
raise HTTPException( raise HTTPException(

View file

@ -6,6 +6,7 @@ from fnmatch import fnmatch
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
from urllib.parse import quote from urllib.parse import quote
import asyncio
from fastapi import ( from fastapi import (
APIRouter, APIRouter,
@ -137,22 +138,26 @@ def upload_file(
} }
contents, file_path = Storage.upload_file(file.file, filename, tags) contents, file_path = Storage.upload_file(file.file, filename, tags)
file_item = Files.insert_new_file( loop = asyncio.get_event_loop()
user.id, file_item = loop.run_until_complete(
FileForm( Files.insert_new_file(
**{ user.id,
"id": id, FileForm(
"filename": name, **{
"path": file_path, "id": id,
"meta": { "filename": name,
"name": name, "path": file_path,
"content_type": file.content_type, "meta": {
"size": len(contents), "name": name,
"data": file_metadata, "content_type": file.content_type,
}, "size": len(contents),
} "data": file_metadata,
), },
}
),
)
) )
if process: if process:
try: try:
if file.content_type: if file.content_type:
@ -187,7 +192,7 @@ def upload_file(
) )
process_file(request, ProcessFileForm(file_id=id), user=user) process_file(request, ProcessFileForm(file_id=id), user=user)
file_item = Files.get_file_by_id(id=id) file_item = loop.run_until_complete(Files.get_file_by_id(id=id))
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
log.error(f"Error processing file: {file_item.id}") log.error(f"Error processing file: {file_item.id}")
@ -489,7 +494,7 @@ async def get_html_file_content_by_id(id: str, user=Depends(get_verified_user)):
detail=ERROR_MESSAGES.NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND,
) )
file_user = Users.get_user_by_id(file.user_id) file_user = await Users.get_user_by_id(file.user_id)
if not file_user.role == "admin": if not file_user.role == "admin":
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,

View file

@ -50,7 +50,7 @@ async def get_folders(user=Depends(get_verified_user)):
"items": { "items": {
"chats": [ "chats": [
{"title": chat.title, "id": chat.id, "updated_at": chat.updated_at} {"title": chat.title, "id": chat.id, "updated_at": chat.updated_at}
for chat in Chats.get_chats_by_folder_id_and_user_id( for chat in await Chats.get_chats_by_folder_id_and_user_id(
folder.id, user.id folder.id, user.id
) )
] ]
@ -246,7 +246,7 @@ async def delete_folder_by_id(
try: try:
folder_ids = Folders.delete_folder_by_id_and_user_id(id, user.id) folder_ids = Folders.delete_folder_by_id_and_user_id(id, user.id)
for folder_id in folder_ids: for folder_id in folder_ids:
Chats.delete_chats_by_user_id_and_folder_id(user.id, folder_id) await Chats.delete_chats_by_user_id_and_folder_id(user.id, folder_id)
return True return True
except Exception as e: except Exception as e:

View file

@ -45,7 +45,9 @@ async def get_notes(request: Request, user=Depends(get_verified_user)):
NoteUserResponse( NoteUserResponse(
**{ **{
**note.model_dump(), **note.model_dump(),
"user": UserResponse(**Users.get_user_by_id(note.user_id).model_dump()), "user": UserResponse(
**((await Users.get_user_by_id(note.user_id)).model_dump())
),
} }
) )
for note in Notes.get_notes_by_user_id(user.id, "write") for note in Notes.get_notes_by_user_id(user.id, "write")

View file

@ -350,7 +350,7 @@ def group_to_scim(group: GroupModel, request: Request) -> SCIMGroup:
"""Convert internal Group model to SCIM Group""" """Convert internal Group model to SCIM Group"""
members = [] members = []
for user_id in group.user_ids: for user_id in group.user_ids:
user = Users.get_user_by_id(user_id) user = await Users.get_user_by_id(user_id)
if user: if user:
members.append( members.append(
SCIMGroupMember( SCIMGroupMember(
@ -523,7 +523,7 @@ async def get_user(
_: bool = Depends(get_scim_auth), _: bool = Depends(get_scim_auth),
): ):
"""Get SCIM User by ID""" """Get SCIM User by ID"""
user = Users.get_user_by_id(user_id) user = await Users.get_user_by_id(user_id)
if not user: if not user:
return scim_error( return scim_error(
status_code=status.HTTP_404_NOT_FOUND, detail=f"User {user_id} not found" status_code=status.HTTP_404_NOT_FOUND, detail=f"User {user_id} not found"
@ -565,7 +565,7 @@ async def create_user(
profile_image = user_data.photos[0].value profile_image = user_data.photos[0].value
# Create user # Create user
new_user = Users.insert_new_user( new_user = await Users.insert_new_user(
id=user_id, id=user_id,
name=name, name=name,
email=email, email=email,
@ -590,7 +590,7 @@ async def update_user(
_: bool = Depends(get_scim_auth), _: bool = Depends(get_scim_auth),
): ):
"""Update SCIM User (full update)""" """Update SCIM User (full update)"""
user = Users.get_user_by_id(user_id) user = await Users.get_user_by_id(user_id)
if not user: if not user:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
@ -641,7 +641,7 @@ async def patch_user(
_: bool = Depends(get_scim_auth), _: bool = Depends(get_scim_auth),
): ):
"""Update SCIM User (partial update)""" """Update SCIM User (partial update)"""
user = Users.get_user_by_id(user_id) user = await Users.get_user_by_id(user_id)
if not user: if not user:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
@ -688,7 +688,7 @@ async def delete_user(
_: bool = Depends(get_scim_auth), _: bool = Depends(get_scim_auth),
): ):
"""Delete SCIM User""" """Delete SCIM User"""
user = Users.get_user_by_id(user_id) user = await Users.get_user_by_id(user_id)
if not user: if not user:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,

View file

@ -310,7 +310,7 @@ async def get_user_by_id(user_id: str, user=Depends(get_verified_user)):
# If it is, get the user_id from the chat # If it is, get the user_id from the chat
if user_id.startswith("shared-"): if user_id.startswith("shared-"):
chat_id = user_id.replace("shared-", "") chat_id = user_id.replace("shared-", "")
chat = Chats.get_chat_by_id(chat_id) chat = await Chats.get_chat_by_id(chat_id)
if chat: if chat:
user_id = chat.user_id user_id = chat.user_id
else: else:
@ -319,7 +319,7 @@ async def get_user_by_id(user_id: str, user=Depends(get_verified_user)):
detail=ERROR_MESSAGES.USER_NOT_FOUND, detail=ERROR_MESSAGES.USER_NOT_FOUND,
) )
user = Users.get_user_by_id(user_id) user = await Users.get_user_by_id(user_id)
if user: if user:
return UserResponse( return UserResponse(
@ -422,11 +422,11 @@ async def update_user_by_id(
detail="Could not verify primary admin status.", detail="Could not verify primary admin status.",
) )
user = Users.get_user_by_id(user_id) user = await Users.get_user_by_id(user_id)
if user: if user:
if form_data.email.lower() != user.email: if form_data.email.lower() != user.email:
email_user = Users.get_user_by_email(form_data.email.lower()) email_user = await Users.get_user_by_email(form_data.email.lower())
if email_user: if email_user:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
@ -436,10 +436,10 @@ async def update_user_by_id(
if form_data.password: if form_data.password:
hashed = get_password_hash(form_data.password) hashed = get_password_hash(form_data.password)
log.debug(f"hashed: {hashed}") log.debug(f"hashed: {hashed}")
Auths.update_user_password_by_id(user_id, hashed) await Auths.update_user_password_by_id(user_id, hashed)
Auths.update_email_by_id(user_id, form_data.email.lower()) await Auths.update_email_by_id(user_id, form_data.email.lower())
updated_user = Users.update_user_by_id( updated_user = await Users.update_user_by_id(
user_id, user_id,
{ {
"role": form_data.role, "role": form_data.role,
@ -486,7 +486,7 @@ async def delete_user_by_id(user_id: str, user=Depends(get_admin_user)):
) )
if user.id != user_id: if user.id != user_id:
result = Auths.delete_auth_by_id(user_id) result = await Auths.delete_auth_by_id(user_id)
if result: if result:
return True return True

View file

@ -295,7 +295,7 @@ async def user_join(sid, data):
USER_POOL[user.id] = [sid] USER_POOL[user.id] = [sid]
# Join all the channels # Join all the channels
channels = Channels.get_channels_by_user_id(user.id) channels = await Channels.get_channels_by_user_id(user.id)
log.debug(f"{channels=}") log.debug(f"{channels=}")
for channel in channels: for channel in channels:
await sio.enter_room(sid, f"channel:{channel.id}") await sio.enter_room(sid, f"channel:{channel.id}")
@ -317,7 +317,7 @@ async def join_channel(sid, data):
return return
# Join all the channels # Join all the channels
channels = Channels.get_channels_by_user_id(user.id) channels = await Channels.get_channels_by_user_id(user.id)
log.debug(f"{channels=}") log.debug(f"{channels=}")
for channel in channels: for channel in channels:
await sio.enter_room(sid, f"channel:{channel.id}") await sio.enter_room(sid, f"channel:{channel.id}")
@ -668,14 +668,14 @@ def get_event_emitter(request_info, update_db=True):
if update_db: if update_db:
if "type" in event_data and event_data["type"] == "status": if "type" in event_data and event_data["type"] == "status":
Chats.add_message_status_to_chat_by_id_and_message_id( await Chats.add_message_status_to_chat_by_id_and_message_id(
request_info["chat_id"], request_info["chat_id"],
request_info["message_id"], request_info["message_id"],
event_data.get("data", {}), event_data.get("data", {}),
) )
if "type" in event_data and event_data["type"] == "message": if "type" in event_data and event_data["type"] == "message":
message = Chats.get_message_by_id_and_message_id( message = await Chats.get_message_by_id_and_message_id(
request_info["chat_id"], request_info["chat_id"],
request_info["message_id"], request_info["message_id"],
) )
@ -684,7 +684,7 @@ def get_event_emitter(request_info, update_db=True):
content = message.get("content", "") content = message.get("content", "")
content += event_data.get("data", {}).get("content", "") content += event_data.get("data", {}).get("content", "")
Chats.upsert_message_to_chat_by_id_and_message_id( await Chats.upsert_message_to_chat_by_id_and_message_id(
request_info["chat_id"], request_info["chat_id"],
request_info["message_id"], request_info["message_id"],
{ {
@ -695,7 +695,7 @@ def get_event_emitter(request_info, update_db=True):
if "type" in event_data and event_data["type"] == "replace": if "type" in event_data and event_data["type"] == "replace":
content = event_data.get("data", {}).get("content", "") content = event_data.get("data", {}).get("content", "")
Chats.upsert_message_to_chat_by_id_and_message_id( await Chats.upsert_message_to_chat_by_id_and_message_id(
request_info["chat_id"], request_info["chat_id"],
request_info["message_id"], request_info["message_id"],
{ {

View file

@ -771,9 +771,9 @@ async def process_chat_payload(request, form_data, user, metadata, model):
# Check if the request has chat_id and is inside of a folder # Check if the request has chat_id and is inside of a folder
chat_id = metadata.get("chat_id", None) chat_id = metadata.get("chat_id", None)
if chat_id and user: if chat_id and user:
chat = Chats.get_chat_by_id_and_user_id(chat_id, user.id) chat = await Chats.get_chat_by_id_and_user_id(chat_id, user.id)
if chat and chat.folder_id: if chat and chat.folder_id:
folder = Folders.get_folder_by_id_and_user_id(chat.folder_id, user.id) folder = await Folders.get_folder_by_id_and_user_id(chat.folder_id, user.id)
if folder and folder.data: if folder and folder.data:
if "system_prompt" in folder.data: if "system_prompt" in folder.data:
@ -1042,7 +1042,7 @@ async def process_chat_response(
request, response, form_data, user, metadata, model, events, tasks request, response, form_data, user, metadata, model, events, tasks
): ):
async def background_tasks_handler(): async def background_tasks_handler():
message_map = Chats.get_messages_by_chat_id(metadata["chat_id"]) message_map = await Chats.get_messages_by_chat_id(metadata["chat_id"])
message = message_map.get(metadata["message_id"]) if message_map else None message = message_map.get(metadata["message_id"]) if message_map else None
if message: if message:
@ -1115,7 +1115,7 @@ async def process_chat_response(
"follow_ups", [] "follow_ups", []
) )
Chats.upsert_message_to_chat_by_id_and_message_id( await Chats.upsert_message_to_chat_by_id_and_message_id(
metadata["chat_id"], metadata["chat_id"],
metadata["message_id"], metadata["message_id"],
{ {
@ -1177,7 +1177,9 @@ async def process_chat_response(
if not title: if not title:
title = messages[0].get("content", user_message) title = messages[0].get("content", user_message)
Chats.update_chat_title_by_id(metadata["chat_id"], title) await Chats.update_chat_title_by_id(
metadata["chat_id"], title
)
await event_emitter( await event_emitter(
{ {
@ -1188,7 +1190,7 @@ async def process_chat_response(
elif len(messages) == 2: elif len(messages) == 2:
title = messages[0].get("content", user_message) title = messages[0].get("content", user_message)
Chats.update_chat_title_by_id(metadata["chat_id"], title) await Chats.update_chat_title_by_id(metadata["chat_id"], title)
await event_emitter( await event_emitter(
{ {
@ -1224,7 +1226,7 @@ async def process_chat_response(
try: try:
tags = json.loads(tags_string).get("tags", []) tags = json.loads(tags_string).get("tags", [])
Chats.update_chat_tags_by_id( await Chats.update_chat_tags_by_id(
metadata["chat_id"], tags, user metadata["chat_id"], tags, user
) )
@ -1267,7 +1269,7 @@ async def process_chat_response(
if "error" in response_data: if "error" in response_data:
error = response_data["error"].get("detail", response_data["error"]) error = response_data["error"].get("detail", response_data["error"])
Chats.upsert_message_to_chat_by_id_and_message_id( await Chats.upsert_message_to_chat_by_id_and_message_id(
metadata["chat_id"], metadata["chat_id"],
metadata["message_id"], metadata["message_id"],
{ {
@ -1276,7 +1278,7 @@ async def process_chat_response(
) )
if "selected_model_id" in response_data: if "selected_model_id" in response_data:
Chats.upsert_message_to_chat_by_id_and_message_id( await Chats.upsert_message_to_chat_by_id_and_message_id(
metadata["chat_id"], metadata["chat_id"],
metadata["message_id"], metadata["message_id"],
{ {
@ -1296,7 +1298,7 @@ async def process_chat_response(
} }
) )
title = Chats.get_chat_title_by_id(metadata["chat_id"]) title = await Chats.get_chat_title_by_id(metadata["chat_id"])
await event_emitter( await event_emitter(
{ {
@ -1310,7 +1312,7 @@ async def process_chat_response(
) )
# Save message in the database # Save message in the database
Chats.upsert_message_to_chat_by_id_and_message_id( await Chats.upsert_message_to_chat_by_id_and_message_id(
metadata["chat_id"], metadata["chat_id"],
metadata["message_id"], metadata["message_id"],
{ {
@ -1767,7 +1769,7 @@ async def process_chat_response(
return content, content_blocks, end_flag return content, content_blocks, end_flag
message = Chats.get_message_by_id_and_message_id( message = await Chats.get_message_by_id_and_message_id(
metadata["chat_id"], metadata["message_id"] metadata["chat_id"], metadata["message_id"]
) )
@ -1827,7 +1829,7 @@ async def process_chat_response(
) )
# Save message in the database # Save message in the database
Chats.upsert_message_to_chat_by_id_and_message_id( await Chats.upsert_message_to_chat_by_id_and_message_id(
metadata["chat_id"], metadata["chat_id"],
metadata["message_id"], metadata["message_id"],
{ {
@ -1882,7 +1884,7 @@ async def process_chat_response(
if "selected_model_id" in data: if "selected_model_id" in data:
model_id = data["selected_model_id"] model_id = data["selected_model_id"]
Chats.upsert_message_to_chat_by_id_and_message_id( await Chats.upsert_message_to_chat_by_id_and_message_id(
metadata["chat_id"], metadata["chat_id"],
metadata["message_id"], metadata["message_id"],
{ {
@ -2081,7 +2083,7 @@ async def process_chat_response(
if ENABLE_REALTIME_CHAT_SAVE: if ENABLE_REALTIME_CHAT_SAVE:
# Save message in the database # Save message in the database
Chats.upsert_message_to_chat_by_id_and_message_id( await Chats.upsert_message_to_chat_by_id_and_message_id(
metadata["chat_id"], metadata["chat_id"],
metadata["message_id"], metadata["message_id"],
{ {
@ -2502,7 +2504,7 @@ async def process_chat_response(
log.debug(e) log.debug(e)
break break
title = Chats.get_chat_title_by_id(metadata["chat_id"]) title = await Chats.get_chat_title_by_id(metadata["chat_id"])
data = { data = {
"done": True, "done": True,
"content": serialize_content_blocks(content_blocks), "content": serialize_content_blocks(content_blocks),
@ -2511,7 +2513,7 @@ async def process_chat_response(
if not ENABLE_REALTIME_CHAT_SAVE: if not ENABLE_REALTIME_CHAT_SAVE:
# Save message in the database # Save message in the database
Chats.upsert_message_to_chat_by_id_and_message_id( await Chats.upsert_message_to_chat_by_id_and_message_id(
metadata["chat_id"], metadata["chat_id"],
metadata["message_id"], metadata["message_id"],
{ {
@ -2549,7 +2551,7 @@ async def process_chat_response(
if not ENABLE_REALTIME_CHAT_SAVE: if not ENABLE_REALTIME_CHAT_SAVE:
# Save message in the database # Save message in the database
Chats.upsert_message_to_chat_by_id_and_message_id( await Chats.upsert_message_to_chat_by_id_and_message_id(
metadata["chat_id"], metadata["chat_id"],
metadata["message_id"], metadata["message_id"],
{ {

View file

@ -490,7 +490,7 @@ class OAuthManager:
role = self.get_user_role(None, user_data) role = self.get_user_role(None, user_data)
user = Auths.insert_new_auth( user = await Auths.insert_new_auth(
email=email, email=email,
password=get_password_hash( password=get_password_hash(
str(uuid.uuid4()) str(uuid.uuid4())