mirror of
https://github.com/open-webui/open-webui.git
synced 2026-01-02 06:35:20 +00:00
refac/enh: db session sharing
This commit is contained in:
parent
6dd0f99b90
commit
b1d0f00d8c
23 changed files with 1173 additions and 663 deletions
|
|
@ -19,7 +19,7 @@ from open_webui.env import (
|
||||||
from peewee_migrate import Router
|
from peewee_migrate import Router
|
||||||
from sqlalchemy import Dialect, create_engine, MetaData, event, types
|
from sqlalchemy import Dialect, create_engine, MetaData, event, types
|
||||||
from sqlalchemy.ext.declarative import declarative_base
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
from sqlalchemy.orm import scoped_session, sessionmaker
|
from sqlalchemy.orm import scoped_session, sessionmaker, Session
|
||||||
from sqlalchemy.pool import QueuePool, NullPool
|
from sqlalchemy.pool import QueuePool, NullPool
|
||||||
from sqlalchemy.sql.type_api import _T
|
from sqlalchemy.sql.type_api import _T
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
@ -148,7 +148,7 @@ SessionLocal = sessionmaker(
|
||||||
)
|
)
|
||||||
metadata_obj = MetaData(schema=DATABASE_SCHEMA)
|
metadata_obj = MetaData(schema=DATABASE_SCHEMA)
|
||||||
Base = declarative_base(metadata=metadata_obj)
|
Base = declarative_base(metadata=metadata_obj)
|
||||||
Session = scoped_session(SessionLocal)
|
ScopedSession = scoped_session(SessionLocal)
|
||||||
|
|
||||||
|
|
||||||
def get_session():
|
def get_session():
|
||||||
|
|
@ -169,4 +169,3 @@ def get_db_context(db: Optional[Session] = None):
|
||||||
else:
|
else:
|
||||||
with get_db() as session:
|
with get_db() as session:
|
||||||
yield session
|
yield session
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -102,7 +102,9 @@ from open_webui.routers.retrieval import (
|
||||||
get_rf,
|
get_rf,
|
||||||
)
|
)
|
||||||
|
|
||||||
from open_webui.internal.db import Session, engine
|
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from open_webui.internal.db import ScopedSession, engine, get_session
|
||||||
|
|
||||||
from open_webui.models.functions import Functions
|
from open_webui.models.functions import Functions
|
||||||
from open_webui.models.models import Models
|
from open_webui.models.models import Models
|
||||||
|
|
@ -1324,7 +1326,7 @@ app.add_middleware(APIKeyRestrictionMiddleware)
|
||||||
async def commit_session_after_request(request: Request, call_next):
|
async def commit_session_after_request(request: Request, call_next):
|
||||||
response = await call_next(request)
|
response = await call_next(request)
|
||||||
# log.debug("Commit session after request")
|
# log.debug("Commit session after request")
|
||||||
Session.commit()
|
ScopedSession.commit()
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -2280,8 +2282,13 @@ async def oauth_login(provider: str, request: Request):
|
||||||
# - Email addresses are considered unique, so we fail registration if the email address is already taken
|
# - Email addresses are considered unique, so we fail registration if the email address is already taken
|
||||||
@app.get("/oauth/{provider}/login/callback")
|
@app.get("/oauth/{provider}/login/callback")
|
||||||
@app.get("/oauth/{provider}/callback") # Legacy endpoint
|
@app.get("/oauth/{provider}/callback") # Legacy endpoint
|
||||||
async def oauth_login_callback(provider: str, request: Request, response: Response):
|
async def oauth_login_callback(
|
||||||
return await oauth_manager.handle_callback(request, provider, response)
|
provider: str,
|
||||||
|
request: Request,
|
||||||
|
response: Response,
|
||||||
|
db: Session = Depends(get_session),
|
||||||
|
):
|
||||||
|
return await oauth_manager.handle_callback(request, provider, response, db=db)
|
||||||
|
|
||||||
|
|
||||||
@app.get("/manifest.json")
|
@app.get("/manifest.json")
|
||||||
|
|
@ -2340,7 +2347,7 @@ async def healthcheck():
|
||||||
|
|
||||||
@app.get("/health/db")
|
@app.get("/health/db")
|
||||||
async def healthcheck_with_db():
|
async def healthcheck_with_db():
|
||||||
Session.execute(text("SELECT 1;")).all()
|
ScopedSession.execute(text("SELECT 1;")).all()
|
||||||
return {"status": True}
|
return {"status": True}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -30,6 +30,7 @@ from sqlalchemy.exc import NoSuchTableError
|
||||||
from sqlalchemy.dialects.postgresql.psycopg2 import PGDialect_psycopg2
|
from sqlalchemy.dialects.postgresql.psycopg2 import PGDialect_psycopg2
|
||||||
from sqlalchemy.dialects import registry
|
from sqlalchemy.dialects import registry
|
||||||
|
|
||||||
|
|
||||||
class OpenGaussDialect(PGDialect_psycopg2):
|
class OpenGaussDialect(PGDialect_psycopg2):
|
||||||
name = "opengauss"
|
name = "opengauss"
|
||||||
|
|
||||||
|
|
@ -40,9 +41,7 @@ class OpenGaussDialect(PGDialect_psycopg2):
|
||||||
return (9, 0, 0)
|
return (9, 0, 0)
|
||||||
|
|
||||||
match = re.search(
|
match = re.search(
|
||||||
r"openGauss\s+(\d+)\.(\d+)\.(\d+)(?:-\w+)?",
|
r"openGauss\s+(\d+)\.(\d+)\.(\d+)(?:-\w+)?", version, re.IGNORECASE
|
||||||
version,
|
|
||||||
re.IGNORECASE
|
|
||||||
)
|
)
|
||||||
if match:
|
if match:
|
||||||
return (int(match.group(1)), int(match.group(2)), int(match.group(3)))
|
return (int(match.group(1)), int(match.group(2)), int(match.group(3)))
|
||||||
|
|
@ -51,6 +50,7 @@ class OpenGaussDialect(PGDialect_psycopg2):
|
||||||
except Exception:
|
except Exception:
|
||||||
return (9, 0, 0)
|
return (9, 0, 0)
|
||||||
|
|
||||||
|
|
||||||
# Register dialect
|
# Register dialect
|
||||||
registry.register("opengauss", __name__, "OpenGaussDialect")
|
registry.register("opengauss", __name__, "OpenGaussDialect")
|
||||||
|
|
||||||
|
|
@ -78,6 +78,7 @@ Base = declarative_base()
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
|
|
||||||
|
|
||||||
class DocumentChunk(Base):
|
class DocumentChunk(Base):
|
||||||
__tablename__ = "document_chunk"
|
__tablename__ = "document_chunk"
|
||||||
|
|
||||||
|
|
@ -87,25 +88,26 @@ class DocumentChunk(Base):
|
||||||
text = Column(Text, nullable=True)
|
text = Column(Text, nullable=True)
|
||||||
vmetadata = Column(MutableDict.as_mutable(JSONB), nullable=True)
|
vmetadata = Column(MutableDict.as_mutable(JSONB), nullable=True)
|
||||||
|
|
||||||
|
|
||||||
class OpenGaussClient(VectorDBBase):
|
class OpenGaussClient(VectorDBBase):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
if not OPENGAUSS_DB_URL:
|
if not OPENGAUSS_DB_URL:
|
||||||
from open_webui.internal.db import Session
|
from open_webui.internal.db import ScopedSession
|
||||||
self.session = Session
|
|
||||||
|
self.session = ScopedSession
|
||||||
else:
|
else:
|
||||||
engine_kwargs = {
|
engine_kwargs = {"pool_pre_ping": True, "dialect": OpenGaussDialect()}
|
||||||
"pool_pre_ping": True,
|
|
||||||
"dialect": OpenGaussDialect()
|
|
||||||
}
|
|
||||||
|
|
||||||
if isinstance(OPENGAUSS_POOL_SIZE, int) and OPENGAUSS_POOL_SIZE > 0:
|
if isinstance(OPENGAUSS_POOL_SIZE, int) and OPENGAUSS_POOL_SIZE > 0:
|
||||||
engine_kwargs.update({
|
engine_kwargs.update(
|
||||||
|
{
|
||||||
"pool_size": OPENGAUSS_POOL_SIZE,
|
"pool_size": OPENGAUSS_POOL_SIZE,
|
||||||
"max_overflow": OPENGAUSS_POOL_MAX_OVERFLOW,
|
"max_overflow": OPENGAUSS_POOL_MAX_OVERFLOW,
|
||||||
"pool_timeout": OPENGAUSS_POOL_TIMEOUT,
|
"pool_timeout": OPENGAUSS_POOL_TIMEOUT,
|
||||||
"pool_recycle": OPENGAUSS_POOL_RECYCLE,
|
"pool_recycle": OPENGAUSS_POOL_RECYCLE,
|
||||||
"poolclass": QueuePool
|
"poolclass": QueuePool,
|
||||||
})
|
}
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
engine_kwargs["poolclass"] = NullPool
|
engine_kwargs["poolclass"] = NullPool
|
||||||
|
|
||||||
|
|
@ -160,7 +162,9 @@ class OpenGaussClient(VectorDBBase):
|
||||||
else:
|
else:
|
||||||
raise Exception("The 'vector' column type is not Vector.")
|
raise Exception("The 'vector' column type is not Vector.")
|
||||||
else:
|
else:
|
||||||
raise Exception("The 'vector' column does not exist in the 'document_chunk' table.")
|
raise Exception(
|
||||||
|
"The 'vector' column does not exist in the 'document_chunk' table."
|
||||||
|
)
|
||||||
|
|
||||||
def adjust_vector_length(self, vector: List[float]) -> List[float]:
|
def adjust_vector_length(self, vector: List[float]) -> List[float]:
|
||||||
current_length = len(vector)
|
current_length = len(vector)
|
||||||
|
|
@ -185,7 +189,9 @@ class OpenGaussClient(VectorDBBase):
|
||||||
new_items.append(new_chunk)
|
new_items.append(new_chunk)
|
||||||
self.session.bulk_save_objects(new_items)
|
self.session.bulk_save_objects(new_items)
|
||||||
self.session.commit()
|
self.session.commit()
|
||||||
log.info(f"Inserting {len(new_items)} items into collection '{collection_name}'.")
|
log.info(
|
||||||
|
f"Inserting {len(new_items)} items into collection '{collection_name}'."
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.session.rollback()
|
self.session.rollback()
|
||||||
log.exception(f"Failed to insert data: {e}")
|
log.exception(f"Failed to insert data: {e}")
|
||||||
|
|
@ -215,7 +221,9 @@ class OpenGaussClient(VectorDBBase):
|
||||||
)
|
)
|
||||||
self.session.add(new_chunk)
|
self.session.add(new_chunk)
|
||||||
self.session.commit()
|
self.session.commit()
|
||||||
log.info(f"Inserting/updating {len(items)} items in collection '{collection_name}'.")
|
log.info(
|
||||||
|
f"Inserting/updating {len(items)} items in collection '{collection_name}'."
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.session.rollback()
|
self.session.rollback()
|
||||||
log.exception(f"Failed to insert or update data.: {e}")
|
log.exception(f"Failed to insert or update data.: {e}")
|
||||||
|
|
@ -241,7 +249,9 @@ class OpenGaussClient(VectorDBBase):
|
||||||
q_vector_col = column("q_vector", Vector(VECTOR_LENGTH))
|
q_vector_col = column("q_vector", Vector(VECTOR_LENGTH))
|
||||||
query_vectors = (
|
query_vectors = (
|
||||||
values(qid_col, q_vector_col)
|
values(qid_col, q_vector_col)
|
||||||
.data([(idx, vector_expr(vector)) for idx, vector in enumerate(vectors)])
|
.data(
|
||||||
|
[(idx, vector_expr(vector)) for idx, vector in enumerate(vectors)]
|
||||||
|
)
|
||||||
.alias("query_vectors")
|
.alias("query_vectors")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -249,13 +259,17 @@ class OpenGaussClient(VectorDBBase):
|
||||||
DocumentChunk.id,
|
DocumentChunk.id,
|
||||||
DocumentChunk.text,
|
DocumentChunk.text,
|
||||||
DocumentChunk.vmetadata,
|
DocumentChunk.vmetadata,
|
||||||
(DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)).label("distance"),
|
(DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)).label(
|
||||||
|
"distance"
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
subq = (
|
subq = (
|
||||||
select(*result_fields)
|
select(*result_fields)
|
||||||
.where(DocumentChunk.collection_name == collection_name)
|
.where(DocumentChunk.collection_name == collection_name)
|
||||||
.order_by(DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector))
|
.order_by(
|
||||||
|
DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
if limit is not None:
|
if limit is not None:
|
||||||
subq = subq.limit(limit)
|
subq = subq.limit(limit)
|
||||||
|
|
@ -368,7 +382,9 @@ class OpenGaussClient(VectorDBBase):
|
||||||
query = query.filter(DocumentChunk.id.in_(ids))
|
query = query.filter(DocumentChunk.id.in_(ids))
|
||||||
if filter:
|
if filter:
|
||||||
for key, value in filter.items():
|
for key, value in filter.items():
|
||||||
query = query.filter(DocumentChunk.vmetadata[key].astext == str(value))
|
query = query.filter(
|
||||||
|
DocumentChunk.vmetadata[key].astext == str(value)
|
||||||
|
)
|
||||||
deleted = query.delete(synchronize_session=False)
|
deleted = query.delete(synchronize_session=False)
|
||||||
self.session.commit()
|
self.session.commit()
|
||||||
log.info(f"Deleted {deleted} items from collection '{collection_name}'")
|
log.info(f"Deleted {deleted} items from collection '{collection_name}'")
|
||||||
|
|
@ -395,7 +411,8 @@ class OpenGaussClient(VectorDBBase):
|
||||||
exists = (
|
exists = (
|
||||||
self.session.query(DocumentChunk)
|
self.session.query(DocumentChunk)
|
||||||
.filter(DocumentChunk.collection_name == collection_name)
|
.filter(DocumentChunk.collection_name == collection_name)
|
||||||
.first() is not None
|
.first()
|
||||||
|
is not None
|
||||||
)
|
)
|
||||||
self.session.rollback()
|
self.session.rollback()
|
||||||
return exists
|
return exists
|
||||||
|
|
|
||||||
|
|
@ -90,9 +90,9 @@ class PgvectorClient(VectorDBBase):
|
||||||
|
|
||||||
# if no pgvector uri, use the existing database connection
|
# if no pgvector uri, use the existing database connection
|
||||||
if not PGVECTOR_DB_URL:
|
if not PGVECTOR_DB_URL:
|
||||||
from open_webui.internal.db import Session
|
from open_webui.internal.db import ScopedSession
|
||||||
|
|
||||||
self.session = Session
|
self.session = ScopedSession
|
||||||
else:
|
else:
|
||||||
if isinstance(PGVECTOR_POOL_SIZE, int):
|
if isinstance(PGVECTOR_POOL_SIZE, int):
|
||||||
if PGVECTOR_POOL_SIZE > 0:
|
if PGVECTOR_POOL_SIZE > 0:
|
||||||
|
|
|
||||||
|
|
@ -62,6 +62,8 @@ from open_webui.utils.auth import (
|
||||||
get_password_hash,
|
get_password_hash,
|
||||||
get_http_authorization_cred,
|
get_http_authorization_cred,
|
||||||
)
|
)
|
||||||
|
from open_webui.internal.db import get_session
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
from open_webui.utils.webhook import post_webhook
|
from open_webui.utils.webhook import post_webhook
|
||||||
from open_webui.utils.access_control import get_permissions, has_permission
|
from open_webui.utils.access_control import get_permissions, has_permission
|
||||||
from open_webui.utils.groups import apply_default_group_assignment
|
from open_webui.utils.groups import apply_default_group_assignment
|
||||||
|
|
@ -103,7 +105,10 @@ class SessionUserInfoResponse(SessionUserResponse, UserStatus):
|
||||||
|
|
||||||
@router.get("/", response_model=SessionUserInfoResponse)
|
@router.get("/", response_model=SessionUserInfoResponse)
|
||||||
async def get_session_user(
|
async def get_session_user(
|
||||||
request: Request, response: Response, user=Depends(get_current_user)
|
request: Request,
|
||||||
|
response: Response,
|
||||||
|
user=Depends(get_current_user),
|
||||||
|
db: Session = Depends(get_session),
|
||||||
):
|
):
|
||||||
|
|
||||||
auth_header = request.headers.get("Authorization")
|
auth_header = request.headers.get("Authorization")
|
||||||
|
|
@ -137,7 +142,7 @@ async def get_session_user(
|
||||||
)
|
)
|
||||||
|
|
||||||
user_permissions = get_permissions(
|
user_permissions = get_permissions(
|
||||||
user.id, request.app.state.config.USER_PERMISSIONS
|
user.id, request.app.state.config.USER_PERMISSIONS, db=db
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|
@ -166,12 +171,15 @@ async def get_session_user(
|
||||||
|
|
||||||
@router.post("/update/profile", response_model=UserProfileImageResponse)
|
@router.post("/update/profile", response_model=UserProfileImageResponse)
|
||||||
async def update_profile(
|
async def update_profile(
|
||||||
form_data: UpdateProfileForm, session_user=Depends(get_verified_user)
|
form_data: UpdateProfileForm,
|
||||||
|
session_user=Depends(get_verified_user),
|
||||||
|
db: Session = Depends(get_session),
|
||||||
):
|
):
|
||||||
if session_user:
|
if session_user:
|
||||||
user = Users.update_user_by_id(
|
user = Users.update_user_by_id(
|
||||||
session_user.id,
|
session_user.id,
|
||||||
form_data.model_dump(),
|
form_data.model_dump(),
|
||||||
|
db=db,
|
||||||
)
|
)
|
||||||
if user:
|
if user:
|
||||||
return user
|
return user
|
||||||
|
|
@ -188,13 +196,17 @@ async def update_profile(
|
||||||
|
|
||||||
@router.post("/update/password", response_model=bool)
|
@router.post("/update/password", response_model=bool)
|
||||||
async def update_password(
|
async def update_password(
|
||||||
form_data: UpdatePasswordForm, session_user=Depends(get_current_user)
|
form_data: UpdatePasswordForm,
|
||||||
|
session_user=Depends(get_current_user),
|
||||||
|
db: Session = Depends(get_session),
|
||||||
):
|
):
|
||||||
if WEBUI_AUTH_TRUSTED_EMAIL_HEADER:
|
if WEBUI_AUTH_TRUSTED_EMAIL_HEADER:
|
||||||
raise HTTPException(400, detail=ERROR_MESSAGES.ACTION_PROHIBITED)
|
raise HTTPException(400, detail=ERROR_MESSAGES.ACTION_PROHIBITED)
|
||||||
if session_user:
|
if session_user:
|
||||||
user = Auths.authenticate_user(
|
user = Auths.authenticate_user(
|
||||||
session_user.email, lambda pw: verify_password(form_data.password, pw)
|
session_user.email,
|
||||||
|
lambda pw: verify_password(form_data.password, pw),
|
||||||
|
db=db,
|
||||||
)
|
)
|
||||||
|
|
||||||
if user:
|
if user:
|
||||||
|
|
@ -203,7 +215,7 @@ async def update_password(
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(400, detail=str(e))
|
raise HTTPException(400, detail=str(e))
|
||||||
hashed = get_password_hash(form_data.new_password)
|
hashed = get_password_hash(form_data.new_password)
|
||||||
return Auths.update_user_password_by_id(user.id, hashed)
|
return Auths.update_user_password_by_id(user.id, hashed, db=db)
|
||||||
else:
|
else:
|
||||||
raise HTTPException(400, detail=ERROR_MESSAGES.INCORRECT_PASSWORD)
|
raise HTTPException(400, detail=ERROR_MESSAGES.INCORRECT_PASSWORD)
|
||||||
else:
|
else:
|
||||||
|
|
@ -214,7 +226,12 @@ async def update_password(
|
||||||
# LDAP Authentication
|
# LDAP Authentication
|
||||||
############################
|
############################
|
||||||
@router.post("/ldap", response_model=SessionUserResponse)
|
@router.post("/ldap", response_model=SessionUserResponse)
|
||||||
async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
async def ldap_auth(
|
||||||
|
request: Request,
|
||||||
|
response: Response,
|
||||||
|
form_data: LdapForm,
|
||||||
|
db: Session = Depends(get_session),
|
||||||
|
):
|
||||||
# Security checks FIRST - before loading any config
|
# Security checks FIRST - before loading any config
|
||||||
if not request.app.state.config.ENABLE_LDAP:
|
if not request.app.state.config.ENABLE_LDAP:
|
||||||
raise HTTPException(400, detail="LDAP authentication is not enabled")
|
raise HTTPException(400, detail="LDAP authentication is not enabled")
|
||||||
|
|
@ -400,12 +417,12 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
||||||
if not connection_user.bind():
|
if not connection_user.bind():
|
||||||
raise HTTPException(400, "Authentication failed.")
|
raise HTTPException(400, "Authentication failed.")
|
||||||
|
|
||||||
user = Users.get_user_by_email(email)
|
user = Users.get_user_by_email(email, db=db)
|
||||||
if not user:
|
if not user:
|
||||||
try:
|
try:
|
||||||
role = (
|
role = (
|
||||||
"admin"
|
"admin"
|
||||||
if not Users.has_users()
|
if not Users.has_users(db=db)
|
||||||
else request.app.state.config.DEFAULT_USER_ROLE
|
else request.app.state.config.DEFAULT_USER_ROLE
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -414,6 +431,7 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
||||||
password=str(uuid.uuid4()),
|
password=str(uuid.uuid4()),
|
||||||
name=cn,
|
name=cn,
|
||||||
role=role,
|
role=role,
|
||||||
|
db=db,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not user:
|
if not user:
|
||||||
|
|
@ -424,6 +442,7 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
||||||
apply_default_group_assignment(
|
apply_default_group_assignment(
|
||||||
request.app.state.config.DEFAULT_GROUP_ID,
|
request.app.state.config.DEFAULT_GROUP_ID,
|
||||||
user.id,
|
user.id,
|
||||||
|
db=db,
|
||||||
)
|
)
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
|
|
@ -434,7 +453,7 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
||||||
500, detail="Internal error occurred during LDAP user creation."
|
500, detail="Internal error occurred during LDAP user creation."
|
||||||
)
|
)
|
||||||
|
|
||||||
user = Auths.authenticate_user_by_email(email)
|
user = Auths.authenticate_user_by_email(email, db=db)
|
||||||
|
|
||||||
if user:
|
if user:
|
||||||
expires_delta = parse_duration(request.app.state.config.JWT_EXPIRES_IN)
|
expires_delta = parse_duration(request.app.state.config.JWT_EXPIRES_IN)
|
||||||
|
|
@ -464,7 +483,7 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
||||||
)
|
)
|
||||||
|
|
||||||
user_permissions = get_permissions(
|
user_permissions = get_permissions(
|
||||||
user.id, request.app.state.config.USER_PERMISSIONS
|
user.id, request.app.state.config.USER_PERMISSIONS, db=db
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
|
|
@ -473,9 +492,9 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
||||||
and user_groups
|
and user_groups
|
||||||
):
|
):
|
||||||
if ENABLE_LDAP_GROUP_CREATION:
|
if ENABLE_LDAP_GROUP_CREATION:
|
||||||
Groups.create_groups_by_group_names(user.id, user_groups)
|
Groups.create_groups_by_group_names(user.id, user_groups, db=db)
|
||||||
try:
|
try:
|
||||||
Groups.sync_groups_by_group_names(user.id, user_groups)
|
Groups.sync_groups_by_group_names(user.id, user_groups, db=db)
|
||||||
log.info(
|
log.info(
|
||||||
f"Successfully synced groups for user {user.id}: {user_groups}"
|
f"Successfully synced groups for user {user.id}: {user_groups}"
|
||||||
)
|
)
|
||||||
|
|
@ -508,7 +527,12 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
||||||
|
|
||||||
|
|
||||||
@router.post("/signin", response_model=SessionUserResponse)
|
@router.post("/signin", response_model=SessionUserResponse)
|
||||||
async def signin(request: Request, response: Response, form_data: SigninForm):
|
async def signin(
|
||||||
|
request: Request,
|
||||||
|
response: Response,
|
||||||
|
form_data: SigninForm,
|
||||||
|
db: Session = Depends(get_session),
|
||||||
|
):
|
||||||
if not ENABLE_PASSWORD_AUTH:
|
if not ENABLE_PASSWORD_AUTH:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
|
@ -529,14 +553,15 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if not Users.get_user_by_email(email.lower()):
|
if not Users.get_user_by_email(email.lower(), db=db):
|
||||||
await signup(
|
await signup(
|
||||||
request,
|
request,
|
||||||
response,
|
response,
|
||||||
SignupForm(email=email, password=str(uuid.uuid4()), name=name),
|
SignupForm(email=email, password=str(uuid.uuid4()), name=name),
|
||||||
|
db=db,
|
||||||
)
|
)
|
||||||
|
|
||||||
user = Auths.authenticate_user_by_email(email)
|
user = Auths.authenticate_user_by_email(email, db=db)
|
||||||
if WEBUI_AUTH_TRUSTED_GROUPS_HEADER and user and user.role != "admin":
|
if WEBUI_AUTH_TRUSTED_GROUPS_HEADER and user and user.role != "admin":
|
||||||
group_names = request.headers.get(
|
group_names = request.headers.get(
|
||||||
WEBUI_AUTH_TRUSTED_GROUPS_HEADER, ""
|
WEBUI_AUTH_TRUSTED_GROUPS_HEADER, ""
|
||||||
|
|
@ -544,28 +569,33 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
|
||||||
group_names = [name.strip() for name in group_names if name.strip()]
|
group_names = [name.strip() for name in group_names if name.strip()]
|
||||||
|
|
||||||
if group_names:
|
if group_names:
|
||||||
Groups.sync_groups_by_group_names(user.id, group_names)
|
Groups.sync_groups_by_group_names(user.id, group_names, db=db)
|
||||||
|
|
||||||
elif WEBUI_AUTH == False:
|
elif WEBUI_AUTH == False:
|
||||||
admin_email = "admin@localhost"
|
admin_email = "admin@localhost"
|
||||||
admin_password = "admin"
|
admin_password = "admin"
|
||||||
|
|
||||||
if Users.get_user_by_email(admin_email.lower()):
|
if Users.get_user_by_email(admin_email.lower(), db=db):
|
||||||
user = Auths.authenticate_user(
|
user = Auths.authenticate_user(
|
||||||
admin_email.lower(), lambda pw: verify_password(admin_password, pw)
|
admin_email.lower(),
|
||||||
|
lambda pw: verify_password(admin_password, pw),
|
||||||
|
db=db,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if Users.has_users():
|
if Users.has_users(db=db):
|
||||||
raise HTTPException(400, detail=ERROR_MESSAGES.EXISTING_USERS)
|
raise HTTPException(400, detail=ERROR_MESSAGES.EXISTING_USERS)
|
||||||
|
|
||||||
await signup(
|
await signup(
|
||||||
request,
|
request,
|
||||||
response,
|
response,
|
||||||
SignupForm(email=admin_email, password=admin_password, name="User"),
|
SignupForm(email=admin_email, password=admin_password, name="User"),
|
||||||
|
db=db,
|
||||||
)
|
)
|
||||||
|
|
||||||
user = Auths.authenticate_user(
|
user = Auths.authenticate_user(
|
||||||
admin_email.lower(), lambda pw: verify_password(admin_password, pw)
|
admin_email.lower(),
|
||||||
|
lambda pw: verify_password(admin_password, pw),
|
||||||
|
db=db,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if signin_rate_limiter.is_limited(form_data.email.lower()):
|
if signin_rate_limiter.is_limited(form_data.email.lower()):
|
||||||
|
|
@ -584,7 +614,9 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
|
||||||
form_data.password = password_bytes.decode("utf-8", errors="ignore")
|
form_data.password = password_bytes.decode("utf-8", errors="ignore")
|
||||||
|
|
||||||
user = Auths.authenticate_user(
|
user = Auths.authenticate_user(
|
||||||
form_data.email.lower(), lambda pw: verify_password(form_data.password, pw)
|
form_data.email.lower(),
|
||||||
|
lambda pw: verify_password(form_data.password, pw),
|
||||||
|
db=db,
|
||||||
)
|
)
|
||||||
|
|
||||||
if user:
|
if user:
|
||||||
|
|
@ -616,7 +648,7 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
|
||||||
)
|
)
|
||||||
|
|
||||||
user_permissions = get_permissions(
|
user_permissions = get_permissions(
|
||||||
user.id, request.app.state.config.USER_PERMISSIONS
|
user.id, request.app.state.config.USER_PERMISSIONS, db=db
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|
@ -640,8 +672,13 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
|
||||||
|
|
||||||
|
|
||||||
@router.post("/signup", response_model=SessionUserResponse)
|
@router.post("/signup", response_model=SessionUserResponse)
|
||||||
async def signup(request: Request, response: Response, form_data: SignupForm):
|
async def signup(
|
||||||
has_users = Users.has_users()
|
request: Request,
|
||||||
|
response: Response,
|
||||||
|
form_data: SignupForm,
|
||||||
|
db: Session = Depends(get_session),
|
||||||
|
):
|
||||||
|
has_users = Users.has_users(db=db)
|
||||||
|
|
||||||
if WEBUI_AUTH:
|
if WEBUI_AUTH:
|
||||||
if (
|
if (
|
||||||
|
|
@ -663,7 +700,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
|
||||||
status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT
|
status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT
|
||||||
)
|
)
|
||||||
|
|
||||||
if Users.get_user_by_email(form_data.email.lower()):
|
if Users.get_user_by_email(form_data.email.lower(), db=db):
|
||||||
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
|
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -681,6 +718,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
|
||||||
form_data.name,
|
form_data.name,
|
||||||
form_data.profile_image_url,
|
form_data.profile_image_url,
|
||||||
role,
|
role,
|
||||||
|
db=db,
|
||||||
)
|
)
|
||||||
|
|
||||||
if user:
|
if user:
|
||||||
|
|
@ -723,7 +761,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
|
||||||
)
|
)
|
||||||
|
|
||||||
user_permissions = get_permissions(
|
user_permissions = get_permissions(
|
||||||
user.id, request.app.state.config.USER_PERMISSIONS
|
user.id, request.app.state.config.USER_PERMISSIONS, db=db
|
||||||
)
|
)
|
||||||
|
|
||||||
if not has_users:
|
if not has_users:
|
||||||
|
|
@ -733,6 +771,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
|
||||||
apply_default_group_assignment(
|
apply_default_group_assignment(
|
||||||
request.app.state.config.DEFAULT_GROUP_ID,
|
request.app.state.config.DEFAULT_GROUP_ID,
|
||||||
user.id,
|
user.id,
|
||||||
|
db=db,
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|
@ -754,7 +793,9 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
|
||||||
|
|
||||||
|
|
||||||
@router.get("/signout")
|
@router.get("/signout")
|
||||||
async def signout(request: Request, response: Response):
|
async def signout(
|
||||||
|
request: Request, response: Response, db: Session = Depends(get_session)
|
||||||
|
):
|
||||||
|
|
||||||
# get auth token from headers or cookies
|
# get auth token from headers or cookies
|
||||||
token = None
|
token = None
|
||||||
|
|
@ -776,7 +817,7 @@ async def signout(request: Request, response: Response):
|
||||||
if oauth_session_id:
|
if oauth_session_id:
|
||||||
response.delete_cookie("oauth_session_id")
|
response.delete_cookie("oauth_session_id")
|
||||||
|
|
||||||
session = OAuthSessions.get_session_by_id(oauth_session_id)
|
session = OAuthSessions.get_session_by_id(oauth_session_id, db=db)
|
||||||
oauth_server_metadata_url = (
|
oauth_server_metadata_url = (
|
||||||
request.app.state.oauth_manager.get_server_metadata_url(session.provider)
|
request.app.state.oauth_manager.get_server_metadata_url(session.provider)
|
||||||
if session
|
if session
|
||||||
|
|
@ -839,14 +880,17 @@ async def signout(request: Request, response: Response):
|
||||||
|
|
||||||
@router.post("/add", response_model=SigninResponse)
|
@router.post("/add", response_model=SigninResponse)
|
||||||
async def add_user(
|
async def add_user(
|
||||||
request: Request, form_data: AddUserForm, user=Depends(get_admin_user)
|
request: Request,
|
||||||
|
form_data: AddUserForm,
|
||||||
|
user=Depends(get_admin_user),
|
||||||
|
db: Session = Depends(get_session),
|
||||||
):
|
):
|
||||||
if not validate_email_format(form_data.email.lower()):
|
if not validate_email_format(form_data.email.lower()):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT
|
status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT
|
||||||
)
|
)
|
||||||
|
|
||||||
if Users.get_user_by_email(form_data.email.lower()):
|
if Users.get_user_by_email(form_data.email.lower(), db=db):
|
||||||
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
|
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -862,12 +906,14 @@ async def add_user(
|
||||||
form_data.name,
|
form_data.name,
|
||||||
form_data.profile_image_url,
|
form_data.profile_image_url,
|
||||||
form_data.role,
|
form_data.role,
|
||||||
|
db=db,
|
||||||
)
|
)
|
||||||
|
|
||||||
if user:
|
if user:
|
||||||
apply_default_group_assignment(
|
apply_default_group_assignment(
|
||||||
request.app.state.config.DEFAULT_GROUP_ID,
|
request.app.state.config.DEFAULT_GROUP_ID,
|
||||||
user.id,
|
user.id,
|
||||||
|
db=db,
|
||||||
)
|
)
|
||||||
|
|
||||||
token = create_token(data={"id": user.id})
|
token = create_token(data={"id": user.id})
|
||||||
|
|
@ -895,7 +941,9 @@ async def add_user(
|
||||||
|
|
||||||
|
|
||||||
@router.get("/admin/details")
|
@router.get("/admin/details")
|
||||||
async def get_admin_details(request: Request, user=Depends(get_current_user)):
|
async def get_admin_details(
|
||||||
|
request: Request, user=Depends(get_current_user), db: Session = Depends(get_session)
|
||||||
|
):
|
||||||
if request.app.state.config.SHOW_ADMIN_DETAILS:
|
if request.app.state.config.SHOW_ADMIN_DETAILS:
|
||||||
admin_email = request.app.state.config.ADMIN_EMAIL
|
admin_email = request.app.state.config.ADMIN_EMAIL
|
||||||
admin_name = None
|
admin_name = None
|
||||||
|
|
@ -903,11 +951,11 @@ async def get_admin_details(request: Request, user=Depends(get_current_user)):
|
||||||
log.info(f"Admin details - Email: {admin_email}, Name: {admin_name}")
|
log.info(f"Admin details - Email: {admin_email}, Name: {admin_name}")
|
||||||
|
|
||||||
if admin_email:
|
if admin_email:
|
||||||
admin = Users.get_user_by_email(admin_email)
|
admin = Users.get_user_by_email(admin_email, db=db)
|
||||||
if admin:
|
if admin:
|
||||||
admin_name = admin.name
|
admin_name = admin.name
|
||||||
else:
|
else:
|
||||||
admin = Users.get_first_user()
|
admin = Users.get_first_user(db=db)
|
||||||
if admin:
|
if admin:
|
||||||
admin_email = admin.email
|
admin_email = admin.email
|
||||||
admin_name = admin.name
|
admin_name = admin.name
|
||||||
|
|
@ -1149,7 +1197,9 @@ async def update_ldap_config(
|
||||||
|
|
||||||
# create api key
|
# create api key
|
||||||
@router.post("/api_key", response_model=ApiKey)
|
@router.post("/api_key", response_model=ApiKey)
|
||||||
async def generate_api_key(request: Request, user=Depends(get_current_user)):
|
async def generate_api_key(
|
||||||
|
request: Request, user=Depends(get_current_user), db: Session = Depends(get_session)
|
||||||
|
):
|
||||||
if not request.app.state.config.ENABLE_API_KEYS or not has_permission(
|
if not request.app.state.config.ENABLE_API_KEYS or not has_permission(
|
||||||
user.id, "features.api_keys", request.app.state.config.USER_PERMISSIONS
|
user.id, "features.api_keys", request.app.state.config.USER_PERMISSIONS
|
||||||
):
|
):
|
||||||
|
|
@ -1159,7 +1209,7 @@ async def generate_api_key(request: Request, user=Depends(get_current_user)):
|
||||||
)
|
)
|
||||||
|
|
||||||
api_key = create_api_key()
|
api_key = create_api_key()
|
||||||
success = Users.update_user_api_key_by_id(user.id, api_key)
|
success = Users.update_user_api_key_by_id(user.id, api_key, db=db)
|
||||||
|
|
||||||
if success:
|
if success:
|
||||||
return {
|
return {
|
||||||
|
|
@ -1171,14 +1221,18 @@ async def generate_api_key(request: Request, user=Depends(get_current_user)):
|
||||||
|
|
||||||
# delete api key
|
# delete api key
|
||||||
@router.delete("/api_key", response_model=bool)
|
@router.delete("/api_key", response_model=bool)
|
||||||
async def delete_api_key(user=Depends(get_current_user)):
|
async def delete_api_key(
|
||||||
return Users.delete_user_api_key_by_id(user.id)
|
user=Depends(get_current_user), db: Session = Depends(get_session)
|
||||||
|
):
|
||||||
|
return Users.delete_user_api_key_by_id(user.id, db=db)
|
||||||
|
|
||||||
|
|
||||||
# get api key
|
# get api key
|
||||||
@router.get("/api_key", response_model=ApiKey)
|
@router.get("/api_key", response_model=ApiKey)
|
||||||
async def get_api_key(user=Depends(get_current_user)):
|
async def get_api_key(
|
||||||
api_key = Users.get_user_api_key_by_id(user.id)
|
user=Depends(get_current_user), db: Session = Depends(get_session)
|
||||||
|
):
|
||||||
|
api_key = Users.get_user_api_key_by_id(user.id, db=db)
|
||||||
if api_key:
|
if api_key:
|
||||||
return {
|
return {
|
||||||
"api_key": api_key,
|
"api_key": api_key,
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
import asyncio
|
import asyncio
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
|
|
||||||
|
|
@ -23,6 +24,7 @@ from open_webui.models.chats import (
|
||||||
)
|
)
|
||||||
from open_webui.models.tags import TagModel, Tags
|
from open_webui.models.tags import TagModel, Tags
|
||||||
from open_webui.models.folders import Folders
|
from open_webui.models.folders import Folders
|
||||||
|
from open_webui.internal.db import get_session
|
||||||
|
|
||||||
from open_webui.config import ENABLE_ADMIN_CHAT_ACCESS, ENABLE_ADMIN_EXPORT
|
from open_webui.config import ENABLE_ADMIN_CHAT_ACCESS, ENABLE_ADMIN_EXPORT
|
||||||
from open_webui.constants import ERROR_MESSAGES
|
from open_webui.constants import ERROR_MESSAGES
|
||||||
|
|
@ -49,6 +51,7 @@ def get_session_user_chat_list(
|
||||||
page: Optional[int] = None,
|
page: Optional[int] = None,
|
||||||
include_pinned: Optional[bool] = False,
|
include_pinned: Optional[bool] = False,
|
||||||
include_folders: Optional[bool] = False,
|
include_folders: Optional[bool] = False,
|
||||||
|
db: Session = Depends(get_session),
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
if page is not None:
|
if page is not None:
|
||||||
|
|
@ -61,10 +64,14 @@ def get_session_user_chat_list(
|
||||||
include_pinned=include_pinned,
|
include_pinned=include_pinned,
|
||||||
skip=skip,
|
skip=skip,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
|
db=db,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return Chats.get_chat_title_id_list_by_user_id(
|
return Chats.get_chat_title_id_list_by_user_id(
|
||||||
user.id, include_folders=include_folders, include_pinned=include_pinned
|
user.id,
|
||||||
|
include_folders=include_folders,
|
||||||
|
include_pinned=include_pinned,
|
||||||
|
db=db,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(e)
|
log.exception(e)
|
||||||
|
|
@ -84,12 +91,13 @@ def get_session_user_chat_usage_stats(
|
||||||
items_per_page: Optional[int] = 50,
|
items_per_page: Optional[int] = 50,
|
||||||
page: Optional[int] = 1,
|
page: Optional[int] = 1,
|
||||||
user=Depends(get_verified_user),
|
user=Depends(get_verified_user),
|
||||||
|
db: Session = Depends(get_session),
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
limit = items_per_page
|
limit = items_per_page
|
||||||
skip = (page - 1) * limit
|
skip = (page - 1) * limit
|
||||||
|
|
||||||
result = Chats.get_chats_by_user_id(user.id, skip=skip, limit=limit)
|
result = Chats.get_chats_by_user_id(user.id, skip=skip, limit=limit, db=db)
|
||||||
|
|
||||||
chats = result.items
|
chats = result.items
|
||||||
total = result.total
|
total = result.total
|
||||||
|
|
@ -216,6 +224,7 @@ class ChatStatsExportList(BaseModel):
|
||||||
|
|
||||||
def _process_chat_for_export(chat) -> Optional[ChatStatsExport]:
|
def _process_chat_for_export(chat) -> Optional[ChatStatsExport]:
|
||||||
try:
|
try:
|
||||||
|
|
||||||
def get_message_content_length(message):
|
def get_message_content_length(message):
|
||||||
content = message.get("content", "")
|
content = message.get("content", "")
|
||||||
if isinstance(content, str):
|
if isinstance(content, str):
|
||||||
|
|
@ -348,7 +357,9 @@ def _process_chat_for_export(chat) -> Optional[ChatStatsExport]:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def calculate_chat_stats(user_id, skip=0, limit=10, filter=None):
|
def calculate_chat_stats(
|
||||||
|
user_id, skip=0, limit=10, filter=None, db: Optional[Session] = None
|
||||||
|
):
|
||||||
if filter is None:
|
if filter is None:
|
||||||
filter = {}
|
filter = {}
|
||||||
|
|
||||||
|
|
@ -357,6 +368,7 @@ def calculate_chat_stats(user_id, skip=0, limit=10, filter=None):
|
||||||
skip=skip,
|
skip=skip,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
filter=filter,
|
filter=filter,
|
||||||
|
db=db,
|
||||||
)
|
)
|
||||||
|
|
||||||
chat_stats_export_list = []
|
chat_stats_export_list = []
|
||||||
|
|
@ -368,14 +380,21 @@ def calculate_chat_stats(user_id, skip=0, limit=10, filter=None):
|
||||||
return chat_stats_export_list, result.total
|
return chat_stats_export_list, result.total
|
||||||
|
|
||||||
|
|
||||||
async def generate_chat_stats_jsonl_generator(user_id, filter):
|
async def generate_chat_stats_jsonl_generator(
|
||||||
|
user_id, filter, db: Optional[Session] = None
|
||||||
|
):
|
||||||
skip = 0
|
skip = 0
|
||||||
limit = CHAT_EXPORT_PAGE_ITEM_COUNT
|
limit = CHAT_EXPORT_PAGE_ITEM_COUNT
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
# Use asyncio.to_thread to make the blocking DB call non-blocking
|
# Use asyncio.to_thread to make the blocking DB call non-blocking
|
||||||
result = await asyncio.to_thread(
|
result = await asyncio.to_thread(
|
||||||
Chats.get_chats_by_user_id, user_id, filter=filter, skip=skip, limit=limit
|
Chats.get_chats_by_user_id,
|
||||||
|
user_id,
|
||||||
|
filter=filter,
|
||||||
|
skip=skip,
|
||||||
|
limit=limit,
|
||||||
|
db=db,
|
||||||
)
|
)
|
||||||
if not result.items:
|
if not result.items:
|
||||||
break
|
break
|
||||||
|
|
@ -400,6 +419,7 @@ async def export_chat_stats(
|
||||||
page: Optional[int] = 1,
|
page: Optional[int] = 1,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
user=Depends(get_verified_user),
|
user=Depends(get_verified_user),
|
||||||
|
db: Session = Depends(get_session),
|
||||||
):
|
):
|
||||||
# Check if the user has permission to share/export chats
|
# Check if the user has permission to share/export chats
|
||||||
if (user.role != "admin") and (
|
if (user.role != "admin") and (
|
||||||
|
|
@ -415,7 +435,7 @@ async def export_chat_stats(
|
||||||
filter = {"order_by": "created_at", "direction": "asc"}
|
filter = {"order_by": "created_at", "direction": "asc"}
|
||||||
|
|
||||||
if chat_id:
|
if chat_id:
|
||||||
chat = Chats.get_chat_by_id(chat_id)
|
chat = Chats.get_chat_by_id(chat_id, db=db)
|
||||||
if chat:
|
if chat:
|
||||||
filter["start_time"] = chat.created_at
|
filter["start_time"] = chat.created_at
|
||||||
|
|
||||||
|
|
@ -426,7 +446,7 @@ async def export_chat_stats(
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
generate_chat_stats_jsonl_generator(user.id, filter),
|
generate_chat_stats_jsonl_generator(user.id, filter, db=db),
|
||||||
media_type="application/x-ndjson",
|
media_type="application/x-ndjson",
|
||||||
headers={
|
headers={
|
||||||
"Content-Disposition": f"attachment; filename=chat-stats-export-{user.id}.jsonl"
|
"Content-Disposition": f"attachment; filename=chat-stats-export-{user.id}.jsonl"
|
||||||
|
|
@ -437,7 +457,7 @@ async def export_chat_stats(
|
||||||
skip = (page - 1) * limit
|
skip = (page - 1) * limit
|
||||||
|
|
||||||
chat_stats_export_list, total = await asyncio.to_thread(
|
chat_stats_export_list, total = await asyncio.to_thread(
|
||||||
calculate_chat_stats, user.id, skip, limit, filter
|
calculate_chat_stats, user.id, skip, limit, filter, db=db
|
||||||
)
|
)
|
||||||
|
|
||||||
return ChatStatsExportList(
|
return ChatStatsExportList(
|
||||||
|
|
@ -452,7 +472,11 @@ async def export_chat_stats(
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/", response_model=bool)
|
@router.delete("/", response_model=bool)
|
||||||
async def delete_all_user_chats(request: Request, user=Depends(get_verified_user)):
|
async def delete_all_user_chats(
|
||||||
|
request: Request,
|
||||||
|
user=Depends(get_verified_user),
|
||||||
|
db: Session = Depends(get_session),
|
||||||
|
):
|
||||||
|
|
||||||
if user.role == "user" and not has_permission(
|
if user.role == "user" and not has_permission(
|
||||||
user.id, "chat.delete", request.app.state.config.USER_PERMISSIONS
|
user.id, "chat.delete", request.app.state.config.USER_PERMISSIONS
|
||||||
|
|
@ -462,7 +486,7 @@ async def delete_all_user_chats(request: Request, user=Depends(get_verified_user
|
||||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||||
)
|
)
|
||||||
|
|
||||||
result = Chats.delete_chats_by_user_id(user.id)
|
result = Chats.delete_chats_by_user_id(user.id, db=db)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -479,6 +503,7 @@ async def get_user_chat_list_by_user_id(
|
||||||
order_by: Optional[str] = None,
|
order_by: Optional[str] = None,
|
||||||
direction: Optional[str] = None,
|
direction: Optional[str] = None,
|
||||||
user=Depends(get_admin_user),
|
user=Depends(get_admin_user),
|
||||||
|
db: Session = Depends(get_session),
|
||||||
):
|
):
|
||||||
if not ENABLE_ADMIN_CHAT_ACCESS:
|
if not ENABLE_ADMIN_CHAT_ACCESS:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -501,7 +526,7 @@ async def get_user_chat_list_by_user_id(
|
||||||
filter["direction"] = direction
|
filter["direction"] = direction
|
||||||
|
|
||||||
return Chats.get_chat_list_by_user_id(
|
return Chats.get_chat_list_by_user_id(
|
||||||
user_id, include_archived=True, filter=filter, skip=skip, limit=limit
|
user_id, include_archived=True, filter=filter, skip=skip, limit=limit, db=db
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -511,9 +536,13 @@ async def get_user_chat_list_by_user_id(
|
||||||
|
|
||||||
|
|
||||||
@router.post("/new", response_model=Optional[ChatResponse])
|
@router.post("/new", response_model=Optional[ChatResponse])
|
||||||
async def create_new_chat(form_data: ChatForm, user=Depends(get_verified_user)):
|
async def create_new_chat(
|
||||||
|
form_data: ChatForm,
|
||||||
|
user=Depends(get_verified_user),
|
||||||
|
db: Session = Depends(get_session),
|
||||||
|
):
|
||||||
try:
|
try:
|
||||||
chat = Chats.insert_new_chat(user.id, form_data)
|
chat = Chats.insert_new_chat(user.id, form_data, db=db)
|
||||||
return ChatResponse(**chat.model_dump())
|
return ChatResponse(**chat.model_dump())
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(e)
|
log.exception(e)
|
||||||
|
|
@ -528,9 +557,13 @@ async def create_new_chat(form_data: ChatForm, user=Depends(get_verified_user)):
|
||||||
|
|
||||||
|
|
||||||
@router.post("/import", response_model=list[ChatResponse])
|
@router.post("/import", response_model=list[ChatResponse])
|
||||||
async def import_chats(form_data: ChatsImportForm, user=Depends(get_verified_user)):
|
async def import_chats(
|
||||||
|
form_data: ChatsImportForm,
|
||||||
|
user=Depends(get_verified_user),
|
||||||
|
db: Session = Depends(get_session),
|
||||||
|
):
|
||||||
try:
|
try:
|
||||||
chats = Chats.import_chats(user.id, form_data.chats)
|
chats = Chats.import_chats(user.id, form_data.chats, db=db)
|
||||||
return chats
|
return chats
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(e)
|
log.exception(e)
|
||||||
|
|
@ -546,7 +579,10 @@ async def import_chats(form_data: ChatsImportForm, user=Depends(get_verified_use
|
||||||
|
|
||||||
@router.get("/search", response_model=list[ChatTitleIdResponse])
|
@router.get("/search", response_model=list[ChatTitleIdResponse])
|
||||||
def search_user_chats(
|
def search_user_chats(
|
||||||
text: str, page: Optional[int] = None, user=Depends(get_verified_user)
|
text: str,
|
||||||
|
page: Optional[int] = None,
|
||||||
|
user=Depends(get_verified_user),
|
||||||
|
db: Session = Depends(get_session),
|
||||||
):
|
):
|
||||||
if page is None:
|
if page is None:
|
||||||
page = 1
|
page = 1
|
||||||
|
|
@ -557,7 +593,7 @@ def search_user_chats(
|
||||||
chat_list = [
|
chat_list = [
|
||||||
ChatTitleIdResponse(**chat.model_dump())
|
ChatTitleIdResponse(**chat.model_dump())
|
||||||
for chat in Chats.get_chats_by_user_id_and_search_text(
|
for chat in Chats.get_chats_by_user_id_and_search_text(
|
||||||
user.id, text, skip=skip, limit=limit
|
user.id, text, skip=skip, limit=limit, db=db
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
@ -566,9 +602,9 @@ def search_user_chats(
|
||||||
if page == 1 and len(words) == 1 and words[0].startswith("tag:"):
|
if page == 1 and len(words) == 1 and words[0].startswith("tag:"):
|
||||||
tag_id = words[0].replace("tag:", "")
|
tag_id = words[0].replace("tag:", "")
|
||||||
if len(chat_list) == 0:
|
if len(chat_list) == 0:
|
||||||
if Tags.get_tag_by_name_and_user_id(tag_id, user.id):
|
if Tags.get_tag_by_name_and_user_id(tag_id, user.id, db=db):
|
||||||
log.debug(f"deleting tag: {tag_id}")
|
log.debug(f"deleting tag: {tag_id}")
|
||||||
Tags.delete_tag_by_name_and_user_id(tag_id, user.id)
|
Tags.delete_tag_by_name_and_user_id(tag_id, user.id, db=db)
|
||||||
|
|
||||||
return chat_list
|
return chat_list
|
||||||
|
|
||||||
|
|
@ -579,23 +615,30 @@ def search_user_chats(
|
||||||
|
|
||||||
|
|
||||||
@router.get("/folder/{folder_id}", response_model=list[ChatResponse])
|
@router.get("/folder/{folder_id}", response_model=list[ChatResponse])
|
||||||
async def get_chats_by_folder_id(folder_id: str, user=Depends(get_verified_user)):
|
async def get_chats_by_folder_id(
|
||||||
|
folder_id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||||
|
):
|
||||||
folder_ids = [folder_id]
|
folder_ids = [folder_id]
|
||||||
children_folders = Folders.get_children_folders_by_id_and_user_id(
|
children_folders = Folders.get_children_folders_by_id_and_user_id(
|
||||||
folder_id, user.id
|
folder_id, user.id, db=db
|
||||||
)
|
)
|
||||||
if children_folders:
|
if children_folders:
|
||||||
folder_ids.extend([folder.id for folder in children_folders])
|
folder_ids.extend([folder.id for folder in children_folders])
|
||||||
|
|
||||||
return [
|
return [
|
||||||
ChatResponse(**chat.model_dump())
|
ChatResponse(**chat.model_dump())
|
||||||
for chat in Chats.get_chats_by_folder_ids_and_user_id(folder_ids, user.id)
|
for chat in Chats.get_chats_by_folder_ids_and_user_id(
|
||||||
|
folder_ids, user.id, db=db
|
||||||
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@router.get("/folder/{folder_id}/list")
|
@router.get("/folder/{folder_id}/list")
|
||||||
async def get_chat_list_by_folder_id(
|
async def get_chat_list_by_folder_id(
|
||||||
folder_id: str, page: Optional[int] = 1, user=Depends(get_verified_user)
|
folder_id: str,
|
||||||
|
page: Optional[int] = 1,
|
||||||
|
user=Depends(get_verified_user),
|
||||||
|
db: Session = Depends(get_session),
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
limit = 10
|
limit = 10
|
||||||
|
|
@ -604,7 +647,7 @@ async def get_chat_list_by_folder_id(
|
||||||
return [
|
return [
|
||||||
{"title": chat.title, "id": chat.id, "updated_at": chat.updated_at}
|
{"title": chat.title, "id": chat.id, "updated_at": chat.updated_at}
|
||||||
for chat in Chats.get_chats_by_folder_id_and_user_id(
|
for chat in Chats.get_chats_by_folder_id_and_user_id(
|
||||||
folder_id, user.id, skip=skip, limit=limit
|
folder_id, user.id, skip=skip, limit=limit, db=db
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
@ -621,10 +664,12 @@ async def get_chat_list_by_folder_id(
|
||||||
|
|
||||||
|
|
||||||
@router.get("/pinned", response_model=list[ChatTitleIdResponse])
|
@router.get("/pinned", response_model=list[ChatTitleIdResponse])
|
||||||
async def get_user_pinned_chats(user=Depends(get_verified_user)):
|
async def get_user_pinned_chats(
|
||||||
|
user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||||
|
):
|
||||||
return [
|
return [
|
||||||
ChatTitleIdResponse(**chat.model_dump())
|
ChatTitleIdResponse(**chat.model_dump())
|
||||||
for chat in Chats.get_pinned_chats_by_user_id(user.id)
|
for chat in Chats.get_pinned_chats_by_user_id(user.id, db=db)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -634,10 +679,12 @@ async def get_user_pinned_chats(user=Depends(get_verified_user)):
|
||||||
|
|
||||||
|
|
||||||
@router.get("/all", response_model=list[ChatResponse])
|
@router.get("/all", response_model=list[ChatResponse])
|
||||||
async def get_user_chats(user=Depends(get_verified_user)):
|
async def get_user_chats(
|
||||||
|
user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||||
|
):
|
||||||
return [
|
return [
|
||||||
ChatResponse(**chat.model_dump())
|
ChatResponse(**chat.model_dump())
|
||||||
for chat in Chats.get_chats_by_user_id(user.id)
|
for chat in Chats.get_chats_by_user_id(user.id, db=db)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -647,10 +694,12 @@ async def get_user_chats(user=Depends(get_verified_user)):
|
||||||
|
|
||||||
|
|
||||||
@router.get("/all/archived", response_model=list[ChatResponse])
|
@router.get("/all/archived", response_model=list[ChatResponse])
|
||||||
async def get_user_archived_chats(user=Depends(get_verified_user)):
|
async def get_user_archived_chats(
|
||||||
|
user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||||
|
):
|
||||||
return [
|
return [
|
||||||
ChatResponse(**chat.model_dump())
|
ChatResponse(**chat.model_dump())
|
||||||
for chat in Chats.get_archived_chats_by_user_id(user.id)
|
for chat in Chats.get_archived_chats_by_user_id(user.id, db=db)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -660,9 +709,11 @@ async def get_user_archived_chats(user=Depends(get_verified_user)):
|
||||||
|
|
||||||
|
|
||||||
@router.get("/all/tags", response_model=list[TagModel])
|
@router.get("/all/tags", response_model=list[TagModel])
|
||||||
async def get_all_user_tags(user=Depends(get_verified_user)):
|
async def get_all_user_tags(
|
||||||
|
user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||||
|
):
|
||||||
try:
|
try:
|
||||||
tags = Tags.get_tags_by_user_id(user.id)
|
tags = Tags.get_tags_by_user_id(user.id, db=db)
|
||||||
return tags
|
return tags
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(e)
|
log.exception(e)
|
||||||
|
|
@ -677,13 +728,15 @@ async def get_all_user_tags(user=Depends(get_verified_user)):
|
||||||
|
|
||||||
|
|
||||||
@router.get("/all/db", response_model=list[ChatResponse])
|
@router.get("/all/db", response_model=list[ChatResponse])
|
||||||
async def get_all_user_chats_in_db(user=Depends(get_admin_user)):
|
async def get_all_user_chats_in_db(
|
||||||
|
user=Depends(get_admin_user), db: Session = Depends(get_session)
|
||||||
|
):
|
||||||
if not ENABLE_ADMIN_EXPORT:
|
if not ENABLE_ADMIN_EXPORT:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||||
)
|
)
|
||||||
return [ChatResponse(**chat.model_dump()) for chat in Chats.get_chats()]
|
return [ChatResponse(**chat.model_dump()) for chat in Chats.get_chats(db=db)]
|
||||||
|
|
||||||
|
|
||||||
############################
|
############################
|
||||||
|
|
@ -698,6 +751,7 @@ async def get_archived_session_user_chat_list(
|
||||||
order_by: Optional[str] = None,
|
order_by: Optional[str] = None,
|
||||||
direction: Optional[str] = None,
|
direction: Optional[str] = None,
|
||||||
user=Depends(get_verified_user),
|
user=Depends(get_verified_user),
|
||||||
|
db: Session = Depends(get_session),
|
||||||
):
|
):
|
||||||
if page is None:
|
if page is None:
|
||||||
page = 1
|
page = 1
|
||||||
|
|
@ -720,6 +774,7 @@ async def get_archived_session_user_chat_list(
|
||||||
filter=filter,
|
filter=filter,
|
||||||
skip=skip,
|
skip=skip,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
|
db=db,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
@ -732,8 +787,10 @@ async def get_archived_session_user_chat_list(
|
||||||
|
|
||||||
|
|
||||||
@router.post("/archive/all", response_model=bool)
|
@router.post("/archive/all", response_model=bool)
|
||||||
async def archive_all_chats(user=Depends(get_verified_user)):
|
async def archive_all_chats(
|
||||||
return Chats.archive_all_chats_by_user_id(user.id)
|
user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||||
|
):
|
||||||
|
return Chats.archive_all_chats_by_user_id(user.id, db=db)
|
||||||
|
|
||||||
|
|
||||||
############################
|
############################
|
||||||
|
|
@ -742,8 +799,10 @@ async def archive_all_chats(user=Depends(get_verified_user)):
|
||||||
|
|
||||||
|
|
||||||
@router.post("/unarchive/all", response_model=bool)
|
@router.post("/unarchive/all", response_model=bool)
|
||||||
async def unarchive_all_chats(user=Depends(get_verified_user)):
|
async def unarchive_all_chats(
|
||||||
return Chats.unarchive_all_chats_by_user_id(user.id)
|
user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||||
|
):
|
||||||
|
return Chats.unarchive_all_chats_by_user_id(user.id, db=db)
|
||||||
|
|
||||||
|
|
||||||
############################
|
############################
|
||||||
|
|
@ -752,16 +811,18 @@ async def unarchive_all_chats(user=Depends(get_verified_user)):
|
||||||
|
|
||||||
|
|
||||||
@router.get("/share/{share_id}", response_model=Optional[ChatResponse])
|
@router.get("/share/{share_id}", response_model=Optional[ChatResponse])
|
||||||
async def get_shared_chat_by_id(share_id: str, user=Depends(get_verified_user)):
|
async def get_shared_chat_by_id(
|
||||||
|
share_id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||||
|
):
|
||||||
if user.role == "pending":
|
if user.role == "pending":
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
|
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
|
||||||
)
|
)
|
||||||
|
|
||||||
if user.role == "user" or (user.role == "admin" and not ENABLE_ADMIN_CHAT_ACCESS):
|
if user.role == "user" or (user.role == "admin" and not ENABLE_ADMIN_CHAT_ACCESS):
|
||||||
chat = Chats.get_chat_by_share_id(share_id)
|
chat = Chats.get_chat_by_share_id(share_id, db=db)
|
||||||
elif user.role == "admin" and ENABLE_ADMIN_CHAT_ACCESS:
|
elif user.role == "admin" and ENABLE_ADMIN_CHAT_ACCESS:
|
||||||
chat = Chats.get_chat_by_id(share_id)
|
chat = Chats.get_chat_by_id(share_id, db=db)
|
||||||
|
|
||||||
if chat:
|
if chat:
|
||||||
return ChatResponse(**chat.model_dump())
|
return ChatResponse(**chat.model_dump())
|
||||||
|
|
@ -788,13 +849,15 @@ class TagFilterForm(TagForm):
|
||||||
|
|
||||||
@router.post("/tags", response_model=list[ChatTitleIdResponse])
|
@router.post("/tags", response_model=list[ChatTitleIdResponse])
|
||||||
async def get_user_chat_list_by_tag_name(
|
async def get_user_chat_list_by_tag_name(
|
||||||
form_data: TagFilterForm, user=Depends(get_verified_user)
|
form_data: TagFilterForm,
|
||||||
|
user=Depends(get_verified_user),
|
||||||
|
db: Session = Depends(get_session),
|
||||||
):
|
):
|
||||||
chats = Chats.get_chat_list_by_user_id_and_tag_name(
|
chats = Chats.get_chat_list_by_user_id_and_tag_name(
|
||||||
user.id, form_data.name, form_data.skip, form_data.limit
|
user.id, form_data.name, form_data.skip, form_data.limit, db=db
|
||||||
)
|
)
|
||||||
if len(chats) == 0:
|
if len(chats) == 0:
|
||||||
Tags.delete_tag_by_name_and_user_id(form_data.name, user.id)
|
Tags.delete_tag_by_name_and_user_id(form_data.name, user.id, db=db)
|
||||||
|
|
||||||
return chats
|
return chats
|
||||||
|
|
||||||
|
|
@ -805,8 +868,10 @@ async def get_user_chat_list_by_tag_name(
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{id}", response_model=Optional[ChatResponse])
|
@router.get("/{id}", response_model=Optional[ChatResponse])
|
||||||
async def get_chat_by_id(id: str, user=Depends(get_verified_user)):
|
async def get_chat_by_id(
|
||||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||||
|
):
|
||||||
|
chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db)
|
||||||
|
|
||||||
if chat:
|
if chat:
|
||||||
return ChatResponse(**chat.model_dump())
|
return ChatResponse(**chat.model_dump())
|
||||||
|
|
@ -824,12 +889,15 @@ async def get_chat_by_id(id: str, user=Depends(get_verified_user)):
|
||||||
|
|
||||||
@router.post("/{id}", response_model=Optional[ChatResponse])
|
@router.post("/{id}", response_model=Optional[ChatResponse])
|
||||||
async def update_chat_by_id(
|
async def update_chat_by_id(
|
||||||
id: str, form_data: ChatForm, user=Depends(get_verified_user)
|
id: str,
|
||||||
|
form_data: ChatForm,
|
||||||
|
user=Depends(get_verified_user),
|
||||||
|
db: Session = Depends(get_session),
|
||||||
):
|
):
|
||||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db)
|
||||||
if chat:
|
if chat:
|
||||||
updated_chat = {**chat.chat, **form_data.chat}
|
updated_chat = {**chat.chat, **form_data.chat}
|
||||||
chat = Chats.update_chat_by_id(id, updated_chat)
|
chat = Chats.update_chat_by_id(id, updated_chat, db=db)
|
||||||
return ChatResponse(**chat.model_dump())
|
return ChatResponse(**chat.model_dump())
|
||||||
else:
|
else:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -847,9 +915,13 @@ class MessageForm(BaseModel):
|
||||||
|
|
||||||
@router.post("/{id}/messages/{message_id}", response_model=Optional[ChatResponse])
|
@router.post("/{id}/messages/{message_id}", response_model=Optional[ChatResponse])
|
||||||
async def update_chat_message_by_id(
|
async def update_chat_message_by_id(
|
||||||
id: str, message_id: str, form_data: MessageForm, user=Depends(get_verified_user)
|
id: str,
|
||||||
|
message_id: str,
|
||||||
|
form_data: MessageForm,
|
||||||
|
user=Depends(get_verified_user),
|
||||||
|
db: Session = Depends(get_session),
|
||||||
):
|
):
|
||||||
chat = Chats.get_chat_by_id(id)
|
chat = Chats.get_chat_by_id(id, db=db)
|
||||||
|
|
||||||
if not chat:
|
if not chat:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -869,6 +941,7 @@ async def update_chat_message_by_id(
|
||||||
{
|
{
|
||||||
"content": form_data.content,
|
"content": form_data.content,
|
||||||
},
|
},
|
||||||
|
db=db,
|
||||||
)
|
)
|
||||||
|
|
||||||
event_emitter = get_event_emitter(
|
event_emitter = get_event_emitter(
|
||||||
|
|
@ -905,9 +978,13 @@ class EventForm(BaseModel):
|
||||||
|
|
||||||
@router.post("/{id}/messages/{message_id}/event", response_model=Optional[bool])
|
@router.post("/{id}/messages/{message_id}/event", response_model=Optional[bool])
|
||||||
async def send_chat_message_event_by_id(
|
async def send_chat_message_event_by_id(
|
||||||
id: str, message_id: str, form_data: EventForm, user=Depends(get_verified_user)
|
id: str,
|
||||||
|
message_id: str,
|
||||||
|
form_data: EventForm,
|
||||||
|
user=Depends(get_verified_user),
|
||||||
|
db: Session = Depends(get_session),
|
||||||
):
|
):
|
||||||
chat = Chats.get_chat_by_id(id)
|
chat = Chats.get_chat_by_id(id, db=db)
|
||||||
|
|
||||||
if not chat:
|
if not chat:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -945,14 +1022,19 @@ async def send_chat_message_event_by_id(
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/{id}", response_model=bool)
|
@router.delete("/{id}", response_model=bool)
|
||||||
async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified_user)):
|
async def delete_chat_by_id(
|
||||||
|
request: Request,
|
||||||
|
id: str,
|
||||||
|
user=Depends(get_verified_user),
|
||||||
|
db: Session = Depends(get_session),
|
||||||
|
):
|
||||||
if user.role == "admin":
|
if user.role == "admin":
|
||||||
chat = Chats.get_chat_by_id(id)
|
chat = Chats.get_chat_by_id(id, db=db)
|
||||||
for tag in chat.meta.get("tags", []):
|
for tag in chat.meta.get("tags", []):
|
||||||
if Chats.count_chats_by_tag_name_and_user_id(tag, user.id) == 1:
|
if Chats.count_chats_by_tag_name_and_user_id(tag, user.id, db=db) == 1:
|
||||||
Tags.delete_tag_by_name_and_user_id(tag, user.id)
|
Tags.delete_tag_by_name_and_user_id(tag, user.id, db=db)
|
||||||
|
|
||||||
result = Chats.delete_chat_by_id(id)
|
result = Chats.delete_chat_by_id(id, db=db)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
else:
|
else:
|
||||||
|
|
@ -964,12 +1046,12 @@ async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified
|
||||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||||
)
|
)
|
||||||
|
|
||||||
chat = Chats.get_chat_by_id(id)
|
chat = Chats.get_chat_by_id(id, db=db)
|
||||||
for tag in chat.meta.get("tags", []):
|
for tag in chat.meta.get("tags", []):
|
||||||
if Chats.count_chats_by_tag_name_and_user_id(tag, user.id) == 1:
|
if Chats.count_chats_by_tag_name_and_user_id(tag, user.id, db=db) == 1:
|
||||||
Tags.delete_tag_by_name_and_user_id(tag, user.id)
|
Tags.delete_tag_by_name_and_user_id(tag, user.id, db=db)
|
||||||
|
|
||||||
result = Chats.delete_chat_by_id_and_user_id(id, user.id)
|
result = Chats.delete_chat_by_id_and_user_id(id, user.id, db=db)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -979,8 +1061,10 @@ async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{id}/pinned", response_model=Optional[bool])
|
@router.get("/{id}/pinned", response_model=Optional[bool])
|
||||||
async def get_pinned_status_by_id(id: str, user=Depends(get_verified_user)):
|
async def get_pinned_status_by_id(
|
||||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||||
|
):
|
||||||
|
chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db)
|
||||||
if chat:
|
if chat:
|
||||||
return chat.pinned
|
return chat.pinned
|
||||||
else:
|
else:
|
||||||
|
|
@ -995,10 +1079,12 @@ async def get_pinned_status_by_id(id: str, user=Depends(get_verified_user)):
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{id}/pin", response_model=Optional[ChatResponse])
|
@router.post("/{id}/pin", response_model=Optional[ChatResponse])
|
||||||
async def pin_chat_by_id(id: str, user=Depends(get_verified_user)):
|
async def pin_chat_by_id(
|
||||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||||
|
):
|
||||||
|
chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db)
|
||||||
if chat:
|
if chat:
|
||||||
chat = Chats.toggle_chat_pinned_by_id(id)
|
chat = Chats.toggle_chat_pinned_by_id(id, db=db)
|
||||||
return chat
|
return chat
|
||||||
else:
|
else:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -1017,9 +1103,12 @@ class CloneForm(BaseModel):
|
||||||
|
|
||||||
@router.post("/{id}/clone", response_model=Optional[ChatResponse])
|
@router.post("/{id}/clone", response_model=Optional[ChatResponse])
|
||||||
async def clone_chat_by_id(
|
async def clone_chat_by_id(
|
||||||
form_data: CloneForm, id: str, user=Depends(get_verified_user)
|
form_data: CloneForm,
|
||||||
|
id: str,
|
||||||
|
user=Depends(get_verified_user),
|
||||||
|
db: Session = Depends(get_session),
|
||||||
):
|
):
|
||||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db)
|
||||||
if chat:
|
if chat:
|
||||||
updated_chat = {
|
updated_chat = {
|
||||||
**chat.chat,
|
**chat.chat,
|
||||||
|
|
@ -1040,6 +1129,7 @@ async def clone_chat_by_id(
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
|
db=db,
|
||||||
)
|
)
|
||||||
|
|
||||||
if chats:
|
if chats:
|
||||||
|
|
@ -1062,12 +1152,14 @@ async def clone_chat_by_id(
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{id}/clone/shared", response_model=Optional[ChatResponse])
|
@router.post("/{id}/clone/shared", response_model=Optional[ChatResponse])
|
||||||
async def clone_shared_chat_by_id(id: str, user=Depends(get_verified_user)):
|
async def clone_shared_chat_by_id(
|
||||||
|
id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||||
|
):
|
||||||
|
|
||||||
if user.role == "admin":
|
if user.role == "admin":
|
||||||
chat = Chats.get_chat_by_id(id)
|
chat = Chats.get_chat_by_id(id, db=db)
|
||||||
else:
|
else:
|
||||||
chat = Chats.get_chat_by_share_id(id)
|
chat = Chats.get_chat_by_share_id(id, db=db)
|
||||||
|
|
||||||
if chat:
|
if chat:
|
||||||
updated_chat = {
|
updated_chat = {
|
||||||
|
|
@ -1089,6 +1181,7 @@ async def clone_shared_chat_by_id(id: str, user=Depends(get_verified_user)):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
|
db=db,
|
||||||
)
|
)
|
||||||
|
|
||||||
if chats:
|
if chats:
|
||||||
|
|
@ -1111,23 +1204,28 @@ async def clone_shared_chat_by_id(id: str, user=Depends(get_verified_user)):
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{id}/archive", response_model=Optional[ChatResponse])
|
@router.post("/{id}/archive", response_model=Optional[ChatResponse])
|
||||||
async def archive_chat_by_id(id: str, user=Depends(get_verified_user)):
|
async def archive_chat_by_id(
|
||||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||||
|
):
|
||||||
|
chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db)
|
||||||
if chat:
|
if chat:
|
||||||
chat = Chats.toggle_chat_archive_by_id(id)
|
chat = Chats.toggle_chat_archive_by_id(id, db=db)
|
||||||
|
|
||||||
# Delete tags if chat is archived
|
# Delete tags if chat is archived
|
||||||
if chat.archived:
|
if chat.archived:
|
||||||
for tag_id in chat.meta.get("tags", []):
|
for tag_id in chat.meta.get("tags", []):
|
||||||
if Chats.count_chats_by_tag_name_and_user_id(tag_id, user.id) == 0:
|
if (
|
||||||
|
Chats.count_chats_by_tag_name_and_user_id(tag_id, user.id, db=db)
|
||||||
|
== 0
|
||||||
|
):
|
||||||
log.debug(f"deleting tag: {tag_id}")
|
log.debug(f"deleting tag: {tag_id}")
|
||||||
Tags.delete_tag_by_name_and_user_id(tag_id, user.id)
|
Tags.delete_tag_by_name_and_user_id(tag_id, user.id, db=db)
|
||||||
else:
|
else:
|
||||||
for tag_id in chat.meta.get("tags", []):
|
for tag_id in chat.meta.get("tags", []):
|
||||||
tag = Tags.get_tag_by_name_and_user_id(tag_id, user.id)
|
tag = Tags.get_tag_by_name_and_user_id(tag_id, user.id, db=db)
|
||||||
if tag is None:
|
if tag is None:
|
||||||
log.debug(f"inserting tag: {tag_id}")
|
log.debug(f"inserting tag: {tag_id}")
|
||||||
tag = Tags.insert_new_tag(tag_id, user.id)
|
tag = Tags.insert_new_tag(tag_id, user.id, db=db)
|
||||||
|
|
||||||
return ChatResponse(**chat.model_dump())
|
return ChatResponse(**chat.model_dump())
|
||||||
else:
|
else:
|
||||||
|
|
@ -1142,7 +1240,12 @@ async def archive_chat_by_id(id: str, user=Depends(get_verified_user)):
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{id}/share", response_model=Optional[ChatResponse])
|
@router.post("/{id}/share", response_model=Optional[ChatResponse])
|
||||||
async def share_chat_by_id(request: Request, id: str, user=Depends(get_verified_user)):
|
async def share_chat_by_id(
|
||||||
|
request: Request,
|
||||||
|
id: str,
|
||||||
|
user=Depends(get_verified_user),
|
||||||
|
db: Session = Depends(get_session),
|
||||||
|
):
|
||||||
if (user.role != "admin") and (
|
if (user.role != "admin") and (
|
||||||
not has_permission(
|
not has_permission(
|
||||||
user.id, "chat.share", request.app.state.config.USER_PERMISSIONS
|
user.id, "chat.share", request.app.state.config.USER_PERMISSIONS
|
||||||
|
|
@ -1153,14 +1256,14 @@ async def share_chat_by_id(request: Request, id: str, user=Depends(get_verified_
|
||||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||||
)
|
)
|
||||||
|
|
||||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db)
|
||||||
|
|
||||||
if chat:
|
if chat:
|
||||||
if chat.share_id:
|
if chat.share_id:
|
||||||
shared_chat = Chats.update_shared_chat_by_chat_id(chat.id)
|
shared_chat = Chats.update_shared_chat_by_chat_id(chat.id, db=db)
|
||||||
return ChatResponse(**shared_chat.model_dump())
|
return ChatResponse(**shared_chat.model_dump())
|
||||||
|
|
||||||
shared_chat = Chats.insert_shared_chat_by_chat_id(chat.id)
|
shared_chat = Chats.insert_shared_chat_by_chat_id(chat.id, db=db)
|
||||||
if not shared_chat:
|
if not shared_chat:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
|
@ -1181,14 +1284,16 @@ async def share_chat_by_id(request: Request, id: str, user=Depends(get_verified_
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/{id}/share", response_model=Optional[bool])
|
@router.delete("/{id}/share", response_model=Optional[bool])
|
||||||
async def delete_shared_chat_by_id(id: str, user=Depends(get_verified_user)):
|
async def delete_shared_chat_by_id(
|
||||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||||
|
):
|
||||||
|
chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db)
|
||||||
if chat:
|
if chat:
|
||||||
if not chat.share_id:
|
if not chat.share_id:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
result = Chats.delete_shared_chat_by_chat_id(id)
|
result = Chats.delete_shared_chat_by_chat_id(id, db=db)
|
||||||
update_result = Chats.update_chat_share_id_by_id(id, None)
|
update_result = Chats.update_chat_share_id_by_id(id, None, db=db)
|
||||||
|
|
||||||
return result and update_result != None
|
return result and update_result != None
|
||||||
else:
|
else:
|
||||||
|
|
@ -1209,12 +1314,15 @@ class ChatFolderIdForm(BaseModel):
|
||||||
|
|
||||||
@router.post("/{id}/folder", response_model=Optional[ChatResponse])
|
@router.post("/{id}/folder", response_model=Optional[ChatResponse])
|
||||||
async def update_chat_folder_id_by_id(
|
async def update_chat_folder_id_by_id(
|
||||||
id: str, form_data: ChatFolderIdForm, user=Depends(get_verified_user)
|
id: str,
|
||||||
|
form_data: ChatFolderIdForm,
|
||||||
|
user=Depends(get_verified_user),
|
||||||
|
db: Session = Depends(get_session),
|
||||||
):
|
):
|
||||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db)
|
||||||
if chat:
|
if chat:
|
||||||
chat = Chats.update_chat_folder_id_by_id_and_user_id(
|
chat = Chats.update_chat_folder_id_by_id_and_user_id(
|
||||||
id, user.id, form_data.folder_id
|
id, user.id, form_data.folder_id, db=db
|
||||||
)
|
)
|
||||||
return ChatResponse(**chat.model_dump())
|
return ChatResponse(**chat.model_dump())
|
||||||
else:
|
else:
|
||||||
|
|
@ -1229,11 +1337,13 @@ async def update_chat_folder_id_by_id(
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{id}/tags", response_model=list[TagModel])
|
@router.get("/{id}/tags", response_model=list[TagModel])
|
||||||
async def get_chat_tags_by_id(id: str, user=Depends(get_verified_user)):
|
async def get_chat_tags_by_id(
|
||||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||||
|
):
|
||||||
|
chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db)
|
||||||
if chat:
|
if chat:
|
||||||
tags = chat.meta.get("tags", [])
|
tags = chat.meta.get("tags", [])
|
||||||
return Tags.get_tags_by_ids_and_user_id(tags, user.id)
|
return Tags.get_tags_by_ids_and_user_id(tags, user.id, db=db)
|
||||||
else:
|
else:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
|
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
|
||||||
|
|
@ -1247,9 +1357,12 @@ async def get_chat_tags_by_id(id: str, user=Depends(get_verified_user)):
|
||||||
|
|
||||||
@router.post("/{id}/tags", response_model=list[TagModel])
|
@router.post("/{id}/tags", response_model=list[TagModel])
|
||||||
async def add_tag_by_id_and_tag_name(
|
async def add_tag_by_id_and_tag_name(
|
||||||
id: str, form_data: TagForm, user=Depends(get_verified_user)
|
id: str,
|
||||||
|
form_data: TagForm,
|
||||||
|
user=Depends(get_verified_user),
|
||||||
|
db: Session = Depends(get_session),
|
||||||
):
|
):
|
||||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db)
|
||||||
if chat:
|
if chat:
|
||||||
tags = chat.meta.get("tags", [])
|
tags = chat.meta.get("tags", [])
|
||||||
tag_id = form_data.name.replace(" ", "_").lower()
|
tag_id = form_data.name.replace(" ", "_").lower()
|
||||||
|
|
@ -1262,12 +1375,12 @@ async def add_tag_by_id_and_tag_name(
|
||||||
|
|
||||||
if tag_id not in tags:
|
if tag_id not in tags:
|
||||||
Chats.add_chat_tag_by_id_and_user_id_and_tag_name(
|
Chats.add_chat_tag_by_id_and_user_id_and_tag_name(
|
||||||
id, user.id, form_data.name
|
id, user.id, form_data.name, db=db
|
||||||
)
|
)
|
||||||
|
|
||||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db)
|
||||||
tags = chat.meta.get("tags", [])
|
tags = chat.meta.get("tags", [])
|
||||||
return Tags.get_tags_by_ids_and_user_id(tags, user.id)
|
return Tags.get_tags_by_ids_and_user_id(tags, user.id, db=db)
|
||||||
else:
|
else:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
|
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
|
||||||
|
|
@ -1281,18 +1394,26 @@ async def add_tag_by_id_and_tag_name(
|
||||||
|
|
||||||
@router.delete("/{id}/tags", response_model=list[TagModel])
|
@router.delete("/{id}/tags", response_model=list[TagModel])
|
||||||
async def delete_tag_by_id_and_tag_name(
|
async def delete_tag_by_id_and_tag_name(
|
||||||
id: str, form_data: TagForm, user=Depends(get_verified_user)
|
id: str,
|
||||||
|
form_data: TagForm,
|
||||||
|
user=Depends(get_verified_user),
|
||||||
|
db: Session = Depends(get_session),
|
||||||
):
|
):
|
||||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db)
|
||||||
if chat:
|
if chat:
|
||||||
Chats.delete_tag_by_id_and_user_id_and_tag_name(id, user.id, form_data.name)
|
Chats.delete_tag_by_id_and_user_id_and_tag_name(
|
||||||
|
id, user.id, form_data.name, db=db
|
||||||
|
)
|
||||||
|
|
||||||
if Chats.count_chats_by_tag_name_and_user_id(form_data.name, user.id) == 0:
|
if (
|
||||||
Tags.delete_tag_by_name_and_user_id(form_data.name, user.id)
|
Chats.count_chats_by_tag_name_and_user_id(form_data.name, user.id, db=db)
|
||||||
|
== 0
|
||||||
|
):
|
||||||
|
Tags.delete_tag_by_name_and_user_id(form_data.name, user.id, db=db)
|
||||||
|
|
||||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db)
|
||||||
tags = chat.meta.get("tags", [])
|
tags = chat.meta.get("tags", [])
|
||||||
return Tags.get_tags_by_ids_and_user_id(tags, user.id)
|
return Tags.get_tags_by_ids_and_user_id(tags, user.id, db=db)
|
||||||
else:
|
else:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
|
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
|
||||||
|
|
@ -1305,14 +1426,16 @@ async def delete_tag_by_id_and_tag_name(
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/{id}/tags/all", response_model=Optional[bool])
|
@router.delete("/{id}/tags/all", response_model=Optional[bool])
|
||||||
async def delete_all_tags_by_id(id: str, user=Depends(get_verified_user)):
|
async def delete_all_tags_by_id(
|
||||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||||
|
):
|
||||||
|
chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db)
|
||||||
if chat:
|
if chat:
|
||||||
Chats.delete_all_tags_by_id_and_user_id(id, user.id)
|
Chats.delete_all_tags_by_id_and_user_id(id, user.id, db=db)
|
||||||
|
|
||||||
for tag in chat.meta.get("tags", []):
|
for tag in chat.meta.get("tags", []):
|
||||||
if Chats.count_chats_by_tag_name_and_user_id(tag, user.id) == 0:
|
if Chats.count_chats_by_tag_name_and_user_id(tag, user.id, db=db) == 0:
|
||||||
Tags.delete_tag_by_name_and_user_id(tag, user.id)
|
Tags.delete_tag_by_name_and_user_id(tag, user.id, db=db)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,8 @@ from open_webui.models.feedbacks import (
|
||||||
|
|
||||||
from open_webui.constants import ERROR_MESSAGES
|
from open_webui.constants import ERROR_MESSAGES
|
||||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||||
|
from open_webui.internal.db import get_session
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
@ -60,38 +62,50 @@ async def update_config(
|
||||||
|
|
||||||
|
|
||||||
@router.get("/feedbacks/all", response_model=list[FeedbackResponse])
|
@router.get("/feedbacks/all", response_model=list[FeedbackResponse])
|
||||||
async def get_all_feedbacks(user=Depends(get_admin_user)):
|
async def get_all_feedbacks(
|
||||||
feedbacks = Feedbacks.get_all_feedbacks()
|
user=Depends(get_admin_user), db: Session = Depends(get_session)
|
||||||
|
):
|
||||||
|
feedbacks = Feedbacks.get_all_feedbacks(db=db)
|
||||||
return feedbacks
|
return feedbacks
|
||||||
|
|
||||||
|
|
||||||
@router.get("/feedbacks/all/ids", response_model=list[FeedbackIdResponse])
|
@router.get("/feedbacks/all/ids", response_model=list[FeedbackIdResponse])
|
||||||
async def get_all_feedback_ids(user=Depends(get_admin_user)):
|
async def get_all_feedback_ids(
|
||||||
feedbacks = Feedbacks.get_all_feedbacks()
|
user=Depends(get_admin_user), db: Session = Depends(get_session)
|
||||||
|
):
|
||||||
|
feedbacks = Feedbacks.get_all_feedbacks(db=db)
|
||||||
return feedbacks
|
return feedbacks
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/feedbacks/all")
|
@router.delete("/feedbacks/all")
|
||||||
async def delete_all_feedbacks(user=Depends(get_admin_user)):
|
async def delete_all_feedbacks(
|
||||||
success = Feedbacks.delete_all_feedbacks()
|
user=Depends(get_admin_user), db: Session = Depends(get_session)
|
||||||
|
):
|
||||||
|
success = Feedbacks.delete_all_feedbacks(db=db)
|
||||||
return success
|
return success
|
||||||
|
|
||||||
|
|
||||||
@router.get("/feedbacks/all/export", response_model=list[FeedbackModel])
|
@router.get("/feedbacks/all/export", response_model=list[FeedbackModel])
|
||||||
async def export_all_feedbacks(user=Depends(get_admin_user)):
|
async def export_all_feedbacks(
|
||||||
feedbacks = Feedbacks.get_all_feedbacks()
|
user=Depends(get_admin_user), db: Session = Depends(get_session)
|
||||||
|
):
|
||||||
|
feedbacks = Feedbacks.get_all_feedbacks(db=db)
|
||||||
return feedbacks
|
return feedbacks
|
||||||
|
|
||||||
|
|
||||||
@router.get("/feedbacks/user", response_model=list[FeedbackUserResponse])
|
@router.get("/feedbacks/user", response_model=list[FeedbackUserResponse])
|
||||||
async def get_feedbacks(user=Depends(get_verified_user)):
|
async def get_feedbacks(
|
||||||
feedbacks = Feedbacks.get_feedbacks_by_user_id(user.id)
|
user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||||
|
):
|
||||||
|
feedbacks = Feedbacks.get_feedbacks_by_user_id(user.id, db=db)
|
||||||
return feedbacks
|
return feedbacks
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/feedbacks", response_model=bool)
|
@router.delete("/feedbacks", response_model=bool)
|
||||||
async def delete_feedbacks(user=Depends(get_verified_user)):
|
async def delete_feedbacks(
|
||||||
success = Feedbacks.delete_feedbacks_by_user_id(user.id)
|
user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||||
|
):
|
||||||
|
success = Feedbacks.delete_feedbacks_by_user_id(user.id, db=db)
|
||||||
return success
|
return success
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -104,6 +118,7 @@ async def get_feedbacks(
|
||||||
direction: Optional[str] = None,
|
direction: Optional[str] = None,
|
||||||
page: Optional[int] = 1,
|
page: Optional[int] = 1,
|
||||||
user=Depends(get_admin_user),
|
user=Depends(get_admin_user),
|
||||||
|
db: Session = Depends(get_session),
|
||||||
):
|
):
|
||||||
limit = PAGE_ITEM_COUNT
|
limit = PAGE_ITEM_COUNT
|
||||||
|
|
||||||
|
|
@ -116,7 +131,7 @@ async def get_feedbacks(
|
||||||
if direction:
|
if direction:
|
||||||
filter["direction"] = direction
|
filter["direction"] = direction
|
||||||
|
|
||||||
result = Feedbacks.get_feedback_items(filter=filter, skip=skip, limit=limit)
|
result = Feedbacks.get_feedback_items(filter=filter, skip=skip, limit=limit, db=db)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -125,8 +140,11 @@ async def create_feedback(
|
||||||
request: Request,
|
request: Request,
|
||||||
form_data: FeedbackForm,
|
form_data: FeedbackForm,
|
||||||
user=Depends(get_verified_user),
|
user=Depends(get_verified_user),
|
||||||
|
db: Session = Depends(get_session),
|
||||||
):
|
):
|
||||||
feedback = Feedbacks.insert_new_feedback(user_id=user.id, form_data=form_data)
|
feedback = Feedbacks.insert_new_feedback(
|
||||||
|
user_id=user.id, form_data=form_data, db=db
|
||||||
|
)
|
||||||
if not feedback:
|
if not feedback:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
|
@ -137,11 +155,15 @@ async def create_feedback(
|
||||||
|
|
||||||
|
|
||||||
@router.get("/feedback/{id}", response_model=FeedbackModel)
|
@router.get("/feedback/{id}", response_model=FeedbackModel)
|
||||||
async def get_feedback_by_id(id: str, user=Depends(get_verified_user)):
|
async def get_feedback_by_id(
|
||||||
|
id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||||
|
):
|
||||||
if user.role == "admin":
|
if user.role == "admin":
|
||||||
feedback = Feedbacks.get_feedback_by_id(id=id)
|
feedback = Feedbacks.get_feedback_by_id(id=id, db=db)
|
||||||
else:
|
else:
|
||||||
feedback = Feedbacks.get_feedback_by_id_and_user_id(id=id, user_id=user.id)
|
feedback = Feedbacks.get_feedback_by_id_and_user_id(
|
||||||
|
id=id, user_id=user.id, db=db
|
||||||
|
)
|
||||||
|
|
||||||
if not feedback:
|
if not feedback:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -153,13 +175,16 @@ async def get_feedback_by_id(id: str, user=Depends(get_verified_user)):
|
||||||
|
|
||||||
@router.post("/feedback/{id}", response_model=FeedbackModel)
|
@router.post("/feedback/{id}", response_model=FeedbackModel)
|
||||||
async def update_feedback_by_id(
|
async def update_feedback_by_id(
|
||||||
id: str, form_data: FeedbackForm, user=Depends(get_verified_user)
|
id: str,
|
||||||
|
form_data: FeedbackForm,
|
||||||
|
user=Depends(get_verified_user),
|
||||||
|
db: Session = Depends(get_session),
|
||||||
):
|
):
|
||||||
if user.role == "admin":
|
if user.role == "admin":
|
||||||
feedback = Feedbacks.update_feedback_by_id(id=id, form_data=form_data)
|
feedback = Feedbacks.update_feedback_by_id(id=id, form_data=form_data, db=db)
|
||||||
else:
|
else:
|
||||||
feedback = Feedbacks.update_feedback_by_id_and_user_id(
|
feedback = Feedbacks.update_feedback_by_id_and_user_id(
|
||||||
id=id, user_id=user.id, form_data=form_data
|
id=id, user_id=user.id, form_data=form_data, db=db
|
||||||
)
|
)
|
||||||
|
|
||||||
if not feedback:
|
if not feedback:
|
||||||
|
|
@ -171,11 +196,15 @@ async def update_feedback_by_id(
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/feedback/{id}")
|
@router.delete("/feedback/{id}")
|
||||||
async def delete_feedback_by_id(id: str, user=Depends(get_verified_user)):
|
async def delete_feedback_by_id(
|
||||||
|
id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||||
|
):
|
||||||
if user.role == "admin":
|
if user.role == "admin":
|
||||||
success = Feedbacks.delete_feedback_by_id(id=id)
|
success = Feedbacks.delete_feedback_by_id(id=id, db=db)
|
||||||
else:
|
else:
|
||||||
success = Feedbacks.delete_feedback_by_id_and_user_id(id=id, user_id=user.id)
|
success = Feedbacks.delete_feedback_by_id_and_user_id(
|
||||||
|
id=id, user_id=user.id, db=db
|
||||||
|
)
|
||||||
|
|
||||||
if not success:
|
if not success:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,8 @@ from fastapi import (
|
||||||
)
|
)
|
||||||
|
|
||||||
from fastapi.responses import FileResponse, StreamingResponse
|
from fastapi.responses import FileResponse, StreamingResponse
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from open_webui.internal.db import get_session, SessionLocal
|
||||||
|
|
||||||
from open_webui.constants import ERROR_MESSAGES
|
from open_webui.constants import ERROR_MESSAGES
|
||||||
from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT
|
from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT
|
||||||
|
|
@ -62,9 +64,12 @@ router = APIRouter()
|
||||||
|
|
||||||
# TODO: Optimize this function to use the knowledge_file table for faster lookups.
|
# TODO: Optimize this function to use the knowledge_file table for faster lookups.
|
||||||
def has_access_to_file(
|
def has_access_to_file(
|
||||||
file_id: Optional[str], access_type: str, user=Depends(get_verified_user)
|
file_id: Optional[str],
|
||||||
|
access_type: str,
|
||||||
|
user=Depends(get_verified_user),
|
||||||
|
db: Optional[Session] = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
file = Files.get_file_by_id(file_id)
|
file = Files.get_file_by_id(file_id, db=db)
|
||||||
log.debug(f"Checking if user has {access_type} access to file")
|
log.debug(f"Checking if user has {access_type} access to file")
|
||||||
if not file:
|
if not file:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -73,31 +78,33 @@ def has_access_to_file(
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if the file is associated with any knowledge bases the user has access to
|
# Check if the file is associated with any knowledge bases the user has access to
|
||||||
knowledge_bases = Knowledges.get_knowledges_by_file_id(file_id)
|
knowledge_bases = Knowledges.get_knowledges_by_file_id(file_id, db=db)
|
||||||
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user.id)}
|
user_group_ids = {
|
||||||
|
group.id for group in Groups.get_groups_by_member_id(user.id, db=db)
|
||||||
|
}
|
||||||
for knowledge_base in knowledge_bases:
|
for knowledge_base in knowledge_bases:
|
||||||
if knowledge_base.user_id == user.id or has_access(
|
if knowledge_base.user_id == user.id or has_access(
|
||||||
user.id, access_type, knowledge_base.access_control, user_group_ids
|
user.id, access_type, knowledge_base.access_control, user_group_ids, db=db
|
||||||
):
|
):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
knowledge_base_id = file.meta.get("collection_name") if file.meta else None
|
knowledge_base_id = file.meta.get("collection_name") if file.meta else None
|
||||||
if knowledge_base_id:
|
if knowledge_base_id:
|
||||||
knowledge_bases = Knowledges.get_knowledge_bases_by_user_id(
|
knowledge_bases = Knowledges.get_knowledge_bases_by_user_id(
|
||||||
user.id, access_type
|
user.id, access_type, db=db
|
||||||
)
|
)
|
||||||
for knowledge_base in knowledge_bases:
|
for knowledge_base in knowledge_bases:
|
||||||
if knowledge_base.id == knowledge_base_id:
|
if knowledge_base.id == knowledge_base_id:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# Check if the file is associated with any channels the user has access to
|
# Check if the file is associated with any channels the user has access to
|
||||||
channels = Channels.get_channels_by_file_id_and_user_id(file_id, user.id)
|
channels = Channels.get_channels_by_file_id_and_user_id(file_id, user.id, db=db)
|
||||||
if access_type == "read" and channels:
|
if access_type == "read" and channels:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# Check if the file is associated with any chats the user has access to
|
# Check if the file is associated with any chats the user has access to
|
||||||
# TODO: Granular access control for chats
|
# TODO: Granular access control for chats
|
||||||
chats = Chats.get_shared_chats_by_file_id(file_id)
|
chats = Chats.get_shared_chats_by_file_id(file_id, db=db)
|
||||||
if chats:
|
if chats:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
@ -109,16 +116,29 @@ def has_access_to_file(
|
||||||
############################
|
############################
|
||||||
|
|
||||||
|
|
||||||
def process_uploaded_file(request, file, file_path, file_item, file_metadata, user):
|
def process_uploaded_file(
|
||||||
|
request,
|
||||||
|
file,
|
||||||
|
file_path,
|
||||||
|
file_item,
|
||||||
|
file_metadata,
|
||||||
|
user,
|
||||||
|
db: Optional[Session] = None,
|
||||||
|
):
|
||||||
|
def _process_handler(db_session):
|
||||||
try:
|
try:
|
||||||
if file.content_type:
|
if file.content_type:
|
||||||
stt_supported_content_types = getattr(
|
stt_supported_content_types = getattr(
|
||||||
request.app.state.config, "STT_SUPPORTED_CONTENT_TYPES", []
|
request.app.state.config, "STT_SUPPORTED_CONTENT_TYPES", []
|
||||||
)
|
)
|
||||||
|
|
||||||
if strict_match_mime_type(stt_supported_content_types, file.content_type):
|
if strict_match_mime_type(
|
||||||
file_path = Storage.get_file(file_path)
|
stt_supported_content_types, file.content_type
|
||||||
result = transcribe(request, file_path, file_metadata, user)
|
):
|
||||||
|
file_path_processed = Storage.get_file(file_path)
|
||||||
|
result = transcribe(
|
||||||
|
request, file_path_processed, file_metadata, user
|
||||||
|
)
|
||||||
|
|
||||||
process_file(
|
process_file(
|
||||||
request,
|
request,
|
||||||
|
|
@ -126,11 +146,17 @@ def process_uploaded_file(request, file, file_path, file_item, file_metadata, us
|
||||||
file_id=file_item.id, content=result.get("text", "")
|
file_id=file_item.id, content=result.get("text", "")
|
||||||
),
|
),
|
||||||
user=user,
|
user=user,
|
||||||
|
db=db_session,
|
||||||
)
|
)
|
||||||
elif (not file.content_type.startswith(("image/", "video/"))) or (
|
elif (not file.content_type.startswith(("image/", "video/"))) or (
|
||||||
request.app.state.config.CONTENT_EXTRACTION_ENGINE == "external"
|
request.app.state.config.CONTENT_EXTRACTION_ENGINE == "external"
|
||||||
):
|
):
|
||||||
process_file(request, ProcessFileForm(file_id=file_item.id), user=user)
|
process_file(
|
||||||
|
request,
|
||||||
|
ProcessFileForm(file_id=file_item.id),
|
||||||
|
user=user,
|
||||||
|
db=db_session,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"File type {file.content_type} is not supported for processing"
|
f"File type {file.content_type} is not supported for processing"
|
||||||
|
|
@ -139,7 +165,12 @@ def process_uploaded_file(request, file, file_path, file_item, file_metadata, us
|
||||||
log.info(
|
log.info(
|
||||||
f"File type {file.content_type} is not provided, but trying to process anyway"
|
f"File type {file.content_type} is not provided, but trying to process anyway"
|
||||||
)
|
)
|
||||||
process_file(request, ProcessFileForm(file_id=file_item.id), user=user)
|
process_file(
|
||||||
|
request,
|
||||||
|
ProcessFileForm(file_id=file_item.id),
|
||||||
|
user=user,
|
||||||
|
db=db_session,
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"Error processing file: {file_item.id}")
|
log.error(f"Error processing file: {file_item.id}")
|
||||||
|
|
@ -149,8 +180,15 @@ def process_uploaded_file(request, file, file_path, file_item, file_metadata, us
|
||||||
"status": "failed",
|
"status": "failed",
|
||||||
"error": str(e.detail) if hasattr(e, "detail") else str(e),
|
"error": str(e.detail) if hasattr(e, "detail") else str(e),
|
||||||
},
|
},
|
||||||
|
db=db_session,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if db:
|
||||||
|
_process_handler(db)
|
||||||
|
else:
|
||||||
|
with SessionLocal() as db_session:
|
||||||
|
_process_handler(db_session)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/", response_model=FileModelResponse)
|
@router.post("/", response_model=FileModelResponse)
|
||||||
def upload_file(
|
def upload_file(
|
||||||
|
|
@ -161,6 +199,7 @@ def upload_file(
|
||||||
process: bool = Query(True),
|
process: bool = Query(True),
|
||||||
process_in_background: bool = Query(True),
|
process_in_background: bool = Query(True),
|
||||||
user=Depends(get_verified_user),
|
user=Depends(get_verified_user),
|
||||||
|
db: Session = Depends(get_session),
|
||||||
):
|
):
|
||||||
return upload_file_handler(
|
return upload_file_handler(
|
||||||
request,
|
request,
|
||||||
|
|
@ -170,6 +209,7 @@ def upload_file(
|
||||||
process_in_background=process_in_background,
|
process_in_background=process_in_background,
|
||||||
user=user,
|
user=user,
|
||||||
background_tasks=background_tasks,
|
background_tasks=background_tasks,
|
||||||
|
db=db,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -181,6 +221,7 @@ def upload_file_handler(
|
||||||
process_in_background: bool = Query(True),
|
process_in_background: bool = Query(True),
|
||||||
user=Depends(get_verified_user),
|
user=Depends(get_verified_user),
|
||||||
background_tasks: Optional[BackgroundTasks] = None,
|
background_tasks: Optional[BackgroundTasks] = None,
|
||||||
|
db: Optional[Session] = None,
|
||||||
):
|
):
|
||||||
log.info(f"file.content_type: {file.content_type} {process}")
|
log.info(f"file.content_type: {file.content_type} {process}")
|
||||||
|
|
||||||
|
|
@ -248,14 +289,17 @@ def upload_file_handler(
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
|
db=db,
|
||||||
)
|
)
|
||||||
|
|
||||||
if "channel_id" in file_metadata:
|
if "channel_id" in file_metadata:
|
||||||
channel = Channels.get_channel_by_id_and_user_id(
|
channel = Channels.get_channel_by_id_and_user_id(
|
||||||
file_metadata["channel_id"], user.id
|
file_metadata["channel_id"], user.id, db=db
|
||||||
)
|
)
|
||||||
if channel:
|
if channel:
|
||||||
Channels.add_file_to_channel_by_id(channel.id, file_item.id, user.id)
|
Channels.add_file_to_channel_by_id(
|
||||||
|
channel.id, file_item.id, user.id, db=db
|
||||||
|
)
|
||||||
|
|
||||||
if process:
|
if process:
|
||||||
if background_tasks and process_in_background:
|
if background_tasks and process_in_background:
|
||||||
|
|
@ -277,6 +321,7 @@ def upload_file_handler(
|
||||||
file_item,
|
file_item,
|
||||||
file_metadata,
|
file_metadata,
|
||||||
user,
|
user,
|
||||||
|
db=db,
|
||||||
)
|
)
|
||||||
return {"status": True, **file_item.model_dump()}
|
return {"status": True, **file_item.model_dump()}
|
||||||
else:
|
else:
|
||||||
|
|
@ -302,11 +347,15 @@ def upload_file_handler(
|
||||||
|
|
||||||
|
|
||||||
@router.get("/", response_model=list[FileModelResponse])
|
@router.get("/", response_model=list[FileModelResponse])
|
||||||
async def list_files(user=Depends(get_verified_user), content: bool = Query(True)):
|
async def list_files(
|
||||||
|
user=Depends(get_verified_user),
|
||||||
|
content: bool = Query(True),
|
||||||
|
db: Session = Depends(get_session),
|
||||||
|
):
|
||||||
if user.role == "admin":
|
if user.role == "admin":
|
||||||
files = Files.get_files()
|
files = Files.get_files(db=db)
|
||||||
else:
|
else:
|
||||||
files = Files.get_files_by_user_id(user.id)
|
files = Files.get_files_by_user_id(user.id, db=db)
|
||||||
|
|
||||||
if not content:
|
if not content:
|
||||||
for file in files:
|
for file in files:
|
||||||
|
|
@ -329,15 +378,16 @@ async def search_files(
|
||||||
),
|
),
|
||||||
content: bool = Query(True),
|
content: bool = Query(True),
|
||||||
user=Depends(get_verified_user),
|
user=Depends(get_verified_user),
|
||||||
|
db: Session = Depends(get_session),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Search for files by filename with support for wildcard patterns.
|
Search for files by filename with support for wildcard patterns.
|
||||||
"""
|
"""
|
||||||
# Get files according to user role
|
# Get files according to user role
|
||||||
if user.role == "admin":
|
if user.role == "admin":
|
||||||
files = Files.get_files()
|
files = Files.get_files(db=db)
|
||||||
else:
|
else:
|
||||||
files = Files.get_files_by_user_id(user.id)
|
files = Files.get_files_by_user_id(user.id, db=db)
|
||||||
|
|
||||||
# Get matching files
|
# Get matching files
|
||||||
matching_files = [
|
matching_files = [
|
||||||
|
|
@ -364,8 +414,10 @@ async def search_files(
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/all")
|
@router.delete("/all")
|
||||||
async def delete_all_files(user=Depends(get_admin_user)):
|
async def delete_all_files(
|
||||||
result = Files.delete_all_files()
|
user=Depends(get_admin_user), db: Session = Depends(get_session)
|
||||||
|
):
|
||||||
|
result = Files.delete_all_files(db=db)
|
||||||
if result:
|
if result:
|
||||||
try:
|
try:
|
||||||
Storage.delete_all_files()
|
Storage.delete_all_files()
|
||||||
|
|
@ -391,8 +443,10 @@ async def delete_all_files(user=Depends(get_admin_user)):
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{id}", response_model=Optional[FileModel])
|
@router.get("/{id}", response_model=Optional[FileModel])
|
||||||
async def get_file_by_id(id: str, user=Depends(get_verified_user)):
|
async def get_file_by_id(
|
||||||
file = Files.get_file_by_id(id)
|
id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||||
|
):
|
||||||
|
file = Files.get_file_by_id(id, db=db)
|
||||||
|
|
||||||
if not file:
|
if not file:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -403,7 +457,7 @@ async def get_file_by_id(id: str, user=Depends(get_verified_user)):
|
||||||
if (
|
if (
|
||||||
file.user_id == user.id
|
file.user_id == user.id
|
||||||
or user.role == "admin"
|
or user.role == "admin"
|
||||||
or has_access_to_file(id, "read", user)
|
or has_access_to_file(id, "read", user, db=db)
|
||||||
):
|
):
|
||||||
return file
|
return file
|
||||||
else:
|
else:
|
||||||
|
|
@ -415,9 +469,12 @@ async def get_file_by_id(id: str, user=Depends(get_verified_user)):
|
||||||
|
|
||||||
@router.get("/{id}/process/status")
|
@router.get("/{id}/process/status")
|
||||||
async def get_file_process_status(
|
async def get_file_process_status(
|
||||||
id: str, stream: bool = Query(False), user=Depends(get_verified_user)
|
id: str,
|
||||||
|
stream: bool = Query(False),
|
||||||
|
user=Depends(get_verified_user),
|
||||||
|
db: Session = Depends(get_session),
|
||||||
):
|
):
|
||||||
file = Files.get_file_by_id(id)
|
file = Files.get_file_by_id(id, db=db)
|
||||||
|
|
||||||
if not file:
|
if not file:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -428,7 +485,7 @@ async def get_file_process_status(
|
||||||
if (
|
if (
|
||||||
file.user_id == user.id
|
file.user_id == user.id
|
||||||
or user.role == "admin"
|
or user.role == "admin"
|
||||||
or has_access_to_file(id, "read", user)
|
or has_access_to_file(id, "read", user, db=db)
|
||||||
):
|
):
|
||||||
if stream:
|
if stream:
|
||||||
MAX_FILE_PROCESSING_DURATION = 3600 * 2
|
MAX_FILE_PROCESSING_DURATION = 3600 * 2
|
||||||
|
|
@ -436,7 +493,7 @@ async def get_file_process_status(
|
||||||
async def event_stream(file_item):
|
async def event_stream(file_item):
|
||||||
if file_item:
|
if file_item:
|
||||||
for _ in range(MAX_FILE_PROCESSING_DURATION):
|
for _ in range(MAX_FILE_PROCESSING_DURATION):
|
||||||
file_item = Files.get_file_by_id(file_item.id)
|
file_item = Files.get_file_by_id(file_item.id, db=db)
|
||||||
if file_item:
|
if file_item:
|
||||||
data = file_item.model_dump().get("data", {})
|
data = file_item.model_dump().get("data", {})
|
||||||
status = data.get("status")
|
status = data.get("status")
|
||||||
|
|
@ -476,8 +533,10 @@ async def get_file_process_status(
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{id}/data/content")
|
@router.get("/{id}/data/content")
|
||||||
async def get_file_data_content_by_id(id: str, user=Depends(get_verified_user)):
|
async def get_file_data_content_by_id(
|
||||||
file = Files.get_file_by_id(id)
|
id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||||
|
):
|
||||||
|
file = Files.get_file_by_id(id, db=db)
|
||||||
|
|
||||||
if not file:
|
if not file:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -488,7 +547,7 @@ async def get_file_data_content_by_id(id: str, user=Depends(get_verified_user)):
|
||||||
if (
|
if (
|
||||||
file.user_id == user.id
|
file.user_id == user.id
|
||||||
or user.role == "admin"
|
or user.role == "admin"
|
||||||
or has_access_to_file(id, "read", user)
|
or has_access_to_file(id, "read", user, db=db)
|
||||||
):
|
):
|
||||||
return {"content": file.data.get("content", "")}
|
return {"content": file.data.get("content", "")}
|
||||||
else:
|
else:
|
||||||
|
|
@ -509,9 +568,13 @@ class ContentForm(BaseModel):
|
||||||
|
|
||||||
@router.post("/{id}/data/content/update")
|
@router.post("/{id}/data/content/update")
|
||||||
async def update_file_data_content_by_id(
|
async def update_file_data_content_by_id(
|
||||||
request: Request, id: str, form_data: ContentForm, user=Depends(get_verified_user)
|
request: Request,
|
||||||
|
id: str,
|
||||||
|
form_data: ContentForm,
|
||||||
|
user=Depends(get_verified_user),
|
||||||
|
db: Session = Depends(get_session),
|
||||||
):
|
):
|
||||||
file = Files.get_file_by_id(id)
|
file = Files.get_file_by_id(id, db=db)
|
||||||
|
|
||||||
if not file:
|
if not file:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -522,7 +585,7 @@ async def update_file_data_content_by_id(
|
||||||
if (
|
if (
|
||||||
file.user_id == user.id
|
file.user_id == user.id
|
||||||
or user.role == "admin"
|
or user.role == "admin"
|
||||||
or has_access_to_file(id, "write", user)
|
or has_access_to_file(id, "write", user, db=db)
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
process_file(
|
process_file(
|
||||||
|
|
@ -530,7 +593,7 @@ async def update_file_data_content_by_id(
|
||||||
ProcessFileForm(file_id=id, content=form_data.content),
|
ProcessFileForm(file_id=id, content=form_data.content),
|
||||||
user=user,
|
user=user,
|
||||||
)
|
)
|
||||||
file = Files.get_file_by_id(id=id)
|
file = Files.get_file_by_id(id=id, db=db)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(e)
|
log.exception(e)
|
||||||
log.error(f"Error processing file: {file.id}")
|
log.error(f"Error processing file: {file.id}")
|
||||||
|
|
@ -550,9 +613,12 @@ async def update_file_data_content_by_id(
|
||||||
|
|
||||||
@router.get("/{id}/content")
|
@router.get("/{id}/content")
|
||||||
async def get_file_content_by_id(
|
async def get_file_content_by_id(
|
||||||
id: str, user=Depends(get_verified_user), attachment: bool = Query(False)
|
id: str,
|
||||||
|
user=Depends(get_verified_user),
|
||||||
|
attachment: bool = Query(False),
|
||||||
|
db: Session = Depends(get_session),
|
||||||
):
|
):
|
||||||
file = Files.get_file_by_id(id)
|
file = Files.get_file_by_id(id, db=db)
|
||||||
|
|
||||||
if not file:
|
if not file:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -563,7 +629,7 @@ async def get_file_content_by_id(
|
||||||
if (
|
if (
|
||||||
file.user_id == user.id
|
file.user_id == user.id
|
||||||
or user.role == "admin"
|
or user.role == "admin"
|
||||||
or has_access_to_file(id, "read", user)
|
or has_access_to_file(id, "read", user, db=db)
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
file_path = Storage.get_file(file.path)
|
file_path = Storage.get_file(file.path)
|
||||||
|
|
@ -619,8 +685,10 @@ async def get_file_content_by_id(
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{id}/content/html")
|
@router.get("/{id}/content/html")
|
||||||
async def get_html_file_content_by_id(id: str, user=Depends(get_verified_user)):
|
async def get_html_file_content_by_id(
|
||||||
file = Files.get_file_by_id(id)
|
id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||||
|
):
|
||||||
|
file = Files.get_file_by_id(id, db=db)
|
||||||
|
|
||||||
if not file:
|
if not file:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -628,7 +696,7 @@ async def get_html_file_content_by_id(id: str, user=Depends(get_verified_user)):
|
||||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||||
)
|
)
|
||||||
|
|
||||||
file_user = Users.get_user_by_id(file.user_id)
|
file_user = Users.get_user_by_id(file.user_id, db=db)
|
||||||
if not file_user.role == "admin":
|
if not file_user.role == "admin":
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
|
@ -638,7 +706,7 @@ async def get_html_file_content_by_id(id: str, user=Depends(get_verified_user)):
|
||||||
if (
|
if (
|
||||||
file.user_id == user.id
|
file.user_id == user.id
|
||||||
or user.role == "admin"
|
or user.role == "admin"
|
||||||
or has_access_to_file(id, "read", user)
|
or has_access_to_file(id, "read", user, db=db)
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
file_path = Storage.get_file(file.path)
|
file_path = Storage.get_file(file.path)
|
||||||
|
|
@ -668,8 +736,10 @@ async def get_html_file_content_by_id(id: str, user=Depends(get_verified_user)):
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{id}/content/{file_name}")
|
@router.get("/{id}/content/{file_name}")
|
||||||
async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
|
async def get_file_content_by_id(
|
||||||
file = Files.get_file_by_id(id)
|
id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||||
|
):
|
||||||
|
file = Files.get_file_by_id(id, db=db)
|
||||||
|
|
||||||
if not file:
|
if not file:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -680,7 +750,7 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
|
||||||
if (
|
if (
|
||||||
file.user_id == user.id
|
file.user_id == user.id
|
||||||
or user.role == "admin"
|
or user.role == "admin"
|
||||||
or has_access_to_file(id, "read", user)
|
or has_access_to_file(id, "read", user, db=db)
|
||||||
):
|
):
|
||||||
file_path = file.path
|
file_path = file.path
|
||||||
|
|
||||||
|
|
@ -730,8 +800,10 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/{id}")
|
@router.delete("/{id}")
|
||||||
async def delete_file_by_id(id: str, user=Depends(get_verified_user)):
|
async def delete_file_by_id(
|
||||||
file = Files.get_file_by_id(id)
|
id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||||
|
):
|
||||||
|
file = Files.get_file_by_id(id, db=db)
|
||||||
|
|
||||||
if not file:
|
if not file:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -742,10 +814,10 @@ async def delete_file_by_id(id: str, user=Depends(get_verified_user)):
|
||||||
if (
|
if (
|
||||||
file.user_id == user.id
|
file.user_id == user.id
|
||||||
or user.role == "admin"
|
or user.role == "admin"
|
||||||
or has_access_to_file(id, "write", user)
|
or has_access_to_file(id, "write", user, db=db)
|
||||||
):
|
):
|
||||||
|
|
||||||
result = Files.delete_file_by_id(id)
|
result = Files.delete_file_by_id(id, db=db)
|
||||||
if result:
|
if result:
|
||||||
try:
|
try:
|
||||||
Storage.delete_file(file.path)
|
Storage.delete_file(file.path)
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,8 @@ from open_webui.models.knowledge import Knowledges
|
||||||
|
|
||||||
from open_webui.config import UPLOAD_DIR
|
from open_webui.config import UPLOAD_DIR
|
||||||
from open_webui.constants import ERROR_MESSAGES
|
from open_webui.constants import ERROR_MESSAGES
|
||||||
|
from open_webui.internal.db import get_session
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status, Request
|
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status, Request
|
||||||
|
|
@ -44,7 +46,11 @@ router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
@router.get("/", response_model=list[FolderNameIdResponse])
|
@router.get("/", response_model=list[FolderNameIdResponse])
|
||||||
async def get_folders(request: Request, user=Depends(get_verified_user)):
|
async def get_folders(
|
||||||
|
request: Request,
|
||||||
|
user=Depends(get_verified_user),
|
||||||
|
db: Session = Depends(get_session),
|
||||||
|
):
|
||||||
if request.app.state.config.ENABLE_FOLDERS is False:
|
if request.app.state.config.ENABLE_FOLDERS is False:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
|
@ -55,22 +61,23 @@ async def get_folders(request: Request, user=Depends(get_verified_user)):
|
||||||
user.id,
|
user.id,
|
||||||
"features.folders",
|
"features.folders",
|
||||||
request.app.state.config.USER_PERMISSIONS,
|
request.app.state.config.USER_PERMISSIONS,
|
||||||
|
db=db,
|
||||||
):
|
):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||||
)
|
)
|
||||||
|
|
||||||
folders = Folders.get_folders_by_user_id(user.id)
|
folders = Folders.get_folders_by_user_id(user.id, db=db)
|
||||||
|
|
||||||
# Verify folder data integrity
|
# Verify folder data integrity
|
||||||
folder_list = []
|
folder_list = []
|
||||||
for folder in folders:
|
for folder in folders:
|
||||||
if folder.parent_id and not Folders.get_folder_by_id_and_user_id(
|
if folder.parent_id and not Folders.get_folder_by_id_and_user_id(
|
||||||
folder.parent_id, user.id
|
folder.parent_id, user.id, db=db
|
||||||
):
|
):
|
||||||
folder = Folders.update_folder_parent_id_by_id_and_user_id(
|
folder = Folders.update_folder_parent_id_by_id_and_user_id(
|
||||||
folder.id, user.id, None
|
folder.id, user.id, None, db=db
|
||||||
)
|
)
|
||||||
|
|
||||||
if folder.data:
|
if folder.data:
|
||||||
|
|
@ -80,12 +87,12 @@ async def get_folders(request: Request, user=Depends(get_verified_user)):
|
||||||
|
|
||||||
if file.get("type") == "file":
|
if file.get("type") == "file":
|
||||||
if Files.check_access_by_user_id(
|
if Files.check_access_by_user_id(
|
||||||
file.get("id"), user.id, "read"
|
file.get("id"), user.id, "read", db=db
|
||||||
):
|
):
|
||||||
valid_files.append(file)
|
valid_files.append(file)
|
||||||
elif file.get("type") == "collection":
|
elif file.get("type") == "collection":
|
||||||
if Knowledges.check_access_by_user_id(
|
if Knowledges.check_access_by_user_id(
|
||||||
file.get("id"), user.id, "read"
|
file.get("id"), user.id, "read", db=db
|
||||||
):
|
):
|
||||||
valid_files.append(file)
|
valid_files.append(file)
|
||||||
else:
|
else:
|
||||||
|
|
@ -93,7 +100,7 @@ async def get_folders(request: Request, user=Depends(get_verified_user)):
|
||||||
|
|
||||||
folder.data["files"] = valid_files
|
folder.data["files"] = valid_files
|
||||||
Folders.update_folder_by_id_and_user_id(
|
Folders.update_folder_by_id_and_user_id(
|
||||||
folder.id, user.id, FolderUpdateForm(data=folder.data)
|
folder.id, user.id, FolderUpdateForm(data=folder.data), db=db
|
||||||
)
|
)
|
||||||
|
|
||||||
folder_list.append(FolderNameIdResponse(**folder.model_dump()))
|
folder_list.append(FolderNameIdResponse(**folder.model_dump()))
|
||||||
|
|
@ -107,9 +114,13 @@ async def get_folders(request: Request, user=Depends(get_verified_user)):
|
||||||
|
|
||||||
|
|
||||||
@router.post("/")
|
@router.post("/")
|
||||||
def create_folder(form_data: FolderForm, user=Depends(get_verified_user)):
|
def create_folder(
|
||||||
|
form_data: FolderForm,
|
||||||
|
user=Depends(get_verified_user),
|
||||||
|
db: Session = Depends(get_session),
|
||||||
|
):
|
||||||
folder = Folders.get_folder_by_parent_id_and_user_id_and_name(
|
folder = Folders.get_folder_by_parent_id_and_user_id_and_name(
|
||||||
None, user.id, form_data.name
|
None, user.id, form_data.name, db=db
|
||||||
)
|
)
|
||||||
|
|
||||||
if folder:
|
if folder:
|
||||||
|
|
@ -119,7 +130,7 @@ def create_folder(form_data: FolderForm, user=Depends(get_verified_user)):
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
folder = Folders.insert_new_folder(user.id, form_data)
|
folder = Folders.insert_new_folder(user.id, form_data, db=db)
|
||||||
return folder
|
return folder
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(e)
|
log.exception(e)
|
||||||
|
|
@ -136,8 +147,10 @@ def create_folder(form_data: FolderForm, user=Depends(get_verified_user)):
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{id}", response_model=Optional[FolderModel])
|
@router.get("/{id}", response_model=Optional[FolderModel])
|
||||||
async def get_folder_by_id(id: str, user=Depends(get_verified_user)):
|
async def get_folder_by_id(
|
||||||
folder = Folders.get_folder_by_id_and_user_id(id, user.id)
|
id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||||
|
):
|
||||||
|
folder = Folders.get_folder_by_id_and_user_id(id, user.id, db=db)
|
||||||
if folder:
|
if folder:
|
||||||
return folder
|
return folder
|
||||||
else:
|
else:
|
||||||
|
|
@ -154,15 +167,18 @@ async def get_folder_by_id(id: str, user=Depends(get_verified_user)):
|
||||||
|
|
||||||
@router.post("/{id}/update")
|
@router.post("/{id}/update")
|
||||||
async def update_folder_name_by_id(
|
async def update_folder_name_by_id(
|
||||||
id: str, form_data: FolderUpdateForm, user=Depends(get_verified_user)
|
id: str,
|
||||||
|
form_data: FolderUpdateForm,
|
||||||
|
user=Depends(get_verified_user),
|
||||||
|
db: Session = Depends(get_session),
|
||||||
):
|
):
|
||||||
folder = Folders.get_folder_by_id_and_user_id(id, user.id)
|
folder = Folders.get_folder_by_id_and_user_id(id, user.id, db=db)
|
||||||
if folder:
|
if folder:
|
||||||
|
|
||||||
if form_data.name is not None:
|
if form_data.name is not None:
|
||||||
# Check if folder with same name exists
|
# Check if folder with same name exists
|
||||||
existing_folder = Folders.get_folder_by_parent_id_and_user_id_and_name(
|
existing_folder = Folders.get_folder_by_parent_id_and_user_id_and_name(
|
||||||
folder.parent_id, user.id, form_data.name
|
folder.parent_id, user.id, form_data.name, db=db
|
||||||
)
|
)
|
||||||
if existing_folder and existing_folder.id != id:
|
if existing_folder and existing_folder.id != id:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -171,7 +187,9 @@ async def update_folder_name_by_id(
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
folder = Folders.update_folder_by_id_and_user_id(id, user.id, form_data)
|
folder = Folders.update_folder_by_id_and_user_id(
|
||||||
|
id, user.id, form_data, db=db
|
||||||
|
)
|
||||||
return folder
|
return folder
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(e)
|
log.exception(e)
|
||||||
|
|
@ -198,12 +216,15 @@ class FolderParentIdForm(BaseModel):
|
||||||
|
|
||||||
@router.post("/{id}/update/parent")
|
@router.post("/{id}/update/parent")
|
||||||
async def update_folder_parent_id_by_id(
|
async def update_folder_parent_id_by_id(
|
||||||
id: str, form_data: FolderParentIdForm, user=Depends(get_verified_user)
|
id: str,
|
||||||
|
form_data: FolderParentIdForm,
|
||||||
|
user=Depends(get_verified_user),
|
||||||
|
db: Session = Depends(get_session),
|
||||||
):
|
):
|
||||||
folder = Folders.get_folder_by_id_and_user_id(id, user.id)
|
folder = Folders.get_folder_by_id_and_user_id(id, user.id, db=db)
|
||||||
if folder:
|
if folder:
|
||||||
existing_folder = Folders.get_folder_by_parent_id_and_user_id_and_name(
|
existing_folder = Folders.get_folder_by_parent_id_and_user_id_and_name(
|
||||||
form_data.parent_id, user.id, folder.name
|
form_data.parent_id, user.id, folder.name, db=db
|
||||||
)
|
)
|
||||||
|
|
||||||
if existing_folder:
|
if existing_folder:
|
||||||
|
|
@ -214,7 +235,7 @@ async def update_folder_parent_id_by_id(
|
||||||
|
|
||||||
try:
|
try:
|
||||||
folder = Folders.update_folder_parent_id_by_id_and_user_id(
|
folder = Folders.update_folder_parent_id_by_id_and_user_id(
|
||||||
id, user.id, form_data.parent_id
|
id, user.id, form_data.parent_id, db=db
|
||||||
)
|
)
|
||||||
return folder
|
return folder
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -242,13 +263,16 @@ class FolderIsExpandedForm(BaseModel):
|
||||||
|
|
||||||
@router.post("/{id}/update/expanded")
|
@router.post("/{id}/update/expanded")
|
||||||
async def update_folder_is_expanded_by_id(
|
async def update_folder_is_expanded_by_id(
|
||||||
id: str, form_data: FolderIsExpandedForm, user=Depends(get_verified_user)
|
id: str,
|
||||||
|
form_data: FolderIsExpandedForm,
|
||||||
|
user=Depends(get_verified_user),
|
||||||
|
db: Session = Depends(get_session),
|
||||||
):
|
):
|
||||||
folder = Folders.get_folder_by_id_and_user_id(id, user.id)
|
folder = Folders.get_folder_by_id_and_user_id(id, user.id, db=db)
|
||||||
if folder:
|
if folder:
|
||||||
try:
|
try:
|
||||||
folder = Folders.update_folder_is_expanded_by_id_and_user_id(
|
folder = Folders.update_folder_is_expanded_by_id_and_user_id(
|
||||||
id, user.id, form_data.is_expanded
|
id, user.id, form_data.is_expanded, db=db
|
||||||
)
|
)
|
||||||
return folder
|
return folder
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -276,10 +300,11 @@ async def delete_folder_by_id(
|
||||||
id: str,
|
id: str,
|
||||||
delete_contents: Optional[bool] = True,
|
delete_contents: Optional[bool] = True,
|
||||||
user=Depends(get_verified_user),
|
user=Depends(get_verified_user),
|
||||||
|
db: Session = Depends(get_session),
|
||||||
):
|
):
|
||||||
if Chats.count_chats_by_folder_id_and_user_id(id, user.id):
|
if Chats.count_chats_by_folder_id_and_user_id(id, user.id, db=db):
|
||||||
chat_delete_permission = has_permission(
|
chat_delete_permission = has_permission(
|
||||||
user.id, "chat.delete", request.app.state.config.USER_PERMISSIONS
|
user.id, "chat.delete", request.app.state.config.USER_PERMISSIONS, db=db
|
||||||
)
|
)
|
||||||
if user.role != "admin" and not chat_delete_permission:
|
if user.role != "admin" and not chat_delete_permission:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -288,19 +313,21 @@ async def delete_folder_by_id(
|
||||||
)
|
)
|
||||||
|
|
||||||
folders = []
|
folders = []
|
||||||
folders.append(Folders.get_folder_by_id_and_user_id(id, user.id))
|
folders.append(Folders.get_folder_by_id_and_user_id(id, user.id, db=db))
|
||||||
while folders:
|
while folders:
|
||||||
folder = folders.pop()
|
folder = folders.pop()
|
||||||
if folder:
|
if folder:
|
||||||
try:
|
try:
|
||||||
folder_ids = Folders.delete_folder_by_id_and_user_id(id, user.id)
|
folder_ids = Folders.delete_folder_by_id_and_user_id(id, user.id, db=db)
|
||||||
|
|
||||||
for folder_id in folder_ids:
|
for folder_id in folder_ids:
|
||||||
if delete_contents:
|
if delete_contents:
|
||||||
Chats.delete_chats_by_user_id_and_folder_id(user.id, folder_id)
|
Chats.delete_chats_by_user_id_and_folder_id(
|
||||||
|
user.id, folder_id, db=db
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
Chats.move_chats_by_user_id_and_folder_id(
|
Chats.move_chats_by_user_id_and_folder_id(
|
||||||
user.id, folder_id, None
|
user.id, folder_id, None, db=db
|
||||||
)
|
)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
@ -314,7 +341,7 @@ async def delete_folder_by_id(
|
||||||
finally:
|
finally:
|
||||||
# Get all subfolders
|
# Get all subfolders
|
||||||
subfolders = Folders.get_folders_by_parent_id_and_user_id(
|
subfolders = Folders.get_folders_by_parent_id_and_user_id(
|
||||||
folder.id, user.id
|
folder.id, user.id, db=db
|
||||||
)
|
)
|
||||||
folders.extend(subfolders)
|
folders.extend(subfolders)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,8 @@ from open_webui.constants import ERROR_MESSAGES
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||||
from pydantic import BaseModel, HttpUrl
|
from pydantic import BaseModel, HttpUrl
|
||||||
|
from open_webui.internal.db import get_session
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
@ -37,13 +39,13 @@ router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
@router.get("/", response_model=list[FunctionResponse])
|
@router.get("/", response_model=list[FunctionResponse])
|
||||||
async def get_functions(user=Depends(get_verified_user)):
|
async def get_functions(user=Depends(get_verified_user), db: Session = Depends(get_session)):
|
||||||
return Functions.get_functions()
|
return Functions.get_functions(db=db)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/list", response_model=list[FunctionUserResponse])
|
@router.get("/list", response_model=list[FunctionUserResponse])
|
||||||
async def get_function_list(user=Depends(get_admin_user)):
|
async def get_function_list(user=Depends(get_admin_user), db: Session = Depends(get_session)):
|
||||||
return Functions.get_function_list()
|
return Functions.get_function_list(db=db)
|
||||||
|
|
||||||
|
|
||||||
############################
|
############################
|
||||||
|
|
@ -52,8 +54,8 @@ async def get_function_list(user=Depends(get_admin_user)):
|
||||||
|
|
||||||
|
|
||||||
@router.get("/export", response_model=list[FunctionModel | FunctionWithValvesModel])
|
@router.get("/export", response_model=list[FunctionModel | FunctionWithValvesModel])
|
||||||
async def get_functions(include_valves: bool = False, user=Depends(get_admin_user)):
|
async def get_functions(include_valves: bool = False, user=Depends(get_admin_user), db: Session = Depends(get_session)):
|
||||||
return Functions.get_functions(include_valves=include_valves)
|
return Functions.get_functions(include_valves=include_valves, db=db)
|
||||||
|
|
||||||
|
|
||||||
############################
|
############################
|
||||||
|
|
@ -142,7 +144,7 @@ class SyncFunctionsForm(BaseModel):
|
||||||
|
|
||||||
@router.post("/sync", response_model=list[FunctionWithValvesModel])
|
@router.post("/sync", response_model=list[FunctionWithValvesModel])
|
||||||
async def sync_functions(
|
async def sync_functions(
|
||||||
request: Request, form_data: SyncFunctionsForm, user=Depends(get_admin_user)
|
request: Request, form_data: SyncFunctionsForm, user=Depends(get_admin_user), db: Session = Depends(get_session)
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
for function in form_data.functions:
|
for function in form_data.functions:
|
||||||
|
|
@ -164,7 +166,7 @@ async def sync_functions(
|
||||||
)
|
)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
return Functions.sync_functions(user.id, form_data.functions)
|
return Functions.sync_functions(user.id, form_data.functions, db=db)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(f"Failed to load a function: {e}")
|
log.exception(f"Failed to load a function: {e}")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -180,7 +182,7 @@ async def sync_functions(
|
||||||
|
|
||||||
@router.post("/create", response_model=Optional[FunctionResponse])
|
@router.post("/create", response_model=Optional[FunctionResponse])
|
||||||
async def create_new_function(
|
async def create_new_function(
|
||||||
request: Request, form_data: FunctionForm, user=Depends(get_admin_user)
|
request: Request, form_data: FunctionForm, user=Depends(get_admin_user), db: Session = Depends(get_session)
|
||||||
):
|
):
|
||||||
if not form_data.id.isidentifier():
|
if not form_data.id.isidentifier():
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -190,7 +192,7 @@ async def create_new_function(
|
||||||
|
|
||||||
form_data.id = form_data.id.lower()
|
form_data.id = form_data.id.lower()
|
||||||
|
|
||||||
function = Functions.get_function_by_id(form_data.id)
|
function = Functions.get_function_by_id(form_data.id, db=db)
|
||||||
if function is None:
|
if function is None:
|
||||||
try:
|
try:
|
||||||
form_data.content = replace_imports(form_data.content)
|
form_data.content = replace_imports(form_data.content)
|
||||||
|
|
@ -203,13 +205,13 @@ async def create_new_function(
|
||||||
FUNCTIONS = request.app.state.FUNCTIONS
|
FUNCTIONS = request.app.state.FUNCTIONS
|
||||||
FUNCTIONS[form_data.id] = function_module
|
FUNCTIONS[form_data.id] = function_module
|
||||||
|
|
||||||
function = Functions.insert_new_function(user.id, function_type, form_data)
|
function = Functions.insert_new_function(user.id, function_type, form_data, db=db)
|
||||||
|
|
||||||
function_cache_dir = CACHE_DIR / "functions" / form_data.id
|
function_cache_dir = CACHE_DIR / "functions" / form_data.id
|
||||||
function_cache_dir.mkdir(parents=True, exist_ok=True)
|
function_cache_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
if function_type == "filter" and getattr(function_module, "toggle", None):
|
if function_type == "filter" and getattr(function_module, "toggle", None):
|
||||||
Functions.update_function_metadata_by_id(id, {"toggle": True})
|
Functions.update_function_metadata_by_id(form_data.id, {"toggle": True}, db=db)
|
||||||
|
|
||||||
if function:
|
if function:
|
||||||
return function
|
return function
|
||||||
|
|
@ -237,8 +239,8 @@ async def create_new_function(
|
||||||
|
|
||||||
|
|
||||||
@router.get("/id/{id}", response_model=Optional[FunctionModel])
|
@router.get("/id/{id}", response_model=Optional[FunctionModel])
|
||||||
async def get_function_by_id(id: str, user=Depends(get_admin_user)):
|
async def get_function_by_id(id: str, user=Depends(get_admin_user), db: Session = Depends(get_session)):
|
||||||
function = Functions.get_function_by_id(id)
|
function = Functions.get_function_by_id(id, db=db)
|
||||||
|
|
||||||
if function:
|
if function:
|
||||||
return function
|
return function
|
||||||
|
|
@ -255,11 +257,11 @@ async def get_function_by_id(id: str, user=Depends(get_admin_user)):
|
||||||
|
|
||||||
|
|
||||||
@router.post("/id/{id}/toggle", response_model=Optional[FunctionModel])
|
@router.post("/id/{id}/toggle", response_model=Optional[FunctionModel])
|
||||||
async def toggle_function_by_id(id: str, user=Depends(get_admin_user)):
|
async def toggle_function_by_id(id: str, user=Depends(get_admin_user), db: Session = Depends(get_session)):
|
||||||
function = Functions.get_function_by_id(id)
|
function = Functions.get_function_by_id(id, db=db)
|
||||||
if function:
|
if function:
|
||||||
function = Functions.update_function_by_id(
|
function = Functions.update_function_by_id(
|
||||||
id, {"is_active": not function.is_active}
|
id, {"is_active": not function.is_active}, db=db
|
||||||
)
|
)
|
||||||
|
|
||||||
if function:
|
if function:
|
||||||
|
|
@ -282,11 +284,11 @@ async def toggle_function_by_id(id: str, user=Depends(get_admin_user)):
|
||||||
|
|
||||||
|
|
||||||
@router.post("/id/{id}/toggle/global", response_model=Optional[FunctionModel])
|
@router.post("/id/{id}/toggle/global", response_model=Optional[FunctionModel])
|
||||||
async def toggle_global_by_id(id: str, user=Depends(get_admin_user)):
|
async def toggle_global_by_id(id: str, user=Depends(get_admin_user), db: Session = Depends(get_session)):
|
||||||
function = Functions.get_function_by_id(id)
|
function = Functions.get_function_by_id(id, db=db)
|
||||||
if function:
|
if function:
|
||||||
function = Functions.update_function_by_id(
|
function = Functions.update_function_by_id(
|
||||||
id, {"is_global": not function.is_global}
|
id, {"is_global": not function.is_global}, db=db
|
||||||
)
|
)
|
||||||
|
|
||||||
if function:
|
if function:
|
||||||
|
|
@ -310,7 +312,7 @@ async def toggle_global_by_id(id: str, user=Depends(get_admin_user)):
|
||||||
|
|
||||||
@router.post("/id/{id}/update", response_model=Optional[FunctionModel])
|
@router.post("/id/{id}/update", response_model=Optional[FunctionModel])
|
||||||
async def update_function_by_id(
|
async def update_function_by_id(
|
||||||
request: Request, id: str, form_data: FunctionForm, user=Depends(get_admin_user)
|
request: Request, id: str, form_data: FunctionForm, user=Depends(get_admin_user), db: Session = Depends(get_session)
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
form_data.content = replace_imports(form_data.content)
|
form_data.content = replace_imports(form_data.content)
|
||||||
|
|
@ -325,10 +327,10 @@ async def update_function_by_id(
|
||||||
updated = {**form_data.model_dump(exclude={"id"}), "type": function_type}
|
updated = {**form_data.model_dump(exclude={"id"}), "type": function_type}
|
||||||
log.debug(updated)
|
log.debug(updated)
|
||||||
|
|
||||||
function = Functions.update_function_by_id(id, updated)
|
function = Functions.update_function_by_id(id, updated, db=db)
|
||||||
|
|
||||||
if function_type == "filter" and getattr(function_module, "toggle", None):
|
if function_type == "filter" and getattr(function_module, "toggle", None):
|
||||||
Functions.update_function_metadata_by_id(id, {"toggle": True})
|
Functions.update_function_metadata_by_id(id, {"toggle": True}, db=db)
|
||||||
|
|
||||||
if function:
|
if function:
|
||||||
return function
|
return function
|
||||||
|
|
@ -352,9 +354,9 @@ async def update_function_by_id(
|
||||||
|
|
||||||
@router.delete("/id/{id}/delete", response_model=bool)
|
@router.delete("/id/{id}/delete", response_model=bool)
|
||||||
async def delete_function_by_id(
|
async def delete_function_by_id(
|
||||||
request: Request, id: str, user=Depends(get_admin_user)
|
request: Request, id: str, user=Depends(get_admin_user), db: Session = Depends(get_session)
|
||||||
):
|
):
|
||||||
result = Functions.delete_function_by_id(id)
|
result = Functions.delete_function_by_id(id, db=db)
|
||||||
|
|
||||||
if result:
|
if result:
|
||||||
FUNCTIONS = request.app.state.FUNCTIONS
|
FUNCTIONS = request.app.state.FUNCTIONS
|
||||||
|
|
@ -370,11 +372,11 @@ async def delete_function_by_id(
|
||||||
|
|
||||||
|
|
||||||
@router.get("/id/{id}/valves", response_model=Optional[dict])
|
@router.get("/id/{id}/valves", response_model=Optional[dict])
|
||||||
async def get_function_valves_by_id(id: str, user=Depends(get_admin_user)):
|
async def get_function_valves_by_id(id: str, user=Depends(get_admin_user), db: Session = Depends(get_session)):
|
||||||
function = Functions.get_function_by_id(id)
|
function = Functions.get_function_by_id(id, db=db)
|
||||||
if function:
|
if function:
|
||||||
try:
|
try:
|
||||||
valves = Functions.get_function_valves_by_id(id)
|
valves = Functions.get_function_valves_by_id(id, db=db)
|
||||||
return valves
|
return valves
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -395,9 +397,9 @@ async def get_function_valves_by_id(id: str, user=Depends(get_admin_user)):
|
||||||
|
|
||||||
@router.get("/id/{id}/valves/spec", response_model=Optional[dict])
|
@router.get("/id/{id}/valves/spec", response_model=Optional[dict])
|
||||||
async def get_function_valves_spec_by_id(
|
async def get_function_valves_spec_by_id(
|
||||||
request: Request, id: str, user=Depends(get_admin_user)
|
request: Request, id: str, user=Depends(get_admin_user), db: Session = Depends(get_session)
|
||||||
):
|
):
|
||||||
function = Functions.get_function_by_id(id)
|
function = Functions.get_function_by_id(id, db=db)
|
||||||
if function:
|
if function:
|
||||||
function_module, function_type, frontmatter = get_function_module_from_cache(
|
function_module, function_type, frontmatter = get_function_module_from_cache(
|
||||||
request, id
|
request, id
|
||||||
|
|
@ -421,9 +423,9 @@ async def get_function_valves_spec_by_id(
|
||||||
|
|
||||||
@router.post("/id/{id}/valves/update", response_model=Optional[dict])
|
@router.post("/id/{id}/valves/update", response_model=Optional[dict])
|
||||||
async def update_function_valves_by_id(
|
async def update_function_valves_by_id(
|
||||||
request: Request, id: str, form_data: dict, user=Depends(get_admin_user)
|
request: Request, id: str, form_data: dict, user=Depends(get_admin_user), db: Session = Depends(get_session)
|
||||||
):
|
):
|
||||||
function = Functions.get_function_by_id(id)
|
function = Functions.get_function_by_id(id, db=db)
|
||||||
if function:
|
if function:
|
||||||
function_module, function_type, frontmatter = get_function_module_from_cache(
|
function_module, function_type, frontmatter = get_function_module_from_cache(
|
||||||
request, id
|
request, id
|
||||||
|
|
@ -437,7 +439,7 @@ async def update_function_valves_by_id(
|
||||||
valves = Valves(**form_data)
|
valves = Valves(**form_data)
|
||||||
|
|
||||||
valves_dict = valves.model_dump(exclude_unset=True)
|
valves_dict = valves.model_dump(exclude_unset=True)
|
||||||
Functions.update_function_valves_by_id(id, valves_dict)
|
Functions.update_function_valves_by_id(id, valves_dict, db=db)
|
||||||
return valves_dict
|
return valves_dict
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(f"Error updating function values by id {id}: {e}")
|
log.exception(f"Error updating function values by id {id}: {e}")
|
||||||
|
|
@ -464,11 +466,11 @@ async def update_function_valves_by_id(
|
||||||
|
|
||||||
|
|
||||||
@router.get("/id/{id}/valves/user", response_model=Optional[dict])
|
@router.get("/id/{id}/valves/user", response_model=Optional[dict])
|
||||||
async def get_function_user_valves_by_id(id: str, user=Depends(get_verified_user)):
|
async def get_function_user_valves_by_id(id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)):
|
||||||
function = Functions.get_function_by_id(id)
|
function = Functions.get_function_by_id(id, db=db)
|
||||||
if function:
|
if function:
|
||||||
try:
|
try:
|
||||||
user_valves = Functions.get_user_valves_by_id_and_user_id(id, user.id)
|
user_valves = Functions.get_user_valves_by_id_and_user_id(id, user.id, db=db)
|
||||||
return user_valves
|
return user_valves
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -484,9 +486,9 @@ async def get_function_user_valves_by_id(id: str, user=Depends(get_verified_user
|
||||||
|
|
||||||
@router.get("/id/{id}/valves/user/spec", response_model=Optional[dict])
|
@router.get("/id/{id}/valves/user/spec", response_model=Optional[dict])
|
||||||
async def get_function_user_valves_spec_by_id(
|
async def get_function_user_valves_spec_by_id(
|
||||||
request: Request, id: str, user=Depends(get_verified_user)
|
request: Request, id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||||
):
|
):
|
||||||
function = Functions.get_function_by_id(id)
|
function = Functions.get_function_by_id(id, db=db)
|
||||||
if function:
|
if function:
|
||||||
function_module, function_type, frontmatter = get_function_module_from_cache(
|
function_module, function_type, frontmatter = get_function_module_from_cache(
|
||||||
request, id
|
request, id
|
||||||
|
|
@ -505,9 +507,9 @@ async def get_function_user_valves_spec_by_id(
|
||||||
|
|
||||||
@router.post("/id/{id}/valves/user/update", response_model=Optional[dict])
|
@router.post("/id/{id}/valves/user/update", response_model=Optional[dict])
|
||||||
async def update_function_user_valves_by_id(
|
async def update_function_user_valves_by_id(
|
||||||
request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
|
request: Request, id: str, form_data: dict, user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||||
):
|
):
|
||||||
function = Functions.get_function_by_id(id)
|
function = Functions.get_function_by_id(id, db=db)
|
||||||
|
|
||||||
if function:
|
if function:
|
||||||
function_module, function_type, frontmatter = get_function_module_from_cache(
|
function_module, function_type, frontmatter = get_function_module_from_cache(
|
||||||
|
|
@ -522,7 +524,7 @@ async def update_function_user_valves_by_id(
|
||||||
user_valves = UserValves(**form_data)
|
user_valves = UserValves(**form_data)
|
||||||
user_valves_dict = user_valves.model_dump(exclude_unset=True)
|
user_valves_dict = user_valves.model_dump(exclude_unset=True)
|
||||||
Functions.update_user_valves_by_id_and_user_id(
|
Functions.update_user_valves_by_id_and_user_id(
|
||||||
id, user.id, user_valves_dict
|
id, user.id, user_valves_dict, db=db
|
||||||
)
|
)
|
||||||
return user_valves_dict
|
return user_valves_dict
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,9 @@ from open_webui.config import CACHE_DIR
|
||||||
from open_webui.constants import ERROR_MESSAGES
|
from open_webui.constants import ERROR_MESSAGES
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||||
|
|
||||||
|
from open_webui.internal.db import get_session
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -29,7 +32,11 @@ router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
@router.get("/", response_model=list[GroupResponse])
|
@router.get("/", response_model=list[GroupResponse])
|
||||||
async def get_groups(share: Optional[bool] = None, user=Depends(get_verified_user)):
|
async def get_groups(
|
||||||
|
share: Optional[bool] = None,
|
||||||
|
user=Depends(get_verified_user),
|
||||||
|
db: Session = Depends(get_session),
|
||||||
|
):
|
||||||
|
|
||||||
filter = {}
|
filter = {}
|
||||||
if user.role != "admin":
|
if user.role != "admin":
|
||||||
|
|
@ -38,7 +45,7 @@ async def get_groups(share: Optional[bool] = None, user=Depends(get_verified_use
|
||||||
if share is not None:
|
if share is not None:
|
||||||
filter["share"] = share
|
filter["share"] = share
|
||||||
|
|
||||||
groups = Groups.get_groups(filter=filter)
|
groups = Groups.get_groups(filter=filter, db=db)
|
||||||
|
|
||||||
return groups
|
return groups
|
||||||
|
|
||||||
|
|
@ -49,13 +56,17 @@ async def get_groups(share: Optional[bool] = None, user=Depends(get_verified_use
|
||||||
|
|
||||||
|
|
||||||
@router.post("/create", response_model=Optional[GroupResponse])
|
@router.post("/create", response_model=Optional[GroupResponse])
|
||||||
async def create_new_group(form_data: GroupForm, user=Depends(get_admin_user)):
|
async def create_new_group(
|
||||||
|
form_data: GroupForm,
|
||||||
|
user=Depends(get_admin_user),
|
||||||
|
db: Session = Depends(get_session),
|
||||||
|
):
|
||||||
try:
|
try:
|
||||||
group = Groups.insert_new_group(user.id, form_data)
|
group = Groups.insert_new_group(user.id, form_data, db=db)
|
||||||
if group:
|
if group:
|
||||||
return GroupResponse(
|
return GroupResponse(
|
||||||
**group.model_dump(),
|
**group.model_dump(),
|
||||||
member_count=Groups.get_group_member_count_by_id(group.id),
|
member_count=Groups.get_group_member_count_by_id(group.id, db=db),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -76,12 +87,14 @@ async def create_new_group(form_data: GroupForm, user=Depends(get_admin_user)):
|
||||||
|
|
||||||
|
|
||||||
@router.get("/id/{id}", response_model=Optional[GroupResponse])
|
@router.get("/id/{id}", response_model=Optional[GroupResponse])
|
||||||
async def get_group_by_id(id: str, user=Depends(get_admin_user)):
|
async def get_group_by_id(
|
||||||
group = Groups.get_group_by_id(id)
|
id: str, user=Depends(get_admin_user), db: Session = Depends(get_session)
|
||||||
|
):
|
||||||
|
group = Groups.get_group_by_id(id, db=db)
|
||||||
if group:
|
if group:
|
||||||
return GroupResponse(
|
return GroupResponse(
|
||||||
**group.model_dump(),
|
**group.model_dump(),
|
||||||
member_count=Groups.get_group_member_count_by_id(group.id),
|
member_count=Groups.get_group_member_count_by_id(group.id, db=db),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -101,13 +114,15 @@ class GroupExportResponse(GroupResponse):
|
||||||
|
|
||||||
|
|
||||||
@router.get("/id/{id}/export", response_model=Optional[GroupExportResponse])
|
@router.get("/id/{id}/export", response_model=Optional[GroupExportResponse])
|
||||||
async def export_group_by_id(id: str, user=Depends(get_admin_user)):
|
async def export_group_by_id(
|
||||||
group = Groups.get_group_by_id(id)
|
id: str, user=Depends(get_admin_user), db: Session = Depends(get_session)
|
||||||
|
):
|
||||||
|
group = Groups.get_group_by_id(id, db=db)
|
||||||
if group:
|
if group:
|
||||||
return GroupExportResponse(
|
return GroupExportResponse(
|
||||||
**group.model_dump(),
|
**group.model_dump(),
|
||||||
member_count=Groups.get_group_member_count_by_id(group.id),
|
member_count=Groups.get_group_member_count_by_id(group.id, db=db),
|
||||||
user_ids=Groups.get_group_user_ids_by_id(group.id),
|
user_ids=Groups.get_group_user_ids_by_id(group.id, db=db),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -122,9 +137,11 @@ async def export_group_by_id(id: str, user=Depends(get_admin_user)):
|
||||||
|
|
||||||
|
|
||||||
@router.post("/id/{id}/users", response_model=list[UserInfoResponse])
|
@router.post("/id/{id}/users", response_model=list[UserInfoResponse])
|
||||||
async def get_users_in_group(id: str, user=Depends(get_admin_user)):
|
async def get_users_in_group(
|
||||||
|
id: str, user=Depends(get_admin_user), db: Session = Depends(get_session)
|
||||||
|
):
|
||||||
try:
|
try:
|
||||||
users = Users.get_users_by_group_id(id)
|
users = Users.get_users_by_group_id(id, db=db)
|
||||||
return users
|
return users
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(f"Error adding users to group {id}: {e}")
|
log.exception(f"Error adding users to group {id}: {e}")
|
||||||
|
|
@ -141,14 +158,17 @@ async def get_users_in_group(id: str, user=Depends(get_admin_user)):
|
||||||
|
|
||||||
@router.post("/id/{id}/update", response_model=Optional[GroupResponse])
|
@router.post("/id/{id}/update", response_model=Optional[GroupResponse])
|
||||||
async def update_group_by_id(
|
async def update_group_by_id(
|
||||||
id: str, form_data: GroupUpdateForm, user=Depends(get_admin_user)
|
id: str,
|
||||||
|
form_data: GroupUpdateForm,
|
||||||
|
user=Depends(get_admin_user),
|
||||||
|
db: Session = Depends(get_session),
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
group = Groups.update_group_by_id(id, form_data)
|
group = Groups.update_group_by_id(id, form_data, db=db)
|
||||||
if group:
|
if group:
|
||||||
return GroupResponse(
|
return GroupResponse(
|
||||||
**group.model_dump(),
|
**group.model_dump(),
|
||||||
member_count=Groups.get_group_member_count_by_id(group.id),
|
member_count=Groups.get_group_member_count_by_id(group.id, db=db),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -170,17 +190,20 @@ async def update_group_by_id(
|
||||||
|
|
||||||
@router.post("/id/{id}/users/add", response_model=Optional[GroupResponse])
|
@router.post("/id/{id}/users/add", response_model=Optional[GroupResponse])
|
||||||
async def add_user_to_group(
|
async def add_user_to_group(
|
||||||
id: str, form_data: UserIdsForm, user=Depends(get_admin_user)
|
id: str,
|
||||||
|
form_data: UserIdsForm,
|
||||||
|
user=Depends(get_admin_user),
|
||||||
|
db: Session = Depends(get_session),
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
if form_data.user_ids:
|
if form_data.user_ids:
|
||||||
form_data.user_ids = Users.get_valid_user_ids(form_data.user_ids)
|
form_data.user_ids = Users.get_valid_user_ids(form_data.user_ids, db=db)
|
||||||
|
|
||||||
group = Groups.add_users_to_group(id, form_data.user_ids)
|
group = Groups.add_users_to_group(id, form_data.user_ids, db=db)
|
||||||
if group:
|
if group:
|
||||||
return GroupResponse(
|
return GroupResponse(
|
||||||
**group.model_dump(),
|
**group.model_dump(),
|
||||||
member_count=Groups.get_group_member_count_by_id(group.id),
|
member_count=Groups.get_group_member_count_by_id(group.id, db=db),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -197,14 +220,17 @@ async def add_user_to_group(
|
||||||
|
|
||||||
@router.post("/id/{id}/users/remove", response_model=Optional[GroupResponse])
|
@router.post("/id/{id}/users/remove", response_model=Optional[GroupResponse])
|
||||||
async def remove_users_from_group(
|
async def remove_users_from_group(
|
||||||
id: str, form_data: UserIdsForm, user=Depends(get_admin_user)
|
id: str,
|
||||||
|
form_data: UserIdsForm,
|
||||||
|
user=Depends(get_admin_user),
|
||||||
|
db: Session = Depends(get_session),
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
group = Groups.remove_users_from_group(id, form_data.user_ids)
|
group = Groups.remove_users_from_group(id, form_data.user_ids, db=db)
|
||||||
if group:
|
if group:
|
||||||
return GroupResponse(
|
return GroupResponse(
|
||||||
**group.model_dump(),
|
**group.model_dump(),
|
||||||
member_count=Groups.get_group_member_count_by_id(group.id),
|
member_count=Groups.get_group_member_count_by_id(group.id, db=db),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -225,9 +251,11 @@ async def remove_users_from_group(
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/id/{id}/delete", response_model=bool)
|
@router.delete("/id/{id}/delete", response_model=bool)
|
||||||
async def delete_group_by_id(id: str, user=Depends(get_admin_user)):
|
async def delete_group_by_id(
|
||||||
|
id: str, user=Depends(get_admin_user), db: Session = Depends(get_session)
|
||||||
|
):
|
||||||
try:
|
try:
|
||||||
result = Groups.delete_group_by_id(id)
|
result = Groups.delete_group_by_id(id, db=db)
|
||||||
if result:
|
if result:
|
||||||
return result
|
return result
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,8 @@ from fastapi import APIRouter, Depends, HTTPException, status, Request, Query
|
||||||
from fastapi.concurrency import run_in_threadpool
|
from fastapi.concurrency import run_in_threadpool
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from open_webui.internal.db import get_session
|
||||||
from open_webui.models.groups import Groups
|
from open_webui.models.groups import Groups
|
||||||
from open_webui.models.knowledge import (
|
from open_webui.models.knowledge import (
|
||||||
KnowledgeFileListResponse,
|
KnowledgeFileListResponse,
|
||||||
|
|
@ -52,21 +54,25 @@ class KnowledgeAccessListResponse(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
@router.get("/", response_model=KnowledgeAccessListResponse)
|
@router.get("/", response_model=KnowledgeAccessListResponse)
|
||||||
async def get_knowledge_bases(page: Optional[int] = 1, user=Depends(get_verified_user)):
|
async def get_knowledge_bases(
|
||||||
|
page: Optional[int] = 1,
|
||||||
|
user=Depends(get_verified_user),
|
||||||
|
db: Session = Depends(get_session),
|
||||||
|
):
|
||||||
page = max(page, 1)
|
page = max(page, 1)
|
||||||
limit = PAGE_ITEM_COUNT
|
limit = PAGE_ITEM_COUNT
|
||||||
skip = (page - 1) * limit
|
skip = (page - 1) * limit
|
||||||
|
|
||||||
filter = {}
|
filter = {}
|
||||||
if not user.role == "admin" or not BYPASS_ADMIN_ACCESS_CONTROL:
|
if not user.role == "admin" or not BYPASS_ADMIN_ACCESS_CONTROL:
|
||||||
groups = Groups.get_groups_by_member_id(user.id)
|
groups = Groups.get_groups_by_member_id(user.id, db=db)
|
||||||
if groups:
|
if groups:
|
||||||
filter["group_ids"] = [group.id for group in groups]
|
filter["group_ids"] = [group.id for group in groups]
|
||||||
|
|
||||||
filter["user_id"] = user.id
|
filter["user_id"] = user.id
|
||||||
|
|
||||||
result = Knowledges.search_knowledge_bases(
|
result = Knowledges.search_knowledge_bases(
|
||||||
user.id, filter=filter, skip=skip, limit=limit
|
user.id, filter=filter, skip=skip, limit=limit, db=db
|
||||||
)
|
)
|
||||||
|
|
||||||
return KnowledgeAccessListResponse(
|
return KnowledgeAccessListResponse(
|
||||||
|
|
@ -75,7 +81,9 @@ async def get_knowledge_bases(page: Optional[int] = 1, user=Depends(get_verified
|
||||||
**knowledge_base.model_dump(),
|
**knowledge_base.model_dump(),
|
||||||
write_access=(
|
write_access=(
|
||||||
user.id == knowledge_base.user_id
|
user.id == knowledge_base.user_id
|
||||||
or has_access(user.id, "write", knowledge_base.access_control)
|
or has_access(
|
||||||
|
user.id, "write", knowledge_base.access_control, db=db
|
||||||
|
)
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
for knowledge_base in result.items
|
for knowledge_base in result.items
|
||||||
|
|
@ -90,6 +98,7 @@ async def search_knowledge_bases(
|
||||||
view_option: Optional[str] = None,
|
view_option: Optional[str] = None,
|
||||||
page: Optional[int] = 1,
|
page: Optional[int] = 1,
|
||||||
user=Depends(get_verified_user),
|
user=Depends(get_verified_user),
|
||||||
|
db: Session = Depends(get_session),
|
||||||
):
|
):
|
||||||
page = max(page, 1)
|
page = max(page, 1)
|
||||||
limit = PAGE_ITEM_COUNT
|
limit = PAGE_ITEM_COUNT
|
||||||
|
|
@ -102,14 +111,14 @@ async def search_knowledge_bases(
|
||||||
filter["view_option"] = view_option
|
filter["view_option"] = view_option
|
||||||
|
|
||||||
if not user.role == "admin" or not BYPASS_ADMIN_ACCESS_CONTROL:
|
if not user.role == "admin" or not BYPASS_ADMIN_ACCESS_CONTROL:
|
||||||
groups = Groups.get_groups_by_member_id(user.id)
|
groups = Groups.get_groups_by_member_id(user.id, db=db)
|
||||||
if groups:
|
if groups:
|
||||||
filter["group_ids"] = [group.id for group in groups]
|
filter["group_ids"] = [group.id for group in groups]
|
||||||
|
|
||||||
filter["user_id"] = user.id
|
filter["user_id"] = user.id
|
||||||
|
|
||||||
result = Knowledges.search_knowledge_bases(
|
result = Knowledges.search_knowledge_bases(
|
||||||
user.id, filter=filter, skip=skip, limit=limit
|
user.id, filter=filter, skip=skip, limit=limit, db=db
|
||||||
)
|
)
|
||||||
|
|
||||||
return KnowledgeAccessListResponse(
|
return KnowledgeAccessListResponse(
|
||||||
|
|
@ -118,7 +127,9 @@ async def search_knowledge_bases(
|
||||||
**knowledge_base.model_dump(),
|
**knowledge_base.model_dump(),
|
||||||
write_access=(
|
write_access=(
|
||||||
user.id == knowledge_base.user_id
|
user.id == knowledge_base.user_id
|
||||||
or has_access(user.id, "write", knowledge_base.access_control)
|
or has_access(
|
||||||
|
user.id, "write", knowledge_base.access_control, db=db
|
||||||
|
)
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
for knowledge_base in result.items
|
for knowledge_base in result.items
|
||||||
|
|
@ -132,6 +143,7 @@ async def search_knowledge_files(
|
||||||
query: Optional[str] = None,
|
query: Optional[str] = None,
|
||||||
page: Optional[int] = 1,
|
page: Optional[int] = 1,
|
||||||
user=Depends(get_verified_user),
|
user=Depends(get_verified_user),
|
||||||
|
db: Session = Depends(get_session),
|
||||||
):
|
):
|
||||||
page = max(page, 1)
|
page = max(page, 1)
|
||||||
limit = PAGE_ITEM_COUNT
|
limit = PAGE_ITEM_COUNT
|
||||||
|
|
@ -141,13 +153,15 @@ async def search_knowledge_files(
|
||||||
if query:
|
if query:
|
||||||
filter["query"] = query
|
filter["query"] = query
|
||||||
|
|
||||||
groups = Groups.get_groups_by_member_id(user.id)
|
groups = Groups.get_groups_by_member_id(user.id, db=db)
|
||||||
if groups:
|
if groups:
|
||||||
filter["group_ids"] = [group.id for group in groups]
|
filter["group_ids"] = [group.id for group in groups]
|
||||||
|
|
||||||
filter["user_id"] = user.id
|
filter["user_id"] = user.id
|
||||||
|
|
||||||
return Knowledges.search_knowledge_files(filter=filter, skip=skip, limit=limit)
|
return Knowledges.search_knowledge_files(
|
||||||
|
filter=filter, skip=skip, limit=limit, db=db
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
############################
|
############################
|
||||||
|
|
@ -157,10 +171,13 @@ async def search_knowledge_files(
|
||||||
|
|
||||||
@router.post("/create", response_model=Optional[KnowledgeResponse])
|
@router.post("/create", response_model=Optional[KnowledgeResponse])
|
||||||
async def create_new_knowledge(
|
async def create_new_knowledge(
|
||||||
request: Request, form_data: KnowledgeForm, user=Depends(get_verified_user)
|
request: Request,
|
||||||
|
form_data: KnowledgeForm,
|
||||||
|
user=Depends(get_verified_user),
|
||||||
|
db: Session = Depends(get_session),
|
||||||
):
|
):
|
||||||
if user.role != "admin" and not has_permission(
|
if user.role != "admin" and not has_permission(
|
||||||
user.id, "workspace.knowledge", request.app.state.config.USER_PERMISSIONS
|
user.id, "workspace.knowledge", request.app.state.config.USER_PERMISSIONS, db=db
|
||||||
):
|
):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
|
@ -175,11 +192,12 @@ async def create_new_knowledge(
|
||||||
user.id,
|
user.id,
|
||||||
"sharing.public_knowledge",
|
"sharing.public_knowledge",
|
||||||
request.app.state.config.USER_PERMISSIONS,
|
request.app.state.config.USER_PERMISSIONS,
|
||||||
|
db=db,
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
form_data.access_control = {}
|
form_data.access_control = {}
|
||||||
|
|
||||||
knowledge = Knowledges.insert_new_knowledge(user.id, form_data)
|
knowledge = Knowledges.insert_new_knowledge(user.id, form_data, db=db)
|
||||||
|
|
||||||
if knowledge:
|
if knowledge:
|
||||||
return knowledge
|
return knowledge
|
||||||
|
|
@ -196,20 +214,24 @@ async def create_new_knowledge(
|
||||||
|
|
||||||
|
|
||||||
@router.post("/reindex", response_model=bool)
|
@router.post("/reindex", response_model=bool)
|
||||||
async def reindex_knowledge_files(request: Request, user=Depends(get_verified_user)):
|
async def reindex_knowledge_files(
|
||||||
|
request: Request,
|
||||||
|
user=Depends(get_verified_user),
|
||||||
|
db: Session = Depends(get_session),
|
||||||
|
):
|
||||||
if user.role != "admin":
|
if user.role != "admin":
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
||||||
)
|
)
|
||||||
|
|
||||||
knowledge_bases = Knowledges.get_knowledge_bases()
|
knowledge_bases = Knowledges.get_knowledge_bases(db=db)
|
||||||
|
|
||||||
log.info(f"Starting reindexing for {len(knowledge_bases)} knowledge bases")
|
log.info(f"Starting reindexing for {len(knowledge_bases)} knowledge bases")
|
||||||
|
|
||||||
for knowledge_base in knowledge_bases:
|
for knowledge_base in knowledge_bases:
|
||||||
try:
|
try:
|
||||||
files = Knowledges.get_files_by_id(knowledge_base.id)
|
files = Knowledges.get_files_by_id(knowledge_base.id, db=db)
|
||||||
try:
|
try:
|
||||||
if VECTOR_DB_CLIENT.has_collection(collection_name=knowledge_base.id):
|
if VECTOR_DB_CLIENT.has_collection(collection_name=knowledge_base.id):
|
||||||
VECTOR_DB_CLIENT.delete_collection(
|
VECTOR_DB_CLIENT.delete_collection(
|
||||||
|
|
@ -229,6 +251,7 @@ async def reindex_knowledge_files(request: Request, user=Depends(get_verified_us
|
||||||
file_id=file.id, collection_name=knowledge_base.id
|
file_id=file.id, collection_name=knowledge_base.id
|
||||||
),
|
),
|
||||||
user=user,
|
user=user,
|
||||||
|
db=db,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(
|
log.error(
|
||||||
|
|
@ -264,21 +287,23 @@ class KnowledgeFilesResponse(KnowledgeResponse):
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{id}", response_model=Optional[KnowledgeFilesResponse])
|
@router.get("/{id}", response_model=Optional[KnowledgeFilesResponse])
|
||||||
async def get_knowledge_by_id(id: str, user=Depends(get_verified_user)):
|
async def get_knowledge_by_id(
|
||||||
knowledge = Knowledges.get_knowledge_by_id(id=id)
|
id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||||
|
):
|
||||||
|
knowledge = Knowledges.get_knowledge_by_id(id=id, db=db)
|
||||||
|
|
||||||
if knowledge:
|
if knowledge:
|
||||||
if (
|
if (
|
||||||
user.role == "admin"
|
user.role == "admin"
|
||||||
or knowledge.user_id == user.id
|
or knowledge.user_id == user.id
|
||||||
or has_access(user.id, "read", knowledge.access_control)
|
or has_access(user.id, "read", knowledge.access_control, db=db)
|
||||||
):
|
):
|
||||||
|
|
||||||
return KnowledgeFilesResponse(
|
return KnowledgeFilesResponse(
|
||||||
**knowledge.model_dump(),
|
**knowledge.model_dump(),
|
||||||
write_access=(
|
write_access=(
|
||||||
user.id == knowledge.user_id
|
user.id == knowledge.user_id
|
||||||
or has_access(user.id, "write", knowledge.access_control)
|
or has_access(user.id, "write", knowledge.access_control, db=db)
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|
@ -299,8 +324,9 @@ async def update_knowledge_by_id(
|
||||||
id: str,
|
id: str,
|
||||||
form_data: KnowledgeForm,
|
form_data: KnowledgeForm,
|
||||||
user=Depends(get_verified_user),
|
user=Depends(get_verified_user),
|
||||||
|
db: Session = Depends(get_session),
|
||||||
):
|
):
|
||||||
knowledge = Knowledges.get_knowledge_by_id(id=id)
|
knowledge = Knowledges.get_knowledge_by_id(id=id, db=db)
|
||||||
if not knowledge:
|
if not knowledge:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
|
@ -309,7 +335,7 @@ async def update_knowledge_by_id(
|
||||||
# Is the user the original creator, in a group with write access, or an admin
|
# Is the user the original creator, in a group with write access, or an admin
|
||||||
if (
|
if (
|
||||||
knowledge.user_id != user.id
|
knowledge.user_id != user.id
|
||||||
and not has_access(user.id, "write", knowledge.access_control)
|
and not has_access(user.id, "write", knowledge.access_control, db=db)
|
||||||
and user.role != "admin"
|
and user.role != "admin"
|
||||||
):
|
):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -325,15 +351,16 @@ async def update_knowledge_by_id(
|
||||||
user.id,
|
user.id,
|
||||||
"sharing.public_knowledge",
|
"sharing.public_knowledge",
|
||||||
request.app.state.config.USER_PERMISSIONS,
|
request.app.state.config.USER_PERMISSIONS,
|
||||||
|
db=db,
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
form_data.access_control = {}
|
form_data.access_control = {}
|
||||||
|
|
||||||
knowledge = Knowledges.update_knowledge_by_id(id=id, form_data=form_data)
|
knowledge = Knowledges.update_knowledge_by_id(id=id, form_data=form_data, db=db)
|
||||||
if knowledge:
|
if knowledge:
|
||||||
return KnowledgeFilesResponse(
|
return KnowledgeFilesResponse(
|
||||||
**knowledge.model_dump(),
|
**knowledge.model_dump(),
|
||||||
files=Knowledges.get_file_metadatas_by_id(knowledge.id),
|
files=Knowledges.get_file_metadatas_by_id(knowledge.id, db=db),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -356,9 +383,10 @@ async def get_knowledge_files_by_id(
|
||||||
direction: Optional[str] = None,
|
direction: Optional[str] = None,
|
||||||
page: Optional[int] = 1,
|
page: Optional[int] = 1,
|
||||||
user=Depends(get_verified_user),
|
user=Depends(get_verified_user),
|
||||||
|
db: Session = Depends(get_session),
|
||||||
):
|
):
|
||||||
|
|
||||||
knowledge = Knowledges.get_knowledge_by_id(id=id)
|
knowledge = Knowledges.get_knowledge_by_id(id=id, db=db)
|
||||||
if not knowledge:
|
if not knowledge:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
|
@ -368,7 +396,7 @@ async def get_knowledge_files_by_id(
|
||||||
if not (
|
if not (
|
||||||
user.role == "admin"
|
user.role == "admin"
|
||||||
or knowledge.user_id == user.id
|
or knowledge.user_id == user.id
|
||||||
or has_access(user.id, "read", knowledge.access_control)
|
or has_access(user.id, "read", knowledge.access_control, db=db)
|
||||||
):
|
):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
|
@ -391,7 +419,7 @@ async def get_knowledge_files_by_id(
|
||||||
filter["direction"] = direction
|
filter["direction"] = direction
|
||||||
|
|
||||||
return Knowledges.search_files_by_id(
|
return Knowledges.search_files_by_id(
|
||||||
id, user.id, filter=filter, skip=skip, limit=limit
|
id, user.id, filter=filter, skip=skip, limit=limit, db=db
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -410,8 +438,9 @@ def add_file_to_knowledge_by_id(
|
||||||
id: str,
|
id: str,
|
||||||
form_data: KnowledgeFileIdForm,
|
form_data: KnowledgeFileIdForm,
|
||||||
user=Depends(get_verified_user),
|
user=Depends(get_verified_user),
|
||||||
|
db: Session = Depends(get_session),
|
||||||
):
|
):
|
||||||
knowledge = Knowledges.get_knowledge_by_id(id=id)
|
knowledge = Knowledges.get_knowledge_by_id(id=id, db=db)
|
||||||
if not knowledge:
|
if not knowledge:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
|
@ -420,7 +449,7 @@ def add_file_to_knowledge_by_id(
|
||||||
|
|
||||||
if (
|
if (
|
||||||
knowledge.user_id != user.id
|
knowledge.user_id != user.id
|
||||||
and not has_access(user.id, "write", knowledge.access_control)
|
and not has_access(user.id, "write", knowledge.access_control, db=db)
|
||||||
and user.role != "admin"
|
and user.role != "admin"
|
||||||
):
|
):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -428,7 +457,7 @@ def add_file_to_knowledge_by_id(
|
||||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||||
)
|
)
|
||||||
|
|
||||||
file = Files.get_file_by_id(form_data.file_id)
|
file = Files.get_file_by_id(form_data.file_id, db=db)
|
||||||
if not file:
|
if not file:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
|
@ -446,11 +475,12 @@ def add_file_to_knowledge_by_id(
|
||||||
request,
|
request,
|
||||||
ProcessFileForm(file_id=form_data.file_id, collection_name=id),
|
ProcessFileForm(file_id=form_data.file_id, collection_name=id),
|
||||||
user=user,
|
user=user,
|
||||||
|
db=db,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add file to knowledge base
|
# Add file to knowledge base
|
||||||
Knowledges.add_file_to_knowledge_by_id(
|
Knowledges.add_file_to_knowledge_by_id(
|
||||||
knowledge_id=id, file_id=form_data.file_id, user_id=user.id
|
knowledge_id=id, file_id=form_data.file_id, user_id=user.id, db=db
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.debug(e)
|
log.debug(e)
|
||||||
|
|
@ -462,7 +492,7 @@ def add_file_to_knowledge_by_id(
|
||||||
if knowledge:
|
if knowledge:
|
||||||
return KnowledgeFilesResponse(
|
return KnowledgeFilesResponse(
|
||||||
**knowledge.model_dump(),
|
**knowledge.model_dump(),
|
||||||
files=Knowledges.get_file_metadatas_by_id(knowledge.id),
|
files=Knowledges.get_file_metadatas_by_id(knowledge.id, db=db),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -477,8 +507,9 @@ def update_file_from_knowledge_by_id(
|
||||||
id: str,
|
id: str,
|
||||||
form_data: KnowledgeFileIdForm,
|
form_data: KnowledgeFileIdForm,
|
||||||
user=Depends(get_verified_user),
|
user=Depends(get_verified_user),
|
||||||
|
db: Session = Depends(get_session),
|
||||||
):
|
):
|
||||||
knowledge = Knowledges.get_knowledge_by_id(id=id)
|
knowledge = Knowledges.get_knowledge_by_id(id=id, db=db)
|
||||||
if not knowledge:
|
if not knowledge:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
|
@ -487,7 +518,7 @@ def update_file_from_knowledge_by_id(
|
||||||
|
|
||||||
if (
|
if (
|
||||||
knowledge.user_id != user.id
|
knowledge.user_id != user.id
|
||||||
and not has_access(user.id, "write", knowledge.access_control)
|
and not has_access(user.id, "write", knowledge.access_control, db=db)
|
||||||
and user.role != "admin"
|
and user.role != "admin"
|
||||||
):
|
):
|
||||||
|
|
||||||
|
|
@ -496,7 +527,7 @@ def update_file_from_knowledge_by_id(
|
||||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||||
)
|
)
|
||||||
|
|
||||||
file = Files.get_file_by_id(form_data.file_id)
|
file = Files.get_file_by_id(form_data.file_id, db=db)
|
||||||
if not file:
|
if not file:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
|
@ -514,6 +545,7 @@ def update_file_from_knowledge_by_id(
|
||||||
request,
|
request,
|
||||||
ProcessFileForm(file_id=form_data.file_id, collection_name=id),
|
ProcessFileForm(file_id=form_data.file_id, collection_name=id),
|
||||||
user=user,
|
user=user,
|
||||||
|
db=db,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -524,7 +556,7 @@ def update_file_from_knowledge_by_id(
|
||||||
if knowledge:
|
if knowledge:
|
||||||
return KnowledgeFilesResponse(
|
return KnowledgeFilesResponse(
|
||||||
**knowledge.model_dump(),
|
**knowledge.model_dump(),
|
||||||
files=Knowledges.get_file_metadatas_by_id(knowledge.id),
|
files=Knowledges.get_file_metadatas_by_id(knowledge.id, db=db),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -544,8 +576,9 @@ def remove_file_from_knowledge_by_id(
|
||||||
form_data: KnowledgeFileIdForm,
|
form_data: KnowledgeFileIdForm,
|
||||||
delete_file: bool = Query(True),
|
delete_file: bool = Query(True),
|
||||||
user=Depends(get_verified_user),
|
user=Depends(get_verified_user),
|
||||||
|
db: Session = Depends(get_session),
|
||||||
):
|
):
|
||||||
knowledge = Knowledges.get_knowledge_by_id(id=id)
|
knowledge = Knowledges.get_knowledge_by_id(id=id, db=db)
|
||||||
if not knowledge:
|
if not knowledge:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
|
@ -554,7 +587,7 @@ def remove_file_from_knowledge_by_id(
|
||||||
|
|
||||||
if (
|
if (
|
||||||
knowledge.user_id != user.id
|
knowledge.user_id != user.id
|
||||||
and not has_access(user.id, "write", knowledge.access_control)
|
and not has_access(user.id, "write", knowledge.access_control, db=db)
|
||||||
and user.role != "admin"
|
and user.role != "admin"
|
||||||
):
|
):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -562,7 +595,7 @@ def remove_file_from_knowledge_by_id(
|
||||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||||
)
|
)
|
||||||
|
|
||||||
file = Files.get_file_by_id(form_data.file_id)
|
file = Files.get_file_by_id(form_data.file_id, db=db)
|
||||||
if not file:
|
if not file:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
|
@ -570,7 +603,7 @@ def remove_file_from_knowledge_by_id(
|
||||||
)
|
)
|
||||||
|
|
||||||
Knowledges.remove_file_from_knowledge_by_id(
|
Knowledges.remove_file_from_knowledge_by_id(
|
||||||
knowledge_id=id, file_id=form_data.file_id
|
knowledge_id=id, file_id=form_data.file_id, db=db
|
||||||
)
|
)
|
||||||
|
|
||||||
# Remove content from the vector database
|
# Remove content from the vector database
|
||||||
|
|
@ -599,12 +632,12 @@ def remove_file_from_knowledge_by_id(
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Delete file from database
|
# Delete file from database
|
||||||
Files.delete_file_by_id(form_data.file_id)
|
Files.delete_file_by_id(form_data.file_id, db=db)
|
||||||
|
|
||||||
if knowledge:
|
if knowledge:
|
||||||
return KnowledgeFilesResponse(
|
return KnowledgeFilesResponse(
|
||||||
**knowledge.model_dump(),
|
**knowledge.model_dump(),
|
||||||
files=Knowledges.get_file_metadatas_by_id(knowledge.id),
|
files=Knowledges.get_file_metadatas_by_id(knowledge.id, db=db),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -619,8 +652,10 @@ def remove_file_from_knowledge_by_id(
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/{id}/delete", response_model=bool)
|
@router.delete("/{id}/delete", response_model=bool)
|
||||||
async def delete_knowledge_by_id(id: str, user=Depends(get_verified_user)):
|
async def delete_knowledge_by_id(
|
||||||
knowledge = Knowledges.get_knowledge_by_id(id=id)
|
id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||||
|
):
|
||||||
|
knowledge = Knowledges.get_knowledge_by_id(id=id, db=db)
|
||||||
if not knowledge:
|
if not knowledge:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
|
@ -629,7 +664,7 @@ async def delete_knowledge_by_id(id: str, user=Depends(get_verified_user)):
|
||||||
|
|
||||||
if (
|
if (
|
||||||
knowledge.user_id != user.id
|
knowledge.user_id != user.id
|
||||||
and not has_access(user.id, "write", knowledge.access_control)
|
and not has_access(user.id, "write", knowledge.access_control, db=db)
|
||||||
and user.role != "admin"
|
and user.role != "admin"
|
||||||
):
|
):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -640,7 +675,7 @@ async def delete_knowledge_by_id(id: str, user=Depends(get_verified_user)):
|
||||||
log.info(f"Deleting knowledge base: {id} (name: {knowledge.name})")
|
log.info(f"Deleting knowledge base: {id} (name: {knowledge.name})")
|
||||||
|
|
||||||
# Get all models
|
# Get all models
|
||||||
models = Models.get_all_models()
|
models = Models.get_all_models(db=db)
|
||||||
log.info(f"Found {len(models)} models to check for knowledge base {id}")
|
log.info(f"Found {len(models)} models to check for knowledge base {id}")
|
||||||
|
|
||||||
# Update models that reference this knowledge base
|
# Update models that reference this knowledge base
|
||||||
|
|
@ -664,7 +699,7 @@ async def delete_knowledge_by_id(id: str, user=Depends(get_verified_user)):
|
||||||
access_control=model.access_control,
|
access_control=model.access_control,
|
||||||
is_active=model.is_active,
|
is_active=model.is_active,
|
||||||
)
|
)
|
||||||
Models.update_model_by_id(model.id, model_form)
|
Models.update_model_by_id(model.id, model_form, db=db)
|
||||||
|
|
||||||
# Clean up vector DB
|
# Clean up vector DB
|
||||||
try:
|
try:
|
||||||
|
|
@ -672,7 +707,7 @@ async def delete_knowledge_by_id(id: str, user=Depends(get_verified_user)):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.debug(e)
|
log.debug(e)
|
||||||
pass
|
pass
|
||||||
result = Knowledges.delete_knowledge_by_id(id=id)
|
result = Knowledges.delete_knowledge_by_id(id=id, db=db)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -682,8 +717,10 @@ async def delete_knowledge_by_id(id: str, user=Depends(get_verified_user)):
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{id}/reset", response_model=Optional[KnowledgeResponse])
|
@router.post("/{id}/reset", response_model=Optional[KnowledgeResponse])
|
||||||
async def reset_knowledge_by_id(id: str, user=Depends(get_verified_user)):
|
async def reset_knowledge_by_id(
|
||||||
knowledge = Knowledges.get_knowledge_by_id(id=id)
|
id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||||
|
):
|
||||||
|
knowledge = Knowledges.get_knowledge_by_id(id=id, db=db)
|
||||||
if not knowledge:
|
if not knowledge:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
|
@ -692,7 +729,7 @@ async def reset_knowledge_by_id(id: str, user=Depends(get_verified_user)):
|
||||||
|
|
||||||
if (
|
if (
|
||||||
knowledge.user_id != user.id
|
knowledge.user_id != user.id
|
||||||
and not has_access(user.id, "write", knowledge.access_control)
|
and not has_access(user.id, "write", knowledge.access_control, db=db)
|
||||||
and user.role != "admin"
|
and user.role != "admin"
|
||||||
):
|
):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -706,7 +743,7 @@ async def reset_knowledge_by_id(id: str, user=Depends(get_verified_user)):
|
||||||
log.debug(e)
|
log.debug(e)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
knowledge = Knowledges.reset_knowledge_by_id(id=id)
|
knowledge = Knowledges.reset_knowledge_by_id(id=id, db=db)
|
||||||
return knowledge
|
return knowledge
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -721,11 +758,12 @@ async def add_files_to_knowledge_batch(
|
||||||
id: str,
|
id: str,
|
||||||
form_data: list[KnowledgeFileIdForm],
|
form_data: list[KnowledgeFileIdForm],
|
||||||
user=Depends(get_verified_user),
|
user=Depends(get_verified_user),
|
||||||
|
db: Session = Depends(get_session),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Add multiple files to a knowledge base
|
Add multiple files to a knowledge base
|
||||||
"""
|
"""
|
||||||
knowledge = Knowledges.get_knowledge_by_id(id=id)
|
knowledge = Knowledges.get_knowledge_by_id(id=id, db=db)
|
||||||
if not knowledge:
|
if not knowledge:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
|
@ -734,7 +772,7 @@ async def add_files_to_knowledge_batch(
|
||||||
|
|
||||||
if (
|
if (
|
||||||
knowledge.user_id != user.id
|
knowledge.user_id != user.id
|
||||||
and not has_access(user.id, "write", knowledge.access_control)
|
and not has_access(user.id, "write", knowledge.access_control, db=db)
|
||||||
and user.role != "admin"
|
and user.role != "admin"
|
||||||
):
|
):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -746,7 +784,7 @@ async def add_files_to_knowledge_batch(
|
||||||
log.info(f"files/batch/add - {len(form_data)} files")
|
log.info(f"files/batch/add - {len(form_data)} files")
|
||||||
files: List[FileModel] = []
|
files: List[FileModel] = []
|
||||||
for form in form_data:
|
for form in form_data:
|
||||||
file = Files.get_file_by_id(form.file_id)
|
file = Files.get_file_by_id(form.file_id, db=db)
|
||||||
if not file:
|
if not file:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
|
@ -760,6 +798,7 @@ async def add_files_to_knowledge_batch(
|
||||||
request=request,
|
request=request,
|
||||||
form_data=BatchProcessFilesForm(files=files, collection_name=id),
|
form_data=BatchProcessFilesForm(files=files, collection_name=id),
|
||||||
user=user,
|
user=user,
|
||||||
|
db=db,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(
|
log.error(
|
||||||
|
|
@ -771,7 +810,7 @@ async def add_files_to_knowledge_batch(
|
||||||
successful_file_ids = [r.file_id for r in result.results if r.status == "completed"]
|
successful_file_ids = [r.file_id for r in result.results if r.status == "completed"]
|
||||||
for file_id in successful_file_ids:
|
for file_id in successful_file_ids:
|
||||||
Knowledges.add_file_to_knowledge_by_id(
|
Knowledges.add_file_to_knowledge_by_id(
|
||||||
knowledge_id=id, file_id=file_id, user_id=user.id
|
knowledge_id=id, file_id=file_id, user_id=user.id, db=db
|
||||||
)
|
)
|
||||||
|
|
||||||
# If there were any errors, include them in the response
|
# If there were any errors, include them in the response
|
||||||
|
|
@ -779,7 +818,7 @@ async def add_files_to_knowledge_batch(
|
||||||
error_details = [f"{err.file_id}: {err.error}" for err in result.errors]
|
error_details = [f"{err.file_id}: {err.error}" for err in result.errors]
|
||||||
return KnowledgeFilesResponse(
|
return KnowledgeFilesResponse(
|
||||||
**knowledge.model_dump(),
|
**knowledge.model_dump(),
|
||||||
files=Knowledges.get_file_metadatas_by_id(knowledge.id),
|
files=Knowledges.get_file_metadatas_by_id(knowledge.id, db=db),
|
||||||
warnings={
|
warnings={
|
||||||
"message": "Some files failed to process",
|
"message": "Some files failed to process",
|
||||||
"errors": error_details,
|
"errors": error_details,
|
||||||
|
|
@ -788,5 +827,5 @@ async def add_files_to_knowledge_batch(
|
||||||
|
|
||||||
return KnowledgeFilesResponse(
|
return KnowledgeFilesResponse(
|
||||||
**knowledge.model_dump(),
|
**knowledge.model_dump(),
|
||||||
files=Knowledges.get_file_metadatas_by_id(knowledge.id),
|
files=Knowledges.get_file_metadatas_by_id(knowledge.id, db=db),
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,10 @@ from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT
|
||||||
from open_webui.utils.auth import get_verified_user
|
from open_webui.utils.auth import get_verified_user
|
||||||
|
|
||||||
|
|
||||||
|
from open_webui.utils.auth import get_verified_user
|
||||||
|
from open_webui.internal.db import get_session
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
@ -25,8 +29,8 @@ async def get_embeddings(request: Request):
|
||||||
|
|
||||||
|
|
||||||
@router.get("/", response_model=list[MemoryModel])
|
@router.get("/", response_model=list[MemoryModel])
|
||||||
async def get_memories(user=Depends(get_verified_user)):
|
async def get_memories(user=Depends(get_verified_user), db: Session = Depends(get_session)):
|
||||||
return Memories.get_memories_by_user_id(user.id)
|
return Memories.get_memories_by_user_id(user.id, db=db)
|
||||||
|
|
||||||
|
|
||||||
############################
|
############################
|
||||||
|
|
@ -47,8 +51,9 @@ async def add_memory(
|
||||||
request: Request,
|
request: Request,
|
||||||
form_data: AddMemoryForm,
|
form_data: AddMemoryForm,
|
||||||
user=Depends(get_verified_user),
|
user=Depends(get_verified_user),
|
||||||
|
db: Session = Depends(get_session),
|
||||||
):
|
):
|
||||||
memory = Memories.insert_new_memory(user.id, form_data.content)
|
memory = Memories.insert_new_memory(user.id, form_data.content, db=db)
|
||||||
|
|
||||||
vector = await request.app.state.EMBEDDING_FUNCTION(memory.content, user=user)
|
vector = await request.app.state.EMBEDDING_FUNCTION(memory.content, user=user)
|
||||||
|
|
||||||
|
|
@ -79,9 +84,9 @@ class QueryMemoryForm(BaseModel):
|
||||||
|
|
||||||
@router.post("/query")
|
@router.post("/query")
|
||||||
async def query_memory(
|
async def query_memory(
|
||||||
request: Request, form_data: QueryMemoryForm, user=Depends(get_verified_user)
|
request: Request, form_data: QueryMemoryForm, user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||||
):
|
):
|
||||||
memories = Memories.get_memories_by_user_id(user.id)
|
memories = Memories.get_memories_by_user_id(user.id, db=db)
|
||||||
if not memories:
|
if not memories:
|
||||||
raise HTTPException(status_code=404, detail="No memories found for user")
|
raise HTTPException(status_code=404, detail="No memories found for user")
|
||||||
|
|
||||||
|
|
@ -101,11 +106,11 @@ async def query_memory(
|
||||||
############################
|
############################
|
||||||
@router.post("/reset", response_model=bool)
|
@router.post("/reset", response_model=bool)
|
||||||
async def reset_memory_from_vector_db(
|
async def reset_memory_from_vector_db(
|
||||||
request: Request, user=Depends(get_verified_user)
|
request: Request, user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||||
):
|
):
|
||||||
VECTOR_DB_CLIENT.delete_collection(f"user-memory-{user.id}")
|
VECTOR_DB_CLIENT.delete_collection(f"user-memory-{user.id}")
|
||||||
|
|
||||||
memories = Memories.get_memories_by_user_id(user.id)
|
memories = Memories.get_memories_by_user_id(user.id, db=db)
|
||||||
|
|
||||||
# Generate vectors in parallel
|
# Generate vectors in parallel
|
||||||
vectors = await asyncio.gather(
|
vectors = await asyncio.gather(
|
||||||
|
|
@ -140,8 +145,8 @@ async def reset_memory_from_vector_db(
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/delete/user", response_model=bool)
|
@router.delete("/delete/user", response_model=bool)
|
||||||
async def delete_memory_by_user_id(user=Depends(get_verified_user)):
|
async def delete_memory_by_user_id(user=Depends(get_verified_user), db: Session = Depends(get_session)):
|
||||||
result = Memories.delete_memories_by_user_id(user.id)
|
result = Memories.delete_memories_by_user_id(user.id, db=db)
|
||||||
|
|
||||||
if result:
|
if result:
|
||||||
try:
|
try:
|
||||||
|
|
@ -164,9 +169,10 @@ async def update_memory_by_id(
|
||||||
request: Request,
|
request: Request,
|
||||||
form_data: MemoryUpdateModel,
|
form_data: MemoryUpdateModel,
|
||||||
user=Depends(get_verified_user),
|
user=Depends(get_verified_user),
|
||||||
|
db: Session = Depends(get_session),
|
||||||
):
|
):
|
||||||
memory = Memories.update_memory_by_id_and_user_id(
|
memory = Memories.update_memory_by_id_and_user_id(
|
||||||
memory_id, user.id, form_data.content
|
memory_id, user.id, form_data.content, db=db
|
||||||
)
|
)
|
||||||
if memory is None:
|
if memory is None:
|
||||||
raise HTTPException(status_code=404, detail="Memory not found")
|
raise HTTPException(status_code=404, detail="Memory not found")
|
||||||
|
|
@ -198,8 +204,8 @@ async def update_memory_by_id(
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/{memory_id}", response_model=bool)
|
@router.delete("/{memory_id}", response_model=bool)
|
||||||
async def delete_memory_by_id(memory_id: str, user=Depends(get_verified_user)):
|
async def delete_memory_by_id(memory_id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)):
|
||||||
result = Memories.delete_memory_by_id_and_user_id(memory_id, user.id)
|
result = Memories.delete_memory_by_id_and_user_id(memory_id, user.id, db=db)
|
||||||
|
|
||||||
if result:
|
if result:
|
||||||
VECTOR_DB_CLIENT.delete(
|
VECTOR_DB_CLIENT.delete(
|
||||||
|
|
|
||||||
|
|
@ -30,6 +30,8 @@ from fastapi.responses import FileResponse, StreamingResponse
|
||||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||||
from open_webui.utils.access_control import has_access, has_permission
|
from open_webui.utils.access_control import has_access, has_permission
|
||||||
from open_webui.config import BYPASS_ADMIN_ACCESS_CONTROL, STATIC_DIR
|
from open_webui.config import BYPASS_ADMIN_ACCESS_CONTROL, STATIC_DIR
|
||||||
|
from open_webui.internal.db import get_session
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -59,6 +61,7 @@ async def get_models(
|
||||||
direction: Optional[str] = None,
|
direction: Optional[str] = None,
|
||||||
page: Optional[int] = 1,
|
page: Optional[int] = 1,
|
||||||
user=Depends(get_verified_user),
|
user=Depends(get_verified_user),
|
||||||
|
db: Session = Depends(get_session),
|
||||||
):
|
):
|
||||||
|
|
||||||
limit = PAGE_ITEM_COUNT
|
limit = PAGE_ITEM_COUNT
|
||||||
|
|
@ -79,13 +82,13 @@ async def get_models(
|
||||||
filter["direction"] = direction
|
filter["direction"] = direction
|
||||||
|
|
||||||
if not user.role == "admin" or not BYPASS_ADMIN_ACCESS_CONTROL:
|
if not user.role == "admin" or not BYPASS_ADMIN_ACCESS_CONTROL:
|
||||||
groups = Groups.get_groups_by_member_id(user.id)
|
groups = Groups.get_groups_by_member_id(user.id, db=db)
|
||||||
if groups:
|
if groups:
|
||||||
filter["group_ids"] = [group.id for group in groups]
|
filter["group_ids"] = [group.id for group in groups]
|
||||||
|
|
||||||
filter["user_id"] = user.id
|
filter["user_id"] = user.id
|
||||||
|
|
||||||
return Models.search_models(user.id, filter=filter, skip=skip, limit=limit)
|
return Models.search_models(user.id, filter=filter, skip=skip, limit=limit, db=db)
|
||||||
|
|
||||||
|
|
||||||
###########################
|
###########################
|
||||||
|
|
@ -94,8 +97,8 @@ async def get_models(
|
||||||
|
|
||||||
|
|
||||||
@router.get("/base", response_model=list[ModelResponse])
|
@router.get("/base", response_model=list[ModelResponse])
|
||||||
async def get_base_models(user=Depends(get_admin_user)):
|
async def get_base_models(user=Depends(get_admin_user), db: Session = Depends(get_session)):
|
||||||
return Models.get_base_models()
|
return Models.get_base_models(db=db)
|
||||||
|
|
||||||
|
|
||||||
###########################
|
###########################
|
||||||
|
|
@ -104,11 +107,11 @@ async def get_base_models(user=Depends(get_admin_user)):
|
||||||
|
|
||||||
|
|
||||||
@router.get("/tags", response_model=list[str])
|
@router.get("/tags", response_model=list[str])
|
||||||
async def get_model_tags(user=Depends(get_verified_user)):
|
async def get_model_tags(user=Depends(get_verified_user), db: Session = Depends(get_session)):
|
||||||
if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL:
|
if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL:
|
||||||
models = Models.get_models()
|
models = Models.get_models(db=db)
|
||||||
else:
|
else:
|
||||||
models = Models.get_models_by_user_id(user.id)
|
models = Models.get_models_by_user_id(user.id, db=db)
|
||||||
|
|
||||||
tags_set = set()
|
tags_set = set()
|
||||||
for model in models:
|
for model in models:
|
||||||
|
|
@ -132,16 +135,17 @@ async def create_new_model(
|
||||||
request: Request,
|
request: Request,
|
||||||
form_data: ModelForm,
|
form_data: ModelForm,
|
||||||
user=Depends(get_verified_user),
|
user=Depends(get_verified_user),
|
||||||
|
db: Session = Depends(get_session),
|
||||||
):
|
):
|
||||||
if user.role != "admin" and not has_permission(
|
if user.role != "admin" and not has_permission(
|
||||||
user.id, "workspace.models", request.app.state.config.USER_PERMISSIONS
|
user.id, "workspace.models", request.app.state.config.USER_PERMISSIONS, db=db
|
||||||
):
|
):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
||||||
)
|
)
|
||||||
|
|
||||||
model = Models.get_model_by_id(form_data.id)
|
model = Models.get_model_by_id(form_data.id, db=db)
|
||||||
if model:
|
if model:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
|
@ -155,7 +159,7 @@ async def create_new_model(
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
model = Models.insert_new_model(form_data, user.id)
|
model = Models.insert_new_model(form_data, user.id, db=db)
|
||||||
if model:
|
if model:
|
||||||
return model
|
return model
|
||||||
else:
|
else:
|
||||||
|
|
@ -171,9 +175,9 @@ async def create_new_model(
|
||||||
|
|
||||||
|
|
||||||
@router.get("/export", response_model=list[ModelModel])
|
@router.get("/export", response_model=list[ModelModel])
|
||||||
async def export_models(request: Request, user=Depends(get_verified_user)):
|
async def export_models(request: Request, user=Depends(get_verified_user), db: Session = Depends(get_session)):
|
||||||
if user.role != "admin" and not has_permission(
|
if user.role != "admin" and not has_permission(
|
||||||
user.id, "workspace.models_export", request.app.state.config.USER_PERMISSIONS
|
user.id, "workspace.models_export", request.app.state.config.USER_PERMISSIONS, db=db
|
||||||
):
|
):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
|
@ -181,9 +185,9 @@ async def export_models(request: Request, user=Depends(get_verified_user)):
|
||||||
)
|
)
|
||||||
|
|
||||||
if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL:
|
if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL:
|
||||||
return Models.get_models()
|
return Models.get_models(db=db)
|
||||||
else:
|
else:
|
||||||
return Models.get_models_by_user_id(user.id)
|
return Models.get_models_by_user_id(user.id, db=db)
|
||||||
|
|
||||||
|
|
||||||
############################
|
############################
|
||||||
|
|
@ -200,9 +204,10 @@ async def import_models(
|
||||||
request: Request,
|
request: Request,
|
||||||
user=Depends(get_verified_user),
|
user=Depends(get_verified_user),
|
||||||
form_data: ModelsImportForm = (...),
|
form_data: ModelsImportForm = (...),
|
||||||
|
db: Session = Depends(get_session),
|
||||||
):
|
):
|
||||||
if user.role != "admin" and not has_permission(
|
if user.role != "admin" and not has_permission(
|
||||||
user.id, "workspace.models_import", request.app.state.config.USER_PERMISSIONS
|
user.id, "workspace.models_import", request.app.state.config.USER_PERMISSIONS, db=db
|
||||||
):
|
):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
|
@ -216,7 +221,7 @@ async def import_models(
|
||||||
model_id = model_data.get("id")
|
model_id = model_data.get("id")
|
||||||
|
|
||||||
if model_id and is_valid_model_id(model_id):
|
if model_id and is_valid_model_id(model_id):
|
||||||
existing_model = Models.get_model_by_id(model_id)
|
existing_model = Models.get_model_by_id(model_id, db=db)
|
||||||
if existing_model:
|
if existing_model:
|
||||||
# Update existing model
|
# Update existing model
|
||||||
model_data["meta"] = model_data.get("meta", {})
|
model_data["meta"] = model_data.get("meta", {})
|
||||||
|
|
@ -225,13 +230,13 @@ async def import_models(
|
||||||
updated_model = ModelForm(
|
updated_model = ModelForm(
|
||||||
**{**existing_model.model_dump(), **model_data}
|
**{**existing_model.model_dump(), **model_data}
|
||||||
)
|
)
|
||||||
Models.update_model_by_id(model_id, updated_model)
|
Models.update_model_by_id(model_id, updated_model, db=db)
|
||||||
else:
|
else:
|
||||||
# Insert new model
|
# Insert new model
|
||||||
model_data["meta"] = model_data.get("meta", {})
|
model_data["meta"] = model_data.get("meta", {})
|
||||||
model_data["params"] = model_data.get("params", {})
|
model_data["params"] = model_data.get("params", {})
|
||||||
new_model = ModelForm(**model_data)
|
new_model = ModelForm(**model_data)
|
||||||
Models.insert_new_model(user_id=user.id, form_data=new_model)
|
Models.insert_new_model(user_id=user.id, form_data=new_model, db=db)
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
raise HTTPException(status_code=400, detail="Invalid JSON format")
|
raise HTTPException(status_code=400, detail="Invalid JSON format")
|
||||||
|
|
@ -251,9 +256,9 @@ class SyncModelsForm(BaseModel):
|
||||||
|
|
||||||
@router.post("/sync", response_model=list[ModelModel])
|
@router.post("/sync", response_model=list[ModelModel])
|
||||||
async def sync_models(
|
async def sync_models(
|
||||||
request: Request, form_data: SyncModelsForm, user=Depends(get_admin_user)
|
request: Request, form_data: SyncModelsForm, user=Depends(get_admin_user), db: Session = Depends(get_session)
|
||||||
):
|
):
|
||||||
return Models.sync_models(user.id, form_data.models)
|
return Models.sync_models(user.id, form_data.models, db=db)
|
||||||
|
|
||||||
|
|
||||||
###########################
|
###########################
|
||||||
|
|
@ -267,13 +272,13 @@ class ModelIdForm(BaseModel):
|
||||||
|
|
||||||
# Note: We're not using the typical url path param here, but instead using a query parameter to allow '/' in the id
|
# Note: We're not using the typical url path param here, but instead using a query parameter to allow '/' in the id
|
||||||
@router.get("/model", response_model=Optional[ModelResponse])
|
@router.get("/model", response_model=Optional[ModelResponse])
|
||||||
async def get_model_by_id(id: str, user=Depends(get_verified_user)):
|
async def get_model_by_id(id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)):
|
||||||
model = Models.get_model_by_id(id)
|
model = Models.get_model_by_id(id, db=db)
|
||||||
if model:
|
if model:
|
||||||
if (
|
if (
|
||||||
(user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL)
|
(user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL)
|
||||||
or model.user_id == user.id
|
or model.user_id == user.id
|
||||||
or has_access(user.id, "read", model.access_control)
|
or has_access(user.id, "read", model.access_control, db=db)
|
||||||
):
|
):
|
||||||
return model
|
return model
|
||||||
else:
|
else:
|
||||||
|
|
@ -289,8 +294,8 @@ async def get_model_by_id(id: str, user=Depends(get_verified_user)):
|
||||||
|
|
||||||
|
|
||||||
@router.get("/model/profile/image")
|
@router.get("/model/profile/image")
|
||||||
async def get_model_profile_image(id: str, user=Depends(get_verified_user)):
|
async def get_model_profile_image(id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)):
|
||||||
model = Models.get_model_by_id(id)
|
model = Models.get_model_by_id(id, db=db)
|
||||||
# Cache-control headers to prevent stale cached images
|
# Cache-control headers to prevent stale cached images
|
||||||
cache_headers = {"Cache-Control": "no-cache, must-revalidate"}
|
cache_headers = {"Cache-Control": "no-cache, must-revalidate"}
|
||||||
|
|
||||||
|
|
@ -330,15 +335,15 @@ async def get_model_profile_image(id: str, user=Depends(get_verified_user)):
|
||||||
|
|
||||||
|
|
||||||
@router.post("/model/toggle", response_model=Optional[ModelResponse])
|
@router.post("/model/toggle", response_model=Optional[ModelResponse])
|
||||||
async def toggle_model_by_id(id: str, user=Depends(get_verified_user)):
|
async def toggle_model_by_id(id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)):
|
||||||
model = Models.get_model_by_id(id)
|
model = Models.get_model_by_id(id, db=db)
|
||||||
if model:
|
if model:
|
||||||
if (
|
if (
|
||||||
user.role == "admin"
|
user.role == "admin"
|
||||||
or model.user_id == user.id
|
or model.user_id == user.id
|
||||||
or has_access(user.id, "write", model.access_control)
|
or has_access(user.id, "write", model.access_control, db=db)
|
||||||
):
|
):
|
||||||
model = Models.toggle_model_by_id(id)
|
model = Models.toggle_model_by_id(id, db=db)
|
||||||
|
|
||||||
if model:
|
if model:
|
||||||
return model
|
return model
|
||||||
|
|
@ -368,8 +373,9 @@ async def toggle_model_by_id(id: str, user=Depends(get_verified_user)):
|
||||||
async def update_model_by_id(
|
async def update_model_by_id(
|
||||||
form_data: ModelForm,
|
form_data: ModelForm,
|
||||||
user=Depends(get_verified_user),
|
user=Depends(get_verified_user),
|
||||||
|
db: Session = Depends(get_session),
|
||||||
):
|
):
|
||||||
model = Models.get_model_by_id(form_data.id)
|
model = Models.get_model_by_id(form_data.id, db=db)
|
||||||
if not model:
|
if not model:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
|
@ -378,7 +384,7 @@ async def update_model_by_id(
|
||||||
|
|
||||||
if (
|
if (
|
||||||
model.user_id != user.id
|
model.user_id != user.id
|
||||||
and not has_access(user.id, "write", model.access_control)
|
and not has_access(user.id, "write", model.access_control, db=db)
|
||||||
and user.role != "admin"
|
and user.role != "admin"
|
||||||
):
|
):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -386,7 +392,7 @@ async def update_model_by_id(
|
||||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||||
)
|
)
|
||||||
|
|
||||||
model = Models.update_model_by_id(form_data.id, ModelForm(**form_data.model_dump()))
|
model = Models.update_model_by_id(form_data.id, ModelForm(**form_data.model_dump()), db=db)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -396,8 +402,8 @@ async def update_model_by_id(
|
||||||
|
|
||||||
|
|
||||||
@router.post("/model/delete", response_model=bool)
|
@router.post("/model/delete", response_model=bool)
|
||||||
async def delete_model_by_id(form_data: ModelIdForm, user=Depends(get_verified_user)):
|
async def delete_model_by_id(form_data: ModelIdForm, user=Depends(get_verified_user), db: Session = Depends(get_session)):
|
||||||
model = Models.get_model_by_id(form_data.id)
|
model = Models.get_model_by_id(form_data.id, db=db)
|
||||||
if not model:
|
if not model:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
|
@ -407,18 +413,18 @@ async def delete_model_by_id(form_data: ModelIdForm, user=Depends(get_verified_u
|
||||||
if (
|
if (
|
||||||
user.role != "admin"
|
user.role != "admin"
|
||||||
and model.user_id != user.id
|
and model.user_id != user.id
|
||||||
and not has_access(user.id, "write", model.access_control)
|
and not has_access(user.id, "write", model.access_control, db=db)
|
||||||
):
|
):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
||||||
)
|
)
|
||||||
|
|
||||||
result = Models.delete_model_by_id(form_data.id)
|
result = Models.delete_model_by_id(form_data.id, db=db)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/delete/all", response_model=bool)
|
@router.delete("/delete/all", response_model=bool)
|
||||||
async def delete_all_models(user=Depends(get_admin_user)):
|
async def delete_all_models(user=Depends(get_admin_user), db: Session = Depends(get_session)):
|
||||||
result = Models.delete_all_models()
|
result = Models.delete_all_models(db=db)
|
||||||
return result
|
return result
|
||||||
|
|
|
||||||
|
|
@ -28,6 +28,8 @@ from open_webui.constants import ERROR_MESSAGES
|
||||||
|
|
||||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||||
from open_webui.utils.access_control import has_access, has_permission
|
from open_webui.utils.access_control import has_access, has_permission
|
||||||
|
from open_webui.internal.db import get_session
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -49,10 +51,13 @@ class NoteItemResponse(BaseModel):
|
||||||
|
|
||||||
@router.get("/", response_model=list[NoteItemResponse])
|
@router.get("/", response_model=list[NoteItemResponse])
|
||||||
async def get_notes(
|
async def get_notes(
|
||||||
request: Request, page: Optional[int] = None, user=Depends(get_verified_user)
|
request: Request,
|
||||||
|
page: Optional[int] = None,
|
||||||
|
user=Depends(get_verified_user),
|
||||||
|
db: Session = Depends(get_session),
|
||||||
):
|
):
|
||||||
if user.role != "admin" and not has_permission(
|
if user.role != "admin" and not has_permission(
|
||||||
user.id, "features.notes", request.app.state.config.USER_PERMISSIONS
|
user.id, "features.notes", request.app.state.config.USER_PERMISSIONS, db=db
|
||||||
):
|
):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
|
@ -69,10 +74,14 @@ async def get_notes(
|
||||||
NoteUserResponse(
|
NoteUserResponse(
|
||||||
**{
|
**{
|
||||||
**note.model_dump(),
|
**note.model_dump(),
|
||||||
"user": UserResponse(**Users.get_user_by_id(note.user_id).model_dump()),
|
"user": UserResponse(
|
||||||
|
**Users.get_user_by_id(note.user_id, db=db).model_dump()
|
||||||
|
),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
for note in Notes.get_notes_by_user_id(user.id, "read", skip=skip, limit=limit)
|
for note in Notes.get_notes_by_user_id(
|
||||||
|
user.id, "read", skip=skip, limit=limit, db=db
|
||||||
|
)
|
||||||
]
|
]
|
||||||
return notes
|
return notes
|
||||||
|
|
||||||
|
|
@ -87,9 +96,10 @@ async def search_notes(
|
||||||
direction: Optional[str] = None,
|
direction: Optional[str] = None,
|
||||||
page: Optional[int] = 1,
|
page: Optional[int] = 1,
|
||||||
user=Depends(get_verified_user),
|
user=Depends(get_verified_user),
|
||||||
|
db: Session = Depends(get_session),
|
||||||
):
|
):
|
||||||
if user.role != "admin" and not has_permission(
|
if user.role != "admin" and not has_permission(
|
||||||
user.id, "features.notes", request.app.state.config.USER_PERMISSIONS
|
user.id, "features.notes", request.app.state.config.USER_PERMISSIONS, db=db
|
||||||
):
|
):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
|
@ -115,13 +125,13 @@ async def search_notes(
|
||||||
filter["direction"] = direction
|
filter["direction"] = direction
|
||||||
|
|
||||||
if not user.role == "admin" or not BYPASS_ADMIN_ACCESS_CONTROL:
|
if not user.role == "admin" or not BYPASS_ADMIN_ACCESS_CONTROL:
|
||||||
groups = Groups.get_groups_by_member_id(user.id)
|
groups = Groups.get_groups_by_member_id(user.id, db=db)
|
||||||
if groups:
|
if groups:
|
||||||
filter["group_ids"] = [group.id for group in groups]
|
filter["group_ids"] = [group.id for group in groups]
|
||||||
|
|
||||||
filter["user_id"] = user.id
|
filter["user_id"] = user.id
|
||||||
|
|
||||||
return Notes.search_notes(user.id, filter, skip=skip, limit=limit)
|
return Notes.search_notes(user.id, filter, skip=skip, limit=limit, db=db)
|
||||||
|
|
||||||
|
|
||||||
############################
|
############################
|
||||||
|
|
@ -131,10 +141,13 @@ async def search_notes(
|
||||||
|
|
||||||
@router.post("/create", response_model=Optional[NoteModel])
|
@router.post("/create", response_model=Optional[NoteModel])
|
||||||
async def create_new_note(
|
async def create_new_note(
|
||||||
request: Request, form_data: NoteForm, user=Depends(get_verified_user)
|
request: Request,
|
||||||
|
form_data: NoteForm,
|
||||||
|
user=Depends(get_verified_user),
|
||||||
|
db: Session = Depends(get_session),
|
||||||
):
|
):
|
||||||
if user.role != "admin" and not has_permission(
|
if user.role != "admin" and not has_permission(
|
||||||
user.id, "features.notes", request.app.state.config.USER_PERMISSIONS
|
user.id, "features.notes", request.app.state.config.USER_PERMISSIONS, db=db
|
||||||
):
|
):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
|
@ -142,7 +155,7 @@ async def create_new_note(
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
note = Notes.insert_new_note(user.id, form_data)
|
note = Notes.insert_new_note(user.id, form_data, db=db)
|
||||||
return note
|
return note
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(e)
|
log.exception(e)
|
||||||
|
|
@ -161,16 +174,21 @@ class NoteResponse(NoteModel):
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{id}", response_model=Optional[NoteResponse])
|
@router.get("/{id}", response_model=Optional[NoteResponse])
|
||||||
async def get_note_by_id(request: Request, id: str, user=Depends(get_verified_user)):
|
async def get_note_by_id(
|
||||||
|
request: Request,
|
||||||
|
id: str,
|
||||||
|
user=Depends(get_verified_user),
|
||||||
|
db: Session = Depends(get_session),
|
||||||
|
):
|
||||||
if user.role != "admin" and not has_permission(
|
if user.role != "admin" and not has_permission(
|
||||||
user.id, "features.notes", request.app.state.config.USER_PERMISSIONS
|
user.id, "features.notes", request.app.state.config.USER_PERMISSIONS, db=db
|
||||||
):
|
):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
||||||
)
|
)
|
||||||
|
|
||||||
note = Notes.get_note_by_id(id)
|
note = Notes.get_note_by_id(id, db=db)
|
||||||
if not note:
|
if not note:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||||
|
|
@ -178,7 +196,11 @@ async def get_note_by_id(request: Request, id: str, user=Depends(get_verified_us
|
||||||
|
|
||||||
if user.role != "admin" and (
|
if user.role != "admin" and (
|
||||||
user.id != note.user_id
|
user.id != note.user_id
|
||||||
and (not has_access(user.id, type="read", access_control=note.access_control))
|
and (
|
||||||
|
not has_access(
|
||||||
|
user.id, type="read", access_control=note.access_control, db=db
|
||||||
|
)
|
||||||
|
)
|
||||||
):
|
):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||||
|
|
@ -188,7 +210,11 @@ async def get_note_by_id(request: Request, id: str, user=Depends(get_verified_us
|
||||||
user.role == "admin"
|
user.role == "admin"
|
||||||
or (user.id == note.user_id)
|
or (user.id == note.user_id)
|
||||||
or has_access(
|
or has_access(
|
||||||
user.id, type="write", access_control=note.access_control, strict=False
|
user.id,
|
||||||
|
type="write",
|
||||||
|
access_control=note.access_control,
|
||||||
|
strict=False,
|
||||||
|
db=db,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -202,17 +228,21 @@ async def get_note_by_id(request: Request, id: str, user=Depends(get_verified_us
|
||||||
|
|
||||||
@router.post("/{id}/update", response_model=Optional[NoteModel])
|
@router.post("/{id}/update", response_model=Optional[NoteModel])
|
||||||
async def update_note_by_id(
|
async def update_note_by_id(
|
||||||
request: Request, id: str, form_data: NoteForm, user=Depends(get_verified_user)
|
request: Request,
|
||||||
|
id: str,
|
||||||
|
form_data: NoteForm,
|
||||||
|
user=Depends(get_verified_user),
|
||||||
|
db: Session = Depends(get_session),
|
||||||
):
|
):
|
||||||
if user.role != "admin" and not has_permission(
|
if user.role != "admin" and not has_permission(
|
||||||
user.id, "features.notes", request.app.state.config.USER_PERMISSIONS
|
user.id, "features.notes", request.app.state.config.USER_PERMISSIONS, db=db
|
||||||
):
|
):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
||||||
)
|
)
|
||||||
|
|
||||||
note = Notes.get_note_by_id(id)
|
note = Notes.get_note_by_id(id, db=db)
|
||||||
if not note:
|
if not note:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||||
|
|
@ -220,7 +250,9 @@ async def update_note_by_id(
|
||||||
|
|
||||||
if user.role != "admin" and (
|
if user.role != "admin" and (
|
||||||
user.id != note.user_id
|
user.id != note.user_id
|
||||||
and not has_access(user.id, type="write", access_control=note.access_control)
|
and not has_access(
|
||||||
|
user.id, type="write", access_control=note.access_control, db=db
|
||||||
|
)
|
||||||
):
|
):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||||
|
|
@ -234,12 +266,13 @@ async def update_note_by_id(
|
||||||
user.id,
|
user.id,
|
||||||
"sharing.public_notes",
|
"sharing.public_notes",
|
||||||
request.app.state.config.USER_PERMISSIONS,
|
request.app.state.config.USER_PERMISSIONS,
|
||||||
|
db=db,
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
form_data.access_control = {}
|
form_data.access_control = {}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
note = Notes.update_note_by_id(id, form_data)
|
note = Notes.update_note_by_id(id, form_data, db=db)
|
||||||
await sio.emit(
|
await sio.emit(
|
||||||
"note-events",
|
"note-events",
|
||||||
note.model_dump(),
|
note.model_dump(),
|
||||||
|
|
@ -260,16 +293,21 @@ async def update_note_by_id(
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/{id}/delete", response_model=bool)
|
@router.delete("/{id}/delete", response_model=bool)
|
||||||
async def delete_note_by_id(request: Request, id: str, user=Depends(get_verified_user)):
|
async def delete_note_by_id(
|
||||||
|
request: Request,
|
||||||
|
id: str,
|
||||||
|
user=Depends(get_verified_user),
|
||||||
|
db: Session = Depends(get_session),
|
||||||
|
):
|
||||||
if user.role != "admin" and not has_permission(
|
if user.role != "admin" and not has_permission(
|
||||||
user.id, "features.notes", request.app.state.config.USER_PERMISSIONS
|
user.id, "features.notes", request.app.state.config.USER_PERMISSIONS, db=db
|
||||||
):
|
):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
||||||
)
|
)
|
||||||
|
|
||||||
note = Notes.get_note_by_id(id)
|
note = Notes.get_note_by_id(id, db=db)
|
||||||
if not note:
|
if not note:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||||
|
|
@ -277,14 +315,16 @@ async def delete_note_by_id(request: Request, id: str, user=Depends(get_verified
|
||||||
|
|
||||||
if user.role != "admin" and (
|
if user.role != "admin" and (
|
||||||
user.id != note.user_id
|
user.id != note.user_id
|
||||||
and not has_access(user.id, type="write", access_control=note.access_control)
|
and not has_access(
|
||||||
|
user.id, type="write", access_control=note.access_control, db=db
|
||||||
|
)
|
||||||
):
|
):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
note = Notes.delete_note_by_id(id)
|
note = Notes.delete_note_by_id(id, db=db)
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(e)
|
log.exception(e)
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,8 @@ from open_webui.constants import ERROR_MESSAGES
|
||||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||||
from open_webui.utils.access_control import has_access, has_permission
|
from open_webui.utils.access_control import has_access, has_permission
|
||||||
from open_webui.config import BYPASS_ADMIN_ACCESS_CONTROL
|
from open_webui.config import BYPASS_ADMIN_ACCESS_CONTROL
|
||||||
|
from open_webui.internal.db import get_session
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
@ -20,21 +22,21 @@ router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
@router.get("/", response_model=list[PromptModel])
|
@router.get("/", response_model=list[PromptModel])
|
||||||
async def get_prompts(user=Depends(get_verified_user)):
|
async def get_prompts(user=Depends(get_verified_user), db: Session = Depends(get_session)):
|
||||||
if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL:
|
if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL:
|
||||||
prompts = Prompts.get_prompts()
|
prompts = Prompts.get_prompts(db=db)
|
||||||
else:
|
else:
|
||||||
prompts = Prompts.get_prompts_by_user_id(user.id, "read")
|
prompts = Prompts.get_prompts_by_user_id(user.id, "read", db=db)
|
||||||
|
|
||||||
return prompts
|
return prompts
|
||||||
|
|
||||||
|
|
||||||
@router.get("/list", response_model=list[PromptUserResponse])
|
@router.get("/list", response_model=list[PromptUserResponse])
|
||||||
async def get_prompt_list(user=Depends(get_verified_user)):
|
async def get_prompt_list(user=Depends(get_verified_user), db: Session = Depends(get_session)):
|
||||||
if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL:
|
if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL:
|
||||||
prompts = Prompts.get_prompts()
|
prompts = Prompts.get_prompts(db=db)
|
||||||
else:
|
else:
|
||||||
prompts = Prompts.get_prompts_by_user_id(user.id, "write")
|
prompts = Prompts.get_prompts_by_user_id(user.id, "write", db=db)
|
||||||
|
|
||||||
return prompts
|
return prompts
|
||||||
|
|
||||||
|
|
@ -46,16 +48,17 @@ async def get_prompt_list(user=Depends(get_verified_user)):
|
||||||
|
|
||||||
@router.post("/create", response_model=Optional[PromptModel])
|
@router.post("/create", response_model=Optional[PromptModel])
|
||||||
async def create_new_prompt(
|
async def create_new_prompt(
|
||||||
request: Request, form_data: PromptForm, user=Depends(get_verified_user)
|
request: Request, form_data: PromptForm, user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||||
):
|
):
|
||||||
if user.role != "admin" and not (
|
if user.role != "admin" and not (
|
||||||
has_permission(
|
has_permission(
|
||||||
user.id, "workspace.prompts", request.app.state.config.USER_PERMISSIONS
|
user.id, "workspace.prompts", request.app.state.config.USER_PERMISSIONS, db=db
|
||||||
)
|
)
|
||||||
or has_permission(
|
or has_permission(
|
||||||
user.id,
|
user.id,
|
||||||
"workspace.prompts_import",
|
"workspace.prompts_import",
|
||||||
request.app.state.config.USER_PERMISSIONS,
|
request.app.state.config.USER_PERMISSIONS,
|
||||||
|
db=db,
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -63,9 +66,9 @@ async def create_new_prompt(
|
||||||
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt = Prompts.get_prompt_by_command(form_data.command)
|
prompt = Prompts.get_prompt_by_command(form_data.command, db=db)
|
||||||
if prompt is None:
|
if prompt is None:
|
||||||
prompt = Prompts.insert_new_prompt(user.id, form_data)
|
prompt = Prompts.insert_new_prompt(user.id, form_data, db=db)
|
||||||
|
|
||||||
if prompt:
|
if prompt:
|
||||||
return prompt
|
return prompt
|
||||||
|
|
@ -85,14 +88,14 @@ async def create_new_prompt(
|
||||||
|
|
||||||
|
|
||||||
@router.get("/command/{command}", response_model=Optional[PromptModel])
|
@router.get("/command/{command}", response_model=Optional[PromptModel])
|
||||||
async def get_prompt_by_command(command: str, user=Depends(get_verified_user)):
|
async def get_prompt_by_command(command: str, user=Depends(get_verified_user), db: Session = Depends(get_session)):
|
||||||
prompt = Prompts.get_prompt_by_command(f"/{command}")
|
prompt = Prompts.get_prompt_by_command(f"/{command}", db=db)
|
||||||
|
|
||||||
if prompt:
|
if prompt:
|
||||||
if (
|
if (
|
||||||
user.role == "admin"
|
user.role == "admin"
|
||||||
or prompt.user_id == user.id
|
or prompt.user_id == user.id
|
||||||
or has_access(user.id, "read", prompt.access_control)
|
or has_access(user.id, "read", prompt.access_control, db=db)
|
||||||
):
|
):
|
||||||
return prompt
|
return prompt
|
||||||
else:
|
else:
|
||||||
|
|
@ -112,8 +115,9 @@ async def update_prompt_by_command(
|
||||||
command: str,
|
command: str,
|
||||||
form_data: PromptForm,
|
form_data: PromptForm,
|
||||||
user=Depends(get_verified_user),
|
user=Depends(get_verified_user),
|
||||||
|
db: Session = Depends(get_session),
|
||||||
):
|
):
|
||||||
prompt = Prompts.get_prompt_by_command(f"/{command}")
|
prompt = Prompts.get_prompt_by_command(f"/{command}", db=db)
|
||||||
if not prompt:
|
if not prompt:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
|
@ -123,7 +127,7 @@ async def update_prompt_by_command(
|
||||||
# Is the user the original creator, in a group with write access, or an admin
|
# Is the user the original creator, in a group with write access, or an admin
|
||||||
if (
|
if (
|
||||||
prompt.user_id != user.id
|
prompt.user_id != user.id
|
||||||
and not has_access(user.id, "write", prompt.access_control)
|
and not has_access(user.id, "write", prompt.access_control, db=db)
|
||||||
and user.role != "admin"
|
and user.role != "admin"
|
||||||
):
|
):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -131,7 +135,7 @@ async def update_prompt_by_command(
|
||||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt = Prompts.update_prompt_by_command(f"/{command}", form_data)
|
prompt = Prompts.update_prompt_by_command(f"/{command}", form_data, db=db)
|
||||||
if prompt:
|
if prompt:
|
||||||
return prompt
|
return prompt
|
||||||
else:
|
else:
|
||||||
|
|
@ -147,8 +151,8 @@ async def update_prompt_by_command(
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/command/{command}/delete", response_model=bool)
|
@router.delete("/command/{command}/delete", response_model=bool)
|
||||||
async def delete_prompt_by_command(command: str, user=Depends(get_verified_user)):
|
async def delete_prompt_by_command(command: str, user=Depends(get_verified_user), db: Session = Depends(get_session)):
|
||||||
prompt = Prompts.get_prompt_by_command(f"/{command}")
|
prompt = Prompts.get_prompt_by_command(f"/{command}", db=db)
|
||||||
if not prompt:
|
if not prompt:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
|
@ -157,7 +161,7 @@ async def delete_prompt_by_command(command: str, user=Depends(get_verified_user)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
prompt.user_id != user.id
|
prompt.user_id != user.id
|
||||||
and not has_access(user.id, "write", prompt.access_control)
|
and not has_access(user.id, "write", prompt.access_control, db=db)
|
||||||
and user.role != "admin"
|
and user.role != "admin"
|
||||||
):
|
):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -165,5 +169,5 @@ async def delete_prompt_by_command(command: str, user=Depends(get_verified_user)
|
||||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||||
)
|
)
|
||||||
|
|
||||||
result = Prompts.delete_prompt_by_command(f"/{command}")
|
result = Prompts.delete_prompt_by_command(f"/{command}", db=db)
|
||||||
return result
|
return result
|
||||||
|
|
|
||||||
|
|
@ -39,6 +39,8 @@ from langchain_core.documents import Document
|
||||||
from open_webui.models.files import FileModel, FileUpdateForm, Files
|
from open_webui.models.files import FileModel, FileUpdateForm, Files
|
||||||
from open_webui.models.knowledge import Knowledges
|
from open_webui.models.knowledge import Knowledges
|
||||||
from open_webui.storage.provider import Storage
|
from open_webui.storage.provider import Storage
|
||||||
|
from open_webui.internal.db import get_session
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
|
||||||
from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT
|
from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT
|
||||||
|
|
@ -1484,14 +1486,15 @@ def process_file(
|
||||||
request: Request,
|
request: Request,
|
||||||
form_data: ProcessFileForm,
|
form_data: ProcessFileForm,
|
||||||
user=Depends(get_verified_user),
|
user=Depends(get_verified_user),
|
||||||
|
db: Session = Depends(get_session),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Process a file and save its content to the vector database.
|
Process a file and save its content to the vector database.
|
||||||
"""
|
"""
|
||||||
if user.role == "admin":
|
if user.role == "admin":
|
||||||
file = Files.get_file_by_id(form_data.file_id)
|
file = Files.get_file_by_id(form_data.file_id, db=db)
|
||||||
else:
|
else:
|
||||||
file = Files.get_file_by_id_and_user_id(form_data.file_id, user.id)
|
file = Files.get_file_by_id_and_user_id(form_data.file_id, user.id, db=db)
|
||||||
|
|
||||||
if file:
|
if file:
|
||||||
try:
|
try:
|
||||||
|
|
@ -1633,12 +1636,13 @@ def process_file(
|
||||||
Files.update_file_data_by_id(
|
Files.update_file_data_by_id(
|
||||||
file.id,
|
file.id,
|
||||||
{"content": text_content},
|
{"content": text_content},
|
||||||
|
db=db,
|
||||||
)
|
)
|
||||||
hash = calculate_sha256_string(text_content)
|
hash = calculate_sha256_string(text_content)
|
||||||
Files.update_file_hash_by_id(file.id, hash)
|
Files.update_file_hash_by_id(file.id, hash, db=db)
|
||||||
|
|
||||||
if request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL:
|
if request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL:
|
||||||
Files.update_file_data_by_id(file.id, {"status": "completed"})
|
Files.update_file_data_by_id(file.id, {"status": "completed"}, db=db)
|
||||||
return {
|
return {
|
||||||
"status": True,
|
"status": True,
|
||||||
"collection_name": None,
|
"collection_name": None,
|
||||||
|
|
@ -1667,11 +1671,13 @@ def process_file(
|
||||||
{
|
{
|
||||||
"collection_name": collection_name,
|
"collection_name": collection_name,
|
||||||
},
|
},
|
||||||
|
db=db,
|
||||||
)
|
)
|
||||||
|
|
||||||
Files.update_file_data_by_id(
|
Files.update_file_data_by_id(
|
||||||
file.id,
|
file.id,
|
||||||
{"status": "completed"},
|
{"status": "completed"},
|
||||||
|
db=db,
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|
@ -1690,6 +1696,7 @@ def process_file(
|
||||||
Files.update_file_data_by_id(
|
Files.update_file_data_by_id(
|
||||||
file.id,
|
file.id,
|
||||||
{"status": "failed"},
|
{"status": "failed"},
|
||||||
|
db=db,
|
||||||
)
|
)
|
||||||
|
|
||||||
if "No pandoc was found" in str(e):
|
if "No pandoc was found" in str(e):
|
||||||
|
|
@ -2417,10 +2424,10 @@ class DeleteForm(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
@router.post("/delete")
|
@router.post("/delete")
|
||||||
def delete_entries_from_collection(form_data: DeleteForm, user=Depends(get_admin_user)):
|
def delete_entries_from_collection(form_data: DeleteForm, user=Depends(get_admin_user), db: Session = Depends(get_session)):
|
||||||
try:
|
try:
|
||||||
if VECTOR_DB_CLIENT.has_collection(collection_name=form_data.collection_name):
|
if VECTOR_DB_CLIENT.has_collection(collection_name=form_data.collection_name):
|
||||||
file = Files.get_file_by_id(form_data.file_id)
|
file = Files.get_file_by_id(form_data.file_id, db=db)
|
||||||
hash = file.hash
|
hash = file.hash
|
||||||
|
|
||||||
VECTOR_DB_CLIENT.delete(
|
VECTOR_DB_CLIENT.delete(
|
||||||
|
|
@ -2436,9 +2443,9 @@ def delete_entries_from_collection(form_data: DeleteForm, user=Depends(get_admin
|
||||||
|
|
||||||
|
|
||||||
@router.post("/reset/db")
|
@router.post("/reset/db")
|
||||||
def reset_vector_db(user=Depends(get_admin_user)):
|
def reset_vector_db(user=Depends(get_admin_user), db: Session = Depends(get_session)):
|
||||||
VECTOR_DB_CLIENT.reset()
|
VECTOR_DB_CLIENT.reset()
|
||||||
Knowledges.delete_all_knowledge()
|
Knowledges.delete_all_knowledge(db=db)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/reset/uploads")
|
@router.post("/reset/uploads")
|
||||||
|
|
@ -2496,6 +2503,7 @@ async def process_files_batch(
|
||||||
request: Request,
|
request: Request,
|
||||||
form_data: BatchProcessFilesForm,
|
form_data: BatchProcessFilesForm,
|
||||||
user=Depends(get_verified_user),
|
user=Depends(get_verified_user),
|
||||||
|
db: Session = Depends(get_session),
|
||||||
) -> BatchProcessFilesResponse:
|
) -> BatchProcessFilesResponse:
|
||||||
"""
|
"""
|
||||||
Process a batch of files and save them to the vector database.
|
Process a batch of files and save them to the vector database.
|
||||||
|
|
@ -2558,7 +2566,7 @@ async def process_files_batch(
|
||||||
|
|
||||||
# Update all files with collection name
|
# Update all files with collection name
|
||||||
for file_update, file_result in zip(file_updates, file_results):
|
for file_update, file_result in zip(file_updates, file_results):
|
||||||
Files.update_file_by_id(id=file_result.file_id, form_data=file_update)
|
Files.update_file_by_id(id=file_result.file_id, form_data=file_update, db=db)
|
||||||
file_result.status = "completed"
|
file_result.status = "completed"
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,8 @@ import aiohttp
|
||||||
from open_webui.models.groups import Groups
|
from open_webui.models.groups import Groups
|
||||||
from pydantic import BaseModel, HttpUrl
|
from pydantic import BaseModel, HttpUrl
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from open_webui.internal.db import get_session
|
||||||
|
|
||||||
|
|
||||||
from open_webui.models.oauth_sessions import OAuthSessions
|
from open_webui.models.oauth_sessions import OAuthSessions
|
||||||
|
|
@ -51,11 +53,11 @@ def get_tool_module(request, tool_id, load_from_db=True):
|
||||||
|
|
||||||
|
|
||||||
@router.get("/", response_model=list[ToolUserResponse])
|
@router.get("/", response_model=list[ToolUserResponse])
|
||||||
async def get_tools(request: Request, user=Depends(get_verified_user)):
|
async def get_tools(request: Request, user=Depends(get_verified_user), db: Session = Depends(get_session)):
|
||||||
tools = []
|
tools = []
|
||||||
|
|
||||||
# Local Tools
|
# Local Tools
|
||||||
for tool in Tools.get_tools():
|
for tool in Tools.get_tools(db=db):
|
||||||
tool_module = get_tool_module(request, tool.id)
|
tool_module = get_tool_module(request, tool.id)
|
||||||
tools.append(
|
tools.append(
|
||||||
ToolUserResponse(
|
ToolUserResponse(
|
||||||
|
|
@ -140,12 +142,12 @@ async def get_tools(request: Request, user=Depends(get_verified_user)):
|
||||||
# Admin can see all tools
|
# Admin can see all tools
|
||||||
return tools
|
return tools
|
||||||
else:
|
else:
|
||||||
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user.id)}
|
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user.id, db=db)}
|
||||||
tools = [
|
tools = [
|
||||||
tool
|
tool
|
||||||
for tool in tools
|
for tool in tools
|
||||||
if tool.user_id == user.id
|
if tool.user_id == user.id
|
||||||
or has_access(user.id, "read", tool.access_control, user_group_ids)
|
or has_access(user.id, "read", tool.access_control, user_group_ids, db=db)
|
||||||
]
|
]
|
||||||
return tools
|
return tools
|
||||||
|
|
||||||
|
|
@ -156,11 +158,11 @@ async def get_tools(request: Request, user=Depends(get_verified_user)):
|
||||||
|
|
||||||
|
|
||||||
@router.get("/list", response_model=list[ToolUserResponse])
|
@router.get("/list", response_model=list[ToolUserResponse])
|
||||||
async def get_tool_list(user=Depends(get_verified_user)):
|
async def get_tool_list(user=Depends(get_verified_user), db: Session = Depends(get_session)):
|
||||||
if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL:
|
if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL:
|
||||||
tools = Tools.get_tools()
|
tools = Tools.get_tools(db=db)
|
||||||
else:
|
else:
|
||||||
tools = Tools.get_tools_by_user_id(user.id, "write")
|
tools = Tools.get_tools_by_user_id(user.id, "write", db=db)
|
||||||
return tools
|
return tools
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -245,9 +247,9 @@ async def load_tool_from_url(
|
||||||
|
|
||||||
|
|
||||||
@router.get("/export", response_model=list[ToolModel])
|
@router.get("/export", response_model=list[ToolModel])
|
||||||
async def export_tools(request: Request, user=Depends(get_verified_user)):
|
async def export_tools(request: Request, user=Depends(get_verified_user), db: Session = Depends(get_session)):
|
||||||
if user.role != "admin" and not has_permission(
|
if user.role != "admin" and not has_permission(
|
||||||
user.id, "workspace.tools_export", request.app.state.config.USER_PERMISSIONS
|
user.id, "workspace.tools_export", request.app.state.config.USER_PERMISSIONS, db=db
|
||||||
):
|
):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
|
@ -255,9 +257,9 @@ async def export_tools(request: Request, user=Depends(get_verified_user)):
|
||||||
)
|
)
|
||||||
|
|
||||||
if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL:
|
if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL:
|
||||||
return Tools.get_tools()
|
return Tools.get_tools(db=db)
|
||||||
else:
|
else:
|
||||||
return Tools.get_tools_by_user_id(user.id, "read")
|
return Tools.get_tools_by_user_id(user.id, "read", db=db)
|
||||||
|
|
||||||
|
|
||||||
############################
|
############################
|
||||||
|
|
@ -270,13 +272,14 @@ async def create_new_tools(
|
||||||
request: Request,
|
request: Request,
|
||||||
form_data: ToolForm,
|
form_data: ToolForm,
|
||||||
user=Depends(get_verified_user),
|
user=Depends(get_verified_user),
|
||||||
|
db: Session = Depends(get_session),
|
||||||
):
|
):
|
||||||
if user.role != "admin" and not (
|
if user.role != "admin" and not (
|
||||||
has_permission(
|
has_permission(
|
||||||
user.id, "workspace.tools", request.app.state.config.USER_PERMISSIONS
|
user.id, "workspace.tools", request.app.state.config.USER_PERMISSIONS, db=db
|
||||||
)
|
)
|
||||||
or has_permission(
|
or has_permission(
|
||||||
user.id, "workspace.tools_import", request.app.state.config.USER_PERMISSIONS
|
user.id, "workspace.tools_import", request.app.state.config.USER_PERMISSIONS, db=db
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -292,7 +295,7 @@ async def create_new_tools(
|
||||||
|
|
||||||
form_data.id = form_data.id.lower()
|
form_data.id = form_data.id.lower()
|
||||||
|
|
||||||
tools = Tools.get_tool_by_id(form_data.id)
|
tools = Tools.get_tool_by_id(form_data.id, db=db)
|
||||||
if tools is None:
|
if tools is None:
|
||||||
try:
|
try:
|
||||||
form_data.content = replace_imports(form_data.content)
|
form_data.content = replace_imports(form_data.content)
|
||||||
|
|
@ -305,7 +308,7 @@ async def create_new_tools(
|
||||||
TOOLS[form_data.id] = tool_module
|
TOOLS[form_data.id] = tool_module
|
||||||
|
|
||||||
specs = get_tool_specs(TOOLS[form_data.id])
|
specs = get_tool_specs(TOOLS[form_data.id])
|
||||||
tools = Tools.insert_new_tool(user.id, form_data, specs)
|
tools = Tools.insert_new_tool(user.id, form_data, specs, db=db)
|
||||||
|
|
||||||
tool_cache_dir = CACHE_DIR / "tools" / form_data.id
|
tool_cache_dir = CACHE_DIR / "tools" / form_data.id
|
||||||
tool_cache_dir.mkdir(parents=True, exist_ok=True)
|
tool_cache_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
@ -336,14 +339,14 @@ async def create_new_tools(
|
||||||
|
|
||||||
|
|
||||||
@router.get("/id/{id}", response_model=Optional[ToolModel])
|
@router.get("/id/{id}", response_model=Optional[ToolModel])
|
||||||
async def get_tools_by_id(id: str, user=Depends(get_verified_user)):
|
async def get_tools_by_id(id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)):
|
||||||
tools = Tools.get_tool_by_id(id)
|
tools = Tools.get_tool_by_id(id, db=db)
|
||||||
|
|
||||||
if tools:
|
if tools:
|
||||||
if (
|
if (
|
||||||
user.role == "admin"
|
user.role == "admin"
|
||||||
or tools.user_id == user.id
|
or tools.user_id == user.id
|
||||||
or has_access(user.id, "read", tools.access_control)
|
or has_access(user.id, "read", tools.access_control, db=db)
|
||||||
):
|
):
|
||||||
return tools
|
return tools
|
||||||
else:
|
else:
|
||||||
|
|
@ -364,8 +367,9 @@ async def update_tools_by_id(
|
||||||
id: str,
|
id: str,
|
||||||
form_data: ToolForm,
|
form_data: ToolForm,
|
||||||
user=Depends(get_verified_user),
|
user=Depends(get_verified_user),
|
||||||
|
db: Session = Depends(get_session),
|
||||||
):
|
):
|
||||||
tools = Tools.get_tool_by_id(id)
|
tools = Tools.get_tool_by_id(id, db=db)
|
||||||
if not tools:
|
if not tools:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
|
@ -375,7 +379,7 @@ async def update_tools_by_id(
|
||||||
# Is the user the original creator, in a group with write access, or an admin
|
# Is the user the original creator, in a group with write access, or an admin
|
||||||
if (
|
if (
|
||||||
tools.user_id != user.id
|
tools.user_id != user.id
|
||||||
and not has_access(user.id, "write", tools.access_control)
|
and not has_access(user.id, "write", tools.access_control, db=db)
|
||||||
and user.role != "admin"
|
and user.role != "admin"
|
||||||
):
|
):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -399,7 +403,7 @@ async def update_tools_by_id(
|
||||||
}
|
}
|
||||||
|
|
||||||
log.debug(updated)
|
log.debug(updated)
|
||||||
tools = Tools.update_tool_by_id(id, updated)
|
tools = Tools.update_tool_by_id(id, updated, db=db)
|
||||||
|
|
||||||
if tools:
|
if tools:
|
||||||
return tools
|
return tools
|
||||||
|
|
@ -423,9 +427,9 @@ async def update_tools_by_id(
|
||||||
|
|
||||||
@router.delete("/id/{id}/delete", response_model=bool)
|
@router.delete("/id/{id}/delete", response_model=bool)
|
||||||
async def delete_tools_by_id(
|
async def delete_tools_by_id(
|
||||||
request: Request, id: str, user=Depends(get_verified_user)
|
request: Request, id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||||
):
|
):
|
||||||
tools = Tools.get_tool_by_id(id)
|
tools = Tools.get_tool_by_id(id, db=db)
|
||||||
if not tools:
|
if not tools:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
|
@ -434,7 +438,7 @@ async def delete_tools_by_id(
|
||||||
|
|
||||||
if (
|
if (
|
||||||
tools.user_id != user.id
|
tools.user_id != user.id
|
||||||
and not has_access(user.id, "write", tools.access_control)
|
and not has_access(user.id, "write", tools.access_control, db=db)
|
||||||
and user.role != "admin"
|
and user.role != "admin"
|
||||||
):
|
):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -442,7 +446,7 @@ async def delete_tools_by_id(
|
||||||
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
||||||
)
|
)
|
||||||
|
|
||||||
result = Tools.delete_tool_by_id(id)
|
result = Tools.delete_tool_by_id(id, db=db)
|
||||||
if result:
|
if result:
|
||||||
TOOLS = request.app.state.TOOLS
|
TOOLS = request.app.state.TOOLS
|
||||||
if id in TOOLS:
|
if id in TOOLS:
|
||||||
|
|
@ -457,11 +461,11 @@ async def delete_tools_by_id(
|
||||||
|
|
||||||
|
|
||||||
@router.get("/id/{id}/valves", response_model=Optional[dict])
|
@router.get("/id/{id}/valves", response_model=Optional[dict])
|
||||||
async def get_tools_valves_by_id(id: str, user=Depends(get_verified_user)):
|
async def get_tools_valves_by_id(id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)):
|
||||||
tools = Tools.get_tool_by_id(id)
|
tools = Tools.get_tool_by_id(id, db=db)
|
||||||
if tools:
|
if tools:
|
||||||
try:
|
try:
|
||||||
valves = Tools.get_tool_valves_by_id(id)
|
valves = Tools.get_tool_valves_by_id(id, db=db)
|
||||||
return valves
|
return valves
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -482,9 +486,9 @@ async def get_tools_valves_by_id(id: str, user=Depends(get_verified_user)):
|
||||||
|
|
||||||
@router.get("/id/{id}/valves/spec", response_model=Optional[dict])
|
@router.get("/id/{id}/valves/spec", response_model=Optional[dict])
|
||||||
async def get_tools_valves_spec_by_id(
|
async def get_tools_valves_spec_by_id(
|
||||||
request: Request, id: str, user=Depends(get_verified_user)
|
request: Request, id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||||
):
|
):
|
||||||
tools = Tools.get_tool_by_id(id)
|
tools = Tools.get_tool_by_id(id, db=db)
|
||||||
if tools:
|
if tools:
|
||||||
if id in request.app.state.TOOLS:
|
if id in request.app.state.TOOLS:
|
||||||
tools_module = request.app.state.TOOLS[id]
|
tools_module = request.app.state.TOOLS[id]
|
||||||
|
|
@ -510,9 +514,9 @@ async def get_tools_valves_spec_by_id(
|
||||||
|
|
||||||
@router.post("/id/{id}/valves/update", response_model=Optional[dict])
|
@router.post("/id/{id}/valves/update", response_model=Optional[dict])
|
||||||
async def update_tools_valves_by_id(
|
async def update_tools_valves_by_id(
|
||||||
request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
|
request: Request, id: str, form_data: dict, user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||||
):
|
):
|
||||||
tools = Tools.get_tool_by_id(id)
|
tools = Tools.get_tool_by_id(id, db=db)
|
||||||
if not tools:
|
if not tools:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
|
@ -521,7 +525,7 @@ async def update_tools_valves_by_id(
|
||||||
|
|
||||||
if (
|
if (
|
||||||
tools.user_id != user.id
|
tools.user_id != user.id
|
||||||
and not has_access(user.id, "write", tools.access_control)
|
and not has_access(user.id, "write", tools.access_control, db=db)
|
||||||
and user.role != "admin"
|
and user.role != "admin"
|
||||||
):
|
):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -546,7 +550,7 @@ async def update_tools_valves_by_id(
|
||||||
form_data = {k: v for k, v in form_data.items() if v is not None}
|
form_data = {k: v for k, v in form_data.items() if v is not None}
|
||||||
valves = Valves(**form_data)
|
valves = Valves(**form_data)
|
||||||
valves_dict = valves.model_dump(exclude_unset=True)
|
valves_dict = valves.model_dump(exclude_unset=True)
|
||||||
Tools.update_tool_valves_by_id(id, valves_dict)
|
Tools.update_tool_valves_by_id(id, valves_dict, db=db)
|
||||||
return valves_dict
|
return valves_dict
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(f"Failed to update tool valves by id {id}: {e}")
|
log.exception(f"Failed to update tool valves by id {id}: {e}")
|
||||||
|
|
@ -562,11 +566,11 @@ async def update_tools_valves_by_id(
|
||||||
|
|
||||||
|
|
||||||
@router.get("/id/{id}/valves/user", response_model=Optional[dict])
|
@router.get("/id/{id}/valves/user", response_model=Optional[dict])
|
||||||
async def get_tools_user_valves_by_id(id: str, user=Depends(get_verified_user)):
|
async def get_tools_user_valves_by_id(id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)):
|
||||||
tools = Tools.get_tool_by_id(id)
|
tools = Tools.get_tool_by_id(id, db=db)
|
||||||
if tools:
|
if tools:
|
||||||
try:
|
try:
|
||||||
user_valves = Tools.get_user_valves_by_id_and_user_id(id, user.id)
|
user_valves = Tools.get_user_valves_by_id_and_user_id(id, user.id, db=db)
|
||||||
return user_valves
|
return user_valves
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -582,9 +586,9 @@ async def get_tools_user_valves_by_id(id: str, user=Depends(get_verified_user)):
|
||||||
|
|
||||||
@router.get("/id/{id}/valves/user/spec", response_model=Optional[dict])
|
@router.get("/id/{id}/valves/user/spec", response_model=Optional[dict])
|
||||||
async def get_tools_user_valves_spec_by_id(
|
async def get_tools_user_valves_spec_by_id(
|
||||||
request: Request, id: str, user=Depends(get_verified_user)
|
request: Request, id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||||
):
|
):
|
||||||
tools = Tools.get_tool_by_id(id)
|
tools = Tools.get_tool_by_id(id, db=db)
|
||||||
if tools:
|
if tools:
|
||||||
if id in request.app.state.TOOLS:
|
if id in request.app.state.TOOLS:
|
||||||
tools_module = request.app.state.TOOLS[id]
|
tools_module = request.app.state.TOOLS[id]
|
||||||
|
|
@ -605,9 +609,9 @@ async def get_tools_user_valves_spec_by_id(
|
||||||
|
|
||||||
@router.post("/id/{id}/valves/user/update", response_model=Optional[dict])
|
@router.post("/id/{id}/valves/user/update", response_model=Optional[dict])
|
||||||
async def update_tools_user_valves_by_id(
|
async def update_tools_user_valves_by_id(
|
||||||
request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
|
request: Request, id: str, form_data: dict, user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||||
):
|
):
|
||||||
tools = Tools.get_tool_by_id(id)
|
tools = Tools.get_tool_by_id(id, db=db)
|
||||||
|
|
||||||
if tools:
|
if tools:
|
||||||
if id in request.app.state.TOOLS:
|
if id in request.app.state.TOOLS:
|
||||||
|
|
@ -624,7 +628,7 @@ async def update_tools_user_valves_by_id(
|
||||||
user_valves = UserValves(**form_data)
|
user_valves = UserValves(**form_data)
|
||||||
user_valves_dict = user_valves.model_dump(exclude_unset=True)
|
user_valves_dict = user_valves.model_dump(exclude_unset=True)
|
||||||
Tools.update_user_valves_by_id_and_user_id(
|
Tools.update_user_valves_by_id_and_user_id(
|
||||||
id, user.id, user_valves_dict
|
id, user.id, user_valves_dict, db=db
|
||||||
)
|
)
|
||||||
return user_valves_dict
|
return user_valves_dict
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
import base64
|
import base64
|
||||||
import io
|
import io
|
||||||
|
|
||||||
|
|
@ -29,6 +30,7 @@ from open_webui.models.users import (
|
||||||
|
|
||||||
from open_webui.constants import ERROR_MESSAGES
|
from open_webui.constants import ERROR_MESSAGES
|
||||||
from open_webui.env import STATIC_DIR
|
from open_webui.env import STATIC_DIR
|
||||||
|
from open_webui.internal.db import get_session
|
||||||
|
|
||||||
|
|
||||||
from open_webui.utils.auth import (
|
from open_webui.utils.auth import (
|
||||||
|
|
@ -60,6 +62,7 @@ async def get_users(
|
||||||
direction: Optional[str] = None,
|
direction: Optional[str] = None,
|
||||||
page: Optional[int] = 1,
|
page: Optional[int] = 1,
|
||||||
user=Depends(get_admin_user),
|
user=Depends(get_admin_user),
|
||||||
|
db: Session = Depends(get_session),
|
||||||
):
|
):
|
||||||
limit = PAGE_ITEM_COUNT
|
limit = PAGE_ITEM_COUNT
|
||||||
|
|
||||||
|
|
@ -74,7 +77,9 @@ async def get_users(
|
||||||
if direction:
|
if direction:
|
||||||
filter["direction"] = direction
|
filter["direction"] = direction
|
||||||
|
|
||||||
result = Users.get_users(filter=filter, skip=skip, limit=limit)
|
filter["direction"] = direction
|
||||||
|
|
||||||
|
result = Users.get_users(filter=filter, skip=skip, limit=limit, db=db)
|
||||||
|
|
||||||
users = result["users"]
|
users = result["users"]
|
||||||
total = result["total"]
|
total = result["total"]
|
||||||
|
|
@ -85,7 +90,8 @@ async def get_users(
|
||||||
**{
|
**{
|
||||||
**user.model_dump(),
|
**user.model_dump(),
|
||||||
"group_ids": [
|
"group_ids": [
|
||||||
group.id for group in Groups.get_groups_by_member_id(user.id)
|
group.id
|
||||||
|
for group in Groups.get_groups_by_member_id(user.id, db=db)
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
@ -98,8 +104,9 @@ async def get_users(
|
||||||
@router.get("/all", response_model=UserInfoListResponse)
|
@router.get("/all", response_model=UserInfoListResponse)
|
||||||
async def get_all_users(
|
async def get_all_users(
|
||||||
user=Depends(get_admin_user),
|
user=Depends(get_admin_user),
|
||||||
|
db: Session = Depends(get_session),
|
||||||
):
|
):
|
||||||
return Users.get_users()
|
return Users.get_users(db=db)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/search", response_model=UserInfoListResponse)
|
@router.get("/search", response_model=UserInfoListResponse)
|
||||||
|
|
@ -109,16 +116,13 @@ async def search_users(
|
||||||
direction: Optional[str] = None,
|
direction: Optional[str] = None,
|
||||||
page: Optional[int] = 1,
|
page: Optional[int] = 1,
|
||||||
user=Depends(get_verified_user),
|
user=Depends(get_verified_user),
|
||||||
|
db: Session = Depends(get_session),
|
||||||
):
|
):
|
||||||
limit = PAGE_ITEM_COUNT
|
limit = PAGE_ITEM_COUNT
|
||||||
|
|
||||||
page = max(1, page)
|
page = max(1, page)
|
||||||
skip = (page - 1) * limit
|
skip = (page - 1) * limit
|
||||||
|
|
||||||
filter = {}
|
|
||||||
if query:
|
|
||||||
filter["query"] = query
|
|
||||||
|
|
||||||
filter = {}
|
filter = {}
|
||||||
if query:
|
if query:
|
||||||
filter["query"] = query
|
filter["query"] = query
|
||||||
|
|
@ -127,7 +131,7 @@ async def search_users(
|
||||||
if direction:
|
if direction:
|
||||||
filter["direction"] = direction
|
filter["direction"] = direction
|
||||||
|
|
||||||
return Users.get_users(filter=filter, skip=skip, limit=limit)
|
return Users.get_users(filter=filter, skip=skip, limit=limit, db=db)
|
||||||
|
|
||||||
|
|
||||||
############################
|
############################
|
||||||
|
|
@ -136,8 +140,10 @@ async def search_users(
|
||||||
|
|
||||||
|
|
||||||
@router.get("/groups")
|
@router.get("/groups")
|
||||||
async def get_user_groups(user=Depends(get_verified_user)):
|
async def get_user_groups(
|
||||||
return Groups.get_groups_by_member_id(user.id)
|
user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||||
|
):
|
||||||
|
return Groups.get_groups_by_member_id(user.id, db=db)
|
||||||
|
|
||||||
|
|
||||||
############################
|
############################
|
||||||
|
|
@ -146,9 +152,13 @@ async def get_user_groups(user=Depends(get_verified_user)):
|
||||||
|
|
||||||
|
|
||||||
@router.get("/permissions")
|
@router.get("/permissions")
|
||||||
async def get_user_permissisions(request: Request, user=Depends(get_verified_user)):
|
async def get_user_permissisions(
|
||||||
|
request: Request,
|
||||||
|
user=Depends(get_verified_user),
|
||||||
|
db: Session = Depends(get_session),
|
||||||
|
):
|
||||||
user_permissions = get_permissions(
|
user_permissions = get_permissions(
|
||||||
user.id, request.app.state.config.USER_PERMISSIONS
|
user.id, request.app.state.config.USER_PERMISSIONS, db=db
|
||||||
)
|
)
|
||||||
|
|
||||||
return user_permissions
|
return user_permissions
|
||||||
|
|
@ -256,8 +266,10 @@ async def update_default_user_permissions(
|
||||||
|
|
||||||
|
|
||||||
@router.get("/user/settings", response_model=Optional[UserSettings])
|
@router.get("/user/settings", response_model=Optional[UserSettings])
|
||||||
async def get_user_settings_by_session_user(user=Depends(get_verified_user)):
|
async def get_user_settings_by_session_user(
|
||||||
user = Users.get_user_by_id(user.id)
|
user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||||
|
):
|
||||||
|
user = Users.get_user_by_id(user.id, db=db)
|
||||||
if user:
|
if user:
|
||||||
return user.settings
|
return user.settings
|
||||||
else:
|
else:
|
||||||
|
|
@ -274,7 +286,10 @@ async def get_user_settings_by_session_user(user=Depends(get_verified_user)):
|
||||||
|
|
||||||
@router.post("/user/settings/update", response_model=UserSettings)
|
@router.post("/user/settings/update", response_model=UserSettings)
|
||||||
async def update_user_settings_by_session_user(
|
async def update_user_settings_by_session_user(
|
||||||
request: Request, form_data: UserSettings, user=Depends(get_verified_user)
|
request: Request,
|
||||||
|
form_data: UserSettings,
|
||||||
|
user=Depends(get_verified_user),
|
||||||
|
db: Session = Depends(get_session),
|
||||||
):
|
):
|
||||||
updated_user_settings = form_data.model_dump()
|
updated_user_settings = form_data.model_dump()
|
||||||
if (
|
if (
|
||||||
|
|
@ -289,7 +304,7 @@ async def update_user_settings_by_session_user(
|
||||||
# If the user is not an admin and does not have permission to use tool servers, remove the key
|
# If the user is not an admin and does not have permission to use tool servers, remove the key
|
||||||
updated_user_settings["ui"].pop("toolServers", None)
|
updated_user_settings["ui"].pop("toolServers", None)
|
||||||
|
|
||||||
user = Users.update_user_settings_by_id(user.id, updated_user_settings)
|
user = Users.update_user_settings_by_id(user.id, updated_user_settings, db=db)
|
||||||
if user:
|
if user:
|
||||||
return user.settings
|
return user.settings
|
||||||
else:
|
else:
|
||||||
|
|
@ -305,8 +320,10 @@ async def update_user_settings_by_session_user(
|
||||||
|
|
||||||
|
|
||||||
@router.get("/user/status")
|
@router.get("/user/status")
|
||||||
async def get_user_status_by_session_user(user=Depends(get_verified_user)):
|
async def get_user_status_by_session_user(
|
||||||
user = Users.get_user_by_id(user.id)
|
user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||||
|
):
|
||||||
|
user = Users.get_user_by_id(user.id, db=db)
|
||||||
if user:
|
if user:
|
||||||
return user
|
return user
|
||||||
else:
|
else:
|
||||||
|
|
@ -323,11 +340,13 @@ async def get_user_status_by_session_user(user=Depends(get_verified_user)):
|
||||||
|
|
||||||
@router.post("/user/status/update")
|
@router.post("/user/status/update")
|
||||||
async def update_user_status_by_session_user(
|
async def update_user_status_by_session_user(
|
||||||
form_data: UserStatus, user=Depends(get_verified_user)
|
form_data: UserStatus,
|
||||||
|
user=Depends(get_verified_user),
|
||||||
|
db: Session = Depends(get_session),
|
||||||
):
|
):
|
||||||
user = Users.get_user_by_id(user.id)
|
user = Users.get_user_by_id(user.id, db=db)
|
||||||
if user:
|
if user:
|
||||||
user = Users.update_user_status_by_id(user.id, form_data)
|
user = Users.update_user_status_by_id(user.id, form_data, db=db)
|
||||||
return user
|
return user
|
||||||
else:
|
else:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -342,8 +361,10 @@ async def update_user_status_by_session_user(
|
||||||
|
|
||||||
|
|
||||||
@router.get("/user/info", response_model=Optional[dict])
|
@router.get("/user/info", response_model=Optional[dict])
|
||||||
async def get_user_info_by_session_user(user=Depends(get_verified_user)):
|
async def get_user_info_by_session_user(
|
||||||
user = Users.get_user_by_id(user.id)
|
user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||||
|
):
|
||||||
|
user = Users.get_user_by_id(user.id, db=db)
|
||||||
if user:
|
if user:
|
||||||
return user.info
|
return user.info
|
||||||
else:
|
else:
|
||||||
|
|
@ -360,14 +381,16 @@ async def get_user_info_by_session_user(user=Depends(get_verified_user)):
|
||||||
|
|
||||||
@router.post("/user/info/update", response_model=Optional[dict])
|
@router.post("/user/info/update", response_model=Optional[dict])
|
||||||
async def update_user_info_by_session_user(
|
async def update_user_info_by_session_user(
|
||||||
form_data: dict, user=Depends(get_verified_user)
|
form_data: dict, user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||||
):
|
):
|
||||||
user = Users.get_user_by_id(user.id)
|
user = Users.get_user_by_id(user.id, db=db)
|
||||||
if user:
|
if user:
|
||||||
if user.info is None:
|
if user.info is None:
|
||||||
user.info = {}
|
user.info = {}
|
||||||
|
|
||||||
user = Users.update_user_by_id(user.id, {"info": {**user.info, **form_data}})
|
user = Users.update_user_by_id(
|
||||||
|
user.id, {"info": {**user.info, **form_data}}, db=db
|
||||||
|
)
|
||||||
if user:
|
if user:
|
||||||
return user.info
|
return user.info
|
||||||
else:
|
else:
|
||||||
|
|
@ -397,7 +420,9 @@ class UserActiveResponse(UserStatus):
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{user_id}", response_model=UserActiveResponse)
|
@router.get("/{user_id}", response_model=UserActiveResponse)
|
||||||
async def get_user_by_id(user_id: str, user=Depends(get_verified_user)):
|
async def get_user_by_id(
|
||||||
|
user_id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||||
|
):
|
||||||
# Check if user_id is a shared chat
|
# Check if user_id is a shared chat
|
||||||
# If it is, get the user_id from the chat
|
# If it is, get the user_id from the chat
|
||||||
if user_id.startswith("shared-"):
|
if user_id.startswith("shared-"):
|
||||||
|
|
@ -411,14 +436,14 @@ async def get_user_by_id(user_id: str, user=Depends(get_verified_user)):
|
||||||
detail=ERROR_MESSAGES.USER_NOT_FOUND,
|
detail=ERROR_MESSAGES.USER_NOT_FOUND,
|
||||||
)
|
)
|
||||||
|
|
||||||
user = Users.get_user_by_id(user_id)
|
user = Users.get_user_by_id(user_id, db=db)
|
||||||
if user:
|
if user:
|
||||||
groups = Groups.get_groups_by_member_id(user_id)
|
groups = Groups.get_groups_by_member_id(user_id, db=db)
|
||||||
return UserActiveResponse(
|
return UserActiveResponse(
|
||||||
**{
|
**{
|
||||||
**user.model_dump(),
|
**user.model_dump(),
|
||||||
"groups": [{"id": group.id, "name": group.name} for group in groups],
|
"groups": [{"id": group.id, "name": group.name} for group in groups],
|
||||||
"is_active": Users.is_user_active(user_id),
|
"is_active": Users.is_user_active(user_id, db=db),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|
@ -429,8 +454,10 @@ async def get_user_by_id(user_id: str, user=Depends(get_verified_user)):
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{user_id}/oauth/sessions")
|
@router.get("/{user_id}/oauth/sessions")
|
||||||
async def get_user_oauth_sessions_by_id(user_id: str, user=Depends(get_admin_user)):
|
async def get_user_oauth_sessions_by_id(
|
||||||
sessions = OAuthSessions.get_sessions_by_user_id(user_id)
|
user_id: str, user=Depends(get_admin_user), db: Session = Depends(get_session)
|
||||||
|
):
|
||||||
|
sessions = OAuthSessions.get_sessions_by_user_id(user_id, db=db)
|
||||||
if sessions and len(sessions) > 0:
|
if sessions and len(sessions) > 0:
|
||||||
return sessions
|
return sessions
|
||||||
else:
|
else:
|
||||||
|
|
@ -446,8 +473,10 @@ async def get_user_oauth_sessions_by_id(user_id: str, user=Depends(get_admin_use
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{user_id}/profile/image")
|
@router.get("/{user_id}/profile/image")
|
||||||
async def get_user_profile_image_by_id(user_id: str, user=Depends(get_verified_user)):
|
async def get_user_profile_image_by_id(
|
||||||
user = Users.get_user_by_id(user_id)
|
user_id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||||
|
):
|
||||||
|
user = Users.get_user_by_id(user_id, db=db)
|
||||||
if user:
|
if user:
|
||||||
if user.profile_image_url:
|
if user.profile_image_url:
|
||||||
# check if it's url or base64
|
# check if it's url or base64
|
||||||
|
|
@ -484,9 +513,11 @@ async def get_user_profile_image_by_id(user_id: str, user=Depends(get_verified_u
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{user_id}/active", response_model=dict)
|
@router.get("/{user_id}/active", response_model=dict)
|
||||||
async def get_user_active_status_by_id(user_id: str, user=Depends(get_verified_user)):
|
async def get_user_active_status_by_id(
|
||||||
|
user_id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||||
|
):
|
||||||
return {
|
return {
|
||||||
"active": Users.is_user_active(user_id),
|
"active": Users.is_user_active(user_id, db=db),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -500,10 +531,11 @@ async def update_user_by_id(
|
||||||
user_id: str,
|
user_id: str,
|
||||||
form_data: UserUpdateForm,
|
form_data: UserUpdateForm,
|
||||||
session_user=Depends(get_admin_user),
|
session_user=Depends(get_admin_user),
|
||||||
|
db: Session = Depends(get_session),
|
||||||
):
|
):
|
||||||
# Prevent modification of the primary admin user by other admins
|
# Prevent modification of the primary admin user by other admins
|
||||||
try:
|
try:
|
||||||
first_user = Users.get_first_user()
|
first_user = Users.get_first_user(db=db)
|
||||||
if first_user:
|
if first_user:
|
||||||
if user_id == first_user.id:
|
if user_id == first_user.id:
|
||||||
if session_user.id != user_id:
|
if session_user.id != user_id:
|
||||||
|
|
@ -527,11 +559,11 @@ async def update_user_by_id(
|
||||||
detail="Could not verify primary admin status.",
|
detail="Could not verify primary admin status.",
|
||||||
)
|
)
|
||||||
|
|
||||||
user = Users.get_user_by_id(user_id)
|
user = Users.get_user_by_id(user_id, db=db)
|
||||||
|
|
||||||
if user:
|
if user:
|
||||||
if form_data.email.lower() != user.email:
|
if form_data.email.lower() != user.email:
|
||||||
email_user = Users.get_user_by_email(form_data.email.lower())
|
email_user = Users.get_user_by_email(form_data.email.lower(), db=db)
|
||||||
if email_user:
|
if email_user:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
|
@ -545,9 +577,9 @@ async def update_user_by_id(
|
||||||
raise HTTPException(400, detail=str(e))
|
raise HTTPException(400, detail=str(e))
|
||||||
|
|
||||||
hashed = get_password_hash(form_data.password)
|
hashed = get_password_hash(form_data.password)
|
||||||
Auths.update_user_password_by_id(user_id, hashed)
|
Auths.update_user_password_by_id(user_id, hashed, db=db)
|
||||||
|
|
||||||
Auths.update_email_by_id(user_id, form_data.email.lower())
|
Auths.update_email_by_id(user_id, form_data.email.lower(), db=db)
|
||||||
updated_user = Users.update_user_by_id(
|
updated_user = Users.update_user_by_id(
|
||||||
user_id,
|
user_id,
|
||||||
{
|
{
|
||||||
|
|
@ -556,6 +588,7 @@ async def update_user_by_id(
|
||||||
"email": form_data.email.lower(),
|
"email": form_data.email.lower(),
|
||||||
"profile_image_url": form_data.profile_image_url,
|
"profile_image_url": form_data.profile_image_url,
|
||||||
},
|
},
|
||||||
|
db=db,
|
||||||
)
|
)
|
||||||
|
|
||||||
if updated_user:
|
if updated_user:
|
||||||
|
|
@ -578,10 +611,12 @@ async def update_user_by_id(
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/{user_id}", response_model=bool)
|
@router.delete("/{user_id}", response_model=bool)
|
||||||
async def delete_user_by_id(user_id: str, user=Depends(get_admin_user)):
|
async def delete_user_by_id(
|
||||||
|
user_id: str, user=Depends(get_admin_user), db: Session = Depends(get_session)
|
||||||
|
):
|
||||||
# Prevent deletion of the primary admin user
|
# Prevent deletion of the primary admin user
|
||||||
try:
|
try:
|
||||||
first_user = Users.get_first_user()
|
first_user = Users.get_first_user(db=db)
|
||||||
if first_user and user_id == first_user.id:
|
if first_user and user_id == first_user.id:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
|
@ -595,7 +630,7 @@ async def delete_user_by_id(user_id: str, user=Depends(get_admin_user)):
|
||||||
)
|
)
|
||||||
|
|
||||||
if user.id != user_id:
|
if user.id != user_id:
|
||||||
result = Auths.delete_auth_by_id(user_id)
|
result = Auths.delete_auth_by_id(user_id, db=db)
|
||||||
|
|
||||||
if result:
|
if result:
|
||||||
return True
|
return True
|
||||||
|
|
@ -618,5 +653,7 @@ async def delete_user_by_id(user_id: str, user=Depends(get_admin_user)):
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{user_id}/groups")
|
@router.get("/{user_id}/groups")
|
||||||
async def get_user_groups_by_id(user_id: str, user=Depends(get_admin_user)):
|
async def get_user_groups_by_id(
|
||||||
return Groups.get_groups_by_member_id(user_id)
|
user_id: str, user=Depends(get_admin_user), db: Session = Depends(get_session)
|
||||||
|
):
|
||||||
|
return Groups.get_groups_by_member_id(user_id, db=db)
|
||||||
|
|
|
||||||
|
|
@ -28,6 +28,7 @@ def fill_missing_permissions(
|
||||||
def get_permissions(
|
def get_permissions(
|
||||||
user_id: str,
|
user_id: str,
|
||||||
default_permissions: Dict[str, Any],
|
default_permissions: Dict[str, Any],
|
||||||
|
db: Optional[Any] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Get all permissions for a user by combining the permissions of all groups the user is a member of.
|
Get all permissions for a user by combining the permissions of all groups the user is a member of.
|
||||||
|
|
@ -53,7 +54,7 @@ def get_permissions(
|
||||||
) # Use the most permissive value (True > False)
|
) # Use the most permissive value (True > False)
|
||||||
return permissions
|
return permissions
|
||||||
|
|
||||||
user_groups = Groups.get_groups_by_member_id(user_id)
|
user_groups = Groups.get_groups_by_member_id(user_id, db=db)
|
||||||
|
|
||||||
# Deep copy default permissions to avoid modifying the original dict
|
# Deep copy default permissions to avoid modifying the original dict
|
||||||
permissions = json.loads(json.dumps(default_permissions))
|
permissions = json.loads(json.dumps(default_permissions))
|
||||||
|
|
@ -72,6 +73,7 @@ def has_permission(
|
||||||
user_id: str,
|
user_id: str,
|
||||||
permission_key: str,
|
permission_key: str,
|
||||||
default_permissions: Dict[str, Any] = {},
|
default_permissions: Dict[str, Any] = {},
|
||||||
|
db: Optional[Any] = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if a user has a specific permission by checking the group permissions
|
Check if a user has a specific permission by checking the group permissions
|
||||||
|
|
@ -92,7 +94,7 @@ def has_permission(
|
||||||
permission_hierarchy = permission_key.split(".")
|
permission_hierarchy = permission_key.split(".")
|
||||||
|
|
||||||
# Retrieve user group permissions
|
# Retrieve user group permissions
|
||||||
user_groups = Groups.get_groups_by_member_id(user_id)
|
user_groups = Groups.get_groups_by_member_id(user_id, db=db)
|
||||||
|
|
||||||
for group in user_groups:
|
for group in user_groups:
|
||||||
if get_permission(group.permissions or {}, permission_hierarchy):
|
if get_permission(group.permissions or {}, permission_hierarchy):
|
||||||
|
|
@ -127,6 +129,7 @@ def has_access(
|
||||||
access_control: Optional[dict] = None,
|
access_control: Optional[dict] = None,
|
||||||
user_group_ids: Optional[Set[str]] = None,
|
user_group_ids: Optional[Set[str]] = None,
|
||||||
strict: bool = True,
|
strict: bool = True,
|
||||||
|
db: Optional[Any] = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
if access_control is None:
|
if access_control is None:
|
||||||
if strict:
|
if strict:
|
||||||
|
|
@ -135,7 +138,7 @@ def has_access(
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if user_group_ids is None:
|
if user_group_ids is None:
|
||||||
user_groups = Groups.get_groups_by_member_id(user_id)
|
user_groups = Groups.get_groups_by_member_id(user_id, db=db)
|
||||||
user_group_ids = {group.id for group in user_groups}
|
user_group_ids = {group.id for group in user_groups}
|
||||||
|
|
||||||
permitted_ids = get_permitted_group_and_user_ids(type, access_control)
|
permitted_ids = get_permitted_group_and_user_ids(type, access_control)
|
||||||
|
|
@ -152,10 +155,10 @@ def has_access(
|
||||||
|
|
||||||
# Get all users with access to a resource
|
# Get all users with access to a resource
|
||||||
def get_users_with_access(
|
def get_users_with_access(
|
||||||
type: str = "write", access_control: Optional[dict] = None
|
type: str = "write", access_control: Optional[dict] = None, db: Optional[Any] = None
|
||||||
) -> list[UserModel]:
|
) -> list[UserModel]:
|
||||||
if access_control is None:
|
if access_control is None:
|
||||||
result = Users.get_users(filter={"roles": ["!pending"]})
|
result = Users.get_users(filter={"roles": ["!pending"]}, db=db)
|
||||||
return result.get("users", [])
|
return result.get("users", [])
|
||||||
|
|
||||||
permitted_ids = get_permitted_group_and_user_ids(type, access_control)
|
permitted_ids = get_permitted_group_and_user_ids(type, access_control)
|
||||||
|
|
@ -167,8 +170,8 @@ def get_users_with_access(
|
||||||
|
|
||||||
user_ids_with_access = set(permitted_user_ids)
|
user_ids_with_access = set(permitted_user_ids)
|
||||||
|
|
||||||
group_user_ids_map = Groups.get_group_user_ids_by_ids(permitted_group_ids)
|
group_user_ids_map = Groups.get_group_user_ids_by_ids(permitted_group_ids, db=db)
|
||||||
for user_ids in group_user_ids_map.values():
|
for user_ids in group_user_ids_map.values():
|
||||||
user_ids_with_access.update(user_ids)
|
user_ids_with_access.update(user_ids)
|
||||||
|
|
||||||
return Users.get_users_by_user_ids(list(user_ids_with_access))
|
return Users.get_users_by_user_ids(list(user_ids_with_access), db=db)
|
||||||
|
|
|
||||||
|
|
@ -42,6 +42,8 @@ from open_webui.env import (
|
||||||
|
|
||||||
from fastapi import BackgroundTasks, Depends, HTTPException, Request, Response, status
|
from fastapi import BackgroundTasks, Depends, HTTPException, Request, Response, status
|
||||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from open_webui.internal.db import get_session
|
||||||
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
@ -271,6 +273,7 @@ async def get_current_user(
|
||||||
response: Response,
|
response: Response,
|
||||||
background_tasks: BackgroundTasks,
|
background_tasks: BackgroundTasks,
|
||||||
auth_token: HTTPAuthorizationCredentials = Depends(bearer_security),
|
auth_token: HTTPAuthorizationCredentials = Depends(bearer_security),
|
||||||
|
db: Session = Depends(get_session),
|
||||||
):
|
):
|
||||||
token = None
|
token = None
|
||||||
|
|
||||||
|
|
@ -285,7 +288,7 @@ async def get_current_user(
|
||||||
|
|
||||||
# auth by api key
|
# auth by api key
|
||||||
if token.startswith("sk-"):
|
if token.startswith("sk-"):
|
||||||
user = get_current_user_by_api_key(request, token)
|
user = get_current_user_by_api_key(request, token, db=db)
|
||||||
|
|
||||||
# Add user info to current span
|
# Add user info to current span
|
||||||
current_span = trace.get_current_span()
|
current_span = trace.get_current_span()
|
||||||
|
|
@ -314,7 +317,7 @@ async def get_current_user(
|
||||||
detail="Invalid token",
|
detail="Invalid token",
|
||||||
)
|
)
|
||||||
|
|
||||||
user = Users.get_user_by_id(data["id"])
|
user = Users.get_user_by_id(data["id"], db=db)
|
||||||
if user is None:
|
if user is None:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
|
@ -364,8 +367,8 @@ async def get_current_user(
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
def get_current_user_by_api_key(request, api_key: str):
|
def get_current_user_by_api_key(request, api_key: str, db: Session = None):
|
||||||
user = Users.get_user_by_api_key(api_key)
|
user = Users.get_user_by_api_key(api_key, db=db)
|
||||||
|
|
||||||
if user is None:
|
if user is None:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -393,7 +396,7 @@ def get_current_user_by_api_key(request, api_key: str):
|
||||||
current_span.set_attribute("client.user.role", user.role)
|
current_span.set_attribute("client.user.role", user.role)
|
||||||
current_span.set_attribute("client.auth.type", "api_key")
|
current_span.set_attribute("client.auth.type", "api_key")
|
||||||
|
|
||||||
Users.update_last_active_by_id(user.id)
|
Users.update_last_active_by_id(user.id, db=db)
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ log = logging.getLogger(__name__)
|
||||||
def apply_default_group_assignment(
|
def apply_default_group_assignment(
|
||||||
default_group_id: str,
|
default_group_id: str,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
|
db=None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Apply default group assignment to a user if default_group_id is provided.
|
Apply default group assignment to a user if default_group_id is provided.
|
||||||
|
|
@ -17,7 +18,7 @@ def apply_default_group_assignment(
|
||||||
"""
|
"""
|
||||||
if default_group_id:
|
if default_group_id:
|
||||||
try:
|
try:
|
||||||
Groups.add_users_to_group(default_group_id, [user_id])
|
Groups.add_users_to_group(default_group_id, [user_id], db=db)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(
|
log.error(
|
||||||
f"Failed to add user {user_id} to default group {default_group_id}: {e}"
|
f"Failed to add user {user_id} to default group {default_group_id}: {e}"
|
||||||
|
|
|
||||||
|
|
@ -1336,7 +1336,7 @@ class OAuthManager:
|
||||||
|
|
||||||
return await client.authorize_redirect(request, redirect_uri, **kwargs)
|
return await client.authorize_redirect(request, redirect_uri, **kwargs)
|
||||||
|
|
||||||
async def handle_callback(self, request, provider, response):
|
async def handle_callback(self, request, provider, response, db=None):
|
||||||
if provider not in OAUTH_PROVIDERS:
|
if provider not in OAUTH_PROVIDERS:
|
||||||
raise HTTPException(404)
|
raise HTTPException(404)
|
||||||
|
|
||||||
|
|
@ -1461,20 +1461,20 @@ class OAuthManager:
|
||||||
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
||||||
|
|
||||||
# Check if the user exists
|
# Check if the user exists
|
||||||
user = Users.get_user_by_oauth_sub(provider, sub)
|
user = Users.get_user_by_oauth_sub(provider, sub, db=db)
|
||||||
if not user:
|
if not user:
|
||||||
# If the user does not exist, check if merging is enabled
|
# If the user does not exist, check if merging is enabled
|
||||||
if auth_manager_config.OAUTH_MERGE_ACCOUNTS_BY_EMAIL:
|
if auth_manager_config.OAUTH_MERGE_ACCOUNTS_BY_EMAIL:
|
||||||
# Check if the user exists by email
|
# Check if the user exists by email
|
||||||
user = Users.get_user_by_email(email)
|
user = Users.get_user_by_email(email, db=db)
|
||||||
if user:
|
if user:
|
||||||
# Update the user with the new oauth sub
|
# Update the user with the new oauth sub
|
||||||
Users.update_user_oauth_by_id(user.id, provider, sub)
|
Users.update_user_oauth_by_id(user.id, provider, sub, db=db)
|
||||||
|
|
||||||
if user:
|
if user:
|
||||||
determined_role = self.get_user_role(user, user_data)
|
determined_role = self.get_user_role(user, user_data)
|
||||||
if user.role != determined_role:
|
if user.role != determined_role:
|
||||||
Users.update_user_role_by_id(user.id, determined_role)
|
Users.update_user_role_by_id(user.id, determined_role, db=db)
|
||||||
# Update the user object in memory as well,
|
# Update the user object in memory as well,
|
||||||
# to avoid problems with the ENABLE_OAUTH_GROUP_MANAGEMENT check below
|
# to avoid problems with the ENABLE_OAUTH_GROUP_MANAGEMENT check below
|
||||||
user.role = determined_role
|
user.role = determined_role
|
||||||
|
|
@ -1491,14 +1491,14 @@ class OAuthManager:
|
||||||
)
|
)
|
||||||
if processed_picture_url != user.profile_image_url:
|
if processed_picture_url != user.profile_image_url:
|
||||||
Users.update_user_profile_image_url_by_id(
|
Users.update_user_profile_image_url_by_id(
|
||||||
user.id, processed_picture_url
|
user.id, processed_picture_url, db=db
|
||||||
)
|
)
|
||||||
log.debug(f"Updated profile picture for user {user.email}")
|
log.debug(f"Updated profile picture for user {user.email}")
|
||||||
else:
|
else:
|
||||||
# If the user does not exist, check if signups are enabled
|
# If the user does not exist, check if signups are enabled
|
||||||
if auth_manager_config.ENABLE_OAUTH_SIGNUP:
|
if auth_manager_config.ENABLE_OAUTH_SIGNUP:
|
||||||
# Check if an existing user with the same email already exists
|
# Check if an existing user with the same email already exists
|
||||||
existing_user = Users.get_user_by_email(email)
|
existing_user = Users.get_user_by_email(email, db=db)
|
||||||
if existing_user:
|
if existing_user:
|
||||||
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
|
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
|
||||||
|
|
||||||
|
|
@ -1529,6 +1529,7 @@ class OAuthManager:
|
||||||
profile_image_url=picture_url,
|
profile_image_url=picture_url,
|
||||||
role=self.get_user_role(None, user_data),
|
role=self.get_user_role(None, user_data),
|
||||||
oauth=oauth_data,
|
oauth=oauth_data,
|
||||||
|
db=db,
|
||||||
)
|
)
|
||||||
|
|
||||||
if auth_manager_config.WEBHOOK_URL:
|
if auth_manager_config.WEBHOOK_URL:
|
||||||
|
|
@ -1544,8 +1545,7 @@ class OAuthManager:
|
||||||
)
|
)
|
||||||
|
|
||||||
apply_default_group_assignment(
|
apply_default_group_assignment(
|
||||||
request.app.state.config.DEFAULT_GROUP_ID,
|
request.app.state.config.DEFAULT_GROUP_ID, user.id, db=db
|
||||||
user.id,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|
@ -1616,15 +1616,16 @@ class OAuthManager:
|
||||||
token["expires_at"] = datetime.now().timestamp() + token["expires_in"]
|
token["expires_at"] = datetime.now().timestamp() + token["expires_in"]
|
||||||
|
|
||||||
# Clean up any existing sessions for this user/provider first
|
# Clean up any existing sessions for this user/provider first
|
||||||
sessions = OAuthSessions.get_sessions_by_user_id(user.id)
|
sessions = OAuthSessions.get_sessions_by_user_id(user.id, db=db)
|
||||||
for session in sessions:
|
for session in sessions:
|
||||||
if session.provider == provider:
|
if session.provider == provider:
|
||||||
OAuthSessions.delete_session_by_id(session.id)
|
OAuthSessions.delete_session_by_id(session.id, db=db)
|
||||||
|
|
||||||
session = OAuthSessions.create_session(
|
session = OAuthSessions.create_session(
|
||||||
user_id=user.id,
|
user_id=user.id,
|
||||||
provider=provider,
|
provider=provider,
|
||||||
token=token,
|
token=token,
|
||||||
|
db=db,
|
||||||
)
|
)
|
||||||
|
|
||||||
response.set_cookie(
|
response.set_cookie(
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue