mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-12 04:15:25 +00:00
refac/enh: forward user info header to reranker
This commit is contained in:
parent
b4f04ff3a7
commit
0013f5c1fc
5 changed files with 54 additions and 10 deletions
|
|
@ -89,6 +89,7 @@ from open_webui.routers import (
|
|||
|
||||
from open_webui.routers.retrieval import (
|
||||
get_embedding_function,
|
||||
get_reranking_function,
|
||||
get_ef,
|
||||
get_rf,
|
||||
)
|
||||
|
|
@ -878,6 +879,7 @@ app.state.config.FIRECRAWL_API_KEY = FIRECRAWL_API_KEY
|
|||
app.state.config.TAVILY_EXTRACT_DEPTH = TAVILY_EXTRACT_DEPTH
|
||||
|
||||
app.state.EMBEDDING_FUNCTION = None
|
||||
app.state.RERANKING_FUNCTION = None
|
||||
app.state.ef = None
|
||||
app.state.rf = None
|
||||
|
||||
|
|
@ -906,8 +908,8 @@ except Exception as e:
|
|||
app.state.EMBEDDING_FUNCTION = get_embedding_function(
|
||||
app.state.config.RAG_EMBEDDING_ENGINE,
|
||||
app.state.config.RAG_EMBEDDING_MODEL,
|
||||
app.state.ef,
|
||||
(
|
||||
embedding_function=app.state.ef,
|
||||
url=(
|
||||
app.state.config.RAG_OPENAI_API_BASE_URL
|
||||
if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
|
||||
else (
|
||||
|
|
@ -916,7 +918,7 @@ app.state.EMBEDDING_FUNCTION = get_embedding_function(
|
|||
else app.state.config.RAG_AZURE_OPENAI_BASE_URL
|
||||
)
|
||||
),
|
||||
(
|
||||
key=(
|
||||
app.state.config.RAG_OPENAI_API_KEY
|
||||
if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
|
||||
else (
|
||||
|
|
@ -925,7 +927,7 @@ app.state.EMBEDDING_FUNCTION = get_embedding_function(
|
|||
else app.state.config.RAG_AZURE_OPENAI_API_KEY
|
||||
)
|
||||
),
|
||||
app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
||||
embedding_batch_size=app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
||||
azure_api_version=(
|
||||
app.state.config.RAG_AZURE_OPENAI_API_VERSION
|
||||
if app.state.config.RAG_EMBEDDING_ENGINE == "azure_openai"
|
||||
|
|
@ -933,6 +935,12 @@ app.state.EMBEDDING_FUNCTION = get_embedding_function(
|
|||
),
|
||||
)
|
||||
|
||||
app.state.RERANKING_FUNCTION = get_reranking_function(
|
||||
app.state.config.RAG_RERANKING_ENGINE,
|
||||
app.state.config.RAG_RERANKING_MODEL,
|
||||
reranking_function=app.state.rf,
|
||||
)
|
||||
|
||||
########################################
|
||||
#
|
||||
# CODE EXECUTION
|
||||
|
|
|
|||
|
|
@ -1,8 +1,10 @@
|
|||
import logging
|
||||
import requests
|
||||
from typing import Optional, List, Tuple
|
||||
from urllib.parse import quote
|
||||
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
from open_webui.env import ENABLE_FORWARD_USER_INFO_HEADERS, SRC_LOG_LEVELS
|
||||
from open_webui.retrieval.models.base_reranker import BaseReranker
|
||||
|
||||
|
||||
|
|
@ -21,7 +23,9 @@ class ExternalReranker(BaseReranker):
|
|||
self.url = url
|
||||
self.model = model
|
||||
|
||||
def predict(self, sentences: List[Tuple[str, str]]) -> Optional[List[float]]:
|
||||
def predict(
|
||||
self, sentences: List[Tuple[str, str]], user=None
|
||||
) -> Optional[List[float]]:
|
||||
query = sentences[0][0]
|
||||
docs = [i[1] for i in sentences]
|
||||
|
||||
|
|
@ -41,6 +45,16 @@ class ExternalReranker(BaseReranker):
|
|||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": quote(user.name, safe=" "),
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||
else {}
|
||||
),
|
||||
},
|
||||
json=payload,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -445,6 +445,15 @@ def get_embedding_function(
|
|||
raise ValueError(f"Unknown embedding engine: {embedding_engine}")
|
||||
|
||||
|
||||
def get_reranking_function(reranking_engine, reranking_model, reranking_function):
|
||||
if reranking_engine == "external":
|
||||
return lambda sentences, user=None: reranking_function.predict(
|
||||
sentences, user=user
|
||||
)
|
||||
else:
|
||||
return lambda sentences, user=None: reranking_function.predict(sentences)
|
||||
|
||||
|
||||
def get_sources_from_items(
|
||||
request,
|
||||
items,
|
||||
|
|
@ -925,7 +934,7 @@ class RerankCompressor(BaseDocumentCompressor):
|
|||
reranking = self.reranking_function is not None
|
||||
|
||||
if reranking:
|
||||
scores = self.reranking_function.predict(
|
||||
scores = self.reranking_function(
|
||||
[(query, doc.page_content) for doc in documents]
|
||||
)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -70,6 +70,7 @@ from open_webui.retrieval.web.external import search_external
|
|||
|
||||
from open_webui.retrieval.utils import (
|
||||
get_embedding_function,
|
||||
get_reranking_function,
|
||||
get_model_path,
|
||||
query_collection,
|
||||
query_collection_with_hybrid_search,
|
||||
|
|
@ -824,6 +825,12 @@ async def update_rag_config(
|
|||
request.app.state.config.RAG_EXTERNAL_RERANKER_API_KEY,
|
||||
True,
|
||||
)
|
||||
|
||||
request.app.state.RERANKING_FUNCTION = get_reranking_function(
|
||||
request.app.state.config.RAG_RERANKING_ENGINE,
|
||||
request.app.state.config.RAG_RERANKING_MODEL,
|
||||
request.app.state.rf,
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"Error loading reranking model: {e}")
|
||||
request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = False
|
||||
|
|
@ -2042,7 +2049,9 @@ def query_doc_handler(
|
|||
query, prefix=prefix, user=user
|
||||
),
|
||||
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
|
||||
reranking_function=request.app.state.rf,
|
||||
reranking_function=lambda sentences: request.app.state.RERANKING_FUNCTION(
|
||||
sentences, user=user
|
||||
),
|
||||
k_reranker=form_data.k_reranker
|
||||
or request.app.state.config.TOP_K_RERANKER,
|
||||
r=(
|
||||
|
|
@ -2099,7 +2108,9 @@ def query_collection_handler(
|
|||
query, prefix=prefix, user=user
|
||||
),
|
||||
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
|
||||
reranking_function=request.app.state.rf,
|
||||
reranking_function=lambda sentences: request.app.state.RERANKING_FUNCTION(
|
||||
sentences, user=user
|
||||
),
|
||||
k_reranker=form_data.k_reranker
|
||||
or request.app.state.config.TOP_K_RERANKER,
|
||||
r=(
|
||||
|
|
|
|||
|
|
@ -652,7 +652,9 @@ async def chat_completion_files_handler(
|
|||
query, prefix=prefix, user=user
|
||||
),
|
||||
k=request.app.state.config.TOP_K,
|
||||
reranking_function=request.app.state.rf,
|
||||
reranking_function=lambda sentences: request.app.state.RERANKING_FUNCTION(
|
||||
sentences, user=user
|
||||
),
|
||||
k_reranker=request.app.state.config.TOP_K_RERANKER,
|
||||
r=request.app.state.config.RELEVANCE_THRESHOLD,
|
||||
hybrid_bm25_weight=request.app.state.config.HYBRID_BM25_WEIGHT,
|
||||
|
|
|
|||
Loading…
Reference in a new issue