refac/enh: db session sharing

This commit is contained in:
Timothy Jaeryang Baek 2025-12-28 22:00:44 +04:00
parent d4de26bd05
commit 2041ab483e
20 changed files with 600 additions and 562 deletions

View file

@ -160,3 +160,13 @@ def get_session():
get_db = contextmanager(get_session) get_db = contextmanager(get_session)
@contextmanager
def get_db_context(db: Optional[Session] = None):
if db:
yield db
else:
with get_db() as session:
yield session

View file

@ -2,7 +2,8 @@ import logging
import uuid import uuid
from typing import Optional from typing import Optional
from open_webui.internal.db import Base, get_db from sqlalchemy.orm import Session
from open_webui.internal.db import Base, JSONField, get_db, get_db_context
from open_webui.models.users import UserModel, UserProfileImageResponse, Users from open_webui.models.users import UserModel, UserProfileImageResponse, Users
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy import Boolean, Column, String, Text from sqlalchemy import Boolean, Column, String, Text
@ -87,8 +88,9 @@ class AuthsTable:
profile_image_url: str = "/user.png", profile_image_url: str = "/user.png",
role: str = "pending", role: str = "pending",
oauth: Optional[dict] = None, oauth: Optional[dict] = None,
db: Optional[Session] = None,
) -> Optional[UserModel]: ) -> Optional[UserModel]:
with get_db() as db: with get_db_context(db) as db:
log.info("insert_new_auth") log.info("insert_new_auth")
id = str(uuid.uuid4()) id = str(uuid.uuid4())
@ -100,7 +102,7 @@ class AuthsTable:
db.add(result) db.add(result)
user = Users.insert_new_user( user = Users.insert_new_user(
id, name, email, profile_image_url, role, oauth=oauth id, name, email, profile_image_url, role, oauth=oauth, db=db
) )
db.commit() db.commit()
@ -112,16 +114,16 @@ class AuthsTable:
return None return None
def authenticate_user( def authenticate_user(
self, email: str, verify_password: callable self, email: str, verify_password: callable, db: Optional[Session] = None
) -> Optional[UserModel]: ) -> Optional[UserModel]:
log.info(f"authenticate_user: {email}") log.info(f"authenticate_user: {email}")
user = Users.get_user_by_email(email) user = Users.get_user_by_email(email, db=db)
if not user: if not user:
return None return None
try: try:
with get_db() as db: with get_db_context(db) as db:
auth = db.query(Auth).filter_by(id=user.id, active=True).first() auth = db.query(Auth).filter_by(id=user.id, active=True).first()
if auth: if auth:
if verify_password(auth.password): if verify_password(auth.password):
@ -133,32 +135,32 @@ class AuthsTable:
except Exception: except Exception:
return None return None
def authenticate_user_by_api_key(self, api_key: str) -> Optional[UserModel]: def authenticate_user_by_api_key(self, api_key: str, db: Optional[Session] = None) -> Optional[UserModel]:
log.info(f"authenticate_user_by_api_key: {api_key}") log.info(f"authenticate_user_by_api_key: {api_key}")
# if no api_key, return None # if no api_key, return None
if not api_key: if not api_key:
return None return None
try: try:
user = Users.get_user_by_api_key(api_key) user = Users.get_user_by_api_key(api_key, db=db)
return user if user else None return user if user else None
except Exception: except Exception:
return False return False
def authenticate_user_by_email(self, email: str) -> Optional[UserModel]: def authenticate_user_by_email(self, email: str, db: Optional[Session] = None) -> Optional[UserModel]:
log.info(f"authenticate_user_by_email: {email}") log.info(f"authenticate_user_by_email: {email}")
try: try:
with get_db() as db: with get_db_context(db) as db:
auth = db.query(Auth).filter_by(email=email, active=True).first() auth = db.query(Auth).filter_by(email=email, active=True).first()
if auth: if auth:
user = Users.get_user_by_id(auth.id) user = Users.get_user_by_id(auth.id, db=db)
return user return user
except Exception: except Exception:
return None return None
def update_user_password_by_id(self, id: str, new_password: str) -> bool: def update_user_password_by_id(self, id: str, new_password: str, db: Optional[Session] = None) -> bool:
try: try:
with get_db() as db: with get_db_context(db) as db:
result = ( result = (
db.query(Auth).filter_by(id=id).update({"password": new_password}) db.query(Auth).filter_by(id=id).update({"password": new_password})
) )
@ -167,20 +169,20 @@ class AuthsTable:
except Exception: except Exception:
return False return False
def update_email_by_id(self, id: str, email: str) -> bool: def update_email_by_id(self, id: str, email: str, db: Optional[Session] = None) -> bool:
try: try:
with get_db() as db: with get_db_context(db) as db:
result = db.query(Auth).filter_by(id=id).update({"email": email}) result = db.query(Auth).filter_by(id=id).update({"email": email})
db.commit() db.commit()
return True if result == 1 else False return True if result == 1 else False
except Exception: except Exception:
return False return False
def delete_auth_by_id(self, id: str) -> bool: def delete_auth_by_id(self, id: str, db: Optional[Session] = None) -> bool:
try: try:
with get_db() as db: with get_db_context(db) as db:
# Delete User # Delete User
result = Users.delete_user_by_id(id) result = Users.delete_user_by_id(id, db=db)
if result: if result:
db.query(Auth).filter_by(id=id).delete() db.query(Auth).filter_by(id=id).delete()

View file

@ -1,9 +1,11 @@
import json import json
import time import time
import uuid import uuid
from typing import Optional from typing import Optional
from open_webui.internal.db import Base, get_db from sqlalchemy.orm import Session
from open_webui.internal.db import Base, JSONField, get_db, get_db_context
from open_webui.models.groups import Groups from open_webui.models.groups import Groups
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
@ -337,8 +339,8 @@ class ChannelTable:
db.commit() db.commit()
return channel return channel
def get_channels(self) -> list[ChannelModel]: def get_channels(self, db: Optional[Session] = None) -> list[ChannelModel]:
with get_db() as db: with get_db_context(db) as db:
channels = db.query(Channel).all() channels = db.query(Channel).all()
return [ChannelModel.model_validate(channel) for channel in channels] return [ChannelModel.model_validate(channel) for channel in channels]
@ -384,8 +386,8 @@ class ChannelTable:
return query return query
def get_channels_by_user_id(self, user_id: str) -> list[ChannelModel]: def get_channels_by_user_id(self, user_id: str, db: Optional[Session] = None) -> list[ChannelModel]:
with get_db() as db: with get_db_context(db) as db:
user_group_ids = [ user_group_ids = [
group.id for group in Groups.get_groups_by_member_id(user_id) group.id for group in Groups.get_groups_by_member_id(user_id)
] ]
@ -683,10 +685,13 @@ class ChannelTable:
) )
return membership is not None return membership is not None
def get_channel_by_id(self, id: str) -> Optional[ChannelModel]: def get_channel_by_id(self, id: str, db: Optional[Session] = None) -> Optional[ChannelModel]:
with get_db() as db: try:
channel = db.query(Channel).filter(Channel.id == id).first() with get_db_context(db) as db:
return ChannelModel.model_validate(channel) if channel else None channel = db.query(Channel).filter(Channel.id == id).first()
return ChannelModel.model_validate(channel) if channel else None
except Exception:
return None
def get_channels_by_file_id(self, file_id: str) -> list[ChannelModel]: def get_channels_by_file_id(self, file_id: str) -> list[ChannelModel]:
with get_db() as db: with get_db() as db:

View file

