refac/enh: forward user info header to reranker

This commit is contained in:
Timothy Jaeryang Baek 2025-07-14 13:59:10 +04:00
parent b4f04ff3a7
commit 0013f5c1fc
5 changed files with 54 additions and 10 deletions

View file

@ -89,6 +89,7 @@ from open_webui.routers import (
from open_webui.routers.retrieval import (
get_embedding_function,
get_reranking_function,
get_ef,
get_rf,
)
@ -878,6 +879,7 @@ app.state.config.FIRECRAWL_API_KEY = FIRECRAWL_API_KEY
app.state.config.TAVILY_EXTRACT_DEPTH = TAVILY_EXTRACT_DEPTH
app.state.EMBEDDING_FUNCTION = None
app.state.RERANKING_FUNCTION = None
app.state.ef = None
app.state.rf = None
@ -906,8 +908,8 @@ except Exception as e:
app.state.EMBEDDING_FUNCTION = get_embedding_function(
app.state.config.RAG_EMBEDDING_ENGINE,
app.state.config.RAG_EMBEDDING_MODEL,
app.state.ef,
(
embedding_function=app.state.ef,
url=(
app.state.config.RAG_OPENAI_API_BASE_URL
if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
else (
@ -916,7 +918,7 @@ app.state.EMBEDDING_FUNCTION = get_embedding_function(
else app.state.config.RAG_AZURE_OPENAI_BASE_URL
)
),
(
key=(
app.state.config.RAG_OPENAI_API_KEY
if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
else (
@ -925,7 +927,7 @@ app.state.EMBEDDING_FUNCTION = get_embedding_function(
else app.state.config.RAG_AZURE_OPENAI_API_KEY
)
),
app.state.config.RAG_EMBEDDING_BATCH_SIZE,
embedding_batch_size=app.state.config.RAG_EMBEDDING_BATCH_SIZE,
azure_api_version=(
app.state.config.RAG_AZURE_OPENAI_API_VERSION
if app.state.config.RAG_EMBEDDING_ENGINE == "azure_openai"
@ -933,6 +935,12 @@ app.state.EMBEDDING_FUNCTION = get_embedding_function(
),
)
app.state.RERANKING_FUNCTION = get_reranking_function(
app.state.config.RAG_RERANKING_ENGINE,
app.state.config.RAG_RERANKING_MODEL,
reranking_function=app.state.rf,
)
########################################
#
# CODE EXECUTION

View file

@ -1,8 +1,10 @@
import logging
import requests
from typing import Optional, List, Tuple
from urllib.parse import quote
from open_webui.env import SRC_LOG_LEVELS
from open_webui.env import ENABLE_FORWARD_USER_INFO_HEADERS, SRC_LOG_LEVELS
from open_webui.retrieval.models.base_reranker import BaseReranker
@ -21,7 +23,9 @@ class ExternalReranker(BaseReranker):
self.url = url
self.model = model
def predict(self, sentences: List[Tuple[str, str]]) -> Optional[List[float]]:
def predict(
self, sentences: List[Tuple[str, str]], user=None
) -> Optional[List[float]]:
query = sentences[0][0]
docs = [i[1] for i in sentences]
@ -41,6 +45,16 @@ class ExternalReranker(BaseReranker):
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}",
**(
{
"X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
),
},
json=payload,
)

View file

@ -445,6 +445,15 @@ def get_embedding_function(
raise ValueError(f"Unknown embedding engine: {embedding_engine}")
def get_reranking_function(reranking_engine, reranking_model, reranking_function):
if reranking_engine == "external":
return lambda sentences, user=None: reranking_function.predict(
sentences, user=user
)
else:
return lambda sentences, user=None: reranking_function.predict(sentences)
def get_sources_from_items(
request,
items,
@ -925,7 +934,7 @@ class RerankCompressor(BaseDocumentCompressor):
reranking = self.reranking_function is not None
if reranking:
scores = self.reranking_function.predict(
scores = self.reranking_function(
[(query, doc.page_content) for doc in documents]
)
else:

View file

@ -70,6 +70,7 @@ from open_webui.retrieval.web.external import search_external
from open_webui.retrieval.utils import (
get_embedding_function,
get_reranking_function,
get_model_path,
query_collection,
query_collection_with_hybrid_search,
@ -824,6 +825,12 @@ async def update_rag_config(
request.app.state.config.RAG_EXTERNAL_RERANKER_API_KEY,
True,
)
request.app.state.RERANKING_FUNCTION = get_reranking_function(
request.app.state.config.RAG_RERANKING_ENGINE,
request.app.state.config.RAG_RERANKING_MODEL,
request.app.state.rf,
)
except Exception as e:
log.error(f"Error loading reranking model: {e}")
request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = False
@ -2042,7 +2049,9 @@ def query_doc_handler(
query, prefix=prefix, user=user
),
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
reranking_function=request.app.state.rf,
reranking_function=lambda sentences: request.app.state.RERANKING_FUNCTION(
sentences, user=user
),
k_reranker=form_data.k_reranker
or request.app.state.config.TOP_K_RERANKER,
r=(
@ -2099,7 +2108,9 @@ def query_collection_handler(
query, prefix=prefix, user=user
),
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
reranking_function=request.app.state.rf,
reranking_function=lambda sentences: request.app.state.RERANKING_FUNCTION(
sentences, user=user
),
k_reranker=form_data.k_reranker
or request.app.state.config.TOP_K_RERANKER,
r=(

View file

@ -652,7 +652,9 @@ async def chat_completion_files_handler(
query, prefix=prefix, user=user
),
k=request.app.state.config.TOP_K,
reranking_function=request.app.state.rf,
reranking_function=lambda sentences: request.app.state.RERANKING_FUNCTION(
sentences, user=user
),
k_reranker=request.app.state.config.TOP_K_RERANKER,
r=request.app.state.config.RELEVANCE_THRESHOLD,
hybrid_bm25_weight=request.app.state.config.HYBRID_BM25_WEIGHT,