open-webui/backend/open_webui/retrieval/utils.py
2025-11-28 12:10:12 +01:00

1333 lines
46 KiB
Python

import logging
import os
from typing import Awaitable, Optional, Union
import requests
import aiohttp
import asyncio
import hashlib
from concurrent.futures import ThreadPoolExecutor
import time
import re
from urllib.parse import quote
from huggingface_hub import snapshot_download
from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriever
from langchain_community.retrievers import BM25Retriever
from langchain_core.documents import Document
from open_webui.config import VECTOR_DB
from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT
from open_webui.models.users import UserModel
from open_webui.models.files import Files
from open_webui.models.knowledge import Knowledges
from open_webui.models.chats import Chats
from open_webui.models.notes import Notes
from open_webui.retrieval.vector.main import GetResult
from open_webui.utils.access_control import has_access
from open_webui.utils.headers import include_user_info_headers
from open_webui.utils.misc import get_message_list
from open_webui.retrieval.web.utils import get_web_loader
from open_webui.retrieval.loaders.youtube import YoutubeLoader
from open_webui.env import (
SRC_LOG_LEVELS,
OFFLINE_MODE,
ENABLE_FORWARD_USER_INFO_HEADERS,
)
from open_webui.config import (
RAG_EMBEDDING_QUERY_PREFIX,
RAG_EMBEDDING_CONTENT_PREFIX,
RAG_EMBEDDING_PREFIX_FIELD_NAME,
)
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
from typing import Any
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.retrievers import BaseRetriever
def is_youtube_url(url: str) -> bool:
youtube_regex = r"^(https?://)?(www\.)?(youtube\.com|youtu\.be)/.+$"
return re.match(youtube_regex, url) is not None
def get_loader(request, url: str):
if is_youtube_url(url):
return YoutubeLoader(
url,
language=request.app.state.config.YOUTUBE_LOADER_LANGUAGE,
proxy_url=request.app.state.config.YOUTUBE_LOADER_PROXY_URL,
)
else:
return get_web_loader(
url,
verify_ssl=request.app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION,
requests_per_second=request.app.state.config.WEB_LOADER_CONCURRENT_REQUESTS,
trust_env=request.app.state.config.WEB_SEARCH_TRUST_ENV,
)
def get_content_from_url(request, url: str) -> str:
loader = get_loader(request, url)
docs = loader.load()
content = " ".join([doc.page_content for doc in docs])
return content, docs
class VectorSearchRetriever(BaseRetriever):
collection_name: Any
embedding_function: Any
top_k: int
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> list[Document]:
"""Get documents relevant to a query.
Args:
query: String to find relevant documents for.
run_manager: The callback handler to use.
Returns:
List of relevant documents.
"""
return []
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
) -> list[Document]:
embedding = await self.embedding_function(query, RAG_EMBEDDING_QUERY_PREFIX)
result = VECTOR_DB_CLIENT.search(
collection_name=self.collection_name,
vectors=[embedding],
limit=self.top_k,
)
ids = result.ids[0]
metadatas = result.metadatas[0]
documents = result.documents[0]
results = []
for idx in range(len(ids)):
results.append(
Document(
metadata=metadatas[idx],
page_content=documents[idx],
)
)
return results
def query_doc(
collection_name: str, query_embedding: list[float], k: int, user: UserModel = None
):
try:
log.debug(f"query_doc:doc {collection_name}")
result = VECTOR_DB_CLIENT.search(
collection_name=collection_name,
vectors=[query_embedding],
limit=k,
)
if result:
log.info(f"query_doc:result {result.ids} {result.metadatas}")
return result
except Exception as e:
log.exception(f"Error querying doc {collection_name} with limit {k}: {e}")
raise e
def get_doc(collection_name: str, user: UserModel = None):
try:
log.debug(f"get_doc:doc {collection_name}")
result = VECTOR_DB_CLIENT.get(collection_name=collection_name)
if result:
log.info(f"query_doc:result {result.ids} {result.metadatas}")
return result
except Exception as e:
log.exception(f"Error getting doc {collection_name}: {e}")
raise e
def get_enriched_texts(collection_result: GetResult) -> list[str]:
enriched_texts = []
for idx, text in enumerate(collection_result.documents[0]):
metadata = collection_result.metadatas[0][idx]
metadata_parts = [text]
# Add filename (repeat twice for extra weight in BM25 scoring)
if metadata.get("name"):
filename = metadata["name"]
filename_tokens = (
filename.replace("_", " ").replace("-", " ").replace(".", " ")
)
metadata_parts.append(
f"Filename: {filename} {filename_tokens} {filename_tokens}"
)
# Add title if available
if metadata.get("title"):
metadata_parts.append(f"Title: {metadata['title']}")
# Add document section headings if available (from markdown splitter)
if metadata.get("headings") and isinstance(metadata["headings"], list):
headings = " > ".join(str(h) for h in metadata["headings"])
metadata_parts.append(f"Section: {headings}")
# Add source URL/path if available
if metadata.get("source"):
metadata_parts.append(f"Source: {metadata['source']}")
# Add snippet for web search results
if metadata.get("snippet"):
metadata_parts.append(f"Snippet: {metadata['snippet']}")
enriched_texts.append(" ".join(metadata_parts))
return enriched_texts
async def query_doc_with_hybrid_search(
collection_name: str,
collection_result: GetResult,
query: str,
embedding_function,
k: int,
reranking_function,
k_reranker: int,
r: float,
hybrid_bm25_weight: float,
enable_enriched_texts: bool = False,
) -> dict:
try:
# First check if collection_result has the required attributes
if (
not collection_result
or not hasattr(collection_result, "documents")
or not hasattr(collection_result, "metadatas")
):
log.warning(f"query_doc_with_hybrid_search:no_docs {collection_name}")
return {"documents": [], "metadatas": [], "distances": []}
# Now safely check the documents content after confirming attributes exist
if (
not collection_result.documents
or len(collection_result.documents) == 0
or not collection_result.documents[0]
):
log.warning(f"query_doc_with_hybrid_search:no_docs {collection_name}")
return {"documents": [], "metadatas": [], "distances": []}
log.debug(f"query_doc_with_hybrid_search:doc {collection_name}")
bm25_texts = (
get_enriched_texts(collection_result)
if enable_enriched_texts
else collection_result.documents[0]
)
bm25_retriever = BM25Retriever.from_texts(
texts=bm25_texts,
metadatas=collection_result.metadatas[0],
)
bm25_retriever.k = k
vector_search_retriever = VectorSearchRetriever(
collection_name=collection_name,
embedding_function=embedding_function,
top_k=k,
)
if hybrid_bm25_weight <= 0:
ensemble_retriever = EnsembleRetriever(
retrievers=[vector_search_retriever], weights=[1.0]
)
elif hybrid_bm25_weight >= 1:
ensemble_retriever = EnsembleRetriever(
retrievers=[bm25_retriever], weights=[1.0]
)
else:
ensemble_retriever = EnsembleRetriever(
retrievers=[bm25_retriever, vector_search_retriever],
weights=[hybrid_bm25_weight, 1.0 - hybrid_bm25_weight],
)
compressor = RerankCompressor(
embedding_function=embedding_function,
top_n=k_reranker,
reranking_function=reranking_function,
r_score=r,
)
compression_retriever = ContextualCompressionRetriever(
base_compressor=compressor, base_retriever=ensemble_retriever
)
result = await compression_retriever.ainvoke(query)
distances = [d.metadata.get("score") for d in result]
documents = [d.page_content for d in result]
metadatas = [d.metadata for d in result]
# retrieve only min(k, k_reranker) items, sort and cut by distance if k < k_reranker
if k < k_reranker:
sorted_items = sorted(
zip(distances, metadatas, documents), key=lambda x: x[0], reverse=True
)
sorted_items = sorted_items[:k]
if sorted_items:
distances, documents, metadatas = map(list, zip(*sorted_items))
else:
distances, documents, metadatas = [], [], []
result = {
"distances": [distances],
"documents": [documents],
"metadatas": [metadatas],
}
log.info(
"query_doc_with_hybrid_search:result "
+ f'{result["metadatas"]} {result["distances"]}'
)
return result
except Exception as e:
log.exception(f"Error querying doc {collection_name} with hybrid search: {e}")
raise e
def merge_get_results(get_results: list[dict]) -> dict:
# Initialize lists to store combined data
combined_documents = []
combined_metadatas = []
combined_ids = []
for data in get_results:
combined_documents.extend(data["documents"][0])
combined_metadatas.extend(data["metadatas"][0])
combined_ids.extend(data["ids"][0])
# Create the output dictionary
result = {
"documents": [combined_documents],
"metadatas": [combined_metadatas],
"ids": [combined_ids],
}
return result
def merge_and_sort_query_results(query_results: list[dict], k: int) -> dict:
# Initialize lists to store combined data
combined = dict() # To store documents with unique document hashes
for data in query_results:
if (
len(data.get("distances", [])) == 0
or len(data.get("documents", [])) == 0
or len(data.get("metadatas", [])) == 0
):
continue
distances = data["distances"][0]
documents = data["documents"][0]
metadatas = data["metadatas"][0]
for distance, document, metadata in zip(distances, documents, metadatas):
if isinstance(document, str):
doc_hash = hashlib.sha256(
document.encode()
).hexdigest() # Compute a hash for uniqueness
if doc_hash not in combined.keys():
combined[doc_hash] = (distance, document, metadata)
continue # if doc is new, no further comparison is needed
# if doc is alredy in, but new distance is better, update
if distance > combined[doc_hash][0]:
combined[doc_hash] = (distance, document, metadata)
combined = list(combined.values())
# Sort the list based on distances
combined.sort(key=lambda x: x[0], reverse=True)
# Slice to keep only the top k elements
sorted_distances, sorted_documents, sorted_metadatas = (
zip(*combined[:k]) if combined else ([], [], [])
)
# Create and return the output dictionary
return {
"distances": [list(sorted_distances)],
"documents": [list(sorted_documents)],
"metadatas": [list(sorted_metadatas)],
}
def get_all_items_from_collections(collection_names: list[str]) -> dict:
results = []
for collection_name in collection_names:
if collection_name:
try:
result = get_doc(collection_name=collection_name)
if result is not None:
results.append(result.model_dump())
except Exception as e:
log.exception(f"Error when querying the collection: {e}")
else:
pass
return merge_get_results(results)
async def query_collection(
collection_names: list[str],
queries: list[str],
embedding_function,
k: int,
) -> dict:
results = []
error = False
def process_query_collection(collection_name, query_embedding):
try:
if collection_name:
result = query_doc(
collection_name=collection_name,
k=k,
query_embedding=query_embedding,
)
if result is not None:
return result.model_dump(), None
return None, None
except Exception as e:
log.exception(f"Error when querying the collection: {e}")
return None, e
# Generate all query embeddings (in one call)
query_embeddings = await embedding_function(
queries, prefix=RAG_EMBEDDING_QUERY_PREFIX
)
log.debug(
f"query_collection: processing {len(queries)} queries across {len(collection_names)} collections"
)
with ThreadPoolExecutor() as executor:
future_results = []
for query_embedding in query_embeddings:
for collection_name in collection_names:
result = executor.submit(
process_query_collection, collection_name, query_embedding
)
future_results.append(result)
task_results = [future.result() for future in future_results]
for result, err in task_results:
if err is not None:
error = True
elif result is not None:
results.append(result)
if error and not results:
log.warning("All collection queries failed. No results returned.")
return merge_and_sort_query_results(results, k=k)
async def query_collection_with_hybrid_search(
collection_names: list[str],
queries: list[str],
embedding_function,
k: int,
reranking_function,
k_reranker: int,
r: float,
hybrid_bm25_weight: float,
enable_enriched_texts: bool = False,
) -> dict:
results = []
error = False
# Fetch collection data once per collection sequentially
# Avoid fetching the same data multiple times later
collection_results = {}
for collection_name in collection_names:
try:
log.debug(
f"query_collection_with_hybrid_search:VECTOR_DB_CLIENT.get:collection {collection_name}"
)
collection_results[collection_name] = VECTOR_DB_CLIENT.get(
collection_name=collection_name
)
except Exception as e:
log.exception(f"Failed to fetch collection {collection_name}: {e}")
collection_results[collection_name] = None
log.info(
f"Starting hybrid search for {len(queries)} queries in {len(collection_names)} collections..."
)
async def process_query(collection_name, query):
try:
result = await query_doc_with_hybrid_search(
collection_name=collection_name,
collection_result=collection_results[collection_name],
query=query,
embedding_function=embedding_function,
k=k,
reranking_function=reranking_function,
k_reranker=k_reranker,
r=r,
hybrid_bm25_weight=hybrid_bm25_weight,
enable_enriched_texts=enable_enriched_texts,
)
return result, None
except Exception as e:
log.exception(f"Error when querying the collection with hybrid_search: {e}")
return None, e
# Prepare tasks for all collections and queries
# Avoid running any tasks for collections that failed to fetch data (have assigned None)
tasks = [
(collection_name, query)
for collection_name in collection_names
if collection_results[collection_name] is not None
for query in queries
]
# Run all queries in parallel using asyncio.gather
task_results = await asyncio.gather(
*[process_query(collection_name, query) for collection_name, query in tasks]
)
for result, err in task_results:
if err is not None:
error = True
elif result is not None:
results.append(result)
if error and not results:
raise Exception(
"Hybrid search failed for all collections. Using Non-hybrid search as fallback."
)
return merge_and_sort_query_results(results, k=k)
def generate_openai_batch_embeddings(
model: str,
texts: list[str],
url: str = "https://api.openai.com/v1",
key: str = "",
prefix: str = None,
user: UserModel = None,
) -> Optional[list[list[float]]]:
try:
log.debug(
f"generate_openai_batch_embeddings:model {model} batch size: {len(texts)}"
)
json_data = {"input": texts, "model": model}
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {key}",
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
headers = include_user_info_headers(headers, user)
r = requests.post(
f"{url}/embeddings",
headers=headers,
json=json_data,
)
r.raise_for_status()
data = r.json()
if "data" in data:
return [elem["embedding"] for elem in data["data"]]
else:
raise "Something went wrong :/"
except Exception as e:
log.exception(f"Error generating openai batch embeddings: {e}")
return None
async def agenerate_openai_batch_embeddings(
model: str,
texts: list[str],
url: str = "https://api.openai.com/v1",
key: str = "",
prefix: str = None,
user: UserModel = None,
) -> Optional[list[list[float]]]:
try:
log.debug(
f"agenerate_openai_batch_embeddings:model {model} batch size: {len(texts)}"
)
form_data = {"input": texts, "model": model}
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
form_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {key}",
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
headers = include_user_info_headers(headers, user)
async with aiohttp.ClientSession(trust_env=True) as session:
async with session.post(
f"{url}/embeddings", headers=headers, json=form_data
) as r:
r.raise_for_status()
data = await r.json()
if "data" in data:
return [item["embedding"] for item in data["data"]]
else:
raise Exception("Something went wrong :/")
except Exception as e:
log.exception(f"Error generating openai batch embeddings: {e}")
return None
def generate_azure_openai_batch_embeddings(
model: str,
texts: list[str],
url: str,
key: str = "",
version: str = "",
prefix: str = None,
user: UserModel = None,
) -> Optional[list[list[float]]]:
try:
log.debug(
f"generate_azure_openai_batch_embeddings:deployment {model} batch size: {len(texts)}"
)
json_data = {"input": texts}
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
url = f"{url}/openai/deployments/{model}/embeddings?api-version={version}"
for _ in range(5):
headers = {
"Content-Type": "application/json",
"api-key": key,
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
headers = include_user_info_headers(headers, user)
r = requests.post(
url,
headers=headers,
json=json_data,
)
if r.status_code == 429:
retry = float(r.headers.get("Retry-After", "1"))
time.sleep(retry)
continue
r.raise_for_status()
data = r.json()
if "data" in data:
return [elem["embedding"] for elem in data["data"]]
else:
raise Exception("Something went wrong :/")
return None
except Exception as e:
log.exception(f"Error generating azure openai batch embeddings: {e}")
return None
async def agenerate_azure_openai_batch_embeddings(
model: str,
texts: list[str],
url: str,
key: str = "",
version: str = "",
prefix: str = None,
user: UserModel = None,
) -> Optional[list[list[float]]]:
try:
log.debug(
f"agenerate_azure_openai_batch_embeddings:deployment {model} batch size: {len(texts)}"
)
form_data = {"input": texts}
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
form_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
full_url = f"{url}/openai/deployments/{model}/embeddings?api-version={version}"
headers = {
"Content-Type": "application/json",
"api-key": key,
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
headers = include_user_info_headers(headers, user)
async with aiohttp.ClientSession(trust_env=True) as session:
async with session.post(full_url, headers=headers, json=form_data) as r:
r.raise_for_status()
data = await r.json()
if "data" in data:
return [item["embedding"] for item in data["data"]]
else:
raise Exception("Something went wrong :/")
except Exception as e:
log.exception(f"Error generating azure openai batch embeddings: {e}")
return None
def generate_ollama_batch_embeddings(
model: str,
texts: list[str],
url: str,
key: str = "",
prefix: str = None,
user: UserModel = None,
) -> Optional[list[list[float]]]:
try:
log.debug(
f"generate_ollama_batch_embeddings:model {model} batch size: {len(texts)}"
)
json_data = {"input": texts, "model": model}
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {key}",
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
headers = include_user_info_headers(headers, user)
r = requests.post(
f"{url}/api/embed",
headers=headers,
json=json_data,
)
r.raise_for_status()
data = r.json()
if "embeddings" in data:
return data["embeddings"]
else:
raise "Something went wrong :/"
except Exception as e:
log.exception(f"Error generating ollama batch embeddings: {e}")
return None
async def agenerate_ollama_batch_embeddings(
model: str,
texts: list[str],
url: str,
key: str = "",
prefix: str = None,
user: UserModel = None,
) -> Optional[list[list[float]]]:
try:
log.debug(
f"agenerate_ollama_batch_embeddings:model {model} batch size: {len(texts)}"
)
form_data = {"input": texts, "model": model}
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
form_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {key}",
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
headers = include_user_info_headers(headers, user)
async with aiohttp.ClientSession(trust_env=True) as session:
async with session.post(
f"{url}/api/embed", headers=headers, json=form_data
) as r:
r.raise_for_status()
data = await r.json()
if "embeddings" in data:
return data["embeddings"]
else:
raise Exception("Something went wrong :/")
except Exception as e:
log.exception(f"Error generating ollama batch embeddings: {e}")
return None
def get_embedding_function(
embedding_engine,
embedding_model,
embedding_function,
url,
key,
embedding_batch_size,
azure_api_version=None,
enable_async=True,
) -> Awaitable:
if embedding_engine == "":
# Sentence transformers: CPU-bound sync operation
async def async_embedding_function(query, prefix=None, user=None):
return await asyncio.to_thread(
(
lambda query, prefix=None: embedding_function.encode(
query, **({"prompt": prefix} if prefix else {})
).tolist()
),
query,
prefix,
)
return async_embedding_function
elif embedding_engine in ["ollama", "openai", "azure_openai"]:
embedding_function = lambda query, prefix=None, user=None: generate_embeddings(
engine=embedding_engine,
model=embedding_model,
text=query,
prefix=prefix,
url=url,
key=key,
user=user,
azure_api_version=azure_api_version,
)
async def async_embedding_function(query, prefix=None, user=None):
if isinstance(query, list):
# Create batches
batches = [
query[i : i + embedding_batch_size]
for i in range(0, len(query), embedding_batch_size)
]
if enable_async:
log.debug(
f"generate_multiple_async: Processing {len(batches)} batches in parallel"
)
# Execute all batches in parallel
tasks = [
embedding_function(batch, prefix=prefix, user=user)
for batch in batches
]
batch_results = await asyncio.gather(*tasks)
else:
log.debug(
f"generate_multiple_async: Processing {len(batches)} batches sequentially"
)
batch_results = []
for batch in batches:
batch_results.append(
await embedding_function(batch, prefix=prefix, user=user)
)
# Flatten results
embeddings = []
for batch_embeddings in batch_results:
if isinstance(batch_embeddings, list):
embeddings.extend(batch_embeddings)
log.debug(
f"generate_multiple_async: Generated {len(embeddings)} embeddings from {len(batches)} parallel batches"
)
return embeddings
else:
return await embedding_function(query, prefix, user)
return async_embedding_function
else:
raise ValueError(f"Unknown embedding engine: {embedding_engine}")
async def generate_embeddings(
engine: str,
model: str,
text: Union[str, list[str]],
prefix: Union[str, None] = None,
**kwargs,
):
url = kwargs.get("url", "")
key = kwargs.get("key", "")
user = kwargs.get("user")
if prefix is not None and RAG_EMBEDDING_PREFIX_FIELD_NAME is None:
if isinstance(text, list):
text = [f"{prefix}{text_element}" for text_element in text]
else:
text = f"{prefix}{text}"
if engine == "ollama":
embeddings = await agenerate_ollama_batch_embeddings(
**{
"model": model,
"texts": text if isinstance(text, list) else [text],
"url": url,
"key": key,
"prefix": prefix,
"user": user,
}
)
return embeddings[0] if isinstance(text, str) else embeddings
elif engine == "openai":
embeddings = await agenerate_openai_batch_embeddings(
model, text if isinstance(text, list) else [text], url, key, prefix, user
)
return embeddings[0] if isinstance(text, str) else embeddings
elif engine == "azure_openai":
azure_api_version = kwargs.get("azure_api_version", "")
embeddings = await agenerate_azure_openai_batch_embeddings(
model,
text if isinstance(text, list) else [text],
url,
key,
azure_api_version,
prefix,
user,
)
return embeddings[0] if isinstance(text, str) else embeddings
def get_reranking_function(reranking_engine, reranking_model, reranking_function):
if reranking_function is None:
return None
if reranking_engine == "external":
return lambda query, documents, user=None: reranking_function.predict(
[(query, doc.page_content) for doc in documents], user=user
)
else:
return lambda query, documents, user=None: reranking_function.predict(
[(query, doc.page_content) for doc in documents]
)
async def get_sources_from_items(
request,
items,
queries,
embedding_function,
k,
reranking_function,
k_reranker,
r,
hybrid_bm25_weight,
hybrid_search,
full_context=False,
user: Optional[UserModel] = None,
):
log.debug(
f"items: {items} {queries} {embedding_function} {reranking_function} {full_context}"
)
extracted_collections = []
query_results = []
for item in items:
query_result = None
collection_names = []
if item.get("type") == "text":
# Raw Text
# Used during temporary chat file uploads or web page & youtube attachements
if item.get("context") == "full":
if item.get("file"):
# if item has file data, use it
query_result = {
"documents": [
[item.get("file", {}).get("data", {}).get("content")]
],
"metadatas": [[item.get("file", {}).get("meta", {})]],
}
if query_result is None:
# Fallback
if item.get("collection_name"):
# If item has a collection name, use it
collection_names.append(item.get("collection_name"))
elif item.get("file"):
# If item has file data, use it
query_result = {
"documents": [
[item.get("file", {}).get("data", {}).get("content")]
],
"metadatas": [[item.get("file", {}).get("meta", {})]],
}
else:
# Fallback to item content
query_result = {
"documents": [[item.get("content")]],
"metadatas": [
[{"file_id": item.get("id"), "name": item.get("name")}]
],
}
elif item.get("type") == "note":
# Note Attached
note = Notes.get_note_by_id(item.get("id"))
if note and (
user.role == "admin"
or note.user_id == user.id
or has_access(user.id, "read", note.access_control)
):
# User has access to the note
query_result = {
"documents": [[note.data.get("content", {}).get("md", "")]],
"metadatas": [[{"file_id": note.id, "name": note.title}]],
}
elif item.get("type") == "chat":
# Chat Attached
chat = Chats.get_chat_by_id(item.get("id"))
if chat and (user.role == "admin" or chat.user_id == user.id):
messages_map = chat.chat.get("history", {}).get("messages", {})
message_id = chat.chat.get("history", {}).get("currentId")
if messages_map and message_id:
# Reconstruct the message list in order
message_list = get_message_list(messages_map, message_id)
message_history = "\n".join(
[
f"#### {m.get('role', 'user').capitalize()}\n{m.get('content')}\n"
for m in message_list
]
)
# User has access to the chat
query_result = {
"documents": [[message_history]],
"metadatas": [[{"file_id": chat.id, "name": chat.title}]],
}
elif item.get("type") == "url":
content, docs = get_content_from_url(request, item.get("url"))
if docs:
query_result = {
"documents": [[content]],
"metadatas": [[{"url": item.get("url"), "name": item.get("url")}]],
}
elif item.get("type") == "file":
file_bypassed_rag = False
if item.get("id"):
file_object = Files.get_file_by_id(item.get("id"))
if file_object and file_object.meta:
file_bypassed_rag = file_object.meta.get("bypass_rag", False)
if (
item.get("context") == "full"
or request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL
or file_bypassed_rag
):
if item.get("file", {}).get("data", {}).get("content", ""):
# Manual Full Mode Toggle
# Used from chat file modal, we can assume that the file content will be available from item.get("file").get("data", {}).get("content")
query_result = {
"documents": [
[item.get("file", {}).get("data", {}).get("content", "")]
],
"metadatas": [
[
{
"file_id": item.get("id"),
"name": item.get("name"),
**item.get("file")
.get("data", {})
.get("metadata", {}),
}
]
],
}
elif item.get("id"):
file_object = Files.get_file_by_id(item.get("id"))
if file_object:
query_result = {
"documents": [[file_object.data.get("content", "")]],
"metadatas": [
[
{
"file_id": item.get("id"),
"name": file_object.filename,
"source": file_object.filename,
}
]
],
}
else:
# Fallback to collection names
if item.get("legacy"):
collection_names.append(f"{item['id']}")
else:
collection_names.append(f"file-{item['id']}")
elif item.get("type") == "collection":
# Manual Full Mode Toggle for Collection
knowledge_base = Knowledges.get_knowledge_by_id(item.get("id"))
if knowledge_base and (
user.role == "admin"
or knowledge_base.user_id == user.id
or has_access(user.id, "read", knowledge_base.access_control)
):
if (
item.get("context") == "full"
or request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL
):
if knowledge_base and (
user.role == "admin"
or knowledge_base.user_id == user.id
or has_access(user.id, "read", knowledge_base.access_control)
):
file_ids = knowledge_base.data.get("file_ids", [])
documents = []
metadatas = []
for file_id in file_ids:
file_object = Files.get_file_by_id(file_id)
if file_object:
documents.append(file_object.data.get("content", ""))
metadatas.append(
{
"file_id": file_id,
"name": file_object.filename,
"source": file_object.filename,
}
)
query_result = {
"documents": [documents],
"metadatas": [metadatas],
}
else:
# Fallback to collection names
if item.get("legacy"):
collection_names = item.get("collection_names", [])
else:
collection_names.append(item["id"])
elif item.get("docs"):
# BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL
query_result = {
"documents": [[doc.get("content") for doc in item.get("docs")]],
"metadatas": [[doc.get("metadata") for doc in item.get("docs")]],
}
elif item.get("collection_name"):
# Direct Collection Name
collection_names.append(item["collection_name"])
elif item.get("collection_names"):
# Collection Names List
collection_names.extend(item["collection_names"])
# If query_result is None
# Fallback to collection names and vector search the collections
if query_result is None and collection_names:
collection_names = set(collection_names).difference(extracted_collections)
if not collection_names:
log.debug(f"skipping {item} as it has already been extracted")
continue
try:
if full_context:
query_result = get_all_items_from_collections(collection_names)
else:
query_result = None # Initialize to None
if hybrid_search:
try:
query_result = await query_collection_with_hybrid_search(
collection_names=collection_names,
queries=queries,
embedding_function=embedding_function,
k=k,
reranking_function=reranking_function,
k_reranker=k_reranker,
r=r,
hybrid_bm25_weight=hybrid_bm25_weight,
enable_enriched_texts=request.app.state.config.ENABLE_RAG_HYBRID_SEARCH_ENRICHED_TEXTS,
)
except Exception as e:
log.debug(
"Error when using hybrid search, using non hybrid search as fallback."
)
# fallback to non-hybrid search
if not hybrid_search and query_result is None:
query_result = await query_collection(
collection_names=collection_names,
queries=queries,
embedding_function=embedding_function,
k=k,
)
except Exception as e:
log.exception(e)
extracted_collections.extend(collection_names)
if query_result:
if "data" in item:
del item["data"]
query_results.append({**query_result, "file": item})
sources = []
for query_result in query_results:
try:
if "documents" in query_result:
if "metadatas" in query_result:
source = {
"source": query_result["file"],
"document": query_result["documents"][0],
"metadata": query_result["metadatas"][0],
}
if "distances" in query_result and query_result["distances"]:
source["distances"] = query_result["distances"][0]
sources.append(source)
except Exception as e:
log.exception(e)
return sources
def get_model_path(model: str, update_model: bool = False):
# Construct huggingface_hub kwargs with local_files_only to return the snapshot path
cache_dir = os.getenv("SENTENCE_TRANSFORMERS_HOME")
local_files_only = not update_model
if OFFLINE_MODE:
local_files_only = True
snapshot_kwargs = {
"cache_dir": cache_dir,
"local_files_only": local_files_only,
}
log.debug(f"model: {model}")
log.debug(f"snapshot_kwargs: {snapshot_kwargs}")
# Inspiration from upstream sentence_transformers
if (
os.path.exists(model)
or ("\\" in model or model.count("/") > 1)
and local_files_only
):
# If fully qualified path exists, return input, else set repo_id
return model
elif "/" not in model:
# Set valid repo_id for model short-name
model = "sentence-transformers" + "/" + model
snapshot_kwargs["repo_id"] = model
# Attempt to query the huggingface_hub library to determine the local path and/or to update
try:
model_repo_path = snapshot_download(**snapshot_kwargs)
log.debug(f"model_repo_path: {model_repo_path}")
return model_repo_path
except Exception as e:
log.exception(f"Cannot determine model snapshot path: {e}")
return model
import operator
from typing import Optional, Sequence
from langchain_core.callbacks import Callbacks
from langchain_core.documents import BaseDocumentCompressor, Document
class RerankCompressor(BaseDocumentCompressor):
embedding_function: Any
top_n: int
reranking_function: Any
r_score: float
class Config:
extra = "forbid"
arbitrary_types_allowed = True
def compress_documents(
self,
documents: Sequence[Document],
query: str,
callbacks: Optional[Callbacks] = None,
) -> Sequence[Document]:
"""Compress retrieved documents given the query context.
Args:
documents: The retrieved documents.
query: The query context.
callbacks: Optional callbacks to run during compression.
Returns:
The compressed documents.
"""
return []
async def acompress_documents(
self,
documents: Sequence[Document],
query: str,
callbacks: Optional[Callbacks] = None,
) -> Sequence[Document]:
reranking = self.reranking_function is not None
scores = None
if reranking:
scores = self.reranking_function(query, documents)
else:
from sentence_transformers import util
query_embedding = await self.embedding_function(
query, RAG_EMBEDDING_QUERY_PREFIX
)
document_embedding = await self.embedding_function(
[doc.page_content for doc in documents], RAG_EMBEDDING_CONTENT_PREFIX
)
scores = util.cos_sim(query_embedding, document_embedding)[0]
if scores is not None:
docs_with_scores = list(
zip(
documents,
scores.tolist() if not isinstance(scores, list) else scores,
)
)
if self.r_score:
docs_with_scores = [
(d, s) for d, s in docs_with_scores if s >= self.r_score
]
result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=True)
final_results = []
for doc, doc_score in result[: self.top_n]:
metadata = doc.metadata
metadata["score"] = doc_score
doc = Document(
page_content=doc.page_content,
metadata=metadata,
)
final_results.append(doc)
return final_results
else:
log.warning(
"No valid scores found, check your reranking function. Returning original documents."
)
return documents