@ -4,7 +4,8 @@ import time
import uuid import uuid
from typing import Optional from typing import Optional
from open_webui.internal.db import Base, get_db from sqlalchemy.orm import Session
from open_webui.internal.db import Base, JSONField, get_db, get_db_context
from open_webui.models.tags import TagModel, Tag, Tags from open_webui.models.tags import TagModel, Tag, Tags
from open_webui.models.folders import Folders from open_webui.models.folders import Folders
from open_webui.utils.misc import sanitize_data_for_db, sanitize_text_for_db from open_webui.utils.misc import sanitize_data_for_db, sanitize_text_for_db
@ -280,8 +281,8 @@ class ChatTable:
return changed return changed
def insert_new_chat(self, user_id: str, form_data: ChatForm) -> Optional[ChatModel]: def insert_new_chat(self, user_id: str, form_data: ChatForm, db: Optional[Session] = None) -> Optional[ChatModel]:
with get_db() as db: with get_db_context(db) as db:
id = str(uuid.uuid4()) id = str(uuid.uuid4())
chat = ChatModel( chat = ChatModel(
**{ **{
@ -331,9 +332,9 @@ class ChatTable:
return chat return chat
def import_chats( def import_chats(
self, user_id: str, chat_import_forms: list[ChatImportForm] self, user_id: str, chat_import_forms: list[ChatImportForm], db: Optional[Session] = None
) -> list[ChatModel]: ) -> list[ChatModel]:
with get_db() as db: with get_db_context(db) as db:
chats = [] chats = []
for form_data in chat_import_forms: for form_data in chat_import_forms:
@ -344,9 +345,9 @@ class ChatTable:
db.commit() db.commit()
return [ChatModel.model_validate(chat) for chat in chats] return [ChatModel.model_validate(chat) for chat in chats]
def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]: def update_chat_by_id(self, id: str, chat: dict, db: Optional[Session] = None) -> Optional[ChatModel]:
try: try:
with get_db() as db: with get_db_context(db) as db:
chat_item = db.get(Chat, id) chat_item = db.get(Chat, id)
chat_item.chat = self._clean_null_bytes(chat) chat_item.chat = self._clean_null_bytes(chat)
chat_item.title = ( chat_item.title = (
@ -483,13 +484,13 @@ class ChatTable:
self.update_chat_by_id(id, chat) self.update_chat_by_id(id, chat)
return message_files return message_files
def insert_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]: def insert_shared_chat_by_chat_id(self, chat_id: str, db: Optional[Session] = None) -> Optional[ChatModel]:
with get_db() as db: with get_db_context(db) as db:
# Get the existing chat to share # Get the existing chat to share
chat = db.get(Chat, chat_id) chat = db.get(Chat, chat_id)
# Check if the chat is already shared # Check if the chat is already shared
if chat.share_id: if chat.share_id:
return self.get_chat_by_id_and_user_id(chat.share_id, "shared") return self.get_chat_by_id_and_user_id(chat.share_id, "shared", db=db)
# Create a new chat with the same data, but with a new ID # Create a new chat with the same data, but with a new ID
shared_chat = ChatModel( shared_chat = ChatModel(
**{ **{
@ -518,16 +519,16 @@ class ChatTable:
db.commit() db.commit()
return shared_chat if (shared_result and result) else None return shared_chat if (shared_result and result) else None
def update_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]: def update_shared_chat_by_chat_id(self, chat_id: str, db: Optional[Session] = None) -> Optional[ChatModel]:
try: try:
with get_db() as db: with get_db_context(db) as db:
chat = db.get(Chat, chat_id) chat = db.get(Chat, chat_id)
shared_chat = ( shared_chat = (
db.query(Chat).filter_by(user_id=f"shared-{chat_id}").first() db.query(Chat).filter_by(user_id=f"shared-{chat_id}").first()
) )
if shared_chat is None: if shared_chat is None:
return self.insert_shared_chat_by_chat_id(chat_id) return self.insert_shared_chat_by_chat_id(chat_id, db=db)
shared_chat.title = chat.title shared_chat.title = chat.title
shared_chat.chat = chat.chat shared_chat.chat = chat.chat
@ -542,9 +543,9 @@ class ChatTable:
except Exception: except Exception:
return None return None
def delete_shared_chat_by_chat_id(self, chat_id: str) -> bool: def delete_shared_chat_by_chat_id(self, chat_id: str, db: Optional[Session] = None) -> bool:
try: try:
with get_db() as db: with get_db_context(db) as db:
db.query(Chat).filter_by(user_id=f"shared-{chat_id}").delete() db.query(Chat).filter_by(user_id=f"shared-{chat_id}").delete()
db.commit() db.commit()
@ -552,9 +553,9 @@ class ChatTable:
except Exception: except Exception:
return False return False
def unarchive_all_chats_by_user_id(self, user_id: str) -> bool: def unarchive_all_chats_by_user_id(self, user_id: str, db: Optional[Session] = None) -> bool:
try: try:
with get_db() as db: with get_db_context(db) as db:
db.query(Chat).filter_by(user_id=user_id).update({"archived": False}) db.query(Chat).filter_by(user_id=user_id).update({"archived": False})
db.commit() db.commit()
return True return True
@ -562,10 +563,10 @@ class ChatTable:
return False return False
def update_chat_share_id_by_id( def update_chat_share_id_by_id(
self, id: str, share_id: Optional[str] self, id: str, share_id: Optional[str], db: Optional[Session] = None
) -> Optional[ChatModel]: ) -> Optional[ChatModel]:
try: try:
with get_db() as db: with get_db_context(db) as db:
chat = db.get(Chat, id) chat = db.get(Chat, id)
chat.share_id = share_id chat.share_id = share_id
db.commit() db.commit()
@ -574,9 +575,9 @@ class ChatTable:
except Exception: except Exception:
return None return None
def toggle_chat_pinned_by_id(self, id: str) -> Optional[ChatModel]: def toggle_chat_pinned_by_id(self, id: str, db: Optional[Session] = None) -> Optional[ChatModel]:
try: try:
with get_db() as db: with get_db_context(db) as db:
chat = db.get(Chat, id) chat = db.get(Chat, id)
chat.pinned = not chat.pinned chat.pinned = not chat.pinned
chat.updated_at = int(time.time()) chat.updated_at = int(time.time())
@ -586,9 +587,9 @@ class ChatTable:
except Exception: except Exception:
return None return None
def toggle_chat_archive_by_id(self, id: str) -> Optional[ChatModel]: def toggle_chat_archive_by_id(self, id: str, db: Optional[Session] = None) -> Optional[ChatModel]:
try: try:
with get_db() as db: with get_db_context(db) as db:
chat = db.get(Chat, id) chat = db.get(Chat, id)
chat.archived = not chat.archived chat.archived = not chat.archived
chat.folder_id = None chat.folder_id = None
@ -599,9 +600,9 @@ class ChatTable:
except Exception: except Exception:
return None return None
def archive_all_chats_by_user_id(self, user_id: str) -> bool: def archive_all_chats_by_user_id(self, user_id: str, db: Optional[Session] = None) -> bool:
try: try:
with get_db() as db: with get_db_context(db) as db:
db.query(Chat).filter_by(user_id=user_id).update({"archived": True}) db.query(Chat).filter_by(user_id=user_id).update({"archived": True})
db.commit() db.commit()
return True return True
@ -614,9 +615,10 @@ class ChatTable:
filter: Optional[dict] = None, filter: Optional[dict] = None,
skip: int = 0, skip: int = 0,
limit: int = 50, limit: int = 50,
db: Optional[Session] = None,
) -> list[ChatModel]: ) -> list[ChatModel]:
with get_db() as db: with get_db_context(db) as db:
query = db.query(Chat).filter_by(user_id=user_id, archived=True) query = db.query(Chat).filter_by(user_id=user_id, archived=True)
if filter: if filter:
@ -655,8 +657,9 @@ class ChatTable:
filter: Optional[dict] = None, filter: Optional[dict] = None,
skip: int = 0, skip: int = 0,
limit: int = 50, limit: int = 50,
db: Optional[Session] = None,
) -> list[ChatModel]: ) -> list[ChatModel]:
with get_db() as db: with get_db_context(db) as db:
query = db.query(Chat).filter_by(user_id=user_id) query = db.query(Chat).filter_by(user_id=user_id)
if not include_archived: if not include_archived:
query = query.filter_by(archived=False) query = query.filter_by(archived=False)
@ -695,8 +698,9 @@ class ChatTable:
include_pinned: bool = False, include_pinned: bool = False,
skip: Optional[int] = None, skip: Optional[int] = None,
limit: Optional[int] = None, limit: Optional[int] = None,
db: Optional[Session] = None,
) -> list[ChatTitleIdResponse]: ) -> list[ChatTitleIdResponse]:
with get_db() as db: with get_db_context(db) as db:
query = db.query(Chat).filter_by(user_id=user_id) query = db.query(Chat).filter_by(user_id=user_id)
if not include_folders: if not include_folders:
@ -733,9 +737,9 @@ class ChatTable:
] ]
def get_chat_list_by_chat_ids( def get_chat_list_by_chat_ids(
self, chat_ids: list[str], skip: int = 0, limit: int = 50 self, chat_ids: list[str], skip: int = 0, limit: int = 50, db: Optional[Session] = None
) -> list[ChatModel]: ) -> list[ChatModel]:
with get_db() as db: with get_db_context(db) as db:
all_chats = ( all_chats = (
db.query(Chat) db.query(Chat)
.filter(Chat.id.in_(chat_ids)) .filter(Chat.id.in_(chat_ids))
@ -745,9 +749,9 @@ class ChatTable:
) )
return [ChatModel.model_validate(chat) for chat in all_chats] return [ChatModel.model_validate(chat) for chat in all_chats]
def get_chat_by_id(self, id: str) -> Optional[ChatModel]: def get_chat_by_id(self, id: str, db: Optional[Session] = None) -> Optional[ChatModel]:
try: try:
with get_db() as db: with get_db_context(db) as db:
chat_item = db.get(Chat, id) chat_item = db.get(Chat, id)
if chat_item is None: if chat_item is None:
return None return None
@ -760,30 +764,30 @@ class ChatTable:
except Exception: except Exception:
return None return None
def get_chat_by_share_id(self, id: str) -> Optional[ChatModel]: def get_chat_by_share_id(self, id: str, db: Optional[Session] = None) -> Optional[ChatModel]:
try: try:
with get_db() as db: with get_db_context(db) as db:
# it is possible that the shared link was deleted. hence, # it is possible that the shared link was deleted. hence,
# we check if the chat is still shared by checking if a chat with the share_id exists # we check if the chat is still shared by checking if a chat with the share_id exists
chat = db.query(Chat).filter_by(share_id=id).first() chat = db.query(Chat).filter_by(share_id=id).first()
if chat: if chat:
return self.get_chat_by_id(id) return self.get_chat_by_id(id, db=db)
else: else:
return None return None
except Exception: except Exception:
return None return None
def get_chat_by_id_and_user_id(self, id: str, user_id: str) -> Optional[ChatModel]: def get_chat_by_id_and_user_id(self, id: str, user_id: str, db: Optional[Session] = None) -> Optional[ChatModel]:
try: try:
with get_db() as db: with get_db_context(db) as db:
chat = db.query(Chat).filter_by(id=id, user_id=user_id).first() chat = db.query(Chat).filter_by(id=id, user_id=user_id).first()
return ChatModel.model_validate(chat) return ChatModel.model_validate(chat)
except Exception: except Exception:
return None return None
def get_chats(self, skip: int = 0, limit: int = 50) -> list[ChatModel]: def get_chats(self, skip: int = 0, limit: int = 50, db: Optional[Session] = None) -> list[ChatModel]:
with get_db() as db: with get_db_context(db) as db:
all_chats = ( all_chats = (
db.query(Chat) db.query(Chat)
# .limit(limit).offset(skip) # .limit(limit).offset(skip)
@ -797,8 +801,9 @@ class ChatTable:
filter: Optional[dict] = None, filter: Optional[dict] = None,
skip: Optional[int] = None, skip: Optional[int] = None,
limit: Optional[int] = None, limit: Optional[int] = None,
db: Optional[Session] = None,
) -> ChatListResponse: ) -> ChatListResponse:
with get_db() as db: with get_db_context(db) as db:
query = db.query(Chat).filter_by(user_id=user_id) query = db.query(Chat).filter_by(user_id=user_id)
if filter: if filter:
@ -838,8 +843,8 @@ class ChatTable:
} }
) )
def get_pinned_chats_by_user_id(self, user_id: str) -> list[ChatModel]: def get_pinned_chats_by_user_id(self, user_id: str, db: Optional[Session] = None) -> list[ChatModel]:
with get_db() as db: with get_db_context(db) as db:
all_chats = ( all_chats = (
db.query(Chat) db.query(Chat)
.filter_by(user_id=user_id, pinned=True, archived=False) .filter_by(user_id=user_id, pinned=True, archived=False)
@ -847,8 +852,8 @@ class ChatTable:
) )
return [ChatModel.model_validate(chat) for chat in all_chats] return [ChatModel.model_validate(chat) for chat in all_chats]
def get_archived_chats_by_user_id(self, user_id: str) -> list[ChatModel]: def get_archived_chats_by_user_id(self, user_id: str, db: Optional[Session] = None) -> list[ChatModel]:
with get_db() as db: with get_db_context(db) as db:
all_chats = ( all_chats = (
db.query(Chat) db.query(Chat)
.filter_by(user_id=user_id, archived=True) .filter_by(user_id=user_id, archived=True)
@ -863,6 +868,7 @@ class ChatTable:
include_archived: bool = False, include_archived: bool = False,
skip: int = 0, skip: int = 0,
limit: int = 60, limit: int = 60,
db: Optional[Session] = None,
) -> list[ChatModel]: ) -> list[ChatModel]:
""" """
Filters chats based on a search query using Python, allowing pagination using skip and limit. Filters chats based on a search query using Python, allowing pagination using skip and limit.
@ -871,7 +877,7 @@ class ChatTable:
if not search_text: if not search_text:
return self.get_chat_list_by_user_id( return self.get_chat_list_by_user_id(
user_id, include_archived, filter={}, skip=skip, limit=limit user_id, include_archived, filter={}, skip=skip, limit=limit, db=db
) )
search_text_words = search_text.split(" ") search_text_words = search_text.split(" ")
@ -926,7 +932,7 @@ class ChatTable:
search_text = " ".join(search_text_words) search_text = " ".join(search_text_words)
with get_db() as db: with get_db_context(db) as db:
query = db.query(Chat).filter(Chat.user_id == user_id) query = db.query(Chat).filter(Chat.user_id == user_id)
if is_archived is not None: if is_archived is not None:
@ -1067,9 +1073,9 @@ class ChatTable:
return [ChatModel.model_validate(chat) for chat in all_chats] return [ChatModel.model_validate(chat) for chat in all_chats]
def get_chats_by_folder_id_and_user_id( def get_chats_by_folder_id_and_user_id(
self, folder_id: str, user_id: str, skip: int = 0, limit: int = 60 self, folder_id: str, user_id: str, skip: int = 0, limit: int = 60, db: Optional[Session] = None
) -> list[ChatModel]: ) -> list[ChatModel]:
with get_db() as db: with get_db_context(db) as db:
query = db.query(Chat).filter_by(folder_id=folder_id, user_id=user_id) query = db.query(Chat).filter_by(folder_id=folder_id, user_id=user_id)
query = query.filter(or_(Chat.pinned == False, Chat.pinned == None)) query = query.filter(or_(Chat.pinned == False, Chat.pinned == None))
query = query.filter_by(archived=False) query = query.filter_by(archived=False)
@ -1085,9 +1091,9 @@ class ChatTable:
return [ChatModel.model_validate(chat) for chat in all_chats] return [ChatModel.model_validate(chat) for chat in all_chats]
def get_chats_by_folder_ids_and_user_id( def get_chats_by_folder_ids_and_user_id(
self, folder_ids: list[str], user_id: str self, folder_ids: list[str], user_id: str, db: Optional[Session] = None
) -> list[ChatModel]: ) -> list[ChatModel]:
with get_db() as db: with get_db_context(db) as db:
query = db.query(Chat).filter( query = db.query(Chat).filter(
Chat.folder_id.in_(folder_ids), Chat.user_id == user_id Chat.folder_id.in_(folder_ids), Chat.user_id == user_id
) )
@ -1100,10 +1106,10 @@ class ChatTable:
return [ChatModel.model_validate(chat) for chat in all_chats] return [ChatModel.model_validate(chat) for chat in all_chats]
def update_chat_folder_id_by_id_and_user_id( def update_chat_folder_id_by_id_and_user_id(
self, id: str, user_id: str, folder_id: str self, id: str, user_id: str, folder_id: str, db: Optional[Session] = None
) -> Optional[ChatModel]: ) -> Optional[ChatModel]:
try: try:
with get_db() as db: with get_db_context(db) as db:
chat = db.get(Chat, id) chat = db.get(Chat, id)
chat.folder_id = folder_id chat.folder_id = folder_id
chat.updated_at = int(time.time()) chat.updated_at = int(time.time())
@ -1114,16 +1120,16 @@ class ChatTable:
except Exception: except Exception:
return None return None
def get_chat_tags_by_id_and_user_id(self, id: str, user_id: str) -> list[TagModel]: def get_chat_tags_by_id_and_user_id(self, id: str, user_id: str, db: Optional[Session] = None) -> list[TagModel]:
with get_db() as db: with get_db_context(db) as db:
chat = db.get(Chat, id) chat = db.get(Chat, id)
tags = chat.meta.get("tags", []) tags = chat.meta.get("tags", [])
return [Tags.get_tag_by_name_and_user_id(tag, user_id) for tag in tags] return [Tags.get_tag_by_name_and_user_id(tag, user_id) for tag in tags]
def get_chat_list_by_user_id_and_tag_name( def get_chat_list_by_user_id_and_tag_name(
self, user_id: str, tag_name: str, skip: int = 0, limit: int = 50 self, user_id: str, tag_name: str, skip: int = 0, limit: int = 50, db: Optional[Session] = None
) -> list[ChatModel]: ) -> list[ChatModel]:
with get_db() as db: with get_db_context(db) as db:
query = db.query(Chat).filter_by(user_id=user_id) query = db.query(Chat).filter_by(user_id=user_id)
tag_id = tag_name.replace(" ", "_").lower() tag_id = tag_name.replace(" ", "_").lower()
@ -1152,13 +1158,13 @@ class ChatTable:
return [ChatModel.model_validate(chat) for chat in all_chats] return [ChatModel.model_validate(chat) for chat in all_chats]
def add_chat_tag_by_id_and_user_id_and_tag_name( def add_chat_tag_by_id_and_user_id_and_tag_name(
self, id: str, user_id: str, tag_name: str self, id: str, user_id: str, tag_name: str, db: Optional[Session] = None
) -> Optional[ChatModel]: ) -> Optional[ChatModel]:
tag = Tags.get_tag_by_name_and_user_id(tag_name, user_id) tag = Tags.get_tag_by_name_and_user_id(tag_name, user_id)
if tag is None: if tag is None:
tag = Tags.insert_new_tag(tag_name, user_id) tag = Tags.insert_new_tag(tag_name, user_id)
try: try:
with get_db() as db: with get_db_context(db) as db:
chat = db.get(Chat, id) chat = db.get(Chat, id)
tag_id = tag.id tag_id = tag.id
@ -1174,8 +1180,8 @@ class ChatTable:
except Exception: except Exception:
return None return None
def count_chats_by_tag_name_and_user_id(self, tag_name: str, user_id: str) -> int: def count_chats_by_tag_name_and_user_id(self, tag_name: str, user_id: str, db: Optional[Session] = None) -> int:
with get_db() as db: # Assuming `get_db()` returns a session object with get_db_context(db) as db: # Assuming `get_db()` returns a session object
query = db.query(Chat).filter_by(user_id=user_id, archived=False) query = db.query(Chat).filter_by(user_id=user_id, archived=False)
# Normalize the tag_name for consistency # Normalize the tag_name for consistency
@ -1210,8 +1216,8 @@ class ChatTable:
return count return count
def count_chats_by_folder_id_and_user_id(self, folder_id: str, user_id: str) -> int: def count_chats_by_folder_id_and_user_id(self, folder_id: str, user_id: str, db: Optional[Session] = None) -> int:
with get_db() as db: with get_db_context(db) as db:
query = db.query(Chat).filter_by(user_id=user_id) query = db.query(Chat).filter_by(user_id=user_id)
query = query.filter_by(folder_id=folder_id) query = query.filter_by(folder_id=folder_id)
@ -1221,10 +1227,10 @@ class ChatTable:
return count return count
def delete_tag_by_id_and_user_id_and_tag_name( def delete_tag_by_id_and_user_id_and_tag_name(
self, id: str, user_id: str, tag_name: str self, id: str, user_id: str, tag_name: str, db: Optional[Session] = None
) -> bool: ) -> bool:
try: try:
with get_db() as db: with get_db_context(db) as db:
chat = db.get(Chat, id) chat = db.get(Chat, id)
tags = chat.meta.get("tags", []) tags = chat.meta.get("tags", [])
tag_id = tag_name.replace(" ", "_").lower() tag_id = tag_name.replace(" ", "_").lower()
@ -1239,9 +1245,9 @@ class ChatTable:
except Exception: except Exception:
return False return False
def delete_all_tags_by_id_and_user_id(self, id: str, user_id: str) -> bool: def delete_all_tags_by_id_and_user_id(self, id: str, user_id: str, db: Optional[Session] = None) -> bool:
try: try:
with get_db() as db: with get_db_context(db) as db:
chat = db.get(Chat, id) chat = db.get(Chat, id)
chat.meta = { chat.meta = {
**chat.meta, **chat.meta,
@ -1253,30 +1259,30 @@ class ChatTable:
except Exception: except Exception:
return False return False
def delete_chat_by_id(self, id: str) -> bool: def delete_chat_by_id(self, id: str, db: Optional[Session] = None) -> bool:
try: try:
with get_db() as db: with get_db_context(db) as db:
db.query(Chat).filter_by(id=id).delete() db.query(Chat).filter_by(id=id).delete()
db.commit() db.commit()
return True and self.delete_shared_chat_by_chat_id(id) return True and self.delete_shared_chat_by_chat_id(id, db=db)
except Exception: except Exception:
return False return False
def delete_chat_by_id_and_user_id(self, id: str, user_id: str) -> bool: def delete_chat_by_id_and_user_id(self, id: str, user_id: str, db: Optional[Session] = None) -> bool:
try: try:
with get_db() as db: with get_db_context(db) as db:
db.query(Chat).filter_by(id=id, user_id=user_id).delete() db.query(Chat).filter_by(id=id, user_id=user_id).delete()
db.commit() db.commit()
return True and self.delete_shared_chat_by_chat_id(id) return True and self.delete_shared_chat_by_chat_id(id, db=db)
except Exception: except Exception:
return False return False
def delete_chats_by_user_id(self, user_id: str) -> bool: def delete_chats_by_user_id(self, user_id: str, db: Optional[Session] = None) -> bool:
try: try:
with get_db() as db: with get_db_context(db) as db:
self.delete_shared_chats_by_user_id(user_id) self.delete_shared_chats_by_user_id(user_id, db=db)
db.query(Chat).filter_by(user_id=user_id).delete() db.query(Chat).filter_by(user_id=user_id).delete()
db.commit() db.commit()
@ -1286,10 +1292,10 @@ class ChatTable:
return False return False
def delete_chats_by_user_id_and_folder_id( def delete_chats_by_user_id_and_folder_id(
self, user_id: str, folder_id: str self, user_id: str, folder_id: str, db: Optional[Session] = None
) -> bool: ) -> bool:
try: try:
with get_db() as db: with get_db_context(db) as db:
db.query(Chat).filter_by(user_id=user_id, folder_id=folder_id).delete() db.query(Chat).filter_by(user_id=user_id, folder_id=folder_id).delete()
db.commit() db.commit()
@ -1298,10 +1304,10 @@ class ChatTable:
return False return False
def move_chats_by_user_id_and_folder_id( def move_chats_by_user_id_and_folder_id(
self, user_id: str, folder_id: str, new_folder_id: Optional[str] self, user_id: str, folder_id: str, new_folder_id: Optional[str], db: Optional[Session] = None
) -> bool: ) -> bool:
try: try:
with get_db() as db: with get_db_context(db) as db:
db.query(Chat).filter_by(user_id=user_id, folder_id=folder_id).update( db.query(Chat).filter_by(user_id=user_id, folder_id=folder_id).update(
{"folder_id": new_folder_id} {"folder_id": new_folder_id}
) )
@ -1311,9 +1317,9 @@ class ChatTable:
except Exception: except Exception:
return False return False
def delete_shared_chats_by_user_id(self, user_id: str) -> bool: def delete_shared_chats_by_user_id(self, user_id: str, db: Optional[Session] = None) -> bool:
try: try:
with get_db() as db: with get_db_context(db) as db:
chats_by_user = db.query(Chat).filter_by(user_id=user_id).all() chats_by_user = db.query(Chat).filter_by(user_id=user_id).all()
shared_chat_ids = [f"shared-{chat.id}" for chat in chats_by_user] shared_chat_ids = [f"shared-{chat.id}" for chat in chats_by_user]
@ -1325,7 +1331,7 @@ class ChatTable:
return False return False
def insert_chat_files( def insert_chat_files(
self, chat_id: str, message_id: str, file_ids: list[str], user_id: str self, chat_id: str, message_id: str, file_ids: list[str], user_id: str, db: Optional[Session] = None
) -> Optional[list[ChatFileModel]]: ) -> Optional[list[ChatFileModel]]:
if not file_ids: if not file_ids:
return None return None
@ -1333,7 +1339,7 @@ class ChatTable:
chat_message_file_ids = [ chat_message_file_ids = [
item.id item.id
for item in self.get_chat_files_by_chat_id_and_message_id( for item in self.get_chat_files_by_chat_id_and_message_id(
chat_id, message_id chat_id, message_id, db=db
) )
] ]
# Remove duplicates and existing file_ids # Remove duplicates and existing file_ids
@ -1350,7 +1356,7 @@ class ChatTable:
return None return None
try: try:
with get_db() as db: with get_db_context(db) as db:
now = int(time.time()) now = int(time.time())
chat_files = [ chat_files = [
@ -1378,9 +1384,9 @@ class ChatTable:
return None return None
def get_chat_files_by_chat_id_and_message_id( def get_chat_files_by_chat_id_and_message_id(
self, chat_id: str, message_id: str self, chat_id: str, message_id: str, db: Optional[Session] = None
) -> list[ChatFileModel]: ) -> list[ChatFileModel]:
with get_db() as db: with get_db_context(db) as db:
all_chat_files = ( all_chat_files = (
db.query(ChatFile) db.query(ChatFile)
.filter_by(chat_id=chat_id, message_id=message_id) .filter_by(chat_id=chat_id, message_id=message_id)
@ -1391,17 +1397,17 @@ class ChatTable:
ChatFileModel.model_validate(chat_file) for chat_file in all_chat_files ChatFileModel.model_validate(chat_file) for chat_file in all_chat_files
] ]
def delete_chat_file(self, chat_id: str, file_id: str) -> bool: def delete_chat_file(self, chat_id: str, file_id: str, db: Optional[Session] = None) -> bool:
try: try:
with get_db() as db: with get_db_context(db) as db:
db.query(ChatFile).filter_by(chat_id=chat_id, file_id=file_id).delete() db.query(ChatFile).filter_by(chat_id=chat_id, file_id=file_id).delete()
db.commit() db.commit()
return True return True
except Exception: except Exception:
return False return False
def get_shared_chats_by_file_id(self, file_id: str) -> list[ChatModel]: def get_shared_chats_by_file_id(self, file_id: str, db: Optional[Session] = None) -> list[ChatModel]:
with get_db() as db: with get_db_context(db) as db:
# Join Chat and ChatFile tables to get shared chats associated with the file_id # Join Chat and ChatFile tables to get shared chats associated with the file_id
all_chats = ( all_chats = (
db.query(Chat) db.query(Chat)

View file

@ -3,7 +3,8 @@ import time
import uuid import uuid
from typing import Optional from typing import Optional
from open_webui.internal.db import Base, get_db from sqlalchemy.orm import Session
from open_webui.internal.db import Base, JSONField, get_db, get_db_context
from open_webui.models.users import User from open_webui.models.users import User
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
@ -121,9 +122,9 @@ class FeedbackListResponse(BaseModel):
class FeedbackTable: class FeedbackTable:
def insert_new_feedback( def insert_new_feedback(
self, user_id: str, form_data: FeedbackForm self, user_id: str, form_data: FeedbackForm, db: Optional[Session] = None
) -> Optional[FeedbackModel]: ) -> Optional[FeedbackModel]:
with get_db() as db: with get_db_context(db) as db:
id = str(uuid.uuid4()) id = str(uuid.uuid4())
feedback = FeedbackModel( feedback = FeedbackModel(
**{ **{
@ -148,9 +149,9 @@ class FeedbackTable:
log.exception(f"Error creating a new feedback: {e}") log.exception(f"Error creating a new feedback: {e}")
return None return None
def get_feedback_by_id(self, id: str) -> Optional[FeedbackModel]: def get_feedback_by_id(self, id: str, db: Optional[Session] = None) -> Optional[FeedbackModel]:
try: try:
with get_db() as db: with get_db_context(db) as db:
feedback = db.query(Feedback).filter_by(id=id).first() feedback = db.query(Feedback).filter_by(id=id).first()
if not feedback: if not feedback:
return None return None
@ -159,10 +160,10 @@ class FeedbackTable:
return None return None
def get_feedback_by_id_and_user_id( def get_feedback_by_id_and_user_id(
self, id: str, user_id: str self, id: str, user_id: str, db: Optional[Session] = None
) -> Optional[FeedbackModel]: ) -> Optional[FeedbackModel]:
try: try:
with get_db() as db: with get_db_context(db) as db:
feedback = db.query(Feedback).filter_by(id=id, user_id=user_id).first() feedback = db.query(Feedback).filter_by(id=id, user_id=user_id).first()
if not feedback: if not feedback:
return None return None
@ -171,9 +172,9 @@ class FeedbackTable:
return None return None
def get_feedback_items( def get_feedback_items(
self, filter: dict = {}, skip: int = 0, limit: int = 30 self, filter: dict = {}, skip: int = 0, limit: int = 30, db: Optional[Session] = None
) -> FeedbackListResponse: ) -> FeedbackListResponse:
with get_db() as db: with get_db_context(db) as db:
query = db.query(Feedback, User).join(User, Feedback.user_id == User.id) query = db.query(Feedback, User).join(User, Feedback.user_id == User.id)
if filter: if filter:
@ -234,8 +235,8 @@ class FeedbackTable:
return FeedbackListResponse(items=feedbacks, total=total) return FeedbackListResponse(items=feedbacks, total=total)
def get_all_feedbacks(self) -> list[FeedbackModel]: def get_all_feedbacks(self, db: Optional[Session] = None) -> list[FeedbackModel]:
with get_db() as db: with get_db_context(db) as db:
return [ return [
FeedbackModel.model_validate(feedback) FeedbackModel.model_validate(feedback)
for feedback in db.query(Feedback) for feedback in db.query(Feedback)
@ -243,8 +244,8 @@ class FeedbackTable:
.all() .all()
] ]
def get_feedbacks_by_type(self, type: str) -> list[FeedbackModel]: def get_feedbacks_by_type(self, type: str, db: Optional[Session] = None) -> list[FeedbackModel]:
with get_db() as db: with get_db_context(db) as db:
return [ return [
FeedbackModel.model_validate(feedback) FeedbackModel.model_validate(feedback)
for feedback in db.query(Feedback) for feedback in db.query(Feedback)
@ -253,8 +254,8 @@ class FeedbackTable:
.all() .all()
] ]
def get_feedbacks_by_user_id(self, user_id: str) -> list[FeedbackModel]: def get_feedbacks_by_user_id(self, user_id: str, db: Optional[Session] = None) -> list[FeedbackModel]:
with get_db() as db: with get_db_context(db) as db:
return [ return [
FeedbackModel.model_validate(feedback) FeedbackModel.model_validate(feedback)
for feedback in db.query(Feedback) for feedback in db.query(Feedback)
@ -264,9 +265,9 @@ class FeedbackTable:
] ]
def update_feedback_by_id( def update_feedback_by_id(
self, id: str, form_data: FeedbackForm self, id: str, form_data: FeedbackForm, db: Optional[Session] = None
) -> Optional[FeedbackModel]: ) -> Optional[FeedbackModel]:
with get_db() as db: with get_db_context(db) as db:
feedback = db.query(Feedback).filter_by(id=id).first() feedback = db.query(Feedback).filter_by(id=id).first()
if not feedback: if not feedback:
return None return None
@ -284,9 +285,9 @@ class FeedbackTable:
return FeedbackModel.model_validate(feedback) return FeedbackModel.model_validate(feedback)
def update_feedback_by_id_and_user_id( def update_feedback_by_id_and_user_id(
self, id: str, user_id: str, form_data: FeedbackForm self, id: str, user_id: str, form_data: FeedbackForm, db: Optional[Session] = None
) -> Optional[FeedbackModel]: ) -> Optional[FeedbackModel]:
with get_db() as db: with get_db_context(db) as db:
feedback = db.query(Feedback).filter_by(id=id, user_id=user_id).first() feedback = db.query(Feedback).filter_by(id=id, user_id=user_id).first()
if not feedback: if not feedback:
return None return None
@ -303,8 +304,8 @@ class FeedbackTable:
db.commit() db.commit()
return FeedbackModel.model_validate(feedback) return FeedbackModel.model_validate(feedback)
def delete_feedback_by_id(self, id: str) -> bool: def delete_feedback_by_id(self, id: str, db: Optional[Session] = None) -> bool:
with get_db() as db: with get_db_context(db) as db:
feedback = db.query(Feedback).filter_by(id=id).first() feedback = db.query(Feedback).filter_by(id=id).first()
if not feedback: if not feedback:
return False return False
@ -312,8 +313,8 @@ class FeedbackTable:
db.commit() db.commit()
return True return True
def delete_feedback_by_id_and_user_id(self, id: str, user_id: str) -> bool: def delete_feedback_by_id_and_user_id(self, id: str, user_id: str, db: Optional[Session] = None) -> bool:
with get_db() as db: with get_db_context(db) as db:
feedback = db.query(Feedback).filter_by(id=id, user_id=user_id).first() feedback = db.query(Feedback).filter_by(id=id, user_id=user_id).first()
if not feedback: if not feedback:
return False return False
@ -321,8 +322,8 @@ class FeedbackTable:
db.commit() db.commit()
return True return True
def delete_feedbacks_by_user_id(self, user_id: str) -> bool: def delete_feedbacks_by_user_id(self, user_id: str, db: Optional[Session] = None) -> bool:
with get_db() as db: with get_db_context(db) as db:
feedbacks = db.query(Feedback).filter_by(user_id=user_id).all() feedbacks = db.query(Feedback).filter_by(user_id=user_id).all()
if not feedbacks: if not feedbacks:
return False return False
@ -331,8 +332,8 @@ class FeedbackTable:
db.commit() db.commit()
return True return True
def delete_all_feedbacks(self) -> bool: def delete_all_feedbacks(self, db: Optional[Session] = None) -> bool:
with get_db() as db: with get_db_context(db) as db:
feedbacks = db.query(Feedback).all() feedbacks = db.query(Feedback).all()
if not feedbacks: if not feedbacks:
return False return False

View file

@ -2,7 +2,8 @@ import logging
import time import time
from typing import Optional from typing import Optional
from open_webui.internal.db import Base, JSONField, get_db from sqlalchemy.orm import Session
from open_webui.internal.db import Base, JSONField, get_db, get_db_context
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, String, Text, JSON from sqlalchemy import BigInteger, Column, String, Text, JSON
@ -108,8 +109,10 @@ class FileListResponse(BaseModel):
class FilesTable: class FilesTable:
def insert_new_file(self, user_id: str, form_data: FileForm) -> Optional[FileModel]: def insert_new_file(
with get_db() as db: self, user_id: str, form_data: FileForm, db: Optional[Session] = None
) -> Optional[FileModel]:
with get_db_context(db) as db:
file = FileModel( file = FileModel(
**{ **{
**form_data.model_dump(), **form_data.model_dump(),
@ -132,13 +135,16 @@ class FilesTable:
log.exception(f"Error inserting a new file: {e}") log.exception(f"Error inserting a new file: {e}")
return None return None
def get_file_by_id(self, id: str) -> Optional[FileModel]: def get_file_by_id(self, id: str, db: Optional[Session] = None) -> Optional[FileModel]:
with get_db() as db: try:
try: with get_db_context(db) as db:
file = db.get(File, id) try:
return FileModel.model_validate(file) file = db.get(File, id)
except Exception: return FileModel.model_validate(file)
return None except Exception:
return None
except Exception:
return None
def get_file_by_id_and_user_id(self, id: str, user_id: str) -> Optional[FileModel]: def get_file_by_id_and_user_id(self, id: str, user_id: str) -> Optional[FileModel]:
with get_db() as db: with get_db() as db:
@ -165,8 +171,8 @@ class FilesTable:
except Exception: except Exception:
return None return None
def get_files(self) -> list[FileModel]: def get_files(self, db: Optional[Session] = None) -> list[FileModel]:
with get_db() as db: with get_db_context(db) as db:
return [FileModel.model_validate(file) for file in db.query(File).all()] return [FileModel.model_validate(file) for file in db.query(File).all()]
def check_access_by_user_id(self, id, user_id, permission="write") -> bool: def check_access_by_user_id(self, id, user_id, permission="write") -> bool:
@ -206,8 +212,8 @@ class FilesTable:
.all() .all()
] ]
def get_files_by_user_id(self, user_id: str) -> list[FileModel]: def get_files_by_user_id(self, user_id: str, db: Optional[Session] = None) -> list[FileModel]:
with get_db() as db: with get_db_context(db) as db:
return [ return [
FileModel.model_validate(file) FileModel.model_validate(file)
for file in db.query(File).filter_by(user_id=user_id).all() for file in db.query(File).filter_by(user_id=user_id).all()
@ -271,24 +277,6 @@ class FilesTable:
except Exception: except Exception:
return None return None
def delete_file_by_id(self, id: str) -> bool:
with get_db() as db:
try:
db.query(File).filter_by(id=id).delete()
db.commit()
return True
except Exception:
return False
def delete_all_files(self) -> bool:
with get_db() as db:
try:
db.query(File).delete()
db.commit()
return True
except Exception:
return False return False

View file

@ -7,8 +7,9 @@ import re
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, Text, JSON, Boolean, func from sqlalchemy import BigInteger, Column, Text, JSON, Boolean, func
from sqlalchemy.orm import Session
from open_webui.internal.db import Base, get_db from open_webui.internal.db import Base, JSONField, get_db, get_db_context
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -83,9 +84,9 @@ class FolderUpdateForm(BaseModel):
class FolderTable: class FolderTable:
def insert_new_folder( def insert_new_folder(
self, user_id: str, form_data: FolderForm, parent_id: Optional[str] = None self, user_id: str, form_data: FolderForm, parent_id: Optional[str] = None, db: Optional[Session] = None
) -> Optional[FolderModel]: ) -> Optional[FolderModel]:
with get_db() as db: with get_db_context(db) as db:
id = str(uuid.uuid4()) id = str(uuid.uuid4())
folder = FolderModel( folder = FolderModel(
**{ **{
@ -111,10 +112,10 @@ class FolderTable:
return None return None
def get_folder_by_id_and_user_id( def get_folder_by_id_and_user_id(
self, id: str, user_id: str self, id: str, user_id: str, db: Optional[Session] = None
) -> Optional[FolderModel]: ) -> Optional[FolderModel]:
try: try:
with get_db() as db: with get_db_context(db) as db:
folder = db.query(Folder).filter_by(id=id, user_id=user_id).first() folder = db.query(Folder).filter_by(id=id, user_id=user_id).first()
if not folder: if not folder:
@ -125,15 +126,15 @@ class FolderTable:
return None return None
def get_children_folders_by_id_and_user_id( def get_children_folders_by_id_and_user_id(
self, id: str, user_id: str self, id: str, user_id: str, db: Optional[Session] = None
) -> Optional[list[FolderModel]]: ) -> Optional[list[FolderModel]]:
try: try:
with get_db() as db: with get_db_context(db) as db:
folders = [] folders = []
def get_children(folder): def get_children(folder):
children = self.get_folders_by_parent_id_and_user_id( children = self.get_folders_by_parent_id_and_user_id(
folder.id, user_id folder.id, user_id, db=db
) )
for child in children: for child in children:
get_children(child) get_children(child)
@ -148,18 +149,18 @@ class FolderTable:
except Exception: except Exception:
return None return None
def get_folders_by_user_id(self, user_id: str) -> list[FolderModel]: def get_folders_by_user_id(self, user_id: str, db: Optional[Session] = None) -> list[FolderModel]:
with get_db() as db: with get_db_context(db) as db:
return [ return [
FolderModel.model_validate(folder) FolderModel.model_validate(folder)
for folder in db.query(Folder).filter_by(user_id=user_id).all() for folder in db.query(Folder).filter_by(user_id=user_id).all()
] ]
def get_folder_by_parent_id_and_user_id_and_name( def get_folder_by_parent_id_and_user_id_and_name(
self, parent_id: Optional[str], user_id: str, name: str self, parent_id: Optional[str], user_id: str, name: str, db: Optional[Session] = None
) -> Optional[FolderModel]: ) -> Optional[FolderModel]:
try: try:
with get_db() as db: with get_db_context(db) as db:
# Check if folder exists # Check if folder exists
folder = ( folder = (
db.query(Folder) db.query(Folder)
@ -177,9 +178,9 @@ class FolderTable:
return None return None
def get_folders_by_parent_id_and_user_id( def get_folders_by_parent_id_and_user_id(
self, parent_id: Optional[str], user_id: str self, parent_id: Optional[str], user_id: str, db: Optional[Session] = None
) -> list[FolderModel]: ) -> list[FolderModel]:
with get_db() as db: with get_db_context(db) as db:
return [ return [
FolderModel.model_validate(folder) FolderModel.model_validate(folder)
for folder in db.query(Folder) for folder in db.query(Folder)
@ -192,9 +193,10 @@ class FolderTable:
id: str, id: str,
user_id: str, user_id: str,
parent_id: str, parent_id: str,
db: Optional[Session] = None,
) -> Optional[FolderModel]: ) -> Optional[FolderModel]:
try: try:
with get_db() as db: with get_db_context(db) as db:
folder = db.query(Folder).filter_by(id=id, user_id=user_id).first() folder = db.query(Folder).filter_by(id=id, user_id=user_id).first()
if not folder: if not folder:
@ -211,10 +213,10 @@ class FolderTable:
return return
def update_folder_by_id_and_user_id( def update_folder_by_id_and_user_id(
self, id: str, user_id: str, form_data: FolderUpdateForm self, id: str, user_id: str, form_data: FolderUpdateForm, db: Optional[Session] = None
) -> Optional[FolderModel]: ) -> Optional[FolderModel]:
try: try:
with get_db() as db: with get_db_context(db) as db:
folder = db.query(Folder).filter_by(id=id, user_id=user_id).first() folder = db.query(Folder).filter_by(id=id, user_id=user_id).first()
if not folder: if not folder:
@ -257,10 +259,10 @@ class FolderTable:
return return
def update_folder_is_expanded_by_id_and_user_id( def update_folder_is_expanded_by_id_and_user_id(
self, id: str, user_id: str, is_expanded: bool self, id: str, user_id: str, is_expanded: bool, db: Optional[Session] = None
) -> Optional[FolderModel]: ) -> Optional[FolderModel]:
try: try:
with get_db() as db: with get_db_context(db) as db:
folder = db.query(Folder).filter_by(id=id, user_id=user_id).first() folder = db.query(Folder).filter_by(id=id, user_id=user_id).first()
if not folder: if not folder:
@ -276,10 +278,10 @@ class FolderTable:
log.error(f"update_folder: {e}") log.error(f"update_folder: {e}")
return return
def delete_folder_by_id_and_user_id(self, id: str, user_id: str) -> list[str]: def delete_folder_by_id_and_user_id(self, id: str, user_id: str, db: Optional[Session] = None) -> list[str]:
try: try:
folder_ids = [] folder_ids = []
with get_db() as db: with get_db_context(db) as db:
folder = db.query(Folder).filter_by(id=id, user_id=user_id).first() folder = db.query(Folder).filter_by(id=id, user_id=user_id).first()
if not folder: if not folder:
return folder_ids return folder_ids
@ -289,7 +291,7 @@ class FolderTable:
# Delete all children folders # Delete all children folders
def delete_children(folder): def delete_children(folder):
folder_children = self.get_folders_by_parent_id_and_user_id( folder_children = self.get_folders_by_parent_id_and_user_id(
folder.id, user_id folder.id, user_id, db=db
) )
for folder_child in folder_children: for folder_child in folder_children:
@ -314,7 +316,7 @@ class FolderTable:
return name.strip().lower() return name.strip().lower()
def search_folders_by_names( def search_folders_by_names(
self, user_id: str, queries: list[str] self, user_id: str, queries: list[str], db: Optional[Session] = None
) -> list[FolderModel]: ) -> list[FolderModel]:
""" """
Search for folders for a user where the name matches any of the queries, treating _ and space as equivalent, case-insensitive. Search for folders for a user where the name matches any of the queries, treating _ and space as equivalent, case-insensitive.
@ -324,7 +326,7 @@ class FolderTable:
return [] return []
results = {} results = {}
with get_db() as db: with get_db_context(db) as db:
folders = db.query(Folder).filter_by(user_id=user_id).all() folders = db.query(Folder).filter_by(user_id=user_id).all()
for folder in folders: for folder in folders:
if self.normalize_folder_name(folder.name) in normalized_queries: if self.normalize_folder_name(folder.name) in normalized_queries:
@ -332,7 +334,7 @@ class FolderTable:
# get children folders # get children folders
children = self.get_children_folders_by_id_and_user_id( children = self.get_children_folders_by_id_and_user_id(
folder.id, user_id folder.id, user_id, db=db
) )
for child in children: for child in children:
results[child.id] = child results[child.id] = child
@ -345,14 +347,14 @@ class FolderTable:
return results return results
def search_folders_by_name_contains( def search_folders_by_name_contains(
self, user_id: str, query: str self, user_id: str, query: str, db: Optional[Session] = None
) -> list[FolderModel]: ) -> list[FolderModel]:
""" """
Partial match: normalized name contains (as substring) the normalized query. Partial match: normalized name contains (as substring) the normalized query.
""" """
normalized_query = self.normalize_folder_name(query) normalized_query = self.normalize_folder_name(query)
results = [] results = []
with get_db() as db: with get_db_context(db) as db:
folders = db.query(Folder).filter_by(user_id=user_id).all() folders = db.query(Folder).filter_by(user_id=user_id).all()
for folder in folders: for folder in folders:
norm_name = self.normalize_folder_name(folder.name) norm_name = self.normalize_folder_name(folder.name)

View file

@ -2,7 +2,8 @@ import logging
import time import time
from typing import Optional from typing import Optional
from open_webui.internal.db import Base, JSONField, get_db from sqlalchemy.orm import Session
from open_webui.internal.db import Base, JSONField, get_db, get_db_context
from open_webui.models.users import Users, UserModel from open_webui.models.users import Users, UserModel
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Boolean, Column, String, Text, Index from sqlalchemy import BigInteger, Boolean, Column, String, Text, Index
@ -103,7 +104,7 @@ class FunctionValves(BaseModel):
class FunctionsTable: class FunctionsTable:
def insert_new_function( def insert_new_function(
self, user_id: str, type: str, form_data: FunctionForm self, user_id: str, type: str, form_data: FunctionForm, db: Optional[Session] = None
) -> Optional[FunctionModel]: ) -> Optional[FunctionModel]:
function = FunctionModel( function = FunctionModel(
**{ **{
@ -116,7 +117,7 @@ class FunctionsTable:
) )
try: try:
with get_db() as db: with get_db_context(db) as db:
result = Function(**function.model_dump()) result = Function(**function.model_dump())
db.add(result) db.add(result)
db.commit() db.commit()
@ -130,11 +131,11 @@ class FunctionsTable:
return None return None
def sync_functions( def sync_functions(
self, user_id: str, functions: list[FunctionWithValvesModel] self, user_id: str, functions: list[FunctionWithValvesModel], db: Optional[Session] = None
) -> list[FunctionWithValvesModel]: ) -> list[FunctionWithValvesModel]:
# Synchronize functions for a user by updating existing ones, inserting new ones, and removing those that are no longer present. # Synchronize functions for a user by updating existing ones, inserting new ones, and removing those that are no longer present.
try: try:
with get_db() as db: with get_db_context(db) as db:
# Get existing functions # Get existing functions
existing_functions = db.query(Function).all() existing_functions = db.query(Function).all()
existing_ids = {func.id for func in existing_functions} existing_ids = {func.id for func in existing_functions}
@ -177,18 +178,18 @@ class FunctionsTable:
log.exception(f"Error syncing functions for user {user_id}: {e}") log.exception(f"Error syncing functions for user {user_id}: {e}")
return [] return []
def get_function_by_id(self, id: str) -> Optional[FunctionModel]: def get_function_by_id(self, id: str, db: Optional[Session] = None) -> Optional[FunctionModel]:
try: try:
with get_db() as db: with get_db_context(db) as db:
function = db.get(Function, id) function = db.get(Function, id)
return FunctionModel.model_validate(function) return FunctionModel.model_validate(function)
except Exception: except Exception:
return None return None
def get_functions( def get_functions(
self, active_only=False, include_valves=False self, active_only=False, include_valves=False, db: Optional[Session] = None
) -> list[FunctionModel | FunctionWithValvesModel]: ) -> list[FunctionModel | FunctionWithValvesModel]:
with get_db() as db: with get_db_context(db) as db:
if active_only: if active_only:
functions = db.query(Function).filter_by(is_active=True).all() functions = db.query(Function).filter_by(is_active=True).all()
@ -205,12 +206,12 @@ class FunctionsTable:
FunctionModel.model_validate(function) for function in functions FunctionModel.model_validate(function) for function in functions
] ]
def get_function_list(self) -> list[FunctionUserResponse]: def get_function_list(self, db: Optional[Session] = None) -> list[FunctionUserResponse]:
with get_db() as db: with get_db_context(db) as db:
functions = db.query(Function).order_by(Function.updated_at.desc()).all() functions = db.query(Function).order_by(Function.updated_at.desc()).all()
user_ids = list(set(func.user_id for func in functions)) user_ids = list(set(func.user_id for func in functions))
users = Users.get_users_by_user_ids(user_ids) if user_ids else [] users = Users.get_users_by_user_ids(user_ids, db=db) if user_ids else []
users_dict = {user.id: user for user in users} users_dict = {user.id: user for user in users}
return [ return [
@ -228,9 +229,9 @@ class FunctionsTable:
] ]
def get_functions_by_type( def get_functions_by_type(
self, type: str, active_only=False self, type: str, active_only=False, db: Optional[Session] = None
) -> list[FunctionModel]: ) -> list[FunctionModel]:
with get_db() as db: with get_db_context(db) as db:
if active_only: if active_only:
return [ return [
FunctionModel.model_validate(function) FunctionModel.model_validate(function)
@ -244,8 +245,8 @@ class FunctionsTable:
for function in db.query(Function).filter_by(type=type).all() for function in db.query(Function).filter_by(type=type).all()
] ]
def get_global_filter_functions(self) -> list[FunctionModel]: def get_global_filter_functions(self, db: Optional[Session] = None) -> list[FunctionModel]:
with get_db() as db: with get_db_context(db) as db:
return [ return [
FunctionModel.model_validate(function) FunctionModel.model_validate(function)
for function in db.query(Function) for function in db.query(Function)
@ -253,8 +254,8 @@ class FunctionsTable:
.all() .all()
] ]
def get_global_action_functions(self) -> list[FunctionModel]: def get_global_action_functions(self, db: Optional[Session] = None) -> list[FunctionModel]:
with get_db() as db: with get_db_context(db) as db:
return [ return [
FunctionModel.model_validate(function) FunctionModel.model_validate(function)
for function in db.query(Function) for function in db.query(Function)
@ -262,8 +263,8 @@ class FunctionsTable:
.all() .all()
] ]
def get_function_valves_by_id(self, id: str) -> Optional[dict]: def get_function_valves_by_id(self, id: str, db: Optional[Session] = None) -> Optional[dict]:
with get_db() as db: with get_db_context(db) as db:
try: try:
function = db.get(Function, id) function = db.get(Function, id)
return function.valves if function.valves else {} return function.valves if function.valves else {}
@ -272,23 +273,23 @@ class FunctionsTable:
return None return None
def update_function_valves_by_id( def update_function_valves_by_id(
self, id: str, valves: dict self, id: str, valves: dict, db: Optional[Session] = None
) -> Optional[FunctionValves]: ) -> Optional[FunctionValves]:
with get_db() as db: with get_db_context(db) as db:
try: try:
function = db.get(Function, id) function = db.get(Function, id)
function.valves = valves function.valves = valves
function.updated_at = int(time.time()) function.updated_at = int(time.time())
db.commit() db.commit()
db.refresh(function) db.refresh(function)
return self.get_function_by_id(id) return self.get_function_by_id(id, db=db)
except Exception: except Exception:
return None return None
def update_function_metadata_by_id( def update_function_metadata_by_id(
self, id: str, metadata: dict self, id: str, metadata: dict, db: Optional[Session] = None
) -> Optional[FunctionModel]: ) -> Optional[FunctionModel]:
with get_db() as db: with get_db_context(db) as db:
try: try:
function = db.get(Function, id) function = db.get(Function, id)
@ -301,7 +302,7 @@ class FunctionsTable:
function.updated_at = int(time.time()) function.updated_at = int(time.time())
db.commit() db.commit()
db.refresh(function) db.refresh(function)
return self.get_function_by_id(id) return self.get_function_by_id(id, db=db)
else: else:
return None return None
except Exception as e: except Exception as e:
@ -309,10 +310,10 @@ class FunctionsTable:
return None return None
def get_user_valves_by_id_and_user_id( def get_user_valves_by_id_and_user_id(
self, id: str, user_id: str self, id: str, user_id: str, db: Optional[Session] = None
) -> Optional[dict]: ) -> Optional[dict]:
try: try:
user = Users.get_user_by_id(user_id) user = Users.get_user_by_id(user_id, db=db)
user_settings = user.settings.model_dump() if user.settings else {} user_settings = user.settings.model_dump() if user.settings else {}
# Check if user has "functions" and "valves" settings # Check if user has "functions" and "valves" settings
@ -327,10 +328,10 @@ class FunctionsTable:
return None return None
def update_user_valves_by_id_and_user_id( def update_user_valves_by_id_and_user_id(
self, id: str, user_id: str, valves: dict self, id: str, user_id: str, valves: dict, db: Optional[Session] = None
) -> Optional[dict]: ) -> Optional[dict]:
try: try:
user = Users.get_user_by_id(user_id) user = Users.get_user_by_id(user_id, db=db)
user_settings = user.settings.model_dump() if user.settings else {} user_settings = user.settings.model_dump() if user.settings else {}
# Check if user has "functions" and "valves" settings # Check if user has "functions" and "valves" settings
@ -342,7 +343,7 @@ class FunctionsTable:
user_settings["functions"]["valves"][id] = valves user_settings["functions"]["valves"][id] = valves
# Update the user settings in the database # Update the user settings in the database
Users.update_user_by_id(user_id, {"settings": user_settings}) Users.update_user_by_id(user_id, {"settings": user_settings}, db=db)
return user_settings["functions"]["valves"][id] return user_settings["functions"]["valves"][id]
except Exception as e: except Exception as e:
@ -351,8 +352,8 @@ class FunctionsTable:
) )
return None return None
def update_function_by_id(self, id: str, updated: dict) -> Optional[FunctionModel]: def update_function_by_id(self, id: str, updated: dict, db: Optional[Session] = None) -> Optional[FunctionModel]:
with get_db() as db: with get_db_context(db) as db:
try: try:
db.query(Function).filter_by(id=id).update( db.query(Function).filter_by(id=id).update(
{ {
@ -361,12 +362,12 @@ class FunctionsTable:
} }
) )
db.commit() db.commit()
return self.get_function_by_id(id) return self.get_function_by_id(id, db=db)
except Exception: except Exception:
return None return None
def deactivate_all_functions(self) -> Optional[bool]: def deactivate_all_functions(self, db: Optional[Session] = None) -> Optional[bool]:
with get_db() as db: with get_db_context(db) as db:
try: try:
db.query(Function).update( db.query(Function).update(
{ {
@ -379,8 +380,8 @@ class FunctionsTable:
except Exception: except Exception:
return None return None
def delete_function_by_id(self, id: str) -> bool: def delete_function_by_id(self, id: str, db: Optional[Session] = None) -> bool:
with get_db() as db: with get_db_context(db) as db:
try: try:
db.query(Function).filter_by(id=id).delete() db.query(Function).filter_by(id=id).delete()
db.commit() db.commit()

View file

@ -4,7 +4,8 @@ import time
from typing import Optional from typing import Optional
import uuid import uuid
from open_webui.internal.db import Base, get_db from sqlalchemy.orm import Session
from open_webui.internal.db import Base, JSONField, get_db, get_db_context
from open_webui.models.files import FileMetadataResponse from open_webui.models.files import FileMetadataResponse
@ -120,9 +121,9 @@ class GroupListResponse(BaseModel):
class GroupTable: class GroupTable:
def insert_new_group( def insert_new_group(
self, user_id: str, form_data: GroupForm self, user_id: str, form_data: GroupForm, db: Optional[Session] = None
) -> Optional[GroupModel]: ) -> Optional[GroupModel]:
with get_db() as db: with get_db_context(db) as db:
group = GroupModel( group = GroupModel(
**{ **{
**form_data.model_dump(exclude_none=True), **form_data.model_dump(exclude_none=True),
@ -146,13 +147,13 @@ class GroupTable:
except Exception: except Exception:
return None return None
def get_all_groups(self) -> list[GroupModel]: def get_all_groups(self, db: Optional[Session] = None) -> list[GroupModel]:
with get_db() as db: with get_db_context(db) as db:
groups = db.query(Group).order_by(Group.updated_at.desc()).all() groups = db.query(Group).order_by(Group.updated_at.desc()).all()
return [GroupModel.model_validate(group) for group in groups] return [GroupModel.model_validate(group) for group in groups]
def get_groups(self, filter) -> list[GroupResponse]: def get_groups(self, filter, db: Optional[Session] = None) -> list[GroupResponse]:
with get_db() as db: with get_db_context(db) as db:
query = db.query(Group) query = db.query(Group)
if filter: if filter:
@ -184,16 +185,16 @@ class GroupTable:
GroupResponse.model_validate( GroupResponse.model_validate(
{ {
**GroupModel.model_validate(group).model_dump(), **GroupModel.model_validate(group).model_dump(),
"member_count": self.get_group_member_count_by_id(group.id), "member_count": self.get_group_member_count_by_id(group.id, db=db),
} }
) )
for group in groups for group in groups
] ]
def search_groups( def search_groups(
self, filter: Optional[dict] = None, skip: int = 0, limit: int = 30 self, filter: Optional[dict] = None, skip: int = 0, limit: int = 30, db: Optional[Session] = None
) -> GroupListResponse: ) -> GroupListResponse:
with get_db() as db: with get_db_context(db) as db:
query = db.query(Group) query = db.query(Group)
if filter: if filter:
@ -220,15 +221,15 @@ class GroupTable:
"items": [ "items": [
GroupResponse.model_validate( GroupResponse.model_validate(
**GroupModel.model_validate(group).model_dump(), **GroupModel.model_validate(group).model_dump(),
member_count=self.get_group_member_count_by_id(group.id), member_count=self.get_group_member_count_by_id(group.id, db=db),
) )
for group in groups for group in groups
], ],
"total": total, "total": total,
} }
def get_groups_by_member_id(self, user_id: str) -> list[GroupModel]: def get_groups_by_member_id(self, user_id: str, db: Optional[Session] = None) -> list[GroupModel]:
with get_db() as db: with get_db_context(db) as db:
return [ return [
GroupModel.model_validate(group) GroupModel.model_validate(group)
for group in db.query(Group) for group in db.query(Group)
@ -238,16 +239,16 @@ class GroupTable:
.all() .all()
] ]
def get_group_by_id(self, id: str) -> Optional[GroupModel]: def get_group_by_id(self, id: str, db: Optional[Session] = None) -> Optional[GroupModel]:
try: try:
with get_db() as db: with get_db_context(db) as db:
group = db.query(Group).filter_by(id=id).first() group = db.query(Group).filter_by(id=id).first()
return GroupModel.model_validate(group) if group else None return GroupModel.model_validate(group) if group else None
except Exception: except Exception:
return None return None
def get_group_user_ids_by_id(self, id: str) -> Optional[list[str]]: def get_group_user_ids_by_id(self, id: str, db: Optional[Session] = None) -> Optional[list[str]]:
with get_db() as db: with get_db_context(db) as db:
members = ( members = (
db.query(GroupMember.user_id).filter(GroupMember.group_id == id).all() db.query(GroupMember.user_id).filter(GroupMember.group_id == id).all()
) )
@ -257,8 +258,8 @@ class GroupTable:
return [m[0] for m in members] return [m[0] for m in members]
def get_group_user_ids_by_ids(self, group_ids: list[str]) -> dict[str, list[str]]: def get_group_user_ids_by_ids(self, group_ids: list[str], db: Optional[Session] = None) -> dict[str, list[str]]:
with get_db() as db: with get_db_context(db) as db:
members = ( members = (
db.query(GroupMember.group_id, GroupMember.user_id) db.query(GroupMember.group_id, GroupMember.user_id)
.filter(GroupMember.group_id.in_(group_ids)) .filter(GroupMember.group_id.in_(group_ids))
@ -274,8 +275,8 @@ class GroupTable:
return group_user_ids return group_user_ids
def set_group_user_ids_by_id(self, group_id: str, user_ids: list[str]) -> None: def set_group_user_ids_by_id(self, group_id: str, user_ids: list[str], db: Optional[Session] = None) -> None:
with get_db() as db: with get_db_context(db) as db:
# Delete existing members # Delete existing members
db.query(GroupMember).filter(GroupMember.group_id == group_id).delete() db.query(GroupMember).filter(GroupMember.group_id == group_id).delete()
@ -295,8 +296,8 @@ class GroupTable:
db.add_all(new_members) db.add_all(new_members)
db.commit() db.commit()
def get_group_member_count_by_id(self, id: str) -> int: def get_group_member_count_by_id(self, id: str, db: Optional[Session] = None) -> int:
with get_db() as db: with get_db_context(db) as db:
count = ( count = (
db.query(func.count(GroupMember.user_id)) db.query(func.count(GroupMember.user_id))
.filter(GroupMember.group_id == id) .filter(GroupMember.group_id == id)
@ -305,10 +306,10 @@ class GroupTable:
return count if count else 0 return count if count else 0
def update_group_by_id( def update_group_by_id(
self, id: str, form_data: GroupUpdateForm, overwrite: bool = False self, id: str, form_data: GroupUpdateForm, overwrite: bool = False, db: Optional[Session] = None
) -> Optional[GroupModel]: ) -> Optional[GroupModel]:
try: try:
with get_db() as db: with get_db_context(db) as db:
db.query(Group).filter_by(id=id).update( db.query(Group).filter_by(id=id).update(
{ {
**form_data.model_dump(exclude_none=True), **form_data.model_dump(exclude_none=True),
@ -316,22 +317,22 @@ class GroupTable:
} }
) )
db.commit() db.commit()
return self.get_group_by_id(id=id) return self.get_group_by_id(id=id, db=db)
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
return None return None
def delete_group_by_id(self, id: str) -> bool: def delete_group_by_id(self, id: str, db: Optional[Session] = None) -> bool:
try: try:
with get_db() as db: with get_db_context(db) as db:
db.query(Group).filter_by(id=id).delete() db.query(Group).filter_by(id=id).delete()
db.commit() db.commit()
return True return True
except Exception: except Exception:
return False return False
def delete_all_groups(self) -> bool: def delete_all_groups(self, db: Optional[Session] = None) -> bool:
with get_db() as db: with get_db_context(db) as db:
try: try:
db.query(Group).delete() db.query(Group).delete()
db.commit() db.commit()
@ -340,8 +341,8 @@ class GroupTable:
except Exception: except Exception:
return False return False
def remove_user_from_all_groups(self, user_id: str) -> bool: def remove_user_from_all_groups(self, user_id: str, db: Optional[Session] = None) -> bool:
with get_db() as db: with get_db_context(db) as db:
try: try:
# Find all groups the user belongs to # Find all groups the user belongs to
groups = ( groups = (
@ -369,16 +370,16 @@ class GroupTable:
return False return False
def create_groups_by_group_names( def create_groups_by_group_names(
self, user_id: str, group_names: list[str] self, user_id: str, group_names: list[str], db: Optional[Session] = None
) -> list[GroupModel]: ) -> list[GroupModel]:
# check for existing groups # check for existing groups
existing_groups = self.get_all_groups() existing_groups = self.get_all_groups(db=db)
existing_group_names = {group.name for group in existing_groups} existing_group_names = {group.name for group in existing_groups}
new_groups = [] new_groups = []
with get_db() as db: with get_db_context(db) as db:
for group_name in group_names: for group_name in group_names:
if group_name not in existing_group_names: if group_name not in existing_group_names:
new_group = GroupModel( new_group = GroupModel(
@ -400,8 +401,8 @@ class GroupTable:
continue continue
return new_groups return new_groups
def sync_groups_by_group_names(self, user_id: str, group_names: list[str]) -> bool: def sync_groups_by_group_names(self, user_id: str, group_names: list[str], db: Optional[Session] = None) -> bool:
with get_db() as db: with get_db_context(db) as db:
try: try:
now = int(time.time()) now = int(time.time())
@ -461,10 +462,10 @@ class GroupTable:
return False return False
def add_users_to_group( def add_users_to_group(
self, id: str, user_ids: Optional[list[str]] = None self, id: str, user_ids: Optional[list[str]] = None, db: Optional[Session] = None
) -> Optional[GroupModel]: ) -> Optional[GroupModel]:
try: try:
with get_db() as db: with get_db_context(db) as db:
group = db.query(Group).filter_by(id=id).first() group = db.query(Group).filter_by(id=id).first()
if not group: if not group:
return None return None
@ -499,10 +500,10 @@ class GroupTable:
return None return None
def remove_users_from_group( def remove_users_from_group(
self, id: str, user_ids: Optional[list[str]] = None self, id: str, user_ids: Optional[list[str]] = None, db: Optional[Session] = None
) -> Optional[GroupModel]: ) -> Optional[GroupModel]:
try: try:
with get_db() as db: with get_db_context(db) as db:
group = db.query(Group).filter_by(id=id).first() group = db.query(Group).filter_by(id=id).first()
if not group: if not group:
return None return None

View file

@ -1,10 +1,12 @@
import json import json
import logging import logging
import time import time
from typing import Optional from typing import Optional
import uuid import uuid
from open_webui.internal.db import Base, get_db from sqlalchemy.orm import Session
from open_webui.internal.db import Base, JSONField, get_db, get_db_context
from open_webui.models.files import ( from open_webui.models.files import (
File, File,
@ -157,9 +159,9 @@ class KnowledgeFileListResponse(BaseModel):
class KnowledgeTable: class KnowledgeTable:
def insert_new_knowledge( def insert_new_knowledge(
self, user_id: str, form_data: KnowledgeForm self, user_id: str, form_data: KnowledgeForm, db: Optional[Session] = None
) -> Optional[KnowledgeModel]: ) -> Optional[KnowledgeModel]:
with get_db() as db: with get_db_context(db) as db:
knowledge = KnowledgeModel( knowledge = KnowledgeModel(
**{ **{
**form_data.model_dump(), **form_data.model_dump(),
@ -183,15 +185,15 @@ class KnowledgeTable:
return None return None
def get_knowledge_bases( def get_knowledge_bases(
self, skip: int = 0, limit: int = 30 self, skip: int = 0, limit: int = 30, db: Optional[Session] = None
) -> list[KnowledgeUserModel]: ) -> list[KnowledgeUserModel]:
with get_db() as db: with get_db_context(db) as db:
all_knowledge = ( all_knowledge = (
db.query(Knowledge).order_by(Knowledge.updated_at.desc()).all() db.query(Knowledge).order_by(Knowledge.updated_at.desc()).all()
) )
user_ids = list(set(knowledge.user_id for knowledge in all_knowledge)) user_ids = list(set(knowledge.user_id for knowledge in all_knowledge))
users = Users.get_users_by_user_ids(user_ids) if user_ids else [] users = Users.get_users_by_user_ids(user_ids, db=db) if user_ids else []
users_dict = {user.id: user for user in users} users_dict = {user.id: user for user in users}
knowledge_bases = [] knowledge_bases = []
@ -208,10 +210,10 @@ class KnowledgeTable:
return knowledge_bases return knowledge_bases
def search_knowledge_bases( def search_knowledge_bases(
self, user_id: str, filter: dict, skip: int = 0, limit: int = 30 self, user_id: str, filter: dict, skip: int = 0, limit: int = 30, db: Optional[Session] = None
) -> KnowledgeListResponse: ) -> KnowledgeListResponse:
try: try:
with get_db() as db: with get_db_context(db) as db:
query = db.query(Knowledge, User).outerjoin( query = db.query(Knowledge, User).outerjoin(
User, User.id == Knowledge.user_id User, User.id == Knowledge.user_id
) )
@ -267,14 +269,14 @@ class KnowledgeTable:
return KnowledgeListResponse(items=[], total=0) return KnowledgeListResponse(items=[], total=0)
def search_knowledge_files( def search_knowledge_files(
self, filter: dict, skip: int = 0, limit: int = 30 self, filter: dict, skip: int = 0, limit: int = 30, db: Optional[Session] = None
) -> KnowledgeFileListResponse: ) -> KnowledgeFileListResponse:
""" """
Scalable version: search files across all knowledge bases the user has Scalable version: search files across all knowledge bases the user has
READ access to, without loading all KBs or using large IN() lists. READ access to, without loading all KBs or using large IN() lists.
""" """
try: try:
with get_db() as db: with get_db_context(db) as db:
# Base query: join Knowledge → KnowledgeFile → File # Base query: join Knowledge → KnowledgeFile → File
query = ( query = (
db.query(File, User) db.query(File, User)
@ -327,20 +329,20 @@ class KnowledgeTable:
print("search_knowledge_files error:", e) print("search_knowledge_files error:", e)
return KnowledgeFileListResponse(items=[], total=0) return KnowledgeFileListResponse(items=[], total=0)
def check_access_by_user_id(self, id, user_id, permission="write") -> bool: def check_access_by_user_id(self, id, user_id, permission="write", db: Optional[Session] = None) -> bool:
knowledge = self.get_knowledge_by_id(id) knowledge = self.get_knowledge_by_id(id, db=db)
if not knowledge: if not knowledge:
return False return False
if knowledge.user_id == user_id: if knowledge.user_id == user_id:
return True return True
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id)} user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id, db=db)}
return has_access(user_id, permission, knowledge.access_control, user_group_ids) return has_access(user_id, permission, knowledge.access_control, user_group_ids)
def get_knowledge_bases_by_user_id( def get_knowledge_bases_by_user_id(
self, user_id: str, permission: str = "write" self, user_id: str, permission: str = "write", db: Optional[Session] = None
) -> list[KnowledgeUserModel]: ) -> list[KnowledgeUserModel]:
knowledge_bases = self.get_knowledge_bases() knowledge_bases = self.get_knowledge_bases(db=db)
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id)} user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id, db=db)}
return [ return [
knowledge_base knowledge_base
for knowledge_base in knowledge_bases for knowledge_base in knowledge_bases
@ -350,32 +352,32 @@ class KnowledgeTable:
) )
] ]
def get_knowledge_by_id(self, id: str) -> Optional[KnowledgeModel]: def get_knowledge_by_id(self, id: str, db: Optional[Session] = None) -> Optional[KnowledgeModel]:
try: try:
with get_db() as db: with get_db_context(db) as db:
knowledge = db.query(Knowledge).filter_by(id=id).first() knowledge = db.query(Knowledge).filter_by(id=id).first()
return KnowledgeModel.model_validate(knowledge) if knowledge else None return KnowledgeModel.model_validate(knowledge) if knowledge else None
except Exception: except Exception:
return None return None
def get_knowledge_by_id_and_user_id( def get_knowledge_by_id_and_user_id(
self, id: str, user_id: str self, id: str, user_id: str, db: Optional[Session] = None
) -> Optional[KnowledgeModel]: ) -> Optional[KnowledgeModel]:
knowledge = self.get_knowledge_by_id(id) knowledge = self.get_knowledge_by_id(id, db=db)
if not knowledge: if not knowledge:
return None return None
if knowledge.user_id == user_id: if knowledge.user_id == user_id:
return knowledge return knowledge
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id)} user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id, db=db)}
if has_access(user_id, "write", knowledge.access_control, user_group_ids): if has_access(user_id, "write", knowledge.access_control, user_group_ids):
return knowledge return knowledge
return None return None
def get_knowledges_by_file_id(self, file_id: str) -> list[KnowledgeModel]: def get_knowledges_by_file_id(self, file_id: str, db: Optional[Session] = None) -> list[KnowledgeModel]:
try: try:
with get_db() as db: with get_db_context(db) as db:
knowledges = ( knowledges = (
db.query(Knowledge) db.query(Knowledge)
.join(KnowledgeFile, Knowledge.id == KnowledgeFile.knowledge_id) .join(KnowledgeFile, Knowledge.id == KnowledgeFile.knowledge_id)
@ -395,9 +397,10 @@ class KnowledgeTable:
filter: dict, filter: dict,
skip: int = 0, skip: int = 0,
limit: int = 30, limit: int = 30,
db: Optional[Session] = None,
) -> KnowledgeFileListResponse: ) -> KnowledgeFileListResponse:
try: try:
with get_db() as db: with get_db_context(db) as db:
query = ( query = (
db.query(File, User) db.query(File, User)
.join(KnowledgeFile, File.id == KnowledgeFile.file_id) .join(KnowledgeFile, File.id == KnowledgeFile.file_id)
@ -470,9 +473,9 @@ class KnowledgeTable:
print(e) print(e)
return KnowledgeFileListResponse(items=[], total=0) return KnowledgeFileListResponse(items=[], total=0)
def get_files_by_id(self, knowledge_id: str) -> list[FileModel]: def get_files_by_id(self, knowledge_id: str, db: Optional[Session] = None) -> list[FileModel]:
try: try:
with get_db() as db: with get_db_context(db) as db:
files = ( files = (
db.query(File) db.query(File)
.join(KnowledgeFile, File.id == KnowledgeFile.file_id) .join(KnowledgeFile, File.id == KnowledgeFile.file_id)
@ -483,18 +486,18 @@ class KnowledgeTable:
except Exception: except Exception:
return [] return []
def get_file_metadatas_by_id(self, knowledge_id: str) -> list[FileMetadataResponse]: def get_file_metadatas_by_id(self, knowledge_id: str, db: Optional[Session] = None) -> list[FileMetadataResponse]:
try: try:
with get_db() as db: with get_db_context(db) as db:
files = self.get_files_by_id(knowledge_id) files = self.get_files_by_id(knowledge_id, db=db)
return [FileMetadataResponse(**file.model_dump()) for file in files] return [FileMetadataResponse(**file.model_dump()) for file in files]
except Exception: except Exception:
return [] return []
def add_file_to_knowledge_by_id( def add_file_to_knowledge_by_id(
self, knowledge_id: str, file_id: str, user_id: str self, knowledge_id: str, file_id: str, user_id: str, db: Optional[Session] = None
) -> Optional[KnowledgeFileModel]: ) -> Optional[KnowledgeFileModel]:
with get_db() as db: with get_db_context(db) as db:
knowledge_file = KnowledgeFileModel( knowledge_file = KnowledgeFileModel(
**{ **{
"id": str(uuid.uuid4()), "id": str(uuid.uuid4()),
@ -518,9 +521,9 @@ class KnowledgeTable:
except Exception: except Exception:
return None return None
def remove_file_from_knowledge_by_id(self, knowledge_id: str, file_id: str) -> bool: def remove_file_from_knowledge_by_id(self, knowledge_id: str, file_id: str, db: Optional[Session] = None) -> bool:
try: try:
with get_db() as db: with get_db_context(db) as db:
db.query(KnowledgeFile).filter_by( db.query(KnowledgeFile).filter_by(
knowledge_id=knowledge_id, file_id=file_id knowledge_id=knowledge_id, file_id=file_id
).delete() ).delete()
@ -529,9 +532,9 @@ class KnowledgeTable:
except Exception: except Exception:
return False return False
def reset_knowledge_by_id(self, id: str) -> Optional[KnowledgeModel]: def reset_knowledge_by_id(self, id: str, db: Optional[Session] = None) -> Optional[KnowledgeModel]:
try: try:
with get_db() as db: with get_db_context(db) as db:
# Delete all knowledge_file entries for this knowledge_id # Delete all knowledge_file entries for this knowledge_id
db.query(KnowledgeFile).filter_by(knowledge_id=id).delete() db.query(KnowledgeFile).filter_by(knowledge_id=id).delete()
db.commit() db.commit()
@ -544,17 +547,17 @@ class KnowledgeTable:
) )
db.commit() db.commit()
return self.get_knowledge_by_id(id=id) return self.get_knowledge_by_id(id=id, db=db)
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
return None return None
def update_knowledge_by_id( def update_knowledge_by_id(
self, id: str, form_data: KnowledgeForm, overwrite: bool = False self, id: str, form_data: KnowledgeForm, overwrite: bool = False, db: Optional[Session] = None
) -> Optional[KnowledgeModel]: ) -> Optional[KnowledgeModel]:
try: try:
with get_db() as db: with get_db_context(db) as db:
knowledge = self.get_knowledge_by_id(id=id) knowledge = self.get_knowledge_by_id(id=id, db=db)
db.query(Knowledge).filter_by(id=id).update( db.query(Knowledge).filter_by(id=id).update(
{ {
**form_data.model_dump(), **form_data.model_dump(),
@ -562,17 +565,17 @@ class KnowledgeTable:
} }
) )
db.commit() db.commit()
return self.get_knowledge_by_id(id=id) return self.get_knowledge_by_id(id=id, db=db)
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
return None return None
def update_knowledge_data_by_id( def update_knowledge_data_by_id(
self, id: str, data: dict self, id: str, data: dict, db: Optional[Session] = None
) -> Optional[KnowledgeModel]: ) -> Optional[KnowledgeModel]:
try: try:
with get_db() as db: with get_db_context(db) as db:
knowledge = self.get_knowledge_by_id(id=id) knowledge = self.get_knowledge_by_id(id=id, db=db)
db.query(Knowledge).filter_by(id=id).update( db.query(Knowledge).filter_by(id=id).update(
{ {
"data": data, "data": data,
@ -580,22 +583,22 @@ class KnowledgeTable:
} }
) )
db.commit() db.commit()
return self.get_knowledge_by_id(id=id) return self.get_knowledge_by_id(id=id, db=db)
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
return None return None
def delete_knowledge_by_id(self, id: str) -> bool: def delete_knowledge_by_id(self, id: str, db: Optional[Session] = None) -> bool:
try: try:
with get_db() as db: with get_db_context(db) as db:
db.query(Knowledge).filter_by(id=id).delete() db.query(Knowledge).filter_by(id=id).delete()
db.commit() db.commit()
return True return True
except Exception: except Exception:
return False return False
def delete_all_knowledge(self) -> bool: def delete_all_knowledge(self, db: Optional[Session] = None) -> bool:
with get_db() as db: with get_db_context(db) as db:
try: try:
db.query(Knowledge).delete() db.query(Knowledge).delete()
db.commit() db.commit()

View file

@ -2,7 +2,8 @@ import time
import uuid import uuid
from typing import Optional from typing import Optional
from open_webui.internal.db import Base, get_db from sqlalchemy.orm import Session
from open_webui.internal.db import Base, get_db, get_db_context
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, String, Text from sqlalchemy import BigInteger, Column, String, Text
@ -41,8 +42,9 @@ class MemoriesTable:
self, self,
user_id: str, user_id: str,
content: str, content: str,
db: Optional[Session] = None,
) -> Optional[MemoryModel]: ) -> Optional[MemoryModel]:
with get_db() as db: with get_db_context(db) as db:
id = str(uuid.uuid4()) id = str(uuid.uuid4())
memory = MemoryModel( memory = MemoryModel(
@ -68,8 +70,9 @@ class MemoriesTable:
id: str, id: str,
user_id: str, user_id: str,
content: str, content: str,
db: Optional[Session] = None,
) -> Optional[MemoryModel]: ) -> Optional[MemoryModel]:
with get_db() as db: with get_db_context(db) as db:
try: try:
memory = db.get(Memory, id) memory = db.get(Memory, id)
if not memory or memory.user_id != user_id: if not memory or memory.user_id != user_id:
@ -83,32 +86,32 @@ class MemoriesTable:
except Exception: except Exception:
return None return None
def get_memories(self) -> list[MemoryModel]: def get_memories(self, db: Optional[Session] = None) -> list[MemoryModel]:
with get_db() as db: with get_db_context(db) as db:
try: try:
memories = db.query(Memory).all() memories = db.query(Memory).all()
return [MemoryModel.model_validate(memory) for memory in memories] return [MemoryModel.model_validate(memory) for memory in memories]
except Exception: except Exception:
return None return None
def get_memories_by_user_id(self, user_id: str) -> list[MemoryModel]: def get_memories_by_user_id(self, user_id: str, db: Optional[Session] = None) -> list[MemoryModel]:
with get_db() as db: with get_db_context(db) as db:
try: try:
memories = db.query(Memory).filter_by(user_id=user_id).all() memories = db.query(Memory).filter_by(user_id=user_id).all()
return [MemoryModel.model_validate(memory) for memory in memories] return [MemoryModel.model_validate(memory) for memory in memories]
except Exception: except Exception:
return None return None
def get_memory_by_id(self, id: str) -> Optional[MemoryModel]: def get_memory_by_id(self, id: str, db: Optional[Session] = None) -> Optional[MemoryModel]:
with get_db() as db: with get_db_context(db) as db:
try: try:
memory = db.get(Memory, id) memory = db.get(Memory, id)
return MemoryModel.model_validate(memory) return MemoryModel.model_validate(memory)
except Exception: except Exception:
return None return None
def delete_memory_by_id(self, id: str) -> bool: def delete_memory_by_id(self, id: str, db: Optional[Session] = None) -> bool:
with get_db() as db: with get_db_context(db) as db:
try: try:
db.query(Memory).filter_by(id=id).delete() db.query(Memory).filter_by(id=id).delete()
db.commit() db.commit()
@ -118,8 +121,8 @@ class MemoriesTable:
except Exception: except Exception:
return False return False
def delete_memories_by_user_id(self, user_id: str) -> bool: def delete_memories_by_user_id(self, user_id: str, db: Optional[Session] = None) -> bool:
with get_db() as db: with get_db_context(db) as db:
try: try:
db.query(Memory).filter_by(user_id=user_id).delete() db.query(Memory).filter_by(user_id=user_id).delete()
db.commit() db.commit()
@ -128,8 +131,8 @@ class MemoriesTable:
except Exception: except Exception:
return False return False
def delete_memory_by_id_and_user_id(self, id: str, user_id: str) -> bool: def delete_memory_by_id_and_user_id(self, id: str, user_id: str, db: Optional[Session] = None) -> bool:
with get_db() as db: with get_db_context(db) as db:
try: try:
memory = db.get(Memory, id) memory = db.get(Memory, id)
if not memory or memory.user_id != user_id: if not memory or memory.user_id != user_id:

View file

@ -3,7 +3,8 @@ import time
import uuid import uuid
from typing import Optional from typing import Optional
from open_webui.internal.db import Base, get_db from sqlalchemy.orm import Session
from open_webui.internal.db import Base, JSONField, get_db, get_db_context
from open_webui.models.tags import TagModel, Tag, Tags from open_webui.models.tags import TagModel, Tag, Tags
from open_webui.models.users import Users, User, UserNameResponse from open_webui.models.users import Users, User, UserNameResponse
from open_webui.models.channels import Channels, ChannelMember from open_webui.models.channels import Channels, ChannelMember
@ -137,9 +138,9 @@ class MessageResponse(MessageReplyToResponse):
class MessageTable: class MessageTable:
def insert_new_message( def insert_new_message(
self, form_data: MessageForm, channel_id: str, user_id: str self, form_data: MessageForm, channel_id: str, user_id: str, db: Optional[Session] = None
) -> Optional[MessageModel]: ) -> Optional[MessageModel]:
with get_db() as db: with get_db_context(db) as db:
channel_member = Channels.join_channel(channel_id, user_id) channel_member = Channels.join_channel(channel_id, user_id)
id = str(uuid.uuid4()) id = str(uuid.uuid4())
@ -169,22 +170,22 @@ class MessageTable:
db.refresh(result) db.refresh(result)
return MessageModel.model_validate(result) if result else None return MessageModel.model_validate(result) if result else None
def get_message_by_id(self, id: str) -> Optional[MessageResponse]: def get_message_by_id(self, id: str, db: Optional[Session] = None) -> Optional[MessageResponse]:
with get_db() as db: with get_db_context(db) as db:
message = db.get(Message, id) message = db.get(Message, id)
if not message: if not message:
return None return None
reply_to_message = ( reply_to_message = (
self.get_message_by_id(message.reply_to_id) self.get_message_by_id(message.reply_to_id, db=db)
if message.reply_to_id if message.reply_to_id
else None else None
) )
reactions = self.get_reactions_by_message_id(id) reactions = self.get_reactions_by_message_id(id, db=db)
thread_replies = self.get_thread_replies_by_message_id(id) thread_replies = self.get_thread_replies_by_message_id(id, db=db)
user = Users.get_user_by_id(message.user_id) user = Users.get_user_by_id(message.user_id, db=db)
return MessageResponse.model_validate( return MessageResponse.model_validate(
{ {
**MessageModel.model_validate(message).model_dump(), **MessageModel.model_validate(message).model_dump(),
@ -200,8 +201,8 @@ class MessageTable:
} }
) )
def get_thread_replies_by_message_id(self, id: str) -> list[MessageReplyToResponse]: def get_thread_replies_by_message_id(self, id: str, db: Optional[Session] = None) -> list[MessageReplyToResponse]:
with get_db() as db: with get_db_context(db) as db:
all_messages = ( all_messages = (
db.query(Message) db.query(Message)
.filter_by(parent_id=id) .filter_by(parent_id=id)
@ -212,7 +213,7 @@ class MessageTable:
messages = [] messages = []
for message in all_messages: for message in all_messages:
reply_to_message = ( reply_to_message = (
self.get_message_by_id(message.reply_to_id) self.get_message_by_id(message.reply_to_id, db=db)
if message.reply_to_id if message.reply_to_id
else None else None
) )
@ -230,17 +231,17 @@ class MessageTable:
) )
return messages return messages
def get_reply_user_ids_by_message_id(self, id: str) -> list[str]: def get_reply_user_ids_by_message_id(self, id: str, db: Optional[Session] = None) -> list[str]:
with get_db() as db: with get_db_context(db) as db:
return [ return [
message.user_id message.user_id
for message in db.query(Message).filter_by(parent_id=id).all() for message in db.query(Message).filter_by(parent_id=id).all()
] ]
def get_messages_by_channel_id( def get_messages_by_channel_id(
self, channel_id: str, skip: int = 0, limit: int = 50 self, channel_id: str, skip: int = 0, limit: int = 50, db: Optional[Session] = None
) -> list[MessageReplyToResponse]: ) -> list[MessageReplyToResponse]:
with get_db() as db: with get_db_context(db) as db:
all_messages = ( all_messages = (
db.query(Message) db.query(Message)
.filter_by(channel_id=channel_id, parent_id=None) .filter_by(channel_id=channel_id, parent_id=None)
@ -253,7 +254,7 @@ class MessageTable:
messages = [] messages = []
for message in all_messages: for message in all_messages:
reply_to_message = ( reply_to_message = (
self.get_message_by_id(message.reply_to_id) self.get_message_by_id(message.reply_to_id, db=db)
if message.reply_to_id if message.reply_to_id
else None else None
) )
@ -272,9 +273,9 @@ class MessageTable:
return messages return messages
def get_messages_by_parent_id( def get_messages_by_parent_id(
self, channel_id: str, parent_id: str, skip: int = 0, limit: int = 50 self, channel_id: str, parent_id: str, skip: int = 0, limit: int = 50, db: Optional[Session] = None
) -> list[MessageReplyToResponse]: ) -> list[MessageReplyToResponse]:
with get_db() as db: with get_db_context(db) as db:
message = db.get(Message, parent_id) message = db.get(Message, parent_id)
if not message: if not message:
@ -296,7 +297,7 @@ class MessageTable:
messages = [] messages = []
for message in all_messages: for message in all_messages:
reply_to_message = ( reply_to_message = (
self.get_message_by_id(message.reply_to_id) self.get_message_by_id(message.reply_to_id, db=db)
if message.reply_to_id if message.reply_to_id
else None else None
) )
@ -314,8 +315,8 @@ class MessageTable:
) )
return messages return messages
def get_last_message_by_channel_id(self, channel_id: str) -> Optional[MessageModel]: def get_last_message_by_channel_id(self, channel_id: str, db: Optional[Session] = None) -> Optional[MessageModel]:
with get_db() as db: with get_db_context(db) as db:
message = ( message = (
db.query(Message) db.query(Message)
.filter_by(channel_id=channel_id) .filter_by(channel_id=channel_id)
@ -325,9 +326,9 @@ class MessageTable:
return MessageModel.model_validate(message) if message else None return MessageModel.model_validate(message) if message else None
def get_pinned_messages_by_channel_id( def get_pinned_messages_by_channel_id(
self, channel_id: str, skip: int = 0, limit: int = 50 self, channel_id: str, skip: int = 0, limit: int = 50, db: Optional[Session] = None
) -> list[MessageModel]: ) -> list[MessageModel]:
with get_db() as db: with get_db_context(db) as db:
all_messages = ( all_messages = (
db.query(Message) db.query(Message)
.filter_by(channel_id=channel_id, is_pinned=True) .filter_by(channel_id=channel_id, is_pinned=True)
@ -339,9 +340,9 @@ class MessageTable:
return [MessageModel.model_validate(message) for message in all_messages] return [MessageModel.model_validate(message) for message in all_messages]
def update_message_by_id( def update_message_by_id(
self, id: str, form_data: MessageForm self, id: str, form_data: MessageForm, db: Optional[Session] = None
) -> Optional[MessageModel]: ) -> Optional[MessageModel]:
with get_db() as db: with get_db_context(db) as db:
message = db.get(Message, id) message = db.get(Message, id)
message.content = form_data.content message.content = form_data.content
message.data = { message.data = {
@ -358,9 +359,9 @@ class MessageTable:
return MessageModel.model_validate(message) if message else None return MessageModel.model_validate(message) if message else None
def update_is_pinned_by_id( def update_is_pinned_by_id(
self, id: str, is_pinned: bool, pinned_by: Optional[str] = None self, id: str, is_pinned: bool, pinned_by: Optional[str] = None, db: Optional[Session] = None
) -> Optional[MessageModel]: ) -> Optional[MessageModel]:
with get_db() as db: with get_db_context(db) as db:
message = db.get(Message, id) message = db.get(Message, id)
message.is_pinned = is_pinned message.is_pinned = is_pinned
message.pinned_at = int(time.time_ns()) if is_pinned else None message.pinned_at = int(time.time_ns()) if is_pinned else None
@ -370,9 +371,9 @@ class MessageTable:
return MessageModel.model_validate(message) if message else None return MessageModel.model_validate(message) if message else None
def get_unread_message_count( def get_unread_message_count(
self, channel_id: str, user_id: str, last_read_at: Optional[int] = None self, channel_id: str, user_id: str, last_read_at: Optional[int] = None, db: Optional[Session] = None
) -> int: ) -> int:
with get_db() as db: with get_db_context(db) as db:
query = db.query(Message).filter( query = db.query(Message).filter(
Message.channel_id == channel_id, Message.channel_id == channel_id,
Message.parent_id == None, # only count top-level messages Message.parent_id == None, # only count top-level messages
@ -383,9 +384,9 @@ class MessageTable:
return query.count() return query.count()
def add_reaction_to_message( def add_reaction_to_message(
self, id: str, user_id: str, name: str self, id: str, user_id: str, name: str, db: Optional[Session] = None
) -> Optional[MessageReactionModel]: ) -> Optional[MessageReactionModel]:
with get_db() as db: with get_db_context(db) as db:
# check for existing reaction # check for existing reaction
existing_reaction = ( existing_reaction = (
db.query(MessageReaction) db.query(MessageReaction)
@ -409,8 +410,8 @@ class MessageTable:
db.refresh(result) db.refresh(result)
return MessageReactionModel.model_validate(result) if result else None return MessageReactionModel.model_validate(result) if result else None
def get_reactions_by_message_id(self, id: str) -> list[Reactions]: def get_reactions_by_message_id(self, id: str, db: Optional[Session] = None) -> list[Reactions]:
with get_db() as db: with get_db_context(db) as db:
# JOIN User so all user info is fetched in one query # JOIN User so all user info is fetched in one query
results = ( results = (
db.query(MessageReaction, User) db.query(MessageReaction, User)
@ -440,29 +441,29 @@ class MessageTable:
return [Reactions(**reaction) for reaction in reactions.values()] return [Reactions(**reaction) for reaction in reactions.values()]
def remove_reaction_by_id_and_user_id_and_name( def remove_reaction_by_id_and_user_id_and_name(
self, id: str, user_id: str, name: str self, id: str, user_id: str, name: str, db: Optional[Session] = None
) -> bool: ) -> bool:
with get_db() as db: with get_db_context(db) as db:
db.query(MessageReaction).filter_by( db.query(MessageReaction).filter_by(
message_id=id, user_id=user_id, name=name message_id=id, user_id=user_id, name=name
).delete() ).delete()
db.commit() db.commit()
return True return True
def delete_reactions_by_id(self, id: str) -> bool: def delete_reactions_by_id(self, id: str, db: Optional[Session] = None) -> bool:
with get_db() as db: with get_db_context(db) as db:
db.query(MessageReaction).filter_by(message_id=id).delete() db.query(MessageReaction).filter_by(message_id=id).delete()
db.commit() db.commit()
return True return True
def delete_replies_by_id(self, id: str) -> bool: def delete_replies_by_id(self, id: str, db: Optional[Session] = None) -> bool:
with get_db() as db: with get_db_context(db) as db:
db.query(Message).filter_by(parent_id=id).delete() db.query(Message).filter_by(parent_id=id).delete()
db.commit() db.commit()
return True return True
def delete_message_by_id(self, id: str) -> bool: def delete_message_by_id(self, id: str, db: Optional[Session] = None) -> bool:
with get_db() as db: with get_db_context(db) as db:
db.query(Message).filter_by(id=id).delete() db.query(Message).filter_by(id=id).delete()
# Delete all reactions to this message # Delete all reactions to this message

View file

@ -2,7 +2,8 @@ import logging
import time import time
from typing import Optional from typing import Optional
from open_webui.internal.db import Base, JSONField, get_db from sqlalchemy.orm import Session
from open_webui.internal.db import Base, JSONField, get_db, get_db_context
from open_webui.models.groups import Groups from open_webui.models.groups import Groups
from open_webui.models.users import User, UserModel, Users, UserResponse from open_webui.models.users import User, UserModel, Users, UserResponse
@ -150,7 +151,7 @@ class ModelForm(BaseModel):
class ModelsTable: class ModelsTable:
def insert_new_model( def insert_new_model(
self, form_data: ModelForm, user_id: str self, form_data: ModelForm, user_id: str, db: Optional[Session] = None
) -> Optional[ModelModel]: ) -> Optional[ModelModel]:
model = ModelModel( model = ModelModel(
**{ **{
@ -161,7 +162,7 @@ class ModelsTable:
} }
) )
try: try:
with get_db() as db: with get_db_context(db) as db:
result = Model(**model.model_dump()) result = Model(**model.model_dump())
db.add(result) db.add(result)
db.commit() db.commit()
@ -175,17 +176,17 @@ class ModelsTable:
log.exception(f"Failed to insert a new model: {e}") log.exception(f"Failed to insert a new model: {e}")
return None return None
def get_all_models(self) -> list[ModelModel]: def get_all_models(self, db: Optional[Session] = None) -> list[ModelModel]:
with get_db() as db: with get_db_context(db) as db:
return [ModelModel.model_validate(model) for model in db.query(Model).all()] return [ModelModel.model_validate(model) for model in db.query(Model).all()]
def get_models(self) -> list[ModelUserResponse]: def get_models(self, db: Optional[Session] = None) -> list[ModelUserResponse]:
with get_db() as db: with get_db_context(db) as db:
all_models = db.query(Model).filter(Model.base_model_id != None).all() all_models = db.query(Model).filter(Model.base_model_id != None).all()
user_ids = list(set(model.user_id for model in all_models)) user_ids = list(set(model.user_id for model in all_models))
users = Users.get_users_by_user_ids(user_ids) if user_ids else [] users = Users.get_users_by_user_ids(user_ids, db=db) if user_ids else []
users_dict = {user.id: user for user in users} users_dict = {user.id: user for user in users}
models = [] models = []
@ -201,18 +202,18 @@ class ModelsTable:
) )
return models return models
def get_base_models(self) -> list[ModelModel]: def get_base_models(self, db: Optional[Session] = None) -> list[ModelModel]:
with get_db() as db: with get_db_context(db) as db:
return [ return [
ModelModel.model_validate(model) ModelModel.model_validate(model)
for model in db.query(Model).filter(Model.base_model_id == None).all() for model in db.query(Model).filter(Model.base_model_id == None).all()
] ]
def get_models_by_user_id( def get_models_by_user_id(
self, user_id: str, permission: str = "write" self, user_id: str, permission: str = "write", db: Optional[Session] = None
) -> list[ModelUserResponse]: ) -> list[ModelUserResponse]:
models = self.get_models() models = self.get_models(db=db)
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id)} user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id, db=db)}
return [ return [
model model
for model in models for model in models
@ -263,9 +264,9 @@ class ModelsTable:
return query return query
def search_models( def search_models(
self, user_id: str, filter: dict = {}, skip: int = 0, limit: int = 30 self, user_id: str, filter: dict = {}, skip: int = 0, limit: int = 30, db: Optional[Session] = None
) -> ModelListResponse: ) -> ModelListResponse:
with get_db() as db: with get_db_context(db) as db:
# Join GroupMember so we can order by group_id when requested # Join GroupMember so we can order by group_id when requested
query = db.query(Model, User).outerjoin(User, User.id == Model.user_id) query = db.query(Model, User).outerjoin(User, User.id == Model.user_id)
query = query.filter(Model.base_model_id != None) query = query.filter(Model.base_model_id != None)
@ -349,24 +350,24 @@ class ModelsTable:
return ModelListResponse(items=models, total=total) return ModelListResponse(items=models, total=total)
def get_model_by_id(self, id: str) -> Optional[ModelModel]: def get_model_by_id(self, id: str, db: Optional[Session] = None) -> Optional[ModelModel]:
try: try:
with get_db() as db: with get_db_context(db) as db:
model = db.get(Model, id) model = db.get(Model, id)
return ModelModel.model_validate(model) return ModelModel.model_validate(model)
except Exception: except Exception:
return None return None
def get_models_by_ids(self, ids: list[str]) -> list[ModelModel]: def get_models_by_ids(self, ids: list[str], db: Optional[Session] = None) -> list[ModelModel]:
try: try:
with get_db() as db: with get_db_context(db) as db:
models = db.query(Model).filter(Model.id.in_(ids)).all() models = db.query(Model).filter(Model.id.in_(ids)).all()
return [ModelModel.model_validate(model) for model in models] return [ModelModel.model_validate(model) for model in models]
except Exception: except Exception:
return [] return []
def toggle_model_by_id(self, id: str) -> Optional[ModelModel]: def toggle_model_by_id(self, id: str, db: Optional[Session] = None) -> Optional[ModelModel]:
with get_db() as db: with get_db_context(db) as db:
try: try:
is_active = db.query(Model).filter_by(id=id).first().is_active is_active = db.query(Model).filter_by(id=id).first().is_active
@ -378,13 +379,13 @@ class ModelsTable:
) )
db.commit() db.commit()
return self.get_model_by_id(id) return self.get_model_by_id(id, db=db)
except Exception: except Exception:
return None return None
def update_model_by_id(self, id: str, model: ModelForm) -> Optional[ModelModel]: def update_model_by_id(self, id: str, model: ModelForm, db: Optional[Session] = None) -> Optional[ModelModel]:
try: try:
with get_db() as db: with get_db_context(db) as db:
# update only the fields that are present in the model # update only the fields that are present in the model
data = model.model_dump(exclude={"id"}) data = model.model_dump(exclude={"id"})
result = db.query(Model).filter_by(id=id).update(data) result = db.query(Model).filter_by(id=id).update(data)
@ -398,9 +399,9 @@ class ModelsTable:
log.exception(f"Failed to update the model by id {id}: {e}") log.exception(f"Failed to update the model by id {id}: {e}")
return None return None
def delete_model_by_id(self, id: str) -> bool: def delete_model_by_id(self, id: str, db: Optional[Session] = None) -> bool:
try: try:
with get_db() as db: with get_db_context(db) as db:
db.query(Model).filter_by(id=id).delete() db.query(Model).filter_by(id=id).delete()
db.commit() db.commit()
@ -408,9 +409,9 @@ class ModelsTable:
except Exception: except Exception:
return False return False
def delete_all_models(self) -> bool: def delete_all_models(self, db: Optional[Session] = None) -> bool:
try: try:
with get_db() as db: with get_db_context(db) as db:
db.query(Model).delete() db.query(Model).delete()
db.commit() db.commit()
@ -418,9 +419,9 @@ class ModelsTable:
except Exception: except Exception:
return False return False
def sync_models(self, user_id: str, models: list[ModelModel]) -> list[ModelModel]: def sync_models(self, user_id: str, models: list[ModelModel], db: Optional[Session] = None) -> list[ModelModel]:
try: try:
with get_db() as db: with get_db_context(db) as db:
# Get existing models # Get existing models
existing_models = db.query(Model).all() existing_models = db.query(Model).all()
existing_ids = {model.id for model in existing_models} existing_ids = {model.id for model in existing_models}

View file

@ -4,7 +4,8 @@ import uuid
from typing import Optional from typing import Optional
from functools import lru_cache from functools import lru_cache
from open_webui.internal.db import Base, get_db from sqlalchemy.orm import Session
from open_webui.internal.db import Base, get_db, get_db_context
from open_webui.models.groups import Groups from open_webui.models.groups import Groups
from open_webui.utils.access_control import has_access from open_webui.utils.access_control import has_access
from open_webui.models.users import User, UserModel, Users, UserResponse from open_webui.models.users import User, UserModel, Users, UserResponse
@ -211,11 +212,9 @@ class NoteTable:
return query return query
def insert_new_note( def insert_new_note(
self, self, user_id: str, form_data: NoteForm, db: Optional[Session] = None
form_data: NoteForm,
user_id: str,
) -> Optional[NoteModel]: ) -> Optional[NoteModel]:
with get_db() as db: with get_db_context(db) as db:
note = NoteModel( note = NoteModel(
**{ **{
"id": str(uuid.uuid4()), "id": str(uuid.uuid4()),
@ -233,9 +232,9 @@ class NoteTable:
return note return note
def get_notes( def get_notes(
self, skip: Optional[int] = None, limit: Optional[int] = None self, skip: int = 0, limit: int = 50, db: Optional[Session] = None
) -> list[NoteModel]: ) -> list[NoteModel]:
with get_db() as db: with get_db_context(db) as db:
query = db.query(Note).order_by(Note.updated_at.desc()) query = db.query(Note).order_by(Note.updated_at.desc())
if skip is not None: if skip is not None:
query = query.offset(skip) query = query.offset(skip)
@ -333,10 +332,11 @@ class NoteTable:
self, self,
user_id: str, user_id: str,
permission: str = "read", permission: str = "read",
skip: Optional[int] = None, skip: int = 0,
limit: Optional[int] = None, limit: int = 50,
db: Optional[Session] = None,
) -> list[NoteModel]: ) -> list[NoteModel]:
with get_db() as db: with get_db_context(db) as db:
user_group_ids = [ user_group_ids = [
group.id for group in Groups.get_groups_by_member_id(user_id) group.id for group in Groups.get_groups_by_member_id(user_id)
] ]
@ -354,15 +354,17 @@ class NoteTable:
notes = query.all() notes = query.all()
return [NoteModel.model_validate(note) for note in notes] return [NoteModel.model_validate(note) for note in notes]
def get_note_by_id(self, id: str) -> Optional[NoteModel]: def get_note_by_id(
with get_db() as db: self, id: str, db: Optional[Session] = None
) -> Optional[NoteModel]:
with get_db_context(db) as db:
note = db.query(Note).filter(Note.id == id).first() note = db.query(Note).filter(Note.id == id).first()
return NoteModel.model_validate(note) if note else None return NoteModel.model_validate(note) if note else None
def update_note_by_id( def update_note_by_id(
self, id: str, form_data: NoteUpdateForm self, id: str, form_data: NoteUpdateForm, db: Optional[Session] = None
) -> Optional[NoteModel]: ) -> Optional[NoteModel]:
with get_db() as db: with get_db_context(db) as db:
note = db.query(Note).filter(Note.id == id).first() note = db.query(Note).filter(Note.id == id).first()
if not note: if not note:
return None return None
@ -384,11 +386,14 @@ class NoteTable:
db.commit() db.commit()
return NoteModel.model_validate(note) if note else None return NoteModel.model_validate(note) if note else None
def delete_note_by_id(self, id: str): def delete_note_by_id(self, id: str, db: Optional[Session] = None) -> bool:
with get_db() as db: try:
db.query(Note).filter(Note.id == id).delete() with get_db_context(db) as db:
db.commit() db.query(Note).filter(Note.id == id).delete()
return True db.commit()
return True
except Exception:
return False
Notes = NoteTable() Notes = NoteTable()

View file

@ -8,7 +8,8 @@ import json
from cryptography.fernet import Fernet from cryptography.fernet import Fernet
from open_webui.internal.db import Base, get_db from sqlalchemy.orm import Session
from open_webui.internal.db import Base, get_db, get_db_context
from open_webui.env import OAUTH_SESSION_TOKEN_ENCRYPTION_KEY from open_webui.env import OAUTH_SESSION_TOKEN_ENCRYPTION_KEY
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
@ -109,10 +110,11 @@ class OAuthSessionTable:
user_id: str, user_id: str,
provider: str, provider: str,
token: dict, token: dict,
db: Optional[Session] = None,
) -> Optional[OAuthSessionModel]: ) -> Optional[OAuthSessionModel]:
"""Create a new OAuth session""" """Create a new OAuth session"""
try: try:
with get_db() as db: with get_db_context(db) as db:
current_time = int(time.time()) current_time = int(time.time())
id = str(uuid.uuid4()) id = str(uuid.uuid4())
@ -141,10 +143,10 @@ class OAuthSessionTable:
log.error(f"Error creating OAuth session: {e}") log.error(f"Error creating OAuth session: {e}")
return None return None
def get_session_by_id(self, session_id: str) -> Optional[OAuthSessionModel]: def get_session_by_id(self, session_id: str, db: Optional[Session] = None) -> Optional[OAuthSessionModel]:
"""Get OAuth session by ID""" """Get OAuth session by ID"""
try: try:
with get_db() as db: with get_db_context(db) as db:
session = db.query(OAuthSession).filter_by(id=session_id).first() session = db.query(OAuthSession).filter_by(id=session_id).first()
if session: if session:
session.token = self._decrypt_token(session.token) session.token = self._decrypt_token(session.token)
@ -156,11 +158,11 @@ class OAuthSessionTable:
return None return None
def get_session_by_id_and_user_id( def get_session_by_id_and_user_id(
self, session_id: str, user_id: str self, session_id: str, user_id: str, db: Optional[Session] = None
) -> Optional[OAuthSessionModel]: ) -> Optional[OAuthSessionModel]:
"""Get OAuth session by ID and user ID""" """Get OAuth session by ID and user ID"""
try: try:
with get_db() as db: with get_db_context(db) as db:
session = ( session = (
db.query(OAuthSession) db.query(OAuthSession)
.filter_by(id=session_id, user_id=user_id) .filter_by(id=session_id, user_id=user_id)
@ -176,11 +178,11 @@ class OAuthSessionTable:
return None return None
def get_session_by_provider_and_user_id( def get_session_by_provider_and_user_id(
self, provider: str, user_id: str self, provider: str, user_id: str, db: Optional[Session] = None
) -> Optional[OAuthSessionModel]: ) -> Optional[OAuthSessionModel]:
"""Get OAuth session by provider and user ID""" """Get OAuth session by provider and user ID"""
try: try:
with get_db() as db: with get_db_context(db) as db:
session = ( session = (
db.query(OAuthSession) db.query(OAuthSession)
.filter_by(provider=provider, user_id=user_id) .filter_by(provider=provider, user_id=user_id)
@ -195,10 +197,10 @@ class OAuthSessionTable:
log.error(f"Error getting OAuth session by provider and user ID: {e}") log.error(f"Error getting OAuth session by provider and user ID: {e}")
return None return None
def get_sessions_by_user_id(self, user_id: str) -> List[OAuthSessionModel]: def get_sessions_by_user_id(self, user_id: str, db: Optional[Session] = None) -> List[OAuthSessionModel]:
"""Get all OAuth sessions for a user""" """Get all OAuth sessions for a user"""
try: try:
with get_db() as db: with get_db_context(db) as db:
sessions = db.query(OAuthSession).filter_by(user_id=user_id).all() sessions = db.query(OAuthSession).filter_by(user_id=user_id).all()
results = [] results = []
@ -213,11 +215,11 @@ class OAuthSessionTable:
return [] return []
def update_session_by_id( def update_session_by_id(
self, session_id: str, token: dict self, session_id: str, token: dict, db: Optional[Session] = None
) -> Optional[OAuthSessionModel]: ) -> Optional[OAuthSessionModel]:
"""Update OAuth session tokens""" """Update OAuth session tokens"""
try: try:
with get_db() as db: with get_db_context(db) as db:
current_time = int(time.time()) current_time = int(time.time())
db.query(OAuthSession).filter_by(id=session_id).update( db.query(OAuthSession).filter_by(id=session_id).update(
@ -239,10 +241,10 @@ class OAuthSessionTable:
log.error(f"Error updating OAuth session tokens: {e}") log.error(f"Error updating OAuth session tokens: {e}")
return None return None
def delete_session_by_id(self, session_id: str) -> bool: def delete_session_by_id(self, session_id: str, db: Optional[Session] = None) -> bool:
"""Delete an OAuth session""" """Delete an OAuth session"""
try: try:
with get_db() as db: with get_db_context(db) as db:
result = db.query(OAuthSession).filter_by(id=session_id).delete() result = db.query(OAuthSession).filter_by(id=session_id).delete()
db.commit() db.commit()
return result > 0 return result > 0
@ -250,10 +252,10 @@ class OAuthSessionTable:
log.error(f"Error deleting OAuth session: {e}") log.error(f"Error deleting OAuth session: {e}")
return False return False
def delete_sessions_by_user_id(self, user_id: str) -> bool: def delete_sessions_by_user_id(self, user_id: str, db: Optional[Session] = None) -> bool:
"""Delete all OAuth sessions for a user""" """Delete all OAuth sessions for a user"""
try: try:
with get_db() as db: with get_db_context(db) as db:
result = db.query(OAuthSession).filter_by(user_id=user_id).delete() result = db.query(OAuthSession).filter_by(user_id=user_id).delete()
db.commit() db.commit()
return True return True
@ -261,10 +263,10 @@ class OAuthSessionTable:
log.error(f"Error deleting OAuth sessions by user ID: {e}") log.error(f"Error deleting OAuth sessions by user ID: {e}")
return False return False
def delete_sessions_by_provider(self, provider: str) -> bool: def delete_sessions_by_provider(self, provider: str, db: Optional[Session] = None) -> bool:
"""Delete all OAuth sessions for a provider""" """Delete all OAuth sessions for a provider"""
try: try:
with get_db() as db: with get_db_context(db) as db:
db.query(OAuthSession).filter_by(provider=provider).delete() db.query(OAuthSession).filter_by(provider=provider).delete()
db.commit() db.commit()
return True return True

View file

@ -1,7 +1,9 @@
import time import time
from typing import Optional from typing import Optional
from open_webui.internal.db import Base, get_db from sqlalchemy.orm import Session
from open_webui.internal.db import Base, JSONField, get_db, get_db_context
from open_webui.models.groups import Groups from open_webui.models.groups import Groups
from open_webui.models.users import Users, UserResponse from open_webui.models.users import Users, UserResponse
@ -71,7 +73,7 @@ class PromptForm(BaseModel):
class PromptsTable: class PromptsTable:
def insert_new_prompt( def insert_new_prompt(
self, user_id: str, form_data: PromptForm self, user_id: str, form_data: PromptForm, db: Optional[Session] = None
) -> Optional[PromptModel]: ) -> Optional[PromptModel]:
prompt = PromptModel( prompt = PromptModel(
**{ **{
@ -82,7 +84,7 @@ class PromptsTable:
) )
try: try:
with get_db() as db: with get_db_context(db) as db:
result = Prompt(**prompt.model_dump()) result = Prompt(**prompt.model_dump())
db.add(result) db.add(result)
db.commit() db.commit()
@ -94,21 +96,21 @@ class PromptsTable:
except Exception: except Exception:
return None return None
def get_prompt_by_command(self, command: str) -> Optional[PromptModel]: def get_prompt_by_command(self, command: str, db: Optional[Session] = None) -> Optional[PromptModel]:
try: try:
with get_db() as db: with get_db_context(db) as db:
prompt = db.query(Prompt).filter_by(command=command).first() prompt = db.query(Prompt).filter_by(command=command).first()
return PromptModel.model_validate(prompt) return PromptModel.model_validate(prompt)
except Exception: except Exception:
return None return None
def get_prompts(self) -> list[PromptUserResponse]: def get_prompts(self, db: Optional[Session] = None) -> list[PromptUserResponse]:
with get_db() as db: with get_db_context(db) as db:
all_prompts = db.query(Prompt).order_by(Prompt.timestamp.desc()).all() all_prompts = db.query(Prompt).order_by(Prompt.timestamp.desc()).all()
user_ids = list(set(prompt.user_id for prompt in all_prompts)) user_ids = list(set(prompt.user_id for prompt in all_prompts))
users = Users.get_users_by_user_ids(user_ids) if user_ids else [] users = Users.get_users_by_user_ids(user_ids, db=db) if user_ids else []
users_dict = {user.id: user for user in users} users_dict = {user.id: user for user in users}
prompts = [] prompts = []
@ -126,10 +128,10 @@ class PromptsTable:
return prompts return prompts
def get_prompts_by_user_id( def get_prompts_by_user_id(
self, user_id: str, permission: str = "write" self, user_id: str, permission: str = "write", db: Optional[Session] = None
) -> list[PromptUserResponse]: ) -> list[PromptUserResponse]:
prompts = self.get_prompts() prompts = self.get_prompts(db=db)
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id)} user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id, db=db)}
return [ return [
prompt prompt
@ -139,10 +141,10 @@ class PromptsTable:
] ]
def update_prompt_by_command( def update_prompt_by_command(
self, command: str, form_data: PromptForm self, command: str, form_data: PromptForm, db: Optional[Session] = None
) -> Optional[PromptModel]: ) -> Optional[PromptModel]:
try: try:
with get_db() as db: with get_db_context(db) as db:
prompt = db.query(Prompt).filter_by(command=command).first() prompt = db.query(Prompt).filter_by(command=command).first()
prompt.title = form_data.title prompt.title = form_data.title
prompt.content = form_data.content prompt.content = form_data.content
@ -153,9 +155,9 @@ class PromptsTable:
except Exception: except Exception:
return None return None
def delete_prompt_by_command(self, command: str) -> bool: def delete_prompt_by_command(self, command: str, db: Optional[Session] = None) -> bool:
try: try:
with get_db() as db: with get_db_context(db) as db:
db.query(Prompt).filter_by(command=command).delete() db.query(Prompt).filter_by(command=command).delete()
db.commit() db.commit()

View file

@ -3,7 +3,8 @@ import time
import uuid import uuid
from typing import Optional from typing import Optional
from open_webui.internal.db import Base, get_db from sqlalchemy.orm import Session
from open_webui.internal.db import Base, JSONField, get_db, get_db_context
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
@ -50,8 +51,8 @@ class TagChatIdForm(BaseModel):
class TagTable: class TagTable:
def insert_new_tag(self, name: str, user_id: str) -> Optional[TagModel]: def insert_new_tag(self, name: str, user_id: str, db: Optional[Session] = None) -> Optional[TagModel]:
with get_db() as db: with get_db_context(db) as db:
id = name.replace(" ", "_").lower() id = name.replace(" ", "_").lower()
tag = TagModel(**{"id": id, "user_id": user_id, "name": name}) tag = TagModel(**{"id": id, "user_id": user_id, "name": name})
try: try:
@ -68,27 +69,27 @@ class TagTable:
return None return None
def get_tag_by_name_and_user_id( def get_tag_by_name_and_user_id(
self, name: str, user_id: str self, name: str, user_id: str, db: Optional[Session] = None
) -> Optional[TagModel]: ) -> Optional[TagModel]:
try: try:
id = name.replace(" ", "_").lower() id = name.replace(" ", "_").lower()
with get_db() as db: with get_db_context(db) as db:
tag = db.query(Tag).filter_by(id=id, user_id=user_id).first() tag = db.query(Tag).filter_by(id=id, user_id=user_id).first()
return TagModel.model_validate(tag) return TagModel.model_validate(tag)
except Exception: except Exception:
return None return None
def get_tags_by_user_id(self, user_id: str) -> list[TagModel]: def get_tags_by_user_id(self, user_id: str, db: Optional[Session] = None) -> list[TagModel]:
with get_db() as db: with get_db_context(db) as db:
return [ return [
TagModel.model_validate(tag) TagModel.model_validate(tag)
for tag in (db.query(Tag).filter_by(user_id=user_id).all()) for tag in (db.query(Tag).filter_by(user_id=user_id).all())
] ]
def get_tags_by_ids_and_user_id( def get_tags_by_ids_and_user_id(
self, ids: list[str], user_id: str self, ids: list[str], user_id: str, db: Optional[Session] = None
) -> list[TagModel]: ) -> list[TagModel]:
with get_db() as db: with get_db_context(db) as db:
return [ return [
TagModel.model_validate(tag) TagModel.model_validate(tag)
for tag in ( for tag in (
@ -96,9 +97,9 @@ class TagTable:
) )
] ]
def delete_tag_by_name_and_user_id(self, name: str, user_id: str) -> bool: def delete_tag_by_name_and_user_id(self, name: str, user_id: str, db: Optional[Session] = None) -> bool:
try: try:
with get_db() as db: with get_db_context(db) as db:
id = name.replace(" ", "_").lower() id = name.replace(" ", "_").lower()
res = db.query(Tag).filter_by(id=id, user_id=user_id).delete() res = db.query(Tag).filter_by(id=id, user_id=user_id).delete()
log.debug(f"res: {res}") log.debug(f"res: {res}")

View file

@ -2,7 +2,8 @@ import logging
import time import time
from typing import Optional from typing import Optional
from open_webui.internal.db import Base, JSONField, get_db from sqlalchemy.orm import Session
from open_webui.internal.db import Base, JSONField, get_db, get_db_context
from open_webui.models.users import Users, UserResponse from open_webui.models.users import Users, UserResponse
from open_webui.models.groups import Groups from open_webui.models.groups import Groups
@ -110,9 +111,9 @@ class ToolValves(BaseModel):
class ToolsTable: class ToolsTable:
def insert_new_tool( def insert_new_tool(
self, user_id: str, form_data: ToolForm, specs: list[dict] self, user_id: str, form_data: ToolForm, specs: list[dict], db: Optional[Session] = None
) -> Optional[ToolModel]: ) -> Optional[ToolModel]:
with get_db() as db: with get_db_context(db) as db:
tool = ToolModel( tool = ToolModel(
**{ **{
**form_data.model_dump(), **form_data.model_dump(),
@ -136,21 +137,21 @@ class ToolsTable:
log.exception(f"Error creating a new tool: {e}") log.exception(f"Error creating a new tool: {e}")
return None return None
def get_tool_by_id(self, id: str) -> Optional[ToolModel]: def get_tool_by_id(self, id: str, db: Optional[Session] = None) -> Optional[ToolModel]:
try: try:
with get_db() as db: with get_db_context(db) as db:
tool = db.get(Tool, id) tool = db.get(Tool, id)
return ToolModel.model_validate(tool) return ToolModel.model_validate(tool)
except Exception: except Exception:
return None return None
def get_tools(self) -> list[ToolUserModel]: def get_tools(self, db: Optional[Session] = None) -> list[ToolUserModel]:
with get_db() as db: with get_db_context(db) as db:
all_tools = db.query(Tool).order_by(Tool.updated_at.desc()).all() all_tools = db.query(Tool).order_by(Tool.updated_at.desc()).all()
user_ids = list(set(tool.user_id for tool in all_tools)) user_ids = list(set(tool.user_id for tool in all_tools))
users = Users.get_users_by_user_ids(user_ids) if user_ids else [] users = Users.get_users_by_user_ids(user_ids, db=db) if user_ids else []
users_dict = {user.id: user for user in users} users_dict = {user.id: user for user in users}
tools = [] tools = []
@ -167,10 +168,10 @@ class ToolsTable:
return tools return tools
def get_tools_by_user_id( def get_tools_by_user_id(
self, user_id: str, permission: str = "write" self, user_id: str, permission: str = "write", db: Optional[Session] = None
) -> list[ToolUserModel]: ) -> list[ToolUserModel]:
tools = self.get_tools() tools = self.get_tools(db=db)
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id)} user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id, db=db)}
return [ return [
tool tool
@ -179,31 +180,31 @@ class ToolsTable:
or has_access(user_id, permission, tool.access_control, user_group_ids) or has_access(user_id, permission, tool.access_control, user_group_ids)
] ]
def get_tool_valves_by_id(self, id: str) -> Optional[dict]: def get_tool_valves_by_id(self, id: str, db: Optional[Session] = None) -> Optional[dict]:
try: try:
with get_db() as db: with get_db_context(db) as db:
tool = db.get(Tool, id) tool = db.get(Tool, id)
return tool.valves if tool.valves else {} return tool.valves if tool.valves else {}
except Exception as e: except Exception as e:
log.exception(f"Error getting tool valves by id {id}") log.exception(f"Error getting tool valves by id {id}")
return None return None
def update_tool_valves_by_id(self, id: str, valves: dict) -> Optional[ToolValves]: def update_tool_valves_by_id(self, id: str, valves: dict, db: Optional[Session] = None) -> Optional[ToolValves]:
try: try:
with get_db() as db: with get_db_context(db) as db:
db.query(Tool).filter_by(id=id).update( db.query(Tool).filter_by(id=id).update(
{"valves": valves, "updated_at": int(time.time())} {"valves": valves, "updated_at": int(time.time())}
) )
db.commit() db.commit()
return self.get_tool_by_id(id) return self.get_tool_by_id(id, db=db)
except Exception: except Exception:
return None return None
def get_user_valves_by_id_and_user_id( def get_user_valves_by_id_and_user_id(
self, id: str, user_id: str self, id: str, user_id: str, db: Optional[Session] = None
) -> Optional[dict]: ) -> Optional[dict]:
try: try:
user = Users.get_user_by_id(user_id) user = Users.get_user_by_id(user_id, db=db)
user_settings = user.settings.model_dump() if user.settings else {} user_settings = user.settings.model_dump() if user.settings else {}
# Check if user has "tools" and "valves" settings # Check if user has "tools" and "valves" settings
@ -220,10 +221,10 @@ class ToolsTable:
return None return None
def update_user_valves_by_id_and_user_id( def update_user_valves_by_id_and_user_id(
self, id: str, user_id: str, valves: dict self, id: str, user_id: str, valves: dict, db: Optional[Session] = None
) -> Optional[dict]: ) -> Optional[dict]:
try: try:
user = Users.get_user_by_id(user_id) user = Users.get_user_by_id(user_id, db=db)
user_settings = user.settings.model_dump() if user.settings else {} user_settings = user.settings.model_dump() if user.settings else {}
# Check if user has "tools" and "valves" settings # Check if user has "tools" and "valves" settings
@ -235,7 +236,7 @@ class ToolsTable:
user_settings["tools"]["valves"][id] = valves user_settings["tools"]["valves"][id] = valves
# Update the user settings in the database # Update the user settings in the database
Users.update_user_by_id(user_id, {"settings": user_settings}) Users.update_user_by_id(user_id, {"settings": user_settings}, db=db)
return user_settings["tools"]["valves"][id] return user_settings["tools"]["valves"][id]
except Exception as e: except Exception as e:
@ -244,9 +245,9 @@ class ToolsTable:
) )
return None return None
def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]: def update_tool_by_id(self, id: str, updated: dict, db: Optional[Session] = None) -> Optional[ToolModel]:
try: try:
with get_db() as db: with get_db_context(db) as db:
db.query(Tool).filter_by(id=id).update( db.query(Tool).filter_by(id=id).update(
{**updated, "updated_at": int(time.time())} {**updated, "updated_at": int(time.time())}
) )
@ -258,9 +259,9 @@ class ToolsTable:
except Exception: except Exception:
return None return None
def delete_tool_by_id(self, id: str) -> bool: def delete_tool_by_id(self, id: str, db: Optional[Session] = None) -> bool:
try: try:
with get_db() as db: with get_db_context(db) as db:
db.query(Tool).filter_by(id=id).delete() db.query(Tool).filter_by(id=id).delete()
db.commit() db.commit()

View file

@ -1,7 +1,8 @@
import time import time
from typing import Optional from typing import Optional
from open_webui.internal.db import Base, JSONField, get_db from sqlalchemy.orm import Session
from open_webui.internal.db import Base, JSONField, get_db, get_db_context
from open_webui.env import DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL from open_webui.env import DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL
@ -243,8 +244,9 @@ class UsersTable:
profile_image_url: str = "/user.png", profile_image_url: str = "/user.png",
role: str = "pending", role: str = "pending",
oauth: Optional[dict] = None, oauth: Optional[dict] = None,
db: Optional[Session] = None,
) -> Optional[UserModel]: ) -> Optional[UserModel]:
with get_db() as db: with get_db_context(db) as db:
user = UserModel( user = UserModel(
**{ **{
"id": id, "id": id,
@ -267,17 +269,17 @@ class UsersTable:
else: else:
return None return None
def get_user_by_id(self, id: str) -> Optional[UserModel]: def get_user_by_id(self, id: str, db: Optional[Session] = None) -> Optional[UserModel]:
try: try:
with get_db() as db: with get_db_context(db) as db:
user = db.query(User).filter_by(id=id).first() user = db.query(User).filter_by(id=id).first()
return UserModel.model_validate(user) return UserModel.model_validate(user)
except Exception: except Exception:
return None return None
def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]: def get_user_by_api_key(self, api_key: str, db: Optional[Session] = None) -> Optional[UserModel]:
try: try:
with get_db() as db: with get_db_context(db) as db:
user = ( user = (
db.query(User) db.query(User)
.join(ApiKey, User.id == ApiKey.user_id) .join(ApiKey, User.id == ApiKey.user_id)
@ -288,17 +290,17 @@ class UsersTable:
except Exception: except Exception:
return None return None
def get_user_by_email(self, email: str) -> Optional[UserModel]: def get_user_by_email(self, email: str, db: Optional[Session] = None) -> Optional[UserModel]:
try: try:
with get_db() as db: with get_db_context(db) as db:
user = db.query(User).filter_by(email=email).first() user = db.query(User).filter_by(email=email).first()
return UserModel.model_validate(user) return UserModel.model_validate(user)
except Exception: except Exception:
return None return None
def get_user_by_oauth_sub(self, provider: str, sub: str) -> Optional[UserModel]: def get_user_by_oauth_sub(self, provider: str, sub: str, db: Optional[Session] = None) -> Optional[UserModel]:
try: try:
with get_db() as db: # type: Session with get_db_context(db) as db: # type: Session
dialect_name = db.bind.dialect.name dialect_name = db.bind.dialect.name
query = db.query(User) query = db.query(User)
@ -320,8 +322,9 @@ class UsersTable:
filter: Optional[dict] = None, filter: Optional[dict] = None,
skip: Optional[int] = None, skip: Optional[int] = None,
limit: Optional[int] = None, limit: Optional[int] = None,
db: Optional[Session] = None,
) -> dict: ) -> dict:
with get_db() as db: with get_db_context(db) as db:
# Join GroupMember so we can order by group_id when requested # Join GroupMember so we can order by group_id when requested
query = db.query(User) query = db.query(User)
@ -452,8 +455,8 @@ class UsersTable:
"total": total, "total": total,
} }
def get_users_by_group_id(self, group_id: str) -> list[UserModel]: def get_users_by_group_id(self, group_id: str, db: Optional[Session] = None) -> list[UserModel]:
with get_db() as db: with get_db_context(db) as db:
users = ( users = (
db.query(User) db.query(User)
.join(GroupMember, User.id == GroupMember.user_id) .join(GroupMember, User.id == GroupMember.user_id)
@ -462,30 +465,30 @@ class UsersTable:
) )
return [UserModel.model_validate(user) for user in users] return [UserModel.model_validate(user) for user in users]
def get_users_by_user_ids(self, user_ids: list[str]) -> list[UserStatusModel]: def get_users_by_user_ids(self, user_ids: list[str], db: Optional[Session] = None) -> list[UserStatusModel]:
with get_db() as db: with get_db_context(db) as db:
users = db.query(User).filter(User.id.in_(user_ids)).all() users = db.query(User).filter(User.id.in_(user_ids)).all()
return [UserModel.model_validate(user) for user in users] return [UserModel.model_validate(user) for user in users]
def get_num_users(self) -> Optional[int]: def get_num_users(self, db: Optional[Session] = None) -> Optional[int]:
with get_db() as db: with get_db_context(db) as db:
return db.query(User).count() return db.query(User).count()
def has_users(self) -> bool: def has_users(self, db: Optional[Session] = None) -> bool:
with get_db() as db: with get_db_context(db) as db:
return db.query(db.query(User).exists()).scalar() return db.query(db.query(User).exists()).scalar()
def get_first_user(self) -> UserModel: def get_first_user(self, db: Optional[Session] = None) -> UserModel:
try: try:
with get_db() as db: with get_db_context(db) as db:
user = db.query(User).order_by(User.created_at).first() user = db.query(User).order_by(User.created_at).first()
return UserModel.model_validate(user) return UserModel.model_validate(user)
except Exception: except Exception:
return None return None
def get_user_webhook_url_by_id(self, id: str) -> Optional[str]: def get_user_webhook_url_by_id(self, id: str, db: Optional[Session] = None) -> Optional[str]:
try: try:
with get_db() as db: with get_db_context(db) as db:
user = db.query(User).filter_by(id=id).first() user = db.query(User).filter_by(id=id).first()
if user.settings is None: if user.settings is None:
@ -499,8 +502,8 @@ class UsersTable:
except Exception: except Exception:
return None return None
def get_num_users_active_today(self) -> Optional[int]: def get_num_users_active_today(self, db: Optional[Session] = None) -> Optional[int]:
with get_db() as db: with get_db_context(db) as db:
current_timestamp = int(datetime.datetime.now().timestamp()) current_timestamp = int(datetime.datetime.now().timestamp())
today_midnight_timestamp = current_timestamp - (current_timestamp % 86400) today_midnight_timestamp = current_timestamp - (current_timestamp % 86400)
query = db.query(User).filter( query = db.query(User).filter(
@ -508,9 +511,9 @@ class UsersTable:
) )
return query.count() return query.count()
def update_user_role_by_id(self, id: str, role: str) -> Optional[UserModel]: def update_user_role_by_id(self, id: str, role: str, db: Optional[Session] = None) -> Optional[UserModel]:
try: try:
with get_db() as db: with get_db_context(db) as db:
db.query(User).filter_by(id=id).update({"role": role}) db.query(User).filter_by(id=id).update({"role": role})
db.commit() db.commit()
user = db.query(User).filter_by(id=id).first() user = db.query(User).filter_by(id=id).first()
@ -519,10 +522,10 @@ class UsersTable:
return None return None
def update_user_status_by_id( def update_user_status_by_id(
self, id: str, form_data: UserStatus self, id: str, form_data: UserStatus, db: Optional[Session] = None
) -> Optional[UserModel]: ) -> Optional[UserModel]:
try: try:
with get_db() as db: with get_db_context(db) as db:
db.query(User).filter_by(id=id).update( db.query(User).filter_by(id=id).update(
{**form_data.model_dump(exclude_none=True)} {**form_data.model_dump(exclude_none=True)}
) )
@ -534,10 +537,10 @@ class UsersTable:
return None return None
def update_user_profile_image_url_by_id( def update_user_profile_image_url_by_id(
self, id: str, profile_image_url: str self, id: str, profile_image_url: str, db: Optional[Session] = None
) -> Optional[UserModel]: ) -> Optional[UserModel]:
try: try:
with get_db() as db: with get_db_context(db) as db:
db.query(User).filter_by(id=id).update( db.query(User).filter_by(id=id).update(
{"profile_image_url": profile_image_url} {"profile_image_url": profile_image_url}
) )
@ -549,9 +552,9 @@ class UsersTable:
return None return None
@throttle(DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL) @throttle(DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL)
def update_last_active_by_id(self, id: str) -> Optional[UserModel]: def update_last_active_by_id(self, id: str, db: Optional[Session] = None) -> Optional[UserModel]:
try: try:
with get_db() as db: with get_db_context(db) as db:
db.query(User).filter_by(id=id).update( db.query(User).filter_by(id=id).update(
{"last_active_at": int(time.time())} {"last_active_at": int(time.time())}
) )
@ -563,7 +566,7 @@ class UsersTable:
return None return None
def update_user_oauth_by_id( def update_user_oauth_by_id(
self, id: str, provider: str, sub: str self, id: str, provider: str, sub: str, db: Optional[Session] = None
) -> Optional[UserModel]: ) -> Optional[UserModel]:
""" """
Update or insert an OAuth provider/sub pair into the user's oauth JSON field. Update or insert an OAuth provider/sub pair into the user's oauth JSON field.
@ -574,7 +577,7 @@ class UsersTable:
} }
""" """
try: try:
with get_db() as db: with get_db_context(db) as db:
user = db.query(User).filter_by(id=id).first() user = db.query(User).filter_by(id=id).first()
if not user: if not user:
return None return None
@ -594,9 +597,9 @@ class UsersTable:
except Exception: except Exception:
return None return None
def update_user_by_id(self, id: str, updated: dict) -> Optional[UserModel]: def update_user_by_id(self, id: str, updated: dict, db: Optional[Session] = None) -> Optional[UserModel]:
try: try:
with get_db() as db: with get_db_context(db) as db:
db.query(User).filter_by(id=id).update(updated) db.query(User).filter_by(id=id).update(updated)
db.commit() db.commit()
@ -607,9 +610,9 @@ class UsersTable:
print(e) print(e)
return None return None
def update_user_settings_by_id(self, id: str, updated: dict) -> Optional[UserModel]: def update_user_settings_by_id(self, id: str, updated: dict, db: Optional[Session] = None) -> Optional[UserModel]:
try: try:
with get_db() as db: with get_db_context(db) as db:
user_settings = db.query(User).filter_by(id=id).first().settings user_settings = db.query(User).filter_by(id=id).first().settings
if user_settings is None: if user_settings is None:
@ -625,15 +628,15 @@ class UsersTable:
except Exception: except Exception:
return None return None
def delete_user_by_id(self, id: str) -> bool: def delete_user_by_id(self, id: str, db: Optional[Session] = None) -> bool:
try: try:
# Remove User from Groups # Remove User from Groups
Groups.remove_user_from_all_groups(id) Groups.remove_user_from_all_groups(id)
# Delete User Chats # Delete User Chats
result = Chats.delete_chats_by_user_id(id) result = Chats.delete_chats_by_user_id(id, db=db)
if result: if result:
with get_db() as db: with get_db_context(db) as db:
# Delete User # Delete User
db.query(User).filter_by(id=id).delete() db.query(User).filter_by(id=id).delete()
db.commit() db.commit()
@ -644,17 +647,17 @@ class UsersTable:
except Exception: except Exception:
return False return False
def get_user_api_key_by_id(self, id: str) -> Optional[str]: def get_user_api_key_by_id(self, id: str, db: Optional[Session] = None) -> Optional[str]:
try: try:
with get_db() as db: with get_db_context(db) as db:
api_key = db.query(ApiKey).filter_by(user_id=id).first() api_key = db.query(ApiKey).filter_by(user_id=id).first()
return api_key.key if api_key else None return api_key.key if api_key else None
except Exception: except Exception:
return None return None
def update_user_api_key_by_id(self, id: str, api_key: str) -> bool: def update_user_api_key_by_id(self, id: str, api_key: str, db: Optional[Session] = None) -> bool:
try: try:
with get_db() as db: with get_db_context(db) as db:
db.query(ApiKey).filter_by(user_id=id).delete() db.query(ApiKey).filter_by(user_id=id).delete()
db.commit() db.commit()
@ -674,30 +677,30 @@ class UsersTable:
except Exception: except Exception:
return False return False
def delete_user_api_key_by_id(self, id: str) -> bool: def delete_user_api_key_by_id(self, id: str, db: Optional[Session] = None) -> bool:
try: try:
with get_db() as db: with get_db_context(db) as db:
db.query(ApiKey).filter_by(user_id=id).delete() db.query(ApiKey).filter_by(user_id=id).delete()
db.commit() db.commit()
return True return True
except Exception: except Exception:
return False return False
def get_valid_user_ids(self, user_ids: list[str]) -> list[str]: def get_valid_user_ids(self, user_ids: list[str], db: Optional[Session] = None) -> list[str]:
with get_db() as db: with get_db_context(db) as db:
users = db.query(User).filter(User.id.in_(user_ids)).all() users = db.query(User).filter(User.id.in_(user_ids)).all()
return [user.id for user in users] return [user.id for user in users]
def get_super_admin_user(self) -> Optional[UserModel]: def get_super_admin_user(self, db: Optional[Session] = None) -> Optional[UserModel]:
with get_db() as db: with get_db_context(db) as db:
user = db.query(User).filter_by(role="admin").first() user = db.query(User).filter_by(role="admin").first()
if user: if user:
return UserModel.model_validate(user) return UserModel.model_validate(user)
else: else:
return None return None
def get_active_user_count(self) -> int: def get_active_user_count(self, db: Optional[Session] = None) -> int:
with get_db() as db: with get_db_context(db) as db:
# Consider user active if last_active_at within the last 3 minutes # Consider user active if last_active_at within the last 3 minutes
three_minutes_ago = int(time.time()) - 180 three_minutes_ago = int(time.time()) - 180
count = ( count = (
@ -705,8 +708,8 @@ class UsersTable:
) )
return count return count
def is_user_active(self, user_id: str) -> bool: def is_user_active(self, user_id: str, db: Optional[Session] = None) -> bool:
with get_db() as db: with get_db_context(db) as db:
user = db.query(User).filter_by(id=user_id).first() user = db.query(User).filter_by(id=user_id).first()
if user and user.last_active_at: if user and user.last_active_at:
# Consider user active if last_active_at within the last 3 minutes # Consider user active if last_active_at within the last 3 minutes

View file

@ -142,7 +142,7 @@ async def create_new_note(
) )
try: try:
note = Notes.insert_new_note(form_data, user.id) note = Notes.insert_new_note(user.id, form_data)
return note return note
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)