From b1d0f00d8c4d024a55f84e4d465d2fddd637ef27 Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Mon, 29 Dec 2025 00:21:18 +0400 Subject: [PATCH] refac/enh: db session sharing --- backend/open_webui/internal/db.py | 5 +- backend/open_webui/main.py | 17 +- .../retrieval/vector/dbs/opengauss.py | 77 ++-- .../retrieval/vector/dbs/pgvector.py | 4 +- backend/open_webui/routers/auths.py | 134 +++++-- backend/open_webui/routers/chats.py | 345 ++++++++++++------ backend/open_webui/routers/evaluations.py | 75 ++-- backend/open_webui/routers/files.py | 230 ++++++++---- backend/open_webui/routers/folders.py | 87 +++-- backend/open_webui/routers/functions.py | 84 ++--- backend/open_webui/routers/groups.py | 80 ++-- backend/open_webui/routers/knowledge.py | 153 +++++--- backend/open_webui/routers/memories.py | 30 +- backend/open_webui/routers/models.py | 82 +++-- backend/open_webui/routers/notes.py | 90 +++-- backend/open_webui/routers/prompts.py | 44 ++- backend/open_webui/routers/retrieval.py | 26 +- backend/open_webui/routers/tools.py | 88 ++--- backend/open_webui/routers/users.py | 129 ++++--- backend/open_webui/utils/access_control.py | 17 +- backend/open_webui/utils/auth.py | 13 +- backend/open_webui/utils/groups.py | 3 +- backend/open_webui/utils/oauth.py | 23 +- 23 files changed, 1173 insertions(+), 663 deletions(-) diff --git a/backend/open_webui/internal/db.py b/backend/open_webui/internal/db.py index e84dfe9aa7..e4e4e0d2f8 100644 --- a/backend/open_webui/internal/db.py +++ b/backend/open_webui/internal/db.py @@ -19,7 +19,7 @@ from open_webui.env import ( from peewee_migrate import Router from sqlalchemy import Dialect, create_engine, MetaData, event, types from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import scoped_session, sessionmaker +from sqlalchemy.orm import scoped_session, sessionmaker, Session from sqlalchemy.pool import QueuePool, NullPool from sqlalchemy.sql.type_api import _T from typing_extensions import Self @@ -148,7 +148,7 @@ SessionLocal = sessionmaker( ) metadata_obj = MetaData(schema=DATABASE_SCHEMA) Base = declarative_base(metadata=metadata_obj) -Session = scoped_session(SessionLocal) +ScopedSession = scoped_session(SessionLocal) def get_session(): @@ -169,4 +169,3 @@ def get_db_context(db: Optional[Session] = None): else: with get_db() as session: yield session - diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 63ae72f7bc..97d91c0ee0 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -102,7 +102,9 @@ from open_webui.routers.retrieval import ( get_rf, ) -from open_webui.internal.db import Session, engine + +from sqlalchemy.orm import Session +from open_webui.internal.db import ScopedSession, engine, get_session from open_webui.models.functions import Functions from open_webui.models.models import Models @@ -1324,7 +1326,7 @@ app.add_middleware(APIKeyRestrictionMiddleware) async def commit_session_after_request(request: Request, call_next): response = await call_next(request) # log.debug("Commit session after request") - Session.commit() + ScopedSession.commit() return response @@ -2280,8 +2282,13 @@ async def oauth_login(provider: str, request: Request): # - Email addresses are considered unique, so we fail registration if the email address is already taken @app.get("/oauth/{provider}/login/callback") @app.get("/oauth/{provider}/callback") # Legacy endpoint -async def oauth_login_callback(provider: str, request: Request, response: Response): - return await oauth_manager.handle_callback(request, provider, response) +async def oauth_login_callback( + provider: str, + request: Request, + response: Response, + db: Session = Depends(get_session), +): + return await oauth_manager.handle_callback(request, provider, response, db=db) @app.get("/manifest.json") @@ -2340,7 +2347,7 @@ async def healthcheck(): @app.get("/health/db") async def healthcheck_with_db(): - Session.execute(text("SELECT 1;")).all() + ScopedSession.execute(text("SELECT 1;")).all() return {"status": True} diff --git a/backend/open_webui/retrieval/vector/dbs/opengauss.py b/backend/open_webui/retrieval/vector/dbs/opengauss.py index 13505e1852..056a8a61cc 100644 --- a/backend/open_webui/retrieval/vector/dbs/opengauss.py +++ b/backend/open_webui/retrieval/vector/dbs/opengauss.py @@ -30,27 +30,27 @@ from sqlalchemy.exc import NoSuchTableError from sqlalchemy.dialects.postgresql.psycopg2 import PGDialect_psycopg2 from sqlalchemy.dialects import registry + class OpenGaussDialect(PGDialect_psycopg2): name = "opengauss" - + def _get_server_version_info(self, connection): try: version = connection.exec_driver_sql("SELECT version()").scalar() if not version: return (9, 0, 0) - + match = re.search( - r"openGauss\s+(\d+)\.(\d+)\.(\d+)(?:-\w+)?", - version, - re.IGNORECASE + r"openGauss\s+(\d+)\.(\d+)\.(\d+)(?:-\w+)?", version, re.IGNORECASE ) if match: return (int(match.group(1)), int(match.group(2)), int(match.group(3))) - + return super()._get_server_version_info(connection) except Exception: return (9, 0, 0) + # Register dialect registry.register("opengauss", __name__, "OpenGaussDialect") @@ -78,6 +78,7 @@ Base = declarative_base() log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) + class DocumentChunk(Base): __tablename__ = "document_chunk" @@ -87,29 +88,30 @@ class DocumentChunk(Base): text = Column(Text, nullable=True) vmetadata = Column(MutableDict.as_mutable(JSONB), nullable=True) + class OpenGaussClient(VectorDBBase): def __init__(self) -> None: if not OPENGAUSS_DB_URL: - from open_webui.internal.db import Session - self.session = Session + from open_webui.internal.db import ScopedSession + + self.session = ScopedSession else: - engine_kwargs = { - "pool_pre_ping": True, - "dialect": OpenGaussDialect() - } - + engine_kwargs = {"pool_pre_ping": True, "dialect": OpenGaussDialect()} + if isinstance(OPENGAUSS_POOL_SIZE, int) and OPENGAUSS_POOL_SIZE > 0: - engine_kwargs.update({ - "pool_size": OPENGAUSS_POOL_SIZE, - "max_overflow": OPENGAUSS_POOL_MAX_OVERFLOW, - "pool_timeout": OPENGAUSS_POOL_TIMEOUT, - "pool_recycle": OPENGAUSS_POOL_RECYCLE, - "poolclass": QueuePool - }) + engine_kwargs.update( + { + "pool_size": OPENGAUSS_POOL_SIZE, + "max_overflow": OPENGAUSS_POOL_MAX_OVERFLOW, + "pool_timeout": OPENGAUSS_POOL_TIMEOUT, + "pool_recycle": OPENGAUSS_POOL_RECYCLE, + "poolclass": QueuePool, + } + ) else: engine_kwargs["poolclass"] = NullPool - engine = create_engine(OPENGAUSS_DB_URL,** engine_kwargs) + engine = create_engine(OPENGAUSS_DB_URL, **engine_kwargs) SessionLocal = sessionmaker( autocommit=False, autoflush=False, bind=engine, expire_on_commit=False @@ -160,7 +162,9 @@ class OpenGaussClient(VectorDBBase): else: raise Exception("The 'vector' column type is not Vector.") else: - raise Exception("The 'vector' column does not exist in the 'document_chunk' table.") + raise Exception( + "The 'vector' column does not exist in the 'document_chunk' table." + ) def adjust_vector_length(self, vector: List[float]) -> List[float]: current_length = len(vector) @@ -185,7 +189,9 @@ class OpenGaussClient(VectorDBBase): new_items.append(new_chunk) self.session.bulk_save_objects(new_items) self.session.commit() - log.info(f"Inserting {len(new_items)} items into collection '{collection_name}'.") + log.info( + f"Inserting {len(new_items)} items into collection '{collection_name}'." + ) except Exception as e: self.session.rollback() log.exception(f"Failed to insert data: {e}") @@ -215,7 +221,9 @@ class OpenGaussClient(VectorDBBase): ) self.session.add(new_chunk) self.session.commit() - log.info(f"Inserting/updating {len(items)} items in collection '{collection_name}'.") + log.info( + f"Inserting/updating {len(items)} items in collection '{collection_name}'." + ) except Exception as e: self.session.rollback() log.exception(f"Failed to insert or update data.: {e}") @@ -241,7 +249,9 @@ class OpenGaussClient(VectorDBBase): q_vector_col = column("q_vector", Vector(VECTOR_LENGTH)) query_vectors = ( values(qid_col, q_vector_col) - .data([(idx, vector_expr(vector)) for idx, vector in enumerate(vectors)]) + .data( + [(idx, vector_expr(vector)) for idx, vector in enumerate(vectors)] + ) .alias("query_vectors") ) @@ -249,13 +259,17 @@ class OpenGaussClient(VectorDBBase): DocumentChunk.id, DocumentChunk.text, DocumentChunk.vmetadata, - (DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)).label("distance"), + (DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)).label( + "distance" + ), ] subq = ( select(*result_fields) .where(DocumentChunk.collection_name == collection_name) - .order_by(DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)) + .order_by( + DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector) + ) ) if limit is not None: subq = subq.limit(limit) @@ -368,7 +382,9 @@ class OpenGaussClient(VectorDBBase): query = query.filter(DocumentChunk.id.in_(ids)) if filter: for key, value in filter.items(): - query = query.filter(DocumentChunk.vmetadata[key].astext == str(value)) + query = query.filter( + DocumentChunk.vmetadata[key].astext == str(value) + ) deleted = query.delete(synchronize_session=False) self.session.commit() log.info(f"Deleted {deleted} items from collection '{collection_name}'") @@ -395,7 +411,8 @@ class OpenGaussClient(VectorDBBase): exists = ( self.session.query(DocumentChunk) .filter(DocumentChunk.collection_name == collection_name) - .first() is not None + .first() + is not None ) self.session.rollback() return exists @@ -406,4 +423,4 @@ class OpenGaussClient(VectorDBBase): def delete_collection(self, collection_name: str) -> None: self.delete(collection_name) - log.info(f"Collection '{collection_name}' has been deleted") \ No newline at end of file + log.info(f"Collection '{collection_name}' has been deleted") diff --git a/backend/open_webui/retrieval/vector/dbs/pgvector.py b/backend/open_webui/retrieval/vector/dbs/pgvector.py index 2f4677995a..b9b1dba07b 100644 --- a/backend/open_webui/retrieval/vector/dbs/pgvector.py +++ b/backend/open_webui/retrieval/vector/dbs/pgvector.py @@ -90,9 +90,9 @@ class PgvectorClient(VectorDBBase): # if no pgvector uri, use the existing database connection if not PGVECTOR_DB_URL: - from open_webui.internal.db import Session + from open_webui.internal.db import ScopedSession - self.session = Session + self.session = ScopedSession else: if isinstance(PGVECTOR_POOL_SIZE, int): if PGVECTOR_POOL_SIZE > 0: diff --git a/backend/open_webui/routers/auths.py b/backend/open_webui/routers/auths.py index 59c5322c7f..7482e9053e 100644 --- a/backend/open_webui/routers/auths.py +++ b/backend/open_webui/routers/auths.py @@ -62,6 +62,8 @@ from open_webui.utils.auth import ( get_password_hash, get_http_authorization_cred, ) +from open_webui.internal.db import get_session +from sqlalchemy.orm import Session from open_webui.utils.webhook import post_webhook from open_webui.utils.access_control import get_permissions, has_permission from open_webui.utils.groups import apply_default_group_assignment @@ -103,7 +105,10 @@ class SessionUserInfoResponse(SessionUserResponse, UserStatus): @router.get("/", response_model=SessionUserInfoResponse) async def get_session_user( - request: Request, response: Response, user=Depends(get_current_user) + request: Request, + response: Response, + user=Depends(get_current_user), + db: Session = Depends(get_session), ): auth_header = request.headers.get("Authorization") @@ -137,7 +142,7 @@ async def get_session_user( ) user_permissions = get_permissions( - user.id, request.app.state.config.USER_PERMISSIONS + user.id, request.app.state.config.USER_PERMISSIONS, db=db ) return { @@ -166,12 +171,15 @@ async def get_session_user( @router.post("/update/profile", response_model=UserProfileImageResponse) async def update_profile( - form_data: UpdateProfileForm, session_user=Depends(get_verified_user) + form_data: UpdateProfileForm, + session_user=Depends(get_verified_user), + db: Session = Depends(get_session), ): if session_user: user = Users.update_user_by_id( session_user.id, form_data.model_dump(), + db=db, ) if user: return user @@ -188,13 +196,17 @@ async def update_profile( @router.post("/update/password", response_model=bool) async def update_password( - form_data: UpdatePasswordForm, session_user=Depends(get_current_user) + form_data: UpdatePasswordForm, + session_user=Depends(get_current_user), + db: Session = Depends(get_session), ): if WEBUI_AUTH_TRUSTED_EMAIL_HEADER: raise HTTPException(400, detail=ERROR_MESSAGES.ACTION_PROHIBITED) if session_user: user = Auths.authenticate_user( - session_user.email, lambda pw: verify_password(form_data.password, pw) + session_user.email, + lambda pw: verify_password(form_data.password, pw), + db=db, ) if user: @@ -203,7 +215,7 @@ async def update_password( except Exception as e: raise HTTPException(400, detail=str(e)) hashed = get_password_hash(form_data.new_password) - return Auths.update_user_password_by_id(user.id, hashed) + return Auths.update_user_password_by_id(user.id, hashed, db=db) else: raise HTTPException(400, detail=ERROR_MESSAGES.INCORRECT_PASSWORD) else: @@ -214,7 +226,12 @@ async def update_password( # LDAP Authentication ############################ @router.post("/ldap", response_model=SessionUserResponse) -async def ldap_auth(request: Request, response: Response, form_data: LdapForm): +async def ldap_auth( + request: Request, + response: Response, + form_data: LdapForm, + db: Session = Depends(get_session), +): # Security checks FIRST - before loading any config if not request.app.state.config.ENABLE_LDAP: raise HTTPException(400, detail="LDAP authentication is not enabled") @@ -400,12 +417,12 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm): if not connection_user.bind(): raise HTTPException(400, "Authentication failed.") - user = Users.get_user_by_email(email) + user = Users.get_user_by_email(email, db=db) if not user: try: role = ( "admin" - if not Users.has_users() + if not Users.has_users(db=db) else request.app.state.config.DEFAULT_USER_ROLE ) @@ -414,6 +431,7 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm): password=str(uuid.uuid4()), name=cn, role=role, + db=db, ) if not user: @@ -424,6 +442,7 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm): apply_default_group_assignment( request.app.state.config.DEFAULT_GROUP_ID, user.id, + db=db, ) except HTTPException: @@ -434,7 +453,7 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm): 500, detail="Internal error occurred during LDAP user creation." ) - user = Auths.authenticate_user_by_email(email) + user = Auths.authenticate_user_by_email(email, db=db) if user: expires_delta = parse_duration(request.app.state.config.JWT_EXPIRES_IN) @@ -464,7 +483,7 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm): ) user_permissions = get_permissions( - user.id, request.app.state.config.USER_PERMISSIONS + user.id, request.app.state.config.USER_PERMISSIONS, db=db ) if ( @@ -473,9 +492,9 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm): and user_groups ): if ENABLE_LDAP_GROUP_CREATION: - Groups.create_groups_by_group_names(user.id, user_groups) + Groups.create_groups_by_group_names(user.id, user_groups, db=db) try: - Groups.sync_groups_by_group_names(user.id, user_groups) + Groups.sync_groups_by_group_names(user.id, user_groups, db=db) log.info( f"Successfully synced groups for user {user.id}: {user_groups}" ) @@ -508,7 +527,12 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm): @router.post("/signin", response_model=SessionUserResponse) -async def signin(request: Request, response: Response, form_data: SigninForm): +async def signin( + request: Request, + response: Response, + form_data: SigninForm, + db: Session = Depends(get_session), +): if not ENABLE_PASSWORD_AUTH: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, @@ -529,14 +553,15 @@ async def signin(request: Request, response: Response, form_data: SigninForm): except Exception as e: pass - if not Users.get_user_by_email(email.lower()): + if not Users.get_user_by_email(email.lower(), db=db): await signup( request, response, SignupForm(email=email, password=str(uuid.uuid4()), name=name), + db=db, ) - user = Auths.authenticate_user_by_email(email) + user = Auths.authenticate_user_by_email(email, db=db) if WEBUI_AUTH_TRUSTED_GROUPS_HEADER and user and user.role != "admin": group_names = request.headers.get( WEBUI_AUTH_TRUSTED_GROUPS_HEADER, "" @@ -544,28 +569,33 @@ async def signin(request: Request, response: Response, form_data: SigninForm): group_names = [name.strip() for name in group_names if name.strip()] if group_names: - Groups.sync_groups_by_group_names(user.id, group_names) + Groups.sync_groups_by_group_names(user.id, group_names, db=db) elif WEBUI_AUTH == False: admin_email = "admin@localhost" admin_password = "admin" - if Users.get_user_by_email(admin_email.lower()): + if Users.get_user_by_email(admin_email.lower(), db=db): user = Auths.authenticate_user( - admin_email.lower(), lambda pw: verify_password(admin_password, pw) + admin_email.lower(), + lambda pw: verify_password(admin_password, pw), + db=db, ) else: - if Users.has_users(): + if Users.has_users(db=db): raise HTTPException(400, detail=ERROR_MESSAGES.EXISTING_USERS) await signup( request, response, SignupForm(email=admin_email, password=admin_password, name="User"), + db=db, ) user = Auths.authenticate_user( - admin_email.lower(), lambda pw: verify_password(admin_password, pw) + admin_email.lower(), + lambda pw: verify_password(admin_password, pw), + db=db, ) else: if signin_rate_limiter.is_limited(form_data.email.lower()): @@ -584,7 +614,9 @@ async def signin(request: Request, response: Response, form_data: SigninForm): form_data.password = password_bytes.decode("utf-8", errors="ignore") user = Auths.authenticate_user( - form_data.email.lower(), lambda pw: verify_password(form_data.password, pw) + form_data.email.lower(), + lambda pw: verify_password(form_data.password, pw), + db=db, ) if user: @@ -616,7 +648,7 @@ async def signin(request: Request, response: Response, form_data: SigninForm): ) user_permissions = get_permissions( - user.id, request.app.state.config.USER_PERMISSIONS + user.id, request.app.state.config.USER_PERMISSIONS, db=db ) return { @@ -640,8 +672,13 @@ async def signin(request: Request, response: Response, form_data: SigninForm): @router.post("/signup", response_model=SessionUserResponse) -async def signup(request: Request, response: Response, form_data: SignupForm): - has_users = Users.has_users() +async def signup( + request: Request, + response: Response, + form_data: SignupForm, + db: Session = Depends(get_session), +): + has_users = Users.has_users(db=db) if WEBUI_AUTH: if ( @@ -663,7 +700,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm): status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT ) - if Users.get_user_by_email(form_data.email.lower()): + if Users.get_user_by_email(form_data.email.lower(), db=db): raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN) try: @@ -681,6 +718,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm): form_data.name, form_data.profile_image_url, role, + db=db, ) if user: @@ -723,7 +761,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm): ) user_permissions = get_permissions( - user.id, request.app.state.config.USER_PERMISSIONS + user.id, request.app.state.config.USER_PERMISSIONS, db=db ) if not has_users: @@ -733,6 +771,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm): apply_default_group_assignment( request.app.state.config.DEFAULT_GROUP_ID, user.id, + db=db, ) return { @@ -754,7 +793,9 @@ async def signup(request: Request, response: Response, form_data: SignupForm): @router.get("/signout") -async def signout(request: Request, response: Response): +async def signout( + request: Request, response: Response, db: Session = Depends(get_session) +): # get auth token from headers or cookies token = None @@ -776,7 +817,7 @@ async def signout(request: Request, response: Response): if oauth_session_id: response.delete_cookie("oauth_session_id") - session = OAuthSessions.get_session_by_id(oauth_session_id) + session = OAuthSessions.get_session_by_id(oauth_session_id, db=db) oauth_server_metadata_url = ( request.app.state.oauth_manager.get_server_metadata_url(session.provider) if session @@ -839,14 +880,17 @@ async def signout(request: Request, response: Response): @router.post("/add", response_model=SigninResponse) async def add_user( - request: Request, form_data: AddUserForm, user=Depends(get_admin_user) + request: Request, + form_data: AddUserForm, + user=Depends(get_admin_user), + db: Session = Depends(get_session), ): if not validate_email_format(form_data.email.lower()): raise HTTPException( status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT ) - if Users.get_user_by_email(form_data.email.lower()): + if Users.get_user_by_email(form_data.email.lower(), db=db): raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN) try: @@ -862,12 +906,14 @@ async def add_user( form_data.name, form_data.profile_image_url, form_data.role, + db=db, ) if user: apply_default_group_assignment( request.app.state.config.DEFAULT_GROUP_ID, user.id, + db=db, ) token = create_token(data={"id": user.id}) @@ -895,7 +941,9 @@ async def add_user( @router.get("/admin/details") -async def get_admin_details(request: Request, user=Depends(get_current_user)): +async def get_admin_details( + request: Request, user=Depends(get_current_user), db: Session = Depends(get_session) +): if request.app.state.config.SHOW_ADMIN_DETAILS: admin_email = request.app.state.config.ADMIN_EMAIL admin_name = None @@ -903,11 +951,11 @@ async def get_admin_details(request: Request, user=Depends(get_current_user)): log.info(f"Admin details - Email: {admin_email}, Name: {admin_name}") if admin_email: - admin = Users.get_user_by_email(admin_email) + admin = Users.get_user_by_email(admin_email, db=db) if admin: admin_name = admin.name else: - admin = Users.get_first_user() + admin = Users.get_first_user(db=db) if admin: admin_email = admin.email admin_name = admin.name @@ -1149,7 +1197,9 @@ async def update_ldap_config( # create api key @router.post("/api_key", response_model=ApiKey) -async def generate_api_key(request: Request, user=Depends(get_current_user)): +async def generate_api_key( + request: Request, user=Depends(get_current_user), db: Session = Depends(get_session) +): if not request.app.state.config.ENABLE_API_KEYS or not has_permission( user.id, "features.api_keys", request.app.state.config.USER_PERMISSIONS ): @@ -1159,7 +1209,7 @@ async def generate_api_key(request: Request, user=Depends(get_current_user)): ) api_key = create_api_key() - success = Users.update_user_api_key_by_id(user.id, api_key) + success = Users.update_user_api_key_by_id(user.id, api_key, db=db) if success: return { @@ -1171,14 +1221,18 @@ async def generate_api_key(request: Request, user=Depends(get_current_user)): # delete api key @router.delete("/api_key", response_model=bool) -async def delete_api_key(user=Depends(get_current_user)): - return Users.delete_user_api_key_by_id(user.id) +async def delete_api_key( + user=Depends(get_current_user), db: Session = Depends(get_session) +): + return Users.delete_user_api_key_by_id(user.id, db=db) # get api key @router.get("/api_key", response_model=ApiKey) -async def get_api_key(user=Depends(get_current_user)): - api_key = Users.get_user_api_key_by_id(user.id) +async def get_api_key( + user=Depends(get_current_user), db: Session = Depends(get_session) +): + api_key = Users.get_user_api_key_by_id(user.id, db=db) if api_key: return { "api_key": api_key, diff --git a/backend/open_webui/routers/chats.py b/backend/open_webui/routers/chats.py index 684f929134..c98e119da5 100644 --- a/backend/open_webui/routers/chats.py +++ b/backend/open_webui/routers/chats.py @@ -1,6 +1,7 @@ import json import logging from typing import Optional +from sqlalchemy.orm import Session import asyncio from fastapi.responses import StreamingResponse @@ -23,6 +24,7 @@ from open_webui.models.chats import ( ) from open_webui.models.tags import TagModel, Tags from open_webui.models.folders import Folders +from open_webui.internal.db import get_session from open_webui.config import ENABLE_ADMIN_CHAT_ACCESS, ENABLE_ADMIN_EXPORT from open_webui.constants import ERROR_MESSAGES @@ -49,6 +51,7 @@ def get_session_user_chat_list( page: Optional[int] = None, include_pinned: Optional[bool] = False, include_folders: Optional[bool] = False, + db: Session = Depends(get_session), ): try: if page is not None: @@ -61,10 +64,14 @@ def get_session_user_chat_list( include_pinned=include_pinned, skip=skip, limit=limit, + db=db, ) else: return Chats.get_chat_title_id_list_by_user_id( - user.id, include_folders=include_folders, include_pinned=include_pinned + user.id, + include_folders=include_folders, + include_pinned=include_pinned, + db=db, ) except Exception as e: log.exception(e) @@ -84,12 +91,13 @@ def get_session_user_chat_usage_stats( items_per_page: Optional[int] = 50, page: Optional[int] = 1, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): try: limit = items_per_page skip = (page - 1) * limit - result = Chats.get_chats_by_user_id(user.id, skip=skip, limit=limit) + result = Chats.get_chats_by_user_id(user.id, skip=skip, limit=limit, db=db) chats = result.items total = result.total @@ -216,6 +224,7 @@ class ChatStatsExportList(BaseModel): def _process_chat_for_export(chat) -> Optional[ChatStatsExport]: try: + def get_message_content_length(message): content = message.get("content", "") if isinstance(content, str): @@ -348,7 +357,9 @@ def _process_chat_for_export(chat) -> Optional[ChatStatsExport]: return None -def calculate_chat_stats(user_id, skip=0, limit=10, filter=None): +def calculate_chat_stats( + user_id, skip=0, limit=10, filter=None, db: Optional[Session] = None +): if filter is None: filter = {} @@ -357,6 +368,7 @@ def calculate_chat_stats(user_id, skip=0, limit=10, filter=None): skip=skip, limit=limit, filter=filter, + db=db, ) chat_stats_export_list = [] @@ -368,14 +380,21 @@ def calculate_chat_stats(user_id, skip=0, limit=10, filter=None): return chat_stats_export_list, result.total -async def generate_chat_stats_jsonl_generator(user_id, filter): +async def generate_chat_stats_jsonl_generator( + user_id, filter, db: Optional[Session] = None +): skip = 0 limit = CHAT_EXPORT_PAGE_ITEM_COUNT while True: # Use asyncio.to_thread to make the blocking DB call non-blocking result = await asyncio.to_thread( - Chats.get_chats_by_user_id, user_id, filter=filter, skip=skip, limit=limit + Chats.get_chats_by_user_id, + user_id, + filter=filter, + skip=skip, + limit=limit, + db=db, ) if not result.items: break @@ -386,7 +405,7 @@ async def generate_chat_stats_jsonl_generator(user_id, filter): if chat_stat: yield chat_stat.model_dump_json() + "\n" except Exception as e: - log.exception(f"Error processing chat {chat.id}: {e}") + log.exception(f"Error processing chat {chat.id}: {e}") skip += limit @@ -400,6 +419,7 @@ async def export_chat_stats( page: Optional[int] = 1, stream: bool = False, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): # Check if the user has permission to share/export chats if (user.role != "admin") and ( @@ -415,7 +435,7 @@ async def export_chat_stats( filter = {"order_by": "created_at", "direction": "asc"} if chat_id: - chat = Chats.get_chat_by_id(chat_id) + chat = Chats.get_chat_by_id(chat_id, db=db) if chat: filter["start_time"] = chat.created_at @@ -426,7 +446,7 @@ async def export_chat_stats( if stream: return StreamingResponse( - generate_chat_stats_jsonl_generator(user.id, filter), + generate_chat_stats_jsonl_generator(user.id, filter, db=db), media_type="application/x-ndjson", headers={ "Content-Disposition": f"attachment; filename=chat-stats-export-{user.id}.jsonl" @@ -437,7 +457,7 @@ async def export_chat_stats( skip = (page - 1) * limit chat_stats_export_list, total = await asyncio.to_thread( - calculate_chat_stats, user.id, skip, limit, filter + calculate_chat_stats, user.id, skip, limit, filter, db=db ) return ChatStatsExportList( @@ -452,7 +472,11 @@ async def export_chat_stats( @router.delete("/", response_model=bool) -async def delete_all_user_chats(request: Request, user=Depends(get_verified_user)): +async def delete_all_user_chats( + request: Request, + user=Depends(get_verified_user), + db: Session = Depends(get_session), +): if user.role == "user" and not has_permission( user.id, "chat.delete", request.app.state.config.USER_PERMISSIONS @@ -462,7 +486,7 @@ async def delete_all_user_chats(request: Request, user=Depends(get_verified_user detail=ERROR_MESSAGES.ACCESS_PROHIBITED, ) - result = Chats.delete_chats_by_user_id(user.id) + result = Chats.delete_chats_by_user_id(user.id, db=db) return result @@ -479,6 +503,7 @@ async def get_user_chat_list_by_user_id( order_by: Optional[str] = None, direction: Optional[str] = None, user=Depends(get_admin_user), + db: Session = Depends(get_session), ): if not ENABLE_ADMIN_CHAT_ACCESS: raise HTTPException( @@ -501,7 +526,7 @@ async def get_user_chat_list_by_user_id( filter["direction"] = direction return Chats.get_chat_list_by_user_id( - user_id, include_archived=True, filter=filter, skip=skip, limit=limit + user_id, include_archived=True, filter=filter, skip=skip, limit=limit, db=db ) @@ -511,9 +536,13 @@ async def get_user_chat_list_by_user_id( @router.post("/new", response_model=Optional[ChatResponse]) -async def create_new_chat(form_data: ChatForm, user=Depends(get_verified_user)): +async def create_new_chat( + form_data: ChatForm, + user=Depends(get_verified_user), + db: Session = Depends(get_session), +): try: - chat = Chats.insert_new_chat(user.id, form_data) + chat = Chats.insert_new_chat(user.id, form_data, db=db) return ChatResponse(**chat.model_dump()) except Exception as e: log.exception(e) @@ -528,9 +557,13 @@ async def create_new_chat(form_data: ChatForm, user=Depends(get_verified_user)): @router.post("/import", response_model=list[ChatResponse]) -async def import_chats(form_data: ChatsImportForm, user=Depends(get_verified_user)): +async def import_chats( + form_data: ChatsImportForm, + user=Depends(get_verified_user), + db: Session = Depends(get_session), +): try: - chats = Chats.import_chats(user.id, form_data.chats) + chats = Chats.import_chats(user.id, form_data.chats, db=db) return chats except Exception as e: log.exception(e) @@ -546,7 +579,10 @@ async def import_chats(form_data: ChatsImportForm, user=Depends(get_verified_use @router.get("/search", response_model=list[ChatTitleIdResponse]) def search_user_chats( - text: str, page: Optional[int] = None, user=Depends(get_verified_user) + text: str, + page: Optional[int] = None, + user=Depends(get_verified_user), + db: Session = Depends(get_session), ): if page is None: page = 1 @@ -557,7 +593,7 @@ def search_user_chats( chat_list = [ ChatTitleIdResponse(**chat.model_dump()) for chat in Chats.get_chats_by_user_id_and_search_text( - user.id, text, skip=skip, limit=limit + user.id, text, skip=skip, limit=limit, db=db ) ] @@ -566,9 +602,9 @@ def search_user_chats( if page == 1 and len(words) == 1 and words[0].startswith("tag:"): tag_id = words[0].replace("tag:", "") if len(chat_list) == 0: - if Tags.get_tag_by_name_and_user_id(tag_id, user.id): + if Tags.get_tag_by_name_and_user_id(tag_id, user.id, db=db): log.debug(f"deleting tag: {tag_id}") - Tags.delete_tag_by_name_and_user_id(tag_id, user.id) + Tags.delete_tag_by_name_and_user_id(tag_id, user.id, db=db) return chat_list @@ -579,23 +615,30 @@ def search_user_chats( @router.get("/folder/{folder_id}", response_model=list[ChatResponse]) -async def get_chats_by_folder_id(folder_id: str, user=Depends(get_verified_user)): +async def get_chats_by_folder_id( + folder_id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) +): folder_ids = [folder_id] children_folders = Folders.get_children_folders_by_id_and_user_id( - folder_id, user.id + folder_id, user.id, db=db ) if children_folders: folder_ids.extend([folder.id for folder in children_folders]) return [ ChatResponse(**chat.model_dump()) - for chat in Chats.get_chats_by_folder_ids_and_user_id(folder_ids, user.id) + for chat in Chats.get_chats_by_folder_ids_and_user_id( + folder_ids, user.id, db=db + ) ] @router.get("/folder/{folder_id}/list") async def get_chat_list_by_folder_id( - folder_id: str, page: Optional[int] = 1, user=Depends(get_verified_user) + folder_id: str, + page: Optional[int] = 1, + user=Depends(get_verified_user), + db: Session = Depends(get_session), ): try: limit = 10 @@ -604,7 +647,7 @@ async def get_chat_list_by_folder_id( return [ {"title": chat.title, "id": chat.id, "updated_at": chat.updated_at} for chat in Chats.get_chats_by_folder_id_and_user_id( - folder_id, user.id, skip=skip, limit=limit + folder_id, user.id, skip=skip, limit=limit, db=db ) ] @@ -621,10 +664,12 @@ async def get_chat_list_by_folder_id( @router.get("/pinned", response_model=list[ChatTitleIdResponse]) -async def get_user_pinned_chats(user=Depends(get_verified_user)): +async def get_user_pinned_chats( + user=Depends(get_verified_user), db: Session = Depends(get_session) +): return [ ChatTitleIdResponse(**chat.model_dump()) - for chat in Chats.get_pinned_chats_by_user_id(user.id) + for chat in Chats.get_pinned_chats_by_user_id(user.id, db=db) ] @@ -634,10 +679,12 @@ async def get_user_pinned_chats(user=Depends(get_verified_user)): @router.get("/all", response_model=list[ChatResponse]) -async def get_user_chats(user=Depends(get_verified_user)): +async def get_user_chats( + user=Depends(get_verified_user), db: Session = Depends(get_session) +): return [ ChatResponse(**chat.model_dump()) - for chat in Chats.get_chats_by_user_id(user.id) + for chat in Chats.get_chats_by_user_id(user.id, db=db) ] @@ -647,10 +694,12 @@ async def get_user_chats(user=Depends(get_verified_user)): @router.get("/all/archived", response_model=list[ChatResponse]) -async def get_user_archived_chats(user=Depends(get_verified_user)): +async def get_user_archived_chats( + user=Depends(get_verified_user), db: Session = Depends(get_session) +): return [ ChatResponse(**chat.model_dump()) - for chat in Chats.get_archived_chats_by_user_id(user.id) + for chat in Chats.get_archived_chats_by_user_id(user.id, db=db) ] @@ -660,9 +709,11 @@ async def get_user_archived_chats(user=Depends(get_verified_user)): @router.get("/all/tags", response_model=list[TagModel]) -async def get_all_user_tags(user=Depends(get_verified_user)): +async def get_all_user_tags( + user=Depends(get_verified_user), db: Session = Depends(get_session) +): try: - tags = Tags.get_tags_by_user_id(user.id) + tags = Tags.get_tags_by_user_id(user.id, db=db) return tags except Exception as e: log.exception(e) @@ -677,13 +728,15 @@ async def get_all_user_tags(user=Depends(get_verified_user)): @router.get("/all/db", response_model=list[ChatResponse]) -async def get_all_user_chats_in_db(user=Depends(get_admin_user)): +async def get_all_user_chats_in_db( + user=Depends(get_admin_user), db: Session = Depends(get_session) +): if not ENABLE_ADMIN_EXPORT: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED, ) - return [ChatResponse(**chat.model_dump()) for chat in Chats.get_chats()] + return [ChatResponse(**chat.model_dump()) for chat in Chats.get_chats(db=db)] ############################ @@ -698,6 +751,7 @@ async def get_archived_session_user_chat_list( order_by: Optional[str] = None, direction: Optional[str] = None, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): if page is None: page = 1 @@ -720,6 +774,7 @@ async def get_archived_session_user_chat_list( filter=filter, skip=skip, limit=limit, + db=db, ) ] @@ -732,8 +787,10 @@ async def get_archived_session_user_chat_list( @router.post("/archive/all", response_model=bool) -async def archive_all_chats(user=Depends(get_verified_user)): - return Chats.archive_all_chats_by_user_id(user.id) +async def archive_all_chats( + user=Depends(get_verified_user), db: Session = Depends(get_session) +): + return Chats.archive_all_chats_by_user_id(user.id, db=db) ############################ @@ -742,8 +799,10 @@ async def archive_all_chats(user=Depends(get_verified_user)): @router.post("/unarchive/all", response_model=bool) -async def unarchive_all_chats(user=Depends(get_verified_user)): - return Chats.unarchive_all_chats_by_user_id(user.id) +async def unarchive_all_chats( + user=Depends(get_verified_user), db: Session = Depends(get_session) +): + return Chats.unarchive_all_chats_by_user_id(user.id, db=db) ############################ @@ -752,16 +811,18 @@ async def unarchive_all_chats(user=Depends(get_verified_user)): @router.get("/share/{share_id}", response_model=Optional[ChatResponse]) -async def get_shared_chat_by_id(share_id: str, user=Depends(get_verified_user)): +async def get_shared_chat_by_id( + share_id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) +): if user.role == "pending": raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND ) if user.role == "user" or (user.role == "admin" and not ENABLE_ADMIN_CHAT_ACCESS): - chat = Chats.get_chat_by_share_id(share_id) + chat = Chats.get_chat_by_share_id(share_id, db=db) elif user.role == "admin" and ENABLE_ADMIN_CHAT_ACCESS: - chat = Chats.get_chat_by_id(share_id) + chat = Chats.get_chat_by_id(share_id, db=db) if chat: return ChatResponse(**chat.model_dump()) @@ -788,13 +849,15 @@ class TagFilterForm(TagForm): @router.post("/tags", response_model=list[ChatTitleIdResponse]) async def get_user_chat_list_by_tag_name( - form_data: TagFilterForm, user=Depends(get_verified_user) + form_data: TagFilterForm, + user=Depends(get_verified_user), + db: Session = Depends(get_session), ): chats = Chats.get_chat_list_by_user_id_and_tag_name( - user.id, form_data.name, form_data.skip, form_data.limit + user.id, form_data.name, form_data.skip, form_data.limit, db=db ) if len(chats) == 0: - Tags.delete_tag_by_name_and_user_id(form_data.name, user.id) + Tags.delete_tag_by_name_and_user_id(form_data.name, user.id, db=db) return chats @@ -805,8 +868,10 @@ async def get_user_chat_list_by_tag_name( @router.get("/{id}", response_model=Optional[ChatResponse]) -async def get_chat_by_id(id: str, user=Depends(get_verified_user)): - chat = Chats.get_chat_by_id_and_user_id(id, user.id) +async def get_chat_by_id( + id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) +): + chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db) if chat: return ChatResponse(**chat.model_dump()) @@ -824,12 +889,15 @@ async def get_chat_by_id(id: str, user=Depends(get_verified_user)): @router.post("/{id}", response_model=Optional[ChatResponse]) async def update_chat_by_id( - id: str, form_data: ChatForm, user=Depends(get_verified_user) + id: str, + form_data: ChatForm, + user=Depends(get_verified_user), + db: Session = Depends(get_session), ): - chat = Chats.get_chat_by_id_and_user_id(id, user.id) + chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db) if chat: updated_chat = {**chat.chat, **form_data.chat} - chat = Chats.update_chat_by_id(id, updated_chat) + chat = Chats.update_chat_by_id(id, updated_chat, db=db) return ChatResponse(**chat.model_dump()) else: raise HTTPException( @@ -847,9 +915,13 @@ class MessageForm(BaseModel): @router.post("/{id}/messages/{message_id}", response_model=Optional[ChatResponse]) async def update_chat_message_by_id( - id: str, message_id: str, form_data: MessageForm, user=Depends(get_verified_user) + id: str, + message_id: str, + form_data: MessageForm, + user=Depends(get_verified_user), + db: Session = Depends(get_session), ): - chat = Chats.get_chat_by_id(id) + chat = Chats.get_chat_by_id(id, db=db) if not chat: raise HTTPException( @@ -869,6 +941,7 @@ async def update_chat_message_by_id( { "content": form_data.content, }, + db=db, ) event_emitter = get_event_emitter( @@ -905,9 +978,13 @@ class EventForm(BaseModel): @router.post("/{id}/messages/{message_id}/event", response_model=Optional[bool]) async def send_chat_message_event_by_id( - id: str, message_id: str, form_data: EventForm, user=Depends(get_verified_user) + id: str, + message_id: str, + form_data: EventForm, + user=Depends(get_verified_user), + db: Session = Depends(get_session), ): - chat = Chats.get_chat_by_id(id) + chat = Chats.get_chat_by_id(id, db=db) if not chat: raise HTTPException( @@ -945,14 +1022,19 @@ async def send_chat_message_event_by_id( @router.delete("/{id}", response_model=bool) -async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified_user)): +async def delete_chat_by_id( + request: Request, + id: str, + user=Depends(get_verified_user), + db: Session = Depends(get_session), +): if user.role == "admin": - chat = Chats.get_chat_by_id(id) + chat = Chats.get_chat_by_id(id, db=db) for tag in chat.meta.get("tags", []): - if Chats.count_chats_by_tag_name_and_user_id(tag, user.id) == 1: - Tags.delete_tag_by_name_and_user_id(tag, user.id) + if Chats.count_chats_by_tag_name_and_user_id(tag, user.id, db=db) == 1: + Tags.delete_tag_by_name_and_user_id(tag, user.id, db=db) - result = Chats.delete_chat_by_id(id) + result = Chats.delete_chat_by_id(id, db=db) return result else: @@ -964,12 +1046,12 @@ async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified detail=ERROR_MESSAGES.ACCESS_PROHIBITED, ) - chat = Chats.get_chat_by_id(id) + chat = Chats.get_chat_by_id(id, db=db) for tag in chat.meta.get("tags", []): - if Chats.count_chats_by_tag_name_and_user_id(tag, user.id) == 1: - Tags.delete_tag_by_name_and_user_id(tag, user.id) + if Chats.count_chats_by_tag_name_and_user_id(tag, user.id, db=db) == 1: + Tags.delete_tag_by_name_and_user_id(tag, user.id, db=db) - result = Chats.delete_chat_by_id_and_user_id(id, user.id) + result = Chats.delete_chat_by_id_and_user_id(id, user.id, db=db) return result @@ -979,8 +1061,10 @@ async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified @router.get("/{id}/pinned", response_model=Optional[bool]) -async def get_pinned_status_by_id(id: str, user=Depends(get_verified_user)): - chat = Chats.get_chat_by_id_and_user_id(id, user.id) +async def get_pinned_status_by_id( + id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) +): + chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db) if chat: return chat.pinned else: @@ -995,10 +1079,12 @@ async def get_pinned_status_by_id(id: str, user=Depends(get_verified_user)): @router.post("/{id}/pin", response_model=Optional[ChatResponse]) -async def pin_chat_by_id(id: str, user=Depends(get_verified_user)): - chat = Chats.get_chat_by_id_and_user_id(id, user.id) +async def pin_chat_by_id( + id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) +): + chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db) if chat: - chat = Chats.toggle_chat_pinned_by_id(id) + chat = Chats.toggle_chat_pinned_by_id(id, db=db) return chat else: raise HTTPException( @@ -1017,9 +1103,12 @@ class CloneForm(BaseModel): @router.post("/{id}/clone", response_model=Optional[ChatResponse]) async def clone_chat_by_id( - form_data: CloneForm, id: str, user=Depends(get_verified_user) + form_data: CloneForm, + id: str, + user=Depends(get_verified_user), + db: Session = Depends(get_session), ): - chat = Chats.get_chat_by_id_and_user_id(id, user.id) + chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db) if chat: updated_chat = { **chat.chat, @@ -1040,6 +1129,7 @@ async def clone_chat_by_id( } ) ], + db=db, ) if chats: @@ -1062,12 +1152,14 @@ async def clone_chat_by_id( @router.post("/{id}/clone/shared", response_model=Optional[ChatResponse]) -async def clone_shared_chat_by_id(id: str, user=Depends(get_verified_user)): +async def clone_shared_chat_by_id( + id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) +): if user.role == "admin": - chat = Chats.get_chat_by_id(id) + chat = Chats.get_chat_by_id(id, db=db) else: - chat = Chats.get_chat_by_share_id(id) + chat = Chats.get_chat_by_share_id(id, db=db) if chat: updated_chat = { @@ -1089,6 +1181,7 @@ async def clone_shared_chat_by_id(id: str, user=Depends(get_verified_user)): } ) ], + db=db, ) if chats: @@ -1111,23 +1204,28 @@ async def clone_shared_chat_by_id(id: str, user=Depends(get_verified_user)): @router.post("/{id}/archive", response_model=Optional[ChatResponse]) -async def archive_chat_by_id(id: str, user=Depends(get_verified_user)): - chat = Chats.get_chat_by_id_and_user_id(id, user.id) +async def archive_chat_by_id( + id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) +): + chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db) if chat: - chat = Chats.toggle_chat_archive_by_id(id) + chat = Chats.toggle_chat_archive_by_id(id, db=db) # Delete tags if chat is archived if chat.archived: for tag_id in chat.meta.get("tags", []): - if Chats.count_chats_by_tag_name_and_user_id(tag_id, user.id) == 0: + if ( + Chats.count_chats_by_tag_name_and_user_id(tag_id, user.id, db=db) + == 0 + ): log.debug(f"deleting tag: {tag_id}") - Tags.delete_tag_by_name_and_user_id(tag_id, user.id) + Tags.delete_tag_by_name_and_user_id(tag_id, user.id, db=db) else: for tag_id in chat.meta.get("tags", []): - tag = Tags.get_tag_by_name_and_user_id(tag_id, user.id) + tag = Tags.get_tag_by_name_and_user_id(tag_id, user.id, db=db) if tag is None: log.debug(f"inserting tag: {tag_id}") - tag = Tags.insert_new_tag(tag_id, user.id) + tag = Tags.insert_new_tag(tag_id, user.id, db=db) return ChatResponse(**chat.model_dump()) else: @@ -1142,7 +1240,12 @@ async def archive_chat_by_id(id: str, user=Depends(get_verified_user)): @router.post("/{id}/share", response_model=Optional[ChatResponse]) -async def share_chat_by_id(request: Request, id: str, user=Depends(get_verified_user)): +async def share_chat_by_id( + request: Request, + id: str, + user=Depends(get_verified_user), + db: Session = Depends(get_session), +): if (user.role != "admin") and ( not has_permission( user.id, "chat.share", request.app.state.config.USER_PERMISSIONS @@ -1153,14 +1256,14 @@ async def share_chat_by_id(request: Request, id: str, user=Depends(get_verified_ detail=ERROR_MESSAGES.ACCESS_PROHIBITED, ) - chat = Chats.get_chat_by_id_and_user_id(id, user.id) + chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db) if chat: if chat.share_id: - shared_chat = Chats.update_shared_chat_by_chat_id(chat.id) + shared_chat = Chats.update_shared_chat_by_chat_id(chat.id, db=db) return ChatResponse(**shared_chat.model_dump()) - shared_chat = Chats.insert_shared_chat_by_chat_id(chat.id) + shared_chat = Chats.insert_shared_chat_by_chat_id(chat.id, db=db) if not shared_chat: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, @@ -1181,14 +1284,16 @@ async def share_chat_by_id(request: Request, id: str, user=Depends(get_verified_ @router.delete("/{id}/share", response_model=Optional[bool]) -async def delete_shared_chat_by_id(id: str, user=Depends(get_verified_user)): - chat = Chats.get_chat_by_id_and_user_id(id, user.id) +async def delete_shared_chat_by_id( + id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) +): + chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db) if chat: if not chat.share_id: return False - result = Chats.delete_shared_chat_by_chat_id(id) - update_result = Chats.update_chat_share_id_by_id(id, None) + result = Chats.delete_shared_chat_by_chat_id(id, db=db) + update_result = Chats.update_chat_share_id_by_id(id, None, db=db) return result and update_result != None else: @@ -1209,12 +1314,15 @@ class ChatFolderIdForm(BaseModel): @router.post("/{id}/folder", response_model=Optional[ChatResponse]) async def update_chat_folder_id_by_id( - id: str, form_data: ChatFolderIdForm, user=Depends(get_verified_user) + id: str, + form_data: ChatFolderIdForm, + user=Depends(get_verified_user), + db: Session = Depends(get_session), ): - chat = Chats.get_chat_by_id_and_user_id(id, user.id) + chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db) if chat: chat = Chats.update_chat_folder_id_by_id_and_user_id( - id, user.id, form_data.folder_id + id, user.id, form_data.folder_id, db=db ) return ChatResponse(**chat.model_dump()) else: @@ -1229,11 +1337,13 @@ async def update_chat_folder_id_by_id( @router.get("/{id}/tags", response_model=list[TagModel]) -async def get_chat_tags_by_id(id: str, user=Depends(get_verified_user)): - chat = Chats.get_chat_by_id_and_user_id(id, user.id) +async def get_chat_tags_by_id( + id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) +): + chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db) if chat: tags = chat.meta.get("tags", []) - return Tags.get_tags_by_ids_and_user_id(tags, user.id) + return Tags.get_tags_by_ids_and_user_id(tags, user.id, db=db) else: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND @@ -1247,9 +1357,12 @@ async def get_chat_tags_by_id(id: str, user=Depends(get_verified_user)): @router.post("/{id}/tags", response_model=list[TagModel]) async def add_tag_by_id_and_tag_name( - id: str, form_data: TagForm, user=Depends(get_verified_user) + id: str, + form_data: TagForm, + user=Depends(get_verified_user), + db: Session = Depends(get_session), ): - chat = Chats.get_chat_by_id_and_user_id(id, user.id) + chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db) if chat: tags = chat.meta.get("tags", []) tag_id = form_data.name.replace(" ", "_").lower() @@ -1262,12 +1375,12 @@ async def add_tag_by_id_and_tag_name( if tag_id not in tags: Chats.add_chat_tag_by_id_and_user_id_and_tag_name( - id, user.id, form_data.name + id, user.id, form_data.name, db=db ) - chat = Chats.get_chat_by_id_and_user_id(id, user.id) + chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db) tags = chat.meta.get("tags", []) - return Tags.get_tags_by_ids_and_user_id(tags, user.id) + return Tags.get_tags_by_ids_and_user_id(tags, user.id, db=db) else: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT() @@ -1281,18 +1394,26 @@ async def add_tag_by_id_and_tag_name( @router.delete("/{id}/tags", response_model=list[TagModel]) async def delete_tag_by_id_and_tag_name( - id: str, form_data: TagForm, user=Depends(get_verified_user) + id: str, + form_data: TagForm, + user=Depends(get_verified_user), + db: Session = Depends(get_session), ): - chat = Chats.get_chat_by_id_and_user_id(id, user.id) + chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db) if chat: - Chats.delete_tag_by_id_and_user_id_and_tag_name(id, user.id, form_data.name) + Chats.delete_tag_by_id_and_user_id_and_tag_name( + id, user.id, form_data.name, db=db + ) - if Chats.count_chats_by_tag_name_and_user_id(form_data.name, user.id) == 0: - Tags.delete_tag_by_name_and_user_id(form_data.name, user.id) + if ( + Chats.count_chats_by_tag_name_and_user_id(form_data.name, user.id, db=db) + == 0 + ): + Tags.delete_tag_by_name_and_user_id(form_data.name, user.id, db=db) - chat = Chats.get_chat_by_id_and_user_id(id, user.id) + chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db) tags = chat.meta.get("tags", []) - return Tags.get_tags_by_ids_and_user_id(tags, user.id) + return Tags.get_tags_by_ids_and_user_id(tags, user.id, db=db) else: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND @@ -1305,14 +1426,16 @@ async def delete_tag_by_id_and_tag_name( @router.delete("/{id}/tags/all", response_model=Optional[bool]) -async def delete_all_tags_by_id(id: str, user=Depends(get_verified_user)): - chat = Chats.get_chat_by_id_and_user_id(id, user.id) +async def delete_all_tags_by_id( + id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) +): + chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db) if chat: - Chats.delete_all_tags_by_id_and_user_id(id, user.id) + Chats.delete_all_tags_by_id_and_user_id(id, user.id, db=db) for tag in chat.meta.get("tags", []): - if Chats.count_chats_by_tag_name_and_user_id(tag, user.id) == 0: - Tags.delete_tag_by_name_and_user_id(tag, user.id) + if Chats.count_chats_by_tag_name_and_user_id(tag, user.id, db=db) == 0: + Tags.delete_tag_by_name_and_user_id(tag, user.id, db=db) return True else: diff --git a/backend/open_webui/routers/evaluations.py b/backend/open_webui/routers/evaluations.py index cdcefe6ba7..a445be5f98 100644 --- a/backend/open_webui/routers/evaluations.py +++ b/backend/open_webui/routers/evaluations.py @@ -15,6 +15,8 @@ from open_webui.models.feedbacks import ( from open_webui.constants import ERROR_MESSAGES from open_webui.utils.auth import get_admin_user, get_verified_user +from open_webui.internal.db import get_session +from sqlalchemy.orm import Session router = APIRouter() @@ -60,38 +62,50 @@ async def update_config( @router.get("/feedbacks/all", response_model=list[FeedbackResponse]) -async def get_all_feedbacks(user=Depends(get_admin_user)): - feedbacks = Feedbacks.get_all_feedbacks() +async def get_all_feedbacks( + user=Depends(get_admin_user), db: Session = Depends(get_session) +): + feedbacks = Feedbacks.get_all_feedbacks(db=db) return feedbacks @router.get("/feedbacks/all/ids", response_model=list[FeedbackIdResponse]) -async def get_all_feedback_ids(user=Depends(get_admin_user)): - feedbacks = Feedbacks.get_all_feedbacks() +async def get_all_feedback_ids( + user=Depends(get_admin_user), db: Session = Depends(get_session) +): + feedbacks = Feedbacks.get_all_feedbacks(db=db) return feedbacks @router.delete("/feedbacks/all") -async def delete_all_feedbacks(user=Depends(get_admin_user)): - success = Feedbacks.delete_all_feedbacks() +async def delete_all_feedbacks( + user=Depends(get_admin_user), db: Session = Depends(get_session) +): + success = Feedbacks.delete_all_feedbacks(db=db) return success @router.get("/feedbacks/all/export", response_model=list[FeedbackModel]) -async def export_all_feedbacks(user=Depends(get_admin_user)): - feedbacks = Feedbacks.get_all_feedbacks() +async def export_all_feedbacks( + user=Depends(get_admin_user), db: Session = Depends(get_session) +): + feedbacks = Feedbacks.get_all_feedbacks(db=db) return feedbacks @router.get("/feedbacks/user", response_model=list[FeedbackUserResponse]) -async def get_feedbacks(user=Depends(get_verified_user)): - feedbacks = Feedbacks.get_feedbacks_by_user_id(user.id) +async def get_feedbacks( + user=Depends(get_verified_user), db: Session = Depends(get_session) +): + feedbacks = Feedbacks.get_feedbacks_by_user_id(user.id, db=db) return feedbacks @router.delete("/feedbacks", response_model=bool) -async def delete_feedbacks(user=Depends(get_verified_user)): - success = Feedbacks.delete_feedbacks_by_user_id(user.id) +async def delete_feedbacks( + user=Depends(get_verified_user), db: Session = Depends(get_session) +): + success = Feedbacks.delete_feedbacks_by_user_id(user.id, db=db) return success @@ -104,6 +118,7 @@ async def get_feedbacks( direction: Optional[str] = None, page: Optional[int] = 1, user=Depends(get_admin_user), + db: Session = Depends(get_session), ): limit = PAGE_ITEM_COUNT @@ -116,7 +131,7 @@ async def get_feedbacks( if direction: filter["direction"] = direction - result = Feedbacks.get_feedback_items(filter=filter, skip=skip, limit=limit) + result = Feedbacks.get_feedback_items(filter=filter, skip=skip, limit=limit, db=db) return result @@ -125,8 +140,11 @@ async def create_feedback( request: Request, form_data: FeedbackForm, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): - feedback = Feedbacks.insert_new_feedback(user_id=user.id, form_data=form_data) + feedback = Feedbacks.insert_new_feedback( + user_id=user.id, form_data=form_data, db=db + ) if not feedback: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -137,11 +155,15 @@ async def create_feedback( @router.get("/feedback/{id}", response_model=FeedbackModel) -async def get_feedback_by_id(id: str, user=Depends(get_verified_user)): +async def get_feedback_by_id( + id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) +): if user.role == "admin": - feedback = Feedbacks.get_feedback_by_id(id=id) + feedback = Feedbacks.get_feedback_by_id(id=id, db=db) else: - feedback = Feedbacks.get_feedback_by_id_and_user_id(id=id, user_id=user.id) + feedback = Feedbacks.get_feedback_by_id_and_user_id( + id=id, user_id=user.id, db=db + ) if not feedback: raise HTTPException( @@ -153,13 +175,16 @@ async def get_feedback_by_id(id: str, user=Depends(get_verified_user)): @router.post("/feedback/{id}", response_model=FeedbackModel) async def update_feedback_by_id( - id: str, form_data: FeedbackForm, user=Depends(get_verified_user) + id: str, + form_data: FeedbackForm, + user=Depends(get_verified_user), + db: Session = Depends(get_session), ): if user.role == "admin": - feedback = Feedbacks.update_feedback_by_id(id=id, form_data=form_data) + feedback = Feedbacks.update_feedback_by_id(id=id, form_data=form_data, db=db) else: feedback = Feedbacks.update_feedback_by_id_and_user_id( - id=id, user_id=user.id, form_data=form_data + id=id, user_id=user.id, form_data=form_data, db=db ) if not feedback: @@ -171,11 +196,15 @@ async def update_feedback_by_id( @router.delete("/feedback/{id}") -async def delete_feedback_by_id(id: str, user=Depends(get_verified_user)): +async def delete_feedback_by_id( + id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) +): if user.role == "admin": - success = Feedbacks.delete_feedback_by_id(id=id) + success = Feedbacks.delete_feedback_by_id(id=id, db=db) else: - success = Feedbacks.delete_feedback_by_id_and_user_id(id=id, user_id=user.id) + success = Feedbacks.delete_feedback_by_id_and_user_id( + id=id, user_id=user.id, db=db + ) if not success: raise HTTPException( diff --git a/backend/open_webui/routers/files.py b/backend/open_webui/routers/files.py index 7f01537e66..311f5f7b84 100644 --- a/backend/open_webui/routers/files.py +++ b/backend/open_webui/routers/files.py @@ -22,6 +22,8 @@ from fastapi import ( ) from fastapi.responses import FileResponse, StreamingResponse +from sqlalchemy.orm import Session +from open_webui.internal.db import get_session, SessionLocal from open_webui.constants import ERROR_MESSAGES from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT @@ -62,9 +64,12 @@ router = APIRouter() # TODO: Optimize this function to use the knowledge_file table for faster lookups. def has_access_to_file( - file_id: Optional[str], access_type: str, user=Depends(get_verified_user) + file_id: Optional[str], + access_type: str, + user=Depends(get_verified_user), + db: Optional[Session] = None, ) -> bool: - file = Files.get_file_by_id(file_id) + file = Files.get_file_by_id(file_id, db=db) log.debug(f"Checking if user has {access_type} access to file") if not file: raise HTTPException( @@ -73,31 +78,33 @@ def has_access_to_file( ) # Check if the file is associated with any knowledge bases the user has access to - knowledge_bases = Knowledges.get_knowledges_by_file_id(file_id) - user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user.id)} + knowledge_bases = Knowledges.get_knowledges_by_file_id(file_id, db=db) + user_group_ids = { + group.id for group in Groups.get_groups_by_member_id(user.id, db=db) + } for knowledge_base in knowledge_bases: if knowledge_base.user_id == user.id or has_access( - user.id, access_type, knowledge_base.access_control, user_group_ids + user.id, access_type, knowledge_base.access_control, user_group_ids, db=db ): return True knowledge_base_id = file.meta.get("collection_name") if file.meta else None if knowledge_base_id: knowledge_bases = Knowledges.get_knowledge_bases_by_user_id( - user.id, access_type + user.id, access_type, db=db ) for knowledge_base in knowledge_bases: if knowledge_base.id == knowledge_base_id: return True # Check if the file is associated with any channels the user has access to - channels = Channels.get_channels_by_file_id_and_user_id(file_id, user.id) + channels = Channels.get_channels_by_file_id_and_user_id(file_id, user.id, db=db) if access_type == "read" and channels: return True # Check if the file is associated with any chats the user has access to # TODO: Granular access control for chats - chats = Chats.get_shared_chats_by_file_id(file_id) + chats = Chats.get_shared_chats_by_file_id(file_id, db=db) if chats: return True @@ -109,47 +116,78 @@ def has_access_to_file( ############################ -def process_uploaded_file(request, file, file_path, file_item, file_metadata, user): - try: - if file.content_type: - stt_supported_content_types = getattr( - request.app.state.config, "STT_SUPPORTED_CONTENT_TYPES", [] - ) +def process_uploaded_file( + request, + file, + file_path, + file_item, + file_metadata, + user, + db: Optional[Session] = None, +): + def _process_handler(db_session): + try: + if file.content_type: + stt_supported_content_types = getattr( + request.app.state.config, "STT_SUPPORTED_CONTENT_TYPES", [] + ) - if strict_match_mime_type(stt_supported_content_types, file.content_type): - file_path = Storage.get_file(file_path) - result = transcribe(request, file_path, file_metadata, user) + if strict_match_mime_type( + stt_supported_content_types, file.content_type + ): + file_path_processed = Storage.get_file(file_path) + result = transcribe( + request, file_path_processed, file_metadata, user + ) + process_file( + request, + ProcessFileForm( + file_id=file_item.id, content=result.get("text", "") + ), + user=user, + db=db_session, + ) + elif (not file.content_type.startswith(("image/", "video/"))) or ( + request.app.state.config.CONTENT_EXTRACTION_ENGINE == "external" + ): + process_file( + request, + ProcessFileForm(file_id=file_item.id), + user=user, + db=db_session, + ) + else: + raise Exception( + f"File type {file.content_type} is not supported for processing" + ) + else: + log.info( + f"File type {file.content_type} is not provided, but trying to process anyway" + ) process_file( request, - ProcessFileForm( - file_id=file_item.id, content=result.get("text", "") - ), + ProcessFileForm(file_id=file_item.id), user=user, + db=db_session, ) - elif (not file.content_type.startswith(("image/", "video/"))) or ( - request.app.state.config.CONTENT_EXTRACTION_ENGINE == "external" - ): - process_file(request, ProcessFileForm(file_id=file_item.id), user=user) - else: - raise Exception( - f"File type {file.content_type} is not supported for processing" - ) - else: - log.info( - f"File type {file.content_type} is not provided, but trying to process anyway" - ) - process_file(request, ProcessFileForm(file_id=file_item.id), user=user) - except Exception as e: - log.error(f"Error processing file: {file_item.id}") - Files.update_file_data_by_id( - file_item.id, - { - "status": "failed", - "error": str(e.detail) if hasattr(e, "detail") else str(e), - }, - ) + except Exception as e: + log.error(f"Error processing file: {file_item.id}") + Files.update_file_data_by_id( + file_item.id, + { + "status": "failed", + "error": str(e.detail) if hasattr(e, "detail") else str(e), + }, + db=db_session, + ) + + if db: + _process_handler(db) + else: + with SessionLocal() as db_session: + _process_handler(db_session) @router.post("/", response_model=FileModelResponse) @@ -161,6 +199,7 @@ def upload_file( process: bool = Query(True), process_in_background: bool = Query(True), user=Depends(get_verified_user), + db: Session = Depends(get_session), ): return upload_file_handler( request, @@ -170,6 +209,7 @@ def upload_file( process_in_background=process_in_background, user=user, background_tasks=background_tasks, + db=db, ) @@ -181,6 +221,7 @@ def upload_file_handler( process_in_background: bool = Query(True), user=Depends(get_verified_user), background_tasks: Optional[BackgroundTasks] = None, + db: Optional[Session] = None, ): log.info(f"file.content_type: {file.content_type} {process}") @@ -248,14 +289,17 @@ def upload_file_handler( }, } ), + db=db, ) if "channel_id" in file_metadata: channel = Channels.get_channel_by_id_and_user_id( - file_metadata["channel_id"], user.id + file_metadata["channel_id"], user.id, db=db ) if channel: - Channels.add_file_to_channel_by_id(channel.id, file_item.id, user.id) + Channels.add_file_to_channel_by_id( + channel.id, file_item.id, user.id, db=db + ) if process: if background_tasks and process_in_background: @@ -277,6 +321,7 @@ def upload_file_handler( file_item, file_metadata, user, + db=db, ) return {"status": True, **file_item.model_dump()} else: @@ -302,11 +347,15 @@ def upload_file_handler( @router.get("/", response_model=list[FileModelResponse]) -async def list_files(user=Depends(get_verified_user), content: bool = Query(True)): +async def list_files( + user=Depends(get_verified_user), + content: bool = Query(True), + db: Session = Depends(get_session), +): if user.role == "admin": - files = Files.get_files() + files = Files.get_files(db=db) else: - files = Files.get_files_by_user_id(user.id) + files = Files.get_files_by_user_id(user.id, db=db) if not content: for file in files: @@ -329,15 +378,16 @@ async def search_files( ), content: bool = Query(True), user=Depends(get_verified_user), + db: Session = Depends(get_session), ): """ Search for files by filename with support for wildcard patterns. """ # Get files according to user role if user.role == "admin": - files = Files.get_files() + files = Files.get_files(db=db) else: - files = Files.get_files_by_user_id(user.id) + files = Files.get_files_by_user_id(user.id, db=db) # Get matching files matching_files = [ @@ -364,8 +414,10 @@ async def search_files( @router.delete("/all") -async def delete_all_files(user=Depends(get_admin_user)): - result = Files.delete_all_files() +async def delete_all_files( + user=Depends(get_admin_user), db: Session = Depends(get_session) +): + result = Files.delete_all_files(db=db) if result: try: Storage.delete_all_files() @@ -391,8 +443,10 @@ async def delete_all_files(user=Depends(get_admin_user)): @router.get("/{id}", response_model=Optional[FileModel]) -async def get_file_by_id(id: str, user=Depends(get_verified_user)): - file = Files.get_file_by_id(id) +async def get_file_by_id( + id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) +): + file = Files.get_file_by_id(id, db=db) if not file: raise HTTPException( @@ -403,7 +457,7 @@ async def get_file_by_id(id: str, user=Depends(get_verified_user)): if ( file.user_id == user.id or user.role == "admin" - or has_access_to_file(id, "read", user) + or has_access_to_file(id, "read", user, db=db) ): return file else: @@ -415,9 +469,12 @@ async def get_file_by_id(id: str, user=Depends(get_verified_user)): @router.get("/{id}/process/status") async def get_file_process_status( - id: str, stream: bool = Query(False), user=Depends(get_verified_user) + id: str, + stream: bool = Query(False), + user=Depends(get_verified_user), + db: Session = Depends(get_session), ): - file = Files.get_file_by_id(id) + file = Files.get_file_by_id(id, db=db) if not file: raise HTTPException( @@ -428,7 +485,7 @@ async def get_file_process_status( if ( file.user_id == user.id or user.role == "admin" - or has_access_to_file(id, "read", user) + or has_access_to_file(id, "read", user, db=db) ): if stream: MAX_FILE_PROCESSING_DURATION = 3600 * 2 @@ -436,7 +493,7 @@ async def get_file_process_status( async def event_stream(file_item): if file_item: for _ in range(MAX_FILE_PROCESSING_DURATION): - file_item = Files.get_file_by_id(file_item.id) + file_item = Files.get_file_by_id(file_item.id, db=db) if file_item: data = file_item.model_dump().get("data", {}) status = data.get("status") @@ -476,8 +533,10 @@ async def get_file_process_status( @router.get("/{id}/data/content") -async def get_file_data_content_by_id(id: str, user=Depends(get_verified_user)): - file = Files.get_file_by_id(id) +async def get_file_data_content_by_id( + id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) +): + file = Files.get_file_by_id(id, db=db) if not file: raise HTTPException( @@ -488,7 +547,7 @@ async def get_file_data_content_by_id(id: str, user=Depends(get_verified_user)): if ( file.user_id == user.id or user.role == "admin" - or has_access_to_file(id, "read", user) + or has_access_to_file(id, "read", user, db=db) ): return {"content": file.data.get("content", "")} else: @@ -509,9 +568,13 @@ class ContentForm(BaseModel): @router.post("/{id}/data/content/update") async def update_file_data_content_by_id( - request: Request, id: str, form_data: ContentForm, user=Depends(get_verified_user) + request: Request, + id: str, + form_data: ContentForm, + user=Depends(get_verified_user), + db: Session = Depends(get_session), ): - file = Files.get_file_by_id(id) + file = Files.get_file_by_id(id, db=db) if not file: raise HTTPException( @@ -522,7 +585,7 @@ async def update_file_data_content_by_id( if ( file.user_id == user.id or user.role == "admin" - or has_access_to_file(id, "write", user) + or has_access_to_file(id, "write", user, db=db) ): try: process_file( @@ -530,7 +593,7 @@ async def update_file_data_content_by_id( ProcessFileForm(file_id=id, content=form_data.content), user=user, ) - file = Files.get_file_by_id(id=id) + file = Files.get_file_by_id(id=id, db=db) except Exception as e: log.exception(e) log.error(f"Error processing file: {file.id}") @@ -550,9 +613,12 @@ async def update_file_data_content_by_id( @router.get("/{id}/content") async def get_file_content_by_id( - id: str, user=Depends(get_verified_user), attachment: bool = Query(False) + id: str, + user=Depends(get_verified_user), + attachment: bool = Query(False), + db: Session = Depends(get_session), ): - file = Files.get_file_by_id(id) + file = Files.get_file_by_id(id, db=db) if not file: raise HTTPException( @@ -563,7 +629,7 @@ async def get_file_content_by_id( if ( file.user_id == user.id or user.role == "admin" - or has_access_to_file(id, "read", user) + or has_access_to_file(id, "read", user, db=db) ): try: file_path = Storage.get_file(file.path) @@ -619,8 +685,10 @@ async def get_file_content_by_id( @router.get("/{id}/content/html") -async def get_html_file_content_by_id(id: str, user=Depends(get_verified_user)): - file = Files.get_file_by_id(id) +async def get_html_file_content_by_id( + id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) +): + file = Files.get_file_by_id(id, db=db) if not file: raise HTTPException( @@ -628,7 +696,7 @@ async def get_html_file_content_by_id(id: str, user=Depends(get_verified_user)): detail=ERROR_MESSAGES.NOT_FOUND, ) - file_user = Users.get_user_by_id(file.user_id) + file_user = Users.get_user_by_id(file.user_id, db=db) if not file_user.role == "admin": raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, @@ -638,7 +706,7 @@ async def get_html_file_content_by_id(id: str, user=Depends(get_verified_user)): if ( file.user_id == user.id or user.role == "admin" - or has_access_to_file(id, "read", user) + or has_access_to_file(id, "read", user, db=db) ): try: file_path = Storage.get_file(file.path) @@ -668,8 +736,10 @@ async def get_html_file_content_by_id(id: str, user=Depends(get_verified_user)): @router.get("/{id}/content/{file_name}") -async def get_file_content_by_id(id: str, user=Depends(get_verified_user)): - file = Files.get_file_by_id(id) +async def get_file_content_by_id( + id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) +): + file = Files.get_file_by_id(id, db=db) if not file: raise HTTPException( @@ -680,7 +750,7 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)): if ( file.user_id == user.id or user.role == "admin" - or has_access_to_file(id, "read", user) + or has_access_to_file(id, "read", user, db=db) ): file_path = file.path @@ -730,8 +800,10 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)): @router.delete("/{id}") -async def delete_file_by_id(id: str, user=Depends(get_verified_user)): - file = Files.get_file_by_id(id) +async def delete_file_by_id( + id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) +): + file = Files.get_file_by_id(id, db=db) if not file: raise HTTPException( @@ -742,10 +814,10 @@ async def delete_file_by_id(id: str, user=Depends(get_verified_user)): if ( file.user_id == user.id or user.role == "admin" - or has_access_to_file(id, "write", user) + or has_access_to_file(id, "write", user, db=db) ): - result = Files.delete_file_by_id(id) + result = Files.delete_file_by_id(id, db=db) if result: try: Storage.delete_file(file.path) diff --git a/backend/open_webui/routers/folders.py b/backend/open_webui/routers/folders.py index 32911fa509..1c9b2229cf 100644 --- a/backend/open_webui/routers/folders.py +++ b/backend/open_webui/routers/folders.py @@ -22,6 +22,8 @@ from open_webui.models.knowledge import Knowledges from open_webui.config import UPLOAD_DIR from open_webui.constants import ERROR_MESSAGES +from open_webui.internal.db import get_session +from sqlalchemy.orm import Session from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status, Request @@ -44,7 +46,11 @@ router = APIRouter() @router.get("/", response_model=list[FolderNameIdResponse]) -async def get_folders(request: Request, user=Depends(get_verified_user)): +async def get_folders( + request: Request, + user=Depends(get_verified_user), + db: Session = Depends(get_session), +): if request.app.state.config.ENABLE_FOLDERS is False: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, @@ -55,22 +61,23 @@ async def get_folders(request: Request, user=Depends(get_verified_user)): user.id, "features.folders", request.app.state.config.USER_PERMISSIONS, + db=db, ): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED, ) - folders = Folders.get_folders_by_user_id(user.id) + folders = Folders.get_folders_by_user_id(user.id, db=db) # Verify folder data integrity folder_list = [] for folder in folders: if folder.parent_id and not Folders.get_folder_by_id_and_user_id( - folder.parent_id, user.id + folder.parent_id, user.id, db=db ): folder = Folders.update_folder_parent_id_by_id_and_user_id( - folder.id, user.id, None + folder.id, user.id, None, db=db ) if folder.data: @@ -80,12 +87,12 @@ async def get_folders(request: Request, user=Depends(get_verified_user)): if file.get("type") == "file": if Files.check_access_by_user_id( - file.get("id"), user.id, "read" + file.get("id"), user.id, "read", db=db ): valid_files.append(file) elif file.get("type") == "collection": if Knowledges.check_access_by_user_id( - file.get("id"), user.id, "read" + file.get("id"), user.id, "read", db=db ): valid_files.append(file) else: @@ -93,7 +100,7 @@ async def get_folders(request: Request, user=Depends(get_verified_user)): folder.data["files"] = valid_files Folders.update_folder_by_id_and_user_id( - folder.id, user.id, FolderUpdateForm(data=folder.data) + folder.id, user.id, FolderUpdateForm(data=folder.data), db=db ) folder_list.append(FolderNameIdResponse(**folder.model_dump())) @@ -107,9 +114,13 @@ async def get_folders(request: Request, user=Depends(get_verified_user)): @router.post("/") -def create_folder(form_data: FolderForm, user=Depends(get_verified_user)): +def create_folder( + form_data: FolderForm, + user=Depends(get_verified_user), + db: Session = Depends(get_session), +): folder = Folders.get_folder_by_parent_id_and_user_id_and_name( - None, user.id, form_data.name + None, user.id, form_data.name, db=db ) if folder: @@ -119,7 +130,7 @@ def create_folder(form_data: FolderForm, user=Depends(get_verified_user)): ) try: - folder = Folders.insert_new_folder(user.id, form_data) + folder = Folders.insert_new_folder(user.id, form_data, db=db) return folder except Exception as e: log.exception(e) @@ -136,8 +147,10 @@ def create_folder(form_data: FolderForm, user=Depends(get_verified_user)): @router.get("/{id}", response_model=Optional[FolderModel]) -async def get_folder_by_id(id: str, user=Depends(get_verified_user)): - folder = Folders.get_folder_by_id_and_user_id(id, user.id) +async def get_folder_by_id( + id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) +): + folder = Folders.get_folder_by_id_and_user_id(id, user.id, db=db) if folder: return folder else: @@ -154,15 +167,18 @@ async def get_folder_by_id(id: str, user=Depends(get_verified_user)): @router.post("/{id}/update") async def update_folder_name_by_id( - id: str, form_data: FolderUpdateForm, user=Depends(get_verified_user) + id: str, + form_data: FolderUpdateForm, + user=Depends(get_verified_user), + db: Session = Depends(get_session), ): - folder = Folders.get_folder_by_id_and_user_id(id, user.id) + folder = Folders.get_folder_by_id_and_user_id(id, user.id, db=db) if folder: if form_data.name is not None: # Check if folder with same name exists existing_folder = Folders.get_folder_by_parent_id_and_user_id_and_name( - folder.parent_id, user.id, form_data.name + folder.parent_id, user.id, form_data.name, db=db ) if existing_folder and existing_folder.id != id: raise HTTPException( @@ -171,7 +187,9 @@ async def update_folder_name_by_id( ) try: - folder = Folders.update_folder_by_id_and_user_id(id, user.id, form_data) + folder = Folders.update_folder_by_id_and_user_id( + id, user.id, form_data, db=db + ) return folder except Exception as e: log.exception(e) @@ -198,12 +216,15 @@ class FolderParentIdForm(BaseModel): @router.post("/{id}/update/parent") async def update_folder_parent_id_by_id( - id: str, form_data: FolderParentIdForm, user=Depends(get_verified_user) + id: str, + form_data: FolderParentIdForm, + user=Depends(get_verified_user), + db: Session = Depends(get_session), ): - folder = Folders.get_folder_by_id_and_user_id(id, user.id) + folder = Folders.get_folder_by_id_and_user_id(id, user.id, db=db) if folder: existing_folder = Folders.get_folder_by_parent_id_and_user_id_and_name( - form_data.parent_id, user.id, folder.name + form_data.parent_id, user.id, folder.name, db=db ) if existing_folder: @@ -214,7 +235,7 @@ async def update_folder_parent_id_by_id( try: folder = Folders.update_folder_parent_id_by_id_and_user_id( - id, user.id, form_data.parent_id + id, user.id, form_data.parent_id, db=db ) return folder except Exception as e: @@ -242,13 +263,16 @@ class FolderIsExpandedForm(BaseModel): @router.post("/{id}/update/expanded") async def update_folder_is_expanded_by_id( - id: str, form_data: FolderIsExpandedForm, user=Depends(get_verified_user) + id: str, + form_data: FolderIsExpandedForm, + user=Depends(get_verified_user), + db: Session = Depends(get_session), ): - folder = Folders.get_folder_by_id_and_user_id(id, user.id) + folder = Folders.get_folder_by_id_and_user_id(id, user.id, db=db) if folder: try: folder = Folders.update_folder_is_expanded_by_id_and_user_id( - id, user.id, form_data.is_expanded + id, user.id, form_data.is_expanded, db=db ) return folder except Exception as e: @@ -276,10 +300,11 @@ async def delete_folder_by_id( id: str, delete_contents: Optional[bool] = True, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): - if Chats.count_chats_by_folder_id_and_user_id(id, user.id): + if Chats.count_chats_by_folder_id_and_user_id(id, user.id, db=db): chat_delete_permission = has_permission( - user.id, "chat.delete", request.app.state.config.USER_PERMISSIONS + user.id, "chat.delete", request.app.state.config.USER_PERMISSIONS, db=db ) if user.role != "admin" and not chat_delete_permission: raise HTTPException( @@ -288,19 +313,21 @@ async def delete_folder_by_id( ) folders = [] - folders.append(Folders.get_folder_by_id_and_user_id(id, user.id)) + folders.append(Folders.get_folder_by_id_and_user_id(id, user.id, db=db)) while folders: folder = folders.pop() if folder: try: - folder_ids = Folders.delete_folder_by_id_and_user_id(id, user.id) + folder_ids = Folders.delete_folder_by_id_and_user_id(id, user.id, db=db) for folder_id in folder_ids: if delete_contents: - Chats.delete_chats_by_user_id_and_folder_id(user.id, folder_id) + Chats.delete_chats_by_user_id_and_folder_id( + user.id, folder_id, db=db + ) else: Chats.move_chats_by_user_id_and_folder_id( - user.id, folder_id, None + user.id, folder_id, None, db=db ) return True @@ -314,7 +341,7 @@ async def delete_folder_by_id( finally: # Get all subfolders subfolders = Folders.get_folders_by_parent_id_and_user_id( - folder.id, user.id + folder.id, user.id, db=db ) folders.extend(subfolders) diff --git a/backend/open_webui/routers/functions.py b/backend/open_webui/routers/functions.py index 82321cb546..b6dd07d736 100644 --- a/backend/open_webui/routers/functions.py +++ b/backend/open_webui/routers/functions.py @@ -24,6 +24,8 @@ from open_webui.constants import ERROR_MESSAGES from fastapi import APIRouter, Depends, HTTPException, Request, status from open_webui.utils.auth import get_admin_user, get_verified_user from pydantic import BaseModel, HttpUrl +from open_webui.internal.db import get_session +from sqlalchemy.orm import Session log = logging.getLogger(__name__) @@ -37,13 +39,13 @@ router = APIRouter() @router.get("/", response_model=list[FunctionResponse]) -async def get_functions(user=Depends(get_verified_user)): - return Functions.get_functions() +async def get_functions(user=Depends(get_verified_user), db: Session = Depends(get_session)): + return Functions.get_functions(db=db) @router.get("/list", response_model=list[FunctionUserResponse]) -async def get_function_list(user=Depends(get_admin_user)): - return Functions.get_function_list() +async def get_function_list(user=Depends(get_admin_user), db: Session = Depends(get_session)): + return Functions.get_function_list(db=db) ############################ @@ -52,8 +54,8 @@ async def get_function_list(user=Depends(get_admin_user)): @router.get("/export", response_model=list[FunctionModel | FunctionWithValvesModel]) -async def get_functions(include_valves: bool = False, user=Depends(get_admin_user)): - return Functions.get_functions(include_valves=include_valves) +async def get_functions(include_valves: bool = False, user=Depends(get_admin_user), db: Session = Depends(get_session)): + return Functions.get_functions(include_valves=include_valves, db=db) ############################ @@ -142,7 +144,7 @@ class SyncFunctionsForm(BaseModel): @router.post("/sync", response_model=list[FunctionWithValvesModel]) async def sync_functions( - request: Request, form_data: SyncFunctionsForm, user=Depends(get_admin_user) + request: Request, form_data: SyncFunctionsForm, user=Depends(get_admin_user), db: Session = Depends(get_session) ): try: for function in form_data.functions: @@ -164,7 +166,7 @@ async def sync_functions( ) raise e - return Functions.sync_functions(user.id, form_data.functions) + return Functions.sync_functions(user.id, form_data.functions, db=db) except Exception as e: log.exception(f"Failed to load a function: {e}") raise HTTPException( @@ -180,7 +182,7 @@ async def sync_functions( @router.post("/create", response_model=Optional[FunctionResponse]) async def create_new_function( - request: Request, form_data: FunctionForm, user=Depends(get_admin_user) + request: Request, form_data: FunctionForm, user=Depends(get_admin_user), db: Session = Depends(get_session) ): if not form_data.id.isidentifier(): raise HTTPException( @@ -190,7 +192,7 @@ async def create_new_function( form_data.id = form_data.id.lower() - function = Functions.get_function_by_id(form_data.id) + function = Functions.get_function_by_id(form_data.id, db=db) if function is None: try: form_data.content = replace_imports(form_data.content) @@ -203,13 +205,13 @@ async def create_new_function( FUNCTIONS = request.app.state.FUNCTIONS FUNCTIONS[form_data.id] = function_module - function = Functions.insert_new_function(user.id, function_type, form_data) + function = Functions.insert_new_function(user.id, function_type, form_data, db=db) function_cache_dir = CACHE_DIR / "functions" / form_data.id function_cache_dir.mkdir(parents=True, exist_ok=True) if function_type == "filter" and getattr(function_module, "toggle", None): - Functions.update_function_metadata_by_id(id, {"toggle": True}) + Functions.update_function_metadata_by_id(form_data.id, {"toggle": True}, db=db) if function: return function @@ -237,8 +239,8 @@ async def create_new_function( @router.get("/id/{id}", response_model=Optional[FunctionModel]) -async def get_function_by_id(id: str, user=Depends(get_admin_user)): - function = Functions.get_function_by_id(id) +async def get_function_by_id(id: str, user=Depends(get_admin_user), db: Session = Depends(get_session)): + function = Functions.get_function_by_id(id, db=db) if function: return function @@ -255,11 +257,11 @@ async def get_function_by_id(id: str, user=Depends(get_admin_user)): @router.post("/id/{id}/toggle", response_model=Optional[FunctionModel]) -async def toggle_function_by_id(id: str, user=Depends(get_admin_user)): - function = Functions.get_function_by_id(id) +async def toggle_function_by_id(id: str, user=Depends(get_admin_user), db: Session = Depends(get_session)): + function = Functions.get_function_by_id(id, db=db) if function: function = Functions.update_function_by_id( - id, {"is_active": not function.is_active} + id, {"is_active": not function.is_active}, db=db ) if function: @@ -282,11 +284,11 @@ async def toggle_function_by_id(id: str, user=Depends(get_admin_user)): @router.post("/id/{id}/toggle/global", response_model=Optional[FunctionModel]) -async def toggle_global_by_id(id: str, user=Depends(get_admin_user)): - function = Functions.get_function_by_id(id) +async def toggle_global_by_id(id: str, user=Depends(get_admin_user), db: Session = Depends(get_session)): + function = Functions.get_function_by_id(id, db=db) if function: function = Functions.update_function_by_id( - id, {"is_global": not function.is_global} + id, {"is_global": not function.is_global}, db=db ) if function: @@ -310,7 +312,7 @@ async def toggle_global_by_id(id: str, user=Depends(get_admin_user)): @router.post("/id/{id}/update", response_model=Optional[FunctionModel]) async def update_function_by_id( - request: Request, id: str, form_data: FunctionForm, user=Depends(get_admin_user) + request: Request, id: str, form_data: FunctionForm, user=Depends(get_admin_user), db: Session = Depends(get_session) ): try: form_data.content = replace_imports(form_data.content) @@ -325,10 +327,10 @@ async def update_function_by_id( updated = {**form_data.model_dump(exclude={"id"}), "type": function_type} log.debug(updated) - function = Functions.update_function_by_id(id, updated) + function = Functions.update_function_by_id(id, updated, db=db) if function_type == "filter" and getattr(function_module, "toggle", None): - Functions.update_function_metadata_by_id(id, {"toggle": True}) + Functions.update_function_metadata_by_id(id, {"toggle": True}, db=db) if function: return function @@ -352,9 +354,9 @@ async def update_function_by_id( @router.delete("/id/{id}/delete", response_model=bool) async def delete_function_by_id( - request: Request, id: str, user=Depends(get_admin_user) + request: Request, id: str, user=Depends(get_admin_user), db: Session = Depends(get_session) ): - result = Functions.delete_function_by_id(id) + result = Functions.delete_function_by_id(id, db=db) if result: FUNCTIONS = request.app.state.FUNCTIONS @@ -370,11 +372,11 @@ async def delete_function_by_id( @router.get("/id/{id}/valves", response_model=Optional[dict]) -async def get_function_valves_by_id(id: str, user=Depends(get_admin_user)): - function = Functions.get_function_by_id(id) +async def get_function_valves_by_id(id: str, user=Depends(get_admin_user), db: Session = Depends(get_session)): + function = Functions.get_function_by_id(id, db=db) if function: try: - valves = Functions.get_function_valves_by_id(id) + valves = Functions.get_function_valves_by_id(id, db=db) return valves except Exception as e: raise HTTPException( @@ -395,9 +397,9 @@ async def get_function_valves_by_id(id: str, user=Depends(get_admin_user)): @router.get("/id/{id}/valves/spec", response_model=Optional[dict]) async def get_function_valves_spec_by_id( - request: Request, id: str, user=Depends(get_admin_user) + request: Request, id: str, user=Depends(get_admin_user), db: Session = Depends(get_session) ): - function = Functions.get_function_by_id(id) + function = Functions.get_function_by_id(id, db=db) if function: function_module, function_type, frontmatter = get_function_module_from_cache( request, id @@ -421,9 +423,9 @@ async def get_function_valves_spec_by_id( @router.post("/id/{id}/valves/update", response_model=Optional[dict]) async def update_function_valves_by_id( - request: Request, id: str, form_data: dict, user=Depends(get_admin_user) + request: Request, id: str, form_data: dict, user=Depends(get_admin_user), db: Session = Depends(get_session) ): - function = Functions.get_function_by_id(id) + function = Functions.get_function_by_id(id, db=db) if function: function_module, function_type, frontmatter = get_function_module_from_cache( request, id @@ -437,7 +439,7 @@ async def update_function_valves_by_id( valves = Valves(**form_data) valves_dict = valves.model_dump(exclude_unset=True) - Functions.update_function_valves_by_id(id, valves_dict) + Functions.update_function_valves_by_id(id, valves_dict, db=db) return valves_dict except Exception as e: log.exception(f"Error updating function values by id {id}: {e}") @@ -464,11 +466,11 @@ async def update_function_valves_by_id( @router.get("/id/{id}/valves/user", response_model=Optional[dict]) -async def get_function_user_valves_by_id(id: str, user=Depends(get_verified_user)): - function = Functions.get_function_by_id(id) +async def get_function_user_valves_by_id(id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)): + function = Functions.get_function_by_id(id, db=db) if function: try: - user_valves = Functions.get_user_valves_by_id_and_user_id(id, user.id) + user_valves = Functions.get_user_valves_by_id_and_user_id(id, user.id, db=db) return user_valves except Exception as e: raise HTTPException( @@ -484,9 +486,9 @@ async def get_function_user_valves_by_id(id: str, user=Depends(get_verified_user @router.get("/id/{id}/valves/user/spec", response_model=Optional[dict]) async def get_function_user_valves_spec_by_id( - request: Request, id: str, user=Depends(get_verified_user) + request: Request, id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) ): - function = Functions.get_function_by_id(id) + function = Functions.get_function_by_id(id, db=db) if function: function_module, function_type, frontmatter = get_function_module_from_cache( request, id @@ -505,9 +507,9 @@ async def get_function_user_valves_spec_by_id( @router.post("/id/{id}/valves/user/update", response_model=Optional[dict]) async def update_function_user_valves_by_id( - request: Request, id: str, form_data: dict, user=Depends(get_verified_user) + request: Request, id: str, form_data: dict, user=Depends(get_verified_user), db: Session = Depends(get_session) ): - function = Functions.get_function_by_id(id) + function = Functions.get_function_by_id(id, db=db) if function: function_module, function_type, frontmatter = get_function_module_from_cache( @@ -522,7 +524,7 @@ async def update_function_user_valves_by_id( user_valves = UserValves(**form_data) user_valves_dict = user_valves.model_dump(exclude_unset=True) Functions.update_user_valves_by_id_and_user_id( - id, user.id, user_valves_dict + id, user.id, user_valves_dict, db=db ) return user_valves_dict except Exception as e: diff --git a/backend/open_webui/routers/groups.py b/backend/open_webui/routers/groups.py index 423f6b1c67..f1c567fe81 100755 --- a/backend/open_webui/routers/groups.py +++ b/backend/open_webui/routers/groups.py @@ -16,6 +16,9 @@ from open_webui.config import CACHE_DIR from open_webui.constants import ERROR_MESSAGES from fastapi import APIRouter, Depends, HTTPException, Request, status +from open_webui.internal.db import get_session +from sqlalchemy.orm import Session + from open_webui.utils.auth import get_admin_user, get_verified_user @@ -29,7 +32,11 @@ router = APIRouter() @router.get("/", response_model=list[GroupResponse]) -async def get_groups(share: Optional[bool] = None, user=Depends(get_verified_user)): +async def get_groups( + share: Optional[bool] = None, + user=Depends(get_verified_user), + db: Session = Depends(get_session), +): filter = {} if user.role != "admin": @@ -38,7 +45,7 @@ async def get_groups(share: Optional[bool] = None, user=Depends(get_verified_use if share is not None: filter["share"] = share - groups = Groups.get_groups(filter=filter) + groups = Groups.get_groups(filter=filter, db=db) return groups @@ -49,13 +56,17 @@ async def get_groups(share: Optional[bool] = None, user=Depends(get_verified_use @router.post("/create", response_model=Optional[GroupResponse]) -async def create_new_group(form_data: GroupForm, user=Depends(get_admin_user)): +async def create_new_group( + form_data: GroupForm, + user=Depends(get_admin_user), + db: Session = Depends(get_session), +): try: - group = Groups.insert_new_group(user.id, form_data) + group = Groups.insert_new_group(user.id, form_data, db=db) if group: return GroupResponse( **group.model_dump(), - member_count=Groups.get_group_member_count_by_id(group.id), + member_count=Groups.get_group_member_count_by_id(group.id, db=db), ) else: raise HTTPException( @@ -76,12 +87,14 @@ async def create_new_group(form_data: GroupForm, user=Depends(get_admin_user)): @router.get("/id/{id}", response_model=Optional[GroupResponse]) -async def get_group_by_id(id: str, user=Depends(get_admin_user)): - group = Groups.get_group_by_id(id) +async def get_group_by_id( + id: str, user=Depends(get_admin_user), db: Session = Depends(get_session) +): + group = Groups.get_group_by_id(id, db=db) if group: return GroupResponse( **group.model_dump(), - member_count=Groups.get_group_member_count_by_id(group.id), + member_count=Groups.get_group_member_count_by_id(group.id, db=db), ) else: raise HTTPException( @@ -101,13 +114,15 @@ class GroupExportResponse(GroupResponse): @router.get("/id/{id}/export", response_model=Optional[GroupExportResponse]) -async def export_group_by_id(id: str, user=Depends(get_admin_user)): - group = Groups.get_group_by_id(id) +async def export_group_by_id( + id: str, user=Depends(get_admin_user), db: Session = Depends(get_session) +): + group = Groups.get_group_by_id(id, db=db) if group: return GroupExportResponse( **group.model_dump(), - member_count=Groups.get_group_member_count_by_id(group.id), - user_ids=Groups.get_group_user_ids_by_id(group.id), + member_count=Groups.get_group_member_count_by_id(group.id, db=db), + user_ids=Groups.get_group_user_ids_by_id(group.id, db=db), ) else: raise HTTPException( @@ -122,9 +137,11 @@ async def export_group_by_id(id: str, user=Depends(get_admin_user)): @router.post("/id/{id}/users", response_model=list[UserInfoResponse]) -async def get_users_in_group(id: str, user=Depends(get_admin_user)): +async def get_users_in_group( + id: str, user=Depends(get_admin_user), db: Session = Depends(get_session) +): try: - users = Users.get_users_by_group_id(id) + users = Users.get_users_by_group_id(id, db=db) return users except Exception as e: log.exception(f"Error adding users to group {id}: {e}") @@ -141,14 +158,17 @@ async def get_users_in_group(id: str, user=Depends(get_admin_user)): @router.post("/id/{id}/update", response_model=Optional[GroupResponse]) async def update_group_by_id( - id: str, form_data: GroupUpdateForm, user=Depends(get_admin_user) + id: str, + form_data: GroupUpdateForm, + user=Depends(get_admin_user), + db: Session = Depends(get_session), ): try: - group = Groups.update_group_by_id(id, form_data) + group = Groups.update_group_by_id(id, form_data, db=db) if group: return GroupResponse( **group.model_dump(), - member_count=Groups.get_group_member_count_by_id(group.id), + member_count=Groups.get_group_member_count_by_id(group.id, db=db), ) else: raise HTTPException( @@ -170,17 +190,20 @@ async def update_group_by_id( @router.post("/id/{id}/users/add", response_model=Optional[GroupResponse]) async def add_user_to_group( - id: str, form_data: UserIdsForm, user=Depends(get_admin_user) + id: str, + form_data: UserIdsForm, + user=Depends(get_admin_user), + db: Session = Depends(get_session), ): try: if form_data.user_ids: - form_data.user_ids = Users.get_valid_user_ids(form_data.user_ids) + form_data.user_ids = Users.get_valid_user_ids(form_data.user_ids, db=db) - group = Groups.add_users_to_group(id, form_data.user_ids) + group = Groups.add_users_to_group(id, form_data.user_ids, db=db) if group: return GroupResponse( **group.model_dump(), - member_count=Groups.get_group_member_count_by_id(group.id), + member_count=Groups.get_group_member_count_by_id(group.id, db=db), ) else: raise HTTPException( @@ -197,14 +220,17 @@ async def add_user_to_group( @router.post("/id/{id}/users/remove", response_model=Optional[GroupResponse]) async def remove_users_from_group( - id: str, form_data: UserIdsForm, user=Depends(get_admin_user) + id: str, + form_data: UserIdsForm, + user=Depends(get_admin_user), + db: Session = Depends(get_session), ): try: - group = Groups.remove_users_from_group(id, form_data.user_ids) + group = Groups.remove_users_from_group(id, form_data.user_ids, db=db) if group: return GroupResponse( **group.model_dump(), - member_count=Groups.get_group_member_count_by_id(group.id), + member_count=Groups.get_group_member_count_by_id(group.id, db=db), ) else: raise HTTPException( @@ -225,9 +251,11 @@ async def remove_users_from_group( @router.delete("/id/{id}/delete", response_model=bool) -async def delete_group_by_id(id: str, user=Depends(get_admin_user)): +async def delete_group_by_id( + id: str, user=Depends(get_admin_user), db: Session = Depends(get_session) +): try: - result = Groups.delete_group_by_id(id) + result = Groups.delete_group_by_id(id, db=db) if result: return result else: diff --git a/backend/open_webui/routers/knowledge.py b/backend/open_webui/routers/knowledge.py index 467f6f3896..0b322fa862 100644 --- a/backend/open_webui/routers/knowledge.py +++ b/backend/open_webui/routers/knowledge.py @@ -4,6 +4,8 @@ from fastapi import APIRouter, Depends, HTTPException, status, Request, Query from fastapi.concurrency import run_in_threadpool import logging +from sqlalchemy.orm import Session +from open_webui.internal.db import get_session from open_webui.models.groups import Groups from open_webui.models.knowledge import ( KnowledgeFileListResponse, @@ -52,21 +54,25 @@ class KnowledgeAccessListResponse(BaseModel): @router.get("/", response_model=KnowledgeAccessListResponse) -async def get_knowledge_bases(page: Optional[int] = 1, user=Depends(get_verified_user)): +async def get_knowledge_bases( + page: Optional[int] = 1, + user=Depends(get_verified_user), + db: Session = Depends(get_session), +): page = max(page, 1) limit = PAGE_ITEM_COUNT skip = (page - 1) * limit filter = {} if not user.role == "admin" or not BYPASS_ADMIN_ACCESS_CONTROL: - groups = Groups.get_groups_by_member_id(user.id) + groups = Groups.get_groups_by_member_id(user.id, db=db) if groups: filter["group_ids"] = [group.id for group in groups] filter["user_id"] = user.id result = Knowledges.search_knowledge_bases( - user.id, filter=filter, skip=skip, limit=limit + user.id, filter=filter, skip=skip, limit=limit, db=db ) return KnowledgeAccessListResponse( @@ -75,7 +81,9 @@ async def get_knowledge_bases(page: Optional[int] = 1, user=Depends(get_verified **knowledge_base.model_dump(), write_access=( user.id == knowledge_base.user_id - or has_access(user.id, "write", knowledge_base.access_control) + or has_access( + user.id, "write", knowledge_base.access_control, db=db + ) ), ) for knowledge_base in result.items @@ -90,6 +98,7 @@ async def search_knowledge_bases( view_option: Optional[str] = None, page: Optional[int] = 1, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): page = max(page, 1) limit = PAGE_ITEM_COUNT @@ -102,14 +111,14 @@ async def search_knowledge_bases( filter["view_option"] = view_option if not user.role == "admin" or not BYPASS_ADMIN_ACCESS_CONTROL: - groups = Groups.get_groups_by_member_id(user.id) + groups = Groups.get_groups_by_member_id(user.id, db=db) if groups: filter["group_ids"] = [group.id for group in groups] filter["user_id"] = user.id result = Knowledges.search_knowledge_bases( - user.id, filter=filter, skip=skip, limit=limit + user.id, filter=filter, skip=skip, limit=limit, db=db ) return KnowledgeAccessListResponse( @@ -118,7 +127,9 @@ async def search_knowledge_bases( **knowledge_base.model_dump(), write_access=( user.id == knowledge_base.user_id - or has_access(user.id, "write", knowledge_base.access_control) + or has_access( + user.id, "write", knowledge_base.access_control, db=db + ) ), ) for knowledge_base in result.items @@ -132,6 +143,7 @@ async def search_knowledge_files( query: Optional[str] = None, page: Optional[int] = 1, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): page = max(page, 1) limit = PAGE_ITEM_COUNT @@ -141,13 +153,15 @@ async def search_knowledge_files( if query: filter["query"] = query - groups = Groups.get_groups_by_member_id(user.id) + groups = Groups.get_groups_by_member_id(user.id, db=db) if groups: filter["group_ids"] = [group.id for group in groups] filter["user_id"] = user.id - return Knowledges.search_knowledge_files(filter=filter, skip=skip, limit=limit) + return Knowledges.search_knowledge_files( + filter=filter, skip=skip, limit=limit, db=db + ) ############################ @@ -157,10 +171,13 @@ async def search_knowledge_files( @router.post("/create", response_model=Optional[KnowledgeResponse]) async def create_new_knowledge( - request: Request, form_data: KnowledgeForm, user=Depends(get_verified_user) + request: Request, + form_data: KnowledgeForm, + user=Depends(get_verified_user), + db: Session = Depends(get_session), ): if user.role != "admin" and not has_permission( - user.id, "workspace.knowledge", request.app.state.config.USER_PERMISSIONS + user.id, "workspace.knowledge", request.app.state.config.USER_PERMISSIONS, db=db ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -175,11 +192,12 @@ async def create_new_knowledge( user.id, "sharing.public_knowledge", request.app.state.config.USER_PERMISSIONS, + db=db, ) ): form_data.access_control = {} - knowledge = Knowledges.insert_new_knowledge(user.id, form_data) + knowledge = Knowledges.insert_new_knowledge(user.id, form_data, db=db) if knowledge: return knowledge @@ -196,20 +214,24 @@ async def create_new_knowledge( @router.post("/reindex", response_model=bool) -async def reindex_knowledge_files(request: Request, user=Depends(get_verified_user)): +async def reindex_knowledge_files( + request: Request, + user=Depends(get_verified_user), + db: Session = Depends(get_session), +): if user.role != "admin": raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.UNAUTHORIZED, ) - knowledge_bases = Knowledges.get_knowledge_bases() + knowledge_bases = Knowledges.get_knowledge_bases(db=db) log.info(f"Starting reindexing for {len(knowledge_bases)} knowledge bases") for knowledge_base in knowledge_bases: try: - files = Knowledges.get_files_by_id(knowledge_base.id) + files = Knowledges.get_files_by_id(knowledge_base.id, db=db) try: if VECTOR_DB_CLIENT.has_collection(collection_name=knowledge_base.id): VECTOR_DB_CLIENT.delete_collection( @@ -229,6 +251,7 @@ async def reindex_knowledge_files(request: Request, user=Depends(get_verified_us file_id=file.id, collection_name=knowledge_base.id ), user=user, + db=db, ) except Exception as e: log.error( @@ -264,21 +287,23 @@ class KnowledgeFilesResponse(KnowledgeResponse): @router.get("/{id}", response_model=Optional[KnowledgeFilesResponse]) -async def get_knowledge_by_id(id: str, user=Depends(get_verified_user)): - knowledge = Knowledges.get_knowledge_by_id(id=id) +async def get_knowledge_by_id( + id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) +): + knowledge = Knowledges.get_knowledge_by_id(id=id, db=db) if knowledge: if ( user.role == "admin" or knowledge.user_id == user.id - or has_access(user.id, "read", knowledge.access_control) + or has_access(user.id, "read", knowledge.access_control, db=db) ): return KnowledgeFilesResponse( **knowledge.model_dump(), write_access=( user.id == knowledge.user_id - or has_access(user.id, "write", knowledge.access_control) + or has_access(user.id, "write", knowledge.access_control, db=db) ), ) else: @@ -299,8 +324,9 @@ async def update_knowledge_by_id( id: str, form_data: KnowledgeForm, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): - knowledge = Knowledges.get_knowledge_by_id(id=id) + knowledge = Knowledges.get_knowledge_by_id(id=id, db=db) if not knowledge: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -309,7 +335,7 @@ async def update_knowledge_by_id( # Is the user the original creator, in a group with write access, or an admin if ( knowledge.user_id != user.id - and not has_access(user.id, "write", knowledge.access_control) + and not has_access(user.id, "write", knowledge.access_control, db=db) and user.role != "admin" ): raise HTTPException( @@ -325,15 +351,16 @@ async def update_knowledge_by_id( user.id, "sharing.public_knowledge", request.app.state.config.USER_PERMISSIONS, + db=db, ) ): form_data.access_control = {} - knowledge = Knowledges.update_knowledge_by_id(id=id, form_data=form_data) + knowledge = Knowledges.update_knowledge_by_id(id=id, form_data=form_data, db=db) if knowledge: return KnowledgeFilesResponse( **knowledge.model_dump(), - files=Knowledges.get_file_metadatas_by_id(knowledge.id), + files=Knowledges.get_file_metadatas_by_id(knowledge.id, db=db), ) else: raise HTTPException( @@ -356,9 +383,10 @@ async def get_knowledge_files_by_id( direction: Optional[str] = None, page: Optional[int] = 1, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): - knowledge = Knowledges.get_knowledge_by_id(id=id) + knowledge = Knowledges.get_knowledge_by_id(id=id, db=db) if not knowledge: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -368,7 +396,7 @@ async def get_knowledge_files_by_id( if not ( user.role == "admin" or knowledge.user_id == user.id - or has_access(user.id, "read", knowledge.access_control) + or has_access(user.id, "read", knowledge.access_control, db=db) ): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -391,7 +419,7 @@ async def get_knowledge_files_by_id( filter["direction"] = direction return Knowledges.search_files_by_id( - id, user.id, filter=filter, skip=skip, limit=limit + id, user.id, filter=filter, skip=skip, limit=limit, db=db ) @@ -410,8 +438,9 @@ def add_file_to_knowledge_by_id( id: str, form_data: KnowledgeFileIdForm, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): - knowledge = Knowledges.get_knowledge_by_id(id=id) + knowledge = Knowledges.get_knowledge_by_id(id=id, db=db) if not knowledge: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -420,7 +449,7 @@ def add_file_to_knowledge_by_id( if ( knowledge.user_id != user.id - and not has_access(user.id, "write", knowledge.access_control) + and not has_access(user.id, "write", knowledge.access_control, db=db) and user.role != "admin" ): raise HTTPException( @@ -428,7 +457,7 @@ def add_file_to_knowledge_by_id( detail=ERROR_MESSAGES.ACCESS_PROHIBITED, ) - file = Files.get_file_by_id(form_data.file_id) + file = Files.get_file_by_id(form_data.file_id, db=db) if not file: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -446,11 +475,12 @@ def add_file_to_knowledge_by_id( request, ProcessFileForm(file_id=form_data.file_id, collection_name=id), user=user, + db=db, ) # Add file to knowledge base Knowledges.add_file_to_knowledge_by_id( - knowledge_id=id, file_id=form_data.file_id, user_id=user.id + knowledge_id=id, file_id=form_data.file_id, user_id=user.id, db=db ) except Exception as e: log.debug(e) @@ -462,7 +492,7 @@ def add_file_to_knowledge_by_id( if knowledge: return KnowledgeFilesResponse( **knowledge.model_dump(), - files=Knowledges.get_file_metadatas_by_id(knowledge.id), + files=Knowledges.get_file_metadatas_by_id(knowledge.id, db=db), ) else: raise HTTPException( @@ -477,8 +507,9 @@ def update_file_from_knowledge_by_id( id: str, form_data: KnowledgeFileIdForm, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): - knowledge = Knowledges.get_knowledge_by_id(id=id) + knowledge = Knowledges.get_knowledge_by_id(id=id, db=db) if not knowledge: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -487,7 +518,7 @@ def update_file_from_knowledge_by_id( if ( knowledge.user_id != user.id - and not has_access(user.id, "write", knowledge.access_control) + and not has_access(user.id, "write", knowledge.access_control, db=db) and user.role != "admin" ): @@ -496,7 +527,7 @@ def update_file_from_knowledge_by_id( detail=ERROR_MESSAGES.ACCESS_PROHIBITED, ) - file = Files.get_file_by_id(form_data.file_id) + file = Files.get_file_by_id(form_data.file_id, db=db) if not file: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -514,6 +545,7 @@ def update_file_from_knowledge_by_id( request, ProcessFileForm(file_id=form_data.file_id, collection_name=id), user=user, + db=db, ) except Exception as e: raise HTTPException( @@ -524,7 +556,7 @@ def update_file_from_knowledge_by_id( if knowledge: return KnowledgeFilesResponse( **knowledge.model_dump(), - files=Knowledges.get_file_metadatas_by_id(knowledge.id), + files=Knowledges.get_file_metadatas_by_id(knowledge.id, db=db), ) else: raise HTTPException( @@ -544,8 +576,9 @@ def remove_file_from_knowledge_by_id( form_data: KnowledgeFileIdForm, delete_file: bool = Query(True), user=Depends(get_verified_user), + db: Session = Depends(get_session), ): - knowledge = Knowledges.get_knowledge_by_id(id=id) + knowledge = Knowledges.get_knowledge_by_id(id=id, db=db) if not knowledge: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -554,7 +587,7 @@ def remove_file_from_knowledge_by_id( if ( knowledge.user_id != user.id - and not has_access(user.id, "write", knowledge.access_control) + and not has_access(user.id, "write", knowledge.access_control, db=db) and user.role != "admin" ): raise HTTPException( @@ -562,7 +595,7 @@ def remove_file_from_knowledge_by_id( detail=ERROR_MESSAGES.ACCESS_PROHIBITED, ) - file = Files.get_file_by_id(form_data.file_id) + file = Files.get_file_by_id(form_data.file_id, db=db) if not file: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -570,7 +603,7 @@ def remove_file_from_knowledge_by_id( ) Knowledges.remove_file_from_knowledge_by_id( - knowledge_id=id, file_id=form_data.file_id + knowledge_id=id, file_id=form_data.file_id, db=db ) # Remove content from the vector database @@ -599,12 +632,12 @@ def remove_file_from_knowledge_by_id( pass # Delete file from database - Files.delete_file_by_id(form_data.file_id) + Files.delete_file_by_id(form_data.file_id, db=db) if knowledge: return KnowledgeFilesResponse( **knowledge.model_dump(), - files=Knowledges.get_file_metadatas_by_id(knowledge.id), + files=Knowledges.get_file_metadatas_by_id(knowledge.id, db=db), ) else: raise HTTPException( @@ -619,8 +652,10 @@ def remove_file_from_knowledge_by_id( @router.delete("/{id}/delete", response_model=bool) -async def delete_knowledge_by_id(id: str, user=Depends(get_verified_user)): - knowledge = Knowledges.get_knowledge_by_id(id=id) +async def delete_knowledge_by_id( + id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) +): + knowledge = Knowledges.get_knowledge_by_id(id=id, db=db) if not knowledge: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -629,7 +664,7 @@ async def delete_knowledge_by_id(id: str, user=Depends(get_verified_user)): if ( knowledge.user_id != user.id - and not has_access(user.id, "write", knowledge.access_control) + and not has_access(user.id, "write", knowledge.access_control, db=db) and user.role != "admin" ): raise HTTPException( @@ -640,7 +675,7 @@ async def delete_knowledge_by_id(id: str, user=Depends(get_verified_user)): log.info(f"Deleting knowledge base: {id} (name: {knowledge.name})") # Get all models - models = Models.get_all_models() + models = Models.get_all_models(db=db) log.info(f"Found {len(models)} models to check for knowledge base {id}") # Update models that reference this knowledge base @@ -664,7 +699,7 @@ async def delete_knowledge_by_id(id: str, user=Depends(get_verified_user)): access_control=model.access_control, is_active=model.is_active, ) - Models.update_model_by_id(model.id, model_form) + Models.update_model_by_id(model.id, model_form, db=db) # Clean up vector DB try: @@ -672,7 +707,7 @@ async def delete_knowledge_by_id(id: str, user=Depends(get_verified_user)): except Exception as e: log.debug(e) pass - result = Knowledges.delete_knowledge_by_id(id=id) + result = Knowledges.delete_knowledge_by_id(id=id, db=db) return result @@ -682,8 +717,10 @@ async def delete_knowledge_by_id(id: str, user=Depends(get_verified_user)): @router.post("/{id}/reset", response_model=Optional[KnowledgeResponse]) -async def reset_knowledge_by_id(id: str, user=Depends(get_verified_user)): - knowledge = Knowledges.get_knowledge_by_id(id=id) +async def reset_knowledge_by_id( + id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) +): + knowledge = Knowledges.get_knowledge_by_id(id=id, db=db) if not knowledge: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -692,7 +729,7 @@ async def reset_knowledge_by_id(id: str, user=Depends(get_verified_user)): if ( knowledge.user_id != user.id - and not has_access(user.id, "write", knowledge.access_control) + and not has_access(user.id, "write", knowledge.access_control, db=db) and user.role != "admin" ): raise HTTPException( @@ -706,7 +743,7 @@ async def reset_knowledge_by_id(id: str, user=Depends(get_verified_user)): log.debug(e) pass - knowledge = Knowledges.reset_knowledge_by_id(id=id) + knowledge = Knowledges.reset_knowledge_by_id(id=id, db=db) return knowledge @@ -721,11 +758,12 @@ async def add_files_to_knowledge_batch( id: str, form_data: list[KnowledgeFileIdForm], user=Depends(get_verified_user), + db: Session = Depends(get_session), ): """ Add multiple files to a knowledge base """ - knowledge = Knowledges.get_knowledge_by_id(id=id) + knowledge = Knowledges.get_knowledge_by_id(id=id, db=db) if not knowledge: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -734,7 +772,7 @@ async def add_files_to_knowledge_batch( if ( knowledge.user_id != user.id - and not has_access(user.id, "write", knowledge.access_control) + and not has_access(user.id, "write", knowledge.access_control, db=db) and user.role != "admin" ): raise HTTPException( @@ -746,7 +784,7 @@ async def add_files_to_knowledge_batch( log.info(f"files/batch/add - {len(form_data)} files") files: List[FileModel] = [] for form in form_data: - file = Files.get_file_by_id(form.file_id) + file = Files.get_file_by_id(form.file_id, db=db) if not file: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -760,6 +798,7 @@ async def add_files_to_knowledge_batch( request=request, form_data=BatchProcessFilesForm(files=files, collection_name=id), user=user, + db=db, ) except Exception as e: log.error( @@ -771,7 +810,7 @@ async def add_files_to_knowledge_batch( successful_file_ids = [r.file_id for r in result.results if r.status == "completed"] for file_id in successful_file_ids: Knowledges.add_file_to_knowledge_by_id( - knowledge_id=id, file_id=file_id, user_id=user.id + knowledge_id=id, file_id=file_id, user_id=user.id, db=db ) # If there were any errors, include them in the response @@ -779,7 +818,7 @@ async def add_files_to_knowledge_batch( error_details = [f"{err.file_id}: {err.error}" for err in result.errors] return KnowledgeFilesResponse( **knowledge.model_dump(), - files=Knowledges.get_file_metadatas_by_id(knowledge.id), + files=Knowledges.get_file_metadatas_by_id(knowledge.id, db=db), warnings={ "message": "Some files failed to process", "errors": error_details, @@ -788,5 +827,5 @@ async def add_files_to_knowledge_batch( return KnowledgeFilesResponse( **knowledge.model_dump(), - files=Knowledges.get_file_metadatas_by_id(knowledge.id), + files=Knowledges.get_file_metadatas_by_id(knowledge.id, db=db), ) diff --git a/backend/open_webui/routers/memories.py b/backend/open_webui/routers/memories.py index 9bb1ef518d..3747ab14ed 100644 --- a/backend/open_webui/routers/memories.py +++ b/backend/open_webui/routers/memories.py @@ -9,6 +9,10 @@ from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT from open_webui.utils.auth import get_verified_user +from open_webui.utils.auth import get_verified_user +from open_webui.internal.db import get_session +from sqlalchemy.orm import Session + log = logging.getLogger(__name__) router = APIRouter() @@ -25,8 +29,8 @@ async def get_embeddings(request: Request): @router.get("/", response_model=list[MemoryModel]) -async def get_memories(user=Depends(get_verified_user)): - return Memories.get_memories_by_user_id(user.id) +async def get_memories(user=Depends(get_verified_user), db: Session = Depends(get_session)): + return Memories.get_memories_by_user_id(user.id, db=db) ############################ @@ -47,8 +51,9 @@ async def add_memory( request: Request, form_data: AddMemoryForm, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): - memory = Memories.insert_new_memory(user.id, form_data.content) + memory = Memories.insert_new_memory(user.id, form_data.content, db=db) vector = await request.app.state.EMBEDDING_FUNCTION(memory.content, user=user) @@ -79,9 +84,9 @@ class QueryMemoryForm(BaseModel): @router.post("/query") async def query_memory( - request: Request, form_data: QueryMemoryForm, user=Depends(get_verified_user) + request: Request, form_data: QueryMemoryForm, user=Depends(get_verified_user), db: Session = Depends(get_session) ): - memories = Memories.get_memories_by_user_id(user.id) + memories = Memories.get_memories_by_user_id(user.id, db=db) if not memories: raise HTTPException(status_code=404, detail="No memories found for user") @@ -101,11 +106,11 @@ async def query_memory( ############################ @router.post("/reset", response_model=bool) async def reset_memory_from_vector_db( - request: Request, user=Depends(get_verified_user) + request: Request, user=Depends(get_verified_user), db: Session = Depends(get_session) ): VECTOR_DB_CLIENT.delete_collection(f"user-memory-{user.id}") - memories = Memories.get_memories_by_user_id(user.id) + memories = Memories.get_memories_by_user_id(user.id, db=db) # Generate vectors in parallel vectors = await asyncio.gather( @@ -140,8 +145,8 @@ async def reset_memory_from_vector_db( @router.delete("/delete/user", response_model=bool) -async def delete_memory_by_user_id(user=Depends(get_verified_user)): - result = Memories.delete_memories_by_user_id(user.id) +async def delete_memory_by_user_id(user=Depends(get_verified_user), db: Session = Depends(get_session)): + result = Memories.delete_memories_by_user_id(user.id, db=db) if result: try: @@ -164,9 +169,10 @@ async def update_memory_by_id( request: Request, form_data: MemoryUpdateModel, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): memory = Memories.update_memory_by_id_and_user_id( - memory_id, user.id, form_data.content + memory_id, user.id, form_data.content, db=db ) if memory is None: raise HTTPException(status_code=404, detail="Memory not found") @@ -198,8 +204,8 @@ async def update_memory_by_id( @router.delete("/{memory_id}", response_model=bool) -async def delete_memory_by_id(memory_id: str, user=Depends(get_verified_user)): - result = Memories.delete_memory_by_id_and_user_id(memory_id, user.id) +async def delete_memory_by_id(memory_id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)): + result = Memories.delete_memory_by_id_and_user_id(memory_id, user.id, db=db) if result: VECTOR_DB_CLIENT.delete( diff --git a/backend/open_webui/routers/models.py b/backend/open_webui/routers/models.py index c5d8b6d102..d65a28cc1f 100644 --- a/backend/open_webui/routers/models.py +++ b/backend/open_webui/routers/models.py @@ -30,6 +30,8 @@ from fastapi.responses import FileResponse, StreamingResponse from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.utils.access_control import has_access, has_permission from open_webui.config import BYPASS_ADMIN_ACCESS_CONTROL, STATIC_DIR +from open_webui.internal.db import get_session +from sqlalchemy.orm import Session log = logging.getLogger(__name__) @@ -59,6 +61,7 @@ async def get_models( direction: Optional[str] = None, page: Optional[int] = 1, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): limit = PAGE_ITEM_COUNT @@ -79,13 +82,13 @@ async def get_models( filter["direction"] = direction if not user.role == "admin" or not BYPASS_ADMIN_ACCESS_CONTROL: - groups = Groups.get_groups_by_member_id(user.id) + groups = Groups.get_groups_by_member_id(user.id, db=db) if groups: filter["group_ids"] = [group.id for group in groups] filter["user_id"] = user.id - return Models.search_models(user.id, filter=filter, skip=skip, limit=limit) + return Models.search_models(user.id, filter=filter, skip=skip, limit=limit, db=db) ########################### @@ -94,8 +97,8 @@ async def get_models( @router.get("/base", response_model=list[ModelResponse]) -async def get_base_models(user=Depends(get_admin_user)): - return Models.get_base_models() +async def get_base_models(user=Depends(get_admin_user), db: Session = Depends(get_session)): + return Models.get_base_models(db=db) ########################### @@ -104,11 +107,11 @@ async def get_base_models(user=Depends(get_admin_user)): @router.get("/tags", response_model=list[str]) -async def get_model_tags(user=Depends(get_verified_user)): +async def get_model_tags(user=Depends(get_verified_user), db: Session = Depends(get_session)): if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL: - models = Models.get_models() + models = Models.get_models(db=db) else: - models = Models.get_models_by_user_id(user.id) + models = Models.get_models_by_user_id(user.id, db=db) tags_set = set() for model in models: @@ -132,16 +135,17 @@ async def create_new_model( request: Request, form_data: ModelForm, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): if user.role != "admin" and not has_permission( - user.id, "workspace.models", request.app.state.config.USER_PERMISSIONS + user.id, "workspace.models", request.app.state.config.USER_PERMISSIONS, db=db ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.UNAUTHORIZED, ) - model = Models.get_model_by_id(form_data.id) + model = Models.get_model_by_id(form_data.id, db=db) if model: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -155,7 +159,7 @@ async def create_new_model( ) else: - model = Models.insert_new_model(form_data, user.id) + model = Models.insert_new_model(form_data, user.id, db=db) if model: return model else: @@ -171,9 +175,9 @@ async def create_new_model( @router.get("/export", response_model=list[ModelModel]) -async def export_models(request: Request, user=Depends(get_verified_user)): +async def export_models(request: Request, user=Depends(get_verified_user), db: Session = Depends(get_session)): if user.role != "admin" and not has_permission( - user.id, "workspace.models_export", request.app.state.config.USER_PERMISSIONS + user.id, "workspace.models_export", request.app.state.config.USER_PERMISSIONS, db=db ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -181,9 +185,9 @@ async def export_models(request: Request, user=Depends(get_verified_user)): ) if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL: - return Models.get_models() + return Models.get_models(db=db) else: - return Models.get_models_by_user_id(user.id) + return Models.get_models_by_user_id(user.id, db=db) ############################ @@ -200,9 +204,10 @@ async def import_models( request: Request, user=Depends(get_verified_user), form_data: ModelsImportForm = (...), + db: Session = Depends(get_session), ): if user.role != "admin" and not has_permission( - user.id, "workspace.models_import", request.app.state.config.USER_PERMISSIONS + user.id, "workspace.models_import", request.app.state.config.USER_PERMISSIONS, db=db ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -216,7 +221,7 @@ async def import_models( model_id = model_data.get("id") if model_id and is_valid_model_id(model_id): - existing_model = Models.get_model_by_id(model_id) + existing_model = Models.get_model_by_id(model_id, db=db) if existing_model: # Update existing model model_data["meta"] = model_data.get("meta", {}) @@ -225,13 +230,13 @@ async def import_models( updated_model = ModelForm( **{**existing_model.model_dump(), **model_data} ) - Models.update_model_by_id(model_id, updated_model) + Models.update_model_by_id(model_id, updated_model, db=db) else: # Insert new model model_data["meta"] = model_data.get("meta", {}) model_data["params"] = model_data.get("params", {}) new_model = ModelForm(**model_data) - Models.insert_new_model(user_id=user.id, form_data=new_model) + Models.insert_new_model(user_id=user.id, form_data=new_model, db=db) return True else: raise HTTPException(status_code=400, detail="Invalid JSON format") @@ -251,9 +256,9 @@ class SyncModelsForm(BaseModel): @router.post("/sync", response_model=list[ModelModel]) async def sync_models( - request: Request, form_data: SyncModelsForm, user=Depends(get_admin_user) + request: Request, form_data: SyncModelsForm, user=Depends(get_admin_user), db: Session = Depends(get_session) ): - return Models.sync_models(user.id, form_data.models) + return Models.sync_models(user.id, form_data.models, db=db) ########################### @@ -267,13 +272,13 @@ class ModelIdForm(BaseModel): # Note: We're not using the typical url path param here, but instead using a query parameter to allow '/' in the id @router.get("/model", response_model=Optional[ModelResponse]) -async def get_model_by_id(id: str, user=Depends(get_verified_user)): - model = Models.get_model_by_id(id) +async def get_model_by_id(id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)): + model = Models.get_model_by_id(id, db=db) if model: if ( (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL) or model.user_id == user.id - or has_access(user.id, "read", model.access_control) + or has_access(user.id, "read", model.access_control, db=db) ): return model else: @@ -289,8 +294,8 @@ async def get_model_by_id(id: str, user=Depends(get_verified_user)): @router.get("/model/profile/image") -async def get_model_profile_image(id: str, user=Depends(get_verified_user)): - model = Models.get_model_by_id(id) +async def get_model_profile_image(id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)): + model = Models.get_model_by_id(id, db=db) # Cache-control headers to prevent stale cached images cache_headers = {"Cache-Control": "no-cache, must-revalidate"} @@ -330,15 +335,15 @@ async def get_model_profile_image(id: str, user=Depends(get_verified_user)): @router.post("/model/toggle", response_model=Optional[ModelResponse]) -async def toggle_model_by_id(id: str, user=Depends(get_verified_user)): - model = Models.get_model_by_id(id) +async def toggle_model_by_id(id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)): + model = Models.get_model_by_id(id, db=db) if model: if ( user.role == "admin" or model.user_id == user.id - or has_access(user.id, "write", model.access_control) + or has_access(user.id, "write", model.access_control, db=db) ): - model = Models.toggle_model_by_id(id) + model = Models.toggle_model_by_id(id, db=db) if model: return model @@ -368,8 +373,9 @@ async def toggle_model_by_id(id: str, user=Depends(get_verified_user)): async def update_model_by_id( form_data: ModelForm, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): - model = Models.get_model_by_id(form_data.id) + model = Models.get_model_by_id(form_data.id, db=db) if not model: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -378,7 +384,7 @@ async def update_model_by_id( if ( model.user_id != user.id - and not has_access(user.id, "write", model.access_control) + and not has_access(user.id, "write", model.access_control, db=db) and user.role != "admin" ): raise HTTPException( @@ -386,7 +392,7 @@ async def update_model_by_id( detail=ERROR_MESSAGES.ACCESS_PROHIBITED, ) - model = Models.update_model_by_id(form_data.id, ModelForm(**form_data.model_dump())) + model = Models.update_model_by_id(form_data.id, ModelForm(**form_data.model_dump()), db=db) return model @@ -396,8 +402,8 @@ async def update_model_by_id( @router.post("/model/delete", response_model=bool) -async def delete_model_by_id(form_data: ModelIdForm, user=Depends(get_verified_user)): - model = Models.get_model_by_id(form_data.id) +async def delete_model_by_id(form_data: ModelIdForm, user=Depends(get_verified_user), db: Session = Depends(get_session)): + model = Models.get_model_by_id(form_data.id, db=db) if not model: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -407,18 +413,18 @@ async def delete_model_by_id(form_data: ModelIdForm, user=Depends(get_verified_u if ( user.role != "admin" and model.user_id != user.id - and not has_access(user.id, "write", model.access_control) + and not has_access(user.id, "write", model.access_control, db=db) ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.UNAUTHORIZED, ) - result = Models.delete_model_by_id(form_data.id) + result = Models.delete_model_by_id(form_data.id, db=db) return result @router.delete("/delete/all", response_model=bool) -async def delete_all_models(user=Depends(get_admin_user)): - result = Models.delete_all_models() +async def delete_all_models(user=Depends(get_admin_user), db: Session = Depends(get_session)): + result = Models.delete_all_models(db=db) return result diff --git a/backend/open_webui/routers/notes.py b/backend/open_webui/routers/notes.py index 853e7848cc..98138a7535 100644 --- a/backend/open_webui/routers/notes.py +++ b/backend/open_webui/routers/notes.py @@ -28,6 +28,8 @@ from open_webui.constants import ERROR_MESSAGES from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.utils.access_control import has_access, has_permission +from open_webui.internal.db import get_session +from sqlalchemy.orm import Session log = logging.getLogger(__name__) @@ -49,10 +51,13 @@ class NoteItemResponse(BaseModel): @router.get("/", response_model=list[NoteItemResponse]) async def get_notes( - request: Request, page: Optional[int] = None, user=Depends(get_verified_user) + request: Request, + page: Optional[int] = None, + user=Depends(get_verified_user), + db: Session = Depends(get_session), ): if user.role != "admin" and not has_permission( - user.id, "features.notes", request.app.state.config.USER_PERMISSIONS + user.id, "features.notes", request.app.state.config.USER_PERMISSIONS, db=db ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -69,10 +74,14 @@ async def get_notes( NoteUserResponse( **{ **note.model_dump(), - "user": UserResponse(**Users.get_user_by_id(note.user_id).model_dump()), + "user": UserResponse( + **Users.get_user_by_id(note.user_id, db=db).model_dump() + ), } ) - for note in Notes.get_notes_by_user_id(user.id, "read", skip=skip, limit=limit) + for note in Notes.get_notes_by_user_id( + user.id, "read", skip=skip, limit=limit, db=db + ) ] return notes @@ -87,9 +96,10 @@ async def search_notes( direction: Optional[str] = None, page: Optional[int] = 1, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): if user.role != "admin" and not has_permission( - user.id, "features.notes", request.app.state.config.USER_PERMISSIONS + user.id, "features.notes", request.app.state.config.USER_PERMISSIONS, db=db ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -115,13 +125,13 @@ async def search_notes( filter["direction"] = direction if not user.role == "admin" or not BYPASS_ADMIN_ACCESS_CONTROL: - groups = Groups.get_groups_by_member_id(user.id) + groups = Groups.get_groups_by_member_id(user.id, db=db) if groups: filter["group_ids"] = [group.id for group in groups] filter["user_id"] = user.id - return Notes.search_notes(user.id, filter, skip=skip, limit=limit) + return Notes.search_notes(user.id, filter, skip=skip, limit=limit, db=db) ############################ @@ -131,10 +141,13 @@ async def search_notes( @router.post("/create", response_model=Optional[NoteModel]) async def create_new_note( - request: Request, form_data: NoteForm, user=Depends(get_verified_user) + request: Request, + form_data: NoteForm, + user=Depends(get_verified_user), + db: Session = Depends(get_session), ): if user.role != "admin" and not has_permission( - user.id, "features.notes", request.app.state.config.USER_PERMISSIONS + user.id, "features.notes", request.app.state.config.USER_PERMISSIONS, db=db ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -142,7 +155,7 @@ async def create_new_note( ) try: - note = Notes.insert_new_note(user.id, form_data) + note = Notes.insert_new_note(user.id, form_data, db=db) return note except Exception as e: log.exception(e) @@ -161,16 +174,21 @@ class NoteResponse(NoteModel): @router.get("/{id}", response_model=Optional[NoteResponse]) -async def get_note_by_id(request: Request, id: str, user=Depends(get_verified_user)): +async def get_note_by_id( + request: Request, + id: str, + user=Depends(get_verified_user), + db: Session = Depends(get_session), +): if user.role != "admin" and not has_permission( - user.id, "features.notes", request.app.state.config.USER_PERMISSIONS + user.id, "features.notes", request.app.state.config.USER_PERMISSIONS, db=db ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.UNAUTHORIZED, ) - note = Notes.get_note_by_id(id) + note = Notes.get_note_by_id(id, db=db) if not note: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND @@ -178,7 +196,11 @@ async def get_note_by_id(request: Request, id: str, user=Depends(get_verified_us if user.role != "admin" and ( user.id != note.user_id - and (not has_access(user.id, type="read", access_control=note.access_control)) + and ( + not has_access( + user.id, type="read", access_control=note.access_control, db=db + ) + ) ): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() @@ -188,7 +210,11 @@ async def get_note_by_id(request: Request, id: str, user=Depends(get_verified_us user.role == "admin" or (user.id == note.user_id) or has_access( - user.id, type="write", access_control=note.access_control, strict=False + user.id, + type="write", + access_control=note.access_control, + strict=False, + db=db, ) ) @@ -202,17 +228,21 @@ async def get_note_by_id(request: Request, id: str, user=Depends(get_verified_us @router.post("/{id}/update", response_model=Optional[NoteModel]) async def update_note_by_id( - request: Request, id: str, form_data: NoteForm, user=Depends(get_verified_user) + request: Request, + id: str, + form_data: NoteForm, + user=Depends(get_verified_user), + db: Session = Depends(get_session), ): if user.role != "admin" and not has_permission( - user.id, "features.notes", request.app.state.config.USER_PERMISSIONS + user.id, "features.notes", request.app.state.config.USER_PERMISSIONS, db=db ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.UNAUTHORIZED, ) - note = Notes.get_note_by_id(id) + note = Notes.get_note_by_id(id, db=db) if not note: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND @@ -220,7 +250,9 @@ async def update_note_by_id( if user.role != "admin" and ( user.id != note.user_id - and not has_access(user.id, type="write", access_control=note.access_control) + and not has_access( + user.id, type="write", access_control=note.access_control, db=db + ) ): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() @@ -234,12 +266,13 @@ async def update_note_by_id( user.id, "sharing.public_notes", request.app.state.config.USER_PERMISSIONS, + db=db, ) ): form_data.access_control = {} try: - note = Notes.update_note_by_id(id, form_data) + note = Notes.update_note_by_id(id, form_data, db=db) await sio.emit( "note-events", note.model_dump(), @@ -260,16 +293,21 @@ async def update_note_by_id( @router.delete("/{id}/delete", response_model=bool) -async def delete_note_by_id(request: Request, id: str, user=Depends(get_verified_user)): +async def delete_note_by_id( + request: Request, + id: str, + user=Depends(get_verified_user), + db: Session = Depends(get_session), +): if user.role != "admin" and not has_permission( - user.id, "features.notes", request.app.state.config.USER_PERMISSIONS + user.id, "features.notes", request.app.state.config.USER_PERMISSIONS, db=db ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.UNAUTHORIZED, ) - note = Notes.get_note_by_id(id) + note = Notes.get_note_by_id(id, db=db) if not note: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND @@ -277,14 +315,16 @@ async def delete_note_by_id(request: Request, id: str, user=Depends(get_verified if user.role != "admin" and ( user.id != note.user_id - and not has_access(user.id, type="write", access_control=note.access_control) + and not has_access( + user.id, type="write", access_control=note.access_control, db=db + ) ): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() ) try: - note = Notes.delete_note_by_id(id) + note = Notes.delete_note_by_id(id, db=db) return True except Exception as e: log.exception(e) diff --git a/backend/open_webui/routers/prompts.py b/backend/open_webui/routers/prompts.py index 6a957f2547..4633cee86b 100644 --- a/backend/open_webui/routers/prompts.py +++ b/backend/open_webui/routers/prompts.py @@ -11,6 +11,8 @@ from open_webui.constants import ERROR_MESSAGES from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.utils.access_control import has_access, has_permission from open_webui.config import BYPASS_ADMIN_ACCESS_CONTROL +from open_webui.internal.db import get_session +from sqlalchemy.orm import Session router = APIRouter() @@ -20,21 +22,21 @@ router = APIRouter() @router.get("/", response_model=list[PromptModel]) -async def get_prompts(user=Depends(get_verified_user)): +async def get_prompts(user=Depends(get_verified_user), db: Session = Depends(get_session)): if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL: - prompts = Prompts.get_prompts() + prompts = Prompts.get_prompts(db=db) else: - prompts = Prompts.get_prompts_by_user_id(user.id, "read") + prompts = Prompts.get_prompts_by_user_id(user.id, "read", db=db) return prompts @router.get("/list", response_model=list[PromptUserResponse]) -async def get_prompt_list(user=Depends(get_verified_user)): +async def get_prompt_list(user=Depends(get_verified_user), db: Session = Depends(get_session)): if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL: - prompts = Prompts.get_prompts() + prompts = Prompts.get_prompts(db=db) else: - prompts = Prompts.get_prompts_by_user_id(user.id, "write") + prompts = Prompts.get_prompts_by_user_id(user.id, "write", db=db) return prompts @@ -46,16 +48,17 @@ async def get_prompt_list(user=Depends(get_verified_user)): @router.post("/create", response_model=Optional[PromptModel]) async def create_new_prompt( - request: Request, form_data: PromptForm, user=Depends(get_verified_user) + request: Request, form_data: PromptForm, user=Depends(get_verified_user), db: Session = Depends(get_session) ): if user.role != "admin" and not ( has_permission( - user.id, "workspace.prompts", request.app.state.config.USER_PERMISSIONS + user.id, "workspace.prompts", request.app.state.config.USER_PERMISSIONS, db=db ) or has_permission( user.id, "workspace.prompts_import", request.app.state.config.USER_PERMISSIONS, + db=db, ) ): raise HTTPException( @@ -63,9 +66,9 @@ async def create_new_prompt( detail=ERROR_MESSAGES.UNAUTHORIZED, ) - prompt = Prompts.get_prompt_by_command(form_data.command) + prompt = Prompts.get_prompt_by_command(form_data.command, db=db) if prompt is None: - prompt = Prompts.insert_new_prompt(user.id, form_data) + prompt = Prompts.insert_new_prompt(user.id, form_data, db=db) if prompt: return prompt @@ -85,14 +88,14 @@ async def create_new_prompt( @router.get("/command/{command}", response_model=Optional[PromptModel]) -async def get_prompt_by_command(command: str, user=Depends(get_verified_user)): - prompt = Prompts.get_prompt_by_command(f"/{command}") +async def get_prompt_by_command(command: str, user=Depends(get_verified_user), db: Session = Depends(get_session)): + prompt = Prompts.get_prompt_by_command(f"/{command}", db=db) if prompt: if ( user.role == "admin" or prompt.user_id == user.id - or has_access(user.id, "read", prompt.access_control) + or has_access(user.id, "read", prompt.access_control, db=db) ): return prompt else: @@ -112,8 +115,9 @@ async def update_prompt_by_command( command: str, form_data: PromptForm, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): - prompt = Prompts.get_prompt_by_command(f"/{command}") + prompt = Prompts.get_prompt_by_command(f"/{command}", db=db) if not prompt: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -123,7 +127,7 @@ async def update_prompt_by_command( # Is the user the original creator, in a group with write access, or an admin if ( prompt.user_id != user.id - and not has_access(user.id, "write", prompt.access_control) + and not has_access(user.id, "write", prompt.access_control, db=db) and user.role != "admin" ): raise HTTPException( @@ -131,7 +135,7 @@ async def update_prompt_by_command( detail=ERROR_MESSAGES.ACCESS_PROHIBITED, ) - prompt = Prompts.update_prompt_by_command(f"/{command}", form_data) + prompt = Prompts.update_prompt_by_command(f"/{command}", form_data, db=db) if prompt: return prompt else: @@ -147,8 +151,8 @@ async def update_prompt_by_command( @router.delete("/command/{command}/delete", response_model=bool) -async def delete_prompt_by_command(command: str, user=Depends(get_verified_user)): - prompt = Prompts.get_prompt_by_command(f"/{command}") +async def delete_prompt_by_command(command: str, user=Depends(get_verified_user), db: Session = Depends(get_session)): + prompt = Prompts.get_prompt_by_command(f"/{command}", db=db) if not prompt: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -157,7 +161,7 @@ async def delete_prompt_by_command(command: str, user=Depends(get_verified_user) if ( prompt.user_id != user.id - and not has_access(user.id, "write", prompt.access_control) + and not has_access(user.id, "write", prompt.access_control, db=db) and user.role != "admin" ): raise HTTPException( @@ -165,5 +169,5 @@ async def delete_prompt_by_command(command: str, user=Depends(get_verified_user) detail=ERROR_MESSAGES.ACCESS_PROHIBITED, ) - result = Prompts.delete_prompt_by_command(f"/{command}") + result = Prompts.delete_prompt_by_command(f"/{command}", db=db) return result diff --git a/backend/open_webui/routers/retrieval.py b/backend/open_webui/routers/retrieval.py index a2c1cc80d5..c24ebf8d56 100644 --- a/backend/open_webui/routers/retrieval.py +++ b/backend/open_webui/routers/retrieval.py @@ -39,6 +39,8 @@ from langchain_core.documents import Document from open_webui.models.files import FileModel, FileUpdateForm, Files from open_webui.models.knowledge import Knowledges from open_webui.storage.provider import Storage +from open_webui.internal.db import get_session +from sqlalchemy.orm import Session from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT @@ -1484,14 +1486,15 @@ def process_file( request: Request, form_data: ProcessFileForm, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): """ Process a file and save its content to the vector database. """ if user.role == "admin": - file = Files.get_file_by_id(form_data.file_id) + file = Files.get_file_by_id(form_data.file_id, db=db) else: - file = Files.get_file_by_id_and_user_id(form_data.file_id, user.id) + file = Files.get_file_by_id_and_user_id(form_data.file_id, user.id, db=db) if file: try: @@ -1633,12 +1636,13 @@ def process_file( Files.update_file_data_by_id( file.id, {"content": text_content}, + db=db, ) hash = calculate_sha256_string(text_content) - Files.update_file_hash_by_id(file.id, hash) + Files.update_file_hash_by_id(file.id, hash, db=db) if request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL: - Files.update_file_data_by_id(file.id, {"status": "completed"}) + Files.update_file_data_by_id(file.id, {"status": "completed"}, db=db) return { "status": True, "collection_name": None, @@ -1667,11 +1671,13 @@ def process_file( { "collection_name": collection_name, }, + db=db, ) Files.update_file_data_by_id( file.id, {"status": "completed"}, + db=db, ) return { @@ -1690,6 +1696,7 @@ def process_file( Files.update_file_data_by_id( file.id, {"status": "failed"}, + db=db, ) if "No pandoc was found" in str(e): @@ -2417,10 +2424,10 @@ class DeleteForm(BaseModel): @router.post("/delete") -def delete_entries_from_collection(form_data: DeleteForm, user=Depends(get_admin_user)): +def delete_entries_from_collection(form_data: DeleteForm, user=Depends(get_admin_user), db: Session = Depends(get_session)): try: if VECTOR_DB_CLIENT.has_collection(collection_name=form_data.collection_name): - file = Files.get_file_by_id(form_data.file_id) + file = Files.get_file_by_id(form_data.file_id, db=db) hash = file.hash VECTOR_DB_CLIENT.delete( @@ -2436,9 +2443,9 @@ def delete_entries_from_collection(form_data: DeleteForm, user=Depends(get_admin @router.post("/reset/db") -def reset_vector_db(user=Depends(get_admin_user)): +def reset_vector_db(user=Depends(get_admin_user), db: Session = Depends(get_session)): VECTOR_DB_CLIENT.reset() - Knowledges.delete_all_knowledge() + Knowledges.delete_all_knowledge(db=db) @router.post("/reset/uploads") @@ -2496,6 +2503,7 @@ async def process_files_batch( request: Request, form_data: BatchProcessFilesForm, user=Depends(get_verified_user), + db: Session = Depends(get_session), ) -> BatchProcessFilesResponse: """ Process a batch of files and save them to the vector database. @@ -2558,7 +2566,7 @@ async def process_files_batch( # Update all files with collection name for file_update, file_result in zip(file_updates, file_results): - Files.update_file_by_id(id=file_result.file_id, form_data=file_update) + Files.update_file_by_id(id=file_result.file_id, form_data=file_update, db=db) file_result.status = "completed" except Exception as e: diff --git a/backend/open_webui/routers/tools.py b/backend/open_webui/routers/tools.py index fdcaf266fa..144f017edf 100644 --- a/backend/open_webui/routers/tools.py +++ b/backend/open_webui/routers/tools.py @@ -7,6 +7,8 @@ import aiohttp from open_webui.models.groups import Groups from pydantic import BaseModel, HttpUrl from fastapi import APIRouter, Depends, HTTPException, Request, status +from sqlalchemy.orm import Session +from open_webui.internal.db import get_session from open_webui.models.oauth_sessions import OAuthSessions @@ -51,11 +53,11 @@ def get_tool_module(request, tool_id, load_from_db=True): @router.get("/", response_model=list[ToolUserResponse]) -async def get_tools(request: Request, user=Depends(get_verified_user)): +async def get_tools(request: Request, user=Depends(get_verified_user), db: Session = Depends(get_session)): tools = [] # Local Tools - for tool in Tools.get_tools(): + for tool in Tools.get_tools(db=db): tool_module = get_tool_module(request, tool.id) tools.append( ToolUserResponse( @@ -140,12 +142,12 @@ async def get_tools(request: Request, user=Depends(get_verified_user)): # Admin can see all tools return tools else: - user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user.id)} + user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user.id, db=db)} tools = [ tool for tool in tools if tool.user_id == user.id - or has_access(user.id, "read", tool.access_control, user_group_ids) + or has_access(user.id, "read", tool.access_control, user_group_ids, db=db) ] return tools @@ -156,11 +158,11 @@ async def get_tools(request: Request, user=Depends(get_verified_user)): @router.get("/list", response_model=list[ToolUserResponse]) -async def get_tool_list(user=Depends(get_verified_user)): +async def get_tool_list(user=Depends(get_verified_user), db: Session = Depends(get_session)): if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL: - tools = Tools.get_tools() + tools = Tools.get_tools(db=db) else: - tools = Tools.get_tools_by_user_id(user.id, "write") + tools = Tools.get_tools_by_user_id(user.id, "write", db=db) return tools @@ -245,9 +247,9 @@ async def load_tool_from_url( @router.get("/export", response_model=list[ToolModel]) -async def export_tools(request: Request, user=Depends(get_verified_user)): +async def export_tools(request: Request, user=Depends(get_verified_user), db: Session = Depends(get_session)): if user.role != "admin" and not has_permission( - user.id, "workspace.tools_export", request.app.state.config.USER_PERMISSIONS + user.id, "workspace.tools_export", request.app.state.config.USER_PERMISSIONS, db=db ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -255,9 +257,9 @@ async def export_tools(request: Request, user=Depends(get_verified_user)): ) if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL: - return Tools.get_tools() + return Tools.get_tools(db=db) else: - return Tools.get_tools_by_user_id(user.id, "read") + return Tools.get_tools_by_user_id(user.id, "read", db=db) ############################ @@ -270,13 +272,14 @@ async def create_new_tools( request: Request, form_data: ToolForm, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): if user.role != "admin" and not ( has_permission( - user.id, "workspace.tools", request.app.state.config.USER_PERMISSIONS + user.id, "workspace.tools", request.app.state.config.USER_PERMISSIONS, db=db ) or has_permission( - user.id, "workspace.tools_import", request.app.state.config.USER_PERMISSIONS + user.id, "workspace.tools_import", request.app.state.config.USER_PERMISSIONS, db=db ) ): raise HTTPException( @@ -292,7 +295,7 @@ async def create_new_tools( form_data.id = form_data.id.lower() - tools = Tools.get_tool_by_id(form_data.id) + tools = Tools.get_tool_by_id(form_data.id, db=db) if tools is None: try: form_data.content = replace_imports(form_data.content) @@ -305,7 +308,7 @@ async def create_new_tools( TOOLS[form_data.id] = tool_module specs = get_tool_specs(TOOLS[form_data.id]) - tools = Tools.insert_new_tool(user.id, form_data, specs) + tools = Tools.insert_new_tool(user.id, form_data, specs, db=db) tool_cache_dir = CACHE_DIR / "tools" / form_data.id tool_cache_dir.mkdir(parents=True, exist_ok=True) @@ -336,14 +339,14 @@ async def create_new_tools( @router.get("/id/{id}", response_model=Optional[ToolModel]) -async def get_tools_by_id(id: str, user=Depends(get_verified_user)): - tools = Tools.get_tool_by_id(id) +async def get_tools_by_id(id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)): + tools = Tools.get_tool_by_id(id, db=db) if tools: if ( user.role == "admin" or tools.user_id == user.id - or has_access(user.id, "read", tools.access_control) + or has_access(user.id, "read", tools.access_control, db=db) ): return tools else: @@ -364,8 +367,9 @@ async def update_tools_by_id( id: str, form_data: ToolForm, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): - tools = Tools.get_tool_by_id(id) + tools = Tools.get_tool_by_id(id, db=db) if not tools: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -375,7 +379,7 @@ async def update_tools_by_id( # Is the user the original creator, in a group with write access, or an admin if ( tools.user_id != user.id - and not has_access(user.id, "write", tools.access_control) + and not has_access(user.id, "write", tools.access_control, db=db) and user.role != "admin" ): raise HTTPException( @@ -399,7 +403,7 @@ async def update_tools_by_id( } log.debug(updated) - tools = Tools.update_tool_by_id(id, updated) + tools = Tools.update_tool_by_id(id, updated, db=db) if tools: return tools @@ -423,9 +427,9 @@ async def update_tools_by_id( @router.delete("/id/{id}/delete", response_model=bool) async def delete_tools_by_id( - request: Request, id: str, user=Depends(get_verified_user) + request: Request, id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) ): - tools = Tools.get_tool_by_id(id) + tools = Tools.get_tool_by_id(id, db=db) if not tools: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -434,7 +438,7 @@ async def delete_tools_by_id( if ( tools.user_id != user.id - and not has_access(user.id, "write", tools.access_control) + and not has_access(user.id, "write", tools.access_control, db=db) and user.role != "admin" ): raise HTTPException( @@ -442,7 +446,7 @@ async def delete_tools_by_id( detail=ERROR_MESSAGES.UNAUTHORIZED, ) - result = Tools.delete_tool_by_id(id) + result = Tools.delete_tool_by_id(id, db=db) if result: TOOLS = request.app.state.TOOLS if id in TOOLS: @@ -457,11 +461,11 @@ async def delete_tools_by_id( @router.get("/id/{id}/valves", response_model=Optional[dict]) -async def get_tools_valves_by_id(id: str, user=Depends(get_verified_user)): - tools = Tools.get_tool_by_id(id) +async def get_tools_valves_by_id(id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)): + tools = Tools.get_tool_by_id(id, db=db) if tools: try: - valves = Tools.get_tool_valves_by_id(id) + valves = Tools.get_tool_valves_by_id(id, db=db) return valves except Exception as e: raise HTTPException( @@ -482,9 +486,9 @@ async def get_tools_valves_by_id(id: str, user=Depends(get_verified_user)): @router.get("/id/{id}/valves/spec", response_model=Optional[dict]) async def get_tools_valves_spec_by_id( - request: Request, id: str, user=Depends(get_verified_user) + request: Request, id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) ): - tools = Tools.get_tool_by_id(id) + tools = Tools.get_tool_by_id(id, db=db) if tools: if id in request.app.state.TOOLS: tools_module = request.app.state.TOOLS[id] @@ -510,9 +514,9 @@ async def get_tools_valves_spec_by_id( @router.post("/id/{id}/valves/update", response_model=Optional[dict]) async def update_tools_valves_by_id( - request: Request, id: str, form_data: dict, user=Depends(get_verified_user) + request: Request, id: str, form_data: dict, user=Depends(get_verified_user), db: Session = Depends(get_session) ): - tools = Tools.get_tool_by_id(id) + tools = Tools.get_tool_by_id(id, db=db) if not tools: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -521,7 +525,7 @@ async def update_tools_valves_by_id( if ( tools.user_id != user.id - and not has_access(user.id, "write", tools.access_control) + and not has_access(user.id, "write", tools.access_control, db=db) and user.role != "admin" ): raise HTTPException( @@ -546,7 +550,7 @@ async def update_tools_valves_by_id( form_data = {k: v for k, v in form_data.items() if v is not None} valves = Valves(**form_data) valves_dict = valves.model_dump(exclude_unset=True) - Tools.update_tool_valves_by_id(id, valves_dict) + Tools.update_tool_valves_by_id(id, valves_dict, db=db) return valves_dict except Exception as e: log.exception(f"Failed to update tool valves by id {id}: {e}") @@ -562,11 +566,11 @@ async def update_tools_valves_by_id( @router.get("/id/{id}/valves/user", response_model=Optional[dict]) -async def get_tools_user_valves_by_id(id: str, user=Depends(get_verified_user)): - tools = Tools.get_tool_by_id(id) +async def get_tools_user_valves_by_id(id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)): + tools = Tools.get_tool_by_id(id, db=db) if tools: try: - user_valves = Tools.get_user_valves_by_id_and_user_id(id, user.id) + user_valves = Tools.get_user_valves_by_id_and_user_id(id, user.id, db=db) return user_valves except Exception as e: raise HTTPException( @@ -582,9 +586,9 @@ async def get_tools_user_valves_by_id(id: str, user=Depends(get_verified_user)): @router.get("/id/{id}/valves/user/spec", response_model=Optional[dict]) async def get_tools_user_valves_spec_by_id( - request: Request, id: str, user=Depends(get_verified_user) + request: Request, id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) ): - tools = Tools.get_tool_by_id(id) + tools = Tools.get_tool_by_id(id, db=db) if tools: if id in request.app.state.TOOLS: tools_module = request.app.state.TOOLS[id] @@ -605,9 +609,9 @@ async def get_tools_user_valves_spec_by_id( @router.post("/id/{id}/valves/user/update", response_model=Optional[dict]) async def update_tools_user_valves_by_id( - request: Request, id: str, form_data: dict, user=Depends(get_verified_user) + request: Request, id: str, form_data: dict, user=Depends(get_verified_user), db: Session = Depends(get_session) ): - tools = Tools.get_tool_by_id(id) + tools = Tools.get_tool_by_id(id, db=db) if tools: if id in request.app.state.TOOLS: @@ -624,7 +628,7 @@ async def update_tools_user_valves_by_id( user_valves = UserValves(**form_data) user_valves_dict = user_valves.model_dump(exclude_unset=True) Tools.update_user_valves_by_id_and_user_id( - id, user.id, user_valves_dict + id, user.id, user_valves_dict, db=db ) return user_valves_dict except Exception as e: diff --git a/backend/open_webui/routers/users.py b/backend/open_webui/routers/users.py index 12d1415059..a9686f1d15 100644 --- a/backend/open_webui/routers/users.py +++ b/backend/open_webui/routers/users.py @@ -1,5 +1,6 @@ import logging from typing import Optional +from sqlalchemy.orm import Session import base64 import io @@ -29,6 +30,7 @@ from open_webui.models.users import ( from open_webui.constants import ERROR_MESSAGES from open_webui.env import STATIC_DIR +from open_webui.internal.db import get_session from open_webui.utils.auth import ( @@ -60,6 +62,7 @@ async def get_users( direction: Optional[str] = None, page: Optional[int] = 1, user=Depends(get_admin_user), + db: Session = Depends(get_session), ): limit = PAGE_ITEM_COUNT @@ -74,7 +77,9 @@ async def get_users( if direction: filter["direction"] = direction - result = Users.get_users(filter=filter, skip=skip, limit=limit) + filter["direction"] = direction + + result = Users.get_users(filter=filter, skip=skip, limit=limit, db=db) users = result["users"] total = result["total"] @@ -85,7 +90,8 @@ async def get_users( **{ **user.model_dump(), "group_ids": [ - group.id for group in Groups.get_groups_by_member_id(user.id) + group.id + for group in Groups.get_groups_by_member_id(user.id, db=db) ], } ) @@ -98,8 +104,9 @@ async def get_users( @router.get("/all", response_model=UserInfoListResponse) async def get_all_users( user=Depends(get_admin_user), + db: Session = Depends(get_session), ): - return Users.get_users() + return Users.get_users(db=db) @router.get("/search", response_model=UserInfoListResponse) @@ -109,16 +116,13 @@ async def search_users( direction: Optional[str] = None, page: Optional[int] = 1, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): limit = PAGE_ITEM_COUNT page = max(1, page) skip = (page - 1) * limit - filter = {} - if query: - filter["query"] = query - filter = {} if query: filter["query"] = query @@ -127,7 +131,7 @@ async def search_users( if direction: filter["direction"] = direction - return Users.get_users(filter=filter, skip=skip, limit=limit) + return Users.get_users(filter=filter, skip=skip, limit=limit, db=db) ############################ @@ -136,8 +140,10 @@ async def search_users( @router.get("/groups") -async def get_user_groups(user=Depends(get_verified_user)): - return Groups.get_groups_by_member_id(user.id) +async def get_user_groups( + user=Depends(get_verified_user), db: Session = Depends(get_session) +): + return Groups.get_groups_by_member_id(user.id, db=db) ############################ @@ -146,9 +152,13 @@ async def get_user_groups(user=Depends(get_verified_user)): @router.get("/permissions") -async def get_user_permissisions(request: Request, user=Depends(get_verified_user)): +async def get_user_permissisions( + request: Request, + user=Depends(get_verified_user), + db: Session = Depends(get_session), +): user_permissions = get_permissions( - user.id, request.app.state.config.USER_PERMISSIONS + user.id, request.app.state.config.USER_PERMISSIONS, db=db ) return user_permissions @@ -256,8 +266,10 @@ async def update_default_user_permissions( @router.get("/user/settings", response_model=Optional[UserSettings]) -async def get_user_settings_by_session_user(user=Depends(get_verified_user)): - user = Users.get_user_by_id(user.id) +async def get_user_settings_by_session_user( + user=Depends(get_verified_user), db: Session = Depends(get_session) +): + user = Users.get_user_by_id(user.id, db=db) if user: return user.settings else: @@ -274,7 +286,10 @@ async def get_user_settings_by_session_user(user=Depends(get_verified_user)): @router.post("/user/settings/update", response_model=UserSettings) async def update_user_settings_by_session_user( - request: Request, form_data: UserSettings, user=Depends(get_verified_user) + request: Request, + form_data: UserSettings, + user=Depends(get_verified_user), + db: Session = Depends(get_session), ): updated_user_settings = form_data.model_dump() if ( @@ -289,7 +304,7 @@ async def update_user_settings_by_session_user( # If the user is not an admin and does not have permission to use tool servers, remove the key updated_user_settings["ui"].pop("toolServers", None) - user = Users.update_user_settings_by_id(user.id, updated_user_settings) + user = Users.update_user_settings_by_id(user.id, updated_user_settings, db=db) if user: return user.settings else: @@ -305,8 +320,10 @@ async def update_user_settings_by_session_user( @router.get("/user/status") -async def get_user_status_by_session_user(user=Depends(get_verified_user)): - user = Users.get_user_by_id(user.id) +async def get_user_status_by_session_user( + user=Depends(get_verified_user), db: Session = Depends(get_session) +): + user = Users.get_user_by_id(user.id, db=db) if user: return user else: @@ -323,11 +340,13 @@ async def get_user_status_by_session_user(user=Depends(get_verified_user)): @router.post("/user/status/update") async def update_user_status_by_session_user( - form_data: UserStatus, user=Depends(get_verified_user) + form_data: UserStatus, + user=Depends(get_verified_user), + db: Session = Depends(get_session), ): - user = Users.get_user_by_id(user.id) + user = Users.get_user_by_id(user.id, db=db) if user: - user = Users.update_user_status_by_id(user.id, form_data) + user = Users.update_user_status_by_id(user.id, form_data, db=db) return user else: raise HTTPException( @@ -342,8 +361,10 @@ async def update_user_status_by_session_user( @router.get("/user/info", response_model=Optional[dict]) -async def get_user_info_by_session_user(user=Depends(get_verified_user)): - user = Users.get_user_by_id(user.id) +async def get_user_info_by_session_user( + user=Depends(get_verified_user), db: Session = Depends(get_session) +): + user = Users.get_user_by_id(user.id, db=db) if user: return user.info else: @@ -360,14 +381,16 @@ async def get_user_info_by_session_user(user=Depends(get_verified_user)): @router.post("/user/info/update", response_model=Optional[dict]) async def update_user_info_by_session_user( - form_data: dict, user=Depends(get_verified_user) + form_data: dict, user=Depends(get_verified_user), db: Session = Depends(get_session) ): - user = Users.get_user_by_id(user.id) + user = Users.get_user_by_id(user.id, db=db) if user: if user.info is None: user.info = {} - user = Users.update_user_by_id(user.id, {"info": {**user.info, **form_data}}) + user = Users.update_user_by_id( + user.id, {"info": {**user.info, **form_data}}, db=db + ) if user: return user.info else: @@ -397,7 +420,9 @@ class UserActiveResponse(UserStatus): @router.get("/{user_id}", response_model=UserActiveResponse) -async def get_user_by_id(user_id: str, user=Depends(get_verified_user)): +async def get_user_by_id( + user_id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) +): # Check if user_id is a shared chat # If it is, get the user_id from the chat if user_id.startswith("shared-"): @@ -411,14 +436,14 @@ async def get_user_by_id(user_id: str, user=Depends(get_verified_user)): detail=ERROR_MESSAGES.USER_NOT_FOUND, ) - user = Users.get_user_by_id(user_id) + user = Users.get_user_by_id(user_id, db=db) if user: - groups = Groups.get_groups_by_member_id(user_id) + groups = Groups.get_groups_by_member_id(user_id, db=db) return UserActiveResponse( **{ **user.model_dump(), "groups": [{"id": group.id, "name": group.name} for group in groups], - "is_active": Users.is_user_active(user_id), + "is_active": Users.is_user_active(user_id, db=db), } ) else: @@ -429,8 +454,10 @@ async def get_user_by_id(user_id: str, user=Depends(get_verified_user)): @router.get("/{user_id}/oauth/sessions") -async def get_user_oauth_sessions_by_id(user_id: str, user=Depends(get_admin_user)): - sessions = OAuthSessions.get_sessions_by_user_id(user_id) +async def get_user_oauth_sessions_by_id( + user_id: str, user=Depends(get_admin_user), db: Session = Depends(get_session) +): + sessions = OAuthSessions.get_sessions_by_user_id(user_id, db=db) if sessions and len(sessions) > 0: return sessions else: @@ -446,8 +473,10 @@ async def get_user_oauth_sessions_by_id(user_id: str, user=Depends(get_admin_use @router.get("/{user_id}/profile/image") -async def get_user_profile_image_by_id(user_id: str, user=Depends(get_verified_user)): - user = Users.get_user_by_id(user_id) +async def get_user_profile_image_by_id( + user_id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) +): + user = Users.get_user_by_id(user_id, db=db) if user: if user.profile_image_url: # check if it's url or base64 @@ -484,9 +513,11 @@ async def get_user_profile_image_by_id(user_id: str, user=Depends(get_verified_u @router.get("/{user_id}/active", response_model=dict) -async def get_user_active_status_by_id(user_id: str, user=Depends(get_verified_user)): +async def get_user_active_status_by_id( + user_id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) +): return { - "active": Users.is_user_active(user_id), + "active": Users.is_user_active(user_id, db=db), } @@ -500,10 +531,11 @@ async def update_user_by_id( user_id: str, form_data: UserUpdateForm, session_user=Depends(get_admin_user), + db: Session = Depends(get_session), ): # Prevent modification of the primary admin user by other admins try: - first_user = Users.get_first_user() + first_user = Users.get_first_user(db=db) if first_user: if user_id == first_user.id: if session_user.id != user_id: @@ -527,11 +559,11 @@ async def update_user_by_id( detail="Could not verify primary admin status.", ) - user = Users.get_user_by_id(user_id) + user = Users.get_user_by_id(user_id, db=db) if user: if form_data.email.lower() != user.email: - email_user = Users.get_user_by_email(form_data.email.lower()) + email_user = Users.get_user_by_email(form_data.email.lower(), db=db) if email_user: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -545,9 +577,9 @@ async def update_user_by_id( raise HTTPException(400, detail=str(e)) hashed = get_password_hash(form_data.password) - Auths.update_user_password_by_id(user_id, hashed) + Auths.update_user_password_by_id(user_id, hashed, db=db) - Auths.update_email_by_id(user_id, form_data.email.lower()) + Auths.update_email_by_id(user_id, form_data.email.lower(), db=db) updated_user = Users.update_user_by_id( user_id, { @@ -556,6 +588,7 @@ async def update_user_by_id( "email": form_data.email.lower(), "profile_image_url": form_data.profile_image_url, }, + db=db, ) if updated_user: @@ -578,10 +611,12 @@ async def update_user_by_id( @router.delete("/{user_id}", response_model=bool) -async def delete_user_by_id(user_id: str, user=Depends(get_admin_user)): +async def delete_user_by_id( + user_id: str, user=Depends(get_admin_user), db: Session = Depends(get_session) +): # Prevent deletion of the primary admin user try: - first_user = Users.get_first_user() + first_user = Users.get_first_user(db=db) if first_user and user_id == first_user.id: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, @@ -595,7 +630,7 @@ async def delete_user_by_id(user_id: str, user=Depends(get_admin_user)): ) if user.id != user_id: - result = Auths.delete_auth_by_id(user_id) + result = Auths.delete_auth_by_id(user_id, db=db) if result: return True @@ -618,5 +653,7 @@ async def delete_user_by_id(user_id: str, user=Depends(get_admin_user)): @router.get("/{user_id}/groups") -async def get_user_groups_by_id(user_id: str, user=Depends(get_admin_user)): - return Groups.get_groups_by_member_id(user_id) +async def get_user_groups_by_id( + user_id: str, user=Depends(get_admin_user), db: Session = Depends(get_session) +): + return Groups.get_groups_by_member_id(user_id, db=db) diff --git a/backend/open_webui/utils/access_control.py b/backend/open_webui/utils/access_control.py index 97d0b41491..7784f6efd7 100644 --- a/backend/open_webui/utils/access_control.py +++ b/backend/open_webui/utils/access_control.py @@ -28,6 +28,7 @@ def fill_missing_permissions( def get_permissions( user_id: str, default_permissions: Dict[str, Any], + db: Optional[Any] = None, ) -> Dict[str, Any]: """ Get all permissions for a user by combining the permissions of all groups the user is a member of. @@ -53,7 +54,7 @@ def get_permissions( ) # Use the most permissive value (True > False) return permissions - user_groups = Groups.get_groups_by_member_id(user_id) + user_groups = Groups.get_groups_by_member_id(user_id, db=db) # Deep copy default permissions to avoid modifying the original dict permissions = json.loads(json.dumps(default_permissions)) @@ -72,6 +73,7 @@ def has_permission( user_id: str, permission_key: str, default_permissions: Dict[str, Any] = {}, + db: Optional[Any] = None, ) -> bool: """ Check if a user has a specific permission by checking the group permissions @@ -92,7 +94,7 @@ def has_permission( permission_hierarchy = permission_key.split(".") # Retrieve user group permissions - user_groups = Groups.get_groups_by_member_id(user_id) + user_groups = Groups.get_groups_by_member_id(user_id, db=db) for group in user_groups: if get_permission(group.permissions or {}, permission_hierarchy): @@ -127,6 +129,7 @@ def has_access( access_control: Optional[dict] = None, user_group_ids: Optional[Set[str]] = None, strict: bool = True, + db: Optional[Any] = None, ) -> bool: if access_control is None: if strict: @@ -135,7 +138,7 @@ def has_access( return True if user_group_ids is None: - user_groups = Groups.get_groups_by_member_id(user_id) + user_groups = Groups.get_groups_by_member_id(user_id, db=db) user_group_ids = {group.id for group in user_groups} permitted_ids = get_permitted_group_and_user_ids(type, access_control) @@ -152,10 +155,10 @@ def has_access( # Get all users with access to a resource def get_users_with_access( - type: str = "write", access_control: Optional[dict] = None + type: str = "write", access_control: Optional[dict] = None, db: Optional[Any] = None ) -> list[UserModel]: if access_control is None: - result = Users.get_users(filter={"roles": ["!pending"]}) + result = Users.get_users(filter={"roles": ["!pending"]}, db=db) return result.get("users", []) permitted_ids = get_permitted_group_and_user_ids(type, access_control) @@ -167,8 +170,8 @@ def get_users_with_access( user_ids_with_access = set(permitted_user_ids) - group_user_ids_map = Groups.get_group_user_ids_by_ids(permitted_group_ids) + group_user_ids_map = Groups.get_group_user_ids_by_ids(permitted_group_ids, db=db) for user_ids in group_user_ids_map.values(): user_ids_with_access.update(user_ids) - return Users.get_users_by_user_ids(list(user_ids_with_access)) + return Users.get_users_by_user_ids(list(user_ids_with_access), db=db) diff --git a/backend/open_webui/utils/auth.py b/backend/open_webui/utils/auth.py index 6cd5fd2d1b..d89beb2ae5 100644 --- a/backend/open_webui/utils/auth.py +++ b/backend/open_webui/utils/auth.py @@ -42,6 +42,8 @@ from open_webui.env import ( from fastapi import BackgroundTasks, Depends, HTTPException, Request, Response, status from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer +from sqlalchemy.orm import Session +from open_webui.internal.db import get_session log = logging.getLogger(__name__) @@ -271,6 +273,7 @@ async def get_current_user( response: Response, background_tasks: BackgroundTasks, auth_token: HTTPAuthorizationCredentials = Depends(bearer_security), + db: Session = Depends(get_session), ): token = None @@ -285,7 +288,7 @@ async def get_current_user( # auth by api key if token.startswith("sk-"): - user = get_current_user_by_api_key(request, token) + user = get_current_user_by_api_key(request, token, db=db) # Add user info to current span current_span = trace.get_current_span() @@ -314,7 +317,7 @@ async def get_current_user( detail="Invalid token", ) - user = Users.get_user_by_id(data["id"]) + user = Users.get_user_by_id(data["id"], db=db) if user is None: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -364,8 +367,8 @@ async def get_current_user( raise e -def get_current_user_by_api_key(request, api_key: str): - user = Users.get_user_by_api_key(api_key) +def get_current_user_by_api_key(request, api_key: str, db: Session = None): + user = Users.get_user_by_api_key(api_key, db=db) if user is None: raise HTTPException( @@ -393,7 +396,7 @@ def get_current_user_by_api_key(request, api_key: str): current_span.set_attribute("client.user.role", user.role) current_span.set_attribute("client.auth.type", "api_key") - Users.update_last_active_by_id(user.id) + Users.update_last_active_by_id(user.id, db=db) return user diff --git a/backend/open_webui/utils/groups.py b/backend/open_webui/utils/groups.py index 0f15f27e2c..26fc5d8434 100644 --- a/backend/open_webui/utils/groups.py +++ b/backend/open_webui/utils/groups.py @@ -7,6 +7,7 @@ log = logging.getLogger(__name__) def apply_default_group_assignment( default_group_id: str, user_id: str, + db=None, ) -> None: """ Apply default group assignment to a user if default_group_id is provided. @@ -17,7 +18,7 @@ def apply_default_group_assignment( """ if default_group_id: try: - Groups.add_users_to_group(default_group_id, [user_id]) + Groups.add_users_to_group(default_group_id, [user_id], db=db) except Exception as e: log.error( f"Failed to add user {user_id} to default group {default_group_id}: {e}" diff --git a/backend/open_webui/utils/oauth.py b/backend/open_webui/utils/oauth.py index 13108f78e5..ffdd99b878 100644 --- a/backend/open_webui/utils/oauth.py +++ b/backend/open_webui/utils/oauth.py @@ -1336,7 +1336,7 @@ class OAuthManager: return await client.authorize_redirect(request, redirect_uri, **kwargs) - async def handle_callback(self, request, provider, response): + async def handle_callback(self, request, provider, response, db=None): if provider not in OAUTH_PROVIDERS: raise HTTPException(404) @@ -1461,20 +1461,20 @@ class OAuthManager: raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) # Check if the user exists - user = Users.get_user_by_oauth_sub(provider, sub) + user = Users.get_user_by_oauth_sub(provider, sub, db=db) if not user: # If the user does not exist, check if merging is enabled if auth_manager_config.OAUTH_MERGE_ACCOUNTS_BY_EMAIL: # Check if the user exists by email - user = Users.get_user_by_email(email) + user = Users.get_user_by_email(email, db=db) if user: # Update the user with the new oauth sub - Users.update_user_oauth_by_id(user.id, provider, sub) + Users.update_user_oauth_by_id(user.id, provider, sub, db=db) if user: determined_role = self.get_user_role(user, user_data) if user.role != determined_role: - Users.update_user_role_by_id(user.id, determined_role) + Users.update_user_role_by_id(user.id, determined_role, db=db) # Update the user object in memory as well, # to avoid problems with the ENABLE_OAUTH_GROUP_MANAGEMENT check below user.role = determined_role @@ -1491,14 +1491,14 @@ class OAuthManager: ) if processed_picture_url != user.profile_image_url: Users.update_user_profile_image_url_by_id( - user.id, processed_picture_url + user.id, processed_picture_url, db=db ) log.debug(f"Updated profile picture for user {user.email}") else: # If the user does not exist, check if signups are enabled if auth_manager_config.ENABLE_OAUTH_SIGNUP: # Check if an existing user with the same email already exists - existing_user = Users.get_user_by_email(email) + existing_user = Users.get_user_by_email(email, db=db) if existing_user: raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN) @@ -1529,6 +1529,7 @@ class OAuthManager: profile_image_url=picture_url, role=self.get_user_role(None, user_data), oauth=oauth_data, + db=db, ) if auth_manager_config.WEBHOOK_URL: @@ -1544,8 +1545,7 @@ class OAuthManager: ) apply_default_group_assignment( - request.app.state.config.DEFAULT_GROUP_ID, - user.id, + request.app.state.config.DEFAULT_GROUP_ID, user.id, db=db ) else: @@ -1616,15 +1616,16 @@ class OAuthManager: token["expires_at"] = datetime.now().timestamp() + token["expires_in"] # Clean up any existing sessions for this user/provider first - sessions = OAuthSessions.get_sessions_by_user_id(user.id) + sessions = OAuthSessions.get_sessions_by_user_id(user.id, db=db) for session in sessions: if session.provider == provider: - OAuthSessions.delete_session_by_id(session.id) + OAuthSessions.delete_session_by_id(session.id, db=db) session = OAuthSessions.create_session( user_id=user.id, provider=provider, token=token, + db=db, ) response.set_cookie(