diff --git a/backend/open_webui/retrieval/utils.py b/backend/open_webui/retrieval/utils.py index 71f6390d6b..9d6a2a79b2 100644 --- a/backend/open_webui/retrieval/utils.py +++ b/backend/open_webui/retrieval/utils.py @@ -18,9 +18,11 @@ from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT from open_webui.models.users import UserModel from open_webui.models.files import Files +from open_webui.models.knowledge import Knowledges from open_webui.models.notes import Notes from open_webui.retrieval.vector.main import GetResult +from open_webui.utils.access_control import has_access from open_webui.env import ( @@ -443,9 +445,9 @@ def get_embedding_function( raise ValueError(f"Unknown embedding engine: {embedding_engine}") -def get_sources_from_files( +def get_sources_from_items( request, - files, + items, queries, embedding_function, k, @@ -455,75 +457,90 @@ def get_sources_from_files( hybrid_bm25_weight, hybrid_search, full_context=False, + user: Optional[UserModel] = None, ): log.debug( - f"files: {files} {queries} {embedding_function} {reranking_function} {full_context}" + f"items: {items} {queries} {embedding_function} {reranking_function} {full_context}" ) extracted_collections = [] query_results = [] - for file in files: + for item in items: query_result = None - if file.get("docs"): + if item.get("type") == "text": + # Text File + # Used during temporary chat file uploads + query_result = { + "documents": [[item.get("content")]], + "metadatas": [[{"file_id": item.get("id"), "name": item.get("name")}]], + } + + elif item.get("type") == "note": + # Note Attached + note = Notes.get_note_by_id(item.get("id")) + + if user.role == "admin" or has_access(user.id, "read", note.access_control): + # User has access to the note + query_result = { + "documents": [[note.data.get("content", {}).get("md", "")]], + "metadatas": [[{"file_id": note.id, "name": note.title}]], + } + + elif item.get("docs"): # BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL query_result = { - "documents": [[doc.get("content") for doc in file.get("docs")]], - "metadatas": [[doc.get("metadata") for doc in file.get("docs")]], + "documents": [[doc.get("content") for doc in item.get("docs")]], + "metadatas": [[doc.get("metadata") for doc in item.get("docs")]], } - elif file.get("type") == "text": - # Text File - query_result = { - "documents": [[file.get("content")]], - "metadatas": [[{"file_id": file.get("id"), "name": file.get("name")}]], - } - elif file.get("type") == "note": - # Note Attached - note = Notes.get_note_by_id(file.get("id")) - query_result = { - "documents": [[note.data.get("content", {}).get("md", "")]], - "metadatas": [[{"file_id": note.id, "name": note.title}]], - } - elif file.get("context") == "full": - if file.get("type") == "file": + elif item.get("context") == "full": + if item.get("type") == "file": # Manual Full Mode Toggle + # Used from chat file modal, we can assume that the file content will be available from item.get("file").get("data", {}).get("content") query_result = { - "documents": [[file.get("file").get("data", {}).get("content")]], + "documents": [[item.get("file").get("data", {}).get("content")]], "metadatas": [ - [{"file_id": file.get("id"), "name": file.get("name")}] + [{"file_id": item.get("id"), "name": item.get("name")}] ], } - elif file.get("type") == "collection": + elif item.get("type") == "collection": # Manual Full Mode Toggle for Collection - file_ids = file.get("data", {}).get("file_ids", []) + knowledge_base = Knowledges.get_knowledge_by_id(item.get("id")) - documents = [] - metadatas = [] - for file_id in file_ids: - file_object = Files.get_file_by_id(file_id) + if knowledge_base and ( + user.role == "admin" + or has_access(user.id, "read", knowledge_base.access_control) + ): - if file_object: - documents.append(file_object.data.get("content", "")) - metadatas.append( - { - "file_id": file_id, - "name": file_object.filename, - "source": file_object.filename, - } - ) + file_ids = knowledge_base.data.get("file_ids", []) - query_result = { - "documents": [documents], - "metadatas": [metadatas], - } + documents = [] + metadatas = [] + for file_id in file_ids: + file_object = Files.get_file_by_id(file_id) + + if file_object: + documents.append(file_object.data.get("content", "")) + metadatas.append( + { + "file_id": file_id, + "name": file_object.filename, + "source": file_object.filename, + } + ) + + query_result = { + "documents": [documents], + "metadatas": [metadatas], + } elif ( - file.get("type") != "web_search" + item.get("type") != "web_search" and request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL ): # BYPASS_EMBEDDING_AND_RETRIEVAL - if file.get("type") == "collection": - file_ids = file.get("data", {}).get("file_ids", []) + if item.get("type") == "collection": + file_ids = item.get("data", {}).get("file_ids", []) documents = [] metadatas = [] @@ -545,46 +562,46 @@ def get_sources_from_files( "metadatas": [metadatas], } - elif file.get("id"): - file_object = Files.get_file_by_id(file.get("id")) + elif item.get("id"): + file_object = Files.get_file_by_id(item.get("id")) if file_object: query_result = { "documents": [[file_object.data.get("content", "")]], "metadatas": [ [ { - "file_id": file.get("id"), + "file_id": item.get("id"), "name": file_object.filename, "source": file_object.filename, } ] ], } - elif file.get("file").get("data"): + elif item.get("file").get("data"): query_result = { - "documents": [[file.get("file").get("data", {}).get("content")]], + "documents": [[item.get("file").get("data", {}).get("content")]], "metadatas": [ - [file.get("file").get("data", {}).get("metadata", {})] + [item.get("file").get("data", {}).get("metadata", {})] ], } else: collection_names = [] - if file.get("type") == "collection": - if file.get("legacy"): - collection_names = file.get("collection_names", []) + if item.get("type") == "collection": + if item.get("legacy"): + collection_names = item.get("collection_names", []) else: - collection_names.append(file["id"]) - elif file.get("collection_name"): - collection_names.append(file["collection_name"]) - elif file.get("id"): - if file.get("legacy"): - collection_names.append(f"{file['id']}") + collection_names.append(item["id"]) + elif item.get("collection_name"): + collection_names.append(item["collection_name"]) + elif item.get("id"): + if item.get("legacy"): + collection_names.append(f"{item['id']}") else: - collection_names.append(f"file-{file['id']}") + collection_names.append(f"file-{item['id']}") collection_names = set(collection_names).difference(extracted_collections) if not collection_names: - log.debug(f"skipping {file} as it has already been extracted") + log.debug(f"skipping {item} as it has already been extracted") continue if full_context: @@ -596,14 +613,14 @@ def get_sources_from_files( else: try: query_result = None - if file.get("type") == "text": + if item.get("type") == "text": # Not sure when this is used, but it seems to be a fallback query_result = { "documents": [ - [file.get("file").get("data", {}).get("content")] + [item.get("file").get("data", {}).get("content")] ], "metadatas": [ - [file.get("file").get("data", {}).get("meta", {})] + [item.get("file").get("data", {}).get("meta", {})] ], } else: @@ -638,10 +655,10 @@ def get_sources_from_files( extracted_collections.extend(collection_names) if query_result: - if "data" in file: - del file["data"] + if "data" in item: + del item["data"] - query_results.append({**query_result, "file": file}) + query_results.append({**query_result, "file": item}) sources = [] for query_result in query_results: diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index 4c9bd6c06b..2b3c867e3f 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -56,7 +56,7 @@ from open_webui.models.users import UserModel from open_webui.models.functions import Functions from open_webui.models.models import Models -from open_webui.retrieval.utils import get_sources_from_files +from open_webui.retrieval.utils import get_sources_from_items from open_webui.utils.chat import generate_chat_completion @@ -638,14 +638,14 @@ async def chat_completion_files_handler( queries = [get_last_user_message(body["messages"])] try: - # Offload get_sources_from_files to a separate thread + # Offload get_sources_from_items to a separate thread loop = asyncio.get_running_loop() with ThreadPoolExecutor() as executor: sources = await loop.run_in_executor( executor, - lambda: get_sources_from_files( + lambda: get_sources_from_items( request=request, - files=files, + items=files, queries=queries, embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION( query, prefix=prefix, user=user @@ -657,6 +657,7 @@ async def chat_completion_files_handler( hybrid_bm25_weight=request.app.state.config.HYBRID_BM25_WEIGHT, hybrid_search=request.app.state.config.ENABLE_RAG_HYBRID_SEARCH, full_context=request.app.state.config.RAG_FULL_CONTEXT, + user=user, ), ) except Exception as e: @@ -2152,7 +2153,9 @@ async def process_chat_response( if isinstance(tool_result, dict) or isinstance( tool_result, list ): - tool_result = json.dumps(tool_result, indent=2, ensure_ascii=False) + tool_result = json.dumps( + tool_result, indent=2, ensure_ascii=False + ) results.append( {