refac/enh: db session sharing

This commit is contained in:
Timothy Jaeryang Baek 2025-12-29 00:21:18 +04:00
parent 6dd0f99b90
commit b1d0f00d8c
23 changed files with 1173 additions and 663 deletions

View file

@ -19,7 +19,7 @@ from open_webui.env import (
from peewee_migrate import Router
from sqlalchemy import Dialect, create_engine, MetaData, event, types
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import scoped_session, sessionmaker
from sqlalchemy.orm import scoped_session, sessionmaker, Session
from sqlalchemy.pool import QueuePool, NullPool
from sqlalchemy.sql.type_api import _T
from typing_extensions import Self
@ -148,7 +148,7 @@ SessionLocal = sessionmaker(
)
metadata_obj = MetaData(schema=DATABASE_SCHEMA)
Base = declarative_base(metadata=metadata_obj)
Session = scoped_session(SessionLocal)
ScopedSession = scoped_session(SessionLocal)
def get_session():
@ -169,4 +169,3 @@ def get_db_context(db: Optional[Session] = None):
else:
with get_db() as session:
yield session

View file

@ -102,7 +102,9 @@ from open_webui.routers.retrieval import (
get_rf,
)
from open_webui.internal.db import Session, engine
from sqlalchemy.orm import Session
from open_webui.internal.db import ScopedSession, engine, get_session
from open_webui.models.functions import Functions
from open_webui.models.models import Models
@ -1324,7 +1326,7 @@ app.add_middleware(APIKeyRestrictionMiddleware)
async def commit_session_after_request(request: Request, call_next):
response = await call_next(request)
# log.debug("Commit session after request")
Session.commit()
ScopedSession.commit()
return response
@ -2280,8 +2282,13 @@ async def oauth_login(provider: str, request: Request):
# - Email addresses are considered unique, so we fail registration if the email address is already taken
@app.get("/oauth/{provider}/login/callback")
@app.get("/oauth/{provider}/callback") # Legacy endpoint
async def oauth_login_callback(provider: str, request: Request, response: Response):
return await oauth_manager.handle_callback(request, provider, response)
async def oauth_login_callback(
provider: str,
request: Request,
response: Response,
db: Session = Depends(get_session),
):
return await oauth_manager.handle_callback(request, provider, response, db=db)
@app.get("/manifest.json")
@ -2340,7 +2347,7 @@ async def healthcheck():
@app.get("/health/db")
async def healthcheck_with_db():
Session.execute(text("SELECT 1;")).all()
ScopedSession.execute(text("SELECT 1;")).all()
return {"status": True}

View file

@ -30,27 +30,27 @@ from sqlalchemy.exc import NoSuchTableError
from sqlalchemy.dialects.postgresql.psycopg2 import PGDialect_psycopg2
from sqlalchemy.dialects import registry
class OpenGaussDialect(PGDialect_psycopg2):
name = "opengauss"
def _get_server_version_info(self, connection):
try:
version = connection.exec_driver_sql("SELECT version()").scalar()
if not version:
return (9, 0, 0)
match = re.search(
r"openGauss\s+(\d+)\.(\d+)\.(\d+)(?:-\w+)?",
version,
re.IGNORECASE
r"openGauss\s+(\d+)\.(\d+)\.(\d+)(?:-\w+)?", version, re.IGNORECASE
)
if match:
return (int(match.group(1)), int(match.group(2)), int(match.group(3)))
return super()._get_server_version_info(connection)
except Exception:
return (9, 0, 0)
# Register dialect
registry.register("opengauss", __name__, "OpenGaussDialect")
@ -78,6 +78,7 @@ Base = declarative_base()
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
class DocumentChunk(Base):
__tablename__ = "document_chunk"
@ -87,29 +88,30 @@ class DocumentChunk(Base):
text = Column(Text, nullable=True)
vmetadata = Column(MutableDict.as_mutable(JSONB), nullable=True)
class OpenGaussClient(VectorDBBase):
def __init__(self) -> None:
if not OPENGAUSS_DB_URL:
from open_webui.internal.db import Session
self.session = Session
from open_webui.internal.db import ScopedSession
self.session = ScopedSession
else:
engine_kwargs = {
"pool_pre_ping": True,
"dialect": OpenGaussDialect()
}
engine_kwargs = {"pool_pre_ping": True, "dialect": OpenGaussDialect()}
if isinstance(OPENGAUSS_POOL_SIZE, int) and OPENGAUSS_POOL_SIZE > 0:
engine_kwargs.update({
"pool_size": OPENGAUSS_POOL_SIZE,
"max_overflow": OPENGAUSS_POOL_MAX_OVERFLOW,
"pool_timeout": OPENGAUSS_POOL_TIMEOUT,
"pool_recycle": OPENGAUSS_POOL_RECYCLE,
"poolclass": QueuePool
})
engine_kwargs.update(
{
"pool_size": OPENGAUSS_POOL_SIZE,
"max_overflow": OPENGAUSS_POOL_MAX_OVERFLOW,
"pool_timeout": OPENGAUSS_POOL_TIMEOUT,
"pool_recycle": OPENGAUSS_POOL_RECYCLE,
"poolclass": QueuePool,
}
)
else:
engine_kwargs["poolclass"] = NullPool
engine = create_engine(OPENGAUSS_DB_URL,** engine_kwargs)
engine = create_engine(OPENGAUSS_DB_URL, **engine_kwargs)
SessionLocal = sessionmaker(
autocommit=False, autoflush=False, bind=engine, expire_on_commit=False
@ -160,7 +162,9 @@ class OpenGaussClient(VectorDBBase):
else:
raise Exception("The 'vector' column type is not Vector.")
else:
raise Exception("The 'vector' column does not exist in the 'document_chunk' table.")
raise Exception(
"The 'vector' column does not exist in the 'document_chunk' table."
)
def adjust_vector_length(self, vector: List[float]) -> List[float]:
current_length = len(vector)
@ -185,7 +189,9 @@ class OpenGaussClient(VectorDBBase):
new_items.append(new_chunk)
self.session.bulk_save_objects(new_items)
self.session.commit()
log.info(f"Inserting {len(new_items)} items into collection '{collection_name}'.")
log.info(
f"Inserting {len(new_items)} items into collection '{collection_name}'."
)
except Exception as e:
self.session.rollback()
log.exception(f"Failed to insert data: {e}")
@ -215,7 +221,9 @@ class OpenGaussClient(VectorDBBase):
)
self.session.add(new_chunk)
self.session.commit()
log.info(f"Inserting/updating {len(items)} items in collection '{collection_name}'.")
log.info(
f"Inserting/updating {len(items)} items in collection '{collection_name}'."
)
except Exception as e:
self.session.rollback()
log.exception(f"Failed to insert or update data.: {e}")
@ -241,7 +249,9 @@ class OpenGaussClient(VectorDBBase):
q_vector_col = column("q_vector", Vector(VECTOR_LENGTH))
query_vectors = (
values(qid_col, q_vector_col)
.data([(idx, vector_expr(vector)) for idx, vector in enumerate(vectors)])
.data(
[(idx, vector_expr(vector)) for idx, vector in enumerate(vectors)]
)
.alias("query_vectors")
)
@ -249,13 +259,17 @@ class OpenGaussClient(VectorDBBase):
DocumentChunk.id,
DocumentChunk.text,
DocumentChunk.vmetadata,
(DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)).label("distance"),
(DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)).label(
"distance"
),
]
subq = (
select(*result_fields)
.where(DocumentChunk.collection_name == collection_name)
.order_by(DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector))
.order_by(
DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)
)
)
if limit is not None:
subq = subq.limit(limit)
@ -368,7 +382,9 @@ class OpenGaussClient(VectorDBBase):
query = query.filter(DocumentChunk.id.in_(ids))
if filter:
for key, value in filter.items():
query = query.filter(DocumentChunk.vmetadata[key].astext == str(value))
query = query.filter(
DocumentChunk.vmetadata[key].astext == str(value)
)
deleted = query.delete(synchronize_session=False)
self.session.commit()
log.info(f"Deleted {deleted} items from collection '{collection_name}'")
@ -395,7 +411,8 @@ class OpenGaussClient(VectorDBBase):
exists = (
self.session.query(DocumentChunk)
.filter(DocumentChunk.collection_name == collection_name)
.first() is not None
.first()
is not None
)
self.session.rollback()
return exists
@ -406,4 +423,4 @@ class OpenGaussClient(VectorDBBase):
def delete_collection(self, collection_name: str) -> None:
self.delete(collection_name)
log.info(f"Collection '{collection_name}' has been deleted")
log.info(f"Collection '{collection_name}' has been deleted")

View file

@ -90,9 +90,9 @@ class PgvectorClient(VectorDBBase):
# if no pgvector uri, use the existing database connection
if not PGVECTOR_DB_URL:
from open_webui.internal.db import Session
from open_webui.internal.db import ScopedSession
self.session = Session
self.session = ScopedSession
else:
if isinstance(PGVECTOR_POOL_SIZE, int):
if PGVECTOR_POOL_SIZE > 0:

View file

