diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index ca811b1032..595d551d75 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -89,6 +89,7 @@ from open_webui.routers import ( from open_webui.routers.retrieval import ( get_embedding_function, + get_reranking_function, get_ef, get_rf, ) @@ -878,6 +879,7 @@ app.state.config.FIRECRAWL_API_KEY = FIRECRAWL_API_KEY app.state.config.TAVILY_EXTRACT_DEPTH = TAVILY_EXTRACT_DEPTH app.state.EMBEDDING_FUNCTION = None +app.state.RERANKING_FUNCTION = None app.state.ef = None app.state.rf = None @@ -906,8 +908,8 @@ except Exception as e: app.state.EMBEDDING_FUNCTION = get_embedding_function( app.state.config.RAG_EMBEDDING_ENGINE, app.state.config.RAG_EMBEDDING_MODEL, - app.state.ef, - ( + embedding_function=app.state.ef, + url=( app.state.config.RAG_OPENAI_API_BASE_URL if app.state.config.RAG_EMBEDDING_ENGINE == "openai" else ( @@ -916,7 +918,7 @@ app.state.EMBEDDING_FUNCTION = get_embedding_function( else app.state.config.RAG_AZURE_OPENAI_BASE_URL ) ), - ( + key=( app.state.config.RAG_OPENAI_API_KEY if app.state.config.RAG_EMBEDDING_ENGINE == "openai" else ( @@ -925,7 +927,7 @@ app.state.EMBEDDING_FUNCTION = get_embedding_function( else app.state.config.RAG_AZURE_OPENAI_API_KEY ) ), - app.state.config.RAG_EMBEDDING_BATCH_SIZE, + embedding_batch_size=app.state.config.RAG_EMBEDDING_BATCH_SIZE, azure_api_version=( app.state.config.RAG_AZURE_OPENAI_API_VERSION if app.state.config.RAG_EMBEDDING_ENGINE == "azure_openai" @@ -933,6 +935,12 @@ app.state.EMBEDDING_FUNCTION = get_embedding_function( ), ) +app.state.RERANKING_FUNCTION = get_reranking_function( + app.state.config.RAG_RERANKING_ENGINE, + app.state.config.RAG_RERANKING_MODEL, + reranking_function=app.state.rf, +) + ######################################## # # CODE EXECUTION diff --git a/backend/open_webui/retrieval/models/external.py b/backend/open_webui/retrieval/models/external.py index 5ebc3e52ea..a9be526b6d 100644 --- a/backend/open_webui/retrieval/models/external.py +++ b/backend/open_webui/retrieval/models/external.py @@ -1,8 +1,10 @@ import logging import requests from typing import Optional, List, Tuple +from urllib.parse import quote -from open_webui.env import SRC_LOG_LEVELS + +from open_webui.env import ENABLE_FORWARD_USER_INFO_HEADERS, SRC_LOG_LEVELS from open_webui.retrieval.models.base_reranker import BaseReranker @@ -21,7 +23,9 @@ class ExternalReranker(BaseReranker): self.url = url self.model = model - def predict(self, sentences: List[Tuple[str, str]]) -> Optional[List[float]]: + def predict( + self, sentences: List[Tuple[str, str]], user=None + ) -> Optional[List[float]]: query = sentences[0][0] docs = [i[1] for i in sentences] @@ -41,6 +45,16 @@ class ExternalReranker(BaseReranker): headers={ "Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}", + **( + { + "X-OpenWebUI-User-Name": quote(user.name, safe=" "), + "X-OpenWebUI-User-Id": user.id, + "X-OpenWebUI-User-Email": user.email, + "X-OpenWebUI-User-Role": user.role, + } + if ENABLE_FORWARD_USER_INFO_HEADERS and user + else {} + ), }, json=payload, ) diff --git a/backend/open_webui/retrieval/utils.py b/backend/open_webui/retrieval/utils.py index c0ad2b765b..06f8e939e3 100644 --- a/backend/open_webui/retrieval/utils.py +++ b/backend/open_webui/retrieval/utils.py @@ -445,6 +445,15 @@ def get_embedding_function( raise ValueError(f"Unknown embedding engine: {embedding_engine}") +def get_reranking_function(reranking_engine, reranking_model, reranking_function): + if reranking_engine == "external": + return lambda sentences, user=None: reranking_function.predict( + sentences, user=user + ) + else: + return lambda sentences, user=None: reranking_function.predict(sentences) + + def get_sources_from_items( request, items, @@ -925,7 +934,7 @@ class RerankCompressor(BaseDocumentCompressor): reranking = self.reranking_function is not None if reranking: - scores = self.reranking_function.predict( + scores = self.reranking_function( [(query, doc.page_content) for doc in documents] ) else: diff --git a/backend/open_webui/routers/retrieval.py b/backend/open_webui/routers/retrieval.py index 34910f23ef..25e6754b08 100644 --- a/backend/open_webui/routers/retrieval.py +++ b/backend/open_webui/routers/retrieval.py @@ -70,6 +70,7 @@ from open_webui.retrieval.web.external import search_external from open_webui.retrieval.utils import ( get_embedding_function, + get_reranking_function, get_model_path, query_collection, query_collection_with_hybrid_search, @@ -824,6 +825,12 @@ async def update_rag_config( request.app.state.config.RAG_EXTERNAL_RERANKER_API_KEY, True, ) + + request.app.state.RERANKING_FUNCTION = get_reranking_function( + request.app.state.config.RAG_RERANKING_ENGINE, + request.app.state.config.RAG_RERANKING_MODEL, + request.app.state.rf, + ) except Exception as e: log.error(f"Error loading reranking model: {e}") request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = False @@ -2042,7 +2049,9 @@ def query_doc_handler( query, prefix=prefix, user=user ), k=form_data.k if form_data.k else request.app.state.config.TOP_K, - reranking_function=request.app.state.rf, + reranking_function=lambda sentences: request.app.state.RERANKING_FUNCTION( + sentences, user=user + ), k_reranker=form_data.k_reranker or request.app.state.config.TOP_K_RERANKER, r=( @@ -2099,7 +2108,9 @@ def query_collection_handler( query, prefix=prefix, user=user ), k=form_data.k if form_data.k else request.app.state.config.TOP_K, - reranking_function=request.app.state.rf, + reranking_function=lambda sentences: request.app.state.RERANKING_FUNCTION( + sentences, user=user + ), k_reranker=form_data.k_reranker or request.app.state.config.TOP_K_RERANKER, r=( diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index 003e97e84c..e5850735b0 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -652,7 +652,9 @@ async def chat_completion_files_handler( query, prefix=prefix, user=user ), k=request.app.state.config.TOP_K, - reranking_function=request.app.state.rf, + reranking_function=lambda sentences: request.app.state.RERANKING_FUNCTION( + sentences, user=user + ), k_reranker=request.app.state.config.TOP_K_RERANKER, r=request.app.state.config.RELEVANCE_THRESHOLD, hybrid_bm25_weight=request.app.state.config.HYBRID_BM25_WEIGHT,