This commit is contained in:
Timothy Jaeryang Baek 2025-07-11 12:00:21 +04:00
parent 6a3b0e3110
commit 3b9d86de0b
2 changed files with 94 additions and 74 deletions

View file

@ -18,9 +18,11 @@ from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT
from open_webui.models.users import UserModel from open_webui.models.users import UserModel
from open_webui.models.files import Files from open_webui.models.files import Files
from open_webui.models.knowledge import Knowledges
from open_webui.models.notes import Notes from open_webui.models.notes import Notes
from open_webui.retrieval.vector.main import GetResult from open_webui.retrieval.vector.main import GetResult
from open_webui.utils.access_control import has_access
from open_webui.env import ( from open_webui.env import (
@ -443,9 +445,9 @@ def get_embedding_function(
raise ValueError(f"Unknown embedding engine: {embedding_engine}") raise ValueError(f"Unknown embedding engine: {embedding_engine}")
def get_sources_from_files( def get_sources_from_items(
request, request,
files, items,
queries, queries,
embedding_function, embedding_function,
k, k,
@ -455,75 +457,90 @@ def get_sources_from_files(
hybrid_bm25_weight, hybrid_bm25_weight,
hybrid_search, hybrid_search,
full_context=False, full_context=False,
user: Optional[UserModel] = None,
): ):
log.debug( log.debug(
f"files: {files} {queries} {embedding_function} {reranking_function} {full_context}" f"items: {items} {queries} {embedding_function} {reranking_function} {full_context}"
) )
extracted_collections = [] extracted_collections = []
query_results = [] query_results = []
for file in files: for item in items:
query_result = None query_result = None
if file.get("docs"): if item.get("type") == "text":
# Text File
# Used during temporary chat file uploads
query_result = {
"documents": [[item.get("content")]],
"metadatas": [[{"file_id": item.get("id"), "name": item.get("name")}]],
}
elif item.get("type") == "note":
# Note Attached
note = Notes.get_note_by_id(item.get("id"))
if user.role == "admin" or has_access(user.id, "read", note.access_control):
# User has access to the note
query_result = {
"documents": [[note.data.get("content", {}).get("md", "")]],
"metadatas": [[{"file_id": note.id, "name": note.title}]],
}
elif item.get("docs"):
# BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL # BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL
query_result = { query_result = {
"documents": [[doc.get("content") for doc in file.get("docs")]], "documents": [[doc.get("content") for doc in item.get("docs")]],
"metadatas": [[doc.get("metadata") for doc in file.get("docs")]], "metadatas": [[doc.get("metadata") for doc in item.get("docs")]],
} }
elif file.get("type") == "text":
# Text File
query_result = {
"documents": [[file.get("content")]],
"metadatas": [[{"file_id": file.get("id"), "name": file.get("name")}]],
}
elif file.get("type") == "note":
# Note Attached
note = Notes.get_note_by_id(file.get("id"))
query_result = { elif item.get("context") == "full":
"documents": [[note.data.get("content", {}).get("md", "")]], if item.get("type") == "file":
"metadatas": [[{"file_id": note.id, "name": note.title}]],
}
elif file.get("context") == "full":
if file.get("type") == "file":
# Manual Full Mode Toggle # Manual Full Mode Toggle
# Used from chat file modal, we can assume that the file content will be available from item.get("file").get("data", {}).get("content")
query_result = { query_result = {
"documents": [[file.get("file").get("data", {}).get("content")]], "documents": [[item.get("file").get("data", {}).get("content")]],
"metadatas": [ "metadatas": [
[{"file_id": file.get("id"), "name": file.get("name")}] [{"file_id": item.get("id"), "name": item.get("name")}]
], ],
} }
elif file.get("type") == "collection": elif item.get("type") == "collection":
# Manual Full Mode Toggle for Collection # Manual Full Mode Toggle for Collection
file_ids = file.get("data", {}).get("file_ids", []) knowledge_base = Knowledges.get_knowledge_by_id(item.get("id"))
documents = [] if knowledge_base and (
metadatas = [] user.role == "admin"
for file_id in file_ids: or has_access(user.id, "read", knowledge_base.access_control)
file_object = Files.get_file_by_id(file_id) ):
if file_object: file_ids = knowledge_base.data.get("file_ids", [])
documents.append(file_object.data.get("content", ""))
metadatas.append(
{
"file_id": file_id,
"name": file_object.filename,
"source": file_object.filename,
}
)
query_result = { documents = []
"documents": [documents], metadatas = []
"metadatas": [metadatas], for file_id in file_ids:
} file_object = Files.get_file_by_id(file_id)
if file_object:
documents.append(file_object.data.get("content", ""))
metadatas.append(
{
"file_id": file_id,
"name": file_object.filename,
"source": file_object.filename,
}
)
query_result = {
"documents": [documents],
"metadatas": [metadatas],
}
elif ( elif (
file.get("type") != "web_search" item.get("type") != "web_search"
and request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL and request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL
): ):
# BYPASS_EMBEDDING_AND_RETRIEVAL # BYPASS_EMBEDDING_AND_RETRIEVAL
if file.get("type") == "collection": if item.get("type") == "collection":
file_ids = file.get("data", {}).get("file_ids", []) file_ids = item.get("data", {}).get("file_ids", [])
documents = [] documents = []
metadatas = [] metadatas = []
@ -545,46 +562,46 @@ def get_sources_from_files(
"metadatas": [metadatas], "metadatas": [metadatas],
} }
elif file.get("id"): elif item.get("id"):
file_object = Files.get_file_by_id(file.get("id")) file_object = Files.get_file_by_id(item.get("id"))
if file_object: if file_object:
query_result = { query_result = {
"documents": [[file_object.data.get("content", "")]], "documents": [[file_object.data.get("content", "")]],
"metadatas": [ "metadatas": [
[ [
{ {
"file_id": file.get("id"), "file_id": item.get("id"),
"name": file_object.filename, "name": file_object.filename,
"source": file_object.filename, "source": file_object.filename,
} }
] ]
], ],
} }
elif file.get("file").get("data"): elif item.get("file").get("data"):
query_result = { query_result = {
"documents": [[file.get("file").get("data", {}).get("content")]], "documents": [[item.get("file").get("data", {}).get("content")]],
"metadatas": [ "metadatas": [
[file.get("file").get("data", {}).get("metadata", {})] [item.get("file").get("data", {}).get("metadata", {})]
], ],
} }
else: else:
collection_names = [] collection_names = []
if file.get("type") == "collection": if item.get("type") == "collection":
if file.get("legacy"): if item.get("legacy"):
collection_names = file.get("collection_names", []) collection_names = item.get("collection_names", [])
else: else:
collection_names.append(file["id"]) collection_names.append(item["id"])
elif file.get("collection_name"): elif item.get("collection_name"):
collection_names.append(file["collection_name"]) collection_names.append(item["collection_name"])
elif file.get("id"): elif item.get("id"):
if file.get("legacy"): if item.get("legacy"):
collection_names.append(f"{file['id']}") collection_names.append(f"{item['id']}")
else: else:
collection_names.append(f"file-{file['id']}") collection_names.append(f"file-{item['id']}")
collection_names = set(collection_names).difference(extracted_collections) collection_names = set(collection_names).difference(extracted_collections)
if not collection_names: if not collection_names:
log.debug(f"skipping {file} as it has already been extracted") log.debug(f"skipping {item} as it has already been extracted")
continue continue
if full_context: if full_context:
@ -596,14 +613,14 @@ def get_sources_from_files(
else: else:
try: try:
query_result = None query_result = None
if file.get("type") == "text": if item.get("type") == "text":
# Not sure when this is used, but it seems to be a fallback # Not sure when this is used, but it seems to be a fallback
query_result = { query_result = {
"documents": [ "documents": [
[file.get("file").get("data", {}).get("content")] [item.get("file").get("data", {}).get("content")]
], ],
"metadatas": [ "metadatas": [
[file.get("file").get("data", {}).get("meta", {})] [item.get("file").get("data", {}).get("meta", {})]
], ],
} }
else: else:
@ -638,10 +655,10 @@ def get_sources_from_files(
extracted_collections.extend(collection_names) extracted_collections.extend(collection_names)
if query_result: if query_result:
if "data" in file: if "data" in item:
del file["data"] del item["data"]
query_results.append({**query_result, "file": file}) query_results.append({**query_result, "file": item})
sources = [] sources = []
for query_result in query_results: for query_result in query_results:

View file

@ -56,7 +56,7 @@ from open_webui.models.users import UserModel
from open_webui.models.functions import Functions from open_webui.models.functions import Functions
from open_webui.models.models import Models from open_webui.models.models import Models
from open_webui.retrieval.utils import get_sources_from_files from open_webui.retrieval.utils import get_sources_from_items
from open_webui.utils.chat import generate_chat_completion from open_webui.utils.chat import generate_chat_completion
@ -638,14 +638,14 @@ async def chat_completion_files_handler(
queries = [get_last_user_message(body["messages"])] queries = [get_last_user_message(body["messages"])]
try: try:
# Offload get_sources_from_files to a separate thread # Offload get_sources_from_items to a separate thread
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
with ThreadPoolExecutor() as executor: with ThreadPoolExecutor() as executor:
sources = await loop.run_in_executor( sources = await loop.run_in_executor(
executor, executor,
lambda: get_sources_from_files( lambda: get_sources_from_items(
request=request, request=request,
files=files, items=files,
queries=queries, queries=queries,
embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION( embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION(
query, prefix=prefix, user=user query, prefix=prefix, user=user
@ -657,6 +657,7 @@ async def chat_completion_files_handler(
hybrid_bm25_weight=request.app.state.config.HYBRID_BM25_WEIGHT, hybrid_bm25_weight=request.app.state.config.HYBRID_BM25_WEIGHT,
hybrid_search=request.app.state.config.ENABLE_RAG_HYBRID_SEARCH, hybrid_search=request.app.state.config.ENABLE_RAG_HYBRID_SEARCH,
full_context=request.app.state.config.RAG_FULL_CONTEXT, full_context=request.app.state.config.RAG_FULL_CONTEXT,
user=user,
), ),
) )
except Exception as e: except Exception as e:
@ -2152,7 +2153,9 @@ async def process_chat_response(
if isinstance(tool_result, dict) or isinstance( if isinstance(tool_result, dict) or isinstance(
tool_result, list tool_result, list
): ):
tool_result = json.dumps(tool_result, indent=2, ensure_ascii=False) tool_result = json.dumps(
tool_result, indent=2, ensure_ascii=False
)
results.append( results.append(
{ {