open-webui/backend/open_webui/retrieval/utils.py

1324 lines
45 KiB
Python
Raw Normal View History

import logging
2024-08-27 22:10:27 +00:00
import os
2025-11-23 02:33:14 +00:00
from typing import Awaitable, Optional, Union
2024-03-09 03:26:39 +00:00
2024-08-27 22:10:27 +00:00
import requests
import aiohttp
import asyncio
2025-02-27 07:51:39 +00:00
import hashlib
2025-03-31 14:43:37 +00:00
from concurrent.futures import ThreadPoolExecutor
2025-05-20 02:58:04 +00:00
import time
2025-10-04 07:02:26 +00:00
import re
2024-09-10 01:27:50 +00:00
from urllib.parse import quote
2024-04-25 12:49:59 +00:00
from huggingface_hub import snapshot_download
2024-08-27 22:10:27 +00:00
from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriever
from langchain_community.retrievers import BM25Retriever
2024-08-27 22:10:27 +00:00
from langchain_core.documents import Document
2024-09-10 01:27:50 +00:00
2025-01-08 08:21:50 +00:00
from open_webui.config import VECTOR_DB
from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT
2025-02-20 19:02:45 +00:00
2025-10-04 07:02:26 +00:00
from open_webui.models.users import UserModel
2025-02-26 23:42:19 +00:00
from open_webui.models.files import Files
2025-07-11 08:00:21 +00:00
from open_webui.models.knowledge import Knowledges
2025-09-14 08:26:46 +00:00
from open_webui.models.chats import Chats
2025-07-08 21:17:25 +00:00
from open_webui.models.notes import Notes
2024-04-14 23:48:15 +00:00
2025-03-31 03:48:22 +00:00
from open_webui.retrieval.vector.main import GetResult
2025-07-11 08:00:21 +00:00
from open_webui.utils.access_control import has_access
from open_webui.utils.headers import include_user_info_headers
2025-09-14 08:26:46 +00:00
from open_webui.utils.misc import get_message_list
2025-03-31 03:48:22 +00:00
2025-10-04 07:02:26 +00:00
from open_webui.retrieval.web.utils import get_web_loader
from open_webui.retrieval.loaders.youtube import YoutubeLoader
2025-02-05 23:15:24 +00:00
2025-02-05 08:07:45 +00:00
from open_webui.env import (
SRC_LOG_LEVELS,
OFFLINE_MODE,
ENABLE_FORWARD_USER_INFO_HEADERS,
)
2025-02-04 21:04:36 +00:00
from open_webui.config import (
2025-03-31 04:55:15 +00:00
RAG_EMBEDDING_QUERY_PREFIX,
RAG_EMBEDDING_CONTENT_PREFIX,
RAG_EMBEDDING_PREFIX_FIELD_NAME,
2025-02-04 21:04:36 +00:00
)
2024-09-10 01:27:50 +00:00
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
2024-03-09 03:26:39 +00:00
2024-09-10 03:37:06 +00:00
from typing import Any
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.retrievers import BaseRetriever
2025-10-04 07:02:26 +00:00
def is_youtube_url(url: str) -> bool:
youtube_regex = r"^(https?://)?(www\.)?(youtube\.com|youtu\.be)/.+$"
return re.match(youtube_regex, url) is not None
def get_loader(request, url: str):
if is_youtube_url(url):
return YoutubeLoader(
url,
language=request.app.state.config.YOUTUBE_LOADER_LANGUAGE,
proxy_url=request.app.state.config.YOUTUBE_LOADER_PROXY_URL,
)
else:
return get_web_loader(
url,
verify_ssl=request.app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION,
requests_per_second=request.app.state.config.WEB_LOADER_CONCURRENT_REQUESTS,
trust_env=request.app.state.config.WEB_SEARCH_TRUST_ENV,
2025-10-04 07:02:26 +00:00
)
def get_content_from_url(request, url: str) -> str:
loader = get_loader(request, url)
docs = loader.load()
content = " ".join([doc.page_content for doc in docs])
return content, docs
2024-09-10 03:37:06 +00:00
class VectorSearchRetriever(BaseRetriever):
collection_name: Any
embedding_function: Any
top_k: int
2025-11-24 10:52:18 +00:00
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> list[Document]:
"""Get documents relevant to a query.
Args:
query: String to find relevant documents for.
run_manager: The callback handler to use.
Returns:
List of relevant documents.
"""
2025-11-24 10:58:22 +00:00
return []
2025-11-24 10:52:18 +00:00
async def _aget_relevant_documents(
2024-09-10 03:37:06 +00:00
self,
query: str,
2024-12-31 00:55:29 +00:00
*,
run_manager: CallbackManagerForRetrieverRun,
2024-09-10 03:37:06 +00:00
) -> list[Document]:
embedding = await self.embedding_function(query, RAG_EMBEDDING_QUERY_PREFIX)
2024-09-10 03:37:06 +00:00
result = VECTOR_DB_CLIENT.search(
collection_name=self.collection_name,
vectors=[embedding],
2024-09-10 03:37:06 +00:00
limit=self.top_k,
)
2024-09-13 05:18:20 +00:00
ids = result.ids[0]
metadatas = result.metadatas[0]
documents = result.documents[0]
2024-09-10 03:37:06 +00:00
2024-12-31 00:55:29 +00:00
results = []
for idx in range(len(ids)):
results.append(
Document(
metadata=metadatas[idx],
page_content=documents[idx],
)
)
return results
2024-09-10 03:37:06 +00:00
2024-04-27 19:38:50 +00:00
def query_doc(
2025-02-05 08:07:45 +00:00
collection_name: str, query_embedding: list[float], k: int, user: UserModel = None
2024-04-22 20:49:58 +00:00
):
2024-04-14 21:55:00 +00:00
try:
2025-04-08 16:08:32 +00:00
log.debug(f"query_doc:doc {collection_name}")
2024-12-31 00:55:29 +00:00
result = VECTOR_DB_CLIENT.search(
2024-09-10 03:37:06 +00:00
collection_name=collection_name,
vectors=[query_embedding],
2024-09-10 03:37:06 +00:00
limit=k,
2024-12-31 00:55:29 +00:00
)
if result:
2024-12-20 04:56:16 +00:00
log.info(f"query_doc:result {result.ids} {result.metadatas}")
2024-12-31 00:55:29 +00:00
return result
2024-04-27 19:38:50 +00:00
except Exception as e:
log.exception(f"Error querying doc {collection_name} with limit {k}: {e}")
2024-04-27 19:38:50 +00:00
raise e
2024-04-25 21:03:00 +00:00
2025-02-19 05:14:58 +00:00
def get_doc(collection_name: str, user: UserModel = None):
try:
2025-04-08 16:08:32 +00:00
log.debug(f"get_doc:doc {collection_name}")
2025-02-19 05:14:58 +00:00
result = VECTOR_DB_CLIENT.get(collection_name=collection_name)
if result:
log.info(f"query_doc:result {result.ids} {result.metadatas}")
return result
except Exception as e:
log.exception(f"Error getting doc {collection_name}: {e}")
2025-02-19 05:14:58 +00:00
raise e
def get_enriched_texts(collection_result: GetResult) -> list[str]:
enriched_texts = []
for idx, text in enumerate(collection_result.documents[0]):
metadata = collection_result.metadatas[0][idx]
metadata_parts = [text]
# Add filename (repeat twice for extra weight in BM25 scoring)
if metadata.get("name"):
filename = metadata["name"]
filename_tokens = (
filename.replace("_", " ").replace("-", " ").replace(".", " ")
)
metadata_parts.append(
f"Filename: {filename} {filename_tokens} {filename_tokens}"
)
# Add title if available
if metadata.get("title"):
metadata_parts.append(f"Title: {metadata['title']}")
# Add document section headings if available (from markdown splitter)
if metadata.get("headings") and isinstance(metadata["headings"], list):
headings = " > ".join(str(h) for h in metadata["headings"])
metadata_parts.append(f"Section: {headings}")
# Add source URL/path if available
if metadata.get("source"):
metadata_parts.append(f"Source: {metadata['source']}")
# Add snippet for web search results
if metadata.get("snippet"):
metadata_parts.append(f"Snippet: {metadata['snippet']}")
enriched_texts.append(" ".join(metadata_parts))
return enriched_texts
async def query_doc_with_hybrid_search(
2024-04-27 19:38:50 +00:00
collection_name: str,
2025-03-31 03:48:22 +00:00
collection_result: GetResult,
2024-04-27 19:38:50 +00:00
query: str,
embedding_function,
k: int,
reranking_function,
2025-03-06 09:47:57 +00:00
k_reranker: int,
r: float,
hybrid_bm25_weight: float,
enable_enriched_texts: bool = False,
2024-09-12 13:50:18 +00:00
) -> dict:
2024-04-27 19:38:50 +00:00
try:
# First check if collection_result has the required attributes
2025-09-21 04:14:43 +00:00
if (
not collection_result
or not hasattr(collection_result, "documents")
or not hasattr(collection_result, "metadatas")
):
log.warning(f"query_doc_with_hybrid_search:no_docs {collection_name}")
return {"documents": [], "metadatas": [], "distances": []}
2025-11-10 02:33:50 +00:00
# Now safely check the documents content after confirming attributes exist
if (
not collection_result.documents
2025-09-21 04:14:43 +00:00
or len(collection_result.documents) == 0
or not collection_result.documents[0]
):
2025-08-26 11:04:46 +00:00
log.warning(f"query_doc_with_hybrid_search:no_docs {collection_name}")
return {"documents": [], "metadatas": [], "distances": []}
2025-09-01 10:21:17 +00:00
log.debug(f"query_doc_with_hybrid_search:doc {collection_name}")
2025-09-01 10:22:02 +00:00
bm25_texts = (
get_enriched_texts(collection_result)
if enable_enriched_texts
else collection_result.documents[0]
)
2025-09-01 10:21:17 +00:00
bm25_retriever = BM25Retriever.from_texts(
texts=bm25_texts,
2025-09-01 10:21:17 +00:00
metadatas=collection_result.metadatas[0],
)
bm25_retriever.k = k
2024-04-25 21:03:00 +00:00
2024-09-10 03:37:06 +00:00
vector_search_retriever = VectorSearchRetriever(
collection_name=collection_name,
2024-04-27 19:38:50 +00:00
embedding_function=embedding_function,
2024-09-10 03:37:06 +00:00
top_k=k,
2024-04-27 19:38:50 +00:00
)
2024-04-25 21:03:00 +00:00
if hybrid_bm25_weight <= 0:
ensemble_retriever = EnsembleRetriever(
2025-05-23 22:13:54 +00:00
retrievers=[vector_search_retriever], weights=[1.0]
)
elif hybrid_bm25_weight >= 1:
ensemble_retriever = EnsembleRetriever(
2025-05-23 22:13:54 +00:00
retrievers=[bm25_retriever], weights=[1.0]
)
else:
ensemble_retriever = EnsembleRetriever(
retrievers=[bm25_retriever, vector_search_retriever],
2025-05-23 22:13:54 +00:00
weights=[hybrid_bm25_weight, 1.0 - hybrid_bm25_weight],
)
2024-04-27 19:38:50 +00:00
compressor = RerankCompressor(
embedding_function=embedding_function,
2025-03-06 09:47:57 +00:00
top_n=k_reranker,
2024-04-27 19:38:50 +00:00
reranking_function=reranking_function,
r_score=r,
)
compression_retriever = ContextualCompressionRetriever(
base_compressor=compressor, base_retriever=ensemble_retriever
)
2024-04-25 21:03:00 +00:00
result = await compression_retriever.ainvoke(query)
2025-03-31 03:48:22 +00:00
distances = [d.metadata.get("score") for d in result]
documents = [d.page_content for d in result]
metadatas = [d.metadata for d in result]
# retrieve only min(k, k_reranker) items, sort and cut by distance if k < k_reranker
if k < k_reranker:
2025-03-27 08:40:28 +00:00
sorted_items = sorted(
zip(distances, metadatas, documents), key=lambda x: x[0], reverse=True
)
sorted_items = sorted_items[:k]
2025-10-07 12:31:06 +00:00
if sorted_items:
distances, documents, metadatas = map(list, zip(*sorted_items))
else:
distances, documents, metadatas = [], [], []
2025-03-31 03:48:22 +00:00
result = {
"distances": [distances],
2025-03-18 11:14:59 +00:00
"documents": [documents],
"metadatas": [metadatas],
2024-04-27 19:38:50 +00:00
}
2024-04-29 17:15:58 +00:00
log.info(
2024-11-07 07:01:10 +00:00
"query_doc_with_hybrid_search:result "
2025-03-31 03:48:22 +00:00
+ f'{result["metadatas"]} {result["distances"]}'
)
2025-03-31 03:48:22 +00:00
return result
2024-04-14 21:55:00 +00:00
except Exception as e:
2025-04-08 16:08:32 +00:00
log.exception(f"Error querying doc {collection_name} with hybrid search: {e}")
2024-04-14 21:55:00 +00:00
raise e
2025-02-19 05:14:58 +00:00
def merge_get_results(get_results: list[dict]) -> dict:
# Initialize lists to store combined data
combined_documents = []
combined_metadatas = []
2025-02-19 07:49:27 +00:00
combined_ids = []
2025-02-19 05:14:58 +00:00
for data in get_results:
combined_documents.extend(data["documents"][0])
combined_metadatas.extend(data["metadatas"][0])
2025-02-19 07:49:27 +00:00
combined_ids.extend(data["ids"][0])
2025-02-19 05:14:58 +00:00
# Create the output dictionary
result = {
"documents": [combined_documents],
"metadatas": [combined_metadatas],
2025-02-19 07:49:27 +00:00
"ids": [combined_ids],
2025-02-19 05:14:58 +00:00
}
return result
2025-03-25 18:09:17 +00:00
def merge_and_sort_query_results(query_results: list[dict], k: int) -> dict:
2024-12-31 00:55:29 +00:00
# Initialize lists to store combined data
2025-03-19 15:06:10 +00:00
combined = dict() # To store documents with unique document hashes
2024-12-31 00:55:29 +00:00
for data in query_results:
2025-10-09 21:16:24 +00:00
if (
len(data.get("distances", [])) == 0
or len(data.get("documents", [])) == 0
or len(data.get("metadatas", [])) == 0
):
continue
2025-02-27 07:51:39 +00:00
distances = data["distances"][0]
documents = data["documents"][0]
metadatas = data["metadatas"][0]
for distance, document, metadata in zip(distances, documents, metadatas):
if isinstance(document, str):
doc_hash = hashlib.sha256(
2025-02-27 07:51:39 +00:00
document.encode()
).hexdigest() # Compute a hash for uniqueness
2024-12-31 00:55:29 +00:00
2025-03-19 15:06:10 +00:00
if doc_hash not in combined.keys():
combined[doc_hash] = (distance, document, metadata)
continue # if doc is new, no further comparison is needed
2024-12-31 00:55:29 +00:00
2025-03-19 15:06:10 +00:00
# if doc is alredy in, but new distance is better, update
if distance > combined[doc_hash][0]:
2025-03-19 15:06:10 +00:00
combined[doc_hash] = (distance, document, metadata)
combined = list(combined.values())
2024-12-31 00:55:29 +00:00
# Sort the list based on distances
combined.sort(key=lambda x: x[0], reverse=True)
2024-12-31 00:55:29 +00:00
2025-02-27 07:51:39 +00:00
# Slice to keep only the top k elements
sorted_distances, sorted_documents, sorted_metadatas = (
zip(*combined[:k]) if combined else ([], [], [])
)
2025-02-20 19:02:45 +00:00
2025-02-27 07:51:39 +00:00
# Create and return the output dictionary
return {
"distances": [list(sorted_distances)],
"documents": [list(sorted_documents)],
"metadatas": [list(sorted_metadatas)],
2024-12-31 00:55:29 +00:00
}
2024-03-09 03:26:39 +00:00
2025-02-19 05:14:58 +00:00
def get_all_items_from_collections(collection_names: list[str]) -> dict:
results = []
for collection_name in collection_names:
if collection_name:
try:
result = get_doc(collection_name=collection_name)
if result is not None:
results.append(result.model_dump())
except Exception as e:
log.exception(f"Error when querying the collection: {e}")
else:
pass
return merge_get_results(results)
async def query_collection(
2024-08-14 12:46:31 +00:00
collection_names: list[str],
2024-11-19 10:24:32 +00:00
queries: list[str],
2024-04-27 19:38:50 +00:00
embedding_function,
k: int,
2024-09-12 13:50:18 +00:00
) -> dict:
2024-04-27 19:38:50 +00:00
results = []
error = False
def process_query_collection(collection_name, query_embedding):
try:
2024-12-31 00:55:29 +00:00
if collection_name:
result = query_doc(
collection_name=collection_name,
k=k,
query_embedding=query_embedding,
)
if result is not None:
return result.model_dump(), None
return None, None
except Exception as e:
log.exception(f"Error when querying the collection: {e}")
return None, e
# Generate all query embeddings (in one call)
2025-11-23 02:33:14 +00:00
query_embeddings = await embedding_function(
queries, prefix=RAG_EMBEDDING_QUERY_PREFIX
)
log.debug(
f"query_collection: processing {len(queries)} queries across {len(collection_names)} collections"
)
with ThreadPoolExecutor() as executor:
future_results = []
for query_embedding in query_embeddings:
for collection_name in collection_names:
result = executor.submit(
process_query_collection, collection_name, query_embedding
)
future_results.append(result)
task_results = [future.result() for future in future_results]
for result, err in task_results:
if err is not None:
error = True
elif result is not None:
results.append(result)
if error and not results:
log.warning("All collection queries failed. No results returned.")
return merge_and_sort_query_results(results, k=k)
2024-04-27 19:38:50 +00:00
async def query_collection_with_hybrid_search(
2024-08-14 12:46:31 +00:00
collection_names: list[str],
2024-11-19 10:24:32 +00:00
queries: list[str],
2024-04-27 19:38:50 +00:00
embedding_function,
2024-04-22 20:49:58 +00:00
k: int,
reranking_function,
2025-03-06 09:47:57 +00:00
k_reranker: int,
2024-04-27 19:38:50 +00:00
r: float,
hybrid_bm25_weight: float,
enable_enriched_texts: bool = False,
2024-09-12 13:50:18 +00:00
) -> dict:
2024-04-14 21:55:00 +00:00
results = []
2024-09-13 05:18:20 +00:00
error = False
2025-03-27 18:05:20 +00:00
# Fetch collection data once per collection sequentially
# Avoid fetching the same data multiple times later
2025-03-31 03:48:22 +00:00
collection_results = {}
2025-08-31 20:57:13 +00:00
for collection_name in collection_names:
try:
log.debug(
f"query_collection_with_hybrid_search:VECTOR_DB_CLIENT.get:collection {collection_name}"
)
collection_results[collection_name] = VECTOR_DB_CLIENT.get(
collection_name=collection_name
)
except Exception as e:
log.exception(f"Failed to fetch collection {collection_name}: {e}")
collection_results[collection_name] = None
2025-04-01 00:59:21 +00:00
log.info(
f"Starting hybrid search for {len(queries)} queries in {len(collection_names)} collections..."
)
async def process_query(collection_name, query):
2024-12-31 00:55:29 +00:00
try:
result = await query_doc_with_hybrid_search(
2025-03-31 14:43:37 +00:00
collection_name=collection_name,
collection_result=collection_results[collection_name],
query=query,
embedding_function=embedding_function,
k=k,
reranking_function=reranking_function,
k_reranker=k_reranker,
r=r,
hybrid_bm25_weight=hybrid_bm25_weight,
enable_enriched_texts=enable_enriched_texts,
2024-12-31 00:55:29 +00:00
)
2025-03-31 14:43:37 +00:00
return result, None
except Exception as e:
log.exception(f"Error when querying the collection with hybrid_search: {e}")
return None, e
# Prepare tasks for all collections and queries
# Avoid running any tasks for collections that failed to fetch data (have assigned None)
2025-04-05 14:03:24 +00:00
tasks = [
2025-11-23 02:33:14 +00:00
(collection_name, query)
for collection_name in collection_names
if collection_results[collection_name] is not None
for query in queries
2025-04-05 14:03:24 +00:00
]
2025-03-31 14:43:37 +00:00
# Run all queries in parallel using asyncio.gather
2025-11-23 02:33:14 +00:00
task_results = await asyncio.gather(
*[process_query(collection_name, query) for collection_name, query in tasks]
)
2025-03-31 14:43:37 +00:00
for result, err in task_results:
if err is not None:
2024-12-31 00:55:29 +00:00
error = True
2025-03-31 14:43:37 +00:00
elif result is not None:
results.append(result)
2024-09-13 05:18:20 +00:00
2025-03-31 14:43:37 +00:00
if error and not results:
2025-04-01 00:59:21 +00:00
raise Exception(
"Hybrid search failed for all collections. Using Non-hybrid search as fallback."
)
return merge_and_sort_query_results(results, k=k)
2024-04-14 21:55:00 +00:00
2025-03-27 08:40:28 +00:00
2025-11-23 02:33:14 +00:00
def generate_openai_batch_embeddings(
model: str,
texts: list[str],
url: str = "https://api.openai.com/v1",
key: str = "",
prefix: str = None,
user: UserModel = None,
) -> Optional[list[list[float]]]:
try:
log.debug(
2025-11-23 02:33:14 +00:00
f"generate_openai_batch_embeddings:model {model} batch size: {len(texts)}"
)
json_data = {"input": texts, "model": model}
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
2025-11-23 09:40:05 +00:00
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {key}",
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
headers = include_user_info_headers(headers, user)
2025-11-24 01:15:52 +00:00
2025-11-23 02:33:14 +00:00
r = requests.post(
f"{url}/embeddings",
2025-11-23 09:40:05 +00:00
headers=headers,
2025-11-23 02:33:14 +00:00
json=json_data,
)
r.raise_for_status()
data = r.json()
if "data" in data:
return [elem["embedding"] for elem in data["data"]]
else:
raise "Something went wrong :/"
except Exception as e:
log.exception(f"Error generating openai batch embeddings: {e}")
return None
async def agenerate_openai_batch_embeddings(
model: str,
texts: list[str],
url: str = "https://api.openai.com/v1",
key: str = "",
prefix: str = None,
user: UserModel = None,
) -> Optional[list[list[float]]]:
try:
log.debug(
f"agenerate_openai_batch_embeddings:model {model} batch size: {len(texts)}"
)
form_data = {"input": texts, "model": model}
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
form_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {key}",
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
headers = include_user_info_headers(headers, user)
2025-11-23 02:33:14 +00:00
async with aiohttp.ClientSession(trust_env=True) as session:
async with session.post(
f"{url}/embeddings", headers=headers, json=form_data
) as r:
r.raise_for_status()
data = await r.json()
if "data" in data:
return [item["embedding"] for item in data["data"]]
else:
raise Exception("Something went wrong :/")
except Exception as e:
log.exception(f"Error generating openai batch embeddings: {e}")
return None
2025-11-23 02:33:14 +00:00
def generate_azure_openai_batch_embeddings(
model: str,
texts: list[str],
url: str,
key: str = "",
version: str = "",
prefix: str = None,
user: UserModel = None,
) -> Optional[list[list[float]]]:
try:
log.debug(
2025-11-23 02:33:14 +00:00
f"generate_azure_openai_batch_embeddings:deployment {model} batch size: {len(texts)}"
)
json_data = {"input": texts}
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
url = f"{url}/openai/deployments/{model}/embeddings?api-version={version}"
for _ in range(5):
2025-11-23 09:40:05 +00:00
headers = {
"Content-Type": "application/json",
"api-key": key,
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
headers = include_user_info_headers(headers, user)
2025-11-24 01:15:52 +00:00
2025-11-23 02:33:14 +00:00
r = requests.post(
url,
2025-11-23 09:40:05 +00:00
headers=headers,
2025-11-23 02:33:14 +00:00
json=json_data,
)
if r.status_code == 429:
retry = float(r.headers.get("Retry-After", "1"))
time.sleep(retry)
continue
r.raise_for_status()
data = r.json()
if "data" in data:
return [elem["embedding"] for elem in data["data"]]
else:
raise Exception("Something went wrong :/")
return None
except Exception as e:
log.exception(f"Error generating azure openai batch embeddings: {e}")
return None
async def agenerate_azure_openai_batch_embeddings(
model: str,
texts: list[str],
url: str,
key: str = "",
version: str = "",
prefix: str = None,
user: UserModel = None,
) -> Optional[list[list[float]]]:
try:
log.debug(
f"agenerate_azure_openai_batch_embeddings:deployment {model} batch size: {len(texts)}"
)
form_data = {"input": texts}
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
form_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
full_url = f"{url}/openai/deployments/{model}/embeddings?api-version={version}"
headers = {
"Content-Type": "application/json",
"api-key": key,
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
headers = include_user_info_headers(headers, user)
2025-11-23 02:33:14 +00:00
async with aiohttp.ClientSession(trust_env=True) as session:
async with session.post(full_url, headers=headers, json=form_data) as r:
r.raise_for_status()
data = await r.json()
if "data" in data:
return [item["embedding"] for item in data["data"]]
else:
raise Exception("Something went wrong :/")
except Exception as e:
log.exception(f"Error generating azure openai batch embeddings: {e}")
return None
2025-11-23 02:33:14 +00:00
def generate_ollama_batch_embeddings(
model: str,
texts: list[str],
url: str,
key: str = "",
prefix: str = None,
user: UserModel = None,
) -> Optional[list[list[float]]]:
try:
log.debug(
f"generate_ollama_batch_embeddings:model {model} batch size: {len(texts)}"
)
json_data = {"input": texts, "model": model}
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
2025-11-23 09:40:05 +00:00
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {key}",
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
headers = include_user_info_headers(headers, user)
2025-11-24 01:15:52 +00:00
2025-11-23 02:33:14 +00:00
r = requests.post(
f"{url}/api/embed",
2025-11-23 09:40:05 +00:00
headers=headers,
2025-11-23 02:33:14 +00:00
json=json_data,
)
r.raise_for_status()
data = r.json()
if "embeddings" in data:
return data["embeddings"]
else:
raise "Something went wrong :/"
except Exception as e:
log.exception(f"Error generating ollama batch embeddings: {e}")
return None
async def agenerate_ollama_batch_embeddings(
model: str,
texts: list[str],
url: str,
key: str = "",
prefix: str = None,
user: UserModel = None,
) -> Optional[list[list[float]]]:
try:
log.debug(
2025-11-23 02:33:14 +00:00
f"agenerate_ollama_batch_embeddings:model {model} batch size: {len(texts)}"
)
form_data = {"input": texts, "model": model}
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
form_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {key}",
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
headers = include_user_info_headers(headers, user)
2025-11-23 02:33:14 +00:00
async with aiohttp.ClientSession(trust_env=True) as session:
async with session.post(
f"{url}/api/embed", headers=headers, json=form_data
) as r:
r.raise_for_status()
data = await r.json()
if "embeddings" in data:
return data["embeddings"]
else:
raise Exception("Something went wrong :/")
except Exception as e:
log.exception(f"Error generating ollama batch embeddings: {e}")
return None
2024-04-27 19:38:50 +00:00
def get_embedding_function(
2024-04-22 20:49:58 +00:00
embedding_engine,
embedding_model,
embedding_function,
2024-11-18 22:19:56 +00:00
url,
key,
2025-02-05 08:07:45 +00:00
embedding_batch_size,
azure_api_version=None,
enable_async=True,
2025-11-23 02:33:14 +00:00
) -> Awaitable:
2024-04-22 20:49:58 +00:00
if embedding_engine == "":
# Sentence transformers: CPU-bound sync operation
2025-11-23 02:33:14 +00:00
async def async_embedding_function(query, prefix=None, user=None):
return await asyncio.to_thread(
(
lambda query, prefix=None: embedding_function.encode(
query, **({"prompt": prefix} if prefix else {})
).tolist()
),
query,
prefix,
)
2025-11-23 02:33:14 +00:00
return async_embedding_function
2025-05-20 02:58:04 +00:00
elif embedding_engine in ["ollama", "openai", "azure_openai"]:
2025-11-23 02:33:14 +00:00
embedding_function = lambda query, prefix=None, user=None: generate_embeddings(
engine=embedding_engine,
model=embedding_model,
text=query,
prefix=prefix,
url=url,
key=key,
user=user,
azure_api_version=azure_api_version,
)
2025-03-31 04:55:15 +00:00
2025-11-23 03:57:27 +00:00
async def async_embedding_function(query, prefix=None, user=None):
if isinstance(query, list):
2025-11-23 02:33:14 +00:00
# Create batches
batches = [
query[i : i + embedding_batch_size]
for i in range(0, len(query), embedding_batch_size)
]
if enable_async:
log.debug(
f"generate_multiple_async: Processing {len(batches)} batches in parallel"
)
# Execute all batches in parallel
tasks = [
embedding_function(batch, prefix=prefix, user=user)
for batch in batches
]
batch_results = await asyncio.gather(*tasks)
else:
log.debug(
f"generate_multiple_async: Processing {len(batches)} batches sequentially"
)
batch_results = []
for batch in batches:
batch_results.append(
await embedding_function(batch, prefix=prefix, user=user)
)
2025-11-23 02:33:14 +00:00
# Flatten results
embeddings = []
for batch_embeddings in batch_results:
if isinstance(batch_embeddings, list):
embeddings.extend(batch_embeddings)
log.debug(
f"generate_multiple_async: Generated {len(embeddings)} embeddings from {len(batches)} parallel batches"
)
return embeddings
else:
2025-11-23 02:33:14 +00:00
return await embedding_function(query, prefix, user)
2025-03-31 04:55:15 +00:00
2025-11-23 02:33:14 +00:00
return async_embedding_function
else:
raise ValueError(f"Unknown embedding engine: {embedding_engine}")
2024-04-22 20:49:58 +00:00
2025-11-23 02:33:14 +00:00
async def generate_embeddings(
engine: str,
model: str,
text: Union[str, list[str]],
prefix: Union[str, None] = None,
**kwargs,
):
url = kwargs.get("url", "")
key = kwargs.get("key", "")
user = kwargs.get("user")
if prefix is not None and RAG_EMBEDDING_PREFIX_FIELD_NAME is None:
if isinstance(text, list):
text = [f"{prefix}{text_element}" for text_element in text]
else:
text = f"{prefix}{text}"
if engine == "ollama":
embeddings = await agenerate_ollama_batch_embeddings(
**{
"model": model,
"texts": text if isinstance(text, list) else [text],
"url": url,
"key": key,
"prefix": prefix,
"user": user,
}
)
return embeddings[0] if isinstance(text, str) else embeddings
elif engine == "openai":
embeddings = await agenerate_openai_batch_embeddings(
model, text if isinstance(text, list) else [text], url, key, prefix, user
)
return embeddings[0] if isinstance(text, str) else embeddings
elif engine == "azure_openai":
azure_api_version = kwargs.get("azure_api_version", "")
embeddings = await agenerate_azure_openai_batch_embeddings(
model,
text if isinstance(text, list) else [text],
url,
key,
azure_api_version,
prefix,
user,
)
return embeddings[0] if isinstance(text, str) else embeddings
def get_reranking_function(reranking_engine, reranking_model, reranking_function):
2025-07-14 10:05:06 +00:00
if reranking_function is None:
return None
if reranking_engine == "external":
2025-11-10 02:33:50 +00:00
return lambda query, documents, user=None: reranking_function.predict(
[(query, doc.page_content) for doc in documents], user=user
)
else:
2025-11-10 02:33:50 +00:00
return lambda query, documents, user=None: reranking_function.predict(
[(query, doc.page_content) for doc in documents]
)
async def get_sources_from_items(
2025-02-26 23:42:19 +00:00
request,
2025-07-11 08:00:21 +00:00
items,
2024-11-19 10:24:32 +00:00
queries,
2024-04-27 19:38:50 +00:00
embedding_function,
2024-04-14 23:48:15 +00:00
k,
2024-04-27 19:38:50 +00:00
reranking_function,
2025-03-06 09:47:57 +00:00
k_reranker,
r,
hybrid_bm25_weight,
2024-04-26 18:41:39 +00:00
hybrid_search,
2025-02-19 05:14:58 +00:00
full_context=False,
2025-07-11 08:00:21 +00:00
user: Optional[UserModel] = None,
2024-04-14 23:48:15 +00:00
):
2025-02-19 05:14:58 +00:00
log.debug(
2025-07-11 08:00:21 +00:00
f"items: {items} {queries} {embedding_function} {reranking_function} {full_context}"
2025-02-19 05:14:58 +00:00
)
2024-03-11 01:40:50 +00:00
extracted_collections = []
2025-06-25 08:20:08 +00:00
query_results = []
2024-03-11 01:40:50 +00:00
2025-07-11 08:00:21 +00:00
for item in items:
2025-06-25 08:20:08 +00:00
query_result = None
2025-07-11 08:29:17 +00:00
collection_names = []
2025-07-11 08:00:21 +00:00
if item.get("type") == "text":
2025-07-11 08:29:17 +00:00
# Raw Text
2025-08-31 21:22:50 +00:00
# Used during temporary chat file uploads or web page & youtube attachements
2025-07-11 08:35:42 +00:00
if item.get("context") == "full":
if item.get("file"):
# if item has file data, use it
query_result = {
"documents": [
[item.get("file", {}).get("data", {}).get("content")]
],
"metadatas": [[item.get("file", {}).get("meta", {})]],
}
if query_result is None:
# Fallback
if item.get("collection_name"):
# If item has a collection name, use it
collection_names.append(item.get("collection_name"))
elif item.get("file"):
# If item has file data, use it
query_result = {
"documents": [
[item.get("file", {}).get("data", {}).get("content")]
],
"metadatas": [[item.get("file", {}).get("meta", {})]],
}
else:
# Fallback to item content
query_result = {
"documents": [[item.get("content")]],
"metadatas": [
[{"file_id": item.get("id"), "name": item.get("name")}]
],
}
2025-07-11 08:00:21 +00:00
elif item.get("type") == "note":
2025-07-08 21:17:25 +00:00
# Note Attached
2025-07-11 08:00:21 +00:00
note = Notes.get_note_by_id(item.get("id"))
2025-07-08 21:17:25 +00:00
2025-07-22 07:38:47 +00:00
if note and (
2025-07-22 13:17:26 +00:00
user.role == "admin"
or note.user_id == user.id
or has_access(user.id, "read", note.access_control)
2025-07-22 07:38:47 +00:00
):
2025-07-11 08:00:21 +00:00
# User has access to the note
query_result = {
"documents": [[note.data.get("content", {}).get("md", "")]],
"metadatas": [[{"file_id": note.id, "name": note.title}]],
}
2025-09-14 08:26:46 +00:00
elif item.get("type") == "chat":
# Chat Attached
chat = Chats.get_chat_by_id(item.get("id"))
if chat and (user.role == "admin" or chat.user_id == user.id):
messages_map = chat.chat.get("history", {}).get("messages", {})
message_id = chat.chat.get("history", {}).get("currentId")
if messages_map and message_id:
# Reconstruct the message list in order
message_list = get_message_list(messages_map, message_id)
message_history = "\n".join(
[
2025-09-14 08:46:49 +00:00
f"#### {m.get('role', 'user').capitalize()}\n{m.get('content')}\n"
2025-09-14 08:26:46 +00:00
for m in message_list
]
)
# User has access to the chat
query_result = {
"documents": [[message_history]],
"metadatas": [[{"file_id": chat.id, "name": chat.title}]],
}
2025-10-04 07:02:26 +00:00
elif item.get("type") == "url":
content, docs = get_content_from_url(request, item.get("url"))
if docs:
query_result = {
"documents": [[content]],
"metadatas": [[{"url": item.get("url"), "name": item.get("url")}]],
}
2025-07-11 08:29:17 +00:00
elif item.get("type") == "file":
if (
item.get("context") == "full"
or request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL
):
2025-07-14 13:50:03 +00:00
if item.get("file", {}).get("data", {}).get("content", ""):
2025-07-11 08:29:17 +00:00
# Manual Full Mode Toggle
# Used from chat file modal, we can assume that the file content will be available from item.get("file").get("data", {}).get("content")
query_result = {
"documents": [
2025-07-14 13:50:03 +00:00
[item.get("file", {}).get("data", {}).get("content", "")]
2025-07-11 08:29:17 +00:00
],
"metadatas": [
[
{
"file_id": item.get("id"),
"name": item.get("name"),
**item.get("file")
.get("data", {})
.get("metadata", {}),
}
]
],
}
elif item.get("id"):
file_object = Files.get_file_by_id(item.get("id"))
if file_object:
query_result = {
"documents": [[file_object.data.get("content", "")]],
"metadatas": [
[
{
"file_id": item.get("id"),
"name": file_object.filename,
"source": file_object.filename,
}
]
],
}
else:
# Fallback to collection names
if item.get("legacy"):
collection_names.append(f"{item['id']}")
else:
collection_names.append(f"file-{item['id']}")
2025-07-11 08:00:21 +00:00
2025-07-11 08:29:17 +00:00
elif item.get("type") == "collection":
2025-10-27 00:22:23 +00:00
# Manual Full Mode Toggle for Collection
knowledge_base = Knowledges.get_knowledge_by_id(item.get("id"))
2025-10-27 00:22:23 +00:00
if knowledge_base and (
user.role == "admin"
or knowledge_base.user_id == user.id
or has_access(user.id, "read", knowledge_base.access_control)
):
if (
item.get("context") == "full"
or request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL
2025-07-11 08:00:21 +00:00
):
2025-10-27 00:22:23 +00:00
if knowledge_base and (
user.role == "admin"
or knowledge_base.user_id == user.id
or has_access(user.id, "read", knowledge_base.access_control)
):
2025-12-02 15:53:32 +00:00
files = Knowledges.get_files_by_id(knowledge_base.id)
2025-10-27 00:22:23 +00:00
documents = []
metadatas = []
2025-12-02 15:53:32 +00:00
for file in files:
documents.append(file.data.get("content", ""))
metadatas.append(
{
"file_id": file.id,
"name": file.filename,
"source": file.filename,
}
)
2025-10-27 00:22:23 +00:00
query_result = {
"documents": [documents],
"metadatas": [metadatas],
}
2024-10-04 06:06:47 +00:00
else:
2025-10-27 00:22:23 +00:00
# Fallback to collection names
if item.get("legacy"):
collection_names = item.get("collection_names", [])
else:
collection_names.append(item["id"])
2024-05-06 22:49:00 +00:00
2025-07-11 08:29:17 +00:00
elif item.get("docs"):
# BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL
query_result = {
"documents": [[doc.get("content") for doc in item.get("docs")]],
"metadatas": [[doc.get("metadata") for doc in item.get("docs")]],
}
elif item.get("collection_name"):
# Direct Collection Name
collection_names.append(item["collection_name"])
2025-07-15 17:57:24 +00:00
elif item.get("collection_names"):
# Collection Names List
collection_names.extend(item["collection_names"])
2025-07-11 08:29:17 +00:00
# If query_result is None
# Fallback to collection names and vector search the collections
if query_result is None and collection_names:
2024-09-29 20:52:27 +00:00
collection_names = set(collection_names).difference(extracted_collections)
if not collection_names:
2025-07-11 08:00:21 +00:00
log.debug(f"skipping {item} as it has already been extracted")
2024-09-29 20:52:27 +00:00
continue
2024-04-14 23:48:15 +00:00
2025-07-11 08:35:42 +00:00
try:
if full_context:
2025-06-25 08:20:08 +00:00
query_result = get_all_items_from_collections(collection_names)
2025-07-11 08:35:42 +00:00
else:
query_result = None # Initialize to None
if hybrid_search:
try:
query_result = await query_collection_with_hybrid_search(
2024-09-29 20:52:27 +00:00
collection_names=collection_names,
2024-11-19 10:24:32 +00:00
queries=queries,
2024-09-29 20:52:27 +00:00
embedding_function=embedding_function,
k=k,
2025-07-11 08:35:42 +00:00
reranking_function=reranking_function,
k_reranker=k_reranker,
r=r,
hybrid_bm25_weight=hybrid_bm25_weight,
enable_enriched_texts=request.app.state.config.ENABLE_RAG_HYBRID_SEARCH_ENRICHED_TEXTS,
2025-07-11 08:35:42 +00:00
)
except Exception as e:
log.debug(
"Error when using hybrid search, using non hybrid search as fallback."
2024-09-29 20:52:27 +00:00
)
2025-07-11 08:35:42 +00:00
# fallback to non-hybrid search
if not hybrid_search and query_result is None:
query_result = await query_collection(
2025-07-11 08:35:42 +00:00
collection_names=collection_names,
queries=queries,
embedding_function=embedding_function,
k=k,
)
except Exception as e:
log.exception(e)
2024-09-29 20:52:27 +00:00
extracted_collections.extend(collection_names)
2024-03-11 01:40:50 +00:00
2025-06-25 08:20:08 +00:00
if query_result:
2025-07-11 08:00:21 +00:00
if "data" in item:
del item["data"]
query_results.append({**query_result, "file": item})
2024-03-11 01:40:50 +00:00
2024-11-22 03:46:09 +00:00
sources = []
2025-06-25 08:20:08 +00:00
for query_result in query_results:
try:
2025-06-25 08:20:08 +00:00
if "documents" in query_result:
if "metadatas" in query_result:
2024-11-22 03:46:09 +00:00
source = {
2025-06-25 08:20:08 +00:00
"source": query_result["file"],
"document": query_result["documents"][0],
"metadata": query_result["metadatas"][0],
}
2025-06-25 08:20:08 +00:00
if "distances" in query_result and query_result["distances"]:
source["distances"] = query_result["distances"][0]
2024-11-22 03:46:09 +00:00
sources.append(source)
except Exception as e:
log.exception(e)
2024-11-22 03:46:09 +00:00
return sources
2024-04-04 18:07:42 +00:00
2024-04-25 12:49:59 +00:00
def get_model_path(model: str, update_model: bool = False):
# Construct huggingface_hub kwargs with local_files_only to return the snapshot path
cache_dir = os.getenv("SENTENCE_TRANSFORMERS_HOME")
local_files_only = not update_model
if OFFLINE_MODE:
local_files_only = True
2024-04-25 12:49:59 +00:00
snapshot_kwargs = {
"cache_dir": cache_dir,
"local_files_only": local_files_only,
}
2024-04-25 18:28:31 +00:00
log.debug(f"model: {model}")
2024-04-25 12:49:59 +00:00
log.debug(f"snapshot_kwargs: {snapshot_kwargs}")
# Inspiration from upstream sentence_transformers
if (
2024-07-15 09:09:05 +00:00
os.path.exists(model)
2024-04-25 12:49:59 +00:00
or ("\\" in model or model.count("/") > 1)
and local_files_only
):
# If fully qualified path exists, return input, else set repo_id
2024-07-15 09:09:05 +00:00
return model
2024-04-25 12:49:59 +00:00
elif "/" not in model:
# Set valid repo_id for model short-name
model = "sentence-transformers" + "/" + model
snapshot_kwargs["repo_id"] = model
# Attempt to query the huggingface_hub library to determine the local path and/or to update
try:
model_repo_path = snapshot_download(**snapshot_kwargs)
log.debug(f"model_repo_path: {model_repo_path}")
return model_repo_path
except Exception as e:
log.exception(f"Cannot determine model snapshot path: {e}")
2024-07-15 09:09:05 +00:00
return model
2024-04-25 12:49:59 +00:00
import operator
from typing import Optional, Sequence
from langchain_core.callbacks import Callbacks
2024-08-27 22:10:27 +00:00
from langchain_core.documents import BaseDocumentCompressor, Document
class RerankCompressor(BaseDocumentCompressor):
2024-04-27 19:38:50 +00:00
embedding_function: Any
2024-04-29 17:15:58 +00:00
top_n: int
reranking_function: Any
r_score: float
class Config:
2024-09-19 15:05:49 +00:00
extra = "forbid"
arbitrary_types_allowed = True
2025-11-24 10:52:18 +00:00
def compress_documents(
self,
documents: Sequence[Document],
query: str,
callbacks: Optional[Callbacks] = None,
) -> Sequence[Document]:
"""Compress retrieved documents given the query context.
Args:
documents: The retrieved documents.
query: The query context.
callbacks: Optional callbacks to run during compression.
Returns:
The compressed documents.
"""
2025-11-24 10:58:22 +00:00
return []
2025-11-24 10:52:18 +00:00
async def acompress_documents(
self,
documents: Sequence[Document],
query: str,
callbacks: Optional[Callbacks] = None,
) -> Sequence[Document]:
2024-04-29 17:15:58 +00:00
reranking = self.reranking_function is not None
2025-08-21 17:48:21 +00:00
scores = None
2024-04-29 17:15:58 +00:00
if reranking:
2025-11-24 10:58:22 +00:00
scores = self.reranking_function(query, documents)
else:
from sentence_transformers import util
2025-11-23 02:33:14 +00:00
query_embedding = await self.embedding_function(
query, RAG_EMBEDDING_QUERY_PREFIX
)
document_embedding = await self.embedding_function(
2025-03-31 04:55:15 +00:00
[doc.page_content for doc in documents], RAG_EMBEDDING_CONTENT_PREFIX
)
scores = util.cos_sim(query_embedding, document_embedding)[0]
2025-08-22 12:47:05 +00:00
if scores is not None:
2025-08-21 17:48:21 +00:00
docs_with_scores = list(
zip(
documents,
scores.tolist() if not isinstance(scores, list) else scores,
)
)
if self.r_score:
docs_with_scores = [
(d, s) for d, s in docs_with_scores if s >= self.r_score
]
result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=True)
final_results = []
for doc, doc_score in result[: self.top_n]:
metadata = doc.metadata
metadata["score"] = doc_score
doc = Document(
page_content=doc.page_content,
metadata=metadata,
)
final_results.append(doc)
return final_results
else:
log.warning(
"No valid scores found, check your reranking function. Returning original documents."
)
2025-08-21 17:48:21 +00:00
return documents