mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-13 21:05:19 +00:00
refac
This commit is contained in:
parent
6a3b0e3110
commit
3b9d86de0b
2 changed files with 94 additions and 74 deletions
|
|
@ -18,9 +18,11 @@ from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT
|
||||||
|
|
||||||
from open_webui.models.users import UserModel
|
from open_webui.models.users import UserModel
|
||||||
from open_webui.models.files import Files
|
from open_webui.models.files import Files
|
||||||
|
from open_webui.models.knowledge import Knowledges
|
||||||
from open_webui.models.notes import Notes
|
from open_webui.models.notes import Notes
|
||||||
|
|
||||||
from open_webui.retrieval.vector.main import GetResult
|
from open_webui.retrieval.vector.main import GetResult
|
||||||
|
from open_webui.utils.access_control import has_access
|
||||||
|
|
||||||
|
|
||||||
from open_webui.env import (
|
from open_webui.env import (
|
||||||
|
|
@ -443,9 +445,9 @@ def get_embedding_function(
|
||||||
raise ValueError(f"Unknown embedding engine: {embedding_engine}")
|
raise ValueError(f"Unknown embedding engine: {embedding_engine}")
|
||||||
|
|
||||||
|
|
||||||
def get_sources_from_files(
|
def get_sources_from_items(
|
||||||
request,
|
request,
|
||||||
files,
|
items,
|
||||||
queries,
|
queries,
|
||||||
embedding_function,
|
embedding_function,
|
||||||
k,
|
k,
|
||||||
|
|
@ -455,75 +457,90 @@ def get_sources_from_files(
|
||||||
hybrid_bm25_weight,
|
hybrid_bm25_weight,
|
||||||
hybrid_search,
|
hybrid_search,
|
||||||
full_context=False,
|
full_context=False,
|
||||||
|
user: Optional[UserModel] = None,
|
||||||
):
|
):
|
||||||
log.debug(
|
log.debug(
|
||||||
f"files: {files} {queries} {embedding_function} {reranking_function} {full_context}"
|
f"items: {items} {queries} {embedding_function} {reranking_function} {full_context}"
|
||||||
)
|
)
|
||||||
|
|
||||||
extracted_collections = []
|
extracted_collections = []
|
||||||
query_results = []
|
query_results = []
|
||||||
|
|
||||||
for file in files:
|
for item in items:
|
||||||
query_result = None
|
query_result = None
|
||||||
if file.get("docs"):
|
if item.get("type") == "text":
|
||||||
|
# Text File
|
||||||
|
# Used during temporary chat file uploads
|
||||||
|
query_result = {
|
||||||
|
"documents": [[item.get("content")]],
|
||||||
|
"metadatas": [[{"file_id": item.get("id"), "name": item.get("name")}]],
|
||||||
|
}
|
||||||
|
|
||||||
|
elif item.get("type") == "note":
|
||||||
|
# Note Attached
|
||||||
|
note = Notes.get_note_by_id(item.get("id"))
|
||||||
|
|
||||||
|
if user.role == "admin" or has_access(user.id, "read", note.access_control):
|
||||||
|
# User has access to the note
|
||||||
|
query_result = {
|
||||||
|
"documents": [[note.data.get("content", {}).get("md", "")]],
|
||||||
|
"metadatas": [[{"file_id": note.id, "name": note.title}]],
|
||||||
|
}
|
||||||
|
|
||||||
|
elif item.get("docs"):
|
||||||
# BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL
|
# BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL
|
||||||
query_result = {
|
query_result = {
|
||||||
"documents": [[doc.get("content") for doc in file.get("docs")]],
|
"documents": [[doc.get("content") for doc in item.get("docs")]],
|
||||||
"metadatas": [[doc.get("metadata") for doc in file.get("docs")]],
|
"metadatas": [[doc.get("metadata") for doc in item.get("docs")]],
|
||||||
}
|
}
|
||||||
elif file.get("type") == "text":
|
|
||||||
# Text File
|
|
||||||
query_result = {
|
|
||||||
"documents": [[file.get("content")]],
|
|
||||||
"metadatas": [[{"file_id": file.get("id"), "name": file.get("name")}]],
|
|
||||||
}
|
|
||||||
elif file.get("type") == "note":
|
|
||||||
# Note Attached
|
|
||||||
note = Notes.get_note_by_id(file.get("id"))
|
|
||||||
|
|
||||||
query_result = {
|
elif item.get("context") == "full":
|
||||||
"documents": [[note.data.get("content", {}).get("md", "")]],
|
if item.get("type") == "file":
|
||||||
"metadatas": [[{"file_id": note.id, "name": note.title}]],
|
|
||||||
}
|
|
||||||
elif file.get("context") == "full":
|
|
||||||
if file.get("type") == "file":
|
|
||||||
# Manual Full Mode Toggle
|
# Manual Full Mode Toggle
|
||||||
|
# Used from chat file modal, we can assume that the file content will be available from item.get("file").get("data", {}).get("content")
|
||||||
query_result = {
|
query_result = {
|
||||||
"documents": [[file.get("file").get("data", {}).get("content")]],
|
"documents": [[item.get("file").get("data", {}).get("content")]],
|
||||||
"metadatas": [
|
"metadatas": [
|
||||||
[{"file_id": file.get("id"), "name": file.get("name")}]
|
[{"file_id": item.get("id"), "name": item.get("name")}]
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
elif file.get("type") == "collection":
|
elif item.get("type") == "collection":
|
||||||
# Manual Full Mode Toggle for Collection
|
# Manual Full Mode Toggle for Collection
|
||||||
file_ids = file.get("data", {}).get("file_ids", [])
|
knowledge_base = Knowledges.get_knowledge_by_id(item.get("id"))
|
||||||
|
|
||||||
documents = []
|
if knowledge_base and (
|
||||||
metadatas = []
|
user.role == "admin"
|
||||||
for file_id in file_ids:
|
or has_access(user.id, "read", knowledge_base.access_control)
|
||||||
file_object = Files.get_file_by_id(file_id)
|
):
|
||||||
|
|
||||||
if file_object:
|
file_ids = knowledge_base.data.get("file_ids", [])
|
||||||
documents.append(file_object.data.get("content", ""))
|
|
||||||
metadatas.append(
|
|
||||||
{
|
|
||||||
"file_id": file_id,
|
|
||||||
"name": file_object.filename,
|
|
||||||
"source": file_object.filename,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
query_result = {
|
documents = []
|
||||||
"documents": [documents],
|
metadatas = []
|
||||||
"metadatas": [metadatas],
|
for file_id in file_ids:
|
||||||
}
|
file_object = Files.get_file_by_id(file_id)
|
||||||
|
|
||||||
|
if file_object:
|
||||||
|
documents.append(file_object.data.get("content", ""))
|
||||||
|
metadatas.append(
|
||||||
|
{
|
||||||
|
"file_id": file_id,
|
||||||
|
"name": file_object.filename,
|
||||||
|
"source": file_object.filename,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
query_result = {
|
||||||
|
"documents": [documents],
|
||||||
|
"metadatas": [metadatas],
|
||||||
|
}
|
||||||
elif (
|
elif (
|
||||||
file.get("type") != "web_search"
|
item.get("type") != "web_search"
|
||||||
and request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL
|
and request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL
|
||||||
):
|
):
|
||||||
# BYPASS_EMBEDDING_AND_RETRIEVAL
|
# BYPASS_EMBEDDING_AND_RETRIEVAL
|
||||||
if file.get("type") == "collection":
|
if item.get("type") == "collection":
|
||||||
file_ids = file.get("data", {}).get("file_ids", [])
|
file_ids = item.get("data", {}).get("file_ids", [])
|
||||||
|
|
||||||
documents = []
|
documents = []
|
||||||
metadatas = []
|
metadatas = []
|
||||||
|
|
@ -545,46 +562,46 @@ def get_sources_from_files(
|
||||||
"metadatas": [metadatas],
|
"metadatas": [metadatas],
|
||||||
}
|
}
|
||||||
|
|
||||||
elif file.get("id"):
|
elif item.get("id"):
|
||||||
file_object = Files.get_file_by_id(file.get("id"))
|
file_object = Files.get_file_by_id(item.get("id"))
|
||||||
if file_object:
|
if file_object:
|
||||||
query_result = {
|
query_result = {
|
||||||
"documents": [[file_object.data.get("content", "")]],
|
"documents": [[file_object.data.get("content", "")]],
|
||||||
"metadatas": [
|
"metadatas": [
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
"file_id": file.get("id"),
|
"file_id": item.get("id"),
|
||||||
"name": file_object.filename,
|
"name": file_object.filename,
|
||||||
"source": file_object.filename,
|
"source": file_object.filename,
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
elif file.get("file").get("data"):
|
elif item.get("file").get("data"):
|
||||||
query_result = {
|
query_result = {
|
||||||
"documents": [[file.get("file").get("data", {}).get("content")]],
|
"documents": [[item.get("file").get("data", {}).get("content")]],
|
||||||
"metadatas": [
|
"metadatas": [
|
||||||
[file.get("file").get("data", {}).get("metadata", {})]
|
[item.get("file").get("data", {}).get("metadata", {})]
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
collection_names = []
|
collection_names = []
|
||||||
if file.get("type") == "collection":
|
if item.get("type") == "collection":
|
||||||
if file.get("legacy"):
|
if item.get("legacy"):
|
||||||
collection_names = file.get("collection_names", [])
|
collection_names = item.get("collection_names", [])
|
||||||
else:
|
else:
|
||||||
collection_names.append(file["id"])
|
collection_names.append(item["id"])
|
||||||
elif file.get("collection_name"):
|
elif item.get("collection_name"):
|
||||||
collection_names.append(file["collection_name"])
|
collection_names.append(item["collection_name"])
|
||||||
elif file.get("id"):
|
elif item.get("id"):
|
||||||
if file.get("legacy"):
|
if item.get("legacy"):
|
||||||
collection_names.append(f"{file['id']}")
|
collection_names.append(f"{item['id']}")
|
||||||
else:
|
else:
|
||||||
collection_names.append(f"file-{file['id']}")
|
collection_names.append(f"file-{item['id']}")
|
||||||
|
|
||||||
collection_names = set(collection_names).difference(extracted_collections)
|
collection_names = set(collection_names).difference(extracted_collections)
|
||||||
if not collection_names:
|
if not collection_names:
|
||||||
log.debug(f"skipping {file} as it has already been extracted")
|
log.debug(f"skipping {item} as it has already been extracted")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if full_context:
|
if full_context:
|
||||||
|
|
@ -596,14 +613,14 @@ def get_sources_from_files(
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
query_result = None
|
query_result = None
|
||||||
if file.get("type") == "text":
|
if item.get("type") == "text":
|
||||||
# Not sure when this is used, but it seems to be a fallback
|
# Not sure when this is used, but it seems to be a fallback
|
||||||
query_result = {
|
query_result = {
|
||||||
"documents": [
|
"documents": [
|
||||||
[file.get("file").get("data", {}).get("content")]
|
[item.get("file").get("data", {}).get("content")]
|
||||||
],
|
],
|
||||||
"metadatas": [
|
"metadatas": [
|
||||||
[file.get("file").get("data", {}).get("meta", {})]
|
[item.get("file").get("data", {}).get("meta", {})]
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
|
|
@ -638,10 +655,10 @@ def get_sources_from_files(
|
||||||
extracted_collections.extend(collection_names)
|
extracted_collections.extend(collection_names)
|
||||||
|
|
||||||
if query_result:
|
if query_result:
|
||||||
if "data" in file:
|
if "data" in item:
|
||||||
del file["data"]
|
del item["data"]
|
||||||
|
|
||||||
query_results.append({**query_result, "file": file})
|
query_results.append({**query_result, "file": item})
|
||||||
|
|
||||||
sources = []
|
sources = []
|
||||||
for query_result in query_results:
|
for query_result in query_results:
|
||||||
|
|
|
||||||
|
|
@ -56,7 +56,7 @@ from open_webui.models.users import UserModel
|
||||||
from open_webui.models.functions import Functions
|
from open_webui.models.functions import Functions
|
||||||
from open_webui.models.models import Models
|
from open_webui.models.models import Models
|
||||||
|
|
||||||
from open_webui.retrieval.utils import get_sources_from_files
|
from open_webui.retrieval.utils import get_sources_from_items
|
||||||
|
|
||||||
|
|
||||||
from open_webui.utils.chat import generate_chat_completion
|
from open_webui.utils.chat import generate_chat_completion
|
||||||
|
|
@ -638,14 +638,14 @@ async def chat_completion_files_handler(
|
||||||
queries = [get_last_user_message(body["messages"])]
|
queries = [get_last_user_message(body["messages"])]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Offload get_sources_from_files to a separate thread
|
# Offload get_sources_from_items to a separate thread
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
with ThreadPoolExecutor() as executor:
|
with ThreadPoolExecutor() as executor:
|
||||||
sources = await loop.run_in_executor(
|
sources = await loop.run_in_executor(
|
||||||
executor,
|
executor,
|
||||||
lambda: get_sources_from_files(
|
lambda: get_sources_from_items(
|
||||||
request=request,
|
request=request,
|
||||||
files=files,
|
items=files,
|
||||||
queries=queries,
|
queries=queries,
|
||||||
embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION(
|
embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION(
|
||||||
query, prefix=prefix, user=user
|
query, prefix=prefix, user=user
|
||||||
|
|
@ -657,6 +657,7 @@ async def chat_completion_files_handler(
|
||||||
hybrid_bm25_weight=request.app.state.config.HYBRID_BM25_WEIGHT,
|
hybrid_bm25_weight=request.app.state.config.HYBRID_BM25_WEIGHT,
|
||||||
hybrid_search=request.app.state.config.ENABLE_RAG_HYBRID_SEARCH,
|
hybrid_search=request.app.state.config.ENABLE_RAG_HYBRID_SEARCH,
|
||||||
full_context=request.app.state.config.RAG_FULL_CONTEXT,
|
full_context=request.app.state.config.RAG_FULL_CONTEXT,
|
||||||
|
user=user,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -2152,7 +2153,9 @@ async def process_chat_response(
|
||||||
if isinstance(tool_result, dict) or isinstance(
|
if isinstance(tool_result, dict) or isinstance(
|
||||||
tool_result, list
|
tool_result, list
|
||||||
):
|
):
|
||||||
tool_result = json.dumps(tool_result, indent=2, ensure_ascii=False)
|
tool_result = json.dumps(
|
||||||
|
tool_result, indent=2, ensure_ascii=False
|
||||||
|
)
|
||||||
|
|
||||||
results.append(
|
results.append(
|
||||||
{
|
{
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue