From 44e9ae243de2b95b4c7338c8454066522112d28f Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Thu, 14 Aug 2025 15:46:18 +0400 Subject: [PATCH] init --- backend/open_webui/config.py | 58 ++-- backend/open_webui/internal/db.py | 44 +-- backend/open_webui/main.py | 14 +- backend/open_webui/models/auths.py | 58 ++-- backend/open_webui/models/channels.py | 40 +-- backend/open_webui/models/chats.py | 367 +++++++++++----------- backend/open_webui/models/feedbacks.py | 102 +++--- backend/open_webui/models/files.py | 94 +++--- backend/open_webui/models/folders.py | 126 ++++---- backend/open_webui/models/functions.py | 118 +++---- backend/open_webui/models/groups.py | 24 +- backend/open_webui/models/knowledge.py | 18 +- backend/open_webui/models/memories.py | 16 +- backend/open_webui/models/messages.py | 26 +- backend/open_webui/models/models.py | 24 +- backend/open_webui/models/notes.py | 10 +- backend/open_webui/models/prompts.py | 14 +- backend/open_webui/models/tags.py | 10 +- backend/open_webui/models/tools.py | 26 +- backend/open_webui/models/users.py | 178 ++++++----- backend/open_webui/routers/auths.py | 22 +- backend/open_webui/routers/channels.py | 60 ++-- backend/open_webui/routers/chats.py | 148 ++++----- backend/open_webui/routers/evaluations.py | 30 +- backend/open_webui/routers/files.py | 39 ++- backend/open_webui/routers/folders.py | 4 +- backend/open_webui/routers/notes.py | 4 +- backend/open_webui/routers/scim.py | 12 +- backend/open_webui/routers/users.py | 16 +- backend/open_webui/socket/main.py | 12 +- backend/open_webui/utils/middleware.py | 38 +-- backend/open_webui/utils/oauth.py | 2 +- 32 files changed, 927 insertions(+), 827 deletions(-) diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index 43c864ef22..3f8fb139a5 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -4,6 +4,7 @@ import os import shutil import base64 import redis +import asyncio from datetime import datetime from pathlib import Path @@ -85,23 +86,28 @@ def load_json_config(): return json.load(file) -def save_to_db(data): - with get_db() as db: - existing_config = db.query(Config).first() +async def asave_to_db(data): + async with get_db() as db: + existing_config = await db.query(Config).first() if not existing_config: new_config = Config(data=data, version=0) - db.add(new_config) + await db.add(new_config) else: existing_config.data = data existing_config.updated_at = datetime.now() - db.add(existing_config) - db.commit() + await db.add(existing_config) + await db.commit() -def reset_config(): - with get_db() as db: - db.query(Config).delete() - db.commit() +def save_to_db(data): + loop = asyncio.get_event_loop() + result = loop.run_until_complete(asave_to_db(data)) + + +async def reset_config(): + async with get_db() as db: + await db.query(Config).delete() + await db.commit() # When initializing, check if config.json exists and migrate it to the database @@ -116,25 +122,14 @@ DEFAULT_CONFIG = { } -def get_config(): - with get_db() as db: - config_entry = db.query(Config).order_by(Config.id.desc()).first() +async def get_config(): + async with get_db() as db: + config_entry = await db.query(Config).order_by(Config.id.desc()).first() return config_entry.data if config_entry else DEFAULT_CONFIG -CONFIG_DATA = get_config() - - -def get_config_value(config_path: str): - path_parts = config_path.split(".") - cur_config = CONFIG_DATA - for key in path_parts: - if key in cur_config: - cur_config = cur_config[key] - else: - return None - return cur_config - +loop = asyncio.get_event_loop() +CONFIG_DATA = loop.run_until_complete(get_config()) PERSISTENT_CONFIG_REGISTRY = [] @@ -162,6 +157,17 @@ ENABLE_PERSISTENT_CONFIG = ( ) +def get_config_value(config_path: str): + path_parts = config_path.split(".") + cur_config = CONFIG_DATA + for key in path_parts: + if key in cur_config: + cur_config = cur_config[key] + else: + return None + return cur_config + + class PersistentConfig(Generic[T]): def __init__(self, env_name: str, config_path: str, env_value: T): self.env_name = env_name diff --git a/backend/open_webui/internal/db.py b/backend/open_webui/internal/db.py index d7a200ff20..1b94c138da 100644 --- a/backend/open_webui/internal/db.py +++ b/backend/open_webui/internal/db.py @@ -1,7 +1,6 @@ import os import json import logging -from contextlib import contextmanager from typing import Any, Optional from open_webui.internal.wrappers import register_connection @@ -21,6 +20,13 @@ from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import scoped_session, sessionmaker from sqlalchemy.pool import QueuePool, NullPool from sqlalchemy.sql.type_api import _T +from sqlalchemy.ext.asyncio import ( + create_async_engine, + AsyncSession, + async_sessionmaker, + AsyncAttrs, +) +from contextlib import asynccontextmanager from typing_extensions import Self log = logging.getLogger(__name__) @@ -102,7 +108,7 @@ if SQLALCHEMY_DATABASE_URL.startswith("sqlite+sqlcipher://"): conn.execute(f"PRAGMA key = '{database_password}'") return conn - engine = create_engine( + engine = create_async_engine( "sqlite://", # Dummy URL since we're using creator creator=create_sqlcipher_connection, echo=False, @@ -111,13 +117,13 @@ if SQLALCHEMY_DATABASE_URL.startswith("sqlite+sqlcipher://"): log.info("Connected to encrypted SQLite database using SQLCipher") elif "sqlite" in SQLALCHEMY_DATABASE_URL: - engine = create_engine( + engine = create_async_engine( SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False} ) else: if isinstance(DATABASE_POOL_SIZE, int): if DATABASE_POOL_SIZE > 0: - engine = create_engine( + engine = create_async_engine( SQLALCHEMY_DATABASE_URL, pool_size=DATABASE_POOL_SIZE, max_overflow=DATABASE_POOL_MAX_OVERFLOW, @@ -127,27 +133,27 @@ else: poolclass=QueuePool, ) else: - engine = create_engine( + engine = create_async_engine( SQLALCHEMY_DATABASE_URL, pool_pre_ping=True, poolclass=NullPool ) else: - engine = create_engine(SQLALCHEMY_DATABASE_URL, pool_pre_ping=True) + engine = create_async_engine(SQLALCHEMY_DATABASE_URL, pool_pre_ping=True) - -SessionLocal = sessionmaker( - autocommit=False, autoflush=False, bind=engine, expire_on_commit=False +AsyncSessionLocal = async_sessionmaker( + bind=engine, + class_=AsyncSession, + expire_on_commit=False, + autoflush=False, + autocommit=False, ) metadata_obj = MetaData(schema=DATABASE_SCHEMA) Base = declarative_base(metadata=metadata_obj) -Session = scoped_session(SessionLocal) -def get_session(): - db = SessionLocal() - try: - yield db - finally: - db.close() - - -get_db = contextmanager(get_session) +@asynccontextmanager +async def get_db(): + async with AsyncSessionLocal() as session: + try: + yield session + finally: + await session.close() diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 071929ed42..4e867103a8 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -1459,7 +1459,9 @@ async def chat_completion( if metadata.get("chat_id") and (user and user.role != "admin"): if metadata["chat_id"] != "local": - chat = Chats.get_chat_by_id_and_user_id(metadata["chat_id"], user.id) + chat = await Chats.get_chat_by_id_and_user_id( + metadata["chat_id"], user.id + ) if chat is None: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, @@ -1477,7 +1479,7 @@ async def chat_completion( if metadata.get("chat_id") and metadata.get("message_id"): # Update the chat message with the error try: - Chats.upsert_message_to_chat_by_id_and_message_id( + await Chats.upsert_message_to_chat_by_id_and_message_id( metadata["chat_id"], metadata["message_id"], { @@ -1496,7 +1498,7 @@ async def chat_completion( response = await chat_completion_handler(request, form_data, user) if metadata.get("chat_id") and metadata.get("message_id"): try: - Chats.upsert_message_to_chat_by_id_and_message_id( + await Chats.upsert_message_to_chat_by_id_and_message_id( metadata["chat_id"], metadata["message_id"], { @@ -1514,7 +1516,7 @@ async def chat_completion( if metadata.get("chat_id") and metadata.get("message_id"): # Update the chat message with the error try: - Chats.upsert_message_to_chat_by_id_and_message_id( + await Chats.upsert_message_to_chat_by_id_and_message_id( metadata["chat_id"], metadata["message_id"], { @@ -1593,7 +1595,7 @@ async def list_tasks_endpoint(request: Request, user=Depends(get_verified_user)) async def list_tasks_by_chat_id_endpoint( request: Request, chat_id: str, user=Depends(get_verified_user) ): - chat = Chats.get_chat_by_id(chat_id) + chat = await Chats.get_chat_by_id(chat_id) if chat is None or chat.user_id != user.id: return {"task_ids": []} @@ -1624,7 +1626,7 @@ async def get_app_config(request: Request): detail="Invalid token", ) if data is not None and "id" in data: - user = Users.get_user_by_id(data["id"]) + user = await Users.get_user_by_id(data["id"]) user_count = Users.get_num_users() onboarding = False diff --git a/backend/open_webui/models/auths.py b/backend/open_webui/models/auths.py index 3ad88bc119..72f314eb5c 100644 --- a/backend/open_webui/models/auths.py +++ b/backend/open_webui/models/auths.py @@ -95,7 +95,7 @@ class AddUserForm(SignupForm): class AuthsTable: - def insert_new_auth( + async def insert_new_auth( self, email: str, password: str, @@ -104,7 +104,7 @@ class AuthsTable: role: str = "pending", oauth_sub: Optional[str] = None, ) -> Optional[UserModel]: - with get_db() as db: + async with get_db() as db: log.info("insert_new_auth") id = str(uuid.uuid4()) @@ -115,28 +115,28 @@ class AuthsTable: result = Auth(**auth.model_dump()) db.add(result) - user = Users.insert_new_user( + user = await Users.insert_new_user( id, name, email, profile_image_url, role, oauth_sub ) - db.commit() - db.refresh(result) + await db.commit() + await db.refresh(result) if result and user: return user else: return None - def authenticate_user(self, email: str, password: str) -> Optional[UserModel]: + async def authenticate_user(self, email: str, password: str) -> Optional[UserModel]: log.info(f"authenticate_user: {email}") - user = Users.get_user_by_email(email) + user = await Users.get_user_by_email(email) if not user: return None try: - with get_db() as db: - auth = db.query(Auth).filter_by(id=user.id, active=True).first() + async with get_db() as db: + auth = await db.query(Auth).filter_by(id=user.id, active=True).first() if auth: if verify_password(password, auth.password): return user @@ -147,58 +147,60 @@ class AuthsTable: except Exception: return None - def authenticate_user_by_api_key(self, api_key: str) -> Optional[UserModel]: + async def authenticate_user_by_api_key(self, api_key: str) -> Optional[UserModel]: log.info(f"authenticate_user_by_api_key: {api_key}") # if no api_key, return None if not api_key: return None try: - user = Users.get_user_by_api_key(api_key) + user = await Users.get_user_by_api_key(api_key) return user if user else None except Exception: return False - def authenticate_user_by_email(self, email: str) -> Optional[UserModel]: + async def authenticate_user_by_email(self, email: str) -> Optional[UserModel]: log.info(f"authenticate_user_by_email: {email}") try: - with get_db() as db: - auth = db.query(Auth).filter_by(email=email, active=True).first() + async with get_db() as db: + auth = await db.query(Auth).filter_by(email=email, active=True).first() if auth: - user = Users.get_user_by_id(auth.id) + user = await Users.get_user_by_id(auth.id) return user except Exception: return None - def update_user_password_by_id(self, id: str, new_password: str) -> bool: + async def update_user_password_by_id(self, id: str, new_password: str) -> bool: try: - with get_db() as db: + async with get_db() as db: result = ( - db.query(Auth).filter_by(id=id).update({"password": new_password}) + await db.query(Auth) + .filter_by(id=id) + .update({"password": new_password}) ) - db.commit() + await db.commit() return True if result == 1 else False except Exception: return False - def update_email_by_id(self, id: str, email: str) -> bool: + async def update_email_by_id(self, id: str, email: str) -> bool: try: - with get_db() as db: - result = db.query(Auth).filter_by(id=id).update({"email": email}) - db.commit() + async with get_db() as db: + result = await db.query(Auth).filter_by(id=id).update({"email": email}) + await db.commit() return True if result == 1 else False except Exception: return False - def delete_auth_by_id(self, id: str) -> bool: + async def delete_auth_by_id(self, id: str) -> bool: try: - with get_db() as db: + async with get_db() as db: # Delete User - result = Users.delete_user_by_id(id) + result = await Users.delete_user_by_id(id) if result: - db.query(Auth).filter_by(id=id).delete() - db.commit() + await db.query(Auth).filter_by(id=id).delete() + await db.commit() return True else: diff --git a/backend/open_webui/models/channels.py b/backend/open_webui/models/channels.py index 92f238c3a0..ac018f5a7a 100644 --- a/backend/open_webui/models/channels.py +++ b/backend/open_webui/models/channels.py @@ -66,10 +66,10 @@ class ChannelForm(BaseModel): class ChannelTable: - def insert_new_channel( + async def insert_new_channel( self, type: Optional[str], form_data: ChannelForm, user_id: str ) -> Optional[ChannelModel]: - with get_db() as db: + async with get_db() as db: channel = ChannelModel( **{ **form_data.model_dump(), @@ -84,19 +84,19 @@ class ChannelTable: new_channel = Channel(**channel.model_dump()) - db.add(new_channel) - db.commit() + await db.add(new_channel) + await db.commit() return channel - def get_channels(self) -> list[ChannelModel]: - with get_db() as db: - channels = db.query(Channel).all() + async def get_channels(self) -> list[ChannelModel]: + async with get_db() as db: + channels = await db.query(Channel).all() return [ChannelModel.model_validate(channel) for channel in channels] - def get_channels_by_user_id( + async def get_channels_by_user_id( self, user_id: str, permission: str = "read" ) -> list[ChannelModel]: - channels = self.get_channels() + channels = await self.get_channels() return [ channel for channel in channels @@ -104,16 +104,16 @@ class ChannelTable: or has_access(user_id, permission, channel.access_control) ] - def get_channel_by_id(self, id: str) -> Optional[ChannelModel]: - with get_db() as db: - channel = db.query(Channel).filter(Channel.id == id).first() + async def get_channel_by_id(self, id: str) -> Optional[ChannelModel]: + async with get_db() as db: + channel = await db.query(Channel).filter(Channel.id == id).first() return ChannelModel.model_validate(channel) if channel else None - def update_channel_by_id( + async def update_channel_by_id( self, id: str, form_data: ChannelForm ) -> Optional[ChannelModel]: - with get_db() as db: - channel = db.query(Channel).filter(Channel.id == id).first() + async with get_db() as db: + channel = await db.query(Channel).filter(Channel.id == id).first() if not channel: return None @@ -123,13 +123,13 @@ class ChannelTable: channel.access_control = form_data.access_control channel.updated_at = int(time.time_ns()) - db.commit() + await db.commit() return ChannelModel.model_validate(channel) if channel else None - def delete_channel_by_id(self, id: str): - with get_db() as db: - db.query(Channel).filter(Channel.id == id).delete() - db.commit() + async def delete_channel_by_id(self, id: str): + async with get_db() as db: + await db.query(Channel).filter(Channel.id == id).delete() + await db.commit() return True diff --git a/backend/open_webui/models/chats.py b/backend/open_webui/models/chats.py index a70af898d4..8026b0d101 100644 --- a/backend/open_webui/models/chats.py +++ b/backend/open_webui/models/chats.py @@ -109,8 +109,10 @@ class ChatTitleIdResponse(BaseModel): class ChatTable: - def insert_new_chat(self, user_id: str, form_data: ChatForm) -> Optional[ChatModel]: - with get_db() as db: + async def insert_new_chat( + self, user_id: str, form_data: ChatForm + ) -> Optional[ChatModel]: + async with get_db() as db: id = str(uuid.uuid4()) chat = ChatModel( **{ @@ -129,15 +131,15 @@ class ChatTable: ) result = Chat(**chat.model_dump()) - db.add(result) - db.commit() - db.refresh(result) + await db.add(result) + await db.commit() + await db.refresh(result) return ChatModel.model_validate(result) if result else None - def import_chat( + async def import_chat( self, user_id: str, form_data: ChatImportForm ) -> Optional[ChatModel]: - with get_db() as db: + async with get_db() as db: id = str(uuid.uuid4()) chat = ChatModel( **{ @@ -166,82 +168,84 @@ class ChatTable: ) result = Chat(**chat.model_dump()) - db.add(result) - db.commit() - db.refresh(result) + await db.add(result) + await db.commit() + await db.refresh(result) return ChatModel.model_validate(result) if result else None - def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]: + async def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]: try: - with get_db() as db: - chat_item = db.get(Chat, id) + async with get_db() as db: + chat_item = await db.get(Chat, id) chat_item.chat = chat chat_item.title = chat["title"] if "title" in chat else "New Chat" chat_item.updated_at = int(time.time()) - db.commit() - db.refresh(chat_item) + await db.commit() + await db.refresh(chat_item) return ChatModel.model_validate(chat_item) except Exception: return None - def update_chat_title_by_id(self, id: str, title: str) -> Optional[ChatModel]: - chat = self.get_chat_by_id(id) + async def update_chat_title_by_id(self, id: str, title: str) -> Optional[ChatModel]: + chat = await self.get_chat_by_id(id) if chat is None: return None chat = chat.chat chat["title"] = title - return self.update_chat_by_id(id, chat) + return await self.update_chat_by_id(id, chat) - def update_chat_tags_by_id( + async def update_chat_tags_by_id( self, id: str, tags: list[str], user ) -> Optional[ChatModel]: - chat = self.get_chat_by_id(id) + chat = await self.get_chat_by_id(id) if chat is None: return None - self.delete_all_tags_by_id_and_user_id(id, user.id) + await self.delete_all_tags_by_id_and_user_id(id, user.id) for tag in chat.meta.get("tags", []): - if self.count_chats_by_tag_name_and_user_id(tag, user.id) == 0: - Tags.delete_tag_by_name_and_user_id(tag, user.id) + if await self.count_chats_by_tag_name_and_user_id(tag, user.id) == 0: + await Tags.delete_tag_by_name_and_user_id(tag, user.id) for tag_name in tags: if tag_name.lower() == "none": continue - self.add_chat_tag_by_id_and_user_id_and_tag_name(id, user.id, tag_name) - return self.get_chat_by_id(id) + await self.add_chat_tag_by_id_and_user_id_and_tag_name( + id, user.id, tag_name + ) + return await self.get_chat_by_id(id) - def get_chat_title_by_id(self, id: str) -> Optional[str]: - chat = self.get_chat_by_id(id) + async def get_chat_title_by_id(self, id: str) -> Optional[str]: + chat = await self.get_chat_by_id(id) if chat is None: return None return chat.chat.get("title", "New Chat") - def get_messages_by_chat_id(self, id: str) -> Optional[dict]: - chat = self.get_chat_by_id(id) + async def get_messages_by_chat_id(self, id: str) -> Optional[dict]: + chat = await self.get_chat_by_id(id) if chat is None: return None return chat.chat.get("history", {}).get("messages", {}) or {} - def get_message_by_id_and_message_id( + async def get_message_by_id_and_message_id( self, id: str, message_id: str ) -> Optional[dict]: - chat = self.get_chat_by_id(id) + chat = await self.get_chat_by_id(id) if chat is None: return None return chat.chat.get("history", {}).get("messages", {}).get(message_id, {}) - def upsert_message_to_chat_by_id_and_message_id( + async def upsert_message_to_chat_by_id_and_message_id( self, id: str, message_id: str, message: dict ) -> Optional[ChatModel]: - chat = self.get_chat_by_id(id) + chat = await self.get_chat_by_id(id) if chat is None: return None @@ -263,12 +267,12 @@ class ChatTable: history["currentId"] = message_id chat["history"] = history - return self.update_chat_by_id(id, chat) + return await self.update_chat_by_id(id, chat) - def add_message_status_to_chat_by_id_and_message_id( + async def add_message_status_to_chat_by_id_and_message_id( self, id: str, message_id: str, status: dict ) -> Optional[ChatModel]: - chat = self.get_chat_by_id(id) + chat = await self.get_chat_by_id(id) if chat is None: return None @@ -281,15 +285,15 @@ class ChatTable: history["messages"][message_id]["statusHistory"] = status_history chat["history"] = history - return self.update_chat_by_id(id, chat) + return await self.update_chat_by_id(id, chat) - def insert_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]: - with get_db() as db: + async def insert_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]: + async with get_db() as db: # Get the existing chat to share - chat = db.get(Chat, chat_id) + chat = await db.get(Chat, chat_id) # Check if the chat is already shared if chat.share_id: - return self.get_chat_by_id_and_user_id(chat.share_id, "shared") + return await self.get_chat_by_id_and_user_id(chat.share_id, "shared") # Create a new chat with the same data, but with a new ID shared_chat = ChatModel( **{ @@ -305,29 +309,30 @@ class ChatTable: } ) shared_result = Chat(**shared_chat.model_dump()) - db.add(shared_result) - db.commit() - db.refresh(shared_result) + await db.add(shared_result) + await db.commit() + await db.refresh(shared_result) # Update the original chat with the share_id result = ( - db.query(Chat) + await db.query(Chat) .filter_by(id=chat_id) .update({"share_id": shared_chat.id}) ) - db.commit() + await db.commit() + return shared_chat if (shared_result and result) else None - def update_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]: + async def update_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]: try: - with get_db() as db: - chat = db.get(Chat, chat_id) + async with get_db() as db: + chat = await db.get(Chat, chat_id) shared_chat = ( - db.query(Chat).filter_by(user_id=f"shared-{chat_id}").first() + await db.query(Chat).filter_by(user_id=f"shared-{chat_id}").first() ) if shared_chat is None: - return self.insert_shared_chat_by_chat_id(chat_id) + return await self.insert_shared_chat_by_chat_id(chat_id) shared_chat.title = chat.title shared_chat.chat = chat.chat @@ -335,70 +340,72 @@ class ChatTable: shared_chat.pinned = chat.pinned shared_chat.folder_id = chat.folder_id shared_chat.updated_at = int(time.time()) - db.commit() - db.refresh(shared_chat) + await db.commit() + await db.refresh(shared_chat) return ChatModel.model_validate(shared_chat) except Exception: return None - def delete_shared_chat_by_chat_id(self, chat_id: str) -> bool: + async def delete_shared_chat_by_chat_id(self, chat_id: str) -> bool: try: - with get_db() as db: - db.query(Chat).filter_by(user_id=f"shared-{chat_id}").delete() - db.commit() + async with get_db() as db: + await db.query(Chat).filter_by(user_id=f"shared-{chat_id}").delete() + await db.commit() return True except Exception: return False - def update_chat_share_id_by_id( + async def update_chat_share_id_by_id( self, id: str, share_id: Optional[str] ) -> Optional[ChatModel]: try: - with get_db() as db: - chat = db.get(Chat, id) + async with get_db() as db: + chat = await db.get(Chat, id) chat.share_id = share_id - db.commit() - db.refresh(chat) + await db.commit() + await db.refresh(chat) return ChatModel.model_validate(chat) except Exception: return None - def toggle_chat_pinned_by_id(self, id: str) -> Optional[ChatModel]: + async def toggle_chat_pinned_by_id(self, id: str) -> Optional[ChatModel]: try: - with get_db() as db: - chat = db.get(Chat, id) + async with get_db() as db: + chat = await db.get(Chat, id) chat.pinned = not chat.pinned chat.updated_at = int(time.time()) - db.commit() - db.refresh(chat) + await db.commit() + await db.refresh(chat) return ChatModel.model_validate(chat) except Exception: return None - def toggle_chat_archive_by_id(self, id: str) -> Optional[ChatModel]: + async def toggle_chat_archive_by_id(self, id: str) -> Optional[ChatModel]: try: - with get_db() as db: - chat = db.get(Chat, id) + async with get_db() as db: + chat = await db.get(Chat, id) chat.archived = not chat.archived chat.updated_at = int(time.time()) - db.commit() - db.refresh(chat) + await db.commit() + await db.refresh(chat) return ChatModel.model_validate(chat) except Exception: return None - def archive_all_chats_by_user_id(self, user_id: str) -> bool: + async def archive_all_chats_by_user_id(self, user_id: str) -> bool: try: - with get_db() as db: - db.query(Chat).filter_by(user_id=user_id).update({"archived": True}) - db.commit() + async with get_db() as db: + await db.query(Chat).filter_by(user_id=user_id).update( + {"archived": True} + ) + await db.commit() return True except Exception: return False - def get_archived_chat_list_by_user_id( + async def get_archived_chat_list_by_user_id( self, user_id: str, filter: Optional[dict] = None, @@ -406,7 +413,7 @@ class ChatTable: limit: int = 50, ) -> list[ChatModel]: - with get_db() as db: + async with get_db() as db: query = db.query(Chat).filter_by(user_id=user_id, archived=True) if filter: @@ -432,10 +439,10 @@ class ChatTable: if limit: query = query.limit(limit) - all_chats = query.all() + all_chats = await query.all() return [ChatModel.model_validate(chat) for chat in all_chats] - def get_chat_list_by_user_id( + async def get_chat_list_by_user_id( self, user_id: str, include_archived: bool = False, @@ -443,7 +450,7 @@ class ChatTable: skip: int = 0, limit: int = 50, ) -> list[ChatModel]: - with get_db() as db: + async with get_db() as db: query = db.query(Chat).filter_by(user_id=user_id) if not include_archived: query = query.filter_by(archived=False) @@ -471,17 +478,17 @@ class ChatTable: if limit: query = query.limit(limit) - all_chats = query.all() + all_chats = await query.all() return [ChatModel.model_validate(chat) for chat in all_chats] - def get_chat_title_id_list_by_user_id( + async def get_chat_title_id_list_by_user_id( self, user_id: str, include_archived: bool = False, skip: Optional[int] = None, limit: Optional[int] = None, ) -> list[ChatTitleIdResponse]: - with get_db() as db: + async with get_db() as db: query = db.query(Chat).filter_by(user_id=user_id).filter_by(folder_id=None) query = query.filter(or_(Chat.pinned == False, Chat.pinned == None)) @@ -497,7 +504,7 @@ class ChatTable: if limit: query = query.limit(limit) - all_chats = query.all() + all_chats = await query.all() # result has to be destructured from sqlalchemy `row` and mapped to a dict since the `ChatModel`is not the returned dataclass. return [ @@ -512,12 +519,12 @@ class ChatTable: for chat in all_chats ] - def get_chat_list_by_chat_ids( + async def get_chat_list_by_chat_ids( self, chat_ids: list[str], skip: int = 0, limit: int = 50 ) -> list[ChatModel]: - with get_db() as db: + async with get_db() as db: all_chats = ( - db.query(Chat) + await db.query(Chat) .filter(Chat.id.in_(chat_ids)) .filter_by(archived=False) .order_by(Chat.updated_at.desc()) @@ -525,73 +532,75 @@ class ChatTable: ) return [ChatModel.model_validate(chat) for chat in all_chats] - def get_chat_by_id(self, id: str) -> Optional[ChatModel]: + async def get_chat_by_id(self, id: str) -> Optional[ChatModel]: try: - with get_db() as db: - chat = db.get(Chat, id) + async with get_db() as db: + chat = await db.get(Chat, id) return ChatModel.model_validate(chat) except Exception: return None - def get_chat_by_share_id(self, id: str) -> Optional[ChatModel]: + async def get_chat_by_share_id(self, id: str) -> Optional[ChatModel]: try: - with get_db() as db: + async with get_db() as db: # it is possible that the shared link was deleted. hence, # we check if the chat is still shared by checking if a chat with the share_id exists - chat = db.query(Chat).filter_by(share_id=id).first() + chat = await db.query(Chat).filter_by(share_id=id).first() if chat: - return self.get_chat_by_id(id) + return await self.get_chat_by_id(id) else: return None except Exception: return None - def get_chat_by_id_and_user_id(self, id: str, user_id: str) -> Optional[ChatModel]: + async def get_chat_by_id_and_user_id( + self, id: str, user_id: str + ) -> Optional[ChatModel]: try: - with get_db() as db: - chat = db.query(Chat).filter_by(id=id, user_id=user_id).first() + async with get_db() as db: + chat = await db.query(Chat).filter_by(id=id, user_id=user_id).first() return ChatModel.model_validate(chat) except Exception: return None - def get_chats(self, skip: int = 0, limit: int = 50) -> list[ChatModel]: - with get_db() as db: + async def get_chats(self, skip: int = 0, limit: int = 50) -> list[ChatModel]: + async with get_db() as db: all_chats = ( - db.query(Chat) + await db.query(Chat) # .limit(limit).offset(skip) .order_by(Chat.updated_at.desc()) ) return [ChatModel.model_validate(chat) for chat in all_chats] - def get_chats_by_user_id(self, user_id: str) -> list[ChatModel]: - with get_db() as db: + async def get_chats_by_user_id(self, user_id: str) -> list[ChatModel]: + async with get_db() as db: all_chats = ( - db.query(Chat) + await db.query(Chat) .filter_by(user_id=user_id) .order_by(Chat.updated_at.desc()) ) return [ChatModel.model_validate(chat) for chat in all_chats] - def get_pinned_chats_by_user_id(self, user_id: str) -> list[ChatModel]: - with get_db() as db: + async def get_pinned_chats_by_user_id(self, user_id: str) -> list[ChatModel]: + async with get_db() as db: all_chats = ( - db.query(Chat) + await db.query(Chat) .filter_by(user_id=user_id, pinned=True, archived=False) .order_by(Chat.updated_at.desc()) ) return [ChatModel.model_validate(chat) for chat in all_chats] - def get_archived_chats_by_user_id(self, user_id: str) -> list[ChatModel]: - with get_db() as db: + async def get_archived_chats_by_user_id(self, user_id: str) -> list[ChatModel]: + async with get_db() as db: all_chats = ( - db.query(Chat) + await db.query(Chat) .filter_by(user_id=user_id, archived=True) .order_by(Chat.updated_at.desc()) ) return [ChatModel.model_validate(chat) for chat in all_chats] - def get_chats_by_user_id_and_search_text( + async def get_chats_by_user_id_and_search_text( self, user_id: str, search_text: str, @@ -605,7 +614,7 @@ class ChatTable: search_text = search_text.replace("\u0000", "").lower().strip() if not search_text: - return self.get_chat_list_by_user_id( + return await self.get_chat_list_by_user_id( user_id, include_archived, filter={}, skip=skip, limit=limit ) @@ -619,7 +628,7 @@ class ChatTable: ] # Extract folder names - handle spaces and case insensitivity - folders = Folders.search_folders_by_names( + folders = await Folders.search_folders_by_names( user_id, [ word.replace("folder:", "") @@ -661,7 +670,7 @@ class ChatTable: search_text = " ".join(search_text_words) - with get_db() as db: + async with get_db() as db: query = db.query(Chat).filter(Chat.user_id == user_id) if is_archived is not None: @@ -783,30 +792,30 @@ class ChatTable: ) # Perform pagination at the SQL level - all_chats = query.offset(skip).limit(limit).all() + all_chats = await query.offset(skip).limit(limit).all() log.info(f"The number of chats: {len(all_chats)}") # Validate and return chats return [ChatModel.model_validate(chat) for chat in all_chats] - def get_chats_by_folder_id_and_user_id( + async def get_chats_by_folder_id_and_user_id( self, folder_id: str, user_id: str ) -> list[ChatModel]: - with get_db() as db: + async with get_db() as db: query = db.query(Chat).filter_by(folder_id=folder_id, user_id=user_id) query = query.filter(or_(Chat.pinned == False, Chat.pinned == None)) query = query.filter_by(archived=False) query = query.order_by(Chat.updated_at.desc()) - all_chats = query.all() + all_chats = await query.all() return [ChatModel.model_validate(chat) for chat in all_chats] - def get_chats_by_folder_ids_and_user_id( + async def get_chats_by_folder_ids_and_user_id( self, folder_ids: list[str], user_id: str ) -> list[ChatModel]: - with get_db() as db: + async with get_db() as db: query = db.query(Chat).filter( Chat.folder_id.in_(folder_ids), Chat.user_id == user_id ) @@ -815,34 +824,38 @@ class ChatTable: query = query.order_by(Chat.updated_at.desc()) - all_chats = query.all() + all_chats = await query.all() return [ChatModel.model_validate(chat) for chat in all_chats] - def update_chat_folder_id_by_id_and_user_id( + async def update_chat_folder_id_by_id_and_user_id( self, id: str, user_id: str, folder_id: str ) -> Optional[ChatModel]: try: - with get_db() as db: - chat = db.get(Chat, id) + async with get_db() as db: + chat = await db.get(Chat, id) chat.folder_id = folder_id chat.updated_at = int(time.time()) chat.pinned = False - db.commit() - db.refresh(chat) + await db.commit() + await db.refresh(chat) return ChatModel.model_validate(chat) except Exception: return None - def get_chat_tags_by_id_and_user_id(self, id: str, user_id: str) -> list[TagModel]: - with get_db() as db: - chat = db.get(Chat, id) + async def get_chat_tags_by_id_and_user_id( + self, id: str, user_id: str + ) -> list[TagModel]: + async with get_db() as db: + chat = await db.get(Chat, id) tags = chat.meta.get("tags", []) - return [Tags.get_tag_by_name_and_user_id(tag, user_id) for tag in tags] + return [ + await Tags.get_tag_by_name_and_user_id(tag, user_id) for tag in tags + ] - def get_chat_list_by_user_id_and_tag_name( + async def get_chat_list_by_user_id_and_tag_name( self, user_id: str, tag_name: str, skip: int = 0, limit: int = 50 ) -> list[ChatModel]: - with get_db() as db: + async with get_db() as db: query = db.query(Chat).filter_by(user_id=user_id) tag_id = tag_name.replace(" ", "_").lower() @@ -866,19 +879,19 @@ class ChatTable: f"Unsupported dialect: {db.bind.dialect.name}" ) - all_chats = query.all() + all_chats = await query.all() log.debug(f"all_chats: {all_chats}") return [ChatModel.model_validate(chat) for chat in all_chats] - def add_chat_tag_by_id_and_user_id_and_tag_name( + async def add_chat_tag_by_id_and_user_id_and_tag_name( self, id: str, user_id: str, tag_name: str ) -> Optional[ChatModel]: - tag = Tags.get_tag_by_name_and_user_id(tag_name, user_id) + tag = await Tags.get_tag_by_name_and_user_id(tag_name, user_id) if tag is None: - tag = Tags.insert_new_tag(tag_name, user_id) + tag = await Tags.insert_new_tag(tag_name, user_id) try: - with get_db() as db: - chat = db.get(Chat, id) + async with get_db() as db: + chat = await db.get(Chat, id) tag_id = tag.id if tag_id not in chat.meta.get("tags", []): @@ -887,14 +900,16 @@ class ChatTable: "tags": list(set(chat.meta.get("tags", []) + [tag_id])), } - db.commit() - db.refresh(chat) + await db.commit() + await db.refresh(chat) return ChatModel.model_validate(chat) except Exception: return None - def count_chats_by_tag_name_and_user_id(self, tag_name: str, user_id: str) -> int: - with get_db() as db: # Assuming `get_db()` returns a session object + async def count_chats_by_tag_name_and_user_id( + self, tag_name: str, user_id: str + ) -> int: + async with get_db() as db: # Assuming `get_db()` returns a session object query = db.query(Chat).filter_by(user_id=user_id, archived=False) # Normalize the tag_name for consistency @@ -922,19 +937,19 @@ class ChatTable: ) # Get the count of matching records - count = query.count() + count = await query.count() # Debugging output for inspection log.info(f"Count of chats for tag '{tag_name}': {count}") return count - def delete_tag_by_id_and_user_id_and_tag_name( + async def delete_tag_by_id_and_user_id_and_tag_name( self, id: str, user_id: str, tag_name: str ) -> bool: try: - with get_db() as db: - chat = db.get(Chat, id) + async with get_db() as db: + chat = await db.get(Chat, id) tags = chat.meta.get("tags", []) tag_id = tag_name.replace(" ", "_").lower() @@ -943,77 +958,79 @@ class ChatTable: **chat.meta, "tags": list(set(tags)), } - db.commit() + await db.commit() return True except Exception: return False - def delete_all_tags_by_id_and_user_id(self, id: str, user_id: str) -> bool: + async def delete_all_tags_by_id_and_user_id(self, id: str, user_id: str) -> bool: try: - with get_db() as db: - chat = db.get(Chat, id) + async with get_db() as db: + chat = await db.get(Chat, id) chat.meta = { **chat.meta, "tags": [], } - db.commit() + await db.commit() return True except Exception: return False - def delete_chat_by_id(self, id: str) -> bool: + async def delete_chat_by_id(self, id: str) -> bool: try: - with get_db() as db: - db.query(Chat).filter_by(id=id).delete() - db.commit() + async with get_db() as db: + await db.query(Chat).filter_by(id=id).delete() + await db.commit() - return True and self.delete_shared_chat_by_chat_id(id) + return True and await self.delete_shared_chat_by_chat_id(id) except Exception: return False - def delete_chat_by_id_and_user_id(self, id: str, user_id: str) -> bool: + async def delete_chat_by_id_and_user_id(self, id: str, user_id: str) -> bool: try: - with get_db() as db: - db.query(Chat).filter_by(id=id, user_id=user_id).delete() - db.commit() + async with get_db() as db: + await db.query(Chat).filter_by(id=id, user_id=user_id).delete() + await db.commit() - return True and self.delete_shared_chat_by_chat_id(id) + return True and await self.delete_shared_chat_by_chat_id(id) except Exception: return False - def delete_chats_by_user_id(self, user_id: str) -> bool: + async def delete_chats_by_user_id(self, user_id: str) -> bool: try: - with get_db() as db: - self.delete_shared_chats_by_user_id(user_id) + async with get_db() as db: + await self.delete_shared_chats_by_user_id(user_id) - db.query(Chat).filter_by(user_id=user_id).delete() - db.commit() + await db.query(Chat).filter_by(user_id=user_id).delete() + await db.commit() return True except Exception: return False - def delete_chats_by_user_id_and_folder_id( + async def delete_chats_by_user_id_and_folder_id( self, user_id: str, folder_id: str ) -> bool: try: - with get_db() as db: - db.query(Chat).filter_by(user_id=user_id, folder_id=folder_id).delete() - db.commit() + async with get_db() as db: + await db.query(Chat).filter_by( + user_id=user_id, folder_id=folder_id + ).delete() + await db.commit() return True except Exception: return False - def delete_shared_chats_by_user_id(self, user_id: str) -> bool: + async def delete_shared_chats_by_user_id(self, user_id: str) -> bool: try: - with get_db() as db: - chats_by_user = db.query(Chat).filter_by(user_id=user_id).all() + async with get_db() as db: + chats_by_user = await db.query(Chat).filter_by(user_id=user_id).all() shared_chat_ids = [f"shared-{chat.id}" for chat in chats_by_user] - db.query(Chat).filter(Chat.user_id.in_(shared_chat_ids)).delete() - db.commit() + await db.query(Chat).filter(Chat.user_id.in_(shared_chat_ids)).delete() + await db.commit() return True except Exception: diff --git a/backend/open_webui/models/feedbacks.py b/backend/open_webui/models/feedbacks.py index 215e36aa24..daef09e883 100644 --- a/backend/open_webui/models/feedbacks.py +++ b/backend/open_webui/models/feedbacks.py @@ -93,10 +93,10 @@ class FeedbackForm(BaseModel): class FeedbackTable: - def insert_new_feedback( + async def insert_new_feedback( self, user_id: str, form_data: FeedbackForm ) -> Optional[FeedbackModel]: - with get_db() as db: + async with get_db() as db: id = str(uuid.uuid4()) feedback = FeedbackModel( **{ @@ -110,9 +110,9 @@ class FeedbackTable: ) try: result = Feedback(**feedback.model_dump()) - db.add(result) - db.commit() - db.refresh(result) + await db.add(result) + await db.commit() + await db.refresh(result) if result: return FeedbackModel.model_validate(result) else: @@ -121,62 +121,64 @@ class FeedbackTable: log.exception(f"Error creating a new feedback: {e}") return None - def get_feedback_by_id(self, id: str) -> Optional[FeedbackModel]: + async def get_feedback_by_id(self, id: str) -> Optional[FeedbackModel]: try: - with get_db() as db: - feedback = db.query(Feedback).filter_by(id=id).first() + async with get_db() as db: + feedback = await db.query(Feedback).filter_by(id=id).first() if not feedback: return None return FeedbackModel.model_validate(feedback) except Exception: return None - def get_feedback_by_id_and_user_id( + async def get_feedback_by_id_and_user_id( self, id: str, user_id: str ) -> Optional[FeedbackModel]: try: - with get_db() as db: - feedback = db.query(Feedback).filter_by(id=id, user_id=user_id).first() + async with get_db() as db: + feedback = ( + await db.query(Feedback).filter_by(id=id, user_id=user_id).first() + ) if not feedback: return None return FeedbackModel.model_validate(feedback) except Exception: return None - def get_all_feedbacks(self) -> list[FeedbackModel]: - with get_db() as db: + async def get_all_feedbacks(self) -> list[FeedbackModel]: + async with get_db() as db: return [ FeedbackModel.model_validate(feedback) - for feedback in db.query(Feedback) + for feedback in await db.query(Feedback) .order_by(Feedback.updated_at.desc()) .all() ] - def get_feedbacks_by_type(self, type: str) -> list[FeedbackModel]: - with get_db() as db: + async def get_feedbacks_by_type(self, type: str) -> list[FeedbackModel]: + async with get_db() as db: return [ FeedbackModel.model_validate(feedback) - for feedback in db.query(Feedback) + for feedback in await db.query(Feedback) .filter_by(type=type) .order_by(Feedback.updated_at.desc()) .all() ] - def get_feedbacks_by_user_id(self, user_id: str) -> list[FeedbackModel]: - with get_db() as db: + async def get_feedbacks_by_user_id(self, user_id: str) -> list[FeedbackModel]: + async with get_db() as db: return [ FeedbackModel.model_validate(feedback) - for feedback in db.query(Feedback) + for feedback in await db.query(Feedback) .filter_by(user_id=user_id) .order_by(Feedback.updated_at.desc()) .all() ] - def update_feedback_by_id( + async def update_feedback_by_id( self, id: str, form_data: FeedbackForm ) -> Optional[FeedbackModel]: - with get_db() as db: - feedback = db.query(Feedback).filter_by(id=id).first() + async with get_db() as db: + feedback = await db.query(Feedback).filter_by(id=id).first() if not feedback: return None @@ -189,14 +191,16 @@ class FeedbackTable: feedback.updated_at = int(time.time()) - db.commit() + await db.commit() return FeedbackModel.model_validate(feedback) - def update_feedback_by_id_and_user_id( + async def update_feedback_by_id_and_user_id( self, id: str, user_id: str, form_data: FeedbackForm ) -> Optional[FeedbackModel]: - with get_db() as db: - feedback = db.query(Feedback).filter_by(id=id, user_id=user_id).first() + async with get_db() as db: + feedback = ( + await db.query(Feedback).filter_by(id=id, user_id=user_id).first() + ) if not feedback: return None @@ -209,45 +213,47 @@ class FeedbackTable: feedback.updated_at = int(time.time()) - db.commit() + await db.commit() return FeedbackModel.model_validate(feedback) - def delete_feedback_by_id(self, id: str) -> bool: - with get_db() as db: - feedback = db.query(Feedback).filter_by(id=id).first() + async def delete_feedback_by_id(self, id: str) -> bool: + async with get_db() as db: + feedback = await db.query(Feedback).filter_by(id=id).first() if not feedback: return False - db.delete(feedback) - db.commit() + await db.delete(feedback) + await db.commit() return True - def delete_feedback_by_id_and_user_id(self, id: str, user_id: str) -> bool: - with get_db() as db: - feedback = db.query(Feedback).filter_by(id=id, user_id=user_id).first() + async def delete_feedback_by_id_and_user_id(self, id: str, user_id: str) -> bool: + async with get_db() as db: + feedback = ( + await db.query(Feedback).filter_by(id=id, user_id=user_id).first() + ) if not feedback: return False - db.delete(feedback) - db.commit() + await db.delete(feedback) + await db.commit() return True - def delete_feedbacks_by_user_id(self, user_id: str) -> bool: - with get_db() as db: - feedbacks = db.query(Feedback).filter_by(user_id=user_id).all() + async def delete_feedbacks_by_user_id(self, user_id: str) -> bool: + async with get_db() as db: + feedbacks = await db.query(Feedback).filter_by(user_id=user_id).all() if not feedbacks: return False for feedback in feedbacks: - db.delete(feedback) - db.commit() + await db.delete(feedback) + await db.commit() return True - def delete_all_feedbacks(self) -> bool: - with get_db() as db: - feedbacks = db.query(Feedback).all() + async def delete_all_feedbacks(self) -> bool: + async with get_db() as db: + feedbacks = await db.query(Feedback).all() if not feedbacks: return False for feedback in feedbacks: - db.delete(feedback) - db.commit() + await db.delete(feedback) + await db.commit() return True diff --git a/backend/open_webui/models/files.py b/backend/open_webui/models/files.py index 6f1511cd13..c3651759ce 100644 --- a/backend/open_webui/models/files.py +++ b/backend/open_webui/models/files.py @@ -98,8 +98,10 @@ class FileForm(BaseModel): class FilesTable: - def insert_new_file(self, user_id: str, form_data: FileForm) -> Optional[FileModel]: - with get_db() as db: + async def insert_new_file( + self, user_id: str, form_data: FileForm + ) -> Optional[FileModel]: + async with get_db() as db: file = FileModel( **{ **form_data.model_dump(), @@ -111,9 +113,9 @@ class FilesTable: try: result = File(**file.model_dump()) - db.add(result) - db.commit() - db.refresh(result) + await db.add(result) + await db.commit() + await db.refresh(result) if result: return FileModel.model_validate(result) else: @@ -122,18 +124,18 @@ class FilesTable: log.exception(f"Error inserting a new file: {e}") return None - def get_file_by_id(self, id: str) -> Optional[FileModel]: - with get_db() as db: + async def get_file_by_id(self, id: str) -> Optional[FileModel]: + async with get_db() as db: try: - file = db.get(File, id) + file = await db.get(File, id) return FileModel.model_validate(file) except Exception: return None - def get_file_metadata_by_id(self, id: str) -> Optional[FileMetadataResponse]: - with get_db() as db: + async def get_file_metadata_by_id(self, id: str) -> Optional[FileMetadataResponse]: + async with get_db() as db: try: - file = db.get(File, id) + file = await db.get(File, id) return FileMetadataResponse( id=file.id, meta=file.meta, @@ -143,22 +145,26 @@ class FilesTable: except Exception: return None - def get_files(self) -> list[FileModel]: - with get_db() as db: - return [FileModel.model_validate(file) for file in db.query(File).all()] + async def get_files(self) -> list[FileModel]: + async with get_db() as db: + return [ + FileModel.model_validate(file) for file in await db.query(File).all() + ] - def get_files_by_ids(self, ids: list[str]) -> list[FileModel]: - with get_db() as db: + async def get_files_by_ids(self, ids: list[str]) -> list[FileModel]: + async with get_db() as db: return [ FileModel.model_validate(file) - for file in db.query(File) + for file in await db.query(File) .filter(File.id.in_(ids)) .order_by(File.updated_at.desc()) .all() ] - def get_file_metadatas_by_ids(self, ids: list[str]) -> list[FileMetadataResponse]: - with get_db() as db: + async def get_file_metadatas_by_ids( + self, ids: list[str] + ) -> list[FileMetadataResponse]: + async with get_db() as db: return [ FileMetadataResponse( id=file.id, @@ -166,66 +172,68 @@ class FilesTable: created_at=file.created_at, updated_at=file.updated_at, ) - for file in db.query(File) + for file in await db.query(File) .filter(File.id.in_(ids)) .order_by(File.updated_at.desc()) .all() ] - def get_files_by_user_id(self, user_id: str) -> list[FileModel]: - with get_db() as db: + async def get_files_by_user_id(self, user_id: str) -> list[FileModel]: + async with get_db() as db: return [ FileModel.model_validate(file) - for file in db.query(File).filter_by(user_id=user_id).all() + for file in await db.query(File).filter_by(user_id=user_id).all() ] - def update_file_hash_by_id(self, id: str, hash: str) -> Optional[FileModel]: - with get_db() as db: + async def update_file_hash_by_id(self, id: str, hash: str) -> Optional[FileModel]: + async with get_db() as db: try: - file = db.query(File).filter_by(id=id).first() + file = await db.query(File).filter_by(id=id).first() file.hash = hash - db.commit() + await db.commit() return FileModel.model_validate(file) except Exception: return None - def update_file_data_by_id(self, id: str, data: dict) -> Optional[FileModel]: - with get_db() as db: + async def update_file_data_by_id(self, id: str, data: dict) -> Optional[FileModel]: + async with get_db() as db: try: - file = db.query(File).filter_by(id=id).first() + file = await db.query(File).filter_by(id=id).first() file.data = {**(file.data if file.data else {}), **data} - db.commit() + await db.commit() return FileModel.model_validate(file) except Exception as e: return None - def update_file_metadata_by_id(self, id: str, meta: dict) -> Optional[FileModel]: - with get_db() as db: + async def update_file_metadata_by_id( + self, id: str, meta: dict + ) -> Optional[FileModel]: + async with get_db() as db: try: - file = db.query(File).filter_by(id=id).first() + file = await db.query(File).filter_by(id=id).first() file.meta = {**(file.meta if file.meta else {}), **meta} - db.commit() + await db.commit() return FileModel.model_validate(file) except Exception: return None - def delete_file_by_id(self, id: str) -> bool: - with get_db() as db: + async def delete_file_by_id(self, id: str) -> bool: + async with get_db() as db: try: - db.query(File).filter_by(id=id).delete() - db.commit() + await db.query(File).filter_by(id=id).delete() + await db.commit() return True except Exception: return False - def delete_all_files(self) -> bool: - with get_db() as db: + async def delete_all_files(self) -> bool: + async with get_db() as db: try: - db.query(File).delete() - db.commit() + await db.query(File).delete() + await db.commit() return True except Exception: diff --git a/backend/open_webui/models/folders.py b/backend/open_webui/models/folders.py index 15deecbf42..918342fd59 100644 --- a/backend/open_webui/models/folders.py +++ b/backend/open_webui/models/folders.py @@ -62,10 +62,10 @@ class FolderForm(BaseModel): class FolderTable: - def insert_new_folder( + async def insert_new_folder( self, user_id: str, form_data: FolderForm, parent_id: Optional[str] = None ) -> Optional[FolderModel]: - with get_db() as db: + async with get_db() as db: id = str(uuid.uuid4()) folder = FolderModel( **{ @@ -79,9 +79,9 @@ class FolderTable: ) try: result = Folder(**folder.model_dump()) - db.add(result) - db.commit() - db.refresh(result) + await db.add(result) + await db.commit() + await db.refresh(result) if result: return FolderModel.model_validate(result) else: @@ -90,12 +90,14 @@ class FolderTable: log.exception(f"Error inserting a new folder: {e}") return None - def get_folder_by_id_and_user_id( + async def get_folder_by_id_and_user_id( self, id: str, user_id: str ) -> Optional[FolderModel]: try: - with get_db() as db: - folder = db.query(Folder).filter_by(id=id, user_id=user_id).first() + async with get_db() as db: + folder = ( + await db.query(Folder).filter_by(id=id, user_id=user_id).first() + ) if not folder: return None @@ -104,45 +106,47 @@ class FolderTable: except Exception: return None - def get_children_folders_by_id_and_user_id( + async def get_children_folders_by_id_and_user_id( self, id: str, user_id: str ) -> Optional[list[FolderModel]]: try: - with get_db() as db: + async with get_db() as db: folders = [] - def get_children(folder): - children = self.get_folders_by_parent_id_and_user_id( + async def get_children(folder): + children = await self.get_folders_by_parent_id_and_user_id( folder.id, user_id ) for child in children: - get_children(child) + await get_children(child) folders.append(child) - folder = db.query(Folder).filter_by(id=id, user_id=user_id).first() + folder = ( + await db.query(Folder).filter_by(id=id, user_id=user_id).first() + ) if not folder: return None - get_children(folder) + await get_children(folder) return folders except Exception: return None - def get_folders_by_user_id(self, user_id: str) -> list[FolderModel]: - with get_db() as db: + async def get_folders_by_user_id(self, user_id: str) -> list[FolderModel]: + async with get_db() as db: return [ FolderModel.model_validate(folder) - for folder in db.query(Folder).filter_by(user_id=user_id).all() + for folder in await db.query(Folder).filter_by(user_id=user_id).all() ] - def get_folder_by_parent_id_and_user_id_and_name( + async def get_folder_by_parent_id_and_user_id_and_name( self, parent_id: Optional[str], user_id: str, name: str ) -> Optional[FolderModel]: try: - with get_db() as db: + async with get_db() as db: # Check if folder exists folder = ( - db.query(Folder) + await db.query(Folder) .filter_by(parent_id=parent_id, user_id=user_id) .filter(Folder.name.ilike(name)) .first() @@ -156,26 +160,28 @@ class FolderTable: log.error(f"get_folder_by_parent_id_and_user_id_and_name: {e}") return None - def get_folders_by_parent_id_and_user_id( + async def get_folders_by_parent_id_and_user_id( self, parent_id: Optional[str], user_id: str ) -> list[FolderModel]: - with get_db() as db: + async with get_db() as db: return [ FolderModel.model_validate(folder) - for folder in db.query(Folder) + for folder in await db.query(Folder) .filter_by(parent_id=parent_id, user_id=user_id) .all() ] - def update_folder_parent_id_by_id_and_user_id( + async def update_folder_parent_id_by_id_and_user_id( self, id: str, user_id: str, parent_id: str, ) -> Optional[FolderModel]: try: - with get_db() as db: - folder = db.query(Folder).filter_by(id=id, user_id=user_id).first() + async with get_db() as db: + folder = ( + await db.query(Folder).filter_by(id=id, user_id=user_id).first() + ) if not folder: return None @@ -183,19 +189,21 @@ class FolderTable: folder.parent_id = parent_id folder.updated_at = int(time.time()) - db.commit() + await db.commit() return FolderModel.model_validate(folder) except Exception as e: log.error(f"update_folder: {e}") return - def update_folder_by_id_and_user_id( + async def update_folder_by_id_and_user_id( self, id: str, user_id: str, form_data: FolderForm ) -> Optional[FolderModel]: try: - with get_db() as db: - folder = db.query(Folder).filter_by(id=id, user_id=user_id).first() + async with get_db() as db: + folder = ( + await db.query(Folder).filter_by(id=id, user_id=user_id).first() + ) if not folder: return None @@ -203,7 +211,7 @@ class FolderTable: form_data = form_data.model_dump(exclude_unset=True) existing_folder = ( - db.query(Folder) + await db.query(Folder) .filter_by( name=form_data.get("name"), parent_id=folder.parent_id, @@ -224,19 +232,21 @@ class FolderTable: folder.updated_at = int(time.time()) - db.commit() + await db.commit() return FolderModel.model_validate(folder) except Exception as e: log.error(f"update_folder: {e}") return - def update_folder_is_expanded_by_id_and_user_id( + async def update_folder_is_expanded_by_id_and_user_id( self, id: str, user_id: str, is_expanded: bool ) -> Optional[FolderModel]: try: - with get_db() as db: - folder = db.query(Folder).filter_by(id=id, user_id=user_id).first() + async with get_db() as db: + folder = ( + await db.query(Folder).filter_by(id=id, user_id=user_id).first() + ) if not folder: return None @@ -244,40 +254,44 @@ class FolderTable: folder.is_expanded = is_expanded folder.updated_at = int(time.time()) - db.commit() + await db.commit() return FolderModel.model_validate(folder) except Exception as e: log.error(f"update_folder: {e}") return - def delete_folder_by_id_and_user_id(self, id: str, user_id: str) -> list[str]: + async def delete_folder_by_id_and_user_id(self, id: str, user_id: str) -> list[str]: try: folder_ids = [] - with get_db() as db: - folder = db.query(Folder).filter_by(id=id, user_id=user_id).first() + async with get_db() as db: + folder = ( + await db.query(Folder).filter_by(id=id, user_id=user_id).first() + ) if not folder: return folder_ids folder_ids.append(folder.id) # Delete all children folders - def delete_children(folder): - folder_children = self.get_folders_by_parent_id_and_user_id( + async def delete_children(folder): + folder_children = await self.get_folders_by_parent_id_and_user_id( folder.id, user_id ) for folder_child in folder_children: - delete_children(folder_child) + await delete_children(folder_child) folder_ids.append(folder_child.id) - folder = db.query(Folder).filter_by(id=folder_child.id).first() - db.delete(folder) - db.commit() + folder = ( + await db.query(Folder).filter_by(id=folder_child.id).first() + ) + await db.delete(folder) + await db.commit() - delete_children(folder) - db.delete(folder) - db.commit() + await delete_children(folder) + await db.delete(folder) + await db.commit() return folder_ids except Exception as e: log.error(f"delete_folder: {e}") @@ -288,7 +302,7 @@ class FolderTable: name = re.sub(r"[\s_]+", " ", name) return name.strip().lower() - def search_folders_by_names( + async def search_folders_by_names( self, user_id: str, queries: list[str] ) -> list[FolderModel]: """ @@ -299,14 +313,14 @@ class FolderTable: return [] results = {} - with get_db() as db: - folders = db.query(Folder).filter_by(user_id=user_id).all() + async with get_db() as db: + folders = await db.query(Folder).filter_by(user_id=user_id).all() for folder in folders: if self.normalize_folder_name(folder.name) in normalized_queries: results[folder.id] = FolderModel.model_validate(folder) # get children folders - children = self.get_children_folders_by_id_and_user_id( + children = await self.get_children_folders_by_id_and_user_id( folder.id, user_id ) for child in children: @@ -319,7 +333,7 @@ class FolderTable: results = list(results.values()) return results - def search_folders_by_name_contains( + async def search_folders_by_name_contains( self, user_id: str, query: str ) -> list[FolderModel]: """ @@ -327,8 +341,8 @@ class FolderTable: """ normalized_query = self.normalize_folder_name(query) results = [] - with get_db() as db: - folders = db.query(Folder).filter_by(user_id=user_id).all() + async with get_db() as db: + folders = await db.query(Folder).filter_by(user_id=user_id).all() for folder in folders: norm_name = self.normalize_folder_name(folder.name) if normalized_query in norm_name: diff --git a/backend/open_webui/models/functions.py b/backend/open_webui/models/functions.py index e98771fa02..ea397f0b4d 100644 --- a/backend/open_webui/models/functions.py +++ b/backend/open_webui/models/functions.py @@ -81,7 +81,7 @@ class FunctionValves(BaseModel): class FunctionsTable: - def insert_new_function( + async def insert_new_function( self, user_id: str, type: str, form_data: FunctionForm ) -> Optional[FunctionModel]: function = FunctionModel( @@ -95,11 +95,11 @@ class FunctionsTable: ) try: - with get_db() as db: + async with get_db() as db: result = Function(**function.model_dump()) - db.add(result) - db.commit() - db.refresh(result) + await db.add(result) + await db.commit() + await db.refresh(result) if result: return FunctionModel.model_validate(result) else: @@ -108,14 +108,14 @@ class FunctionsTable: log.exception(f"Error creating a new function: {e}") return None - def sync_functions( + async def sync_functions( self, user_id: str, functions: list[FunctionModel] ) -> list[FunctionModel]: # Synchronize functions for a user by updating existing ones, inserting new ones, and removing those that are no longer present. try: - with get_db() as db: + async with get_db() as db: # Get existing functions - existing_functions = db.query(Function).all() + existing_functions = await db.query(Function).all() existing_ids = {func.id for func in existing_functions} # Prepare a set of new function IDs @@ -124,7 +124,7 @@ class FunctionsTable: # Update or insert functions for func in functions: if func.id in existing_ids: - db.query(Function).filter_by(id=func.id).update( + await db.query(Function).filter_by(id=func.id).update( { **func.model_dump(), "user_id": user_id, @@ -139,107 +139,109 @@ class FunctionsTable: "updated_at": int(time.time()), } ) - db.add(new_func) + await db.add(new_func) # Remove functions that are no longer present for func in existing_functions: if func.id not in new_function_ids: - db.delete(func) + await db.delete(func) - db.commit() + await db.commit() return [ FunctionModel.model_validate(func) - for func in db.query(Function).all() + for func in await db.query(Function).all() ] except Exception as e: log.exception(f"Error syncing functions for user {user_id}: {e}") return [] - def get_function_by_id(self, id: str) -> Optional[FunctionModel]: + async def get_function_by_id(self, id: str) -> Optional[FunctionModel]: try: - with get_db() as db: - function = db.get(Function, id) + async with get_db() as db: + function = await db.get(Function, id) return FunctionModel.model_validate(function) except Exception: return None - def get_functions(self, active_only=False) -> list[FunctionModel]: - with get_db() as db: + async def get_functions(self, active_only=False) -> list[FunctionModel]: + async with get_db() as db: if active_only: return [ FunctionModel.model_validate(function) - for function in db.query(Function).filter_by(is_active=True).all() + for function in await db.query(Function) + .filter_by(is_active=True) + .all() ] else: return [ FunctionModel.model_validate(function) - for function in db.query(Function).all() + for function in await db.query(Function).all() ] - def get_functions_by_type( + async def get_functions_by_type( self, type: str, active_only=False ) -> list[FunctionModel]: - with get_db() as db: + async with get_db() as db: if active_only: return [ FunctionModel.model_validate(function) - for function in db.query(Function) + for function in await db.query(Function) .filter_by(type=type, is_active=True) .all() ] else: return [ FunctionModel.model_validate(function) - for function in db.query(Function).filter_by(type=type).all() + for function in await db.query(Function).filter_by(type=type).all() ] - def get_global_filter_functions(self) -> list[FunctionModel]: - with get_db() as db: + async def get_global_filter_functions(self) -> list[FunctionModel]: + async with get_db() as db: return [ FunctionModel.model_validate(function) - for function in db.query(Function) + for function in await db.query(Function) .filter_by(type="filter", is_active=True, is_global=True) .all() ] - def get_global_action_functions(self) -> list[FunctionModel]: - with get_db() as db: + async def get_global_action_functions(self) -> list[FunctionModel]: + async with get_db() as db: return [ FunctionModel.model_validate(function) - for function in db.query(Function) + for function in await db.query(Function) .filter_by(type="action", is_active=True, is_global=True) .all() ] - def get_function_valves_by_id(self, id: str) -> Optional[dict]: - with get_db() as db: + async def get_function_valves_by_id(self, id: str) -> Optional[dict]: + async with get_db() as db: try: - function = db.get(Function, id) + function = await db.get(Function, id) return function.valves if function.valves else {} except Exception as e: log.exception(f"Error getting function valves by id {id}: {e}") return None - def update_function_valves_by_id( + async def update_function_valves_by_id( self, id: str, valves: dict ) -> Optional[FunctionValves]: - with get_db() as db: + async with get_db() as db: try: - function = db.get(Function, id) + function = await db.get(Function, id) function.valves = valves function.updated_at = int(time.time()) - db.commit() - db.refresh(function) - return self.get_function_by_id(id) + await db.commit() + await db.refresh(function) + return await self.get_function_by_id(id) except Exception: return None - def get_user_valves_by_id_and_user_id( + async def get_user_valves_by_id_and_user_id( self, id: str, user_id: str ) -> Optional[dict]: try: - user = Users.get_user_by_id(user_id) + user = await Users.get_user_by_id(user_id) user_settings = user.settings.model_dump() if user.settings else {} # Check if user has "functions" and "valves" settings @@ -255,11 +257,11 @@ class FunctionsTable: ) return None - def update_user_valves_by_id_and_user_id( + async def update_user_valves_by_id_and_user_id( self, id: str, user_id: str, valves: dict ) -> Optional[dict]: try: - user = Users.get_user_by_id(user_id) + user = await Users.get_user_by_id(user_id) user_settings = user.settings.model_dump() if user.settings else {} # Check if user has "functions" and "valves" settings @@ -271,7 +273,7 @@ class FunctionsTable: user_settings["functions"]["valves"][id] = valves # Update the user settings in the database - Users.update_user_by_id(user_id, {"settings": user_settings}) + await Users.update_user_by_id(user_id, {"settings": user_settings}) return user_settings["functions"]["valves"][id] except Exception as e: @@ -280,39 +282,41 @@ class FunctionsTable: ) return None - def update_function_by_id(self, id: str, updated: dict) -> Optional[FunctionModel]: - with get_db() as db: + async def update_function_by_id( + self, id: str, updated: dict + ) -> Optional[FunctionModel]: + async with get_db() as db: try: - db.query(Function).filter_by(id=id).update( + await db.query(Function).filter_by(id=id).update( { **updated, "updated_at": int(time.time()), } ) - db.commit() - return self.get_function_by_id(id) + await db.commit() + return await self.get_function_by_id(id) except Exception: return None - def deactivate_all_functions(self) -> Optional[bool]: - with get_db() as db: + async def deactivate_all_functions(self) -> Optional[bool]: + async with get_db() as db: try: - db.query(Function).update( + await db.query(Function).update( { "is_active": False, "updated_at": int(time.time()), } ) - db.commit() + await db.commit() return True except Exception: return None - def delete_function_by_id(self, id: str) -> bool: - with get_db() as db: + async def delete_function_by_id(self, id: str) -> bool: + async with get_db() as db: try: - db.query(Function).filter_by(id=id).delete() - db.commit() + await db.query(Function).filter_by(id=id).delete() + await db.commit() return True except Exception: diff --git a/backend/open_webui/models/groups.py b/backend/open_webui/models/groups.py index 6615f95142..4097bab531 100644 --- a/backend/open_webui/models/groups.py +++ b/backend/open_webui/models/groups.py @@ -95,7 +95,7 @@ class GroupTable: def insert_new_group( self, user_id: str, form_data: GroupForm ) -> Optional[GroupModel]: - with get_db() as db: + async with get_db() as db: group = GroupModel( **{ **form_data.model_dump(exclude_none=True), @@ -120,14 +120,14 @@ class GroupTable: return None def get_groups(self) -> list[GroupModel]: - with get_db() as db: + async with get_db() as db: return [ GroupModel.model_validate(group) for group in db.query(Group).order_by(Group.updated_at.desc()).all() ] def get_groups_by_member_id(self, user_id: str) -> list[GroupModel]: - with get_db() as db: + async with get_db() as db: return [ GroupModel.model_validate(group) for group in db.query(Group) @@ -143,7 +143,7 @@ class GroupTable: def get_group_by_id(self, id: str) -> Optional[GroupModel]: try: - with get_db() as db: + async with get_db() as db: group = db.query(Group).filter_by(id=id).first() return GroupModel.model_validate(group) if group else None except Exception: @@ -160,7 +160,7 @@ class GroupTable: self, id: str, form_data: GroupUpdateForm, overwrite: bool = False ) -> Optional[GroupModel]: try: - with get_db() as db: + async with get_db() as db: db.query(Group).filter_by(id=id).update( { **form_data.model_dump(exclude_none=True), @@ -175,7 +175,7 @@ class GroupTable: def delete_group_by_id(self, id: str) -> bool: try: - with get_db() as db: + async with get_db() as db: db.query(Group).filter_by(id=id).delete() db.commit() return True @@ -183,7 +183,7 @@ class GroupTable: return False def delete_all_groups(self) -> bool: - with get_db() as db: + async with get_db() as db: try: db.query(Group).delete() db.commit() @@ -193,7 +193,7 @@ class GroupTable: return False def remove_user_from_all_groups(self, user_id: str) -> bool: - with get_db() as db: + async with get_db() as db: try: groups = self.get_groups_by_member_id(user_id) @@ -221,7 +221,7 @@ class GroupTable: new_groups = [] - with get_db() as db: + async with get_db() as db: for group_name in group_names: if group_name not in existing_group_names: new_group = GroupModel( @@ -244,7 +244,7 @@ class GroupTable: return new_groups def sync_groups_by_group_names(self, user_id: str, group_names: list[str]) -> bool: - with get_db() as db: + async with get_db() as db: try: groups = db.query(Group).filter(Group.name.in_(group_names)).all() group_ids = [group.id for group in groups] @@ -283,7 +283,7 @@ class GroupTable: self, id: str, user_ids: Optional[list[str]] = None ) -> Optional[GroupModel]: try: - with get_db() as db: + async with get_db() as db: group = db.query(Group).filter_by(id=id).first() if not group: return None @@ -307,7 +307,7 @@ class GroupTable: self, id: str, user_ids: Optional[list[str]] = None ) -> Optional[GroupModel]: try: - with get_db() as db: + async with get_db() as db: group = db.query(Group).filter_by(id=id).first() if not group: return None diff --git a/backend/open_webui/models/knowledge.py b/backend/open_webui/models/knowledge.py index bed3d5542e..6dff07ae4e 100644 --- a/backend/open_webui/models/knowledge.py +++ b/backend/open_webui/models/knowledge.py @@ -103,7 +103,7 @@ class KnowledgeTable: def insert_new_knowledge( self, user_id: str, form_data: KnowledgeForm ) -> Optional[KnowledgeModel]: - with get_db() as db: + async with get_db() as db: knowledge = KnowledgeModel( **{ **form_data.model_dump(), @@ -126,13 +126,13 @@ class KnowledgeTable: except Exception: return None - def get_knowledge_bases(self) -> list[KnowledgeUserModel]: - with get_db() as db: + async def get_knowledge_bases(self) -> list[KnowledgeUserModel]: + async with get_db() as db: knowledge_bases = [] for knowledge in ( db.query(Knowledge).order_by(Knowledge.updated_at.desc()).all() ): - user = Users.get_user_by_id(knowledge.user_id) + user = await Users.get_user_by_id(knowledge.user_id) knowledge_bases.append( KnowledgeUserModel.model_validate( { @@ -156,7 +156,7 @@ class KnowledgeTable: def get_knowledge_by_id(self, id: str) -> Optional[KnowledgeModel]: try: - with get_db() as db: + async with get_db() as db: knowledge = db.query(Knowledge).filter_by(id=id).first() return KnowledgeModel.model_validate(knowledge) if knowledge else None except Exception: @@ -166,7 +166,7 @@ class KnowledgeTable: self, id: str, form_data: KnowledgeForm, overwrite: bool = False ) -> Optional[KnowledgeModel]: try: - with get_db() as db: + async with get_db() as db: knowledge = self.get_knowledge_by_id(id=id) db.query(Knowledge).filter_by(id=id).update( { @@ -184,7 +184,7 @@ class KnowledgeTable: self, id: str, data: dict ) -> Optional[KnowledgeModel]: try: - with get_db() as db: + async with get_db() as db: knowledge = self.get_knowledge_by_id(id=id) db.query(Knowledge).filter_by(id=id).update( { @@ -200,7 +200,7 @@ class KnowledgeTable: def delete_knowledge_by_id(self, id: str) -> bool: try: - with get_db() as db: + async with get_db() as db: db.query(Knowledge).filter_by(id=id).delete() db.commit() return True @@ -208,7 +208,7 @@ class KnowledgeTable: return False def delete_all_knowledge(self) -> bool: - with get_db() as db: + async with get_db() as db: try: db.query(Knowledge).delete() db.commit() diff --git a/backend/open_webui/models/memories.py b/backend/open_webui/models/memories.py index 253371c680..5602f56419 100644 --- a/backend/open_webui/models/memories.py +++ b/backend/open_webui/models/memories.py @@ -42,7 +42,7 @@ class MemoriesTable: user_id: str, content: str, ) -> Optional[MemoryModel]: - with get_db() as db: + async with get_db() as db: id = str(uuid.uuid4()) memory = MemoryModel( @@ -69,7 +69,7 @@ class MemoriesTable: user_id: str, content: str, ) -> Optional[MemoryModel]: - with get_db() as db: + async with get_db() as db: try: memory = db.get(Memory, id) if not memory or memory.user_id != user_id: @@ -84,7 +84,7 @@ class MemoriesTable: return None def get_memories(self) -> list[MemoryModel]: - with get_db() as db: + async with get_db() as db: try: memories = db.query(Memory).all() return [MemoryModel.model_validate(memory) for memory in memories] @@ -92,7 +92,7 @@ class MemoriesTable: return None def get_memories_by_user_id(self, user_id: str) -> list[MemoryModel]: - with get_db() as db: + async with get_db() as db: try: memories = db.query(Memory).filter_by(user_id=user_id).all() return [MemoryModel.model_validate(memory) for memory in memories] @@ -100,7 +100,7 @@ class MemoriesTable: return None def get_memory_by_id(self, id: str) -> Optional[MemoryModel]: - with get_db() as db: + async with get_db() as db: try: memory = db.get(Memory, id) return MemoryModel.model_validate(memory) @@ -108,7 +108,7 @@ class MemoriesTable: return None def delete_memory_by_id(self, id: str) -> bool: - with get_db() as db: + async with get_db() as db: try: db.query(Memory).filter_by(id=id).delete() db.commit() @@ -119,7 +119,7 @@ class MemoriesTable: return False def delete_memories_by_user_id(self, user_id: str) -> bool: - with get_db() as db: + async with get_db() as db: try: db.query(Memory).filter_by(user_id=user_id).delete() db.commit() @@ -129,7 +129,7 @@ class MemoriesTable: return False def delete_memory_by_id_and_user_id(self, id: str, user_id: str) -> bool: - with get_db() as db: + async with get_db() as db: try: memory = db.get(Memory, id) if not memory or memory.user_id != user_id: diff --git a/backend/open_webui/models/messages.py b/backend/open_webui/models/messages.py index a27ae52519..d9c0ff3a04 100644 --- a/backend/open_webui/models/messages.py +++ b/backend/open_webui/models/messages.py @@ -98,7 +98,7 @@ class MessageTable: def insert_new_message( self, form_data: MessageForm, channel_id: str, user_id: str ) -> Optional[MessageModel]: - with get_db() as db: + async with get_db() as db: id = str(uuid.uuid4()) ts = int(time.time_ns()) @@ -123,7 +123,7 @@ class MessageTable: return MessageModel.model_validate(result) if result else None def get_message_by_id(self, id: str) -> Optional[MessageResponse]: - with get_db() as db: + async with get_db() as db: message = db.get(Message, id) if not message: return None @@ -141,7 +141,7 @@ class MessageTable: ) def get_replies_by_message_id(self, id: str) -> list[MessageModel]: - with get_db() as db: + async with get_db() as db: all_messages = ( db.query(Message) .filter_by(parent_id=id) @@ -151,7 +151,7 @@ class MessageTable: return [MessageModel.model_validate(message) for message in all_messages] def get_reply_user_ids_by_message_id(self, id: str) -> list[str]: - with get_db() as db: + async with get_db() as db: return [ message.user_id for message in db.query(Message).filter_by(parent_id=id).all() @@ -160,7 +160,7 @@ class MessageTable: def get_messages_by_channel_id( self, channel_id: str, skip: int = 0, limit: int = 50 ) -> list[MessageModel]: - with get_db() as db: + async with get_db() as db: all_messages = ( db.query(Message) .filter_by(channel_id=channel_id, parent_id=None) @@ -174,7 +174,7 @@ class MessageTable: def get_messages_by_parent_id( self, channel_id: str, parent_id: str, skip: int = 0, limit: int = 50 ) -> list[MessageModel]: - with get_db() as db: + async with get_db() as db: message = db.get(Message, parent_id) if not message: @@ -198,7 +198,7 @@ class MessageTable: def update_message_by_id( self, id: str, form_data: MessageForm ) -> Optional[MessageModel]: - with get_db() as db: + async with get_db() as db: message = db.get(Message, id) message.content = form_data.content message.data = form_data.data @@ -211,7 +211,7 @@ class MessageTable: def add_reaction_to_message( self, id: str, user_id: str, name: str ) -> Optional[MessageReactionModel]: - with get_db() as db: + async with get_db() as db: reaction_id = str(uuid.uuid4()) reaction = MessageReactionModel( id=reaction_id, @@ -227,7 +227,7 @@ class MessageTable: return MessageReactionModel.model_validate(result) if result else None def get_reactions_by_message_id(self, id: str) -> list[Reactions]: - with get_db() as db: + async with get_db() as db: all_reactions = db.query(MessageReaction).filter_by(message_id=id).all() reactions = {} @@ -246,7 +246,7 @@ class MessageTable: def remove_reaction_by_id_and_user_id_and_name( self, id: str, user_id: str, name: str ) -> bool: - with get_db() as db: + async with get_db() as db: db.query(MessageReaction).filter_by( message_id=id, user_id=user_id, name=name ).delete() @@ -254,19 +254,19 @@ class MessageTable: return True def delete_reactions_by_id(self, id: str) -> bool: - with get_db() as db: + async with get_db() as db: db.query(MessageReaction).filter_by(message_id=id).delete() db.commit() return True def delete_replies_by_id(self, id: str) -> bool: - with get_db() as db: + async with get_db() as db: db.query(Message).filter_by(parent_id=id).delete() db.commit() return True def delete_message_by_id(self, id: str) -> bool: - with get_db() as db: + async with get_db() as db: db.query(Message).filter_by(id=id).delete() # Delete all reactions to this message diff --git a/backend/open_webui/models/models.py b/backend/open_webui/models/models.py index 1a29b86eae..202dd9ac5f 100755 --- a/backend/open_webui/models/models.py +++ b/backend/open_webui/models/models.py @@ -155,7 +155,7 @@ class ModelsTable: } ) try: - with get_db() as db: + async with get_db() as db: result = Model(**model.model_dump()) db.add(result) db.commit() @@ -170,14 +170,14 @@ class ModelsTable: return None def get_all_models(self) -> list[ModelModel]: - with get_db() as db: + async with get_db() as db: return [ModelModel.model_validate(model) for model in db.query(Model).all()] - def get_models(self) -> list[ModelUserResponse]: - with get_db() as db: + async def get_models(self) -> list[ModelUserResponse]: + async with get_db() as db: models = [] for model in db.query(Model).filter(Model.base_model_id != None).all(): - user = Users.get_user_by_id(model.user_id) + user = await Users.get_user_by_id(model.user_id) models.append( ModelUserResponse.model_validate( { @@ -189,7 +189,7 @@ class ModelsTable: return models def get_base_models(self) -> list[ModelModel]: - with get_db() as db: + async with get_db() as db: return [ ModelModel.model_validate(model) for model in db.query(Model).filter(Model.base_model_id == None).all() @@ -208,14 +208,14 @@ class ModelsTable: def get_model_by_id(self, id: str) -> Optional[ModelModel]: try: - with get_db() as db: + async with get_db() as db: model = db.get(Model, id) return ModelModel.model_validate(model) except Exception: return None def toggle_model_by_id(self, id: str) -> Optional[ModelModel]: - with get_db() as db: + async with get_db() as db: try: is_active = db.query(Model).filter_by(id=id).first().is_active @@ -233,7 +233,7 @@ class ModelsTable: def update_model_by_id(self, id: str, model: ModelForm) -> Optional[ModelModel]: try: - with get_db() as db: + async with get_db() as db: # update only the fields that are present in the model result = ( db.query(Model) @@ -251,7 +251,7 @@ class ModelsTable: def delete_model_by_id(self, id: str) -> bool: try: - with get_db() as db: + async with get_db() as db: db.query(Model).filter_by(id=id).delete() db.commit() @@ -261,7 +261,7 @@ class ModelsTable: def delete_all_models(self) -> bool: try: - with get_db() as db: + async with get_db() as db: db.query(Model).delete() db.commit() @@ -271,7 +271,7 @@ class ModelsTable: def sync_models(self, user_id: str, models: list[ModelModel]) -> list[ModelModel]: try: - with get_db() as db: + async with get_db() as db: # Get existing models existing_models = db.query(Model).all() existing_ids = {model.id for model in existing_models} diff --git a/backend/open_webui/models/notes.py b/backend/open_webui/models/notes.py index ce3b9f2e20..ec24a7de8b 100644 --- a/backend/open_webui/models/notes.py +++ b/backend/open_webui/models/notes.py @@ -79,7 +79,7 @@ class NoteTable: form_data: NoteForm, user_id: str, ) -> Optional[NoteModel]: - with get_db() as db: + async with get_db() as db: note = NoteModel( **{ "id": str(uuid.uuid4()), @@ -97,7 +97,7 @@ class NoteTable: return note def get_notes(self) -> list[NoteModel]: - with get_db() as db: + async with get_db() as db: notes = db.query(Note).order_by(Note.updated_at.desc()).all() return [NoteModel.model_validate(note) for note in notes] @@ -113,14 +113,14 @@ class NoteTable: ] def get_note_by_id(self, id: str) -> Optional[NoteModel]: - with get_db() as db: + async with get_db() as db: note = db.query(Note).filter(Note.id == id).first() return NoteModel.model_validate(note) if note else None def update_note_by_id( self, id: str, form_data: NoteUpdateForm ) -> Optional[NoteModel]: - with get_db() as db: + async with get_db() as db: note = db.query(Note).filter(Note.id == id).first() if not note: return None @@ -143,7 +143,7 @@ class NoteTable: return NoteModel.model_validate(note) if note else None def delete_note_by_id(self, id: str): - with get_db() as db: + async with get_db() as db: db.query(Note).filter(Note.id == id).delete() db.commit() return True diff --git a/backend/open_webui/models/prompts.py b/backend/open_webui/models/prompts.py index 8ef4cd2bec..faeb0b85a1 100644 --- a/backend/open_webui/models/prompts.py +++ b/backend/open_webui/models/prompts.py @@ -81,7 +81,7 @@ class PromptsTable: ) try: - with get_db() as db: + async with get_db() as db: result = Prompt(**prompt.model_dump()) db.add(result) db.commit() @@ -95,18 +95,18 @@ class PromptsTable: def get_prompt_by_command(self, command: str) -> Optional[PromptModel]: try: - with get_db() as db: + async with get_db() as db: prompt = db.query(Prompt).filter_by(command=command).first() return PromptModel.model_validate(prompt) except Exception: return None - def get_prompts(self) -> list[PromptUserResponse]: - with get_db() as db: + async def get_prompts(self) -> list[PromptUserResponse]: + async with get_db() as db: prompts = [] for prompt in db.query(Prompt).order_by(Prompt.timestamp.desc()).all(): - user = Users.get_user_by_id(prompt.user_id) + user = await Users.get_user_by_id(prompt.user_id) prompts.append( PromptUserResponse.model_validate( { @@ -134,7 +134,7 @@ class PromptsTable: self, command: str, form_data: PromptForm ) -> Optional[PromptModel]: try: - with get_db() as db: + async with get_db() as db: prompt = db.query(Prompt).filter_by(command=command).first() prompt.title = form_data.title prompt.content = form_data.content @@ -147,7 +147,7 @@ class PromptsTable: def delete_prompt_by_command(self, command: str) -> bool: try: - with get_db() as db: + async with get_db() as db: db.query(Prompt).filter_by(command=command).delete() db.commit() diff --git a/backend/open_webui/models/tags.py b/backend/open_webui/models/tags.py index 279dc624d5..d71c889233 100644 --- a/backend/open_webui/models/tags.py +++ b/backend/open_webui/models/tags.py @@ -48,7 +48,7 @@ class TagChatIdForm(BaseModel): class TagTable: def insert_new_tag(self, name: str, user_id: str) -> Optional[TagModel]: - with get_db() as db: + async with get_db() as db: id = name.replace(" ", "_").lower() tag = TagModel(**{"id": id, "user_id": user_id, "name": name}) try: @@ -69,14 +69,14 @@ class TagTable: ) -> Optional[TagModel]: try: id = name.replace(" ", "_").lower() - with get_db() as db: + async with get_db() as db: tag = db.query(Tag).filter_by(id=id, user_id=user_id).first() return TagModel.model_validate(tag) except Exception: return None def get_tags_by_user_id(self, user_id: str) -> list[TagModel]: - with get_db() as db: + async with get_db() as db: return [ TagModel.model_validate(tag) for tag in (db.query(Tag).filter_by(user_id=user_id).all()) @@ -85,7 +85,7 @@ class TagTable: def get_tags_by_ids_and_user_id( self, ids: list[str], user_id: str ) -> list[TagModel]: - with get_db() as db: + async with get_db() as db: return [ TagModel.model_validate(tag) for tag in ( @@ -95,7 +95,7 @@ class TagTable: def delete_tag_by_name_and_user_id(self, name: str, user_id: str) -> bool: try: - with get_db() as db: + async with get_db() as db: id = name.replace(" ", "_").lower() res = db.query(Tag).filter_by(id=id, user_id=user_id).delete() log.debug(f"res: {res}") diff --git a/backend/open_webui/models/tools.py b/backend/open_webui/models/tools.py index 68a83ea42c..f9a5d56829 100644 --- a/backend/open_webui/models/tools.py +++ b/backend/open_webui/models/tools.py @@ -110,7 +110,7 @@ class ToolsTable: def insert_new_tool( self, user_id: str, form_data: ToolForm, specs: list[dict] ) -> Optional[ToolModel]: - with get_db() as db: + async with get_db() as db: tool = ToolModel( **{ **form_data.model_dump(), @@ -136,17 +136,17 @@ class ToolsTable: def get_tool_by_id(self, id: str) -> Optional[ToolModel]: try: - with get_db() as db: + async with get_db() as db: tool = db.get(Tool, id) return ToolModel.model_validate(tool) except Exception: return None - def get_tools(self) -> list[ToolUserModel]: - with get_db() as db: + async def get_tools(self) -> list[ToolUserModel]: + async with get_db() as db: tools = [] for tool in db.query(Tool).order_by(Tool.updated_at.desc()).all(): - user = Users.get_user_by_id(tool.user_id) + user = await Users.get_user_by_id(tool.user_id) tools.append( ToolUserModel.model_validate( { @@ -171,7 +171,7 @@ class ToolsTable: def get_tool_valves_by_id(self, id: str) -> Optional[dict]: try: - with get_db() as db: + async with get_db() as db: tool = db.get(Tool, id) return tool.valves if tool.valves else {} except Exception as e: @@ -180,7 +180,7 @@ class ToolsTable: def update_tool_valves_by_id(self, id: str, valves: dict) -> Optional[ToolValves]: try: - with get_db() as db: + async with get_db() as db: db.query(Tool).filter_by(id=id).update( {"valves": valves, "updated_at": int(time.time())} ) @@ -189,11 +189,11 @@ class ToolsTable: except Exception: return None - def get_user_valves_by_id_and_user_id( + async def get_user_valves_by_id_and_user_id( self, id: str, user_id: str ) -> Optional[dict]: try: - user = Users.get_user_by_id(user_id) + user = await Users.get_user_by_id(user_id) user_settings = user.settings.model_dump() if user.settings else {} # Check if user has "tools" and "valves" settings @@ -209,11 +209,11 @@ class ToolsTable: ) return None - def update_user_valves_by_id_and_user_id( + async def update_user_valves_by_id_and_user_id( self, id: str, user_id: str, valves: dict ) -> Optional[dict]: try: - user = Users.get_user_by_id(user_id) + user = await Users.get_user_by_id(user_id) user_settings = user.settings.model_dump() if user.settings else {} # Check if user has "tools" and "valves" settings @@ -236,7 +236,7 @@ class ToolsTable: def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]: try: - with get_db() as db: + async with get_db() as db: db.query(Tool).filter_by(id=id).update( {**updated, "updated_at": int(time.time())} ) @@ -250,7 +250,7 @@ class ToolsTable: def delete_tool_by_id(self, id: str) -> bool: try: - with get_db() as db: + async with get_db() as db: db.query(Tool).filter_by(id=id).delete() db.commit() diff --git a/backend/open_webui/models/users.py b/backend/open_webui/models/users.py index 60b6ad0c10..12b67cf22f 100644 --- a/backend/open_webui/models/users.py +++ b/backend/open_webui/models/users.py @@ -115,7 +115,7 @@ class UserUpdateForm(BaseModel): class UsersTable: - def insert_new_user( + async def insert_new_user( self, id: str, name: str, @@ -124,7 +124,7 @@ class UsersTable: role: str = "pending", oauth_sub: Optional[str] = None, ) -> Optional[UserModel]: - with get_db() as db: + async with get_db() as db: user = UserModel( **{ "id": id, @@ -139,53 +139,53 @@ class UsersTable: } ) result = User(**user.model_dump()) - db.add(result) - db.commit() - db.refresh(result) + await db.add(result) + await db.commit() + await db.refresh(result) if result: return user else: return None - def get_user_by_id(self, id: str) -> Optional[UserModel]: + async def get_user_by_id(self, id: str) -> Optional[UserModel]: try: - with get_db() as db: - user = db.query(User).filter_by(id=id).first() + async with get_db() as db: + user = await db.query(User).filter_by(id=id).first() return UserModel.model_validate(user) except Exception: return None - def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]: + async def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]: try: - with get_db() as db: - user = db.query(User).filter_by(api_key=api_key).first() + async with get_db() as db: + user = await db.query(User).filter_by(api_key=api_key).first() return UserModel.model_validate(user) except Exception: return None - def get_user_by_email(self, email: str) -> Optional[UserModel]: + async def get_user_by_email(self, email: str) -> Optional[UserModel]: try: - with get_db() as db: - user = db.query(User).filter_by(email=email).first() + async with get_db() as db: + user = await db.query(User).filter_by(email=email).first() return UserModel.model_validate(user) except Exception: return None - def get_user_by_oauth_sub(self, sub: str) -> Optional[UserModel]: + async def get_user_by_oauth_sub(self, sub: str) -> Optional[UserModel]: try: - with get_db() as db: - user = db.query(User).filter_by(oauth_sub=sub).first() + async with get_db() as db: + user = await db.query(User).filter_by(oauth_sub=sub).first() return UserModel.model_validate(user) except Exception: return None - def get_users( + async def get_users( self, filter: Optional[dict] = None, skip: Optional[int] = None, limit: Optional[int] = None, ) -> UserListResponse: - with get_db() as db: + async with get_db() as db: query = db.query(User) if filter: @@ -243,37 +243,37 @@ class UsersTable: if limit: query = query.limit(limit) - users = query.all() + users = await query.all() return { "users": [UserModel.model_validate(user) for user in users], - "total": db.query(User).count(), + "total": await db.query(User).count(), } - def get_users_by_user_ids(self, user_ids: list[str]) -> list[UserModel]: - with get_db() as db: - users = db.query(User).filter(User.id.in_(user_ids)).all() + async def get_users_by_user_ids(self, user_ids: list[str]) -> list[UserModel]: + async with get_db() as db: + users = await db.query(User).filter(User.id.in_(user_ids)).all() return [UserModel.model_validate(user) for user in users] - def get_num_users(self) -> Optional[int]: - with get_db() as db: - return db.query(User).count() + async def get_num_users(self) -> Optional[int]: + async with get_db() as db: + return await db.query(User).count() - def has_users(self) -> bool: - with get_db() as db: - return db.query(db.query(User).exists()).scalar() + async def has_users(self) -> bool: + async with get_db() as db: + return await db.query(db.query(User).exists()).scalar() - def get_first_user(self) -> UserModel: + async def get_first_user(self) -> UserModel: try: - with get_db() as db: - user = db.query(User).order_by(User.created_at).first() + async with get_db() as db: + user = await db.query(User).order_by(User.created_at).first() return UserModel.model_validate(user) except Exception: return None - def get_user_webhook_url_by_id(self, id: str) -> Optional[str]: + async def get_user_webhook_url_by_id(self, id: str) -> Optional[str]: try: - with get_db() as db: - user = db.query(User).filter_by(id=id).first() + async with get_db() as db: + user = await db.query(User).filter_by(id=id).first() if user.settings is None: return None @@ -286,99 +286,103 @@ class UsersTable: except Exception: return None - def update_user_role_by_id(self, id: str, role: str) -> Optional[UserModel]: + async def update_user_role_by_id(self, id: str, role: str) -> Optional[UserModel]: try: - with get_db() as db: - db.query(User).filter_by(id=id).update({"role": role}) - db.commit() - user = db.query(User).filter_by(id=id).first() + async with get_db() as db: + await db.query(User).filter_by(id=id).update({"role": role}) + await db.commit() + user = await db.query(User).filter_by(id=id).first() return UserModel.model_validate(user) except Exception: return None - def update_user_profile_image_url_by_id( + async def update_user_profile_image_url_by_id( self, id: str, profile_image_url: str ) -> Optional[UserModel]: try: - with get_db() as db: - db.query(User).filter_by(id=id).update( + async with get_db() as db: + await db.query(User).filter_by(id=id).update( {"profile_image_url": profile_image_url} ) - db.commit() + await db.commit() - user = db.query(User).filter_by(id=id).first() + user = await db.query(User).filter_by(id=id).first() return UserModel.model_validate(user) except Exception: return None - def update_user_last_active_by_id(self, id: str) -> Optional[UserModel]: + async def update_user_last_active_by_id(self, id: str) -> Optional[UserModel]: try: - with get_db() as db: - db.query(User).filter_by(id=id).update( + async with get_db() as db: + await db.query(User).filter_by(id=id).update( {"last_active_at": int(time.time())} ) - db.commit() + await db.commit() - user = db.query(User).filter_by(id=id).first() + user = await db.query(User).filter_by(id=id).first() return UserModel.model_validate(user) except Exception: return None - def update_user_oauth_sub_by_id( + async def update_user_oauth_sub_by_id( self, id: str, oauth_sub: str ) -> Optional[UserModel]: try: - with get_db() as db: - db.query(User).filter_by(id=id).update({"oauth_sub": oauth_sub}) - db.commit() + async with get_db() as db: + await db.query(User).filter_by(id=id).update({"oauth_sub": oauth_sub}) + await db.commit() - user = db.query(User).filter_by(id=id).first() + user = await db.query(User).filter_by(id=id).first() return UserModel.model_validate(user) except Exception: return None - def update_user_by_id(self, id: str, updated: dict) -> Optional[UserModel]: + async def update_user_by_id(self, id: str, updated: dict) -> Optional[UserModel]: try: - with get_db() as db: - db.query(User).filter_by(id=id).update(updated) - db.commit() + async with get_db() as db: + await db.query(User).filter_by(id=id).update(updated) + await db.commit() - user = db.query(User).filter_by(id=id).first() + user = await db.query(User).filter_by(id=id).first() return UserModel.model_validate(user) # return UserModel(**user.dict()) except Exception: return None - def update_user_settings_by_id(self, id: str, updated: dict) -> Optional[UserModel]: + async def update_user_settings_by_id( + self, id: str, updated: dict + ) -> Optional[UserModel]: try: - with get_db() as db: - user_settings = db.query(User).filter_by(id=id).first().settings + async with get_db() as db: + user_settings = await db.query(User).filter_by(id=id).first().settings if user_settings is None: user_settings = {} user_settings.update(updated) - db.query(User).filter_by(id=id).update({"settings": user_settings}) - db.commit() + await db.query(User).filter_by(id=id).update( + {"settings": user_settings} + ) + await db.commit() - user = db.query(User).filter_by(id=id).first() + user = await db.query(User).filter_by(id=id).first() return UserModel.model_validate(user) except Exception: return None - def delete_user_by_id(self, id: str) -> bool: + async def delete_user_by_id(self, id: str) -> bool: try: # Remove User from Groups - Groups.remove_user_from_all_groups(id) + await Groups.remove_user_from_all_groups(id) # Delete User Chats - result = Chats.delete_chats_by_user_id(id) + result = await Chats.delete_chats_by_user_id(id) if result: - with get_db() as db: + async with get_db() as db: # Delete User - db.query(User).filter_by(id=id).delete() - db.commit() + await db.query(User).filter_by(id=id).delete() + await db.commit() return True else: @@ -386,31 +390,33 @@ class UsersTable: except Exception: return False - def update_user_api_key_by_id(self, id: str, api_key: str) -> bool: + async def update_user_api_key_by_id(self, id: str, api_key: str) -> bool: try: - with get_db() as db: - result = db.query(User).filter_by(id=id).update({"api_key": api_key}) - db.commit() + async with get_db() as db: + result = ( + await db.query(User).filter_by(id=id).update({"api_key": api_key}) + ) + await db.commit() return True if result == 1 else False except Exception: return False - def get_user_api_key_by_id(self, id: str) -> Optional[str]: + async def get_user_api_key_by_id(self, id: str) -> Optional[str]: try: - with get_db() as db: - user = db.query(User).filter_by(id=id).first() + async with get_db() as db: + user = await db.query(User).filter_by(id=id).first() return user.api_key except Exception: return None - def get_valid_user_ids(self, user_ids: list[str]) -> list[str]: - with get_db() as db: - users = db.query(User).filter(User.id.in_(user_ids)).all() + async def get_valid_user_ids(self, user_ids: list[str]) -> list[str]: + async with get_db() as db: + users = await db.query(User).filter(User.id.in_(user_ids)).all() return [user.id for user in users] - def get_super_admin_user(self) -> Optional[UserModel]: - with get_db() as db: - user = db.query(User).filter_by(role="admin").first() + async def get_super_admin_user(self) -> Optional[UserModel]: + async with get_db() as db: + user = await db.query(User).filter_by(role="admin").first() if user: return UserModel.model_validate(user) else: diff --git a/backend/open_webui/routers/auths.py b/backend/open_webui/routers/auths.py index e157e5527d..cab71d076b 100644 --- a/backend/open_webui/routers/auths.py +++ b/backend/open_webui/routers/auths.py @@ -159,11 +159,11 @@ async def update_password( if WEBUI_AUTH_TRUSTED_EMAIL_HEADER: raise HTTPException(400, detail=ERROR_MESSAGES.ACTION_PROHIBITED) if session_user: - user = Auths.authenticate_user(session_user.email, form_data.password) + user = await Auths.authenticate_user(session_user.email, form_data.password) if user: hashed = get_password_hash(form_data.new_password) - return Auths.update_user_password_by_id(user.id, hashed) + return await Auths.update_user_password_by_id(user.id, hashed) else: raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_PASSWORD) else: @@ -357,7 +357,7 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm): else request.app.state.config.DEFAULT_USER_ROLE ) - user = Auths.insert_new_auth( + user = await Auths.insert_new_auth( email=email, password=str(uuid.uuid4()), name=cn, @@ -377,7 +377,7 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm): 500, detail="Internal error occurred during LDAP user creation." ) - user = Auths.authenticate_user_by_email(email) + user = await Auths.authenticate_user_by_email(email) if user: expires_delta = parse_duration(request.app.state.config.JWT_EXPIRES_IN) @@ -470,7 +470,7 @@ async def signin(request: Request, response: Response, form_data: SigninForm): SignupForm(email=email, password=str(uuid.uuid4()), name=name), ) - user = Auths.authenticate_user_by_email(email) + user = await Auths.authenticate_user_by_email(email) if WEBUI_AUTH_TRUSTED_GROUPS_HEADER and user and user.role != "admin": group_names = request.headers.get( WEBUI_AUTH_TRUSTED_GROUPS_HEADER, "" @@ -485,7 +485,7 @@ async def signin(request: Request, response: Response, form_data: SigninForm): admin_password = "admin" if Users.get_user_by_email(admin_email.lower()): - user = Auths.authenticate_user(admin_email.lower(), admin_password) + user = await Auths.authenticate_user(admin_email.lower(), admin_password) else: if Users.has_users(): raise HTTPException(400, detail=ERROR_MESSAGES.EXISTING_USERS) @@ -496,9 +496,11 @@ async def signin(request: Request, response: Response, form_data: SigninForm): SignupForm(email=admin_email, password=admin_password, name="User"), ) - user = Auths.authenticate_user(admin_email.lower(), admin_password) + user = await Auths.authenticate_user(admin_email.lower(), admin_password) else: - user = Auths.authenticate_user(form_data.email.lower(), form_data.password) + user = await Auths.authenticate_user( + form_data.email.lower(), form_data.password + ) if user: @@ -589,7 +591,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm): ) hashed = get_password_hash(form_data.password) - user = Auths.insert_new_auth( + user = await Auths.insert_new_auth( form_data.email.lower(), hashed, form_data.name, @@ -736,7 +738,7 @@ async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)): try: hashed = get_password_hash(form_data.password) - user = Auths.insert_new_auth( + user = await Auths.insert_new_auth( form_data.email.lower(), hashed, form_data.name, diff --git a/backend/open_webui/routers/channels.py b/backend/open_webui/routers/channels.py index e4390e23f6..df606d61cc 100644 --- a/backend/open_webui/routers/channels.py +++ b/backend/open_webui/routers/channels.py @@ -40,14 +40,14 @@ router = APIRouter() @router.get("/", response_model=list[ChannelModel]) async def get_channels(user=Depends(get_verified_user)): - return Channels.get_channels_by_user_id(user.id) + return await Channels.get_channels_by_user_id(user.id) @router.get("/list", response_model=list[ChannelModel]) async def get_all_channels(user=Depends(get_verified_user)): if user.role == "admin": - return Channels.get_channels() - return Channels.get_channels_by_user_id(user.id) + return await Channels.get_channels() + return await Channels.get_channels_by_user_id(user.id) ############################ @@ -58,7 +58,7 @@ async def get_all_channels(user=Depends(get_verified_user)): @router.post("/create", response_model=Optional[ChannelModel]) async def create_new_channel(form_data: ChannelForm, user=Depends(get_admin_user)): try: - channel = Channels.insert_new_channel(None, form_data, user.id) + channel = await Channels.insert_new_channel(None, form_data, user.id) return ChannelModel(**channel.model_dump()) except Exception as e: log.exception(e) @@ -74,7 +74,7 @@ async def create_new_channel(form_data: ChannelForm, user=Depends(get_admin_user @router.get("/{id}", response_model=Optional[ChannelModel]) async def get_channel_by_id(id: str, user=Depends(get_verified_user)): - channel = Channels.get_channel_by_id(id) + channel = await Channels.get_channel_by_id(id) if not channel: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND @@ -99,14 +99,14 @@ async def get_channel_by_id(id: str, user=Depends(get_verified_user)): async def update_channel_by_id( id: str, form_data: ChannelForm, user=Depends(get_admin_user) ): - channel = Channels.get_channel_by_id(id) + channel = await Channels.get_channel_by_id(id) if not channel: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND ) try: - channel = Channels.update_channel_by_id(id, form_data) + channel = await Channels.update_channel_by_id(id, form_data) return ChannelModel(**channel.model_dump()) except Exception as e: log.exception(e) @@ -122,14 +122,14 @@ async def update_channel_by_id( @router.delete("/{id}/delete", response_model=bool) async def delete_channel_by_id(id: str, user=Depends(get_admin_user)): - channel = Channels.get_channel_by_id(id) + channel = await Channels.get_channel_by_id(id) if not channel: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND ) try: - Channels.delete_channel_by_id(id) + await Channels.delete_channel_by_id(id) return True except Exception as e: log.exception(e) @@ -151,7 +151,7 @@ class MessageUserResponse(MessageResponse): async def get_channel_messages( id: str, skip: int = 0, limit: int = 50, user=Depends(get_verified_user) ): - channel = Channels.get_channel_by_id(id) + channel = await Channels.get_channel_by_id(id) if not channel: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND @@ -170,7 +170,7 @@ async def get_channel_messages( messages = [] for message in message_list: if message.user_id not in users: - user = Users.get_user_by_id(message.user_id) + user = await Users.get_user_by_id(message.user_id) users[message.user_id] = user replies = Messages.get_replies_by_message_id(message.id) @@ -230,7 +230,7 @@ async def post_new_message( background_tasks: BackgroundTasks, user=Depends(get_verified_user), ): - channel = Channels.get_channel_by_id(id) + channel = await Channels.get_channel_by_id(id) if not channel: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND @@ -290,9 +290,13 @@ async def post_new_message( **{ **parent_message.model_dump(), "user": UserNameResponse( - **Users.get_user_by_id( - parent_message.user_id - ).model_dump() + **( + ( + await Users.get_user_by_id( + parent_message.user_id + ) + ).model_dump() + ) ), } ).model_dump(), @@ -331,7 +335,7 @@ async def post_new_message( async def get_channel_message( id: str, message_id: str, user=Depends(get_verified_user) ): - channel = Channels.get_channel_by_id(id) + channel = await Channels.get_channel_by_id(id) if not channel: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND @@ -359,7 +363,7 @@ async def get_channel_message( **{ **message.model_dump(), "user": UserNameResponse( - **Users.get_user_by_id(message.user_id).model_dump() + **((await Users.get_user_by_id(message.user_id)).model_dump()) ), } ) @@ -380,7 +384,7 @@ async def get_channel_thread_messages( limit: int = 50, user=Depends(get_verified_user), ): - channel = Channels.get_channel_by_id(id) + channel = await Channels.get_channel_by_id(id) if not channel: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND @@ -399,7 +403,7 @@ async def get_channel_thread_messages( messages = [] for message in message_list: if message.user_id not in users: - user = Users.get_user_by_id(message.user_id) + user = await Users.get_user_by_id(message.user_id) users[message.user_id] = user messages.append( @@ -428,7 +432,7 @@ async def get_channel_thread_messages( async def update_message_by_id( id: str, message_id: str, form_data: MessageForm, user=Depends(get_verified_user) ): - channel = Channels.get_channel_by_id(id) + channel = await Channels.get_channel_by_id(id) if not channel: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND @@ -502,7 +506,7 @@ class ReactionForm(BaseModel): async def add_reaction_to_message( id: str, message_id: str, form_data: ReactionForm, user=Depends(get_verified_user) ): - channel = Channels.get_channel_by_id(id) + channel = await Channels.get_channel_by_id(id) if not channel: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND @@ -540,7 +544,7 @@ async def add_reaction_to_message( "data": { **message.model_dump(), "user": UserNameResponse( - **Users.get_user_by_id(message.user_id).model_dump() + **(await Users.get_user_by_id(message.user_id)).model_dump() ).model_dump(), "name": form_data.name, }, @@ -568,7 +572,7 @@ async def add_reaction_to_message( async def remove_reaction_by_id_and_user_id_and_name( id: str, message_id: str, form_data: ReactionForm, user=Depends(get_verified_user) ): - channel = Channels.get_channel_by_id(id) + channel = await Channels.get_channel_by_id(id) if not channel: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND @@ -609,7 +613,7 @@ async def remove_reaction_by_id_and_user_id_and_name( "data": { **message.model_dump(), "user": UserNameResponse( - **Users.get_user_by_id(message.user_id).model_dump() + **(await Users.get_user_by_id(message.user_id)).model_dump() ).model_dump(), "name": form_data.name, }, @@ -637,7 +641,7 @@ async def remove_reaction_by_id_and_user_id_and_name( async def delete_message_by_id( id: str, message_id: str, user=Depends(get_verified_user) ): - channel = Channels.get_channel_by_id(id) + channel = await Channels.get_channel_by_id(id) if not channel: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND @@ -699,8 +703,10 @@ async def delete_message_by_id( **{ **parent_message.model_dump(), "user": UserNameResponse( - **Users.get_user_by_id( - parent_message.user_id + **( + await Users.get_user_by_id( + parent_message.user_id + ) ).model_dump() ), } diff --git a/backend/open_webui/routers/chats.py b/backend/open_webui/routers/chats.py index ba16b506f7..88f0ad625f 100644 --- a/backend/open_webui/routers/chats.py +++ b/backend/open_webui/routers/chats.py @@ -44,11 +44,11 @@ async def get_session_user_chat_list( limit = 60 skip = (page - 1) * limit - return Chats.get_chat_title_id_list_by_user_id( + return await Chats.get_chat_title_id_list_by_user_id( user.id, skip=skip, limit=limit ) else: - return Chats.get_chat_title_id_list_by_user_id(user.id) + return await Chats.get_chat_title_id_list_by_user_id(user.id) except Exception as e: log.exception(e) raise HTTPException( @@ -72,7 +72,7 @@ async def delete_all_user_chats(request: Request, user=Depends(get_verified_user detail=ERROR_MESSAGES.ACCESS_PROHIBITED, ) - result = Chats.delete_chats_by_user_id(user.id) + result = await Chats.delete_chats_by_user_id(user.id) return result @@ -110,7 +110,7 @@ async def get_user_chat_list_by_user_id( if direction: filter["direction"] = direction - return Chats.get_chat_list_by_user_id( + return await Chats.get_chat_list_by_user_id( user_id, include_archived=True, filter=filter, skip=skip, limit=limit ) @@ -123,7 +123,7 @@ async def get_user_chat_list_by_user_id( @router.post("/new", response_model=Optional[ChatResponse]) async def create_new_chat(form_data: ChatForm, user=Depends(get_verified_user)): try: - chat = Chats.insert_new_chat(user.id, form_data) + chat = await Chats.insert_new_chat(user.id, form_data) return ChatResponse(**chat.model_dump()) except Exception as e: log.exception(e) @@ -140,7 +140,7 @@ async def create_new_chat(form_data: ChatForm, user=Depends(get_verified_user)): @router.post("/import", response_model=Optional[ChatResponse]) async def import_chat(form_data: ChatImportForm, user=Depends(get_verified_user)): try: - chat = Chats.import_chat(user.id, form_data) + chat = await Chats.import_chat(user.id, form_data) if chat: tags = chat.meta.get("tags", []) for tag_id in tags: @@ -177,7 +177,7 @@ async def search_user_chats( chat_list = [ ChatTitleIdResponse(**chat.model_dump()) - for chat in Chats.get_chats_by_user_id_and_search_text( + for chat in await Chats.get_chats_by_user_id_and_search_text( user.id, text, skip=skip, limit=limit ) ] @@ -187,9 +187,9 @@ async def search_user_chats( if page == 1 and len(words) == 1 and words[0].startswith("tag:"): tag_id = words[0].replace("tag:", "") if len(chat_list) == 0: - if Tags.get_tag_by_name_and_user_id(tag_id, user.id): + if await Tags.get_tag_by_name_and_user_id(tag_id, user.id): log.debug(f"deleting tag: {tag_id}") - Tags.delete_tag_by_name_and_user_id(tag_id, user.id) + await Tags.delete_tag_by_name_and_user_id(tag_id, user.id) return chat_list @@ -210,7 +210,7 @@ async def get_chats_by_folder_id(folder_id: str, user=Depends(get_verified_user) return [ ChatResponse(**chat.model_dump()) - for chat in Chats.get_chats_by_folder_ids_and_user_id(folder_ids, user.id) + for chat in await Chats.get_chats_by_folder_ids_and_user_id(folder_ids, user.id) ] @@ -223,7 +223,7 @@ async def get_chats_by_folder_id(folder_id: str, user=Depends(get_verified_user) async def get_user_pinned_chats(user=Depends(get_verified_user)): return [ ChatTitleIdResponse(**chat.model_dump()) - for chat in Chats.get_pinned_chats_by_user_id(user.id) + for chat in await Chats.get_pinned_chats_by_user_id(user.id) ] @@ -236,7 +236,7 @@ async def get_user_pinned_chats(user=Depends(get_verified_user)): async def get_user_chats(user=Depends(get_verified_user)): return [ ChatResponse(**chat.model_dump()) - for chat in Chats.get_chats_by_user_id(user.id) + for chat in await Chats.get_chats_by_user_id(user.id) ] @@ -249,7 +249,7 @@ async def get_user_chats(user=Depends(get_verified_user)): async def get_user_archived_chats(user=Depends(get_verified_user)): return [ ChatResponse(**chat.model_dump()) - for chat in Chats.get_archived_chats_by_user_id(user.id) + for chat in await Chats.get_archived_chats_by_user_id(user.id) ] @@ -282,7 +282,7 @@ async def get_all_user_chats_in_db(user=Depends(get_admin_user)): status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED, ) - return [ChatResponse(**chat.model_dump()) for chat in Chats.get_chats()] + return [ChatResponse(**chat.model_dump()) for chat in await Chats.get_chats()] ############################ @@ -314,7 +314,7 @@ async def get_archived_session_user_chat_list( chat_list = [ ChatTitleIdResponse(**chat.model_dump()) - for chat in Chats.get_archived_chat_list_by_user_id( + for chat in await Chats.get_archived_chat_list_by_user_id( user.id, filter=filter, skip=skip, @@ -332,7 +332,7 @@ async def get_archived_session_user_chat_list( @router.post("/archive/all", response_model=bool) async def archive_all_chats(user=Depends(get_verified_user)): - return Chats.archive_all_chats_by_user_id(user.id) + return await Chats.archive_all_chats_by_user_id(user.id) ############################ @@ -348,9 +348,9 @@ async def get_shared_chat_by_id(share_id: str, user=Depends(get_verified_user)): ) if user.role == "user" or (user.role == "admin" and not ENABLE_ADMIN_CHAT_ACCESS): - chat = Chats.get_chat_by_share_id(share_id) + chat = await Chats.get_chat_by_share_id(share_id) elif user.role == "admin" and ENABLE_ADMIN_CHAT_ACCESS: - chat = Chats.get_chat_by_id(share_id) + chat = await Chats.get_chat_by_id(share_id) if chat: return ChatResponse(**chat.model_dump()) @@ -379,11 +379,11 @@ class TagFilterForm(TagForm): async def get_user_chat_list_by_tag_name( form_data: TagFilterForm, user=Depends(get_verified_user) ): - chats = Chats.get_chat_list_by_user_id_and_tag_name( + chats = await Chats.get_chat_list_by_user_id_and_tag_name( user.id, form_data.name, form_data.skip, form_data.limit ) if len(chats) == 0: - Tags.delete_tag_by_name_and_user_id(form_data.name, user.id) + await Tags.delete_tag_by_name_and_user_id(form_data.name, user.id) return chats @@ -395,7 +395,7 @@ async def get_user_chat_list_by_tag_name( @router.get("/{id}", response_model=Optional[ChatResponse]) async def get_chat_by_id(id: str, user=Depends(get_verified_user)): - chat = Chats.get_chat_by_id_and_user_id(id, user.id) + chat = await Chats.get_chat_by_id_and_user_id(id, user.id) if chat: return ChatResponse(**chat.model_dump()) @@ -415,10 +415,10 @@ async def get_chat_by_id(id: str, user=Depends(get_verified_user)): async def update_chat_by_id( id: str, form_data: ChatForm, user=Depends(get_verified_user) ): - chat = Chats.get_chat_by_id_and_user_id(id, user.id) + chat = await Chats.get_chat_by_id_and_user_id(id, user.id) if chat: updated_chat = {**chat.chat, **form_data.chat} - chat = Chats.update_chat_by_id(id, updated_chat) + chat = await Chats.update_chat_by_id(id, updated_chat) return ChatResponse(**chat.model_dump()) else: raise HTTPException( @@ -438,7 +438,7 @@ class MessageForm(BaseModel): async def update_chat_message_by_id( id: str, message_id: str, form_data: MessageForm, user=Depends(get_verified_user) ): - chat = Chats.get_chat_by_id(id) + chat = await Chats.get_chat_by_id(id) if not chat: raise HTTPException( @@ -452,7 +452,7 @@ async def update_chat_message_by_id( detail=ERROR_MESSAGES.ACCESS_PROHIBITED, ) - chat = Chats.upsert_message_to_chat_by_id_and_message_id( + chat = await Chats.upsert_message_to_chat_by_id_and_message_id( id, message_id, { @@ -496,7 +496,7 @@ class EventForm(BaseModel): async def send_chat_message_event_by_id( id: str, message_id: str, form_data: EventForm, user=Depends(get_verified_user) ): - chat = Chats.get_chat_by_id(id) + chat = await Chats.get_chat_by_id(id) if not chat: raise HTTPException( @@ -536,12 +536,12 @@ async def send_chat_message_event_by_id( @router.delete("/{id}", response_model=bool) async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified_user)): if user.role == "admin": - chat = Chats.get_chat_by_id(id) + chat = await Chats.get_chat_by_id(id) for tag in chat.meta.get("tags", []): - if Chats.count_chats_by_tag_name_and_user_id(tag, user.id) == 1: - Tags.delete_tag_by_name_and_user_id(tag, user.id) + if await Chats.count_chats_by_tag_name_and_user_id(tag, user.id) == 1: + await Tags.delete_tag_by_name_and_user_id(tag, user.id) - result = Chats.delete_chat_by_id(id) + result = await Chats.delete_chat_by_id(id) return result else: @@ -553,12 +553,12 @@ async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified detail=ERROR_MESSAGES.ACCESS_PROHIBITED, ) - chat = Chats.get_chat_by_id(id) + chat = await Chats.get_chat_by_id(id) for tag in chat.meta.get("tags", []): - if Chats.count_chats_by_tag_name_and_user_id(tag, user.id) == 1: + if await Chats.count_chats_by_tag_name_and_user_id(tag, user.id) == 1: Tags.delete_tag_by_name_and_user_id(tag, user.id) - result = Chats.delete_chat_by_id_and_user_id(id, user.id) + result = await Chats.delete_chat_by_id_and_user_id(id, user.id) return result @@ -569,7 +569,7 @@ async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified @router.get("/{id}/pinned", response_model=Optional[bool]) async def get_pinned_status_by_id(id: str, user=Depends(get_verified_user)): - chat = Chats.get_chat_by_id_and_user_id(id, user.id) + chat = await Chats.get_chat_by_id_and_user_id(id, user.id) if chat: return chat.pinned else: @@ -585,9 +585,9 @@ async def get_pinned_status_by_id(id: str, user=Depends(get_verified_user)): @router.post("/{id}/pin", response_model=Optional[ChatResponse]) async def pin_chat_by_id(id: str, user=Depends(get_verified_user)): - chat = Chats.get_chat_by_id_and_user_id(id, user.id) + chat = await Chats.get_chat_by_id_and_user_id(id, user.id) if chat: - chat = Chats.toggle_chat_pinned_by_id(id) + chat = await Chats.toggle_chat_pinned_by_id(id) return chat else: raise HTTPException( @@ -608,7 +608,7 @@ class CloneForm(BaseModel): async def clone_chat_by_id( form_data: CloneForm, id: str, user=Depends(get_verified_user) ): - chat = Chats.get_chat_by_id_and_user_id(id, user.id) + chat = await Chats.get_chat_by_id_and_user_id(id, user.id) if chat: updated_chat = { **chat.chat, @@ -617,7 +617,7 @@ async def clone_chat_by_id( "title": form_data.title if form_data.title else f"Clone of {chat.title}", } - chat = Chats.import_chat( + chat = await Chats.import_chat( user.id, ChatImportForm( **{ @@ -645,9 +645,9 @@ async def clone_chat_by_id( async def clone_shared_chat_by_id(id: str, user=Depends(get_verified_user)): if user.role == "admin": - chat = Chats.get_chat_by_id(id) + chat = await Chats.get_chat_by_id(id) else: - chat = Chats.get_chat_by_share_id(id) + chat = await Chats.get_chat_by_share_id(id) if chat: updated_chat = { @@ -657,7 +657,7 @@ async def clone_shared_chat_by_id(id: str, user=Depends(get_verified_user)): "title": f"Clone of {chat.title}", } - chat = Chats.import_chat( + chat = await Chats.import_chat( user.id, ChatImportForm( **{ @@ -682,22 +682,25 @@ async def clone_shared_chat_by_id(id: str, user=Depends(get_verified_user)): @router.post("/{id}/archive", response_model=Optional[ChatResponse]) async def archive_chat_by_id(id: str, user=Depends(get_verified_user)): - chat = Chats.get_chat_by_id_and_user_id(id, user.id) + chat = await Chats.get_chat_by_id_and_user_id(id, user.id) if chat: - chat = Chats.toggle_chat_archive_by_id(id) + chat = await Chats.toggle_chat_archive_by_id(id) # Delete tags if chat is archived if chat.archived: for tag_id in chat.meta.get("tags", []): - if Chats.count_chats_by_tag_name_and_user_id(tag_id, user.id) == 0: + if ( + await Chats.count_chats_by_tag_name_and_user_id(tag_id, user.id) + == 0 + ): log.debug(f"deleting tag: {tag_id}") Tags.delete_tag_by_name_and_user_id(tag_id, user.id) else: for tag_id in chat.meta.get("tags", []): - tag = Tags.get_tag_by_name_and_user_id(tag_id, user.id) + tag = await Tags.get_tag_by_name_and_user_id(tag_id, user.id) if tag is None: log.debug(f"inserting tag: {tag_id}") - tag = Tags.insert_new_tag(tag_id, user.id) + tag = await Tags.insert_new_tag(tag_id, user.id) return ChatResponse(**chat.model_dump()) else: @@ -723,14 +726,14 @@ async def share_chat_by_id(request: Request, id: str, user=Depends(get_verified_ detail=ERROR_MESSAGES.ACCESS_PROHIBITED, ) - chat = Chats.get_chat_by_id_and_user_id(id, user.id) + chat = await Chats.get_chat_by_id_and_user_id(id, user.id) if chat: if chat.share_id: - shared_chat = Chats.update_shared_chat_by_chat_id(chat.id) + shared_chat = await Chats.update_shared_chat_by_chat_id(chat.id) return ChatResponse(**shared_chat.model_dump()) - shared_chat = Chats.insert_shared_chat_by_chat_id(chat.id) + shared_chat = await Chats.insert_shared_chat_by_chat_id(chat.id) if not shared_chat: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, @@ -752,13 +755,13 @@ async def share_chat_by_id(request: Request, id: str, user=Depends(get_verified_ @router.delete("/{id}/share", response_model=Optional[bool]) async def delete_shared_chat_by_id(id: str, user=Depends(get_verified_user)): - chat = Chats.get_chat_by_id_and_user_id(id, user.id) + chat = await Chats.get_chat_by_id_and_user_id(id, user.id) if chat: if not chat.share_id: return False - result = Chats.delete_shared_chat_by_chat_id(id) - update_result = Chats.update_chat_share_id_by_id(id, None) + result = await Chats.delete_shared_chat_by_chat_id(id) + update_result = await Chats.update_chat_share_id_by_id(id, None) return result and update_result != None else: @@ -781,9 +784,9 @@ class ChatFolderIdForm(BaseModel): async def update_chat_folder_id_by_id( id: str, form_data: ChatFolderIdForm, user=Depends(get_verified_user) ): - chat = Chats.get_chat_by_id_and_user_id(id, user.id) + chat = await Chats.get_chat_by_id_and_user_id(id, user.id) if chat: - chat = Chats.update_chat_folder_id_by_id_and_user_id( + chat = await Chats.update_chat_folder_id_by_id_and_user_id( id, user.id, form_data.folder_id ) return ChatResponse(**chat.model_dump()) @@ -800,10 +803,10 @@ async def update_chat_folder_id_by_id( @router.get("/{id}/tags", response_model=list[TagModel]) async def get_chat_tags_by_id(id: str, user=Depends(get_verified_user)): - chat = Chats.get_chat_by_id_and_user_id(id, user.id) + chat = await Chats.get_chat_by_id_and_user_id(id, user.id) if chat: tags = chat.meta.get("tags", []) - return Tags.get_tags_by_ids_and_user_id(tags, user.id) + return await Tags.get_tags_by_ids_and_user_id(tags, user.id) else: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND @@ -819,7 +822,7 @@ async def get_chat_tags_by_id(id: str, user=Depends(get_verified_user)): async def add_tag_by_id_and_tag_name( id: str, form_data: TagForm, user=Depends(get_verified_user) ): - chat = Chats.get_chat_by_id_and_user_id(id, user.id) + chat = await Chats.get_chat_by_id_and_user_id(id, user.id) if chat: tags = chat.meta.get("tags", []) tag_id = form_data.name.replace(" ", "_").lower() @@ -831,13 +834,13 @@ async def add_tag_by_id_and_tag_name( ) if tag_id not in tags: - Chats.add_chat_tag_by_id_and_user_id_and_tag_name( + await Chats.add_chat_tag_by_id_and_user_id_and_tag_name( id, user.id, form_data.name ) - chat = Chats.get_chat_by_id_and_user_id(id, user.id) + chat = await Chats.get_chat_by_id_and_user_id(id, user.id) tags = chat.meta.get("tags", []) - return Tags.get_tags_by_ids_and_user_id(tags, user.id) + return await Tags.get_tags_by_ids_and_user_id(tags, user.id) else: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT() @@ -853,16 +856,21 @@ async def add_tag_by_id_and_tag_name( async def delete_tag_by_id_and_tag_name( id: str, form_data: TagForm, user=Depends(get_verified_user) ): - chat = Chats.get_chat_by_id_and_user_id(id, user.id) + chat = await Chats.get_chat_by_id_and_user_id(id, user.id) if chat: - Chats.delete_tag_by_id_and_user_id_and_tag_name(id, user.id, form_data.name) + await Chats.delete_tag_by_id_and_user_id_and_tag_name( + id, user.id, form_data.name + ) - if Chats.count_chats_by_tag_name_and_user_id(form_data.name, user.id) == 0: - Tags.delete_tag_by_name_and_user_id(form_data.name, user.id) + if ( + await Chats.count_chats_by_tag_name_and_user_id(form_data.name, user.id) + == 0 + ): + await Tags.delete_tag_by_name_and_user_id(form_data.name, user.id) - chat = Chats.get_chat_by_id_and_user_id(id, user.id) + chat = await Chats.get_chat_by_id_and_user_id(id, user.id) tags = chat.meta.get("tags", []) - return Tags.get_tags_by_ids_and_user_id(tags, user.id) + return await Tags.get_tags_by_ids_and_user_id(tags, user.id) else: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND @@ -876,13 +884,13 @@ async def delete_tag_by_id_and_tag_name( @router.delete("/{id}/tags/all", response_model=Optional[bool]) async def delete_all_tags_by_id(id: str, user=Depends(get_verified_user)): - chat = Chats.get_chat_by_id_and_user_id(id, user.id) + chat = await Chats.get_chat_by_id_and_user_id(id, user.id) if chat: - Chats.delete_all_tags_by_id_and_user_id(id, user.id) + await Chats.delete_all_tags_by_id_and_user_id(id, user.id) for tag in chat.meta.get("tags", []): - if Chats.count_chats_by_tag_name_and_user_id(tag, user.id) == 0: - Tags.delete_tag_by_name_and_user_id(tag, user.id) + if await Chats.count_chats_by_tag_name_and_user_id(tag, user.id) == 0: + await Tags.delete_tag_by_name_and_user_id(tag, user.id) return True else: diff --git a/backend/open_webui/routers/evaluations.py b/backend/open_webui/routers/evaluations.py index c76a1f6915..555db9322c 100644 --- a/backend/open_webui/routers/evaluations.py +++ b/backend/open_webui/routers/evaluations.py @@ -73,11 +73,11 @@ class FeedbackUserResponse(FeedbackResponse): @router.get("/feedbacks/all", response_model=list[FeedbackUserResponse]) async def get_all_feedbacks(user=Depends(get_admin_user)): - feedbacks = Feedbacks.get_all_feedbacks() + feedbacks = await Feedbacks.get_all_feedbacks() feedback_list = [] for feedback in feedbacks: - user = Users.get_user_by_id(feedback.user_id) + user = await Users.get_user_by_id(feedback.user_id) feedback_list.append( FeedbackUserResponse( **feedback.model_dump(), @@ -89,25 +89,25 @@ async def get_all_feedbacks(user=Depends(get_admin_user)): @router.delete("/feedbacks/all") async def delete_all_feedbacks(user=Depends(get_admin_user)): - success = Feedbacks.delete_all_feedbacks() + success = await Feedbacks.delete_all_feedbacks() return success @router.get("/feedbacks/all/export", response_model=list[FeedbackModel]) async def get_all_feedbacks(user=Depends(get_admin_user)): - feedbacks = Feedbacks.get_all_feedbacks() + feedbacks = await Feedbacks.get_all_feedbacks() return feedbacks @router.get("/feedbacks/user", response_model=list[FeedbackUserResponse]) async def get_feedbacks(user=Depends(get_verified_user)): - feedbacks = Feedbacks.get_feedbacks_by_user_id(user.id) + feedbacks = await Feedbacks.get_feedbacks_by_user_id(user.id) return feedbacks @router.delete("/feedbacks", response_model=bool) async def delete_feedbacks(user=Depends(get_verified_user)): - success = Feedbacks.delete_feedbacks_by_user_id(user.id) + success = await Feedbacks.delete_feedbacks_by_user_id(user.id) return success @@ -117,7 +117,7 @@ async def create_feedback( form_data: FeedbackForm, user=Depends(get_verified_user), ): - feedback = Feedbacks.insert_new_feedback(user_id=user.id, form_data=form_data) + feedback = await Feedbacks.insert_new_feedback(user_id=user.id, form_data=form_data) if not feedback: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -130,9 +130,11 @@ async def create_feedback( @router.get("/feedback/{id}", response_model=FeedbackModel) async def get_feedback_by_id(id: str, user=Depends(get_verified_user)): if user.role == "admin": - feedback = Feedbacks.get_feedback_by_id(id=id) + feedback = await Feedbacks.get_feedback_by_id(id=id) else: - feedback = Feedbacks.get_feedback_by_id_and_user_id(id=id, user_id=user.id) + feedback = await Feedbacks.get_feedback_by_id_and_user_id( + id=id, user_id=user.id + ) if not feedback: raise HTTPException( @@ -147,9 +149,9 @@ async def update_feedback_by_id( id: str, form_data: FeedbackForm, user=Depends(get_verified_user) ): if user.role == "admin": - feedback = Feedbacks.update_feedback_by_id(id=id, form_data=form_data) + feedback = await Feedbacks.update_feedback_by_id(id=id, form_data=form_data) else: - feedback = Feedbacks.update_feedback_by_id_and_user_id( + feedback = await Feedbacks.update_feedback_by_id_and_user_id( id=id, user_id=user.id, form_data=form_data ) @@ -164,9 +166,11 @@ async def update_feedback_by_id( @router.delete("/feedback/{id}") async def delete_feedback_by_id(id: str, user=Depends(get_verified_user)): if user.role == "admin": - success = Feedbacks.delete_feedback_by_id(id=id) + success = await Feedbacks.delete_feedback_by_id(id=id) else: - success = Feedbacks.delete_feedback_by_id_and_user_id(id=id, user_id=user.id) + success = await Feedbacks.delete_feedback_by_id_and_user_id( + id=id, user_id=user.id + ) if not success: raise HTTPException( diff --git a/backend/open_webui/routers/files.py b/backend/open_webui/routers/files.py index 0a2b4ac97f..2a83442bf1 100644 --- a/backend/open_webui/routers/files.py +++ b/backend/open_webui/routers/files.py @@ -6,6 +6,7 @@ from fnmatch import fnmatch from pathlib import Path from typing import Optional from urllib.parse import quote +import asyncio from fastapi import ( APIRouter, @@ -137,22 +138,26 @@ def upload_file( } contents, file_path = Storage.upload_file(file.file, filename, tags) - file_item = Files.insert_new_file( - user.id, - FileForm( - **{ - "id": id, - "filename": name, - "path": file_path, - "meta": { - "name": name, - "content_type": file.content_type, - "size": len(contents), - "data": file_metadata, - }, - } - ), + loop = asyncio.get_event_loop() + file_item = loop.run_until_complete( + Files.insert_new_file( + user.id, + FileForm( + **{ + "id": id, + "filename": name, + "path": file_path, + "meta": { + "name": name, + "content_type": file.content_type, + "size": len(contents), + "data": file_metadata, + }, + } + ), + ) ) + if process: try: if file.content_type: @@ -187,7 +192,7 @@ def upload_file( ) process_file(request, ProcessFileForm(file_id=id), user=user) - file_item = Files.get_file_by_id(id=id) + file_item = loop.run_until_complete(Files.get_file_by_id(id=id)) except Exception as e: log.exception(e) log.error(f"Error processing file: {file_item.id}") @@ -489,7 +494,7 @@ async def get_html_file_content_by_id(id: str, user=Depends(get_verified_user)): detail=ERROR_MESSAGES.NOT_FOUND, ) - file_user = Users.get_user_by_id(file.user_id) + file_user = await Users.get_user_by_id(file.user_id) if not file_user.role == "admin": raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, diff --git a/backend/open_webui/routers/folders.py b/backend/open_webui/routers/folders.py index e419989e46..e245c7f477 100644 --- a/backend/open_webui/routers/folders.py +++ b/backend/open_webui/routers/folders.py @@ -50,7 +50,7 @@ async def get_folders(user=Depends(get_verified_user)): "items": { "chats": [ {"title": chat.title, "id": chat.id, "updated_at": chat.updated_at} - for chat in Chats.get_chats_by_folder_id_and_user_id( + for chat in await Chats.get_chats_by_folder_id_and_user_id( folder.id, user.id ) ] @@ -246,7 +246,7 @@ async def delete_folder_by_id( try: folder_ids = Folders.delete_folder_by_id_and_user_id(id, user.id) for folder_id in folder_ids: - Chats.delete_chats_by_user_id_and_folder_id(user.id, folder_id) + await Chats.delete_chats_by_user_id_and_folder_id(user.id, folder_id) return True except Exception as e: diff --git a/backend/open_webui/routers/notes.py b/backend/open_webui/routers/notes.py index 375f59ff6c..e9e63788be 100644 --- a/backend/open_webui/routers/notes.py +++ b/backend/open_webui/routers/notes.py @@ -45,7 +45,9 @@ async def get_notes(request: Request, user=Depends(get_verified_user)): NoteUserResponse( **{ **note.model_dump(), - "user": UserResponse(**Users.get_user_by_id(note.user_id).model_dump()), + "user": UserResponse( + **((await Users.get_user_by_id(note.user_id)).model_dump()) + ), } ) for note in Notes.get_notes_by_user_id(user.id, "write") diff --git a/backend/open_webui/routers/scim.py b/backend/open_webui/routers/scim.py index de1b979c86..77a9f035b1 100644 --- a/backend/open_webui/routers/scim.py +++ b/backend/open_webui/routers/scim.py @@ -350,7 +350,7 @@ def group_to_scim(group: GroupModel, request: Request) -> SCIMGroup: """Convert internal Group model to SCIM Group""" members = [] for user_id in group.user_ids: - user = Users.get_user_by_id(user_id) + user = await Users.get_user_by_id(user_id) if user: members.append( SCIMGroupMember( @@ -523,7 +523,7 @@ async def get_user( _: bool = Depends(get_scim_auth), ): """Get SCIM User by ID""" - user = Users.get_user_by_id(user_id) + user = await Users.get_user_by_id(user_id) if not user: return scim_error( status_code=status.HTTP_404_NOT_FOUND, detail=f"User {user_id} not found" @@ -565,7 +565,7 @@ async def create_user( profile_image = user_data.photos[0].value # Create user - new_user = Users.insert_new_user( + new_user = await Users.insert_new_user( id=user_id, name=name, email=email, @@ -590,7 +590,7 @@ async def update_user( _: bool = Depends(get_scim_auth), ): """Update SCIM User (full update)""" - user = Users.get_user_by_id(user_id) + user = await Users.get_user_by_id(user_id) if not user: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, @@ -641,7 +641,7 @@ async def patch_user( _: bool = Depends(get_scim_auth), ): """Update SCIM User (partial update)""" - user = Users.get_user_by_id(user_id) + user = await Users.get_user_by_id(user_id) if not user: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, @@ -688,7 +688,7 @@ async def delete_user( _: bool = Depends(get_scim_auth), ): """Delete SCIM User""" - user = Users.get_user_by_id(user_id) + user = await Users.get_user_by_id(user_id) if not user: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, diff --git a/backend/open_webui/routers/users.py b/backend/open_webui/routers/users.py index 7b27b45b9d..cad29c475d 100644 --- a/backend/open_webui/routers/users.py +++ b/backend/open_webui/routers/users.py @@ -310,7 +310,7 @@ async def get_user_by_id(user_id: str, user=Depends(get_verified_user)): # If it is, get the user_id from the chat if user_id.startswith("shared-"): chat_id = user_id.replace("shared-", "") - chat = Chats.get_chat_by_id(chat_id) + chat = await Chats.get_chat_by_id(chat_id) if chat: user_id = chat.user_id else: @@ -319,7 +319,7 @@ async def get_user_by_id(user_id: str, user=Depends(get_verified_user)): detail=ERROR_MESSAGES.USER_NOT_FOUND, ) - user = Users.get_user_by_id(user_id) + user = await Users.get_user_by_id(user_id) if user: return UserResponse( @@ -422,11 +422,11 @@ async def update_user_by_id( detail="Could not verify primary admin status.", ) - user = Users.get_user_by_id(user_id) + user = await Users.get_user_by_id(user_id) if user: if form_data.email.lower() != user.email: - email_user = Users.get_user_by_email(form_data.email.lower()) + email_user = await Users.get_user_by_email(form_data.email.lower()) if email_user: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -436,10 +436,10 @@ async def update_user_by_id( if form_data.password: hashed = get_password_hash(form_data.password) log.debug(f"hashed: {hashed}") - Auths.update_user_password_by_id(user_id, hashed) + await Auths.update_user_password_by_id(user_id, hashed) - Auths.update_email_by_id(user_id, form_data.email.lower()) - updated_user = Users.update_user_by_id( + await Auths.update_email_by_id(user_id, form_data.email.lower()) + updated_user = await Users.update_user_by_id( user_id, { "role": form_data.role, @@ -486,7 +486,7 @@ async def delete_user_by_id(user_id: str, user=Depends(get_admin_user)): ) if user.id != user_id: - result = Auths.delete_auth_by_id(user_id) + result = await Auths.delete_auth_by_id(user_id) if result: return True diff --git a/backend/open_webui/socket/main.py b/backend/open_webui/socket/main.py index 49323db975..63dda91502 100644 --- a/backend/open_webui/socket/main.py +++ b/backend/open_webui/socket/main.py @@ -295,7 +295,7 @@ async def user_join(sid, data): USER_POOL[user.id] = [sid] # Join all the channels - channels = Channels.get_channels_by_user_id(user.id) + channels = await Channels.get_channels_by_user_id(user.id) log.debug(f"{channels=}") for channel in channels: await sio.enter_room(sid, f"channel:{channel.id}") @@ -317,7 +317,7 @@ async def join_channel(sid, data): return # Join all the channels - channels = Channels.get_channels_by_user_id(user.id) + channels = await Channels.get_channels_by_user_id(user.id) log.debug(f"{channels=}") for channel in channels: await sio.enter_room(sid, f"channel:{channel.id}") @@ -668,14 +668,14 @@ def get_event_emitter(request_info, update_db=True): if update_db: if "type" in event_data and event_data["type"] == "status": - Chats.add_message_status_to_chat_by_id_and_message_id( + await Chats.add_message_status_to_chat_by_id_and_message_id( request_info["chat_id"], request_info["message_id"], event_data.get("data", {}), ) if "type" in event_data and event_data["type"] == "message": - message = Chats.get_message_by_id_and_message_id( + message = await Chats.get_message_by_id_and_message_id( request_info["chat_id"], request_info["message_id"], ) @@ -684,7 +684,7 @@ def get_event_emitter(request_info, update_db=True): content = message.get("content", "") content += event_data.get("data", {}).get("content", "") - Chats.upsert_message_to_chat_by_id_and_message_id( + await Chats.upsert_message_to_chat_by_id_and_message_id( request_info["chat_id"], request_info["message_id"], { @@ -695,7 +695,7 @@ def get_event_emitter(request_info, update_db=True): if "type" in event_data and event_data["type"] == "replace": content = event_data.get("data", {}).get("content", "") - Chats.upsert_message_to_chat_by_id_and_message_id( + await Chats.upsert_message_to_chat_by_id_and_message_id( request_info["chat_id"], request_info["message_id"], { diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index e0c38c4a03..72c2707846 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -771,9 +771,9 @@ async def process_chat_payload(request, form_data, user, metadata, model): # Check if the request has chat_id and is inside of a folder chat_id = metadata.get("chat_id", None) if chat_id and user: - chat = Chats.get_chat_by_id_and_user_id(chat_id, user.id) + chat = await Chats.get_chat_by_id_and_user_id(chat_id, user.id) if chat and chat.folder_id: - folder = Folders.get_folder_by_id_and_user_id(chat.folder_id, user.id) + folder = await Folders.get_folder_by_id_and_user_id(chat.folder_id, user.id) if folder and folder.data: if "system_prompt" in folder.data: @@ -1042,7 +1042,7 @@ async def process_chat_response( request, response, form_data, user, metadata, model, events, tasks ): async def background_tasks_handler(): - message_map = Chats.get_messages_by_chat_id(metadata["chat_id"]) + message_map = await Chats.get_messages_by_chat_id(metadata["chat_id"]) message = message_map.get(metadata["message_id"]) if message_map else None if message: @@ -1115,7 +1115,7 @@ async def process_chat_response( "follow_ups", [] ) - Chats.upsert_message_to_chat_by_id_and_message_id( + await Chats.upsert_message_to_chat_by_id_and_message_id( metadata["chat_id"], metadata["message_id"], { @@ -1177,7 +1177,9 @@ async def process_chat_response( if not title: title = messages[0].get("content", user_message) - Chats.update_chat_title_by_id(metadata["chat_id"], title) + await Chats.update_chat_title_by_id( + metadata["chat_id"], title + ) await event_emitter( { @@ -1188,7 +1190,7 @@ async def process_chat_response( elif len(messages) == 2: title = messages[0].get("content", user_message) - Chats.update_chat_title_by_id(metadata["chat_id"], title) + await Chats.update_chat_title_by_id(metadata["chat_id"], title) await event_emitter( { @@ -1224,7 +1226,7 @@ async def process_chat_response( try: tags = json.loads(tags_string).get("tags", []) - Chats.update_chat_tags_by_id( + await Chats.update_chat_tags_by_id( metadata["chat_id"], tags, user ) @@ -1267,7 +1269,7 @@ async def process_chat_response( if "error" in response_data: error = response_data["error"].get("detail", response_data["error"]) - Chats.upsert_message_to_chat_by_id_and_message_id( + await Chats.upsert_message_to_chat_by_id_and_message_id( metadata["chat_id"], metadata["message_id"], { @@ -1276,7 +1278,7 @@ async def process_chat_response( ) if "selected_model_id" in response_data: - Chats.upsert_message_to_chat_by_id_and_message_id( + await Chats.upsert_message_to_chat_by_id_and_message_id( metadata["chat_id"], metadata["message_id"], { @@ -1296,7 +1298,7 @@ async def process_chat_response( } ) - title = Chats.get_chat_title_by_id(metadata["chat_id"]) + title = await Chats.get_chat_title_by_id(metadata["chat_id"]) await event_emitter( { @@ -1310,7 +1312,7 @@ async def process_chat_response( ) # Save message in the database - Chats.upsert_message_to_chat_by_id_and_message_id( + await Chats.upsert_message_to_chat_by_id_and_message_id( metadata["chat_id"], metadata["message_id"], { @@ -1767,7 +1769,7 @@ async def process_chat_response( return content, content_blocks, end_flag - message = Chats.get_message_by_id_and_message_id( + message = await Chats.get_message_by_id_and_message_id( metadata["chat_id"], metadata["message_id"] ) @@ -1827,7 +1829,7 @@ async def process_chat_response( ) # Save message in the database - Chats.upsert_message_to_chat_by_id_and_message_id( + await Chats.upsert_message_to_chat_by_id_and_message_id( metadata["chat_id"], metadata["message_id"], { @@ -1882,7 +1884,7 @@ async def process_chat_response( if "selected_model_id" in data: model_id = data["selected_model_id"] - Chats.upsert_message_to_chat_by_id_and_message_id( + await Chats.upsert_message_to_chat_by_id_and_message_id( metadata["chat_id"], metadata["message_id"], { @@ -2081,7 +2083,7 @@ async def process_chat_response( if ENABLE_REALTIME_CHAT_SAVE: # Save message in the database - Chats.upsert_message_to_chat_by_id_and_message_id( + await Chats.upsert_message_to_chat_by_id_and_message_id( metadata["chat_id"], metadata["message_id"], { @@ -2502,7 +2504,7 @@ async def process_chat_response( log.debug(e) break - title = Chats.get_chat_title_by_id(metadata["chat_id"]) + title = await Chats.get_chat_title_by_id(metadata["chat_id"]) data = { "done": True, "content": serialize_content_blocks(content_blocks), @@ -2511,7 +2513,7 @@ async def process_chat_response( if not ENABLE_REALTIME_CHAT_SAVE: # Save message in the database - Chats.upsert_message_to_chat_by_id_and_message_id( + await Chats.upsert_message_to_chat_by_id_and_message_id( metadata["chat_id"], metadata["message_id"], { @@ -2549,7 +2551,7 @@ async def process_chat_response( if not ENABLE_REALTIME_CHAT_SAVE: # Save message in the database - Chats.upsert_message_to_chat_by_id_and_message_id( + await Chats.upsert_message_to_chat_by_id_and_message_id( metadata["chat_id"], metadata["message_id"], { diff --git a/backend/open_webui/utils/oauth.py b/backend/open_webui/utils/oauth.py index 67ae13bc9b..9ad067eeb1 100644 --- a/backend/open_webui/utils/oauth.py +++ b/backend/open_webui/utils/oauth.py @@ -490,7 +490,7 @@ class OAuthManager: role = self.get_user_role(None, user_data) - user = Auths.insert_new_auth( + user = await Auths.insert_new_auth( email=email, password=get_password_hash( str(uuid.uuid4())