@ -62,6 +62,8 @@ from open_webui.utils.auth import (
get_password_hash,
get_http_authorization_cred,
)
from open_webui.internal.db import get_session
from sqlalchemy.orm import Session
from open_webui.utils.webhook import post_webhook
from open_webui.utils.access_control import get_permissions, has_permission
from open_webui.utils.groups import apply_default_group_assignment
@ -103,7 +105,10 @@ class SessionUserInfoResponse(SessionUserResponse, UserStatus):
@router.get("/", response_model=SessionUserInfoResponse)
async def get_session_user(
request: Request, response: Response, user=Depends(get_current_user)
request: Request,
response: Response,
user=Depends(get_current_user),
db: Session = Depends(get_session),
):
auth_header = request.headers.get("Authorization")
@ -137,7 +142,7 @@ async def get_session_user(
)
user_permissions = get_permissions(
user.id, request.app.state.config.USER_PERMISSIONS
user.id, request.app.state.config.USER_PERMISSIONS, db=db
)
return {
@ -166,12 +171,15 @@ async def get_session_user(
@router.post("/update/profile", response_model=UserProfileImageResponse)
async def update_profile(
form_data: UpdateProfileForm, session_user=Depends(get_verified_user)
form_data: UpdateProfileForm,
session_user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
if session_user:
user = Users.update_user_by_id(
session_user.id,
form_data.model_dump(),
db=db,
)
if user:
return user
@ -188,13 +196,17 @@ async def update_profile(
@router.post("/update/password", response_model=bool)
async def update_password(
form_data: UpdatePasswordForm, session_user=Depends(get_current_user)
form_data: UpdatePasswordForm,
session_user=Depends(get_current_user),
db: Session = Depends(get_session),
):
if WEBUI_AUTH_TRUSTED_EMAIL_HEADER:
raise HTTPException(400, detail=ERROR_MESSAGES.ACTION_PROHIBITED)
if session_user:
user = Auths.authenticate_user(
session_user.email, lambda pw: verify_password(form_data.password, pw)
session_user.email,
lambda pw: verify_password(form_data.password, pw),
db=db,
)
if user:
@ -203,7 +215,7 @@ async def update_password(
except Exception as e:
raise HTTPException(400, detail=str(e))
hashed = get_password_hash(form_data.new_password)
return Auths.update_user_password_by_id(user.id, hashed)
return Auths.update_user_password_by_id(user.id, hashed, db=db)
else:
raise HTTPException(400, detail=ERROR_MESSAGES.INCORRECT_PASSWORD)
else:
@ -214,7 +226,12 @@ async def update_password(
# LDAP Authentication
############################
@router.post("/ldap", response_model=SessionUserResponse)
async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
async def ldap_auth(
request: Request,
response: Response,
form_data: LdapForm,
db: Session = Depends(get_session),
):
# Security checks FIRST - before loading any config
if not request.app.state.config.ENABLE_LDAP:
raise HTTPException(400, detail="LDAP authentication is not enabled")
@ -400,12 +417,12 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
if not connection_user.bind():
raise HTTPException(400, "Authentication failed.")
user = Users.get_user_by_email(email)
user = Users.get_user_by_email(email, db=db)
if not user:
try:
role = (
"admin"
if not Users.has_users()
if not Users.has_users(db=db)
else request.app.state.config.DEFAULT_USER_ROLE
)
@ -414,6 +431,7 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
password=str(uuid.uuid4()),
name=cn,
role=role,
db=db,
)
if not user:
@ -424,6 +442,7 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
apply_default_group_assignment(
request.app.state.config.DEFAULT_GROUP_ID,
user.id,
db=db,
)
except HTTPException:
@ -434,7 +453,7 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
500, detail="Internal error occurred during LDAP user creation."
)
user = Auths.authenticate_user_by_email(email)
user = Auths.authenticate_user_by_email(email, db=db)
if user:
expires_delta = parse_duration(request.app.state.config.JWT_EXPIRES_IN)
@ -464,7 +483,7 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
)
user_permissions = get_permissions(
user.id, request.app.state.config.USER_PERMISSIONS
user.id, request.app.state.config.USER_PERMISSIONS, db=db
)
if (
@ -473,9 +492,9 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
and user_groups
):
if ENABLE_LDAP_GROUP_CREATION:
Groups.create_groups_by_group_names(user.id, user_groups)
Groups.create_groups_by_group_names(user.id, user_groups, db=db)
try:
Groups.sync_groups_by_group_names(user.id, user_groups)
Groups.sync_groups_by_group_names(user.id, user_groups, db=db)
log.info(
f"Successfully synced groups for user {user.id}: {user_groups}"
)
@ -508,7 +527,12 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
@router.post("/signin", response_model=SessionUserResponse)
async def signin(request: Request, response: Response, form_data: SigninForm):
async def signin(
request: Request,
response: Response,
form_data: SigninForm,
db: Session = Depends(get_session),
):
if not ENABLE_PASSWORD_AUTH:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
@ -529,14 +553,15 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
except Exception as e:
pass
if not Users.get_user_by_email(email.lower()):
if not Users.get_user_by_email(email.lower(), db=db):
await signup(
request,
response,
SignupForm(email=email, password=str(uuid.uuid4()), name=name),
db=db,
)
user = Auths.authenticate_user_by_email(email)
user = Auths.authenticate_user_by_email(email, db=db)
if WEBUI_AUTH_TRUSTED_GROUPS_HEADER and user and user.role != "admin":
group_names = request.headers.get(
WEBUI_AUTH_TRUSTED_GROUPS_HEADER, ""
@ -544,28 +569,33 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
group_names = [name.strip() for name in group_names if name.strip()]
if group_names:
Groups.sync_groups_by_group_names(user.id, group_names)
Groups.sync_groups_by_group_names(user.id, group_names, db=db)
elif WEBUI_AUTH == False:
admin_email = "admin@localhost"
admin_password = "admin"
if Users.get_user_by_email(admin_email.lower()):
if Users.get_user_by_email(admin_email.lower(), db=db):
user = Auths.authenticate_user(
admin_email.lower(), lambda pw: verify_password(admin_password, pw)
admin_email.lower(),
lambda pw: verify_password(admin_password, pw),
db=db,
)
else:
if Users.has_users():
if Users.has_users(db=db):
raise HTTPException(400, detail=ERROR_MESSAGES.EXISTING_USERS)
await signup(
request,
response,
SignupForm(email=admin_email, password=admin_password, name="User"),
db=db,
)
user = Auths.authenticate_user(
admin_email.lower(), lambda pw: verify_password(admin_password, pw)
admin_email.lower(),
lambda pw: verify_password(admin_password, pw),
db=db,
)
else:
if signin_rate_limiter.is_limited(form_data.email.lower()):
@ -584,7 +614,9 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
form_data.password = password_bytes.decode("utf-8", errors="ignore")
user = Auths.authenticate_user(
form_data.email.lower(), lambda pw: verify_password(form_data.password, pw)
form_data.email.lower(),
lambda pw: verify_password(form_data.password, pw),
db=db,
)
if user:
@ -616,7 +648,7 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
)
user_permissions = get_permissions(
user.id, request.app.state.config.USER_PERMISSIONS
user.id, request.app.state.config.USER_PERMISSIONS, db=db
)
return {
@ -640,8 +672,13 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
@router.post("/signup", response_model=SessionUserResponse)
async def signup(request: Request, response: Response, form_data: SignupForm):
has_users = Users.has_users()
async def signup(
request: Request,
response: Response,
form_data: SignupForm,
db: Session = Depends(get_session),
):
has_users = Users.has_users(db=db)
if WEBUI_AUTH:
if (
@ -663,7 +700,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT
)
if Users.get_user_by_email(form_data.email.lower()):
if Users.get_user_by_email(form_data.email.lower(), db=db):
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
try:
@ -681,6 +718,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
form_data.name,
form_data.profile_image_url,
role,
db=db,
)
if user:
@ -723,7 +761,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
)
user_permissions = get_permissions(
user.id, request.app.state.config.USER_PERMISSIONS
user.id, request.app.state.config.USER_PERMISSIONS, db=db
)
if not has_users:
@ -733,6 +771,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
apply_default_group_assignment(
request.app.state.config.DEFAULT_GROUP_ID,
user.id,
db=db,
)
return {
@ -754,7 +793,9 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
@router.get("/signout")
async def signout(request: Request, response: Response):
async def signout(
request: Request, response: Response, db: Session = Depends(get_session)
):
# get auth token from headers or cookies
token = None
@ -776,7 +817,7 @@ async def signout(request: Request, response: Response):
if oauth_session_id:
response.delete_cookie("oauth_session_id")
session = OAuthSessions.get_session_by_id(oauth_session_id)
session = OAuthSessions.get_session_by_id(oauth_session_id, db=db)
oauth_server_metadata_url = (
request.app.state.oauth_manager.get_server_metadata_url(session.provider)
if session
@ -839,14 +880,17 @@ async def signout(request: Request, response: Response):
@router.post("/add", response_model=SigninResponse)
async def add_user(
request: Request, form_data: AddUserForm, user=Depends(get_admin_user)
request: Request,
form_data: AddUserForm,
user=Depends(get_admin_user),
db: Session = Depends(get_session),
):
if not validate_email_format(form_data.email.lower()):
raise HTTPException(
status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT
)
if Users.get_user_by_email(form_data.email.lower()):
if Users.get_user_by_email(form_data.email.lower(), db=db):
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
try:
@ -862,12 +906,14 @@ async def add_user(
form_data.name,
form_data.profile_image_url,
form_data.role,
db=db,
)
if user:
apply_default_group_assignment(
request.app.state.config.DEFAULT_GROUP_ID,
user.id,
db=db,
)
token = create_token(data={"id": user.id})
@ -895,7 +941,9 @@ async def add_user(
@router.get("/admin/details")
async def get_admin_details(request: Request, user=Depends(get_current_user)):
async def get_admin_details(
request: Request, user=Depends(get_current_user), db: Session = Depends(get_session)
):
if request.app.state.config.SHOW_ADMIN_DETAILS:
admin_email = request.app.state.config.ADMIN_EMAIL
admin_name = None
@ -903,11 +951,11 @@ async def get_admin_details(request: Request, user=Depends(get_current_user)):
log.info(f"Admin details - Email: {admin_email}, Name: {admin_name}")
if admin_email:
admin = Users.get_user_by_email(admin_email)
admin = Users.get_user_by_email(admin_email, db=db)
if admin:
admin_name = admin.name
else:
admin = Users.get_first_user()
admin = Users.get_first_user(db=db)
if admin:
admin_email = admin.email
admin_name = admin.name
@ -1149,7 +1197,9 @@ async def update_ldap_config(
# create api key
@router.post("/api_key", response_model=ApiKey)
async def generate_api_key(request: Request, user=Depends(get_current_user)):
async def generate_api_key(
request: Request, user=Depends(get_current_user), db: Session = Depends(get_session)
):
if not request.app.state.config.ENABLE_API_KEYS or not has_permission(
user.id, "features.api_keys", request.app.state.config.USER_PERMISSIONS
):
@ -1159,7 +1209,7 @@ async def generate_api_key(request: Request, user=Depends(get_current_user)):
)
api_key = create_api_key()
success = Users.update_user_api_key_by_id(user.id, api_key)
success = Users.update_user_api_key_by_id(user.id, api_key, db=db)
if success:
return {
@ -1171,14 +1221,18 @@ async def generate_api_key(request: Request, user=Depends(get_current_user)):
# delete api key
@router.delete("/api_key", response_model=bool)
async def delete_api_key(user=Depends(get_current_user)):
return Users.delete_user_api_key_by_id(user.id)
async def delete_api_key(
user=Depends(get_current_user), db: Session = Depends(get_session)
):
return Users.delete_user_api_key_by_id(user.id, db=db)
# get api key
@router.get("/api_key", response_model=ApiKey)
async def get_api_key(user=Depends(get_current_user)):
api_key = Users.get_user_api_key_by_id(user.id)
async def get_api_key(
user=Depends(get_current_user), db: Session = Depends(get_session)
):
api_key = Users.get_user_api_key_by_id(user.id, db=db)
if api_key:
return {
"api_key": api_key,

View file

@ -1,6 +1,7 @@
import json
import logging
from typing import Optional
from sqlalchemy.orm import Session
import asyncio
from fastapi.responses import StreamingResponse
@ -23,6 +24,7 @@ from open_webui.models.chats import (
)
from open_webui.models.tags import TagModel, Tags
from open_webui.models.folders import Folders
from open_webui.internal.db import get_session
from open_webui.config import ENABLE_ADMIN_CHAT_ACCESS, ENABLE_ADMIN_EXPORT
from open_webui.constants import ERROR_MESSAGES
@ -49,6 +51,7 @@ def get_session_user_chat_list(
page: Optional[int] = None,
include_pinned: Optional[bool] = False,
include_folders: Optional[bool] = False,
db: Session = Depends(get_session),
):
try:
if page is not None:
@ -61,10 +64,14 @@ def get_session_user_chat_list(
include_pinned=include_pinned,
skip=skip,
limit=limit,
db=db,
)
else:
return Chats.get_chat_title_id_list_by_user_id(
user.id, include_folders=include_folders, include_pinned=include_pinned
user.id,
include_folders=include_folders,
include_pinned=include_pinned,
db=db,
)
except Exception as e:
log.exception(e)
@ -84,12 +91,13 @@ def get_session_user_chat_usage_stats(
items_per_page: Optional[int] = 50,
page: Optional[int] = 1,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
try:
limit = items_per_page
skip = (page - 1) * limit
result = Chats.get_chats_by_user_id(user.id, skip=skip, limit=limit)
result = Chats.get_chats_by_user_id(user.id, skip=skip, limit=limit, db=db)
chats = result.items
total = result.total
@ -216,6 +224,7 @@ class ChatStatsExportList(BaseModel):
def _process_chat_for_export(chat) -> Optional[ChatStatsExport]:
try:
def get_message_content_length(message):
content = message.get("content", "")
if isinstance(content, str):
@ -348,7 +357,9 @@ def _process_chat_for_export(chat) -> Optional[ChatStatsExport]:
return None
def calculate_chat_stats(user_id, skip=0, limit=10, filter=None):
def calculate_chat_stats(
user_id, skip=0, limit=10, filter=None, db: Optional[Session] = None
):
if filter is None:
filter = {}
@ -357,6 +368,7 @@ def calculate_chat_stats(user_id, skip=0, limit=10, filter=None):
skip=skip,
limit=limit,
filter=filter,
db=db,
)
chat_stats_export_list = []
@ -368,14 +380,21 @@ def calculate_chat_stats(user_id, skip=0, limit=10, filter=None):
return chat_stats_export_list, result.total
async def generate_chat_stats_jsonl_generator(user_id, filter):
async def generate_chat_stats_jsonl_generator(
user_id, filter, db: Optional[Session] = None
):
skip = 0
limit = CHAT_EXPORT_PAGE_ITEM_COUNT
while True:
# Use asyncio.to_thread to make the blocking DB call non-blocking
result = await asyncio.to_thread(
Chats.get_chats_by_user_id, user_id, filter=filter, skip=skip, limit=limit
Chats.get_chats_by_user_id,
user_id,
filter=filter,
skip=skip,
limit=limit,
db=db,
)
if not result.items:
break
@ -386,7 +405,7 @@ async def generate_chat_stats_jsonl_generator(user_id, filter):
if chat_stat:
yield chat_stat.model_dump_json() + "\n"
except Exception as e:
log.exception(f"Error processing chat {chat.id}: {e}")
log.exception(f"Error processing chat {chat.id}: {e}")
skip += limit
@ -400,6 +419,7 @@ async def export_chat_stats(
page: Optional[int] = 1,
stream: bool = False,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
# Check if the user has permission to share/export chats
if (user.role != "admin") and (
@ -415,7 +435,7 @@ async def export_chat_stats(
filter = {"order_by": "created_at", "direction": "asc"}
if chat_id:
chat = Chats.get_chat_by_id(chat_id)
chat = Chats.get_chat_by_id(chat_id, db=db)
if chat:
filter["start_time"] = chat.created_at
@ -426,7 +446,7 @@ async def export_chat_stats(
if stream:
return StreamingResponse(
generate_chat_stats_jsonl_generator(user.id, filter),
generate_chat_stats_jsonl_generator(user.id, filter, db=db),
media_type="application/x-ndjson",
headers={
"Content-Disposition": f"attachment; filename=chat-stats-export-{user.id}.jsonl"
@ -437,7 +457,7 @@ async def export_chat_stats(
skip = (page - 1) * limit
chat_stats_export_list, total = await asyncio.to_thread(
calculate_chat_stats, user.id, skip, limit, filter
calculate_chat_stats, user.id, skip, limit, filter, db=db
)
return ChatStatsExportList(
@ -452,7 +472,11 @@ async def export_chat_stats(
@router.delete("/", response_model=bool)
async def delete_all_user_chats(request: Request, user=Depends(get_verified_user)):
async def delete_all_user_chats(
request: Request,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
if user.role == "user" and not has_permission(
user.id, "chat.delete", request.app.state.config.USER_PERMISSIONS
@ -462,7 +486,7 @@ async def delete_all_user_chats(request: Request, user=Depends(get_verified_user
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
result = Chats.delete_chats_by_user_id(user.id)
result = Chats.delete_chats_by_user_id(user.id, db=db)
return result
@ -479,6 +503,7 @@ async def get_user_chat_list_by_user_id(
order_by: Optional[str] = None,
direction: Optional[str] = None,
user=Depends(get_admin_user),
db: Session = Depends(get_session),
):
if not ENABLE_ADMIN_CHAT_ACCESS:
raise HTTPException(
@ -501,7 +526,7 @@ async def get_user_chat_list_by_user_id(
filter["direction"] = direction
return Chats.get_chat_list_by_user_id(
user_id, include_archived=True, filter=filter, skip=skip, limit=limit
user_id, include_archived=True, filter=filter, skip=skip, limit=limit, db=db
)
@ -511,9 +536,13 @@ async def get_user_chat_list_by_user_id(
@router.post("/new", response_model=Optional[ChatResponse])
async def create_new_chat(form_data: ChatForm, user=Depends(get_verified_user)):
async def create_new_chat(
form_data: ChatForm,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
try:
chat = Chats.insert_new_chat(user.id, form_data)
chat = Chats.insert_new_chat(user.id, form_data, db=db)
return ChatResponse(**chat.model_dump())
except Exception as e:
log.exception(e)
@ -528,9 +557,13 @@ async def create_new_chat(form_data: ChatForm, user=Depends(get_verified_user)):
@router.post("/import", response_model=list[ChatResponse])
async def import_chats(form_data: ChatsImportForm, user=Depends(get_verified_user)):
async def import_chats(
form_data: ChatsImportForm,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
try:
chats = Chats.import_chats(user.id, form_data.chats)
chats = Chats.import_chats(user.id, form_data.chats, db=db)
return chats
except Exception as e:
log.exception(e)
@ -546,7 +579,10 @@ async def import_chats(form_data: ChatsImportForm, user=Depends(get_verified_use
@router.get("/search", response_model=list[ChatTitleIdResponse])
def search_user_chats(
text: str, page: Optional[int] = None, user=Depends(get_verified_user)
text: str,
page: Optional[int] = None,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
if page is None:
page = 1
@ -557,7 +593,7 @@ def search_user_chats(
chat_list = [
ChatTitleIdResponse(**chat.model_dump())
for chat in Chats.get_chats_by_user_id_and_search_text(
user.id, text, skip=skip, limit=limit
user.id, text, skip=skip, limit=limit, db=db
)
]
@ -566,9 +602,9 @@ def search_user_chats(
if page == 1 and len(words) == 1 and words[0].startswith("tag:"):
tag_id = words[0].replace("tag:", "")
if len(chat_list) == 0:
if Tags.get_tag_by_name_and_user_id(tag_id, user.id):
if Tags.get_tag_by_name_and_user_id(tag_id, user.id, db=db):
log.debug(f"deleting tag: {tag_id}")
Tags.delete_tag_by_name_and_user_id(tag_id, user.id)
Tags.delete_tag_by_name_and_user_id(tag_id, user.id, db=db)
return chat_list
@ -579,23 +615,30 @@ def search_user_chats(
@router.get("/folder/{folder_id}", response_model=list[ChatResponse])
async def get_chats_by_folder_id(folder_id: str, user=Depends(get_verified_user)):
async def get_chats_by_folder_id(
folder_id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
):
folder_ids = [folder_id]
children_folders = Folders.get_children_folders_by_id_and_user_id(
folder_id, user.id
folder_id, user.id, db=db
)
if children_folders:
folder_ids.extend([folder.id for folder in children_folders])
return [
ChatResponse(**chat.model_dump())
for chat in Chats.get_chats_by_folder_ids_and_user_id(folder_ids, user.id)
for chat in Chats.get_chats_by_folder_ids_and_user_id(
folder_ids, user.id, db=db
)
]
@router.get("/folder/{folder_id}/list")
async def get_chat_list_by_folder_id(
folder_id: str, page: Optional[int] = 1, user=Depends(get_verified_user)
folder_id: str,
page: Optional[int] = 1,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
try:
limit = 10
@ -604,7 +647,7 @@ async def get_chat_list_by_folder_id(
return [
{"title": chat.title, "id": chat.id, "updated_at": chat.updated_at}
for chat in Chats.get_chats_by_folder_id_and_user_id(
folder_id, user.id, skip=skip, limit=limit
folder_id, user.id, skip=skip, limit=limit, db=db
)
]
@ -621,10 +664,12 @@ async def get_chat_list_by_folder_id(
@router.get("/pinned", response_model=list[ChatTitleIdResponse])
async def get_user_pinned_chats(user=Depends(get_verified_user)):
async def get_user_pinned_chats(
user=Depends(get_verified_user), db: Session = Depends(get_session)
):
return [
ChatTitleIdResponse(**chat.model_dump())
for chat in Chats.get_pinned_chats_by_user_id(user.id)
for chat in Chats.get_pinned_chats_by_user_id(user.id, db=db)
]
@ -634,10 +679,12 @@ async def get_user_pinned_chats(user=Depends(get_verified_user)):
@router.get("/all", response_model=list[ChatResponse])
async def get_user_chats(user=Depends(get_verified_user)):
async def get_user_chats(
user=Depends(get_verified_user), db: Session = Depends(get_session)
):
return [
ChatResponse(**chat.model_dump())
for chat in Chats.get_chats_by_user_id(user.id)
for chat in Chats.get_chats_by_user_id(user.id, db=db)
]
@ -647,10 +694,12 @@ async def get_user_chats(user=Depends(get_verified_user)):
@router.get("/all/archived", response_model=list[ChatResponse])
async def get_user_archived_chats(user=Depends(get_verified_user)):
async def get_user_archived_chats(
user=Depends(get_verified_user), db: Session = Depends(get_session)
):
return [
ChatResponse(**chat.model_dump())
for chat in Chats.get_archived_chats_by_user_id(user.id)
for chat in Chats.get_archived_chats_by_user_id(user.id, db=db)
]
@ -660,9 +709,11 @@ async def get_user_archived_chats(user=Depends(get_verified_user)):
@router.get("/all/tags", response_model=list[TagModel])
async def get_all_user_tags(user=Depends(get_verified_user)):
async def get_all_user_tags(
user=Depends(get_verified_user), db: Session = Depends(get_session)
):
try:
tags = Tags.get_tags_by_user_id(user.id)
tags = Tags.get_tags_by_user_id(user.id, db=db)
return tags
except Exception as e:
log.exception(e)
@ -677,13 +728,15 @@ async def get_all_user_tags(user=Depends(get_verified_user)):
@router.get("/all/db", response_model=list[ChatResponse])
async def get_all_user_chats_in_db(user=Depends(get_admin_user)):
async def get_all_user_chats_in_db(
user=Depends(get_admin_user), db: Session = Depends(get_session)
):
if not ENABLE_ADMIN_EXPORT:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
return [ChatResponse(**chat.model_dump()) for chat in Chats.get_chats()]
return [ChatResponse(**chat.model_dump()) for chat in Chats.get_chats(db=db)]
############################
@ -698,6 +751,7 @@ async def get_archived_session_user_chat_list(
order_by: Optional[str] = None,
direction: Optional[str] = None,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
if page is None:
page = 1
@ -720,6 +774,7 @@ async def get_archived_session_user_chat_list(
filter=filter,
skip=skip,
limit=limit,
db=db,
)
]
@ -732,8 +787,10 @@ async def get_archived_session_user_chat_list(
@router.post("/archive/all", response_model=bool)
async def archive_all_chats(user=Depends(get_verified_user)):
return Chats.archive_all_chats_by_user_id(user.id)
async def archive_all_chats(
user=Depends(get_verified_user), db: Session = Depends(get_session)
):
return Chats.archive_all_chats_by_user_id(user.id, db=db)
############################
@ -742,8 +799,10 @@ async def archive_all_chats(user=Depends(get_verified_user)):
@router.post("/unarchive/all", response_model=bool)
async def unarchive_all_chats(user=Depends(get_verified_user)):
return Chats.unarchive_all_chats_by_user_id(user.id)
async def unarchive_all_chats(
user=Depends(get_verified_user), db: Session = Depends(get_session)
):
return Chats.unarchive_all_chats_by_user_id(user.id, db=db)
############################
@ -752,16 +811,18 @@ async def unarchive_all_chats(user=Depends(get_verified_user)):
@router.get("/share/{share_id}", response_model=Optional[ChatResponse])
async def get_shared_chat_by_id(share_id: str, user=Depends(get_verified_user)):
async def get_shared_chat_by_id(
share_id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
):
if user.role == "pending":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
)
if user.role == "user" or (user.role == "admin" and not ENABLE_ADMIN_CHAT_ACCESS):
chat = Chats.get_chat_by_share_id(share_id)
chat = Chats.get_chat_by_share_id(share_id, db=db)
elif user.role == "admin" and ENABLE_ADMIN_CHAT_ACCESS:
chat = Chats.get_chat_by_id(share_id)
chat = Chats.get_chat_by_id(share_id, db=db)
if chat:
return ChatResponse(**chat.model_dump())
@ -788,13 +849,15 @@ class TagFilterForm(TagForm):
@router.post("/tags", response_model=list[ChatTitleIdResponse])
async def get_user_chat_list_by_tag_name(
form_data: TagFilterForm, user=Depends(get_verified_user)
form_data: TagFilterForm,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
chats = Chats.get_chat_list_by_user_id_and_tag_name(
user.id, form_data.name, form_data.skip, form_data.limit
user.id, form_data.name, form_data.skip, form_data.limit, db=db
)
if len(chats) == 0:
Tags.delete_tag_by_name_and_user_id(form_data.name, user.id)
Tags.delete_tag_by_name_and_user_id(form_data.name, user.id, db=db)
return chats
@ -805,8 +868,10 @@ async def get_user_chat_list_by_tag_name(
@router.get("/{id}", response_model=Optional[ChatResponse])
async def get_chat_by_id(id: str, user=Depends(get_verified_user)):
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
async def get_chat_by_id(
id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
):
chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db)
if chat:
return ChatResponse(**chat.model_dump())
@ -824,12 +889,15 @@ async def get_chat_by_id(id: str, user=Depends(get_verified_user)):
@router.post("/{id}", response_model=Optional[ChatResponse])
async def update_chat_by_id(
id: str, form_data: ChatForm, user=Depends(get_verified_user)
id: str,
form_data: ChatForm,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db)
if chat:
updated_chat = {**chat.chat, **form_data.chat}
chat = Chats.update_chat_by_id(id, updated_chat)
chat = Chats.update_chat_by_id(id, updated_chat, db=db)
return ChatResponse(**chat.model_dump())
else:
raise HTTPException(
@ -847,9 +915,13 @@ class MessageForm(BaseModel):
@router.post("/{id}/messages/{message_id}", response_model=Optional[ChatResponse])
async def update_chat_message_by_id(
id: str, message_id: str, form_data: MessageForm, user=Depends(get_verified_user)
id: str,
message_id: str,
form_data: MessageForm,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
chat = Chats.get_chat_by_id(id)
chat = Chats.get_chat_by_id(id, db=db)
if not chat:
raise HTTPException(
@ -869,6 +941,7 @@ async def update_chat_message_by_id(
{
"content": form_data.content,
},
db=db,
)
event_emitter = get_event_emitter(
@ -905,9 +978,13 @@ class EventForm(BaseModel):
@router.post("/{id}/messages/{message_id}/event", response_model=Optional[bool])
async def send_chat_message_event_by_id(
id: str, message_id: str, form_data: EventForm, user=Depends(get_verified_user)
id: str,
message_id: str,
form_data: EventForm,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
chat = Chats.get_chat_by_id(id)
chat = Chats.get_chat_by_id(id, db=db)
if not chat:
raise HTTPException(
@ -945,14 +1022,19 @@ async def send_chat_message_event_by_id(
@router.delete("/{id}", response_model=bool)
async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified_user)):
async def delete_chat_by_id(
request: Request,
id: str,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
if user.role == "admin":
chat = Chats.get_chat_by_id(id)
chat = Chats.get_chat_by_id(id, db=db)
for tag in chat.meta.get("tags", []):
if Chats.count_chats_by_tag_name_and_user_id(tag, user.id) == 1:
Tags.delete_tag_by_name_and_user_id(tag, user.id)
if Chats.count_chats_by_tag_name_and_user_id(tag, user.id, db=db) == 1:
Tags.delete_tag_by_name_and_user_id(tag, user.id, db=db)
result = Chats.delete_chat_by_id(id)
result = Chats.delete_chat_by_id(id, db=db)
return result
else:
@ -964,12 +1046,12 @@ async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
chat = Chats.get_chat_by_id(id)
chat = Chats.get_chat_by_id(id, db=db)
for tag in chat.meta.get("tags", []):
if Chats.count_chats_by_tag_name_and_user_id(tag, user.id) == 1:
Tags.delete_tag_by_name_and_user_id(tag, user.id)
if Chats.count_chats_by_tag_name_and_user_id(tag, user.id, db=db) == 1:
Tags.delete_tag_by_name_and_user_id(tag, user.id, db=db)
result = Chats.delete_chat_by_id_and_user_id(id, user.id)
result = Chats.delete_chat_by_id_and_user_id(id, user.id, db=db)
return result
@ -979,8 +1061,10 @@ async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified
@router.get("/{id}/pinned", response_model=Optional[bool])
async def get_pinned_status_by_id(id: str, user=Depends(get_verified_user)):
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
async def get_pinned_status_by_id(
id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
):
chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db)
if chat:
return chat.pinned
else:
@ -995,10 +1079,12 @@ async def get_pinned_status_by_id(id: str, user=Depends(get_verified_user)):
@router.post("/{id}/pin", response_model=Optional[ChatResponse])
async def pin_chat_by_id(id: str, user=Depends(get_verified_user)):
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
async def pin_chat_by_id(
id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
):
chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db)
if chat:
chat = Chats.toggle_chat_pinned_by_id(id)
chat = Chats.toggle_chat_pinned_by_id(id, db=db)
return chat
else:
raise HTTPException(
@ -1017,9 +1103,12 @@ class CloneForm(BaseModel):
@router.post("/{id}/clone", response_model=Optional[ChatResponse])
async def clone_chat_by_id(
form_data: CloneForm, id: str, user=Depends(get_verified_user)
form_data: CloneForm,
id: str,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db)
if chat:
updated_chat = {
**chat.chat,
@ -1040,6 +1129,7 @@ async def clone_chat_by_id(
}
)
],
db=db,
)
if chats:
@ -1062,12 +1152,14 @@ async def clone_chat_by_id(
@router.post("/{id}/clone/shared", response_model=Optional[ChatResponse])
async def clone_shared_chat_by_id(id: str, user=Depends(get_verified_user)):
async def clone_shared_chat_by_id(
id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
):
if user.role == "admin":
chat = Chats.get_chat_by_id(id)
chat = Chats.get_chat_by_id(id, db=db)
else:
chat = Chats.get_chat_by_share_id(id)
chat = Chats.get_chat_by_share_id(id, db=db)
if chat:
updated_chat = {
@ -1089,6 +1181,7 @@ async def clone_shared_chat_by_id(id: str, user=Depends(get_verified_user)):
}
)
],
db=db,
)
if chats:
@ -1111,23 +1204,28 @@ async def clone_shared_chat_by_id(id: str, user=Depends(get_verified_user)):
@router.post("/{id}/archive", response_model=Optional[ChatResponse])
async def archive_chat_by_id(id: str, user=Depends(get_verified_user)):
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
async def archive_chat_by_id(
id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
):
chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db)
if chat:
chat = Chats.toggle_chat_archive_by_id(id)
chat = Chats.toggle_chat_archive_by_id(id, db=db)
# Delete tags if chat is archived
if chat.archived:
for tag_id in chat.meta.get("tags", []):
if Chats.count_chats_by_tag_name_and_user_id(tag_id, user.id) == 0:
if (
Chats.count_chats_by_tag_name_and_user_id(tag_id, user.id, db=db)
== 0
):
log.debug(f"deleting tag: {tag_id}")
Tags.delete_tag_by_name_and_user_id(tag_id, user.id)
Tags.delete_tag_by_name_and_user_id(tag_id, user.id, db=db)
else:
for tag_id in chat.meta.get("tags", []):
tag = Tags.get_tag_by_name_and_user_id(tag_id, user.id)
tag = Tags.get_tag_by_name_and_user_id(tag_id, user.id, db=db)
if tag is None:
log.debug(f"inserting tag: {tag_id}")
tag = Tags.insert_new_tag(tag_id, user.id)
tag = Tags.insert_new_tag(tag_id, user.id, db=db)
return ChatResponse(**chat.model_dump())
else:
@ -1142,7 +1240,12 @@ async def archive_chat_by_id(id: str, user=Depends(get_verified_user)):
@router.post("/{id}/share", response_model=Optional[ChatResponse])
async def share_chat_by_id(request: Request, id: str, user=Depends(get_verified_user)):
async def share_chat_by_id(
request: Request,
id: str,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
if (user.role != "admin") and (
not has_permission(
user.id, "chat.share", request.app.state.config.USER_PERMISSIONS
@ -1153,14 +1256,14 @@ async def share_chat_by_id(request: Request, id: str, user=Depends(get_verified_
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db)
if chat:
if chat.share_id:
shared_chat = Chats.update_shared_chat_by_chat_id(chat.id)
shared_chat = Chats.update_shared_chat_by_chat_id(chat.id, db=db)
return ChatResponse(**shared_chat.model_dump())
shared_chat = Chats.insert_shared_chat_by_chat_id(chat.id)
shared_chat = Chats.insert_shared_chat_by_chat_id(chat.id, db=db)
if not shared_chat:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
@ -1181,14 +1284,16 @@ async def share_chat_by_id(request: Request, id: str, user=Depends(get_verified_
@router.delete("/{id}/share", response_model=Optional[bool])
async def delete_shared_chat_by_id(id: str, user=Depends(get_verified_user)):
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
async def delete_shared_chat_by_id(
id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
):
chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db)
if chat:
if not chat.share_id:
return False
result = Chats.delete_shared_chat_by_chat_id(id)
update_result = Chats.update_chat_share_id_by_id(id, None)
result = Chats.delete_shared_chat_by_chat_id(id, db=db)
update_result = Chats.update_chat_share_id_by_id(id, None, db=db)
return result and update_result != None
else:
@ -1209,12 +1314,15 @@ class ChatFolderIdForm(BaseModel):
@router.post("/{id}/folder", response_model=Optional[ChatResponse])
async def update_chat_folder_id_by_id(
id: str, form_data: ChatFolderIdForm, user=Depends(get_verified_user)
id: str,
form_data: ChatFolderIdForm,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db)
if chat:
chat = Chats.update_chat_folder_id_by_id_and_user_id(
id, user.id, form_data.folder_id
id, user.id, form_data.folder_id, db=db
)
return ChatResponse(**chat.model_dump())
else:
@ -1229,11 +1337,13 @@ async def update_chat_folder_id_by_id(
@router.get("/{id}/tags", response_model=list[TagModel])
async def get_chat_tags_by_id(id: str, user=Depends(get_verified_user)):
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
async def get_chat_tags_by_id(
id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
):
chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db)
if chat:
tags = chat.meta.get("tags", [])
return Tags.get_tags_by_ids_and_user_id(tags, user.id)
return Tags.get_tags_by_ids_and_user_id(tags, user.id, db=db)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
@ -1247,9 +1357,12 @@ async def get_chat_tags_by_id(id: str, user=Depends(get_verified_user)):
@router.post("/{id}/tags", response_model=list[TagModel])
async def add_tag_by_id_and_tag_name(
id: str, form_data: TagForm, user=Depends(get_verified_user)
id: str,
form_data: TagForm,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db)
if chat:
tags = chat.meta.get("tags", [])
tag_id = form_data.name.replace(" ", "_").lower()
@ -1262,12 +1375,12 @@ async def add_tag_by_id_and_tag_name(
if tag_id not in tags:
Chats.add_chat_tag_by_id_and_user_id_and_tag_name(
id, user.id, form_data.name
id, user.id, form_data.name, db=db
)
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db)
tags = chat.meta.get("tags", [])
return Tags.get_tags_by_ids_and_user_id(tags, user.id)
return Tags.get_tags_by_ids_and_user_id(tags, user.id, db=db)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
@ -1281,18 +1394,26 @@ async def add_tag_by_id_and_tag_name(
@router.delete("/{id}/tags", response_model=list[TagModel])
async def delete_tag_by_id_and_tag_name(
id: str, form_data: TagForm, user=Depends(get_verified_user)
id: str,
form_data: TagForm,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db)
if chat:
Chats.delete_tag_by_id_and_user_id_and_tag_name(id, user.id, form_data.name)
Chats.delete_tag_by_id_and_user_id_and_tag_name(
id, user.id, form_data.name, db=db
)
if Chats.count_chats_by_tag_name_and_user_id(form_data.name, user.id) == 0:
Tags.delete_tag_by_name_and_user_id(form_data.name, user.id)
if (
Chats.count_chats_by_tag_name_and_user_id(form_data.name, user.id, db=db)
== 0
):
Tags.delete_tag_by_name_and_user_id(form_data.name, user.id, db=db)
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db)
tags = chat.meta.get("tags", [])
return Tags.get_tags_by_ids_and_user_id(tags, user.id)
return Tags.get_tags_by_ids_and_user_id(tags, user.id, db=db)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
@ -1305,14 +1426,16 @@ async def delete_tag_by_id_and_tag_name(
@router.delete("/{id}/tags/all", response_model=Optional[bool])
async def delete_all_tags_by_id(id: str, user=Depends(get_verified_user)):
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
async def delete_all_tags_by_id(
id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
):
chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db)
if chat:
Chats.delete_all_tags_by_id_and_user_id(id, user.id)
Chats.delete_all_tags_by_id_and_user_id(id, user.id, db=db)
for tag in chat.meta.get("tags", []):
if Chats.count_chats_by_tag_name_and_user_id(tag, user.id) == 0:
Tags.delete_tag_by_name_and_user_id(tag, user.id)
if Chats.count_chats_by_tag_name_and_user_id(tag, user.id, db=db) == 0:
Tags.delete_tag_by_name_and_user_id(tag, user.id, db=db)
return True
else:

View file

@ -15,6 +15,8 @@ from open_webui.models.feedbacks import (
from open_webui.constants import ERROR_MESSAGES
from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.internal.db import get_session
from sqlalchemy.orm import Session
router = APIRouter()
@ -60,38 +62,50 @@ async def update_config(
@router.get("/feedbacks/all", response_model=list[FeedbackResponse])
async def get_all_feedbacks(user=Depends(get_admin_user)):
feedbacks = Feedbacks.get_all_feedbacks()
async def get_all_feedbacks(
user=Depends(get_admin_user), db: Session = Depends(get_session)
):
feedbacks = Feedbacks.get_all_feedbacks(db=db)
return feedbacks
@router.get("/feedbacks/all/ids", response_model=list[FeedbackIdResponse])
async def get_all_feedback_ids(user=Depends(get_admin_user)):
feedbacks = Feedbacks.get_all_feedbacks()
async def get_all_feedback_ids(
user=Depends(get_admin_user), db: Session = Depends(get_session)
):
feedbacks = Feedbacks.get_all_feedbacks(db=db)
return feedbacks
@router.delete("/feedbacks/all")
async def delete_all_feedbacks(user=Depends(get_admin_user)):
success = Feedbacks.delete_all_feedbacks()
async def delete_all_feedbacks(
user=Depends(get_admin_user), db: Session = Depends(get_session)
):
success = Feedbacks.delete_all_feedbacks(db=db)
return success
@router.get("/feedbacks/all/export", response_model=list[FeedbackModel])
async def export_all_feedbacks(user=Depends(get_admin_user)):
feedbacks = Feedbacks.get_all_feedbacks()
async def export_all_feedbacks(
user=Depends(get_admin_user), db: Session = Depends(get_session)
):
feedbacks = Feedbacks.get_all_feedbacks(db=db)
return feedbacks
@router.get("/feedbacks/user", response_model=list[FeedbackUserResponse])
async def get_feedbacks(user=Depends(get_verified_user)):
feedbacks = Feedbacks.get_feedbacks_by_user_id(user.id)
async def get_feedbacks(
user=Depends(get_verified_user), db: Session = Depends(get_session)
):
feedbacks = Feedbacks.get_feedbacks_by_user_id(user.id, db=db)
return feedbacks
@router.delete("/feedbacks", response_model=bool)
async def delete_feedbacks(user=Depends(get_verified_user)):
success = Feedbacks.delete_feedbacks_by_user_id(user.id)
async def delete_feedbacks(
user=Depends(get_verified_user), db: Session = Depends(get_session)
):
success = Feedbacks.delete_feedbacks_by_user_id(user.id, db=db)
return success
@ -104,6 +118,7 @@ async def get_feedbacks(
direction: Optional[str] = None,
page: Optional[int] = 1,
user=Depends(get_admin_user),
db: Session = Depends(get_session),
):
limit = PAGE_ITEM_COUNT
@ -116,7 +131,7 @@ async def get_feedbacks(
if direction:
filter["direction"] = direction
result = Feedbacks.get_feedback_items(filter=filter, skip=skip, limit=limit)
result = Feedbacks.get_feedback_items(filter=filter, skip=skip, limit=limit, db=db)
return result
@ -125,8 +140,11 @@ async def create_feedback(
request: Request,
form_data: FeedbackForm,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
feedback = Feedbacks.insert_new_feedback(user_id=user.id, form_data=form_data)
feedback = Feedbacks.insert_new_feedback(
user_id=user.id, form_data=form_data, db=db
)
if not feedback:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
@ -137,11 +155,15 @@ async def create_feedback(
@router.get("/feedback/{id}", response_model=FeedbackModel)
async def get_feedback_by_id(id: str, user=Depends(get_verified_user)):
async def get_feedback_by_id(
id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
):
if user.role == "admin":
feedback = Feedbacks.get_feedback_by_id(id=id)
feedback = Feedbacks.get_feedback_by_id(id=id, db=db)
else:
feedback = Feedbacks.get_feedback_by_id_and_user_id(id=id, user_id=user.id)
feedback = Feedbacks.get_feedback_by_id_and_user_id(
id=id, user_id=user.id, db=db
)
if not feedback:
raise HTTPException(
@ -153,13 +175,16 @@ async def get_feedback_by_id(id: str, user=Depends(get_verified_user)):
@router.post("/feedback/{id}", response_model=FeedbackModel)
async def update_feedback_by_id(
id: str, form_data: FeedbackForm, user=Depends(get_verified_user)
id: str,
form_data: FeedbackForm,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
if user.role == "admin":
feedback = Feedbacks.update_feedback_by_id(id=id, form_data=form_data)
feedback = Feedbacks.update_feedback_by_id(id=id, form_data=form_data, db=db)
else:
feedback = Feedbacks.update_feedback_by_id_and_user_id(
id=id, user_id=user.id, form_data=form_data
id=id, user_id=user.id, form_data=form_data, db=db
)
if not feedback:
@ -171,11 +196,15 @@ async def update_feedback_by_id(
@router.delete("/feedback/{id}")
async def delete_feedback_by_id(id: str, user=Depends(get_verified_user)):
async def delete_feedback_by_id(
id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
):
if user.role == "admin":
success = Feedbacks.delete_feedback_by_id(id=id)
success = Feedbacks.delete_feedback_by_id(id=id, db=db)
else:
success = Feedbacks.delete_feedback_by_id_and_user_id(id=id, user_id=user.id)
success = Feedbacks.delete_feedback_by_id_and_user_id(
id=id, user_id=user.id, db=db
)
if not success:
raise HTTPException(

View file

@ -22,6 +22,8 @@ from fastapi import (
)
from fastapi.responses import FileResponse, StreamingResponse
from sqlalchemy.orm import Session
from open_webui.internal.db import get_session, SessionLocal
from open_webui.constants import ERROR_MESSAGES
from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT
@ -62,9 +64,12 @@ router = APIRouter()
# TODO: Optimize this function to use the knowledge_file table for faster lookups.
def has_access_to_file(
file_id: Optional[str], access_type: str, user=Depends(get_verified_user)
file_id: Optional[str],
access_type: str,
user=Depends(get_verified_user),
db: Optional[Session] = None,
) -> bool:
file = Files.get_file_by_id(file_id)
file = Files.get_file_by_id(file_id, db=db)
log.debug(f"Checking if user has {access_type} access to file")
if not file:
raise HTTPException(
@ -73,31 +78,33 @@ def has_access_to_file(
)
# Check if the file is associated with any knowledge bases the user has access to
knowledge_bases = Knowledges.get_knowledges_by_file_id(file_id)
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user.id)}
knowledge_bases = Knowledges.get_knowledges_by_file_id(file_id, db=db)
user_group_ids = {
group.id for group in Groups.get_groups_by_member_id(user.id, db=db)
}
for knowledge_base in knowledge_bases:
if knowledge_base.user_id == user.id or has_access(
user.id, access_type, knowledge_base.access_control, user_group_ids
user.id, access_type, knowledge_base.access_control, user_group_ids, db=db
):
return True
knowledge_base_id = file.meta.get("collection_name") if file.meta else None
if knowledge_base_id:
knowledge_bases = Knowledges.get_knowledge_bases_by_user_id(
user.id, access_type
user.id, access_type, db=db
)
for knowledge_base in knowledge_bases:
if knowledge_base.id == knowledge_base_id:
return True
# Check if the file is associated with any channels the user has access to
channels = Channels.get_channels_by_file_id_and_user_id(file_id, user.id)
channels = Channels.get_channels_by_file_id_and_user_id(file_id, user.id, db=db)
if access_type == "read" and channels:
return True
# Check if the file is associated with any chats the user has access to
# TODO: Granular access control for chats
chats = Chats.get_shared_chats_by_file_id(file_id)
chats = Chats.get_shared_chats_by_file_id(file_id, db=db)
if chats:
return True
@ -109,47 +116,78 @@ def has_access_to_file(
############################
def process_uploaded_file(request, file, file_path, file_item, file_metadata, user):
try:
if file.content_type:
stt_supported_content_types = getattr(
request.app.state.config, "STT_SUPPORTED_CONTENT_TYPES", []
)
def process_uploaded_file(
request,
file,
file_path,
file_item,
file_metadata,
user,
db: Optional[Session] = None,
):
def _process_handler(db_session):
try:
if file.content_type:
stt_supported_content_types = getattr(
request.app.state.config, "STT_SUPPORTED_CONTENT_TYPES", []
)
if strict_match_mime_type(stt_supported_content_types, file.content_type):
file_path = Storage.get_file(file_path)
result = transcribe(request, file_path, file_metadata, user)
if strict_match_mime_type(
stt_supported_content_types, file.content_type
):
file_path_processed = Storage.get_file(file_path)
result = transcribe(
request, file_path_processed, file_metadata, user
)
process_file(
request,
ProcessFileForm(
file_id=file_item.id, content=result.get("text", "")
),
user=user,
db=db_session,
)
elif (not file.content_type.startswith(("image/", "video/"))) or (
request.app.state.config.CONTENT_EXTRACTION_ENGINE == "external"
):
process_file(
request,
ProcessFileForm(file_id=file_item.id),
user=user,
db=db_session,
)
else:
raise Exception(
f"File type {file.content_type} is not supported for processing"
)
else:
log.info(
f"File type {file.content_type} is not provided, but trying to process anyway"
)
process_file(
request,
ProcessFileForm(
file_id=file_item.id, content=result.get("text", "")
),
ProcessFileForm(file_id=file_item.id),
user=user,
db=db_session,
)
elif (not file.content_type.startswith(("image/", "video/"))) or (
request.app.state.config.CONTENT_EXTRACTION_ENGINE == "external"
):
process_file(request, ProcessFileForm(file_id=file_item.id), user=user)
else:
raise Exception(
f"File type {file.content_type} is not supported for processing"
)
else:
log.info(
f"File type {file.content_type} is not provided, but trying to process anyway"
)
process_file(request, ProcessFileForm(file_id=file_item.id), user=user)
except Exception as e:
log.error(f"Error processing file: {file_item.id}")
Files.update_file_data_by_id(
file_item.id,
{
"status": "failed",
"error": str(e.detail) if hasattr(e, "detail") else str(e),
},
)
except Exception as e:
log.error(f"Error processing file: {file_item.id}")
Files.update_file_data_by_id(
file_item.id,
{
"status": "failed",
"error": str(e.detail) if hasattr(e, "detail") else str(e),
},
db=db_session,
)
if db:
_process_handler(db)
else:
with SessionLocal() as db_session:
_process_handler(db_session)
@router.post("/", response_model=FileModelResponse)
@ -161,6 +199,7 @@ def upload_file(
process: bool = Query(True),
process_in_background: bool = Query(True),
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
return upload_file_handler(
request,
@ -170,6 +209,7 @@ def upload_file(
process_in_background=process_in_background,
user=user,
background_tasks=background_tasks,
db=db,
)
@ -181,6 +221,7 @@ def upload_file_handler(
process_in_background: bool = Query(True),
user=Depends(get_verified_user),
background_tasks: Optional[BackgroundTasks] = None,
db: Optional[Session] = None,
):
log.info(f"file.content_type: {file.content_type} {process}")
@ -248,14 +289,17 @@ def upload_file_handler(
},
}
),
db=db,
)
if "channel_id" in file_metadata:
channel = Channels.get_channel_by_id_and_user_id(
file_metadata["channel_id"], user.id
file_metadata["channel_id"], user.id, db=db
)
if channel:
Channels.add_file_to_channel_by_id(channel.id, file_item.id, user.id)
Channels.add_file_to_channel_by_id(
channel.id, file_item.id, user.id, db=db
)
if process:
if background_tasks and process_in_background:
@ -277,6 +321,7 @@ def upload_file_handler(
file_item,
file_metadata,
user,
db=db,
)
return {"status": True, **file_item.model_dump()}
else:
@ -302,11 +347,15 @@ def upload_file_handler(
@router.get("/", response_model=list[FileModelResponse])
async def list_files(user=Depends(get_verified_user), content: bool = Query(True)):
async def list_files(
user=Depends(get_verified_user),
content: bool = Query(True),
db: Session = Depends(get_session),
):
if user.role == "admin":
files = Files.get_files()
files = Files.get_files(db=db)
else:
files = Files.get_files_by_user_id(user.id)
files = Files.get_files_by_user_id(user.id, db=db)
if not content:
for file in files:
@ -329,15 +378,16 @@ async def search_files(
),
content: bool = Query(True),
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
"""
Search for files by filename with support for wildcard patterns.
"""
# Get files according to user role
if user.role == "admin":
files = Files.get_files()
files = Files.get_files(db=db)
else:
files = Files.get_files_by_user_id(user.id)
files = Files.get_files_by_user_id(user.id, db=db)
# Get matching files
matching_files = [
@ -364,8 +414,10 @@ async def search_files(
@router.delete("/all")
async def delete_all_files(user=Depends(get_admin_user)):
result = Files.delete_all_files()
async def delete_all_files(
user=Depends(get_admin_user), db: Session = Depends(get_session)
):
result = Files.delete_all_files(db=db)
if result:
try:
Storage.delete_all_files()
@ -391,8 +443,10 @@ async def delete_all_files(user=Depends(get_admin_user)):
@router.get("/{id}", response_model=Optional[FileModel])
async def get_file_by_id(id: str, user=Depends(get_verified_user)):
file = Files.get_file_by_id(id)
async def get_file_by_id(
id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
):
file = Files.get_file_by_id(id, db=db)
if not file:
raise HTTPException(
@ -403,7 +457,7 @@ async def get_file_by_id(id: str, user=Depends(get_verified_user)):
if (
file.user_id == user.id
or user.role == "admin"
or has_access_to_file(id, "read", user)
or has_access_to_file(id, "read", user, db=db)
):
return file
else:
@ -415,9 +469,12 @@ async def get_file_by_id(id: str, user=Depends(get_verified_user)):
@router.get("/{id}/process/status")
async def get_file_process_status(
id: str, stream: bool = Query(False), user=Depends(get_verified_user)
id: str,
stream: bool = Query(False),
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
file = Files.get_file_by_id(id)
file = Files.get_file_by_id(id, db=db)
if not file:
raise HTTPException(
@ -428,7 +485,7 @@ async def get_file_process_status(
if (
file.user_id == user.id
or user.role == "admin"
or has_access_to_file(id, "read", user)
or has_access_to_file(id, "read", user, db=db)
):
if stream:
MAX_FILE_PROCESSING_DURATION = 3600 * 2
@ -436,7 +493,7 @@ async def get_file_process_status(
async def event_stream(file_item):
if file_item:
for _ in range(MAX_FILE_PROCESSING_DURATION):
file_item = Files.get_file_by_id(file_item.id)
file_item = Files.get_file_by_id(file_item.id, db=db)
if file_item:
data = file_item.model_dump().get("data", {})
status = data.get("status")
@ -476,8 +533,10 @@ async def get_file_process_status(
@router.get("/{id}/data/content")
async def get_file_data_content_by_id(id: str, user=Depends(get_verified_user)):
file = Files.get_file_by_id(id)
async def get_file_data_content_by_id(
id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
):
file = Files.get_file_by_id(id, db=db)
if not file:
raise HTTPException(
@ -488,7 +547,7 @@ async def get_file_data_content_by_id(id: str, user=Depends(get_verified_user)):
if (
file.user_id == user.id
or user.role == "admin"
or has_access_to_file(id, "read", user)
or has_access_to_file(id, "read", user, db=db)
):
return {"content": file.data.get("content", "")}
else:
@ -509,9 +568,13 @@ class ContentForm(BaseModel):
@router.post("/{id}/data/content/update")
async def update_file_data_content_by_id(
request: Request, id: str, form_data: ContentForm, user=Depends(get_verified_user)
request: Request,
id: str,
form_data: ContentForm,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
file = Files.get_file_by_id(id)
file = Files.get_file_by_id(id, db=db)
if not file:
raise HTTPException(
@ -522,7 +585,7 @@ async def update_file_data_content_by_id(
if (
file.user_id == user.id
or user.role == "admin"
or has_access_to_file(id, "write", user)
or has_access_to_file(id, "write", user, db=db)
):
try:
process_file(
@ -530,7 +593,7 @@ async def update_file_data_content_by_id(
ProcessFileForm(file_id=id, content=form_data.content),
user=user,
)
file = Files.get_file_by_id(id=id)
file = Files.get_file_by_id(id=id, db=db)
except Exception as e:
log.exception(e)
log.error(f"Error processing file: {file.id}")
@ -550,9 +613,12 @@ async def update_file_data_content_by_id(
@router.get("/{id}/content")
async def get_file_content_by_id(
id: str, user=Depends(get_verified_user), attachment: bool = Query(False)
id: str,
user=Depends(get_verified_user),
attachment: bool = Query(False),
db: Session = Depends(get_session),
):
file = Files.get_file_by_id(id)
file = Files.get_file_by_id(id, db=db)
if not file:
raise HTTPException(
@ -563,7 +629,7 @@ async def get_file_content_by_id(
if (
file.user_id == user.id
or user.role == "admin"
or has_access_to_file(id, "read", user)
or has_access_to_file(id, "read", user, db=db)
):
try:
file_path = Storage.get_file(file.path)
@ -619,8 +685,10 @@ async def get_file_content_by_id(
@router.get("/{id}/content/html")
async def get_html_file_content_by_id(id: str, user=Depends(get_verified_user)):
file = Files.get_file_by_id(id)
async def get_html_file_content_by_id(
id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
):
file = Files.get_file_by_id(id, db=db)
if not file:
raise HTTPException(
@ -628,7 +696,7 @@ async def get_html_file_content_by_id(id: str, user=Depends(get_verified_user)):
detail=ERROR_MESSAGES.NOT_FOUND,
)
file_user = Users.get_user_by_id(file.user_id)
file_user = Users.get_user_by_id(file.user_id, db=db)
if not file_user.role == "admin":
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
@ -638,7 +706,7 @@ async def get_html_file_content_by_id(id: str, user=Depends(get_verified_user)):
if (
file.user_id == user.id
or user.role == "admin"
or has_access_to_file(id, "read", user)
or has_access_to_file(id, "read", user, db=db)
):
try:
file_path = Storage.get_file(file.path)
@ -668,8 +736,10 @@ async def get_html_file_content_by_id(id: str, user=Depends(get_verified_user)):
@router.get("/{id}/content/{file_name}")
async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
file = Files.get_file_by_id(id)
async def get_file_content_by_id(
id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
):
file = Files.get_file_by_id(id, db=db)
if not file:
raise HTTPException(
@ -680,7 +750,7 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
if (
file.user_id == user.id
or user.role == "admin"
or has_access_to_file(id, "read", user)
or has_access_to_file(id, "read", user, db=db)
):
file_path = file.path
@ -730,8 +800,10 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
@router.delete("/{id}")
async def delete_file_by_id(id: str, user=Depends(get_verified_user)):
file = Files.get_file_by_id(id)
async def delete_file_by_id(
id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
):
file = Files.get_file_by_id(id, db=db)
if not file:
raise HTTPException(
@ -742,10 +814,10 @@ async def delete_file_by_id(id: str, user=Depends(get_verified_user)):
if (
file.user_id == user.id
or user.role == "admin"
or has_access_to_file(id, "write", user)
or has_access_to_file(id, "write", user, db=db)
):
result = Files.delete_file_by_id(id)
result = Files.delete_file_by_id(id, db=db)
if result:
try:
Storage.delete_file(file.path)

View file

@ -22,6 +22,8 @@ from open_webui.models.knowledge import Knowledges
from open_webui.config import UPLOAD_DIR
from open_webui.constants import ERROR_MESSAGES
from open_webui.internal.db import get_session
from sqlalchemy.orm import Session
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status, Request
@ -44,7 +46,11 @@ router = APIRouter()
@router.get("/", response_model=list[FolderNameIdResponse])
async def get_folders(request: Request, user=Depends(get_verified_user)):
async def get_folders(
request: Request,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
if request.app.state.config.ENABLE_FOLDERS is False:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
@ -55,22 +61,23 @@ async def get_folders(request: Request, user=Depends(get_verified_user)):
user.id,
"features.folders",
request.app.state.config.USER_PERMISSIONS,
db=db,
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
folders = Folders.get_folders_by_user_id(user.id)
folders = Folders.get_folders_by_user_id(user.id, db=db)
# Verify folder data integrity
folder_list = []
for folder in folders:
if folder.parent_id and not Folders.get_folder_by_id_and_user_id(
folder.parent_id, user.id
folder.parent_id, user.id, db=db
):
folder = Folders.update_folder_parent_id_by_id_and_user_id(
folder.id, user.id, None
folder.id, user.id, None, db=db
)
if folder.data:
@ -80,12 +87,12 @@ async def get_folders(request: Request, user=Depends(get_verified_user)):
if file.get("type") == "file":
if Files.check_access_by_user_id(
file.get("id"), user.id, "read"
file.get("id"), user.id, "read", db=db
):
valid_files.append(file)
elif file.get("type") == "collection":
if Knowledges.check_access_by_user_id(
file.get("id"), user.id, "read"
file.get("id"), user.id, "read", db=db
):
valid_files.append(file)
else:
@ -93,7 +100,7 @@ async def get_folders(request: Request, user=Depends(get_verified_user)):
folder.data["files"] = valid_files
Folders.update_folder_by_id_and_user_id(
folder.id, user.id, FolderUpdateForm(data=folder.data)
folder.id, user.id, FolderUpdateForm(data=folder.data), db=db
)
folder_list.append(FolderNameIdResponse(**folder.model_dump()))
@ -107,9 +114,13 @@ async def get_folders(request: Request, user=Depends(get_verified_user)):
@router.post("/")
def create_folder(form_data: FolderForm, user=Depends(get_verified_user)):
def create_folder(
form_data: FolderForm,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
folder = Folders.get_folder_by_parent_id_and_user_id_and_name(
None, user.id, form_data.name
None, user.id, form_data.name, db=db
)
if folder:
@ -119,7 +130,7 @@ def create_folder(form_data: FolderForm, user=Depends(get_verified_user)):
)
try:
folder = Folders.insert_new_folder(user.id, form_data)
folder = Folders.insert_new_folder(user.id, form_data, db=db)
return folder
except Exception as e:
log.exception(e)
@ -136,8 +147,10 @@ def create_folder(form_data: FolderForm, user=Depends(get_verified_user)):
@router.get("/{id}", response_model=Optional[FolderModel])
async def get_folder_by_id(id: str, user=Depends(get_verified_user)):
folder = Folders.get_folder_by_id_and_user_id(id, user.id)
async def get_folder_by_id(
id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
):
folder = Folders.get_folder_by_id_and_user_id(id, user.id, db=db)
if folder:
return folder
else:
@ -154,15 +167,18 @@ async def get_folder_by_id(id: str, user=Depends(get_verified_user)):
@router.post("/{id}/update")
async def update_folder_name_by_id(
id: str, form_data: FolderUpdateForm, user=Depends(get_verified_user)
id: str,
form_data: FolderUpdateForm,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
folder = Folders.get_folder_by_id_and_user_id(id, user.id)
folder = Folders.get_folder_by_id_and_user_id(id, user.id, db=db)
if folder:
if form_data.name is not None:
# Check if folder with same name exists
existing_folder = Folders.get_folder_by_parent_id_and_user_id_and_name(
folder.parent_id, user.id, form_data.name
folder.parent_id, user.id, form_data.name, db=db
)
if existing_folder and existing_folder.id != id:
raise HTTPException(
@ -171,7 +187,9 @@ async def update_folder_name_by_id(
)
try:
folder = Folders.update_folder_by_id_and_user_id(id, user.id, form_data)
folder = Folders.update_folder_by_id_and_user_id(
id, user.id, form_data, db=db
)
return folder
except Exception as e:
log.exception(e)
@ -198,12 +216,15 @@ class FolderParentIdForm(BaseModel):
@router.post("/{id}/update/parent")
async def update_folder_parent_id_by_id(
id: str, form_data: FolderParentIdForm, user=Depends(get_verified_user)
id: str,
form_data: FolderParentIdForm,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
folder = Folders.get_folder_by_id_and_user_id(id, user.id)
folder = Folders.get_folder_by_id_and_user_id(id, user.id, db=db)
if folder:
existing_folder = Folders.get_folder_by_parent_id_and_user_id_and_name(
form_data.parent_id, user.id, folder.name
form_data.parent_id, user.id, folder.name, db=db
)
if existing_folder:
@ -214,7 +235,7 @@ async def update_folder_parent_id_by_id(
try:
folder = Folders.update_folder_parent_id_by_id_and_user_id(
id, user.id, form_data.parent_id
id, user.id, form_data.parent_id, db=db
)
return folder
except Exception as e:
@ -242,13 +263,16 @@ class FolderIsExpandedForm(BaseModel):
@router.post("/{id}/update/expanded")
async def update_folder_is_expanded_by_id(
id: str, form_data: FolderIsExpandedForm, user=Depends(get_verified_user)
id: str,
form_data: FolderIsExpandedForm,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
folder = Folders.get_folder_by_id_and_user_id(id, user.id)
folder = Folders.get_folder_by_id_and_user_id(id, user.id, db=db)
if folder:
try:
folder = Folders.update_folder_is_expanded_by_id_and_user_id(
id, user.id, form_data.is_expanded
id, user.id, form_data.is_expanded, db=db
)
return folder
except Exception as e:
@ -276,10 +300,11 @@ async def delete_folder_by_id(
id: str,
delete_contents: Optional[bool] = True,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
if Chats.count_chats_by_folder_id_and_user_id(id, user.id):
if Chats.count_chats_by_folder_id_and_user_id(id, user.id, db=db):
chat_delete_permission = has_permission(
user.id, "chat.delete", request.app.state.config.USER_PERMISSIONS
user.id, "chat.delete", request.app.state.config.USER_PERMISSIONS, db=db
)
if user.role != "admin" and not chat_delete_permission:
raise HTTPException(
@ -288,19 +313,21 @@ async def delete_folder_by_id(
)
folders = []
folders.append(Folders.get_folder_by_id_and_user_id(id, user.id))
folders.append(Folders.get_folder_by_id_and_user_id(id, user.id, db=db))
while folders:
folder = folders.pop()
if folder:
try:
folder_ids = Folders.delete_folder_by_id_and_user_id(id, user.id)
folder_ids = Folders.delete_folder_by_id_and_user_id(id, user.id, db=db)
for folder_id in folder_ids:
if delete_contents:
Chats.delete_chats_by_user_id_and_folder_id(user.id, folder_id)
Chats.delete_chats_by_user_id_and_folder_id(
user.id, folder_id, db=db
)
else:
Chats.move_chats_by_user_id_and_folder_id(
user.id, folder_id, None
user.id, folder_id, None, db=db
)
return True
@ -314,7 +341,7 @@ async def delete_folder_by_id(
finally:
# Get all subfolders
subfolders = Folders.get_folders_by_parent_id_and_user_id(
folder.id, user.id
folder.id, user.id, db=db
)
folders.extend(subfolders)

View file

@ -24,6 +24,8 @@ from open_webui.constants import ERROR_MESSAGES
from fastapi import APIRouter, Depends, HTTPException, Request, status
from open_webui.utils.auth import get_admin_user, get_verified_user
from pydantic import BaseModel, HttpUrl
from open_webui.internal.db import get_session
from sqlalchemy.orm import Session
log = logging.getLogger(__name__)
@ -37,13 +39,13 @@ router = APIRouter()
@router.get("/", response_model=list[FunctionResponse])
async def get_functions(user=Depends(get_verified_user)):
return Functions.get_functions()
async def get_functions(user=Depends(get_verified_user), db: Session = Depends(get_session)):
return Functions.get_functions(db=db)
@router.get("/list", response_model=list[FunctionUserResponse])
async def get_function_list(user=Depends(get_admin_user)):
return Functions.get_function_list()
async def get_function_list(user=Depends(get_admin_user), db: Session = Depends(get_session)):
return Functions.get_function_list(db=db)
############################
@ -52,8 +54,8 @@ async def get_function_list(user=Depends(get_admin_user)):
@router.get("/export", response_model=list[FunctionModel | FunctionWithValvesModel])
async def get_functions(include_valves: bool = False, user=Depends(get_admin_user)):
return Functions.get_functions(include_valves=include_valves)
async def get_functions(include_valves: bool = False, user=Depends(get_admin_user), db: Session = Depends(get_session)):
return Functions.get_functions(include_valves=include_valves, db=db)
############################
@ -142,7 +144,7 @@ class SyncFunctionsForm(BaseModel):
@router.post("/sync", response_model=list[FunctionWithValvesModel])
async def sync_functions(
request: Request, form_data: SyncFunctionsForm, user=Depends(get_admin_user)
request: Request, form_data: SyncFunctionsForm, user=Depends(get_admin_user), db: Session = Depends(get_session)
):
try:
for function in form_data.functions:
@ -164,7 +166,7 @@ async def sync_functions(
)
raise e
return Functions.sync_functions(user.id, form_data.functions)
return Functions.sync_functions(user.id, form_data.functions, db=db)
except Exception as e:
log.exception(f"Failed to load a function: {e}")
raise HTTPException(
@ -180,7 +182,7 @@ async def sync_functions(
@router.post("/create", response_model=Optional[FunctionResponse])
async def create_new_function(
request: Request, form_data: FunctionForm, user=Depends(get_admin_user)
request: Request, form_data: FunctionForm, user=Depends(get_admin_user), db: Session = Depends(get_session)
):
if not form_data.id.isidentifier():
raise HTTPException(
@ -190,7 +192,7 @@ async def create_new_function(
form_data.id = form_data.id.lower()
function = Functions.get_function_by_id(form_data.id)
function = Functions.get_function_by_id(form_data.id, db=db)
if function is None:
try:
form_data.content = replace_imports(form_data.content)
@ -203,13 +205,13 @@ async def create_new_function(
FUNCTIONS = request.app.state.FUNCTIONS
FUNCTIONS[form_data.id] = function_module
function = Functions.insert_new_function(user.id, function_type, form_data)
function = Functions.insert_new_function(user.id, function_type, form_data, db=db)
function_cache_dir = CACHE_DIR / "functions" / form_data.id
function_cache_dir.mkdir(parents=True, exist_ok=True)
if function_type == "filter" and getattr(function_module, "toggle", None):
Functions.update_function_metadata_by_id(id, {"toggle": True})
Functions.update_function_metadata_by_id(form_data.id, {"toggle": True}, db=db)
if function:
return function
@ -237,8 +239,8 @@ async def create_new_function(
@router.get("/id/{id}", response_model=Optional[FunctionModel])
async def get_function_by_id(id: str, user=Depends(get_admin_user)):
function = Functions.get_function_by_id(id)
async def get_function_by_id(id: str, user=Depends(get_admin_user), db: Session = Depends(get_session)):
function = Functions.get_function_by_id(id, db=db)
if function:
return function
@ -255,11 +257,11 @@ async def get_function_by_id(id: str, user=Depends(get_admin_user)):
@router.post("/id/{id}/toggle", response_model=Optional[FunctionModel])
async def toggle_function_by_id(id: str, user=Depends(get_admin_user)):
function = Functions.get_function_by_id(id)
async def toggle_function_by_id(id: str, user=Depends(get_admin_user), db: Session = Depends(get_session)):
function = Functions.get_function_by_id(id, db=db)
if function:
function = Functions.update_function_by_id(
id, {"is_active": not function.is_active}
id, {"is_active": not function.is_active}, db=db
)
if function:
@ -282,11 +284,11 @@ async def toggle_function_by_id(id: str, user=Depends(get_admin_user)):
@router.post("/id/{id}/toggle/global", response_model=Optional[FunctionModel])
async def toggle_global_by_id(id: str, user=Depends(get_admin_user)):
function = Functions.get_function_by_id(id)
async def toggle_global_by_id(id: str, user=Depends(get_admin_user), db: Session = Depends(get_session)):
function = Functions.get_function_by_id(id, db=db)
if function:
function = Functions.update_function_by_id(
id, {"is_global": not function.is_global}
id, {"is_global": not function.is_global}, db=db
)
if function:
@ -310,7 +312,7 @@ async def toggle_global_by_id(id: str, user=Depends(get_admin_user)):
@router.post("/id/{id}/update", response_model=Optional[FunctionModel])
async def update_function_by_id(
request: Request, id: str, form_data: FunctionForm, user=Depends(get_admin_user)
request: Request, id: str, form_data: FunctionForm, user=Depends(get_admin_user), db: Session = Depends(get_session)
):
try:
form_data.content = replace_imports(form_data.content)
@ -325,10 +327,10 @@ async def update_function_by_id(
updated = {**form_data.model_dump(exclude={"id"}), "type": function_type}
log.debug(updated)
function = Functions.update_function_by_id(id, updated)
function = Functions.update_function_by_id(id, updated, db=db)
if function_type == "filter" and getattr(function_module, "toggle", None):
Functions.update_function_metadata_by_id(id, {"toggle": True})
Functions.update_function_metadata_by_id(id, {"toggle": True}, db=db)
if function:
return function
@ -352,9 +354,9 @@ async def update_function_by_id(
@router.delete("/id/{id}/delete", response_model=bool)
async def delete_function_by_id(
request: Request, id: str, user=Depends(get_admin_user)
request: Request, id: str, user=Depends(get_admin_user), db: Session = Depends(get_session)
):
result = Functions.delete_function_by_id(id)
result = Functions.delete_function_by_id(id, db=db)
if result:
FUNCTIONS = request.app.state.FUNCTIONS
@ -370,11 +372,11 @@ async def delete_function_by_id(
@router.get("/id/{id}/valves", response_model=Optional[dict])
async def get_function_valves_by_id(id: str, user=Depends(get_admin_user)):
function = Functions.get_function_by_id(id)
async def get_function_valves_by_id(id: str, user=Depends(get_admin_user), db: Session = Depends(get_session)):
function = Functions.get_function_by_id(id, db=db)
if function:
try:
valves = Functions.get_function_valves_by_id(id)
valves = Functions.get_function_valves_by_id(id, db=db)
return valves
except Exception as e:
raise HTTPException(
@ -395,9 +397,9 @@ async def get_function_valves_by_id(id: str, user=Depends(get_admin_user)):
@router.get("/id/{id}/valves/spec", response_model=Optional[dict])
async def get_function_valves_spec_by_id(
request: Request, id: str, user=Depends(get_admin_user)
request: Request, id: str, user=Depends(get_admin_user), db: Session = Depends(get_session)
):
function = Functions.get_function_by_id(id)
function = Functions.get_function_by_id(id, db=db)
if function:
function_module, function_type, frontmatter = get_function_module_from_cache(
request, id
@ -421,9 +423,9 @@ async def get_function_valves_spec_by_id(
@router.post("/id/{id}/valves/update", response_model=Optional[dict])
async def update_function_valves_by_id(
request: Request, id: str, form_data: dict, user=Depends(get_admin_user)
request: Request, id: str, form_data: dict, user=Depends(get_admin_user), db: Session = Depends(get_session)
):
function = Functions.get_function_by_id(id)
function = Functions.get_function_by_id(id, db=db)
if function:
function_module, function_type, frontmatter = get_function_module_from_cache(
request, id
@ -437,7 +439,7 @@ async def update_function_valves_by_id(
valves = Valves(**form_data)
valves_dict = valves.model_dump(exclude_unset=True)
Functions.update_function_valves_by_id(id, valves_dict)
Functions.update_function_valves_by_id(id, valves_dict, db=db)
return valves_dict
except Exception as e:
log.exception(f"Error updating function values by id {id}: {e}")
@ -464,11 +466,11 @@ async def update_function_valves_by_id(
@router.get("/id/{id}/valves/user", response_model=Optional[dict])
async def get_function_user_valves_by_id(id: str, user=Depends(get_verified_user)):
function = Functions.get_function_by_id(id)
async def get_function_user_valves_by_id(id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)):
function = Functions.get_function_by_id(id, db=db)
if function:
try:
user_valves = Functions.get_user_valves_by_id_and_user_id(id, user.id)
user_valves = Functions.get_user_valves_by_id_and_user_id(id, user.id, db=db)
return user_valves
except Exception as e:
raise HTTPException(
@ -484,9 +486,9 @@ async def get_function_user_valves_by_id(id: str, user=Depends(get_verified_user
@router.get("/id/{id}/valves/user/spec", response_model=Optional[dict])
async def get_function_user_valves_spec_by_id(
request: Request, id: str, user=Depends(get_verified_user)
request: Request, id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
):
function = Functions.get_function_by_id(id)
function = Functions.get_function_by_id(id, db=db)
if function:
function_module, function_type, frontmatter = get_function_module_from_cache(
request, id
@ -505,9 +507,9 @@ async def get_function_user_valves_spec_by_id(
@router.post("/id/{id}/valves/user/update", response_model=Optional[dict])
async def update_function_user_valves_by_id(
request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
request: Request, id: str, form_data: dict, user=Depends(get_verified_user), db: Session = Depends(get_session)
):
function = Functions.get_function_by_id(id)
function = Functions.get_function_by_id(id, db=db)
if function:
function_module, function_type, frontmatter = get_function_module_from_cache(
@ -522,7 +524,7 @@ async def update_function_user_valves_by_id(
user_valves = UserValves(**form_data)
user_valves_dict = user_valves.model_dump(exclude_unset=True)
Functions.update_user_valves_by_id_and_user_id(
id, user.id, user_valves_dict
id, user.id, user_valves_dict, db=db
)
return user_valves_dict
except Exception as e:

View file

@ -16,6 +16,9 @@ from open_webui.config import CACHE_DIR
from open_webui.constants import ERROR_MESSAGES
from fastapi import APIRouter, Depends, HTTPException, Request, status
from open_webui.internal.db import get_session
from sqlalchemy.orm import Session
from open_webui.utils.auth import get_admin_user, get_verified_user
@ -29,7 +32,11 @@ router = APIRouter()
@router.get("/", response_model=list[GroupResponse])
async def get_groups(share: Optional[bool] = None, user=Depends(get_verified_user)):
async def get_groups(
share: Optional[bool] = None,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
filter = {}
if user.role != "admin":
@ -38,7 +45,7 @@ async def get_groups(share: Optional[bool] = None, user=Depends(get_verified_use
if share is not None:
filter["share"] = share
groups = Groups.get_groups(filter=filter)
groups = Groups.get_groups(filter=filter, db=db)
return groups
@ -49,13 +56,17 @@ async def get_groups(share: Optional[bool] = None, user=Depends(get_verified_use
@router.post("/create", response_model=Optional[GroupResponse])
async def create_new_group(form_data: GroupForm, user=Depends(get_admin_user)):
async def create_new_group(
form_data: GroupForm,
user=Depends(get_admin_user),
db: Session = Depends(get_session),
):
try:
group = Groups.insert_new_group(user.id, form_data)
group = Groups.insert_new_group(user.id, form_data, db=db)
if group:
return GroupResponse(
**group.model_dump(),
member_count=Groups.get_group_member_count_by_id(group.id),
member_count=Groups.get_group_member_count_by_id(group.id, db=db),
)
else:
raise HTTPException(
@ -76,12 +87,14 @@ async def create_new_group(form_data: GroupForm, user=Depends(get_admin_user)):
@router.get("/id/{id}", response_model=Optional[GroupResponse])
async def get_group_by_id(id: str, user=Depends(get_admin_user)):
group = Groups.get_group_by_id(id)
async def get_group_by_id(
id: str, user=Depends(get_admin_user), db: Session = Depends(get_session)
):
group = Groups.get_group_by_id(id, db=db)
if group:
return GroupResponse(
**group.model_dump(),
member_count=Groups.get_group_member_count_by_id(group.id),
member_count=Groups.get_group_member_count_by_id(group.id, db=db),
)
else:
raise HTTPException(
@ -101,13 +114,15 @@ class GroupExportResponse(GroupResponse):
@router.get("/id/{id}/export", response_model=Optional[GroupExportResponse])
async def export_group_by_id(id: str, user=Depends(get_admin_user)):
group = Groups.get_group_by_id(id)
async def export_group_by_id(
id: str, user=Depends(get_admin_user), db: Session = Depends(get_session)
):
group = Groups.get_group_by_id(id, db=db)
if group:
return GroupExportResponse(
**group.model_dump(),
member_count=Groups.get_group_member_count_by_id(group.id),
user_ids=Groups.get_group_user_ids_by_id(group.id),
member_count=Groups.get_group_member_count_by_id(group.id, db=db),
user_ids=Groups.get_group_user_ids_by_id(group.id, db=db),
)
else:
raise HTTPException(
@ -122,9 +137,11 @@ async def export_group_by_id(id: str, user=Depends(get_admin_user)):
@router.post("/id/{id}/users", response_model=list[UserInfoResponse])
async def get_users_in_group(id: str, user=Depends(get_admin_user)):
async def get_users_in_group(
id: str, user=Depends(get_admin_user), db: Session = Depends(get_session)
):
try:
users = Users.get_users_by_group_id(id)
users = Users.get_users_by_group_id(id, db=db)
return users
except Exception as e:
log.exception(f"Error adding users to group {id}: {e}")
@ -141,14 +158,17 @@ async def get_users_in_group(id: str, user=Depends(get_admin_user)):
@router.post("/id/{id}/update", response_model=Optional[GroupResponse])
async def update_group_by_id(
id: str, form_data: GroupUpdateForm, user=Depends(get_admin_user)
id: str,
form_data: GroupUpdateForm,
user=Depends(get_admin_user),
db: Session = Depends(get_session),
):
try:
group = Groups.update_group_by_id(id, form_data)
group = Groups.update_group_by_id(id, form_data, db=db)
if group:
return GroupResponse(
**group.model_dump(),
member_count=Groups.get_group_member_count_by_id(group.id),
member_count=Groups.get_group_member_count_by_id(group.id, db=db),
)
else:
raise HTTPException(
@ -170,17 +190,20 @@ async def update_group_by_id(
@router.post("/id/{id}/users/add", response_model=Optional[GroupResponse])
async def add_user_to_group(
id: str, form_data: UserIdsForm, user=Depends(get_admin_user)
id: str,
form_data: UserIdsForm,
user=Depends(get_admin_user),
db: Session = Depends(get_session),
):
try:
if form_data.user_ids:
form_data.user_ids = Users.get_valid_user_ids(form_data.user_ids)
form_data.user_ids = Users.get_valid_user_ids(form_data.user_ids, db=db)
group = Groups.add_users_to_group(id, form_data.user_ids)
group = Groups.add_users_to_group(id, form_data.user_ids, db=db)
if group:
return GroupResponse(
**group.model_dump(),
member_count=Groups.get_group_member_count_by_id(group.id),
member_count=Groups.get_group_member_count_by_id(group.id, db=db),
)
else:
raise HTTPException(
@ -197,14 +220,17 @@ async def add_user_to_group(
@router.post("/id/{id}/users/remove", response_model=Optional[GroupResponse])
async def remove_users_from_group(
id: str, form_data: UserIdsForm, user=Depends(get_admin_user)
id: str,
form_data: UserIdsForm,
user=Depends(get_admin_user),
db: Session = Depends(get_session),
):
try:
group = Groups.remove_users_from_group(id, form_data.user_ids)
group = Groups.remove_users_from_group(id, form_data.user_ids, db=db)
if group:
return GroupResponse(
**group.model_dump(),
member_count=Groups.get_group_member_count_by_id(group.id),
member_count=Groups.get_group_member_count_by_id(group.id, db=db),
)
else:
raise HTTPException(
@ -225,9 +251,11 @@ async def remove_users_from_group(
@router.delete("/id/{id}/delete", response_model=bool)
async def delete_group_by_id(id: str, user=Depends(get_admin_user)):
async def delete_group_by_id(
id: str, user=Depends(get_admin_user), db: Session = Depends(get_session)
):
try:
result = Groups.delete_group_by_id(id)
result = Groups.delete_group_by_id(id, db=db)
if result:
return result
else:

View file

@ -4,6 +4,8 @@ from fastapi import APIRouter, Depends, HTTPException, status, Request, Query
from fastapi.concurrency import run_in_threadpool
import logging
from sqlalchemy.orm import Session
from open_webui.internal.db import get_session
from open_webui.models.groups import Groups
from open_webui.models.knowledge import (
KnowledgeFileListResponse,
@ -52,21 +54,25 @@ class KnowledgeAccessListResponse(BaseModel):
@router.get("/", response_model=KnowledgeAccessListResponse)
async def get_knowledge_bases(page: Optional[int] = 1, user=Depends(get_verified_user)):
async def get_knowledge_bases(
page: Optional[int] = 1,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
page = max(page, 1)
limit = PAGE_ITEM_COUNT
skip = (page - 1) * limit
filter = {}
if not user.role == "admin" or not BYPASS_ADMIN_ACCESS_CONTROL:
groups = Groups.get_groups_by_member_id(user.id)
groups = Groups.get_groups_by_member_id(user.id, db=db)
if groups:
filter["group_ids"] = [group.id for group in groups]
filter["user_id"] = user.id
result = Knowledges.search_knowledge_bases(
user.id, filter=filter, skip=skip, limit=limit
user.id, filter=filter, skip=skip, limit=limit, db=db
)
return KnowledgeAccessListResponse(
@ -75,7 +81,9 @@ async def get_knowledge_bases(page: Optional[int] = 1, user=Depends(get_verified
**knowledge_base.model_dump(),
write_access=(
user.id == knowledge_base.user_id
or has_access(user.id, "write", knowledge_base.access_control)
or has_access(
user.id, "write", knowledge_base.access_control, db=db
)
),
)
for knowledge_base in result.items
@ -90,6 +98,7 @@ async def search_knowledge_bases(
view_option: Optional[str] = None,
page: Optional[int] = 1,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
page = max(page, 1)
limit = PAGE_ITEM_COUNT
@ -102,14 +111,14 @@ async def search_knowledge_bases(
filter["view_option"] = view_option
if not user.role == "admin" or not BYPASS_ADMIN_ACCESS_CONTROL:
groups = Groups.get_groups_by_member_id(user.id)
groups = Groups.get_groups_by_member_id(user.id, db=db)
if groups:
filter["group_ids"] = [group.id for group in groups]
filter["user_id"] = user.id
result = Knowledges.search_knowledge_bases(
user.id, filter=filter, skip=skip, limit=limit
user.id, filter=filter, skip=skip, limit=limit, db=db
)
return KnowledgeAccessListResponse(
@ -118,7 +127,9 @@ async def search_knowledge_bases(
**knowledge_base.model_dump(),
write_access=(
user.id == knowledge_base.user_id
or has_access(user.id, "write", knowledge_base.access_control)
or has_access(
user.id, "write", knowledge_base.access_control, db=db
)
),
)
for knowledge_base in result.items
@ -132,6 +143,7 @@ async def search_knowledge_files(
query: Optional[str] = None,
page: Optional[int] = 1,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
page = max(page, 1)
limit = PAGE_ITEM_COUNT
@ -141,13 +153,15 @@ async def search_knowledge_files(
if query:
filter["query"] = query
groups = Groups.get_groups_by_member_id(user.id)
groups = Groups.get_groups_by_member_id(user.id, db=db)
if groups:
filter["group_ids"] = [group.id for group in groups]
filter["user_id"] = user.id
return Knowledges.search_knowledge_files(filter=filter, skip=skip, limit=limit)
return Knowledges.search_knowledge_files(
filter=filter, skip=skip, limit=limit, db=db
)
############################
@ -157,10 +171,13 @@ async def search_knowledge_files(
@router.post("/create", response_model=Optional[KnowledgeResponse])
async def create_new_knowledge(
request: Request, form_data: KnowledgeForm, user=Depends(get_verified_user)
request: Request,
form_data: KnowledgeForm,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
if user.role != "admin" and not has_permission(
user.id, "workspace.knowledge", request.app.state.config.USER_PERMISSIONS
user.id, "workspace.knowledge", request.app.state.config.USER_PERMISSIONS, db=db
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@ -175,11 +192,12 @@ async def create_new_knowledge(
user.id,
"sharing.public_knowledge",
request.app.state.config.USER_PERMISSIONS,
db=db,
)
):
form_data.access_control = {}
knowledge = Knowledges.insert_new_knowledge(user.id, form_data)
knowledge = Knowledges.insert_new_knowledge(user.id, form_data, db=db)
if knowledge:
return knowledge
@ -196,20 +214,24 @@ async def create_new_knowledge(
@router.post("/reindex", response_model=bool)
async def reindex_knowledge_files(request: Request, user=Depends(get_verified_user)):
async def reindex_knowledge_files(
request: Request,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.UNAUTHORIZED,
)
knowledge_bases = Knowledges.get_knowledge_bases()
knowledge_bases = Knowledges.get_knowledge_bases(db=db)
log.info(f"Starting reindexing for {len(knowledge_bases)} knowledge bases")
for knowledge_base in knowledge_bases:
try:
files = Knowledges.get_files_by_id(knowledge_base.id)
files = Knowledges.get_files_by_id(knowledge_base.id, db=db)
try:
if VECTOR_DB_CLIENT.has_collection(collection_name=knowledge_base.id):
VECTOR_DB_CLIENT.delete_collection(
@ -229,6 +251,7 @@ async def reindex_knowledge_files(request: Request, user=Depends(get_verified_us
file_id=file.id, collection_name=knowledge_base.id
),
user=user,
db=db,
)
except Exception as e:
log.error(
@ -264,21 +287,23 @@ class KnowledgeFilesResponse(KnowledgeResponse):
@router.get("/{id}", response_model=Optional[KnowledgeFilesResponse])
async def get_knowledge_by_id(id: str, user=Depends(get_verified_user)):
knowledge = Knowledges.get_knowledge_by_id(id=id)
async def get_knowledge_by_id(
id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
):
knowledge = Knowledges.get_knowledge_by_id(id=id, db=db)
if knowledge:
if (
user.role == "admin"
or knowledge.user_id == user.id
or has_access(user.id, "read", knowledge.access_control)
or has_access(user.id, "read", knowledge.access_control, db=db)
):
return KnowledgeFilesResponse(
**knowledge.model_dump(),
write_access=(
user.id == knowledge.user_id
or has_access(user.id, "write", knowledge.access_control)
or has_access(user.id, "write", knowledge.access_control, db=db)
),
)
else:
@ -299,8 +324,9 @@ async def update_knowledge_by_id(
id: str,
form_data: KnowledgeForm,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
knowledge = Knowledges.get_knowledge_by_id(id=id)
knowledge = Knowledges.get_knowledge_by_id(id=id, db=db)
if not knowledge:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
@ -309,7 +335,7 @@ async def update_knowledge_by_id(
# Is the user the original creator, in a group with write access, or an admin
if (
knowledge.user_id != user.id
and not has_access(user.id, "write", knowledge.access_control)
and not has_access(user.id, "write", knowledge.access_control, db=db)
and user.role != "admin"
):
raise HTTPException(
@ -325,15 +351,16 @@ async def update_knowledge_by_id(
user.id,
"sharing.public_knowledge",
request.app.state.config.USER_PERMISSIONS,
db=db,
)
):
form_data.access_control = {}
knowledge = Knowledges.update_knowledge_by_id(id=id, form_data=form_data)
knowledge = Knowledges.update_knowledge_by_id(id=id, form_data=form_data, db=db)
if knowledge:
return KnowledgeFilesResponse(
**knowledge.model_dump(),
files=Knowledges.get_file_metadatas_by_id(knowledge.id),
files=Knowledges.get_file_metadatas_by_id(knowledge.id, db=db),
)
else:
raise HTTPException(
@ -356,9 +383,10 @@ async def get_knowledge_files_by_id(
direction: Optional[str] = None,
page: Optional[int] = 1,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
knowledge = Knowledges.get_knowledge_by_id(id=id)
knowledge = Knowledges.get_knowledge_by_id(id=id, db=db)
if not knowledge:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
@ -368,7 +396,7 @@ async def get_knowledge_files_by_id(
if not (
user.role == "admin"
or knowledge.user_id == user.id
or has_access(user.id, "read", knowledge.access_control)
or has_access(user.id, "read", knowledge.access_control, db=db)
):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
@ -391,7 +419,7 @@ async def get_knowledge_files_by_id(
filter["direction"] = direction
return Knowledges.search_files_by_id(
id, user.id, filter=filter, skip=skip, limit=limit
id, user.id, filter=filter, skip=skip, limit=limit, db=db
)
@ -410,8 +438,9 @@ def add_file_to_knowledge_by_id(
id: str,
form_data: KnowledgeFileIdForm,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
knowledge = Knowledges.get_knowledge_by_id(id=id)
knowledge = Knowledges.get_knowledge_by_id(id=id, db=db)
if not knowledge:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
@ -420,7 +449,7 @@ def add_file_to_knowledge_by_id(
if (
knowledge.user_id != user.id
and not has_access(user.id, "write", knowledge.access_control)
and not has_access(user.id, "write", knowledge.access_control, db=db)
and user.role != "admin"
):
raise HTTPException(
@ -428,7 +457,7 @@ def add_file_to_knowledge_by_id(
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
file = Files.get_file_by_id(form_data.file_id)
file = Files.get_file_by_id(form_data.file_id, db=db)
if not file:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
@ -446,11 +475,12 @@ def add_file_to_knowledge_by_id(
request,
ProcessFileForm(file_id=form_data.file_id, collection_name=id),
user=user,
db=db,
)
# Add file to knowledge base
Knowledges.add_file_to_knowledge_by_id(
knowledge_id=id, file_id=form_data.file_id, user_id=user.id
knowledge_id=id, file_id=form_data.file_id, user_id=user.id, db=db
)
except Exception as e:
log.debug(e)
@ -462,7 +492,7 @@ def add_file_to_knowledge_by_id(
if knowledge:
return KnowledgeFilesResponse(
**knowledge.model_dump(),
files=Knowledges.get_file_metadatas_by_id(knowledge.id),
files=Knowledges.get_file_metadatas_by_id(knowledge.id, db=db),
)
else:
raise HTTPException(
@ -477,8 +507,9 @@ def update_file_from_knowledge_by_id(
id: str,
form_data: KnowledgeFileIdForm,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
knowledge = Knowledges.get_knowledge_by_id(id=id)
knowledge = Knowledges.get_knowledge_by_id(id=id, db=db)
if not knowledge:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
@ -487,7 +518,7 @@ def update_file_from_knowledge_by_id(
if (
knowledge.user_id != user.id
and not has_access(user.id, "write", knowledge.access_control)
and not has_access(user.id, "write", knowledge.access_control, db=db)
and user.role != "admin"
):
@ -496,7 +527,7 @@ def update_file_from_knowledge_by_id(
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
file = Files.get_file_by_id(form_data.file_id)
file = Files.get_file_by_id(form_data.file_id, db=db)
if not file:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
@ -514,6 +545,7 @@ def update_file_from_knowledge_by_id(
request,
ProcessFileForm(file_id=form_data.file_id, collection_name=id),
user=user,
db=db,
)
except Exception as e:
raise HTTPException(
@ -524,7 +556,7 @@ def update_file_from_knowledge_by_id(
if knowledge:
return KnowledgeFilesResponse(
**knowledge.model_dump(),
files=Knowledges.get_file_metadatas_by_id(knowledge.id),
files=Knowledges.get_file_metadatas_by_id(knowledge.id, db=db),
)
else:
raise HTTPException(
@ -544,8 +576,9 @@ def remove_file_from_knowledge_by_id(
form_data: KnowledgeFileIdForm,
delete_file: bool = Query(True),
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
knowledge = Knowledges.get_knowledge_by_id(id=id)
knowledge = Knowledges.get_knowledge_by_id(id=id, db=db)
if not knowledge:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
@ -554,7 +587,7 @@ def remove_file_from_knowledge_by_id(
if (
knowledge.user_id != user.id
and not has_access(user.id, "write", knowledge.access_control)
and not has_access(user.id, "write", knowledge.access_control, db=db)
and user.role != "admin"
):
raise HTTPException(
@ -562,7 +595,7 @@ def remove_file_from_knowledge_by_id(
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
file = Files.get_file_by_id(form_data.file_id)
file = Files.get_file_by_id(form_data.file_id, db=db)
if not file:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
@ -570,7 +603,7 @@ def remove_file_from_knowledge_by_id(
)
Knowledges.remove_file_from_knowledge_by_id(
knowledge_id=id, file_id=form_data.file_id
knowledge_id=id, file_id=form_data.file_id, db=db
)
# Remove content from the vector database
@ -599,12 +632,12 @@ def remove_file_from_knowledge_by_id(
pass
# Delete file from database
Files.delete_file_by_id(form_data.file_id)
Files.delete_file_by_id(form_data.file_id, db=db)
if knowledge:
return KnowledgeFilesResponse(
**knowledge.model_dump(),
files=Knowledges.get_file_metadatas_by_id(knowledge.id),
files=Knowledges.get_file_metadatas_by_id(knowledge.id, db=db),
)
else:
raise HTTPException(
@ -619,8 +652,10 @@ def remove_file_from_knowledge_by_id(
@router.delete("/{id}/delete", response_model=bool)
async def delete_knowledge_by_id(id: str, user=Depends(get_verified_user)):
knowledge = Knowledges.get_knowledge_by_id(id=id)
async def delete_knowledge_by_id(
id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
):
knowledge = Knowledges.get_knowledge_by_id(id=id, db=db)
if not knowledge:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
@ -629,7 +664,7 @@ async def delete_knowledge_by_id(id: str, user=Depends(get_verified_user)):
if (
knowledge.user_id != user.id
and not has_access(user.id, "write", knowledge.access_control)
and not has_access(user.id, "write", knowledge.access_control, db=db)
and user.role != "admin"
):
raise HTTPException(
@ -640,7 +675,7 @@ async def delete_knowledge_by_id(id: str, user=Depends(get_verified_user)):
log.info(f"Deleting knowledge base: {id} (name: {knowledge.name})")
# Get all models
models = Models.get_all_models()
models = Models.get_all_models(db=db)
log.info(f"Found {len(models)} models to check for knowledge base {id}")
# Update models that reference this knowledge base
@ -664,7 +699,7 @@ async def delete_knowledge_by_id(id: str, user=Depends(get_verified_user)):
access_control=model.access_control,
is_active=model.is_active,
)
Models.update_model_by_id(model.id, model_form)
Models.update_model_by_id(model.id, model_form, db=db)
# Clean up vector DB
try:
@ -672,7 +707,7 @@ async def delete_knowledge_by_id(id: str, user=Depends(get_verified_user)):
except Exception as e:
log.debug(e)
pass
result = Knowledges.delete_knowledge_by_id(id=id)
result = Knowledges.delete_knowledge_by_id(id=id, db=db)
return result
@ -682,8 +717,10 @@ async def delete_knowledge_by_id(id: str, user=Depends(get_verified_user)):
@router.post("/{id}/reset", response_model=Optional[KnowledgeResponse])
async def reset_knowledge_by_id(id: str, user=Depends(get_verified_user)):
knowledge = Knowledges.get_knowledge_by_id(id=id)
async def reset_knowledge_by_id(
id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
):
knowledge = Knowledges.get_knowledge_by_id(id=id, db=db)
if not knowledge:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
@ -692,7 +729,7 @@ async def reset_knowledge_by_id(id: str, user=Depends(get_verified_user)):
if (
knowledge.user_id != user.id
and not has_access(user.id, "write", knowledge.access_control)
and not has_access(user.id, "write", knowledge.access_control, db=db)
and user.role != "admin"
):
raise HTTPException(
@ -706,7 +743,7 @@ async def reset_knowledge_by_id(id: str, user=Depends(get_verified_user)):
log.debug(e)
pass
knowledge = Knowledges.reset_knowledge_by_id(id=id)
knowledge = Knowledges.reset_knowledge_by_id(id=id, db=db)
return knowledge
@ -721,11 +758,12 @@ async def add_files_to_knowledge_batch(
id: str,
form_data: list[KnowledgeFileIdForm],
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
"""
Add multiple files to a knowledge base
"""
knowledge = Knowledges.get_knowledge_by_id(id=id)
knowledge = Knowledges.get_knowledge_by_id(id=id, db=db)
if not knowledge:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
@ -734,7 +772,7 @@ async def add_files_to_knowledge_batch(
if (
knowledge.user_id != user.id
and not has_access(user.id, "write", knowledge.access_control)
and not has_access(user.id, "write", knowledge.access_control, db=db)
and user.role != "admin"
):
raise HTTPException(
@ -746,7 +784,7 @@ async def add_files_to_knowledge_batch(
log.info(f"files/batch/add - {len(form_data)} files")
files: List[FileModel] = []
for form in form_data:
file = Files.get_file_by_id(form.file_id)
file = Files.get_file_by_id(form.file_id, db=db)
if not file:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
@ -760,6 +798,7 @@ async def add_files_to_knowledge_batch(
request=request,
form_data=BatchProcessFilesForm(files=files, collection_name=id),
user=user,
db=db,
)
except Exception as e:
log.error(
@ -771,7 +810,7 @@ async def add_files_to_knowledge_batch(
successful_file_ids = [r.file_id for r in result.results if r.status == "completed"]
for file_id in successful_file_ids:
Knowledges.add_file_to_knowledge_by_id(
knowledge_id=id, file_id=file_id, user_id=user.id
knowledge_id=id, file_id=file_id, user_id=user.id, db=db
)
# If there were any errors, include them in the response
@ -779,7 +818,7 @@ async def add_files_to_knowledge_batch(
error_details = [f"{err.file_id}: {err.error}" for err in result.errors]
return KnowledgeFilesResponse(
**knowledge.model_dump(),
files=Knowledges.get_file_metadatas_by_id(knowledge.id),
files=Knowledges.get_file_metadatas_by_id(knowledge.id, db=db),
warnings={
"message": "Some files failed to process",
"errors": error_details,
@ -788,5 +827,5 @@ async def add_files_to_knowledge_batch(
return KnowledgeFilesResponse(
**knowledge.model_dump(),
files=Knowledges.get_file_metadatas_by_id(knowledge.id),
files=Knowledges.get_file_metadatas_by_id(knowledge.id, db=db),
)

View file

@ -9,6 +9,10 @@ from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT
from open_webui.utils.auth import get_verified_user
from open_webui.utils.auth import get_verified_user
from open_webui.internal.db import get_session
from sqlalchemy.orm import Session
log = logging.getLogger(__name__)
router = APIRouter()
@ -25,8 +29,8 @@ async def get_embeddings(request: Request):
@router.get("/", response_model=list[MemoryModel])
async def get_memories(user=Depends(get_verified_user)):
return Memories.get_memories_by_user_id(user.id)
async def get_memories(user=Depends(get_verified_user), db: Session = Depends(get_session)):
return Memories.get_memories_by_user_id(user.id, db=db)
############################
@ -47,8 +51,9 @@ async def add_memory(
request: Request,
form_data: AddMemoryForm,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
memory = Memories.insert_new_memory(user.id, form_data.content)
memory = Memories.insert_new_memory(user.id, form_data.content, db=db)
vector = await request.app.state.EMBEDDING_FUNCTION(memory.content, user=user)
@ -79,9 +84,9 @@ class QueryMemoryForm(BaseModel):
@router.post("/query")
async def query_memory(
request: Request, form_data: QueryMemoryForm, user=Depends(get_verified_user)
request: Request, form_data: QueryMemoryForm, user=Depends(get_verified_user), db: Session = Depends(get_session)
):
memories = Memories.get_memories_by_user_id(user.id)
memories = Memories.get_memories_by_user_id(user.id, db=db)
if not memories:
raise HTTPException(status_code=404, detail="No memories found for user")
@ -101,11 +106,11 @@ async def query_memory(
############################
@router.post("/reset", response_model=bool)
async def reset_memory_from_vector_db(
request: Request, user=Depends(get_verified_user)
request: Request, user=Depends(get_verified_user), db: Session = Depends(get_session)
):
VECTOR_DB_CLIENT.delete_collection(f"user-memory-{user.id}")
memories = Memories.get_memories_by_user_id(user.id)
memories = Memories.get_memories_by_user_id(user.id, db=db)
# Generate vectors in parallel
vectors = await asyncio.gather(
@ -140,8 +145,8 @@ async def reset_memory_from_vector_db(
@router.delete("/delete/user", response_model=bool)
async def delete_memory_by_user_id(user=Depends(get_verified_user)):
result = Memories.delete_memories_by_user_id(user.id)
async def delete_memory_by_user_id(user=Depends(get_verified_user), db: Session = Depends(get_session)):
result = Memories.delete_memories_by_user_id(user.id, db=db)
if result:
try:
@ -164,9 +169,10 @@ async def update_memory_by_id(
request: Request,
form_data: MemoryUpdateModel,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
memory = Memories.update_memory_by_id_and_user_id(
memory_id, user.id, form_data.content
memory_id, user.id, form_data.content, db=db
)
if memory is None:
raise HTTPException(status_code=404, detail="Memory not found")
@ -198,8 +204,8 @@ async def update_memory_by_id(
@router.delete("/{memory_id}", response_model=bool)
async def delete_memory_by_id(memory_id: str, user=Depends(get_verified_user)):
result = Memories.delete_memory_by_id_and_user_id(memory_id, user.id)
async def delete_memory_by_id(memory_id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)):
result = Memories.delete_memory_by_id_and_user_id(memory_id, user.id, db=db)
if result:
VECTOR_DB_CLIENT.delete(

View file

@ -30,6 +30,8 @@ from fastapi.responses import FileResponse, StreamingResponse
from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.utils.access_control import has_access, has_permission
from open_webui.config import BYPASS_ADMIN_ACCESS_CONTROL, STATIC_DIR
from open_webui.internal.db import get_session
from sqlalchemy.orm import Session
log = logging.getLogger(__name__)
@ -59,6 +61,7 @@ async def get_models(
direction: Optional[str] = None,
page: Optional[int] = 1,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
limit = PAGE_ITEM_COUNT
@ -79,13 +82,13 @@ async def get_models(
filter["direction"] = direction
if not user.role == "admin" or not BYPASS_ADMIN_ACCESS_CONTROL:
groups = Groups.get_groups_by_member_id(user.id)
groups = Groups.get_groups_by_member_id(user.id, db=db)
if groups:
filter["group_ids"] = [group.id for group in groups]
filter["user_id"] = user.id
return Models.search_models(user.id, filter=filter, skip=skip, limit=limit)
return Models.search_models(user.id, filter=filter, skip=skip, limit=limit, db=db)
###########################
@ -94,8 +97,8 @@ async def get_models(
@router.get("/base", response_model=list[ModelResponse])
async def get_base_models(user=Depends(get_admin_user)):
return Models.get_base_models()
async def get_base_models(user=Depends(get_admin_user), db: Session = Depends(get_session)):
return Models.get_base_models(db=db)
###########################
@ -104,11 +107,11 @@ async def get_base_models(user=Depends(get_admin_user)):
@router.get("/tags", response_model=list[str])
async def get_model_tags(user=Depends(get_verified_user)):
async def get_model_tags(user=Depends(get_verified_user), db: Session = Depends(get_session)):
if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL:
models = Models.get_models()
models = Models.get_models(db=db)
else:
models = Models.get_models_by_user_id(user.id)
models = Models.get_models_by_user_id(user.id, db=db)
tags_set = set()
for model in models:
@ -132,16 +135,17 @@ async def create_new_model(
request: Request,
form_data: ModelForm,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
if user.role != "admin" and not has_permission(
user.id, "workspace.models", request.app.state.config.USER_PERMISSIONS
user.id, "workspace.models", request.app.state.config.USER_PERMISSIONS, db=db
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.UNAUTHORIZED,
)
model = Models.get_model_by_id(form_data.id)
model = Models.get_model_by_id(form_data.id, db=db)
if model:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@ -155,7 +159,7 @@ async def create_new_model(
)
else:
model = Models.insert_new_model(form_data, user.id)
model = Models.insert_new_model(form_data, user.id, db=db)
if model:
return model
else:
@ -171,9 +175,9 @@ async def create_new_model(
@router.get("/export", response_model=list[ModelModel])
async def export_models(request: Request, user=Depends(get_verified_user)):
async def export_models(request: Request, user=Depends(get_verified_user), db: Session = Depends(get_session)):
if user.role != "admin" and not has_permission(
user.id, "workspace.models_export", request.app.state.config.USER_PERMISSIONS
user.id, "workspace.models_export", request.app.state.config.USER_PERMISSIONS, db=db
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@ -181,9 +185,9 @@ async def export_models(request: Request, user=Depends(get_verified_user)):
)
if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL:
return Models.get_models()
return Models.get_models(db=db)
else:
return Models.get_models_by_user_id(user.id)
return Models.get_models_by_user_id(user.id, db=db)
############################
@ -200,9 +204,10 @@ async def import_models(
request: Request,
user=Depends(get_verified_user),
form_data: ModelsImportForm = (...),
db: Session = Depends(get_session),
):
if user.role != "admin" and not has_permission(
user.id, "workspace.models_import", request.app.state.config.USER_PERMISSIONS
user.id, "workspace.models_import", request.app.state.config.USER_PERMISSIONS, db=db
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@ -216,7 +221,7 @@ async def import_models(
model_id = model_data.get("id")
if model_id and is_valid_model_id(model_id):
existing_model = Models.get_model_by_id(model_id)
existing_model = Models.get_model_by_id(model_id, db=db)
if existing_model:
# Update existing model
model_data["meta"] = model_data.get("meta", {})
@ -225,13 +230,13 @@ async def import_models(
updated_model = ModelForm(
**{**existing_model.model_dump(), **model_data}
)
Models.update_model_by_id(model_id, updated_model)
Models.update_model_by_id(model_id, updated_model, db=db)
else:
# Insert new model
model_data["meta"] = model_data.get("meta", {})
model_data["params"] = model_data.get("params", {})
new_model = ModelForm(**model_data)
Models.insert_new_model(user_id=user.id, form_data=new_model)
Models.insert_new_model(user_id=user.id, form_data=new_model, db=db)
return True
else:
raise HTTPException(status_code=400, detail="Invalid JSON format")
@ -251,9 +256,9 @@ class SyncModelsForm(BaseModel):
@router.post("/sync", response_model=list[ModelModel])
async def sync_models(
request: Request, form_data: SyncModelsForm, user=Depends(get_admin_user)
request: Request, form_data: SyncModelsForm, user=Depends(get_admin_user), db: Session = Depends(get_session)
):
return Models.sync_models(user.id, form_data.models)
return Models.sync_models(user.id, form_data.models, db=db)
###########################
@ -267,13 +272,13 @@ class ModelIdForm(BaseModel):
# Note: We're not using the typical url path param here, but instead using a query parameter to allow '/' in the id
@router.get("/model", response_model=Optional[ModelResponse])
async def get_model_by_id(id: str, user=Depends(get_verified_user)):
model = Models.get_model_by_id(id)
async def get_model_by_id(id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)):
model = Models.get_model_by_id(id, db=db)
if model:
if (
(user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL)
or model.user_id == user.id
or has_access(user.id, "read", model.access_control)
or has_access(user.id, "read", model.access_control, db=db)
):
return model
else:
@ -289,8 +294,8 @@ async def get_model_by_id(id: str, user=Depends(get_verified_user)):
@router.get("/model/profile/image")
async def get_model_profile_image(id: str, user=Depends(get_verified_user)):
model = Models.get_model_by_id(id)
async def get_model_profile_image(id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)):
model = Models.get_model_by_id(id, db=db)
# Cache-control headers to prevent stale cached images
cache_headers = {"Cache-Control": "no-cache, must-revalidate"}
@ -330,15 +335,15 @@ async def get_model_profile_image(id: str, user=Depends(get_verified_user)):
@router.post("/model/toggle", response_model=Optional[ModelResponse])
async def toggle_model_by_id(id: str, user=Depends(get_verified_user)):
model = Models.get_model_by_id(id)
async def toggle_model_by_id(id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)):
model = Models.get_model_by_id(id, db=db)
if model:
if (
user.role == "admin"
or model.user_id == user.id
or has_access(user.id, "write", model.access_control)
or has_access(user.id, "write", model.access_control, db=db)
):
model = Models.toggle_model_by_id(id)
model = Models.toggle_model_by_id(id, db=db)
if model:
return model
@ -368,8 +373,9 @@ async def toggle_model_by_id(id: str, user=Depends(get_verified_user)):
async def update_model_by_id(
form_data: ModelForm,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
model = Models.get_model_by_id(form_data.id)
model = Models.get_model_by_id(form_data.id, db=db)
if not model:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@ -378,7 +384,7 @@ async def update_model_by_id(
if (
model.user_id != user.id
and not has_access(user.id, "write", model.access_control)
and not has_access(user.id, "write", model.access_control, db=db)
and user.role != "admin"
):
raise HTTPException(
@ -386,7 +392,7 @@ async def update_model_by_id(
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
model = Models.update_model_by_id(form_data.id, ModelForm(**form_data.model_dump()))
model = Models.update_model_by_id(form_data.id, ModelForm(**form_data.model_dump()), db=db)
return model
@ -396,8 +402,8 @@ async def update_model_by_id(
@router.post("/model/delete", response_model=bool)
async def delete_model_by_id(form_data: ModelIdForm, user=Depends(get_verified_user)):
model = Models.get_model_by_id(form_data.id)
async def delete_model_by_id(form_data: ModelIdForm, user=Depends(get_verified_user), db: Session = Depends(get_session)):
model = Models.get_model_by_id(form_data.id, db=db)
if not model:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@ -407,18 +413,18 @@ async def delete_model_by_id(form_data: ModelIdForm, user=Depends(get_verified_u
if (
user.role != "admin"
and model.user_id != user.id
and not has_access(user.id, "write", model.access_control)
and not has_access(user.id, "write", model.access_control, db=db)
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.UNAUTHORIZED,
)
result = Models.delete_model_by_id(form_data.id)
result = Models.delete_model_by_id(form_data.id, db=db)
return result
@router.delete("/delete/all", response_model=bool)
async def delete_all_models(user=Depends(get_admin_user)):
result = Models.delete_all_models()
async def delete_all_models(user=Depends(get_admin_user), db: Session = Depends(get_session)):
result = Models.delete_all_models(db=db)
return result

View file

@ -28,6 +28,8 @@ from open_webui.constants import ERROR_MESSAGES
from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.utils.access_control import has_access, has_permission
from open_webui.internal.db import get_session
from sqlalchemy.orm import Session
log = logging.getLogger(__name__)
@ -49,10 +51,13 @@ class NoteItemResponse(BaseModel):
@router.get("/", response_model=list[NoteItemResponse])
async def get_notes(
request: Request, page: Optional[int] = None, user=Depends(get_verified_user)
request: Request,
page: Optional[int] = None,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
if user.role != "admin" and not has_permission(
user.id, "features.notes", request.app.state.config.USER_PERMISSIONS
user.id, "features.notes", request.app.state.config.USER_PERMISSIONS, db=db
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@ -69,10 +74,14 @@ async def get_notes(
NoteUserResponse(
**{
**note.model_dump(),
"user": UserResponse(**Users.get_user_by_id(note.user_id).model_dump()),
"user": UserResponse(
**Users.get_user_by_id(note.user_id, db=db).model_dump()
),
}
)
for note in Notes.get_notes_by_user_id(user.id, "read", skip=skip, limit=limit)
for note in Notes.get_notes_by_user_id(
user.id, "read", skip=skip, limit=limit, db=db
)
]
return notes
@ -87,9 +96,10 @@ async def search_notes(
direction: Optional[str] = None,
page: Optional[int] = 1,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
if user.role != "admin" and not has_permission(
user.id, "features.notes", request.app.state.config.USER_PERMISSIONS
user.id, "features.notes", request.app.state.config.USER_PERMISSIONS, db=db
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@ -115,13 +125,13 @@ async def search_notes(
filter["direction"] = direction
if not user.role == "admin" or not BYPASS_ADMIN_ACCESS_CONTROL:
groups = Groups.get_groups_by_member_id(user.id)
groups = Groups.get_groups_by_member_id(user.id, db=db)
if groups:
filter["group_ids"] = [group.id for group in groups]
filter["user_id"] = user.id
return Notes.search_notes(user.id, filter, skip=skip, limit=limit)
return Notes.search_notes(user.id, filter, skip=skip, limit=limit, db=db)
############################
@ -131,10 +141,13 @@ async def search_notes(
@router.post("/create", response_model=Optional[NoteModel])
async def create_new_note(
request: Request, form_data: NoteForm, user=Depends(get_verified_user)
request: Request,
form_data: NoteForm,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
if user.role != "admin" and not has_permission(
user.id, "features.notes", request.app.state.config.USER_PERMISSIONS
user.id, "features.notes", request.app.state.config.USER_PERMISSIONS, db=db
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@ -142,7 +155,7 @@ async def create_new_note(
)
try:
note = Notes.insert_new_note(user.id, form_data)
note = Notes.insert_new_note(user.id, form_data, db=db)
return note
except Exception as e:
log.exception(e)
@ -161,16 +174,21 @@ class NoteResponse(NoteModel):
@router.get("/{id}", response_model=Optional[NoteResponse])
async def get_note_by_id(request: Request, id: str, user=Depends(get_verified_user)):
async def get_note_by_id(
request: Request,
id: str,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
if user.role != "admin" and not has_permission(
user.id, "features.notes", request.app.state.config.USER_PERMISSIONS
user.id, "features.notes", request.app.state.config.USER_PERMISSIONS, db=db
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.UNAUTHORIZED,
)
note = Notes.get_note_by_id(id)
note = Notes.get_note_by_id(id, db=db)
if not note:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
@ -178,7 +196,11 @@ async def get_note_by_id(request: Request, id: str, user=Depends(get_verified_us
if user.role != "admin" and (
user.id != note.user_id
and (not has_access(user.id, type="read", access_control=note.access_control))
and (
not has_access(
user.id, type="read", access_control=note.access_control, db=db
)
)
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
@ -188,7 +210,11 @@ async def get_note_by_id(request: Request, id: str, user=Depends(get_verified_us
user.role == "admin"
or (user.id == note.user_id)
or has_access(
user.id, type="write", access_control=note.access_control, strict=False
user.id,
type="write",
access_control=note.access_control,
strict=False,
db=db,
)
)
@ -202,17 +228,21 @@ async def get_note_by_id(request: Request, id: str, user=Depends(get_verified_us
@router.post("/{id}/update", response_model=Optional[NoteModel])
async def update_note_by_id(
request: Request, id: str, form_data: NoteForm, user=Depends(get_verified_user)
request: Request,
id: str,
form_data: NoteForm,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
if user.role != "admin" and not has_permission(
user.id, "features.notes", request.app.state.config.USER_PERMISSIONS
user.id, "features.notes", request.app.state.config.USER_PERMISSIONS, db=db
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.UNAUTHORIZED,
)
note = Notes.get_note_by_id(id)
note = Notes.get_note_by_id(id, db=db)
if not note:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
@ -220,7 +250,9 @@ async def update_note_by_id(
if user.role != "admin" and (
user.id != note.user_id
and not has_access(user.id, type="write", access_control=note.access_control)
and not has_access(
user.id, type="write", access_control=note.access_control, db=db
)
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
@ -234,12 +266,13 @@ async def update_note_by_id(
user.id,
"sharing.public_notes",
request.app.state.config.USER_PERMISSIONS,
db=db,
)
):
form_data.access_control = {}
try:
note = Notes.update_note_by_id(id, form_data)
note = Notes.update_note_by_id(id, form_data, db=db)
await sio.emit(
"note-events",
note.model_dump(),
@ -260,16 +293,21 @@ async def update_note_by_id(
@router.delete("/{id}/delete", response_model=bool)
async def delete_note_by_id(request: Request, id: str, user=Depends(get_verified_user)):
async def delete_note_by_id(
request: Request,
id: str,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
if user.role != "admin" and not has_permission(
user.id, "features.notes", request.app.state.config.USER_PERMISSIONS
user.id, "features.notes", request.app.state.config.USER_PERMISSIONS, db=db
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.UNAUTHORIZED,
)
note = Notes.get_note_by_id(id)
note = Notes.get_note_by_id(id, db=db)
if not note:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
@ -277,14 +315,16 @@ async def delete_note_by_id(request: Request, id: str, user=Depends(get_verified
if user.role != "admin" and (
user.id != note.user_id
and not has_access(user.id, type="write", access_control=note.access_control)
and not has_access(
user.id, type="write", access_control=note.access_control, db=db
)
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
)
try:
note = Notes.delete_note_by_id(id)
note = Notes.delete_note_by_id(id, db=db)
return True
except Exception as e:
log.exception(e)

View file

@ -11,6 +11,8 @@ from open_webui.constants import ERROR_MESSAGES
from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.utils.access_control import has_access, has_permission
from open_webui.config import BYPASS_ADMIN_ACCESS_CONTROL
from open_webui.internal.db import get_session
from sqlalchemy.orm import Session
router = APIRouter()
@ -20,21 +22,21 @@ router = APIRouter()
@router.get("/", response_model=list[PromptModel])
async def get_prompts(user=Depends(get_verified_user)):
async def get_prompts(user=Depends(get_verified_user), db: Session = Depends(get_session)):
if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL:
prompts = Prompts.get_prompts()
prompts = Prompts.get_prompts(db=db)
else:
prompts = Prompts.get_prompts_by_user_id(user.id, "read")
prompts = Prompts.get_prompts_by_user_id(user.id, "read", db=db)
return prompts
@router.get("/list", response_model=list[PromptUserResponse])
async def get_prompt_list(user=Depends(get_verified_user)):
async def get_prompt_list(user=Depends(get_verified_user), db: Session = Depends(get_session)):
if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL:
prompts = Prompts.get_prompts()
prompts = Prompts.get_prompts(db=db)
else:
prompts = Prompts.get_prompts_by_user_id(user.id, "write")
prompts = Prompts.get_prompts_by_user_id(user.id, "write", db=db)
return prompts
@ -46,16 +48,17 @@ async def get_prompt_list(user=Depends(get_verified_user)):
@router.post("/create", response_model=Optional[PromptModel])
async def create_new_prompt(
request: Request, form_data: PromptForm, user=Depends(get_verified_user)
request: Request, form_data: PromptForm, user=Depends(get_verified_user), db: Session = Depends(get_session)
):
if user.role != "admin" and not (
has_permission(
user.id, "workspace.prompts", request.app.state.config.USER_PERMISSIONS
user.id, "workspace.prompts", request.app.state.config.USER_PERMISSIONS, db=db
)
or has_permission(
user.id,
"workspace.prompts_import",
request.app.state.config.USER_PERMISSIONS,
db=db,
)
):
raise HTTPException(
@ -63,9 +66,9 @@ async def create_new_prompt(
detail=ERROR_MESSAGES.UNAUTHORIZED,
)
prompt = Prompts.get_prompt_by_command(form_data.command)
prompt = Prompts.get_prompt_by_command(form_data.command, db=db)
if prompt is None:
prompt = Prompts.insert_new_prompt(user.id, form_data)
prompt = Prompts.insert_new_prompt(user.id, form_data, db=db)
if prompt:
return prompt
@ -85,14 +88,14 @@ async def create_new_prompt(
@router.get("/command/{command}", response_model=Optional[PromptModel])
async def get_prompt_by_command(command: str, user=Depends(get_verified_user)):
prompt = Prompts.get_prompt_by_command(f"/{command}")
async def get_prompt_by_command(command: str, user=Depends(get_verified_user), db: Session = Depends(get_session)):
prompt = Prompts.get_prompt_by_command(f"/{command}", db=db)
if prompt:
if (
user.role == "admin"
or prompt.user_id == user.id
or has_access(user.id, "read", prompt.access_control)
or has_access(user.id, "read", prompt.access_control, db=db)
):
return prompt
else:
@ -112,8 +115,9 @@ async def update_prompt_by_command(
command: str,
form_data: PromptForm,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
prompt = Prompts.get_prompt_by_command(f"/{command}")
prompt = Prompts.get_prompt_by_command(f"/{command}", db=db)
if not prompt:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@ -123,7 +127,7 @@ async def update_prompt_by_command(
# Is the user the original creator, in a group with write access, or an admin
if (
prompt.user_id != user.id
and not has_access(user.id, "write", prompt.access_control)
and not has_access(user.id, "write", prompt.access_control, db=db)
and user.role != "admin"
):
raise HTTPException(
@ -131,7 +135,7 @@ async def update_prompt_by_command(
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
prompt = Prompts.update_prompt_by_command(f"/{command}", form_data)
prompt = Prompts.update_prompt_by_command(f"/{command}", form_data, db=db)
if prompt:
return prompt
else:
@ -147,8 +151,8 @@ async def update_prompt_by_command(
@router.delete("/command/{command}/delete", response_model=bool)
async def delete_prompt_by_command(command: str, user=Depends(get_verified_user)):
prompt = Prompts.get_prompt_by_command(f"/{command}")
async def delete_prompt_by_command(command: str, user=Depends(get_verified_user), db: Session = Depends(get_session)):
prompt = Prompts.get_prompt_by_command(f"/{command}", db=db)
if not prompt:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@ -157,7 +161,7 @@ async def delete_prompt_by_command(command: str, user=Depends(get_verified_user)
if (
prompt.user_id != user.id
and not has_access(user.id, "write", prompt.access_control)
and not has_access(user.id, "write", prompt.access_control, db=db)
and user.role != "admin"
):
raise HTTPException(
@ -165,5 +169,5 @@ async def delete_prompt_by_command(command: str, user=Depends(get_verified_user)
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
result = Prompts.delete_prompt_by_command(f"/{command}")
result = Prompts.delete_prompt_by_command(f"/{command}", db=db)
return result

View file

@ -39,6 +39,8 @@ from langchain_core.documents import Document
from open_webui.models.files import FileModel, FileUpdateForm, Files
from open_webui.models.knowledge import Knowledges
from open_webui.storage.provider import Storage
from open_webui.internal.db import get_session
from sqlalchemy.orm import Session
from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT
@ -1484,14 +1486,15 @@ def process_file(
request: Request,
form_data: ProcessFileForm,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
"""
Process a file and save its content to the vector database.
"""
if user.role == "admin":
file = Files.get_file_by_id(form_data.file_id)
file = Files.get_file_by_id(form_data.file_id, db=db)
else:
file = Files.get_file_by_id_and_user_id(form_data.file_id, user.id)
file = Files.get_file_by_id_and_user_id(form_data.file_id, user.id, db=db)
if file:
try:
@ -1633,12 +1636,13 @@ def process_file(
Files.update_file_data_by_id(
file.id,
{"content": text_content},
db=db,
)
hash = calculate_sha256_string(text_content)
Files.update_file_hash_by_id(file.id, hash)
Files.update_file_hash_by_id(file.id, hash, db=db)
if request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL:
Files.update_file_data_by_id(file.id, {"status": "completed"})
Files.update_file_data_by_id(file.id, {"status": "completed"}, db=db)
return {
"status": True,
"collection_name": None,
@ -1667,11 +1671,13 @@ def process_file(
{
"collection_name": collection_name,
},
db=db,
)
Files.update_file_data_by_id(
file.id,
{"status": "completed"},
db=db,
)
return {
@ -1690,6 +1696,7 @@ def process_file(
Files.update_file_data_by_id(
file.id,
{"status": "failed"},
db=db,
)
if "No pandoc was found" in str(e):
@ -2417,10 +2424,10 @@ class DeleteForm(BaseModel):
@router.post("/delete")
def delete_entries_from_collection(form_data: DeleteForm, user=Depends(get_admin_user)):
def delete_entries_from_collection(form_data: DeleteForm, user=Depends(get_admin_user), db: Session = Depends(get_session)):
try:
if VECTOR_DB_CLIENT.has_collection(collection_name=form_data.collection_name):
file = Files.get_file_by_id(form_data.file_id)
file = Files.get_file_by_id(form_data.file_id, db=db)
hash = file.hash
VECTOR_DB_CLIENT.delete(
@ -2436,9 +2443,9 @@ def delete_entries_from_collection(form_data: DeleteForm, user=Depends(get_admin
@router.post("/reset/db")
def reset_vector_db(user=Depends(get_admin_user)):
def reset_vector_db(user=Depends(get_admin_user), db: Session = Depends(get_session)):
VECTOR_DB_CLIENT.reset()
Knowledges.delete_all_knowledge()
Knowledges.delete_all_knowledge(db=db)
@router.post("/reset/uploads")
@ -2496,6 +2503,7 @@ async def process_files_batch(
request: Request,
form_data: BatchProcessFilesForm,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
) -> BatchProcessFilesResponse:
"""
Process a batch of files and save them to the vector database.
@ -2558,7 +2566,7 @@ async def process_files_batch(
# Update all files with collection name
for file_update, file_result in zip(file_updates, file_results):
Files.update_file_by_id(id=file_result.file_id, form_data=file_update)
Files.update_file_by_id(id=file_result.file_id, form_data=file_update, db=db)
file_result.status = "completed"
except Exception as e:

View file

@ -7,6 +7,8 @@ import aiohttp
from open_webui.models.groups import Groups
from pydantic import BaseModel, HttpUrl
from fastapi import APIRouter, Depends, HTTPException, Request, status
from sqlalchemy.orm import Session
from open_webui.internal.db import get_session
from open_webui.models.oauth_sessions import OAuthSessions
@ -51,11 +53,11 @@ def get_tool_module(request, tool_id, load_from_db=True):
@router.get("/", response_model=list[ToolUserResponse])
async def get_tools(request: Request, user=Depends(get_verified_user)):
async def get_tools(request: Request, user=Depends(get_verified_user), db: Session = Depends(get_session)):
tools = []
# Local Tools
for tool in Tools.get_tools():
for tool in Tools.get_tools(db=db):
tool_module = get_tool_module(request, tool.id)
tools.append(
ToolUserResponse(
@ -140,12 +142,12 @@ async def get_tools(request: Request, user=Depends(get_verified_user)):
# Admin can see all tools
return tools
else:
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user.id)}
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user.id, db=db)}
tools = [
tool
for tool in tools
if tool.user_id == user.id
or has_access(user.id, "read", tool.access_control, user_group_ids)
or has_access(user.id, "read", tool.access_control, user_group_ids, db=db)
]
return tools
@ -156,11 +158,11 @@ async def get_tools(request: Request, user=Depends(get_verified_user)):
@router.get("/list", response_model=list[ToolUserResponse])
async def get_tool_list(user=Depends(get_verified_user)):
async def get_tool_list(user=Depends(get_verified_user), db: Session = Depends(get_session)):
if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL:
tools = Tools.get_tools()
tools = Tools.get_tools(db=db)
else:
tools = Tools.get_tools_by_user_id(user.id, "write")
tools = Tools.get_tools_by_user_id(user.id, "write", db=db)
return tools
@ -245,9 +247,9 @@ async def load_tool_from_url(
@router.get("/export", response_model=list[ToolModel])
async def export_tools(request: Request, user=Depends(get_verified_user)):
async def export_tools(request: Request, user=Depends(get_verified_user), db: Session = Depends(get_session)):
if user.role != "admin" and not has_permission(
user.id, "workspace.tools_export", request.app.state.config.USER_PERMISSIONS
user.id, "workspace.tools_export", request.app.state.config.USER_PERMISSIONS, db=db
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@ -255,9 +257,9 @@ async def export_tools(request: Request, user=Depends(get_verified_user)):
)
if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL:
return Tools.get_tools()
return Tools.get_tools(db=db)
else:
return Tools.get_tools_by_user_id(user.id, "read")
return Tools.get_tools_by_user_id(user.id, "read", db=db)
############################
@ -270,13 +272,14 @@ async def create_new_tools(
request: Request,
form_data: ToolForm,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
if user.role != "admin" and not (
has_permission(
user.id, "workspace.tools", request.app.state.config.USER_PERMISSIONS
user.id, "workspace.tools", request.app.state.config.USER_PERMISSIONS, db=db
)
or has_permission(
user.id, "workspace.tools_import", request.app.state.config.USER_PERMISSIONS
user.id, "workspace.tools_import", request.app.state.config.USER_PERMISSIONS, db=db
)
):
raise HTTPException(
@ -292,7 +295,7 @@ async def create_new_tools(
form_data.id = form_data.id.lower()
tools = Tools.get_tool_by_id(form_data.id)
tools = Tools.get_tool_by_id(form_data.id, db=db)
if tools is None:
try:
form_data.content = replace_imports(form_data.content)
@ -305,7 +308,7 @@ async def create_new_tools(
TOOLS[form_data.id] = tool_module
specs = get_tool_specs(TOOLS[form_data.id])
tools = Tools.insert_new_tool(user.id, form_data, specs)
tools = Tools.insert_new_tool(user.id, form_data, specs, db=db)
tool_cache_dir = CACHE_DIR / "tools" / form_data.id
tool_cache_dir.mkdir(parents=True, exist_ok=True)
@ -336,14 +339,14 @@ async def create_new_tools(
@router.get("/id/{id}", response_model=Optional[ToolModel])
async def get_tools_by_id(id: str, user=Depends(get_verified_user)):
tools = Tools.get_tool_by_id(id)
async def get_tools_by_id(id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)):
tools = Tools.get_tool_by_id(id, db=db)
if tools:
if (
user.role == "admin"
or tools.user_id == user.id
or has_access(user.id, "read", tools.access_control)
or has_access(user.id, "read", tools.access_control, db=db)
):
return tools
else:
@ -364,8 +367,9 @@ async def update_tools_by_id(
id: str,
form_data: ToolForm,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
tools = Tools.get_tool_by_id(id)
tools = Tools.get_tool_by_id(id, db=db)
if not tools:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@ -375,7 +379,7 @@ async def update_tools_by_id(
# Is the user the original creator, in a group with write access, or an admin
if (
tools.user_id != user.id
and not has_access(user.id, "write", tools.access_control)
and not has_access(user.id, "write", tools.access_control, db=db)
and user.role != "admin"
):
raise HTTPException(
@ -399,7 +403,7 @@ async def update_tools_by_id(
}
log.debug(updated)
tools = Tools.update_tool_by_id(id, updated)
tools = Tools.update_tool_by_id(id, updated, db=db)
if tools:
return tools
@ -423,9 +427,9 @@ async def update_tools_by_id(
@router.delete("/id/{id}/delete", response_model=bool)
async def delete_tools_by_id(
request: Request, id: str, user=Depends(get_verified_user)
request: Request, id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
):
tools = Tools.get_tool_by_id(id)
tools = Tools.get_tool_by_id(id, db=db)
if not tools:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@ -434,7 +438,7 @@ async def delete_tools_by_id(
if (
tools.user_id != user.id
and not has_access(user.id, "write", tools.access_control)
and not has_access(user.id, "write", tools.access_control, db=db)
and user.role != "admin"
):
raise HTTPException(
@ -442,7 +446,7 @@ async def delete_tools_by_id(
detail=ERROR_MESSAGES.UNAUTHORIZED,
)
result = Tools.delete_tool_by_id(id)
result = Tools.delete_tool_by_id(id, db=db)
if result:
TOOLS = request.app.state.TOOLS
if id in TOOLS:
@ -457,11 +461,11 @@ async def delete_tools_by_id(
@router.get("/id/{id}/valves", response_model=Optional[dict])
async def get_tools_valves_by_id(id: str, user=Depends(get_verified_user)):
tools = Tools.get_tool_by_id(id)
async def get_tools_valves_by_id(id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)):
tools = Tools.get_tool_by_id(id, db=db)
if tools:
try:
valves = Tools.get_tool_valves_by_id(id)
valves = Tools.get_tool_valves_by_id(id, db=db)
return valves
except Exception as e:
raise HTTPException(
@ -482,9 +486,9 @@ async def get_tools_valves_by_id(id: str, user=Depends(get_verified_user)):
@router.get("/id/{id}/valves/spec", response_model=Optional[dict])
async def get_tools_valves_spec_by_id(
request: Request, id: str, user=Depends(get_verified_user)
request: Request, id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
):
tools = Tools.get_tool_by_id(id)
tools = Tools.get_tool_by_id(id, db=db)
if tools:
if id in request.app.state.TOOLS:
tools_module = request.app.state.TOOLS[id]
@ -510,9 +514,9 @@ async def get_tools_valves_spec_by_id(
@router.post("/id/{id}/valves/update", response_model=Optional[dict])
async def update_tools_valves_by_id(
request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
request: Request, id: str, form_data: dict, user=Depends(get_verified_user), db: Session = Depends(get_session)
):
tools = Tools.get_tool_by_id(id)
tools = Tools.get_tool_by_id(id, db=db)
if not tools:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@ -521,7 +525,7 @@ async def update_tools_valves_by_id(
if (
tools.user_id != user.id
and not has_access(user.id, "write", tools.access_control)
and not has_access(user.id, "write", tools.access_control, db=db)
and user.role != "admin"
):
raise HTTPException(
@ -546,7 +550,7 @@ async def update_tools_valves_by_id(
form_data = {k: v for k, v in form_data.items() if v is not None}
valves = Valves(**form_data)
valves_dict = valves.model_dump(exclude_unset=True)
Tools.update_tool_valves_by_id(id, valves_dict)
Tools.update_tool_valves_by_id(id, valves_dict, db=db)
return valves_dict
except Exception as e:
log.exception(f"Failed to update tool valves by id {id}: {e}")
@ -562,11 +566,11 @@ async def update_tools_valves_by_id(
@router.get("/id/{id}/valves/user", response_model=Optional[dict])
async def get_tools_user_valves_by_id(id: str, user=Depends(get_verified_user)):
tools = Tools.get_tool_by_id(id)
async def get_tools_user_valves_by_id(id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)):
tools = Tools.get_tool_by_id(id, db=db)
if tools:
try:
user_valves = Tools.get_user_valves_by_id_and_user_id(id, user.id)
user_valves = Tools.get_user_valves_by_id_and_user_id(id, user.id, db=db)
return user_valves
except Exception as e:
raise HTTPException(
@ -582,9 +586,9 @@ async def get_tools_user_valves_by_id(id: str, user=Depends(get_verified_user)):
@router.get("/id/{id}/valves/user/spec", response_model=Optional[dict])
async def get_tools_user_valves_spec_by_id(
request: Request, id: str, user=Depends(get_verified_user)
request: Request, id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
):
tools = Tools.get_tool_by_id(id)
tools = Tools.get_tool_by_id(id, db=db)
if tools:
if id in request.app.state.TOOLS:
tools_module = request.app.state.TOOLS[id]
@ -605,9 +609,9 @@ async def get_tools_user_valves_spec_by_id(
@router.post("/id/{id}/valves/user/update", response_model=Optional[dict])
async def update_tools_user_valves_by_id(
request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
request: Request, id: str, form_data: dict, user=Depends(get_verified_user), db: Session = Depends(get_session)
):
tools = Tools.get_tool_by_id(id)
tools = Tools.get_tool_by_id(id, db=db)
if tools:
if id in request.app.state.TOOLS:
@ -624,7 +628,7 @@ async def update_tools_user_valves_by_id(
user_valves = UserValves(**form_data)
user_valves_dict = user_valves.model_dump(exclude_unset=True)
Tools.update_user_valves_by_id_and_user_id(
id, user.id, user_valves_dict
id, user.id, user_valves_dict, db=db
)
return user_valves_dict
except Exception as e:

View file

@ -1,5 +1,6 @@
import logging
from typing import Optional
from sqlalchemy.orm import Session
import base64
import io
@ -29,6 +30,7 @@ from open_webui.models.users import (
from open_webui.constants import ERROR_MESSAGES
from open_webui.env import STATIC_DIR
from open_webui.internal.db import get_session
from open_webui.utils.auth import (
@ -60,6 +62,7 @@ async def get_users(
direction: Optional[str] = None,
page: Optional[int] = 1,
user=Depends(get_admin_user),
db: Session = Depends(get_session),
):
limit = PAGE_ITEM_COUNT
@ -74,7 +77,9 @@ async def get_users(
if direction:
filter["direction"] = direction
result = Users.get_users(filter=filter, skip=skip, limit=limit)
filter["direction"] = direction
result = Users.get_users(filter=filter, skip=skip, limit=limit, db=db)
users = result["users"]
total = result["total"]
@ -85,7 +90,8 @@ async def get_users(
**{
**user.model_dump(),
"group_ids": [
group.id for group in Groups.get_groups_by_member_id(user.id)
group.id
for group in Groups.get_groups_by_member_id(user.id, db=db)
],
}
)
@ -98,8 +104,9 @@ async def get_users(
@router.get("/all", response_model=UserInfoListResponse)
async def get_all_users(
user=Depends(get_admin_user),
db: Session = Depends(get_session),
):
return Users.get_users()
return Users.get_users(db=db)
@router.get("/search", response_model=UserInfoListResponse)
@ -109,16 +116,13 @@ async def search_users(
direction: Optional[str] = None,
page: Optional[int] = 1,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
limit = PAGE_ITEM_COUNT
page = max(1, page)
skip = (page - 1) * limit
filter = {}
if query:
filter["query"] = query
filter = {}
if query:
filter["query"] = query
@ -127,7 +131,7 @@ async def search_users(
if direction:
filter["direction"] = direction
return Users.get_users(filter=filter, skip=skip, limit=limit)
return Users.get_users(filter=filter, skip=skip, limit=limit, db=db)
############################
@ -136,8 +140,10 @@ async def search_users(
@router.get("/groups")
async def get_user_groups(user=Depends(get_verified_user)):
return Groups.get_groups_by_member_id(user.id)
async def get_user_groups(
user=Depends(get_verified_user), db: Session = Depends(get_session)
):
return Groups.get_groups_by_member_id(user.id, db=db)
############################
@ -146,9 +152,13 @@ async def get_user_groups(user=Depends(get_verified_user)):
@router.get("/permissions")
async def get_user_permissisions(request: Request, user=Depends(get_verified_user)):
async def get_user_permissisions(
request: Request,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
user_permissions = get_permissions(
user.id, request.app.state.config.USER_PERMISSIONS
user.id, request.app.state.config.USER_PERMISSIONS, db=db
)
return user_permissions
@ -256,8 +266,10 @@ async def update_default_user_permissions(
@router.get("/user/settings", response_model=Optional[UserSettings])
async def get_user_settings_by_session_user(user=Depends(get_verified_user)):
user = Users.get_user_by_id(user.id)
async def get_user_settings_by_session_user(
user=Depends(get_verified_user), db: Session = Depends(get_session)
):
user = Users.get_user_by_id(user.id, db=db)
if user:
return user.settings
else:
@ -274,7 +286,10 @@ async def get_user_settings_by_session_user(user=Depends(get_verified_user)):
@router.post("/user/settings/update", response_model=UserSettings)
async def update_user_settings_by_session_user(
request: Request, form_data: UserSettings, user=Depends(get_verified_user)
request: Request,
form_data: UserSettings,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
updated_user_settings = form_data.model_dump()
if (
@ -289,7 +304,7 @@ async def update_user_settings_by_session_user(
# If the user is not an admin and does not have permission to use tool servers, remove the key
updated_user_settings["ui"].pop("toolServers", None)
user = Users.update_user_settings_by_id(user.id, updated_user_settings)
user = Users.update_user_settings_by_id(user.id, updated_user_settings, db=db)
if user:
return user.settings
else:
@ -305,8 +320,10 @@ async def update_user_settings_by_session_user(
@router.get("/user/status")
async def get_user_status_by_session_user(user=Depends(get_verified_user)):
user = Users.get_user_by_id(user.id)
async def get_user_status_by_session_user(
user=Depends(get_verified_user), db: Session = Depends(get_session)
):
user = Users.get_user_by_id(user.id, db=db)
if user:
return user
else:
@ -323,11 +340,13 @@ async def get_user_status_by_session_user(user=Depends(get_verified_user)):
@router.post("/user/status/update")
async def update_user_status_by_session_user(
form_data: UserStatus, user=Depends(get_verified_user)
form_data: UserStatus,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
user = Users.get_user_by_id(user.id)
user = Users.get_user_by_id(user.id, db=db)
if user:
user = Users.update_user_status_by_id(user.id, form_data)
user = Users.update_user_status_by_id(user.id, form_data, db=db)
return user
else:
raise HTTPException(
@ -342,8 +361,10 @@ async def update_user_status_by_session_user(
@router.get("/user/info", response_model=Optional[dict])
async def get_user_info_by_session_user(user=Depends(get_verified_user)):
user = Users.get_user_by_id(user.id)
async def get_user_info_by_session_user(
user=Depends(get_verified_user), db: Session = Depends(get_session)
):
user = Users.get_user_by_id(user.id, db=db)
if user:
return user.info
else:
@ -360,14 +381,16 @@ async def get_user_info_by_session_user(user=Depends(get_verified_user)):
@router.post("/user/info/update", response_model=Optional[dict])
async def update_user_info_by_session_user(
form_data: dict, user=Depends(get_verified_user)
form_data: dict, user=Depends(get_verified_user), db: Session = Depends(get_session)
):
user = Users.get_user_by_id(user.id)
user = Users.get_user_by_id(user.id, db=db)
if user:
if user.info is None:
user.info = {}
user = Users.update_user_by_id(user.id, {"info": {**user.info, **form_data}})
user = Users.update_user_by_id(
user.id, {"info": {**user.info, **form_data}}, db=db
)
if user:
return user.info
else:
@ -397,7 +420,9 @@ class UserActiveResponse(UserStatus):
@router.get("/{user_id}", response_model=UserActiveResponse)
async def get_user_by_id(user_id: str, user=Depends(get_verified_user)):
async def get_user_by_id(
user_id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
):
# Check if user_id is a shared chat
# If it is, get the user_id from the chat
if user_id.startswith("shared-"):
@ -411,14 +436,14 @@ async def get_user_by_id(user_id: str, user=Depends(get_verified_user)):
detail=ERROR_MESSAGES.USER_NOT_FOUND,
)
user = Users.get_user_by_id(user_id)
user = Users.get_user_by_id(user_id, db=db)
if user:
groups = Groups.get_groups_by_member_id(user_id)
groups = Groups.get_groups_by_member_id(user_id, db=db)
return UserActiveResponse(
**{
**user.model_dump(),
"groups": [{"id": group.id, "name": group.name} for group in groups],
"is_active": Users.is_user_active(user_id),
"is_active": Users.is_user_active(user_id, db=db),
}
)
else:
@ -429,8 +454,10 @@ async def get_user_by_id(user_id: str, user=Depends(get_verified_user)):
@router.get("/{user_id}/oauth/sessions")
async def get_user_oauth_sessions_by_id(user_id: str, user=Depends(get_admin_user)):
sessions = OAuthSessions.get_sessions_by_user_id(user_id)
async def get_user_oauth_sessions_by_id(
user_id: str, user=Depends(get_admin_user), db: Session = Depends(get_session)
):
sessions = OAuthSessions.get_sessions_by_user_id(user_id, db=db)
if sessions and len(sessions) > 0:
return sessions
else:
@ -446,8 +473,10 @@ async def get_user_oauth_sessions_by_id(user_id: str, user=Depends(get_admin_use
@router.get("/{user_id}/profile/image")
async def get_user_profile_image_by_id(user_id: str, user=Depends(get_verified_user)):
user = Users.get_user_by_id(user_id)
async def get_user_profile_image_by_id(
user_id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
):
user = Users.get_user_by_id(user_id, db=db)
if user:
if user.profile_image_url:
# check if it's url or base64
@ -484,9 +513,11 @@ async def get_user_profile_image_by_id(user_id: str, user=Depends(get_verified_u
@router.get("/{user_id}/active", response_model=dict)
async def get_user_active_status_by_id(user_id: str, user=Depends(get_verified_user)):
async def get_user_active_status_by_id(
user_id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
):
return {
"active": Users.is_user_active(user_id),
"active": Users.is_user_active(user_id, db=db),
}
@ -500,10 +531,11 @@ async def update_user_by_id(
user_id: str,
form_data: UserUpdateForm,
session_user=Depends(get_admin_user),
db: Session = Depends(get_session),
):
# Prevent modification of the primary admin user by other admins
try:
first_user = Users.get_first_user()
first_user = Users.get_first_user(db=db)
if first_user:
if user_id == first_user.id:
if session_user.id != user_id:
@ -527,11 +559,11 @@ async def update_user_by_id(
detail="Could not verify primary admin status.",
)
user = Users.get_user_by_id(user_id)
user = Users.get_user_by_id(user_id, db=db)
if user:
if form_data.email.lower() != user.email:
email_user = Users.get_user_by_email(form_data.email.lower())
email_user = Users.get_user_by_email(form_data.email.lower(), db=db)
if email_user:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
@ -545,9 +577,9 @@ async def update_user_by_id(
raise HTTPException(400, detail=str(e))
hashed = get_password_hash(form_data.password)
Auths.update_user_password_by_id(user_id, hashed)
Auths.update_user_password_by_id(user_id, hashed, db=db)
Auths.update_email_by_id(user_id, form_data.email.lower())
Auths.update_email_by_id(user_id, form_data.email.lower(), db=db)
updated_user = Users.update_user_by_id(
user_id,
{
@ -556,6 +588,7 @@ async def update_user_by_id(
"email": form_data.email.lower(),
"profile_image_url": form_data.profile_image_url,
},
db=db,
)
if updated_user:
@ -578,10 +611,12 @@ async def update_user_by_id(
@router.delete("/{user_id}", response_model=bool)
async def delete_user_by_id(user_id: str, user=Depends(get_admin_user)):
async def delete_user_by_id(
user_id: str, user=Depends(get_admin_user), db: Session = Depends(get_session)
):
# Prevent deletion of the primary admin user
try:
first_user = Users.get_first_user()
first_user = Users.get_first_user(db=db)
if first_user and user_id == first_user.id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
@ -595,7 +630,7 @@ async def delete_user_by_id(user_id: str, user=Depends(get_admin_user)):
)
if user.id != user_id:
result = Auths.delete_auth_by_id(user_id)
result = Auths.delete_auth_by_id(user_id, db=db)
if result:
return True
@ -618,5 +653,7 @@ async def delete_user_by_id(user_id: str, user=Depends(get_admin_user)):
@router.get("/{user_id}/groups")
async def get_user_groups_by_id(user_id: str, user=Depends(get_admin_user)):
return Groups.get_groups_by_member_id(user_id)
async def get_user_groups_by_id(
user_id: str, user=Depends(get_admin_user), db: Session = Depends(get_session)
):
return Groups.get_groups_by_member_id(user_id, db=db)

View file

@ -28,6 +28,7 @@ def fill_missing_permissions(
def get_permissions(
user_id: str,
default_permissions: Dict[str, Any],
db: Optional[Any] = None,
) -> Dict[str, Any]:
"""
Get all permissions for a user by combining the permissions of all groups the user is a member of.
@ -53,7 +54,7 @@ def get_permissions(
) # Use the most permissive value (True > False)
return permissions
user_groups = Groups.get_groups_by_member_id(user_id)
user_groups = Groups.get_groups_by_member_id(user_id, db=db)
# Deep copy default permissions to avoid modifying the original dict
permissions = json.loads(json.dumps(default_permissions))
@ -72,6 +73,7 @@ def has_permission(
user_id: str,
permission_key: str,
default_permissions: Dict[str, Any] = {},
db: Optional[Any] = None,
) -> bool:
"""
Check if a user has a specific permission by checking the group permissions
@ -92,7 +94,7 @@ def has_permission(
permission_hierarchy = permission_key.split(".")
# Retrieve user group permissions
user_groups = Groups.get_groups_by_member_id(user_id)
user_groups = Groups.get_groups_by_member_id(user_id, db=db)
for group in user_groups:
if get_permission(group.permissions or {}, permission_hierarchy):
@ -127,6 +129,7 @@ def has_access(
access_control: Optional[dict] = None,
user_group_ids: Optional[Set[str]] = None,
strict: bool = True,
db: Optional[Any] = None,
) -> bool:
if access_control is None:
if strict:
@ -135,7 +138,7 @@ def has_access(
return True
if user_group_ids is None:
user_groups = Groups.get_groups_by_member_id(user_id)
user_groups = Groups.get_groups_by_member_id(user_id, db=db)
user_group_ids = {group.id for group in user_groups}
permitted_ids = get_permitted_group_and_user_ids(type, access_control)
@ -152,10 +155,10 @@ def has_access(
# Get all users with access to a resource
def get_users_with_access(
type: str = "write", access_control: Optional[dict] = None
type: str = "write", access_control: Optional[dict] = None, db: Optional[Any] = None
) -> list[UserModel]:
if access_control is None:
result = Users.get_users(filter={"roles": ["!pending"]})
result = Users.get_users(filter={"roles": ["!pending"]}, db=db)
return result.get("users", [])
permitted_ids = get_permitted_group_and_user_ids(type, access_control)
@ -167,8 +170,8 @@ def get_users_with_access(
user_ids_with_access = set(permitted_user_ids)
group_user_ids_map = Groups.get_group_user_ids_by_ids(permitted_group_ids)
group_user_ids_map = Groups.get_group_user_ids_by_ids(permitted_group_ids, db=db)
for user_ids in group_user_ids_map.values():
user_ids_with_access.update(user_ids)
return Users.get_users_by_user_ids(list(user_ids_with_access))
return Users.get_users_by_user_ids(list(user_ids_with_access), db=db)

View file

@ -42,6 +42,8 @@ from open_webui.env import (
from fastapi import BackgroundTasks, Depends, HTTPException, Request, Response, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from sqlalchemy.orm import Session
from open_webui.internal.db import get_session
log = logging.getLogger(__name__)
@ -271,6 +273,7 @@ async def get_current_user(
response: Response,
background_tasks: BackgroundTasks,
auth_token: HTTPAuthorizationCredentials = Depends(bearer_security),
db: Session = Depends(get_session),
):
token = None
@ -285,7 +288,7 @@ async def get_current_user(
# auth by api key
if token.startswith("sk-"):
user = get_current_user_by_api_key(request, token)
user = get_current_user_by_api_key(request, token, db=db)
# Add user info to current span
current_span = trace.get_current_span()
@ -314,7 +317,7 @@ async def get_current_user(
detail="Invalid token",
)
user = Users.get_user_by_id(data["id"])
user = Users.get_user_by_id(data["id"], db=db)
if user is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@ -364,8 +367,8 @@ async def get_current_user(
raise e
def get_current_user_by_api_key(request, api_key: str):
user = Users.get_user_by_api_key(api_key)
def get_current_user_by_api_key(request, api_key: str, db: Session = None):
user = Users.get_user_by_api_key(api_key, db=db)
if user is None:
raise HTTPException(
@ -393,7 +396,7 @@ def get_current_user_by_api_key(request, api_key: str):
current_span.set_attribute("client.user.role", user.role)
current_span.set_attribute("client.auth.type", "api_key")
Users.update_last_active_by_id(user.id)
Users.update_last_active_by_id(user.id, db=db)
return user

View file

@ -7,6 +7,7 @@ log = logging.getLogger(__name__)
def apply_default_group_assignment(
default_group_id: str,
user_id: str,
db=None,
) -> None:
"""
Apply default group assignment to a user if default_group_id is provided.
@ -17,7 +18,7 @@ def apply_default_group_assignment(
"""
if default_group_id:
try:
Groups.add_users_to_group(default_group_id, [user_id])
Groups.add_users_to_group(default_group_id, [user_id], db=db)
except Exception as e:
log.error(
f"Failed to add user {user_id} to default group {default_group_id}: {e}"

View file

@ -1336,7 +1336,7 @@ class OAuthManager:
return await client.authorize_redirect(request, redirect_uri, **kwargs)
async def handle_callback(self, request, provider, response):
async def handle_callback(self, request, provider, response, db=None):
if provider not in OAUTH_PROVIDERS:
raise HTTPException(404)
@ -1461,20 +1461,20 @@ class OAuthManager:
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
# Check if the user exists
user = Users.get_user_by_oauth_sub(provider, sub)
user = Users.get_user_by_oauth_sub(provider, sub, db=db)
if not user:
# If the user does not exist, check if merging is enabled
if auth_manager_config.OAUTH_MERGE_ACCOUNTS_BY_EMAIL:
# Check if the user exists by email
user = Users.get_user_by_email(email)
user = Users.get_user_by_email(email, db=db)
if user:
# Update the user with the new oauth sub
Users.update_user_oauth_by_id(user.id, provider, sub)
Users.update_user_oauth_by_id(user.id, provider, sub, db=db)
if user:
determined_role = self.get_user_role(user, user_data)
if user.role != determined_role:
Users.update_user_role_by_id(user.id, determined_role)
Users.update_user_role_by_id(user.id, determined_role, db=db)
# Update the user object in memory as well,
# to avoid problems with the ENABLE_OAUTH_GROUP_MANAGEMENT check below
user.role = determined_role
@ -1491,14 +1491,14 @@ class OAuthManager:
)
if processed_picture_url != user.profile_image_url:
Users.update_user_profile_image_url_by_id(
user.id, processed_picture_url
user.id, processed_picture_url, db=db
)
log.debug(f"Updated profile picture for user {user.email}")
else:
# If the user does not exist, check if signups are enabled
if auth_manager_config.ENABLE_OAUTH_SIGNUP:
# Check if an existing user with the same email already exists
existing_user = Users.get_user_by_email(email)
existing_user = Users.get_user_by_email(email, db=db)
if existing_user:
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
@ -1529,6 +1529,7 @@ class OAuthManager:
profile_image_url=picture_url,
role=self.get_user_role(None, user_data),
oauth=oauth_data,
db=db,
)
if auth_manager_config.WEBHOOK_URL:
@ -1544,8 +1545,7 @@ class OAuthManager:
)
apply_default_group_assignment(
request.app.state.config.DEFAULT_GROUP_ID,
user.id,
request.app.state.config.DEFAULT_GROUP_ID, user.id, db=db
)
else:
@ -1616,15 +1616,16 @@ class OAuthManager:
token["expires_at"] = datetime.now().timestamp() + token["expires_in"]
# Clean up any existing sessions for this user/provider first
sessions = OAuthSessions.get_sessions_by_user_id(user.id)
sessions = OAuthSessions.get_sessions_by_user_id(user.id, db=db)
for session in sessions:
if session.provider == provider:
OAuthSessions.delete_session_by_id(session.id)
OAuthSessions.delete_session_by_id(session.id, db=db)
session = OAuthSessions.create_session(
user_id=user.id,
provider=provider,
token=token,
db=db,
)
response.set_cookie(