mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-11 20:05:19 +00:00
refac/perf: has_access_to_file optimization
This commit is contained in:
parent
9f6c91987f
commit
e301d1962e
2 changed files with 34 additions and 6 deletions
|
|
@ -217,6 +217,21 @@ class KnowledgeTable:
|
||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def get_knowledges_by_file_id(self, file_id: str) -> list[KnowledgeModel]:
|
||||||
|
try:
|
||||||
|
with get_db() as db:
|
||||||
|
knowledges = (
|
||||||
|
db.query(Knowledge)
|
||||||
|
.join(KnowledgeFile, Knowledge.id == KnowledgeFile.knowledge_id)
|
||||||
|
.filter(KnowledgeFile.file_id == file_id)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
return [
|
||||||
|
KnowledgeModel.model_validate(knowledge) for knowledge in knowledges
|
||||||
|
]
|
||||||
|
except Exception:
|
||||||
|
return []
|
||||||
|
|
||||||
def get_files_by_id(self, knowledge_id: str) -> list[FileModel]:
|
def get_files_by_id(self, knowledge_id: str) -> list[FileModel]:
|
||||||
try:
|
try:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,7 @@ from fastapi import (
|
||||||
)
|
)
|
||||||
|
|
||||||
from fastapi.responses import FileResponse, StreamingResponse
|
from fastapi.responses import FileResponse, StreamingResponse
|
||||||
|
|
||||||
from open_webui.constants import ERROR_MESSAGES
|
from open_webui.constants import ERROR_MESSAGES
|
||||||
from open_webui.env import SRC_LOG_LEVELS
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT
|
from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT
|
||||||
|
|
@ -34,12 +35,19 @@ from open_webui.models.files import (
|
||||||
Files,
|
Files,
|
||||||
)
|
)
|
||||||
from open_webui.models.knowledge import Knowledges
|
from open_webui.models.knowledge import Knowledges
|
||||||
|
from open_webui.models.groups import Groups
|
||||||
|
|
||||||
|
|
||||||
from open_webui.routers.knowledge import get_knowledge, get_knowledge_list
|
from open_webui.routers.knowledge import get_knowledge, get_knowledge_list
|
||||||
from open_webui.routers.retrieval import ProcessFileForm, process_file
|
from open_webui.routers.retrieval import ProcessFileForm, process_file
|
||||||
from open_webui.routers.audio import transcribe
|
from open_webui.routers.audio import transcribe
|
||||||
|
|
||||||
from open_webui.storage.provider import Storage
|
from open_webui.storage.provider import Storage
|
||||||
|
|
||||||
|
|
||||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||||
|
from open_webui.utils.access_control import has_access
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
@ -59,26 +67,31 @@ def has_access_to_file(
|
||||||
) -> bool:
|
) -> bool:
|
||||||
file = Files.get_file_by_id(file_id)
|
file = Files.get_file_by_id(file_id)
|
||||||
log.debug(f"Checking if user has {access_type} access to file")
|
log.debug(f"Checking if user has {access_type} access to file")
|
||||||
|
|
||||||
if not file:
|
if not file:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||||
)
|
)
|
||||||
|
|
||||||
has_access = False
|
knowledge_bases = Knowledges.get_knowledges_by_file_id(file_id)
|
||||||
knowledge_base_id = file.meta.get("collection_name") if file.meta else None
|
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user.id)}
|
||||||
|
|
||||||
|
for knowledge_base in knowledge_bases:
|
||||||
|
if knowledge_base.user_id == user.id or has_access(
|
||||||
|
user.id, access_type, knowledge_base.access_control, user_group_ids
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
|
||||||
|
knowledge_base_id = file.meta.get("collection_name") if file.meta else None
|
||||||
if knowledge_base_id:
|
if knowledge_base_id:
|
||||||
knowledge_bases = Knowledges.get_knowledge_bases_by_user_id(
|
knowledge_bases = Knowledges.get_knowledge_bases_by_user_id(
|
||||||
user.id, access_type
|
user.id, access_type
|
||||||
)
|
)
|
||||||
for knowledge_base in knowledge_bases:
|
for knowledge_base in knowledge_bases:
|
||||||
if knowledge_base.id == knowledge_base_id:
|
if knowledge_base.id == knowledge_base_id:
|
||||||
has_access = True
|
return True
|
||||||
break
|
|
||||||
|
|
||||||
return has_access
|
return False
|
||||||
|
|
||||||
|
|
||||||
############################
|
############################
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue