mirror of
https://github.com/open-webui/open-webui.git
synced 2026-01-01 22:25:20 +00:00
refac/enh: db session sharing
This commit is contained in:
parent
6dd0f99b90
commit
b1d0f00d8c
23 changed files with 1173 additions and 663 deletions
|
|
@ -19,7 +19,7 @@ from open_webui.env import (
|
|||
from peewee_migrate import Router
|
||||
from sqlalchemy import Dialect, create_engine, MetaData, event, types
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import scoped_session, sessionmaker
|
||||
from sqlalchemy.orm import scoped_session, sessionmaker, Session
|
||||
from sqlalchemy.pool import QueuePool, NullPool
|
||||
from sqlalchemy.sql.type_api import _T
|
||||
from typing_extensions import Self
|
||||
|
|
@ -148,7 +148,7 @@ SessionLocal = sessionmaker(
|
|||
)
|
||||
metadata_obj = MetaData(schema=DATABASE_SCHEMA)
|
||||
Base = declarative_base(metadata=metadata_obj)
|
||||
Session = scoped_session(SessionLocal)
|
||||
ScopedSession = scoped_session(SessionLocal)
|
||||
|
||||
|
||||
def get_session():
|
||||
|
|
@ -169,4 +169,3 @@ def get_db_context(db: Optional[Session] = None):
|
|||
else:
|
||||
with get_db() as session:
|
||||
yield session
|
||||
|
||||
|
|
|
|||
|
|
@ -102,7 +102,9 @@ from open_webui.routers.retrieval import (
|
|||
get_rf,
|
||||
)
|
||||
|
||||
from open_webui.internal.db import Session, engine
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from open_webui.internal.db import ScopedSession, engine, get_session
|
||||
|
||||
from open_webui.models.functions import Functions
|
||||
from open_webui.models.models import Models
|
||||
|
|
@ -1324,7 +1326,7 @@ app.add_middleware(APIKeyRestrictionMiddleware)
|
|||
async def commit_session_after_request(request: Request, call_next):
|
||||
response = await call_next(request)
|
||||
# log.debug("Commit session after request")
|
||||
Session.commit()
|
||||
ScopedSession.commit()
|
||||
return response
|
||||
|
||||
|
||||
|
|
@ -2280,8 +2282,13 @@ async def oauth_login(provider: str, request: Request):
|
|||
# - Email addresses are considered unique, so we fail registration if the email address is already taken
|
||||
@app.get("/oauth/{provider}/login/callback")
|
||||
@app.get("/oauth/{provider}/callback") # Legacy endpoint
|
||||
async def oauth_login_callback(provider: str, request: Request, response: Response):
|
||||
return await oauth_manager.handle_callback(request, provider, response)
|
||||
async def oauth_login_callback(
|
||||
provider: str,
|
||||
request: Request,
|
||||
response: Response,
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
return await oauth_manager.handle_callback(request, provider, response, db=db)
|
||||
|
||||
|
||||
@app.get("/manifest.json")
|
||||
|
|
@ -2340,7 +2347,7 @@ async def healthcheck():
|
|||
|
||||
@app.get("/health/db")
|
||||
async def healthcheck_with_db():
|
||||
Session.execute(text("SELECT 1;")).all()
|
||||
ScopedSession.execute(text("SELECT 1;")).all()
|
||||
return {"status": True}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -30,27 +30,27 @@ from sqlalchemy.exc import NoSuchTableError
|
|||
from sqlalchemy.dialects.postgresql.psycopg2 import PGDialect_psycopg2
|
||||
from sqlalchemy.dialects import registry
|
||||
|
||||
|
||||
class OpenGaussDialect(PGDialect_psycopg2):
|
||||
name = "opengauss"
|
||||
|
||||
|
||||
def _get_server_version_info(self, connection):
|
||||
try:
|
||||
version = connection.exec_driver_sql("SELECT version()").scalar()
|
||||
if not version:
|
||||
return (9, 0, 0)
|
||||
|
||||
|
||||
match = re.search(
|
||||
r"openGauss\s+(\d+)\.(\d+)\.(\d+)(?:-\w+)?",
|
||||
version,
|
||||
re.IGNORECASE
|
||||
r"openGauss\s+(\d+)\.(\d+)\.(\d+)(?:-\w+)?", version, re.IGNORECASE
|
||||
)
|
||||
if match:
|
||||
return (int(match.group(1)), int(match.group(2)), int(match.group(3)))
|
||||
|
||||
|
||||
return super()._get_server_version_info(connection)
|
||||
except Exception:
|
||||
return (9, 0, 0)
|
||||
|
||||
|
||||
# Register dialect
|
||||
registry.register("opengauss", __name__, "OpenGaussDialect")
|
||||
|
||||
|
|
@ -78,6 +78,7 @@ Base = declarative_base()
|
|||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
class DocumentChunk(Base):
|
||||
__tablename__ = "document_chunk"
|
||||
|
||||
|
|
@ -87,29 +88,30 @@ class DocumentChunk(Base):
|
|||
text = Column(Text, nullable=True)
|
||||
vmetadata = Column(MutableDict.as_mutable(JSONB), nullable=True)
|
||||
|
||||
|
||||
class OpenGaussClient(VectorDBBase):
|
||||
def __init__(self) -> None:
|
||||
if not OPENGAUSS_DB_URL:
|
||||
from open_webui.internal.db import Session
|
||||
self.session = Session
|
||||
from open_webui.internal.db import ScopedSession
|
||||
|
||||
self.session = ScopedSession
|
||||
else:
|
||||
engine_kwargs = {
|
||||
"pool_pre_ping": True,
|
||||
"dialect": OpenGaussDialect()
|
||||
}
|
||||
|
||||
engine_kwargs = {"pool_pre_ping": True, "dialect": OpenGaussDialect()}
|
||||
|
||||
if isinstance(OPENGAUSS_POOL_SIZE, int) and OPENGAUSS_POOL_SIZE > 0:
|
||||
engine_kwargs.update({
|
||||
"pool_size": OPENGAUSS_POOL_SIZE,
|
||||
"max_overflow": OPENGAUSS_POOL_MAX_OVERFLOW,
|
||||
"pool_timeout": OPENGAUSS_POOL_TIMEOUT,
|
||||
"pool_recycle": OPENGAUSS_POOL_RECYCLE,
|
||||
"poolclass": QueuePool
|
||||
})
|
||||
engine_kwargs.update(
|
||||
{
|
||||
"pool_size": OPENGAUSS_POOL_SIZE,
|
||||
"max_overflow": OPENGAUSS_POOL_MAX_OVERFLOW,
|
||||
"pool_timeout": OPENGAUSS_POOL_TIMEOUT,
|
||||
"pool_recycle": OPENGAUSS_POOL_RECYCLE,
|
||||
"poolclass": QueuePool,
|
||||
}
|
||||
)
|
||||
else:
|
||||
engine_kwargs["poolclass"] = NullPool
|
||||
|
||||
engine = create_engine(OPENGAUSS_DB_URL,** engine_kwargs)
|
||||
engine = create_engine(OPENGAUSS_DB_URL, **engine_kwargs)
|
||||
|
||||
SessionLocal = sessionmaker(
|
||||
autocommit=False, autoflush=False, bind=engine, expire_on_commit=False
|
||||
|
|
@ -160,7 +162,9 @@ class OpenGaussClient(VectorDBBase):
|
|||
else:
|
||||
raise Exception("The 'vector' column type is not Vector.")
|
||||
else:
|
||||
raise Exception("The 'vector' column does not exist in the 'document_chunk' table.")
|
||||
raise Exception(
|
||||
"The 'vector' column does not exist in the 'document_chunk' table."
|
||||
)
|
||||
|
||||
def adjust_vector_length(self, vector: List[float]) -> List[float]:
|
||||
current_length = len(vector)
|
||||
|
|
@ -185,7 +189,9 @@ class OpenGaussClient(VectorDBBase):
|
|||
new_items.append(new_chunk)
|
||||
self.session.bulk_save_objects(new_items)
|
||||
self.session.commit()
|
||||
log.info(f"Inserting {len(new_items)} items into collection '{collection_name}'.")
|
||||
log.info(
|
||||
f"Inserting {len(new_items)} items into collection '{collection_name}'."
|
||||
)
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
log.exception(f"Failed to insert data: {e}")
|
||||
|
|
@ -215,7 +221,9 @@ class OpenGaussClient(VectorDBBase):
|
|||
)
|
||||
self.session.add(new_chunk)
|
||||
self.session.commit()
|
||||
log.info(f"Inserting/updating {len(items)} items in collection '{collection_name}'.")
|
||||
log.info(
|
||||
f"Inserting/updating {len(items)} items in collection '{collection_name}'."
|
||||
)
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
log.exception(f"Failed to insert or update data.: {e}")
|
||||
|
|
@ -241,7 +249,9 @@ class OpenGaussClient(VectorDBBase):
|
|||
q_vector_col = column("q_vector", Vector(VECTOR_LENGTH))
|
||||
query_vectors = (
|
||||
values(qid_col, q_vector_col)
|
||||
.data([(idx, vector_expr(vector)) for idx, vector in enumerate(vectors)])
|
||||
.data(
|
||||
[(idx, vector_expr(vector)) for idx, vector in enumerate(vectors)]
|
||||
)
|
||||
.alias("query_vectors")
|
||||
)
|
||||
|
||||
|
|
@ -249,13 +259,17 @@ class OpenGaussClient(VectorDBBase):
|
|||
DocumentChunk.id,
|
||||
DocumentChunk.text,
|
||||
DocumentChunk.vmetadata,
|
||||
(DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)).label("distance"),
|
||||
(DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)).label(
|
||||
"distance"
|
||||
),
|
||||
]
|
||||
|
||||
subq = (
|
||||
select(*result_fields)
|
||||
.where(DocumentChunk.collection_name == collection_name)
|
||||
.order_by(DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector))
|
||||
.order_by(
|
||||
DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)
|
||||
)
|
||||
)
|
||||
if limit is not None:
|
||||
subq = subq.limit(limit)
|
||||
|
|
@ -368,7 +382,9 @@ class OpenGaussClient(VectorDBBase):
|
|||
query = query.filter(DocumentChunk.id.in_(ids))
|
||||
if filter:
|
||||
for key, value in filter.items():
|
||||
query = query.filter(DocumentChunk.vmetadata[key].astext == str(value))
|
||||
query = query.filter(
|
||||
DocumentChunk.vmetadata[key].astext == str(value)
|
||||
)
|
||||
deleted = query.delete(synchronize_session=False)
|
||||
self.session.commit()
|
||||
log.info(f"Deleted {deleted} items from collection '{collection_name}'")
|
||||
|
|
@ -395,7 +411,8 @@ class OpenGaussClient(VectorDBBase):
|
|||
exists = (
|
||||
self.session.query(DocumentChunk)
|
||||
.filter(DocumentChunk.collection_name == collection_name)
|
||||
.first() is not None
|
||||
.first()
|
||||
is not None
|
||||
)
|
||||
self.session.rollback()
|
||||
return exists
|
||||
|
|
@ -406,4 +423,4 @@ class OpenGaussClient(VectorDBBase):
|
|||
|
||||
def delete_collection(self, collection_name: str) -> None:
|
||||
self.delete(collection_name)
|
||||
log.info(f"Collection '{collection_name}' has been deleted")
|
||||
log.info(f"Collection '{collection_name}' has been deleted")
|
||||
|
|
|
|||
|
|
@ -90,9 +90,9 @@ class PgvectorClient(VectorDBBase):
|
|||
|
||||
# if no pgvector uri, use the existing database connection
|
||||
if not PGVECTOR_DB_URL:
|
||||
from open_webui.internal.db import Session
|
||||
from open_webui.internal.db import ScopedSession
|
||||
|
||||
self.session = Session
|
||||
self.session = ScopedSession
|
||||
else:
|
||||
if isinstance(PGVECTOR_POOL_SIZE, int):
|
||||
if PGVECTOR_POOL_SIZE > 0:
|
||||
|
|
|
|||
|
|
@ -62,6 +62,8 @@ from open_webui.utils.auth import (
|
|||
get_password_hash,
|
||||
get_http_authorization_cred,
|
||||
)
|
||||
from open_webui.internal.db import get_session
|
||||
from sqlalchemy.orm import Session
|
||||
from open_webui.utils.webhook import post_webhook
|
||||
from open_webui.utils.access_control import get_permissions, has_permission
|
||||
from open_webui.utils.groups import apply_default_group_assignment
|
||||
|
|
@ -103,7 +105,10 @@ class SessionUserInfoResponse(SessionUserResponse, UserStatus):
|
|||
|
||||
@router.get("/", response_model=SessionUserInfoResponse)
|
||||
async def get_session_user(
|
||||
request: Request, response: Response, user=Depends(get_current_user)
|
||||
request: Request,
|
||||
response: Response,
|
||||
user=Depends(get_current_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
|
||||
auth_header = request.headers.get("Authorization")
|
||||
|
|
@ -137,7 +142,7 @@ async def get_session_user(
|
|||
)
|
||||
|
||||
user_permissions = get_permissions(
|
||||
user.id, request.app.state.config.USER_PERMISSIONS
|
||||
user.id, request.app.state.config.USER_PERMISSIONS, db=db
|
||||
)
|
||||
|
||||
return {
|
||||
|
|
@ -166,12 +171,15 @@ async def get_session_user(
|
|||
|
||||
@router.post("/update/profile", response_model=UserProfileImageResponse)
|
||||
async def update_profile(
|
||||
form_data: UpdateProfileForm, session_user=Depends(get_verified_user)
|
||||
form_data: UpdateProfileForm,
|
||||
session_user=Depends(get_verified_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
if session_user:
|
||||
user = Users.update_user_by_id(
|
||||
session_user.id,
|
||||
form_data.model_dump(),
|
||||
db=db,
|
||||
)
|
||||
if user:
|
||||
return user
|
||||
|
|
@ -188,13 +196,17 @@ async def update_profile(
|
|||
|
||||
@router.post("/update/password", response_model=bool)
|
||||
async def update_password(
|
||||
form_data: UpdatePasswordForm, session_user=Depends(get_current_user)
|
||||
form_data: UpdatePasswordForm,
|
||||
session_user=Depends(get_current_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
if WEBUI_AUTH_TRUSTED_EMAIL_HEADER:
|
||||
raise HTTPException(400, detail=ERROR_MESSAGES.ACTION_PROHIBITED)
|
||||
if session_user:
|
||||
user = Auths.authenticate_user(
|
||||
session_user.email, lambda pw: verify_password(form_data.password, pw)
|
||||
session_user.email,
|
||||
lambda pw: verify_password(form_data.password, pw),
|
||||
db=db,
|
||||
)
|
||||
|
||||
if user:
|
||||
|
|
@ -203,7 +215,7 @@ async def update_password(
|
|||
except Exception as e:
|
||||
raise HTTPException(400, detail=str(e))
|
||||
hashed = get_password_hash(form_data.new_password)
|
||||
return Auths.update_user_password_by_id(user.id, hashed)
|
||||
return Auths.update_user_password_by_id(user.id, hashed, db=db)
|
||||
else:
|
||||
raise HTTPException(400, detail=ERROR_MESSAGES.INCORRECT_PASSWORD)
|
||||
else:
|
||||
|
|
@ -214,7 +226,12 @@ async def update_password(
|
|||
# LDAP Authentication
|
||||
############################
|
||||
@router.post("/ldap", response_model=SessionUserResponse)
|
||||
async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
||||
async def ldap_auth(
|
||||
request: Request,
|
||||
response: Response,
|
||||
form_data: LdapForm,
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
# Security checks FIRST - before loading any config
|
||||
if not request.app.state.config.ENABLE_LDAP:
|
||||
raise HTTPException(400, detail="LDAP authentication is not enabled")
|
||||
|
|
@ -400,12 +417,12 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
|||
if not connection_user.bind():
|
||||
raise HTTPException(400, "Authentication failed.")
|
||||
|
||||
user = Users.get_user_by_email(email)
|
||||
user = Users.get_user_by_email(email, db=db)
|
||||
if not user:
|
||||
try:
|
||||
role = (
|
||||
"admin"
|
||||
if not Users.has_users()
|
||||
if not Users.has_users(db=db)
|
||||
else request.app.state.config.DEFAULT_USER_ROLE
|
||||
)
|
||||
|
||||
|
|
@ -414,6 +431,7 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
|||
password=str(uuid.uuid4()),
|
||||
name=cn,
|
||||
role=role,
|
||||
db=db,
|
||||
)
|
||||
|
||||
if not user:
|
||||
|
|
@ -424,6 +442,7 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
|||
apply_default_group_assignment(
|
||||
request.app.state.config.DEFAULT_GROUP_ID,
|
||||
user.id,
|
||||
db=db,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
|
|
@ -434,7 +453,7 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
|||
500, detail="Internal error occurred during LDAP user creation."
|
||||
)
|
||||
|
||||
user = Auths.authenticate_user_by_email(email)
|
||||
user = Auths.authenticate_user_by_email(email, db=db)
|
||||
|
||||
if user:
|
||||
expires_delta = parse_duration(request.app.state.config.JWT_EXPIRES_IN)
|
||||
|
|
@ -464,7 +483,7 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
|||
)
|
||||
|
||||
user_permissions = get_permissions(
|
||||
user.id, request.app.state.config.USER_PERMISSIONS
|
||||
user.id, request.app.state.config.USER_PERMISSIONS, db=db
|
||||
)
|
||||
|
||||
if (
|
||||
|
|
@ -473,9 +492,9 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
|||
and user_groups
|
||||
):
|
||||
if ENABLE_LDAP_GROUP_CREATION:
|
||||
Groups.create_groups_by_group_names(user.id, user_groups)
|
||||
Groups.create_groups_by_group_names(user.id, user_groups, db=db)
|
||||
try:
|
||||
Groups.sync_groups_by_group_names(user.id, user_groups)
|
||||
Groups.sync_groups_by_group_names(user.id, user_groups, db=db)
|
||||
log.info(
|
||||
f"Successfully synced groups for user {user.id}: {user_groups}"
|
||||
)
|
||||
|
|
@ -508,7 +527,12 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
|||
|
||||
|
||||
@router.post("/signin", response_model=SessionUserResponse)
|
||||
async def signin(request: Request, response: Response, form_data: SigninForm):
|
||||
async def signin(
|
||||
request: Request,
|
||||
response: Response,
|
||||
form_data: SigninForm,
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
if not ENABLE_PASSWORD_AUTH:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
|
|
@ -529,14 +553,15 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
|
|||
except Exception as e:
|
||||
pass
|
||||
|
||||
if not Users.get_user_by_email(email.lower()):
|
||||
if not Users.get_user_by_email(email.lower(), db=db):
|
||||
await signup(
|
||||
request,
|
||||
response,
|
||||
SignupForm(email=email, password=str(uuid.uuid4()), name=name),
|
||||
db=db,
|
||||
)
|
||||
|
||||
user = Auths.authenticate_user_by_email(email)
|
||||
user = Auths.authenticate_user_by_email(email, db=db)
|
||||
if WEBUI_AUTH_TRUSTED_GROUPS_HEADER and user and user.role != "admin":
|
||||
group_names = request.headers.get(
|
||||
WEBUI_AUTH_TRUSTED_GROUPS_HEADER, ""
|
||||
|
|
@ -544,28 +569,33 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
|
|||
group_names = [name.strip() for name in group_names if name.strip()]
|
||||
|
||||
if group_names:
|
||||
Groups.sync_groups_by_group_names(user.id, group_names)
|
||||
Groups.sync_groups_by_group_names(user.id, group_names, db=db)
|
||||
|
||||
elif WEBUI_AUTH == False:
|
||||
admin_email = "admin@localhost"
|
||||
admin_password = "admin"
|
||||
|
||||
if Users.get_user_by_email(admin_email.lower()):
|
||||
if Users.get_user_by_email(admin_email.lower(), db=db):
|
||||
user = Auths.authenticate_user(
|
||||
admin_email.lower(), lambda pw: verify_password(admin_password, pw)
|
||||
admin_email.lower(),
|
||||
lambda pw: verify_password(admin_password, pw),
|
||||
db=db,
|
||||
)
|
||||
else:
|
||||
if Users.has_users():
|
||||
if Users.has_users(db=db):
|
||||
raise HTTPException(400, detail=ERROR_MESSAGES.EXISTING_USERS)
|
||||
|
||||
await signup(
|
||||
request,
|
||||
response,
|
||||
SignupForm(email=admin_email, password=admin_password, name="User"),
|
||||
db=db,
|
||||
)
|
||||
|
||||
user = Auths.authenticate_user(
|
||||
admin_email.lower(), lambda pw: verify_password(admin_password, pw)
|
||||
admin_email.lower(),
|
||||
lambda pw: verify_password(admin_password, pw),
|
||||
db=db,
|
||||
)
|
||||
else:
|
||||
if signin_rate_limiter.is_limited(form_data.email.lower()):
|
||||
|
|
@ -584,7 +614,9 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
|
|||
form_data.password = password_bytes.decode("utf-8", errors="ignore")
|
||||
|
||||
user = Auths.authenticate_user(
|
||||
form_data.email.lower(), lambda pw: verify_password(form_data.password, pw)
|
||||
form_data.email.lower(),
|
||||
lambda pw: verify_password(form_data.password, pw),
|
||||
db=db,
|
||||
)
|
||||
|
||||
if user:
|
||||
|
|
@ -616,7 +648,7 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
|
|||
)
|
||||
|
||||
user_permissions = get_permissions(
|
||||
user.id, request.app.state.config.USER_PERMISSIONS
|
||||
user.id, request.app.state.config.USER_PERMISSIONS, db=db
|
||||
)
|
||||
|
||||
return {
|
||||
|
|
@ -640,8 +672,13 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
|
|||
|
||||
|
||||
@router.post("/signup", response_model=SessionUserResponse)
|
||||
async def signup(request: Request, response: Response, form_data: SignupForm):
|
||||
has_users = Users.has_users()
|
||||
async def signup(
|
||||
request: Request,
|
||||
response: Response,
|
||||
form_data: SignupForm,
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
has_users = Users.has_users(db=db)
|
||||
|
||||
if WEBUI_AUTH:
|
||||
if (
|
||||
|
|
@ -663,7 +700,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
|
|||
status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT
|
||||
)
|
||||
|
||||
if Users.get_user_by_email(form_data.email.lower()):
|
||||
if Users.get_user_by_email(form_data.email.lower(), db=db):
|
||||
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
|
||||
|
||||
try:
|
||||
|
|
@ -681,6 +718,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
|
|||
form_data.name,
|
||||
form_data.profile_image_url,
|
||||
role,
|
||||
db=db,
|
||||
)
|
||||
|
||||
if user:
|
||||
|
|
@ -723,7 +761,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
|
|||
)
|
||||
|
||||
user_permissions = get_permissions(
|
||||
user.id, request.app.state.config.USER_PERMISSIONS
|
||||
user.id, request.app.state.config.USER_PERMISSIONS, db=db
|
||||
)
|
||||
|
||||
if not has_users:
|
||||
|
|
@ -733,6 +771,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
|
|||
apply_default_group_assignment(
|
||||
request.app.state.config.DEFAULT_GROUP_ID,
|
||||
user.id,
|
||||
db=db,
|
||||
)
|
||||
|
||||
return {
|
||||
|
|
@ -754,7 +793,9 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
|
|||
|
||||
|
||||
@router.get("/signout")
|
||||
async def signout(request: Request, response: Response):
|
||||
async def signout(
|
||||
request: Request, response: Response, db: Session = Depends(get_session)
|
||||
):
|
||||
|
||||
# get auth token from headers or cookies
|
||||
token = None
|
||||
|
|
@ -776,7 +817,7 @@ async def signout(request: Request, response: Response):
|
|||
if oauth_session_id:
|
||||
response.delete_cookie("oauth_session_id")
|
||||
|
||||
session = OAuthSessions.get_session_by_id(oauth_session_id)
|
||||
session = OAuthSessions.get_session_by_id(oauth_session_id, db=db)
|
||||
oauth_server_metadata_url = (
|
||||
request.app.state.oauth_manager.get_server_metadata_url(session.provider)
|
||||
if session
|
||||
|
|
@ -839,14 +880,17 @@ async def signout(request: Request, response: Response):
|
|||
|
||||
@router.post("/add", response_model=SigninResponse)
|
||||
async def add_user(
|
||||
request: Request, form_data: AddUserForm, user=Depends(get_admin_user)
|
||||
request: Request,
|
||||
form_data: AddUserForm,
|
||||
user=Depends(get_admin_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
if not validate_email_format(form_data.email.lower()):
|
||||
raise HTTPException(
|
||||
status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT
|
||||
)
|
||||
|
||||
if Users.get_user_by_email(form_data.email.lower()):
|
||||
if Users.get_user_by_email(form_data.email.lower(), db=db):
|
||||
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
|
||||
|
||||
try:
|
||||
|
|
@ -862,12 +906,14 @@ async def add_user(
|
|||
form_data.name,
|
||||
form_data.profile_image_url,
|
||||
form_data.role,
|
||||
db=db,
|
||||
)
|
||||
|
||||
if user:
|
||||
apply_default_group_assignment(
|
||||
request.app.state.config.DEFAULT_GROUP_ID,
|
||||
user.id,
|
||||
db=db,
|
||||
)
|
||||
|
||||
token = create_token(data={"id": user.id})
|
||||
|
|
@ -895,7 +941,9 @@ async def add_user(
|
|||
|
||||
|
||||
@router.get("/admin/details")
|
||||
async def get_admin_details(request: Request, user=Depends(get_current_user)):
|
||||
async def get_admin_details(
|
||||
request: Request, user=Depends(get_current_user), db: Session = Depends(get_session)
|
||||
):
|
||||
if request.app.state.config.SHOW_ADMIN_DETAILS:
|
||||
admin_email = request.app.state.config.ADMIN_EMAIL
|
||||
admin_name = None
|
||||
|
|
@ -903,11 +951,11 @@ async def get_admin_details(request: Request, user=Depends(get_current_user)):
|
|||
log.info(f"Admin details - Email: {admin_email}, Name: {admin_name}")
|
||||
|
||||
if admin_email:
|
||||
admin = Users.get_user_by_email(admin_email)
|
||||
admin = Users.get_user_by_email(admin_email, db=db)
|
||||
if admin:
|
||||
admin_name = admin.name
|
||||
else:
|
||||
admin = Users.get_first_user()
|
||||
admin = Users.get_first_user(db=db)
|
||||
if admin:
|
||||
admin_email = admin.email
|
||||
admin_name = admin.name
|
||||
|
|
@ -1149,7 +1197,9 @@ async def update_ldap_config(
|
|||
|
||||
# create api key
|
||||
@router.post("/api_key", response_model=ApiKey)
|
||||
async def generate_api_key(request: Request, user=Depends(get_current_user)):
|
||||
async def generate_api_key(
|
||||
request: Request, user=Depends(get_current_user), db: Session = Depends(get_session)
|
||||
):
|
||||
if not request.app.state.config.ENABLE_API_KEYS or not has_permission(
|
||||
user.id, "features.api_keys", request.app.state.config.USER_PERMISSIONS
|
||||
):
|
||||
|
|
@ -1159,7 +1209,7 @@ async def generate_api_key(request: Request, user=Depends(get_current_user)):
|
|||
)
|
||||
|
||||
api_key = create_api_key()
|
||||
success = Users.update_user_api_key_by_id(user.id, api_key)
|
||||
success = Users.update_user_api_key_by_id(user.id, api_key, db=db)
|
||||
|
||||
if success:
|
||||
return {
|
||||
|
|
@ -1171,14 +1221,18 @@ async def generate_api_key(request: Request, user=Depends(get_current_user)):
|
|||
|
||||
# delete api key
|
||||
@router.delete("/api_key", response_model=bool)
|
||||
async def delete_api_key(user=Depends(get_current_user)):
|
||||
return Users.delete_user_api_key_by_id(user.id)
|
||||
async def delete_api_key(
|
||||
user=Depends(get_current_user), db: Session = Depends(get_session)
|
||||
):
|
||||
return Users.delete_user_api_key_by_id(user.id, db=db)
|
||||
|
||||
|
||||
# get api key
|
||||
@router.get("/api_key", response_model=ApiKey)
|
||||
async def get_api_key(user=Depends(get_current_user)):
|
||||
api_key = Users.get_user_api_key_by_id(user.id)
|
||||
async def get_api_key(
|
||||
user=Depends(get_current_user), db: Session = Depends(get_session)
|
||||
):
|
||||
api_key = Users.get_user_api_key_by_id(user.id, db=db)
|
||||
if api_key:
|
||||
return {
|
||||
"api_key": api_key,
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import json
|
||||
import logging
|
||||
from typing import Optional
|
||||
from sqlalchemy.orm import Session
|
||||
import asyncio
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
|
|
@ -23,6 +24,7 @@ from open_webui.models.chats import (
|
|||
)
|
||||
from open_webui.models.tags import TagModel, Tags
|
||||
from open_webui.models.folders import Folders
|
||||
from open_webui.internal.db import get_session
|
||||
|
||||
from open_webui.config import ENABLE_ADMIN_CHAT_ACCESS, ENABLE_ADMIN_EXPORT
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
|
|
@ -49,6 +51,7 @@ def get_session_user_chat_list(
|
|||
page: Optional[int] = None,
|
||||
include_pinned: Optional[bool] = False,
|
||||
include_folders: Optional[bool] = False,
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
try:
|
||||
if page is not None:
|
||||
|
|
@ -61,10 +64,14 @@ def get_session_user_chat_list(
|
|||
include_pinned=include_pinned,
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
db=db,
|
||||
)
|
||||
else:
|
||||
return Chats.get_chat_title_id_list_by_user_id(
|
||||
user.id, include_folders=include_folders, include_pinned=include_pinned
|
||||
user.id,
|
||||
include_folders=include_folders,
|
||||
include_pinned=include_pinned,
|
||||
db=db,
|
||||
)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
|
@ -84,12 +91,13 @@ def get_session_user_chat_usage_stats(
|
|||
items_per_page: Optional[int] = 50,
|
||||
page: Optional[int] = 1,
|
||||
user=Depends(get_verified_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
try:
|
||||
limit = items_per_page
|
||||
skip = (page - 1) * limit
|
||||
|
||||
result = Chats.get_chats_by_user_id(user.id, skip=skip, limit=limit)
|
||||
result = Chats.get_chats_by_user_id(user.id, skip=skip, limit=limit, db=db)
|
||||
|
||||
chats = result.items
|
||||
total = result.total
|
||||
|
|
@ -216,6 +224,7 @@ class ChatStatsExportList(BaseModel):
|
|||
|
||||
def _process_chat_for_export(chat) -> Optional[ChatStatsExport]:
|
||||
try:
|
||||
|
||||
def get_message_content_length(message):
|
||||
content = message.get("content", "")
|
||||
if isinstance(content, str):
|
||||
|
|
@ -348,7 +357,9 @@ def _process_chat_for_export(chat) -> Optional[ChatStatsExport]:
|
|||
return None
|
||||
|
||||
|
||||
def calculate_chat_stats(user_id, skip=0, limit=10, filter=None):
|
||||
def calculate_chat_stats(
|
||||
user_id, skip=0, limit=10, filter=None, db: Optional[Session] = None
|
||||
):
|
||||
if filter is None:
|
||||
filter = {}
|
||||
|
||||
|
|
@ -357,6 +368,7 @@ def calculate_chat_stats(user_id, skip=0, limit=10, filter=None):
|
|||
skip=skip,
|
||||
limit=limit,
|
||||
filter=filter,
|
||||
db=db,
|
||||
)
|
||||
|
||||
chat_stats_export_list = []
|
||||
|
|
@ -368,14 +380,21 @@ def calculate_chat_stats(user_id, skip=0, limit=10, filter=None):
|
|||
return chat_stats_export_list, result.total
|
||||
|
||||
|
||||
async def generate_chat_stats_jsonl_generator(user_id, filter):
|
||||
async def generate_chat_stats_jsonl_generator(
|
||||
user_id, filter, db: Optional[Session] = None
|
||||
):
|
||||
skip = 0
|
||||
limit = CHAT_EXPORT_PAGE_ITEM_COUNT
|
||||
|
||||
while True:
|
||||
# Use asyncio.to_thread to make the blocking DB call non-blocking
|
||||
result = await asyncio.to_thread(
|
||||
Chats.get_chats_by_user_id, user_id, filter=filter, skip=skip, limit=limit
|
||||
Chats.get_chats_by_user_id,
|
||||
user_id,
|
||||
filter=filter,
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
db=db,
|
||||
)
|
||||
if not result.items:
|
||||
break
|
||||
|
|
@ -386,7 +405,7 @@ async def generate_chat_stats_jsonl_generator(user_id, filter):
|
|||
if chat_stat:
|
||||
yield chat_stat.model_dump_json() + "\n"
|
||||
except Exception as e:
|
||||
log.exception(f"Error processing chat {chat.id}: {e}")
|
||||
log.exception(f"Error processing chat {chat.id}: {e}")
|
||||
|
||||
skip += limit
|
||||
|
||||
|
|
@ -400,6 +419,7 @@ async def export_chat_stats(
|
|||
page: Optional[int] = 1,
|
||||
stream: bool = False,
|
||||
user=Depends(get_verified_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
# Check if the user has permission to share/export chats
|
||||
if (user.role != "admin") and (
|
||||
|
|
@ -415,7 +435,7 @@ async def export_chat_stats(
|
|||
filter = {"order_by": "created_at", "direction": "asc"}
|
||||
|
||||
if chat_id:
|
||||
chat = Chats.get_chat_by_id(chat_id)
|
||||
chat = Chats.get_chat_by_id(chat_id, db=db)
|
||||
if chat:
|
||||
filter["start_time"] = chat.created_at
|
||||
|
||||
|
|
@ -426,7 +446,7 @@ async def export_chat_stats(
|
|||
|
||||
if stream:
|
||||
return StreamingResponse(
|
||||
generate_chat_stats_jsonl_generator(user.id, filter),
|
||||
generate_chat_stats_jsonl_generator(user.id, filter, db=db),
|
||||
media_type="application/x-ndjson",
|
||||
headers={
|
||||
"Content-Disposition": f"attachment; filename=chat-stats-export-{user.id}.jsonl"
|
||||
|
|
@ -437,7 +457,7 @@ async def export_chat_stats(
|
|||
skip = (page - 1) * limit
|
||||
|
||||
chat_stats_export_list, total = await asyncio.to_thread(
|
||||
calculate_chat_stats, user.id, skip, limit, filter
|
||||
calculate_chat_stats, user.id, skip, limit, filter, db=db
|
||||
)
|
||||
|
||||
return ChatStatsExportList(
|
||||
|
|
@ -452,7 +472,11 @@ async def export_chat_stats(
|
|||
|
||||
|
||||
@router.delete("/", response_model=bool)
|
||||
async def delete_all_user_chats(request: Request, user=Depends(get_verified_user)):
|
||||
async def delete_all_user_chats(
|
||||
request: Request,
|
||||
user=Depends(get_verified_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
|
||||
if user.role == "user" and not has_permission(
|
||||
user.id, "chat.delete", request.app.state.config.USER_PERMISSIONS
|
||||
|
|
@ -462,7 +486,7 @@ async def delete_all_user_chats(request: Request, user=Depends(get_verified_user
|
|||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||
)
|
||||
|
||||
result = Chats.delete_chats_by_user_id(user.id)
|
||||
result = Chats.delete_chats_by_user_id(user.id, db=db)
|
||||
return result
|
||||
|
||||
|
||||
|
|
@ -479,6 +503,7 @@ async def get_user_chat_list_by_user_id(
|
|||
order_by: Optional[str] = None,
|
||||
direction: Optional[str] = None,
|
||||
user=Depends(get_admin_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
if not ENABLE_ADMIN_CHAT_ACCESS:
|
||||
raise HTTPException(
|
||||
|
|
@ -501,7 +526,7 @@ async def get_user_chat_list_by_user_id(
|
|||
filter["direction"] = direction
|
||||
|
||||
return Chats.get_chat_list_by_user_id(
|
||||
user_id, include_archived=True, filter=filter, skip=skip, limit=limit
|
||||
user_id, include_archived=True, filter=filter, skip=skip, limit=limit, db=db
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -511,9 +536,13 @@ async def get_user_chat_list_by_user_id(
|
|||
|
||||
|
||||
@router.post("/new", response_model=Optional[ChatResponse])
|
||||
async def create_new_chat(form_data: ChatForm, user=Depends(get_verified_user)):
|
||||
async def create_new_chat(
|
||||
form_data: ChatForm,
|
||||
user=Depends(get_verified_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
try:
|
||||
chat = Chats.insert_new_chat(user.id, form_data)
|
||||
chat = Chats.insert_new_chat(user.id, form_data, db=db)
|
||||
return ChatResponse(**chat.model_dump())
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
|
@ -528,9 +557,13 @@ async def create_new_chat(form_data: ChatForm, user=Depends(get_verified_user)):
|
|||
|
||||
|
||||
@router.post("/import", response_model=list[ChatResponse])
|
||||
async def import_chats(form_data: ChatsImportForm, user=Depends(get_verified_user)):
|
||||
async def import_chats(
|
||||
form_data: ChatsImportForm,
|
||||
user=Depends(get_verified_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
try:
|
||||
chats = Chats.import_chats(user.id, form_data.chats)
|
||||
chats = Chats.import_chats(user.id, form_data.chats, db=db)
|
||||
return chats
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
|
@ -546,7 +579,10 @@ async def import_chats(form_data: ChatsImportForm, user=Depends(get_verified_use
|
|||
|
||||
@router.get("/search", response_model=list[ChatTitleIdResponse])
|
||||
def search_user_chats(
|
||||
text: str, page: Optional[int] = None, user=Depends(get_verified_user)
|
||||
text: str,
|
||||
page: Optional[int] = None,
|
||||
user=Depends(get_verified_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
if page is None:
|
||||
page = 1
|
||||
|
|
@ -557,7 +593,7 @@ def search_user_chats(
|
|||
chat_list = [
|
||||
ChatTitleIdResponse(**chat.model_dump())
|
||||
for chat in Chats.get_chats_by_user_id_and_search_text(
|
||||
user.id, text, skip=skip, limit=limit
|
||||
user.id, text, skip=skip, limit=limit, db=db
|
||||
)
|
||||
]
|
||||
|
||||
|
|
@ -566,9 +602,9 @@ def search_user_chats(
|
|||
if page == 1 and len(words) == 1 and words[0].startswith("tag:"):
|
||||
tag_id = words[0].replace("tag:", "")
|
||||
if len(chat_list) == 0:
|
||||
if Tags.get_tag_by_name_and_user_id(tag_id, user.id):
|
||||
if Tags.get_tag_by_name_and_user_id(tag_id, user.id, db=db):
|
||||
log.debug(f"deleting tag: {tag_id}")
|
||||
Tags.delete_tag_by_name_and_user_id(tag_id, user.id)
|
||||
Tags.delete_tag_by_name_and_user_id(tag_id, user.id, db=db)
|
||||
|
||||
return chat_list
|
||||
|
||||
|
|
@ -579,23 +615,30 @@ def search_user_chats(
|
|||
|
||||
|
||||
@router.get("/folder/{folder_id}", response_model=list[ChatResponse])
|
||||
async def get_chats_by_folder_id(folder_id: str, user=Depends(get_verified_user)):
|
||||
async def get_chats_by_folder_id(
|
||||
folder_id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||
):
|
||||
folder_ids = [folder_id]
|
||||
children_folders = Folders.get_children_folders_by_id_and_user_id(
|
||||
folder_id, user.id
|
||||
folder_id, user.id, db=db
|
||||
)
|
||||
if children_folders:
|
||||
folder_ids.extend([folder.id for folder in children_folders])
|
||||
|
||||
return [
|
||||
ChatResponse(**chat.model_dump())
|
||||
for chat in Chats.get_chats_by_folder_ids_and_user_id(folder_ids, user.id)
|
||||
for chat in Chats.get_chats_by_folder_ids_and_user_id(
|
||||
folder_ids, user.id, db=db
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@router.get("/folder/{folder_id}/list")
|
||||
async def get_chat_list_by_folder_id(
|
||||
folder_id: str, page: Optional[int] = 1, user=Depends(get_verified_user)
|
||||
folder_id: str,
|
||||
page: Optional[int] = 1,
|
||||
user=Depends(get_verified_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
try:
|
||||
limit = 10
|
||||
|
|
@ -604,7 +647,7 @@ async def get_chat_list_by_folder_id(
|
|||
return [
|
||||
{"title": chat.title, "id": chat.id, "updated_at": chat.updated_at}
|
||||
for chat in Chats.get_chats_by_folder_id_and_user_id(
|
||||
folder_id, user.id, skip=skip, limit=limit
|
||||
folder_id, user.id, skip=skip, limit=limit, db=db
|
||||
)
|
||||
]
|
||||
|
||||
|
|
@ -621,10 +664,12 @@ async def get_chat_list_by_folder_id(
|
|||
|
||||
|
||||
@router.get("/pinned", response_model=list[ChatTitleIdResponse])
|
||||
async def get_user_pinned_chats(user=Depends(get_verified_user)):
|
||||
async def get_user_pinned_chats(
|
||||
user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||
):
|
||||
return [
|
||||
ChatTitleIdResponse(**chat.model_dump())
|
||||
for chat in Chats.get_pinned_chats_by_user_id(user.id)
|
||||
for chat in Chats.get_pinned_chats_by_user_id(user.id, db=db)
|
||||
]
|
||||
|
||||
|
||||
|
|
@ -634,10 +679,12 @@ async def get_user_pinned_chats(user=Depends(get_verified_user)):
|
|||
|
||||
|
||||
@router.get("/all", response_model=list[ChatResponse])
|
||||
async def get_user_chats(user=Depends(get_verified_user)):
|
||||
async def get_user_chats(
|
||||
user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||
):
|
||||
return [
|
||||
ChatResponse(**chat.model_dump())
|
||||
for chat in Chats.get_chats_by_user_id(user.id)
|
||||
for chat in Chats.get_chats_by_user_id(user.id, db=db)
|
||||
]
|
||||
|
||||
|
||||
|
|
@ -647,10 +694,12 @@ async def get_user_chats(user=Depends(get_verified_user)):
|
|||
|
||||
|
||||
@router.get("/all/archived", response_model=list[ChatResponse])
|
||||
async def get_user_archived_chats(user=Depends(get_verified_user)):
|
||||
async def get_user_archived_chats(
|
||||
user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||
):
|
||||
return [
|
||||
ChatResponse(**chat.model_dump())
|
||||
for chat in Chats.get_archived_chats_by_user_id(user.id)
|
||||
for chat in Chats.get_archived_chats_by_user_id(user.id, db=db)
|
||||
]
|
||||
|
||||
|
||||
|
|
@ -660,9 +709,11 @@ async def get_user_archived_chats(user=Depends(get_verified_user)):
|
|||
|
||||
|
||||
@router.get("/all/tags", response_model=list[TagModel])
|
||||
async def get_all_user_tags(user=Depends(get_verified_user)):
|
||||
async def get_all_user_tags(
|
||||
user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||
):
|
||||
try:
|
||||
tags = Tags.get_tags_by_user_id(user.id)
|
||||
tags = Tags.get_tags_by_user_id(user.id, db=db)
|
||||
return tags
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
|
@ -677,13 +728,15 @@ async def get_all_user_tags(user=Depends(get_verified_user)):
|
|||
|
||||
|
||||
@router.get("/all/db", response_model=list[ChatResponse])
|
||||
async def get_all_user_chats_in_db(user=Depends(get_admin_user)):
|
||||
async def get_all_user_chats_in_db(
|
||||
user=Depends(get_admin_user), db: Session = Depends(get_session)
|
||||
):
|
||||
if not ENABLE_ADMIN_EXPORT:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||
)
|
||||
return [ChatResponse(**chat.model_dump()) for chat in Chats.get_chats()]
|
||||
return [ChatResponse(**chat.model_dump()) for chat in Chats.get_chats(db=db)]
|
||||
|
||||
|
||||
############################
|
||||
|
|
@ -698,6 +751,7 @@ async def get_archived_session_user_chat_list(
|
|||
order_by: Optional[str] = None,
|
||||
direction: Optional[str] = None,
|
||||
user=Depends(get_verified_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
if page is None:
|
||||
page = 1
|
||||
|
|
@ -720,6 +774,7 @@ async def get_archived_session_user_chat_list(
|
|||
filter=filter,
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
db=db,
|
||||
)
|
||||
]
|
||||
|
||||
|
|
@ -732,8 +787,10 @@ async def get_archived_session_user_chat_list(
|
|||
|
||||
|
||||
@router.post("/archive/all", response_model=bool)
|
||||
async def archive_all_chats(user=Depends(get_verified_user)):
|
||||
return Chats.archive_all_chats_by_user_id(user.id)
|
||||
async def archive_all_chats(
|
||||
user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||
):
|
||||
return Chats.archive_all_chats_by_user_id(user.id, db=db)
|
||||
|
||||
|
||||
############################
|
||||
|
|
@ -742,8 +799,10 @@ async def archive_all_chats(user=Depends(get_verified_user)):
|
|||
|
||||
|
||||
@router.post("/unarchive/all", response_model=bool)
|
||||
async def unarchive_all_chats(user=Depends(get_verified_user)):
|
||||
return Chats.unarchive_all_chats_by_user_id(user.id)
|
||||
async def unarchive_all_chats(
|
||||
user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||
):
|
||||
return Chats.unarchive_all_chats_by_user_id(user.id, db=db)
|
||||
|
||||
|
||||
############################
|
||||
|
|
@ -752,16 +811,18 @@ async def unarchive_all_chats(user=Depends(get_verified_user)):
|
|||
|
||||
|
||||
@router.get("/share/{share_id}", response_model=Optional[ChatResponse])
|
||||
async def get_shared_chat_by_id(share_id: str, user=Depends(get_verified_user)):
|
||||
async def get_shared_chat_by_id(
|
||||
share_id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||
):
|
||||
if user.role == "pending":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
)
|
||||
|
||||
if user.role == "user" or (user.role == "admin" and not ENABLE_ADMIN_CHAT_ACCESS):
|
||||
chat = Chats.get_chat_by_share_id(share_id)
|
||||
chat = Chats.get_chat_by_share_id(share_id, db=db)
|
||||
elif user.role == "admin" and ENABLE_ADMIN_CHAT_ACCESS:
|
||||
chat = Chats.get_chat_by_id(share_id)
|
||||
chat = Chats.get_chat_by_id(share_id, db=db)
|
||||
|
||||
if chat:
|
||||
return ChatResponse(**chat.model_dump())
|
||||
|
|
@ -788,13 +849,15 @@ class TagFilterForm(TagForm):
|
|||
|
||||
@router.post("/tags", response_model=list[ChatTitleIdResponse])
|
||||
async def get_user_chat_list_by_tag_name(
|
||||
form_data: TagFilterForm, user=Depends(get_verified_user)
|
||||
form_data: TagFilterForm,
|
||||
user=Depends(get_verified_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
chats = Chats.get_chat_list_by_user_id_and_tag_name(
|
||||
user.id, form_data.name, form_data.skip, form_data.limit
|
||||
user.id, form_data.name, form_data.skip, form_data.limit, db=db
|
||||
)
|
||||
if len(chats) == 0:
|
||||
Tags.delete_tag_by_name_and_user_id(form_data.name, user.id)
|
||||
Tags.delete_tag_by_name_and_user_id(form_data.name, user.id, db=db)
|
||||
|
||||
return chats
|
||||
|
||||
|
|
@ -805,8 +868,10 @@ async def get_user_chat_list_by_tag_name(
|
|||
|
||||
|
||||
@router.get("/{id}", response_model=Optional[ChatResponse])
|
||||
async def get_chat_by_id(id: str, user=Depends(get_verified_user)):
|
||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
||||
async def get_chat_by_id(
|
||||
id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||
):
|
||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db)
|
||||
|
||||
if chat:
|
||||
return ChatResponse(**chat.model_dump())
|
||||
|
|
@ -824,12 +889,15 @@ async def get_chat_by_id(id: str, user=Depends(get_verified_user)):
|
|||
|
||||
@router.post("/{id}", response_model=Optional[ChatResponse])
|
||||
async def update_chat_by_id(
|
||||
id: str, form_data: ChatForm, user=Depends(get_verified_user)
|
||||
id: str,
|
||||
form_data: ChatForm,
|
||||
user=Depends(get_verified_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db)
|
||||
if chat:
|
||||
updated_chat = {**chat.chat, **form_data.chat}
|
||||
chat = Chats.update_chat_by_id(id, updated_chat)
|
||||
chat = Chats.update_chat_by_id(id, updated_chat, db=db)
|
||||
return ChatResponse(**chat.model_dump())
|
||||
else:
|
||||
raise HTTPException(
|
||||
|
|
@ -847,9 +915,13 @@ class MessageForm(BaseModel):
|
|||
|
||||
@router.post("/{id}/messages/{message_id}", response_model=Optional[ChatResponse])
|
||||
async def update_chat_message_by_id(
|
||||
id: str, message_id: str, form_data: MessageForm, user=Depends(get_verified_user)
|
||||
id: str,
|
||||
message_id: str,
|
||||
form_data: MessageForm,
|
||||
user=Depends(get_verified_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
chat = Chats.get_chat_by_id(id)
|
||||
chat = Chats.get_chat_by_id(id, db=db)
|
||||
|
||||
if not chat:
|
||||
raise HTTPException(
|
||||
|
|
@ -869,6 +941,7 @@ async def update_chat_message_by_id(
|
|||
{
|
||||
"content": form_data.content,
|
||||
},
|
||||
db=db,
|
||||
)
|
||||
|
||||
event_emitter = get_event_emitter(
|
||||
|
|
@ -905,9 +978,13 @@ class EventForm(BaseModel):
|
|||
|
||||
@router.post("/{id}/messages/{message_id}/event", response_model=Optional[bool])
|
||||
async def send_chat_message_event_by_id(
|
||||
id: str, message_id: str, form_data: EventForm, user=Depends(get_verified_user)
|
||||
id: str,
|
||||
message_id: str,
|
||||
form_data: EventForm,
|
||||
user=Depends(get_verified_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
chat = Chats.get_chat_by_id(id)
|
||||
chat = Chats.get_chat_by_id(id, db=db)
|
||||
|
||||
if not chat:
|
||||
raise HTTPException(
|
||||
|
|
@ -945,14 +1022,19 @@ async def send_chat_message_event_by_id(
|
|||
|
||||
|
||||
@router.delete("/{id}", response_model=bool)
|
||||
async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified_user)):
|
||||
async def delete_chat_by_id(
|
||||
request: Request,
|
||||
id: str,
|
||||
user=Depends(get_verified_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
if user.role == "admin":
|
||||
chat = Chats.get_chat_by_id(id)
|
||||
chat = Chats.get_chat_by_id(id, db=db)
|
||||
for tag in chat.meta.get("tags", []):
|
||||
if Chats.count_chats_by_tag_name_and_user_id(tag, user.id) == 1:
|
||||
Tags.delete_tag_by_name_and_user_id(tag, user.id)
|
||||
if Chats.count_chats_by_tag_name_and_user_id(tag, user.id, db=db) == 1:
|
||||
Tags.delete_tag_by_name_and_user_id(tag, user.id, db=db)
|
||||
|
||||
result = Chats.delete_chat_by_id(id)
|
||||
result = Chats.delete_chat_by_id(id, db=db)
|
||||
|
||||
return result
|
||||
else:
|
||||
|
|
@ -964,12 +1046,12 @@ async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified
|
|||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||
)
|
||||
|
||||
chat = Chats.get_chat_by_id(id)
|
||||
chat = Chats.get_chat_by_id(id, db=db)
|
||||
for tag in chat.meta.get("tags", []):
|
||||
if Chats.count_chats_by_tag_name_and_user_id(tag, user.id) == 1:
|
||||
Tags.delete_tag_by_name_and_user_id(tag, user.id)
|
||||
if Chats.count_chats_by_tag_name_and_user_id(tag, user.id, db=db) == 1:
|
||||
Tags.delete_tag_by_name_and_user_id(tag, user.id, db=db)
|
||||
|
||||
result = Chats.delete_chat_by_id_and_user_id(id, user.id)
|
||||
result = Chats.delete_chat_by_id_and_user_id(id, user.id, db=db)
|
||||
return result
|
||||
|
||||
|
||||
|
|
@ -979,8 +1061,10 @@ async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified
|
|||
|
||||
|
||||
@router.get("/{id}/pinned", response_model=Optional[bool])
|
||||
async def get_pinned_status_by_id(id: str, user=Depends(get_verified_user)):
|
||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
||||
async def get_pinned_status_by_id(
|
||||
id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||
):
|
||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db)
|
||||
if chat:
|
||||
return chat.pinned
|
||||
else:
|
||||
|
|
@ -995,10 +1079,12 @@ async def get_pinned_status_by_id(id: str, user=Depends(get_verified_user)):
|
|||
|
||||
|
||||
@router.post("/{id}/pin", response_model=Optional[ChatResponse])
|
||||
async def pin_chat_by_id(id: str, user=Depends(get_verified_user)):
|
||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
||||
async def pin_chat_by_id(
|
||||
id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||
):
|
||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db)
|
||||
if chat:
|
||||
chat = Chats.toggle_chat_pinned_by_id(id)
|
||||
chat = Chats.toggle_chat_pinned_by_id(id, db=db)
|
||||
return chat
|
||||
else:
|
||||
raise HTTPException(
|
||||
|
|
@ -1017,9 +1103,12 @@ class CloneForm(BaseModel):
|
|||
|
||||
@router.post("/{id}/clone", response_model=Optional[ChatResponse])
|
||||
async def clone_chat_by_id(
|
||||
form_data: CloneForm, id: str, user=Depends(get_verified_user)
|
||||
form_data: CloneForm,
|
||||
id: str,
|
||||
user=Depends(get_verified_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db)
|
||||
if chat:
|
||||
updated_chat = {
|
||||
**chat.chat,
|
||||
|
|
@ -1040,6 +1129,7 @@ async def clone_chat_by_id(
|
|||
}
|
||||
)
|
||||
],
|
||||
db=db,
|
||||
)
|
||||
|
||||
if chats:
|
||||
|
|
@ -1062,12 +1152,14 @@ async def clone_chat_by_id(
|
|||
|
||||
|
||||
@router.post("/{id}/clone/shared", response_model=Optional[ChatResponse])
|
||||
async def clone_shared_chat_by_id(id: str, user=Depends(get_verified_user)):
|
||||
async def clone_shared_chat_by_id(
|
||||
id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||
):
|
||||
|
||||
if user.role == "admin":
|
||||
chat = Chats.get_chat_by_id(id)
|
||||
chat = Chats.get_chat_by_id(id, db=db)
|
||||
else:
|
||||
chat = Chats.get_chat_by_share_id(id)
|
||||
chat = Chats.get_chat_by_share_id(id, db=db)
|
||||
|
||||
if chat:
|
||||
updated_chat = {
|
||||
|
|
@ -1089,6 +1181,7 @@ async def clone_shared_chat_by_id(id: str, user=Depends(get_verified_user)):
|
|||
}
|
||||
)
|
||||
],
|
||||
db=db,
|
||||
)
|
||||
|
||||
if chats:
|
||||
|
|
@ -1111,23 +1204,28 @@ async def clone_shared_chat_by_id(id: str, user=Depends(get_verified_user)):
|
|||
|
||||
|
||||
@router.post("/{id}/archive", response_model=Optional[ChatResponse])
|
||||
async def archive_chat_by_id(id: str, user=Depends(get_verified_user)):
|
||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
||||
async def archive_chat_by_id(
|
||||
id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||
):
|
||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db)
|
||||
if chat:
|
||||
chat = Chats.toggle_chat_archive_by_id(id)
|
||||
chat = Chats.toggle_chat_archive_by_id(id, db=db)
|
||||
|
||||
# Delete tags if chat is archived
|
||||
if chat.archived:
|
||||
for tag_id in chat.meta.get("tags", []):
|
||||
if Chats.count_chats_by_tag_name_and_user_id(tag_id, user.id) == 0:
|
||||
if (
|
||||
Chats.count_chats_by_tag_name_and_user_id(tag_id, user.id, db=db)
|
||||
== 0
|
||||
):
|
||||
log.debug(f"deleting tag: {tag_id}")
|
||||
Tags.delete_tag_by_name_and_user_id(tag_id, user.id)
|
||||
Tags.delete_tag_by_name_and_user_id(tag_id, user.id, db=db)
|
||||
else:
|
||||
for tag_id in chat.meta.get("tags", []):
|
||||
tag = Tags.get_tag_by_name_and_user_id(tag_id, user.id)
|
||||
tag = Tags.get_tag_by_name_and_user_id(tag_id, user.id, db=db)
|
||||
if tag is None:
|
||||
log.debug(f"inserting tag: {tag_id}")
|
||||
tag = Tags.insert_new_tag(tag_id, user.id)
|
||||
tag = Tags.insert_new_tag(tag_id, user.id, db=db)
|
||||
|
||||
return ChatResponse(**chat.model_dump())
|
||||
else:
|
||||
|
|
@ -1142,7 +1240,12 @@ async def archive_chat_by_id(id: str, user=Depends(get_verified_user)):
|
|||
|
||||
|
||||
@router.post("/{id}/share", response_model=Optional[ChatResponse])
|
||||
async def share_chat_by_id(request: Request, id: str, user=Depends(get_verified_user)):
|
||||
async def share_chat_by_id(
|
||||
request: Request,
|
||||
id: str,
|
||||
user=Depends(get_verified_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
if (user.role != "admin") and (
|
||||
not has_permission(
|
||||
user.id, "chat.share", request.app.state.config.USER_PERMISSIONS
|
||||
|
|
@ -1153,14 +1256,14 @@ async def share_chat_by_id(request: Request, id: str, user=Depends(get_verified_
|
|||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||
)
|
||||
|
||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db)
|
||||
|
||||
if chat:
|
||||
if chat.share_id:
|
||||
shared_chat = Chats.update_shared_chat_by_chat_id(chat.id)
|
||||
shared_chat = Chats.update_shared_chat_by_chat_id(chat.id, db=db)
|
||||
return ChatResponse(**shared_chat.model_dump())
|
||||
|
||||
shared_chat = Chats.insert_shared_chat_by_chat_id(chat.id)
|
||||
shared_chat = Chats.insert_shared_chat_by_chat_id(chat.id, db=db)
|
||||
if not shared_chat:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
|
|
@ -1181,14 +1284,16 @@ async def share_chat_by_id(request: Request, id: str, user=Depends(get_verified_
|
|||
|
||||
|
||||
@router.delete("/{id}/share", response_model=Optional[bool])
|
||||
async def delete_shared_chat_by_id(id: str, user=Depends(get_verified_user)):
|
||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
||||
async def delete_shared_chat_by_id(
|
||||
id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||
):
|
||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db)
|
||||
if chat:
|
||||
if not chat.share_id:
|
||||
return False
|
||||
|
||||
result = Chats.delete_shared_chat_by_chat_id(id)
|
||||
update_result = Chats.update_chat_share_id_by_id(id, None)
|
||||
result = Chats.delete_shared_chat_by_chat_id(id, db=db)
|
||||
update_result = Chats.update_chat_share_id_by_id(id, None, db=db)
|
||||
|
||||
return result and update_result != None
|
||||
else:
|
||||
|
|
@ -1209,12 +1314,15 @@ class ChatFolderIdForm(BaseModel):
|
|||
|
||||
@router.post("/{id}/folder", response_model=Optional[ChatResponse])
|
||||
async def update_chat_folder_id_by_id(
|
||||
id: str, form_data: ChatFolderIdForm, user=Depends(get_verified_user)
|
||||
id: str,
|
||||
form_data: ChatFolderIdForm,
|
||||
user=Depends(get_verified_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db)
|
||||
if chat:
|
||||
chat = Chats.update_chat_folder_id_by_id_and_user_id(
|
||||
id, user.id, form_data.folder_id
|
||||
id, user.id, form_data.folder_id, db=db
|
||||
)
|
||||
return ChatResponse(**chat.model_dump())
|
||||
else:
|
||||
|
|
@ -1229,11 +1337,13 @@ async def update_chat_folder_id_by_id(
|
|||
|
||||
|
||||
@router.get("/{id}/tags", response_model=list[TagModel])
|
||||
async def get_chat_tags_by_id(id: str, user=Depends(get_verified_user)):
|
||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
||||
async def get_chat_tags_by_id(
|
||||
id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||
):
|
||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db)
|
||||
if chat:
|
||||
tags = chat.meta.get("tags", [])
|
||||
return Tags.get_tags_by_ids_and_user_id(tags, user.id)
|
||||
return Tags.get_tags_by_ids_and_user_id(tags, user.id, db=db)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
|
|
@ -1247,9 +1357,12 @@ async def get_chat_tags_by_id(id: str, user=Depends(get_verified_user)):
|
|||
|
||||
@router.post("/{id}/tags", response_model=list[TagModel])
|
||||
async def add_tag_by_id_and_tag_name(
|
||||
id: str, form_data: TagForm, user=Depends(get_verified_user)
|
||||
id: str,
|
||||
form_data: TagForm,
|
||||
user=Depends(get_verified_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db)
|
||||
if chat:
|
||||
tags = chat.meta.get("tags", [])
|
||||
tag_id = form_data.name.replace(" ", "_").lower()
|
||||
|
|
@ -1262,12 +1375,12 @@ async def add_tag_by_id_and_tag_name(
|
|||
|
||||
if tag_id not in tags:
|
||||
Chats.add_chat_tag_by_id_and_user_id_and_tag_name(
|
||||
id, user.id, form_data.name
|
||||
id, user.id, form_data.name, db=db
|
||||
)
|
||||
|
||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db)
|
||||
tags = chat.meta.get("tags", [])
|
||||
return Tags.get_tags_by_ids_and_user_id(tags, user.id)
|
||||
return Tags.get_tags_by_ids_and_user_id(tags, user.id, db=db)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
|
||||
|
|
@ -1281,18 +1394,26 @@ async def add_tag_by_id_and_tag_name(
|
|||
|
||||
@router.delete("/{id}/tags", response_model=list[TagModel])
|
||||
async def delete_tag_by_id_and_tag_name(
|
||||
id: str, form_data: TagForm, user=Depends(get_verified_user)
|
||||
id: str,
|
||||
form_data: TagForm,
|
||||
user=Depends(get_verified_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db)
|
||||
if chat:
|
||||
Chats.delete_tag_by_id_and_user_id_and_tag_name(id, user.id, form_data.name)
|
||||
Chats.delete_tag_by_id_and_user_id_and_tag_name(
|
||||
id, user.id, form_data.name, db=db
|
||||
)
|
||||
|
||||
if Chats.count_chats_by_tag_name_and_user_id(form_data.name, user.id) == 0:
|
||||
Tags.delete_tag_by_name_and_user_id(form_data.name, user.id)
|
||||
if (
|
||||
Chats.count_chats_by_tag_name_and_user_id(form_data.name, user.id, db=db)
|
||||
== 0
|
||||
):
|
||||
Tags.delete_tag_by_name_and_user_id(form_data.name, user.id, db=db)
|
||||
|
||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db)
|
||||
tags = chat.meta.get("tags", [])
|
||||
return Tags.get_tags_by_ids_and_user_id(tags, user.id)
|
||||
return Tags.get_tags_by_ids_and_user_id(tags, user.id, db=db)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
|
|
@ -1305,14 +1426,16 @@ async def delete_tag_by_id_and_tag_name(
|
|||
|
||||
|
||||
@router.delete("/{id}/tags/all", response_model=Optional[bool])
|
||||
async def delete_all_tags_by_id(id: str, user=Depends(get_verified_user)):
|
||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
||||
async def delete_all_tags_by_id(
|
||||
id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||
):
|
||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db)
|
||||
if chat:
|
||||
Chats.delete_all_tags_by_id_and_user_id(id, user.id)
|
||||
Chats.delete_all_tags_by_id_and_user_id(id, user.id, db=db)
|
||||
|
||||
for tag in chat.meta.get("tags", []):
|
||||
if Chats.count_chats_by_tag_name_and_user_id(tag, user.id) == 0:
|
||||
Tags.delete_tag_by_name_and_user_id(tag, user.id)
|
||||
if Chats.count_chats_by_tag_name_and_user_id(tag, user.id, db=db) == 0:
|
||||
Tags.delete_tag_by_name_and_user_id(tag, user.id, db=db)
|
||||
|
||||
return True
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -15,6 +15,8 @@ from open_webui.models.feedbacks import (
|
|||
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||
from open_webui.internal.db import get_session
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
|
@ -60,38 +62,50 @@ async def update_config(
|
|||
|
||||
|
||||
@router.get("/feedbacks/all", response_model=list[FeedbackResponse])
|
||||
async def get_all_feedbacks(user=Depends(get_admin_user)):
|
||||
feedbacks = Feedbacks.get_all_feedbacks()
|
||||
async def get_all_feedbacks(
|
||||
user=Depends(get_admin_user), db: Session = Depends(get_session)
|
||||
):
|
||||
feedbacks = Feedbacks.get_all_feedbacks(db=db)
|
||||
return feedbacks
|
||||
|
||||
|
||||
@router.get("/feedbacks/all/ids", response_model=list[FeedbackIdResponse])
|
||||
async def get_all_feedback_ids(user=Depends(get_admin_user)):
|
||||
feedbacks = Feedbacks.get_all_feedbacks()
|
||||
async def get_all_feedback_ids(
|
||||
user=Depends(get_admin_user), db: Session = Depends(get_session)
|
||||
):
|
||||
feedbacks = Feedbacks.get_all_feedbacks(db=db)
|
||||
return feedbacks
|
||||
|
||||
|
||||
@router.delete("/feedbacks/all")
|
||||
async def delete_all_feedbacks(user=Depends(get_admin_user)):
|
||||
success = Feedbacks.delete_all_feedbacks()
|
||||
async def delete_all_feedbacks(
|
||||
user=Depends(get_admin_user), db: Session = Depends(get_session)
|
||||
):
|
||||
success = Feedbacks.delete_all_feedbacks(db=db)
|
||||
return success
|
||||
|
||||
|
||||
@router.get("/feedbacks/all/export", response_model=list[FeedbackModel])
|
||||
async def export_all_feedbacks(user=Depends(get_admin_user)):
|
||||
feedbacks = Feedbacks.get_all_feedbacks()
|
||||
async def export_all_feedbacks(
|
||||
user=Depends(get_admin_user), db: Session = Depends(get_session)
|
||||
):
|
||||
feedbacks = Feedbacks.get_all_feedbacks(db=db)
|
||||
return feedbacks
|
||||
|
||||
|
||||
@router.get("/feedbacks/user", response_model=list[FeedbackUserResponse])
|
||||
async def get_feedbacks(user=Depends(get_verified_user)):
|
||||
feedbacks = Feedbacks.get_feedbacks_by_user_id(user.id)
|
||||
async def get_feedbacks(
|
||||
user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||
):
|
||||
feedbacks = Feedbacks.get_feedbacks_by_user_id(user.id, db=db)
|
||||
return feedbacks
|
||||
|
||||
|
||||
@router.delete("/feedbacks", response_model=bool)
|
||||
async def delete_feedbacks(user=Depends(get_verified_user)):
|
||||
success = Feedbacks.delete_feedbacks_by_user_id(user.id)
|
||||
async def delete_feedbacks(
|
||||
user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||
):
|
||||
success = Feedbacks.delete_feedbacks_by_user_id(user.id, db=db)
|
||||
return success
|
||||
|
||||
|
||||
|
|
@ -104,6 +118,7 @@ async def get_feedbacks(
|
|||
direction: Optional[str] = None,
|
||||
page: Optional[int] = 1,
|
||||
user=Depends(get_admin_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
limit = PAGE_ITEM_COUNT
|
||||
|
||||
|
|
@ -116,7 +131,7 @@ async def get_feedbacks(
|
|||
if direction:
|
||||
filter["direction"] = direction
|
||||
|
||||
result = Feedbacks.get_feedback_items(filter=filter, skip=skip, limit=limit)
|
||||
result = Feedbacks.get_feedback_items(filter=filter, skip=skip, limit=limit, db=db)
|
||||
return result
|
||||
|
||||
|
||||
|
|
@ -125,8 +140,11 @@ async def create_feedback(
|
|||
request: Request,
|
||||
form_data: FeedbackForm,
|
||||
user=Depends(get_verified_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
feedback = Feedbacks.insert_new_feedback(user_id=user.id, form_data=form_data)
|
||||
feedback = Feedbacks.insert_new_feedback(
|
||||
user_id=user.id, form_data=form_data, db=db
|
||||
)
|
||||
if not feedback:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
|
|
@ -137,11 +155,15 @@ async def create_feedback(
|
|||
|
||||
|
||||
@router.get("/feedback/{id}", response_model=FeedbackModel)
|
||||
async def get_feedback_by_id(id: str, user=Depends(get_verified_user)):
|
||||
async def get_feedback_by_id(
|
||||
id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||
):
|
||||
if user.role == "admin":
|
||||
feedback = Feedbacks.get_feedback_by_id(id=id)
|
||||
feedback = Feedbacks.get_feedback_by_id(id=id, db=db)
|
||||
else:
|
||||
feedback = Feedbacks.get_feedback_by_id_and_user_id(id=id, user_id=user.id)
|
||||
feedback = Feedbacks.get_feedback_by_id_and_user_id(
|
||||
id=id, user_id=user.id, db=db
|
||||
)
|
||||
|
||||
if not feedback:
|
||||
raise HTTPException(
|
||||
|
|
@ -153,13 +175,16 @@ async def get_feedback_by_id(id: str, user=Depends(get_verified_user)):
|
|||
|
||||
@router.post("/feedback/{id}", response_model=FeedbackModel)
|
||||
async def update_feedback_by_id(
|
||||
id: str, form_data: FeedbackForm, user=Depends(get_verified_user)
|
||||
id: str,
|
||||
form_data: FeedbackForm,
|
||||
user=Depends(get_verified_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
if user.role == "admin":
|
||||
feedback = Feedbacks.update_feedback_by_id(id=id, form_data=form_data)
|
||||
feedback = Feedbacks.update_feedback_by_id(id=id, form_data=form_data, db=db)
|
||||
else:
|
||||
feedback = Feedbacks.update_feedback_by_id_and_user_id(
|
||||
id=id, user_id=user.id, form_data=form_data
|
||||
id=id, user_id=user.id, form_data=form_data, db=db
|
||||
)
|
||||
|
||||
if not feedback:
|
||||
|
|
@ -171,11 +196,15 @@ async def update_feedback_by_id(
|
|||
|
||||
|
||||
@router.delete("/feedback/{id}")
|
||||
async def delete_feedback_by_id(id: str, user=Depends(get_verified_user)):
|
||||
async def delete_feedback_by_id(
|
||||
id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||
):
|
||||
if user.role == "admin":
|
||||
success = Feedbacks.delete_feedback_by_id(id=id)
|
||||
success = Feedbacks.delete_feedback_by_id(id=id, db=db)
|
||||
else:
|
||||
success = Feedbacks.delete_feedback_by_id_and_user_id(id=id, user_id=user.id)
|
||||
success = Feedbacks.delete_feedback_by_id_and_user_id(
|
||||
id=id, user_id=user.id, db=db
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
|
|
|
|||
|
|
@ -22,6 +22,8 @@ from fastapi import (
|
|||
)
|
||||
|
||||
from fastapi.responses import FileResponse, StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
from open_webui.internal.db import get_session, SessionLocal
|
||||
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT
|
||||
|
|
@ -62,9 +64,12 @@ router = APIRouter()
|
|||
|
||||
# TODO: Optimize this function to use the knowledge_file table for faster lookups.
|
||||
def has_access_to_file(
|
||||
file_id: Optional[str], access_type: str, user=Depends(get_verified_user)
|
||||
file_id: Optional[str],
|
||||
access_type: str,
|
||||
user=Depends(get_verified_user),
|
||||
db: Optional[Session] = None,
|
||||
) -> bool:
|
||||
file = Files.get_file_by_id(file_id)
|
||||
file = Files.get_file_by_id(file_id, db=db)
|
||||
log.debug(f"Checking if user has {access_type} access to file")
|
||||
if not file:
|
||||
raise HTTPException(
|
||||
|
|
@ -73,31 +78,33 @@ def has_access_to_file(
|
|||
)
|
||||
|
||||
# Check if the file is associated with any knowledge bases the user has access to
|
||||
knowledge_bases = Knowledges.get_knowledges_by_file_id(file_id)
|
||||
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user.id)}
|
||||
knowledge_bases = Knowledges.get_knowledges_by_file_id(file_id, db=db)
|
||||
user_group_ids = {
|
||||
group.id for group in Groups.get_groups_by_member_id(user.id, db=db)
|
||||
}
|
||||
for knowledge_base in knowledge_bases:
|
||||
if knowledge_base.user_id == user.id or has_access(
|
||||
user.id, access_type, knowledge_base.access_control, user_group_ids
|
||||
user.id, access_type, knowledge_base.access_control, user_group_ids, db=db
|
||||
):
|
||||
return True
|
||||
|
||||
knowledge_base_id = file.meta.get("collection_name") if file.meta else None
|
||||
if knowledge_base_id:
|
||||
knowledge_bases = Knowledges.get_knowledge_bases_by_user_id(
|
||||
user.id, access_type
|
||||
user.id, access_type, db=db
|
||||
)
|
||||
for knowledge_base in knowledge_bases:
|
||||
if knowledge_base.id == knowledge_base_id:
|
||||
return True
|
||||
|
||||
# Check if the file is associated with any channels the user has access to
|
||||
channels = Channels.get_channels_by_file_id_and_user_id(file_id, user.id)
|
||||
channels = Channels.get_channels_by_file_id_and_user_id(file_id, user.id, db=db)
|
||||
if access_type == "read" and channels:
|
||||
return True
|
||||
|
||||
# Check if the file is associated with any chats the user has access to
|
||||
# TODO: Granular access control for chats
|
||||
chats = Chats.get_shared_chats_by_file_id(file_id)
|
||||
chats = Chats.get_shared_chats_by_file_id(file_id, db=db)
|
||||
if chats:
|
||||
return True
|
||||
|
||||
|
|
@ -109,47 +116,78 @@ def has_access_to_file(
|
|||
############################
|
||||
|
||||
|
||||
def process_uploaded_file(request, file, file_path, file_item, file_metadata, user):
|
||||
try:
|
||||
if file.content_type:
|
||||
stt_supported_content_types = getattr(
|
||||
request.app.state.config, "STT_SUPPORTED_CONTENT_TYPES", []
|
||||
)
|
||||
def process_uploaded_file(
|
||||
request,
|
||||
file,
|
||||
file_path,
|
||||
file_item,
|
||||
file_metadata,
|
||||
user,
|
||||
db: Optional[Session] = None,
|
||||
):
|
||||
def _process_handler(db_session):
|
||||
try:
|
||||
if file.content_type:
|
||||
stt_supported_content_types = getattr(
|
||||
request.app.state.config, "STT_SUPPORTED_CONTENT_TYPES", []
|
||||
)
|
||||
|
||||
if strict_match_mime_type(stt_supported_content_types, file.content_type):
|
||||
file_path = Storage.get_file(file_path)
|
||||
result = transcribe(request, file_path, file_metadata, user)
|
||||
if strict_match_mime_type(
|
||||
stt_supported_content_types, file.content_type
|
||||
):
|
||||
file_path_processed = Storage.get_file(file_path)
|
||||
result = transcribe(
|
||||
request, file_path_processed, file_metadata, user
|
||||
)
|
||||
|
||||
process_file(
|
||||
request,
|
||||
ProcessFileForm(
|
||||
file_id=file_item.id, content=result.get("text", "")
|
||||
),
|
||||
user=user,
|
||||
db=db_session,
|
||||
)
|
||||
elif (not file.content_type.startswith(("image/", "video/"))) or (
|
||||
request.app.state.config.CONTENT_EXTRACTION_ENGINE == "external"
|
||||
):
|
||||
process_file(
|
||||
request,
|
||||
ProcessFileForm(file_id=file_item.id),
|
||||
user=user,
|
||||
db=db_session,
|
||||
)
|
||||
else:
|
||||
raise Exception(
|
||||
f"File type {file.content_type} is not supported for processing"
|
||||
)
|
||||
else:
|
||||
log.info(
|
||||
f"File type {file.content_type} is not provided, but trying to process anyway"
|
||||
)
|
||||
process_file(
|
||||
request,
|
||||
ProcessFileForm(
|
||||
file_id=file_item.id, content=result.get("text", "")
|
||||
),
|
||||
ProcessFileForm(file_id=file_item.id),
|
||||
user=user,
|
||||
db=db_session,
|
||||
)
|
||||
elif (not file.content_type.startswith(("image/", "video/"))) or (
|
||||
request.app.state.config.CONTENT_EXTRACTION_ENGINE == "external"
|
||||
):
|
||||
process_file(request, ProcessFileForm(file_id=file_item.id), user=user)
|
||||
else:
|
||||
raise Exception(
|
||||
f"File type {file.content_type} is not supported for processing"
|
||||
)
|
||||
else:
|
||||
log.info(
|
||||
f"File type {file.content_type} is not provided, but trying to process anyway"
|
||||
)
|
||||
process_file(request, ProcessFileForm(file_id=file_item.id), user=user)
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Error processing file: {file_item.id}")
|
||||
Files.update_file_data_by_id(
|
||||
file_item.id,
|
||||
{
|
||||
"status": "failed",
|
||||
"error": str(e.detail) if hasattr(e, "detail") else str(e),
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"Error processing file: {file_item.id}")
|
||||
Files.update_file_data_by_id(
|
||||
file_item.id,
|
||||
{
|
||||
"status": "failed",
|
||||
"error": str(e.detail) if hasattr(e, "detail") else str(e),
|
||||
},
|
||||
db=db_session,
|
||||
)
|
||||
|
||||
if db:
|
||||
_process_handler(db)
|
||||
else:
|
||||
with SessionLocal() as db_session:
|
||||
_process_handler(db_session)
|
||||
|
||||
|
||||
@router.post("/", response_model=FileModelResponse)
|
||||
|
|
@ -161,6 +199,7 @@ def upload_file(
|
|||
process: bool = Query(True),
|
||||
process_in_background: bool = Query(True),
|
||||
user=Depends(get_verified_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
return upload_file_handler(
|
||||
request,
|
||||
|
|
@ -170,6 +209,7 @@ def upload_file(
|
|||
process_in_background=process_in_background,
|
||||
user=user,
|
||||
background_tasks=background_tasks,
|
||||
db=db,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -181,6 +221,7 @@ def upload_file_handler(
|
|||
process_in_background: bool = Query(True),
|
||||
user=Depends(get_verified_user),
|
||||
background_tasks: Optional[BackgroundTasks] = None,
|
||||
db: Optional[Session] = None,
|
||||
):
|
||||
log.info(f"file.content_type: {file.content_type} {process}")
|
||||
|
||||
|
|
@ -248,14 +289,17 @@ def upload_file_handler(
|
|||
},
|
||||
}
|
||||
),
|
||||
db=db,
|
||||
)
|
||||
|
||||
if "channel_id" in file_metadata:
|
||||
channel = Channels.get_channel_by_id_and_user_id(
|
||||
file_metadata["channel_id"], user.id
|
||||
file_metadata["channel_id"], user.id, db=db
|
||||
)
|
||||
if channel:
|
||||
Channels.add_file_to_channel_by_id(channel.id, file_item.id, user.id)
|
||||
Channels.add_file_to_channel_by_id(
|
||||
channel.id, file_item.id, user.id, db=db
|
||||
)
|
||||
|
||||
if process:
|
||||
if background_tasks and process_in_background:
|
||||
|
|
@ -277,6 +321,7 @@ def upload_file_handler(
|
|||
file_item,
|
||||
file_metadata,
|
||||
user,
|
||||
db=db,
|
||||
)
|
||||
return {"status": True, **file_item.model_dump()}
|
||||
else:
|
||||
|
|
@ -302,11 +347,15 @@ def upload_file_handler(
|
|||
|
||||
|
||||
@router.get("/", response_model=list[FileModelResponse])
|
||||
async def list_files(user=Depends(get_verified_user), content: bool = Query(True)):
|
||||
async def list_files(
|
||||
user=Depends(get_verified_user),
|
||||
content: bool = Query(True),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
if user.role == "admin":
|
||||
files = Files.get_files()
|
||||
files = Files.get_files(db=db)
|
||||
else:
|
||||
files = Files.get_files_by_user_id(user.id)
|
||||
files = Files.get_files_by_user_id(user.id, db=db)
|
||||
|
||||
if not content:
|
||||
for file in files:
|
||||
|
|
@ -329,15 +378,16 @@ async def search_files(
|
|||
),
|
||||
content: bool = Query(True),
|
||||
user=Depends(get_verified_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
"""
|
||||
Search for files by filename with support for wildcard patterns.
|
||||
"""
|
||||
# Get files according to user role
|
||||
if user.role == "admin":
|
||||
files = Files.get_files()
|
||||
files = Files.get_files(db=db)
|
||||
else:
|
||||
files = Files.get_files_by_user_id(user.id)
|
||||
files = Files.get_files_by_user_id(user.id, db=db)
|
||||
|
||||
# Get matching files
|
||||
matching_files = [
|
||||
|
|
@ -364,8 +414,10 @@ async def search_files(
|
|||
|
||||
|
||||
@router.delete("/all")
|
||||
async def delete_all_files(user=Depends(get_admin_user)):
|
||||
result = Files.delete_all_files()
|
||||
async def delete_all_files(
|
||||
user=Depends(get_admin_user), db: Session = Depends(get_session)
|
||||
):
|
||||
result = Files.delete_all_files(db=db)
|
||||
if result:
|
||||
try:
|
||||
Storage.delete_all_files()
|
||||
|
|
@ -391,8 +443,10 @@ async def delete_all_files(user=Depends(get_admin_user)):
|
|||
|
||||
|
||||
@router.get("/{id}", response_model=Optional[FileModel])
|
||||
async def get_file_by_id(id: str, user=Depends(get_verified_user)):
|
||||
file = Files.get_file_by_id(id)
|
||||
async def get_file_by_id(
|
||||
id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||
):
|
||||
file = Files.get_file_by_id(id, db=db)
|
||||
|
||||
if not file:
|
||||
raise HTTPException(
|
||||
|
|
@ -403,7 +457,7 @@ async def get_file_by_id(id: str, user=Depends(get_verified_user)):
|
|||
if (
|
||||
file.user_id == user.id
|
||||
or user.role == "admin"
|
||||
or has_access_to_file(id, "read", user)
|
||||
or has_access_to_file(id, "read", user, db=db)
|
||||
):
|
||||
return file
|
||||
else:
|
||||
|
|
@ -415,9 +469,12 @@ async def get_file_by_id(id: str, user=Depends(get_verified_user)):
|
|||
|
||||
@router.get("/{id}/process/status")
|
||||
async def get_file_process_status(
|
||||
id: str, stream: bool = Query(False), user=Depends(get_verified_user)
|
||||
id: str,
|
||||
stream: bool = Query(False),
|
||||
user=Depends(get_verified_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
file = Files.get_file_by_id(id)
|
||||
file = Files.get_file_by_id(id, db=db)
|
||||
|
||||
if not file:
|
||||
raise HTTPException(
|
||||
|
|
@ -428,7 +485,7 @@ async def get_file_process_status(
|
|||
if (
|
||||
file.user_id == user.id
|
||||
or user.role == "admin"
|
||||
or has_access_to_file(id, "read", user)
|
||||
or has_access_to_file(id, "read", user, db=db)
|
||||
):
|
||||
if stream:
|
||||
MAX_FILE_PROCESSING_DURATION = 3600 * 2
|
||||
|
|
@ -436,7 +493,7 @@ async def get_file_process_status(
|
|||
async def event_stream(file_item):
|
||||
if file_item:
|
||||
for _ in range(MAX_FILE_PROCESSING_DURATION):
|
||||
file_item = Files.get_file_by_id(file_item.id)
|
||||
file_item = Files.get_file_by_id(file_item.id, db=db)
|
||||
if file_item:
|
||||
data = file_item.model_dump().get("data", {})
|
||||
status = data.get("status")
|
||||
|
|
@ -476,8 +533,10 @@ async def get_file_process_status(
|
|||
|
||||
|
||||
@router.get("/{id}/data/content")
|
||||
async def get_file_data_content_by_id(id: str, user=Depends(get_verified_user)):
|
||||
file = Files.get_file_by_id(id)
|
||||
async def get_file_data_content_by_id(
|
||||
id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||
):
|
||||
file = Files.get_file_by_id(id, db=db)
|
||||
|
||||
if not file:
|
||||
raise HTTPException(
|
||||
|
|
@ -488,7 +547,7 @@ async def get_file_data_content_by_id(id: str, user=Depends(get_verified_user)):
|
|||
if (
|
||||
file.user_id == user.id
|
||||
or user.role == "admin"
|
||||
or has_access_to_file(id, "read", user)
|
||||
or has_access_to_file(id, "read", user, db=db)
|
||||
):
|
||||
return {"content": file.data.get("content", "")}
|
||||
else:
|
||||
|
|
@ -509,9 +568,13 @@ class ContentForm(BaseModel):
|
|||
|
||||
@router.post("/{id}/data/content/update")
|
||||
async def update_file_data_content_by_id(
|
||||
request: Request, id: str, form_data: ContentForm, user=Depends(get_verified_user)
|
||||
request: Request,
|
||||
id: str,
|
||||
form_data: ContentForm,
|
||||
user=Depends(get_verified_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
file = Files.get_file_by_id(id)
|
||||
file = Files.get_file_by_id(id, db=db)
|
||||
|
||||
if not file:
|
||||
raise HTTPException(
|
||||
|
|
@ -522,7 +585,7 @@ async def update_file_data_content_by_id(
|
|||
if (
|
||||
file.user_id == user.id
|
||||
or user.role == "admin"
|
||||
or has_access_to_file(id, "write", user)
|
||||
or has_access_to_file(id, "write", user, db=db)
|
||||
):
|
||||
try:
|
||||
process_file(
|
||||
|
|
@ -530,7 +593,7 @@ async def update_file_data_content_by_id(
|
|||
ProcessFileForm(file_id=id, content=form_data.content),
|
||||
user=user,
|
||||
)
|
||||
file = Files.get_file_by_id(id=id)
|
||||
file = Files.get_file_by_id(id=id, db=db)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
log.error(f"Error processing file: {file.id}")
|
||||
|
|
@ -550,9 +613,12 @@ async def update_file_data_content_by_id(
|
|||
|
||||
@router.get("/{id}/content")
|
||||
async def get_file_content_by_id(
|
||||
id: str, user=Depends(get_verified_user), attachment: bool = Query(False)
|
||||
id: str,
|
||||
user=Depends(get_verified_user),
|
||||
attachment: bool = Query(False),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
file = Files.get_file_by_id(id)
|
||||
file = Files.get_file_by_id(id, db=db)
|
||||
|
||||
if not file:
|
||||
raise HTTPException(
|
||||
|
|
@ -563,7 +629,7 @@ async def get_file_content_by_id(
|
|||
if (
|
||||
file.user_id == user.id
|
||||
or user.role == "admin"
|
||||
or has_access_to_file(id, "read", user)
|
||||
or has_access_to_file(id, "read", user, db=db)
|
||||
):
|
||||
try:
|
||||
file_path = Storage.get_file(file.path)
|
||||
|
|
@ -619,8 +685,10 @@ async def get_file_content_by_id(
|
|||
|
||||
|
||||
@router.get("/{id}/content/html")
|
||||
async def get_html_file_content_by_id(id: str, user=Depends(get_verified_user)):
|
||||
file = Files.get_file_by_id(id)
|
||||
async def get_html_file_content_by_id(
|
||||
id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||
):
|
||||
file = Files.get_file_by_id(id, db=db)
|
||||
|
||||
if not file:
|
||||
raise HTTPException(
|
||||
|
|
@ -628,7 +696,7 @@ async def get_html_file_content_by_id(id: str, user=Depends(get_verified_user)):
|
|||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
file_user = Users.get_user_by_id(file.user_id)
|
||||
file_user = Users.get_user_by_id(file.user_id, db=db)
|
||||
if not file_user.role == "admin":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
|
|
@ -638,7 +706,7 @@ async def get_html_file_content_by_id(id: str, user=Depends(get_verified_user)):
|
|||
if (
|
||||
file.user_id == user.id
|
||||
or user.role == "admin"
|
||||
or has_access_to_file(id, "read", user)
|
||||
or has_access_to_file(id, "read", user, db=db)
|
||||
):
|
||||
try:
|
||||
file_path = Storage.get_file(file.path)
|
||||
|
|
@ -668,8 +736,10 @@ async def get_html_file_content_by_id(id: str, user=Depends(get_verified_user)):
|
|||
|
||||
|
||||
@router.get("/{id}/content/{file_name}")
|
||||
async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
|
||||
file = Files.get_file_by_id(id)
|
||||
async def get_file_content_by_id(
|
||||
id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||
):
|
||||
file = Files.get_file_by_id(id, db=db)
|
||||
|
||||
if not file:
|
||||
raise HTTPException(
|
||||
|
|
@ -680,7 +750,7 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
|
|||
if (
|
||||
file.user_id == user.id
|
||||
or user.role == "admin"
|
||||
or has_access_to_file(id, "read", user)
|
||||
or has_access_to_file(id, "read", user, db=db)
|
||||
):
|
||||
file_path = file.path
|
||||
|
||||
|
|
@ -730,8 +800,10 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
|
|||
|
||||
|
||||
@router.delete("/{id}")
|
||||
async def delete_file_by_id(id: str, user=Depends(get_verified_user)):
|
||||
file = Files.get_file_by_id(id)
|
||||
async def delete_file_by_id(
|
||||
id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||
):
|
||||
file = Files.get_file_by_id(id, db=db)
|
||||
|
||||
if not file:
|
||||
raise HTTPException(
|
||||
|
|
@ -742,10 +814,10 @@ async def delete_file_by_id(id: str, user=Depends(get_verified_user)):
|
|||
if (
|
||||
file.user_id == user.id
|
||||
or user.role == "admin"
|
||||
or has_access_to_file(id, "write", user)
|
||||
or has_access_to_file(id, "write", user, db=db)
|
||||
):
|
||||
|
||||
result = Files.delete_file_by_id(id)
|
||||
result = Files.delete_file_by_id(id, db=db)
|
||||
if result:
|
||||
try:
|
||||
Storage.delete_file(file.path)
|
||||
|
|
|
|||
|
|
@ -22,6 +22,8 @@ from open_webui.models.knowledge import Knowledges
|
|||
|
||||
from open_webui.config import UPLOAD_DIR
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from open_webui.internal.db import get_session
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
||||
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status, Request
|
||||
|
|
@ -44,7 +46,11 @@ router = APIRouter()
|
|||
|
||||
|
||||
@router.get("/", response_model=list[FolderNameIdResponse])
|
||||
async def get_folders(request: Request, user=Depends(get_verified_user)):
|
||||
async def get_folders(
|
||||
request: Request,
|
||||
user=Depends(get_verified_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
if request.app.state.config.ENABLE_FOLDERS is False:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
|
|
@ -55,22 +61,23 @@ async def get_folders(request: Request, user=Depends(get_verified_user)):
|
|||
user.id,
|
||||
"features.folders",
|
||||
request.app.state.config.USER_PERMISSIONS,
|
||||
db=db,
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||
)
|
||||
|
||||
folders = Folders.get_folders_by_user_id(user.id)
|
||||
folders = Folders.get_folders_by_user_id(user.id, db=db)
|
||||
|
||||
# Verify folder data integrity
|
||||
folder_list = []
|
||||
for folder in folders:
|
||||
if folder.parent_id and not Folders.get_folder_by_id_and_user_id(
|
||||
folder.parent_id, user.id
|
||||
folder.parent_id, user.id, db=db
|
||||
):
|
||||
folder = Folders.update_folder_parent_id_by_id_and_user_id(
|
||||
folder.id, user.id, None
|
||||
folder.id, user.id, None, db=db
|
||||
)
|
||||
|
||||
if folder.data:
|
||||
|
|
@ -80,12 +87,12 @@ async def get_folders(request: Request, user=Depends(get_verified_user)):
|
|||
|
||||
if file.get("type") == "file":
|
||||
if Files.check_access_by_user_id(
|
||||
file.get("id"), user.id, "read"
|
||||
file.get("id"), user.id, "read", db=db
|
||||
):
|
||||
valid_files.append(file)
|
||||
elif file.get("type") == "collection":
|
||||
if Knowledges.check_access_by_user_id(
|
||||
file.get("id"), user.id, "read"
|
||||
file.get("id"), user.id, "read", db=db
|
||||
):
|
||||
valid_files.append(file)
|
||||
else:
|
||||
|
|
@ -93,7 +100,7 @@ async def get_folders(request: Request, user=Depends(get_verified_user)):
|
|||
|
||||
folder.data["files"] = valid_files
|
||||
Folders.update_folder_by_id_and_user_id(
|
||||
folder.id, user.id, FolderUpdateForm(data=folder.data)
|
||||
folder.id, user.id, FolderUpdateForm(data=folder.data), db=db
|
||||
)
|
||||
|
||||
folder_list.append(FolderNameIdResponse(**folder.model_dump()))
|
||||
|
|
@ -107,9 +114,13 @@ async def get_folders(request: Request, user=Depends(get_verified_user)):
|
|||
|
||||
|
||||
@router.post("/")
|
||||
def create_folder(form_data: FolderForm, user=Depends(get_verified_user)):
|
||||
def create_folder(
|
||||
form_data: FolderForm,
|
||||
user=Depends(get_verified_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
folder = Folders.get_folder_by_parent_id_and_user_id_and_name(
|
||||
None, user.id, form_data.name
|
||||
None, user.id, form_data.name, db=db
|
||||
)
|
||||
|
||||
if folder:
|
||||
|
|
@ -119,7 +130,7 @@ def create_folder(form_data: FolderForm, user=Depends(get_verified_user)):
|
|||
)
|
||||
|
||||
try:
|
||||
folder = Folders.insert_new_folder(user.id, form_data)
|
||||
folder = Folders.insert_new_folder(user.id, form_data, db=db)
|
||||
return folder
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
|
@ -136,8 +147,10 @@ def create_folder(form_data: FolderForm, user=Depends(get_verified_user)):
|
|||
|
||||
|
||||
@router.get("/{id}", response_model=Optional[FolderModel])
|
||||
async def get_folder_by_id(id: str, user=Depends(get_verified_user)):
|
||||
folder = Folders.get_folder_by_id_and_user_id(id, user.id)
|
||||
async def get_folder_by_id(
|
||||
id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||
):
|
||||
folder = Folders.get_folder_by_id_and_user_id(id, user.id, db=db)
|
||||
if folder:
|
||||
return folder
|
||||
else:
|
||||
|
|
@ -154,15 +167,18 @@ async def get_folder_by_id(id: str, user=Depends(get_verified_user)):
|
|||
|
||||
@router.post("/{id}/update")
|
||||
async def update_folder_name_by_id(
|
||||
id: str, form_data: FolderUpdateForm, user=Depends(get_verified_user)
|
||||
id: str,
|
||||
form_data: FolderUpdateForm,
|
||||
user=Depends(get_verified_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
folder = Folders.get_folder_by_id_and_user_id(id, user.id)
|
||||
folder = Folders.get_folder_by_id_and_user_id(id, user.id, db=db)
|
||||
if folder:
|
||||
|
||||
if form_data.name is not None:
|
||||
# Check if folder with same name exists
|
||||
existing_folder = Folders.get_folder_by_parent_id_and_user_id_and_name(
|
||||
folder.parent_id, user.id, form_data.name
|
||||
folder.parent_id, user.id, form_data.name, db=db
|
||||
)
|
||||
if existing_folder and existing_folder.id != id:
|
||||
raise HTTPException(
|
||||
|
|
@ -171,7 +187,9 @@ async def update_folder_name_by_id(
|
|||
)
|
||||
|
||||
try:
|
||||
folder = Folders.update_folder_by_id_and_user_id(id, user.id, form_data)
|
||||
folder = Folders.update_folder_by_id_and_user_id(
|
||||
id, user.id, form_data, db=db
|
||||
)
|
||||
return folder
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
|
@ -198,12 +216,15 @@ class FolderParentIdForm(BaseModel):
|
|||
|
||||
@router.post("/{id}/update/parent")
|
||||
async def update_folder_parent_id_by_id(
|
||||
id: str, form_data: FolderParentIdForm, user=Depends(get_verified_user)
|
||||
id: str,
|
||||
form_data: FolderParentIdForm,
|
||||
user=Depends(get_verified_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
folder = Folders.get_folder_by_id_and_user_id(id, user.id)
|
||||
folder = Folders.get_folder_by_id_and_user_id(id, user.id, db=db)
|
||||
if folder:
|
||||
existing_folder = Folders.get_folder_by_parent_id_and_user_id_and_name(
|
||||
form_data.parent_id, user.id, folder.name
|
||||
form_data.parent_id, user.id, folder.name, db=db
|
||||
)
|
||||
|
||||
if existing_folder:
|
||||
|
|
@ -214,7 +235,7 @@ async def update_folder_parent_id_by_id(
|
|||
|
||||
try:
|
||||
folder = Folders.update_folder_parent_id_by_id_and_user_id(
|
||||
id, user.id, form_data.parent_id
|
||||
id, user.id, form_data.parent_id, db=db
|
||||
)
|
||||
return folder
|
||||
except Exception as e:
|
||||
|
|
@ -242,13 +263,16 @@ class FolderIsExpandedForm(BaseModel):
|
|||
|
||||
@router.post("/{id}/update/expanded")
|
||||
async def update_folder_is_expanded_by_id(
|
||||
id: str, form_data: FolderIsExpandedForm, user=Depends(get_verified_user)
|
||||
id: str,
|
||||
form_data: FolderIsExpandedForm,
|
||||
user=Depends(get_verified_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
folder = Folders.get_folder_by_id_and_user_id(id, user.id)
|
||||
folder = Folders.get_folder_by_id_and_user_id(id, user.id, db=db)
|
||||
if folder:
|
||||
try:
|
||||
folder = Folders.update_folder_is_expanded_by_id_and_user_id(
|
||||
id, user.id, form_data.is_expanded
|
||||
id, user.id, form_data.is_expanded, db=db
|
||||
)
|
||||
return folder
|
||||
except Exception as e:
|
||||
|
|
@ -276,10 +300,11 @@ async def delete_folder_by_id(
|
|||
id: str,
|
||||
delete_contents: Optional[bool] = True,
|
||||
user=Depends(get_verified_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
if Chats.count_chats_by_folder_id_and_user_id(id, user.id):
|
||||
if Chats.count_chats_by_folder_id_and_user_id(id, user.id, db=db):
|
||||
chat_delete_permission = has_permission(
|
||||
user.id, "chat.delete", request.app.state.config.USER_PERMISSIONS
|
||||
user.id, "chat.delete", request.app.state.config.USER_PERMISSIONS, db=db
|
||||
)
|
||||
if user.role != "admin" and not chat_delete_permission:
|
||||
raise HTTPException(
|
||||
|
|
@ -288,19 +313,21 @@ async def delete_folder_by_id(
|
|||
)
|
||||
|
||||
folders = []
|
||||
folders.append(Folders.get_folder_by_id_and_user_id(id, user.id))
|
||||
folders.append(Folders.get_folder_by_id_and_user_id(id, user.id, db=db))
|
||||
while folders:
|
||||
folder = folders.pop()
|
||||
if folder:
|
||||
try:
|
||||
folder_ids = Folders.delete_folder_by_id_and_user_id(id, user.id)
|
||||
folder_ids = Folders.delete_folder_by_id_and_user_id(id, user.id, db=db)
|
||||
|
||||
for folder_id in folder_ids:
|
||||
if delete_contents:
|
||||
Chats.delete_chats_by_user_id_and_folder_id(user.id, folder_id)
|
||||
Chats.delete_chats_by_user_id_and_folder_id(
|
||||
user.id, folder_id, db=db
|
||||
)
|
||||
else:
|
||||
Chats.move_chats_by_user_id_and_folder_id(
|
||||
user.id, folder_id, None
|
||||
user.id, folder_id, None, db=db
|
||||
)
|
||||
|
||||
return True
|
||||
|
|
@ -314,7 +341,7 @@ async def delete_folder_by_id(
|
|||
finally:
|
||||
# Get all subfolders
|
||||
subfolders = Folders.get_folders_by_parent_id_and_user_id(
|
||||
folder.id, user.id
|
||||
folder.id, user.id, db=db
|
||||
)
|
||||
folders.extend(subfolders)
|
||||
|
||||
|
|
|
|||
|
|
@ -24,6 +24,8 @@ from open_webui.constants import ERROR_MESSAGES
|
|||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||
from pydantic import BaseModel, HttpUrl
|
||||
from open_webui.internal.db import get_session
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
|
@ -37,13 +39,13 @@ router = APIRouter()
|
|||
|
||||
|
||||
@router.get("/", response_model=list[FunctionResponse])
|
||||
async def get_functions(user=Depends(get_verified_user)):
|
||||
return Functions.get_functions()
|
||||
async def get_functions(user=Depends(get_verified_user), db: Session = Depends(get_session)):
|
||||
return Functions.get_functions(db=db)
|
||||
|
||||
|
||||
@router.get("/list", response_model=list[FunctionUserResponse])
|
||||
async def get_function_list(user=Depends(get_admin_user)):
|
||||
return Functions.get_function_list()
|
||||
async def get_function_list(user=Depends(get_admin_user), db: Session = Depends(get_session)):
|
||||
return Functions.get_function_list(db=db)
|
||||
|
||||
|
||||
############################
|
||||
|
|
@ -52,8 +54,8 @@ async def get_function_list(user=Depends(get_admin_user)):
|
|||
|
||||
|
||||
@router.get("/export", response_model=list[FunctionModel | FunctionWithValvesModel])
|
||||
async def get_functions(include_valves: bool = False, user=Depends(get_admin_user)):
|
||||
return Functions.get_functions(include_valves=include_valves)
|
||||
async def get_functions(include_valves: bool = False, user=Depends(get_admin_user), db: Session = Depends(get_session)):
|
||||
return Functions.get_functions(include_valves=include_valves, db=db)
|
||||
|
||||
|
||||
############################
|
||||
|
|
@ -142,7 +144,7 @@ class SyncFunctionsForm(BaseModel):
|
|||
|
||||
@router.post("/sync", response_model=list[FunctionWithValvesModel])
|
||||
async def sync_functions(
|
||||
request: Request, form_data: SyncFunctionsForm, user=Depends(get_admin_user)
|
||||
request: Request, form_data: SyncFunctionsForm, user=Depends(get_admin_user), db: Session = Depends(get_session)
|
||||
):
|
||||
try:
|
||||
for function in form_data.functions:
|
||||
|
|
@ -164,7 +166,7 @@ async def sync_functions(
|
|||
)
|
||||
raise e
|
||||
|
||||
return Functions.sync_functions(user.id, form_data.functions)
|
||||
return Functions.sync_functions(user.id, form_data.functions, db=db)
|
||||
except Exception as e:
|
||||
log.exception(f"Failed to load a function: {e}")
|
||||
raise HTTPException(
|
||||
|
|
@ -180,7 +182,7 @@ async def sync_functions(
|
|||
|
||||
@router.post("/create", response_model=Optional[FunctionResponse])
|
||||
async def create_new_function(
|
||||
request: Request, form_data: FunctionForm, user=Depends(get_admin_user)
|
||||
request: Request, form_data: FunctionForm, user=Depends(get_admin_user), db: Session = Depends(get_session)
|
||||
):
|
||||
if not form_data.id.isidentifier():
|
||||
raise HTTPException(
|
||||
|
|
@ -190,7 +192,7 @@ async def create_new_function(
|
|||
|
||||
form_data.id = form_data.id.lower()
|
||||
|
||||
function = Functions.get_function_by_id(form_data.id)
|
||||
function = Functions.get_function_by_id(form_data.id, db=db)
|
||||
if function is None:
|
||||
try:
|
||||
form_data.content = replace_imports(form_data.content)
|
||||
|
|
@ -203,13 +205,13 @@ async def create_new_function(
|
|||
FUNCTIONS = request.app.state.FUNCTIONS
|
||||
FUNCTIONS[form_data.id] = function_module
|
||||
|
||||
function = Functions.insert_new_function(user.id, function_type, form_data)
|
||||
function = Functions.insert_new_function(user.id, function_type, form_data, db=db)
|
||||
|
||||
function_cache_dir = CACHE_DIR / "functions" / form_data.id
|
||||
function_cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if function_type == "filter" and getattr(function_module, "toggle", None):
|
||||
Functions.update_function_metadata_by_id(id, {"toggle": True})
|
||||
Functions.update_function_metadata_by_id(form_data.id, {"toggle": True}, db=db)
|
||||
|
||||
if function:
|
||||
return function
|
||||
|
|
@ -237,8 +239,8 @@ async def create_new_function(
|
|||
|
||||
|
||||
@router.get("/id/{id}", response_model=Optional[FunctionModel])
|
||||
async def get_function_by_id(id: str, user=Depends(get_admin_user)):
|
||||
function = Functions.get_function_by_id(id)
|
||||
async def get_function_by_id(id: str, user=Depends(get_admin_user), db: Session = Depends(get_session)):
|
||||
function = Functions.get_function_by_id(id, db=db)
|
||||
|
||||
if function:
|
||||
return function
|
||||
|
|
@ -255,11 +257,11 @@ async def get_function_by_id(id: str, user=Depends(get_admin_user)):
|
|||
|
||||
|
||||
@router.post("/id/{id}/toggle", response_model=Optional[FunctionModel])
|
||||
async def toggle_function_by_id(id: str, user=Depends(get_admin_user)):
|
||||
function = Functions.get_function_by_id(id)
|
||||
async def toggle_function_by_id(id: str, user=Depends(get_admin_user), db: Session = Depends(get_session)):
|
||||
function = Functions.get_function_by_id(id, db=db)
|
||||
if function:
|
||||
function = Functions.update_function_by_id(
|
||||
id, {"is_active": not function.is_active}
|
||||
id, {"is_active": not function.is_active}, db=db
|
||||
)
|
||||
|
||||
if function:
|
||||
|
|
@ -282,11 +284,11 @@ async def toggle_function_by_id(id: str, user=Depends(get_admin_user)):
|
|||
|
||||
|
||||
@router.post("/id/{id}/toggle/global", response_model=Optional[FunctionModel])
|
||||
async def toggle_global_by_id(id: str, user=Depends(get_admin_user)):
|
||||
function = Functions.get_function_by_id(id)
|
||||
async def toggle_global_by_id(id: str, user=Depends(get_admin_user), db: Session = Depends(get_session)):
|
||||
function = Functions.get_function_by_id(id, db=db)
|
||||
if function:
|
||||
function = Functions.update_function_by_id(
|
||||
id, {"is_global": not function.is_global}
|
||||
id, {"is_global": not function.is_global}, db=db
|
||||
)
|
||||
|
||||
if function:
|
||||
|
|
@ -310,7 +312,7 @@ async def toggle_global_by_id(id: str, user=Depends(get_admin_user)):
|
|||
|
||||
@router.post("/id/{id}/update", response_model=Optional[FunctionModel])
|
||||
async def update_function_by_id(
|
||||
request: Request, id: str, form_data: FunctionForm, user=Depends(get_admin_user)
|
||||
request: Request, id: str, form_data: FunctionForm, user=Depends(get_admin_user), db: Session = Depends(get_session)
|
||||
):
|
||||
try:
|
||||
form_data.content = replace_imports(form_data.content)
|
||||
|
|
@ -325,10 +327,10 @@ async def update_function_by_id(
|
|||
updated = {**form_data.model_dump(exclude={"id"}), "type": function_type}
|
||||
log.debug(updated)
|
||||
|
||||
function = Functions.update_function_by_id(id, updated)
|
||||
function = Functions.update_function_by_id(id, updated, db=db)
|
||||
|
||||
if function_type == "filter" and getattr(function_module, "toggle", None):
|
||||
Functions.update_function_metadata_by_id(id, {"toggle": True})
|
||||
Functions.update_function_metadata_by_id(id, {"toggle": True}, db=db)
|
||||
|
||||
if function:
|
||||
return function
|
||||
|
|
@ -352,9 +354,9 @@ async def update_function_by_id(
|
|||
|
||||
@router.delete("/id/{id}/delete", response_model=bool)
|
||||
async def delete_function_by_id(
|
||||
request: Request, id: str, user=Depends(get_admin_user)
|
||||
request: Request, id: str, user=Depends(get_admin_user), db: Session = Depends(get_session)
|
||||
):
|
||||
result = Functions.delete_function_by_id(id)
|
||||
result = Functions.delete_function_by_id(id, db=db)
|
||||
|
||||
if result:
|
||||
FUNCTIONS = request.app.state.FUNCTIONS
|
||||
|
|
@ -370,11 +372,11 @@ async def delete_function_by_id(
|
|||
|
||||
|
||||
@router.get("/id/{id}/valves", response_model=Optional[dict])
|
||||
async def get_function_valves_by_id(id: str, user=Depends(get_admin_user)):
|
||||
function = Functions.get_function_by_id(id)
|
||||
async def get_function_valves_by_id(id: str, user=Depends(get_admin_user), db: Session = Depends(get_session)):
|
||||
function = Functions.get_function_by_id(id, db=db)
|
||||
if function:
|
||||
try:
|
||||
valves = Functions.get_function_valves_by_id(id)
|
||||
valves = Functions.get_function_valves_by_id(id, db=db)
|
||||
return valves
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
|
|
@ -395,9 +397,9 @@ async def get_function_valves_by_id(id: str, user=Depends(get_admin_user)):
|
|||
|
||||
@router.get("/id/{id}/valves/spec", response_model=Optional[dict])
|
||||
async def get_function_valves_spec_by_id(
|
||||
request: Request, id: str, user=Depends(get_admin_user)
|
||||
request: Request, id: str, user=Depends(get_admin_user), db: Session = Depends(get_session)
|
||||
):
|
||||
function = Functions.get_function_by_id(id)
|
||||
function = Functions.get_function_by_id(id, db=db)
|
||||
if function:
|
||||
function_module, function_type, frontmatter = get_function_module_from_cache(
|
||||
request, id
|
||||
|
|
@ -421,9 +423,9 @@ async def get_function_valves_spec_by_id(
|
|||
|
||||
@router.post("/id/{id}/valves/update", response_model=Optional[dict])
|
||||
async def update_function_valves_by_id(
|
||||
request: Request, id: str, form_data: dict, user=Depends(get_admin_user)
|
||||
request: Request, id: str, form_data: dict, user=Depends(get_admin_user), db: Session = Depends(get_session)
|
||||
):
|
||||
function = Functions.get_function_by_id(id)
|
||||
function = Functions.get_function_by_id(id, db=db)
|
||||
if function:
|
||||
function_module, function_type, frontmatter = get_function_module_from_cache(
|
||||
request, id
|
||||
|
|
@ -437,7 +439,7 @@ async def update_function_valves_by_id(
|
|||
valves = Valves(**form_data)
|
||||
|
||||
valves_dict = valves.model_dump(exclude_unset=True)
|
||||
Functions.update_function_valves_by_id(id, valves_dict)
|
||||
Functions.update_function_valves_by_id(id, valves_dict, db=db)
|
||||
return valves_dict
|
||||
except Exception as e:
|
||||
log.exception(f"Error updating function values by id {id}: {e}")
|
||||
|
|
@ -464,11 +466,11 @@ async def update_function_valves_by_id(
|
|||
|
||||
|
||||
@router.get("/id/{id}/valves/user", response_model=Optional[dict])
|
||||
async def get_function_user_valves_by_id(id: str, user=Depends(get_verified_user)):
|
||||
function = Functions.get_function_by_id(id)
|
||||
async def get_function_user_valves_by_id(id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)):
|
||||
function = Functions.get_function_by_id(id, db=db)
|
||||
if function:
|
||||
try:
|
||||
user_valves = Functions.get_user_valves_by_id_and_user_id(id, user.id)
|
||||
user_valves = Functions.get_user_valves_by_id_and_user_id(id, user.id, db=db)
|
||||
return user_valves
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
|
|
@ -484,9 +486,9 @@ async def get_function_user_valves_by_id(id: str, user=Depends(get_verified_user
|
|||
|
||||
@router.get("/id/{id}/valves/user/spec", response_model=Optional[dict])
|
||||
async def get_function_user_valves_spec_by_id(
|
||||
request: Request, id: str, user=Depends(get_verified_user)
|
||||
request: Request, id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||
):
|
||||
function = Functions.get_function_by_id(id)
|
||||
function = Functions.get_function_by_id(id, db=db)
|
||||
if function:
|
||||
function_module, function_type, frontmatter = get_function_module_from_cache(
|
||||
request, id
|
||||
|
|
@ -505,9 +507,9 @@ async def get_function_user_valves_spec_by_id(
|
|||
|
||||
@router.post("/id/{id}/valves/user/update", response_model=Optional[dict])
|
||||
async def update_function_user_valves_by_id(
|
||||
request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
|
||||
request: Request, id: str, form_data: dict, user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||
):
|
||||
function = Functions.get_function_by_id(id)
|
||||
function = Functions.get_function_by_id(id, db=db)
|
||||
|
||||
if function:
|
||||
function_module, function_type, frontmatter = get_function_module_from_cache(
|
||||
|
|
@ -522,7 +524,7 @@ async def update_function_user_valves_by_id(
|
|||
user_valves = UserValves(**form_data)
|
||||
user_valves_dict = user_valves.model_dump(exclude_unset=True)
|
||||
Functions.update_user_valves_by_id_and_user_id(
|
||||
id, user.id, user_valves_dict
|
||||
id, user.id, user_valves_dict, db=db
|
||||
)
|
||||
return user_valves_dict
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -16,6 +16,9 @@ from open_webui.config import CACHE_DIR
|
|||
from open_webui.constants import ERROR_MESSAGES
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
|
||||
from open_webui.internal.db import get_session
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||
|
||||
|
||||
|
|
@ -29,7 +32,11 @@ router = APIRouter()
|
|||
|
||||
|
||||
@router.get("/", response_model=list[GroupResponse])
|
||||
async def get_groups(share: Optional[bool] = None, user=Depends(get_verified_user)):
|
||||
async def get_groups(
|
||||
share: Optional[bool] = None,
|
||||
user=Depends(get_verified_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
|
||||
filter = {}
|
||||
if user.role != "admin":
|
||||
|
|
@ -38,7 +45,7 @@ async def get_groups(share: Optional[bool] = None, user=Depends(get_verified_use
|
|||
if share is not None:
|
||||
filter["share"] = share
|
||||
|
||||
groups = Groups.get_groups(filter=filter)
|
||||
groups = Groups.get_groups(filter=filter, db=db)
|
||||
|
||||
return groups
|
||||
|
||||
|
|
@ -49,13 +56,17 @@ async def get_groups(share: Optional[bool] = None, user=Depends(get_verified_use
|
|||
|
||||
|
||||
@router.post("/create", response_model=Optional[GroupResponse])
|
||||
async def create_new_group(form_data: GroupForm, user=Depends(get_admin_user)):
|
||||
async def create_new_group(
|
||||
form_data: GroupForm,
|
||||
user=Depends(get_admin_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
try:
|
||||
group = Groups.insert_new_group(user.id, form_data)
|
||||
group = Groups.insert_new_group(user.id, form_data, db=db)
|
||||
if group:
|
||||
return GroupResponse(
|
||||
**group.model_dump(),
|
||||
member_count=Groups.get_group_member_count_by_id(group.id),
|
||||
member_count=Groups.get_group_member_count_by_id(group.id, db=db),
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
|
|
@ -76,12 +87,14 @@ async def create_new_group(form_data: GroupForm, user=Depends(get_admin_user)):
|
|||
|
||||
|
||||
@router.get("/id/{id}", response_model=Optional[GroupResponse])
|
||||
async def get_group_by_id(id: str, user=Depends(get_admin_user)):
|
||||
group = Groups.get_group_by_id(id)
|
||||
async def get_group_by_id(
|
||||
id: str, user=Depends(get_admin_user), db: Session = Depends(get_session)
|
||||
):
|
||||
group = Groups.get_group_by_id(id, db=db)
|
||||
if group:
|
||||
return GroupResponse(
|
||||
**group.model_dump(),
|
||||
member_count=Groups.get_group_member_count_by_id(group.id),
|
||||
member_count=Groups.get_group_member_count_by_id(group.id, db=db),
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
|
|
@ -101,13 +114,15 @@ class GroupExportResponse(GroupResponse):
|
|||
|
||||
|
||||
@router.get("/id/{id}/export", response_model=Optional[GroupExportResponse])
|
||||
async def export_group_by_id(id: str, user=Depends(get_admin_user)):
|
||||
group = Groups.get_group_by_id(id)
|
||||
async def export_group_by_id(
|
||||
id: str, user=Depends(get_admin_user), db: Session = Depends(get_session)
|
||||
):
|
||||
group = Groups.get_group_by_id(id, db=db)
|
||||
if group:
|
||||
return GroupExportResponse(
|
||||
**group.model_dump(),
|
||||
member_count=Groups.get_group_member_count_by_id(group.id),
|
||||
user_ids=Groups.get_group_user_ids_by_id(group.id),
|
||||
member_count=Groups.get_group_member_count_by_id(group.id, db=db),
|
||||
user_ids=Groups.get_group_user_ids_by_id(group.id, db=db),
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
|
|
@ -122,9 +137,11 @@ async def export_group_by_id(id: str, user=Depends(get_admin_user)):
|
|||
|
||||
|
||||
@router.post("/id/{id}/users", response_model=list[UserInfoResponse])
|
||||
async def get_users_in_group(id: str, user=Depends(get_admin_user)):
|
||||
async def get_users_in_group(
|
||||
id: str, user=Depends(get_admin_user), db: Session = Depends(get_session)
|
||||
):
|
||||
try:
|
||||
users = Users.get_users_by_group_id(id)
|
||||
users = Users.get_users_by_group_id(id, db=db)
|
||||
return users
|
||||
except Exception as e:
|
||||
log.exception(f"Error adding users to group {id}: {e}")
|
||||
|
|
@ -141,14 +158,17 @@ async def get_users_in_group(id: str, user=Depends(get_admin_user)):
|
|||
|
||||
@router.post("/id/{id}/update", response_model=Optional[GroupResponse])
|
||||
async def update_group_by_id(
|
||||
id: str, form_data: GroupUpdateForm, user=Depends(get_admin_user)
|
||||
id: str,
|
||||
form_data: GroupUpdateForm,
|
||||
user=Depends(get_admin_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
try:
|
||||
group = Groups.update_group_by_id(id, form_data)
|
||||
group = Groups.update_group_by_id(id, form_data, db=db)
|
||||
if group:
|
||||
return GroupResponse(
|
||||
**group.model_dump(),
|
||||
member_count=Groups.get_group_member_count_by_id(group.id),
|
||||
member_count=Groups.get_group_member_count_by_id(group.id, db=db),
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
|
|
@ -170,17 +190,20 @@ async def update_group_by_id(
|
|||
|
||||
@router.post("/id/{id}/users/add", response_model=Optional[GroupResponse])
|
||||
async def add_user_to_group(
|
||||
id: str, form_data: UserIdsForm, user=Depends(get_admin_user)
|
||||
id: str,
|
||||
form_data: UserIdsForm,
|
||||
user=Depends(get_admin_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
try:
|
||||
if form_data.user_ids:
|
||||
form_data.user_ids = Users.get_valid_user_ids(form_data.user_ids)
|
||||
form_data.user_ids = Users.get_valid_user_ids(form_data.user_ids, db=db)
|
||||
|
||||
group = Groups.add_users_to_group(id, form_data.user_ids)
|
||||
group = Groups.add_users_to_group(id, form_data.user_ids, db=db)
|
||||
if group:
|
||||
return GroupResponse(
|
||||
**group.model_dump(),
|
||||
member_count=Groups.get_group_member_count_by_id(group.id),
|
||||
member_count=Groups.get_group_member_count_by_id(group.id, db=db),
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
|
|
@ -197,14 +220,17 @@ async def add_user_to_group(
|
|||
|
||||
@router.post("/id/{id}/users/remove", response_model=Optional[GroupResponse])
|
||||
async def remove_users_from_group(
|
||||
id: str, form_data: UserIdsForm, user=Depends(get_admin_user)
|
||||
id: str,
|
||||
form_data: UserIdsForm,
|
||||
user=Depends(get_admin_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
try:
|
||||
group = Groups.remove_users_from_group(id, form_data.user_ids)
|
||||
group = Groups.remove_users_from_group(id, form_data.user_ids, db=db)
|
||||
if group:
|
||||
return GroupResponse(
|
||||
**group.model_dump(),
|
||||
member_count=Groups.get_group_member_count_by_id(group.id),
|
||||
member_count=Groups.get_group_member_count_by_id(group.id, db=db),
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
|
|
@ -225,9 +251,11 @@ async def remove_users_from_group(
|
|||
|
||||
|
||||
@router.delete("/id/{id}/delete", response_model=bool)
|
||||
async def delete_group_by_id(id: str, user=Depends(get_admin_user)):
|
||||
async def delete_group_by_id(
|
||||
id: str, user=Depends(get_admin_user), db: Session = Depends(get_session)
|
||||
):
|
||||
try:
|
||||
result = Groups.delete_group_by_id(id)
|
||||
result = Groups.delete_group_by_id(id, db=db)
|
||||
if result:
|
||||
return result
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -4,6 +4,8 @@ from fastapi import APIRouter, Depends, HTTPException, status, Request, Query
|
|||
from fastapi.concurrency import run_in_threadpool
|
||||
import logging
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from open_webui.internal.db import get_session
|
||||
from open_webui.models.groups import Groups
|
||||
from open_webui.models.knowledge import (
|
||||
KnowledgeFileListResponse,
|
||||
|
|
@ -52,21 +54,25 @@ class KnowledgeAccessListResponse(BaseModel):
|
|||
|
||||
|
||||
@router.get("/", response_model=KnowledgeAccessListResponse)
|
||||
async def get_knowledge_bases(page: Optional[int] = 1, user=Depends(get_verified_user)):
|
||||
async def get_knowledge_bases(
|
||||
page: Optional[int] = 1,
|
||||
user=Depends(get_verified_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
page = max(page, 1)
|
||||
limit = PAGE_ITEM_COUNT
|
||||
skip = (page - 1) * limit
|
||||
|
||||
filter = {}
|
||||
if not user.role == "admin" or not BYPASS_ADMIN_ACCESS_CONTROL:
|
||||
groups = Groups.get_groups_by_member_id(user.id)
|
||||
groups = Groups.get_groups_by_member_id(user.id, db=db)
|
||||
if groups:
|
||||
filter["group_ids"] = [group.id for group in groups]
|
||||
|
||||
filter["user_id"] = user.id
|
||||
|
||||
result = Knowledges.search_knowledge_bases(
|
||||
user.id, filter=filter, skip=skip, limit=limit
|
||||
user.id, filter=filter, skip=skip, limit=limit, db=db
|
||||
)
|
||||
|
||||
return KnowledgeAccessListResponse(
|
||||
|
|
@ -75,7 +81,9 @@ async def get_knowledge_bases(page: Optional[int] = 1, user=Depends(get_verified
|
|||
**knowledge_base.model_dump(),
|
||||
write_access=(
|
||||
user.id == knowledge_base.user_id
|
||||
or has_access(user.id, "write", knowledge_base.access_control)
|
||||
or has_access(
|
||||
user.id, "write", knowledge_base.access_control, db=db
|
||||
)
|
||||
),
|
||||
)
|
||||
for knowledge_base in result.items
|
||||
|
|
@ -90,6 +98,7 @@ async def search_knowledge_bases(
|
|||
view_option: Optional[str] = None,
|
||||
page: Optional[int] = 1,
|
||||
user=Depends(get_verified_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
page = max(page, 1)
|
||||
limit = PAGE_ITEM_COUNT
|
||||
|
|
@ -102,14 +111,14 @@ async def search_knowledge_bases(
|
|||
filter["view_option"] = view_option
|
||||
|
||||
if not user.role == "admin" or not BYPASS_ADMIN_ACCESS_CONTROL:
|
||||
groups = Groups.get_groups_by_member_id(user.id)
|
||||
groups = Groups.get_groups_by_member_id(user.id, db=db)
|
||||
if groups:
|
||||
filter["group_ids"] = [group.id for group in groups]
|
||||
|
||||
filter["user_id"] = user.id
|
||||
|
||||
result = Knowledges.search_knowledge_bases(
|
||||
user.id, filter=filter, skip=skip, limit=limit
|
||||
user.id, filter=filter, skip=skip, limit=limit, db=db
|
||||
)
|
||||
|
||||
return KnowledgeAccessListResponse(
|
||||
|
|
@ -118,7 +127,9 @@ async def search_knowledge_bases(
|
|||
**knowledge_base.model_dump(),
|
||||
write_access=(
|
||||
user.id == knowledge_base.user_id
|
||||
or has_access(user.id, "write", knowledge_base.access_control)
|
||||
or has_access(
|
||||
user.id, "write", knowledge_base.access_control, db=db
|
||||
)
|
||||
),
|
||||
)
|
||||
for knowledge_base in result.items
|
||||
|
|
@ -132,6 +143,7 @@ async def search_knowledge_files(
|
|||
query: Optional[str] = None,
|
||||
page: Optional[int] = 1,
|
||||
user=Depends(get_verified_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
page = max(page, 1)
|
||||
limit = PAGE_ITEM_COUNT
|
||||
|
|
@ -141,13 +153,15 @@ async def search_knowledge_files(
|
|||
if query:
|
||||
filter["query"] = query
|
||||
|
||||
groups = Groups.get_groups_by_member_id(user.id)
|
||||
groups = Groups.get_groups_by_member_id(user.id, db=db)
|
||||
if groups:
|
||||
filter["group_ids"] = [group.id for group in groups]
|
||||
|
||||
filter["user_id"] = user.id
|
||||
|
||||
return Knowledges.search_knowledge_files(filter=filter, skip=skip, limit=limit)
|
||||
return Knowledges.search_knowledge_files(
|
||||
filter=filter, skip=skip, limit=limit, db=db
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
|
|
@ -157,10 +171,13 @@ async def search_knowledge_files(
|
|||
|
||||
@router.post("/create", response_model=Optional[KnowledgeResponse])
|
||||
async def create_new_knowledge(
|
||||
request: Request, form_data: KnowledgeForm, user=Depends(get_verified_user)
|
||||
request: Request,
|
||||
form_data: KnowledgeForm,
|
||||
user=Depends(get_verified_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
if user.role != "admin" and not has_permission(
|
||||
user.id, "workspace.knowledge", request.app.state.config.USER_PERMISSIONS
|
||||
user.id, "workspace.knowledge", request.app.state.config.USER_PERMISSIONS, db=db
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
|
|
@ -175,11 +192,12 @@ async def create_new_knowledge(
|
|||
user.id,
|
||||
"sharing.public_knowledge",
|
||||
request.app.state.config.USER_PERMISSIONS,
|
||||
db=db,
|
||||
)
|
||||
):
|
||||
form_data.access_control = {}
|
||||
|
||||
knowledge = Knowledges.insert_new_knowledge(user.id, form_data)
|
||||
knowledge = Knowledges.insert_new_knowledge(user.id, form_data, db=db)
|
||||
|
||||
if knowledge:
|
||||
return knowledge
|
||||
|
|
@ -196,20 +214,24 @@ async def create_new_knowledge(
|
|||
|
||||
|
||||
@router.post("/reindex", response_model=bool)
|
||||
async def reindex_knowledge_files(request: Request, user=Depends(get_verified_user)):
|
||||
async def reindex_knowledge_files(
|
||||
request: Request,
|
||||
user=Depends(get_verified_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
if user.role != "admin":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
||||
)
|
||||
|
||||
knowledge_bases = Knowledges.get_knowledge_bases()
|
||||
knowledge_bases = Knowledges.get_knowledge_bases(db=db)
|
||||
|
||||
log.info(f"Starting reindexing for {len(knowledge_bases)} knowledge bases")
|
||||
|
||||
for knowledge_base in knowledge_bases:
|
||||
try:
|
||||
files = Knowledges.get_files_by_id(knowledge_base.id)
|
||||
files = Knowledges.get_files_by_id(knowledge_base.id, db=db)
|
||||
try:
|
||||
if VECTOR_DB_CLIENT.has_collection(collection_name=knowledge_base.id):
|
||||
VECTOR_DB_CLIENT.delete_collection(
|
||||
|
|
@ -229,6 +251,7 @@ async def reindex_knowledge_files(request: Request, user=Depends(get_verified_us
|
|||
file_id=file.id, collection_name=knowledge_base.id
|
||||
),
|
||||
user=user,
|
||||
db=db,
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(
|
||||
|
|
@ -264,21 +287,23 @@ class KnowledgeFilesResponse(KnowledgeResponse):
|
|||
|
||||
|
||||
@router.get("/{id}", response_model=Optional[KnowledgeFilesResponse])
|
||||
async def get_knowledge_by_id(id: str, user=Depends(get_verified_user)):
|
||||
knowledge = Knowledges.get_knowledge_by_id(id=id)
|
||||
async def get_knowledge_by_id(
|
||||
id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||
):
|
||||
knowledge = Knowledges.get_knowledge_by_id(id=id, db=db)
|
||||
|
||||
if knowledge:
|
||||
if (
|
||||
user.role == "admin"
|
||||
or knowledge.user_id == user.id
|
||||
or has_access(user.id, "read", knowledge.access_control)
|
||||
or has_access(user.id, "read", knowledge.access_control, db=db)
|
||||
):
|
||||
|
||||
return KnowledgeFilesResponse(
|
||||
**knowledge.model_dump(),
|
||||
write_access=(
|
||||
user.id == knowledge.user_id
|
||||
or has_access(user.id, "write", knowledge.access_control)
|
||||
or has_access(user.id, "write", knowledge.access_control, db=db)
|
||||
),
|
||||
)
|
||||
else:
|
||||
|
|
@ -299,8 +324,9 @@ async def update_knowledge_by_id(
|
|||
id: str,
|
||||
form_data: KnowledgeForm,
|
||||
user=Depends(get_verified_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
knowledge = Knowledges.get_knowledge_by_id(id=id)
|
||||
knowledge = Knowledges.get_knowledge_by_id(id=id, db=db)
|
||||
if not knowledge:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
|
|
@ -309,7 +335,7 @@ async def update_knowledge_by_id(
|
|||
# Is the user the original creator, in a group with write access, or an admin
|
||||
if (
|
||||
knowledge.user_id != user.id
|
||||
and not has_access(user.id, "write", knowledge.access_control)
|
||||
and not has_access(user.id, "write", knowledge.access_control, db=db)
|
||||
and user.role != "admin"
|
||||
):
|
||||
raise HTTPException(
|
||||
|
|
@ -325,15 +351,16 @@ async def update_knowledge_by_id(
|
|||
user.id,
|
||||
"sharing.public_knowledge",
|
||||
request.app.state.config.USER_PERMISSIONS,
|
||||
db=db,
|
||||
)
|
||||
):
|
||||
form_data.access_control = {}
|
||||
|
||||
knowledge = Knowledges.update_knowledge_by_id(id=id, form_data=form_data)
|
||||
knowledge = Knowledges.update_knowledge_by_id(id=id, form_data=form_data, db=db)
|
||||
if knowledge:
|
||||
return KnowledgeFilesResponse(
|
||||
**knowledge.model_dump(),
|
||||
files=Knowledges.get_file_metadatas_by_id(knowledge.id),
|
||||
files=Knowledges.get_file_metadatas_by_id(knowledge.id, db=db),
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
|
|
@ -356,9 +383,10 @@ async def get_knowledge_files_by_id(
|
|||
direction: Optional[str] = None,
|
||||
page: Optional[int] = 1,
|
||||
user=Depends(get_verified_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
|
||||
knowledge = Knowledges.get_knowledge_by_id(id=id)
|
||||
knowledge = Knowledges.get_knowledge_by_id(id=id, db=db)
|
||||
if not knowledge:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
|
|
@ -368,7 +396,7 @@ async def get_knowledge_files_by_id(
|
|||
if not (
|
||||
user.role == "admin"
|
||||
or knowledge.user_id == user.id
|
||||
or has_access(user.id, "read", knowledge.access_control)
|
||||
or has_access(user.id, "read", knowledge.access_control, db=db)
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
|
|
@ -391,7 +419,7 @@ async def get_knowledge_files_by_id(
|
|||
filter["direction"] = direction
|
||||
|
||||
return Knowledges.search_files_by_id(
|
||||
id, user.id, filter=filter, skip=skip, limit=limit
|
||||
id, user.id, filter=filter, skip=skip, limit=limit, db=db
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -410,8 +438,9 @@ def add_file_to_knowledge_by_id(
|
|||
id: str,
|
||||
form_data: KnowledgeFileIdForm,
|
||||
user=Depends(get_verified_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
knowledge = Knowledges.get_knowledge_by_id(id=id)
|
||||
knowledge = Knowledges.get_knowledge_by_id(id=id, db=db)
|
||||
if not knowledge:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
|
|
@ -420,7 +449,7 @@ def add_file_to_knowledge_by_id(
|
|||
|
||||
if (
|
||||
knowledge.user_id != user.id
|
||||
and not has_access(user.id, "write", knowledge.access_control)
|
||||
and not has_access(user.id, "write", knowledge.access_control, db=db)
|
||||
and user.role != "admin"
|
||||
):
|
||||
raise HTTPException(
|
||||
|
|
@ -428,7 +457,7 @@ def add_file_to_knowledge_by_id(
|
|||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||
)
|
||||
|
||||
file = Files.get_file_by_id(form_data.file_id)
|
||||
file = Files.get_file_by_id(form_data.file_id, db=db)
|
||||
if not file:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
|
|
@ -446,11 +475,12 @@ def add_file_to_knowledge_by_id(
|
|||
request,
|
||||
ProcessFileForm(file_id=form_data.file_id, collection_name=id),
|
||||
user=user,
|
||||
db=db,
|
||||
)
|
||||
|
||||
# Add file to knowledge base
|
||||
Knowledges.add_file_to_knowledge_by_id(
|
||||
knowledge_id=id, file_id=form_data.file_id, user_id=user.id
|
||||
knowledge_id=id, file_id=form_data.file_id, user_id=user.id, db=db
|
||||
)
|
||||
except Exception as e:
|
||||
log.debug(e)
|
||||
|
|
@ -462,7 +492,7 @@ def add_file_to_knowledge_by_id(
|
|||
if knowledge:
|
||||
return KnowledgeFilesResponse(
|
||||
**knowledge.model_dump(),
|
||||
files=Knowledges.get_file_metadatas_by_id(knowledge.id),
|
||||
files=Knowledges.get_file_metadatas_by_id(knowledge.id, db=db),
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
|
|
@ -477,8 +507,9 @@ def update_file_from_knowledge_by_id(
|
|||
id: str,
|
||||
form_data: KnowledgeFileIdForm,
|
||||
user=Depends(get_verified_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
knowledge = Knowledges.get_knowledge_by_id(id=id)
|
||||
knowledge = Knowledges.get_knowledge_by_id(id=id, db=db)
|
||||
if not knowledge:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
|
|
@ -487,7 +518,7 @@ def update_file_from_knowledge_by_id(
|
|||
|
||||
if (
|
||||
knowledge.user_id != user.id
|
||||
and not has_access(user.id, "write", knowledge.access_control)
|
||||
and not has_access(user.id, "write", knowledge.access_control, db=db)
|
||||
and user.role != "admin"
|
||||
):
|
||||
|
||||
|
|
@ -496,7 +527,7 @@ def update_file_from_knowledge_by_id(
|
|||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||
)
|
||||
|
||||
file = Files.get_file_by_id(form_data.file_id)
|
||||
file = Files.get_file_by_id(form_data.file_id, db=db)
|
||||
if not file:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
|
|
@ -514,6 +545,7 @@ def update_file_from_knowledge_by_id(
|
|||
request,
|
||||
ProcessFileForm(file_id=form_data.file_id, collection_name=id),
|
||||
user=user,
|
||||
db=db,
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
|
|
@ -524,7 +556,7 @@ def update_file_from_knowledge_by_id(
|
|||
if knowledge:
|
||||
return KnowledgeFilesResponse(
|
||||
**knowledge.model_dump(),
|
||||
files=Knowledges.get_file_metadatas_by_id(knowledge.id),
|
||||
files=Knowledges.get_file_metadatas_by_id(knowledge.id, db=db),
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
|
|
@ -544,8 +576,9 @@ def remove_file_from_knowledge_by_id(
|
|||
form_data: KnowledgeFileIdForm,
|
||||
delete_file: bool = Query(True),
|
||||
user=Depends(get_verified_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
knowledge = Knowledges.get_knowledge_by_id(id=id)
|
||||
knowledge = Knowledges.get_knowledge_by_id(id=id, db=db)
|
||||
if not knowledge:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
|
|
@ -554,7 +587,7 @@ def remove_file_from_knowledge_by_id(
|
|||
|
||||
if (
|
||||
knowledge.user_id != user.id
|
||||
and not has_access(user.id, "write", knowledge.access_control)
|
||||
and not has_access(user.id, "write", knowledge.access_control, db=db)
|
||||
and user.role != "admin"
|
||||
):
|
||||
raise HTTPException(
|
||||
|
|
@ -562,7 +595,7 @@ def remove_file_from_knowledge_by_id(
|
|||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||
)
|
||||
|
||||
file = Files.get_file_by_id(form_data.file_id)
|
||||
file = Files.get_file_by_id(form_data.file_id, db=db)
|
||||
if not file:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
|
|
@ -570,7 +603,7 @@ def remove_file_from_knowledge_by_id(
|
|||
)
|
||||
|
||||
Knowledges.remove_file_from_knowledge_by_id(
|
||||
knowledge_id=id, file_id=form_data.file_id
|
||||
knowledge_id=id, file_id=form_data.file_id, db=db
|
||||
)
|
||||
|
||||
# Remove content from the vector database
|
||||
|
|
@ -599,12 +632,12 @@ def remove_file_from_knowledge_by_id(
|
|||
pass
|
||||
|
||||
# Delete file from database
|
||||
Files.delete_file_by_id(form_data.file_id)
|
||||
Files.delete_file_by_id(form_data.file_id, db=db)
|
||||
|
||||
if knowledge:
|
||||
return KnowledgeFilesResponse(
|
||||
**knowledge.model_dump(),
|
||||
files=Knowledges.get_file_metadatas_by_id(knowledge.id),
|
||||
files=Knowledges.get_file_metadatas_by_id(knowledge.id, db=db),
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
|
|
@ -619,8 +652,10 @@ def remove_file_from_knowledge_by_id(
|
|||
|
||||
|
||||
@router.delete("/{id}/delete", response_model=bool)
|
||||
async def delete_knowledge_by_id(id: str, user=Depends(get_verified_user)):
|
||||
knowledge = Knowledges.get_knowledge_by_id(id=id)
|
||||
async def delete_knowledge_by_id(
|
||||
id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||
):
|
||||
knowledge = Knowledges.get_knowledge_by_id(id=id, db=db)
|
||||
if not knowledge:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
|
|
@ -629,7 +664,7 @@ async def delete_knowledge_by_id(id: str, user=Depends(get_verified_user)):
|
|||
|
||||
if (
|
||||
knowledge.user_id != user.id
|
||||
and not has_access(user.id, "write", knowledge.access_control)
|
||||
and not has_access(user.id, "write", knowledge.access_control, db=db)
|
||||
and user.role != "admin"
|
||||
):
|
||||
raise HTTPException(
|
||||
|
|
@ -640,7 +675,7 @@ async def delete_knowledge_by_id(id: str, user=Depends(get_verified_user)):
|
|||
log.info(f"Deleting knowledge base: {id} (name: {knowledge.name})")
|
||||
|
||||
# Get all models
|
||||
models = Models.get_all_models()
|
||||
models = Models.get_all_models(db=db)
|
||||
log.info(f"Found {len(models)} models to check for knowledge base {id}")
|
||||
|
||||
# Update models that reference this knowledge base
|
||||
|
|
@ -664,7 +699,7 @@ async def delete_knowledge_by_id(id: str, user=Depends(get_verified_user)):
|
|||
access_control=model.access_control,
|
||||
is_active=model.is_active,
|
||||
)
|
||||
Models.update_model_by_id(model.id, model_form)
|
||||
Models.update_model_by_id(model.id, model_form, db=db)
|
||||
|
||||
# Clean up vector DB
|
||||
try:
|
||||
|
|
@ -672,7 +707,7 @@ async def delete_knowledge_by_id(id: str, user=Depends(get_verified_user)):
|
|||
except Exception as e:
|
||||
log.debug(e)
|
||||
pass
|
||||
result = Knowledges.delete_knowledge_by_id(id=id)
|
||||
result = Knowledges.delete_knowledge_by_id(id=id, db=db)
|
||||
return result
|
||||
|
||||
|
||||
|
|
@ -682,8 +717,10 @@ async def delete_knowledge_by_id(id: str, user=Depends(get_verified_user)):
|
|||
|
||||
|
||||
@router.post("/{id}/reset", response_model=Optional[KnowledgeResponse])
|
||||
async def reset_knowledge_by_id(id: str, user=Depends(get_verified_user)):
|
||||
knowledge = Knowledges.get_knowledge_by_id(id=id)
|
||||
async def reset_knowledge_by_id(
|
||||
id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||
):
|
||||
knowledge = Knowledges.get_knowledge_by_id(id=id, db=db)
|
||||
if not knowledge:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
|
|
@ -692,7 +729,7 @@ async def reset_knowledge_by_id(id: str, user=Depends(get_verified_user)):
|
|||
|
||||
if (
|
||||
knowledge.user_id != user.id
|
||||
and not has_access(user.id, "write", knowledge.access_control)
|
||||
and not has_access(user.id, "write", knowledge.access_control, db=db)
|
||||
and user.role != "admin"
|
||||
):
|
||||
raise HTTPException(
|
||||
|
|
@ -706,7 +743,7 @@ async def reset_knowledge_by_id(id: str, user=Depends(get_verified_user)):
|
|||
log.debug(e)
|
||||
pass
|
||||
|
||||
knowledge = Knowledges.reset_knowledge_by_id(id=id)
|
||||
knowledge = Knowledges.reset_knowledge_by_id(id=id, db=db)
|
||||
return knowledge
|
||||
|
||||
|
||||
|
|
@ -721,11 +758,12 @@ async def add_files_to_knowledge_batch(
|
|||
id: str,
|
||||
form_data: list[KnowledgeFileIdForm],
|
||||
user=Depends(get_verified_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
"""
|
||||
Add multiple files to a knowledge base
|
||||
"""
|
||||
knowledge = Knowledges.get_knowledge_by_id(id=id)
|
||||
knowledge = Knowledges.get_knowledge_by_id(id=id, db=db)
|
||||
if not knowledge:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
|
|
@ -734,7 +772,7 @@ async def add_files_to_knowledge_batch(
|
|||
|
||||
if (
|
||||
knowledge.user_id != user.id
|
||||
and not has_access(user.id, "write", knowledge.access_control)
|
||||
and not has_access(user.id, "write", knowledge.access_control, db=db)
|
||||
and user.role != "admin"
|
||||
):
|
||||
raise HTTPException(
|
||||
|
|
@ -746,7 +784,7 @@ async def add_files_to_knowledge_batch(
|
|||
log.info(f"files/batch/add - {len(form_data)} files")
|
||||
files: List[FileModel] = []
|
||||
for form in form_data:
|
||||
file = Files.get_file_by_id(form.file_id)
|
||||
file = Files.get_file_by_id(form.file_id, db=db)
|
||||
if not file:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
|
|
@ -760,6 +798,7 @@ async def add_files_to_knowledge_batch(
|
|||
request=request,
|
||||
form_data=BatchProcessFilesForm(files=files, collection_name=id),
|
||||
user=user,
|
||||
db=db,
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(
|
||||
|
|
@ -771,7 +810,7 @@ async def add_files_to_knowledge_batch(
|
|||
successful_file_ids = [r.file_id for r in result.results if r.status == "completed"]
|
||||
for file_id in successful_file_ids:
|
||||
Knowledges.add_file_to_knowledge_by_id(
|
||||
knowledge_id=id, file_id=file_id, user_id=user.id
|
||||
knowledge_id=id, file_id=file_id, user_id=user.id, db=db
|
||||
)
|
||||
|
||||
# If there were any errors, include them in the response
|
||||
|
|
@ -779,7 +818,7 @@ async def add_files_to_knowledge_batch(
|
|||
error_details = [f"{err.file_id}: {err.error}" for err in result.errors]
|
||||
return KnowledgeFilesResponse(
|
||||
**knowledge.model_dump(),
|
||||
files=Knowledges.get_file_metadatas_by_id(knowledge.id),
|
||||
files=Knowledges.get_file_metadatas_by_id(knowledge.id, db=db),
|
||||
warnings={
|
||||
"message": "Some files failed to process",
|
||||
"errors": error_details,
|
||||
|
|
@ -788,5 +827,5 @@ async def add_files_to_knowledge_batch(
|
|||
|
||||
return KnowledgeFilesResponse(
|
||||
**knowledge.model_dump(),
|
||||
files=Knowledges.get_file_metadatas_by_id(knowledge.id),
|
||||
files=Knowledges.get_file_metadatas_by_id(knowledge.id, db=db),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -9,6 +9,10 @@ from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT
|
|||
from open_webui.utils.auth import get_verified_user
|
||||
|
||||
|
||||
from open_webui.utils.auth import get_verified_user
|
||||
from open_webui.internal.db import get_session
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
|
@ -25,8 +29,8 @@ async def get_embeddings(request: Request):
|
|||
|
||||
|
||||
@router.get("/", response_model=list[MemoryModel])
|
||||
async def get_memories(user=Depends(get_verified_user)):
|
||||
return Memories.get_memories_by_user_id(user.id)
|
||||
async def get_memories(user=Depends(get_verified_user), db: Session = Depends(get_session)):
|
||||
return Memories.get_memories_by_user_id(user.id, db=db)
|
||||
|
||||
|
||||
############################
|
||||
|
|
@ -47,8 +51,9 @@ async def add_memory(
|
|||
request: Request,
|
||||
form_data: AddMemoryForm,
|
||||
user=Depends(get_verified_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
memory = Memories.insert_new_memory(user.id, form_data.content)
|
||||
memory = Memories.insert_new_memory(user.id, form_data.content, db=db)
|
||||
|
||||
vector = await request.app.state.EMBEDDING_FUNCTION(memory.content, user=user)
|
||||
|
||||
|
|
@ -79,9 +84,9 @@ class QueryMemoryForm(BaseModel):
|
|||
|
||||
@router.post("/query")
|
||||
async def query_memory(
|
||||
request: Request, form_data: QueryMemoryForm, user=Depends(get_verified_user)
|
||||
request: Request, form_data: QueryMemoryForm, user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||
):
|
||||
memories = Memories.get_memories_by_user_id(user.id)
|
||||
memories = Memories.get_memories_by_user_id(user.id, db=db)
|
||||
if not memories:
|
||||
raise HTTPException(status_code=404, detail="No memories found for user")
|
||||
|
||||
|
|
@ -101,11 +106,11 @@ async def query_memory(
|
|||
############################
|
||||
@router.post("/reset", response_model=bool)
|
||||
async def reset_memory_from_vector_db(
|
||||
request: Request, user=Depends(get_verified_user)
|
||||
request: Request, user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||
):
|
||||
VECTOR_DB_CLIENT.delete_collection(f"user-memory-{user.id}")
|
||||
|
||||
memories = Memories.get_memories_by_user_id(user.id)
|
||||
memories = Memories.get_memories_by_user_id(user.id, db=db)
|
||||
|
||||
# Generate vectors in parallel
|
||||
vectors = await asyncio.gather(
|
||||
|
|
@ -140,8 +145,8 @@ async def reset_memory_from_vector_db(
|
|||
|
||||
|
||||
@router.delete("/delete/user", response_model=bool)
|
||||
async def delete_memory_by_user_id(user=Depends(get_verified_user)):
|
||||
result = Memories.delete_memories_by_user_id(user.id)
|
||||
async def delete_memory_by_user_id(user=Depends(get_verified_user), db: Session = Depends(get_session)):
|
||||
result = Memories.delete_memories_by_user_id(user.id, db=db)
|
||||
|
||||
if result:
|
||||
try:
|
||||
|
|
@ -164,9 +169,10 @@ async def update_memory_by_id(
|
|||
request: Request,
|
||||
form_data: MemoryUpdateModel,
|
||||
user=Depends(get_verified_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
memory = Memories.update_memory_by_id_and_user_id(
|
||||
memory_id, user.id, form_data.content
|
||||
memory_id, user.id, form_data.content, db=db
|
||||
)
|
||||
if memory is None:
|
||||
raise HTTPException(status_code=404, detail="Memory not found")
|
||||
|
|
@ -198,8 +204,8 @@ async def update_memory_by_id(
|
|||
|
||||
|
||||
@router.delete("/{memory_id}", response_model=bool)
|
||||
async def delete_memory_by_id(memory_id: str, user=Depends(get_verified_user)):
|
||||
result = Memories.delete_memory_by_id_and_user_id(memory_id, user.id)
|
||||
async def delete_memory_by_id(memory_id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)):
|
||||
result = Memories.delete_memory_by_id_and_user_id(memory_id, user.id, db=db)
|
||||
|
||||
if result:
|
||||
VECTOR_DB_CLIENT.delete(
|
||||
|
|
|
|||
|
|
@ -30,6 +30,8 @@ from fastapi.responses import FileResponse, StreamingResponse
|
|||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||
from open_webui.utils.access_control import has_access, has_permission
|
||||
from open_webui.config import BYPASS_ADMIN_ACCESS_CONTROL, STATIC_DIR
|
||||
from open_webui.internal.db import get_session
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -59,6 +61,7 @@ async def get_models(
|
|||
direction: Optional[str] = None,
|
||||
page: Optional[int] = 1,
|
||||
user=Depends(get_verified_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
|
||||
limit = PAGE_ITEM_COUNT
|
||||
|
|
@ -79,13 +82,13 @@ async def get_models(
|
|||
filter["direction"] = direction
|
||||
|
||||
if not user.role == "admin" or not BYPASS_ADMIN_ACCESS_CONTROL:
|
||||
groups = Groups.get_groups_by_member_id(user.id)
|
||||
groups = Groups.get_groups_by_member_id(user.id, db=db)
|
||||
if groups:
|
||||
filter["group_ids"] = [group.id for group in groups]
|
||||
|
||||
filter["user_id"] = user.id
|
||||
|
||||
return Models.search_models(user.id, filter=filter, skip=skip, limit=limit)
|
||||
return Models.search_models(user.id, filter=filter, skip=skip, limit=limit, db=db)
|
||||
|
||||
|
||||
###########################
|
||||
|
|
@ -94,8 +97,8 @@ async def get_models(
|
|||
|
||||
|
||||
@router.get("/base", response_model=list[ModelResponse])
|
||||
async def get_base_models(user=Depends(get_admin_user)):
|
||||
return Models.get_base_models()
|
||||
async def get_base_models(user=Depends(get_admin_user), db: Session = Depends(get_session)):
|
||||
return Models.get_base_models(db=db)
|
||||
|
||||
|
||||
###########################
|
||||
|
|
@ -104,11 +107,11 @@ async def get_base_models(user=Depends(get_admin_user)):
|
|||
|
||||
|
||||
@router.get("/tags", response_model=list[str])
|
||||
async def get_model_tags(user=Depends(get_verified_user)):
|
||||
async def get_model_tags(user=Depends(get_verified_user), db: Session = Depends(get_session)):
|
||||
if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL:
|
||||
models = Models.get_models()
|
||||
models = Models.get_models(db=db)
|
||||
else:
|
||||
models = Models.get_models_by_user_id(user.id)
|
||||
models = Models.get_models_by_user_id(user.id, db=db)
|
||||
|
||||
tags_set = set()
|
||||
for model in models:
|
||||
|
|
@ -132,16 +135,17 @@ async def create_new_model(
|
|||
request: Request,
|
||||
form_data: ModelForm,
|
||||
user=Depends(get_verified_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
if user.role != "admin" and not has_permission(
|
||||
user.id, "workspace.models", request.app.state.config.USER_PERMISSIONS
|
||||
user.id, "workspace.models", request.app.state.config.USER_PERMISSIONS, db=db
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
||||
)
|
||||
|
||||
model = Models.get_model_by_id(form_data.id)
|
||||
model = Models.get_model_by_id(form_data.id, db=db)
|
||||
if model:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
|
|
@ -155,7 +159,7 @@ async def create_new_model(
|
|||
)
|
||||
|
||||
else:
|
||||
model = Models.insert_new_model(form_data, user.id)
|
||||
model = Models.insert_new_model(form_data, user.id, db=db)
|
||||
if model:
|
||||
return model
|
||||
else:
|
||||
|
|
@ -171,9 +175,9 @@ async def create_new_model(
|
|||
|
||||
|
||||
@router.get("/export", response_model=list[ModelModel])
|
||||
async def export_models(request: Request, user=Depends(get_verified_user)):
|
||||
async def export_models(request: Request, user=Depends(get_verified_user), db: Session = Depends(get_session)):
|
||||
if user.role != "admin" and not has_permission(
|
||||
user.id, "workspace.models_export", request.app.state.config.USER_PERMISSIONS
|
||||
user.id, "workspace.models_export", request.app.state.config.USER_PERMISSIONS, db=db
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
|
|
@ -181,9 +185,9 @@ async def export_models(request: Request, user=Depends(get_verified_user)):
|
|||
)
|
||||
|
||||
if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL:
|
||||
return Models.get_models()
|
||||
return Models.get_models(db=db)
|
||||
else:
|
||||
return Models.get_models_by_user_id(user.id)
|
||||
return Models.get_models_by_user_id(user.id, db=db)
|
||||
|
||||
|
||||
############################
|
||||
|
|
@ -200,9 +204,10 @@ async def import_models(
|
|||
request: Request,
|
||||
user=Depends(get_verified_user),
|
||||
form_data: ModelsImportForm = (...),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
if user.role != "admin" and not has_permission(
|
||||
user.id, "workspace.models_import", request.app.state.config.USER_PERMISSIONS
|
||||
user.id, "workspace.models_import", request.app.state.config.USER_PERMISSIONS, db=db
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
|
|
@ -216,7 +221,7 @@ async def import_models(
|
|||
model_id = model_data.get("id")
|
||||
|
||||
if model_id and is_valid_model_id(model_id):
|
||||
existing_model = Models.get_model_by_id(model_id)
|
||||
existing_model = Models.get_model_by_id(model_id, db=db)
|
||||
if existing_model:
|
||||
# Update existing model
|
||||
model_data["meta"] = model_data.get("meta", {})
|
||||
|
|
@ -225,13 +230,13 @@ async def import_models(
|
|||
updated_model = ModelForm(
|
||||
**{**existing_model.model_dump(), **model_data}
|
||||
)
|
||||
Models.update_model_by_id(model_id, updated_model)
|
||||
Models.update_model_by_id(model_id, updated_model, db=db)
|
||||
else:
|
||||
# Insert new model
|
||||
model_data["meta"] = model_data.get("meta", {})
|
||||
model_data["params"] = model_data.get("params", {})
|
||||
new_model = ModelForm(**model_data)
|
||||
Models.insert_new_model(user_id=user.id, form_data=new_model)
|
||||
Models.insert_new_model(user_id=user.id, form_data=new_model, db=db)
|
||||
return True
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="Invalid JSON format")
|
||||
|
|
@ -251,9 +256,9 @@ class SyncModelsForm(BaseModel):
|
|||
|
||||
@router.post("/sync", response_model=list[ModelModel])
|
||||
async def sync_models(
|
||||
request: Request, form_data: SyncModelsForm, user=Depends(get_admin_user)
|
||||
request: Request, form_data: SyncModelsForm, user=Depends(get_admin_user), db: Session = Depends(get_session)
|
||||
):
|
||||
return Models.sync_models(user.id, form_data.models)
|
||||
return Models.sync_models(user.id, form_data.models, db=db)
|
||||
|
||||
|
||||
###########################
|
||||
|
|
@ -267,13 +272,13 @@ class ModelIdForm(BaseModel):
|
|||
|
||||
# Note: We're not using the typical url path param here, but instead using a query parameter to allow '/' in the id
|
||||
@router.get("/model", response_model=Optional[ModelResponse])
|
||||
async def get_model_by_id(id: str, user=Depends(get_verified_user)):
|
||||
model = Models.get_model_by_id(id)
|
||||
async def get_model_by_id(id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)):
|
||||
model = Models.get_model_by_id(id, db=db)
|
||||
if model:
|
||||
if (
|
||||
(user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL)
|
||||
or model.user_id == user.id
|
||||
or has_access(user.id, "read", model.access_control)
|
||||
or has_access(user.id, "read", model.access_control, db=db)
|
||||
):
|
||||
return model
|
||||
else:
|
||||
|
|
@ -289,8 +294,8 @@ async def get_model_by_id(id: str, user=Depends(get_verified_user)):
|
|||
|
||||
|
||||
@router.get("/model/profile/image")
|
||||
async def get_model_profile_image(id: str, user=Depends(get_verified_user)):
|
||||
model = Models.get_model_by_id(id)
|
||||
async def get_model_profile_image(id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)):
|
||||
model = Models.get_model_by_id(id, db=db)
|
||||
# Cache-control headers to prevent stale cached images
|
||||
cache_headers = {"Cache-Control": "no-cache, must-revalidate"}
|
||||
|
||||
|
|
@ -330,15 +335,15 @@ async def get_model_profile_image(id: str, user=Depends(get_verified_user)):
|
|||
|
||||
|
||||
@router.post("/model/toggle", response_model=Optional[ModelResponse])
|
||||
async def toggle_model_by_id(id: str, user=Depends(get_verified_user)):
|
||||
model = Models.get_model_by_id(id)
|
||||
async def toggle_model_by_id(id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)):
|
||||
model = Models.get_model_by_id(id, db=db)
|
||||
if model:
|
||||
if (
|
||||
user.role == "admin"
|
||||
or model.user_id == user.id
|
||||
or has_access(user.id, "write", model.access_control)
|
||||
or has_access(user.id, "write", model.access_control, db=db)
|
||||
):
|
||||
model = Models.toggle_model_by_id(id)
|
||||
model = Models.toggle_model_by_id(id, db=db)
|
||||
|
||||
if model:
|
||||
return model
|
||||
|
|
@ -368,8 +373,9 @@ async def toggle_model_by_id(id: str, user=Depends(get_verified_user)):
|
|||
async def update_model_by_id(
|
||||
form_data: ModelForm,
|
||||
user=Depends(get_verified_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
model = Models.get_model_by_id(form_data.id)
|
||||
model = Models.get_model_by_id(form_data.id, db=db)
|
||||
if not model:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
|
|
@ -378,7 +384,7 @@ async def update_model_by_id(
|
|||
|
||||
if (
|
||||
model.user_id != user.id
|
||||
and not has_access(user.id, "write", model.access_control)
|
||||
and not has_access(user.id, "write", model.access_control, db=db)
|
||||
and user.role != "admin"
|
||||
):
|
||||
raise HTTPException(
|
||||
|
|
@ -386,7 +392,7 @@ async def update_model_by_id(
|
|||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||
)
|
||||
|
||||
model = Models.update_model_by_id(form_data.id, ModelForm(**form_data.model_dump()))
|
||||
model = Models.update_model_by_id(form_data.id, ModelForm(**form_data.model_dump()), db=db)
|
||||
return model
|
||||
|
||||
|
||||
|
|
@ -396,8 +402,8 @@ async def update_model_by_id(
|
|||
|
||||
|
||||
@router.post("/model/delete", response_model=bool)
|
||||
async def delete_model_by_id(form_data: ModelIdForm, user=Depends(get_verified_user)):
|
||||
model = Models.get_model_by_id(form_data.id)
|
||||
async def delete_model_by_id(form_data: ModelIdForm, user=Depends(get_verified_user), db: Session = Depends(get_session)):
|
||||
model = Models.get_model_by_id(form_data.id, db=db)
|
||||
if not model:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
|
|
@ -407,18 +413,18 @@ async def delete_model_by_id(form_data: ModelIdForm, user=Depends(get_verified_u
|
|||
if (
|
||||
user.role != "admin"
|
||||
and model.user_id != user.id
|
||||
and not has_access(user.id, "write", model.access_control)
|
||||
and not has_access(user.id, "write", model.access_control, db=db)
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
||||
)
|
||||
|
||||
result = Models.delete_model_by_id(form_data.id)
|
||||
result = Models.delete_model_by_id(form_data.id, db=db)
|
||||
return result
|
||||
|
||||
|
||||
@router.delete("/delete/all", response_model=bool)
|
||||
async def delete_all_models(user=Depends(get_admin_user)):
|
||||
result = Models.delete_all_models()
|
||||
async def delete_all_models(user=Depends(get_admin_user), db: Session = Depends(get_session)):
|
||||
result = Models.delete_all_models(db=db)
|
||||
return result
|
||||
|
|
|
|||
|
|
@ -28,6 +28,8 @@ from open_webui.constants import ERROR_MESSAGES
|
|||
|
||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||
from open_webui.utils.access_control import has_access, has_permission
|
||||
from open_webui.internal.db import get_session
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -49,10 +51,13 @@ class NoteItemResponse(BaseModel):
|
|||
|
||||
@router.get("/", response_model=list[NoteItemResponse])
|
||||
async def get_notes(
|
||||
request: Request, page: Optional[int] = None, user=Depends(get_verified_user)
|
||||
request: Request,
|
||||
page: Optional[int] = None,
|
||||
user=Depends(get_verified_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
if user.role != "admin" and not has_permission(
|
||||
user.id, "features.notes", request.app.state.config.USER_PERMISSIONS
|
||||
user.id, "features.notes", request.app.state.config.USER_PERMISSIONS, db=db
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
|
|
@ -69,10 +74,14 @@ async def get_notes(
|
|||
NoteUserResponse(
|
||||
**{
|
||||
**note.model_dump(),
|
||||
"user": UserResponse(**Users.get_user_by_id(note.user_id).model_dump()),
|
||||
"user": UserResponse(
|
||||
**Users.get_user_by_id(note.user_id, db=db).model_dump()
|
||||
),
|
||||
}
|
||||
)
|
||||
for note in Notes.get_notes_by_user_id(user.id, "read", skip=skip, limit=limit)
|
||||
for note in Notes.get_notes_by_user_id(
|
||||
user.id, "read", skip=skip, limit=limit, db=db
|
||||
)
|
||||
]
|
||||
return notes
|
||||
|
||||
|
|
@ -87,9 +96,10 @@ async def search_notes(
|
|||
direction: Optional[str] = None,
|
||||
page: Optional[int] = 1,
|
||||
user=Depends(get_verified_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
if user.role != "admin" and not has_permission(
|
||||
user.id, "features.notes", request.app.state.config.USER_PERMISSIONS
|
||||
user.id, "features.notes", request.app.state.config.USER_PERMISSIONS, db=db
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
|
|
@ -115,13 +125,13 @@ async def search_notes(
|
|||
filter["direction"] = direction
|
||||
|
||||
if not user.role == "admin" or not BYPASS_ADMIN_ACCESS_CONTROL:
|
||||
groups = Groups.get_groups_by_member_id(user.id)
|
||||
groups = Groups.get_groups_by_member_id(user.id, db=db)
|
||||
if groups:
|
||||
filter["group_ids"] = [group.id for group in groups]
|
||||
|
||||
filter["user_id"] = user.id
|
||||
|
||||
return Notes.search_notes(user.id, filter, skip=skip, limit=limit)
|
||||
return Notes.search_notes(user.id, filter, skip=skip, limit=limit, db=db)
|
||||
|
||||
|
||||
############################
|
||||
|
|
@ -131,10 +141,13 @@ async def search_notes(
|
|||
|
||||
@router.post("/create", response_model=Optional[NoteModel])
|
||||
async def create_new_note(
|
||||
request: Request, form_data: NoteForm, user=Depends(get_verified_user)
|
||||
request: Request,
|
||||
form_data: NoteForm,
|
||||
user=Depends(get_verified_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
if user.role != "admin" and not has_permission(
|
||||
user.id, "features.notes", request.app.state.config.USER_PERMISSIONS
|
||||
user.id, "features.notes", request.app.state.config.USER_PERMISSIONS, db=db
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
|
|
@ -142,7 +155,7 @@ async def create_new_note(
|
|||
)
|
||||
|
||||
try:
|
||||
note = Notes.insert_new_note(user.id, form_data)
|
||||
note = Notes.insert_new_note(user.id, form_data, db=db)
|
||||
return note
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
|
@ -161,16 +174,21 @@ class NoteResponse(NoteModel):
|
|||
|
||||
|
||||
@router.get("/{id}", response_model=Optional[NoteResponse])
|
||||
async def get_note_by_id(request: Request, id: str, user=Depends(get_verified_user)):
|
||||
async def get_note_by_id(
|
||||
request: Request,
|
||||
id: str,
|
||||
user=Depends(get_verified_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
if user.role != "admin" and not has_permission(
|
||||
user.id, "features.notes", request.app.state.config.USER_PERMISSIONS
|
||||
user.id, "features.notes", request.app.state.config.USER_PERMISSIONS, db=db
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
||||
)
|
||||
|
||||
note = Notes.get_note_by_id(id)
|
||||
note = Notes.get_note_by_id(id, db=db)
|
||||
if not note:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
|
|
@ -178,7 +196,11 @@ async def get_note_by_id(request: Request, id: str, user=Depends(get_verified_us
|
|||
|
||||
if user.role != "admin" and (
|
||||
user.id != note.user_id
|
||||
and (not has_access(user.id, type="read", access_control=note.access_control))
|
||||
and (
|
||||
not has_access(
|
||||
user.id, type="read", access_control=note.access_control, db=db
|
||||
)
|
||||
)
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||
|
|
@ -188,7 +210,11 @@ async def get_note_by_id(request: Request, id: str, user=Depends(get_verified_us
|
|||
user.role == "admin"
|
||||
or (user.id == note.user_id)
|
||||
or has_access(
|
||||
user.id, type="write", access_control=note.access_control, strict=False
|
||||
user.id,
|
||||
type="write",
|
||||
access_control=note.access_control,
|
||||
strict=False,
|
||||
db=db,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -202,17 +228,21 @@ async def get_note_by_id(request: Request, id: str, user=Depends(get_verified_us
|
|||
|
||||
@router.post("/{id}/update", response_model=Optional[NoteModel])
|
||||
async def update_note_by_id(
|
||||
request: Request, id: str, form_data: NoteForm, user=Depends(get_verified_user)
|
||||
request: Request,
|
||||
id: str,
|
||||
form_data: NoteForm,
|
||||
user=Depends(get_verified_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
if user.role != "admin" and not has_permission(
|
||||
user.id, "features.notes", request.app.state.config.USER_PERMISSIONS
|
||||
user.id, "features.notes", request.app.state.config.USER_PERMISSIONS, db=db
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
||||
)
|
||||
|
||||
note = Notes.get_note_by_id(id)
|
||||
note = Notes.get_note_by_id(id, db=db)
|
||||
if not note:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
|
|
@ -220,7 +250,9 @@ async def update_note_by_id(
|
|||
|
||||
if user.role != "admin" and (
|
||||
user.id != note.user_id
|
||||
and not has_access(user.id, type="write", access_control=note.access_control)
|
||||
and not has_access(
|
||||
user.id, type="write", access_control=note.access_control, db=db
|
||||
)
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||
|
|
@ -234,12 +266,13 @@ async def update_note_by_id(
|
|||
user.id,
|
||||
"sharing.public_notes",
|
||||
request.app.state.config.USER_PERMISSIONS,
|
||||
db=db,
|
||||
)
|
||||
):
|
||||
form_data.access_control = {}
|
||||
|
||||
try:
|
||||
note = Notes.update_note_by_id(id, form_data)
|
||||
note = Notes.update_note_by_id(id, form_data, db=db)
|
||||
await sio.emit(
|
||||
"note-events",
|
||||
note.model_dump(),
|
||||
|
|
@ -260,16 +293,21 @@ async def update_note_by_id(
|
|||
|
||||
|
||||
@router.delete("/{id}/delete", response_model=bool)
|
||||
async def delete_note_by_id(request: Request, id: str, user=Depends(get_verified_user)):
|
||||
async def delete_note_by_id(
|
||||
request: Request,
|
||||
id: str,
|
||||
user=Depends(get_verified_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
if user.role != "admin" and not has_permission(
|
||||
user.id, "features.notes", request.app.state.config.USER_PERMISSIONS
|
||||
user.id, "features.notes", request.app.state.config.USER_PERMISSIONS, db=db
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
||||
)
|
||||
|
||||
note = Notes.get_note_by_id(id)
|
||||
note = Notes.get_note_by_id(id, db=db)
|
||||
if not note:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
|
|
@ -277,14 +315,16 @@ async def delete_note_by_id(request: Request, id: str, user=Depends(get_verified
|
|||
|
||||
if user.role != "admin" and (
|
||||
user.id != note.user_id
|
||||
and not has_access(user.id, type="write", access_control=note.access_control)
|
||||
and not has_access(
|
||||
user.id, type="write", access_control=note.access_control, db=db
|
||||
)
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
||||
try:
|
||||
note = Notes.delete_note_by_id(id)
|
||||
note = Notes.delete_note_by_id(id, db=db)
|
||||
return True
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
|
|
|||
|
|
@ -11,6 +11,8 @@ from open_webui.constants import ERROR_MESSAGES
|
|||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||
from open_webui.utils.access_control import has_access, has_permission
|
||||
from open_webui.config import BYPASS_ADMIN_ACCESS_CONTROL
|
||||
from open_webui.internal.db import get_session
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
|
@ -20,21 +22,21 @@ router = APIRouter()
|
|||
|
||||
|
||||
@router.get("/", response_model=list[PromptModel])
|
||||
async def get_prompts(user=Depends(get_verified_user)):
|
||||
async def get_prompts(user=Depends(get_verified_user), db: Session = Depends(get_session)):
|
||||
if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL:
|
||||
prompts = Prompts.get_prompts()
|
||||
prompts = Prompts.get_prompts(db=db)
|
||||
else:
|
||||
prompts = Prompts.get_prompts_by_user_id(user.id, "read")
|
||||
prompts = Prompts.get_prompts_by_user_id(user.id, "read", db=db)
|
||||
|
||||
return prompts
|
||||
|
||||
|
||||
@router.get("/list", response_model=list[PromptUserResponse])
|
||||
async def get_prompt_list(user=Depends(get_verified_user)):
|
||||
async def get_prompt_list(user=Depends(get_verified_user), db: Session = Depends(get_session)):
|
||||
if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL:
|
||||
prompts = Prompts.get_prompts()
|
||||
prompts = Prompts.get_prompts(db=db)
|
||||
else:
|
||||
prompts = Prompts.get_prompts_by_user_id(user.id, "write")
|
||||
prompts = Prompts.get_prompts_by_user_id(user.id, "write", db=db)
|
||||
|
||||
return prompts
|
||||
|
||||
|
|
@ -46,16 +48,17 @@ async def get_prompt_list(user=Depends(get_verified_user)):
|
|||
|
||||
@router.post("/create", response_model=Optional[PromptModel])
|
||||
async def create_new_prompt(
|
||||
request: Request, form_data: PromptForm, user=Depends(get_verified_user)
|
||||
request: Request, form_data: PromptForm, user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||
):
|
||||
if user.role != "admin" and not (
|
||||
has_permission(
|
||||
user.id, "workspace.prompts", request.app.state.config.USER_PERMISSIONS
|
||||
user.id, "workspace.prompts", request.app.state.config.USER_PERMISSIONS, db=db
|
||||
)
|
||||
or has_permission(
|
||||
user.id,
|
||||
"workspace.prompts_import",
|
||||
request.app.state.config.USER_PERMISSIONS,
|
||||
db=db,
|
||||
)
|
||||
):
|
||||
raise HTTPException(
|
||||
|
|
@ -63,9 +66,9 @@ async def create_new_prompt(
|
|||
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
||||
)
|
||||
|
||||
prompt = Prompts.get_prompt_by_command(form_data.command)
|
||||
prompt = Prompts.get_prompt_by_command(form_data.command, db=db)
|
||||
if prompt is None:
|
||||
prompt = Prompts.insert_new_prompt(user.id, form_data)
|
||||
prompt = Prompts.insert_new_prompt(user.id, form_data, db=db)
|
||||
|
||||
if prompt:
|
||||
return prompt
|
||||
|
|
@ -85,14 +88,14 @@ async def create_new_prompt(
|
|||
|
||||
|
||||
@router.get("/command/{command}", response_model=Optional[PromptModel])
|
||||
async def get_prompt_by_command(command: str, user=Depends(get_verified_user)):
|
||||
prompt = Prompts.get_prompt_by_command(f"/{command}")
|
||||
async def get_prompt_by_command(command: str, user=Depends(get_verified_user), db: Session = Depends(get_session)):
|
||||
prompt = Prompts.get_prompt_by_command(f"/{command}", db=db)
|
||||
|
||||
if prompt:
|
||||
if (
|
||||
user.role == "admin"
|
||||
or prompt.user_id == user.id
|
||||
or has_access(user.id, "read", prompt.access_control)
|
||||
or has_access(user.id, "read", prompt.access_control, db=db)
|
||||
):
|
||||
return prompt
|
||||
else:
|
||||
|
|
@ -112,8 +115,9 @@ async def update_prompt_by_command(
|
|||
command: str,
|
||||
form_data: PromptForm,
|
||||
user=Depends(get_verified_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
prompt = Prompts.get_prompt_by_command(f"/{command}")
|
||||
prompt = Prompts.get_prompt_by_command(f"/{command}", db=db)
|
||||
if not prompt:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
|
|
@ -123,7 +127,7 @@ async def update_prompt_by_command(
|
|||
# Is the user the original creator, in a group with write access, or an admin
|
||||
if (
|
||||
prompt.user_id != user.id
|
||||
and not has_access(user.id, "write", prompt.access_control)
|
||||
and not has_access(user.id, "write", prompt.access_control, db=db)
|
||||
and user.role != "admin"
|
||||
):
|
||||
raise HTTPException(
|
||||
|
|
@ -131,7 +135,7 @@ async def update_prompt_by_command(
|
|||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||
)
|
||||
|
||||
prompt = Prompts.update_prompt_by_command(f"/{command}", form_data)
|
||||
prompt = Prompts.update_prompt_by_command(f"/{command}", form_data, db=db)
|
||||
if prompt:
|
||||
return prompt
|
||||
else:
|
||||
|
|
@ -147,8 +151,8 @@ async def update_prompt_by_command(
|
|||
|
||||
|
||||
@router.delete("/command/{command}/delete", response_model=bool)
|
||||
async def delete_prompt_by_command(command: str, user=Depends(get_verified_user)):
|
||||
prompt = Prompts.get_prompt_by_command(f"/{command}")
|
||||
async def delete_prompt_by_command(command: str, user=Depends(get_verified_user), db: Session = Depends(get_session)):
|
||||
prompt = Prompts.get_prompt_by_command(f"/{command}", db=db)
|
||||
if not prompt:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
|
|
@ -157,7 +161,7 @@ async def delete_prompt_by_command(command: str, user=Depends(get_verified_user)
|
|||
|
||||
if (
|
||||
prompt.user_id != user.id
|
||||
and not has_access(user.id, "write", prompt.access_control)
|
||||
and not has_access(user.id, "write", prompt.access_control, db=db)
|
||||
and user.role != "admin"
|
||||
):
|
||||
raise HTTPException(
|
||||
|
|
@ -165,5 +169,5 @@ async def delete_prompt_by_command(command: str, user=Depends(get_verified_user)
|
|||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||
)
|
||||
|
||||
result = Prompts.delete_prompt_by_command(f"/{command}")
|
||||
result = Prompts.delete_prompt_by_command(f"/{command}", db=db)
|
||||
return result
|
||||
|
|
|
|||
|
|
@ -39,6 +39,8 @@ from langchain_core.documents import Document
|
|||
from open_webui.models.files import FileModel, FileUpdateForm, Files
|
||||
from open_webui.models.knowledge import Knowledges
|
||||
from open_webui.storage.provider import Storage
|
||||
from open_webui.internal.db import get_session
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
||||
from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT
|
||||
|
|
@ -1484,14 +1486,15 @@ def process_file(
|
|||
request: Request,
|
||||
form_data: ProcessFileForm,
|
||||
user=Depends(get_verified_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
"""
|
||||
Process a file and save its content to the vector database.
|
||||
"""
|
||||
if user.role == "admin":
|
||||
file = Files.get_file_by_id(form_data.file_id)
|
||||
file = Files.get_file_by_id(form_data.file_id, db=db)
|
||||
else:
|
||||
file = Files.get_file_by_id_and_user_id(form_data.file_id, user.id)
|
||||
file = Files.get_file_by_id_and_user_id(form_data.file_id, user.id, db=db)
|
||||
|
||||
if file:
|
||||
try:
|
||||
|
|
@ -1633,12 +1636,13 @@ def process_file(
|
|||
Files.update_file_data_by_id(
|
||||
file.id,
|
||||
{"content": text_content},
|
||||
db=db,
|
||||
)
|
||||
hash = calculate_sha256_string(text_content)
|
||||
Files.update_file_hash_by_id(file.id, hash)
|
||||
Files.update_file_hash_by_id(file.id, hash, db=db)
|
||||
|
||||
if request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL:
|
||||
Files.update_file_data_by_id(file.id, {"status": "completed"})
|
||||
Files.update_file_data_by_id(file.id, {"status": "completed"}, db=db)
|
||||
return {
|
||||
"status": True,
|
||||
"collection_name": None,
|
||||
|
|
@ -1667,11 +1671,13 @@ def process_file(
|
|||
{
|
||||
"collection_name": collection_name,
|
||||
},
|
||||
db=db,
|
||||
)
|
||||
|
||||
Files.update_file_data_by_id(
|
||||
file.id,
|
||||
{"status": "completed"},
|
||||
db=db,
|
||||
)
|
||||
|
||||
return {
|
||||
|
|
@ -1690,6 +1696,7 @@ def process_file(
|
|||
Files.update_file_data_by_id(
|
||||
file.id,
|
||||
{"status": "failed"},
|
||||
db=db,
|
||||
)
|
||||
|
||||
if "No pandoc was found" in str(e):
|
||||
|
|
@ -2417,10 +2424,10 @@ class DeleteForm(BaseModel):
|
|||
|
||||
|
||||
@router.post("/delete")
|
||||
def delete_entries_from_collection(form_data: DeleteForm, user=Depends(get_admin_user)):
|
||||
def delete_entries_from_collection(form_data: DeleteForm, user=Depends(get_admin_user), db: Session = Depends(get_session)):
|
||||
try:
|
||||
if VECTOR_DB_CLIENT.has_collection(collection_name=form_data.collection_name):
|
||||
file = Files.get_file_by_id(form_data.file_id)
|
||||
file = Files.get_file_by_id(form_data.file_id, db=db)
|
||||
hash = file.hash
|
||||
|
||||
VECTOR_DB_CLIENT.delete(
|
||||
|
|
@ -2436,9 +2443,9 @@ def delete_entries_from_collection(form_data: DeleteForm, user=Depends(get_admin
|
|||
|
||||
|
||||
@router.post("/reset/db")
|
||||
def reset_vector_db(user=Depends(get_admin_user)):
|
||||
def reset_vector_db(user=Depends(get_admin_user), db: Session = Depends(get_session)):
|
||||
VECTOR_DB_CLIENT.reset()
|
||||
Knowledges.delete_all_knowledge()
|
||||
Knowledges.delete_all_knowledge(db=db)
|
||||
|
||||
|
||||
@router.post("/reset/uploads")
|
||||
|
|
@ -2496,6 +2503,7 @@ async def process_files_batch(
|
|||
request: Request,
|
||||
form_data: BatchProcessFilesForm,
|
||||
user=Depends(get_verified_user),
|
||||
db: Session = Depends(get_session),
|
||||
) -> BatchProcessFilesResponse:
|
||||
"""
|
||||
Process a batch of files and save them to the vector database.
|
||||
|
|
@ -2558,7 +2566,7 @@ async def process_files_batch(
|
|||
|
||||
# Update all files with collection name
|
||||
for file_update, file_result in zip(file_updates, file_results):
|
||||
Files.update_file_by_id(id=file_result.file_id, form_data=file_update)
|
||||
Files.update_file_by_id(id=file_result.file_id, form_data=file_update, db=db)
|
||||
file_result.status = "completed"
|
||||
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -7,6 +7,8 @@ import aiohttp
|
|||
from open_webui.models.groups import Groups
|
||||
from pydantic import BaseModel, HttpUrl
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from sqlalchemy.orm import Session
|
||||
from open_webui.internal.db import get_session
|
||||
|
||||
|
||||
from open_webui.models.oauth_sessions import OAuthSessions
|
||||
|
|
@ -51,11 +53,11 @@ def get_tool_module(request, tool_id, load_from_db=True):
|
|||
|
||||
|
||||
@router.get("/", response_model=list[ToolUserResponse])
|
||||
async def get_tools(request: Request, user=Depends(get_verified_user)):
|
||||
async def get_tools(request: Request, user=Depends(get_verified_user), db: Session = Depends(get_session)):
|
||||
tools = []
|
||||
|
||||
# Local Tools
|
||||
for tool in Tools.get_tools():
|
||||
for tool in Tools.get_tools(db=db):
|
||||
tool_module = get_tool_module(request, tool.id)
|
||||
tools.append(
|
||||
ToolUserResponse(
|
||||
|
|
@ -140,12 +142,12 @@ async def get_tools(request: Request, user=Depends(get_verified_user)):
|
|||
# Admin can see all tools
|
||||
return tools
|
||||
else:
|
||||
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user.id)}
|
||||
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user.id, db=db)}
|
||||
tools = [
|
||||
tool
|
||||
for tool in tools
|
||||
if tool.user_id == user.id
|
||||
or has_access(user.id, "read", tool.access_control, user_group_ids)
|
||||
or has_access(user.id, "read", tool.access_control, user_group_ids, db=db)
|
||||
]
|
||||
return tools
|
||||
|
||||
|
|
@ -156,11 +158,11 @@ async def get_tools(request: Request, user=Depends(get_verified_user)):
|
|||
|
||||
|
||||
@router.get("/list", response_model=list[ToolUserResponse])
|
||||
async def get_tool_list(user=Depends(get_verified_user)):
|
||||
async def get_tool_list(user=Depends(get_verified_user), db: Session = Depends(get_session)):
|
||||
if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL:
|
||||
tools = Tools.get_tools()
|
||||
tools = Tools.get_tools(db=db)
|
||||
else:
|
||||
tools = Tools.get_tools_by_user_id(user.id, "write")
|
||||
tools = Tools.get_tools_by_user_id(user.id, "write", db=db)
|
||||
return tools
|
||||
|
||||
|
||||
|
|
@ -245,9 +247,9 @@ async def load_tool_from_url(
|
|||
|
||||
|
||||
@router.get("/export", response_model=list[ToolModel])
|
||||
async def export_tools(request: Request, user=Depends(get_verified_user)):
|
||||
async def export_tools(request: Request, user=Depends(get_verified_user), db: Session = Depends(get_session)):
|
||||
if user.role != "admin" and not has_permission(
|
||||
user.id, "workspace.tools_export", request.app.state.config.USER_PERMISSIONS
|
||||
user.id, "workspace.tools_export", request.app.state.config.USER_PERMISSIONS, db=db
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
|
|
@ -255,9 +257,9 @@ async def export_tools(request: Request, user=Depends(get_verified_user)):
|
|||
)
|
||||
|
||||
if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL:
|
||||
return Tools.get_tools()
|
||||
return Tools.get_tools(db=db)
|
||||
else:
|
||||
return Tools.get_tools_by_user_id(user.id, "read")
|
||||
return Tools.get_tools_by_user_id(user.id, "read", db=db)
|
||||
|
||||
|
||||
############################
|
||||
|
|
@ -270,13 +272,14 @@ async def create_new_tools(
|
|||
request: Request,
|
||||
form_data: ToolForm,
|
||||
user=Depends(get_verified_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
if user.role != "admin" and not (
|
||||
has_permission(
|
||||
user.id, "workspace.tools", request.app.state.config.USER_PERMISSIONS
|
||||
user.id, "workspace.tools", request.app.state.config.USER_PERMISSIONS, db=db
|
||||
)
|
||||
or has_permission(
|
||||
user.id, "workspace.tools_import", request.app.state.config.USER_PERMISSIONS
|
||||
user.id, "workspace.tools_import", request.app.state.config.USER_PERMISSIONS, db=db
|
||||
)
|
||||
):
|
||||
raise HTTPException(
|
||||
|
|
@ -292,7 +295,7 @@ async def create_new_tools(
|
|||
|
||||
form_data.id = form_data.id.lower()
|
||||
|
||||
tools = Tools.get_tool_by_id(form_data.id)
|
||||
tools = Tools.get_tool_by_id(form_data.id, db=db)
|
||||
if tools is None:
|
||||
try:
|
||||
form_data.content = replace_imports(form_data.content)
|
||||
|
|
@ -305,7 +308,7 @@ async def create_new_tools(
|
|||
TOOLS[form_data.id] = tool_module
|
||||
|
||||
specs = get_tool_specs(TOOLS[form_data.id])
|
||||
tools = Tools.insert_new_tool(user.id, form_data, specs)
|
||||
tools = Tools.insert_new_tool(user.id, form_data, specs, db=db)
|
||||
|
||||
tool_cache_dir = CACHE_DIR / "tools" / form_data.id
|
||||
tool_cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
|
@ -336,14 +339,14 @@ async def create_new_tools(
|
|||
|
||||
|
||||
@router.get("/id/{id}", response_model=Optional[ToolModel])
|
||||
async def get_tools_by_id(id: str, user=Depends(get_verified_user)):
|
||||
tools = Tools.get_tool_by_id(id)
|
||||
async def get_tools_by_id(id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)):
|
||||
tools = Tools.get_tool_by_id(id, db=db)
|
||||
|
||||
if tools:
|
||||
if (
|
||||
user.role == "admin"
|
||||
or tools.user_id == user.id
|
||||
or has_access(user.id, "read", tools.access_control)
|
||||
or has_access(user.id, "read", tools.access_control, db=db)
|
||||
):
|
||||
return tools
|
||||
else:
|
||||
|
|
@ -364,8 +367,9 @@ async def update_tools_by_id(
|
|||
id: str,
|
||||
form_data: ToolForm,
|
||||
user=Depends(get_verified_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
tools = Tools.get_tool_by_id(id)
|
||||
tools = Tools.get_tool_by_id(id, db=db)
|
||||
if not tools:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
|
|
@ -375,7 +379,7 @@ async def update_tools_by_id(
|
|||
# Is the user the original creator, in a group with write access, or an admin
|
||||
if (
|
||||
tools.user_id != user.id
|
||||
and not has_access(user.id, "write", tools.access_control)
|
||||
and not has_access(user.id, "write", tools.access_control, db=db)
|
||||
and user.role != "admin"
|
||||
):
|
||||
raise HTTPException(
|
||||
|
|
@ -399,7 +403,7 @@ async def update_tools_by_id(
|
|||
}
|
||||
|
||||
log.debug(updated)
|
||||
tools = Tools.update_tool_by_id(id, updated)
|
||||
tools = Tools.update_tool_by_id(id, updated, db=db)
|
||||
|
||||
if tools:
|
||||
return tools
|
||||
|
|
@ -423,9 +427,9 @@ async def update_tools_by_id(
|
|||
|
||||
@router.delete("/id/{id}/delete", response_model=bool)
|
||||
async def delete_tools_by_id(
|
||||
request: Request, id: str, user=Depends(get_verified_user)
|
||||
request: Request, id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||
):
|
||||
tools = Tools.get_tool_by_id(id)
|
||||
tools = Tools.get_tool_by_id(id, db=db)
|
||||
if not tools:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
|
|
@ -434,7 +438,7 @@ async def delete_tools_by_id(
|
|||
|
||||
if (
|
||||
tools.user_id != user.id
|
||||
and not has_access(user.id, "write", tools.access_control)
|
||||
and not has_access(user.id, "write", tools.access_control, db=db)
|
||||
and user.role != "admin"
|
||||
):
|
||||
raise HTTPException(
|
||||
|
|
@ -442,7 +446,7 @@ async def delete_tools_by_id(
|
|||
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
||||
)
|
||||
|
||||
result = Tools.delete_tool_by_id(id)
|
||||
result = Tools.delete_tool_by_id(id, db=db)
|
||||
if result:
|
||||
TOOLS = request.app.state.TOOLS
|
||||
if id in TOOLS:
|
||||
|
|
@ -457,11 +461,11 @@ async def delete_tools_by_id(
|
|||
|
||||
|
||||
@router.get("/id/{id}/valves", response_model=Optional[dict])
|
||||
async def get_tools_valves_by_id(id: str, user=Depends(get_verified_user)):
|
||||
tools = Tools.get_tool_by_id(id)
|
||||
async def get_tools_valves_by_id(id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)):
|
||||
tools = Tools.get_tool_by_id(id, db=db)
|
||||
if tools:
|
||||
try:
|
||||
valves = Tools.get_tool_valves_by_id(id)
|
||||
valves = Tools.get_tool_valves_by_id(id, db=db)
|
||||
return valves
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
|
|
@ -482,9 +486,9 @@ async def get_tools_valves_by_id(id: str, user=Depends(get_verified_user)):
|
|||
|
||||
@router.get("/id/{id}/valves/spec", response_model=Optional[dict])
|
||||
async def get_tools_valves_spec_by_id(
|
||||
request: Request, id: str, user=Depends(get_verified_user)
|
||||
request: Request, id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||
):
|
||||
tools = Tools.get_tool_by_id(id)
|
||||
tools = Tools.get_tool_by_id(id, db=db)
|
||||
if tools:
|
||||
if id in request.app.state.TOOLS:
|
||||
tools_module = request.app.state.TOOLS[id]
|
||||
|
|
@ -510,9 +514,9 @@ async def get_tools_valves_spec_by_id(
|
|||
|
||||
@router.post("/id/{id}/valves/update", response_model=Optional[dict])
|
||||
async def update_tools_valves_by_id(
|
||||
request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
|
||||
request: Request, id: str, form_data: dict, user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||
):
|
||||
tools = Tools.get_tool_by_id(id)
|
||||
tools = Tools.get_tool_by_id(id, db=db)
|
||||
if not tools:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
|
|
@ -521,7 +525,7 @@ async def update_tools_valves_by_id(
|
|||
|
||||
if (
|
||||
tools.user_id != user.id
|
||||
and not has_access(user.id, "write", tools.access_control)
|
||||
and not has_access(user.id, "write", tools.access_control, db=db)
|
||||
and user.role != "admin"
|
||||
):
|
||||
raise HTTPException(
|
||||
|
|
@ -546,7 +550,7 @@ async def update_tools_valves_by_id(
|
|||
form_data = {k: v for k, v in form_data.items() if v is not None}
|
||||
valves = Valves(**form_data)
|
||||
valves_dict = valves.model_dump(exclude_unset=True)
|
||||
Tools.update_tool_valves_by_id(id, valves_dict)
|
||||
Tools.update_tool_valves_by_id(id, valves_dict, db=db)
|
||||
return valves_dict
|
||||
except Exception as e:
|
||||
log.exception(f"Failed to update tool valves by id {id}: {e}")
|
||||
|
|
@ -562,11 +566,11 @@ async def update_tools_valves_by_id(
|
|||
|
||||
|
||||
@router.get("/id/{id}/valves/user", response_model=Optional[dict])
|
||||
async def get_tools_user_valves_by_id(id: str, user=Depends(get_verified_user)):
|
||||
tools = Tools.get_tool_by_id(id)
|
||||
async def get_tools_user_valves_by_id(id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)):
|
||||
tools = Tools.get_tool_by_id(id, db=db)
|
||||
if tools:
|
||||
try:
|
||||
user_valves = Tools.get_user_valves_by_id_and_user_id(id, user.id)
|
||||
user_valves = Tools.get_user_valves_by_id_and_user_id(id, user.id, db=db)
|
||||
return user_valves
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
|
|
@ -582,9 +586,9 @@ async def get_tools_user_valves_by_id(id: str, user=Depends(get_verified_user)):
|
|||
|
||||
@router.get("/id/{id}/valves/user/spec", response_model=Optional[dict])
|
||||
async def get_tools_user_valves_spec_by_id(
|
||||
request: Request, id: str, user=Depends(get_verified_user)
|
||||
request: Request, id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||
):
|
||||
tools = Tools.get_tool_by_id(id)
|
||||
tools = Tools.get_tool_by_id(id, db=db)
|
||||
if tools:
|
||||
if id in request.app.state.TOOLS:
|
||||
tools_module = request.app.state.TOOLS[id]
|
||||
|
|
@ -605,9 +609,9 @@ async def get_tools_user_valves_spec_by_id(
|
|||
|
||||
@router.post("/id/{id}/valves/user/update", response_model=Optional[dict])
|
||||
async def update_tools_user_valves_by_id(
|
||||
request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
|
||||
request: Request, id: str, form_data: dict, user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||
):
|
||||
tools = Tools.get_tool_by_id(id)
|
||||
tools = Tools.get_tool_by_id(id, db=db)
|
||||
|
||||
if tools:
|
||||
if id in request.app.state.TOOLS:
|
||||
|
|
@ -624,7 +628,7 @@ async def update_tools_user_valves_by_id(
|
|||
user_valves = UserValves(**form_data)
|
||||
user_valves_dict = user_valves.model_dump(exclude_unset=True)
|
||||
Tools.update_user_valves_by_id_and_user_id(
|
||||
id, user.id, user_valves_dict
|
||||
id, user.id, user_valves_dict, db=db
|
||||
)
|
||||
return user_valves_dict
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import logging
|
||||
from typing import Optional
|
||||
from sqlalchemy.orm import Session
|
||||
import base64
|
||||
import io
|
||||
|
||||
|
|
@ -29,6 +30,7 @@ from open_webui.models.users import (
|
|||
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from open_webui.env import STATIC_DIR
|
||||
from open_webui.internal.db import get_session
|
||||
|
||||
|
||||
from open_webui.utils.auth import (
|
||||
|
|
@ -60,6 +62,7 @@ async def get_users(
|
|||
direction: Optional[str] = None,
|
||||
page: Optional[int] = 1,
|
||||
user=Depends(get_admin_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
limit = PAGE_ITEM_COUNT
|
||||
|
||||
|
|
@ -74,7 +77,9 @@ async def get_users(
|
|||
if direction:
|
||||
filter["direction"] = direction
|
||||
|
||||
result = Users.get_users(filter=filter, skip=skip, limit=limit)
|
||||
filter["direction"] = direction
|
||||
|
||||
result = Users.get_users(filter=filter, skip=skip, limit=limit, db=db)
|
||||
|
||||
users = result["users"]
|
||||
total = result["total"]
|
||||
|
|
@ -85,7 +90,8 @@ async def get_users(
|
|||
**{
|
||||
**user.model_dump(),
|
||||
"group_ids": [
|
||||
group.id for group in Groups.get_groups_by_member_id(user.id)
|
||||
group.id
|
||||
for group in Groups.get_groups_by_member_id(user.id, db=db)
|
||||
],
|
||||
}
|
||||
)
|
||||
|
|
@ -98,8 +104,9 @@ async def get_users(
|
|||
@router.get("/all", response_model=UserInfoListResponse)
|
||||
async def get_all_users(
|
||||
user=Depends(get_admin_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
return Users.get_users()
|
||||
return Users.get_users(db=db)
|
||||
|
||||
|
||||
@router.get("/search", response_model=UserInfoListResponse)
|
||||
|
|
@ -109,16 +116,13 @@ async def search_users(
|
|||
direction: Optional[str] = None,
|
||||
page: Optional[int] = 1,
|
||||
user=Depends(get_verified_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
limit = PAGE_ITEM_COUNT
|
||||
|
||||
page = max(1, page)
|
||||
skip = (page - 1) * limit
|
||||
|
||||
filter = {}
|
||||
if query:
|
||||
filter["query"] = query
|
||||
|
||||
filter = {}
|
||||
if query:
|
||||
filter["query"] = query
|
||||
|
|
@ -127,7 +131,7 @@ async def search_users(
|
|||
if direction:
|
||||
filter["direction"] = direction
|
||||
|
||||
return Users.get_users(filter=filter, skip=skip, limit=limit)
|
||||
return Users.get_users(filter=filter, skip=skip, limit=limit, db=db)
|
||||
|
||||
|
||||
############################
|
||||
|
|
@ -136,8 +140,10 @@ async def search_users(
|
|||
|
||||
|
||||
@router.get("/groups")
|
||||
async def get_user_groups(user=Depends(get_verified_user)):
|
||||
return Groups.get_groups_by_member_id(user.id)
|
||||
async def get_user_groups(
|
||||
user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||
):
|
||||
return Groups.get_groups_by_member_id(user.id, db=db)
|
||||
|
||||
|
||||
############################
|
||||
|
|
@ -146,9 +152,13 @@ async def get_user_groups(user=Depends(get_verified_user)):
|
|||
|
||||
|
||||
@router.get("/permissions")
|
||||
async def get_user_permissisions(request: Request, user=Depends(get_verified_user)):
|
||||
async def get_user_permissisions(
|
||||
request: Request,
|
||||
user=Depends(get_verified_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
user_permissions = get_permissions(
|
||||
user.id, request.app.state.config.USER_PERMISSIONS
|
||||
user.id, request.app.state.config.USER_PERMISSIONS, db=db
|
||||
)
|
||||
|
||||
return user_permissions
|
||||
|
|
@ -256,8 +266,10 @@ async def update_default_user_permissions(
|
|||
|
||||
|
||||
@router.get("/user/settings", response_model=Optional[UserSettings])
|
||||
async def get_user_settings_by_session_user(user=Depends(get_verified_user)):
|
||||
user = Users.get_user_by_id(user.id)
|
||||
async def get_user_settings_by_session_user(
|
||||
user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||
):
|
||||
user = Users.get_user_by_id(user.id, db=db)
|
||||
if user:
|
||||
return user.settings
|
||||
else:
|
||||
|
|
@ -274,7 +286,10 @@ async def get_user_settings_by_session_user(user=Depends(get_verified_user)):
|
|||
|
||||
@router.post("/user/settings/update", response_model=UserSettings)
|
||||
async def update_user_settings_by_session_user(
|
||||
request: Request, form_data: UserSettings, user=Depends(get_verified_user)
|
||||
request: Request,
|
||||
form_data: UserSettings,
|
||||
user=Depends(get_verified_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
updated_user_settings = form_data.model_dump()
|
||||
if (
|
||||
|
|
@ -289,7 +304,7 @@ async def update_user_settings_by_session_user(
|
|||
# If the user is not an admin and does not have permission to use tool servers, remove the key
|
||||
updated_user_settings["ui"].pop("toolServers", None)
|
||||
|
||||
user = Users.update_user_settings_by_id(user.id, updated_user_settings)
|
||||
user = Users.update_user_settings_by_id(user.id, updated_user_settings, db=db)
|
||||
if user:
|
||||
return user.settings
|
||||
else:
|
||||
|
|
@ -305,8 +320,10 @@ async def update_user_settings_by_session_user(
|
|||
|
||||
|
||||
@router.get("/user/status")
|
||||
async def get_user_status_by_session_user(user=Depends(get_verified_user)):
|
||||
user = Users.get_user_by_id(user.id)
|
||||
async def get_user_status_by_session_user(
|
||||
user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||
):
|
||||
user = Users.get_user_by_id(user.id, db=db)
|
||||
if user:
|
||||
return user
|
||||
else:
|
||||
|
|
@ -323,11 +340,13 @@ async def get_user_status_by_session_user(user=Depends(get_verified_user)):
|
|||
|
||||
@router.post("/user/status/update")
|
||||
async def update_user_status_by_session_user(
|
||||
form_data: UserStatus, user=Depends(get_verified_user)
|
||||
form_data: UserStatus,
|
||||
user=Depends(get_verified_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
user = Users.get_user_by_id(user.id)
|
||||
user = Users.get_user_by_id(user.id, db=db)
|
||||
if user:
|
||||
user = Users.update_user_status_by_id(user.id, form_data)
|
||||
user = Users.update_user_status_by_id(user.id, form_data, db=db)
|
||||
return user
|
||||
else:
|
||||
raise HTTPException(
|
||||
|
|
@ -342,8 +361,10 @@ async def update_user_status_by_session_user(
|
|||
|
||||
|
||||
@router.get("/user/info", response_model=Optional[dict])
|
||||
async def get_user_info_by_session_user(user=Depends(get_verified_user)):
|
||||
user = Users.get_user_by_id(user.id)
|
||||
async def get_user_info_by_session_user(
|
||||
user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||
):
|
||||
user = Users.get_user_by_id(user.id, db=db)
|
||||
if user:
|
||||
return user.info
|
||||
else:
|
||||
|
|
@ -360,14 +381,16 @@ async def get_user_info_by_session_user(user=Depends(get_verified_user)):
|
|||
|
||||
@router.post("/user/info/update", response_model=Optional[dict])
|
||||
async def update_user_info_by_session_user(
|
||||
form_data: dict, user=Depends(get_verified_user)
|
||||
form_data: dict, user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||
):
|
||||
user = Users.get_user_by_id(user.id)
|
||||
user = Users.get_user_by_id(user.id, db=db)
|
||||
if user:
|
||||
if user.info is None:
|
||||
user.info = {}
|
||||
|
||||
user = Users.update_user_by_id(user.id, {"info": {**user.info, **form_data}})
|
||||
user = Users.update_user_by_id(
|
||||
user.id, {"info": {**user.info, **form_data}}, db=db
|
||||
)
|
||||
if user:
|
||||
return user.info
|
||||
else:
|
||||
|
|
@ -397,7 +420,9 @@ class UserActiveResponse(UserStatus):
|
|||
|
||||
|
||||
@router.get("/{user_id}", response_model=UserActiveResponse)
|
||||
async def get_user_by_id(user_id: str, user=Depends(get_verified_user)):
|
||||
async def get_user_by_id(
|
||||
user_id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||
):
|
||||
# Check if user_id is a shared chat
|
||||
# If it is, get the user_id from the chat
|
||||
if user_id.startswith("shared-"):
|
||||
|
|
@ -411,14 +436,14 @@ async def get_user_by_id(user_id: str, user=Depends(get_verified_user)):
|
|||
detail=ERROR_MESSAGES.USER_NOT_FOUND,
|
||||
)
|
||||
|
||||
user = Users.get_user_by_id(user_id)
|
||||
user = Users.get_user_by_id(user_id, db=db)
|
||||
if user:
|
||||
groups = Groups.get_groups_by_member_id(user_id)
|
||||
groups = Groups.get_groups_by_member_id(user_id, db=db)
|
||||
return UserActiveResponse(
|
||||
**{
|
||||
**user.model_dump(),
|
||||
"groups": [{"id": group.id, "name": group.name} for group in groups],
|
||||
"is_active": Users.is_user_active(user_id),
|
||||
"is_active": Users.is_user_active(user_id, db=db),
|
||||
}
|
||||
)
|
||||
else:
|
||||
|
|
@ -429,8 +454,10 @@ async def get_user_by_id(user_id: str, user=Depends(get_verified_user)):
|
|||
|
||||
|
||||
@router.get("/{user_id}/oauth/sessions")
|
||||
async def get_user_oauth_sessions_by_id(user_id: str, user=Depends(get_admin_user)):
|
||||
sessions = OAuthSessions.get_sessions_by_user_id(user_id)
|
||||
async def get_user_oauth_sessions_by_id(
|
||||
user_id: str, user=Depends(get_admin_user), db: Session = Depends(get_session)
|
||||
):
|
||||
sessions = OAuthSessions.get_sessions_by_user_id(user_id, db=db)
|
||||
if sessions and len(sessions) > 0:
|
||||
return sessions
|
||||
else:
|
||||
|
|
@ -446,8 +473,10 @@ async def get_user_oauth_sessions_by_id(user_id: str, user=Depends(get_admin_use
|
|||
|
||||
|
||||
@router.get("/{user_id}/profile/image")
|
||||
async def get_user_profile_image_by_id(user_id: str, user=Depends(get_verified_user)):
|
||||
user = Users.get_user_by_id(user_id)
|
||||
async def get_user_profile_image_by_id(
|
||||
user_id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||
):
|
||||
user = Users.get_user_by_id(user_id, db=db)
|
||||
if user:
|
||||
if user.profile_image_url:
|
||||
# check if it's url or base64
|
||||
|
|
@ -484,9 +513,11 @@ async def get_user_profile_image_by_id(user_id: str, user=Depends(get_verified_u
|
|||
|
||||
|
||||
@router.get("/{user_id}/active", response_model=dict)
|
||||
async def get_user_active_status_by_id(user_id: str, user=Depends(get_verified_user)):
|
||||
async def get_user_active_status_by_id(
|
||||
user_id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||
):
|
||||
return {
|
||||
"active": Users.is_user_active(user_id),
|
||||
"active": Users.is_user_active(user_id, db=db),
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -500,10 +531,11 @@ async def update_user_by_id(
|
|||
user_id: str,
|
||||
form_data: UserUpdateForm,
|
||||
session_user=Depends(get_admin_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
# Prevent modification of the primary admin user by other admins
|
||||
try:
|
||||
first_user = Users.get_first_user()
|
||||
first_user = Users.get_first_user(db=db)
|
||||
if first_user:
|
||||
if user_id == first_user.id:
|
||||
if session_user.id != user_id:
|
||||
|
|
@ -527,11 +559,11 @@ async def update_user_by_id(
|
|||
detail="Could not verify primary admin status.",
|
||||
)
|
||||
|
||||
user = Users.get_user_by_id(user_id)
|
||||
user = Users.get_user_by_id(user_id, db=db)
|
||||
|
||||
if user:
|
||||
if form_data.email.lower() != user.email:
|
||||
email_user = Users.get_user_by_email(form_data.email.lower())
|
||||
email_user = Users.get_user_by_email(form_data.email.lower(), db=db)
|
||||
if email_user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
|
|
@ -545,9 +577,9 @@ async def update_user_by_id(
|
|||
raise HTTPException(400, detail=str(e))
|
||||
|
||||
hashed = get_password_hash(form_data.password)
|
||||
Auths.update_user_password_by_id(user_id, hashed)
|
||||
Auths.update_user_password_by_id(user_id, hashed, db=db)
|
||||
|
||||
Auths.update_email_by_id(user_id, form_data.email.lower())
|
||||
Auths.update_email_by_id(user_id, form_data.email.lower(), db=db)
|
||||
updated_user = Users.update_user_by_id(
|
||||
user_id,
|
||||
{
|
||||
|
|
@ -556,6 +588,7 @@ async def update_user_by_id(
|
|||
"email": form_data.email.lower(),
|
||||
"profile_image_url": form_data.profile_image_url,
|
||||
},
|
||||
db=db,
|
||||
)
|
||||
|
||||
if updated_user:
|
||||
|
|
@ -578,10 +611,12 @@ async def update_user_by_id(
|
|||
|
||||
|
||||
@router.delete("/{user_id}", response_model=bool)
|
||||
async def delete_user_by_id(user_id: str, user=Depends(get_admin_user)):
|
||||
async def delete_user_by_id(
|
||||
user_id: str, user=Depends(get_admin_user), db: Session = Depends(get_session)
|
||||
):
|
||||
# Prevent deletion of the primary admin user
|
||||
try:
|
||||
first_user = Users.get_first_user()
|
||||
first_user = Users.get_first_user(db=db)
|
||||
if first_user and user_id == first_user.id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
|
|
@ -595,7 +630,7 @@ async def delete_user_by_id(user_id: str, user=Depends(get_admin_user)):
|
|||
)
|
||||
|
||||
if user.id != user_id:
|
||||
result = Auths.delete_auth_by_id(user_id)
|
||||
result = Auths.delete_auth_by_id(user_id, db=db)
|
||||
|
||||
if result:
|
||||
return True
|
||||
|
|
@ -618,5 +653,7 @@ async def delete_user_by_id(user_id: str, user=Depends(get_admin_user)):
|
|||
|
||||
|
||||
@router.get("/{user_id}/groups")
|
||||
async def get_user_groups_by_id(user_id: str, user=Depends(get_admin_user)):
|
||||
return Groups.get_groups_by_member_id(user_id)
|
||||
async def get_user_groups_by_id(
|
||||
user_id: str, user=Depends(get_admin_user), db: Session = Depends(get_session)
|
||||
):
|
||||
return Groups.get_groups_by_member_id(user_id, db=db)
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ def fill_missing_permissions(
|
|||
def get_permissions(
|
||||
user_id: str,
|
||||
default_permissions: Dict[str, Any],
|
||||
db: Optional[Any] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get all permissions for a user by combining the permissions of all groups the user is a member of.
|
||||
|
|
@ -53,7 +54,7 @@ def get_permissions(
|
|||
) # Use the most permissive value (True > False)
|
||||
return permissions
|
||||
|
||||
user_groups = Groups.get_groups_by_member_id(user_id)
|
||||
user_groups = Groups.get_groups_by_member_id(user_id, db=db)
|
||||
|
||||
# Deep copy default permissions to avoid modifying the original dict
|
||||
permissions = json.loads(json.dumps(default_permissions))
|
||||
|
|
@ -72,6 +73,7 @@ def has_permission(
|
|||
user_id: str,
|
||||
permission_key: str,
|
||||
default_permissions: Dict[str, Any] = {},
|
||||
db: Optional[Any] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a user has a specific permission by checking the group permissions
|
||||
|
|
@ -92,7 +94,7 @@ def has_permission(
|
|||
permission_hierarchy = permission_key.split(".")
|
||||
|
||||
# Retrieve user group permissions
|
||||
user_groups = Groups.get_groups_by_member_id(user_id)
|
||||
user_groups = Groups.get_groups_by_member_id(user_id, db=db)
|
||||
|
||||
for group in user_groups:
|
||||
if get_permission(group.permissions or {}, permission_hierarchy):
|
||||
|
|
@ -127,6 +129,7 @@ def has_access(
|
|||
access_control: Optional[dict] = None,
|
||||
user_group_ids: Optional[Set[str]] = None,
|
||||
strict: bool = True,
|
||||
db: Optional[Any] = None,
|
||||
) -> bool:
|
||||
if access_control is None:
|
||||
if strict:
|
||||
|
|
@ -135,7 +138,7 @@ def has_access(
|
|||
return True
|
||||
|
||||
if user_group_ids is None:
|
||||
user_groups = Groups.get_groups_by_member_id(user_id)
|
||||
user_groups = Groups.get_groups_by_member_id(user_id, db=db)
|
||||
user_group_ids = {group.id for group in user_groups}
|
||||
|
||||
permitted_ids = get_permitted_group_and_user_ids(type, access_control)
|
||||
|
|
@ -152,10 +155,10 @@ def has_access(
|
|||
|
||||
# Get all users with access to a resource
|
||||
def get_users_with_access(
|
||||
type: str = "write", access_control: Optional[dict] = None
|
||||
type: str = "write", access_control: Optional[dict] = None, db: Optional[Any] = None
|
||||
) -> list[UserModel]:
|
||||
if access_control is None:
|
||||
result = Users.get_users(filter={"roles": ["!pending"]})
|
||||
result = Users.get_users(filter={"roles": ["!pending"]}, db=db)
|
||||
return result.get("users", [])
|
||||
|
||||
permitted_ids = get_permitted_group_and_user_ids(type, access_control)
|
||||
|
|
@ -167,8 +170,8 @@ def get_users_with_access(
|
|||
|
||||
user_ids_with_access = set(permitted_user_ids)
|
||||
|
||||
group_user_ids_map = Groups.get_group_user_ids_by_ids(permitted_group_ids)
|
||||
group_user_ids_map = Groups.get_group_user_ids_by_ids(permitted_group_ids, db=db)
|
||||
for user_ids in group_user_ids_map.values():
|
||||
user_ids_with_access.update(user_ids)
|
||||
|
||||
return Users.get_users_by_user_ids(list(user_ids_with_access))
|
||||
return Users.get_users_by_user_ids(list(user_ids_with_access), db=db)
|
||||
|
|
|
|||
|
|
@ -42,6 +42,8 @@ from open_webui.env import (
|
|||
|
||||
from fastapi import BackgroundTasks, Depends, HTTPException, Request, Response, status
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from sqlalchemy.orm import Session
|
||||
from open_webui.internal.db import get_session
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
|
@ -271,6 +273,7 @@ async def get_current_user(
|
|||
response: Response,
|
||||
background_tasks: BackgroundTasks,
|
||||
auth_token: HTTPAuthorizationCredentials = Depends(bearer_security),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
token = None
|
||||
|
||||
|
|
@ -285,7 +288,7 @@ async def get_current_user(
|
|||
|
||||
# auth by api key
|
||||
if token.startswith("sk-"):
|
||||
user = get_current_user_by_api_key(request, token)
|
||||
user = get_current_user_by_api_key(request, token, db=db)
|
||||
|
||||
# Add user info to current span
|
||||
current_span = trace.get_current_span()
|
||||
|
|
@ -314,7 +317,7 @@ async def get_current_user(
|
|||
detail="Invalid token",
|
||||
)
|
||||
|
||||
user = Users.get_user_by_id(data["id"])
|
||||
user = Users.get_user_by_id(data["id"], db=db)
|
||||
if user is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
|
|
@ -364,8 +367,8 @@ async def get_current_user(
|
|||
raise e
|
||||
|
||||
|
||||
def get_current_user_by_api_key(request, api_key: str):
|
||||
user = Users.get_user_by_api_key(api_key)
|
||||
def get_current_user_by_api_key(request, api_key: str, db: Session = None):
|
||||
user = Users.get_user_by_api_key(api_key, db=db)
|
||||
|
||||
if user is None:
|
||||
raise HTTPException(
|
||||
|
|
@ -393,7 +396,7 @@ def get_current_user_by_api_key(request, api_key: str):
|
|||
current_span.set_attribute("client.user.role", user.role)
|
||||
current_span.set_attribute("client.auth.type", "api_key")
|
||||
|
||||
Users.update_last_active_by_id(user.id)
|
||||
Users.update_last_active_by_id(user.id, db=db)
|
||||
return user
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ log = logging.getLogger(__name__)
|
|||
def apply_default_group_assignment(
|
||||
default_group_id: str,
|
||||
user_id: str,
|
||||
db=None,
|
||||
) -> None:
|
||||
"""
|
||||
Apply default group assignment to a user if default_group_id is provided.
|
||||
|
|
@ -17,7 +18,7 @@ def apply_default_group_assignment(
|
|||
"""
|
||||
if default_group_id:
|
||||
try:
|
||||
Groups.add_users_to_group(default_group_id, [user_id])
|
||||
Groups.add_users_to_group(default_group_id, [user_id], db=db)
|
||||
except Exception as e:
|
||||
log.error(
|
||||
f"Failed to add user {user_id} to default group {default_group_id}: {e}"
|
||||
|
|
|
|||
|
|
@ -1336,7 +1336,7 @@ class OAuthManager:
|
|||
|
||||
return await client.authorize_redirect(request, redirect_uri, **kwargs)
|
||||
|
||||
async def handle_callback(self, request, provider, response):
|
||||
async def handle_callback(self, request, provider, response, db=None):
|
||||
if provider not in OAUTH_PROVIDERS:
|
||||
raise HTTPException(404)
|
||||
|
||||
|
|
@ -1461,20 +1461,20 @@ class OAuthManager:
|
|||
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
||||
|
||||
# Check if the user exists
|
||||
user = Users.get_user_by_oauth_sub(provider, sub)
|
||||
user = Users.get_user_by_oauth_sub(provider, sub, db=db)
|
||||
if not user:
|
||||
# If the user does not exist, check if merging is enabled
|
||||
if auth_manager_config.OAUTH_MERGE_ACCOUNTS_BY_EMAIL:
|
||||
# Check if the user exists by email
|
||||
user = Users.get_user_by_email(email)
|
||||
user = Users.get_user_by_email(email, db=db)
|
||||
if user:
|
||||
# Update the user with the new oauth sub
|
||||
Users.update_user_oauth_by_id(user.id, provider, sub)
|
||||
Users.update_user_oauth_by_id(user.id, provider, sub, db=db)
|
||||
|
||||
if user:
|
||||
determined_role = self.get_user_role(user, user_data)
|
||||
if user.role != determined_role:
|
||||
Users.update_user_role_by_id(user.id, determined_role)
|
||||
Users.update_user_role_by_id(user.id, determined_role, db=db)
|
||||
# Update the user object in memory as well,
|
||||
# to avoid problems with the ENABLE_OAUTH_GROUP_MANAGEMENT check below
|
||||
user.role = determined_role
|
||||
|
|
@ -1491,14 +1491,14 @@ class OAuthManager:
|
|||
)
|
||||
if processed_picture_url != user.profile_image_url:
|
||||
Users.update_user_profile_image_url_by_id(
|
||||
user.id, processed_picture_url
|
||||
user.id, processed_picture_url, db=db
|
||||
)
|
||||
log.debug(f"Updated profile picture for user {user.email}")
|
||||
else:
|
||||
# If the user does not exist, check if signups are enabled
|
||||
if auth_manager_config.ENABLE_OAUTH_SIGNUP:
|
||||
# Check if an existing user with the same email already exists
|
||||
existing_user = Users.get_user_by_email(email)
|
||||
existing_user = Users.get_user_by_email(email, db=db)
|
||||
if existing_user:
|
||||
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
|
||||
|
||||
|
|
@ -1529,6 +1529,7 @@ class OAuthManager:
|
|||
profile_image_url=picture_url,
|
||||
role=self.get_user_role(None, user_data),
|
||||
oauth=oauth_data,
|
||||
db=db,
|
||||
)
|
||||
|
||||
if auth_manager_config.WEBHOOK_URL:
|
||||
|
|
@ -1544,8 +1545,7 @@ class OAuthManager:
|
|||
)
|
||||
|
||||
apply_default_group_assignment(
|
||||
request.app.state.config.DEFAULT_GROUP_ID,
|
||||
user.id,
|
||||
request.app.state.config.DEFAULT_GROUP_ID, user.id, db=db
|
||||
)
|
||||
|
||||
else:
|
||||
|
|
@ -1616,15 +1616,16 @@ class OAuthManager:
|
|||
token["expires_at"] = datetime.now().timestamp() + token["expires_in"]
|
||||
|
||||
# Clean up any existing sessions for this user/provider first
|
||||
sessions = OAuthSessions.get_sessions_by_user_id(user.id)
|
||||
sessions = OAuthSessions.get_sessions_by_user_id(user.id, db=db)
|
||||
for session in sessions:
|
||||
if session.provider == provider:
|
||||
OAuthSessions.delete_session_by_id(session.id)
|
||||
OAuthSessions.delete_session_by_id(session.id, db=db)
|
||||
|
||||
session = OAuthSessions.create_session(
|
||||
user_id=user.id,
|
||||
provider=provider,
|
||||
token=token,
|
||||
db=db,
|
||||
)
|
||||
|
||||
response.set_cookie(
|
||||
|
|
|
|||
Loading…
Reference in a new issue