mirror of
https://github.com/open-webui/open-webui.git
synced 2026-01-02 14:45:18 +00:00
refac/enh: db session sharing
This commit is contained in:
parent
d4de26bd05
commit
2041ab483e
20 changed files with 600 additions and 562 deletions
|
|
@ -160,3 +160,13 @@ def get_session():
|
|||
|
||||
|
||||
get_db = contextmanager(get_session)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_db_context(db: Optional[Session] = None):
|
||||
if db:
|
||||
yield db
|
||||
else:
|
||||
with get_db() as session:
|
||||
yield session
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,8 @@ import logging
|
|||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.internal.db import Base, get_db
|
||||
from sqlalchemy.orm import Session
|
||||
from open_webui.internal.db import Base, JSONField, get_db, get_db_context
|
||||
from open_webui.models.users import UserModel, UserProfileImageResponse, Users
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import Boolean, Column, String, Text
|
||||
|
|
@ -87,8 +88,9 @@ class AuthsTable:
|
|||
profile_image_url: str = "/user.png",
|
||||
role: str = "pending",
|
||||
oauth: Optional[dict] = None,
|
||||
db: Optional[Session] = None,
|
||||
) -> Optional[UserModel]:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
log.info("insert_new_auth")
|
||||
|
||||
id = str(uuid.uuid4())
|
||||
|
|
@ -100,7 +102,7 @@ class AuthsTable:
|
|||
db.add(result)
|
||||
|
||||
user = Users.insert_new_user(
|
||||
id, name, email, profile_image_url, role, oauth=oauth
|
||||
id, name, email, profile_image_url, role, oauth=oauth, db=db
|
||||
)
|
||||
|
||||
db.commit()
|
||||
|
|
@ -112,16 +114,16 @@ class AuthsTable:
|
|||
return None
|
||||
|
||||
def authenticate_user(
|
||||
self, email: str, verify_password: callable
|
||||
self, email: str, verify_password: callable, db: Optional[Session] = None
|
||||
) -> Optional[UserModel]:
|
||||
log.info(f"authenticate_user: {email}")
|
||||
|
||||
user = Users.get_user_by_email(email)
|
||||
user = Users.get_user_by_email(email, db=db)
|
||||
if not user:
|
||||
return None
|
||||
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
auth = db.query(Auth).filter_by(id=user.id, active=True).first()
|
||||
if auth:
|
||||
if verify_password(auth.password):
|
||||
|
|
@ -133,32 +135,32 @@ class AuthsTable:
|
|||
except Exception:
|
||||
return None
|
||||
|
||||
def authenticate_user_by_api_key(self, api_key: str) -> Optional[UserModel]:
|
||||
def authenticate_user_by_api_key(self, api_key: str, db: Optional[Session] = None) -> Optional[UserModel]:
|
||||
log.info(f"authenticate_user_by_api_key: {api_key}")
|
||||
# if no api_key, return None
|
||||
if not api_key:
|
||||
return None
|
||||
|
||||
try:
|
||||
user = Users.get_user_by_api_key(api_key)
|
||||
user = Users.get_user_by_api_key(api_key, db=db)
|
||||
return user if user else None
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def authenticate_user_by_email(self, email: str) -> Optional[UserModel]:
|
||||
def authenticate_user_by_email(self, email: str, db: Optional[Session] = None) -> Optional[UserModel]:
|
||||
log.info(f"authenticate_user_by_email: {email}")
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
auth = db.query(Auth).filter_by(email=email, active=True).first()
|
||||
if auth:
|
||||
user = Users.get_user_by_id(auth.id)
|
||||
user = Users.get_user_by_id(auth.id, db=db)
|
||||
return user
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def update_user_password_by_id(self, id: str, new_password: str) -> bool:
|
||||
def update_user_password_by_id(self, id: str, new_password: str, db: Optional[Session] = None) -> bool:
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
result = (
|
||||
db.query(Auth).filter_by(id=id).update({"password": new_password})
|
||||
)
|
||||
|
|
@ -167,20 +169,20 @@ class AuthsTable:
|
|||
except Exception:
|
||||
return False
|
||||
|
||||
def update_email_by_id(self, id: str, email: str) -> bool:
|
||||
def update_email_by_id(self, id: str, email: str, db: Optional[Session] = None) -> bool:
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
result = db.query(Auth).filter_by(id=id).update({"email": email})
|
||||
db.commit()
|
||||
return True if result == 1 else False
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def delete_auth_by_id(self, id: str) -> bool:
|
||||
def delete_auth_by_id(self, id: str, db: Optional[Session] = None) -> bool:
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
# Delete User
|
||||
result = Users.delete_user_by_id(id)
|
||||
result = Users.delete_user_by_id(id, db=db)
|
||||
|
||||
if result:
|
||||
db.query(Auth).filter_by(id=id).delete()
|
||||
|
|
|
|||
|
|
@ -1,9 +1,11 @@
|
|||
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.internal.db import Base, get_db
|
||||
from sqlalchemy.orm import Session
|
||||
from open_webui.internal.db import Base, JSONField, get_db, get_db_context
|
||||
from open_webui.models.groups import Groups
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
|
@ -337,8 +339,8 @@ class ChannelTable:
|
|||
db.commit()
|
||||
return channel
|
||||
|
||||
def get_channels(self) -> list[ChannelModel]:
|
||||
with get_db() as db:
|
||||
def get_channels(self, db: Optional[Session] = None) -> list[ChannelModel]:
|
||||
with get_db_context(db) as db:
|
||||
channels = db.query(Channel).all()
|
||||
return [ChannelModel.model_validate(channel) for channel in channels]
|
||||
|
||||
|
|
@ -384,8 +386,8 @@ class ChannelTable:
|
|||
|
||||
return query
|
||||
|
||||
def get_channels_by_user_id(self, user_id: str) -> list[ChannelModel]:
|
||||
with get_db() as db:
|
||||
def get_channels_by_user_id(self, user_id: str, db: Optional[Session] = None) -> list[ChannelModel]:
|
||||
with get_db_context(db) as db:
|
||||
user_group_ids = [
|
||||
group.id for group in Groups.get_groups_by_member_id(user_id)
|
||||
]
|
||||
|
|
@ -683,10 +685,13 @@ class ChannelTable:
|
|||
)
|
||||
return membership is not None
|
||||
|
||||
def get_channel_by_id(self, id: str) -> Optional[ChannelModel]:
|
||||
with get_db() as db:
|
||||
channel = db.query(Channel).filter(Channel.id == id).first()
|
||||
return ChannelModel.model_validate(channel) if channel else None
|
||||
def get_channel_by_id(self, id: str, db: Optional[Session] = None) -> Optional[ChannelModel]:
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
channel = db.query(Channel).filter(Channel.id == id).first()
|
||||
return ChannelModel.model_validate(channel) if channel else None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_channels_by_file_id(self, file_id: str) -> list[ChannelModel]:
|
||||
with get_db() as db:
|
||||
|
|
|
|||
|
|
@ -4,7 +4,8 @@ import time
|
|||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.internal.db import Base, get_db
|
||||
from sqlalchemy.orm import Session
|
||||
from open_webui.internal.db import Base, JSONField, get_db, get_db_context
|
||||
from open_webui.models.tags import TagModel, Tag, Tags
|
||||
from open_webui.models.folders import Folders
|
||||
from open_webui.utils.misc import sanitize_data_for_db, sanitize_text_for_db
|
||||
|
|
@ -280,8 +281,8 @@ class ChatTable:
|
|||
|
||||
return changed
|
||||
|
||||
def insert_new_chat(self, user_id: str, form_data: ChatForm) -> Optional[ChatModel]:
|
||||
with get_db() as db:
|
||||
def insert_new_chat(self, user_id: str, form_data: ChatForm, db: Optional[Session] = None) -> Optional[ChatModel]:
|
||||
with get_db_context(db) as db:
|
||||
id = str(uuid.uuid4())
|
||||
chat = ChatModel(
|
||||
**{
|
||||
|
|
@ -331,9 +332,9 @@ class ChatTable:
|
|||
return chat
|
||||
|
||||
def import_chats(
|
||||
self, user_id: str, chat_import_forms: list[ChatImportForm]
|
||||
self, user_id: str, chat_import_forms: list[ChatImportForm], db: Optional[Session] = None
|
||||
) -> list[ChatModel]:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
chats = []
|
||||
|
||||
for form_data in chat_import_forms:
|
||||
|
|
@ -344,9 +345,9 @@ class ChatTable:
|
|||
db.commit()
|
||||
return [ChatModel.model_validate(chat) for chat in chats]
|
||||
|
||||
def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]:
|
||||
def update_chat_by_id(self, id: str, chat: dict, db: Optional[Session] = None) -> Optional[ChatModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
chat_item = db.get(Chat, id)
|
||||
chat_item.chat = self._clean_null_bytes(chat)
|
||||
chat_item.title = (
|
||||
|
|
@ -483,13 +484,13 @@ class ChatTable:
|
|||
self.update_chat_by_id(id, chat)
|
||||
return message_files
|
||||
|
||||
def insert_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
|
||||
with get_db() as db:
|
||||
def insert_shared_chat_by_chat_id(self, chat_id: str, db: Optional[Session] = None) -> Optional[ChatModel]:
|
||||
with get_db_context(db) as db:
|
||||
# Get the existing chat to share
|
||||
chat = db.get(Chat, chat_id)
|
||||
# Check if the chat is already shared
|
||||
if chat.share_id:
|
||||
return self.get_chat_by_id_and_user_id(chat.share_id, "shared")
|
||||
return self.get_chat_by_id_and_user_id(chat.share_id, "shared", db=db)
|
||||
# Create a new chat with the same data, but with a new ID
|
||||
shared_chat = ChatModel(
|
||||
**{
|
||||
|
|
@ -518,16 +519,16 @@ class ChatTable:
|
|||
db.commit()
|
||||
return shared_chat if (shared_result and result) else None
|
||||
|
||||
def update_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
|
||||
def update_shared_chat_by_chat_id(self, chat_id: str, db: Optional[Session] = None) -> Optional[ChatModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
chat = db.get(Chat, chat_id)
|
||||
shared_chat = (
|
||||
db.query(Chat).filter_by(user_id=f"shared-{chat_id}").first()
|
||||
)
|
||||
|
||||
if shared_chat is None:
|
||||
return self.insert_shared_chat_by_chat_id(chat_id)
|
||||
return self.insert_shared_chat_by_chat_id(chat_id, db=db)
|
||||
|
||||
shared_chat.title = chat.title
|
||||
shared_chat.chat = chat.chat
|
||||
|
|
@ -542,9 +543,9 @@ class ChatTable:
|
|||
except Exception:
|
||||
return None
|
||||
|
||||
def delete_shared_chat_by_chat_id(self, chat_id: str) -> bool:
|
||||
def delete_shared_chat_by_chat_id(self, chat_id: str, db: Optional[Session] = None) -> bool:
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
db.query(Chat).filter_by(user_id=f"shared-{chat_id}").delete()
|
||||
db.commit()
|
||||
|
||||
|
|
@ -552,9 +553,9 @@ class ChatTable:
|
|||
except Exception:
|
||||
return False
|
||||
|
||||
def unarchive_all_chats_by_user_id(self, user_id: str) -> bool:
|
||||
def unarchive_all_chats_by_user_id(self, user_id: str, db: Optional[Session] = None) -> bool:
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
db.query(Chat).filter_by(user_id=user_id).update({"archived": False})
|
||||
db.commit()
|
||||
return True
|
||||
|
|
@ -562,10 +563,10 @@ class ChatTable:
|
|||
return False
|
||||
|
||||
def update_chat_share_id_by_id(
|
||||
self, id: str, share_id: Optional[str]
|
||||
self, id: str, share_id: Optional[str], db: Optional[Session] = None
|
||||
) -> Optional[ChatModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
chat = db.get(Chat, id)
|
||||
chat.share_id = share_id
|
||||
db.commit()
|
||||
|
|
@ -574,9 +575,9 @@ class ChatTable:
|
|||
except Exception:
|
||||
return None
|
||||
|
||||
def toggle_chat_pinned_by_id(self, id: str) -> Optional[ChatModel]:
|
||||
def toggle_chat_pinned_by_id(self, id: str, db: Optional[Session] = None) -> Optional[ChatModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
chat = db.get(Chat, id)
|
||||
chat.pinned = not chat.pinned
|
||||
chat.updated_at = int(time.time())
|
||||
|
|
@ -586,9 +587,9 @@ class ChatTable:
|
|||
except Exception:
|
||||
return None
|
||||
|
||||
def toggle_chat_archive_by_id(self, id: str) -> Optional[ChatModel]:
|
||||
def toggle_chat_archive_by_id(self, id: str, db: Optional[Session] = None) -> Optional[ChatModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
chat = db.get(Chat, id)
|
||||
chat.archived = not chat.archived
|
||||
chat.folder_id = None
|
||||
|
|
@ -599,9 +600,9 @@ class ChatTable:
|
|||
except Exception:
|
||||
return None
|
||||
|
||||
def archive_all_chats_by_user_id(self, user_id: str) -> bool:
|
||||
def archive_all_chats_by_user_id(self, user_id: str, db: Optional[Session] = None) -> bool:
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
db.query(Chat).filter_by(user_id=user_id).update({"archived": True})
|
||||
db.commit()
|
||||
return True
|
||||
|
|
@ -614,9 +615,10 @@ class ChatTable:
|
|||
filter: Optional[dict] = None,
|
||||
skip: int = 0,
|
||||
limit: int = 50,
|
||||
db: Optional[Session] = None,
|
||||
) -> list[ChatModel]:
|
||||
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
query = db.query(Chat).filter_by(user_id=user_id, archived=True)
|
||||
|
||||
if filter:
|
||||
|
|
@ -655,8 +657,9 @@ class ChatTable:
|
|||
filter: Optional[dict] = None,
|
||||
skip: int = 0,
|
||||
limit: int = 50,
|
||||
db: Optional[Session] = None,
|
||||
) -> list[ChatModel]:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
query = db.query(Chat).filter_by(user_id=user_id)
|
||||
if not include_archived:
|
||||
query = query.filter_by(archived=False)
|
||||
|
|
@ -695,8 +698,9 @@ class ChatTable:
|
|||
include_pinned: bool = False,
|
||||
skip: Optional[int] = None,
|
||||
limit: Optional[int] = None,
|
||||
db: Optional[Session] = None,
|
||||
) -> list[ChatTitleIdResponse]:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
query = db.query(Chat).filter_by(user_id=user_id)
|
||||
|
||||
if not include_folders:
|
||||
|
|
@ -733,9 +737,9 @@ class ChatTable:
|
|||
]
|
||||
|
||||
def get_chat_list_by_chat_ids(
|
||||
self, chat_ids: list[str], skip: int = 0, limit: int = 50
|
||||
self, chat_ids: list[str], skip: int = 0, limit: int = 50, db: Optional[Session] = None
|
||||
) -> list[ChatModel]:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
all_chats = (
|
||||
db.query(Chat)
|
||||
.filter(Chat.id.in_(chat_ids))
|
||||
|
|
@ -745,9 +749,9 @@ class ChatTable:
|
|||
)
|
||||
return [ChatModel.model_validate(chat) for chat in all_chats]
|
||||
|
||||
def get_chat_by_id(self, id: str) -> Optional[ChatModel]:
|
||||
def get_chat_by_id(self, id: str, db: Optional[Session] = None) -> Optional[ChatModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
chat_item = db.get(Chat, id)
|
||||
if chat_item is None:
|
||||
return None
|
||||
|
|
@ -760,30 +764,30 @@ class ChatTable:
|
|||
except Exception:
|
||||
return None
|
||||
|
||||
def get_chat_by_share_id(self, id: str) -> Optional[ChatModel]:
|
||||
def get_chat_by_share_id(self, id: str, db: Optional[Session] = None) -> Optional[ChatModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
# it is possible that the shared link was deleted. hence,
|
||||
# we check if the chat is still shared by checking if a chat with the share_id exists
|
||||
chat = db.query(Chat).filter_by(share_id=id).first()
|
||||
|
||||
if chat:
|
||||
return self.get_chat_by_id(id)
|
||||
return self.get_chat_by_id(id, db=db)
|
||||
else:
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_chat_by_id_and_user_id(self, id: str, user_id: str) -> Optional[ChatModel]:
|
||||
def get_chat_by_id_and_user_id(self, id: str, user_id: str, db: Optional[Session] = None) -> Optional[ChatModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
chat = db.query(Chat).filter_by(id=id, user_id=user_id).first()
|
||||
return ChatModel.model_validate(chat)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_chats(self, skip: int = 0, limit: int = 50) -> list[ChatModel]:
|
||||
with get_db() as db:
|
||||
def get_chats(self, skip: int = 0, limit: int = 50, db: Optional[Session] = None) -> list[ChatModel]:
|
||||
with get_db_context(db) as db:
|
||||
all_chats = (
|
||||
db.query(Chat)
|
||||
# .limit(limit).offset(skip)
|
||||
|
|
@ -797,8 +801,9 @@ class ChatTable:
|
|||
filter: Optional[dict] = None,
|
||||
skip: Optional[int] = None,
|
||||
limit: Optional[int] = None,
|
||||
db: Optional[Session] = None,
|
||||
) -> ChatListResponse:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
query = db.query(Chat).filter_by(user_id=user_id)
|
||||
|
||||
if filter:
|
||||
|
|
@ -838,8 +843,8 @@ class ChatTable:
|
|||
}
|
||||
)
|
||||
|
||||
def get_pinned_chats_by_user_id(self, user_id: str) -> list[ChatModel]:
|
||||
with get_db() as db:
|
||||
def get_pinned_chats_by_user_id(self, user_id: str, db: Optional[Session] = None) -> list[ChatModel]:
|
||||
with get_db_context(db) as db:
|
||||
all_chats = (
|
||||
db.query(Chat)
|
||||
.filter_by(user_id=user_id, pinned=True, archived=False)
|
||||
|
|
@ -847,8 +852,8 @@ class ChatTable:
|
|||
)
|
||||
return [ChatModel.model_validate(chat) for chat in all_chats]
|
||||
|
||||
def get_archived_chats_by_user_id(self, user_id: str) -> list[ChatModel]:
|
||||
with get_db() as db:
|
||||
def get_archived_chats_by_user_id(self, user_id: str, db: Optional[Session] = None) -> list[ChatModel]:
|
||||
with get_db_context(db) as db:
|
||||
all_chats = (
|
||||
db.query(Chat)
|
||||
.filter_by(user_id=user_id, archived=True)
|
||||
|
|
@ -863,6 +868,7 @@ class ChatTable:
|
|||
include_archived: bool = False,
|
||||
skip: int = 0,
|
||||
limit: int = 60,
|
||||
db: Optional[Session] = None,
|
||||
) -> list[ChatModel]:
|
||||
"""
|
||||
Filters chats based on a search query using Python, allowing pagination using skip and limit.
|
||||
|
|
@ -871,7 +877,7 @@ class ChatTable:
|
|||
|
||||
if not search_text:
|
||||
return self.get_chat_list_by_user_id(
|
||||
user_id, include_archived, filter={}, skip=skip, limit=limit
|
||||
user_id, include_archived, filter={}, skip=skip, limit=limit, db=db
|
||||
)
|
||||
|
||||
search_text_words = search_text.split(" ")
|
||||
|
|
@ -926,7 +932,7 @@ class ChatTable:
|
|||
|
||||
search_text = " ".join(search_text_words)
|
||||
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
query = db.query(Chat).filter(Chat.user_id == user_id)
|
||||
|
||||
if is_archived is not None:
|
||||
|
|
@ -1067,9 +1073,9 @@ class ChatTable:
|
|||
return [ChatModel.model_validate(chat) for chat in all_chats]
|
||||
|
||||
def get_chats_by_folder_id_and_user_id(
|
||||
self, folder_id: str, user_id: str, skip: int = 0, limit: int = 60
|
||||
self, folder_id: str, user_id: str, skip: int = 0, limit: int = 60, db: Optional[Session] = None
|
||||
) -> list[ChatModel]:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
query = db.query(Chat).filter_by(folder_id=folder_id, user_id=user_id)
|
||||
query = query.filter(or_(Chat.pinned == False, Chat.pinned == None))
|
||||
query = query.filter_by(archived=False)
|
||||
|
|
@ -1085,9 +1091,9 @@ class ChatTable:
|
|||
return [ChatModel.model_validate(chat) for chat in all_chats]
|
||||
|
||||
def get_chats_by_folder_ids_and_user_id(
|
||||
self, folder_ids: list[str], user_id: str
|
||||
self, folder_ids: list[str], user_id: str, db: Optional[Session] = None
|
||||
) -> list[ChatModel]:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
query = db.query(Chat).filter(
|
||||
Chat.folder_id.in_(folder_ids), Chat.user_id == user_id
|
||||
)
|
||||
|
|
@ -1100,10 +1106,10 @@ class ChatTable:
|
|||
return [ChatModel.model_validate(chat) for chat in all_chats]
|
||||
|
||||
def update_chat_folder_id_by_id_and_user_id(
|
||||
self, id: str, user_id: str, folder_id: str
|
||||
self, id: str, user_id: str, folder_id: str, db: Optional[Session] = None
|
||||
) -> Optional[ChatModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
chat = db.get(Chat, id)
|
||||
chat.folder_id = folder_id
|
||||
chat.updated_at = int(time.time())
|
||||
|
|
@ -1114,16 +1120,16 @@ class ChatTable:
|
|||
except Exception:
|
||||
return None
|
||||
|
||||
def get_chat_tags_by_id_and_user_id(self, id: str, user_id: str) -> list[TagModel]:
|
||||
with get_db() as db:
|
||||
def get_chat_tags_by_id_and_user_id(self, id: str, user_id: str, db: Optional[Session] = None) -> list[TagModel]:
|
||||
with get_db_context(db) as db:
|
||||
chat = db.get(Chat, id)
|
||||
tags = chat.meta.get("tags", [])
|
||||
return [Tags.get_tag_by_name_and_user_id(tag, user_id) for tag in tags]
|
||||
|
||||
def get_chat_list_by_user_id_and_tag_name(
|
||||
self, user_id: str, tag_name: str, skip: int = 0, limit: int = 50
|
||||
self, user_id: str, tag_name: str, skip: int = 0, limit: int = 50, db: Optional[Session] = None
|
||||
) -> list[ChatModel]:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
query = db.query(Chat).filter_by(user_id=user_id)
|
||||
tag_id = tag_name.replace(" ", "_").lower()
|
||||
|
||||
|
|
@ -1152,13 +1158,13 @@ class ChatTable:
|
|||
return [ChatModel.model_validate(chat) for chat in all_chats]
|
||||
|
||||
def add_chat_tag_by_id_and_user_id_and_tag_name(
|
||||
self, id: str, user_id: str, tag_name: str
|
||||
self, id: str, user_id: str, tag_name: str, db: Optional[Session] = None
|
||||
) -> Optional[ChatModel]:
|
||||
tag = Tags.get_tag_by_name_and_user_id(tag_name, user_id)
|
||||
if tag is None:
|
||||
tag = Tags.insert_new_tag(tag_name, user_id)
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
chat = db.get(Chat, id)
|
||||
|
||||
tag_id = tag.id
|
||||
|
|
@ -1174,8 +1180,8 @@ class ChatTable:
|
|||
except Exception:
|
||||
return None
|
||||
|
||||
def count_chats_by_tag_name_and_user_id(self, tag_name: str, user_id: str) -> int:
|
||||
with get_db() as db: # Assuming `get_db()` returns a session object
|
||||
def count_chats_by_tag_name_and_user_id(self, tag_name: str, user_id: str, db: Optional[Session] = None) -> int:
|
||||
with get_db_context(db) as db: # Assuming `get_db()` returns a session object
|
||||
query = db.query(Chat).filter_by(user_id=user_id, archived=False)
|
||||
|
||||
# Normalize the tag_name for consistency
|
||||
|
|
@ -1210,8 +1216,8 @@ class ChatTable:
|
|||
|
||||
return count
|
||||
|
||||
def count_chats_by_folder_id_and_user_id(self, folder_id: str, user_id: str) -> int:
|
||||
with get_db() as db:
|
||||
def count_chats_by_folder_id_and_user_id(self, folder_id: str, user_id: str, db: Optional[Session] = None) -> int:
|
||||
with get_db_context(db) as db:
|
||||
query = db.query(Chat).filter_by(user_id=user_id)
|
||||
|
||||
query = query.filter_by(folder_id=folder_id)
|
||||
|
|
@ -1221,10 +1227,10 @@ class ChatTable:
|
|||
return count
|
||||
|
||||
def delete_tag_by_id_and_user_id_and_tag_name(
|
||||
self, id: str, user_id: str, tag_name: str
|
||||
self, id: str, user_id: str, tag_name: str, db: Optional[Session] = None
|
||||
) -> bool:
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
chat = db.get(Chat, id)
|
||||
tags = chat.meta.get("tags", [])
|
||||
tag_id = tag_name.replace(" ", "_").lower()
|
||||
|
|
@ -1239,9 +1245,9 @@ class ChatTable:
|
|||
except Exception:
|
||||
return False
|
||||
|
||||
def delete_all_tags_by_id_and_user_id(self, id: str, user_id: str) -> bool:
|
||||
def delete_all_tags_by_id_and_user_id(self, id: str, user_id: str, db: Optional[Session] = None) -> bool:
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
chat = db.get(Chat, id)
|
||||
chat.meta = {
|
||||
**chat.meta,
|
||||
|
|
@ -1253,30 +1259,30 @@ class ChatTable:
|
|||
except Exception:
|
||||
return False
|
||||
|
||||
def delete_chat_by_id(self, id: str) -> bool:
|
||||
def delete_chat_by_id(self, id: str, db: Optional[Session] = None) -> bool:
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
db.query(Chat).filter_by(id=id).delete()
|
||||
db.commit()
|
||||
|
||||
return True and self.delete_shared_chat_by_chat_id(id)
|
||||
return True and self.delete_shared_chat_by_chat_id(id, db=db)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def delete_chat_by_id_and_user_id(self, id: str, user_id: str) -> bool:
|
||||
def delete_chat_by_id_and_user_id(self, id: str, user_id: str, db: Optional[Session] = None) -> bool:
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
db.query(Chat).filter_by(id=id, user_id=user_id).delete()
|
||||
db.commit()
|
||||
|
||||
return True and self.delete_shared_chat_by_chat_id(id)
|
||||
return True and self.delete_shared_chat_by_chat_id(id, db=db)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def delete_chats_by_user_id(self, user_id: str) -> bool:
|
||||
def delete_chats_by_user_id(self, user_id: str, db: Optional[Session] = None) -> bool:
|
||||
try:
|
||||
with get_db() as db:
|
||||
self.delete_shared_chats_by_user_id(user_id)
|
||||
with get_db_context(db) as db:
|
||||
self.delete_shared_chats_by_user_id(user_id, db=db)
|
||||
|
||||
db.query(Chat).filter_by(user_id=user_id).delete()
|
||||
db.commit()
|
||||
|
|
@ -1286,10 +1292,10 @@ class ChatTable:
|
|||
return False
|
||||
|
||||
def delete_chats_by_user_id_and_folder_id(
|
||||
self, user_id: str, folder_id: str
|
||||
self, user_id: str, folder_id: str, db: Optional[Session] = None
|
||||
) -> bool:
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
db.query(Chat).filter_by(user_id=user_id, folder_id=folder_id).delete()
|
||||
db.commit()
|
||||
|
||||
|
|
@ -1298,10 +1304,10 @@ class ChatTable:
|
|||
return False
|
||||
|
||||
def move_chats_by_user_id_and_folder_id(
|
||||
self, user_id: str, folder_id: str, new_folder_id: Optional[str]
|
||||
self, user_id: str, folder_id: str, new_folder_id: Optional[str], db: Optional[Session] = None
|
||||
) -> bool:
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
db.query(Chat).filter_by(user_id=user_id, folder_id=folder_id).update(
|
||||
{"folder_id": new_folder_id}
|
||||
)
|
||||
|
|
@ -1311,9 +1317,9 @@ class ChatTable:
|
|||
except Exception:
|
||||
return False
|
||||
|
||||
def delete_shared_chats_by_user_id(self, user_id: str) -> bool:
|
||||
def delete_shared_chats_by_user_id(self, user_id: str, db: Optional[Session] = None) -> bool:
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
chats_by_user = db.query(Chat).filter_by(user_id=user_id).all()
|
||||
shared_chat_ids = [f"shared-{chat.id}" for chat in chats_by_user]
|
||||
|
||||
|
|
@ -1325,7 +1331,7 @@ class ChatTable:
|
|||
return False
|
||||
|
||||
def insert_chat_files(
|
||||
self, chat_id: str, message_id: str, file_ids: list[str], user_id: str
|
||||
self, chat_id: str, message_id: str, file_ids: list[str], user_id: str, db: Optional[Session] = None
|
||||
) -> Optional[list[ChatFileModel]]:
|
||||
if not file_ids:
|
||||
return None
|
||||
|
|
@ -1333,7 +1339,7 @@ class ChatTable:
|
|||
chat_message_file_ids = [
|
||||
item.id
|
||||
for item in self.get_chat_files_by_chat_id_and_message_id(
|
||||
chat_id, message_id
|
||||
chat_id, message_id, db=db
|
||||
)
|
||||
]
|
||||
# Remove duplicates and existing file_ids
|
||||
|
|
@ -1350,7 +1356,7 @@ class ChatTable:
|
|||
return None
|
||||
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
now = int(time.time())
|
||||
|
||||
chat_files = [
|
||||
|
|
@ -1378,9 +1384,9 @@ class ChatTable:
|
|||
return None
|
||||
|
||||
def get_chat_files_by_chat_id_and_message_id(
|
||||
self, chat_id: str, message_id: str
|
||||
self, chat_id: str, message_id: str, db: Optional[Session] = None
|
||||
) -> list[ChatFileModel]:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
all_chat_files = (
|
||||
db.query(ChatFile)
|
||||
.filter_by(chat_id=chat_id, message_id=message_id)
|
||||
|
|
@ -1391,17 +1397,17 @@ class ChatTable:
|
|||
ChatFileModel.model_validate(chat_file) for chat_file in all_chat_files
|
||||
]
|
||||
|
||||
def delete_chat_file(self, chat_id: str, file_id: str) -> bool:
|
||||
def delete_chat_file(self, chat_id: str, file_id: str, db: Optional[Session] = None) -> bool:
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
db.query(ChatFile).filter_by(chat_id=chat_id, file_id=file_id).delete()
|
||||
db.commit()
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def get_shared_chats_by_file_id(self, file_id: str) -> list[ChatModel]:
|
||||
with get_db() as db:
|
||||
def get_shared_chats_by_file_id(self, file_id: str, db: Optional[Session] = None) -> list[ChatModel]:
|
||||
with get_db_context(db) as db:
|
||||
# Join Chat and ChatFile tables to get shared chats associated with the file_id
|
||||
all_chats = (
|
||||
db.query(Chat)
|
||||
|
|
|
|||
|
|
@ -3,7 +3,8 @@ import time
|
|||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.internal.db import Base, get_db
|
||||
from sqlalchemy.orm import Session
|
||||
from open_webui.internal.db import Base, JSONField, get_db, get_db_context
|
||||
from open_webui.models.users import User
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
|
@ -121,9 +122,9 @@ class FeedbackListResponse(BaseModel):
|
|||
|
||||
class FeedbackTable:
|
||||
def insert_new_feedback(
|
||||
self, user_id: str, form_data: FeedbackForm
|
||||
self, user_id: str, form_data: FeedbackForm, db: Optional[Session] = None
|
||||
) -> Optional[FeedbackModel]:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
id = str(uuid.uuid4())
|
||||
feedback = FeedbackModel(
|
||||
**{
|
||||
|
|
@ -148,9 +149,9 @@ class FeedbackTable:
|
|||
log.exception(f"Error creating a new feedback: {e}")
|
||||
return None
|
||||
|
||||
def get_feedback_by_id(self, id: str) -> Optional[FeedbackModel]:
|
||||
def get_feedback_by_id(self, id: str, db: Optional[Session] = None) -> Optional[FeedbackModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
feedback = db.query(Feedback).filter_by(id=id).first()
|
||||
if not feedback:
|
||||
return None
|
||||
|
|
@ -159,10 +160,10 @@ class FeedbackTable:
|
|||
return None
|
||||
|
||||
def get_feedback_by_id_and_user_id(
|
||||
self, id: str, user_id: str
|
||||
self, id: str, user_id: str, db: Optional[Session] = None
|
||||
) -> Optional[FeedbackModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
feedback = db.query(Feedback).filter_by(id=id, user_id=user_id).first()
|
||||
if not feedback:
|
||||
return None
|
||||
|
|
@ -171,9 +172,9 @@ class FeedbackTable:
|
|||
return None
|
||||
|
||||
def get_feedback_items(
|
||||
self, filter: dict = {}, skip: int = 0, limit: int = 30
|
||||
self, filter: dict = {}, skip: int = 0, limit: int = 30, db: Optional[Session] = None
|
||||
) -> FeedbackListResponse:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
query = db.query(Feedback, User).join(User, Feedback.user_id == User.id)
|
||||
|
||||
if filter:
|
||||
|
|
@ -234,8 +235,8 @@ class FeedbackTable:
|
|||
|
||||
return FeedbackListResponse(items=feedbacks, total=total)
|
||||
|
||||
def get_all_feedbacks(self) -> list[FeedbackModel]:
|
||||
with get_db() as db:
|
||||
def get_all_feedbacks(self, db: Optional[Session] = None) -> list[FeedbackModel]:
|
||||
with get_db_context(db) as db:
|
||||
return [
|
||||
FeedbackModel.model_validate(feedback)
|
||||
for feedback in db.query(Feedback)
|
||||
|
|
@ -243,8 +244,8 @@ class FeedbackTable:
|
|||
.all()
|
||||
]
|
||||
|
||||
def get_feedbacks_by_type(self, type: str) -> list[FeedbackModel]:
|
||||
with get_db() as db:
|
||||
def get_feedbacks_by_type(self, type: str, db: Optional[Session] = None) -> list[FeedbackModel]:
|
||||
with get_db_context(db) as db:
|
||||
return [
|
||||
FeedbackModel.model_validate(feedback)
|
||||
for feedback in db.query(Feedback)
|
||||
|
|
@ -253,8 +254,8 @@ class FeedbackTable:
|
|||
.all()
|
||||
]
|
||||
|
||||
def get_feedbacks_by_user_id(self, user_id: str) -> list[FeedbackModel]:
|
||||
with get_db() as db:
|
||||
def get_feedbacks_by_user_id(self, user_id: str, db: Optional[Session] = None) -> list[FeedbackModel]:
|
||||
with get_db_context(db) as db:
|
||||
return [
|
||||
FeedbackModel.model_validate(feedback)
|
||||
for feedback in db.query(Feedback)
|
||||
|
|
@ -264,9 +265,9 @@ class FeedbackTable:
|
|||
]
|
||||
|
||||
def update_feedback_by_id(
|
||||
self, id: str, form_data: FeedbackForm
|
||||
self, id: str, form_data: FeedbackForm, db: Optional[Session] = None
|
||||
) -> Optional[FeedbackModel]:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
feedback = db.query(Feedback).filter_by(id=id).first()
|
||||
if not feedback:
|
||||
return None
|
||||
|
|
@ -284,9 +285,9 @@ class FeedbackTable:
|
|||
return FeedbackModel.model_validate(feedback)
|
||||
|
||||
def update_feedback_by_id_and_user_id(
|
||||
self, id: str, user_id: str, form_data: FeedbackForm
|
||||
self, id: str, user_id: str, form_data: FeedbackForm, db: Optional[Session] = None
|
||||
) -> Optional[FeedbackModel]:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
feedback = db.query(Feedback).filter_by(id=id, user_id=user_id).first()
|
||||
if not feedback:
|
||||
return None
|
||||
|
|
@ -303,8 +304,8 @@ class FeedbackTable:
|
|||
db.commit()
|
||||
return FeedbackModel.model_validate(feedback)
|
||||
|
||||
def delete_feedback_by_id(self, id: str) -> bool:
|
||||
with get_db() as db:
|
||||
def delete_feedback_by_id(self, id: str, db: Optional[Session] = None) -> bool:
|
||||
with get_db_context(db) as db:
|
||||
feedback = db.query(Feedback).filter_by(id=id).first()
|
||||
if not feedback:
|
||||
return False
|
||||
|
|
@ -312,8 +313,8 @@ class FeedbackTable:
|
|||
db.commit()
|
||||
return True
|
||||
|
||||
def delete_feedback_by_id_and_user_id(self, id: str, user_id: str) -> bool:
|
||||
with get_db() as db:
|
||||
def delete_feedback_by_id_and_user_id(self, id: str, user_id: str, db: Optional[Session] = None) -> bool:
|
||||
with get_db_context(db) as db:
|
||||
feedback = db.query(Feedback).filter_by(id=id, user_id=user_id).first()
|
||||
if not feedback:
|
||||
return False
|
||||
|
|
@ -321,8 +322,8 @@ class FeedbackTable:
|
|||
db.commit()
|
||||
return True
|
||||
|
||||
def delete_feedbacks_by_user_id(self, user_id: str) -> bool:
|
||||
with get_db() as db:
|
||||
def delete_feedbacks_by_user_id(self, user_id: str, db: Optional[Session] = None) -> bool:
|
||||
with get_db_context(db) as db:
|
||||
feedbacks = db.query(Feedback).filter_by(user_id=user_id).all()
|
||||
if not feedbacks:
|
||||
return False
|
||||
|
|
@ -331,8 +332,8 @@ class FeedbackTable:
|
|||
db.commit()
|
||||
return True
|
||||
|
||||
def delete_all_feedbacks(self) -> bool:
|
||||
with get_db() as db:
|
||||
def delete_all_feedbacks(self, db: Optional[Session] = None) -> bool:
|
||||
with get_db_context(db) as db:
|
||||
feedbacks = db.query(Feedback).all()
|
||||
if not feedbacks:
|
||||
return False
|
||||
|
|
|
|||
|
|
@ -2,7 +2,8 @@ import logging
|
|||
import time
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.internal.db import Base, JSONField, get_db
|
||||
from sqlalchemy.orm import Session
|
||||
from open_webui.internal.db import Base, JSONField, get_db, get_db_context
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import BigInteger, Column, String, Text, JSON
|
||||
|
||||
|
|
@ -108,8 +109,10 @@ class FileListResponse(BaseModel):
|
|||
|
||||
|
||||
class FilesTable:
|
||||
def insert_new_file(self, user_id: str, form_data: FileForm) -> Optional[FileModel]:
|
||||
with get_db() as db:
|
||||
def insert_new_file(
|
||||
self, user_id: str, form_data: FileForm, db: Optional[Session] = None
|
||||
) -> Optional[FileModel]:
|
||||
with get_db_context(db) as db:
|
||||
file = FileModel(
|
||||
**{
|
||||
**form_data.model_dump(),
|
||||
|
|
@ -132,13 +135,16 @@ class FilesTable:
|
|||
log.exception(f"Error inserting a new file: {e}")
|
||||
return None
|
||||
|
||||
def get_file_by_id(self, id: str) -> Optional[FileModel]:
|
||||
with get_db() as db:
|
||||
try:
|
||||
file = db.get(File, id)
|
||||
return FileModel.model_validate(file)
|
||||
except Exception:
|
||||
return None
|
||||
def get_file_by_id(self, id: str, db: Optional[Session] = None) -> Optional[FileModel]:
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
try:
|
||||
file = db.get(File, id)
|
||||
return FileModel.model_validate(file)
|
||||
except Exception:
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_file_by_id_and_user_id(self, id: str, user_id: str) -> Optional[FileModel]:
|
||||
with get_db() as db:
|
||||
|
|
@ -165,8 +171,8 @@ class FilesTable:
|
|||
except Exception:
|
||||
return None
|
||||
|
||||
def get_files(self) -> list[FileModel]:
|
||||
with get_db() as db:
|
||||
def get_files(self, db: Optional[Session] = None) -> list[FileModel]:
|
||||
with get_db_context(db) as db:
|
||||
return [FileModel.model_validate(file) for file in db.query(File).all()]
|
||||
|
||||
def check_access_by_user_id(self, id, user_id, permission="write") -> bool:
|
||||
|
|
@ -206,8 +212,8 @@ class FilesTable:
|
|||
.all()
|
||||
]
|
||||
|
||||
def get_files_by_user_id(self, user_id: str) -> list[FileModel]:
|
||||
with get_db() as db:
|
||||
def get_files_by_user_id(self, user_id: str, db: Optional[Session] = None) -> list[FileModel]:
|
||||
with get_db_context(db) as db:
|
||||
return [
|
||||
FileModel.model_validate(file)
|
||||
for file in db.query(File).filter_by(user_id=user_id).all()
|
||||
|
|
@ -271,24 +277,6 @@ class FilesTable:
|
|||
except Exception:
|
||||
return None
|
||||
|
||||
def delete_file_by_id(self, id: str) -> bool:
|
||||
with get_db() as db:
|
||||
try:
|
||||
db.query(File).filter_by(id=id).delete()
|
||||
db.commit()
|
||||
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def delete_all_files(self) -> bool:
|
||||
with get_db() as db:
|
||||
try:
|
||||
db.query(File).delete()
|
||||
db.commit()
|
||||
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -7,8 +7,9 @@ import re
|
|||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import BigInteger, Column, Text, JSON, Boolean, func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from open_webui.internal.db import Base, get_db
|
||||
from open_webui.internal.db import Base, JSONField, get_db, get_db_context
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
|
@ -83,9 +84,9 @@ class FolderUpdateForm(BaseModel):
|
|||
|
||||
class FolderTable:
|
||||
def insert_new_folder(
|
||||
self, user_id: str, form_data: FolderForm, parent_id: Optional[str] = None
|
||||
self, user_id: str, form_data: FolderForm, parent_id: Optional[str] = None, db: Optional[Session] = None
|
||||
) -> Optional[FolderModel]:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
id = str(uuid.uuid4())
|
||||
folder = FolderModel(
|
||||
**{
|
||||
|
|
@ -111,10 +112,10 @@ class FolderTable:
|
|||
return None
|
||||
|
||||
def get_folder_by_id_and_user_id(
|
||||
self, id: str, user_id: str
|
||||
self, id: str, user_id: str, db: Optional[Session] = None
|
||||
) -> Optional[FolderModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
folder = db.query(Folder).filter_by(id=id, user_id=user_id).first()
|
||||
|
||||
if not folder:
|
||||
|
|
@ -125,15 +126,15 @@ class FolderTable:
|
|||
return None
|
||||
|
||||
def get_children_folders_by_id_and_user_id(
|
||||
self, id: str, user_id: str
|
||||
self, id: str, user_id: str, db: Optional[Session] = None
|
||||
) -> Optional[list[FolderModel]]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
folders = []
|
||||
|
||||
def get_children(folder):
|
||||
children = self.get_folders_by_parent_id_and_user_id(
|
||||
folder.id, user_id
|
||||
folder.id, user_id, db=db
|
||||
)
|
||||
for child in children:
|
||||
get_children(child)
|
||||
|
|
@ -148,18 +149,18 @@ class FolderTable:
|
|||
except Exception:
|
||||
return None
|
||||
|
||||
def get_folders_by_user_id(self, user_id: str) -> list[FolderModel]:
|
||||
with get_db() as db:
|
||||
def get_folders_by_user_id(self, user_id: str, db: Optional[Session] = None) -> list[FolderModel]:
|
||||
with get_db_context(db) as db:
|
||||
return [
|
||||
FolderModel.model_validate(folder)
|
||||
for folder in db.query(Folder).filter_by(user_id=user_id).all()
|
||||
]
|
||||
|
||||
def get_folder_by_parent_id_and_user_id_and_name(
|
||||
self, parent_id: Optional[str], user_id: str, name: str
|
||||
self, parent_id: Optional[str], user_id: str, name: str, db: Optional[Session] = None
|
||||
) -> Optional[FolderModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
# Check if folder exists
|
||||
folder = (
|
||||
db.query(Folder)
|
||||
|
|
@ -177,9 +178,9 @@ class FolderTable:
|
|||
return None
|
||||
|
||||
def get_folders_by_parent_id_and_user_id(
|
||||
self, parent_id: Optional[str], user_id: str
|
||||
self, parent_id: Optional[str], user_id: str, db: Optional[Session] = None
|
||||
) -> list[FolderModel]:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
return [
|
||||
FolderModel.model_validate(folder)
|
||||
for folder in db.query(Folder)
|
||||
|
|
@ -192,9 +193,10 @@ class FolderTable:
|
|||
id: str,
|
||||
user_id: str,
|
||||
parent_id: str,
|
||||
db: Optional[Session] = None,
|
||||
) -> Optional[FolderModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
folder = db.query(Folder).filter_by(id=id, user_id=user_id).first()
|
||||
|
||||
if not folder:
|
||||
|
|
@ -211,10 +213,10 @@ class FolderTable:
|
|||
return
|
||||
|
||||
def update_folder_by_id_and_user_id(
|
||||
self, id: str, user_id: str, form_data: FolderUpdateForm
|
||||
self, id: str, user_id: str, form_data: FolderUpdateForm, db: Optional[Session] = None
|
||||
) -> Optional[FolderModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
folder = db.query(Folder).filter_by(id=id, user_id=user_id).first()
|
||||
|
||||
if not folder:
|
||||
|
|
@ -257,10 +259,10 @@ class FolderTable:
|
|||
return
|
||||
|
||||
def update_folder_is_expanded_by_id_and_user_id(
|
||||
self, id: str, user_id: str, is_expanded: bool
|
||||
self, id: str, user_id: str, is_expanded: bool, db: Optional[Session] = None
|
||||
) -> Optional[FolderModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
folder = db.query(Folder).filter_by(id=id, user_id=user_id).first()
|
||||
|
||||
if not folder:
|
||||
|
|
@ -276,10 +278,10 @@ class FolderTable:
|
|||
log.error(f"update_folder: {e}")
|
||||
return
|
||||
|
||||
def delete_folder_by_id_and_user_id(self, id: str, user_id: str) -> list[str]:
|
||||
def delete_folder_by_id_and_user_id(self, id: str, user_id: str, db: Optional[Session] = None) -> list[str]:
|
||||
try:
|
||||
folder_ids = []
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
folder = db.query(Folder).filter_by(id=id, user_id=user_id).first()
|
||||
if not folder:
|
||||
return folder_ids
|
||||
|
|
@ -289,7 +291,7 @@ class FolderTable:
|
|||
# Delete all children folders
|
||||
def delete_children(folder):
|
||||
folder_children = self.get_folders_by_parent_id_and_user_id(
|
||||
folder.id, user_id
|
||||
folder.id, user_id, db=db
|
||||
)
|
||||
for folder_child in folder_children:
|
||||
|
||||
|
|
@ -314,7 +316,7 @@ class FolderTable:
|
|||
return name.strip().lower()
|
||||
|
||||
def search_folders_by_names(
|
||||
self, user_id: str, queries: list[str]
|
||||
self, user_id: str, queries: list[str], db: Optional[Session] = None
|
||||
) -> list[FolderModel]:
|
||||
"""
|
||||
Search for folders for a user where the name matches any of the queries, treating _ and space as equivalent, case-insensitive.
|
||||
|
|
@ -324,7 +326,7 @@ class FolderTable:
|
|||
return []
|
||||
|
||||
results = {}
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
folders = db.query(Folder).filter_by(user_id=user_id).all()
|
||||
for folder in folders:
|
||||
if self.normalize_folder_name(folder.name) in normalized_queries:
|
||||
|
|
@ -332,7 +334,7 @@ class FolderTable:
|
|||
|
||||
# get children folders
|
||||
children = self.get_children_folders_by_id_and_user_id(
|
||||
folder.id, user_id
|
||||
folder.id, user_id, db=db
|
||||
)
|
||||
for child in children:
|
||||
results[child.id] = child
|
||||
|
|
@ -345,14 +347,14 @@ class FolderTable:
|
|||
return results
|
||||
|
||||
def search_folders_by_name_contains(
|
||||
self, user_id: str, query: str
|
||||
self, user_id: str, query: str, db: Optional[Session] = None
|
||||
) -> list[FolderModel]:
|
||||
"""
|
||||
Partial match: normalized name contains (as substring) the normalized query.
|
||||
"""
|
||||
normalized_query = self.normalize_folder_name(query)
|
||||
results = []
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
folders = db.query(Folder).filter_by(user_id=user_id).all()
|
||||
for folder in folders:
|
||||
norm_name = self.normalize_folder_name(folder.name)
|
||||
|
|
|
|||
|
|
@ -2,7 +2,8 @@ import logging
|
|||
import time
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.internal.db import Base, JSONField, get_db
|
||||
from sqlalchemy.orm import Session
|
||||
from open_webui.internal.db import Base, JSONField, get_db, get_db_context
|
||||
from open_webui.models.users import Users, UserModel
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import BigInteger, Boolean, Column, String, Text, Index
|
||||
|
|
@ -103,7 +104,7 @@ class FunctionValves(BaseModel):
|
|||
|
||||
class FunctionsTable:
|
||||
def insert_new_function(
|
||||
self, user_id: str, type: str, form_data: FunctionForm
|
||||
self, user_id: str, type: str, form_data: FunctionForm, db: Optional[Session] = None
|
||||
) -> Optional[FunctionModel]:
|
||||
function = FunctionModel(
|
||||
**{
|
||||
|
|
@ -116,7 +117,7 @@ class FunctionsTable:
|
|||
)
|
||||
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
result = Function(**function.model_dump())
|
||||
db.add(result)
|
||||
db.commit()
|
||||
|
|
@ -130,11 +131,11 @@ class FunctionsTable:
|
|||
return None
|
||||
|
||||
def sync_functions(
|
||||
self, user_id: str, functions: list[FunctionWithValvesModel]
|
||||
self, user_id: str, functions: list[FunctionWithValvesModel], db: Optional[Session] = None
|
||||
) -> list[FunctionWithValvesModel]:
|
||||
# Synchronize functions for a user by updating existing ones, inserting new ones, and removing those that are no longer present.
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
# Get existing functions
|
||||
existing_functions = db.query(Function).all()
|
||||
existing_ids = {func.id for func in existing_functions}
|
||||
|
|
@ -177,18 +178,18 @@ class FunctionsTable:
|
|||
log.exception(f"Error syncing functions for user {user_id}: {e}")
|
||||
return []
|
||||
|
||||
def get_function_by_id(self, id: str) -> Optional[FunctionModel]:
|
||||
def get_function_by_id(self, id: str, db: Optional[Session] = None) -> Optional[FunctionModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
function = db.get(Function, id)
|
||||
return FunctionModel.model_validate(function)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_functions(
|
||||
self, active_only=False, include_valves=False
|
||||
self, active_only=False, include_valves=False, db: Optional[Session] = None
|
||||
) -> list[FunctionModel | FunctionWithValvesModel]:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
if active_only:
|
||||
functions = db.query(Function).filter_by(is_active=True).all()
|
||||
|
||||
|
|
@ -205,12 +206,12 @@ class FunctionsTable:
|
|||
FunctionModel.model_validate(function) for function in functions
|
||||
]
|
||||
|
||||
def get_function_list(self) -> list[FunctionUserResponse]:
|
||||
with get_db() as db:
|
||||
def get_function_list(self, db: Optional[Session] = None) -> list[FunctionUserResponse]:
|
||||
with get_db_context(db) as db:
|
||||
functions = db.query(Function).order_by(Function.updated_at.desc()).all()
|
||||
user_ids = list(set(func.user_id for func in functions))
|
||||
|
||||
users = Users.get_users_by_user_ids(user_ids) if user_ids else []
|
||||
users = Users.get_users_by_user_ids(user_ids, db=db) if user_ids else []
|
||||
users_dict = {user.id: user for user in users}
|
||||
|
||||
return [
|
||||
|
|
@ -228,9 +229,9 @@ class FunctionsTable:
|
|||
]
|
||||
|
||||
def get_functions_by_type(
|
||||
self, type: str, active_only=False
|
||||
self, type: str, active_only=False, db: Optional[Session] = None
|
||||
) -> list[FunctionModel]:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
if active_only:
|
||||
return [
|
||||
FunctionModel.model_validate(function)
|
||||
|
|
@ -244,8 +245,8 @@ class FunctionsTable:
|
|||
for function in db.query(Function).filter_by(type=type).all()
|
||||
]
|
||||
|
||||
def get_global_filter_functions(self) -> list[FunctionModel]:
|
||||
with get_db() as db:
|
||||
def get_global_filter_functions(self, db: Optional[Session] = None) -> list[FunctionModel]:
|
||||
with get_db_context(db) as db:
|
||||
return [
|
||||
FunctionModel.model_validate(function)
|
||||
for function in db.query(Function)
|
||||
|
|
@ -253,8 +254,8 @@ class FunctionsTable:
|
|||
.all()
|
||||
]
|
||||
|
||||
def get_global_action_functions(self) -> list[FunctionModel]:
|
||||
with get_db() as db:
|
||||
def get_global_action_functions(self, db: Optional[Session] = None) -> list[FunctionModel]:
|
||||
with get_db_context(db) as db:
|
||||
return [
|
||||
FunctionModel.model_validate(function)
|
||||
for function in db.query(Function)
|
||||
|
|
@ -262,8 +263,8 @@ class FunctionsTable:
|
|||
.all()
|
||||
]
|
||||
|
||||
def get_function_valves_by_id(self, id: str) -> Optional[dict]:
|
||||
with get_db() as db:
|
||||
def get_function_valves_by_id(self, id: str, db: Optional[Session] = None) -> Optional[dict]:
|
||||
with get_db_context(db) as db:
|
||||
try:
|
||||
function = db.get(Function, id)
|
||||
return function.valves if function.valves else {}
|
||||
|
|
@ -272,23 +273,23 @@ class FunctionsTable:
|
|||
return None
|
||||
|
||||
def update_function_valves_by_id(
|
||||
self, id: str, valves: dict
|
||||
self, id: str, valves: dict, db: Optional[Session] = None
|
||||
) -> Optional[FunctionValves]:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
try:
|
||||
function = db.get(Function, id)
|
||||
function.valves = valves
|
||||
function.updated_at = int(time.time())
|
||||
db.commit()
|
||||
db.refresh(function)
|
||||
return self.get_function_by_id(id)
|
||||
return self.get_function_by_id(id, db=db)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def update_function_metadata_by_id(
|
||||
self, id: str, metadata: dict
|
||||
self, id: str, metadata: dict, db: Optional[Session] = None
|
||||
) -> Optional[FunctionModel]:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
try:
|
||||
function = db.get(Function, id)
|
||||
|
||||
|
|
@ -301,7 +302,7 @@ class FunctionsTable:
|
|||
function.updated_at = int(time.time())
|
||||
db.commit()
|
||||
db.refresh(function)
|
||||
return self.get_function_by_id(id)
|
||||
return self.get_function_by_id(id, db=db)
|
||||
else:
|
||||
return None
|
||||
except Exception as e:
|
||||
|
|
@ -309,10 +310,10 @@ class FunctionsTable:
|
|||
return None
|
||||
|
||||
def get_user_valves_by_id_and_user_id(
|
||||
self, id: str, user_id: str
|
||||
self, id: str, user_id: str, db: Optional[Session] = None
|
||||
) -> Optional[dict]:
|
||||
try:
|
||||
user = Users.get_user_by_id(user_id)
|
||||
user = Users.get_user_by_id(user_id, db=db)
|
||||
user_settings = user.settings.model_dump() if user.settings else {}
|
||||
|
||||
# Check if user has "functions" and "valves" settings
|
||||
|
|
@ -327,10 +328,10 @@ class FunctionsTable:
|
|||
return None
|
||||
|
||||
def update_user_valves_by_id_and_user_id(
|
||||
self, id: str, user_id: str, valves: dict
|
||||
self, id: str, user_id: str, valves: dict, db: Optional[Session] = None
|
||||
) -> Optional[dict]:
|
||||
try:
|
||||
user = Users.get_user_by_id(user_id)
|
||||
user = Users.get_user_by_id(user_id, db=db)
|
||||
user_settings = user.settings.model_dump() if user.settings else {}
|
||||
|
||||
# Check if user has "functions" and "valves" settings
|
||||
|
|
@ -342,7 +343,7 @@ class FunctionsTable:
|
|||
user_settings["functions"]["valves"][id] = valves
|
||||
|
||||
# Update the user settings in the database
|
||||
Users.update_user_by_id(user_id, {"settings": user_settings})
|
||||
Users.update_user_by_id(user_id, {"settings": user_settings}, db=db)
|
||||
|
||||
return user_settings["functions"]["valves"][id]
|
||||
except Exception as e:
|
||||
|
|
@ -351,8 +352,8 @@ class FunctionsTable:
|
|||
)
|
||||
return None
|
||||
|
||||
def update_function_by_id(self, id: str, updated: dict) -> Optional[FunctionModel]:
|
||||
with get_db() as db:
|
||||
def update_function_by_id(self, id: str, updated: dict, db: Optional[Session] = None) -> Optional[FunctionModel]:
|
||||
with get_db_context(db) as db:
|
||||
try:
|
||||
db.query(Function).filter_by(id=id).update(
|
||||
{
|
||||
|
|
@ -361,12 +362,12 @@ class FunctionsTable:
|
|||
}
|
||||
)
|
||||
db.commit()
|
||||
return self.get_function_by_id(id)
|
||||
return self.get_function_by_id(id, db=db)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def deactivate_all_functions(self) -> Optional[bool]:
|
||||
with get_db() as db:
|
||||
def deactivate_all_functions(self, db: Optional[Session] = None) -> Optional[bool]:
|
||||
with get_db_context(db) as db:
|
||||
try:
|
||||
db.query(Function).update(
|
||||
{
|
||||
|
|
@ -379,8 +380,8 @@ class FunctionsTable:
|
|||
except Exception:
|
||||
return None
|
||||
|
||||
def delete_function_by_id(self, id: str) -> bool:
|
||||
with get_db() as db:
|
||||
def delete_function_by_id(self, id: str, db: Optional[Session] = None) -> bool:
|
||||
with get_db_context(db) as db:
|
||||
try:
|
||||
db.query(Function).filter_by(id=id).delete()
|
||||
db.commit()
|
||||
|
|
|
|||
|
|
@ -4,7 +4,8 @@ import time
|
|||
from typing import Optional
|
||||
import uuid
|
||||
|
||||
from open_webui.internal.db import Base, get_db
|
||||
from sqlalchemy.orm import Session
|
||||
from open_webui.internal.db import Base, JSONField, get_db, get_db_context
|
||||
|
||||
from open_webui.models.files import FileMetadataResponse
|
||||
|
||||
|
|
@ -120,9 +121,9 @@ class GroupListResponse(BaseModel):
|
|||
|
||||
class GroupTable:
|
||||
def insert_new_group(
|
||||
self, user_id: str, form_data: GroupForm
|
||||
self, user_id: str, form_data: GroupForm, db: Optional[Session] = None
|
||||
) -> Optional[GroupModel]:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
group = GroupModel(
|
||||
**{
|
||||
**form_data.model_dump(exclude_none=True),
|
||||
|
|
@ -146,13 +147,13 @@ class GroupTable:
|
|||
except Exception:
|
||||
return None
|
||||
|
||||
def get_all_groups(self) -> list[GroupModel]:
|
||||
with get_db() as db:
|
||||
def get_all_groups(self, db: Optional[Session] = None) -> list[GroupModel]:
|
||||
with get_db_context(db) as db:
|
||||
groups = db.query(Group).order_by(Group.updated_at.desc()).all()
|
||||
return [GroupModel.model_validate(group) for group in groups]
|
||||
|
||||
def get_groups(self, filter) -> list[GroupResponse]:
|
||||
with get_db() as db:
|
||||
def get_groups(self, filter, db: Optional[Session] = None) -> list[GroupResponse]:
|
||||
with get_db_context(db) as db:
|
||||
query = db.query(Group)
|
||||
|
||||
if filter:
|
||||
|
|
@ -184,16 +185,16 @@ class GroupTable:
|
|||
GroupResponse.model_validate(
|
||||
{
|
||||
**GroupModel.model_validate(group).model_dump(),
|
||||
"member_count": self.get_group_member_count_by_id(group.id),
|
||||
"member_count": self.get_group_member_count_by_id(group.id, db=db),
|
||||
}
|
||||
)
|
||||
for group in groups
|
||||
]
|
||||
|
||||
def search_groups(
|
||||
self, filter: Optional[dict] = None, skip: int = 0, limit: int = 30
|
||||
self, filter: Optional[dict] = None, skip: int = 0, limit: int = 30, db: Optional[Session] = None
|
||||
) -> GroupListResponse:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
query = db.query(Group)
|
||||
|
||||
if filter:
|
||||
|
|
@ -220,15 +221,15 @@ class GroupTable:
|
|||
"items": [
|
||||
GroupResponse.model_validate(
|
||||
**GroupModel.model_validate(group).model_dump(),
|
||||
member_count=self.get_group_member_count_by_id(group.id),
|
||||
member_count=self.get_group_member_count_by_id(group.id, db=db),
|
||||
)
|
||||
for group in groups
|
||||
],
|
||||
"total": total,
|
||||
}
|
||||
|
||||
def get_groups_by_member_id(self, user_id: str) -> list[GroupModel]:
|
||||
with get_db() as db:
|
||||
def get_groups_by_member_id(self, user_id: str, db: Optional[Session] = None) -> list[GroupModel]:
|
||||
with get_db_context(db) as db:
|
||||
return [
|
||||
GroupModel.model_validate(group)
|
||||
for group in db.query(Group)
|
||||
|
|
@ -238,16 +239,16 @@ class GroupTable:
|
|||
.all()
|
||||
]
|
||||
|
||||
def get_group_by_id(self, id: str) -> Optional[GroupModel]:
|
||||
def get_group_by_id(self, id: str, db: Optional[Session] = None) -> Optional[GroupModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
group = db.query(Group).filter_by(id=id).first()
|
||||
return GroupModel.model_validate(group) if group else None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_group_user_ids_by_id(self, id: str) -> Optional[list[str]]:
|
||||
with get_db() as db:
|
||||
def get_group_user_ids_by_id(self, id: str, db: Optional[Session] = None) -> Optional[list[str]]:
|
||||
with get_db_context(db) as db:
|
||||
members = (
|
||||
db.query(GroupMember.user_id).filter(GroupMember.group_id == id).all()
|
||||
)
|
||||
|
|
@ -257,8 +258,8 @@ class GroupTable:
|
|||
|
||||
return [m[0] for m in members]
|
||||
|
||||
def get_group_user_ids_by_ids(self, group_ids: list[str]) -> dict[str, list[str]]:
|
||||
with get_db() as db:
|
||||
def get_group_user_ids_by_ids(self, group_ids: list[str], db: Optional[Session] = None) -> dict[str, list[str]]:
|
||||
with get_db_context(db) as db:
|
||||
members = (
|
||||
db.query(GroupMember.group_id, GroupMember.user_id)
|
||||
.filter(GroupMember.group_id.in_(group_ids))
|
||||
|
|
@ -274,8 +275,8 @@ class GroupTable:
|
|||
|
||||
return group_user_ids
|
||||
|
||||
def set_group_user_ids_by_id(self, group_id: str, user_ids: list[str]) -> None:
|
||||
with get_db() as db:
|
||||
def set_group_user_ids_by_id(self, group_id: str, user_ids: list[str], db: Optional[Session] = None) -> None:
|
||||
with get_db_context(db) as db:
|
||||
# Delete existing members
|
||||
db.query(GroupMember).filter(GroupMember.group_id == group_id).delete()
|
||||
|
||||
|
|
@ -295,8 +296,8 @@ class GroupTable:
|
|||
db.add_all(new_members)
|
||||
db.commit()
|
||||
|
||||
def get_group_member_count_by_id(self, id: str) -> int:
|
||||
with get_db() as db:
|
||||
def get_group_member_count_by_id(self, id: str, db: Optional[Session] = None) -> int:
|
||||
with get_db_context(db) as db:
|
||||
count = (
|
||||
db.query(func.count(GroupMember.user_id))
|
||||
.filter(GroupMember.group_id == id)
|
||||
|
|
@ -305,10 +306,10 @@ class GroupTable:
|
|||
return count if count else 0
|
||||
|
||||
def update_group_by_id(
|
||||
self, id: str, form_data: GroupUpdateForm, overwrite: bool = False
|
||||
self, id: str, form_data: GroupUpdateForm, overwrite: bool = False, db: Optional[Session] = None
|
||||
) -> Optional[GroupModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
db.query(Group).filter_by(id=id).update(
|
||||
{
|
||||
**form_data.model_dump(exclude_none=True),
|
||||
|
|
@ -316,22 +317,22 @@ class GroupTable:
|
|||
}
|
||||
)
|
||||
db.commit()
|
||||
return self.get_group_by_id(id=id)
|
||||
return self.get_group_by_id(id=id, db=db)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
return None
|
||||
|
||||
def delete_group_by_id(self, id: str) -> bool:
|
||||
def delete_group_by_id(self, id: str, db: Optional[Session] = None) -> bool:
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
db.query(Group).filter_by(id=id).delete()
|
||||
db.commit()
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def delete_all_groups(self) -> bool:
|
||||
with get_db() as db:
|
||||
def delete_all_groups(self, db: Optional[Session] = None) -> bool:
|
||||
with get_db_context(db) as db:
|
||||
try:
|
||||
db.query(Group).delete()
|
||||
db.commit()
|
||||
|
|
@ -340,8 +341,8 @@ class GroupTable:
|
|||
except Exception:
|
||||
return False
|
||||
|
||||
def remove_user_from_all_groups(self, user_id: str) -> bool:
|
||||
with get_db() as db:
|
||||
def remove_user_from_all_groups(self, user_id: str, db: Optional[Session] = None) -> bool:
|
||||
with get_db_context(db) as db:
|
||||
try:
|
||||
# Find all groups the user belongs to
|
||||
groups = (
|
||||
|
|
@ -369,16 +370,16 @@ class GroupTable:
|
|||
return False
|
||||
|
||||
def create_groups_by_group_names(
|
||||
self, user_id: str, group_names: list[str]
|
||||
self, user_id: str, group_names: list[str], db: Optional[Session] = None
|
||||
) -> list[GroupModel]:
|
||||
|
||||
# check for existing groups
|
||||
existing_groups = self.get_all_groups()
|
||||
existing_groups = self.get_all_groups(db=db)
|
||||
existing_group_names = {group.name for group in existing_groups}
|
||||
|
||||
new_groups = []
|
||||
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
for group_name in group_names:
|
||||
if group_name not in existing_group_names:
|
||||
new_group = GroupModel(
|
||||
|
|
@ -400,8 +401,8 @@ class GroupTable:
|
|||
continue
|
||||
return new_groups
|
||||
|
||||
def sync_groups_by_group_names(self, user_id: str, group_names: list[str]) -> bool:
|
||||
with get_db() as db:
|
||||
def sync_groups_by_group_names(self, user_id: str, group_names: list[str], db: Optional[Session] = None) -> bool:
|
||||
with get_db_context(db) as db:
|
||||
try:
|
||||
now = int(time.time())
|
||||
|
||||
|
|
@ -461,10 +462,10 @@ class GroupTable:
|
|||
return False
|
||||
|
||||
def add_users_to_group(
|
||||
self, id: str, user_ids: Optional[list[str]] = None
|
||||
self, id: str, user_ids: Optional[list[str]] = None, db: Optional[Session] = None
|
||||
) -> Optional[GroupModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
group = db.query(Group).filter_by(id=id).first()
|
||||
if not group:
|
||||
return None
|
||||
|
|
@ -499,10 +500,10 @@ class GroupTable:
|
|||
return None
|
||||
|
||||
def remove_users_from_group(
|
||||
self, id: str, user_ids: Optional[list[str]] = None
|
||||
self, id: str, user_ids: Optional[list[str]] = None, db: Optional[Session] = None
|
||||
) -> Optional[GroupModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
group = db.query(Group).filter_by(id=id).first()
|
||||
if not group:
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -1,10 +1,12 @@
|
|||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional
|
||||
import uuid
|
||||
|
||||
from open_webui.internal.db import Base, get_db
|
||||
from sqlalchemy.orm import Session
|
||||
from open_webui.internal.db import Base, JSONField, get_db, get_db_context
|
||||
|
||||
from open_webui.models.files import (
|
||||
File,
|
||||
|
|
@ -157,9 +159,9 @@ class KnowledgeFileListResponse(BaseModel):
|
|||
|
||||
class KnowledgeTable:
|
||||
def insert_new_knowledge(
|
||||
self, user_id: str, form_data: KnowledgeForm
|
||||
self, user_id: str, form_data: KnowledgeForm, db: Optional[Session] = None
|
||||
) -> Optional[KnowledgeModel]:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
knowledge = KnowledgeModel(
|
||||
**{
|
||||
**form_data.model_dump(),
|
||||
|
|
@ -183,15 +185,15 @@ class KnowledgeTable:
|
|||
return None
|
||||
|
||||
def get_knowledge_bases(
|
||||
self, skip: int = 0, limit: int = 30
|
||||
self, skip: int = 0, limit: int = 30, db: Optional[Session] = None
|
||||
) -> list[KnowledgeUserModel]:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
all_knowledge = (
|
||||
db.query(Knowledge).order_by(Knowledge.updated_at.desc()).all()
|
||||
)
|
||||
user_ids = list(set(knowledge.user_id for knowledge in all_knowledge))
|
||||
|
||||
users = Users.get_users_by_user_ids(user_ids) if user_ids else []
|
||||
users = Users.get_users_by_user_ids(user_ids, db=db) if user_ids else []
|
||||
users_dict = {user.id: user for user in users}
|
||||
|
||||
knowledge_bases = []
|
||||
|
|
@ -208,10 +210,10 @@ class KnowledgeTable:
|
|||
return knowledge_bases
|
||||
|
||||
def search_knowledge_bases(
|
||||
self, user_id: str, filter: dict, skip: int = 0, limit: int = 30
|
||||
self, user_id: str, filter: dict, skip: int = 0, limit: int = 30, db: Optional[Session] = None
|
||||
) -> KnowledgeListResponse:
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
query = db.query(Knowledge, User).outerjoin(
|
||||
User, User.id == Knowledge.user_id
|
||||
)
|
||||
|
|
@ -267,14 +269,14 @@ class KnowledgeTable:
|
|||
return KnowledgeListResponse(items=[], total=0)
|
||||
|
||||
def search_knowledge_files(
|
||||
self, filter: dict, skip: int = 0, limit: int = 30
|
||||
self, filter: dict, skip: int = 0, limit: int = 30, db: Optional[Session] = None
|
||||
) -> KnowledgeFileListResponse:
|
||||
"""
|
||||
Scalable version: search files across all knowledge bases the user has
|
||||
READ access to, without loading all KBs or using large IN() lists.
|
||||
"""
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
# Base query: join Knowledge → KnowledgeFile → File
|
||||
query = (
|
||||
db.query(File, User)
|
||||
|
|
@ -327,20 +329,20 @@ class KnowledgeTable:
|
|||
print("search_knowledge_files error:", e)
|
||||
return KnowledgeFileListResponse(items=[], total=0)
|
||||
|
||||
def check_access_by_user_id(self, id, user_id, permission="write") -> bool:
|
||||
knowledge = self.get_knowledge_by_id(id)
|
||||
def check_access_by_user_id(self, id, user_id, permission="write", db: Optional[Session] = None) -> bool:
|
||||
knowledge = self.get_knowledge_by_id(id, db=db)
|
||||
if not knowledge:
|
||||
return False
|
||||
if knowledge.user_id == user_id:
|
||||
return True
|
||||
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id)}
|
||||
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id, db=db)}
|
||||
return has_access(user_id, permission, knowledge.access_control, user_group_ids)
|
||||
|
||||
def get_knowledge_bases_by_user_id(
|
||||
self, user_id: str, permission: str = "write"
|
||||
self, user_id: str, permission: str = "write", db: Optional[Session] = None
|
||||
) -> list[KnowledgeUserModel]:
|
||||
knowledge_bases = self.get_knowledge_bases()
|
||||
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id)}
|
||||
knowledge_bases = self.get_knowledge_bases(db=db)
|
||||
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id, db=db)}
|
||||
return [
|
||||
knowledge_base
|
||||
for knowledge_base in knowledge_bases
|
||||
|
|
@ -350,32 +352,32 @@ class KnowledgeTable:
|
|||
)
|
||||
]
|
||||
|
||||
def get_knowledge_by_id(self, id: str) -> Optional[KnowledgeModel]:
|
||||
def get_knowledge_by_id(self, id: str, db: Optional[Session] = None) -> Optional[KnowledgeModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
knowledge = db.query(Knowledge).filter_by(id=id).first()
|
||||
return KnowledgeModel.model_validate(knowledge) if knowledge else None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_knowledge_by_id_and_user_id(
|
||||
self, id: str, user_id: str
|
||||
self, id: str, user_id: str, db: Optional[Session] = None
|
||||
) -> Optional[KnowledgeModel]:
|
||||
knowledge = self.get_knowledge_by_id(id)
|
||||
knowledge = self.get_knowledge_by_id(id, db=db)
|
||||
if not knowledge:
|
||||
return None
|
||||
|
||||
if knowledge.user_id == user_id:
|
||||
return knowledge
|
||||
|
||||
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id)}
|
||||
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id, db=db)}
|
||||
if has_access(user_id, "write", knowledge.access_control, user_group_ids):
|
||||
return knowledge
|
||||
return None
|
||||
|
||||
def get_knowledges_by_file_id(self, file_id: str) -> list[KnowledgeModel]:
|
||||
def get_knowledges_by_file_id(self, file_id: str, db: Optional[Session] = None) -> list[KnowledgeModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
knowledges = (
|
||||
db.query(Knowledge)
|
||||
.join(KnowledgeFile, Knowledge.id == KnowledgeFile.knowledge_id)
|
||||
|
|
@ -395,9 +397,10 @@ class KnowledgeTable:
|
|||
filter: dict,
|
||||
skip: int = 0,
|
||||
limit: int = 30,
|
||||
db: Optional[Session] = None,
|
||||
) -> KnowledgeFileListResponse:
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
query = (
|
||||
db.query(File, User)
|
||||
.join(KnowledgeFile, File.id == KnowledgeFile.file_id)
|
||||
|
|
@ -470,9 +473,9 @@ class KnowledgeTable:
|
|||
print(e)
|
||||
return KnowledgeFileListResponse(items=[], total=0)
|
||||
|
||||
def get_files_by_id(self, knowledge_id: str) -> list[FileModel]:
|
||||
def get_files_by_id(self, knowledge_id: str, db: Optional[Session] = None) -> list[FileModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
files = (
|
||||
db.query(File)
|
||||
.join(KnowledgeFile, File.id == KnowledgeFile.file_id)
|
||||
|
|
@ -483,18 +486,18 @@ class KnowledgeTable:
|
|||
except Exception:
|
||||
return []
|
||||
|
||||
def get_file_metadatas_by_id(self, knowledge_id: str) -> list[FileMetadataResponse]:
|
||||
def get_file_metadatas_by_id(self, knowledge_id: str, db: Optional[Session] = None) -> list[FileMetadataResponse]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
files = self.get_files_by_id(knowledge_id)
|
||||
with get_db_context(db) as db:
|
||||
files = self.get_files_by_id(knowledge_id, db=db)
|
||||
return [FileMetadataResponse(**file.model_dump()) for file in files]
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
def add_file_to_knowledge_by_id(
|
||||
self, knowledge_id: str, file_id: str, user_id: str
|
||||
self, knowledge_id: str, file_id: str, user_id: str, db: Optional[Session] = None
|
||||
) -> Optional[KnowledgeFileModel]:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
knowledge_file = KnowledgeFileModel(
|
||||
**{
|
||||
"id": str(uuid.uuid4()),
|
||||
|
|
@ -518,9 +521,9 @@ class KnowledgeTable:
|
|||
except Exception:
|
||||
return None
|
||||
|
||||
def remove_file_from_knowledge_by_id(self, knowledge_id: str, file_id: str) -> bool:
|
||||
def remove_file_from_knowledge_by_id(self, knowledge_id: str, file_id: str, db: Optional[Session] = None) -> bool:
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
db.query(KnowledgeFile).filter_by(
|
||||
knowledge_id=knowledge_id, file_id=file_id
|
||||
).delete()
|
||||
|
|
@ -529,9 +532,9 @@ class KnowledgeTable:
|
|||
except Exception:
|
||||
return False
|
||||
|
||||
def reset_knowledge_by_id(self, id: str) -> Optional[KnowledgeModel]:
|
||||
def reset_knowledge_by_id(self, id: str, db: Optional[Session] = None) -> Optional[KnowledgeModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
# Delete all knowledge_file entries for this knowledge_id
|
||||
db.query(KnowledgeFile).filter_by(knowledge_id=id).delete()
|
||||
db.commit()
|
||||
|
|
@ -544,17 +547,17 @@ class KnowledgeTable:
|
|||
)
|
||||
db.commit()
|
||||
|
||||
return self.get_knowledge_by_id(id=id)
|
||||
return self.get_knowledge_by_id(id=id, db=db)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
return None
|
||||
|
||||
def update_knowledge_by_id(
|
||||
self, id: str, form_data: KnowledgeForm, overwrite: bool = False
|
||||
self, id: str, form_data: KnowledgeForm, overwrite: bool = False, db: Optional[Session] = None
|
||||
) -> Optional[KnowledgeModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
knowledge = self.get_knowledge_by_id(id=id)
|
||||
with get_db_context(db) as db:
|
||||
knowledge = self.get_knowledge_by_id(id=id, db=db)
|
||||
db.query(Knowledge).filter_by(id=id).update(
|
||||
{
|
||||
**form_data.model_dump(),
|
||||
|
|
@ -562,17 +565,17 @@ class KnowledgeTable:
|
|||
}
|
||||
)
|
||||
db.commit()
|
||||
return self.get_knowledge_by_id(id=id)
|
||||
return self.get_knowledge_by_id(id=id, db=db)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
return None
|
||||
|
||||
def update_knowledge_data_by_id(
|
||||
self, id: str, data: dict
|
||||
self, id: str, data: dict, db: Optional[Session] = None
|
||||
) -> Optional[KnowledgeModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
knowledge = self.get_knowledge_by_id(id=id)
|
||||
with get_db_context(db) as db:
|
||||
knowledge = self.get_knowledge_by_id(id=id, db=db)
|
||||
db.query(Knowledge).filter_by(id=id).update(
|
||||
{
|
||||
"data": data,
|
||||
|
|
@ -580,22 +583,22 @@ class KnowledgeTable:
|
|||
}
|
||||
)
|
||||
db.commit()
|
||||
return self.get_knowledge_by_id(id=id)
|
||||
return self.get_knowledge_by_id(id=id, db=db)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
return None
|
||||
|
||||
def delete_knowledge_by_id(self, id: str) -> bool:
|
||||
def delete_knowledge_by_id(self, id: str, db: Optional[Session] = None) -> bool:
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
db.query(Knowledge).filter_by(id=id).delete()
|
||||
db.commit()
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def delete_all_knowledge(self) -> bool:
|
||||
with get_db() as db:
|
||||
def delete_all_knowledge(self, db: Optional[Session] = None) -> bool:
|
||||
with get_db_context(db) as db:
|
||||
try:
|
||||
db.query(Knowledge).delete()
|
||||
db.commit()
|
||||
|
|
|
|||
|
|
@ -2,7 +2,8 @@ import time
|
|||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.internal.db import Base, get_db
|
||||
from sqlalchemy.orm import Session
|
||||
from open_webui.internal.db import Base, get_db, get_db_context
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import BigInteger, Column, String, Text
|
||||
|
||||
|
|
@ -41,8 +42,9 @@ class MemoriesTable:
|
|||
self,
|
||||
user_id: str,
|
||||
content: str,
|
||||
db: Optional[Session] = None,
|
||||
) -> Optional[MemoryModel]:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
id = str(uuid.uuid4())
|
||||
|
||||
memory = MemoryModel(
|
||||
|
|
@ -68,8 +70,9 @@ class MemoriesTable:
|
|||
id: str,
|
||||
user_id: str,
|
||||
content: str,
|
||||
db: Optional[Session] = None,
|
||||
) -> Optional[MemoryModel]:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
try:
|
||||
memory = db.get(Memory, id)
|
||||
if not memory or memory.user_id != user_id:
|
||||
|
|
@ -83,32 +86,32 @@ class MemoriesTable:
|
|||
except Exception:
|
||||
return None
|
||||
|
||||
def get_memories(self) -> list[MemoryModel]:
|
||||
with get_db() as db:
|
||||
def get_memories(self, db: Optional[Session] = None) -> list[MemoryModel]:
|
||||
with get_db_context(db) as db:
|
||||
try:
|
||||
memories = db.query(Memory).all()
|
||||
return [MemoryModel.model_validate(memory) for memory in memories]
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_memories_by_user_id(self, user_id: str) -> list[MemoryModel]:
|
||||
with get_db() as db:
|
||||
def get_memories_by_user_id(self, user_id: str, db: Optional[Session] = None) -> list[MemoryModel]:
|
||||
with get_db_context(db) as db:
|
||||
try:
|
||||
memories = db.query(Memory).filter_by(user_id=user_id).all()
|
||||
return [MemoryModel.model_validate(memory) for memory in memories]
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_memory_by_id(self, id: str) -> Optional[MemoryModel]:
|
||||
with get_db() as db:
|
||||
def get_memory_by_id(self, id: str, db: Optional[Session] = None) -> Optional[MemoryModel]:
|
||||
with get_db_context(db) as db:
|
||||
try:
|
||||
memory = db.get(Memory, id)
|
||||
return MemoryModel.model_validate(memory)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def delete_memory_by_id(self, id: str) -> bool:
|
||||
with get_db() as db:
|
||||
def delete_memory_by_id(self, id: str, db: Optional[Session] = None) -> bool:
|
||||
with get_db_context(db) as db:
|
||||
try:
|
||||
db.query(Memory).filter_by(id=id).delete()
|
||||
db.commit()
|
||||
|
|
@ -118,8 +121,8 @@ class MemoriesTable:
|
|||
except Exception:
|
||||
return False
|
||||
|
||||
def delete_memories_by_user_id(self, user_id: str) -> bool:
|
||||
with get_db() as db:
|
||||
def delete_memories_by_user_id(self, user_id: str, db: Optional[Session] = None) -> bool:
|
||||
with get_db_context(db) as db:
|
||||
try:
|
||||
db.query(Memory).filter_by(user_id=user_id).delete()
|
||||
db.commit()
|
||||
|
|
@ -128,8 +131,8 @@ class MemoriesTable:
|
|||
except Exception:
|
||||
return False
|
||||
|
||||
def delete_memory_by_id_and_user_id(self, id: str, user_id: str) -> bool:
|
||||
with get_db() as db:
|
||||
def delete_memory_by_id_and_user_id(self, id: str, user_id: str, db: Optional[Session] = None) -> bool:
|
||||
with get_db_context(db) as db:
|
||||
try:
|
||||
memory = db.get(Memory, id)
|
||||
if not memory or memory.user_id != user_id:
|
||||
|
|
|
|||
|
|
@ -3,7 +3,8 @@ import time
|
|||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.internal.db import Base, get_db
|
||||
from sqlalchemy.orm import Session
|
||||
from open_webui.internal.db import Base, JSONField, get_db, get_db_context
|
||||
from open_webui.models.tags import TagModel, Tag, Tags
|
||||
from open_webui.models.users import Users, User, UserNameResponse
|
||||
from open_webui.models.channels import Channels, ChannelMember
|
||||
|
|
@ -137,9 +138,9 @@ class MessageResponse(MessageReplyToResponse):
|
|||
|
||||
class MessageTable:
|
||||
def insert_new_message(
|
||||
self, form_data: MessageForm, channel_id: str, user_id: str
|
||||
self, form_data: MessageForm, channel_id: str, user_id: str, db: Optional[Session] = None
|
||||
) -> Optional[MessageModel]:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
channel_member = Channels.join_channel(channel_id, user_id)
|
||||
|
||||
id = str(uuid.uuid4())
|
||||
|
|
@ -169,22 +170,22 @@ class MessageTable:
|
|||
db.refresh(result)
|
||||
return MessageModel.model_validate(result) if result else None
|
||||
|
||||
def get_message_by_id(self, id: str) -> Optional[MessageResponse]:
|
||||
with get_db() as db:
|
||||
def get_message_by_id(self, id: str, db: Optional[Session] = None) -> Optional[MessageResponse]:
|
||||
with get_db_context(db) as db:
|
||||
message = db.get(Message, id)
|
||||
if not message:
|
||||
return None
|
||||
|
||||
reply_to_message = (
|
||||
self.get_message_by_id(message.reply_to_id)
|
||||
self.get_message_by_id(message.reply_to_id, db=db)
|
||||
if message.reply_to_id
|
||||
else None
|
||||
)
|
||||
|
||||
reactions = self.get_reactions_by_message_id(id)
|
||||
thread_replies = self.get_thread_replies_by_message_id(id)
|
||||
reactions = self.get_reactions_by_message_id(id, db=db)
|
||||
thread_replies = self.get_thread_replies_by_message_id(id, db=db)
|
||||
|
||||
user = Users.get_user_by_id(message.user_id)
|
||||
user = Users.get_user_by_id(message.user_id, db=db)
|
||||
return MessageResponse.model_validate(
|
||||
{
|
||||
**MessageModel.model_validate(message).model_dump(),
|
||||
|
|
@ -200,8 +201,8 @@ class MessageTable:
|
|||
}
|
||||
)
|
||||
|
||||
def get_thread_replies_by_message_id(self, id: str) -> list[MessageReplyToResponse]:
|
||||
with get_db() as db:
|
||||
def get_thread_replies_by_message_id(self, id: str, db: Optional[Session] = None) -> list[MessageReplyToResponse]:
|
||||
with get_db_context(db) as db:
|
||||
all_messages = (
|
||||
db.query(Message)
|
||||
.filter_by(parent_id=id)
|
||||
|
|
@ -212,7 +213,7 @@ class MessageTable:
|
|||
messages = []
|
||||
for message in all_messages:
|
||||
reply_to_message = (
|
||||
self.get_message_by_id(message.reply_to_id)
|
||||
self.get_message_by_id(message.reply_to_id, db=db)
|
||||
if message.reply_to_id
|
||||
else None
|
||||
)
|
||||
|
|
@ -230,17 +231,17 @@ class MessageTable:
|
|||
)
|
||||
return messages
|
||||
|
||||
def get_reply_user_ids_by_message_id(self, id: str) -> list[str]:
|
||||
with get_db() as db:
|
||||
def get_reply_user_ids_by_message_id(self, id: str, db: Optional[Session] = None) -> list[str]:
|
||||
with get_db_context(db) as db:
|
||||
return [
|
||||
message.user_id
|
||||
for message in db.query(Message).filter_by(parent_id=id).all()
|
||||
]
|
||||
|
||||
def get_messages_by_channel_id(
|
||||
self, channel_id: str, skip: int = 0, limit: int = 50
|
||||
self, channel_id: str, skip: int = 0, limit: int = 50, db: Optional[Session] = None
|
||||
) -> list[MessageReplyToResponse]:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
all_messages = (
|
||||
db.query(Message)
|
||||
.filter_by(channel_id=channel_id, parent_id=None)
|
||||
|
|
@ -253,7 +254,7 @@ class MessageTable:
|
|||
messages = []
|
||||
for message in all_messages:
|
||||
reply_to_message = (
|
||||
self.get_message_by_id(message.reply_to_id)
|
||||
self.get_message_by_id(message.reply_to_id, db=db)
|
||||
if message.reply_to_id
|
||||
else None
|
||||
)
|
||||
|
|
@ -272,9 +273,9 @@ class MessageTable:
|
|||
return messages
|
||||
|
||||
def get_messages_by_parent_id(
|
||||
self, channel_id: str, parent_id: str, skip: int = 0, limit: int = 50
|
||||
self, channel_id: str, parent_id: str, skip: int = 0, limit: int = 50, db: Optional[Session] = None
|
||||
) -> list[MessageReplyToResponse]:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
message = db.get(Message, parent_id)
|
||||
|
||||
if not message:
|
||||
|
|
@ -296,7 +297,7 @@ class MessageTable:
|
|||
messages = []
|
||||
for message in all_messages:
|
||||
reply_to_message = (
|
||||
self.get_message_by_id(message.reply_to_id)
|
||||
self.get_message_by_id(message.reply_to_id, db=db)
|
||||
if message.reply_to_id
|
||||
else None
|
||||
)
|
||||
|
|
@ -314,8 +315,8 @@ class MessageTable:
|
|||
)
|
||||
return messages
|
||||
|
||||
def get_last_message_by_channel_id(self, channel_id: str) -> Optional[MessageModel]:
|
||||
with get_db() as db:
|
||||
def get_last_message_by_channel_id(self, channel_id: str, db: Optional[Session] = None) -> Optional[MessageModel]:
|
||||
with get_db_context(db) as db:
|
||||
message = (
|
||||
db.query(Message)
|
||||
.filter_by(channel_id=channel_id)
|
||||
|
|
@ -325,9 +326,9 @@ class MessageTable:
|
|||
return MessageModel.model_validate(message) if message else None
|
||||
|
||||
def get_pinned_messages_by_channel_id(
|
||||
self, channel_id: str, skip: int = 0, limit: int = 50
|
||||
self, channel_id: str, skip: int = 0, limit: int = 50, db: Optional[Session] = None
|
||||
) -> list[MessageModel]:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
all_messages = (
|
||||
db.query(Message)
|
||||
.filter_by(channel_id=channel_id, is_pinned=True)
|
||||
|
|
@ -339,9 +340,9 @@ class MessageTable:
|
|||
return [MessageModel.model_validate(message) for message in all_messages]
|
||||
|
||||
def update_message_by_id(
|
||||
self, id: str, form_data: MessageForm
|
||||
self, id: str, form_data: MessageForm, db: Optional[Session] = None
|
||||
) -> Optional[MessageModel]:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
message = db.get(Message, id)
|
||||
message.content = form_data.content
|
||||
message.data = {
|
||||
|
|
@ -358,9 +359,9 @@ class MessageTable:
|
|||
return MessageModel.model_validate(message) if message else None
|
||||
|
||||
def update_is_pinned_by_id(
|
||||
self, id: str, is_pinned: bool, pinned_by: Optional[str] = None
|
||||
self, id: str, is_pinned: bool, pinned_by: Optional[str] = None, db: Optional[Session] = None
|
||||
) -> Optional[MessageModel]:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
message = db.get(Message, id)
|
||||
message.is_pinned = is_pinned
|
||||
message.pinned_at = int(time.time_ns()) if is_pinned else None
|
||||
|
|
@ -370,9 +371,9 @@ class MessageTable:
|
|||
return MessageModel.model_validate(message) if message else None
|
||||
|
||||
def get_unread_message_count(
|
||||
self, channel_id: str, user_id: str, last_read_at: Optional[int] = None
|
||||
self, channel_id: str, user_id: str, last_read_at: Optional[int] = None, db: Optional[Session] = None
|
||||
) -> int:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
query = db.query(Message).filter(
|
||||
Message.channel_id == channel_id,
|
||||
Message.parent_id == None, # only count top-level messages
|
||||
|
|
@ -383,9 +384,9 @@ class MessageTable:
|
|||
return query.count()
|
||||
|
||||
def add_reaction_to_message(
|
||||
self, id: str, user_id: str, name: str
|
||||
self, id: str, user_id: str, name: str, db: Optional[Session] = None
|
||||
) -> Optional[MessageReactionModel]:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
# check for existing reaction
|
||||
existing_reaction = (
|
||||
db.query(MessageReaction)
|
||||
|
|
@ -409,8 +410,8 @@ class MessageTable:
|
|||
db.refresh(result)
|
||||
return MessageReactionModel.model_validate(result) if result else None
|
||||
|
||||
def get_reactions_by_message_id(self, id: str) -> list[Reactions]:
|
||||
with get_db() as db:
|
||||
def get_reactions_by_message_id(self, id: str, db: Optional[Session] = None) -> list[Reactions]:
|
||||
with get_db_context(db) as db:
|
||||
# JOIN User so all user info is fetched in one query
|
||||
results = (
|
||||
db.query(MessageReaction, User)
|
||||
|
|
@ -440,29 +441,29 @@ class MessageTable:
|
|||
return [Reactions(**reaction) for reaction in reactions.values()]
|
||||
|
||||
def remove_reaction_by_id_and_user_id_and_name(
|
||||
self, id: str, user_id: str, name: str
|
||||
self, id: str, user_id: str, name: str, db: Optional[Session] = None
|
||||
) -> bool:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
db.query(MessageReaction).filter_by(
|
||||
message_id=id, user_id=user_id, name=name
|
||||
).delete()
|
||||
db.commit()
|
||||
return True
|
||||
|
||||
def delete_reactions_by_id(self, id: str) -> bool:
|
||||
with get_db() as db:
|
||||
def delete_reactions_by_id(self, id: str, db: Optional[Session] = None) -> bool:
|
||||
with get_db_context(db) as db:
|
||||
db.query(MessageReaction).filter_by(message_id=id).delete()
|
||||
db.commit()
|
||||
return True
|
||||
|
||||
def delete_replies_by_id(self, id: str) -> bool:
|
||||
with get_db() as db:
|
||||
def delete_replies_by_id(self, id: str, db: Optional[Session] = None) -> bool:
|
||||
with get_db_context(db) as db:
|
||||
db.query(Message).filter_by(parent_id=id).delete()
|
||||
db.commit()
|
||||
return True
|
||||
|
||||
def delete_message_by_id(self, id: str) -> bool:
|
||||
with get_db() as db:
|
||||
def delete_message_by_id(self, id: str, db: Optional[Session] = None) -> bool:
|
||||
with get_db_context(db) as db:
|
||||
db.query(Message).filter_by(id=id).delete()
|
||||
|
||||
# Delete all reactions to this message
|
||||
|
|
|
|||
|
|
@ -2,7 +2,8 @@ import logging
|
|||
import time
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.internal.db import Base, JSONField, get_db
|
||||
from sqlalchemy.orm import Session
|
||||
from open_webui.internal.db import Base, JSONField, get_db, get_db_context
|
||||
|
||||
from open_webui.models.groups import Groups
|
||||
from open_webui.models.users import User, UserModel, Users, UserResponse
|
||||
|
|
@ -150,7 +151,7 @@ class ModelForm(BaseModel):
|
|||
|
||||
class ModelsTable:
|
||||
def insert_new_model(
|
||||
self, form_data: ModelForm, user_id: str
|
||||
self, form_data: ModelForm, user_id: str, db: Optional[Session] = None
|
||||
) -> Optional[ModelModel]:
|
||||
model = ModelModel(
|
||||
**{
|
||||
|
|
@ -161,7 +162,7 @@ class ModelsTable:
|
|||
}
|
||||
)
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
result = Model(**model.model_dump())
|
||||
db.add(result)
|
||||
db.commit()
|
||||
|
|
@ -175,17 +176,17 @@ class ModelsTable:
|
|||
log.exception(f"Failed to insert a new model: {e}")
|
||||
return None
|
||||
|
||||
def get_all_models(self) -> list[ModelModel]:
|
||||
with get_db() as db:
|
||||
def get_all_models(self, db: Optional[Session] = None) -> list[ModelModel]:
|
||||
with get_db_context(db) as db:
|
||||
return [ModelModel.model_validate(model) for model in db.query(Model).all()]
|
||||
|
||||
def get_models(self) -> list[ModelUserResponse]:
|
||||
with get_db() as db:
|
||||
def get_models(self, db: Optional[Session] = None) -> list[ModelUserResponse]:
|
||||
with get_db_context(db) as db:
|
||||
all_models = db.query(Model).filter(Model.base_model_id != None).all()
|
||||
|
||||
user_ids = list(set(model.user_id for model in all_models))
|
||||
|
||||
users = Users.get_users_by_user_ids(user_ids) if user_ids else []
|
||||
users = Users.get_users_by_user_ids(user_ids, db=db) if user_ids else []
|
||||
users_dict = {user.id: user for user in users}
|
||||
|
||||
models = []
|
||||
|
|
@ -201,18 +202,18 @@ class ModelsTable:
|
|||
)
|
||||
return models
|
||||
|
||||
def get_base_models(self) -> list[ModelModel]:
|
||||
with get_db() as db:
|
||||
def get_base_models(self, db: Optional[Session] = None) -> list[ModelModel]:
|
||||
with get_db_context(db) as db:
|
||||
return [
|
||||
ModelModel.model_validate(model)
|
||||
for model in db.query(Model).filter(Model.base_model_id == None).all()
|
||||
]
|
||||
|
||||
def get_models_by_user_id(
|
||||
self, user_id: str, permission: str = "write"
|
||||
self, user_id: str, permission: str = "write", db: Optional[Session] = None
|
||||
) -> list[ModelUserResponse]:
|
||||
models = self.get_models()
|
||||
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id)}
|
||||
models = self.get_models(db=db)
|
||||
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id, db=db)}
|
||||
return [
|
||||
model
|
||||
for model in models
|
||||
|
|
@ -263,9 +264,9 @@ class ModelsTable:
|
|||
return query
|
||||
|
||||
def search_models(
|
||||
self, user_id: str, filter: dict = {}, skip: int = 0, limit: int = 30
|
||||
self, user_id: str, filter: dict = {}, skip: int = 0, limit: int = 30, db: Optional[Session] = None
|
||||
) -> ModelListResponse:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
# Join GroupMember so we can order by group_id when requested
|
||||
query = db.query(Model, User).outerjoin(User, User.id == Model.user_id)
|
||||
query = query.filter(Model.base_model_id != None)
|
||||
|
|
@ -349,24 +350,24 @@ class ModelsTable:
|
|||
|
||||
return ModelListResponse(items=models, total=total)
|
||||
|
||||
def get_model_by_id(self, id: str) -> Optional[ModelModel]:
|
||||
def get_model_by_id(self, id: str, db: Optional[Session] = None) -> Optional[ModelModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
model = db.get(Model, id)
|
||||
return ModelModel.model_validate(model)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_models_by_ids(self, ids: list[str]) -> list[ModelModel]:
|
||||
def get_models_by_ids(self, ids: list[str], db: Optional[Session] = None) -> list[ModelModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
models = db.query(Model).filter(Model.id.in_(ids)).all()
|
||||
return [ModelModel.model_validate(model) for model in models]
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
def toggle_model_by_id(self, id: str) -> Optional[ModelModel]:
|
||||
with get_db() as db:
|
||||
def toggle_model_by_id(self, id: str, db: Optional[Session] = None) -> Optional[ModelModel]:
|
||||
with get_db_context(db) as db:
|
||||
try:
|
||||
is_active = db.query(Model).filter_by(id=id).first().is_active
|
||||
|
||||
|
|
@ -378,13 +379,13 @@ class ModelsTable:
|
|||
)
|
||||
db.commit()
|
||||
|
||||
return self.get_model_by_id(id)
|
||||
return self.get_model_by_id(id, db=db)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def update_model_by_id(self, id: str, model: ModelForm) -> Optional[ModelModel]:
|
||||
def update_model_by_id(self, id: str, model: ModelForm, db: Optional[Session] = None) -> Optional[ModelModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
# update only the fields that are present in the model
|
||||
data = model.model_dump(exclude={"id"})
|
||||
result = db.query(Model).filter_by(id=id).update(data)
|
||||
|
|
@ -398,9 +399,9 @@ class ModelsTable:
|
|||
log.exception(f"Failed to update the model by id {id}: {e}")
|
||||
return None
|
||||
|
||||
def delete_model_by_id(self, id: str) -> bool:
|
||||
def delete_model_by_id(self, id: str, db: Optional[Session] = None) -> bool:
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
db.query(Model).filter_by(id=id).delete()
|
||||
db.commit()
|
||||
|
||||
|
|
@ -408,9 +409,9 @@ class ModelsTable:
|
|||
except Exception:
|
||||
return False
|
||||
|
||||
def delete_all_models(self) -> bool:
|
||||
def delete_all_models(self, db: Optional[Session] = None) -> bool:
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
db.query(Model).delete()
|
||||
db.commit()
|
||||
|
||||
|
|
@ -418,9 +419,9 @@ class ModelsTable:
|
|||
except Exception:
|
||||
return False
|
||||
|
||||
def sync_models(self, user_id: str, models: list[ModelModel]) -> list[ModelModel]:
|
||||
def sync_models(self, user_id: str, models: list[ModelModel], db: Optional[Session] = None) -> list[ModelModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
# Get existing models
|
||||
existing_models = db.query(Model).all()
|
||||
existing_ids = {model.id for model in existing_models}
|
||||
|
|
|
|||
|
|
@ -4,7 +4,8 @@ import uuid
|
|||
from typing import Optional
|
||||
from functools import lru_cache
|
||||
|
||||
from open_webui.internal.db import Base, get_db
|
||||
from sqlalchemy.orm import Session
|
||||
from open_webui.internal.db import Base, get_db, get_db_context
|
||||
from open_webui.models.groups import Groups
|
||||
from open_webui.utils.access_control import has_access
|
||||
from open_webui.models.users import User, UserModel, Users, UserResponse
|
||||
|
|
@ -211,11 +212,9 @@ class NoteTable:
|
|||
return query
|
||||
|
||||
def insert_new_note(
|
||||
self,
|
||||
form_data: NoteForm,
|
||||
user_id: str,
|
||||
self, user_id: str, form_data: NoteForm, db: Optional[Session] = None
|
||||
) -> Optional[NoteModel]:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
note = NoteModel(
|
||||
**{
|
||||
"id": str(uuid.uuid4()),
|
||||
|
|
@ -233,9 +232,9 @@ class NoteTable:
|
|||
return note
|
||||
|
||||
def get_notes(
|
||||
self, skip: Optional[int] = None, limit: Optional[int] = None
|
||||
self, skip: int = 0, limit: int = 50, db: Optional[Session] = None
|
||||
) -> list[NoteModel]:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
query = db.query(Note).order_by(Note.updated_at.desc())
|
||||
if skip is not None:
|
||||
query = query.offset(skip)
|
||||
|
|
@ -333,10 +332,11 @@ class NoteTable:
|
|||
self,
|
||||
user_id: str,
|
||||
permission: str = "read",
|
||||
skip: Optional[int] = None,
|
||||
limit: Optional[int] = None,
|
||||
skip: int = 0,
|
||||
limit: int = 50,
|
||||
db: Optional[Session] = None,
|
||||
) -> list[NoteModel]:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
user_group_ids = [
|
||||
group.id for group in Groups.get_groups_by_member_id(user_id)
|
||||
]
|
||||
|
|
@ -354,15 +354,17 @@ class NoteTable:
|
|||
notes = query.all()
|
||||
return [NoteModel.model_validate(note) for note in notes]
|
||||
|
||||
def get_note_by_id(self, id: str) -> Optional[NoteModel]:
|
||||
with get_db() as db:
|
||||
def get_note_by_id(
|
||||
self, id: str, db: Optional[Session] = None
|
||||
) -> Optional[NoteModel]:
|
||||
with get_db_context(db) as db:
|
||||
note = db.query(Note).filter(Note.id == id).first()
|
||||
return NoteModel.model_validate(note) if note else None
|
||||
|
||||
def update_note_by_id(
|
||||
self, id: str, form_data: NoteUpdateForm
|
||||
self, id: str, form_data: NoteUpdateForm, db: Optional[Session] = None
|
||||
) -> Optional[NoteModel]:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
note = db.query(Note).filter(Note.id == id).first()
|
||||
if not note:
|
||||
return None
|
||||
|
|
@ -384,11 +386,14 @@ class NoteTable:
|
|||
db.commit()
|
||||
return NoteModel.model_validate(note) if note else None
|
||||
|
||||
def delete_note_by_id(self, id: str):
|
||||
with get_db() as db:
|
||||
db.query(Note).filter(Note.id == id).delete()
|
||||
db.commit()
|
||||
return True
|
||||
def delete_note_by_id(self, id: str, db: Optional[Session] = None) -> bool:
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
db.query(Note).filter(Note.id == id).delete()
|
||||
db.commit()
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
Notes = NoteTable()
|
||||
|
|
|
|||
|
|
@ -8,7 +8,8 @@ import json
|
|||
|
||||
from cryptography.fernet import Fernet
|
||||
|
||||
from open_webui.internal.db import Base, get_db
|
||||
from sqlalchemy.orm import Session
|
||||
from open_webui.internal.db import Base, get_db, get_db_context
|
||||
from open_webui.env import OAUTH_SESSION_TOKEN_ENCRYPTION_KEY
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
|
@ -109,10 +110,11 @@ class OAuthSessionTable:
|
|||
user_id: str,
|
||||
provider: str,
|
||||
token: dict,
|
||||
db: Optional[Session] = None,
|
||||
) -> Optional[OAuthSessionModel]:
|
||||
"""Create a new OAuth session"""
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
current_time = int(time.time())
|
||||
id = str(uuid.uuid4())
|
||||
|
||||
|
|
@ -141,10 +143,10 @@ class OAuthSessionTable:
|
|||
log.error(f"Error creating OAuth session: {e}")
|
||||
return None
|
||||
|
||||
def get_session_by_id(self, session_id: str) -> Optional[OAuthSessionModel]:
|
||||
def get_session_by_id(self, session_id: str, db: Optional[Session] = None) -> Optional[OAuthSessionModel]:
|
||||
"""Get OAuth session by ID"""
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
session = db.query(OAuthSession).filter_by(id=session_id).first()
|
||||
if session:
|
||||
session.token = self._decrypt_token(session.token)
|
||||
|
|
@ -156,11 +158,11 @@ class OAuthSessionTable:
|
|||
return None
|
||||
|
||||
def get_session_by_id_and_user_id(
|
||||
self, session_id: str, user_id: str
|
||||
self, session_id: str, user_id: str, db: Optional[Session] = None
|
||||
) -> Optional[OAuthSessionModel]:
|
||||
"""Get OAuth session by ID and user ID"""
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
session = (
|
||||
db.query(OAuthSession)
|
||||
.filter_by(id=session_id, user_id=user_id)
|
||||
|
|
@ -176,11 +178,11 @@ class OAuthSessionTable:
|
|||
return None
|
||||
|
||||
def get_session_by_provider_and_user_id(
|
||||
self, provider: str, user_id: str
|
||||
self, provider: str, user_id: str, db: Optional[Session] = None
|
||||
) -> Optional[OAuthSessionModel]:
|
||||
"""Get OAuth session by provider and user ID"""
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
session = (
|
||||
db.query(OAuthSession)
|
||||
.filter_by(provider=provider, user_id=user_id)
|
||||
|
|
@ -195,10 +197,10 @@ class OAuthSessionTable:
|
|||
log.error(f"Error getting OAuth session by provider and user ID: {e}")
|
||||
return None
|
||||
|
||||
def get_sessions_by_user_id(self, user_id: str) -> List[OAuthSessionModel]:
|
||||
def get_sessions_by_user_id(self, user_id: str, db: Optional[Session] = None) -> List[OAuthSessionModel]:
|
||||
"""Get all OAuth sessions for a user"""
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
sessions = db.query(OAuthSession).filter_by(user_id=user_id).all()
|
||||
|
||||
results = []
|
||||
|
|
@ -213,11 +215,11 @@ class OAuthSessionTable:
|
|||
return []
|
||||
|
||||
def update_session_by_id(
|
||||
self, session_id: str, token: dict
|
||||
self, session_id: str, token: dict, db: Optional[Session] = None
|
||||
) -> Optional[OAuthSessionModel]:
|
||||
"""Update OAuth session tokens"""
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
current_time = int(time.time())
|
||||
|
||||
db.query(OAuthSession).filter_by(id=session_id).update(
|
||||
|
|
@ -239,10 +241,10 @@ class OAuthSessionTable:
|
|||
log.error(f"Error updating OAuth session tokens: {e}")
|
||||
return None
|
||||
|
||||
def delete_session_by_id(self, session_id: str) -> bool:
|
||||
def delete_session_by_id(self, session_id: str, db: Optional[Session] = None) -> bool:
|
||||
"""Delete an OAuth session"""
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
result = db.query(OAuthSession).filter_by(id=session_id).delete()
|
||||
db.commit()
|
||||
return result > 0
|
||||
|
|
@ -250,10 +252,10 @@ class OAuthSessionTable:
|
|||
log.error(f"Error deleting OAuth session: {e}")
|
||||
return False
|
||||
|
||||
def delete_sessions_by_user_id(self, user_id: str) -> bool:
|
||||
def delete_sessions_by_user_id(self, user_id: str, db: Optional[Session] = None) -> bool:
|
||||
"""Delete all OAuth sessions for a user"""
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
result = db.query(OAuthSession).filter_by(user_id=user_id).delete()
|
||||
db.commit()
|
||||
return True
|
||||
|
|
@ -261,10 +263,10 @@ class OAuthSessionTable:
|
|||
log.error(f"Error deleting OAuth sessions by user ID: {e}")
|
||||
return False
|
||||
|
||||
def delete_sessions_by_provider(self, provider: str) -> bool:
|
||||
def delete_sessions_by_provider(self, provider: str, db: Optional[Session] = None) -> bool:
|
||||
"""Delete all OAuth sessions for a provider"""
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
db.query(OAuthSession).filter_by(provider=provider).delete()
|
||||
db.commit()
|
||||
return True
|
||||
|
|
|
|||
|
|
@ -1,7 +1,9 @@
|
|||
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.internal.db import Base, get_db
|
||||
from sqlalchemy.orm import Session
|
||||
from open_webui.internal.db import Base, JSONField, get_db, get_db_context
|
||||
from open_webui.models.groups import Groups
|
||||
from open_webui.models.users import Users, UserResponse
|
||||
|
||||
|
|
@ -71,7 +73,7 @@ class PromptForm(BaseModel):
|
|||
|
||||
class PromptsTable:
|
||||
def insert_new_prompt(
|
||||
self, user_id: str, form_data: PromptForm
|
||||
self, user_id: str, form_data: PromptForm, db: Optional[Session] = None
|
||||
) -> Optional[PromptModel]:
|
||||
prompt = PromptModel(
|
||||
**{
|
||||
|
|
@ -82,7 +84,7 @@ class PromptsTable:
|
|||
)
|
||||
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
result = Prompt(**prompt.model_dump())
|
||||
db.add(result)
|
||||
db.commit()
|
||||
|
|
@ -94,21 +96,21 @@ class PromptsTable:
|
|||
except Exception:
|
||||
return None
|
||||
|
||||
def get_prompt_by_command(self, command: str) -> Optional[PromptModel]:
|
||||
def get_prompt_by_command(self, command: str, db: Optional[Session] = None) -> Optional[PromptModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
prompt = db.query(Prompt).filter_by(command=command).first()
|
||||
return PromptModel.model_validate(prompt)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_prompts(self) -> list[PromptUserResponse]:
|
||||
with get_db() as db:
|
||||
def get_prompts(self, db: Optional[Session] = None) -> list[PromptUserResponse]:
|
||||
with get_db_context(db) as db:
|
||||
all_prompts = db.query(Prompt).order_by(Prompt.timestamp.desc()).all()
|
||||
|
||||
user_ids = list(set(prompt.user_id for prompt in all_prompts))
|
||||
|
||||
users = Users.get_users_by_user_ids(user_ids) if user_ids else []
|
||||
users = Users.get_users_by_user_ids(user_ids, db=db) if user_ids else []
|
||||
users_dict = {user.id: user for user in users}
|
||||
|
||||
prompts = []
|
||||
|
|
@ -126,10 +128,10 @@ class PromptsTable:
|
|||
return prompts
|
||||
|
||||
def get_prompts_by_user_id(
|
||||
self, user_id: str, permission: str = "write"
|
||||
self, user_id: str, permission: str = "write", db: Optional[Session] = None
|
||||
) -> list[PromptUserResponse]:
|
||||
prompts = self.get_prompts()
|
||||
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id)}
|
||||
prompts = self.get_prompts(db=db)
|
||||
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id, db=db)}
|
||||
|
||||
return [
|
||||
prompt
|
||||
|
|
@ -139,10 +141,10 @@ class PromptsTable:
|
|||
]
|
||||
|
||||
def update_prompt_by_command(
|
||||
self, command: str, form_data: PromptForm
|
||||
self, command: str, form_data: PromptForm, db: Optional[Session] = None
|
||||
) -> Optional[PromptModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
prompt = db.query(Prompt).filter_by(command=command).first()
|
||||
prompt.title = form_data.title
|
||||
prompt.content = form_data.content
|
||||
|
|
@ -153,9 +155,9 @@ class PromptsTable:
|
|||
except Exception:
|
||||
return None
|
||||
|
||||
def delete_prompt_by_command(self, command: str) -> bool:
|
||||
def delete_prompt_by_command(self, command: str, db: Optional[Session] = None) -> bool:
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
db.query(Prompt).filter_by(command=command).delete()
|
||||
db.commit()
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,8 @@ import time
|
|||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.internal.db import Base, get_db
|
||||
from sqlalchemy.orm import Session
|
||||
from open_webui.internal.db import Base, JSONField, get_db, get_db_context
|
||||
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
|
@ -50,8 +51,8 @@ class TagChatIdForm(BaseModel):
|
|||
|
||||
|
||||
class TagTable:
|
||||
def insert_new_tag(self, name: str, user_id: str) -> Optional[TagModel]:
|
||||
with get_db() as db:
|
||||
def insert_new_tag(self, name: str, user_id: str, db: Optional[Session] = None) -> Optional[TagModel]:
|
||||
with get_db_context(db) as db:
|
||||
id = name.replace(" ", "_").lower()
|
||||
tag = TagModel(**{"id": id, "user_id": user_id, "name": name})
|
||||
try:
|
||||
|
|
@ -68,27 +69,27 @@ class TagTable:
|
|||
return None
|
||||
|
||||
def get_tag_by_name_and_user_id(
|
||||
self, name: str, user_id: str
|
||||
self, name: str, user_id: str, db: Optional[Session] = None
|
||||
) -> Optional[TagModel]:
|
||||
try:
|
||||
id = name.replace(" ", "_").lower()
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
tag = db.query(Tag).filter_by(id=id, user_id=user_id).first()
|
||||
return TagModel.model_validate(tag)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_tags_by_user_id(self, user_id: str) -> list[TagModel]:
|
||||
with get_db() as db:
|
||||
def get_tags_by_user_id(self, user_id: str, db: Optional[Session] = None) -> list[TagModel]:
|
||||
with get_db_context(db) as db:
|
||||
return [
|
||||
TagModel.model_validate(tag)
|
||||
for tag in (db.query(Tag).filter_by(user_id=user_id).all())
|
||||
]
|
||||
|
||||
def get_tags_by_ids_and_user_id(
|
||||
self, ids: list[str], user_id: str
|
||||
self, ids: list[str], user_id: str, db: Optional[Session] = None
|
||||
) -> list[TagModel]:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
return [
|
||||
TagModel.model_validate(tag)
|
||||
for tag in (
|
||||
|
|
@ -96,9 +97,9 @@ class TagTable:
|
|||
)
|
||||
]
|
||||
|
||||
def delete_tag_by_name_and_user_id(self, name: str, user_id: str) -> bool:
|
||||
def delete_tag_by_name_and_user_id(self, name: str, user_id: str, db: Optional[Session] = None) -> bool:
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
id = name.replace(" ", "_").lower()
|
||||
res = db.query(Tag).filter_by(id=id, user_id=user_id).delete()
|
||||
log.debug(f"res: {res}")
|
||||
|
|
|
|||
|
|
@ -2,7 +2,8 @@ import logging
|
|||
import time
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.internal.db import Base, JSONField, get_db
|
||||
from sqlalchemy.orm import Session
|
||||
from open_webui.internal.db import Base, JSONField, get_db, get_db_context
|
||||
from open_webui.models.users import Users, UserResponse
|
||||
from open_webui.models.groups import Groups
|
||||
|
||||
|
|
@ -110,9 +111,9 @@ class ToolValves(BaseModel):
|
|||
|
||||
class ToolsTable:
|
||||
def insert_new_tool(
|
||||
self, user_id: str, form_data: ToolForm, specs: list[dict]
|
||||
self, user_id: str, form_data: ToolForm, specs: list[dict], db: Optional[Session] = None
|
||||
) -> Optional[ToolModel]:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
tool = ToolModel(
|
||||
**{
|
||||
**form_data.model_dump(),
|
||||
|
|
@ -136,21 +137,21 @@ class ToolsTable:
|
|||
log.exception(f"Error creating a new tool: {e}")
|
||||
return None
|
||||
|
||||
def get_tool_by_id(self, id: str) -> Optional[ToolModel]:
|
||||
def get_tool_by_id(self, id: str, db: Optional[Session] = None) -> Optional[ToolModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
tool = db.get(Tool, id)
|
||||
return ToolModel.model_validate(tool)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_tools(self) -> list[ToolUserModel]:
|
||||
with get_db() as db:
|
||||
def get_tools(self, db: Optional[Session] = None) -> list[ToolUserModel]:
|
||||
with get_db_context(db) as db:
|
||||
all_tools = db.query(Tool).order_by(Tool.updated_at.desc()).all()
|
||||
|
||||
user_ids = list(set(tool.user_id for tool in all_tools))
|
||||
|
||||
users = Users.get_users_by_user_ids(user_ids) if user_ids else []
|
||||
users = Users.get_users_by_user_ids(user_ids, db=db) if user_ids else []
|
||||
users_dict = {user.id: user for user in users}
|
||||
|
||||
tools = []
|
||||
|
|
@ -167,10 +168,10 @@ class ToolsTable:
|
|||
return tools
|
||||
|
||||
def get_tools_by_user_id(
|
||||
self, user_id: str, permission: str = "write"
|
||||
self, user_id: str, permission: str = "write", db: Optional[Session] = None
|
||||
) -> list[ToolUserModel]:
|
||||
tools = self.get_tools()
|
||||
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id)}
|
||||
tools = self.get_tools(db=db)
|
||||
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id, db=db)}
|
||||
|
||||
return [
|
||||
tool
|
||||
|
|
@ -179,31 +180,31 @@ class ToolsTable:
|
|||
or has_access(user_id, permission, tool.access_control, user_group_ids)
|
||||
]
|
||||
|
||||
def get_tool_valves_by_id(self, id: str) -> Optional[dict]:
|
||||
def get_tool_valves_by_id(self, id: str, db: Optional[Session] = None) -> Optional[dict]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
tool = db.get(Tool, id)
|
||||
return tool.valves if tool.valves else {}
|
||||
except Exception as e:
|
||||
log.exception(f"Error getting tool valves by id {id}")
|
||||
return None
|
||||
|
||||
def update_tool_valves_by_id(self, id: str, valves: dict) -> Optional[ToolValves]:
|
||||
def update_tool_valves_by_id(self, id: str, valves: dict, db: Optional[Session] = None) -> Optional[ToolValves]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
db.query(Tool).filter_by(id=id).update(
|
||||
{"valves": valves, "updated_at": int(time.time())}
|
||||
)
|
||||
db.commit()
|
||||
return self.get_tool_by_id(id)
|
||||
return self.get_tool_by_id(id, db=db)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_user_valves_by_id_and_user_id(
|
||||
self, id: str, user_id: str
|
||||
self, id: str, user_id: str, db: Optional[Session] = None
|
||||
) -> Optional[dict]:
|
||||
try:
|
||||
user = Users.get_user_by_id(user_id)
|
||||
user = Users.get_user_by_id(user_id, db=db)
|
||||
user_settings = user.settings.model_dump() if user.settings else {}
|
||||
|
||||
# Check if user has "tools" and "valves" settings
|
||||
|
|
@ -220,10 +221,10 @@ class ToolsTable:
|
|||
return None
|
||||
|
||||
def update_user_valves_by_id_and_user_id(
|
||||
self, id: str, user_id: str, valves: dict
|
||||
self, id: str, user_id: str, valves: dict, db: Optional[Session] = None
|
||||
) -> Optional[dict]:
|
||||
try:
|
||||
user = Users.get_user_by_id(user_id)
|
||||
user = Users.get_user_by_id(user_id, db=db)
|
||||
user_settings = user.settings.model_dump() if user.settings else {}
|
||||
|
||||
# Check if user has "tools" and "valves" settings
|
||||
|
|
@ -235,7 +236,7 @@ class ToolsTable:
|
|||
user_settings["tools"]["valves"][id] = valves
|
||||
|
||||
# Update the user settings in the database
|
||||
Users.update_user_by_id(user_id, {"settings": user_settings})
|
||||
Users.update_user_by_id(user_id, {"settings": user_settings}, db=db)
|
||||
|
||||
return user_settings["tools"]["valves"][id]
|
||||
except Exception as e:
|
||||
|
|
@ -244,9 +245,9 @@ class ToolsTable:
|
|||
)
|
||||
return None
|
||||
|
||||
def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]:
|
||||
def update_tool_by_id(self, id: str, updated: dict, db: Optional[Session] = None) -> Optional[ToolModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
db.query(Tool).filter_by(id=id).update(
|
||||
{**updated, "updated_at": int(time.time())}
|
||||
)
|
||||
|
|
@ -258,9 +259,9 @@ class ToolsTable:
|
|||
except Exception:
|
||||
return None
|
||||
|
||||
def delete_tool_by_id(self, id: str) -> bool:
|
||||
def delete_tool_by_id(self, id: str, db: Optional[Session] = None) -> bool:
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
db.query(Tool).filter_by(id=id).delete()
|
||||
db.commit()
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
import time
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.internal.db import Base, JSONField, get_db
|
||||
from sqlalchemy.orm import Session
|
||||
from open_webui.internal.db import Base, JSONField, get_db, get_db_context
|
||||
|
||||
|
||||
from open_webui.env import DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL
|
||||
|
|
@ -243,8 +244,9 @@ class UsersTable:
|
|||
profile_image_url: str = "/user.png",
|
||||
role: str = "pending",
|
||||
oauth: Optional[dict] = None,
|
||||
db: Optional[Session] = None,
|
||||
) -> Optional[UserModel]:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
user = UserModel(
|
||||
**{
|
||||
"id": id,
|
||||
|
|
@ -267,17 +269,17 @@ class UsersTable:
|
|||
else:
|
||||
return None
|
||||
|
||||
def get_user_by_id(self, id: str) -> Optional[UserModel]:
|
||||
def get_user_by_id(self, id: str, db: Optional[Session] = None) -> Optional[UserModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
user = db.query(User).filter_by(id=id).first()
|
||||
return UserModel.model_validate(user)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]:
|
||||
def get_user_by_api_key(self, api_key: str, db: Optional[Session] = None) -> Optional[UserModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
user = (
|
||||
db.query(User)
|
||||
.join(ApiKey, User.id == ApiKey.user_id)
|
||||
|
|
@ -288,17 +290,17 @@ class UsersTable:
|
|||
except Exception:
|
||||
return None
|
||||
|
||||
def get_user_by_email(self, email: str) -> Optional[UserModel]:
|
||||
def get_user_by_email(self, email: str, db: Optional[Session] = None) -> Optional[UserModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
user = db.query(User).filter_by(email=email).first()
|
||||
return UserModel.model_validate(user)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_user_by_oauth_sub(self, provider: str, sub: str) -> Optional[UserModel]:
|
||||
def get_user_by_oauth_sub(self, provider: str, sub: str, db: Optional[Session] = None) -> Optional[UserModel]:
|
||||
try:
|
||||
with get_db() as db: # type: Session
|
||||
with get_db_context(db) as db: # type: Session
|
||||
dialect_name = db.bind.dialect.name
|
||||
|
||||
query = db.query(User)
|
||||
|
|
@ -320,8 +322,9 @@ class UsersTable:
|
|||
filter: Optional[dict] = None,
|
||||
skip: Optional[int] = None,
|
||||
limit: Optional[int] = None,
|
||||
db: Optional[Session] = None,
|
||||
) -> dict:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
# Join GroupMember so we can order by group_id when requested
|
||||
query = db.query(User)
|
||||
|
||||
|
|
@ -452,8 +455,8 @@ class UsersTable:
|
|||
"total": total,
|
||||
}
|
||||
|
||||
def get_users_by_group_id(self, group_id: str) -> list[UserModel]:
|
||||
with get_db() as db:
|
||||
def get_users_by_group_id(self, group_id: str, db: Optional[Session] = None) -> list[UserModel]:
|
||||
with get_db_context(db) as db:
|
||||
users = (
|
||||
db.query(User)
|
||||
.join(GroupMember, User.id == GroupMember.user_id)
|
||||
|
|
@ -462,30 +465,30 @@ class UsersTable:
|
|||
)
|
||||
return [UserModel.model_validate(user) for user in users]
|
||||
|
||||
def get_users_by_user_ids(self, user_ids: list[str]) -> list[UserStatusModel]:
|
||||
with get_db() as db:
|
||||
def get_users_by_user_ids(self, user_ids: list[str], db: Optional[Session] = None) -> list[UserStatusModel]:
|
||||
with get_db_context(db) as db:
|
||||
users = db.query(User).filter(User.id.in_(user_ids)).all()
|
||||
return [UserModel.model_validate(user) for user in users]
|
||||
|
||||
def get_num_users(self) -> Optional[int]:
|
||||
with get_db() as db:
|
||||
def get_num_users(self, db: Optional[Session] = None) -> Optional[int]:
|
||||
with get_db_context(db) as db:
|
||||
return db.query(User).count()
|
||||
|
||||
def has_users(self) -> bool:
|
||||
with get_db() as db:
|
||||
def has_users(self, db: Optional[Session] = None) -> bool:
|
||||
with get_db_context(db) as db:
|
||||
return db.query(db.query(User).exists()).scalar()
|
||||
|
||||
def get_first_user(self) -> UserModel:
|
||||
def get_first_user(self, db: Optional[Session] = None) -> UserModel:
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
user = db.query(User).order_by(User.created_at).first()
|
||||
return UserModel.model_validate(user)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_user_webhook_url_by_id(self, id: str) -> Optional[str]:
|
||||
def get_user_webhook_url_by_id(self, id: str, db: Optional[Session] = None) -> Optional[str]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
user = db.query(User).filter_by(id=id).first()
|
||||
|
||||
if user.settings is None:
|
||||
|
|
@ -499,8 +502,8 @@ class UsersTable:
|
|||
except Exception:
|
||||
return None
|
||||
|
||||
def get_num_users_active_today(self) -> Optional[int]:
|
||||
with get_db() as db:
|
||||
def get_num_users_active_today(self, db: Optional[Session] = None) -> Optional[int]:
|
||||
with get_db_context(db) as db:
|
||||
current_timestamp = int(datetime.datetime.now().timestamp())
|
||||
today_midnight_timestamp = current_timestamp - (current_timestamp % 86400)
|
||||
query = db.query(User).filter(
|
||||
|
|
@ -508,9 +511,9 @@ class UsersTable:
|
|||
)
|
||||
return query.count()
|
||||
|
||||
def update_user_role_by_id(self, id: str, role: str) -> Optional[UserModel]:
|
||||
def update_user_role_by_id(self, id: str, role: str, db: Optional[Session] = None) -> Optional[UserModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
db.query(User).filter_by(id=id).update({"role": role})
|
||||
db.commit()
|
||||
user = db.query(User).filter_by(id=id).first()
|
||||
|
|
@ -519,10 +522,10 @@ class UsersTable:
|
|||
return None
|
||||
|
||||
def update_user_status_by_id(
|
||||
self, id: str, form_data: UserStatus
|
||||
self, id: str, form_data: UserStatus, db: Optional[Session] = None
|
||||
) -> Optional[UserModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
db.query(User).filter_by(id=id).update(
|
||||
{**form_data.model_dump(exclude_none=True)}
|
||||
)
|
||||
|
|
@ -534,10 +537,10 @@ class UsersTable:
|
|||
return None
|
||||
|
||||
def update_user_profile_image_url_by_id(
|
||||
self, id: str, profile_image_url: str
|
||||
self, id: str, profile_image_url: str, db: Optional[Session] = None
|
||||
) -> Optional[UserModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
db.query(User).filter_by(id=id).update(
|
||||
{"profile_image_url": profile_image_url}
|
||||
)
|
||||
|
|
@ -549,9 +552,9 @@ class UsersTable:
|
|||
return None
|
||||
|
||||
@throttle(DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL)
|
||||
def update_last_active_by_id(self, id: str) -> Optional[UserModel]:
|
||||
def update_last_active_by_id(self, id: str, db: Optional[Session] = None) -> Optional[UserModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
db.query(User).filter_by(id=id).update(
|
||||
{"last_active_at": int(time.time())}
|
||||
)
|
||||
|
|
@ -563,7 +566,7 @@ class UsersTable:
|
|||
return None
|
||||
|
||||
def update_user_oauth_by_id(
|
||||
self, id: str, provider: str, sub: str
|
||||
self, id: str, provider: str, sub: str, db: Optional[Session] = None
|
||||
) -> Optional[UserModel]:
|
||||
"""
|
||||
Update or insert an OAuth provider/sub pair into the user's oauth JSON field.
|
||||
|
|
@ -574,7 +577,7 @@ class UsersTable:
|
|||
}
|
||||
"""
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
user = db.query(User).filter_by(id=id).first()
|
||||
if not user:
|
||||
return None
|
||||
|
|
@ -594,9 +597,9 @@ class UsersTable:
|
|||
except Exception:
|
||||
return None
|
||||
|
||||
def update_user_by_id(self, id: str, updated: dict) -> Optional[UserModel]:
|
||||
def update_user_by_id(self, id: str, updated: dict, db: Optional[Session] = None) -> Optional[UserModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
db.query(User).filter_by(id=id).update(updated)
|
||||
db.commit()
|
||||
|
||||
|
|
@ -607,9 +610,9 @@ class UsersTable:
|
|||
print(e)
|
||||
return None
|
||||
|
||||
def update_user_settings_by_id(self, id: str, updated: dict) -> Optional[UserModel]:
|
||||
def update_user_settings_by_id(self, id: str, updated: dict, db: Optional[Session] = None) -> Optional[UserModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
user_settings = db.query(User).filter_by(id=id).first().settings
|
||||
|
||||
if user_settings is None:
|
||||
|
|
@ -625,15 +628,15 @@ class UsersTable:
|
|||
except Exception:
|
||||
return None
|
||||
|
||||
def delete_user_by_id(self, id: str) -> bool:
|
||||
def delete_user_by_id(self, id: str, db: Optional[Session] = None) -> bool:
|
||||
try:
|
||||
# Remove User from Groups
|
||||
Groups.remove_user_from_all_groups(id)
|
||||
|
||||
# Delete User Chats
|
||||
result = Chats.delete_chats_by_user_id(id)
|
||||
result = Chats.delete_chats_by_user_id(id, db=db)
|
||||
if result:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
# Delete User
|
||||
db.query(User).filter_by(id=id).delete()
|
||||
db.commit()
|
||||
|
|
@ -644,17 +647,17 @@ class UsersTable:
|
|||
except Exception:
|
||||
return False
|
||||
|
||||
def get_user_api_key_by_id(self, id: str) -> Optional[str]:
|
||||
def get_user_api_key_by_id(self, id: str, db: Optional[Session] = None) -> Optional[str]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
api_key = db.query(ApiKey).filter_by(user_id=id).first()
|
||||
return api_key.key if api_key else None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def update_user_api_key_by_id(self, id: str, api_key: str) -> bool:
|
||||
def update_user_api_key_by_id(self, id: str, api_key: str, db: Optional[Session] = None) -> bool:
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
db.query(ApiKey).filter_by(user_id=id).delete()
|
||||
db.commit()
|
||||
|
||||
|
|
@ -674,30 +677,30 @@ class UsersTable:
|
|||
except Exception:
|
||||
return False
|
||||
|
||||
def delete_user_api_key_by_id(self, id: str) -> bool:
|
||||
def delete_user_api_key_by_id(self, id: str, db: Optional[Session] = None) -> bool:
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
db.query(ApiKey).filter_by(user_id=id).delete()
|
||||
db.commit()
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def get_valid_user_ids(self, user_ids: list[str]) -> list[str]:
|
||||
with get_db() as db:
|
||||
def get_valid_user_ids(self, user_ids: list[str], db: Optional[Session] = None) -> list[str]:
|
||||
with get_db_context(db) as db:
|
||||
users = db.query(User).filter(User.id.in_(user_ids)).all()
|
||||
return [user.id for user in users]
|
||||
|
||||
def get_super_admin_user(self) -> Optional[UserModel]:
|
||||
with get_db() as db:
|
||||
def get_super_admin_user(self, db: Optional[Session] = None) -> Optional[UserModel]:
|
||||
with get_db_context(db) as db:
|
||||
user = db.query(User).filter_by(role="admin").first()
|
||||
if user:
|
||||
return UserModel.model_validate(user)
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_active_user_count(self) -> int:
|
||||
with get_db() as db:
|
||||
def get_active_user_count(self, db: Optional[Session] = None) -> int:
|
||||
with get_db_context(db) as db:
|
||||
# Consider user active if last_active_at within the last 3 minutes
|
||||
three_minutes_ago = int(time.time()) - 180
|
||||
count = (
|
||||
|
|
@ -705,8 +708,8 @@ class UsersTable:
|
|||
)
|
||||
return count
|
||||
|
||||
def is_user_active(self, user_id: str) -> bool:
|
||||
with get_db() as db:
|
||||
def is_user_active(self, user_id: str, db: Optional[Session] = None) -> bool:
|
||||
with get_db_context(db) as db:
|
||||
user = db.query(User).filter_by(id=user_id).first()
|
||||
if user and user.last_active_at:
|
||||
# Consider user active if last_active_at within the last 3 minutes
|
||||
|
|
|
|||
|
|
@ -142,7 +142,7 @@ async def create_new_note(
|
|||
)
|
||||
|
||||
try:
|
||||
note = Notes.insert_new_note(form_data, user.id)
|
||||
note = Notes.insert_new_note(user.id, form_data)
|
||||
return note
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
|
|
|||
Loading…
Reference in a new issue