diff --git a/backend/open_webui/internal/db.py b/backend/open_webui/internal/db.py index 95571c2435..e84dfe9aa7 100644 --- a/backend/open_webui/internal/db.py +++ b/backend/open_webui/internal/db.py @@ -160,3 +160,13 @@ def get_session(): get_db = contextmanager(get_session) + + +@contextmanager +def get_db_context(db: Optional[Session] = None): + if db: + yield db + else: + with get_db() as session: + yield session + diff --git a/backend/open_webui/models/auths.py b/backend/open_webui/models/auths.py index cb6c057c88..bb74fb7ded 100644 --- a/backend/open_webui/models/auths.py +++ b/backend/open_webui/models/auths.py @@ -2,7 +2,8 @@ import logging import uuid from typing import Optional -from open_webui.internal.db import Base, get_db +from sqlalchemy.orm import Session +from open_webui.internal.db import Base, JSONField, get_db, get_db_context from open_webui.models.users import UserModel, UserProfileImageResponse, Users from pydantic import BaseModel from sqlalchemy import Boolean, Column, String, Text @@ -87,8 +88,9 @@ class AuthsTable: profile_image_url: str = "/user.png", role: str = "pending", oauth: Optional[dict] = None, + db: Optional[Session] = None, ) -> Optional[UserModel]: - with get_db() as db: + with get_db_context(db) as db: log.info("insert_new_auth") id = str(uuid.uuid4()) @@ -100,7 +102,7 @@ class AuthsTable: db.add(result) user = Users.insert_new_user( - id, name, email, profile_image_url, role, oauth=oauth + id, name, email, profile_image_url, role, oauth=oauth, db=db ) db.commit() @@ -112,16 +114,16 @@ class AuthsTable: return None def authenticate_user( - self, email: str, verify_password: callable + self, email: str, verify_password: callable, db: Optional[Session] = None ) -> Optional[UserModel]: log.info(f"authenticate_user: {email}") - user = Users.get_user_by_email(email) + user = Users.get_user_by_email(email, db=db) if not user: return None try: - with get_db() as db: + with get_db_context(db) as db: auth = db.query(Auth).filter_by(id=user.id, active=True).first() if auth: if verify_password(auth.password): @@ -133,32 +135,32 @@ class AuthsTable: except Exception: return None - def authenticate_user_by_api_key(self, api_key: str) -> Optional[UserModel]: + def authenticate_user_by_api_key(self, api_key: str, db: Optional[Session] = None) -> Optional[UserModel]: log.info(f"authenticate_user_by_api_key: {api_key}") # if no api_key, return None if not api_key: return None try: - user = Users.get_user_by_api_key(api_key) + user = Users.get_user_by_api_key(api_key, db=db) return user if user else None except Exception: return False - def authenticate_user_by_email(self, email: str) -> Optional[UserModel]: + def authenticate_user_by_email(self, email: str, db: Optional[Session] = None) -> Optional[UserModel]: log.info(f"authenticate_user_by_email: {email}") try: - with get_db() as db: + with get_db_context(db) as db: auth = db.query(Auth).filter_by(email=email, active=True).first() if auth: - user = Users.get_user_by_id(auth.id) + user = Users.get_user_by_id(auth.id, db=db) return user except Exception: return None - def update_user_password_by_id(self, id: str, new_password: str) -> bool: + def update_user_password_by_id(self, id: str, new_password: str, db: Optional[Session] = None) -> bool: try: - with get_db() as db: + with get_db_context(db) as db: result = ( db.query(Auth).filter_by(id=id).update({"password": new_password}) ) @@ -167,20 +169,20 @@ class AuthsTable: except Exception: return False - def update_email_by_id(self, id: str, email: str) -> bool: + def update_email_by_id(self, id: str, email: str, db: Optional[Session] = None) -> bool: try: - with get_db() as db: + with get_db_context(db) as db: result = db.query(Auth).filter_by(id=id).update({"email": email}) db.commit() return True if result == 1 else False except Exception: return False - def delete_auth_by_id(self, id: str) -> bool: + def delete_auth_by_id(self, id: str, db: Optional[Session] = None) -> bool: try: - with get_db() as db: + with get_db_context(db) as db: # Delete User - result = Users.delete_user_by_id(id) + result = Users.delete_user_by_id(id, db=db) if result: db.query(Auth).filter_by(id=id).delete() diff --git a/backend/open_webui/models/channels.py b/backend/open_webui/models/channels.py index 362222a284..3aba302ff0 100644 --- a/backend/open_webui/models/channels.py +++ b/backend/open_webui/models/channels.py @@ -1,9 +1,11 @@ + import json import time import uuid from typing import Optional -from open_webui.internal.db import Base, get_db +from sqlalchemy.orm import Session +from open_webui.internal.db import Base, JSONField, get_db, get_db_context from open_webui.models.groups import Groups from pydantic import BaseModel, ConfigDict @@ -337,8 +339,8 @@ class ChannelTable: db.commit() return channel - def get_channels(self) -> list[ChannelModel]: - with get_db() as db: + def get_channels(self, db: Optional[Session] = None) -> list[ChannelModel]: + with get_db_context(db) as db: channels = db.query(Channel).all() return [ChannelModel.model_validate(channel) for channel in channels] @@ -384,8 +386,8 @@ class ChannelTable: return query - def get_channels_by_user_id(self, user_id: str) -> list[ChannelModel]: - with get_db() as db: + def get_channels_by_user_id(self, user_id: str, db: Optional[Session] = None) -> list[ChannelModel]: + with get_db_context(db) as db: user_group_ids = [ group.id for group in Groups.get_groups_by_member_id(user_id) ] @@ -683,10 +685,13 @@ class ChannelTable: ) return membership is not None - def get_channel_by_id(self, id: str) -> Optional[ChannelModel]: - with get_db() as db: - channel = db.query(Channel).filter(Channel.id == id).first() - return ChannelModel.model_validate(channel) if channel else None + def get_channel_by_id(self, id: str, db: Optional[Session] = None) -> Optional[ChannelModel]: + try: + with get_db_context(db) as db: + channel = db.query(Channel).filter(Channel.id == id).first() + return ChannelModel.model_validate(channel) if channel else None + except Exception: + return None def get_channels_by_file_id(self, file_id: str) -> list[ChannelModel]: with get_db() as db: diff --git a/backend/open_webui/models/chats.py b/backend/open_webui/models/chats.py index d7153a2330..a8e3d27b99 100644 --- a/backend/open_webui/models/chats.py +++ b/backend/open_webui/models/chats.py @@ -4,7 +4,8 @@ import time import uuid from typing import Optional -from open_webui.internal.db import Base, get_db +from sqlalchemy.orm import Session +from open_webui.internal.db import Base, JSONField, get_db, get_db_context from open_webui.models.tags import TagModel, Tag, Tags from open_webui.models.folders import Folders from open_webui.utils.misc import sanitize_data_for_db, sanitize_text_for_db @@ -280,8 +281,8 @@ class ChatTable: return changed - def insert_new_chat(self, user_id: str, form_data: ChatForm) -> Optional[ChatModel]: - with get_db() as db: + def insert_new_chat(self, user_id: str, form_data: ChatForm, db: Optional[Session] = None) -> Optional[ChatModel]: + with get_db_context(db) as db: id = str(uuid.uuid4()) chat = ChatModel( **{ @@ -331,9 +332,9 @@ class ChatTable: return chat def import_chats( - self, user_id: str, chat_import_forms: list[ChatImportForm] + self, user_id: str, chat_import_forms: list[ChatImportForm], db: Optional[Session] = None ) -> list[ChatModel]: - with get_db() as db: + with get_db_context(db) as db: chats = [] for form_data in chat_import_forms: @@ -344,9 +345,9 @@ class ChatTable: db.commit() return [ChatModel.model_validate(chat) for chat in chats] - def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]: + def update_chat_by_id(self, id: str, chat: dict, db: Optional[Session] = None) -> Optional[ChatModel]: try: - with get_db() as db: + with get_db_context(db) as db: chat_item = db.get(Chat, id) chat_item.chat = self._clean_null_bytes(chat) chat_item.title = ( @@ -483,13 +484,13 @@ class ChatTable: self.update_chat_by_id(id, chat) return message_files - def insert_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]: - with get_db() as db: + def insert_shared_chat_by_chat_id(self, chat_id: str, db: Optional[Session] = None) -> Optional[ChatModel]: + with get_db_context(db) as db: # Get the existing chat to share chat = db.get(Chat, chat_id) # Check if the chat is already shared if chat.share_id: - return self.get_chat_by_id_and_user_id(chat.share_id, "shared") + return self.get_chat_by_id_and_user_id(chat.share_id, "shared", db=db) # Create a new chat with the same data, but with a new ID shared_chat = ChatModel( **{ @@ -518,16 +519,16 @@ class ChatTable: db.commit() return shared_chat if (shared_result and result) else None - def update_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]: + def update_shared_chat_by_chat_id(self, chat_id: str, db: Optional[Session] = None) -> Optional[ChatModel]: try: - with get_db() as db: + with get_db_context(db) as db: chat = db.get(Chat, chat_id) shared_chat = ( db.query(Chat).filter_by(user_id=f"shared-{chat_id}").first() ) if shared_chat is None: - return self.insert_shared_chat_by_chat_id(chat_id) + return self.insert_shared_chat_by_chat_id(chat_id, db=db) shared_chat.title = chat.title shared_chat.chat = chat.chat @@ -542,9 +543,9 @@ class ChatTable: except Exception: return None - def delete_shared_chat_by_chat_id(self, chat_id: str) -> bool: + def delete_shared_chat_by_chat_id(self, chat_id: str, db: Optional[Session] = None) -> bool: try: - with get_db() as db: + with get_db_context(db) as db: db.query(Chat).filter_by(user_id=f"shared-{chat_id}").delete() db.commit() @@ -552,9 +553,9 @@ class ChatTable: except Exception: return False - def unarchive_all_chats_by_user_id(self, user_id: str) -> bool: + def unarchive_all_chats_by_user_id(self, user_id: str, db: Optional[Session] = None) -> bool: try: - with get_db() as db: + with get_db_context(db) as db: db.query(Chat).filter_by(user_id=user_id).update({"archived": False}) db.commit() return True @@ -562,10 +563,10 @@ class ChatTable: return False def update_chat_share_id_by_id( - self, id: str, share_id: Optional[str] + self, id: str, share_id: Optional[str], db: Optional[Session] = None ) -> Optional[ChatModel]: try: - with get_db() as db: + with get_db_context(db) as db: chat = db.get(Chat, id) chat.share_id = share_id db.commit() @@ -574,9 +575,9 @@ class ChatTable: except Exception: return None - def toggle_chat_pinned_by_id(self, id: str) -> Optional[ChatModel]: + def toggle_chat_pinned_by_id(self, id: str, db: Optional[Session] = None) -> Optional[ChatModel]: try: - with get_db() as db: + with get_db_context(db) as db: chat = db.get(Chat, id) chat.pinned = not chat.pinned chat.updated_at = int(time.time()) @@ -586,9 +587,9 @@ class ChatTable: except Exception: return None - def toggle_chat_archive_by_id(self, id: str) -> Optional[ChatModel]: + def toggle_chat_archive_by_id(self, id: str, db: Optional[Session] = None) -> Optional[ChatModel]: try: - with get_db() as db: + with get_db_context(db) as db: chat = db.get(Chat, id) chat.archived = not chat.archived chat.folder_id = None @@ -599,9 +600,9 @@ class ChatTable: except Exception: return None - def archive_all_chats_by_user_id(self, user_id: str) -> bool: + def archive_all_chats_by_user_id(self, user_id: str, db: Optional[Session] = None) -> bool: try: - with get_db() as db: + with get_db_context(db) as db: db.query(Chat).filter_by(user_id=user_id).update({"archived": True}) db.commit() return True @@ -614,9 +615,10 @@ class ChatTable: filter: Optional[dict] = None, skip: int = 0, limit: int = 50, + db: Optional[Session] = None, ) -> list[ChatModel]: - with get_db() as db: + with get_db_context(db) as db: query = db.query(Chat).filter_by(user_id=user_id, archived=True) if filter: @@ -655,8 +657,9 @@ class ChatTable: filter: Optional[dict] = None, skip: int = 0, limit: int = 50, + db: Optional[Session] = None, ) -> list[ChatModel]: - with get_db() as db: + with get_db_context(db) as db: query = db.query(Chat).filter_by(user_id=user_id) if not include_archived: query = query.filter_by(archived=False) @@ -695,8 +698,9 @@ class ChatTable: include_pinned: bool = False, skip: Optional[int] = None, limit: Optional[int] = None, + db: Optional[Session] = None, ) -> list[ChatTitleIdResponse]: - with get_db() as db: + with get_db_context(db) as db: query = db.query(Chat).filter_by(user_id=user_id) if not include_folders: @@ -733,9 +737,9 @@ class ChatTable: ] def get_chat_list_by_chat_ids( - self, chat_ids: list[str], skip: int = 0, limit: int = 50 + self, chat_ids: list[str], skip: int = 0, limit: int = 50, db: Optional[Session] = None ) -> list[ChatModel]: - with get_db() as db: + with get_db_context(db) as db: all_chats = ( db.query(Chat) .filter(Chat.id.in_(chat_ids)) @@ -745,9 +749,9 @@ class ChatTable: ) return [ChatModel.model_validate(chat) for chat in all_chats] - def get_chat_by_id(self, id: str) -> Optional[ChatModel]: + def get_chat_by_id(self, id: str, db: Optional[Session] = None) -> Optional[ChatModel]: try: - with get_db() as db: + with get_db_context(db) as db: chat_item = db.get(Chat, id) if chat_item is None: return None @@ -760,30 +764,30 @@ class ChatTable: except Exception: return None - def get_chat_by_share_id(self, id: str) -> Optional[ChatModel]: + def get_chat_by_share_id(self, id: str, db: Optional[Session] = None) -> Optional[ChatModel]: try: - with get_db() as db: + with get_db_context(db) as db: # it is possible that the shared link was deleted. hence, # we check if the chat is still shared by checking if a chat with the share_id exists chat = db.query(Chat).filter_by(share_id=id).first() if chat: - return self.get_chat_by_id(id) + return self.get_chat_by_id(id, db=db) else: return None except Exception: return None - def get_chat_by_id_and_user_id(self, id: str, user_id: str) -> Optional[ChatModel]: + def get_chat_by_id_and_user_id(self, id: str, user_id: str, db: Optional[Session] = None) -> Optional[ChatModel]: try: - with get_db() as db: + with get_db_context(db) as db: chat = db.query(Chat).filter_by(id=id, user_id=user_id).first() return ChatModel.model_validate(chat) except Exception: return None - def get_chats(self, skip: int = 0, limit: int = 50) -> list[ChatModel]: - with get_db() as db: + def get_chats(self, skip: int = 0, limit: int = 50, db: Optional[Session] = None) -> list[ChatModel]: + with get_db_context(db) as db: all_chats = ( db.query(Chat) # .limit(limit).offset(skip) @@ -797,8 +801,9 @@ class ChatTable: filter: Optional[dict] = None, skip: Optional[int] = None, limit: Optional[int] = None, + db: Optional[Session] = None, ) -> ChatListResponse: - with get_db() as db: + with get_db_context(db) as db: query = db.query(Chat).filter_by(user_id=user_id) if filter: @@ -838,8 +843,8 @@ class ChatTable: } ) - def get_pinned_chats_by_user_id(self, user_id: str) -> list[ChatModel]: - with get_db() as db: + def get_pinned_chats_by_user_id(self, user_id: str, db: Optional[Session] = None) -> list[ChatModel]: + with get_db_context(db) as db: all_chats = ( db.query(Chat) .filter_by(user_id=user_id, pinned=True, archived=False) @@ -847,8 +852,8 @@ class ChatTable: ) return [ChatModel.model_validate(chat) for chat in all_chats] - def get_archived_chats_by_user_id(self, user_id: str) -> list[ChatModel]: - with get_db() as db: + def get_archived_chats_by_user_id(self, user_id: str, db: Optional[Session] = None) -> list[ChatModel]: + with get_db_context(db) as db: all_chats = ( db.query(Chat) .filter_by(user_id=user_id, archived=True) @@ -863,6 +868,7 @@ class ChatTable: include_archived: bool = False, skip: int = 0, limit: int = 60, + db: Optional[Session] = None, ) -> list[ChatModel]: """ Filters chats based on a search query using Python, allowing pagination using skip and limit. @@ -871,7 +877,7 @@ class ChatTable: if not search_text: return self.get_chat_list_by_user_id( - user_id, include_archived, filter={}, skip=skip, limit=limit + user_id, include_archived, filter={}, skip=skip, limit=limit, db=db ) search_text_words = search_text.split(" ") @@ -926,7 +932,7 @@ class ChatTable: search_text = " ".join(search_text_words) - with get_db() as db: + with get_db_context(db) as db: query = db.query(Chat).filter(Chat.user_id == user_id) if is_archived is not None: @@ -1067,9 +1073,9 @@ class ChatTable: return [ChatModel.model_validate(chat) for chat in all_chats] def get_chats_by_folder_id_and_user_id( - self, folder_id: str, user_id: str, skip: int = 0, limit: int = 60 + self, folder_id: str, user_id: str, skip: int = 0, limit: int = 60, db: Optional[Session] = None ) -> list[ChatModel]: - with get_db() as db: + with get_db_context(db) as db: query = db.query(Chat).filter_by(folder_id=folder_id, user_id=user_id) query = query.filter(or_(Chat.pinned == False, Chat.pinned == None)) query = query.filter_by(archived=False) @@ -1085,9 +1091,9 @@ class ChatTable: return [ChatModel.model_validate(chat) for chat in all_chats] def get_chats_by_folder_ids_and_user_id( - self, folder_ids: list[str], user_id: str + self, folder_ids: list[str], user_id: str, db: Optional[Session] = None ) -> list[ChatModel]: - with get_db() as db: + with get_db_context(db) as db: query = db.query(Chat).filter( Chat.folder_id.in_(folder_ids), Chat.user_id == user_id ) @@ -1100,10 +1106,10 @@ class ChatTable: return [ChatModel.model_validate(chat) for chat in all_chats] def update_chat_folder_id_by_id_and_user_id( - self, id: str, user_id: str, folder_id: str + self, id: str, user_id: str, folder_id: str, db: Optional[Session] = None ) -> Optional[ChatModel]: try: - with get_db() as db: + with get_db_context(db) as db: chat = db.get(Chat, id) chat.folder_id = folder_id chat.updated_at = int(time.time()) @@ -1114,16 +1120,16 @@ class ChatTable: except Exception: return None - def get_chat_tags_by_id_and_user_id(self, id: str, user_id: str) -> list[TagModel]: - with get_db() as db: + def get_chat_tags_by_id_and_user_id(self, id: str, user_id: str, db: Optional[Session] = None) -> list[TagModel]: + with get_db_context(db) as db: chat = db.get(Chat, id) tags = chat.meta.get("tags", []) return [Tags.get_tag_by_name_and_user_id(tag, user_id) for tag in tags] def get_chat_list_by_user_id_and_tag_name( - self, user_id: str, tag_name: str, skip: int = 0, limit: int = 50 + self, user_id: str, tag_name: str, skip: int = 0, limit: int = 50, db: Optional[Session] = None ) -> list[ChatModel]: - with get_db() as db: + with get_db_context(db) as db: query = db.query(Chat).filter_by(user_id=user_id) tag_id = tag_name.replace(" ", "_").lower() @@ -1152,13 +1158,13 @@ class ChatTable: return [ChatModel.model_validate(chat) for chat in all_chats] def add_chat_tag_by_id_and_user_id_and_tag_name( - self, id: str, user_id: str, tag_name: str + self, id: str, user_id: str, tag_name: str, db: Optional[Session] = None ) -> Optional[ChatModel]: tag = Tags.get_tag_by_name_and_user_id(tag_name, user_id) if tag is None: tag = Tags.insert_new_tag(tag_name, user_id) try: - with get_db() as db: + with get_db_context(db) as db: chat = db.get(Chat, id) tag_id = tag.id @@ -1174,8 +1180,8 @@ class ChatTable: except Exception: return None - def count_chats_by_tag_name_and_user_id(self, tag_name: str, user_id: str) -> int: - with get_db() as db: # Assuming `get_db()` returns a session object + def count_chats_by_tag_name_and_user_id(self, tag_name: str, user_id: str, db: Optional[Session] = None) -> int: + with get_db_context(db) as db: # Assuming `get_db()` returns a session object query = db.query(Chat).filter_by(user_id=user_id, archived=False) # Normalize the tag_name for consistency @@ -1210,8 +1216,8 @@ class ChatTable: return count - def count_chats_by_folder_id_and_user_id(self, folder_id: str, user_id: str) -> int: - with get_db() as db: + def count_chats_by_folder_id_and_user_id(self, folder_id: str, user_id: str, db: Optional[Session] = None) -> int: + with get_db_context(db) as db: query = db.query(Chat).filter_by(user_id=user_id) query = query.filter_by(folder_id=folder_id) @@ -1221,10 +1227,10 @@ class ChatTable: return count def delete_tag_by_id_and_user_id_and_tag_name( - self, id: str, user_id: str, tag_name: str + self, id: str, user_id: str, tag_name: str, db: Optional[Session] = None ) -> bool: try: - with get_db() as db: + with get_db_context(db) as db: chat = db.get(Chat, id) tags = chat.meta.get("tags", []) tag_id = tag_name.replace(" ", "_").lower() @@ -1239,9 +1245,9 @@ class ChatTable: except Exception: return False - def delete_all_tags_by_id_and_user_id(self, id: str, user_id: str) -> bool: + def delete_all_tags_by_id_and_user_id(self, id: str, user_id: str, db: Optional[Session] = None) -> bool: try: - with get_db() as db: + with get_db_context(db) as db: chat = db.get(Chat, id) chat.meta = { **chat.meta, @@ -1253,30 +1259,30 @@ class ChatTable: except Exception: return False - def delete_chat_by_id(self, id: str) -> bool: + def delete_chat_by_id(self, id: str, db: Optional[Session] = None) -> bool: try: - with get_db() as db: + with get_db_context(db) as db: db.query(Chat).filter_by(id=id).delete() db.commit() - return True and self.delete_shared_chat_by_chat_id(id) + return True and self.delete_shared_chat_by_chat_id(id, db=db) except Exception: return False - def delete_chat_by_id_and_user_id(self, id: str, user_id: str) -> bool: + def delete_chat_by_id_and_user_id(self, id: str, user_id: str, db: Optional[Session] = None) -> bool: try: - with get_db() as db: + with get_db_context(db) as db: db.query(Chat).filter_by(id=id, user_id=user_id).delete() db.commit() - return True and self.delete_shared_chat_by_chat_id(id) + return True and self.delete_shared_chat_by_chat_id(id, db=db) except Exception: return False - def delete_chats_by_user_id(self, user_id: str) -> bool: + def delete_chats_by_user_id(self, user_id: str, db: Optional[Session] = None) -> bool: try: - with get_db() as db: - self.delete_shared_chats_by_user_id(user_id) + with get_db_context(db) as db: + self.delete_shared_chats_by_user_id(user_id, db=db) db.query(Chat).filter_by(user_id=user_id).delete() db.commit() @@ -1286,10 +1292,10 @@ class ChatTable: return False def delete_chats_by_user_id_and_folder_id( - self, user_id: str, folder_id: str + self, user_id: str, folder_id: str, db: Optional[Session] = None ) -> bool: try: - with get_db() as db: + with get_db_context(db) as db: db.query(Chat).filter_by(user_id=user_id, folder_id=folder_id).delete() db.commit() @@ -1298,10 +1304,10 @@ class ChatTable: return False def move_chats_by_user_id_and_folder_id( - self, user_id: str, folder_id: str, new_folder_id: Optional[str] + self, user_id: str, folder_id: str, new_folder_id: Optional[str], db: Optional[Session] = None ) -> bool: try: - with get_db() as db: + with get_db_context(db) as db: db.query(Chat).filter_by(user_id=user_id, folder_id=folder_id).update( {"folder_id": new_folder_id} ) @@ -1311,9 +1317,9 @@ class ChatTable: except Exception: return False - def delete_shared_chats_by_user_id(self, user_id: str) -> bool: + def delete_shared_chats_by_user_id(self, user_id: str, db: Optional[Session] = None) -> bool: try: - with get_db() as db: + with get_db_context(db) as db: chats_by_user = db.query(Chat).filter_by(user_id=user_id).all() shared_chat_ids = [f"shared-{chat.id}" for chat in chats_by_user] @@ -1325,7 +1331,7 @@ class ChatTable: return False def insert_chat_files( - self, chat_id: str, message_id: str, file_ids: list[str], user_id: str + self, chat_id: str, message_id: str, file_ids: list[str], user_id: str, db: Optional[Session] = None ) -> Optional[list[ChatFileModel]]: if not file_ids: return None @@ -1333,7 +1339,7 @@ class ChatTable: chat_message_file_ids = [ item.id for item in self.get_chat_files_by_chat_id_and_message_id( - chat_id, message_id + chat_id, message_id, db=db ) ] # Remove duplicates and existing file_ids @@ -1350,7 +1356,7 @@ class ChatTable: return None try: - with get_db() as db: + with get_db_context(db) as db: now = int(time.time()) chat_files = [ @@ -1378,9 +1384,9 @@ class ChatTable: return None def get_chat_files_by_chat_id_and_message_id( - self, chat_id: str, message_id: str + self, chat_id: str, message_id: str, db: Optional[Session] = None ) -> list[ChatFileModel]: - with get_db() as db: + with get_db_context(db) as db: all_chat_files = ( db.query(ChatFile) .filter_by(chat_id=chat_id, message_id=message_id) @@ -1391,17 +1397,17 @@ class ChatTable: ChatFileModel.model_validate(chat_file) for chat_file in all_chat_files ] - def delete_chat_file(self, chat_id: str, file_id: str) -> bool: + def delete_chat_file(self, chat_id: str, file_id: str, db: Optional[Session] = None) -> bool: try: - with get_db() as db: + with get_db_context(db) as db: db.query(ChatFile).filter_by(chat_id=chat_id, file_id=file_id).delete() db.commit() return True except Exception: return False - def get_shared_chats_by_file_id(self, file_id: str) -> list[ChatModel]: - with get_db() as db: + def get_shared_chats_by_file_id(self, file_id: str, db: Optional[Session] = None) -> list[ChatModel]: + with get_db_context(db) as db: # Join Chat and ChatFile tables to get shared chats associated with the file_id all_chats = ( db.query(Chat) diff --git a/backend/open_webui/models/feedbacks.py b/backend/open_webui/models/feedbacks.py index 39e22ff2d9..fa2629a684 100644 --- a/backend/open_webui/models/feedbacks.py +++ b/backend/open_webui/models/feedbacks.py @@ -3,7 +3,8 @@ import time import uuid from typing import Optional -from open_webui.internal.db import Base, get_db +from sqlalchemy.orm import Session +from open_webui.internal.db import Base, JSONField, get_db, get_db_context from open_webui.models.users import User from pydantic import BaseModel, ConfigDict @@ -121,9 +122,9 @@ class FeedbackListResponse(BaseModel): class FeedbackTable: def insert_new_feedback( - self, user_id: str, form_data: FeedbackForm + self, user_id: str, form_data: FeedbackForm, db: Optional[Session] = None ) -> Optional[FeedbackModel]: - with get_db() as db: + with get_db_context(db) as db: id = str(uuid.uuid4()) feedback = FeedbackModel( **{ @@ -148,9 +149,9 @@ class FeedbackTable: log.exception(f"Error creating a new feedback: {e}") return None - def get_feedback_by_id(self, id: str) -> Optional[FeedbackModel]: + def get_feedback_by_id(self, id: str, db: Optional[Session] = None) -> Optional[FeedbackModel]: try: - with get_db() as db: + with get_db_context(db) as db: feedback = db.query(Feedback).filter_by(id=id).first() if not feedback: return None @@ -159,10 +160,10 @@ class FeedbackTable: return None def get_feedback_by_id_and_user_id( - self, id: str, user_id: str + self, id: str, user_id: str, db: Optional[Session] = None ) -> Optional[FeedbackModel]: try: - with get_db() as db: + with get_db_context(db) as db: feedback = db.query(Feedback).filter_by(id=id, user_id=user_id).first() if not feedback: return None @@ -171,9 +172,9 @@ class FeedbackTable: return None def get_feedback_items( - self, filter: dict = {}, skip: int = 0, limit: int = 30 + self, filter: dict = {}, skip: int = 0, limit: int = 30, db: Optional[Session] = None ) -> FeedbackListResponse: - with get_db() as db: + with get_db_context(db) as db: query = db.query(Feedback, User).join(User, Feedback.user_id == User.id) if filter: @@ -234,8 +235,8 @@ class FeedbackTable: return FeedbackListResponse(items=feedbacks, total=total) - def get_all_feedbacks(self) -> list[FeedbackModel]: - with get_db() as db: + def get_all_feedbacks(self, db: Optional[Session] = None) -> list[FeedbackModel]: + with get_db_context(db) as db: return [ FeedbackModel.model_validate(feedback) for feedback in db.query(Feedback) @@ -243,8 +244,8 @@ class FeedbackTable: .all() ] - def get_feedbacks_by_type(self, type: str) -> list[FeedbackModel]: - with get_db() as db: + def get_feedbacks_by_type(self, type: str, db: Optional[Session] = None) -> list[FeedbackModel]: + with get_db_context(db) as db: return [ FeedbackModel.model_validate(feedback) for feedback in db.query(Feedback) @@ -253,8 +254,8 @@ class FeedbackTable: .all() ] - def get_feedbacks_by_user_id(self, user_id: str) -> list[FeedbackModel]: - with get_db() as db: + def get_feedbacks_by_user_id(self, user_id: str, db: Optional[Session] = None) -> list[FeedbackModel]: + with get_db_context(db) as db: return [ FeedbackModel.model_validate(feedback) for feedback in db.query(Feedback) @@ -264,9 +265,9 @@ class FeedbackTable: ] def update_feedback_by_id( - self, id: str, form_data: FeedbackForm + self, id: str, form_data: FeedbackForm, db: Optional[Session] = None ) -> Optional[FeedbackModel]: - with get_db() as db: + with get_db_context(db) as db: feedback = db.query(Feedback).filter_by(id=id).first() if not feedback: return None @@ -284,9 +285,9 @@ class FeedbackTable: return FeedbackModel.model_validate(feedback) def update_feedback_by_id_and_user_id( - self, id: str, user_id: str, form_data: FeedbackForm + self, id: str, user_id: str, form_data: FeedbackForm, db: Optional[Session] = None ) -> Optional[FeedbackModel]: - with get_db() as db: + with get_db_context(db) as db: feedback = db.query(Feedback).filter_by(id=id, user_id=user_id).first() if not feedback: return None @@ -303,8 +304,8 @@ class FeedbackTable: db.commit() return FeedbackModel.model_validate(feedback) - def delete_feedback_by_id(self, id: str) -> bool: - with get_db() as db: + def delete_feedback_by_id(self, id: str, db: Optional[Session] = None) -> bool: + with get_db_context(db) as db: feedback = db.query(Feedback).filter_by(id=id).first() if not feedback: return False @@ -312,8 +313,8 @@ class FeedbackTable: db.commit() return True - def delete_feedback_by_id_and_user_id(self, id: str, user_id: str) -> bool: - with get_db() as db: + def delete_feedback_by_id_and_user_id(self, id: str, user_id: str, db: Optional[Session] = None) -> bool: + with get_db_context(db) as db: feedback = db.query(Feedback).filter_by(id=id, user_id=user_id).first() if not feedback: return False @@ -321,8 +322,8 @@ class FeedbackTable: db.commit() return True - def delete_feedbacks_by_user_id(self, user_id: str) -> bool: - with get_db() as db: + def delete_feedbacks_by_user_id(self, user_id: str, db: Optional[Session] = None) -> bool: + with get_db_context(db) as db: feedbacks = db.query(Feedback).filter_by(user_id=user_id).all() if not feedbacks: return False @@ -331,8 +332,8 @@ class FeedbackTable: db.commit() return True - def delete_all_feedbacks(self) -> bool: - with get_db() as db: + def delete_all_feedbacks(self, db: Optional[Session] = None) -> bool: + with get_db_context(db) as db: feedbacks = db.query(Feedback).all() if not feedbacks: return False diff --git a/backend/open_webui/models/files.py b/backend/open_webui/models/files.py index 9d4e8fb054..387dd5b28c 100644 --- a/backend/open_webui/models/files.py +++ b/backend/open_webui/models/files.py @@ -2,7 +2,8 @@ import logging import time from typing import Optional -from open_webui.internal.db import Base, JSONField, get_db +from sqlalchemy.orm import Session +from open_webui.internal.db import Base, JSONField, get_db, get_db_context from pydantic import BaseModel, ConfigDict from sqlalchemy import BigInteger, Column, String, Text, JSON @@ -108,8 +109,10 @@ class FileListResponse(BaseModel): class FilesTable: - def insert_new_file(self, user_id: str, form_data: FileForm) -> Optional[FileModel]: - with get_db() as db: + def insert_new_file( + self, user_id: str, form_data: FileForm, db: Optional[Session] = None + ) -> Optional[FileModel]: + with get_db_context(db) as db: file = FileModel( **{ **form_data.model_dump(), @@ -132,13 +135,16 @@ class FilesTable: log.exception(f"Error inserting a new file: {e}") return None - def get_file_by_id(self, id: str) -> Optional[FileModel]: - with get_db() as db: - try: - file = db.get(File, id) - return FileModel.model_validate(file) - except Exception: - return None + def get_file_by_id(self, id: str, db: Optional[Session] = None) -> Optional[FileModel]: + try: + with get_db_context(db) as db: + try: + file = db.get(File, id) + return FileModel.model_validate(file) + except Exception: + return None + except Exception: + return None def get_file_by_id_and_user_id(self, id: str, user_id: str) -> Optional[FileModel]: with get_db() as db: @@ -165,8 +171,8 @@ class FilesTable: except Exception: return None - def get_files(self) -> list[FileModel]: - with get_db() as db: + def get_files(self, db: Optional[Session] = None) -> list[FileModel]: + with get_db_context(db) as db: return [FileModel.model_validate(file) for file in db.query(File).all()] def check_access_by_user_id(self, id, user_id, permission="write") -> bool: @@ -206,8 +212,8 @@ class FilesTable: .all() ] - def get_files_by_user_id(self, user_id: str) -> list[FileModel]: - with get_db() as db: + def get_files_by_user_id(self, user_id: str, db: Optional[Session] = None) -> list[FileModel]: + with get_db_context(db) as db: return [ FileModel.model_validate(file) for file in db.query(File).filter_by(user_id=user_id).all() @@ -271,24 +277,6 @@ class FilesTable: except Exception: return None - def delete_file_by_id(self, id: str) -> bool: - with get_db() as db: - try: - db.query(File).filter_by(id=id).delete() - db.commit() - - return True - except Exception: - return False - - def delete_all_files(self) -> bool: - with get_db() as db: - try: - db.query(File).delete() - db.commit() - - return True - except Exception: return False diff --git a/backend/open_webui/models/folders.py b/backend/open_webui/models/folders.py index 0043dd3644..1744c29454 100644 --- a/backend/open_webui/models/folders.py +++ b/backend/open_webui/models/folders.py @@ -7,8 +7,9 @@ import re from pydantic import BaseModel, ConfigDict from sqlalchemy import BigInteger, Column, Text, JSON, Boolean, func +from sqlalchemy.orm import Session -from open_webui.internal.db import Base, get_db +from open_webui.internal.db import Base, JSONField, get_db, get_db_context log = logging.getLogger(__name__) @@ -83,9 +84,9 @@ class FolderUpdateForm(BaseModel): class FolderTable: def insert_new_folder( - self, user_id: str, form_data: FolderForm, parent_id: Optional[str] = None + self, user_id: str, form_data: FolderForm, parent_id: Optional[str] = None, db: Optional[Session] = None ) -> Optional[FolderModel]: - with get_db() as db: + with get_db_context(db) as db: id = str(uuid.uuid4()) folder = FolderModel( **{ @@ -111,10 +112,10 @@ class FolderTable: return None def get_folder_by_id_and_user_id( - self, id: str, user_id: str + self, id: str, user_id: str, db: Optional[Session] = None ) -> Optional[FolderModel]: try: - with get_db() as db: + with get_db_context(db) as db: folder = db.query(Folder).filter_by(id=id, user_id=user_id).first() if not folder: @@ -125,15 +126,15 @@ class FolderTable: return None def get_children_folders_by_id_and_user_id( - self, id: str, user_id: str + self, id: str, user_id: str, db: Optional[Session] = None ) -> Optional[list[FolderModel]]: try: - with get_db() as db: + with get_db_context(db) as db: folders = [] def get_children(folder): children = self.get_folders_by_parent_id_and_user_id( - folder.id, user_id + folder.id, user_id, db=db ) for child in children: get_children(child) @@ -148,18 +149,18 @@ class FolderTable: except Exception: return None - def get_folders_by_user_id(self, user_id: str) -> list[FolderModel]: - with get_db() as db: + def get_folders_by_user_id(self, user_id: str, db: Optional[Session] = None) -> list[FolderModel]: + with get_db_context(db) as db: return [ FolderModel.model_validate(folder) for folder in db.query(Folder).filter_by(user_id=user_id).all() ] def get_folder_by_parent_id_and_user_id_and_name( - self, parent_id: Optional[str], user_id: str, name: str + self, parent_id: Optional[str], user_id: str, name: str, db: Optional[Session] = None ) -> Optional[FolderModel]: try: - with get_db() as db: + with get_db_context(db) as db: # Check if folder exists folder = ( db.query(Folder) @@ -177,9 +178,9 @@ class FolderTable: return None def get_folders_by_parent_id_and_user_id( - self, parent_id: Optional[str], user_id: str + self, parent_id: Optional[str], user_id: str, db: Optional[Session] = None ) -> list[FolderModel]: - with get_db() as db: + with get_db_context(db) as db: return [ FolderModel.model_validate(folder) for folder in db.query(Folder) @@ -192,9 +193,10 @@ class FolderTable: id: str, user_id: str, parent_id: str, + db: Optional[Session] = None, ) -> Optional[FolderModel]: try: - with get_db() as db: + with get_db_context(db) as db: folder = db.query(Folder).filter_by(id=id, user_id=user_id).first() if not folder: @@ -211,10 +213,10 @@ class FolderTable: return def update_folder_by_id_and_user_id( - self, id: str, user_id: str, form_data: FolderUpdateForm + self, id: str, user_id: str, form_data: FolderUpdateForm, db: Optional[Session] = None ) -> Optional[FolderModel]: try: - with get_db() as db: + with get_db_context(db) as db: folder = db.query(Folder).filter_by(id=id, user_id=user_id).first() if not folder: @@ -257,10 +259,10 @@ class FolderTable: return def update_folder_is_expanded_by_id_and_user_id( - self, id: str, user_id: str, is_expanded: bool + self, id: str, user_id: str, is_expanded: bool, db: Optional[Session] = None ) -> Optional[FolderModel]: try: - with get_db() as db: + with get_db_context(db) as db: folder = db.query(Folder).filter_by(id=id, user_id=user_id).first() if not folder: @@ -276,10 +278,10 @@ class FolderTable: log.error(f"update_folder: {e}") return - def delete_folder_by_id_and_user_id(self, id: str, user_id: str) -> list[str]: + def delete_folder_by_id_and_user_id(self, id: str, user_id: str, db: Optional[Session] = None) -> list[str]: try: folder_ids = [] - with get_db() as db: + with get_db_context(db) as db: folder = db.query(Folder).filter_by(id=id, user_id=user_id).first() if not folder: return folder_ids @@ -289,7 +291,7 @@ class FolderTable: # Delete all children folders def delete_children(folder): folder_children = self.get_folders_by_parent_id_and_user_id( - folder.id, user_id + folder.id, user_id, db=db ) for folder_child in folder_children: @@ -314,7 +316,7 @@ class FolderTable: return name.strip().lower() def search_folders_by_names( - self, user_id: str, queries: list[str] + self, user_id: str, queries: list[str], db: Optional[Session] = None ) -> list[FolderModel]: """ Search for folders for a user where the name matches any of the queries, treating _ and space as equivalent, case-insensitive. @@ -324,7 +326,7 @@ class FolderTable: return [] results = {} - with get_db() as db: + with get_db_context(db) as db: folders = db.query(Folder).filter_by(user_id=user_id).all() for folder in folders: if self.normalize_folder_name(folder.name) in normalized_queries: @@ -332,7 +334,7 @@ class FolderTable: # get children folders children = self.get_children_folders_by_id_and_user_id( - folder.id, user_id + folder.id, user_id, db=db ) for child in children: results[child.id] = child @@ -345,14 +347,14 @@ class FolderTable: return results def search_folders_by_name_contains( - self, user_id: str, query: str + self, user_id: str, query: str, db: Optional[Session] = None ) -> list[FolderModel]: """ Partial match: normalized name contains (as substring) the normalized query. """ normalized_query = self.normalize_folder_name(query) results = [] - with get_db() as db: + with get_db_context(db) as db: folders = db.query(Folder).filter_by(user_id=user_id).all() for folder in folders: norm_name = self.normalize_folder_name(folder.name) diff --git a/backend/open_webui/models/functions.py b/backend/open_webui/models/functions.py index 19ad985d0c..f6d2aa3375 100644 --- a/backend/open_webui/models/functions.py +++ b/backend/open_webui/models/functions.py @@ -2,7 +2,8 @@ import logging import time from typing import Optional -from open_webui.internal.db import Base, JSONField, get_db +from sqlalchemy.orm import Session +from open_webui.internal.db import Base, JSONField, get_db, get_db_context from open_webui.models.users import Users, UserModel from pydantic import BaseModel, ConfigDict from sqlalchemy import BigInteger, Boolean, Column, String, Text, Index @@ -103,7 +104,7 @@ class FunctionValves(BaseModel): class FunctionsTable: def insert_new_function( - self, user_id: str, type: str, form_data: FunctionForm + self, user_id: str, type: str, form_data: FunctionForm, db: Optional[Session] = None ) -> Optional[FunctionModel]: function = FunctionModel( **{ @@ -116,7 +117,7 @@ class FunctionsTable: ) try: - with get_db() as db: + with get_db_context(db) as db: result = Function(**function.model_dump()) db.add(result) db.commit() @@ -130,11 +131,11 @@ class FunctionsTable: return None def sync_functions( - self, user_id: str, functions: list[FunctionWithValvesModel] + self, user_id: str, functions: list[FunctionWithValvesModel], db: Optional[Session] = None ) -> list[FunctionWithValvesModel]: # Synchronize functions for a user by updating existing ones, inserting new ones, and removing those that are no longer present. try: - with get_db() as db: + with get_db_context(db) as db: # Get existing functions existing_functions = db.query(Function).all() existing_ids = {func.id for func in existing_functions} @@ -177,18 +178,18 @@ class FunctionsTable: log.exception(f"Error syncing functions for user {user_id}: {e}") return [] - def get_function_by_id(self, id: str) -> Optional[FunctionModel]: + def get_function_by_id(self, id: str, db: Optional[Session] = None) -> Optional[FunctionModel]: try: - with get_db() as db: + with get_db_context(db) as db: function = db.get(Function, id) return FunctionModel.model_validate(function) except Exception: return None def get_functions( - self, active_only=False, include_valves=False + self, active_only=False, include_valves=False, db: Optional[Session] = None ) -> list[FunctionModel | FunctionWithValvesModel]: - with get_db() as db: + with get_db_context(db) as db: if active_only: functions = db.query(Function).filter_by(is_active=True).all() @@ -205,12 +206,12 @@ class FunctionsTable: FunctionModel.model_validate(function) for function in functions ] - def get_function_list(self) -> list[FunctionUserResponse]: - with get_db() as db: + def get_function_list(self, db: Optional[Session] = None) -> list[FunctionUserResponse]: + with get_db_context(db) as db: functions = db.query(Function).order_by(Function.updated_at.desc()).all() user_ids = list(set(func.user_id for func in functions)) - users = Users.get_users_by_user_ids(user_ids) if user_ids else [] + users = Users.get_users_by_user_ids(user_ids, db=db) if user_ids else [] users_dict = {user.id: user for user in users} return [ @@ -228,9 +229,9 @@ class FunctionsTable: ] def get_functions_by_type( - self, type: str, active_only=False + self, type: str, active_only=False, db: Optional[Session] = None ) -> list[FunctionModel]: - with get_db() as db: + with get_db_context(db) as db: if active_only: return [ FunctionModel.model_validate(function) @@ -244,8 +245,8 @@ class FunctionsTable: for function in db.query(Function).filter_by(type=type).all() ] - def get_global_filter_functions(self) -> list[FunctionModel]: - with get_db() as db: + def get_global_filter_functions(self, db: Optional[Session] = None) -> list[FunctionModel]: + with get_db_context(db) as db: return [ FunctionModel.model_validate(function) for function in db.query(Function) @@ -253,8 +254,8 @@ class FunctionsTable: .all() ] - def get_global_action_functions(self) -> list[FunctionModel]: - with get_db() as db: + def get_global_action_functions(self, db: Optional[Session] = None) -> list[FunctionModel]: + with get_db_context(db) as db: return [ FunctionModel.model_validate(function) for function in db.query(Function) @@ -262,8 +263,8 @@ class FunctionsTable: .all() ] - def get_function_valves_by_id(self, id: str) -> Optional[dict]: - with get_db() as db: + def get_function_valves_by_id(self, id: str, db: Optional[Session] = None) -> Optional[dict]: + with get_db_context(db) as db: try: function = db.get(Function, id) return function.valves if function.valves else {} @@ -272,23 +273,23 @@ class FunctionsTable: return None def update_function_valves_by_id( - self, id: str, valves: dict + self, id: str, valves: dict, db: Optional[Session] = None ) -> Optional[FunctionValves]: - with get_db() as db: + with get_db_context(db) as db: try: function = db.get(Function, id) function.valves = valves function.updated_at = int(time.time()) db.commit() db.refresh(function) - return self.get_function_by_id(id) + return self.get_function_by_id(id, db=db) except Exception: return None def update_function_metadata_by_id( - self, id: str, metadata: dict + self, id: str, metadata: dict, db: Optional[Session] = None ) -> Optional[FunctionModel]: - with get_db() as db: + with get_db_context(db) as db: try: function = db.get(Function, id) @@ -301,7 +302,7 @@ class FunctionsTable: function.updated_at = int(time.time()) db.commit() db.refresh(function) - return self.get_function_by_id(id) + return self.get_function_by_id(id, db=db) else: return None except Exception as e: @@ -309,10 +310,10 @@ class FunctionsTable: return None def get_user_valves_by_id_and_user_id( - self, id: str, user_id: str + self, id: str, user_id: str, db: Optional[Session] = None ) -> Optional[dict]: try: - user = Users.get_user_by_id(user_id) + user = Users.get_user_by_id(user_id, db=db) user_settings = user.settings.model_dump() if user.settings else {} # Check if user has "functions" and "valves" settings @@ -327,10 +328,10 @@ class FunctionsTable: return None def update_user_valves_by_id_and_user_id( - self, id: str, user_id: str, valves: dict + self, id: str, user_id: str, valves: dict, db: Optional[Session] = None ) -> Optional[dict]: try: - user = Users.get_user_by_id(user_id) + user = Users.get_user_by_id(user_id, db=db) user_settings = user.settings.model_dump() if user.settings else {} # Check if user has "functions" and "valves" settings @@ -342,7 +343,7 @@ class FunctionsTable: user_settings["functions"]["valves"][id] = valves # Update the user settings in the database - Users.update_user_by_id(user_id, {"settings": user_settings}) + Users.update_user_by_id(user_id, {"settings": user_settings}, db=db) return user_settings["functions"]["valves"][id] except Exception as e: @@ -351,8 +352,8 @@ class FunctionsTable: ) return None - def update_function_by_id(self, id: str, updated: dict) -> Optional[FunctionModel]: - with get_db() as db: + def update_function_by_id(self, id: str, updated: dict, db: Optional[Session] = None) -> Optional[FunctionModel]: + with get_db_context(db) as db: try: db.query(Function).filter_by(id=id).update( { @@ -361,12 +362,12 @@ class FunctionsTable: } ) db.commit() - return self.get_function_by_id(id) + return self.get_function_by_id(id, db=db) except Exception: return None - def deactivate_all_functions(self) -> Optional[bool]: - with get_db() as db: + def deactivate_all_functions(self, db: Optional[Session] = None) -> Optional[bool]: + with get_db_context(db) as db: try: db.query(Function).update( { @@ -379,8 +380,8 @@ class FunctionsTable: except Exception: return None - def delete_function_by_id(self, id: str) -> bool: - with get_db() as db: + def delete_function_by_id(self, id: str, db: Optional[Session] = None) -> bool: + with get_db_context(db) as db: try: db.query(Function).filter_by(id=id).delete() db.commit() diff --git a/backend/open_webui/models/groups.py b/backend/open_webui/models/groups.py index da94287111..1ceb8f1ec0 100644 --- a/backend/open_webui/models/groups.py +++ b/backend/open_webui/models/groups.py @@ -4,7 +4,8 @@ import time from typing import Optional import uuid -from open_webui.internal.db import Base, get_db +from sqlalchemy.orm import Session +from open_webui.internal.db import Base, JSONField, get_db, get_db_context from open_webui.models.files import FileMetadataResponse @@ -120,9 +121,9 @@ class GroupListResponse(BaseModel): class GroupTable: def insert_new_group( - self, user_id: str, form_data: GroupForm + self, user_id: str, form_data: GroupForm, db: Optional[Session] = None ) -> Optional[GroupModel]: - with get_db() as db: + with get_db_context(db) as db: group = GroupModel( **{ **form_data.model_dump(exclude_none=True), @@ -146,13 +147,13 @@ class GroupTable: except Exception: return None - def get_all_groups(self) -> list[GroupModel]: - with get_db() as db: + def get_all_groups(self, db: Optional[Session] = None) -> list[GroupModel]: + with get_db_context(db) as db: groups = db.query(Group).order_by(Group.updated_at.desc()).all() return [GroupModel.model_validate(group) for group in groups] - def get_groups(self, filter) -> list[GroupResponse]: - with get_db() as db: + def get_groups(self, filter, db: Optional[Session] = None) -> list[GroupResponse]: + with get_db_context(db) as db: query = db.query(Group) if filter: @@ -184,16 +185,16 @@ class GroupTable: GroupResponse.model_validate( { **GroupModel.model_validate(group).model_dump(), - "member_count": self.get_group_member_count_by_id(group.id), + "member_count": self.get_group_member_count_by_id(group.id, db=db), } ) for group in groups ] def search_groups( - self, filter: Optional[dict] = None, skip: int = 0, limit: int = 30 + self, filter: Optional[dict] = None, skip: int = 0, limit: int = 30, db: Optional[Session] = None ) -> GroupListResponse: - with get_db() as db: + with get_db_context(db) as db: query = db.query(Group) if filter: @@ -220,15 +221,15 @@ class GroupTable: "items": [ GroupResponse.model_validate( **GroupModel.model_validate(group).model_dump(), - member_count=self.get_group_member_count_by_id(group.id), + member_count=self.get_group_member_count_by_id(group.id, db=db), ) for group in groups ], "total": total, } - def get_groups_by_member_id(self, user_id: str) -> list[GroupModel]: - with get_db() as db: + def get_groups_by_member_id(self, user_id: str, db: Optional[Session] = None) -> list[GroupModel]: + with get_db_context(db) as db: return [ GroupModel.model_validate(group) for group in db.query(Group) @@ -238,16 +239,16 @@ class GroupTable: .all() ] - def get_group_by_id(self, id: str) -> Optional[GroupModel]: + def get_group_by_id(self, id: str, db: Optional[Session] = None) -> Optional[GroupModel]: try: - with get_db() as db: + with get_db_context(db) as db: group = db.query(Group).filter_by(id=id).first() return GroupModel.model_validate(group) if group else None except Exception: return None - def get_group_user_ids_by_id(self, id: str) -> Optional[list[str]]: - with get_db() as db: + def get_group_user_ids_by_id(self, id: str, db: Optional[Session] = None) -> Optional[list[str]]: + with get_db_context(db) as db: members = ( db.query(GroupMember.user_id).filter(GroupMember.group_id == id).all() ) @@ -257,8 +258,8 @@ class GroupTable: return [m[0] for m in members] - def get_group_user_ids_by_ids(self, group_ids: list[str]) -> dict[str, list[str]]: - with get_db() as db: + def get_group_user_ids_by_ids(self, group_ids: list[str], db: Optional[Session] = None) -> dict[str, list[str]]: + with get_db_context(db) as db: members = ( db.query(GroupMember.group_id, GroupMember.user_id) .filter(GroupMember.group_id.in_(group_ids)) @@ -274,8 +275,8 @@ class GroupTable: return group_user_ids - def set_group_user_ids_by_id(self, group_id: str, user_ids: list[str]) -> None: - with get_db() as db: + def set_group_user_ids_by_id(self, group_id: str, user_ids: list[str], db: Optional[Session] = None) -> None: + with get_db_context(db) as db: # Delete existing members db.query(GroupMember).filter(GroupMember.group_id == group_id).delete() @@ -295,8 +296,8 @@ class GroupTable: db.add_all(new_members) db.commit() - def get_group_member_count_by_id(self, id: str) -> int: - with get_db() as db: + def get_group_member_count_by_id(self, id: str, db: Optional[Session] = None) -> int: + with get_db_context(db) as db: count = ( db.query(func.count(GroupMember.user_id)) .filter(GroupMember.group_id == id) @@ -305,10 +306,10 @@ class GroupTable: return count if count else 0 def update_group_by_id( - self, id: str, form_data: GroupUpdateForm, overwrite: bool = False + self, id: str, form_data: GroupUpdateForm, overwrite: bool = False, db: Optional[Session] = None ) -> Optional[GroupModel]: try: - with get_db() as db: + with get_db_context(db) as db: db.query(Group).filter_by(id=id).update( { **form_data.model_dump(exclude_none=True), @@ -316,22 +317,22 @@ class GroupTable: } ) db.commit() - return self.get_group_by_id(id=id) + return self.get_group_by_id(id=id, db=db) except Exception as e: log.exception(e) return None - def delete_group_by_id(self, id: str) -> bool: + def delete_group_by_id(self, id: str, db: Optional[Session] = None) -> bool: try: - with get_db() as db: + with get_db_context(db) as db: db.query(Group).filter_by(id=id).delete() db.commit() return True except Exception: return False - def delete_all_groups(self) -> bool: - with get_db() as db: + def delete_all_groups(self, db: Optional[Session] = None) -> bool: + with get_db_context(db) as db: try: db.query(Group).delete() db.commit() @@ -340,8 +341,8 @@ class GroupTable: except Exception: return False - def remove_user_from_all_groups(self, user_id: str) -> bool: - with get_db() as db: + def remove_user_from_all_groups(self, user_id: str, db: Optional[Session] = None) -> bool: + with get_db_context(db) as db: try: # Find all groups the user belongs to groups = ( @@ -369,16 +370,16 @@ class GroupTable: return False def create_groups_by_group_names( - self, user_id: str, group_names: list[str] + self, user_id: str, group_names: list[str], db: Optional[Session] = None ) -> list[GroupModel]: # check for existing groups - existing_groups = self.get_all_groups() + existing_groups = self.get_all_groups(db=db) existing_group_names = {group.name for group in existing_groups} new_groups = [] - with get_db() as db: + with get_db_context(db) as db: for group_name in group_names: if group_name not in existing_group_names: new_group = GroupModel( @@ -400,8 +401,8 @@ class GroupTable: continue return new_groups - def sync_groups_by_group_names(self, user_id: str, group_names: list[str]) -> bool: - with get_db() as db: + def sync_groups_by_group_names(self, user_id: str, group_names: list[str], db: Optional[Session] = None) -> bool: + with get_db_context(db) as db: try: now = int(time.time()) @@ -461,10 +462,10 @@ class GroupTable: return False def add_users_to_group( - self, id: str, user_ids: Optional[list[str]] = None + self, id: str, user_ids: Optional[list[str]] = None, db: Optional[Session] = None ) -> Optional[GroupModel]: try: - with get_db() as db: + with get_db_context(db) as db: group = db.query(Group).filter_by(id=id).first() if not group: return None @@ -499,10 +500,10 @@ class GroupTable: return None def remove_users_from_group( - self, id: str, user_ids: Optional[list[str]] = None + self, id: str, user_ids: Optional[list[str]] = None, db: Optional[Session] = None ) -> Optional[GroupModel]: try: - with get_db() as db: + with get_db_context(db) as db: group = db.query(Group).filter_by(id=id).first() if not group: return None diff --git a/backend/open_webui/models/knowledge.py b/backend/open_webui/models/knowledge.py index d8af004338..348f0de64b 100644 --- a/backend/open_webui/models/knowledge.py +++ b/backend/open_webui/models/knowledge.py @@ -1,10 +1,12 @@ + import json import logging import time from typing import Optional import uuid -from open_webui.internal.db import Base, get_db +from sqlalchemy.orm import Session +from open_webui.internal.db import Base, JSONField, get_db, get_db_context from open_webui.models.files import ( File, @@ -157,9 +159,9 @@ class KnowledgeFileListResponse(BaseModel): class KnowledgeTable: def insert_new_knowledge( - self, user_id: str, form_data: KnowledgeForm + self, user_id: str, form_data: KnowledgeForm, db: Optional[Session] = None ) -> Optional[KnowledgeModel]: - with get_db() as db: + with get_db_context(db) as db: knowledge = KnowledgeModel( **{ **form_data.model_dump(), @@ -183,15 +185,15 @@ class KnowledgeTable: return None def get_knowledge_bases( - self, skip: int = 0, limit: int = 30 + self, skip: int = 0, limit: int = 30, db: Optional[Session] = None ) -> list[KnowledgeUserModel]: - with get_db() as db: + with get_db_context(db) as db: all_knowledge = ( db.query(Knowledge).order_by(Knowledge.updated_at.desc()).all() ) user_ids = list(set(knowledge.user_id for knowledge in all_knowledge)) - users = Users.get_users_by_user_ids(user_ids) if user_ids else [] + users = Users.get_users_by_user_ids(user_ids, db=db) if user_ids else [] users_dict = {user.id: user for user in users} knowledge_bases = [] @@ -208,10 +210,10 @@ class KnowledgeTable: return knowledge_bases def search_knowledge_bases( - self, user_id: str, filter: dict, skip: int = 0, limit: int = 30 + self, user_id: str, filter: dict, skip: int = 0, limit: int = 30, db: Optional[Session] = None ) -> KnowledgeListResponse: try: - with get_db() as db: + with get_db_context(db) as db: query = db.query(Knowledge, User).outerjoin( User, User.id == Knowledge.user_id ) @@ -267,14 +269,14 @@ class KnowledgeTable: return KnowledgeListResponse(items=[], total=0) def search_knowledge_files( - self, filter: dict, skip: int = 0, limit: int = 30 + self, filter: dict, skip: int = 0, limit: int = 30, db: Optional[Session] = None ) -> KnowledgeFileListResponse: """ Scalable version: search files across all knowledge bases the user has READ access to, without loading all KBs or using large IN() lists. """ try: - with get_db() as db: + with get_db_context(db) as db: # Base query: join Knowledge → KnowledgeFile → File query = ( db.query(File, User) @@ -327,20 +329,20 @@ class KnowledgeTable: print("search_knowledge_files error:", e) return KnowledgeFileListResponse(items=[], total=0) - def check_access_by_user_id(self, id, user_id, permission="write") -> bool: - knowledge = self.get_knowledge_by_id(id) + def check_access_by_user_id(self, id, user_id, permission="write", db: Optional[Session] = None) -> bool: + knowledge = self.get_knowledge_by_id(id, db=db) if not knowledge: return False if knowledge.user_id == user_id: return True - user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id)} + user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id, db=db)} return has_access(user_id, permission, knowledge.access_control, user_group_ids) def get_knowledge_bases_by_user_id( - self, user_id: str, permission: str = "write" + self, user_id: str, permission: str = "write", db: Optional[Session] = None ) -> list[KnowledgeUserModel]: - knowledge_bases = self.get_knowledge_bases() - user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id)} + knowledge_bases = self.get_knowledge_bases(db=db) + user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id, db=db)} return [ knowledge_base for knowledge_base in knowledge_bases @@ -350,32 +352,32 @@ class KnowledgeTable: ) ] - def get_knowledge_by_id(self, id: str) -> Optional[KnowledgeModel]: + def get_knowledge_by_id(self, id: str, db: Optional[Session] = None) -> Optional[KnowledgeModel]: try: - with get_db() as db: + with get_db_context(db) as db: knowledge = db.query(Knowledge).filter_by(id=id).first() return KnowledgeModel.model_validate(knowledge) if knowledge else None except Exception: return None def get_knowledge_by_id_and_user_id( - self, id: str, user_id: str + self, id: str, user_id: str, db: Optional[Session] = None ) -> Optional[KnowledgeModel]: - knowledge = self.get_knowledge_by_id(id) + knowledge = self.get_knowledge_by_id(id, db=db) if not knowledge: return None if knowledge.user_id == user_id: return knowledge - user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id)} + user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id, db=db)} if has_access(user_id, "write", knowledge.access_control, user_group_ids): return knowledge return None - def get_knowledges_by_file_id(self, file_id: str) -> list[KnowledgeModel]: + def get_knowledges_by_file_id(self, file_id: str, db: Optional[Session] = None) -> list[KnowledgeModel]: try: - with get_db() as db: + with get_db_context(db) as db: knowledges = ( db.query(Knowledge) .join(KnowledgeFile, Knowledge.id == KnowledgeFile.knowledge_id) @@ -395,9 +397,10 @@ class KnowledgeTable: filter: dict, skip: int = 0, limit: int = 30, + db: Optional[Session] = None, ) -> KnowledgeFileListResponse: try: - with get_db() as db: + with get_db_context(db) as db: query = ( db.query(File, User) .join(KnowledgeFile, File.id == KnowledgeFile.file_id) @@ -470,9 +473,9 @@ class KnowledgeTable: print(e) return KnowledgeFileListResponse(items=[], total=0) - def get_files_by_id(self, knowledge_id: str) -> list[FileModel]: + def get_files_by_id(self, knowledge_id: str, db: Optional[Session] = None) -> list[FileModel]: try: - with get_db() as db: + with get_db_context(db) as db: files = ( db.query(File) .join(KnowledgeFile, File.id == KnowledgeFile.file_id) @@ -483,18 +486,18 @@ class KnowledgeTable: except Exception: return [] - def get_file_metadatas_by_id(self, knowledge_id: str) -> list[FileMetadataResponse]: + def get_file_metadatas_by_id(self, knowledge_id: str, db: Optional[Session] = None) -> list[FileMetadataResponse]: try: - with get_db() as db: - files = self.get_files_by_id(knowledge_id) + with get_db_context(db) as db: + files = self.get_files_by_id(knowledge_id, db=db) return [FileMetadataResponse(**file.model_dump()) for file in files] except Exception: return [] def add_file_to_knowledge_by_id( - self, knowledge_id: str, file_id: str, user_id: str + self, knowledge_id: str, file_id: str, user_id: str, db: Optional[Session] = None ) -> Optional[KnowledgeFileModel]: - with get_db() as db: + with get_db_context(db) as db: knowledge_file = KnowledgeFileModel( **{ "id": str(uuid.uuid4()), @@ -518,9 +521,9 @@ class KnowledgeTable: except Exception: return None - def remove_file_from_knowledge_by_id(self, knowledge_id: str, file_id: str) -> bool: + def remove_file_from_knowledge_by_id(self, knowledge_id: str, file_id: str, db: Optional[Session] = None) -> bool: try: - with get_db() as db: + with get_db_context(db) as db: db.query(KnowledgeFile).filter_by( knowledge_id=knowledge_id, file_id=file_id ).delete() @@ -529,9 +532,9 @@ class KnowledgeTable: except Exception: return False - def reset_knowledge_by_id(self, id: str) -> Optional[KnowledgeModel]: + def reset_knowledge_by_id(self, id: str, db: Optional[Session] = None) -> Optional[KnowledgeModel]: try: - with get_db() as db: + with get_db_context(db) as db: # Delete all knowledge_file entries for this knowledge_id db.query(KnowledgeFile).filter_by(knowledge_id=id).delete() db.commit() @@ -544,17 +547,17 @@ class KnowledgeTable: ) db.commit() - return self.get_knowledge_by_id(id=id) + return self.get_knowledge_by_id(id=id, db=db) except Exception as e: log.exception(e) return None def update_knowledge_by_id( - self, id: str, form_data: KnowledgeForm, overwrite: bool = False + self, id: str, form_data: KnowledgeForm, overwrite: bool = False, db: Optional[Session] = None ) -> Optional[KnowledgeModel]: try: - with get_db() as db: - knowledge = self.get_knowledge_by_id(id=id) + with get_db_context(db) as db: + knowledge = self.get_knowledge_by_id(id=id, db=db) db.query(Knowledge).filter_by(id=id).update( { **form_data.model_dump(), @@ -562,17 +565,17 @@ class KnowledgeTable: } ) db.commit() - return self.get_knowledge_by_id(id=id) + return self.get_knowledge_by_id(id=id, db=db) except Exception as e: log.exception(e) return None def update_knowledge_data_by_id( - self, id: str, data: dict + self, id: str, data: dict, db: Optional[Session] = None ) -> Optional[KnowledgeModel]: try: - with get_db() as db: - knowledge = self.get_knowledge_by_id(id=id) + with get_db_context(db) as db: + knowledge = self.get_knowledge_by_id(id=id, db=db) db.query(Knowledge).filter_by(id=id).update( { "data": data, @@ -580,22 +583,22 @@ class KnowledgeTable: } ) db.commit() - return self.get_knowledge_by_id(id=id) + return self.get_knowledge_by_id(id=id, db=db) except Exception as e: log.exception(e) return None - def delete_knowledge_by_id(self, id: str) -> bool: + def delete_knowledge_by_id(self, id: str, db: Optional[Session] = None) -> bool: try: - with get_db() as db: + with get_db_context(db) as db: db.query(Knowledge).filter_by(id=id).delete() db.commit() return True except Exception: return False - def delete_all_knowledge(self) -> bool: - with get_db() as db: + def delete_all_knowledge(self, db: Optional[Session] = None) -> bool: + with get_db_context(db) as db: try: db.query(Knowledge).delete() db.commit() diff --git a/backend/open_webui/models/memories.py b/backend/open_webui/models/memories.py index f5f2492b99..c0faed2c62 100644 --- a/backend/open_webui/models/memories.py +++ b/backend/open_webui/models/memories.py @@ -2,7 +2,8 @@ import time import uuid from typing import Optional -from open_webui.internal.db import Base, get_db +from sqlalchemy.orm import Session +from open_webui.internal.db import Base, get_db, get_db_context from pydantic import BaseModel, ConfigDict from sqlalchemy import BigInteger, Column, String, Text @@ -41,8 +42,9 @@ class MemoriesTable: self, user_id: str, content: str, + db: Optional[Session] = None, ) -> Optional[MemoryModel]: - with get_db() as db: + with get_db_context(db) as db: id = str(uuid.uuid4()) memory = MemoryModel( @@ -68,8 +70,9 @@ class MemoriesTable: id: str, user_id: str, content: str, + db: Optional[Session] = None, ) -> Optional[MemoryModel]: - with get_db() as db: + with get_db_context(db) as db: try: memory = db.get(Memory, id) if not memory or memory.user_id != user_id: @@ -83,32 +86,32 @@ class MemoriesTable: except Exception: return None - def get_memories(self) -> list[MemoryModel]: - with get_db() as db: + def get_memories(self, db: Optional[Session] = None) -> list[MemoryModel]: + with get_db_context(db) as db: try: memories = db.query(Memory).all() return [MemoryModel.model_validate(memory) for memory in memories] except Exception: return None - def get_memories_by_user_id(self, user_id: str) -> list[MemoryModel]: - with get_db() as db: + def get_memories_by_user_id(self, user_id: str, db: Optional[Session] = None) -> list[MemoryModel]: + with get_db_context(db) as db: try: memories = db.query(Memory).filter_by(user_id=user_id).all() return [MemoryModel.model_validate(memory) for memory in memories] except Exception: return None - def get_memory_by_id(self, id: str) -> Optional[MemoryModel]: - with get_db() as db: + def get_memory_by_id(self, id: str, db: Optional[Session] = None) -> Optional[MemoryModel]: + with get_db_context(db) as db: try: memory = db.get(Memory, id) return MemoryModel.model_validate(memory) except Exception: return None - def delete_memory_by_id(self, id: str) -> bool: - with get_db() as db: + def delete_memory_by_id(self, id: str, db: Optional[Session] = None) -> bool: + with get_db_context(db) as db: try: db.query(Memory).filter_by(id=id).delete() db.commit() @@ -118,8 +121,8 @@ class MemoriesTable: except Exception: return False - def delete_memories_by_user_id(self, user_id: str) -> bool: - with get_db() as db: + def delete_memories_by_user_id(self, user_id: str, db: Optional[Session] = None) -> bool: + with get_db_context(db) as db: try: db.query(Memory).filter_by(user_id=user_id).delete() db.commit() @@ -128,8 +131,8 @@ class MemoriesTable: except Exception: return False - def delete_memory_by_id_and_user_id(self, id: str, user_id: str) -> bool: - with get_db() as db: + def delete_memory_by_id_and_user_id(self, id: str, user_id: str, db: Optional[Session] = None) -> bool: + with get_db_context(db) as db: try: memory = db.get(Memory, id) if not memory or memory.user_id != user_id: diff --git a/backend/open_webui/models/messages.py b/backend/open_webui/models/messages.py index 5b068b6449..8fe37e58ea 100644 --- a/backend/open_webui/models/messages.py +++ b/backend/open_webui/models/messages.py @@ -3,7 +3,8 @@ import time import uuid from typing import Optional -from open_webui.internal.db import Base, get_db +from sqlalchemy.orm import Session +from open_webui.internal.db import Base, JSONField, get_db, get_db_context from open_webui.models.tags import TagModel, Tag, Tags from open_webui.models.users import Users, User, UserNameResponse from open_webui.models.channels import Channels, ChannelMember @@ -137,9 +138,9 @@ class MessageResponse(MessageReplyToResponse): class MessageTable: def insert_new_message( - self, form_data: MessageForm, channel_id: str, user_id: str + self, form_data: MessageForm, channel_id: str, user_id: str, db: Optional[Session] = None ) -> Optional[MessageModel]: - with get_db() as db: + with get_db_context(db) as db: channel_member = Channels.join_channel(channel_id, user_id) id = str(uuid.uuid4()) @@ -169,22 +170,22 @@ class MessageTable: db.refresh(result) return MessageModel.model_validate(result) if result else None - def get_message_by_id(self, id: str) -> Optional[MessageResponse]: - with get_db() as db: + def get_message_by_id(self, id: str, db: Optional[Session] = None) -> Optional[MessageResponse]: + with get_db_context(db) as db: message = db.get(Message, id) if not message: return None reply_to_message = ( - self.get_message_by_id(message.reply_to_id) + self.get_message_by_id(message.reply_to_id, db=db) if message.reply_to_id else None ) - reactions = self.get_reactions_by_message_id(id) - thread_replies = self.get_thread_replies_by_message_id(id) + reactions = self.get_reactions_by_message_id(id, db=db) + thread_replies = self.get_thread_replies_by_message_id(id, db=db) - user = Users.get_user_by_id(message.user_id) + user = Users.get_user_by_id(message.user_id, db=db) return MessageResponse.model_validate( { **MessageModel.model_validate(message).model_dump(), @@ -200,8 +201,8 @@ class MessageTable: } ) - def get_thread_replies_by_message_id(self, id: str) -> list[MessageReplyToResponse]: - with get_db() as db: + def get_thread_replies_by_message_id(self, id: str, db: Optional[Session] = None) -> list[MessageReplyToResponse]: + with get_db_context(db) as db: all_messages = ( db.query(Message) .filter_by(parent_id=id) @@ -212,7 +213,7 @@ class MessageTable: messages = [] for message in all_messages: reply_to_message = ( - self.get_message_by_id(message.reply_to_id) + self.get_message_by_id(message.reply_to_id, db=db) if message.reply_to_id else None ) @@ -230,17 +231,17 @@ class MessageTable: ) return messages - def get_reply_user_ids_by_message_id(self, id: str) -> list[str]: - with get_db() as db: + def get_reply_user_ids_by_message_id(self, id: str, db: Optional[Session] = None) -> list[str]: + with get_db_context(db) as db: return [ message.user_id for message in db.query(Message).filter_by(parent_id=id).all() ] def get_messages_by_channel_id( - self, channel_id: str, skip: int = 0, limit: int = 50 + self, channel_id: str, skip: int = 0, limit: int = 50, db: Optional[Session] = None ) -> list[MessageReplyToResponse]: - with get_db() as db: + with get_db_context(db) as db: all_messages = ( db.query(Message) .filter_by(channel_id=channel_id, parent_id=None) @@ -253,7 +254,7 @@ class MessageTable: messages = [] for message in all_messages: reply_to_message = ( - self.get_message_by_id(message.reply_to_id) + self.get_message_by_id(message.reply_to_id, db=db) if message.reply_to_id else None ) @@ -272,9 +273,9 @@ class MessageTable: return messages def get_messages_by_parent_id( - self, channel_id: str, parent_id: str, skip: int = 0, limit: int = 50 + self, channel_id: str, parent_id: str, skip: int = 0, limit: int = 50, db: Optional[Session] = None ) -> list[MessageReplyToResponse]: - with get_db() as db: + with get_db_context(db) as db: message = db.get(Message, parent_id) if not message: @@ -296,7 +297,7 @@ class MessageTable: messages = [] for message in all_messages: reply_to_message = ( - self.get_message_by_id(message.reply_to_id) + self.get_message_by_id(message.reply_to_id, db=db) if message.reply_to_id else None ) @@ -314,8 +315,8 @@ class MessageTable: ) return messages - def get_last_message_by_channel_id(self, channel_id: str) -> Optional[MessageModel]: - with get_db() as db: + def get_last_message_by_channel_id(self, channel_id: str, db: Optional[Session] = None) -> Optional[MessageModel]: + with get_db_context(db) as db: message = ( db.query(Message) .filter_by(channel_id=channel_id) @@ -325,9 +326,9 @@ class MessageTable: return MessageModel.model_validate(message) if message else None def get_pinned_messages_by_channel_id( - self, channel_id: str, skip: int = 0, limit: int = 50 + self, channel_id: str, skip: int = 0, limit: int = 50, db: Optional[Session] = None ) -> list[MessageModel]: - with get_db() as db: + with get_db_context(db) as db: all_messages = ( db.query(Message) .filter_by(channel_id=channel_id, is_pinned=True) @@ -339,9 +340,9 @@ class MessageTable: return [MessageModel.model_validate(message) for message in all_messages] def update_message_by_id( - self, id: str, form_data: MessageForm + self, id: str, form_data: MessageForm, db: Optional[Session] = None ) -> Optional[MessageModel]: - with get_db() as db: + with get_db_context(db) as db: message = db.get(Message, id) message.content = form_data.content message.data = { @@ -358,9 +359,9 @@ class MessageTable: return MessageModel.model_validate(message) if message else None def update_is_pinned_by_id( - self, id: str, is_pinned: bool, pinned_by: Optional[str] = None + self, id: str, is_pinned: bool, pinned_by: Optional[str] = None, db: Optional[Session] = None ) -> Optional[MessageModel]: - with get_db() as db: + with get_db_context(db) as db: message = db.get(Message, id) message.is_pinned = is_pinned message.pinned_at = int(time.time_ns()) if is_pinned else None @@ -370,9 +371,9 @@ class MessageTable: return MessageModel.model_validate(message) if message else None def get_unread_message_count( - self, channel_id: str, user_id: str, last_read_at: Optional[int] = None + self, channel_id: str, user_id: str, last_read_at: Optional[int] = None, db: Optional[Session] = None ) -> int: - with get_db() as db: + with get_db_context(db) as db: query = db.query(Message).filter( Message.channel_id == channel_id, Message.parent_id == None, # only count top-level messages @@ -383,9 +384,9 @@ class MessageTable: return query.count() def add_reaction_to_message( - self, id: str, user_id: str, name: str + self, id: str, user_id: str, name: str, db: Optional[Session] = None ) -> Optional[MessageReactionModel]: - with get_db() as db: + with get_db_context(db) as db: # check for existing reaction existing_reaction = ( db.query(MessageReaction) @@ -409,8 +410,8 @@ class MessageTable: db.refresh(result) return MessageReactionModel.model_validate(result) if result else None - def get_reactions_by_message_id(self, id: str) -> list[Reactions]: - with get_db() as db: + def get_reactions_by_message_id(self, id: str, db: Optional[Session] = None) -> list[Reactions]: + with get_db_context(db) as db: # JOIN User so all user info is fetched in one query results = ( db.query(MessageReaction, User) @@ -440,29 +441,29 @@ class MessageTable: return [Reactions(**reaction) for reaction in reactions.values()] def remove_reaction_by_id_and_user_id_and_name( - self, id: str, user_id: str, name: str + self, id: str, user_id: str, name: str, db: Optional[Session] = None ) -> bool: - with get_db() as db: + with get_db_context(db) as db: db.query(MessageReaction).filter_by( message_id=id, user_id=user_id, name=name ).delete() db.commit() return True - def delete_reactions_by_id(self, id: str) -> bool: - with get_db() as db: + def delete_reactions_by_id(self, id: str, db: Optional[Session] = None) -> bool: + with get_db_context(db) as db: db.query(MessageReaction).filter_by(message_id=id).delete() db.commit() return True - def delete_replies_by_id(self, id: str) -> bool: - with get_db() as db: + def delete_replies_by_id(self, id: str, db: Optional[Session] = None) -> bool: + with get_db_context(db) as db: db.query(Message).filter_by(parent_id=id).delete() db.commit() return True - def delete_message_by_id(self, id: str) -> bool: - with get_db() as db: + def delete_message_by_id(self, id: str, db: Optional[Session] = None) -> bool: + with get_db_context(db) as db: db.query(Message).filter_by(id=id).delete() # Delete all reactions to this message diff --git a/backend/open_webui/models/models.py b/backend/open_webui/models/models.py index 08dcfe87a1..adb82755d5 100755 --- a/backend/open_webui/models/models.py +++ b/backend/open_webui/models/models.py @@ -2,7 +2,8 @@ import logging import time from typing import Optional -from open_webui.internal.db import Base, JSONField, get_db +from sqlalchemy.orm import Session +from open_webui.internal.db import Base, JSONField, get_db, get_db_context from open_webui.models.groups import Groups from open_webui.models.users import User, UserModel, Users, UserResponse @@ -150,7 +151,7 @@ class ModelForm(BaseModel): class ModelsTable: def insert_new_model( - self, form_data: ModelForm, user_id: str + self, form_data: ModelForm, user_id: str, db: Optional[Session] = None ) -> Optional[ModelModel]: model = ModelModel( **{ @@ -161,7 +162,7 @@ class ModelsTable: } ) try: - with get_db() as db: + with get_db_context(db) as db: result = Model(**model.model_dump()) db.add(result) db.commit() @@ -175,17 +176,17 @@ class ModelsTable: log.exception(f"Failed to insert a new model: {e}") return None - def get_all_models(self) -> list[ModelModel]: - with get_db() as db: + def get_all_models(self, db: Optional[Session] = None) -> list[ModelModel]: + with get_db_context(db) as db: return [ModelModel.model_validate(model) for model in db.query(Model).all()] - def get_models(self) -> list[ModelUserResponse]: - with get_db() as db: + def get_models(self, db: Optional[Session] = None) -> list[ModelUserResponse]: + with get_db_context(db) as db: all_models = db.query(Model).filter(Model.base_model_id != None).all() user_ids = list(set(model.user_id for model in all_models)) - users = Users.get_users_by_user_ids(user_ids) if user_ids else [] + users = Users.get_users_by_user_ids(user_ids, db=db) if user_ids else [] users_dict = {user.id: user for user in users} models = [] @@ -201,18 +202,18 @@ class ModelsTable: ) return models - def get_base_models(self) -> list[ModelModel]: - with get_db() as db: + def get_base_models(self, db: Optional[Session] = None) -> list[ModelModel]: + with get_db_context(db) as db: return [ ModelModel.model_validate(model) for model in db.query(Model).filter(Model.base_model_id == None).all() ] def get_models_by_user_id( - self, user_id: str, permission: str = "write" + self, user_id: str, permission: str = "write", db: Optional[Session] = None ) -> list[ModelUserResponse]: - models = self.get_models() - user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id)} + models = self.get_models(db=db) + user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id, db=db)} return [ model for model in models @@ -263,9 +264,9 @@ class ModelsTable: return query def search_models( - self, user_id: str, filter: dict = {}, skip: int = 0, limit: int = 30 + self, user_id: str, filter: dict = {}, skip: int = 0, limit: int = 30, db: Optional[Session] = None ) -> ModelListResponse: - with get_db() as db: + with get_db_context(db) as db: # Join GroupMember so we can order by group_id when requested query = db.query(Model, User).outerjoin(User, User.id == Model.user_id) query = query.filter(Model.base_model_id != None) @@ -349,24 +350,24 @@ class ModelsTable: return ModelListResponse(items=models, total=total) - def get_model_by_id(self, id: str) -> Optional[ModelModel]: + def get_model_by_id(self, id: str, db: Optional[Session] = None) -> Optional[ModelModel]: try: - with get_db() as db: + with get_db_context(db) as db: model = db.get(Model, id) return ModelModel.model_validate(model) except Exception: return None - def get_models_by_ids(self, ids: list[str]) -> list[ModelModel]: + def get_models_by_ids(self, ids: list[str], db: Optional[Session] = None) -> list[ModelModel]: try: - with get_db() as db: + with get_db_context(db) as db: models = db.query(Model).filter(Model.id.in_(ids)).all() return [ModelModel.model_validate(model) for model in models] except Exception: return [] - def toggle_model_by_id(self, id: str) -> Optional[ModelModel]: - with get_db() as db: + def toggle_model_by_id(self, id: str, db: Optional[Session] = None) -> Optional[ModelModel]: + with get_db_context(db) as db: try: is_active = db.query(Model).filter_by(id=id).first().is_active @@ -378,13 +379,13 @@ class ModelsTable: ) db.commit() - return self.get_model_by_id(id) + return self.get_model_by_id(id, db=db) except Exception: return None - def update_model_by_id(self, id: str, model: ModelForm) -> Optional[ModelModel]: + def update_model_by_id(self, id: str, model: ModelForm, db: Optional[Session] = None) -> Optional[ModelModel]: try: - with get_db() as db: + with get_db_context(db) as db: # update only the fields that are present in the model data = model.model_dump(exclude={"id"}) result = db.query(Model).filter_by(id=id).update(data) @@ -398,9 +399,9 @@ class ModelsTable: log.exception(f"Failed to update the model by id {id}: {e}") return None - def delete_model_by_id(self, id: str) -> bool: + def delete_model_by_id(self, id: str, db: Optional[Session] = None) -> bool: try: - with get_db() as db: + with get_db_context(db) as db: db.query(Model).filter_by(id=id).delete() db.commit() @@ -408,9 +409,9 @@ class ModelsTable: except Exception: return False - def delete_all_models(self) -> bool: + def delete_all_models(self, db: Optional[Session] = None) -> bool: try: - with get_db() as db: + with get_db_context(db) as db: db.query(Model).delete() db.commit() @@ -418,9 +419,9 @@ class ModelsTable: except Exception: return False - def sync_models(self, user_id: str, models: list[ModelModel]) -> list[ModelModel]: + def sync_models(self, user_id: str, models: list[ModelModel], db: Optional[Session] = None) -> list[ModelModel]: try: - with get_db() as db: + with get_db_context(db) as db: # Get existing models existing_models = db.query(Model).all() existing_ids = {model.id for model in existing_models} diff --git a/backend/open_webui/models/notes.py b/backend/open_webui/models/notes.py index cfeddf4a8c..fef9db7bf4 100644 --- a/backend/open_webui/models/notes.py +++ b/backend/open_webui/models/notes.py @@ -4,7 +4,8 @@ import uuid from typing import Optional from functools import lru_cache -from open_webui.internal.db import Base, get_db +from sqlalchemy.orm import Session +from open_webui.internal.db import Base, get_db, get_db_context from open_webui.models.groups import Groups from open_webui.utils.access_control import has_access from open_webui.models.users import User, UserModel, Users, UserResponse @@ -211,11 +212,9 @@ class NoteTable: return query def insert_new_note( - self, - form_data: NoteForm, - user_id: str, + self, user_id: str, form_data: NoteForm, db: Optional[Session] = None ) -> Optional[NoteModel]: - with get_db() as db: + with get_db_context(db) as db: note = NoteModel( **{ "id": str(uuid.uuid4()), @@ -233,9 +232,9 @@ class NoteTable: return note def get_notes( - self, skip: Optional[int] = None, limit: Optional[int] = None + self, skip: int = 0, limit: int = 50, db: Optional[Session] = None ) -> list[NoteModel]: - with get_db() as db: + with get_db_context(db) as db: query = db.query(Note).order_by(Note.updated_at.desc()) if skip is not None: query = query.offset(skip) @@ -333,10 +332,11 @@ class NoteTable: self, user_id: str, permission: str = "read", - skip: Optional[int] = None, - limit: Optional[int] = None, + skip: int = 0, + limit: int = 50, + db: Optional[Session] = None, ) -> list[NoteModel]: - with get_db() as db: + with get_db_context(db) as db: user_group_ids = [ group.id for group in Groups.get_groups_by_member_id(user_id) ] @@ -354,15 +354,17 @@ class NoteTable: notes = query.all() return [NoteModel.model_validate(note) for note in notes] - def get_note_by_id(self, id: str) -> Optional[NoteModel]: - with get_db() as db: + def get_note_by_id( + self, id: str, db: Optional[Session] = None + ) -> Optional[NoteModel]: + with get_db_context(db) as db: note = db.query(Note).filter(Note.id == id).first() return NoteModel.model_validate(note) if note else None def update_note_by_id( - self, id: str, form_data: NoteUpdateForm + self, id: str, form_data: NoteUpdateForm, db: Optional[Session] = None ) -> Optional[NoteModel]: - with get_db() as db: + with get_db_context(db) as db: note = db.query(Note).filter(Note.id == id).first() if not note: return None @@ -384,11 +386,14 @@ class NoteTable: db.commit() return NoteModel.model_validate(note) if note else None - def delete_note_by_id(self, id: str): - with get_db() as db: - db.query(Note).filter(Note.id == id).delete() - db.commit() - return True + def delete_note_by_id(self, id: str, db: Optional[Session] = None) -> bool: + try: + with get_db_context(db) as db: + db.query(Note).filter(Note.id == id).delete() + db.commit() + return True + except Exception: + return False Notes = NoteTable() diff --git a/backend/open_webui/models/oauth_sessions.py b/backend/open_webui/models/oauth_sessions.py index 8b0334ed19..4221987faa 100644 --- a/backend/open_webui/models/oauth_sessions.py +++ b/backend/open_webui/models/oauth_sessions.py @@ -8,7 +8,8 @@ import json from cryptography.fernet import Fernet -from open_webui.internal.db import Base, get_db +from sqlalchemy.orm import Session +from open_webui.internal.db import Base, get_db, get_db_context from open_webui.env import OAUTH_SESSION_TOKEN_ENCRYPTION_KEY from pydantic import BaseModel, ConfigDict @@ -109,10 +110,11 @@ class OAuthSessionTable: user_id: str, provider: str, token: dict, + db: Optional[Session] = None, ) -> Optional[OAuthSessionModel]: """Create a new OAuth session""" try: - with get_db() as db: + with get_db_context(db) as db: current_time = int(time.time()) id = str(uuid.uuid4()) @@ -141,10 +143,10 @@ class OAuthSessionTable: log.error(f"Error creating OAuth session: {e}") return None - def get_session_by_id(self, session_id: str) -> Optional[OAuthSessionModel]: + def get_session_by_id(self, session_id: str, db: Optional[Session] = None) -> Optional[OAuthSessionModel]: """Get OAuth session by ID""" try: - with get_db() as db: + with get_db_context(db) as db: session = db.query(OAuthSession).filter_by(id=session_id).first() if session: session.token = self._decrypt_token(session.token) @@ -156,11 +158,11 @@ class OAuthSessionTable: return None def get_session_by_id_and_user_id( - self, session_id: str, user_id: str + self, session_id: str, user_id: str, db: Optional[Session] = None ) -> Optional[OAuthSessionModel]: """Get OAuth session by ID and user ID""" try: - with get_db() as db: + with get_db_context(db) as db: session = ( db.query(OAuthSession) .filter_by(id=session_id, user_id=user_id) @@ -176,11 +178,11 @@ class OAuthSessionTable: return None def get_session_by_provider_and_user_id( - self, provider: str, user_id: str + self, provider: str, user_id: str, db: Optional[Session] = None ) -> Optional[OAuthSessionModel]: """Get OAuth session by provider and user ID""" try: - with get_db() as db: + with get_db_context(db) as db: session = ( db.query(OAuthSession) .filter_by(provider=provider, user_id=user_id) @@ -195,10 +197,10 @@ class OAuthSessionTable: log.error(f"Error getting OAuth session by provider and user ID: {e}") return None - def get_sessions_by_user_id(self, user_id: str) -> List[OAuthSessionModel]: + def get_sessions_by_user_id(self, user_id: str, db: Optional[Session] = None) -> List[OAuthSessionModel]: """Get all OAuth sessions for a user""" try: - with get_db() as db: + with get_db_context(db) as db: sessions = db.query(OAuthSession).filter_by(user_id=user_id).all() results = [] @@ -213,11 +215,11 @@ class OAuthSessionTable: return [] def update_session_by_id( - self, session_id: str, token: dict + self, session_id: str, token: dict, db: Optional[Session] = None ) -> Optional[OAuthSessionModel]: """Update OAuth session tokens""" try: - with get_db() as db: + with get_db_context(db) as db: current_time = int(time.time()) db.query(OAuthSession).filter_by(id=session_id).update( @@ -239,10 +241,10 @@ class OAuthSessionTable: log.error(f"Error updating OAuth session tokens: {e}") return None - def delete_session_by_id(self, session_id: str) -> bool: + def delete_session_by_id(self, session_id: str, db: Optional[Session] = None) -> bool: """Delete an OAuth session""" try: - with get_db() as db: + with get_db_context(db) as db: result = db.query(OAuthSession).filter_by(id=session_id).delete() db.commit() return result > 0 @@ -250,10 +252,10 @@ class OAuthSessionTable: log.error(f"Error deleting OAuth session: {e}") return False - def delete_sessions_by_user_id(self, user_id: str) -> bool: + def delete_sessions_by_user_id(self, user_id: str, db: Optional[Session] = None) -> bool: """Delete all OAuth sessions for a user""" try: - with get_db() as db: + with get_db_context(db) as db: result = db.query(OAuthSession).filter_by(user_id=user_id).delete() db.commit() return True @@ -261,10 +263,10 @@ class OAuthSessionTable: log.error(f"Error deleting OAuth sessions by user ID: {e}") return False - def delete_sessions_by_provider(self, provider: str) -> bool: + def delete_sessions_by_provider(self, provider: str, db: Optional[Session] = None) -> bool: """Delete all OAuth sessions for a provider""" try: - with get_db() as db: + with get_db_context(db) as db: db.query(OAuthSession).filter_by(provider=provider).delete() db.commit() return True diff --git a/backend/open_webui/models/prompts.py b/backend/open_webui/models/prompts.py index 7502f34ccd..eefb74de53 100644 --- a/backend/open_webui/models/prompts.py +++ b/backend/open_webui/models/prompts.py @@ -1,7 +1,9 @@ + import time from typing import Optional -from open_webui.internal.db import Base, get_db +from sqlalchemy.orm import Session +from open_webui.internal.db import Base, JSONField, get_db, get_db_context from open_webui.models.groups import Groups from open_webui.models.users import Users, UserResponse @@ -71,7 +73,7 @@ class PromptForm(BaseModel): class PromptsTable: def insert_new_prompt( - self, user_id: str, form_data: PromptForm + self, user_id: str, form_data: PromptForm, db: Optional[Session] = None ) -> Optional[PromptModel]: prompt = PromptModel( **{ @@ -82,7 +84,7 @@ class PromptsTable: ) try: - with get_db() as db: + with get_db_context(db) as db: result = Prompt(**prompt.model_dump()) db.add(result) db.commit() @@ -94,21 +96,21 @@ class PromptsTable: except Exception: return None - def get_prompt_by_command(self, command: str) -> Optional[PromptModel]: + def get_prompt_by_command(self, command: str, db: Optional[Session] = None) -> Optional[PromptModel]: try: - with get_db() as db: + with get_db_context(db) as db: prompt = db.query(Prompt).filter_by(command=command).first() return PromptModel.model_validate(prompt) except Exception: return None - def get_prompts(self) -> list[PromptUserResponse]: - with get_db() as db: + def get_prompts(self, db: Optional[Session] = None) -> list[PromptUserResponse]: + with get_db_context(db) as db: all_prompts = db.query(Prompt).order_by(Prompt.timestamp.desc()).all() user_ids = list(set(prompt.user_id for prompt in all_prompts)) - users = Users.get_users_by_user_ids(user_ids) if user_ids else [] + users = Users.get_users_by_user_ids(user_ids, db=db) if user_ids else [] users_dict = {user.id: user for user in users} prompts = [] @@ -126,10 +128,10 @@ class PromptsTable: return prompts def get_prompts_by_user_id( - self, user_id: str, permission: str = "write" + self, user_id: str, permission: str = "write", db: Optional[Session] = None ) -> list[PromptUserResponse]: - prompts = self.get_prompts() - user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id)} + prompts = self.get_prompts(db=db) + user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id, db=db)} return [ prompt @@ -139,10 +141,10 @@ class PromptsTable: ] def update_prompt_by_command( - self, command: str, form_data: PromptForm + self, command: str, form_data: PromptForm, db: Optional[Session] = None ) -> Optional[PromptModel]: try: - with get_db() as db: + with get_db_context(db) as db: prompt = db.query(Prompt).filter_by(command=command).first() prompt.title = form_data.title prompt.content = form_data.content @@ -153,9 +155,9 @@ class PromptsTable: except Exception: return None - def delete_prompt_by_command(self, command: str) -> bool: + def delete_prompt_by_command(self, command: str, db: Optional[Session] = None) -> bool: try: - with get_db() as db: + with get_db_context(db) as db: db.query(Prompt).filter_by(command=command).delete() db.commit() diff --git a/backend/open_webui/models/tags.py b/backend/open_webui/models/tags.py index 499f3859dc..950c475ba3 100644 --- a/backend/open_webui/models/tags.py +++ b/backend/open_webui/models/tags.py @@ -3,7 +3,8 @@ import time import uuid from typing import Optional -from open_webui.internal.db import Base, get_db +from sqlalchemy.orm import Session +from open_webui.internal.db import Base, JSONField, get_db, get_db_context from pydantic import BaseModel, ConfigDict @@ -50,8 +51,8 @@ class TagChatIdForm(BaseModel): class TagTable: - def insert_new_tag(self, name: str, user_id: str) -> Optional[TagModel]: - with get_db() as db: + def insert_new_tag(self, name: str, user_id: str, db: Optional[Session] = None) -> Optional[TagModel]: + with get_db_context(db) as db: id = name.replace(" ", "_").lower() tag = TagModel(**{"id": id, "user_id": user_id, "name": name}) try: @@ -68,27 +69,27 @@ class TagTable: return None def get_tag_by_name_and_user_id( - self, name: str, user_id: str + self, name: str, user_id: str, db: Optional[Session] = None ) -> Optional[TagModel]: try: id = name.replace(" ", "_").lower() - with get_db() as db: + with get_db_context(db) as db: tag = db.query(Tag).filter_by(id=id, user_id=user_id).first() return TagModel.model_validate(tag) except Exception: return None - def get_tags_by_user_id(self, user_id: str) -> list[TagModel]: - with get_db() as db: + def get_tags_by_user_id(self, user_id: str, db: Optional[Session] = None) -> list[TagModel]: + with get_db_context(db) as db: return [ TagModel.model_validate(tag) for tag in (db.query(Tag).filter_by(user_id=user_id).all()) ] def get_tags_by_ids_and_user_id( - self, ids: list[str], user_id: str + self, ids: list[str], user_id: str, db: Optional[Session] = None ) -> list[TagModel]: - with get_db() as db: + with get_db_context(db) as db: return [ TagModel.model_validate(tag) for tag in ( @@ -96,9 +97,9 @@ class TagTable: ) ] - def delete_tag_by_name_and_user_id(self, name: str, user_id: str) -> bool: + def delete_tag_by_name_and_user_id(self, name: str, user_id: str, db: Optional[Session] = None) -> bool: try: - with get_db() as db: + with get_db_context(db) as db: id = name.replace(" ", "_").lower() res = db.query(Tag).filter_by(id=id, user_id=user_id).delete() log.debug(f"res: {res}") diff --git a/backend/open_webui/models/tools.py b/backend/open_webui/models/tools.py index fff53a7e94..cbd17cf5a7 100644 --- a/backend/open_webui/models/tools.py +++ b/backend/open_webui/models/tools.py @@ -2,7 +2,8 @@ import logging import time from typing import Optional -from open_webui.internal.db import Base, JSONField, get_db +from sqlalchemy.orm import Session +from open_webui.internal.db import Base, JSONField, get_db, get_db_context from open_webui.models.users import Users, UserResponse from open_webui.models.groups import Groups @@ -110,9 +111,9 @@ class ToolValves(BaseModel): class ToolsTable: def insert_new_tool( - self, user_id: str, form_data: ToolForm, specs: list[dict] + self, user_id: str, form_data: ToolForm, specs: list[dict], db: Optional[Session] = None ) -> Optional[ToolModel]: - with get_db() as db: + with get_db_context(db) as db: tool = ToolModel( **{ **form_data.model_dump(), @@ -136,21 +137,21 @@ class ToolsTable: log.exception(f"Error creating a new tool: {e}") return None - def get_tool_by_id(self, id: str) -> Optional[ToolModel]: + def get_tool_by_id(self, id: str, db: Optional[Session] = None) -> Optional[ToolModel]: try: - with get_db() as db: + with get_db_context(db) as db: tool = db.get(Tool, id) return ToolModel.model_validate(tool) except Exception: return None - def get_tools(self) -> list[ToolUserModel]: - with get_db() as db: + def get_tools(self, db: Optional[Session] = None) -> list[ToolUserModel]: + with get_db_context(db) as db: all_tools = db.query(Tool).order_by(Tool.updated_at.desc()).all() user_ids = list(set(tool.user_id for tool in all_tools)) - users = Users.get_users_by_user_ids(user_ids) if user_ids else [] + users = Users.get_users_by_user_ids(user_ids, db=db) if user_ids else [] users_dict = {user.id: user for user in users} tools = [] @@ -167,10 +168,10 @@ class ToolsTable: return tools def get_tools_by_user_id( - self, user_id: str, permission: str = "write" + self, user_id: str, permission: str = "write", db: Optional[Session] = None ) -> list[ToolUserModel]: - tools = self.get_tools() - user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id)} + tools = self.get_tools(db=db) + user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id, db=db)} return [ tool @@ -179,31 +180,31 @@ class ToolsTable: or has_access(user_id, permission, tool.access_control, user_group_ids) ] - def get_tool_valves_by_id(self, id: str) -> Optional[dict]: + def get_tool_valves_by_id(self, id: str, db: Optional[Session] = None) -> Optional[dict]: try: - with get_db() as db: + with get_db_context(db) as db: tool = db.get(Tool, id) return tool.valves if tool.valves else {} except Exception as e: log.exception(f"Error getting tool valves by id {id}") return None - def update_tool_valves_by_id(self, id: str, valves: dict) -> Optional[ToolValves]: + def update_tool_valves_by_id(self, id: str, valves: dict, db: Optional[Session] = None) -> Optional[ToolValves]: try: - with get_db() as db: + with get_db_context(db) as db: db.query(Tool).filter_by(id=id).update( {"valves": valves, "updated_at": int(time.time())} ) db.commit() - return self.get_tool_by_id(id) + return self.get_tool_by_id(id, db=db) except Exception: return None def get_user_valves_by_id_and_user_id( - self, id: str, user_id: str + self, id: str, user_id: str, db: Optional[Session] = None ) -> Optional[dict]: try: - user = Users.get_user_by_id(user_id) + user = Users.get_user_by_id(user_id, db=db) user_settings = user.settings.model_dump() if user.settings else {} # Check if user has "tools" and "valves" settings @@ -220,10 +221,10 @@ class ToolsTable: return None def update_user_valves_by_id_and_user_id( - self, id: str, user_id: str, valves: dict + self, id: str, user_id: str, valves: dict, db: Optional[Session] = None ) -> Optional[dict]: try: - user = Users.get_user_by_id(user_id) + user = Users.get_user_by_id(user_id, db=db) user_settings = user.settings.model_dump() if user.settings else {} # Check if user has "tools" and "valves" settings @@ -235,7 +236,7 @@ class ToolsTable: user_settings["tools"]["valves"][id] = valves # Update the user settings in the database - Users.update_user_by_id(user_id, {"settings": user_settings}) + Users.update_user_by_id(user_id, {"settings": user_settings}, db=db) return user_settings["tools"]["valves"][id] except Exception as e: @@ -244,9 +245,9 @@ class ToolsTable: ) return None - def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]: + def update_tool_by_id(self, id: str, updated: dict, db: Optional[Session] = None) -> Optional[ToolModel]: try: - with get_db() as db: + with get_db_context(db) as db: db.query(Tool).filter_by(id=id).update( {**updated, "updated_at": int(time.time())} ) @@ -258,9 +259,9 @@ class ToolsTable: except Exception: return None - def delete_tool_by_id(self, id: str) -> bool: + def delete_tool_by_id(self, id: str, db: Optional[Session] = None) -> bool: try: - with get_db() as db: + with get_db_context(db) as db: db.query(Tool).filter_by(id=id).delete() db.commit() diff --git a/backend/open_webui/models/users.py b/backend/open_webui/models/users.py index 5807603a89..b48e7931ba 100644 --- a/backend/open_webui/models/users.py +++ b/backend/open_webui/models/users.py @@ -1,7 +1,8 @@ import time from typing import Optional -from open_webui.internal.db import Base, JSONField, get_db +from sqlalchemy.orm import Session +from open_webui.internal.db import Base, JSONField, get_db, get_db_context from open_webui.env import DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL @@ -243,8 +244,9 @@ class UsersTable: profile_image_url: str = "/user.png", role: str = "pending", oauth: Optional[dict] = None, + db: Optional[Session] = None, ) -> Optional[UserModel]: - with get_db() as db: + with get_db_context(db) as db: user = UserModel( **{ "id": id, @@ -267,17 +269,17 @@ class UsersTable: else: return None - def get_user_by_id(self, id: str) -> Optional[UserModel]: + def get_user_by_id(self, id: str, db: Optional[Session] = None) -> Optional[UserModel]: try: - with get_db() as db: + with get_db_context(db) as db: user = db.query(User).filter_by(id=id).first() return UserModel.model_validate(user) except Exception: return None - def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]: + def get_user_by_api_key(self, api_key: str, db: Optional[Session] = None) -> Optional[UserModel]: try: - with get_db() as db: + with get_db_context(db) as db: user = ( db.query(User) .join(ApiKey, User.id == ApiKey.user_id) @@ -288,17 +290,17 @@ class UsersTable: except Exception: return None - def get_user_by_email(self, email: str) -> Optional[UserModel]: + def get_user_by_email(self, email: str, db: Optional[Session] = None) -> Optional[UserModel]: try: - with get_db() as db: + with get_db_context(db) as db: user = db.query(User).filter_by(email=email).first() return UserModel.model_validate(user) except Exception: return None - def get_user_by_oauth_sub(self, provider: str, sub: str) -> Optional[UserModel]: + def get_user_by_oauth_sub(self, provider: str, sub: str, db: Optional[Session] = None) -> Optional[UserModel]: try: - with get_db() as db: # type: Session + with get_db_context(db) as db: # type: Session dialect_name = db.bind.dialect.name query = db.query(User) @@ -320,8 +322,9 @@ class UsersTable: filter: Optional[dict] = None, skip: Optional[int] = None, limit: Optional[int] = None, + db: Optional[Session] = None, ) -> dict: - with get_db() as db: + with get_db_context(db) as db: # Join GroupMember so we can order by group_id when requested query = db.query(User) @@ -452,8 +455,8 @@ class UsersTable: "total": total, } - def get_users_by_group_id(self, group_id: str) -> list[UserModel]: - with get_db() as db: + def get_users_by_group_id(self, group_id: str, db: Optional[Session] = None) -> list[UserModel]: + with get_db_context(db) as db: users = ( db.query(User) .join(GroupMember, User.id == GroupMember.user_id) @@ -462,30 +465,30 @@ class UsersTable: ) return [UserModel.model_validate(user) for user in users] - def get_users_by_user_ids(self, user_ids: list[str]) -> list[UserStatusModel]: - with get_db() as db: + def get_users_by_user_ids(self, user_ids: list[str], db: Optional[Session] = None) -> list[UserStatusModel]: + with get_db_context(db) as db: users = db.query(User).filter(User.id.in_(user_ids)).all() return [UserModel.model_validate(user) for user in users] - def get_num_users(self) -> Optional[int]: - with get_db() as db: + def get_num_users(self, db: Optional[Session] = None) -> Optional[int]: + with get_db_context(db) as db: return db.query(User).count() - def has_users(self) -> bool: - with get_db() as db: + def has_users(self, db: Optional[Session] = None) -> bool: + with get_db_context(db) as db: return db.query(db.query(User).exists()).scalar() - def get_first_user(self) -> UserModel: + def get_first_user(self, db: Optional[Session] = None) -> UserModel: try: - with get_db() as db: + with get_db_context(db) as db: user = db.query(User).order_by(User.created_at).first() return UserModel.model_validate(user) except Exception: return None - def get_user_webhook_url_by_id(self, id: str) -> Optional[str]: + def get_user_webhook_url_by_id(self, id: str, db: Optional[Session] = None) -> Optional[str]: try: - with get_db() as db: + with get_db_context(db) as db: user = db.query(User).filter_by(id=id).first() if user.settings is None: @@ -499,8 +502,8 @@ class UsersTable: except Exception: return None - def get_num_users_active_today(self) -> Optional[int]: - with get_db() as db: + def get_num_users_active_today(self, db: Optional[Session] = None) -> Optional[int]: + with get_db_context(db) as db: current_timestamp = int(datetime.datetime.now().timestamp()) today_midnight_timestamp = current_timestamp - (current_timestamp % 86400) query = db.query(User).filter( @@ -508,9 +511,9 @@ class UsersTable: ) return query.count() - def update_user_role_by_id(self, id: str, role: str) -> Optional[UserModel]: + def update_user_role_by_id(self, id: str, role: str, db: Optional[Session] = None) -> Optional[UserModel]: try: - with get_db() as db: + with get_db_context(db) as db: db.query(User).filter_by(id=id).update({"role": role}) db.commit() user = db.query(User).filter_by(id=id).first() @@ -519,10 +522,10 @@ class UsersTable: return None def update_user_status_by_id( - self, id: str, form_data: UserStatus + self, id: str, form_data: UserStatus, db: Optional[Session] = None ) -> Optional[UserModel]: try: - with get_db() as db: + with get_db_context(db) as db: db.query(User).filter_by(id=id).update( {**form_data.model_dump(exclude_none=True)} ) @@ -534,10 +537,10 @@ class UsersTable: return None def update_user_profile_image_url_by_id( - self, id: str, profile_image_url: str + self, id: str, profile_image_url: str, db: Optional[Session] = None ) -> Optional[UserModel]: try: - with get_db() as db: + with get_db_context(db) as db: db.query(User).filter_by(id=id).update( {"profile_image_url": profile_image_url} ) @@ -549,9 +552,9 @@ class UsersTable: return None @throttle(DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL) - def update_last_active_by_id(self, id: str) -> Optional[UserModel]: + def update_last_active_by_id(self, id: str, db: Optional[Session] = None) -> Optional[UserModel]: try: - with get_db() as db: + with get_db_context(db) as db: db.query(User).filter_by(id=id).update( {"last_active_at": int(time.time())} ) @@ -563,7 +566,7 @@ class UsersTable: return None def update_user_oauth_by_id( - self, id: str, provider: str, sub: str + self, id: str, provider: str, sub: str, db: Optional[Session] = None ) -> Optional[UserModel]: """ Update or insert an OAuth provider/sub pair into the user's oauth JSON field. @@ -574,7 +577,7 @@ class UsersTable: } """ try: - with get_db() as db: + with get_db_context(db) as db: user = db.query(User).filter_by(id=id).first() if not user: return None @@ -594,9 +597,9 @@ class UsersTable: except Exception: return None - def update_user_by_id(self, id: str, updated: dict) -> Optional[UserModel]: + def update_user_by_id(self, id: str, updated: dict, db: Optional[Session] = None) -> Optional[UserModel]: try: - with get_db() as db: + with get_db_context(db) as db: db.query(User).filter_by(id=id).update(updated) db.commit() @@ -607,9 +610,9 @@ class UsersTable: print(e) return None - def update_user_settings_by_id(self, id: str, updated: dict) -> Optional[UserModel]: + def update_user_settings_by_id(self, id: str, updated: dict, db: Optional[Session] = None) -> Optional[UserModel]: try: - with get_db() as db: + with get_db_context(db) as db: user_settings = db.query(User).filter_by(id=id).first().settings if user_settings is None: @@ -625,15 +628,15 @@ class UsersTable: except Exception: return None - def delete_user_by_id(self, id: str) -> bool: + def delete_user_by_id(self, id: str, db: Optional[Session] = None) -> bool: try: # Remove User from Groups Groups.remove_user_from_all_groups(id) # Delete User Chats - result = Chats.delete_chats_by_user_id(id) + result = Chats.delete_chats_by_user_id(id, db=db) if result: - with get_db() as db: + with get_db_context(db) as db: # Delete User db.query(User).filter_by(id=id).delete() db.commit() @@ -644,17 +647,17 @@ class UsersTable: except Exception: return False - def get_user_api_key_by_id(self, id: str) -> Optional[str]: + def get_user_api_key_by_id(self, id: str, db: Optional[Session] = None) -> Optional[str]: try: - with get_db() as db: + with get_db_context(db) as db: api_key = db.query(ApiKey).filter_by(user_id=id).first() return api_key.key if api_key else None except Exception: return None - def update_user_api_key_by_id(self, id: str, api_key: str) -> bool: + def update_user_api_key_by_id(self, id: str, api_key: str, db: Optional[Session] = None) -> bool: try: - with get_db() as db: + with get_db_context(db) as db: db.query(ApiKey).filter_by(user_id=id).delete() db.commit() @@ -674,30 +677,30 @@ class UsersTable: except Exception: return False - def delete_user_api_key_by_id(self, id: str) -> bool: + def delete_user_api_key_by_id(self, id: str, db: Optional[Session] = None) -> bool: try: - with get_db() as db: + with get_db_context(db) as db: db.query(ApiKey).filter_by(user_id=id).delete() db.commit() return True except Exception: return False - def get_valid_user_ids(self, user_ids: list[str]) -> list[str]: - with get_db() as db: + def get_valid_user_ids(self, user_ids: list[str], db: Optional[Session] = None) -> list[str]: + with get_db_context(db) as db: users = db.query(User).filter(User.id.in_(user_ids)).all() return [user.id for user in users] - def get_super_admin_user(self) -> Optional[UserModel]: - with get_db() as db: + def get_super_admin_user(self, db: Optional[Session] = None) -> Optional[UserModel]: + with get_db_context(db) as db: user = db.query(User).filter_by(role="admin").first() if user: return UserModel.model_validate(user) else: return None - def get_active_user_count(self) -> int: - with get_db() as db: + def get_active_user_count(self, db: Optional[Session] = None) -> int: + with get_db_context(db) as db: # Consider user active if last_active_at within the last 3 minutes three_minutes_ago = int(time.time()) - 180 count = ( @@ -705,8 +708,8 @@ class UsersTable: ) return count - def is_user_active(self, user_id: str) -> bool: - with get_db() as db: + def is_user_active(self, user_id: str, db: Optional[Session] = None) -> bool: + with get_db_context(db) as db: user = db.query(User).filter_by(id=user_id).first() if user and user.last_active_at: # Consider user active if last_active_at within the last 3 minutes diff --git a/backend/open_webui/routers/notes.py b/backend/open_webui/routers/notes.py index ee0e46da29..853e7848cc 100644 --- a/backend/open_webui/routers/notes.py +++ b/backend/open_webui/routers/notes.py @@ -142,7 +142,7 @@ async def create_new_note( ) try: - note = Notes.insert_new_note(form_data, user.id) + note = Notes.insert_new_note(user.id, form_data) return note except Exception as e: log.exception(e)