mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-12 04:15:25 +00:00
refac/enh: forward user info header to reranker
This commit is contained in:
parent
b4f04ff3a7
commit
0013f5c1fc
5 changed files with 54 additions and 10 deletions
|
|
@ -89,6 +89,7 @@ from open_webui.routers import (
|
||||||
|
|
||||||
from open_webui.routers.retrieval import (
|
from open_webui.routers.retrieval import (
|
||||||
get_embedding_function,
|
get_embedding_function,
|
||||||
|
get_reranking_function,
|
||||||
get_ef,
|
get_ef,
|
||||||
get_rf,
|
get_rf,
|
||||||
)
|
)
|
||||||
|
|
@ -878,6 +879,7 @@ app.state.config.FIRECRAWL_API_KEY = FIRECRAWL_API_KEY
|
||||||
app.state.config.TAVILY_EXTRACT_DEPTH = TAVILY_EXTRACT_DEPTH
|
app.state.config.TAVILY_EXTRACT_DEPTH = TAVILY_EXTRACT_DEPTH
|
||||||
|
|
||||||
app.state.EMBEDDING_FUNCTION = None
|
app.state.EMBEDDING_FUNCTION = None
|
||||||
|
app.state.RERANKING_FUNCTION = None
|
||||||
app.state.ef = None
|
app.state.ef = None
|
||||||
app.state.rf = None
|
app.state.rf = None
|
||||||
|
|
||||||
|
|
@ -906,8 +908,8 @@ except Exception as e:
|
||||||
app.state.EMBEDDING_FUNCTION = get_embedding_function(
|
app.state.EMBEDDING_FUNCTION = get_embedding_function(
|
||||||
app.state.config.RAG_EMBEDDING_ENGINE,
|
app.state.config.RAG_EMBEDDING_ENGINE,
|
||||||
app.state.config.RAG_EMBEDDING_MODEL,
|
app.state.config.RAG_EMBEDDING_MODEL,
|
||||||
app.state.ef,
|
embedding_function=app.state.ef,
|
||||||
(
|
url=(
|
||||||
app.state.config.RAG_OPENAI_API_BASE_URL
|
app.state.config.RAG_OPENAI_API_BASE_URL
|
||||||
if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
|
if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
|
||||||
else (
|
else (
|
||||||
|
|
@ -916,7 +918,7 @@ app.state.EMBEDDING_FUNCTION = get_embedding_function(
|
||||||
else app.state.config.RAG_AZURE_OPENAI_BASE_URL
|
else app.state.config.RAG_AZURE_OPENAI_BASE_URL
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
(
|
key=(
|
||||||
app.state.config.RAG_OPENAI_API_KEY
|
app.state.config.RAG_OPENAI_API_KEY
|
||||||
if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
|
if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
|
||||||
else (
|
else (
|
||||||
|
|
@ -925,7 +927,7 @@ app.state.EMBEDDING_FUNCTION = get_embedding_function(
|
||||||
else app.state.config.RAG_AZURE_OPENAI_API_KEY
|
else app.state.config.RAG_AZURE_OPENAI_API_KEY
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
embedding_batch_size=app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
||||||
azure_api_version=(
|
azure_api_version=(
|
||||||
app.state.config.RAG_AZURE_OPENAI_API_VERSION
|
app.state.config.RAG_AZURE_OPENAI_API_VERSION
|
||||||
if app.state.config.RAG_EMBEDDING_ENGINE == "azure_openai"
|
if app.state.config.RAG_EMBEDDING_ENGINE == "azure_openai"
|
||||||
|
|
@ -933,6 +935,12 @@ app.state.EMBEDDING_FUNCTION = get_embedding_function(
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
app.state.RERANKING_FUNCTION = get_reranking_function(
|
||||||
|
app.state.config.RAG_RERANKING_ENGINE,
|
||||||
|
app.state.config.RAG_RERANKING_MODEL,
|
||||||
|
reranking_function=app.state.rf,
|
||||||
|
)
|
||||||
|
|
||||||
########################################
|
########################################
|
||||||
#
|
#
|
||||||
# CODE EXECUTION
|
# CODE EXECUTION
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,10 @@
|
||||||
import logging
|
import logging
|
||||||
import requests
|
import requests
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
from urllib.parse import quote
|
||||||
|
|
||||||
from open_webui.env import SRC_LOG_LEVELS
|
|
||||||
|
from open_webui.env import ENABLE_FORWARD_USER_INFO_HEADERS, SRC_LOG_LEVELS
|
||||||
from open_webui.retrieval.models.base_reranker import BaseReranker
|
from open_webui.retrieval.models.base_reranker import BaseReranker
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -21,7 +23,9 @@ class ExternalReranker(BaseReranker):
|
||||||
self.url = url
|
self.url = url
|
||||||
self.model = model
|
self.model = model
|
||||||
|
|
||||||
def predict(self, sentences: List[Tuple[str, str]]) -> Optional[List[float]]:
|
def predict(
|
||||||
|
self, sentences: List[Tuple[str, str]], user=None
|
||||||
|
) -> Optional[List[float]]:
|
||||||
query = sentences[0][0]
|
query = sentences[0][0]
|
||||||
docs = [i[1] for i in sentences]
|
docs = [i[1] for i in sentences]
|
||||||
|
|
||||||
|
|
@ -41,6 +45,16 @@ class ExternalReranker(BaseReranker):
|
||||||
headers={
|
headers={
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
"Authorization": f"Bearer {self.api_key}",
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
|
**(
|
||||||
|
{
|
||||||
|
"X-OpenWebUI-User-Name": quote(user.name, safe=" "),
|
||||||
|
"X-OpenWebUI-User-Id": user.id,
|
||||||
|
"X-OpenWebUI-User-Email": user.email,
|
||||||
|
"X-OpenWebUI-User-Role": user.role,
|
||||||
|
}
|
||||||
|
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||||
|
else {}
|
||||||
|
),
|
||||||
},
|
},
|
||||||
json=payload,
|
json=payload,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -445,6 +445,15 @@ def get_embedding_function(
|
||||||
raise ValueError(f"Unknown embedding engine: {embedding_engine}")
|
raise ValueError(f"Unknown embedding engine: {embedding_engine}")
|
||||||
|
|
||||||
|
|
||||||
|
def get_reranking_function(reranking_engine, reranking_model, reranking_function):
|
||||||
|
if reranking_engine == "external":
|
||||||
|
return lambda sentences, user=None: reranking_function.predict(
|
||||||
|
sentences, user=user
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return lambda sentences, user=None: reranking_function.predict(sentences)
|
||||||
|
|
||||||
|
|
||||||
def get_sources_from_items(
|
def get_sources_from_items(
|
||||||
request,
|
request,
|
||||||
items,
|
items,
|
||||||
|
|
@ -925,7 +934,7 @@ class RerankCompressor(BaseDocumentCompressor):
|
||||||
reranking = self.reranking_function is not None
|
reranking = self.reranking_function is not None
|
||||||
|
|
||||||
if reranking:
|
if reranking:
|
||||||
scores = self.reranking_function.predict(
|
scores = self.reranking_function(
|
||||||
[(query, doc.page_content) for doc in documents]
|
[(query, doc.page_content) for doc in documents]
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -70,6 +70,7 @@ from open_webui.retrieval.web.external import search_external
|
||||||
|
|
||||||
from open_webui.retrieval.utils import (
|
from open_webui.retrieval.utils import (
|
||||||
get_embedding_function,
|
get_embedding_function,
|
||||||
|
get_reranking_function,
|
||||||
get_model_path,
|
get_model_path,
|
||||||
query_collection,
|
query_collection,
|
||||||
query_collection_with_hybrid_search,
|
query_collection_with_hybrid_search,
|
||||||
|
|
@ -824,6 +825,12 @@ async def update_rag_config(
|
||||||
request.app.state.config.RAG_EXTERNAL_RERANKER_API_KEY,
|
request.app.state.config.RAG_EXTERNAL_RERANKER_API_KEY,
|
||||||
True,
|
True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
request.app.state.RERANKING_FUNCTION = get_reranking_function(
|
||||||
|
request.app.state.config.RAG_RERANKING_ENGINE,
|
||||||
|
request.app.state.config.RAG_RERANKING_MODEL,
|
||||||
|
request.app.state.rf,
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"Error loading reranking model: {e}")
|
log.error(f"Error loading reranking model: {e}")
|
||||||
request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = False
|
request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = False
|
||||||
|
|
@ -2042,7 +2049,9 @@ def query_doc_handler(
|
||||||
query, prefix=prefix, user=user
|
query, prefix=prefix, user=user
|
||||||
),
|
),
|
||||||
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
|
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
|
||||||
reranking_function=request.app.state.rf,
|
reranking_function=lambda sentences: request.app.state.RERANKING_FUNCTION(
|
||||||
|
sentences, user=user
|
||||||
|
),
|
||||||
k_reranker=form_data.k_reranker
|
k_reranker=form_data.k_reranker
|
||||||
or request.app.state.config.TOP_K_RERANKER,
|
or request.app.state.config.TOP_K_RERANKER,
|
||||||
r=(
|
r=(
|
||||||
|
|
@ -2099,7 +2108,9 @@ def query_collection_handler(
|
||||||
query, prefix=prefix, user=user
|
query, prefix=prefix, user=user
|
||||||
),
|
),
|
||||||
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
|
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
|
||||||
reranking_function=request.app.state.rf,
|
reranking_function=lambda sentences: request.app.state.RERANKING_FUNCTION(
|
||||||
|
sentences, user=user
|
||||||
|
),
|
||||||
k_reranker=form_data.k_reranker
|
k_reranker=form_data.k_reranker
|
||||||
or request.app.state.config.TOP_K_RERANKER,
|
or request.app.state.config.TOP_K_RERANKER,
|
||||||
r=(
|
r=(
|
||||||
|
|
|
||||||
|
|
@ -652,7 +652,9 @@ async def chat_completion_files_handler(
|
||||||
query, prefix=prefix, user=user
|
query, prefix=prefix, user=user
|
||||||
),
|
),
|
||||||
k=request.app.state.config.TOP_K,
|
k=request.app.state.config.TOP_K,
|
||||||
reranking_function=request.app.state.rf,
|
reranking_function=lambda sentences: request.app.state.RERANKING_FUNCTION(
|
||||||
|
sentences, user=user
|
||||||
|
),
|
||||||
k_reranker=request.app.state.config.TOP_K_RERANKER,
|
k_reranker=request.app.state.config.TOP_K_RERANKER,
|
||||||
r=request.app.state.config.RELEVANCE_THRESHOLD,
|
r=request.app.state.config.RELEVANCE_THRESHOLD,
|
||||||
hybrid_bm25_weight=request.app.state.config.HYBRID_BM25_WEIGHT,
|
hybrid_bm25_weight=request.app.state.config.HYBRID_BM25_WEIGHT,
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue