refac/enh: db session sharing

This commit is contained in:
Timothy Jaeryang Baek 2025-12-29 00:21:18 +04:00
parent 6dd0f99b90
commit b1d0f00d8c
23 changed files with 1173 additions and 663 deletions

View file

@ -19,7 +19,7 @@ from open_webui.env import (
from peewee_migrate import Router from peewee_migrate import Router
from sqlalchemy import Dialect, create_engine, MetaData, event, types from sqlalchemy import Dialect, create_engine, MetaData, event, types
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import scoped_session, sessionmaker from sqlalchemy.orm import scoped_session, sessionmaker, Session
from sqlalchemy.pool import QueuePool, NullPool from sqlalchemy.pool import QueuePool, NullPool
from sqlalchemy.sql.type_api import _T from sqlalchemy.sql.type_api import _T
from typing_extensions import Self from typing_extensions import Self
@ -148,7 +148,7 @@ SessionLocal = sessionmaker(
) )
metadata_obj = MetaData(schema=DATABASE_SCHEMA) metadata_obj = MetaData(schema=DATABASE_SCHEMA)
Base = declarative_base(metadata=metadata_obj) Base = declarative_base(metadata=metadata_obj)
Session = scoped_session(SessionLocal) ScopedSession = scoped_session(SessionLocal)
def get_session(): def get_session():
@ -169,4 +169,3 @@ def get_db_context(db: Optional[Session] = None):
else: else:
with get_db() as session: with get_db() as session:
yield session yield session

View file

@ -102,7 +102,9 @@ from open_webui.routers.retrieval import (
get_rf, get_rf,
) )
from open_webui.internal.db import Session, engine
from sqlalchemy.orm import Session
from open_webui.internal.db import ScopedSession, engine, get_session
from open_webui.models.functions import Functions from open_webui.models.functions import Functions
from open_webui.models.models import Models from open_webui.models.models import Models
@ -1324,7 +1326,7 @@ app.add_middleware(APIKeyRestrictionMiddleware)
async def commit_session_after_request(request: Request, call_next): async def commit_session_after_request(request: Request, call_next):
response = await call_next(request) response = await call_next(request)
# log.debug("Commit session after request") # log.debug("Commit session after request")
Session.commit() ScopedSession.commit()
return response return response
@ -2280,8 +2282,13 @@ async def oauth_login(provider: str, request: Request):
# - Email addresses are considered unique, so we fail registration if the email address is already taken # - Email addresses are considered unique, so we fail registration if the email address is already taken
@app.get("/oauth/{provider}/login/callback") @app.get("/oauth/{provider}/login/callback")
@app.get("/oauth/{provider}/callback") # Legacy endpoint @app.get("/oauth/{provider}/callback") # Legacy endpoint
async def oauth_login_callback(provider: str, request: Request, response: Response): async def oauth_login_callback(
return await oauth_manager.handle_callback(request, provider, response) provider: str,
request: Request,
response: Response,
db: Session = Depends(get_session),
):
return await oauth_manager.handle_callback(request, provider, response, db=db)
@app.get("/manifest.json") @app.get("/manifest.json")
@ -2340,7 +2347,7 @@ async def healthcheck():
@app.get("/health/db") @app.get("/health/db")
async def healthcheck_with_db(): async def healthcheck_with_db():
Session.execute(text("SELECT 1;")).all() ScopedSession.execute(text("SELECT 1;")).all()
return {"status": True} return {"status": True}

View file

@ -30,27 +30,27 @@ from sqlalchemy.exc import NoSuchTableError
from sqlalchemy.dialects.postgresql.psycopg2 import PGDialect_psycopg2 from sqlalchemy.dialects.postgresql.psycopg2 import PGDialect_psycopg2
from sqlalchemy.dialects import registry from sqlalchemy.dialects import registry
class OpenGaussDialect(PGDialect_psycopg2): class OpenGaussDialect(PGDialect_psycopg2):
name = "opengauss" name = "opengauss"
def _get_server_version_info(self, connection): def _get_server_version_info(self, connection):
try: try:
version = connection.exec_driver_sql("SELECT version()").scalar() version = connection.exec_driver_sql("SELECT version()").scalar()
if not version: if not version:
return (9, 0, 0) return (9, 0, 0)
match = re.search( match = re.search(
r"openGauss\s+(\d+)\.(\d+)\.(\d+)(?:-\w+)?", r"openGauss\s+(\d+)\.(\d+)\.(\d+)(?:-\w+)?", version, re.IGNORECASE
version,
re.IGNORECASE
) )
if match: if match:
return (int(match.group(1)), int(match.group(2)), int(match.group(3))) return (int(match.group(1)), int(match.group(2)), int(match.group(3)))
return super()._get_server_version_info(connection) return super()._get_server_version_info(connection)
except Exception: except Exception:
return (9, 0, 0) return (9, 0, 0)
# Register dialect # Register dialect
registry.register("opengauss", __name__, "OpenGaussDialect") registry.register("opengauss", __name__, "OpenGaussDialect")
@ -78,6 +78,7 @@ Base = declarative_base()
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"]) log.setLevel(SRC_LOG_LEVELS["RAG"])
class DocumentChunk(Base): class DocumentChunk(Base):
__tablename__ = "document_chunk" __tablename__ = "document_chunk"
@ -87,29 +88,30 @@ class DocumentChunk(Base):
text = Column(Text, nullable=True) text = Column(Text, nullable=True)
vmetadata = Column(MutableDict.as_mutable(JSONB), nullable=True) vmetadata = Column(MutableDict.as_mutable(JSONB), nullable=True)
class OpenGaussClient(VectorDBBase): class OpenGaussClient(VectorDBBase):
def __init__(self) -> None: def __init__(self) -> None:
if not OPENGAUSS_DB_URL: if not OPENGAUSS_DB_URL:
from open_webui.internal.db import Session from open_webui.internal.db import ScopedSession
self.session = Session
self.session = ScopedSession
else: else:
engine_kwargs = { engine_kwargs = {"pool_pre_ping": True, "dialect": OpenGaussDialect()}
"pool_pre_ping": True,
"dialect": OpenGaussDialect()
}
if isinstance(OPENGAUSS_POOL_SIZE, int) and OPENGAUSS_POOL_SIZE > 0: if isinstance(OPENGAUSS_POOL_SIZE, int) and OPENGAUSS_POOL_SIZE > 0:
engine_kwargs.update({ engine_kwargs.update(
"pool_size": OPENGAUSS_POOL_SIZE, {
"max_overflow": OPENGAUSS_POOL_MAX_OVERFLOW, "pool_size": OPENGAUSS_POOL_SIZE,
"pool_timeout": OPENGAUSS_POOL_TIMEOUT, "max_overflow": OPENGAUSS_POOL_MAX_OVERFLOW,
"pool_recycle": OPENGAUSS_POOL_RECYCLE, "pool_timeout": OPENGAUSS_POOL_TIMEOUT,
"poolclass": QueuePool "pool_recycle": OPENGAUSS_POOL_RECYCLE,
}) "poolclass": QueuePool,
}
)
else: else:
engine_kwargs["poolclass"] = NullPool engine_kwargs["poolclass"] = NullPool
engine = create_engine(OPENGAUSS_DB_URL,** engine_kwargs) engine = create_engine(OPENGAUSS_DB_URL, **engine_kwargs)
SessionLocal = sessionmaker( SessionLocal = sessionmaker(
autocommit=False, autoflush=False, bind=engine, expire_on_commit=False autocommit=False, autoflush=False, bind=engine, expire_on_commit=False
@ -160,7 +162,9 @@ class OpenGaussClient(VectorDBBase):
else: else:
raise Exception("The 'vector' column type is not Vector.") raise Exception("The 'vector' column type is not Vector.")
else: else:
raise Exception("The 'vector' column does not exist in the 'document_chunk' table.") raise Exception(
"The 'vector' column does not exist in the 'document_chunk' table."
)
def adjust_vector_length(self, vector: List[float]) -> List[float]: def adjust_vector_length(self, vector: List[float]) -> List[float]:
current_length = len(vector) current_length = len(vector)
@ -185,7 +189,9 @@ class OpenGaussClient(VectorDBBase):
new_items.append(new_chunk) new_items.append(new_chunk)
self.session.bulk_save_objects(new_items) self.session.bulk_save_objects(new_items)
self.session.commit() self.session.commit()
log.info(f"Inserting {len(new_items)} items into collection '{collection_name}'.") log.info(
f"Inserting {len(new_items)} items into collection '{collection_name}'."
)
except Exception as e: except Exception as e:
self.session.rollback() self.session.rollback()
log.exception(f"Failed to insert data: {e}") log.exception(f"Failed to insert data: {e}")
@ -215,7 +221,9 @@ class OpenGaussClient(VectorDBBase):
) )
self.session.add(new_chunk) self.session.add(new_chunk)
self.session.commit() self.session.commit()
log.info(f"Inserting/updating {len(items)} items in collection '{collection_name}'.") log.info(
f"Inserting/updating {len(items)} items in collection '{collection_name}'."
)
except Exception as e: except Exception as e:
self.session.rollback() self.session.rollback()
log.exception(f"Failed to insert or update data.: {e}") log.exception(f"Failed to insert or update data.: {e}")
@ -241,7 +249,9 @@ class OpenGaussClient(VectorDBBase):
q_vector_col = column("q_vector", Vector(VECTOR_LENGTH)) q_vector_col = column("q_vector", Vector(VECTOR_LENGTH))
query_vectors = ( query_vectors = (
values(qid_col, q_vector_col) values(qid_col, q_vector_col)
.data([(idx, vector_expr(vector)) for idx, vector in enumerate(vectors)]) .data(
[(idx, vector_expr(vector)) for idx, vector in enumerate(vectors)]
)
.alias("query_vectors") .alias("query_vectors")
) )
@ -249,13 +259,17 @@ class OpenGaussClient(VectorDBBase):
DocumentChunk.id, DocumentChunk.id,
DocumentChunk.text, DocumentChunk.text,
DocumentChunk.vmetadata, DocumentChunk.vmetadata,
(DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)).label("distance"), (DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)).label(
"distance"
),
] ]
subq = ( subq = (
select(*result_fields) select(*result_fields)
.where(DocumentChunk.collection_name == collection_name) .where(DocumentChunk.collection_name == collection_name)
.order_by(DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)) .order_by(
DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)
)
) )
if limit is not None: if limit is not None:
subq = subq.limit(limit) subq = subq.limit(limit)
@ -368,7 +382,9 @@ class OpenGaussClient(VectorDBBase):
query = query.filter(DocumentChunk.id.in_(ids)) query = query.filter(DocumentChunk.id.in_(ids))
if filter: if filter:
for key, value in filter.items(): for key, value in filter.items():
query = query.filter(DocumentChunk.vmetadata[key].astext == str(value)) query = query.filter(
DocumentChunk.vmetadata[key].astext == str(value)
)
deleted = query.delete(synchronize_session=False) deleted = query.delete(synchronize_session=False)
self.session.commit() self.session.commit()
log.info(f"Deleted {deleted} items from collection '{collection_name}'") log.info(f"Deleted {deleted} items from collection '{collection_name}'")
@ -395,7 +411,8 @@ class OpenGaussClient(VectorDBBase):
exists = ( exists = (
self.session.query(DocumentChunk) self.session.query(DocumentChunk)
.filter(DocumentChunk.collection_name == collection_name) .filter(DocumentChunk.collection_name == collection_name)
.first() is not None .first()
is not None
) )
self.session.rollback() self.session.rollback()
return exists return exists
@ -406,4 +423,4 @@ class OpenGaussClient(VectorDBBase):
def delete_collection(self, collection_name: str) -> None: def delete_collection(self, collection_name: str) -> None:
self.delete(collection_name) self.delete(collection_name)
log.info(f"Collection '{collection_name}' has been deleted") log.info(f"Collection '{collection_name}' has been deleted")

View file

@ -90,9 +90,9 @@ class PgvectorClient(VectorDBBase):
# if no pgvector uri, use the existing database connection # if no pgvector uri, use the existing database connection
if not PGVECTOR_DB_URL: if not PGVECTOR_DB_URL:
from open_webui.internal.db import Session from open_webui.internal.db import ScopedSession
self.session = Session self.session = ScopedSession
else: else:
if isinstance(PGVECTOR_POOL_SIZE, int): if isinstance(PGVECTOR_POOL_SIZE, int):
if PGVECTOR_POOL_SIZE > 0: if PGVECTOR_POOL_SIZE > 0:

View file

@ -62,6 +62,8 @@ from open_webui.utils.auth import (
get_password_hash, get_password_hash,
get_http_authorization_cred, get_http_authorization_cred,
) )
from open_webui.internal.db import get_session
from sqlalchemy.orm import Session
from open_webui.utils.webhook import post_webhook from open_webui.utils.webhook import post_webhook
from open_webui.utils.access_control import get_permissions, has_permission from open_webui.utils.access_control import get_permissions, has_permission
from open_webui.utils.groups import apply_default_group_assignment from open_webui.utils.groups import apply_default_group_assignment
@ -103,7 +105,10 @@ class SessionUserInfoResponse(SessionUserResponse, UserStatus):
@router.get("/", response_model=SessionUserInfoResponse) @router.get("/", response_model=SessionUserInfoResponse)
async def get_session_user( async def get_session_user(
request: Request, response: Response, user=Depends(get_current_user) request: Request,
response: Response,
user=Depends(get_current_user),
db: Session = Depends(get_session),
): ):
auth_header = request.headers.get("Authorization") auth_header = request.headers.get("Authorization")
@ -137,7 +142,7 @@ async def get_session_user(
) )
user_permissions = get_permissions( user_permissions = get_permissions(
user.id, request.app.state.config.USER_PERMISSIONS user.id, request.app.state.config.USER_PERMISSIONS, db=db
) )
return { return {
@ -166,12 +171,15 @@ async def get_session_user(
@router.post("/update/profile", response_model=UserProfileImageResponse) @router.post("/update/profile", response_model=UserProfileImageResponse)
async def update_profile( async def update_profile(
form_data: UpdateProfileForm, session_user=Depends(get_verified_user) form_data: UpdateProfileForm,
session_user=Depends(get_verified_user),
db: Session = Depends(get_session),
): ):
if session_user: if session_user:
user = Users.update_user_by_id( user = Users.update_user_by_id(
session_user.id, session_user.id,
form_data.model_dump(), form_data.model_dump(),
db=db,
) )
if user: if user:
return user return user
@ -188,13 +196,17 @@ async def update_profile(
@router.post("/update/password", response_model=bool) @router.post("/update/password", response_model=bool)
async def update_password( async def update_password(
form_data: UpdatePasswordForm, session_user=Depends(get_current_user) form_data: UpdatePasswordForm,
session_user=Depends(get_current_user),
db: Session = Depends(get_session),
): ):
if WEBUI_AUTH_TRUSTED_EMAIL_HEADER: if WEBUI_AUTH_TRUSTED_EMAIL_HEADER:
raise HTTPException(400, detail=ERROR_MESSAGES.ACTION_PROHIBITED) raise HTTPException(400, detail=ERROR_MESSAGES.ACTION_PROHIBITED)
if session_user: if session_user:
user = Auths.authenticate_user( user = Auths.authenticate_user(
session_user.email, lambda pw: verify_password(form_data.password, pw) session_user.email,
lambda pw: verify_password(form_data.password, pw),
db=db,
) )
if user: if user:
@ -203,7 +215,7 @@ async def update_password(
except Exception as e: except Exception as e:
raise HTTPException(400, detail=str(e)) raise HTTPException(400, detail=str(e))
hashed = get_password_hash(form_data.new_password) hashed = get_password_hash(form_data.new_password)
return Auths.update_user_password_by_id(user.id, hashed) return Auths.update_user_password_by_id(user.id, hashed, db=db)
else: else:
raise HTTPException(400, detail=ERROR_MESSAGES.INCORRECT_PASSWORD) raise HTTPException(400, detail=ERROR_MESSAGES.INCORRECT_PASSWORD)
else: else:
@ -214,7 +226,12 @@ async def update_password(
# LDAP Authentication # LDAP Authentication
############################ ############################
@router.post("/ldap", response_model=SessionUserResponse) @router.post("/ldap", response_model=SessionUserResponse)
async def ldap_auth(request: Request, response: Response, form_data: LdapForm): async def ldap_auth(
request: Request,
response: Response,
form_data: LdapForm,
db: Session = Depends(get_session),
):
# Security checks FIRST - before loading any config # Security checks FIRST - before loading any config
if not request.app.state.config.ENABLE_LDAP: if not request.app.state.config.ENABLE_LDAP:
raise HTTPException(400, detail="LDAP authentication is not enabled") raise HTTPException(400, detail="LDAP authentication is not enabled")
@ -400,12 +417,12 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
if not connection_user.bind(): if not connection_user.bind():
raise HTTPException(400, "Authentication failed.") raise HTTPException(400, "Authentication failed.")
user = Users.get_user_by_email(email) user = Users.get_user_by_email(email, db=db)
if not user: if not user:
try: try:
role = ( role = (
"admin" "admin"
if not Users.has_users() if not Users.has_users(db=db)
else request.app.state.config.DEFAULT_USER_ROLE else request.app.state.config.DEFAULT_USER_ROLE
) )
@ -414,6 +431,7 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
password=str(uuid.uuid4()), password=str(uuid.uuid4()),
name=cn, name=cn,
role=role, role=role,
db=db,
) )
if not user: if not user:
@ -424,6 +442,7 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
apply_default_group_assignment( apply_default_group_assignment(
request.app.state.config.DEFAULT_GROUP_ID, request.app.state.config.DEFAULT_GROUP_ID,
user.id, user.id,
db=db,
) )
except HTTPException: except HTTPException:
@ -434,7 +453,7 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
500, detail="Internal error occurred during LDAP user creation." 500, detail="Internal error occurred during LDAP user creation."
) )
user = Auths.authenticate_user_by_email(email) user = Auths.authenticate_user_by_email(email, db=db)
if user: if user:
expires_delta = parse_duration(request.app.state.config.JWT_EXPIRES_IN) expires_delta = parse_duration(request.app.state.config.JWT_EXPIRES_IN)
@ -464,7 +483,7 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
) )
user_permissions = get_permissions( user_permissions = get_permissions(
user.id, request.app.state.config.USER_PERMISSIONS user.id, request.app.state.config.USER_PERMISSIONS, db=db
) )
if ( if (
@ -473,9 +492,9 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
and user_groups and user_groups
): ):
if ENABLE_LDAP_GROUP_CREATION: if ENABLE_LDAP_GROUP_CREATION:
Groups.create_groups_by_group_names(user.id, user_groups) Groups.create_groups_by_group_names(user.id, user_groups, db=db)
try: try:
Groups.sync_groups_by_group_names(user.id, user_groups) Groups.sync_groups_by_group_names(user.id, user_groups, db=db)
log.info( log.info(
f"Successfully synced groups for user {user.id}: {user_groups}" f"Successfully synced groups for user {user.id}: {user_groups}"
) )
@ -508,7 +527,12 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
@router.post("/signin", response_model=SessionUserResponse) @router.post("/signin", response_model=SessionUserResponse)
async def signin(request: Request, response: Response, form_data: SigninForm): async def signin(
request: Request,
response: Response,
form_data: SigninForm,
db: Session = Depends(get_session),
):
if not ENABLE_PASSWORD_AUTH: if not ENABLE_PASSWORD_AUTH:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
@ -529,14 +553,15 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
except Exception as e: except Exception as e:
pass pass
if not Users.get_user_by_email(email.lower()): if not Users.get_user_by_email(email.lower(), db=db):
await signup( await signup(
request, request,
response, response,
SignupForm(email=email, password=str(uuid.uuid4()), name=name), SignupForm(email=email, password=str(uuid.uuid4()), name=name),
db=db,
) )
user = Auths.authenticate_user_by_email(email) user = Auths.authenticate_user_by_email(email, db=db)
if WEBUI_AUTH_TRUSTED_GROUPS_HEADER and user and user.role != "admin": if WEBUI_AUTH_TRUSTED_GROUPS_HEADER and user and user.role != "admin":
group_names = request.headers.get( group_names = request.headers.get(
WEBUI_AUTH_TRUSTED_GROUPS_HEADER, "" WEBUI_AUTH_TRUSTED_GROUPS_HEADER, ""
@ -544,28 +569,33 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
group_names = [name.strip() for name in group_names if name.strip()] group_names = [name.strip() for name in group_names if name.strip()]
if group_names: if group_names:
Groups.sync_groups_by_group_names(user.id, group_names) Groups.sync_groups_by_group_names(user.id, group_names, db=db)
elif WEBUI_AUTH == False: elif WEBUI_AUTH == False:
admin_email = "admin@localhost" admin_email = "admin@localhost"
admin_password = "admin" admin_password = "admin"
if Users.get_user_by_email(admin_email.lower()): if Users.get_user_by_email(admin_email.lower(), db=db):
user = Auths.authenticate_user( user = Auths.authenticate_user(
admin_email.lower(), lambda pw: verify_password(admin_password, pw) admin_email.lower(),
lambda pw: verify_password(admin_password, pw),
db=db,
) )
else: else:
if Users.has_users(): if Users.has_users(db=db):
raise HTTPException(400, detail=ERROR_MESSAGES.EXISTING_USERS) raise HTTPException(400, detail=ERROR_MESSAGES.EXISTING_USERS)
await signup( await signup(
request, request,
response, response,
SignupForm(email=admin_email, password=admin_password, name="User"), SignupForm(email=admin_email, password=admin_password, name="User"),
db=db,
) )
user = Auths.authenticate_user( user = Auths.authenticate_user(
admin_email.lower(), lambda pw: verify_password(admin_password, pw) admin_email.lower(),
lambda pw: verify_password(admin_password, pw),
db=db,
) )
else: else:
if signin_rate_limiter.is_limited(form_data.email.lower()): if signin_rate_limiter.is_limited(form_data.email.lower()):
@ -584,7 +614,9 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
form_data.password = password_bytes.decode("utf-8", errors="ignore") form_data.password = password_bytes.decode("utf-8", errors="ignore")
user = Auths.authenticate_user( user = Auths.authenticate_user(
form_data.email.lower(), lambda pw: verify_password(form_data.password, pw) form_data.email.lower(),
lambda pw: verify_password(form_data.password, pw),
db=db,
) )
if user: if user:
@ -616,7 +648,7 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
) )
user_permissions = get_permissions( user_permissions = get_permissions(
user.id, request.app.state.config.USER_PERMISSIONS user.id, request.app.state.config.USER_PERMISSIONS, db=db
) )
return { return {
@ -640,8 +672,13 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
@router.post("/signup", response_model=SessionUserResponse) @router.post("/signup", response_model=SessionUserResponse)
async def signup(request: Request, response: Response, form_data: SignupForm): async def signup(
has_users = Users.has_users() request: Request,
response: Response,
form_data: SignupForm,
db: Session = Depends(get_session),
):
has_users = Users.has_users(db=db)
if WEBUI_AUTH: if WEBUI_AUTH:
if ( if (
@ -663,7 +700,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT
) )
if Users.get_user_by_email(form_data.email.lower()): if Users.get_user_by_email(form_data.email.lower(), db=db):
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN) raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
try: try:
@ -681,6 +718,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
form_data.name, form_data.name,
form_data.profile_image_url, form_data.profile_image_url,
role, role,
db=db,
) )
if user: if user:
@ -723,7 +761,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
) )
user_permissions = get_permissions( user_permissions = get_permissions(
user.id, request.app.state.config.USER_PERMISSIONS user.id, request.app.state.config.USER_PERMISSIONS, db=db
) )
if not has_users: if not has_users:
@ -733,6 +771,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
apply_default_group_assignment( apply_default_group_assignment(
request.app.state.config.DEFAULT_GROUP_ID, request.app.state.config.DEFAULT_GROUP_ID,
user.id, user.id,
db=db,
) )
return { return {
@ -754,7 +793,9 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
@router.get("/signout") @router.get("/signout")
async def signout(request: Request, response: Response): async def signout(
request: Request, response: Response, db: Session = Depends(get_session)
):
# get auth token from headers or cookies # get auth token from headers or cookies
token = None token = None
@ -776,7 +817,7 @@ async def signout(request: Request, response: Response):
if oauth_session_id: if oauth_session_id:
response.delete_cookie("oauth_session_id") response.delete_cookie("oauth_session_id")
session = OAuthSessions.get_session_by_id(oauth_session_id) session = OAuthSessions.get_session_by_id(oauth_session_id, db=db)
oauth_server_metadata_url = ( oauth_server_metadata_url = (
request.app.state.oauth_manager.get_server_metadata_url(session.provider) request.app.state.oauth_manager.get_server_metadata_url(session.provider)
if session if session
@ -839,14 +880,17 @@ async def signout(request: Request, response: Response):
@router.post("/add", response_model=SigninResponse) @router.post("/add", response_model=SigninResponse)
async def add_user( async def add_user(
request: Request, form_data: AddUserForm, user=Depends(get_admin_user) request: Request,
form_data: AddUserForm,
user=Depends(get_admin_user),
db: Session = Depends(get_session),
): ):
if not validate_email_format(form_data.email.lower()): if not validate_email_format(form_data.email.lower()):
raise HTTPException( raise HTTPException(
status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT
) )
if Users.get_user_by_email(form_data.email.lower()): if Users.get_user_by_email(form_data.email.lower(), db=db):
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN) raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
try: try:
@ -862,12 +906,14 @@ async def add_user(
form_data.name, form_data.name,
form_data.profile_image_url, form_data.profile_image_url,
form_data.role, form_data.role,
db=db,
) )
if user: if user:
apply_default_group_assignment( apply_default_group_assignment(
request.app.state.config.DEFAULT_GROUP_ID, request.app.state.config.DEFAULT_GROUP_ID,
user.id, user.id,
db=db,
) )
token = create_token(data={"id": user.id}) token = create_token(data={"id": user.id})
@ -895,7 +941,9 @@ async def add_user(
@router.get("/admin/details") @router.get("/admin/details")
async def get_admin_details(request: Request, user=Depends(get_current_user)): async def get_admin_details(
request: Request, user=Depends(get_current_user), db: Session = Depends(get_session)
):
if request.app.state.config.SHOW_ADMIN_DETAILS: if request.app.state.config.SHOW_ADMIN_DETAILS:
admin_email = request.app.state.config.ADMIN_EMAIL admin_email = request.app.state.config.ADMIN_EMAIL
admin_name = None admin_name = None
@ -903,11 +951,11 @@ async def get_admin_details(request: Request, user=Depends(get_current_user)):
log.info(f"Admin details - Email: {admin_email}, Name: {admin_name}") log.info(f"Admin details - Email: {admin_email}, Name: {admin_name}")
if admin_email: if admin_email:
admin = Users.get_user_by_email(admin_email) admin = Users.get_user_by_email(admin_email, db=db)
if admin: if admin:
admin_name = admin.name admin_name = admin.name
else: else:
admin = Users.get_first_user() admin = Users.get_first_user(db=db)
if admin: if admin:
admin_email = admin.email admin_email = admin.email
admin_name = admin.name admin_name = admin.name
@ -1149,7 +1197,9 @@ async def update_ldap_config(
# create api key # create api key
@router.post("/api_key", response_model=ApiKey) @router.post("/api_key", response_model=ApiKey)
async def generate_api_key(request: Request, user=Depends(get_current_user)): async def generate_api_key(
request: Request, user=Depends(get_current_user), db: Session = Depends(get_session)
):
if not request.app.state.config.ENABLE_API_KEYS or not has_permission( if not request.app.state.config.ENABLE_API_KEYS or not has_permission(
user.id, "features.api_keys", request.app.state.config.USER_PERMISSIONS user.id, "features.api_keys", request.app.state.config.USER_PERMISSIONS
): ):
@ -1159,7 +1209,7 @@ async def generate_api_key(request: Request, user=Depends(get_current_user)):
) )
api_key = create_api_key() api_key = create_api_key()
success = Users.update_user_api_key_by_id(user.id, api_key) success = Users.update_user_api_key_by_id(user.id, api_key, db=db)
if success: if success:
return { return {
@ -1171,14 +1221,18 @@ async def generate_api_key(request: Request, user=Depends(get_current_user)):
# delete api key # delete api key
@router.delete("/api_key", response_model=bool) @router.delete("/api_key", response_model=bool)
async def delete_api_key(user=Depends(get_current_user)): async def delete_api_key(
return Users.delete_user_api_key_by_id(user.id) user=Depends(get_current_user), db: Session = Depends(get_session)
):
return Users.delete_user_api_key_by_id(user.id, db=db)
# get api key # get api key
@router.get("/api_key", response_model=ApiKey) @router.get("/api_key", response_model=ApiKey)
async def get_api_key(user=Depends(get_current_user)): async def get_api_key(
api_key = Users.get_user_api_key_by_id(user.id) user=Depends(get_current_user), db: Session = Depends(get_session)
):
api_key = Users.get_user_api_key_by_id(user.id, db=db)
if api_key: if api_key:
return { return {
"api_key": api_key, "api_key": api_key,

View file

@ -1,6 +1,7 @@
import json import json
import logging import logging
from typing import Optional from typing import Optional
from sqlalchemy.orm import Session
import asyncio import asyncio
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
@ -23,6 +24,7 @@ from open_webui.models.chats import (
) )
from open_webui.models.tags import TagModel, Tags from open_webui.models.tags import TagModel, Tags
from open_webui.models.folders import Folders from open_webui.models.folders import Folders
from open_webui.internal.db import get_session
from open_webui.config import ENABLE_ADMIN_CHAT_ACCESS, ENABLE_ADMIN_EXPORT from open_webui.config import ENABLE_ADMIN_CHAT_ACCESS, ENABLE_ADMIN_EXPORT
from open_webui.constants import ERROR_MESSAGES from open_webui.constants import ERROR_MESSAGES
@ -49,6 +51,7 @@ def get_session_user_chat_list(
page: Optional[int] = None, page: Optional[int] = None,
include_pinned: Optional[bool] = False, include_pinned: Optional[bool] = False,
include_folders: Optional[bool] = False, include_folders: Optional[bool] = False,
db: Session = Depends(get_session),
): ):
try: try:
if page is not None: if page is not None:
@ -61,10 +64,14 @@ def get_session_user_chat_list(
include_pinned=include_pinned, include_pinned=include_pinned,
skip=skip, skip=skip,
limit=limit, limit=limit,
db=db,
) )
else: else:
return Chats.get_chat_title_id_list_by_user_id( return Chats.get_chat_title_id_list_by_user_id(
user.id, include_folders=include_folders, include_pinned=include_pinned user.id,
include_folders=include_folders,
include_pinned=include_pinned,
db=db,
) )
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
@ -84,12 +91,13 @@ def get_session_user_chat_usage_stats(
items_per_page: Optional[int] = 50, items_per_page: Optional[int] = 50,
page: Optional[int] = 1, page: Optional[int] = 1,
user=Depends(get_verified_user), user=Depends(get_verified_user),
db: Session = Depends(get_session),
): ):
try: try:
limit = items_per_page limit = items_per_page
skip = (page - 1) * limit skip = (page - 1) * limit
result = Chats.get_chats_by_user_id(user.id, skip=skip, limit=limit) result = Chats.get_chats_by_user_id(user.id, skip=skip, limit=limit, db=db)
chats = result.items chats = result.items
total = result.total total = result.total
@ -216,6 +224,7 @@ class ChatStatsExportList(BaseModel):
def _process_chat_for_export(chat) -> Optional[ChatStatsExport]: def _process_chat_for_export(chat) -> Optional[ChatStatsExport]:
try: try:
def get_message_content_length(message): def get_message_content_length(message):
content = message.get("content", "") content = message.get("content", "")
if isinstance(content, str): if isinstance(content, str):
@ -348,7 +357,9 @@ def _process_chat_for_export(chat) -> Optional[ChatStatsExport]:
return None return None
def calculate_chat_stats(user_id, skip=0, limit=10, filter=None): def calculate_chat_stats(
user_id, skip=0, limit=10, filter=None, db: Optional[Session] = None
):
if filter is None: if filter is None:
filter = {} filter = {}
@ -357,6 +368,7 @@ def calculate_chat_stats(user_id, skip=0, limit=10, filter=None):
skip=skip, skip=skip,
limit=limit, limit=limit,
filter=filter, filter=filter,
db=db,
) )
chat_stats_export_list = [] chat_stats_export_list = []
@ -368,14 +380,21 @@ def calculate_chat_stats(user_id, skip=0, limit=10, filter=None):
return chat_stats_export_list, result.total return chat_stats_export_list, result.total
async def generate_chat_stats_jsonl_generator(user_id, filter): async def generate_chat_stats_jsonl_generator(
user_id, filter, db: Optional[Session] = None
):
skip = 0 skip = 0
limit = CHAT_EXPORT_PAGE_ITEM_COUNT limit = CHAT_EXPORT_PAGE_ITEM_COUNT
while True: while True:
# Use asyncio.to_thread to make the blocking DB call non-blocking # Use asyncio.to_thread to make the blocking DB call non-blocking
result = await asyncio.to_thread( result = await asyncio.to_thread(
Chats.get_chats_by_user_id, user_id, filter=filter, skip=skip, limit=limit Chats.get_chats_by_user_id,
user_id,
filter=filter,
skip=skip,
limit=limit,
db=db,
) )
if not result.items: if not result.items:
break break
@ -386,7 +405,7 @@ async def generate_chat_stats_jsonl_generator(user_id, filter):
if chat_stat: if chat_stat:
yield chat_stat.model_dump_json() + "\n" yield chat_stat.model_dump_json() + "\n"
except Exception as e: except Exception as e:
log.exception(f"Error processing chat {chat.id}: {e}") log.exception(f"Error processing chat {chat.id}: {e}")
skip += limit skip += limit
@ -400,6 +419,7 @@ async def export_chat_stats(
page: Optional[int] = 1, page: Optional[int] = 1,
stream: bool = False, stream: bool = False,
user=Depends(get_verified_user), user=Depends(get_verified_user),
db: Session = Depends(get_session),
): ):
# Check if the user has permission to share/export chats # Check if the user has permission to share/export chats
if (user.role != "admin") and ( if (user.role != "admin") and (
@ -415,7 +435,7 @@ async def export_chat_stats(
filter = {"order_by": "created_at", "direction": "asc"} filter = {"order_by": "created_at", "direction": "asc"}
if chat_id: if chat_id:
chat = Chats.get_chat_by_id(chat_id) chat = Chats.get_chat_by_id(chat_id, db=db)
if chat: if chat:
filter["start_time"] = chat.created_at filter["start_time"] = chat.created_at
@ -426,7 +446,7 @@ async def export_chat_stats(
if stream: if stream:
return StreamingResponse( return StreamingResponse(
generate_chat_stats_jsonl_generator(user.id, filter), generate_chat_stats_jsonl_generator(user.id, filter, db=db),
media_type="application/x-ndjson", media_type="application/x-ndjson",
headers={ headers={
"Content-Disposition": f"attachment; filename=chat-stats-export-{user.id}.jsonl" "Content-Disposition": f"attachment; filename=chat-stats-export-{user.id}.jsonl"
@ -437,7 +457,7 @@ async def export_chat_stats(
skip = (page - 1) * limit skip = (page - 1) * limit
chat_stats_export_list, total = await asyncio.to_thread( chat_stats_export_list, total = await asyncio.to_thread(
calculate_chat_stats, user.id, skip, limit, filter calculate_chat_stats, user.id, skip, limit, filter, db=db
) )
return ChatStatsExportList( return ChatStatsExportList(
@ -452,7 +472,11 @@ async def export_chat_stats(
@router.delete("/", response_model=bool) @router.delete("/", response_model=bool)
async def delete_all_user_chats(request: Request, user=Depends(get_verified_user)): async def delete_all_user_chats(
request: Request,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
if user.role == "user" and not has_permission( if user.role == "user" and not has_permission(
user.id, "chat.delete", request.app.state.config.USER_PERMISSIONS user.id, "chat.delete", request.app.state.config.USER_PERMISSIONS
@ -462,7 +486,7 @@ async def delete_all_user_chats(request: Request, user=Depends(get_verified_user
detail=ERROR_MESSAGES.ACCESS_PROHIBITED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
) )
result = Chats.delete_chats_by_user_id(user.id) result = Chats.delete_chats_by_user_id(user.id, db=db)
return result return result
@ -479,6 +503,7 @@ async def get_user_chat_list_by_user_id(
order_by: Optional[str] = None, order_by: Optional[str] = None,
direction: Optional[str] = None, direction: Optional[str] = None,
user=Depends(get_admin_user), user=Depends(get_admin_user),
db: Session = Depends(get_session),
): ):
if not ENABLE_ADMIN_CHAT_ACCESS: if not ENABLE_ADMIN_CHAT_ACCESS:
raise HTTPException( raise HTTPException(
@ -501,7 +526,7 @@ async def get_user_chat_list_by_user_id(
filter["direction"] = direction filter["direction"] = direction
return Chats.get_chat_list_by_user_id( return Chats.get_chat_list_by_user_id(
user_id, include_archived=True, filter=filter, skip=skip, limit=limit user_id, include_archived=True, filter=filter, skip=skip, limit=limit, db=db
) )
@ -511,9 +536,13 @@ async def get_user_chat_list_by_user_id(
@router.post("/new", response_model=Optional[ChatResponse]) @router.post("/new", response_model=Optional[ChatResponse])
async def create_new_chat(form_data: ChatForm, user=Depends(get_verified_user)): async def create_new_chat(
form_data: ChatForm,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
try: try:
chat = Chats.insert_new_chat(user.id, form_data) chat = Chats.insert_new_chat(user.id, form_data, db=db)
return ChatResponse(**chat.model_dump()) return ChatResponse(**chat.model_dump())
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
@ -528,9 +557,13 @@ async def create_new_chat(form_data: ChatForm, user=Depends(get_verified_user)):
@router.post("/import", response_model=list[ChatResponse]) @router.post("/import", response_model=list[ChatResponse])
async def import_chats(form_data: ChatsImportForm, user=Depends(get_verified_user)): async def import_chats(
form_data: ChatsImportForm,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
try: try:
chats = Chats.import_chats(user.id, form_data.chats) chats = Chats.import_chats(user.id, form_data.chats, db=db)
return chats return chats
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
@ -546,7 +579,10 @@ async def import_chats(form_data: ChatsImportForm, user=Depends(get_verified_use
@router.get("/search", response_model=list[ChatTitleIdResponse]) @router.get("/search", response_model=list[ChatTitleIdResponse])
def search_user_chats( def search_user_chats(
text: str, page: Optional[int] = None, user=Depends(get_verified_user) text: str,
page: Optional[int] = None,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
): ):
if page is None: if page is None:
page = 1 page = 1
@ -557,7 +593,7 @@ def search_user_chats(
chat_list = [ chat_list = [
ChatTitleIdResponse(**chat.model_dump()) ChatTitleIdResponse(**chat.model_dump())
for chat in Chats.get_chats_by_user_id_and_search_text( for chat in Chats.get_chats_by_user_id_and_search_text(
user.id, text, skip=skip, limit=limit user.id, text, skip=skip, limit=limit, db=db
) )
] ]
@ -566,9 +602,9 @@ def search_user_chats(
if page == 1 and len(words) == 1 and words[0].startswith("tag:"): if page == 1 and len(words) == 1 and words[0].startswith("tag:"):
tag_id = words[0].replace("tag:", "") tag_id = words[0].replace("tag:", "")
if len(chat_list) == 0: if len(chat_list) == 0:
if Tags.get_tag_by_name_and_user_id(tag_id, user.id): if Tags.get_tag_by_name_and_user_id(tag_id, user.id, db=db):
log.debug(f"deleting tag: {tag_id}") log.debug(f"deleting tag: {tag_id}")
Tags.delete_tag_by_name_and_user_id(tag_id, user.id) Tags.delete_tag_by_name_and_user_id(tag_id, user.id, db=db)
return chat_list return chat_list
@ -579,23 +615,30 @@ def search_user_chats(
@router.get("/folder/{folder_id}", response_model=list[ChatResponse]) @router.get("/folder/{folder_id}", response_model=list[ChatResponse])
async def get_chats_by_folder_id(folder_id: str, user=Depends(get_verified_user)): async def get_chats_by_folder_id(
folder_id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
):
folder_ids = [folder_id] folder_ids = [folder_id]
children_folders = Folders.get_children_folders_by_id_and_user_id( children_folders = Folders.get_children_folders_by_id_and_user_id(
folder_id, user.id folder_id, user.id, db=db
) )
if children_folders: if children_folders:
folder_ids.extend([folder.id for folder in children_folders]) folder_ids.extend([folder.id for folder in children_folders])
return [ return [
ChatResponse(**chat.model_dump()) ChatResponse(**chat.model_dump())
for chat in Chats.get_chats_by_folder_ids_and_user_id(folder_ids, user.id) for chat in Chats.get_chats_by_folder_ids_and_user_id(
folder_ids, user.id, db=db
)
] ]
@router.get("/folder/{folder_id}/list") @router.get("/folder/{folder_id}/list")
async def get_chat_list_by_folder_id( async def get_chat_list_by_folder_id(
folder_id: str, page: Optional[int] = 1, user=Depends(get_verified_user) folder_id: str,
page: Optional[int] = 1,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
): ):
try: try:
limit = 10 limit = 10
@ -604,7 +647,7 @@ async def get_chat_list_by_folder_id(
return [ return [
{"title": chat.title, "id": chat.id, "updated_at": chat.updated_at} {"title": chat.title, "id": chat.id, "updated_at": chat.updated_at}
for chat in Chats.get_chats_by_folder_id_and_user_id( for chat in Chats.get_chats_by_folder_id_and_user_id(
folder_id, user.id, skip=skip, limit=limit folder_id, user.id, skip=skip, limit=limit, db=db
) )
] ]
@ -621,10 +664,12 @@ async def get_chat_list_by_folder_id(
@router.get("/pinned", response_model=list[ChatTitleIdResponse]) @router.get("/pinned", response_model=list[ChatTitleIdResponse])
async def get_user_pinned_chats(user=Depends(get_verified_user)): async def get_user_pinned_chats(
user=Depends(get_verified_user), db: Session = Depends(get_session)
):
return [ return [
ChatTitleIdResponse(**chat.model_dump()) ChatTitleIdResponse(**chat.model_dump())
for chat in Chats.get_pinned_chats_by_user_id(user.id) for chat in Chats.get_pinned_chats_by_user_id(user.id, db=db)
] ]
@ -634,10 +679,12 @@ async def get_user_pinned_chats(user=Depends(get_verified_user)):
@router.get("/all", response_model=list[ChatResponse]) @router.get("/all", response_model=list[ChatResponse])
async def get_user_chats(user=Depends(get_verified_user)): async def get_user_chats(
user=Depends(get_verified_user), db: Session = Depends(get_session)
):
return [ return [
ChatResponse(**chat.model_dump()) ChatResponse(**chat.model_dump())
for chat in Chats.get_chats_by_user_id(user.id) for chat in Chats.get_chats_by_user_id(user.id, db=db)
] ]
@ -647,10 +694,12 @@ async def get_user_chats(user=Depends(get_verified_user)):
@router.get("/all/archived", response_model=list[ChatResponse]) @router.get("/all/archived", response_model=list[ChatResponse])
async def get_user_archived_chats(user=Depends(get_verified_user)): async def get_user_archived_chats(
user=Depends(get_verified_user), db: Session = Depends(get_session)
):
return [ return [
ChatResponse(**chat.model_dump()) ChatResponse(**chat.model_dump())
for chat in Chats.get_archived_chats_by_user_id(user.id) for chat in Chats.get_archived_chats_by_user_id(user.id, db=db)
] ]
@ -660,9 +709,11 @@ async def get_user_archived_chats(user=Depends(get_verified_user)):
@router.get("/all/tags", response_model=list[TagModel]) @router.get("/all/tags", response_model=list[TagModel])
async def get_all_user_tags(user=Depends(get_verified_user)): async def get_all_user_tags(
user=Depends(get_verified_user), db: Session = Depends(get_session)
):
try: try:
tags = Tags.get_tags_by_user_id(user.id) tags = Tags.get_tags_by_user_id(user.id, db=db)
return tags return tags
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
@ -677,13 +728,15 @@ async def get_all_user_tags(user=Depends(get_verified_user)):
@router.get("/all/db", response_model=list[ChatResponse]) @router.get("/all/db", response_model=list[ChatResponse])
async def get_all_user_chats_in_db(user=Depends(get_admin_user)): async def get_all_user_chats_in_db(
user=Depends(get_admin_user), db: Session = Depends(get_session)
):
if not ENABLE_ADMIN_EXPORT: if not ENABLE_ADMIN_EXPORT:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
) )
return [ChatResponse(**chat.model_dump()) for chat in Chats.get_chats()] return [ChatResponse(**chat.model_dump()) for chat in Chats.get_chats(db=db)]
############################ ############################
@ -698,6 +751,7 @@ async def get_archived_session_user_chat_list(
order_by: Optional[str] = None, order_by: Optional[str] = None,
direction: Optional[str] = None, direction: Optional[str] = None,
user=Depends(get_verified_user), user=Depends(get_verified_user),
db: Session = Depends(get_session),
): ):
if page is None: if page is None:
page = 1 page = 1
@ -720,6 +774,7 @@ async def get_archived_session_user_chat_list(
filter=filter, filter=filter,
skip=skip, skip=skip,
limit=limit, limit=limit,
db=db,
) )
] ]
@ -732,8 +787,10 @@ async def get_archived_session_user_chat_list(
@router.post("/archive/all", response_model=bool) @router.post("/archive/all", response_model=bool)
async def archive_all_chats(user=Depends(get_verified_user)): async def archive_all_chats(
return Chats.archive_all_chats_by_user_id(user.id) user=Depends(get_verified_user), db: Session = Depends(get_session)
):
return Chats.archive_all_chats_by_user_id(user.id, db=db)
############################ ############################
@ -742,8 +799,10 @@ async def archive_all_chats(user=Depends(get_verified_user)):
@router.post("/unarchive/all", response_model=bool) @router.post("/unarchive/all", response_model=bool)
async def unarchive_all_chats(user=Depends(get_verified_user)): async def unarchive_all_chats(
return Chats.unarchive_all_chats_by_user_id(user.id) user=Depends(get_verified_user), db: Session = Depends(get_session)
):
return Chats.unarchive_all_chats_by_user_id(user.id, db=db)
############################ ############################
@ -752,16 +811,18 @@ async def unarchive_all_chats(user=Depends(get_verified_user)):
@router.get("/share/{share_id}", response_model=Optional[ChatResponse]) @router.get("/share/{share_id}", response_model=Optional[ChatResponse])
async def get_shared_chat_by_id(share_id: str, user=Depends(get_verified_user)): async def get_shared_chat_by_id(
share_id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
):
if user.role == "pending": if user.role == "pending":
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
) )
if user.role == "user" or (user.role == "admin" and not ENABLE_ADMIN_CHAT_ACCESS): if user.role == "user" or (user.role == "admin" and not ENABLE_ADMIN_CHAT_ACCESS):
chat = Chats.get_chat_by_share_id(share_id) chat = Chats.get_chat_by_share_id(share_id, db=db)
elif user.role == "admin" and ENABLE_ADMIN_CHAT_ACCESS: elif user.role == "admin" and ENABLE_ADMIN_CHAT_ACCESS:
chat = Chats.get_chat_by_id(share_id) chat = Chats.get_chat_by_id(share_id, db=db)
if chat: if chat:
return ChatResponse(**chat.model_dump()) return ChatResponse(**chat.model_dump())
@ -788,13 +849,15 @@ class TagFilterForm(TagForm):
@router.post("/tags", response_model=list[ChatTitleIdResponse]) @router.post("/tags", response_model=list[ChatTitleIdResponse])
async def get_user_chat_list_by_tag_name( async def get_user_chat_list_by_tag_name(
form_data: TagFilterForm, user=Depends(get_verified_user) form_data: TagFilterForm,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
): ):
chats = Chats.get_chat_list_by_user_id_and_tag_name( chats = Chats.get_chat_list_by_user_id_and_tag_name(
user.id, form_data.name, form_data.skip, form_data.limit user.id, form_data.name, form_data.skip, form_data.limit, db=db
) )
if len(chats) == 0: if len(chats) == 0:
Tags.delete_tag_by_name_and_user_id(form_data.name, user.id) Tags.delete_tag_by_name_and_user_id(form_data.name, user.id, db=db)
return chats return chats
@ -805,8 +868,10 @@ async def get_user_chat_list_by_tag_name(
@router.get("/{id}", response_model=Optional[ChatResponse]) @router.get("/{id}", response_model=Optional[ChatResponse])
async def get_chat_by_id(id: str, user=Depends(get_verified_user)): async def get_chat_by_id(
chat = Chats.get_chat_by_id_and_user_id(id, user.id) id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
):
chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db)
if chat: if chat:
return ChatResponse(**chat.model_dump()) return ChatResponse(**chat.model_dump())
@ -824,12 +889,15 @@ async def get_chat_by_id(id: str, user=Depends(get_verified_user)):
@router.post("/{id}", response_model=Optional[ChatResponse]) @router.post("/{id}", response_model=Optional[ChatResponse])
async def update_chat_by_id( async def update_chat_by_id(
id: str, form_data: ChatForm, user=Depends(get_verified_user) id: str,
form_data: ChatForm,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
): ):
chat = Chats.get_chat_by_id_and_user_id(id, user.id) chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db)
if chat: if chat:
updated_chat = {**chat.chat, **form_data.chat} updated_chat = {**chat.chat, **form_data.chat}
chat = Chats.update_chat_by_id(id, updated_chat) chat = Chats.update_chat_by_id(id, updated_chat, db=db)
return ChatResponse(**chat.model_dump()) return ChatResponse(**chat.model_dump())
else: else:
raise HTTPException( raise HTTPException(
@ -847,9 +915,13 @@ class MessageForm(BaseModel):
@router.post("/{id}/messages/{message_id}", response_model=Optional[ChatResponse]) @router.post("/{id}/messages/{message_id}", response_model=Optional[ChatResponse])
async def update_chat_message_by_id( async def update_chat_message_by_id(
id: str, message_id: str, form_data: MessageForm, user=Depends(get_verified_user) id: str,
message_id: str,
form_data: MessageForm,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
): ):
chat = Chats.get_chat_by_id(id) chat = Chats.get_chat_by_id(id, db=db)
if not chat: if not chat:
raise HTTPException( raise HTTPException(
@ -869,6 +941,7 @@ async def update_chat_message_by_id(
{ {
"content": form_data.content, "content": form_data.content,
}, },
db=db,
) )
event_emitter = get_event_emitter( event_emitter = get_event_emitter(
@ -905,9 +978,13 @@ class EventForm(BaseModel):
@router.post("/{id}/messages/{message_id}/event", response_model=Optional[bool]) @router.post("/{id}/messages/{message_id}/event", response_model=Optional[bool])
async def send_chat_message_event_by_id( async def send_chat_message_event_by_id(
id: str, message_id: str, form_data: EventForm, user=Depends(get_verified_user) id: str,
message_id: str,
form_data: EventForm,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
): ):
chat = Chats.get_chat_by_id(id) chat = Chats.get_chat_by_id(id, db=db)
if not chat: if not chat:
raise HTTPException( raise HTTPException(
@ -945,14 +1022,19 @@ async def send_chat_message_event_by_id(
@router.delete("/{id}", response_model=bool) @router.delete("/{id}", response_model=bool)
async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified_user)): async def delete_chat_by_id(
request: Request,
id: str,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
if user.role == "admin": if user.role == "admin":
chat = Chats.get_chat_by_id(id) chat = Chats.get_chat_by_id(id, db=db)
for tag in chat.meta.get("tags", []): for tag in chat.meta.get("tags", []):
if Chats.count_chats_by_tag_name_and_user_id(tag, user.id) == 1: if Chats.count_chats_by_tag_name_and_user_id(tag, user.id, db=db) == 1:
Tags.delete_tag_by_name_and_user_id(tag, user.id) Tags.delete_tag_by_name_and_user_id(tag, user.id, db=db)
result = Chats.delete_chat_by_id(id) result = Chats.delete_chat_by_id(id, db=db)
return result return result
else: else:
@ -964,12 +1046,12 @@ async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified
detail=ERROR_MESSAGES.ACCESS_PROHIBITED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
) )
chat = Chats.get_chat_by_id(id) chat = Chats.get_chat_by_id(id, db=db)
for tag in chat.meta.get("tags", []): for tag in chat.meta.get("tags", []):
if Chats.count_chats_by_tag_name_and_user_id(tag, user.id) == 1: if Chats.count_chats_by_tag_name_and_user_id(tag, user.id, db=db) == 1:
Tags.delete_tag_by_name_and_user_id(tag, user.id) Tags.delete_tag_by_name_and_user_id(tag, user.id, db=db)
result = Chats.delete_chat_by_id_and_user_id(id, user.id) result = Chats.delete_chat_by_id_and_user_id(id, user.id, db=db)
return result return result
@ -979,8 +1061,10 @@ async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified
@router.get("/{id}/pinned", response_model=Optional[bool]) @router.get("/{id}/pinned", response_model=Optional[bool])
async def get_pinned_status_by_id(id: str, user=Depends(get_verified_user)): async def get_pinned_status_by_id(
chat = Chats.get_chat_by_id_and_user_id(id, user.id) id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
):
chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db)
if chat: if chat:
return chat.pinned return chat.pinned
else: else:
@ -995,10 +1079,12 @@ async def get_pinned_status_by_id(id: str, user=Depends(get_verified_user)):
@router.post("/{id}/pin", response_model=Optional[ChatResponse]) @router.post("/{id}/pin", response_model=Optional[ChatResponse])
async def pin_chat_by_id(id: str, user=Depends(get_verified_user)): async def pin_chat_by_id(
chat = Chats.get_chat_by_id_and_user_id(id, user.id) id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
):
chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db)
if chat: if chat:
chat = Chats.toggle_chat_pinned_by_id(id) chat = Chats.toggle_chat_pinned_by_id(id, db=db)
return chat return chat
else: else:
raise HTTPException( raise HTTPException(
@ -1017,9 +1103,12 @@ class CloneForm(BaseModel):
@router.post("/{id}/clone", response_model=Optional[ChatResponse]) @router.post("/{id}/clone", response_model=Optional[ChatResponse])
async def clone_chat_by_id( async def clone_chat_by_id(
form_data: CloneForm, id: str, user=Depends(get_verified_user) form_data: CloneForm,
id: str,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
): ):
chat = Chats.get_chat_by_id_and_user_id(id, user.id) chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db)
if chat: if chat:
updated_chat = { updated_chat = {
**chat.chat, **chat.chat,
@ -1040,6 +1129,7 @@ async def clone_chat_by_id(
} }
) )
], ],
db=db,
) )
if chats: if chats:
@ -1062,12 +1152,14 @@ async def clone_chat_by_id(
@router.post("/{id}/clone/shared", response_model=Optional[ChatResponse]) @router.post("/{id}/clone/shared", response_model=Optional[ChatResponse])
async def clone_shared_chat_by_id(id: str, user=Depends(get_verified_user)): async def clone_shared_chat_by_id(
id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
):
if user.role == "admin": if user.role == "admin":
chat = Chats.get_chat_by_id(id) chat = Chats.get_chat_by_id(id, db=db)
else: else:
chat = Chats.get_chat_by_share_id(id) chat = Chats.get_chat_by_share_id(id, db=db)
if chat: if chat:
updated_chat = { updated_chat = {
@ -1089,6 +1181,7 @@ async def clone_shared_chat_by_id(id: str, user=Depends(get_verified_user)):
} }
) )
], ],
db=db,
) )
if chats: if chats:
@ -1111,23 +1204,28 @@ async def clone_shared_chat_by_id(id: str, user=Depends(get_verified_user)):
@router.post("/{id}/archive", response_model=Optional[ChatResponse]) @router.post("/{id}/archive", response_model=Optional[ChatResponse])
async def archive_chat_by_id(id: str, user=Depends(get_verified_user)): async def archive_chat_by_id(
chat = Chats.get_chat_by_id_and_user_id(id, user.id) id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
):
chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db)
if chat: if chat:
chat = Chats.toggle_chat_archive_by_id(id) chat = Chats.toggle_chat_archive_by_id(id, db=db)
# Delete tags if chat is archived # Delete tags if chat is archived
if chat.archived: if chat.archived:
for tag_id in chat.meta.get("tags", []): for tag_id in chat.meta.get("tags", []):
if Chats.count_chats_by_tag_name_and_user_id(tag_id, user.id) == 0: if (
Chats.count_chats_by_tag_name_and_user_id(tag_id, user.id, db=db)
== 0
):
log.debug(f"deleting tag: {tag_id}") log.debug(f"deleting tag: {tag_id}")
Tags.delete_tag_by_name_and_user_id(tag_id, user.id) Tags.delete_tag_by_name_and_user_id(tag_id, user.id, db=db)
else: else:
for tag_id in chat.meta.get("tags", []): for tag_id in chat.meta.get("tags", []):
tag = Tags.get_tag_by_name_and_user_id(tag_id, user.id) tag = Tags.get_tag_by_name_and_user_id(tag_id, user.id, db=db)
if tag is None: if tag is None:
log.debug(f"inserting tag: {tag_id}") log.debug(f"inserting tag: {tag_id}")
tag = Tags.insert_new_tag(tag_id, user.id) tag = Tags.insert_new_tag(tag_id, user.id, db=db)
return ChatResponse(**chat.model_dump()) return ChatResponse(**chat.model_dump())
else: else:
@ -1142,7 +1240,12 @@ async def archive_chat_by_id(id: str, user=Depends(get_verified_user)):
@router.post("/{id}/share", response_model=Optional[ChatResponse]) @router.post("/{id}/share", response_model=Optional[ChatResponse])
async def share_chat_by_id(request: Request, id: str, user=Depends(get_verified_user)): async def share_chat_by_id(
request: Request,
id: str,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
if (user.role != "admin") and ( if (user.role != "admin") and (
not has_permission( not has_permission(
user.id, "chat.share", request.app.state.config.USER_PERMISSIONS user.id, "chat.share", request.app.state.config.USER_PERMISSIONS
@ -1153,14 +1256,14 @@ async def share_chat_by_id(request: Request, id: str, user=Depends(get_verified_
detail=ERROR_MESSAGES.ACCESS_PROHIBITED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
) )
chat = Chats.get_chat_by_id_and_user_id(id, user.id) chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db)
if chat: if chat:
if chat.share_id: if chat.share_id:
shared_chat = Chats.update_shared_chat_by_chat_id(chat.id) shared_chat = Chats.update_shared_chat_by_chat_id(chat.id, db=db)
return ChatResponse(**shared_chat.model_dump()) return ChatResponse(**shared_chat.model_dump())
shared_chat = Chats.insert_shared_chat_by_chat_id(chat.id) shared_chat = Chats.insert_shared_chat_by_chat_id(chat.id, db=db)
if not shared_chat: if not shared_chat:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
@ -1181,14 +1284,16 @@ async def share_chat_by_id(request: Request, id: str, user=Depends(get_verified_
@router.delete("/{id}/share", response_model=Optional[bool]) @router.delete("/{id}/share", response_model=Optional[bool])
async def delete_shared_chat_by_id(id: str, user=Depends(get_verified_user)): async def delete_shared_chat_by_id(
chat = Chats.get_chat_by_id_and_user_id(id, user.id) id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
):
chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db)
if chat: if chat:
if not chat.share_id: if not chat.share_id:
return False return False
result = Chats.delete_shared_chat_by_chat_id(id) result = Chats.delete_shared_chat_by_chat_id(id, db=db)
update_result = Chats.update_chat_share_id_by_id(id, None) update_result = Chats.update_chat_share_id_by_id(id, None, db=db)
return result and update_result != None return result and update_result != None
else: else:
@ -1209,12 +1314,15 @@ class ChatFolderIdForm(BaseModel):
@router.post("/{id}/folder", response_model=Optional[ChatResponse]) @router.post("/{id}/folder", response_model=Optional[ChatResponse])
async def update_chat_folder_id_by_id( async def update_chat_folder_id_by_id(
id: str, form_data: ChatFolderIdForm, user=Depends(get_verified_user) id: str,
form_data: ChatFolderIdForm,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
): ):
chat = Chats.get_chat_by_id_and_user_id(id, user.id) chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db)
if chat: if chat:
chat = Chats.update_chat_folder_id_by_id_and_user_id( chat = Chats.update_chat_folder_id_by_id_and_user_id(
id, user.id, form_data.folder_id id, user.id, form_data.folder_id, db=db
) )
return ChatResponse(**chat.model_dump()) return ChatResponse(**chat.model_dump())
else: else:
@ -1229,11 +1337,13 @@ async def update_chat_folder_id_by_id(
@router.get("/{id}/tags", response_model=list[TagModel]) @router.get("/{id}/tags", response_model=list[TagModel])
async def get_chat_tags_by_id(id: str, user=Depends(get_verified_user)): async def get_chat_tags_by_id(
chat = Chats.get_chat_by_id_and_user_id(id, user.id) id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
):
chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db)
if chat: if chat:
tags = chat.meta.get("tags", []) tags = chat.meta.get("tags", [])
return Tags.get_tags_by_ids_and_user_id(tags, user.id) return Tags.get_tags_by_ids_and_user_id(tags, user.id, db=db)
else: else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
@ -1247,9 +1357,12 @@ async def get_chat_tags_by_id(id: str, user=Depends(get_verified_user)):
@router.post("/{id}/tags", response_model=list[TagModel]) @router.post("/{id}/tags", response_model=list[TagModel])
async def add_tag_by_id_and_tag_name( async def add_tag_by_id_and_tag_name(
id: str, form_data: TagForm, user=Depends(get_verified_user) id: str,
form_data: TagForm,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
): ):
chat = Chats.get_chat_by_id_and_user_id(id, user.id) chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db)
if chat: if chat:
tags = chat.meta.get("tags", []) tags = chat.meta.get("tags", [])
tag_id = form_data.name.replace(" ", "_").lower() tag_id = form_data.name.replace(" ", "_").lower()
@ -1262,12 +1375,12 @@ async def add_tag_by_id_and_tag_name(
if tag_id not in tags: if tag_id not in tags:
Chats.add_chat_tag_by_id_and_user_id_and_tag_name( Chats.add_chat_tag_by_id_and_user_id_and_tag_name(
id, user.id, form_data.name id, user.id, form_data.name, db=db
) )
chat = Chats.get_chat_by_id_and_user_id(id, user.id) chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db)
tags = chat.meta.get("tags", []) tags = chat.meta.get("tags", [])
return Tags.get_tags_by_ids_and_user_id(tags, user.id) return Tags.get_tags_by_ids_and_user_id(tags, user.id, db=db)
else: else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT() status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
@ -1281,18 +1394,26 @@ async def add_tag_by_id_and_tag_name(
@router.delete("/{id}/tags", response_model=list[TagModel]) @router.delete("/{id}/tags", response_model=list[TagModel])
async def delete_tag_by_id_and_tag_name( async def delete_tag_by_id_and_tag_name(
id: str, form_data: TagForm, user=Depends(get_verified_user) id: str,
form_data: TagForm,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
): ):
chat = Chats.get_chat_by_id_and_user_id(id, user.id) chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db)
if chat: if chat:
Chats.delete_tag_by_id_and_user_id_and_tag_name(id, user.id, form_data.name) Chats.delete_tag_by_id_and_user_id_and_tag_name(
id, user.id, form_data.name, db=db
)
if Chats.count_chats_by_tag_name_and_user_id(form_data.name, user.id) == 0: if (
Tags.delete_tag_by_name_and_user_id(form_data.name, user.id) Chats.count_chats_by_tag_name_and_user_id(form_data.name, user.id, db=db)
== 0
):
Tags.delete_tag_by_name_and_user_id(form_data.name, user.id, db=db)
chat = Chats.get_chat_by_id_and_user_id(id, user.id) chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db)
tags = chat.meta.get("tags", []) tags = chat.meta.get("tags", [])
return Tags.get_tags_by_ids_and_user_id(tags, user.id) return Tags.get_tags_by_ids_and_user_id(tags, user.id, db=db)
else: else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
@ -1305,14 +1426,16 @@ async def delete_tag_by_id_and_tag_name(
@router.delete("/{id}/tags/all", response_model=Optional[bool]) @router.delete("/{id}/tags/all", response_model=Optional[bool])
async def delete_all_tags_by_id(id: str, user=Depends(get_verified_user)): async def delete_all_tags_by_id(
chat = Chats.get_chat_by_id_and_user_id(id, user.id) id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
):
chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db)
if chat: if chat:
Chats.delete_all_tags_by_id_and_user_id(id, user.id) Chats.delete_all_tags_by_id_and_user_id(id, user.id, db=db)
for tag in chat.meta.get("tags", []): for tag in chat.meta.get("tags", []):
if Chats.count_chats_by_tag_name_and_user_id(tag, user.id) == 0: if Chats.count_chats_by_tag_name_and_user_id(tag, user.id, db=db) == 0:
Tags.delete_tag_by_name_and_user_id(tag, user.id) Tags.delete_tag_by_name_and_user_id(tag, user.id, db=db)
return True return True
else: else:

View file

@ -15,6 +15,8 @@ from open_webui.models.feedbacks import (
from open_webui.constants import ERROR_MESSAGES from open_webui.constants import ERROR_MESSAGES
from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.internal.db import get_session
from sqlalchemy.orm import Session
router = APIRouter() router = APIRouter()
@ -60,38 +62,50 @@ async def update_config(
@router.get("/feedbacks/all", response_model=list[FeedbackResponse]) @router.get("/feedbacks/all", response_model=list[FeedbackResponse])
async def get_all_feedbacks(user=Depends(get_admin_user)): async def get_all_feedbacks(
feedbacks = Feedbacks.get_all_feedbacks() user=Depends(get_admin_user), db: Session = Depends(get_session)
):
feedbacks = Feedbacks.get_all_feedbacks(db=db)
return feedbacks return feedbacks
@router.get("/feedbacks/all/ids", response_model=list[FeedbackIdResponse]) @router.get("/feedbacks/all/ids", response_model=list[FeedbackIdResponse])
async def get_all_feedback_ids(user=Depends(get_admin_user)): async def get_all_feedback_ids(
feedbacks = Feedbacks.get_all_feedbacks() user=Depends(get_admin_user), db: Session = Depends(get_session)
):
feedbacks = Feedbacks.get_all_feedbacks(db=db)
return feedbacks return feedbacks
@router.delete("/feedbacks/all") @router.delete("/feedbacks/all")
async def delete_all_feedbacks(user=Depends(get_admin_user)): async def delete_all_feedbacks(
success = Feedbacks.delete_all_feedbacks() user=Depends(get_admin_user), db: Session = Depends(get_session)
):
success = Feedbacks.delete_all_feedbacks(db=db)
return success return success
@router.get("/feedbacks/all/export", response_model=list[FeedbackModel]) @router.get("/feedbacks/all/export", response_model=list[FeedbackModel])
async def export_all_feedbacks(user=Depends(get_admin_user)): async def export_all_feedbacks(
feedbacks = Feedbacks.get_all_feedbacks() user=Depends(get_admin_user), db: Session = Depends(get_session)
):
feedbacks = Feedbacks.get_all_feedbacks(db=db)
return feedbacks return feedbacks
@router.get("/feedbacks/user", response_model=list[FeedbackUserResponse]) @router.get("/feedbacks/user", response_model=list[FeedbackUserResponse])
async def get_feedbacks(user=Depends(get_verified_user)): async def get_feedbacks(
feedbacks = Feedbacks.get_feedbacks_by_user_id(user.id) user=Depends(get_verified_user), db: Session = Depends(get_session)
):
feedbacks = Feedbacks.get_feedbacks_by_user_id(user.id, db=db)
return feedbacks return feedbacks
@router.delete("/feedbacks", response_model=bool) @router.delete("/feedbacks", response_model=bool)
async def delete_feedbacks(user=Depends(get_verified_user)): async def delete_feedbacks(
success = Feedbacks.delete_feedbacks_by_user_id(user.id) user=Depends(get_verified_user), db: Session = Depends(get_session)
):
success = Feedbacks.delete_feedbacks_by_user_id(user.id, db=db)
return success return success
@ -104,6 +118,7 @@ async def get_feedbacks(
direction: Optional[str] = None, direction: Optional[str] = None,
page: Optional[int] = 1, page: Optional[int] = 1,
user=Depends(get_admin_user), user=Depends(get_admin_user),
db: Session = Depends(get_session),
): ):
limit = PAGE_ITEM_COUNT limit = PAGE_ITEM_COUNT
@ -116,7 +131,7 @@ async def get_feedbacks(
if direction: if direction:
filter["direction"] = direction filter["direction"] = direction
result = Feedbacks.get_feedback_items(filter=filter, skip=skip, limit=limit) result = Feedbacks.get_feedback_items(filter=filter, skip=skip, limit=limit, db=db)
return result return result
@ -125,8 +140,11 @@ async def create_feedback(
request: Request, request: Request,
form_data: FeedbackForm, form_data: FeedbackForm,
user=Depends(get_verified_user), user=Depends(get_verified_user),
db: Session = Depends(get_session),
): ):
feedback = Feedbacks.insert_new_feedback(user_id=user.id, form_data=form_data) feedback = Feedbacks.insert_new_feedback(
user_id=user.id, form_data=form_data, db=db
)
if not feedback: if not feedback:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
@ -137,11 +155,15 @@ async def create_feedback(
@router.get("/feedback/{id}", response_model=FeedbackModel) @router.get("/feedback/{id}", response_model=FeedbackModel)
async def get_feedback_by_id(id: str, user=Depends(get_verified_user)): async def get_feedback_by_id(
id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
):
if user.role == "admin": if user.role == "admin":
feedback = Feedbacks.get_feedback_by_id(id=id) feedback = Feedbacks.get_feedback_by_id(id=id, db=db)
else: else:
feedback = Feedbacks.get_feedback_by_id_and_user_id(id=id, user_id=user.id) feedback = Feedbacks.get_feedback_by_id_and_user_id(
id=id, user_id=user.id, db=db
)
if not feedback: if not feedback:
raise HTTPException( raise HTTPException(
@ -153,13 +175,16 @@ async def get_feedback_by_id(id: str, user=Depends(get_verified_user)):
@router.post("/feedback/{id}", response_model=FeedbackModel) @router.post("/feedback/{id}", response_model=FeedbackModel)
async def update_feedback_by_id( async def update_feedback_by_id(
id: str, form_data: FeedbackForm, user=Depends(get_verified_user) id: str,
form_data: FeedbackForm,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
): ):
if user.role == "admin": if user.role == "admin":
feedback = Feedbacks.update_feedback_by_id(id=id, form_data=form_data) feedback = Feedbacks.update_feedback_by_id(id=id, form_data=form_data, db=db)
else: else:
feedback = Feedbacks.update_feedback_by_id_and_user_id( feedback = Feedbacks.update_feedback_by_id_and_user_id(
id=id, user_id=user.id, form_data=form_data id=id, user_id=user.id, form_data=form_data, db=db
) )
if not feedback: if not feedback:
@ -171,11 +196,15 @@ async def update_feedback_by_id(
@router.delete("/feedback/{id}") @router.delete("/feedback/{id}")
async def delete_feedback_by_id(id: str, user=Depends(get_verified_user)): async def delete_feedback_by_id(
id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
):
if user.role == "admin": if user.role == "admin":
success = Feedbacks.delete_feedback_by_id(id=id) success = Feedbacks.delete_feedback_by_id(id=id, db=db)
else: else:
success = Feedbacks.delete_feedback_by_id_and_user_id(id=id, user_id=user.id) success = Feedbacks.delete_feedback_by_id_and_user_id(
id=id, user_id=user.id, db=db
)
if not success: if not success:
raise HTTPException( raise HTTPException(

View file

@ -22,6 +22,8 @@ from fastapi import (
) )
from fastapi.responses import FileResponse, StreamingResponse from fastapi.responses import FileResponse, StreamingResponse
from sqlalchemy.orm import Session
from open_webui.internal.db import get_session, SessionLocal
from open_webui.constants import ERROR_MESSAGES from open_webui.constants import ERROR_MESSAGES
from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT
@ -62,9 +64,12 @@ router = APIRouter()
# TODO: Optimize this function to use the knowledge_file table for faster lookups. # TODO: Optimize this function to use the knowledge_file table for faster lookups.
def has_access_to_file( def has_access_to_file(
file_id: Optional[str], access_type: str, user=Depends(get_verified_user) file_id: Optional[str],
access_type: str,
user=Depends(get_verified_user),
db: Optional[Session] = None,
) -> bool: ) -> bool:
file = Files.get_file_by_id(file_id) file = Files.get_file_by_id(file_id, db=db)
log.debug(f"Checking if user has {access_type} access to file") log.debug(f"Checking if user has {access_type} access to file")
if not file: if not file:
raise HTTPException( raise HTTPException(
@ -73,31 +78,33 @@ def has_access_to_file(
) )
# Check if the file is associated with any knowledge bases the user has access to # Check if the file is associated with any knowledge bases the user has access to
knowledge_bases = Knowledges.get_knowledges_by_file_id(file_id) knowledge_bases = Knowledges.get_knowledges_by_file_id(file_id, db=db)
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user.id)} user_group_ids = {
group.id for group in Groups.get_groups_by_member_id(user.id, db=db)
}
for knowledge_base in knowledge_bases: for knowledge_base in knowledge_bases:
if knowledge_base.user_id == user.id or has_access( if knowledge_base.user_id == user.id or has_access(
user.id, access_type, knowledge_base.access_control, user_group_ids user.id, access_type, knowledge_base.access_control, user_group_ids, db=db
): ):
return True return True
knowledge_base_id = file.meta.get("collection_name") if file.meta else None knowledge_base_id = file.meta.get("collection_name") if file.meta else None
if knowledge_base_id: if knowledge_base_id:
knowledge_bases = Knowledges.get_knowledge_bases_by_user_id( knowledge_bases = Knowledges.get_knowledge_bases_by_user_id(
user.id, access_type user.id, access_type, db=db
) )
for knowledge_base in knowledge_bases: for knowledge_base in knowledge_bases:
if knowledge_base.id == knowledge_base_id: if knowledge_base.id == knowledge_base_id:
return True return True
# Check if the file is associated with any channels the user has access to # Check if the file is associated with any channels the user has access to
channels = Channels.get_channels_by_file_id_and_user_id(file_id, user.id) channels = Channels.get_channels_by_file_id_and_user_id(file_id, user.id, db=db)
if access_type == "read" and channels: if access_type == "read" and channels:
return True return True
# Check if the file is associated with any chats the user has access to # Check if the file is associated with any chats the user has access to
# TODO: Granular access control for chats # TODO: Granular access control for chats
chats = Chats.get_shared_chats_by_file_id(file_id) chats = Chats.get_shared_chats_by_file_id(file_id, db=db)
if chats: if chats:
return True return True
@ -109,47 +116,78 @@ def has_access_to_file(
############################ ############################
def process_uploaded_file(request, file, file_path, file_item, file_metadata, user): def process_uploaded_file(
try: request,
if file.content_type: file,
stt_supported_content_types = getattr( file_path,
request.app.state.config, "STT_SUPPORTED_CONTENT_TYPES", [] file_item,
) file_metadata,
user,
db: Optional[Session] = None,
):
def _process_handler(db_session):
try:
if file.content_type:
stt_supported_content_types = getattr(
request.app.state.config, "STT_SUPPORTED_CONTENT_TYPES", []
)
if strict_match_mime_type(stt_supported_content_types, file.content_type): if strict_match_mime_type(
file_path = Storage.get_file(file_path) stt_supported_content_types, file.content_type
result = transcribe(request, file_path, file_metadata, user) ):
file_path_processed = Storage.get_file(file_path)
result = transcribe(
request, file_path_processed, file_metadata, user
)
process_file(
request,
ProcessFileForm(
file_id=file_item.id, content=result.get("text", "")
),
user=user,
db=db_session,
)
elif (not file.content_type.startswith(("image/", "video/"))) or (
request.app.state.config.CONTENT_EXTRACTION_ENGINE == "external"
):
process_file(
request,
ProcessFileForm(file_id=file_item.id),
user=user,
db=db_session,
)
else:
raise Exception(
f"File type {file.content_type} is not supported for processing"
)
else:
log.info(
f"File type {file.content_type} is not provided, but trying to process anyway"
)
process_file( process_file(
request, request,
ProcessFileForm( ProcessFileForm(file_id=file_item.id),
file_id=file_item.id, content=result.get("text", "")
),
user=user, user=user,
db=db_session,
) )
elif (not file.content_type.startswith(("image/", "video/"))) or (
request.app.state.config.CONTENT_EXTRACTION_ENGINE == "external"
):
process_file(request, ProcessFileForm(file_id=file_item.id), user=user)
else:
raise Exception(
f"File type {file.content_type} is not supported for processing"
)
else:
log.info(
f"File type {file.content_type} is not provided, but trying to process anyway"
)
process_file(request, ProcessFileForm(file_id=file_item.id), user=user)
except Exception as e: except Exception as e:
log.error(f"Error processing file: {file_item.id}") log.error(f"Error processing file: {file_item.id}")
Files.update_file_data_by_id( Files.update_file_data_by_id(
file_item.id, file_item.id,
{ {
"status": "failed", "status": "failed",
"error": str(e.detail) if hasattr(e, "detail") else str(e), "error": str(e.detail) if hasattr(e, "detail") else str(e),
}, },
) db=db_session,
)
if db:
_process_handler(db)
else:
with SessionLocal() as db_session:
_process_handler(db_session)
@router.post("/", response_model=FileModelResponse) @router.post("/", response_model=FileModelResponse)
@ -161,6 +199,7 @@ def upload_file(
process: bool = Query(True), process: bool = Query(True),
process_in_background: bool = Query(True), process_in_background: bool = Query(True),
user=Depends(get_verified_user), user=Depends(get_verified_user),
db: Session = Depends(get_session),
): ):
return upload_file_handler( return upload_file_handler(
request, request,
@ -170,6 +209,7 @@ def upload_file(
process_in_background=process_in_background, process_in_background=process_in_background,
user=user, user=user,
background_tasks=background_tasks, background_tasks=background_tasks,
db=db,
) )
@ -181,6 +221,7 @@ def upload_file_handler(
process_in_background: bool = Query(True), process_in_background: bool = Query(True),
user=Depends(get_verified_user), user=Depends(get_verified_user),
background_tasks: Optional[BackgroundTasks] = None, background_tasks: Optional[BackgroundTasks] = None,
db: Optional[Session] = None,
): ):
log.info(f"file.content_type: {file.content_type} {process}") log.info(f"file.content_type: {file.content_type} {process}")
@ -248,14 +289,17 @@ def upload_file_handler(
}, },
} }
), ),
db=db,
) )
if "channel_id" in file_metadata: if "channel_id" in file_metadata:
channel = Channels.get_channel_by_id_and_user_id( channel = Channels.get_channel_by_id_and_user_id(
file_metadata["channel_id"], user.id file_metadata["channel_id"], user.id, db=db
) )
if channel: if channel:
Channels.add_file_to_channel_by_id(channel.id, file_item.id, user.id) Channels.add_file_to_channel_by_id(
channel.id, file_item.id, user.id, db=db
)
if process: if process:
if background_tasks and process_in_background: if background_tasks and process_in_background:
@ -277,6 +321,7 @@ def upload_file_handler(
file_item, file_item,
file_metadata, file_metadata,
user, user,
db=db,
) )
return {"status": True, **file_item.model_dump()} return {"status": True, **file_item.model_dump()}
else: else:
@ -302,11 +347,15 @@ def upload_file_handler(
@router.get("/", response_model=list[FileModelResponse]) @router.get("/", response_model=list[FileModelResponse])
async def list_files(user=Depends(get_verified_user), content: bool = Query(True)): async def list_files(
user=Depends(get_verified_user),
content: bool = Query(True),
db: Session = Depends(get_session),
):
if user.role == "admin": if user.role == "admin":
files = Files.get_files() files = Files.get_files(db=db)
else: else:
files = Files.get_files_by_user_id(user.id) files = Files.get_files_by_user_id(user.id, db=db)
if not content: if not content:
for file in files: for file in files:
@ -329,15 +378,16 @@ async def search_files(
), ),
content: bool = Query(True), content: bool = Query(True),
user=Depends(get_verified_user), user=Depends(get_verified_user),
db: Session = Depends(get_session),
): ):
""" """
Search for files by filename with support for wildcard patterns. Search for files by filename with support for wildcard patterns.
""" """
# Get files according to user role # Get files according to user role
if user.role == "admin": if user.role == "admin":
files = Files.get_files() files = Files.get_files(db=db)
else: else:
files = Files.get_files_by_user_id(user.id) files = Files.get_files_by_user_id(user.id, db=db)
# Get matching files # Get matching files
matching_files = [ matching_files = [
@ -364,8 +414,10 @@ async def search_files(
@router.delete("/all") @router.delete("/all")
async def delete_all_files(user=Depends(get_admin_user)): async def delete_all_files(
result = Files.delete_all_files() user=Depends(get_admin_user), db: Session = Depends(get_session)
):
result = Files.delete_all_files(db=db)
if result: if result:
try: try:
Storage.delete_all_files() Storage.delete_all_files()
@ -391,8 +443,10 @@ async def delete_all_files(user=Depends(get_admin_user)):
@router.get("/{id}", response_model=Optional[FileModel]) @router.get("/{id}", response_model=Optional[FileModel])
async def get_file_by_id(id: str, user=Depends(get_verified_user)): async def get_file_by_id(
file = Files.get_file_by_id(id) id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
):
file = Files.get_file_by_id(id, db=db)
if not file: if not file:
raise HTTPException( raise HTTPException(
@ -403,7 +457,7 @@ async def get_file_by_id(id: str, user=Depends(get_verified_user)):
if ( if (
file.user_id == user.id file.user_id == user.id
or user.role == "admin" or user.role == "admin"
or has_access_to_file(id, "read", user) or has_access_to_file(id, "read", user, db=db)
): ):
return file return file
else: else:
@ -415,9 +469,12 @@ async def get_file_by_id(id: str, user=Depends(get_verified_user)):
@router.get("/{id}/process/status") @router.get("/{id}/process/status")
async def get_file_process_status( async def get_file_process_status(
id: str, stream: bool = Query(False), user=Depends(get_verified_user) id: str,
stream: bool = Query(False),
user=Depends(get_verified_user),
db: Session = Depends(get_session),
): ):
file = Files.get_file_by_id(id) file = Files.get_file_by_id(id, db=db)
if not file: if not file:
raise HTTPException( raise HTTPException(
@ -428,7 +485,7 @@ async def get_file_process_status(
if ( if (
file.user_id == user.id file.user_id == user.id
or user.role == "admin" or user.role == "admin"
or has_access_to_file(id, "read", user) or has_access_to_file(id, "read", user, db=db)
): ):
if stream: if stream:
MAX_FILE_PROCESSING_DURATION = 3600 * 2 MAX_FILE_PROCESSING_DURATION = 3600 * 2
@ -436,7 +493,7 @@ async def get_file_process_status(
async def event_stream(file_item): async def event_stream(file_item):
if file_item: if file_item:
for _ in range(MAX_FILE_PROCESSING_DURATION): for _ in range(MAX_FILE_PROCESSING_DURATION):
file_item = Files.get_file_by_id(file_item.id) file_item = Files.get_file_by_id(file_item.id, db=db)
if file_item: if file_item:
data = file_item.model_dump().get("data", {}) data = file_item.model_dump().get("data", {})
status = data.get("status") status = data.get("status")
@ -476,8 +533,10 @@ async def get_file_process_status(
@router.get("/{id}/data/content") @router.get("/{id}/data/content")
async def get_file_data_content_by_id(id: str, user=Depends(get_verified_user)): async def get_file_data_content_by_id(
file = Files.get_file_by_id(id) id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
):
file = Files.get_file_by_id(id, db=db)
if not file: if not file:
raise HTTPException( raise HTTPException(
@ -488,7 +547,7 @@ async def get_file_data_content_by_id(id: str, user=Depends(get_verified_user)):
if ( if (
file.user_id == user.id file.user_id == user.id
or user.role == "admin" or user.role == "admin"
or has_access_to_file(id, "read", user) or has_access_to_file(id, "read", user, db=db)
): ):
return {"content": file.data.get("content", "")} return {"content": file.data.get("content", "")}
else: else:
@ -509,9 +568,13 @@ class ContentForm(BaseModel):
@router.post("/{id}/data/content/update") @router.post("/{id}/data/content/update")
async def update_file_data_content_by_id( async def update_file_data_content_by_id(
request: Request, id: str, form_data: ContentForm, user=Depends(get_verified_user) request: Request,
id: str,
form_data: ContentForm,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
): ):
file = Files.get_file_by_id(id) file = Files.get_file_by_id(id, db=db)
if not file: if not file:
raise HTTPException( raise HTTPException(
@ -522,7 +585,7 @@ async def update_file_data_content_by_id(
if ( if (
file.user_id == user.id file.user_id == user.id
or user.role == "admin" or user.role == "admin"
or has_access_to_file(id, "write", user) or has_access_to_file(id, "write", user, db=db)
): ):
try: try:
process_file( process_file(
@ -530,7 +593,7 @@ async def update_file_data_content_by_id(
ProcessFileForm(file_id=id, content=form_data.content), ProcessFileForm(file_id=id, content=form_data.content),
user=user, user=user,
) )
file = Files.get_file_by_id(id=id) file = Files.get_file_by_id(id=id, db=db)
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
log.error(f"Error processing file: {file.id}") log.error(f"Error processing file: {file.id}")
@ -550,9 +613,12 @@ async def update_file_data_content_by_id(
@router.get("/{id}/content") @router.get("/{id}/content")
async def get_file_content_by_id( async def get_file_content_by_id(
id: str, user=Depends(get_verified_user), attachment: bool = Query(False) id: str,
user=Depends(get_verified_user),
attachment: bool = Query(False),
db: Session = Depends(get_session),
): ):
file = Files.get_file_by_id(id) file = Files.get_file_by_id(id, db=db)
if not file: if not file:
raise HTTPException( raise HTTPException(
@ -563,7 +629,7 @@ async def get_file_content_by_id(
if ( if (
file.user_id == user.id file.user_id == user.id
or user.role == "admin" or user.role == "admin"
or has_access_to_file(id, "read", user) or has_access_to_file(id, "read", user, db=db)
): ):
try: try:
file_path = Storage.get_file(file.path) file_path = Storage.get_file(file.path)
@ -619,8 +685,10 @@ async def get_file_content_by_id(
@router.get("/{id}/content/html") @router.get("/{id}/content/html")
async def get_html_file_content_by_id(id: str, user=Depends(get_verified_user)): async def get_html_file_content_by_id(
file = Files.get_file_by_id(id) id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
):
file = Files.get_file_by_id(id, db=db)
if not file: if not file:
raise HTTPException( raise HTTPException(
@ -628,7 +696,7 @@ async def get_html_file_content_by_id(id: str, user=Depends(get_verified_user)):
detail=ERROR_MESSAGES.NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND,
) )
file_user = Users.get_user_by_id(file.user_id) file_user = Users.get_user_by_id(file.user_id, db=db)
if not file_user.role == "admin": if not file_user.role == "admin":
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
@ -638,7 +706,7 @@ async def get_html_file_content_by_id(id: str, user=Depends(get_verified_user)):
if ( if (
file.user_id == user.id file.user_id == user.id
or user.role == "admin" or user.role == "admin"
or has_access_to_file(id, "read", user) or has_access_to_file(id, "read", user, db=db)
): ):
try: try:
file_path = Storage.get_file(file.path) file_path = Storage.get_file(file.path)
@ -668,8 +736,10 @@ async def get_html_file_content_by_id(id: str, user=Depends(get_verified_user)):
@router.get("/{id}/content/{file_name}") @router.get("/{id}/content/{file_name}")
async def get_file_content_by_id(id: str, user=Depends(get_verified_user)): async def get_file_content_by_id(
file = Files.get_file_by_id(id) id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
):
file = Files.get_file_by_id(id, db=db)
if not file: if not file:
raise HTTPException( raise HTTPException(
@ -680,7 +750,7 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
if ( if (
file.user_id == user.id file.user_id == user.id
or user.role == "admin" or user.role == "admin"
or has_access_to_file(id, "read", user) or has_access_to_file(id, "read", user, db=db)
): ):
file_path = file.path file_path = file.path
@ -730,8 +800,10 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
@router.delete("/{id}") @router.delete("/{id}")
async def delete_file_by_id(id: str, user=Depends(get_verified_user)): async def delete_file_by_id(
file = Files.get_file_by_id(id) id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
):
file = Files.get_file_by_id(id, db=db)
if not file: if not file:
raise HTTPException( raise HTTPException(
@ -742,10 +814,10 @@ async def delete_file_by_id(id: str, user=Depends(get_verified_user)):
if ( if (
file.user_id == user.id file.user_id == user.id
or user.role == "admin" or user.role == "admin"
or has_access_to_file(id, "write", user) or has_access_to_file(id, "write", user, db=db)
): ):
result = Files.delete_file_by_id(id) result = Files.delete_file_by_id(id, db=db)
if result: if result:
try: try:
Storage.delete_file(file.path) Storage.delete_file(file.path)

View file

@ -22,6 +22,8 @@ from open_webui.models.knowledge import Knowledges
from open_webui.config import UPLOAD_DIR from open_webui.config import UPLOAD_DIR
from open_webui.constants import ERROR_MESSAGES from open_webui.constants import ERROR_MESSAGES
from open_webui.internal.db import get_session
from sqlalchemy.orm import Session
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status, Request from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status, Request
@ -44,7 +46,11 @@ router = APIRouter()
@router.get("/", response_model=list[FolderNameIdResponse]) @router.get("/", response_model=list[FolderNameIdResponse])
async def get_folders(request: Request, user=Depends(get_verified_user)): async def get_folders(
request: Request,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
if request.app.state.config.ENABLE_FOLDERS is False: if request.app.state.config.ENABLE_FOLDERS is False:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
@ -55,22 +61,23 @@ async def get_folders(request: Request, user=Depends(get_verified_user)):
user.id, user.id,
"features.folders", "features.folders",
request.app.state.config.USER_PERMISSIONS, request.app.state.config.USER_PERMISSIONS,
db=db,
): ):
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
) )
folders = Folders.get_folders_by_user_id(user.id) folders = Folders.get_folders_by_user_id(user.id, db=db)
# Verify folder data integrity # Verify folder data integrity
folder_list = [] folder_list = []
for folder in folders: for folder in folders:
if folder.parent_id and not Folders.get_folder_by_id_and_user_id( if folder.parent_id and not Folders.get_folder_by_id_and_user_id(
folder.parent_id, user.id folder.parent_id, user.id, db=db
): ):
folder = Folders.update_folder_parent_id_by_id_and_user_id( folder = Folders.update_folder_parent_id_by_id_and_user_id(
folder.id, user.id, None folder.id, user.id, None, db=db
) )
if folder.data: if folder.data:
@ -80,12 +87,12 @@ async def get_folders(request: Request, user=Depends(get_verified_user)):
if file.get("type") == "file": if file.get("type") == "file":
if Files.check_access_by_user_id( if Files.check_access_by_user_id(
file.get("id"), user.id, "read" file.get("id"), user.id, "read", db=db
): ):
valid_files.append(file) valid_files.append(file)
elif file.get("type") == "collection": elif file.get("type") == "collection":
if Knowledges.check_access_by_user_id( if Knowledges.check_access_by_user_id(
file.get("id"), user.id, "read" file.get("id"), user.id, "read", db=db
): ):
valid_files.append(file) valid_files.append(file)
else: else:
@ -93,7 +100,7 @@ async def get_folders(request: Request, user=Depends(get_verified_user)):
folder.data["files"] = valid_files folder.data["files"] = valid_files
Folders.update_folder_by_id_and_user_id( Folders.update_folder_by_id_and_user_id(
folder.id, user.id, FolderUpdateForm(data=folder.data) folder.id, user.id, FolderUpdateForm(data=folder.data), db=db
) )
folder_list.append(FolderNameIdResponse(**folder.model_dump())) folder_list.append(FolderNameIdResponse(**folder.model_dump()))
@ -107,9 +114,13 @@ async def get_folders(request: Request, user=Depends(get_verified_user)):
@router.post("/") @router.post("/")
def create_folder(form_data: FolderForm, user=Depends(get_verified_user)): def create_folder(
form_data: FolderForm,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
folder = Folders.get_folder_by_parent_id_and_user_id_and_name( folder = Folders.get_folder_by_parent_id_and_user_id_and_name(
None, user.id, form_data.name None, user.id, form_data.name, db=db
) )
if folder: if folder:
@ -119,7 +130,7 @@ def create_folder(form_data: FolderForm, user=Depends(get_verified_user)):
) )
try: try:
folder = Folders.insert_new_folder(user.id, form_data) folder = Folders.insert_new_folder(user.id, form_data, db=db)
return folder return folder
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
@ -136,8 +147,10 @@ def create_folder(form_data: FolderForm, user=Depends(get_verified_user)):
@router.get("/{id}", response_model=Optional[FolderModel]) @router.get("/{id}", response_model=Optional[FolderModel])
async def get_folder_by_id(id: str, user=Depends(get_verified_user)): async def get_folder_by_id(
folder = Folders.get_folder_by_id_and_user_id(id, user.id) id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
):
folder = Folders.get_folder_by_id_and_user_id(id, user.id, db=db)
if folder: if folder:
return folder return folder
else: else:
@ -154,15 +167,18 @@ async def get_folder_by_id(id: str, user=Depends(get_verified_user)):
@router.post("/{id}/update") @router.post("/{id}/update")
async def update_folder_name_by_id( async def update_folder_name_by_id(
id: str, form_data: FolderUpdateForm, user=Depends(get_verified_user) id: str,
form_data: FolderUpdateForm,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
): ):
folder = Folders.get_folder_by_id_and_user_id(id, user.id) folder = Folders.get_folder_by_id_and_user_id(id, user.id, db=db)
if folder: if folder:
if form_data.name is not None: if form_data.name is not None:
# Check if folder with same name exists # Check if folder with same name exists
existing_folder = Folders.get_folder_by_parent_id_and_user_id_and_name( existing_folder = Folders.get_folder_by_parent_id_and_user_id_and_name(
folder.parent_id, user.id, form_data.name folder.parent_id, user.id, form_data.name, db=db
) )
if existing_folder and existing_folder.id != id: if existing_folder and existing_folder.id != id:
raise HTTPException( raise HTTPException(
@ -171,7 +187,9 @@ async def update_folder_name_by_id(
) )
try: try:
folder = Folders.update_folder_by_id_and_user_id(id, user.id, form_data) folder = Folders.update_folder_by_id_and_user_id(
id, user.id, form_data, db=db
)
return folder return folder
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
@ -198,12 +216,15 @@ class FolderParentIdForm(BaseModel):
@router.post("/{id}/update/parent") @router.post("/{id}/update/parent")
async def update_folder_parent_id_by_id( async def update_folder_parent_id_by_id(
id: str, form_data: FolderParentIdForm, user=Depends(get_verified_user) id: str,
form_data: FolderParentIdForm,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
): ):
folder = Folders.get_folder_by_id_and_user_id(id, user.id) folder = Folders.get_folder_by_id_and_user_id(id, user.id, db=db)
if folder: if folder:
existing_folder = Folders.get_folder_by_parent_id_and_user_id_and_name( existing_folder = Folders.get_folder_by_parent_id_and_user_id_and_name(
form_data.parent_id, user.id, folder.name form_data.parent_id, user.id, folder.name, db=db
) )
if existing_folder: if existing_folder:
@ -214,7 +235,7 @@ async def update_folder_parent_id_by_id(
try: try:
folder = Folders.update_folder_parent_id_by_id_and_user_id( folder = Folders.update_folder_parent_id_by_id_and_user_id(
id, user.id, form_data.parent_id id, user.id, form_data.parent_id, db=db
) )
return folder return folder
except Exception as e: except Exception as e:
@ -242,13 +263,16 @@ class FolderIsExpandedForm(BaseModel):
@router.post("/{id}/update/expanded") @router.post("/{id}/update/expanded")
async def update_folder_is_expanded_by_id( async def update_folder_is_expanded_by_id(
id: str, form_data: FolderIsExpandedForm, user=Depends(get_verified_user) id: str,
form_data: FolderIsExpandedForm,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
): ):
folder = Folders.get_folder_by_id_and_user_id(id, user.id) folder = Folders.get_folder_by_id_and_user_id(id, user.id, db=db)
if folder: if folder:
try: try:
folder = Folders.update_folder_is_expanded_by_id_and_user_id( folder = Folders.update_folder_is_expanded_by_id_and_user_id(
id, user.id, form_data.is_expanded id, user.id, form_data.is_expanded, db=db
) )
return folder return folder
except Exception as e: except Exception as e:
@ -276,10 +300,11 @@ async def delete_folder_by_id(
id: str, id: str,
delete_contents: Optional[bool] = True, delete_contents: Optional[bool] = True,
user=Depends(get_verified_user), user=Depends(get_verified_user),
db: Session = Depends(get_session),
): ):
if Chats.count_chats_by_folder_id_and_user_id(id, user.id): if Chats.count_chats_by_folder_id_and_user_id(id, user.id, db=db):
chat_delete_permission = has_permission( chat_delete_permission = has_permission(
user.id, "chat.delete", request.app.state.config.USER_PERMISSIONS user.id, "chat.delete", request.app.state.config.USER_PERMISSIONS, db=db
) )
if user.role != "admin" and not chat_delete_permission: if user.role != "admin" and not chat_delete_permission:
raise HTTPException( raise HTTPException(
@ -288,19 +313,21 @@ async def delete_folder_by_id(
) )
folders = [] folders = []
folders.append(Folders.get_folder_by_id_and_user_id(id, user.id)) folders.append(Folders.get_folder_by_id_and_user_id(id, user.id, db=db))
while folders: while folders:
folder = folders.pop() folder = folders.pop()
if folder: if folder:
try: try:
folder_ids = Folders.delete_folder_by_id_and_user_id(id, user.id) folder_ids = Folders.delete_folder_by_id_and_user_id(id, user.id, db=db)
for folder_id in folder_ids: for folder_id in folder_ids:
if delete_contents: if delete_contents:
Chats.delete_chats_by_user_id_and_folder_id(user.id, folder_id) Chats.delete_chats_by_user_id_and_folder_id(
user.id, folder_id, db=db
)
else: else:
Chats.move_chats_by_user_id_and_folder_id( Chats.move_chats_by_user_id_and_folder_id(
user.id, folder_id, None user.id, folder_id, None, db=db
) )
return True return True
@ -314,7 +341,7 @@ async def delete_folder_by_id(
finally: finally:
# Get all subfolders # Get all subfolders
subfolders = Folders.get_folders_by_parent_id_and_user_id( subfolders = Folders.get_folders_by_parent_id_and_user_id(
folder.id, user.id folder.id, user.id, db=db
) )
folders.extend(subfolders) folders.extend(subfolders)

View file

@ -24,6 +24,8 @@ from open_webui.constants import ERROR_MESSAGES
from fastapi import APIRouter, Depends, HTTPException, Request, status from fastapi import APIRouter, Depends, HTTPException, Request, status
from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.utils.auth import get_admin_user, get_verified_user
from pydantic import BaseModel, HttpUrl from pydantic import BaseModel, HttpUrl
from open_webui.internal.db import get_session
from sqlalchemy.orm import Session
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -37,13 +39,13 @@ router = APIRouter()
@router.get("/", response_model=list[FunctionResponse]) @router.get("/", response_model=list[FunctionResponse])
async def get_functions(user=Depends(get_verified_user)): async def get_functions(user=Depends(get_verified_user), db: Session = Depends(get_session)):
return Functions.get_functions() return Functions.get_functions(db=db)
@router.get("/list", response_model=list[FunctionUserResponse]) @router.get("/list", response_model=list[FunctionUserResponse])
async def get_function_list(user=Depends(get_admin_user)): async def get_function_list(user=Depends(get_admin_user), db: Session = Depends(get_session)):
return Functions.get_function_list() return Functions.get_function_list(db=db)
############################ ############################
@ -52,8 +54,8 @@ async def get_function_list(user=Depends(get_admin_user)):
@router.get("/export", response_model=list[FunctionModel | FunctionWithValvesModel]) @router.get("/export", response_model=list[FunctionModel | FunctionWithValvesModel])
async def get_functions(include_valves: bool = False, user=Depends(get_admin_user)): async def get_functions(include_valves: bool = False, user=Depends(get_admin_user), db: Session = Depends(get_session)):
return Functions.get_functions(include_valves=include_valves) return Functions.get_functions(include_valves=include_valves, db=db)
############################ ############################
@ -142,7 +144,7 @@ class SyncFunctionsForm(BaseModel):
@router.post("/sync", response_model=list[FunctionWithValvesModel]) @router.post("/sync", response_model=list[FunctionWithValvesModel])
async def sync_functions( async def sync_functions(
request: Request, form_data: SyncFunctionsForm, user=Depends(get_admin_user) request: Request, form_data: SyncFunctionsForm, user=Depends(get_admin_user), db: Session = Depends(get_session)
): ):
try: try:
for function in form_data.functions: for function in form_data.functions:
@ -164,7 +166,7 @@ async def sync_functions(
) )
raise e raise e
return Functions.sync_functions(user.id, form_data.functions) return Functions.sync_functions(user.id, form_data.functions, db=db)
except Exception as e: except Exception as e:
log.exception(f"Failed to load a function: {e}") log.exception(f"Failed to load a function: {e}")
raise HTTPException( raise HTTPException(
@ -180,7 +182,7 @@ async def sync_functions(
@router.post("/create", response_model=Optional[FunctionResponse]) @router.post("/create", response_model=Optional[FunctionResponse])
async def create_new_function( async def create_new_function(
request: Request, form_data: FunctionForm, user=Depends(get_admin_user) request: Request, form_data: FunctionForm, user=Depends(get_admin_user), db: Session = Depends(get_session)
): ):
if not form_data.id.isidentifier(): if not form_data.id.isidentifier():
raise HTTPException( raise HTTPException(
@ -190,7 +192,7 @@ async def create_new_function(
form_data.id = form_data.id.lower() form_data.id = form_data.id.lower()
function = Functions.get_function_by_id(form_data.id) function = Functions.get_function_by_id(form_data.id, db=db)
if function is None: if function is None:
try: try:
form_data.content = replace_imports(form_data.content) form_data.content = replace_imports(form_data.content)
@ -203,13 +205,13 @@ async def create_new_function(
FUNCTIONS = request.app.state.FUNCTIONS FUNCTIONS = request.app.state.FUNCTIONS
FUNCTIONS[form_data.id] = function_module FUNCTIONS[form_data.id] = function_module
function = Functions.insert_new_function(user.id, function_type, form_data) function = Functions.insert_new_function(user.id, function_type, form_data, db=db)
function_cache_dir = CACHE_DIR / "functions" / form_data.id function_cache_dir = CACHE_DIR / "functions" / form_data.id
function_cache_dir.mkdir(parents=True, exist_ok=True) function_cache_dir.mkdir(parents=True, exist_ok=True)
if function_type == "filter" and getattr(function_module, "toggle", None): if function_type == "filter" and getattr(function_module, "toggle", None):
Functions.update_function_metadata_by_id(id, {"toggle": True}) Functions.update_function_metadata_by_id(form_data.id, {"toggle": True}, db=db)
if function: if function:
return function return function
@ -237,8 +239,8 @@ async def create_new_function(
@router.get("/id/{id}", response_model=Optional[FunctionModel]) @router.get("/id/{id}", response_model=Optional[FunctionModel])
async def get_function_by_id(id: str, user=Depends(get_admin_user)): async def get_function_by_id(id: str, user=Depends(get_admin_user), db: Session = Depends(get_session)):
function = Functions.get_function_by_id(id) function = Functions.get_function_by_id(id, db=db)
if function: if function:
return function return function
@ -255,11 +257,11 @@ async def get_function_by_id(id: str, user=Depends(get_admin_user)):
@router.post("/id/{id}/toggle", response_model=Optional[FunctionModel]) @router.post("/id/{id}/toggle", response_model=Optional[FunctionModel])
async def toggle_function_by_id(id: str, user=Depends(get_admin_user)): async def toggle_function_by_id(id: str, user=Depends(get_admin_user), db: Session = Depends(get_session)):
function = Functions.get_function_by_id(id) function = Functions.get_function_by_id(id, db=db)
if function: if function:
function = Functions.update_function_by_id( function = Functions.update_function_by_id(
id, {"is_active": not function.is_active} id, {"is_active": not function.is_active}, db=db
) )
if function: if function:
@ -282,11 +284,11 @@ async def toggle_function_by_id(id: str, user=Depends(get_admin_user)):
@router.post("/id/{id}/toggle/global", response_model=Optional[FunctionModel]) @router.post("/id/{id}/toggle/global", response_model=Optional[FunctionModel])
async def toggle_global_by_id(id: str, user=Depends(get_admin_user)): async def toggle_global_by_id(id: str, user=Depends(get_admin_user), db: Session = Depends(get_session)):
function = Functions.get_function_by_id(id) function = Functions.get_function_by_id(id, db=db)
if function: if function:
function = Functions.update_function_by_id( function = Functions.update_function_by_id(
id, {"is_global": not function.is_global} id, {"is_global": not function.is_global}, db=db
) )
if function: if function:
@ -310,7 +312,7 @@ async def toggle_global_by_id(id: str, user=Depends(get_admin_user)):
@router.post("/id/{id}/update", response_model=Optional[FunctionModel]) @router.post("/id/{id}/update", response_model=Optional[FunctionModel])
async def update_function_by_id( async def update_function_by_id(
request: Request, id: str, form_data: FunctionForm, user=Depends(get_admin_user) request: Request, id: str, form_data: FunctionForm, user=Depends(get_admin_user), db: Session = Depends(get_session)
): ):
try: try:
form_data.content = replace_imports(form_data.content) form_data.content = replace_imports(form_data.content)
@ -325,10 +327,10 @@ async def update_function_by_id(
updated = {**form_data.model_dump(exclude={"id"}), "type": function_type} updated = {**form_data.model_dump(exclude={"id"}), "type": function_type}
log.debug(updated) log.debug(updated)
function = Functions.update_function_by_id(id, updated) function = Functions.update_function_by_id(id, updated, db=db)
if function_type == "filter" and getattr(function_module, "toggle", None): if function_type == "filter" and getattr(function_module, "toggle", None):
Functions.update_function_metadata_by_id(id, {"toggle": True}) Functions.update_function_metadata_by_id(id, {"toggle": True}, db=db)
if function: if function:
return function return function
@ -352,9 +354,9 @@ async def update_function_by_id(
@router.delete("/id/{id}/delete", response_model=bool) @router.delete("/id/{id}/delete", response_model=bool)
async def delete_function_by_id( async def delete_function_by_id(
request: Request, id: str, user=Depends(get_admin_user) request: Request, id: str, user=Depends(get_admin_user), db: Session = Depends(get_session)
): ):
result = Functions.delete_function_by_id(id) result = Functions.delete_function_by_id(id, db=db)
if result: if result:
FUNCTIONS = request.app.state.FUNCTIONS FUNCTIONS = request.app.state.FUNCTIONS
@ -370,11 +372,11 @@ async def delete_function_by_id(
@router.get("/id/{id}/valves", response_model=Optional[dict]) @router.get("/id/{id}/valves", response_model=Optional[dict])
async def get_function_valves_by_id(id: str, user=Depends(get_admin_user)): async def get_function_valves_by_id(id: str, user=Depends(get_admin_user), db: Session = Depends(get_session)):
function = Functions.get_function_by_id(id) function = Functions.get_function_by_id(id, db=db)
if function: if function:
try: try:
valves = Functions.get_function_valves_by_id(id) valves = Functions.get_function_valves_by_id(id, db=db)
return valves return valves
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
@ -395,9 +397,9 @@ async def get_function_valves_by_id(id: str, user=Depends(get_admin_user)):
@router.get("/id/{id}/valves/spec", response_model=Optional[dict]) @router.get("/id/{id}/valves/spec", response_model=Optional[dict])
async def get_function_valves_spec_by_id( async def get_function_valves_spec_by_id(
request: Request, id: str, user=Depends(get_admin_user) request: Request, id: str, user=Depends(get_admin_user), db: Session = Depends(get_session)
): ):
function = Functions.get_function_by_id(id) function = Functions.get_function_by_id(id, db=db)
if function: if function:
function_module, function_type, frontmatter = get_function_module_from_cache( function_module, function_type, frontmatter = get_function_module_from_cache(
request, id request, id
@ -421,9 +423,9 @@ async def get_function_valves_spec_by_id(
@router.post("/id/{id}/valves/update", response_model=Optional[dict]) @router.post("/id/{id}/valves/update", response_model=Optional[dict])
async def update_function_valves_by_id( async def update_function_valves_by_id(
request: Request, id: str, form_data: dict, user=Depends(get_admin_user) request: Request, id: str, form_data: dict, user=Depends(get_admin_user), db: Session = Depends(get_session)
): ):
function = Functions.get_function_by_id(id) function = Functions.get_function_by_id(id, db=db)
if function: if function:
function_module, function_type, frontmatter = get_function_module_from_cache( function_module, function_type, frontmatter = get_function_module_from_cache(
request, id request, id
@ -437,7 +439,7 @@ async def update_function_valves_by_id(
valves = Valves(**form_data) valves = Valves(**form_data)
valves_dict = valves.model_dump(exclude_unset=True) valves_dict = valves.model_dump(exclude_unset=True)
Functions.update_function_valves_by_id(id, valves_dict) Functions.update_function_valves_by_id(id, valves_dict, db=db)
return valves_dict return valves_dict
except Exception as e: except Exception as e:
log.exception(f"Error updating function values by id {id}: {e}") log.exception(f"Error updating function values by id {id}: {e}")
@ -464,11 +466,11 @@ async def update_function_valves_by_id(
@router.get("/id/{id}/valves/user", response_model=Optional[dict]) @router.get("/id/{id}/valves/user", response_model=Optional[dict])
async def get_function_user_valves_by_id(id: str, user=Depends(get_verified_user)): async def get_function_user_valves_by_id(id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)):
function = Functions.get_function_by_id(id) function = Functions.get_function_by_id(id, db=db)
if function: if function:
try: try:
user_valves = Functions.get_user_valves_by_id_and_user_id(id, user.id) user_valves = Functions.get_user_valves_by_id_and_user_id(id, user.id, db=db)
return user_valves return user_valves
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
@ -484,9 +486,9 @@ async def get_function_user_valves_by_id(id: str, user=Depends(get_verified_user
@router.get("/id/{id}/valves/user/spec", response_model=Optional[dict]) @router.get("/id/{id}/valves/user/spec", response_model=Optional[dict])
async def get_function_user_valves_spec_by_id( async def get_function_user_valves_spec_by_id(
request: Request, id: str, user=Depends(get_verified_user) request: Request, id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
): ):
function = Functions.get_function_by_id(id) function = Functions.get_function_by_id(id, db=db)
if function: if function:
function_module, function_type, frontmatter = get_function_module_from_cache( function_module, function_type, frontmatter = get_function_module_from_cache(
request, id request, id
@ -505,9 +507,9 @@ async def get_function_user_valves_spec_by_id(
@router.post("/id/{id}/valves/user/update", response_model=Optional[dict]) @router.post("/id/{id}/valves/user/update", response_model=Optional[dict])
async def update_function_user_valves_by_id( async def update_function_user_valves_by_id(
request: Request, id: str, form_data: dict, user=Depends(get_verified_user) request: Request, id: str, form_data: dict, user=Depends(get_verified_user), db: Session = Depends(get_session)
): ):
function = Functions.get_function_by_id(id) function = Functions.get_function_by_id(id, db=db)
if function: if function:
function_module, function_type, frontmatter = get_function_module_from_cache( function_module, function_type, frontmatter = get_function_module_from_cache(
@ -522,7 +524,7 @@ async def update_function_user_valves_by_id(
user_valves = UserValves(**form_data) user_valves = UserValves(**form_data)
user_valves_dict = user_valves.model_dump(exclude_unset=True) user_valves_dict = user_valves.model_dump(exclude_unset=True)
Functions.update_user_valves_by_id_and_user_id( Functions.update_user_valves_by_id_and_user_id(
id, user.id, user_valves_dict id, user.id, user_valves_dict, db=db
) )
return user_valves_dict return user_valves_dict
except Exception as e: except Exception as e:

View file

@ -16,6 +16,9 @@ from open_webui.config import CACHE_DIR
from open_webui.constants import ERROR_MESSAGES from open_webui.constants import ERROR_MESSAGES
from fastapi import APIRouter, Depends, HTTPException, Request, status from fastapi import APIRouter, Depends, HTTPException, Request, status
from open_webui.internal.db import get_session
from sqlalchemy.orm import Session
from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.utils.auth import get_admin_user, get_verified_user
@ -29,7 +32,11 @@ router = APIRouter()
@router.get("/", response_model=list[GroupResponse]) @router.get("/", response_model=list[GroupResponse])
async def get_groups(share: Optional[bool] = None, user=Depends(get_verified_user)): async def get_groups(
share: Optional[bool] = None,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
filter = {} filter = {}
if user.role != "admin": if user.role != "admin":
@ -38,7 +45,7 @@ async def get_groups(share: Optional[bool] = None, user=Depends(get_verified_use
if share is not None: if share is not None:
filter["share"] = share filter["share"] = share
groups = Groups.get_groups(filter=filter) groups = Groups.get_groups(filter=filter, db=db)
return groups return groups
@ -49,13 +56,17 @@ async def get_groups(share: Optional[bool] = None, user=Depends(get_verified_use
@router.post("/create", response_model=Optional[GroupResponse]) @router.post("/create", response_model=Optional[GroupResponse])
async def create_new_group(form_data: GroupForm, user=Depends(get_admin_user)): async def create_new_group(
form_data: GroupForm,
user=Depends(get_admin_user),
db: Session = Depends(get_session),
):
try: try:
group = Groups.insert_new_group(user.id, form_data) group = Groups.insert_new_group(user.id, form_data, db=db)
if group: if group:
return GroupResponse( return GroupResponse(
**group.model_dump(), **group.model_dump(),
member_count=Groups.get_group_member_count_by_id(group.id), member_count=Groups.get_group_member_count_by_id(group.id, db=db),
) )
else: else:
raise HTTPException( raise HTTPException(
@ -76,12 +87,14 @@ async def create_new_group(form_data: GroupForm, user=Depends(get_admin_user)):
@router.get("/id/{id}", response_model=Optional[GroupResponse]) @router.get("/id/{id}", response_model=Optional[GroupResponse])
async def get_group_by_id(id: str, user=Depends(get_admin_user)): async def get_group_by_id(
group = Groups.get_group_by_id(id) id: str, user=Depends(get_admin_user), db: Session = Depends(get_session)
):
group = Groups.get_group_by_id(id, db=db)
if group: if group:
return GroupResponse( return GroupResponse(
**group.model_dump(), **group.model_dump(),
member_count=Groups.get_group_member_count_by_id(group.id), member_count=Groups.get_group_member_count_by_id(group.id, db=db),
) )
else: else:
raise HTTPException( raise HTTPException(
@ -101,13 +114,15 @@ class GroupExportResponse(GroupResponse):
@router.get("/id/{id}/export", response_model=Optional[GroupExportResponse]) @router.get("/id/{id}/export", response_model=Optional[GroupExportResponse])
async def export_group_by_id(id: str, user=Depends(get_admin_user)): async def export_group_by_id(
group = Groups.get_group_by_id(id) id: str, user=Depends(get_admin_user), db: Session = Depends(get_session)
):
group = Groups.get_group_by_id(id, db=db)
if group: if group:
return GroupExportResponse( return GroupExportResponse(
**group.model_dump(), **group.model_dump(),
member_count=Groups.get_group_member_count_by_id(group.id), member_count=Groups.get_group_member_count_by_id(group.id, db=db),
user_ids=Groups.get_group_user_ids_by_id(group.id), user_ids=Groups.get_group_user_ids_by_id(group.id, db=db),
) )
else: else:
raise HTTPException( raise HTTPException(
@ -122,9 +137,11 @@ async def export_group_by_id(id: str, user=Depends(get_admin_user)):
@router.post("/id/{id}/users", response_model=list[UserInfoResponse]) @router.post("/id/{id}/users", response_model=list[UserInfoResponse])
async def get_users_in_group(id: str, user=Depends(get_admin_user)): async def get_users_in_group(
id: str, user=Depends(get_admin_user), db: Session = Depends(get_session)
):
try: try:
users = Users.get_users_by_group_id(id) users = Users.get_users_by_group_id(id, db=db)
return users return users
except Exception as e: except Exception as e:
log.exception(f"Error adding users to group {id}: {e}") log.exception(f"Error adding users to group {id}: {e}")
@ -141,14 +158,17 @@ async def get_users_in_group(id: str, user=Depends(get_admin_user)):
@router.post("/id/{id}/update", response_model=Optional[GroupResponse]) @router.post("/id/{id}/update", response_model=Optional[GroupResponse])
async def update_group_by_id( async def update_group_by_id(
id: str, form_data: GroupUpdateForm, user=Depends(get_admin_user) id: str,
form_data: GroupUpdateForm,
user=Depends(get_admin_user),
db: Session = Depends(get_session),
): ):
try: try:
group = Groups.update_group_by_id(id, form_data) group = Groups.update_group_by_id(id, form_data, db=db)
if group: if group:
return GroupResponse( return GroupResponse(
**group.model_dump(), **group.model_dump(),
member_count=Groups.get_group_member_count_by_id(group.id), member_count=Groups.get_group_member_count_by_id(group.id, db=db),
) )
else: else:
raise HTTPException( raise HTTPException(
@ -170,17 +190,20 @@ async def update_group_by_id(
@router.post("/id/{id}/users/add", response_model=Optional[GroupResponse]) @router.post("/id/{id}/users/add", response_model=Optional[GroupResponse])
async def add_user_to_group( async def add_user_to_group(
id: str, form_data: UserIdsForm, user=Depends(get_admin_user) id: str,
form_data: UserIdsForm,
user=Depends(get_admin_user),
db: Session = Depends(get_session),
): ):
try: try:
if form_data.user_ids: if form_data.user_ids:
form_data.user_ids = Users.get_valid_user_ids(form_data.user_ids) form_data.user_ids = Users.get_valid_user_ids(form_data.user_ids, db=db)
group = Groups.add_users_to_group(id, form_data.user_ids) group = Groups.add_users_to_group(id, form_data.user_ids, db=db)
if group: if group:
return GroupResponse( return GroupResponse(
**group.model_dump(), **group.model_dump(),
member_count=Groups.get_group_member_count_by_id(group.id), member_count=Groups.get_group_member_count_by_id(group.id, db=db),
) )
else: else:
raise HTTPException( raise HTTPException(
@ -197,14 +220,17 @@ async def add_user_to_group(
@router.post("/id/{id}/users/remove", response_model=Optional[GroupResponse]) @router.post("/id/{id}/users/remove", response_model=Optional[GroupResponse])
async def remove_users_from_group( async def remove_users_from_group(
id: str, form_data: UserIdsForm, user=Depends(get_admin_user) id: str,
form_data: UserIdsForm,
user=Depends(get_admin_user),
db: Session = Depends(get_session),
): ):
try: try:
group = Groups.remove_users_from_group(id, form_data.user_ids) group = Groups.remove_users_from_group(id, form_data.user_ids, db=db)
if group: if group:
return GroupResponse( return GroupResponse(
**group.model_dump(), **group.model_dump(),
member_count=Groups.get_group_member_count_by_id(group.id), member_count=Groups.get_group_member_count_by_id(group.id, db=db),
) )
else: else:
raise HTTPException( raise HTTPException(
@ -225,9 +251,11 @@ async def remove_users_from_group(
@router.delete("/id/{id}/delete", response_model=bool) @router.delete("/id/{id}/delete", response_model=bool)
async def delete_group_by_id(id: str, user=Depends(get_admin_user)): async def delete_group_by_id(
id: str, user=Depends(get_admin_user), db: Session = Depends(get_session)
):
try: try:
result = Groups.delete_group_by_id(id) result = Groups.delete_group_by_id(id, db=db)
if result: if result:
return result return result
else: else:

View file

@ -4,6 +4,8 @@ from fastapi import APIRouter, Depends, HTTPException, status, Request, Query
from fastapi.concurrency import run_in_threadpool from fastapi.concurrency import run_in_threadpool
import logging import logging
from sqlalchemy.orm import Session
from open_webui.internal.db import get_session
from open_webui.models.groups import Groups from open_webui.models.groups import Groups
from open_webui.models.knowledge import ( from open_webui.models.knowledge import (
KnowledgeFileListResponse, KnowledgeFileListResponse,
@ -52,21 +54,25 @@ class KnowledgeAccessListResponse(BaseModel):
@router.get("/", response_model=KnowledgeAccessListResponse) @router.get("/", response_model=KnowledgeAccessListResponse)
async def get_knowledge_bases(page: Optional[int] = 1, user=Depends(get_verified_user)): async def get_knowledge_bases(
page: Optional[int] = 1,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
page = max(page, 1) page = max(page, 1)
limit = PAGE_ITEM_COUNT limit = PAGE_ITEM_COUNT
skip = (page - 1) * limit skip = (page - 1) * limit
filter = {} filter = {}
if not user.role == "admin" or not BYPASS_ADMIN_ACCESS_CONTROL: if not user.role == "admin" or not BYPASS_ADMIN_ACCESS_CONTROL:
groups = Groups.get_groups_by_member_id(user.id) groups = Groups.get_groups_by_member_id(user.id, db=db)
if groups: if groups:
filter["group_ids"] = [group.id for group in groups] filter["group_ids"] = [group.id for group in groups]
filter["user_id"] = user.id filter["user_id"] = user.id
result = Knowledges.search_knowledge_bases( result = Knowledges.search_knowledge_bases(
user.id, filter=filter, skip=skip, limit=limit user.id, filter=filter, skip=skip, limit=limit, db=db
) )
return KnowledgeAccessListResponse( return KnowledgeAccessListResponse(
@ -75,7 +81,9 @@ async def get_knowledge_bases(page: Optional[int] = 1, user=Depends(get_verified
**knowledge_base.model_dump(), **knowledge_base.model_dump(),
write_access=( write_access=(
user.id == knowledge_base.user_id user.id == knowledge_base.user_id
or has_access(user.id, "write", knowledge_base.access_control) or has_access(
user.id, "write", knowledge_base.access_control, db=db
)
), ),
) )
for knowledge_base in result.items for knowledge_base in result.items
@ -90,6 +98,7 @@ async def search_knowledge_bases(
view_option: Optional[str] = None, view_option: Optional[str] = None,
page: Optional[int] = 1, page: Optional[int] = 1,
user=Depends(get_verified_user), user=Depends(get_verified_user),
db: Session = Depends(get_session),
): ):
page = max(page, 1) page = max(page, 1)
limit = PAGE_ITEM_COUNT limit = PAGE_ITEM_COUNT
@ -102,14 +111,14 @@ async def search_knowledge_bases(
filter["view_option"] = view_option filter["view_option"] = view_option
if not user.role == "admin" or not BYPASS_ADMIN_ACCESS_CONTROL: if not user.role == "admin" or not BYPASS_ADMIN_ACCESS_CONTROL:
groups = Groups.get_groups_by_member_id(user.id) groups = Groups.get_groups_by_member_id(user.id, db=db)
if groups: if groups:
filter["group_ids"] = [group.id for group in groups] filter["group_ids"] = [group.id for group in groups]
filter["user_id"] = user.id filter["user_id"] = user.id
result = Knowledges.search_knowledge_bases( result = Knowledges.search_knowledge_bases(
user.id, filter=filter, skip=skip, limit=limit user.id, filter=filter, skip=skip, limit=limit, db=db
) )
return KnowledgeAccessListResponse( return KnowledgeAccessListResponse(
@ -118,7 +127,9 @@ async def search_knowledge_bases(
**knowledge_base.model_dump(), **knowledge_base.model_dump(),
write_access=( write_access=(
user.id == knowledge_base.user_id user.id == knowledge_base.user_id
or has_access(user.id, "write", knowledge_base.access_control) or has_access(
user.id, "write", knowledge_base.access_control, db=db
)
), ),
) )
for knowledge_base in result.items for knowledge_base in result.items
@ -132,6 +143,7 @@ async def search_knowledge_files(
query: Optional[str] = None, query: Optional[str] = None,
page: Optional[int] = 1, page: Optional[int] = 1,
user=Depends(get_verified_user), user=Depends(get_verified_user),
db: Session = Depends(get_session),
): ):
page = max(page, 1) page = max(page, 1)
limit = PAGE_ITEM_COUNT limit = PAGE_ITEM_COUNT
@ -141,13 +153,15 @@ async def search_knowledge_files(
if query: if query:
filter["query"] = query filter["query"] = query
groups = Groups.get_groups_by_member_id(user.id) groups = Groups.get_groups_by_member_id(user.id, db=db)
if groups: if groups:
filter["group_ids"] = [group.id for group in groups] filter["group_ids"] = [group.id for group in groups]
filter["user_id"] = user.id filter["user_id"] = user.id
return Knowledges.search_knowledge_files(filter=filter, skip=skip, limit=limit) return Knowledges.search_knowledge_files(
filter=filter, skip=skip, limit=limit, db=db
)
############################ ############################
@ -157,10 +171,13 @@ async def search_knowledge_files(
@router.post("/create", response_model=Optional[KnowledgeResponse]) @router.post("/create", response_model=Optional[KnowledgeResponse])
async def create_new_knowledge( async def create_new_knowledge(
request: Request, form_data: KnowledgeForm, user=Depends(get_verified_user) request: Request,
form_data: KnowledgeForm,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
): ):
if user.role != "admin" and not has_permission( if user.role != "admin" and not has_permission(
user.id, "workspace.knowledge", request.app.state.config.USER_PERMISSIONS user.id, "workspace.knowledge", request.app.state.config.USER_PERMISSIONS, db=db
): ):
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
@ -175,11 +192,12 @@ async def create_new_knowledge(
user.id, user.id,
"sharing.public_knowledge", "sharing.public_knowledge",
request.app.state.config.USER_PERMISSIONS, request.app.state.config.USER_PERMISSIONS,
db=db,
) )
): ):
form_data.access_control = {} form_data.access_control = {}
knowledge = Knowledges.insert_new_knowledge(user.id, form_data) knowledge = Knowledges.insert_new_knowledge(user.id, form_data, db=db)
if knowledge: if knowledge:
return knowledge return knowledge
@ -196,20 +214,24 @@ async def create_new_knowledge(
@router.post("/reindex", response_model=bool) @router.post("/reindex", response_model=bool)
async def reindex_knowledge_files(request: Request, user=Depends(get_verified_user)): async def reindex_knowledge_files(
request: Request,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
if user.role != "admin": if user.role != "admin":
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.UNAUTHORIZED, detail=ERROR_MESSAGES.UNAUTHORIZED,
) )
knowledge_bases = Knowledges.get_knowledge_bases() knowledge_bases = Knowledges.get_knowledge_bases(db=db)
log.info(f"Starting reindexing for {len(knowledge_bases)} knowledge bases") log.info(f"Starting reindexing for {len(knowledge_bases)} knowledge bases")
for knowledge_base in knowledge_bases: for knowledge_base in knowledge_bases:
try: try:
files = Knowledges.get_files_by_id(knowledge_base.id) files = Knowledges.get_files_by_id(knowledge_base.id, db=db)
try: try:
if VECTOR_DB_CLIENT.has_collection(collection_name=knowledge_base.id): if VECTOR_DB_CLIENT.has_collection(collection_name=knowledge_base.id):
VECTOR_DB_CLIENT.delete_collection( VECTOR_DB_CLIENT.delete_collection(
@ -229,6 +251,7 @@ async def reindex_knowledge_files(request: Request, user=Depends(get_verified_us
file_id=file.id, collection_name=knowledge_base.id file_id=file.id, collection_name=knowledge_base.id
), ),
user=user, user=user,
db=db,
) )
except Exception as e: except Exception as e:
log.error( log.error(
@ -264,21 +287,23 @@ class KnowledgeFilesResponse(KnowledgeResponse):
@router.get("/{id}", response_model=Optional[KnowledgeFilesResponse]) @router.get("/{id}", response_model=Optional[KnowledgeFilesResponse])
async def get_knowledge_by_id(id: str, user=Depends(get_verified_user)): async def get_knowledge_by_id(
knowledge = Knowledges.get_knowledge_by_id(id=id) id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
):
knowledge = Knowledges.get_knowledge_by_id(id=id, db=db)
if knowledge: if knowledge:
if ( if (
user.role == "admin" user.role == "admin"
or knowledge.user_id == user.id or knowledge.user_id == user.id
or has_access(user.id, "read", knowledge.access_control) or has_access(user.id, "read", knowledge.access_control, db=db)
): ):
return KnowledgeFilesResponse( return KnowledgeFilesResponse(
**knowledge.model_dump(), **knowledge.model_dump(),
write_access=( write_access=(
user.id == knowledge.user_id user.id == knowledge.user_id
or has_access(user.id, "write", knowledge.access_control) or has_access(user.id, "write", knowledge.access_control, db=db)
), ),
) )
else: else:
@ -299,8 +324,9 @@ async def update_knowledge_by_id(
id: str, id: str,
form_data: KnowledgeForm, form_data: KnowledgeForm,
user=Depends(get_verified_user), user=Depends(get_verified_user),
db: Session = Depends(get_session),
): ):
knowledge = Knowledges.get_knowledge_by_id(id=id) knowledge = Knowledges.get_knowledge_by_id(id=id, db=db)
if not knowledge: if not knowledge:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
@ -309,7 +335,7 @@ async def update_knowledge_by_id(
# Is the user the original creator, in a group with write access, or an admin # Is the user the original creator, in a group with write access, or an admin
if ( if (
knowledge.user_id != user.id knowledge.user_id != user.id
and not has_access(user.id, "write", knowledge.access_control) and not has_access(user.id, "write", knowledge.access_control, db=db)
and user.role != "admin" and user.role != "admin"
): ):
raise HTTPException( raise HTTPException(
@ -325,15 +351,16 @@ async def update_knowledge_by_id(
user.id, user.id,
"sharing.public_knowledge", "sharing.public_knowledge",
request.app.state.config.USER_PERMISSIONS, request.app.state.config.USER_PERMISSIONS,
db=db,
) )
): ):
form_data.access_control = {} form_data.access_control = {}
knowledge = Knowledges.update_knowledge_by_id(id=id, form_data=form_data) knowledge = Knowledges.update_knowledge_by_id(id=id, form_data=form_data, db=db)
if knowledge: if knowledge:
return KnowledgeFilesResponse( return KnowledgeFilesResponse(
**knowledge.model_dump(), **knowledge.model_dump(),
files=Knowledges.get_file_metadatas_by_id(knowledge.id), files=Knowledges.get_file_metadatas_by_id(knowledge.id, db=db),
) )
else: else:
raise HTTPException( raise HTTPException(
@ -356,9 +383,10 @@ async def get_knowledge_files_by_id(
direction: Optional[str] = None, direction: Optional[str] = None,
page: Optional[int] = 1, page: Optional[int] = 1,
user=Depends(get_verified_user), user=Depends(get_verified_user),
db: Session = Depends(get_session),
): ):
knowledge = Knowledges.get_knowledge_by_id(id=id) knowledge = Knowledges.get_knowledge_by_id(id=id, db=db)
if not knowledge: if not knowledge:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
@ -368,7 +396,7 @@ async def get_knowledge_files_by_id(
if not ( if not (
user.role == "admin" user.role == "admin"
or knowledge.user_id == user.id or knowledge.user_id == user.id
or has_access(user.id, "read", knowledge.access_control) or has_access(user.id, "read", knowledge.access_control, db=db)
): ):
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
@ -391,7 +419,7 @@ async def get_knowledge_files_by_id(
filter["direction"] = direction filter["direction"] = direction
return Knowledges.search_files_by_id( return Knowledges.search_files_by_id(
id, user.id, filter=filter, skip=skip, limit=limit id, user.id, filter=filter, skip=skip, limit=limit, db=db
) )
@ -410,8 +438,9 @@ def add_file_to_knowledge_by_id(
id: str, id: str,
form_data: KnowledgeFileIdForm, form_data: KnowledgeFileIdForm,
user=Depends(get_verified_user), user=Depends(get_verified_user),
db: Session = Depends(get_session),
): ):
knowledge = Knowledges.get_knowledge_by_id(id=id) knowledge = Knowledges.get_knowledge_by_id(id=id, db=db)
if not knowledge: if not knowledge:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
@ -420,7 +449,7 @@ def add_file_to_knowledge_by_id(
if ( if (
knowledge.user_id != user.id knowledge.user_id != user.id
and not has_access(user.id, "write", knowledge.access_control) and not has_access(user.id, "write", knowledge.access_control, db=db)
and user.role != "admin" and user.role != "admin"
): ):
raise HTTPException( raise HTTPException(
@ -428,7 +457,7 @@ def add_file_to_knowledge_by_id(
detail=ERROR_MESSAGES.ACCESS_PROHIBITED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
) )
file = Files.get_file_by_id(form_data.file_id) file = Files.get_file_by_id(form_data.file_id, db=db)
if not file: if not file:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
@ -446,11 +475,12 @@ def add_file_to_knowledge_by_id(
request, request,
ProcessFileForm(file_id=form_data.file_id, collection_name=id), ProcessFileForm(file_id=form_data.file_id, collection_name=id),
user=user, user=user,
db=db,
) )
# Add file to knowledge base # Add file to knowledge base
Knowledges.add_file_to_knowledge_by_id( Knowledges.add_file_to_knowledge_by_id(
knowledge_id=id, file_id=form_data.file_id, user_id=user.id knowledge_id=id, file_id=form_data.file_id, user_id=user.id, db=db
) )
except Exception as e: except Exception as e:
log.debug(e) log.debug(e)
@ -462,7 +492,7 @@ def add_file_to_knowledge_by_id(
if knowledge: if knowledge:
return KnowledgeFilesResponse( return KnowledgeFilesResponse(
**knowledge.model_dump(), **knowledge.model_dump(),
files=Knowledges.get_file_metadatas_by_id(knowledge.id), files=Knowledges.get_file_metadatas_by_id(knowledge.id, db=db),
) )
else: else:
raise HTTPException( raise HTTPException(
@ -477,8 +507,9 @@ def update_file_from_knowledge_by_id(
id: str, id: str,
form_data: KnowledgeFileIdForm, form_data: KnowledgeFileIdForm,
user=Depends(get_verified_user), user=Depends(get_verified_user),
db: Session = Depends(get_session),
): ):
knowledge = Knowledges.get_knowledge_by_id(id=id) knowledge = Knowledges.get_knowledge_by_id(id=id, db=db)
if not knowledge: if not knowledge:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
@ -487,7 +518,7 @@ def update_file_from_knowledge_by_id(
if ( if (
knowledge.user_id != user.id knowledge.user_id != user.id
and not has_access(user.id, "write", knowledge.access_control) and not has_access(user.id, "write", knowledge.access_control, db=db)
and user.role != "admin" and user.role != "admin"
): ):
@ -496,7 +527,7 @@ def update_file_from_knowledge_by_id(
detail=ERROR_MESSAGES.ACCESS_PROHIBITED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
) )
file = Files.get_file_by_id(form_data.file_id) file = Files.get_file_by_id(form_data.file_id, db=db)
if not file: if not file:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
@ -514,6 +545,7 @@ def update_file_from_knowledge_by_id(
request, request,
ProcessFileForm(file_id=form_data.file_id, collection_name=id), ProcessFileForm(file_id=form_data.file_id, collection_name=id),
user=user, user=user,
db=db,
) )
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
@ -524,7 +556,7 @@ def update_file_from_knowledge_by_id(
if knowledge: if knowledge:
return KnowledgeFilesResponse( return KnowledgeFilesResponse(
**knowledge.model_dump(), **knowledge.model_dump(),
files=Knowledges.get_file_metadatas_by_id(knowledge.id), files=Knowledges.get_file_metadatas_by_id(knowledge.id, db=db),
) )
else: else:
raise HTTPException( raise HTTPException(
@ -544,8 +576,9 @@ def remove_file_from_knowledge_by_id(
form_data: KnowledgeFileIdForm, form_data: KnowledgeFileIdForm,
delete_file: bool = Query(True), delete_file: bool = Query(True),
user=Depends(get_verified_user), user=Depends(get_verified_user),
db: Session = Depends(get_session),
): ):
knowledge = Knowledges.get_knowledge_by_id(id=id) knowledge = Knowledges.get_knowledge_by_id(id=id, db=db)
if not knowledge: if not knowledge:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
@ -554,7 +587,7 @@ def remove_file_from_knowledge_by_id(
if ( if (
knowledge.user_id != user.id knowledge.user_id != user.id
and not has_access(user.id, "write", knowledge.access_control) and not has_access(user.id, "write", knowledge.access_control, db=db)
and user.role != "admin" and user.role != "admin"
): ):
raise HTTPException( raise HTTPException(
@ -562,7 +595,7 @@ def remove_file_from_knowledge_by_id(
detail=ERROR_MESSAGES.ACCESS_PROHIBITED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
) )
file = Files.get_file_by_id(form_data.file_id) file = Files.get_file_by_id(form_data.file_id, db=db)
if not file: if not file:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
@ -570,7 +603,7 @@ def remove_file_from_knowledge_by_id(
) )
Knowledges.remove_file_from_knowledge_by_id( Knowledges.remove_file_from_knowledge_by_id(
knowledge_id=id, file_id=form_data.file_id knowledge_id=id, file_id=form_data.file_id, db=db
) )
# Remove content from the vector database # Remove content from the vector database
@ -599,12 +632,12 @@ def remove_file_from_knowledge_by_id(
pass pass
# Delete file from database # Delete file from database
Files.delete_file_by_id(form_data.file_id) Files.delete_file_by_id(form_data.file_id, db=db)
if knowledge: if knowledge:
return KnowledgeFilesResponse( return KnowledgeFilesResponse(
**knowledge.model_dump(), **knowledge.model_dump(),
files=Knowledges.get_file_metadatas_by_id(knowledge.id), files=Knowledges.get_file_metadatas_by_id(knowledge.id, db=db),
) )
else: else:
raise HTTPException( raise HTTPException(
@ -619,8 +652,10 @@ def remove_file_from_knowledge_by_id(
@router.delete("/{id}/delete", response_model=bool) @router.delete("/{id}/delete", response_model=bool)
async def delete_knowledge_by_id(id: str, user=Depends(get_verified_user)): async def delete_knowledge_by_id(
knowledge = Knowledges.get_knowledge_by_id(id=id) id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
):
knowledge = Knowledges.get_knowledge_by_id(id=id, db=db)
if not knowledge: if not knowledge:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
@ -629,7 +664,7 @@ async def delete_knowledge_by_id(id: str, user=Depends(get_verified_user)):
if ( if (
knowledge.user_id != user.id knowledge.user_id != user.id
and not has_access(user.id, "write", knowledge.access_control) and not has_access(user.id, "write", knowledge.access_control, db=db)
and user.role != "admin" and user.role != "admin"
): ):
raise HTTPException( raise HTTPException(
@ -640,7 +675,7 @@ async def delete_knowledge_by_id(id: str, user=Depends(get_verified_user)):
log.info(f"Deleting knowledge base: {id} (name: {knowledge.name})") log.info(f"Deleting knowledge base: {id} (name: {knowledge.name})")
# Get all models # Get all models
models = Models.get_all_models() models = Models.get_all_models(db=db)
log.info(f"Found {len(models)} models to check for knowledge base {id}") log.info(f"Found {len(models)} models to check for knowledge base {id}")
# Update models that reference this knowledge base # Update models that reference this knowledge base
@ -664,7 +699,7 @@ async def delete_knowledge_by_id(id: str, user=Depends(get_verified_user)):
access_control=model.access_control, access_control=model.access_control,
is_active=model.is_active, is_active=model.is_active,
) )
Models.update_model_by_id(model.id, model_form) Models.update_model_by_id(model.id, model_form, db=db)
# Clean up vector DB # Clean up vector DB
try: try:
@ -672,7 +707,7 @@ async def delete_knowledge_by_id(id: str, user=Depends(get_verified_user)):
except Exception as e: except Exception as e:
log.debug(e) log.debug(e)
pass pass
result = Knowledges.delete_knowledge_by_id(id=id) result = Knowledges.delete_knowledge_by_id(id=id, db=db)
return result return result
@ -682,8 +717,10 @@ async def delete_knowledge_by_id(id: str, user=Depends(get_verified_user)):
@router.post("/{id}/reset", response_model=Optional[KnowledgeResponse]) @router.post("/{id}/reset", response_model=Optional[KnowledgeResponse])
async def reset_knowledge_by_id(id: str, user=Depends(get_verified_user)): async def reset_knowledge_by_id(
knowledge = Knowledges.get_knowledge_by_id(id=id) id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
):
knowledge = Knowledges.get_knowledge_by_id(id=id, db=db)
if not knowledge: if not knowledge:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
@ -692,7 +729,7 @@ async def reset_knowledge_by_id(id: str, user=Depends(get_verified_user)):
if ( if (
knowledge.user_id != user.id knowledge.user_id != user.id
and not has_access(user.id, "write", knowledge.access_control) and not has_access(user.id, "write", knowledge.access_control, db=db)
and user.role != "admin" and user.role != "admin"
): ):
raise HTTPException( raise HTTPException(
@ -706,7 +743,7 @@ async def reset_knowledge_by_id(id: str, user=Depends(get_verified_user)):
log.debug(e) log.debug(e)
pass pass
knowledge = Knowledges.reset_knowledge_by_id(id=id) knowledge = Knowledges.reset_knowledge_by_id(id=id, db=db)
return knowledge return knowledge
@ -721,11 +758,12 @@ async def add_files_to_knowledge_batch(
id: str, id: str,
form_data: list[KnowledgeFileIdForm], form_data: list[KnowledgeFileIdForm],
user=Depends(get_verified_user), user=Depends(get_verified_user),
db: Session = Depends(get_session),
): ):
""" """
Add multiple files to a knowledge base Add multiple files to a knowledge base
""" """
knowledge = Knowledges.get_knowledge_by_id(id=id) knowledge = Knowledges.get_knowledge_by_id(id=id, db=db)
if not knowledge: if not knowledge:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
@ -734,7 +772,7 @@ async def add_files_to_knowledge_batch(
if ( if (
knowledge.user_id != user.id knowledge.user_id != user.id
and not has_access(user.id, "write", knowledge.access_control) and not has_access(user.id, "write", knowledge.access_control, db=db)
and user.role != "admin" and user.role != "admin"
): ):
raise HTTPException( raise HTTPException(
@ -746,7 +784,7 @@ async def add_files_to_knowledge_batch(
log.info(f"files/batch/add - {len(form_data)} files") log.info(f"files/batch/add - {len(form_data)} files")
files: List[FileModel] = [] files: List[FileModel] = []
for form in form_data: for form in form_data:
file = Files.get_file_by_id(form.file_id) file = Files.get_file_by_id(form.file_id, db=db)
if not file: if not file:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
@ -760,6 +798,7 @@ async def add_files_to_knowledge_batch(
request=request, request=request,
form_data=BatchProcessFilesForm(files=files, collection_name=id), form_data=BatchProcessFilesForm(files=files, collection_name=id),
user=user, user=user,
db=db,
) )
except Exception as e: except Exception as e:
log.error( log.error(
@ -771,7 +810,7 @@ async def add_files_to_knowledge_batch(
successful_file_ids = [r.file_id for r in result.results if r.status == "completed"] successful_file_ids = [r.file_id for r in result.results if r.status == "completed"]
for file_id in successful_file_ids: for file_id in successful_file_ids:
Knowledges.add_file_to_knowledge_by_id( Knowledges.add_file_to_knowledge_by_id(
knowledge_id=id, file_id=file_id, user_id=user.id knowledge_id=id, file_id=file_id, user_id=user.id, db=db
) )
# If there were any errors, include them in the response # If there were any errors, include them in the response
@ -779,7 +818,7 @@ async def add_files_to_knowledge_batch(
error_details = [f"{err.file_id}: {err.error}" for err in result.errors] error_details = [f"{err.file_id}: {err.error}" for err in result.errors]
return KnowledgeFilesResponse( return KnowledgeFilesResponse(
**knowledge.model_dump(), **knowledge.model_dump(),
files=Knowledges.get_file_metadatas_by_id(knowledge.id), files=Knowledges.get_file_metadatas_by_id(knowledge.id, db=db),
warnings={ warnings={
"message": "Some files failed to process", "message": "Some files failed to process",
"errors": error_details, "errors": error_details,
@ -788,5 +827,5 @@ async def add_files_to_knowledge_batch(
return KnowledgeFilesResponse( return KnowledgeFilesResponse(
**knowledge.model_dump(), **knowledge.model_dump(),
files=Knowledges.get_file_metadatas_by_id(knowledge.id), files=Knowledges.get_file_metadatas_by_id(knowledge.id, db=db),
) )

View file

@ -9,6 +9,10 @@ from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT
from open_webui.utils.auth import get_verified_user from open_webui.utils.auth import get_verified_user
from open_webui.utils.auth import get_verified_user
from open_webui.internal.db import get_session
from sqlalchemy.orm import Session
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
router = APIRouter() router = APIRouter()
@ -25,8 +29,8 @@ async def get_embeddings(request: Request):
@router.get("/", response_model=list[MemoryModel]) @router.get("/", response_model=list[MemoryModel])
async def get_memories(user=Depends(get_verified_user)): async def get_memories(user=Depends(get_verified_user), db: Session = Depends(get_session)):
return Memories.get_memories_by_user_id(user.id) return Memories.get_memories_by_user_id(user.id, db=db)
############################ ############################
@ -47,8 +51,9 @@ async def add_memory(
request: Request, request: Request,
form_data: AddMemoryForm, form_data: AddMemoryForm,
user=Depends(get_verified_user), user=Depends(get_verified_user),
db: Session = Depends(get_session),
): ):
memory = Memories.insert_new_memory(user.id, form_data.content) memory = Memories.insert_new_memory(user.id, form_data.content, db=db)
vector = await request.app.state.EMBEDDING_FUNCTION(memory.content, user=user) vector = await request.app.state.EMBEDDING_FUNCTION(memory.content, user=user)
@ -79,9 +84,9 @@ class QueryMemoryForm(BaseModel):
@router.post("/query") @router.post("/query")
async def query_memory( async def query_memory(
request: Request, form_data: QueryMemoryForm, user=Depends(get_verified_user) request: Request, form_data: QueryMemoryForm, user=Depends(get_verified_user), db: Session = Depends(get_session)
): ):
memories = Memories.get_memories_by_user_id(user.id) memories = Memories.get_memories_by_user_id(user.id, db=db)
if not memories: if not memories:
raise HTTPException(status_code=404, detail="No memories found for user") raise HTTPException(status_code=404, detail="No memories found for user")
@ -101,11 +106,11 @@ async def query_memory(
############################ ############################
@router.post("/reset", response_model=bool) @router.post("/reset", response_model=bool)
async def reset_memory_from_vector_db( async def reset_memory_from_vector_db(
request: Request, user=Depends(get_verified_user) request: Request, user=Depends(get_verified_user), db: Session = Depends(get_session)
): ):
VECTOR_DB_CLIENT.delete_collection(f"user-memory-{user.id}") VECTOR_DB_CLIENT.delete_collection(f"user-memory-{user.id}")
memories = Memories.get_memories_by_user_id(user.id) memories = Memories.get_memories_by_user_id(user.id, db=db)
# Generate vectors in parallel # Generate vectors in parallel
vectors = await asyncio.gather( vectors = await asyncio.gather(
@ -140,8 +145,8 @@ async def reset_memory_from_vector_db(
@router.delete("/delete/user", response_model=bool) @router.delete("/delete/user", response_model=bool)
async def delete_memory_by_user_id(user=Depends(get_verified_user)): async def delete_memory_by_user_id(user=Depends(get_verified_user), db: Session = Depends(get_session)):
result = Memories.delete_memories_by_user_id(user.id) result = Memories.delete_memories_by_user_id(user.id, db=db)
if result: if result:
try: try:
@ -164,9 +169,10 @@ async def update_memory_by_id(
request: Request, request: Request,
form_data: MemoryUpdateModel, form_data: MemoryUpdateModel,
user=Depends(get_verified_user), user=Depends(get_verified_user),
db: Session = Depends(get_session),
): ):
memory = Memories.update_memory_by_id_and_user_id( memory = Memories.update_memory_by_id_and_user_id(
memory_id, user.id, form_data.content memory_id, user.id, form_data.content, db=db
) )
if memory is None: if memory is None:
raise HTTPException(status_code=404, detail="Memory not found") raise HTTPException(status_code=404, detail="Memory not found")
@ -198,8 +204,8 @@ async def update_memory_by_id(
@router.delete("/{memory_id}", response_model=bool) @router.delete("/{memory_id}", response_model=bool)
async def delete_memory_by_id(memory_id: str, user=Depends(get_verified_user)): async def delete_memory_by_id(memory_id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)):
result = Memories.delete_memory_by_id_and_user_id(memory_id, user.id) result = Memories.delete_memory_by_id_and_user_id(memory_id, user.id, db=db)
if result: if result:
VECTOR_DB_CLIENT.delete( VECTOR_DB_CLIENT.delete(

View file

@ -30,6 +30,8 @@ from fastapi.responses import FileResponse, StreamingResponse
from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.utils.access_control import has_access, has_permission from open_webui.utils.access_control import has_access, has_permission
from open_webui.config import BYPASS_ADMIN_ACCESS_CONTROL, STATIC_DIR from open_webui.config import BYPASS_ADMIN_ACCESS_CONTROL, STATIC_DIR
from open_webui.internal.db import get_session
from sqlalchemy.orm import Session
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -59,6 +61,7 @@ async def get_models(
direction: Optional[str] = None, direction: Optional[str] = None,
page: Optional[int] = 1, page: Optional[int] = 1,
user=Depends(get_verified_user), user=Depends(get_verified_user),
db: Session = Depends(get_session),
): ):
limit = PAGE_ITEM_COUNT limit = PAGE_ITEM_COUNT
@ -79,13 +82,13 @@ async def get_models(
filter["direction"] = direction filter["direction"] = direction
if not user.role == "admin" or not BYPASS_ADMIN_ACCESS_CONTROL: if not user.role == "admin" or not BYPASS_ADMIN_ACCESS_CONTROL:
groups = Groups.get_groups_by_member_id(user.id) groups = Groups.get_groups_by_member_id(user.id, db=db)
if groups: if groups:
filter["group_ids"] = [group.id for group in groups] filter["group_ids"] = [group.id for group in groups]
filter["user_id"] = user.id filter["user_id"] = user.id
return Models.search_models(user.id, filter=filter, skip=skip, limit=limit) return Models.search_models(user.id, filter=filter, skip=skip, limit=limit, db=db)
########################### ###########################
@ -94,8 +97,8 @@ async def get_models(
@router.get("/base", response_model=list[ModelResponse]) @router.get("/base", response_model=list[ModelResponse])
async def get_base_models(user=Depends(get_admin_user)): async def get_base_models(user=Depends(get_admin_user), db: Session = Depends(get_session)):
return Models.get_base_models() return Models.get_base_models(db=db)
########################### ###########################
@ -104,11 +107,11 @@ async def get_base_models(user=Depends(get_admin_user)):
@router.get("/tags", response_model=list[str]) @router.get("/tags", response_model=list[str])
async def get_model_tags(user=Depends(get_verified_user)): async def get_model_tags(user=Depends(get_verified_user), db: Session = Depends(get_session)):
if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL: if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL:
models = Models.get_models() models = Models.get_models(db=db)
else: else:
models = Models.get_models_by_user_id(user.id) models = Models.get_models_by_user_id(user.id, db=db)
tags_set = set() tags_set = set()
for model in models: for model in models:
@ -132,16 +135,17 @@ async def create_new_model(
request: Request, request: Request,
form_data: ModelForm, form_data: ModelForm,
user=Depends(get_verified_user), user=Depends(get_verified_user),
db: Session = Depends(get_session),
): ):
if user.role != "admin" and not has_permission( if user.role != "admin" and not has_permission(
user.id, "workspace.models", request.app.state.config.USER_PERMISSIONS user.id, "workspace.models", request.app.state.config.USER_PERMISSIONS, db=db
): ):
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.UNAUTHORIZED, detail=ERROR_MESSAGES.UNAUTHORIZED,
) )
model = Models.get_model_by_id(form_data.id) model = Models.get_model_by_id(form_data.id, db=db)
if model: if model:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
@ -155,7 +159,7 @@ async def create_new_model(
) )
else: else:
model = Models.insert_new_model(form_data, user.id) model = Models.insert_new_model(form_data, user.id, db=db)
if model: if model:
return model return model
else: else:
@ -171,9 +175,9 @@ async def create_new_model(
@router.get("/export", response_model=list[ModelModel]) @router.get("/export", response_model=list[ModelModel])
async def export_models(request: Request, user=Depends(get_verified_user)): async def export_models(request: Request, user=Depends(get_verified_user), db: Session = Depends(get_session)):
if user.role != "admin" and not has_permission( if user.role != "admin" and not has_permission(
user.id, "workspace.models_export", request.app.state.config.USER_PERMISSIONS user.id, "workspace.models_export", request.app.state.config.USER_PERMISSIONS, db=db
): ):
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
@ -181,9 +185,9 @@ async def export_models(request: Request, user=Depends(get_verified_user)):
) )
if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL: if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL:
return Models.get_models() return Models.get_models(db=db)
else: else:
return Models.get_models_by_user_id(user.id) return Models.get_models_by_user_id(user.id, db=db)
############################ ############################
@ -200,9 +204,10 @@ async def import_models(
request: Request, request: Request,
user=Depends(get_verified_user), user=Depends(get_verified_user),
form_data: ModelsImportForm = (...), form_data: ModelsImportForm = (...),
db: Session = Depends(get_session),
): ):
if user.role != "admin" and not has_permission( if user.role != "admin" and not has_permission(
user.id, "workspace.models_import", request.app.state.config.USER_PERMISSIONS user.id, "workspace.models_import", request.app.state.config.USER_PERMISSIONS, db=db
): ):
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
@ -216,7 +221,7 @@ async def import_models(
model_id = model_data.get("id") model_id = model_data.get("id")
if model_id and is_valid_model_id(model_id): if model_id and is_valid_model_id(model_id):
existing_model = Models.get_model_by_id(model_id) existing_model = Models.get_model_by_id(model_id, db=db)
if existing_model: if existing_model:
# Update existing model # Update existing model
model_data["meta"] = model_data.get("meta", {}) model_data["meta"] = model_data.get("meta", {})
@ -225,13 +230,13 @@ async def import_models(
updated_model = ModelForm( updated_model = ModelForm(
**{**existing_model.model_dump(), **model_data} **{**existing_model.model_dump(), **model_data}
) )
Models.update_model_by_id(model_id, updated_model) Models.update_model_by_id(model_id, updated_model, db=db)
else: else:
# Insert new model # Insert new model
model_data["meta"] = model_data.get("meta", {}) model_data["meta"] = model_data.get("meta", {})
model_data["params"] = model_data.get("params", {}) model_data["params"] = model_data.get("params", {})
new_model = ModelForm(**model_data) new_model = ModelForm(**model_data)
Models.insert_new_model(user_id=user.id, form_data=new_model) Models.insert_new_model(user_id=user.id, form_data=new_model, db=db)
return True return True
else: else:
raise HTTPException(status_code=400, detail="Invalid JSON format") raise HTTPException(status_code=400, detail="Invalid JSON format")
@ -251,9 +256,9 @@ class SyncModelsForm(BaseModel):
@router.post("/sync", response_model=list[ModelModel]) @router.post("/sync", response_model=list[ModelModel])
async def sync_models( async def sync_models(
request: Request, form_data: SyncModelsForm, user=Depends(get_admin_user) request: Request, form_data: SyncModelsForm, user=Depends(get_admin_user), db: Session = Depends(get_session)
): ):
return Models.sync_models(user.id, form_data.models) return Models.sync_models(user.id, form_data.models, db=db)
########################### ###########################
@ -267,13 +272,13 @@ class ModelIdForm(BaseModel):
# Note: We're not using the typical url path param here, but instead using a query parameter to allow '/' in the id # Note: We're not using the typical url path param here, but instead using a query parameter to allow '/' in the id
@router.get("/model", response_model=Optional[ModelResponse]) @router.get("/model", response_model=Optional[ModelResponse])
async def get_model_by_id(id: str, user=Depends(get_verified_user)): async def get_model_by_id(id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)):
model = Models.get_model_by_id(id) model = Models.get_model_by_id(id, db=db)
if model: if model:
if ( if (
(user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL) (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL)
or model.user_id == user.id or model.user_id == user.id
or has_access(user.id, "read", model.access_control) or has_access(user.id, "read", model.access_control, db=db)
): ):
return model return model
else: else:
@ -289,8 +294,8 @@ async def get_model_by_id(id: str, user=Depends(get_verified_user)):
@router.get("/model/profile/image") @router.get("/model/profile/image")
async def get_model_profile_image(id: str, user=Depends(get_verified_user)): async def get_model_profile_image(id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)):
model = Models.get_model_by_id(id) model = Models.get_model_by_id(id, db=db)
# Cache-control headers to prevent stale cached images # Cache-control headers to prevent stale cached images
cache_headers = {"Cache-Control": "no-cache, must-revalidate"} cache_headers = {"Cache-Control": "no-cache, must-revalidate"}
@ -330,15 +335,15 @@ async def get_model_profile_image(id: str, user=Depends(get_verified_user)):
@router.post("/model/toggle", response_model=Optional[ModelResponse]) @router.post("/model/toggle", response_model=Optional[ModelResponse])
async def toggle_model_by_id(id: str, user=Depends(get_verified_user)): async def toggle_model_by_id(id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)):
model = Models.get_model_by_id(id) model = Models.get_model_by_id(id, db=db)
if model: if model:
if ( if (
user.role == "admin" user.role == "admin"
or model.user_id == user.id or model.user_id == user.id
or has_access(user.id, "write", model.access_control) or has_access(user.id, "write", model.access_control, db=db)
): ):
model = Models.toggle_model_by_id(id) model = Models.toggle_model_by_id(id, db=db)
if model: if model:
return model return model
@ -368,8 +373,9 @@ async def toggle_model_by_id(id: str, user=Depends(get_verified_user)):
async def update_model_by_id( async def update_model_by_id(
form_data: ModelForm, form_data: ModelForm,
user=Depends(get_verified_user), user=Depends(get_verified_user),
db: Session = Depends(get_session),
): ):
model = Models.get_model_by_id(form_data.id) model = Models.get_model_by_id(form_data.id, db=db)
if not model: if not model:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
@ -378,7 +384,7 @@ async def update_model_by_id(
if ( if (
model.user_id != user.id model.user_id != user.id
and not has_access(user.id, "write", model.access_control) and not has_access(user.id, "write", model.access_control, db=db)
and user.role != "admin" and user.role != "admin"
): ):
raise HTTPException( raise HTTPException(
@ -386,7 +392,7 @@ async def update_model_by_id(
detail=ERROR_MESSAGES.ACCESS_PROHIBITED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
) )
model = Models.update_model_by_id(form_data.id, ModelForm(**form_data.model_dump())) model = Models.update_model_by_id(form_data.id, ModelForm(**form_data.model_dump()), db=db)
return model return model
@ -396,8 +402,8 @@ async def update_model_by_id(
@router.post("/model/delete", response_model=bool) @router.post("/model/delete", response_model=bool)
async def delete_model_by_id(form_data: ModelIdForm, user=Depends(get_verified_user)): async def delete_model_by_id(form_data: ModelIdForm, user=Depends(get_verified_user), db: Session = Depends(get_session)):
model = Models.get_model_by_id(form_data.id) model = Models.get_model_by_id(form_data.id, db=db)
if not model: if not model:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
@ -407,18 +413,18 @@ async def delete_model_by_id(form_data: ModelIdForm, user=Depends(get_verified_u
if ( if (
user.role != "admin" user.role != "admin"
and model.user_id != user.id and model.user_id != user.id
and not has_access(user.id, "write", model.access_control) and not has_access(user.id, "write", model.access_control, db=db)
): ):
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.UNAUTHORIZED, detail=ERROR_MESSAGES.UNAUTHORIZED,
) )
result = Models.delete_model_by_id(form_data.id) result = Models.delete_model_by_id(form_data.id, db=db)
return result return result
@router.delete("/delete/all", response_model=bool) @router.delete("/delete/all", response_model=bool)
async def delete_all_models(user=Depends(get_admin_user)): async def delete_all_models(user=Depends(get_admin_user), db: Session = Depends(get_session)):
result = Models.delete_all_models() result = Models.delete_all_models(db=db)
return result return result

View file

@ -28,6 +28,8 @@ from open_webui.constants import ERROR_MESSAGES
from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.utils.access_control import has_access, has_permission from open_webui.utils.access_control import has_access, has_permission
from open_webui.internal.db import get_session
from sqlalchemy.orm import Session
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -49,10 +51,13 @@ class NoteItemResponse(BaseModel):
@router.get("/", response_model=list[NoteItemResponse]) @router.get("/", response_model=list[NoteItemResponse])
async def get_notes( async def get_notes(
request: Request, page: Optional[int] = None, user=Depends(get_verified_user) request: Request,
page: Optional[int] = None,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
): ):
if user.role != "admin" and not has_permission( if user.role != "admin" and not has_permission(
user.id, "features.notes", request.app.state.config.USER_PERMISSIONS user.id, "features.notes", request.app.state.config.USER_PERMISSIONS, db=db
): ):
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
@ -69,10 +74,14 @@ async def get_notes(
NoteUserResponse( NoteUserResponse(
**{ **{
**note.model_dump(), **note.model_dump(),
"user": UserResponse(**Users.get_user_by_id(note.user_id).model_dump()), "user": UserResponse(
**Users.get_user_by_id(note.user_id, db=db).model_dump()
),
} }
) )
for note in Notes.get_notes_by_user_id(user.id, "read", skip=skip, limit=limit) for note in Notes.get_notes_by_user_id(
user.id, "read", skip=skip, limit=limit, db=db
)
] ]
return notes return notes
@ -87,9 +96,10 @@ async def search_notes(
direction: Optional[str] = None, direction: Optional[str] = None,
page: Optional[int] = 1, page: Optional[int] = 1,
user=Depends(get_verified_user), user=Depends(get_verified_user),
db: Session = Depends(get_session),
): ):
if user.role != "admin" and not has_permission( if user.role != "admin" and not has_permission(
user.id, "features.notes", request.app.state.config.USER_PERMISSIONS user.id, "features.notes", request.app.state.config.USER_PERMISSIONS, db=db
): ):
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
@ -115,13 +125,13 @@ async def search_notes(
filter["direction"] = direction filter["direction"] = direction
if not user.role == "admin" or not BYPASS_ADMIN_ACCESS_CONTROL: if not user.role == "admin" or not BYPASS_ADMIN_ACCESS_CONTROL:
groups = Groups.get_groups_by_member_id(user.id) groups = Groups.get_groups_by_member_id(user.id, db=db)
if groups: if groups:
filter["group_ids"] = [group.id for group in groups] filter["group_ids"] = [group.id for group in groups]
filter["user_id"] = user.id filter["user_id"] = user.id
return Notes.search_notes(user.id, filter, skip=skip, limit=limit) return Notes.search_notes(user.id, filter, skip=skip, limit=limit, db=db)
############################ ############################
@ -131,10 +141,13 @@ async def search_notes(
@router.post("/create", response_model=Optional[NoteModel]) @router.post("/create", response_model=Optional[NoteModel])
async def create_new_note( async def create_new_note(
request: Request, form_data: NoteForm, user=Depends(get_verified_user) request: Request,
form_data: NoteForm,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
): ):
if user.role != "admin" and not has_permission( if user.role != "admin" and not has_permission(
user.id, "features.notes", request.app.state.config.USER_PERMISSIONS user.id, "features.notes", request.app.state.config.USER_PERMISSIONS, db=db
): ):
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
@ -142,7 +155,7 @@ async def create_new_note(
) )
try: try:
note = Notes.insert_new_note(user.id, form_data) note = Notes.insert_new_note(user.id, form_data, db=db)
return note return note
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
@ -161,16 +174,21 @@ class NoteResponse(NoteModel):
@router.get("/{id}", response_model=Optional[NoteResponse]) @router.get("/{id}", response_model=Optional[NoteResponse])
async def get_note_by_id(request: Request, id: str, user=Depends(get_verified_user)): async def get_note_by_id(
request: Request,
id: str,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
if user.role != "admin" and not has_permission( if user.role != "admin" and not has_permission(
user.id, "features.notes", request.app.state.config.USER_PERMISSIONS user.id, "features.notes", request.app.state.config.USER_PERMISSIONS, db=db
): ):
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.UNAUTHORIZED, detail=ERROR_MESSAGES.UNAUTHORIZED,
) )
note = Notes.get_note_by_id(id) note = Notes.get_note_by_id(id, db=db)
if not note: if not note:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
@ -178,7 +196,11 @@ async def get_note_by_id(request: Request, id: str, user=Depends(get_verified_us
if user.role != "admin" and ( if user.role != "admin" and (
user.id != note.user_id user.id != note.user_id
and (not has_access(user.id, type="read", access_control=note.access_control)) and (
not has_access(
user.id, type="read", access_control=note.access_control, db=db
)
)
): ):
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
@ -188,7 +210,11 @@ async def get_note_by_id(request: Request, id: str, user=Depends(get_verified_us
user.role == "admin" user.role == "admin"
or (user.id == note.user_id) or (user.id == note.user_id)
or has_access( or has_access(
user.id, type="write", access_control=note.access_control, strict=False user.id,
type="write",
access_control=note.access_control,
strict=False,
db=db,
) )
) )
@ -202,17 +228,21 @@ async def get_note_by_id(request: Request, id: str, user=Depends(get_verified_us
@router.post("/{id}/update", response_model=Optional[NoteModel]) @router.post("/{id}/update", response_model=Optional[NoteModel])
async def update_note_by_id( async def update_note_by_id(
request: Request, id: str, form_data: NoteForm, user=Depends(get_verified_user) request: Request,
id: str,
form_data: NoteForm,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
): ):
if user.role != "admin" and not has_permission( if user.role != "admin" and not has_permission(
user.id, "features.notes", request.app.state.config.USER_PERMISSIONS user.id, "features.notes", request.app.state.config.USER_PERMISSIONS, db=db
): ):
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.UNAUTHORIZED, detail=ERROR_MESSAGES.UNAUTHORIZED,
) )
note = Notes.get_note_by_id(id) note = Notes.get_note_by_id(id, db=db)
if not note: if not note:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
@ -220,7 +250,9 @@ async def update_note_by_id(
if user.role != "admin" and ( if user.role != "admin" and (
user.id != note.user_id user.id != note.user_id
and not has_access(user.id, type="write", access_control=note.access_control) and not has_access(
user.id, type="write", access_control=note.access_control, db=db
)
): ):
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
@ -234,12 +266,13 @@ async def update_note_by_id(
user.id, user.id,
"sharing.public_notes", "sharing.public_notes",
request.app.state.config.USER_PERMISSIONS, request.app.state.config.USER_PERMISSIONS,
db=db,
) )
): ):
form_data.access_control = {} form_data.access_control = {}
try: try:
note = Notes.update_note_by_id(id, form_data) note = Notes.update_note_by_id(id, form_data, db=db)
await sio.emit( await sio.emit(
"note-events", "note-events",
note.model_dump(), note.model_dump(),
@ -260,16 +293,21 @@ async def update_note_by_id(
@router.delete("/{id}/delete", response_model=bool) @router.delete("/{id}/delete", response_model=bool)
async def delete_note_by_id(request: Request, id: str, user=Depends(get_verified_user)): async def delete_note_by_id(
request: Request,
id: str,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
if user.role != "admin" and not has_permission( if user.role != "admin" and not has_permission(
user.id, "features.notes", request.app.state.config.USER_PERMISSIONS user.id, "features.notes", request.app.state.config.USER_PERMISSIONS, db=db
): ):
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.UNAUTHORIZED, detail=ERROR_MESSAGES.UNAUTHORIZED,
) )
note = Notes.get_note_by_id(id) note = Notes.get_note_by_id(id, db=db)
if not note: if not note:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
@ -277,14 +315,16 @@ async def delete_note_by_id(request: Request, id: str, user=Depends(get_verified
if user.role != "admin" and ( if user.role != "admin" and (
user.id != note.user_id user.id != note.user_id
and not has_access(user.id, type="write", access_control=note.access_control) and not has_access(
user.id, type="write", access_control=note.access_control, db=db
)
): ):
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
) )
try: try:
note = Notes.delete_note_by_id(id) note = Notes.delete_note_by_id(id, db=db)
return True return True
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)

View file

@ -11,6 +11,8 @@ from open_webui.constants import ERROR_MESSAGES
from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.utils.access_control import has_access, has_permission from open_webui.utils.access_control import has_access, has_permission
from open_webui.config import BYPASS_ADMIN_ACCESS_CONTROL from open_webui.config import BYPASS_ADMIN_ACCESS_CONTROL
from open_webui.internal.db import get_session
from sqlalchemy.orm import Session
router = APIRouter() router = APIRouter()
@ -20,21 +22,21 @@ router = APIRouter()
@router.get("/", response_model=list[PromptModel]) @router.get("/", response_model=list[PromptModel])
async def get_prompts(user=Depends(get_verified_user)): async def get_prompts(user=Depends(get_verified_user), db: Session = Depends(get_session)):
if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL: if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL:
prompts = Prompts.get_prompts() prompts = Prompts.get_prompts(db=db)
else: else:
prompts = Prompts.get_prompts_by_user_id(user.id, "read") prompts = Prompts.get_prompts_by_user_id(user.id, "read", db=db)
return prompts return prompts
@router.get("/list", response_model=list[PromptUserResponse]) @router.get("/list", response_model=list[PromptUserResponse])
async def get_prompt_list(user=Depends(get_verified_user)): async def get_prompt_list(user=Depends(get_verified_user), db: Session = Depends(get_session)):
if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL: if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL:
prompts = Prompts.get_prompts() prompts = Prompts.get_prompts(db=db)
else: else:
prompts = Prompts.get_prompts_by_user_id(user.id, "write") prompts = Prompts.get_prompts_by_user_id(user.id, "write", db=db)
return prompts return prompts
@ -46,16 +48,17 @@ async def get_prompt_list(user=Depends(get_verified_user)):
@router.post("/create", response_model=Optional[PromptModel]) @router.post("/create", response_model=Optional[PromptModel])
async def create_new_prompt( async def create_new_prompt(
request: Request, form_data: PromptForm, user=Depends(get_verified_user) request: Request, form_data: PromptForm, user=Depends(get_verified_user), db: Session = Depends(get_session)
): ):
if user.role != "admin" and not ( if user.role != "admin" and not (
has_permission( has_permission(
user.id, "workspace.prompts", request.app.state.config.USER_PERMISSIONS user.id, "workspace.prompts", request.app.state.config.USER_PERMISSIONS, db=db
) )
or has_permission( or has_permission(
user.id, user.id,
"workspace.prompts_import", "workspace.prompts_import",
request.app.state.config.USER_PERMISSIONS, request.app.state.config.USER_PERMISSIONS,
db=db,
) )
): ):
raise HTTPException( raise HTTPException(
@ -63,9 +66,9 @@ async def create_new_prompt(
detail=ERROR_MESSAGES.UNAUTHORIZED, detail=ERROR_MESSAGES.UNAUTHORIZED,
) )
prompt = Prompts.get_prompt_by_command(form_data.command) prompt = Prompts.get_prompt_by_command(form_data.command, db=db)
if prompt is None: if prompt is None:
prompt = Prompts.insert_new_prompt(user.id, form_data) prompt = Prompts.insert_new_prompt(user.id, form_data, db=db)
if prompt: if prompt:
return prompt return prompt
@ -85,14 +88,14 @@ async def create_new_prompt(
@router.get("/command/{command}", response_model=Optional[PromptModel]) @router.get("/command/{command}", response_model=Optional[PromptModel])
async def get_prompt_by_command(command: str, user=Depends(get_verified_user)): async def get_prompt_by_command(command: str, user=Depends(get_verified_user), db: Session = Depends(get_session)):
prompt = Prompts.get_prompt_by_command(f"/{command}") prompt = Prompts.get_prompt_by_command(f"/{command}", db=db)
if prompt: if prompt:
if ( if (
user.role == "admin" user.role == "admin"
or prompt.user_id == user.id or prompt.user_id == user.id
or has_access(user.id, "read", prompt.access_control) or has_access(user.id, "read", prompt.access_control, db=db)
): ):
return prompt return prompt
else: else:
@ -112,8 +115,9 @@ async def update_prompt_by_command(
command: str, command: str,
form_data: PromptForm, form_data: PromptForm,
user=Depends(get_verified_user), user=Depends(get_verified_user),
db: Session = Depends(get_session),
): ):
prompt = Prompts.get_prompt_by_command(f"/{command}") prompt = Prompts.get_prompt_by_command(f"/{command}", db=db)
if not prompt: if not prompt:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
@ -123,7 +127,7 @@ async def update_prompt_by_command(
# Is the user the original creator, in a group with write access, or an admin # Is the user the original creator, in a group with write access, or an admin
if ( if (
prompt.user_id != user.id prompt.user_id != user.id
and not has_access(user.id, "write", prompt.access_control) and not has_access(user.id, "write", prompt.access_control, db=db)
and user.role != "admin" and user.role != "admin"
): ):
raise HTTPException( raise HTTPException(
@ -131,7 +135,7 @@ async def update_prompt_by_command(
detail=ERROR_MESSAGES.ACCESS_PROHIBITED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
) )
prompt = Prompts.update_prompt_by_command(f"/{command}", form_data) prompt = Prompts.update_prompt_by_command(f"/{command}", form_data, db=db)
if prompt: if prompt:
return prompt return prompt
else: else:
@ -147,8 +151,8 @@ async def update_prompt_by_command(
@router.delete("/command/{command}/delete", response_model=bool) @router.delete("/command/{command}/delete", response_model=bool)
async def delete_prompt_by_command(command: str, user=Depends(get_verified_user)): async def delete_prompt_by_command(command: str, user=Depends(get_verified_user), db: Session = Depends(get_session)):
prompt = Prompts.get_prompt_by_command(f"/{command}") prompt = Prompts.get_prompt_by_command(f"/{command}", db=db)
if not prompt: if not prompt:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
@ -157,7 +161,7 @@ async def delete_prompt_by_command(command: str, user=Depends(get_verified_user)
if ( if (
prompt.user_id != user.id prompt.user_id != user.id
and not has_access(user.id, "write", prompt.access_control) and not has_access(user.id, "write", prompt.access_control, db=db)
and user.role != "admin" and user.role != "admin"
): ):
raise HTTPException( raise HTTPException(
@ -165,5 +169,5 @@ async def delete_prompt_by_command(command: str, user=Depends(get_verified_user)
detail=ERROR_MESSAGES.ACCESS_PROHIBITED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
) )
result = Prompts.delete_prompt_by_command(f"/{command}") result = Prompts.delete_prompt_by_command(f"/{command}", db=db)
return result return result

View file

@ -39,6 +39,8 @@ from langchain_core.documents import Document
from open_webui.models.files import FileModel, FileUpdateForm, Files from open_webui.models.files import FileModel, FileUpdateForm, Files
from open_webui.models.knowledge import Knowledges from open_webui.models.knowledge import Knowledges
from open_webui.storage.provider import Storage from open_webui.storage.provider import Storage
from open_webui.internal.db import get_session
from sqlalchemy.orm import Session
from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT
@ -1484,14 +1486,15 @@ def process_file(
request: Request, request: Request,
form_data: ProcessFileForm, form_data: ProcessFileForm,
user=Depends(get_verified_user), user=Depends(get_verified_user),
db: Session = Depends(get_session),
): ):
""" """
Process a file and save its content to the vector database. Process a file and save its content to the vector database.
""" """
if user.role == "admin": if user.role == "admin":
file = Files.get_file_by_id(form_data.file_id) file = Files.get_file_by_id(form_data.file_id, db=db)
else: else:
file = Files.get_file_by_id_and_user_id(form_data.file_id, user.id) file = Files.get_file_by_id_and_user_id(form_data.file_id, user.id, db=db)
if file: if file:
try: try:
@ -1633,12 +1636,13 @@ def process_file(
Files.update_file_data_by_id( Files.update_file_data_by_id(
file.id, file.id,
{"content": text_content}, {"content": text_content},
db=db,
) )
hash = calculate_sha256_string(text_content) hash = calculate_sha256_string(text_content)
Files.update_file_hash_by_id(file.id, hash) Files.update_file_hash_by_id(file.id, hash, db=db)
if request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL: if request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL:
Files.update_file_data_by_id(file.id, {"status": "completed"}) Files.update_file_data_by_id(file.id, {"status": "completed"}, db=db)
return { return {
"status": True, "status": True,
"collection_name": None, "collection_name": None,
@ -1667,11 +1671,13 @@ def process_file(
{ {
"collection_name": collection_name, "collection_name": collection_name,
}, },
db=db,
) )
Files.update_file_data_by_id( Files.update_file_data_by_id(
file.id, file.id,
{"status": "completed"}, {"status": "completed"},
db=db,
) )
return { return {
@ -1690,6 +1696,7 @@ def process_file(
Files.update_file_data_by_id( Files.update_file_data_by_id(
file.id, file.id,
{"status": "failed"}, {"status": "failed"},
db=db,
) )
if "No pandoc was found" in str(e): if "No pandoc was found" in str(e):
@ -2417,10 +2424,10 @@ class DeleteForm(BaseModel):
@router.post("/delete") @router.post("/delete")
def delete_entries_from_collection(form_data: DeleteForm, user=Depends(get_admin_user)): def delete_entries_from_collection(form_data: DeleteForm, user=Depends(get_admin_user), db: Session = Depends(get_session)):
try: try:
if VECTOR_DB_CLIENT.has_collection(collection_name=form_data.collection_name): if VECTOR_DB_CLIENT.has_collection(collection_name=form_data.collection_name):
file = Files.get_file_by_id(form_data.file_id) file = Files.get_file_by_id(form_data.file_id, db=db)
hash = file.hash hash = file.hash
VECTOR_DB_CLIENT.delete( VECTOR_DB_CLIENT.delete(
@ -2436,9 +2443,9 @@ def delete_entries_from_collection(form_data: DeleteForm, user=Depends(get_admin
@router.post("/reset/db") @router.post("/reset/db")
def reset_vector_db(user=Depends(get_admin_user)): def reset_vector_db(user=Depends(get_admin_user), db: Session = Depends(get_session)):
VECTOR_DB_CLIENT.reset() VECTOR_DB_CLIENT.reset()
Knowledges.delete_all_knowledge() Knowledges.delete_all_knowledge(db=db)
@router.post("/reset/uploads") @router.post("/reset/uploads")
@ -2496,6 +2503,7 @@ async def process_files_batch(
request: Request, request: Request,
form_data: BatchProcessFilesForm, form_data: BatchProcessFilesForm,
user=Depends(get_verified_user), user=Depends(get_verified_user),
db: Session = Depends(get_session),
) -> BatchProcessFilesResponse: ) -> BatchProcessFilesResponse:
""" """
Process a batch of files and save them to the vector database. Process a batch of files and save them to the vector database.
@ -2558,7 +2566,7 @@ async def process_files_batch(
# Update all files with collection name # Update all files with collection name
for file_update, file_result in zip(file_updates, file_results): for file_update, file_result in zip(file_updates, file_results):
Files.update_file_by_id(id=file_result.file_id, form_data=file_update) Files.update_file_by_id(id=file_result.file_id, form_data=file_update, db=db)
file_result.status = "completed" file_result.status = "completed"
except Exception as e: except Exception as e:

View file

@ -7,6 +7,8 @@ import aiohttp
from open_webui.models.groups import Groups from open_webui.models.groups import Groups
from pydantic import BaseModel, HttpUrl from pydantic import BaseModel, HttpUrl
from fastapi import APIRouter, Depends, HTTPException, Request, status from fastapi import APIRouter, Depends, HTTPException, Request, status
from sqlalchemy.orm import Session
from open_webui.internal.db import get_session
from open_webui.models.oauth_sessions import OAuthSessions from open_webui.models.oauth_sessions import OAuthSessions
@ -51,11 +53,11 @@ def get_tool_module(request, tool_id, load_from_db=True):
@router.get("/", response_model=list[ToolUserResponse]) @router.get("/", response_model=list[ToolUserResponse])
async def get_tools(request: Request, user=Depends(get_verified_user)): async def get_tools(request: Request, user=Depends(get_verified_user), db: Session = Depends(get_session)):
tools = [] tools = []
# Local Tools # Local Tools
for tool in Tools.get_tools(): for tool in Tools.get_tools(db=db):
tool_module = get_tool_module(request, tool.id) tool_module = get_tool_module(request, tool.id)
tools.append( tools.append(
ToolUserResponse( ToolUserResponse(
@ -140,12 +142,12 @@ async def get_tools(request: Request, user=Depends(get_verified_user)):
# Admin can see all tools # Admin can see all tools
return tools return tools
else: else:
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user.id)} user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user.id, db=db)}
tools = [ tools = [
tool tool
for tool in tools for tool in tools
if tool.user_id == user.id if tool.user_id == user.id
or has_access(user.id, "read", tool.access_control, user_group_ids) or has_access(user.id, "read", tool.access_control, user_group_ids, db=db)
] ]
return tools return tools
@ -156,11 +158,11 @@ async def get_tools(request: Request, user=Depends(get_verified_user)):
@router.get("/list", response_model=list[ToolUserResponse]) @router.get("/list", response_model=list[ToolUserResponse])
async def get_tool_list(user=Depends(get_verified_user)): async def get_tool_list(user=Depends(get_verified_user), db: Session = Depends(get_session)):
if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL: if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL:
tools = Tools.get_tools() tools = Tools.get_tools(db=db)
else: else:
tools = Tools.get_tools_by_user_id(user.id, "write") tools = Tools.get_tools_by_user_id(user.id, "write", db=db)
return tools return tools
@ -245,9 +247,9 @@ async def load_tool_from_url(
@router.get("/export", response_model=list[ToolModel]) @router.get("/export", response_model=list[ToolModel])
async def export_tools(request: Request, user=Depends(get_verified_user)): async def export_tools(request: Request, user=Depends(get_verified_user), db: Session = Depends(get_session)):
if user.role != "admin" and not has_permission( if user.role != "admin" and not has_permission(
user.id, "workspace.tools_export", request.app.state.config.USER_PERMISSIONS user.id, "workspace.tools_export", request.app.state.config.USER_PERMISSIONS, db=db
): ):
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
@ -255,9 +257,9 @@ async def export_tools(request: Request, user=Depends(get_verified_user)):
) )
if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL: if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL:
return Tools.get_tools() return Tools.get_tools(db=db)
else: else:
return Tools.get_tools_by_user_id(user.id, "read") return Tools.get_tools_by_user_id(user.id, "read", db=db)
############################ ############################
@ -270,13 +272,14 @@ async def create_new_tools(
request: Request, request: Request,
form_data: ToolForm, form_data: ToolForm,
user=Depends(get_verified_user), user=Depends(get_verified_user),
db: Session = Depends(get_session),
): ):
if user.role != "admin" and not ( if user.role != "admin" and not (
has_permission( has_permission(
user.id, "workspace.tools", request.app.state.config.USER_PERMISSIONS user.id, "workspace.tools", request.app.state.config.USER_PERMISSIONS, db=db
) )
or has_permission( or has_permission(
user.id, "workspace.tools_import", request.app.state.config.USER_PERMISSIONS user.id, "workspace.tools_import", request.app.state.config.USER_PERMISSIONS, db=db
) )
): ):
raise HTTPException( raise HTTPException(
@ -292,7 +295,7 @@ async def create_new_tools(
form_data.id = form_data.id.lower() form_data.id = form_data.id.lower()
tools = Tools.get_tool_by_id(form_data.id) tools = Tools.get_tool_by_id(form_data.id, db=db)
if tools is None: if tools is None:
try: try:
form_data.content = replace_imports(form_data.content) form_data.content = replace_imports(form_data.content)
@ -305,7 +308,7 @@ async def create_new_tools(
TOOLS[form_data.id] = tool_module TOOLS[form_data.id] = tool_module
specs = get_tool_specs(TOOLS[form_data.id]) specs = get_tool_specs(TOOLS[form_data.id])
tools = Tools.insert_new_tool(user.id, form_data, specs) tools = Tools.insert_new_tool(user.id, form_data, specs, db=db)
tool_cache_dir = CACHE_DIR / "tools" / form_data.id tool_cache_dir = CACHE_DIR / "tools" / form_data.id
tool_cache_dir.mkdir(parents=True, exist_ok=True) tool_cache_dir.mkdir(parents=True, exist_ok=True)
@ -336,14 +339,14 @@ async def create_new_tools(
@router.get("/id/{id}", response_model=Optional[ToolModel]) @router.get("/id/{id}", response_model=Optional[ToolModel])
async def get_tools_by_id(id: str, user=Depends(get_verified_user)): async def get_tools_by_id(id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)):
tools = Tools.get_tool_by_id(id) tools = Tools.get_tool_by_id(id, db=db)
if tools: if tools:
if ( if (
user.role == "admin" user.role == "admin"
or tools.user_id == user.id or tools.user_id == user.id
or has_access(user.id, "read", tools.access_control) or has_access(user.id, "read", tools.access_control, db=db)
): ):
return tools return tools
else: else:
@ -364,8 +367,9 @@ async def update_tools_by_id(
id: str, id: str,
form_data: ToolForm, form_data: ToolForm,
user=Depends(get_verified_user), user=Depends(get_verified_user),
db: Session = Depends(get_session),
): ):
tools = Tools.get_tool_by_id(id) tools = Tools.get_tool_by_id(id, db=db)
if not tools: if not tools:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
@ -375,7 +379,7 @@ async def update_tools_by_id(
# Is the user the original creator, in a group with write access, or an admin # Is the user the original creator, in a group with write access, or an admin
if ( if (
tools.user_id != user.id tools.user_id != user.id
and not has_access(user.id, "write", tools.access_control) and not has_access(user.id, "write", tools.access_control, db=db)
and user.role != "admin" and user.role != "admin"
): ):
raise HTTPException( raise HTTPException(
@ -399,7 +403,7 @@ async def update_tools_by_id(
} }
log.debug(updated) log.debug(updated)
tools = Tools.update_tool_by_id(id, updated) tools = Tools.update_tool_by_id(id, updated, db=db)
if tools: if tools:
return tools return tools
@ -423,9 +427,9 @@ async def update_tools_by_id(
@router.delete("/id/{id}/delete", response_model=bool) @router.delete("/id/{id}/delete", response_model=bool)
async def delete_tools_by_id( async def delete_tools_by_id(
request: Request, id: str, user=Depends(get_verified_user) request: Request, id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
): ):
tools = Tools.get_tool_by_id(id) tools = Tools.get_tool_by_id(id, db=db)
if not tools: if not tools:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
@ -434,7 +438,7 @@ async def delete_tools_by_id(
if ( if (
tools.user_id != user.id tools.user_id != user.id
and not has_access(user.id, "write", tools.access_control) and not has_access(user.id, "write", tools.access_control, db=db)
and user.role != "admin" and user.role != "admin"
): ):
raise HTTPException( raise HTTPException(
@ -442,7 +446,7 @@ async def delete_tools_by_id(
detail=ERROR_MESSAGES.UNAUTHORIZED, detail=ERROR_MESSAGES.UNAUTHORIZED,
) )
result = Tools.delete_tool_by_id(id) result = Tools.delete_tool_by_id(id, db=db)
if result: if result:
TOOLS = request.app.state.TOOLS TOOLS = request.app.state.TOOLS
if id in TOOLS: if id in TOOLS:
@ -457,11 +461,11 @@ async def delete_tools_by_id(
@router.get("/id/{id}/valves", response_model=Optional[dict]) @router.get("/id/{id}/valves", response_model=Optional[dict])
async def get_tools_valves_by_id(id: str, user=Depends(get_verified_user)): async def get_tools_valves_by_id(id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)):
tools = Tools.get_tool_by_id(id) tools = Tools.get_tool_by_id(id, db=db)
if tools: if tools:
try: try:
valves = Tools.get_tool_valves_by_id(id) valves = Tools.get_tool_valves_by_id(id, db=db)
return valves return valves
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
@ -482,9 +486,9 @@ async def get_tools_valves_by_id(id: str, user=Depends(get_verified_user)):
@router.get("/id/{id}/valves/spec", response_model=Optional[dict]) @router.get("/id/{id}/valves/spec", response_model=Optional[dict])
async def get_tools_valves_spec_by_id( async def get_tools_valves_spec_by_id(
request: Request, id: str, user=Depends(get_verified_user) request: Request, id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
): ):
tools = Tools.get_tool_by_id(id) tools = Tools.get_tool_by_id(id, db=db)
if tools: if tools:
if id in request.app.state.TOOLS: if id in request.app.state.TOOLS:
tools_module = request.app.state.TOOLS[id] tools_module = request.app.state.TOOLS[id]
@ -510,9 +514,9 @@ async def get_tools_valves_spec_by_id(
@router.post("/id/{id}/valves/update", response_model=Optional[dict]) @router.post("/id/{id}/valves/update", response_model=Optional[dict])
async def update_tools_valves_by_id( async def update_tools_valves_by_id(
request: Request, id: str, form_data: dict, user=Depends(get_verified_user) request: Request, id: str, form_data: dict, user=Depends(get_verified_user), db: Session = Depends(get_session)
): ):
tools = Tools.get_tool_by_id(id) tools = Tools.get_tool_by_id(id, db=db)
if not tools: if not tools:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
@ -521,7 +525,7 @@ async def update_tools_valves_by_id(
if ( if (
tools.user_id != user.id tools.user_id != user.id
and not has_access(user.id, "write", tools.access_control) and not has_access(user.id, "write", tools.access_control, db=db)
and user.role != "admin" and user.role != "admin"
): ):
raise HTTPException( raise HTTPException(
@ -546,7 +550,7 @@ async def update_tools_valves_by_id(
form_data = {k: v for k, v in form_data.items() if v is not None} form_data = {k: v for k, v in form_data.items() if v is not None}
valves = Valves(**form_data) valves = Valves(**form_data)
valves_dict = valves.model_dump(exclude_unset=True) valves_dict = valves.model_dump(exclude_unset=True)
Tools.update_tool_valves_by_id(id, valves_dict) Tools.update_tool_valves_by_id(id, valves_dict, db=db)
return valves_dict return valves_dict
except Exception as e: except Exception as e:
log.exception(f"Failed to update tool valves by id {id}: {e}") log.exception(f"Failed to update tool valves by id {id}: {e}")
@ -562,11 +566,11 @@ async def update_tools_valves_by_id(
@router.get("/id/{id}/valves/user", response_model=Optional[dict]) @router.get("/id/{id}/valves/user", response_model=Optional[dict])
async def get_tools_user_valves_by_id(id: str, user=Depends(get_verified_user)): async def get_tools_user_valves_by_id(id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)):
tools = Tools.get_tool_by_id(id) tools = Tools.get_tool_by_id(id, db=db)
if tools: if tools:
try: try:
user_valves = Tools.get_user_valves_by_id_and_user_id(id, user.id) user_valves = Tools.get_user_valves_by_id_and_user_id(id, user.id, db=db)
return user_valves return user_valves
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
@ -582,9 +586,9 @@ async def get_tools_user_valves_by_id(id: str, user=Depends(get_verified_user)):
@router.get("/id/{id}/valves/user/spec", response_model=Optional[dict]) @router.get("/id/{id}/valves/user/spec", response_model=Optional[dict])
async def get_tools_user_valves_spec_by_id( async def get_tools_user_valves_spec_by_id(
request: Request, id: str, user=Depends(get_verified_user) request: Request, id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
): ):
tools = Tools.get_tool_by_id(id) tools = Tools.get_tool_by_id(id, db=db)
if tools: if tools:
if id in request.app.state.TOOLS: if id in request.app.state.TOOLS:
tools_module = request.app.state.TOOLS[id] tools_module = request.app.state.TOOLS[id]
@ -605,9 +609,9 @@ async def get_tools_user_valves_spec_by_id(
@router.post("/id/{id}/valves/user/update", response_model=Optional[dict]) @router.post("/id/{id}/valves/user/update", response_model=Optional[dict])
async def update_tools_user_valves_by_id( async def update_tools_user_valves_by_id(
request: Request, id: str, form_data: dict, user=Depends(get_verified_user) request: Request, id: str, form_data: dict, user=Depends(get_verified_user), db: Session = Depends(get_session)
): ):
tools = Tools.get_tool_by_id(id) tools = Tools.get_tool_by_id(id, db=db)
if tools: if tools:
if id in request.app.state.TOOLS: if id in request.app.state.TOOLS:
@ -624,7 +628,7 @@ async def update_tools_user_valves_by_id(
user_valves = UserValves(**form_data) user_valves = UserValves(**form_data)
user_valves_dict = user_valves.model_dump(exclude_unset=True) user_valves_dict = user_valves.model_dump(exclude_unset=True)
Tools.update_user_valves_by_id_and_user_id( Tools.update_user_valves_by_id_and_user_id(
id, user.id, user_valves_dict id, user.id, user_valves_dict, db=db
) )
return user_valves_dict return user_valves_dict
except Exception as e: except Exception as e:

View file

@ -1,5 +1,6 @@
import logging import logging
from typing import Optional from typing import Optional
from sqlalchemy.orm import Session
import base64 import base64
import io import io
@ -29,6 +30,7 @@ from open_webui.models.users import (
from open_webui.constants import ERROR_MESSAGES from open_webui.constants import ERROR_MESSAGES
from open_webui.env import STATIC_DIR from open_webui.env import STATIC_DIR
from open_webui.internal.db import get_session
from open_webui.utils.auth import ( from open_webui.utils.auth import (
@ -60,6 +62,7 @@ async def get_users(
direction: Optional[str] = None, direction: Optional[str] = None,
page: Optional[int] = 1, page: Optional[int] = 1,
user=Depends(get_admin_user), user=Depends(get_admin_user),
db: Session = Depends(get_session),
): ):
limit = PAGE_ITEM_COUNT limit = PAGE_ITEM_COUNT
@ -74,7 +77,9 @@ async def get_users(
if direction: if direction:
filter["direction"] = direction filter["direction"] = direction
result = Users.get_users(filter=filter, skip=skip, limit=limit) filter["direction"] = direction
result = Users.get_users(filter=filter, skip=skip, limit=limit, db=db)
users = result["users"] users = result["users"]
total = result["total"] total = result["total"]
@ -85,7 +90,8 @@ async def get_users(
**{ **{
**user.model_dump(), **user.model_dump(),
"group_ids": [ "group_ids": [
group.id for group in Groups.get_groups_by_member_id(user.id) group.id
for group in Groups.get_groups_by_member_id(user.id, db=db)
], ],
} }
) )
@ -98,8 +104,9 @@ async def get_users(
@router.get("/all", response_model=UserInfoListResponse) @router.get("/all", response_model=UserInfoListResponse)
async def get_all_users( async def get_all_users(
user=Depends(get_admin_user), user=Depends(get_admin_user),
db: Session = Depends(get_session),
): ):
return Users.get_users() return Users.get_users(db=db)
@router.get("/search", response_model=UserInfoListResponse) @router.get("/search", response_model=UserInfoListResponse)
@ -109,16 +116,13 @@ async def search_users(
direction: Optional[str] = None, direction: Optional[str] = None,
page: Optional[int] = 1, page: Optional[int] = 1,
user=Depends(get_verified_user), user=Depends(get_verified_user),
db: Session = Depends(get_session),
): ):
limit = PAGE_ITEM_COUNT limit = PAGE_ITEM_COUNT
page = max(1, page) page = max(1, page)
skip = (page - 1) * limit skip = (page - 1) * limit
filter = {}
if query:
filter["query"] = query
filter = {} filter = {}
if query: if query:
filter["query"] = query filter["query"] = query
@ -127,7 +131,7 @@ async def search_users(
if direction: if direction:
filter["direction"] = direction filter["direction"] = direction
return Users.get_users(filter=filter, skip=skip, limit=limit) return Users.get_users(filter=filter, skip=skip, limit=limit, db=db)
############################ ############################
@ -136,8 +140,10 @@ async def search_users(
@router.get("/groups") @router.get("/groups")
async def get_user_groups(user=Depends(get_verified_user)): async def get_user_groups(
return Groups.get_groups_by_member_id(user.id) user=Depends(get_verified_user), db: Session = Depends(get_session)
):
return Groups.get_groups_by_member_id(user.id, db=db)
############################ ############################
@ -146,9 +152,13 @@ async def get_user_groups(user=Depends(get_verified_user)):
@router.get("/permissions") @router.get("/permissions")
async def get_user_permissisions(request: Request, user=Depends(get_verified_user)): async def get_user_permissisions(
request: Request,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
user_permissions = get_permissions( user_permissions = get_permissions(
user.id, request.app.state.config.USER_PERMISSIONS user.id, request.app.state.config.USER_PERMISSIONS, db=db
) )
return user_permissions return user_permissions
@ -256,8 +266,10 @@ async def update_default_user_permissions(
@router.get("/user/settings", response_model=Optional[UserSettings]) @router.get("/user/settings", response_model=Optional[UserSettings])
async def get_user_settings_by_session_user(user=Depends(get_verified_user)): async def get_user_settings_by_session_user(
user = Users.get_user_by_id(user.id) user=Depends(get_verified_user), db: Session = Depends(get_session)
):
user = Users.get_user_by_id(user.id, db=db)
if user: if user:
return user.settings return user.settings
else: else:
@ -274,7 +286,10 @@ async def get_user_settings_by_session_user(user=Depends(get_verified_user)):
@router.post("/user/settings/update", response_model=UserSettings) @router.post("/user/settings/update", response_model=UserSettings)
async def update_user_settings_by_session_user( async def update_user_settings_by_session_user(
request: Request, form_data: UserSettings, user=Depends(get_verified_user) request: Request,
form_data: UserSettings,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
): ):
updated_user_settings = form_data.model_dump() updated_user_settings = form_data.model_dump()
if ( if (
@ -289,7 +304,7 @@ async def update_user_settings_by_session_user(
# If the user is not an admin and does not have permission to use tool servers, remove the key # If the user is not an admin and does not have permission to use tool servers, remove the key
updated_user_settings["ui"].pop("toolServers", None) updated_user_settings["ui"].pop("toolServers", None)
user = Users.update_user_settings_by_id(user.id, updated_user_settings) user = Users.update_user_settings_by_id(user.id, updated_user_settings, db=db)
if user: if user:
return user.settings return user.settings
else: else:
@ -305,8 +320,10 @@ async def update_user_settings_by_session_user(
@router.get("/user/status") @router.get("/user/status")
async def get_user_status_by_session_user(user=Depends(get_verified_user)): async def get_user_status_by_session_user(
user = Users.get_user_by_id(user.id) user=Depends(get_verified_user), db: Session = Depends(get_session)
):
user = Users.get_user_by_id(user.id, db=db)
if user: if user:
return user return user
else: else:
@ -323,11 +340,13 @@ async def get_user_status_by_session_user(user=Depends(get_verified_user)):
@router.post("/user/status/update") @router.post("/user/status/update")
async def update_user_status_by_session_user( async def update_user_status_by_session_user(
form_data: UserStatus, user=Depends(get_verified_user) form_data: UserStatus,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
): ):
user = Users.get_user_by_id(user.id) user = Users.get_user_by_id(user.id, db=db)
if user: if user:
user = Users.update_user_status_by_id(user.id, form_data) user = Users.update_user_status_by_id(user.id, form_data, db=db)
return user return user
else: else:
raise HTTPException( raise HTTPException(
@ -342,8 +361,10 @@ async def update_user_status_by_session_user(
@router.get("/user/info", response_model=Optional[dict]) @router.get("/user/info", response_model=Optional[dict])
async def get_user_info_by_session_user(user=Depends(get_verified_user)): async def get_user_info_by_session_user(
user = Users.get_user_by_id(user.id) user=Depends(get_verified_user), db: Session = Depends(get_session)
):
user = Users.get_user_by_id(user.id, db=db)
if user: if user:
return user.info return user.info
else: else:
@ -360,14 +381,16 @@ async def get_user_info_by_session_user(user=Depends(get_verified_user)):
@router.post("/user/info/update", response_model=Optional[dict]) @router.post("/user/info/update", response_model=Optional[dict])
async def update_user_info_by_session_user( async def update_user_info_by_session_user(
form_data: dict, user=Depends(get_verified_user) form_data: dict, user=Depends(get_verified_user), db: Session = Depends(get_session)
): ):
user = Users.get_user_by_id(user.id) user = Users.get_user_by_id(user.id, db=db)
if user: if user:
if user.info is None: if user.info is None:
user.info = {} user.info = {}
user = Users.update_user_by_id(user.id, {"info": {**user.info, **form_data}}) user = Users.update_user_by_id(
user.id, {"info": {**user.info, **form_data}}, db=db
)
if user: if user:
return user.info return user.info
else: else:
@ -397,7 +420,9 @@ class UserActiveResponse(UserStatus):
@router.get("/{user_id}", response_model=UserActiveResponse) @router.get("/{user_id}", response_model=UserActiveResponse)
async def get_user_by_id(user_id: str, user=Depends(get_verified_user)): async def get_user_by_id(
user_id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
):
# Check if user_id is a shared chat # Check if user_id is a shared chat
# If it is, get the user_id from the chat # If it is, get the user_id from the chat
if user_id.startswith("shared-"): if user_id.startswith("shared-"):
@ -411,14 +436,14 @@ async def get_user_by_id(user_id: str, user=Depends(get_verified_user)):
detail=ERROR_MESSAGES.USER_NOT_FOUND, detail=ERROR_MESSAGES.USER_NOT_FOUND,
) )
user = Users.get_user_by_id(user_id) user = Users.get_user_by_id(user_id, db=db)
if user: if user:
groups = Groups.get_groups_by_member_id(user_id) groups = Groups.get_groups_by_member_id(user_id, db=db)
return UserActiveResponse( return UserActiveResponse(
**{ **{
**user.model_dump(), **user.model_dump(),
"groups": [{"id": group.id, "name": group.name} for group in groups], "groups": [{"id": group.id, "name": group.name} for group in groups],
"is_active": Users.is_user_active(user_id), "is_active": Users.is_user_active(user_id, db=db),
} }
) )
else: else:
@ -429,8 +454,10 @@ async def get_user_by_id(user_id: str, user=Depends(get_verified_user)):
@router.get("/{user_id}/oauth/sessions") @router.get("/{user_id}/oauth/sessions")
async def get_user_oauth_sessions_by_id(user_id: str, user=Depends(get_admin_user)): async def get_user_oauth_sessions_by_id(
sessions = OAuthSessions.get_sessions_by_user_id(user_id) user_id: str, user=Depends(get_admin_user), db: Session = Depends(get_session)
):
sessions = OAuthSessions.get_sessions_by_user_id(user_id, db=db)
if sessions and len(sessions) > 0: if sessions and len(sessions) > 0:
return sessions return sessions
else: else:
@ -446,8 +473,10 @@ async def get_user_oauth_sessions_by_id(user_id: str, user=Depends(get_admin_use
@router.get("/{user_id}/profile/image") @router.get("/{user_id}/profile/image")
async def get_user_profile_image_by_id(user_id: str, user=Depends(get_verified_user)): async def get_user_profile_image_by_id(
user = Users.get_user_by_id(user_id) user_id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
):
user = Users.get_user_by_id(user_id, db=db)
if user: if user:
if user.profile_image_url: if user.profile_image_url:
# check if it's url or base64 # check if it's url or base64
@ -484,9 +513,11 @@ async def get_user_profile_image_by_id(user_id: str, user=Depends(get_verified_u
@router.get("/{user_id}/active", response_model=dict) @router.get("/{user_id}/active", response_model=dict)
async def get_user_active_status_by_id(user_id: str, user=Depends(get_verified_user)): async def get_user_active_status_by_id(
user_id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
):
return { return {
"active": Users.is_user_active(user_id), "active": Users.is_user_active(user_id, db=db),
} }
@ -500,10 +531,11 @@ async def update_user_by_id(
user_id: str, user_id: str,
form_data: UserUpdateForm, form_data: UserUpdateForm,
session_user=Depends(get_admin_user), session_user=Depends(get_admin_user),
db: Session = Depends(get_session),
): ):
# Prevent modification of the primary admin user by other admins # Prevent modification of the primary admin user by other admins
try: try:
first_user = Users.get_first_user() first_user = Users.get_first_user(db=db)
if first_user: if first_user:
if user_id == first_user.id: if user_id == first_user.id:
if session_user.id != user_id: if session_user.id != user_id:
@ -527,11 +559,11 @@ async def update_user_by_id(
detail="Could not verify primary admin status.", detail="Could not verify primary admin status.",
) )
user = Users.get_user_by_id(user_id) user = Users.get_user_by_id(user_id, db=db)
if user: if user:
if form_data.email.lower() != user.email: if form_data.email.lower() != user.email:
email_user = Users.get_user_by_email(form_data.email.lower()) email_user = Users.get_user_by_email(form_data.email.lower(), db=db)
if email_user: if email_user:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
@ -545,9 +577,9 @@ async def update_user_by_id(
raise HTTPException(400, detail=str(e)) raise HTTPException(400, detail=str(e))
hashed = get_password_hash(form_data.password) hashed = get_password_hash(form_data.password)
Auths.update_user_password_by_id(user_id, hashed) Auths.update_user_password_by_id(user_id, hashed, db=db)
Auths.update_email_by_id(user_id, form_data.email.lower()) Auths.update_email_by_id(user_id, form_data.email.lower(), db=db)
updated_user = Users.update_user_by_id( updated_user = Users.update_user_by_id(
user_id, user_id,
{ {
@ -556,6 +588,7 @@ async def update_user_by_id(
"email": form_data.email.lower(), "email": form_data.email.lower(),
"profile_image_url": form_data.profile_image_url, "profile_image_url": form_data.profile_image_url,
}, },
db=db,
) )
if updated_user: if updated_user:
@ -578,10 +611,12 @@ async def update_user_by_id(
@router.delete("/{user_id}", response_model=bool) @router.delete("/{user_id}", response_model=bool)
async def delete_user_by_id(user_id: str, user=Depends(get_admin_user)): async def delete_user_by_id(
user_id: str, user=Depends(get_admin_user), db: Session = Depends(get_session)
):
# Prevent deletion of the primary admin user # Prevent deletion of the primary admin user
try: try:
first_user = Users.get_first_user() first_user = Users.get_first_user(db=db)
if first_user and user_id == first_user.id: if first_user and user_id == first_user.id:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
@ -595,7 +630,7 @@ async def delete_user_by_id(user_id: str, user=Depends(get_admin_user)):
) )
if user.id != user_id: if user.id != user_id:
result = Auths.delete_auth_by_id(user_id) result = Auths.delete_auth_by_id(user_id, db=db)
if result: if result:
return True return True
@ -618,5 +653,7 @@ async def delete_user_by_id(user_id: str, user=Depends(get_admin_user)):
@router.get("/{user_id}/groups") @router.get("/{user_id}/groups")
async def get_user_groups_by_id(user_id: str, user=Depends(get_admin_user)): async def get_user_groups_by_id(
return Groups.get_groups_by_member_id(user_id) user_id: str, user=Depends(get_admin_user), db: Session = Depends(get_session)
):
return Groups.get_groups_by_member_id(user_id, db=db)

View file

@ -28,6 +28,7 @@ def fill_missing_permissions(
def get_permissions( def get_permissions(
user_id: str, user_id: str,
default_permissions: Dict[str, Any], default_permissions: Dict[str, Any],
db: Optional[Any] = None,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
Get all permissions for a user by combining the permissions of all groups the user is a member of. Get all permissions for a user by combining the permissions of all groups the user is a member of.
@ -53,7 +54,7 @@ def get_permissions(
) # Use the most permissive value (True > False) ) # Use the most permissive value (True > False)
return permissions return permissions
user_groups = Groups.get_groups_by_member_id(user_id) user_groups = Groups.get_groups_by_member_id(user_id, db=db)
# Deep copy default permissions to avoid modifying the original dict # Deep copy default permissions to avoid modifying the original dict
permissions = json.loads(json.dumps(default_permissions)) permissions = json.loads(json.dumps(default_permissions))
@ -72,6 +73,7 @@ def has_permission(
user_id: str, user_id: str,
permission_key: str, permission_key: str,
default_permissions: Dict[str, Any] = {}, default_permissions: Dict[str, Any] = {},
db: Optional[Any] = None,
) -> bool: ) -> bool:
""" """
Check if a user has a specific permission by checking the group permissions Check if a user has a specific permission by checking the group permissions
@ -92,7 +94,7 @@ def has_permission(
permission_hierarchy = permission_key.split(".") permission_hierarchy = permission_key.split(".")
# Retrieve user group permissions # Retrieve user group permissions
user_groups = Groups.get_groups_by_member_id(user_id) user_groups = Groups.get_groups_by_member_id(user_id, db=db)
for group in user_groups: for group in user_groups:
if get_permission(group.permissions or {}, permission_hierarchy): if get_permission(group.permissions or {}, permission_hierarchy):
@ -127,6 +129,7 @@ def has_access(
access_control: Optional[dict] = None, access_control: Optional[dict] = None,
user_group_ids: Optional[Set[str]] = None, user_group_ids: Optional[Set[str]] = None,
strict: bool = True, strict: bool = True,
db: Optional[Any] = None,
) -> bool: ) -> bool:
if access_control is None: if access_control is None:
if strict: if strict:
@ -135,7 +138,7 @@ def has_access(
return True return True
if user_group_ids is None: if user_group_ids is None:
user_groups = Groups.get_groups_by_member_id(user_id) user_groups = Groups.get_groups_by_member_id(user_id, db=db)
user_group_ids = {group.id for group in user_groups} user_group_ids = {group.id for group in user_groups}
permitted_ids = get_permitted_group_and_user_ids(type, access_control) permitted_ids = get_permitted_group_and_user_ids(type, access_control)
@ -152,10 +155,10 @@ def has_access(
# Get all users with access to a resource # Get all users with access to a resource
def get_users_with_access( def get_users_with_access(
type: str = "write", access_control: Optional[dict] = None type: str = "write", access_control: Optional[dict] = None, db: Optional[Any] = None
) -> list[UserModel]: ) -> list[UserModel]:
if access_control is None: if access_control is None:
result = Users.get_users(filter={"roles": ["!pending"]}) result = Users.get_users(filter={"roles": ["!pending"]}, db=db)
return result.get("users", []) return result.get("users", [])
permitted_ids = get_permitted_group_and_user_ids(type, access_control) permitted_ids = get_permitted_group_and_user_ids(type, access_control)
@ -167,8 +170,8 @@ def get_users_with_access(
user_ids_with_access = set(permitted_user_ids) user_ids_with_access = set(permitted_user_ids)
group_user_ids_map = Groups.get_group_user_ids_by_ids(permitted_group_ids) group_user_ids_map = Groups.get_group_user_ids_by_ids(permitted_group_ids, db=db)
for user_ids in group_user_ids_map.values(): for user_ids in group_user_ids_map.values():
user_ids_with_access.update(user_ids) user_ids_with_access.update(user_ids)
return Users.get_users_by_user_ids(list(user_ids_with_access)) return Users.get_users_by_user_ids(list(user_ids_with_access), db=db)

View file

@ -42,6 +42,8 @@ from open_webui.env import (
from fastapi import BackgroundTasks, Depends, HTTPException, Request, Response, status from fastapi import BackgroundTasks, Depends, HTTPException, Request, Response, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from sqlalchemy.orm import Session
from open_webui.internal.db import get_session
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -271,6 +273,7 @@ async def get_current_user(
response: Response, response: Response,
background_tasks: BackgroundTasks, background_tasks: BackgroundTasks,
auth_token: HTTPAuthorizationCredentials = Depends(bearer_security), auth_token: HTTPAuthorizationCredentials = Depends(bearer_security),
db: Session = Depends(get_session),
): ):
token = None token = None
@ -285,7 +288,7 @@ async def get_current_user(
# auth by api key # auth by api key
if token.startswith("sk-"): if token.startswith("sk-"):
user = get_current_user_by_api_key(request, token) user = get_current_user_by_api_key(request, token, db=db)
# Add user info to current span # Add user info to current span
current_span = trace.get_current_span() current_span = trace.get_current_span()
@ -314,7 +317,7 @@ async def get_current_user(
detail="Invalid token", detail="Invalid token",
) )
user = Users.get_user_by_id(data["id"]) user = Users.get_user_by_id(data["id"], db=db)
if user is None: if user is None:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
@ -364,8 +367,8 @@ async def get_current_user(
raise e raise e
def get_current_user_by_api_key(request, api_key: str): def get_current_user_by_api_key(request, api_key: str, db: Session = None):
user = Users.get_user_by_api_key(api_key) user = Users.get_user_by_api_key(api_key, db=db)
if user is None: if user is None:
raise HTTPException( raise HTTPException(
@ -393,7 +396,7 @@ def get_current_user_by_api_key(request, api_key: str):
current_span.set_attribute("client.user.role", user.role) current_span.set_attribute("client.user.role", user.role)
current_span.set_attribute("client.auth.type", "api_key") current_span.set_attribute("client.auth.type", "api_key")
Users.update_last_active_by_id(user.id) Users.update_last_active_by_id(user.id, db=db)
return user return user

View file

@ -7,6 +7,7 @@ log = logging.getLogger(__name__)
def apply_default_group_assignment( def apply_default_group_assignment(
default_group_id: str, default_group_id: str,
user_id: str, user_id: str,
db=None,
) -> None: ) -> None:
""" """
Apply default group assignment to a user if default_group_id is provided. Apply default group assignment to a user if default_group_id is provided.
@ -17,7 +18,7 @@ def apply_default_group_assignment(
""" """
if default_group_id: if default_group_id:
try: try:
Groups.add_users_to_group(default_group_id, [user_id]) Groups.add_users_to_group(default_group_id, [user_id], db=db)
except Exception as e: except Exception as e:
log.error( log.error(
f"Failed to add user {user_id} to default group {default_group_id}: {e}" f"Failed to add user {user_id} to default group {default_group_id}: {e}"

View file

@ -1336,7 +1336,7 @@ class OAuthManager:
return await client.authorize_redirect(request, redirect_uri, **kwargs) return await client.authorize_redirect(request, redirect_uri, **kwargs)
async def handle_callback(self, request, provider, response): async def handle_callback(self, request, provider, response, db=None):
if provider not in OAUTH_PROVIDERS: if provider not in OAUTH_PROVIDERS:
raise HTTPException(404) raise HTTPException(404)
@ -1461,20 +1461,20 @@ class OAuthManager:
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
# Check if the user exists # Check if the user exists
user = Users.get_user_by_oauth_sub(provider, sub) user = Users.get_user_by_oauth_sub(provider, sub, db=db)
if not user: if not user:
# If the user does not exist, check if merging is enabled # If the user does not exist, check if merging is enabled
if auth_manager_config.OAUTH_MERGE_ACCOUNTS_BY_EMAIL: if auth_manager_config.OAUTH_MERGE_ACCOUNTS_BY_EMAIL:
# Check if the user exists by email # Check if the user exists by email
user = Users.get_user_by_email(email) user = Users.get_user_by_email(email, db=db)
if user: if user:
# Update the user with the new oauth sub # Update the user with the new oauth sub
Users.update_user_oauth_by_id(user.id, provider, sub) Users.update_user_oauth_by_id(user.id, provider, sub, db=db)
if user: if user:
determined_role = self.get_user_role(user, user_data) determined_role = self.get_user_role(user, user_data)
if user.role != determined_role: if user.role != determined_role:
Users.update_user_role_by_id(user.id, determined_role) Users.update_user_role_by_id(user.id, determined_role, db=db)
# Update the user object in memory as well, # Update the user object in memory as well,
# to avoid problems with the ENABLE_OAUTH_GROUP_MANAGEMENT check below # to avoid problems with the ENABLE_OAUTH_GROUP_MANAGEMENT check below
user.role = determined_role user.role = determined_role
@ -1491,14 +1491,14 @@ class OAuthManager:
) )
if processed_picture_url != user.profile_image_url: if processed_picture_url != user.profile_image_url:
Users.update_user_profile_image_url_by_id( Users.update_user_profile_image_url_by_id(
user.id, processed_picture_url user.id, processed_picture_url, db=db
) )
log.debug(f"Updated profile picture for user {user.email}") log.debug(f"Updated profile picture for user {user.email}")
else: else:
# If the user does not exist, check if signups are enabled # If the user does not exist, check if signups are enabled
if auth_manager_config.ENABLE_OAUTH_SIGNUP: if auth_manager_config.ENABLE_OAUTH_SIGNUP:
# Check if an existing user with the same email already exists # Check if an existing user with the same email already exists
existing_user = Users.get_user_by_email(email) existing_user = Users.get_user_by_email(email, db=db)
if existing_user: if existing_user:
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN) raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
@ -1529,6 +1529,7 @@ class OAuthManager:
profile_image_url=picture_url, profile_image_url=picture_url,
role=self.get_user_role(None, user_data), role=self.get_user_role(None, user_data),
oauth=oauth_data, oauth=oauth_data,
db=db,
) )
if auth_manager_config.WEBHOOK_URL: if auth_manager_config.WEBHOOK_URL:
@ -1544,8 +1545,7 @@ class OAuthManager:
) )
apply_default_group_assignment( apply_default_group_assignment(
request.app.state.config.DEFAULT_GROUP_ID, request.app.state.config.DEFAULT_GROUP_ID, user.id, db=db
user.id,
) )
else: else:
@ -1616,15 +1616,16 @@ class OAuthManager:
token["expires_at"] = datetime.now().timestamp() + token["expires_in"] token["expires_at"] = datetime.now().timestamp() + token["expires_in"]
# Clean up any existing sessions for this user/provider first # Clean up any existing sessions for this user/provider first
sessions = OAuthSessions.get_sessions_by_user_id(user.id) sessions = OAuthSessions.get_sessions_by_user_id(user.id, db=db)
for session in sessions: for session in sessions:
if session.provider == provider: if session.provider == provider:
OAuthSessions.delete_session_by_id(session.id) OAuthSessions.delete_session_by_id(session.id, db=db)
session = OAuthSessions.create_session( session = OAuthSessions.create_session(
user_id=user.id, user_id=user.id,
provider=provider, provider=provider,
token=token, token=token,
db=db,
) )
response.set_cookie( response.set_cookie(